├── Create_Weight.py ├── LICENSE ├── README.md ├── configs ├── btrfly.yaml └── defaults.py ├── datasets └── prepare_val.py ├── figures └── 1.png ├── input.py ├── models ├── __init__.py ├── btrfly_net.py └── eb_discriminator.py ├── test.py ├── train.py └── utils ├── checkpoint.py ├── data.py ├── inference.py ├── logger.py ├── metrics.py ├── misc.py └── trainer.py /Create_Weight.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import scipy.io as io 4 | import nibabel as nib 5 | import numpy as np 6 | 7 | 8 | INPUT = '/shenlab/local/zhenghan/pj_wx/BtrflyNet/datasets/Label2D/train/' 9 | OUTPUT = '/shenlab/local/zhenghan/pj_wx/BtrflyNet/datasets/Label2D/weight/' 10 | INPUT_SEG = '/shenlab/local/zhenghan/original/seg/' 11 | 12 | threshold = 0.6 13 | 14 | def get_train_list(path): 15 | file_list_train = [] 16 | 17 | for root, dirs, files in os.walk(path): 18 | for file in files: 19 | if file.startswith('.'): 20 | continue 21 | else: 22 | if (file[0:-4] == 'verse113') | (file[0:-4] == 'verse104') | (file[0:-4] == 'verse201'): 23 | continue 24 | else: 25 | file_list_train.append(os.path.join(file)) 26 | 27 | wholesize = len(file_list_train) 28 | 29 | return file_list_train, wholesize 30 | 31 | def get_gt_heatmap(file_list_train): 32 | front = [] 33 | side = [] 34 | for file in file_list_train: 35 | path = INPUT + file 36 | front.append((io.loadmat(path))['front']) 37 | side.append((io.loadmat(path))['side']) 38 | return front, side 39 | 40 | def get_seg_data(file_list_train): 41 | seg = [] 42 | 43 | for file in file_list_train: 44 | path = INPUT_SEG + file[0:-4] + '_seg.nii' 45 | imgseg = nib.load(path) 46 | seg.append(imgseg.get_fdata()) 47 | return seg 48 | 49 | def frequency_statistics(gt_heatmap_list): 50 | label_pixel = np.zeros(25) 51 | #background_pixel = np.zeros(25) 52 | all_pixel = 0 53 | for gt_heatmap in gt_heatmap_list: 54 | all_pixel += gt_heatmap.shape[0] * gt_heatmap.shape[1] ##* gt_heatmap.shape[2] 55 | #channel_pixel = gt_heatmap.shape[0] * gt_heatmap.shape[1] 56 | gt_label = np.zeros((gt_heatmap.shape[0],gt_heatmap.shape[1])) 57 | for i in range(24): 58 | label = i + 1 59 | gt_channel = np.where(gt_heatmap[:,:,label] > threshold, 1, 0) 60 | gt_label = np.where(gt_heatmap[:,:,label] > threshold, 1, gt_label) 61 | label_pixel[label] += gt_channel.sum() 62 | #background_pixel[label] += (channel_pixel - gt_label.sum()) 63 | background = np.where(gt_label == 0, 1, 0) 64 | label_pixel[0] += background.sum() 65 | 66 | label_freq = label_pixel / float(all_pixel) 67 | #background_freq = background_pixel / float(all_pixel) 68 | 69 | return label_freq 70 | 71 | def get_save_weighted_para(file_list, front_heatmap_list, side_heatmap_list, front_freq, side_freq, save_path): 72 | front_medium = np.median(front_freq) 73 | side_medium = np.median(side_freq) 74 | idx = 0 75 | for file in file_list: 76 | front_heatmap = front_heatmap_list[idx] 77 | side_heatmap = side_heatmap_list[idx] 78 | outpath = save_path + file 79 | 80 | 81 | w, h, c = front_heatmap.shape[0], front_heatmap.shape[1], front_heatmap.shape[2] 82 | front_weight = np.zeros(c) 83 | #gt_label = np.zeros((front_heatmap.shape[0], front_heatmap.shape[1])) 84 | for label in range(25): 85 | #label = i + 1 86 | front_weight[label] = front_medium/front_freq[label] 87 | #gt_label = np.where(front_heatmap[:,:,label] > threshold, 1, gt_label) 88 | #front_weight[:,:,0]= np.where(gt_label == 0, front_medium/front_freq[0], 0) 89 | 90 | w, h, c = side_heatmap.shape[0], side_heatmap.shape[1], side_heatmap.shape[2] 91 | side_weight = np.zeros(c) 92 | #gt_label = np.zeros((side_heatmap.shape[0], side_heatmap.shape[1])) 93 | for label in range(25): 94 | #label = i + 1 95 | side_weight[label] = side_medium / side_freq[label] 96 | #gt_label = np.where(side_heatmap[:, :, label] > threshold, 1, gt_label) 97 | #side_weight[:, :, 0] = np.where(gt_label == 0, side_medium / side_freq[0], 0) 98 | 99 | 100 | io.savemat(outpath,{'front':front_weight, 'side':side_weight}) 101 | idx += 1 102 | 103 | 104 | 105 | 106 | 107 | 108 | file_list_train, train_len = get_train_list(INPUT) 109 | front, side = get_gt_heatmap(file_list_train) 110 | #seg = get_seg_data(file_list_train) 111 | front_label_freq = frequency_statistics(front) 112 | side_label_freq = frequency_statistics(side) 113 | front_label_medium = np.median(front_label_freq) 114 | side_label_medium = np.median(side_label_freq) 115 | front_weight = front_label_medium / front_label_freq 116 | side_weight = side_label_medium / side_label_freq 117 | io.savemat(OUTPUT + 'train_weight.mat',{'front':front_weight, 'side':side_weight}) 118 | 119 | #get_save_weighted_para(file_list_train, front, side, front_label_freq, side_label_freq, OUTPUT) 120 | 121 | 122 | """ 123 | testfront, testside, testsegfront, testsegside = front[1], side[1], np.max(seg[1],axis=0), np.max(seg[1],axis=2) 124 | 125 | #segfront_20 = np.where(testsegfront == 20, 1, 0) 126 | #segside_20 = np.where(testsegside == 20, 1, 0) 127 | for i in range(7): 128 | index = i + 18 129 | front_20 = testfront[:, :, index] 130 | side_20 = testside[:, :, index] 131 | front_20_r = np.where(testsegfront == index, front_20, 3) 132 | side_20_r = np.where(testsegside == index, side_20, 3) 133 | front_min = np.min(front_20_r) 134 | side_min = np.min(side_20_r) 135 | print(front_min,side_min) 136 | """ 137 | a = 1 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Xin Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Btrfly-Net-Pytorch 2 | This repository implements Butterfly Net (for vertebra localization) in PyTorch 1.0 or higher. 3 | 4 | Paper: [Sekuboyina A, et al.: Btrfly Net: Vertebrae Labelling with Energy-based Adversarial Learning of Local Spine Prior. In: MICCAI. (2018) ](https://arxiv.org/abs/1804.01307v2) 5 | 6 | 7 | 8 | ## Highlights 9 | 10 | - **Error Correction**: In view of some mistakes related to the structure of Btrfly Net in the original paper (Sekuboyina A, et al.), we correct the Figure 2 in the paper and build the right structure through PyTorch after contacting with the author many times. 11 | - **Method Improvement**: We create a new position-inferencing method in order to promote the performance of the model. Now you can see two different kinds of id rate results (using the paper's method and the method we proposed) during training. *Roughly speaking, we use the weighted average of the positions inferenced from the two 2D heat maps instead of the direct outer product of them to get the final positions of vertebrae. This helps promote the id rate by nearly 5% in our experiment.* 12 | - **Configs**: We use [YACS](https://pypi.org/project/yacs/) to manage the parameters, including the devices, hyperparameters of the model and some directory paths. See configs/btrfly.yaml. You can change them freely. 13 | - **Smooth and Enjoyable Training Procedure**: We save the state of model, optimizer, scheduler, training iter, you can stop your training and resume training exactly from the save point without change your training `CMD`. 14 | - **TensorboardX**: We support tensorboardX and the log directory is outputs/tf_logs. If you don't want to use it, just set the parameter `--use_tensorboard` to `0`, according to the "Train Using Your Parameters" section. 15 | - **Evaluating during training**: Evaluate you model every `eval_step` to check performance improving or not. 16 | 17 | ## Dataset 18 | 19 | See [Verse2019 challenge](https://verse2019.grand-challenge.org/Data/) for more information. 20 | 21 | The original data directory should be like the following structure. 22 | 23 | ``` 24 | OriginalPath 25 | |__ raw 26 | |_ 001.nii 27 | |_ 002.nii 28 | |_ ... 29 | |__ seg 30 | |_ 001_seg.nii 31 | |_ 002_seg.nii 32 | |_ ... 33 | |__ pos 34 | |_ 001_ctd.json 35 | |_ 002_ctd.json 36 | |_ ... 37 | ``` 38 | 39 | - The dataset has three files corresponding to one data sample, structured as follows: 40 | - - 1. verse.nii.gz - Image 41 | 2. verse_seg.nii.gz - Segmentation Mask 42 | 3. verse_ctd.json - Centroid annotations 43 | 4. verse_snapshot - A PNG overview of the annotations. 44 | 45 | - The images need NOT be in the same orientation. Their spacing need NOT be the same. However, an image and its corresponding mask will be in the same orientation. 46 | 47 | - Both masks and centroids are linked with the label values [1-24] corresponding to the vertebrae [C1-L5]. Some cases might contain 25, the label L6. 48 | - The centroid annotations are with respect to the coordinate axis fixed on an isotropic scan (1mm) and a (P, I, R) or (A, S, L) orientation, described as: 49 | - - 1. Origin at Superior (S) - Right (R) - Anterior (A) 50 | 2. 'X' corresponds to S -> I direction 51 | 3. 'Y' corresponds to A -> P direction 52 | 4. 'Z' corresponds to R -> L direction 53 | 5. 'label' corresponds to the vertebral label 54 | 55 | ## Training 56 | 57 | ### Train Directly 58 | 59 | ```python 60 | python train.py 61 | ``` 62 | 63 | ### Train Using Your Parameters 64 | 65 | For example, you can change the `save_step` by 66 | 67 | ``` 68 | python train.py --save_step 10 69 | ``` 70 | 71 | See more changeable parameters in train.py file. 72 | 73 | ## Result 74 | 75 | The prediction result will be saved as a `.pth` file in the `pred_list` directory. You can set the parameter `is_test` in test.py to 0 or 1 to determine if the trained model is used for validation set or test set. 76 | 77 | -------------------------------------------------------------------------------- /configs/btrfly.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | DEVICE: 'cuda' 3 | USE_GAN: 0 4 | IMAGE_SIZE: 256 5 | USE_BN: 1 6 | CHANNELS: (1, 32, 64, 128, 256, 256, 512, 1024, 512, 512, 256, 128, 64, 25) 7 | 8 | SOLVER: 9 | LR: 3e-4 10 | WEIGHT_DECAY: 1e-3 11 | MAX_ITER: 500000 12 | BATCH_SIZE: 2 13 | SAVE_NUM: 5 14 | 15 | TEST: 16 | BATCH_SIZE: 2 17 | DEVICE: 'cuda' 18 | 19 | ORIGINAL_PATH: '/shenlab/local/zhenghan/original/' 20 | MAT_DIR_TRAIN: 'datasets/Label2DRes/train/' 21 | INPUT_IMG_DIR_TRAIN: 'datasets/JPEGRes/train/' 22 | 23 | 24 | CROP_INFO_DIR: 'datasets/Crop_info/' 25 | MAT_DIR_VAL: 'datasets/Label2DRes/val/' 26 | INPUT_IMG_DIR_VAL: 'datasets/JPEGRes/val/' 27 | 28 | MAT_DIR_TEST: 'datasets/Label2DRes/test/' 29 | INPUT_IMG_DIR_TEST: 'datasets/JPEGRes/test/' 30 | 31 | OUTPUT_DIR: 'outputs_16/' 32 | TRAIN_WEIGHT: "/shenlab/local/zhenghan/pj_wx/BtrflyNet/datasets/Label2D/weight/train_weight_8.mat" -------------------------------------------------------------------------------- /configs/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _C = CN() 4 | 5 | _C.TRAIN_WEIGHT = "/shenlab/local/zhenghan/pj_wx/BtrflyNet/datasets/Label2D/weight/train_weight.mat" 6 | 7 | _C.MODEL = CN() 8 | _C.MODEL.DEVICE = "cuda" 9 | _C.MODEL.USE_GAN = 0 10 | _C.MODEL.IMAGE_SIZE = 512 11 | _C.MODEL.USE_BN = 1 12 | _C.MODEL.CHANNELS = (1, 32, 64, 128, 256, 256, 512, 1024, 512, 512, 256, 128, 64, 25) 13 | 14 | _C.SOLVER = CN() 15 | _C.SOLVER.LR = 1e-3 16 | _C.SOLVER.WEIGHT_DECAY = 1e-3 17 | _C.SOLVER.MAX_ITER = 120000 18 | _C.SOLVER.BATCH_SIZE = 16 19 | _C.SOLVER.SAVE_NUM = 25 20 | 21 | _C.TEST = CN() 22 | _C.TEST.BATCH_SIZE = 10 23 | _C.TEST.DEVICE = 'cuda' 24 | 25 | _C.ORIGINAL_PATH = '/shenlab/local/zhenghan/original/' 26 | _C.MAT_DIR_TRAIN = "datasets/Label2D/train/" 27 | _C.INPUT_IMG_DIR_TRAIN = "datasets/JPEGUncover_train/" 28 | 29 | _C.CROP_INFO_DIR = 'datasets/Crop_info' 30 | _C.MAT_DIR_VAL = 'datasets/VOC2007/Label2D_val/' 31 | _C.INPUT_IMG_DIR_VAL = 'datasets/VOC2007/JPEGUncover_val/' 32 | 33 | _C.MAT_DIR_TEST = 'datasets/' 34 | _C.INPUT_IMG_DIR_TEST = 'datasets/' 35 | 36 | _C.OUTPUT_DIR = "outputs" 37 | 38 | cfg = _C -------------------------------------------------------------------------------- /datasets/prepare_val.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import os 4 | 5 | file_list = glob.glob(pathname="Label2D/*.mat") 6 | rand_num = np.random.rand(len(file_list)) 7 | 8 | #os.system("mv " + "Label2D/val_backup/* " + "Label2D/") 9 | #os.system("mv " + "Label2D/train_backup/* " + "Label2D/") 10 | #os.system("mv " + "JPEGCrop/val/* " + "JPEGCrop/") 11 | #os.system("mv " + "JPEGCrop/train/* " + "JPEGCrop/") 12 | 13 | for i in range(len(file_list)): 14 | if rand_num[i] <= 0.25: 15 | os.system("mv " + file_list[i] + " Label2D/val_backup/") 16 | #os.system("mv " + "JPEGCrop/" + file_list[i][8:16] + "*.jpg " + "JPEGCrop/val/") 17 | else: 18 | os.system("mv " + file_list[i] + " Label2D/train_backup/") 19 | #os.system("mv " + "JPEGCrop/" + file_list[i][8:16] + "*.jpg " + "JPEGCrop/train/") 20 | -------------------------------------------------------------------------------- /figures/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wxdrizzle/Btrfly-Net-Pytorch/bd59a02bf94cce235a47f7ddb5e689327d0de435/figures/1.png -------------------------------------------------------------------------------- /input.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import nibabel as nib 4 | import SimpleITK as sitk 5 | import json 6 | import scipy.io as io 7 | import imageio 8 | 9 | 10 | def get_verse_list(path): 11 | """ 12 | get the list of the vertebrate data 13 | form eg---> verse006.nii 14 | :param path: 15 | :return: 16 | """ 17 | file_list = [] 18 | file_list_train = [] 19 | file_list_vali = [] 20 | file_list_train_pos = [] 21 | file_list_vali_pos = [] 22 | for root, dirs, files in os.walk(path): 23 | for file in files: 24 | if file.startswith('.'): 25 | continue 26 | else: 27 | if (file == 'verse113.nii') | (file == 'verse104.nii') | (file == 'verse201.nii'): 28 | continue 29 | else: 30 | file_list.append(os.path.join(file)) 31 | 32 | 33 | wholesize = len(file_list) 34 | index = range(wholesize) 35 | train_len = wholesize 36 | vali_len = wholesize - train_len 37 | train_index = index[0:train_len] 38 | 39 | vali_index = index[train_len:] 40 | train_index = sorted(train_index) 41 | vali_index = sorted(vali_index) 42 | for idx in train_index: 43 | file_list_train.append(file_list[idx]) 44 | file_list_train_pos.append(file_list[idx][0:-4] + '_ctd.json') 45 | for idx in vali_index: 46 | file_list_vali.append(file_list[idx]) 47 | file_list_vali_pos.append(file_list[idx][0:-4] + '_ctd.json') 48 | file_list_train.sort() 49 | file_list_train_pos.sort() 50 | return file_list_train, file_list_train_pos, train_len 51 | 52 | def get_centroid_pos(raw, w, h, c, fileJson ): 53 | """ 54 | :param raw: from sitk 55 | :param w: from nib 56 | :param h: 57 | :param c: 58 | :param labelidx: 59 | :return: 60 | """ 61 | Dic = {0: 'Z', 1: 'Y', 2: 'X'} 62 | direction = np.round(list(raw.GetDirection())) 63 | direc0 = direction[0:7:3] 64 | direc1 = direction[1:8:3] 65 | direc2 = direction[2:9:3] 66 | dim0char = Dic[(np.argwhere((np.abs(direc0 )) == 1))[0][0]] 67 | dim1char = Dic[(np.argwhere((np.abs(direc1 )) == 1))[0][0]] 68 | dim2char = Dic[(np.argwhere((np.abs(direc2 )) == 1))[0][0]] 69 | resolution = raw.GetSpacing() 70 | label = fileJson['label'] 71 | if np.sum(direc0) == -1: 72 | if dim0char == 'X': 73 | dim0 = fileJson['X']/resolution[0] 74 | else: 75 | dim0 = w - fileJson[dim0char]/resolution[0] 76 | else: 77 | if dim0char == 'X': 78 | dim0 = w - fileJson['X']/resolution[0] 79 | else: 80 | dim0 = fileJson[dim0char]/resolution[0] 81 | 82 | if np.sum(direc1) == -1: 83 | if dim1char == 'X': 84 | dim1 = fileJson['X']/resolution[1] 85 | else: 86 | dim1 = h - fileJson[dim1char]/resolution[1] 87 | else: 88 | if dim1char == 'X': 89 | dim1 = h - fileJson['X']/resolution[1] 90 | else: 91 | dim1 = fileJson[dim1char]/resolution[1] 92 | 93 | if np.sum(direc2) == -1: 94 | if dim2char == 'X': 95 | dim2 = fileJson['X']/resolution[2] 96 | else: 97 | dim2 = c - fileJson[dim2char]/resolution[2] 98 | else: 99 | if dim2char == 'X': 100 | dim2 = c - fileJson['X']/resolution[2] 101 | else: 102 | dim2 = fileJson[dim2char]/resolution[2] 103 | 104 | return label, int(dim0), int(dim1), int(dim2) 105 | 106 | def image_mode(img_path): 107 | Dic = {0:'Z', 1:'Y', 2:'X'} 108 | img = sitk.ReadImage(img_path) 109 | direction = np.round(list(img.GetDirection())) 110 | direc0 = direction[0:7:3] 111 | direc1 = direction[1:8:3] 112 | direc2 = direction[2:9:3] 113 | 114 | dim0_char = Dic[(np.argwhere((np.abs(np.round(direc0))) == 1))[0][0]] 115 | dim1_char = Dic[(np.argwhere((np.abs(np.round(direc1))) == 1))[0][0]] 116 | dim2_char = Dic[(np.argwhere((np.abs(np.round(direc2))) == 1))[0][0]] 117 | 118 | return dim0_char, dim1_char, dim2_char 119 | 120 | def prepare_SSD_input(img_path='/shenlab/local/zhenghan/original/', is_test=False): 121 | """ 122 | generate heat map and save it to "datasets/VOC2007/Label2D/" as .mat file 123 | :input: 124 | img_path: path of original directory 125 | """ 126 | resolution = 2.0 127 | nii_path = 'test/' if is_test else 'raw/' 128 | raw_file_list, pos_file_list, file_num = get_verse_list(img_path + nii_path) 129 | for i in range(file_num): 130 | raw_file_name, pos_file_name = raw_file_list[i], pos_file_list[i] 131 | print(raw_file_name) 132 | 133 | direc = image_mode(img_path + nii_path + raw_file_name) 134 | if (direc != ('Z', 'Y', 'X')) & (direc != ('Y', 'X', 'Z')): 135 | raise Exception('Unknown direction!') 136 | 137 | img_handle_nib = nib.load(img_path + nii_path + raw_file_name) 138 | 139 | img_handle_sitk = sitk.ReadImage(img_path + 'raw/' + raw_file_name) 140 | spacing = img_handle_sitk.GetSpacing() 141 | w, h, c = img_handle_sitk.GetSize() 142 | 143 | #new_size_w = round(w * spacing[0] / resolution) 144 | #new_size_h = round(h * spacing[1] / resolution) 145 | #new_size_c = round(c * spacing[2] / resolution) 146 | new_size_w, new_size_h, new_size_c = w, h, c 147 | 148 | 149 | img_raw = img_handle_nib.get_fdata() 150 | #img_same_res = transform.resize(img_raw, (new_size_w, new_size_h, new_size_c)) 151 | img_same_res = img_raw 152 | imageio.imwrite("../SSD/datasets/VOC2007/JPEGImages/" + raw_file_name[:-4] + "_0.jpg", np.max(img_same_res, axis=0)) 153 | imageio.imwrite("../SSD/datasets/VOC2007/JPEGImages/" + raw_file_name[:-4] + "_1.jpg", np.max(img_same_res, axis=1)) 154 | imageio.imwrite("../SSD/datasets/VOC2007/JPEGImages/" + raw_file_name[:-4] + "_2.jpg", np.max(img_same_res, axis=2)) 155 | 156 | def prepare_heat_map(img_path='/shenlab/local/zhenghan/original/'): 157 | resolution = 2.0 158 | raw_file_list, pos_file_list, file_num = get_verse_list(img_path + 'raw') 159 | for i in range(file_num): 160 | raw_file_name, pos_file_name = raw_file_list[i], pos_file_list[i] 161 | print(raw_file_name) 162 | 163 | direc = image_mode(img_path + 'raw/' + raw_file_name) 164 | if (direc != ('Z', 'Y', 'X')) & (direc != ('Y', 'X', 'Z')): 165 | raise Exception('Unknown direction!') 166 | 167 | img_handle_nib = nib.load(img_path + 'raw/' + raw_file_name) 168 | 169 | img_handle_sitk = sitk.ReadImage(img_path + 'raw/' + raw_file_name) 170 | spacing = img_handle_sitk.GetSpacing() 171 | w, h, c = img_handle_sitk.GetSize() 172 | 173 | if direc == ('Z', 'Y', 'X'): 174 | new_size_side_0 = new_size_h 175 | new_size_side_1 = new_size_c 176 | new_size_front_0 = new_size_w 177 | new_size_front_1 = new_size_c 178 | elif direc == ('Y', 'X', 'Z'): 179 | new_size_side_0 = new_size_w 180 | new_size_side_1 = new_size_h 181 | new_size_front_0 = new_size_h 182 | new_size_front_1 = new_size_c 183 | else: 184 | raise Exception('Unknown direction!') 185 | 186 | pos_file_handle = open(img_path + 'pos/' + pos_file_name, "rb") 187 | pos_file_json = json.load(pos_file_handle) 188 | 189 | label_side = np.zeros((new_size_side_0, new_size_side_1, 25)) 190 | label_front = np.zeros((new_size_front_0, new_size_front_1, 25)) 191 | 192 | x_side, y_side = np.meshgrid(range(new_size_side_0), range(new_size_side_1), indexing='ij') 193 | x_front, y_front = np.meshgrid(range(new_size_front_0), range(new_size_front_1), indexing='ij') 194 | 195 | for idx in pos_file_json: 196 | print(idx) 197 | label, location_x, location_y, location_z = get_centroid_pos(img_handle_sitk, w, h, c, idx) 198 | #location_x = location_x * spacing[0] / resolution 199 | #location_y = location_y * spacing[1] / resolution 200 | #location_z = location_z * spacing[2] / resolution 201 | if label == 25: 202 | continue 203 | if direc == ('Z', 'Y', 'X'): 204 | label_side[:, :, label] = np.exp(-((x_side - location_y) ** 2 + (y_side-location_z) ** 2) * 0.02) 205 | label_front[:, :, label] = np.exp(-((x_front - location_x) ** 2 + (y_front - location_z) ** 2) * 0.02) 206 | else: 207 | label_side[:, :, label] = np.exp(-((x_side - location_x) ** 2 + (y_side - location_y) ** 2) * 0.02) 208 | label_front[:, :, label] = np.exp(-((x_front - location_y) ** 2 + (y_front - location_z) ** 2) * 0.02) 209 | 210 | label_side[:, :, 0] = 1 - np.max(label_side[:, :, 1:25], axis=2) 211 | label_front[:, :, 0] = 1 - np.max(label_front[:, :, 1:25], axis=2) 212 | 213 | assert label_side.shape[0] in (w, h, c) 214 | assert label_side.shape[1] in (w, h, c) 215 | assert label_front.shape[0] in (w, h, c) 216 | assert label_front.shape[1] in (w, h, c) 217 | assert img_same_res.shape == (w, h, c) 218 | 219 | imageio.imwrite('datasets/Snapshot/' + raw_file_name[0:8] + '_side.jpg', np.sum(label_side[:, :, 1:25], axis=2)) 220 | imageio.imwrite('datasets/Snapshot/' + raw_file_name[0:8] + '_front.jpg', np.sum(label_front[:, :, 1:25], axis=2)) 221 | imageio.imwrite('datasets/Snapshot/' + raw_file_name[0:8] + '_side_bkgd.jpg', label_side[:, :, 0]) 222 | imageio.imwrite('datasets/Snapshot/' + raw_file_name[0:8] + '_front_bkgd.jpg', label_front[:, :, 0]) 223 | 224 | io.savemat('datasets/Label2D/' + raw_file_name[0:8] + ".mat", 225 | {"side": label_side, "front": label_front}) 226 | 227 | 228 | 229 | 230 | if __name__ == '__main__': 231 | prepare_SSD_input() -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .btrfly_net import BtrflyNet 2 | from .eb_discriminator import EBGAN 3 | 4 | 5 | def build_model(cfg, name="Btrfly"): 6 | if name == "EBGAN": 7 | return EBGAN() 8 | return BtrflyNet(cfg) -------------------------------------------------------------------------------- /models/btrfly_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def crop(input1, input2): 5 | assert input1.shape[0] == input2.shape[0] 6 | assert input1.shape[2] - input2.shape[2] in (0, 1) 7 | assert input1.shape[3] - input2.shape[3] in (0, 1) 8 | 9 | return (input1[:, :, :input2.shape[2], :input2.shape[3]], input2) 10 | 11 | class conv_blk(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, drop_out): 13 | super(conv_blk, self).__init__() 14 | self.blk = nn.Sequential( 15 | nn.Conv2d( 16 | in_channels=in_channels, out_channels=out_channels, 17 | kernel_size=kernel_size, stride=stride, padding=padding, 18 | ), 19 | nn.ReLU(inplace=True), 20 | nn.Dropout(0.5 if drop_out else 0), 21 | nn.BatchNorm2d(num_features=out_channels), 22 | ) 23 | 24 | def forward(self, input): 25 | output = self.blk(input) 26 | return output 27 | 28 | class deconv_blk(nn.Module): 29 | def __init__(self, in_channels): 30 | super(deconv_blk, self).__init__() 31 | self.blk = nn.Sequential( 32 | nn.ConvTranspose2d( 33 | in_channels=in_channels, out_channels=in_channels, 34 | kernel_size=4, stride=2, padding=1, 35 | ), 36 | nn.ReLU(inplace=True), 37 | ) 38 | 39 | def forward(self, input): 40 | output = self.blk(input) 41 | return output 42 | 43 | class green(nn.Module): 44 | def __init__(self, cfg, pos): 45 | super(green, self).__init__() 46 | self.conv =conv_blk(in_channels=cfg.MODEL.CHANNELS[pos], 47 | out_channels=cfg.MODEL.CHANNELS[pos+1], 48 | kernel_size=1 if pos == 12 else 3, 49 | stride=1, 50 | padding=0 if pos == 12 else 1, 51 | drop_out=True if pos in (4, 5) else False, 52 | ) 53 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 54 | 55 | def forward(self, input): 56 | output_side = self.conv(input) 57 | output_side_pad = nn.functional.pad(output_side, (0, (output_side.shape[3] % 2), 0, (output_side.shape[2] % 2), 0, 0, 0, 0)) 58 | output_main = self.pool(output_side_pad) 59 | return output_main, output_side 60 | 61 | class purple(nn.Module): 62 | def __init__(self, cfg, pos): 63 | super(purple, self).__init__() 64 | self.deconv = deconv_blk(in_channels=cfg.MODEL.CHANNELS[pos]) 65 | self.conv = conv_blk(in_channels=cfg.MODEL.CHANNELS[pos] + (cfg.MODEL.CHANNELS[13 - pos] if pos < 9 else cfg.MODEL.CHANNELS[12 - pos]), 66 | out_channels=cfg.MODEL.CHANNELS[pos+1], 67 | kernel_size=3, stride=1, padding=1, drop_out=False, 68 | ) 69 | 70 | def forward(self, input_main, input_side): 71 | output = self.deconv(input_main) 72 | 73 | output = torch.cat(crop(output, input_side), dim=1) 74 | output = self.conv(output) 75 | return output 76 | 77 | class red(nn.Module): 78 | def __init__(self, cfg, pos): 79 | super(red, self).__init__() 80 | self.blk = nn.Conv2d( 81 | in_channels=cfg.MODEL.CHANNELS[pos], 82 | out_channels=cfg.MODEL.CHANNELS[pos+1], 83 | kernel_size=1,stride=1, padding=0, 84 | ) 85 | 86 | def forward(self, input): 87 | output = self.blk(input) 88 | return output 89 | 90 | 91 | 92 | class in_arm(nn.Module): 93 | def __init__(self, cfg): 94 | super(in_arm, self).__init__() 95 | self.green0 = green(cfg, pos=0) 96 | self.green1 = green(cfg, pos=1) 97 | self.green2 = green(cfg, pos=2) 98 | 99 | def forward(self, input): 100 | output_main, output_side_0 = self.green0(input) 101 | output_main, output_side_1 = self.green1(output_main) 102 | output_main, output_side_2 = self.green2(output_main) 103 | return output_main, output_side_0, output_side_1, output_side_2 104 | 105 | 106 | class out_arm(nn.Module): 107 | def __init__(self, cfg): 108 | super(out_arm, self).__init__() 109 | self.purple0 = purple(cfg, pos=9) 110 | self.purple1 = purple(cfg, pos=10) 111 | self.purple2 = purple(cfg, pos=11) 112 | self.red = red(cfg, pos=12) 113 | 114 | def forward(self, input_main, input_side_2, input_side_1, input_side_0): 115 | output = self.purple0(input_main, input_side_2) 116 | output = self.purple1(output, input_side_1) 117 | output = self.purple2(output, input_side_0) 118 | output = self.red(output) 119 | return output 120 | 121 | 122 | class body(nn.Module): 123 | def __init__(self, cfg): 124 | super(body, self).__init__() 125 | self.green0 = green(cfg, pos=4) 126 | self.green1 = green(cfg, pos=5) 127 | self.conv = conv_blk(in_channels=cfg.MODEL.CHANNELS[6], 128 | out_channels=cfg.MODEL.CHANNELS[7], 129 | kernel_size=3, stride=1, padding=1, drop_out=True, 130 | ) 131 | self.purple0 = purple(cfg, pos=7) 132 | self.purple1 = purple(cfg, pos=8) 133 | 134 | def forward(self, input_sag, input_cor): 135 | output = torch.cat(crop(input_sag, input_cor), dim=1) 136 | output, side_0 = self.green0(output) 137 | output, side_1 = self.green1(output) 138 | output = self.conv(output) 139 | output = self.purple0(output, side_1) 140 | output = self.purple1(output, side_0) 141 | return output 142 | 143 | 144 | class BtrflyNet(nn.Module): 145 | def __init__(self, cfg): 146 | super(BtrflyNet, self).__init__() 147 | self.input_arm_sag = in_arm(cfg) 148 | self.input_arm_cor = in_arm(cfg) 149 | self.body = body(cfg) 150 | self.output_arm_sag = out_arm(cfg) 151 | self.output_arm_cor = out_arm(cfg) 152 | 153 | def forward(self, sag, cor): 154 | sag_body, sag_side0, sag_side1, sag_side2 = self.input_arm_sag(sag) 155 | cor_body, cor_side0, cor_side1, cor_side2 = self.input_arm_cor(cor) 156 | body_out = self.body(sag_body, cor_body) 157 | out_sag = self.output_arm_sag(body_out, sag_side2, sag_side1, sag_side0) 158 | out_cor = self.output_arm_cor(body_out, cor_side2, cor_side1, cor_side0) 159 | return out_sag, out_cor -------------------------------------------------------------------------------- /models/eb_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class EBGAN(nn.Module): 5 | """ 6 | 7 | """ 8 | def __init__(self): 9 | super(EBGAN, self).__init__() 10 | 11 | self.avgpool3d = nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 12 | self.architecture = nn.Sequential( 13 | # 14 | nn.Conv3d(in_channels=1, out_channels=5, kernel_size=(5, 5, 5), padding=(2, 2, 2), stride=(1, 1, 1)), 15 | nn.LeakyReLU(), 16 | nn.BatchNorm3d(5), 17 | # 18 | nn.Conv3d(in_channels=5, out_channels=10, kernel_size=(5, 5, 5), padding=(2, 2, 2), stride=(1, 1, 1)), 19 | nn.AvgPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)), 20 | nn.LeakyReLU(), 21 | nn.BatchNorm3d(10), 22 | # 23 | nn.Conv3d(in_channels=10, out_channels=10, kernel_size=(5, 5, 5), padding=(2, 4, 4), dilation=(1, 2, 2)), 24 | nn.AvgPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)), 25 | nn.LeakyReLU(), 26 | nn.BatchNorm3d(10), 27 | # 28 | nn.Conv3d(in_channels=10, out_channels=10, kernel_size=(5, 5, 5), padding=(2, 4, 4), dilation=(1, 2, 2)), 29 | nn.LeakyReLU(), 30 | nn.BatchNorm3d(10), 31 | # 32 | nn.ConvTranspose3d(in_channels=10, out_channels=5, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)), 33 | nn.LeakyReLU(), 34 | nn.BatchNorm3d(5), 35 | # 36 | nn.ConvTranspose3d(in_channels=5, out_channels=5, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)), 37 | nn.LeakyReLU(), 38 | nn.BatchNorm3d(5), 39 | # 40 | nn.Conv3d(in_channels=5, out_channels=1, kernel_size=(1,1,1)) 41 | 42 | ) 43 | 44 | 45 | def forward(self, input): 46 | input2 = input[:, 1:25, :, :].view(input.shape[0], 1, 24, input.shape[2], input.shape[3]) 47 | reduce_input = self.avgpool3d(input2) 48 | output = self.architecture(reduce_input) 49 | D = pow((reduce_input - output),2) 50 | #D.view(-1)#D.sum(-1) 51 | return D.view(-1).sum(-1) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | from configs.defaults import cfg 5 | from utils.misc import mkdir 6 | from utils.logger import * 7 | from utils.checkpoint import CheckPointer 8 | from models import build_model 9 | from input import get_verse_list 10 | from utils.data import * 11 | from torchvision import transforms, utils 12 | import scipy.misc 13 | from torch.utils.data import DataLoader 14 | from utils.trainer import do_train 15 | import imageio 16 | from utils.metrics import * 17 | import torch.nn.functional as F 18 | import scipy.io 19 | 20 | @torch.no_grad() 21 | def pred(cfg): 22 | device = torch.device(cfg.TEST.DEVICE) 23 | model = build_model(cfg).to(device) 24 | lr = cfg.SOLVER.LR 25 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 26 | arguments = {"iteration": 0, "epoch": 0} 27 | checkpointer = CheckPointer(model, optimizer, cfg.OUTPUT_DIR) 28 | extra_checkpoint_data = checkpointer.load(is_val=True) 29 | arguments.update(extra_checkpoint_data) 30 | 31 | model.eval() 32 | is_test = 0 33 | dataset = ProjectionDataset(cfg=cfg, mat_dir=cfg.MAT_DIR_TEST if is_test else cfg.MAT_DIR_VAL 34 | , input_img_dir=cfg.INPUT_IMG_DIR_TEST if is_test else cfg.INPUT_IMG_DIR_VAL, 35 | transform=transforms.Compose([ToTensor()])) 36 | test_loader = DataLoader(dataset, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, num_workers=4) 37 | mkdir(os.path.join(cfg.OUTPUT_DIR, "jpg_val")) 38 | #os.system("rm " + os.path.join(cfg.OUTPUT_DIR, "jpg_val/*")) 39 | name_list = [] 40 | whole_step_list = [] 41 | score_list = [] 42 | position_cor_list = [] 43 | position_sag_list = [] 44 | 45 | if is_test == 0: 46 | val_file_list = glob.glob(cfg.MAT_DIR_VAL + '*.mat') 47 | val_file_list.sort() 48 | 49 | gt_label_list = [] 50 | for idx in range(len(val_file_list)): 51 | gt_label_list.append(json.load( 52 | open(cfg.ORIGINAL_PATH + 'pos/' + val_file_list[idx][len(cfg.MAT_DIR_VAL):-4] + '_ctd.json', "rb"))) 53 | 54 | for idx, sample in enumerate(test_loader): 55 | print(idx) 56 | input_cor = sample["input_cor"].float().to(device) 57 | input_sag = sample["input_sag"].float().to(device) 58 | sag_pad = sample["sag_pad"] 59 | cor_pad = sample["cor_pad"] 60 | if is_test == 0: 61 | gt_cor = sample["gt_cor"].float().to(device) 62 | gt_sag = sample["gt_sag"].float().to(device) 63 | output_sag, output_cor = model(input_sag, input_cor) 64 | 65 | for batch_num in range(input_cor.shape[0]): 66 | output_sag[batch_num, :, :sag_pad[2][batch_num], :] = 0 67 | output_sag[batch_num, :, :, output_sag.shape[3] - sag_pad[1][batch_num]:] = 0 68 | output_sag[batch_num, :, output_sag.shape[2] - sag_pad[3][batch_num]:, :] = 0 69 | output_sag[batch_num, :, :, :sag_pad[0][batch_num]] = 0 70 | 71 | output_cor[batch_num, :, :cor_pad[2][batch_num], :] = 0 72 | output_cor[batch_num, :, :, output_cor.shape[3] - cor_pad[1][batch_num]:] = 0 73 | output_cor[batch_num, :, output_cor.shape[2] - cor_pad[3][batch_num]:, :] = 0 74 | output_cor[batch_num, :, :, :cor_pad[0][batch_num]] = 0 75 | 76 | if is_test: 77 | for j in range(input_cor.shape[0]): 78 | imageio.imwrite(cfg.OUTPUT_DIR + "jpg_val/" + sample['name'][j] + "_input_cor.jpg", torch.squeeze(input_cor[j, :, :]).cpu().detach().numpy()) 79 | imageio.imwrite(cfg.OUTPUT_DIR + "jpg_val/" + sample['name'][j] + "_input_sag.jpg", torch.squeeze(input_sag[j, :, :]).cpu().detach().numpy()) 80 | imageio.imwrite(cfg.OUTPUT_DIR + "jpg_val/" + sample['name'][j] + "_output_cor.jpg", 30 * np.max(torch.squeeze(output_cor[j, 1:25, :, :]).cpu().detach().numpy(), axis=0)) 81 | imageio.imwrite(cfg.OUTPUT_DIR + "jpg_val/" + sample['name'][j] + "_output_sag.jpg", 30 * np.max(torch.squeeze(output_sag[j, 1:25, :, :]).cpu().detach().numpy(), axis=0)) 82 | 83 | 84 | #for c_num in range(24): 85 | #output_sag[:, c_num + 1, :, :] = output_sag[:, c_num+1, :, :] * (output_sag[:, 0, :, :].max() - output_sag[:, 0, :, :]) 86 | #output_cor[:, c_num + 1, :, :] = output_cor[:, c_num + 1, :, :] * (output_cor[:, 0, :, :].max() - output_cor[:, 0, :, :]) 87 | position, position_batch_cor , position_batch_sag = pred_pos_3(device, output_sag[:, 1:25, :, :], output_cor[:, 1:25, :, :], sample['direction'], 88 | sample['crop_info'], sample['spacing'], sample['cor_pad'], sample['sag_pad']) 89 | # position = pred_pos(device, output_sag[:, 1:25, :, :], output_cor[:, 1:25, :, :], sample['direction'], 90 | # sample['crop_info'], sample['spacing'], sample['cor_pad'], sample['sag_pad']) 91 | 92 | if idx == 0: 93 | for step in range(position.shape[0]): 94 | whole_step_list.append([]) 95 | score_list.append([]) 96 | 97 | for j in range(input_sag.shape[0]): 98 | position_cor_list.append(position_batch_cor[j, :, :]) 99 | position_sag_list.append(position_batch_sag[j, :, :]) 100 | 101 | for j in range(input_sag.shape[0]): 102 | for step in range(position.shape[0]): 103 | whole_step_list[step].append( 104 | create_centroid_pos([sample['direction_sitk'][0][j], sample['direction_sitk'][1][j], sample['direction_sitk'][2][j], 105 | sample['direction_sitk'][3][j], sample['direction_sitk'][4][j], sample['direction_sitk'][5][j], 106 | sample['direction_sitk'][6][j], sample['direction_sitk'][7][j], sample['direction_sitk'][8][j]], 107 | [sample['spacing'][0][j], sample['spacing'][1][j], sample['spacing'][2][j]], 108 | [sample['size_raw'][0][j], sample['size_raw'][1][j], sample['size_raw'][2][j]], 109 | position[step, j, :, 0:3]) 110 | ) 111 | score_list[step].append(position[step, j, :, 3]) 112 | name_list.append(sample["name"][j]) 113 | 114 | id_rate = list(range(position.shape[0])) 115 | id_rate_gt = list(range(position.shape[0])) 116 | if is_test == 0: 117 | for step in range(position.shape[0]): 118 | id_rate[step], id_rate_gt[step] = Get_Identification_Rate(gt_label_list, whole_step_list[step]) 119 | 120 | if is_test: 121 | torch.save({"pred_list": whole_step_list[0], 'score': score_list[0], 122 | 'pred_cor_list': position_cor_list, 'pred_sag_list':position_sag_list, 123 | 'name':name_list}, "pred_list/pred_test.pth") 124 | else: 125 | print("id_rate: ", id_rate) 126 | print("id_rate_gt: ", id_rate_gt) 127 | torch.save({"pred_list": whole_step_list[0], 'score': score_list[0], 'name': name_list, 'gt_list': gt_label_list, 128 | 'pred_cor_list': position_cor_list, 'pred_sag_list':position_sag_list}, 129 | "pred_list/pred.pth") 130 | 131 | 132 | 133 | 134 | def main(): 135 | torch.cuda.empty_cache() 136 | # some configs, including yaml file 137 | parser = argparse.ArgumentParser(description='Btrfly Net Training with Pytorch') 138 | parser.add_argument( 139 | "--config_file", 140 | default="configs/btrfly.yaml", 141 | metavar="FILE", 142 | help="path to config file", 143 | type=str, 144 | ) 145 | parser.add_argument("--log_step", default=1, type=int, help="print logs every log_step") 146 | parser.add_argument("--save_step", default=50, type=int, help="save checkpoint every save_step") 147 | parser.add_argument("--eval_step", default=10, type=int, help="evaluate dataset every eval_step, disabled if eval_step <= 0") 148 | parser.add_argument("--use_tensorboard", default=1, type=int, help="use visdom to illustrate training process, unless use_visdom == 0") 149 | args = parser.parse_args() 150 | 151 | # enable inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware 152 | # so it helps increase training speed 153 | if torch.cuda.is_available(): 154 | torch.backends.cudnn.benchmark = True 155 | 156 | # use YACS as the config manager, see https://github.com/rbgirshick/yacs for more info 157 | # cfg contains all the configs set by configs/defaults and overrided by config_file (see line 13) 158 | cfg.merge_from_file(args.config_file) 159 | cfg.freeze() 160 | # make output directory designated by OUTPUT_DIR if necessary 161 | if cfg.OUTPUT_DIR: 162 | mkdir(cfg.OUTPUT_DIR) 163 | 164 | # set up 2 loggers 165 | # logger_all can print time and logger's name 166 | # logger_message only print message 167 | # it will print info to stdout and to OUTPUT_DIR/log.txt (way: append) 168 | logger_all = setup_colorful_logger( 169 | "main", 170 | save_dir=os.path.join(cfg.OUTPUT_DIR, 'log.txt'), 171 | format="include_other_info") 172 | logger_message = setup_colorful_logger( 173 | "main_message", 174 | save_dir=os.path.join(cfg.OUTPUT_DIR, 'log.txt'), 175 | format="only_message") 176 | 177 | # print config info (cfg and args) 178 | # args are obtained by command line 179 | # cfg is obtained by yaml file and defaults.py in configs/ 180 | separator(logger_message) 181 | logger_message.warning(" ---------------------------------------") 182 | logger_message.warning("| Your config: |") 183 | logger_message.warning(" ---------------------------------------") 184 | logger_message.info(args) 185 | logger_message.warning(" ---------------------------------------") 186 | logger_message.warning("| Running with entire config: |") 187 | logger_message.warning(" ---------------------------------------") 188 | logger_message.info(cfg) 189 | separator(logger_message) 190 | 191 | pred(cfg) 192 | 193 | if __name__ == '__main__': 194 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from configs.defaults import cfg 4 | from utils.misc import mkdir 5 | from utils.logger import * 6 | from utils.checkpoint import CheckPointer 7 | from models import build_model 8 | from utils.data import * 9 | from torchvision import transforms 10 | from torch.utils.data import DataLoader 11 | from utils.trainer import do_train 12 | 13 | def train(cfg, args): 14 | # set default device 15 | device = torch.device(cfg.MODEL.DEVICE) 16 | # build Butterfly Net as [model] 17 | model = build_model(cfg, name="Btrfly").to(device) 18 | 19 | #build discriminator nets as [model_D1] and [model_D2] if necessary 20 | model_D1, model_D2 = None, None 21 | if cfg.MODEL.USE_GAN: 22 | model_D1 = build_model(cfg, name="EBGAN").to(device) 23 | model_D2 = build_model(cfg, name="EBGAN").to(device) 24 | print(model_D1) 25 | 26 | #if you need to visualize the Net, uncomment these codes 27 | """ 28 | input1 = torch.rand(3, 1, 128, 128) 29 | input2 = torch.rand(3, 1, 128, 128) 30 | with SummaryWriter(comment='BtrflyNet') as w: 31 | w.add_graph(model, (input1, input2, )) 32 | """ 33 | 34 | # learning rate 35 | lr = cfg.SOLVER.LR 36 | # optimizer of [model] 37 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 38 | #optimizers of [model_D1] and [model D2] if necessary 39 | optimizer_D1, optimizer_D2 = None, None 40 | if cfg.MODEL.USE_GAN: 41 | optimizer_D1 = torch.optim.Adam(model_D1.parameters(), lr=lr) 42 | optimizer_D2 = torch.optim.Adam(model_D2.parameters(), lr=lr) 43 | 44 | # update [checkpointer] if necessary 45 | # except iteration and epoch numbers, 46 | # [arguments] also has a list which contains the information of the best several models, 47 | # including their numbers and their validation losses 48 | arguments = {"iteration": 0, "epoch": 0, "list_loss_val": {}} 49 | checkpointer = CheckPointer(model, optimizer, cfg.OUTPUT_DIR) 50 | extra_checkpoint_data = checkpointer.load() 51 | arguments.update(extra_checkpoint_data) 52 | 53 | # build training set from the directory designated by cfg 54 | dataset = ProjectionDataset(cfg=cfg, 55 | mat_dir=cfg.MAT_DIR_TRAIN, 56 | input_img_dir=cfg.INPUT_IMG_DIR_TRAIN, 57 | transform=transforms.Compose([ToTensor()]), 58 | ) 59 | train_loader = DataLoader(dataset, batch_size=cfg.SOLVER.BATCH_SIZE, shuffle=True, num_workers=4) 60 | 61 | 62 | return do_train(cfg, args, model, model_D1, model_D2, train_loader, optimizer, optimizer_D1, optimizer_D2, checkpointer, device, arguments) 63 | 64 | 65 | 66 | 67 | def main(): 68 | torch.cuda.empty_cache() 69 | # some configs, including yaml file 70 | parser = argparse.ArgumentParser(description='Btrfly Net Training with Pytorch') 71 | parser.add_argument( 72 | "--config_file", 73 | default="configs/btrfly.yaml", 74 | metavar="FILE", 75 | help="path to config file", 76 | type=str, 77 | ) 78 | parser.add_argument("--log_step", default=1, type=int, help="print logs every log_step") 79 | parser.add_argument("--save_step", default=5, type=int, help="save checkpoint every save_step") 80 | parser.add_argument("--eval_step", default=5, type=int, help="evaluate dataset every eval_step, disabled if eval_step <= 0") 81 | parser.add_argument("--use_tensorboard", default=1, type=int, help="use visdom to illustrate training process, unless use_visdom == 0") 82 | parser.add_argument("--train_from_no_checkpoint", default=1, type=int, help="train_from_no_checkpoint") 83 | args = parser.parse_args() 84 | 85 | # enable inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware 86 | # so it helps increase training speed 87 | if torch.cuda.is_available(): 88 | torch.backends.cudnn.benchmark = True 89 | 90 | # use YACS as the config manager, see https://github.com/rbgirshick/yacs for more info 91 | # cfg contains all the configs set by configs/defaults and overrided by config_file 92 | cfg.merge_from_file(args.config_file) 93 | cfg.freeze() 94 | # make output directory designated by OUTPUT_DIR if necessary 95 | if cfg.OUTPUT_DIR: 96 | mkdir(cfg.OUTPUT_DIR) 97 | # if you need, this removes the results related to last training 98 | if args.train_from_no_checkpoint: 99 | os.system("rm -r " + os.path.join(cfg.OUTPUT_DIR, "*")) 100 | 101 | # logger_message help print message 102 | # it will also print info to stdout and to OUTPUT_DIR/log.txt (way: append) 103 | logger_message = setup_colorful_logger( 104 | "main_message", 105 | save_dir=os.path.join(cfg.OUTPUT_DIR, 'log.txt'), 106 | format="only_message") 107 | 108 | # print config info (cfg and args) 109 | # args are obtained by command line 110 | # cfg is obtained by yaml file and defaults.py in configs/ 111 | separator(logger_message) 112 | logger_message.warning(" ---------------------------------------") 113 | logger_message.warning("| Your config: |") 114 | logger_message.warning(" ---------------------------------------") 115 | logger_message.info(args) 116 | logger_message.warning(" ---------------------------------------") 117 | logger_message.warning("| Running with entire config: |") 118 | logger_message.warning(" ---------------------------------------") 119 | logger_message.info(cfg) 120 | separator(logger_message) 121 | 122 | train(cfg, args) 123 | 124 | if __name__ == '__main__': 125 | main() -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | from utils.logger import * 2 | import os 3 | import torch 4 | 5 | class CheckPointer: 6 | _last_checkpoint_name = 'last_checkpoint.txt' 7 | _best_checkpoint_name = 'best_checkpoint.txt' 8 | def __init__(self, model, optimizer, save_dir=""): 9 | self.model = model 10 | self.optimizer = optimizer 11 | self.save_dir = save_dir 12 | self.logger = setup_colorful_logger("checkpointer", save_dir=os.path.join(save_dir, 'log.txt'), format="include_other_info") 13 | 14 | def save(self, name, is_last, is_best, **kwargs): 15 | if not self.save_dir: 16 | return 17 | 18 | data = {} 19 | data['model'] = self.model.state_dict() 20 | if self.optimizer is not None: 21 | data["optimizer"] = self.optimizer.state_dict() 22 | data.update(kwargs) 23 | 24 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 25 | torch.save(data, save_file) 26 | self.tag_last_checkpoint(save_file, is_last=is_last, is_best=is_best) 27 | 28 | def load(self, f=None, use_latest=True, is_val=False): 29 | if self.has_checkpoint() and use_latest: 30 | if is_val: 31 | f = self.get_checkpoint_file(is_val=True) 32 | else: 33 | f = self.get_checkpoint_file() 34 | if not f: 35 | self.logger.warning("No checkpoint found.") 36 | return {} 37 | 38 | self.logger.warning("Loading checkpoint from {}".format(f)) 39 | checkpoint = torch.load(f, map_location=torch.device("cpu")) 40 | model = self.model 41 | model.load_state_dict(checkpoint.pop("model")) 42 | if "optimizer" in checkpoint and self.optimizer: 43 | self.logger.warning("Loading optimizer from {}".format(f)) 44 | self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 45 | return checkpoint 46 | 47 | 48 | def has_checkpoint(self): 49 | save_file = os.path.join(self.save_dir, self._last_checkpoint_name) 50 | return os.path.exists(save_file) 51 | 52 | def get_checkpoint_file(self, is_val=False): 53 | if is_val: 54 | save_file = os.path.join(self.save_dir, self._best_checkpoint_name) 55 | else: 56 | save_file = os.path.join(self.save_dir, self._last_checkpoint_name) 57 | try: 58 | with open (save_file, 'r') as f: 59 | last_saved = f.read() 60 | last_saved = last_saved.strip() 61 | except IOError: 62 | last_saved = "" 63 | return last_saved 64 | 65 | def tag_last_checkpoint(self, last_filename, is_last, is_best): 66 | if is_last: 67 | save_file = os.path.join(self.save_dir, self._last_checkpoint_name) 68 | with open(save_file, "w") as f: 69 | f.write(last_filename) 70 | elif is_best: 71 | save_file = os.path.join(self.save_dir, self._best_checkpoint_name) 72 | with open(save_file, "w") as f: 73 | f.write(last_filename) -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | from input import * 3 | import SimpleITK as sitk 4 | import torch.nn.functional as func 5 | 6 | def get_image_mode_length(img_path): 7 | Dic = {0:'Z', 1:'Y', 2:'X'} 8 | img = sitk.ReadImage(img_path) 9 | direction = np.round(list(img.GetDirection())) 10 | direc0 = direction[0:7:3] 11 | direc1 = direction[1:8:3] 12 | direc2 = direction[2:9:3] 13 | 14 | dim0_char = Dic[(np.argwhere((np.abs(np.round(direc0))) == 1))[0][0]] 15 | dim1_char = Dic[(np.argwhere((np.abs(np.round(direc1))) == 1))[0][0]] 16 | dim2_char = Dic[(np.argwhere((np.abs(np.round(direc2))) == 1))[0][0]] 17 | 18 | if [dim0_char, dim1_char, dim2_char] == ['Y', 'X', 'Z']: 19 | dimlength = 1 20 | elif [dim0_char, dim1_char, dim2_char] == ['Z', 'Y', 'X']: 21 | dimlength = 2 22 | 23 | return dimlength 24 | 25 | class ProjectionDataset(Dataset.Dataset): 26 | def __init__(self, cfg, mat_dir=None, input_img_dir=None, transform=None): 27 | """ 28 | :param mat_dir: (string) directory with all the .mat file containing heat maps 29 | :param input_img_dir: (string) directory with all the images 30 | :param transform: (callable, optional) optional transform to be applied on a sample 31 | """ 32 | self.cfg = cfg 33 | self.mat_dir = mat_dir 34 | self.mat_list = glob.glob(pathname=os.path.join(mat_dir, "*.mat")) 35 | self.mat_list.sort() 36 | self.not_test = 'test' not in self.mat_list[0] 37 | 38 | self.input_img_dir = input_img_dir 39 | self.input_img_list = glob.glob(pathname=os.path.join(input_img_dir, "*.jpg")) 40 | self.input_img_list.sort() 41 | 42 | self.raw_file_dir = cfg.ORIGINAL_PATH + 'raw/' if self.not_test else cfg.ORIGINAL_PATH + 'test/' 43 | self.raw_file_list, self.pos_file_list, self.file_num = get_verse_list(self.raw_file_dir) 44 | 45 | self.crop_info_file_dir = cfg.CROP_INFO_DIR 46 | 47 | if "train" in input_img_dir: 48 | self.no_bg_file_dir = glob.glob(pathname=input_img_dir + "../train/*.jpg") 49 | self.no_bg_file_dir.sort() 50 | elif "val" in input_img_dir: 51 | self.no_bg_file_dir = glob.glob(pathname=input_img_dir + "../val/*.jpg") 52 | self.no_bg_file_dir.sort() 53 | else: 54 | pass 55 | 56 | if ("test" not in mat_dir) & (len(self.input_img_list) != 2 * len(self.mat_list)): 57 | raise Exception("Length error! The number of imgs should be 2 times the number of mat files.") 58 | 59 | self.transform = transform 60 | 61 | def __len__(self): 62 | return len(self.mat_list) 63 | 64 | def __getitem__(self, idx): 65 | res = 1.0 66 | not_test = 'test' not in self.mat_list[0] 67 | name = self.mat_list[idx][-12:-4] 68 | input_cor = imageio.imread(self.input_img_list[2 * idx]) 69 | input_sag = imageio.imread(self.input_img_list[2 * idx + 1]) 70 | 71 | label_2D_front = loadmat(self.mat_list[idx])["front"] 72 | label_2D_side = loadmat(self.mat_list[idx])["side"] 73 | raw_sitk = sitk.ReadImage(self.raw_file_dir + self.mat_list[idx][-12:-4] + ".nii") 74 | direction_sitk = raw_sitk.GetDirection() 75 | direction = image_mode(self.raw_file_dir + self.mat_list[idx][-12:-4] + ".nii") 76 | spacing = raw_sitk.GetSpacing() 77 | size_raw = raw_sitk.GetSize() 78 | crop_info = loadmat(self.crop_info_file_dir + self.mat_list[idx][-12:]) 79 | 80 | sample = {'not_test': not_test, 'input_cor': input_cor, 'input_sag': input_sag, 'gt_cor': label_2D_front, 81 | 'gt_sag': label_2D_side, 82 | 'direction_sitk': direction_sitk, 'direction': direction, 'spacing': spacing, 83 | 'size_raw': size_raw, 'crop_info': crop_info, 'name': name} 84 | 85 | if self.transform: 86 | sample = self.transform(sample) 87 | 88 | return sample 89 | 90 | class ToTensor(object): 91 | def __call__(self, sample): 92 | input_cor, input_sag, gt_cor, gt_sag = sample['input_cor'], sample['input_sag'], sample['gt_cor'], sample['gt_sag'] 93 | gt_cor = gt_cor.transpose((2, 0, 1)) if sample['not_test'] else gt_cor 94 | gt_sag = gt_sag.transpose((2, 0, 1)) if sample['not_test'] else gt_sag 95 | 96 | d0_uni = max(608, input_cor.shape[0], input_sag.shape[0]) 97 | d1_uni = max(608, input_cor.shape[1], input_sag.shape[1]) 98 | cor_pad_d0 = d0_uni - input_cor.shape[0] 99 | cor_pad_d1 = d1_uni - input_cor.shape[1] 100 | sag_pad_d0 = d0_uni - input_sag.shape[0] 101 | sag_pad_d1 = d1_uni - input_sag.shape[1] 102 | 103 | cor_pad = (cor_pad_d1 // 2 + (cor_pad_d1 % 2), cor_pad_d1 // 2, cor_pad_d0 // 2 + (cor_pad_d0 % 2), cor_pad_d0 // 2) 104 | sag_pad = (sag_pad_d1 // 2 + (sag_pad_d1 % 2), sag_pad_d1 // 2, sag_pad_d0 // 2 + (sag_pad_d0 % 2), sag_pad_d0 // 2) 105 | 106 | 107 | input_cor_padded = func.pad(torch.from_numpy(input_cor), cor_pad) 108 | input_sag_padded = func.pad(torch.from_numpy(input_sag), sag_pad) 109 | 110 | input_cor_padded = input_cor_padded.reshape((1, input_cor_padded.shape[0], input_cor_padded.shape[1])) 111 | input_sag_padded = input_sag_padded.reshape((1, input_sag_padded.shape[0], input_sag_padded.shape[1])) 112 | 113 | if sample['not_test']: 114 | gt_cor_padded = func.pad(torch.from_numpy(gt_cor), cor_pad) 115 | gt_sag_padded = func.pad(torch.from_numpy(gt_sag), sag_pad) 116 | 117 | assert input_cor_padded.shape[1] == d0_uni, (input_cor_padded.shape[1], d0_uni) 118 | assert input_cor_padded.shape[2] == d1_uni, (input_cor_padded.shape[2], d1_uni) 119 | assert input_sag_padded.shape[1] == d0_uni, (input_sag_padded.shape[1], d0_uni) 120 | assert input_sag_padded.shape[2] == d1_uni, (input_cor_padded.shape[2], d1_uni) 121 | 122 | if sample['not_test']: 123 | return {'input_cor':input_cor_padded, 'input_sag': input_sag_padded, 124 | 'gt_cor':gt_cor_padded, 'gt_sag':gt_sag_padded, 125 | 'direction_sitk':sample['direction_sitk'], 'direction':sample['direction'], 126 | 'spacing':sample['spacing'], 'size_raw':sample['size_raw'], 'crop_info':sample['crop_info'], 127 | "cor_pad":cor_pad, "sag_pad":sag_pad, "name":sample["name"]} 128 | else: 129 | return {'input_cor':input_cor_padded, 'input_sag': input_sag_padded, 130 | 'direction_sitk':sample['direction_sitk'], 'direction':sample['direction'], 131 | 'spacing':sample['spacing'], 'size_raw':sample['size_raw'], 'crop_info':sample['crop_info'], 132 | "cor_pad":cor_pad, "sag_pad":sag_pad, "name":sample["name"]} 133 | 134 | 135 | -------------------------------------------------------------------------------- /utils/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | from configs.defaults import cfg 5 | from utils.misc import mkdir 6 | from utils.logger import * 7 | from utils.checkpoint import CheckPointer 8 | from models import build_model 9 | from input import get_verse_list 10 | from utils import trainer 11 | from utils.metrics import * 12 | from utils.data import * 13 | from torchvision import transforms, utils 14 | import scipy.misc 15 | from torch.utils.data import DataLoader 16 | import glob 17 | import torch.nn.functional as F 18 | 19 | @torch.no_grad() 20 | def do_evaluation(cfg, model, summary_writer, global_step): 21 | device = torch.device(cfg.MODEL.DEVICE) 22 | model.eval() 23 | 24 | w = loadmat(cfg.TRAIN_WEIGHT) 25 | w_front, w_side = torch.Tensor(w["front"]).to(device), torch.Tensor(w["side"]).to(device) 26 | 27 | dataset = ProjectionDataset(cfg=cfg, mat_dir=cfg.MAT_DIR_VAL, input_img_dir=cfg.INPUT_IMG_DIR_VAL, 28 | transform=transforms.Compose([ToTensor()])) 29 | val_loader = DataLoader(dataset, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, num_workers=4) 30 | 31 | val_loss = 0 32 | val_num = 0 33 | whole_step_list = [] 34 | whole_step_list_softmax = [] 35 | whole_step_list_norm = [] 36 | val_file_list = glob.glob(cfg.MAT_DIR_VAL + '*.mat') 37 | val_file_list.sort() 38 | 39 | gt_label_list = [] 40 | for idx in range(len(val_file_list)): 41 | gt_label_list.append(json.load(open(cfg.ORIGINAL_PATH + 'pos/' + val_file_list[idx][len(cfg.MAT_DIR_VAL):-4] + '_ctd.json', "rb"))) 42 | 43 | for idx, sample in enumerate(val_loader): 44 | input_cor = sample["input_cor"].float().to(device) 45 | input_sag = sample["input_sag"].float().to(device) 46 | gt_cor = sample["gt_cor"].float().to(device) 47 | gt_sag = sample["gt_sag"].float().to(device) 48 | cor_pad = sample["cor_pad"] 49 | sag_pad = sample["sag_pad"] 50 | 51 | output_sag, output_cor = model(input_sag, input_cor) 52 | 53 | for batch_num in range(gt_cor.shape[0]): 54 | output_sag[batch_num, :, :sag_pad[2][batch_num], :] = 0 55 | output_sag[batch_num, :, :, output_sag.shape[3] - sag_pad[1][batch_num]:] = 0 56 | output_sag[batch_num, :, output_sag.shape[2] - sag_pad[3][batch_num]:, :] = 0 57 | output_sag[batch_num, :, :, :sag_pad[0][batch_num]] = 0 58 | 59 | output_cor[batch_num, :, :cor_pad[2][batch_num], :] = 0 60 | output_cor[batch_num, :, :, output_cor.shape[3] - cor_pad[1][batch_num]:] = 0 61 | output_cor[batch_num, :, output_cor.shape[2] - cor_pad[3][batch_num]:, :] = 0 62 | output_cor[batch_num, :, :, :cor_pad[0][batch_num]] = 0 63 | 64 | for i in range(output_sag.shape[0]): 65 | summary_writer.add_image(str(idx) + "_" + str(i) + "_" + "side_input", 66 | (input_sag[i, :, :, :]-torch.max(input_sag[i, :, :, :]))/(torch.max(input_sag[i, :, :, :])-torch.min(input_sag[i, :, :, :])), 67 | global_step=global_step) 68 | summary_writer.add_image(str(idx) + "_" + str(i) + "_" + "front_input", 69 | (input_cor[i, :, :, :]-torch.max(input_cor[i, :, :, :]))/(torch.max(input_cor[i, :, :, :])-torch.min(input_cor[i, :, :, :])), 70 | global_step=global_step) 71 | summary_writer.add_image(str(idx) + "_" + str(i) + "_" + "side_output", 72 | torch.max(output_sag[i, 1:25, :, :], dim=0)[0].view(1, output_sag.shape[2], output_sag.shape[3]), 73 | global_step=global_step) 74 | summary_writer.add_image(str(idx) + "_" + str(i) + "_" + "front_output", 75 | torch.max(output_cor[i, 1:25, :, :], dim=0)[0].view(1, output_cor.shape[2], output_cor.shape[3]), 76 | global_step=global_step) 77 | summary_writer.add_image(str(idx) + "_" + str(i) + "_" + "side_output_bkgd", 78 | output_sag[i, 0, :, :].view(1, output_sag.shape[2], output_sag.shape[3]), 79 | global_step=global_step) 80 | summary_writer.add_image(str(idx) + "_" + str(i) + "_" + "front_output_bkgd", 81 | output_cor[i, 0, :, :].view(1, output_cor.shape[2],output_cor.shape[3]), 82 | global_step=global_step) 83 | 84 | summary_writer.add_image(str(idx) + "_" + str(i) + "_" + "side_gt", 85 | torch.max(gt_sag[i, 1:25, :, :], dim=0)[0].view(1, gt_sag.shape[2], gt_sag.shape[3]), 86 | global_step=global_step) 87 | summary_writer.add_image(str(idx) + "_" + str(i) + "_" + "front_gt", 88 | torch.max(gt_cor[i, 1:25, :, :], dim=0)[0].view(1, gt_cor.shape[2], gt_cor.shape[3]), 89 | global_step=global_step) 90 | 91 | 92 | position = pred_pos_3(device, output_sag[:, 1:25, :, :], output_cor[:, 1:25, :, :], sample['direction'], sample['crop_info'], sample['spacing'], sample['cor_pad'], sample['sag_pad'])[0] 93 | 94 | if idx == 0: 95 | for step in range(position.shape[0]): 96 | whole_step_list.append([]) 97 | 98 | for j in range(input_sag.shape[0]): 99 | for step in range(position.shape[0]): 100 | whole_step_list[step].append( 101 | create_centroid_pos([sample['direction_sitk'][0][j], sample['direction_sitk'][1][j], sample['direction_sitk'][2][j], 102 | sample['direction_sitk'][3][j], sample['direction_sitk'][4][j], sample['direction_sitk'][5][j], 103 | sample['direction_sitk'][6][j], sample['direction_sitk'][7][j], sample['direction_sitk'][8][j]], 104 | [sample['spacing'][0][j], sample['spacing'][1][j], sample['spacing'][2][j]], 105 | [sample['size_raw'][0][j], sample['size_raw'][1][j], sample['size_raw'][2][j]], 106 | position[step, j, :, :]) 107 | ) 108 | 109 | val_loss = val_loss + trainer.compute_loss(gt_sag[:, :, :, :], gt_cor[:, :, :, :], output_sag, output_cor, w_front, w_side, device, sag_pad, cor_pad) 110 | val_num += gt_cor.size(0) 111 | 112 | 113 | 114 | id_rate = list(range(position.shape[0])) 115 | id_rate_gt = list(range(position.shape[0])) 116 | for step in range(position.shape[0]): 117 | id_rate[step], id_rate_gt[step] = Get_Identification_Rate(gt_label_list, whole_step_list[step]) 118 | 119 | summary_writer.add_scalar('id_rate_val', max(id_rate), global_step=global_step) 120 | summary_writer.add_scalar('id_rate_val_gt', max(id_rate_gt), global_step=global_step) 121 | 122 | if val_num != len(glob.glob(pathname=cfg.MAT_DIR_VAL + "*.mat")): 123 | raise Exception("Validation number is not equal to sum of batch sizes!") 124 | 125 | return val_loss.item() / val_num, id_rate, id_rate_gt -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import colorlog 5 | 6 | log_colors = { 7 | 'DEBUG': 'white', 8 | 'INFO': '', 9 | 'WARNING': 'cyan', 10 | 'ERROR': 'red', 11 | 'CRITICAL': 'purple', 12 | } 13 | 14 | 15 | def setup_colorful_logger(name, save_dir=None, format="only_message"): 16 | """ 17 | set up colorful logger (colors are defined in log_colors[type: dict]) 18 | :param name: logger's name 19 | :param save_dir: save log info to a txt file 20 | :param format: only message will be printed if it is a str named "only_message" 21 | :return: the logger 22 | """ 23 | logger = colorlog.getLogger(name) 24 | logger.setLevel(logging.DEBUG) 25 | stream_handler = colorlog.StreamHandler(stream=sys.stdout) 26 | 27 | if format != "only_message": 28 | formatter = colorlog.ColoredFormatter( 29 | fmt="%(log_color)s[%(asctime)s] %(name)s: %(message)s", 30 | datefmt='%m-%d %H:%M:%S', 31 | log_colors=log_colors, 32 | ) 33 | else: 34 | formatter = colorlog.ColoredFormatter( 35 | fmt="%(log_color)s%(message)s", 36 | datefmt='%m-%d %H:%M:%S', 37 | log_colors=log_colors, 38 | ) 39 | 40 | stream_handler.setLevel(logging.DEBUG) 41 | stream_handler.setFormatter(formatter) 42 | 43 | logger.addHandler(stream_handler) 44 | if save_dir: 45 | fh = logging.FileHandler(save_dir) 46 | fh.setLevel(logging.DEBUG) 47 | fh.setFormatter(formatter) 48 | logger.addHandler(fh) 49 | 50 | return logger 51 | 52 | 53 | # draw a separator 54 | def separator(logger): 55 | logger.info( 56 | "———————————————————————————————————————————————————————————————————————————————————————————————————————") 57 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | import cv2 5 | 6 | Dic = {0:'Z',1:'Y',2:'X'} 7 | 8 | @torch.no_grad() 9 | def decom(whole_dict): 10 | label = [] 11 | dim_X = [] 12 | dim_Y = [] 13 | dim_Z = [] 14 | for index in range(len(whole_dict)): 15 | current = whole_dict[index] 16 | label.append(current['label']) 17 | dim_X.append(current['X']) 18 | dim_Y.append(current['Y']) 19 | dim_Z.append(current['Z']) 20 | return label, dim_X, dim_Y, dim_Z 21 | 22 | @torch.no_grad() 23 | def create_centroid_pos(Direction, Spacing, Size, position): 24 | # dim0, dim1,dim2, label): 25 | """ 26 | 27 | :param Direction,Spacing, Size: from sitk raw.GetDirection(),GetSpacing(),GetSize() 28 | :param position:[24,3] 29 | :return: 30 | """ 31 | direction = np.round(list(Direction)) 32 | direc0 = direction[0:7:3] 33 | direc1 = direction[1:8:3] 34 | direc2 = direction[2:9:3] 35 | dim0char = Dic[(np.argwhere((np.abs(direc0)) == 1))[0][0]] 36 | dim1char = Dic[(np.argwhere((np.abs(direc1)) == 1))[0][0]] 37 | dim2char = Dic[(np.argwhere((np.abs(direc2)) == 1))[0][0]] 38 | resolution = Spacing 39 | w, h, c = Size[0], Size[1], Size[2] 40 | jsonlist = [] 41 | for i in range(24): 42 | dim0, dim1, dim2 = position[i:i + 1, 0], position[i:i + 1, 1], position[i:i + 1, 2] 43 | if dim0 >= 0: 44 | label = i + 1 45 | if np.sum(direc0) == -1: 46 | if dim0char == 'X': 47 | Jsondim0 = dim0 * resolution[0] 48 | else: 49 | Jsondim0 = (w - dim0) * resolution[0] 50 | else: 51 | if dim0char == 'X': 52 | Jsondim0 = (w - dim0) * resolution[0] 53 | else: 54 | Jsondim0 = dim0 * resolution[0] 55 | 56 | if np.sum(direc1) == -1: 57 | if dim1char == 'X': 58 | Jsondim1 = dim1 * resolution[1] 59 | else: 60 | Jsondim1 = (h - dim1) * resolution[1] 61 | else: 62 | if dim1char == 'X': 63 | Jsondim1 = (h - dim1) * resolution[1] 64 | else: 65 | Jsondim1 = dim1 * resolution[1] 66 | 67 | if np.sum(direc2) == -1: 68 | if dim2char == 'X': 69 | Jsondim2 = dim2 * resolution[2] 70 | else: 71 | Jsondim2 = (c - dim2) * resolution[2] 72 | else: 73 | if dim2char == 'X': 74 | Jsondim2 = (c - dim2) * resolution[2] 75 | else: 76 | Jsondim2 = dim2 * resolution[2] 77 | jsonlist.append({dim0char: Jsondim0, dim1char: Jsondim1, dim2char: Jsondim2, 'label': label}) 78 | 79 | return jsonlist 80 | 81 | @torch.no_grad() 82 | def Get_Identification_Rate(ground_truth_list: object, pred_list: object) -> object: 83 | """ 84 | 85 | :param ground_truth_list: dict-->{'X':XXX, 'Y':XXX, 'Z':XXX} 86 | :param pred_list: 87 | :return: 88 | """ 89 | correctpred = 0 90 | whole_number = 0 91 | whole_number_gt = 0 92 | for i in range(len(pred_list)): 93 | label_GT, dim_X_GT, dim_Y_GT, dim_Z_GT = decom(ground_truth_list[i]) 94 | label_PRED, dim_X_PRED, dim_Y_PRED, dim_Z_PRED = decom(pred_list[i]) 95 | whole_number += len(label_PRED) 96 | whole_number_gt += len(label_GT) 97 | for idx in range(len(label_PRED)): 98 | label_c = label_PRED[idx] 99 | if label_c in label_GT: 100 | pos = label_GT.index(label_c) 101 | dif_X = dim_X_GT[pos] - dim_X_PRED[idx] 102 | dif_Y = dim_Y_GT[pos] - dim_Y_PRED[idx] 103 | dif_Z = dim_Z_GT[pos] - dim_Z_PRED[idx] 104 | distance = pow((pow(dif_X, 2) + pow(dif_Y, 2) + pow(dif_Z, 2)), 0.5) 105 | if distance < 20: 106 | correctpred += 1 107 | 108 | iden_rate = correctpred / whole_number if whole_number != 0 else 0 109 | 110 | return iden_rate, correctpred / whole_number_gt 111 | 112 | @torch.no_grad() 113 | def Get_Localisation_distance(ground_truth, pred): 114 | """ 115 | 116 | :param ground_truth: from each subject 117 | :param pred: 118 | :return: 119 | """ 120 | hit = 0 121 | distance = 0 122 | label_GT, dim_X_GT, dim_Y_GT, dim_Z_GT = decom(ground_truth) 123 | label_PRED, dim_X_PRED, dim_Y_PRED, dim_Z_PRED = decom(pred) 124 | for idx in range(len(label_PRED)): 125 | label_c = label_PRED[idx] 126 | if label_c in label_GT: 127 | hit += 1 128 | pos = label_GT.index(label_c) 129 | dif_X = dim_X_GT[pos] - dim_X_PRED[idx] 130 | dif_Y = dim_Y_GT[pos] - dim_Y_PRED[idx] 131 | dif_Z = dim_Z_GT[pos] - dim_Z_PRED[idx] 132 | distance += pow((pow(dif_X, 2) + pow(dif_Y, 2) + pow(dif_Y, 2)), 0.5) 133 | if hit == 0 : 134 | print('ALL MISSED') 135 | loc_dis = [] 136 | else: 137 | loc_dis = distance / hit 138 | 139 | return loc_dis 140 | 141 | @torch.no_grad() 142 | def Get_Recall_AND_Precision(ground_truth, pred): 143 | """ 144 | 145 | :param ground_truth: from each subject 146 | :param pred: 147 | :return: 148 | """ 149 | hit = 0 150 | label_GT, dim_X_GT, dim_Y_GT, dim_Z_GT = decom(ground_truth) 151 | label_PRED, dim_X_PRED, dim_Y_PRED, dim_Z_PRED = decom(pred) 152 | GT_length = len(label_GT) 153 | PRED_length = len(label_PRED) 154 | for idx in range(PRED_length): 155 | label_c = label_PRED[idx] 156 | if label_c in label_GT: 157 | pos = label_GT.index(label_c) 158 | dif_X = dim_X_GT[pos] - dim_X_PRED[idx] 159 | dif_Y = dim_Y_GT[pos] - dim_Y_PRED[idx] 160 | dif_Z = dim_Z_GT[pos] - dim_Z_PRED[idx] 161 | distance = pow((pow(dif_X, 2) + pow(dif_Y, 2) + pow(dif_Z, 2)), 0.5) 162 | if distance < 20: 163 | hit += 1 164 | Recall = hit / GT_length 165 | Precision = hit / PRED_length 166 | 167 | return Recall, Precision 168 | 169 | @torch.no_grad() 170 | def pred_pos(device, output_sag_batch, output_cor_batch, direction, crop_info, spacing, cor_pad, sag_pad): 171 | """ 172 | Compute the tensor product between output_sag and output_cor, 173 | then use argmax to find the position of the ith vertebra in channel i. 174 | Let's say the original 3D data has a shape of (B, C, d0, d1, d2), normally with C=25. 175 | Parameters: 176 | output_sag & output_cor: output of the Btrfly Net 177 | direc: it should be ('Z', 'Y', 'X') or ('Y', 'X', 'Z'), indicating the direction of the subject 178 | Return: 179 | a (BxCx3) tensor about the positions of the bones of every subjects in the batch 180 | """ 181 | if output_sag_batch.shape[:2] != output_cor_batch.shape[:2]: 182 | raise Exception("output_sag and output_cor have different batch sizes or channel numbers!") 183 | B, C = output_sag_batch.shape[0], output_sag_batch.shape[1] 184 | # threshold to reduce noise 185 | threshold_noise = 0 186 | threshold_label = torch.from_numpy(np.arange(0, 0.4, 0.01)).float() 187 | position_batch = torch.Tensor(len(threshold_label), B, C, 4) 188 | resolution = 1.0 189 | ori_d0, ori_d1, ori_d2 = np.zeros(B), np.zeros(B), np.zeros(B) 190 | for i in range(B): 191 | direc = (direction[0][i], direction[1][i], direction[2][i]) 192 | if (direc != ('Z', 'Y', 'X')) & (direc != ('Y', 'X', 'Z')): 193 | raise Exception('Unknown direction!') 194 | # select ith subject 195 | output_cor = output_cor_batch[i, :, :, :] 196 | output_sag = output_sag_batch[i, :, :, :] 197 | 198 | # reduce the noise according to threshold 199 | reduce_noise_sag = torch.where(output_sag < threshold_noise, torch.full_like(output_sag, 0), output_sag) 200 | reduce_noise_cor = torch.where(output_cor < threshold_noise, torch.full_like(output_cor, 0), output_cor) 201 | max_value, max_idx = torch.zeros(24), torch.zeros(24) 202 | 203 | if direc == ('Z', 'Y', 'X'): 204 | # sag:(C, d1, d2), cor:(C, d0, d2) 205 | if (output_sag.shape[2] != output_cor.shape[2]): 206 | raise Exception("sag and cor should have an identical size in the last dimension!") 207 | d0, d1, d2 = output_cor.shape[1], output_sag.shape[1], output_sag.shape[2] 208 | ori_d0[i], ori_d1[i], ori_d2[i] = d0 - cor_pad[2][i] - cor_pad[3][i], d1 - sag_pad[2][i] - sag_pad[3][i], d2 - cor_pad[0][i] - cor_pad[1][i] 209 | #extend them to (d0, d1, d2) 210 | for c_num in range(24): 211 | 212 | reduce_noise_sag_one_cha = reduce_noise_sag[c_num, sag_pad[2][i]:d1-sag_pad[3][i], sag_pad[0][i]:d2-sag_pad[1][i]] 213 | reduce_noise_cor_one_cha = reduce_noise_cor[c_num, cor_pad[2][i]:d0-cor_pad[3][i], cor_pad[0][i]:d2-cor_pad[1][i]] 214 | assert reduce_noise_cor_one_cha.shape[1] == reduce_noise_sag_one_cha.shape[1] 215 | 216 | reduce_noise_sag_one_cha = reduce_noise_sag_one_cha.unsqueeze(0).expand(int(ori_d0[i]), int(ori_d1[i]), int(ori_d2[i])) 217 | reduce_noise_cor_one_cha = reduce_noise_cor_one_cha.unsqueeze(1).expand(int(ori_d0[i]), int(ori_d1[i]), int(ori_d2[i])) 218 | 219 | product = reduce_noise_cor_one_cha * reduce_noise_sag_one_cha 220 | # find maximum value for each batch and channel 221 | max_value[c_num], max_idx[c_num] = torch.max(product.view(-1), dim=0) 222 | else: 223 | # sag:(C, d0, d1), cor:(C, d1, d2) 224 | if (output_sag.shape[2] != output_cor.shape[1]): 225 | raise Exception("sag and cor should have an identical size in some dimension!") 226 | d0, d1, d2 = output_sag.shape[1], output_sag.shape[2], output_cor.shape[2] 227 | ori_d0[i], ori_d1[i], ori_d2[i] = d0 - sag_pad[2][i] - sag_pad[3][i], d1 - sag_pad[0][i] - sag_pad[1][i], d2 - cor_pad[0][i] - cor_pad[1][i] 228 | #extend them to (d0, d1, d2) 229 | for c_num in range(24): 230 | 231 | reduce_noise_sag_one_cha = reduce_noise_sag[c_num, sag_pad[2][i]:d0 - sag_pad[3][i], sag_pad[0][i]:d1 - sag_pad[1][i]] 232 | reduce_noise_cor_one_cha = reduce_noise_cor[c_num, cor_pad[2][i]:d1 - cor_pad[3][i], cor_pad[0][i]:d2 - cor_pad[1][i]] 233 | assert reduce_noise_cor_one_cha.shape[0] == reduce_noise_sag_one_cha.shape[1] 234 | 235 | reduce_noise_sag_one_cha = reduce_noise_sag_one_cha.unsqueeze(2).expand(int(ori_d0[i]), int(ori_d1[i]), int(ori_d2[i])) 236 | reduce_noise_cor_one_cha = reduce_noise_cor_one_cha.unsqueeze(0).expand(int(ori_d0[i]), int(ori_d1[i]), int(ori_d2[i])) 237 | product = reduce_noise_cor_one_cha * reduce_noise_sag_one_cha 238 | # find maximum value for each batch and channel 239 | max_value[c_num], max_idx[c_num] = torch.max(product.view(-1), dim=0) 240 | 241 | # translate the indexes to 3D form 242 | max_idx_x, max_idx_y, max_idx_z = -torch.ones(24), -torch.ones(24), -torch.ones(24) 243 | for c_num in range(24): 244 | max_idx_x[c_num], max_idx_y[c_num], max_idx_z[c_num] = \ 245 | max_idx[c_num] // (ori_d1[i] * ori_d2[i]), \ 246 | (max_idx[c_num] % (ori_d1[i] * ori_d2[i])) // ori_d2[i], \ 247 | (max_idx[c_num] % (ori_d1[i] * ori_d2[i])) % ori_d2[i] 248 | for step in range(len(threshold_label)): 249 | position_batch[step, i, :, 0] = (max_idx_x.float() * resolution / spacing[0][i] + crop_info['displace'][i, 0, 0])\ 250 | * (2 * (max_value > threshold_label[step]).float() - 1) \ 251 | - (max_value <= threshold_label[step]).float() 252 | position_batch[step, i, :, 1] = (max_idx_y.float() * resolution / spacing[1][i] + crop_info['displace'][i, 0, 1])\ 253 | * (2 * (max_value > threshold_label[step]).float() - 1) \ 254 | - (max_value <= threshold_label[step]).float() 255 | position_batch[step, i, :, 2] = (max_idx_z.float() * resolution / spacing[2][i] + crop_info['displace'][i, 0, 2])\ 256 | * (2 * (max_value > threshold_label[step]).float() - 1) \ 257 | - (max_value <= threshold_label[step]).float() 258 | position_batch[step, i, :, 3] = max_value 259 | 260 | return position_batch 261 | 262 | 263 | @torch.no_grad() 264 | def pred_pos_2(device, output_sag_batch, output_cor_batch, direction, crop_info, spacing, cor_pad, sag_pad): 265 | """ 266 | Compute the tensor product between output_sag and output_cor, 267 | then use argmax to find the position of the ith vertebra in channel i. 268 | Let's say the original 3D data has a shape of (B, C, d0, d1, d2), normally with C=25. 269 | Parameters: 270 | output_sag & output_cor: output of the Btrfly Net 271 | direc: it should be ('Z', 'Y', 'X') or ('Y', 'X', 'Z'), indicating the direction of the subject 272 | Return: 273 | a (BxCx3) tensor about the positions of the bones of every subjects in the batch 274 | """ 275 | if output_sag_batch.shape[:2] != output_cor_batch.shape[:2]: 276 | raise Exception("output_sag and output_cor have different batch sizes or channel numbers!") 277 | B, C = output_sag_batch.shape[0], output_sag_batch.shape[1] 278 | # threshold to reduce noise 279 | threshold_noise = 0 280 | threshold_label = torch.from_numpy(np.arange(0, 0.4, 0.01)).float() 281 | position_batch = torch.Tensor(len(threshold_label), B, C, 4) 282 | resolution = 2.0 283 | ori_d0, ori_d1, ori_d2 = np.zeros(B), np.zeros(B), np.zeros(B) 284 | for i in range(B): 285 | direc = (direction[0][i], direction[1][i], direction[2][i]) 286 | if (direc != ('Z', 'Y', 'X')) & (direc != ('Y', 'X', 'Z')): 287 | raise Exception('Unknown direction!') 288 | # select ith subject 289 | output_cor = output_cor_batch[i, :, :, :] 290 | output_sag = output_sag_batch[i, :, :, :] 291 | 292 | # reduce the noise according to threshold 293 | reduce_noise_sag = torch.where(output_sag < threshold_noise, torch.full_like(output_sag, 0), output_sag) 294 | reduce_noise_cor = torch.where(output_cor < threshold_noise, torch.full_like(output_cor, 0), output_cor) 295 | max_value, max_idx = torch.zeros(24), torch.zeros(24) 296 | max_idx_x, max_idx_y, max_idx_z = torch.zeros(24), torch.zeros(24), torch.zeros(24) 297 | if direc == ('Z', 'Y', 'X'): 298 | # sag:(C, d1, d2), cor:(C, d0, d2) 299 | if (output_sag.shape[2] != output_cor.shape[2]): 300 | raise Exception("sag and cor should have an identical size in the last dimension!") 301 | d0, d1, d2 = output_cor.shape[1], output_sag.shape[1], output_sag.shape[2] 302 | ori_d0[i], ori_d1[i], ori_d2[i] = d0 - cor_pad[2][i] - cor_pad[3][i], d1 - sag_pad[2][i] - sag_pad[3][i], d2 - cor_pad[0][i] - cor_pad[1][i] 303 | reduce_noise_sag_no_padding = reduce_noise_sag[:, sag_pad[2][i]:d1-sag_pad[3][i], sag_pad[0][i]:d2-sag_pad[1][i]] 304 | reduce_noise_cor_no_padding = reduce_noise_cor[:, cor_pad[2][i]:d0-cor_pad[3][i], cor_pad[0][i]:d2-cor_pad[1][i]] 305 | # (24) 306 | max_value_sag, max_idx_sag = torch.max(reduce_noise_sag_no_padding.contiguous().view(reduce_noise_sag_no_padding.shape[0], -1), dim=1) 307 | max_value_cor, max_idx_cor = torch.max(reduce_noise_cor_no_padding.contiguous().view(reduce_noise_cor_no_padding.shape[0], -1), dim=1) 308 | max_idx_sag_x, max_idx_sag_y = max_idx_sag // ori_d2[i], max_idx_sag % ori_d2[i] 309 | max_idx_cor_x, max_idx_cor_y = max_idx_cor // ori_d2[i], max_idx_cor % ori_d2[i] 310 | for c_num in range(24): 311 | if max_value_sag[c_num] > max_value_cor[c_num]: 312 | max_idx_x[c_num], max_idx_y[c_num], max_idx_z[c_num] = max_idx_cor_x[c_num], max_idx_sag_x[c_num], max_idx_sag_y[c_num] 313 | max_value[c_num] = max_value_sag[c_num] 314 | else: 315 | max_idx_x[c_num], max_idx_y[c_num], max_idx_z[c_num] = max_idx_cor_x[c_num], max_idx_sag_x[c_num], max_idx_cor_y[c_num] 316 | max_value[c_num] = max_value_cor[c_num] 317 | else: 318 | # sag:(C, d0, d1), cor:(C, d1, d2) 319 | if (output_sag.shape[2] != output_cor.shape[1]): 320 | raise Exception("sag and cor should have an identical size in some dimension!") 321 | d0, d1, d2 = output_sag.shape[1], output_sag.shape[2], output_cor.shape[2] 322 | ori_d0[i], ori_d1[i], ori_d2[i] = d0 - sag_pad[2][i] - sag_pad[3][i], d1 - sag_pad[0][i] - sag_pad[1][i], d2 - cor_pad[0][i] - cor_pad[1][i] 323 | reduce_noise_sag_no_padding = reduce_noise_sag[:, sag_pad[2][i]:d1 - sag_pad[3][i], sag_pad[0][i]:d2 - sag_pad[1][i]] 324 | reduce_noise_cor_no_padding = reduce_noise_cor[:, cor_pad[2][i]:d0 - cor_pad[3][i], cor_pad[0][i]:d2 - cor_pad[1][i]] 325 | # (24) 326 | max_value_sag, max_idx_sag = torch.max(reduce_noise_sag_no_padding.contiguous().view(reduce_noise_sag_no_padding.shape[0], -1), dim=1) 327 | max_value_cor, max_idx_cor = torch.max(reduce_noise_cor_no_padding.contiguous().view(reduce_noise_cor_no_padding.shape[0], -1), dim=1) 328 | max_idx_sag_x, max_idx_sag_y = max_idx_sag // ori_d1[i], max_idx_sag % ori_d1[i] 329 | max_idx_cor_x, max_idx_cor_y = max_idx_cor // ori_d2[i], max_idx_cor % ori_d2[i] 330 | for c_num in range(24): 331 | if max_value_sag[c_num] > max_value_cor[c_num]: 332 | max_idx_x[c_num], max_idx_y[c_num], max_idx_z[c_num] = max_idx_sag_x[c_num], max_idx_sag_y[c_num], max_idx_cor_y[c_num] 333 | max_value[c_num] = max_value_sag[c_num] 334 | else: 335 | max_idx_x[c_num], max_idx_y[c_num], max_idx_z[c_num] = max_idx_sag_x[c_num], max_idx_cor_x[c_num], max_idx_cor_y[c_num] 336 | max_value[c_num] = max_value_cor[c_num] 337 | 338 | for step in range(len(threshold_label)): 339 | position_batch[step, i, :, 0] = (max_idx_x.float() * resolution / spacing[0][i] + crop_info['displace'][i, 0, 0])\ 340 | * (2 * (max_value > threshold_label[step]).float() - 1) \ 341 | - (max_value <= threshold_label[step]).float() 342 | position_batch[step, i, :, 1] = (max_idx_y.float() * resolution / spacing[1][i] + crop_info['displace'][i, 0, 1])\ 343 | * (2 * (max_value > threshold_label[step]).float() - 1) \ 344 | - (max_value <= threshold_label[step]).float() 345 | position_batch[step, i, :, 2] = (max_idx_z.float() * resolution / spacing[2][i] + crop_info['displace'][i, 0, 2])\ 346 | * (2 * (max_value > threshold_label[step]).float() - 1) \ 347 | - (max_value <= threshold_label[step]).float() 348 | position_batch[step, i, :, 3] = max_value 349 | 350 | return position_batch 351 | 352 | @torch.no_grad() 353 | def pred_pos_3(device, output_sag_batch, output_cor_batch, direction, crop_info, spacing, cor_pad, sag_pad): 354 | """ 355 | Compute the tensor product between output_sag and output_cor, 356 | then use argmax to find the position of the ith vertebra in channel i. 357 | Let's say the original 3D data has a shape of (B, C, d0, d1, d2), normally with C=25. 358 | Parameters: 359 | output_sag & output_cor: output of the Btrfly Net 360 | direc: it should be ('Z', 'Y', 'X') or ('Y', 'X', 'Z'), indicating the direction of the subject 361 | Return: 362 | a (BxCx3) tensor about the positions of the bones of every subjects in the batch 363 | """ 364 | if output_sag_batch.shape[:2] != output_cor_batch.shape[:2]: 365 | raise Exception("output_sag and output_cor have different batch sizes or channel numbers!") 366 | B, C = output_sag_batch.shape[0], output_sag_batch.shape[1] 367 | f_size = 7 368 | # threshold to reduce noise 369 | threshold_noise = 0 370 | threshold_label = torch.from_numpy(np.arange(0, 0.2, 0.005)).float() 371 | position_batch = torch.Tensor(len(threshold_label), B, C, 4) 372 | position_batch_sag = torch.Tensor(B, C, 3) 373 | position_batch_cor = torch.Tensor(B, C, 3) 374 | resolution = 1.0 375 | ori_d0, ori_d1, ori_d2 = np.zeros(B), np.zeros(B), np.zeros(B) 376 | for i in range(B): 377 | direc = (direction[0][i], direction[1][i], direction[2][i]) 378 | if (direc != ('Z', 'Y', 'X')) & (direc != ('Y', 'X', 'Z')): 379 | raise Exception('Unknown direction!') 380 | # select ith subject 381 | output_cor = output_cor_batch[i, :, :, :] 382 | output_sag = output_sag_batch[i, :, :, :] 383 | 384 | # reduce the noise according to threshold 385 | reduce_noise_sag = torch.where(output_sag < threshold_noise, torch.full_like(output_sag, 0), output_sag) 386 | reduce_noise_cor = torch.where(output_cor < threshold_noise, torch.full_like(output_cor, 0), output_cor) 387 | max_value, max_idx = torch.zeros(24), torch.zeros(24) 388 | 389 | if direc == ('Z', 'Y', 'X'): 390 | # sag:(C, d1, d2), cor:(C, d0, d2) 391 | if (output_sag.shape[2] != output_cor.shape[2]): 392 | raise Exception("sag and cor should have an identical size in the last dimension!") 393 | d0, d1, d2 = output_cor.shape[1], output_sag.shape[1], output_sag.shape[2] 394 | ori_d0[i], ori_d1[i], ori_d2[i] = d0 - cor_pad[2][i] - cor_pad[3][i], d1 - sag_pad[2][i] - sag_pad[3][i], d2 - cor_pad[0][i] - cor_pad[1][i] 395 | 396 | 397 | reduce_noise_sag_no_padding = reduce_noise_sag[:, sag_pad[2][i]:d1 - sag_pad[3][i], 398 | sag_pad[0][i]:d2 - sag_pad[1][i]] 399 | reduce_noise_cor_no_padding = reduce_noise_cor[:, cor_pad[2][i]:d0 - cor_pad[3][i], 400 | cor_pad[0][i]:d2 - cor_pad[1][i]] 401 | 402 | # for c_num in range(24): 403 | # 404 | # reduce_noise_sag_no_padding[c_num, :, :] = torch.tensor(cv2.medianBlur(reduce_noise_sag_no_padding[c_num, :, :].cpu().numpy(), f_size)).to(device) 405 | # reduce_noise_cor_no_padding[c_num, :, :] = torch.tensor(cv2.medianBlur(reduce_noise_cor_no_padding[c_num, :, :].cpu().numpy(), f_size)).to(device) 406 | 407 | # (24) 408 | max_value_sag, max_idx_sag = torch.max( 409 | reduce_noise_sag_no_padding.contiguous().view(reduce_noise_sag_no_padding.shape[0], -1), dim=1) 410 | max_value_cor, max_idx_cor = torch.max( 411 | reduce_noise_cor_no_padding.contiguous().view(reduce_noise_cor_no_padding.shape[0], -1), dim=1) 412 | max_idx_sag_x, max_idx_sag_y = max_idx_sag // ori_d2[i], max_idx_sag % ori_d2[i] 413 | max_idx_cor_x, max_idx_cor_y = max_idx_cor // ori_d2[i], max_idx_cor % ori_d2[i] 414 | 415 | #extend them to (d0, d1, d2) 416 | for c_num in range(24): 417 | reduce_noise_sag_one_cha = reduce_noise_sag[c_num, sag_pad[2][i]:d1-sag_pad[3][i], sag_pad[0][i]:d2-sag_pad[1][i]] 418 | 419 | reduce_noise_cor_one_cha = reduce_noise_cor[c_num, cor_pad[2][i]:d0-cor_pad[3][i], cor_pad[0][i]:d2-cor_pad[1][i]] 420 | assert reduce_noise_cor_one_cha.shape[1] == reduce_noise_sag_one_cha.shape[1] 421 | 422 | reduce_noise_sag_one_cha = reduce_noise_sag_one_cha.unsqueeze(0).expand(int(ori_d0[i]), int(ori_d1[i]), int(ori_d2[i])) 423 | reduce_noise_cor_one_cha = reduce_noise_cor_one_cha.unsqueeze(1).expand(int(ori_d0[i]), int(ori_d1[i]), int(ori_d2[i])) 424 | 425 | product = reduce_noise_cor_one_cha * reduce_noise_sag_one_cha 426 | # find maximum value for each batch and channel 427 | max_value[c_num], max_idx[c_num] = torch.max(product.view(-1), dim=0) 428 | else: 429 | # sag:(C, d0, d1), cor:(C, d1, d2) 430 | if (output_sag.shape[2] != output_cor.shape[1]): 431 | raise Exception("sag and cor should have an identical size in some dimension!") 432 | d0, d1, d2 = output_sag.shape[1], output_sag.shape[2], output_cor.shape[2] 433 | ori_d0[i], ori_d1[i], ori_d2[i] = d0 - sag_pad[2][i] - sag_pad[3][i], d1 - sag_pad[0][i] - sag_pad[1][i], d2 - cor_pad[0][i] - cor_pad[1][i] 434 | 435 | reduce_noise_sag_no_padding = reduce_noise_sag[:, sag_pad[2][i]:d1 - sag_pad[3][i], 436 | sag_pad[0][i]:d2 - sag_pad[1][i]] 437 | reduce_noise_cor_no_padding = reduce_noise_cor[:, cor_pad[2][i]:d0 - cor_pad[3][i], 438 | cor_pad[0][i]:d2 - cor_pad[1][i]] 439 | 440 | # for c_num in range(24): 441 | # 442 | # reduce_noise_sag_no_padding[c_num, :, :] = torch.tensor(cv2.medianBlur(reduce_noise_sag_no_padding[c_num, :, :].cpu().numpy(), f_size)).to(device) 443 | # reduce_noise_cor_no_padding[c_num, :, :] = torch.tensor(cv2.medianBlur(reduce_noise_cor_no_padding[c_num, :, :].cpu().numpy(), f_size)).to(device) 444 | 445 | # (24) 446 | max_value_sag, max_idx_sag = torch.max( 447 | reduce_noise_sag_no_padding.contiguous().view(reduce_noise_sag_no_padding.shape[0], -1), dim=1) 448 | max_value_cor, max_idx_cor = torch.max( 449 | reduce_noise_cor_no_padding.contiguous().view(reduce_noise_cor_no_padding.shape[0], -1), dim=1) 450 | 451 | 452 | 453 | max_idx_sag_x, max_idx_sag_y = max_idx_sag // ori_d1[i], max_idx_sag % ori_d1[i] 454 | max_idx_cor_x, max_idx_cor_y = max_idx_cor // ori_d2[i], max_idx_cor % ori_d2[i] 455 | 456 | #extend them to (d0, d1, d2) 457 | for c_num in range(24): 458 | 459 | reduce_noise_sag_one_cha = reduce_noise_sag[c_num, sag_pad[2][i]:d0 - sag_pad[3][i], sag_pad[0][i]:d1 - sag_pad[1][i]] 460 | reduce_noise_cor_one_cha = reduce_noise_cor[c_num, cor_pad[2][i]:d1 - cor_pad[3][i], cor_pad[0][i]:d2 - cor_pad[1][i]] 461 | assert reduce_noise_cor_one_cha.shape[0] == reduce_noise_sag_one_cha.shape[1] 462 | 463 | reduce_noise_sag_one_cha = reduce_noise_sag_one_cha.unsqueeze(2).expand(int(ori_d0[i]), int(ori_d1[i]), int(ori_d2[i])) 464 | reduce_noise_cor_one_cha = reduce_noise_cor_one_cha.unsqueeze(0).expand(int(ori_d0[i]), int(ori_d1[i]), int(ori_d2[i])) 465 | product = reduce_noise_cor_one_cha * reduce_noise_sag_one_cha 466 | # find maximum value for each batch and channel 467 | max_value[c_num], max_idx[c_num] = torch.max(product.view(-1), dim=0) 468 | 469 | # translate the indexes to 3D form 470 | max_idx_x, max_idx_y, max_idx_z = -torch.ones(24), -torch.ones(24), -torch.ones(24) 471 | for c_num in range(24): 472 | max_idx_x[c_num], max_idx_y[c_num], max_idx_z[c_num] = \ 473 | max_idx[c_num] // (ori_d1[i] * ori_d2[i]), \ 474 | (max_idx[c_num] % (ori_d1[i] * ori_d2[i])) // ori_d2[i], \ 475 | (max_idx[c_num] % (ori_d1[i] * ori_d2[i])) % ori_d2[i] 476 | 477 | if direc == ('Z', 'Y', 'X'): 478 | for c_num in range(24): 479 | max_idx_z[c_num] = (max_idx_sag_y[c_num] * max_value_sag[c_num] + max_idx_cor_y[c_num] * max_value_cor[c_num]) / \ 480 | (max_value_sag[c_num]+max_value_cor[c_num]) 481 | # if max_value_sag[c_num] > max_value_cor[c_num]: 482 | # max_idx_z[c_num] = max_idx_sag_y[c_num] 483 | # else: 484 | # max_idx_z[c_num] = max_idx_cor_y[c_num] 485 | 486 | else: 487 | for c_num in range(24): 488 | max_idx_y[c_num] = (max_idx_sag_y[c_num] * max_value_sag[c_num] + max_idx_cor_x[c_num] * max_value_cor[c_num]) / \ 489 | (max_value_sag[c_num]+ max_value_cor[c_num]) 490 | # if max_value_sag[c_num] > max_value_cor[c_num]: 491 | # max_idx_y[c_num] = max_idx_sag_y[c_num] 492 | # else: 493 | # max_idx_y[c_num] = max_idx_cor_x[c_num] 494 | 495 | for step in range(len(threshold_label)): 496 | position_batch[step, i, :, 0] = (max_idx_x.float() * resolution / spacing[0][i] + crop_info['displace'][i, 0, 0])\ 497 | * (2 * (max_value > threshold_label[step]).float() - 1) \ 498 | - (max_value <= threshold_label[step]).float() 499 | position_batch[step, i, :, 1] = (max_idx_y.float() * resolution / spacing[1][i] + crop_info['displace'][i, 0, 1])\ 500 | * (2 * (max_value > threshold_label[step]).float() - 1) \ 501 | - (max_value <= threshold_label[step]).float() 502 | position_batch[step, i, :, 2] = (max_idx_z.float() * resolution / spacing[2][i] + crop_info['displace'][i, 0, 2])\ 503 | * (2 * (max_value > threshold_label[step]).float() - 1) \ 504 | - (max_value <= threshold_label[step]).float() 505 | position_batch[step, i, :, 3] = max_value 506 | 507 | position_batch_sag[i, :, 0] = max_idx_sag_x 508 | position_batch_sag[i, :, 1] = max_idx_sag_y 509 | position_batch_sag[i, :, 2] = max_value_sag 510 | position_batch_cor[i, :, 0] = max_idx_cor_x 511 | position_batch_cor[i, :, 1] = max_idx_cor_y 512 | position_batch_cor[i, :, 2] = max_value_cor 513 | 514 | return position_batch, position_batch_cor , position_batch_sag 515 | 516 | 517 | @torch.no_grad() 518 | def pred_pos_4(device, output_sag_batch, output_cor_batch, direction, crop_info, spacing, cor_pad, sag_pad): 519 | """ 520 | Compute the tensor product between output_sag and output_cor, 521 | then use argmax to find the position of the ith vertebra in channel i. 522 | Let's say the original 3D data has a shape of (B, C, d0, d1, d2), normally with C=25. 523 | Parameters: 524 | output_sag & output_cor: output of the Btrfly Net 525 | direc: it should be ('Z', 'Y', 'X') or ('Y', 'X', 'Z'), indicating the direction of the subject 526 | Return: 527 | a (BxCx3) tensor about the positions of the bones of every subjects in the batch 528 | """ 529 | if output_sag_batch.shape[:2] != output_cor_batch.shape[:2]: 530 | raise Exception("output_sag and output_cor have different batch sizes or channel numbers!") 531 | B, C = output_sag_batch.shape[0], output_sag_batch.shape[1] 532 | f_size = 7 533 | # threshold to reduce noise 534 | threshold_noise = 0 535 | threshold_label = torch.from_numpy(np.arange(0, 0.4, 0.01)).float() 536 | position_batch = torch.Tensor(len(threshold_label), B, C, 4) 537 | position_batch_sag = torch.Tensor(B, C, 2) 538 | position_batch_cor = torch.Tensor(B, C, 2) 539 | resolution = 1.0 540 | ori_d0, ori_d1, ori_d2 = np.zeros(B), np.zeros(B), np.zeros(B) 541 | for i in range(B): 542 | direc = (direction[0][i], direction[1][i], direction[2][i]) 543 | if (direc != ('Z', 'Y', 'X')) & (direc != ('Y', 'X', 'Z')): 544 | raise Exception('Unknown direction!') 545 | # select ith subject 546 | output_cor = output_cor_batch[i, :, :, :] 547 | output_sag = output_sag_batch[i, :, :, :] 548 | 549 | # reduce the noise according to threshold 550 | reduce_noise_sag = torch.where(output_sag < threshold_noise, torch.full_like(output_sag, 0), output_sag) 551 | reduce_noise_cor = torch.where(output_cor < threshold_noise, torch.full_like(output_cor, 0), output_cor) 552 | max_value, max_idx, max_cor_num = torch.zeros(24), torch.zeros(24), torch.zeros(24) 553 | 554 | if direc == ('Z', 'Y', 'X'): 555 | # sag:(C, d1, d2), cor:(C, d0, d2) 556 | if (output_sag.shape[2] != output_cor.shape[2]): 557 | raise Exception("sag and cor should have an identical size in the last dimension!") 558 | d0, d1, d2 = output_cor.shape[1], output_sag.shape[1], output_sag.shape[2] 559 | ori_d0[i], ori_d1[i], ori_d2[i] = d0 - cor_pad[2][i] - cor_pad[3][i], d1 - sag_pad[2][i] - sag_pad[3][ 560 | i], d2 - cor_pad[0][i] - cor_pad[1][i] 561 | 562 | reduce_noise_sag_no_padding = reduce_noise_sag[:, sag_pad[2][i]:d1 - sag_pad[3][i], 563 | sag_pad[0][i]:d2 - sag_pad[1][i]] 564 | reduce_noise_cor_no_padding = reduce_noise_cor[:, cor_pad[2][i]:d0 - cor_pad[3][i], 565 | cor_pad[0][i]:d2 - cor_pad[1][i]] 566 | 567 | # for c_num in range(24): 568 | # 569 | # reduce_noise_sag_no_padding[c_num, :, :] = torch.tensor(cv2.medianBlur(reduce_noise_sag_no_padding[c_num, :, :].cpu().numpy(), f_size)).to(device) 570 | # reduce_noise_cor_no_padding[c_num, :, :] = torch.tensor(cv2.medianBlur(reduce_noise_cor_no_padding[c_num, :, :].cpu().numpy(), f_size)).to(device) 571 | 572 | # (24) 573 | max_value_sag, max_idx_sag = torch.max( 574 | reduce_noise_sag_no_padding.contiguous().view(reduce_noise_sag_no_padding.shape[0], -1), dim=1) 575 | max_value_cor, max_idx_cor = torch.max( 576 | reduce_noise_cor_no_padding.contiguous().view(reduce_noise_cor_no_padding.shape[0], -1), dim=1) 577 | max_idx_sag_x, max_idx_sag_y = max_idx_sag // ori_d2[i], max_idx_sag % ori_d2[i] 578 | max_idx_cor_x, max_idx_cor_y = max_idx_cor // ori_d2[i], max_idx_cor % ori_d2[i] 579 | 580 | # extend them to (d0, d1, d2) 581 | for c_num_sag in range(24): 582 | reduce_noise_sag_one_cha = reduce_noise_sag_no_padding[c_num_sag, :, :].unsqueeze(0).expand(int(ori_d0[i]), int(ori_d1[i]), int(ori_d2[i])) 583 | for c_num_cor in range(24): 584 | reduce_noise_cor_one_cha = reduce_noise_cor_no_padding[c_num_cor, :, :].unsqueeze(1).expand(int(ori_d0[i]), int(ori_d1[i]), 585 | int(ori_d2[i])) 586 | 587 | product = reduce_noise_cor_one_cha * reduce_noise_sag_one_cha 588 | 589 | # find maximum value for each batch and channel 590 | max_value_tmp, max_idx_tmp = torch.max(product.view(-1), dim=0) 591 | if max_value_tmp.cpu() > max_value[c_num_sag]: 592 | max_value[c_num_sag], max_idx[c_num_sag], max_cor_num[c_num_sag] = max_value_tmp, max_idx_tmp, c_num_cor 593 | else: 594 | # sag:(C, d0, d1), cor:(C, d1, d2) 595 | if (output_sag.shape[2] != output_cor.shape[1]): 596 | raise Exception("sag and cor should have an identical size in some dimension!") 597 | d0, d1, d2 = output_sag.shape[1], output_sag.shape[2], output_cor.shape[2] 598 | ori_d0[i], ori_d1[i], ori_d2[i] = d0 - sag_pad[2][i] - sag_pad[3][i], d1 - sag_pad[0][i] - sag_pad[1][ 599 | i], d2 - cor_pad[0][i] - cor_pad[1][i] 600 | 601 | reduce_noise_sag_no_padding = reduce_noise_sag[:, sag_pad[2][i]:d1 - sag_pad[3][i], 602 | sag_pad[0][i]:d2 - sag_pad[1][i]] 603 | reduce_noise_cor_no_padding = reduce_noise_cor[:, cor_pad[2][i]:d0 - cor_pad[3][i], 604 | cor_pad[0][i]:d2 - cor_pad[1][i]] 605 | 606 | # for c_num in range(24): 607 | # 608 | # reduce_noise_sag_no_padding[c_num, :, :] = torch.tensor(cv2.medianBlur(reduce_noise_sag_no_padding[c_num, :, :].cpu().numpy(), f_size)).to(device) 609 | # reduce_noise_cor_no_padding[c_num, :, :] = torch.tensor(cv2.medianBlur(reduce_noise_cor_no_padding[c_num, :, :].cpu().numpy(), f_size)).to(device) 610 | 611 | # (24) 612 | max_value_sag, max_idx_sag = torch.max( 613 | reduce_noise_sag_no_padding.contiguous().view(reduce_noise_sag_no_padding.shape[0], -1), dim=1) 614 | max_value_cor, max_idx_cor = torch.max( 615 | reduce_noise_cor_no_padding.contiguous().view(reduce_noise_cor_no_padding.shape[0], -1), dim=1) 616 | 617 | max_idx_sag_x, max_idx_sag_y = max_idx_sag // ori_d1[i], max_idx_sag % ori_d1[i] 618 | max_idx_cor_x, max_idx_cor_y = max_idx_cor // ori_d2[i], max_idx_cor % ori_d2[i] 619 | 620 | # extend them to (d0, d1, d2) 621 | for c_num_sag in range(24): 622 | reduce_noise_sag_one_cha = reduce_noise_sag_no_padding[c_num_sag, :, :].unsqueeze(2).expand(int(ori_d0[i]), int(ori_d1[i]), 623 | int(ori_d2[i])) 624 | for c_num_cor in range(24): 625 | reduce_noise_cor_one_cha = reduce_noise_cor_no_padding[c_num_cor, :, :].unsqueeze(0).expand(int(ori_d0[i]), int(ori_d1[i]), 626 | int(ori_d2[i])) 627 | 628 | product = reduce_noise_cor_one_cha * reduce_noise_sag_one_cha 629 | 630 | # find maximum value for each batch and channel 631 | max_value_tmp, max_idx_tmp = torch.max(product.view(-1), dim=0) 632 | if max_value_tmp.cpu() > max_value[c_num_sag]: 633 | max_value[c_num_sag], max_idx[c_num_sag], max_cor_num[c_num_sag] = max_value_tmp, max_idx_tmp, c_num_cor 634 | 635 | 636 | 637 | # translate the indexes to 3D form 638 | max_idx_x, max_idx_y, max_idx_z = -torch.ones(24), -torch.ones(24), -torch.ones(24) 639 | for c_num in range(24): 640 | max_idx_x[c_num], max_idx_y[c_num], max_idx_z[c_num] = \ 641 | max_idx[c_num] // (ori_d1[i] * ori_d2[i]), \ 642 | (max_idx[c_num] % (ori_d1[i] * ori_d2[i])) // ori_d2[i], \ 643 | (max_idx[c_num] % (ori_d1[i] * ori_d2[i])) % ori_d2[i] 644 | 645 | if direc == ('Z', 'Y', 'X'): 646 | for c_num in range(24): 647 | max_idx_z[c_num] = (max_idx_sag_y[c_num] * max_value_sag[c_num] + max_idx_cor_y[c_num] * max_value_cor[ 648 | c_num]) / \ 649 | (max_value_sag[c_num] + max_value_cor[c_num]) 650 | # if max_value_sag[c_num] > max_value_cor[c_num]: 651 | # max_idx_z[c_num] = max_idx_sag_y[c_num] 652 | # else: 653 | # max_idx_z[c_num] = max_idx_cor_y[c_num] 654 | 655 | else: 656 | for c_num in range(24): 657 | max_idx_y[c_num] = (max_idx_sag_y[c_num] * max_value_sag[c_num] + max_idx_cor_x[c_num] * max_value_cor[ 658 | c_num]) / \ 659 | (max_value_sag[c_num] + max_value_cor[c_num]) 660 | # if max_value_sag[c_num] > max_value_cor[c_num]: 661 | # max_idx_y[c_num] = max_idx_sag_y[c_num] 662 | # else: 663 | # max_idx_y[c_num] = max_idx_cor_x[c_num] 664 | 665 | for step in range(len(threshold_label)): 666 | position_batch[step, i, :, 0] = (max_idx_x.float() * resolution / spacing[0][i] + crop_info['displace'][ 667 | i, 0, 0]) \ 668 | * (2 * (max_value >= threshold_label[step]).float() - 1) \ 669 | - (max_value < threshold_label[step]).float() 670 | position_batch[step, i, :, 1] = (max_idx_y.float() * resolution / spacing[1][i] + crop_info['displace'][ 671 | i, 0, 1]) \ 672 | * (2 * (max_value >= threshold_label[step]).float() - 1) \ 673 | - (max_value < threshold_label[step]).float() 674 | position_batch[step, i, :, 2] = (max_idx_z.float() * resolution / spacing[2][i] + crop_info['displace'][ 675 | i, 0, 2]) \ 676 | * (2 * (max_value >= threshold_label[step]).float() - 1) \ 677 | - (max_value < threshold_label[step]).float() 678 | position_batch[step, i, :, 3] = max_value 679 | position_batch_sag[i, :, 0] = max_idx_sag_x 680 | position_batch_sag[i, :, 1] = max_idx_sag_y 681 | position_batch_cor[i, :, 0] = max_idx_cor_x 682 | position_batch_cor[i, :, 1] = max_idx_cor_y 683 | 684 | return position_batch, position_batch_cor, position_batch_sag -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | 4 | def mkdir(path): 5 | try: 6 | os.makedirs(path) 7 | except OSError as e: 8 | if e.errno != errno.EEXIST: 9 | raise -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import datetime 3 | import os 4 | import time 5 | from utils.inference import * 6 | import glob 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as func 10 | from scipy.io import loadmat 11 | from utils.metrics import * 12 | import cv2 13 | 14 | 15 | def compute_loss(gt_sag, gt_cor, output_sag, output_cor, w_front, w_side, device, sag_pad, cor_pad): 16 | # gt_sag_segment = torch.FloatTensor(gt_sag.size()).to(device) 17 | # gt_cor_segment = torch.FloatTensor(gt_cor.size()).to(device) 18 | # 19 | # gt_sag_segment[:, 1:25, :, :] = torch.where(gt_sag[:, 1:25, :, :] > 0.6, 20 | # torch.full_like(gt_sag[:, 1:25, :, :], 1), 21 | # torch.full_like(gt_sag[:, 1:25, :, :], 0)) 22 | # gt_sag_segment[:, 0, :, :] = torch.where(gt_sag[:, 0, :, :] <= 0.4, 23 | # torch.full_like(gt_sag[:, 0, :, :], 1), 24 | # torch.full_like(gt_sag[:, 0, :, :], 0)) 25 | # gt_cor_segment[:, 1:25, :, :] = torch.where(gt_cor[:, 1:25, :, :] > 0.6, 26 | # torch.full_like(gt_cor[:, 1:25, :, :], 1), 27 | # torch.full_like(gt_cor[:, 1:25, :, :], 0)) 28 | # gt_cor_segment[:, 0, :, :] = torch.where(gt_cor[:, 0, :, :] <= 0.4, 29 | # torch.full_like(gt_cor[:, 0, :, :], 1), 30 | # torch.full_like(gt_cor[:, 0, :, :], 0)) 31 | 32 | loss_MSE_sag = torch.sum(torch.pow((gt_sag - output_sag), 2)) 33 | loss_MSE_cor = torch.sum(torch.pow((gt_cor - output_cor), 2)) 34 | 35 | product_sag = -func.log_softmax(output_sag, dim=1) * func.softmax(gt_sag, dim=1) 36 | product_cor = -func.log_softmax(output_cor, dim=1) * func.softmax(gt_cor, dim=1) 37 | for batch_num in range(gt_cor.shape[0]): 38 | product_sag[batch_num, :, :sag_pad[2][batch_num], :] = 0 39 | product_sag[batch_num, :, :, product_sag.shape[3] - sag_pad[1][batch_num]:] = 0 40 | product_sag[batch_num, :, product_sag.shape[2] - sag_pad[3][batch_num]:, :] = 0 41 | product_sag[batch_num, :, :, :sag_pad[0][batch_num]] = 0 42 | 43 | product_cor[batch_num, :, :cor_pad[2][batch_num], :] = 0 44 | product_cor[batch_num, :, :, product_cor.shape[3] - cor_pad[1][batch_num]:] = 0 45 | product_cor[batch_num, :, product_cor.shape[2] - cor_pad[3][batch_num]:, :] = 0 46 | product_cor[batch_num, :, :, :cor_pad[0][batch_num]] = 0 47 | 48 | 49 | 50 | loss_cross_entropy_sag = torch.sum(torch.sum(torch.sum(torch.sum(product_sag, dim=2), dim=2), dim=0) * w_side) 51 | loss_cross_entropy_cor = torch.sum(torch.sum(torch.sum(torch.sum(product_cor, dim=2), dim=2), dim=0) * w_front) 52 | 53 | return loss_MSE_sag + loss_MSE_cor + loss_cross_entropy_cor + loss_cross_entropy_sag 54 | 55 | 56 | 57 | def do_train(cfg, args, model, model_D1, model_D2, data_loader, optimizer, optimizer_D1, optimizer_D2, checkpointer, device, arguments): 58 | # 59 | logger = setup_colorful_logger("trainer", save_dir=os.path.join(cfg.OUTPUT_DIR, 'log.txt'), format="include_other_info") 60 | logger.warning("Start training ...") 61 | logger_val = setup_colorful_logger("evaluator", save_dir=os.path.join(cfg.OUTPUT_DIR, 'log.txt'), format="include_other_info") 62 | w = loadmat(cfg.TRAIN_WEIGHT) 63 | w_front, w_side = torch.Tensor(w["front"]).to(device), torch.Tensor(w["side"]).to(device) 64 | 65 | model.train() 66 | if None not in (model_D1, model_D2, optimizer_D1, optimizer_D2): 67 | m = torch.tensor(32).to(device) 68 | model_D1.train() 69 | model_D2.train() 70 | if args.use_tensorboard: 71 | import tensorboardX 72 | summary_writer = tensorboardX.SummaryWriter(log_dir=os.path.join(cfg.OUTPUT_DIR, 'tf_logs')) 73 | else: 74 | summary_writer = None 75 | 76 | max_iter = cfg.SOLVER.MAX_ITER 77 | iteration = arguments["iteration"] 78 | start_epoch = arguments["epoch"] 79 | list_loss_val = arguments["list_loss_val"] 80 | 81 | start_training_time = time.time() 82 | for epoch in range(round(max_iter/len(data_loader)))[start_epoch+1:]: 83 | arguments["epoch"] = epoch 84 | loss_show = 0 85 | if None not in (model_D1, model_D2, optimizer_D1, optimizer_D2): 86 | loss_show_D1 = 0 87 | loss_show_D2 = 0 88 | ins_num = 0 89 | for idx, sample in enumerate(data_loader): 90 | iteration = iteration + 1 91 | arguments["iteration"] = iteration 92 | 93 | input_cor_padded = sample["input_cor"].float().to(device) 94 | input_sag_padded = sample["input_sag"].float().to(device) 95 | gt_cor = sample["gt_cor"].float().to(device) 96 | gt_sag = sample["gt_sag"].float().to(device) 97 | cor_pad = sample["cor_pad"] 98 | sag_pad = sample["sag_pad"] 99 | 100 | output_sag, output_cor = model(input_sag_padded, input_cor_padded) 101 | 102 | for batch_num in range(gt_cor.shape[0]): 103 | output_sag[batch_num, :, :sag_pad[2][batch_num], :] = 0 104 | output_sag[batch_num, :, :, output_sag.shape[3] - sag_pad[1][batch_num]:] = 0 105 | output_sag[batch_num, :, output_sag.shape[2] - sag_pad[3][batch_num]:, :] = 0 106 | output_sag[batch_num, :, :, :sag_pad[0][batch_num]] = 0 107 | 108 | output_cor[batch_num, :, :cor_pad[2][batch_num], :] = 0 109 | output_cor[batch_num, :, :, output_cor.shape[3] - cor_pad[1][batch_num]:] = 0 110 | output_cor[batch_num, :, output_cor.shape[2] - cor_pad[3][batch_num]:, :] = 0 111 | output_cor[batch_num, :, :, :cor_pad[0][batch_num]] = 0 112 | 113 | if None not in (model_D1, model_D2, optimizer_D1, optimizer_D2): 114 | output_fake_D1 = model_D1(output_sag.detach()) 115 | output_fake_D2 = model_D2(output_cor.detach()) 116 | output_gt_D1 = model_D1(gt_sag.detach()) 117 | output_gt_D2 = model_D2(gt_cor.detach()) 118 | loss_D1 = output_gt_D1 + torch.max(torch.tensor(0).float().to(device), m - output_fake_D1) 119 | loss_D2 = output_gt_D2 + torch.max(torch.tensor(0).float().to(device), m - output_fake_D2) 120 | 121 | loss_show_D1 += loss_D1.item() 122 | loss_show_D2 += loss_D2.item() 123 | 124 | ins_num += gt_cor.size(0) 125 | 126 | if None not in (model_D1, model_D2, optimizer_D1, optimizer_D2): 127 | optimizer_D1.zero_grad() 128 | loss_D1.backward() 129 | optimizer_D1.step() 130 | 131 | optimizer_D2.zero_grad() 132 | loss_D2.backward() 133 | optimizer_D2.step() 134 | 135 | 136 | loss_G = compute_loss(gt_sag, gt_cor, output_sag, output_cor, w_front, w_side, device, sag_pad, cor_pad) 137 | if None not in (model_D1, model_D2, optimizer_D1, optimizer_D2): 138 | loss_G = loss_G + model_D1(output_sag) + model_D2(output_cor) 139 | loss_show += loss_G.item() 140 | optimizer.zero_grad() 141 | loss_G.backward() 142 | optimizer.step() 143 | 144 | 145 | if ins_num != len(glob.glob(pathname=cfg.MAT_DIR_TRAIN + "*.mat")): 146 | raise Exception("Instance number is not equal to sum of batch sizes!") 147 | 148 | if epoch % args.log_step == 0: 149 | if None in (model_D1, model_D2, optimizer_D1, optimizer_D2): 150 | logger.info("epoch: {epoch:05d}, iter: {iter:06d}, loss_G: {loss_G}" 151 | .format(epoch=epoch, iter=iteration, loss_G=loss_show/ins_num)) 152 | else: 153 | logger.info("epoch: {epoch:05d}, iter: {iter:06d}, loss_G: {loss_G}, loss_D1: {loss_D1}, loss_D2: {loss_D2}" 154 | .format(epoch=epoch, iter=iteration, loss_G=loss_show/ins_num, loss_D1=loss_show_D1/ins_num, loss_D2=loss_show_D2/ins_num)) 155 | if summary_writer: 156 | summary_writer.add_scalar('loss_G', loss_show/ins_num, global_step=iteration) 157 | if None not in (model_D1, model_D2, optimizer_D1, optimizer_D2): 158 | summary_writer.add_scalar('loss_D1', loss_show_D1 / ins_num, global_step=iteration) 159 | summary_writer.add_scalar('loss_D2', loss_show_D2 / ins_num, global_step=iteration) 160 | 161 | 162 | if args.eval_step > 0 and epoch % args.eval_step == 0 and not iteration == max_iter: 163 | loss_val, id_rate, id_rate_gt = do_evaluation(cfg, model, summary_writer, iteration) 164 | logger_val.error("epoch: {epoch:05d}, iter: {iter:06d}, evaluation_loss: {loss}, \nid_rate: {id_rate}, \nid_rate_gt: {id_rate_gt}, " 165 | .format(epoch=epoch, iter=iteration, loss=loss_val, id_rate=id_rate, id_rate_gt=id_rate_gt)) 166 | best_id_rate_gt = - max(id_rate_gt) 167 | max_loss_iter = max(list_loss_val, key=list_loss_val.get) if len(list_loss_val) else 999 168 | min_loss_iter = min(list_loss_val, key=list_loss_val.get) if len(list_loss_val) else -1 169 | if len(list_loss_val) == 0: 170 | logger_val.warning("Have no saved model, saving first model_{:06d}. ".format(iteration)) 171 | checkpointer.save("model_{:06d}".format(iteration), is_last=False, is_best=True, **arguments) 172 | list_loss_val[str(iteration)] = best_id_rate_gt 173 | elif len(list_loss_val) < cfg.SOLVER.SAVE_NUM: 174 | if list_loss_val[min_loss_iter] > best_id_rate_gt: 175 | logger_val.warning("Have saved {:02d} models, " 176 | "saving newest (best) model_{:06d}. ".format(len(list_loss_val), iteration)) 177 | checkpointer.save("model_{:06d}".format(iteration), is_last=False, is_best=True, **arguments) 178 | else: 179 | logger_val.warning("Have saved {:02d} models, " 180 | "saving newest (NOT best) model_{:06d}. ".format(len(list_loss_val), iteration)) 181 | checkpointer.save("model_{:06d}".format(iteration), is_last=False, is_best=False, **arguments) 182 | list_loss_val[str(iteration)] = best_id_rate_gt 183 | else: 184 | if list_loss_val[max_loss_iter] >= best_id_rate_gt: 185 | if list_loss_val[min_loss_iter] > best_id_rate_gt: 186 | logger_val.warning("Have saved {:02d} models, " 187 | "deleting the worst saved model_{:06d} and " 188 | "saving newest (best) model_{:06d}. ".format(cfg.SOLVER.SAVE_NUM, int(max_loss_iter), iteration)) 189 | checkpointer.save("model_{:06d}".format(iteration), is_last = False, is_best=True, **arguments) 190 | else: 191 | logger_val.warning("Have saved {:02d} models, " 192 | "deleting the worst saved model_{:06d} and " 193 | "saving newest (NOT best) model_{:06d}. ".format(cfg.SOLVER.SAVE_NUM, int(max_loss_iter), iteration)) 194 | checkpointer.save("model_{:06d}".format(iteration), is_last=False, is_best=False, **arguments) 195 | del list_loss_val[max_loss_iter] 196 | os.system("rm " + cfg.OUTPUT_DIR + "model_{:06d}.pth".format(int(max_loss_iter))) 197 | list_loss_val[str(iteration)] = best_id_rate_gt 198 | else: 199 | logger_val.warning("Have saved {:02d} models, " 200 | "newest model_{:06d} is the worst. " 201 | "No model is saved or deleted in the best-model list. ".format(cfg.SOLVER.SAVE_NUM, iteration)) 202 | os.system("rm " + cfg.OUTPUT_DIR + "model_last.pth") 203 | checkpointer.save("model_last", is_last=True, is_best=False, **arguments) 204 | 205 | if summary_writer: 206 | summary_writer.add_scalar('val_loss', loss_val, global_step=iteration) 207 | model.train() 208 | 209 | if iteration > max_iter: 210 | break 211 | 212 | checkpointer.save("model_final", **arguments) 213 | # compute training time 214 | total_training_time = int(time.time() - start_training_time) 215 | total_time_str = str(datetime.timedelta(seconds=total_training_time)) 216 | logger.warning("Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / max_iter)) 217 | return model --------------------------------------------------------------------------------