├── data ├── __init__.py ├── trans.py └── datasets.py ├── makePklDataset.py ├── losses.py ├── README.md ├── infer.py ├── utils.py ├── train.py └── models.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | import sys -------------------------------------------------------------------------------- /data/trans.py: -------------------------------------------------------------------------------- 1 | # import math 2 | import collections 3 | import numpy as np 4 | 5 | 6 | class Base(object): 7 | def sample(self, *shape): 8 | return shape 9 | 10 | def tf(self, img, k=0): 11 | return img 12 | 13 | def __call__(self, img, dim=3, reuse=False): 14 | if not reuse: 15 | im = img if isinstance(img, np.ndarray) else img[0] 16 | shape = im.shape[1:dim+1] 17 | self.sample(*shape) 18 | 19 | if isinstance(img, collections.Sequence): 20 | return [self.tf(x, k) for k, x in enumerate(img)] 21 | 22 | return self.tf(img) 23 | 24 | def __str__(self): 25 | return 'Identity()' 26 | 27 | class Seg_norm(Base): 28 | def __init__(self, ): 29 | a = None 30 | self.seg_table = np.array([0, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 61, 62, 31 | 63, 64, 65, 66, 67, 68, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 101, 102, 121, 122, 161, 162, 32 | 163, 164, 165, 166]) 33 | def tf(self, img, k=0): 34 | if k == 0: 35 | return img 36 | img_out = np.zeros_like(img) 37 | for i in range(len(self.seg_table)): 38 | img_out[img == self.seg_table[i]] = i 39 | return img_out 40 | 41 | 42 | class NumpyType(Base): 43 | def __init__(self, types, num=-1): 44 | self.types = types # ('float32', 'int64') 45 | self.num = num 46 | 47 | def tf(self, img, k=0): 48 | if self.num > 0 and k >= self.num: 49 | return img 50 | # make this work with both Tensor and Numpy 51 | return img.astype(self.types[k]) 52 | 53 | def __str__(self): 54 | s = ', '.join([str(s) for s in self.types]) 55 | return 'NumpyType(({}))'.format(s) 56 | 57 | -------------------------------------------------------------------------------- /makePklDataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import SimpleITK as sitk 3 | import numpy as np 4 | import glob 5 | from natsort import natsorted 6 | import os 7 | 8 | def pksave(img, label, save_path): 9 | with open(save_path, 'wb') as f: 10 | pickle.dump((img, label), f) 11 | 12 | def nii2arr(nii_img): 13 | return sitk.GetArrayFromImage(sitk.ReadImage(nii_img)) 14 | 15 | def center(arr): 16 | c = np.sort(np.nonzero(arr))[:,[0,-1]] 17 | return np.mean(c, axis=-1).astype('int16') 18 | 19 | def minmax(arr): 20 | return (arr-np.min(arr))/(np.max(arr)-np.min(arr)) 21 | 22 | def cropByCenter(image,center,final_shape=(160,192,160)): 23 | c = center 24 | crop = np.array([s // 2 for s in final_shape]) 25 | # 0 axis 26 | cropmin, cropmax = c[0] - crop[0], c[0] + crop[0] 27 | if cropmin < 0: 28 | cropmin = 0 29 | cropmax = final_shape[0] 30 | if cropmax > image.shape[0]: 31 | cropmax = image.shape[0] 32 | cropmin = image.shape[0] - final_shape[0] 33 | image = image[cropmin:cropmax, :, :] 34 | # 1 axis 35 | cropmin, cropmax = c[1] - crop[1], c[1] + crop[1] 36 | if cropmin < 0: 37 | cropmin = 0 38 | cropmax = final_shape[1] 39 | if cropmax > image.shape[1]: 40 | cropmax = image.shape[1] 41 | cropmin = image.shape[1] - final_shape[1] 42 | image = image[:, cropmin:cropmax, :] 43 | 44 | # 2 axis 45 | cropmin, cropmax = c[2] - crop[2], c[2] + crop[2] 46 | if cropmin < 0: 47 | cropmin = 0 48 | cropmax = final_shape[2] 49 | if cropmax > image.shape[2]: 50 | cropmax = image.shape[2] 51 | cropmin = image.shape[2] - final_shape[2] 52 | image = image[:, :, cropmin:cropmax] 53 | return image 54 | 55 | path_to_LPBA='/data/LPBA40/' # the path of the original dataset 56 | img_niis = natsorted(glob.glob(path_to_LPBA+'*/*/*skullstripped.img.gz')) 57 | label_niis = natsorted(glob.glob(path_to_LPBA+'*/*/*label.img.gz')) 58 | print(img_niis, label_niis) 59 | 60 | save_path = 'LPBA_data/' 61 | if not os.path.exists(save_path): 62 | os.makedirs(save_path) 63 | 64 | for i, nii in enumerate(zip(img_niis, label_niis)): 65 | print(nii) 66 | img_nii, label_nii = nii 67 | img, label = nii2arr(img_nii), nii2arr(label_nii) 68 | print(img.shape, label.shape) 69 | 70 | # crop by center 71 | c = center(img) 72 | img = cropByCenter(img, c) 73 | label = cropByCenter(label, c) 74 | 75 | #norm 76 | img = minmax(img).astype('float32') 77 | label = label.astype('uint16') 78 | print(img.shape,np.unique(img),label.dtype, label.shape,np.unique(label),label.dtype) 79 | print(save_path+'subject_%02d.pkl'%(i+1)) 80 | pksave(img,label, save_path=save_path+'subject_%02d.pkl'%(i+1)) 81 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import math 5 | 6 | class Grad3d(torch.nn.Module): 7 | """ 8 | N-D gradient loss. 9 | """ 10 | 11 | def __init__(self, penalty='l1', loss_mult=None): 12 | super(Grad3d, self).__init__() 13 | self.penalty = penalty 14 | self.loss_mult = loss_mult 15 | 16 | def forward(self, y_pred, y_true): 17 | dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :]) 18 | dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :]) 19 | dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1]) 20 | 21 | if self.penalty == 'l2': 22 | dy = dy * dy 23 | dx = dx * dx 24 | dz = dz * dz 25 | 26 | d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz) 27 | grad = d / 3.0 28 | 29 | if self.loss_mult is not None: 30 | grad *= self.loss_mult 31 | return grad 32 | 33 | 34 | class NCC_vxm(torch.nn.Module): 35 | """ 36 | Local (over window) normalized cross correlation loss. 37 | """ 38 | 39 | def __init__(self, win=None): 40 | super(NCC_vxm, self).__init__() 41 | self.win = win 42 | 43 | def forward(self, y_true, y_pred): 44 | 45 | Ii = y_true 46 | Ji = y_pred 47 | 48 | # get dimension of volume 49 | # assumes Ii, Ji are sized [batch_size, *vol_shape, nb_feats] 50 | ndims = len(list(Ii.size())) - 2 51 | assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims 52 | 53 | # set window size 54 | win = [9] * ndims if self.win is None else self.win 55 | 56 | # compute filters 57 | sum_filt = torch.ones([1, 1, *win]).to("cuda") 58 | 59 | pad_no = math.floor(win[0] / 2) 60 | 61 | if ndims == 1: 62 | stride = (1) 63 | padding = (pad_no) 64 | elif ndims == 2: 65 | stride = (1, 1) 66 | padding = (pad_no, pad_no) 67 | else: 68 | stride = (1, 1, 1) 69 | padding = (pad_no, pad_no, pad_no) 70 | 71 | # get convolution function 72 | conv_fn = getattr(F, 'conv%dd' % ndims) 73 | 74 | # compute CC squares 75 | I2 = Ii * Ii 76 | J2 = Ji * Ji 77 | IJ = Ii * Ji 78 | 79 | I_sum = conv_fn(Ii, sum_filt, stride=stride, padding=padding) 80 | J_sum = conv_fn(Ji, sum_filt, stride=stride, padding=padding) 81 | I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding) 82 | J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding) 83 | IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding) 84 | 85 | win_size = np.prod(win) 86 | u_I = I_sum / win_size 87 | u_J = J_sum / win_size 88 | 89 | cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size 90 | I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size 91 | J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size 92 | 93 | cc = cross * cross / (I_var * J_var + 1e-5) 94 | 95 | return -torch.mean(cc) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recursive Deformable Pyramid Network for Unsupervised Medical Image Registration (TMI2024) 2 | 3 | By Haiqiao Wang, Dong Ni, Yi Wang. 4 | 5 | Paper link: [[TMI]](https://ieeexplore.ieee.org/document/10423043) 6 | 7 | 8 | 9 | 10 | ## Description 11 | An unsupervised brain MR deformable registration method that achieves precise alignment through a pure convolutional pyramid structure and a semantics-infused progressive recursive inter-level looping strategy for modeling complex deformations, even without pre-alignment of brain MR images. 12 | 13 | ![图片1](https://github.com/ZAX130/RDP/assets/43944700/66c3058f-7d9c-499c-8017-40c62240f4d7) 14 | 15 | 16 | ## Dataset 17 | The official access addresses of the public data sets are as follows: 18 | 19 | LPBA [[link]](https://resource.loni.usc.edu/resources/atlases-downloads/) 20 | 21 | Mindboggle [[link]](https://osf.io/yhkde/) 22 | 23 | IXI [[link]](https://surfer.nmr.mgh.harvard.edu/pub/data/) [[freesurfer link]](https://surfer.nmr.mgh.harvard.edu/pub/data/ixi/) 24 | 25 | Note that we use the processed IXI dataset provided by freesurfer. 26 | 27 | ## Instructions 28 | For convenience, we are sharing the preprocessed [LPBA](https://drive.usercontent.google.com/download?id=1mFzZDn2qPAiP1ByGZ7EbsvEmm6vrS5WO&export=download&authuser=0) dataset used in our experiments. Once uncompressed, simply modify the "LPBA_path" in `train.py` to the path name of the extracted data. Next, you can execute `train.py` to train the network, and after training, you can run `infer.py` to test the network performance. 29 | 30 | ## Citation 31 | If you find the code useful, please cite our paper. 32 | ``` 33 | @ARTICLE{10423043, 34 | author={Wang, Haiqiao and Ni, Dong and Wang, Yi}, 35 | journal={IEEE Transactions on Medical Imaging}, 36 | title={Recursive Deformable Pyramid Network for Unsupervised Medical Image Registration}, 37 | year={2024}, 38 | volume={43}, 39 | number={6}, 40 | pages={2229-2240}, 41 | keywords={Deformation;Decoding;Feature extraction;Deformable models;Training;Image resolution;Image registration;Deformable image registration;convolutional neural networks;brain MRI}, 42 | doi={10.1109/TMI.2024.3362968}} 43 | ``` 44 | The overall framework and some network components of the code are heavily based on [TransMorph](https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration) and [VoxelMorph](https://github.com/voxelmorph/voxelmorph). We are very grateful for their contributions. 45 | 46 | The file `makePklDataset.py` shows how to make a pkl dataset from the original LPBA dataset. If you have any other questions about the .pkl format, please refer to the github page of [[TransMorph_on_IXI]](https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration/blob/main/IXI/TransMorph_on_IXI.md). 47 | 48 | ## Baseline Methods 49 | Several PyTorch implementations of some baseline methods can be found at [[SmileCode]](https://github.com/ZAX130/SmileCode/tree/main). 50 | 51 | ## How can other datasets be used in this code? 52 | This is a common question, and please refer to the github page of [ChangeDataset.md](https://github.com/ZAX130/ModeTv2/blob/main/ChangeDataset.md) for more information. 53 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os, losses, utils 3 | from torch.utils.data import DataLoader 4 | from data import datasets, trans 5 | import numpy as np 6 | import torch 7 | from torchvision import transforms 8 | 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.mplot3d import axes3d 11 | from natsort import natsorted 12 | from models import RDP 13 | import random 14 | def same_seeds(seed): 15 | # Python built-in random module 16 | random.seed(seed) 17 | # Numpy 18 | np.random.seed(seed) 19 | # Torch 20 | torch.manual_seed(seed) 21 | if torch.cuda.is_available(): 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | # torch.backends.cudnn.deterministic = True 25 | 26 | same_seeds(24) 27 | 28 | class AverageMeter(object): 29 | """Computes and stores the average and current value""" 30 | def __init__(self): 31 | self.reset() 32 | 33 | def reset(self): 34 | self.val = 0 35 | self.avg = 0 36 | self.sum = 0 37 | self.count = 0 38 | self.vals = [] 39 | self.std = 0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | self.avg = self.sum / self.count 46 | self.vals.append(val) 47 | self.std = np.std(self.vals) 48 | 49 | def main(): 50 | 51 | val_dir = '/LPBA_path/Val/' 52 | weights = [1, 1] # loss weights 53 | lr = 0.0001 54 | model_idx = -1 55 | model_folder = 'RDP_ncc_{}_reg_{}_lr_{}_54r/'.format(weights[0], weights[1], lr) 56 | model_dir = 'experiments/' + model_folder 57 | 58 | img_size = (160, 192, 160) 59 | model = RDP(img_size, channels=16) 60 | best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[model_idx])['state_dict'] 61 | print('Best model: {}'.format(natsorted(os.listdir(model_dir))[model_idx])) 62 | model.load_state_dict(best_model) 63 | model.cuda() 64 | reg_model = utils.register_model(img_size, 'nearest') 65 | reg_model.cuda() 66 | test_composed = transforms.Compose([trans.Seg_norm(), 67 | trans.NumpyType((np.float32, np.int16)), 68 | ]) 69 | test_set = datasets.LPBABrainInferDatasetS2S(glob.glob(val_dir + '*.pkl'), transforms=test_composed) 70 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=0, pin_memory=True, drop_last=True) 71 | eval_dsc_def = AverageMeter() 72 | eval_dsc_raw = AverageMeter() 73 | eval_det = AverageMeter() 74 | with torch.no_grad(): 75 | stdy_idx = 0 76 | for data in test_loader: 77 | model.eval() 78 | data = [t.cuda() for t in data] 79 | x = data[0] 80 | y = data[1] 81 | x_seg = data[2] 82 | y_seg = data[3] 83 | 84 | x_def, flow = model(x,y) 85 | def_out = reg_model([x_seg.cuda().float(), flow.cuda()]) 86 | tar = y.detach().cpu().numpy()[0, 0, :, :, :] 87 | jac_det = utils.jacobian_determinant_vxm(flow.detach().cpu().numpy()[0, :, :, :, :]) 88 | eval_det.update(np.sum(jac_det <= 0) / np.prod(tar.shape), x.size(0)) 89 | dsc_trans = utils.dice_val_VOI(def_out.long(), y_seg.long()) 90 | dsc_raw = utils.dice_val_VOI(x_seg.long(), y_seg.long()) 91 | print('Trans dsc: {:.4f}, Raw dsc: {:.4f}'.format(dsc_trans.item(),dsc_raw.item())) 92 | eval_dsc_def.update(dsc_trans.item(), x.size(0)) 93 | eval_dsc_raw.update(dsc_raw.item(), x.size(0)) 94 | stdy_idx += 1 95 | print('Deformed DSC: {:.3f} +- {:.3f}, Affine DSC: {:.3f} +- {:.3f}'.format(eval_dsc_def.avg, 96 | eval_dsc_def.std, 97 | eval_dsc_raw.avg, 98 | eval_dsc_raw.std)) 99 | print('deformed det: {}, std: {}'.format(eval_det.avg, eval_det.std)) 100 | 101 | 102 | if __name__ == '__main__': 103 | ''' 104 | GPU configuration 105 | ''' 106 | GPU_iden = 0 107 | GPU_num = torch.cuda.device_count() 108 | print('Number of GPU: ' + str(GPU_num)) 109 | for GPU_idx in range(GPU_num): 110 | GPU_name = torch.cuda.get_device_name(GPU_idx) 111 | print(' GPU #' + str(GPU_idx) + ': ' + GPU_name) 112 | torch.cuda.set_device(GPU_iden) 113 | GPU_avai = torch.cuda.is_available() 114 | print('Currently using: ' + torch.cuda.get_device_name(GPU_iden)) 115 | print('If the GPU is available? ' + str(GPU_avai)) 116 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch, sys 5 | from torch import nn 6 | import pystrum.pynd.ndutils as nd 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | self.vals = [] 19 | self.std = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | self.vals.append(val) 27 | self.std = np.std(self.vals) 28 | 29 | 30 | class SpatialTransformer(nn.Module): 31 | """ 32 | N-D Spatial Transformer 33 | """ 34 | 35 | def __init__(self, size, mode='bilinear'): 36 | super().__init__() 37 | 38 | self.mode = mode 39 | 40 | # create sampling grid 41 | vectors = [torch.arange(0, s) for s in size] 42 | grids = torch.meshgrid(vectors) 43 | grid = torch.stack(grids) 44 | grid = torch.unsqueeze(grid, 0) 45 | grid = grid.type(torch.FloatTensor).cuda() 46 | 47 | # registering the grid as a buffer cleanly moves it to the GPU, but it also 48 | # adds it to the state dict. this is annoying since everything in the state dict 49 | # is included when saving weights to disk, so the model files are way bigger 50 | # than they need to be. so far, there does not appear to be an elegant solution. 51 | # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict 52 | self.register_buffer('grid', grid) 53 | 54 | def forward(self, src, flow): 55 | # new locations 56 | new_locs = self.grid + flow 57 | shape = flow.shape[2:] 58 | 59 | # need to normalize grid values to [-1, 1] for resampler 60 | for i in range(len(shape)): 61 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) 62 | 63 | # move channels dim to last position 64 | # also not sure why, but the channels need to be reversed 65 | if len(shape) == 2: 66 | new_locs = new_locs.permute(0, 2, 3, 1) 67 | new_locs = new_locs[..., [1, 0]] 68 | elif len(shape) == 3: 69 | new_locs = new_locs.permute(0, 2, 3, 4, 1) 70 | new_locs = new_locs[..., [2, 1, 0]] 71 | 72 | return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode) 73 | 74 | class register_model(nn.Module): 75 | def __init__(self, img_size=(64, 256, 256), mode='bilinear'): 76 | super(register_model, self).__init__() 77 | self.spatial_trans = SpatialTransformer(img_size, mode) 78 | 79 | def forward(self, x): 80 | img = x[0].cuda() 81 | flow = x[1].cuda() 82 | out = self.spatial_trans(img, flow) 83 | return out 84 | 85 | 86 | def dice_val_VOI(y_pred, y_true): 87 | VOI_lbls = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 88 | 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 89 | 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 90 | 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 91 | 48, 49, 50, 51, 52, 53, 54] 92 | 93 | pred = y_pred.detach().cpu().numpy()[0, 0, ...] 94 | true = y_true.detach().cpu().numpy()[0, 0, ...] 95 | DSCs = np.zeros((len(VOI_lbls), 1)) 96 | idx = 0 97 | for i in VOI_lbls: 98 | pred_i = pred == i 99 | true_i = true == i 100 | intersection = pred_i * true_i 101 | intersection = np.sum(intersection) 102 | union = np.sum(pred_i) + np.sum(true_i) 103 | dsc = (2.*intersection) / (union + 1e-5) 104 | DSCs[idx] =dsc 105 | idx += 1 106 | return np.mean(DSCs) 107 | 108 | def jacobian_determinant_vxm(disp): 109 | """ 110 | jacobian determinant of a displacement field. 111 | NB: to compute the spatial gradients, we use np.gradient. 112 | Parameters: 113 | disp: 2D or 3D displacement field of size [*vol_shape, nb_dims], 114 | where vol_shape is of len nb_dims 115 | Returns: 116 | jacobian determinant (scalar) 117 | """ 118 | 119 | # check inputs 120 | disp = disp.transpose(1, 2, 3, 0) 121 | volshape = disp.shape[:-1] 122 | nb_dims = len(volshape) 123 | assert len(volshape) in (2, 3), 'flow has to be 2D or 3D' 124 | 125 | # compute grid 126 | grid_lst = nd.volsize2ndgrid(volshape) 127 | grid = np.stack(grid_lst, len(volshape)) 128 | 129 | # compute gradients 130 | J = np.gradient(disp + grid) 131 | 132 | # 3D glow 133 | if nb_dims == 3: 134 | dx = J[0] 135 | dy = J[1] 136 | dz = J[2] 137 | 138 | # compute jacobian components 139 | Jdet0 = dx[..., 0] * (dy[..., 1] * dz[..., 2] - dy[..., 2] * dz[..., 1]) 140 | Jdet1 = dx[..., 1] * (dy[..., 0] * dz[..., 2] - dy[..., 2] * dz[..., 0]) 141 | Jdet2 = dx[..., 2] * (dy[..., 0] * dz[..., 1] - dy[..., 1] * dz[..., 0]) 142 | 143 | return Jdet0 - Jdet1 + Jdet2 144 | 145 | else: # must be 2 146 | 147 | dfdx = J[0] 148 | dfdy = J[1] 149 | 150 | return dfdx[..., 0] * dfdy[..., 1] - dfdy[..., 0] * dfdx[..., 1] 151 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import torch, sys 3 | from torch.utils.data import Dataset 4 | import matplotlib.pyplot as plt 5 | import pickle 6 | import numpy as np 7 | 8 | def pkload(fname): 9 | with open(fname, 'rb') as f: 10 | return pickle.load(f) 11 | 12 | class LPBABrainDatasetS2S(Dataset): 13 | def __init__(self, data_path, transforms): 14 | self.paths = data_path 15 | self.transforms = transforms 16 | 17 | def one_hot(self, img, C): 18 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3])) 19 | for i in range(C): 20 | out[i,...] = img == i 21 | return out 22 | 23 | def __getitem__(self, index): 24 | x_index = index // (len(self.paths) - 1) 25 | s = index % (len(self.paths) - 1) 26 | y_index = s + 1 if s >= x_index else s 27 | path_x = self.paths[x_index] 28 | path_y = self.paths[y_index] 29 | x, x_seg = pkload(path_x) 30 | y, y_seg = pkload(path_y) 31 | #print(x.shape) 32 | #print(x.shape) 33 | #print(np.unique(y)) 34 | # print(x.shape, y.shape)#(240, 240, 155) (240, 240, 155) 35 | # transforms work with nhwtc 36 | x, y = x[None, ...], y[None, ...] 37 | # print(x.shape, y.shape)#(1, 240, 240, 155) (1, 240, 240, 155) 38 | x,y = self.transforms([x, y]) 39 | #y = self.one_hot(y, 2) 40 | #print(y.shape) 41 | #sys.exit(0) 42 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth] 43 | y = np.ascontiguousarray(y) 44 | #plt.figure() 45 | #plt.subplot(1, 2, 1) 46 | #plt.imshow(x[0, :, :, 8], cmap='gray') 47 | #plt.subplot(1, 2, 2) 48 | #plt.imshow(y[0, :, :, 8], cmap='gray') 49 | #plt.show() 50 | #sys.exit(0) 51 | #y = np.squeeze(y, axis=0) 52 | x, y = torch.from_numpy(x), torch.from_numpy(y) 53 | return x, y 54 | 55 | def __len__(self): 56 | return len(self.paths)*(len(self.paths)-1) 57 | 58 | 59 | class LPBABrainInferDatasetS2S(Dataset): 60 | def __init__(self, data_path, transforms): 61 | self.paths = data_path 62 | self.transforms = transforms 63 | 64 | def one_hot(self, img, C): 65 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3])) 66 | for i in range(C): 67 | out[i,...] = img == i 68 | return out 69 | 70 | def __getitem__(self, index): 71 | x_index = index//(len(self.paths)-1) 72 | s = index%(len(self.paths)-1) 73 | y_index = s+1 if s >= x_index else s 74 | path_x = self.paths[x_index] 75 | path_y = self.paths[y_index] 76 | # print(os.path.basename(path_x), os.path.basename(path_y)) 77 | x, x_seg = pkload(path_x) 78 | y, y_seg = pkload(path_y) 79 | x, y = x[None, ...], y[None, ...] 80 | x_seg, y_seg= x_seg[None, ...], y_seg[None, ...] 81 | x, x_seg = self.transforms([x, x_seg]) 82 | y, y_seg = self.transforms([y, y_seg]) 83 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth] 84 | y = np.ascontiguousarray(y) 85 | x_seg = np.ascontiguousarray(x_seg) # [Bsize,channelsHeight,,Width,Depth] 86 | y_seg = np.ascontiguousarray(y_seg) 87 | x, y, x_seg, y_seg = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(x_seg), torch.from_numpy(y_seg) 88 | return x, y, x_seg, y_seg 89 | 90 | def __len__(self): 91 | return len(self.paths)*(len(self.paths)-1) 92 | 93 | 94 | class LPBABrainHalfDatasetS2S(Dataset): 95 | def __init__(self, data_path, transforms): 96 | self.paths = data_path 97 | self.transforms = transforms 98 | 99 | def one_hot(self, img, C): 100 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3])) 101 | for i in range(C): 102 | out[i,...] = img == i 103 | return out 104 | def half_pair(self,pair): 105 | return pair[0][::2,::2,::2], pair[1][::2,::2,::2] 106 | 107 | def __getitem__(self, index): 108 | x_index = index // (len(self.paths) - 1) 109 | s = index % (len(self.paths) - 1) 110 | y_index = s + 1 if s >= x_index else s 111 | path_x = self.paths[x_index] 112 | path_y = self.paths[y_index] 113 | x, x_seg = self.half_pair(pkload(path_x)) 114 | y, y_seg = self.half_pair(pkload(path_y)) 115 | 116 | #print(x.shape) 117 | #print(x.shape) 118 | #print(np.unique(y)) 119 | # print(x.shape, y.shape)#(240, 240, 155) (240, 240, 155) 120 | # transforms work with nhwtc 121 | x, y = x[None, ...], y[None, ...] 122 | # print(x.shape, y.shape)#(1, 240, 240, 155) (1, 240, 240, 155) 123 | x,y = self.transforms([x, y]) 124 | #y = self.one_hot(y, 2) 125 | #print(y.shape) 126 | #sys.exit(0) 127 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth] 128 | y = np.ascontiguousarray(y) 129 | #plt.figure() 130 | #plt.subplot(1, 2, 1) 131 | #plt.imshow(x[0, :, :, 8], cmap='gray') 132 | #plt.subplot(1, 2, 2) 133 | #plt.imshow(y[0, :, :, 8], cmap='gray') 134 | #plt.show() 135 | #sys.exit(0) 136 | #y = np.squeeze(y, axis=0) 137 | x, y = torch.from_numpy(x), torch.from_numpy(y) 138 | return x, y 139 | 140 | def __len__(self): 141 | return len(self.paths)*(len(self.paths)-1) 142 | 143 | 144 | class LPBABrainHalfInferDatasetS2S(Dataset): 145 | def __init__(self, data_path, transforms): 146 | self.paths = data_path 147 | self.transforms = transforms 148 | 149 | def one_hot(self, img, C): 150 | out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3])) 151 | for i in range(C): 152 | out[i,...] = img == i 153 | return out 154 | def half_pair(self,pair): 155 | return pair[0][::2,::2,::2], pair[1][::2,::2,::2] 156 | def __getitem__(self, index): 157 | x_index = index//(len(self.paths)-1) 158 | s = index%(len(self.paths)-1) 159 | y_index = s+1 if s >= x_index else s 160 | path_x = self.paths[x_index] 161 | path_y = self.paths[y_index] 162 | # print(os.path.basename(path_x), os.path.basename(path_y)) 163 | x, x_seg = self.half_pair(pkload(path_x)) 164 | y, y_seg = self.half_pair(pkload(path_y)) 165 | x, y = x[None, ...], y[None, ...] 166 | x_seg, y_seg= x_seg[None, ...], y_seg[None, ...] 167 | x, x_seg = self.transforms([x, x_seg]) 168 | y, y_seg = self.transforms([y, y_seg]) 169 | x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth] 170 | y = np.ascontiguousarray(y) 171 | x_seg = np.ascontiguousarray(x_seg) # [Bsize,channelsHeight,,Width,Depth] 172 | y_seg = np.ascontiguousarray(y_seg) 173 | x, y, x_seg, y_seg = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(x_seg), torch.from_numpy(y_seg) 174 | return x, y, x_seg, y_seg 175 | 176 | def __len__(self): 177 | return len(self.paths)*(len(self.paths)-1) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | # from torch.utils.tensorboard import SummaryWriter 3 | import os, losses, utils 4 | import sys 5 | from torch.utils.data import DataLoader 6 | from data import datasets, trans 7 | import numpy as np 8 | import torch 9 | from torchvision import transforms 10 | from torch import optim 11 | import torch.nn as nn 12 | import matplotlib.pyplot as plt 13 | from natsort import natsorted 14 | from models import RDP 15 | import random 16 | def same_seeds(seed): 17 | # Python built-in random module 18 | random.seed(seed) 19 | # Numpy 20 | np.random.seed(seed) 21 | # Torch 22 | torch.manual_seed(seed) 23 | if torch.cuda.is_available(): 24 | torch.cuda.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | torch.backends.cudnn.benchmark = True 27 | # torch.backends.cudnn.deterministic = True 28 | 29 | same_seeds(24) 30 | class Logger(object): 31 | def __init__(self, save_dir): 32 | self.terminal = sys.stdout 33 | self.log = open(save_dir+"logfile.log", "a") 34 | 35 | def write(self, message): 36 | self.terminal.write(message) 37 | self.log.write(message) 38 | 39 | def flush(self): 40 | pass 41 | 42 | def main(): 43 | batch_size = 1 44 | 45 | train_dir = '/LPBA_path/Train/' 46 | val_dir = '/LPBA_path/Val/' 47 | weights = [1, 1] # loss weights 48 | lr = 0.0001 49 | save_dir = 'RDP_ncc_{}_reg_{}_lr_{}_54r/'.format(*weights, lr) 50 | if not os.path.exists('experiments/' + save_dir): 51 | os.makedirs('experiments/' + save_dir) 52 | if not os.path.exists('logs/' + save_dir): 53 | os.makedirs('logs/' + save_dir) 54 | sys.stdout = Logger('logs/' + save_dir) 55 | f = open(os.path.join('logs/'+save_dir, 'losses and dice' + ".txt"), "a") 56 | 57 | epoch_start = 0 58 | max_epoch = 30 59 | img_size = (160,192,160) 60 | cont_training = False 61 | 62 | ''' 63 | Initialize model 64 | ''' 65 | model = RDP(img_size, channels=16) 66 | model.cuda() 67 | 68 | ''' 69 | Initialize spatial transformation function 70 | ''' 71 | reg_model = utils.register_model(img_size, 'nearest') 72 | reg_model.cuda() 73 | 74 | 75 | ''' 76 | If continue from previous training 77 | ''' 78 | if cont_training: 79 | model_dir = 'experiments/'+save_dir 80 | updated_lr = round(lr * np.power(1 - (epoch_start) / max_epoch,0.9),8) 81 | best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[-1])['state_dict'] 82 | model.load_state_dict(best_model) 83 | print(model_dir + natsorted(os.listdir(model_dir))[-1]) 84 | else: 85 | updated_lr = lr 86 | 87 | ''' 88 | Initialize training 89 | ''' 90 | train_composed = transforms.Compose([trans.NumpyType((np.float32, np.float32))]) 91 | 92 | val_composed = transforms.Compose([trans.Seg_norm(), 93 | trans.NumpyType((np.float32, np.int16))]) 94 | train_set = datasets.LPBABrainDatasetS2S(glob.glob(train_dir + '*.pkl'), transforms=train_composed) 95 | val_set = datasets.LPBABrainInferDatasetS2S(glob.glob(val_dir + '*.pkl'), transforms=val_composed) 96 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) 97 | val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) 98 | 99 | optimizer = optim.Adam(model.parameters(), lr=updated_lr, weight_decay=0, amsgrad=True) 100 | criterion = losses.NCC_vxm() 101 | criterions = [criterion] 102 | criterions += [losses.Grad3d(penalty='l2')] 103 | best_dsc = 0 104 | # writer = SummaryWriter(log_dir='logs/'+save_dir) 105 | for epoch in range(epoch_start, max_epoch): 106 | print('Training Starts') 107 | ''' 108 | Training 109 | ''' 110 | loss_all = utils.AverageMeter() 111 | idx = 0 112 | for data in train_loader: 113 | idx += 1 114 | model.train() 115 | adjust_learning_rate(optimizer, epoch, max_epoch, lr) 116 | data = [t.cuda() for t in data] 117 | x = data[0] 118 | y = data[1] 119 | 120 | output = model(x,y) 121 | 122 | loss = 0 123 | loss_vals = [] 124 | for n, loss_function in enumerate(criterions): 125 | curr_loss = loss_function(output[n], y) * weights[n] 126 | loss_vals.append(curr_loss) 127 | loss += curr_loss 128 | loss_all.update(loss.item(), y.numel()) 129 | optimizer.zero_grad() 130 | loss.backward() 131 | optimizer.step() 132 | 133 | print('Iter {} of {} loss {:.4f}, Img Sim: {:.6f}, Reg: {:.6f}'.format(idx, len(train_loader), loss.item(), loss_vals[0].item(), loss_vals[1].item())) 134 | 135 | print('{} Epoch {} loss {:.4f}'.format(save_dir, epoch, loss_all.avg)) 136 | print('Epoch {} loss {:.4f}'.format(epoch, loss_all.avg), file=f, end=' ') 137 | ''' 138 | Validation 139 | ''' 140 | eval_dsc = utils.AverageMeter() 141 | with torch.no_grad(): 142 | for data in val_loader: 143 | model.eval() 144 | data = [t.cuda() for t in data] 145 | x = data[0] 146 | y = data[1] 147 | x_seg = data[2] 148 | y_seg = data[3] 149 | output = model(x,y) 150 | def_out = reg_model([x_seg.cuda().float(), output[1].cuda()]) 151 | dsc = utils.dice_val_VOI(def_out.long(), y_seg.long()) 152 | eval_dsc.update(dsc.item(), x.size(0)) 153 | print(epoch, ':',eval_dsc.avg) 154 | best_dsc = max(eval_dsc.avg, best_dsc) 155 | print(eval_dsc.avg, file=f) 156 | save_checkpoint({ 157 | 'epoch': epoch + 1, 158 | 'state_dict': model.state_dict(), 159 | 'best_dsc': best_dsc, 160 | 'optimizer': optimizer.state_dict(), 161 | }, save_dir='experiments/' + save_dir, filename='dsc{:.3f}.pth.tar'.format(eval_dsc.avg)) 162 | loss_all.reset() 163 | 164 | def adjust_learning_rate(optimizer, epoch, MAX_EPOCHES, INIT_LR, power=0.9): 165 | for param_group in optimizer.param_groups: 166 | param_group['lr'] = round(INIT_LR * np.power(1 - (epoch) / MAX_EPOCHES, power), 8) 167 | 168 | 169 | def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max_model_num=8): 170 | torch.save(state, save_dir+filename) 171 | model_lists = natsorted(glob.glob(save_dir + '*')) 172 | while len(model_lists) > max_model_num: 173 | os.remove(model_lists[0]) 174 | model_lists = natsorted(glob.glob(save_dir + '*')) 175 | 176 | if __name__ == '__main__': 177 | ''' 178 | 179 | GPU configuration 180 | ''' 181 | GPU_iden = 0 182 | GPU_num = torch.cuda.device_count() 183 | print('Number of GPU: ' + str(GPU_num)) 184 | for GPU_idx in range(GPU_num): 185 | GPU_name = torch.cuda.get_device_name(GPU_idx) 186 | print(' GPU #' + str(GPU_idx) + ': ' + GPU_name) 187 | torch.cuda.set_device(GPU_iden) 188 | GPU_avai = torch.cuda.is_available() 189 | print('Currently using: ' + torch.cuda.get_device_name(GPU_iden)) 190 | print('If the GPU is available? ' + str(GPU_avai)) 191 | main() -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | VoxelMorph 3 | Original code retrieved from: 4 | https://github.com/voxelmorph/voxelmorph 5 | 6 | Original paper: 7 | Balakrishnan, G., Zhao, A., Sabuncu, M. R., Guttag, J., & Dalca, A. V. (2019). 8 | VoxelMorph: a learning framework for deformable medical image registration. 9 | IEEE transactions on medical imaging, 38(8), 1788-1800. 10 | 11 | Modified and tested by: 12 | Haiqiao Wang 13 | 2110246069@email.szu.edu.cn 14 | Shenzhen University 15 | ''' 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as nnf 20 | 21 | import numpy as np 22 | from torch.distributions.normal import Normal 23 | 24 | class SpatialTransformer(nn.Module): 25 | """ 26 | N-D Spatial Transformer 27 | """ 28 | 29 | def __init__(self, size, mode='bilinear'): 30 | super().__init__() 31 | 32 | self.mode = mode 33 | 34 | # create sampling grid 35 | vectors = [torch.arange(0, s) for s in size] 36 | grids = torch.meshgrid(vectors) 37 | grid = torch.stack(grids) 38 | grid = torch.unsqueeze(grid, 0) 39 | grid = grid.type(torch.FloatTensor) 40 | 41 | # registering the grid as a buffer cleanly moves it to the GPU, but it also 42 | # adds it to the state dict. this is annoying since everything in the state dict 43 | # is included when saving weights to disk, so the model files are way bigger 44 | # than they need to be. so far, there does not appear to be an elegant solution. 45 | # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict 46 | self.register_buffer('grid', grid) 47 | 48 | def forward(self, src, flow): 49 | # new locations 50 | new_locs = self.grid + flow 51 | shape = flow.shape[2:] 52 | 53 | # need to normalize grid values to [-1, 1] for resampler 54 | for i in range(len(shape)): 55 | new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) 56 | 57 | # move channels dim to last position 58 | # also not sure why, but the channels need to be reversed 59 | if len(shape) == 2: 60 | new_locs = new_locs.permute(0, 2, 3, 1) 61 | new_locs = new_locs[..., [1, 0]] 62 | elif len(shape) == 3: 63 | new_locs = new_locs.permute(0, 2, 3, 4, 1) 64 | new_locs = new_locs[..., [2, 1, 0]] 65 | 66 | return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode) 67 | 68 | 69 | class VecInt(nn.Module): 70 | """ 71 | Integrates a vector field via scaling and squaring. 72 | """ 73 | 74 | def __init__(self, inshape, nsteps=7): 75 | super().__init__() 76 | 77 | assert nsteps >= 0, 'nsteps should be >= 0, found: %d' % nsteps 78 | self.nsteps = nsteps 79 | self.scale = 1.0 / (2 ** self.nsteps) 80 | self.transformer = SpatialTransformer(inshape) 81 | 82 | def forward(self, vec): 83 | vec = vec * self.scale 84 | for _ in range(self.nsteps): 85 | vec = vec + self.transformer(vec, vec) 86 | return vec 87 | 88 | 89 | class ResizeTransform(nn.Module): 90 | 91 | def __init__(self, vel_resize, ndims): 92 | super().__init__() 93 | self.factor = 1.0 / vel_resize 94 | self.mode = 'linear' 95 | if ndims == 2: 96 | self.mode = 'bi' + self.mode 97 | elif ndims == 3: 98 | self.mode = 'tri' + self.mode 99 | 100 | def forward(self, x): 101 | if self.factor < 1: 102 | # resize first to save memory 103 | x = nnf.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode) 104 | x = self.factor * x 105 | 106 | elif self.factor > 1: 107 | # multiply first to save memory 108 | x = self.factor * x 109 | x = nnf.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode) 110 | 111 | # don't do anything if resize is 1 112 | return x 113 | 114 | 115 | class ConvBlock(nn.Module): 116 | """ 117 | Specific convolutional block followed by leakyrelu for unet. 118 | """ 119 | 120 | def __init__(self, ndims, in_channels, out_channels,kernal_size=3, stride=1, padding=1, alpha=0.1): 121 | super().__init__() 122 | 123 | Conv = getattr(nn, 'Conv%dd' % ndims) 124 | self.main = Conv(in_channels, out_channels, kernal_size, stride, padding) 125 | self.activation = nn.LeakyReLU(alpha) 126 | 127 | def forward(self, x): 128 | out = self.main(x) 129 | out = self.activation(out) 130 | return out 131 | 132 | class ConvInsBlock(nn.Module): 133 | """ 134 | Specific convolutional block followed by leakyrelu for unet. 135 | """ 136 | 137 | def __init__(self, in_channels, out_channels,kernal_size=3, stride=1, padding=1, alpha=0.1): 138 | super().__init__() 139 | 140 | self.main = nn.Conv3d(in_channels, out_channels, kernal_size, stride, padding) 141 | self.norm = nn.InstanceNorm3d(out_channels) 142 | self.activation = nn.LeakyReLU(alpha) 143 | 144 | def forward(self, x): 145 | out = self.main(x) 146 | out = self.norm(out) 147 | out = self.activation(out) 148 | return out 149 | 150 | class UpConvBlock(nn.Module): 151 | def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, alpha=0.1): 152 | super(UpConvBlock, self).__init__() 153 | 154 | self.upconv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1) 155 | 156 | self.actout = nn.Sequential( 157 | nn.InstanceNorm3d(out_channels), 158 | nn.LeakyReLU(alpha) 159 | ) 160 | def forward(self, x): 161 | x = self.upconv(x) 162 | return self.actout(x) 163 | 164 | class ResBlock(nn.Module): 165 | """ 166 | VoxRes module 167 | """ 168 | 169 | def __init__(self, channel, alpha=0.1): 170 | super(ResBlock, self).__init__() 171 | self.block = nn.Sequential( 172 | nn.InstanceNorm3d(channel), 173 | nn.LeakyReLU(alpha), 174 | nn.Conv3d(channel, channel, kernel_size=3, padding=1) 175 | ) 176 | self.actout = nn.Sequential( 177 | nn.InstanceNorm3d(channel), 178 | nn.LeakyReLU(alpha), 179 | ) 180 | def forward(self, x): 181 | out = self.block(x) + x 182 | return self.actout(out) 183 | 184 | 185 | class Encoder(nn.Module): 186 | """ 187 | Main model 188 | """ 189 | 190 | def __init__(self, in_channel=1, first_out_channel=16): 191 | super(Encoder, self).__init__() 192 | 193 | c = first_out_channel 194 | self.conv0 = ConvInsBlock(in_channel, c, 3, 1) 195 | 196 | self.conv1 = nn.Sequential( 197 | nn.Conv3d(c, 2*c, kernel_size=3, stride=2, padding=1),#80 198 | ResBlock(2*c) 199 | ) 200 | 201 | self.conv2 = nn.Sequential( 202 | nn.Conv3d(2*c, 4*c, kernel_size=3, stride=2, padding=1),#40 203 | ResBlock(4*c) 204 | ) 205 | 206 | self.conv3 = nn.Sequential( 207 | nn.Conv3d(4*c, 8*c, kernel_size=3, stride=2, padding=1),#20 208 | ResBlock(8*c) 209 | ) 210 | 211 | def forward(self, x): 212 | out0 = self.conv0(x) # 1 213 | out1 = self.conv1(out0) # 1/2 214 | out2 = self.conv2(out1) # 1/4 215 | out3 = self.conv3(out2) # 1/8 216 | 217 | return [out0, out1, out2, out3] 218 | 219 | class CConv(nn.Module): 220 | def __init__(self, channel): 221 | super(CConv, self).__init__() 222 | 223 | c = channel 224 | 225 | self.conv = nn.Sequential( 226 | ConvInsBlock(c, c, 3, 1), 227 | ConvInsBlock(c, c, 3, 1) 228 | ) 229 | 230 | def forward(self, float_fm, fixed_fm, d_fm): 231 | concat_fm = torch.cat([float_fm, fixed_fm, d_fm], dim=1) 232 | x = self.conv(concat_fm) 233 | return x 234 | 235 | class RDP(nn.Module): 236 | def __init__(self, inshape=(160,192,160), flow_multiplier=1.,in_channel=1, channels=16): 237 | super(RDP, self).__init__() 238 | self.flow_multiplier = flow_multiplier 239 | self.channels = channels 240 | self.step = 7 241 | self.inshape = inshape 242 | 243 | c = self.channels 244 | self.encoder_moving = Encoder(in_channel=in_channel, first_out_channel=c) 245 | self.encoder_fixed = Encoder(in_channel=in_channel, first_out_channel=c) 246 | 247 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 248 | self.upsample_trilin = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)#nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) 249 | 250 | 251 | self.warp = nn.ModuleList() 252 | self.diff = nn.ModuleList() 253 | for i in range(4): 254 | self.warp.append(SpatialTransformer([s // 2**i for s in inshape])) 255 | self.diff.append(VecInt([s // 2**i for s in inshape])) 256 | 257 | # bottleNeck 258 | self.cconv_4 = nn.Sequential( 259 | ConvInsBlock(16 * c, 8 * c, 3, 1), 260 | ConvInsBlock(8 * c, 8 * c, 3, 1) 261 | ) 262 | # warp scale 2 263 | self.defconv4 = nn.Conv3d(8*c, 3, 3, 1, 1) 264 | self.defconv4.weight = nn.Parameter(Normal(0, 1e-5).sample(self.defconv4.weight.shape)) 265 | self.defconv4.bias = nn.Parameter(torch.zeros(self.defconv4.bias.shape)) 266 | self.dconv4 = nn.Sequential( 267 | ConvInsBlock(3*8*c, 8*c), 268 | ConvInsBlock(8*c, 8*c) 269 | ) 270 | 271 | self.upconv3 = UpConvBlock(8*c, 4*c, 4, 2) 272 | self.cconv_3 = CConv(3*4*c) 273 | 274 | # warp scale 1 275 | self.defconv3 = nn.Conv3d(3*4*c, 3, 3, 1, 1) 276 | self.defconv3.weight = nn.Parameter(Normal(0, 1e-5).sample(self.defconv3.weight.shape)) 277 | self.defconv3.bias = nn.Parameter(torch.zeros(self.defconv3.bias.shape)) 278 | self.dconv3 = ConvInsBlock(3 * 4 * c, 4 * c) 279 | 280 | self.upconv2 = UpConvBlock(3*4*c, 2*c, 4, 2) 281 | self.cconv_2 = CConv(3*2*c) 282 | 283 | # warp scale 0 284 | self.defconv2 = nn.Conv3d(3*2*c, 3, 3, 1, 1) 285 | self.defconv2.weight = nn.Parameter(Normal(0, 1e-5).sample(self.defconv2.weight.shape)) 286 | self.defconv2.bias = nn.Parameter(torch.zeros(self.defconv2.bias.shape)) 287 | self.dconv2 = ConvInsBlock(3 * 2 * c, 2 * c) 288 | 289 | self.upconv1 = UpConvBlock(3*2*c, c, 4, 2) 290 | self.cconv_1 = CConv(3*c) 291 | 292 | # decoder layers 293 | self.defconv1 = nn.Conv3d(3*c, 3, 3, 1, 1) 294 | self.defconv1.weight = nn.Parameter(Normal(0, 1e-5).sample(self.defconv1.weight.shape)) 295 | self.defconv1.bias = nn.Parameter(torch.zeros(self.defconv1.bias.shape)) 296 | #self.dconv1 = ConvInsBlock(3 * c, c) 297 | 298 | def forward(self, moving, fixed): 299 | 300 | # encode stage 301 | M1, M2, M3, M4 = self.encoder_moving(moving) 302 | F1, F2, F3, F4 = self.encoder_fixed(fixed) 303 | # c=16, 2c, 4c, 8c # 160, 80, 40, 20 304 | 305 | # first dec layer 306 | C4 = torch.cat([F4, M4], dim=1) 307 | C4 = self.cconv_4(C4) # (1,128,20,24,20) 308 | flow = self.defconv4(C4) # (1,3,20,24,20) 309 | flow = self.diff[3](flow) 310 | warped = self.warp[3](M4, flow) 311 | C4 = self.dconv4(torch.cat([F4, warped, C4], dim=1)) 312 | v = self.defconv4(C4) # (1,3,20,24,20) 313 | w = self.diff[3](v) 314 | 315 | 316 | D3 = self.upconv3(C4) # (1, 64, 40, 48, 40) 317 | flow = self.upsample_trilin(2*(self.warp[3](flow, w)+w)) 318 | warped = self.warp[2](M3, flow) # (1, 64, 40, 48, 40) 319 | C3 = self.cconv_3(F3, warped, D3) # (1, 3 * 64, 40, 48, 40) 320 | v = self.defconv3(C3) 321 | w = self.diff[2](v) 322 | flow = self.warp[2](flow, w)+w 323 | warped = self.warp[2](M3, flow) # (1, 64, 40, 48, 40) 324 | D3 = self.dconv3(C3) 325 | C3 = self.cconv_3(F3, warped, D3) # (1, 3 * 64, 40, 48, 40) 326 | v = self.defconv3(C3) 327 | w = self.diff[2](v) 328 | 329 | D2 = self.upconv2(C3) 330 | flow = self.upsample_trilin(2*(self.warp[2](flow, w)+w)) 331 | warped = self.warp[1](M2, flow) 332 | C2 = self.cconv_2(F2, warped, D2) 333 | v = self.defconv2(C2) # (1,3,80,96,80) 334 | w = self.diff[1](v) 335 | flow = self.warp[1](flow, w)+w 336 | warped = self.warp[1](M2, flow) 337 | D2 = self.dconv2(C2) 338 | C2 = self.cconv_2(F2, warped, D2) 339 | v = self.defconv2(C2) # (1,3,80,96,80) 340 | w = self.diff[1](v) 341 | flow = self.warp[1](flow, w)+w 342 | warped = self.warp[1](M2, flow) 343 | D2 = self.dconv2(C2) 344 | C2 = self.cconv_2(F2, warped, D2) 345 | v = self.defconv2(C2) # (1,3,80,96,80) 346 | w = self.diff[1](v) 347 | 348 | D1 = self.upconv1(C2) # (1,16,160,196,160) 349 | flow = self.upsample_trilin(2*(self.warp[1](flow, w)+w)) # (1,3,160,196,160) 350 | warped = self.warp[0](M1, flow) # (1,16,160,196,160) 351 | C1 = self.cconv_1(F1, warped, D1) # (1,48,160,196,160) 352 | v = self.defconv1(C1) 353 | w = self.diff[0](v) 354 | flow = self.warp[0](flow, w)+w # (1,3,160,196,160) 355 | 356 | y_moved = self.warp[0](moving, flow) 357 | 358 | return y_moved, flow 359 | 360 | if __name__ == '__main__': 361 | size = (1, 1, 80, 96, 80) 362 | model = RDP(size[2:]) 363 | # print(str(model)) 364 | A = torch.ones(size) 365 | B = torch.ones(size) 366 | out, flow = model(A, B) 367 | print(out.shape, flow.shape) 368 | --------------------------------------------------------------------------------