├── 3D_LessNet ├── L2ss_2_Chan_12_Smth_5.0_LR_0.0001_Val │ └── DiceVal_0.75091_Epoch_000000490.pth ├── L2ss_2_Chan_16_Smth_5.0_LR_0.0001_Val │ └── DiceVal_0.75407_Epoch_000000416.pth ├── L2ss_2_Chan_24_Smth_5.0_LR_0.0001_Val │ └── DiceVal_0.75697_Epoch_000000467.pth ├── L2ss_2_Chan_8_Smth_5.0_LR_0.0001_Val │ └── DiceVal_0.74657_Epoch_000000494.pth ├── Models.py ├── compute_dsc_jet_from_quantiResult.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── data_utils.cpython-38.pyc │ │ ├── datasets.cpython-38.pyc │ │ ├── rand.cpython-38.pyc │ │ └── trans.cpython-38.pyc │ ├── data_utils.py │ ├── datasets.py │ ├── rand.py │ └── trans.py ├── infer_bilinear.py ├── label_info.txt ├── train.py └── utils.py ├── 3D_LessNet_Diff ├── L2ss_2_Chan_12_Smth_2.0_LR_0.0001_Val │ └── DiceVal_0.75488_Epoch_000000348.pth ├── L2ss_2_Chan_16_Smth_2.0_LR_0.0001_Val │ └── DiceVal_0.75676_Epoch_000000423.pth ├── L2ss_2_Chan_24_Smth_2.0_LR_0.0001_Val │ └── DiceVal_0.75726_Epoch_000000408.pth ├── L2ss_2_Chan_8_Smth_2.0_LR_0.0001_Val │ └── DiceVal_0.74972_Epoch_000000498.pth ├── Models.py ├── compute_dsc_jet_from_quantiResult.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── data_utils.cpython-38.pyc │ │ ├── datasets.cpython-38.pyc │ │ ├── rand.cpython-38.pyc │ │ └── trans.cpython-38.pyc │ ├── data_utils.py │ ├── datasets.py │ ├── rand.py │ └── trans.py ├── infer_bilinear.py ├── label_info.txt ├── train.py └── utils.py └── README.md /3D_LessNet/L2ss_2_Chan_12_Smth_5.0_LR_0.0001_Val/DiceVal_0.75091_Epoch_000000490.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet/L2ss_2_Chan_12_Smth_5.0_LR_0.0001_Val/DiceVal_0.75091_Epoch_000000490.pth -------------------------------------------------------------------------------- /3D_LessNet/L2ss_2_Chan_16_Smth_5.0_LR_0.0001_Val/DiceVal_0.75407_Epoch_000000416.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet/L2ss_2_Chan_16_Smth_5.0_LR_0.0001_Val/DiceVal_0.75407_Epoch_000000416.pth -------------------------------------------------------------------------------- /3D_LessNet/L2ss_2_Chan_24_Smth_5.0_LR_0.0001_Val/DiceVal_0.75697_Epoch_000000467.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet/L2ss_2_Chan_24_Smth_5.0_LR_0.0001_Val/DiceVal_0.75697_Epoch_000000467.pth -------------------------------------------------------------------------------- /3D_LessNet/L2ss_2_Chan_8_Smth_5.0_LR_0.0001_Val/DiceVal_0.74657_Epoch_000000494.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet/L2ss_2_Chan_8_Smth_5.0_LR_0.0001_Val/DiceVal_0.74657_Epoch_000000494.pth -------------------------------------------------------------------------------- /3D_LessNet/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | 7 | 8 | class UNet(nn.Module): 9 | def __init__(self, in_channel, n_classes, start_channel): 10 | self.in_channel = in_channel 11 | self.n_classes = n_classes 12 | self.start_channel = start_channel 13 | 14 | bias_opt = True 15 | 16 | super(UNet, self).__init__() 17 | self.eninput = self.encoder(self.in_channel, self.start_channel * 4, bias=bias_opt) 18 | 19 | self.dc1 = self.encoder(self.start_channel * 4, self.start_channel * 4, kernel_size=3, 20 | stride=1, bias=bias_opt) 21 | self.dc2 = self.encoder(self.start_channel * 4, self.start_channel * 4, kernel_size=3, stride=1, bias=bias_opt) 22 | self.dc3 = self.encoder(self.start_channel * 3 + 6, self.start_channel * 3, kernel_size=3, 23 | stride=1, bias=bias_opt) 24 | self.dc4 = self.encoder(self.start_channel * 3, self.start_channel * 3, kernel_size=3, stride=1, bias=bias_opt) 25 | self.dc5 = self.encoder(self.start_channel * 2 + 6, self.start_channel * 2, kernel_size=3, 26 | stride=1, bias=bias_opt) 27 | self.dc6 = self.encoder(self.start_channel * 2, self.start_channel * 2, kernel_size=3, stride=1, bias=bias_opt) 28 | self.dc7 = self.encoder(self.start_channel * 1 + 2, self.start_channel * 1, kernel_size=3, 29 | stride=1, bias=bias_opt) 30 | self.dc8 = self.encoder(self.start_channel * 1, self.start_channel * 1, kernel_size=3, stride=1, bias=bias_opt) 31 | self.dc9 = self.outputs(self.start_channel * 1, self.n_classes, kernel_size=3, stride=1, padding=1, bias=False) 32 | 33 | self.up2 = self.decoder(self.start_channel * 4, self.start_channel * 3) 34 | self.up3 = self.decoder(self.start_channel * 3, self.start_channel * 2) 35 | self.up4 = self.decoder(self.start_channel * 2, self.start_channel * 1) 36 | 37 | self.maxp_1_2 = nn.MaxPool3d(2, stride=2) 38 | self.maxp_1_4 = nn.MaxPool3d(4, stride=4) 39 | self.maxp_1_8 = nn.MaxPool3d(8, stride=8) 40 | self.avgp_1_2 = nn.AvgPool3d(2, stride=2) 41 | self.avgp_1_4 = nn.AvgPool3d(4, stride=4) 42 | self.avgp_1_8 = nn.AvgPool3d(8, stride=8) 43 | 44 | def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False, batchnorm=False): 45 | if batchnorm: 46 | layer = nn.Sequential( 47 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 48 | nn.BatchNorm3d(out_channels), 49 | nn.LeakyReLU()) 50 | else: 51 | layer = nn.Sequential( 52 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 53 | nn.LeakyReLU()) 54 | return layer 55 | 56 | def decoder(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True): 57 | layer = nn.Sequential( 58 | nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, 59 | padding=padding, output_padding=output_padding, bias=bias), 60 | nn.LeakyReLU()) 61 | return layer 62 | 63 | def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 64 | bias=False, batchnorm=False): 65 | if batchnorm: 66 | layer = nn.Sequential( 67 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 68 | nn.BatchNorm2d(out_channels), 69 | nn.Tanh()) 70 | else: 71 | layer = nn.Sequential( 72 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 73 | nn.Softsign()) 74 | return layer 75 | 76 | 77 | def forward(self, x, y): 78 | x_avg_1_8 = self.avgp_1_8(x) 79 | y_avg_1_8 = self.avgp_1_8(y) 80 | x_max_1_8 = self.maxp_1_8(x) 81 | y_max_1_8 = self.maxp_1_8(y) 82 | x_min_1_8 = -self.maxp_1_8(-x) 83 | y_min_1_8 = -self.maxp_1_8(-y) 84 | x_in = torch.cat((x_avg_1_8, y_avg_1_8, x_max_1_8, y_max_1_8, x_min_1_8, y_min_1_8), 1) 85 | d0 = self.eninput(x_in) 86 | 87 | # x_1 = torch.nn.functional.interpolate(x, size=[40, 48], mode='bilinear') 88 | # y_1 = torch.nn.functional.interpolate(y, size=[40, 48], mode='bilinear') 89 | # x_1_minus = torch.nn.functional.interpolate(x-y, size=[40, 48], mode='bilinear') 90 | # y_1_minus = torch.nn.functional.interpolate(y-x, size=[40, 48], mode='bilinear') 91 | # x_1_minus = torch.nn.functional.interpolate(x-y, size=[40, 48], mode='bilinear') 92 | # y_1_minus = torch.nn.functional.interpolate(y-x, size=[40, 48], mode='bilinear') 93 | # e0 = self.ec1(e0) 94 | 95 | # e1 = self.ec2(e0) 96 | # e1 = self.ec3(e1) 97 | 98 | # e2 = self.ec4(e1) 99 | # e2 = self.ec5(e2) 100 | 101 | # e3 = self.ec6(e2) 102 | # e3 = self.ec7(e3) 103 | 104 | # e4 = self.ec8(e3) 105 | # e4 = self.ec9(e4) 106 | 107 | # d0 = torch.cat((self.up1(e4), e3), 1) 108 | 109 | d0 = self.dc1(d0) 110 | d0 = self.dc2(d0) 111 | 112 | d1 = self.up2(d0) 113 | x_avg_1_4 = self.avgp_1_4(x) 114 | y_avg_1_4 = self.avgp_1_4(y) 115 | x_max_1_4 = self.maxp_1_4(x) 116 | y_max_1_4 = self.maxp_1_4(y) 117 | x_min_1_4 = -self.maxp_1_4(-x) 118 | y_min_1_4 = -self.maxp_1_4(-y) 119 | # print(x_avg_1_4.shape) 120 | # print(x_avg_1_4[0,0,20:30,20:30], x_max_1_4[0,0,20:30,20:30], x_min_1_4[0,0,20:30,20:30]) 121 | # assert 0==1 122 | # x_in = torch.cat((), 1) 123 | d1 = torch.cat((d1, x_avg_1_4, y_avg_1_4, x_max_1_4, y_max_1_4, x_min_1_4, y_min_1_4),1) 124 | 125 | d1 = self.dc3(d1) 126 | d1 = self.dc4(d1) 127 | d2 = self.up3(d1) 128 | 129 | 130 | # x_1_2 = torch.nn.functional.interpolate(x, size=[80, 96], mode='bilinear') 131 | # y_1_2 = torch.nn.functional.interpolate(y, size=[80, 96], mode='bilinear') 132 | # x_1_2_ = torch.nn.functional.interpolate(x-y, size=[80, 96], mode='bilinear') 133 | # y_1_2_ = torch.nn.functional.interpolate(y-x, size=[80, 96], mode='bilinear') 134 | x_avg_1_2 = self.avgp_1_2(x) 135 | y_avg_1_2 = self.avgp_1_2(y) 136 | x_max_1_2 = self.maxp_1_2(x) 137 | y_max_1_2 = self.maxp_1_2(y) 138 | x_min_1_2 = -self.maxp_1_2(-x) 139 | y_min_1_2 = -self.maxp_1_2(-y) 140 | d2 = torch.cat((d2, x_avg_1_2, y_avg_1_2, x_max_1_2, y_max_1_2, x_min_1_2, y_min_1_2), 1) 141 | 142 | 143 | d2 = self.dc5(d2) 144 | d2 = self.dc6(d2) 145 | d3 = self.up4(d2) 146 | 147 | d3 = torch.cat((d3, x, y), 1) 148 | 149 | d3 = self.dc7(d3) 150 | d3 = self.dc8(d3) 151 | 152 | f_xy = self.dc9(d3) 153 | 154 | 155 | return f_xy 156 | 157 | 158 | class SpatialTransform(nn.Module): 159 | def __init__(self): 160 | super(SpatialTransform, self).__init__() 161 | def forward(self, mov_image, flow, mod = 'bilinear'): 162 | d2, h2, w2 = mov_image.shape[-3:] 163 | 164 | grid_d, grid_h, grid_w = torch.meshgrid([torch.linspace(-1, 1, d2), torch.linspace(-1, 1, h2), torch.linspace(-1, 1, w2)]) 165 | 166 | grid_h = grid_h.to(flow.device).float() 167 | grid_d = grid_d.to(flow.device).float() 168 | grid_w = grid_w.to(flow.device).float() 169 | grid_d = nn.Parameter(grid_d, requires_grad=False) 170 | grid_w = nn.Parameter(grid_w, requires_grad=False) 171 | grid_h = nn.Parameter(grid_h, requires_grad=False) 172 | flow_d = flow[:,:,:,:,0] 173 | flow_h = flow[:,:,:,:,1] 174 | flow_w = flow[:,:,:,:,2] 175 | #Softsign 176 | #disp_d = (grid_d + (flow_d * 2 / d2)).squeeze(1) 177 | #disp_h = (grid_h + (flow_h * 2 / h2)).squeeze(1) 178 | #disp_w = (grid_w + (flow_w * 2 / w2)).squeeze(1) 179 | 180 | # Remove Channel Dimension 181 | disp_d = (grid_d + (flow_d)).squeeze(1) 182 | disp_h = (grid_h + (flow_h)).squeeze(1) 183 | disp_w = (grid_w + (flow_w)).squeeze(1) 184 | sample_grid = torch.stack((disp_w, disp_h, disp_d), 4) # shape (N, D, H, W, 3) 185 | warped = torch.nn.functional.grid_sample(mov_image, sample_grid, mode = mod, align_corners = True) 186 | 187 | return warped 188 | 189 | class DiffeomorphicTransform(nn.Module): 190 | def __init__(self, time_step=7): 191 | super(DiffeomorphicTransform, self).__init__() 192 | self.time_step = time_step 193 | 194 | def forward(self, flow): 195 | 196 | # print(flow.shape) 197 | d2, h2, w2 = flow.shape[-3:] 198 | grid_d, grid_h, grid_w = torch.meshgrid([torch.linspace(-1, 1, d2), torch.linspace(-1, 1, h2), torch.linspace(-1, 1, w2)]) 199 | grid_h = grid_h.to(flow.device).float() 200 | grid_d = grid_d.to(flow.device).float() 201 | grid_w = grid_w.to(flow.device).float() 202 | grid_d = nn.Parameter(grid_d, requires_grad=False) 203 | grid_w = nn.Parameter(grid_w, requires_grad=False) 204 | grid_h = nn.Parameter(grid_h, requires_grad=False) 205 | flow = flow / (2 ** self.time_step) 206 | 207 | for i in range(self.time_step): 208 | flow_d = flow[:,0,:,:,:] 209 | flow_h = flow[:,1,:,:,:] 210 | flow_w = flow[:,2,:,:,:] 211 | disp_d = (grid_d + flow_d).squeeze(1) 212 | disp_h = (grid_h + flow_h).squeeze(1) 213 | disp_w = (grid_w + flow_w).squeeze(1) 214 | 215 | deformation = torch.stack((disp_w, disp_h, disp_d), 4) # shape (N, D, H, W, 3) 216 | flow = flow + torch.nn.functional.grid_sample(flow, deformation, mode='bilinear', padding_mode="border", align_corners = True) 217 | return flow 218 | 219 | def smoothloss(y_pred): 220 | #print('smoothloss y_pred.shape ',y_pred.shape) 221 | #[N,3,D,H,W] 222 | d2, h2, w2 = y_pred.shape[-3:] 223 | dy = torch.abs(y_pred[:,:,1:, :, :] - y_pred[:,:, :-1, :, :]) / 2 * d2 224 | dx = torch.abs(y_pred[:,:,:, 1:, :] - y_pred[:,:, :, :-1, :]) / 2 * h2 225 | dz = torch.abs(y_pred[:,:,:, :, 1:] - y_pred[:,:, :, :, :-1]) / 2 * w2 226 | return (torch.mean(dx * dx)+torch.mean(dy*dy)+torch.mean(dz*dz))/3.0 227 | 228 | """ 229 | Normalized local cross-correlation function in Pytorch. Modified from https://github.com/voxelmorph/voxelmorph. 230 | """ 231 | class NCC(torch.nn.Module): 232 | """ 233 | local (over window) normalized cross correlation 234 | """ 235 | def __init__(self, win= 9, eps=1e-5): 236 | super(NCC, self).__init__() 237 | self.win_raw = win 238 | self.eps = eps 239 | self.win = win 240 | 241 | def forward(self, I, J): 242 | ndims = 3 243 | win_size = self.win_raw 244 | self.win = [self.win_raw] * ndims 245 | 246 | weight_win_size = self.win_raw 247 | weight = torch.ones((1, 1, weight_win_size, weight_win_size, weight_win_size), device=I.device, requires_grad=False) 248 | conv_fn = F.conv3d 249 | 250 | # compute CC squares 251 | I2 = I*I 252 | J2 = J*J 253 | IJ = I*J 254 | 255 | # compute filters 256 | # compute local sums via convolution 257 | I_sum = conv_fn(I, weight, padding=int(win_size/2)) 258 | J_sum = conv_fn(J, weight, padding=int(win_size/2)) 259 | I2_sum = conv_fn(I2, weight, padding=int(win_size/2)) 260 | J2_sum = conv_fn(J2, weight, padding=int(win_size/2)) 261 | IJ_sum = conv_fn(IJ, weight, padding=int(win_size/2)) 262 | 263 | # compute cross correlation 264 | win_size = np.prod(self.win) 265 | u_I = I_sum/win_size 266 | u_J = J_sum/win_size 267 | 268 | cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size 269 | I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size 270 | J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size 271 | 272 | cc = cross * cross / (I_var * J_var + self.eps) 273 | 274 | # return negative cc. 275 | return -1.0 * torch.mean(cc) 276 | 277 | class MSE: 278 | """ 279 | Mean squared error loss. 280 | """ 281 | 282 | def loss(self, y_true, y_pred): 283 | return torch.mean((y_true - y_pred) ** 2) 284 | 285 | class SAD: 286 | """ 287 | Mean squared error loss. 288 | """ 289 | 290 | def loss(self, y_true, y_pred): 291 | return torch.mean(torch.abs(y_true - y_pred)) 292 | -------------------------------------------------------------------------------- /3D_LessNet/compute_dsc_jet_from_quantiResult.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv, sys 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import matplotlib.font_manager as font_manager 6 | from scipy.stats import wilcoxon, ttest_rel, ttest_ind 7 | 8 | outstruct = ['Brain-Stem', 'Thalamus', 'Cerebellum-Cortex', 'Cerebral-White-Matter', 'Cerebellum-White-Matter', 'Putamen', 'VentralDC', 'Pallidum', 'Caudate', 'Lateral-Ventricle', 'Hippocampus', 9 | '3rd-Ventricle', '4th-Ventricle', 'Amygdala', 'Cerebral-Cortex', 'CSF', 'choroid-plexus'] 10 | exp_data = np.zeros((len(outstruct), 115)) 11 | stct_i = 0 12 | file_dir = './Quantitative_Results/' 13 | for stct in outstruct: 14 | tar_idx = [] 15 | with open(file_dir+'L2ss_2_Chan_16_Smth_5.0_LR_0.0001_Test_Bilinear.csv', "r") as f: 16 | reader = csv.reader(f, delimiter="\t") 17 | for i, line in enumerate(reader): 18 | if i == 1: 19 | names = line[0].split(',') 20 | idx = 0 21 | for item in names: 22 | if stct in item: 23 | tar_idx.append(idx) 24 | idx += 1 25 | elif i>1: 26 | if line[0].split(',')[1]=='': 27 | continue 28 | val = 0 29 | for lr_i in tar_idx: 30 | vals = line[0].split(',') 31 | val += float(vals[lr_i]) 32 | val = val/len(tar_idx) 33 | exp_data[stct_i, i-2] = val 34 | stct_i+=1 35 | # all_dsc.append(exp_data.mean(axis=0)) 36 | print(exp_data.mean()) 37 | print(exp_data.std()) 38 | my_list = [] 39 | with open(file_dir+'L2ss_2_Chan_16_Smth_5.0_LR_0.0001_Test_Bilinear.csv', newline='') as f: 40 | reader = csv.reader(f) 41 | my_list = [row[-1] for row in reader] 42 | my_list = my_list[2:] 43 | my_list = np.array([float(i) for i in my_list])*100 44 | print('jec_det: {:.3f} +- {:.3f}'.format(my_list.mean(), my_list.std())) 45 | -------------------------------------------------------------------------------- /3D_LessNet/data/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np -------------------------------------------------------------------------------- /3D_LessNet/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /3D_LessNet/data/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet/data/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /3D_LessNet/data/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet/data/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /3D_LessNet/data/__pycache__/rand.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet/data/__pycache__/rand.cpython-38.pyc -------------------------------------------------------------------------------- /3D_LessNet/data/__pycache__/trans.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet/data/__pycache__/trans.cpython-38.pyc -------------------------------------------------------------------------------- /3D_LessNet/data/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pickle 3 | import numpy as np 4 | import torch 5 | 6 | M = 2 ** 32 - 1 7 | 8 | 9 | def init_fn(worker): 10 | seed = torch.LongTensor(1).random_().item() 11 | seed = (seed + worker) % M 12 | np.random.seed(seed) 13 | random.seed(seed) 14 | 15 | 16 | def add_mask(x, mask, dim=1): 17 | mask = mask.unsqueeze(dim) 18 | shape = list(x.shape); 19 | shape[dim] += 21 20 | new_x = x.new(*shape).zero_() 21 | new_x = new_x.scatter_(dim, mask, 1.0) 22 | s = [slice(None)] * len(shape) 23 | s[dim] = slice(21, None) 24 | new_x[s] = x 25 | return new_x 26 | 27 | 28 | def sample(x, size): 29 | # https://gist.github.com/yoavram/4134617 30 | i = random.sample(range(x.shape[0]), size) 31 | return torch.tensor(x[i], dtype=torch.int16) 32 | # x = np.random.permutation(x) 33 | # return torch.tensor(x[:size]) 34 | 35 | 36 | def pkload(fname): 37 | with open(fname, 'rb') as f: 38 | return pickle.load(f) 39 | 40 | 41 | _shape = (240, 240, 155) 42 | 43 | 44 | def get_all_coords(stride): 45 | return torch.tensor( 46 | np.stack([v.reshape(-1) for v in 47 | np.meshgrid( 48 | *[stride // 2 + np.arange(0, s, stride) for s in _shape], 49 | indexing='ij')], 50 | -1), dtype=torch.int16) 51 | 52 | 53 | _zero = torch.tensor([0]) 54 | 55 | 56 | def gen_feats(): 57 | x, y, z = 240, 240, 155 58 | feats = np.stack( 59 | np.meshgrid( 60 | np.arange(x), np.arange(y), np.arange(z), 61 | indexing='ij'), -1).astype('float32') 62 | shape = np.array([x, y, z]) 63 | feats -= shape / 2.0 64 | feats /= shape 65 | 66 | return feats -------------------------------------------------------------------------------- /3D_LessNet/data/datasets.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import torch, sys 3 | from torch.utils.data import Dataset 4 | from .data_utils import pkload 5 | import matplotlib.pyplot as plt 6 | 7 | import numpy as np 8 | 9 | 10 | class IXIBrainDataset(Dataset): 11 | def __init__(self, data_path, atlas_path, transforms): 12 | self.paths = data_path 13 | self.atlas_path = atlas_path 14 | self.transforms = transforms 15 | 16 | def one_hot(self, img, C): 17 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3])) 18 | for i in range(C): 19 | out[i,...] = img == i 20 | return out 21 | 22 | def __getitem__(self, index): 23 | path = self.paths[index] 24 | x, x_seg = pkload(self.atlas_path) 25 | y, y_seg = pkload(path) 26 | #print(x.shape) 27 | #print(x.shape) 28 | #print(np.unique(y)) 29 | # print(x.shape, y.shape)#(240, 240, 155) (240, 240, 155) 30 | # transforms work with nhwtc 31 | x, y = x[None, ...], y[None, ...] 32 | # print(x.shape, y.shape)#(1, 240, 240, 155) (1, 240, 240, 155) 33 | x,y = self.transforms([x, y]) 34 | #y = self.one_hot(y, 2) 35 | #print(y.shape) 36 | #sys.exit(0) 37 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth] 38 | y = np.ascontiguousarray(y) 39 | #plt.figure() 40 | #plt.subplot(1, 2, 1) 41 | #plt.imshow(x[0, :, :, 8], cmap='gray') 42 | #plt.subplot(1, 2, 2) 43 | #plt.imshow(y[0, :, :, 8], cmap='gray') 44 | #plt.show() 45 | #sys.exit(0) 46 | #y = np.squeeze(y, axis=0) 47 | x, y = torch.from_numpy(x), torch.from_numpy(y) 48 | return x, y 49 | 50 | def __len__(self): 51 | return len(self.paths) 52 | 53 | 54 | class IXIBrainInferDataset(Dataset): 55 | def __init__(self, data_path, atlas_path, transforms): 56 | self.atlas_path = atlas_path 57 | self.paths = data_path 58 | self.transforms = transforms 59 | 60 | def one_hot(self, img, C): 61 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3])) 62 | for i in range(C): 63 | out[i,...] = img == i 64 | return out 65 | 66 | def __getitem__(self, index): 67 | path = self.paths[index] 68 | x, x_seg = pkload(self.atlas_path) 69 | y, y_seg = pkload(path) 70 | x, y = x[None, ...], y[None, ...] 71 | x_seg, y_seg= x_seg[None, ...], y_seg[None, ...] 72 | x, x_seg = self.transforms([x, x_seg]) 73 | y, y_seg = self.transforms([y, y_seg]) 74 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth] 75 | y = np.ascontiguousarray(y) 76 | x_seg = np.ascontiguousarray(x_seg) # [Bsize,channelsHeight,,Width,Depth] 77 | y_seg = np.ascontiguousarray(y_seg) 78 | x, y, x_seg, y_seg = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(x_seg), torch.from_numpy(y_seg) 79 | return x, y, x_seg, y_seg 80 | 81 | def __len__(self): 82 | return len(self.paths) -------------------------------------------------------------------------------- /3D_LessNet/data/rand.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | class Uniform(object): 5 | def __init__(self, a, b): 6 | self.a = a 7 | self.b = b 8 | 9 | def sample(self): 10 | return random.uniform(self.a, self.b) 11 | 12 | 13 | class Gaussian(object): 14 | def __init__(self, mean, std): 15 | self.mean = mean 16 | self.std = std 17 | 18 | def sample(self): 19 | return random.gauss(self.mean, self.std) 20 | 21 | 22 | class Constant(object): 23 | def __init__(self, val): 24 | self.val = val 25 | 26 | def sample(self): 27 | return self.val -------------------------------------------------------------------------------- /3D_LessNet/data/trans.py: -------------------------------------------------------------------------------- 1 | # import math 2 | import random 3 | import collections 4 | import numpy as np 5 | import torch, sys, random, math 6 | from scipy import ndimage 7 | 8 | from .rand import Constant, Uniform, Gaussian 9 | from scipy.ndimage import rotate 10 | from skimage.transform import rescale, resize 11 | 12 | class Base(object): 13 | def sample(self, *shape): 14 | return shape 15 | 16 | def tf(self, img, k=0): 17 | return img 18 | 19 | def __call__(self, img, dim=3, reuse=False): # class -> func() 20 | # image: nhwtc 21 | # shape: no first dim 22 | if not reuse: 23 | im = img if isinstance(img, np.ndarray) else img[0] 24 | # how to know if the last dim is channel?? 25 | # nhwtc vs nhwt?? 26 | shape = im.shape[1:dim+1] 27 | # print(dim,shape) # 3, (240,240,155) 28 | self.sample(*shape) 29 | 30 | if isinstance(img, collections.Sequence): 31 | return [self.tf(x, k) for k, x in enumerate(img)] # img:k=0,label:k=1 32 | 33 | return self.tf(img) 34 | 35 | def __str__(self): 36 | return 'Identity()' 37 | 38 | Identity = Base 39 | 40 | # gemetric transformations, need a buffers 41 | # first axis is N 42 | class Rot90(Base): 43 | def __init__(self, axes=(0, 1)): 44 | self.axes = axes 45 | 46 | for a in self.axes: 47 | assert a > 0 48 | 49 | def sample(self, *shape): 50 | shape = list(shape) 51 | i, j = self.axes 52 | 53 | # shape: no first dim 54 | i, j = i-1, j-1 55 | shape[i], shape[j] = shape[j], shape[i] 56 | 57 | return shape 58 | 59 | def tf(self, img, k=0): 60 | return np.rot90(img, axes=self.axes) 61 | 62 | def __str__(self): 63 | return 'Rot90(axes=({}, {})'.format(*self.axes) 64 | 65 | # class RandomRotion(Base): 66 | # def __init__(self, angle=20):# angle :in degress, float, [0,360] 67 | # assert angle >= 0.0 68 | # self.axes = (0,1) # 只对HW方向进行旋转 69 | # self.angle = angle # 70 | # self.buffer = None 71 | # 72 | # def sample(self, *shape):# shape : [H,W,D] 73 | # shape = list(shape) 74 | # self.buffer = round(np.random.uniform(low=-self.angle,high=self.angle),2) # 2个小数点 75 | # if self.buffer < 0: 76 | # self.buffer += 180 77 | # return shape 78 | # 79 | # def tf(self, img, k=0): # img shape [1,H,W,D,c] while label shape is [1,H,W,D] 80 | # return ndimage.rotate(img, angle=self.buffer, reshape=False) 81 | # 82 | # def __str__(self): 83 | # return 'RandomRotion(axes=({}, {}),Angle:{}'.format(*self.axes,self.buffer) 84 | 85 | class RandomRotion(Base): 86 | def __init__(self,angle_spectrum=10): 87 | assert isinstance(angle_spectrum,int) 88 | # axes = [(2, 1), (3, 1),(3, 2)] 89 | axes = [(1, 0), (2, 1),(2, 0)] 90 | self.angle_spectrum = angle_spectrum 91 | self.axes = axes 92 | 93 | def sample(self,*shape): 94 | self.axes_buffer = self.axes[np.random.choice(list(range(len(self.axes))))] # choose the random direction 95 | self.angle_buffer = np.random.randint(-self.angle_spectrum, self.angle_spectrum) # choose the random direction 96 | return list(shape) 97 | 98 | def tf(self, img, k=0): 99 | """ Introduction: The rotation function supports the shape [H,W,D,C] or shape [H,W,D] 100 | :param img: if x, shape is [1,H,W,D,c]; if label, shape is [1,H,W,D] 101 | :param k: if x, k=0; if label, k=1 102 | """ 103 | bsize = img.shape[0] 104 | 105 | for bs in range(bsize): 106 | if k == 0: 107 | # [[H,W,D], ...] 108 | # print(img.shape) # (1, 128, 128, 128, 4) 109 | channels = [rotate(img[bs,:,:,:,c], self.angle_buffer, axes=self.axes_buffer, reshape=False, order=0, mode='constant', cval=-1) for c in 110 | range(img.shape[4])] 111 | img[bs,...] = np.stack(channels, axis=-1) 112 | 113 | if k == 1: 114 | img[bs,...] = rotate(img[bs,...], self.angle_buffer, axes=self.axes_buffer, reshape=False, order=0, mode='constant', cval=-1) 115 | 116 | return img 117 | 118 | def __str__(self): 119 | return 'RandomRotion(axes={},Angle:{}'.format(self.axes_buffer,self.angle_buffer) 120 | 121 | 122 | class Flip(Base): 123 | def __init__(self, axis=0): 124 | self.axis = axis 125 | 126 | def tf(self, img, k=0): 127 | return np.flip(img, self.axis) 128 | 129 | def __str__(self): 130 | return 'Flip(axis={})'.format(self.axis) 131 | 132 | class RandomFlip(Base): 133 | # mirror flip across all x,y,z 134 | def __init__(self,axis=0): 135 | # assert axis == (1,2,3) # For both data and label, it has to specify the axis. 136 | self.axis = (1,2,3) 137 | self.x_buffer = None 138 | self.y_buffer = None 139 | self.z_buffer = None 140 | 141 | def sample(self, *shape): 142 | self.x_buffer = np.random.choice([True,False]) 143 | self.y_buffer = np.random.choice([True,False]) 144 | self.z_buffer = np.random.choice([True,False]) 145 | return list(shape) # the shape is not changed 146 | 147 | def tf(self,img,k=0): # img shape is (1, 240, 240, 155, 4) 148 | if self.x_buffer: 149 | img = np.flip(img,axis=self.axis[0]) 150 | if self.y_buffer: 151 | img = np.flip(img,axis=self.axis[1]) 152 | if self.z_buffer: 153 | img = np.flip(img,axis=self.axis[2]) 154 | return img 155 | 156 | 157 | class RandSelect(Base): 158 | def __init__(self, prob=0.5, tf=None): 159 | self.prob = prob 160 | self.ops = tf if isinstance(tf, collections.Sequence) else (tf, ) 161 | self.buff = False 162 | 163 | def sample(self, *shape): 164 | self.buff = random.random() < self.prob 165 | 166 | if self.buff: 167 | for op in self.ops: 168 | shape = op.sample(*shape) 169 | 170 | return shape 171 | 172 | def tf(self, img, k=0): 173 | if self.buff: 174 | for op in self.ops: 175 | img = op.tf(img, k) 176 | return img 177 | 178 | def __str__(self): 179 | if len(self.ops) == 1: 180 | ops = str(self.ops[0]) 181 | else: 182 | ops = '[{}]'.format(', '.join([str(op) for op in self.ops])) 183 | return 'RandSelect({}, {})'.format(self.prob, ops) 184 | 185 | 186 | class CenterCrop(Base): 187 | def __init__(self, size): 188 | self.size = size 189 | self.buffer = None 190 | 191 | def sample(self, *shape): 192 | size = self.size 193 | start = [(s -size)//2 for s in shape] 194 | self.buffer = [slice(None)] + [slice(s, s+size) for s in start] 195 | return [size] * len(shape) 196 | 197 | def tf(self, img, k=0): 198 | # print(img.shape)#(1, 240, 240, 155, 4) 199 | return img[tuple(self.buffer)] 200 | # return img[self.buffer] 201 | 202 | def __str__(self): 203 | return 'CenterCrop({})'.format(self.size) 204 | 205 | class CenterCropBySize(CenterCrop): 206 | def sample(self, *shape): 207 | assert len(self.size) == 3 # random crop [H,W,T] from img [240,240,155] 208 | if not isinstance(self.size, list): 209 | size = list(self.size) 210 | else: 211 | size = self.size 212 | start = [(s-i)//2 for i, s in zip(size, shape)] 213 | self.buffer = [slice(None)] + [slice(s, s+i) for i, s in zip(size, start)] 214 | return size 215 | 216 | def __str__(self): 217 | return 'CenterCropBySize({})'.format(self.size) 218 | 219 | class RandCrop(CenterCrop): 220 | def sample(self, *shape): 221 | size = self.size 222 | start = [random.randint(0, s-size) for s in shape] 223 | self.buffer = [slice(None)] + [slice(s, s+size) for s in start] 224 | return [size]*len(shape) 225 | 226 | def __str__(self): 227 | return 'RandCrop({})'.format(self.size) 228 | 229 | 230 | class RandCrop3D(CenterCrop): 231 | def sample(self, *shape): # shape : [240,240,155] 232 | assert len(self.size)==3 # random crop [H,W,T] from img [240,240,155] 233 | if not isinstance(self.size,list): 234 | size = list(self.size) 235 | else: 236 | size = self.size 237 | start = [random.randint(0, s-i) for i,s in zip(size,shape)] 238 | self.buffer = [slice(None)] + [slice(s, s+k) for s,k in zip(start,size)] 239 | return size 240 | 241 | def __str__(self): 242 | return 'RandCrop({})'.format(self.size) 243 | 244 | # for data only 245 | class RandomIntensityChange(Base): 246 | def __init__(self,factor): 247 | shift,scale = factor 248 | assert (shift >0) and (scale >0) 249 | self.shift = shift 250 | self.scale = scale 251 | 252 | def tf(self,img,k=0): 253 | if k==1: 254 | return img 255 | 256 | shift_factor = np.random.uniform(-self.shift,self.shift,size=[1,img.shape[1],1,1,img.shape[4]]) # [-0.1,+0.1] 257 | scale_factor = np.random.uniform(1.0 - self.scale, 1.0 + self.scale,size=[1,img.shape[1],1,1,img.shape[4]]) # [0.9,1.1) 258 | # shift_factor = np.random.uniform(-self.shift,self.shift,size=[1,1,1,img.shape[3],img.shape[4]]) # [-0.1,+0.1] 259 | # scale_factor = np.random.uniform(1.0 - self.scale, 1.0 + self.scale,size=[1,1,1,img.shape[3],img.shape[4]]) # [0.9,1.1) 260 | return img * scale_factor + shift_factor 261 | 262 | def __str__(self): 263 | return 'random intensity shift per channels on the input image, including' 264 | 265 | class RandomGammaCorrection(Base): 266 | def __init__(self,factor): 267 | lower, upper = factor 268 | assert (lower >0) and (upper >0) 269 | self.lower = lower 270 | self.upper = upper 271 | 272 | def tf(self,img,k=0): 273 | if k==1: 274 | return img 275 | img = img + np.min(img) 276 | img_max = np.max(img) 277 | img = img/img_max 278 | factor = random.choice(np.arange(self.lower, self.upper, 0.1)) 279 | gamma = random.choice([1, factor]) 280 | if gamma == 1: 281 | return img 282 | img = img ** gamma * img_max 283 | img = (img - img.mean())/img.std() 284 | return img 285 | 286 | def __str__(self): 287 | return 'random intensity shift per channels on the input image, including' 288 | 289 | class MinMax_norm(Base): 290 | def __init__(self, ): 291 | a = None 292 | 293 | def tf(self, img, k=0): 294 | if k == 1: 295 | return img 296 | img = (img - img.min()) / (img.max()-img.min()) 297 | return img 298 | 299 | class Seg_norm(Base): 300 | def __init__(self, ): 301 | a = None 302 | self.seg_table = np.array([0, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 303 | 28, 30, 31, 41, 42, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 62, 304 | 63, 72, 77, 80, 85, 251, 252, 253, 254, 255]) 305 | def tf(self, img, k=0): 306 | if k == 0: 307 | return img 308 | img_out = np.zeros_like(img) 309 | for i in range(len(self.seg_table)): 310 | img_out[img == self.seg_table[i]] = i 311 | return img_out 312 | 313 | class Resize_img(Base): 314 | def __init__(self, shape): 315 | self.shape = shape 316 | 317 | def tf(self, img, k=0): 318 | if k == 1: 319 | img = resize(img, (img.shape[0], self.shape[0], self.shape[1], self.shape[2]), 320 | anti_aliasing=False, order=0) 321 | else: 322 | img = resize(img, (img.shape[0], self.shape[0], self.shape[1], self.shape[2]), 323 | anti_aliasing=False, order=3) 324 | return img 325 | 326 | class Pad(Base): 327 | def __init__(self, pad): # [0,0,0,5,0] 328 | self.pad = pad 329 | self.px = tuple(zip([0]*len(pad), pad)) 330 | 331 | def sample(self, *shape): 332 | 333 | shape = list(shape) 334 | 335 | # shape: no first dim 336 | for i in range(len(shape)): 337 | shape[i] += self.pad[i+1] 338 | 339 | return shape 340 | 341 | def tf(self, img, k=0): 342 | #nhwtc, nhwt 343 | dim = len(img.shape) 344 | return np.pad(img, self.px[:dim], mode='constant') 345 | 346 | def __str__(self): 347 | return 'Pad(({}, {}, {}))'.format(*self.pad) 348 | 349 | class Pad3DIfNeeded(Base): 350 | def __init__(self, shape, value=0, mask_value=0): # [0,0,0,5,0] 351 | self.shape = shape 352 | self.value = value 353 | self.mask_value = mask_value 354 | 355 | def tf(self, img, k=0): 356 | pad = [(0,0)] 357 | if k==0: 358 | img_shape = img.shape[1:-1] 359 | else: 360 | img_shape = img.shape[1:] 361 | for i, t in zip(img_shape, self.shape): 362 | if i < t: 363 | diff = t-i 364 | pad.append((math.ceil(diff/2),math.floor(diff/2))) 365 | else: 366 | pad.append((0,0)) 367 | if k == 0: 368 | pad.append((0,0)) 369 | pad = tuple(pad) 370 | if k==0: 371 | return np.pad(img, pad, mode='constant', constant_values=img.min()) 372 | else: 373 | return np.pad(img, pad, mode='constant', constant_values=self.mask_value) 374 | 375 | def __str__(self): 376 | return 'Pad(({}, {}, {}))'.format(*self.pad) 377 | 378 | class Noise(Base): 379 | def __init__(self, dim, sigma=0.1, channel=True, num=-1): 380 | self.dim = dim 381 | self.sigma = sigma 382 | self.channel = channel 383 | self.num = num 384 | 385 | def tf(self, img, k=0): 386 | if self.num > 0 and k >= self.num: 387 | return img 388 | 389 | if self.channel: 390 | #nhwtc, hwtc, hwt 391 | shape = [1] if len(img.shape) < self.dim+2 else [img.shape[-1]] 392 | else: 393 | shape = img.shape 394 | return img * np.exp(self.sigma * torch.randn(shape, dtype=torch.float32).numpy()) 395 | 396 | def __str__(self): 397 | return 'Noise()' 398 | 399 | 400 | # dim could come from shape 401 | class GaussianBlur(Base): 402 | def __init__(self, dim, sigma=Constant(1.5), app=-1): 403 | # 1.5 pixel 404 | self.dim = dim 405 | self.sigma = sigma 406 | self.eps = 0.001 407 | self.app = app 408 | 409 | def tf(self, img, k=0): 410 | if self.num > 0 and k >= self.num: 411 | return img 412 | 413 | # image is nhwtc 414 | for n in range(img.shape[0]): 415 | sig = self.sigma.sample() 416 | # sample each channel saperately to avoid correlations 417 | if sig > self.eps: 418 | if len(img.shape) == self.dim+2: 419 | C = img.shape[-1] 420 | for c in range(C): 421 | img[n,..., c] = ndimage.gaussian_filter(img[n, ..., c], sig) 422 | elif len(img.shape) == self.dim+1: 423 | img[n] = ndimage.gaussian_filter(img[n], sig) 424 | else: 425 | raise ValueError('image shape is not supported') 426 | 427 | return img 428 | 429 | def __str__(self): 430 | return 'GaussianBlur()' 431 | 432 | 433 | class ToNumpy(Base): 434 | def __init__(self, num=-1): 435 | self.num = num 436 | 437 | def tf(self, img, k=0): 438 | if self.num > 0 and k >= self.num: 439 | return img 440 | return img.numpy() 441 | 442 | def __str__(self): 443 | return 'ToNumpy()' 444 | 445 | 446 | class ToTensor(Base): 447 | def __init__(self, num=-1): 448 | self.num = num 449 | 450 | def tf(self, img, k=0): 451 | if self.num > 0 and k >= self.num: 452 | return img 453 | 454 | return torch.from_numpy(img) 455 | 456 | def __str__(self): 457 | return 'ToTensor' 458 | 459 | 460 | class TensorType(Base): 461 | def __init__(self, types, num=-1): 462 | self.types = types # ('torch.float32', 'torch.int64') 463 | self.num = num 464 | 465 | def tf(self, img, k=0): 466 | if self.num > 0 and k >= self.num: 467 | return img 468 | # make this work with both Tensor and Numpy 469 | return img.type(self.types[k]) 470 | 471 | def __str__(self): 472 | s = ', '.join([str(s) for s in self.types]) 473 | return 'TensorType(({}))'.format(s) 474 | 475 | 476 | class NumpyType(Base): 477 | def __init__(self, types, num=-1): 478 | self.types = types # ('float32', 'int64') 479 | self.num = num 480 | 481 | def tf(self, img, k=0): 482 | if self.num > 0 and k >= self.num: 483 | return img 484 | # make this work with both Tensor and Numpy 485 | return img.astype(self.types[k]) 486 | 487 | def __str__(self): 488 | s = ', '.join([str(s) for s in self.types]) 489 | return 'NumpyType(({}))'.format(s) 490 | 491 | 492 | class Normalize(Base): 493 | def __init__(self, mean=0.0, std=1.0, num=-1): 494 | self.mean = mean 495 | self.std = std 496 | self.num = num 497 | 498 | def tf(self, img, k=0): 499 | if self.num > 0 and k >= self.num: 500 | return img 501 | img -= self.mean 502 | img /= self.std 503 | return img 504 | 505 | def __str__(self): 506 | return 'Normalize()' 507 | 508 | 509 | class Compose(Base): 510 | def __init__(self, ops): 511 | if not isinstance(ops, collections.Sequence): 512 | ops = ops, 513 | self.ops = ops 514 | 515 | def sample(self, *shape): 516 | for op in self.ops: 517 | shape = op.sample(*shape) 518 | 519 | def tf(self, img, k=0): 520 | #is_tensor = isinstance(img, torch.Tensor) 521 | #if is_tensor: 522 | # img = img.numpy() 523 | 524 | for op in self.ops: 525 | # print(op,img.shape,k) 526 | img = op.tf(img, k) # do not use op(img) here 527 | 528 | #if is_tensor: 529 | # img = np.ascontiguousarray(img) 530 | # img = torch.from_numpy(img) 531 | 532 | return img 533 | 534 | def __str__(self): 535 | ops = ', '.join([str(op) for op in self.ops]) 536 | return 'Compose([{}])'.format(ops) -------------------------------------------------------------------------------- /3D_LessNet/infer_bilinear.py: -------------------------------------------------------------------------------- 1 | import os, utils 2 | import glob 3 | import sys 4 | from argparse import ArgumentParser 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torch 8 | from torchvision import transforms 9 | from Models import * 10 | # from Functions import TrainDataset 11 | from torch.utils.data import DataLoader 12 | import torch.utils.data as Data 13 | from data import datasets, trans 14 | from natsort import natsorted 15 | import csv 16 | parser = ArgumentParser() 17 | parser.add_argument("--lr", type=float, 18 | dest="lr", default=1e-4, help="learning rate") 19 | parser.add_argument("--bs", type=int, 20 | dest="bs", default=1, help="batch_size") 21 | parser.add_argument("--iteration", type=int, 22 | dest="iteration", default=320001, 23 | help="number of total iterations") 24 | parser.add_argument("--smth_labda", type=float, 25 | dest="smth_labda", default=0.02, 26 | help="labda loss: suggested range 0.1 to 10") 27 | parser.add_argument("--checkpoint", type=int, 28 | dest="checkpoint", default=403, 29 | help="frequency of saving models") 30 | parser.add_argument("--start_channel", type=int, 31 | dest="start_channel", default=8, 32 | help="number of start channels") 33 | parser.add_argument("--trainingset", type=int, 34 | dest="trainingset", default=4, 35 | help="1 Half : 200 Images, 2 The other Half 200 Images 3 All 400 Images") 36 | parser.add_argument("--using_l2", type=int, 37 | dest="using_l2", 38 | default=1, 39 | help="using l2 or not") 40 | opt = parser.parse_args() 41 | 42 | lr = opt.lr 43 | bs = opt.bs 44 | iteration = opt.iteration 45 | start_channel = opt.start_channel 46 | n_checkpoint = opt.checkpoint 47 | smooth = opt.smth_labda 48 | trainingset = opt.trainingset 49 | using_l2 = opt.using_l2 50 | 51 | 52 | def main(): 53 | use_cuda = True 54 | device = torch.device("cuda" if use_cuda else "cpu") 55 | transform = SpatialTransform().cuda() 56 | diff_transform = DiffeomorphicTransform(time_step=7).cuda() 57 | atlas_dir = '/bask/projects/d/duanj-ai-imaging/UvT/TransMorph_Xi/IXI_Mine/IXI_data/atlas.pkl' 58 | test_dir = '/bask/projects/d/duanj-ai-imaging/UvT/TransMorph_Xi/IXI_Mine/IXI_data/Test/' 59 | model_idx = -1 60 | model_dir = './L2ss_{}_Chan_{}_Smth_{}_LR_{}_Val/'.format(using_l2,start_channel,smooth,lr) 61 | dict = utils.process_label() 62 | if not os.path.exists('Quantitative_Results/'): 63 | os.makedirs('Quantitative_Results/') 64 | if os.path.exists('Quantitative_Results/'+model_dir[:-1]+'_Test.csv'): 65 | os.remove('Quantitative_Results/'+model_dir[:-1]+'_Test.csv') 66 | csv_writter(model_dir[:-1], 'Quantitative_Results/' + model_dir[:-1]+'_Test') 67 | line = '' 68 | for i in range(46): 69 | line = line + ',' + dict[i] 70 | csv_writter(line +','+'non_jec', 'Quantitative_Results/' + model_dir[:-1]+'_Test') 71 | 72 | 73 | model = UNet(6, 3, start_channel).cuda() 74 | 75 | print(model_dir + natsorted(os.listdir(model_dir))[model_idx]) 76 | best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[model_idx])#['state_dict'] 77 | model.load_state_dict(best_model) 78 | model.cuda() 79 | # reg_model = utils.register_model(config.img_size, 'nearest') 80 | # reg_model.cuda() 81 | test_composed = transforms.Compose([trans.Seg_norm(), 82 | trans.NumpyType((np.float32, np.int16)), 83 | ]) 84 | test_set = datasets.IXIBrainInferDataset(glob.glob(test_dir + '*.pkl'), atlas_dir, transforms=test_composed) 85 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, drop_last=True) 86 | eval_dsc_def = utils.AverageMeter() 87 | eval_dsc_raw = utils.AverageMeter() 88 | eval_det = utils.AverageMeter() 89 | with torch.no_grad(): 90 | stdy_idx = 0 91 | for data in test_loader: 92 | model.eval() 93 | data = [t.cuda() for t in data] 94 | x = data[0] 95 | y = data[1] 96 | x_seg = data[2] 97 | y_seg = data[3] 98 | 99 | v_xy = model(x.float().to(device), y.float().to(device)) 100 | # Dv_xy = diff_transform(v_xy) 101 | Dv_xy = v_xy 102 | # def_out = reg_model([x_seg.cuda().float(), flow.cuda()]) 103 | def_out= transform(x_seg.float().to(device), Dv_xy.permute(0, 2, 3, 4, 1), mod = 'nearest') 104 | tar = y.detach().cpu().numpy()[0, 0, :, :, :] 105 | # print(f_xy.shape) #[1, 3, 160, 192, 224] 106 | dd, hh, ww = Dv_xy.shape[-3:] 107 | Dv_xy = Dv_xy.detach().cpu().numpy() 108 | Dv_xy[:,0,:,:,:] = Dv_xy[:,0,:,:,:] * dd / 2 109 | Dv_xy[:,1,:,:,:] = Dv_xy[:,1,:,:,:] * hh / 2 110 | Dv_xy[:,2,:,:,:] = Dv_xy[:,2,:,:,:] * ww / 2 111 | # jac_det = utils.jacobian_determinant_vxm(f_xy.detach().cpu().numpy()[0, :, :, :, :]) 112 | jac_det = utils.jacobian_determinant_vxm(Dv_xy[0, :, :, :, :]) 113 | line = utils.dice_val_substruct(def_out.long(), y_seg.long(), stdy_idx) 114 | line = line +','+str(np.sum(jac_det <= 0)/np.prod(tar.shape)) 115 | csv_writter(line, 'Quantitative_Results/' + model_dir[:-1]+'_Test') 116 | eval_det.update(np.sum(jac_det <= 0) / np.prod(tar.shape), x.size(0)) 117 | print('det < 0: {}'.format(np.sum(jac_det <= 0) / np.prod(tar.shape))) 118 | dsc_trans = utils.dice_val(def_out.long(), y_seg.long(), 46) 119 | dsc_raw = utils.dice_val(x_seg.long(), y_seg.long(), 46) 120 | print('Trans dsc: {:.4f}, Raw dsc: {:.4f}'.format(dsc_trans.item(),dsc_raw.item())) 121 | eval_dsc_def.update(dsc_trans.item(), x.size(0)) 122 | eval_dsc_raw.update(dsc_raw.item(), x.size(0)) 123 | stdy_idx += 1 124 | 125 | print('Deformed DSC: {:.3f} +- {:.3f}, Affine DSC: {:.3f} +- {:.3f}'.format(eval_dsc_def.avg, 126 | eval_dsc_def.std, 127 | eval_dsc_raw.avg, 128 | eval_dsc_raw.std)) 129 | print('deformed det: {}, std: {}'.format(eval_det.avg, eval_det.std)) 130 | 131 | def csv_writter(line, name): 132 | with open(name+'.csv', 'a') as file: 133 | file.write(line) 134 | file.write('\n') 135 | 136 | if __name__ == '__main__': 137 | ''' 138 | GPU configuration 139 | ''' 140 | # GPU_iden = 1 141 | # GPU_num = torch.cuda.device_count() 142 | # print('Number of GPU: ' + str(GPU_num)) 143 | # for GPU_idx in range(GPU_num): 144 | # GPU_name = torch.cuda.get_device_name(GPU_idx) 145 | # print(' GPU #' + str(GPU_idx) + ': ' + GPU_name) 146 | # torch.cuda.set_device(GPU_iden) 147 | # GPU_avai = torch.cuda.is_available() 148 | # print('Currently using: ' + torch.cuda.get_device_name(GPU_iden)) 149 | # print('If the GPU is available? ' + str(GPU_avai)) 150 | main() 151 | -------------------------------------------------------------------------------- /3D_LessNet/label_info.txt: -------------------------------------------------------------------------------- 1 | 0 Unknown 0 0 0 0 2 | 1 Left-Cerebral-Exterior 70 130 180 0 3 | 2 Left-Cerebral-White-Matter 245 245 245 0 4 | 3 Left-Cerebral-Cortex 205 62 78 0 5 | 4 Left-Lateral-Ventricle 120 18 134 0 6 | 5 Left-Inf-Lat-Vent 196 58 250 0 7 | 6 Left-Cerebellum-Exterior 0 148 0 0 8 | 7 Left-Cerebellum-White-Matter 220 248 164 0 9 | 8 Left-Cerebellum-Cortex 230 148 34 0 10 | 9 Left-Thalamus 0 118 14 0 11 | 10 Left-Thalamus-Proper* 0 118 14 0 12 | 11 Left-Caudate 122 186 220 0 13 | 12 Left-Putamen 236 13 176 0 14 | 13 Left-Pallidum 12 48 255 0 15 | 14 3rd-Ventricle 204 182 142 0 16 | 15 4th-Ventricle 42 204 164 0 17 | 16 Brain-Stem 119 159 176 0 18 | 17 Left-Hippocampus 220 216 20 0 19 | 18 Left-Amygdala 103 255 255 0 20 | 19 Left-Insula 80 196 98 0 21 | 20 Left-Operculum 60 58 210 0 22 | 21 Line-1 60 58 210 0 23 | 22 Line-2 60 58 210 0 24 | 23 Line-3 60 58 210 0 25 | 24 CSF 60 60 60 0 26 | 25 Left-Lesion 255 165 0 0 27 | 26 Left-Accumbens-area 255 165 0 0 28 | 27 Left-Substancia-Nigra 0 255 127 0 29 | 28 Left-VentralDC 165 42 42 0 30 | 29 Left-undetermined 135 206 235 0 31 | 30 Left-vessel 160 32 240 0 32 | 31 Left-choroid-plexus 0 200 200 0 33 | 32 Left-F3orb 100 50 100 0 34 | 33 Left-lOg 135 50 74 0 35 | 34 Left-aOg 122 135 50 0 36 | 35 Left-mOg 51 50 135 0 37 | 36 Left-pOg 74 155 60 0 38 | 37 Left-Stellate 120 62 43 0 39 | 38 Left-Porg 74 155 60 0 40 | 39 Left-Aorg 122 135 50 0 41 | 40 Right-Cerebral-Exterior 70 130 180 0 42 | 41 Right-Cerebral-White-Matter 245 245 245 0 43 | 42 Right-Cerebral-Cortex 205 62 78 0 44 | 43 Right-Lateral-Ventricle 120 18 134 0 45 | 44 Right-Inf-Lat-Vent 196 58 250 0 46 | 45 Right-Cerebellum-Exterior 0 148 0 0 47 | 46 Right-Cerebellum-White-Matter 220 248 164 0 48 | 47 Right-Cerebellum-Cortex 230 148 34 0 49 | 48 Right-Thalamus 0 118 14 0 50 | 49 Right-Thalamus-Proper* 0 118 14 0 51 | 50 Right-Caudate 122 186 220 0 52 | 51 Right-Putamen 236 13 176 0 53 | 52 Right-Pallidum 13 48 255 0 54 | 53 Right-Hippocampus 220 216 20 0 55 | 54 Right-Amygdala 103 255 255 0 56 | 55 Right-Insula 80 196 98 0 57 | 56 Right-Operculum 60 58 210 0 58 | 57 Right-Lesion 255 165 0 0 59 | 58 Right-Accumbens-area 255 165 0 0 60 | 59 Right-Substancia-Nigra 0 255 127 0 61 | 60 Right-VentralDC 165 42 42 0 62 | 61 Right-undetermined 135 206 235 0 63 | 62 Right-vessel 160 32 240 0 64 | 63 Right-choroid-plexus 0 200 221 0 65 | 64 Right-F3orb 100 50 100 0 66 | 65 Right-lOg 135 50 74 0 67 | 66 Right-aOg 122 135 50 0 68 | 67 Right-mOg 51 50 135 0 69 | 68 Right-pOg 74 155 60 0 70 | 69 Right-Stellate 120 62 43 0 71 | 70 Right-Porg 74 155 60 0 72 | 71 Right-Aorg 122 135 50 0 73 | 72 5th-Ventricle 120 190 150 0 74 | 73 Left-Interior 122 135 50 0 75 | 74 Right-Interior 122 135 50 0 76 | 77 | 77 WM-hypointensities 200 70 255 0 78 | 78 Left-WM-hypointensities 255 148 10 0 79 | 79 Right-WM-hypointensities 255 148 10 0 80 | 80 non-WM-hypointensities 164 108 226 0 81 | 81 Left-non-WM-hypointensities 164 108 226 0 82 | 82 Right-non-WM-hypointensities 164 108 226 0 83 | 83 Left-F1 255 218 185 0 84 | 84 Right-F1 255 218 185 0 85 | 85 Optic-Chiasm 234 169 30 0 86 | 192 Corpus_Callosum 250 255 50 0 87 | 88 | 86 Left_future_WMSA 200 120 255 0 89 | 87 Right_future_WMSA 200 121 255 0 90 | 88 future_WMSA 200 122 255 0 91 | 92 | 93 | 96 Left-Amygdala-Anterior 205 10 125 0 94 | 97 Right-Amygdala-Anterior 205 10 125 0 95 | 98 Dura 160 32 240 0 96 | 97 | 100 Left-wm-intensity-abnormality 124 140 178 0 98 | 101 Left-caudate-intensity-abnormality 125 140 178 0 99 | 102 Left-putamen-intensity-abnormality 126 140 178 0 100 | 103 Left-accumbens-intensity-abnormality 127 140 178 0 101 | 104 Left-pallidum-intensity-abnormality 124 141 178 0 102 | 105 Left-amygdala-intensity-abnormality 124 142 178 0 103 | 106 Left-hippocampus-intensity-abnormality 124 143 178 0 104 | 107 Left-thalamus-intensity-abnormality 124 144 178 0 105 | 108 Left-VDC-intensity-abnormality 124 140 179 0 106 | 109 Right-wm-intensity-abnormality 124 140 178 0 107 | 110 Right-caudate-intensity-abnormality 125 140 178 0 108 | 111 Right-putamen-intensity-abnormality 126 140 178 0 109 | 112 Right-accumbens-intensity-abnormality 127 140 178 0 110 | 113 Right-pallidum-intensity-abnormality 124 141 178 0 111 | 114 Right-amygdala-intensity-abnormality 124 142 178 0 112 | 115 Right-hippocampus-intensity-abnormality 124 143 178 0 113 | 116 Right-thalamus-intensity-abnormality 124 144 178 0 114 | 117 Right-VDC-intensity-abnormality 124 140 179 0 115 | 116 | 118 Epidermis 255 20 147 0 117 | 119 Conn-Tissue 205 179 139 0 118 | 120 SC-Fat-Muscle 238 238 209 0 119 | 121 Cranium 200 200 200 0 120 | 122 CSF-SA 74 255 74 0 121 | 123 Muscle 238 0 0 0 122 | 124 Ear 0 0 139 0 123 | 125 Adipose 173 255 47 0 124 | 126 Spinal-Cord 133 203 229 0 125 | 127 Soft-Tissue 26 237 57 0 126 | 128 Nerve 34 139 34 0 127 | 129 Bone 30 144 255 0 128 | 130 Air 147 19 173 0 129 | 131 Orbital-Fat 238 59 59 0 130 | 132 Tongue 221 39 200 0 131 | 133 Nasal-Structures 238 174 238 0 132 | 134 Globe 255 0 0 0 133 | 135 Teeth 72 61 139 0 134 | 136 Left-Caudate-Putamen 21 39 132 0 135 | 137 Right-Caudate-Putamen 21 39 132 0 136 | 138 Left-Claustrum 65 135 20 0 137 | 139 Right-Claustrum 65 135 20 0 138 | 140 Cornea 134 4 160 0 139 | 142 Diploe 221 226 68 0 140 | 143 Vitreous-Humor 255 255 254 0 141 | 144 Lens 52 209 226 0 142 | 145 Aqueous-Humor 239 160 223 0 143 | 146 Outer-Table 70 130 180 0 144 | 147 Inner-Table 70 130 181 0 145 | 148 Periosteum 139 121 94 0 146 | 149 Endosteum 224 224 224 0 147 | 150 R-C-S 255 0 0 0 148 | 151 Iris 205 205 0 0 149 | 152 SC-Adipose-Muscle 238 238 209 0 150 | 153 SC-Tissue 139 121 94 0 151 | 154 Orbital-Adipose 238 59 59 0 152 | 153 | 155 Left-IntCapsule-Ant 238 59 59 0 154 | 156 Right-IntCapsule-Ant 238 59 59 0 155 | 157 Left-IntCapsule-Pos 62 10 205 0 156 | 158 Right-IntCapsule-Pos 62 10 205 0 157 | 158 | # These labels are for babies/children 159 | 159 Left-Cerebral-WM-unmyelinated 0 118 14 0 160 | 160 Right-Cerebral-WM-unmyelinated 0 118 14 0 161 | 161 Left-Cerebral-WM-myelinated 220 216 21 0 162 | 162 Right-Cerebral-WM-myelinated 220 216 21 0 163 | 163 Left-Subcortical-Gray-Matter 122 186 220 0 164 | 164 Right-Subcortical-Gray-Matter 122 186 220 0 165 | 165 Skull 120 120 120 0 166 | 166 Posterior-fossa 14 48 255 0 167 | 167 Scalp 166 42 42 0 168 | 168 Hematoma 121 18 134 0 169 | 169 Left-Basal-Ganglia 236 13 127 0 170 | 176 Right-Basal-Ganglia 236 13 126 0 171 | 172 | # Label names and colors for Brainstem consituents 173 | # No. Label Name: R G B A 174 | 170 brainstem 119 159 176 0 175 | 171 DCG 119 0 176 0 176 | 172 Vermis 119 100 176 0 177 | 173 Midbrain 242 104 76 0 178 | 174 Pons 206 195 58 0 179 | 175 Medulla 119 159 176 0 180 | 177 Vermis-White-Matter 119 50 176 0 181 | 178 SCP 142 182 0 0 182 | 179 Floculus 19 100 176 0 183 | 184 | 180 Left-Cortical-Dysplasia 73 61 139 0 185 | 181 Right-Cortical-Dysplasia 73 62 139 0 186 | 182 CblumNodulus 10 100 176 0 187 | 188 | 193 Left-hippocampal_fissure 0 196 255 0 189 | 194 Left-CADG-head 255 164 164 0 190 | 195 Left-subiculum 196 196 0 0 191 | 196 Left-fimbria 0 100 255 0 192 | 197 Right-hippocampal_fissure 128 196 164 0 193 | 198 Right-CADG-head 0 126 75 0 194 | 199 Right-subiculum 128 96 64 0 195 | 200 Right-fimbria 0 50 128 0 196 | 201 alveus 255 204 153 0 197 | 202 perforant_pathway 255 128 128 0 198 | 203 parasubiculum 255 255 0 0 199 | 204 presubiculum 64 0 64 0 200 | 205 subiculum 0 0 255 0 201 | 206 CA1 255 0 0 0 202 | 207 CA2 128 128 255 0 203 | 208 CA3 0 128 0 0 204 | 209 CA4 196 160 128 0 205 | 210 GC-DG 32 200 255 0 206 | 211 HATA 128 255 128 0 207 | 212 fimbria 204 153 204 0 208 | 213 lateral_ventricle 121 17 136 0 209 | 214 molecular_layer_HP 128 0 0 0 210 | 215 hippocampal_fissure 128 32 255 0 211 | 216 entorhinal_cortex 255 204 102 0 212 | 217 molecular_layer_subiculum 128 128 128 0 213 | 218 Amygdala 104 255 255 0 214 | 219 Cerebral_White_Matter 0 226 0 0 215 | 220 Cerebral_Cortex 205 63 78 0 216 | 221 Inf_Lat_Vent 197 58 250 0 217 | 222 Perirhinal 33 150 250 0 218 | 223 Cerebral_White_Matter_Edge 226 0 0 0 219 | 224 Background 100 100 100 0 220 | 225 Ectorhinal 197 150 250 0 221 | 226 HP_tail 170 170 255 0 222 | 223 | 250 Fornix 255 0 0 0 224 | 251 CC_Posterior 0 0 64 0 225 | 252 CC_Mid_Posterior 0 0 112 0 226 | 253 CC_Central 0 0 160 0 227 | 254 CC_Mid_Anterior 0 0 208 0 228 | 255 CC_Anterior 0 0 255 0 229 | -------------------------------------------------------------------------------- /3D_LessNet/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | from argparse import ArgumentParser 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torch 8 | from torchvision import transforms 9 | from Models import * 10 | # from Functions import TrainDataset 11 | import torch.utils.data as Data 12 | from data import datasets, trans 13 | from natsort import natsorted 14 | import csv 15 | parser = ArgumentParser() 16 | parser.add_argument("--lr", type=float, 17 | dest="lr", default=1e-4, help="learning rate") 18 | parser.add_argument("--bs", type=int, 19 | dest="bs", default=1, help="batch_size") 20 | parser.add_argument("--iteration", type=int, 21 | dest="iteration", default=320001, 22 | help="number of total iterations") 23 | parser.add_argument("--smth_labda", type=float, 24 | dest="smth_labda", default=0.02, 25 | help="labda loss: suggested range 0.1 to 10") 26 | parser.add_argument("--checkpoint", type=int, 27 | dest="checkpoint", default=403, 28 | help="frequency of saving models") 29 | parser.add_argument("--start_channel", type=int, 30 | dest="start_channel", default=8, 31 | help="number of start channels") 32 | parser.add_argument("--trainingset", type=int, 33 | dest="trainingset", default=4, 34 | help="1 Half : 200 Images, 2 The other Half 200 Images 3 All 400 Images") 35 | parser.add_argument("--using_l2", type=int, 36 | dest="using_l2", 37 | default=1, 38 | help="using l2 or not") 39 | opt = parser.parse_args() 40 | 41 | lr = opt.lr 42 | bs = opt.bs 43 | iteration = opt.iteration 44 | start_channel = opt.start_channel 45 | n_checkpoint = opt.checkpoint 46 | smooth = opt.smth_labda 47 | trainingset = opt.trainingset 48 | using_l2 = opt.using_l2 49 | 50 | def dice(pred1, truth1): 51 | VOI_lbls = [1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 34, 36] 52 | dice_35=np.zeros(len(VOI_lbls)) 53 | index = 0 54 | for k in VOI_lbls: 55 | #print(k) 56 | truth = truth1 == k 57 | pred = pred1 == k 58 | intersection = np.sum(pred * truth) * 2.0 59 | # print(intersection) 60 | dice_35[index]=intersection / (np.sum(pred) + np.sum(truth)) 61 | index = index + 1 62 | return np.mean(dice_35) 63 | 64 | def save_checkpoint(state, save_dir, save_filename, max_model_num=10): 65 | torch.save(state, save_dir + save_filename) 66 | model_lists = natsorted(glob.glob(save_dir + '*')) 67 | # print(model_lists) 68 | while len(model_lists) > max_model_num: 69 | os.remove(model_lists[0]) 70 | model_lists = natsorted(glob.glob(save_dir + '*')) 71 | 72 | def train(): 73 | use_cuda = True 74 | device = torch.device("cuda" if use_cuda else "cpu") 75 | 76 | atlas_dir = '/bask/projects/d/duanj-ai-imaging/UvT/TransMorph_Xi/IXI_Mine/IXI_data/atlas.pkl' 77 | train_dir = '/bask/projects/d/duanj-ai-imaging/UvT/TransMorph_Xi/IXI_Mine/IXI_data/Train/' 78 | val_dir = '/bask/projects/d/duanj-ai-imaging/UvT/TransMorph_Xi/IXI_Mine/IXI_data/Val/' 79 | # train_composed = transforms.Compose([trans.RandomFlip(0), 80 | # trans.NumpyType((np.float32, np.float32)), 81 | # ]) 82 | 83 | train_composed = transforms.Compose([trans.NumpyType((np.float32, np.float32)), 84 | ]) 85 | 86 | val_composed = transforms.Compose([trans.Seg_norm(), #rearrange segmentation label to 1 to 46 87 | trans.NumpyType((np.float32, np.int16))]) 88 | train_set = datasets.IXIBrainDataset(glob.glob(train_dir + '*.pkl'), atlas_dir, transforms=train_composed) 89 | val_set = datasets.IXIBrainInferDataset(glob.glob(val_dir + '*.pkl'), atlas_dir, transforms=val_composed) 90 | train_loader = Data.DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=4, pin_memory=True) 91 | val_loader = Data.DataLoader(val_set, batch_size=bs, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) 92 | 93 | 94 | model = UNet(6, 3, start_channel).to(device) 95 | if using_l2 == 1: 96 | loss_similarity = MSE().loss 97 | elif using_l2 == 0: 98 | loss_similarity = SAD().loss 99 | elif using_l2 == 2: 100 | loss_similarity = NCC() 101 | loss_smooth = smoothloss 102 | 103 | transform = SpatialTransform().to(device) 104 | diff_transform = DiffeomorphicTransform(time_step=7).to(device) 105 | 106 | 107 | for param in transform.parameters(): 108 | param.requires_grad = False 109 | param.volatile = True 110 | 111 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 112 | model_dir = './L2ss_{}_Chan_{}_Smth_{}_LR_{}_Val/'.format(using_l2,start_channel,smooth,lr) 113 | csv_name = 'L2ss_{}_Chan_{}_Smth_{}_LR_{}.csv'.format(using_l2,start_channel,smooth,lr) 114 | assert os.path.exists(csv_name) ==0 115 | assert os.path.isdir(model_dir) ==0 116 | f = open(csv_name, 'w') 117 | with f: 118 | fnames = ['Index','Dice'] 119 | writer = csv.DictWriter(f, fieldnames=fnames) 120 | writer.writeheader() 121 | 122 | if not os.path.isdir(model_dir): 123 | os.mkdir(model_dir) 124 | 125 | lossall = np.zeros((3, iteration)) 126 | step = 1 127 | epoch = 0 128 | while step <= iteration: 129 | for X, Y in train_loader: 130 | 131 | X = X.to(device).float() 132 | Y = Y.to(device).float() 133 | 134 | f_xy = model(X, Y) 135 | # D_f_xy = diff_transform(f_xy) 136 | D_f_xy = f_xy 137 | X_Y = transform(X, D_f_xy.permute(0, 2, 3, 4, 1)) 138 | 139 | loss1 = loss_similarity(Y, X_Y) 140 | loss5 = loss_smooth(f_xy) 141 | loss = loss1 + smooth * loss5 142 | 143 | optimizer.zero_grad() 144 | loss.backward() 145 | optimizer.step() 146 | 147 | lossall[:,step] = np.array([loss.item(),loss1.item(),loss5.item()]) 148 | sys.stdout.write("\r" + 'step "{0}" -> training loss "{1:.4f}" - sim "{2:.4f}" -smo "{3:.4f}" '.format(step, loss.item(),loss1.item(),loss5.item())) 149 | sys.stdout.flush() 150 | 151 | if (step % n_checkpoint == 0) or (step == 1): 152 | with torch.no_grad(): 153 | Dices_Validation = [] 154 | for data in val_loader: 155 | model.eval() 156 | xv = data[0] 157 | yv = data[1] 158 | xv_seg = data[2] 159 | yv_seg = data[3] 160 | vf_xy = model(xv.float().to(device), yv.float().to(device)) 161 | warped_xv_seg= transform(xv_seg.float().to(device), vf_xy.permute(0, 2, 3, 4, 1), mod = 'nearest') 162 | for bs_index in range(bs): 163 | dice_bs=dice(warped_xv_seg[bs_index,...].data.cpu().numpy().copy(),yv_seg[bs_index,...].data.cpu().numpy().copy()) 164 | Dices_Validation.append(dice_bs) 165 | modelname = 'DiceVal_{:.4f}_Epoch_{:04d}.pth'.format(np.mean(Dices_Validation), epoch) 166 | f = open(csv_name, 'a') 167 | with f: 168 | writer = csv.writer(f) 169 | writer.writerow([epoch, np.mean(Dices_Validation)]) 170 | save_checkpoint(model.state_dict(), model_dir, modelname) 171 | # modelname = 'Epoch_{:09d}.pth'.format(epoch) 172 | # torch.save(model.state_dict(), model_dir + modelname) 173 | np.save(model_dir + 'Loss.npy', lossall) 174 | step += 1 175 | 176 | if step > iteration: 177 | break 178 | print("one epoch pass") 179 | epoch = epoch + 1 180 | np.save(model_dir + '/Loss.npy', lossall) 181 | 182 | train() 183 | -------------------------------------------------------------------------------- /3D_LessNet/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch, sys 5 | from torch import nn 6 | import pystrum.pynd.ndutils as nd 7 | from scipy.ndimage import gaussian_filter 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | self.vals = [] 20 | self.std = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | self.vals.append(val) 28 | self.std = np.std(self.vals) 29 | 30 | def pad_image(img, target_size): 31 | rows_to_pad = max(target_size[0] - img.shape[2], 0) 32 | cols_to_pad = max(target_size[1] - img.shape[3], 0) 33 | slcs_to_pad = max(target_size[2] - img.shape[4], 0) 34 | padded_img = F.pad(img, (0, slcs_to_pad, 0, cols_to_pad, 0, rows_to_pad), "constant", 0) 35 | return padded_img 36 | 37 | class SpatialTransformer(nn.Module): 38 | """ 39 | N-D Spatial Transformer 40 | """ 41 | 42 | def __init__(self, size, mode='bilinear'): 43 | super().__init__() 44 | 45 | self.mode = mode 46 | 47 | # create sampling grid 48 | vectors = [torch.arange(0, s) for s in size] 49 | grids = torch.meshgrid(vectors) 50 | grid = torch.stack(grids) 51 | grid = torch.unsqueeze(grid, 0) 52 | grid = grid.type(torch.FloatTensor).cuda() 53 | 54 | # registering the grid as a buffer cleanly moves it to the GPU, but it also 55 | # adds it to the state dict. this is annoying since everything in the state dict 56 | # is included when saving weights to disk, so the model files are way bigger 57 | # than they need to be. so far, there does not appear to be an elegant solution. 58 | # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict 59 | self.register_buffer('grid', grid) 60 | 61 | def forward(self, src, flow): 62 | # new locations 63 | new_locs = self.grid + flow 64 | shape = flow.shape[2:] 65 | 66 | # need to normalize grid values to [-1, 1] for resampler 67 | for i in range(len(shape)): 68 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) 69 | 70 | # move channels dim to last position 71 | # also not sure why, but the channels need to be reversed 72 | if len(shape) == 2: 73 | new_locs = new_locs.permute(0, 2, 3, 1) 74 | new_locs = new_locs[..., [1, 0]] 75 | elif len(shape) == 3: 76 | new_locs = new_locs.permute(0, 2, 3, 4, 1) 77 | new_locs = new_locs[..., [2, 1, 0]] 78 | 79 | return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode) 80 | 81 | class register_model(nn.Module): 82 | def __init__(self, img_size=(64, 256, 256), mode='bilinear'): 83 | super(register_model, self).__init__() 84 | self.spatial_trans = SpatialTransformer(img_size, mode) 85 | 86 | def forward(self, x): 87 | img = x[0].cuda() 88 | flow = x[1].cuda() 89 | out = self.spatial_trans(img, flow) 90 | return out 91 | 92 | def dice_val(y_pred, y_true, num_clus): 93 | y_pred = nn.functional.one_hot(y_pred, num_classes=num_clus) 94 | y_pred = torch.squeeze(y_pred, 1) 95 | y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() 96 | y_true = nn.functional.one_hot(y_true, num_classes=num_clus) 97 | y_true = torch.squeeze(y_true, 1) 98 | y_true = y_true.permute(0, 4, 1, 2, 3).contiguous() 99 | intersection = y_pred * y_true 100 | intersection = intersection.sum(dim=[2, 3, 4]) 101 | union = y_pred.sum(dim=[2, 3, 4]) + y_true.sum(dim=[2, 3, 4]) 102 | dsc = (2.*intersection) / (union + 1e-5) 103 | return torch.mean(torch.mean(dsc, dim=1)) 104 | 105 | def dice_val_VOI(y_pred, y_true): 106 | VOI_lbls = [1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 34, 36] 107 | pred = y_pred.detach().cpu().numpy()[0, 0, ...] 108 | true = y_true.detach().cpu().numpy()[0, 0, ...] 109 | DSCs = np.zeros((len(VOI_lbls), 1)) 110 | idx = 0 111 | for i in VOI_lbls: 112 | pred_i = pred == i 113 | true_i = true == i 114 | intersection = pred_i * true_i 115 | intersection = np.sum(intersection) 116 | union = np.sum(pred_i) + np.sum(true_i) 117 | dsc = (2.*intersection) / (union + 1e-5) 118 | DSCs[idx] =dsc 119 | idx += 1 120 | return np.mean(DSCs) 121 | 122 | def jacobian_determinant_vxm(disp): 123 | """ 124 | jacobian determinant of a displacement field. 125 | NB: to compute the spatial gradients, we use np.gradient. 126 | Parameters: 127 | disp: 2D or 3D displacement field of size [*vol_shape, nb_dims], 128 | where vol_shape is of len nb_dims 129 | Returns: 130 | jacobian determinant (scalar) 131 | """ 132 | 133 | # check inputs 134 | disp = disp.transpose(1, 2, 3, 0) 135 | volshape = disp.shape[:-1] 136 | nb_dims = len(volshape) 137 | assert len(volshape) in (2, 3), 'flow has to be 2D or 3D' 138 | 139 | # compute grid 140 | grid_lst = nd.volsize2ndgrid(volshape) 141 | grid = np.stack(grid_lst, len(volshape)) 142 | 143 | # compute gradients 144 | J = np.gradient(disp + grid) 145 | 146 | # 3D glow 147 | if nb_dims == 3: 148 | dx = J[0] 149 | dy = J[1] 150 | dz = J[2] 151 | 152 | # compute jacobian components 153 | Jdet0 = dx[..., 0] * (dy[..., 1] * dz[..., 2] - dy[..., 2] * dz[..., 1]) 154 | Jdet1 = dx[..., 1] * (dy[..., 0] * dz[..., 2] - dy[..., 2] * dz[..., 0]) 155 | Jdet2 = dx[..., 2] * (dy[..., 0] * dz[..., 1] - dy[..., 1] * dz[..., 0]) 156 | 157 | return Jdet0 - Jdet1 + Jdet2 158 | 159 | else: # must be 2 160 | 161 | dfdx = J[0] 162 | dfdy = J[1] 163 | 164 | return dfdx[..., 0] * dfdy[..., 1] - dfdy[..., 0] * dfdx[..., 1] 165 | 166 | import re 167 | def process_label(): 168 | #process labeling information for FreeSurfer 169 | seg_table = [0, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 170 | 28, 30, 31, 41, 42, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 62, 171 | 63, 72, 77, 80, 85, 251, 252, 253, 254, 255] 172 | 173 | 174 | file1 = open('label_info.txt', 'r') 175 | Lines = file1.readlines() 176 | dict = {} 177 | seg_i = 0 178 | seg_look_up = [] 179 | for seg_label in seg_table: 180 | for line in Lines: 181 | line = re.sub(' +', ' ',line).split(' ') 182 | try: 183 | int(line[0]) 184 | except: 185 | continue 186 | if int(line[0]) == seg_label: 187 | seg_look_up.append([seg_i, int(line[0]), line[1]]) 188 | dict[seg_i] = line[1] 189 | seg_i += 1 190 | return dict 191 | 192 | def write2csv(line, name): 193 | with open(name+'.csv', 'a') as file: 194 | file.write(line) 195 | file.write('\n') 196 | 197 | def dice_val_substruct(y_pred, y_true, std_idx): 198 | with torch.no_grad(): 199 | y_pred = nn.functional.one_hot(y_pred, num_classes=46) 200 | y_pred = torch.squeeze(y_pred, 1) 201 | y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() 202 | y_true = nn.functional.one_hot(y_true, num_classes=46) 203 | y_true = torch.squeeze(y_true, 1) 204 | y_true = y_true.permute(0, 4, 1, 2, 3).contiguous() 205 | y_pred = y_pred.detach().cpu().numpy() 206 | y_true = y_true.detach().cpu().numpy() 207 | 208 | line = 'p_{}'.format(std_idx) 209 | for i in range(46): 210 | pred_clus = y_pred[0, i, ...] 211 | true_clus = y_true[0, i, ...] 212 | intersection = pred_clus * true_clus 213 | intersection = intersection.sum() 214 | union = pred_clus.sum() + true_clus.sum() 215 | dsc = (2.*intersection) / (union + 1e-5) 216 | line = line+','+str(dsc) 217 | return line 218 | 219 | def dice(y_pred, y_true, ): 220 | intersection = y_pred * y_true 221 | intersection = np.sum(intersection) 222 | union = np.sum(y_pred) + np.sum(y_true) 223 | dsc = (2.*intersection) / (union + 1e-5) 224 | return dsc 225 | 226 | def smooth_seg(binary_img, sigma=1.5, thresh=0.4): 227 | binary_img = gaussian_filter(binary_img.astype(np.float32()), sigma=sigma) 228 | binary_img = binary_img > thresh 229 | return binary_img 230 | 231 | def get_mc_preds(net, inputs, mc_iter: int = 25): 232 | """Convenience fn. for MC integration for uncertainty estimation. 233 | Args: 234 | net: DIP model (can be standard, MFVI or MCDropout) 235 | inputs: input to net 236 | mc_iter: number of MC samples 237 | post_processor: process output of net before computing loss (e.g. downsampler in SR) 238 | mask: multiply output and target by mask before computing loss (for inpainting) 239 | """ 240 | img_list = [] 241 | flow_list = [] 242 | with torch.no_grad(): 243 | for _ in range(mc_iter): 244 | img, flow = net(inputs) 245 | img_list.append(img) 246 | flow_list.append(flow) 247 | return img_list, flow_list 248 | 249 | def calc_uncert(tar, img_list): 250 | sqr_diffs = [] 251 | for i in range(len(img_list)): 252 | sqr_diff = (img_list[i] - tar)**2 253 | sqr_diffs.append(sqr_diff) 254 | uncert = torch.mean(torch.cat(sqr_diffs, dim=0)[:], dim=0, keepdim=True) 255 | return uncert 256 | 257 | def calc_error(tar, img_list): 258 | sqr_diffs = [] 259 | for i in range(len(img_list)): 260 | sqr_diff = (img_list[i] - tar)**2 261 | sqr_diffs.append(sqr_diff) 262 | uncert = torch.mean(torch.cat(sqr_diffs, dim=0)[:], dim=0, keepdim=True) 263 | return uncert 264 | 265 | def get_mc_preds_w_errors(net, inputs, target, mc_iter: int = 25): 266 | """Convenience fn. for MC integration for uncertainty estimation. 267 | Args: 268 | net: DIP model (can be standard, MFVI or MCDropout) 269 | inputs: input to net 270 | mc_iter: number of MC samples 271 | post_processor: process output of net before computing loss (e.g. downsampler in SR) 272 | mask: multiply output and target by mask before computing loss (for inpainting) 273 | """ 274 | img_list = [] 275 | flow_list = [] 276 | MSE = nn.MSELoss() 277 | err = [] 278 | with torch.no_grad(): 279 | for _ in range(mc_iter): 280 | img, flow = net(inputs) 281 | img_list.append(img) 282 | flow_list.append(flow) 283 | err.append(MSE(img, target).item()) 284 | return img_list, flow_list, err 285 | 286 | def get_diff_mc_preds(net, inputs, mc_iter: int = 25): 287 | """Convenience fn. for MC integration for uncertainty estimation. 288 | Args: 289 | net: DIP model (can be standard, MFVI or MCDropout) 290 | inputs: input to net 291 | mc_iter: number of MC samples 292 | post_processor: process output of net before computing loss (e.g. downsampler in SR) 293 | mask: multiply output and target by mask before computing loss (for inpainting) 294 | """ 295 | img_list = [] 296 | flow_list = [] 297 | disp_list = [] 298 | with torch.no_grad(): 299 | for _ in range(mc_iter): 300 | img, _, flow, disp = net(inputs) 301 | img_list.append(img) 302 | flow_list.append(flow) 303 | disp_list.append(disp) 304 | return img_list, flow_list, disp_list 305 | 306 | def uncert_regression_gal(img_list, reduction = 'mean'): 307 | img_list = torch.cat(img_list, dim=0) 308 | mean = img_list[:,:-1].mean(dim=0, keepdim=True) 309 | ale = img_list[:,-1:].mean(dim=0, keepdim=True) 310 | epi = torch.var(img_list[:,:-1], dim=0, keepdim=True) 311 | #if epi.shape[1] == 3: 312 | epi = epi.mean(dim=1, keepdim=True) 313 | uncert = ale + epi 314 | if reduction == 'mean': 315 | return ale.mean().item(), epi.mean().item(), uncert.mean().item() 316 | elif reduction == 'sum': 317 | return ale.sum().item(), epi.sum().item(), uncert.sum().item() 318 | else: 319 | return ale.detach(), epi.detach(), uncert.detach() 320 | 321 | def uceloss(errors, uncert, n_bins=15, outlier=0.0, range=None): 322 | device = errors.device 323 | if range == None: 324 | bin_boundaries = torch.linspace(uncert.min().item(), uncert.max().item(), n_bins + 1, device=device) 325 | else: 326 | bin_boundaries = torch.linspace(range[0], range[1], n_bins + 1, device=device) 327 | bin_lowers = bin_boundaries[:-1] 328 | bin_uppers = bin_boundaries[1:] 329 | 330 | errors_in_bin_list = [] 331 | avg_uncert_in_bin_list = [] 332 | prop_in_bin_list = [] 333 | 334 | uce = torch.zeros(1, device=device) 335 | for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): 336 | # Calculated |uncertainty - error| in each bin 337 | in_bin = uncert.gt(bin_lower.item()) * uncert.le(bin_upper.item()) 338 | prop_in_bin = in_bin.float().mean() # |Bm| / n 339 | prop_in_bin_list.append(prop_in_bin) 340 | if prop_in_bin.item() > outlier: 341 | errors_in_bin = errors[in_bin].float().mean() # err() 342 | avg_uncert_in_bin = uncert[in_bin].mean() # uncert() 343 | uce += torch.abs(avg_uncert_in_bin - errors_in_bin) * prop_in_bin 344 | 345 | errors_in_bin_list.append(errors_in_bin) 346 | avg_uncert_in_bin_list.append(avg_uncert_in_bin) 347 | 348 | err_in_bin = torch.tensor(errors_in_bin_list, device=device) 349 | avg_uncert_in_bin = torch.tensor(avg_uncert_in_bin_list, device=device) 350 | prop_in_bin = torch.tensor(prop_in_bin_list, device=device) 351 | 352 | return uce, err_in_bin, avg_uncert_in_bin, prop_in_bin -------------------------------------------------------------------------------- /3D_LessNet_Diff/L2ss_2_Chan_12_Smth_2.0_LR_0.0001_Val/DiceVal_0.75488_Epoch_000000348.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet_Diff/L2ss_2_Chan_12_Smth_2.0_LR_0.0001_Val/DiceVal_0.75488_Epoch_000000348.pth -------------------------------------------------------------------------------- /3D_LessNet_Diff/L2ss_2_Chan_16_Smth_2.0_LR_0.0001_Val/DiceVal_0.75676_Epoch_000000423.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet_Diff/L2ss_2_Chan_16_Smth_2.0_LR_0.0001_Val/DiceVal_0.75676_Epoch_000000423.pth -------------------------------------------------------------------------------- /3D_LessNet_Diff/L2ss_2_Chan_24_Smth_2.0_LR_0.0001_Val/DiceVal_0.75726_Epoch_000000408.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet_Diff/L2ss_2_Chan_24_Smth_2.0_LR_0.0001_Val/DiceVal_0.75726_Epoch_000000408.pth -------------------------------------------------------------------------------- /3D_LessNet_Diff/L2ss_2_Chan_8_Smth_2.0_LR_0.0001_Val/DiceVal_0.74972_Epoch_000000498.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet_Diff/L2ss_2_Chan_8_Smth_2.0_LR_0.0001_Val/DiceVal_0.74972_Epoch_000000498.pth -------------------------------------------------------------------------------- /3D_LessNet_Diff/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | 7 | 8 | class UNet(nn.Module): 9 | def __init__(self, in_channel, n_classes, start_channel): 10 | self.in_channel = in_channel 11 | self.n_classes = n_classes 12 | self.start_channel = start_channel 13 | 14 | bias_opt = True 15 | 16 | super(UNet, self).__init__() 17 | self.eninput = self.encoder(self.in_channel, self.start_channel * 4, bias=bias_opt) 18 | 19 | self.dc1 = self.encoder(self.start_channel * 4, self.start_channel * 4, kernel_size=3, 20 | stride=1, bias=bias_opt) 21 | self.dc2 = self.encoder(self.start_channel * 4, self.start_channel * 4, kernel_size=3, stride=1, bias=bias_opt) 22 | self.dc3 = self.encoder(self.start_channel * 3 + 6, self.start_channel * 3, kernel_size=3, 23 | stride=1, bias=bias_opt) 24 | self.dc4 = self.encoder(self.start_channel * 3, self.start_channel * 3, kernel_size=3, stride=1, bias=bias_opt) 25 | self.dc5 = self.encoder(self.start_channel * 2 + 6, self.start_channel * 2, kernel_size=3, 26 | stride=1, bias=bias_opt) 27 | self.dc6 = self.encoder(self.start_channel * 2, self.start_channel * 2, kernel_size=3, stride=1, bias=bias_opt) 28 | self.dc7 = self.encoder(self.start_channel * 1 + 2, self.start_channel * 1, kernel_size=3, 29 | stride=1, bias=bias_opt) 30 | self.dc8 = self.encoder(self.start_channel * 1, self.start_channel * 1, kernel_size=3, stride=1, bias=bias_opt) 31 | self.dc9 = self.outputs(self.start_channel * 1, self.n_classes, kernel_size=3, stride=1, padding=1, bias=False) 32 | 33 | self.up2 = self.decoder(self.start_channel * 4, self.start_channel * 3) 34 | self.up3 = self.decoder(self.start_channel * 3, self.start_channel * 2) 35 | self.up4 = self.decoder(self.start_channel * 2, self.start_channel * 1) 36 | 37 | self.maxp_1_2 = nn.MaxPool3d(2, stride=2) 38 | self.maxp_1_4 = nn.MaxPool3d(4, stride=4) 39 | self.maxp_1_8 = nn.MaxPool3d(8, stride=8) 40 | self.avgp_1_2 = nn.AvgPool3d(2, stride=2) 41 | self.avgp_1_4 = nn.AvgPool3d(4, stride=4) 42 | self.avgp_1_8 = nn.AvgPool3d(8, stride=8) 43 | 44 | def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False, batchnorm=False): 45 | if batchnorm: 46 | layer = nn.Sequential( 47 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 48 | nn.BatchNorm3d(out_channels), 49 | nn.LeakyReLU()) 50 | else: 51 | layer = nn.Sequential( 52 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 53 | nn.LeakyReLU()) 54 | return layer 55 | 56 | def decoder(self, in_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True): 57 | layer = nn.Sequential( 58 | nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, 59 | padding=padding, output_padding=output_padding, bias=bias), 60 | nn.LeakyReLU()) 61 | return layer 62 | 63 | def outputs(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 64 | bias=False, batchnorm=False): 65 | if batchnorm: 66 | layer = nn.Sequential( 67 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 68 | nn.BatchNorm2d(out_channels), 69 | nn.Tanh()) 70 | else: 71 | layer = nn.Sequential( 72 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 73 | nn.Softsign()) 74 | return layer 75 | 76 | 77 | def forward(self, x, y): 78 | x_avg_1_8 = self.avgp_1_8(x) 79 | y_avg_1_8 = self.avgp_1_8(y) 80 | x_max_1_8 = self.maxp_1_8(x) 81 | y_max_1_8 = self.maxp_1_8(y) 82 | x_min_1_8 = -self.maxp_1_8(-x) 83 | y_min_1_8 = -self.maxp_1_8(-y) 84 | x_in = torch.cat((x_avg_1_8, y_avg_1_8, x_max_1_8, y_max_1_8, x_min_1_8, y_min_1_8), 1) 85 | d0 = self.eninput(x_in) 86 | 87 | # x_1 = torch.nn.functional.interpolate(x, size=[40, 48], mode='bilinear') 88 | # y_1 = torch.nn.functional.interpolate(y, size=[40, 48], mode='bilinear') 89 | # x_1_minus = torch.nn.functional.interpolate(x-y, size=[40, 48], mode='bilinear') 90 | # y_1_minus = torch.nn.functional.interpolate(y-x, size=[40, 48], mode='bilinear') 91 | # x_1_minus = torch.nn.functional.interpolate(x-y, size=[40, 48], mode='bilinear') 92 | # y_1_minus = torch.nn.functional.interpolate(y-x, size=[40, 48], mode='bilinear') 93 | # e0 = self.ec1(e0) 94 | 95 | # e1 = self.ec2(e0) 96 | # e1 = self.ec3(e1) 97 | 98 | # e2 = self.ec4(e1) 99 | # e2 = self.ec5(e2) 100 | 101 | # e3 = self.ec6(e2) 102 | # e3 = self.ec7(e3) 103 | 104 | # e4 = self.ec8(e3) 105 | # e4 = self.ec9(e4) 106 | 107 | # d0 = torch.cat((self.up1(e4), e3), 1) 108 | 109 | d0 = self.dc1(d0) 110 | d0 = self.dc2(d0) 111 | 112 | d1 = self.up2(d0) 113 | x_avg_1_4 = self.avgp_1_4(x) 114 | y_avg_1_4 = self.avgp_1_4(y) 115 | x_max_1_4 = self.maxp_1_4(x) 116 | y_max_1_4 = self.maxp_1_4(y) 117 | x_min_1_4 = -self.maxp_1_4(-x) 118 | y_min_1_4 = -self.maxp_1_4(-y) 119 | # print(x_avg_1_4.shape) 120 | # print(x_avg_1_4[0,0,20:30,20:30], x_max_1_4[0,0,20:30,20:30], x_min_1_4[0,0,20:30,20:30]) 121 | # assert 0==1 122 | # x_in = torch.cat((), 1) 123 | d1 = torch.cat((d1, x_avg_1_4, y_avg_1_4, x_max_1_4, y_max_1_4, x_min_1_4, y_min_1_4),1) 124 | 125 | d1 = self.dc3(d1) 126 | d1 = self.dc4(d1) 127 | d2 = self.up3(d1) 128 | 129 | 130 | # x_1_2 = torch.nn.functional.interpolate(x, size=[80, 96], mode='bilinear') 131 | # y_1_2 = torch.nn.functional.interpolate(y, size=[80, 96], mode='bilinear') 132 | # x_1_2_ = torch.nn.functional.interpolate(x-y, size=[80, 96], mode='bilinear') 133 | # y_1_2_ = torch.nn.functional.interpolate(y-x, size=[80, 96], mode='bilinear') 134 | x_avg_1_2 = self.avgp_1_2(x) 135 | y_avg_1_2 = self.avgp_1_2(y) 136 | x_max_1_2 = self.maxp_1_2(x) 137 | y_max_1_2 = self.maxp_1_2(y) 138 | x_min_1_2 = -self.maxp_1_2(-x) 139 | y_min_1_2 = -self.maxp_1_2(-y) 140 | d2 = torch.cat((d2, x_avg_1_2, y_avg_1_2, x_max_1_2, y_max_1_2, x_min_1_2, y_min_1_2), 1) 141 | 142 | 143 | d2 = self.dc5(d2) 144 | d2 = self.dc6(d2) 145 | d3 = self.up4(d2) 146 | 147 | d3 = torch.cat((d3, x, y), 1) 148 | 149 | d3 = self.dc7(d3) 150 | d3 = self.dc8(d3) 151 | 152 | f_xy = self.dc9(d3) 153 | 154 | 155 | return f_xy 156 | 157 | 158 | class SpatialTransform(nn.Module): 159 | def __init__(self): 160 | super(SpatialTransform, self).__init__() 161 | def forward(self, mov_image, flow, mod = 'bilinear'): 162 | d2, h2, w2 = mov_image.shape[-3:] 163 | 164 | grid_d, grid_h, grid_w = torch.meshgrid([torch.linspace(-1, 1, d2), torch.linspace(-1, 1, h2), torch.linspace(-1, 1, w2)]) 165 | 166 | grid_h = grid_h.to(flow.device).float() 167 | grid_d = grid_d.to(flow.device).float() 168 | grid_w = grid_w.to(flow.device).float() 169 | grid_d = nn.Parameter(grid_d, requires_grad=False) 170 | grid_w = nn.Parameter(grid_w, requires_grad=False) 171 | grid_h = nn.Parameter(grid_h, requires_grad=False) 172 | flow_d = flow[:,:,:,:,0] 173 | flow_h = flow[:,:,:,:,1] 174 | flow_w = flow[:,:,:,:,2] 175 | #Softsign 176 | #disp_d = (grid_d + (flow_d * 2 / d2)).squeeze(1) 177 | #disp_h = (grid_h + (flow_h * 2 / h2)).squeeze(1) 178 | #disp_w = (grid_w + (flow_w * 2 / w2)).squeeze(1) 179 | 180 | # Remove Channel Dimension 181 | disp_d = (grid_d + (flow_d)).squeeze(1) 182 | disp_h = (grid_h + (flow_h)).squeeze(1) 183 | disp_w = (grid_w + (flow_w)).squeeze(1) 184 | sample_grid = torch.stack((disp_w, disp_h, disp_d), 4) # shape (N, D, H, W, 3) 185 | warped = torch.nn.functional.grid_sample(mov_image, sample_grid, mode = mod, align_corners = True) 186 | 187 | return warped 188 | 189 | class DiffeomorphicTransform(nn.Module): 190 | def __init__(self, time_step=7): 191 | super(DiffeomorphicTransform, self).__init__() 192 | self.time_step = time_step 193 | 194 | def forward(self, flow): 195 | 196 | # print(flow.shape) 197 | d2, h2, w2 = flow.shape[-3:] 198 | grid_d, grid_h, grid_w = torch.meshgrid([torch.linspace(-1, 1, d2), torch.linspace(-1, 1, h2), torch.linspace(-1, 1, w2)]) 199 | grid_h = grid_h.to(flow.device).float() 200 | grid_d = grid_d.to(flow.device).float() 201 | grid_w = grid_w.to(flow.device).float() 202 | grid_d = nn.Parameter(grid_d, requires_grad=False) 203 | grid_w = nn.Parameter(grid_w, requires_grad=False) 204 | grid_h = nn.Parameter(grid_h, requires_grad=False) 205 | flow = flow / (2 ** self.time_step) 206 | 207 | for i in range(self.time_step): 208 | flow_d = flow[:,0,:,:,:] 209 | flow_h = flow[:,1,:,:,:] 210 | flow_w = flow[:,2,:,:,:] 211 | disp_d = (grid_d + flow_d).squeeze(1) 212 | disp_h = (grid_h + flow_h).squeeze(1) 213 | disp_w = (grid_w + flow_w).squeeze(1) 214 | 215 | deformation = torch.stack((disp_w, disp_h, disp_d), 4) # shape (N, D, H, W, 3) 216 | flow = flow + torch.nn.functional.grid_sample(flow, deformation, mode='bilinear', padding_mode="border", align_corners = True) 217 | return flow 218 | 219 | def smoothloss(y_pred): 220 | #print('smoothloss y_pred.shape ',y_pred.shape) 221 | #[N,3,D,H,W] 222 | d2, h2, w2 = y_pred.shape[-3:] 223 | dy = torch.abs(y_pred[:,:,1:, :, :] - y_pred[:,:, :-1, :, :]) / 2 * d2 224 | dx = torch.abs(y_pred[:,:,:, 1:, :] - y_pred[:,:, :, :-1, :]) / 2 * h2 225 | dz = torch.abs(y_pred[:,:,:, :, 1:] - y_pred[:,:, :, :, :-1]) / 2 * w2 226 | return (torch.mean(dx * dx)+torch.mean(dy*dy)+torch.mean(dz*dz))/3.0 227 | 228 | """ 229 | Normalized local cross-correlation function in Pytorch. Modified from https://github.com/voxelmorph/voxelmorph. 230 | """ 231 | class NCC(torch.nn.Module): 232 | """ 233 | local (over window) normalized cross correlation 234 | """ 235 | def __init__(self, win= 9, eps=1e-5): 236 | super(NCC, self).__init__() 237 | self.win_raw = win 238 | self.eps = eps 239 | self.win = win 240 | 241 | def forward(self, I, J): 242 | ndims = 3 243 | win_size = self.win_raw 244 | self.win = [self.win_raw] * ndims 245 | 246 | weight_win_size = self.win_raw 247 | weight = torch.ones((1, 1, weight_win_size, weight_win_size, weight_win_size), device=I.device, requires_grad=False) 248 | conv_fn = F.conv3d 249 | 250 | # compute CC squares 251 | I2 = I*I 252 | J2 = J*J 253 | IJ = I*J 254 | 255 | # compute filters 256 | # compute local sums via convolution 257 | I_sum = conv_fn(I, weight, padding=int(win_size/2)) 258 | J_sum = conv_fn(J, weight, padding=int(win_size/2)) 259 | I2_sum = conv_fn(I2, weight, padding=int(win_size/2)) 260 | J2_sum = conv_fn(J2, weight, padding=int(win_size/2)) 261 | IJ_sum = conv_fn(IJ, weight, padding=int(win_size/2)) 262 | 263 | # compute cross correlation 264 | win_size = np.prod(self.win) 265 | u_I = I_sum/win_size 266 | u_J = J_sum/win_size 267 | 268 | cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size 269 | I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size 270 | J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size 271 | 272 | cc = cross * cross / (I_var * J_var + self.eps) 273 | 274 | # return negative cc. 275 | return -1.0 * torch.mean(cc) 276 | 277 | class MSE: 278 | """ 279 | Mean squared error loss. 280 | """ 281 | 282 | def loss(self, y_true, y_pred): 283 | return torch.mean((y_true - y_pred) ** 2) 284 | 285 | class SAD: 286 | """ 287 | Mean squared error loss. 288 | """ 289 | 290 | def loss(self, y_true, y_pred): 291 | return torch.mean(torch.abs(y_true - y_pred)) 292 | -------------------------------------------------------------------------------- /3D_LessNet_Diff/compute_dsc_jet_from_quantiResult.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv, sys 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import matplotlib.font_manager as font_manager 6 | from scipy.stats import wilcoxon, ttest_rel, ttest_ind 7 | 8 | outstruct = ['Brain-Stem', 'Thalamus', 'Cerebellum-Cortex', 'Cerebral-White-Matter', 'Cerebellum-White-Matter', 'Putamen', 'VentralDC', 'Pallidum', 'Caudate', 'Lateral-Ventricle', 'Hippocampus', 9 | '3rd-Ventricle', '4th-Ventricle', 'Amygdala', 'Cerebral-Cortex', 'CSF', 'choroid-plexus'] 10 | exp_data = np.zeros((len(outstruct), 115)) 11 | stct_i = 0 12 | file_dir = './Quantitative_Results/' 13 | for stct in outstruct: 14 | tar_idx = [] 15 | with open(file_dir+'L2ss_2_Chan_16_Smth_5.0_LR_0.0001_Test_Bilinear.csv', "r") as f: 16 | reader = csv.reader(f, delimiter="\t") 17 | for i, line in enumerate(reader): 18 | if i == 1: 19 | names = line[0].split(',') 20 | idx = 0 21 | for item in names: 22 | if stct in item: 23 | tar_idx.append(idx) 24 | idx += 1 25 | elif i>1: 26 | if line[0].split(',')[1]=='': 27 | continue 28 | val = 0 29 | for lr_i in tar_idx: 30 | vals = line[0].split(',') 31 | val += float(vals[lr_i]) 32 | val = val/len(tar_idx) 33 | exp_data[stct_i, i-2] = val 34 | stct_i+=1 35 | # all_dsc.append(exp_data.mean(axis=0)) 36 | print(exp_data.mean()) 37 | print(exp_data.std()) 38 | my_list = [] 39 | with open(file_dir+'L2ss_2_Chan_16_Smth_5.0_LR_0.0001_Test_Bilinear.csv', newline='') as f: 40 | reader = csv.reader(f) 41 | my_list = [row[-1] for row in reader] 42 | my_list = my_list[2:] 43 | my_list = np.array([float(i) for i in my_list])*100 44 | print('jec_det: {:.3f} +- {:.3f}'.format(my_list.mean(), my_list.std())) 45 | -------------------------------------------------------------------------------- /3D_LessNet_Diff/data/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np -------------------------------------------------------------------------------- /3D_LessNet_Diff/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet_Diff/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /3D_LessNet_Diff/data/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet_Diff/data/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /3D_LessNet_Diff/data/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet_Diff/data/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /3D_LessNet_Diff/data/__pycache__/rand.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet_Diff/data/__pycache__/rand.cpython-38.pyc -------------------------------------------------------------------------------- /3D_LessNet_Diff/data/__pycache__/trans.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xi-jia/LessNet/7cc82b99bb210da453d5e43075b94d4d05dc922a/3D_LessNet_Diff/data/__pycache__/trans.cpython-38.pyc -------------------------------------------------------------------------------- /3D_LessNet_Diff/data/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pickle 3 | import numpy as np 4 | import torch 5 | 6 | M = 2 ** 32 - 1 7 | 8 | 9 | def init_fn(worker): 10 | seed = torch.LongTensor(1).random_().item() 11 | seed = (seed + worker) % M 12 | np.random.seed(seed) 13 | random.seed(seed) 14 | 15 | 16 | def add_mask(x, mask, dim=1): 17 | mask = mask.unsqueeze(dim) 18 | shape = list(x.shape); 19 | shape[dim] += 21 20 | new_x = x.new(*shape).zero_() 21 | new_x = new_x.scatter_(dim, mask, 1.0) 22 | s = [slice(None)] * len(shape) 23 | s[dim] = slice(21, None) 24 | new_x[s] = x 25 | return new_x 26 | 27 | 28 | def sample(x, size): 29 | # https://gist.github.com/yoavram/4134617 30 | i = random.sample(range(x.shape[0]), size) 31 | return torch.tensor(x[i], dtype=torch.int16) 32 | # x = np.random.permutation(x) 33 | # return torch.tensor(x[:size]) 34 | 35 | 36 | def pkload(fname): 37 | with open(fname, 'rb') as f: 38 | return pickle.load(f) 39 | 40 | 41 | _shape = (240, 240, 155) 42 | 43 | 44 | def get_all_coords(stride): 45 | return torch.tensor( 46 | np.stack([v.reshape(-1) for v in 47 | np.meshgrid( 48 | *[stride // 2 + np.arange(0, s, stride) for s in _shape], 49 | indexing='ij')], 50 | -1), dtype=torch.int16) 51 | 52 | 53 | _zero = torch.tensor([0]) 54 | 55 | 56 | def gen_feats(): 57 | x, y, z = 240, 240, 155 58 | feats = np.stack( 59 | np.meshgrid( 60 | np.arange(x), np.arange(y), np.arange(z), 61 | indexing='ij'), -1).astype('float32') 62 | shape = np.array([x, y, z]) 63 | feats -= shape / 2.0 64 | feats /= shape 65 | 66 | return feats -------------------------------------------------------------------------------- /3D_LessNet_Diff/data/datasets.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import torch, sys 3 | from torch.utils.data import Dataset 4 | from .data_utils import pkload 5 | import matplotlib.pyplot as plt 6 | 7 | import numpy as np 8 | 9 | 10 | class IXIBrainDataset(Dataset): 11 | def __init__(self, data_path, atlas_path, transforms): 12 | self.paths = data_path 13 | self.atlas_path = atlas_path 14 | self.transforms = transforms 15 | 16 | def one_hot(self, img, C): 17 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3])) 18 | for i in range(C): 19 | out[i,...] = img == i 20 | return out 21 | 22 | def __getitem__(self, index): 23 | path = self.paths[index] 24 | x, x_seg = pkload(self.atlas_path) 25 | y, y_seg = pkload(path) 26 | #print(x.shape) 27 | #print(x.shape) 28 | #print(np.unique(y)) 29 | # print(x.shape, y.shape)#(240, 240, 155) (240, 240, 155) 30 | # transforms work with nhwtc 31 | x, y = x[None, ...], y[None, ...] 32 | # print(x.shape, y.shape)#(1, 240, 240, 155) (1, 240, 240, 155) 33 | x,y = self.transforms([x, y]) 34 | #y = self.one_hot(y, 2) 35 | #print(y.shape) 36 | #sys.exit(0) 37 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth] 38 | y = np.ascontiguousarray(y) 39 | #plt.figure() 40 | #plt.subplot(1, 2, 1) 41 | #plt.imshow(x[0, :, :, 8], cmap='gray') 42 | #plt.subplot(1, 2, 2) 43 | #plt.imshow(y[0, :, :, 8], cmap='gray') 44 | #plt.show() 45 | #sys.exit(0) 46 | #y = np.squeeze(y, axis=0) 47 | x, y = torch.from_numpy(x), torch.from_numpy(y) 48 | return x, y 49 | 50 | def __len__(self): 51 | return len(self.paths) 52 | 53 | 54 | class IXIBrainInferDataset(Dataset): 55 | def __init__(self, data_path, atlas_path, transforms): 56 | self.atlas_path = atlas_path 57 | self.paths = data_path 58 | self.transforms = transforms 59 | 60 | def one_hot(self, img, C): 61 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3])) 62 | for i in range(C): 63 | out[i,...] = img == i 64 | return out 65 | 66 | def __getitem__(self, index): 67 | path = self.paths[index] 68 | x, x_seg = pkload(self.atlas_path) 69 | y, y_seg = pkload(path) 70 | x, y = x[None, ...], y[None, ...] 71 | x_seg, y_seg= x_seg[None, ...], y_seg[None, ...] 72 | x, x_seg = self.transforms([x, x_seg]) 73 | y, y_seg = self.transforms([y, y_seg]) 74 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth] 75 | y = np.ascontiguousarray(y) 76 | x_seg = np.ascontiguousarray(x_seg) # [Bsize,channelsHeight,,Width,Depth] 77 | y_seg = np.ascontiguousarray(y_seg) 78 | x, y, x_seg, y_seg = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(x_seg), torch.from_numpy(y_seg) 79 | return x, y, x_seg, y_seg 80 | 81 | def __len__(self): 82 | return len(self.paths) -------------------------------------------------------------------------------- /3D_LessNet_Diff/data/rand.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | class Uniform(object): 5 | def __init__(self, a, b): 6 | self.a = a 7 | self.b = b 8 | 9 | def sample(self): 10 | return random.uniform(self.a, self.b) 11 | 12 | 13 | class Gaussian(object): 14 | def __init__(self, mean, std): 15 | self.mean = mean 16 | self.std = std 17 | 18 | def sample(self): 19 | return random.gauss(self.mean, self.std) 20 | 21 | 22 | class Constant(object): 23 | def __init__(self, val): 24 | self.val = val 25 | 26 | def sample(self): 27 | return self.val -------------------------------------------------------------------------------- /3D_LessNet_Diff/data/trans.py: -------------------------------------------------------------------------------- 1 | # import math 2 | import random 3 | import collections 4 | import numpy as np 5 | import torch, sys, random, math 6 | from scipy import ndimage 7 | 8 | from .rand import Constant, Uniform, Gaussian 9 | from scipy.ndimage import rotate 10 | from skimage.transform import rescale, resize 11 | 12 | class Base(object): 13 | def sample(self, *shape): 14 | return shape 15 | 16 | def tf(self, img, k=0): 17 | return img 18 | 19 | def __call__(self, img, dim=3, reuse=False): # class -> func() 20 | # image: nhwtc 21 | # shape: no first dim 22 | if not reuse: 23 | im = img if isinstance(img, np.ndarray) else img[0] 24 | # how to know if the last dim is channel?? 25 | # nhwtc vs nhwt?? 26 | shape = im.shape[1:dim+1] 27 | # print(dim,shape) # 3, (240,240,155) 28 | self.sample(*shape) 29 | 30 | if isinstance(img, collections.Sequence): 31 | return [self.tf(x, k) for k, x in enumerate(img)] # img:k=0,label:k=1 32 | 33 | return self.tf(img) 34 | 35 | def __str__(self): 36 | return 'Identity()' 37 | 38 | Identity = Base 39 | 40 | # gemetric transformations, need a buffers 41 | # first axis is N 42 | class Rot90(Base): 43 | def __init__(self, axes=(0, 1)): 44 | self.axes = axes 45 | 46 | for a in self.axes: 47 | assert a > 0 48 | 49 | def sample(self, *shape): 50 | shape = list(shape) 51 | i, j = self.axes 52 | 53 | # shape: no first dim 54 | i, j = i-1, j-1 55 | shape[i], shape[j] = shape[j], shape[i] 56 | 57 | return shape 58 | 59 | def tf(self, img, k=0): 60 | return np.rot90(img, axes=self.axes) 61 | 62 | def __str__(self): 63 | return 'Rot90(axes=({}, {})'.format(*self.axes) 64 | 65 | # class RandomRotion(Base): 66 | # def __init__(self, angle=20):# angle :in degress, float, [0,360] 67 | # assert angle >= 0.0 68 | # self.axes = (0,1) # 只对HW方向进行旋转 69 | # self.angle = angle # 70 | # self.buffer = None 71 | # 72 | # def sample(self, *shape):# shape : [H,W,D] 73 | # shape = list(shape) 74 | # self.buffer = round(np.random.uniform(low=-self.angle,high=self.angle),2) # 2个小数点 75 | # if self.buffer < 0: 76 | # self.buffer += 180 77 | # return shape 78 | # 79 | # def tf(self, img, k=0): # img shape [1,H,W,D,c] while label shape is [1,H,W,D] 80 | # return ndimage.rotate(img, angle=self.buffer, reshape=False) 81 | # 82 | # def __str__(self): 83 | # return 'RandomRotion(axes=({}, {}),Angle:{}'.format(*self.axes,self.buffer) 84 | 85 | class RandomRotion(Base): 86 | def __init__(self,angle_spectrum=10): 87 | assert isinstance(angle_spectrum,int) 88 | # axes = [(2, 1), (3, 1),(3, 2)] 89 | axes = [(1, 0), (2, 1),(2, 0)] 90 | self.angle_spectrum = angle_spectrum 91 | self.axes = axes 92 | 93 | def sample(self,*shape): 94 | self.axes_buffer = self.axes[np.random.choice(list(range(len(self.axes))))] # choose the random direction 95 | self.angle_buffer = np.random.randint(-self.angle_spectrum, self.angle_spectrum) # choose the random direction 96 | return list(shape) 97 | 98 | def tf(self, img, k=0): 99 | """ Introduction: The rotation function supports the shape [H,W,D,C] or shape [H,W,D] 100 | :param img: if x, shape is [1,H,W,D,c]; if label, shape is [1,H,W,D] 101 | :param k: if x, k=0; if label, k=1 102 | """ 103 | bsize = img.shape[0] 104 | 105 | for bs in range(bsize): 106 | if k == 0: 107 | # [[H,W,D], ...] 108 | # print(img.shape) # (1, 128, 128, 128, 4) 109 | channels = [rotate(img[bs,:,:,:,c], self.angle_buffer, axes=self.axes_buffer, reshape=False, order=0, mode='constant', cval=-1) for c in 110 | range(img.shape[4])] 111 | img[bs,...] = np.stack(channels, axis=-1) 112 | 113 | if k == 1: 114 | img[bs,...] = rotate(img[bs,...], self.angle_buffer, axes=self.axes_buffer, reshape=False, order=0, mode='constant', cval=-1) 115 | 116 | return img 117 | 118 | def __str__(self): 119 | return 'RandomRotion(axes={},Angle:{}'.format(self.axes_buffer,self.angle_buffer) 120 | 121 | 122 | class Flip(Base): 123 | def __init__(self, axis=0): 124 | self.axis = axis 125 | 126 | def tf(self, img, k=0): 127 | return np.flip(img, self.axis) 128 | 129 | def __str__(self): 130 | return 'Flip(axis={})'.format(self.axis) 131 | 132 | class RandomFlip(Base): 133 | # mirror flip across all x,y,z 134 | def __init__(self,axis=0): 135 | # assert axis == (1,2,3) # For both data and label, it has to specify the axis. 136 | self.axis = (1,2,3) 137 | self.x_buffer = None 138 | self.y_buffer = None 139 | self.z_buffer = None 140 | 141 | def sample(self, *shape): 142 | self.x_buffer = np.random.choice([True,False]) 143 | self.y_buffer = np.random.choice([True,False]) 144 | self.z_buffer = np.random.choice([True,False]) 145 | return list(shape) # the shape is not changed 146 | 147 | def tf(self,img,k=0): # img shape is (1, 240, 240, 155, 4) 148 | if self.x_buffer: 149 | img = np.flip(img,axis=self.axis[0]) 150 | if self.y_buffer: 151 | img = np.flip(img,axis=self.axis[1]) 152 | if self.z_buffer: 153 | img = np.flip(img,axis=self.axis[2]) 154 | return img 155 | 156 | 157 | class RandSelect(Base): 158 | def __init__(self, prob=0.5, tf=None): 159 | self.prob = prob 160 | self.ops = tf if isinstance(tf, collections.Sequence) else (tf, ) 161 | self.buff = False 162 | 163 | def sample(self, *shape): 164 | self.buff = random.random() < self.prob 165 | 166 | if self.buff: 167 | for op in self.ops: 168 | shape = op.sample(*shape) 169 | 170 | return shape 171 | 172 | def tf(self, img, k=0): 173 | if self.buff: 174 | for op in self.ops: 175 | img = op.tf(img, k) 176 | return img 177 | 178 | def __str__(self): 179 | if len(self.ops) == 1: 180 | ops = str(self.ops[0]) 181 | else: 182 | ops = '[{}]'.format(', '.join([str(op) for op in self.ops])) 183 | return 'RandSelect({}, {})'.format(self.prob, ops) 184 | 185 | 186 | class CenterCrop(Base): 187 | def __init__(self, size): 188 | self.size = size 189 | self.buffer = None 190 | 191 | def sample(self, *shape): 192 | size = self.size 193 | start = [(s -size)//2 for s in shape] 194 | self.buffer = [slice(None)] + [slice(s, s+size) for s in start] 195 | return [size] * len(shape) 196 | 197 | def tf(self, img, k=0): 198 | # print(img.shape)#(1, 240, 240, 155, 4) 199 | return img[tuple(self.buffer)] 200 | # return img[self.buffer] 201 | 202 | def __str__(self): 203 | return 'CenterCrop({})'.format(self.size) 204 | 205 | class CenterCropBySize(CenterCrop): 206 | def sample(self, *shape): 207 | assert len(self.size) == 3 # random crop [H,W,T] from img [240,240,155] 208 | if not isinstance(self.size, list): 209 | size = list(self.size) 210 | else: 211 | size = self.size 212 | start = [(s-i)//2 for i, s in zip(size, shape)] 213 | self.buffer = [slice(None)] + [slice(s, s+i) for i, s in zip(size, start)] 214 | return size 215 | 216 | def __str__(self): 217 | return 'CenterCropBySize({})'.format(self.size) 218 | 219 | class RandCrop(CenterCrop): 220 | def sample(self, *shape): 221 | size = self.size 222 | start = [random.randint(0, s-size) for s in shape] 223 | self.buffer = [slice(None)] + [slice(s, s+size) for s in start] 224 | return [size]*len(shape) 225 | 226 | def __str__(self): 227 | return 'RandCrop({})'.format(self.size) 228 | 229 | 230 | class RandCrop3D(CenterCrop): 231 | def sample(self, *shape): # shape : [240,240,155] 232 | assert len(self.size)==3 # random crop [H,W,T] from img [240,240,155] 233 | if not isinstance(self.size,list): 234 | size = list(self.size) 235 | else: 236 | size = self.size 237 | start = [random.randint(0, s-i) for i,s in zip(size,shape)] 238 | self.buffer = [slice(None)] + [slice(s, s+k) for s,k in zip(start,size)] 239 | return size 240 | 241 | def __str__(self): 242 | return 'RandCrop({})'.format(self.size) 243 | 244 | # for data only 245 | class RandomIntensityChange(Base): 246 | def __init__(self,factor): 247 | shift,scale = factor 248 | assert (shift >0) and (scale >0) 249 | self.shift = shift 250 | self.scale = scale 251 | 252 | def tf(self,img,k=0): 253 | if k==1: 254 | return img 255 | 256 | shift_factor = np.random.uniform(-self.shift,self.shift,size=[1,img.shape[1],1,1,img.shape[4]]) # [-0.1,+0.1] 257 | scale_factor = np.random.uniform(1.0 - self.scale, 1.0 + self.scale,size=[1,img.shape[1],1,1,img.shape[4]]) # [0.9,1.1) 258 | # shift_factor = np.random.uniform(-self.shift,self.shift,size=[1,1,1,img.shape[3],img.shape[4]]) # [-0.1,+0.1] 259 | # scale_factor = np.random.uniform(1.0 - self.scale, 1.0 + self.scale,size=[1,1,1,img.shape[3],img.shape[4]]) # [0.9,1.1) 260 | return img * scale_factor + shift_factor 261 | 262 | def __str__(self): 263 | return 'random intensity shift per channels on the input image, including' 264 | 265 | class RandomGammaCorrection(Base): 266 | def __init__(self,factor): 267 | lower, upper = factor 268 | assert (lower >0) and (upper >0) 269 | self.lower = lower 270 | self.upper = upper 271 | 272 | def tf(self,img,k=0): 273 | if k==1: 274 | return img 275 | img = img + np.min(img) 276 | img_max = np.max(img) 277 | img = img/img_max 278 | factor = random.choice(np.arange(self.lower, self.upper, 0.1)) 279 | gamma = random.choice([1, factor]) 280 | if gamma == 1: 281 | return img 282 | img = img ** gamma * img_max 283 | img = (img - img.mean())/img.std() 284 | return img 285 | 286 | def __str__(self): 287 | return 'random intensity shift per channels on the input image, including' 288 | 289 | class MinMax_norm(Base): 290 | def __init__(self, ): 291 | a = None 292 | 293 | def tf(self, img, k=0): 294 | if k == 1: 295 | return img 296 | img = (img - img.min()) / (img.max()-img.min()) 297 | return img 298 | 299 | class Seg_norm(Base): 300 | def __init__(self, ): 301 | a = None 302 | self.seg_table = np.array([0, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 303 | 28, 30, 31, 41, 42, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 62, 304 | 63, 72, 77, 80, 85, 251, 252, 253, 254, 255]) 305 | def tf(self, img, k=0): 306 | if k == 0: 307 | return img 308 | img_out = np.zeros_like(img) 309 | for i in range(len(self.seg_table)): 310 | img_out[img == self.seg_table[i]] = i 311 | return img_out 312 | 313 | class Resize_img(Base): 314 | def __init__(self, shape): 315 | self.shape = shape 316 | 317 | def tf(self, img, k=0): 318 | if k == 1: 319 | img = resize(img, (img.shape[0], self.shape[0], self.shape[1], self.shape[2]), 320 | anti_aliasing=False, order=0) 321 | else: 322 | img = resize(img, (img.shape[0], self.shape[0], self.shape[1], self.shape[2]), 323 | anti_aliasing=False, order=3) 324 | return img 325 | 326 | class Pad(Base): 327 | def __init__(self, pad): # [0,0,0,5,0] 328 | self.pad = pad 329 | self.px = tuple(zip([0]*len(pad), pad)) 330 | 331 | def sample(self, *shape): 332 | 333 | shape = list(shape) 334 | 335 | # shape: no first dim 336 | for i in range(len(shape)): 337 | shape[i] += self.pad[i+1] 338 | 339 | return shape 340 | 341 | def tf(self, img, k=0): 342 | #nhwtc, nhwt 343 | dim = len(img.shape) 344 | return np.pad(img, self.px[:dim], mode='constant') 345 | 346 | def __str__(self): 347 | return 'Pad(({}, {}, {}))'.format(*self.pad) 348 | 349 | class Pad3DIfNeeded(Base): 350 | def __init__(self, shape, value=0, mask_value=0): # [0,0,0,5,0] 351 | self.shape = shape 352 | self.value = value 353 | self.mask_value = mask_value 354 | 355 | def tf(self, img, k=0): 356 | pad = [(0,0)] 357 | if k==0: 358 | img_shape = img.shape[1:-1] 359 | else: 360 | img_shape = img.shape[1:] 361 | for i, t in zip(img_shape, self.shape): 362 | if i < t: 363 | diff = t-i 364 | pad.append((math.ceil(diff/2),math.floor(diff/2))) 365 | else: 366 | pad.append((0,0)) 367 | if k == 0: 368 | pad.append((0,0)) 369 | pad = tuple(pad) 370 | if k==0: 371 | return np.pad(img, pad, mode='constant', constant_values=img.min()) 372 | else: 373 | return np.pad(img, pad, mode='constant', constant_values=self.mask_value) 374 | 375 | def __str__(self): 376 | return 'Pad(({}, {}, {}))'.format(*self.pad) 377 | 378 | class Noise(Base): 379 | def __init__(self, dim, sigma=0.1, channel=True, num=-1): 380 | self.dim = dim 381 | self.sigma = sigma 382 | self.channel = channel 383 | self.num = num 384 | 385 | def tf(self, img, k=0): 386 | if self.num > 0 and k >= self.num: 387 | return img 388 | 389 | if self.channel: 390 | #nhwtc, hwtc, hwt 391 | shape = [1] if len(img.shape) < self.dim+2 else [img.shape[-1]] 392 | else: 393 | shape = img.shape 394 | return img * np.exp(self.sigma * torch.randn(shape, dtype=torch.float32).numpy()) 395 | 396 | def __str__(self): 397 | return 'Noise()' 398 | 399 | 400 | # dim could come from shape 401 | class GaussianBlur(Base): 402 | def __init__(self, dim, sigma=Constant(1.5), app=-1): 403 | # 1.5 pixel 404 | self.dim = dim 405 | self.sigma = sigma 406 | self.eps = 0.001 407 | self.app = app 408 | 409 | def tf(self, img, k=0): 410 | if self.num > 0 and k >= self.num: 411 | return img 412 | 413 | # image is nhwtc 414 | for n in range(img.shape[0]): 415 | sig = self.sigma.sample() 416 | # sample each channel saperately to avoid correlations 417 | if sig > self.eps: 418 | if len(img.shape) == self.dim+2: 419 | C = img.shape[-1] 420 | for c in range(C): 421 | img[n,..., c] = ndimage.gaussian_filter(img[n, ..., c], sig) 422 | elif len(img.shape) == self.dim+1: 423 | img[n] = ndimage.gaussian_filter(img[n], sig) 424 | else: 425 | raise ValueError('image shape is not supported') 426 | 427 | return img 428 | 429 | def __str__(self): 430 | return 'GaussianBlur()' 431 | 432 | 433 | class ToNumpy(Base): 434 | def __init__(self, num=-1): 435 | self.num = num 436 | 437 | def tf(self, img, k=0): 438 | if self.num > 0 and k >= self.num: 439 | return img 440 | return img.numpy() 441 | 442 | def __str__(self): 443 | return 'ToNumpy()' 444 | 445 | 446 | class ToTensor(Base): 447 | def __init__(self, num=-1): 448 | self.num = num 449 | 450 | def tf(self, img, k=0): 451 | if self.num > 0 and k >= self.num: 452 | return img 453 | 454 | return torch.from_numpy(img) 455 | 456 | def __str__(self): 457 | return 'ToTensor' 458 | 459 | 460 | class TensorType(Base): 461 | def __init__(self, types, num=-1): 462 | self.types = types # ('torch.float32', 'torch.int64') 463 | self.num = num 464 | 465 | def tf(self, img, k=0): 466 | if self.num > 0 and k >= self.num: 467 | return img 468 | # make this work with both Tensor and Numpy 469 | return img.type(self.types[k]) 470 | 471 | def __str__(self): 472 | s = ', '.join([str(s) for s in self.types]) 473 | return 'TensorType(({}))'.format(s) 474 | 475 | 476 | class NumpyType(Base): 477 | def __init__(self, types, num=-1): 478 | self.types = types # ('float32', 'int64') 479 | self.num = num 480 | 481 | def tf(self, img, k=0): 482 | if self.num > 0 and k >= self.num: 483 | return img 484 | # make this work with both Tensor and Numpy 485 | return img.astype(self.types[k]) 486 | 487 | def __str__(self): 488 | s = ', '.join([str(s) for s in self.types]) 489 | return 'NumpyType(({}))'.format(s) 490 | 491 | 492 | class Normalize(Base): 493 | def __init__(self, mean=0.0, std=1.0, num=-1): 494 | self.mean = mean 495 | self.std = std 496 | self.num = num 497 | 498 | def tf(self, img, k=0): 499 | if self.num > 0 and k >= self.num: 500 | return img 501 | img -= self.mean 502 | img /= self.std 503 | return img 504 | 505 | def __str__(self): 506 | return 'Normalize()' 507 | 508 | 509 | class Compose(Base): 510 | def __init__(self, ops): 511 | if not isinstance(ops, collections.Sequence): 512 | ops = ops, 513 | self.ops = ops 514 | 515 | def sample(self, *shape): 516 | for op in self.ops: 517 | shape = op.sample(*shape) 518 | 519 | def tf(self, img, k=0): 520 | #is_tensor = isinstance(img, torch.Tensor) 521 | #if is_tensor: 522 | # img = img.numpy() 523 | 524 | for op in self.ops: 525 | # print(op,img.shape,k) 526 | img = op.tf(img, k) # do not use op(img) here 527 | 528 | #if is_tensor: 529 | # img = np.ascontiguousarray(img) 530 | # img = torch.from_numpy(img) 531 | 532 | return img 533 | 534 | def __str__(self): 535 | ops = ', '.join([str(op) for op in self.ops]) 536 | return 'Compose([{}])'.format(ops) -------------------------------------------------------------------------------- /3D_LessNet_Diff/infer_bilinear.py: -------------------------------------------------------------------------------- 1 | import os, utils 2 | import glob 3 | import sys 4 | from argparse import ArgumentParser 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torch 8 | from torchvision import transforms 9 | from Models import * 10 | # from Functions import TrainDataset 11 | import torch.utils.data as Data 12 | from data import datasets, trans 13 | from natsort import natsorted 14 | from torch.utils.data import DataLoader 15 | import csv 16 | parser = ArgumentParser() 17 | parser.add_argument("--lr", type=float, 18 | dest="lr", default=1e-4, help="learning rate") 19 | parser.add_argument("--bs", type=int, 20 | dest="bs", default=1, help="batch_size") 21 | parser.add_argument("--iteration", type=int, 22 | dest="iteration", default=320001, 23 | help="number of total iterations") 24 | parser.add_argument("--smth_labda", type=float, 25 | dest="smth_labda", default=0.02, 26 | help="labda loss: suggested range 0.1 to 10") 27 | parser.add_argument("--checkpoint", type=int, 28 | dest="checkpoint", default=403, 29 | help="frequency of saving models") 30 | parser.add_argument("--start_channel", type=int, 31 | dest="start_channel", default=8, 32 | help="number of start channels") 33 | parser.add_argument("--trainingset", type=int, 34 | dest="trainingset", default=4, 35 | help="1 Half : 200 Images, 2 The other Half 200 Images 3 All 400 Images") 36 | parser.add_argument("--using_l2", type=int, 37 | dest="using_l2", 38 | default=1, 39 | help="using l2 or not") 40 | opt = parser.parse_args() 41 | 42 | lr = opt.lr 43 | bs = opt.bs 44 | iteration = opt.iteration 45 | start_channel = opt.start_channel 46 | n_checkpoint = opt.checkpoint 47 | smooth = opt.smth_labda 48 | trainingset = opt.trainingset 49 | using_l2 = opt.using_l2 50 | 51 | 52 | def main(): 53 | use_cuda = True 54 | device = torch.device("cuda" if use_cuda else "cpu") 55 | transform = SpatialTransform().cuda() 56 | diff_transform = DiffeomorphicTransform(time_step=7).cuda() 57 | atlas_dir = '/bask/projects/d/duanj-ai-imaging/UvT/TransMorph_Xi/IXI_Mine/IXI_data/atlas.pkl' 58 | test_dir = '/bask/projects/d/duanj-ai-imaging/UvT/TransMorph_Xi/IXI_Mine/IXI_data/Test/' 59 | model_idx = -1 60 | model_dir = './L2ss_{}_Chan_{}_Smth_{}_LR_{}_Val/'.format(using_l2,start_channel,smooth,lr) 61 | dict = utils.process_label() 62 | if not os.path.exists('Quantitative_Results/'): 63 | os.makedirs('Quantitative_Results/') 64 | if os.path.exists('Quantitative_Results/'+model_dir[:-1]+'_Test.csv'): 65 | os.remove('Quantitative_Results/'+model_dir[:-1]+'_Test.csv') 66 | csv_writter(model_dir[:-1], 'Quantitative_Results/' + model_dir[:-1]+'_Test') 67 | line = '' 68 | for i in range(46): 69 | line = line + ',' + dict[i] 70 | csv_writter(line +','+'non_jec', 'Quantitative_Results/' + model_dir[:-1]+'_Test') 71 | 72 | 73 | model = UNet(6, 3, start_channel).cuda() 74 | 75 | print(model_dir + natsorted(os.listdir(model_dir))[model_idx]) 76 | best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[model_idx])#['state_dict'] 77 | model.load_state_dict(best_model) 78 | model.cuda() 79 | # reg_model = utils.register_model(config.img_size, 'nearest') 80 | # reg_model.cuda() 81 | test_composed = transforms.Compose([trans.Seg_norm(), 82 | trans.NumpyType((np.float32, np.int16)), 83 | ]) 84 | test_set = datasets.IXIBrainInferDataset(glob.glob(test_dir + '*.pkl'), atlas_dir, transforms=test_composed) 85 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, drop_last=True) 86 | eval_dsc_def = utils.AverageMeter() 87 | eval_dsc_raw = utils.AverageMeter() 88 | eval_det = utils.AverageMeter() 89 | with torch.no_grad(): 90 | stdy_idx = 0 91 | for data in test_loader: 92 | model.eval() 93 | data = [t.cuda() for t in data] 94 | x = data[0] 95 | y = data[1] 96 | x_seg = data[2] 97 | y_seg = data[3] 98 | 99 | v_xy = model(x.float().to(device), y.float().to(device)) 100 | Dv_xy = diff_transform(v_xy) 101 | # Dv_xy = v_xy 102 | # def_out = reg_model([x_seg.cuda().float(), flow.cuda()]) 103 | def_out= transform(x_seg.float().to(device), Dv_xy.permute(0, 2, 3, 4, 1), mod = 'nearest') 104 | tar = y.detach().cpu().numpy()[0, 0, :, :, :] 105 | # print(f_xy.shape) #[1, 3, 160, 192, 224] 106 | dd, hh, ww = Dv_xy.shape[-3:] 107 | Dv_xy = Dv_xy.detach().cpu().numpy() 108 | Dv_xy[:,0,:,:,:] = Dv_xy[:,0,:,:,:] * dd / 2 109 | Dv_xy[:,1,:,:,:] = Dv_xy[:,1,:,:,:] * hh / 2 110 | Dv_xy[:,2,:,:,:] = Dv_xy[:,2,:,:,:] * ww / 2 111 | # jac_det = utils.jacobian_determinant_vxm(f_xy.detach().cpu().numpy()[0, :, :, :, :]) 112 | jac_det = utils.jacobian_determinant_vxm(Dv_xy[0, :, :, :, :]) 113 | line = utils.dice_val_substruct(def_out.long(), y_seg.long(), stdy_idx) 114 | line = line +','+str(np.sum(jac_det <= 0)/np.prod(tar.shape)) 115 | csv_writter(line, 'Quantitative_Results/' + model_dir[:-1]+'_Test') 116 | eval_det.update(np.sum(jac_det <= 0) / np.prod(tar.shape), x.size(0)) 117 | print('det < 0: {}'.format(np.sum(jac_det <= 0) / np.prod(tar.shape))) 118 | dsc_trans = utils.dice_val(def_out.long(), y_seg.long(), 46) 119 | dsc_raw = utils.dice_val(x_seg.long(), y_seg.long(), 46) 120 | print('Trans dsc: {:.4f}, Raw dsc: {:.4f}'.format(dsc_trans.item(),dsc_raw.item())) 121 | eval_dsc_def.update(dsc_trans.item(), x.size(0)) 122 | eval_dsc_raw.update(dsc_raw.item(), x.size(0)) 123 | stdy_idx += 1 124 | 125 | print('Deformed DSC: {:.3f} +- {:.3f}, Affine DSC: {:.3f} +- {:.3f}'.format(eval_dsc_def.avg, 126 | eval_dsc_def.std, 127 | eval_dsc_raw.avg, 128 | eval_dsc_raw.std)) 129 | print('deformed det: {}, std: {}'.format(eval_det.avg, eval_det.std)) 130 | 131 | def csv_writter(line, name): 132 | with open(name+'.csv', 'a') as file: 133 | file.write(line) 134 | file.write('\n') 135 | 136 | if __name__ == '__main__': 137 | ''' 138 | GPU configuration 139 | ''' 140 | # GPU_iden = 1 141 | # GPU_num = torch.cuda.device_count() 142 | # print('Number of GPU: ' + str(GPU_num)) 143 | # for GPU_idx in range(GPU_num): 144 | # GPU_name = torch.cuda.get_device_name(GPU_idx) 145 | # print(' GPU #' + str(GPU_idx) + ': ' + GPU_name) 146 | # torch.cuda.set_device(GPU_iden) 147 | # GPU_avai = torch.cuda.is_available() 148 | # print('Currently using: ' + torch.cuda.get_device_name(GPU_iden)) 149 | # print('If the GPU is available? ' + str(GPU_avai)) 150 | main() 151 | -------------------------------------------------------------------------------- /3D_LessNet_Diff/label_info.txt: -------------------------------------------------------------------------------- 1 | 0 Unknown 0 0 0 0 2 | 1 Left-Cerebral-Exterior 70 130 180 0 3 | 2 Left-Cerebral-White-Matter 245 245 245 0 4 | 3 Left-Cerebral-Cortex 205 62 78 0 5 | 4 Left-Lateral-Ventricle 120 18 134 0 6 | 5 Left-Inf-Lat-Vent 196 58 250 0 7 | 6 Left-Cerebellum-Exterior 0 148 0 0 8 | 7 Left-Cerebellum-White-Matter 220 248 164 0 9 | 8 Left-Cerebellum-Cortex 230 148 34 0 10 | 9 Left-Thalamus 0 118 14 0 11 | 10 Left-Thalamus-Proper* 0 118 14 0 12 | 11 Left-Caudate 122 186 220 0 13 | 12 Left-Putamen 236 13 176 0 14 | 13 Left-Pallidum 12 48 255 0 15 | 14 3rd-Ventricle 204 182 142 0 16 | 15 4th-Ventricle 42 204 164 0 17 | 16 Brain-Stem 119 159 176 0 18 | 17 Left-Hippocampus 220 216 20 0 19 | 18 Left-Amygdala 103 255 255 0 20 | 19 Left-Insula 80 196 98 0 21 | 20 Left-Operculum 60 58 210 0 22 | 21 Line-1 60 58 210 0 23 | 22 Line-2 60 58 210 0 24 | 23 Line-3 60 58 210 0 25 | 24 CSF 60 60 60 0 26 | 25 Left-Lesion 255 165 0 0 27 | 26 Left-Accumbens-area 255 165 0 0 28 | 27 Left-Substancia-Nigra 0 255 127 0 29 | 28 Left-VentralDC 165 42 42 0 30 | 29 Left-undetermined 135 206 235 0 31 | 30 Left-vessel 160 32 240 0 32 | 31 Left-choroid-plexus 0 200 200 0 33 | 32 Left-F3orb 100 50 100 0 34 | 33 Left-lOg 135 50 74 0 35 | 34 Left-aOg 122 135 50 0 36 | 35 Left-mOg 51 50 135 0 37 | 36 Left-pOg 74 155 60 0 38 | 37 Left-Stellate 120 62 43 0 39 | 38 Left-Porg 74 155 60 0 40 | 39 Left-Aorg 122 135 50 0 41 | 40 Right-Cerebral-Exterior 70 130 180 0 42 | 41 Right-Cerebral-White-Matter 245 245 245 0 43 | 42 Right-Cerebral-Cortex 205 62 78 0 44 | 43 Right-Lateral-Ventricle 120 18 134 0 45 | 44 Right-Inf-Lat-Vent 196 58 250 0 46 | 45 Right-Cerebellum-Exterior 0 148 0 0 47 | 46 Right-Cerebellum-White-Matter 220 248 164 0 48 | 47 Right-Cerebellum-Cortex 230 148 34 0 49 | 48 Right-Thalamus 0 118 14 0 50 | 49 Right-Thalamus-Proper* 0 118 14 0 51 | 50 Right-Caudate 122 186 220 0 52 | 51 Right-Putamen 236 13 176 0 53 | 52 Right-Pallidum 13 48 255 0 54 | 53 Right-Hippocampus 220 216 20 0 55 | 54 Right-Amygdala 103 255 255 0 56 | 55 Right-Insula 80 196 98 0 57 | 56 Right-Operculum 60 58 210 0 58 | 57 Right-Lesion 255 165 0 0 59 | 58 Right-Accumbens-area 255 165 0 0 60 | 59 Right-Substancia-Nigra 0 255 127 0 61 | 60 Right-VentralDC 165 42 42 0 62 | 61 Right-undetermined 135 206 235 0 63 | 62 Right-vessel 160 32 240 0 64 | 63 Right-choroid-plexus 0 200 221 0 65 | 64 Right-F3orb 100 50 100 0 66 | 65 Right-lOg 135 50 74 0 67 | 66 Right-aOg 122 135 50 0 68 | 67 Right-mOg 51 50 135 0 69 | 68 Right-pOg 74 155 60 0 70 | 69 Right-Stellate 120 62 43 0 71 | 70 Right-Porg 74 155 60 0 72 | 71 Right-Aorg 122 135 50 0 73 | 72 5th-Ventricle 120 190 150 0 74 | 73 Left-Interior 122 135 50 0 75 | 74 Right-Interior 122 135 50 0 76 | 77 | 77 WM-hypointensities 200 70 255 0 78 | 78 Left-WM-hypointensities 255 148 10 0 79 | 79 Right-WM-hypointensities 255 148 10 0 80 | 80 non-WM-hypointensities 164 108 226 0 81 | 81 Left-non-WM-hypointensities 164 108 226 0 82 | 82 Right-non-WM-hypointensities 164 108 226 0 83 | 83 Left-F1 255 218 185 0 84 | 84 Right-F1 255 218 185 0 85 | 85 Optic-Chiasm 234 169 30 0 86 | 192 Corpus_Callosum 250 255 50 0 87 | 88 | 86 Left_future_WMSA 200 120 255 0 89 | 87 Right_future_WMSA 200 121 255 0 90 | 88 future_WMSA 200 122 255 0 91 | 92 | 93 | 96 Left-Amygdala-Anterior 205 10 125 0 94 | 97 Right-Amygdala-Anterior 205 10 125 0 95 | 98 Dura 160 32 240 0 96 | 97 | 100 Left-wm-intensity-abnormality 124 140 178 0 98 | 101 Left-caudate-intensity-abnormality 125 140 178 0 99 | 102 Left-putamen-intensity-abnormality 126 140 178 0 100 | 103 Left-accumbens-intensity-abnormality 127 140 178 0 101 | 104 Left-pallidum-intensity-abnormality 124 141 178 0 102 | 105 Left-amygdala-intensity-abnormality 124 142 178 0 103 | 106 Left-hippocampus-intensity-abnormality 124 143 178 0 104 | 107 Left-thalamus-intensity-abnormality 124 144 178 0 105 | 108 Left-VDC-intensity-abnormality 124 140 179 0 106 | 109 Right-wm-intensity-abnormality 124 140 178 0 107 | 110 Right-caudate-intensity-abnormality 125 140 178 0 108 | 111 Right-putamen-intensity-abnormality 126 140 178 0 109 | 112 Right-accumbens-intensity-abnormality 127 140 178 0 110 | 113 Right-pallidum-intensity-abnormality 124 141 178 0 111 | 114 Right-amygdala-intensity-abnormality 124 142 178 0 112 | 115 Right-hippocampus-intensity-abnormality 124 143 178 0 113 | 116 Right-thalamus-intensity-abnormality 124 144 178 0 114 | 117 Right-VDC-intensity-abnormality 124 140 179 0 115 | 116 | 118 Epidermis 255 20 147 0 117 | 119 Conn-Tissue 205 179 139 0 118 | 120 SC-Fat-Muscle 238 238 209 0 119 | 121 Cranium 200 200 200 0 120 | 122 CSF-SA 74 255 74 0 121 | 123 Muscle 238 0 0 0 122 | 124 Ear 0 0 139 0 123 | 125 Adipose 173 255 47 0 124 | 126 Spinal-Cord 133 203 229 0 125 | 127 Soft-Tissue 26 237 57 0 126 | 128 Nerve 34 139 34 0 127 | 129 Bone 30 144 255 0 128 | 130 Air 147 19 173 0 129 | 131 Orbital-Fat 238 59 59 0 130 | 132 Tongue 221 39 200 0 131 | 133 Nasal-Structures 238 174 238 0 132 | 134 Globe 255 0 0 0 133 | 135 Teeth 72 61 139 0 134 | 136 Left-Caudate-Putamen 21 39 132 0 135 | 137 Right-Caudate-Putamen 21 39 132 0 136 | 138 Left-Claustrum 65 135 20 0 137 | 139 Right-Claustrum 65 135 20 0 138 | 140 Cornea 134 4 160 0 139 | 142 Diploe 221 226 68 0 140 | 143 Vitreous-Humor 255 255 254 0 141 | 144 Lens 52 209 226 0 142 | 145 Aqueous-Humor 239 160 223 0 143 | 146 Outer-Table 70 130 180 0 144 | 147 Inner-Table 70 130 181 0 145 | 148 Periosteum 139 121 94 0 146 | 149 Endosteum 224 224 224 0 147 | 150 R-C-S 255 0 0 0 148 | 151 Iris 205 205 0 0 149 | 152 SC-Adipose-Muscle 238 238 209 0 150 | 153 SC-Tissue 139 121 94 0 151 | 154 Orbital-Adipose 238 59 59 0 152 | 153 | 155 Left-IntCapsule-Ant 238 59 59 0 154 | 156 Right-IntCapsule-Ant 238 59 59 0 155 | 157 Left-IntCapsule-Pos 62 10 205 0 156 | 158 Right-IntCapsule-Pos 62 10 205 0 157 | 158 | # These labels are for babies/children 159 | 159 Left-Cerebral-WM-unmyelinated 0 118 14 0 160 | 160 Right-Cerebral-WM-unmyelinated 0 118 14 0 161 | 161 Left-Cerebral-WM-myelinated 220 216 21 0 162 | 162 Right-Cerebral-WM-myelinated 220 216 21 0 163 | 163 Left-Subcortical-Gray-Matter 122 186 220 0 164 | 164 Right-Subcortical-Gray-Matter 122 186 220 0 165 | 165 Skull 120 120 120 0 166 | 166 Posterior-fossa 14 48 255 0 167 | 167 Scalp 166 42 42 0 168 | 168 Hematoma 121 18 134 0 169 | 169 Left-Basal-Ganglia 236 13 127 0 170 | 176 Right-Basal-Ganglia 236 13 126 0 171 | 172 | # Label names and colors for Brainstem consituents 173 | # No. Label Name: R G B A 174 | 170 brainstem 119 159 176 0 175 | 171 DCG 119 0 176 0 176 | 172 Vermis 119 100 176 0 177 | 173 Midbrain 242 104 76 0 178 | 174 Pons 206 195 58 0 179 | 175 Medulla 119 159 176 0 180 | 177 Vermis-White-Matter 119 50 176 0 181 | 178 SCP 142 182 0 0 182 | 179 Floculus 19 100 176 0 183 | 184 | 180 Left-Cortical-Dysplasia 73 61 139 0 185 | 181 Right-Cortical-Dysplasia 73 62 139 0 186 | 182 CblumNodulus 10 100 176 0 187 | 188 | 193 Left-hippocampal_fissure 0 196 255 0 189 | 194 Left-CADG-head 255 164 164 0 190 | 195 Left-subiculum 196 196 0 0 191 | 196 Left-fimbria 0 100 255 0 192 | 197 Right-hippocampal_fissure 128 196 164 0 193 | 198 Right-CADG-head 0 126 75 0 194 | 199 Right-subiculum 128 96 64 0 195 | 200 Right-fimbria 0 50 128 0 196 | 201 alveus 255 204 153 0 197 | 202 perforant_pathway 255 128 128 0 198 | 203 parasubiculum 255 255 0 0 199 | 204 presubiculum 64 0 64 0 200 | 205 subiculum 0 0 255 0 201 | 206 CA1 255 0 0 0 202 | 207 CA2 128 128 255 0 203 | 208 CA3 0 128 0 0 204 | 209 CA4 196 160 128 0 205 | 210 GC-DG 32 200 255 0 206 | 211 HATA 128 255 128 0 207 | 212 fimbria 204 153 204 0 208 | 213 lateral_ventricle 121 17 136 0 209 | 214 molecular_layer_HP 128 0 0 0 210 | 215 hippocampal_fissure 128 32 255 0 211 | 216 entorhinal_cortex 255 204 102 0 212 | 217 molecular_layer_subiculum 128 128 128 0 213 | 218 Amygdala 104 255 255 0 214 | 219 Cerebral_White_Matter 0 226 0 0 215 | 220 Cerebral_Cortex 205 63 78 0 216 | 221 Inf_Lat_Vent 197 58 250 0 217 | 222 Perirhinal 33 150 250 0 218 | 223 Cerebral_White_Matter_Edge 226 0 0 0 219 | 224 Background 100 100 100 0 220 | 225 Ectorhinal 197 150 250 0 221 | 226 HP_tail 170 170 255 0 222 | 223 | 250 Fornix 255 0 0 0 224 | 251 CC_Posterior 0 0 64 0 225 | 252 CC_Mid_Posterior 0 0 112 0 226 | 253 CC_Central 0 0 160 0 227 | 254 CC_Mid_Anterior 0 0 208 0 228 | 255 CC_Anterior 0 0 255 0 229 | -------------------------------------------------------------------------------- /3D_LessNet_Diff/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | from argparse import ArgumentParser 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torch 8 | from torchvision import transforms 9 | from Models import * 10 | # from Functions import TrainDataset 11 | import torch.utils.data as Data 12 | from data import datasets, trans 13 | from natsort import natsorted 14 | import csv 15 | parser = ArgumentParser() 16 | parser.add_argument("--lr", type=float, 17 | dest="lr", default=1e-4, help="learning rate") 18 | parser.add_argument("--bs", type=int, 19 | dest="bs", default=1, help="batch_size") 20 | parser.add_argument("--iteration", type=int, 21 | dest="iteration", default=320001, 22 | help="number of total iterations") 23 | parser.add_argument("--smth_labda", type=float, 24 | dest="smth_labda", default=0.02, 25 | help="labda loss: suggested range 0.1 to 10") 26 | parser.add_argument("--checkpoint", type=int, 27 | dest="checkpoint", default=403, 28 | help="frequency of saving models") 29 | parser.add_argument("--start_channel", type=int, 30 | dest="start_channel", default=8, 31 | help="number of start channels") 32 | parser.add_argument("--trainingset", type=int, 33 | dest="trainingset", default=4, 34 | help="1 Half : 200 Images, 2 The other Half 200 Images 3 All 400 Images") 35 | parser.add_argument("--using_l2", type=int, 36 | dest="using_l2", 37 | default=1, 38 | help="using l2 or not") 39 | opt = parser.parse_args() 40 | 41 | lr = opt.lr 42 | bs = opt.bs 43 | iteration = opt.iteration 44 | start_channel = opt.start_channel 45 | n_checkpoint = opt.checkpoint 46 | smooth = opt.smth_labda 47 | trainingset = opt.trainingset 48 | using_l2 = opt.using_l2 49 | 50 | def dice(pred1, truth1): 51 | VOI_lbls = [1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 34, 36] 52 | dice_35=np.zeros(len(VOI_lbls)) 53 | index = 0 54 | for k in VOI_lbls: 55 | #print(k) 56 | truth = truth1 == k 57 | pred = pred1 == k 58 | intersection = np.sum(pred * truth) * 2.0 59 | # print(intersection) 60 | dice_35[index]=intersection / (np.sum(pred) + np.sum(truth)) 61 | index = index + 1 62 | return np.mean(dice_35) 63 | 64 | def save_checkpoint(state, save_dir, save_filename, max_model_num=10): 65 | torch.save(state, save_dir + save_filename) 66 | model_lists = natsorted(glob.glob(save_dir + '*')) 67 | # print(model_lists) 68 | while len(model_lists) > max_model_num: 69 | os.remove(model_lists[0]) 70 | model_lists = natsorted(glob.glob(save_dir + '*')) 71 | 72 | def train(): 73 | use_cuda = True 74 | device = torch.device("cuda" if use_cuda else "cpu") 75 | 76 | atlas_dir = '/bask/projects/d/duanj-ai-imaging/UvT/TransMorph_Xi/IXI_Mine/IXI_data/atlas.pkl' 77 | train_dir = '/bask/projects/d/duanj-ai-imaging/UvT/TransMorph_Xi/IXI_Mine/IXI_data/Train/' 78 | val_dir = '/bask/projects/d/duanj-ai-imaging/UvT/TransMorph_Xi/IXI_Mine/IXI_data/Val/' 79 | train_composed = transforms.Compose([trans.RandomFlip(0), 80 | trans.NumpyType((np.float32, np.float32)), 81 | ]) 82 | 83 | # train_composed = transforms.Compose([trans.NumpyType((np.float32, np.float32)), 84 | # ]) 85 | 86 | val_composed = transforms.Compose([trans.Seg_norm(), #rearrange segmentation label to 1 to 46 87 | trans.NumpyType((np.float32, np.int16))]) 88 | train_set = datasets.IXIBrainDataset(glob.glob(train_dir + '*.pkl'), atlas_dir, transforms=train_composed) 89 | val_set = datasets.IXIBrainInferDataset(glob.glob(val_dir + '*.pkl'), atlas_dir, transforms=val_composed) 90 | train_loader = Data.DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=4, pin_memory=True) 91 | val_loader = Data.DataLoader(val_set, batch_size=bs, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) 92 | 93 | 94 | model = UNet(6, 3, start_channel).to(device) 95 | if using_l2 == 1: 96 | loss_similarity = MSE().loss 97 | elif using_l2 == 0: 98 | loss_similarity = SAD().loss 99 | elif using_l2 == 2: 100 | loss_similarity = NCC() 101 | loss_smooth = smoothloss 102 | 103 | transform = SpatialTransform().to(device) 104 | diff_transform = DiffeomorphicTransform(time_step=7).to(device) 105 | 106 | 107 | for param in transform.parameters(): 108 | param.requires_grad = False 109 | param.volatile = True 110 | 111 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 112 | model_dir = './L2ss_{}_Chan_{}_Smth_{}_LR_{}_Val/'.format(using_l2,start_channel,smooth,lr) 113 | csv_name = 'L2ss_{}_Chan_{}_Smth_{}_LR_{}.csv'.format(using_l2,start_channel,smooth,lr) 114 | assert os.path.exists(csv_name) ==0 115 | assert os.path.isdir(model_dir) ==0 116 | f = open(csv_name, 'w') 117 | with f: 118 | fnames = ['Index','Dice'] 119 | writer = csv.DictWriter(f, fieldnames=fnames) 120 | writer.writeheader() 121 | 122 | if not os.path.isdir(model_dir): 123 | os.mkdir(model_dir) 124 | 125 | lossall = np.zeros((3, iteration)) 126 | step = 1 127 | epoch = 0 128 | while step <= iteration: 129 | for X, Y in train_loader: 130 | 131 | X = X.to(device).float() 132 | Y = Y.to(device).float() 133 | 134 | f_xy = model(X, Y) 135 | D_f_xy = diff_transform(f_xy) 136 | # D_f_xy = f_xy 137 | X_Y = transform(X, D_f_xy.permute(0, 2, 3, 4, 1)) 138 | 139 | loss1 = loss_similarity(Y, X_Y) 140 | loss5 = loss_smooth(f_xy) 141 | loss = loss1 + smooth * loss5 142 | 143 | optimizer.zero_grad() 144 | loss.backward() 145 | optimizer.step() 146 | 147 | lossall[:,step] = np.array([loss.item(),loss1.item(),loss5.item()]) 148 | sys.stdout.write("\r" + 'step "{0}" -> training loss "{1:.4f}" - sim "{2:.4f}" -smo "{3:.4f}" '.format(step, loss.item(),loss1.item(),loss5.item())) 149 | sys.stdout.flush() 150 | 151 | if (step % n_checkpoint == 0) or (step == 1): 152 | with torch.no_grad(): 153 | Dices_Validation = [] 154 | for data in val_loader: 155 | model.eval() 156 | xv = data[0] 157 | yv = data[1] 158 | xv_seg = data[2] 159 | yv_seg = data[3] 160 | vf_xy = model(xv.float().to(device), yv.float().to(device)) 161 | D_vf_xy = diff_transform(vf_xy) 162 | warped_xv_seg= transform(xv_seg.float().to(device), D_vf_xy.permute(0, 2, 3, 4, 1), mod = 'nearest') 163 | for bs_index in range(bs): 164 | dice_bs=dice(warped_xv_seg[bs_index,...].data.cpu().numpy().copy(),yv_seg[bs_index,...].data.cpu().numpy().copy()) 165 | Dices_Validation.append(dice_bs) 166 | modelname = 'DiceVal_{:.4f}_Epoch_{:04d}.pth'.format(np.mean(Dices_Validation), epoch) 167 | f = open(csv_name, 'a') 168 | with f: 169 | writer = csv.writer(f) 170 | writer.writerow([epoch, np.mean(Dices_Validation)]) 171 | save_checkpoint(model.state_dict(), model_dir, modelname) 172 | # modelname = 'Epoch_{:09d}.pth'.format(epoch) 173 | # torch.save(model.state_dict(), model_dir + modelname) 174 | np.save(model_dir + 'Loss.npy', lossall) 175 | step += 1 176 | 177 | if step > iteration: 178 | break 179 | print("one epoch pass") 180 | epoch = epoch + 1 181 | np.save(model_dir + '/Loss.npy', lossall) 182 | 183 | train() 184 | -------------------------------------------------------------------------------- /3D_LessNet_Diff/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch, sys 5 | from torch import nn 6 | import pystrum.pynd.ndutils as nd 7 | from scipy.ndimage import gaussian_filter 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | self.vals = [] 20 | self.std = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | self.vals.append(val) 28 | self.std = np.std(self.vals) 29 | 30 | def pad_image(img, target_size): 31 | rows_to_pad = max(target_size[0] - img.shape[2], 0) 32 | cols_to_pad = max(target_size[1] - img.shape[3], 0) 33 | slcs_to_pad = max(target_size[2] - img.shape[4], 0) 34 | padded_img = F.pad(img, (0, slcs_to_pad, 0, cols_to_pad, 0, rows_to_pad), "constant", 0) 35 | return padded_img 36 | 37 | class SpatialTransformer(nn.Module): 38 | """ 39 | N-D Spatial Transformer 40 | """ 41 | 42 | def __init__(self, size, mode='bilinear'): 43 | super().__init__() 44 | 45 | self.mode = mode 46 | 47 | # create sampling grid 48 | vectors = [torch.arange(0, s) for s in size] 49 | grids = torch.meshgrid(vectors) 50 | grid = torch.stack(grids) 51 | grid = torch.unsqueeze(grid, 0) 52 | grid = grid.type(torch.FloatTensor).cuda() 53 | 54 | # registering the grid as a buffer cleanly moves it to the GPU, but it also 55 | # adds it to the state dict. this is annoying since everything in the state dict 56 | # is included when saving weights to disk, so the model files are way bigger 57 | # than they need to be. so far, there does not appear to be an elegant solution. 58 | # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict 59 | self.register_buffer('grid', grid) 60 | 61 | def forward(self, src, flow): 62 | # new locations 63 | new_locs = self.grid + flow 64 | shape = flow.shape[2:] 65 | 66 | # need to normalize grid values to [-1, 1] for resampler 67 | for i in range(len(shape)): 68 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) 69 | 70 | # move channels dim to last position 71 | # also not sure why, but the channels need to be reversed 72 | if len(shape) == 2: 73 | new_locs = new_locs.permute(0, 2, 3, 1) 74 | new_locs = new_locs[..., [1, 0]] 75 | elif len(shape) == 3: 76 | new_locs = new_locs.permute(0, 2, 3, 4, 1) 77 | new_locs = new_locs[..., [2, 1, 0]] 78 | 79 | return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode) 80 | 81 | class register_model(nn.Module): 82 | def __init__(self, img_size=(64, 256, 256), mode='bilinear'): 83 | super(register_model, self).__init__() 84 | self.spatial_trans = SpatialTransformer(img_size, mode) 85 | 86 | def forward(self, x): 87 | img = x[0].cuda() 88 | flow = x[1].cuda() 89 | out = self.spatial_trans(img, flow) 90 | return out 91 | 92 | def dice_val(y_pred, y_true, num_clus): 93 | y_pred = nn.functional.one_hot(y_pred, num_classes=num_clus) 94 | y_pred = torch.squeeze(y_pred, 1) 95 | y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() 96 | y_true = nn.functional.one_hot(y_true, num_classes=num_clus) 97 | y_true = torch.squeeze(y_true, 1) 98 | y_true = y_true.permute(0, 4, 1, 2, 3).contiguous() 99 | intersection = y_pred * y_true 100 | intersection = intersection.sum(dim=[2, 3, 4]) 101 | union = y_pred.sum(dim=[2, 3, 4]) + y_true.sum(dim=[2, 3, 4]) 102 | dsc = (2.*intersection) / (union + 1e-5) 103 | return torch.mean(torch.mean(dsc, dim=1)) 104 | 105 | def dice_val_VOI(y_pred, y_true): 106 | VOI_lbls = [1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 34, 36] 107 | pred = y_pred.detach().cpu().numpy()[0, 0, ...] 108 | true = y_true.detach().cpu().numpy()[0, 0, ...] 109 | DSCs = np.zeros((len(VOI_lbls), 1)) 110 | idx = 0 111 | for i in VOI_lbls: 112 | pred_i = pred == i 113 | true_i = true == i 114 | intersection = pred_i * true_i 115 | intersection = np.sum(intersection) 116 | union = np.sum(pred_i) + np.sum(true_i) 117 | dsc = (2.*intersection) / (union + 1e-5) 118 | DSCs[idx] =dsc 119 | idx += 1 120 | return np.mean(DSCs) 121 | 122 | def jacobian_determinant_vxm(disp): 123 | """ 124 | jacobian determinant of a displacement field. 125 | NB: to compute the spatial gradients, we use np.gradient. 126 | Parameters: 127 | disp: 2D or 3D displacement field of size [*vol_shape, nb_dims], 128 | where vol_shape is of len nb_dims 129 | Returns: 130 | jacobian determinant (scalar) 131 | """ 132 | 133 | # check inputs 134 | disp = disp.transpose(1, 2, 3, 0) 135 | volshape = disp.shape[:-1] 136 | nb_dims = len(volshape) 137 | assert len(volshape) in (2, 3), 'flow has to be 2D or 3D' 138 | 139 | # compute grid 140 | grid_lst = nd.volsize2ndgrid(volshape) 141 | grid = np.stack(grid_lst, len(volshape)) 142 | 143 | # compute gradients 144 | J = np.gradient(disp + grid) 145 | 146 | # 3D glow 147 | if nb_dims == 3: 148 | dx = J[0] 149 | dy = J[1] 150 | dz = J[2] 151 | 152 | # compute jacobian components 153 | Jdet0 = dx[..., 0] * (dy[..., 1] * dz[..., 2] - dy[..., 2] * dz[..., 1]) 154 | Jdet1 = dx[..., 1] * (dy[..., 0] * dz[..., 2] - dy[..., 2] * dz[..., 0]) 155 | Jdet2 = dx[..., 2] * (dy[..., 0] * dz[..., 1] - dy[..., 1] * dz[..., 0]) 156 | 157 | return Jdet0 - Jdet1 + Jdet2 158 | 159 | else: # must be 2 160 | 161 | dfdx = J[0] 162 | dfdy = J[1] 163 | 164 | return dfdx[..., 0] * dfdy[..., 1] - dfdy[..., 0] * dfdx[..., 1] 165 | 166 | import re 167 | def process_label(): 168 | #process labeling information for FreeSurfer 169 | seg_table = [0, 2, 3, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 26, 170 | 28, 30, 31, 41, 42, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60, 62, 171 | 63, 72, 77, 80, 85, 251, 252, 253, 254, 255] 172 | 173 | 174 | file1 = open('label_info.txt', 'r') 175 | Lines = file1.readlines() 176 | dict = {} 177 | seg_i = 0 178 | seg_look_up = [] 179 | for seg_label in seg_table: 180 | for line in Lines: 181 | line = re.sub(' +', ' ',line).split(' ') 182 | try: 183 | int(line[0]) 184 | except: 185 | continue 186 | if int(line[0]) == seg_label: 187 | seg_look_up.append([seg_i, int(line[0]), line[1]]) 188 | dict[seg_i] = line[1] 189 | seg_i += 1 190 | return dict 191 | 192 | def write2csv(line, name): 193 | with open(name+'.csv', 'a') as file: 194 | file.write(line) 195 | file.write('\n') 196 | 197 | def dice_val_substruct(y_pred, y_true, std_idx): 198 | with torch.no_grad(): 199 | y_pred = nn.functional.one_hot(y_pred, num_classes=46) 200 | y_pred = torch.squeeze(y_pred, 1) 201 | y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous() 202 | y_true = nn.functional.one_hot(y_true, num_classes=46) 203 | y_true = torch.squeeze(y_true, 1) 204 | y_true = y_true.permute(0, 4, 1, 2, 3).contiguous() 205 | y_pred = y_pred.detach().cpu().numpy() 206 | y_true = y_true.detach().cpu().numpy() 207 | 208 | line = 'p_{}'.format(std_idx) 209 | for i in range(46): 210 | pred_clus = y_pred[0, i, ...] 211 | true_clus = y_true[0, i, ...] 212 | intersection = pred_clus * true_clus 213 | intersection = intersection.sum() 214 | union = pred_clus.sum() + true_clus.sum() 215 | dsc = (2.*intersection) / (union + 1e-5) 216 | line = line+','+str(dsc) 217 | return line 218 | 219 | def dice(y_pred, y_true, ): 220 | intersection = y_pred * y_true 221 | intersection = np.sum(intersection) 222 | union = np.sum(y_pred) + np.sum(y_true) 223 | dsc = (2.*intersection) / (union + 1e-5) 224 | return dsc 225 | 226 | def smooth_seg(binary_img, sigma=1.5, thresh=0.4): 227 | binary_img = gaussian_filter(binary_img.astype(np.float32()), sigma=sigma) 228 | binary_img = binary_img > thresh 229 | return binary_img 230 | 231 | def get_mc_preds(net, inputs, mc_iter: int = 25): 232 | """Convenience fn. for MC integration for uncertainty estimation. 233 | Args: 234 | net: DIP model (can be standard, MFVI or MCDropout) 235 | inputs: input to net 236 | mc_iter: number of MC samples 237 | post_processor: process output of net before computing loss (e.g. downsampler in SR) 238 | mask: multiply output and target by mask before computing loss (for inpainting) 239 | """ 240 | img_list = [] 241 | flow_list = [] 242 | with torch.no_grad(): 243 | for _ in range(mc_iter): 244 | img, flow = net(inputs) 245 | img_list.append(img) 246 | flow_list.append(flow) 247 | return img_list, flow_list 248 | 249 | def calc_uncert(tar, img_list): 250 | sqr_diffs = [] 251 | for i in range(len(img_list)): 252 | sqr_diff = (img_list[i] - tar)**2 253 | sqr_diffs.append(sqr_diff) 254 | uncert = torch.mean(torch.cat(sqr_diffs, dim=0)[:], dim=0, keepdim=True) 255 | return uncert 256 | 257 | def calc_error(tar, img_list): 258 | sqr_diffs = [] 259 | for i in range(len(img_list)): 260 | sqr_diff = (img_list[i] - tar)**2 261 | sqr_diffs.append(sqr_diff) 262 | uncert = torch.mean(torch.cat(sqr_diffs, dim=0)[:], dim=0, keepdim=True) 263 | return uncert 264 | 265 | def get_mc_preds_w_errors(net, inputs, target, mc_iter: int = 25): 266 | """Convenience fn. for MC integration for uncertainty estimation. 267 | Args: 268 | net: DIP model (can be standard, MFVI or MCDropout) 269 | inputs: input to net 270 | mc_iter: number of MC samples 271 | post_processor: process output of net before computing loss (e.g. downsampler in SR) 272 | mask: multiply output and target by mask before computing loss (for inpainting) 273 | """ 274 | img_list = [] 275 | flow_list = [] 276 | MSE = nn.MSELoss() 277 | err = [] 278 | with torch.no_grad(): 279 | for _ in range(mc_iter): 280 | img, flow = net(inputs) 281 | img_list.append(img) 282 | flow_list.append(flow) 283 | err.append(MSE(img, target).item()) 284 | return img_list, flow_list, err 285 | 286 | def get_diff_mc_preds(net, inputs, mc_iter: int = 25): 287 | """Convenience fn. for MC integration for uncertainty estimation. 288 | Args: 289 | net: DIP model (can be standard, MFVI or MCDropout) 290 | inputs: input to net 291 | mc_iter: number of MC samples 292 | post_processor: process output of net before computing loss (e.g. downsampler in SR) 293 | mask: multiply output and target by mask before computing loss (for inpainting) 294 | """ 295 | img_list = [] 296 | flow_list = [] 297 | disp_list = [] 298 | with torch.no_grad(): 299 | for _ in range(mc_iter): 300 | img, _, flow, disp = net(inputs) 301 | img_list.append(img) 302 | flow_list.append(flow) 303 | disp_list.append(disp) 304 | return img_list, flow_list, disp_list 305 | 306 | def uncert_regression_gal(img_list, reduction = 'mean'): 307 | img_list = torch.cat(img_list, dim=0) 308 | mean = img_list[:,:-1].mean(dim=0, keepdim=True) 309 | ale = img_list[:,-1:].mean(dim=0, keepdim=True) 310 | epi = torch.var(img_list[:,:-1], dim=0, keepdim=True) 311 | #if epi.shape[1] == 3: 312 | epi = epi.mean(dim=1, keepdim=True) 313 | uncert = ale + epi 314 | if reduction == 'mean': 315 | return ale.mean().item(), epi.mean().item(), uncert.mean().item() 316 | elif reduction == 'sum': 317 | return ale.sum().item(), epi.sum().item(), uncert.sum().item() 318 | else: 319 | return ale.detach(), epi.detach(), uncert.detach() 320 | 321 | def uceloss(errors, uncert, n_bins=15, outlier=0.0, range=None): 322 | device = errors.device 323 | if range == None: 324 | bin_boundaries = torch.linspace(uncert.min().item(), uncert.max().item(), n_bins + 1, device=device) 325 | else: 326 | bin_boundaries = torch.linspace(range[0], range[1], n_bins + 1, device=device) 327 | bin_lowers = bin_boundaries[:-1] 328 | bin_uppers = bin_boundaries[1:] 329 | 330 | errors_in_bin_list = [] 331 | avg_uncert_in_bin_list = [] 332 | prop_in_bin_list = [] 333 | 334 | uce = torch.zeros(1, device=device) 335 | for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): 336 | # Calculated |uncertainty - error| in each bin 337 | in_bin = uncert.gt(bin_lower.item()) * uncert.le(bin_upper.item()) 338 | prop_in_bin = in_bin.float().mean() # |Bm| / n 339 | prop_in_bin_list.append(prop_in_bin) 340 | if prop_in_bin.item() > outlier: 341 | errors_in_bin = errors[in_bin].float().mean() # err() 342 | avg_uncert_in_bin = uncert[in_bin].mean() # uncert() 343 | uce += torch.abs(avg_uncert_in_bin - errors_in_bin) * prop_in_bin 344 | 345 | errors_in_bin_list.append(errors_in_bin) 346 | avg_uncert_in_bin_list.append(avg_uncert_in_bin) 347 | 348 | err_in_bin = torch.tensor(errors_in_bin_list, device=device) 349 | avg_uncert_in_bin = torch.tensor(avg_uncert_in_bin_list, device=device) 350 | prop_in_bin = torch.tensor(prop_in_bin_list, device=device) 351 | 352 | return uce, err_in_bin, avg_uncert_in_bin, prop_in_bin -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Training & Testing 2 | 3 | We note that the pre-trained models are provided for easier reproduction of our reported results. 4 | 5 | ``` 6 | CUDA_VISIBLE_DEVICES=0 python train.py --start_channel 8 --using_l2 2 --smth_labda 5.0 --lr 1e-4 --trainingset 4 --checkpoint 403 --iteration 201501 7 | 8 | CUDA_VISIBLE_DEVICES=0 python infer_bilinear.py --start_channel 8 --using_l2 2 --smth_labda 5.0 --lr 1e-4 --trainingset 4 --checkpoint 403 --iteration 201501 9 | 10 | python compute_dsc_jet_from_quantiResult.py 11 | ``` 12 | # Changes 13 | 14 | In the original arXiv draft, I used the notations LessNet_4, 6, 8, 12, and 16 (LessNet_C) to denote different variants of LessNet, assuming a starting channel of 4C that is progressively reduced to 3C, 2C, and C. 15 | However, I recently noticed that in the code implementation, the starting channels were actually set to 8×, 6×, 4×, and 2× a start channel. 16 | To align the notation with the actual implementation, the model names should be updated to LessNet_8, 12, 16, 24, and 32. But the computational results remain unchanged. 17 | 18 | For instance, the changes in TABLE IV will be: 19 | 20 | | Was | Now | Unchanged | Unchanged | Unchanged | Unchanged |Unchanged | 21 | |-----|-----|:---------:|:---------:|:------:|:-----------:|:------------:| 22 | | C | C | Parameter | Mult-Adds | Memory | Dice | J<0 | 23 | | 4 | 8 | 44,336 | 152 | 10.31 | 0.749±0.040 | 0.808±0.389 | 24 | | 6 | 12 | 96,264 | 327 | 15.23 | 0.757±0.040 | 0.749±0.376 | 25 | | 8 | 16 | 168,032 | 570 | 20.16 | 0.761±0.039 | 0.742±0.353 | 26 | | 12 | 24 | 371,088 | 1250 | 30.00 | 0.766±0.039 | 0.852±0.392 | 27 | | 16 | 32 | 653,504 | 2200 | 39.84 | 0.768±0.039 | 0.830±0.397 | 28 | 29 | **Only the naming was incorrect; the values reported are correct.** 30 | I have updated the code accordingly to reflect these changes. 31 | 32 | # Acknowledgment 33 | 34 | We note that parts of the code are adopted from [IC-Net](https://github.com/zhangjun001/ICNet), [SYM-Net,](https://github.com/cwmok/Fast-Symmetric-Diffeomorphic-Image-Registration-with-Convolutional-Neural-Networks) and [TransMorph](https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration) (in chronological order of publication). 35 | 36 | The preprocessed IXI data can be found in [TransMorph](https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration). 37 | --------------------------------------------------------------------------------