├── __init__.py ├── reconstruction ├── __init__.py ├── data │ ├── __init__.py │ ├── volume_sampler.py │ ├── subsample.py │ └── mri_data.py ├── models │ ├── denoisers │ │ ├── __init__.py │ │ ├── kspace_net.py │ │ ├── norm_unet.py │ │ ├── unet.py │ │ └── mwcnn.py │ ├── __init__.py │ ├── recurrent_varnet.py │ ├── cinenet.py │ ├── varnet.py │ ├── recurrent_cinenet.py │ └── recurrent_xpdnet.py ├── pl_modules │ ├── __init__.py │ ├── cinenet_module.py │ ├── varnet_module.py │ ├── xpdnet_module.py │ └── data_module.py └── utils │ ├── __init__.py │ ├── coil_combine.py │ ├── evaluate.py │ ├── losses.py │ ├── padding.py │ ├── math.py │ └── fftc.py ├── docs └── crossnet.png ├── traintest_scripts ├── dirs_path.yaml ├── run_inference.py ├── varnet │ └── train_test_varnet.py ├── cinenet │ └── train_test_cinenet.py └── xpdnet │ └── train_test_xpdnet.py ├── requirements.txt ├── LICENSE ├── notebooks └── BART_setup.ipynb └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /reconstruction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/crossnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f78bono/deep-cine-cardiac-mri/HEAD/docs/crossnet.png -------------------------------------------------------------------------------- /traintest_scripts/dirs_path.yaml: -------------------------------------------------------------------------------- 1 | data_path: /path/to/data 2 | log_path: /root/traintest_scripts 3 | save_path: /root/results -------------------------------------------------------------------------------- /reconstruction/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .mri_data import SliceDataset, CombinedSliceDataset 2 | from .volume_sampler import VolumeSampler 3 | -------------------------------------------------------------------------------- /reconstruction/models/denoisers/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import Unet 2 | from .norm_unet import NormUnet, NormUnet3D 3 | from .mwcnn import MWCNN 4 | from .kspace_net import KSpaceCNN -------------------------------------------------------------------------------- /reconstruction/pl_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .mri_module import MriModule 2 | from .varnet_module import VarNetModule 3 | from .cinenet_module import CineNetModule 4 | from .xpdnet_module import XPDNetModule 5 | from .data_module import MriDataModule -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.0 2 | scipy==1.7.0 3 | scikit_image==0.18.0 4 | h5py==3.3.0 5 | imageio==2.4.1 6 | torchvision==0.10.0 7 | torchtext==0.10.0 8 | torchaudio==0.9.0 9 | torchmetrics==0.4.0 10 | torch==1.9.0 11 | pytorch_lightning==1.3.8 -------------------------------------------------------------------------------- /reconstruction/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .varnet import SensitivityModel, VarNet, VarNetBlock 2 | from .cinenet import CineNet, CineNetBlock 3 | from .xpdnet import XPDNet, XPDNetBlock 4 | from .recurrent_varnet import VarNet_RNN 5 | from .recurrent_cinenet import CineNet_RNN 6 | from .recurrent_xpdnet import XPDNet_RNN -------------------------------------------------------------------------------- /reconstruction/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .coil_combine import rss, rss_complex 2 | 3 | from .fftc import ( 4 | fft1c, 5 | ifft1c, 6 | fft2c, 7 | ifft2c, 8 | fftshift, 9 | ifftshift, 10 | roll, 11 | ) 12 | 13 | from .losses import SSIMLoss 14 | 15 | from .math import ( 16 | complex_abs, 17 | complex_abs_sq, 18 | complex_conj, 19 | complex_mul, 20 | tensor_to_complex_np, 21 | real_to_complex_multi_ch, 22 | complex_to_real_multi_ch, 23 | ) 24 | 25 | from .padding import pad_for_mwcnn, unpad_from_mwcnn 26 | -------------------------------------------------------------------------------- /reconstruction/utils/coil_combine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .math import complex_abs_sq 3 | 4 | 5 | def rss(data: torch.Tensor, dim: int = 0) -> torch.Tensor: 6 | """ 7 | Compute the Root Sum of Squares (RSS). 8 | 9 | RSS is computed assuming that dim is the coil dimension. 10 | 11 | Args: 12 | data: The input tensor 13 | dim: The dimensions along which to apply the RSS transform 14 | 15 | Returns: 16 | The RSS value. 17 | """ 18 | return torch.sqrt((data ** 2).sum(dim)) 19 | 20 | 21 | def rss_complex(data: torch.Tensor, dim: int = 0) -> torch.Tensor: 22 | """ 23 | Compute the Root Sum of Squares (RSS) for complex inputs. 24 | 25 | RSS is computed assuming that dim is the coil dimension. 26 | 27 | Args: 28 | data: The input tensor 29 | dim: The dimensions along which to apply the RSS transform 30 | 31 | Returns: 32 | The RSS value. 33 | """ 34 | return torch.sqrt(complex_abs_sq(data).sum(dim)) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Francesco Bono 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /reconstruction/utils/evaluate.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 4 | 5 | 6 | def mse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray: 7 | """Compute Mean Squared Error (MSE)""" 8 | return np.mean((gt - pred) ** 2) 9 | 10 | 11 | def nmse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray: 12 | """Compute Normalized Mean Squared Error (NMSE)""" 13 | return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 14 | 15 | 16 | def psnr( 17 | gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None 18 | ) -> np.ndarray: 19 | """Compute Peak Signal to Noise Ratio metric (PSNR)""" 20 | if maxval is None: 21 | maxval = gt.max() 22 | return peak_signal_noise_ratio(gt, pred, data_range=maxval) 23 | 24 | 25 | def ssim( 26 | gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None 27 | ) -> np.ndarray: 28 | """Compute time-averaged Structural Similarity Index Metric (SSIM)""" 29 | if not gt.ndim == 3: 30 | raise ValueError("Unexpected number of dimensions in ground truth.") 31 | if not gt.ndim == pred.ndim: 32 | raise ValueError("Ground truth dimensions does not match pred.") 33 | 34 | maxval = gt.max() if maxval is None else maxval 35 | 36 | ssim = 0 37 | for slice_num in range(gt.shape[0]): 38 | ssim = ssim + structural_similarity( 39 | gt[slice_num], pred[slice_num], data_range=maxval 40 | ) 41 | 42 | return ssim / gt.shape[0] 43 | 44 | 45 | METRIC_FUNCS = dict( 46 | MSE=mse, 47 | NMSE=nmse, 48 | PSNR=psnr, 49 | SSIM=ssim, 50 | ) -------------------------------------------------------------------------------- /reconstruction/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SSIMLoss(nn.Module): 7 | """ 8 | Time-averaged SSIM loss module. 9 | """ 10 | 11 | def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03): 12 | """ 13 | Args: 14 | win_size: Window size for SSIM calculation. 15 | k1: k1 parameter for SSIM calculation. 16 | k2: k2 parameter for SSIM calculation. 17 | """ 18 | super().__init__() 19 | self.win_size = win_size 20 | self.k1, self.k2 = k1, k2 21 | self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size ** 2) 22 | NP = win_size ** 2 23 | self.cov_norm = NP / (NP - 1) 24 | 25 | def forward(self, Xt: torch.Tensor, Yt: torch.Tensor, data_range: torch.Tensor): 26 | assert isinstance(self.w, torch.Tensor) 27 | 28 | ssims = 0. 29 | Nt = Xt.shape[2] 30 | 31 | for t in range(Nt): 32 | X = Xt[:,:,t,:] 33 | Y = Yt[:,:,t,:] 34 | data_range = torch.Tensor([Y.max()]).to('cuda') 35 | 36 | data_range = data_range[:, None, None, None] 37 | C1 = (self.k1 * data_range) ** 2 38 | C2 = (self.k2 * data_range) ** 2 39 | ux = F.conv2d(X, self.w) # typing: ignore 40 | uy = F.conv2d(Y, self.w) # 41 | uxx = F.conv2d(X * X, self.w) 42 | uyy = F.conv2d(Y * Y, self.w) 43 | uxy = F.conv2d(X * Y, self.w) 44 | vx = self.cov_norm * (uxx - ux * ux) 45 | vy = self.cov_norm * (uyy - uy * uy) 46 | vxy = self.cov_norm * (uxy - ux * uy) 47 | A1, A2, B1, B2 = ( 48 | 2 * ux * uy + C1, 49 | 2 * vxy + C2, 50 | ux ** 2 + uy ** 2 + C1, 51 | vx + vy + C2, 52 | ) 53 | D = B1 * B2 54 | S = (A1 * A2) / D 55 | 56 | ssims += 1 - S.mean() 57 | 58 | return ssims / Nt 59 | -------------------------------------------------------------------------------- /reconstruction/models/denoisers/kspace_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class KSpaceCNN(nn.Module): 7 | """ 8 | A simple CNN model performing k-space interpolation of a buffer in the 9 | k-space correction module of XPDNet. The model architecture consists of 10 | consecutive convolutional layers, each followed by an activation function. 11 | """ 12 | def __init__( 13 | self, 14 | in_chans: int, 15 | out_chans: int, 16 | n_convs: int = 3, 17 | n_filters: int = 16, 18 | ): 19 | """ 20 | Args: 21 | in_chans: Number of channels in the input to the CNN model. 22 | out_chans: Number of channels in the output of the CNN model. 23 | n_convs: Number of consecutive convolutional layers. 24 | n_filters: Number of convolutional filters. 25 | """ 26 | super().__init__() 27 | 28 | self.in_chans = in_chans 29 | self.out_chans = out_chans 30 | self.n_convs = n_convs 31 | self.n_filters = n_filters 32 | 33 | convs = nn.ModuleList([ 34 | nn.Conv3d(self.in_chans, self.n_filters, 3, padding='same'), 35 | nn.ReLU(inplace=True), 36 | ]) 37 | for _ in range(1, self.n_convs-1): 38 | convs.append( 39 | nn.Conv3d(self.n_filters, self.n_filters, 3, padding='same'), 40 | ) 41 | convs.append(nn.ReLU(inplace=True)) 42 | convs.append(nn.Conv3d(self.n_filters, self.out_chans, 3, padding='same')) 43 | 44 | self.layers = nn.Sequential(*convs) 45 | 46 | 47 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 48 | """ 49 | Args: 50 | inputs: Input tensor of shape `(N, T, N_coils, H, W, in_chans)` 51 | Returns: 52 | Output tensor of shape `(N, T, N_coils, H, W, out_chans)` 53 | """ 54 | 55 | b, t, c, h, w, ch = inputs.shape 56 | 57 | outputs = inputs.permute(0,2,5,1,3,4).reshape(b*c, ch, t, h, w) 58 | outputs = self.layers(outputs) 59 | outputs = outputs.reshape(b, c, self.out_chans, t, h, w).permute(0,3,1,4,5,2) 60 | 61 | return outputs -------------------------------------------------------------------------------- /reconstruction/utils/padding.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | def pad_for_mwcnn(x: torch.Tensor, n_scales: int) -> Tuple[torch.Tensor, List[torch.Tensor]]: 8 | """ 9 | Pads a tensor for input to a multi-scale wavelet CNN. 10 | Padding is applied to the last two dimensions. 11 | 12 | Source: 13 | https://github.com/zaccharieramzi/fastmri-reproducible-benchmark/blob/master/fastmri_recon/models/utils/pad_for_pool.py 14 | 15 | Args: 16 | x: A PyTorch tensor with at least 2 dimensions. 17 | n_scales: Number of scales in multi-scale wavelet CNN. 18 | 19 | Returns: 20 | The padded tensor and the corresponding padding values. 21 | """ 22 | if x.dim() < 2: 23 | raise ValueError("Number of dimensions cannot be less than 2") 24 | problematic_dims = torch.tensor(x.shape[-2:]) 25 | 26 | k = torch.div(problematic_dims, 2**n_scales, rounding_mode='floor' ) 27 | n_pad = torch.where( 28 | torch.eq(torch.remainder(problematic_dims, 2**n_scales), 0), 29 | 0, 30 | (k+1) * 2**n_scales - problematic_dims 31 | ) 32 | 33 | padding_left = torch.where( 34 | torch.logical_or( 35 | torch.eq(torch.remainder(problematic_dims, 2), 0), 36 | torch.eq(n_pad, 0), 37 | ), 38 | torch.div(n_pad, 2, rounding_mode='floor'), 39 | 1 + torch.div(n_pad, 2, rounding_mode='floor'), 40 | ) 41 | padding_right = torch.div(n_pad, 2, rounding_mode='floor') 42 | 43 | paddings = [] 44 | for i in range(2): 45 | paddings += [padding_left[-1-i], padding_right[-1-i]] 46 | 47 | x_padded = F.pad(x, paddings) 48 | 49 | return x_padded, paddings 50 | 51 | 52 | 53 | def unpad_from_mwcnn(x: torch.Tensor, pad: List[torch.Tensor]) -> torch.Tensor: 54 | """ 55 | Unpads the output tensor from a multi-scale wavelet CNN. 56 | 57 | Args: 58 | x: A padded PyTorch tensor with at least 2 dimensions. 59 | pad: The corresponding left and right padding values, 60 | ordered from the last to second last dimensions in x. 61 | 62 | Returns: 63 | The unpadded tensor. 64 | """ 65 | if pad[1] == 0: 66 | return x[..., pad[2]:, pad[0]:] if pad[3] == 0 else x[..., pad[2]:-pad[3], pad[0]:] 67 | elif pad[3] == 0: 68 | return x[..., pad[2]:, pad[0]:] if pad[1] == 0 else x[..., pad[2]:, pad[0]:-pad[1]] 69 | else: 70 | return x[..., pad[2]:-pad[3], pad[0]:-pad[1]] -------------------------------------------------------------------------------- /traintest_scripts/run_inference.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pathlib import Path 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | import reconstruction as rec 10 | from reconstruction.data.transforms import center_crop_to_smallest 11 | 12 | 13 | class InferenceTransform(nn.Module): 14 | """ 15 | Data saving module for reconstruction output of the inference 16 | dataset. This is generally a subset of the test set and it is 17 | used for visualisation purposes. 18 | """ 19 | def __init__( 20 | self, 21 | model: nn.Module, 22 | model_type: str, 23 | save_path: Path, 24 | ): 25 | """ 26 | Args: 27 | model: Trained model used for dynamic reconstruction of 28 | MRI data. 29 | model_type: One of 'varnet', 'cinenet', 'xpdnet'. 30 | save_path: Path to directory where saving data will be 31 | stored. 32 | """ 33 | super(InferenceTransform, self).__init__() 34 | 35 | assert model_type in ['varnet', 'cinenet', 'xpdnet'], \ 36 | 'Wrong model_type arg.' 37 | 38 | self.model_type = model_type 39 | self.save_path = save_path 40 | self.device = 'cuda' 41 | self.model = model.to(self.device).eval() 42 | 43 | def forward( 44 | self, 45 | masked_kspace: torch.Tensor, 46 | mask: torch.Tensor, 47 | target: torch.Tensor, 48 | fname: str, 49 | sens_maps: Optional[torch.Tensor]=None, 50 | ) -> float: 51 | 52 | # Image reconstruction of inference dataset using trained model 53 | model_time_start = time.time() 54 | masked_k = masked_kspace.to(self.device) 55 | mask = mask.to(self.device) 56 | if self.model_type == 'cinenet': 57 | sens_maps = sens_maps.to(self.device) 58 | output = self.model(masked_k, mask, sens_maps) 59 | else: 60 | output = self.model(masked_k, mask) 61 | model_time_end = time.time() 62 | output = output.cpu() 63 | 64 | # Generate zero-filled reconstruction for qualitative comparison 65 | scaling_factor = torch.sqrt(torch.prod(torch.as_tensor(masked_kspace.shape[-3:-1]))) 66 | images = rec.utils.ifft2c(masked_kspace, norm=None) * scaling_factor 67 | zero_filled = rec.utils.rss_complex(images, dim=2) 68 | 69 | # Crop all tensors to the same size (for visualisation) 70 | target, output = center_crop_to_smallest(target, output) 71 | target, zero_filled = center_crop_to_smallest(target, zero_filled) 72 | 73 | # Store ndarray-converted tensors to save_path 74 | target = target.numpy().astype('float32') 75 | output = output.numpy().astype('float32') 76 | zero_filled = zero_filled.numpy().astype('float32') 77 | 78 | np.save(str(self.save_path) + f'/target_{fname[0]}.npy', target[0]) 79 | np.save(str(self.save_path) + f'/output_{self.model_type}_{fname[0]}.npy', output[0]) 80 | np.save(str(self.save_path) + f'/zero_filled_{fname[0]}.npy', zero_filled[0]) 81 | 82 | return model_time_end - model_time_start 83 | -------------------------------------------------------------------------------- /reconstruction/utils/math.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def complex_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 6 | """ 7 | Complex multiplication. 8 | 9 | This multiplies two complex tensors assuming that they are both stored as 10 | real arrays with the last dimension being the complex dimension. 11 | 12 | Args: 13 | x: A PyTorch tensor with the last dimension of size 2. 14 | y: A PyTorch tensor with the last dimension of size 2. 15 | 16 | Returns: 17 | A PyTorch tensor with the last dimension of size 2. 18 | """ 19 | if not x.shape[-1] == y.shape[-1] == 2: 20 | raise ValueError("Tensors do not have separate complex dim.") 21 | 22 | re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] 23 | im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] 24 | 25 | return torch.stack((re, im), dim=-1) 26 | 27 | 28 | def complex_conj(x: torch.Tensor) -> torch.Tensor: 29 | """ 30 | Complex conjugate. 31 | 32 | This applies the complex conjugate assuming that the input array has the 33 | last dimension as the complex dimension. 34 | 35 | Args: 36 | x: A PyTorch tensor with the last dimension of size 2. 37 | y: A PyTorch tensor with the last dimension of size 2. 38 | 39 | Returns: 40 | A PyTorch tensor with the last dimension of size 2. 41 | """ 42 | if not x.shape[-1] == 2: 43 | raise ValueError("Tensor does not have separate complex dim.") 44 | 45 | return torch.stack((x[..., 0], -x[..., 1]), dim=-1) 46 | 47 | 48 | def complex_abs(data: torch.Tensor) -> torch.Tensor: 49 | """ 50 | Compute the absolute value of a complex valued input tensor. 51 | 52 | Args: 53 | data: A complex valued tensor, where the size of the final dimension 54 | should be 2. 55 | 56 | Returns: 57 | Absolute value of data. 58 | """ 59 | if not data.shape[-1] == 2: 60 | raise ValueError("Tensor does not have separate complex dim.") 61 | 62 | return (data ** 2).sum(dim=-1).sqrt() 63 | 64 | 65 | def complex_abs_sq(data: torch.Tensor) -> torch.Tensor: 66 | """ 67 | Compute the squared absolute value of a complex tensor. 68 | 69 | Args: 70 | data: A complex valued tensor, where the size of the final dimension 71 | should be 2. 72 | 73 | Returns: 74 | Squared absolute value of data. 75 | """ 76 | if not data.shape[-1] == 2: 77 | raise ValueError("Tensor does not have separate complex dim.") 78 | 79 | return (data ** 2).sum(dim=-1) 80 | 81 | 82 | def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray: 83 | """ 84 | Converts a complex torch tensor to numpy array. 85 | 86 | Args: 87 | data: Input data to be converted to numpy. 88 | 89 | Returns: 90 | Complex numpy version of data. 91 | """ 92 | data = data.numpy() 93 | 94 | return data[..., 0] + 1j * data[..., 1] 95 | 96 | 97 | def real_to_complex_multi_ch(x: torch.Tensor, n: int) -> torch.Tensor: 98 | """ 99 | Real to complex tensor conversion. 100 | 101 | Converts a stack of n complex tensors, stored as a torch.float array 102 | with last dimension (channel dimension) of size 2n, into a single 103 | torch.complex tensor with n channels. 104 | 105 | Args: 106 | x: A torch.float-type tensor where the first n>=2 elements of the 107 | last dimension correspond to the real part and the last n>=2 108 | elements of the last dimension correspond to the imaginary 109 | part of the stacked complex tensors. 110 | n: The number of stacked complex tensors. 111 | 112 | Returns: 113 | A torch.complex-type tensor with the last dimension of size n. 114 | """ 115 | if not x.shape[-1] == 2*n: 116 | raise ValueError("Real and imaginary parts do not have the same size") 117 | 118 | return torch.complex(x[..., :n], x[..., n:]) 119 | 120 | 121 | def complex_to_real_multi_ch(x: torch.Tensor) -> torch.Tensor: 122 | """ 123 | Complex to real tensor conversion. 124 | 125 | Converts a torch.complex tensor with the last dimension >= 1 126 | into a torch.float tensor with stacked real and imaginary parts. 127 | 128 | Args: 129 | x: A torch.complex-type tensor with the last dimension >= 1. 130 | 131 | Returns: 132 | A torch.float-type tensor with last dimension double the size 133 | of that of x. 134 | """ 135 | return torch.cat([x.real, x.imag], dim=-1) -------------------------------------------------------------------------------- /reconstruction/data/volume_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | This source code is based on the fastMRI repository from Facebook AI 3 | Research and is used as a general framework to handle MRI data. Link: 4 | 5 | https://github.com/facebookresearch/fastMRI 6 | """ 7 | 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | import torch.distributed as dist 12 | from reconstruction.data.mri_data import CombinedSliceDataset, SliceDataset 13 | from torch.utils.data import Sampler 14 | 15 | 16 | class VolumeSampler(Sampler): 17 | """ 18 | Sampler for volumetric MRI data. 19 | 20 | Based on pytorch DistributedSampler, the difference is that all instances 21 | from the same MRI volume need to go to the same node for distributed 22 | training. Dataset example is a list of tuples (fname, instance), where 23 | fname is essentially the volume name (actually a filename). 24 | """ 25 | 26 | def __init__( 27 | self, 28 | dataset: Union[CombinedSliceDataset, SliceDataset], 29 | num_replicas: Optional[int] = None, 30 | rank: Optional[int] = None, 31 | shuffle: bool = True, 32 | seed: int = 0, 33 | ): 34 | """ 35 | Args: 36 | dataset: An MRI dataset (e.g., SliceData). 37 | num_replicas: Number of processes participating in distributed 38 | training. By default, :attr:`rank` is retrieved from the 39 | current distributed group. 40 | rank: Rank of the current process within :attr:`num_replicas`. By 41 | default, :attr:`rank` is retrieved from the current distributed 42 | group. 43 | shuffle: If ``True`` (default), sampler will shuffle the indices. 44 | seed: random seed used to shuffle the sampler if 45 | :attr:`shuffle=True`. This number should be identical across 46 | all processes in the distributed group. 47 | """ 48 | if num_replicas is None: 49 | if not dist.is_available(): 50 | raise RuntimeError("Requires distributed package to be available") 51 | num_replicas = dist.get_world_size() 52 | if rank is None: 53 | if not dist.is_available(): 54 | raise RuntimeError("Requires distributed package to be available") 55 | rank = dist.get_rank() 56 | self.dataset = dataset 57 | self.num_replicas = num_replicas 58 | self.rank = rank 59 | self.epoch = 0 60 | self.shuffle = shuffle 61 | self.seed = seed 62 | 63 | # get all file names and split them based on number of processes 64 | self.all_volume_names = sorted( 65 | set(str(example[0]) for example in self.dataset.examples) 66 | ) 67 | self.all_volumes_split: List[List[str]] = [] 68 | for rank_num in range(self.num_replicas): 69 | self.all_volumes_split.append( 70 | [ 71 | self.all_volume_names[i] 72 | for i in range( 73 | rank_num, len(self.all_volume_names), self.num_replicas 74 | ) 75 | ] 76 | ) 77 | 78 | # get slice indices for each file name 79 | rank_indices: List[List[int]] = [[] for _ in range(self.num_replicas)] 80 | for i, example in enumerate(self.dataset.examples): 81 | vname = str(example[0]) 82 | for rank_num in range(self.num_replicas): 83 | if vname in self.all_volumes_split[rank_num]: 84 | rank_indices[rank_num].append(i) 85 | break 86 | 87 | # need to send equal number of samples to each process - take the max 88 | self.num_samples = max(len(indices) for indices in rank_indices) 89 | self.total_size = self.num_samples * self.num_replicas 90 | self.indices = rank_indices[self.rank] 91 | 92 | def __iter__(self): 93 | if self.shuffle: 94 | # deterministically shuffle based on epoch and seed 95 | g = torch.Generator() 96 | g.manual_seed(self.seed + self.epoch) 97 | ordering = torch.randperm(len(self.indices), generator=g).tolist() 98 | indices = [self.indices[i] for i in ordering] 99 | else: 100 | indices = self.indices 101 | 102 | # add extra samples to match num_samples 103 | repeat_times = self.num_samples // len(indices) 104 | indices = indices * repeat_times 105 | indices = indices + indices[: self.num_samples - len(indices)] 106 | assert len(indices) == self.num_samples 107 | 108 | return iter(indices) 109 | 110 | def __len__(self): 111 | return self.num_samples 112 | 113 | def set_epoch(self, epoch): 114 | self.epoch = epoch 115 | -------------------------------------------------------------------------------- /notebooks/BART_setup.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyNV++4NsZ3v3DQJXeWImJLX"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["This notebook shows how to set up a Python environment for the BART library on Google Colab.\n","\n","In this project, BART is used to estimate the coil sensitivity maps via ESPIRiT calibration, in order to generate ground truth coil-combined MRI images from the given dataset of k-space volumes.\n","\n","Running the following code before launching the train_test programs should be enough for things to work smoothly. If errors arise, please refer to the official BART documentation at:\n","- [BART Setup on Colab](https://github.com/mrirecon/bart-workshop/blob/master/ismrm2021/neural_networks/bart_neural_networks.ipynb)\n","- [BART Setup on local machine](https://mrirecon.github.io/bart/)"],"metadata":{"id":"li4Fk7lyuX47"}},{"cell_type":"markdown","metadata":{"id":"17qwMnVqFl9K"},"source":["# Setup BART for Colab"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"1mY3soK3C8al","executionInfo":{"status":"ok","timestamp":1667862079062,"user_tz":-60,"elapsed":16,"user":{"displayName":"Francesco Bono","userId":"11189470803682651835"}},"outputId":"8d2900dd-b6b9-48e8-8abb-c62a78e920f8"},"source":["%%bash\n","\n","# Use CUDA 10.1 when on Tesla K80\n","\n","# Estimate GPU Type\n","GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader)\n","\n","echo \"GPU Type:\"\n","echo $GPU_NAME\n","\n","if [ \"Tesla K80\" = \"$GPU_NAME\" ];\n","then\n"," echo \"GPU type Tesla K80 does not support CUDA 11. Set CUDA to version 10.1.\"\n","\n"," # Change default CUDA to version 10.1\n"," cd /usr/local\n"," rm cuda\n"," ln -s cuda-10.1 cuda\n","else\n"," echo \"Current GPU supports default CUDA-11.\"\n"," echo \"No further actions are necessary.\"\n","fi\n","\n","echo \"GPU Information:\"\n","nvidia-smi --query-gpu=gpu_name,driver_version,memory.total --format=csv\n","nvcc --version"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["GPU Type:\n","Tesla T4\n","Current GPU supports default CUDA-11.\n","No further actions are necessary.\n","GPU Information:\n","name, driver_version, memory.total [MiB]\n","Tesla T4, 460.32.03, 15109 MiB\n","nvcc: NVIDIA (R) Cuda compiler driver\n","Copyright (c) 2005-2021 NVIDIA Corporation\n","Built on Sun_Feb_14_21:12:58_PST_2021\n","Cuda compilation tools, release 11.2, V11.2.152\n","Build cuda_11.2.r11.2/compiler.29618528_0\n"]}]},{"cell_type":"code","metadata":{"id":"MJrU8MCeDKl3"},"source":["%%bash\n","\n","# Install BARTs dependencies\n","apt-get install -y make gcc libfftw3-dev liblapacke-dev libpng-dev libopenblas-dev &> /dev/null\n","\n","# Clone Bart\n","[ -d /content/bart ] && rm -r /content/bart\n","git clone https://github.com/mrirecon/bart/ bart &> /dev/null"],"execution_count":null,"outputs":[]},{"cell_type":"code","source":["%%bash\n","\n","cd bart\n","\n","# For long term support, we checkout the following tag:\n","git checkout tags/ISMRM21_NN\n","\n","# Define specifications \n","COMPILE_SPECS=\" PARALLEL=1\n"," CUDA=1\n"," CUDA_BASE=/usr/local/cuda\n"," CUDA_LIB=lib64\n"," OPENBLAS=1\n"," BLAS_THREADSAFE=1\"\n","\n","printf \"%s\\n\" $COMPILE_SPECS > Makefiles/Makefile.local\n","\n","make &> /dev/null"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"MD7G5YoENj5w","executionInfo":{"status":"ok","timestamp":1667862149978,"user_tz":-60,"elapsed":56361,"user":{"displayName":"Francesco Bono","userId":"11189470803682651835"}},"outputId":"5f756861-11c9-4c8b-9223-d83659234521"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["Note: checking out 'tags/ISMRM21_NN'.\n","\n","You are in 'detached HEAD' state. You can look around, make experimental\n","changes and commit them, and you can discard any commits you make in this\n","state without impacting any branches by performing another checkout.\n","\n","If you want to create a new branch to retain commits you create, you may\n","do so (now or later) by using -b with the checkout command again. Example:\n","\n"," git checkout -b \n","\n","HEAD is now at 5287d1c5 Add tool to transform onehotencoded data to integer encoded\n"]}]},{"cell_type":"code","source":["%env TOOLBOX_PATH=/content/bart"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"yv6fkHGdMngD","executionInfo":{"status":"ok","timestamp":1667862149979,"user_tz":-60,"elapsed":14,"user":{"displayName":"Francesco Bono","userId":"11189470803682651835"}},"outputId":"ecb58a80-84f7-4281-f1ce-c04b9f451131"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["env: TOOLBOX_PATH=/content/bart\n"]}]},{"cell_type":"code","source":["import os\n","import sys\n","\n","os.environ['PATH'] = os.environ['TOOLBOX_PATH'] + \":\" + os.environ['PATH']\n","sys.path.append(os.environ['TOOLBOX_PATH'] + \"/python/\")"],"metadata":{"id":"tqNbYP6vOJ3t"},"execution_count":null,"outputs":[]}]} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep learning-based models for multi-coil cine cardiac MRI reconstruction 2 | ## Description 3 | 4 | This project presents novel methods to tackle the problem of dynamic MRI 5 | reconstruction (2D space + 1D time) of accelerated multi-coil cardiac data. 6 | 7 | The proposed methods are inspired by the recent deep-learning models 8 | - End-to-End Variational Net [[1]](#VarNet) 9 | - XPDNet [[2]](#XPDNet) 10 | 11 | that achieved state-of-the-art results in the [2020 FastMRI challenge](https://fastmri.org/leaderboards/challenge/) 12 | for static MRI reconstruction of the brain and the knee, and by 13 | - CineNet [[3]](#CineNet) 14 | 15 | for dynamic reconstruction of 2+1D cardiac data. 16 | 17 | The underlying architecture for all proposed models is based on the cross-domain 18 | network with unrolled optimisation. A high-level diagram of the architecture is 19 | shown below: 20 | 21 | 22 | 23 | The dynamic aspect of the MRI data is taken into account by adjusting the models 24 | to process 2+1D data volumes and exploit the inherent temporal redundancies of 25 | cine cardiac data. To this aim, each model was structured including five dynamic 26 | variants: 27 | - **2D** - static scenario, each temporal slice is processed independently 28 | - **3D** - each 2+1D volume is processed as a whole; model is isotropic towards 29 | each dimension 30 | - **XT** - two copies of each volume are rotated in the two 2D spatio-temporal 31 | domains (*x-t* plane and *y-t* plane), respectively, then processed by separate 32 | image correction networks and finally recombined 33 | - **XF** - same as for 'XT', but volume firstly undergoes an FFT along the temporal 34 | dimension to sparsify the domain 35 | - **CRNN** - leverage the time-dependency directly by implementing a recurrence 36 | along the temporal direction and also across the unrolled iterations. This variant 37 | is based on the Convolutional Recurrent Neural Network introduced by [[4]](#CRNN) 38 | 39 | ## Getting Started 40 | - Clone the repository: 41 | ``` 42 | git clone https://github.com/f78bono/deep-cine-cardiac-mri.git 43 | ``` 44 | - Create a new virtual environment and install the requirements in `./requirements.txt` 45 | - Set up a BART environment following the instruction in this [notebook](./notebooks/BART_setup.ipynb) 46 | - Change the paths in the following files according to your root directory: 47 | ``` 48 | # traintest_scripts/dirs_path.yaml 49 | data_path: /path/to/data 50 | log_path: /root/traintest_scripts 51 | save_path: /root/results 52 | 53 | # reconstruction/pl_modules/mri_module 54 | path_config = pathlib.Path("/root/traintest_scripts/dirs_path.yaml") 55 | 56 | # traintest_scripts/model/train_test_model.py 57 | path_config = pathlib.Path("/root/traintest_scripts/dirs_path.yaml") 58 | ``` 59 | 60 | ## Data Format 61 | The data set used in this project consists of breath-hold, retrospectively 62 | cardiac-gated, Cartesian, bSSFP multi-coil cine cardiac raw data. Each MRI volume 63 | is stored in a HDF5 file that contains the raw k-space data, with dimensions 64 | (number of slices, height, width, number of coils). 65 | 66 | The path to the data directory expects subdirectories `train/valid/test`. If a 67 | subdirectory `inference` is also included, the program can be given the option 68 | to store the model outputs for visualisation purposes. The code for dealing with 69 | the raw HDF5 files can be found in `./reconstruction/data/mri_data.py`. Adjust the 70 | following settings according to the specifics of the data set used: 71 | ``` 72 | scaling = 1e6 73 | crop_shape = (200, 200) 74 | crop_target = (180, 180) 75 | n_slices = 15 76 | filter_size = [0.7, 0., 0.3, 0.3] 77 | ``` 78 | 79 | ## Basic usage 80 | The scripts for training and testing each model can be found in the `traintest_scripts` 81 | subfolders. Running a script with no arguments, for example 82 | ``` 83 | python3 train_test_varnet.py 84 | ``` 85 | will start the training of the model (in this case XF-VarNet) with default parameters. 86 | Pass optional arguments to overwrite the default parameters, for example 87 | ``` 88 | python3 train_test_varnet.py \ 89 | --epochs 50 \ 90 | --save_checkpoint 1 \ 91 | --num_cascades 6 \ 92 | --dynamic_type CRNN 93 | ``` 94 | For testing use the command 95 | ``` 96 | python3 train_test_varnet.py --mode test --load_model 1 97 | ``` 98 | plus the arguments needed for the model under consideration. 99 | 100 | A detailed description of each argument can be found in the source code. 101 | 102 | ## Evaluation 103 | A qualitative assessment of the models was perfomed by direct observation of the reconstructed 104 | output along with the absolute error between normalised target and output. Check this 105 | [notebook](./notebooks/reconstruction_visualisation.ipynb) for a visual comparison of 106 | the different model outputs. 107 | 108 | The quantitative metrics used for evaluation are: 109 | - Structural similarity index measure (**SSIM**) 110 | - Normalised mean square error (**NMSE**) 111 | - Peak signal-to-noise ratio (**PSNR**) 112 | 113 | These statistics are compiled automatically during testing and stored in the `results` 114 | folder. 115 | 116 | ## Citations 117 | ### VarNet 118 | [code](https://github.com/facebookresearch/fastMRI) [publication](https://link.springer.com/chapter/10.1007/978-3-030-59713-9_7) 119 | 120 | [1] Sriram, Anuroop et al. (2020). **"End-to-End Variational Networks for Accelerated MRI Reconstruction"**. In: *Medical Image Computing and Computer Assisted Intervention - MICCAI 2020*. Cham: Springer International Publishing, pp. 64 - 73. 121 | 122 | ### XPDNet 123 | [code](https://github.com/zaccharieramzi/fastmri-reproducible-benchmark) [arXiv](https://arxiv.org/abs/2010.07290) 124 | 125 | [2] Ramzi, Zaccharie et al. (2021). **"XPDNet for MRI Reconstruction: an application to the 2020 fastMRI challenge"**. In: *arXiv: 2010.07290*. 126 | 127 | ### CineNet 128 | [code](https://github.com/koflera/DynamicRadCineMRI) [publication](https://aapm.onlinelibrary.wiley.com/doi/10.1002/mp.14809) 129 | 130 | [3] Kofler, Andreas et al. (2021). **"An end-to-end-trainable iterative network architecture for accelerated radial multi-coil 2D cine MR image reconstruction"**. In: *Medical Physics* 48.5, pp. 2412 - 2425. 131 | 132 | ### CRNN 133 | [publication](https://ieeexplore.ieee.org/document/8425639/) 134 | 135 | [4] Qin, Chen et al. (2019). **"Convolutional Recurrent Neural Networks for Dynamic MR Image Reconstruction"**. In: *IEEE Transactions on Medical Imaging* 38.1, pp. 280 - 290. 136 | -------------------------------------------------------------------------------- /reconstruction/utils/fftc.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import torch 3 | 4 | 5 | def fft1c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 6 | """ 7 | Apply centered 1 dimensional Fast Fourier Transform. 8 | 9 | Args: 10 | data: Complex valued input data containing at least 2 dimensions: 11 | dimension -2 is spatial dimension and dimension -1 has size 12 | 2. All other dimensions are assumed to be batch dimensions. 13 | norm: Normalization mode. See ``torch.fft.fft``. 14 | 15 | Returns: 16 | The FFT of the input. 17 | """ 18 | if not data.shape[-1] == 2: 19 | raise ValueError("Tensor does not have separate complex dim.") 20 | 21 | data = ifftshift(data, dim=[-2]) 22 | data = torch.view_as_real( 23 | torch.fft.fft( 24 | torch.view_as_complex(data), dim=-1, norm=norm 25 | ) 26 | ) 27 | data = fftshift(data, dim=[-2]) 28 | 29 | return data 30 | 31 | 32 | def ifft1c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 33 | """ 34 | Apply centered 1-dimensional Inverse Fast Fourier Transform. 35 | 36 | Args: 37 | data: Complex valued input data containing at least 2 dimensions: 38 | dimension -2 is spatial dimension and dimension -1 has size 39 | 2. All other dimensions are assumed to be batch dimensions. 40 | norm: Normalization mode. See ``torch.fft.ifft``. 41 | 42 | Returns: 43 | The IFFT of the input. 44 | """ 45 | if not data.shape[-1] == 2: 46 | raise ValueError("Tensor does not have separate complex dim.") 47 | 48 | data = ifftshift(data, dim=[-2]) 49 | data = torch.view_as_real( 50 | torch.fft.ifft( 51 | torch.view_as_complex(data), dim=-1, norm=norm 52 | ) 53 | ) 54 | data = fftshift(data, dim=[-2]) 55 | 56 | return data 57 | 58 | 59 | def fft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 60 | """ 61 | Apply centered 2 dimensional Fast Fourier Transform. 62 | 63 | Args: 64 | data: Complex valued input data containing at least 3 dimensions: 65 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 66 | 2. All other dimensions are assumed to be batch dimensions. 67 | norm: Normalization mode. See ``torch.fft.fft``. 68 | 69 | Returns: 70 | The FFT of the input. 71 | """ 72 | if not data.shape[-1] == 2: 73 | raise ValueError("Tensor does not have separate complex dim.") 74 | 75 | data = ifftshift(data, dim=[-3, -2]) 76 | data = torch.view_as_real( 77 | torch.fft.fftn( 78 | torch.view_as_complex(data), dim=(-2, -1), norm=norm 79 | ) 80 | ) 81 | data = fftshift(data, dim=[-3, -2]) 82 | 83 | return data 84 | 85 | 86 | def ifft2c(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 87 | """ 88 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 89 | 90 | Args: 91 | data: Complex valued input data containing at least 3 dimensions: 92 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 93 | 2. All other dimensions are assumed to be batch dimensions. 94 | norm: Normalization mode. See ``torch.fft.ifft``. 95 | 96 | Returns: 97 | The IFFT of the input. 98 | """ 99 | if not data.shape[-1] == 2: 100 | raise ValueError("Tensor does not have separate complex dim.") 101 | 102 | data = ifftshift(data, dim=[-3, -2]) 103 | data = torch.view_as_real( 104 | torch.fft.ifftn( 105 | torch.view_as_complex(data), dim=(-2, -1), norm=norm 106 | ) 107 | ) 108 | data = fftshift(data, dim=[-3, -2]) 109 | 110 | return data 111 | 112 | 113 | 114 | 115 | 116 | # Helper functions 117 | 118 | 119 | def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor: 120 | """ 121 | Similar to roll but for only one dim. 122 | 123 | Args: 124 | x: A PyTorch tensor. 125 | shift: Amount to roll. 126 | dim: Which dimension to roll. 127 | 128 | Returns: 129 | Rolled version of x. 130 | """ 131 | shift = shift % x.size(dim) 132 | if shift == 0: 133 | return x 134 | 135 | left = x.narrow(dim, 0, x.size(dim) - shift) 136 | right = x.narrow(dim, x.size(dim) - shift, shift) 137 | 138 | return torch.cat((right, left), dim=dim) 139 | 140 | 141 | def roll( 142 | x: torch.Tensor, 143 | shift: List[int], 144 | dim: List[int], 145 | ) -> torch.Tensor: 146 | """ 147 | Similar to np.roll but applies to PyTorch Tensors. 148 | 149 | Args: 150 | x: A PyTorch tensor. 151 | shift: Amount to roll. 152 | dim: Which dimension to roll. 153 | 154 | Returns: 155 | Rolled version of x. 156 | """ 157 | if len(shift) != len(dim): 158 | raise ValueError("len(shift) must match len(dim)") 159 | 160 | for (s, d) in zip(shift, dim): 161 | x = roll_one_dim(x, s, d) 162 | 163 | return x 164 | 165 | 166 | def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: 167 | """ 168 | Similar to np.fft.fftshift but applies to PyTorch Tensors 169 | 170 | Args: 171 | x: A PyTorch tensor. 172 | dim: Which dimension to fftshift. 173 | 174 | Returns: 175 | fftshifted version of x. 176 | """ 177 | if dim is None: 178 | # this weird code is necessary for toch.jit.script typing 179 | dim = [0] * (x.dim()) 180 | for i in range(1, x.dim()): 181 | dim[i] = i 182 | 183 | # also necessary for torch.jit.script 184 | shift = [0] * len(dim) 185 | for i, dim_num in enumerate(dim): 186 | shift[i] = x.shape[dim_num] // 2 187 | 188 | return roll(x, shift, dim) 189 | 190 | 191 | def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: 192 | """ 193 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 194 | 195 | Args: 196 | x: A PyTorch tensor. 197 | dim: Which dimension to ifftshift. 198 | 199 | Returns: 200 | ifftshifted version of x. 201 | """ 202 | if dim is None: 203 | # this weird code is necessary for toch.jit.script typing 204 | dim = [0] * (x.dim()) 205 | for i in range(1, x.dim()): 206 | dim[i] = i 207 | 208 | # also necessary for torch.jit.script 209 | shift = [0] * len(dim) 210 | for i, dim_num in enumerate(dim): 211 | shift[i] = (x.shape[dim_num] + 1) // 2 212 | 213 | return roll(x, shift, dim) 214 | -------------------------------------------------------------------------------- /reconstruction/models/denoisers/norm_unet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | from .unet import Unet 9 | 10 | 11 | 12 | class NormUnet(nn.Module): 13 | """ 14 | Normalized U-Net model. 15 | 16 | This is the same as a regular U-Net, but with normalization applied to the 17 | input before the U-Net. This keeps the values more numerically stable 18 | during training. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | chans: int, 24 | num_pools: int, 25 | in_chans: int = 2, 26 | out_chans: int = 2, 27 | drop_prob: float = 0.0, 28 | ): 29 | """ 30 | Args: 31 | chans: Number of output channels of the first convolution layer. 32 | num_pools: Number of down-sampling and up-sampling layers. 33 | in_chans: Number of channels in the input to the U-Net model. 34 | out_chans: Number of channels in the output to the U-Net model. 35 | drop_prob: Dropout probability. 36 | """ 37 | super().__init__() 38 | 39 | self.unet = Unet( 40 | in_chans=in_chans, 41 | out_chans=out_chans, 42 | chans=chans, 43 | num_pool_layers=num_pools, 44 | drop_prob=drop_prob, 45 | dims=2, 46 | ) 47 | 48 | def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: 49 | b, c, h, w, two = x.shape 50 | assert two == 2 51 | return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) 52 | 53 | def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: 54 | b, c2, h, w = x.shape 55 | assert c2 % 2 == 0 56 | c = c2 // 2 57 | return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() 58 | 59 | def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 60 | # group norm 61 | b, c, h, w = x.shape 62 | x = x.reshape(b, 2, c // 2 * h * w) 63 | 64 | mean = x.mean(dim=2).view(b, c, 1, 1) 65 | std = x.std(dim=2).view(b, c, 1, 1) 66 | 67 | x = x.view(b, c, h, w) 68 | 69 | return (x - mean) / std, mean, std 70 | 71 | def unnorm( 72 | self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor 73 | ) -> torch.Tensor: 74 | return x * std + mean 75 | 76 | def pad( 77 | self, x: torch.Tensor 78 | ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: 79 | _, _, h, w = x.shape 80 | w_mult = ((w - 1) | 15) + 1 81 | h_mult = ((h - 1) | 15) + 1 82 | w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] 83 | h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] 84 | x = F.pad(x, w_pad + h_pad) 85 | 86 | return x, (h_pad, w_pad, h_mult, w_mult) 87 | 88 | def unpad( 89 | self, 90 | x: torch.Tensor, 91 | h_pad: List[int], 92 | w_pad: List[int], 93 | h_mult: int, 94 | w_mult: int, 95 | ) -> torch.Tensor: 96 | return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] 97 | 98 | def forward(self, x: torch.Tensor) -> torch.Tensor: 99 | if not x.shape[-1] == 2: 100 | raise ValueError("Last dimension must be 2 for complex.") 101 | 102 | # get shapes for unet and normalize 103 | x = self.complex_to_chan_dim(x) 104 | x, mean, std = self.norm(x) 105 | x, pad_sizes = self.pad(x) 106 | 107 | x = self.unet(x) 108 | 109 | # get shapes back and unnormalize 110 | x = self.unpad(x, *pad_sizes) 111 | x = self.unnorm(x, mean, std) 112 | x = self.chan_complex_to_last_dim(x) 113 | 114 | return x 115 | 116 | 117 | class NormUnet3D(nn.Module): 118 | """ 119 | Normalized U-Net model for dynamic data (2D space + time). 120 | """ 121 | 122 | def __init__( 123 | self, 124 | chans: int, 125 | num_pools: int, 126 | in_chans: int = 2, 127 | out_chans: int = 2, 128 | drop_prob: float = 0.0, 129 | ): 130 | """ 131 | Args: 132 | chans: Number of output channels of the first convolution layer. 133 | num_pools: Number of down-sampling and up-sampling layers. 134 | in_chans: Number of channels in the input to the U-Net model. 135 | out_chans: Number of channels in the output to the U-Net model. 136 | drop_prob: Dropout probability. 137 | """ 138 | super().__init__() 139 | 140 | self.unet = Unet( 141 | in_chans=in_chans, 142 | out_chans=out_chans, 143 | chans=chans, 144 | num_pool_layers=num_pools, 145 | drop_prob=drop_prob, 146 | dims=3, 147 | ) 148 | 149 | def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: 150 | b, c, t, h, w, two = x.shape 151 | assert two == 2 152 | return x.permute(0, 5, 1, 2, 3, 4).reshape(b, 2 * c, t, h, w) 153 | 154 | def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: 155 | b, c2, t, h, w = x.shape 156 | assert c2 % 2 == 0 157 | c = c2 // 2 158 | return x.view(b, 2, c, t, h, w).permute(0, 2, 3, 4, 5, 1).contiguous() 159 | 160 | def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 161 | # group norm 162 | b, c, t, h, w = x.shape 163 | x = x.reshape(b, 2, c // 2 * t * h * w) 164 | 165 | mean = x.mean(dim=2).view(b, c, 1, 1, 1) 166 | std = x.std(dim=2).view(b, c, 1, 1, 1) 167 | 168 | x = x.view(b, c, t, h, w) 169 | 170 | return (x - mean) / std, mean, std 171 | 172 | def unnorm( 173 | self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor 174 | ) -> torch.Tensor: 175 | return x * std + mean 176 | 177 | def pad( 178 | self, x: torch.Tensor 179 | ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: 180 | _, _, t, h, w = x.shape 181 | w_mult = ((w - 1) | 15) + 1 182 | h_mult = ((h - 1) | 15) + 1 183 | t_mult = ((t - 1) | 15) + 1 184 | w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] 185 | h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] 186 | t_pad = [math.floor((t_mult - t) / 2), math.ceil((t_mult - t) / 2)] 187 | 188 | x = F.pad(x, w_pad + h_pad + t_pad) 189 | return x, (t_pad, h_pad, w_pad, t_mult, h_mult, w_mult) 190 | 191 | def unpad( 192 | self, 193 | x: torch.Tensor, 194 | t_pad: List[int], 195 | h_pad: List[int], 196 | w_pad: List[int], 197 | t_mult: int, 198 | h_mult: int, 199 | w_mult: int, 200 | ) -> torch.Tensor: 201 | return x[..., t_pad[0] : t_mult - t_pad[1], h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] 202 | 203 | def forward(self, x: torch.Tensor) -> torch.Tensor: 204 | if not x.shape[-1] == 2: 205 | raise ValueError("Last dimension must be 2 for complex.") 206 | 207 | # get shapes for unet and normalize 208 | x = self.complex_to_chan_dim(x) 209 | x, mean, std = self.norm(x) 210 | x, pad_sizes = self.pad(x) 211 | 212 | x = self.unet(x) 213 | 214 | # get shapes back and unnormalize 215 | x = self.unpad(x, *pad_sizes) 216 | x = self.unnorm(x, mean, std) 217 | x = self.chan_complex_to_last_dim(x) 218 | 219 | return x -------------------------------------------------------------------------------- /reconstruction/pl_modules/cinenet_module.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | 4 | from reconstruction.data import transforms 5 | from reconstruction.utils import SSIMLoss 6 | from reconstruction.models import CineNet, CineNet_RNN 7 | from .mri_module import MriModule 8 | 9 | 10 | class CineNetModule(MriModule): 11 | """ 12 | Pytorch Lightning module for training CineNet. 13 | 14 | The architecture variations for dynamic MRI reconstruction are 15 | inspired by the deep learning network introduced in the following paper: 16 | 17 | A. Kofler et al. `An end-to-end-trainable iterative network architecture 18 | for accelerated radial multi-coil 2D cine MR image reconstruction.` 19 | In Medical Physics, 2021. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_cascades: int = 12, 25 | CG_iters: int = 4, 26 | chans: int = 18, 27 | pools: int = 4, 28 | dynamic_type: str = 'XF', 29 | weight_sharing: bool = False, 30 | lr: float = 0.0003, 31 | lr_step_size: int = 40, 32 | lr_gamma: float = 0.1, 33 | weight_decay: float = 0.0, 34 | **kwargs, 35 | ): 36 | """ 37 | Args: 38 | num_cascades: Number of alternations between CG and U-Net modules. 39 | CG_iters: Number of CG iterations in the CG module. 40 | chans: Number of channels for cascade U-Net. 41 | pools: Number of downsampling and upsampling layers for cascade U-Net. 42 | dynamic_type: Type of architecture adjustment for dynamic setting. 43 | weight_sharing: Optional setting in 'XF' or 'XT' dynamics mode, allowing 44 | U-Net to share the same parameters in both x-f and y-f planes. 45 | lr: Learning rate. 46 | lr_step_size: Learning rate step size. 47 | lr_gamma: Learning rate gamma decay. 48 | weight_decay: Parameter for penalizing weights norm. 49 | """ 50 | super().__init__(**kwargs) 51 | self.save_hyperparameters() 52 | 53 | self.num_cascades = num_cascades 54 | self.CG_iters = CG_iters 55 | self.pools = pools 56 | self.chans = chans 57 | self.dynamic_type = dynamic_type 58 | self.weight_sharing = weight_sharing 59 | self.lr = lr 60 | self.lr_step_size = lr_step_size 61 | self.lr_gamma = lr_gamma 62 | self.weight_decay = weight_decay 63 | 64 | assert self.dynamic_type in ['XF', 'XT', '2D', '3D', 'CRNN'], \ 65 | "dynamic_type argument must be one of 'XF', 'XT', '2D', '3D' or 'CRNN'" 66 | 67 | if self.dynamic_type == 'CRNN': 68 | self.cinenet = CineNet_RNN( 69 | num_cascades=self.num_cascades, 70 | CG_iters=self.CG_iters, 71 | chans=self.chans, 72 | ) 73 | else: 74 | self.cinenet = CineNet( 75 | num_cascades=self.num_cascades, 76 | CG_iters=self.CG_iters, 77 | chans=self.chans, 78 | pools=self.pools, 79 | dynamic_type=self.dynamic_type, 80 | weight_sharing = self.weight_sharing, 81 | ) 82 | 83 | self.loss = SSIMLoss() 84 | 85 | def forward(self, masked_kspace, mask, coils_maps): 86 | return self.cinenet(masked_kspace, mask, coils_maps) 87 | 88 | def training_step(self, batch, batch_idx): 89 | masked_kspace, mask, coils_maps, target, fname, slice_num, max_value, _ = batch 90 | 91 | output = self(masked_kspace, mask, coils_maps) 92 | target, output = transforms.center_crop_to_smallest(target, output) 93 | 94 | return { 95 | "batch_idx": batch_idx, 96 | "fname": fname, 97 | "slice_num": slice_num, 98 | "max_value": max_value, 99 | "output": output, 100 | "target": target, 101 | "loss": self.loss( 102 | output.unsqueeze(1), target.unsqueeze(1), data_range=max_value 103 | ), 104 | } 105 | 106 | def validation_step(self, batch, batch_idx): 107 | masked_kspace, mask, coils_maps, target, fname, slice_num, max_value, _ = batch 108 | 109 | output = self.forward(masked_kspace, mask, coils_maps) 110 | target, output = transforms.center_crop_to_smallest(target, output) 111 | 112 | return { 113 | "batch_idx": batch_idx, 114 | "fname": fname, 115 | "slice_num": slice_num, 116 | "max_value": max_value, 117 | "output": output, 118 | "target": target, 119 | "val_loss": self.loss( 120 | output.unsqueeze(1), target.unsqueeze(1), data_range=max_value 121 | ), 122 | } 123 | 124 | def test_step(self, batch, batch_idx): 125 | masked_kspace, mask, coils_maps, target, fname, slice_num, max_value, _ = batch 126 | 127 | output = self(masked_kspace, mask, coils_maps) 128 | target, output = transforms.center_crop_to_smallest(target, output) 129 | 130 | return { 131 | "batch_idx": batch_idx, 132 | "fname": fname, 133 | "slice_num": slice_num, 134 | "max_value": max_value, 135 | "output": output, 136 | "target": target, 137 | "test_loss": self.loss( 138 | output.unsqueeze(1), target.unsqueeze(1), data_range=max_value 139 | ), 140 | } 141 | 142 | def configure_optimizers(self): 143 | optim = torch.optim.Adam( 144 | self.parameters(), lr=self.lr, weight_decay=self.weight_decay 145 | ) 146 | scheduler = torch.optim.lr_scheduler.StepLR( 147 | optim, self.lr_step_size, self.lr_gamma 148 | ) 149 | 150 | return [optim], [scheduler] 151 | 152 | @staticmethod 153 | def add_model_specific_args(parent_parser): # pragma: no-cover 154 | """ 155 | Define parameters that only apply to this model 156 | """ 157 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 158 | parser = MriModule.add_model_specific_args(parser) 159 | 160 | # param overwrites 161 | 162 | # network params 163 | parser.add_argument( 164 | "--num_cascades", 165 | default=12, 166 | type=int, 167 | help="Number of alternations between CG and U-Net modules", 168 | ) 169 | parser.add_argument( 170 | "--CG_iters", 171 | default=4, 172 | type=int, 173 | help="Number of Conjugate Gradient iterations", 174 | ) 175 | parser.add_argument( 176 | "--pools", 177 | default=4, 178 | type=int, 179 | help="Number of U-Net pooling layers in CineNet blocks", 180 | ) 181 | parser.add_argument( 182 | "--chans", 183 | default=18, 184 | type=int, 185 | help="Number of channels for U-Net in CineNet blocks", 186 | ) 187 | parser.add_argument( 188 | "--dynamic_type", 189 | default='XF', 190 | type=str, 191 | help="""Architectural variation for dynamic reconstruction. 192 | Options are ['XF', 'XT', '2D', '3D', 'CRNN']""", 193 | ) 194 | parser.add_argument( 195 | "--weight_sharing", 196 | default=False, 197 | type=bool, 198 | help="Allows parameter sharing of U-Nets in x-f, y-f planes", 199 | ) 200 | 201 | # training params (opt) 202 | parser.add_argument( 203 | "--lr", default=0.0003, type=float, help="Adam learning rate" 204 | ) 205 | parser.add_argument( 206 | "--lr_step_size", 207 | default=40, 208 | type=int, 209 | help="Epoch at which to decrease step size", 210 | ) 211 | parser.add_argument( 212 | "--lr_gamma", 213 | default=0.1, 214 | type=float, 215 | help="Extent to which step size should be decreased", 216 | ) 217 | parser.add_argument( 218 | "--weight_decay", 219 | default=0.0, 220 | type=float, 221 | help="Strength of weight decay regularization", 222 | ) 223 | 224 | return parser 225 | -------------------------------------------------------------------------------- /reconstruction/models/denoisers/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class Unet(nn.Module): 7 | """ 8 | PyTorch implementation of a U-Net model. 9 | 10 | O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks 11 | for biomedical image segmentation. In International Conference on Medical 12 | image computing and computer-assisted intervention, pages 234-241. 13 | Springer, 2015. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | chans: int = 32, 19 | num_pool_layers: int = 4, 20 | in_chans: int = 2, 21 | out_chans: int = 2, 22 | drop_prob: float = 0.0, 23 | dims: int = 2, 24 | ): 25 | """ 26 | Args: 27 | chans: Number of output channels of the first convolution layer. 28 | num_pool_layers: Number of down-sampling and up-sampling layers. 29 | in_chans: Number of channels in the input to the U-Net model. 30 | out_chans: Number of channels in the output of the U-Net model. 31 | drop_prob: Dropout probability. 32 | dims: number of dimensions for convolutional operations (2 or 3). 33 | """ 34 | super().__init__() 35 | 36 | self.chans = chans 37 | self.num_pool_layers = num_pool_layers 38 | self.in_chans = in_chans 39 | self.out_chans = out_chans 40 | self.drop_prob = drop_prob 41 | self.dims = dims 42 | 43 | assert dims in [2, 3], \ 44 | "Dimensions must be either 2 or 3" 45 | 46 | if dims == 2: 47 | conv_op = nn.Conv2d 48 | if dims == 3: 49 | conv_op = nn.Conv3d 50 | 51 | self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob, dims)]) 52 | ch = chans 53 | for _ in range(num_pool_layers - 1): 54 | self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob, dims)) 55 | ch *= 2 56 | self.conv = ConvBlock(ch, ch * 2, drop_prob, dims) 57 | 58 | self.up_conv = nn.ModuleList() 59 | self.up_transpose_conv = nn.ModuleList() 60 | for _ in range(num_pool_layers - 1): 61 | self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch, dims)) 62 | self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob, dims)) 63 | ch //= 2 64 | 65 | self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch, dims)) 66 | self.up_conv.append( 67 | nn.Sequential( 68 | ConvBlock(ch * 2, ch, drop_prob, dims), 69 | conv_op(ch, self.out_chans, kernel_size=1, stride=1), 70 | ) 71 | ) 72 | 73 | def forward(self, image: torch.Tensor) -> torch.Tensor: 74 | """ 75 | Args: 76 | image: Input tensor of shape 77 | - `(N, in_chans, H, W)` if dims = 2 78 | - `(N, in_chans, T, H, W)` if dims = 3 79 | 80 | Returns: 81 | Output tensor of shape 82 | - `(N, out_chans, H, W)` if dims = 2 83 | - `(N, out_chans, T, H, W)` if dims = 3 84 | """ 85 | if self.dims == 2: 86 | pool_op = F.avg_pool2d 87 | if self.dims == 3: 88 | pool_op = F.avg_pool3d 89 | 90 | stack = [] 91 | output = image 92 | 93 | # apply down-sampling layers 94 | for layer in self.down_sample_layers: 95 | output = layer(output) 96 | stack.append(output) 97 | output = pool_op(output, kernel_size=2, stride=2, padding=0) 98 | 99 | output = self.conv(output) 100 | 101 | # apply up-sampling layers 102 | for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): 103 | downsample_layer = stack.pop() 104 | output = transpose_conv(output) 105 | 106 | # reflect pad if needed to handle odd input dimensions 107 | if self.dims == 2: 108 | padding = [0, 0, 0, 0] 109 | if self.dims == 3: 110 | padding = [0, 0, 0, 0, 0, 0] 111 | 112 | if output.shape[-1] != downsample_layer.shape[-1]: 113 | padding[1] = 1 # padding right 114 | if output.shape[-2] != downsample_layer.shape[-2]: 115 | padding[3] = 1 # padding bottom 116 | if self.dims == 3: 117 | if output.shape[-3] != downsample_layer.shape[-3]: 118 | padding[5] = 1 # padding temporal end 119 | if torch.sum(torch.tensor(padding)) != 0: 120 | output = F.pad(output, padding) 121 | 122 | output = torch.cat([output, downsample_layer], dim=1) 123 | output = conv(output) 124 | 125 | return output 126 | 127 | 128 | class ConvBlock(nn.Module): 129 | """ 130 | A Convolutional Block that consists of two convolution layers each followed by 131 | instance normalization, LeakyReLU activation and dropout. 132 | """ 133 | 134 | def __init__(self, in_chans: int, out_chans: int, drop_prob: float, dims: int): 135 | """ 136 | Args: 137 | in_chans: Number of channels in the input. 138 | out_chans: Number of channels in the output. 139 | drop_prob: Dropout probability. 140 | dims: number of dimensions for convolutional operations (2 or 3). 141 | """ 142 | super().__init__() 143 | 144 | self.in_chans = in_chans 145 | self.out_chans = out_chans 146 | self.drop_prob = drop_prob 147 | self.dims = dims 148 | 149 | if self.dims == 2: 150 | conv_op = nn.Conv2d 151 | norm_op = nn.InstanceNorm2d 152 | drop_op = nn.Dropout2d 153 | 154 | if self.dims == 3: 155 | conv_op = nn.Conv3d 156 | norm_op = nn.InstanceNorm3d 157 | drop_op = nn.Dropout3d 158 | 159 | self.layers = nn.Sequential( 160 | conv_op(in_chans, out_chans, kernel_size=3, padding=1, bias=False), 161 | norm_op(out_chans), 162 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 163 | drop_op(drop_prob), 164 | conv_op(out_chans, out_chans, kernel_size=3, padding=1, bias=False), 165 | norm_op(out_chans), 166 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 167 | drop_op(drop_prob), 168 | ) 169 | 170 | def forward(self, image: torch.Tensor) -> torch.Tensor: 171 | """ 172 | Args: 173 | image: Input tensor of shape 174 | - `(N, in_chans, H, W)` if dims = 2 175 | - `(N, in_chans, T, H, W)` if dims = 3 176 | 177 | Returns: 178 | Output tensor of shape 179 | - `(N, out_chans, H, W)` if dims = 2 180 | - `(N, out_chans, T, H, W)` if dims = 3 181 | """ 182 | return self.layers(image) 183 | 184 | 185 | class TransposeConvBlock(nn.Module): 186 | """ 187 | A Transpose Convolutional Block that consists of one convolution transpose 188 | layers followed by instance normalization and LeakyReLU activation. 189 | """ 190 | 191 | def __init__(self, in_chans: int, out_chans: int, dims:int): 192 | """ 193 | Args: 194 | in_chans: Number of channels in the input. 195 | out_chans: Number of channels in the output. 196 | dims: number of dimensions for convolutional operations (2 or 3). 197 | """ 198 | super().__init__() 199 | 200 | self.in_chans = in_chans 201 | self.out_chans = out_chans 202 | self.dims = dims 203 | 204 | if self.dims == 2: 205 | up_conv_op = nn.ConvTranspose2d 206 | norm_op = nn.InstanceNorm2d 207 | 208 | if self.dims == 3: 209 | up_conv_op = nn.ConvTranspose3d 210 | norm_op = nn.InstanceNorm3d 211 | 212 | self.layers = nn.Sequential( 213 | up_conv_op( 214 | in_chans, out_chans, kernel_size=2, stride=2, bias=False 215 | ), 216 | norm_op(out_chans), 217 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 218 | ) 219 | 220 | def forward(self, image: torch.Tensor) -> torch.Tensor: 221 | """ 222 | Args: 223 | image: Input tensor of shape 224 | - `(N, in_chans, H, W)` if dims = 2 225 | - `(N, in_chans, T, H, W)` if dims = 3 226 | 227 | Returns: 228 | Output tensor of shape 229 | - `(N, out_chans, H*2, W*2)` if dims = 2 230 | - `(N, out_chans, T*2, H*2, W*2)` if dims = 3 231 | """ 232 | return self.layers(image) 233 | -------------------------------------------------------------------------------- /reconstruction/pl_modules/varnet_module.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | 4 | from reconstruction.data import transforms 5 | from reconstruction.utils import SSIMLoss 6 | from reconstruction.models import VarNet, VarNet_RNN 7 | from .mri_module import MriModule 8 | 9 | 10 | class VarNetModule(MriModule): 11 | """ 12 | Pytorch Lightning module for training VarNet. 13 | 14 | The architecture variations for dynamic MRI reconstruction are 15 | inspired by the End-to-End Variational Network for static MRI 16 | reconstruction, introduced in the following paper: 17 | 18 | A. Sriram et al. "End-to-end variational networks for accelerated MRI 19 | reconstruction". In International Conference on Medical Image Computing and 20 | Computer-Assisted Intervention, 2020. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | num_cascades: int = 12, 26 | pools: int = 4, 27 | chans: int = 18, 28 | sens_pools: int = 4, 29 | sens_chans: int = 8, 30 | dynamic_type: str = 'XF', 31 | weight_sharing: bool = False, 32 | lr: float = 0.0003, 33 | lr_step_size: int = 40, 34 | lr_gamma: float = 0.1, 35 | weight_decay: float = 0.0, 36 | **kwargs, 37 | ): 38 | """ 39 | Args: 40 | num_cascades: Number of cascades (i.e., layers) for variational 41 | network. 42 | pools: Number of downsampling and upsampling layers for cascade 43 | U-Net. 44 | chans: Number of channels for cascade U-Net. 45 | sens_pools: Number of downsampling and upsampling layers for 46 | sensitivity map U-Net. 47 | sens_chans: Number of channels for sensitivity map U-Net. 48 | dynamic_type: Type of architecture adjustment for dynamic setting. 49 | weight_sharing: Optional setting in 'XF' or 'XT' dynamics mode, allowing 50 | U-Net to share the same parameters in both x-f and y-f planes. 51 | lr: Learning rate. 52 | lr_step_size: Learning rate step size. 53 | lr_gamma: Learning rate gamma decay. 54 | weight_decay: Parameter for penalizing weights norm. 55 | """ 56 | super().__init__(**kwargs) 57 | self.save_hyperparameters() 58 | 59 | self.num_cascades = num_cascades 60 | self.pools = pools 61 | self.chans = chans 62 | self.sens_pools = sens_pools 63 | self.sens_chans = sens_chans 64 | self.dynamic_type = dynamic_type 65 | self.weight_sharing = weight_sharing 66 | self.lr = lr 67 | self.lr_step_size = lr_step_size 68 | self.lr_gamma = lr_gamma 69 | self.weight_decay = weight_decay 70 | 71 | assert self.dynamic_type in ['XF', 'XT', '2D', '3D', 'CRNN'], \ 72 | "dynamic_type argument must be one of 'XF', 'XT', '2D', '3D' or 'CRNN'" 73 | 74 | if self.dynamic_type == 'CRNN': 75 | self.varnet = VarNet_RNN( 76 | num_cascades=self.num_cascades, 77 | sens_chans=self.sens_chans, 78 | sens_pools=self.sens_pools, 79 | chans=self.chans, 80 | ) 81 | else: 82 | self.varnet = VarNet( 83 | num_cascades=self.num_cascades, 84 | sens_chans=self.sens_chans, 85 | sens_pools=self.sens_pools, 86 | chans=self.chans, 87 | pools=self.pools, 88 | dynamic_type=self.dynamic_type, 89 | weight_sharing = self.weight_sharing, 90 | ) 91 | 92 | self.loss = SSIMLoss() 93 | 94 | def forward(self, masked_kspace, mask): 95 | return self.varnet(masked_kspace, mask) 96 | 97 | def training_step(self, batch, batch_idx): 98 | masked_kspace, mask, target, fname, slice_num, max_value, _ = batch 99 | 100 | output = self(masked_kspace, mask) 101 | target, output = transforms.center_crop_to_smallest(target, output) 102 | 103 | return { 104 | "batch_idx": batch_idx, 105 | "fname": fname, 106 | "slice_num": slice_num, 107 | "max_value": max_value, 108 | "output": output, 109 | "target": target, 110 | "loss": self.loss( 111 | output.unsqueeze(1), target.unsqueeze(1), data_range=max_value 112 | ), 113 | } 114 | 115 | def validation_step(self, batch, batch_idx): 116 | masked_kspace, mask, target, fname, slice_num, max_value, _ = batch 117 | 118 | output = self.forward(masked_kspace, mask) 119 | target, output = transforms.center_crop_to_smallest(target, output) 120 | 121 | return { 122 | "batch_idx": batch_idx, 123 | "fname": fname, 124 | "slice_num": slice_num, 125 | "max_value": max_value, 126 | "output": output, 127 | "target": target, 128 | "val_loss": self.loss( 129 | output.unsqueeze(1), target.unsqueeze(1), data_range=max_value 130 | ), 131 | } 132 | 133 | def test_step(self, batch, batch_idx): 134 | masked_kspace, mask, target, fname, slice_num, max_value, _ = batch 135 | 136 | output = self(masked_kspace, mask) 137 | target, output = transforms.center_crop_to_smallest(target, output) 138 | 139 | return { 140 | "batch_idx": batch_idx, 141 | "fname": fname, 142 | "slice_num": slice_num, 143 | "max_value": max_value, 144 | "output": output, 145 | "target": target, 146 | "test_loss": self.loss( 147 | output.unsqueeze(1), target.unsqueeze(1), data_range=max_value 148 | ), 149 | } 150 | 151 | def configure_optimizers(self): 152 | optim = torch.optim.Adam( 153 | self.parameters(), lr=self.lr, weight_decay=self.weight_decay 154 | ) 155 | scheduler = torch.optim.lr_scheduler.StepLR( 156 | optim, self.lr_step_size, self.lr_gamma 157 | ) 158 | 159 | return [optim], [scheduler] 160 | 161 | @staticmethod 162 | def add_model_specific_args(parent_parser): # pragma: no-cover 163 | """ 164 | Define parameters that only apply to this model 165 | """ 166 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 167 | parser = MriModule.add_model_specific_args(parser) 168 | 169 | # param overwrites 170 | 171 | # network params 172 | parser.add_argument( 173 | "--num_cascades", 174 | default=12, 175 | type=int, 176 | help="Number of VarNet cascades", 177 | ) 178 | parser.add_argument( 179 | "--pools", 180 | default=4, 181 | type=int, 182 | help="Number of U-Net pooling layers in VarNet blocks", 183 | ) 184 | parser.add_argument( 185 | "--chans", 186 | default=18, 187 | type=int, 188 | help="Number of channels for U-Net in VarNet blocks", 189 | ) 190 | parser.add_argument( 191 | "--sens_pools", 192 | default=4, 193 | type=int, 194 | help="Number of pooling layers for sense map estimation U-Net in VarNet", 195 | ) 196 | parser.add_argument( 197 | "--sens_chans", 198 | default=8, 199 | type=float, 200 | help="Number of channels for sense map estimation U-Net in VarNet", 201 | ) 202 | parser.add_argument( 203 | "--dynamic_type", 204 | default='XF', 205 | type=str, 206 | help="""Architectural variation for dynamic reconstruction. 207 | Options are ['XF', 'XT', '2D', '3D', 'CRNN']""", 208 | ) 209 | parser.add_argument( 210 | "--weight_sharing", 211 | default=False, 212 | type=bool, 213 | help="Allows parameter sharing of U-Nets in x-f, y-f planes.", 214 | ) 215 | 216 | # training params (opt) 217 | parser.add_argument( 218 | "--lr", default=0.0003, type=float, help="Adam learning rate" 219 | ) 220 | parser.add_argument( 221 | "--lr_step_size", 222 | default=40, 223 | type=int, 224 | help="Epoch at which to decrease step size", 225 | ) 226 | parser.add_argument( 227 | "--lr_gamma", 228 | default=0.1, 229 | type=float, 230 | help="Extent to which step size should be decreased", 231 | ) 232 | parser.add_argument( 233 | "--weight_decay", 234 | default=0.0, 235 | type=float, 236 | help="Strength of weight decay regularization", 237 | ) 238 | 239 | return parser 240 | -------------------------------------------------------------------------------- /reconstruction/data/subsample.py: -------------------------------------------------------------------------------- 1 | """ 2 | This source code is based on the fastMRI repository from Facebook AI 3 | Research and is used as a general framework to handle MRI data. Link: 4 | 5 | https://github.com/facebookresearch/fastMRI 6 | """ 7 | 8 | import contextlib 9 | from typing import Optional, Sequence, Tuple, Union 10 | 11 | import numpy as np 12 | import torch 13 | 14 | 15 | @contextlib.contextmanager 16 | def temp_seed(rng: np.random, seed: Optional[Union[int, Tuple[int, ...]]]): 17 | if seed is None: 18 | try: 19 | yield 20 | finally: 21 | pass 22 | else: 23 | state = rng.get_state() 24 | rng.seed(seed) 25 | try: 26 | yield 27 | finally: 28 | rng.set_state(state) 29 | 30 | 31 | class MaskFunc: 32 | """ 33 | An object for GRAPPA-style sampling masks. 34 | 35 | This crates a sampling mask that densely samples the center while 36 | subsampling outer k-space regions based on the undersampling factor. 37 | """ 38 | 39 | def __init__(self, center_fractions: Sequence[float], accelerations: Sequence[int]): 40 | """ 41 | Args: 42 | center_fractions: When using a random mask, number of low-frequency 43 | lines to retain. When using an equispaced masked, fraction of 44 | low-frequency lines to retain. 45 | If multiple values are provided, then one of these numbers is 46 | chosen uniformly each time. 47 | accelerations: Amount of under-sampling. This should have the same 48 | length as center_fractions. If multiple values are provided, 49 | then one of these is chosen uniformly each time. 50 | """ 51 | if not len(center_fractions) == len(accelerations): 52 | raise ValueError( 53 | "Number of center fractions should match number of accelerations" 54 | ) 55 | 56 | self.center_fractions = center_fractions 57 | self.accelerations = accelerations 58 | self.rng = np.random.RandomState() 59 | 60 | def __call__( 61 | self, shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None 62 | ) -> torch.Tensor: 63 | raise NotImplementedError 64 | 65 | def choose_acceleration(self): 66 | """Choose acceleration based on class parameters.""" 67 | choice = self.rng.randint(0, len(self.accelerations)) 68 | center_fraction = self.center_fractions[choice] 69 | acceleration = self.accelerations[choice] 70 | 71 | return center_fraction, acceleration 72 | 73 | 74 | 75 | class RandomMaskFunc(MaskFunc): 76 | """ 77 | RandomMaskFunc creates a Cartesian sub-sampling mask of a given shape, 78 | as implemented in 79 | "A Deep Cascade of Convolutional Neural Networks for Dynamic MR Image 80 | Reconstruction" by J. Schlemper et al. 81 | 82 | The mask selects a subset of rows from the input k-space data. If the 83 | k-space data has N rows, the mask picks out: 84 | 1. center_fraction rows in the center corresponding to low-frequencies. 85 | 2. The remaining rows are selected according to a tail-adjusted 86 | Gaussian probability density function. This ensures that the 87 | expected number of rows selected is equal to (N / acceleration). 88 | 89 | It is possible to use multiple center_fractions and accelerations, in which 90 | case one possible (center_fraction, acceleration) is chosen uniformly at 91 | random each time the RandomMaskFunc object is called. 92 | """ 93 | 94 | def __call__( 95 | self, shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None 96 | ) -> torch.Tensor: 97 | """ 98 | Create the mask. 99 | 100 | Args: 101 | shape: The shape of the mask to be created. The shape should have 102 | at least 3 dimensions. Samples are drawn along the third 103 | dimension. 104 | seed: Seed for the random number generator. Setting the seed 105 | ensures the same mask is generated each time for the same 106 | shape. The random state is reset afterwards. 107 | 108 | Returns: 109 | A mask of the specified shape. 110 | """ 111 | if len(shape) < 3: 112 | raise ValueError("Shape should have 3 or more dimensions") 113 | 114 | with temp_seed(self.rng, seed): 115 | sample_n, acc = self.choose_acceleration() 116 | 117 | N, Nc, Nx, Ny, Nch = shape 118 | 119 | # generate normal distribution 120 | normal_pdf = lambda length, sensitivity: np.exp(-sensitivity * (np.arange(length) - length / 2)**2) 121 | pdf_x = normal_pdf(Nx, 0.5/(Nx/10.)**2) 122 | lmda = Nx / (2.*acc) 123 | n_lines = int(Nx / acc) 124 | 125 | # add uniform distribution so that probability of sampling 126 | # high-frequency lines is non-zero 127 | pdf_x += lmda * 1./Nx 128 | 129 | if sample_n: 130 | # lines are never randomly sampled from the already 131 | # sampled center 132 | pdf_x[Nx//2 - sample_n//2 : Nx//2 + sample_n//2] = 0 133 | pdf_x /= np.sum(pdf_x) # normalise distribution 134 | n_lines -= sample_n 135 | 136 | mask = np.zeros((N, Nx)) 137 | for i in range(N): 138 | # select low-frequency lines according to pdf 139 | idx = np.random.choice(Nx, n_lines, False, pdf_x) 140 | mask[i, idx] = 1 141 | 142 | if sample_n: 143 | # central lines are always sampled 144 | mask[:, Nx//2-sample_n//2:Nx//2+sample_n//2] = 1 145 | 146 | # reshape the mask 147 | mask_shape = [1 for _ in shape] 148 | mask_shape[-3] = Nx 149 | mask_shape[0] = N 150 | mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) 151 | return mask 152 | 153 | 154 | class EquispacedMaskFunc(MaskFunc): 155 | """ 156 | EquispacedMaskFunc creates a sub-sampling mask of a given shape. 157 | 158 | The mask selects a subset of columns from the input k-space data. If the 159 | k-space data has N rows, the mask picks out: 160 | 1. N_low_freqs = (N * center_fraction) rows in the center 161 | corresponding to low-frequencies. 162 | 2. The other rows are selected with equal spacing at a proportion 163 | that reaches the desired acceleration rate taking into consideration 164 | the number of low frequencies. This ensures that the expected number 165 | of rows selected is equal to (N / acceleration) 166 | 167 | It is possible to use multiple center_fractions and accelerations, in which 168 | case one possible (center_fraction, acceleration) is chosen uniformly at 169 | random each time the EquispacedMaskFunc object is called. 170 | """ 171 | 172 | def __call__( 173 | self, shape: Sequence[int], seed: Optional[Union[int, Tuple[int, ...]]] = None 174 | ) -> torch.Tensor: 175 | """ 176 | Args: 177 | shape: The shape of the mask to be created. The shape should have 178 | at least 3 dimensions. Samples are drawn along the third last 179 | dimension. 180 | seed: Seed for the random number generator. Setting the seed 181 | ensures the same mask is generated each time for the same 182 | shape. The random state is reset afterwards. 183 | 184 | Returns: 185 | A mask of the specified shape. 186 | """ 187 | if len(shape) < 3: 188 | raise ValueError("Shape should have 3 or more dimensions") 189 | 190 | with temp_seed(self.rng, seed): 191 | center_fraction, acceleration = self.choose_acceleration() 192 | num_rows = shape[-3] 193 | num_low_freqs = int(round(num_rows * center_fraction)) 194 | 195 | # create the mask 196 | mask = np.zeros(num_rows, dtype=np.float32) 197 | pad = (num_rows - num_low_freqs + 1) // 2 198 | mask[pad : pad + num_low_freqs] = True 199 | 200 | # determine acceleration rate by adjusting for the number of low frequencies 201 | adjusted_accel = (acceleration * (num_low_freqs - num_rows)) / ( 202 | num_low_freqs * acceleration - num_rows 203 | ) 204 | offset = self.rng.randint(0, round(adjusted_accel)) 205 | 206 | accel_samples = np.arange(offset, num_rows - 1, adjusted_accel) 207 | accel_samples = np.around(accel_samples).astype(np.uint) 208 | mask[accel_samples] = True 209 | 210 | # reshape the mask 211 | mask_shape = [1 for _ in shape] 212 | mask_shape[-3] = num_rows 213 | mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) 214 | 215 | return mask 216 | 217 | 218 | def create_mask_for_mask_type( 219 | mask_type_str: str, 220 | center_fractions: Sequence[float], 221 | accelerations: Sequence[int], 222 | ) -> MaskFunc: 223 | """ 224 | Creates a mask of the specified type. 225 | 226 | Args: 227 | center_fractions: What fraction of the center of k-space to include. 228 | accelerations: What accelerations to apply. 229 | """ 230 | if mask_type_str == "random": 231 | return RandomMaskFunc(center_fractions, accelerations) 232 | elif mask_type_str == "equispaced": 233 | return EquispacedMaskFunc(center_fractions, accelerations) 234 | else: 235 | raise Exception(f"{mask_type_str} not supported") 236 | -------------------------------------------------------------------------------- /reconstruction/models/recurrent_varnet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | import reconstruction as rec 10 | from .varnet import SensitivityModel 11 | 12 | 13 | class VarNet_RNN(nn.Module): 14 | """ 15 | A hybrid model for Dynamic MRI Reconstruction, inspired by combining 16 | the End-to-End Variational Network [1] and Recurrent Convolutional 17 | Neural Networks (RCNN) [2]. 18 | 19 | Reference papers: 20 | [1] A. Sriram et al. `End-to-end variational networks for accelerated MRI 21 | reconstruction`. In International Conference on Medical Image Computing and 22 | Computer-Assisted Intervention, 2020. 23 | [2] C. Qin et al. `Convolutional Recurrent Neural Networks for Dynamic MR 24 | Image Reconstruction`. In IEEE Transactions on Medical Imaging 38.1, 25 | pp. 280–290, 2019. 26 | """ 27 | def __init__( 28 | self, 29 | num_cascades: int = 12, 30 | sens_chans: int = 8, 31 | sens_pools: int = 4, 32 | chans: int = 18, 33 | ): 34 | """ 35 | Args: 36 | num_cascades: Number of cascades (i.e., layers) for variational 37 | network. 38 | sens_chans: Number of channels for sensitivity map U-Net. 39 | sens_pools Number of downsampling and upsampling layers for 40 | sensitivity map U-Net. 41 | chans: Number of channels for convolutional layers of the RCNN. 42 | """ 43 | super(VarNet_RNN, self).__init__() 44 | 45 | self.num_cascades = num_cascades 46 | self.chans = chans 47 | 48 | self.sens_net = SensitivityModel(sens_chans, sens_pools) 49 | self.bcrnn = BCRNNlayer(input_size=2, hidden_size=self.chans, kernel_size=3) 50 | 51 | self.conv1_x = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 52 | self.conv1_h = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 53 | self.conv2_x = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 54 | self.conv2_h = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 55 | self.conv3_x = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 56 | self.conv3_h = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 57 | self.conv4_x = nn.Conv2d(self.chans, 2, 3, padding = 3//2) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | self.Softplus = nn.Softplus(1.) 61 | lambda_init = np.log(np.exp(1)-1.)/1. 62 | self.lambda_reg = nn.Parameter(torch.tensor(lambda_init*torch.ones(1),dtype=torch.float), 63 | requires_grad=True) 64 | 65 | def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 66 | """ 67 | Forward operator: from coil-combined image-space to k-space. 68 | """ 69 | return rec.utils.fft2c(rec.utils.complex_mul(x.permute(0,4,2,3,1).unsqueeze(2), sens_maps)) 70 | 71 | def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 72 | """ 73 | Backward operator: from k-space to coil-combined image-space. 74 | """ 75 | x = rec.utils.ifft2c(x) 76 | return rec.utils.complex_mul(x, rec.utils.complex_conj(sens_maps)).sum( 77 | dim=2, keepdim=False 78 | ).permute(0,4,2,3,1) # b, ch, h, w, t 79 | 80 | def data_consistency(self, 81 | x: torch.Tensor, 82 | ref_kspace: torch.Tensor, 83 | mask: torch.Tensor, 84 | sens_maps: torch.Tensor, 85 | ) -> torch.Tensor: 86 | 87 | current_kspace = self.sens_expand(x, sens_maps) 88 | v = self.Softplus(self.lambda_reg) 89 | dc = (1 - mask) * current_kspace + mask * (current_kspace + v * ref_kspace) / (1 + v) # b,t,c,h,w,ch 90 | return self.sens_reduce(dc, sens_maps) 91 | 92 | 93 | def forward(self, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 94 | """ 95 | Args: 96 | ref_kspace, mask: Input 6D tensors of shape `(b, t, c, h, w, ch)`. 97 | 98 | Returns: 99 | Output tensor of shape `(b, t, h, w)`. 100 | """ 101 | sens_maps = self.sens_net(ref_kspace, mask) 102 | current_kspace = ref_kspace.clone() 103 | x = self.sens_reduce(current_kspace, sens_maps) 104 | 105 | b, ch, h, w, t = x.size() 106 | size_h = [t*b, self.chans, h, w] 107 | 108 | # Initialise parameters of rcnn layers at the first iteration to zero 109 | net = {} 110 | rcnn_layers = 5 111 | for j in range(rcnn_layers-1): 112 | net['t0_x%d'%j] = Variable(torch.zeros(size_h)).cuda() 113 | 114 | # Recurrence through iterations 115 | for i in range(1, self.num_cascades + 1): 116 | 117 | x = x.permute(4,0,1,2,3) 118 | x = x.contiguous() 119 | 120 | net['t%d_x0' % (i-1)] = net['t%d_x0' % (i-1)].view(t, b, self.chans, h, w) 121 | net['t%d_x0'%i] = self.bcrnn(x, net['t%d_x0' % (i-1)]) 122 | net['t%d_x0'%i] = net['t%d_x0'%i].view(-1, self.chans, h, w) 123 | 124 | net['t%d_x1'%i] = self.conv1_x(net['t%d_x0'%i]) 125 | net['t%d_h1'%i] = self.conv1_h(net['t%d_x1'%(i-1)]) 126 | net['t%d_x1'%i] = self.relu(net['t%d_h1'%i] + net['t%d_x1'%i]) 127 | 128 | net['t%d_x2'%i] = self.conv2_x(net['t%d_x1'%i]) 129 | net['t%d_h2'%i] = self.conv2_h(net['t%d_x2'%(i-1)]) 130 | net['t%d_x2'%i] = self.relu(net['t%d_h2'%i] + net['t%d_x2'%i]) 131 | 132 | net['t%d_x3'%i] = self.conv3_x(net['t%d_x2'%i]) 133 | net['t%d_h3'%i] = self.conv3_h(net['t%d_x3'%(i-1)]) 134 | net['t%d_x3'%i] = self.relu(net['t%d_h3'%i] + net['t%d_x3'%i]) 135 | 136 | net['t%d_x4'%i] = self.conv4_x(net['t%d_x3'%i]) 137 | 138 | x = x.view(-1, ch, h, w) 139 | net['t%d_out'%i] = x + net['t%d_x4'%i] 140 | 141 | net['t%d_out'%i] = net['t%d_out'%i].view(-1, b, ch, h, w) 142 | net['t%d_out'%i] = net['t%d_out'%i].permute(1,2,3,4,0) 143 | net['t%d_out'%i].contiguous() 144 | 145 | net['t%d_out'%i] = self.data_consistency(net['t%d_out'%i], ref_kspace, mask, sens_maps) 146 | 147 | x = net['t%d_out'%i] 148 | 149 | out = net['t%d_out'%i] 150 | return rec.utils.complex_abs(out.permute(0,4,2,3,1)) 151 | 152 | 153 | class CRNNcell(nn.Module): 154 | """ 155 | Convolutional RNN cell that evolves over both time and iterations. 156 | """ 157 | def __init__( 158 | self, 159 | input_size: int, 160 | hidden_size: int, 161 | kernel_size: int, 162 | ): 163 | """ 164 | Args: 165 | input_size: Number of input channels 166 | hidden_size: Number of RCNN hidden layers channels 167 | kernel_size: Size of convolutional kernel 168 | """ 169 | super(CRNNcell, self).__init__() 170 | 171 | # Convolution for input 172 | self.i2h = nn.Conv2d(input_size, hidden_size, kernel_size, padding=kernel_size // 2) 173 | # Convolution for hidden states in temporal dimension 174 | self.h2h = nn.Conv2d(hidden_size, hidden_size, kernel_size, padding=kernel_size // 2) 175 | # Convolution for hidden states in iteration dimension 176 | self.ih2ih = nn.Conv2d(hidden_size, hidden_size, kernel_size, padding=kernel_size // 2) 177 | 178 | self.relu = nn.ReLU(inplace=True) 179 | 180 | def forward( 181 | self, 182 | input: torch.Tensor, 183 | hidden_iteration: torch.Tensor, 184 | hidden: torch.Tensor, 185 | ) -> torch.Tensor: 186 | """ 187 | Args: 188 | input: Input 4D tensor of shape `(b, ch, h, w)` 189 | hidden_iteration: hidden states in iteration dimension, 4d tensor of shape (b, hidden_size, h, w) 190 | hidden: hidden states in temporal dimension, 4d tensor of shape (b, hidden_size, h, w) 191 | Returns: 192 | Output tensor of shape `(b, hidden_size, h, w)`. 193 | """ 194 | in_to_hid = self.i2h(input) 195 | hid_to_hid = self.h2h(hidden) 196 | ih_to_ih = self.ih2ih(hidden_iteration) 197 | 198 | hidden = self.relu(in_to_hid + hid_to_hid + ih_to_ih) 199 | 200 | return hidden 201 | 202 | 203 | class BCRNNlayer(nn.Module): 204 | """ 205 | Bidirectional Convolutional RNN layer 206 | """ 207 | def __init__( 208 | self, 209 | input_size: int, 210 | hidden_size: int, 211 | kernel_size: int, 212 | ): 213 | """ 214 | Args: 215 | input_size: Number of input channels 216 | hidden_size: Number of RCNN hidden layers channels 217 | kernel_size: Size of convolutional kernel 218 | """ 219 | super(BCRNNlayer, self).__init__() 220 | 221 | self.hidden_size = hidden_size 222 | self.CRNN_model = CRNNcell(input_size, self.hidden_size, kernel_size) 223 | 224 | def forward(self, input: torch.Tensor, hidden_iteration: torch.Tensor) -> torch.Tensor: 225 | """ 226 | Args: 227 | input: Input 5D tensor of shape `(t, b, ch, h, w)` 228 | hidden_iteration: hidden states (output of BCRNNlayer) from previous 229 | iteration, 5d tensor of shape (t, b, hidden_size, h, w) 230 | Returns: 231 | Output tensor of shape `(t, b, hidden_size, h, w)`. 232 | """ 233 | t, b, ch, h, w = input.shape 234 | size_h = [b, self.hidden_size, h, w] 235 | 236 | hid_init = Variable(torch.zeros(size_h)).cuda() 237 | output_f = [] 238 | output_b = [] 239 | 240 | # forward 241 | hidden = hid_init 242 | for i in range(t): 243 | hidden = self.CRNN_model(input[i], hidden_iteration[i], hidden) 244 | output_f.append(hidden) 245 | output_f = torch.cat(output_f) 246 | 247 | # backward 248 | hidden = hid_init 249 | for i in range(t): 250 | hidden = self.CRNN_model(input[t - i - 1], hidden_iteration[t - i -1], hidden) 251 | output_b.append(hidden) 252 | output_b = torch.cat(output_b[::-1]) 253 | 254 | output = output_f + output_b 255 | 256 | if b == 1: 257 | output = output.view(t, 1, self.hidden_size, h, w) 258 | 259 | return output 260 | -------------------------------------------------------------------------------- /traintest_scripts/varnet/train_test_varnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/path/to/source') 3 | 4 | import os 5 | import pathlib 6 | import time 7 | from argparse import ArgumentParser 8 | 9 | import numpy as np 10 | import torch 11 | import pytorch_lightning as pl 12 | 13 | from reconstruction.data import SliceDataset 14 | from reconstruction.data.mri_data import fetch_dir 15 | from reconstruction.data.subsample import create_mask_for_mask_type 16 | from reconstruction.data.transforms import VarNetDataTransform 17 | from reconstruction.pl_modules import MriDataModule, VarNetModule 18 | from traintest_scripts.run_inference import InferenceTransform 19 | 20 | 21 | 22 | def train_test_main(args, save_path): 23 | pl.seed_everything(args.seed) 24 | 25 | # ------------ 26 | # DATA SECTION 27 | # ------------ 28 | 29 | # This creates a k-space mask to subsample input data 30 | mask = create_mask_for_mask_type( 31 | args.mask_type, args.center_fractions, args.accelerations 32 | ) 33 | 34 | train_transform = VarNetDataTransform(mask_func=mask, use_seed=False) 35 | val_transform = VarNetDataTransform(mask_func=mask, use_seed=False) 36 | test_transform = VarNetDataTransform(mask_func=mask, use_seed=False) 37 | 38 | # Data module - this handles data loaders 39 | data_module = MriDataModule( 40 | data_path=args.data_path, 41 | train_transform=train_transform, 42 | val_transform=val_transform, 43 | test_transform=test_transform, 44 | combine_train_val=args.combine_train_val, 45 | test_split=args.test_split, 46 | test_path=args.test_path, 47 | sample_rate=args.sample_rate, 48 | use_dataset_cache_file=args.use_dataset_cache_file, 49 | batch_size=args.batch_size, 50 | num_workers=args.num_workers, 51 | distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")), 52 | ) 53 | 54 | # ------------- 55 | # MODEL SECTION 56 | # ------------- 57 | 58 | # Load model state dictionary (generally for testing) 59 | if args.load_model: 60 | checkpoint_dir = args.default_root_dir / "checkpoints" 61 | ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime) 62 | if ckpt_list: 63 | last_ckpt_path = str(ckpt_list[-1]) 64 | print(f"Loading model from {last_ckpt_path}") 65 | model = VarNetModule.load_from_checkpoint(last_ckpt_path) 66 | else: 67 | raise ValueError("No checkpoint available") 68 | 69 | else: 70 | # Build model 71 | model = VarNetModule( 72 | num_cascades=args.num_cascades, 73 | pools=args.pools, 74 | chans=args.chans, 75 | sens_pools=args.sens_pools, 76 | sens_chans=args.sens_chans, 77 | dynamic_type=args.dynamic_type, 78 | weight_sharing=args.weight_sharing, 79 | lr=args.lr, 80 | lr_step_size=args.lr_step_size, 81 | lr_gamma=args.lr_gamma, 82 | weight_decay=args.weight_decay, 83 | ) 84 | 85 | # ------------------ 86 | # TRAIN-TEST SECTION 87 | # ------------------ 88 | 89 | trainer = pl.Trainer.from_argparse_args(args) 90 | 91 | if args.mode == "train": 92 | 93 | print("Training VarNet " 94 | f"{args.dynamic_type} with " 95 | f"{args.num_cascades} cascades for " 96 | f"{args.max_epochs} epochs.\nData is subsampled with a " 97 | f"{args.mask_type} mask, acceleration " 98 | f"{args.accelerations[0]}." 99 | ) 100 | 101 | start_time = time.perf_counter() 102 | trainer.fit(model, datamodule=data_module) 103 | end_time = time.perf_counter() 104 | 105 | print(f"Training time: {(end_time-start_time) / 3600.} hours") 106 | 107 | if args.save_checkpoint: 108 | trainer.save_checkpoint(args.default_root_dir / f"checkpoints/varnet.ckpt") 109 | print(f"Saving checkpoint in varnet_{args.dynamic_type}_acc{args.accelerations[0]}_ckpt") 110 | 111 | elif args.mode == "test": 112 | trainer.test(model, datamodule=data_module) 113 | else: 114 | raise ValueError(f"unrecognized mode {args.mode}") 115 | 116 | # ----------------- 117 | # INFERENCE SECTION 118 | # ----------------- 119 | 120 | if (args.mode == "test" and args.inference): 121 | 122 | inference_dataset = SliceDataset( 123 | root=args.data_path / "inference", transform=test_transform, 124 | ) 125 | dataloader = torch.utils.data.DataLoader(inference_dataset, num_workers=2) 126 | inf_transform = InferenceTransform(model, 'varnet', save_path) 127 | time_for_inference = 0 128 | 129 | print('Starting inference..............') 130 | 131 | for batch in dataloader: 132 | with torch.no_grad(): 133 | masked_kspace, mask, target, fname, _, _, _ = batch 134 | time_for_inference += inf_transform(masked_kspace, mask, target, fname) 135 | 136 | print(f"Elapsed time: {time_for_inference} seconds.") 137 | 138 | 139 | 140 | 141 | def build_args(): 142 | parser = ArgumentParser() 143 | 144 | # ---------- 145 | # BASIC ARGS 146 | # ---------- 147 | path_config = pathlib.Path("/root/traintest_scripts/dirs_path.yaml") 148 | backend = "dp" 149 | num_gpus = 2 if backend == "ddp" else 1 150 | batch_size = 1 151 | 152 | # Set defaults based on optional directory config 153 | data_path = fetch_dir("data_path", path_config) 154 | save_path = fetch_dir("save_path", path_config) 155 | default_root_dir = fetch_dir("log_path", path_config) / "varnet/varnet_logs" 156 | 157 | # ----------- 158 | # CLIENT ARGS 159 | # ----------- 160 | parser.add_argument( 161 | "--mode", 162 | default="train", 163 | choices=("train", "test"), 164 | type=str, 165 | help="Operation mode", 166 | ) 167 | 168 | parser.add_argument( 169 | "--epochs", 170 | default=150, 171 | type=int, 172 | help="Total number of epochs to train the model for", 173 | ) 174 | 175 | parser.add_argument( 176 | "--save_checkpoint", 177 | default=0, 178 | choices=(0, 1), 179 | type=int, 180 | help="Whether to save a checkpoint of the model at the end of training", 181 | ) 182 | 183 | parser.add_argument( 184 | "--resume_training", 185 | default=0, 186 | choices=(0, 1), 187 | type=int, 188 | help="Whether to resume training from the latest checkpoint", 189 | ) 190 | 191 | parser.add_argument( 192 | "--load_model", 193 | default=0, 194 | choices=(0, 1), 195 | type=int, 196 | help="Whether to load the latest model in checkpoint dir, to be used for testing", 197 | ) 198 | 199 | parser.add_argument( 200 | "--inference", 201 | default=1, 202 | choices=(0, 1), 203 | type=int, 204 | help="Whether to generate and save the reconstruction made by the trained model on an inference dataset", 205 | ) 206 | 207 | # Data transform params 208 | parser.add_argument( 209 | "--mask_type", 210 | choices=("random", "equispaced"), 211 | default="random", 212 | type=str, 213 | help="Type of k-space mask", 214 | ) 215 | parser.add_argument( 216 | "--center_fractions", 217 | nargs="+", 218 | default=[10], 219 | type=float, 220 | help="Number of central lines to use in mask", 221 | ) 222 | parser.add_argument( 223 | "--accelerations", 224 | nargs="+", 225 | default=[4], 226 | type=int, 227 | help="Acceleration rates to use for masks", 228 | ) 229 | 230 | 231 | # -------------- 232 | # MODULES CONFIG 233 | # -------------- 234 | 235 | # Data config 236 | parser = MriDataModule.add_data_specific_args(parser) 237 | parser.set_defaults( 238 | data_path=data_path, 239 | test_path=None, 240 | test_split="test", 241 | sample_rate=None, 242 | use_dataset_cache_file=True, 243 | combine_train_val=False, 244 | batch_size=batch_size, 245 | num_workers=4, 246 | ) 247 | 248 | # Model config 249 | parser = VarNetModule.add_model_specific_args(parser) 250 | parser.set_defaults( 251 | num_cascades=10, 252 | pools=3, 253 | chans=16, 254 | sens_pools=3, 255 | sens_chans=8, 256 | dynamic_type='XF', 257 | weight_sharing=False, 258 | lr=0.0001, 259 | lr_step_size=140, 260 | lr_gamma=0.01, 261 | weight_decay=0.0, 262 | ) 263 | 264 | args = parser.parse_args() 265 | 266 | # Configure checkpointing in checkpoint_dir 267 | checkpoint_dir = default_root_dir / "checkpoints" 268 | if not checkpoint_dir.exists(): 269 | checkpoint_dir.mkdir(parents=True) 270 | 271 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 272 | dirpath=default_root_dir / "checkpoints", 273 | filename=f"varnet_{args.dynamic_type}_acc{args.accelerations[0]}_ckpt", 274 | verbose=True, 275 | monitor="validation_loss", 276 | mode="min", 277 | ) 278 | 279 | resume_from_checkpoint_path = None 280 | if args.resume_training: 281 | ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime) 282 | if ckpt_list: 283 | resume_from_checkpoint_path = str(ckpt_list[-1]) 284 | 285 | # Configure trainer options 286 | parser = pl.Trainer.add_argparse_args(parser) 287 | parser.set_defaults( 288 | gpus=num_gpus, # number of gpus to use 289 | replace_sampler_ddp=False, # this is necessary for volume dispatch during val 290 | accelerator=backend, # what distributed version to use 291 | seed=42, # random seed 292 | deterministic=True, # makes things slower, but deterministic 293 | default_root_dir=default_root_dir, # directory for logs and checkpoints 294 | max_epochs=args.epochs, 295 | callbacks=[checkpoint_callback], 296 | resume_from_checkpoint=resume_from_checkpoint_path, 297 | ) 298 | 299 | args = parser.parse_args() 300 | 301 | return args, save_path 302 | 303 | 304 | 305 | def run_main(): 306 | args, save_path = build_args() 307 | train_test_main(args, save_path) 308 | 309 | 310 | if __name__ == "__main__": 311 | run_main() 312 | -------------------------------------------------------------------------------- /traintest_scripts/cinenet/train_test_cinenet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/path/to/source') 3 | 4 | import os 5 | import pathlib 6 | import time 7 | from argparse import ArgumentParser 8 | 9 | import numpy as np 10 | import torch 11 | import pytorch_lightning as pl 12 | 13 | from reconstruction.data import SliceDataset 14 | from reconstruction.data.mri_data import fetch_dir 15 | from reconstruction.data.subsample import create_mask_for_mask_type 16 | from reconstruction.data.transforms import CineNetDataTransform 17 | from reconstruction.pl_modules import MriDataModule, CineNetModule 18 | from traintest_scripts.run_inference import InferenceTransform 19 | 20 | 21 | 22 | def train_test_main(args, save_path): 23 | pl.seed_everything(args.seed) 24 | 25 | # ------------ 26 | # DATA SECTION 27 | # ------------ 28 | 29 | # This creates a k-space mask to subsample input data 30 | mask = create_mask_for_mask_type( 31 | args.mask_type, args.center_fractions, args.accelerations 32 | ) 33 | 34 | train_transform = CineNetDataTransform(mask_func=mask, use_seed=False) 35 | val_transform = CineNetDataTransform(mask_func=mask, use_seed=False) 36 | test_transform = CineNetDataTransform(mask_func=mask, use_seed=False) 37 | 38 | # Data module - this handles data loaders 39 | data_module = MriDataModule( 40 | data_path=args.data_path, 41 | train_transform=train_transform, 42 | val_transform=val_transform, 43 | test_transform=test_transform, 44 | combine_train_val=args.combine_train_val, 45 | test_split=args.test_split, 46 | test_path=args.test_path, 47 | sample_rate=args.sample_rate, 48 | use_dataset_cache_file=args.use_dataset_cache_file, 49 | batch_size=args.batch_size, 50 | num_workers=args.num_workers, 51 | distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")), 52 | ) 53 | 54 | # ------------- 55 | # MODEL SECTION 56 | # ------------- 57 | 58 | # Load model state dictionary (generally for testing) 59 | if args.load_model: 60 | checkpoint_dir = args.default_root_dir / "checkpoints" 61 | ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime) 62 | if ckpt_list: 63 | last_ckpt_path = str(ckpt_list[-1]) 64 | print(f"Loading model from {last_ckpt_path}") 65 | model = CineNetModule.load_from_checkpoint(last_ckpt_path) 66 | else: 67 | raise ValueError("No checkpoint available") 68 | 69 | else: 70 | # Build model 71 | model = CineNetModule( 72 | num_cascades=args.num_cascades, 73 | CG_iters=args.CG_iters, 74 | pools=args.pools, 75 | chans=args.chans, 76 | dynamic_type=args.dynamic_type, 77 | weight_sharing=args.weight_sharing, 78 | lr=args.lr, 79 | lr_step_size=args.lr_step_size, 80 | lr_gamma=args.lr_gamma, 81 | weight_decay=args.weight_decay, 82 | ) 83 | 84 | # ------------------ 85 | # TRAIN-TEST SECTION 86 | # ------------------ 87 | 88 | trainer = pl.Trainer.from_argparse_args(args) 89 | 90 | if args.mode == "train": 91 | 92 | print("Training CineNet " 93 | f"{args.dynamic_type} with " 94 | f"{args.num_cascades} unrolled iterations, " 95 | f"{args.CG_iters} CG iterations, for " 96 | f"{args.max_epochs} epochs.\nData is subsampled with a " 97 | f"{args.mask_type} mask, acceleration " 98 | f"{args.accelerations[0]}." 99 | ) 100 | 101 | start_time = time.perf_counter() 102 | trainer.fit(model, datamodule=data_module) 103 | end_time = time.perf_counter() 104 | 105 | print(f"Training time: {(end_time-start_time) / 3600.} hours") 106 | 107 | if args.save_checkpoint: 108 | trainer.save_checkpoint(args.default_root_dir / f"checkpoints/cinenet.ckpt") 109 | print(f"Saving checkpoint in cinenet_{args.dynamic_type}_acc{args.accelerations[0]}_ckpt") 110 | 111 | elif args.mode == "test": 112 | trainer.test(model, datamodule=data_module) 113 | else: 114 | raise ValueError(f"unrecognized mode {args.mode}") 115 | 116 | # ----------------- 117 | # INFERENCE SECTION 118 | # ----------------- 119 | 120 | if (args.mode == "test" and args.inference): 121 | 122 | inference_dataset = SliceDataset( 123 | root=args.data_path / "inference", transform=test_transform, 124 | ) 125 | dataloader = torch.utils.data.DataLoader(inference_dataset, num_workers=4) 126 | inf_transform = InferenceTransform(model, 'cinenet', save_path) 127 | time_for_inference = 0 128 | 129 | print('Starting inference..............') 130 | 131 | for batch in dataloader: 132 | with torch.no_grad(): 133 | masked_kspace, mask, sens_maps, target, fname, _, _, _ = batch 134 | time_for_inference += inf_transform(masked_kspace, mask, target, fname, sens_maps) 135 | 136 | print(f"Elapsed time: {time_for_inference} seconds.") 137 | 138 | 139 | 140 | 141 | def build_args(): 142 | parser = ArgumentParser() 143 | 144 | # ---------- 145 | # BASIC ARGS 146 | # ---------- 147 | path_config = pathlib.Path("/root/traintest_scripts/dirs_path.yaml") 148 | backend = "dp" 149 | num_gpus = 2 if backend == "ddp" else 1 150 | batch_size = 1 151 | 152 | # Set defaults based on optional directory config 153 | data_path = fetch_dir("data_path", path_config) 154 | save_path = fetch_dir("save_path", path_config) 155 | default_root_dir = fetch_dir("log_path", path_config) / "cinenet/cinenet_logs" 156 | 157 | # ----------- 158 | # CLIENT ARGS 159 | # ----------- 160 | parser.add_argument( 161 | "--mode", 162 | default="train", 163 | choices=("train", "test"), 164 | type=str, 165 | help="Operation mode", 166 | ) 167 | 168 | parser.add_argument( 169 | "--epochs", 170 | default=150, 171 | type=int, 172 | help="Total number of epochs to train the model for", 173 | ) 174 | 175 | parser.add_argument( 176 | "--save_checkpoint", 177 | default=0, 178 | choices=(0, 1), 179 | type=int, 180 | help="Whether to save a checkpoint of the model at the end of training", 181 | ) 182 | 183 | parser.add_argument( 184 | "--resume_training", 185 | default=0, 186 | choices=(0, 1), 187 | type=int, 188 | help="Whether to resume training from the latest checkpoint", 189 | ) 190 | 191 | parser.add_argument( 192 | "--load_model", 193 | default=0, 194 | choices=(0, 1), 195 | type=int, 196 | help="Whether to load the latest model in checkpoint dir, to be used for testing", 197 | ) 198 | 199 | parser.add_argument( 200 | "--inference", 201 | default=1, 202 | choices=(0, 1), 203 | type=int, 204 | help="Whether to generate and save the reconstruction made by the trained model on an inference dataset", 205 | ) 206 | 207 | # Data transform params 208 | parser.add_argument( 209 | "--mask_type", 210 | choices=("random", "equispaced"), 211 | default="random", 212 | type=str, 213 | help="Type of k-space mask", 214 | ) 215 | parser.add_argument( 216 | "--center_fractions", 217 | nargs="+", 218 | default=[10], 219 | type=float, 220 | help="Number of central lines to use in mask", 221 | ) 222 | parser.add_argument( 223 | "--accelerations", 224 | nargs="+", 225 | default=[4], 226 | type=int, 227 | help="Acceleration rates to use for masks", 228 | ) 229 | 230 | 231 | # -------------- 232 | # MODULES CONFIG 233 | # -------------- 234 | 235 | # Data config 236 | parser = MriDataModule.add_data_specific_args(parser) 237 | parser.set_defaults( 238 | data_path=data_path, 239 | test_path=None, 240 | test_split="test", 241 | sample_rate=None, 242 | use_dataset_cache_file=True, 243 | combine_train_val=False, 244 | batch_size=batch_size, 245 | num_workers=4, 246 | ) 247 | 248 | # Model config 249 | parser = CineNetModule.add_model_specific_args(parser) 250 | parser.set_defaults( 251 | num_cascades=10, 252 | CG_iters=6, 253 | pools=3, 254 | chans=16, 255 | dynamic_type='XF', 256 | weight_sharing=False, 257 | lr=0.0001, 258 | lr_step_size=140, 259 | lr_gamma=0.01, 260 | weight_decay=0.0, 261 | ) 262 | 263 | args = parser.parse_args() 264 | 265 | # Configure checkpointing in checkpoint_dir 266 | checkpoint_dir = default_root_dir / "checkpoints" 267 | if not checkpoint_dir.exists(): 268 | checkpoint_dir.mkdir(parents=True) 269 | 270 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 271 | dirpath=default_root_dir / "checkpoints", 272 | filename=f"cinenet_{args.dynamic_type}_acc{args.accelerations[0]}_ckpt", 273 | verbose=True, 274 | monitor="validation_loss", 275 | mode="min", 276 | ) 277 | 278 | resume_from_checkpoint_path = None 279 | if args.resume_training: 280 | ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime) 281 | if ckpt_list: 282 | resume_from_checkpoint_path = str(ckpt_list[-1]) 283 | 284 | # Configure trainer options 285 | parser = pl.Trainer.add_argparse_args(parser) 286 | parser.set_defaults( 287 | gpus=num_gpus, # number of gpus to use 288 | replace_sampler_ddp=False, # this is necessary for volume dispatch during val 289 | accelerator=backend, # what distributed version to use 290 | seed=42, # random seed 291 | deterministic=True, # makes things slower, but deterministic 292 | default_root_dir=default_root_dir, # directory for logs and checkpoints 293 | max_epochs=args.epochs, 294 | callbacks=[checkpoint_callback], 295 | resume_from_checkpoint=resume_from_checkpoint_path, 296 | ) 297 | 298 | args = parser.parse_args() 299 | 300 | return args, save_path 301 | 302 | 303 | 304 | def run_main(): 305 | args, save_path = build_args() 306 | train_test_main(args, save_path) 307 | 308 | 309 | if __name__ == "__main__": 310 | run_main() 311 | -------------------------------------------------------------------------------- /reconstruction/models/denoisers/mwcnn.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | 8 | class MWCNN(nn.Module): 9 | """ 10 | PyTorch implementation of a Multi-scale Wavelet CNN model, based on the 11 | standard U-Net architecture where pooling operations have been replaced 12 | by discrete wavelet transforms. 13 | 14 | Source: 15 | https://github.com/zaccharieramzi/fastmri-reproducible-benchmark/blob/master/fastmri_recon/models/subclassed_models/denoisers/mwcnn.py#L134 16 | """ 17 | def __init__( 18 | self, 19 | in_chans: int, 20 | out_chans: int, 21 | dims: int = 2, 22 | n_scales: int = 3, 23 | n_filters_per_scale: List[int] = [16, 32, 64], 24 | n_convs_per_scale: List[int] = [2, 2, 2], 25 | n_first_convs: int = 1, 26 | first_conv_n_filters: int = 16, 27 | res: bool = False, 28 | ): 29 | """ 30 | Args: 31 | in_chans: Number of channels in the input to the MWCNN model. 32 | out_chans: Number of channels in the output of the MWCNN model. 33 | dims: number of dimensions for convolutional operations (2 or 3). 34 | n_scales: Number of scales, i.e. number of pooling layers. 35 | n_filters_per_scale: Number of filters used by the convolutional 36 | layers at each scale. 37 | n_convs_per_scale: Number of convolutional layers per scale. 38 | n_first_convs: Number of convolutional layers at the start of 39 | the architecture, i.e. before pooling layers. 40 | first_conv_n_filters: Number of filters used by the inital 41 | convolutional layers. 42 | res: Whether to use a residual connection between input and output. 43 | """ 44 | super().__init__() 45 | 46 | self.in_chans = in_chans 47 | self.out_chans = out_chans 48 | self.dims = dims 49 | self.n_scales = n_scales 50 | self.n_filters_per_scale = n_filters_per_scale 51 | self.n_convs_per_scale = n_convs_per_scale 52 | self.n_first_convs = n_first_convs 53 | self.first_conv_n_filters = first_conv_n_filters 54 | self.res = res 55 | 56 | assert self.dims in [2, 3], \ 57 | "Dimensions must be either 2 or 3" 58 | 59 | if self.dims == 2: 60 | conv_op = nn.Conv2d 61 | if self.dims == 3: 62 | conv_op = nn.Conv3d 63 | 64 | # First and last convolutions block without pooling 65 | if self.n_first_convs > 0: 66 | self.first_convs = nn.ModuleList([ConvBlock( 67 | in_chans = self.in_chans, 68 | n_filters = self.first_conv_n_filters, 69 | dims = self.dims, 70 | )]) 71 | for _ in range(1, 2 * self.n_first_convs - 1): 72 | self.first_convs.append(ConvBlock( 73 | in_chans = self.first_conv_n_filters, 74 | n_filters = self.first_conv_n_filters, 75 | dims = self.dims, 76 | )) 77 | self.first_convs.append(conv_op( 78 | self.first_conv_n_filters, 79 | self.out_chans, 80 | kernel_size=3, 81 | padding='same', 82 | bias=True, 83 | )) 84 | 85 | # All convolution blocks during pooling/unpooling 86 | self.conv_blocks_per_scale = nn.ModuleList([ 87 | nn.ModuleList([ConvBlock( 88 | in_chans = self.chans_for_conv_for_scale(i_scale, i_conv)[0], 89 | n_filters = self.chans_for_conv_for_scale(i_scale, i_conv)[1], 90 | dims = self.dims, 91 | ) for i_conv in range(self.n_convs_per_scale[i_scale] * 2)]) 92 | for i_scale in range(self.n_scales) 93 | ]) 94 | 95 | if self.n_first_convs < 1: 96 | # Adjust last convolution of the last convolution block 97 | self.conv_blocks_per_scale[0][-1] = conv_op( 98 | self.n_filters_per_scale[0], 99 | 4 * self.out_chans, 100 | kernel_size=3, 101 | padding='same', 102 | bias=True, 103 | ) 104 | 105 | # Pooling operations 106 | self.pooling = DWT() 107 | self.unpooling = IWT() 108 | 109 | 110 | def chans_for_conv_for_scale(self, i_scale: int, i_conv: int) -> Tuple[int, int]: 111 | """ 112 | Returns input channels and number of filters for each convolution at each 113 | scale for both downsampling and upsampling sections of the network. 114 | """ 115 | in_chans = self.n_filters_per_scale[i_scale] 116 | n_filters = self.n_filters_per_scale[i_scale] 117 | 118 | # Convolutions in downsampling section 119 | if i_conv == 0: 120 | if i_scale == 0: 121 | in_chans = 4 * self.first_conv_n_filters 122 | else: 123 | in_chans = 4 * self.n_filters_per_scale[i_scale-1] 124 | 125 | # Convolutions in upsampling section 126 | if i_conv == self.n_convs_per_scale[i_scale] * 2 - 1: 127 | if i_scale == 0: 128 | n_filters = max(4 * self.first_conv_n_filters, 4 * self.out_chans) 129 | else: 130 | n_filters = 4 * self.n_filters_per_scale[i_scale-1] 131 | 132 | return in_chans, n_filters 133 | 134 | 135 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 136 | """ 137 | Args: 138 | inputs: Input tensor of shape `(N,in_chans,H,W)` 139 | - 140 | Returns: 141 | Output tensor of shape `(N,out_chans,H,W)` 142 | """ 143 | last_feature_for_scale = [] 144 | current_feature = inputs 145 | 146 | # First convolutions 147 | if self.n_first_convs > 0: 148 | for conv in self.first_convs[:self.n_first_convs]: 149 | current_feature = conv(current_feature) 150 | first_conv_feature = current_feature 151 | 152 | # Downsampling section 153 | for i_scale in range(self.n_scales): 154 | current_feature = self.pooling(current_feature) 155 | n_convs = self.n_convs_per_scale[i_scale] 156 | for conv in self.conv_blocks_per_scale[i_scale][:n_convs]: 157 | current_feature = conv(current_feature) 158 | last_feature_for_scale.append(current_feature) 159 | 160 | # Upsampling section 161 | for i_scale in range(self.n_scales - 1, -1, -1): 162 | if i_scale != self.n_scales - 1: 163 | current_feature = self.unpooling(current_feature) 164 | current_feature = current_feature + last_feature_for_scale[i_scale] 165 | n_convs = self.n_convs_per_scale[i_scale] 166 | for conv in self.conv_blocks_per_scale[i_scale][n_convs:]: 167 | current_feature = conv(current_feature) 168 | current_feature = self.unpooling(current_feature) 169 | 170 | # Last convolution 171 | if self.n_first_convs > 0: 172 | current_feature = current_feature + first_conv_feature 173 | for conv in self.first_convs[self.n_first_convs:]: 174 | current_feature = conv(current_feature) 175 | if self.res: 176 | outputs = inputs + current_feature 177 | else: 178 | outputs = current_feature 179 | return outputs 180 | 181 | 182 | 183 | class ConvBlock(nn.Module): 184 | """ 185 | A Convolutional Block that consists of one convolution layer, followed by 186 | instance normalization and LeakyReLU activation. 187 | """ 188 | def __init__(self, in_chans: int, n_filters: int, dims: int): 189 | """ 190 | Args: 191 | in_chans: Number of channels in the input. 192 | n_filters: Number of convolutional filters. 193 | dims: Number of dimensions for convolutional operations (2 or 3). 194 | """ 195 | super().__init__() 196 | 197 | if dims == 2: 198 | conv_op = nn.Conv2d 199 | norm_op = nn.InstanceNorm2d 200 | 201 | if dims == 3: 202 | conv_op = nn.Conv3d 203 | norm_op = nn.InstanceNorm3d 204 | 205 | self.layers = nn.Sequential( 206 | conv_op(in_chans, n_filters, kernel_size=3, padding='same', bias=False), 207 | norm_op(n_filters), 208 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 209 | ) 210 | 211 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 212 | return self.layers(inputs) 213 | 214 | 215 | 216 | class DWT(nn.Module): 217 | """ 218 | A discrete wavelet transform used in the down-pooling operations of the 219 | MWCNN network. 220 | """ 221 | def __init__(self): 222 | super().__init__() 223 | 224 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 225 | x01 = inputs[:, :, 0::2] / 2 226 | x02 = inputs[:, :, 1::2] / 2 227 | x1 = x01[..., 0::2] 228 | x2 = x02[..., 0::2] 229 | x3 = x01[..., 1::2] 230 | x4 = x02[..., 1::2] 231 | x_LL = x1 + x2 + x3 + x4 232 | x_HL = -x1 - x2 + x3 + x4 233 | x_LH = -x1 + x2 - x3 + x4 234 | x_HH = x1 - x2 - x3 + x4 235 | 236 | return torch.cat([x_LL, x_HL, x_LH, x_HH], dim=1) 237 | 238 | 239 | 240 | class IWT(nn.Module): 241 | """ 242 | A discrete inverse wavelet transform used in the up-pooling operations of the 243 | MWCNN network. 244 | """ 245 | def __init__(self): 246 | super().__init__() 247 | 248 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 249 | b, ch, h, w = inputs.shape 250 | new_ch = ch // 4 251 | 252 | x1 = inputs[:, 0:new_ch] / 2 253 | x2 = inputs[:, new_ch:2*new_ch] / 2 254 | x3 = inputs[:, 2*new_ch:3*new_ch] / 2 255 | x4 = inputs[:, 3*new_ch:4*new_ch] / 2 256 | 257 | outputs = torch.zeros([b, new_ch, 2*h, 2*w], dtype=inputs.dtype).cuda() 258 | outputs[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 259 | outputs[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 260 | outputs[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 261 | outputs[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 262 | 263 | return outputs 264 | 265 | -------------------------------------------------------------------------------- /reconstruction/models/cinenet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Callable 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import reconstruction as rec 9 | from reconstruction.data import transforms 10 | from .denoisers.unet import Unet 11 | 12 | 13 | 14 | class CineNet(nn.Module): 15 | """ 16 | An adaptation of the CineNet model for dynamic MRI reconstruction, 17 | which consists of alternating U-Net and Conjugate Gradient (CG) blocks. 18 | Reference paper: 19 | 20 | A. Kofler et al. `An end-to-end-trainable iterative network architecture 21 | for accelerated radial multi-coil 2D cine MR image reconstruction.` 22 | In Medical Physics, 2021. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | num_cascades: int = 12, 28 | CG_iters: int = 4, 29 | chans: int = 18, 30 | pools: int = 4, 31 | dynamic_type: str = 'XF', 32 | weight_sharing: bool = False, 33 | ): 34 | """ 35 | Args: 36 | num_cascades: Number of alternations between CG and U-Net modules. 37 | CG_iters: Number of CG iterations in the CG module. 38 | chans: Number of channels for cascade U-Net. 39 | pools: Number of downsampling and upsampling layers for cascade U-Net. 40 | dynamic_type: Type of architecture adjustment for dynamic setting. 41 | weight_sharing: Optional setting in 'XF' or 'XT' dynamics mode, allowing 42 | U-Net to share the same parameters in both x-f and y-f planes. 43 | """ 44 | super().__init__() 45 | 46 | if dynamic_type in ['XF', 'XT']: 47 | if weight_sharing: 48 | self.model = Unet(chans, pools, dims=2) 49 | else: 50 | self.model = nn.ModuleList([Unet(chans, pools, dims=2), Unet(chans, pools, dims=2)]) 51 | elif dynamic_type == '3D': 52 | self.model = Unet(chans, pools, dims=3) 53 | else: 54 | self.model = Unet(chans, pools, dims=2) 55 | 56 | self.cascades = nn.ModuleList( 57 | [CineNetBlock(self.model, CG_iters, dynamic_type, weight_sharing) for _ in range(num_cascades)] 58 | ) 59 | 60 | 61 | def forward(self, masked_kspace: torch.Tensor, mask: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 62 | 63 | # Coil-combined image, shape (b, t, 1, h, w, ch) 64 | image_pred = rec.utils.complex_mul( 65 | rec.utils.ifft2c(masked_kspace), rec.utils.complex_conj(sens_maps) 66 | ).sum(dim=2, keepdim=True) 67 | 68 | image_ref = image_pred.clone() 69 | 70 | for cascade in self.cascades: 71 | image_pred = cascade(image_pred, image_ref, mask, sens_maps) 72 | 73 | return rec.utils.complex_abs(image_pred.squeeze(2)) 74 | 75 | 76 | 77 | class CineNetBlock(nn.Module): 78 | """ 79 | Model block for CineNet with several temporal dynamics adjustments. 80 | A series of these blocks can be stacked to form the full network. 81 | """ 82 | 83 | def __init__(self, model: nn.Module, CG_iters: int, dynamic_type: str, weight_sharing: bool): 84 | """ 85 | Args: 86 | model: Module for UNet-type image denoiser component of CineNet. 87 | Its architecture depends on the specfic dynamics mode. 88 | CG_iters: Number of CG iterations in the CG module. 89 | dynamic_type: Type of architecture adjustment for dynamic setting. 90 | weight_sharing: Optional setting in 'XF' or 'XT' dynamics mode, allowing 91 | U-Net to share the same parameters in both x-f and y-f planes. 92 | """ 93 | super().__init__() 94 | 95 | self.model = model 96 | self.CG_iters = CG_iters 97 | self.dynamic_type = dynamic_type 98 | self.weight_sharing = weight_sharing 99 | 100 | # Regularisation parameter is learned during training 101 | self.Softplus = nn.Softplus(1.) 102 | lambda_init = np.log(np.exp(1)-1.)/1. 103 | self.lambda_reg = nn.Parameter(torch.tensor(lambda_init*torch.ones(1),dtype=torch.float), 104 | requires_grad=True) 105 | 106 | def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 107 | """ 108 | Forward operator: from coil-combined image-space to k-space. 109 | """ 110 | return rec.utils.fft2c(rec.utils.complex_mul(x, sens_maps)) 111 | 112 | def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 113 | """ 114 | Backward operator: from k-space to coil-combined image-space. 115 | """ 116 | x = rec.utils.ifft2c(x) 117 | return rec.utils.complex_mul(x, rec.utils.complex_conj(sens_maps)).sum( 118 | dim=2, keepdim=True, 119 | ) 120 | 121 | def HOperator(self, x: torch.Tensor, mask: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 122 | """ 123 | The operator H = A^H \circ A + \lambda_Reg * \Id, where A is the encoding matrix. 124 | This ensures data consistency. 125 | """ 126 | # Forward operator 127 | k_coils = self.sens_expand(x, sens_maps) 128 | # Apply sampling mask 129 | k_masked = k_coils * mask + 0.0 130 | # Backward operator 131 | x_combined = self.sens_reduce(k_masked, sens_maps) 132 | # Result of H(x) 133 | return x_combined + self.Softplus(self.lambda_reg) * x 134 | 135 | 136 | def ConjGrad(self, x:torch.Tensor, b:torch.Tensor, mask:torch.Tensor, sens_maps:torch.Tensor, CG_iters:int)-> torch.Tensor: 137 | """ 138 | Conjugate Gradient method for solving the system Hx = b 139 | """ 140 | # x is the starting value, b the rhs 141 | r = self.HOperator(x, mask, sens_maps) 142 | r = b-r 143 | 144 | # Initialize p 145 | p = r.clone() 146 | 147 | # Old squared norm of residual 148 | sqnorm_r_old = torch.dot(r.flatten(), r.flatten()) 149 | 150 | for kiter in range(CG_iters): 151 | # Calculate H(p) 152 | d = self.HOperator(p, mask, sens_maps) 153 | 154 | # Calculate step size alpha; 155 | inner_p_d = torch.dot(p.flatten(), d.flatten()) 156 | alpha = sqnorm_r_old / inner_p_d 157 | 158 | # Perform step and calculate new residual 159 | x = torch.add(x, p, alpha = alpha.item()) 160 | r = torch.add(r, d, alpha = -alpha.item()) 161 | 162 | # New residual norm 163 | sqnorm_r_new = torch.dot(r.flatten(), r.flatten()) 164 | 165 | # Calculate beta and update the norm 166 | beta = sqnorm_r_new / sqnorm_r_old 167 | sqnorm_r_old = sqnorm_r_new 168 | 169 | p = torch.add(r, p, alpha = beta.item()) 170 | 171 | return x 172 | 173 | 174 | def xfyf_transform(self, image_combined: torch.Tensor) -> torch.Tensor: 175 | """ 176 | Separate input into two volumes in the rotated planes x-f and y-f 177 | (or x-t, y-t if in 'XT' dynamics mode). After being processed by 178 | their respective U-Nets, the volumes are then combined back into one. 179 | """ 180 | b, t, h, w, ch = image_combined.shape 181 | 182 | # Subtract the image temporal average for numerical stability 183 | image_temp = image_combined.clone() 184 | image_mean = torch.stack(t * [torch.mean(image_temp, dim=1)], dim=1) 185 | x = image_combined - image_mean 186 | 187 | if self.dynamic_type == 'XF': 188 | # Apply temporal FFT 189 | x = x.permute(0,2,3,1,4) # b,h,w,t,2 190 | x = rec.utils.fft1c(x) 191 | x = x.permute(0,3,1,2,4) # b,t,h,w,2 192 | 193 | # Reshape to xf, yf planes 194 | xf = x.clone().permute(0,2,4,3,1).view(b*h, 2, w, t) 195 | yf = x.clone().permute(0,3,4,2,1).view(b*w, 2, h, t) 196 | 197 | # UNet opearting on temporal transformed xf, yf-domain 198 | if self.weight_sharing: 199 | xf = self.model(xf) 200 | yf = self.model(yf) 201 | else: 202 | model_xf, model_yf = self.model 203 | xf = model_xf(xf) 204 | yf = model_yf(yf) 205 | 206 | # Reshape from xf, yf 207 | xf_r = xf.view(b,h,1,2,w,t).permute(0,5,2,1,4,3) # b,t,1,h,w,2 208 | yf_r = yf.view(b,w,1,2,h,t).permute(0,5,2,4,1,3) # b,t,1,h,w,2 209 | 210 | out = 0.5 * (xf_r + yf_r) 211 | 212 | if self.dynamic_type == 'XF': 213 | # Apply temporal IFFT 214 | out = out.permute(0,2,3,4,1,5) # b,1,h,w,t,2 215 | out = rec.utils.ifft1c(out) 216 | out = out.permute(0,4,1,2,3,5) # b,t,1,h,w,2 217 | 218 | # Residual connection 219 | return out + image_mean.unsqueeze(2) 220 | 221 | 222 | def forward( 223 | self, 224 | image_pred: torch.Tensor, 225 | image_ref: torch.Tensor, 226 | mask: torch.Tensor, 227 | sens_maps: torch.Tensor, 228 | ) -> torch.Tensor: 229 | 230 | # Prepare image for input to U-Net 231 | b, t, c, h, w, ch = image_pred.shape # c=1 (coil-combined image) 232 | 233 | if self.dynamic_type in ['XF', 'XT']: 234 | model_out = self.xfyf_transform(image_pred.squeeze(2)) 235 | 236 | if self.dynamic_type == '2D': 237 | # Batch dimension b=1. Make first dimension time so 238 | # that each slice is trained independently. This is 239 | # similar to static MRI reconstruction. 240 | 241 | # Input to model has shape (t, ch, h, w) 242 | image_in = image_pred.permute(0,1,2,5,3,4).reshape(b*t, c*ch, h, w) 243 | model_out = self.model(image_in).reshape(b, t, c, ch, h, w, 244 | ).permute(0,1,2,4,5,3) 245 | 246 | if self.dynamic_type == '3D': 247 | # In this mode the whole spatio-temporal volume is 248 | # processed by a 3D U-Net at once. 249 | 250 | # Input to model has shape (b, ch, t, h, w) 251 | image_in = image_pred.permute(0,5,2,1,3,4).reshape(b, ch*c, t, h, w) 252 | model_out = self.model(image_in).reshape(b, ch, c, t, h, w, 253 | ).permute(0,3,2,4,5,1) 254 | 255 | return self.ConjGrad( 256 | model_out, image_ref + self.Softplus(self.lambda_reg) * model_out, mask, sens_maps, self.CG_iters 257 | ) 258 | 259 | -------------------------------------------------------------------------------- /traintest_scripts/xpdnet/train_test_xpdnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/path/to/source') 3 | 4 | import os 5 | import pathlib 6 | import time 7 | from argparse import ArgumentParser 8 | 9 | import numpy as np 10 | import torch 11 | import pytorch_lightning as pl 12 | 13 | from reconstruction.data import SliceDataset 14 | from reconstruction.data.mri_data import fetch_dir 15 | from reconstruction.data.subsample import create_mask_for_mask_type 16 | from reconstruction.data.transforms import XPDNetDataTransform 17 | from reconstruction.pl_modules import MriDataModule, XPDNetModule 18 | from traintest_scripts.run_inference import InferenceTransform 19 | 20 | 21 | 22 | def train_test_main(args, save_path): 23 | pl.seed_everything(args.seed) 24 | 25 | # ------------ 26 | # DATA SECTION 27 | # ------------ 28 | 29 | # This creates a k-space mask to subsample input data 30 | mask = create_mask_for_mask_type( 31 | args.mask_type, args.center_fractions, args.accelerations 32 | ) 33 | 34 | train_transform = XPDNetDataTransform(mask_func=mask, use_seed=False) 35 | val_transform = XPDNetDataTransform(mask_func=mask, use_seed=False) 36 | test_transform = XPDNetDataTransform(mask_func=mask, use_seed=False) 37 | 38 | # Data module - this handles data loaders 39 | data_module = MriDataModule( 40 | data_path=args.data_path, 41 | train_transform=train_transform, 42 | val_transform=val_transform, 43 | test_transform=test_transform, 44 | combine_train_val=args.combine_train_val, 45 | test_split=args.test_split, 46 | test_path=args.test_path, 47 | sample_rate=args.sample_rate, 48 | use_dataset_cache_file=args.use_dataset_cache_file, 49 | batch_size=args.batch_size, 50 | num_workers=args.num_workers, 51 | distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")), 52 | ) 53 | 54 | # ------------- 55 | # MODEL SECTION 56 | # ------------- 57 | 58 | # Load model state dictionary (generally for testing) 59 | if args.load_model: 60 | checkpoint_dir = args.default_root_dir / "checkpoints" 61 | ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime) 62 | if ckpt_list: 63 | last_ckpt_path = str(ckpt_list[-1]) 64 | print(f"Loading model from {last_ckpt_path}") 65 | model = XPDNetModule.load_from_checkpoint(last_ckpt_path) 66 | else: 67 | raise ValueError("No checkpoint available") 68 | 69 | else: 70 | # Build model 71 | model = XPDNetModule( 72 | num_cascades=args.num_cascades, 73 | sens_chans=args.sens_chans, 74 | sens_pools=args.sens_pools, 75 | crnn_chans=args.crnn_chans, 76 | n_scales=args.n_scales, 77 | n_filters_per_scale=args.n_filters_per_scale, 78 | n_convs_per_scale=args.n_convs_per_scale, 79 | n_first_convs=args.n_first_convs, 80 | first_conv_n_filters=args.first_conv_n_filters, 81 | res=args.res, 82 | primal_only=args.primal_only, 83 | n_primal=args.n_primal, 84 | n_dual=args.n_dual, 85 | dynamic_type=args.dynamic_type, 86 | weight_sharing = args.weight_sharing, 87 | lr=args.lr, 88 | lr_step_size=args.lr_step_size, 89 | lr_gamma=args.lr_gamma, 90 | weight_decay=args.weight_decay, 91 | ) 92 | 93 | # ------------------ 94 | # TRAIN-TEST SECTION 95 | # ------------------ 96 | 97 | trainer = pl.Trainer.from_argparse_args(args) 98 | 99 | if args.mode == "train": 100 | 101 | print("Training XPDNet " 102 | f"{args.dynamic_type} with " 103 | f"{args.num_cascades} cascades for " 104 | f"{args.max_epochs} epochs.\nData is subsampled with a " 105 | f"{args.mask_type} mask, acceleration " 106 | f"{args.accelerations[0]}." 107 | ) 108 | 109 | start_time = time.perf_counter() 110 | trainer.fit(model, datamodule=data_module) 111 | end_time = time.perf_counter() 112 | 113 | print(f"Training time: {(end_time-start_time) / 3600.} hours") 114 | 115 | if args.save_checkpoint: 116 | trainer.save_checkpoint(args.default_root_dir / f"checkpoints/xpdnet.ckpt") 117 | print(f"Saving checkpoint in xpdnet_{args.dynamic_type}_acc{args.accelerations[0]}_ckpt") 118 | 119 | elif args.mode == "test": 120 | trainer.test(model, datamodule=data_module) 121 | else: 122 | raise ValueError(f"unrecognized mode {args.mode}") 123 | 124 | # ----------------- 125 | # INFERENCE SECTION 126 | # ----------------- 127 | 128 | if (args.mode == "test" and args.inference): 129 | 130 | inference_dataset = SliceDataset( 131 | root=args.data_path / "inference", transform=test_transform, 132 | ) 133 | dataloader = torch.utils.data.DataLoader(inference_dataset, num_workers=2) 134 | inf_transform = InferenceTransform(model, 'xpdnet', save_path) 135 | time_for_inference = 0 136 | 137 | print('Starting inference..............') 138 | 139 | for batch in dataloader: 140 | with torch.no_grad(): 141 | masked_kspace, mask, target, fname, _, _, _ = batch 142 | time_for_inference += inf_transform(masked_kspace, mask, target, fname) 143 | 144 | print(f"Elapsed time: {time_for_inference} seconds.") 145 | 146 | 147 | 148 | 149 | def build_args(): 150 | parser = ArgumentParser() 151 | 152 | # ---------- 153 | # BASIC ARGS 154 | # ---------- 155 | path_config = pathlib.Path("/root/traintest_scripts/dirs_path.yaml") 156 | backend = "dp" 157 | num_gpus = 2 if backend == "ddp" else 1 158 | batch_size = 1 159 | 160 | # Set defaults based on optional directory config 161 | data_path = fetch_dir("data_path", path_config) 162 | save_path = fetch_dir("save_path", path_config) 163 | default_root_dir = fetch_dir("log_path", path_config) / "xpdnet/xpdnet_logs" 164 | 165 | # ----------- 166 | # CLIENT ARGS 167 | # ----------- 168 | parser.add_argument( 169 | "--mode", 170 | default="train", 171 | choices=("train", "test"), 172 | type=str, 173 | help="Operation mode", 174 | ) 175 | 176 | parser.add_argument( 177 | "--epochs", 178 | default=150, 179 | type=int, 180 | help="Total number of epochs to train the model for", 181 | ) 182 | 183 | parser.add_argument( 184 | "--save_checkpoint", 185 | default=0, 186 | choices=(0, 1), 187 | type=int, 188 | help="Whether to save a checkpoint of the model at the end of training", 189 | ) 190 | 191 | parser.add_argument( 192 | "--resume_training", 193 | default=0, 194 | choices=(0, 1), 195 | type=int, 196 | help="Whether to resume training from the latest checkpoint", 197 | ) 198 | 199 | parser.add_argument( 200 | "--load_model", 201 | default=0, 202 | choices=(0, 1), 203 | type=int, 204 | help="Whether to load the latest model in checkpoint dir, to be used for testing", 205 | ) 206 | 207 | parser.add_argument( 208 | "--inference", 209 | default=1, 210 | choices=(0, 1), 211 | type=int, 212 | help="Whether to generate and save the reconstruction made by the trained model on an inference dataset", 213 | ) 214 | 215 | # Data transform params 216 | parser.add_argument( 217 | "--mask_type", 218 | choices=("random", "equispaced"), 219 | default="random", 220 | type=str, 221 | help="Type of k-space mask", 222 | ) 223 | parser.add_argument( 224 | "--center_fractions", 225 | nargs="+", 226 | default=[10], 227 | type=float, 228 | help="Number of central lines to use in mask", 229 | ) 230 | parser.add_argument( 231 | "--accelerations", 232 | nargs="+", 233 | default=[4], 234 | type=int, 235 | help="Acceleration rates to use for masks", 236 | ) 237 | 238 | 239 | # -------------- 240 | # MODULES CONFIG 241 | # -------------- 242 | 243 | # Data config 244 | parser = MriDataModule.add_data_specific_args(parser) 245 | parser.set_defaults( 246 | data_path=data_path, 247 | test_path=None, 248 | test_split="test", 249 | sample_rate=None, 250 | use_dataset_cache_file=True, 251 | combine_train_val=False, 252 | batch_size=batch_size, 253 | num_workers=4, 254 | ) 255 | 256 | # Model config 257 | parser = XPDNetModule.add_model_specific_args(parser) 258 | parser.set_defaults( 259 | num_cascades=9, 260 | sens_chans=8, 261 | sens_pools=3, 262 | crnn_chans=18, 263 | n_scales=3, 264 | n_filters_per_scale=[16, 32, 64], 265 | n_convs_per_scale=[2, 2, 2], 266 | n_first_convs=1, 267 | first_conv_n_filters=16, 268 | res=False, 269 | primal_only=True, 270 | n_primal=5, 271 | n_dual=1, 272 | dynamic_type='XF', 273 | weight_sharing=False, 274 | lr=0.0001, 275 | lr_step_size=140, 276 | lr_gamma=0.01, 277 | weight_decay=0.0, 278 | ) 279 | 280 | args = parser.parse_args() 281 | 282 | # Configure checkpointing in checkpoint_dir 283 | checkpoint_dir = default_root_dir / "checkpoints" 284 | if not checkpoint_dir.exists(): 285 | checkpoint_dir.mkdir(parents=True) 286 | 287 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 288 | dirpath=default_root_dir / "checkpoints", 289 | filename=f"xpdnet_{args.dynamic_type}_acc{args.accelerations[0]}_ckpt", 290 | verbose=True, 291 | monitor="validation_loss", 292 | mode="min", 293 | ) 294 | 295 | resume_from_checkpoint_path = None 296 | if args.resume_training: 297 | ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime) 298 | if ckpt_list: 299 | resume_from_checkpoint_path = str(ckpt_list[-1]) 300 | 301 | # Configure trainer options 302 | parser = pl.Trainer.add_argparse_args(parser) 303 | parser.set_defaults( 304 | gpus=num_gpus, # number of gpus to use 305 | replace_sampler_ddp=False, # this is necessary for volume dispatch during val 306 | accelerator=backend, # what distributed version to use 307 | seed=42, # random seed 308 | deterministic=True, # makes things slower, but deterministic 309 | default_root_dir=default_root_dir, # directory for logs and checkpoints 310 | max_epochs=args.epochs, 311 | callbacks=[checkpoint_callback], 312 | resume_from_checkpoint=resume_from_checkpoint_path, 313 | ) 314 | 315 | args = parser.parse_args() 316 | 317 | return args, save_path 318 | 319 | 320 | 321 | def run_main(): 322 | args, save_path = build_args() 323 | train_test_main(args, save_path) 324 | 325 | 326 | if __name__ == "__main__": 327 | run_main() 328 | -------------------------------------------------------------------------------- /reconstruction/models/varnet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import reconstruction as rec 9 | from reconstruction.data import transforms 10 | from .denoisers.norm_unet import NormUnet, NormUnet3D 11 | 12 | 13 | 14 | class SensitivityModel(nn.Module): 15 | """ 16 | Model for learning sensitivity estimation from k-space data. 17 | 18 | This model applies an IFFT to multichannel k-space data and then a U-Net 19 | to the coil images to estimate coil sensitivities. It can be used with the 20 | end-to-end variational network. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | chans: int, 26 | num_pools: int, 27 | in_chans: int = 2, 28 | out_chans: int = 2, 29 | drop_prob: float = 0.0, 30 | ): 31 | """ 32 | Args: 33 | chans: Number of output channels of the first convolution layer. 34 | num_pools: Number of down-sampling and up-sampling layers. 35 | in_chans: Number of channels in the input to the U-Net model. 36 | out_chans: Number of channels in the output to the U-Net model. 37 | drop_prob: Dropout probability. 38 | """ 39 | super().__init__() 40 | 41 | self.norm_unet = NormUnet( 42 | chans, 43 | num_pools, 44 | in_chans=in_chans, 45 | out_chans=out_chans, 46 | drop_prob=drop_prob, 47 | ) 48 | 49 | def chans_to_batch_dim(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: 50 | b, c, h, w, comp = x.shape 51 | return x.view(b * c, 1, h, w, comp), b 52 | 53 | def batch_chans_to_chan_dim(self, x: torch.Tensor, batch_size: int) -> torch.Tensor: 54 | bc, _, h, w, comp = x.shape 55 | c = bc // batch_size 56 | return x.view(batch_size, c, h, w, comp) 57 | 58 | def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor: 59 | return x / rec.utils.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1) 60 | 61 | 62 | def forward(self, masked_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 63 | # Get low frequency line locations to mask them out 64 | cent = mask.shape[-3] // 2 65 | left = torch.nonzero(mask[:,0,:].squeeze()[:cent] == 0)[-1] 66 | right = torch.nonzero(mask[:,0,:].squeeze()[cent:] == 0)[0] + cent 67 | num_low_freqs = right - left 68 | pad = (mask.shape[-3] - num_low_freqs + 1) // 2 69 | 70 | # Time-averaged k-space 71 | x = transforms.mask_center(torch.mean(masked_kspace, 1), pad, pad + num_low_freqs) 72 | 73 | # Convert to image space 74 | x = rec.utils.ifft2c(x) 75 | 76 | # Since batch=1, change batch dim to coil dim 77 | # to deal with each coil independently 78 | x, b = self.chans_to_batch_dim(x) 79 | 80 | # Estimate sensitivities 81 | x = self.norm_unet(x) 82 | x = self.batch_chans_to_chan_dim(x, b) 83 | x = self.divide_root_sum_of_squares(x) 84 | 85 | x = x.unsqueeze(1) 86 | return x 87 | 88 | 89 | 90 | 91 | class VarNet(nn.Module): 92 | """ 93 | An adaptation of the end-to-end variational network model for dynamic 94 | MRI reconstruction. Reference paper: 95 | 96 | `A. Sriram et al. "End-to-end variational networks for accelerated MRI 97 | reconstruction". In International Conference on Medical Image Computing and 98 | Computer-Assisted Intervention, 2020`. 99 | """ 100 | 101 | def __init__( 102 | self, 103 | num_cascades: int = 12, 104 | sens_chans: int = 8, 105 | sens_pools: int = 4, 106 | chans: int = 18, 107 | pools: int = 4, 108 | dynamic_type: str = 'XF', 109 | weight_sharing: bool = False, 110 | ): 111 | """ 112 | Args: 113 | num_cascades: Number of cascades (i.e., layers) for variational 114 | network. 115 | sens_chans: Number of channels for sensitivity map U-Net. 116 | sens_pools Number of downsampling and upsampling layers for 117 | sensitivity map U-Net. 118 | chans: Number of channels for cascade U-Net. 119 | pools: Number of downsampling and upsampling layers for cascade U-Net. 120 | dynamic_type: Type of architecture adjustment for dynamic setting. 121 | weight_sharing: Optional setting in 'XF' or 'XT' dynamics mode, allowing 122 | U-Net to share the same parameters in both x-f and y-f planes. 123 | """ 124 | super().__init__() 125 | 126 | self.sens_net = SensitivityModel(sens_chans, sens_pools) 127 | 128 | if dynamic_type in ['XF', 'XT']: 129 | if weight_sharing: 130 | self.model = NormUnet(chans, pools) 131 | else: 132 | self.model = nn.ModuleList([NormUnet(chans, pools), NormUnet(chans, pools)]) 133 | elif dynamic_type == '3D': 134 | self.model = NormUnet3D(chans, pools) 135 | else: 136 | self.model = NormUnet(chans, pools) 137 | 138 | self.cascades = nn.ModuleList( 139 | [VarNetBlock(self.model, dynamic_type, weight_sharing) for _ in range(num_cascades)] 140 | ) 141 | 142 | 143 | def forward(self, masked_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 144 | sens_maps = self.sens_net(masked_kspace, mask) 145 | kspace_pred = masked_kspace.clone() 146 | 147 | for cascade in self.cascades: 148 | kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) 149 | 150 | return rec.utils.complex_abs(rec.utils.complex_mul(rec.utils.ifft2c(kspace_pred), 151 | rec.utils.complex_conj(sens_maps)).sum(dim=2, keepdim=False)) 152 | 153 | 154 | class VarNetBlock(nn.Module): 155 | """ 156 | Model block for time dynamics-adjusted end-to-end variational network. 157 | A series of these blocks can be stacked to form the full variational network. 158 | """ 159 | 160 | def __init__(self, model: nn.Module, dynamic_type: str, weight_sharing: bool,): 161 | """ 162 | Args: 163 | model: Module for "regularization" component of variational 164 | network. Its architecture depends on the specfic dynamics mode. 165 | dynamic_type: Type of architecture adjustment for dynamic setting. 166 | weight_sharing: Optional setting in 'XF' or 'XT' dynamics mode, allowing 167 | U-Net to share the same parameters in both x-f and y-f planes. 168 | """ 169 | super().__init__() 170 | 171 | self.model = model 172 | self.dynamic_type = dynamic_type 173 | self.weight_sharing = weight_sharing 174 | 175 | # Regularisation parameter is learned during training 176 | self.Softplus = nn.Softplus(1.) 177 | lambda_init = np.log(np.exp(1)-1.)/1. 178 | self.lambda_reg = nn.Parameter(torch.tensor(lambda_init*torch.ones(1),dtype=torch.float), 179 | requires_grad=True) 180 | 181 | def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 182 | """ 183 | Forward operator: from coil-combined image-space to k-space. 184 | """ 185 | return rec.utils.fft2c(rec.utils.complex_mul(x, sens_maps)) 186 | 187 | def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 188 | """ 189 | Backward operator: from k-space to coil-combined image-space. 190 | """ 191 | x = rec.utils.ifft2c(x) 192 | return rec.utils.complex_mul(x, rec.utils.complex_conj(sens_maps)).sum( 193 | dim=2, keepdim=True, 194 | ) 195 | 196 | def xfyf_transform(self, image_combined: torch.Tensor) -> torch.Tensor: 197 | """ 198 | Separate input into two volumes in the rotated planes x-f and y-f 199 | (or x-t, y-t if in 'XT' dynamics mode). After being processed by 200 | their respective U-Nets, the volumes are then combined back into one. 201 | """ 202 | b, t, h, w, ch = image_combined.shape 203 | 204 | # Subtract the image temporal average for numerical stability 205 | image_temp = image_combined.clone() 206 | image_mean = torch.stack(t * [torch.mean(image_temp, dim=1)], dim=1) 207 | x = image_combined - image_mean 208 | 209 | if self.dynamic_type == 'XF': 210 | # Apply temporal FFT 211 | x = x.permute(0,2,3,1,4) # b,h,w,t,2 212 | x = rec.utils.fft1c(x) 213 | x = x.permute(0,3,1,2,4) # b,t,h,w,2 214 | 215 | # Reshape to xf, yf planes 216 | xf = x.clone().permute(0,2,3,1,4).view(b*h, 1, w, t, 2) 217 | yf = x.clone().permute(0,3,2,1,4).view(b*w, 1, h, t, 2) 218 | 219 | # UNet opearting on temporal transformed xf, yf-domain 220 | if self.weight_sharing: 221 | xf = self.model(xf) 222 | yf = self.model(yf) 223 | else: 224 | model_xf, model_yf = self.model 225 | xf = model_xf(xf) 226 | yf = model_yf(yf) 227 | 228 | # Reshape from xf, yf 229 | xf_r = xf.view(b,h,1,w,t,2).permute(0,4,2,1,3,5) # b,t,1,h,w,2 230 | yf_r = yf.view(b,w,1,h,t,2).permute(0,4,2,3,1,5) # b,t,1,h,w,2 231 | 232 | out = 0.5 * (xf_r + yf_r) 233 | 234 | if self.dynamic_type == 'XF': 235 | # Apply temporal IFFT 236 | out = out.permute(0,2,3,4,1,5) # b,1,h,w,t,2 237 | out = rec.utils.ifft1c(out) 238 | out = out.permute(0,4,1,2,3,5) # b,t,1,h,w,2 239 | 240 | # Residual connection 241 | return out + image_mean.unsqueeze(2) 242 | 243 | 244 | def forward( 245 | self, 246 | current_kspace: torch.Tensor, 247 | ref_kspace: torch.Tensor, 248 | mask: torch.Tensor, 249 | sens_maps: torch.Tensor, 250 | ) -> torch.Tensor: 251 | 252 | # current_kspace: 6d tensor of shape (b, t, c, h, w, ch) 253 | image_combined = self.sens_reduce(current_kspace, sens_maps) 254 | 255 | if self.dynamic_type in ['XF', 'XT']: 256 | model_out = self.xfyf_transform(image_combined.squeeze(2)) 257 | model_term = self.sens_expand(model_out, sens_maps) 258 | 259 | if self.dynamic_type == '2D': 260 | # Batch dimension b=1. Make first dimension time so 261 | # that each slice is trained independently. This is 262 | # similar to static MRI reconstruction. 263 | 264 | # Input to model has shape (t, 1, h, w, ch) 265 | model_out = self.model(image_combined.squeeze(0)) 266 | model_term = self.sens_expand( 267 | model_out.unsqueeze(0), sens_maps # Add back batch dimension 268 | ) 269 | 270 | if self.dynamic_type == '3D': 271 | # In this mode the whole spatio-temporal volume is 272 | # processed by a 3D U-Net at once. 273 | 274 | # Input to model has shape (b, 1, t, h, w, ch) 275 | model_out = self.model(image_combined.permute(0,2,1,3,4,5)) 276 | model_term = self.sens_expand( 277 | model_out.permute(0,2,1,3,4,5), sens_maps 278 | ) 279 | 280 | # Data consistency step 281 | v = self.Softplus(self.lambda_reg) 282 | return (1 - mask) * model_term + mask * (model_term + v * ref_kspace) / (1 + v) 283 | -------------------------------------------------------------------------------- /reconstruction/models/recurrent_cinenet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | import reconstruction as rec 10 | 11 | 12 | class CineNet_RNN(nn.Module): 13 | """ 14 | A hybrid model for Dynamic MRI Reconstruction, inspired by combining 15 | CineNet [1] and Recurrent Convolutional Neural Networks (RCNN) [2]. 16 | 17 | Reference papers: 18 | [1] A. Kofler et al. `An end-to-end-trainable iterative network architecture 19 | for accelerated radial multi-coil 2D cine MR image reconstruction.` 20 | In Medical Physics, 2021. 21 | [2] C. Qin et al. `Convolutional Recurrent Neural Networks for Dynamic MR 22 | Image Reconstruction`. In IEEE Transactions on Medical Imaging 38.1, 23 | pp. 280–290, 2019. 24 | """ 25 | def __init__( 26 | self, 27 | num_cascades: int = 10, 28 | CG_iters: int = 4, 29 | chans: int = 64, 30 | ): 31 | """ 32 | Args: 33 | num_cascades: Number of alternations between CG and RCNN modules. 34 | CG_iters: Number of CG iterations in the CG module. 35 | chans: Number of channels for convolutional layers of the RCNN. 36 | """ 37 | super(CineNet_RNN, self).__init__() 38 | 39 | self.num_cascades = num_cascades 40 | self.CG_iters = CG_iters 41 | self.chans = chans 42 | 43 | self.bcrnn = BCRNNlayer(input_size=2, hidden_size=self.chans, kernel_size=3) 44 | self.conv1_x = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 45 | self.conv1_h = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 46 | self.conv2_x = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 47 | self.conv2_h = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 48 | self.conv3_x = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 49 | self.conv3_h = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 50 | self.conv4_x = nn.Conv2d(self.chans, 2, 3, padding = 3//2) 51 | self.relu = nn.ReLU(inplace=True) 52 | 53 | self.Softplus = nn.Softplus(1.) 54 | lambda_init = np.log(np.exp(1)-1.)/1. 55 | self.lambda_reg = nn.Parameter(torch.tensor(lambda_init*torch.ones(1),dtype=torch.float), 56 | requires_grad=True) 57 | 58 | 59 | def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 60 | """ 61 | Forward operator: from coil-combined image-space to k-space. 62 | """ 63 | return rec.utils.fft2c(rec.utils.complex_mul(x, sens_maps)) 64 | 65 | def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 66 | """ 67 | Backward operator: from k-space to coil-combined image-space. 68 | """ 69 | x = rec.utils.ifft2c(x) 70 | return rec.utils.complex_mul(x, rec.utils.complex_conj(sens_maps)).sum( 71 | dim=2, keepdim=True, 72 | ) 73 | 74 | def HOperator(self, x: torch.Tensor, mask: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 75 | """ 76 | The operator H = A^H \circ A + \lambda_Reg * \Id, where A is the encoding matrix. 77 | This ensures data consistency. 78 | """ 79 | # Forward operator 80 | k_coils = self.sens_expand(x, sens_maps) 81 | # Apply sampling mask 82 | k_masked = k_coils * mask + 0.0 83 | # Backward operator 84 | x_combined = self.sens_reduce(k_masked, sens_maps) 85 | # Result of H(x) 86 | return x_combined + self.Softplus(self.lambda_reg) * x 87 | 88 | 89 | def ConjGrad(self, x:torch.Tensor, b:torch.Tensor, mask:torch.Tensor, sens_maps:torch.Tensor, CG_iters:int)-> torch.Tensor: 90 | """ 91 | Conjugate Gradient method for solving the system Hx = b 92 | """ 93 | # x is the starting value, b the rhs 94 | r = self.HOperator(x, mask, sens_maps) 95 | r = b-r 96 | 97 | # Initialize p 98 | p = r.clone() 99 | 100 | # Old squared norm of residual 101 | sqnorm_r_old = torch.dot(r.flatten(), r.flatten()) 102 | 103 | for kiter in range(CG_iters): 104 | # Calculate H(p) 105 | d = self.HOperator(p, mask, sens_maps) 106 | 107 | # Calculate step size alpha; 108 | inner_p_d = torch.dot(p.flatten(), d.flatten()) 109 | alpha = sqnorm_r_old / inner_p_d 110 | 111 | # Perform step and calculate new residual 112 | x = torch.add(x, p, alpha = alpha.item()) 113 | r = torch.add(r, d, alpha = -alpha.item()) 114 | 115 | # New residual norm 116 | sqnorm_r_new = torch.dot(r.flatten(), r.flatten()) 117 | 118 | # Calculate beta and update the norm 119 | beta = sqnorm_r_new / sqnorm_r_old 120 | sqnorm_r_old = sqnorm_r_new 121 | 122 | p = torch.add(r, p, alpha = beta.item()) 123 | 124 | return x 125 | 126 | 127 | def forward(self, ref_kspace: torch.Tensor, mask: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 128 | """ 129 | Args: 130 | ref_kspace, mask, sens_maps: tensors of shape `(b, t, c, h, w, ch)`. 131 | 132 | Returns: 133 | Output tensor of shape `(b, t, h, w)`. 134 | """ 135 | 136 | x_ref = self.sens_reduce(ref_kspace, sens_maps) 137 | 138 | x = x_ref.clone().squeeze(2).permute(0,4,2,3,1) 139 | b, ch, h, w, t = x.size() 140 | size_h = [t*b, self.chans, h, w] 141 | 142 | # Initialise parameters of rcnn layers at the first iteration to zero 143 | net = {} 144 | rcnn_layers = 5 145 | for j in range(rcnn_layers-1): 146 | net['t0_x%d'%j] = Variable(torch.zeros(size_h)).cuda() 147 | 148 | # Recurrence through iterations 149 | for i in range(1, self.num_cascades + 1): 150 | 151 | x = x.permute(4,0,1,2,3) 152 | x = x.contiguous() 153 | 154 | net['t%d_x0' % (i-1)] = net['t%d_x0' % (i-1)].view(t, b, self.chans, h, w) 155 | net['t%d_x0'%i] = self.bcrnn(x, net['t%d_x0'%(i-1)]) 156 | net['t%d_x0'%i] = net['t%d_x0'%i].view(-1, self.chans, h, w) 157 | 158 | net['t%d_x1'%i] = self.conv1_x(net['t%d_x0'%i]) 159 | net['t%d_h1'%i] = self.conv1_h(net['t%d_x1'%(i-1)]) 160 | net['t%d_x1'%i] = self.relu(net['t%d_h1'%i] + net['t%d_x1'%i]) 161 | 162 | net['t%d_x2'%i] = self.conv2_x(net['t%d_x1'%i]) 163 | net['t%d_h2'%i] = self.conv2_h(net['t%d_x2'%(i-1)]) 164 | net['t%d_x2'%i] = self.relu(net['t%d_h2'%i] + net['t%d_x2'%i]) 165 | 166 | net['t%d_x3'%i] = self.conv3_x(net['t%d_x2'%i]) 167 | net['t%d_h3'%i] = self.conv3_h(net['t%d_x3'%(i-1)]) 168 | net['t%d_x3'%i] = self.relu(net['t%d_h3'%i] + net['t%d_x3'%i]) 169 | 170 | net['t%d_x4'%i] = self.conv4_x(net['t%d_x3'%i]) 171 | 172 | x = x.view(-1, ch, h, w) 173 | net['t%d_out'%i] = x + net['t%d_x4'%i] 174 | 175 | net['t%d_out'%i] = net['t%d_out'%i].view(-1, b, ch, h, w) 176 | net['t%d_out'%i] = net['t%d_out'%i].permute(1,0,3,4,2).unsqueeze(2) 177 | net['t%d_out'%i].contiguous() 178 | 179 | net['t%d_out'%i] = self.ConjGrad( 180 | net['t%d_out'%i], x_ref + self.Softplus(self.lambda_reg) * net['t%d_out'%i], mask, sens_maps, self.CG_iters 181 | ) 182 | net['t%d_out'%i] = net['t%d_out'%i].squeeze(2).permute(0,4,2,3,1) 183 | 184 | x = net['t%d_out'%i] 185 | 186 | out = net['t%d_out'%i] 187 | return rec.utils.complex_abs(out.permute(0,4,2,3,1)) 188 | 189 | 190 | 191 | class CRNNcell(nn.Module): 192 | """ 193 | Convolutional RNN cell that evolves over both time and iterations. 194 | """ 195 | def __init__( 196 | self, 197 | input_size: int, 198 | hidden_size: int, 199 | kernel_size: int, 200 | ): 201 | """ 202 | Args: 203 | input_size: Number of input channels 204 | hidden_size: Number of RCNN hidden layers channels 205 | kernel_size: Size of convolutional kernel 206 | """ 207 | super(CRNNcell, self).__init__() 208 | 209 | # Convolution for input 210 | self.i2h = nn.Conv2d(input_size, hidden_size, kernel_size, padding=kernel_size // 2) 211 | # Convolution for hidden states in temporal dimension 212 | self.h2h = nn.Conv2d(hidden_size, hidden_size, kernel_size, padding=kernel_size // 2) 213 | # Convolution for hidden states in iteration dimension 214 | self.ih2ih = nn.Conv2d(hidden_size, hidden_size, kernel_size, padding=kernel_size // 2) 215 | 216 | self.relu = nn.ReLU(inplace=True) 217 | 218 | def forward( 219 | self, 220 | input: torch.Tensor, 221 | hidden_iteration: torch.Tensor, 222 | hidden: torch.Tensor, 223 | ) -> torch.Tensor: 224 | """ 225 | Args: 226 | input: Input 4D tensor of shape `(b, ch, h, w)` 227 | hidden_iteration: hidden states in iteration dimension, 4d tensor of shape (b, hidden_size, h, w) 228 | hidden: hidden states in temporal dimension, 4d tensor of shape (b, hidden_size, h, w) 229 | Returns: 230 | Output tensor of shape `(b, hidden_size, h, w)`. 231 | """ 232 | in_to_hid = self.i2h(input) 233 | hid_to_hid = self.h2h(hidden) 234 | ih_to_ih = self.ih2ih(hidden_iteration) 235 | 236 | hidden = self.relu(in_to_hid + hid_to_hid + ih_to_ih) 237 | 238 | return hidden 239 | 240 | 241 | class BCRNNlayer(nn.Module): 242 | """ 243 | Bidirectional Convolutional RNN layer 244 | """ 245 | def __init__( 246 | self, 247 | input_size: int, 248 | hidden_size: int, 249 | kernel_size: int, 250 | ): 251 | """ 252 | Args: 253 | input_size: Number of input channels 254 | hidden_size: Number of RCNN hidden layers channels 255 | kernel_size: Size of convolutional kernel 256 | """ 257 | super(BCRNNlayer, self).__init__() 258 | 259 | self.hidden_size = hidden_size 260 | self.CRNN_model = CRNNcell(input_size, self.hidden_size, kernel_size) 261 | 262 | def forward(self, input: torch.Tensor, hidden_iteration: torch.Tensor) -> torch.Tensor: 263 | """ 264 | Args: 265 | input: Input 5D tensor of shape `(t, b, ch, h, w)` 266 | hidden_iteration: hidden states (output of BCRNNlayer) from previous 267 | iteration, 5d tensor of shape (t, b, hidden_size, h, w) 268 | Returns: 269 | Output tensor of shape `(t, b, hidden_size, h, w)`. 270 | """ 271 | t, b, ch, h, w = input.shape 272 | size_h = [b, self.hidden_size, h, w] 273 | 274 | hid_init = Variable(torch.zeros(size_h)).cuda() 275 | output_f = [] 276 | output_b = [] 277 | 278 | # forward 279 | hidden = hid_init 280 | for i in range(t): 281 | hidden = self.CRNN_model(input[i], hidden_iteration[i], hidden) 282 | output_f.append(hidden) 283 | output_f = torch.cat(output_f) 284 | 285 | # backward 286 | hidden = hid_init 287 | for i in range(t): 288 | hidden = self.CRNN_model(input[t - i - 1], hidden_iteration[t - i -1], hidden) 289 | output_b.append(hidden) 290 | output_b = torch.cat(output_b[::-1]) 291 | 292 | output = output_f + output_b 293 | 294 | if b == 1: 295 | output = output.view(t, 1, self.hidden_size, h, w) 296 | 297 | return output 298 | -------------------------------------------------------------------------------- /reconstruction/pl_modules/xpdnet_module.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from argparse import ArgumentParser 3 | import torch 4 | 5 | from reconstruction.data import transforms 6 | from reconstruction.utils import SSIMLoss 7 | from reconstruction.models import XPDNet, XPDNet_RNN 8 | from .mri_module import MriModule 9 | 10 | 11 | class XPDNetModule(MriModule): 12 | """ 13 | Pytorch Lightning module for training XPDNet. 14 | 15 | The architecture variations for dynamic MRI reconstruction are 16 | inspired by the XPDNet for static MRI reconstruction, introduced in 17 | the following paper: 18 | 19 | Z. Ramzi et al. "XPDNet for MRI Reconstruction: an application to the 20 | 2020 fastMRI challenge". arXiv: 2010.07290, 2021. 21 | """ 22 | def __init__( 23 | self, 24 | num_cascades: int = 12, 25 | sens_chans: int = 8, 26 | sens_pools: int = 4, 27 | crnn_chans: int = 18, 28 | n_scales: int = 3, 29 | n_filters_per_scale: List[int] = [16, 32, 64], 30 | n_convs_per_scale: List[int] = [2, 2, 2], 31 | n_first_convs: int = 1, 32 | first_conv_n_filters: int = 16, 33 | res: bool = False, 34 | primal_only: bool = True, 35 | n_primal: int = 5, 36 | n_dual: int = 1, 37 | dynamic_type: str = 'XF', 38 | weight_sharing: bool = False, 39 | lr: float = 0.0003, 40 | lr_step_size: int = 40, 41 | lr_gamma: float = 0.1, 42 | weight_decay: float = 0.0, 43 | **kwargs, 44 | ): 45 | """ 46 | Args: 47 | num_cascades: Number of unrolled iterations for XPDNet. 48 | sens_chans: Number of channels for sensitivity map U-Net. 49 | sens_pools Number of downsampling and upsampling layers for 50 | sensitivity map U-Net. 51 | crnn_chans: Hidden state size in CRNN XPDNet. 52 | n_scales: Number of scales, i.e. number of pooling layers, in 53 | image denoiser MWCNN. 54 | n_filters_per_scale: Number of filters used by the convolutional 55 | layers at each scale in image denoiser MWCNN. 56 | n_convs_per_scale: Number of convolutional layers per scale in 57 | image denoiser MWCNN. 58 | n_first_convs: Number of convolutional layers at the start of 59 | the architecture, i.e. before pooling layers, in image denoiser 60 | MWCNN. 61 | first_conv_n_filters: Number of filters used by the inital 62 | convolutional layer in image denoiser MWCNN. 63 | res: Whether to use a residual connection between input and output in 64 | image denoiser MWCNN. 65 | primal_only: Whether to generate a buffer in k-space or only in image 66 | space. 67 | n_primal: The size of the buffer in image-space. 68 | n_dual: The size of the buffer in k-space. 69 | dynamic_type: Type of architecture adjustment for dynamic setting. 70 | weight_sharing: Optional setting in 'XF' or 'XT' dynamics mode, allowing 71 | image net to share the same parameters in both x-f and y-f planes. 72 | lr: Learning rate. 73 | lr_step_size: Learning rate step size. 74 | lr_gamma: Learning rate gamma decay. 75 | weight_decay: Parameter for penalizing weights norm. 76 | """ 77 | super().__init__(**kwargs) 78 | self.save_hyperparameters() 79 | 80 | self.num_cascades = num_cascades 81 | self.sens_chans = sens_chans 82 | self.sens_pools = sens_pools 83 | self.crnn_chans = crnn_chans 84 | self.n_scales = n_scales 85 | self.n_filters_per_scale = n_filters_per_scale 86 | self.n_convs_per_scale = n_convs_per_scale 87 | self.n_first_convs = n_first_convs 88 | self.first_conv_n_filters = first_conv_n_filters 89 | self.res = res 90 | self.primal_only = primal_only 91 | self.n_primal = n_primal 92 | self.n_dual = n_dual 93 | self.dynamic_type = dynamic_type 94 | self.weight_sharing = weight_sharing 95 | self.lr = lr 96 | self.lr_step_size = lr_step_size 97 | self.lr_gamma = lr_gamma 98 | self.weight_decay = weight_decay 99 | 100 | assert self.dynamic_type in ['XF', 'XT', '2D', 'CRNN'], \ 101 | "dynamic_type argument must be one of 'XF', 'XT', '2D' or 'CRNN'" 102 | 103 | if self.dynamic_type == 'CRNN': 104 | self.xpdnet = XPDNet_RNN( 105 | num_cascades=self.num_cascades, 106 | sens_chans=self.sens_chans, 107 | sens_pools=self.sens_pools, 108 | chans=self.crnn_chans, 109 | primal_only=self.primal_only, 110 | n_primal=self.n_primal, 111 | n_dual=self.n_dual, 112 | ) 113 | else: 114 | self.xpdnet = XPDNet( 115 | num_cascades=self.num_cascades, 116 | sens_chans=self.sens_chans, 117 | sens_pools=self.sens_pools, 118 | n_scales=self.n_scales, 119 | n_filters_per_scale=self.n_filters_per_scale, 120 | n_convs_per_scale=self.n_convs_per_scale, 121 | n_first_convs=self.n_first_convs, 122 | first_conv_n_filters=self.first_conv_n_filters, 123 | res=self.res, 124 | primal_only=self.primal_only, 125 | n_primal=self.n_primal, 126 | n_dual=self.n_dual, 127 | dynamic_type=self.dynamic_type, 128 | weight_sharing = self.weight_sharing, 129 | ) 130 | 131 | self.loss = SSIMLoss() 132 | 133 | def forward(self, masked_kspace, mask): 134 | return self.xpdnet(masked_kspace, mask) 135 | 136 | def training_step(self, batch, batch_idx): 137 | masked_kspace, mask, target, fname, slice_num, max_value, _ = batch 138 | 139 | output = self(masked_kspace, mask) 140 | target, output = transforms.center_crop_to_smallest(target, output) 141 | 142 | return { 143 | "batch_idx": batch_idx, 144 | "fname": fname, 145 | "slice_num": slice_num, 146 | "max_value": max_value, 147 | "output": output, 148 | "target": target, 149 | "loss": self.loss( 150 | output.unsqueeze(1), target.unsqueeze(1), data_range=max_value 151 | ), 152 | } 153 | 154 | def validation_step(self, batch, batch_idx): 155 | masked_kspace, mask, target, fname, slice_num, max_value, _ = batch 156 | 157 | output = self.forward(masked_kspace, mask) 158 | target, output = transforms.center_crop_to_smallest(target, output) 159 | 160 | return { 161 | "batch_idx": batch_idx, 162 | "fname": fname, 163 | "slice_num": slice_num, 164 | "max_value": max_value, 165 | "output": output, 166 | "target": target, 167 | "val_loss": self.loss( 168 | output.unsqueeze(1), target.unsqueeze(1), data_range=max_value 169 | ), 170 | } 171 | 172 | def test_step(self, batch, batch_idx): 173 | masked_kspace, mask, target, fname, slice_num, max_value, _ = batch 174 | 175 | output = self(masked_kspace, mask) 176 | target, output = transforms.center_crop_to_smallest(target, output) 177 | 178 | return { 179 | "batch_idx": batch_idx, 180 | "fname": fname, 181 | "slice_num": slice_num, 182 | "max_value": max_value, 183 | "output": output, 184 | "target": target, 185 | "test_loss": self.loss( 186 | output.unsqueeze(1), target.unsqueeze(1), data_range=max_value 187 | ), 188 | } 189 | 190 | def configure_optimizers(self): 191 | optim = torch.optim.Adam( 192 | self.parameters(), lr=self.lr, weight_decay=self.weight_decay 193 | ) 194 | scheduler = torch.optim.lr_scheduler.StepLR( 195 | optim, self.lr_step_size, self.lr_gamma 196 | ) 197 | 198 | return [optim], [scheduler] 199 | 200 | @staticmethod 201 | def add_model_specific_args(parent_parser): # pragma: no-cover 202 | """ 203 | Define parameters that only apply to this model 204 | """ 205 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 206 | parser = MriModule.add_model_specific_args(parser) 207 | 208 | # param overwrites 209 | 210 | # network params 211 | parser.add_argument( 212 | "--num_cascades", 213 | default=12, 214 | type=int, 215 | help="Number of XPDNet cascades", 216 | ) 217 | parser.add_argument( 218 | "--sens_chans", 219 | default=8, 220 | type=int, 221 | help="Number of channels for sense map estimation U-Net in XPDNet", 222 | ) 223 | parser.add_argument( 224 | "--sens_pools", 225 | default=4, 226 | type=int, 227 | help="Number of pooling layers for sense map estimation U-Net in XPDNet", 228 | ) 229 | parser.add_argument( 230 | "--crnn_chans", 231 | default=18, 232 | type=int, 233 | help="Hidden state size in CRNN XPDNet", 234 | ) 235 | parser.add_argument( 236 | "--n_scales", 237 | default=3, 238 | type=int, 239 | help="Number of scales, i.e. number of pooling layers, in image denoiser module", 240 | ) 241 | parser.add_argument( 242 | "--n_filters_per_scale", 243 | nargs="+", 244 | default=[16, 32, 64], 245 | type=int, 246 | help="""Number of filters used by the convolutional layers 247 | at each scale in image denoiser module""", 248 | ) 249 | parser.add_argument( 250 | "--n_convs_per_scale", 251 | nargs="+", 252 | default=[2, 2, 2], 253 | type=int, 254 | help="""Number of convolutional layers per scale in 255 | image denoiser module""", 256 | ) 257 | parser.add_argument( 258 | "--n_first_convs", 259 | default=1, 260 | type=int, 261 | help="""Number of convolutional layers at the start of the architecture, 262 | i.e. before pooling layers, in image denoiser module""", 263 | ) 264 | parser.add_argument( 265 | "--first_conv_n_filters", 266 | default=16, 267 | type=int, 268 | help="Number of filters in the inital convolutional layers", 269 | ) 270 | parser.add_argument( 271 | "--res", 272 | default=False, 273 | type=bool, 274 | help="Whether to use a residual connection in image denoising module", 275 | ) 276 | parser.add_argument( 277 | "--primal_only", 278 | default=True, 279 | type=bool, 280 | help="Whether to generate a buffer in k-space or only in image-space", 281 | ) 282 | parser.add_argument( 283 | "--n_primal", 284 | default=5, 285 | type=int, 286 | help="The size of the buffer in image-space", 287 | ) 288 | parser.add_argument( 289 | "--n_dual", 290 | default=1, 291 | type=int, 292 | help="The size of the buffer in k-space", 293 | ) 294 | parser.add_argument( 295 | "--dynamic_type", 296 | default='XF', 297 | type=str, 298 | help="""Architectural variation for dynamic reconstruction. 299 | Options are ['XF', 'XT', '2D', 'CRNN']""", 300 | ) 301 | parser.add_argument( 302 | "--weight_sharing", 303 | default=False, 304 | type=bool, 305 | help="Allows parameter sharing of MWCNN nets in x-f, y-f planes", 306 | ) 307 | 308 | # training params (opt) 309 | parser.add_argument( 310 | "--lr", default=0.0003, type=float, help="Adam learning rate" 311 | ) 312 | parser.add_argument( 313 | "--lr_step_size", 314 | default=40, 315 | type=int, 316 | help="Epoch at which to decrease step size", 317 | ) 318 | parser.add_argument( 319 | "--lr_gamma", 320 | default=0.1, 321 | type=float, 322 | help="Extent to which step size should be decreased", 323 | ) 324 | parser.add_argument( 325 | "--weight_decay", 326 | default=0.0, 327 | type=float, 328 | help="Strength of weight decay regularization", 329 | ) 330 | 331 | return parser 332 | -------------------------------------------------------------------------------- /reconstruction/pl_modules/data_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This source code is based on the fastMRI repository from Facebook AI 3 | Research and is used as a general framework to handle MRI data. Link: 4 | 5 | https://github.com/facebookresearch/fastMRI 6 | """ 7 | 8 | from argparse import ArgumentParser 9 | from pathlib import Path 10 | from typing import Callable, Optional, Union 11 | 12 | import torch 13 | import pytorch_lightning as pl 14 | 15 | from reconstruction.data import CombinedSliceDataset, SliceDataset, VolumeSampler 16 | 17 | 18 | def worker_init_fn(worker_id): 19 | """Handle random seeding for all mask_func.""" 20 | worker_info = torch.utils.data.get_worker_info() 21 | data: Union[ 22 | SliceDataset, CombinedSliceDataset 23 | ] = worker_info.dataset # pylint: disable=no-member 24 | 25 | # Check if we are using DDP 26 | is_ddp = False 27 | if torch.distributed.is_available(): 28 | if torch.distributed.is_initialized(): 29 | is_ddp = True 30 | 31 | # for NumPy random seed we need it to be in this range 32 | base_seed = worker_info.seed # pylint: disable=no-member 33 | 34 | if isinstance(data, CombinedSliceDataset): 35 | for i, dataset in enumerate(data.datasets): 36 | if dataset.transform.mask_func is not None: 37 | if ( 38 | is_ddp 39 | ): # DDP training: unique seed is determined by worker, device, dataset 40 | seed_i = ( 41 | base_seed 42 | - worker_info.id 43 | + torch.distributed.get_rank() 44 | * (worker_info.num_workers * len(data.datasets)) 45 | + worker_info.id * len(data.datasets) 46 | + i 47 | ) 48 | else: 49 | seed_i = ( 50 | base_seed 51 | - worker_info.id 52 | + worker_info.id * len(data.datasets) 53 | + i 54 | ) 55 | dataset.transform.mask_func.rng.seed(seed_i % (2 ** 32 - 1)) 56 | elif data.transform.mask_func is not None: 57 | if is_ddp: # DDP training: unique seed is determined by worker and device 58 | seed = base_seed + torch.distributed.get_rank() * worker_info.num_workers 59 | else: 60 | seed = base_seed 61 | data.transform.mask_func.rng.seed(seed % (2 ** 32 - 1)) 62 | 63 | 64 | class MriDataModule(pl.LightningDataModule): 65 | """ 66 | Data module class for the MRI datset used in this project. 67 | 68 | This class handles configurations for training on MRI data. It is set 69 | up to process configurations independently of training modules. 70 | 71 | For training with ddp be sure to set distributed_sampler=True to make sure 72 | that volumes are dispatched to the same GPU for the validation loop. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | data_path: Path, 78 | train_transform: Callable, 79 | val_transform: Callable, 80 | test_transform: Callable, 81 | combine_train_val: bool = False, 82 | test_split: str = "test", 83 | test_path: Optional[Path] = None, 84 | sample_rate: Optional[float] = None, 85 | volume_sample_rate: Optional[float] = None, 86 | use_dataset_cache_file: bool = True, 87 | batch_size: int = 1, 88 | num_workers: int = 4, 89 | distributed_sampler: bool = False, 90 | ): 91 | """ 92 | Args: 93 | data_path: Path to root data directory (with expected subdirectories 94 | train/valid/test). 95 | train_transform: A transform object for the training split. 96 | val_transform: A transform object for the validation split. 97 | test_transform: A transform object for the test split. 98 | combine_train_val: Whether to combine train and val splits into one 99 | large train dataset. 100 | test_split: Name of test split from ("test", "challenge"). 101 | test_path: An optional test path. Passing this overwrites data_path 102 | and test_split. 103 | sample_rate: Fraction of slices of the training data split to use. Can be 104 | set to less than 1.0 for rapid prototyping. If not set, it defaults to 1.0. 105 | To subsample the dataset either set sample_rate (sample by slice) or 106 | volume_sample_rate (sample by volume), but not both. 107 | volume_sample_rate: Fraction of volumes of the training data split to use. Can be 108 | set to less than 1.0 for rapid prototyping. If not set, it defaults to 1.0. 109 | To subsample the dataset either set sample_rate (sample by slice) or 110 | volume_sample_rate (sample by volume), but not both. 111 | use_dataset_cache_file: Whether to cache dataset metadata. This is 112 | very useful for large datasets. 113 | batch_size: Batch size. 114 | num_workers: Number of workers for PyTorch dataloader. 115 | distributed_sampler: Whether to use a distributed sampler. This 116 | should be set to True if training with ddp. 117 | """ 118 | super().__init__() 119 | 120 | self.data_path = data_path 121 | self.train_transform = train_transform 122 | self.val_transform = val_transform 123 | self.test_transform = test_transform 124 | self.combine_train_val = combine_train_val 125 | self.test_split = test_split 126 | self.test_path = test_path 127 | self.sample_rate = sample_rate 128 | self.volume_sample_rate = volume_sample_rate 129 | self.use_dataset_cache_file = use_dataset_cache_file 130 | self.batch_size = batch_size 131 | self.num_workers = num_workers 132 | self.distributed_sampler = distributed_sampler 133 | 134 | def _create_data_loader( 135 | self, 136 | data_transform: Callable, 137 | data_partition: str, 138 | sample_rate: Optional[float] = None, 139 | volume_sample_rate: Optional[float] = None, 140 | ) -> torch.utils.data.DataLoader: 141 | if data_partition == "train": 142 | is_train = True 143 | sample_rate = self.sample_rate if sample_rate is None else sample_rate 144 | volume_sample_rate = ( 145 | self.volume_sample_rate 146 | if volume_sample_rate is None 147 | else volume_sample_rate 148 | ) 149 | else: 150 | is_train = False 151 | sample_rate = 1.0 152 | volume_sample_rate = None # default case, no subsampling 153 | 154 | # if desired, combine train and val together for the train split 155 | dataset: Union[SliceDataset, CombinedSliceDataset] 156 | if is_train and self.combine_train_val: 157 | data_paths = [ 158 | self.data_path / f"train", 159 | self.data_path / f"valid", 160 | ] 161 | data_transforms = [data_transform, data_transform] 162 | sample_rates, volume_sample_rates = None, None # default: no subsampling 163 | if sample_rate is not None: 164 | sample_rates = [sample_rate, sample_rate] 165 | if volume_sample_rate is not None: 166 | volume_sample_rates = [volume_sample_rate, volume_sample_rate] 167 | dataset = CombinedSliceDataset( 168 | roots=data_paths, 169 | transforms=data_transforms, 170 | sample_rates=sample_rates, 171 | volume_sample_rates=volume_sample_rates, 172 | use_dataset_cache=self.use_dataset_cache_file, 173 | ) 174 | else: 175 | if data_partition == "test" and self.test_path is not None: 176 | data_path = self.test_path 177 | else: 178 | data_path = self.data_path / f"{data_partition}" 179 | 180 | dataset = SliceDataset( 181 | root=data_path, 182 | transform=data_transform, 183 | sample_rate=sample_rate, 184 | volume_sample_rate=volume_sample_rate, 185 | use_dataset_cache=self.use_dataset_cache_file, 186 | ) 187 | 188 | # ensure that entire volumes go to the same GPU in the ddp setting 189 | sampler = None 190 | if self.distributed_sampler: 191 | if is_train: 192 | sampler = torch.utils.data.DistributedSampler(dataset) 193 | else: 194 | sampler = VolumeSampler(dataset) 195 | 196 | dataloader = torch.utils.data.DataLoader( 197 | dataset=dataset, 198 | batch_size=self.batch_size, 199 | num_workers=self.num_workers, 200 | worker_init_fn=worker_init_fn, 201 | sampler=sampler, 202 | ) 203 | 204 | return dataloader 205 | 206 | def prepare_data(self): 207 | # call dataset for each split one time to make sure the cache is set up on the 208 | # rank 0 ddp process. if not using cache, don't do this 209 | if self.use_dataset_cache_file: 210 | if self.test_path is not None: 211 | test_path = self.test_path 212 | else: 213 | test_path = self.data_path / f"test" 214 | data_paths = [ 215 | self.data_path / f"train", 216 | self.data_path / f"valid", 217 | test_path, 218 | ] 219 | data_transforms = [ 220 | self.train_transform, 221 | self.val_transform, 222 | self.test_transform, 223 | ] 224 | for i, (data_path, data_transform) in enumerate( 225 | zip(data_paths, data_transforms) 226 | ): 227 | sample_rate = self.sample_rate if i == 0 else 1.0 228 | volume_sample_rate = self.volume_sample_rate if i == 0 else None 229 | _ = SliceDataset( 230 | root=data_path, 231 | transform=data_transform, 232 | sample_rate=sample_rate, 233 | volume_sample_rate=volume_sample_rate, 234 | use_dataset_cache=self.use_dataset_cache_file, 235 | ) 236 | 237 | def train_dataloader(self): 238 | return self._create_data_loader(self.train_transform, data_partition="train") 239 | 240 | def val_dataloader(self): 241 | return self._create_data_loader( 242 | self.val_transform, data_partition="valid", sample_rate=1.0 243 | ) 244 | 245 | def test_dataloader(self): 246 | return self._create_data_loader( 247 | self.test_transform, 248 | data_partition=self.test_split, 249 | sample_rate=1.0, 250 | ) 251 | 252 | @staticmethod 253 | def add_data_specific_args(parent_parser): # pragma: no-cover 254 | """ 255 | Define parameters that only apply to this model 256 | """ 257 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 258 | 259 | # dataset arguments 260 | parser.add_argument( 261 | "--data_path", 262 | default=None, 263 | type=Path, 264 | help="Path to data root, expects subdirectories train/valid/test", 265 | ) 266 | parser.add_argument( 267 | "--test_path", 268 | default=None, 269 | type=Path, 270 | help="Path to data for test mode. This overwrites data_path and test_split", 271 | ) 272 | parser.add_argument( 273 | "--test_split", 274 | choices=("test", "challenge"), 275 | default="test", 276 | type=str, 277 | help="Which data split to use as test split", 278 | ) 279 | parser.add_argument( 280 | "--sample_rate", 281 | default=None, 282 | type=float, 283 | help="Fraction of slices in the dataset to use (train split only). If not given all will be used. Cannot set together with volume_sample_rate.", 284 | ) 285 | parser.add_argument( 286 | "--volume_sample_rate", 287 | default=None, 288 | type=float, 289 | help="Fraction of volumes of the dataset to use (train split only). If not given all will be used. Cannot set together with sample_rate.", 290 | ) 291 | parser.add_argument( 292 | "--use_dataset_cache_file", 293 | default=True, 294 | type=bool, 295 | help="Whether to cache dataset metadata in a pkl file", 296 | ) 297 | parser.add_argument( 298 | "--combine_train_val", 299 | default=False, 300 | type=bool, 301 | help="Whether to combine train and val splits for training", 302 | ) 303 | 304 | # data loader arguments 305 | parser.add_argument( 306 | "--batch_size", default=1, type=int, help="Data loader batch size" 307 | ) 308 | parser.add_argument( 309 | "--num_workers", 310 | default=4, 311 | type=float, 312 | help="Number of workers to use in data loader", 313 | ) 314 | 315 | return parser 316 | -------------------------------------------------------------------------------- /reconstruction/models/recurrent_xpdnet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | import reconstruction as rec 10 | from .xpdnet import SensitivityModel, ForwardOperator, BackwardOperator 11 | from .denoisers.kspace_net import KSpaceCNN 12 | 13 | 14 | class XPDNet_RNN(nn.Module): 15 | """ 16 | A hybrid model for Dynamic MRI Reconstruction, inspired by combining 17 | XPDNet [1] and Recurrent Convolutional Neural Networks (RCNN) [2]. 18 | 19 | Reference papers: 20 | [1] Z. Ramzi et al. `XPDNet for MRI Reconstruction: an application to the 2020 fastMRI 21 | challenge`. arXiv: 2010.07290, 2021. 22 | [2] C. Qin et al. `Convolutional Recurrent Neural Networks for Dynamic MR 23 | Image Reconstruction`. In IEEE Transactions on Medical Imaging 38.1, 24 | pp. 280–290, 2019. 25 | """ 26 | def __init__( 27 | self, 28 | num_cascades: int = 12, 29 | sens_chans: int = 8, 30 | sens_pools: int = 4, 31 | chans: int = 18, 32 | primal_only: bool = True, 33 | n_primal: int = 5, 34 | n_dual: int = 1, 35 | ): 36 | """ 37 | Args: 38 | num_cascades: Number of cascades (i.e., layers) for variational 39 | network. 40 | sens_chans: Number of channels for sensitivity map U-Net. 41 | sens_pools Number of downsampling and upsampling layers for 42 | sensitivity map U-Net. 43 | chans: Number of channels for convolutional layers of the RCNN. 44 | primal_only: Whether to generate a buffer in k-space or only in image 45 | space. 46 | n_primal: The size of the buffer in image-space. 47 | n_dual: The size of the buffer in k-space. 48 | """ 49 | super(XPDNet_RNN, self).__init__() 50 | 51 | self.num_cascades = num_cascades 52 | self.chans = chans 53 | self.domain_sequence = 'KI' * num_cascades 54 | self.i_buffer_mode = True 55 | self.k_buffer_mode = not primal_only 56 | self.i_buffer_size = n_primal 57 | self.k_buffer_size = 1 if primal_only else n_dual 58 | 59 | self.backward_op = BackwardOperator(masked=True) 60 | self.forward_op = ForwardOperator(masked=True) 61 | 62 | self.sens_net = SensitivityModel(sens_chans, sens_pools) 63 | self.bcrnn = BCRNNlayer(input_size=2*(n_primal+1), hidden_size=self.chans, kernel_size=3) 64 | 65 | if not primal_only: 66 | self.kspace_net = nn.ModuleList([KSpaceCNN( 67 | in_chans = 2 * (n_dual+2), 68 | out_chans = 2 * n_dual, 69 | n_convs = 3, 70 | n_filters = 16, 71 | ) for _ in range(num_cascades)] 72 | ) 73 | else: 74 | self.kspace_net = [self.measurements_residual for _ in range(num_cascades)] 75 | 76 | 77 | self.conv1_x = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 78 | self.conv1_h = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 79 | self.conv2_x = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 80 | self.conv2_h = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 81 | self.conv3_x = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 82 | self.conv3_h = nn.Conv2d(self.chans, self.chans, 3, padding = 3//2) 83 | self.conv4_x = nn.Conv2d(self.chans, 2*n_primal, 3, padding = 3//2) 84 | self.relu = nn.ReLU(inplace=True) 85 | 86 | 87 | def measurements_residual(self, concat_kspace: torch.Tensor) -> torch.Tensor: 88 | current_kspace = torch.stack([concat_kspace[..., 0], concat_kspace[..., 2]], dim=-1) 89 | ref_kspace = torch.stack([concat_kspace[..., 1], concat_kspace[..., 3]], dim=-1) 90 | return current_kspace - ref_kspace 91 | 92 | 93 | def k_domain_correction( 94 | self, 95 | i_cascade: int, 96 | image_buffer: torch.Tensor, 97 | kspace_buffer: torch.Tensor, 98 | mask: torch.Tensor, 99 | sens_maps: torch.Tensor, 100 | ref_kspace: torch.Tensor 101 | ) -> torch.Tensor: 102 | """ 103 | Updates the kspace buffer and feeds it to the kspace net 104 | corresponding to the current unrolled iteration. 105 | """ 106 | 107 | forward_op_res = rec.utils.real_to_complex_multi_ch( 108 | self.forward_op(image_buffer, mask, sens_maps, self.i_buffer_size), 1, 109 | ) 110 | 111 | if self.k_buffer_mode: 112 | kspace_buffer = rec.utils.real_to_complex_multi_ch(kspace_buffer, self.k_buffer_size) 113 | kspace_buffer = torch.cat([kspace_buffer, forward_op_res], dim=-1) 114 | else: 115 | kspace_buffer = forward_op_res 116 | 117 | kspace_buffer = torch.cat( 118 | [kspace_buffer, 119 | rec.utils.real_to_complex_multi_ch(ref_kspace, 1)], 120 | dim=-1, 121 | ) 122 | 123 | kspace_buffer = rec.utils.complex_to_real_multi_ch(kspace_buffer) 124 | return self.kspace_net[i_cascade](kspace_buffer) 125 | 126 | 127 | def update_image_buffer( 128 | self, 129 | image_buffer: torch.Tensor, 130 | kspace_buffer: torch.Tensor, 131 | mask: torch.Tensor, 132 | sens_maps: torch.Tensor 133 | ) -> torch.Tensor: 134 | """ 135 | Updates the image buffer from the kspace buffer at the 136 | current unrolled iteration. 137 | """ 138 | 139 | backward_op_res = rec.utils.real_to_complex_multi_ch( 140 | self.backward_op(kspace_buffer, mask, sens_maps, self.k_buffer_size), 1, 141 | ) 142 | 143 | if self.i_buffer_mode: 144 | image_buffer = rec.utils.real_to_complex_multi_ch(image_buffer, self.i_buffer_size) 145 | image_buffer = torch.cat([image_buffer, backward_op_res], dim=-1) 146 | else: 147 | image_buffer = backward_op_res 148 | 149 | image_buffer = rec.utils.complex_to_real_multi_ch(image_buffer) 150 | return image_buffer 151 | 152 | 153 | def forward(self, ref_kspace: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 154 | """ 155 | Args: 156 | ref_kspace, mask: Input 6D tensors of shape `(b, t, c, h, w, 2)`. 157 | 158 | Returns: 159 | Output tensor of shape `(b, t, h, w)`. 160 | """ 161 | 162 | sens_maps = self.sens_net(ref_kspace, mask) 163 | image = self.backward_op(ref_kspace, mask, sens_maps, 1) 164 | 165 | # Generate buffers in k-space and image-space 166 | kspace_buffer = torch.repeat_interleave(ref_kspace, self.k_buffer_size, dim=-1) 167 | image_buffer = torch.repeat_interleave(image, self.i_buffer_size, dim=-1) 168 | 169 | b, t, h, w, ch_primal = image_buffer.squeeze(2).size() 170 | ch = 2 * (self.i_buffer_size + 1) 171 | size_h = [t*b, self.chans, h, w] 172 | 173 | # Initialise parameters of rcnn layers at the first iteration to zero 174 | net = {} 175 | rcnn_layers = 5 176 | for j in range(rcnn_layers-1): 177 | net['t0_x%d'%j] = Variable(torch.zeros(size_h)).cuda() 178 | 179 | # Recurrence through iterations 180 | for i in range(1, self.num_cascades + 1): 181 | 182 | kspace_buffer = self.k_domain_correction( 183 | i-1, 184 | image_buffer, 185 | kspace_buffer, 186 | mask, 187 | sens_maps, 188 | ref_kspace, 189 | ) 190 | 191 | image_buffer = self.update_image_buffer( 192 | image_buffer, 193 | kspace_buffer, 194 | mask, 195 | sens_maps, 196 | ) 197 | 198 | x = image_buffer.squeeze(2).permute(1,0,4,2,3) # (t,b,ch,h,w) 199 | x = x.contiguous() 200 | 201 | net['t%d_x0' % (i-1)] = net['t%d_x0' % (i-1)].view(t, b, self.chans, h, w) 202 | net['t%d_x0'%i] = self.bcrnn(x, net['t%d_x0' % (i-1)]) 203 | net['t%d_x0'%i] = net['t%d_x0'%i].view(-1, self.chans, h, w) 204 | 205 | net['t%d_x1'%i] = self.conv1_x(net['t%d_x0'%i]) 206 | net['t%d_h1'%i] = self.conv1_h(net['t%d_x1'%(i-1)]) 207 | net['t%d_x1'%i] = self.relu(net['t%d_h1'%i] + net['t%d_x1'%i]) 208 | 209 | net['t%d_x2'%i] = self.conv2_x(net['t%d_x1'%i]) 210 | net['t%d_h2'%i] = self.conv2_h(net['t%d_x2'%(i-1)]) 211 | net['t%d_x2'%i] = self.relu(net['t%d_h2'%i] + net['t%d_x2'%i]) 212 | 213 | net['t%d_x3'%i] = self.conv3_x(net['t%d_x2'%i]) 214 | net['t%d_h3'%i] = self.conv3_h(net['t%d_x3'%(i-1)]) 215 | net['t%d_x3'%i] = self.relu(net['t%d_h3'%i] + net['t%d_x3'%i]) 216 | 217 | net['t%d_x4'%i] = self.conv4_x(net['t%d_x3'%i]) 218 | 219 | # Residual connection 220 | x_res = torch.cat( 221 | [x.view(-1, ch, h, w)[:, :self.i_buffer_size], 222 | x.view(-1, ch, h, w)[:, self.i_buffer_size+1: -1]], 223 | dim = 1, 224 | ) 225 | net['t%d_out'%i] = x_res + net['t%d_x4'%i] 226 | 227 | net['t%d_out'%i] = net['t%d_out'%i].view(t, b, 1, ch_primal, h, w) 228 | net['t%d_out'%i] = net['t%d_out'%i].permute(1,0,2,4,5,3) 229 | net['t%d_out'%i].contiguous() 230 | 231 | image_buffer = net['t%d_out'%i] 232 | 233 | 234 | out_image = torch.stack( 235 | [image_buffer[..., 0], image_buffer[..., self.i_buffer_size]], 236 | dim=-1, 237 | ) 238 | 239 | return rec.utils.complex_abs(out_image.squeeze(2)) 240 | 241 | 242 | class CRNNcell(nn.Module): 243 | """ 244 | Convolutional RNN cell that evolves over both time and iterations. 245 | """ 246 | def __init__( 247 | self, 248 | input_size: int, 249 | hidden_size: int, 250 | kernel_size: int, 251 | ): 252 | """ 253 | Args: 254 | input_size: Number of input channels 255 | hidden_size: Number of RCNN hidden layers channels 256 | kernel_size: Size of convolutional kernel 257 | """ 258 | super(CRNNcell, self).__init__() 259 | 260 | # Convolution for input 261 | self.i2h = nn.Conv2d(input_size, hidden_size, kernel_size, padding=kernel_size // 2) 262 | # Convolution for hidden states in temporal dimension 263 | self.h2h = nn.Conv2d(hidden_size, hidden_size, kernel_size, padding=kernel_size // 2) 264 | # Convolution for hidden states in iteration dimension 265 | self.ih2ih = nn.Conv2d(hidden_size, hidden_size, kernel_size, padding=kernel_size // 2) 266 | 267 | self.relu = nn.ReLU(inplace=True) 268 | 269 | def forward( 270 | self, 271 | input: torch.Tensor, 272 | hidden_iteration: torch.Tensor, 273 | hidden: torch.Tensor, 274 | ) -> torch.Tensor: 275 | """ 276 | Args: 277 | input: Input 4D tensor of shape `(b, ch, h, w)` 278 | hidden_iteration: hidden states in iteration dimension, 4d tensor of shape (b, hidden_size, h, w) 279 | hidden: hidden states in temporal dimension, 4d tensor of shape (b, hidden_size, h, w) 280 | Returns: 281 | Output tensor of shape `(b, hidden_size, h, w)`. 282 | """ 283 | in_to_hid = self.i2h(input) 284 | hid_to_hid = self.h2h(hidden) 285 | ih_to_ih = self.ih2ih(hidden_iteration) 286 | 287 | hidden = self.relu(in_to_hid + hid_to_hid + ih_to_ih) 288 | 289 | return hidden 290 | 291 | 292 | class BCRNNlayer(nn.Module): 293 | """ 294 | Bidirectional Convolutional RNN layer 295 | """ 296 | def __init__( 297 | self, 298 | input_size: int, 299 | hidden_size: int, 300 | kernel_size: int, 301 | ): 302 | """ 303 | Args: 304 | input_size: Number of input channels 305 | hidden_size: Number of RCNN hidden layers channels 306 | kernel_size: Size of convolutional kernel 307 | """ 308 | super(BCRNNlayer, self).__init__() 309 | 310 | self.hidden_size = hidden_size 311 | self.CRNN_model = CRNNcell(input_size, self.hidden_size, kernel_size) 312 | 313 | def forward(self, input: torch.Tensor, hidden_iteration: torch.Tensor) -> torch.Tensor: 314 | """ 315 | Args: 316 | input: Input 5D tensor of shape `(t, b, ch, h, w)` 317 | hidden_iteration: hidden states (output of BCRNNlayer) from previous 318 | iteration, 5d tensor of shape (t, b, hidden_size, h, w) 319 | Returns: 320 | Output tensor of shape `(t, b, hidden_size, h, w)`. 321 | """ 322 | t, b, ch, h, w = input.shape 323 | size_h = [b, self.hidden_size, h, w] 324 | 325 | hid_init = Variable(torch.zeros(size_h)).cuda() 326 | output_f = [] 327 | output_b = [] 328 | 329 | # forward 330 | hidden = hid_init 331 | for i in range(t): 332 | hidden = self.CRNN_model(input[i], hidden_iteration[i], hidden) 333 | output_f.append(hidden) 334 | output_f = torch.cat(output_f) 335 | 336 | # backward 337 | hidden = hid_init 338 | for i in range(t): 339 | hidden = self.CRNN_model(input[t - i - 1], hidden_iteration[t - i -1], hidden) 340 | output_b.append(hidden) 341 | output_b = torch.cat(output_b[::-1]) 342 | 343 | output = output_f + output_b 344 | 345 | if b == 1: 346 | output = output.view(t, 1, self.hidden_size, h, w) 347 | 348 | return output 349 | -------------------------------------------------------------------------------- /reconstruction/data/mri_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This source code is based on the fastMRI repository from Facebook AI 3 | Research and is used as a general framework to handle MRI data. Link: 4 | 5 | https://github.com/facebookresearch/fastMRI 6 | """ 7 | 8 | import sys 9 | import os 10 | import logging 11 | import pickle 12 | import random 13 | import yaml 14 | import xml.etree.ElementTree as etree 15 | from pathlib import Path 16 | from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union 17 | from warnings import warn 18 | 19 | import h5py 20 | import numpy as np 21 | import torch 22 | from reconstruction.data import transforms 23 | 24 | 25 | # BART is a free image-reconstruction framework used to estimate 26 | # coils sensitivity maps via ESPIRiT to provide a target image. 27 | # For use in Colab Notebooks. Change according to the environment used. 28 | sys.path.append('/content/bart/python/') 29 | os.environ['LD_LIBRARY_PATH'] = "/usr/local/cuda/lib64" 30 | os.environ['TOOLBOX_PATH'] = "/content/bart" 31 | os.environ['OMP_NUM_THREADS']="4" 32 | os.environ['PATH'] = os.environ['TOOLBOX_PATH'] + ":" + os.environ['PATH'] 33 | sys.path.append(os.environ['TOOLBOX_PATH'] + "/python") 34 | 35 | import bart 36 | 37 | 38 | def fetch_dir( 39 | key: str, data_config_file: Union[str, Path, os.PathLike] = "dirs_path.yaml" 40 | ) -> Path: 41 | """ 42 | Data directory fetcher. 43 | 44 | This is a brute-force simple way to configure data directories for a 45 | project. 46 | 47 | Args: 48 | key: key to retrieve path from data_config_file. Expected to be in 49 | ("data_path", "log_path", "save_path"). 50 | data_config_file: Optional; Default path config file to fetch path 51 | from. 52 | 53 | Returns: 54 | The path to the specified directory. 55 | """ 56 | data_config_file = Path(data_config_file) 57 | if not data_config_file.is_file(): 58 | default_config = { 59 | "data_path": "/path/to/data", 60 | "log_path": "/root/traintest_scripts", 61 | "save_path": "/root/results", 62 | } 63 | with open(data_config_file, "w") as f: 64 | yaml.dump(default_config, f) 65 | 66 | data_dir = default_config[key] 67 | 68 | warn( 69 | f"Path config at {data_config_file.resolve()} does not exist. " 70 | "A template has been created for you. " 71 | "Please enter the directory paths for your system to have defaults." 72 | ) 73 | else: 74 | with open(data_config_file, "r") as f: 75 | data_dir = yaml.safe_load(f)[key] 76 | 77 | return Path(data_dir) 78 | 79 | 80 | class CombinedSliceDataset(torch.utils.data.Dataset): 81 | """ 82 | A container for combining slice datasets. 83 | """ 84 | 85 | def __init__( 86 | self, 87 | roots: Sequence[Path], 88 | transforms: Optional[Sequence[Optional[Callable]]] = None, 89 | sample_rates: Optional[Sequence[Optional[float]]] = None, 90 | volume_sample_rates: Optional[Sequence[Optional[float]]] = None, 91 | use_dataset_cache: bool = False, 92 | dataset_cache_file: Union[str, Path, os.PathLike] = "dataset_cache.pkl", 93 | num_cols: Optional[Tuple[int]] = None, 94 | ): 95 | """ 96 | Args: 97 | roots: Paths to the datasets. 98 | transforms: Optional; A sequence of callable objects that 99 | preprocesses the raw data into appropriate form. The transform 100 | function should take 'kspace', 'target', 'attributes', 101 | 'filename', and 'slice' as inputs. 'target' may be null for 102 | test data. 103 | sample_rates: Optional; A sequence of floats between 0 and 1. 104 | This controls what fraction of the slices should be loaded. 105 | When creating subsampled datasets either set sample_rates 106 | (sample by slices) or volume_sample_rates (sample by volumes) 107 | but not both. 108 | volume_sample_rates: Optional; A sequence of floats between 0 and 1. 109 | This controls what fraction of the volumes should be loaded. 110 | When creating subsampled datasets either set sample_rates 111 | (sample by slices) or volume_sample_rates (sample by volumes) 112 | but not both. 113 | use_dataset_cache: Whether to cache dataset metadata. This is very 114 | useful for large datasets. 115 | dataset_cache_file: Optional; A file in which to cache dataset 116 | information for faster load times. 117 | num_cols: Optional; If provided, only slices with the desired 118 | number of columns will be considered. 119 | """ 120 | if sample_rates is not None and volume_sample_rates is not None: 121 | raise ValueError( 122 | "either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes) but not both" 123 | ) 124 | if transforms is None: 125 | transforms = [None] * len(roots) 126 | if sample_rates is None: 127 | sample_rates = [None] * len(roots) 128 | if volume_sample_rates is None: 129 | volume_sample_rates = [None] * len(roots) 130 | if not ( 131 | len(roots) 132 | == len(transforms) 133 | == len(sample_rates) 134 | == len(volume_sample_rates) 135 | ): 136 | raise ValueError( 137 | "Lengths of roots, transforms, sample_rates do not match" 138 | ) 139 | 140 | self.datasets = [] 141 | self.examples: List[Tuple[Path, int, Dict[str, object]]] = [] 142 | for i in range(len(roots)): 143 | self.datasets.append( 144 | SliceDataset( 145 | root=roots[i], 146 | transform=transforms[i], 147 | sample_rate=sample_rates[i], 148 | volume_sample_rate=volume_sample_rates[i], 149 | use_dataset_cache=use_dataset_cache, 150 | dataset_cache_file=dataset_cache_file, 151 | num_cols=num_cols, 152 | ) 153 | ) 154 | 155 | self.examples = self.examples + self.datasets[-1].examples 156 | 157 | def __len__(self): 158 | return sum(len(dataset) for dataset in self.datasets) 159 | 160 | def __getitem__(self, i): 161 | for dataset in self.datasets: 162 | if i < len(dataset): 163 | return dataset[i] 164 | else: 165 | i = i - len(dataset) 166 | 167 | 168 | class SliceDataset(torch.utils.data.Dataset): 169 | """ 170 | A PyTorch Dataset that provides access to MR image slices. 171 | """ 172 | 173 | def __init__( 174 | self, 175 | root: Union[str, Path, os.PathLike], 176 | transform: Optional[Callable] = None, 177 | use_dataset_cache: bool = False, 178 | sample_rate: Optional[float] = None, 179 | volume_sample_rate: Optional[float] = None, 180 | dataset_cache_file: Union[str, Path, os.PathLike] = "dataset_cache.pkl", 181 | num_cols: Optional[Tuple[int]] = None, 182 | ): 183 | """ 184 | Args: 185 | root: Path to the dataset. 186 | transform: Optional; A callable object that pre-processes the raw 187 | data into appropriate form. The transform function should take 188 | 'kspace', 'target', 'attributes', 'filename', and 'slice' as 189 | inputs. 'target' may be null for test data. 190 | use_dataset_cache: Whether to cache dataset metadata. This is very 191 | useful for large datasets. 192 | sample_rate: Optional; A float between 0 and 1. This controls what fraction 193 | of the slices should be loaded. Defaults to 1 if no value is given. 194 | When creating a sampled dataset either set sample_rate (sample by slices) 195 | or volume_sample_rate (sample by volumes) but not both. 196 | volume_sample_rate: Optional; A float between 0 and 1. This controls what fraction 197 | of the volumes should be loaded. Defaults to 1 if no value is given. 198 | When creating a sampled dataset either set sample_rate (sample by slices) 199 | or volume_sample_rate (sample by volumes) but not both. 200 | dataset_cache_file: Optional; A file in which to cache dataset 201 | information for faster load times. 202 | num_cols: Optional; If provided, only slices with the desired 203 | number of columns will be considered. 204 | """ 205 | if sample_rate is not None and volume_sample_rate is not None: 206 | raise ValueError( 207 | "either set sample_rate (sample by slices) or volume_sample_rate (sample by volumes) but not both" 208 | ) 209 | 210 | self.dataset_cache_file = Path(dataset_cache_file) 211 | self.transform = transform 212 | self.examples = [] 213 | 214 | # set default sampling mode if none given 215 | if sample_rate is None: 216 | sample_rate = 1.0 217 | if volume_sample_rate is None: 218 | volume_sample_rate = 1.0 219 | 220 | # load dataset cache if we have and user wants to use it 221 | if self.dataset_cache_file.exists() and use_dataset_cache: 222 | with open(self.dataset_cache_file, "rb") as f: 223 | dataset_cache = pickle.load(f) 224 | else: 225 | dataset_cache = {} 226 | 227 | # check if our dataset is in the cache 228 | # if there, use that metadata, if not, then regenerate the metadata 229 | if dataset_cache.get(root) is None or not use_dataset_cache: 230 | files = list(Path(root).iterdir()) 231 | for fname in sorted(files): 232 | self.examples += [fname] 233 | 234 | if dataset_cache.get(root) is None and use_dataset_cache: 235 | dataset_cache[root] = self.examples 236 | logging.info(f"Saving dataset cache to {self.dataset_cache_file}.") 237 | with open(self.dataset_cache_file, "wb") as f: 238 | pickle.dump(dataset_cache, f) 239 | else: 240 | logging.info(f"Using dataset cache from {self.dataset_cache_file}.") 241 | self.examples = dataset_cache[root] 242 | 243 | # subsample if desired 244 | if sample_rate < 1.0: # sample by slice 245 | random.shuffle(self.examples) 246 | num_examples = round(len(self.examples) * sample_rate) 247 | self.examples = self.examples[:num_examples] 248 | elif volume_sample_rate < 1.0: # sample by volume 249 | vol_names = sorted(list(set([f[0].stem for f in self.examples]))) 250 | random.shuffle(vol_names) 251 | num_volumes = round(len(vol_names) * volume_sample_rate) 252 | sampled_vols = vol_names[:num_volumes] 253 | self.examples = [ 254 | example for example in self.examples if example[0].stem in sampled_vols 255 | ] 256 | 257 | if num_cols: 258 | self.examples = [ 259 | ex 260 | for ex in self.examples 261 | if ex[2]["encoding_size"][1] in num_cols # type: ignore 262 | ] 263 | 264 | def __len__(self): 265 | return len(self.examples) 266 | 267 | def __getitem__(self, i: int): 268 | fname = self.examples[i] 269 | dataslice = 0 # Unused for the current dataset 270 | 271 | with h5py.File(fname, "r") as hf: 272 | # Hardcoded data settings (change them here according to specifics of dataset) 273 | scaling = 1e6 274 | crop_shape = (200, 200) # If BART returns error try changing crop size 275 | crop_target = (180, 180) 276 | n_slices = 15 277 | filter_size = [0.7, 0., 0.3, 0.3] 278 | 279 | # Data dimension (Nt, Nx, Ny, Nc) 280 | # Nt: number of slices 281 | # (Nx, Ny): shape of k-space 282 | # Nc: number of coils 283 | kspace = np.array(hf["y"], dtype='complex64') * scaling 284 | 285 | # Cropping + slice selection 286 | kspace = kspace.transpose(0,3,1,2) 287 | scaling_factor = np.sqrt(np.prod(kspace.shape[-2:])) 288 | images = np.fft.fftshift(np.fft.ifftn(np.fft.ifftshift(kspace, axes=(-2,-1)), axes=(-2,-1), norm=None), axes=(-2,-1)) * scaling_factor 289 | images_cropped, images_filter = transforms.filtered_crop_center_and_slices(images, crop_shape, n_slices, filter_size) 290 | scaling_factor = np.sqrt(np.prod(images_filter.shape[-2:])) 291 | kspace = np.fft.ifftshift(np.fft.fftn(np.fft.fftshift(images_filter, axes=(-2,-1)), axes=(-2,-1), norm=None), axes=(-2,-1)) / scaling_factor 292 | kspace = kspace.transpose(0,2,3,1).astype('complex64') 293 | 294 | # Coils sensitivity maps estimation via ESPIRiT 295 | time_avg_kspace = np.mean(kspace, axis=0, keepdims=True) 296 | [calib, emaps] = bart.bart(2, 'ecalib -r 200', time_avg_kspace) 297 | sens = np.squeeze(calib[...,0]) # dimension (Nx, Ny, Nc) 298 | sens = sens.transpose(2,0,1) 299 | 300 | mask = np.asarray(hf["mask"]) if "mask" in hf else None 301 | kspace = kspace.transpose(0,3,1,2) 302 | target = np.abs(np.sum(images_filter * np.conjugate(np.expand_dims(sens, axis=0)), axis=1)).astype('float32') # dimension (Nt, Nx, Ny) 303 | target = transforms.center_crop(target, crop_target) 304 | 305 | attrs = {} # Unused for the current dataset 306 | 307 | if self.transform is None: 308 | sample = (kspace, mask, target, attrs, fname.name, dataslice) 309 | else: 310 | sample = self.transform(kspace, mask, target, attrs, fname.name, dataslice) 311 | 312 | return sample --------------------------------------------------------------------------------