├── LICENSE ├── README.md ├── code ├── __pycache__ │ └── utils.cpython-37.pyc ├── data_utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── dataloader.cpython-37.pyc │ │ └── transforms.cpython-37.pyc │ ├── dataloader.py │ └── transforms.py ├── main_baseline.py ├── main_yolol.py ├── models │ ├── UNet.py │ ├── VNet.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── UNet.cpython-37.pyc │ │ ├── VNet.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── losses.cpython-37.pyc │ └── losses.py └── utils.py └── images ├── cover.png ├── problem1.png ├── problem2.png └── table.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tao He 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Anchor Ball Regression Model for Large-Scale 3D Skull Landmark Detection 2 | 3 | ![alt text](images/cover.png "Title") 4 | 5 | ## 1. Introduction 6 | ### 1.1 What for? 7 | 8 | In this work, we have identified two limitations that hinder the domain of 3D skull landmark detection: 9 | * 1. The lack of a standard benchmark dataset to evaluate the performance of automatic landmark detection models is a significant constraint. Through a review of advanced models from 2018 to 2023, listed in the Table, it was found that these models were trained on private datasets with variable data sizes, types, evaluation metrics, and numbers of landmarks. 10 | 11 | 12 | * 2. the majority of studies collected data only during pre or postoperative stages. However, in a real-world clinical environment, the model must be robust enough to meet clinical demands with diverse data. Conducting a clinical landmarking evaluation is necessary in both pre and postoperative stages. Unfortunately, most models only focus on identifying a fixed number of landmarks on standard CT or CBCT volumes. 13 | 14 | isolated 15 | **The references can be find in our paper (coming soon)! 16 | 17 | 18 | 19 | The project of Mandibular Molar Landmarking (MML) aims to identify the anatomy locations of the second and third mandibular molars' crowns and roots. The task has two main challenges: 20 | 21 | * Mandibular molars have different root numbers because of molars' variant growth. 22 | 23 | isolated 24 | 25 | * Mandibular molars can be damaged by dental diseases, trauma, or surgery. 26 | 27 | isolated 28 | 29 | 30 | ### 1.2 HighLights 31 | * We created a large-scale benchmark dataset consisting of 648 CT volumes for evaluating 3D skull landmark detection. This dataset is publicly available and is, to the best of our knowledge, the largest public dataset. 32 | * MML requires models that are robust in clinical environments and are capable of detecting arbitrary landmarks on pre-operative or post-operative CT volumes, meeting real clinical needs. 33 | * We compared baseline deep learning methods in three aspects: landmark regression models, training losses, neural network structures. An ABR model inspired by YOLOV3 surpassed other baselines. The model combines landmark regression and classification losses for network training, resulting in better performance than the usual heatmap and offset regression methods. 34 | 35 | 36 | ## 2. Preparation 37 | ### 2.1 Requirements 38 | - python >=3.7 39 | - pytorch >=1.10.0 40 | - Cuda 10 or higher 41 | - numpy 42 | - pandas 43 | - scipy 44 | - nrrd 45 | - time 46 | 47 | ### 2.2 Data Preparation 48 | 49 | The dataset is available at https://drive.google.com/file/d/1NGsBbqXZLDlkiSJtDQdyMlXzgnkFoVON/view?usp=sharing> 50 | * Data division 51 | ``` 52 | - mmld_dataset/train # 458 samples for training 53 | - mmld_dataset/val # 100 samples for validation 54 | - mmld_dataset/test # 100 samples for testing 55 | ``` 56 | * Data format 57 | ``` 58 | - *_volume.nrrd # 3D volumes 59 | - *_label.npy # landmarks 60 | - *_spacing.npy # CT spacings, used for calculating MRE 61 | ``` 62 | 63 | ## 3. Train and Test 64 | ### 3.1 Network Training 65 | 66 | * Training with different network backbones 67 | ``` 68 | python main_yolol.py --model_name PVNet # network training using backbone PVNet 69 | python main_yolol.py --model_name PUNet3D # network training using backbone PUNet3D 70 | python main_yolol.py --model_name PResidualUNet3D # network training using backbone PResidualUNet3D 71 | ``` 72 | 73 | * Training with different GPUs 74 | ``` 75 | python main_yolol.py --gpu 0 # training with 1 gpu 76 | python main_yolol.py --gpu 0,1,2,3 # training with 4 gpus 77 | ``` 78 | 79 | ### 3.2 Fine-tuning in a pretrained checkpoint 80 | ``` 81 | python main_yolol.py --resume ../SavePath/yolol/model.ckpt 82 | ``` 83 | 84 | ### 3.3 Metric counting 85 | ``` 86 | python main_yolol.py --test_flag 0 --resume ../SavePath/yolol/model.ckpt # calculate MRE and SDR in validation set 87 | python main_yolol.py --test_flag 1 --resume ../SavePath/yolol/model.ckpt # calculate MRE and SDR in test set 88 | ``` 89 | 90 | ### 3.4 Training baseline heatmap regression model 91 | ``` 92 | python main_baseline.py # network training for baseline heatmap regression model 93 | ``` 94 | 95 | ## 4. Leadboard (Update 2023/06/15) 96 | 97 | ### The ACC, F1, MRE, and SDR on the MINI subset. 98 | 99 | | **Models** | **ACC(%)** | **F1(%)** | **MRE±Std(mm)** | **SDR-2mm(%)** | **SDR-2.5mm(%)** | **SDR-3mm(%)** | **SDR-4mm(%)** | 100 | | :----------- | :---------- | :--------- | :--------------- | :-------------- | :---------------- | :-------------- | :-------------- | 101 | | Our Baseline | 93.04% | 94.98 | 2.26±1.26 | 61.89 | 74.86 | 82.43 | 91.89 | 102 | | placeholder | | | | | | | | 103 | | placeholder | | | | | | | | 104 | | placeholder | | | | | | | | 105 | 106 | 107 | ### The MRE and SDR on the whole dataset. 108 | 109 | 110 | | **Models** | **MRE±Std(mm)** | **SDR-2mm(%)** | **SDR-2.5mm(%)** | **SDR-3mm(%)** | **SDR-4mm(%)** | 111 | | :----------- | :--------------- | :-------------- | :---------------- | :-------------- | :-------------- | 112 | | Our Baseline | 1.70±0.72 | 76.43 | 86.45 | 90.91 | 95.20 | 113 | | placeholder | | | | | | 114 | | placeholder | | | | | | 115 | | placeholder | | | | | | 116 | 117 | 118 | 119 | 120 | ## 5. Contact 121 | 122 | 123 | Institution: Intelligent Medical Center, Sichuan University 124 | 125 | email: tao_he@scu.edu.cn; taohescu@gmail.com 126 | 127 | ## 6. Citation (coming soon) 128 | 129 | -------------------------------------------------------------------------------- /code/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/data_utils/__init__.py -------------------------------------------------------------------------------- /code/data_utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/data_utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/data_utils/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/data_utils/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /code/data_utils/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/data_utils/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /code/data_utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import nrrd 5 | 6 | class Molar3D(Dataset): 7 | def __init__(self, transform=None, phase='train', parent_path=None, data_type="full"): 8 | 9 | self.data_files = [] 10 | self.label_files = [] 11 | self.spacing = [] 12 | 13 | cur_path = os.path.join(parent_path, str(phase)) 14 | for file_name in os.listdir(cur_path): 15 | if file_name.endswith('_volume.nrrd'): 16 | cur_file_abbr = file_name.split("_volume")[0] 17 | 18 | if data_type == "full": 19 | _label = np.load(os.path.join(cur_path, cur_file_abbr+"_label.npy")) 20 | if np.any(np.sum(_label,1)<0): 21 | continue 22 | if data_type == "mini": 23 | _label = np.load(os.path.join(cur_path, cur_file_abbr+"_label.npy")) 24 | if np.all(np.sum(_label,1)>0): 25 | continue 26 | 27 | self.data_files.append(os.path.join(cur_path, cur_file_abbr+"_volume.nrrd")) 28 | self.label_files.append(os.path.join(cur_path, cur_file_abbr+"_label.npy")) 29 | self.spacing.append(os.path.join(cur_path, cur_file_abbr+"_spacing.npy")) 30 | 31 | self.transform = transform 32 | print('the data length is %d, for %s' % (len(self.data_files), phase)) 33 | 34 | def __len__(self): 35 | L = len(self.data_files) 36 | return L 37 | 38 | def __getitem__(self, index): 39 | _img, _ = nrrd.read(self.data_files[index]) 40 | _landmark = np.load(self.label_files[index]) 41 | _spacing = np.load(self.spacing[index]) 42 | sample = {'image': _img, 'landmarks': _landmark, 'spacing':_spacing} 43 | if self.transform is not None: 44 | sample = self.transform(sample) 45 | return sample 46 | 47 | def __str__(self): 48 | pass 49 | 50 | -------------------------------------------------------------------------------- /code/data_utils/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.ndimage.interpolation import zoom 4 | 5 | 6 | def zoomout_imgandlandmark(img, landmarks, rate): 7 | new_img = zoom(img, rate, order=1) 8 | new_landmarks = [] 9 | for position in landmarks: 10 | position_c = position[0] * rate[0] 11 | position_h = position[1] * rate[1] 12 | position_w = position[2] * rate[2] 13 | new_landmarks.append(np.array([position_c, position_h, position_w])) 14 | return new_img, np.array(new_landmarks) 15 | 16 | 17 | class RandomCrop(object): 18 | def __init__(self, min_rate=0.6, size=[128,128,64]): 19 | self.size = np.array(size) 20 | self.min_rate = min_rate 21 | 22 | def __call__(self, sample): 23 | img = sample['image'] 24 | landmarks = sample['landmarks'] 25 | min_ = np.ones((3,)) * 1000 26 | max_ = np.zeros((3,)) 27 | for landmark in landmarks: 28 | for i in range(3): 29 | # we use a very small value to indicate nonexist landmark 30 | if np.mean(landmark)< -100: 31 | continue 32 | if min_[i] > landmark[i]: 33 | min_[i] = landmark[i] 34 | if max_[i] < landmark[i]: 35 | max_[i] = landmark[i] 36 | 37 | # according to the min and max of landmarks to set the maximum zoom rate 38 | zoom_max = [self.size[0]/(max_[0]-min_[0])-0.02, self.size[1]/(max_[1]-min_[1])-0.02, self.size[2]/(max_[2]-min_[2])-0.04] 39 | 40 | ######################### zoom out ############################# 41 | random_rate0 = np.random.uniform(self.min_rate, min(zoom_max[0], 1)) 42 | random_rate1 = np.random.uniform(self.min_rate, min(zoom_max[1], 1)) 43 | random_rate2 = np.random.uniform(self.min_rate, min(zoom_max[2], 1)) 44 | if zoom_max[0] landmark[i]: 57 | min_[i] = landmark[i] 58 | if max_[i] < landmark[i]: 59 | max_[i] = landmark[i] 60 | ######################### cropping ############################### 61 | begin_=(min_+max_)/2.-self.size/2. 62 | bc = max(0, begin_[0]); ec = min(min_[0], img.shape[0]-self.size[0]) 63 | bh = max(0, begin_[1]); eh = min(min_[1], img.shape[1]-self.size[1]) 64 | bw = max(0, begin_[2]); ew = min(min_[2], img.shape[2]-self.size[2]) 65 | if ec - bc < 1: 66 | ec += 1 67 | if eh - bh < 1: 68 | eh += 1 69 | if ew - bw < 1: 70 | ew += 1 71 | cc = np.random.randint(bc, ec) 72 | ch = np.random.randint(bh, eh) 73 | cw = np.random.randint(bw, ew) 74 | # random crop here 75 | cur_crop_img = img[cc:(cc+self.size[0]), ch:(ch+self.size[1]), cw:(cw+self.size[2])] 76 | 77 | if(cur_crop_img.shape[0]!=self.size[0] or cur_crop_img.shape[1]!=self.size[1] or cur_crop_img.shape[2]!=self.size[2]): 78 | print(cc, ch, cw) 79 | print(img.shape) 80 | print(cur_crop_img.shape) 81 | print('get a error crop img') 82 | pre_new_landmarks = [] 83 | for landmark in landmarks: 84 | cur_landmark = landmark - np.array([cc, ch, cw]) 85 | pre_new_landmarks.append(cur_landmark) 86 | sample['landmarks'] = np.array(pre_new_landmarks) 87 | sample['image'] = cur_crop_img 88 | return sample 89 | 90 | 91 | class LandmarkProposal(object): 92 | def __init__(self, size=[128,128,64], shrink=4., anchors=[0.5, 0.75, 1., 1.25], max_num=400): 93 | self.size = size 94 | self.shrink = shrink 95 | self.anchors = anchors 96 | self.max_num = max_num # setting a fixed anchor number for minibatch 97 | 98 | def __call__(self, sample): 99 | landmarks = sample['landmarks'] 100 | landmarks = landmarks / self.shrink # shrinking the landmark coordinates 101 | proposals = [] 102 | 103 | for idx, anchor in enumerate(self.anchors): 104 | proposal = [] 105 | for ldx, landmark in enumerate(landmarks): 106 | if np.mean(landmark) < -100: 107 | cur_ldx = -1 - ldx # negative number indicates nonexist landmarks 108 | proposal.append([0,0,0,0,0,0,cur_ldx]) 109 | continue 110 | else: 111 | cur_ldx = ldx 112 | 113 | # if a landmark exist, calculate the proposals 114 | cl_min = landmark - anchor 115 | cl_max = landmark + anchor 116 | c = max(0, int(cl_min[0])) 117 | max_c = int(np.ceil(cl_max[0])); max_w = int(np.ceil(cl_max[1])); max_h = int(np.ceil(cl_max[2])) 118 | while(c<=max_c and c=self.max_num): 135 | print("too many proposals were found !") 136 | proposal = proposal[:self.max_num] 137 | # if getting less proposals, padding the tensor 138 | if len(proposal) landmark[i]: 160 | min_[i] = landmark[i] 161 | if max_[i] < landmark[i]: 162 | max_[i] = landmark[i] 163 | zoom_max = [self.size[0]/(max_[0]-min_[0])-0.02, self.size[1]/(max_[1]-min_[1])-0.02, self.size[2]/(max_[2]-min_[2])-0.04] 164 | 165 | ######################### zoom out ############################# 166 | random_rate0 = min(zoom_max[0], 1) 167 | random_rate1 = min(zoom_max[1], 1) 168 | random_rate2 = min(zoom_max[2], 1) 169 | img, landmarks = zoomout_imgandlandmark(img, landmarks, [random_rate0,random_rate1,random_rate2]) 170 | 171 | min_ = np.ones((3,)) * 1000 172 | max_ = np.zeros((3,)) 173 | for landmark in landmarks: 174 | for i in range(3): 175 | if np.mean(landmark)< -100: 176 | continue 177 | if min_[i] > landmark[i]: 178 | min_[i] = landmark[i] 179 | if max_[i] < landmark[i]: 180 | max_[i] = landmark[i] 181 | # import pdb; pdb.set_trace() 182 | begin = ((max_ + min_) /2 - self.size/2 ).astype("int32") 183 | begin[0] = max(0, min(begin[0], img.shape[0]-self.size[0]) ) 184 | begin[1] = max(0, min(begin[1], img.shape[1]-self.size[1])) 185 | begin[2] = max(0, min(begin[2], img.shape[2]-self.size[2])) 186 | 187 | if begin[0]+self.size[0] > img.shape[0] or begin[1]+self.size[1] > img.shape[1] or begin[2]+self.size[2] > img.shape[2]: 188 | print("find a very small landmark , error !!!!!") 189 | # center crop here 190 | sample["image"] = img[begin[0]:begin[0]+self.size[0], begin[1]:begin[1]+self.size[1], begin[2]:begin[2]+self.size[2]] 191 | landmarks[:, 0] = landmarks[:, 0] - begin[0] 192 | landmarks[:, 1] = landmarks[:, 1] - begin[1] 193 | landmarks[:, 2] = landmarks[:, 2] - begin[2] 194 | sample["landmarks"] = landmarks 195 | return sample 196 | 197 | 198 | class Normalize(object): 199 | def __init__(self): 200 | pass 201 | 202 | def __call__(self, sample): 203 | img = np.array(sample['image']).astype(np.float32) 204 | img /= 255.0 205 | sample['image'] = img 206 | return sample 207 | 208 | 209 | class LandMarkToGaussianHeatMap(object): 210 | def __init__(self, R=20., img_size=(128,128,64), n_class=14, GPU=None): 211 | self.R = R # gaussian heatmap radius 212 | self.GPU = GPU 213 | 214 | # generate index in three views: length, width, height 215 | c_row = np.array([i for i in range(img_size[0])]) 216 | c_matrix = np.stack([c_row] * img_size[1], 1) 217 | c_matrix = np.stack([c_matrix] * img_size[2], 2) 218 | c_matrix = np.stack([c_matrix] * n_class, 0) 219 | 220 | h_row = np.array([i for i in range(img_size[1])]) 221 | h_matrix = np.stack([h_row] * img_size[0], 0) 222 | h_matrix = np.stack([h_matrix] * img_size[2], 2) 223 | h_matrix = np.stack([h_matrix] * n_class, 0) 224 | 225 | w_row = np.array([i for i in range(img_size[2])]) 226 | w_matrix = np.stack([w_row] * img_size[0], 0) 227 | w_matrix = np.stack([w_matrix] * img_size[1], 1) 228 | w_matrix = np.stack([w_matrix] * n_class, 0) 229 | if GPU is not None: 230 | self.c_matrix = torch.tensor(c_matrix).float().to(self.GPU) 231 | self.h_matrix = torch.tensor(h_matrix).float().to(self.GPU) 232 | self.w_matrix = torch.tensor(w_matrix).float().to(self.GPU) 233 | 234 | def __call__(self, landmarks): 235 | n_landmark = landmarks.shape[1] 236 | batch_size = landmarks.shape[0] 237 | 238 | if self.GPU is not None: 239 | # generate the mask inside the mask with radius R 240 | mask = torch.sqrt( 241 | torch.pow( 242 | self.c_matrix - 243 | torch.tensor(np.expand_dims(np.expand_dims(landmarks[:, :, 0:1], 3),4)).float().to(self.GPU), 2) + torch.pow( 244 | self.h_matrix - 245 | torch.tensor(np.expand_dims(np.expand_dims(landmarks[:, :, 1:2], 3),4)).float( 246 | ).to(self.GPU), 2) + torch.pow( 247 | self.w_matrix - 248 | torch.tensor(np.expand_dims(np.expand_dims(landmarks[:, :, 2:3], 3),4) 249 | ).float().to(self.GPU), 2)) <= self.R 250 | 251 | # generate the heatmap with Gaussian distribution 252 | # the maximum value is 2, the min value is -1 253 | cur_heatmap = torch.exp(-(( 254 | torch.pow( 255 | self.c_matrix - torch.tensor( 256 | np.expand_dims(np.expand_dims(landmarks[:, :, 0:1], 3),4)).float().to( 257 | self.GPU), 2) + torch.pow( 258 | self.h_matrix - torch.tensor( 259 | np.expand_dims(np.expand_dims(landmarks[:, :, 1:2], 3),4)).float().to( 260 | self.GPU), 2) + torch.pow( 261 | self.w_matrix - torch.tensor( 262 | np.expand_dims(np.expand_dims(landmarks[:, :, 2:3], 263 | 3),4)).float().to(self.GPU), 2)) / 264 | (self.R * self.R) / 0.2)) 265 | heatmap = 2 * cur_heatmap * mask.float() + mask.float() - 1 266 | return heatmap 267 | 268 | 269 | class ToTensor(object): 270 | def __init__(self): 271 | pass 272 | 273 | def __call__(self, sample): 274 | img = np.array(sample['image']).astype(np.float32) 275 | img = np.expand_dims(img, 0) 276 | sample['image'] = img 277 | sample['landmarks'] = sample['landmarks'].astype(np.float32) 278 | return sample 279 | 280 | 281 | -------------------------------------------------------------------------------- /code/main_baseline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import logging 8 | 9 | import torch 10 | from torch.nn import DataParallel 11 | from torch.backends import cudnn 12 | from torch import optim 13 | from torchvision import transforms 14 | from torch.utils.data import DataLoader 15 | 16 | from data_utils.dataloader import Molar3D 17 | import data_utils.transforms as tr 18 | from utils import setgpu, metric 19 | from data_utils.transforms import LandMarkToGaussianHeatMap 20 | from models.losses import HNM_heatmap 21 | from models.VNet import VNet 22 | from models.UNet import UNet3D, ResidualUNet3D 23 | 24 | # super parameters settings here 25 | parser = argparse.ArgumentParser(description='PyTorch landmarking baselin heatmap regression') 26 | # the network backbone settings 27 | parser.add_argument('--model_name',metavar='MODEL',default='VNet',type=str, choices=['VNet', 'UNet3D', 'ResidualUNet3D']) 28 | # the maximum training epochs 29 | parser.add_argument('--epochs',default=200,type=int,metavar='N') 30 | # the beginning epoch 31 | parser.add_argument('--start_epoch',default=1,type=int) 32 | # the batch size, default 4 for one GPU 33 | parser.add_argument('-b','--batch_size',default=4,type=int) 34 | # the initial learning rate 35 | parser.add_argument('--lr','--learning_rate',default=0.001,type=float) 36 | # the path for loading pretrained model parameters 37 | parser.add_argument('--resume',default='',type=str) 38 | # the weight decay 39 | parser.add_argument('--weight-decay','--wd',default=0.0005,type=float) 40 | # the path to save the model parameters 41 | parser.add_argument('--save_dir',default='../SavePath/baseline',type=str) 42 | # the settings of gpus, multiGPU can use '0,1' or '0,1,2,3' 43 | parser.add_argument('--gpu', default='0', type=str) 44 | # the early stop parameter 45 | parser.add_argument('--patient',default=20,type=int) 46 | # the loss HNM_heatmap for baseline heatmap regression, HNM_propmap for yolol 47 | parser.add_argument('--loss_name', default='HNM_heatmap',type=str) 48 | # the path of dataset 49 | # before training please download the dataset and put it in "../mmld_dataset" 50 | parser.add_argument('--data_path', 51 | default='../mmld_dataset', 52 | type=str, 53 | metavar='N', 54 | help='data path') 55 | # the classes 56 | parser.add_argument('--n_class',default=14,type=int, help='number of landmarks 14') 57 | # the radius of gaussian heatmap's mask 58 | parser.add_argument('-R','--focus_radius', default=20,type=int) 59 | # the test flag | -1 for train, 0 for eval, 1 for test | 60 | parser.add_argument('--test_flag',default=-1,type=int, choices=[-1, 0, 1]) 61 | 62 | 63 | DEVICE = torch.device("cuda" if True else "cpu") 64 | def main(args): 65 | cudnn.benchmark = True 66 | setgpu(args.gpu) 67 | ########################### model init ############################################# 68 | net = globals()[args.model_name](n_class=args.n_class) 69 | loss = globals()[args.loss_name](R=args.focus_radius) 70 | 71 | start_epoch = args.start_epoch 72 | save_dir = args.save_dir 73 | logging.info(args) 74 | if args.resume: 75 | checkpoint = torch.load(args.resume) 76 | start_epoch = checkpoint['epoch'] + 1 77 | net.load_state_dict(checkpoint['state_dict']) 78 | 79 | net = net.to(DEVICE) 80 | loss = loss.to(DEVICE) 81 | if len(args.gpu.split(',')) > 1 or args.gpu == 'all': 82 | net = DataParallel(net) 83 | 84 | # using Adam optimizer for network training 85 | optimizer = torch.optim.Adam(net.parameters(), 86 | lr=args.lr, 87 | betas=(0.9, 0.98), 88 | weight_decay=args.weight_decay) 89 | # the lr decayed with rate 0.98 each epoch 90 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98, last_epoch=-1) 91 | 92 | 93 | ########################## network testing ######################################## 94 | # if the test_flag > -1, calculate the MRE and SDR (%) for val and test set 95 | if args.test_flag > -1: 96 | args.batch_size = 1 97 | 98 | if args.test_flag == 0: 99 | test_transform = transforms.Compose([ 100 | tr.Normalize(), 101 | tr.ToTensor(), 102 | ]) 103 | phase = 'val' 104 | else: 105 | test_transform = transforms.Compose([ 106 | tr.CenterCrop(), # center crop for validation 107 | tr.Normalize(), 108 | tr.ToTensor(), 109 | ]) 110 | phase = 'test' 111 | test_dataset = Molar3D(transform=test_transform, 112 | phase=phase, 113 | parent_path=args.data_path) 114 | testloader = DataLoader(test_dataset, 115 | batch_size=1, 116 | shuffle=False, 117 | num_workers=4) 118 | test(testloader, net) 119 | return 120 | 121 | 122 | # generate Gaussian Heatmap using pytorch GPU tensor 123 | l2h = LandMarkToGaussianHeatMap(R=args.focus_radius, 124 | n_class=args.n_class, 125 | GPU=DEVICE, 126 | img_size=(128,128,64)) 127 | 128 | ########################## data preparation ######################################## 129 | # if the test_flag <= -1, begin network training 130 | # train set and validation set preprocessing 131 | train_transform = transforms.Compose([ 132 | tr.RandomCrop(), # zoom and random crop for data augumentation 133 | tr.Normalize(), 134 | tr.ToTensor(), 135 | ]) 136 | train_dataset = Molar3D(transform=train_transform, 137 | phase='train', 138 | parent_path=args.data_path) 139 | trainloader = DataLoader(train_dataset, 140 | batch_size=args.batch_size, 141 | shuffle=True, 142 | num_workers=8) 143 | 144 | eval_transform = transforms.Compose([ 145 | tr.CenterCrop(), # center crop for validation 146 | tr.Normalize(), 147 | tr.ToTensor(), 148 | ]) 149 | eval_dataset = Molar3D(transform=eval_transform, 150 | phase='val', 151 | parent_path=args.data_path) 152 | evalloader = DataLoader(eval_dataset, 153 | batch_size=args.batch_size, 154 | shuffle=False, 155 | num_workers=8) 156 | 157 | 158 | ########################## network training ########################################## 159 | # begin training here 160 | break_flag = 0. # counting for early stop 161 | low_loss = 100. 162 | total_loss = [] 163 | 164 | for epoch in range(start_epoch, args.epochs + 1): 165 | # train in one epoch 166 | train(trainloader, net, loss, epoch, optimizer, l2h) 167 | if optimizer.param_groups[0]['lr'] > args.lr * 0.03: 168 | scheduler.step() 169 | 170 | # validation in one epoch 171 | break_flag += 1 172 | eval_loss = evaluation(evalloader, net, loss, epoch, l2h) 173 | total_loss.append(eval_loss) 174 | if low_loss > eval_loss: 175 | low_loss = eval_loss 176 | break_flag = 0 177 | if len(args.gpu.split(',')) > 1 or args.gpu == 'all': 178 | state_dict = net.module.state_dict() 179 | else: 180 | state_dict = net.state_dict() 181 | torch.save( 182 | { 183 | 'epoch': epoch, 184 | 'save_dir': save_dir, 185 | 'state_dict': state_dict, 186 | 'optimizer': optimizer.state_dict(), 187 | 'args': args 188 | }, os.path.join(save_dir, 'model.ckpt')) 189 | logging.info( 190 | '************************ model saved successful ************************** !\n' 191 | ) 192 | 193 | if break_flag >args.patient: 194 | break 195 | 196 | 197 | def train(data_loader, net, loss, epoch, optimizer, l2h): 198 | start_time = time.time() 199 | net.train() 200 | total_train_loss = [] 201 | for i, sample in enumerate(data_loader): 202 | data = sample['image'] 203 | landmark = sample['landmarks'] 204 | heatmap_batch = l2h(landmark) 205 | data = data.to(DEVICE) 206 | heatmap = net(data) 207 | optimizer.zero_grad() 208 | cur_loss = loss(heatmap, heatmap_batch) 209 | total_train_loss.append(cur_loss.item()) 210 | cur_loss.backward() 211 | optimizer.step() 212 | 213 | logging.info( 214 | 'Train--Epoch[%d], lr[%.6f], total loss: [%.6f], time: %.1f s!' 215 | % (epoch, optimizer.param_groups[0]['lr'], np.mean(total_train_loss), time.time() - start_time)) 216 | 217 | 218 | def evaluation(dataloader, net, loss, epoch, l2h): 219 | start_time = time.time() 220 | net.eval() 221 | total_loss = [] 222 | 223 | with torch.no_grad(): 224 | for i, sample in enumerate(dataloader): 225 | data = sample['image'] 226 | landmark = sample['landmarks'] 227 | heatmap_batch = l2h(landmark) 228 | data = data.to(DEVICE) 229 | heatmap= net(data) 230 | cur_loss = loss(heatmap, heatmap_batch) 231 | total_loss.append(cur_loss.item()) 232 | 233 | logging.info( 234 | 'Eval--Epoch[%d], total loss: [%.6f], time: %.1f s!' 235 | % (epoch, np.mean(total_loss), time.time() - start_time)) 236 | logging.info( 237 | '***************************************************************************' 238 | ) 239 | return np.mean(total_loss) 240 | 241 | 242 | def test(dataloader, net): 243 | start_time = time.time() 244 | net.eval() 245 | total_mre = [] 246 | total_mean_mre = [] 247 | N = 0 248 | total_hits = np.zeros((8, 14)) 249 | with torch.no_grad(): 250 | for i, sample in enumerate(dataloader): 251 | data = sample['image'] 252 | landmarks = sample['landmarks'] 253 | spacing = sample['spacing'] 254 | data = data.to(DEVICE) 255 | heatmap = net(data) 256 | 257 | mre, hits = metric(heatmap.cpu().numpy(), 258 | spacing.numpy(), 259 | landmarks.cpu().numpy()) 260 | total_hits += hits 261 | total_mre.append(np.array(mre)) 262 | N += data.shape[0] 263 | cur_mre = [] 264 | for cdx in range(len(mre[0])): 265 | if mre[0][cdx]>0: 266 | cur_mre.append(mre[0][cdx]) 267 | total_mean_mre.append(np.mean(cur_mre)) 268 | print("#: No.", i, "--the current MRE is [%.4f] "%np.mean(cur_mre)) 269 | total_mre = np.concatenate(total_mre, 0) 270 | 271 | 272 | ################################ molar print############################################## 273 | names = [ 274 | 'L0','La', 'Lb', 'Lc', 'Ld', 'Le', 'Lf', 'R0', 'Ra','Rb','Rc','Rd','Re','Rf' 275 | ] 276 | 277 | IDs = ["MRE", "SD", "2.0", "2.5", "3.0", "4."] 278 | form = {"metric": IDs} 279 | mre = [] 280 | sd = [] 281 | cur_hits = total_hits[:4] / total_hits[4:] 282 | 283 | ############################## each class mre ############################################## 284 | for i, name in enumerate(names): 285 | cur_mre = [] 286 | for j in range(total_mre.shape[0]): 287 | if total_mre[j,i] > 0: 288 | cur_mre.append(total_mre[j,i]) 289 | cur_mre = np.array(cur_mre) 290 | mre.append(np.mean(cur_mre)) 291 | sd.append(np.sqrt(np.sum(pow(np.array(cur_mre) - np.mean(cur_mre), 2)) / (N-1))) 292 | 293 | ########################### total mre ###################################################### 294 | mre = np.stack(mre, 0) 295 | sd = np.stack(sd, 0) 296 | total = np.stack([mre, sd], 0) 297 | total = np.concatenate([total, cur_hits], 0) 298 | for i, name in enumerate(names): 299 | form[name] = total[:, i] 300 | df = pd.DataFrame(form, columns = form.keys()) 301 | df.to_excel( 'baseline_test.xlsx', index = False, header=True) 302 | 303 | ########################### total mre ###################################################### 304 | mmre = np.mean(total_mean_mre) 305 | sd = np.sqrt(np.sum(pow(np.array(total_mean_mre) - mmre, 2)) / (N-1)) 306 | 307 | total_hits = np.sum(total_hits, 1) 308 | logging.info( 309 | 'Test-- MRE: [%.2f] + SD: [%.2f], 2.0 mm: [%.4f], 2.5 mm: [%.4f], 3.0 mm: [%.4f], 4.0 mm: [%.4f], using time: %.1f s!' %( 310 | mmre, sd, 311 | total_hits[0] / total_hits[4], 312 | total_hits[1] / total_hits[5], 313 | total_hits[2] / total_hits[6], 314 | total_hits[3] / total_hits[7], 315 | time.time()-start_time)) 316 | logging.info( 317 | '***************************************************************************' 318 | ) 319 | 320 | if __name__ == '__main__': 321 | global args 322 | args = parser.parse_args() 323 | if not os.path.exists(args.save_dir): 324 | os.makedirs(args.save_dir) 325 | args.save_dir = os.path.join(args.save_dir, args.model_name) 326 | if not os.path.exists(args.save_dir): 327 | os.makedirs(args.save_dir) 328 | 329 | logging.basicConfig(level=logging.INFO, 330 | format='%(asctime)s,%(lineno)d: %(message)s\n', 331 | datefmt='%Y-%m-%d(%a)%H:%M:%S', 332 | filename=os.path.join(args.save_dir, 'log.txt'), 333 | filemode='a') 334 | console = logging.StreamHandler() 335 | console.setLevel(logging.INFO) 336 | logging.getLogger().addHandler(console) 337 | main(args) 338 | 339 | -------------------------------------------------------------------------------- /code/main_yolol.py: -------------------------------------------------------------------------------- 1 | 2 | # package here 3 | import argparse 4 | import os 5 | import time 6 | import numpy as np 7 | import logging 8 | import pandas as pd 9 | 10 | import torch 11 | from torch.nn import DataParallel 12 | from torch.backends import cudnn 13 | from torch import optim 14 | from torchvision import transforms 15 | from torch.utils.data import DataLoader 16 | 17 | from data_utils.dataloader import Molar3D 18 | import data_utils.transforms as tr 19 | from utils import setgpu, metric_proposal 20 | from models.losses import HNM_propmap 21 | 22 | from models.VNet import PVNet 23 | from models.UNet import PUNet3D , PResidualUNet3D 24 | 25 | 26 | # super parameters settings here 27 | parser = argparse.ArgumentParser(description='PyTorch Robust Mandibular Molar Landmark Detection') 28 | # the network backbone settings 29 | parser.add_argument('--model_name',metavar='MODEL',default='PVNet',type=str, choices=['PVNet', 'PUNet3D', 'PResidualUNet3D']) 30 | # the maximum training epochs 31 | parser.add_argument('--epochs',default=200,type=int,metavar='N') 32 | # the beginning epoch 33 | parser.add_argument('--start_epoch',default=1,type=int) 34 | # the batch size, default 4 for one GPU 35 | parser.add_argument('-b','--batch_size',default=4,type=int) 36 | # the initial learning rate 37 | parser.add_argument('--lr','--learning_rate',default=0.001,type=float) 38 | # the path for loading pretrained model parameters 39 | parser.add_argument('--resume',default='',type=str) 40 | # the weight decay 41 | parser.add_argument('--weight-decay','--wd',default=0.0005,type=float) 42 | # the path to save the model parameters 43 | parser.add_argument('--save_dir',default='../SavePath/yolol',type=str) 44 | # the settings of gpus, multiGPU can use '0,1' or '0,1,2,3' 45 | parser.add_argument('--gpu', default='0', type=str) 46 | # the early stop parameter 47 | parser.add_argument('--patient',default=20,type=int) 48 | # the loss HNM_heatmap for baseline heatmap regression, HNM_propmap for yolol 49 | parser.add_argument('--loss_name', default='HNM_propmap',type=str) 50 | # the path of dataset 51 | # before training please download the dataset and put it in "../mmld_dataset" 52 | parser.add_argument('--data_path', 53 | default='../mmld_dataset', 54 | type=str, 55 | metavar='N', 56 | help='data path') 57 | # the classes 58 | parser.add_argument('--n_class',default=14,type=int, help='number of landmarks 14') 59 | # the downsample times 60 | parser.add_argument('--shrink',default=4,type=int,metavar='shrink') 61 | # the anchor balls default r=[0.5u, 0.75u, 1u, 1.25u] 62 | parser.add_argument('--anchors', 63 | default=[0.5, 0.75, 1., 1.25], 64 | type=list, 65 | metavar='anchors', 66 | help='the anchor balls to predict') 67 | # the test flag | -1 for train, 0 for eval, 1 for test | 68 | parser.add_argument('--test_flag',default=-1,type=int, choices=[-1, 0, 1]) 69 | # the data type | full for dataset with complete landmarks | mini for mini dataset with uncomplete landmarks | all for default dataset 70 | parser.add_argument('--data_type', default='all',type=str) 71 | 72 | DEVICE = torch.device("cuda" if True else "cpu") 73 | 74 | def main(args): 75 | logging.info(args) 76 | cudnn.benchmark = True 77 | setgpu(args.gpu) 78 | 79 | ########################### model init ############################################# 80 | net = globals()[args.model_name](n_class=args.n_class, n_anchor=len(args.anchors)) 81 | loss = globals()[args.loss_name](n_class=args.n_class, device=DEVICE) 82 | 83 | start_epoch = args.start_epoch 84 | save_dir = args.save_dir 85 | logging.info(args) 86 | if args.resume: 87 | checkpoint = torch.load(args.resume) 88 | start_epoch = checkpoint['epoch'] + 1 89 | net.load_state_dict(checkpoint['state_dict']) 90 | 91 | net = net.to(DEVICE) 92 | loss = loss.to(DEVICE) 93 | if len(args.gpu.split(',')) > 1 or args.gpu == 'all': 94 | net = DataParallel(net) 95 | 96 | # using Adam optimizer for network training 97 | optimizer = torch.optim.Adam(net.parameters(), 98 | lr=args.lr, 99 | betas=(0.9, 0.98), 100 | weight_decay=args.weight_decay) 101 | # the lr decayed with rate 0.98 each epoch 102 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98, last_epoch=-1) 103 | 104 | 105 | ########################## network testing ######################################## 106 | # if the test_flag > -1, calculate the MRE and SDR (%) for val and test set 107 | if args.test_flag > -1: 108 | args.batch_size = 1 109 | if args.test_flag == 0: 110 | test_transform = transforms.Compose([ 111 | tr.LandmarkProposal(shrink=args.shrink, anchors=args.anchors), 112 | tr.Normalize(), 113 | tr.ToTensor(), 114 | ]) 115 | phase = 'val' 116 | else: 117 | test_transform = transforms.Compose([ 118 | tr.CenterCrop(), # center crop for validation 119 | tr.LandmarkProposal(shrink=args.shrink, anchors=args.anchors), 120 | tr.Normalize(), 121 | tr.ToTensor(), 122 | ]) 123 | phase = 'test' 124 | test_dataset = Molar3D(transform=test_transform, 125 | phase=phase, 126 | parent_path=args.data_path, 127 | data_type=args.data_type) 128 | 129 | testloader = DataLoader(test_dataset, 130 | batch_size=args.batch_size, 131 | shuffle=False, 132 | num_workers=4) 133 | test(testloader, net, args) 134 | return 135 | 136 | 137 | ########################## data preparation ######################################## 138 | # if the test_flag <= -1, begin network training 139 | # train set and validation set preprocessing 140 | train_transform = transforms.Compose([ 141 | tr.RandomCrop(), # zoom and random crop for data augumentation 142 | tr.LandmarkProposal(shrink=args.shrink, anchors=args.anchors), # generate the anchor proposal 143 | tr.Normalize(), 144 | tr.ToTensor(), 145 | ]) 146 | train_dataset = Molar3D(transform=train_transform, 147 | phase='train', 148 | parent_path=args.data_path, 149 | data_type = args.data_type) 150 | trainloader = DataLoader(train_dataset, 151 | batch_size=args.batch_size, 152 | shuffle=True, 153 | num_workers=8) 154 | 155 | eval_transform = transforms.Compose([ 156 | tr.CenterCrop(), # center crop for validation 157 | tr.LandmarkProposal(shrink=args.shrink, anchors=args.anchors), 158 | tr.Normalize(), 159 | tr.ToTensor(), 160 | ]) 161 | eval_dataset = Molar3D(transform=eval_transform, 162 | phase='val', 163 | parent_path=args.data_path, 164 | data_type=args.data_type) 165 | evalloader = DataLoader(eval_dataset, 166 | batch_size=args.batch_size, 167 | shuffle=False, 168 | num_workers=8) 169 | 170 | 171 | ########################## network training ########################################## 172 | # begin training here 173 | break_flag = 0. # counting for early stop 174 | low_loss = 100. 175 | total_loss = [] 176 | 177 | for epoch in range(start_epoch, args.epochs + 1): 178 | # train in one epoch 179 | train(trainloader, net, loss, epoch, optimizer) 180 | if optimizer.param_groups[0]['lr'] > args.lr * 0.03: 181 | scheduler.step() 182 | 183 | # validation in one epoch 184 | break_flag += 1 185 | eval_loss = evaluation(evalloader, net, loss, epoch) 186 | total_loss.append(eval_loss) 187 | if low_loss > eval_loss: 188 | low_loss = eval_loss 189 | break_flag = 0 190 | if len(args.gpu.split(',')) > 1 or args.gpu == 'all': 191 | state_dict = net.module.state_dict() 192 | else: 193 | state_dict = net.state_dict() 194 | torch.save( 195 | { 196 | 'epoch': epoch, 197 | 'save_dir': save_dir, 198 | 'state_dict': state_dict, 199 | 'optimizer': optimizer.state_dict(), 200 | 'args': args 201 | }, os.path.join(save_dir, 'model.ckpt')) 202 | logging.info( 203 | '************************ model saved successful ************************** !\n' 204 | ) 205 | 206 | if break_flag > args.patient: 207 | break 208 | 209 | 210 | def train(data_loader, net, loss, epoch, optimizer): 211 | start_time = time.time() 212 | net.train() 213 | total_train_loss = [] 214 | for i, sample in enumerate(data_loader): 215 | data = sample['image'] 216 | proposals = sample['proposals'] 217 | data = data.to(DEVICE) 218 | proposals = proposals.to(DEVICE) 219 | proposal_map = net(data) 220 | optimizer.zero_grad() 221 | cur_loss = loss(proposal_map, proposals) 222 | total_train_loss.append(cur_loss.item()) 223 | cur_loss.backward() 224 | optimizer.step() 225 | 226 | logging.info( 227 | 'Train--Epoch[%d], lr[%.6f], total loss: [%.6f], time: %.1f s!' 228 | % (epoch, optimizer.param_groups[0]['lr'], np.mean(total_train_loss), time.time() - start_time)) 229 | 230 | 231 | def evaluation(dataloader, net, loss, epoch): 232 | start_time = time.time() 233 | net.eval() 234 | total_loss = [] 235 | with torch.no_grad(): 236 | for i, sample in enumerate(dataloader): 237 | data = sample['image'] 238 | proposals = sample['proposals'] 239 | data = data.to(DEVICE) 240 | proposals = proposals.to(DEVICE) 241 | proposal_map = net(data) 242 | cur_loss = loss(proposal_map, proposals) 243 | total_loss.append(cur_loss.item()) 244 | 245 | logging.info( 246 | 'Eval--Epoch[%d], total loss: [%.6f], time: %.1f s!' 247 | % (epoch, np.mean(total_loss), time.time() - start_time)) 248 | logging.info( 249 | '***************************************************************************' 250 | ) 251 | return np.mean(total_loss) 252 | 253 | 254 | def test(dataloader, net, args): 255 | start_time = time.time() 256 | net.eval() 257 | total_mre = [] 258 | total_mean_mre = [] 259 | N = 0 260 | total_hits = np.zeros((8, args.n_class)) 261 | with torch.no_grad(): 262 | for i, sample in enumerate(dataloader): 263 | data = sample['image'] 264 | landmarks = sample['landmarks'] 265 | spacing = sample['spacing'] 266 | data = data.to(DEVICE) 267 | proposal_map = net(data) 268 | mre, hits = metric_proposal(proposal_map, spacing.numpy(), 269 | landmarks.numpy(), shrink=args.shrink, anchors=args.anchors, 270 | n_class=args.n_class) 271 | total_hits += hits 272 | total_mre.append(np.array(mre)) 273 | N += data.shape[0] 274 | cur_mre = [] 275 | for cdx in range(len(mre[0])): 276 | if mre[0][cdx]>0: 277 | cur_mre.append(mre[0][cdx]) 278 | total_mean_mre.append(np.mean(cur_mre)) 279 | print("#: No.", i, "--the current MRE is [%.4f] "%np.mean(cur_mre)) 280 | total_mre = np.concatenate(total_mre, 0) 281 | 282 | 283 | ################################# molar print ############################################## 284 | names = [ 285 | 'L0','La', 'Lb', 'Lc', 'Ld', 'Le', 'Lf', 'R0', 'Ra','Rb','Rc','Rd','Re','Rf' 286 | ] 287 | IDs = ["MRE", "SD", "2.0", "2.5", "3.0", "4."] 288 | form = {"metric": IDs} 289 | mre = [] 290 | sd = [] 291 | cur_hits = total_hits[:4] / total_hits[4:] 292 | 293 | ############################## each class mre ############################################## 294 | for i, name in enumerate(names): 295 | cur_mre = [] 296 | for j in range(total_mre.shape[0]): 297 | if total_mre[j,i] > 0: 298 | cur_mre.append(total_mre[j,i]) 299 | cur_mre = np.array(cur_mre) 300 | mre.append(np.mean(cur_mre)) 301 | sd.append(np.sqrt(np.sum(pow(np.array(cur_mre) - np.mean(cur_mre), 2)) / (N-1))) 302 | 303 | mre = np.stack(mre, 0) 304 | sd = np.stack(sd, 0) 305 | total = np.stack([mre, sd], 0) 306 | 307 | total = np.concatenate([total, cur_hits], 0) 308 | for i, name in enumerate(names): 309 | form[name] = total[:, i] 310 | df = pd.DataFrame(form, columns = form.keys()) 311 | # write each landmark MRE to xlsx file 312 | df.to_excel( 'yolol_test.xlsx', index = False, header=True) 313 | 314 | ########################### total mre ###################################################### 315 | mmre = np.mean(total_mean_mre) 316 | sd = np.sqrt(np.sum(pow(np.array(total_mean_mre) - mmre, 2)) / (N-1)) 317 | 318 | total_hits = np.sum(total_hits, 1) 319 | logging.info( 320 | 'Test-- MRE: [%.2f] + SD: [%.2f], 2.0 mm: [%.4f], 2.5 mm: [%.4f], 3.0 mm: [%.4f], 4.0 mm: [%.4f], using time: %.1f s!' %( 321 | mmre, sd, 322 | total_hits[0] / total_hits[4], 323 | total_hits[1] / total_hits[5], 324 | total_hits[2] / total_hits[6], 325 | total_hits[3] / total_hits[7], 326 | time.time()-start_time)) 327 | logging.info( 328 | '***************************************************************************' 329 | ) 330 | 331 | if __name__ == '__main__': 332 | global args 333 | args = parser.parse_args() 334 | if not os.path.exists(args.save_dir): 335 | os.makedirs(args.save_dir) 336 | args.save_dir = os.path.join(args.save_dir, args.model_name) 337 | if not os.path.exists(args.save_dir): 338 | os.makedirs(args.save_dir) 339 | 340 | logging.basicConfig(level=logging.INFO, 341 | format='%(asctime)s,%(lineno)d: %(message)s\n', 342 | datefmt='%Y-%m-%d(%a)%H:%M:%S', 343 | filename=os.path.join(args.save_dir, 'log.txt'), 344 | filemode='a') 345 | console = logging.StreamHandler() 346 | console.setLevel(logging.INFO) 347 | logging.getLogger().addHandler(console) 348 | main(args) 349 | 350 | -------------------------------------------------------------------------------- /code/models/UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | import importlib 5 | 6 | 7 | def create_feature_maps(init_channel_number, number_of_fmaps): 8 | return [init_channel_number * 2 ** k for k in range(number_of_fmaps)] 9 | 10 | 11 | def conv3d(in_channels, out_channels, kernel_size, bias, padding=1): 12 | return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) 13 | 14 | 15 | def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1): 16 | """ 17 | Create a list of modules with together constitute a single conv layer with non-linearity 18 | and optional batchnorm/groupnorm. 19 | Args: 20 | in_channels (int): number of input channels 21 | out_channels (int): number of output channels 22 | order (string): order of things, e.g. 23 | 'cr' -> conv + ReLU 24 | 'gcr' -> groupnorm + conv + ReLU 25 | 'cl' -> conv + LeakyReLU 26 | 'ce' -> conv + ELU 27 | 'bcr' -> batchnorm + conv + ReLU 28 | num_groups (int): number of groups for the GroupNorm 29 | padding (int): add zero-padding to the input 30 | Return: 31 | list of tuple (name, module) 32 | """ 33 | assert 'c' in order, "Conv layer MUST be present" 34 | assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' 35 | 36 | modules = [] 37 | for i, char in enumerate(order): 38 | if char == 'r': 39 | modules.append(('ReLU', nn.ReLU(inplace=True))) 40 | elif char == 'l': 41 | modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True))) 42 | elif char == 'e': 43 | modules.append(('ELU', nn.ELU(inplace=True))) 44 | elif char == 'c': 45 | # add learnable bias only in the absence of gatchnorm/groupnorm 46 | bias = not ('g' in order or 'b' in order) 47 | modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) 48 | elif char == 'g': 49 | is_before_conv = i < order.index('c') 50 | if is_before_conv: 51 | num_channels = in_channels 52 | else: 53 | num_channels = out_channels 54 | 55 | # use only one group if the given number of groups is greater than the number of channels 56 | if num_channels < num_groups: 57 | num_groups = 1 58 | 59 | assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' 60 | modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) 61 | elif char == 'b': 62 | is_before_conv = i < order.index('c') 63 | if is_before_conv: 64 | modules.append(('batchnorm', nn.BatchNorm3d(in_channels))) 65 | else: 66 | modules.append(('batchnorm', nn.BatchNorm3d(out_channels))) 67 | else: 68 | raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") 69 | 70 | return modules 71 | 72 | 73 | class SingleConv(nn.Sequential): 74 | """ 75 | Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order 76 | of operations can be specified via the `order` parameter 77 | Args: 78 | in_channels (int): number of input channels 79 | out_channels (int): number of output channels 80 | kernel_size (int): size of the convolving kernel 81 | order (string): determines the order of layers, e.g. 82 | 'cr' -> conv + ReLU 83 | 'crg' -> conv + ReLU + groupnorm 84 | 'cl' -> conv + LeakyReLU 85 | 'ce' -> conv + ELU 86 | num_groups (int): number of groups for the GroupNorm 87 | """ 88 | 89 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1): 90 | super(SingleConv, self).__init__() 91 | 92 | for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding): 93 | self.add_module(name, module) 94 | 95 | 96 | class DoubleConv(nn.Sequential): 97 | """ 98 | A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). 99 | We use (Conv3d+ReLU+GroupNorm3d) by default. 100 | This can be changed however by providing the 'order' argument, e.g. in order 101 | to change to Conv3d+BatchNorm3d+ELU use order='cbe'. 102 | Use padded convolutions to make sure that the output (H_out, W_out) is the same 103 | as (H_in, W_in), so that you don't have to crop in the decoder path. 104 | Args: 105 | in_channels (int): number of input channels 106 | out_channels (int): number of output channels 107 | encoder (bool): if True we're in the encoder path, otherwise we're in the decoder 108 | kernel_size (int): size of the convolving kernel 109 | order (string): determines the order of layers, e.g. 110 | 'cr' -> conv + ReLU 111 | 'crg' -> conv + ReLU + groupnorm 112 | 'cl' -> conv + LeakyReLU 113 | 'ce' -> conv + ELU 114 | num_groups (int): number of groups for the GroupNorm 115 | """ 116 | 117 | def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8): 118 | super(DoubleConv, self).__init__() 119 | if encoder: 120 | # we're in the encoder path 121 | conv1_in_channels = in_channels 122 | conv1_out_channels = out_channels // 2 123 | 124 | if conv1_out_channels < in_channels: 125 | conv1_out_channels = in_channels 126 | conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels 127 | else: 128 | # we're in the decoder path, decrease the number of channels in the 1st convolution 129 | conv1_in_channels, conv1_out_channels = in_channels, out_channels 130 | conv2_in_channels, conv2_out_channels = out_channels, out_channels 131 | # conv1 132 | self.add_module('SingleConv1', 133 | SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups)) 134 | # conv2 135 | self.add_module('SingleConv2', 136 | SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups)) 137 | 138 | 139 | class ExtResNetBlock(nn.Module): 140 | """ 141 | Basic UNet block consisting of a SingleConv followed by the residual block. 142 | The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number 143 | of output channels is compatible with the residual block that follows. 144 | This block can be used instead of standard DoubleConv in the Encoder module. 145 | Motivated by: https://arxiv.org/pdf/1706.00120.pdf 146 | Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. 147 | """ 148 | 149 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs): 150 | super(ExtResNetBlock, self).__init__() 151 | 152 | # first convolution 153 | self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 154 | # residual block 155 | self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 156 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual 157 | n_order = order 158 | for c in 'rel': 159 | n_order = n_order.replace(c, '') 160 | self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, 161 | num_groups=num_groups) 162 | 163 | # create non-linearity separately 164 | if 'l' in order: 165 | self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) 166 | elif 'e' in order: 167 | self.non_linearity = nn.ELU(inplace=True) 168 | else: 169 | self.non_linearity = nn.ReLU(inplace=True) 170 | 171 | def forward(self, x): 172 | # apply first convolution and save the output as a residual 173 | out = self.conv1(x) 174 | residual = out 175 | 176 | # residual block 177 | out = self.conv2(out) 178 | out = self.conv3(out) 179 | 180 | out += residual 181 | out = self.non_linearity(out) 182 | 183 | return out 184 | 185 | 186 | class Encoder(nn.Module): 187 | """ 188 | A single module from the encoder path consisting of the optional max 189 | pooling layer (one may specify the MaxPool kernel_size to be different 190 | than the standard (2,2,2), e.g. if the volumetric data is anisotropic 191 | (make sure to use complementary scale_factor in the decoder path) followed by 192 | a DoubleConv module. 193 | Args: 194 | in_channels (int): number of input channels 195 | out_channels (int): number of output channels 196 | conv_kernel_size (int): size of the convolving kernel 197 | apply_pooling (bool): if True use MaxPool3d before DoubleConv 198 | pool_kernel_size (tuple): the size of the window to take a max over 199 | pool_type (str): pooling layer: 'max' or 'avg' 200 | basic_module(nn.Module): either ResNetBlock or DoubleConv 201 | conv_layer_order (string): determines the order of layers 202 | in `DoubleConv` module. See `DoubleConv` for more info. 203 | num_groups (int): number of groups for the GroupNorm 204 | """ 205 | 206 | def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, 207 | pool_kernel_size=(2, 2, 2), pool_type='avg', basic_module=DoubleConv, conv_layer_order='crg', 208 | num_groups=8): 209 | super(Encoder, self).__init__() 210 | ################################################################### 211 | assert pool_type in ['max', 'avg'] 212 | if apply_pooling: 213 | if pool_type == 'max': 214 | self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) 215 | else: 216 | self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) 217 | else: 218 | self.pooling = None 219 | 220 | self.basic_module = basic_module(in_channels, out_channels, 221 | encoder=True, 222 | kernel_size=conv_kernel_size, 223 | order=conv_layer_order, 224 | num_groups=num_groups) 225 | 226 | def forward(self, x): 227 | if self.pooling is not None: 228 | x = self.pooling(x) 229 | x = self.basic_module(x) 230 | return x 231 | 232 | 233 | class Decoder(nn.Module): 234 | """ 235 | A single module for decoder path consisting of the upsample layer 236 | (either learned ConvTranspose3d or interpolation) followed by a DoubleConv 237 | module. 238 | Args: 239 | in_channels (int): number of input channels 240 | out_channels (int): number of output channels 241 | kernel_size (int): size of the convolving kernel 242 | scale_factor (tuple): used as the multiplier for the image H/W/D in 243 | case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation 244 | from the corresponding encoder 245 | basic_module(nn.Module): either ResNetBlock or DoubleConv 246 | conv_layer_order (string): determines the order of layers 247 | in `DoubleConv` module. See `DoubleConv` for more info. 248 | num_groups (int): number of groups for the GroupNorm 249 | """ 250 | 251 | def __init__(self, in_channels, out_channels, kernel_size=3, 252 | scale_factor=(2, 2, 2), basic_module=DoubleConv, conv_layer_order='crg', num_groups=8): 253 | super(Decoder, self).__init__() 254 | if basic_module == DoubleConv: 255 | # if DoubleConv is the basic_module use nearest neighbor interpolation for upsampling 256 | self.upsample = None 257 | else: 258 | # otherwise use ConvTranspose3d (bear in mind your GPU memory) 259 | # make sure that the output size reverses the MaxPool3d from the corresponding encoder 260 | # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0]) 261 | # also scale the number of channels from in_channels to out_channels so that summation joining 262 | # works correctly 263 | self.upsample = nn.ConvTranspose3d(in_channels, 264 | out_channels, 265 | kernel_size=kernel_size, 266 | stride=scale_factor, 267 | padding=1, 268 | output_padding=1) 269 | # adapt the number of in_channels for the ExtResNetBlock 270 | in_channels = out_channels 271 | 272 | self.basic_module = basic_module(in_channels, out_channels, 273 | encoder=False, 274 | kernel_size=kernel_size, 275 | order=conv_layer_order, 276 | num_groups=num_groups) 277 | 278 | def forward(self, encoder_features, x): 279 | if self.upsample is None: 280 | # use nearest neighbor interpolation and concatenation joining 281 | output_size = encoder_features.size()[2:] 282 | x = F.interpolate(x, size=output_size, mode='nearest') 283 | # concatenate encoder_features (encoder path) with the upsampled input across channel dimension 284 | x = torch.cat((encoder_features, x), dim=1) 285 | else: 286 | # use ConvTranspose3d and summation joining 287 | x = self.upsample(x) 288 | x += encoder_features 289 | 290 | x = self.basic_module(x) 291 | return x 292 | 293 | 294 | class FinalConv(nn.Sequential): 295 | """ 296 | A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution 297 | which reduces the number of channels to 'out_channels'. 298 | with the number of output channels 'out_channels // 2' and 'out_channels' respectively. 299 | We use (Conv3d+ReLU+GroupNorm3d) by default. 300 | This can be change however by providing the 'order' argument, e.g. in order 301 | to change to Conv3d+BatchNorm3d+ReLU use order='cbr'. 302 | Args: 303 | in_channels (int): number of input channels 304 | out_channels (int): number of output channels 305 | kernel_size (int): size of the convolving kernel 306 | order (string): determines the order of layers, e.g. 307 | 'cr' -> conv + ReLU 308 | 'crg' -> conv + ReLU + groupnorm 309 | num_groups (int): number of groups for the GroupNorm 310 | """ 311 | 312 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8): 313 | super(FinalConv, self).__init__() 314 | 315 | # conv1 316 | self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups)) 317 | 318 | # in the last layer a 1×1 convolution reduces the number of output channels to out_channels 319 | final_conv = nn.Conv3d(in_channels, out_channels, 1) 320 | self.add_module('final_conv', final_conv) 321 | 322 | 323 | 324 | class UNet3D(nn.Module): 325 | """ 326 | 3DUnet model from 327 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" 328 | `. 329 | Args: 330 | in_channels (int): number of input channels 331 | out_channels (int): number of output segmentation masks; 332 | Note that that the of out_channels might correspond to either 333 | different semantic classes or to different binary segmentation mask. 334 | It's up to the user of the class to interpret the out_channels and 335 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) 336 | or BCEWithLogitsLoss (two-class) respectively) 337 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 338 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 339 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 340 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 341 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 342 | layer_order (string): determines the order of layers 343 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 344 | See `SingleConv` for more info 345 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 346 | num_groups (int): number of groups for the GroupNorm 347 | """ 348 | 349 | def __init__(self, n_class, in_channels=1, f_maps=32, layer_order='cgr', num_groups=8, 350 | **kwargs): 351 | super(UNet3D, self).__init__() 352 | 353 | # Set testing mode to false by default. It has to be set to true in test mode, otherwise the `final_activation` 354 | # layer won't be applied 355 | out_channels = n_class 356 | if isinstance(f_maps, int): 357 | # use 4 levels in the encoder path as suggested in the paper 358 | f_maps = create_feature_maps(f_maps, number_of_fmaps=4) 359 | 360 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)` 361 | # uses DoubleConv as a basic_module for the Encoder 362 | encoders = [] 363 | for i, out_feature_num in enumerate(f_maps): 364 | if i == 0: 365 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv, 366 | conv_layer_order=layer_order, num_groups=num_groups) 367 | else: 368 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv, 369 | conv_layer_order=layer_order, num_groups=num_groups) 370 | encoders.append(encoder) 371 | 372 | self.encoders = nn.ModuleList(encoders) 373 | 374 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 375 | # uses DoubleConv as a basic_module for the Decoder 376 | decoders = [] 377 | reversed_f_maps = list(reversed(f_maps)) 378 | for i in range(len(reversed_f_maps) - 1): 379 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] 380 | out_feature_num = reversed_f_maps[i + 1] 381 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv, 382 | conv_layer_order=layer_order, num_groups=num_groups) 383 | decoders.append(decoder) 384 | 385 | self.decoders = nn.ModuleList(decoders) 386 | 387 | # in the last layer a 1×1 convolution reduces the number of output 388 | # channels to the number of labels 389 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 390 | 391 | 392 | def forward(self, x): 393 | # encoder part 394 | encoders_features = [] 395 | for encoder in self.encoders: 396 | x = encoder(x) 397 | # reverse the encoder outputs to be aligned with the decoder 398 | encoders_features.insert(0, x) 399 | 400 | # remove the last encoder's output from the list 401 | # !!remember: it's the 1st in the list 402 | encoders_features = encoders_features[1:] 403 | 404 | # decoder part 405 | for decoder, encoder_features in zip(self.decoders, encoders_features): 406 | # pass the output from the corresponding encoder and the output 407 | # of the previous decoder 408 | x = decoder(encoder_features, x) 409 | 410 | x = self.final_conv(x) 411 | 412 | # apply final_activation (i.e. Sigmoid or Softmax) only at test time; during training/evaluation the network 413 | # outputs logits and it's up to the user to normalize it before visualising with tensorboard 414 | # or computing validation metric 415 | 416 | return x 417 | 418 | 419 | class PUNet3D(nn.Module): 420 | 421 | def __init__(self, n_class, n_anchor=4, in_channels=1, f_maps=32, layer_order='cgr', num_groups=8, 422 | **kwargs): 423 | super(PUNet3D, self).__init__() 424 | self.n_anchor = n_anchor 425 | self.n_class = n_class 426 | if isinstance(f_maps, int): 427 | # use 4 levels in the encoder path as suggested in the paper 428 | f_maps = create_feature_maps(f_maps, number_of_fmaps=4) 429 | 430 | encoders = [] 431 | for i, out_feature_num in enumerate(f_maps): 432 | if i == 0: 433 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv, 434 | conv_layer_order=layer_order, num_groups=num_groups) 435 | else: 436 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv, 437 | conv_layer_order=layer_order, num_groups=num_groups) 438 | encoders.append(encoder) 439 | self.encoders = nn.ModuleList(encoders) 440 | 441 | decoders = [] 442 | reversed_f_maps = list(reversed(f_maps)) 443 | for i in range(len(reversed_f_maps) - 3): 444 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] 445 | out_feature_num = reversed_f_maps[i + 1] 446 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv, 447 | conv_layer_order=layer_order, num_groups=num_groups) 448 | decoders.append(decoder) 449 | self.decoders = nn.ModuleList(decoders) 450 | self.early_down1 = nn.Conv3d(f_maps[0], f_maps[2], kernel_size=1, stride=4) 451 | self.early_down2 = nn.Conv3d(f_maps[1], f_maps[2], kernel_size=1, stride=2) 452 | self.pre_layer = nn.Conv3d(3*f_maps[2], n_anchor*(3+n_class), kernel_size=1, stride=1) 453 | 454 | 455 | def forward(self, x): 456 | # encoder part 457 | encoders_features = [] 458 | 459 | for encoder in self.encoders: 460 | x = encoder(x) 461 | print("encoder", x.shape) 462 | # reverse the encoder outputs to be aligned with the decoder 463 | encoders_features.insert(0, x) 464 | 465 | encoders_features = encoders_features[1:] 466 | for decoder, encoder_features in zip(self.decoders, encoders_features): 467 | x = decoder(encoder_features, x) 468 | print("decoder", x.shape) 469 | 470 | early_out1 = self.early_down1(encoders_features[-1]) 471 | early_out2 = self.early_down2(encoders_features[-2]) 472 | out = self.pre_layer(torch.cat([early_out1, early_out2, x], 1)) 473 | out = out.reshape(out.shape[0], self.n_anchor, 3+self.n_class, out.shape[2], out.size(3), out.size(4)) 474 | out = out.permute(0,3,4,5,1,2) 475 | return out 476 | 477 | 478 | class ResidualUNet3D(nn.Module): 479 | """ 480 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. 481 | Uses ExtResNetBlock instead of DoubleConv as a basic building block as well as summation joining instead 482 | of concatenation joining. Since the model effectively becomes a residual net, in theory it allows for deeper UNet. 483 | Args: 484 | in_channels (int): number of input channels 485 | out_channels (int): number of output segmentation masks; 486 | Note that that the of out_channels might correspond to either 487 | different semantic classes or to different binary segmentation mask. 488 | It's up to the user of the class to interpret the out_channels and 489 | use the proper loss criterion during training (i.e. NLLLoss (multi-class) 490 | or BCELoss (two-class) respectively) 491 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 492 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4,5 493 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 494 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 495 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 496 | conv_layer_order (string): determines the order of layers 497 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 498 | See `SingleConv` for more info 499 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64 500 | num_groups (int): number of groups for the GroupNorm 501 | skip_final_activation (bool): if True, skips the final normalization layer (sigmoid/softmax) and returns the 502 | logits directly 503 | """ 504 | 505 | def __init__(self, n_class, in_channels=1, f_maps=32, conv_layer_order='cge', num_groups=8, 506 | **kwargs): 507 | super(ResidualUNet3D, self).__init__() 508 | out_channels = n_class 509 | # Set testing mode to false by default. It has to be set to true in test mode, otherwise the `final_activation` 510 | # layer won't be applied 511 | 512 | if isinstance(f_maps, int): 513 | # use 5 levels in the encoder path as suggested in the paper 514 | f_maps = create_feature_maps(f_maps, number_of_fmaps=5) 515 | 516 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)` 517 | # uses ExtResNetBlock as a basic_module for the Encoder 518 | encoders = [] 519 | for i, out_feature_num in enumerate(f_maps): 520 | if i == 0: 521 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=ExtResNetBlock, 522 | conv_layer_order=conv_layer_order, num_groups=num_groups) 523 | else: 524 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=ExtResNetBlock, 525 | conv_layer_order=conv_layer_order, num_groups=num_groups) 526 | encoders.append(encoder) 527 | 528 | self.encoders = nn.ModuleList(encoders) 529 | 530 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 531 | # uses ExtResNetBlock as a basic_module for the Decoder 532 | decoders = [] 533 | reversed_f_maps = list(reversed(f_maps)) 534 | for i in range(len(reversed_f_maps) - 1): 535 | decoder = Decoder(reversed_f_maps[i], reversed_f_maps[i + 1], basic_module=ExtResNetBlock, 536 | conv_layer_order=conv_layer_order, num_groups=num_groups) 537 | decoders.append(decoder) 538 | 539 | self.decoders = nn.ModuleList(decoders) 540 | 541 | # in the last layer a 1×1 convolution reduces the number of output 542 | # channels to the number of labels 543 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 544 | 545 | 546 | def forward(self, x): 547 | # encoder part 548 | encoders_features = [] 549 | for encoder in self.encoders: 550 | x = encoder(x) 551 | # reverse the encoder outputs to be aligned with the decoder 552 | encoders_features.insert(0, x) 553 | 554 | # remove the last encoder's output from the list 555 | # !!remember: it's the 1st in the list 556 | encoders_features = encoders_features[1:] 557 | 558 | # decoder part 559 | for decoder, encoder_features in zip(self.decoders, encoders_features): 560 | # pass the output from the corresponding encoder and the output 561 | # of the previous decoder 562 | x = decoder(encoder_features, x) 563 | 564 | x = self.final_conv(x) 565 | 566 | # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs 567 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric 568 | 569 | return x 570 | 571 | 572 | class PResidualUNet3D(nn.Module): 573 | def __init__(self, n_class, n_anchor, in_channels=1, f_maps=32, conv_layer_order='cge', num_groups=8, 574 | **kwargs): 575 | super(PResidualUNet3D, self).__init__() 576 | self.n_class = n_class 577 | self.n_anchor = n_anchor 578 | 579 | if isinstance(f_maps, int): 580 | f_maps = create_feature_maps(f_maps, number_of_fmaps=5) 581 | encoders = [] 582 | for i, out_feature_num in enumerate(f_maps): 583 | if i == 0: 584 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=ExtResNetBlock, 585 | conv_layer_order=conv_layer_order, num_groups=num_groups) 586 | else: 587 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=ExtResNetBlock, 588 | conv_layer_order=conv_layer_order, num_groups=num_groups) 589 | encoders.append(encoder) 590 | 591 | self.encoders = nn.ModuleList(encoders) 592 | 593 | decoders = [] 594 | reversed_f_maps = list(reversed(f_maps)) 595 | for i in range(len(reversed_f_maps) - 3): 596 | decoder = Decoder(reversed_f_maps[i], reversed_f_maps[i + 1], basic_module=ExtResNetBlock, 597 | conv_layer_order=conv_layer_order, num_groups=num_groups) 598 | decoders.append(decoder) 599 | 600 | self.decoders = nn.ModuleList(decoders) 601 | self.early_down1 = nn.Conv3d(f_maps[0], f_maps[2], kernel_size=1, stride=4) 602 | self.early_down2 = nn.Conv3d(f_maps[1], f_maps[2], kernel_size=1, stride=2) 603 | self.pre_layer = nn.Conv3d(3*f_maps[2], n_anchor*(3+n_class), kernel_size=1, stride=1) 604 | 605 | def forward(self, x): 606 | # encoder part 607 | encoders_features = [] 608 | for encoder in self.encoders: 609 | x = encoder(x) 610 | 611 | encoders_features.insert(0, x) 612 | encoders_features = encoders_features[1:] 613 | 614 | for decoder, encoder_features in zip(self.decoders, encoders_features): 615 | x = decoder(encoder_features, x) 616 | 617 | early_out1 = self.early_down1(encoders_features[-1]) 618 | early_out2 = self.early_down2(encoders_features[-2]) 619 | out = self.pre_layer(torch.cat([early_out1, early_out2, x], 1)) 620 | out = out.reshape(out.shape[0], self.n_anchor, 3+self.n_class, out.shape[2], out.size(3), out.size(4)) 621 | out = out.permute(0,3,4,5,1,2) 622 | return out 623 | 624 | -------------------------------------------------------------------------------- /code/models/VNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def passthrough(x, **kwargs): 7 | return x 8 | 9 | def ELUCons(elu, nchan): 10 | if elu: 11 | return nn.ELU(inplace=True) 12 | else: 13 | return nn.ReLU(nchan) 14 | 15 | 16 | class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm): 17 | def forward(self, input): 18 | return F.batch_norm( 19 | input, self.running_mean, self.running_var, self.weight, self.bias, 20 | True, self.momentum, self.eps) 21 | 22 | 23 | class LUConv(nn.Module): 24 | def __init__(self, nchan, elu): 25 | super(LUConv, self).__init__() 26 | self.relu1 = ELUCons(elu, nchan) 27 | self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2) 28 | self.bn1 = ContBatchNorm3d(nchan) 29 | 30 | def forward(self, x): 31 | out = self.relu1(self.bn1(self.conv1(x))) 32 | return out 33 | 34 | 35 | def _make_nConv(nchan, depth, elu): 36 | layers = [] 37 | for _ in range(depth): 38 | layers.append(LUConv(nchan, elu)) 39 | return nn.Sequential(*layers) 40 | 41 | 42 | class InputTransition(nn.Module): 43 | def __init__(self, outChans, elu): 44 | super(InputTransition, self).__init__() 45 | self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2) 46 | self.bn1 = ContBatchNorm3d(16) 47 | self.relu1 = ELUCons(elu, 16) 48 | 49 | def forward(self, x): 50 | out = self.bn1(self.conv1(x)) 51 | x16 = torch.cat([x]*16, 1) 52 | out = self.relu1(torch.add(out, x16)) 53 | return out 54 | 55 | 56 | class DownTransition(nn.Module): 57 | def __init__(self, inChans, nConvs, elu, dropout=False): 58 | super(DownTransition, self).__init__() 59 | outChans = 2*inChans 60 | self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2) 61 | self.bn1 = ContBatchNorm3d(outChans) 62 | self.do1 = passthrough 63 | self.relu1 = ELUCons(elu, outChans) 64 | self.relu2 = ELUCons(elu, outChans) 65 | if dropout: 66 | self.do1 = nn.Dropout3d() 67 | self.ops = _make_nConv(outChans, nConvs, elu) 68 | 69 | def forward(self, x): 70 | down = self.relu1(self.bn1(self.down_conv(x))) 71 | out = self.do1(down) 72 | out = self.ops(out) 73 | out = self.relu2(torch.add(out, down)) 74 | return out 75 | 76 | 77 | class UpTransition(nn.Module): 78 | def __init__(self, inChans, outChans, nConvs, elu, dropout=False): 79 | super(UpTransition, self).__init__() 80 | self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2) 81 | self.bn1 = ContBatchNorm3d(outChans // 2) 82 | self.do1 = passthrough 83 | self.do2 = nn.Dropout3d() 84 | self.relu1 = ELUCons(elu, outChans // 2) 85 | self.relu2 = ELUCons(elu, outChans) 86 | if dropout: 87 | self.do1 = nn.Dropout3d() 88 | self.ops = _make_nConv(outChans, nConvs, elu) 89 | 90 | def forward(self, x, skipx): 91 | out = self.do1(x) 92 | skipxdo = self.do2(skipx) 93 | out = self.relu1(self.bn1(self.up_conv(out))) 94 | xcat = torch.cat((out, skipxdo), 1) 95 | out = self.ops(xcat) 96 | out = self.relu2(torch.add(out, xcat)) 97 | return out 98 | 99 | 100 | class OutputTransition(nn.Module): 101 | def __init__(self, inChans, elu, nll, n_class): 102 | super(OutputTransition, self).__init__() 103 | self.conv1 = nn.Conv3d(inChans, n_class, kernel_size=3, padding=1) 104 | 105 | def forward(self, x): 106 | out = self.conv1(x) 107 | return out 108 | 109 | 110 | class VNet(nn.Module): 111 | def __init__(self, n_class, elu=True, nll=False): 112 | super(VNet, self).__init__() 113 | self.in_tr = InputTransition(16, elu) 114 | self.down_tr32 = DownTransition(16, 1, elu) 115 | self.down_tr64 = DownTransition(32, 2, elu) 116 | self.down_tr128 = DownTransition(64, 3, elu, dropout=False) 117 | self.down_tr256 = DownTransition(128, 2, elu, dropout=False) 118 | 119 | self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False) 120 | self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False) 121 | self.up_tr64 = UpTransition(128, 64, 1, elu) 122 | self.up_tr32 = UpTransition(64, 32, 1, elu) 123 | self.out_tr = OutputTransition(32, elu, nll, n_class) 124 | 125 | 126 | def forward(self, x): 127 | out16 = self.in_tr(x) 128 | out32 = self.down_tr32(out16) 129 | out64 = self.down_tr64(out32) 130 | out128 = self.down_tr128(out64) 131 | out256 = self.down_tr256(out128) 132 | 133 | out = self.up_tr256(out256, out128) 134 | out = self.up_tr128(out, out64) 135 | out = self.up_tr64(out, out32) 136 | out = self.up_tr32(out, out16) 137 | out = self.out_tr(out) 138 | 139 | return out 140 | 141 | 142 | class PVNet(nn.Module): 143 | def __init__(self, n_class, n_anchor=4, elu=True, nll=False): 144 | super(PVNet, self).__init__() 145 | self.in_tr = InputTransition(16, elu) 146 | self.down_tr32 = DownTransition(16, 1, elu) 147 | self.down_tr64 = DownTransition(32, 2, elu) 148 | self.down_tr128 = DownTransition(64, 3, elu, dropout=False) 149 | self.down_tr256 = DownTransition(128, 2, elu, dropout=False) 150 | 151 | self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False) 152 | self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False) 153 | self.n_anchor = n_anchor 154 | self.n_class = n_class 155 | 156 | self.early_down1 = nn.Conv3d(16, 64, kernel_size=1, stride=4) 157 | self.early_down2 = nn.Conv3d(32, 64, kernel_size=1, stride=2) 158 | self.pre_layer = nn.Conv3d(64+64+128, n_anchor*(3+n_class), kernel_size=1, stride=1) 159 | 160 | def forward(self, x): 161 | out16 = self.in_tr(x) 162 | out32 = self.down_tr32(out16) 163 | out64 = self.down_tr64(out32) 164 | out128 = self.down_tr128(out64) 165 | out256 = self.down_tr256(out128) 166 | 167 | out = self.up_tr256(out256, out128) 168 | out = self.up_tr128(out, out64) 169 | 170 | early_out1 = self.early_down1(out16) 171 | early_out2 = self.early_down2(out32) 172 | out = self.pre_layer(torch.cat([early_out1, early_out2, out], 1)) 173 | out = out.reshape(out.shape[0], self.n_anchor, 3+self.n_class, out.shape[2], out.size(3), out.size(4)) 174 | out = out.permute(0,3,4,5,1,2) 175 | return out 176 | -------------------------------------------------------------------------------- /code/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/models/__init__.py -------------------------------------------------------------------------------- /code/models/__pycache__/UNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/models/__pycache__/UNet.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/VNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/models/__pycache__/VNet.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/models/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/losses.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | # HNM_heatmap loss for heatmap regression 7 | class HNM_heatmap(nn.Module): 8 | def __init__(self, R=20): 9 | super(HNM_heatmap, self).__init__() 10 | self.R = R 11 | self.regressionLoss = nn.SmoothL1Loss(reduction='mean') 12 | 13 | def forward(self, heatmap, target_heatmap): 14 | loss = 0 15 | batch_size = heatmap.size(0) 16 | n_class = heatmap.size(1) 17 | heatmap = heatmap.reshape(batch_size, n_class, -1) 18 | target_heatmap = target_heatmap.reshape(batch_size, n_class, -1) 19 | for i in range(batch_size): 20 | for j in range(n_class): 21 | # counting the heatmap voxels 22 | select_number = torch.sum( 23 | target_heatmap[i, j] >= 0).int().item() 24 | 25 | if select_number <= 0: 26 | # if landmark is nonexist, setting a fixed number of hard negative mining 27 | select_number = int(self.R * self.R * self.R / 8) 28 | else: 29 | # if existing a landmark, regress these voxels inside the mask 30 | _, cur_idx = torch.topk( 31 | target_heatmap[i, j], select_number) 32 | predict_pos = heatmap[i, j].index_select(0, cur_idx) 33 | target_pos = target_heatmap[i, j].index_select(0, cur_idx) 34 | loss += self.regressionLoss(predict_pos, target_pos) 35 | 36 | # using hard negative mining for background voxels 37 | # the default background voxel is -1 38 | mask_neg = 1 - target_heatmap[i, j] 39 | neg_number = torch.sum( 40 | target_heatmap[i, j] < 0).int().item() 41 | _, neg_idx = torch.topk(mask_neg, neg_number) 42 | predict_neg = heatmap[i, j].index_select(0, neg_idx) 43 | _, cur_idx = torch.topk(predict_neg, 44 | select_number) 45 | predict_neg = heatmap[i, j].index_select(0, cur_idx) 46 | target_neg = target_heatmap[i, j].index_select(0, cur_idx) 47 | loss += self.regressionLoss(predict_neg, target_neg) 48 | return loss / (batch_size * n_class) 49 | 50 | 51 | # HNM_propmap loss for yolol model training 52 | class HNM_propmap(nn.Module): 53 | def __init__(self, n_class=14, lambda_hnm=0.2,lambda_noobj=0.001, device=None): #0.2 54 | super(HNM_propmap, self).__init__() 55 | self.regressionLoss = nn.SmoothL1Loss() # regression loss 56 | self.bceLoss = nn.BCEWithLogitsLoss() # classification loss 57 | self.n_class = n_class 58 | self.lambda_hnm = lambda_hnm # the weight for hard negative mining 59 | self.lambda_noobj = lambda_noobj # the weight for regularization to make background deactivate 60 | self.device = device 61 | self.hard_num = 256 # the selected number for nonexist landmark 62 | 63 | def forward(self, proposal_map, proposals): 64 | loss = 0 65 | batch_size = proposal_map.size(0) 66 | 67 | cl_pred_pos = [] 68 | cl_pred_neg = [] 69 | reg_pred = [] 70 | reg_target = [] 71 | hard_neg_count = np.zeros((self.n_class, )).astype("int32") 72 | hard_neg_pred = [] 73 | for i in range(batch_size): 74 | for anchor_idx, proposal in enumerate(proposals[i]): 75 | for bbox in proposal: 76 | c=int(bbox[0]); w=int(bbox[1]); h=int(bbox[2]) 77 | # -100 indicate the padded proposal 78 | # the details refer to class LandmarkProposal in data_utils/transforms.py 79 | if bbox[-1] == -100: 80 | break 81 | elif bbox[-1] >= 0: 82 | # if landmark exist, generate prediction and target of relative coordinates 83 | cl_pred_pos.append(proposal_map[i, c, w, h, anchor_idx, int(3+bbox[-1]):int(4+bbox[-1])]) 84 | cl_pred_neg.append(proposal_map[i, c, w, h, anchor_idx, 3:int(3+bbox[-1])]) 85 | cl_pred_neg.append(proposal_map[i, c, w, h, anchor_idx, int(4+bbox[-1]):]) 86 | reg_pred.append(proposal_map[i, c, w, h, anchor_idx, :3]) 87 | reg_target.append(bbox[3:-1]) 88 | else: 89 | # if landmark nonexist, indicate the label for hard negative mining 90 | hard_neg_count[-1-int(bbox[-1].item())] += 1 91 | 92 | # select hard negative voxels for nonexist landmarks 93 | for i in range(self.n_class): 94 | if hard_neg_count[i] != 0: 95 | cur_negative = proposal_map[:,:,:,:,:,3+i].reshape(-1) 96 | _, neg_idx = torch.topk(cur_negative, hard_neg_count[i]*self.hard_num) 97 | hard_neg_pred.append(cur_negative[neg_idx]) 98 | 99 | 100 | cl_pred_pos = torch.cat(cl_pred_pos, 0) 101 | cl_pred_neg = torch.cat(cl_pred_neg, 0) 102 | ################## classification loss for positive ############################ 103 | cl_pos_loss= self.bceLoss(cl_pred_pos, torch.ones((cl_pred_pos.shape[0],)).to(self.device)) 104 | ################## classification loss for negative ###################### 105 | cl_neg_loss= 1/(self.n_class-1) * self.bceLoss(cl_pred_neg, torch.zeros((cl_pred_neg.shape[0],)).to(self.device)) 106 | 107 | ################# classification loss for hard negative ######################### 108 | cl_hard_neg_loss = 0 109 | if len(hard_neg_pred) > 0: 110 | hard_neg_pred = torch.cat(hard_neg_pred, 0) 111 | cl_hard_neg_loss += self.lambda_hnm*self.bceLoss(hard_neg_pred, torch.zeros((hard_neg_pred.shape[0],)).to 112 | (self.device)) 113 | 114 | ################### classification loss for regularization ###################### 115 | regu_neg_loss = self.lambda_noobj*self.bceLoss(proposal_map, 116 | torch.zeros_like(proposal_map).to(self.device)) 117 | 118 | ################################## regression ################################### 119 | reg_loss = self.regressionLoss(torch.tanh(torch.stack(reg_pred, 0)), torch.stack(reg_target, 0)) 120 | loss += cl_pos_loss + cl_neg_loss + cl_hard_neg_loss + regu_neg_loss + reg_loss 121 | return loss -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | def setgpu(gpus): 6 | if gpus=='all': 7 | gpus = '0,1,2,3' 8 | print('using gpu '+gpus) 9 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus 10 | return len(gpus.split(',')) 11 | 12 | 13 | def metric(heatmap, spacing, landmarks): 14 | N = heatmap.shape[0] 15 | n_class = heatmap.shape[1] 16 | total_mre = [] 17 | max_num = 500 18 | hits = np.zeros((8, n_class)) 19 | 20 | for j in range(N): 21 | cur_mre_group = [] 22 | for i in range(n_class): 23 | max_count = 0 24 | group_rate = 0.999 25 | if np.max(heatmap[j,i])>0: 26 | while max_count < max_num: 27 | h_score_idxs = np.where( 28 | heatmap[j, i] >= np.max(heatmap[j, i])*group_rate) 29 | group_rate = group_rate - 0.1 30 | max_count = len(h_score_idxs[0]) 31 | else: 32 | h_score_idxs = np.where( 33 | heatmap[j, i] >= np.max(heatmap[j, i])*(1+0.5)) 34 | 35 | h_predict_location = np.array( 36 | [np.mean(h_score_idxs[0]), np.mean(h_score_idxs[1]), np.mean(h_score_idxs[2])]) 37 | 38 | cur_mre = np.linalg.norm( 39 | np.array(landmarks[j,i] - h_predict_location)*spacing, ord=2) 40 | 41 | if np.mean(landmarks[j, i])>0: 42 | cur_mre_group.append(cur_mre) 43 | hits[4:, i] += 1 44 | if cur_mre <= 2.0: 45 | hits[0, i] += 1 46 | if cur_mre <= 2.5: 47 | hits[1, i] += 1 48 | if cur_mre <= 3.: 49 | hits[2, i] += 1 50 | if cur_mre <= 4.: 51 | hits[3, i] += 1 52 | else: 53 | cur_mre_group.append(-1) 54 | total_mre.append(np.array(cur_mre_group)) 55 | 56 | return total_mre, hits 57 | 58 | 59 | def min_distance_voting(landmarks): 60 | min_dis = 1000000 61 | min_landmark = landmarks[0] 62 | for landmark in landmarks: 63 | cur_dis = 0 64 | for sub_landmark in landmarks: 65 | cur_dis += np.linalg.norm( 66 | np.array(landmark - sub_landmark), ord=2) 67 | if cur_dis < min_dis: 68 | min_dis = cur_dis 69 | min_landmark = landmark 70 | return min_landmark 71 | 72 | 73 | def metric_proposal(proposal_map, spacing, 74 | landmarks, shrink=4., anchors=[0.5, 1, 1.5, 2], n_class=14): 75 | # selected number for candidate landmark voting for one landmark 76 | # can be fine-tuned according to anchor numbers 77 | select_number = 15 78 | 79 | batch_size = proposal_map.size(0) 80 | c = proposal_map.size(1) 81 | w = proposal_map.size(2) 82 | h = proposal_map.size(3) 83 | n_anchor = proposal_map.size(4) 84 | total_mre = [] 85 | hits = np.zeros((8, n_class)) 86 | 87 | for j in range(batch_size): 88 | cur_mre_group = [] 89 | for idx in range(n_class): 90 | #################### from proposal map to landmarks ######################### 91 | proposal_map_vector = proposal_map[:,:,:,:,:,3+idx].reshape(-1) 92 | mask = torch.zeros_like(proposal_map_vector) 93 | _, cur_idx = torch.topk( 94 | proposal_map_vector, select_number) 95 | mask[cur_idx] = 1 96 | mask_tensor = mask.reshape((batch_size, c, w, h, n_anchor, -1)) 97 | select_index = np.where(mask_tensor.cpu().numpy()==1) 98 | 99 | # get predicted position 100 | pred_pos = [] 101 | for i in range(len(select_index[0])): 102 | cur_pos = [] 103 | cur_batch = select_index[0][i] 104 | cur_c = select_index[1][i] 105 | cur_w = select_index[2][i] 106 | cur_h = select_index[3][i] 107 | cur_anchor = select_index[4][i] 108 | cur_predict = torch.tanh(proposal_map[cur_batch, cur_c, cur_w, cur_h, cur_anchor, :3]).cpu().numpy() 109 | 110 | cur_pos.append( (np.array([cur_c, cur_w, cur_h]) + cur_predict*anchors[cur_anchor])*shrink ) 111 | pred_pos.append(cur_pos) 112 | pred_pos = np.array(pred_pos) 113 | 114 | cur_mre = np.linalg.norm( 115 | (np.array(landmarks[j,idx] - min_distance_voting(pred_pos)))*spacing[j], ord=2) 116 | if cur_mre <= 2.0: 117 | hits[0, idx] += 1 118 | if cur_mre <= 2.5: 119 | hits[1, idx] += 1 120 | if cur_mre <= 3.: 121 | hits[2, idx] += 1 122 | if cur_mre <= 4.: 123 | hits[3, idx] += 1 124 | 125 | if np.mean(landmarks[j, idx])>0: 126 | cur_mre_group.append(cur_mre) 127 | hits[4:, idx] += 1 128 | else: 129 | # if landmark nonexist, do not calculate MRE and SDR, using -1 to indicate it 130 | cur_mre_group.append(-1) 131 | total_mre.append(np.array(cur_mre_group)) 132 | return total_mre, hits 133 | -------------------------------------------------------------------------------- /images/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/images/cover.png -------------------------------------------------------------------------------- /images/problem1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/images/problem1.png -------------------------------------------------------------------------------- /images/problem2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/images/problem2.png -------------------------------------------------------------------------------- /images/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/images/table.png --------------------------------------------------------------------------------