├── LICENSE ├── README.md ├── convert_landmark_dense.py ├── convert_landmark_dirlab300.py ├── images ├── groupreg_flowchart.png ├── res_1.png ├── res_2.png ├── res_3.png ├── res_4.png ├── res_5.png └── res_6.png ├── model ├── loss.py ├── regnet.py ├── unet.py └── util.py ├── registration_dirlab.py └── utils └── structure.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, vincentme 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GroupRegNet 2 | Implementation of [GroupRegNet: A Groupwise One-shot Deep Learning-based 4D Image Registration Method](https://iopscience.iop.org/article/10.1088/1361-6560/abd956). 3 | Zhang, Y., Wu, X., Gach, H. M., Li, H. H., & Yang, D. (2021). GroupRegNet: a groupwise one-shot deep learning-based 4D image registration method. Physics in Medicine & Biology. 4 | 5 | GroupRegNet is an unsupervised deep learning-based DIR method that employs both groupwise registration and one-shot strategy to register 4D medical images and then to determine all pairwise deformation vector fields (DVFs). 6 | 7 | ## Requirement 8 | 9 | - PyTorch 10 | - SimpleITK: read mhd files 11 | - logging and tqdm 12 | 13 | ## Usage 14 | 15 | To evaluate GroupRegNet with `registration_dirlab.py`, the [DIR-Lab](https://www.dir-lab.com/index.html) dataset is required. The original data needs to be converted into mhd format. 16 | 17 | To convert the original landmark of Landmark300 and LandmarkDense, run `convert_landmark_dirlab300.py` and `convert_landmark_dense.py`. 18 | 19 | ## Overall structure 20 | 21 | ![groupreg_flowchart](images/groupreg_flowchart.png) 22 | 23 | ## Result 24 | 25 | ![res_1](images/res_1.png) 26 | 27 | 28 | ![res_2](images/res_2.png) 29 | 30 | 31 | ![res_3](images/res_3.png) 32 | 33 | Sizes, shapes, and locations of the contoured tumor targets, shown in violet shade, in coronal views of the EI phases of three patient cases. 34 | ![res_4](images/res_4.png) 35 | 36 | Comparison of the tracked targets in ten phases by GroupRegNet and manual contouring of case 1 and 3. The images are shown in coronal views, and the horizontal line in each figure is at the same height for visual reference. 37 | ![res_5](images/res_5.png) 38 | 39 | ![res_6](images/res_6.png) 40 | -------------------------------------------------------------------------------- /convert_landmark_dense.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, os 3 | from scipy import io 4 | 5 | 6 | mat_file_list = [file for file in os.listdir('.') if file.endswith('mat')] 7 | 8 | for mat_file in mat_file_list: 9 | case = mat_file.split('.')[0] 10 | dirqa = io.loadmat(mat_file) 11 | landmark_00 = dirqa['landmark_EI'].astype(np.float32) - 1. # change to 0-based indexing 12 | landmark_50 = dirqa['landmark_EE'].astype(np.float32) - 1. 13 | landmark_00[:, [0, 1]] = landmark_00[:, [1, 0]] 14 | landmark_50[:, [0, 1]] = landmark_50[:, [1, 0]] 15 | disp_00_50 = landmark_50 - landmark_00 # (n, 3) 16 | landmark = {'landmark_00':landmark_00, 'landmark_50':landmark_50, 'disp_00_50':disp_00_50} 17 | torch.save(landmark, f'{case}_00_50.pt') 18 | -------------------------------------------------------------------------------- /convert_landmark_dirlab300.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | landmark_00 = np.genfromtxt('Case1_300_T00_xyz.txt', dtype = np.int64) - 1 # change to 0-based indexing 5 | landmark_50 = np.genfromtxt('Case1_300_T50_xyz.txt', dtype = np.int64) - 1 # (n, 3), (w, h, d) order in the last dimension 6 | disp_00_50 = (landmark_50 - landmark_00).astype(np.float32) # (n, 3) 7 | 8 | landmark = {'landmark_00':landmark_00, 'landmark_50':landmark_50, 'disp_00_50':disp_00_50} 9 | torch.save(landmark, 'Case1_300_00_50.pt') 10 | -------------------------------------------------------------------------------- /images/groupreg_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincentme/GroupRegNet/88a91ad0fbfe82800e5d9e48de999fcaff45c023/images/groupreg_flowchart.png -------------------------------------------------------------------------------- /images/res_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincentme/GroupRegNet/88a91ad0fbfe82800e5d9e48de999fcaff45c023/images/res_1.png -------------------------------------------------------------------------------- /images/res_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincentme/GroupRegNet/88a91ad0fbfe82800e5d9e48de999fcaff45c023/images/res_2.png -------------------------------------------------------------------------------- /images/res_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincentme/GroupRegNet/88a91ad0fbfe82800e5d9e48de999fcaff45c023/images/res_3.png -------------------------------------------------------------------------------- /images/res_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincentme/GroupRegNet/88a91ad0fbfe82800e5d9e48de999fcaff45c023/images/res_4.png -------------------------------------------------------------------------------- /images/res_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincentme/GroupRegNet/88a91ad0fbfe82800e5d9e48de999fcaff45c023/images/res_5.png -------------------------------------------------------------------------------- /images/res_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincentme/GroupRegNet/88a91ad0fbfe82800e5d9e48de999fcaff45c023/images/res_6.png -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import scipy.ndimage 5 | import numpy as np 6 | 7 | torch.backends.cudnn.deterministic = True 8 | 9 | class NCC(nn.Module): 10 | ''' 11 | Calculate local normalized cross-correlation coefficient between tow images. 12 | 13 | Parameters 14 | ---------- 15 | dim : int 16 | Dimension of the input images. 17 | windows_size : int 18 | Side length of the square window to calculate the local NCC. 19 | ''' 20 | def __init__(self, dim, windows_size = 11): 21 | super().__init__() 22 | assert dim in (2, 3) 23 | self.dim = dim 24 | self.num_stab_const = 1e-4 # numerical stability constant 25 | 26 | self.windows_size = windows_size 27 | 28 | self.pad = windows_size//2 29 | self.window_volume = windows_size**self.dim 30 | if self.dim == 2: 31 | self.conv = F.conv2d 32 | elif self.dim == 3: 33 | self.conv = F.conv3d 34 | 35 | def forward(self, I, J): 36 | ''' 37 | Parameters 38 | ---------- 39 | I and J : (n, 1, h, w) or (n, 1, d, h, w) 40 | Torch tensor of same shape. The number of image in the first dimension can be different, in which broadcasting will be used. 41 | windows_size : int 42 | Side length of the square window to calculate the local NCC. 43 | 44 | Returns 45 | ------- 46 | NCC : scalar 47 | Average local normalized cross-correlation coefficient. 48 | ''' 49 | try: 50 | I_sum = self.conv(I, self.sum_filter, padding = self.pad) 51 | except: 52 | self.sum_filter = torch.ones([1, 1] + [self.windows_size, ]*self.dim, dtype = I.dtype, device = I.device) 53 | I_sum = self.conv(I, self.sum_filter, padding = self.pad) 54 | 55 | J_sum = self.conv(J, self.sum_filter, padding = self.pad) # (n, 1, h, w) or (n, 1, d, h, w) 56 | I2_sum = self.conv(I*I, self.sum_filter, padding = self.pad) 57 | J2_sum = self.conv(J*J, self.sum_filter, padding = self.pad) 58 | IJ_sum = self.conv(I*J, self.sum_filter, padding = self.pad) 59 | 60 | cross = torch.clamp(IJ_sum - I_sum*J_sum/self.window_volume, min = self.num_stab_const) 61 | I_var = torch.clamp(I2_sum - I_sum**2/self.window_volume, min = self.num_stab_const) 62 | J_var = torch.clamp(J2_sum - J_sum**2/self.window_volume, min = self.num_stab_const) 63 | 64 | cc = cross/((I_var*J_var)**0.5) 65 | 66 | return -torch.mean(cc) 67 | 68 | 69 | 70 | def smooth_loss(disp, image): 71 | ''' 72 | Calculate the smooth loss. Return mean of absolute or squared of the forward difference of flow field. 73 | 74 | Parameters 75 | ---------- 76 | disp : (n, 2, h, w) or (n, 3, d, h, w) 77 | displacement field 78 | 79 | image : (n, 1, d, h, w) or (1, 1, d, h, w) 80 | 81 | ''' 82 | 83 | image_shape = disp.shape 84 | dim = len(image_shape[2:]) 85 | 86 | d_disp = torch.zeros((image_shape[0], dim) + tuple(image_shape[1:]), dtype = disp.dtype, device = disp.device) 87 | d_image = torch.zeros((image_shape[0], dim) + tuple(image_shape[1:]), dtype = disp.dtype, device = disp.device) 88 | 89 | # forward difference 90 | if dim == 2: 91 | d_disp[:, 1, :, :-1, :] = (disp[:, :, 1:, :] - disp[:, :, :-1, :]) 92 | d_disp[:, 0, :, :, :-1] = (disp[:, :, :, 1:] - disp[:, :, :, :-1]) 93 | d_image[:, 1, :, :-1, :] = (image[:, :, 1:, :] - image[:, :, :-1, :]) 94 | d_image[:, 0, :, :, :-1] = (image[:, :, :, 1:] - image[:, :, :, :-1]) 95 | 96 | elif dim == 3: 97 | d_disp[:, 2, :, :-1, :, :] = (disp[:, :, 1:, :, :] - disp[:, :, :-1, :, :]) 98 | d_disp[:, 1, :, :, :-1, :] = (disp[:, :, :, 1:, :] - disp[:, :, :, :-1, :]) 99 | d_disp[:, 0, :, :, :, :-1] = (disp[:, :, :, :, 1:] - disp[:, :, :, :, :-1]) 100 | 101 | d_image[:, 2, :, :-1, :, :] = (image[:, :, 1:, :, :] - image[:, :, :-1, :, :]) 102 | d_image[:, 1, :, :, :-1, :] = (image[:, :, :, 1:, :] - image[:, :, :, :-1, :]) 103 | d_image[:, 0, :, :, :, :-1] = (image[:, :, :, :, 1:] - image[:, :, :, :, :-1]) 104 | 105 | loss = torch.mean(torch.sum(torch.abs(d_disp), dim = 2, keepdims = True)*torch.exp(-torch.abs(d_image))) 106 | 107 | return loss 108 | 109 | 110 | -------------------------------------------------------------------------------- /model/regnet.py: -------------------------------------------------------------------------------- 1 | from . import unet 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import logging 6 | 7 | 8 | class RegNet_single(nn.Module): 9 | ''' 10 | Groupwise implicit template CNN registration method. 11 | 12 | Parameters 13 | ---------- 14 | dim : int 15 | Dimension of input image. 16 | n : int 17 | Number of image in the group. 18 | depth : int, optional 19 | Depth of the network. The maximum number of channels will be 2**(depth - 1) times than the initial_channels. The default is 5. 20 | initial_channels : int, optional 21 | Number of initial channels. The default is 64. 22 | normalization : int, optional 23 | Whether to add instance normalization after activation. The default is True. 24 | ''' 25 | def __init__(self, dim, n, scale = 1, depth = 5, initial_channels = 64, normalization = True): 26 | 27 | super().__init__() 28 | assert dim in (2, 3) 29 | self.dim = dim 30 | self.n = n 31 | self.scale = scale 32 | 33 | self.unet = unet.UNet(in_channels = n, out_channels = dim*n, dim = dim, depth = depth, initial_channels = initial_channels, normalization = normalization) 34 | self.spatial_transform = SpatialTransformer(self.dim) 35 | 36 | def forward(self, input_image): 37 | ''' 38 | Parameters 39 | ---------- 40 | input_image : (n, 1, h, w) or (n, 1, d, h, w) 41 | The first dimension contains the grouped input images. 42 | 43 | Returns 44 | ------- 45 | warped_input_image : (n, 1, h, w) or (n, 1, d, h, w) 46 | Warped input image. 47 | template : (1, 1, h, w) or (1, 1, d, h, w) 48 | Implicit template image derived by averaging the warped_input_image 49 | disp_t2i : (n, 2, h, w) or (n, 3, d, h, w) 50 | Flow field from implicit template to input image. The starting point of the displacement is on the regular grid defined on the implicit template and the ending point corresponding to the same structure in the input image. 51 | warped_template : (n, 1, h, w) or (n, 1, d, h, w) 52 | Warped template images that should match the original input image. 53 | disp_i2t : (n, 2, h, w) or (n, 3, d, h, w) 54 | Flow field from input image to implicit template. The starting point of the displacement is on the regular grid defined on the input image and the ending point corresponding to the same structure in the implicit template. 55 | ''' 56 | 57 | original_image_shape = input_image.shape[2:] 58 | 59 | if self.scale < 1: 60 | scaled_image = F.interpolate(torch.transpose(input_image, 0, 1), scale_factor = self.scale, align_corners = True, mode = 'bilinear' if self.dim == 2 else 'trilinear', recompute_scale_factor = False) # (1, n, h, w) or (1, n, d, h, w) 61 | else: 62 | scaled_image = torch.transpose(input_image, 0, 1) 63 | 64 | scaled_image_shape = scaled_image.shape[2:] 65 | scaled_disp_t2i = torch.squeeze(self.unet(scaled_image), 0).reshape(self.n, self.dim, *scaled_image_shape) # (n, 2, h, w) or (n, 3, d, h, w) 66 | if self.scale < 1: 67 | disp_t2i = torch.nn.functional.interpolate(scaled_disp_t2i, size = original_image_shape, mode = 'bilinear' if self.dim == 2 else 'trilinear', align_corners = True) 68 | else: 69 | disp_t2i = scaled_disp_t2i 70 | 71 | warped_input_image = self.spatial_transform(input_image, disp_t2i) # (n, 1, h, w) or (n, 1, d, h, w) 72 | template = torch.mean(warped_input_image, 0, keepdim = True) # (1, 1, h, w) or (1, 1, d, h, w) 73 | 74 | 75 | res = {'disp_t2i':disp_t2i, 'scaled_disp_t2i':scaled_disp_t2i, 'warped_input_image':warped_input_image, 'template':template} 76 | 77 | if self.scale < 1: 78 | scaled_template = torch.nn.functional.interpolate(template, size = scaled_image_shape, mode = 'bilinear' if self.dim == 2 else 'trilinear', align_corners = True) 79 | else: 80 | scaled_template = template 81 | res = {'disp_t2i':disp_t2i, 'scaled_disp_t2i':scaled_disp_t2i, 'warped_input_image':warped_input_image, 'template':template, 'scaled_template':scaled_template} 82 | return res 83 | 84 | class RegNet_pairwise(nn.Module): 85 | ''' 86 | Pairwise CNN registration method. 87 | 88 | Parameters 89 | ---------- 90 | dim : int 91 | Dimension of input image. 92 | depth : int, optional 93 | Depth of the network. The maximum number of channels will be 2**(depth - 1) times than the initial_channels. The default is 5. 94 | initial_channels : TYPE, optional 95 | Number of initial channels. The default is 64. 96 | normalization : TYPE, optional 97 | Whether to add instance normalization after activation. The default is True. 98 | ''' 99 | def __init__(self, dim, scale = 1, depth = 5, initial_channels = 64, normalization = True): 100 | 101 | super().__init__() 102 | assert dim in (2, 3) 103 | self.dim = dim 104 | self.scale = scale 105 | 106 | self.unet = unet.UNet(in_channels = 2, out_channels = dim, dim = dim, depth = depth, initial_channels = initial_channels, normalization = normalization) 107 | self.spatial_transform = SpatialTransformer(self.dim) 108 | 109 | def forward(self, fixed_image, moving_image): 110 | ''' 111 | Parameters 112 | ---------- 113 | fixed_image, moving_image : (h, w) or (d, h, w) 114 | Fixed and moving image to be registered 115 | 116 | Returns 117 | ------- 118 | warped_moving_image : (h, w) or (d, h, w) 119 | Warped input image. 120 | disp : (2, h, w) or (3, d, h, w) 121 | Flow field from fixed image to moving image. 122 | scaled_disp 123 | ''' 124 | 125 | original_image_shape = fixed_image.shape 126 | input_image = torch.unsqueeze(torch.stack((fixed_image, moving_image), dim = 0), 0) # (1, 2, h, w) or (1, 2, d, h, w) 127 | 128 | if self.scale < 1: 129 | scaled_image = F.interpolate(input_image, scale_factor = self.scale, align_corners = True, mode = 'bilinear' if self.dim == 2 else 'trilinear', recompute_scale_factor = False) # (1, 2, h, w) or (1, 2, d, h, w) 130 | else: 131 | scaled_image = input_image 132 | 133 | scaled_image_shape = scaled_image.shape[2:] 134 | scaled_disp = torch.squeeze(self.unet(scaled_image), 0).reshape(self.dim, *scaled_image_shape) # (2, h, w) or (3, d, h, w) 135 | if self.scale < 1: 136 | disp = torch.nn.functional.interpolate(torch.unsqueeze(scaled_disp, 0), size = original_image_shape, mode = 'bilinear' if self.dim == 2 else 'trilinear', align_corners = True) 137 | else: 138 | disp = torch.unsqueeze(scaled_disp, 0) 139 | 140 | warped_moving_image = self.spatial_transform(input_image[:, 1:], disp).squeeze() # (h, w) or (d, h, w) 141 | 142 | res = {'disp':disp.squeeze(0), 'scaled_disp':scaled_disp.squeeze(0), 'warped_moving_image':warped_moving_image} 143 | return res 144 | 145 | class SpatialTransformer(nn.Module): 146 | # 2D or 3d spatial transformer network to calculate the warped moving image 147 | 148 | 149 | def __init__(self, dim): 150 | super().__init__() 151 | self.dim = dim 152 | self.grid_dict = {} 153 | self.norm_coeff_dict = {} 154 | 155 | def forward(self, input_image, flow): 156 | ''' 157 | input_image: (n, 1, h, w) or (n, 1, d, h, w) 158 | flow: (n, 2, h, w) or (n, 3, d, h, w) 159 | 160 | return: 161 | warped moving image, (n, 1, h, w) or (n, 1, d, h, w) 162 | ''' 163 | img_shape = input_image.shape[2:] 164 | if img_shape in self.grid_dict: 165 | grid = self.grid_dict[img_shape] 166 | norm_coeff = self.norm_coeff_dict[img_shape] 167 | else: 168 | grids = torch.meshgrid([torch.arange(0, s) for s in img_shape]) 169 | grid = torch.stack(grids[::-1], dim = 0) # 2 x h x w or 3 x d x h x w, the data in second dimension is in the order of [w, h, d] 170 | grid = torch.unsqueeze(grid, 0) 171 | grid = grid.to(dtype = flow.dtype, device = flow.device) 172 | norm_coeff = 2./(torch.tensor(img_shape[::-1], dtype = flow.dtype, device = flow.device) - 1.) # the coefficients to map image coordinates to [-1, 1] 173 | self.grid_dict[img_shape] = grid 174 | self.norm_coeff_dict[img_shape] = norm_coeff 175 | logging.info(f'\nAdd grid shape {tuple(img_shape)}') 176 | new_grid = grid + flow 177 | 178 | if self.dim == 2: 179 | new_grid = new_grid.permute(0, 2, 3, 1) # n x h x w x 2 180 | elif self.dim == 3: 181 | new_grid = new_grid.permute(0, 2, 3, 4, 1) # n x d x h x w x 3 182 | 183 | if len(input_image) != len(new_grid): 184 | # make the image shape compatable by broadcasting 185 | input_image += torch.zeros_like(new_grid) 186 | new_grid += torch.zeros_like(input_image) 187 | 188 | warped_input_img = F.grid_sample(input_image, new_grid*norm_coeff - 1. , mode = 'bilinear', align_corners = True, padding_mode = 'border') 189 | return warped_input_img 190 | 191 | -------------------------------------------------------------------------------- /model/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | torch.backends.cudnn.deterministic = True 5 | torch.backends.cudnn.benchmark = False 6 | 7 | class UNet(nn.Module): 8 | ''' 9 | U-net implementation with modifications. 10 | 1. Works for input of 2D or 3D 11 | 2. Change batch normalization to instance normalization 12 | 13 | Adapted from https://github.com/jvanvugt/pytorch-unet/blob/master/unet.py 14 | 15 | 16 | Parameters 17 | ---------- 18 | in_channels : int 19 | number of input channels. 20 | out_channels : int 21 | number of output channels. 22 | dim : (2 or 3), optional 23 | The dimention of input data. The default is 2. 24 | depth : int, optional 25 | Depth of the network. The maximum number of channels will be 2**(depth - 1) times than the initial_channels. The default is 5. 26 | initial_channels : TYPE, optional 27 | Number of initial channels. The default is 32. 28 | normalization : bool, optional 29 | Whether to add instance normalization after activation. The default is False. 30 | ''' 31 | def __init__(self, in_channels, out_channels, dim = 2, depth = 5, initial_channels = 32, normalization = True): 32 | 33 | super().__init__() 34 | assert dim in (2, 3) 35 | self.dim = dim 36 | 37 | self.depth = depth 38 | prev_channels = in_channels 39 | self.down_path = nn.ModuleList() 40 | for i in range(self.depth): 41 | current_channels = 2**i*initial_channels 42 | self.down_path.append(ConvBlock(prev_channels, current_channels, dim, normalization)) 43 | prev_channels = current_channels 44 | 45 | self.up_path = nn.ModuleList() 46 | for i in reversed(range(self.depth - 1)): 47 | current_channels = 2**i*initial_channels 48 | # print(prev_channels, current_channels) 49 | self.up_path.append(UpBlock(prev_channels, current_channels, dim, normalization)) 50 | prev_channels = current_channels 51 | 52 | if dim == 2: 53 | self.last = nn.Conv2d(prev_channels, out_channels, kernel_size = 1) 54 | elif dim == 3: 55 | self.last = nn.Conv3d(prev_channels, out_channels, kernel_size = 1) 56 | 57 | def forward(self, x): 58 | blocks = [] 59 | for i, down in enumerate(self.down_path): 60 | x = down(x) 61 | if i < self.depth - 1: 62 | blocks.append(x) 63 | x = F.interpolate(x, scale_factor = 0.5, mode = 'bilinear' if self.dim == 2 else 'trilinear', align_corners = True, recompute_scale_factor = False) 64 | 65 | for i, up in enumerate(self.up_path): 66 | x = up(x, blocks[-i - 1]) 67 | 68 | return self.last(x) 69 | 70 | class ConvBlock(nn.Module): 71 | def __init__(self, in_channels, out_channels, dim, normalization, LeakyReLU_slope = 0.2): 72 | super().__init__() 73 | block = [] 74 | if dim == 2: 75 | block.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding = 1)) 76 | if normalization: 77 | block.append(nn.InstanceNorm2d(out_channels)) 78 | block.append(nn.LeakyReLU(LeakyReLU_slope)) 79 | elif dim == 3: 80 | block.append(nn.Conv3d(in_channels, out_channels, kernel_size=3, padding = 1)) 81 | if normalization: 82 | block.append(nn.InstanceNorm3d(out_channels)) 83 | block.append(nn.LeakyReLU(LeakyReLU_slope)) 84 | else: 85 | raise (f'dim should be 2 or 3, got {dim}') 86 | self.block = nn.Sequential(*block) 87 | 88 | def forward(self, x): 89 | out = self.block(x) 90 | return out 91 | 92 | 93 | class UpBlock(nn.Module): 94 | def __init__(self, in_channels, out_channels, dim, normalization): 95 | super().__init__() 96 | self.dim = dim 97 | if dim == 2: 98 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1) 99 | elif dim == 3: 100 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 1) 101 | self.conv_block = ConvBlock(in_channels, out_channels, dim, normalization) 102 | 103 | 104 | def forward(self, x, skip): 105 | x_up = F.interpolate(x, skip.shape[2:], mode = 'bilinear' if self.dim == 2 else 'trilinear', align_corners = True) 106 | x_up_conv = self.conv(x_up) 107 | out = torch.cat([x_up_conv, skip], 1) 108 | out = self.conv_block(out) 109 | return out 110 | 111 | -------------------------------------------------------------------------------- /model/util.py: -------------------------------------------------------------------------------- 1 | from . import regnet 2 | import numpy as np 3 | import torch 4 | 5 | class StopCriterion(object): 6 | def __init__(self, stop_std = 0.001, query_len = 100, num_min_iter = 200): 7 | self.query_len = query_len 8 | self.stop_std = stop_std 9 | self.loss_list = [] 10 | self.loss_min = 1. 11 | self.num_min_iter = num_min_iter 12 | 13 | def add(self, loss): 14 | self.loss_list.append(loss) 15 | if loss < self.loss_min: 16 | self.loss_min = loss 17 | self.loss_min_i = len(self.loss_list) 18 | 19 | def stop(self): 20 | # return True if the stop creteria are met 21 | query_list = self.loss_list[-self.query_len:] 22 | query_std = np.std(query_list) 23 | if query_std < self.stop_std and self.loss_list[-1] - self.loss_min < self.stop_std/3. and len(self.loss_list) > self.loss_min_i and len(self.loss_list) > self.num_min_iter: 24 | return True 25 | else: 26 | return False 27 | 28 | class CalcDisp(object): 29 | def __init__(self, dim, calc_device = 'cuda'): 30 | self.device = torch.device(calc_device) 31 | self.dim = dim 32 | self.spatial_transformer = regnet.SpatialTransformer(dim = dim) 33 | 34 | def inverse_disp(self, disp, threshold = 0.01, max_iteration = 20): 35 | ''' 36 | compute the inverse field. implementationof "A simple fixed‐point approach to invert a deformation field" 37 | 38 | disp : (n, 2, h, w) or (n, 3, d, h, w) or (2, h, w) or (3, d, h, w) 39 | displacement field 40 | ''' 41 | forward_disp = disp.detach().to(device = self.device) 42 | if disp.ndim < self.dim + 2: 43 | forward_disp = torch.unsqueeze(forward_disp, 0) 44 | backward_disp = torch.zeros_like(forward_disp) 45 | backward_disp_old = backward_disp.clone() 46 | for i in range(max_iteration): 47 | backward_disp = -self.spatial_transformer(forward_disp, backward_disp) 48 | diff = torch.max(torch.abs(backward_disp - backward_disp_old)).item() 49 | if diff < threshold: 50 | break 51 | backward_disp_old = backward_disp.clone() 52 | if disp.ndim < self.dim + 2: 53 | backward_disp = torch.squeeze(backward_disp, 0) 54 | 55 | return backward_disp 56 | 57 | def compose_disp(self, disp_i2t, disp_t2i, mode = 'corr'): 58 | ''' 59 | compute the composition field 60 | 61 | disp_i2t: (n, 3, d, h, w) 62 | displacement field from the input image to the template 63 | 64 | disp_t2i: (n, 3, d, h, w) 65 | displacement field from the template to the input image 66 | 67 | mode: string, default 'corr' 68 | 'corr' means generate composition of corresponding displacement field in the batch dimension only, the result shape is the same as input (n, 3, d, h, w) 69 | 'all' means generate all pairs of composition displacement field. The result shape is (n, n, 3, d, h, w) 70 | ''' 71 | disp_i2t_t = disp_i2t.detach().to(device = self.device) 72 | disp_t2i_t = disp_t2i.detach().to(device = self.device) 73 | if disp_i2t.ndim < self.dim + 2: 74 | disp_i2t_t = torch.unsqueeze(disp_i2t_t, 0) 75 | if disp_t2i.ndim < self.dim + 2: 76 | disp_t2i_t = torch.unsqueeze(disp_t2i_t, 0) 77 | 78 | if mode == 'corr': 79 | composed_disp = self.spatial_transformer(disp_t2i_t, disp_i2t_t) + disp_i2t_t # (n, 2, h, w) or (n, 3, d, h, w) 80 | elif mode == 'all': 81 | assert len(disp_i2t_t) == len(disp_t2i_t) 82 | n, _, *image_shape = disp_i2t.shape 83 | disp_i2t_nxn = torch.repeat_interleave(torch.unsqueeze(disp_i2t_t, 1), n, 1) # (n, n, 2, h, w) or (n, n, 3, d, h, w) 84 | disp_i2t_nn = disp_i2t_nxn.reshape(n*n, self.dim, *image_shape) # (n*n, 2, h, w) or (n*n, 3, d, h, w), the order in the first dimension is [0_T, 0_T, ..., 0_T, 1_T, 1_T, ..., 1_T, ..., n_T, n_T, ..., n_T] 85 | disp_t2i_nn = torch.repeat_interleave(torch.unsqueeze(disp_t2i_t, 0), n, 0).reshape(n*n, self.dim, *image_shape) # (n*n, 2, h, w) or (n*n, 3, d, h, w), the order in the first dimension is [0_T, 1_T, ..., n_T, 0_T, 1_T, ..., n_T, ..., 0_T, 1_T, ..., n_T] 86 | composed_disp = self.spatial_transformer(disp_t2i_nn, disp_i2t_nn).reshape(n, n, self.dim, *image_shape) + disp_i2t_nxn # (n, n, 2, h, w) or (n, n, 3, d, h, w) + disp_i2t_nxn 87 | else: 88 | raise 89 | if disp_i2t.ndim < self.dim + 2 and disp_t2i.ndim < self.dim + 2: 90 | composed_disp = torch.squeeze(composed_disp) 91 | return composed_disp 92 | 93 | class Struct: 94 | def __init__(self, **entries): 95 | self.__dict__.update(entries) -------------------------------------------------------------------------------- /registration_dirlab.py: -------------------------------------------------------------------------------- 1 | import model.regnet, model.loss, model.util, utils.structure 2 | import torch, os 3 | import SimpleITK as sitk 4 | import matplotlib.pyplot as plt; plot_dpi = 300 5 | import numpy as np 6 | import logging, tqdm 7 | logging.basicConfig(level=logging.INFO, format = '%(levelname)s: %(message)s') 8 | from scipy import interpolate 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | case = 1 13 | crop_range = [slice(0, 83), slice(43, 200), slice(10, 250)] 14 | pixel_spacing = np.array([0.97, 0.97, 2.5], dtype = np.float32) 15 | 16 | 17 | # case = 2 18 | # crop_range = [slice(5, 98), slice(30, 195), slice(8, 243)] 19 | # pixel_spacing = np.array([1.16, 1.16, 2.5], dtype = np.float32) 20 | 21 | 22 | # case = 3 23 | # crop_range = [slice(0, 95), slice(42, 209), slice(10, 248)] 24 | # pixel_spacing = np.array([1.15, 1.15, 2.5], dtype = np.float32) 25 | 26 | 27 | # case = 4 28 | # crop_range = [slice(0, 90), slice(45, 209), slice(11, 242)] 29 | # pixel_spacing = np.array([1.13, 1.13, 2.5], dtype = np.float32) 30 | 31 | 32 | # case = 5 33 | # crop_range = [slice(0, 90), slice(60, 222), slice(16, 237)] 34 | # pixel_spacing = np.array([1.10, 1.10, 2.5], dtype = np.float32) 35 | 36 | 37 | # case = 6 38 | # crop_range = [slice(10, 107), slice(144, 328), slice(132, 426)] 39 | # pixel_spacing = np.array([0.97, 0.97, 2.5], dtype = np.float32) 40 | 41 | 42 | # case = 7 43 | # crop_range = [slice(13, 108), slice(141, 331), slice(114, 423)] 44 | # pixel_spacing = np.array([0.97, 0.97, 2.5], dtype = np.float32) 45 | 46 | 47 | # case = 8 48 | # crop_range = [slice(18, 118), slice(84, 299), slice(113, 390)] 49 | # pixel_spacing = np.array([0.97, 0.97, 2.5], dtype = np.float32) 50 | 51 | 52 | # case = 9 53 | # crop_range = [slice(0, 70), slice(126, 334), slice(128, 390)] 54 | # pixel_spacing = np.array([0.97, 0.97, 2.5], dtype = np.float32) 55 | 56 | 57 | # case = 10 58 | # crop_range = [slice(0, 90), slice(119, 333), slice(140, 382)] 59 | # pixel_spacing = np.array([0.97, 0.97, 2.5], dtype = np.float32) 60 | 61 | 62 | 63 | data_folder = f'/data/dirlab/Case{case}Pack/Image_MHD/' 64 | landmark_file = f'/data/dirlab/Case1Pack/ExtremePhases/case{case}_00_50.pt' 65 | states_folder = '/result/general_reg/dirlab/' 66 | config = dict( 67 | dim = 3, # dimension of the input image 68 | intensity_scale_const = 1000., # (image - intensity_shift_const)/intensity_scale_const 69 | intensity_shift_const = 1000., 70 | # scale = 0.7, 71 | scale = 0.5, 72 | initial_channels = 32, 73 | depth = 4, 74 | max_num_iteration = 3000, 75 | normalization = True, # whether use normalization layer 76 | learning_rate = 1e-2, 77 | smooth_reg = 1e-3, 78 | cyclic_reg = 1e-2, 79 | ncc_window_size = 5, 80 | load = False, 81 | load_optimizer = False, 82 | group_index_list = None, 83 | pair_disp_indexes = [0, 5], 84 | pair_disp_calc_interval = 20, 85 | stop_std = 0.0007, 86 | stop_query_len = 100, 87 | ) 88 | config = utils.structure.Struct(**config) 89 | 90 | landmark_info = torch.load(landmark_file) 91 | landmark_disp = landmark_info['disp_00_50'] # w, h, d 92 | landmark_00 = landmark_info['landmark_00'] 93 | landmark_50 = landmark_info['landmark_50'] 94 | crop_min = np.min(np.concatenate((landmark_00, landmark_50), axis = 0), axis = 0) - 8 95 | crop_max = np.max(np.concatenate((landmark_00, landmark_50), axis = 0), axis = 0) + 8 96 | print(crop_min) 97 | print(crop_max) 98 | 99 | image_file_list = sorted([file_name for file_name in os.listdir(data_folder) if file_name.lower().endswith('mhd')]) 100 | image_list = [sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(data_folder, file_name))) for file_name in image_file_list] 101 | input_image = torch.stack([torch.from_numpy(image)[None] for image in image_list], 0) 102 | if config.group_index_list is not None: 103 | input_image = input_image[config.group_index_list] 104 | 105 | input_image = (input_image - config.intensity_shift_const)/config.intensity_scale_const 106 | 107 | input_image = input_image[:, :, crop_range[0], crop_range[1], crop_range[2]] 108 | image_shape = np.array(input_image.shape[2:]) # (d, h, w) 109 | num_image = input_image.shape[0] # number of image in the group 110 | regnet = model.regnet.RegNet_single(dim = config.dim, n = num_image, scale = config.scale, depth = config.depth, initial_channels = config.initial_channels, normalization = config.normalization) 111 | 112 | ncc_loss = model.loss.NCC(config.dim, config.ncc_window_size) 113 | regnet = regnet.to(device) 114 | input_image = input_image.to(device) 115 | ncc_loss = ncc_loss.to(device) 116 | optimizer = torch.optim.Adam(regnet.parameters(), lr = config.learning_rate) 117 | calcdisp = model.util.CalcDisp(dim = config.dim, calc_device = 'cuda') 118 | 119 | if config.load: 120 | state_file = os.path.join(states_folder, config.load) 121 | if os.path.exists(state_file): 122 | state_file = os.path.join(states_folder, config.load) 123 | states = torch.load(state_file, map_location = device) 124 | regnet.load_state_dict(states['model']) 125 | if config.load_optimizer: 126 | optimizer.load_state_dict(states['optimizer']) 127 | logging.info(f'load model and optimizer state {config.load}.pth') 128 | else: 129 | logging.info(f'load model state {config.load}.pth') 130 | 131 | grid_tuple = [np.arange(grid_length, dtype = np.float32) for grid_length in image_shape] 132 | landmark_00_converted = np.flip(landmark_00, axis = 1) - np.array([crop_range[0].start, crop_range[1].start, crop_range[2].start], dtype = np.float32) 133 | 134 | 135 | diff_stats = [] 136 | stop_criterion = model.util.StopCriterion(stop_std = config.stop_std, query_len = config.stop_query_len) 137 | pbar = tqdm.tqdm(range(config.max_num_iteration)) 138 | for i in pbar: 139 | optimizer.zero_grad() 140 | res = regnet(input_image) 141 | 142 | total_loss = 0. 143 | if 'disp_i2t' in res: 144 | simi_loss = (ncc_loss(res['warped_input_image'], res['template']) + ncc_loss(input_image, res['warped_template']))/2. 145 | else: 146 | simi_loss = ncc_loss(res['warped_input_image'], res['template']) 147 | total_loss += simi_loss 148 | 149 | 150 | if config.smooth_reg > 0: 151 | if 'disp_i2t' in res: 152 | smooth_loss = (model.loss.smooth_loss(res['scaled_disp_t2i']) + model.loss.smooth_loss(res['scaled_disp_i2t']))/2. 153 | else: 154 | # smooth_loss = model.loss.smooth_loss(res['scaled_disp_t2i']) 155 | smooth_loss = model.loss.smooth_loss(res['scaled_disp_t2i'], res['scaled_template']) 156 | total_loss += config.smooth_reg*smooth_loss 157 | smooth_loss_item = smooth_loss.item() 158 | else: 159 | smooth_loss_item = 0 160 | 161 | if config.cyclic_reg > 0: 162 | if 'disp_i2t' in res: 163 | # cyclic_loss = (torch.mean((torch.sum(res['scaled_disp_t2i'], 0))**2) + torch.mean((torch.sum(res['scaled_disp_i2t'], 0)))**0.5)/2. 164 | cyclic_loss = ((torch.mean((torch.sum(res['scaled_disp_t2i'], 0))**2))**0.5 + (torch.mean((torch.sum(res['scaled_disp_i2t'], 0))**2))**0.5)/2. 165 | else: 166 | cyclic_loss = (torch.mean((torch.sum(res['scaled_disp_t2i'], 0))**2))**0.5 167 | total_loss += config.cyclic_reg*cyclic_loss 168 | cyclic_loss_item = cyclic_loss.item() 169 | else: 170 | cyclic_loss_item = 0 171 | 172 | total_loss.backward() 173 | optimizer.step() 174 | 175 | stop_criterion.add(simi_loss.item()) 176 | if stop_criterion.stop(): 177 | break 178 | 179 | pbar.set_description(f'{i}, simi. loss {simi_loss.item():.4f}, smooth loss {smooth_loss_item:.3f}, cyclic loss {cyclic_loss_item:.3f}') 180 | 181 | if i % config.pair_disp_calc_interval == 0: 182 | if 'disp_i2t' in res: 183 | disp_i2t = res['disp_i2t'][config.pair_disp_indexes] 184 | else: 185 | disp_i2t = calcdisp.inverse_disp(res['disp_t2i'][config.pair_disp_indexes]) 186 | composed_disp = calcdisp.compose_disp(disp_i2t, res['disp_t2i'][config.pair_disp_indexes], mode = 'all') 187 | composed_disp_np = composed_disp.cpu().numpy() # (2, 2, 3, d, h, w) 188 | 189 | inter = interpolate.RegularGridInterpolator(grid_tuple, np.moveaxis(composed_disp_np[0, 1], 0, -1)) 190 | calc_landmark_disp = inter(landmark_00_converted) 191 | 192 | diff = (np.sum(((calc_landmark_disp - landmark_disp)*pixel_spacing)**2, 1))**0.5 193 | diff_stats.append([i, np.mean(diff), np.std(diff)]) 194 | print(f'\ndiff: {np.mean(diff):.2f}+-{np.std(diff):.2f}({np.max(diff):.2f})') 195 | 196 | 197 | if 'disp_i2t' in res: 198 | disp_i2t = res['disp_i2t'][config.pair_disp_indexes] 199 | else: 200 | disp_i2t = calcdisp.inverse_disp(res['disp_t2i'][config.pair_disp_indexes]) 201 | composed_disp = calcdisp.compose_disp(disp_i2t, res['disp_t2i'][config.pair_disp_indexes], mode = 'all') 202 | composed_disp_np = composed_disp.cpu().numpy() # (2, 2, 3, d, h, w) 203 | inter = interpolate.RegularGridInterpolator(grid_tuple, np.moveaxis(composed_disp_np[0, 1], 0, -1)) 204 | calc_landmark_disp = inter(landmark_00_converted) 205 | 206 | diff = (np.sum(((calc_landmark_disp - landmark_disp)*pixel_spacing)**2, 1))**0.5 207 | diff_stats.append([i, np.mean(diff), np.std(diff)]) 208 | print(f'\ndiff: {np.mean(diff):.2f}+-{np.std(diff):.2f}({np.max(diff):.2f})') 209 | diff_stats = np.array(diff_stats) 210 | 211 | 212 | res['composed_disp_np'] = composed_disp_np 213 | states = {'config': config, 'model': regnet.state_dict(), 'optimizer': optimizer.state_dict(), 'registration_result':res, 'loss_list':stop_criterion.loss_list, 'diff_stats':diff_stats} 214 | index = len([file for file in os.listdir(states_folder) if file.endswith('pth')]) 215 | states_file = f'reg_dirlab_case{case}_{index:03d}.pth' 216 | torch.save(states, os.path.join(states_folder, states_file)) 217 | 218 | logging.info(f'save model and optimizer state {states_file}') 219 | 220 | 221 | plt.figure(dpi = plot_dpi) 222 | plt.plot(stop_criterion.loss_list, label = 'simi') 223 | plt.title('similarity loss vs iteration') 224 | -------------------------------------------------------------------------------- /utils/structure.py: -------------------------------------------------------------------------------- 1 | class Struct: 2 | def __init__(self, **entries): 3 | self.__dict__.update(entries) 4 | --------------------------------------------------------------------------------