├── README.md ├── acdc_data_preparation.py ├── libs ├── configs │ └── config_acdc.py ├── datasets │ ├── __init__.py │ ├── acdc_dataset.py │ ├── augment.py │ ├── collate_batch.py │ ├── gen_acdcjson.py │ ├── generate_FilesList.py │ ├── generate_personList.py │ └── joint_augment.py ├── losses │ ├── ce_ieLoss.py │ ├── create_losses.py │ ├── df_loss.py │ ├── mag_angle_loss.py │ └── surface_loss.py └── network │ ├── __init__.py │ ├── train_functions.py │ ├── unet.py │ └── unet_df.py ├── pipeline.png ├── tools ├── MultiShellScripts.py ├── _init_paths.py ├── test_acdc_leadboard.py ├── test_df_vis.py ├── test_utils.py ├── train.py └── train_utils │ ├── __pycache__ │ ├── _init_paths.cpython-36.pyc │ ├── _init_paths.cpython-37.pyc │ ├── base_dir.cpython-36.pyc │ ├── test_3D.cpython-36.pyc │ ├── test_utils.cpython-36.pyc │ ├── train_utils.cpython-36.pyc │ └── train_utils.cpython-37.pyc │ └── train_utils.py └── utils ├── comm.py ├── dice3D.py ├── direct_field ├── __pycache__ │ ├── df_cardia.cpython-36.pyc │ ├── df_cardia.cpython-37.pyc │ ├── utils_df.cpython-36.pyc │ └── utils_df.cpython-37.pyc ├── df_cardia.py └── utils_df.py ├── image_list.py ├── init_net.py ├── metrics.py ├── utils_loss.py └── vis_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # DirectionalFeature 2 | 3 | This repository contains the code of the following paper "**Learning Directional Feature Maps for Cardiac MRI Segmentation (published in MICCAI2020)**", https://arxiv.org/abs/2007.11349 4 | 5 | ![](./pipeline.png) 6 | ## Citation 7 | 8 | Please cite the related works in your publications if it helps your research: 9 | 10 | ``` 11 | @inproceedings{cheng2020learning, 12 | title={Learning directional feature maps for cardiac mri segmentation}, 13 | author={Cheng, Feng and Chen, Cheng and Wang, Yukang and Shi, Heshui and Cao, Yukun and Tu, Dandan and Zhang, Changzheng and Xu, Yongchao}, 14 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 15 | pages={108--117}, 16 | year={2020}, 17 | organization={Springer} 18 | } 19 | 20 | ``` 21 | ## Usage 22 | 23 | ### ACDC Data Preparation 24 | 1. Register and download ACDC-2017 dataset from https://www.creatis.insa-lyon.fr/Challenge/acdc/index.html 25 | 2. Create a folder outside the project with name **ACDC_DataSet** and copy the dataset. 26 | 3. From the project folder open file acdc_data_preparation.py. 27 | 4. In the file, set the path to ACDC training dataset is pointed as: ```complete_data_path = '../../ACDC_DataSet/training' ```. 28 | 5. Run the script acdc_data_preparation.py. 29 | 6. The processed data for training is generated outside the project folder named *processed_acdc_dataset*. 30 | 7. Run the ./libs/datastes/gen_acdcjson.py to generate the data list for ACDC training and validation. 31 | 32 | ### Training 33 | ``` 34 | cd ./tools 35 | python -m torch.distributed.launch --nproc_per_node 4 --master_port $RANDOM train.py --batch_size 12 --mgpus 0,1,2,3 --output_dir logs/... --train_with_eval 36 | ``` 37 | -------------------------------------------------------------------------------- /acdc_data_preparation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, sys, shutil, time, re 3 | import h5py 4 | import skimage.morphology as morph 5 | from tqdm import tqdm 6 | import matplotlib.pyplot as plt 7 | from matplotlib import animation 8 | import time 9 | import pickle 10 | # For ROI extraction 11 | import skimage.transform 12 | from scipy.fftpack import fftn, ifftn 13 | from skimage.feature import peak_local_max, canny 14 | from skimage.transform import hough_circle 15 | # Nifti processing 16 | import nibabel as nib 17 | from collections import OrderedDict 18 | # print sys.path 19 | # sys.path.append("..") 20 | import errno 21 | np.random.seed(42) 22 | 23 | # Helper functions 24 | ## Heart Metrics 25 | def heart_metrics(seg_3Dmap, voxel_size, classes=[3, 1, 2]): 26 | """ 27 | Compute the volumes of each classes 28 | """ 29 | # Loop on each classes of the input images 30 | volumes = [] 31 | for c in classes: 32 | # Copy the gt image to not alterate the input 33 | seg_3Dmap_copy = np.copy(seg_3Dmap) 34 | seg_3Dmap_copy[seg_3Dmap_copy != c] = 0 35 | 36 | # Clip the value to compute the volumes 37 | seg_3Dmap_copy = np.clip(seg_3Dmap_copy, 0, 1) 38 | 39 | # Compute volume 40 | volume = seg_3Dmap_copy.sum() * np.prod(voxel_size) / 1000. 41 | volumes += [volume] 42 | return volumes 43 | 44 | def ejection_fraction(ed_vol, es_vol): 45 | """ 46 | Calculate ejection fraction 47 | """ 48 | stroke_vol = ed_vol - es_vol 49 | return (np.float(stroke_vol)/np.float(ed_vol))*100 50 | 51 | def myocardialmass(myocardvol): 52 | """ 53 | Specific gravity of heart muscle (1.05 g/ml) 54 | """ 55 | return myocardvol*1.05 56 | def imshow(*args,**kwargs): 57 | """ Handy function to show multiple plots in on row, possibly with different cmaps and titles 58 | Usage: 59 | imshow(img1, title="myPlot") 60 | imshow(img1,img2, title=['title1','title2']) 61 | imshow(img1,img2, cmap='hot') 62 | imshow(img1,img2,cmap=['gray','Blues']) """ 63 | cmap = kwargs.get('cmap', 'gray') 64 | title= kwargs.get('title','') 65 | if len(args)==0: 66 | raise ValueError("No images given to imshow") 67 | elif len(args)==1: 68 | plt.title(title) 69 | plt.imshow(args[0], interpolation='none') 70 | else: 71 | n=len(args) 72 | if type(cmap)==str: 73 | cmap = [cmap]*n 74 | if type(title)==str: 75 | title= [title]*n 76 | plt.figure(figsize=(n*5,10)) 77 | for i in range(n): 78 | plt.subplot(1,n,i+1) 79 | plt.title(title[i]) 80 | plt.imshow(args[i], cmap[i]) 81 | plt.show() 82 | 83 | def plot_roi(data4D, roi_center, roi_radii): 84 | """ 85 | Do the animation of full heart volume 86 | """ 87 | x_roi_center, y_roi_center = roi_center[0], roi_center[1] 88 | x_roi_radius, y_roi_radius = roi_radii[0], roi_radii[1] 89 | print ('nslices', data4D.shape[2]) 90 | 91 | zslices = data4D.shape[2] 92 | tframes = data4D.shape[3] 93 | 94 | slice_cnt = 0 95 | for slice in [data4D[:,:,z,:] for z in range(zslices)]: 96 | outdata = np.swapaxes(np.swapaxes(slice[:,:,:], 0,2), 1,2) 97 | roi_mask = np.zeros_like(outdata[0]) 98 | roi_mask[x_roi_center - x_roi_radius:x_roi_center + x_roi_radius, 99 | y_roi_center - y_roi_radius:y_roi_center + y_roi_radius] = 1 100 | 101 | outdata[:, roi_mask > 0.5] = 0.8 * outdata[:, roi_mask > 0.5] 102 | outdata[:, roi_mask > 0.5] = 0.8 * outdata[:, roi_mask > 0.5] 103 | 104 | fig = plt.figure(1) 105 | fig.canvas.set_window_title('slice_No' + str(slice_cnt)) 106 | slice_cnt+=1 107 | def init_out(): 108 | im.set_data(outdata[0]) 109 | 110 | def animate_out(i): 111 | im.set_data(outdata[i]) 112 | return im 113 | 114 | im = fig.gca().imshow(outdata[0], cmap='gray') 115 | anim = animation.FuncAnimation(fig, animate_out, init_func=init_out, frames=tframes, interval=50) 116 | anim.save('Cine_MRI_SAX_%d.mp4'%slice_cnt, fps=50, extra_args=['-vcodec', 'libx264']) 117 | plt.show() 118 | 119 | def plot_4D(data4D): 120 | """ 121 | Do the animation of full heart volume 122 | """ 123 | print ('nslices', data4D.shape[2]) 124 | zslices = data4D.shape[2] 125 | tframes = data4D.shape[3] 126 | 127 | slice_cnt = 0 128 | for slice in [data4D[:,:,z,:] for z in range(zslices)]: 129 | outdata = np.swapaxes(np.swapaxes(slice[:,:,:], 0,2), 1,2) 130 | fig = plt.figure(1) 131 | fig.canvas.set_window_title('slice_No' + str(slice_cnt)) 132 | slice_cnt+=1 133 | def init_out(): 134 | im.set_data(outdata[0]) 135 | 136 | def animate_out(i): 137 | im.set_data(outdata[i]) 138 | return im 139 | 140 | im = fig.gca().imshow(outdata[0], cmap='gray') 141 | anim = animation.FuncAnimation(fig, animate_out, init_func=init_out, frames=tframes, interval=50) 142 | plt.show() 143 | 144 | 145 | def multilabel_split(image_tensor): 146 | """ 147 | image_tensor : Batch * H * W 148 | Split multilabel images and return stack of images 149 | Returns: Tensor of shape: Batch * H * W * n_class (4D tensor) 150 | # TODO: Be careful: when using this code: labels need to be 151 | defined, explictly before hand as this code does not handle 152 | missing labels 153 | So far, this function is okay as it considers full volume for 154 | finding out unique labels 155 | """ 156 | labels = np.unique(image_tensor) 157 | batch_size = image_tensor.shape[0] 158 | out_shape = image_tensor.shape + (len(labels),) 159 | image_tensor_4D = np.zeros(out_shape, dtype='uint8') 160 | for i in xrange(batch_size): 161 | cnt = 0 162 | shape =image_tensor.shape[1:3] + (len(labels),) 163 | temp = np.ones(shape, dtype='uint8') 164 | for label in labels: 165 | temp[...,cnt] = np.where(image_tensor[i] == label, temp[...,cnt], 0) 166 | cnt += 1 167 | image_tensor_4D[i] = temp 168 | return image_tensor_4D 169 | 170 | def save_data(data, filename, out_path): 171 | out_filename = os.path.join(out_path, filename) 172 | with open(out_filename, 'wb') as f: 173 | pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) 174 | print ('saved to %s' % out_filename) 175 | 176 | def load_pkl(path): 177 | with open(path, 'rb') as f: 178 | obj = pickle.load(f) 179 | return obj 180 | 181 | ### Stratified Sampling of data 182 | 183 | # Refer: 184 | # http://www.echopedia.org/wiki/Left_Ventricular_Dimensions 185 | # https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html 186 | # https://en.wikipedia.org/wiki/Body_surface_area 187 | # 30 normal subjects - NOR 188 | NORMAL = 'NOR' 189 | # 30 patients with previous myocardial infarction 190 | # (ejection fraction of the left ventricle lower than 40% and several myocardial segments with abnormal contraction) - MINF 191 | MINF = 'MINF' 192 | # 30 patients with dilated cardiomyopathy 193 | # (diastolic left ventricular volume >100 mL/m2 and an ejection fraction of the left ventricle lower than 40%) - DCM 194 | DCM = 'DCM' 195 | # 30 patients with hypertrophic cardiomyopathy 196 | # (left ventricular cardiac mass high than 110 g/m2, 197 | # several myocardial segments with a thickness higher than 15 mm in diastole and a normal ejecetion fraction) - HCM 198 | HCM = 'HCM' 199 | # 30 patients with abnormal right ventricle (volume of the right ventricular 200 | # cavity higher than 110 mL/m2 or ejection fraction of the rigth ventricle lower than 40%) - RV 201 | RV = 'RV' 202 | def copy(src, dest): 203 | """ 204 | Copy function 205 | """ 206 | try: 207 | shutil.copytree(src, dest, ignore=shutil.ignore_patterns()) 208 | except OSError as e: 209 | # If the error was caused because the source wasn't a directory 210 | if e.errno == errno.ENOTDIR: 211 | shutil.copy(src, dest) 212 | else: 213 | print('Directory not copied. Error: %s' % e) 214 | 215 | def read_patient_cfg(path): 216 | """ 217 | Reads patient data in the cfg file and returns a dictionary 218 | """ 219 | patient_info = {} 220 | with open(os.path.join(path, 'Info.cfg')) as f_in: 221 | for line in f_in: 222 | l = line.rstrip().split(": ") 223 | patient_info[l[0]] = l[1] 224 | return patient_info 225 | 226 | def group_patient_cases(src_path, out_path, force=False): 227 | """ Group the patient data according to cardiac pathology""" 228 | 229 | cases = sorted(next(os.walk(src_path))[1]) 230 | dest_path = os.path.join(out_path, 'Patient_Groups') 231 | if force: 232 | shutil.rmtree(dest_path) 233 | if os.path.exists(dest_path): 234 | return dest_path 235 | 236 | os.makedirs(dest_path) 237 | os.mkdir(os.path.join(dest_path, NORMAL)) 238 | os.mkdir(os.path.join(dest_path, MINF)) 239 | os.mkdir(os.path.join(dest_path, DCM)) 240 | os.mkdir(os.path.join(dest_path, HCM)) 241 | os.mkdir(os.path.join(dest_path, RV)) 242 | 243 | for case in cases: 244 | full_path = os.path.join(src_path, case) 245 | copy(full_path, os.path.join(dest_path,\ 246 | read_patient_cfg(full_path)['Group'], case)) 247 | 248 | def generate_train_validate_test_set(src_path, dest_path): 249 | """ 250 | Split the data into 70:15:15 for train-validate-test set 251 | arg: path: input data path 252 | """ 253 | SPLIT_TRAIN = 0.7 254 | SPLIT_VALID = 0.15 255 | 256 | dest_path = os.path.join(dest_path,'dataset') 257 | if os.path.exists(dest_path): 258 | shutil.rmtree(dest_path) 259 | os.makedirs(os.path.join(dest_path, 'train_set')) 260 | os.makedirs(os.path.join(dest_path, 'validation_set')) 261 | os.makedirs(os.path.join(dest_path, 'test_set')) 262 | # print (src_path) 263 | groups = next(os.walk(src_path))[1] 264 | for group in groups: 265 | group_path = next(os.walk(os.path.join(src_path, group)))[0] 266 | patient_folders = next(os.walk(group_path))[1] 267 | np.random.shuffle(patient_folders) 268 | train_ = patient_folders[0:int(SPLIT_TRAIN*len(patient_folders))] 269 | valid_ = patient_folders[int(SPLIT_TRAIN*len(patient_folders)): 270 | int((SPLIT_TRAIN+SPLIT_VALID)*len(patient_folders))] 271 | test_ = patient_folders[int((SPLIT_TRAIN+SPLIT_VALID)*len(patient_folders)):] 272 | for patient in train_: 273 | folder_path = os.path.join(group_path, patient) 274 | copy(folder_path, os.path.join(dest_path, 'train_set', patient)) 275 | 276 | for patient in valid_: 277 | folder_path = os.path.join(group_path, patient) 278 | copy(folder_path, os.path.join(dest_path, 'validation_set', patient)) 279 | 280 | for patient in test_: 281 | folder_path = os.path.join(group_path, patient) 282 | copy(folder_path, os.path.join(dest_path, 'test_set', patient)) 283 | 284 | # Fourier-Hough Transform Based ROI Extraction 285 | def extract_roi_fft(data4D, pixel_spacing, minradius_mm=15, maxradius_mm=45, kernel_width=5, 286 | center_margin=8, num_peaks=10, num_circles=20, radstep=2): 287 | """ 288 | Returns center and radii of ROI region in (i,j) format 289 | """ 290 | # Data shape: 291 | # radius of the smallest and largest circles in mm estimated from the train set 292 | # convert to pixel counts 293 | 294 | pixel_spacing_X, pixel_spacing_Y, _,_ = pixel_spacing 295 | minradius = int(minradius_mm / pixel_spacing_X) 296 | maxradius = int(maxradius_mm / pixel_spacing_Y) 297 | 298 | ximagesize = data4D.shape[0] 299 | yimagesize = data4D.shape[1] 300 | zslices = data4D.shape[2] 301 | tframes = data4D.shape[3] 302 | xsurface = np.tile(range(ximagesize), (yimagesize, 1)).T 303 | ysurface = np.tile(range(yimagesize), (ximagesize, 1)) 304 | lsurface = np.zeros((ximagesize, yimagesize)) 305 | 306 | allcenters = [] 307 | allaccums = [] 308 | allradii = [] 309 | 310 | for slice in range(zslices): 311 | ff1 = fftn([data4D[:,:,slice, t] for t in range(tframes)]) 312 | fh = np.absolute(ifftn(ff1[1, :, :])) 313 | fh[fh < 0.1 * np.max(fh)] = 0.0 314 | image = 1. * fh / np.max(fh) 315 | # find hough circles and detect two radii 316 | edges = canny(image, sigma=3) 317 | hough_radii = np.arange(minradius, maxradius, radstep) 318 | # print hough_radii 319 | hough_res = hough_circle(edges, hough_radii) 320 | if hough_res.any(): 321 | centers = [] 322 | accums = [] 323 | radii = [] 324 | 325 | for radius, h in zip(hough_radii, hough_res): 326 | # For each radius, extract num_peaks circles 327 | peaks = peak_local_max(h, num_peaks=num_peaks) 328 | centers.extend(peaks) 329 | accums.extend(h[peaks[:, 0], peaks[:, 1]]) 330 | radii.extend([radius] * num_peaks) 331 | 332 | # Keep the most prominent num_circles circles 333 | sorted_circles_idxs = np.argsort(accums)[::-1][:num_circles] 334 | 335 | for idx in sorted_circles_idxs: 336 | center_x, center_y = centers[idx] 337 | allcenters.append(centers[idx]) 338 | allradii.append(radii[idx]) 339 | allaccums.append(accums[idx]) 340 | brightness = accums[idx] 341 | lsurface = lsurface + brightness * np.exp( 342 | -((xsurface - center_x) ** 2 + (ysurface - center_y) ** 2) / kernel_width ** 2) 343 | 344 | lsurface = lsurface / lsurface.max() 345 | # select most likely ROI center 346 | roi_center = np.unravel_index(lsurface.argmax(), lsurface.shape) 347 | 348 | # determine ROI radius 349 | roi_x_radius = 0 350 | roi_y_radius = 0 351 | for idx in range(len(allcenters)): 352 | xshift = np.abs(allcenters[idx][0] - roi_center[0]) 353 | yshift = np.abs(allcenters[idx][1] - roi_center[1]) 354 | if (xshift <= center_margin) & (yshift <= center_margin): 355 | roi_x_radius = np.max((roi_x_radius, allradii[idx] + xshift)) 356 | roi_y_radius = np.max((roi_y_radius, allradii[idx] + yshift)) 357 | 358 | if roi_x_radius > 0 and roi_y_radius > 0: 359 | roi_radii = roi_x_radius, roi_y_radius 360 | else: 361 | roi_radii = None 362 | 363 | return roi_center, roi_radii 364 | 365 | # Stddev-Hough Transform Based ROI Extraction 366 | def extract_roi_stddev(data4D, pixel_spacing, minradius_mm=15, maxradius_mm=45, kernel_width=5, 367 | center_margin=8, num_peaks=10, num_circles=20, radstep=2): 368 | """ 369 | Returns center and radii of ROI region in (i,j) format 370 | """ 371 | # Data shape: 372 | # radius of the smallest and largest circles in mm estimated from the train set 373 | # convert to pixel counts 374 | 375 | pixel_spacing_X, pixel_spacing_Y, _,_ = pixel_spacing 376 | minradius = int(minradius_mm / pixel_spacing_X) 377 | maxradius = int(maxradius_mm / pixel_spacing_Y) 378 | 379 | ximagesize = data4D.shape[0] 380 | yimagesize = data4D.shape[1] 381 | zslices = data4D.shape[2] 382 | tframes = data4D.shape[3] 383 | xsurface = np.tile(range(ximagesize), (yimagesize, 1)).T 384 | ysurface = np.tile(range(yimagesize), (ximagesize, 1)) 385 | lsurface = np.zeros((ximagesize, yimagesize)) 386 | 387 | allcenters = [] 388 | allaccums = [] 389 | allradii = [] 390 | 391 | for slice in range(zslices): 392 | ff1 = np.array([data4D[:,:,slice, t] for t in range(tframes)]) 393 | fh = np.std(ff1, axis=0) 394 | fh[fh < 0.1 * np.max(fh)] = 0.0 395 | image = 1. * fh / np.max(fh) 396 | # find hough circles and detect two radii 397 | edges = canny(image, sigma=3) 398 | hough_radii = np.arange(minradius, maxradius, radstep) 399 | # print hough_radii 400 | hough_res = hough_circle(edges, hough_radii) 401 | if hough_res.any(): 402 | centers = [] 403 | accums = [] 404 | radii = [] 405 | for radius, h in zip(hough_radii, hough_res): 406 | # For each radius, extract num_peaks circles 407 | peaks = peak_local_max(h, num_peaks=num_peaks) 408 | centers.extend(peaks) 409 | accums.extend(h[peaks[:, 0], peaks[:, 1]]) 410 | radii.extend([radius] * num_peaks) 411 | 412 | # Keep the most prominent num_circles circles 413 | sorted_circles_idxs = np.argsort(accums)[::-1][:num_circles] 414 | 415 | for idx in sorted_circles_idxs: 416 | center_x, center_y = centers[idx] 417 | allcenters.append(centers[idx]) 418 | allradii.append(radii[idx]) 419 | allaccums.append(accums[idx]) 420 | brightness = accums[idx] 421 | lsurface = lsurface + brightness * np.exp( 422 | -((xsurface - center_x) ** 2 + (ysurface - center_y) ** 2) / kernel_width ** 2) 423 | 424 | lsurface = lsurface / lsurface.max() 425 | # select most likely ROI center 426 | roi_center = np.unravel_index(lsurface.argmax(), lsurface.shape) 427 | 428 | # determine ROI radius 429 | roi_x_radius = 0 430 | roi_y_radius = 0 431 | for idx in range(len(allcenters)): 432 | xshift = np.abs(allcenters[idx][0] - roi_center[0]) 433 | yshift = np.abs(allcenters[idx][1] - roi_center[1]) 434 | if (xshift <= center_margin) & (yshift <= center_margin): 435 | roi_x_radius = np.max((roi_x_radius, allradii[idx] + xshift)) 436 | roi_y_radius = np.max((roi_y_radius, allradii[idx] + yshift)) 437 | 438 | if roi_x_radius > 0 and roi_y_radius > 0: 439 | roi_radii = roi_x_radius, roi_y_radius 440 | else: 441 | roi_radii = None 442 | 443 | return roi_center, roi_radii 444 | 445 | 446 | class Dataset(object): 447 | def __init__(self, directory, subdir): 448 | # type: (object, object) -> object 449 | self.patient_data = {} 450 | self.directory = directory 451 | self.name = subdir 452 | 453 | def _filename(self, file): 454 | return os.path.join(self.directory, self.name, file) 455 | 456 | def load_nii(self, img_path): 457 | """ 458 | Function to load a 'nii' or 'nii.gz' file, The function returns 459 | everyting needed to save another 'nii' or 'nii.gz' 460 | in the same dimensional space, i.e. the affine matrix and the header 461 | 462 | Parameters 463 | ---------- 464 | 465 | img_path: string 466 | String with the path of the 'nii' or 'nii.gz' image file name. 467 | 468 | Returns 469 | ------- 470 | Three element, the first is a numpy array of the image values, 471 | the second is the affine transformation of the image, and the 472 | last one is the header of the image. 473 | """ 474 | nimg = nib.load(self._filename(img_path)) 475 | return nimg.get_data(), nimg.affine, nimg.header 476 | 477 | def read_patient_info_data(self): 478 | """ 479 | Reads patient data in the cfg file from patient folder 480 | using Info.cfg 481 | """ 482 | print (self._filename('Info.cfg')) 483 | with open(self._filename('Info.cfg')) as f_in: 484 | for line in f_in: 485 | l = line.rstrip().split(": ") 486 | self.patient_data[l[0]] = l[1] 487 | 488 | def read_patient_data(self, mode='train', roi_detect=True): 489 | """ 490 | Reads patient data in the cfg file and returns a dictionary and 491 | extract End diastole and End Systole image from patient folder 492 | using Info.cfg 493 | """ 494 | self.read_patient_info_data() 495 | # Read patient Number 496 | m = re.match("patient(\d{3})", self.name) 497 | patient_No = int(m.group(1)) 498 | # Read Diastole frame Number 499 | ED_frame_No = int(self.patient_data['ED']) 500 | ed_img = "patient%03d_frame%02d.nii.gz" %(patient_No, ED_frame_No) 501 | ed, affine, hdr = self.load_nii(ed_img) 502 | # Read Systole frame Number 503 | ES_frame_No = int(self.patient_data['ES']) 504 | es_img = "patient%03d_frame%02d.nii.gz" %(patient_No, ES_frame_No) 505 | es, _, _ = self.load_nii(es_img) 506 | # Save Images: 507 | self.patient_data['ED_VOL'] = ed 508 | self.patient_data['ES_VOL'] = es 509 | 510 | # Header Info for saving 511 | header_info ={'affine':affine, 'hdr': hdr} 512 | self.patient_data['header'] = header_info 513 | if mode == 'reader': 514 | # Read a particular volume number in 4D image 515 | img_4d_name = "patient%03d_4d.nii.gz"%patient_No 516 | # Load data 517 | img_4D, _, hdr = self.load_nii(img_4d_name) 518 | self.patient_data['4D'] = img_4D 519 | 520 | ed_gt, _, _ = self.load_nii("patient%03d_frame%02d_gt.nii.gz" %(patient_No, ED_frame_No)) 521 | es_gt, _, _ = self.load_nii("patient%03d_frame%02d_gt.nii.gz" %(patient_No, ES_frame_No)) 522 | ed_lv, ed_rv, ed_myo = heart_metrics(ed_gt, hdr.get_zooms()) 523 | es_lv, es_rv, es_myo = heart_metrics(es_gt, hdr.get_zooms()) 524 | ef_lv = ejection_fraction(ed_lv, es_lv) 525 | ef_rv = ejection_fraction(ed_rv, es_rv) 526 | heart_param = {'EDV_LV': ed_lv, 'EDV_RV': ed_rv, 'ESV_LV': es_lv, 'ESV_RV': es_rv, 527 | 'ED_MYO': ed_myo, 'ES_MYO': es_myo, 'EF_LV': ef_lv, 'EF_RV': ef_rv} 528 | self.patient_data['HP'] = heart_param 529 | self.patient_data['ED_GT'] = ed_gt 530 | self.patient_data['ES_GT'] = es_gt 531 | return 532 | 533 | if mode == 'train': 534 | ed_gt, _, _ = self.load_nii("patient%03d_frame%02d_gt.nii.gz" %(patient_No, ED_frame_No)) 535 | es_gt, _, _ = self.load_nii("patient%03d_frame%02d_gt.nii.gz" %(patient_No, ES_frame_No)) 536 | ed_lv, ed_rv, ed_myo = heart_metrics(ed_gt, hdr.get_zooms()) 537 | es_lv, es_rv, es_myo = heart_metrics(es_gt, hdr.get_zooms()) 538 | ef_lv = ejection_fraction(ed_lv, es_lv) 539 | ef_rv = ejection_fraction(ed_rv, es_rv) 540 | heart_param = {'EDV_LV': ed_lv, 'EDV_RV': ed_rv, 'ESV_LV': es_lv, 'ESV_RV': es_rv, 541 | 'ED_MYO': ed_myo, 'ES_MYO': es_myo, 'EF_LV': ef_lv, 'EF_RV': ef_rv} 542 | self.patient_data['HP'] = heart_param 543 | self.patient_data['ED_GT'] = ed_gt 544 | self.patient_data['ES_GT'] = es_gt 545 | 546 | if mode == 'tester': 547 | # Read a particular volume number in 4D image 548 | img_4d_name = "patient%03d_4d.nii.gz"%patient_No 549 | # Load data 550 | img_4D, _, hdr = self.load_nii(img_4d_name) 551 | self.patient_data['4D'] = img_4D 552 | 553 | if roi_detect: 554 | # Read a particular volume number in 4D image 555 | img_4d_name = "patient%03d_4d.nii.gz"%patient_No 556 | # Load data 557 | img_4D, _, hdr = self.load_nii(img_4d_name) 558 | c, r = extract_roi_stddev(img_4D, hdr.get_zooms()) 559 | self.patient_data['roi_center'], self.patient_data['roi_radii']=c,r 560 | self.patient_data['4D'] = img_4D 561 | # print c, r 562 | # plot_roi(img_4D, c,r) 563 | 564 | def convert_nii_np(data_path, mode, roi_detect): 565 | """ 566 | Prepare a dictionary of dataset and save it as numpy file 567 | """ 568 | patient_fulldata = OrderedDict() 569 | print (data_path) 570 | patient_folders = next(os.walk(data_path))[1] 571 | for patient in tqdm(sorted(patient_folders)): 572 | # print (patient) 573 | dset = Dataset(data_path, patient) 574 | dset.read_patient_data(mode=mode, roi_detect=roi_detect) 575 | patient_fulldata[dset.name] = dset.patient_data 576 | return patient_fulldata 577 | 578 | if __name__ == '__main__': 579 | start_time = time.time() 580 | # Path to ACDC training database 581 | complete_data_path = '../../ACDC_DataSet/training' 582 | dest_path = '../../processed_acdc_dataset' 583 | group_path = '../../processed_acdc_dataset/Patient_Groups' 584 | 585 | # Training dataset 586 | train_dataset = '../../processed_acdc_dataset/dataset/train_set' 587 | validation_dataset = '../../processed_acdc_dataset/dataset/validation_set' 588 | test_dataset = '../../processed_acdc_dataset/dataset/test_set' 589 | out_path_train = '../../processed_acdc_dataset/pickled/full_data' 590 | hdf5_out_path = '../../processed_acdc_dataset/hdf5_files' 591 | #Final Test dataset 592 | final_testing_dataset = '../../ACDC_DataSet/testing' 593 | out_path_test = '../../processed_acdc_dataset/pickled/final_test' 594 | 595 | # First perform stratified sampling 596 | group_patient_cases(complete_data_path, dest_path) 597 | generate_train_validate_test_set(group_path, dest_path) 598 | print("---Time taken to stratify the dataset %s seconds ---" % (time.time() - start_time)) 599 | 600 | print ('ROI->ED->ES train dataset') 601 | if not os.path.exists(out_path_train): 602 | os.makedirs(out_path_train) 603 | os.makedirs(out_path_test) 604 | 605 | train_dataset = convert_nii_np(train_dataset, mode='train', roi_detect=True) 606 | save_data(train_dataset, 'train_set.pkl', out_path_train) 607 | print("---Processing Training dataset %s seconds ---" % (time.time() - start_time)) 608 | validation_dataset = convert_nii_np(validation_dataset, mode='train', roi_detect=True) 609 | save_data(validation_dataset, 'validation_set.pkl', out_path_train) 610 | print("---Processing Training dataset %s seconds ---" % (time.time() - start_time)) 611 | test_dataset = convert_nii_np(test_dataset, mode='train', roi_detect=True) 612 | save_data(test_dataset, 'test_set.pkl', out_path_train) 613 | print("---Processing Training dataset %s seconds ---" % (time.time() - start_time)) 614 | 615 | print ('ROI->ED->ES test dataset') 616 | final_test_dataset = convert_nii_np(final_testing_dataset, mode='test', roi_detect=True) 617 | save_data(final_test_dataset, 'final_testing_data.pkl', out_path_test) 618 | print("---Processing final testing dataset %s seconds ---" % (time.time() - start_time)) 619 | 620 | # Generate 2D HDF5 files 621 | modes = ['train_set', 'validation_set', 'test_set'] 622 | for mode in modes: 623 | if os.path.exists(os.path.join(hdf5_out_path, mode)): 624 | shutil.rmtree(os.path.join(hdf5_out_path, mode)) 625 | os.makedirs(os.path.join(hdf5_out_path, mode)) 626 | patient_data = load_pkl(os.path.join(out_path_train, mode+'.pkl')) 627 | for patient_id in tqdm(patient_data.keys()): 628 | # print (patient_id) 629 | _id = patient_id[-3:] 630 | n_slices = patient_data[patient_id]['ED_VOL'].shape[2] 631 | # print (n_slices) 632 | for slice in range(n_slices): 633 | # ED frames 634 | group = patient_data[patient_id]['Group'] 635 | slice_str ='_%02d_'%slice 636 | roi_center = (patient_data[patient_id]['roi_center'][1], patient_data[patient_id]['roi_center'][0]) 637 | hp = h5py.File(os.path.join(hdf5_out_path, mode, 'P_'+_id+'_ED'+slice_str+group+'.hdf5'),'w') 638 | hp.create_dataset('image', data=patient_data[patient_id]['ED_VOL'][:,:,slice].T) 639 | hp.create_dataset('label', data=patient_data[patient_id]['ED_GT'][:,:,slice].T) 640 | hp.create_dataset('roi_center', data=roi_center) 641 | hp.create_dataset('roi_radii', data=patient_data[patient_id]['roi_radii']) 642 | hp.create_dataset('pixel_spacing', data=patient_data[patient_id]['header']['hdr'].get_zooms()) 643 | hp.close() 644 | # ES frames 645 | hp = h5py.File(os.path.join(hdf5_out_path, mode, 'P_'+_id+'_ES'+slice_str+group+'.hdf5'),'w') 646 | hp.create_dataset('image', data=patient_data[patient_id]['ES_VOL'][:,:,slice].T) 647 | hp.create_dataset('label', data=patient_data[patient_id]['ES_GT'][:,:,slice].T) 648 | hp.create_dataset('roi_center', data=roi_center) 649 | hp.create_dataset('roi_radii', data=patient_data[patient_id]['roi_radii']) 650 | hp.create_dataset('pixel_spacing', data=patient_data[patient_id]['header']['hdr'].get_zooms()) 651 | hp.close() 652 | print("---Time taken to generate hdf5 files %s seconds ---" % (time.time() - start_time)) -------------------------------------------------------------------------------- /libs/configs/config_acdc.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import numpy as np 3 | 4 | __C = edict() 5 | cfg = __C 6 | 7 | 8 | # ============== general training config ===================== 9 | __C.TRAIN = edict() 10 | 11 | # __C.TRAIN.NET = "unet.U_Net" 12 | __C.TRAIN.NET = "unet_df.U_NetDF" 13 | 14 | __C.TRAIN.LR = 0.001 15 | __C.TRAIN.LR_CLIP = 0.00001 16 | __C.TRAIN.DECAY_STEP_LIST = [60, 100, 150, 180] 17 | __C.TRAIN.LR_DECAY = 0.5 18 | 19 | __C.TRAIN.GRAD_NORM_CLIP = 1.0 20 | 21 | __C.TRAIN.OPTIMIZER = 'adam' 22 | __C.TRAIN.WEIGHT_DECAY = 0 # "L2 regularization coeff [default: 0.0]" 23 | __C.TRAIN.MOMENTUM = 0.9 24 | 25 | # =============== model config ======================== 26 | __C.MODEL = edict() 27 | 28 | __C.MODEL.SELFEATURE = True 29 | __C.MODEL.SHIFT_N = 1 30 | __C.MODEL.AUXSEG = True 31 | 32 | # ================= dataset config ========================== 33 | __C.DATASET = edict() 34 | 35 | __C.DATASET.NAME = "acdc" 36 | __C.DATASET.MEAN = 63.19523533061758 37 | __C.DATASET.STD = 70.74166957523165 38 | 39 | __C.DATASET.NUM_CLASS = 4 40 | 41 | __C.DATASET.DF_USED = True 42 | __C.DATASET.DF_NORM = True 43 | __C.DATASET.BOUNDARY = False 44 | 45 | __C.DATASET.TRAIN_LIST = "libs/datasets/jsonLists/acdcList/train.json" 46 | __C.DATASET.TEST_LIST = "libs/datasets/jsonLists/acdcList/test.json" 47 | 48 | __C.DATASET.TEST_PERSON_LIST = "libs/datasets/personList/AcdcTestPersonCarname.json" 49 | -------------------------------------------------------------------------------- /libs/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cardia_dataset import CardiaDataset 2 | from .acdc_dataset import AcdcDataset 3 | from .lvsc_dataset import LvscDataset -------------------------------------------------------------------------------- /libs/datasets/acdc_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torchvision import transforms as T 4 | 5 | import os 6 | import json 7 | import numpy as np 8 | from PIL import Image 9 | import h5py 10 | 11 | import sys 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.append(os.path.join(BASE_DIR, '../../')) 14 | from utils.direct_field.df_cardia import direct_field 15 | from libs.datasets import augment as standard_aug 16 | from utils.direct_field.utils_df import class2dist 17 | 18 | 19 | class AcdcDataset(Dataset): 20 | def __init__(self, data_list, df_used=False, joint_augment=None, augment=None, target_augment=None, df_norm=True, boundary=False): 21 | self.joint_augment = joint_augment 22 | self.augment = augment 23 | self.target_augment = target_augment 24 | self.data_list = data_list 25 | self.df_used = df_used 26 | self.df_norm = df_norm 27 | self.boundary = boundary 28 | 29 | with open(data_list, 'r') as f: 30 | self.data_infos = json.load(f) 31 | 32 | def __len__(self): 33 | return len(self.data_infos) 34 | 35 | def __getitem__(self,index): 36 | img = h5py.File(self.data_infos[index],'r')['image'] 37 | gt = h5py.File(self.data_infos[index],'r')['label'] 38 | # print(np.unique(gt)) 39 | img = np.array(img)[:,:,None].astype(np.float32) 40 | gt = np.array(gt)[:,:,None].astype(np.float32) 41 | # print(np.unique(gt)) 42 | 43 | if self.joint_augment is not None: 44 | img, gt = self.joint_augment(img, gt) 45 | if self.augment is not None: 46 | img = self.augment(img) 47 | if self.target_augment is not None: 48 | gt = self.target_augment(gt) 49 | 50 | if self.df_used: 51 | gt_df = direct_field(gt.numpy()[0], norm=self.df_norm) 52 | gt_df = torch.from_numpy(gt_df) 53 | else: 54 | gt_df = None 55 | 56 | if self.boundary: 57 | dist_map = torch.from_numpy(class2dist(gt.numpy()[0], C=4)) 58 | else: 59 | dist_map = None 60 | 61 | return img, gt, gt_df, dist_map 62 | 63 | 64 | -------------------------------------------------------------------------------- /libs/datasets/augment.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import random 4 | import torch 5 | import numpy as np 6 | from PIL import Image, ImageEnhance 7 | 8 | class Compose(): 9 | def __init__(self, transforms): 10 | self.transforms = transforms 11 | 12 | def __call__(self, img): 13 | for t in self.transforms: 14 | img = t(img) 15 | return img 16 | 17 | class to_Tensor(): 18 | def __call__(self,arr): 19 | if len(np.array(arr).shape) == 2: 20 | arr = np.array(arr)[:,:,None] 21 | arr = torch.from_numpy(np.array(arr).transpose(2,0,1)) 22 | return arr 23 | 24 | def imresize(im, size, interp='bilinear'): 25 | if interp == 'nearest': 26 | resample = Image.NEAREST 27 | elif interp == 'bilinear': 28 | resample = Image.BILINEAR 29 | elif interp == 'bicubic': 30 | resample = Image.BICUBIC 31 | else: 32 | raise Exception('resample method undefined!') 33 | return im.resize(size, resample) 34 | 35 | class To_PIL_Image(): 36 | def __call__(self, img): 37 | return to_pil_image(img) 38 | 39 | class normalize(): 40 | def __init__(self,mean,std): 41 | self.mean = torch.tensor(mean) 42 | self.std = torch.tensor(std) 43 | def __call__(self,img): 44 | self.mean = torch.as_tensor(self.mean,dtype=img.dtype,device=img.device) 45 | self.std = torch.as_tensor(self.std,dtype=img.dtype,device=img.device) 46 | return (img-self.mean)/self.std 47 | 48 | class RandomVerticalFlip(): 49 | def __init__(self, prob): 50 | self.prob = prob 51 | 52 | def __call__(self, img): 53 | if random.random() < self.prob: 54 | if isinstance(img, Image.Image): 55 | return img.transpose(Image.FLIP_TOP_BOTTOM) 56 | if isinstance(img, np.ndarray): 57 | return np.flip(img, axis=0) 58 | return img 59 | 60 | class RandomHorizontallyFlip(): 61 | def __init__(self, prob=0.5): 62 | self.prob = prob 63 | 64 | def __call__(self, img): 65 | if random.random() < self.prob: 66 | if isinstance(img, Image.Image): 67 | return img.transpose(Image.FLIP_LEFT_RIGHT) 68 | if isinstance(img, np.ndarray): 69 | return np.flip(img, axis=1) 70 | return img 71 | 72 | class RandomRotate(): 73 | def __init__(self, degree, prob=0.5): 74 | self.prob = prob 75 | self.degree = degree 76 | 77 | def __call__(self, img, interpolation=Image.BILINEAR): 78 | if random.random() < self.prob: 79 | rotate_detree = random.random() * 2 * self.degree - self.degree 80 | return img.rotate(rotate_detree, interpolation) 81 | return img 82 | 83 | class RandomBrightness(): 84 | def __init__(self, min_factor, max_factor, prob=0.5): 85 | """ :param min_factor: The value between 0.0 and max_factor 86 | that define the minimum adjustment of image brightness. 87 | The value 0.0 gives a black image,The value 1.0 gives the original image, value bigger than 1.0 gives more bright image. 88 | :param max_factor: A value should be bigger than min_factor. 89 | that define the maximum adjustment of image brightness. 90 | The value 0.0 gives a black image, value 1.0 gives the original image, value bigger than 1.0 gives more bright image. 91 | 92 | """ 93 | self.prob = prob 94 | self.min_factor = min_factor 95 | self.max_factor = max_factor 96 | 97 | # def __brightness(self, img, factor): 98 | # return img * (1.0 - factor) + img * factor 99 | 100 | # def __call__(self, img): 101 | # if random.random() < self.prob: 102 | # factor = np.random.uniform(self.min_factor, self.max_factor) 103 | # return self.__brightness(img, factor) 104 | 105 | def __call__(self, img): 106 | if random.random() < self.prob: 107 | factor = np.random.uniform(self.min_factor, self.max_factor) 108 | enhancer_brightness = ImageEnhance.Brightness(img) 109 | return enhancer_brightness.enhance(factor) 110 | 111 | return img 112 | 113 | class RandomContrast(): 114 | def __init__(self, min_factor, max_factor, prob=0.5): 115 | """ :param min_factor: The value between 0.0 and max_factor 116 | that define the minimum adjustment of image contrast. 117 | The value 0.0 gives s solid grey image, value 1.0 gives the original image. 118 | :param max_factor: A value should be bigger than min_factor. 119 | that define the maximum adjustment of image contrast. 120 | The value 0.0 gives s solid grey image, value 1.0 gives the original image. 121 | """ 122 | self.prob = prob 123 | self.min_factor = min_factor 124 | self.max_factor = max_factor 125 | 126 | def __call__(self, img): 127 | if random.random() < self.prob: 128 | factor = np.random.uniform(self.min_factor, self.max_factor) 129 | enhance_contrast = ImageEnhance.Contrast(img) 130 | return enhance_contrast.enhance(factor) 131 | return img 132 | 133 | def to_pil_image(pic, mode=None): 134 | """Convert a tensor or an ndarray to PIL Image. 135 | 136 | See :class:`~torchvision.transforms.ToPIlImage` for more details. 137 | 138 | Args: 139 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 140 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 141 | 142 | .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 143 | 144 | Returns: 145 | PIL Image: Image converted to PIL Image. 146 | """ 147 | # if not(_is_numpy_image(pic) or _is_tensor_image(pic)): 148 | # raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) 149 | 150 | npimg = pic 151 | if isinstance(pic, torch.FloatTensor): 152 | pic = pic.mul(255).byte() 153 | if torch.is_tensor(pic): 154 | npimg = np.transpose(pic.numpy(), (1, 2, 0)) 155 | 156 | if not isinstance(npimg, np.ndarray): 157 | raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 158 | 'not {}'.format(type(npimg))) 159 | 160 | if npimg.shape[2] == 1: 161 | expected_mode = None 162 | npimg = npimg[:, :, 0] 163 | if npimg.dtype == np.uint8: 164 | expected_mode = 'L' 165 | elif npimg.dtype == np.int16: 166 | expected_mode = 'I;16' 167 | elif npimg.dtype == np.int32: 168 | expected_mode = 'I' 169 | elif npimg.dtype == np.float32: 170 | expected_mode = 'F' 171 | if mode is not None and mode != expected_mode: 172 | raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" 173 | .format(mode, np.dtype, expected_mode)) 174 | mode = expected_mode 175 | 176 | elif npimg.shape[2] == 4: 177 | permitted_4_channel_modes = ['RGBA', 'CMYK'] 178 | if mode is not None and mode not in permitted_4_channel_modes: 179 | raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) 180 | 181 | if mode is None and npimg.dtype == np.uint8: 182 | mode = 'RGBA' 183 | else: 184 | permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] 185 | if mode is not None and mode not in permitted_3_channel_modes: 186 | raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) 187 | if mode is None and npimg.dtype == np.uint8: 188 | mode = 'RGB' 189 | 190 | if mode is None: 191 | raise TypeError('Input type {} is not supported'.format(npimg.dtype)) 192 | 193 | return Image.fromarray(npimg, mode=mode) 194 | -------------------------------------------------------------------------------- /libs/datasets/collate_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from utils.image_list import to_image_list 3 | 4 | class BatchCollator(object): 5 | """ 6 | From a list of samples from the dataset, 7 | returns the batched images and targets. 8 | This should be passed to the DataLoader 9 | """ 10 | 11 | def __init__(self, size_divisible=0, df_used=False, boundary=False): 12 | self.size_divisible = size_divisible 13 | self.df_used = df_used 14 | self.boundary = boundary 15 | 16 | def __call__(self, batch): 17 | transposed_batch = list(zip(*batch)) 18 | images = to_image_list(transposed_batch[0], self.size_divisible) 19 | targets = to_image_list(transposed_batch[1], self.size_divisible) 20 | 21 | dfs = to_image_list(transposed_batch[2], self.size_divisible) if self.df_used else None 22 | 23 | dist_maps = to_image_list(transposed_batch[3], self.size_divisible) if self.boundary else None 24 | 25 | return images, targets, dfs, dist_maps 26 | 27 | -------------------------------------------------------------------------------- /libs/datasets/gen_acdcjson.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nibabel as nib 3 | import numpy as np 4 | import json 5 | 6 | ''' 7 | DCM,HCM,MINF,NOR,RV 8 | ''' 9 | 10 | def load_nii(nii_path): 11 | data = nib.load(nii_path) 12 | img = data.get_data() 13 | affine = data.affine 14 | header = data.header 15 | return img ,affine,header 16 | 17 | 18 | def read_Infocfg(cfg_path): 19 | patient_info = {} 20 | 21 | 22 | 23 | 24 | with open (cfg_path) as f_in: 25 | for line in f_in: 26 | l = line.rstrip().split(":") 27 | #l is the list of the patient_info 28 | print(len(l)) 29 | patient_info[l[0]] = l[1] 30 | # print(patient_info) 31 | ''' 32 | ['ED', ' 1'] 33 | ['ES', ' 12'] 34 | ['Group', ' DCM'] 35 | ['Height', ' 184.0'] 36 | ['NbFrame', ' 30'] 37 | ['Weight', ' 95.0'] 38 | {'ED': ' 1', 'ES': ' 12', 'Group': ' DCM', 'Height': ' 184.0', 'NbFrame': ' 30', 'Weight': ' 95.0'} 39 | ''' 40 | return patient_info 41 | 42 | def read_json(fpath): 43 | with open(fpath,'r') as f: 44 | obj = json.load(f) 45 | return obj 46 | 47 | def write_json(obj, fpath): 48 | with open(fpath,'w') as f: 49 | json.dump(obj,f,indent=4) 50 | 51 | def write_json_append(obj, fpath): 52 | with open(fpath,'a') as f: 53 | json.dump(obj,f,indent=4) 54 | 55 | def gen_alldatalist(path): 56 | filelist = [] 57 | for dir in os.listdir(path): 58 | 59 | for file in os.listdir(os.path.join(root_path,dir)): 60 | filelist.append(os.path.join(root_path,dir,file)) 61 | #the length of the filelist is 1902 62 | # print(len(filelist)) 63 | filelist.sort() 64 | out_dir = os.path.dirname(os.path.abspath(__file__)) 65 | #/home/ffbian/chencheng/XieheCardiac/2DUNet/UNet/libs/datasets 66 | write_json(filelist, os.path.join(out_dir, "./acdcjson/ACDCDataList.json")) 67 | 68 | def gen_every_kind_datalist(kind_path): 69 | filelist = [] 70 | 71 | for file in os.listdir(kind_path): 72 | filelist.append(os.path.join(kind_path,file)) 73 | 74 | filelist.sort() 75 | out_dir = os.path.dirname(os.path.abspath(__file__)) 76 | #/home/ffbian/chencheng/XieheCardiac/2DUNet/UNet/libs/datasets 77 | write_json(filelist, os.path.join(out_dir, "./acdcjson/{}DataList.json".format(kind_path[-2:]))) 78 | 79 | def generate_train_test_list(json_file_path): 80 | # json_file = "/home/fcheng/Cardia/DataList.json" 81 | # json_file = "/home/ffbian/chencheng/XieheCardiac/2DUNet/UNet/libs/datasets/ACDCDataList.json" 82 | fileslist = read_json(json_file_path) 83 | 84 | nums = len(fileslist) 85 | train_ind = set(np.random.choice(nums, size=int(np.ceil(0.8 * nums)), replace=False)) 86 | test_ind = set(np.arange(nums)) - train_ind 87 | 88 | test_ind = list(test_ind) 89 | test_ind.sort() 90 | 91 | train_list = [fileslist[fl] for fl in train_ind] 92 | test_list = [fileslist[fl] for fl in test_ind] 93 | 94 | out_dir = os.path.dirname(os.path.abspath(__file__)) 95 | print(out_dir) 96 | write_json(train_list, os.path.join(out_dir, "./acdcjson/RVtrain.json")) 97 | write_json(test_list, os.path.join(out_dir, "./acdcjson/RVtest.json")) 98 | write_json_append(train_list, os.path.join(out_dir, "./acdcjson/train.json")) 99 | write_json_append(test_list, os.path.join(out_dir, "./acdcjson/test.json")) 100 | # write_json(test_list, "/home/ffbian/chencheng/XieheCardiac/2DUNet/UNet/libs/datasets/differentkind/kuodatest.json") 101 | 102 | 103 | if __name__ == "__main__": 104 | # read_Infocfg(os.path.join(root_path,"./DCM/patient001/Info.cfg")) 105 | ''' 106 | a,b,c = load_nii(os.path.join(root_path,"./patient001/patient001_frame01.nii.gz")) 107 | print(np.unique(a)) 108 | 109 | print(c) 110 | ''' 111 | root_path = "/home/ffbian/chencheng/MICCAIACDC2017/processed_acdc_dataset/hdf5_files/Bykind" 112 | kind_path = "/home/ffbian/chencheng/MICCAIACDC2017/processed_acdc_dataset/hdf5_files/Bykind/RV" 113 | 114 | # gen_alldatalist(root_path) 115 | json_file_path = "/home/ffbian/chencheng/MICCAIACDC2017/mycode/libs/dataset/acdcjson/RVDataList.json" 116 | 117 | # generate_train_test_list(json_file_path) 118 | # gen_every_kind_datalist(kind_path) 119 | result = read_json(json_file_path) 120 | print(result) -------------------------------------------------------------------------------- /libs/datasets/generate_FilesList.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | 5 | def read_json(fpath): 6 | with open(fpath,'r') as f: 7 | obj = json.load(f) 8 | return obj 9 | 10 | def write_json(obj, fpath): 11 | with open(fpath,'w') as f: 12 | json.dump(obj,f,indent=4) 13 | 14 | def generate_fileslist(): 15 | root = "/home/ffbian/chencheng/XieheCardiac/npydata/" 16 | 17 | cars = os.listdir(root) 18 | cars.sort() 19 | 20 | fileslist = [] # [(path, index_time), (), ...] 21 | # 疾病类型 22 | for car in cars: 23 | persons = os.listdir(os.path.join(root, car)) 24 | persons.sort() 25 | 26 | # 病人个体 27 | for person in persons: 28 | sliceds = os.listdir(os.path.join(root, car, person, "imgs")) 29 | sliceds.sort() 30 | 31 | # 切片位置 32 | for sliced in sliceds: 33 | file_p = os.path.join(root, car, person, "imgs", sliced) 34 | npy = np.load(file_p) 35 | time_n = npy.shape[-1] # 25, 20, 11, 50 36 | 37 | # 时序 38 | for i in range(time_n): 39 | fileslist.append((file_p, i)) 40 | 41 | out_dir = os.path.dirname(os.path.abspath(__file__)) 42 | # write_json(fileslist, os.path.join(out_dir, "DataList.json")) 43 | print(len(fileslist)) 44 | 45 | def generate_train_test_list(): 46 | json_file = "/home/fcheng/Cardia/DataList.json" 47 | fileslist = read_json(json_file) 48 | 49 | nums = len(fileslist) 50 | train_ind = set(np.random.choice(nums, size=int(np.ceil(0.8*nums)), replace=False)) 51 | test_ind = set(np.arange(nums)) - train_ind 52 | 53 | test_ind = list(test_ind) 54 | test_ind.sort() 55 | 56 | train_list = [fileslist[fl] for fl in train_ind] 57 | test_list = [fileslist[fl] for fl in test_ind] 58 | 59 | out_dir = os.path.dirname(os.path.abspath(__file__)) 60 | write_json(train_list, os.path.join(out_dir, "train.json")) 61 | write_json(test_list, os.path.join(out_dir, "test.json")) 62 | 63 | def generate_N_list(N=50000): 64 | json_file = "/home/fcheng/Cardia/source_code/libs/datasets/DataList.json" 65 | fileslist = read_json(json_file) 66 | 67 | nums = len(fileslist) 68 | train_ind = set(np.random.choice(nums, size=N, replace=False)) 69 | test_ind = set(np.arange(nums)) - train_ind 70 | 71 | test_ind = list(test_ind) 72 | test_ind.sort() 73 | 74 | train_list = [fileslist[fl] for fl in train_ind] 75 | test_list = [fileslist[fl] for fl in test_ind] 76 | 77 | out_dir = os.path.dirname(os.path.abspath(__file__)) 78 | write_json(train_list, os.path.join(out_dir, "train_{}.json".format(N))) 79 | write_json(test_list, os.path.join(out_dir, "test_{}.json".format(N))) 80 | 81 | def gene_uniform_List(ratio=0.8, N=None): 82 | root = "/home/ffbian/chencheng/XieheCardiac/2DUNet/UNet/libs/datasets/differentkind/" 83 | cars = os.listdir(root) 84 | cars = [c for c in cars if "test" not in c] 85 | cars.sort() 86 | 87 | train_List = [] 88 | test_List = [] 89 | 90 | total_num = 0 91 | for json_f in cars: 92 | json_list = read_json(os.path.join(root, json_f)) 93 | total_num += len(json_list) 94 | 95 | if N is not None: 96 | ratio = N / total_num 97 | 98 | train_num = 0 99 | for json_f in cars: 100 | json_list = read_json(os.path.join(root, json_f)) 101 | train_num += np.ceil(ratio*len(json_list)) 102 | print(json_f, len(json_list), np.ceil(ratio*len(json_list))) 103 | 104 | ta_ind = set(np.random.choice(len(json_list), size=int(np.ceil(ratio*len(json_list))), replace=False)) 105 | te_ind = set(np.arange(len(json_list))) - ta_ind 106 | 107 | ta_ind = list(ta_ind) 108 | ta_ind.sort() 109 | te_ind = list(te_ind) 110 | te_ind.sort() 111 | 112 | train_List += [json_list[i] for i in ta_ind] 113 | test_List += [json_list[i] for i in te_ind] 114 | 115 | print(total_num, train_num) 116 | print(len(train_List), len(test_List)) 117 | 118 | out_dir = os.path.dirname(os.path.abspath(__file__)) 119 | write_json(train_List, os.path.join(out_dir, "train_{}.json".format(N))) 120 | write_json(test_List, os.path.join(out_dir, "test_{}.json".format(N))) 121 | 122 | if __name__ == "__main__": 123 | # generate_fileslist() 124 | # generate_train_test_list() 125 | # generate_N_list(N=30000) 126 | gene_uniform_List(N=30000) -------------------------------------------------------------------------------- /libs/datasets/generate_personList.py: -------------------------------------------------------------------------------- 1 | import json, h5py 2 | # acdc_test_json = "/home/ffbian/chencheng/MICCAIACDC2017/mycode/libs/dataset/acdcjson/test.json" 3 | # acdc_train_json = "/home/ffbian/chencheng/MICCAIACDC2017/mycode/libs/dataset/acdcjson/train.json" 4 | acdc_test_json = "/home/ffbian/chencheng/MICCAIACDC2017/mycode/libs/dataset/acdcjson/Dense_TestList.json" 5 | acdc_train_json = "/root/chengfeng/Cardiac/source_code/libs/datasets/jsonLists/acdcList/Dense_TrainList.json" 6 | 7 | 8 | def func(path): 9 | with open(path, 'r') as f: 10 | test_list = json.load(f) 11 | print(len(test_list)) 12 | 13 | name_list = set() 14 | for tl in test_list: 15 | a0, a1 = tl.split('/')[-2:] 16 | a1 = "_".join(a1.split('_')[:2]) 17 | name_list.add(a0+'-'+a1) 18 | 19 | name_list = list(name_list) 20 | name_list.sort() 21 | return name_list 22 | 23 | name_list = func(acdc_train_json) 24 | # name_list = func(acdc_test_json) 25 | # name_list = name_list_train + name_list_test 26 | name_list.sort() 27 | 28 | # outfile = "./AcdcPersonCarname.txt" 29 | # with open(outfile, "w") as f: 30 | # # for nl in name_list: 31 | # f.writelines('\n'.join(name_list)) 32 | 33 | outfile = "./AcdcDenseTrainPersonCarname.json" 34 | with open(outfile, "w") as f: 35 | json.dump(name_list, f) 36 | 37 | -------------------------------------------------------------------------------- /libs/datasets/joint_augment.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import random 4 | import numbers 5 | from PIL import Image, ImageEnhance, ImageOps 6 | from .augment import to_pil_image 7 | import torch 8 | from torchvision.transforms import functional as F 9 | 10 | 11 | class Compose(): 12 | def __init__(self, transforms): 13 | self.transforms = transforms 14 | 15 | def __call__(self, img, mask): 16 | assert img.size == mask.size 17 | for t in self.transforms: 18 | img, mask = t(img, mask) 19 | return img, mask 20 | 21 | class To_Tensor(): 22 | def __call__(self, arr, arr2): 23 | if len(np.array(arr).shape) == 2: 24 | arr = np.array(arr)[:,:,None] 25 | if len(np.array(arr2).shape) == 2: 26 | arr2 = np.array(arr2)[:, :, None] 27 | arr = torch.from_numpy(np.array(arr).transpose(2,0,1)) 28 | arr2 = torch.from_numpy(np.array(arr2).transpose(2,0,1)) 29 | return arr, arr2 30 | 31 | class To_PIL_Image(): 32 | def __call__(self, img, mask): 33 | img = to_pil_image(img) 34 | mask = to_pil_image(mask) 35 | return img, mask 36 | 37 | class RandomVerticalFlip(): 38 | def __init__(self, prob): 39 | self.prob = prob 40 | 41 | def __call__(self, img, mask): 42 | if random.random() < self.prob: 43 | if isinstance(img, Image.Image): 44 | return img.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose(Image.FLIP_TOP_BOTTOM) 45 | if isinstance(img, np.ndarray): 46 | return np.flip(img, axis=0), np.flip(mask, axis=0) 47 | return img, mask 48 | 49 | class RandomHorizontallyFlip(): 50 | def __init__(self, prob=0.5): 51 | self.prob = prob 52 | 53 | def __call__(self, img, mask): 54 | if random.random() < self.prob: 55 | if isinstance(img, Image.Image): 56 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 57 | if isinstance(img, np.ndarray): 58 | return np.flip(img, axis=1), np.flip(mask, axis=1) 59 | return img, mask 60 | 61 | class RandomRotate(): 62 | def __init__(self, degrees, prob=0.5): 63 | self.prob = prob 64 | self.degrees = degrees 65 | 66 | def __call__(self, img, mask): 67 | if random.random() < self.prob: 68 | rotate_detree = random.uniform(self.degrees[0], self.degrees[1]) 69 | return img.rotate(rotate_detree, Image.BILINEAR), mask.rotate(rotate_detree, Image.NEAREST) 70 | return img, mask 71 | 72 | class FixResize(): 73 | def __init__(self, size): 74 | if isinstance(size, numbers.Number): 75 | self.size = (int(size), int(size)) 76 | else: 77 | self.size = size 78 | 79 | def __call__(self, img=None, mask=None): 80 | if mask is None: 81 | return img.resize(self.size, Image.BILINEAR) 82 | if img is None: 83 | return mask.resize(self.size, Image.NEAREST) 84 | return img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST) 85 | 86 | class Scale(object): 87 | def __init__(self, size): 88 | self.size = size 89 | 90 | def __call__(self, img, mask): 91 | assert img.size == mask.size 92 | w, h = img.size 93 | if (w >= h and w == self.size) or (h >= w and h == self.size): 94 | return img, mask 95 | if w > h: 96 | ow = self.size 97 | oh = int(self.size * h / w) 98 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 99 | else: 100 | oh = self.size 101 | ow = int(self.size * w / h) 102 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 103 | 104 | class RandomCrop(object): 105 | def __init__(self, size, padding=0): 106 | if isinstance(size, numbers.Number): 107 | self.size = (int(size), int(size)) 108 | else: 109 | self.size = size 110 | self.padding = padding 111 | 112 | def __call__(self, img, mask): 113 | if self.padding > 0: 114 | img = ImageOps.expand(img, border=self.padding, fill=0) 115 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 116 | 117 | assert img.size == mask.size 118 | w, h = img.size 119 | th, tw = self.size 120 | if w == tw and h == th: 121 | return img, mask 122 | if w < tw or h < th: 123 | return img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST) 124 | 125 | x1 = random.randint(0, w - tw) 126 | y1 = random.randint(0, h - th) 127 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 128 | 129 | class RandomSized(object): 130 | def __init__(self, size): 131 | self.size = size 132 | self.scale = Scale(self.size) 133 | self.crop = RandomCrop(self.size) 134 | 135 | def __call__(self, img, mask): 136 | assert img.size == mask.size 137 | 138 | w = int(random.uniform(0.5, 2) * img.size[0]) 139 | h = int(random.uniform(0.5, 2) * img.size[1]) 140 | 141 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) 142 | 143 | return self.crop(*self.scale(img, mask)) 144 | 145 | class ScaleRatio(): 146 | def __init__(self, scale_factor=1): 147 | self.scale_factor = scale_factor 148 | 149 | def __call__(self, img, interpolation): 150 | w, h = img.size 151 | new_h = int(h * self.scale_factor) 152 | new_w = int(w * self.scale_factor) 153 | return img.resize((new_w, new_h), interpolation) 154 | 155 | class RandomScale(): 156 | def __init__(self, min_factor=0.8, max_factor=1.2, prob=0.5): 157 | self.min_factor = min_factor 158 | self.max_factor = max_factor 159 | self.prob = prob 160 | 161 | def __scale(self, img, scale_factor, interpolation): 162 | w, h = img.size 163 | new_h = int(h * scale_factor) 164 | new_w = int(w * scale_factor) 165 | return img.resize((new_w, new_h), interpolation) 166 | 167 | def __call__(self, img, mask): 168 | if random.random() < self.prob: 169 | factor = np.random.uniform(self.min_factor, self.max_factor) 170 | return self.__scale(img, factor, Image.BILINEAR), self.__scale(mask, factor, Image.NEAREST) 171 | return img, mask 172 | 173 | 174 | class Resize(object): 175 | def __init__(self, min_size, max_size): 176 | if not isinstance(min_size, (list, tuple)): 177 | min_size = (min_size,) 178 | self.min_size = min_size 179 | self.max_size = max_size 180 | 181 | # modified from torchvision to add support for max size 182 | def get_size(self, image_size): 183 | w, h = image_size 184 | size = random.choice(self.min_size) 185 | max_size = self.max_size 186 | if max_size is not None: 187 | min_original_size = float(min((w, h))) 188 | max_original_size = float(max((w, h))) 189 | if max_original_size / min_original_size * size > max_size: 190 | size = int(round(max_size * min_original_size / max_original_size)) 191 | 192 | if (w <= h and w == size) or (h <= w and h == size): 193 | return (h, w) 194 | 195 | if w < h: 196 | ow = size 197 | oh = int(size * h / w) 198 | else: 199 | oh = size 200 | ow = int(size * w / h) 201 | 202 | return (oh, ow) 203 | 204 | def __call__(self, image, target=None): 205 | size = self.get_size(image.size) 206 | image = F.resize(image, size) 207 | if isinstance(target, list): 208 | target = [t.resize(image.size) for t in target] 209 | elif target is None: 210 | return image 211 | else: 212 | target = target.resize(image.size, Image.NEAREST) 213 | return image, target 214 | 215 | 216 | class RandomAffine(object): 217 | """Random affine transformation of the image keeping center invariant 218 | 219 | Args: 220 | degrees (sequence or float or int): Range of degrees to select from. 221 | If degrees is a number instead of sequence like (min, max), the range of degrees 222 | will be (-degrees, +degrees). Set to 0 to desactivate rotations. 223 | translate (tuple, optional): tuple of maximum absolute fraction for horizontal 224 | and vertical translations. For example translate=(a, b), then horizontal shift 225 | is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is 226 | randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. 227 | scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is 228 | randomly sampled from the range a <= scale <= b. Will keep original scale by default. 229 | shear (sequence or float or int, optional): Range of degrees to select from. 230 | If degrees is a number instead of sequence like (min, max), the range of degrees 231 | will be (-degrees, +degrees). Will not apply shear by default 232 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 233 | An optional resampling filter. 234 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 235 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 236 | fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) 237 | """ 238 | 239 | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0, prob=0.5): 240 | self.prob = prob 241 | if isinstance(degrees, numbers.Number): 242 | if degrees < 0: 243 | raise ValueError("If degrees is a single number, it must be positive.") 244 | self.degrees = (-degrees, degrees) 245 | else: 246 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ 247 | "degrees should be a list or tuple and it must be of length 2." 248 | self.degrees = degrees 249 | 250 | if translate is not None: 251 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 252 | "translate should be a list or tuple and it must be of length 2." 253 | for t in translate: 254 | if not (0.0 <= t <= 1.0): 255 | raise ValueError("translation values should be between 0 and 1") 256 | self.translate = translate 257 | 258 | if scale is not None: 259 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 260 | "scale should be a list or tuple and it must be of length 2." 261 | for s in scale: 262 | if s <= 0: 263 | raise ValueError("scale values should be positive") 264 | self.scale = scale 265 | 266 | if shear is not None: 267 | if isinstance(shear, numbers.Number): 268 | if shear < 0: 269 | raise ValueError("If shear is a single number, it must be positive.") 270 | self.shear = (-shear, shear) 271 | else: 272 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ 273 | "shear should be a list or tuple and it must be of length 2." 274 | self.shear = shear 275 | else: 276 | self.shear = shear 277 | 278 | self.resample = resample 279 | self.fillcolor = fillcolor 280 | 281 | @staticmethod 282 | def get_params(degrees, translate, scale_ranges, shears, img_size): 283 | """Get parameters for affine transformation 284 | 285 | Returns: 286 | sequence: params to be passed to the affine transformation 287 | """ 288 | angle = random.uniform(degrees[0], degrees[1]) 289 | if translate is not None: 290 | max_dx = translate[0] * img_size[0] 291 | max_dy = translate[1] * img_size[1] 292 | translations = (np.round(random.uniform(-max_dx, max_dx)), 293 | np.round(random.uniform(-max_dy, max_dy))) 294 | else: 295 | translations = (0, 0) 296 | 297 | if scale_ranges is not None: 298 | scale = random.uniform(scale_ranges[0], scale_ranges[1]) 299 | else: 300 | scale = 1.0 301 | 302 | if shears is not None: 303 | shear = random.uniform(shears[0], shears[1]) 304 | else: 305 | shear = 0.0 306 | 307 | return angle, translations, scale, shear 308 | 309 | def __call__(self, img, mask): 310 | """ 311 | img (PIL Image): Image to be transformed. 312 | 313 | Returns: 314 | PIL Image: Affine transformed image. 315 | """ 316 | if random.random() < self.prob: 317 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) 318 | return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor), F.affine(mask, *ret, resample=self.resample, fillcolor=self.fillcolor) 319 | return img, mask 320 | 321 | def __repr__(self): 322 | s = '{name}(degrees={degrees}' 323 | if self.translate is not None: 324 | s += ', translate={translate}' 325 | if self.scale is not None: 326 | s += ', scale={scale}' 327 | if self.shear is not None: 328 | s += ', shear={shear}' 329 | if self.resample > 0: 330 | s += ', resample={resample}' 331 | if self.fillcolor != 0: 332 | s += ', fillcolor={fillcolor}' 333 | s += ')' 334 | d = dict(self.__dict__) 335 | d['resample'] = _pil_interpolation_to_str[d['resample']] 336 | return s.format(name=self.__class__.__name__, **d) 337 | -------------------------------------------------------------------------------- /libs/losses/ce_ieLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class CE_IELoss(nn.Module): 5 | """ 6 | CrossEntropy Loss with Information Entropy as a regularization 7 | """ 8 | def __init__(self, eps=0.5, reduction='mean'): 9 | super(CE_IELoss, self).__init__() 10 | self.eps = eps 11 | self.nll = nn.NLLLoss(reduction=reduction) 12 | self.softmax = nn.Softmax(1) 13 | 14 | def update_eps(self): 15 | self.eps = self.eps * 0.1 16 | 17 | def forward(self, outputs, labels): 18 | """ 19 | :param outputs: [b, c] 20 | :param labels: [b,] 21 | :return: a loss (Variable) 22 | """ 23 | outputs = self.softmax(outputs) # probabilities 24 | ce = self.nll(outputs.log(), labels) 25 | reg = outputs * outputs.log() 26 | reg = reg.sum(1).mean() 27 | loss_total = ce + reg * self.eps 28 | return loss_total #, ce, reg -------------------------------------------------------------------------------- /libs/losses/create_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from libs.losses.df_loss import EuclideanLossWithOHEM 5 | from libs.losses.mag_angle_loss import EuclideanAngleLossWithOHEM 6 | from libs.losses.surface_loss import SurfaceLoss 7 | 8 | class Total_loss(): 9 | def __init__(self, boundary=False): 10 | self.df_loss = EuclideanAngleLossWithOHEM() 11 | self.boundary = boundary 12 | if boundary: 13 | self.boundary_loss = SurfaceLoss(idc=[1,2,3]) 14 | 15 | def __call__(self, net_logit, dist_maps, df_out, gts_df, gts): 16 | df_loss = self.df_loss(df_out, gts_df, gts[:, None, ...]) 17 | 18 | if self.boundary: 19 | net_prob = nn.functional.softmax(net_logit, dim=1) 20 | b_loss = self.boundary_loss(net_prob, dist_maps, gts) 21 | else: 22 | b_loss = torch.tensor([0.], device=net_logit.device) 23 | 24 | return df_loss, b_loss 25 | 26 | -------------------------------------------------------------------------------- /libs/losses/df_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | def LossSegDF(net_ret, data, device="cuda"): 8 | net_out, df_out = net_ret 9 | 10 | _, gts, gts_df = data 11 | gts = torch.squeeze(gts, 1).to(device).long() 12 | gts_df = gts_df.to(device).long() 13 | 14 | # segmentation Loss 15 | seg_loss = F.cross_entropy(net_out, gts) 16 | 17 | # direction field Loss 18 | df_loss = F.mse_loss(df_out, gts_df) 19 | 20 | total_loss = seg_loss + df_loss 21 | return total_loss 22 | 23 | class EuclideanLossWithOHEM(nn.Module): 24 | def __init__(self, npRatio=3): 25 | super(EuclideanLossWithOHEM, self).__init__() 26 | self.npRatio = npRatio 27 | 28 | def __cal_weight(self, gt): 29 | _, H, W = gt.shape # N=1 30 | labels = torch.unique(gt, sorted=True)[1:] 31 | weight = torch.zeros((H, W), dtype=torch.float, device=gt.device) 32 | posRegion = gt[0, ...] > 0 33 | posCount = torch.sum(posRegion) 34 | if posCount != 0: 35 | segRemain = 0 36 | for segi in labels: 37 | overlap_segi = gt[0, ...] == segi 38 | overlapCount_segi = torch.sum(overlap_segi) 39 | if overlapCount_segi == 0: continue 40 | segRemain = segRemain + 1 41 | segAve = float(posCount) / segRemain 42 | for segi in labels: 43 | overlap_segi = gt[0, ...] == segi 44 | overlapCount_segi = torch.sum(overlap_segi, dtype=torch.float) 45 | if overlapCount_segi == 0: continue 46 | pixAve = segAve / overlapCount_segi 47 | weight = weight * (~overlap_segi).to(torch.float) + pixAve * overlap_segi.to(torch.float) 48 | # weight = weight[None] 49 | return weight 50 | 51 | def forward(self, pred, gt_df, gt, weight=None): 52 | """ pred: (N, C, H, W) 53 | gt_df: (N, C, H, W) 54 | gt: (N, 1, H, W) 55 | """ 56 | # L1 and L2 distance 57 | N, _, H, W = pred.shape 58 | distL1 = pred - gt_df 59 | distL2 = distL1 ** 2 60 | 61 | if weight is None: 62 | weight = torch.zeros((N, H, W), device=pred.device) 63 | for i in range(N): 64 | weight[i] = self.__cal_weight(gt[i]) 65 | 66 | # the amount of positive and negtive pixels 67 | regionPos = (weight > 0).to(torch.float) 68 | regionNeg = (weight == 0).to(torch.float) 69 | sumPos = torch.sum(regionPos, dim=(1,2)) # (N,) 70 | sumNeg = torch.sum(regionNeg, dim=(1,2)) 71 | 72 | # the amount of hard negative pixels 73 | sumhardNeg = torch.min(self.npRatio * sumPos, sumNeg).to(torch.int) # (N,) 74 | 75 | # set loss on ~(top - sumhardNeg) negative pixels to 0 76 | lossNeg = (distL2[:,0,...] + distL2[:, 1, ...]) * regionNeg 77 | lossFlat = torch.flatten(lossNeg, start_dim=1) # (N, ...) 78 | arg = torch.argsort(lossFlat, dim=1) 79 | for i in range(N): 80 | lossFlat[i, arg[i, :-sumhardNeg[i]]] = 0 81 | lossHard = lossFlat.view(lossNeg.shape) 82 | 83 | # weight for positive and negative pixels 84 | weightPos = torch.zeros_like(pred) 85 | weightNeg = torch.zeros_like(pred) 86 | 87 | weightPos = torch.stack([weight, weight], dim=1) 88 | 89 | weightNeg[:,0,...] = (lossHard != 0).to(torch.float32) 90 | weightNeg[:,1,...] = (lossHard != 0).to(torch.float32) 91 | 92 | # total loss 93 | total_loss = torch.sum((distL1 ** 2) * (weightPos + weightNeg)) / N / 2. / torch.sum(weightPos + weightNeg) 94 | 95 | return total_loss 96 | 97 | 98 | 99 | if __name__ == "__main__": 100 | import os 101 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 102 | criterion = EuclideanLossWithOHEM() 103 | for i in range(100): 104 | pred = torch.randn((32, 2, 224, 224)).cuda() 105 | gt_df = torch.randn((32, 2, 224, 224)).cuda() 106 | gt = torch.randint(0, 4, (32, 1, 224, 224)).cuda() 107 | 108 | loss = criterion(100*gt_df, gt_df, gt) 109 | print("{:6} loss:{}".format(i, loss)) -------------------------------------------------------------------------------- /libs/losses/mag_angle_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | def cart2polar(coord): 7 | """ coord: (N, 2, ...) 8 | """ 9 | x = coord[:, 0, ...] 10 | y = coord[:, 1, ...] 11 | 12 | theta = torch.atan(y / (x + 1e-12)) 13 | 14 | theta = theta + (x < 0).to(coord.dtype) * math.pi 15 | theta = theta + ((x > 0).to(coord.dtype) * (y < 0).to(coord.dtype)) * 2 * math.pi 16 | return theta / (2 * math.pi) 17 | 18 | class EuclideanAngleLossWithOHEM(nn.Module): 19 | def __init__(self, npRatio=3): 20 | super(EuclideanAngleLossWithOHEM, self).__init__() 21 | self.npRatio = npRatio 22 | 23 | def __cal_weight(self, gt): 24 | _, H, W = gt.shape # N=1 25 | labels = torch.unique(gt, sorted=True)[1:] 26 | weight = torch.zeros((H, W), dtype=torch.float, device=gt.device) 27 | posRegion = gt[0, ...] > 0 28 | posCount = torch.sum(posRegion) 29 | if posCount != 0: 30 | segRemain = 0 31 | for segi in labels: 32 | overlap_segi = gt[0, ...] == segi 33 | overlapCount_segi = torch.sum(overlap_segi) 34 | if overlapCount_segi == 0: continue 35 | segRemain = segRemain + 1 36 | segAve = float(posCount) / segRemain 37 | for segi in labels: 38 | overlap_segi = gt[0, ...] == segi 39 | overlapCount_segi = torch.sum(overlap_segi, dtype=torch.float) 40 | if overlapCount_segi == 0: continue 41 | pixAve = segAve / overlapCount_segi 42 | weight = weight * (~overlap_segi).to(torch.float) + pixAve * overlap_segi.to(torch.float) 43 | # weight = weight[None] 44 | return weight 45 | 46 | def forward(self, pred, gt_df, gt, weight=None): 47 | """ pred: (N, C, H, W) 48 | gt_df: (N, C, H, W) 49 | gt: (N, 1, H, W) 50 | """ 51 | # L1 and L2 distance 52 | N, _, H, W = pred.shape 53 | distL1 = pred - gt_df 54 | distL2 = distL1 ** 2 55 | 56 | theta_p = cart2polar(pred) 57 | theta_g = cart2polar(gt_df) 58 | angleDistL1 = theta_g - theta_p 59 | 60 | 61 | if weight is None: 62 | weight = torch.zeros((N, H, W), device=pred.device) 63 | for i in range(N): 64 | weight[i] = self.__cal_weight(gt[i]) 65 | 66 | # the amount of positive and negtive pixels 67 | regionPos = (weight > 0).to(torch.float) 68 | regionNeg = (weight == 0).to(torch.float) 69 | sumPos = torch.sum(regionPos, dim=(1,2)) # (N,) 70 | sumNeg = torch.sum(regionNeg, dim=(1,2)) 71 | 72 | # the amount of hard negative pixels 73 | sumhardNeg = torch.min(self.npRatio * sumPos, sumNeg).to(torch.int) # (N,) 74 | 75 | # angle loss on ~(top - sumhardNeg) negative pixels to 0 76 | angleLossNeg = (angleDistL1 ** 2) * regionNeg 77 | angleLossNegFlat = torch.flatten(angleLossNeg, start_dim=1) # (N, ...) 78 | 79 | 80 | # set loss on ~(top - sumhardNeg) negative pixels to 0 81 | lossNeg = (distL2[:,0,...] + distL2[:, 1, ...]) * regionNeg 82 | lossFlat = torch.flatten(lossNeg, start_dim=1) # (N, ...) 83 | 84 | # l2-norm distance and angle distance 85 | lossFlat = lossFlat + angleLossNegFlat 86 | arg = torch.argsort(lossFlat, dim=1) 87 | for i in range(N): 88 | lossFlat[i, arg[i, :-sumhardNeg[i]]] = 0 89 | lossHard = lossFlat.view(lossNeg.shape) 90 | 91 | # weight for positive and negative pixels 92 | weightPos = torch.zeros_like(gt, dtype=pred.dtype) 93 | weightNeg = torch.zeros_like(gt, dtype=pred.dtype) 94 | 95 | weightPos = weight.clone() 96 | 97 | weightNeg[:,0,...] = (lossHard != 0).to(torch.float32) 98 | 99 | # total loss 100 | total_loss = torch.sum(((distL2[:,0,...] + distL2[:, 1, ...]) + angleDistL1 ** 2) * 101 | (weightPos + weightNeg)) / N / 2. / torch.sum(weightPos + weightNeg) 102 | 103 | return total_loss 104 | 105 | 106 | 107 | if __name__ == "__main__": 108 | import os 109 | import torch.nn as nn 110 | import torch.optim as optim 111 | 112 | os.environ["CUDA_VISIBLE_DEVICES"] = "4" 113 | criterion = EuclideanAngleLossWithOHEM() 114 | 115 | # models = nn.Sequential(nn.Conv2d(2, 2, 1), 116 | # nn.ReLU()) 117 | # models.to(device="cuda") 118 | 119 | # epoch_n = 200 120 | # learning_rate = 1e-4 121 | 122 | # optimizer = optim.Adam(params=models.parameters(), lr=learning_rate) 123 | 124 | # for i in range(100): 125 | # pred = torch.randn((32, 2, 224, 224)).cuda() 126 | # gt_df = torch.randn((32, 2, 224, 224)).cuda() 127 | # gt = torch.randint(0, 4, (32, 1, 224, 224)).cuda() 128 | 129 | # pred = models(gt_df) 130 | # loss = criterion(pred, gt_df, gt) 131 | # optimizer.zero_grad() 132 | # loss.backward() 133 | # optimizer.step() 134 | 135 | # print("{:6} loss:{}".format(i, loss)) 136 | 137 | for i in range(100): 138 | pred = torch.randn((32, 2, 224, 224)).cuda() 139 | gt_df = torch.randn((32, 2, 224, 224)).cuda() 140 | gt = torch.randint(0, 4, (32, 1, 224, 224)).cuda() 141 | 142 | loss = criterion(-gt_df, gt_df, gt) 143 | print("{:6} loss:{}".format(i, loss)) 144 | -------------------------------------------------------------------------------- /libs/losses/surface_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import List, Set, Iterable 4 | from utils.utils_loss import class2one_hot 5 | 6 | def uniq(a: Tensor) -> Set: 7 | return set(torch.unique(a.cpu()).numpy()) 8 | 9 | def sset(a: Tensor, sub: Iterable) -> bool: 10 | return uniq(a).issubset(sub) 11 | 12 | def simplex(t: Tensor, axis=1) -> bool: 13 | _sum = t.sum(axis).type(torch.float32) 14 | _ones = torch.ones_like(_sum, dtype=torch.float32) 15 | return torch.allclose(_sum, _ones) 16 | 17 | def one_hot(t: Tensor, axis=1) -> bool: 18 | return simplex(t, axis) and sset(t, [0, 1]) 19 | 20 | 21 | class SurfaceLoss(): 22 | def __init__(self, **kwargs): 23 | # Self.idc is used to filter out some classes of the target mask. Use fancy indexing 24 | self.idc: List[int] = kwargs["idc"] 25 | print(f"Initialized {self.__class__.__name__} with {kwargs}") 26 | 27 | def __call__(self, probs: Tensor, dist_maps: Tensor, gts: Tensor) -> Tensor: 28 | assert simplex(probs) 29 | assert not one_hot(dist_maps) 30 | 31 | dist_maps = dist_maps.to(probs.device) 32 | 33 | pc = probs[:, self.idc, ...] 34 | dc = dist_maps[:, self.idc, ...] 35 | 36 | multipled = torch.einsum("bcwh,bcwh->bcwh", pc, dc) 37 | 38 | # gc = class2one_hot(gts)[:, self.idc, ...] 39 | # multipled = torch.einsum("bcwh,bcwh->bcwh", pc - gc, dc) 40 | 41 | loss = multipled.mean() 42 | 43 | return loss -------------------------------------------------------------------------------- /libs/network/__init__.py: -------------------------------------------------------------------------------- 1 | # from libs.network.unet_df import U_NetDF 2 | # from libs.network.unet import U_Net 3 | from .unet import U_Net 4 | from .unet_df import U_NetDF 5 | from .unet_dual import Unet_Dual 6 | from .resdf_unet import Resnet18_DfUnet 7 | from .FCMultiScaleResidualDenseNet import FCMultiScaleResidualDenseNet as DenseNet 8 | from .densenet_df import DenseNet_DF -------------------------------------------------------------------------------- /libs/network/train_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | import math 7 | import cv2 8 | from collections import namedtuple 9 | from utils.metrics import dice, cal_hausdorff_distance 10 | from utils.vis_utils import batchToColorImg, masks_to_contours 11 | 12 | def model_fn_decorator(): 13 | ModelReturn = namedtuple("ModelReturn", ["loss", "tb_dict", "disp_dict"]) 14 | 15 | def model_fn(model, data, criterion, perfermance=False, vis=False, device="cuda", epoch=0, num_class=4): 16 | # imgs, gts, _ = data 17 | imgs, gts = data[:2] 18 | 19 | imgs = imgs.to(device) 20 | gts = torch.squeeze(gts, 1).to(device) 21 | 22 | net_out = model(imgs) 23 | 24 | loss = criterion(net_out[0], gts.long()) 25 | 26 | tb_dict = {} 27 | disp_dict = {} 28 | tb_dict.update({"loss": loss.item()}) 29 | disp_dict.update({"loss": loss.item()}) 30 | 31 | if perfermance: 32 | gts_ = gts.unsqueeze(1) 33 | 34 | net_out = F.softmax(net_out[0], dim=1) 35 | _, preds = torch.max(net_out, 1) 36 | preds = preds.unsqueeze(1) 37 | cal_perfer(make_one_hot(preds, num_class), make_one_hot(gts_, num_class), tb_dict) 38 | 39 | 40 | return ModelReturn(loss, tb_dict, disp_dict) 41 | 42 | return model_fn 43 | 44 | def model_DF_decorator(): 45 | ModelReturn = namedtuple("ModelReturn", ["loss", "tb_dict", "disp_dict"]) 46 | 47 | def model_fn(model, data, criterion=None, perfermance=False, vis=False, device="cuda", epoch=0, num_class=4): 48 | imgs, gts = data[:2] 49 | gts_df, dist_maps = data[2:] 50 | 51 | imgs = imgs.to(device) 52 | gts = torch.squeeze(gts, 1).to(device).long() 53 | gts_df = gts_df.to(device) 54 | 55 | net_out = model(imgs) 56 | seg_out, df_out = net_out[:2] 57 | 58 | # add Auxiliary Segmentation 59 | if len(net_out) >= 3 and net_out[2] is not None: 60 | auxseg_out = net_out[2] 61 | auxseg_loss = F.cross_entropy(auxseg_out, gts) 62 | else: 63 | auxseg_loss = torch.tensor([0.], dtype=torch.float32, device=device) 64 | 65 | 66 | # loss = criterion(net_out, gts.long()) 67 | # segmentation Loss 68 | seg_loss = F.cross_entropy(seg_out, gts) 69 | 70 | # direction field Loss 71 | df_loss, boundary_loss = criterion(seg_out, dist_maps, df_out, gts_df, gts) 72 | 73 | alpha = 1.0 74 | loss = alpha*(seg_loss + 1. * df_loss + 0.1*auxseg_loss) + (1.-alpha)*boundary_loss 75 | 76 | tb_dict = {} 77 | disp_dict = {} 78 | tb_dict.update({"loss": loss.item(), "seg_loss": alpha*seg_loss.item(), "df_loss": alpha*1.*df_loss.item(), 79 | "boundary_loss": (1.-alpha)*boundary_loss.item(), "auxseg_loss": alpha*0.1*auxseg_loss.item()}) 80 | disp_dict.update({"loss": loss.item()}) 81 | 82 | if perfermance: 83 | gts_ = gts.unsqueeze(1) 84 | 85 | seg_out = F.softmax(seg_out, dim=1) 86 | _, preds = torch.max(seg_out, 1) 87 | preds = preds.unsqueeze(1) 88 | cal_perfer(make_one_hot(preds, num_class), make_one_hot(gts_, num_class), tb_dict) 89 | 90 | if vis: 91 | # 可视化 方向场 92 | # vis_dict = {} 93 | gt_df = gts_df.cpu().numpy() 94 | _, angle_gt = cv2.cartToPolar(gt_df[:, 0,...], gt_df[:, 1,...]) 95 | angle_gt = batchToColorImg(angle_gt, minv=0, maxv=2*math.pi).transpose(0, 3, 1, 2) 96 | 97 | df_map = df_out.cpu().numpy() 98 | mag, angle_df = cv2.cartToPolar(df_map[:, 0,...], df_map[:, 1,...]) 99 | angle_df = batchToColorImg(angle_df, minv=0, maxv=2*math.pi).transpose(0, 3, 1, 2) 100 | mag = batchToColorImg(mag).transpose(0, 3, 1, 2) 101 | 102 | tb_dict.update({"vis": [angle_gt, mag, angle_df]}) 103 | 104 | 105 | return ModelReturn(loss, tb_dict, disp_dict) 106 | 107 | return model_fn 108 | 109 | 110 | 111 | def cal_perfer(preds, masks, tb_dict): 112 | LV_dice = [] # 1 113 | MYO_dice = [] # 2 114 | RV_dice = [] # 3 115 | LV_hausdorff = [] 116 | MYO_hausdorff = [] 117 | RV_hausdorff = [] 118 | 119 | for i in range(preds.shape[0]): 120 | LV_dice.append(dice(preds[i,1,:,:],masks[i,1,:,:])) 121 | RV_dice.append(dice(preds[i, 3, :, :], masks[i, 3, :, :])) 122 | MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :])) 123 | 124 | LV_hausdorff.append(cal_hausdorff_distance(preds[i,1,:,:],masks[i,1,:,:])) 125 | RV_hausdorff.append(cal_hausdorff_distance(preds[i,3,:,:],masks[i,3,:,:])) 126 | MYO_hausdorff.append(cal_hausdorff_distance(preds[i,2,:,:],masks[i,2,:,:])) 127 | 128 | tb_dict.update({"LV_dice": np.mean(LV_dice)}) 129 | tb_dict.update({"RV_dice": np.mean(RV_dice)}) 130 | tb_dict.update({"MYO_dice": np.mean(MYO_dice)}) 131 | tb_dict.update({"LV_hausdorff": np.mean(LV_hausdorff)}) 132 | tb_dict.update({"RV_hausdorff": np.mean(RV_hausdorff)}) 133 | tb_dict.update({"MYO_hausdorff": np.mean(MYO_hausdorff)}) 134 | 135 | def make_one_hot(input, num_classes): 136 | """Convert class index tensor to one hot encoding tensor. 137 | Args: 138 | input: A tensor of shape [N, 1, *] 139 | num_classes: An int of number of class 140 | Returns: 141 | A tensor of shape [N, num_classes, *] 142 | """ 143 | shape = np.array(input.shape) 144 | shape[1] = num_classes 145 | shape = tuple(shape) 146 | result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1) 147 | # result = result.scatter_(1, input.cpu(), 1) 148 | 149 | return result 150 | -------------------------------------------------------------------------------- /libs/network/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | def init_weights(net, init_type='normal', gain=0.02): 8 | def init_func(m): 9 | classname = m.__class__.__name__ 10 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 11 | if init_type == 'normal': 12 | init.normal_(m.weight.data, 0.0, gain) 13 | elif init_type == 'xavier': 14 | init.xavier_normal_(m.weight.data, gain=gain) 15 | elif init_type == 'kaiming': 16 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 17 | elif init_type == 'orthogonal': 18 | init.orthogonal_(m.weight.data, gain=gain) 19 | else: 20 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 21 | if hasattr(m, 'bias') and m.bias is not None: 22 | init.constant_(m.bias.data, 0.0) 23 | elif classname.find('BatchNorm2d') != -1: 24 | init.normal_(m.weight.data, 1.0, gain) 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | print('initialize network with %s' % init_type) 28 | net.apply(init_func) 29 | 30 | class conv_block(nn.Module): 31 | def __init__(self,ch_in,ch_out): 32 | super(conv_block,self).__init__() 33 | self.conv = nn.Sequential( 34 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 35 | nn.BatchNorm2d(ch_out), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 38 | nn.BatchNorm2d(ch_out), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | 43 | def forward(self,x): 44 | x = self.conv(x) 45 | return x 46 | 47 | class up_conv(nn.Module): 48 | def __init__(self,ch_in,ch_out): 49 | super(up_conv,self).__init__() 50 | self.up = nn.Sequential( 51 | # nn.Upsample(scale_factor=2), 52 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 53 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 54 | nn.BatchNorm2d(ch_out), 55 | nn.ReLU(inplace=True) 56 | ) 57 | 58 | def forward(self,x): 59 | x = self.up(x) 60 | return x 61 | 62 | class U_Net(nn.Module): 63 | def __init__(self,img_ch=1,num_class=4, selfeat=False): 64 | super(U_Net,self).__init__() 65 | 66 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 67 | 68 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) 69 | self.Conv2 = conv_block(ch_in=64,ch_out=128) 70 | self.Conv3 = conv_block(ch_in=128,ch_out=256) 71 | self.Conv4 = conv_block(ch_in=256,ch_out=512) 72 | self.Conv5 = conv_block(ch_in=512,ch_out=1024) 73 | 74 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 75 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 76 | 77 | self.Up4 = up_conv(ch_in=512,ch_out=256) 78 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 79 | 80 | self.Up3 = up_conv(ch_in=256,ch_out=128) 81 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 82 | 83 | self.Up2 = up_conv(ch_in=128,ch_out=64) 84 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 85 | 86 | self.Conv_1x1 = nn.Conv2d(64,num_class,kernel_size=1,stride=1,padding=0) 87 | 88 | 89 | def forward(self,x): 90 | # encoding path 91 | x1 = self.Conv1(x) 92 | 93 | x2 = self.Maxpool(x1) 94 | x2 = self.Conv2(x2) 95 | 96 | x3 = self.Maxpool(x2) 97 | x3 = self.Conv3(x3) 98 | 99 | x4 = self.Maxpool(x3) 100 | x4 = self.Conv4(x4) 101 | 102 | x5 = self.Maxpool(x4) 103 | x5 = self.Conv5(x5) 104 | 105 | # decoding + concat path 106 | d5 = self.Up5(x5) 107 | d5 = torch.cat((x4,d5),dim=1) 108 | 109 | d5 = self.Up_conv5(d5) 110 | 111 | d4 = self.Up4(d5) 112 | d4 = torch.cat((x3,d4),dim=1) 113 | d4 = self.Up_conv4(d4) 114 | 115 | d3 = self.Up3(d4) 116 | d3 = torch.cat((x2,d3),dim=1) 117 | d3 = self.Up_conv3(d3) 118 | 119 | d2 = self.Up2(d3) 120 | d2 = torch.cat((x1,d2),dim=1) 121 | d2 = self.Up_conv2(d2) 122 | 123 | d1 = self.Conv_1x1(d2) 124 | 125 | return [d1] 126 | 127 | 128 | 129 | if __name__ == "__main__": 130 | def opCounter(model): 131 | type_size = 4 # float 132 | params = list(model.parameters()) 133 | k = 0 134 | for i in params: 135 | l = 1 136 | print("该层的结构:" + str(list(i.size()))) 137 | for j in i.size(): 138 | l *= j 139 | print("该层参数和:" + str(l)) 140 | k = k + l 141 | print("总参数数量和:" + str(k)) 142 | print('Model {} : params: {:4f}M'.format(model._get_name(), k * type_size / 1000 / 1000)) 143 | 144 | model = U_Net() 145 | opCounter(model) 146 | -------------------------------------------------------------------------------- /libs/network/unet_df.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from libs.network.unet import conv_block, up_conv 7 | 8 | class SelFuseFeature(nn.Module): 9 | def __init__(self, in_channels, shift_n=5, n_class=4, auxseg=False): 10 | super(SelFuseFeature, self).__init__() 11 | 12 | self.shift_n = shift_n 13 | self.n_class = n_class 14 | self.auxseg = auxseg 15 | self.fuse_conv = nn.Sequential(nn.Conv2d(in_channels*2, in_channels, kernel_size=1, padding=0), 16 | nn.BatchNorm2d(in_channels), 17 | nn.ReLU(inplace=True), 18 | ) 19 | if auxseg: 20 | self.auxseg_conv = nn.Conv2d(in_channels, self.n_class, 1) 21 | 22 | 23 | def forward(self, x, df): 24 | N, _, H, W = df.shape 25 | mag = torch.sqrt(torch.sum(df ** 2, dim=1)) 26 | greater_mask = mag > 0.5 27 | greater_mask = torch.stack([greater_mask, greater_mask], dim=1) 28 | df[~greater_mask] = 0 29 | 30 | scale = 1. 31 | 32 | grid = torch.stack(torch.meshgrid(torch.arange(H), torch.arange(W)), dim=0) 33 | grid = grid.expand(N, -1, -1, -1).to(x.device, dtype=torch.float).requires_grad_() 34 | grid = grid + scale * df 35 | 36 | grid = grid.permute(0, 2, 3, 1).transpose(1, 2) 37 | grid_ = grid + 0. 38 | grid[...,0] = 2*grid_[..., 0] / (H-1) - 1 39 | grid[...,1] = 2*grid_[..., 1] / (W-1) - 1 40 | 41 | # features = [] 42 | select_x = x.clone() 43 | for _ in range(self.shift_n): 44 | select_x = F.grid_sample(select_x, grid, mode='bilinear', padding_mode='border') 45 | # features.append(select_x) 46 | # select_x = torch.mean(torch.stack(features, dim=0), dim=0) 47 | # features.append(select_x.detach().cpu().numpy()) 48 | # np.save("/root/chengfeng/Cardiac/source_code/logs/acdc_logs/logs_temp/feature.npy", np.array(features)) 49 | if self.auxseg: 50 | auxseg = self.auxseg_conv(x) 51 | else: 52 | auxseg = None 53 | 54 | select_x = self.fuse_conv(torch.cat([x, select_x], dim=1)) 55 | return [select_x, auxseg] 56 | 57 | class U_NetDF(nn.Module): 58 | def __init__(self,img_ch=1,num_class=4, selfeat=False, shift_n=5, auxseg=False): 59 | super(U_NetDF,self).__init__() 60 | 61 | self.selfeat = selfeat 62 | self.shift_n = shift_n 63 | 64 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 65 | 66 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) 67 | self.Conv2 = conv_block(ch_in=64,ch_out=128) 68 | self.Conv3 = conv_block(ch_in=128,ch_out=256) 69 | self.Conv4 = conv_block(ch_in=256,ch_out=512) 70 | self.Conv5 = conv_block(ch_in=512,ch_out=1024) 71 | 72 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 73 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 74 | 75 | self.Up4 = up_conv(ch_in=512,ch_out=256) 76 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 77 | 78 | self.Up3 = up_conv(ch_in=256,ch_out=128) 79 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 80 | 81 | self.Up2 = up_conv(ch_in=128,ch_out=64) 82 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 83 | 84 | # Direct Field 85 | self.ConvDf_1x1 = nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=0) 86 | 87 | if selfeat: 88 | self.SelDF = SelFuseFeature(64, auxseg=auxseg, shift_n=shift_n) 89 | 90 | self.Conv_1x1 = nn.Conv2d(64,num_class,kernel_size=1,stride=1,padding=0) 91 | 92 | def forward(self, inputs): 93 | x = inputs 94 | 95 | # encoding path 96 | x1 = self.Conv1(x) 97 | 98 | x2 = self.Maxpool(x1) 99 | x2 = self.Conv2(x2) 100 | 101 | x3 = self.Maxpool(x2) 102 | x3 = self.Conv3(x3) 103 | 104 | x4 = self.Maxpool(x3) 105 | x4 = self.Conv4(x4) 106 | 107 | x5 = self.Maxpool(x4) 108 | x5 = self.Conv5(x5) 109 | 110 | # decoding + concat path 111 | d5 = self.Up5(x5) 112 | d5 = torch.cat((x4,d5),dim=1) 113 | 114 | d5 = self.Up_conv5(d5) 115 | 116 | d4 = self.Up4(d5) 117 | d4 = torch.cat((x3,d4),dim=1) 118 | d4 = self.Up_conv4(d4) 119 | 120 | d3 = self.Up3(d4) 121 | d3 = torch.cat((x2,d3),dim=1) 122 | d3 = self.Up_conv3(d3) 123 | 124 | # df = self.ConvDf_1x1(d3) 125 | # # df = F.interpolate(inputs[1], size=d3.shape[-2:], mode='bilinear', align_corners=True) 126 | # if self.selfeat: 127 | # d3 = self.SelDF(d3, df) 128 | 129 | 130 | d2 = self.Up2(d3) 131 | d2 = torch.cat((x1,d2),dim=1) 132 | d2 = self.Up_conv2(d2) 133 | 134 | # Direct Field 135 | df = self.ConvDf_1x1(d2) 136 | # df = None 137 | if self.selfeat: 138 | d2_auxseg = self.SelDF(d2, df) 139 | d2, auxseg = d2_auxseg[:2] 140 | else: 141 | auxseg = None 142 | 143 | # df = F.interpolate(df, size=x.shape[-2:], mode='bilinear', align_corners=True) 144 | d1 = self.Conv_1x1(d2) 145 | 146 | return [d1, df, auxseg] 147 | 148 | 149 | 150 | if __name__ == "__main__": 151 | 152 | a = torch.randn(1, 1, 224, 224) 153 | 154 | model = U_NetDF(selfeat=True) 155 | 156 | out = model(a) 157 | print(out[0].shape, out[1].shape, out[2].shape) 158 | 159 | -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/pipeline.png -------------------------------------------------------------------------------- /tools/MultiShellScripts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | def test_df(log_dir, epoch_i=0, best_model=False): 5 | OUTPUT_DIR = os.path.join(log_dir, "eval") 6 | if best_model: 7 | MODEL_PATH = os.path.join(log_dir, "ckpt", "model_best.pth") 8 | else: 9 | MODEL_PATH = os.path.join(log_dir, "ckpt", "checkpoint_epoch_{}.pth".format(epoch_i)) 10 | 11 | commands = "python tools/test_df.py --used_df U_NetDF --selfeat --mgpus 6 --model_path1 {} \ 12 | --output_dir {} --log_file ../log_evaluation_vis.txt --vis".format(MODEL_PATH, OUTPUT_DIR) 13 | os.system(commands) 14 | 15 | def train(): 16 | os.system("python -m torch.distributed.launch --nproc_per_node 2 --master_port $RANDOM tools/train.py --batch_size 24 --mgpus 2,3 --output_dir logs/acdc_logs/log_temp --train_with_eval") 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser(description="arg parser") 20 | parser.add_argument("--scrip", type=str, default=None, help="which scrips to running") 21 | args = parser.parse_args() 22 | 23 | if args.scrip == "train": 24 | train() 25 | elif args.scrip == "test_df": 26 | test_df("logs/acdc_logs/logs_256_supcat_auxseg_thresh0.1/", best_model=False, epoch_i=118) -------------------------------------------------------------------------------- /tools/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(os.path.join(BASE_DIR, '../')) 5 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 6 | sys.path.append(os.path.join(BASE_DIR, '../libs')) -------------------------------------------------------------------------------- /tools/test_acdc_leadboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import nibabel as nib 5 | import time 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import _init_paths 12 | from libs.network import U_Net, U_NetDF 13 | from utils.image_list import to_image_list 14 | import libs.datasets.augment as standard_augment 15 | import libs.datasets.joint_augment as joint_augment 16 | 17 | def progress_bar(curr_idx, max_idx, time_step, repeat_elem = "_"): 18 | max_equals = 55 19 | step_ms = int(time_step*1000) 20 | num_equals = int(curr_idx*max_equals/float(max_idx)) 21 | len_reverse =len('Step:%d ms| %d/%d ['%(step_ms, curr_idx, max_idx)) + num_equals 22 | sys.stdout.write("Step:%d ms|%d/%d [%s]" %(step_ms, curr_idx, max_idx, " " * max_equals,)) 23 | sys.stdout.flush() 24 | sys.stdout.write("/b" * (max_equals+1)) 25 | sys.stdout.write(repeat_elem * num_equals) 26 | sys.stdout.write("/b"*len_reverse) 27 | sys.stdout.flush() 28 | if curr_idx == max_idx: 29 | print('/n') 30 | 31 | def load_nii(img_path): 32 | """ 33 | Function to load a 'nii' or 'nii.gz' file, The function returns 34 | everyting needed to save another 'nii' or 'nii.gz' 35 | in the same dimensional space, i.e. the affine matrix and the header 36 | 37 | Parameters 38 | ---------- 39 | 40 | img_path: string 41 | String with the path of the 'nii' or 'nii.gz' image file name. 42 | 43 | Returns 44 | ------- 45 | Three element, the first is a numpy array of the image values, 46 | the second is the affine transformation of the image, and the 47 | last one is the header of the image. 48 | """ 49 | nimg = nib.load(img_path) 50 | return nimg.get_fdata(), nimg.affine, nimg.header 51 | 52 | def save_nii(vol, affine, hdr, path, prefix, suffix): 53 | vol = nib.Nifti1Image(vol, affine, hdr) 54 | vol.set_data_dtype(np.uint8) 55 | nib.save(vol, os.path.join(path, prefix+'_'+suffix + ".nii.gz")) 56 | 57 | 58 | def get_person_names(root_path=None): 59 | persons_name = os.listdir(root_path) 60 | persons_name = [pn for pn in persons_name if "patient" in pn] 61 | persons_name.sort() 62 | return persons_name 63 | 64 | def get_patient_data(patient, root_path): 65 | patient_data = {} 66 | infocfg_p = os.path.join(root_path, patient, "Info.cfg") 67 | 68 | with open(infocfg_p) as f_in: 69 | for line in f_in: 70 | l = line.rstrip().split(": ") 71 | patient_data[l[0]] = l[1] 72 | 73 | ed_path = os.path.join(root_path, patient, "%s_frame%02d.nii.gz" % (patient, int(patient_data['ED']))) 74 | es_path = os.path.join(root_path, patient, "%s_frame%02d.nii.gz" % (patient, int(patient_data['ES']))) 75 | img_4d_path = os.path.join(root_path, patient, "{}_4d.nii.gz".format(patient)) 76 | # ed_gt_path = os.path.join(root_path, patient, "%s_frame%02d_gt.nii.gz" % (patient, int(patient_data['ED']))) 77 | # es_gt_path = os.path.join(root_path, patient, "%s_frame%02d_gt.nii.gz" % (patient, int(patient_data['ES']))) 78 | 79 | ed, affine, hdr = load_nii(ed_path) 80 | patient_data['ED_VOL'] = np.swapaxes(ed, 0, -1) 81 | patient_data['3D_affine'] = affine 82 | patient_data['3D_hdr'] = hdr 83 | 84 | es, _, _ = load_nii(es_path) # (w, h, slices) 85 | patient_data['ES_VOL'] = np.swapaxes(es, 0, -1) 86 | 87 | img_4d, affine_4d, hdr_4d = load_nii(img_4d_path) # (w, h, slices, times) 88 | patient_data['4D'] = np.swapaxes(img_4d, 0, 1) 89 | patient_data['4D_affine'] = affine_4d 90 | patient_data['4D_hdr'] = hdr_4d 91 | 92 | patient_data['size'] = img_4d.shape[:2][::-1] 93 | patient_data['pid'] = patient 94 | 95 | # ed_gt = load_nii(ed_gt_path) 96 | # patient_data['ED_GT'] = np.swapaxes(ed_gt, 0, 1) 97 | 98 | # es_gt = load_nii(es_gt_path) 99 | # patient_data['ES_GT'] = np.swapaxes(es_gt, 0, 1) 100 | return patient_data 101 | 102 | def test_it(model, data, device="cuda", used_df=True): 103 | model.eval() 104 | imgs = data 105 | 106 | imgs = imgs.to(device) 107 | # gts = gts.to(device) 108 | 109 | net_out = model(imgs) 110 | if used_df: 111 | preds_out = net_out[0] 112 | preds_df = net_out[1] 113 | else: 114 | preds_out = net_out[0] 115 | preds_df = None 116 | preds_out = nn.functional.softmax(preds_out, dim=1) 117 | _, preds = torch.max(preds_out, 1) 118 | preds = preds.unsqueeze(1) # (N, 1, *) 119 | 120 | return preds, preds_df 121 | 122 | def transform(imgs): 123 | mean = 63.19523533061758 124 | std = 70.74166957523165 125 | trans = standard_augment.Compose([standard_augment.To_PIL_Image(), 126 | # joint_augment.RandomAffine(0,translate=(0.125, 0.125)), 127 | # joint_augment.RandomRotate((-180,180)), 128 | # joint_augment.FixResize(224), 129 | standard_augment.to_Tensor(), 130 | standard_augment.normalize([mean], [std]), 131 | ]) 132 | return trans(imgs) 133 | 134 | def test_voxel(model, imgs, used_df, multi_batches=False, resize=None): 135 | """ imgs: (slices, H, W) 136 | preds: (slices, 1, H, W) 137 | """ 138 | imgs = imgs[..., None].astype(np.float32) 139 | B, _, _, C = imgs.shape 140 | 141 | if multi_batches: 142 | data, origin_shape = to_image_list(imgs, size_divisible=32, return_size=True) 143 | preds, _ = test_it(model, data) 144 | 145 | # for j in range(imgs.shape[0]): 146 | # preds[j, ...] = pred.cpu().numpy()[j, :, :origin_shape[j][0], :origin_shape[j][1]] 147 | else: 148 | preds = torch.zeros(B, C, resize[0], resize[1]) 149 | for j, pt in enumerate(imgs): 150 | data = [transform(pt)] 151 | data, origin_shape = to_image_list(data, size_divisible=32, return_size=True) 152 | pred, _ = test_it(model, data, used_df=used_df) 153 | preds[j, ...] = pred[0, 0, :origin_shape[0][0], :origin_shape[0][1]] 154 | 155 | if resize is not None: 156 | # preds = F.interpolate(preds, size=resize, mode='nearest') 157 | preds = preds[..., :resize[0], :resize[1]] 158 | 159 | return preds.cpu().numpy()[:, 0, ...] 160 | 161 | def create_model(model_name, selfeat): 162 | if model_name == 'U_NetDF': 163 | model = U_NetDF(selfeat=selfeat, auxseg=True) 164 | elif model_name == 'U_Net': 165 | model = U_Net() 166 | # elif model_name == 'Resnet18_DfUnet': 167 | # model = Resnet18_DfUnet() 168 | # elif model_name == 'DenseNet': 169 | # model = DenseNet() 170 | # elif model_name == 'DenseNet_DF': 171 | # model = DenseNet_DF(selfeat=selfeat) 172 | 173 | return model 174 | 175 | def test(mgpus, model_name, model_path, used_df, selfeat, log_path): 176 | 177 | model = create_model(model_name, selfeat=selfeat) 178 | if mgpus is not None and len(mgpus) > 2: 179 | model = nn.DataParallel(model) 180 | model.cuda() 181 | 182 | checkpoint = torch.load(model_path, map_location='cpu') 183 | model.load_state_dict(checkpoint['model_state']) 184 | 185 | root_path = "MICCAIACDC2017/ACDC_DataSet/testing/testing/" 186 | root_path = "/root/ACDC_DataSet/testing/testing/" 187 | persons_name = get_person_names(root_path) 188 | for j, pn in enumerate(persons_name): 189 | s_time = time.time() 190 | patient_data = get_patient_data(pn, root_path) 191 | 192 | # (slices, h, w) 193 | es_pred = test_voxel(model, patient_data['ES_VOL'], used_df=used_df, resize=patient_data['size']) 194 | ed_pred = test_voxel(model, patient_data['ED_VOL'], used_df=used_df, resize=patient_data['size']) 195 | es_pred = np.transpose(es_pred, (2, 1, 0)) 196 | ed_pred = np.transpose(ed_pred, (2, 1, 0)) 197 | 198 | img_4D = patient_data['4D'] 199 | h, w, s, t = img_4D.shape 200 | pred_4D = np.zeros((w, h, s, t)) 201 | for i in range(img_4D.shape[-1]): 202 | pred = test_voxel(model, np.transpose(img_4D[...,i], (2, 0, 1)), used_df=used_df, resize=patient_data['size']) 203 | pred_4D[..., i] = np.transpose(pred, (2, 1, 0)) 204 | 205 | save_path = os.path.join(log_path, "all_predictions") 206 | os.makedirs(save_path, exist_ok=True) 207 | CheckSizeAndSaveVolume(pred_4D, patient_data, save_path) 208 | progress_bar(j%(len(persons_name)+1), len(persons_name),time.time()-s_time) 209 | 210 | 211 | def CheckSizeAndSaveVolume(seg_4D, patient_data, save_path): 212 | """ 213 | TODO: 214 | """ 215 | prefix = patient_data['pid'] 216 | suffix = '4D' 217 | 218 | # save_nii(seg_4D, patient_data['4D_affine'], patient_data['4D_hdr'], save_path, prefix, suffix) 219 | suffix = 'ED' 220 | ED_phase_n = int(patient_data['ED']) 221 | ED_pred = seg_4D[:,:,:,ED_phase_n] 222 | save_nii(ED_pred.astype(np.uint8), patient_data['3D_affine'], patient_data['3D_hdr'], save_path, prefix, suffix) 223 | 224 | suffix = 'ES' 225 | ES_phase_n = int(patient_data['ES']) 226 | ES_pred = seg_4D[:,:,:,ES_phase_n] 227 | save_nii(ES_pred.astype(np.uint8), patient_data['3D_affine'], patient_data['3D_hdr'], save_path, prefix, suffix) 228 | 229 | # ED_GT = patient_data.get('ED_GT', None) 230 | results = [] 231 | return results 232 | 233 | 234 | 235 | if __name__ == "__main__": 236 | # get_person_names() 237 | import argparse 238 | parser = argparse.ArgumentParser(description="arg parser") 239 | parser.add_argument('--mgpus', type=str, default='0', required=False, help='whether to use multiple gpu') 240 | parser.add_argument('--used_df', type=str, default='True', help='whether to use df') 241 | parser.add_argument('--model', type=str, default='', help='whether to use df') 242 | parser.add_argument('--selfeat', type=bool, default=True, help='whether to use feature select') 243 | parser.add_argument('--model_path', type=str, default=None, help='whether to train with evaluation') 244 | parser.add_argument('--output_dir', type=str, default="logs/acdc_logs/logs_supcat_auxseg/predtions", required=False, help='specify an output directory if needed') 245 | parser.add_argument('--log_file', type=str, default="../log_predtion.txt", help="the file to write logging") 246 | 247 | args = parser.parse_args() 248 | if args.mgpus is not None: 249 | os.environ["CUDA_VISIBLE_DEVICES"] = args.mgpus 250 | 251 | model_path = "logs/acdc_logs/logs_supcat_auxseg/ckpt/checkpoint_epoch_70.pth" 252 | 253 | test(args.mgpus, "U_NetDF", model_path, args.used_df, args.selfeat, args.output_dir) 254 | -------------------------------------------------------------------------------- /tools/test_df_vis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | 5 | import os 6 | import argparse 7 | import numpy as np 8 | import cv2 9 | import logging 10 | import math 11 | 12 | import _init_paths 13 | from libs.network import U_Net, U_NetDF 14 | from libs.datasets import AcdcDataset 15 | import libs.datasets.joint_augment as joint_augment 16 | import libs.datasets.augment as standard_augment 17 | from libs.datasets.collate_batch import BatchCollator 18 | 19 | from libs.configs.config_acdc import cfg 20 | from train_utils.train_utils import load_checkpoint 21 | from utils.metrics import dice 22 | from utils.vis_utils import mask2png, masks_to_contours, apply_mask, img_mask_png 23 | 24 | parser = argparse.ArgumentParser(description="arg parser") 25 | parser.add_argument('--used_df', type=str, default=False, help='whether to use df') 26 | parser.add_argument('--selfeat', action='store_true', default=False, help='whether to use feature select') 27 | parser.add_argument('--mgpus', type=str, default=None, required=True, help='whether to use multiple gpu') 28 | parser.add_argument('--model_path1', type=str, default=None, help='whether to train with evaluation') 29 | parser.add_argument('--model_path2', type=str, default=None, help='whether to train with evaluation') 30 | parser.add_argument('--batch_size', type=int, default=1, help='batch size for evaluation') 31 | parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader') 32 | parser.add_argument('--output_dir', type=str, default=None, required=True, help='specify an output directory if needed') 33 | parser.add_argument('--log_file', type=str, default="../log_evalation.txt", help="the file to write logging") 34 | parser.add_argument('--vis', action='store_true', default=False, help="weather to save test result images") 35 | args = parser.parse_args() 36 | 37 | if args.mgpus is not None: 38 | os.environ["CUDA_VISIBLE_DEVICES"] = args.mgpus 39 | 40 | def create_logger(log_file): 41 | log_format = '%(asctime)s %(levelname)5s %(message)s' 42 | logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file) 43 | console = logging.StreamHandler() 44 | console.setLevel(logging.DEBUG) 45 | console.setFormatter(logging.Formatter(log_format)) 46 | logging.getLogger(__name__).addHandler(console) 47 | return logging.getLogger(__name__) 48 | 49 | def create_dataloader(): 50 | eval_transform = joint_augment.Compose([ 51 | joint_augment.To_PIL_Image(), 52 | joint_augment.FixResize(256), 53 | joint_augment.To_Tensor()]) 54 | evalImg_transform = standard_augment.Compose([ 55 | standard_augment.normalize([cfg.DATASET.MEAN], [cfg.DATASET.STD])]) 56 | 57 | if cfg.DATASET.NAME == "acdc": 58 | test_set = AcdcDataset(cfg.DATASET.TEST_LIST, df_used=True, joint_augment=eval_transform, 59 | augment=evalImg_transform) 60 | 61 | test_loader = DataLoader(test_set, batch_size=1, pin_memory=True, 62 | num_workers=args.workers, shuffle=False, 63 | collate_fn=BatchCollator(size_divisible=32, df_used=True)) 64 | return test_loader, test_set 65 | 66 | def cal_perfer(preds, masks, tb_dict): 67 | LV_dice = [] # 1 68 | MYO_dice = [] # 2 69 | RV_dice = [] # 3 70 | 71 | for i in range(preds.shape[0]): 72 | LV_dice.append(dice(preds[i,1,:,:],masks[i,1,:,:])) 73 | RV_dice.append(dice(preds[i, 3, :, :], masks[i, 3, :, :])) 74 | MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :])) 75 | # LV_dice.append(dice(preds[i, 3,:,:],masks[i,1,:,:])) 76 | # RV_dice.append(dice(preds[i, 1, :, :], masks[i, 3, :, :])) 77 | # MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :])) 78 | 79 | tb_dict["LV_dice"].append(np.mean(LV_dice)) 80 | tb_dict["RV_dice"].append(np.mean(RV_dice)) 81 | tb_dict["MYO_dice"].append(np.mean(MYO_dice)) 82 | return np.mean(LV_dice), np.mean(RV_dice), np.mean(MYO_dice) 83 | 84 | def make_one_hot(input, num_classes): 85 | """Convert class index tensor to one hot encoding tensor. 86 | Args: 87 | input: A tensor of shape [N, 1, *] 88 | num_classes: An int of number of class 89 | Returns: 90 | A tensor of shape [N, num_classes, *] 91 | """ 92 | shape = np.array(input.shape) 93 | shape[1] = num_classes 94 | shape = tuple(shape) 95 | result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1) 96 | # result = result.scatter_(1, input.cpu(), 1) 97 | 98 | return result 99 | 100 | def test_it(model, data, device="cuda"): 101 | model.eval() 102 | imgs, gts = data[:2] 103 | gts_df = data[2] 104 | 105 | imgs = imgs.to(device) 106 | gts = gts.to(device) 107 | 108 | net_out = model(imgs) 109 | if len(net_out) > 1: 110 | preds_out = net_out[0] 111 | preds_df = net_out[1] 112 | else: 113 | preds_out = net_out[0] 114 | preds_df = None 115 | preds_out = nn.functional.softmax(preds_out, dim=1) 116 | _, preds = torch.max(preds_out, 1) 117 | preds = preds.unsqueeze(1) # (N, 1, *) 118 | 119 | return preds, preds_df 120 | 121 | def vis_it(pred, gt, img=None, filename=None, infos=None): 122 | h, w = pred.shape 123 | # gt_contours = masks_to_contours(gt) 124 | # mask = np.hstack([pred, np.zeros((h, 1)), gt]) 125 | # gt_contours = np.hstack([gt_contours, np.zeros((h, 1)), np.zeros_like(gt)]) 126 | # im_rgb = mask2png(mask).astype(np.int16) 127 | # im_rgb[:, w, :] = [255, 255, 255] 128 | # im_rgb = apply_mask(im_rgb, gt_contours, [255, 255, 255], 0.8) 129 | pred_im = mask2png(pred).astype(np.int16) 130 | gt_im = mask2png(gt).astype(np.int16) 131 | 132 | img = (img - img.min()) / (img.max() - img.min()) * 255 133 | img = np.stack([img, img, img], axis=2) 134 | # img = img_mask_png(img, gt, alpha=0.1) 135 | 136 | # im_rgb = np.hstack([im_rgb, 255*np.ones((h, 1, 3)), img]) 137 | 138 | # cv2.putText(im_rgb, "prediction", (2,h-4), 139 | # fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1) 140 | # cv2.putText(im_rgb, "ground truth", (w, h-4), 141 | # fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1) 142 | 143 | # st_pos = 15 144 | # if infos is not None: 145 | # for info in infos: 146 | # cv2.putText(im_rgb, info+": {}".format(infos[info]), (2, st_pos), 147 | # fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.4, color=(255, 255, 255), thickness=1) 148 | # st_pos += 10 149 | 150 | cv2.imwrite(filename+"_img.png", img[:,:,::-1]) 151 | cv2.imwrite(filename+"_pred.png", pred_im[:,:,::-1]) 152 | cv2.imwrite(filename+"_gt.png", gt_im[:,:,::-1]) 153 | 154 | def vis_df(pred_df, gt_df, filename, infos=None): 155 | _, h, w = pred_df.shape 156 | 157 | # save .npy files 158 | np.save(filename+'.npy', [pred_df, gt_df]) 159 | 160 | theta = np.arctan2(gt_df[1,...], gt_df[0,...]) 161 | degree_gt = (theta - theta.min()) / (theta.max() - theta.min()) * 255 162 | # degree_gt = theta * 360 163 | mag_gt = np.sum(gt_df ** 2, axis=0, keepdims=False) 164 | mag_gt = mag_gt / mag_gt.max() * 255 165 | 166 | theta = np.arctan2(pred_df[1,...], pred_df[0,...]) 167 | degree_df = (theta - theta.min()) / (theta.max() - theta.min()) * 255 168 | # degree_df = theta * 360 169 | magnitude = np.sum(pred_df ** 2, axis=0, keepdims=False) 170 | magnitude = magnitude / magnitude.max() * 255 171 | 172 | im = np.hstack([magnitude, np.zeros((h, 1)), mag_gt, np.zeros((h, 1)), degree_df, np.zeros((h, 1)), degree_gt]).astype(np.uint8) 173 | im = cv2.applyColorMap(im, cv2.COLORMAP_JET) 174 | cv2.imwrite(filename+"_df_pred_mag.png", im[:h, :w, ...]) 175 | cv2.imwrite(filename+"_df_gt_mag.png", im[:h, w+1:2*w+1, ...]) 176 | cv2.imwrite(filename+"_df_pred_degree.png", im[:h, 2*w+2:3*w+2, ...]) 177 | cv2.imwrite(filename+"_df_gt_degree.png", im[:h, 3*w+3:, ...]) 178 | 179 | 180 | def test(): 181 | root_result_dir = args.output_dir 182 | os.makedirs(root_result_dir, exist_ok=True) 183 | 184 | log_file = os.path.join(root_result_dir, args.log_file) 185 | logger = create_logger(log_file) 186 | 187 | for key, val in vars(args).items(): 188 | logger.info("{:16} {}".format(key, val)) 189 | 190 | # create dataset & dataloader & network 191 | if args.used_df == 'U_NetDF': 192 | model = U_NetDF(selfeat=args.selfeat, num_class=4, auxseg=True) 193 | elif args.used_df == 'U_Net': 194 | model = U_Net(num_class=4) 195 | 196 | if args.mgpus is not None and len(args.mgpus) > 2: 197 | model = nn.DataParallel(model) 198 | model.cuda() 199 | 200 | test_loader, test_set = create_dataloader() 201 | 202 | checkpoint = torch.load(args.model_path1, map_location='cpu') 203 | model.load_state_dict(checkpoint['model_state']) 204 | 205 | dice_dict = {"LV_dice": [], 206 | "RV_dice": [], 207 | "MYO_dice": []} 208 | for i, data in enumerate(test_loader): 209 | if i != 23: continue 210 | # i = 5405 211 | # data = test_set[5405] 212 | # data = [data[0][None], data[1][None], data[2][None]] 213 | 214 | pred, pred_df = test_it(model, data[:3]) 215 | 216 | _, gt, gt_df = data[:3] 217 | gt = gt.to("cuda") 218 | 219 | L, R, MYO = cal_perfer(make_one_hot(pred, 4), make_one_hot(gt, 4), dice_dict) 220 | 221 | data_info = test_set.data_infos[i] 222 | if args.vis: 223 | # if 0.7 <= (L + R + MYO) / 3 < 0.8: 224 | vis_it(pred.cpu().numpy()[0, 0], gt.cpu().numpy()[0, 0], data[0].cpu().numpy()[0, 0], 225 | filename=os.path.join(root_result_dir, str(i))) 226 | if pred_df is not None: 227 | vis_df(pred_df.detach().cpu().numpy()[0], gt_df.cpu().numpy()[0], 228 | filename=os.path.join(root_result_dir, str(i))) 229 | 230 | print("\r{}/{} {:.0%} {}".format(i, len(test_set), i/len(test_set), 231 | np.mean(list(dice_dict.values()))), end="") 232 | print() 233 | 234 | logger.info("2D Dice Metirc:") 235 | logger.info("Total {}".format(len(test_set))) 236 | logger.info("LV_dice: {}".format(np.mean(dice_dict["LV_dice"]))) 237 | logger.info("RV_dice: {}".format(np.mean(dice_dict["RV_dice"]))) 238 | logger.info("MYO_dice: {}".format(np.mean(dice_dict["MYO_dice"]))) 239 | logger.info("Mean_dice: {}".format(np.mean(list(dice_dict.values())))) 240 | 241 | 242 | if __name__ == "__main__": 243 | test() -------------------------------------------------------------------------------- /tools/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import h5py 4 | import numpy as np 5 | import math 6 | 7 | from utils.image_list import to_image_list 8 | 9 | # acdc data 10 | def personTo4Ddata(personname, test_list): 11 | sliceofp = [] 12 | for tl in test_list: 13 | if '/'.join(personname.split('-')) in tl: 14 | sliceofp.append(tl) 15 | 16 | imgs = [[], []] 17 | gts = [[], []] 18 | for ti, time_i in enumerate(["ES", "ED"]): 19 | time_path = [] 20 | 21 | for sp in sliceofp: 22 | if time_i in sp: 23 | time_path.append(sp) 24 | 25 | for tp in time_path: 26 | imgs[ti].append(h5py.File(tp, 'r')['image']) 27 | gts[ti].append(h5py.File(tp, 'r')['label']) 28 | 29 | imgs = np.array(imgs).transpose(1,2,3,0) 30 | gts = np.array(gts).transpose(1,2,3,0) 31 | return imgs, gts 32 | 33 | def test_it(model, data, device="cuda", used_df=False): 34 | model.eval() 35 | imgs = data 36 | 37 | imgs = imgs.to(device) 38 | # gts = gts.to(device) 39 | 40 | net_out = model(imgs) 41 | if used_df: 42 | preds_out = net_out[0] 43 | preds_df = net_out[1] 44 | else: 45 | preds_out = net_out[0] 46 | preds_df = None 47 | preds_out = nn.functional.softmax(preds_out, dim=1) 48 | _, preds = torch.max(preds_out, 1) 49 | preds = preds.unsqueeze(1) # (N, 1, *) 50 | 51 | return preds, preds_df 52 | 53 | def test_person(model, imgs, multi_batches=False, used_df=False): 54 | """ imgs: (times, slices, H, W) 55 | preds: (times, slices, H, W) 56 | """ 57 | preds = [] 58 | for i in range(len(imgs)): 59 | preds_timei = np.zeros([imgs[i].size(0), imgs[i].size(2), imgs[i].size(3)]) 60 | 61 | if multi_batches: 62 | batch_size = 32 63 | for bs in range(math.ceil(len(imgs[i]) / batch_size)): 64 | st = batch_size * bs 65 | end = st + batch_size if (st+batch_size) <= len(imgs[i]) else len(imgs[i]) 66 | # data, origin_shape = to_image_list(imgs[i][st:st+batch_size], size_divisible=32, return_size=True) 67 | data = imgs[i][st:end] 68 | origin_shape = imgs[i].shape[-2:] 69 | 70 | pred, _ = test_it(model, data, used_df=used_df) 71 | preds_timei[st:end, ...] = pred.cpu().numpy()[:, 0, :origin_shape[0], :origin_shape[1]] 72 | # =========================== 73 | # data, origin_shape = to_image_list(imgs[i], size_divisible=32, return_size=True) 74 | # pred, _ = test_it(model, data, used_df=used_df) 75 | 76 | # for j in range(imgs[i].shape[0]): 77 | # preds_timei[j, ...] = pred.cpu().numpy()[j, :, :origin_shape[j][0], :origin_shape[j][1]] 78 | else: 79 | for j, pt in enumerate(imgs[i]): 80 | data = [pt] 81 | data, origin_shape = to_image_list(data, size_divisible=32, return_size=True) 82 | pred, _ = test_it(model, data) 83 | preds_timei[j, ...] = pred.cpu().numpy()[0, 0, :origin_shape[0][0], :origin_shape[0][1]] 84 | 85 | preds.append(preds_timei) 86 | 87 | return preds 88 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import importlib 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.distributed as dist 10 | from torch.utils.data import DataLoader 11 | import torchvision 12 | from tensorboardX import SummaryWriter 13 | 14 | import _init_paths 15 | from libs.configs.config_acdc import cfg 16 | 17 | from libs.datasets import AcdcDataset 18 | from libs.datasets import joint_augment as joint_augment 19 | from libs.datasets import augment as standard_augment 20 | from libs.datasets.collate_batch import BatchCollator 21 | # from libs.losses.df_loss import EuclideanLossWithOHEM 22 | # from libs.losses.surface_loss import SurfaceLoss 23 | from libs.losses.create_losses import Total_loss 24 | import train_utils.train_utils as train_utils 25 | from train_utils.train_utils import load_checkpoint 26 | from utils.init_net import init_weights 27 | from utils.comm import get_rank, synchronize 28 | 29 | 30 | parser = argparse.ArgumentParser(description="arg parser") 31 | parser.add_argument("--local_rank", type=int, default=0, required=True, help="device_ids of DistributedDataParallel") 32 | parser.add_argument("--batch_size", type=int, default=32, required=False, help="batch size for training") 33 | parser.add_argument("--epochs", type=int, default=50, required=False, help="Number of epochs to train for") 34 | parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader') 35 | parser.add_argument("--ckpt_save_interval", type=int, default=5, help="number of training epochs") 36 | parser.add_argument('--output_dir', type=str, default=None, help='specify an output directory if needed') 37 | parser.add_argument('--mgpus', type=str, default=None, help='whether to use multiple gpu') 38 | parser.add_argument("--ckpt", type=str, default=None, help="continue training from this checkpoint") 39 | parser.add_argument('--train_with_eval', action='store_true', default=False, help='whether to train with evaluation') 40 | args = parser.parse_args() 41 | 42 | FILE_DIR = os.path.dirname(os.path.abspath(__file__)) 43 | 44 | if args.mgpus is not None: 45 | os.environ["CUDA_VISIBLE_DEVICES"] = args.mgpus 46 | 47 | def create_logger(log_file, dist_rank): 48 | if dist_rank > 0: 49 | logger = logging.getLogger(__name__) 50 | logger.setLevel(logging.WARNING) 51 | return logger 52 | log_format = '%(asctime)s %(levelname)5s %(message)s' 53 | logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file) 54 | console = logging.StreamHandler() 55 | console.setLevel(logging.DEBUG) 56 | console.setFormatter(logging.Formatter(log_format)) 57 | logging.getLogger(__name__).addHandler(console) 58 | return logging.getLogger(__name__) 59 | 60 | def create_dataloader(logger): 61 | train_joint_transform = joint_augment.Compose([ 62 | joint_augment.To_PIL_Image(), 63 | joint_augment.RandomAffine(0,translate=(0.125, 0.125)), 64 | joint_augment.RandomRotate((-180,180)), 65 | joint_augment.FixResize(256) 66 | ]) 67 | transform = standard_augment.Compose([ 68 | standard_augment.to_Tensor(), 69 | standard_augment.normalize([cfg.DATASET.MEAN], [cfg.DATASET.STD])]) 70 | target_transform = standard_augment.Compose([ 71 | standard_augment.to_Tensor()]) 72 | 73 | if cfg.DATASET.NAME == 'acdc': 74 | train_set = AcdcDataset(data_list=cfg.DATASET.TRAIN_LIST, 75 | df_used=cfg.DATASET.DF_USED, df_norm=cfg.DATASET.DF_NORM, 76 | boundary=cfg.DATASET.BOUNDARY, 77 | joint_augment=train_joint_transform, 78 | augment=transform, target_augment=target_transform) 79 | 80 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, 81 | num_replicas=dist.get_world_size(), rank=dist.get_rank()) 82 | train_loader = DataLoader(train_set, batch_size=args.batch_size, pin_memory=True, 83 | num_workers=args.workers, shuffle=False, sampler=train_sampler, 84 | collate_fn=BatchCollator(size_divisible=32, df_used=cfg.DATASET.DF_USED, 85 | boundary=cfg.DATASET.BOUNDARY)) 86 | 87 | if args.train_with_eval: 88 | eval_transform = joint_augment.Compose([ 89 | joint_augment.To_PIL_Image(), 90 | joint_augment.FixResize(256), 91 | joint_augment.To_Tensor()]) 92 | evalImg_transform = standard_augment.Compose([ 93 | standard_augment.normalize([cfg.DATASET.MEAN], [cfg.DATASET.STD])]) 94 | 95 | if cfg.DATASET.NAME == 'acdc': 96 | test_set = AcdcDataset(data_list=cfg.DATASET.TEST_LIST, 97 | df_used=cfg.DATASET.DF_USED, df_norm=cfg.DATASET.DF_NORM, 98 | boundary=cfg.DATASET.BOUNDARY, 99 | joint_augment=eval_transform, 100 | augment=evalImg_transform) 101 | 102 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_set, 103 | num_replicas=dist.get_world_size(), rank=dist.get_rank()) 104 | test_loader = DataLoader(test_set, batch_size=args.batch_size, pin_memory=True, 105 | num_workers=args.workers, shuffle=False, sampler=test_sampler, 106 | collate_fn=BatchCollator(size_divisible=32, df_used=cfg.DATASET.DF_USED, 107 | boundary=cfg.DATASET.BOUNDARY)) 108 | else: 109 | test_loader = None 110 | 111 | return train_loader, test_loader 112 | 113 | def create_optimizer(model): 114 | if cfg.TRAIN.OPTIMIZER == "adam": 115 | optimizer = optim.Adam(model.parameters(), lr=cfg.TRAIN.LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY) 116 | elif cfg.TRAIN.OPTIMIZER == "sgd": 117 | optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY, 118 | momentum=cfg.TRAIN.MOMENTUM) 119 | else: 120 | raise NotImplementedError 121 | return optimizer 122 | 123 | def create_scheduler(model, optimizer, total_steps, last_epoch): 124 | def lr_lbmd(cur_epoch): 125 | cur_decay = 1 126 | for decay_step in cfg.TRAIN.DECAY_STEP_LIST: 127 | if cur_epoch >= decay_step: 128 | cur_decay = cur_decay * cfg.TRAIN.LR_DECAY 129 | return max(cur_decay, cfg.TRAIN.LR_CLIP / cfg.TRAIN.LR) 130 | 131 | lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lbmd, last_epoch=last_epoch) 132 | return lr_scheduler 133 | 134 | def create_model(cfg): 135 | network = cfg.TRAIN.NET 136 | 137 | module = 'libs.network.' + network[:network.rfind('.')] 138 | model = network[network.rfind('.')+1:] 139 | 140 | mod = importlib.import_module(module) 141 | mod_func = importlib.import_module('libs.network.train_functions') 142 | net_func = getattr(mod, model) 143 | 144 | net = net_func(num_class=cfg.DATASET.NUM_CLASS) 145 | if network == 'unet.U_Net': 146 | train_func = getattr(mod_func, 'model_fn_decorator') 147 | elif network == 'unet_df.U_NetDF': 148 | net = net_func(selfeat=cfg.MODEL.SELFEATURE, num_class=cfg.DATASET.NUM_CLASS, shift_n=cfg.MODEL.SHIFT_N, auxseg=cfg.MODEL.AUXSEG) 149 | train_func = getattr(mod_func, 'model_DF_decorator') 150 | 151 | return net, train_func 152 | 153 | def train(): 154 | torch.cuda.set_device(args.local_rank) 155 | dist.init_process_group(backend="nccl", init_method="env://") 156 | synchronize() 157 | 158 | # create dataloader & network & optimizer 159 | model, model_fn_decorator = create_model(cfg) 160 | init_weights(model, init_type='kaiming') 161 | # model.to('cuda') 162 | model.cuda() 163 | model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) 164 | 165 | root_result_dir = args.output_dir 166 | os.makedirs(root_result_dir, exist_ok=True) 167 | 168 | log_file = os.path.join(root_result_dir, "log_train.txt") 169 | logger = create_logger(log_file, get_rank()) 170 | logger.info("**********************Start logging**********************") 171 | 172 | # log to file 173 | gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL' 174 | logger.info("CUDA_VISIBLE_DEVICES=%s" % gpu_list) 175 | 176 | for key, val in vars(args).items(): 177 | logger.info("{:16} {}".format(key, val)) 178 | 179 | logger.info("***********************config infos**********************") 180 | for key, val in vars(cfg).items(): 181 | logger.info("{:16} {}".format(key, val)) 182 | 183 | # log tensorboard 184 | if get_rank() == 0: 185 | tb_log = SummaryWriter(log_dir=os.path.join(root_result_dir, "tensorboard")) 186 | else: 187 | tb_log = None 188 | 189 | 190 | train_loader, test_loader = create_dataloader(logger) 191 | 192 | optimizer = create_optimizer(model) 193 | 194 | # load checkpoint if it is possible 195 | start_epoch = it = best_res = 0 196 | last_epoch = -1 197 | if args.ckpt is not None: 198 | pure_model = model.module if isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) else model 199 | it, start_epoch, best_res = load_checkpoint(pure_model, optimizer, args.ckpt, logger) 200 | last_epoch = start_epoch + 1 201 | 202 | lr_scheduler = create_scheduler(model, optimizer, total_steps=len(train_loader)*args.epochs, 203 | last_epoch=last_epoch) 204 | 205 | if cfg.DATASET.DF_USED: 206 | criterion = Total_loss(boundary=cfg.DATASET.BOUNDARY) 207 | else: 208 | criterion = nn.CrossEntropyLoss() 209 | 210 | 211 | # start training 212 | logger.info('**********************Start training**********************') 213 | ckpt_dir = os.path.join(root_result_dir, "ckpt") 214 | os.makedirs(ckpt_dir, exist_ok=True) 215 | trainer = train_utils.Trainer(model, 216 | model_fn=model_fn_decorator(), 217 | criterion=criterion, 218 | optimizer=optimizer, 219 | ckpt_dir=ckpt_dir, 220 | lr_scheduler=lr_scheduler, 221 | model_fn_eval=model_fn_decorator(), 222 | tb_log=tb_log, 223 | logger=logger, 224 | eval_frequency=1, 225 | grad_norm_clip=cfg.TRAIN.GRAD_NORM_CLIP, 226 | cfg=cfg) 227 | 228 | trainer.train(start_it=it, 229 | start_epoch=start_epoch, 230 | n_epochs=args.epochs, 231 | train_loader=train_loader, 232 | test_loader=test_loader, 233 | ckpt_save_interval=args.ckpt_save_interval, 234 | lr_scheduler_each_iter=False, 235 | best_res=best_res) 236 | 237 | logger.info('**********************End training**********************') 238 | 239 | 240 | # python -m torch.distributed.launch --nproc_per_node 2 --master_port $RANDOM tools/train.py --batch_size 20 --mgpus 2,3 --output_dir logs/... --train_with_eval 241 | if __name__ == "__main__": 242 | train() 243 | 244 | 245 | -------------------------------------------------------------------------------- /tools/train_utils/__pycache__/_init_paths.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/tools/train_utils/__pycache__/_init_paths.cpython-36.pyc -------------------------------------------------------------------------------- /tools/train_utils/__pycache__/_init_paths.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/tools/train_utils/__pycache__/_init_paths.cpython-37.pyc -------------------------------------------------------------------------------- /tools/train_utils/__pycache__/base_dir.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/tools/train_utils/__pycache__/base_dir.cpython-36.pyc -------------------------------------------------------------------------------- /tools/train_utils/__pycache__/test_3D.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/tools/train_utils/__pycache__/test_3D.cpython-36.pyc -------------------------------------------------------------------------------- /tools/train_utils/__pycache__/test_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/tools/train_utils/__pycache__/test_utils.cpython-36.pyc -------------------------------------------------------------------------------- /tools/train_utils/__pycache__/train_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/tools/train_utils/__pycache__/train_utils.cpython-36.pyc -------------------------------------------------------------------------------- /tools/train_utils/__pycache__/train_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/tools/train_utils/__pycache__/train_utils.cpython-37.pyc -------------------------------------------------------------------------------- /tools/train_utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributed as dist 5 | import numpy as np 6 | import json 7 | 8 | from utils.comm import get_world_size, get_rank 9 | import utils.metrics as metrics 10 | from utils.image_list import to_image_list 11 | from tools.test_utils import personTo4Ddata, test_person 12 | from libs.datasets import joint_augment as joint_augment 13 | from libs.datasets import augment as standard_augment 14 | 15 | def save_checkpoint(state, filename='checkpoint', is_best=False): 16 | filename = '{}.pth'.format(filename) 17 | torch.save(state, filename) 18 | if is_best: 19 | torch.save(state, os.path.join(os.path.dirname(filename), "model_best.pth")) 20 | 21 | def checkpoint_state(model=None, optimizer=None, epoch=None, it=None, performance=0.): 22 | optim_state = optimizer.state_dict() if optimizer is not None else None 23 | if model is not None: 24 | if isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): 25 | model_state = model.module.state_dict() 26 | else: 27 | model_state = model.state_dict() 28 | else: 29 | model_state = None 30 | 31 | return {'epoch': epoch, 'it': it, 'model_state': model_state, 'optimizer_state': optim_state, 'performance': performance} 32 | 33 | def load_checkpoint(model=None, optimizer=None, filename="checkpoint", logger=None): 34 | if os.path.isfile(filename): 35 | if logger is not None: 36 | logger.info("==> Loading from checkpoint '{}'".format(filename)) 37 | checkpoint = torch.load(filename, map_location="cpu") 38 | epoch = checkpoint['epoch'] if 'epoch' in checkpoint.keys() else -1 39 | it = checkpoint.get('it', 0.0) 40 | performance = checkpoint.get('performance', 0.) 41 | if model is not None and checkpoint['model_state'] is not None: 42 | model.load_state_dict(checkpoint['model_state']) 43 | if optimizer is not None and checkpoint['optimizer_state'] is not None: 44 | optimizer.load_state_dict(checkpoint['optimizer_state']) 45 | if logger is not None: 46 | logger.info("==> Done") 47 | else: 48 | raise FileNotFoundError 49 | 50 | return it, epoch, performance 51 | 52 | class Trainer(): 53 | def __init__(self, model, model_fn, criterion, optimizer, ckpt_dir, lr_scheduler, model_fn_eval, 54 | tb_log, logger, eval_frequency=1, grad_norm_clip=1.0, cfg=None): 55 | self.model, self.model_fn, self.optimizer, self.model_fn_eval = model, model_fn, optimizer, model_fn_eval 56 | 57 | self.criterion = criterion 58 | self.lr_scheduler = lr_scheduler 59 | self.ckpt_dir = ckpt_dir 60 | self.tb_log = tb_log 61 | self.logger = logger 62 | self.eval_frequency = eval_frequency 63 | self.grad_norm_clip = grad_norm_clip 64 | self.cfg = cfg 65 | self.caches_4D = {} 66 | 67 | def _train_it(self, batch, epoch=0): 68 | self.model.train() 69 | 70 | self.optimizer.zero_grad() 71 | loss, tb_dict, disp_dict = self.model_fn(self.model, batch, self.criterion, perfermance=False, epoch=0) 72 | 73 | loss.backward(retain_graph=True) 74 | self.optimizer.step() 75 | return loss.item(), tb_dict, disp_dict 76 | 77 | def eval_epoch(self, d_loader): 78 | self.model.eval() 79 | 80 | eval_dict = {} 81 | total_loss = 0 82 | 83 | # eval one epoch 84 | if get_rank() == 0: print("evaluating...") 85 | sel_num = np.random.choice(len(d_loader), size=1) 86 | for i, data in enumerate(d_loader, 0): 87 | self.optimizer.zero_grad() 88 | vis = True if i == sel_num else False 89 | 90 | loss, tb_dict, disp_dict = self.model_fn_eval(self.model, data, self.criterion, perfermance=True, vis=vis) 91 | 92 | total_loss += loss.item() 93 | 94 | for k, v in tb_dict.items(): 95 | if "vis" not in k: 96 | eval_dict[k] = eval_dict.get(k, 0) + v 97 | else: 98 | eval_dict[k] = v 99 | if get_rank() == 0: print("\r{}/{} {:.0%}\r".format(i, len(d_loader), i/len(d_loader)), end='') 100 | if get_rank() == 0: print() 101 | 102 | for k, v in tb_dict.items(): 103 | if "vis" not in k: 104 | eval_dict[k] = eval_dict.get(k, 0) / (i + 1) 105 | 106 | return total_loss / (i+1), eval_dict, disp_dict 107 | 108 | def train(self, start_it, start_epoch, n_epochs, train_loader, test_loader=None, 109 | ckpt_save_interval=5, lr_scheduler_each_iter=False, best_res=0): 110 | eval_frequency = self.eval_frequency if self.eval_frequency else 1 111 | 112 | it = start_it 113 | for epoch in range(start_epoch, n_epochs): 114 | if self.lr_scheduler is not None: 115 | self.lr_scheduler.step(epoch) 116 | 117 | for cur_it, batch in enumerate(train_loader): 118 | cur_lr = self.lr_scheduler.get_lr()[0] 119 | 120 | loss, tb_dict, disp_dict = self._train_it(batch, epoch) 121 | it += 1 122 | 123 | # print infos 124 | if get_rank() == 0: 125 | print("Epoch/train:{}({:.0%})/{}({:.0%})".format(epoch, epoch/n_epochs, 126 | cur_it, cur_it/len(train_loader)), end="") 127 | for k, v in disp_dict.items(): 128 | print(", ", k+": {:.6}".format(v), end="") 129 | print("") 130 | 131 | # tensorboard logs 132 | if self.tb_log is not None: 133 | self.tb_log.add_scalar("train_loss", loss, it) 134 | self.tb_log.add_scalar("learning_rate", cur_lr, it) 135 | for key, val in tb_dict.items(): 136 | self.tb_log.add_scalar('train_'+key, val, it) 137 | 138 | # save trained model 139 | trained_epoch = epoch 140 | # if trained_epoch % ckpt_save_interval == 0: 141 | # ckpt_name = os.path.join(self.ckpt_dir, "checkpoint_epoch_%d" % trained_epoch) 142 | # save_checkpoint(checkpoint_state(self.model, self.optimizer, trained_epoch, it), 143 | # filename=ckpt_name) 144 | 145 | # eval one epoch 146 | if (epoch % eval_frequency) == 0 and (test_loader is not None): 147 | with torch.set_grad_enabled(False): 148 | val_loss, eval_dict, disp_dict = self.eval_epoch(test_loader) 149 | # mean_3D = self.metric_3D(self.model, self.cfg) 150 | 151 | if self.tb_log is not None: 152 | for key, val in eval_dict.items(): 153 | if "vis" not in key: 154 | self.tb_log.add_scalar("val_"+key, val, it) 155 | else: 156 | self.tb_log.add_images("df_gt", val[0], it, dataformats="NCHW") 157 | self.tb_log.add_images("df_pred", val[2], it, dataformats="NCHW") 158 | self.tb_log.add_images("df_magnitude", val[1], it, dataformats="NCHW") 159 | 160 | # save model and best model 161 | if get_rank() == 0: 162 | # cal 3D dice 163 | # if self.tb_log is not None: 164 | # for k, v in mean_3D.items(): 165 | # self.tb_log.add_scalar("val_3D_"+k, v, it) 166 | 167 | res = np.mean([eval_dict["LV_dice"], eval_dict["RV_dice"], eval_dict["MYO_dice"]]) 168 | # res = np.mean([mean_3D["LV_dice"], mean_3D["RV_dice"], mean_3D["MYO_dice"]]) 169 | self.logger.info("Epoch {} mean dice(2D/3D): {}/N".format(epoch, res)) 170 | if best_res != 0: 171 | _, _, best_res = load_checkpoint(filename=os.path.join(self.ckpt_dir, "model_best.pth")) 172 | is_best = res > best_res 173 | best_res = max(res, best_res) 174 | 175 | ckpt_name = os.path.join(self.ckpt_dir, "checkpoint_epoch_%d" % trained_epoch) 176 | save_checkpoint(checkpoint_state(self.model, self.optimizer, trained_epoch, it, performance=res), 177 | filename=ckpt_name, is_best=is_best) 178 | 179 | def metric_3D(self, model, cfg): 180 | p_json = cfg.DATASET.TEST_PERSON_LIST 181 | datadir_4D = "/root/ACDC_DataSet/4dData" 182 | 183 | with open(p_json, "r") as f: 184 | persons = json.load(f) 185 | 186 | total_segMetrics = {"dice": [[], [], []], 187 | "hausdorff": [[], [], []]} 188 | for i, p in enumerate(persons): 189 | # imgs, gts = personTo4Ddata(p, val_list) 190 | if p in self.caches_4D.keys(): 191 | imgs, gts = self.caches_4D[p] 192 | else: 193 | imgs = np.load(os.path.join(datadir_4D, p.split('-')[1], '4d_data.npy')) 194 | gts = np.load(os.path.join(datadir_4D, p.split('-')[1], '4d_gt.npy')) 195 | self.caches_4D[p] = [imgs, gts] 196 | 197 | imgs, gts = imgs.astype(np.float32)[..., None, :], gts.astype(np.float32)[..., None,:] 198 | imgs, gts = joint_transform(imgs, gts, cfg) 199 | gts = [gt[:, 0, ...].numpy() for gt in gts] 200 | 201 | preds = test_person(model, imgs, multi_batches=True, used_df=cfg.DATASET.DF_USED) # (times, slices, H, W) 202 | 203 | segMetrics = {"dice": [], "hausdorff": []} 204 | for j in range(len(preds)): 205 | segMetrics["dice"].append(metrics.dice3D(preds[j], gts[j], gts[j].shape)) 206 | segMetrics["hausdorff"].append(metrics.hd_3D(preds[j], gts[j])) 207 | 208 | for k, v in segMetrics.items(): 209 | segMetrics[k] = np.array(v).reshape((-1, 3)) 210 | 211 | for k, v in total_segMetrics.items(): 212 | for j in range(3): 213 | total_segMetrics[k][j] += segMetrics[k][:, j].tolist() 214 | # person i is done 215 | if get_rank() == 0: print("\r{}/{} {:.0%}\r".format(i, len(persons), i/len(persons)), end='') 216 | if get_rank() == 0: print() 217 | 218 | mean = {} 219 | for k, v in total_segMetrics.items(): 220 | mean.update({"LV_"+k: np.mean(v[1])}) 221 | mean.update({"MYO_"+k: np.mean(v[2])}) 222 | mean.update({"RV_"+k: np.mean(v[0])}) 223 | return mean 224 | 225 | def transform(imgs, cfg): 226 | trans = standard_augment.Compose([standard_augment.normalize([cfg.DATASET.MEAN], [cfg.DATASET.STD]), 227 | ]) 228 | return trans(imgs) 229 | def joint_transform(imgs, gts, cfg): 230 | trans = joint_augment.Compose([joint_augment.To_PIL_Image(), 231 | # joint_augment.RandomAffine(0,translate=(0.125, 0.125)), 232 | # joint_augment.RandomRotate((-180,180)), 233 | joint_augment.FixResize(256), 234 | joint_augment.To_Tensor() 235 | ]) 236 | S, H, W, C, T = gts.shape 237 | trans_imgs = [None] * T 238 | trans_gts = [None] * T 239 | for i in range(T): 240 | trans_imgs[i], trans_gts[i] = [], [] 241 | for j in range(S): 242 | t0, t1 = trans(imgs[j,...,i], gts[j,...,i]) 243 | trans_imgs[i].append(transform(t0, cfg)) 244 | trans_gts[i].append(t1) 245 | 246 | aligned_imgs = [] 247 | aligned_gts = [] 248 | for i in range(T): 249 | aligned_imgs.append(to_image_list(trans_imgs[i], size_divisible=32)) 250 | aligned_gts.append(to_image_list(trans_gts[i], size_divisible=32)) 251 | 252 | return aligned_imgs, aligned_gts 253 | 254 | -------------------------------------------------------------------------------- /utils/comm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import pickle 4 | 5 | def get_world_size(): 6 | if not dist.is_available(): 7 | return 1 8 | if not dist.is_initialized(): 9 | return 1 10 | return dist.get_world_size() 11 | 12 | def get_rank(): 13 | if not dist.is_available(): 14 | return 0 15 | if not dist.is_initialized(): 16 | return 0 17 | return dist.get_rank() 18 | 19 | def synchronize(): 20 | """ 21 | Helper function to synchronize (barrier) among all processes when 22 | using distributed training 23 | """ 24 | if not dist.is_available(): 25 | return 26 | if not dist.is_initialized(): 27 | return 28 | world_size = dist.get_world_size() 29 | if world_size == 1: 30 | return 31 | dist.barrier() 32 | 33 | def all_gather(data): 34 | """ 35 | Run all_gather on arbitrary picklable data (not necessarily tensors) 36 | Args: 37 | data: any picklable object 38 | Returns: 39 | list[data]: list of data gathered from each rank 40 | """ 41 | world_size = get_world_size() 42 | if world_size == 1: 43 | return [data] 44 | 45 | # serialized to a Tensor 46 | buffer = pickle.dumps(data) 47 | storage = torch.ByteStorage.from_buffer(buffer) 48 | tensor = torch.ByteTensor(storage).to("cuda") 49 | 50 | # obtain Tensor size of each rank 51 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 52 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] 53 | dist.all_gather(size_list, local_size) 54 | size_list = [int(size.item()) for size in size_list] 55 | max_size = max(size_list) 56 | 57 | # receiving Tensor from all ranks 58 | # we pad the tensor because torch all_gather does not support 59 | # gathering tensors of different shapes 60 | tensor_list = [] 61 | for _ in size_list: 62 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 63 | if local_size != max_size: 64 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 65 | tensor = torch.cat((tensor, padding), dim=0) 66 | dist.all_gather(tensor_list, tensor) 67 | 68 | data_list = [] 69 | for size, tensor in zip(size_list, tensor_list): 70 | buffer = tensor.cpu().numpy().tobytes()[:size] 71 | data_list.append(pickle.loads(buffer)) 72 | 73 | return data_list -------------------------------------------------------------------------------- /utils/dice3D.py: -------------------------------------------------------------------------------- 1 | """ 2 | author: Clément Zotti (clement.zotti@usherbrooke.ca) 3 | date: April 2017 4 | 5 | DESCRIPTION : 6 | The script provide helpers functions to handle nifti image format: 7 | - load_nii() 8 | - save_nii() 9 | 10 | to generate metrics for two images: 11 | - metrics() 12 | 13 | And it is callable from the command line (see below). 14 | Each function provided in this script has comments to understand 15 | how they works. 16 | 17 | HOW-TO: 18 | 19 | This script was tested for python 3.4. 20 | 21 | First, you need to install the required packages with 22 | pip install -r requirements.txt 23 | 24 | After the installation, you have two ways of running this script: 25 | 1) python metrics.py ground_truth/patient001_ED.nii.gz prediction/patient001_ED.nii.gz 26 | 2) python metrics.py ground_truth/ prediction/ 27 | 28 | The first option will print in the console the dice and volume of each class for the given image. 29 | The second option wiil ouput a csv file where each images will have the dice and volume of each class. 30 | 31 | 32 | Link: http://acdc.creatis.insa-lyon.fr 33 | 34 | """ 35 | 36 | import os 37 | from glob import glob 38 | import time 39 | import re 40 | import argparse 41 | import nibabel as nib 42 | # import pandas as pd 43 | from medpy.metric.binary import hd, dc 44 | import numpy as np 45 | 46 | 47 | 48 | HEADER = ["Name", "Dice LV", "Volume LV", "Err LV(ml)", 49 | "Dice RV", "Volume RV", "Err RV(ml)", 50 | "Dice MYO", "Volume MYO", "Err MYO(ml)"] 51 | 52 | # 53 | # Utils functions used to sort strings into a natural order 54 | # 55 | def conv_int(i): 56 | return int(i) if i.isdigit() else i 57 | 58 | 59 | def natural_order(sord): 60 | """ 61 | Sort a (list,tuple) of strings into natural order. 62 | 63 | Ex: 64 | 65 | ['1','10','2'] -> ['1','2','10'] 66 | 67 | ['abc1def','ab10d','b2c','ab1d'] -> ['ab1d','ab10d', 'abc1def', 'b2c'] 68 | 69 | """ 70 | if isinstance(sord, tuple): 71 | sord = sord[0] 72 | return [conv_int(c) for c in re.split(r'(\d+)', sord)] 73 | 74 | 75 | # 76 | # Utils function to load and save nifti files with the nibabel package 77 | # 78 | def load_nii(img_path): 79 | """ 80 | Function to load a 'nii' or 'nii.gz' file, The function returns 81 | everyting needed to save another 'nii' or 'nii.gz' 82 | in the same dimensional space, i.e. the affine matrix and the header 83 | 84 | Parameters 85 | ---------- 86 | 87 | img_path: string 88 | String with the path of the 'nii' or 'nii.gz' image file name. 89 | 90 | Returns 91 | ------- 92 | Three element, the first is a numpy array of the image values, 93 | the second is the affine transformation of the image, and the 94 | last one is the header of the image. 95 | """ 96 | nimg = nib.load(img_path) 97 | return nimg.get_data(), nimg.affine, nimg.header 98 | 99 | 100 | def save_nii(img_path, data, affine, header): 101 | """ 102 | Function to save a 'nii' or 'nii.gz' file. 103 | 104 | Parameters 105 | ---------- 106 | 107 | img_path: string 108 | Path to save the image should be ending with '.nii' or '.nii.gz'. 109 | 110 | data: np.array 111 | Numpy array of the image data. 112 | 113 | affine: list of list or np.array 114 | The affine transformation to save with the image. 115 | 116 | header: nib.Nifti1Header 117 | The header that define everything about the data 118 | (pleasecheck nibabel documentation). 119 | """ 120 | nimg = nib.Nifti1Image(data, affine=affine, header=header) 121 | nimg.to_filename(img_path) 122 | 123 | 124 | # 125 | # Functions to process files, directories and metrics 126 | # 127 | def metrics(img_gt, img_pred, voxel_size): 128 | """ 129 | Function to compute the metrics between two segmentation maps given as input. 130 | 131 | Parameters 132 | ---------- 133 | img_gt: np.array 134 | Array of the ground truth segmentation map. 135 | 136 | img_pred: np.array 137 | Array of the predicted segmentation map. 138 | 139 | voxel_size: list, tuple or np.array 140 | The size of a voxel of the images used to compute the volumes. 141 | 142 | Return 143 | ------ 144 | A list of metrics in this order, [Dice LV, Volume LV, Err LV(ml), 145 | Dice RV, Volume RV, Err RV(ml), Dice MYO, Volume MYO, Err MYO(ml)] 146 | """ 147 | 148 | if img_gt.ndim != img_pred.ndim: 149 | raise ValueError("The arrays 'img_gt' and 'img_pred' should have the " 150 | "same dimension, {} against {}".format(img_gt.ndim, 151 | img_pred.ndim)) 152 | 153 | res = [] 154 | # Loop on each classes of the input images 155 | for c in [3, 1, 2]: 156 | # Copy the gt image to not alterate the input 157 | gt_c_i = np.copy(img_gt) 158 | gt_c_i[gt_c_i != c] = 0 159 | 160 | # Copy the pred image to not alterate the input 161 | pred_c_i = np.copy(img_pred) 162 | pred_c_i[pred_c_i != c] = 0 163 | 164 | # Clip the value to compute the volumes 165 | gt_c_i = np.clip(gt_c_i, 0, 1) 166 | pred_c_i = np.clip(pred_c_i, 0, 1) 167 | 168 | # Compute the Dice 169 | dice = dc(gt_c_i, pred_c_i) 170 | 171 | # Compute volume 172 | volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000. 173 | volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000. 174 | 175 | # res += [dice, volpred, volpred-volgt] 176 | res += [dice] 177 | 178 | return res 179 | 180 | 181 | def compute_metrics_on_files(path_gt, path_pred): 182 | """ 183 | Function to give the metrics for two files 184 | 185 | Parameters 186 | ---------- 187 | 188 | path_gt: string 189 | Path of the ground truth image. 190 | 191 | path_pred: string 192 | Path of the predicted image. 193 | """ 194 | gt, _, header = load_nii(path_gt) 195 | pred, _, _ = load_nii(path_pred) 196 | zooms = header.get_zooms() 197 | 198 | name = os.path.basename(path_gt) 199 | name = name.split('.')[0] 200 | res = metrics(gt, pred, zooms) 201 | res = ["{:.3f}".format(r) for r in res] 202 | 203 | formatting = "{:>14}, {:>7}, {:>9}, {:>10}, {:>7}, {:>9}, {:>10}, {:>8}, {:>10}, {:>11}" 204 | print(formatting.format(*HEADER)) 205 | print(formatting.format(name, *res)) 206 | 207 | 208 | def compute_metrics_on_directories(dir_gt, dir_pred): 209 | """ 210 | Function to generate a csv file for each images of two directories. 211 | 212 | Parameters 213 | ---------- 214 | 215 | path_gt: string 216 | Directory of the ground truth segmentation maps. 217 | 218 | path_pred: string 219 | Directory of the predicted segmentation maps. 220 | """ 221 | lst_gt = sorted(glob(os.path.join(dir_gt, '*')), key=natural_order) 222 | lst_pred = sorted(glob(os.path.join(dir_pred, '*')), key=natural_order) 223 | 224 | res = [] 225 | for p_gt, p_pred in zip(lst_gt, lst_pred): 226 | if os.path.basename(p_gt) != os.path.basename(p_pred): 227 | raise ValueError("The two files don't have the same name" 228 | " {}, {}.".format(os.path.basename(p_gt), 229 | os.path.basename(p_pred))) 230 | 231 | gt, _, header = load_nii(p_gt) 232 | pred, _, _ = load_nii(p_pred) 233 | zooms = header.get_zooms() 234 | res.append(metrics(gt, pred, zooms)) 235 | 236 | lst_name_gt = [os.path.basename(gt).split(".")[0] for gt in lst_gt] 237 | res = [[n,] + r for r, n in zip(res, lst_name_gt)] 238 | df = pd.DataFrame(res, columns=HEADER) 239 | df.to_csv("results_{}.csv".format(time.strftime("%Y%m%d_%H%M%S")), index=False) 240 | 241 | def main(path_gt, path_pred): 242 | """ 243 | Main function to select which method to apply on the input parameters. 244 | """ 245 | if os.path.isfile(path_gt) and os.path.isfile(path_pred): 246 | compute_metrics_on_files(path_gt, path_pred) 247 | elif os.path.isdir(path_gt) and os.path.isdir(path_pred): 248 | compute_metrics_on_directories(path_gt, path_pred) 249 | else: 250 | raise ValueError( 251 | "The paths given needs to be two directories or two files.") 252 | 253 | 254 | if __name__ == "__main__": 255 | # parser = argparse.ArgumentParser( 256 | # description="Script to compute ACDC challenge metrics.") 257 | # parser.add_argument("GT_IMG", type=str, help="Ground Truth image") 258 | # parser.add_argument("PRED_IMG", type=str, help="Predicted image") 259 | # args = parser.parse_args() 260 | # main(args.GT_IMG, args.PRED_IMG) 261 | 262 | ############################################################## 263 | gt = np.random.randint(0, 4, size=(224, 224, 100)) 264 | print(np.unique(gt)) 265 | pred = np.array(gt) 266 | pred[pred==2] = 3 267 | result = metrics(gt, pred, voxel_size=(224, 224, 100)) 268 | print(result) -------------------------------------------------------------------------------- /utils/direct_field/__pycache__/df_cardia.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/utils/direct_field/__pycache__/df_cardia.cpython-36.pyc -------------------------------------------------------------------------------- /utils/direct_field/__pycache__/df_cardia.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/utils/direct_field/__pycache__/df_cardia.cpython-37.pyc -------------------------------------------------------------------------------- /utils/direct_field/__pycache__/utils_df.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/utils/direct_field/__pycache__/utils_df.cpython-36.pyc -------------------------------------------------------------------------------- /utils/direct_field/__pycache__/utils_df.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-feng/DirectionalFeature/4476c0053044319404828268ed5e0775bb1d881c/utils/direct_field/__pycache__/utils_df.cpython-37.pyc -------------------------------------------------------------------------------- /utils/direct_field/df_cardia.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | import math 4 | import cv2 5 | from PIL import Image 6 | 7 | def direct_field(a, norm=True): 8 | """ a: np.ndarray, (h, w) 9 | """ 10 | if a.ndim == 3: 11 | a = np.squeeze(a) 12 | 13 | h, w = a.shape 14 | 15 | a_Image = Image.fromarray(a) 16 | a = a_Image.resize((w, h), Image.NEAREST) 17 | a = np.array(a) 18 | 19 | accumulation = np.zeros((2, h, w), dtype=np.float32) 20 | for i in np.unique(a)[1:]: 21 | # b, ind = ndimage.distance_transform_edt(a==i, return_indices=True) 22 | # c = np.indices((h, w)) 23 | # diff = c - ind 24 | # dr = np.sqrt(np.sum(diff ** 2, axis=0)) 25 | 26 | img = (a == i).astype(np.uint8) 27 | dst, labels = cv2.distanceTransformWithLabels(img, cv2.DIST_L2, cv2.DIST_MASK_PRECISE, labelType=cv2.DIST_LABEL_PIXEL) 28 | index = np.copy(labels) 29 | index[img > 0] = 0 30 | place = np.argwhere(index > 0) 31 | nearCord = place[labels-1,:] 32 | x = nearCord[:, :, 0] 33 | y = nearCord[:, :, 1] 34 | nearPixel = np.zeros((2, h, w)) 35 | nearPixel[0,:,:] = x 36 | nearPixel[1,:,:] = y 37 | grid = np.indices(img.shape) 38 | grid = grid.astype(float) 39 | diff = grid - nearPixel 40 | if norm: 41 | dr = np.sqrt(np.sum(diff**2, axis = 0)) 42 | else: 43 | dr = np.ones_like(img) 44 | 45 | # direction = np.zeros((2, h, w), dtype=np.float32) 46 | # direction[0, b>0] = np.divide(diff[0, b>0], dr[b>0]) 47 | # direction[1, b>0] = np.divide(diff[1, b>0], dr[b>0]) 48 | 49 | direction = np.zeros((2, h, w), dtype=np.float32) 50 | direction[0, img>0] = np.divide(diff[0, img>0], dr[img>0]) 51 | direction[1, img>0] = np.divide(diff[1, img>0], dr[img>0]) 52 | 53 | accumulation[:, img>0] = 0 54 | accumulation = accumulation + direction 55 | 56 | # mag, angle = cv2.cartToPolar(accumulation[0, ...], accumulation[1, ...]) 57 | # for l in np.unique(a)[1:]: 58 | # mag_i = mag[a==l].astype(float) 59 | # t = 1 / mag_i * mag_i.max() 60 | # mag[a==l] = t 61 | # x, y = cv2.polarToCart(mag, angle) 62 | # accumulation = np.stack([x, y], axis=0) 63 | 64 | return accumulation 65 | 66 | 67 | if __name__ == "__main__": 68 | import matplotlib.pyplot as plt 69 | # gt_p = "/home/ffbian/chencheng/XieheCardiac/npydata/dianfen/16100000/gts/16100000_CINE_segmented_SAX_b3.npy" 70 | # gt = np.load(gt_p)[..., 9] # uint8 71 | # print(gt.shape) 72 | 73 | # a_Image = Image.fromarray(gt) 74 | # a = a_Image.resize((224, 224), Image.NEAREST) 75 | # a = np.array(a) # uint8 76 | # print(a.shape, np.unique(a)) 77 | 78 | # # plt.imshow(a) 79 | # # plt.show() 80 | 81 | # ############################################################ 82 | # direction = direct_field(gt) 83 | 84 | # theta = np.arctan2(direction[1,...], direction[0,...]) 85 | # degree = theta * 180 / math.pi 86 | # degree = (degree + 180) / 360 87 | 88 | # plt.imshow(degree) 89 | # plt.show() 90 | 91 | ######################################################## 92 | import json, time, pdb, h5py 93 | data_list = "/home/ffbian/chencheng/XieheCardiac/2DUNet/UNet/libs/datasets/train_new.json" 94 | data_list = "/root/chengfeng/Cardiac/source_code/libs/datasets/jsonLists/acdcList/Dense_TestList.json" 95 | with open(data_list, 'r') as f: 96 | data_infos = json.load(f) 97 | 98 | mag_stat = [] 99 | st = time.time() 100 | for i, di in enumerate(data_infos): 101 | # img_p, times_idx = di 102 | # gt_p = img_p.replace("/imgs/", "/gts/") 103 | # gt = np.load(gt_p)[..., times_idx] 104 | 105 | img = h5py.File(di,'r')['image'] 106 | gt = h5py.File(di,'r')['label'] 107 | gt = np.array(gt).astype(np.float32) 108 | 109 | print(gt.shape) 110 | direction = direct_field(gt, False) 111 | # theta = np.arctan2(direction[1,...], direction[0,...]) 112 | mag, angle = cv2.cartToPolar(direction[0, ...], direction[1, ...]) 113 | # degree = theta * 180 / math.pi 114 | # degree = (degree + 180) / 360 115 | degree = angle / (2 * math.pi) * 255 116 | # degree = (theta - theta.min()) / (theta.max() - theta.min()) * 255 117 | # mag = np.sqrt(np.sum(direction ** 2, axis=0, keepdims=False)) 118 | 119 | 120 | # 归一化 121 | # for l in np.unique(gt)[1:]: 122 | # mag_i = mag[gt==l].astype(float) 123 | # # mag[gt==l] = 1. - mag[gt==l] / np.max(mag[gt==l]) 124 | # t = (mag_i - mag_i.min()) / (mag_i.max() - mag_i.min()) 125 | # mag[gt==l] = np.exp(-10*t) 126 | # print(mag_i.max(), mag_i.min()) 127 | 128 | # for l in np.unique(gt)[1:]: 129 | # mag_i = mag[gt==l].astype(float) 130 | # t = 1 / (mag_i) * mag_i.max() 131 | # # t = np.exp(-0.8*mag_i) * mag_i.max() 132 | # # t = 1 / np.sqrt(mag_i+1) * mag_i.max() 133 | # mag[gt==l] = t 134 | # # print(mag_i.max(), mag_i.min()) 135 | 136 | # mag[mag>0] = 2 * np.exp(-0.8*(mag[mag>0]-1)) 137 | # mag[mag>0] = 2 * np.exp(0.8*(mag[mag>0]-1)) 138 | 139 | 140 | mag_stat.append(mag.max()) 141 | # pdb.set_trace() 142 | 143 | # plt.imshow(degree) 144 | # plt.show() 145 | 146 | ###################### 147 | fig, axs = plt.subplots(1, 3) 148 | axs[0].imshow(degree) 149 | axs[1].imshow(gt) 150 | axs[2].imshow(mag) 151 | plt.show() 152 | 153 | ###################### 154 | if i % 100 == 0: 155 | print("\r\r{}/{} {:.4}s".format(i+1, len(data_infos), time.time()-st)) 156 | print() 157 | 158 | print("total time: ", time.time()-st) 159 | print("Average time: ", (time.time()-st) / len(data_infos)) 160 | # total time: 865.811030626297 161 | # Average time: 0.012969593759126428 162 | 163 | plt.hist(mag_stat) 164 | plt.show() 165 | print(mag_stat) -------------------------------------------------------------------------------- /utils/direct_field/utils_df.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.ndimage import distance_transform_edt as distance 4 | from utils.utils_loss import one_hot, simplex, class2one_hot 5 | 6 | def one_hot2dist(seg: np.ndarray) -> np.ndarray: 7 | assert one_hot(torch.Tensor(seg), axis=0) 8 | C: int = len(seg) 9 | 10 | res = np.zeros_like(seg) 11 | for c in range(C): 12 | posmask = seg[c].astype(np.bool) 13 | 14 | if posmask.any(): 15 | negmask = ~posmask 16 | res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask 17 | return res 18 | 19 | def class2dist(seg: np.ndarray, C=4) -> np.ndarray: 20 | """ res: (C, H, W) 21 | """ 22 | if seg.ndim == 2: 23 | seg_tensor = torch.Tensor(seg) 24 | elif seg.ndim == 3: 25 | seg_tensor = torch.Tensor(seg[0]) 26 | elif seg.ndim == 4: 27 | seg_tensor = torch.Tensor(seg[0, ..., 0]) 28 | 29 | seg_onehot = class2one_hot(seg_tensor, C).to(torch.float32) 30 | 31 | assert simplex(seg_onehot) 32 | res = one_hot2dist(seg_onehot[0].numpy()) 33 | return res 34 | 35 | -------------------------------------------------------------------------------- /utils/image_list.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from __future__ import division 3 | 4 | import torch 5 | import numpy as np 6 | 7 | class ImageList(object): 8 | """ 9 | Structure that holds a list of images (of possibly 10 | varying sizes) as a single tensor. 11 | This works by padding the images to the same size, 12 | and storing in a field the original sizes of each image 13 | """ 14 | 15 | def __init__(self, tensors, image_sizes): 16 | """ 17 | Arguments: 18 | tensors (tensor) 19 | image_sizes (list[tuple[int, int]]) 20 | """ 21 | self.tensors = tensors 22 | self.image_sizes = image_sizes 23 | 24 | def to(self, *args, **kwargs): 25 | cast_tensor = self.tensors.to(*args, **kwargs) 26 | return ImageList(cast_tensor, self.image_sizes) 27 | 28 | 29 | def to_image_list(tensors, size_divisible=0, return_size=False): 30 | """ 31 | tensors can be an ImageList, a torch.Tensor or 32 | an iterable of Tensors. It can't be a numpy array. 33 | When tensors is an iterable of Tensors, it pads 34 | the Tensors with zeros so that they have the same 35 | shape 36 | """ 37 | if isinstance(tensors, torch.Tensor) and size_divisible > 0: 38 | tensors = [tensors] 39 | 40 | if isinstance(tensors, ImageList): 41 | return tensors 42 | elif isinstance(tensors, torch.Tensor): 43 | # single tensor shape can be inferred 44 | if tensors.dim() == 3: 45 | tensors = tensors[None] 46 | assert tensors.dim() == 4 47 | image_sizes = [tensor.shape[-2:] for tensor in tensors] 48 | return ImageList(tensors, image_sizes) 49 | elif isinstance(tensors, (tuple, list, np.ndarray)): 50 | max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) 51 | 52 | # TODO Ideally, just remove this and let me model handle arbitrary 53 | # input sizs 54 | if size_divisible > 0: 55 | import math 56 | 57 | stride = size_divisible 58 | max_size = list(max_size) 59 | max_size[1] = int(math.ceil(max_size[1] / stride) * stride) 60 | max_size[2] = int(math.ceil(max_size[2] / stride) * stride) 61 | max_size = tuple(max_size) 62 | 63 | batch_shape = (len(tensors),) + max_size 64 | batched_imgs = tensors[0].new(*batch_shape).zero_() 65 | for img, pad_img in zip(tensors, batched_imgs): 66 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 67 | 68 | image_sizes = [im.shape[-2:] for im in tensors] 69 | 70 | # return ImageList(batched_imgs, image_sizes) 71 | if return_size: 72 | return batched_imgs, image_sizes 73 | else: 74 | return batched_imgs 75 | else: 76 | raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors))) 77 | 78 | -------------------------------------------------------------------------------- /utils/init_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def init_weights(net, init_type='normal', gain=0.02): 4 | def init_func(m): 5 | classname = m.__class__.__name__ 6 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 7 | if init_type == 'normal': 8 | nn.init.normal_(m.weight.data, 0.0, gain) 9 | elif init_type == 'xavier': 10 | nn.init.xavier_normal_(m.weight.data, gain=gain) 11 | elif init_type == 'kaiming': 12 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 13 | elif init_type == 'orthogonal': 14 | nn.init.orthogonal_(m.weight.data, gain=gain) 15 | else: 16 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 17 | if hasattr(m, 'bias') and m.bias is not None: 18 | nn.init.constant_(m.bias.data, 0.0) 19 | elif classname.find('BatchNorm2d') != -1: 20 | nn.init.normal_(m.weight.data, 1.0, gain) 21 | nn.init.constant_(m.bias.data, 0.0) 22 | 23 | print('initialize network with %s' % init_type) 24 | net.apply(init_func) -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from hausdorff import hausdorff_distance 4 | from medpy.metric.binary import hd, dc 5 | 6 | def dice(pred, target): 7 | pred = pred.contiguous() 8 | target = target.contiguous() 9 | smooth = 0.00001 10 | 11 | # intersection = (pred * target).sum(dim=2).sum(dim=2) 12 | pred_flat = pred.view(1, -1) 13 | target_flat = target.view(1, -1) 14 | 15 | intersection = (pred_flat * target_flat).sum().item() 16 | 17 | # loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))) 18 | dice = (2 * intersection + smooth) / (pred_flat.sum().item() + target_flat.sum().item() + smooth) 19 | return dice 20 | 21 | def dice3D(img_gt, img_pred, voxel_size): 22 | """ 23 | Function to compute the metrics between two segmentation maps given as input. 24 | 25 | Parameters 26 | ---------- 27 | img_gt: np.array 28 | Array of the ground truth segmentation map. 29 | 30 | img_pred: np.array 31 | Array of the predicted segmentation map. 32 | 33 | voxel_size: list, tuple or np.array 34 | The size of a voxel of the images used to compute the volumes. 35 | 36 | Return 37 | ------ 38 | A list of metrics in this order, [Dice LV, Volume LV, Err LV(ml), 39 | Dice RV, Volume RV, Err RV(ml), Dice MYO, Volume MYO, Err MYO(ml)] 40 | """ 41 | 42 | if img_gt.ndim != img_pred.ndim: 43 | raise ValueError("The arrays 'img_gt' and 'img_pred' should have the " 44 | "same dimension, {} against {}".format(img_gt.ndim, 45 | img_pred.ndim)) 46 | 47 | res = [] 48 | # Loop on each classes of the input images 49 | for c in [3, 1, 2]: 50 | # Copy the gt image to not alterate the input 51 | gt_c_i = np.copy(img_gt) 52 | gt_c_i[gt_c_i != c] = 0 53 | 54 | # Copy the pred image to not alterate the input 55 | pred_c_i = np.copy(img_pred) 56 | pred_c_i[pred_c_i != c] = 0 57 | 58 | # Clip the value to compute the volumes 59 | gt_c_i = np.clip(gt_c_i, 0, 1) 60 | pred_c_i = np.clip(pred_c_i, 0, 1) 61 | 62 | # Compute the Dice 63 | dice = dc(gt_c_i, pred_c_i) 64 | 65 | # Compute volume 66 | # volpred = pred_c_i.sum() * np.prod(voxel_size) / 1000. 67 | # volgt = gt_c_i.sum() * np.prod(voxel_size) / 1000. 68 | 69 | # res += [dice, volpred, volpred-volgt] 70 | res += [dice] 71 | 72 | return res 73 | 74 | def hd_3D(img_pred, img_gt, labels=[3, 1, 2]): 75 | res = [] 76 | for c in labels: 77 | gt_c_i = np.copy(img_gt) 78 | gt_c_i[gt_c_i != c] = 0 79 | 80 | pred_c_i = np.copy(img_pred) 81 | pred_c_i[pred_c_i != c] = 0 82 | 83 | gt_c_i = np.clip(gt_c_i, 0, 1) 84 | pred_c_i = np.clip(pred_c_i, 0, 1) 85 | 86 | if np.sum(pred_c_i) == 0 or np.sum(gt_c_i) == 0: 87 | hausdorff = 0 88 | else: 89 | hausdorff = hd(pred_c_i, gt_c_i) 90 | 91 | res += [hausdorff] 92 | 93 | return res 94 | 95 | def cal_hausdorff_distance(pred,target): 96 | 97 | pred = np.array(pred.contiguous()) 98 | target = np.array(target.contiguous()) 99 | result = hausdorff_distance(pred,target,distance="euclidean") 100 | 101 | return result 102 | 103 | def make_one_hot(input, num_classes): 104 | """Convert class index tensor to one hot encoding tensor. 105 | Args: 106 | input: A tensor of shape [N, 1, *] 107 | num_classes: An int of number of class 108 | Returns: 109 | A tensor of shape [N, num_classes, *] 110 | """ 111 | shape = np.array(input.shape) 112 | shape[1] = num_classes 113 | shape = tuple(shape) 114 | result = torch.zeros(shape).scatter_(1, input.cpu().long(), 1) 115 | # result = result.scatter_(1, input.cpu(), 1) 116 | 117 | return result 118 | 119 | def match_pred_gt(pred, gt): 120 | """ pred: (1, C, H, W) 121 | gt: (1, C, H, W) 122 | """ 123 | gt_labels = torch.unique(gt, sorted=True)[1:] 124 | pred_labels = torch.unique(pred, sorted=True)[1:] 125 | 126 | if len(gt_labels) != 0 and len(pred_labels) != 0: 127 | dice_Matrix = torch.zeros((len(pred_labels), len(gt_labels))) 128 | for i, pl in enumerate(pred_labels): 129 | pred_i = torch.tensor(pred==pl, dtype=torch.float) 130 | for j, gl in enumerate(gt_labels): 131 | dice_Matrix[i, j] = dice(make_one_hot(pred_i, 2)[0], make_one_hot(gt==gl, 2)[0]) 132 | 133 | # max_axis0 = np.max(dice_Matrix, axis=0) 134 | max_arg0 = np.argmax(dice_Matrix, axis=0) 135 | else: 136 | return torch.zeros_like(pred) 137 | 138 | pred_match = torch.zeros_like(pred) 139 | for i, arg in enumerate(max_arg0): 140 | pred_match[pred==pred_labels[arg]] = i + 1 141 | return pred_match 142 | 143 | if __name__ == "__main__": 144 | npy_path = "/home/fcheng/Cardia/source_code/logs/logs_df_50000/eval_pp_test/200.npy" 145 | pred_df, gt_df = np.load(npy_p) 146 | -------------------------------------------------------------------------------- /utils/utils_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import List, Set, Iterable 4 | 5 | def uniq(a: Tensor) -> Set: 6 | return set(torch.unique(a.cpu()).numpy()) 7 | 8 | def sset(a: Tensor, sub: Iterable) -> bool: 9 | return uniq(a).issubset(sub) 10 | 11 | def simplex(t: Tensor, axis=1) -> bool: 12 | _sum = t.sum(axis).type(torch.float32) 13 | _ones = torch.ones_like(_sum, dtype=torch.float32) 14 | return torch.allclose(_sum, _ones) 15 | 16 | def one_hot(t: Tensor, axis=1) -> bool: 17 | return simplex(t, axis) and sset(t, [0, 1]) 18 | 19 | # switch between representations 20 | def probs2class(probs: Tensor) -> Tensor: 21 | b, _, w, h = probs.shape # type: Tuple[int, int, int, int] 22 | assert simplex(probs) 23 | 24 | res = probs.argmax(dim=1) 25 | assert res.shape == (b, w, h) 26 | 27 | return res 28 | 29 | def class2one_hot(seg: Tensor, C: int) -> Tensor: 30 | if len(seg.shape) == 2: # Only w, h, used by the dataloader 31 | seg = seg.unsqueeze(dim=0) 32 | assert sset(seg, list(range(C))) 33 | 34 | b, w, h = seg.shape # type: Tuple[int, int, int] 35 | 36 | res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32) 37 | assert res.shape == (b, C, w, h) 38 | assert one_hot(res) 39 | 40 | return res 41 | 42 | def probs2one_hot(probs: Tensor) -> Tensor: 43 | _, C, _, _ = probs.shape 44 | assert simplex(probs) 45 | 46 | res = class2one_hot(probs2class(probs), C) 47 | assert res.shape == probs.shape 48 | assert one_hot(res) 49 | 50 | return res -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import colorsys 3 | import random 4 | import cv2 5 | 6 | def get_n_hls_colors(num): 7 | hls_colors = [] 8 | i = 0 9 | step = 360.0 / num 10 | while i < 360: 11 | h = i 12 | s = 90 + random.random() * 10 13 | l = 50 + random.random() * 10 14 | _hlsc = [h / 360.0, l / 100.0, s / 100.0] 15 | hls_colors.append(_hlsc) 16 | i += step 17 | 18 | return hls_colors 19 | 20 | def ncolors(num): 21 | rgb_colors = [] 22 | if num < 1: 23 | return np.array(rgb_colors) 24 | hls_colors = get_n_hls_colors(num) 25 | for hlsc in hls_colors: 26 | _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2]) 27 | rgb_colors.append([_r, _g, _b]) 28 | # r, g, b = [int(x * 255.0) for x in (_r, _g, _b)] 29 | # rgb_colors.append([r, g, b]) 30 | 31 | return np.array(rgb_colors) 32 | 33 | def random_colors(N, bright=True): 34 | """ 35 | Generate random colors. 36 | To get visually distinct colors, generate them in HSV space then 37 | convert to RGB. 38 | """ 39 | brightness = 1.0 if bright else 0.7 40 | hsv = [(i / N, 1, brightness) for i in range(N)] 41 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 42 | # random.shuffle(colors) 43 | return colors 44 | 45 | def mask2png(mask, file_name=None, suffix="png"): 46 | """ mask: (w, h) 47 | img_rgb: (w, h, rgb) 48 | """ 49 | nums = np.unique(mask)[1:] 50 | if len(nums) < 1: 51 | colors = np.array([[0,0,0]]) 52 | else: 53 | # colors = ncolors(len(nums)) 54 | colors = (np.array(random_colors(len(nums))) * 255).astype(int) 55 | colors = np.insert(colors, 0, [0,0,0], 0) 56 | 57 | # 保证mask中的值为1-N连续 58 | mask_ordered = np.zeros_like(mask) 59 | for cnt, l in enumerate(nums, 1): 60 | mask_ordered[mask==l] = cnt 61 | 62 | im_rgb = colors[mask_ordered.astype(int)] 63 | if file_name is not None: 64 | cv2.imwrite(file_name+"."+suffix, im_rgb[:, :, ::-1]) 65 | return im_rgb 66 | 67 | def apply_mask(image, mask, color, alpha=0.5, scale=1): 68 | """Apply the given mask to the image. 69 | """ 70 | for c in range(3): 71 | image[:, :, c] = np.where(mask == 1, 72 | image[:, :, c] * 73 | (1 - alpha) + alpha * color[c] * scale, 74 | image[:, :, c]) 75 | return image 76 | 77 | def img_mask_png(image, mask, file_name=None, alpha=0.5, suffix="png"): 78 | """ mask: (h, w) 79 | image: (h, w, rgb) 80 | """ 81 | nums = np.unique(mask)[1:] 82 | if len(nums) < 1: 83 | colors = np.array([[0,0,0]]) 84 | else: 85 | colors = ncolors(len(nums)) 86 | colors = np.insert(colors, 0, [0,0,0], 0) 87 | 88 | # 保证mask中的值为1-N连续 89 | mask_ordered = np.zeros_like(mask) 90 | for cnt, l in enumerate(nums, 1): 91 | mask_ordered[mask==l] = cnt 92 | 93 | # mask_rgb = colors[mask_ordered.astype(int)] 94 | mix_im = image.copy() 95 | for i in np.unique(mask_ordered)[1:]: 96 | mix_im = apply_mask(mix_im, mask_ordered==i, colors[int(i)], alpha=alpha, scale=255) 97 | 98 | if file_name is not None: 99 | cv2.imwrite(file_name+"."+suffix, mix_im[:, :, ::-1]) 100 | return mix_im 101 | 102 | def _find_contour(mask): 103 | # _, contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # 顶点 104 | _, contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 105 | 106 | cont = np.zeros_like(mask) 107 | for contour in contours: 108 | cont[contour[:,:,1], contour[:,:,0]] = 1 109 | return cont 110 | 111 | def masks_to_contours(masks): 112 | # 包含多个区域 113 | nums = np.unique(masks)[1:] 114 | cont_mask = np.zeros_like(masks) 115 | for i in nums: 116 | cont_mask += _find_contour(masks==i) 117 | return (cont_mask>0).astype(int) 118 | 119 | def batchToColorImg(batch, minv=None, maxv=None, scale=255.): 120 | """ batch: (N, H, W, C) 121 | """ 122 | if batch.ndim == 3: 123 | N, H, W = batch.shape 124 | elif batch.ndim == 4: 125 | N, H, W, _ = batch.shape 126 | colorImg = np.zeros(shape=(N, H, W, 3)) 127 | for i in range(N): 128 | if minv is None: 129 | a = (batch[i] - batch[i].min()) / (batch[i].max() - batch[i].min()) * 255 130 | else: 131 | a = (batch[i] - minv) / (maxv - minv) * scale 132 | a = cv2.applyColorMap(a.astype(np.uint8), cv2.COLORMAP_JET) 133 | colorImg[i, ...] = a[..., ::-1] / 255. 134 | return colorImg 135 | 136 | if __name__ == "__main__": 137 | a = np.zeros((100, 100)) 138 | a[0:5, 3:8] = 1 139 | a[75:85, 85:95] = 2 140 | 141 | # colors = ncolors(2)[::-1] 142 | colors = np.array(random_colors(2)) * 255 143 | colors = np.insert(colors, 0, [0, 0, 0], 0) 144 | 145 | b = colors[a.astype(int)].astype(np.uint8) 146 | import cv2, skimage 147 | # skimage.io.imsave("test_io.png", b) 148 | # # cv2.imwrite("test.jpg", b[:, :, ::-1]) 149 | # print() 150 | 151 | # mask2png(a, "test") 152 | 153 | ################################################ 154 | # img_mask_png(b, a, "test") 155 | 156 | ############################################# 157 | # cont_mask = find_contours(a) 158 | # print() 159 | # # skimage.io.imsave("test_cont.png", cont_mask) 160 | # b[cont_mask>0, :] = [255, 255, 255] 161 | # skimage.io.imsave("test_cont.png", b) 162 | 163 | gt0 = skimage.io.imread("gt0.png", as_gray=False) 164 | print() 165 | gt0[gt0==54] = 0 166 | # cont_mask = find_contours(gt0==237) # Array([ 18, 54, 73, 237], dtype=uint8) 167 | # cont_mask += find_contours(gt0==18) 168 | # cont_mask += find_contours(gt0==73) 169 | cont_mask = masks_to_contours(gt0) 170 | 171 | colors = np.array(random_colors(1)) * 255 172 | colors = np.insert(colors, 0, [0, 0, 0], 0) 173 | cont_mask = colors[cont_mask.astype(int)].astype(np.uint8) 174 | 175 | skimage.io.imsave("test_cont.png", cont_mask) 176 | --------------------------------------------------------------------------------