├── README.md ├── configs └── config_ours.py ├── main_ours_resnet18_msc.py ├── main_ours_resnet50_msc.py ├── networks ├── ccfloss.py ├── ccfnet.py ├── ccfnet_resnet.py └── parts.py ├── requirements.txt └── utils ├── ccf_utils.py ├── datasets.py ├── dicom_utils.py ├── functions.py ├── series_utils.py ├── study_utils.py ├── transforms.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Scale Context-Guided Lumbar Spine Disease Identification with Coarse-to-fine Localization and Classification 2 | 3 | ## Introduction 4 | This repository is the official PyTorch implementation of CCF-Net (Multi-Scale Context-Guided Lumbar Spine Disease Identification with Coarse-to-fine Localization and Classification), ISBI 2022 (Oral). CCF-Net is also the runner-up solution of the 2020 [Spinal Disease Intelligent Diagnosis AI Challenge](https://tianchi.aliyun.com/competition/entrance/531796/information). The link to the paper is [here](https://ieeexplore.ieee.org/document/9761528). 5 | 6 | ![image](https://user-images.githubusercontent.com/24490441/158391952-a2841e9a-c8d0-426b-959f-03a92c62e955.png) 7 | 8 | ## Requirements 9 | - Pytorch>=1.1.0 10 | - CPU or GPU 11 | - Other packages can be installed with the following command: 12 | ``` 13 | pip install requirements.txt 14 | ``` 15 | ## Quick start 16 | Dataset is provided in [TianChi platform](https://tianchi.aliyun.com/dataset/dataDetail?spm=5176.12281978.0.0.51947a4co21Um6&dataId=79463). 17 | 18 | Once the data is ready, you can run the code with the following command. 19 | ``` 20 | python main_ours_resnet50_msc.py 21 | ``` 22 | 23 | ## Results 24 | | Method | Backbone | Params (M) | Flops (G) | L-Disc | L-Vertebra | C-Disc | C-Vertebra | Score | 25 | |---|---|---|---|---|---|---|---|---| 26 | | SimpleBaseline | ResNet18 | 15.38 | 33.23 | 87.81 | 86.11 | 89.26 | 71.71 | 70.70 | 27 | | SimpleBaseline | ResNet18-MSC | 6.86 | 33.50 | 91.80 | 92.94 | 88.20 | 74.47 | 75.13 | 28 | | SCN | ResNet18 | 26.57 | 42.73 | 88.56 | 88.77 | 89.26 | 71.18 | 71.18 | 29 | | SCN | ResNet18-MSC | 11.62 | 45.43 | 92.79 | 94.13 | 90.16 | 75.94 | 77.64 | 30 | | CCF-Net (ours) | ResNet18 | 9.51 | 11.20 | 89.69 | 89.23 | 89.32 | 76.23 | 74.05 | 31 | | CCF-Net (ours) | ResNet18-MSC | 4.83 | 12.00 | 94.75 | 94.71 | 90.88 | 79.16 | 80.50 | 32 | 33 | | Method | Backbone | Params (M) | Flops (G) | L-Disc | L-Vertebra | C-Disc | C-Vertebra | Score | 34 | |---|---|---|---|---|---|---|---|---| 35 | | SimpleBaseline | ResNet50 | 34.00 | 51.64 | 89.33 | 90.07 | 90.21 | 76.06 | 74.59 | 36 | | SimpleBaseline | ResNet50-MSC | 45.33 | 124.80 | 94.53 | 94.54 | 90.56 | 76.06 | 78.77 | 37 | | HRNet | W32 | 28.54 | 40.98 | 96.19 | 95.63 | 89.68 | 78.43 | 80.62 | 38 | | CCF-Net (ours) | ResNet50 | 23.60 | 21.49 | 90.43 | 90.24 | 90.06 | 76.35 | 75.13 | 39 | | CCF-Net (ours) | ResNet50-MSC | 10.79 | 22.09 | 85.68 | 96.41 | 90.59 | 77.46 | 80.64 | 40 | 41 | ![image](https://user-images.githubusercontent.com/24490441/158932972-0e6d9266-ba09-4216-96ae-c1150118fa69.png) 42 | 43 | ## Citation 44 | ``` 45 | @INPROCEEDINGS{9761528, 46 | author={Chen, Zifan and Zhao, Jie and Yu, Hao and Zhang, Yue and Zhang, Li}, 47 | booktitle={2022 IEEE 19th International Symposium on Biomedical Imaging (ISBI)}, 48 | title={Multi-Scale Context-Guided Lumbar Spine Disease Identification with Coarse-to-Fine Localization and Classification}, 49 | year={2022}, 50 | volume={}, 51 | number={}, 52 | pages={1-5}, 53 | doi={10.1109/ISBI52829.2022.9761528}} 54 | ``` 55 | -------------------------------------------------------------------------------- /configs/config_ours.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import torch 7 | import numpy as np 8 | 9 | class Config(object): 10 | def __init__(self): 11 | super().__init__() 12 | # gpu | cpu settings 13 | self.gpus = [0, ] 14 | self.device = torch.device('cuda:{}'.format(self.gpus[0]) 15 | if torch.cuda.is_available() else 'cpu') 16 | if torch.cuda.is_available(): 17 | self.num_workers = 3 * len(self.gpus) 18 | else: 19 | self.num_workers = 0 20 | 21 | # processing data settings 22 | self.multi_processing = True 23 | self.is_merge_valid = True 24 | self.train_data_dir = 'preliminary_dataset/lumbar_train201/' 25 | self.train_json_file = 'preliminary_dataset/lumbar_train201_annotation.json' 26 | self.valid_data_dir = 'preliminary_dataset/lumbar_train201/' 27 | self.valid_json_file = 'preliminary_dataset/lumbar_train201_annotation.json' 28 | self.testB_data_dir = 'preliminary_dataset/lumbar_testB50/' 29 | self.testB_json_file = 'preliminary_dataset/testB50_series_map.json' 30 | self.testAB_data_dir = 'preliminary_dataset/lumbar_testAB100/' 31 | self.testAB_json_file = 'preliminary_dataset/lumbar_testAB100_annotation.json' 32 | self.kfold_model_dir = 'models/kfold/' 33 | 34 | # network settings 35 | self.backbone = 'resnet18' 36 | self.pretrained = 'models/imagenet/resnet18-5c106cde.pth' 37 | #self.pretrained = 'models/imagenet/resnet50-19c8e357.pth' 38 | 39 | # dataset settings 40 | self.num_rep = 10 # 每一个epoch遍历数据的次数 41 | self.num_classes = 11 # 关键点的类别数(5个锥体+6个椎间盘) 42 | self.num_v_classes = 2 43 | self.v_weights = 1. / np.array([191, 814]) 44 | self.v_numbers = [191, 814] 45 | self.num_d_classes = 4 46 | self.d_weights = 1. / np.array([574, 300, 285, 47]) 47 | self.d_numbers = [574, 300, 285, 47] 48 | self.num_d_v5_classes = 2 49 | self.d_v5_weights = 1. / np.array([1142, 64]) 50 | self.d_v5_numbers = [1142, 64] 51 | self.sagittal_size = (512, 512) 52 | self.predict_size = (32, 32) 53 | self.transverse_size = (80, 80) 54 | self.stride = 16 55 | self.sigma = 16 56 | 57 | # training settings 58 | self.batch_size = 32 * len(self.gpus) 59 | self.epochs = 60 60 | self.display = 200 61 | self.lr = 3e-4 62 | self.weight_decay = 5e-4 63 | self.gamma = 0.96 64 | self.flooding_d = 0.0008 65 | self.test_size = 0.25 66 | 67 | # evaluate 68 | self.test_flip = True 69 | self.heat_weight = 1.0 70 | self.offset_weight = 2.0 71 | self.class_weight = 0.001 72 | self.max_dist = 6 73 | self.epsilon = 1e-5 74 | self.top_k = 1 75 | self.vote = 'average' 76 | self.train_threshold = 0.6 77 | self.threshold = 0.3 78 | self.metric = 'macro f1' 79 | self.identifications = ['L1', 'L2', 'L3', 'L4', 'L5', 'T12-L1', 'L1-L2', 'L2-L3', 'L3-L4', 'L4-L5', 'L5-S1'] 80 | self.classes = ['v2', 'v2', 'v2', 'v2', 'v2', 'v1', 'v1', 'v1', 'v2', 'v3', 'v3'] 81 | self.is_disc = [False, False, False, False, False, True, True, True, True, True, True] 82 | self.trans_matrix = None 83 | 84 | # augmentation settings 85 | self.sagittal_trans_settings = { 86 | 'size': self.sagittal_size, 87 | 'rotate_p': 0.8, 'max_angel': 45, 88 | 'shift_p': 0.8, 'max_shift_ratios': [0.2, 0.3], 89 | 'crop_p': 0.8, 'max_crop_ratios': [0.2, 0.3], 90 | 'intensity_p': 0.8, 'max_intensity_ratio': 0.3, 91 | 'bias_field_p': 0.8, 'order': 3, 'coefficients_range': [-0.5, 0.5], 92 | 'noise_p': 0., 'noise_mean': [-5, 5], 'noise_std': [0, 2], 93 | 'hflip_p': 0.5 94 | } 95 | self.transverse_trans_settings = { 96 | 'size': self.transverse_size, 97 | 'rotate_p': 0.8, 'max_angel': 25, 98 | 'shift_p': 0.8, 'max_shift_ratios': [0.1, 0.1], 99 | 'crop_p': 0.8, 'max_crop_ratios': [0.1, 0.1], 100 | 'intensity_p': 0.8, 'max_intensity_ratio': 0.3, 101 | 'bias_field_p': 0.8, 'order': 3, 'coefficients_range': [-0.5, 0.5], 102 | 'noise_p': 0., 'noise_mean': [-5, 5], 'noise_std': [0, 2], 103 | 'hflip_p': 0.5 104 | } 105 | 106 | # informations 107 | self.DICOM_TAG = {"studyUid": "0020|000d", 108 | "seriesUid": "0020|000e", 109 | "instanceUid": "0008|0018", 110 | "pixelSpacing": "0028|0030", 111 | "seriesDescription": "0008|103e", 112 | "imagePosition": "0020|0032", 113 | "imageOrientation": "0020|0037"} 114 | 115 | self.SPINAL_VERTEBRA_ID = {"L1": 0, "L2": 1, "L3": 2, "L4": 3, "L5": 4} 116 | self.SPINAL_VERTEBRA_DISEASE_ID = {"v1": 0, "v2": 1} 117 | self.SPINAL_DISC_ID = {"T12-L1": 0, "L1-L2": 1, "L2-L3": 2, "L3-L4": 3, "L4-L5": 4, "L5-S1": 5} 118 | self.SPINAL_DISC_DISEASE_ID = {"v1": 0, "v2": 1, "v3": 2, "v4": 3, "v5": 4} 119 | self.PADDING_VALUE: int = -1 120 | 121 | config = Config() 122 | -------------------------------------------------------------------------------- /main_ours_resnet18_msc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import json 7 | import torch 8 | import torch.nn as nn 9 | import time 10 | import numpy as np 11 | from configs.config_ours import config 12 | from networks.ccfnet import get_ccfnet 13 | from networks.ccfloss import CCFLoss 14 | from utils.functions import ( 15 | train, valid, testA, testB, bagging_train_valid 16 | ) 17 | from torch.utils.data import DataLoader 18 | from utils.utils import build_logging, distance 19 | from utils.study_utils import construct_studies 20 | from utils.datasets import SpineDataSet 21 | import time 22 | from sklearn.model_selection import train_test_split 23 | import random 24 | import torch.backends.cudnn as cudnn 25 | from sklearn.model_selection import KFold 26 | import torchvision.transforms.functional as tf 27 | from utils.ccf_utils import ccf_decoder_prob 28 | from utils.utils import confusion_matrix, format_annotation, gen_annotation, cal_metrics, compute_my_metric 29 | from sklearn import metrics 30 | from sklearn.model_selection import KFold 31 | from thop import profile 32 | 33 | def compute_params_flops(config, model): 34 | flops, params = profile(model, inputs=(torch.randn(1, 3, *config.sagittal_size).to(config.device),)) 35 | config.logger.info('FLOPs = ' + str(flops/1000**3) + 'G') 36 | config.logger.info('Params = ' + str(params/1000**2) + 'M') 37 | 38 | def set_seed(SEED): 39 | torch.manual_seed(SEED) 40 | torch.cuda.manual_seed_all(SEED) 41 | np.random.seed(SEED) 42 | random.seed(SEED) 43 | torch.backends.cudnn.deterministic = True 44 | 45 | import argparse 46 | 47 | parser = argparse.ArgumentParser(description='CCFNet') 48 | parser.add_argument('--seed', default=0, type=int) 49 | parser.add_argument('--train_data_dir', default='', type=str) 50 | parser.add_argument('--train_json_file', default='', type=str) 51 | parser.add_argument('--valid_data_dir', default='', type=str) 52 | parser.add_argument('--valid_json_file', default='', type=str) 53 | parser.add_argument('--num_folds', default=4, type=int) 54 | parser.add_argument('--fold_id', default=1, type=int) 55 | args = parser.parse_args() 56 | 57 | def calc_weights(annotation): 58 | v_numbers = {i: np.zeros(2) for i in range(5)} 59 | d_numbers = {i: np.zeros(4) for i in range(6)} 60 | d_v5_numbers = {i: np.zeros(2) for i in range(6)} 61 | for key, value in annotation.items(): 62 | for i, (_, _, c) in enumerate(value[0]): 63 | v_numbers[i][int(c.item())] += 1 64 | for i, (_, _, c) in enumerate(value[1]): 65 | if c.item() >= 4: continue 66 | d_numbers[i][int(c.item())] += 1 67 | for i, (_, _, c) in enumerate(value[2]): 68 | if c.item() > 0: 69 | d_v5_numbers[i][1] += 1 70 | else: 71 | d_v5_numbers[i][0] += 1 72 | v_weights = 1. / np.sum(np.stack([value for value in v_numbers.values()], axis=0), axis=0) 73 | d_weights = 1. / np.sum(np.stack([value for value in d_numbers.values()], axis=0), axis=0) 74 | d_v5_weights = 1. / np.sum(np.stack([value for value in d_v5_numbers.values()], axis=0), axis=0) 75 | return v_numbers, d_numbers, d_v5_numbers, v_weights, d_weights, d_v5_weights 76 | 77 | def prepare_data(): 78 | train_studies, train_annotation, train_counter = construct_studies( 79 | config, config.train_data_dir, config.train_json_file, multi_processing=config.multi_processing) 80 | valid_studies, valid_annotation, valid_counter = construct_studies( 81 | config, config.valid_data_dir, config.valid_json_file, multi_processing=config.multi_processing) 82 | train_studies.update(valid_studies) 83 | train_annotation.update(valid_annotation) 84 | 85 | studies = [(k, v) for k, v in train_studies.items()] 86 | random.seed(221) 87 | random.shuffle(studies) 88 | studies = [studies[0:50], studies[50:100], studies[100:150], studies[150:]] 89 | 90 | tmp_studies = [] 91 | for j in range(args.num_folds): 92 | if j == args.fold_id: continue 93 | tmp_studies.extend(studies[j]) 94 | train_k_studies = {k: v for i, (k, v) in enumerate(tmp_studies)} 95 | valid_k_studies = {k: v for i, (k, v) in enumerate(studies[args.fold_id])} 96 | train_k_annotation = {k: v for k, v in train_annotation.items() if k[0] in train_k_studies.keys()} 97 | valid_k_annotation = {k: v for k, v in train_annotation.items() if k[0] in valid_k_studies.keys()} 98 | 99 | 100 | print('Split dataset: {} train | {} valid'.format(len(train_k_studies), len(valid_k_studies))) 101 | valid_k_annotation = [anno for anno in 102 | json.load(open(config.train_json_file, 'r')) + json.load(open(config.valid_json_file, 'r')) 103 | if anno['studyUid'] in valid_k_studies.keys()] 104 | # json.dump(valid_k_annotation, open('spinaldiease_dataset/valid.json', 'w')) 105 | # valid_json_file = 'spinaldiease_dataset/valid.json' 106 | 107 | return (train_k_studies, train_k_annotation), (valid_k_studies, valid_k_annotation) 108 | 109 | def gen_annotation_prob(config, study, prediction, cls_prediction): 110 | z_index = study.t2_sagittal.instance_uids[study.t2_sagittal_middle_frame.instance_uid] 111 | point = [] 112 | for i, (coord, identification, cls, is_disc) in enumerate(zip( 113 | prediction, config.identifications, cls_prediction, config.is_disc)): 114 | point.append({ 115 | 'coord': coord.cpu().int().numpy().tolist(), 116 | 'tag': { 117 | 'identification': identification, 118 | 'disc' if is_disc else 'vertebra': cls, 119 | }, 120 | 'zIndex': z_index 121 | }) 122 | annotation = { 123 | 'studyUid': study.study_uid, 124 | 'data': [ 125 | { 126 | 'instanceUid': study.t2_sagittal_middle_frame.instance_uid, 127 | 'seriesUid': study.t2_sagittal_middle_frame.series_uid, 128 | 'annotation': [ 129 | { 130 | 'data': { 131 | 'point': point, 132 | } 133 | } 134 | ] 135 | } 136 | ] 137 | } 138 | return annotation 139 | 140 | def format_annotation_prob(annotations): 141 | output = {} 142 | for annotation in annotations: 143 | study_uid = annotation['studyUid'] 144 | series_uid = annotation['data'][0]['seriesUid'] 145 | instance_uid = annotation['data'][0]['instanceUid'] 146 | temp = {} 147 | for point in annotation['data'][0]['annotation'][0]['data']['point']: 148 | identification = point['tag']['identification'] 149 | coord = point['coord'] 150 | if 'disc' in point['tag']: 151 | disease = point['tag']['disc'] 152 | else: 153 | disease = point['tag']['vertebra'] 154 | if isinstance(disease, str): 155 | if 'v1' in disease: disease = 0.0 156 | else: disease = 1.0 157 | temp[identification] = { 158 | 'coord': coord, 159 | 'disease': disease, 160 | } 161 | output[study_uid] = { 162 | 'seriesUid': series_uid, 163 | 'instanceUid': instance_uid, 164 | 'annotation': temp 165 | } 166 | return output 167 | 168 | def compute_auc(y_true, y_pred): 169 | if len(y_pred) == 0: return 0.0 170 | fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1) 171 | return metrics.auc(fpr, tpr) 172 | 173 | def compute_my_metric_prob(config, studies, predictions, annotations): 174 | disc_point_pred, disc_point_true = [], [] 175 | vertebra_point_pred, vertebra_point_true = [], [] 176 | 177 | disc_cls_pred, disc_cls_true = [], [] 178 | vertebra_cls_pred, vertebra_cls_true = [], [] 179 | 180 | for study_uid in studies.keys(): 181 | if study_uid not in annotations: 182 | print(study_uid, '++++++') 183 | continue 184 | annotation = annotations[study_uid] 185 | study = studies[study_uid] 186 | pixel_spacing = study.t2_sagittal_middle_frame.pixel_spacing 187 | pred_points = predictions[study_uid]['annotation'] 188 | for identification, gt_point in annotation['annotation'].items(): 189 | gt_coord = gt_point['coord'] 190 | gt_disease = gt_point['disease'] 191 | 192 | if identification not in pred_points: 193 | print(identification, pred_points, '-----------------') 194 | continue 195 | 196 | pred_coord = pred_points[identification]['coord'] 197 | pred_disease = pred_points[identification]['disease'] 198 | 199 | if '-' in identification: 200 | disc_point_true.append(1) 201 | else: 202 | vertebra_point_true.append(1) 203 | 204 | if distance(gt_coord, pred_coord, pixel_spacing) <= config.max_dist: 205 | if '-' in identification: # disc 206 | disc_point_pred.append(1) 207 | 208 | disc_cls_pred.append(pred_disease.item()) 209 | disc_cls_true.append(gt_disease) 210 | else: # vertebra 211 | vertebra_point_pred.append(1) 212 | 213 | vertebra_cls_pred.append(pred_disease.item()) 214 | vertebra_cls_true.append(gt_disease) 215 | else: 216 | if '-' in identification: 217 | disc_point_pred.append(0) 218 | else: 219 | vertebra_point_pred.append(0) 220 | 221 | metric_dict = { 222 | 'disc_kp_recall': 1.0 * sum(disc_point_pred) / sum(disc_point_true), 223 | 'vertebra_kp_recall': 1.0 * sum(vertebra_point_pred) / sum(vertebra_point_true), 224 | 'disc_cls_auc': compute_auc(disc_cls_true, disc_cls_pred) if len(set(disc_cls_true)) == 2 else 0.0, 225 | 'vertebra_cls_auc': compute_auc(vertebra_cls_true, vertebra_cls_pred) if len(set(vertebra_cls_true)) == 2 else 0.0, 226 | } 227 | metric_dict['kp_recall'] = (metric_dict['disc_kp_recall'] + metric_dict['vertebra_kp_recall']) / 2.0 228 | metric_dict['cls_auc'] = (metric_dict['disc_cls_auc'] + metric_dict['vertebra_cls_auc']) / 2.0 229 | metric_dict['score'] = metric_dict['kp_recall'] * metric_dict['cls_auc'] 230 | return metric_dict 231 | 232 | def evaluate(config, model, studies, eval_annotations, trans_model=None, class_tensors=None): 233 | model.eval() 234 | annotations = [] 235 | for study in studies.values(): 236 | kp_frame = study.t2_sagittal_middle_frame 237 | sagittal_image = tf.resize(kp_frame.image, config.sagittal_size) 238 | sagittal_image = tf.to_tensor(sagittal_image).unsqueeze(dim=0).float().to(config.device) 239 | with torch.no_grad(): 240 | predictions = model(sagittal_image) 241 | heatmaps, offymaps, offxmaps, clssmaps = predictions['result'] 242 | if config.test_flip: 243 | sagittal_image_flipped = np.flip(sagittal_image.detach().cpu().numpy(), 3).copy() 244 | sagittal_image_flipped = torch.from_numpy(sagittal_image_flipped).to(sagittal_image.device) 245 | predictions_flipped = model(sagittal_image_flipped) 246 | heatmaps_flipped, _, _, clssmaps_flipped = predictions_flipped['result'] 247 | heatmaps = (heatmaps + torch.from_numpy(np.flip(heatmaps_flipped.detach().cpu().numpy(), 3).copy()).to( 248 | heatmaps.device)) / 2. 249 | clssmaps = (clssmaps + torch.from_numpy(np.flip(clssmaps_flipped.detach().cpu().numpy(), 3).copy()).to( 250 | clssmaps.device)) / 2. 251 | heatmaps = torch.mean(heatmaps, dim=0, keepdim=True) 252 | offymaps = torch.mean(offymaps, dim=0, keepdim=True) 253 | offxmaps = torch.mean(offxmaps, dim=0, keepdim=True) 254 | clssmaps = torch.mean(clssmaps, dim=0, keepdim=True) 255 | prediction, _, cls_prediction = ccf_decoder_prob( 256 | config, heatmaps[0], offymaps[0], offxmaps[0], clssmaps[0], maskmap=None, tensor=True) 257 | height_ratio = config.sagittal_size[0] / kp_frame.size[1] 258 | width_ratio = config.sagittal_size[1] / kp_frame.size[0] 259 | ratio = torch.tensor([width_ratio, height_ratio], device=prediction.device) 260 | prediction = (prediction / ratio).round().float() 261 | annotation = gen_annotation_prob(config, study, prediction, cls_prediction) 262 | annotations.append(annotation) 263 | 264 | predictions = annotations 265 | predictions = format_annotation_prob(predictions) 266 | 267 | # with open(annotation_path, 'r') as file: 268 | # annotations = json.load(file) 269 | annotations = format_annotation_prob(eval_annotations) 270 | 271 | metric_dict = compute_my_metric_prob(config, studies, predictions, annotations) 272 | 273 | return metric_dict 274 | 275 | def train(config, epoch, model, loader, criterion, params, optimizer=None, train=True): 276 | model.train() if train else model.eval() 277 | epoch_records = {'time': []} 278 | num_batchs = len(loader) 279 | v_loss_weight, d_loss_weight, d_v5_loss_weight = params['v_loss_weight'], params['d_loss_weight'], params['d_v5_loss_weight'] 280 | print('v loss weight: {} | d loss weight: {} | d v5 loss weight: {}'.format(v_loss_weight, d_loss_weight, d_v5_loss_weight)) 281 | for batch_idx, (sagittal_images, heatmaps, offymaps, offxmaps, maskmaps, clssmaps, _, _) in enumerate(loader): 282 | start_time = time.time() 283 | images = sagittal_images.float().to(config.device) 284 | heat_targets = heatmaps.float().to(config.device) 285 | offy_targets = offymaps.float().to(config.device) 286 | offx_targets = offxmaps.float().to(config.device) 287 | masks = maskmaps.float().to(config.device) 288 | clss_targets = clssmaps.long().to(config.device) 289 | 290 | if train: 291 | predictions = model(images, heatmaps=heat_targets, offymaps=offy_targets, offxmaps=offx_targets) 292 | heatmap, offymap, offxmap, clssmap = predictions['result'] 293 | loss, info = criterion(heatmap, offymap, offxmap, clssmap, 294 | heat_targets, offy_targets, offx_targets, clss_targets, masks, 295 | v_loss_weight, d_loss_weight, d_v5_loss_weight, name='predict') 296 | optimizer.zero_grad() 297 | loss.backward() 298 | optimizer.step() 299 | else: 300 | with torch.no_grad(): 301 | predictions = model(images, heatmaps=heat_targets, offymaps=offy_targets, offxmaps=offx_targets) 302 | heatmap, offymap, offxmap, clssmap = predictions['result'] 303 | loss, info = criterion(heatmap, offymap, offxmap, clssmap, 304 | heat_targets, offy_targets, offx_targets, clss_targets, masks, 305 | v_loss_weight, d_loss_weight, d_v5_loss_weight, name='predict') 306 | 307 | for key, value in info.items(): 308 | if key not in epoch_records: epoch_records[key] = [] 309 | epoch_records[key].append(value) 310 | epoch_records['time'].append(time.time() - start_time) 311 | 312 | if (batch_idx and batch_idx % config.display == 0) or (batch_idx == num_batchs - 1): 313 | context = '[{}] EP:{:03d}\tTI:{:03d}/{:03d}\t'.format('T' if train else 'V', epoch, batch_idx, num_batchs) 314 | context_predict = '\t' 315 | context_class = '\t' 316 | for key, value in epoch_records.items(): 317 | if 'cls' in key: 318 | context_class += '{}:{:.6f}({:.6f})\t'.format(key, value[-1], np.mean(value)) 319 | elif 'predict' in key: 320 | context_predict += '{}:{:.6f}({:.6f})\t'.format(key, value[-1], np.mean(value)) 321 | else: 322 | context += '{}:{:.6f}({:.6f})\t'.format(key, value[-1], np.mean(value)) 323 | print(context) 324 | print(context_predict) 325 | print(context_class) 326 | 327 | if not train: 328 | params['v_loss'].append(np.mean(epoch_records['predict_v_cls_loss'])) 329 | params['d_loss'].append(np.mean(epoch_records['predict_d_cls_loss'])) 330 | params['d_v5_loss'].append(np.mean(epoch_records['predict_d_v5_cls_loss'])) 331 | 332 | return np.mean(epoch_records['loss']), params, epoch_records 333 | 334 | def train_valid(config, k, train_k_studies, train_k_annotation, valid_k_studies, valid_k_annotation): 335 | print('Training Fold: {}'.format(k)) 336 | save_model_file = os.path.join(config.kfold_model_dir, '{}.pth.tar'.format(k)) 337 | v_numbers, d_numbers, d_v5_numbers, v_weights, d_weights, d_v5_weights = calc_weights(train_k_annotation) 338 | config.v_numbers = v_numbers 339 | config.d_numbers = d_numbers 340 | config.d_v5_numbers = d_v5_numbers 341 | config.v_weights = v_weights 342 | config.d_weights = d_weights 343 | config.d_v5_weights = d_v5_weights 344 | 345 | model = get_ccfnet(config).to(config.device) 346 | compute_params_flops(config, model) 347 | criterion = CCFLoss(config).to(config.device) 348 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) 349 | #lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config.gamma) 350 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2, eta_min=5e-5) 351 | train_k_dataset = SpineDataSet(config, train_k_studies, train_k_annotation) 352 | train_k_loader = DataLoader(train_k_dataset, num_workers=config.num_workers, 353 | batch_size=config.batch_size, shuffle=True, pin_memory=True, 354 | collate_fn=train_k_dataset.collate_fn) 355 | 356 | best_metric_dict = None 357 | params = {'v_loss': [], 'd_loss': [], 'd_v5_loss': [], 358 | 'v_loss_weight': config.class_weight, 'd_loss_weight': config.class_weight, 'd_v5_loss_weight': config.class_weight} 359 | for epoch in range(config.epochs): 360 | start_time = time.time() 361 | _, _, epoch_records = train(config, epoch, model, train_k_loader, criterion, params, optimizer=optimizer, train=True) 362 | metric_dict = evaluate(config, model, valid_k_studies, valid_k_annotation) 363 | if best_metric_dict is None or np.isnan(metric_dict['score']) or metric_dict['score'] >= best_metric_dict['score']: 364 | best_metric_dict = metric_dict 365 | 366 | config.logger.info(f'Epoch={epoch}') 367 | config.logger.info('=' * 50 + 'valid' + '=' * 50) 368 | for key, value in metric_dict.items(): 369 | config.logger.info('valid_'+key+': '+str(round(value, 4))) 370 | config.logger.info('=' * 50 + 'best' + '=' * 51) 371 | for key, value in best_metric_dict.items(): 372 | config.logger.info('best_'+key+': '+str(round(value, 4))) 373 | config.logger.info('=' * 105) 374 | config.logger.info('') 375 | 376 | lr_scheduler.step() 377 | print('=' * 60 + ' {:.4f} '.format(time.time() - start_time) + '=' * 60) 378 | torch.save(model.state_dict(), save_model_file) 379 | print('saved model: {}'.format(save_model_file)) 380 | 381 | torch.cuda.empty_cache() 382 | 383 | return save_model_file, epoch_records 384 | 385 | def main(): 386 | set_seed(args.seed) 387 | config.train_data_dir = args.train_data_dir 388 | config.train_json_file = args.train_json_file 389 | config.valid_data_dir = args.valid_data_dir 390 | config.valid_json_file = args.valid_json_file 391 | 392 | config.output_dir = os.path.join('output', 'ours_resnet18_msc', str(args.fold_id)) 393 | os.makedirs(config.output_dir, exist_ok=True) 394 | config.logger = build_logging(os.path.join(config.output_dir, 'log.log')) 395 | 396 | (train_studies, train_annotation), \ 397 | (valid_studies, valid_annotation) = prepare_data() 398 | 399 | train_valid(config, 0, train_studies, train_annotation, valid_studies, valid_annotation) 400 | 401 | if __name__ == '__main__': 402 | main() 403 | -------------------------------------------------------------------------------- /main_ours_resnet50_msc.py: -------------------------------------------------------------------------------- 1 | from main_ours_resnet18_msc import * 2 | 3 | def main(): 4 | config.backbone = 'resnet50' 5 | config.pretrained = 'models/imagenet/resnet50-19c8e357.pth' 6 | config.batch_size = 12 * len(config.gpus) 7 | 8 | set_seed(args.seed) 9 | config.train_data_dir = args.train_data_dir 10 | config.train_json_file = args.train_json_file 11 | config.valid_data_dir = args.valid_data_dir 12 | config.valid_json_file = args.valid_json_file 13 | 14 | config.output_dir = os.path.join('output', 'ours_resnet50_msc', str(args.fold_id)) 15 | os.makedirs(config.output_dir, exist_ok=True) 16 | config.logger = build_logging(os.path.join(config.output_dir, 'log.log')) 17 | 18 | (train_studies, train_annotation), \ 19 | (valid_studies, valid_annotation) = prepare_data() 20 | 21 | train_valid(config, 0, train_studies, train_annotation, valid_studies, valid_annotation) 22 | 23 | if __name__ == '__main__': 24 | main() -------------------------------------------------------------------------------- /networks/ccfloss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import contextlib 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import torch.nn.functional as F 10 | 11 | class CBCrossEntropyLoss(nn.Module): 12 | def __init__(self, samples_per_class, device): 13 | super().__init__() 14 | self.samples_per_class = np.sum(np.stack([value for value in samples_per_class.values()], axis=0), axis=0) 15 | if not isinstance(self.samples_per_class, np.ndarray): 16 | self.samples_per_class = np.array(self.samples_per_class) 17 | self.betas_per_class = (self.samples_per_class - 1.) / self.samples_per_class 18 | self.ens_per_class = (1. - np.power(self.betas_per_class, self.samples_per_class)) / (1. - self.betas_per_class) 19 | self.weights_per_class = 1. / (self.ens_per_class + 1e-5) 20 | self.base_criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(self.weights_per_class).float().to(device)) 21 | 22 | def forward(self, predictions, targets): 23 | return self.base_criterion(predictions, targets) 24 | 25 | class CBBCEWithLogitsLoss(nn.Module): 26 | def __init__(self, samples_per_class, device): 27 | super().__init__() 28 | self.samples_per_class = np.sum(np.stack([value for value in samples_per_class.values()], axis=0), axis=0) 29 | if not isinstance(self.samples_per_class, np.ndarray): 30 | self.samples_per_class = np.array(self.samples_per_class) 31 | self.betas_per_class = (self.samples_per_class - 1.) / self.samples_per_class 32 | self.ens_per_class = (1. - np.power(self.betas_per_class, self.samples_per_class)) / (1. - self.betas_per_class) 33 | self.weights_per_class = 1. / (self.ens_per_class + 1e-5) 34 | self.pos_weight = self.weights_per_class[1] / (self.weights_per_class[0] + 1e-5) 35 | self.base_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.pos_weight]).float().to(device)) 36 | 37 | def forward(self, predictions, targets): 38 | return self.base_criterion(predictions, targets) 39 | 40 | class CCFHeatLoss(nn.Module): 41 | def __init__(self, config): 42 | super().__init__() 43 | self.criterion = nn.MSELoss(reduction='mean') 44 | 45 | def forward(self, predictions, targets, masks): 46 | B, C, _, _ = predictions.shape 47 | predictions = predictions * masks 48 | targets = targets * masks 49 | predictions = predictions.reshape((B, C, -1)).split(1, 1) 50 | targets = targets.reshape((B, C, -1)).split(1, 1) 51 | loss = 0 52 | for c in range(C): 53 | prediction = predictions[c].squeeze() 54 | target = targets[c].squeeze() 55 | loss += self.criterion(prediction, target) 56 | 57 | return loss / C 58 | 59 | class CCFOffsetLoss(nn.Module): 60 | def __init__(self, config): 61 | super().__init__() 62 | self.criterion = nn.SmoothL1Loss() 63 | self.threshold = config.train_threshold 64 | 65 | def forward(self, predictions, targets, masks): 66 | peaks = (masks.max(3))[0].max(2)[0] 67 | peaks = peaks.unsqueeze(2).unsqueeze(3) 68 | masks = masks - peaks 69 | masks[masks >= (-1 * self.threshold)] = 1 70 | masks[masks < (-1 * self.threshold)] = 0 71 | predictions = predictions * masks 72 | targets = targets * masks 73 | return self.criterion(predictions, targets) 74 | 75 | class CCFLoss(nn.Module): 76 | def __init__(self, config): 77 | super().__init__() 78 | self.heat_weight = config.heat_weight 79 | self.offset_weight = config.offset_weight 80 | self.class_weight = config.class_weight 81 | self.heat_criterion = CCFHeatLoss(config) 82 | self.offset_criterion = CFFOffsetLoss(config) 83 | self.threshold = config.train_threshold 84 | 85 | self.v_cls_criterion = CBBCEWithLogitsLoss(config.v_numbers, config.device) 86 | d_numbers = {} 87 | for key in config.d_numbers.keys(): 88 | d_numbers[key] = np.asarray([config.d_numbers[key][0], sum(config.d_numbers[key][1:])]).astype(np.float) 89 | self.d_cls_criterion = CBBCEWithLogitsLoss(d_numbers, config.device) 90 | 91 | def forward(self, heat_predictions, offy_predictions, offx_predictions, clss_predictions, 92 | heat_targets, offy_targets, offx_targets, clss_targets, masks, 93 | v_loss_weight, d_loss_weight, d_v5_loss_weight=None, name='stage1'): 94 | heat_loss = self.heat_criterion(heat_predictions, heat_targets, masks) 95 | offy_loss = self.offset_criterion(offy_predictions, offy_targets, heat_targets) 96 | offx_loss = self.offset_criterion(offx_predictions, offx_targets, heat_targets) 97 | 98 | vx, vy, dx, dy = [], [], [], [] 99 | temp_targets = heat_targets.detach().view(heat_targets.shape[0], heat_targets.shape[1], -1) 100 | for b in range(temp_targets.shape[0]): 101 | for c in range(temp_targets.shape[1]): 102 | max_index = torch.argmax(temp_targets[b, c]).item() 103 | basic_grid_y, basic_grid_x = int(max_index // heat_targets.shape[3]), int( 104 | max_index % heat_targets.shape[3]) 105 | for oy, ox in ((0, 0),): 106 | grid_y = basic_grid_y + oy 107 | grid_x = basic_grid_x + ox 108 | if grid_y < 0 or grid_y >= clss_predictions.shape[2] or grid_x < 0 or grid_x >= \ 109 | clss_predictions.shape[3]: 110 | continue 111 | if c < 5: 112 | if clss_targets[b, c, grid_y, grid_x].item() < 0: continue 113 | vx.append(clss_predictions[b, c, grid_y, grid_x]) 114 | vy.append(clss_targets[b, c, grid_y, grid_x]) 115 | else: 116 | if clss_targets[b, c, grid_y, grid_x].item() < 0: continue 117 | dx.append(clss_predictions[b, c, grid_y, grid_x]) 118 | dy.append((clss_targets[b, c, grid_y, grid_x] >= 1).float()) 119 | vx = torch.stack(vx, dim=0) # (B * C) 120 | vy = torch.stack(vy, dim=0) # (B * C,) 121 | dx = torch.stack(dx, dim=0) # (B * C) 122 | dy = torch.stack(dy, dim=0) # (B * C,) 123 | v_cls_loss = self.v_cls_criterion(vx.float(), vy.float()) 124 | d_cls_loss = self.d_cls_criterion(dx.float(), dy.float()) 125 | 126 | loss = heat_loss * self.heat_weight + offy_loss * self.offset_weight + offx_loss * self.offset_weight \ 127 | + v_cls_loss * v_loss_weight + d_cls_loss * d_loss_weight 128 | info = {'loss': loss.item(), 129 | name + '_heat_loss': heat_loss.item(), 130 | name + '_offy_loss': offy_loss.item(), 131 | name + '_offx_loss': offx_loss.item(), 132 | name + '_v_cls_loss': v_cls_loss.item(), 133 | name + '_d_cls_loss': d_cls_loss.item()} 134 | return loss, info -------------------------------------------------------------------------------- /networks/ccfnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | import math 8 | import torch.nn as nn 9 | import torch 10 | from collections import OrderedDict 11 | 12 | def init_weights(model, pretrained): 13 | if os.path.isfile(pretrained): 14 | print('=> init model weights as normal') 15 | pas = 0 16 | for name, m in model.named_modules(): 17 | if 'd_model' in name or 'v_model' in name: 18 | pas += 1 19 | continue 20 | # if isinstance(m, nn.Conv2d): 21 | # nn.init.normal_(m.weight, std=0.001) 22 | if isinstance(m, nn.Conv2d): 23 | nn.init.normal_(m.weight, std=0.001) 24 | for name, _ in m.named_parameters(): 25 | if name in ['bias']: 26 | nn.init.constant_(m.bias, 0) 27 | elif isinstance(m, nn.BatchNorm2d): 28 | nn.init.constant_(m.weight, 1) 29 | nn.init.constant_(m.bias, 0) 30 | elif isinstance(m, nn.ConvTranspose2d): 31 | nn.init.normal_(m.weight, std=0.001) 32 | for name, _ in m.named_parameters(): 33 | if name in ['bias']: 34 | nn.init.constant_(m.bias, 0) 35 | suc = 0 36 | model_state_dict = model.state_dict() 37 | loaded_state_dict = torch.load(pretrained) 38 | for key in loaded_state_dict.keys(): 39 | new_key = key.replace('module.', '') 40 | if new_key in model_state_dict and loaded_state_dict[key].shape == model_state_dict[new_key].shape: 41 | model_state_dict[new_key] = loaded_state_dict[key] 42 | suc += 1 43 | model.load_state_dict(model_state_dict, strict=True) 44 | print('=> loaded pretrained model {}: {}/{} [{}]'.format(pretrained, suc, len(model_state_dict.keys()), pas)) 45 | else: 46 | print('=> imagenet pretrained model dose not exist') 47 | return model 48 | 49 | def get_ccfnet(config): 50 | if 'resnet' in config.backbone: 51 | from networks.ccfnet_resnet import get_ccfnet 52 | model = get_ccfnet(config) 53 | else: 54 | print('Undefined model: {}'.format(config.backbone)) 55 | sys.exit() 56 | model = init_weights(model, config.pretrained) 57 | return model 58 | 59 | def get_ccfnet_resnet(config): 60 | if 'resnet' in config.backbone: 61 | from networks.ccfnet_resnet import get_ccfnet_resnet 62 | model = get_ccfnet_resnet(config) 63 | else: 64 | print('Undefined model: {}'.format(config.backbone)) 65 | sys.exit() 66 | model = init_weights(model, config.pretrained) 67 | return model -------------------------------------------------------------------------------- /networks/ccfnet_resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from collections import OrderedDict 12 | from functools import partial 13 | from networks.parts import BasicBlock, Bottleneck, DetnetBottleneck, Block1, Block2, BN_MOMENTUM, HeadBlock 14 | 15 | class CCFNet(nn.Module): 16 | def __init__(self, config, block, layers, channels): 17 | super().__init__() 18 | self.inplanes = 64 19 | self.num_classes = config.num_classes 20 | self.num_v_classes = config.num_v_classes 21 | self.num_d_classes = config.num_d_classes 22 | self.num_d_v5_classes = config.num_d_v5_classes 23 | 24 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 25 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 28 | self.layer1 = self._make_layer(block, 64, layers[0]) 29 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 30 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 31 | #self.layer5 = self._make_detnet_layer(channels, channels) 32 | self.layer_conv1x1 = nn.Conv2d(channels, 256, kernel_size=1, padding=0, stride=1) 33 | #out_channels = self.num_classes * 3 + (5 + 6 * self.num_d_classes + 6) 34 | out_channels = self.num_classes * 3 + (5 + 6) 35 | self.stage1_headnet = HeadBlock(in_channels=256, out_channels=out_channels, num_classes=self.num_classes) 36 | 37 | def _make_layer(self, block, planes, blocks, stride=1): 38 | downsample = None 39 | if stride != 1 or self.inplanes != planes * block.expansion: 40 | downsample = nn.Sequential( 41 | nn.Conv2d(self.inplanes, planes * block.expansion, 42 | kernel_size=1, stride=stride, bias=False), 43 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), ) 44 | 45 | layers = [] 46 | layers.append(block(self.inplanes, planes, stride, downsample)) 47 | self.inplanes = planes * block.expansion 48 | for i in range(1, blocks): 49 | layers.append(block(self.inplanes, planes)) 50 | 51 | return nn.Sequential(*layers) 52 | 53 | def _make_detnet_layer(self, in_channels, out_channels): 54 | layers = [] 55 | layers.append(DetnetBottleneck(in_planes=in_channels, planes=out_channels, block_type='B')) 56 | layers.append(DetnetBottleneck(in_planes=out_channels, planes=out_channels, block_type='A')) 57 | layers.append(DetnetBottleneck(in_planes=out_channels, planes=out_channels, block_type='A')) 58 | return nn.Sequential(*layers) 59 | 60 | def _make_conv_bn_layer(self, in_channels, out_channels, kernel_size, stride, padding): 61 | layers = [] 62 | layers.append( 63 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)) 64 | layers.append(nn.BatchNorm2d(out_channels)) 65 | layers.append(nn.ReLU(inplace=True)) 66 | return nn.Sequential(*layers) 67 | 68 | def forward(self, x, **kwargs): 69 | x = self.conv1(x) 70 | x = self.bn1(x) 71 | x = self.relu(x) 72 | x = self.maxpool(x) 73 | 74 | x1 = self.layer1(x) 75 | x2 = self.layer2(x1) 76 | x3 = self.layer3(x2) 77 | #x4 = self.layer5(x3) 78 | x3 = self.layer_conv1x1(x3) 79 | x4 = x3 80 | 81 | predictions = {} 82 | heatmap, offymap, offxmap, clssmap = self.stage1_headnet(x4) 83 | predictions['result'] = [heatmap, offymap, offxmap, clssmap] 84 | 85 | return predictions 86 | 87 | class CCFNetResNet(CCFNet): 88 | def __init__(self, config, block, layers, channels): 89 | super().__init__(config, block, layers, channels) 90 | self.inplanes = 64 91 | self.num_classes = config.num_classes 92 | self.num_v_classes = config.num_v_classes 93 | self.num_d_classes = config.num_d_classes 94 | self.num_d_v5_classes = config.num_d_v5_classes 95 | 96 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 97 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 100 | self.layer1 = self._make_layer(block, 64, layers[0]) 101 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 102 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 103 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 104 | out_channels = self.num_classes * 3 + (5 + 6) 105 | self.final_layer = nn.Conv2d(in_channels=512 if block == BasicBlock else 2048, 106 | out_channels=out_channels, 107 | kernel_size=1, stride=1, padding=0) 108 | 109 | def forward(self, x, **kwargs): 110 | x = self.conv1(x) 111 | x = self.bn1(x) 112 | x = self.relu(x) 113 | x = self.maxpool(x) 114 | 115 | x = self.layer1(x) 116 | x = self.layer2(x) 117 | x = self.layer3(x) 118 | x = self.layer4(x) 119 | 120 | x = self.final_layer(x) 121 | 122 | heatmap = x[:, :self.num_classes] 123 | offymap = x[:, self.num_classes:self.num_classes * 2] 124 | offxmap = x[:, self.num_classes * 2:self.num_classes * 3] 125 | clssmap = x[:, self.num_classes * 3:] 126 | 127 | predictions = {} 128 | predictions['result'] = [heatmap, offymap, offxmap, clssmap] 129 | 130 | return predictions 131 | 132 | resnet_spec = {'resnet18': (BasicBlock, [2, 2, 2, 2], 256), 133 | 'resnet50': (Bottleneck, [3, 4, 6, 3], 1024)} 134 | 135 | def get_ccfnet(config): 136 | block_class, layers, channels = resnet_spec[config.backbone] 137 | model = CCFNet(config, block_class, layers, channels) 138 | return model 139 | 140 | def get_ccfnet_resnet(config): 141 | block_class, layers, channels = resnet_spec[config.backbone] 142 | model = CCFNetResNet(config, block_class, layers, channels) 143 | return model 144 | -------------------------------------------------------------------------------- /networks/parts.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | BN_MOMENTUM = 0.1 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = conv3x3(inplanes, planes, stride) 23 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | 40 | if self.downsample is not None: 41 | residual = self.downsample(x) 42 | 43 | out += residual 44 | out = self.relu(out) 45 | 46 | return out 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 58 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 59 | bias=False) 60 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, 61 | momentum=BN_MOMENTUM) 62 | self.relu = nn.ReLU(inplace=True) 63 | if inplanes != planes * self.expansion: 64 | self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, 65 | stride=stride, padding=0, bias=False), 66 | nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)) 67 | else: 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | class DetnetBottleneck(nn.Module): 94 | expansion = 1 95 | 96 | def __init__(self, in_planes, planes, stride=1, block_type='A'): 97 | super(DetnetBottleneck, self).__init__() 98 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 99 | self.bn1 = nn.BatchNorm2d(planes) 100 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=2, bias=False, dilation=2) 101 | self.bn2 = nn.BatchNorm2d(planes) 102 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 103 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 104 | 105 | self.downsample = nn.Sequential() 106 | if stride != 1 or in_planes != self.expansion * planes or block_type == 'B': 107 | self.downsample = nn.Sequential( 108 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 109 | nn.BatchNorm2d(self.expansion * planes) 110 | ) 111 | 112 | self._init_weight() 113 | 114 | def _init_weight(self): 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 118 | elif isinstance(m, nn.BatchNorm2d): 119 | nn.init.constant_(m.weight, 1) 120 | nn.init.constant_(m.bias, 0) 121 | elif isinstance(m, nn.Linear): 122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 123 | 124 | def forward(self, x): 125 | out = F.relu(self.bn1(self.conv1(x))) 126 | out = F.relu(self.bn2(self.conv2(out))) 127 | out = self.bn3(self.conv3(out)) 128 | out += self.downsample(x) 129 | out = F.relu(out) 130 | return out 131 | 132 | class Block1(nn.Module): 133 | def __init__(self, ch): 134 | super().__init__() 135 | self.dilate1 = nn.Conv2d(ch, ch, kernel_size=3, dilation=1, padding=1) 136 | self.dilate2 = nn.Conv2d(ch, ch, kernel_size=3, dilation=3, padding=3) 137 | self.dilate3 = nn.Conv2d(ch, ch, kernel_size=3, dilation=5, padding=5) 138 | self.conv1x1 = nn.Conv2d(ch, ch, kernel_size=1, dilation=1, padding=0) 139 | self.relu = nn.ReLU(inplace=True) 140 | 141 | def forward(self, x): 142 | dilate1_out = self.relu(self.dilate1(x)) 143 | dilate2_out = self.relu(self.conv1x1(self.dilate2(x))) 144 | dilate3_out = self.relu(self.conv1x1(self.dilate2(self.dilate1(x)))) 145 | dilate4_out = self.relu(self.conv1x1(self.dilate3(self.dilate2(self.dilate1(x))))) 146 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out 147 | return out 148 | 149 | class Block2(nn.Module): 150 | def __init__(self, ch): 151 | super().__init__() 152 | self.poo1 = nn.MaxPool2d(kernel_size=[2, 2], stride=2) 153 | self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 154 | self.poo2 = nn.MaxPool2d(kernel_size=[4, 4], stride=4) 155 | self.up2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 156 | self.poo3 = nn.MaxPool2d(kernel_size=[8, 8], stride=8) 157 | self.up3 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 158 | self.conv = nn.Conv2d(ch, 1, kernel_size=1, padding=0, bias=True) 159 | 160 | def forward(self, x): 161 | self.layer1 = self.up1(self.conv(self.poo1(x))) 162 | self.layer2 = self.up2(self.conv(self.poo2(x))) 163 | self.layer3 = self.up3(self.conv(self.poo3(x))) 164 | out = torch.cat([self.layer1, self.layer2, self.layer3, x], dim=1) 165 | return out 166 | 167 | class HeadBlock(nn.Module): 168 | def __init__(self, in_channels, out_channels, num_classes): 169 | super().__init__() 170 | self.num_classes = num_classes 171 | self.feat = nn.Sequential(Block1(in_channels), 172 | Block2(in_channels), 173 | Bottleneck(in_channels + 3, in_channels // Bottleneck.expansion)) 174 | self.out = nn.Conv2d(in_channels=in_channels // Bottleneck.expansion * Bottleneck.expansion, 175 | out_channels=out_channels, 176 | kernel_size=1, stride=1, padding=0) 177 | 178 | def forward(self, x): 179 | x = self.feat(x) 180 | x = self.out(x) 181 | 182 | heatmap = x[:, :self.num_classes] 183 | offymap = x[:, self.num_classes:self.num_classes * 2] 184 | offxmap = x[:, self.num_classes * 2:self.num_classes * 3] 185 | clssmap = x[:, self.num_classes * 3:] 186 | 187 | return heatmap, offymap, offxmap, clssmap -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.0.5 2 | matplotlib==3.2.2 3 | scikit_image==0.17.2 4 | numpy==1.16.2 5 | tqdm==4.47.0 6 | Pillow==7.2.0 7 | scikit_learn==0.23.2 8 | SimpleITK==1.2.4 9 | -------------------------------------------------------------------------------- /utils/ccf_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import numpy as np 5 | import torch 6 | import itertools 7 | 8 | def _softmax(x): 9 | return np.exp(x) / np.sum(np.exp(x)) 10 | 11 | def _sigmoid(x): 12 | return 1 / (1 + np.exp(-x)) 13 | 14 | def batch_ccf_analysis(heatmaps, clssmaps_target, clssmaps_predict): 15 | heatmaps = heatmaps.detach().cpu().numpy() 16 | clssmaps_target = clssmaps_target.detach().cpu().numpy() 17 | clssmaps_predict = clssmaps_predict.detach().cpu().numpy() 18 | B, C, H, W = heatmaps.shape 19 | records = [] 20 | for heatmap, clssmap_target, clssmap_predict in zip(heatmaps, clssmaps_target, clssmaps_predict): 21 | heatmap_reshaped = heatmap.reshape((C, -1)) 22 | idxs = heatmap_reshaped.argsort(1)[:, -1:] 23 | for i in range(0, 11): 24 | grid_y = int(idxs[i, -1] // W) 25 | grid_x = int(idxs[i, -1] % W) 26 | clss_target = clssmap_target[i, grid_y, grid_x] 27 | if i < 5: 28 | clss_predict = np.argmax(clssmap_predict[i * 2:(i + 1) * 2, grid_y, grid_x], axis=0) 29 | clss_score = _softmax(clssmap_predict[i * 2:(i + 1) * 2, grid_y, grid_x]) 30 | else: 31 | clss_predict = np.argmax(clssmap_predict[10 + (i - 5) * 4:10 + (i + 1 - 5) * 4, grid_y, grid_x], axis=0) 32 | clss_score = _softmax(clssmap_predict[10 + (i - 5) * 4:10 + (i + 1 - 5) * 4, grid_y, grid_x]) 33 | records.append([i, clss_target, clss_predict, clss_target == clss_predict, list(clss_score)]) 34 | return records 35 | 36 | def _gaussian(x, sigma): 37 | return np.exp(-x / (2 * sigma ** 2)) 38 | 39 | def ccf_encoder(config, gt_coords, gt_classes, tensor=True): 40 | if not isinstance(gt_coords, np.ndarray): 41 | gt_coords = gt_coords.detach().cpu().numpy() 42 | heatmap = np.zeros((config.num_classes, *config.predict_size)) 43 | offymap = np.zeros((config.num_classes, *config.predict_size)) 44 | offxmap = np.zeros((config.num_classes, *config.predict_size)) 45 | maskmap = np.zeros((config.num_classes, *config.predict_size)) 46 | clssmap = np.zeros((config.num_classes + 6, *config.predict_size)) 47 | 48 | gridmap = np.array(list(itertools.product(range(1, config.predict_size[0] + 1), 49 | range(1, config.predict_size[1] + 1)))) 50 | gridmap = (gridmap - 0.5) * config.stride 51 | 52 | for i, (gt_coord, gt_class) in enumerate(zip(gt_coords, gt_classes[:gt_coords.shape[0]])): 53 | distance = np.square((gt_coord[::-1] - gridmap)).sum(axis=-1) 54 | heatmap[i] = (_gaussian(distance, config.sigma).reshape(*config.predict_size)) 55 | heatmap[i] = heatmap[i] / heatmap[i].max() 56 | offset = ((gt_coord[::-1] - gridmap) / config.stride).reshape(*config.predict_size, -1) 57 | offymap[i] = offset[:, :, 0] 58 | offxmap[i] = offset[:, :, 1] 59 | clssmap[i] = gt_class 60 | 61 | if heatmap[i].max() >= config.threshold: 62 | maskmap[i] = 1 63 | 64 | for i, gt_class in enumerate(gt_classes[gt_coords.shape[0]:]): 65 | clssmap[gt_coords.shape[0] + i] = int(gt_class > 0) 66 | 67 | if tensor: 68 | heatmap = torch.from_numpy(heatmap) 69 | offymap = torch.from_numpy(offymap) 70 | offxmap = torch.from_numpy(offxmap) 71 | maskmap = torch.from_numpy(maskmap) 72 | clssmap = torch.from_numpy(clssmap) 73 | 74 | return heatmap, offymap, offxmap, maskmap, clssmap 75 | 76 | def ccf_decoder(config, heatmap, offymap, offxmap, clssmap, maskmap=None, tensor=True): 77 | assert len(heatmap.shape) == 3 78 | assert len(offymap.shape) == 3 79 | assert len(offxmap.shape) == 3 80 | if not isinstance(heatmap, np.ndarray): 81 | heatmap = heatmap.detach().cpu().numpy() 82 | if not isinstance(offymap, np.ndarray): 83 | offymap = offymap.detach().cpu().numpy() 84 | if not isinstance(offxmap, np.ndarray): 85 | offxmap = offxmap.detach().cpu().numpy() 86 | if (maskmap is not None) and (not isinstance(maskmap, np.ndarray)): 87 | maskmap = maskmap.detach().cpu().numpy() 88 | if (clssmap is not None) and (not isinstance(clssmap, np.ndarray)): 89 | clssmap = clssmap.detach().cpu().numpy() # (num_classes * 7, H, W) 90 | C, H, W = heatmap.shape 91 | assert C == config.num_classes 92 | heatmap_reshaped = heatmap.reshape((C, -1)) 93 | idxs = heatmap_reshaped.argsort(1)[:, -config.top_k:] 94 | scores = np.zeros((C, 1)) 95 | predictions = np.zeros((C, 2)) 96 | cls_predictions = np.array([1, 1, 1, 1, 1, 0, 0, 0, 1, 2, 2]) 97 | d_v5_cls_predictions = np.array([0, 0, 0, 0, 0, 0]) 98 | 99 | # v_thresholds = torch.Tensor([value[0] / (value[0] + value[1])for key, value in config.v_numbers.items()]) 100 | # d_v5_thresholds = torch.Tensor([value[0] / (value[0] + value[1])for key, value in config.d_v5_numbers.items()]) 101 | 102 | for i in range(C): 103 | if (maskmap is not None) and (maskmap[i].sum() <= 0): 104 | continue 105 | weight = 1 if config.vote == 'average' else heatmap_reshaped[i, idxs[i, -1]] 106 | grid_y = int(idxs[i, -1] // W) 107 | grid_x = int(idxs[i, -1] % W) 108 | predictions[i, 1] = (grid_y + offymap[i, grid_y, grid_x] + 0.5) * config.stride * weight 109 | predictions[i, 0] = (grid_x + offxmap[i, grid_y, grid_x] + 0.5) * config.stride * weight 110 | scores[i, 0] = heatmap_reshaped[i, idxs[i, -1]] 111 | num = weight 112 | if clssmap.shape[0]: 113 | if i < 5: 114 | cls_predictions[i] = (_sigmoid(clssmap[i, grid_y, grid_x]) >= 0.5) 115 | else: 116 | cls_predictions[i] = np.argmax(clssmap[5 + (i - 5) * 4:5 + (i + 1 - 5) * 4, grid_y, grid_x], axis=0) 117 | d_v5_cls_predictions[i - 5] = (_sigmoid(clssmap[29 + (i - 5), grid_y, grid_x]) >= 0.5) 118 | 119 | for j in range(config.top_k - 1): 120 | if heatmap_reshaped[i, idxs[i, -(j + 2)]] <= config.threshold: 121 | continue 122 | weight = 1 if config.vote == 'average' else heatmap_reshaped[i, idxs[i, -(j + 2)]] 123 | grid_y = int(idxs[i, -(j + 2)] // W) 124 | grid_x = int(idxs[i, -(j + 2)] % W) 125 | predictions[i, 1] += (grid_y + offymap[i, grid_y, grid_x] + 0.5) * config.stride * weight 126 | predictions[i, 0] += (grid_x + offxmap[i, grid_y, grid_x] + 0.5) * config.stride * weight 127 | scores[i, 0] += heatmap_reshaped[i, idxs[i, -(j + 2)]] 128 | num += weight 129 | predictions[i, 0] /= num 130 | predictions[i, 1] /= num 131 | scores[i, 0] /= num 132 | 133 | if tensor: 134 | predictions = torch.from_numpy(predictions) # (C, 2) 135 | scores = torch.from_numpy(scores) # (C, 1) 136 | cls_predictions = torch.from_numpy(cls_predictions) # (C,) 137 | d_v5_cls_predictions = torch.from_numpy(d_v5_cls_predictions) # (6,) 138 | 139 | return predictions, scores, cls_predictions, d_v5_cls_predictions 140 | 141 | def batch_ccf_decoder(config, heatmaps, offymaps, offxmaps, clssmaps, maskmaps=None, tensor=True): 142 | results = list( 143 | map(ccf_decoder, [config] * heatmaps.shape[0], heatmaps, offymaps, offxmaps, clssmaps, 144 | maskmaps if maskmaps is not None else [None] * heatmaps.shape[0], [tensor] * heatmaps.shape[0])) 145 | predictions, cls_predictions, d_v5_cls_predictions = [], [], [] 146 | for prediction, _, cls_prediction, d_v5_cls_prediction in results: 147 | predictions.append(prediction) 148 | cls_predictions.append(cls_prediction) 149 | d_v5_cls_predictions.append(d_v5_cls_prediction) 150 | predictions = torch.stack(predictions, dim=0) # (B, 11, 2) 151 | cls_predictions = torch.stack(cls_predictions, dim=0) # (B, 11) 152 | d_v5_cls_predictions = torch.stack(d_v5_cls_predictions, dim=0) # (B, 6) 153 | return predictions, cls_predictions, d_v5_cls_predictions 154 | 155 | 156 | def ccf_decoder_prob(config, heatmap, offymap, offxmap, clssmap, maskmap=None, tensor=True): 157 | assert len(heatmap.shape) == 3 158 | assert len(offymap.shape) == 3 159 | assert len(offxmap.shape) == 3 160 | if not isinstance(heatmap, np.ndarray): 161 | heatmap = heatmap.detach().cpu().numpy() 162 | if not isinstance(offymap, np.ndarray): 163 | offymap = offymap.detach().cpu().numpy() 164 | if not isinstance(offxmap, np.ndarray): 165 | offxmap = offxmap.detach().cpu().numpy() 166 | if (maskmap is not None) and (not isinstance(maskmap, np.ndarray)): 167 | maskmap = maskmap.detach().cpu().numpy() 168 | if (clssmap is not None) and (not isinstance(clssmap, np.ndarray)): 169 | clssmap = clssmap.detach().cpu().numpy() # (num_classes * 7, H, W) 170 | C, H, W = heatmap.shape 171 | assert C == config.num_classes 172 | heatmap_reshaped = heatmap.reshape((C, -1)) 173 | idxs = heatmap_reshaped.argsort(1)[:, -config.top_k:] 174 | scores = np.zeros((C, 1)) 175 | predictions = np.zeros((C, 2)) 176 | cls_predictions = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]).astype(np.float) 177 | 178 | for i in range(C): 179 | if (maskmap is not None) and (maskmap[i].sum() <= 0): 180 | continue 181 | weight = 1 if config.vote == 'average' else heatmap_reshaped[i, idxs[i, -1]] 182 | grid_y = int(idxs[i, -1] // W) 183 | grid_x = int(idxs[i, -1] % W) 184 | predictions[i, 1] = (grid_y + offymap[i, grid_y, grid_x] + 0.5) * config.stride * weight 185 | predictions[i, 0] = (grid_x + offxmap[i, grid_y, grid_x] + 0.5) * config.stride * weight 186 | scores[i, 0] = heatmap_reshaped[i, idxs[i, -1]] 187 | num = weight 188 | if clssmap.shape[0]: 189 | cls_predictions[i] = _sigmoid(clssmap[i, grid_y, grid_x]) 190 | 191 | for j in range(config.top_k - 1): 192 | if heatmap_reshaped[i, idxs[i, -(j + 2)]] <= config.threshold: 193 | continue 194 | weight = 1 if config.vote == 'average' else heatmap_reshaped[i, idxs[i, -(j + 2)]] 195 | grid_y = int(idxs[i, -(j + 2)] // W) 196 | grid_x = int(idxs[i, -(j + 2)] % W) 197 | predictions[i, 1] += (grid_y + offymap[i, grid_y, grid_x] + 0.5) * config.stride * weight 198 | predictions[i, 0] += (grid_x + offxmap[i, grid_y, grid_x] + 0.5) * config.stride * weight 199 | scores[i, 0] += heatmap_reshaped[i, idxs[i, -(j + 2)]] 200 | num += weight 201 | predictions[i, 0] /= num 202 | predictions[i, 1] /= num 203 | scores[i, 0] /= num 204 | 205 | if tensor: 206 | predictions = torch.from_numpy(predictions) # (C, 2) 207 | scores = torch.from_numpy(scores) # (C, 1) 208 | cls_predictions = torch.from_numpy(cls_predictions) # (C,) 209 | 210 | return predictions, scores, cls_predictions -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from typing import Dict, Any, Tuple 7 | from utils.study_utils import Study 8 | from torch.utils.data import Dataset 9 | from utils.dicom_utils import DICOM 10 | from utils.cl_utils import cl_encoder 11 | from glob import glob 12 | import os 13 | import numpy as np 14 | from itertools import chain 15 | import heapq 16 | 17 | def prim(graph): 18 | n = len(graph) 19 | v = 0 20 | s = {v} 21 | edges = [] 22 | res = [] 23 | record = [v] 24 | for _ in range(n - 1): 25 | for u, w in graph[v].items(): 26 | heapq.heappush(edges, (w, v, u)) 27 | while edges: 28 | w, p, q = heapq.heappop(edges) 29 | if q not in s: 30 | s.add(q) 31 | record.append(q) 32 | res.append(((p, q), w)) 33 | v = q 34 | break 35 | return res, record 36 | 37 | def check_annotation(v_annotation, d_annotation, d_v5_annotation): 38 | def _distance(a, b): 39 | return (((a - b) ** 2).sum()) ** 0.5 40 | v_annotation = v_annotation.detach().cpu().numpy() 41 | d_annotation = d_annotation.detach().cpu().numpy() 42 | d_v5_annotation = d_v5_annotation.detach().cpu().numpy() 43 | temp = list(chain.from_iterable(zip(d_annotation[:-1, :2], v_annotation[:, :2]))) + [d_annotation[-1, :2]] 44 | graph = [[] for _ in range(len(temp))] 45 | for i in range(len(temp)): 46 | graph[i] = {} 47 | for j in range(len(temp)): 48 | if j == i: continue 49 | graph[i][j] = _distance(temp[i], temp[j]) 50 | _, path = prim(graph) 51 | temp = [temp[i] for i in path] 52 | for i in range(len(temp)): 53 | if i % 2 == 0: # d 54 | d_annotation[i//2, :2] = temp[i] 55 | d_v5_annotation[i//2, :2] = temp[i] 56 | else: 57 | v_annotation[i//2, :2] = temp[i] 58 | v_annotation = torch.from_numpy(v_annotation) 59 | d_annotation = torch.from_numpy(d_annotation) 60 | d_v5_annotation = torch.from_numpy(d_v5_annotation) 61 | return v_annotation, d_annotation, d_v5_annotation 62 | 63 | class SpineDataSet(Dataset): 64 | def __init__(self, 65 | config, 66 | studies: Dict[Any, Study], 67 | annotations: Dict[Any, Tuple[torch.Tensor, torch.Tensor]]): 68 | super().__init__() 69 | self.config = config 70 | self.num_rep = config.num_rep 71 | self.studies = studies 72 | self.annotations = [] 73 | for k, annotation in annotations.items(): 74 | study_uid, series_uid, instance_uid = k 75 | if study_uid not in self.studies: 76 | continue 77 | study = self.studies[study_uid] 78 | if series_uid in study and instance_uid in study[series_uid].instance_uids: 79 | self.annotations.append((k, annotation)) 80 | 81 | def __getitem__(self, item): 82 | item = item % len(self.annotations) 83 | key, (v_annotation, d_annotation, d_v5_annotation) = self.annotations[item] 84 | #v_annotation, d_annotation, d_v5_annotation = check_annotation(v_annotation, d_annotation, d_v5_annotation) 85 | return self.studies[key[0]], key, v_annotation, d_annotation, d_v5_annotation 86 | 87 | def collate_fn(self, data) -> (Tuple[torch.Tensor], Tuple[None]): 88 | sagittal_images = [] 89 | transverse_images = [] 90 | transverse_masks = [] 91 | hflip_flags = [] 92 | heatmaps, offymaps, offxmaps, maskmaps, clssmaps = [], [], [], [], [] 93 | for study, key, v_anno, d_anno, d_v5_anno in data: 94 | instance_uid = study[key[1]].instance_idxs[study[key[1]].instance_uids[key[2]] + np.random.randint(-1, 2)] 95 | dicom: DICOM = study[key[1]][instance_uid] 96 | # random_p = np.random.uniform(0, 1) 97 | # if random_p < 1.0: 98 | # instance_uid = study[key[1]].instance_idxs[study[key[1]].instance_uids[key[2]] + np.random.randint(-1, 2)] 99 | # dicom: DICOM = study[key[1]][instance_uid] 100 | # elif random_p < 0.8: 101 | # dicom: DICOM = study.t2_sagittal_middle_frame 102 | # else: 103 | # dicom: DICOM = study.t1_sagittal_middle_frame 104 | pixel_coord = torch.cat([v_anno[:, :2], d_anno[:, :2]], dim=0) # (11, 2) 105 | pixel_noise = torch.zeros_like(pixel_coord).float().uniform_( 106 | -self.config.max_dist / 4 / dicom.pixel_spacing[1].item(), 107 | self.config.max_dist / 4 / dicom.pixel_spacing[0].item()) 108 | pixel_coord = pixel_coord + pixel_noise 109 | pixel_class = torch.cat([v_anno[:, -1], d_anno[:, -1], d_v5_anno[:, -1]]) # (11+6,) 110 | sagittal_image, pixel_coord, details = dicom.transform(self.config.sagittal_trans_settings, pixel_coord, 111 | tensor=True, isaug=True) 112 | sagittal_images.append(sagittal_image) 113 | 114 | # transverse_image, transverse_mask = study.t2_transverse_k_nearest(self.config, True, d_anno[:, :2], 1) 115 | # transverse_image = transverse_image[:, 0, :] 116 | transverse_image = torch.zeros(6, 1, 96, 96) 117 | transverse_mask = torch.zeros(6) 118 | transverse_images.append(transverse_image) # (6, 1, 96, 96) 119 | transverse_masks.append(transverse_mask) 120 | 121 | heatmap, offymap, offxmap, maskmap, clssmap = cl_encoder(self.config, pixel_coord, pixel_class, tensor=True) 122 | heatmaps.append(heatmap) 123 | offymaps.append(offymap) 124 | offxmaps.append(offxmap) 125 | maskmaps.append(maskmap) 126 | clssmaps.append(clssmap) 127 | 128 | sagittal_images = torch.stack(sagittal_images, dim=0) 129 | transverse_images = torch.stack(transverse_images, dim=0) # (B, 6, 3, 80, 80) 130 | transverse_masks = torch.stack(transverse_masks, dim=0) # (B, 6) 131 | heatmaps = torch.stack(heatmaps, dim=0) 132 | offymaps = torch.stack(offymaps, dim=0) 133 | offxmaps = torch.stack(offxmaps, dim=0) 134 | maskmaps = torch.stack(maskmaps, dim=0) 135 | clssmaps = torch.stack(clssmaps, dim=0) 136 | 137 | return sagittal_images, heatmaps, offymaps, offxmaps, maskmaps, clssmaps, transverse_images, transverse_masks 138 | 139 | def __len__(self): 140 | return len(self.annotations) * self.num_rep 141 | -------------------------------------------------------------------------------- /utils/dicom_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import SimpleITK as sitk 7 | from typing import Union 8 | from PIL import Image 9 | import torchvision.transforms.functional as tf 10 | import numpy as np 11 | from utils.utils import ( 12 | lazy_property, 13 | str2tensor, 14 | unit_vector, 15 | unit_normal_vector 16 | ) 17 | from utils.transforms import ( 18 | random_rotate, 19 | random_shift, 20 | random_crop, 21 | random_intensity, 22 | random_bias_field, 23 | random_noise, 24 | random_hflip, 25 | resize 26 | ) 27 | 28 | 29 | class DICOM: 30 | def __init__(self, file_path, DICOM_TAG): 31 | self.file_path = file_path 32 | self.error_msg = '' 33 | self.DICOM_TAG = DICOM_TAG 34 | 35 | reader = sitk.ImageFileReader() 36 | reader.LoadPrivateTagsOn() 37 | reader.SetImageIO('GDCMImageIO') 38 | reader.SetFileName(file_path) 39 | try: 40 | reader.ReadImageInformation() 41 | except RuntimeError: 42 | pass 43 | 44 | self.study_uid: str = self._get_meta_data(reader, 'studyUid', self.DICOM_TAG, '') 45 | self.series_uid: str = self._get_meta_data(reader, 'seriesUid', self.DICOM_TAG, '') 46 | self.instance_uid: str = self._get_meta_data(reader, 'instanceUid', self.DICOM_TAG, '') 47 | self.series_description: str = self._get_meta_data(reader, 'seriesDescription', self.DICOM_TAG, '') 48 | self._pixel_spacing = self._get_meta_data(reader, 'pixelSpacing', self.DICOM_TAG, None) 49 | self._image_position = self._get_meta_data(reader, 'imagePosition', self.DICOM_TAG, None) 50 | self._image_orientation = self._get_meta_data(reader, 'imageOrientation', self.DICOM_TAG, None) 51 | 52 | try: 53 | image = reader.Execute() 54 | if image.GetNumberOfComponentsPerPixel() == 1: 55 | image = sitk.RescaleIntensity(image, 0, 255) 56 | if reader.GetMetaData('0028|0004').strip() == 'MONOCHROME1': 57 | image = sitk.InvertIntensity(image, maximum=255) 58 | image = sitk.Cast(image, sitk.sitkUInt8) 59 | img_x = sitk.GetArrayFromImage(image)[0] 60 | self.image: Image.Image = tf.to_pil_image(img_x) 61 | except RuntimeError: 62 | self.image = None 63 | 64 | def _get_meta_data(self, reader: sitk.ImageFileReader, key: str, DICOM_TAG: dict, 65 | failed_return: Union[None, str]) -> Union[None, str]: 66 | try: 67 | return reader.GetMetaData(DICOM_TAG[key]) 68 | except RuntimeError: 69 | return failed_return 70 | 71 | @lazy_property 72 | def pixel_spacing(self): 73 | if self._pixel_spacing is None: 74 | return torch.full([2, ], fill_value=np.nan) 75 | else: 76 | return str2tensor(self._pixel_spacing) 77 | 78 | @lazy_property 79 | def image_position(self): 80 | if self._image_position is None: 81 | return torch.full([3, ], fill_value=np.nan) 82 | else: 83 | return str2tensor(self._image_position) 84 | 85 | @lazy_property 86 | def image_orientation(self): 87 | if self._image_orientation is None: 88 | return torch.full([2, 3], fill_value=np.nan) 89 | else: 90 | return unit_vector(str2tensor(self._image_orientation).reshape(2, 3)) 91 | 92 | @lazy_property 93 | def unit_normal_vector(self): 94 | if self.image_orientation is None: 95 | return torch.full([3, ], fill_value=np.nan) 96 | else: 97 | return unit_normal_vector(self.image_orientation) 98 | 99 | @lazy_property 100 | def t_type(self): 101 | if 'T1' in self.series_description.upper(): 102 | return 'T1' 103 | elif 'T2' in self.series_description.upper(): 104 | return 'T2' 105 | else: 106 | return None 107 | 108 | @lazy_property 109 | def t_info(self): 110 | return self.series_description.upper() 111 | 112 | @lazy_property 113 | def plane(self): 114 | if torch.isnan(self.unit_normal_vector).all(): 115 | return None 116 | elif torch.matmul(self.unit_normal_vector, torch.tensor([0., 0., 1.])).abs() > 0.75: 117 | return 'transverse' 118 | elif torch.matmul(self.unit_normal_vector, torch.tensor([1., 0., 0.])).abs() > 0.75: 119 | return 'sagittal' 120 | elif torch.matmul(self.unit_normal_vector, torch.tensor([0., 1., 0.])).abs() > 0.75: 121 | return 'coronal' 122 | else: 123 | return None 124 | 125 | @lazy_property 126 | def mean(self): 127 | if self.image is None: 128 | return None 129 | else: 130 | return tf.to_tensor(self.image).mean() 131 | 132 | @property 133 | def size(self): 134 | if self.image is None: 135 | return None 136 | else: 137 | return self.image.size 138 | 139 | def pixel_coord2human_coord(self, coord: torch.Tensor) -> torch.Tensor: 140 | return torch.matmul(coord * self.pixel_spacing, self.image_orientation) + self.image_position 141 | 142 | def point_distance(self, human_coord: torch.Tensor) -> torch.Tensor: 143 | return torch.matmul(human_coord - self.image_position, self.unit_normal_vector).abs() 144 | 145 | def projection(self, human_coord: torch.Tensor) -> torch.Tensor: 146 | cos = torch.matmul(human_coord - self.image_position, self.image_orientation.transpose(0, 1)) 147 | return (cos / self.pixel_spacing).round() 148 | 149 | def transform(self, trans_dict, pixel_coord: torch.Tensor, tensor=True, isaug=True) -> ( 150 | torch.Tensor, torch.Tensor, dict): 151 | image = self.image 152 | pixel_spacing = self.pixel_spacing 153 | 154 | if len(pixel_coord.shape) == 1: 155 | pixel_coord = pixel_coord.unsqueeze(dim=0) 156 | 157 | details = {'is_hflip': False} 158 | if isaug and np.random.uniform(0, 1) < trans_dict['rotate_p']: 159 | image, pixel_coord = random_rotate(image, pixel_coord, trans_dict['max_angel']) 160 | if isaug and np.random.uniform(0, 1) < trans_dict['shift_p']: 161 | image, pixel_coord = random_shift(image, pixel_coord, trans_dict['max_shift_ratios']) 162 | if isaug and np.random.uniform(0, 1) < trans_dict['crop_p']: 163 | image, pixel_coord = random_crop(image, pixel_coord, trans_dict['max_crop_ratios']) 164 | if isaug and np.random.uniform(0, 1) < trans_dict['hflip_p']: 165 | image, pixel_coord = random_hflip(image, pixel_coord) 166 | details['is_hflip'] = True 167 | if isaug and np.random.uniform(0, 1) < trans_dict['intensity_p']: 168 | image = random_intensity(image, trans_dict['max_intensity_ratio']) 169 | if isaug and np.random.uniform(0, 1) < trans_dict['noise_p']: 170 | image = random_noise(image, trans_dict['noise_mean'], trans_dict['noise_std']) 171 | if isaug and np.random.uniform(0, 1) < trans_dict['bias_field_p']: 172 | image = random_bias_field(image, trans_dict['order'], trans_dict['coefficients_range']) 173 | # resize 174 | if trans_dict['size'] is not None: 175 | image, pixel_coord, pixel_spacing = resize(image, pixel_coord, pixel_spacing, trans_dict['size']) 176 | 177 | if tensor: 178 | image = tf.to_tensor(image) 179 | pixel_coord = pixel_coord.round().long() 180 | 181 | return image, pixel_coord, details 182 | 183 | def transverse_transform(self, trans_dict, pixel_coord: torch.Tensor, tensor=True, isaug=True) -> ( 184 | torch.Tensor, torch.Tensor): 185 | image = self.image 186 | pixel_spacing = self.pixel_spacing 187 | 188 | if trans_dict['size'] is not None: 189 | transverse_size = (int(trans_dict['size'][0] / pixel_spacing[0].item()), 190 | int(trans_dict['size'][1] / pixel_spacing[1].item())) 191 | image = tf.crop(image, 192 | int(pixel_coord[1].item() - transverse_size[0] // 2), 193 | int(pixel_coord[0].item() - transverse_size[1] // 2), 194 | transverse_size[0], 195 | transverse_size[1]) 196 | 197 | if len(pixel_coord.shape) == 1: 198 | pixel_coord = pixel_coord.unsqueeze(dim=0) 199 | 200 | if isaug and np.random.uniform(0, 1) < trans_dict['rotate_p']: 201 | image, pixel_coord = random_rotate(image, pixel_coord, trans_dict['max_angel']) 202 | if isaug and np.random.uniform(0, 1) < trans_dict['shift_p']: 203 | image, pixel_coord = random_shift(image, pixel_coord, trans_dict['max_shift_ratios']) 204 | if isaug and np.random.uniform(0, 1) < trans_dict['crop_p']: 205 | image, pixel_coord = random_crop(image, pixel_coord, trans_dict['max_crop_ratios']) 206 | if isaug and np.random.uniform(0, 1) < trans_dict['hflip_p']: 207 | image, pixel_coord = random_hflip(image, pixel_coord) 208 | if isaug and np.random.uniform(0, 1) < trans_dict['intensity_p']: 209 | image = random_intensity(image, trans_dict['max_intensity_ratio']) 210 | if isaug and np.random.uniform(0, 1) < trans_dict['noise_p']: 211 | image = random_noise(image, trans_dict['noise_mean'], trans_dict['noise_std']) 212 | if isaug and np.random.uniform(0, 1) < trans_dict['bias_field_p']: 213 | image = random_bias_field(image, trans_dict['order'], trans_dict['coefficients_range']) 214 | # resize 215 | if trans_dict['size'] is not None: 216 | image, pixel_coord, pixel_spacing = resize(image, pixel_coord, pixel_spacing, trans_dict['size']) 217 | 218 | if tensor: 219 | image = tf.to_tensor(image) 220 | pixel_coord = pixel_coord.round().long() 221 | 222 | return image, pixel_coord -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import json 7 | import torch 8 | import numpy as np 9 | from math import log 10 | from utils.cl_utils import ( 11 | cl_decoder, batch_cl_decoder, batch_cl_analysis 12 | ) 13 | import torchvision.transforms.functional as tf 14 | from utils.utils import ( 15 | gen_annotation, 16 | confusion_matrix, 17 | confusion_matrix_with_image, 18 | cal_metrics, 19 | format_annotation 20 | ) 21 | from tqdm import tqdm 22 | import time 23 | 24 | def bagging_train_valid(config, epoch, model, loader, criterion, params, optimizer=None, train=True): 25 | model.train() if train else model.eval() 26 | debug_records = [] 27 | epoch_records = {'time': []} 28 | num_batchs = len(loader) 29 | v_loss_weight, d_loss_weight = params['v_loss_weight'], params['d_loss_weight'] 30 | print('v loss weight: {} | d loss weight: {}'.format(v_loss_weight, d_loss_weight)) 31 | for batch_idx, (sagittal_images, heatmaps, offymaps, offxmaps, maskmaps, clssmaps, _, _) in enumerate(loader): 32 | start_time = time.time() 33 | images = sagittal_images.float().to(config.device) 34 | heat_targets = heatmaps.float().to(config.device) 35 | offy_targets = offymaps.float().to(config.device) 36 | offx_targets = offxmaps.float().to(config.device) 37 | masks = maskmaps.float().to(config.device) 38 | clss_targets = clssmaps.long().to(config.device) 39 | 40 | if train: 41 | predictions = model(images, heatmaps=heat_targets, offymaps=offy_targets, offxmaps=offx_targets) 42 | heatmap, offymap, offxmap, clssmap = predictions['result'] 43 | loss, info = criterion(heatmap, offymap, offxmap, clssmap, 44 | heat_targets, offy_targets, offx_targets, clss_targets, masks, 45 | v_loss_weight, d_loss_weight, name='predict') 46 | loss = torch.abs(loss - config.flooding_d) + config.flooding_d 47 | 48 | optimizer.zero_grad() 49 | loss.backward() 50 | optimizer.step() 51 | else: 52 | with torch.no_grad(): 53 | predictions = model(images, heatmaps=heat_targets, offymaps=offy_targets, offxmaps=offx_targets) 54 | heatmap, offymap, offxmap, clssmap = predictions['result'] 55 | loss, info = criterion(heatmap, offymap, offxmap, clssmap, 56 | heat_targets, offy_targets, offx_targets, clss_targets, masks, 57 | v_loss_weight, d_loss_weight, name='predict') 58 | 59 | for key, value in info.items(): 60 | if key not in epoch_records: epoch_records[key] = [] 61 | epoch_records[key].append(value) 62 | epoch_records['time'].append(time.time() - start_time) 63 | 64 | batch_records = batch_cl_analysis(heat_targets, clss_targets, clssmap) 65 | debug_records.extend(batch_records) 66 | 67 | if (batch_idx and batch_idx % config.display == 0) or (batch_idx == num_batchs - 1): 68 | context = '[{}] EP:{:03d}\tTI:{:03d}/{:03d}\t'.format('T' if train else 'V', epoch, batch_idx, num_batchs) 69 | context_predict = '\t' 70 | context_class = '\t' 71 | for key, value in epoch_records.items(): 72 | if 'cls' in key: 73 | context_class += '{}:{:.6f}({:.6f})\t'.format(key, value[-1], np.mean(value)) 74 | elif 'predict' in key: 75 | context_predict += '{}:{:.6f}({:.6f})\t'.format(key, value[-1], np.mean(value)) 76 | else: 77 | context += '{}:{:.6f}({:.6f})\t'.format(key, value[-1], np.mean(value)) 78 | print(context) 79 | print(context_predict) 80 | print(context_class) 81 | 82 | if not train: 83 | params['v_loss'].append(np.mean(epoch_records['predict_v_cls_loss'])) 84 | params['d_loss'].append(np.mean(epoch_records['predict_d_cls_loss'])) 85 | 86 | return np.mean(epoch_records['loss']), params, epoch_records 87 | 88 | 89 | 90 | def train(config, epoch, model, train_loader, criterion, optimizer): 91 | model.train() 92 | debug_records = [] 93 | epoch_records = {'time': []} 94 | num_batchs = len(train_loader) 95 | for batch_idx, (sagittal_images, heatmaps, offymaps, offxmaps, maskmaps, clssmaps, _, _) in enumerate( 96 | train_loader): 97 | start_time = time.time() 98 | images = sagittal_images.float().to(config.device) 99 | heat_targets = heatmaps.float().to(config.device) 100 | offy_targets = offymaps.float().to(config.device) 101 | offx_targets = offxmaps.float().to(config.device) 102 | masks = maskmaps.float().to(config.device) 103 | clss_targets = clssmaps.long().to(config.device) 104 | 105 | predictions = model(images, heatmaps=heat_targets, offymaps=offy_targets, offxmaps=offx_targets) 106 | 107 | heatmap, offymap, offxmap, clssmap = predictions['result'] 108 | loss, info = criterion(heatmap, offymap, offxmap, clssmap, 109 | heat_targets, offy_targets, offx_targets, clss_targets, masks, name='predict') 110 | 111 | loss = torch.abs(loss - config.flooding_d) + config.flooding_d 112 | 113 | optimizer.zero_grad() 114 | loss.backward() 115 | optimizer.step() 116 | 117 | for key, value in info.items(): 118 | if key not in epoch_records: epoch_records[key] = [] 119 | epoch_records[key].append(value) 120 | epoch_records['time'].append(time.time() - start_time) 121 | 122 | batch_records = batch_cl_analysis(heat_targets, clss_targets, clssmap) 123 | debug_records.extend(batch_records) 124 | 125 | if (batch_idx and batch_idx % config.display == 0) or (batch_idx == num_batchs - 1): 126 | context = 'EP:{:03d}\tTI:{:03d}/{:03d}\t'.format(epoch, batch_idx, num_batchs) 127 | context_predict = '\t' 128 | context_class = '\t' 129 | for key, value in epoch_records.items(): 130 | if 'cls' in key: 131 | context_class += '{}:{:.6f}({:.6f})\t'.format(key, value[-1], np.mean(value)) 132 | elif 'predict' in key: 133 | context_predict += '{}:{:.6f}({:.6f})\t'.format(key, value[-1], np.mean(value)) 134 | else: 135 | context += '{}:{:.6f}({:.6f})\t'.format(key, value[-1], np.mean(value)) 136 | print(context) 137 | print(context_predict) 138 | print(context_class) 139 | 140 | # with open('debug/{:03d}.txt'.format(epoch), 'w') as f: 141 | # for record in debug_records: 142 | # f.write(str(record)+'\n') 143 | 144 | return epoch_records 145 | 146 | def compute_distance(P, Q, mode): 147 | if mode == 'kl': 148 | def _asymmetricKL(P_, Q_): 149 | return sum(P_ * log(P_ / Q_)) 150 | return (_asymmetricKL(P, Q) + _asymmetricKL(Q, P)) / 2. 151 | elif mode == 'mse': 152 | return np.sum((P - Q) ** 2) ** 0.2 153 | 154 | def evaluate(config, model, studies, annotation_path=None, trans_model=None, class_tensors=None): 155 | model.eval() 156 | annotations = [] 157 | for study in studies.values(): 158 | kp_frame = study.t2_sagittal_middle_frame 159 | sagittal_image = tf.resize(kp_frame.image, config.sagittal_size) 160 | sagittal_image = tf.to_tensor(sagittal_image).unsqueeze(dim=0).float().to(config.device) 161 | with torch.no_grad(): 162 | predictions = model(sagittal_image) 163 | heatmaps, offymaps, offxmaps, clssmaps = predictions['result'] 164 | if config.test_flip: 165 | sagittal_image_flipped = np.flip(sagittal_image.detach().cpu().numpy(), 3).copy() 166 | sagittal_image_flipped = torch.from_numpy(sagittal_image_flipped).to(sagittal_image.device) 167 | predictions_flipped = model(sagittal_image_flipped) 168 | heatmaps_flipped, _, _, clssmaps_flipped = predictions_flipped['result'] 169 | heatmaps = (heatmaps + torch.from_numpy(np.flip(heatmaps_flipped.detach().cpu().numpy(), 3).copy()).to( 170 | heatmaps.device)) / 2. 171 | clssmaps = (clssmaps + torch.from_numpy(np.flip(clssmaps_flipped.detach().cpu().numpy(), 3).copy()).to( 172 | clssmaps.device)) / 2. 173 | heatmaps = torch.mean(heatmaps, dim=0, keepdim=True) 174 | offymaps = torch.mean(offymaps, dim=0, keepdim=True) 175 | offxmaps = torch.mean(offxmaps, dim=0, keepdim=True) 176 | clssmaps = torch.mean(clssmaps, dim=0, keepdim=True) 177 | prediction, _, cls_prediction, d_v5_cls_prediction = cl_decoder( 178 | config, heatmaps[0], offymaps[0], offxmaps[0], clssmaps[0], maskmap=None, tensor=True) 179 | if (trans_model is not None) and (class_tensors is not None): 180 | try: 181 | with torch.no_grad(): 182 | trans_images, _ = study.t2_transverse_k_nearest(config, False, prediction[5:, :2].float(), 1) 183 | trans_images = trans_images[:, 0, :].repeat(1, 3, 1, 1).float().to(config.device) # (6, 3, 96, 96) 184 | trans_outputs = model(trans_images, True) 185 | for d_index, (trans_image, trans_output) in enumerate(zip(trans_images, trans_outputs)): 186 | trans_image = trans_image.detach().cpu().numpy() 187 | if trans_image.sum() == 0: continue 188 | trans_output = trans_output.detach().cpu().numpy() 189 | distance = np.array([compute_distance(class_tensor, output, mode='mse') for class_tensor in class_tensors]) 190 | cls_prediction[5+d_index] = np.argmin(distance) 191 | except: 192 | pass 193 | height_ratio = config.sagittal_size[0] / kp_frame.size[1] 194 | width_ratio = config.sagittal_size[1] / kp_frame.size[0] 195 | ratio = torch.tensor([width_ratio, height_ratio], device=prediction.device) 196 | prediction = (prediction / ratio).round().float() 197 | annotation = gen_annotation(config, study, prediction, cls_prediction, d_v5_cls_prediction) 198 | annotations.append(annotation) 199 | 200 | if annotation_path is None: 201 | return annotations 202 | else: 203 | predictions = annotations 204 | with open(annotation_path, 'r') as file: 205 | annotations = json.load(file) 206 | 207 | annotations = format_annotation(annotations) 208 | 209 | matrix = confusion_matrix(config, studies, predictions, annotations) 210 | outputs = cal_metrics(matrix) 211 | 212 | i = 0 213 | while i < len(outputs) and outputs[i][0] != config.metric: 214 | i += 1 215 | if i < len(outputs): 216 | outputs = [outputs[i]] + outputs[:i] + outputs[i + 1:] 217 | 218 | return outputs 219 | 220 | 221 | def valid(config, model, valid_studies, annotation_path, trans_model, class_tensors): 222 | metrics_values = evaluate(config, model, valid_studies, annotation_path=annotation_path, trans_model=trans_model, class_tensors=class_tensors) 223 | metrics_value = metrics_values[1][1] 224 | for data in metrics_values: 225 | print('valid {}: '.format(data[0]), *data[1:]) 226 | return metrics_value 227 | 228 | 229 | def testA(config, model, testA_studies, submission_file, trans_model, class_tensors): 230 | predictions = evaluate(config, model, testA_studies, annotation_path=None, trans_model=trans_model, class_tensors=class_tensors) 231 | with open(submission_file, 'w') as file: 232 | json.dump(predictions, file) 233 | print('=' * 100) 234 | print('Generated submission file: {}'.format(submission_file)) 235 | 236 | 237 | def testB(config, model, testA_studies, submission_file, trans_model, class_tensors): 238 | return testA(config, model, testA_studies, submission_file, trans_model=trans_model, class_tensors=class_tensors) 239 | 240 | def vote_by_models(models_prediction, models_cls_prediction, models_d_v5_cls_prediction, tensor=True): 241 | prediction = np.mean(np.stack(models_prediction, axis=0), axis=0) 242 | models_cls_prediction = np.transpose(np.stack(models_cls_prediction, axis=0), [1, 0]) # (11, N) 243 | cls_prediction = np.array([np.argmax(np.bincount(line)) for line in models_cls_prediction]) 244 | models_d_v5_cls_prediction = np.transpose(np.stack(models_d_v5_cls_prediction, axis=0), [1, 0]) # (6, N) 245 | d_v5_cls_prediction = np.array([np.argmax(np.bincount(line)) for line in models_d_v5_cls_prediction]) 246 | if tensor: 247 | prediction = torch.from_numpy(prediction) 248 | cls_prediction = torch.from_numpy(cls_prediction) 249 | d_v5_cls_prediction = torch.from_numpy(d_v5_cls_prediction) 250 | return prediction, cls_prediction, d_v5_cls_prediction 251 | 252 | def evaluate_cross(config, models, studies, annotation_path=None, trans_model=None, class_tensors=None): 253 | for model in models: 254 | model.eval() 255 | annotations = [] 256 | for study in studies.values(): 257 | with torch.no_grad(): 258 | kp_frame = study.t2_sagittal_middle_frame 259 | sagittal_image = tf.resize(kp_frame.image, config.sagittal_size) 260 | sagittal_image = tf.to_tensor(sagittal_image).unsqueeze(dim=0).float().to(config.device) 261 | models_prediction, models_cls_prediction, models_d_v5_cls_prediction = [], [], [] 262 | models_heatmaps, models_offymaps, models_offxmaps, models_clssmaps = [], [], [], [] 263 | for model in models: 264 | predictions = model(sagittal_image) 265 | heatmaps, offymaps, offxmaps, clssmaps = predictions['result'] 266 | if config.test_flip: 267 | sagittal_image_flipped = np.flip(sagittal_image.detach().cpu().numpy(), 3).copy() 268 | sagittal_image_flipped = torch.from_numpy(sagittal_image_flipped).to(sagittal_image.device) 269 | predictions_flipped = model(sagittal_image_flipped) 270 | heatmaps_flipped, _, _, clssmaps_flipped = predictions_flipped['result'] 271 | heatmaps = (heatmaps + torch.from_numpy(np.flip(heatmaps_flipped.detach().cpu().numpy(), 3).copy()).to( 272 | heatmaps.device)) / 2. 273 | clssmaps = (clssmaps + torch.from_numpy(np.flip(clssmaps_flipped.detach().cpu().numpy(), 3).copy()).to( 274 | clssmaps.device)) / 2. 275 | models_heatmaps.append(heatmaps[0]) 276 | models_offymaps.append(offymaps[0]) 277 | models_offxmaps.append(offxmaps[0]) 278 | models_clssmaps.append(clssmaps[0]) 279 | prediction, _, cls_prediction, d_v5_cls_prediction = cl_decoder( 280 | config, heatmaps[0], offymaps[0], offxmaps[0], clssmaps[0], maskmap=None, tensor=False) 281 | models_prediction.append(prediction) 282 | models_cls_prediction.append(cls_prediction) 283 | models_d_v5_cls_prediction.append(d_v5_cls_prediction) 284 | models_heatmap = torch.mean(torch.stack(models_heatmaps, dim=0), dim=0) 285 | models_offymap = torch.mean(torch.stack(models_offymaps, dim=0), dim=0) 286 | models_offxmap = torch.mean(torch.stack(models_offxmaps, dim=0), dim=0) 287 | models_clssmap = torch.mean(torch.stack(models_clssmaps, dim=0), dim=0) 288 | _, cls_prediction, d_v5_cls_prediction = vote_by_models(models_prediction, models_cls_prediction, 289 | models_d_v5_cls_prediction, tensor=True) 290 | prediction, _, _, _ = cl_decoder(config, models_heatmap, models_offymap, models_offxmap, 291 | models_clssmap, maskmap=None, tensor=True) 292 | height_ratio = config.sagittal_size[0] / kp_frame.size[1] 293 | width_ratio = config.sagittal_size[1] / kp_frame.size[0] 294 | ratio = torch.tensor([width_ratio, height_ratio], device=prediction.device) 295 | prediction = (prediction / ratio).round().float() 296 | annotation = gen_annotation(config, study, prediction, cls_prediction, d_v5_cls_prediction) 297 | annotations.append(annotation) 298 | 299 | if annotation_path is None: 300 | return annotations 301 | else: 302 | predictions = annotations 303 | with open(annotation_path, 'r') as file: 304 | annotations = json.load(file) 305 | 306 | annotations = format_annotation(annotations) 307 | 308 | matrix = confusion_matrix(config, studies, predictions, annotations) 309 | outputs = cal_metrics(matrix) 310 | 311 | i = 0 312 | while i < len(outputs) and outputs[i][0] != config.metric: 313 | i += 1 314 | if i < len(outputs): 315 | outputs = [outputs[i]] + outputs[:i] + outputs[i + 1:] 316 | 317 | return outputs 318 | 319 | def testB_cross(config, models, testB_studies, submission_file, trans_model, class_tensors): 320 | predictions = evaluate_cross(config, models, testB_studies, annotation_path=None, trans_model=trans_model, class_tensors=class_tensors) 321 | with open(submission_file, 'w') as file: 322 | json.dump(predictions, file) 323 | print('=' * 100) 324 | print('Generated submission file: {}'.format(submission_file)) 325 | 326 | -------------------------------------------------------------------------------- /utils/series_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import torch 7 | from collections import Counter 8 | from typing import List 9 | from utils.dicom_utils import DICOM 10 | from utils.utils import lazy_property 11 | from PIL import Image 12 | 13 | 14 | class Series(list): 15 | def __init__(self, dicom_list: List[DICOM]): 16 | planes = [dicom.plane for dicom in dicom_list] 17 | plane_counter = Counter(planes) 18 | self.plane = plane_counter.most_common(1)[0][0] 19 | 20 | if self.plane == 'transverse': 21 | dim = 2 22 | elif self.plane == 'sagittal': 23 | dim = 0 24 | elif self.plane == 'transverse': 25 | dim = 1 26 | else: 27 | dim = None 28 | 29 | dicom_list = [dicom for dicom in dicom_list if dicom.plane == self.plane] 30 | if dim is not None: 31 | dicom_list = sorted(dicom_list, key=lambda x: x.image_position[dim], reverse=True) 32 | if self.plane == 'sagittal': 33 | images = [np.asarray(dicom.image) for dicom in dicom_list if dicom.image is not None] 34 | if len(images) > 0: 35 | images = [images[0]] + images + [images[-1]] 36 | for idx in range(len(dicom_list)): 37 | try: 38 | dicom_list[idx].image = Image.fromarray( 39 | np.stack(images[idx + 1 - 1: idx + 1 + 2], axis=2).astype(np.uint8)) 40 | except: 41 | dicom_list[idx].image = Image.fromarray( 42 | np.stack([images[idx] for _ in range(3)], axis=2).astype(np.uint8)) 43 | super().__init__(dicom_list) 44 | self.instance_uids = {d.instance_uid: i for i, d in enumerate(self)} 45 | self.instance_idxs = {i: d.instance_uid for i, d in enumerate(self)} 46 | self.middle_frame_uid = None 47 | 48 | def __getitem__(self, item) -> DICOM: 49 | if isinstance(item, str): 50 | item = self.instance_uids[item] 51 | return super().__getitem__(item) 52 | 53 | @lazy_property 54 | def t_type(self): 55 | t_type_counter = Counter([d.t_type for d in self]) 56 | return t_type_counter.most_common(1)[0][0] 57 | 58 | @lazy_property 59 | def t_info(self): 60 | t_info_counter = Counter([d.t_info for d in self]) 61 | return t_info_counter.most_common(1)[0][0] 62 | 63 | @lazy_property 64 | def mean(self): 65 | output = 0 66 | i = 0 67 | for dicom in self: 68 | mean = dicom.mean 69 | if mean is None: 70 | continue 71 | output = i / (i + 1) * output + mean / (i + 1) 72 | i += 1 73 | return output 74 | 75 | @property 76 | def middle_frame(self) -> DICOM: 77 | if self.middle_frame_uid is not None: 78 | return self[self.middle_frame_uid] 79 | else: 80 | return self[len(self) // 2] 81 | 82 | def set_middle_frame(self, instance_uid): 83 | self.middle_frame_uid = instance_uid 84 | 85 | @property 86 | def image_positions(self): 87 | positions = [] 88 | for dicom in self: 89 | positions.append(dicom.image_position) 90 | return torch.stack(positions, dim=0) 91 | 92 | @property 93 | def unit_normal_vectors(self): 94 | vectors = [] 95 | for dicom in self: 96 | vectors.append(dicom.unit_normal_vector) 97 | return torch.stack(vectors, dim=0) 98 | 99 | @lazy_property 100 | def series_uid(self): 101 | study_uid_counter = Counter([d.series_uid for d in self]) 102 | return study_uid_counter.most_common(1)[0][0] 103 | 104 | @lazy_property 105 | def study_uid(self): 106 | study_uid_counter = Counter([d.study_uid for d in self]) 107 | return study_uid_counter.most_common(1)[0][0] 108 | 109 | def point_distance(self, coord: torch.Tensor): 110 | return torch.stack([dicom.point_distance(coord) for dicom in self], dim=1).squeeze() 111 | 112 | def k_nearest(self, coord: torch.Tensor, k, max_dist) -> List[List[DICOM]]: 113 | distance = self.point_distance(coord) 114 | indices = torch.argsort(distance, dim=1) 115 | if len(indices) == 1: 116 | return [[self[i] if distance[i] < max_dist else None for i in indices[:k]]] 117 | else: 118 | return [[self[i] if row_d[i] < max_dist else None for i in row[:k]] 119 | for row, row_d in zip(indices, distance)] 120 | -------------------------------------------------------------------------------- /utils/study_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import json 7 | import torch 8 | from tqdm import tqdm 9 | from multiprocessing import Pool, cpu_count 10 | from collections import Counter 11 | from typing import Dict, Union 12 | from utils.utils import lazy_property 13 | from utils.dicom_utils import DICOM 14 | from utils.series_utils import Series 15 | from utils.utils import read_annotation 16 | from torchvision.transforms import functional as tf 17 | 18 | 19 | class Study(dict): 20 | def __init__(self, config, study_dir, pool=None): 21 | dicom_list = [] 22 | if pool is not None: 23 | async_results = [] 24 | for dicom_name in os.listdir(study_dir): 25 | dicom_path = os.path.join(study_dir, dicom_name) 26 | async_results.append(pool.apply_async(DICOM, (dicom_path, config.DICOM_TAG,))) 27 | 28 | for async_result in async_results: 29 | async_result.wait() 30 | dicom = async_result.get() 31 | dicom_list.append(dicom) 32 | else: 33 | for dicom_name in os.listdir(study_dir): 34 | dicom_path = os.path.join(study_dir, dicom_name) 35 | dicom = DICOM(dicom_path, config.DICOM_TAG) 36 | dicom_list.append(dicom) 37 | 38 | dicom_dict = {} 39 | for dicom in dicom_list: 40 | series_uid = dicom.series_uid 41 | if series_uid not in dicom_dict: 42 | dicom_dict[series_uid] = [dicom] 43 | else: 44 | dicom_dict[series_uid].append(dicom) 45 | 46 | super().__init__({k: Series(v) for k, v in dicom_dict.items()}) 47 | 48 | self.t1_sagittal_uid = None 49 | self.t2_sagittal_uid = None 50 | self.t2_transverse_uid = None 51 | max_t1_sagittal_mean = 0 52 | max_t2_sagittal_mean = 0 53 | max_t2_transverse_mean = 0 54 | for series_uid, series in self.items(): 55 | if series.plane == 'sagittal' and series.t_type == 'T2': 56 | t2_sagittal_mean = series.mean 57 | if t2_sagittal_mean > max_t2_sagittal_mean: 58 | max_t2_sagittal_mean = t2_sagittal_mean 59 | self.t2_sagittal_uid = series_uid 60 | if series.plane == 'transverse' and series.t_type == 'T2': 61 | t2_transverse_mean = series.mean 62 | if t2_transverse_mean > max_t2_transverse_mean: 63 | max_t2_transverse_mean = t2_transverse_mean 64 | self.t2_transverse_uid = series_uid 65 | if series.plane == 'sagittal' and series.t_type == 'T1': 66 | t1_sagittal_mean = series.mean 67 | if t1_sagittal_mean > max_t1_sagittal_mean: 68 | max_t1_sagittal_mean = t1_sagittal_mean 69 | self.t1_sagittal_uid = series_uid 70 | 71 | if self.t2_sagittal_uid is None: 72 | for series_uid, series in self.items(): 73 | if series.plane == 'sagittal': 74 | t2_sagittal_mean = series.mean 75 | if t2_sagittal_mean > max_t2_sagittal_mean: 76 | max_t2_sagittal_mean = t2_sagittal_mean 77 | self.t2_sagittal_uid = series_uid 78 | 79 | if self.t2_transverse_uid is None: 80 | for series_uid, series in self.items(): 81 | if series.plane == 'transverse': 82 | t2_transverse_mean = series.mean 83 | if t2_transverse_mean > max_t2_transverse_mean: 84 | max_t2_transverse_mean = t2_transverse_mean 85 | self.t2_transverse_uid = series_uid 86 | 87 | if self.t1_sagittal_uid is None: 88 | self.t1_sagittal_uid = self.t2_sagittal_uid 89 | 90 | @lazy_property 91 | def study_uid(self): 92 | study_uid_counter = Counter([s.study_uid for s in self.values()]) 93 | return study_uid_counter.most_common(1)[0][0] 94 | 95 | @property 96 | def t2_sagittal(self) -> Union[None, Series]: 97 | if self.t2_sagittal_uid is None: 98 | return None 99 | else: 100 | return self[self.t2_sagittal_uid] 101 | 102 | @property 103 | def t1_sagittal(self) -> Union[None, Series]: 104 | if self.t1_sagittal_uid is None: 105 | return None 106 | else: 107 | return self[self.t1_sagittal_uid] 108 | 109 | @property 110 | def t2_transverse(self) -> Union[None, Series]: 111 | if self.t2_transverse_uid is None: 112 | return None 113 | else: 114 | return self[self.t2_transverse_uid] 115 | 116 | @property 117 | def t2_sagittal_middle_frame(self) -> Union[None, DICOM]: 118 | if self.t2_sagittal is None: 119 | return None 120 | else: 121 | return self.t2_sagittal.middle_frame 122 | 123 | @property 124 | def t1_sagittal_middle_frame(self) -> Union[None, DICOM]: 125 | if self.t1_sagittal is None: 126 | return None 127 | else: 128 | return self.t1_sagittal.middle_frame 129 | 130 | @property 131 | def t2_sagittal_middle_frame_off(self) -> Union[None, DICOM]: 132 | if self.t2_sagittal is None: 133 | return None 134 | else: 135 | return self.t2_sagittal.middle_frame_off 136 | 137 | def set_t2_sagittal_middle_frame(self, series_uid, instance_uid): 138 | assert series_uid in self 139 | self.t2_sagittal_uid = series_uid 140 | self.t2_sagittal.set_middle_frame(instance_uid) 141 | 142 | def _gen_t2_transverse_one_image_mask(self, config, point, dicom): 143 | if dicom is None: 144 | mask = True 145 | image = torch.zeros(1, *config.transverse_size) 146 | else: 147 | mask = False 148 | projection = dicom.projection(point) 149 | pixel_spacing = dicom.pixel_spacing 150 | image, projection = dicom.transform(config.transverse_trans_settings, projection, tensor=False, isaug=True) 151 | if len(projection.shape) == 2: 152 | projection = projection[0] 153 | transverse_size = (int(config.transverse_size[0] / pixel_spacing[0].item()), 154 | int(config.transverse_size[1] / pixel_spacing[1].item())) 155 | image = tf.crop(image, 156 | int(projection[1] - transverse_size[0] // 2), 157 | int(projection[0] - transverse_size[1] // 2), 158 | transverse_size[0], 159 | transverse_size[1]) 160 | image = tf.resize(image, config.transverse_size) 161 | image = tf.to_tensor(image) 162 | return image, mask 163 | 164 | def _gen_t2_transverse_batch_image_mask(self, config, point, series): 165 | results = list( 166 | map(self._gen_t2_transverse_one_image_mask, [config] * len(series), [point] * len(series), series)) 167 | temp_images, temp_masks = [], [] 168 | for image, mask in results: 169 | temp_images.append(image) 170 | temp_masks.append(mask) 171 | temp_images = torch.stack(temp_images, dim=0) 172 | return temp_images, temp_masks 173 | 174 | def t2_transverse_k_nearest(self, config, isaug, pixel_coord, k): 175 | if k <= 0 or self.t2_transverse is None: 176 | images = torch.zeros(pixel_coord.shape[0], k, 1, *config.transverse_size) 177 | masks = torch.zeros(*images.shape[:2], dtype=torch.bool) 178 | return images, masks 179 | human_coord = self.t2_sagittal_middle_frame.pixel_coord2human_coord(pixel_coord) 180 | dicoms = self.t2_transverse.k_nearest(human_coord, k, config.max_dist * 2) 181 | images = [] 182 | masks = [] 183 | for point, series in zip(human_coord, dicoms): 184 | temp_images = [] 185 | temp_masks = [] 186 | for dicom in series: 187 | if dicom is None: 188 | temp_masks.append(False) 189 | image = torch.zeros(1, *config.transverse_size) 190 | else: 191 | temp_masks.append(True) 192 | projection = dicom.projection(point) 193 | pixel_spacing = dicom.pixel_spacing 194 | image, projection = dicom.transverse_transform( 195 | config.transverse_trans_settings, projection, tensor=True, isaug=isaug) 196 | temp_images.append(image) 197 | temp_images = torch.stack(temp_images, dim=0) 198 | images.append(temp_images) 199 | masks.append(temp_masks) 200 | images = torch.stack(images, dim=0) 201 | masks = torch.tensor(masks, dtype=torch.bool) 202 | return images, masks 203 | 204 | def _construct_studies(config, data_dir, multiprocessing=False): 205 | studies: Dict[str, Study] = {} 206 | if multiprocessing: 207 | pool = Pool(cpu_count()) 208 | else: 209 | pool = None 210 | 211 | for study_name in tqdm(os.listdir(data_dir), ascii=True): 212 | study_dir = os.path.join(data_dir, study_name) 213 | study = Study(config, study_dir, pool) 214 | studies[study.study_uid] = study 215 | 216 | if pool is not None: 217 | pool.close() 218 | pool.join() 219 | 220 | return studies 221 | 222 | def _set_middle_frame(studies: Dict[str, Study], annotation): 223 | counter = { 224 | 't2_sagittal_not_found': [], 225 | 't2_sagittal_miss_match': [], 226 | 't2_sagittal_middle_frame_miss_match': [] 227 | } 228 | for k in annotation.keys(): 229 | if k[0] in studies: 230 | study = studies[k[0]] 231 | if study.t2_sagittal is None: 232 | counter['t2_sagittal_not_found'].append(study.study_uid) 233 | elif study.t2_sagittal_uid != k[1]: 234 | counter['t2_sagittal_miss_match'].append(study.study_uid) 235 | else: 236 | t2_sagittal = study.t2_sagittal 237 | gt_z_index = t2_sagittal.instance_uids[k[2]] 238 | middle_frame = t2_sagittal.middle_frame 239 | z_index = t2_sagittal.instance_uids[middle_frame.instance_uid] 240 | if abs(gt_z_index - z_index) > 1: 241 | counter['t2_sagittal_middle_frame_miss_match'].append(study.study_uid) 242 | study.set_t2_sagittal_middle_frame(k[1], k[2]) 243 | return counter 244 | 245 | def _testB_set_middle_frame(studies: Dict[str, Study], annotation): 246 | for anno_dict in annotation: 247 | if anno_dict['studyUid'] in studies: 248 | study = studies[anno_dict['studyUid']] 249 | study.t2_sagittal_uid = anno_dict['seriesUid'] 250 | return studies 251 | 252 | def construct_studies(config, data_dir, annotation_path=None, multi_processing=False): 253 | studies = _construct_studies(config, data_dir, multi_processing) 254 | 255 | if 'testB' in data_dir or 'round2test' in data_dir: 256 | annotation = json.load(open(annotation_path, 'r')) 257 | studies = _testB_set_middle_frame(studies, annotation) 258 | return studies 259 | elif annotation_path == '' or annotation_path is None: 260 | return studies 261 | else: 262 | annotation = read_annotation(config, annotation_path) 263 | counter = _set_middle_frame(studies, annotation) 264 | return studies, annotation, counter 265 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | from typing import Tuple 7 | from PIL import Image 8 | import torchvision.transforms.functional as tf 9 | import numpy as np 10 | import math 11 | 12 | 13 | def random_rotate(image: Image.Image, pixel_coord: torch.Tensor, max_angel: int) \ 14 | -> (Image.Image, torch.Tensor): 15 | angel = np.random.randint(-1 * max_angel, max_angel) 16 | center = torch.tensor(image.size, dtype=torch.float32) / 2 17 | image = tf.rotate(image, angel, fill=(0,) * (3 if len(np.shape(image)) == 3 else 1)) 18 | if angel != 0: 19 | angel = angel * math.pi / 180 20 | while len(center.shape) < len(pixel_coord.shape): 21 | center = center.unsqueeze(0) 22 | cos = math.cos(angel) 23 | sin = math.sin(angel) 24 | rotate_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32, device=pixel_coord.device) 25 | output = pixel_coord - center 26 | output = torch.matmul(output, rotate_mat) 27 | pixel_coord = output + center 28 | return image, pixel_coord 29 | 30 | 31 | def random_shift(image: Image.Image, pixel_coord: torch.Tensor, max_shift_ratios: tuple) \ 32 | -> (Image.Image, torch.Tensor): 33 | # max_shift_ratios: (height_ratio, width_ratio) 34 | height, width = image.size[1], image.size[0] 35 | center = (height // 2, width // 2) 36 | top_shift = int(height * np.random.uniform(-1 * max_shift_ratios[0], max_shift_ratios[0])) 37 | lef_shift = int(width * np.random.uniform(-1 * max_shift_ratios[1], max_shift_ratios[1])) 38 | image = tf.pad(image, (width // 2, height // 2, width // 2, height // 2)) 39 | center = (center[0] + height // 2 + top_shift, center[1] + width // 2 + lef_shift) 40 | image = tf.crop(image, center[0] - height // 2, center[1] - width // 2, height, width) 41 | pixel_coord[:, 0] -= lef_shift 42 | pixel_coord[:, 1] -= top_shift 43 | return image, pixel_coord 44 | 45 | 46 | def random_crop(image: Image.Image, pixel_coord: torch.Tensor, max_crop_ratios: tuple) \ 47 | -> (Image.Image, torch.Tensor): 48 | # max_crop_ratios: (height_ratio, width_ratio) 49 | height, width = image.size[1], image.size[0] 50 | center = (height // 2, width // 2) 51 | new_height = int(height * (1. + np.random.uniform(-1 * max_crop_ratios[0], max_crop_ratios[0]))) 52 | new_width = int(width * (1. + np.random.uniform(-1 * max_crop_ratios[1], max_crop_ratios[1]))) 53 | image = tf.pad(image, (width // 2, height // 2, width // 2, height // 2)) 54 | center = (center[0] + height // 2, center[1] + width // 2) 55 | image = tf.crop(image, center[0] - new_height // 2, center[1] - new_width // 2, new_height, new_width) 56 | pixel_coord[:, 0] += (new_width - width) // 2 57 | pixel_coord[:, 1] += (new_height - height) // 2 58 | return image, pixel_coord 59 | 60 | 61 | def random_intensity(image: Image.Image, max_intensity_ratio: float) \ 62 | -> Image.Image: 63 | np_image = np.asarray(image) 64 | np_image = np_image * (1. + np.random.uniform(-1 * max_intensity_ratio, max_intensity_ratio)) 65 | np_image = np_image.clip(0, 255) 66 | image = Image.fromarray(np.uint8(np_image)) 67 | return image 68 | 69 | 70 | def random_hflip(image: Image.Image, pixel_coord: torch.Tensor) -> (Image.Image, torch.Tensor): 71 | height, width = image.size[1], image.size[0] 72 | image = tf.hflip(image) 73 | pixel_coord[:, 0] = width - pixel_coord[:, 0] + 1 74 | return image, pixel_coord 75 | 76 | 77 | def resize(image: Image.Image, pixel_coord: torch.Tensor, pixel_spacing: torch.Tensor, size: Tuple[int, int]) \ 78 | -> (Image.Image, torch.Tensor, torch.Tensor): 79 | height_ratio = size[0] / image.size[1] 80 | width_ratio = size[1] / image.size[0] 81 | ratio = torch.tensor([width_ratio, height_ratio]) 82 | image = tf.resize(image, size) 83 | pixel_coord = pixel_coord * ratio 84 | pixel_spacing = pixel_spacing / ratio 85 | return image, pixel_coord, pixel_spacing 86 | 87 | 88 | def random_bias_field(image: Image.Image, order: int = 3, coefficients_range: list = [-0.5, 0.5]) \ 89 | -> Image.Image: 90 | image = np.asarray(image) 91 | shape = np.asarray(image.shape[:2]) 92 | half_shape = shape / 2 93 | 94 | ranges = [np.arange(-n, n) for n in half_shape] 95 | 96 | bias_field = np.zeros(shape) 97 | y_mesh, x_mesh = np.asarray(np.meshgrid(*ranges)) 98 | 99 | y_mesh /= y_mesh.max() 100 | x_mesh /= x_mesh.max() 101 | 102 | random_coefficients = [] 103 | for y_order in range(0, order + 1): 104 | for x_order in range(0, order + 1 - y_order): 105 | number = np.random.uniform(*coefficients_range) 106 | random_coefficients.append(number) 107 | 108 | i = 0 109 | for y_order in range(order + 1): 110 | for x_order in range(order + 1 - y_order): 111 | random_coefficient = random_coefficients[i] 112 | new_map = (random_coefficient 113 | * y_mesh ** y_order 114 | * x_mesh ** x_order) 115 | bias_field += np.transpose(new_map, (1, 0)) 116 | i += 1 117 | bias_field = np.exp(bias_field).astype(np.float32) 118 | if len(np.shape(image)) == 3: 119 | image = image * bias_field[..., np.newaxis] 120 | else: 121 | image = image * bias_field 122 | image = Image.fromarray(np.uint8(image)) 123 | return image 124 | 125 | def random_noise(image: Image.Image, mean: list = [0, 5], std: list = [0, 2]) \ 126 | -> Image.Image: 127 | image = np.asarray(image) 128 | noise = np.random.randn(*image.shape) * np.random.uniform(*std) + np.random.uniform(*mean) 129 | image = Image.fromarray(np.uint8(image + noise)) 130 | return image -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import math 6 | import json 7 | import torch 8 | from typing import Dict, Tuple 9 | import pandas as pd 10 | 11 | pd.set_option('display.max_columns', None) 12 | pd.set_option('display.max_rows', None) 13 | pd.set_option('display.width', 5000) 14 | import os 15 | import logging 16 | import time 17 | from PIL import Image 18 | import numpy as np 19 | import matplotlib.pyplot as plt 20 | from sklearn.metrics import f1_score 21 | from copy import deepcopy 22 | 23 | def build_logging(filename): 24 | logging.basicConfig(level=logging.INFO, 25 | format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', 26 | datefmt='%m-%d %H:%M', 27 | filename=filename, 28 | filemode='w') 29 | console = logging.StreamHandler() 30 | console.setLevel(logging.INFO) 31 | formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') 32 | console.setFormatter(formatter) 33 | logger = logging.getLogger('') 34 | logger.addHandler(console) 35 | return logger 36 | 37 | def str2tensor(s: str) -> torch.Tensor: 38 | return torch.Tensor(list(map(float, s.split('\\')))) 39 | 40 | def unit_vector(tensor: torch.Tensor, dim=-1): 41 | norm = (tensor ** 2).sum(dim=dim, keepdim=True).sqrt() 42 | return tensor / norm 43 | 44 | def unit_normal_vector(orientation: torch.Tensor): 45 | temp1 = orientation[:, [1, 2, 0]] 46 | temp2 = orientation[:, [2, 0, 1]] 47 | output = temp1 * temp2[[1, 0]] 48 | output = output[0] - output[1] 49 | return unit_vector(output, dim=-1) 50 | 51 | def lazy_property(func): 52 | attr_name = "_lazy_" + func.__name__ 53 | 54 | @property 55 | def _lazy_property(self): 56 | if not hasattr(self, attr_name): 57 | setattr(self, attr_name, func(self)) 58 | return getattr(self, attr_name) 59 | 60 | return _lazy_property 61 | 62 | def read_annotation(config, path) -> Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]]: 63 | with open(path, 'r') as annotation_file: 64 | non_hit_count = {} 65 | annotation = {} 66 | for x in json.load(annotation_file): 67 | study_uid = x['studyUid'] 68 | 69 | # assert len(x['data']) == 1, (study_uid, len(x['data'])) 70 | data = x['data'][0] 71 | instance_uid = data['instanceUid'] 72 | series_uid = data['seriesUid'] 73 | 74 | # assert len(data['annotation']) == 1, (study_uid, len(data['annotation'])) 75 | points = data['annotation'][0]['data']['point'] 76 | 77 | vertebra_label = torch.full([len(config.SPINAL_VERTEBRA_ID), 3], 78 | config.PADDING_VALUE, dtype=torch.long) 79 | disc_label = torch.full([len(config.SPINAL_DISC_ID), 3], 80 | config.PADDING_VALUE, dtype=torch.long) 81 | disc_v5_label = torch.full([len(config.SPINAL_DISC_ID), 3], 82 | 0, dtype=torch.long) 83 | for point in points: 84 | identification = point['tag']['identification'] 85 | if identification in config.SPINAL_VERTEBRA_ID: 86 | position = config.SPINAL_VERTEBRA_ID[identification] 87 | diseases = point['tag']['vertebra'] 88 | 89 | vertebra_label[position, :2] = torch.tensor(point['coord']) 90 | for disease in diseases.split(','): 91 | if disease in config.SPINAL_VERTEBRA_DISEASE_ID: 92 | disease = config.SPINAL_VERTEBRA_DISEASE_ID[disease] 93 | vertebra_label[position, 2] = disease 94 | elif identification in config.SPINAL_DISC_ID: 95 | position = config.SPINAL_DISC_ID[identification] 96 | diseases = point['tag']['disc'] 97 | 98 | disc_label[position, :2] = torch.tensor(point['coord']) 99 | disc_v5_label[position, :2] = torch.tensor(point['coord']) 100 | for disease in diseases.split(','): 101 | if disease != 'v5' and disease in config.SPINAL_DISC_DISEASE_ID: 102 | disease = config.SPINAL_DISC_DISEASE_ID[disease] 103 | disc_label[position, 2] = disease 104 | elif disease == 'v5' and disease in config.SPINAL_DISC_DISEASE_ID: 105 | disease = config.SPINAL_DISC_DISEASE_ID[disease] 106 | disc_v5_label[position, 2] = disease 107 | elif identification in non_hit_count: 108 | non_hit_count[identification] += 1 109 | else: 110 | non_hit_count[identification] = 1 111 | 112 | annotation[study_uid, series_uid, instance_uid] = vertebra_label, disc_label, disc_v5_label 113 | if len(non_hit_count) > 0: 114 | print(non_hit_count) 115 | return annotation 116 | 117 | def cal_metrics(confusion_matrix: pd.DataFrame): 118 | key_point_recall = confusion_matrix.iloc[:-2].sum().sum() / confusion_matrix.sum().sum() 119 | precision = {col: confusion_matrix.loc[col, col] / confusion_matrix.loc[col].sum() for col in confusion_matrix} 120 | recall = {col: confusion_matrix.loc[col, col] / confusion_matrix[col].sum() for col in confusion_matrix} 121 | f1 = {col: 2 * precision[col] * recall[col] / (precision[col] + recall[col]) for col in confusion_matrix} 122 | macro_f1 = sum(f1.values()) / len(f1) 123 | 124 | columns = confusion_matrix.columns 125 | recall_true_point = {col: confusion_matrix.loc[col, col] / confusion_matrix.loc[columns, col].sum() 126 | for col in confusion_matrix} 127 | f1_true_point = {col: 2 * precision[col] * recall_true_point[col] / (precision[col] + recall_true_point[col]) 128 | for col in confusion_matrix} 129 | macro_f1_true_point = sum(f1_true_point.values()) / len(f1) 130 | output = [('macro f1', macro_f1), ('key point recall', key_point_recall), 131 | ('macro f1 (true point)', macro_f1_true_point)] 132 | output += sorted([(k + ' f1 (true point)', v, precision[k], recall[k]) for k, v in f1_true_point.items()], 133 | key=lambda x: x[0]) 134 | return output 135 | 136 | def gen_annotation(config, study, prediction, cls_prediction, d_v5_cls_prediction): 137 | z_index = study.t2_sagittal.instance_uids[study.t2_sagittal_middle_frame.instance_uid] 138 | point = [] 139 | for i, (coord, identification, cls, is_disc) in enumerate(zip( 140 | prediction, config.identifications, cls_prediction, config.is_disc)): 141 | point.append({ 142 | 'coord': coord.cpu().int().numpy().tolist(), 143 | 'tag': { 144 | 'identification': identification, 145 | 'disc' if is_disc else 'vertebra': 'v' + str(cls.item() + 1) 146 | # 'disc' if is_disc else 'vertebra': _parse_number(cls.item()) 147 | }, 148 | 'zIndex': z_index 149 | }) 150 | if is_disc and d_v5_cls_prediction[i - 5].item() > 0: 151 | point[-1]['tag']['disc'] = point[-1]['tag']['disc'] + ',v5' 152 | annotation = { 153 | 'studyUid': study.study_uid, 154 | 'data': [ 155 | { 156 | 'instanceUid': study.t2_sagittal_middle_frame.instance_uid, 157 | 'seriesUid': study.t2_sagittal_middle_frame.series_uid, 158 | 'annotation': [ 159 | { 160 | 'data': { 161 | 'point': point, 162 | } 163 | } 164 | ] 165 | } 166 | ] 167 | } 168 | return annotation 169 | 170 | 171 | def format_annotation(annotations): 172 | output = {} 173 | for annotation in annotations: 174 | study_uid = annotation['studyUid'] 175 | series_uid = annotation['data'][0]['seriesUid'] 176 | instance_uid = annotation['data'][0]['instanceUid'] 177 | temp = {} 178 | for point in annotation['data'][0]['annotation'][0]['data']['point']: 179 | identification = point['tag']['identification'] 180 | coord = point['coord'] 181 | if 'disc' in point['tag']: 182 | disease = point['tag']['disc'] 183 | else: 184 | disease = point['tag']['vertebra'] 185 | if disease == '': 186 | disease = 'v1' 187 | temp[identification] = { 188 | 'coord': coord, 189 | 'disease': disease, 190 | } 191 | output[study_uid] = { 192 | 'seriesUid': series_uid, 193 | 'instanceUid': instance_uid, 194 | 'annotation': temp 195 | } 196 | return output 197 | 198 | def compute_my_metric(config, studies, predictions, annotations): 199 | total_points, true_points = 0, 0 200 | disc_4cls_dict = {'v1': 0, 'v2': 1, 'v3': 2, 'v4': 3} 201 | disc_4cls_pred, disc_4cls_true = [], [] # v1 v2 v3 v4 202 | disc_2cls_pred, disc_2cls_true = [], [] # v5 203 | vertebra_2cls_pred, vertebra_2cls_true = [], [] 204 | 205 | for study_uid, annotation in annotations.items(): 206 | study = studies[study_uid] 207 | pixel_spacing = study.t2_sagittal_middle_frame.pixel_spacing 208 | pred_points = predictions[study_uid]['annotation'] 209 | for identification, gt_point in annotation['annotation'].items(): 210 | gt_coord = gt_point['coord'] 211 | gt_disease = gt_point['disease'].split(',') 212 | 213 | if identification not in pred_points: 214 | continue 215 | 216 | pred_coord = pred_points[identification]['coord'] 217 | pred_disease = pred_points[identification]['disease'].split(',') 218 | 219 | total_points += 1 220 | 221 | if distance(gt_coord, pred_coord, pixel_spacing) <= config.max_dist: 222 | true_points += 1 223 | if '-' in identification: # disc 224 | tmp_pred_disease = deepcopy(pred_disease) 225 | if 'v5' in pred_disease: 226 | tmp_pred_disease.remove('v5') 227 | tmp_gt_disease = deepcopy(gt_disease) 228 | if 'v5' in gt_disease: 229 | tmp_gt_disease.remove('v5') 230 | if len(tmp_gt_disease) == 1: 231 | disc_4cls_pred.append(disc_4cls_dict[tmp_pred_disease[0]]) 232 | disc_4cls_true.append(disc_4cls_dict[tmp_gt_disease[0]]) 233 | if 'v5' in pred_disease: 234 | disc_2cls_pred.append(1) 235 | else: 236 | disc_2cls_pred.append(0) 237 | if 'v5' in gt_disease: 238 | disc_2cls_true.append(1) 239 | else: 240 | disc_2cls_true.append(0) 241 | else: # vertebra 242 | if 'v1' in pred_disease: 243 | vertebra_2cls_pred.append(0) 244 | else: 245 | vertebra_2cls_pred.append(1) 246 | if 'v1' in gt_disease: 247 | vertebra_2cls_true.append(0) 248 | else: 249 | vertebra_2cls_true.append(1) 250 | 251 | metric_dict = { 252 | 'keypoint_accuracy': 1.0 * true_points / total_points, 253 | 'vertebra_f1_score': f1_score(vertebra_2cls_true, vertebra_2cls_pred), 254 | 'disc_cls4_f1_score': f1_score(disc_4cls_true, disc_4cls_pred, average='macro'), 255 | 'disc_cls2_f1_score': f1_score(disc_2cls_true, disc_2cls_pred), 256 | } 257 | 258 | metric_dict['class_avg_score'] = (metric_dict['vertebra_f1_score'] + metric_dict['disc_cls4_f1_score'] + metric_dict['disc_cls2_f1_score']) / 3.0 259 | metric_dict['score'] = metric_dict['keypoint_accuracy'] * metric_dict['class_avg_score'] 260 | return metric_dict 261 | 262 | def confusion_matrix_with_image(config, studies, predictions, annotations) -> pd.DataFrame: 263 | columns = ['disc_' + k for k in config.SPINAL_DISC_DISEASE_ID] 264 | columns += ['vertebra_' + k for k in config.SPINAL_VERTEBRA_DISEASE_ID] 265 | output = pd.DataFrame(config.epsilon, columns=columns, index=columns + ['wrong', 'not_hit']) 266 | 267 | predictions = format_annotation(predictions) 268 | for study_uid, annotation in annotations.items(): 269 | study = studies[study_uid] 270 | pixel_spacing = study.t2_sagittal_middle_frame.pixel_spacing 271 | pred_points = predictions[study_uid]['annotation'] 272 | image = study.t2_sagittal_middle_frame.image 273 | plt.imshow(image, cmap='gray') 274 | for identification, gt_point in annotation['annotation'].items(): 275 | gt_coord = gt_point['coord'] 276 | gt_disease = gt_point['disease'] 277 | 278 | if '-' in identification: 279 | _type = 'disc_' 280 | else: 281 | _type = 'vertebra_' 282 | 283 | if identification not in pred_points: 284 | for d in gt_disease.split(','): 285 | output.loc['not_hit', _type + d] += 1 286 | continue 287 | 288 | pred_coord = pred_points[identification]['coord'] 289 | pred_disease = pred_points[identification]['disease'] 290 | 291 | #plt.text(gt_coord[0]-40, gt_coord[1], gt_disease, c='b', fontsize=10) 292 | #plt.text(pred_coord[0]+10, pred_coord[1], pred_disease, c='r', fontsize=10) 293 | 294 | if distance(gt_coord, pred_coord, pixel_spacing) >= config.max_dist: 295 | plt.scatter(*gt_coord, c='b') 296 | plt.scatter(*pred_coord, c='r') 297 | for d in gt_disease.split(','): 298 | output.loc['wrong', _type + d] += 1 299 | else: 300 | plt.scatter(*gt_coord, c='y') 301 | plt.scatter(*pred_coord, c='w') 302 | for d in gt_disease.split(','): 303 | for dp in pred_disease.split(','): 304 | output.loc[_type + dp, _type + d] += 1 305 | plt.savefig('visualize/{}.jpg'.format(study.study_uid)) 306 | plt.close() 307 | 308 | print(output) 309 | return output 310 | 311 | 312 | def confusion_matrix(config, studies, predictions, annotations) -> pd.DataFrame: 313 | columns = ['disc_' + k for k in config.SPINAL_DISC_DISEASE_ID] 314 | columns += ['vertebra_' + k for k in config.SPINAL_VERTEBRA_DISEASE_ID] 315 | output = pd.DataFrame(config.epsilon, columns=columns, index=columns + ['wrong', 'not_hit']) 316 | 317 | predictions = format_annotation(predictions) 318 | for study_uid, annotation in annotations.items(): 319 | study = studies[study_uid] 320 | pixel_spacing = study.t2_sagittal_middle_frame.pixel_spacing 321 | pred_points = predictions[study_uid]['annotation'] 322 | for identification, gt_point in annotation['annotation'].items(): 323 | gt_coord = gt_point['coord'] 324 | gt_disease = gt_point['disease'] 325 | 326 | if '-' in identification: 327 | _type = 'disc_' 328 | else: 329 | _type = 'vertebra_' 330 | 331 | if identification not in pred_points: 332 | for d in gt_disease.split(','): 333 | output.loc['not_hit', _type + d] += 1 334 | continue 335 | 336 | pred_coord = pred_points[identification]['coord'] 337 | pred_disease = pred_points[identification]['disease'] 338 | if distance(gt_coord, pred_coord, pixel_spacing) >= config.max_dist: 339 | for d in gt_disease.split(','): 340 | output.loc['wrong', _type + d] += 1 341 | else: 342 | for d in gt_disease.split(','): 343 | for dp in pred_disease.split(','): 344 | output.loc[_type + dp, _type + d] += 1 345 | # for dp in pred_disease.split(','): 346 | # if dp in gt_disease: 347 | # output.loc[_type + dp, _type + dp] += 1 348 | # else: 349 | # output.loc[_type + dp, _type + gt_disease.split(',')[0]] += 1 350 | 351 | print(output) 352 | return output 353 | 354 | def distance(coord0, coord1, pixel_spacing): 355 | x = (coord0[0] - coord1[0]) * pixel_spacing[0] 356 | y = (coord0[1] - coord1[1]) * pixel_spacing[1] 357 | output = math.sqrt(x ** 2 + y ** 2) 358 | return output --------------------------------------------------------------------------------