├── data
├── __init__.py
├── trans.py
└── datasets.py
├── makePklDataset.py
├── losses.py
├── README.md
├── infer.py
├── utils.py
├── train.py
└── models.py
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
--------------------------------------------------------------------------------
/data/trans.py:
--------------------------------------------------------------------------------
1 | # import math
2 | import collections
3 | import numpy as np
4 |
5 |
6 | class Base(object):
7 | def sample(self, *shape):
8 | return shape
9 |
10 | def tf(self, img, k=0):
11 | return img
12 |
13 | def __call__(self, img, dim=3, reuse=False):
14 | if not reuse:
15 | im = img if isinstance(img, np.ndarray) else img[0]
16 | shape = im.shape[1:dim+1]
17 | self.sample(*shape)
18 |
19 | if isinstance(img, collections.Sequence):
20 | return [self.tf(x, k) for k, x in enumerate(img)]
21 |
22 | return self.tf(img)
23 |
24 | def __str__(self):
25 | return 'Identity()'
26 |
27 | class Seg_norm(Base):
28 | def __init__(self, ):
29 | a = None
30 | self.seg_table = np.array([0, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 61, 62,
31 | 63, 64, 65, 66, 67, 68, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 101, 102, 121, 122, 161, 162,
32 | 163, 164, 165, 166])
33 | def tf(self, img, k=0):
34 | if k == 0:
35 | return img
36 | img_out = np.zeros_like(img)
37 | for i in range(len(self.seg_table)):
38 | img_out[img == self.seg_table[i]] = i
39 | return img_out
40 |
41 |
42 | class NumpyType(Base):
43 | def __init__(self, types, num=-1):
44 | self.types = types # ('float32', 'int64')
45 | self.num = num
46 |
47 | def tf(self, img, k=0):
48 | if self.num > 0 and k >= self.num:
49 | return img
50 | # make this work with both Tensor and Numpy
51 | return img.astype(self.types[k])
52 |
53 | def __str__(self):
54 | s = ', '.join([str(s) for s in self.types])
55 | return 'NumpyType(({}))'.format(s)
56 |
57 |
--------------------------------------------------------------------------------
/makePklDataset.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import SimpleITK as sitk
3 | import numpy as np
4 | import glob
5 | from natsort import natsorted
6 | import os
7 |
8 | def pksave(img, label, save_path):
9 | with open(save_path, 'wb') as f:
10 | pickle.dump((img, label), f)
11 |
12 | def nii2arr(nii_img):
13 | return sitk.GetArrayFromImage(sitk.ReadImage(nii_img))
14 |
15 | def center(arr):
16 | c = np.sort(np.nonzero(arr))[:,[0,-1]]
17 | return np.mean(c, axis=-1).astype('int16')
18 |
19 | def minmax(arr):
20 | return (arr-np.min(arr))/(np.max(arr)-np.min(arr))
21 |
22 | def cropByCenter(image,center,final_shape=(160,192,160)):
23 | c = center
24 | crop = np.array([s // 2 for s in final_shape])
25 | # 0 axis
26 | cropmin, cropmax = c[0] - crop[0], c[0] + crop[0]
27 | if cropmin < 0:
28 | cropmin = 0
29 | cropmax = final_shape[0]
30 | if cropmax > image.shape[0]:
31 | cropmax = image.shape[0]
32 | cropmin = image.shape[0] - final_shape[0]
33 | image = image[cropmin:cropmax, :, :]
34 | # 1 axis
35 | cropmin, cropmax = c[1] - crop[1], c[1] + crop[1]
36 | if cropmin < 0:
37 | cropmin = 0
38 | cropmax = final_shape[1]
39 | if cropmax > image.shape[1]:
40 | cropmax = image.shape[1]
41 | cropmin = image.shape[1] - final_shape[1]
42 | image = image[:, cropmin:cropmax, :]
43 |
44 | # 2 axis
45 | cropmin, cropmax = c[2] - crop[2], c[2] + crop[2]
46 | if cropmin < 0:
47 | cropmin = 0
48 | cropmax = final_shape[2]
49 | if cropmax > image.shape[2]:
50 | cropmax = image.shape[2]
51 | cropmin = image.shape[2] - final_shape[2]
52 | image = image[:, :, cropmin:cropmax]
53 | return image
54 |
55 | path_to_LPBA='/data/LPBA40/' # the path of the original dataset
56 | img_niis = natsorted(glob.glob(path_to_LPBA+'*/*/*skullstripped.img.gz'))
57 | label_niis = natsorted(glob.glob(path_to_LPBA+'*/*/*label.img.gz'))
58 | print(img_niis, label_niis)
59 |
60 | save_path = 'LPBA_data/'
61 | if not os.path.exists(save_path):
62 | os.makedirs(save_path)
63 |
64 | for i, nii in enumerate(zip(img_niis, label_niis)):
65 | print(nii)
66 | img_nii, label_nii = nii
67 | img, label = nii2arr(img_nii), nii2arr(label_nii)
68 | print(img.shape, label.shape)
69 |
70 | # crop by center
71 | c = center(img)
72 | img = cropByCenter(img, c)
73 | label = cropByCenter(label, c)
74 |
75 | #norm
76 | img = minmax(img).astype('float32')
77 | label = label.astype('uint16')
78 | print(img.shape,np.unique(img),label.dtype, label.shape,np.unique(label),label.dtype)
79 | print(save_path+'subject_%02d.pkl'%(i+1))
80 | pksave(img,label, save_path=save_path+'subject_%02d.pkl'%(i+1))
81 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import math
5 |
6 | class Grad3d(torch.nn.Module):
7 | """
8 | N-D gradient loss.
9 | """
10 |
11 | def __init__(self, penalty='l1', loss_mult=None):
12 | super(Grad3d, self).__init__()
13 | self.penalty = penalty
14 | self.loss_mult = loss_mult
15 |
16 | def forward(self, y_pred, y_true):
17 | dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
18 | dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
19 | dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])
20 |
21 | if self.penalty == 'l2':
22 | dy = dy * dy
23 | dx = dx * dx
24 | dz = dz * dz
25 |
26 | d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz)
27 | grad = d / 3.0
28 |
29 | if self.loss_mult is not None:
30 | grad *= self.loss_mult
31 | return grad
32 |
33 |
34 | class NCC_vxm(torch.nn.Module):
35 | """
36 | Local (over window) normalized cross correlation loss.
37 | """
38 |
39 | def __init__(self, win=None):
40 | super(NCC_vxm, self).__init__()
41 | self.win = win
42 |
43 | def forward(self, y_true, y_pred):
44 |
45 | Ii = y_true
46 | Ji = y_pred
47 |
48 | # get dimension of volume
49 | # assumes Ii, Ji are sized [batch_size, *vol_shape, nb_feats]
50 | ndims = len(list(Ii.size())) - 2
51 | assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims
52 |
53 | # set window size
54 | win = [9] * ndims if self.win is None else self.win
55 |
56 | # compute filters
57 | sum_filt = torch.ones([1, 1, *win]).to("cuda")
58 |
59 | pad_no = math.floor(win[0] / 2)
60 |
61 | if ndims == 1:
62 | stride = (1)
63 | padding = (pad_no)
64 | elif ndims == 2:
65 | stride = (1, 1)
66 | padding = (pad_no, pad_no)
67 | else:
68 | stride = (1, 1, 1)
69 | padding = (pad_no, pad_no, pad_no)
70 |
71 | # get convolution function
72 | conv_fn = getattr(F, 'conv%dd' % ndims)
73 |
74 | # compute CC squares
75 | I2 = Ii * Ii
76 | J2 = Ji * Ji
77 | IJ = Ii * Ji
78 |
79 | I_sum = conv_fn(Ii, sum_filt, stride=stride, padding=padding)
80 | J_sum = conv_fn(Ji, sum_filt, stride=stride, padding=padding)
81 | I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding)
82 | J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding)
83 | IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding)
84 |
85 | win_size = np.prod(win)
86 | u_I = I_sum / win_size
87 | u_J = J_sum / win_size
88 |
89 | cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
90 | I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
91 | J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size
92 |
93 | cc = cross * cross / (I_var * J_var + 1e-5)
94 |
95 | return -torch.mean(cc)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Recursive Deformable Pyramid Network for Unsupervised Medical Image Registration (TMI2024)
2 |
3 | By Haiqiao Wang, Dong Ni, Yi Wang.
4 |
5 | Paper link: [[TMI]](https://ieeexplore.ieee.org/document/10423043)
6 |
7 |
8 |
9 |
10 | ## Description
11 | An unsupervised brain MR deformable registration method that achieves precise alignment through a pure convolutional pyramid structure and a semantics-infused progressive recursive inter-level looping strategy for modeling complex deformations, even without pre-alignment of brain MR images.
12 |
13 | 
14 |
15 |
16 | ## Dataset
17 | The official access addresses of the public data sets are as follows:
18 |
19 | LPBA [[link]](https://resource.loni.usc.edu/resources/atlases-downloads/)
20 |
21 | Mindboggle [[link]](https://osf.io/yhkde/)
22 |
23 | IXI [[link]](https://surfer.nmr.mgh.harvard.edu/pub/data/) [[freesurfer link]](https://surfer.nmr.mgh.harvard.edu/pub/data/ixi/)
24 |
25 | Note that we use the processed IXI dataset provided by freesurfer.
26 |
27 | ## Instructions
28 | For convenience, we are sharing the preprocessed [LPBA](https://drive.usercontent.google.com/download?id=1mFzZDn2qPAiP1ByGZ7EbsvEmm6vrS5WO&export=download&authuser=0) dataset used in our experiments. Once uncompressed, simply modify the "LPBA_path" in `train.py` to the path name of the extracted data. Next, you can execute `train.py` to train the network, and after training, you can run `infer.py` to test the network performance.
29 |
30 | ## Citation
31 | If you find the code useful, please cite our paper.
32 | ```
33 | @ARTICLE{10423043,
34 | author={Wang, Haiqiao and Ni, Dong and Wang, Yi},
35 | journal={IEEE Transactions on Medical Imaging},
36 | title={Recursive Deformable Pyramid Network for Unsupervised Medical Image Registration},
37 | year={2024},
38 | volume={43},
39 | number={6},
40 | pages={2229-2240},
41 | keywords={Deformation;Decoding;Feature extraction;Deformable models;Training;Image resolution;Image registration;Deformable image registration;convolutional neural networks;brain MRI},
42 | doi={10.1109/TMI.2024.3362968}}
43 | ```
44 | The overall framework and some network components of the code are heavily based on [TransMorph](https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration) and [VoxelMorph](https://github.com/voxelmorph/voxelmorph). We are very grateful for their contributions.
45 |
46 | The file `makePklDataset.py` shows how to make a pkl dataset from the original LPBA dataset. If you have any other questions about the .pkl format, please refer to the github page of [[TransMorph_on_IXI]](https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration/blob/main/IXI/TransMorph_on_IXI.md).
47 |
48 | ## Baseline Methods
49 | Several PyTorch implementations of some baseline methods can be found at [[SmileCode]](https://github.com/ZAX130/SmileCode/tree/main).
50 |
51 | ## How can other datasets be used in this code?
52 | This is a common question, and please refer to the github page of [ChangeDataset.md](https://github.com/ZAX130/ModeTv2/blob/main/ChangeDataset.md) for more information.
53 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os, losses, utils
3 | from torch.utils.data import DataLoader
4 | from data import datasets, trans
5 | import numpy as np
6 | import torch
7 | from torchvision import transforms
8 |
9 | import matplotlib.pyplot as plt
10 | from mpl_toolkits.mplot3d import axes3d
11 | from natsort import natsorted
12 | from models import RDP
13 | import random
14 | def same_seeds(seed):
15 | # Python built-in random module
16 | random.seed(seed)
17 | # Numpy
18 | np.random.seed(seed)
19 | # Torch
20 | torch.manual_seed(seed)
21 | if torch.cuda.is_available():
22 | torch.cuda.manual_seed(seed)
23 | torch.cuda.manual_seed_all(seed)
24 | # torch.backends.cudnn.deterministic = True
25 |
26 | same_seeds(24)
27 |
28 | class AverageMeter(object):
29 | """Computes and stores the average and current value"""
30 | def __init__(self):
31 | self.reset()
32 |
33 | def reset(self):
34 | self.val = 0
35 | self.avg = 0
36 | self.sum = 0
37 | self.count = 0
38 | self.vals = []
39 | self.std = 0
40 |
41 | def update(self, val, n=1):
42 | self.val = val
43 | self.sum += val * n
44 | self.count += n
45 | self.avg = self.sum / self.count
46 | self.vals.append(val)
47 | self.std = np.std(self.vals)
48 |
49 | def main():
50 |
51 | val_dir = '/LPBA_path/Val/'
52 | weights = [1, 1] # loss weights
53 | lr = 0.0001
54 | model_idx = -1
55 | model_folder = 'RDP_ncc_{}_reg_{}_lr_{}_54r/'.format(weights[0], weights[1], lr)
56 | model_dir = 'experiments/' + model_folder
57 |
58 | img_size = (160, 192, 160)
59 | model = RDP(img_size, channels=16)
60 | best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[model_idx])['state_dict']
61 | print('Best model: {}'.format(natsorted(os.listdir(model_dir))[model_idx]))
62 | model.load_state_dict(best_model)
63 | model.cuda()
64 | reg_model = utils.register_model(img_size, 'nearest')
65 | reg_model.cuda()
66 | test_composed = transforms.Compose([trans.Seg_norm(),
67 | trans.NumpyType((np.float32, np.int16)),
68 | ])
69 | test_set = datasets.LPBABrainInferDatasetS2S(glob.glob(val_dir + '*.pkl'), transforms=test_composed)
70 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)
71 | eval_dsc_def = AverageMeter()
72 | eval_dsc_raw = AverageMeter()
73 | eval_det = AverageMeter()
74 | with torch.no_grad():
75 | stdy_idx = 0
76 | for data in test_loader:
77 | model.eval()
78 | data = [t.cuda() for t in data]
79 | x = data[0]
80 | y = data[1]
81 | x_seg = data[2]
82 | y_seg = data[3]
83 |
84 | x_def, flow = model(x,y)
85 | def_out = reg_model([x_seg.cuda().float(), flow.cuda()])
86 | tar = y.detach().cpu().numpy()[0, 0, :, :, :]
87 | jac_det = utils.jacobian_determinant_vxm(flow.detach().cpu().numpy()[0, :, :, :, :])
88 | eval_det.update(np.sum(jac_det <= 0) / np.prod(tar.shape), x.size(0))
89 | dsc_trans = utils.dice_val_VOI(def_out.long(), y_seg.long())
90 | dsc_raw = utils.dice_val_VOI(x_seg.long(), y_seg.long())
91 | print('Trans dsc: {:.4f}, Raw dsc: {:.4f}'.format(dsc_trans.item(),dsc_raw.item()))
92 | eval_dsc_def.update(dsc_trans.item(), x.size(0))
93 | eval_dsc_raw.update(dsc_raw.item(), x.size(0))
94 | stdy_idx += 1
95 | print('Deformed DSC: {:.3f} +- {:.3f}, Affine DSC: {:.3f} +- {:.3f}'.format(eval_dsc_def.avg,
96 | eval_dsc_def.std,
97 | eval_dsc_raw.avg,
98 | eval_dsc_raw.std))
99 | print('deformed det: {}, std: {}'.format(eval_det.avg, eval_det.std))
100 |
101 |
102 | if __name__ == '__main__':
103 | '''
104 | GPU configuration
105 | '''
106 | GPU_iden = 0
107 | GPU_num = torch.cuda.device_count()
108 | print('Number of GPU: ' + str(GPU_num))
109 | for GPU_idx in range(GPU_num):
110 | GPU_name = torch.cuda.get_device_name(GPU_idx)
111 | print(' GPU #' + str(GPU_idx) + ': ' + GPU_name)
112 | torch.cuda.set_device(GPU_iden)
113 | GPU_avai = torch.cuda.is_available()
114 | print('Currently using: ' + torch.cuda.get_device_name(GPU_iden))
115 | print('If the GPU is available? ' + str(GPU_avai))
116 | main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch.nn.functional as F
4 | import torch, sys
5 | from torch import nn
6 | import pystrum.pynd.ndutils as nd
7 |
8 | class AverageMeter(object):
9 | """Computes and stores the average and current value"""
10 | def __init__(self):
11 | self.reset()
12 |
13 | def reset(self):
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 | self.vals = []
19 | self.std = 0
20 |
21 | def update(self, val, n=1):
22 | self.val = val
23 | self.sum += val * n
24 | self.count += n
25 | self.avg = self.sum / self.count
26 | self.vals.append(val)
27 | self.std = np.std(self.vals)
28 |
29 |
30 | class SpatialTransformer(nn.Module):
31 | """
32 | N-D Spatial Transformer
33 | """
34 |
35 | def __init__(self, size, mode='bilinear'):
36 | super().__init__()
37 |
38 | self.mode = mode
39 |
40 | # create sampling grid
41 | vectors = [torch.arange(0, s) for s in size]
42 | grids = torch.meshgrid(vectors)
43 | grid = torch.stack(grids)
44 | grid = torch.unsqueeze(grid, 0)
45 | grid = grid.type(torch.FloatTensor).cuda()
46 |
47 | # registering the grid as a buffer cleanly moves it to the GPU, but it also
48 | # adds it to the state dict. this is annoying since everything in the state dict
49 | # is included when saving weights to disk, so the model files are way bigger
50 | # than they need to be. so far, there does not appear to be an elegant solution.
51 | # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
52 | self.register_buffer('grid', grid)
53 |
54 | def forward(self, src, flow):
55 | # new locations
56 | new_locs = self.grid + flow
57 | shape = flow.shape[2:]
58 |
59 | # need to normalize grid values to [-1, 1] for resampler
60 | for i in range(len(shape)):
61 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
62 |
63 | # move channels dim to last position
64 | # also not sure why, but the channels need to be reversed
65 | if len(shape) == 2:
66 | new_locs = new_locs.permute(0, 2, 3, 1)
67 | new_locs = new_locs[..., [1, 0]]
68 | elif len(shape) == 3:
69 | new_locs = new_locs.permute(0, 2, 3, 4, 1)
70 | new_locs = new_locs[..., [2, 1, 0]]
71 |
72 | return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode)
73 |
74 | class register_model(nn.Module):
75 | def __init__(self, img_size=(64, 256, 256), mode='bilinear'):
76 | super(register_model, self).__init__()
77 | self.spatial_trans = SpatialTransformer(img_size, mode)
78 |
79 | def forward(self, x):
80 | img = x[0].cuda()
81 | flow = x[1].cuda()
82 | out = self.spatial_trans(img, flow)
83 | return out
84 |
85 |
86 | def dice_val_VOI(y_pred, y_true):
87 | VOI_lbls = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
88 | 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
89 | 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
90 | 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
91 | 48, 49, 50, 51, 52, 53, 54]
92 |
93 | pred = y_pred.detach().cpu().numpy()[0, 0, ...]
94 | true = y_true.detach().cpu().numpy()[0, 0, ...]
95 | DSCs = np.zeros((len(VOI_lbls), 1))
96 | idx = 0
97 | for i in VOI_lbls:
98 | pred_i = pred == i
99 | true_i = true == i
100 | intersection = pred_i * true_i
101 | intersection = np.sum(intersection)
102 | union = np.sum(pred_i) + np.sum(true_i)
103 | dsc = (2.*intersection) / (union + 1e-5)
104 | DSCs[idx] =dsc
105 | idx += 1
106 | return np.mean(DSCs)
107 |
108 | def jacobian_determinant_vxm(disp):
109 | """
110 | jacobian determinant of a displacement field.
111 | NB: to compute the spatial gradients, we use np.gradient.
112 | Parameters:
113 | disp: 2D or 3D displacement field of size [*vol_shape, nb_dims],
114 | where vol_shape is of len nb_dims
115 | Returns:
116 | jacobian determinant (scalar)
117 | """
118 |
119 | # check inputs
120 | disp = disp.transpose(1, 2, 3, 0)
121 | volshape = disp.shape[:-1]
122 | nb_dims = len(volshape)
123 | assert len(volshape) in (2, 3), 'flow has to be 2D or 3D'
124 |
125 | # compute grid
126 | grid_lst = nd.volsize2ndgrid(volshape)
127 | grid = np.stack(grid_lst, len(volshape))
128 |
129 | # compute gradients
130 | J = np.gradient(disp + grid)
131 |
132 | # 3D glow
133 | if nb_dims == 3:
134 | dx = J[0]
135 | dy = J[1]
136 | dz = J[2]
137 |
138 | # compute jacobian components
139 | Jdet0 = dx[..., 0] * (dy[..., 1] * dz[..., 2] - dy[..., 2] * dz[..., 1])
140 | Jdet1 = dx[..., 1] * (dy[..., 0] * dz[..., 2] - dy[..., 2] * dz[..., 0])
141 | Jdet2 = dx[..., 2] * (dy[..., 0] * dz[..., 1] - dy[..., 1] * dz[..., 0])
142 |
143 | return Jdet0 - Jdet1 + Jdet2
144 |
145 | else: # must be 2
146 |
147 | dfdx = J[0]
148 | dfdy = J[1]
149 |
150 | return dfdx[..., 0] * dfdy[..., 1] - dfdy[..., 0] * dfdx[..., 1]
151 |
--------------------------------------------------------------------------------
/data/datasets.py:
--------------------------------------------------------------------------------
1 | import os, glob
2 | import torch, sys
3 | from torch.utils.data import Dataset
4 | import matplotlib.pyplot as plt
5 | import pickle
6 | import numpy as np
7 |
8 | def pkload(fname):
9 | with open(fname, 'rb') as f:
10 | return pickle.load(f)
11 |
12 | class LPBABrainDatasetS2S(Dataset):
13 | def __init__(self, data_path, transforms):
14 | self.paths = data_path
15 | self.transforms = transforms
16 |
17 | def one_hot(self, img, C):
18 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3]))
19 | for i in range(C):
20 | out[i,...] = img == i
21 | return out
22 |
23 | def __getitem__(self, index):
24 | x_index = index // (len(self.paths) - 1)
25 | s = index % (len(self.paths) - 1)
26 | y_index = s + 1 if s >= x_index else s
27 | path_x = self.paths[x_index]
28 | path_y = self.paths[y_index]
29 | x, x_seg = pkload(path_x)
30 | y, y_seg = pkload(path_y)
31 | #print(x.shape)
32 | #print(x.shape)
33 | #print(np.unique(y))
34 | # print(x.shape, y.shape)#(240, 240, 155) (240, 240, 155)
35 | # transforms work with nhwtc
36 | x, y = x[None, ...], y[None, ...]
37 | # print(x.shape, y.shape)#(1, 240, 240, 155) (1, 240, 240, 155)
38 | x,y = self.transforms([x, y])
39 | #y = self.one_hot(y, 2)
40 | #print(y.shape)
41 | #sys.exit(0)
42 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth]
43 | y = np.ascontiguousarray(y)
44 | #plt.figure()
45 | #plt.subplot(1, 2, 1)
46 | #plt.imshow(x[0, :, :, 8], cmap='gray')
47 | #plt.subplot(1, 2, 2)
48 | #plt.imshow(y[0, :, :, 8], cmap='gray')
49 | #plt.show()
50 | #sys.exit(0)
51 | #y = np.squeeze(y, axis=0)
52 | x, y = torch.from_numpy(x), torch.from_numpy(y)
53 | return x, y
54 |
55 | def __len__(self):
56 | return len(self.paths)*(len(self.paths)-1)
57 |
58 |
59 | class LPBABrainInferDatasetS2S(Dataset):
60 | def __init__(self, data_path, transforms):
61 | self.paths = data_path
62 | self.transforms = transforms
63 |
64 | def one_hot(self, img, C):
65 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3]))
66 | for i in range(C):
67 | out[i,...] = img == i
68 | return out
69 |
70 | def __getitem__(self, index):
71 | x_index = index//(len(self.paths)-1)
72 | s = index%(len(self.paths)-1)
73 | y_index = s+1 if s >= x_index else s
74 | path_x = self.paths[x_index]
75 | path_y = self.paths[y_index]
76 | # print(os.path.basename(path_x), os.path.basename(path_y))
77 | x, x_seg = pkload(path_x)
78 | y, y_seg = pkload(path_y)
79 | x, y = x[None, ...], y[None, ...]
80 | x_seg, y_seg= x_seg[None, ...], y_seg[None, ...]
81 | x, x_seg = self.transforms([x, x_seg])
82 | y, y_seg = self.transforms([y, y_seg])
83 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth]
84 | y = np.ascontiguousarray(y)
85 | x_seg = np.ascontiguousarray(x_seg) # [Bsize,channelsHeight,,Width,Depth]
86 | y_seg = np.ascontiguousarray(y_seg)
87 | x, y, x_seg, y_seg = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(x_seg), torch.from_numpy(y_seg)
88 | return x, y, x_seg, y_seg
89 |
90 | def __len__(self):
91 | return len(self.paths)*(len(self.paths)-1)
92 |
93 |
94 | class LPBABrainHalfDatasetS2S(Dataset):
95 | def __init__(self, data_path, transforms):
96 | self.paths = data_path
97 | self.transforms = transforms
98 |
99 | def one_hot(self, img, C):
100 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3]))
101 | for i in range(C):
102 | out[i,...] = img == i
103 | return out
104 | def half_pair(self,pair):
105 | return pair[0][::2,::2,::2], pair[1][::2,::2,::2]
106 |
107 | def __getitem__(self, index):
108 | x_index = index // (len(self.paths) - 1)
109 | s = index % (len(self.paths) - 1)
110 | y_index = s + 1 if s >= x_index else s
111 | path_x = self.paths[x_index]
112 | path_y = self.paths[y_index]
113 | x, x_seg = self.half_pair(pkload(path_x))
114 | y, y_seg = self.half_pair(pkload(path_y))
115 |
116 | #print(x.shape)
117 | #print(x.shape)
118 | #print(np.unique(y))
119 | # print(x.shape, y.shape)#(240, 240, 155) (240, 240, 155)
120 | # transforms work with nhwtc
121 | x, y = x[None, ...], y[None, ...]
122 | # print(x.shape, y.shape)#(1, 240, 240, 155) (1, 240, 240, 155)
123 | x,y = self.transforms([x, y])
124 | #y = self.one_hot(y, 2)
125 | #print(y.shape)
126 | #sys.exit(0)
127 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth]
128 | y = np.ascontiguousarray(y)
129 | #plt.figure()
130 | #plt.subplot(1, 2, 1)
131 | #plt.imshow(x[0, :, :, 8], cmap='gray')
132 | #plt.subplot(1, 2, 2)
133 | #plt.imshow(y[0, :, :, 8], cmap='gray')
134 | #plt.show()
135 | #sys.exit(0)
136 | #y = np.squeeze(y, axis=0)
137 | x, y = torch.from_numpy(x), torch.from_numpy(y)
138 | return x, y
139 |
140 | def __len__(self):
141 | return len(self.paths)*(len(self.paths)-1)
142 |
143 |
144 | class LPBABrainHalfInferDatasetS2S(Dataset):
145 | def __init__(self, data_path, transforms):
146 | self.paths = data_path
147 | self.transforms = transforms
148 |
149 | def one_hot(self, img, C):
150 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3]))
151 | for i in range(C):
152 | out[i,...] = img == i
153 | return out
154 | def half_pair(self,pair):
155 | return pair[0][::2,::2,::2], pair[1][::2,::2,::2]
156 | def __getitem__(self, index):
157 | x_index = index//(len(self.paths)-1)
158 | s = index%(len(self.paths)-1)
159 | y_index = s+1 if s >= x_index else s
160 | path_x = self.paths[x_index]
161 | path_y = self.paths[y_index]
162 | # print(os.path.basename(path_x), os.path.basename(path_y))
163 | x, x_seg = self.half_pair(pkload(path_x))
164 | y, y_seg = self.half_pair(pkload(path_y))
165 | x, y = x[None, ...], y[None, ...]
166 | x_seg, y_seg= x_seg[None, ...], y_seg[None, ...]
167 | x, x_seg = self.transforms([x, x_seg])
168 | y, y_seg = self.transforms([y, y_seg])
169 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth]
170 | y = np.ascontiguousarray(y)
171 | x_seg = np.ascontiguousarray(x_seg) # [Bsize,channelsHeight,,Width,Depth]
172 | y_seg = np.ascontiguousarray(y_seg)
173 | x, y, x_seg, y_seg = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(x_seg), torch.from_numpy(y_seg)
174 | return x, y, x_seg, y_seg
175 |
176 | def __len__(self):
177 | return len(self.paths)*(len(self.paths)-1)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import glob
2 | # from torch.utils.tensorboard import SummaryWriter
3 | import os, losses, utils
4 | import sys
5 | from torch.utils.data import DataLoader
6 | from data import datasets, trans
7 | import numpy as np
8 | import torch
9 | from torchvision import transforms
10 | from torch import optim
11 | import torch.nn as nn
12 | import matplotlib.pyplot as plt
13 | from natsort import natsorted
14 | from models import RDP
15 | import random
16 | def same_seeds(seed):
17 | # Python built-in random module
18 | random.seed(seed)
19 | # Numpy
20 | np.random.seed(seed)
21 | # Torch
22 | torch.manual_seed(seed)
23 | if torch.cuda.is_available():
24 | torch.cuda.manual_seed(seed)
25 | torch.cuda.manual_seed_all(seed)
26 | torch.backends.cudnn.benchmark = True
27 | # torch.backends.cudnn.deterministic = True
28 |
29 | same_seeds(24)
30 | class Logger(object):
31 | def __init__(self, save_dir):
32 | self.terminal = sys.stdout
33 | self.log = open(save_dir+"logfile.log", "a")
34 |
35 | def write(self, message):
36 | self.terminal.write(message)
37 | self.log.write(message)
38 |
39 | def flush(self):
40 | pass
41 |
42 | def main():
43 | batch_size = 1
44 |
45 | train_dir = '/LPBA_path/Train/'
46 | val_dir = '/LPBA_path/Val/'
47 | weights = [1, 1] # loss weights
48 | lr = 0.0001
49 | save_dir = 'RDP_ncc_{}_reg_{}_lr_{}_54r/'.format(*weights, lr)
50 | if not os.path.exists('experiments/' + save_dir):
51 | os.makedirs('experiments/' + save_dir)
52 | if not os.path.exists('logs/' + save_dir):
53 | os.makedirs('logs/' + save_dir)
54 | sys.stdout = Logger('logs/' + save_dir)
55 | f = open(os.path.join('logs/'+save_dir, 'losses and dice' + ".txt"), "a")
56 |
57 | epoch_start = 0
58 | max_epoch = 30
59 | img_size = (160,192,160)
60 | cont_training = False
61 |
62 | '''
63 | Initialize model
64 | '''
65 | model = RDP(img_size, channels=16)
66 | model.cuda()
67 |
68 | '''
69 | Initialize spatial transformation function
70 | '''
71 | reg_model = utils.register_model(img_size, 'nearest')
72 | reg_model.cuda()
73 |
74 |
75 | '''
76 | If continue from previous training
77 | '''
78 | if cont_training:
79 | model_dir = 'experiments/'+save_dir
80 | updated_lr = round(lr * np.power(1 - (epoch_start) / max_epoch,0.9),8)
81 | best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[-1])['state_dict']
82 | model.load_state_dict(best_model)
83 | print(model_dir + natsorted(os.listdir(model_dir))[-1])
84 | else:
85 | updated_lr = lr
86 |
87 | '''
88 | Initialize training
89 | '''
90 | train_composed = transforms.Compose([trans.NumpyType((np.float32, np.float32))])
91 |
92 | val_composed = transforms.Compose([trans.Seg_norm(),
93 | trans.NumpyType((np.float32, np.int16))])
94 | train_set = datasets.LPBABrainDatasetS2S(glob.glob(train_dir + '*.pkl'), transforms=train_composed)
95 | val_set = datasets.LPBABrainInferDatasetS2S(glob.glob(val_dir + '*.pkl'), transforms=val_composed)
96 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
97 | val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
98 |
99 | optimizer = optim.Adam(model.parameters(), lr=updated_lr, weight_decay=0, amsgrad=True)
100 | criterion = losses.NCC_vxm()
101 | criterions = [criterion]
102 | criterions += [losses.Grad3d(penalty='l2')]
103 | best_dsc = 0
104 | # writer = SummaryWriter(log_dir='logs/'+save_dir)
105 | for epoch in range(epoch_start, max_epoch):
106 | print('Training Starts')
107 | '''
108 | Training
109 | '''
110 | loss_all = utils.AverageMeter()
111 | idx = 0
112 | for data in train_loader:
113 | idx += 1
114 | model.train()
115 | adjust_learning_rate(optimizer, epoch, max_epoch, lr)
116 | data = [t.cuda() for t in data]
117 | x = data[0]
118 | y = data[1]
119 |
120 | output = model(x,y)
121 |
122 | loss = 0
123 | loss_vals = []
124 | for n, loss_function in enumerate(criterions):
125 | curr_loss = loss_function(output[n], y) * weights[n]
126 | loss_vals.append(curr_loss)
127 | loss += curr_loss
128 | loss_all.update(loss.item(), y.numel())
129 | optimizer.zero_grad()
130 | loss.backward()
131 | optimizer.step()
132 |
133 | print('Iter {} of {} loss {:.4f}, Img Sim: {:.6f}, Reg: {:.6f}'.format(idx, len(train_loader), loss.item(), loss_vals[0].item(), loss_vals[1].item()))
134 |
135 | print('{} Epoch {} loss {:.4f}'.format(save_dir, epoch, loss_all.avg))
136 | print('Epoch {} loss {:.4f}'.format(epoch, loss_all.avg), file=f, end=' ')
137 | '''
138 | Validation
139 | '''
140 | eval_dsc = utils.AverageMeter()
141 | with torch.no_grad():
142 | for data in val_loader:
143 | model.eval()
144 | data = [t.cuda() for t in data]
145 | x = data[0]
146 | y = data[1]
147 | x_seg = data[2]
148 | y_seg = data[3]
149 | output = model(x,y)
150 | def_out = reg_model([x_seg.cuda().float(), output[1].cuda()])
151 | dsc = utils.dice_val_VOI(def_out.long(), y_seg.long())
152 | eval_dsc.update(dsc.item(), x.size(0))
153 | print(epoch, ':',eval_dsc.avg)
154 | best_dsc = max(eval_dsc.avg, best_dsc)
155 | print(eval_dsc.avg, file=f)
156 | save_checkpoint({
157 | 'epoch': epoch + 1,
158 | 'state_dict': model.state_dict(),
159 | 'best_dsc': best_dsc,
160 | 'optimizer': optimizer.state_dict(),
161 | }, save_dir='experiments/' + save_dir, filename='dsc{:.3f}.pth.tar'.format(eval_dsc.avg))
162 | loss_all.reset()
163 |
164 | def adjust_learning_rate(optimizer, epoch, MAX_EPOCHES, INIT_LR, power=0.9):
165 | for param_group in optimizer.param_groups:
166 | param_group['lr'] = round(INIT_LR * np.power(1 - (epoch) / MAX_EPOCHES, power), 8)
167 |
168 |
169 | def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max_model_num=8):
170 | torch.save(state, save_dir+filename)
171 | model_lists = natsorted(glob.glob(save_dir + '*'))
172 | while len(model_lists) > max_model_num:
173 | os.remove(model_lists[0])
174 | model_lists = natsorted(glob.glob(save_dir + '*'))
175 |
176 | if __name__ == '__main__':
177 | '''
178 |
179 | GPU configuration
180 | '''
181 | GPU_iden = 0
182 | GPU_num = torch.cuda.device_count()
183 | print('Number of GPU: ' + str(GPU_num))
184 | for GPU_idx in range(GPU_num):
185 | GPU_name = torch.cuda.get_device_name(GPU_idx)
186 | print(' GPU #' + str(GPU_idx) + ': ' + GPU_name)
187 | torch.cuda.set_device(GPU_iden)
188 | GPU_avai = torch.cuda.is_available()
189 | print('Currently using: ' + torch.cuda.get_device_name(GPU_iden))
190 | print('If the GPU is available? ' + str(GPU_avai))
191 | main()
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | '''
2 | VoxelMorph
3 | Original code retrieved from:
4 | https://github.com/voxelmorph/voxelmorph
5 |
6 | Original paper:
7 | Balakrishnan, G., Zhao, A., Sabuncu, M. R., Guttag, J., & Dalca, A. V. (2019).
8 | VoxelMorph: a learning framework for deformable medical image registration.
9 | IEEE transactions on medical imaging, 38(8), 1788-1800.
10 |
11 | Modified and tested by:
12 | Haiqiao Wang
13 | 2110246069@email.szu.edu.cn
14 | Shenzhen University
15 | '''
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as nnf
20 |
21 | import numpy as np
22 | from torch.distributions.normal import Normal
23 |
24 | class SpatialTransformer(nn.Module):
25 | """
26 | N-D Spatial Transformer
27 | """
28 |
29 | def __init__(self, size, mode='bilinear'):
30 | super().__init__()
31 |
32 | self.mode = mode
33 |
34 | # create sampling grid
35 | vectors = [torch.arange(0, s) for s in size]
36 | grids = torch.meshgrid(vectors)
37 | grid = torch.stack(grids)
38 | grid = torch.unsqueeze(grid, 0)
39 | grid = grid.type(torch.FloatTensor)
40 |
41 | # registering the grid as a buffer cleanly moves it to the GPU, but it also
42 | # adds it to the state dict. this is annoying since everything in the state dict
43 | # is included when saving weights to disk, so the model files are way bigger
44 | # than they need to be. so far, there does not appear to be an elegant solution.
45 | # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
46 | self.register_buffer('grid', grid)
47 |
48 | def forward(self, src, flow):
49 | # new locations
50 | new_locs = self.grid + flow
51 | shape = flow.shape[2:]
52 |
53 | # need to normalize grid values to [-1, 1] for resampler
54 | for i in range(len(shape)):
55 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
56 |
57 | # move channels dim to last position
58 | # also not sure why, but the channels need to be reversed
59 | if len(shape) == 2:
60 | new_locs = new_locs.permute(0, 2, 3, 1)
61 | new_locs = new_locs[..., [1, 0]]
62 | elif len(shape) == 3:
63 | new_locs = new_locs.permute(0, 2, 3, 4, 1)
64 | new_locs = new_locs[..., [2, 1, 0]]
65 |
66 | return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)
67 |
68 |
69 | class VecInt(nn.Module):
70 | """
71 | Integrates a vector field via scaling and squaring.
72 | """
73 |
74 | def __init__(self, inshape, nsteps=7):
75 | super().__init__()
76 |
77 | assert nsteps >= 0, 'nsteps should be >= 0, found: %d' % nsteps
78 | self.nsteps = nsteps
79 | self.scale = 1.0 / (2 ** self.nsteps)
80 | self.transformer = SpatialTransformer(inshape)
81 |
82 | def forward(self, vec):
83 | vec = vec * self.scale
84 | for _ in range(self.nsteps):
85 | vec = vec + self.transformer(vec, vec)
86 | return vec
87 |
88 |
89 | class ResizeTransform(nn.Module):
90 |
91 | def __init__(self, vel_resize, ndims):
92 | super().__init__()
93 | self.factor = 1.0 / vel_resize
94 | self.mode = 'linear'
95 | if ndims == 2:
96 | self.mode = 'bi' + self.mode
97 | elif ndims == 3:
98 | self.mode = 'tri' + self.mode
99 |
100 | def forward(self, x):
101 | if self.factor < 1:
102 | # resize first to save memory
103 | x = nnf.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode)
104 | x = self.factor * x
105 |
106 | elif self.factor > 1:
107 | # multiply first to save memory
108 | x = self.factor * x
109 | x = nnf.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode)
110 |
111 | # don't do anything if resize is 1
112 | return x
113 |
114 |
115 | class ConvBlock(nn.Module):
116 | """
117 | Specific convolutional block followed by leakyrelu for unet.
118 | """
119 |
120 | def __init__(self, ndims, in_channels, out_channels,kernal_size=3, stride=1, padding=1, alpha=0.1):
121 | super().__init__()
122 |
123 | Conv = getattr(nn, 'Conv%dd' % ndims)
124 | self.main = Conv(in_channels, out_channels, kernal_size, stride, padding)
125 | self.activation = nn.LeakyReLU(alpha)
126 |
127 | def forward(self, x):
128 | out = self.main(x)
129 | out = self.activation(out)
130 | return out
131 |
132 | class ConvInsBlock(nn.Module):
133 | """
134 | Specific convolutional block followed by leakyrelu for unet.
135 | """
136 |
137 | def __init__(self, in_channels, out_channels,kernal_size=3, stride=1, padding=1, alpha=0.1):
138 | super().__init__()
139 |
140 | self.main = nn.Conv3d(in_channels, out_channels, kernal_size, stride, padding)
141 | self.norm = nn.InstanceNorm3d(out_channels)
142 | self.activation = nn.LeakyReLU(alpha)
143 |
144 | def forward(self, x):
145 | out = self.main(x)
146 | out = self.norm(out)
147 | out = self.activation(out)
148 | return out
149 |
150 | class UpConvBlock(nn.Module):
151 | def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, alpha=0.1):
152 | super(UpConvBlock, self).__init__()
153 |
154 | self.upconv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1)
155 |
156 | self.actout = nn.Sequential(
157 | nn.InstanceNorm3d(out_channels),
158 | nn.LeakyReLU(alpha)
159 | )
160 | def forward(self, x):
161 | x = self.upconv(x)
162 | return self.actout(x)
163 |
164 | class ResBlock(nn.Module):
165 | """
166 | VoxRes module
167 | """
168 |
169 | def __init__(self, channel, alpha=0.1):
170 | super(ResBlock, self).__init__()
171 | self.block = nn.Sequential(
172 | nn.InstanceNorm3d(channel),
173 | nn.LeakyReLU(alpha),
174 | nn.Conv3d(channel, channel, kernel_size=3, padding=1)
175 | )
176 | self.actout = nn.Sequential(
177 | nn.InstanceNorm3d(channel),
178 | nn.LeakyReLU(alpha),
179 | )
180 | def forward(self, x):
181 | out = self.block(x) + x
182 | return self.actout(out)
183 |
184 |
185 | class Encoder(nn.Module):
186 | """
187 | Main model
188 | """
189 |
190 | def __init__(self, in_channel=1, first_out_channel=16):
191 | super(Encoder, self).__init__()
192 |
193 | c = first_out_channel
194 | self.conv0 = ConvInsBlock(in_channel, c, 3, 1)
195 |
196 | self.conv1 = nn.Sequential(
197 | nn.Conv3d(c, 2*c, kernel_size=3, stride=2, padding=1),#80
198 | ResBlock(2*c)
199 | )
200 |
201 | self.conv2 = nn.Sequential(
202 | nn.Conv3d(2*c, 4*c, kernel_size=3, stride=2, padding=1),#40
203 | ResBlock(4*c)
204 | )
205 |
206 | self.conv3 = nn.Sequential(
207 | nn.Conv3d(4*c, 8*c, kernel_size=3, stride=2, padding=1),#20
208 | ResBlock(8*c)
209 | )
210 |
211 | def forward(self, x):
212 | out0 = self.conv0(x) # 1
213 | out1 = self.conv1(out0) # 1/2
214 | out2 = self.conv2(out1) # 1/4
215 | out3 = self.conv3(out2) # 1/8
216 |
217 | return [out0, out1, out2, out3]
218 |
219 | class CConv(nn.Module):
220 | def __init__(self, channel):
221 | super(CConv, self).__init__()
222 |
223 | c = channel
224 |
225 | self.conv = nn.Sequential(
226 | ConvInsBlock(c, c, 3, 1),
227 | ConvInsBlock(c, c, 3, 1)
228 | )
229 |
230 | def forward(self, float_fm, fixed_fm, d_fm):
231 | concat_fm = torch.cat([float_fm, fixed_fm, d_fm], dim=1)
232 | x = self.conv(concat_fm)
233 | return x
234 |
235 | class RDP(nn.Module):
236 | def __init__(self, inshape=(160,192,160), flow_multiplier=1.,in_channel=1, channels=16):
237 | super(RDP, self).__init__()
238 | self.flow_multiplier = flow_multiplier
239 | self.channels = channels
240 | self.step = 7
241 | self.inshape = inshape
242 |
243 | c = self.channels
244 | self.encoder_moving = Encoder(in_channel=in_channel, first_out_channel=c)
245 | self.encoder_fixed = Encoder(in_channel=in_channel, first_out_channel=c)
246 |
247 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
248 | self.upsample_trilin = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)#nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
249 |
250 |
251 | self.warp = nn.ModuleList()
252 | self.diff = nn.ModuleList()
253 | for i in range(4):
254 | self.warp.append(SpatialTransformer([s // 2**i for s in inshape]))
255 | self.diff.append(VecInt([s // 2**i for s in inshape]))
256 |
257 | # bottleNeck
258 | self.cconv_4 = nn.Sequential(
259 | ConvInsBlock(16 * c, 8 * c, 3, 1),
260 | ConvInsBlock(8 * c, 8 * c, 3, 1)
261 | )
262 | # warp scale 2
263 | self.defconv4 = nn.Conv3d(8*c, 3, 3, 1, 1)
264 | self.defconv4.weight = nn.Parameter(Normal(0, 1e-5).sample(self.defconv4.weight.shape))
265 | self.defconv4.bias = nn.Parameter(torch.zeros(self.defconv4.bias.shape))
266 | self.dconv4 = nn.Sequential(
267 | ConvInsBlock(3*8*c, 8*c),
268 | ConvInsBlock(8*c, 8*c)
269 | )
270 |
271 | self.upconv3 = UpConvBlock(8*c, 4*c, 4, 2)
272 | self.cconv_3 = CConv(3*4*c)
273 |
274 | # warp scale 1
275 | self.defconv3 = nn.Conv3d(3*4*c, 3, 3, 1, 1)
276 | self.defconv3.weight = nn.Parameter(Normal(0, 1e-5).sample(self.defconv3.weight.shape))
277 | self.defconv3.bias = nn.Parameter(torch.zeros(self.defconv3.bias.shape))
278 | self.dconv3 = ConvInsBlock(3 * 4 * c, 4 * c)
279 |
280 | self.upconv2 = UpConvBlock(3*4*c, 2*c, 4, 2)
281 | self.cconv_2 = CConv(3*2*c)
282 |
283 | # warp scale 0
284 | self.defconv2 = nn.Conv3d(3*2*c, 3, 3, 1, 1)
285 | self.defconv2.weight = nn.Parameter(Normal(0, 1e-5).sample(self.defconv2.weight.shape))
286 | self.defconv2.bias = nn.Parameter(torch.zeros(self.defconv2.bias.shape))
287 | self.dconv2 = ConvInsBlock(3 * 2 * c, 2 * c)
288 |
289 | self.upconv1 = UpConvBlock(3*2*c, c, 4, 2)
290 | self.cconv_1 = CConv(3*c)
291 |
292 | # decoder layers
293 | self.defconv1 = nn.Conv3d(3*c, 3, 3, 1, 1)
294 | self.defconv1.weight = nn.Parameter(Normal(0, 1e-5).sample(self.defconv1.weight.shape))
295 | self.defconv1.bias = nn.Parameter(torch.zeros(self.defconv1.bias.shape))
296 | #self.dconv1 = ConvInsBlock(3 * c, c)
297 |
298 | def forward(self, moving, fixed):
299 |
300 | # encode stage
301 | M1, M2, M3, M4 = self.encoder_moving(moving)
302 | F1, F2, F3, F4 = self.encoder_fixed(fixed)
303 | # c=16, 2c, 4c, 8c # 160, 80, 40, 20
304 |
305 | # first dec layer
306 | C4 = torch.cat([F4, M4], dim=1)
307 | C4 = self.cconv_4(C4) # (1,128,20,24,20)
308 | flow = self.defconv4(C4) # (1,3,20,24,20)
309 | flow = self.diff[3](flow)
310 | warped = self.warp[3](M4, flow)
311 | C4 = self.dconv4(torch.cat([F4, warped, C4], dim=1))
312 | v = self.defconv4(C4) # (1,3,20,24,20)
313 | w = self.diff[3](v)
314 |
315 |
316 | D3 = self.upconv3(C4) # (1, 64, 40, 48, 40)
317 | flow = self.upsample_trilin(2*(self.warp[3](flow, w)+w))
318 | warped = self.warp[2](M3, flow) # (1, 64, 40, 48, 40)
319 | C3 = self.cconv_3(F3, warped, D3) # (1, 3 * 64, 40, 48, 40)
320 | v = self.defconv3(C3)
321 | w = self.diff[2](v)
322 | flow = self.warp[2](flow, w)+w
323 | warped = self.warp[2](M3, flow) # (1, 64, 40, 48, 40)
324 | D3 = self.dconv3(C3)
325 | C3 = self.cconv_3(F3, warped, D3) # (1, 3 * 64, 40, 48, 40)
326 | v = self.defconv3(C3)
327 | w = self.diff[2](v)
328 |
329 | D2 = self.upconv2(C3)
330 | flow = self.upsample_trilin(2*(self.warp[2](flow, w)+w))
331 | warped = self.warp[1](M2, flow)
332 | C2 = self.cconv_2(F2, warped, D2)
333 | v = self.defconv2(C2) # (1,3,80,96,80)
334 | w = self.diff[1](v)
335 | flow = self.warp[1](flow, w)+w
336 | warped = self.warp[1](M2, flow)
337 | D2 = self.dconv2(C2)
338 | C2 = self.cconv_2(F2, warped, D2)
339 | v = self.defconv2(C2) # (1,3,80,96,80)
340 | w = self.diff[1](v)
341 | flow = self.warp[1](flow, w)+w
342 | warped = self.warp[1](M2, flow)
343 | D2 = self.dconv2(C2)
344 | C2 = self.cconv_2(F2, warped, D2)
345 | v = self.defconv2(C2) # (1,3,80,96,80)
346 | w = self.diff[1](v)
347 |
348 | D1 = self.upconv1(C2) # (1,16,160,196,160)
349 | flow = self.upsample_trilin(2*(self.warp[1](flow, w)+w)) # (1,3,160,196,160)
350 | warped = self.warp[0](M1, flow) # (1,16,160,196,160)
351 | C1 = self.cconv_1(F1, warped, D1) # (1,48,160,196,160)
352 | v = self.defconv1(C1)
353 | w = self.diff[0](v)
354 | flow = self.warp[0](flow, w)+w # (1,3,160,196,160)
355 |
356 | y_moved = self.warp[0](moving, flow)
357 |
358 | return y_moved, flow
359 |
360 | if __name__ == '__main__':
361 | size = (1, 1, 80, 96, 80)
362 | model = RDP(size[2:])
363 | # print(str(model))
364 | A = torch.ones(size)
365 | B = torch.ones(size)
366 | out, flow = model(A, B)
367 | print(out.shape, flow.shape)
368 |
--------------------------------------------------------------------------------