├── README.md ├── config.py └── model ├── UNet_S.py └── ABHG_FOR_UIC.py /README.md: -------------------------------------------------------------------------------- 1 | The official implementation of "A Bayesian Network for Simultaneous Keyframe and Landmark Detection in Ultrasonic Cine" 2 | 3 | --- 4 | # Citation 5 | ``` 6 | @article{FENG2024103228, 7 | title = {A Bayesian network for simultaneous keyframe and landmark detection in ultrasonic cine}, 8 | journal = {Medical Image Analysis}, 9 | volume = {97}, 10 | pages = {103228}, 11 | year = {2024}, 12 | issn = {1361-8415}, 13 | doi = {https://doi.org/10.1016/j.media.2024.103228}, 14 | url = {https://www.sciencedirect.com/science/article/pii/S1361841524001531}, 15 | author = {Yong Feng and Jinzhu Yang and Meng Li and Lingzhi Tang and Song Sun and Yonghuai Wang}, 16 | keywords = {Landmark detection, Keyframe detection, Bayesian network, Ultrasonic cine, Multi-task learning}, 17 | } 18 | ``` 19 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from distutils.util import strtobool 3 | 4 | parser = argparse.ArgumentParser() 5 | 6 | # Environment 7 | parser.add_argument("--is_train", type=strtobool, default='true') 8 | parser.add_argument("--tensorboard", type=strtobool, default='true') 9 | parser.add_argument('--cpu', action='store_true', help='use cpu only') 10 | parser.add_argument('--gpu', type=int, default=0) 11 | parser.add_argument("--num_gpu", type=int, default=1) 12 | parser.add_argument("--num_work", type=int, default=1) 13 | parser.add_argument("--exp_dir", type=str, default="./ckpt") 14 | parser.add_argument("--exp_load", type=str, default=None) 15 | 16 | # Data 17 | parser.add_argument("--data_dir", type=str, default="/mnt/sda") 18 | parser.add_argument("--data_name", type=str, default="UIC") 19 | parser.add_argument('--batch_size', type=int, default=2) 20 | parser.add_argument('--rgb_range', type=int, default=1) 21 | 22 | # Model 23 | parser.add_argument('--uncertainty', default='normal', 24 | choices=('normal', 'epistemic', 'aleatoric', 'combined')) 25 | parser.add_argument('--in_channels', type=int, default=1) 26 | parser.add_argument('--n_feats', type=int, default=32) 27 | parser.add_argument('--var_weight', type=float, default=0.001) 28 | parser.add_argument('--drop_rate', type=float, default=0.1) 29 | 30 | # Train 31 | parser.add_argument("--epochs", type=int, default=5000) 32 | parser.add_argument("--lr", type=float, default=1e-4) 33 | parser.add_argument("--decay", type=str, default='50-100-150-200') 34 | parser.add_argument("--gamma", type=float, default=0.5) 35 | parser.add_argument("--optimizer", type=str, default='adam', 36 | choices=('sgd', 'adam', 'rmsprop')) 37 | parser.add_argument("--weight_decay", type=float, default=1e-4) 38 | parser.add_argument("--momentum", type=float, default=0.9) 39 | parser.add_argument("--betas", type=tuple, default=(0.9, 0.999)) 40 | parser.add_argument("--epsilon", type=float, default=1e-8) 41 | 42 | # Test 43 | parser.add_argument('--n_samples', type=int, default=25) 44 | 45 | 46 | def save_args(obj, defaults, kwargs): 47 | for k,v in defaults.iteritems(): 48 | if k in kwargs: v = kwargs[k] 49 | setattr(obj, k, v) 50 | 51 | 52 | def get_config(): 53 | config = parser.parse_args() 54 | return config 55 | -------------------------------------------------------------------------------- /model/UNet_S.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import matplotlib.pyplot as plt 5 | num_pt = 4 6 | 7 | def make_model(args): 8 | return Unet(args) 9 | 10 | class Unet(nn.Module): 11 | def __init__(self, config): 12 | super(Unet, self).__init__() 13 | filter_config = (32, 64, 128, 256) 14 | self.depth = len(filter_config) 15 | self.drop_rate = config.drop_rate 16 | in_channels = config.in_channels 17 | out_channels = num_pt 18 | 19 | self.encoders = nn.ModuleList() 20 | self.decoders_mean = nn.ModuleList() 21 | self.decoders_var = nn.ModuleList() 22 | 23 | # setup number of conv-bn-relu blocks per module and number of filters 24 | encoder_n_layers = (2, 2, 3, 3, 3) 25 | encoder_filter_config = (in_channels,) + filter_config 26 | decoder_n_layers = (3, 3, 3, 2, 1) 27 | decoder_filter_config = filter_config[::-1] + (filter_config[0],) 28 | 29 | self.bottom_conv = nn.Sequential(*[nn.Conv2d(256, 512, 3, 1, 1), 30 | nn.BatchNorm2d(512), 31 | nn.ReLU(), nn.Conv2d(512, 256, 3, 1, 1), 32 | nn.BatchNorm2d(256), 33 | nn.ReLU()]) # nn.ConvTranspose2d(512, 256, 3, 1, 1) 34 | 35 | for i in range(0, self.depth): 36 | # encoder architecture 37 | self.encoders.append(_Encoder(encoder_filter_config[i], 38 | encoder_filter_config[i + 1], 39 | encoder_n_layers[i])) 40 | 41 | # decoder architecture 42 | self.decoders_mean.append(_Decoder(decoder_filter_config[i], 43 | decoder_filter_config[i + 1], 44 | decoder_n_layers[i])) 45 | 46 | 47 | 48 | self.classifier_mean = nn.Conv2d(filter_config[0], 32, 3, 1, 1) 49 | self.classifier_mean1 = nn.Conv2d(32, 16, 3, 1, 1) 50 | self.classifier_mean2 = nn.Conv2d(16, out_channels, 1, 1) 51 | 52 | def forward(self, x): 53 | indices = [] 54 | unpool_sizes = [] 55 | feat = x 56 | feat_encoders = [] 57 | # encoder path, keep track of pooling indices and features size 58 | for i in range(0, self.depth): 59 | feat_ori, (feat, ind), size = self.encoders[i](feat) 60 | feat_encoders.append(feat_ori) 61 | # if i == 1: 62 | # feat = F.dropout(feat, p=self.drop_rate, training=self.training) 63 | indices.append(ind) 64 | unpool_sizes.append(size) 65 | 66 | feat = self.bottom_conv(feat) 67 | feat_mean = feat 68 | feat_bottom = feat 69 | 70 | 71 | # decoder path, upsampling with corresponding indices and size 72 | for i in range(0, self.depth): 73 | feat_mean = self.decoders_mean[i](feat_mean, feat_encoders[self.depth-i-1], indices[self.depth -1 - i], unpool_sizes[self.depth -1 - i]) 74 | 75 | # feat_var = self.decoders_var[i](feat_var, indices[self.depth -1 - i], unpool_sizes[self.depth -1 - i]) 76 | # if i == 0: 77 | # feat_mean = F.dropout(feat_mean, p=self.drop_rate, training=True) 78 | # feat_var = F.dropout(feat_var, p=self.drop_rate, training=True) 79 | 80 | output_mean = self.classifier_mean(feat_mean) 81 | output_mean1 = self.classifier_mean1(output_mean) 82 | output_mean2 = self.classifier_mean2(output_mean1) 83 | 84 | b, c, h, w = output_mean2.shape 85 | output_mean2 = F.softmax(output_mean2.view(b, c, -1), -1).view(b, c, h, w) 86 | 87 | 88 | results = {'mean': output_mean2, 'var': 0, 'feat': feat_encoders[0], 'bottom': feat_bottom} 89 | return results 90 | 91 | 92 | class _Encoder(nn.Module): 93 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2): 94 | """Encoder layer follows VGG rules + keeps pooling indices 95 | Args: 96 | n_in_feat (int): number of input features 97 | n_out_feat (int): number of output features 98 | n_blocks (int): number of conv-batch-relu block inside the encoder 99 | drop_rate (float): dropout rate to use 100 | """ 101 | super(_Encoder, self).__init__() 102 | 103 | layers = [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 104 | nn.BatchNorm2d(n_out_feat), 105 | nn.ReLU()] 106 | 107 | if n_blocks > 1: 108 | layers += [nn.Conv2d(n_out_feat, n_out_feat, 3, 1, 1), 109 | nn.BatchNorm2d(n_out_feat), 110 | nn.ReLU()] 111 | 112 | self.features = nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | output = self.features(x) 116 | return output, F.max_pool2d(output, 2, 2, return_indices=True), output.size() 117 | 118 | 119 | class _Decoder(nn.Module): 120 | """Decoder layer decodes the features by unpooling with respect to 121 | the pooling indices of the corresponding decoder part. 122 | Args: 123 | n_in_feat (int): number of input features 124 | n_out_feat (int): number of output features 125 | n_blocks (int): number of conv-batch-relu block inside the decoder 126 | drop_rate (float): dropout rate to use 127 | """ 128 | 129 | def __init__(self, n_in_feat, n_out_feat, n_blocks=2): 130 | super(_Decoder, self).__init__() 131 | 132 | self.up_conv = nn.ConvTranspose2d(n_in_feat, n_in_feat, 3, 2, 1, 1) 133 | 134 | layers = [nn.Conv2d(2*n_in_feat, n_in_feat, 3, 1, 1), 135 | nn.BatchNorm2d(n_in_feat), 136 | nn.ReLU()] 137 | 138 | if n_blocks > 1: 139 | layers += [nn.Conv2d(n_in_feat, n_out_feat, 3, 1, 1), 140 | nn.BatchNorm2d(n_out_feat), 141 | nn.ReLU()] 142 | 143 | self.features = nn.Sequential(*layers) 144 | 145 | def forward(self, x, x_e, indices, size): 146 | x = torch.cat([x_e, self.up_conv(x)], 1) 147 | return self.features(x) 148 | 149 | 150 | if __name__ == '__main__': 151 | import torchinfo 152 | from config import get_config 153 | BUNET = Unet(get_config()) 154 | 155 | batch_size = 2 156 | torchinfo.summary(BUNET, input_size=(batch_size, 1, 256, 256)) 157 | -------------------------------------------------------------------------------- /model/ABHG_FOR_UIC.py: -------------------------------------------------------------------------------- 1 | from model.UNet_S import Unet 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | 10 | num_pt = 4 11 | num_patch = num_pt * 9 12 | num_hyper = 16 13 | patch_size = 8 14 | 15 | 16 | 17 | def getcropedInputs(ROIs, inputs_origin, cropSize, useGPU=0): 18 | landmarks = ROIs.detach().cpu().numpy() 19 | landmarkNum = landmarks.shape[1] 20 | b, c, h, w = inputs_origin.size() 21 | 22 | cropSize = int(cropSize / 2) 23 | 24 | X, Y = landmarks[:, :, 0], landmarks[:, :, 1] + 1 25 | 26 | X, Y = np.round(X * (h - 1)).astype("int"), np.round(Y * (w - 1)).astype("int") 27 | 28 | cropedDICOMs = [] 29 | flag = True 30 | for landmarkId in range(landmarkNum): 31 | x, y = X[:, landmarkId].clip(0, 255), Y[:, landmarkId].clip(0, 255) 32 | lx, ux, ly, uy = x - cropSize, x + cropSize, y - cropSize, y + cropSize 33 | lxx, uxx, lyy, uyy = np.where(lx > 0, lx, 0), np.where(ux < h, ux, h), np.where(ly > 0, ly, 0), np.where(uy < w, 34 | uy, w) 35 | # lxx, uxx, lyy, uyy = np.clip(lx, 0, 255), np.clip(ux, 0, 255), np.clip(ly, 0, 255), np.clip(uy, 0, 255) 36 | for b_id in range(b): 37 | cropedDICOM = inputs_origin[b_id:b_id + 1, :, lxx[b_id]: uxx[b_id], lyy[b_id]: uyy[b_id]] 38 | # ~ print ("check before", cropedDICOM.size()) 39 | if lx[b_id] < 0: 40 | _, _, curentX, curentY = cropedDICOM.size() 41 | temTensor = torch.zeros(1, c, 0 - lx[b_id], curentY) 42 | if useGPU >= 0: temTensor = temTensor.cuda(useGPU) 43 | cropedDICOM = torch.cat((temTensor, cropedDICOM), 2) 44 | if ux[b_id] > h: 45 | _, _, curentX, curentY = cropedDICOM.size() 46 | temTensor = torch.zeros(1, c, ux[b_id] - h, curentY) 47 | if useGPU >= 0: temTensor = temTensor.cuda(useGPU) 48 | cropedDICOM = torch.cat((cropedDICOM, temTensor), 2) 49 | if ly[b_id] < 0: 50 | _, _, curentX, curentY = cropedDICOM.size() 51 | temTensor = torch.zeros(1, c, curentX, 0 - ly[b_id]) 52 | if useGPU >= 0: temTensor = temTensor.cuda(useGPU) 53 | cropedDICOM = torch.cat((temTensor, cropedDICOM), 3) 54 | if uy[b_id] > w: 55 | _, _, curentX, curentY = cropedDICOM.size() 56 | temTensor = torch.zeros(1, c, curentX, uy[b_id] - w) 57 | if useGPU >= 0: temTensor = temTensor.cuda(useGPU) 58 | cropedDICOM = torch.cat((cropedDICOM, temTensor), 3) 59 | 60 | cropedDICOMs.append(cropedDICOM) 61 | b_crops = [] 62 | for i in range(b): 63 | croped_i = torch.stack(cropedDICOMs[i::b], 1) 64 | b_crops.append(croped_i) 65 | b_crops = torch.cat(b_crops, 0) 66 | 67 | return b_crops 68 | 69 | 70 | def get_hg_node_features(d, landmarks, batch_size): 71 | if landmarks.shape[0] != batch_size: 72 | landmarks = landmarks.repeat(batch_size, 1, 1) # B,num_lands,2 73 | 74 | shifts = torch.tensor([[-patch_size, 0], [-patch_size, patch_size], [0, patch_size], [patch_size, patch_size], 75 | [patch_size, 0], [patch_size, -patch_size], [0, patch_size], [-patch_size, -patch_size], 76 | [0, 0]]).view(9, 2).to(landmarks.device) 77 | shifts = torch.true_divide(shifts, 255) 78 | # get hyer_lands 79 | visual_features = [] 80 | landmarks_hyper = torch.repeat_interleave(landmarks, 9, 1) 81 | 82 | for i in range(landmarks.shape[1]): 83 | for j in range(9): 84 | landmarks_hyper[:, i * j, :] = landmarks_hyper[:, i * j, :] + shifts[j] 85 | visual_features.append(getcropedInputs(landmarks_hyper[:, i * j, :].unsqueeze(1), d, cropSize=patch_size)) 86 | 87 | visual_feature = torch.cat(visual_features, 1) 88 | init_landmark = landmarks_hyper[:, None, :, :] - landmarks_hyper[:, :, None, :] 89 | shape_feature = init_landmark.reshape(batch_size, landmarks_hyper.shape[1], -1) 90 | 91 | return visual_feature, shape_feature, landmarks_hyper 92 | 93 | 94 | 95 | class ABHG(nn.Module): 96 | def __init__(self, in_ch, out_ch, num_patch=num_patch): 97 | super(ABHG, self).__init__() 98 | self.num_hyper = num_hyper 99 | self.H = torch.nn.Parameter((torch.ones((num_patch, self.num_hyper), requires_grad=True) / num_patch), 100 | requires_grad=True) 101 | 102 | self.T = torch.nn.Parameter((torch.ones((num_pt, num_patch), requires_grad=True) / num_patch), 103 | requires_grad=True) 104 | self.W = torch.nn.Parameter( 105 | (torch.ones((self.num_hyper), requires_grad=True).view(1, self.num_hyper)) / self.num_hyper, 106 | requires_grad=True) 107 | self.linear1 = nn.Linear(in_ch, 2) 108 | self.relu1 = nn.LeakyReLU(inplace=True) 109 | self.linear2 = nn.Linear(in_ch, 9) 110 | 111 | self.H = nn.init.normal_(self.H) 112 | self.T = nn.init.normal_(self.T) 113 | 114 | self.shifts = torch.tensor([[-patch_size, 0], [-patch_size, patch_size], [0, patch_size], [patch_size, patch_size], 115 | [patch_size, 0], [patch_size, -patch_size], [0, patch_size], [-patch_size, -patch_size], 116 | [0, 0]]).view(9, 2) 117 | self.shifts = torch.true_divide(self.shifts, 255) 118 | 119 | def forward(self, node_feat): 120 | self.shifts = self.shifts.to(node_feat.device) 121 | 122 | nd = node_feat 123 | M1 = self.T @ self.H @ torch.diag(self.W[0]) @ torch.t(self.H) 124 | 125 | message = torch.matmul(M1, nd) 126 | x1 = F.dropout(message, p=0.1, training=True) 127 | x1 = self.linear1(x1) 128 | offset = F.sigmoid(x1) # B, 4, 2 129 | 130 | x2 = F.dropout(message, p=0.1, training=True) 131 | x2 = self.linear2(x2) # B, 4, 9 132 | direction = F.softmax(x2, 2) 133 | 134 | final_off = offset*(direction@self.shifts) 135 | 136 | 137 | return final_off 138 | 139 | class Coord_fine(nn.Module): 140 | def __init__(self, steps=1 , h_dim=972 + 2): 141 | super(Coord_fine, self).__init__() 142 | self.steps = steps 143 | 144 | self.hg = ABHG(h_dim, 2) 145 | 146 | self.patch_conv = nn.Conv2d(33, 33, patch_size) 147 | self.h_dim = 33 # +3 148 | self.coors = [] 149 | 150 | def forward(self, d, x): 151 | bs = x.shape[0] 152 | updated_landmarks = x 153 | coors = [] 154 | for step in range(self.steps): 155 | patch_feature, shape_feature, landmarks = get_hg_node_features(d, updated_landmarks, bs) 156 | b, num_land = patch_feature.shape[0], patch_feature.shape[1] 157 | patch_feature = patch_feature.view(-1, 32+1, patch_size, patch_size) 158 | patch_feature = self.patch_conv(patch_feature).view(b, num_land, self.h_dim) 159 | gin_feature = torch.cat([patch_feature, shape_feature, landmarks], -1) 160 | shift = self.hg(gin_feature) 161 | updated_landmarks = updated_landmarks + shift 162 | coors.append(updated_landmarks * 255) 163 | self.coors = coors 164 | 165 | return updated_landmarks * 255, self.coors 166 | 167 | 168 | 169 | class UNET_ABHG(nn.Module): 170 | def __init__(self, config): 171 | super().__init__() 172 | 173 | self.s1 = Unet(config) 174 | self.s2 = Coord_fine(h_dim=33 + num_pt * 9 * 2 + 2, steps=1) 175 | self.pt_num = num_pt 176 | 177 | def get_coordinates_from_coarse_heatmaps(self, predicted_heatmap): 178 | global_coordinate = torch.ones(256, 256, 2).float() 179 | for i in range(256): 180 | global_coordinate[i, :, 0] = global_coordinate[i, :, 0] * i 181 | for i in range(256): 182 | global_coordinate[:, i, 1] = global_coordinate[:, i, 1] * i 183 | global_coordinate = (global_coordinate * torch.tensor([1 / (256 - 1), 1 / (256 - 1)])).to( 184 | predicted_heatmap.device) 185 | 186 | num_pt = predicted_heatmap.shape[1] 187 | bs = predicted_heatmap.shape[0] 188 | global_coordinate_permute = global_coordinate.permute(2, 0, 1).unsqueeze(0) 189 | predict = [ 190 | torch.sum((global_coordinate_permute * predicted_heatmap[:, i:i + 1]).view(bs, 2, -1), dim=-1).unsqueeze(1) 191 | for i in 192 | range(num_pt)] 193 | predict = torch.cat(predict, dim=1) 194 | return predict 195 | 196 | def forward_train(self, frame): 197 | 198 | # coarse location stage 199 | s1_results = self.s1(frame) 200 | frame_feats, global_heatmap = s1_results['feat'], s1_results['mean'] 201 | global_coordinate = self.get_coordinates_from_coarse_heatmaps(global_heatmap) 202 | 203 | 204 | # introducing noise to ensure that ABHG is better trained 205 | offset = torch.from_numpy(np.random.normal(loc= 0 , scale= patch_size / 256/3, size=global_coordinate.size())).float().cuda() 206 | global_coordinate_ = global_coordinate + offset 207 | 208 | frame_feats_kf = torch.cat([frame, frame_feats], 1) 209 | 210 | 211 | outputs, coords_kf = self.s2(frame_feats_kf, global_coordinate_) 212 | 213 | 214 | return global_coordinate, outputs, global_heatmap, coords_kf, s1_results['bottom'] 215 | 216 | 217 | 218 | def forward_test(self, frame): 219 | 220 | # coarse location stage 221 | s1_results = self.s1(frame) 222 | frame_feats, global_heatmap = s1_results['feat'], s1_results['mean'] 223 | frame_feats = torch.cat([frame, frame_feats], 1) 224 | global_coordinate = self.get_coordinates_from_coarse_heatmaps(global_heatmap) 225 | 226 | 227 | # ABHG fine-tuning 228 | local_results_t = [] 229 | # here are some differences from the training phase 230 | for i in range(2): # mc drop time 231 | outputs_series, coords = self.s2(frame_feats.view(-1, 33, 256, 256), global_coordinate) 232 | local_results_t.append(outputs_series.view(-1, 1, num_pt, 2)) 233 | 234 | var = torch.var(torch.stack(local_results_t), 0).sum((2, 3)) 235 | coords_series = torch.stack(local_results_t).mean(0) 236 | coords_series = coords_series.view(-1, 2*num_pt, 1, 1) 237 | 238 | 239 | 240 | return global_coordinate*255, coords_series.view(1, -1, num_pt, 2), var 241 | 242 | 243 | if __name__ == '__main__': 244 | from config import get_config 245 | 246 | model = UNET_ABHG(get_config()).cuda() 247 | frame = torch.ones((1, 1, 256, 256), dtype=torch.float).cuda() 248 | 249 | model.forward_test(frame) 250 | --------------------------------------------------------------------------------