├── LICENSE
├── README.md
├── code
├── __pycache__
│ └── utils.cpython-37.pyc
├── data_utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── dataloader.cpython-37.pyc
│ │ └── transforms.cpython-37.pyc
│ ├── dataloader.py
│ └── transforms.py
├── main_baseline.py
├── main_yolol.py
├── models
│ ├── UNet.py
│ ├── VNet.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── UNet.cpython-37.pyc
│ │ ├── VNet.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ └── losses.cpython-37.pyc
│ └── losses.py
└── utils.py
└── images
├── cover.png
├── problem1.png
├── problem2.png
└── table.png
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Tao He
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Anchor Ball Regression Model for Large-Scale 3D Skull Landmark Detection
2 |
3 | 
4 |
5 | ## 1. Introduction
6 | ### 1.1 What for?
7 |
8 | In this work, we have identified two limitations that hinder the domain of 3D skull landmark detection:
9 | * 1. The lack of a standard benchmark dataset to evaluate the performance of automatic landmark detection models is a significant constraint. Through a review of advanced models from 2018 to 2023, listed in the Table, it was found that these models were trained on private datasets with variable data sizes, types, evaluation metrics, and numbers of landmarks.
10 |
11 |
12 | * 2. the majority of studies collected data only during pre or postoperative stages. However, in a real-world clinical environment, the model must be robust enough to meet clinical demands with diverse data. Conducting a clinical landmarking evaluation is necessary in both pre and postoperative stages. Unfortunately, most models only focus on identifying a fixed number of landmarks on standard CT or CBCT volumes.
13 |
14 |
15 | **The references can be find in our paper (coming soon)!
16 |
17 |
18 |
19 | The project of Mandibular Molar Landmarking (MML) aims to identify the anatomy locations of the second and third mandibular molars' crowns and roots. The task has two main challenges:
20 |
21 | * Mandibular molars have different root numbers because of molars' variant growth.
22 |
23 |
24 |
25 | * Mandibular molars can be damaged by dental diseases, trauma, or surgery.
26 |
27 |
28 |
29 |
30 | ### 1.2 HighLights
31 | * We created a large-scale benchmark dataset consisting of 648 CT volumes for evaluating 3D skull landmark detection. This dataset is publicly available and is, to the best of our knowledge, the largest public dataset.
32 | * MML requires models that are robust in clinical environments and are capable of detecting arbitrary landmarks on pre-operative or post-operative CT volumes, meeting real clinical needs.
33 | * We compared baseline deep learning methods in three aspects: landmark regression models, training losses, neural network structures. An ABR model inspired by YOLOV3 surpassed other baselines. The model combines landmark regression and classification losses for network training, resulting in better performance than the usual heatmap and offset regression methods.
34 |
35 |
36 | ## 2. Preparation
37 | ### 2.1 Requirements
38 | - python >=3.7
39 | - pytorch >=1.10.0
40 | - Cuda 10 or higher
41 | - numpy
42 | - pandas
43 | - scipy
44 | - nrrd
45 | - time
46 |
47 | ### 2.2 Data Preparation
48 |
49 | The dataset is available at https://drive.google.com/file/d/1NGsBbqXZLDlkiSJtDQdyMlXzgnkFoVON/view?usp=sharing>
50 | * Data division
51 | ```
52 | - mmld_dataset/train # 458 samples for training
53 | - mmld_dataset/val # 100 samples for validation
54 | - mmld_dataset/test # 100 samples for testing
55 | ```
56 | * Data format
57 | ```
58 | - *_volume.nrrd # 3D volumes
59 | - *_label.npy # landmarks
60 | - *_spacing.npy # CT spacings, used for calculating MRE
61 | ```
62 |
63 | ## 3. Train and Test
64 | ### 3.1 Network Training
65 |
66 | * Training with different network backbones
67 | ```
68 | python main_yolol.py --model_name PVNet # network training using backbone PVNet
69 | python main_yolol.py --model_name PUNet3D # network training using backbone PUNet3D
70 | python main_yolol.py --model_name PResidualUNet3D # network training using backbone PResidualUNet3D
71 | ```
72 |
73 | * Training with different GPUs
74 | ```
75 | python main_yolol.py --gpu 0 # training with 1 gpu
76 | python main_yolol.py --gpu 0,1,2,3 # training with 4 gpus
77 | ```
78 |
79 | ### 3.2 Fine-tuning in a pretrained checkpoint
80 | ```
81 | python main_yolol.py --resume ../SavePath/yolol/model.ckpt
82 | ```
83 |
84 | ### 3.3 Metric counting
85 | ```
86 | python main_yolol.py --test_flag 0 --resume ../SavePath/yolol/model.ckpt # calculate MRE and SDR in validation set
87 | python main_yolol.py --test_flag 1 --resume ../SavePath/yolol/model.ckpt # calculate MRE and SDR in test set
88 | ```
89 |
90 | ### 3.4 Training baseline heatmap regression model
91 | ```
92 | python main_baseline.py # network training for baseline heatmap regression model
93 | ```
94 |
95 | ## 4. Leadboard (Update 2023/06/15)
96 |
97 | ### The ACC, F1, MRE, and SDR on the MINI subset.
98 |
99 | | **Models** | **ACC(%)** | **F1(%)** | **MRE±Std(mm)** | **SDR-2mm(%)** | **SDR-2.5mm(%)** | **SDR-3mm(%)** | **SDR-4mm(%)** |
100 | | :----------- | :---------- | :--------- | :--------------- | :-------------- | :---------------- | :-------------- | :-------------- |
101 | | Our Baseline | 93.04% | 94.98 | 2.26±1.26 | 61.89 | 74.86 | 82.43 | 91.89 |
102 | | placeholder | | | | | | | |
103 | | placeholder | | | | | | | |
104 | | placeholder | | | | | | | |
105 |
106 |
107 | ### The MRE and SDR on the whole dataset.
108 |
109 |
110 | | **Models** | **MRE±Std(mm)** | **SDR-2mm(%)** | **SDR-2.5mm(%)** | **SDR-3mm(%)** | **SDR-4mm(%)** |
111 | | :----------- | :--------------- | :-------------- | :---------------- | :-------------- | :-------------- |
112 | | Our Baseline | 1.70±0.72 | 76.43 | 86.45 | 90.91 | 95.20 |
113 | | placeholder | | | | | |
114 | | placeholder | | | | | |
115 | | placeholder | | | | | |
116 |
117 |
118 |
119 |
120 | ## 5. Contact
121 |
122 |
123 | Institution: Intelligent Medical Center, Sichuan University
124 |
125 | email: tao_he@scu.edu.cn; taohescu@gmail.com
126 |
127 | ## 6. Citation (coming soon)
128 |
129 |
--------------------------------------------------------------------------------
/code/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/code/data_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/data_utils/__init__.py
--------------------------------------------------------------------------------
/code/data_utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/data_utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/code/data_utils/__pycache__/dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/data_utils/__pycache__/dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/code/data_utils/__pycache__/transforms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/data_utils/__pycache__/transforms.cpython-37.pyc
--------------------------------------------------------------------------------
/code/data_utils/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import nrrd
5 |
6 | class Molar3D(Dataset):
7 | def __init__(self, transform=None, phase='train', parent_path=None, data_type="full"):
8 |
9 | self.data_files = []
10 | self.label_files = []
11 | self.spacing = []
12 |
13 | cur_path = os.path.join(parent_path, str(phase))
14 | for file_name in os.listdir(cur_path):
15 | if file_name.endswith('_volume.nrrd'):
16 | cur_file_abbr = file_name.split("_volume")[0]
17 |
18 | if data_type == "full":
19 | _label = np.load(os.path.join(cur_path, cur_file_abbr+"_label.npy"))
20 | if np.any(np.sum(_label,1)<0):
21 | continue
22 | if data_type == "mini":
23 | _label = np.load(os.path.join(cur_path, cur_file_abbr+"_label.npy"))
24 | if np.all(np.sum(_label,1)>0):
25 | continue
26 |
27 | self.data_files.append(os.path.join(cur_path, cur_file_abbr+"_volume.nrrd"))
28 | self.label_files.append(os.path.join(cur_path, cur_file_abbr+"_label.npy"))
29 | self.spacing.append(os.path.join(cur_path, cur_file_abbr+"_spacing.npy"))
30 |
31 | self.transform = transform
32 | print('the data length is %d, for %s' % (len(self.data_files), phase))
33 |
34 | def __len__(self):
35 | L = len(self.data_files)
36 | return L
37 |
38 | def __getitem__(self, index):
39 | _img, _ = nrrd.read(self.data_files[index])
40 | _landmark = np.load(self.label_files[index])
41 | _spacing = np.load(self.spacing[index])
42 | sample = {'image': _img, 'landmarks': _landmark, 'spacing':_spacing}
43 | if self.transform is not None:
44 | sample = self.transform(sample)
45 | return sample
46 |
47 | def __str__(self):
48 | pass
49 |
50 |
--------------------------------------------------------------------------------
/code/data_utils/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.ndimage.interpolation import zoom
4 |
5 |
6 | def zoomout_imgandlandmark(img, landmarks, rate):
7 | new_img = zoom(img, rate, order=1)
8 | new_landmarks = []
9 | for position in landmarks:
10 | position_c = position[0] * rate[0]
11 | position_h = position[1] * rate[1]
12 | position_w = position[2] * rate[2]
13 | new_landmarks.append(np.array([position_c, position_h, position_w]))
14 | return new_img, np.array(new_landmarks)
15 |
16 |
17 | class RandomCrop(object):
18 | def __init__(self, min_rate=0.6, size=[128,128,64]):
19 | self.size = np.array(size)
20 | self.min_rate = min_rate
21 |
22 | def __call__(self, sample):
23 | img = sample['image']
24 | landmarks = sample['landmarks']
25 | min_ = np.ones((3,)) * 1000
26 | max_ = np.zeros((3,))
27 | for landmark in landmarks:
28 | for i in range(3):
29 | # we use a very small value to indicate nonexist landmark
30 | if np.mean(landmark)< -100:
31 | continue
32 | if min_[i] > landmark[i]:
33 | min_[i] = landmark[i]
34 | if max_[i] < landmark[i]:
35 | max_[i] = landmark[i]
36 |
37 | # according to the min and max of landmarks to set the maximum zoom rate
38 | zoom_max = [self.size[0]/(max_[0]-min_[0])-0.02, self.size[1]/(max_[1]-min_[1])-0.02, self.size[2]/(max_[2]-min_[2])-0.04]
39 |
40 | ######################### zoom out #############################
41 | random_rate0 = np.random.uniform(self.min_rate, min(zoom_max[0], 1))
42 | random_rate1 = np.random.uniform(self.min_rate, min(zoom_max[1], 1))
43 | random_rate2 = np.random.uniform(self.min_rate, min(zoom_max[2], 1))
44 | if zoom_max[0] landmark[i]:
57 | min_[i] = landmark[i]
58 | if max_[i] < landmark[i]:
59 | max_[i] = landmark[i]
60 | ######################### cropping ###############################
61 | begin_=(min_+max_)/2.-self.size/2.
62 | bc = max(0, begin_[0]); ec = min(min_[0], img.shape[0]-self.size[0])
63 | bh = max(0, begin_[1]); eh = min(min_[1], img.shape[1]-self.size[1])
64 | bw = max(0, begin_[2]); ew = min(min_[2], img.shape[2]-self.size[2])
65 | if ec - bc < 1:
66 | ec += 1
67 | if eh - bh < 1:
68 | eh += 1
69 | if ew - bw < 1:
70 | ew += 1
71 | cc = np.random.randint(bc, ec)
72 | ch = np.random.randint(bh, eh)
73 | cw = np.random.randint(bw, ew)
74 | # random crop here
75 | cur_crop_img = img[cc:(cc+self.size[0]), ch:(ch+self.size[1]), cw:(cw+self.size[2])]
76 |
77 | if(cur_crop_img.shape[0]!=self.size[0] or cur_crop_img.shape[1]!=self.size[1] or cur_crop_img.shape[2]!=self.size[2]):
78 | print(cc, ch, cw)
79 | print(img.shape)
80 | print(cur_crop_img.shape)
81 | print('get a error crop img')
82 | pre_new_landmarks = []
83 | for landmark in landmarks:
84 | cur_landmark = landmark - np.array([cc, ch, cw])
85 | pre_new_landmarks.append(cur_landmark)
86 | sample['landmarks'] = np.array(pre_new_landmarks)
87 | sample['image'] = cur_crop_img
88 | return sample
89 |
90 |
91 | class LandmarkProposal(object):
92 | def __init__(self, size=[128,128,64], shrink=4., anchors=[0.5, 0.75, 1., 1.25], max_num=400):
93 | self.size = size
94 | self.shrink = shrink
95 | self.anchors = anchors
96 | self.max_num = max_num # setting a fixed anchor number for minibatch
97 |
98 | def __call__(self, sample):
99 | landmarks = sample['landmarks']
100 | landmarks = landmarks / self.shrink # shrinking the landmark coordinates
101 | proposals = []
102 |
103 | for idx, anchor in enumerate(self.anchors):
104 | proposal = []
105 | for ldx, landmark in enumerate(landmarks):
106 | if np.mean(landmark) < -100:
107 | cur_ldx = -1 - ldx # negative number indicates nonexist landmarks
108 | proposal.append([0,0,0,0,0,0,cur_ldx])
109 | continue
110 | else:
111 | cur_ldx = ldx
112 |
113 | # if a landmark exist, calculate the proposals
114 | cl_min = landmark - anchor
115 | cl_max = landmark + anchor
116 | c = max(0, int(cl_min[0]))
117 | max_c = int(np.ceil(cl_max[0])); max_w = int(np.ceil(cl_max[1])); max_h = int(np.ceil(cl_max[2]))
118 | while(c<=max_c and c=self.max_num):
135 | print("too many proposals were found !")
136 | proposal = proposal[:self.max_num]
137 | # if getting less proposals, padding the tensor
138 | if len(proposal) landmark[i]:
160 | min_[i] = landmark[i]
161 | if max_[i] < landmark[i]:
162 | max_[i] = landmark[i]
163 | zoom_max = [self.size[0]/(max_[0]-min_[0])-0.02, self.size[1]/(max_[1]-min_[1])-0.02, self.size[2]/(max_[2]-min_[2])-0.04]
164 |
165 | ######################### zoom out #############################
166 | random_rate0 = min(zoom_max[0], 1)
167 | random_rate1 = min(zoom_max[1], 1)
168 | random_rate2 = min(zoom_max[2], 1)
169 | img, landmarks = zoomout_imgandlandmark(img, landmarks, [random_rate0,random_rate1,random_rate2])
170 |
171 | min_ = np.ones((3,)) * 1000
172 | max_ = np.zeros((3,))
173 | for landmark in landmarks:
174 | for i in range(3):
175 | if np.mean(landmark)< -100:
176 | continue
177 | if min_[i] > landmark[i]:
178 | min_[i] = landmark[i]
179 | if max_[i] < landmark[i]:
180 | max_[i] = landmark[i]
181 | # import pdb; pdb.set_trace()
182 | begin = ((max_ + min_) /2 - self.size/2 ).astype("int32")
183 | begin[0] = max(0, min(begin[0], img.shape[0]-self.size[0]) )
184 | begin[1] = max(0, min(begin[1], img.shape[1]-self.size[1]))
185 | begin[2] = max(0, min(begin[2], img.shape[2]-self.size[2]))
186 |
187 | if begin[0]+self.size[0] > img.shape[0] or begin[1]+self.size[1] > img.shape[1] or begin[2]+self.size[2] > img.shape[2]:
188 | print("find a very small landmark , error !!!!!")
189 | # center crop here
190 | sample["image"] = img[begin[0]:begin[0]+self.size[0], begin[1]:begin[1]+self.size[1], begin[2]:begin[2]+self.size[2]]
191 | landmarks[:, 0] = landmarks[:, 0] - begin[0]
192 | landmarks[:, 1] = landmarks[:, 1] - begin[1]
193 | landmarks[:, 2] = landmarks[:, 2] - begin[2]
194 | sample["landmarks"] = landmarks
195 | return sample
196 |
197 |
198 | class Normalize(object):
199 | def __init__(self):
200 | pass
201 |
202 | def __call__(self, sample):
203 | img = np.array(sample['image']).astype(np.float32)
204 | img /= 255.0
205 | sample['image'] = img
206 | return sample
207 |
208 |
209 | class LandMarkToGaussianHeatMap(object):
210 | def __init__(self, R=20., img_size=(128,128,64), n_class=14, GPU=None):
211 | self.R = R # gaussian heatmap radius
212 | self.GPU = GPU
213 |
214 | # generate index in three views: length, width, height
215 | c_row = np.array([i for i in range(img_size[0])])
216 | c_matrix = np.stack([c_row] * img_size[1], 1)
217 | c_matrix = np.stack([c_matrix] * img_size[2], 2)
218 | c_matrix = np.stack([c_matrix] * n_class, 0)
219 |
220 | h_row = np.array([i for i in range(img_size[1])])
221 | h_matrix = np.stack([h_row] * img_size[0], 0)
222 | h_matrix = np.stack([h_matrix] * img_size[2], 2)
223 | h_matrix = np.stack([h_matrix] * n_class, 0)
224 |
225 | w_row = np.array([i for i in range(img_size[2])])
226 | w_matrix = np.stack([w_row] * img_size[0], 0)
227 | w_matrix = np.stack([w_matrix] * img_size[1], 1)
228 | w_matrix = np.stack([w_matrix] * n_class, 0)
229 | if GPU is not None:
230 | self.c_matrix = torch.tensor(c_matrix).float().to(self.GPU)
231 | self.h_matrix = torch.tensor(h_matrix).float().to(self.GPU)
232 | self.w_matrix = torch.tensor(w_matrix).float().to(self.GPU)
233 |
234 | def __call__(self, landmarks):
235 | n_landmark = landmarks.shape[1]
236 | batch_size = landmarks.shape[0]
237 |
238 | if self.GPU is not None:
239 | # generate the mask inside the mask with radius R
240 | mask = torch.sqrt(
241 | torch.pow(
242 | self.c_matrix -
243 | torch.tensor(np.expand_dims(np.expand_dims(landmarks[:, :, 0:1], 3),4)).float().to(self.GPU), 2) + torch.pow(
244 | self.h_matrix -
245 | torch.tensor(np.expand_dims(np.expand_dims(landmarks[:, :, 1:2], 3),4)).float(
246 | ).to(self.GPU), 2) + torch.pow(
247 | self.w_matrix -
248 | torch.tensor(np.expand_dims(np.expand_dims(landmarks[:, :, 2:3], 3),4)
249 | ).float().to(self.GPU), 2)) <= self.R
250 |
251 | # generate the heatmap with Gaussian distribution
252 | # the maximum value is 2, the min value is -1
253 | cur_heatmap = torch.exp(-((
254 | torch.pow(
255 | self.c_matrix - torch.tensor(
256 | np.expand_dims(np.expand_dims(landmarks[:, :, 0:1], 3),4)).float().to(
257 | self.GPU), 2) + torch.pow(
258 | self.h_matrix - torch.tensor(
259 | np.expand_dims(np.expand_dims(landmarks[:, :, 1:2], 3),4)).float().to(
260 | self.GPU), 2) + torch.pow(
261 | self.w_matrix - torch.tensor(
262 | np.expand_dims(np.expand_dims(landmarks[:, :, 2:3],
263 | 3),4)).float().to(self.GPU), 2)) /
264 | (self.R * self.R) / 0.2))
265 | heatmap = 2 * cur_heatmap * mask.float() + mask.float() - 1
266 | return heatmap
267 |
268 |
269 | class ToTensor(object):
270 | def __init__(self):
271 | pass
272 |
273 | def __call__(self, sample):
274 | img = np.array(sample['image']).astype(np.float32)
275 | img = np.expand_dims(img, 0)
276 | sample['image'] = img
277 | sample['landmarks'] = sample['landmarks'].astype(np.float32)
278 | return sample
279 |
280 |
281 |
--------------------------------------------------------------------------------
/code/main_baseline.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import numpy as np
5 | import pandas as pd
6 |
7 | import logging
8 |
9 | import torch
10 | from torch.nn import DataParallel
11 | from torch.backends import cudnn
12 | from torch import optim
13 | from torchvision import transforms
14 | from torch.utils.data import DataLoader
15 |
16 | from data_utils.dataloader import Molar3D
17 | import data_utils.transforms as tr
18 | from utils import setgpu, metric
19 | from data_utils.transforms import LandMarkToGaussianHeatMap
20 | from models.losses import HNM_heatmap
21 | from models.VNet import VNet
22 | from models.UNet import UNet3D, ResidualUNet3D
23 |
24 | # super parameters settings here
25 | parser = argparse.ArgumentParser(description='PyTorch landmarking baselin heatmap regression')
26 | # the network backbone settings
27 | parser.add_argument('--model_name',metavar='MODEL',default='VNet',type=str, choices=['VNet', 'UNet3D', 'ResidualUNet3D'])
28 | # the maximum training epochs
29 | parser.add_argument('--epochs',default=200,type=int,metavar='N')
30 | # the beginning epoch
31 | parser.add_argument('--start_epoch',default=1,type=int)
32 | # the batch size, default 4 for one GPU
33 | parser.add_argument('-b','--batch_size',default=4,type=int)
34 | # the initial learning rate
35 | parser.add_argument('--lr','--learning_rate',default=0.001,type=float)
36 | # the path for loading pretrained model parameters
37 | parser.add_argument('--resume',default='',type=str)
38 | # the weight decay
39 | parser.add_argument('--weight-decay','--wd',default=0.0005,type=float)
40 | # the path to save the model parameters
41 | parser.add_argument('--save_dir',default='../SavePath/baseline',type=str)
42 | # the settings of gpus, multiGPU can use '0,1' or '0,1,2,3'
43 | parser.add_argument('--gpu', default='0', type=str)
44 | # the early stop parameter
45 | parser.add_argument('--patient',default=20,type=int)
46 | # the loss HNM_heatmap for baseline heatmap regression, HNM_propmap for yolol
47 | parser.add_argument('--loss_name', default='HNM_heatmap',type=str)
48 | # the path of dataset
49 | # before training please download the dataset and put it in "../mmld_dataset"
50 | parser.add_argument('--data_path',
51 | default='../mmld_dataset',
52 | type=str,
53 | metavar='N',
54 | help='data path')
55 | # the classes
56 | parser.add_argument('--n_class',default=14,type=int, help='number of landmarks 14')
57 | # the radius of gaussian heatmap's mask
58 | parser.add_argument('-R','--focus_radius', default=20,type=int)
59 | # the test flag | -1 for train, 0 for eval, 1 for test |
60 | parser.add_argument('--test_flag',default=-1,type=int, choices=[-1, 0, 1])
61 |
62 |
63 | DEVICE = torch.device("cuda" if True else "cpu")
64 | def main(args):
65 | cudnn.benchmark = True
66 | setgpu(args.gpu)
67 | ########################### model init #############################################
68 | net = globals()[args.model_name](n_class=args.n_class)
69 | loss = globals()[args.loss_name](R=args.focus_radius)
70 |
71 | start_epoch = args.start_epoch
72 | save_dir = args.save_dir
73 | logging.info(args)
74 | if args.resume:
75 | checkpoint = torch.load(args.resume)
76 | start_epoch = checkpoint['epoch'] + 1
77 | net.load_state_dict(checkpoint['state_dict'])
78 |
79 | net = net.to(DEVICE)
80 | loss = loss.to(DEVICE)
81 | if len(args.gpu.split(',')) > 1 or args.gpu == 'all':
82 | net = DataParallel(net)
83 |
84 | # using Adam optimizer for network training
85 | optimizer = torch.optim.Adam(net.parameters(),
86 | lr=args.lr,
87 | betas=(0.9, 0.98),
88 | weight_decay=args.weight_decay)
89 | # the lr decayed with rate 0.98 each epoch
90 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98, last_epoch=-1)
91 |
92 |
93 | ########################## network testing ########################################
94 | # if the test_flag > -1, calculate the MRE and SDR (%) for val and test set
95 | if args.test_flag > -1:
96 | args.batch_size = 1
97 |
98 | if args.test_flag == 0:
99 | test_transform = transforms.Compose([
100 | tr.Normalize(),
101 | tr.ToTensor(),
102 | ])
103 | phase = 'val'
104 | else:
105 | test_transform = transforms.Compose([
106 | tr.CenterCrop(), # center crop for validation
107 | tr.Normalize(),
108 | tr.ToTensor(),
109 | ])
110 | phase = 'test'
111 | test_dataset = Molar3D(transform=test_transform,
112 | phase=phase,
113 | parent_path=args.data_path)
114 | testloader = DataLoader(test_dataset,
115 | batch_size=1,
116 | shuffle=False,
117 | num_workers=4)
118 | test(testloader, net)
119 | return
120 |
121 |
122 | # generate Gaussian Heatmap using pytorch GPU tensor
123 | l2h = LandMarkToGaussianHeatMap(R=args.focus_radius,
124 | n_class=args.n_class,
125 | GPU=DEVICE,
126 | img_size=(128,128,64))
127 |
128 | ########################## data preparation ########################################
129 | # if the test_flag <= -1, begin network training
130 | # train set and validation set preprocessing
131 | train_transform = transforms.Compose([
132 | tr.RandomCrop(), # zoom and random crop for data augumentation
133 | tr.Normalize(),
134 | tr.ToTensor(),
135 | ])
136 | train_dataset = Molar3D(transform=train_transform,
137 | phase='train',
138 | parent_path=args.data_path)
139 | trainloader = DataLoader(train_dataset,
140 | batch_size=args.batch_size,
141 | shuffle=True,
142 | num_workers=8)
143 |
144 | eval_transform = transforms.Compose([
145 | tr.CenterCrop(), # center crop for validation
146 | tr.Normalize(),
147 | tr.ToTensor(),
148 | ])
149 | eval_dataset = Molar3D(transform=eval_transform,
150 | phase='val',
151 | parent_path=args.data_path)
152 | evalloader = DataLoader(eval_dataset,
153 | batch_size=args.batch_size,
154 | shuffle=False,
155 | num_workers=8)
156 |
157 |
158 | ########################## network training ##########################################
159 | # begin training here
160 | break_flag = 0. # counting for early stop
161 | low_loss = 100.
162 | total_loss = []
163 |
164 | for epoch in range(start_epoch, args.epochs + 1):
165 | # train in one epoch
166 | train(trainloader, net, loss, epoch, optimizer, l2h)
167 | if optimizer.param_groups[0]['lr'] > args.lr * 0.03:
168 | scheduler.step()
169 |
170 | # validation in one epoch
171 | break_flag += 1
172 | eval_loss = evaluation(evalloader, net, loss, epoch, l2h)
173 | total_loss.append(eval_loss)
174 | if low_loss > eval_loss:
175 | low_loss = eval_loss
176 | break_flag = 0
177 | if len(args.gpu.split(',')) > 1 or args.gpu == 'all':
178 | state_dict = net.module.state_dict()
179 | else:
180 | state_dict = net.state_dict()
181 | torch.save(
182 | {
183 | 'epoch': epoch,
184 | 'save_dir': save_dir,
185 | 'state_dict': state_dict,
186 | 'optimizer': optimizer.state_dict(),
187 | 'args': args
188 | }, os.path.join(save_dir, 'model.ckpt'))
189 | logging.info(
190 | '************************ model saved successful ************************** !\n'
191 | )
192 |
193 | if break_flag >args.patient:
194 | break
195 |
196 |
197 | def train(data_loader, net, loss, epoch, optimizer, l2h):
198 | start_time = time.time()
199 | net.train()
200 | total_train_loss = []
201 | for i, sample in enumerate(data_loader):
202 | data = sample['image']
203 | landmark = sample['landmarks']
204 | heatmap_batch = l2h(landmark)
205 | data = data.to(DEVICE)
206 | heatmap = net(data)
207 | optimizer.zero_grad()
208 | cur_loss = loss(heatmap, heatmap_batch)
209 | total_train_loss.append(cur_loss.item())
210 | cur_loss.backward()
211 | optimizer.step()
212 |
213 | logging.info(
214 | 'Train--Epoch[%d], lr[%.6f], total loss: [%.6f], time: %.1f s!'
215 | % (epoch, optimizer.param_groups[0]['lr'], np.mean(total_train_loss), time.time() - start_time))
216 |
217 |
218 | def evaluation(dataloader, net, loss, epoch, l2h):
219 | start_time = time.time()
220 | net.eval()
221 | total_loss = []
222 |
223 | with torch.no_grad():
224 | for i, sample in enumerate(dataloader):
225 | data = sample['image']
226 | landmark = sample['landmarks']
227 | heatmap_batch = l2h(landmark)
228 | data = data.to(DEVICE)
229 | heatmap= net(data)
230 | cur_loss = loss(heatmap, heatmap_batch)
231 | total_loss.append(cur_loss.item())
232 |
233 | logging.info(
234 | 'Eval--Epoch[%d], total loss: [%.6f], time: %.1f s!'
235 | % (epoch, np.mean(total_loss), time.time() - start_time))
236 | logging.info(
237 | '***************************************************************************'
238 | )
239 | return np.mean(total_loss)
240 |
241 |
242 | def test(dataloader, net):
243 | start_time = time.time()
244 | net.eval()
245 | total_mre = []
246 | total_mean_mre = []
247 | N = 0
248 | total_hits = np.zeros((8, 14))
249 | with torch.no_grad():
250 | for i, sample in enumerate(dataloader):
251 | data = sample['image']
252 | landmarks = sample['landmarks']
253 | spacing = sample['spacing']
254 | data = data.to(DEVICE)
255 | heatmap = net(data)
256 |
257 | mre, hits = metric(heatmap.cpu().numpy(),
258 | spacing.numpy(),
259 | landmarks.cpu().numpy())
260 | total_hits += hits
261 | total_mre.append(np.array(mre))
262 | N += data.shape[0]
263 | cur_mre = []
264 | for cdx in range(len(mre[0])):
265 | if mre[0][cdx]>0:
266 | cur_mre.append(mre[0][cdx])
267 | total_mean_mre.append(np.mean(cur_mre))
268 | print("#: No.", i, "--the current MRE is [%.4f] "%np.mean(cur_mre))
269 | total_mre = np.concatenate(total_mre, 0)
270 |
271 |
272 | ################################ molar print##############################################
273 | names = [
274 | 'L0','La', 'Lb', 'Lc', 'Ld', 'Le', 'Lf', 'R0', 'Ra','Rb','Rc','Rd','Re','Rf'
275 | ]
276 |
277 | IDs = ["MRE", "SD", "2.0", "2.5", "3.0", "4."]
278 | form = {"metric": IDs}
279 | mre = []
280 | sd = []
281 | cur_hits = total_hits[:4] / total_hits[4:]
282 |
283 | ############################## each class mre ##############################################
284 | for i, name in enumerate(names):
285 | cur_mre = []
286 | for j in range(total_mre.shape[0]):
287 | if total_mre[j,i] > 0:
288 | cur_mre.append(total_mre[j,i])
289 | cur_mre = np.array(cur_mre)
290 | mre.append(np.mean(cur_mre))
291 | sd.append(np.sqrt(np.sum(pow(np.array(cur_mre) - np.mean(cur_mre), 2)) / (N-1)))
292 |
293 | ########################### total mre ######################################################
294 | mre = np.stack(mre, 0)
295 | sd = np.stack(sd, 0)
296 | total = np.stack([mre, sd], 0)
297 | total = np.concatenate([total, cur_hits], 0)
298 | for i, name in enumerate(names):
299 | form[name] = total[:, i]
300 | df = pd.DataFrame(form, columns = form.keys())
301 | df.to_excel( 'baseline_test.xlsx', index = False, header=True)
302 |
303 | ########################### total mre ######################################################
304 | mmre = np.mean(total_mean_mre)
305 | sd = np.sqrt(np.sum(pow(np.array(total_mean_mre) - mmre, 2)) / (N-1))
306 |
307 | total_hits = np.sum(total_hits, 1)
308 | logging.info(
309 | 'Test-- MRE: [%.2f] + SD: [%.2f], 2.0 mm: [%.4f], 2.5 mm: [%.4f], 3.0 mm: [%.4f], 4.0 mm: [%.4f], using time: %.1f s!' %(
310 | mmre, sd,
311 | total_hits[0] / total_hits[4],
312 | total_hits[1] / total_hits[5],
313 | total_hits[2] / total_hits[6],
314 | total_hits[3] / total_hits[7],
315 | time.time()-start_time))
316 | logging.info(
317 | '***************************************************************************'
318 | )
319 |
320 | if __name__ == '__main__':
321 | global args
322 | args = parser.parse_args()
323 | if not os.path.exists(args.save_dir):
324 | os.makedirs(args.save_dir)
325 | args.save_dir = os.path.join(args.save_dir, args.model_name)
326 | if not os.path.exists(args.save_dir):
327 | os.makedirs(args.save_dir)
328 |
329 | logging.basicConfig(level=logging.INFO,
330 | format='%(asctime)s,%(lineno)d: %(message)s\n',
331 | datefmt='%Y-%m-%d(%a)%H:%M:%S',
332 | filename=os.path.join(args.save_dir, 'log.txt'),
333 | filemode='a')
334 | console = logging.StreamHandler()
335 | console.setLevel(logging.INFO)
336 | logging.getLogger().addHandler(console)
337 | main(args)
338 |
339 |
--------------------------------------------------------------------------------
/code/main_yolol.py:
--------------------------------------------------------------------------------
1 |
2 | # package here
3 | import argparse
4 | import os
5 | import time
6 | import numpy as np
7 | import logging
8 | import pandas as pd
9 |
10 | import torch
11 | from torch.nn import DataParallel
12 | from torch.backends import cudnn
13 | from torch import optim
14 | from torchvision import transforms
15 | from torch.utils.data import DataLoader
16 |
17 | from data_utils.dataloader import Molar3D
18 | import data_utils.transforms as tr
19 | from utils import setgpu, metric_proposal
20 | from models.losses import HNM_propmap
21 |
22 | from models.VNet import PVNet
23 | from models.UNet import PUNet3D , PResidualUNet3D
24 |
25 |
26 | # super parameters settings here
27 | parser = argparse.ArgumentParser(description='PyTorch Robust Mandibular Molar Landmark Detection')
28 | # the network backbone settings
29 | parser.add_argument('--model_name',metavar='MODEL',default='PVNet',type=str, choices=['PVNet', 'PUNet3D', 'PResidualUNet3D'])
30 | # the maximum training epochs
31 | parser.add_argument('--epochs',default=200,type=int,metavar='N')
32 | # the beginning epoch
33 | parser.add_argument('--start_epoch',default=1,type=int)
34 | # the batch size, default 4 for one GPU
35 | parser.add_argument('-b','--batch_size',default=4,type=int)
36 | # the initial learning rate
37 | parser.add_argument('--lr','--learning_rate',default=0.001,type=float)
38 | # the path for loading pretrained model parameters
39 | parser.add_argument('--resume',default='',type=str)
40 | # the weight decay
41 | parser.add_argument('--weight-decay','--wd',default=0.0005,type=float)
42 | # the path to save the model parameters
43 | parser.add_argument('--save_dir',default='../SavePath/yolol',type=str)
44 | # the settings of gpus, multiGPU can use '0,1' or '0,1,2,3'
45 | parser.add_argument('--gpu', default='0', type=str)
46 | # the early stop parameter
47 | parser.add_argument('--patient',default=20,type=int)
48 | # the loss HNM_heatmap for baseline heatmap regression, HNM_propmap for yolol
49 | parser.add_argument('--loss_name', default='HNM_propmap',type=str)
50 | # the path of dataset
51 | # before training please download the dataset and put it in "../mmld_dataset"
52 | parser.add_argument('--data_path',
53 | default='../mmld_dataset',
54 | type=str,
55 | metavar='N',
56 | help='data path')
57 | # the classes
58 | parser.add_argument('--n_class',default=14,type=int, help='number of landmarks 14')
59 | # the downsample times
60 | parser.add_argument('--shrink',default=4,type=int,metavar='shrink')
61 | # the anchor balls default r=[0.5u, 0.75u, 1u, 1.25u]
62 | parser.add_argument('--anchors',
63 | default=[0.5, 0.75, 1., 1.25],
64 | type=list,
65 | metavar='anchors',
66 | help='the anchor balls to predict')
67 | # the test flag | -1 for train, 0 for eval, 1 for test |
68 | parser.add_argument('--test_flag',default=-1,type=int, choices=[-1, 0, 1])
69 | # the data type | full for dataset with complete landmarks | mini for mini dataset with uncomplete landmarks | all for default dataset
70 | parser.add_argument('--data_type', default='all',type=str)
71 |
72 | DEVICE = torch.device("cuda" if True else "cpu")
73 |
74 | def main(args):
75 | logging.info(args)
76 | cudnn.benchmark = True
77 | setgpu(args.gpu)
78 |
79 | ########################### model init #############################################
80 | net = globals()[args.model_name](n_class=args.n_class, n_anchor=len(args.anchors))
81 | loss = globals()[args.loss_name](n_class=args.n_class, device=DEVICE)
82 |
83 | start_epoch = args.start_epoch
84 | save_dir = args.save_dir
85 | logging.info(args)
86 | if args.resume:
87 | checkpoint = torch.load(args.resume)
88 | start_epoch = checkpoint['epoch'] + 1
89 | net.load_state_dict(checkpoint['state_dict'])
90 |
91 | net = net.to(DEVICE)
92 | loss = loss.to(DEVICE)
93 | if len(args.gpu.split(',')) > 1 or args.gpu == 'all':
94 | net = DataParallel(net)
95 |
96 | # using Adam optimizer for network training
97 | optimizer = torch.optim.Adam(net.parameters(),
98 | lr=args.lr,
99 | betas=(0.9, 0.98),
100 | weight_decay=args.weight_decay)
101 | # the lr decayed with rate 0.98 each epoch
102 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98, last_epoch=-1)
103 |
104 |
105 | ########################## network testing ########################################
106 | # if the test_flag > -1, calculate the MRE and SDR (%) for val and test set
107 | if args.test_flag > -1:
108 | args.batch_size = 1
109 | if args.test_flag == 0:
110 | test_transform = transforms.Compose([
111 | tr.LandmarkProposal(shrink=args.shrink, anchors=args.anchors),
112 | tr.Normalize(),
113 | tr.ToTensor(),
114 | ])
115 | phase = 'val'
116 | else:
117 | test_transform = transforms.Compose([
118 | tr.CenterCrop(), # center crop for validation
119 | tr.LandmarkProposal(shrink=args.shrink, anchors=args.anchors),
120 | tr.Normalize(),
121 | tr.ToTensor(),
122 | ])
123 | phase = 'test'
124 | test_dataset = Molar3D(transform=test_transform,
125 | phase=phase,
126 | parent_path=args.data_path,
127 | data_type=args.data_type)
128 |
129 | testloader = DataLoader(test_dataset,
130 | batch_size=args.batch_size,
131 | shuffle=False,
132 | num_workers=4)
133 | test(testloader, net, args)
134 | return
135 |
136 |
137 | ########################## data preparation ########################################
138 | # if the test_flag <= -1, begin network training
139 | # train set and validation set preprocessing
140 | train_transform = transforms.Compose([
141 | tr.RandomCrop(), # zoom and random crop for data augumentation
142 | tr.LandmarkProposal(shrink=args.shrink, anchors=args.anchors), # generate the anchor proposal
143 | tr.Normalize(),
144 | tr.ToTensor(),
145 | ])
146 | train_dataset = Molar3D(transform=train_transform,
147 | phase='train',
148 | parent_path=args.data_path,
149 | data_type = args.data_type)
150 | trainloader = DataLoader(train_dataset,
151 | batch_size=args.batch_size,
152 | shuffle=True,
153 | num_workers=8)
154 |
155 | eval_transform = transforms.Compose([
156 | tr.CenterCrop(), # center crop for validation
157 | tr.LandmarkProposal(shrink=args.shrink, anchors=args.anchors),
158 | tr.Normalize(),
159 | tr.ToTensor(),
160 | ])
161 | eval_dataset = Molar3D(transform=eval_transform,
162 | phase='val',
163 | parent_path=args.data_path,
164 | data_type=args.data_type)
165 | evalloader = DataLoader(eval_dataset,
166 | batch_size=args.batch_size,
167 | shuffle=False,
168 | num_workers=8)
169 |
170 |
171 | ########################## network training ##########################################
172 | # begin training here
173 | break_flag = 0. # counting for early stop
174 | low_loss = 100.
175 | total_loss = []
176 |
177 | for epoch in range(start_epoch, args.epochs + 1):
178 | # train in one epoch
179 | train(trainloader, net, loss, epoch, optimizer)
180 | if optimizer.param_groups[0]['lr'] > args.lr * 0.03:
181 | scheduler.step()
182 |
183 | # validation in one epoch
184 | break_flag += 1
185 | eval_loss = evaluation(evalloader, net, loss, epoch)
186 | total_loss.append(eval_loss)
187 | if low_loss > eval_loss:
188 | low_loss = eval_loss
189 | break_flag = 0
190 | if len(args.gpu.split(',')) > 1 or args.gpu == 'all':
191 | state_dict = net.module.state_dict()
192 | else:
193 | state_dict = net.state_dict()
194 | torch.save(
195 | {
196 | 'epoch': epoch,
197 | 'save_dir': save_dir,
198 | 'state_dict': state_dict,
199 | 'optimizer': optimizer.state_dict(),
200 | 'args': args
201 | }, os.path.join(save_dir, 'model.ckpt'))
202 | logging.info(
203 | '************************ model saved successful ************************** !\n'
204 | )
205 |
206 | if break_flag > args.patient:
207 | break
208 |
209 |
210 | def train(data_loader, net, loss, epoch, optimizer):
211 | start_time = time.time()
212 | net.train()
213 | total_train_loss = []
214 | for i, sample in enumerate(data_loader):
215 | data = sample['image']
216 | proposals = sample['proposals']
217 | data = data.to(DEVICE)
218 | proposals = proposals.to(DEVICE)
219 | proposal_map = net(data)
220 | optimizer.zero_grad()
221 | cur_loss = loss(proposal_map, proposals)
222 | total_train_loss.append(cur_loss.item())
223 | cur_loss.backward()
224 | optimizer.step()
225 |
226 | logging.info(
227 | 'Train--Epoch[%d], lr[%.6f], total loss: [%.6f], time: %.1f s!'
228 | % (epoch, optimizer.param_groups[0]['lr'], np.mean(total_train_loss), time.time() - start_time))
229 |
230 |
231 | def evaluation(dataloader, net, loss, epoch):
232 | start_time = time.time()
233 | net.eval()
234 | total_loss = []
235 | with torch.no_grad():
236 | for i, sample in enumerate(dataloader):
237 | data = sample['image']
238 | proposals = sample['proposals']
239 | data = data.to(DEVICE)
240 | proposals = proposals.to(DEVICE)
241 | proposal_map = net(data)
242 | cur_loss = loss(proposal_map, proposals)
243 | total_loss.append(cur_loss.item())
244 |
245 | logging.info(
246 | 'Eval--Epoch[%d], total loss: [%.6f], time: %.1f s!'
247 | % (epoch, np.mean(total_loss), time.time() - start_time))
248 | logging.info(
249 | '***************************************************************************'
250 | )
251 | return np.mean(total_loss)
252 |
253 |
254 | def test(dataloader, net, args):
255 | start_time = time.time()
256 | net.eval()
257 | total_mre = []
258 | total_mean_mre = []
259 | N = 0
260 | total_hits = np.zeros((8, args.n_class))
261 | with torch.no_grad():
262 | for i, sample in enumerate(dataloader):
263 | data = sample['image']
264 | landmarks = sample['landmarks']
265 | spacing = sample['spacing']
266 | data = data.to(DEVICE)
267 | proposal_map = net(data)
268 | mre, hits = metric_proposal(proposal_map, spacing.numpy(),
269 | landmarks.numpy(), shrink=args.shrink, anchors=args.anchors,
270 | n_class=args.n_class)
271 | total_hits += hits
272 | total_mre.append(np.array(mre))
273 | N += data.shape[0]
274 | cur_mre = []
275 | for cdx in range(len(mre[0])):
276 | if mre[0][cdx]>0:
277 | cur_mre.append(mre[0][cdx])
278 | total_mean_mre.append(np.mean(cur_mre))
279 | print("#: No.", i, "--the current MRE is [%.4f] "%np.mean(cur_mre))
280 | total_mre = np.concatenate(total_mre, 0)
281 |
282 |
283 | ################################# molar print ##############################################
284 | names = [
285 | 'L0','La', 'Lb', 'Lc', 'Ld', 'Le', 'Lf', 'R0', 'Ra','Rb','Rc','Rd','Re','Rf'
286 | ]
287 | IDs = ["MRE", "SD", "2.0", "2.5", "3.0", "4."]
288 | form = {"metric": IDs}
289 | mre = []
290 | sd = []
291 | cur_hits = total_hits[:4] / total_hits[4:]
292 |
293 | ############################## each class mre ##############################################
294 | for i, name in enumerate(names):
295 | cur_mre = []
296 | for j in range(total_mre.shape[0]):
297 | if total_mre[j,i] > 0:
298 | cur_mre.append(total_mre[j,i])
299 | cur_mre = np.array(cur_mre)
300 | mre.append(np.mean(cur_mre))
301 | sd.append(np.sqrt(np.sum(pow(np.array(cur_mre) - np.mean(cur_mre), 2)) / (N-1)))
302 |
303 | mre = np.stack(mre, 0)
304 | sd = np.stack(sd, 0)
305 | total = np.stack([mre, sd], 0)
306 |
307 | total = np.concatenate([total, cur_hits], 0)
308 | for i, name in enumerate(names):
309 | form[name] = total[:, i]
310 | df = pd.DataFrame(form, columns = form.keys())
311 | # write each landmark MRE to xlsx file
312 | df.to_excel( 'yolol_test.xlsx', index = False, header=True)
313 |
314 | ########################### total mre ######################################################
315 | mmre = np.mean(total_mean_mre)
316 | sd = np.sqrt(np.sum(pow(np.array(total_mean_mre) - mmre, 2)) / (N-1))
317 |
318 | total_hits = np.sum(total_hits, 1)
319 | logging.info(
320 | 'Test-- MRE: [%.2f] + SD: [%.2f], 2.0 mm: [%.4f], 2.5 mm: [%.4f], 3.0 mm: [%.4f], 4.0 mm: [%.4f], using time: %.1f s!' %(
321 | mmre, sd,
322 | total_hits[0] / total_hits[4],
323 | total_hits[1] / total_hits[5],
324 | total_hits[2] / total_hits[6],
325 | total_hits[3] / total_hits[7],
326 | time.time()-start_time))
327 | logging.info(
328 | '***************************************************************************'
329 | )
330 |
331 | if __name__ == '__main__':
332 | global args
333 | args = parser.parse_args()
334 | if not os.path.exists(args.save_dir):
335 | os.makedirs(args.save_dir)
336 | args.save_dir = os.path.join(args.save_dir, args.model_name)
337 | if not os.path.exists(args.save_dir):
338 | os.makedirs(args.save_dir)
339 |
340 | logging.basicConfig(level=logging.INFO,
341 | format='%(asctime)s,%(lineno)d: %(message)s\n',
342 | datefmt='%Y-%m-%d(%a)%H:%M:%S',
343 | filename=os.path.join(args.save_dir, 'log.txt'),
344 | filemode='a')
345 | console = logging.StreamHandler()
346 | console.setLevel(logging.INFO)
347 | logging.getLogger().addHandler(console)
348 | main(args)
349 |
350 |
--------------------------------------------------------------------------------
/code/models/UNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 | import importlib
5 |
6 |
7 | def create_feature_maps(init_channel_number, number_of_fmaps):
8 | return [init_channel_number * 2 ** k for k in range(number_of_fmaps)]
9 |
10 |
11 | def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
12 | return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
13 |
14 |
15 | def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
16 | """
17 | Create a list of modules with together constitute a single conv layer with non-linearity
18 | and optional batchnorm/groupnorm.
19 | Args:
20 | in_channels (int): number of input channels
21 | out_channels (int): number of output channels
22 | order (string): order of things, e.g.
23 | 'cr' -> conv + ReLU
24 | 'gcr' -> groupnorm + conv + ReLU
25 | 'cl' -> conv + LeakyReLU
26 | 'ce' -> conv + ELU
27 | 'bcr' -> batchnorm + conv + ReLU
28 | num_groups (int): number of groups for the GroupNorm
29 | padding (int): add zero-padding to the input
30 | Return:
31 | list of tuple (name, module)
32 | """
33 | assert 'c' in order, "Conv layer MUST be present"
34 | assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
35 |
36 | modules = []
37 | for i, char in enumerate(order):
38 | if char == 'r':
39 | modules.append(('ReLU', nn.ReLU(inplace=True)))
40 | elif char == 'l':
41 | modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
42 | elif char == 'e':
43 | modules.append(('ELU', nn.ELU(inplace=True)))
44 | elif char == 'c':
45 | # add learnable bias only in the absence of gatchnorm/groupnorm
46 | bias = not ('g' in order or 'b' in order)
47 | modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
48 | elif char == 'g':
49 | is_before_conv = i < order.index('c')
50 | if is_before_conv:
51 | num_channels = in_channels
52 | else:
53 | num_channels = out_channels
54 |
55 | # use only one group if the given number of groups is greater than the number of channels
56 | if num_channels < num_groups:
57 | num_groups = 1
58 |
59 | assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
60 | modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
61 | elif char == 'b':
62 | is_before_conv = i < order.index('c')
63 | if is_before_conv:
64 | modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
65 | else:
66 | modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
67 | else:
68 | raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")
69 |
70 | return modules
71 |
72 |
73 | class SingleConv(nn.Sequential):
74 | """
75 | Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
76 | of operations can be specified via the `order` parameter
77 | Args:
78 | in_channels (int): number of input channels
79 | out_channels (int): number of output channels
80 | kernel_size (int): size of the convolving kernel
81 | order (string): determines the order of layers, e.g.
82 | 'cr' -> conv + ReLU
83 | 'crg' -> conv + ReLU + groupnorm
84 | 'cl' -> conv + LeakyReLU
85 | 'ce' -> conv + ELU
86 | num_groups (int): number of groups for the GroupNorm
87 | """
88 |
89 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1):
90 | super(SingleConv, self).__init__()
91 |
92 | for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
93 | self.add_module(name, module)
94 |
95 |
96 | class DoubleConv(nn.Sequential):
97 | """
98 | A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
99 | We use (Conv3d+ReLU+GroupNorm3d) by default.
100 | This can be changed however by providing the 'order' argument, e.g. in order
101 | to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
102 | Use padded convolutions to make sure that the output (H_out, W_out) is the same
103 | as (H_in, W_in), so that you don't have to crop in the decoder path.
104 | Args:
105 | in_channels (int): number of input channels
106 | out_channels (int): number of output channels
107 | encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
108 | kernel_size (int): size of the convolving kernel
109 | order (string): determines the order of layers, e.g.
110 | 'cr' -> conv + ReLU
111 | 'crg' -> conv + ReLU + groupnorm
112 | 'cl' -> conv + LeakyReLU
113 | 'ce' -> conv + ELU
114 | num_groups (int): number of groups for the GroupNorm
115 | """
116 |
117 | def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8):
118 | super(DoubleConv, self).__init__()
119 | if encoder:
120 | # we're in the encoder path
121 | conv1_in_channels = in_channels
122 | conv1_out_channels = out_channels // 2
123 |
124 | if conv1_out_channels < in_channels:
125 | conv1_out_channels = in_channels
126 | conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
127 | else:
128 | # we're in the decoder path, decrease the number of channels in the 1st convolution
129 | conv1_in_channels, conv1_out_channels = in_channels, out_channels
130 | conv2_in_channels, conv2_out_channels = out_channels, out_channels
131 | # conv1
132 | self.add_module('SingleConv1',
133 | SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
134 | # conv2
135 | self.add_module('SingleConv2',
136 | SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))
137 |
138 |
139 | class ExtResNetBlock(nn.Module):
140 | """
141 | Basic UNet block consisting of a SingleConv followed by the residual block.
142 | The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
143 | of output channels is compatible with the residual block that follows.
144 | This block can be used instead of standard DoubleConv in the Encoder module.
145 | Motivated by: https://arxiv.org/pdf/1706.00120.pdf
146 | Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
147 | """
148 |
149 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs):
150 | super(ExtResNetBlock, self).__init__()
151 |
152 | # first convolution
153 | self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
154 | # residual block
155 | self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
156 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
157 | n_order = order
158 | for c in 'rel':
159 | n_order = n_order.replace(c, '')
160 | self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
161 | num_groups=num_groups)
162 |
163 | # create non-linearity separately
164 | if 'l' in order:
165 | self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
166 | elif 'e' in order:
167 | self.non_linearity = nn.ELU(inplace=True)
168 | else:
169 | self.non_linearity = nn.ReLU(inplace=True)
170 |
171 | def forward(self, x):
172 | # apply first convolution and save the output as a residual
173 | out = self.conv1(x)
174 | residual = out
175 |
176 | # residual block
177 | out = self.conv2(out)
178 | out = self.conv3(out)
179 |
180 | out += residual
181 | out = self.non_linearity(out)
182 |
183 | return out
184 |
185 |
186 | class Encoder(nn.Module):
187 | """
188 | A single module from the encoder path consisting of the optional max
189 | pooling layer (one may specify the MaxPool kernel_size to be different
190 | than the standard (2,2,2), e.g. if the volumetric data is anisotropic
191 | (make sure to use complementary scale_factor in the decoder path) followed by
192 | a DoubleConv module.
193 | Args:
194 | in_channels (int): number of input channels
195 | out_channels (int): number of output channels
196 | conv_kernel_size (int): size of the convolving kernel
197 | apply_pooling (bool): if True use MaxPool3d before DoubleConv
198 | pool_kernel_size (tuple): the size of the window to take a max over
199 | pool_type (str): pooling layer: 'max' or 'avg'
200 | basic_module(nn.Module): either ResNetBlock or DoubleConv
201 | conv_layer_order (string): determines the order of layers
202 | in `DoubleConv` module. See `DoubleConv` for more info.
203 | num_groups (int): number of groups for the GroupNorm
204 | """
205 |
206 | def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
207 | pool_kernel_size=(2, 2, 2), pool_type='avg', basic_module=DoubleConv, conv_layer_order='crg',
208 | num_groups=8):
209 | super(Encoder, self).__init__()
210 | ###################################################################
211 | assert pool_type in ['max', 'avg']
212 | if apply_pooling:
213 | if pool_type == 'max':
214 | self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
215 | else:
216 | self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
217 | else:
218 | self.pooling = None
219 |
220 | self.basic_module = basic_module(in_channels, out_channels,
221 | encoder=True,
222 | kernel_size=conv_kernel_size,
223 | order=conv_layer_order,
224 | num_groups=num_groups)
225 |
226 | def forward(self, x):
227 | if self.pooling is not None:
228 | x = self.pooling(x)
229 | x = self.basic_module(x)
230 | return x
231 |
232 |
233 | class Decoder(nn.Module):
234 | """
235 | A single module for decoder path consisting of the upsample layer
236 | (either learned ConvTranspose3d or interpolation) followed by a DoubleConv
237 | module.
238 | Args:
239 | in_channels (int): number of input channels
240 | out_channels (int): number of output channels
241 | kernel_size (int): size of the convolving kernel
242 | scale_factor (tuple): used as the multiplier for the image H/W/D in
243 | case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
244 | from the corresponding encoder
245 | basic_module(nn.Module): either ResNetBlock or DoubleConv
246 | conv_layer_order (string): determines the order of layers
247 | in `DoubleConv` module. See `DoubleConv` for more info.
248 | num_groups (int): number of groups for the GroupNorm
249 | """
250 |
251 | def __init__(self, in_channels, out_channels, kernel_size=3,
252 | scale_factor=(2, 2, 2), basic_module=DoubleConv, conv_layer_order='crg', num_groups=8):
253 | super(Decoder, self).__init__()
254 | if basic_module == DoubleConv:
255 | # if DoubleConv is the basic_module use nearest neighbor interpolation for upsampling
256 | self.upsample = None
257 | else:
258 | # otherwise use ConvTranspose3d (bear in mind your GPU memory)
259 | # make sure that the output size reverses the MaxPool3d from the corresponding encoder
260 | # (D_out = (D_in − 1) × stride[0] − 2 × padding[0] + kernel_size[0] + output_padding[0])
261 | # also scale the number of channels from in_channels to out_channels so that summation joining
262 | # works correctly
263 | self.upsample = nn.ConvTranspose3d(in_channels,
264 | out_channels,
265 | kernel_size=kernel_size,
266 | stride=scale_factor,
267 | padding=1,
268 | output_padding=1)
269 | # adapt the number of in_channels for the ExtResNetBlock
270 | in_channels = out_channels
271 |
272 | self.basic_module = basic_module(in_channels, out_channels,
273 | encoder=False,
274 | kernel_size=kernel_size,
275 | order=conv_layer_order,
276 | num_groups=num_groups)
277 |
278 | def forward(self, encoder_features, x):
279 | if self.upsample is None:
280 | # use nearest neighbor interpolation and concatenation joining
281 | output_size = encoder_features.size()[2:]
282 | x = F.interpolate(x, size=output_size, mode='nearest')
283 | # concatenate encoder_features (encoder path) with the upsampled input across channel dimension
284 | x = torch.cat((encoder_features, x), dim=1)
285 | else:
286 | # use ConvTranspose3d and summation joining
287 | x = self.upsample(x)
288 | x += encoder_features
289 |
290 | x = self.basic_module(x)
291 | return x
292 |
293 |
294 | class FinalConv(nn.Sequential):
295 | """
296 | A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
297 | which reduces the number of channels to 'out_channels'.
298 | with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
299 | We use (Conv3d+ReLU+GroupNorm3d) by default.
300 | This can be change however by providing the 'order' argument, e.g. in order
301 | to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
302 | Args:
303 | in_channels (int): number of input channels
304 | out_channels (int): number of output channels
305 | kernel_size (int): size of the convolving kernel
306 | order (string): determines the order of layers, e.g.
307 | 'cr' -> conv + ReLU
308 | 'crg' -> conv + ReLU + groupnorm
309 | num_groups (int): number of groups for the GroupNorm
310 | """
311 |
312 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8):
313 | super(FinalConv, self).__init__()
314 |
315 | # conv1
316 | self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups))
317 |
318 | # in the last layer a 1×1 convolution reduces the number of output channels to out_channels
319 | final_conv = nn.Conv3d(in_channels, out_channels, 1)
320 | self.add_module('final_conv', final_conv)
321 |
322 |
323 |
324 | class UNet3D(nn.Module):
325 | """
326 | 3DUnet model from
327 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
328 | `.
329 | Args:
330 | in_channels (int): number of input channels
331 | out_channels (int): number of output segmentation masks;
332 | Note that that the of out_channels might correspond to either
333 | different semantic classes or to different binary segmentation mask.
334 | It's up to the user of the class to interpret the out_channels and
335 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
336 | or BCEWithLogitsLoss (two-class) respectively)
337 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
338 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
339 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the
340 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used
341 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model.
342 | layer_order (string): determines the order of layers
343 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
344 | See `SingleConv` for more info
345 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64
346 | num_groups (int): number of groups for the GroupNorm
347 | """
348 |
349 | def __init__(self, n_class, in_channels=1, f_maps=32, layer_order='cgr', num_groups=8,
350 | **kwargs):
351 | super(UNet3D, self).__init__()
352 |
353 | # Set testing mode to false by default. It has to be set to true in test mode, otherwise the `final_activation`
354 | # layer won't be applied
355 | out_channels = n_class
356 | if isinstance(f_maps, int):
357 | # use 4 levels in the encoder path as suggested in the paper
358 | f_maps = create_feature_maps(f_maps, number_of_fmaps=4)
359 |
360 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)`
361 | # uses DoubleConv as a basic_module for the Encoder
362 | encoders = []
363 | for i, out_feature_num in enumerate(f_maps):
364 | if i == 0:
365 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv,
366 | conv_layer_order=layer_order, num_groups=num_groups)
367 | else:
368 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv,
369 | conv_layer_order=layer_order, num_groups=num_groups)
370 | encoders.append(encoder)
371 |
372 | self.encoders = nn.ModuleList(encoders)
373 |
374 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
375 | # uses DoubleConv as a basic_module for the Decoder
376 | decoders = []
377 | reversed_f_maps = list(reversed(f_maps))
378 | for i in range(len(reversed_f_maps) - 1):
379 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
380 | out_feature_num = reversed_f_maps[i + 1]
381 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv,
382 | conv_layer_order=layer_order, num_groups=num_groups)
383 | decoders.append(decoder)
384 |
385 | self.decoders = nn.ModuleList(decoders)
386 |
387 | # in the last layer a 1×1 convolution reduces the number of output
388 | # channels to the number of labels
389 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
390 |
391 |
392 | def forward(self, x):
393 | # encoder part
394 | encoders_features = []
395 | for encoder in self.encoders:
396 | x = encoder(x)
397 | # reverse the encoder outputs to be aligned with the decoder
398 | encoders_features.insert(0, x)
399 |
400 | # remove the last encoder's output from the list
401 | # !!remember: it's the 1st in the list
402 | encoders_features = encoders_features[1:]
403 |
404 | # decoder part
405 | for decoder, encoder_features in zip(self.decoders, encoders_features):
406 | # pass the output from the corresponding encoder and the output
407 | # of the previous decoder
408 | x = decoder(encoder_features, x)
409 |
410 | x = self.final_conv(x)
411 |
412 | # apply final_activation (i.e. Sigmoid or Softmax) only at test time; during training/evaluation the network
413 | # outputs logits and it's up to the user to normalize it before visualising with tensorboard
414 | # or computing validation metric
415 |
416 | return x
417 |
418 |
419 | class PUNet3D(nn.Module):
420 |
421 | def __init__(self, n_class, n_anchor=4, in_channels=1, f_maps=32, layer_order='cgr', num_groups=8,
422 | **kwargs):
423 | super(PUNet3D, self).__init__()
424 | self.n_anchor = n_anchor
425 | self.n_class = n_class
426 | if isinstance(f_maps, int):
427 | # use 4 levels in the encoder path as suggested in the paper
428 | f_maps = create_feature_maps(f_maps, number_of_fmaps=4)
429 |
430 | encoders = []
431 | for i, out_feature_num in enumerate(f_maps):
432 | if i == 0:
433 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv,
434 | conv_layer_order=layer_order, num_groups=num_groups)
435 | else:
436 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv,
437 | conv_layer_order=layer_order, num_groups=num_groups)
438 | encoders.append(encoder)
439 | self.encoders = nn.ModuleList(encoders)
440 |
441 | decoders = []
442 | reversed_f_maps = list(reversed(f_maps))
443 | for i in range(len(reversed_f_maps) - 3):
444 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
445 | out_feature_num = reversed_f_maps[i + 1]
446 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv,
447 | conv_layer_order=layer_order, num_groups=num_groups)
448 | decoders.append(decoder)
449 | self.decoders = nn.ModuleList(decoders)
450 | self.early_down1 = nn.Conv3d(f_maps[0], f_maps[2], kernel_size=1, stride=4)
451 | self.early_down2 = nn.Conv3d(f_maps[1], f_maps[2], kernel_size=1, stride=2)
452 | self.pre_layer = nn.Conv3d(3*f_maps[2], n_anchor*(3+n_class), kernel_size=1, stride=1)
453 |
454 |
455 | def forward(self, x):
456 | # encoder part
457 | encoders_features = []
458 |
459 | for encoder in self.encoders:
460 | x = encoder(x)
461 | print("encoder", x.shape)
462 | # reverse the encoder outputs to be aligned with the decoder
463 | encoders_features.insert(0, x)
464 |
465 | encoders_features = encoders_features[1:]
466 | for decoder, encoder_features in zip(self.decoders, encoders_features):
467 | x = decoder(encoder_features, x)
468 | print("decoder", x.shape)
469 |
470 | early_out1 = self.early_down1(encoders_features[-1])
471 | early_out2 = self.early_down2(encoders_features[-2])
472 | out = self.pre_layer(torch.cat([early_out1, early_out2, x], 1))
473 | out = out.reshape(out.shape[0], self.n_anchor, 3+self.n_class, out.shape[2], out.size(3), out.size(4))
474 | out = out.permute(0,3,4,5,1,2)
475 | return out
476 |
477 |
478 | class ResidualUNet3D(nn.Module):
479 | """
480 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
481 | Uses ExtResNetBlock instead of DoubleConv as a basic building block as well as summation joining instead
482 | of concatenation joining. Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
483 | Args:
484 | in_channels (int): number of input channels
485 | out_channels (int): number of output segmentation masks;
486 | Note that that the of out_channels might correspond to either
487 | different semantic classes or to different binary segmentation mask.
488 | It's up to the user of the class to interpret the out_channels and
489 | use the proper loss criterion during training (i.e. NLLLoss (multi-class)
490 | or BCELoss (two-class) respectively)
491 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
492 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4,5
493 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the
494 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used
495 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model.
496 | conv_layer_order (string): determines the order of layers
497 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
498 | See `SingleConv` for more info
499 | init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64
500 | num_groups (int): number of groups for the GroupNorm
501 | skip_final_activation (bool): if True, skips the final normalization layer (sigmoid/softmax) and returns the
502 | logits directly
503 | """
504 |
505 | def __init__(self, n_class, in_channels=1, f_maps=32, conv_layer_order='cge', num_groups=8,
506 | **kwargs):
507 | super(ResidualUNet3D, self).__init__()
508 | out_channels = n_class
509 | # Set testing mode to false by default. It has to be set to true in test mode, otherwise the `final_activation`
510 | # layer won't be applied
511 |
512 | if isinstance(f_maps, int):
513 | # use 5 levels in the encoder path as suggested in the paper
514 | f_maps = create_feature_maps(f_maps, number_of_fmaps=5)
515 |
516 | # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)`
517 | # uses ExtResNetBlock as a basic_module for the Encoder
518 | encoders = []
519 | for i, out_feature_num in enumerate(f_maps):
520 | if i == 0:
521 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=ExtResNetBlock,
522 | conv_layer_order=conv_layer_order, num_groups=num_groups)
523 | else:
524 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=ExtResNetBlock,
525 | conv_layer_order=conv_layer_order, num_groups=num_groups)
526 | encoders.append(encoder)
527 |
528 | self.encoders = nn.ModuleList(encoders)
529 |
530 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
531 | # uses ExtResNetBlock as a basic_module for the Decoder
532 | decoders = []
533 | reversed_f_maps = list(reversed(f_maps))
534 | for i in range(len(reversed_f_maps) - 1):
535 | decoder = Decoder(reversed_f_maps[i], reversed_f_maps[i + 1], basic_module=ExtResNetBlock,
536 | conv_layer_order=conv_layer_order, num_groups=num_groups)
537 | decoders.append(decoder)
538 |
539 | self.decoders = nn.ModuleList(decoders)
540 |
541 | # in the last layer a 1×1 convolution reduces the number of output
542 | # channels to the number of labels
543 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
544 |
545 |
546 | def forward(self, x):
547 | # encoder part
548 | encoders_features = []
549 | for encoder in self.encoders:
550 | x = encoder(x)
551 | # reverse the encoder outputs to be aligned with the decoder
552 | encoders_features.insert(0, x)
553 |
554 | # remove the last encoder's output from the list
555 | # !!remember: it's the 1st in the list
556 | encoders_features = encoders_features[1:]
557 |
558 | # decoder part
559 | for decoder, encoder_features in zip(self.decoders, encoders_features):
560 | # pass the output from the corresponding encoder and the output
561 | # of the previous decoder
562 | x = decoder(encoder_features, x)
563 |
564 | x = self.final_conv(x)
565 |
566 | # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs
567 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric
568 |
569 | return x
570 |
571 |
572 | class PResidualUNet3D(nn.Module):
573 | def __init__(self, n_class, n_anchor, in_channels=1, f_maps=32, conv_layer_order='cge', num_groups=8,
574 | **kwargs):
575 | super(PResidualUNet3D, self).__init__()
576 | self.n_class = n_class
577 | self.n_anchor = n_anchor
578 |
579 | if isinstance(f_maps, int):
580 | f_maps = create_feature_maps(f_maps, number_of_fmaps=5)
581 | encoders = []
582 | for i, out_feature_num in enumerate(f_maps):
583 | if i == 0:
584 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=ExtResNetBlock,
585 | conv_layer_order=conv_layer_order, num_groups=num_groups)
586 | else:
587 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=ExtResNetBlock,
588 | conv_layer_order=conv_layer_order, num_groups=num_groups)
589 | encoders.append(encoder)
590 |
591 | self.encoders = nn.ModuleList(encoders)
592 |
593 | decoders = []
594 | reversed_f_maps = list(reversed(f_maps))
595 | for i in range(len(reversed_f_maps) - 3):
596 | decoder = Decoder(reversed_f_maps[i], reversed_f_maps[i + 1], basic_module=ExtResNetBlock,
597 | conv_layer_order=conv_layer_order, num_groups=num_groups)
598 | decoders.append(decoder)
599 |
600 | self.decoders = nn.ModuleList(decoders)
601 | self.early_down1 = nn.Conv3d(f_maps[0], f_maps[2], kernel_size=1, stride=4)
602 | self.early_down2 = nn.Conv3d(f_maps[1], f_maps[2], kernel_size=1, stride=2)
603 | self.pre_layer = nn.Conv3d(3*f_maps[2], n_anchor*(3+n_class), kernel_size=1, stride=1)
604 |
605 | def forward(self, x):
606 | # encoder part
607 | encoders_features = []
608 | for encoder in self.encoders:
609 | x = encoder(x)
610 |
611 | encoders_features.insert(0, x)
612 | encoders_features = encoders_features[1:]
613 |
614 | for decoder, encoder_features in zip(self.decoders, encoders_features):
615 | x = decoder(encoder_features, x)
616 |
617 | early_out1 = self.early_down1(encoders_features[-1])
618 | early_out2 = self.early_down2(encoders_features[-2])
619 | out = self.pre_layer(torch.cat([early_out1, early_out2, x], 1))
620 | out = out.reshape(out.shape[0], self.n_anchor, 3+self.n_class, out.shape[2], out.size(3), out.size(4))
621 | out = out.permute(0,3,4,5,1,2)
622 | return out
623 |
624 |
--------------------------------------------------------------------------------
/code/models/VNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def passthrough(x, **kwargs):
7 | return x
8 |
9 | def ELUCons(elu, nchan):
10 | if elu:
11 | return nn.ELU(inplace=True)
12 | else:
13 | return nn.ReLU(nchan)
14 |
15 |
16 | class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
17 | def forward(self, input):
18 | return F.batch_norm(
19 | input, self.running_mean, self.running_var, self.weight, self.bias,
20 | True, self.momentum, self.eps)
21 |
22 |
23 | class LUConv(nn.Module):
24 | def __init__(self, nchan, elu):
25 | super(LUConv, self).__init__()
26 | self.relu1 = ELUCons(elu, nchan)
27 | self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
28 | self.bn1 = ContBatchNorm3d(nchan)
29 |
30 | def forward(self, x):
31 | out = self.relu1(self.bn1(self.conv1(x)))
32 | return out
33 |
34 |
35 | def _make_nConv(nchan, depth, elu):
36 | layers = []
37 | for _ in range(depth):
38 | layers.append(LUConv(nchan, elu))
39 | return nn.Sequential(*layers)
40 |
41 |
42 | class InputTransition(nn.Module):
43 | def __init__(self, outChans, elu):
44 | super(InputTransition, self).__init__()
45 | self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2)
46 | self.bn1 = ContBatchNorm3d(16)
47 | self.relu1 = ELUCons(elu, 16)
48 |
49 | def forward(self, x):
50 | out = self.bn1(self.conv1(x))
51 | x16 = torch.cat([x]*16, 1)
52 | out = self.relu1(torch.add(out, x16))
53 | return out
54 |
55 |
56 | class DownTransition(nn.Module):
57 | def __init__(self, inChans, nConvs, elu, dropout=False):
58 | super(DownTransition, self).__init__()
59 | outChans = 2*inChans
60 | self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
61 | self.bn1 = ContBatchNorm3d(outChans)
62 | self.do1 = passthrough
63 | self.relu1 = ELUCons(elu, outChans)
64 | self.relu2 = ELUCons(elu, outChans)
65 | if dropout:
66 | self.do1 = nn.Dropout3d()
67 | self.ops = _make_nConv(outChans, nConvs, elu)
68 |
69 | def forward(self, x):
70 | down = self.relu1(self.bn1(self.down_conv(x)))
71 | out = self.do1(down)
72 | out = self.ops(out)
73 | out = self.relu2(torch.add(out, down))
74 | return out
75 |
76 |
77 | class UpTransition(nn.Module):
78 | def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
79 | super(UpTransition, self).__init__()
80 | self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
81 | self.bn1 = ContBatchNorm3d(outChans // 2)
82 | self.do1 = passthrough
83 | self.do2 = nn.Dropout3d()
84 | self.relu1 = ELUCons(elu, outChans // 2)
85 | self.relu2 = ELUCons(elu, outChans)
86 | if dropout:
87 | self.do1 = nn.Dropout3d()
88 | self.ops = _make_nConv(outChans, nConvs, elu)
89 |
90 | def forward(self, x, skipx):
91 | out = self.do1(x)
92 | skipxdo = self.do2(skipx)
93 | out = self.relu1(self.bn1(self.up_conv(out)))
94 | xcat = torch.cat((out, skipxdo), 1)
95 | out = self.ops(xcat)
96 | out = self.relu2(torch.add(out, xcat))
97 | return out
98 |
99 |
100 | class OutputTransition(nn.Module):
101 | def __init__(self, inChans, elu, nll, n_class):
102 | super(OutputTransition, self).__init__()
103 | self.conv1 = nn.Conv3d(inChans, n_class, kernel_size=3, padding=1)
104 |
105 | def forward(self, x):
106 | out = self.conv1(x)
107 | return out
108 |
109 |
110 | class VNet(nn.Module):
111 | def __init__(self, n_class, elu=True, nll=False):
112 | super(VNet, self).__init__()
113 | self.in_tr = InputTransition(16, elu)
114 | self.down_tr32 = DownTransition(16, 1, elu)
115 | self.down_tr64 = DownTransition(32, 2, elu)
116 | self.down_tr128 = DownTransition(64, 3, elu, dropout=False)
117 | self.down_tr256 = DownTransition(128, 2, elu, dropout=False)
118 |
119 | self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False)
120 | self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False)
121 | self.up_tr64 = UpTransition(128, 64, 1, elu)
122 | self.up_tr32 = UpTransition(64, 32, 1, elu)
123 | self.out_tr = OutputTransition(32, elu, nll, n_class)
124 |
125 |
126 | def forward(self, x):
127 | out16 = self.in_tr(x)
128 | out32 = self.down_tr32(out16)
129 | out64 = self.down_tr64(out32)
130 | out128 = self.down_tr128(out64)
131 | out256 = self.down_tr256(out128)
132 |
133 | out = self.up_tr256(out256, out128)
134 | out = self.up_tr128(out, out64)
135 | out = self.up_tr64(out, out32)
136 | out = self.up_tr32(out, out16)
137 | out = self.out_tr(out)
138 |
139 | return out
140 |
141 |
142 | class PVNet(nn.Module):
143 | def __init__(self, n_class, n_anchor=4, elu=True, nll=False):
144 | super(PVNet, self).__init__()
145 | self.in_tr = InputTransition(16, elu)
146 | self.down_tr32 = DownTransition(16, 1, elu)
147 | self.down_tr64 = DownTransition(32, 2, elu)
148 | self.down_tr128 = DownTransition(64, 3, elu, dropout=False)
149 | self.down_tr256 = DownTransition(128, 2, elu, dropout=False)
150 |
151 | self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False)
152 | self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False)
153 | self.n_anchor = n_anchor
154 | self.n_class = n_class
155 |
156 | self.early_down1 = nn.Conv3d(16, 64, kernel_size=1, stride=4)
157 | self.early_down2 = nn.Conv3d(32, 64, kernel_size=1, stride=2)
158 | self.pre_layer = nn.Conv3d(64+64+128, n_anchor*(3+n_class), kernel_size=1, stride=1)
159 |
160 | def forward(self, x):
161 | out16 = self.in_tr(x)
162 | out32 = self.down_tr32(out16)
163 | out64 = self.down_tr64(out32)
164 | out128 = self.down_tr128(out64)
165 | out256 = self.down_tr256(out128)
166 |
167 | out = self.up_tr256(out256, out128)
168 | out = self.up_tr128(out, out64)
169 |
170 | early_out1 = self.early_down1(out16)
171 | early_out2 = self.early_down2(out32)
172 | out = self.pre_layer(torch.cat([early_out1, early_out2, out], 1))
173 | out = out.reshape(out.shape[0], self.n_anchor, 3+self.n_class, out.shape[2], out.size(3), out.size(4))
174 | out = out.permute(0,3,4,5,1,2)
175 | return out
176 |
--------------------------------------------------------------------------------
/code/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/models/__init__.py
--------------------------------------------------------------------------------
/code/models/__pycache__/UNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/models/__pycache__/UNet.cpython-37.pyc
--------------------------------------------------------------------------------
/code/models/__pycache__/VNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/models/__pycache__/VNet.cpython-37.pyc
--------------------------------------------------------------------------------
/code/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/code/models/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/code/models/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/code/models/losses.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 | # HNM_heatmap loss for heatmap regression
7 | class HNM_heatmap(nn.Module):
8 | def __init__(self, R=20):
9 | super(HNM_heatmap, self).__init__()
10 | self.R = R
11 | self.regressionLoss = nn.SmoothL1Loss(reduction='mean')
12 |
13 | def forward(self, heatmap, target_heatmap):
14 | loss = 0
15 | batch_size = heatmap.size(0)
16 | n_class = heatmap.size(1)
17 | heatmap = heatmap.reshape(batch_size, n_class, -1)
18 | target_heatmap = target_heatmap.reshape(batch_size, n_class, -1)
19 | for i in range(batch_size):
20 | for j in range(n_class):
21 | # counting the heatmap voxels
22 | select_number = torch.sum(
23 | target_heatmap[i, j] >= 0).int().item()
24 |
25 | if select_number <= 0:
26 | # if landmark is nonexist, setting a fixed number of hard negative mining
27 | select_number = int(self.R * self.R * self.R / 8)
28 | else:
29 | # if existing a landmark, regress these voxels inside the mask
30 | _, cur_idx = torch.topk(
31 | target_heatmap[i, j], select_number)
32 | predict_pos = heatmap[i, j].index_select(0, cur_idx)
33 | target_pos = target_heatmap[i, j].index_select(0, cur_idx)
34 | loss += self.regressionLoss(predict_pos, target_pos)
35 |
36 | # using hard negative mining for background voxels
37 | # the default background voxel is -1
38 | mask_neg = 1 - target_heatmap[i, j]
39 | neg_number = torch.sum(
40 | target_heatmap[i, j] < 0).int().item()
41 | _, neg_idx = torch.topk(mask_neg, neg_number)
42 | predict_neg = heatmap[i, j].index_select(0, neg_idx)
43 | _, cur_idx = torch.topk(predict_neg,
44 | select_number)
45 | predict_neg = heatmap[i, j].index_select(0, cur_idx)
46 | target_neg = target_heatmap[i, j].index_select(0, cur_idx)
47 | loss += self.regressionLoss(predict_neg, target_neg)
48 | return loss / (batch_size * n_class)
49 |
50 |
51 | # HNM_propmap loss for yolol model training
52 | class HNM_propmap(nn.Module):
53 | def __init__(self, n_class=14, lambda_hnm=0.2,lambda_noobj=0.001, device=None): #0.2
54 | super(HNM_propmap, self).__init__()
55 | self.regressionLoss = nn.SmoothL1Loss() # regression loss
56 | self.bceLoss = nn.BCEWithLogitsLoss() # classification loss
57 | self.n_class = n_class
58 | self.lambda_hnm = lambda_hnm # the weight for hard negative mining
59 | self.lambda_noobj = lambda_noobj # the weight for regularization to make background deactivate
60 | self.device = device
61 | self.hard_num = 256 # the selected number for nonexist landmark
62 |
63 | def forward(self, proposal_map, proposals):
64 | loss = 0
65 | batch_size = proposal_map.size(0)
66 |
67 | cl_pred_pos = []
68 | cl_pred_neg = []
69 | reg_pred = []
70 | reg_target = []
71 | hard_neg_count = np.zeros((self.n_class, )).astype("int32")
72 | hard_neg_pred = []
73 | for i in range(batch_size):
74 | for anchor_idx, proposal in enumerate(proposals[i]):
75 | for bbox in proposal:
76 | c=int(bbox[0]); w=int(bbox[1]); h=int(bbox[2])
77 | # -100 indicate the padded proposal
78 | # the details refer to class LandmarkProposal in data_utils/transforms.py
79 | if bbox[-1] == -100:
80 | break
81 | elif bbox[-1] >= 0:
82 | # if landmark exist, generate prediction and target of relative coordinates
83 | cl_pred_pos.append(proposal_map[i, c, w, h, anchor_idx, int(3+bbox[-1]):int(4+bbox[-1])])
84 | cl_pred_neg.append(proposal_map[i, c, w, h, anchor_idx, 3:int(3+bbox[-1])])
85 | cl_pred_neg.append(proposal_map[i, c, w, h, anchor_idx, int(4+bbox[-1]):])
86 | reg_pred.append(proposal_map[i, c, w, h, anchor_idx, :3])
87 | reg_target.append(bbox[3:-1])
88 | else:
89 | # if landmark nonexist, indicate the label for hard negative mining
90 | hard_neg_count[-1-int(bbox[-1].item())] += 1
91 |
92 | # select hard negative voxels for nonexist landmarks
93 | for i in range(self.n_class):
94 | if hard_neg_count[i] != 0:
95 | cur_negative = proposal_map[:,:,:,:,:,3+i].reshape(-1)
96 | _, neg_idx = torch.topk(cur_negative, hard_neg_count[i]*self.hard_num)
97 | hard_neg_pred.append(cur_negative[neg_idx])
98 |
99 |
100 | cl_pred_pos = torch.cat(cl_pred_pos, 0)
101 | cl_pred_neg = torch.cat(cl_pred_neg, 0)
102 | ################## classification loss for positive ############################
103 | cl_pos_loss= self.bceLoss(cl_pred_pos, torch.ones((cl_pred_pos.shape[0],)).to(self.device))
104 | ################## classification loss for negative ######################
105 | cl_neg_loss= 1/(self.n_class-1) * self.bceLoss(cl_pred_neg, torch.zeros((cl_pred_neg.shape[0],)).to(self.device))
106 |
107 | ################# classification loss for hard negative #########################
108 | cl_hard_neg_loss = 0
109 | if len(hard_neg_pred) > 0:
110 | hard_neg_pred = torch.cat(hard_neg_pred, 0)
111 | cl_hard_neg_loss += self.lambda_hnm*self.bceLoss(hard_neg_pred, torch.zeros((hard_neg_pred.shape[0],)).to
112 | (self.device))
113 |
114 | ################### classification loss for regularization ######################
115 | regu_neg_loss = self.lambda_noobj*self.bceLoss(proposal_map,
116 | torch.zeros_like(proposal_map).to(self.device))
117 |
118 | ################################## regression ###################################
119 | reg_loss = self.regressionLoss(torch.tanh(torch.stack(reg_pred, 0)), torch.stack(reg_target, 0))
120 | loss += cl_pos_loss + cl_neg_loss + cl_hard_neg_loss + regu_neg_loss + reg_loss
121 | return loss
--------------------------------------------------------------------------------
/code/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | def setgpu(gpus):
6 | if gpus=='all':
7 | gpus = '0,1,2,3'
8 | print('using gpu '+gpus)
9 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus
10 | return len(gpus.split(','))
11 |
12 |
13 | def metric(heatmap, spacing, landmarks):
14 | N = heatmap.shape[0]
15 | n_class = heatmap.shape[1]
16 | total_mre = []
17 | max_num = 500
18 | hits = np.zeros((8, n_class))
19 |
20 | for j in range(N):
21 | cur_mre_group = []
22 | for i in range(n_class):
23 | max_count = 0
24 | group_rate = 0.999
25 | if np.max(heatmap[j,i])>0:
26 | while max_count < max_num:
27 | h_score_idxs = np.where(
28 | heatmap[j, i] >= np.max(heatmap[j, i])*group_rate)
29 | group_rate = group_rate - 0.1
30 | max_count = len(h_score_idxs[0])
31 | else:
32 | h_score_idxs = np.where(
33 | heatmap[j, i] >= np.max(heatmap[j, i])*(1+0.5))
34 |
35 | h_predict_location = np.array(
36 | [np.mean(h_score_idxs[0]), np.mean(h_score_idxs[1]), np.mean(h_score_idxs[2])])
37 |
38 | cur_mre = np.linalg.norm(
39 | np.array(landmarks[j,i] - h_predict_location)*spacing, ord=2)
40 |
41 | if np.mean(landmarks[j, i])>0:
42 | cur_mre_group.append(cur_mre)
43 | hits[4:, i] += 1
44 | if cur_mre <= 2.0:
45 | hits[0, i] += 1
46 | if cur_mre <= 2.5:
47 | hits[1, i] += 1
48 | if cur_mre <= 3.:
49 | hits[2, i] += 1
50 | if cur_mre <= 4.:
51 | hits[3, i] += 1
52 | else:
53 | cur_mre_group.append(-1)
54 | total_mre.append(np.array(cur_mre_group))
55 |
56 | return total_mre, hits
57 |
58 |
59 | def min_distance_voting(landmarks):
60 | min_dis = 1000000
61 | min_landmark = landmarks[0]
62 | for landmark in landmarks:
63 | cur_dis = 0
64 | for sub_landmark in landmarks:
65 | cur_dis += np.linalg.norm(
66 | np.array(landmark - sub_landmark), ord=2)
67 | if cur_dis < min_dis:
68 | min_dis = cur_dis
69 | min_landmark = landmark
70 | return min_landmark
71 |
72 |
73 | def metric_proposal(proposal_map, spacing,
74 | landmarks, shrink=4., anchors=[0.5, 1, 1.5, 2], n_class=14):
75 | # selected number for candidate landmark voting for one landmark
76 | # can be fine-tuned according to anchor numbers
77 | select_number = 15
78 |
79 | batch_size = proposal_map.size(0)
80 | c = proposal_map.size(1)
81 | w = proposal_map.size(2)
82 | h = proposal_map.size(3)
83 | n_anchor = proposal_map.size(4)
84 | total_mre = []
85 | hits = np.zeros((8, n_class))
86 |
87 | for j in range(batch_size):
88 | cur_mre_group = []
89 | for idx in range(n_class):
90 | #################### from proposal map to landmarks #########################
91 | proposal_map_vector = proposal_map[:,:,:,:,:,3+idx].reshape(-1)
92 | mask = torch.zeros_like(proposal_map_vector)
93 | _, cur_idx = torch.topk(
94 | proposal_map_vector, select_number)
95 | mask[cur_idx] = 1
96 | mask_tensor = mask.reshape((batch_size, c, w, h, n_anchor, -1))
97 | select_index = np.where(mask_tensor.cpu().numpy()==1)
98 |
99 | # get predicted position
100 | pred_pos = []
101 | for i in range(len(select_index[0])):
102 | cur_pos = []
103 | cur_batch = select_index[0][i]
104 | cur_c = select_index[1][i]
105 | cur_w = select_index[2][i]
106 | cur_h = select_index[3][i]
107 | cur_anchor = select_index[4][i]
108 | cur_predict = torch.tanh(proposal_map[cur_batch, cur_c, cur_w, cur_h, cur_anchor, :3]).cpu().numpy()
109 |
110 | cur_pos.append( (np.array([cur_c, cur_w, cur_h]) + cur_predict*anchors[cur_anchor])*shrink )
111 | pred_pos.append(cur_pos)
112 | pred_pos = np.array(pred_pos)
113 |
114 | cur_mre = np.linalg.norm(
115 | (np.array(landmarks[j,idx] - min_distance_voting(pred_pos)))*spacing[j], ord=2)
116 | if cur_mre <= 2.0:
117 | hits[0, idx] += 1
118 | if cur_mre <= 2.5:
119 | hits[1, idx] += 1
120 | if cur_mre <= 3.:
121 | hits[2, idx] += 1
122 | if cur_mre <= 4.:
123 | hits[3, idx] += 1
124 |
125 | if np.mean(landmarks[j, idx])>0:
126 | cur_mre_group.append(cur_mre)
127 | hits[4:, idx] += 1
128 | else:
129 | # if landmark nonexist, do not calculate MRE and SDR, using -1 to indicate it
130 | cur_mre_group.append(-1)
131 | total_mre.append(np.array(cur_mre_group))
132 | return total_mre, hits
133 |
--------------------------------------------------------------------------------
/images/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/images/cover.png
--------------------------------------------------------------------------------
/images/problem1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/images/problem1.png
--------------------------------------------------------------------------------
/images/problem2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/images/problem2.png
--------------------------------------------------------------------------------
/images/table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ithet1007/mmld_code/c060b775bc4e44622a51cbe64af0deeea47250ac/images/table.png
--------------------------------------------------------------------------------