├── .python-version ├── torch_harmonics ├── README.md ├── __init__.py ├── pyproject.toml ├── quadrature.py └── _disco_convolution.py ├── fastmri ├── poisson_cache │ ├── poisson_16x.npy │ ├── poisson_2x.npy │ ├── poisson_32x.npy │ ├── poisson_4x.npy │ ├── poisson_6x.npy │ └── poisson_8x.npy ├── pyproject.toml ├── __init__.py ├── README.md ├── coil_combine.py ├── losses.py ├── math_utils.py ├── evaluate.py ├── fftc.py ├── datasets.py └── subsample.py ├── models ├── pyproject.toml ├── README.md ├── unet.py ├── lightning │ ├── varnet_module.py │ ├── no_varnet_module.py │ └── mri_module.py ├── udno.py ├── varnet.py └── no_varnet.py ├── type_utils.py ├── scripts ├── download_weights.py ├── knee_multipatt.sh └── gen_lmdb_dataset.py ├── pyproject.toml ├── .gitignore ├── README.md └── main.py /.python-version: -------------------------------------------------------------------------------- 1 | 3.10 2 | -------------------------------------------------------------------------------- /torch_harmonics/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_harmonics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastmri/poisson_cache/poisson_16x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/mri/HEAD/fastmri/poisson_cache/poisson_16x.npy -------------------------------------------------------------------------------- /fastmri/poisson_cache/poisson_2x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/mri/HEAD/fastmri/poisson_cache/poisson_2x.npy -------------------------------------------------------------------------------- /fastmri/poisson_cache/poisson_32x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/mri/HEAD/fastmri/poisson_cache/poisson_32x.npy -------------------------------------------------------------------------------- /fastmri/poisson_cache/poisson_4x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/mri/HEAD/fastmri/poisson_cache/poisson_4x.npy -------------------------------------------------------------------------------- /fastmri/poisson_cache/poisson_6x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/mri/HEAD/fastmri/poisson_cache/poisson_6x.npy -------------------------------------------------------------------------------- /fastmri/poisson_cache/poisson_8x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/mri/HEAD/fastmri/poisson_cache/poisson_8x.npy -------------------------------------------------------------------------------- /fastmri/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "fastmri" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.9" 7 | dependencies = [] 8 | -------------------------------------------------------------------------------- /models/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "models" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.9" 7 | dependencies = [] 8 | -------------------------------------------------------------------------------- /type_utils.py: -------------------------------------------------------------------------------- 1 | def tuple_type(strings): 2 | strings = strings.replace("(", "").replace(")", "").replace(" ", "") 3 | mapped_int = map(int, strings.split(",")) 4 | return tuple(mapped_int) 5 | -------------------------------------------------------------------------------- /scripts/download_weights.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import snapshot_download 2 | 3 | repo_id = "armeet/nomri" 4 | repo_path = snapshot_download(repo_id, local_dir="weights") 5 | 6 | print(f"Downloaded to: {repo_path}") 7 | -------------------------------------------------------------------------------- /torch_harmonics/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "torch-harmonics-local" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.9" 7 | dependencies = [] 8 | -------------------------------------------------------------------------------- /fastmri/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from .coil_combine import rss, rss_complex, mvue 9 | from .fftc import fft2c_new as fft2c 10 | from .fftc import fftshift 11 | from .fftc import ifft2c_new as ifft2c 12 | from .fftc import ifftshift, roll 13 | from .losses import SSIMLoss 14 | from .math_utils import ( 15 | complex_abs, 16 | complex_abs_sq, 17 | complex_conj, 18 | complex_mul, 19 | tensor_to_complex_np, 20 | ) 21 | -------------------------------------------------------------------------------- /scripts/knee_multipatt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | export WANDB_MODE=disabled 5 | export WANDB_API_KEY=*** 6 | export WANDB_DIR=/tmp/wandb 7 | export WANDB_CACHE_DIR=/tmp/wandb_cache 8 | 9 | uv run main.py --mode val \ 10 | --name knee_multipatt \ 11 | --model no_vn \ 12 | --num_cascades 6 \ 13 | --body_part knee \ 14 | --experiment release \ 15 | --crop_shape 320,320 \ 16 | --in_shape 320,320 \ 17 | --val_patterns equispaced_fraction magic random gaussian_2d poisson_2d radial_2d \ 18 | --val_accelerations 4 \ 19 | --sample_rate 0.1 \ 20 | --ckpt_path weights/knee_multipatt.ckpt 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "nomri" 3 | version = "1.0.0" 4 | description = "nomri - CVPR 2025" 5 | readme = "README.md" 6 | requires-python = ">=3.9" 7 | dependencies = [ 8 | "lightning-utilities==0.11.7", 9 | "lightning==2.4.0", 10 | "tabulate>=0.9.0", 11 | "torch==2.7.1", 12 | "torchmetrics==1.4.1", 13 | "torchvision", 14 | "wandb==0.17.9", 15 | "pytest>=8.3.5", 16 | "pandas>=2.2.3", 17 | "numpy<2", 18 | "matplotlib>=3.9.4", 19 | "sigpy>=0.1.27", 20 | "h5py>=3.13.0", 21 | "lmdb>=1.6.2", 22 | "runstats>=2.0.0", 23 | "scikit-image>=0.24.0", 24 | "fastmri", 25 | "models", 26 | "torch_harmonics", 27 | "pdbpp>=0.10.3", 28 | "huggingface-hub>=0.30.1", 29 | ] 30 | 31 | [tool.pyright] 32 | extraPaths = ["fastmri", "models", "torch_harmonics"] 33 | typeCheckingMode = "basic" 34 | 35 | [tool.uv.sources] 36 | fastmri = { workspace = true } 37 | models = { workspace = true } 38 | torch_harmonics = { path = "torch_harmonics" } 39 | 40 | [tool.uv.workspace] 41 | members = ["models", "fastmri", "torch_harmonics"] 42 | 43 | [dependency-groups] 44 | dev = [ 45 | "basedpyright>=1.28.3", 46 | "ruff>=0.11.2", 47 | ] 48 | -------------------------------------------------------------------------------- /fastmri/README.md: -------------------------------------------------------------------------------- 1 | # fastMRI 2 | 3 | At the time of research, the fastMRI code repo/library is licensed under an MIT 4 | license which has been copied into this directory under LICENSE.md. 5 | 6 | The original code is available at https://github.com/facebookresearch/fastMRI 7 | 8 | We make a number of necessary modifications that extend and modify some of the 9 | original behavior. 10 | 11 | - new non-rectangular mask functions and changes to mask data types / schemas 12 | * Poisson mask 13 | * Radial mask 14 | * Gaussian mask 15 | - updated versions of certain packages (wandb, etc.) 16 | - working image logging 17 | - documentation 18 | - type annotations 19 | 20 | ## Cite 21 | 22 | Cite the original fastMRI arXiv paper: 23 | 24 | ```BibTeX 25 | @misc{zbontar2018fastMRI, 26 | title={{fastMRI}: An Open Dataset and Benchmarks for Accelerated {MRI}}, 27 | author={Jure Zbontar and Florian Knoll and Anuroop Sriram and Tullie Murrell and Zhengnan Huang and Matthew J. Muckley and Aaron Defazio and Ruben Stern and Patricia Johnson and Mary Bruno and Marc Parente and Krzysztof J. Geras and Joe Katsnelson and Hersh Chandarana and Zizhao Zhang and Michal Drozdzal and Adriana Romero and Michael Rabbat and Pascal Vincent and Nafissa Yakubova and James Pinkerton and Duo Wang and Erich Owens and C. Lawrence Zitnick and Michael P. Recht and Daniel K. Sodickson and Yvonne W. Lui}, 28 | journal = {ArXiv e-prints}, 29 | archivePrefix = "arXiv", 30 | eprint = {1811.08839}, 31 | year={2018} 32 | } 33 | ``` 34 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | ## Neural Operator Models 2 | udno.py : U-shaped DISCO Neural Operator 3 | * in-place resolution invariant replacement for U-Net 4 | * building block for no_varnet 5 | * uses EquidistantDiscreteContinuousConv2d from torch_harmonics 6 | 7 | no_varnet.py : Neural Operator model introduced in https://arxiv.org/abs/2410.16290 8 | 9 | 10 | If you use either the no_varnet or UDNO, please cite the following. 11 | 12 | ```bibtex 13 | @article{jatyani2024unified, 14 | title = {A Unified Model for Compressed Sensing MRI Across Undersampling Patterns}, 15 | author = {Jatyani, Armeet Singh and Wang, Jiayun and Wu, Zihui and Liu-Schiaffini, Miguel and Tolooshams, Bahareh and Anandkumar, Anima}, 16 | journal = {arXiv preprint arXiv:2410.16290}, 17 | year = {2024} 18 | } 19 | ``` 20 | 21 | ```bibtex 22 | @article{liu2024neural, 23 | title={Neural operators with localized integral and differential kernels}, 24 | author={Liu-Schiaffini, Miguel and Berner, Julius and Bonev, Boris and Kurth, Thorsten and Azizzadenesheli, Kamyar and Anandkumar, Anima}, 25 | journal={arXiv preprint arXiv:2402.16845}, 26 | year={2024} 27 | } 28 | ``` 29 | 30 | ## E2E-Varnet 31 | varnet.py : Original E2E-Varnet model 32 | 33 | Please cite the original E2E-Varnet paper. 34 | 35 | ```bibtex 36 | @inproceedings{sriram2020end, 37 | title={End-to-end variational networks for accelerated MRI reconstruction}, 38 | author={Sriram, Anuroop and Zbontar, Jure and Murrell, Tullie and Defazio, Aaron and Zitnick, C Lawrence and Yakubova, Nafissa and Knoll, Florian and Johnson, Patricia}, 39 | booktitle={Medical image computing and computer assisted intervention--MICCAI 2020: 23rd international conference, Lima, Peru, October 4--8, 2020, proceedings, part II 23}, 40 | pages={64--73}, 41 | year={2020}, 42 | organization={Springer} 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Config files 2 | fastmri.yaml 3 | 4 | # Python specific 5 | __pycache__/ 6 | .pytest_cache/ 7 | *.py[cod] 8 | *.so 9 | *.egg-info/ 10 | *.pyo 11 | *.pyd 12 | 13 | # Virtual environments 14 | .env 15 | .venv 16 | env/ 17 | venv/ 18 | ENV/ 19 | env.bak/ 20 | venv.bak/ 21 | 22 | # Code in development 23 | ignore_**.py 24 | 25 | # Hidden / ignore folders 26 | hidden/ 27 | ignore/ 28 | hidden_** 29 | 30 | # Jupyter Notebook Checkpoints 31 | .ipynb_checkpoints 32 | 33 | # Data files 34 | data/ 35 | datasets/ 36 | *.csv 37 | *.tsv 38 | *.h5 39 | *.json 40 | *.xml 41 | *.parquet 42 | *.pkl 43 | 44 | # Model files 45 | *.h5 46 | *.ckpt 47 | *.tflite 48 | *.onnx 49 | *.pb 50 | *.pth 51 | *.pt 52 | *.joblib 53 | *.pkl 54 | 55 | # Logs and outputs 56 | logs/ 57 | wandb/ 58 | *.log 59 | *.out 60 | *.txt 61 | *.csv 62 | 63 | # Test dir 64 | !tests/**/*.txt 65 | !tests/datasets 66 | 67 | # Results 68 | results/ 69 | output/ 70 | runs/ 71 | outfig/ 72 | figs/*.png 73 | 74 | # SLURM 75 | slurm/ 76 | 77 | # Ignore files related to experiments 78 | experiments/ 79 | 80 | # Temporary files 81 | *.tmp 82 | *.temp 83 | *.swp 84 | *.swo 85 | 86 | # VS Code specific 87 | .vscode/ 88 | *.code-workspace 89 | 90 | # System files 91 | .DS_Store 92 | Thumbs.db 93 | 94 | # Environment files 95 | *.env 96 | 97 | # Ignore files from data processing tools 98 | *.dvc 99 | .dvc/ 100 | 101 | # PyTorch Lightning Logs 102 | lightning_logs/ 103 | 104 | # Ignore files generated by package managers 105 | Pipfile 106 | Pipfile.lock 107 | poetry.lock 108 | 109 | # TensorBoard logs 110 | logs/ 111 | events.out.tfevents.* 112 | 113 | # Checkpoints and weights 114 | checkpoints/ 115 | weights/ 116 | 117 | # Large file extensions 118 | *.tar.gz 119 | *.zip 120 | *.tar 121 | *.gz 122 | 123 | 124 | -------------------------------------------------------------------------------- /fastmri/coil_combine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | 10 | import fastmri 11 | 12 | 13 | def rss(data: torch.Tensor, dim: int = 0) -> torch.Tensor: 14 | """ 15 | Compute the Root Sum of Squares (RSS). 16 | 17 | The RSS is computed assuming that `dim` is the coil dimension. 18 | 19 | Parameters 20 | ---------- 21 | data : torch.Tensor 22 | The input tensor. 23 | dim : int, optional 24 | The dimension along which to apply the RSS transform (default is 0). 25 | 26 | Returns 27 | ------- 28 | torch.Tensor 29 | The computed RSS value. 30 | """ 31 | return torch.sqrt((data**2).sum(dim)) 32 | 33 | 34 | def mvue(spatial_pred, sens_maps, dim: int = 0) -> torch.Tensor: 35 | spatial_pred = torch.view_as_complex(spatial_pred) 36 | sens_maps = torch.view_as_complex(sens_maps) 37 | 38 | numerator = torch.sum(spatial_pred * torch.conj(sens_maps), dim=dim) 39 | denominator = torch.sqrt(torch.sum(torch.square(torch.abs(sens_maps)), dim=dim)) 40 | res = numerator / denominator 41 | res = torch.abs(res) 42 | return res 43 | 44 | 45 | def rss_complex(data: torch.Tensor, dim: int = 0) -> torch.Tensor: 46 | """ 47 | Compute the Root Sum of Squares (RSS) for complex inputs. 48 | 49 | The RSS is computed assuming that `dim` is the coil dimension. 50 | 51 | Parameters 52 | ---------- 53 | data : torch.Tensor 54 | The input tensor containing complex values. 55 | dim : int, optional 56 | The dimension along which to apply the RSS transform (default is 0). 57 | 58 | Returns 59 | ------- 60 | torch.Tensor 61 | The computed RSS value for complex inputs. 62 | """ 63 | return torch.sqrt(fastmri.complex_abs_sq(data).sum(dim)) 64 | -------------------------------------------------------------------------------- /fastmri/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SSIMLoss(nn.Module): 14 | """ 15 | SSIM loss module. 16 | """ 17 | 18 | def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03): 19 | """ 20 | Initialize the Losses class. 21 | 22 | Parameters 23 | ---------- 24 | win_size : int, optional 25 | Window size for SSIM calculation. 26 | k1 : float, optional 27 | k1 parameter for SSIM calculation. 28 | k2 : float, optional 29 | k2 parameter for SSIM calculation. 30 | """ 31 | super().__init__() 32 | self.win_size = win_size 33 | self.k1, self.k2 = k1, k2 34 | self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size**2) 35 | NP = win_size**2 36 | self.cov_norm = NP / (NP - 1) 37 | 38 | def forward( 39 | self, 40 | X: torch.Tensor, 41 | Y: torch.Tensor, 42 | data_range: torch.Tensor, 43 | reduced: bool = True, 44 | ): 45 | assert isinstance(self.w, torch.Tensor) 46 | 47 | data_range = data_range[:, None, None, None].to(X.device) 48 | C1 = (self.k1 * data_range) ** 2 49 | C2 = (self.k2 * data_range) ** 2 50 | 51 | # Compute means 52 | ux = F.conv2d(X, self.w) 53 | uy = F.conv2d(Y, self.w) 54 | 55 | # Compute variances 56 | uxx = F.conv2d(X * X, self.w) 57 | uyy = F.conv2d(Y * Y, self.w) 58 | uxy = F.conv2d(X * Y, self.w) 59 | 60 | # Compute covariances 61 | vx = self.cov_norm * (uxx - ux * ux) 62 | vy = self.cov_norm * (uyy - uy * uy) 63 | vxy = self.cov_norm * (uxy - ux * uy) 64 | 65 | # Compute SSIM components 66 | A1, A2 = 2 * ux * uy + C1, 2 * vxy + C2 67 | B1, B2 = ux**2 + uy**2 + C1, vx + vy + C2 68 | D = B1 * B2 69 | S = (A1 * A2) / D 70 | 71 | if reduced: 72 | return 1 - S.mean() 73 | else: 74 | return 1 - S 75 | 76 | 77 | if __name__ == "__main__": 78 | # Example usage 79 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 80 | 81 | # Create the SSIMLoss module and move it to the GPU 82 | ssim_loss = SSIMLoss().to(device) 83 | 84 | # Create example tensors and move them to the GPU 85 | X = torch.randn(4, 1, 256, 256).to(device) 86 | Y = torch.randn(4, 1, 256, 256).to(device) 87 | data_range = torch.rand(4).to(device) 88 | 89 | # Compute the loss 90 | loss = ssim_loss(X, Y, data_range) 91 | print(loss) 92 | -------------------------------------------------------------------------------- /fastmri/math_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def complex_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 13 | """ 14 | Complex multiplication. 15 | 16 | Multiplies two complex tensors assuming that they are both stored as 17 | real arrays with the last dimension being the complex dimension. 18 | 19 | Parameters 20 | ---------- 21 | x : torch.Tensor 22 | A PyTorch tensor with the last dimension of size 2. 23 | y : torch.Tensor 24 | A PyTorch tensor with the last dimension of size 2. 25 | 26 | Returns 27 | ------- 28 | torch.Tensor 29 | A PyTorch tensor with the last dimension of size 2, representing 30 | the result of the complex multiplication. 31 | """ 32 | if not x.shape[-1] == y.shape[-1] == 2: 33 | raise ValueError("Tensors do not have separate complex dim.") 34 | 35 | re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] 36 | im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] 37 | 38 | return torch.stack((re, im), dim=-1) 39 | 40 | 41 | def complex_conj(x: torch.Tensor) -> torch.Tensor: 42 | """ 43 | Complex conjugate. 44 | 45 | Applies the complex conjugate assuming that the input array has the 46 | last dimension as the complex dimension. 47 | 48 | Parameters 49 | ---------- 50 | x : torch.Tensor 51 | A PyTorch tensor with the last dimension of size 2. 52 | 53 | Returns 54 | ------- 55 | torch.Tensor 56 | A PyTorch tensor with the last dimension of size 2, representing 57 | the complex conjugate of the input tensor. 58 | """ 59 | if not x.shape[-1] == 2: 60 | raise ValueError("Tensor does not have separate complex dim.") 61 | 62 | return torch.stack((x[..., 0], -x[..., 1]), dim=-1) 63 | 64 | 65 | def complex_abs(data: torch.Tensor) -> torch.Tensor: 66 | """ 67 | Compute the absolute value of a complex-valued input tensor. 68 | 69 | Parameters 70 | ---------- 71 | data : torch.Tensor 72 | A complex-valued tensor, where the size of the final dimension 73 | should be 2. 74 | 75 | Returns 76 | ------- 77 | torch.Tensor 78 | Absolute value of the input tensor. 79 | """ 80 | if not data.shape[-1] == 2: 81 | raise ValueError("Tensor does not have separate complex dim.") 82 | 83 | return (data**2).sum(dim=-1).sqrt() 84 | 85 | 86 | def complex_abs_sq(data: torch.Tensor) -> torch.Tensor: 87 | """ 88 | Compute the squared absolute value of a complex tensor. 89 | 90 | Parameters 91 | ---------- 92 | data : torch.Tensor 93 | A complex-valued tensor, where the size of the final dimension 94 | should be 2. 95 | 96 | Returns 97 | ------- 98 | torch.Tensor 99 | Squared absolute value of the input tensor. 100 | """ 101 | if not data.shape[-1] == 2: 102 | raise ValueError("Tensor does not have separate complex dim.") 103 | 104 | return (data**2).sum(dim=-1) 105 | 106 | 107 | def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray: 108 | """ 109 | Convert a complex PyTorch tensor to a NumPy array. 110 | 111 | Parameters 112 | ---------- 113 | data : torch.Tensor 114 | Input data to be converted to a NumPy array. 115 | 116 | Returns 117 | ------- 118 | np.ndarray 119 | A complex NumPy array version of the input tensor. 120 | """ 121 | return torch.view_as_complex(data).numpy() 122 | -------------------------------------------------------------------------------- /scripts/gen_lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transform SliceDataset into LMDB dataset (SliceDatasetLMDB) 3 | """ 4 | 5 | import argparse 6 | import os 7 | from pathlib import Path 8 | 9 | import lmdb 10 | import numpy as np 11 | import tqdm 12 | 13 | from fastmri.datasets import SliceDataset, SliceSample 14 | 15 | KNEE_COILS = 15 16 | BRAIN_COILS = 16 17 | 18 | 19 | def main(args): 20 | num_coils = KNEE_COILS 21 | contrast = None # don't filter by contrast for knee (there is only 1) 22 | if args.body_part == "brain": 23 | num_coils = BRAIN_COILS 24 | contrast = "T2" 25 | 26 | dataset = SliceDataset( 27 | args.body_part, 28 | partition=args.partition, 29 | complex=False, 30 | sample_rate=args.sample_rate, 31 | crop_shape=(320, 320), 32 | contrast=contrast, 33 | coils=num_coils, 34 | ) 35 | process( 36 | dataset, 37 | num_coils, 38 | args.out_path, 39 | ) 40 | 41 | 42 | def process( 43 | dataset: SliceDataset, 44 | coils: int, 45 | out_dir: Path | str, 46 | n_jobs=-1, 47 | chunk_size=10, 48 | ): 49 | N = len(dataset) 50 | 51 | kspace_arr = np.zeros((N, coils, 320, 320, 2), dtype=np.float32) # type: ignore 52 | rss_arr = np.zeros((N, 320, 320), dtype=np.float32) # type: ignore 53 | meta = [] 54 | 55 | N_actual = 0 56 | for i in tqdm.trange(N): 57 | sample: SliceSample = dataset[i] 58 | if sample == None: 59 | continue 60 | kspace = sample.masked_kspace 61 | target = sample.target # rss targets 62 | max_value = sample.max_value 63 | fname = sample.fname 64 | slice_num = sample.slice_num 65 | 66 | kspace_arr[N_actual] = kspace 67 | rss_arr[N_actual] = target 68 | meta.append((fname, slice_num, max_value)) 69 | 70 | N_actual += 1 71 | 72 | os.makedirs(out_dir, exist_ok=True) 73 | 74 | # Save kspace 75 | kspace_arr = kspace_arr[:N_actual] 76 | env = lmdb.open(f"{out_dir}/kspace", map_size=int(1e12), readahead=False) 77 | save2db(kspace_arr, env, 0) 78 | env.close() 79 | 80 | # Save rss target 81 | rss_arr = rss_arr[:N_actual] 82 | env = lmdb.open(f"{out_dir}/rss", map_size=int(1e12), readahead=False) 83 | save2db(rss_arr, env, 0) 84 | env.close() 85 | 86 | # Save meta 87 | np.save(f"{out_dir}/meta.npy", meta) 88 | 89 | 90 | def save2db(batch_data, env, cur): 91 | """ 92 | Args: 93 | - batch_data (np.ndarray): (batchsize, H, W) 94 | - env (lmdb.Environment): lmdb environment 95 | - cur (int): current index 96 | """ 97 | num_samples = batch_data.shape[0] 98 | with env.begin(write=True) as txn: 99 | for i in range(num_samples): 100 | key = f"{cur + i}".encode() 101 | txn.put(key, batch_data[i]) 102 | return cur + num_samples 103 | 104 | 105 | def parse_args(): 106 | parser = argparse.ArgumentParser( 107 | description="A script to convert SliceDataset samples into lmdb format." 108 | ) 109 | parser.add_argument( 110 | "--body_part", 111 | "-bp", 112 | type=str, 113 | choices=["brain", "knee"], 114 | required=True, 115 | ) 116 | parser.add_argument( 117 | "--partition", 118 | "-p", 119 | type=str, 120 | choices=["train", "val"], 121 | required=True, 122 | ) 123 | parser.add_argument( 124 | "--out_path", 125 | "-o", 126 | type=str, 127 | required=True, 128 | ) 129 | parser.add_argument( 130 | "--sample_rate", 131 | type=float, 132 | default=1.0, 133 | required=False, 134 | ) 135 | return parser.parse_args() 136 | 137 | 138 | if __name__ == "__main__": 139 | main(parse_args()) 140 | -------------------------------------------------------------------------------- /fastmri/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import pathlib 10 | from argparse import ArgumentParser 11 | from typing import Optional 12 | 13 | import h5py 14 | import numpy as np 15 | from runstats import Statistics 16 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 17 | 18 | from fastmri import transforms 19 | 20 | 21 | def mse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray: 22 | """Compute Mean Squared Error (MSE)""" 23 | return np.mean((gt - pred) ** 2) 24 | 25 | 26 | def nmse(gt: np.ndarray, pred: np.ndarray) -> np.ndarray: 27 | """Compute Normalized Mean Squared Error (NMSE)""" 28 | return np.array(np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2) 29 | 30 | 31 | def psnr( 32 | gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None 33 | ) -> np.ndarray: 34 | """Compute Peak Signal to Noise Ratio metric (PSNR)""" 35 | if maxval is None: 36 | maxval = gt.max() 37 | return peak_signal_noise_ratio(gt, pred, data_range=maxval) 38 | 39 | 40 | def ssim( 41 | gt: np.ndarray, pred: np.ndarray, maxval: Optional[float] = None 42 | ) -> np.ndarray: 43 | """Compute Structural Similarity Index Metric (SSIM)""" 44 | if not gt.ndim == 3: 45 | raise ValueError("Unexpected number of dimensions in ground truth.") 46 | if not gt.ndim == pred.ndim: 47 | raise ValueError("Ground truth dimensions does not match pred.") 48 | 49 | maxval = gt.max() if maxval is None else maxval 50 | 51 | ssim = np.array([0]) 52 | for slice_num in range(gt.shape[0]): 53 | ssim = ssim + structural_similarity( 54 | gt[slice_num], pred[slice_num], data_range=maxval 55 | ) 56 | 57 | return ssim / gt.shape[0] 58 | 59 | 60 | METRIC_FUNCS = dict( 61 | MSE=mse, 62 | NMSE=nmse, 63 | PSNR=psnr, 64 | SSIM=ssim, 65 | ) 66 | 67 | 68 | class Metrics: 69 | """ 70 | Maintains running statistics for a given collection of metrics. 71 | """ 72 | 73 | def __init__(self, metric_funcs): 74 | """ 75 | Parameters 76 | ---------- 77 | metric_funcs : dict 78 | A dictionary where the keys are metric names (as strings) and the values 79 | are Python functions for evaluating the corresponding metrics. 80 | """ 81 | 82 | self.metrics = {metric: Statistics() for metric in metric_funcs} 83 | 84 | def push(self, target, recons): 85 | for metric, func in METRIC_FUNCS.items(): 86 | self.metrics[metric].push(func(target, recons)) 87 | 88 | def means(self): 89 | return {metric: stat.mean() for metric, stat in self.metrics.items()} 90 | 91 | def stddevs(self): 92 | return {metric: stat.stddev() for metric, stat in self.metrics.items()} 93 | 94 | def __repr__(self): 95 | means = self.means() 96 | stddevs = self.stddevs() 97 | metric_names = sorted(list(means)) 98 | return " ".join( 99 | f"{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}" 100 | for name in metric_names 101 | ) 102 | 103 | 104 | def evaluate(args, recons_key): 105 | metrics = Metrics(METRIC_FUNCS) 106 | 107 | for tgt_file in args.target_path.iterdir(): 108 | with ( 109 | h5py.File(tgt_file, "r") as target, 110 | h5py.File(args.predictions_path / tgt_file.name, "r") as recons, 111 | ): 112 | if args.acquisition and args.acquisition != target.attrs["acquisition"]: 113 | continue 114 | 115 | if args.acceleration and target.attrs["acceleration"] != args.acceleration: 116 | continue 117 | 118 | target = target[recons_key][()] 119 | recons = recons["reconstruction"][()] 120 | target = transforms.center_crop( 121 | target, (target.shape[-1], target.shape[-1]) 122 | ) 123 | recons = transforms.center_crop( 124 | recons, (target.shape[-1], target.shape[-1]) 125 | ) 126 | metrics.push(target, recons) 127 | 128 | return metrics 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 133 | parser.add_argument( 134 | "--target-path", 135 | type=pathlib.Path, 136 | required=True, 137 | help="Path to the ground truth data", 138 | ) 139 | parser.add_argument( 140 | "--predictions-path", 141 | type=pathlib.Path, 142 | required=True, 143 | help="Path to reconstructions", 144 | ) 145 | parser.add_argument( 146 | "--challenge", 147 | choices=["singlecoil", "multicoil"], 148 | required=True, 149 | help="Which challenge", 150 | ) 151 | parser.add_argument("--acceleration", type=int, default=None) 152 | parser.add_argument( 153 | "--acquisition", 154 | choices=[ 155 | "CORPD_FBK", 156 | "CORPDFS_FBK", 157 | "AXT1", 158 | "AXT1PRE", 159 | "AXT1POST", 160 | "AXT2", 161 | "AXFLAIR", 162 | ], 163 | default=None, 164 | help=( 165 | "If set, only volumes of the specified acquisition type are used " 166 | "for evaluation. By default, all volumes are included." 167 | ), 168 | ) 169 | args = parser.parse_args() 170 | 171 | recons_key = ( 172 | "reconstruction_rss" if args.challenge == "multicoil" else "reconstruction_esc" 173 | ) 174 | metrics = evaluate(args, recons_key) 175 | print(metrics) 176 | -------------------------------------------------------------------------------- /fastmri/fftc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from typing import List, Optional 9 | 10 | import torch 11 | import torch.fft 12 | 13 | 14 | def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 15 | """ 16 | Apply a centered 2-dimensional Fast Fourier Transform (FFT). 17 | 18 | Parameters 19 | ---------- 20 | data : torch.Tensor 21 | Complex-valued input data containing at least 3 dimensions. 22 | Dimensions -3 and -2 are spatial dimensions, and dimension -1 has size 2. 23 | All other dimensions are assumed to be batch dimensions. 24 | norm : str 25 | Normalization mode. Refer to `torch.fft.fft` for details on normalization options. 26 | 27 | Returns 28 | ------- 29 | torch.Tensor 30 | The FFT of the input data. 31 | """ 32 | 33 | if not data.shape[-1] == 2: 34 | raise ValueError("Tensor does not have separate complex dim.") 35 | 36 | data = ifftshift(data, dim=[-3, -2]) 37 | data = torch.view_as_real( 38 | torch.fft.fftn( # type: ignore 39 | torch.view_as_complex(data), dim=(-2, -1), norm=norm 40 | ) 41 | ) 42 | data = fftshift(data, dim=[-3, -2]) 43 | 44 | return data 45 | 46 | 47 | def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 48 | """ 49 | Apply a centered 2-dimensional Inverse Fast Fourier Transform (IFFT). 50 | 51 | Parameters 52 | ---------- 53 | data : torch.Tensor 54 | Complex-valued input data containing at least 3 dimensions. 55 | Dimensions -3 and -2 are spatial dimensions, and dimension -1 has size 2. 56 | All other dimensions are assumed to be batch dimensions. 57 | norm : str 58 | Normalization mode. Refer to `torch.fft.ifft` for details on normalization options. 59 | 60 | Returns 61 | ------- 62 | torch.Tensor 63 | The IFFT of the input data. 64 | """ 65 | 66 | if not data.shape[-1] == 2: 67 | raise ValueError("Tensor does not have separate complex dim.") 68 | 69 | data = ifftshift(data, dim=[-3, -2]) 70 | data = torch.view_as_real( 71 | torch.fft.ifftn( # type: ignore 72 | torch.view_as_complex(data), dim=(-2, -1), norm=norm 73 | ) 74 | ) 75 | data = fftshift(data, dim=[-3, -2]) 76 | 77 | return data 78 | 79 | 80 | # Helper functions 81 | 82 | 83 | def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor: 84 | """ 85 | Roll a PyTorch tensor along a specified dimension. 86 | 87 | This function is similar to `torch.roll` but operates on a single dimension. 88 | 89 | Parameters 90 | ---------- 91 | x : torch.Tensor 92 | The input tensor to be rolled. 93 | shift : int 94 | Amount to roll. 95 | dim : int 96 | The dimension along which to roll the tensor. 97 | 98 | Returns 99 | ------- 100 | torch.Tensor 101 | A tensor with the same shape as `x`, but rolled along the specified dimension. 102 | """ 103 | 104 | shift = shift % x.size(dim) 105 | if shift == 0: 106 | return x 107 | 108 | left = x.narrow(dim, 0, x.size(dim) - shift) 109 | right = x.narrow(dim, x.size(dim) - shift, shift) 110 | 111 | return torch.cat((right, left), dim=dim) 112 | 113 | 114 | def roll( 115 | x: torch.Tensor, 116 | shift: List[int], 117 | dim: List[int], 118 | ) -> torch.Tensor: 119 | """ 120 | Similar to np.roll but applies to PyTorch Tensors. 121 | 122 | Parameters 123 | ---------- 124 | x : torch.Tensor 125 | A PyTorch tensor. 126 | shift : int 127 | Amount to roll. 128 | dim : int 129 | Which dimension to roll. 130 | 131 | Returns 132 | ------- 133 | torch.Tensor 134 | Rolled version of x. 135 | """ 136 | 137 | if len(shift) != len(dim): 138 | raise ValueError("len(shift) must match len(dim)") 139 | 140 | for s, d in zip(shift, dim): 141 | x = roll_one_dim(x, s, d) 142 | 143 | return x 144 | 145 | 146 | def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: 147 | """ 148 | Similar to np.fft.fftshift but applies to PyTorch Tensors. 149 | 150 | Parameters 151 | ---------- 152 | x : torch.Tensor 153 | A PyTorch tensor. 154 | dim : list of int, optional 155 | Which dimension to apply fftshift. If None, the shift is applied to all dimensions (default is None). 156 | 157 | Returns 158 | ------- 159 | torch.Tensor 160 | fftshifted version of x. 161 | """ 162 | if dim is None: 163 | # this weird code is necessary for torch.jit.script typing 164 | dim = [0] * (x.dim()) 165 | for i in range(1, x.dim()): 166 | dim[i] = i 167 | 168 | # also necessary for torch.jit.script 169 | shift = [0] * len(dim) 170 | for i, dim_num in enumerate(dim): 171 | shift[i] = x.shape[dim_num] // 2 172 | 173 | return roll(x, shift, dim) 174 | 175 | 176 | def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: 177 | """ 178 | Similar to np.fft.ifftshift but applies to PyTorch Tensors. 179 | 180 | Parameters 181 | ---------- 182 | x : torch.Tensor 183 | A PyTorch tensor. 184 | dim : list of int, optional 185 | Which dimension to apply ifftshift. If None, the shift is applied to all dimensions (default is None). 186 | 187 | Returns 188 | ------- 189 | torch.Tensor 190 | ifftshifted version of x. 191 | """ 192 | if dim is None: 193 | # this weird code is necessary for torch.jit.script typing 194 | dim = [0] * (x.dim()) 195 | for i in range(1, x.dim()): 196 | dim[i] = i 197 | 198 | # also necessary for torch.jit.script 199 | shift = [0] * len(dim) 200 | for i, dim_num in enumerate(dim): 201 | shift[i] = (x.shape[dim_num] + 1) // 2 202 | 203 | return roll(x, shift, dim) 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![arXiv](https://img.shields.io/badge/arXiv-2410.16290-b31b1b.svg?style=flat-square&logo=arxiv)](https://arxiv.org/abs/2410.16290) 2 | [![](https://img.shields.io/badge/Blog-armeet.ca%2Fnomri-yellow?style=flat-square)](https://armeet.ca/nomri) 3 | 4 | # A Unified Model for Compressed Sensing MRI Across Undersampling Patterns 5 | 6 | > [**A Unified Model for Compressed Sensing MRI Across Undersampling Patterns**](https://arxiv.org/abs/2410.16290) 7 | > Armeet Singh Jatyani, Jiayun Wang, Aditi Chandrashekar, Zihui Wu, Miguel Liu-Schiaffini, Bahareh Tolooshams, Anima Anandkumar 8 | > *Paper at [CVPR 2025](https://cvpr.thecvf.com/Conferences/2025/AcceptedPapers)* 9 | 10 | [![hugging face](https://github.com/user-attachments/assets/b483ef88-0646-46c0-9dce-941c0fc31fbe)](https://huggingface.co/spaces/armeet/nomri) 11 | 12 | ![intro](https://github.com/user-attachments/assets/79aee2fd-0956-4a05-b6c8-2037618e47b1) 13 | 14 | > _**(a) Unified Model:** NO works across various undersampling patterns, unlike CNNs (e.g., [E2E-VarNet](#)) that need separate models for each._ \ 15 | > _**(b) Consistent Performance:** NO consistently outperforms CNNs, especially for 2× acceleration with a single unrolled cascade._ \ 16 | > _**(c) Resolution-Agnostic:** Maintains fixed kernel size regardless of image resolution, reducing aliasing risks._ \ 17 | > _**(d) Zero-Shot Super-Resolution:** Outperforms CNNs in reconstructing high-res MRIs without retraining._ 18 | 19 | ![super](https://github.com/user-attachments/assets/3675a80e-c05f-4d41-9fdf-531de0576751) 20 | 21 | > _**(a) Zero-Shot Extended FOV:** On 4x Gaussian undersampling, NO achieves higher PSNR and fewer artifacts than E2E-VN, despite both models being trained only on 160 x 160 FOV._ \ 22 | > _**(b) Zero-Shot Super-Resolution in Image Space:** For 2x radial with 640 x 640 input via bilinear upsampling, NO preserves quality while E2E-VN introduces artifacts._ 23 | 24 | ## Requirements 25 | We have tested training/inference on the following hardware/software versions, however there is no reason it shouldn't work on slightly older driver/cuda versions. 26 | - tested on RTX 4090 and A100 with CUDA 12.4 and NVML/Driver version 550 27 | - Ubuntu 22.04.3 LTS & SUSE Linux Enterprise Server 15 28 | - All python packages are in `pyproject.toml` (see Setup) 29 | 30 | ## Setup 31 | 32 | We use `uv` for environment setup. It is 10-100x faster than vanilla pip and conda. If you don't have `uv`, please install it from [here](https://docs.astral.sh/uv/getting-started/installation/) (no sudo required). If you're on a Linux environment you can install with: `curl -LsSf https://astral.sh/uv/install.sh | sh`. Of course, if you would like to use a virtual environment handled by vanilla python or conda, all package and their versions are provided in `pyproject.toml` under "dependencies." 33 | 34 | In the root directory, run 35 | ```bash 36 | uv sync 37 | ``` 38 | 39 | Then you can activate the environment with: 40 | ```bash 41 | source .venv/bin/activate 42 | ``` 43 | Note this is optional. You can run scripts with this venv without activating the environment by using `uv run python script.py` or abbreviated `uv run script.py`. 44 | 45 | `uv` will create a virtual environment for you and install all packages. 46 | 47 | Then to download the pretrained weights, run: 48 | ```bash 49 | uv run scripts/download_weights.py 50 | ``` 51 | This downloads pretrained weights into the `weights/` directory. 52 | 53 | Finally to run scripts, make them executable: 54 | ```bash 55 | chmod u+x scripts/* 56 | ``` 57 | 58 | Then you can run any script. For example: 59 | ```bash 60 | ./scripts/knee_multipatt.sh 61 | ``` 62 | 63 | By default weights & biases (WANDB) is disabled, so scripts will print results to stdout. If you want to visualize results in 64 | weights and biases, add your WANDB api key at the top of the script. We log image predictions 65 | as well as PSNR, NMSE, SSIM metrics for each epoch. 66 | ```bash 67 | export WANDB_API_KEY=*************** 68 | ``` 69 | 70 | Before you can begin training/inference, you will need to download and process the dataset. See the "Datasets" section below. 71 | 72 | ## Datasets 73 | 74 | We use the fastMRI dataset, which can be downloaded [here](https://fastmri.med.nyu.edu/). \ 75 | Dataset classes are provided in `fastmri/datasets.py`: 76 | - `SliceDatasetLMDB`: dataset in significantly faster LMDB format 77 | - `SliceDataset`: dataset class for original fastMRI dataset 78 | 79 | We convert the raw fastMRI HDF5 formatted samples into a significantly faster LMDB format. 80 | This accelerates training/validation by a significant factor. Once you have downloaded the fastMRI dataset, 81 | you will need to run `scripts/gen_lmdb_dataset.py` to convert the original fastMRI dataset into LMDB format. 82 | 83 | ```bash 84 | uv run scripts/gen_lmdb_dataset.py --body_part brain --partition val -o /path/to/lmdb/dataset 85 | ``` 86 | 87 | Do this for every dataset you need: (brain, knee) x (train, val). To choose a smaller subset for faster training/inference add `--sample_rate 0.Xx`. 88 | 89 | By default we use the LMDB format. If you want to use the original SliceDataset class, you can swap out the dataset class in `main.py`. 90 | 91 | Finally, modify your `fastmri.yaml` with the correct dataset paths 92 | 93 | ```yaml 94 | log_path: /tmp/logs 95 | checkpoint_path: /tmp/checkpoints 96 | 97 | lmdb: 98 | knee_train_path: **/**/**/knee_train_lmdb 99 | knee_val_path: **/**/**/knee_val_lmdb 100 | brain_train_path: **/**/**/brain_train_lmdb 101 | brain_val_path: **/**/**/brain_val_lmdb 102 | ``` 103 | 104 | ## Training and Validation 105 | 106 | `main.py` is used for both training and validation. We follow the original fastMRI repo 107 | and use Lightning. We provide both a simple PyTorch model `models/no_varnet.py` (if you want 108 | a thinner abstraction), as well as a Lightning wrapped `models/lightning/no_varnet_module.py` that 109 | makes distributed training across multiple GPUs easier. 110 | 111 | ## Citation 112 | 113 | If you found our work helpful or used any of our models (UDNO), please cite the following: 114 | ```bibtex 115 | @inproceedings{jatyani2025nomri, 116 | author = {Armeet Singh Jatyani* and Jiayun Wang* and Aditi Chandrashekar and Zihui Wu and Miguel Liu-Schiaffini and Bahareh Tolooshams and Anima Anandkumar}, 117 | title = {A Unified Model for Compressed Sensing MRI Across Undersampling Patterns}, 118 | booktitle = {Conference on Computer Vision and Pattern Recognition (CVPR) Proceedings}, 119 | abbr = {CVPR}, 120 | year = {2025} 121 | } 122 | ``` 123 | 124 | ![paper_preview](https://github.com/user-attachments/assets/7e6adaa5-a5fa-4b68-bd8c-5279f6f643d7) 125 | -------------------------------------------------------------------------------- /torch_harmonics/quadrature.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # 9 | # 1. Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | # 12 | # 2. Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | # 16 | # 3. Neither the name of the copyright holder nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | # 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | 32 | import numpy as np 33 | 34 | 35 | def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False): 36 | if (grid != "equidistant") and periodic: 37 | raise ValueError("Periodic grid is only supported on equidistant grids.") 38 | 39 | # compute coordinates 40 | if grid == "equidistant": 41 | xlg, wlg = trapezoidal_weights(n, a=a, b=b, periodic=periodic) 42 | elif grid == "legendre-gauss": 43 | xlg, wlg = legendre_gauss_weights(n, a=a, b=b) 44 | elif grid == "lobatto": 45 | xlg, wlg = lobatto_weights(n, a=a, b=b) 46 | elif grid == "equiangular": 47 | xlg, wlg = clenshaw_curtiss_weights(n, a=a, b=b) 48 | else: 49 | raise ValueError(f"Unknown grid type {grid}") 50 | 51 | return xlg, wlg 52 | 53 | 54 | def _precompute_latitudes(nlat, grid="equiangular"): 55 | r""" 56 | Convenience routine to precompute latitudes 57 | """ 58 | 59 | # compute coordinates 60 | xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False) 61 | 62 | lats = np.flip(np.arccos(xlg)).copy() 63 | wlg = np.flip(wlg).copy() 64 | 65 | return lats, wlg 66 | 67 | 68 | def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False): 69 | r""" 70 | Helper routine which returns equidistant nodes with trapezoidal weights 71 | on the interval [a, b] 72 | """ 73 | 74 | xlg = np.linspace(a, b, n) 75 | wlg = (b - a) / (n - 1) * np.ones(n) 76 | 77 | if not periodic: 78 | wlg[0] *= 0.5 79 | wlg[-1] *= 0.5 80 | 81 | return xlg, wlg 82 | 83 | 84 | def legendre_gauss_weights(n, a=-1.0, b=1.0): 85 | r""" 86 | Helper routine which returns the Legendre-Gauss nodes and weights 87 | on the interval [a, b] 88 | """ 89 | 90 | xlg, wlg = np.polynomial.legendre.leggauss(n) 91 | xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5 92 | wlg = wlg * (b - a) * 0.5 93 | 94 | return xlg, wlg 95 | 96 | 97 | def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100): 98 | r""" 99 | Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights 100 | on the interval [a, b] 101 | """ 102 | 103 | wlg = np.zeros((n,)) 104 | tlg = np.zeros((n,)) 105 | tmp = np.zeros((n,)) 106 | 107 | # Vandermonde Matrix 108 | vdm = np.zeros((n, n)) 109 | 110 | # initialize Chebyshev nodes as first guess 111 | for i in range(n): 112 | tlg[i] = -np.cos(np.pi * i / (n - 1)) 113 | 114 | tmp = 2.0 115 | 116 | for i in range(maxiter): 117 | tmp = tlg 118 | 119 | vdm[:, 0] = 1.0 120 | vdm[:, 1] = tlg 121 | 122 | for k in range(2, n): 123 | vdm[:, k] = ( 124 | (2 * k - 1) * tlg * vdm[:, k - 1] - (k - 1) * vdm[:, k - 2] 125 | ) / k 126 | 127 | tlg = tmp - (tlg * vdm[:, n - 1] - vdm[:, n - 2]) / (n * vdm[:, n - 1]) 128 | 129 | if max(abs(tlg - tmp).flatten()) < tol: 130 | break 131 | 132 | wlg = 2.0 / ((n * (n - 1)) * (vdm[:, n - 1] ** 2)) 133 | 134 | # rescale 135 | tlg = (b - a) * 0.5 * tlg + (b + a) * 0.5 136 | wlg = wlg * (b - a) * 0.5 137 | 138 | return tlg, wlg 139 | 140 | 141 | def clenshaw_curtiss_weights(n, a=-1.0, b=1.0): 142 | r""" 143 | Computation of the Clenshaw-Curtis quadrature nodes and weights. 144 | This implementation follows 145 | 146 | [1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018. 147 | """ 148 | 149 | assert n > 1 150 | 151 | tcc = np.cos(np.linspace(np.pi, 0, n)) 152 | 153 | if n == 2: 154 | wcc = np.array([1.0, 1.0]) 155 | else: 156 | n1 = n - 1 157 | N = np.arange(1, n1, 2) 158 | l = len(N) 159 | m = n1 - l 160 | 161 | v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)]) 162 | v = 0 - v[:-1] - v[-1:0:-1] 163 | 164 | g0 = -np.ones(n1) 165 | g0[l] = g0[l] + n1 166 | g0[m] = g0[m] + n1 167 | g = g0 / (n1**2 - 1 + (n1 % 2)) 168 | wcc = np.fft.ifft(v + g).real 169 | wcc = np.concatenate((wcc, wcc[:1])) 170 | 171 | # rescale 172 | tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5 173 | wcc = wcc * (b - a) * 0.5 174 | 175 | return tcc, wcc 176 | 177 | 178 | def fejer2_weights(n, a=-1.0, b=1.0): 179 | r""" 180 | Computation of the Fejer quadrature nodes and weights. 181 | This implementation follows 182 | 183 | [1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018. 184 | """ 185 | 186 | assert n > 2 187 | 188 | tcc = np.cos(np.linspace(np.pi, 0, n)) 189 | 190 | n1 = n - 1 191 | N = np.arange(1, n1, 2) 192 | l = len(N) 193 | m = n1 - l 194 | 195 | v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)]) 196 | v = 0 - v[:-1] - v[-1:0:-1] 197 | 198 | wcc = np.fft.ifft(v).real 199 | wcc = np.concatenate((wcc, wcc[:1])) 200 | 201 | # rescale 202 | tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5 203 | wcc = wcc * (b - a) * 0.5 204 | 205 | return tcc, wcc 206 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | 13 | class Unet(nn.Module): 14 | """ 15 | PyTorch implementation of a U-Net model. 16 | 17 | O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks 18 | for biomedical image segmentation. In International Conference on Medical 19 | image computing and computer-assisted intervention, pages 234–241. 20 | Springer, 2015. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | in_chans: int, 26 | out_chans: int, 27 | chans: int = 32, 28 | num_pool_layers: int = 4, 29 | drop_prob: float = 0.0, 30 | ): 31 | """ 32 | Parameters 33 | ---------- 34 | in_chans : int 35 | Number of channels in the input to the U-Net model. 36 | out_chans : int 37 | Number of channels in the output to the U-Net model. 38 | chans : int, optional 39 | Number of output channels of the first convolution layer. Default is 32. 40 | num_pool_layers : int, optional 41 | Number of down-sampling and up-sampling layers. Default is 4. 42 | drop_prob : float, optional 43 | Dropout probability. Default is 0.0. 44 | """ 45 | super().__init__() 46 | 47 | self.in_chans = in_chans 48 | self.out_chans = out_chans 49 | self.chans = chans 50 | self.num_pool_layers = num_pool_layers 51 | self.drop_prob = drop_prob 52 | 53 | self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) 54 | ch = chans 55 | for _ in range(num_pool_layers - 1): 56 | self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) 57 | ch *= 2 58 | self.conv = ConvBlock(ch, ch * 2, drop_prob) 59 | 60 | self.up_conv = nn.ModuleList() 61 | self.up_transpose_conv = nn.ModuleList() 62 | for _ in range(num_pool_layers - 1): 63 | self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) 64 | self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) 65 | ch //= 2 66 | 67 | self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) 68 | self.up_conv.append( 69 | nn.Sequential( 70 | ConvBlock(ch * 2, ch, drop_prob), 71 | nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), 72 | ) 73 | ) 74 | 75 | def forward(self, image: torch.Tensor) -> torch.Tensor: 76 | """ 77 | Parameters 78 | ---------- 79 | image : torch.Tensor 80 | Input 4D tensor of shape `(N, in_chans, H, W)`. 81 | 82 | Returns 83 | ------- 84 | torch.Tensor 85 | Output tensor of shape `(N, out_chans, H, W)`. 86 | """ 87 | stack = [] 88 | output = image 89 | 90 | # apply down-sampling layers 91 | for layer in self.down_sample_layers: 92 | output = layer(output) 93 | stack.append(output) 94 | output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) 95 | 96 | output = self.conv(output) 97 | 98 | # apply up-sampling layers 99 | for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): 100 | downsample_layer = stack.pop() 101 | output = transpose_conv(output) 102 | 103 | # reflect pad on the right/botton if needed to handle odd input dimensions 104 | padding = [0, 0, 0, 0] 105 | if output.shape[-1] != downsample_layer.shape[-1]: 106 | padding[1] = 1 # padding right 107 | if output.shape[-2] != downsample_layer.shape[-2]: 108 | padding[3] = 1 # padding bottom 109 | if torch.sum(torch.tensor(padding)) != 0: 110 | output = F.pad(output, padding, "reflect") 111 | 112 | output = torch.cat([output, downsample_layer], dim=1) 113 | output = conv(output) 114 | 115 | return output 116 | 117 | 118 | class ConvBlock(nn.Module): 119 | """ 120 | A Convolutional Block that consists of two convolution layers each followed by 121 | instance normalization, LeakyReLU activation and dropout. 122 | """ 123 | 124 | def __init__(self, in_chans: int, out_chans: int, drop_prob: float): 125 | """ 126 | Parameters 127 | ---------- 128 | in_chans : int 129 | Number of channels in the input. 130 | out_chans : int 131 | Number of channels in the output. 132 | drop_prob : float 133 | Dropout probability. 134 | """ 135 | super().__init__() 136 | 137 | self.in_chans = in_chans 138 | self.out_chans = out_chans 139 | self.drop_prob = drop_prob 140 | 141 | self.layers = nn.Sequential( 142 | nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), 143 | nn.InstanceNorm2d(out_chans), 144 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 145 | nn.Dropout2d(drop_prob), 146 | nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), 147 | nn.InstanceNorm2d(out_chans), 148 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 149 | nn.Dropout2d(drop_prob), 150 | ) 151 | 152 | def forward(self, image: torch.Tensor) -> torch.Tensor: 153 | """ 154 | Parameters 155 | ---------- 156 | image : ndarray 157 | Input 4D tensor of shape `(N, in_chans, H, W)`. 158 | 159 | Returns 160 | ------- 161 | ndarray 162 | Output tensor of shape `(N, out_chans, H, W)`. 163 | """ 164 | return self.layers(image) 165 | 166 | 167 | class TransposeConvBlock(nn.Module): 168 | """ 169 | A Transpose Convolutional Block that consists of one convolution transpose 170 | layers followed by instance normalization and LeakyReLU activation. 171 | """ 172 | 173 | def __init__(self, in_chans: int, out_chans: int): 174 | """ 175 | Parameters 176 | ---------- 177 | in_chans : int 178 | Number of channels in the input. 179 | out_chans : int 180 | Number of channels in the output. 181 | """ 182 | super().__init__() 183 | 184 | self.in_chans = in_chans 185 | self.out_chans = out_chans 186 | 187 | self.layers = nn.Sequential( 188 | nn.ConvTranspose2d( 189 | in_chans, out_chans, kernel_size=2, stride=2, bias=False 190 | ), 191 | nn.InstanceNorm2d(out_chans), 192 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 193 | ) 194 | 195 | def forward(self, image: torch.Tensor) -> torch.Tensor: 196 | """ 197 | Parameters 198 | ---------- 199 | image : torch.Tensor 200 | Input 4D tensor of shape `(N, in_chans, H, W)`. 201 | 202 | Returns 203 | ------- 204 | torch.Tensor 205 | Output tensor of shape `(N, out_chans, H*2, W*2)`. 206 | """ 207 | return self.layers(image) 208 | -------------------------------------------------------------------------------- /models/lightning/varnet_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from argparse import ArgumentParser 9 | 10 | import torch 11 | 12 | import fastmri 13 | from fastmri import transforms 14 | from ..varnet import VarNet 15 | 16 | from .mri_module import MriModule 17 | 18 | 19 | class VarNetModule(MriModule): 20 | """ 21 | VarNet training module. 22 | 23 | This can be used to train variational networks from the paper: 24 | 25 | A. Sriram et al. End-to-end variational networks for accelerated MRI 26 | reconstruction. In International Conference on Medical Image Computing and 27 | Computer-Assisted Intervention, 2020. 28 | 29 | which was inspired by the earlier paper: 30 | 31 | K. Hammernik et al. Learning a variational network for reconstruction of 32 | accelerated MRI data. Magnetic Resonance inMedicine, 79(6):3055–3071, 2018. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | num_cascades: int = 12, 38 | pools: int = 4, 39 | chans: int = 18, 40 | sens_pools: int = 4, 41 | sens_chans: int = 8, 42 | lr: float = 0.0003, 43 | lr_step_size: int = 40, 44 | lr_gamma: float = 0.1, 45 | weight_decay: float = 0.0, 46 | **kwargs, 47 | ): 48 | """ 49 | Parameters 50 | ---------- 51 | num_cascades : int 52 | Number of cascades (i.e., layers) for the variational network. 53 | pools : int 54 | Number of downsampling and upsampling layers for the cascade U-Net. 55 | chans : int 56 | Number of channels for the cascade U-Net. 57 | sens_pools : int 58 | Number of downsampling and upsampling layers for the sensitivity map U-Net. 59 | sens_chans : int 60 | Number of channels for the sensitivity map U-Net. 61 | lr : float 62 | Learning rate. 63 | lr_step_size : int 64 | Learning rate step size. 65 | lr_gamma : float 66 | Learning rate gamma decay. 67 | weight_decay : float 68 | Parameter for penalizing weights norm. 69 | num_sense_lines : int, optional 70 | Number of low-frequency lines to use for sensitivity map computation. 71 | Must be even or `None`. Default `None` will automatically compute the number 72 | from masks. Default behavior may cause some slices to use more low-frequency 73 | lines than others, when used in conjunction with e.g. the EquispacedMaskFunc 74 | defaults. To prevent this, either set `num_sense_lines`, or set 75 | `skip_low_freqs` and `skip_around_low_freqs` to `True` in the EquispacedMaskFunc. 76 | Note that setting this value may lead to undesired behavior when training on 77 | multiple accelerations simultaneously. 78 | """ 79 | super().__init__(**kwargs) 80 | self.save_hyperparameters() 81 | 82 | self.num_cascades = num_cascades 83 | self.pools = pools 84 | self.chans = chans 85 | self.sens_pools = sens_pools 86 | self.sens_chans = sens_chans 87 | self.lr = lr 88 | self.lr_step_size = lr_step_size 89 | self.lr_gamma = lr_gamma 90 | self.weight_decay = weight_decay 91 | 92 | self.varnet = VarNet( 93 | num_cascades=self.num_cascades, 94 | sens_chans=self.sens_chans, 95 | sens_pools=self.sens_pools, 96 | chans=self.chans, 97 | pools=self.pools, 98 | ) 99 | 100 | self.criterion = fastmri.SSIMLoss() 101 | self.num_params = sum(p.numel() for p in self.parameters()) 102 | 103 | def forward(self, masked_kspace, mask, num_low_frequencies): 104 | return self.varnet(masked_kspace, mask, num_low_frequencies) 105 | 106 | def training_step(self, batch, batch_idx): 107 | output = self.forward( 108 | batch.masked_kspace, batch.mask, batch.num_low_frequencies 109 | ) 110 | 111 | target, output = transforms.center_crop_to_smallest(batch.target, output) 112 | loss = self.criterion( 113 | output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value 114 | ) 115 | 116 | self.log("train_loss", loss, on_step=True, on_epoch=True) 117 | self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True) 118 | 119 | return loss 120 | 121 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 122 | dataloaders = self.trainer.val_dataloaders 123 | slug = list(dataloaders.keys())[dataloader_idx] 124 | 125 | # breakpoint() 126 | output = self.forward( 127 | batch.masked_kspace, batch.mask, batch.num_low_frequencies 128 | ) 129 | 130 | target, output = transforms.center_crop_to_smallest(batch.target, output) 131 | 132 | loss = self.criterion( 133 | output.unsqueeze(1), 134 | target.unsqueeze(1), 135 | data_range=batch.max_value, 136 | ) 137 | 138 | return { 139 | "slug": slug, 140 | "fname": batch.fname, 141 | "slice_num": batch.slice_num, 142 | "max_value": batch.max_value, 143 | "output": output, 144 | "target": target, 145 | "val_loss": loss, 146 | } 147 | 148 | def configure_optimizers(self): 149 | optim = torch.optim.Adam( 150 | self.parameters(), lr=self.lr, weight_decay=self.weight_decay 151 | ) 152 | scheduler = torch.optim.lr_scheduler.StepLR( 153 | optim, self.lr_step_size, self.lr_gamma 154 | ) 155 | 156 | return [optim], [scheduler] 157 | 158 | @staticmethod 159 | def add_model_specific_args(parent_parser): # pragma: no-cover 160 | """ 161 | Define parameters that only apply to this model 162 | """ 163 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 164 | parser = MriModule.add_model_specific_args(parser) 165 | 166 | # network params 167 | parser.add_argument( 168 | "--num_cascades", 169 | default=12, 170 | type=int, 171 | help="Number of VarNet cascades", 172 | ) 173 | parser.add_argument( 174 | "--pools", 175 | default=4, 176 | type=int, 177 | help="Number of U-Net pooling layers in VarNet blocks", 178 | ) 179 | parser.add_argument( 180 | "--chans", 181 | default=18, 182 | type=int, 183 | help="Number of channels for U-Net in VarNet blocks", 184 | ) 185 | parser.add_argument( 186 | "--sens_pools", 187 | default=4, 188 | type=int, 189 | help=("Number of pooling layers for sense map estimation U-Net in VarNet"), 190 | ) 191 | parser.add_argument( 192 | "--sens_chans", 193 | default=8, 194 | type=float, 195 | help="Number of channels for sense map estimation U-Net in VarNet", 196 | ) 197 | 198 | # training params (opt) 199 | parser.add_argument( 200 | "--lr", default=0.0003, type=float, help="Adam learning rate" 201 | ) 202 | parser.add_argument( 203 | "--lr_step_size", 204 | default=40, 205 | type=int, 206 | help="Epoch at which to decrease step size", 207 | ) 208 | parser.add_argument( 209 | "--lr_gamma", 210 | default=0.1, 211 | type=float, 212 | help="Extent to which step size should be decreased", 213 | ) 214 | parser.add_argument( 215 | "--weight_decay", 216 | default=0.0, 217 | type=float, 218 | help="Strength of weight decay regularization", 219 | ) 220 | 221 | return parser 222 | -------------------------------------------------------------------------------- /models/lightning/no_varnet_module.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import Tuple 3 | 4 | import fastmri 5 | import torch 6 | from fastmri import transforms 7 | from models.lightning.mri_module import MriModule 8 | from models.no_varnet import NOVarnet 9 | from type_utils import tuple_type 10 | 11 | 12 | class NOVarnetModule(MriModule): 13 | """ 14 | NO-Varnet training module. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | num_cascades: int = 12, 20 | pools: int = 4, 21 | chans: int = 18, 22 | sens_pools: int = 4, 23 | sens_chans: int = 8, 24 | kno_pools: int = 4, 25 | kno_chans: int = 16, 26 | kno_radius_cutoff: float = 0.02, 27 | kno_kernel_shape: Tuple[int, int] = (6, 7), 28 | radius_cutoff: float = 0.02, 29 | kernel_shape: Tuple[int, int] = (6, 7), 30 | in_shape: Tuple[int, int] = (320, 320), 31 | use_dc_term: bool = True, 32 | lr: float = 0.0003, 33 | lr_step_size: int = 40, 34 | lr_gamma: float = 0.1, 35 | weight_decay: float = 0.0, 36 | reduction_method: str = "rss", 37 | skip_method: str = "add", 38 | **kwargs, 39 | ): 40 | """ 41 | Parameters 42 | ---------- 43 | num_cascades : int 44 | Number of cascades (i.e., layers) for the variational network. 45 | pools : int 46 | Number of downsampling and upsampling layers for the cascade U-Net. 47 | chans : int 48 | Number of channels for the cascade U-Net. 49 | sens_pools : int 50 | Number of downsampling and upsampling layers for the sensitivity map U-Net. 51 | sens_chans : int 52 | Number of channels for the sensitivity map U-Net. 53 | lr : float 54 | Learning rate. 55 | lr_step_size : int 56 | Learning rate step size. 57 | lr_gamma : float 58 | Learning rate gamma decay. 59 | weight_decay : float 60 | Parameter for penalizing weights norm. 61 | """ 62 | super().__init__(**kwargs) 63 | self.save_hyperparameters() 64 | 65 | self.num_cascades = num_cascades 66 | self.pools = pools 67 | self.chans = chans 68 | self.sens_pools = sens_pools 69 | self.sens_chans = sens_chans 70 | self.kno_pools = kno_pools 71 | self.kno_chans = kno_chans 72 | self.kno_radius_cutoff = kno_radius_cutoff 73 | self.kno_kernel_shape = kno_kernel_shape 74 | self.radius_cutoff = radius_cutoff 75 | self.kernel_shape = kernel_shape 76 | self.in_shape = in_shape 77 | self.use_dc_term = use_dc_term 78 | self.lr = lr 79 | self.lr_step_size = lr_step_size 80 | self.lr_gamma = lr_gamma 81 | self.weight_decay = weight_decay 82 | self.reduction_method = reduction_method 83 | self.skip_method = skip_method 84 | 85 | self.model = NOVarnet( 86 | num_cascades=self.num_cascades, 87 | sens_chans=self.sens_chans, 88 | sens_pools=self.sens_pools, 89 | chans=self.chans, 90 | pools=self.pools, 91 | kno_chans=self.kno_chans, 92 | kno_pools=self.kno_pools, 93 | kno_radius_cutoff=self.kno_radius_cutoff, 94 | kno_kernel_shape=self.kno_kernel_shape, 95 | radius_cutoff=radius_cutoff, 96 | kernel_shape=kernel_shape, 97 | in_shape=in_shape, 98 | use_dc_term=use_dc_term, 99 | reduction_method=reduction_method, 100 | skip_method=skip_method, 101 | ) 102 | 103 | self.criterion = fastmri.SSIMLoss() 104 | self.num_params = sum(p.numel() for p in self.parameters()) 105 | 106 | def forward(self, masked_kspace, mask, num_low_frequencies): 107 | return self.model(masked_kspace, mask, num_low_frequencies) 108 | 109 | def training_step(self, batch, batch_idx): 110 | output = self.forward( 111 | batch.masked_kspace, batch.mask, batch.num_low_frequencies 112 | ) 113 | 114 | target, output = transforms.center_crop_to_smallest(batch.target, output) 115 | loss = self.criterion( 116 | output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value 117 | ) 118 | 119 | self.log("train_loss", loss, on_step=True, on_epoch=True) 120 | self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True) 121 | 122 | return loss 123 | 124 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 125 | dataloaders = self.trainer.val_dataloaders 126 | slug = list(dataloaders.keys())[dataloader_idx] 127 | 128 | output = self.forward( 129 | batch.masked_kspace, batch.mask, batch.num_low_frequencies 130 | ) 131 | 132 | target, output = transforms.center_crop_to_smallest(batch.target, output) 133 | 134 | loss = self.criterion( 135 | output.unsqueeze(1), 136 | target.unsqueeze(1), 137 | data_range=batch.max_value, 138 | ) 139 | 140 | return { 141 | "slug": slug, 142 | "fname": batch.fname, 143 | "slice_num": batch.slice_num, 144 | "max_value": batch.max_value, 145 | "output": output, 146 | "target": target, 147 | "val_loss": loss, 148 | } 149 | 150 | def configure_optimizers(self): 151 | optim = torch.optim.Adam( 152 | self.parameters(), lr=self.lr, weight_decay=self.weight_decay 153 | ) 154 | scheduler = torch.optim.lr_scheduler.StepLR( 155 | optim, self.lr_step_size, self.lr_gamma 156 | ) 157 | 158 | return [optim], [scheduler] 159 | 160 | @staticmethod 161 | def add_model_specific_args(parent_parser): 162 | """ 163 | Define parameters that only apply to this model 164 | """ 165 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 166 | parser = MriModule.add_model_specific_args(parser) 167 | 168 | # network params 169 | parser.add_argument( 170 | "--num_cascades", 171 | default=12, 172 | type=int, 173 | help="Number of VarNet cascades", 174 | ) 175 | parser.add_argument( 176 | "--pools", 177 | default=4, 178 | type=int, 179 | help="Number of U-Net pooling layers in VarNet blocks", 180 | ) 181 | parser.add_argument( 182 | "--chans", 183 | default=18, 184 | type=int, 185 | help="Number of channels for U-Net in VarNet blocks", 186 | ) 187 | parser.add_argument( 188 | "--sens_pools", 189 | default=4, 190 | type=int, 191 | help=("Number of pooling layers for sense map estimation U-Net in VarNet"), 192 | ) 193 | parser.add_argument( 194 | "--sens_chans", 195 | default=8, 196 | type=float, 197 | help="Number of channels for sense map estimation U-Net in VarNet", 198 | ) 199 | parser.add_argument( 200 | "--kno_pools", 201 | default=4, 202 | type=int, 203 | help=("Number of pooling layers for KNO"), 204 | ) 205 | parser.add_argument( 206 | "--kno_chans", 207 | default=16, 208 | type=int, 209 | help="Number of channels for KNO", 210 | ) 211 | parser.add_argument( 212 | "--kno_radius_cutoff", 213 | default=0.02, 214 | type=float, 215 | required=False, 216 | help="KNO module radius_cutoff", 217 | ) 218 | parser.add_argument( 219 | "--kno_kernel_shape", 220 | default=(6, 7), 221 | type=tuple_type, 222 | required=False, 223 | help="KNO module kernel_shape. Ex: 6,7 (no spaces please)", 224 | ) 225 | parser.add_argument( 226 | "--radius_cutoff", 227 | default=0.02, 228 | type=float, 229 | required=False, 230 | help="DISCO module radius_cutoff", 231 | ) 232 | parser.add_argument( 233 | "--kernel_shape", 234 | default=(6, 7), 235 | type=tuple_type, 236 | required=False, 237 | help="DISCO module kernel_shape. Ex: 6,7 (no spaces please)", 238 | ) 239 | parser.add_argument( 240 | "--in_shape", 241 | default=(640, 320), 242 | type=tuple_type, 243 | required=True, 244 | help="Spatial dimensions of masked_kspace samples. Ex: 320,320 (no spaces)", 245 | ) 246 | parser.add_argument( 247 | "--use_dc_term", 248 | default=True, 249 | type=bool, 250 | help="Whether to use the DC term in the unrolled iterative update step", 251 | ) 252 | 253 | # training params (opt) 254 | parser.add_argument( 255 | "--lr", default=0.0003, type=float, help="Adam learning rate" 256 | ) 257 | parser.add_argument( 258 | "--lr_step_size", 259 | default=40, 260 | type=int, 261 | help="Epoch at which to decrease step size", 262 | ) 263 | parser.add_argument( 264 | "--lr_gamma", 265 | default=0.1, 266 | type=float, 267 | help="Extent to which step size should be decreased", 268 | ) 269 | parser.add_argument( 270 | "--weight_decay", 271 | default=0.0, 272 | type=float, 273 | help="Strength of weight decay regularization", 274 | ) 275 | parser.add_argument( 276 | "--reduction_method", 277 | default="batch", 278 | type=str, 279 | choices=["rss", "batch"], 280 | help="Reduction method used to reduce multi-channel k-space data before inpainting module. Read documentation of KNO for more information.", 281 | ) 282 | parser.add_argument( 283 | "--skip_method", 284 | default="replace", 285 | type=str, 286 | choices=["add_inv", "add", "concat", "replace"], 287 | help="Method for skip connection around inpainting module.", 288 | ) 289 | 290 | return parser 291 | -------------------------------------------------------------------------------- /models/udno.py: -------------------------------------------------------------------------------- 1 | """ 2 | U-shaped DISCO Neural Operator 3 | """ 4 | 5 | from typing import Tuple 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from torch_harmonics.convolution import ( 12 | EquidistantDiscreteContinuousConv2d as DISCO2d, 13 | ) 14 | 15 | 16 | class UDNO(nn.Module): 17 | """ 18 | U-shaped DISCO Neural Operator in PyTorch 19 | """ 20 | 21 | def __init__( 22 | self, 23 | in_chans: int, 24 | out_chans: int, 25 | radius_cutoff: float, 26 | chans: int = 32, 27 | num_pool_layers: int = 4, 28 | drop_prob: float = 0.0, 29 | in_shape: Tuple[int, int] = (320, 320), 30 | kernel_shape: Tuple[int, int] = (3, 4), 31 | ): 32 | """ 33 | Parameters 34 | ---------- 35 | in_chans : int 36 | Number of channels in the input to the U-Net model. 37 | out_chans : int 38 | Number of channels in the output to the U-Net model. 39 | radius_cutoff : float 40 | Control the effective radius of the DISCO kernel. Values are 41 | between 0.0 and 1.0. The radius_cutoff is represented as a proportion 42 | of the normalized input space, to ensure that kernels are resolution 43 | invaraint. 44 | chans : int, optional 45 | Number of output channels of the first DISCO layer. Default is 32. 46 | num_pool_layers : int, optional 47 | Number of down-sampling and up-sampling layers. Default is 4. 48 | drop_prob : float, optional 49 | Dropout probability. Default is 0.0. 50 | in_shape : Tuple[int, int] 51 | Shape of the input to the UDNO. This is required to dynamically 52 | compile DISCO kernels for resolution invariance. 53 | kernel_shape : Tuple[int, int], optional 54 | Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3 55 | rings and 4 anisotropic basis functions. Under the hood, each DISCO 56 | kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard 57 | 3x3 convolution kernel. 58 | 59 | Note: This is NOT kernel_size, as under the DISCO framework, 60 | kernels are dynamically compiled to support resolution invariance. 61 | """ 62 | super().__init__() 63 | assert len(in_shape) == 2, "Input shape must be 2D" 64 | 65 | self.in_chans = in_chans 66 | self.out_chans = out_chans 67 | self.chans = chans 68 | self.num_pool_layers = num_pool_layers 69 | self.drop_prob = drop_prob 70 | self.in_shape = in_shape 71 | self.kernel_shape = kernel_shape 72 | 73 | self.down_sample_layers = nn.ModuleList( 74 | [ 75 | DISCOBlock( 76 | in_chans, 77 | chans, 78 | radius_cutoff, 79 | drop_prob, 80 | in_shape, 81 | kernel_shape, 82 | ) 83 | ] 84 | ) 85 | ch = chans 86 | shape = (in_shape[0] // 2, in_shape[1] // 2) 87 | radius_cutoff = radius_cutoff * 2 88 | for _ in range(num_pool_layers - 1): 89 | self.down_sample_layers.append( 90 | DISCOBlock( 91 | ch, 92 | ch * 2, 93 | radius_cutoff, 94 | drop_prob, 95 | in_shape=shape, 96 | kernel_shape=kernel_shape, 97 | ) 98 | ) 99 | ch *= 2 100 | shape = (shape[0] // 2, shape[1] // 2) 101 | radius_cutoff *= 2 102 | 103 | # test commit 104 | 105 | self.bottleneck = DISCOBlock( 106 | ch, 107 | ch * 2, 108 | radius_cutoff, 109 | drop_prob, 110 | in_shape=shape, 111 | kernel_shape=kernel_shape, 112 | ) 113 | 114 | self.up = nn.ModuleList() 115 | self.up_transpose = nn.ModuleList() 116 | for _ in range(num_pool_layers - 1): 117 | self.up_transpose.append( 118 | TransposeDISCOBlock( 119 | ch * 2, 120 | ch, 121 | radius_cutoff, 122 | in_shape=shape, 123 | kernel_shape=kernel_shape, 124 | ) 125 | ) 126 | shape = (shape[0] * 2, shape[1] * 2) 127 | radius_cutoff /= 2 128 | self.up.append( 129 | DISCOBlock( 130 | ch * 2, 131 | ch, 132 | radius_cutoff, 133 | drop_prob, 134 | in_shape=shape, 135 | kernel_shape=kernel_shape, 136 | ) 137 | ) 138 | ch //= 2 139 | 140 | self.up_transpose.append( 141 | TransposeDISCOBlock( 142 | ch * 2, 143 | ch, 144 | radius_cutoff, 145 | in_shape=shape, 146 | kernel_shape=kernel_shape, 147 | ) 148 | ) 149 | shape = (shape[0] * 2, shape[1] * 2) 150 | radius_cutoff /= 2 151 | self.up.append( 152 | nn.Sequential( 153 | DISCOBlock( 154 | ch * 2, 155 | ch, 156 | radius_cutoff, 157 | drop_prob, 158 | in_shape=shape, 159 | kernel_shape=kernel_shape, 160 | ), 161 | nn.Conv2d( 162 | ch, self.out_chans, kernel_size=1, stride=1 163 | ), # 1x1 conv is always res-invariant (pixel wise channel transformation) 164 | ) 165 | ) 166 | 167 | def forward(self, image: torch.Tensor) -> torch.Tensor: 168 | """ 169 | Parameters 170 | ---------- 171 | image : torch.Tensor 172 | Input 4D tensor of shape `(N, in_chans, H, W)`. 173 | 174 | Returns 175 | ------- 176 | torch.Tensor 177 | Output tensor of shape `(N, out_chans, H, W)`. 178 | """ 179 | stack = [] 180 | output = image 181 | 182 | # apply down-sampling layers 183 | for layer in self.down_sample_layers: 184 | output = layer(output) 185 | stack.append(output) 186 | output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) 187 | 188 | output = self.bottleneck(output) 189 | 190 | # apply up-sampling layers 191 | for transpose, disco in zip(self.up_transpose, self.up): 192 | downsample_layer = stack.pop() 193 | output = transpose(output) 194 | 195 | # reflect pad on the right/botton if needed to handle odd input dimensions 196 | padding = [0, 0, 0, 0] 197 | if output.shape[-1] != downsample_layer.shape[-1]: 198 | padding[1] = 1 # padding right 199 | if output.shape[-2] != downsample_layer.shape[-2]: 200 | padding[3] = 1 # padding bottom 201 | if torch.sum(torch.tensor(padding)) != 0: 202 | output = F.pad(output, padding, "reflect") 203 | 204 | output = torch.cat([output, downsample_layer], dim=1) 205 | output = disco(output) 206 | 207 | return output 208 | 209 | 210 | class DISCOBlock(nn.Module): 211 | """ 212 | A DISCO Block that consists of two DISCO layers each followed by 213 | instance normalization, LeakyReLU activation and dropout. 214 | """ 215 | 216 | def __init__( 217 | self, 218 | in_chans: int, 219 | out_chans: int, 220 | radius_cutoff: float, 221 | drop_prob: float, 222 | in_shape: Tuple[int, int], 223 | kernel_shape: Tuple[int, int] = (3, 4), 224 | ): 225 | """ 226 | Parameters 227 | ---------- 228 | in_chans : int 229 | Number of channels in the input. 230 | out_chans : int 231 | Number of channels in the output. 232 | radius_cutoff : float 233 | Control the effective radius of the DISCO kernel. Values are 234 | between 0.0 and 1.0. The radius_cutoff is represented as a proportion 235 | of the normalized input space, to ensure that kernels are resolution 236 | invaraint. 237 | in_shape : Tuple[int] 238 | Unbatched spatial 2D shape of the input to this block. 239 | Rrequired to dynamically compile DISCO kernels for resolution invariance. 240 | kernel_shape : Tuple[int, int], optional 241 | Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3 242 | rings and 4 anisotropic basis functions. Under the hood, each DISCO 243 | kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard 244 | 3x3 convolution kernel. 245 | 246 | Note: This is NOT kernel_size, as under the DISCO framework, 247 | kernels are dynamically compiled to support resolution invariance. 248 | drop_prob : float 249 | Dropout probability. 250 | """ 251 | super().__init__() 252 | 253 | self.in_chans = in_chans 254 | self.out_chans = out_chans 255 | self.drop_prob = drop_prob 256 | 257 | self.layers = nn.Sequential( 258 | DISCO2d( 259 | in_chans, 260 | out_chans, 261 | kernel_shape=kernel_shape, 262 | in_shape=in_shape, 263 | bias=False, 264 | radius_cutoff=radius_cutoff, 265 | padding_mode="constant", 266 | ), 267 | nn.InstanceNorm2d(out_chans), 268 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 269 | nn.Dropout2d(drop_prob), 270 | DISCO2d( 271 | out_chans, 272 | out_chans, 273 | kernel_shape=kernel_shape, 274 | in_shape=in_shape, 275 | bias=False, 276 | radius_cutoff=radius_cutoff, 277 | padding_mode="constant", 278 | ), 279 | nn.InstanceNorm2d(out_chans), 280 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 281 | nn.Dropout2d(drop_prob), 282 | ) 283 | 284 | def forward(self, image: torch.Tensor) -> torch.Tensor: 285 | """ 286 | Parameters 287 | ---------- 288 | image : ndarray 289 | Input 4D tensor of shape `(N, in_chans, H, W)`. 290 | 291 | Returns 292 | ------- 293 | ndarray 294 | Output tensor of shape `(N, out_chans, H, W)`. 295 | """ 296 | return self.layers(image) 297 | 298 | 299 | class TransposeDISCOBlock(nn.Module): 300 | """ 301 | A transpose DISCO Block that consists of an up-sampling layer followed by a 302 | DISCO layer, instance normalization, and LeakyReLU activation. 303 | """ 304 | 305 | def __init__( 306 | self, 307 | in_chans: int, 308 | out_chans: int, 309 | radius_cutoff: float, 310 | in_shape: Tuple[int, int], 311 | kernel_shape: Tuple[int, int] = (3, 4), 312 | ): 313 | """ 314 | Parameters 315 | ---------- 316 | in_chans : int 317 | Number of channels in the input. 318 | out_chans : int 319 | Number of channels in the output. 320 | radius_cutoff : float 321 | Control the effective radius of the DISCO kernel. Values are 322 | between 0.0 and 1.0. The radius_cutoff is represented as a proportion 323 | of the normalized input space, to ensure that kernels are resolution 324 | invaraint. 325 | in_shape : Tuple[int] 326 | Unbatched spatial 2D shape of the input to this block. 327 | Rrequired to dynamically compile DISCO kernels for resolution invariance. 328 | kernel_shape : Tuple[int, int], optional 329 | Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3 330 | rings and 4 anisotropic basis functions. Under the hood, each DISCO 331 | kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard 332 | 3x3 convolution kernel. 333 | 334 | Note: This is NOT kernel_size, as under the DISCO framework, 335 | kernels are dynamically compiled to support resolution invariance 336 | """ 337 | super().__init__() 338 | 339 | self.in_chans = in_chans 340 | self.out_chans = out_chans 341 | 342 | self.layers = nn.Sequential( 343 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), 344 | DISCO2d( 345 | in_chans, 346 | out_chans, 347 | kernel_shape=kernel_shape, 348 | in_shape=(2 * in_shape[0], 2 * in_shape[1]), 349 | bias=False, 350 | radius_cutoff=(radius_cutoff / 2), 351 | padding_mode="constant", 352 | ), 353 | nn.InstanceNorm2d(out_chans), 354 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 355 | ) 356 | 357 | def forward(self, image: torch.Tensor) -> torch.Tensor: 358 | """ 359 | Parameters 360 | ---------- 361 | image : torch.Tensor 362 | Input 4D tensor of shape `(N, in_chans, H, W)`. 363 | 364 | Returns 365 | ------- 366 | torch.Tensor 367 | Output tensor of shape `(N, out_chans, H*2, W*2)`. 368 | """ 369 | return self.layers(image) 370 | -------------------------------------------------------------------------------- /models/varnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import math 9 | import os 10 | from typing import List, Optional, Tuple 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | import fastmri 17 | from fastmri import transforms 18 | from models.unet import Unet 19 | 20 | 21 | class NormUnet(nn.Module): 22 | """ 23 | Normalized U-Net model. 24 | 25 | This is the same as a regular U-Net, but with normalization applied to the 26 | input before the U-Net. This keeps the values more numerically stable 27 | during training. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | chans: int, 33 | num_pools: int, 34 | in_chans: int = 2, 35 | out_chans: int = 2, 36 | drop_prob: float = 0.0, 37 | ): 38 | """ 39 | 40 | Initialize the VarNet model. 41 | 42 | Parameters 43 | ---------- 44 | chans : int 45 | Number of output channels of the first convolution layer. 46 | num_pools : int 47 | Number of down-sampling and up-sampling layers. 48 | in_chans : int, optional 49 | Number of channels in the input to the U-Net model. Default is 2. 50 | out_chans : int, optional 51 | Number of channels in the output to the U-Net model. Default is 2. 52 | drop_prob : float, optional 53 | Dropout probability. Default is 0.0. 54 | """ 55 | super().__init__() 56 | 57 | self.unet = Unet( 58 | in_chans=in_chans, 59 | out_chans=out_chans, 60 | chans=chans, 61 | num_pool_layers=num_pools, 62 | drop_prob=drop_prob, 63 | ) 64 | 65 | def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: 66 | b, c, h, w, two = x.shape 67 | assert two == 2 68 | return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) 69 | 70 | def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: 71 | b, c2, h, w = x.shape 72 | assert c2 % 2 == 0 73 | c = c2 // 2 74 | return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() 75 | 76 | def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 77 | # group norm 78 | b, c, h, w = x.shape 79 | x = x.view(b, 2, c // 2 * h * w) 80 | 81 | mean = x.mean(dim=2).view(b, 2, 1, 1) 82 | std = x.std(dim=2).view(b, 2, 1, 1) 83 | 84 | x = x.view(b, c, h, w) 85 | 86 | return (x - mean) / std, mean, std 87 | 88 | def unnorm( 89 | self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor 90 | ) -> torch.Tensor: 91 | return x * std + mean 92 | 93 | def pad( 94 | self, x: torch.Tensor 95 | ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: 96 | _, _, h, w = x.shape 97 | w_mult = ((w - 1) | 15) + 1 98 | h_mult = ((h - 1) | 15) + 1 99 | w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] 100 | h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] 101 | # TODO: fix this type when PyTorch fixes theirs 102 | # the documentation lies - this actually takes a list 103 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457 104 | # https://github.com/pytorch/pytorch/pull/16949 105 | x = F.pad(x, w_pad + h_pad) 106 | 107 | return x, (h_pad, w_pad, h_mult, w_mult) 108 | 109 | def unpad( 110 | self, 111 | x: torch.Tensor, 112 | h_pad: List[int], 113 | w_pad: List[int], 114 | h_mult: int, 115 | w_mult: int, 116 | ) -> torch.Tensor: 117 | return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] 118 | 119 | def forward(self, x: torch.Tensor) -> torch.Tensor: 120 | if not x.shape[-1] == 2: 121 | raise ValueError("Last dimension must be 2 for complex.") 122 | 123 | # get shapes for unet and normalize 124 | x = self.complex_to_chan_dim(x) 125 | x, mean, std = self.norm(x) 126 | x, pad_sizes = self.pad(x) 127 | 128 | x = self.unet(x) 129 | 130 | # get shapes back and unnormalize 131 | x = self.unpad(x, *pad_sizes) 132 | x = self.unnorm(x, mean, std) 133 | x = self.chan_complex_to_last_dim(x) 134 | 135 | return x 136 | 137 | 138 | class SensitivityModel(nn.Module): 139 | """ 140 | Model for learning sensitivity estimation from k-space data. 141 | 142 | This model applies an IFFT to multichannel k-space data and then a U-Net 143 | to the coil images to estimate coil sensitivities. It can be used with the 144 | end-to-end variational network. 145 | 146 | Input: multi-coil k-space data 147 | Output: multi-coil spatial domain sensitivity maps 148 | """ 149 | 150 | def __init__( 151 | self, 152 | chans: int, 153 | num_pools: int, 154 | in_chans: int = 2, 155 | out_chans: int = 2, 156 | drop_prob: float = 0.0, 157 | mask_center: bool = True, 158 | ): 159 | """ 160 | Parameters 161 | ---------- 162 | chans : int 163 | Number of output channels of the first convolution layer. 164 | num_pools : int 165 | Number of down-sampling and up-sampling layers. 166 | in_chans : int, optional 167 | Number of channels in the input to the U-Net model. Default is 2. 168 | out_chans : int, optional 169 | Number of channels in the output to the U-Net model. Default is 2. 170 | drop_prob : float, optional 171 | Dropout probability. Default is 0.0. 172 | mask_center : bool, optional 173 | Whether to mask center of k-space for sensitivity map calculation. 174 | Default is True. 175 | """ 176 | super().__init__() 177 | self.mask_center = mask_center 178 | self.norm_unet = NormUnet( 179 | chans, 180 | num_pools, 181 | in_chans=in_chans, 182 | out_chans=out_chans, 183 | drop_prob=drop_prob, 184 | ) 185 | 186 | def chans_to_batch_dim(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: 187 | b, c, h, w, comp = x.shape 188 | 189 | return x.view(b * c, 1, h, w, comp), b 190 | 191 | def batch_chans_to_chan_dim( 192 | self, 193 | x: torch.Tensor, 194 | batch_size: int, 195 | ) -> torch.Tensor: 196 | bc, _, h, w, comp = x.shape 197 | c = bc // batch_size 198 | 199 | return x.view(batch_size, c, h, w, comp) 200 | 201 | def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor: 202 | return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1) 203 | 204 | def get_pad_and_num_low_freqs( 205 | self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None 206 | ) -> Tuple[torch.Tensor, torch.Tensor]: 207 | if num_low_frequencies is None or any( 208 | torch.any(t == 0) for t in num_low_frequencies 209 | ): 210 | # get low frequency line locations and mask them out 211 | squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8) 212 | cent = squeezed_mask.shape[1] // 2 213 | # running argmin returns the first non-zero 214 | left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1) 215 | right = torch.argmin(squeezed_mask[:, cent:], dim=1) 216 | num_low_frequencies_tensor = torch.max( 217 | 2 * torch.min(left, right), torch.ones_like(left) 218 | ) # force a symmetric center unless 1 219 | else: 220 | num_low_frequencies_tensor = num_low_frequencies * torch.ones( 221 | mask.shape[0], dtype=mask.dtype, device=mask.device 222 | ) 223 | 224 | pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2 225 | 226 | return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long) 227 | 228 | def forward( 229 | self, 230 | masked_kspace: torch.Tensor, 231 | mask: torch.Tensor, 232 | num_low_frequencies: Optional[int] = None, 233 | ) -> torch.Tensor: 234 | if self.mask_center: 235 | pad, num_low_freqs = self.get_pad_and_num_low_freqs( 236 | mask, num_low_frequencies 237 | ) 238 | masked_kspace = transforms.batched_mask_center( 239 | masked_kspace, pad, pad + num_low_freqs 240 | ) 241 | 242 | # convert to image space 243 | images, batches = self.chans_to_batch_dim(fastmri.ifft2c(masked_kspace)) 244 | 245 | # estimate sensitivities 246 | return self.divide_root_sum_of_squares( 247 | self.batch_chans_to_chan_dim(self.norm_unet(images), batches) 248 | ) 249 | 250 | 251 | class VarNet(nn.Module): 252 | """ 253 | A full variational network model. 254 | 255 | This model applies a combination of soft data consistency with a U-Net 256 | regularizer. To use non-U-Net regularizers, use VarNetBlock. 257 | 258 | Input: multi-channel k-space data 259 | Output: single-channel RSS reconstructed image 260 | """ 261 | 262 | def __init__( 263 | self, 264 | num_cascades: int = 12, 265 | sens_chans: int = 8, 266 | sens_pools: int = 4, 267 | chans: int = 18, 268 | pools: int = 4, 269 | mask_center: bool = True, 270 | ): 271 | """ 272 | Parameters 273 | ---------- 274 | num_cascades : int 275 | Number of cascades (i.e., layers) for variational network. 276 | sens_chans : int 277 | Number of channels for sensitivity map U-Net. 278 | sens_pools : int 279 | Number of downsampling and upsampling layers for sensitivity map U-Net. 280 | chans : int 281 | Number of channels for cascade U-Net. 282 | pools : int 283 | Number of downsampling and upsampling layers for cascade U-Net. 284 | mask_center : bool 285 | Whether to mask center of k-space for sensitivity map calculation. 286 | """ 287 | 288 | super().__init__() 289 | 290 | self.sens_net = SensitivityModel( 291 | chans=sens_chans, 292 | num_pools=sens_pools, 293 | mask_center=mask_center, 294 | ) 295 | self.cascades = nn.ModuleList( 296 | [VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)] 297 | ) 298 | 299 | def forward( 300 | self, 301 | masked_kspace: torch.Tensor, 302 | mask: torch.Tensor, 303 | num_low_frequencies: Optional[int] = None, 304 | ) -> torch.Tensor: 305 | sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) 306 | kspace_pred = masked_kspace.clone() 307 | for cascade in self.cascades: 308 | kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) 309 | 310 | spatial_pred = fastmri.ifft2c(kspace_pred) 311 | 312 | # ---------> FIXME: CHANGE FOR MVUE MODE 313 | if self.training and os.getenv("MVUE") in ["yes", "1", "true", "True"]: 314 | combined_spatial = fastmri.mvue(spatial_pred, sens_maps, dim=1) 315 | else: 316 | spatial_pred_abs = fastmri.complex_abs(spatial_pred) 317 | combined_spatial = fastmri.rss(spatial_pred_abs, dim=1) 318 | return combined_spatial 319 | 320 | 321 | class VarNetBlock(nn.Module): 322 | """ 323 | Model block for end-to-end variational network (refinemnt module) 324 | 325 | This model applies a combination of soft data consistency with the input 326 | model as a regularizer. A series of these blocks can be stacked to form 327 | the full variational network. 328 | 329 | Input: multi-channel k-space data 330 | Output: multi-channel k-space data 331 | """ 332 | 333 | def __init__(self, model: nn.Module): 334 | """ 335 | Parameters 336 | ---------- 337 | model : nn.Module 338 | Module for "regularization" component of variational network. 339 | """ 340 | super().__init__() 341 | 342 | self.model = model 343 | self.dc_weight = nn.Parameter(torch.ones(1)) 344 | 345 | def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 346 | """ 347 | Calculates F (x sens_maps) 348 | """ 349 | return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) 350 | 351 | def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 352 | """ 353 | Calculates F^{-1}(x) \overline{sens_maps} 354 | where \overline{sens_maps} is the element-wise applied complex conjugate 355 | """ 356 | return fastmri.complex_mul( 357 | fastmri.ifft2c(x), fastmri.complex_conj(sens_maps) 358 | ).sum(dim=1, keepdim=True) 359 | 360 | def forward( 361 | self, 362 | current_kspace: torch.Tensor, 363 | ref_kspace: torch.Tensor, 364 | mask: torch.Tensor, 365 | sens_maps: torch.Tensor, 366 | ) -> torch.Tensor: 367 | """ 368 | Parameters 369 | ---------- 370 | current_kspace : torch.Tensor 371 | The current k-space data (frequency domain data) being processed by the network. 372 | ref_kspace : torch.Tensor 373 | The reference k-space data (measured data) used for data consistency. 374 | mask : torch.Tensor 375 | A binary mask indicating the locations in k-space where data consistency should be enforced. 376 | sens_maps : torch.Tensor 377 | Sensitivity maps for the different coils in parallel imaging. 378 | 379 | Returns 380 | ------- 381 | torch.Tensor 382 | The output k-space data after applying the variational network block. 383 | """ 384 | 385 | """ 386 | Model term: 387 | - Reduces the current k-space data using the sensitivity maps (inverse Fourier transform followed by element-wise multiplication and summation). 388 | - Applies the neural network model to the reduced data. 389 | - Expands the output of the model using the sensitivity maps (element-wise multiplication followed by Fourier transform). 390 | """ 391 | 392 | model_term = self.sens_expand( 393 | self.model(self.sens_reduce(current_kspace, sens_maps)), sens_maps 394 | ) 395 | 396 | """ 397 | Soft data consistency term: 398 | - Calculates the difference between current k-space and reference k-space where the mask is true. 399 | - Multiplies this difference by the data consistency weight. 400 | """ 401 | zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace) 402 | soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight 403 | 404 | # with data consistency term (removed for single cascade experiments) 405 | return current_kspace - soft_dc - model_term 406 | -------------------------------------------------------------------------------- /models/lightning/mri_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified for use in 3 | - minified and removed extraneous abstractions 4 | - updated to latest version of lightning 5 | 6 | Copyright (c) Facebook, Inc. and its affiliates. 7 | 8 | This source code is licensed under the MIT license found in the 9 | LICENSE file in the root directory of this source tree. 10 | """ 11 | 12 | from argparse import ArgumentParser 13 | from collections import defaultdict 14 | from io import BytesIO 15 | 16 | import matplotlib 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | import torch 20 | from PIL import Image 21 | from torchmetrics.metric import Metric 22 | 23 | import lightning as L 24 | 25 | matplotlib.use("Agg") 26 | 27 | from fastmri import evaluate 28 | 29 | 30 | class DistributedMetricSum(Metric): 31 | def __init__(self, dist_sync_on_step=True): 32 | super().__init__(dist_sync_on_step=dist_sync_on_step) 33 | 34 | self.add_state("quantity", default=torch.tensor(0.0), dist_reduce_fx="sum") 35 | 36 | def update(self, batch: torch.Tensor): # type: ignore 37 | self.quantity += batch 38 | 39 | def compute(self): 40 | return self.quantity 41 | 42 | 43 | class MriModule(L.LightningModule): 44 | """ 45 | Abstract super class for deep learning reconstruction models. 46 | 47 | This is a subclass of the LightningModule class from lightning, 48 | with some additional functionality specific to fastMRI: 49 | - Evaluating reconstructions 50 | - Visualization 51 | 52 | To implement a new reconstruction model, inherit from this class and 53 | implement the following methods: 54 | - training_step: Define what happens in one step of training 55 | - validation_step: Define what happens in one step of validation 56 | - test_step: Define what happens in one step of testing 57 | - configure_optimizers: Create and return the optimizers 58 | 59 | Other methods from LightningModule can be overridden as needed. 60 | """ 61 | 62 | def __init__(self, num_log_images: int = 16): 63 | """ 64 | Initialize the MRI module. 65 | 66 | Parameters 67 | ---------- 68 | num_log_images : int, optional 69 | Number of images to log. Defaults to 16. 70 | """ 71 | super().__init__() 72 | 73 | self.num_log_images = num_log_images 74 | self.val_log_indices = [1, 2, 3, 4, 5] 75 | self.val_batch_results = [] 76 | 77 | self.NMSE = DistributedMetricSum() 78 | self.SSIM = DistributedMetricSum() 79 | self.PSNR = DistributedMetricSum() 80 | self.ValLoss = DistributedMetricSum() 81 | self.TotExamples = DistributedMetricSum() 82 | self.TotSliceExamples = DistributedMetricSum() 83 | 84 | def log_image(self, name, image): 85 | if self.logger is not None: 86 | self.logger.log_image( 87 | key=f"{name}", images=[image], caption=[{self.global_step}] 88 | ) 89 | 90 | def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): 91 | # breakpoint() 92 | val_logs = outputs 93 | 94 | mse_vals = defaultdict(dict) 95 | target_norms = defaultdict(dict) 96 | ssim_vals = defaultdict(dict) 97 | max_vals = dict() 98 | 99 | for i, fname in enumerate(val_logs["fname"]): 100 | if i == 0 and batch_idx in self.val_log_indices: 101 | key = f"val_images_idx_{batch_idx}" 102 | target = val_logs["target"][i].unsqueeze(0) 103 | output = val_logs["output"][i].unsqueeze(0) 104 | error = torch.abs(target - output) 105 | output = output / output.max() 106 | target = target / target.max() 107 | error = error / error.max() 108 | self.log_image(f"{key}/target", target) 109 | self.log_image(f"{key}/reconstruction", output) 110 | self.log_image(f"{key}/error", error) 111 | slice_num = int(val_logs["slice_num"][i].cpu()) 112 | 113 | maxval = val_logs["max_value"][i].cpu().numpy() 114 | output = val_logs["output"][i].cpu().numpy() 115 | target = val_logs["target"][i].cpu().numpy() 116 | mse_vals[fname][slice_num] = torch.tensor( 117 | evaluate.mse(target, output) 118 | ).view(1) 119 | target_norms[fname][slice_num] = torch.tensor( 120 | evaluate.mse(target, np.zeros_like(target)) 121 | ).view(1) 122 | ssim_vals[fname][slice_num] = torch.tensor( 123 | evaluate.ssim(target[None, ...], output[None, ...], maxval=maxval) 124 | ).view(1) 125 | max_vals[fname] = maxval 126 | 127 | self.val_batch_results.append( 128 | { 129 | "slug": val_logs["slug"], 130 | "val_loss": val_logs["val_loss"], 131 | "mse_vals": dict(mse_vals), 132 | "target_norms": dict(target_norms), 133 | "ssim_vals": dict(ssim_vals), 134 | "max_vals": max_vals, 135 | } 136 | ) 137 | 138 | def on_validation_epoch_end(self): 139 | val_logs = self.val_batch_results 140 | 141 | dataset_metrics = defaultdict( 142 | lambda: { 143 | "losses": [], 144 | "mse_vals": defaultdict(dict), 145 | "target_norms": defaultdict(dict), 146 | "ssim_vals": defaultdict(dict), 147 | "max_vals": dict(), 148 | } 149 | ) 150 | 151 | # use dict updates to handle duplicate slices 152 | for val_log in val_logs: 153 | slug = val_log["slug"] 154 | dataset_metrics[slug]["losses"].append(val_log["val_loss"].view(-1)) 155 | 156 | for k in val_log["mse_vals"].keys(): 157 | dataset_metrics[slug]["mse_vals"][k].update(val_log["mse_vals"][k]) 158 | for k in val_log["target_norms"].keys(): 159 | dataset_metrics[slug]["target_norms"][k].update( 160 | val_log["target_norms"][k] 161 | ) 162 | for k in val_log["ssim_vals"].keys(): 163 | dataset_metrics[slug]["ssim_vals"][k].update(val_log["ssim_vals"][k]) 164 | for k in val_log["max_vals"]: 165 | dataset_metrics[slug]["max_vals"][k] = val_log["max_vals"][k] 166 | 167 | metrics_to_plot = {"psnr": [], "ssim": [], "nmse": []} 168 | slugs = [] 169 | 170 | for slug, metrics_data in dataset_metrics.items(): 171 | mse_vals, target_norms, ssim_vals, max_vals, losses = ( 172 | metrics_data["mse_vals"], 173 | metrics_data["target_norms"], 174 | metrics_data["ssim_vals"], 175 | metrics_data["max_vals"], 176 | metrics_data["losses"], 177 | ) 178 | # check to make sure we have all files in all metrics 179 | assert ( 180 | mse_vals.keys() 181 | == target_norms.keys() 182 | == ssim_vals.keys() 183 | == max_vals.keys() 184 | ) 185 | 186 | # apply means across image volumes 187 | metrics = {"nmse": 0, "ssim": 0, "psnr": 0} 188 | metric_values = { 189 | "nmse": [], 190 | "ssim": [], 191 | "psnr": [], 192 | } # to store individual values for std 193 | local_examples = 0 194 | 195 | for fname in mse_vals.keys(): 196 | local_examples = local_examples + 1 197 | mse_val = torch.mean( 198 | torch.cat([v.view(-1) for _, v in mse_vals[fname].items()]) 199 | ) 200 | target_norm = torch.mean( 201 | torch.cat([v.view(-1) for _, v in target_norms[fname].items()]) 202 | ) 203 | nmse = mse_val / target_norm 204 | psnr = 20 * torch.log10( 205 | torch.tensor( 206 | max_vals[fname], 207 | dtype=mse_val.dtype, 208 | device=mse_val.device, 209 | ) 210 | ) - 10 * torch.log10(mse_val) 211 | ssim = torch.mean( 212 | torch.cat([v.view(-1) for _, v in ssim_vals[fname].items()]) 213 | ) 214 | 215 | # Accumulate metric values 216 | metrics["nmse"] += nmse 217 | metrics["psnr"] += psnr 218 | metrics["ssim"] += ssim 219 | 220 | # Store individual metric values for std calculation 221 | metric_values["nmse"].append(nmse) 222 | metric_values["psnr"].append(psnr) 223 | metric_values["ssim"].append(ssim) 224 | 225 | # reduce across ddp via sum 226 | metrics["nmse"] = self.NMSE(metrics["nmse"]) 227 | metrics["ssim"] = self.SSIM(metrics["ssim"]) 228 | metrics["psnr"] = self.PSNR(metrics["psnr"]) 229 | 230 | tot_examples = self.TotExamples(torch.tensor(local_examples)) 231 | val_loss = self.ValLoss(torch.sum(torch.cat(losses))) # type: ignore 232 | tot_slice_examples = self.TotSliceExamples( 233 | torch.tensor(len(losses), dtype=torch.float) 234 | ) 235 | 236 | metrics_to_plot["nmse"].append( 237 | ( 238 | (metrics["nmse"] / tot_examples).item(), 239 | torch.std(torch.stack(metric_values["nmse"])).item(), 240 | ) 241 | ) 242 | metrics_to_plot["psnr"].append( 243 | ( 244 | (metrics["psnr"] / tot_examples).item(), 245 | torch.std(torch.stack(metric_values["psnr"])).item(), 246 | ) 247 | ) 248 | metrics_to_plot["ssim"].append( 249 | ( 250 | (metrics["ssim"] / tot_examples).item(), 251 | torch.std(torch.stack(metric_values["ssim"])).item(), 252 | ) 253 | ) 254 | slugs.append(slug) 255 | 256 | # Log the mean values 257 | self.log( 258 | f"{slug}--validation_loss", 259 | val_loss / tot_slice_examples, 260 | prog_bar=True, 261 | ) 262 | for metric, value in metrics.items(): 263 | self.log(f"{slug}--val_metrics_{metric}", value / tot_examples) 264 | 265 | # Calculate and log the standard deviation for each metric 266 | for metric, values in metric_values.items(): 267 | std_value = torch.std(torch.stack(values)) 268 | self.log(f"{slug}--val_metrics_{metric}_std", std_value) 269 | 270 | # generate graph 271 | # breakpoint() 272 | for metric_name, values in metrics_to_plot.items(): 273 | scores = [val[0] for val in values] 274 | std_devs = [val[1] for val in values] 275 | 276 | plt.figure(figsize=(10, 6)) 277 | plt.bar(slugs, scores, yerr=std_devs, capsize=5) 278 | plt.xlabel("Dataset Slug") 279 | plt.ylabel(f"{metric_name.upper()} Score") 280 | plt.title(f"{metric_name.upper()} per Dataset with Standard Deviation") 281 | plt.xticks(rotation=45) 282 | plt.tight_layout() 283 | 284 | # Save the plot 285 | buf = BytesIO() 286 | plt.savefig(buf, format="png") 287 | buf.seek(0) 288 | image = Image.open(buf) 289 | image_array = np.array(image) 290 | self.log_image(f"summary_plot_{metric_name}", image_array) 291 | buf.close() 292 | plt.close() 293 | 294 | def OLD_on_validation_epoch_end(self): 295 | val_logs = self.val_batch_results 296 | 297 | # aggregate losses 298 | losses = [] 299 | mse_vals = defaultdict(dict) 300 | target_norms = defaultdict(dict) 301 | ssim_vals = defaultdict(dict) 302 | max_vals = dict() 303 | 304 | # use dict updates to handle duplicate slices 305 | for val_log in val_logs: 306 | losses.append(val_log["val_loss"].view(-1)) 307 | 308 | for k in val_log["mse_vals"].keys(): 309 | mse_vals[k].update(val_log["mse_vals"][k]) 310 | for k in val_log["target_norms"].keys(): 311 | target_norms[k].update(val_log["target_norms"][k]) 312 | for k in val_log["ssim_vals"].keys(): 313 | ssim_vals[k].update(val_log["ssim_vals"][k]) 314 | for k in val_log["max_vals"]: 315 | max_vals[k] = val_log["max_vals"][k] 316 | 317 | # check to make sure we have all files in all metrics 318 | assert ( 319 | mse_vals.keys() 320 | == target_norms.keys() 321 | == ssim_vals.keys() 322 | == max_vals.keys() 323 | ) 324 | 325 | # apply means across image volumes 326 | metrics = {"nmse": 0, "ssim": 0, "psnr": 0} 327 | local_examples = 0 328 | for fname in mse_vals.keys(): 329 | local_examples = local_examples + 1 330 | mse_val = torch.mean( 331 | torch.cat([v.view(-1) for _, v in mse_vals[fname].items()]) 332 | ) 333 | target_norm = torch.mean( 334 | torch.cat([v.view(-1) for _, v in target_norms[fname].items()]) 335 | ) 336 | metrics["nmse"] = metrics["nmse"] + mse_val / target_norm 337 | metrics["psnr"] = ( 338 | metrics["psnr"] 339 | + 20 340 | * torch.log10( 341 | torch.tensor( 342 | max_vals[fname], 343 | dtype=mse_val.dtype, 344 | device=mse_val.device, 345 | ) 346 | ) 347 | - 10 * torch.log10(mse_val) 348 | ) 349 | metrics["ssim"] = metrics["ssim"] + torch.mean( 350 | torch.cat([v.view(-1) for _, v in ssim_vals[fname].items()]) 351 | ) 352 | 353 | # reduce across ddp via sum 354 | metrics["nmse"] = self.NMSE(metrics["nmse"]) 355 | metrics["ssim"] = self.SSIM(metrics["ssim"]) 356 | metrics["psnr"] = self.PSNR(metrics["psnr"]) 357 | 358 | tot_examples = self.TotExamples(torch.tensor(local_examples)) 359 | val_loss = self.ValLoss(torch.sum(torch.cat(losses))) 360 | tot_slice_examples = self.TotSliceExamples( 361 | torch.tensor(len(losses), dtype=torch.float) 362 | ) 363 | 364 | self.log("validation_loss", val_loss / tot_slice_examples, prog_bar=True) 365 | for metric, value in metrics.items(): 366 | self.log(f"val_metrics_{metric}", value / tot_examples) 367 | 368 | @staticmethod 369 | def add_model_specific_args(parent_parser): # pragma: no-cover 370 | """ 371 | Define parameters that only apply to this model 372 | """ 373 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 374 | 375 | # logging params 376 | parser.add_argument( 377 | "--num_log_images", 378 | default=16, 379 | type=int, 380 | help="Number of images to log to Tensorboard", 381 | ) 382 | 383 | return parser 384 | -------------------------------------------------------------------------------- /fastmri/datasets.py: -------------------------------------------------------------------------------- 1 | import random 2 | import xml.etree.ElementTree as etree 3 | from pathlib import Path 4 | from typing import ( 5 | Any, 6 | Callable, 7 | Dict, 8 | List, 9 | Literal, 10 | NamedTuple, 11 | Optional, 12 | Sequence, 13 | Tuple, 14 | ) 15 | 16 | import h5py 17 | import lmdb 18 | import numpy as np 19 | import pandas as pd 20 | import torch 21 | import yaml 22 | 23 | import fastmri 24 | import fastmri.transforms as T 25 | 26 | 27 | class RawSample(NamedTuple): 28 | fname: Path 29 | slice_num: int 30 | metadata: Dict[str, Any] 31 | 32 | 33 | class SliceSample(NamedTuple): 34 | masked_kspace: torch.Tensor 35 | mask: torch.Tensor 36 | num_low_frequencies: int 37 | target: torch.Tensor 38 | max_value: float 39 | # attrs: Dict[str, Any] 40 | fname: str 41 | slice_num: int 42 | 43 | 44 | class SliceSampleMVUE(NamedTuple): 45 | masked_kspace: torch.Tensor 46 | mask: torch.Tensor 47 | num_low_frequencies: int 48 | target: torch.Tensor 49 | rss: torch.Tensor 50 | max_value: float 51 | # attrs: Dict[str, Any] 52 | fname: str 53 | slice_num: int 54 | 55 | 56 | def et_query( 57 | root: etree.Element, 58 | qlist: Sequence[str], 59 | namespace: str = "http://www.ismrm.org/ISMRMRD", 60 | ) -> str: 61 | """ 62 | Query an XML document using ElementTree. 63 | 64 | This function allows querying an XML document by specifying a root and a list of nested queries. 65 | It supports optional XML namespaces. 66 | 67 | Parameters 68 | ---------- 69 | root : ElementTree.Element 70 | The root element of the XML to search through. 71 | qlist : list of str 72 | A list of strings for nested searches, e.g., ["Encoding", "matrixSize"]. 73 | namespace : str, optional 74 | An optional XML namespace to prepend to the query (default is None). 75 | 76 | Returns 77 | ------- 78 | str 79 | The retrieved data as a string. 80 | """ 81 | 82 | s = "." 83 | prefix = "ismrmrd_namespace" 84 | 85 | ns = {prefix: namespace} 86 | 87 | for el in qlist: 88 | s = s + f"//{prefix}:{el}" 89 | 90 | value = root.find(s, ns) 91 | if value is None: 92 | raise RuntimeError("Element not found") 93 | 94 | return str(value.text) 95 | 96 | 97 | class SliceDataset(torch.utils.data.Dataset): 98 | """ 99 | A simplified PyTorch Dataset that provides access to multicoil MR image 100 | slices from the fastMRI dataset. 101 | """ 102 | 103 | def __init__( 104 | self, 105 | # root: Optional[Path | str], 106 | body_part: Literal["knee", "brain"], 107 | partition: Literal["train", "val", "test"], 108 | mask_fns: Optional[List[Callable]] = None, 109 | sample_rate: float = 1.0, 110 | complex: bool = False, 111 | crop_shape: Tuple[int, int] = (320, 320), 112 | slug: str = "", 113 | contrast: Optional[Literal["T1", "T2"]] = None, 114 | coils: Optional[int] = None, 115 | ): 116 | """ 117 | Initializes the fastMRI multi-coil challenge dataset. 118 | 119 | Samples are individual 2D slices taken from k-space volume data. 120 | 121 | Parameters 122 | ---------- 123 | body_part : {'knee', 'brain'} 124 | The body part to analyze. 125 | partition : {'train', 'val', 'test'} 126 | The data partition type. 127 | mask_fns : list of callable, optional 128 | A list of masking functions to apply to samples. 129 | If multiple are given, a mask is randomly chosen for each sample. 130 | sample_rate : float, optional 131 | Fraction of data to sample, by default 1.0. 132 | complex : bool, optional 133 | Whether the $k$-space data should return complex-valued, by default False. 134 | If True, kspace values will be complex. 135 | If False, kspace values will be real (shape, 2). 136 | crop_shape : tuple of two ints, optional 137 | The shape to center crop the k-space data, by default (320, 320). 138 | slug : string 139 | dataset slug name 140 | contrast : {'T1', 'T2'} 141 | If partition is brain, the contrast of images to use. 142 | """ 143 | 144 | with open("fastmri.yaml", "r") as file: 145 | config = yaml.safe_load(file) 146 | self.contrast = contrast 147 | self.slug = slug 148 | self.partition = partition 149 | self.body_part = body_part 150 | self.root = Path(config.get(f"{body_part}_path")) / f"multicoil_{partition}" 151 | self.mask_fns = mask_fns 152 | self.sample_rate = sample_rate 153 | self.raw_samples: List[RawSample] = self._load_samples() 154 | self.complex = complex 155 | self.crop_shape = crop_shape 156 | self.coils = coils 157 | 158 | def _load_samples(self): 159 | # Gather all files in the root directory 160 | if self.body_part == "brain" and self.contrast: 161 | files = list(self.root.glob(f"*{self.contrast}*.h5")) 162 | else: 163 | files = list(self.root.glob("*.h5")) 164 | raw_samples = [] 165 | 166 | # Load and process metadata from each file 167 | for fname in sorted(files): 168 | with h5py.File(fname, "r") as hf: 169 | metadata, num_slices = self._retrieve_metadata(fname) 170 | 171 | # Collect samples for each slice, discard first c slices, and last c slices 172 | c = 6 173 | for slice_num in range(num_slices): 174 | if c <= slice_num <= num_slices - c - 1: 175 | raw_samples.append(RawSample(fname, slice_num, metadata)) 176 | 177 | # Subsample if desired 178 | if self.sample_rate < 1.0: 179 | raw_samples = random.sample( 180 | raw_samples, int(len(raw_samples) * self.sample_rate) 181 | ) 182 | 183 | return raw_samples 184 | 185 | def _retrieve_metadata(self, fname): 186 | with h5py.File(fname, "r") as hf: 187 | et_root = etree.fromstring(hf["ismrmrd_header"][()]) 188 | 189 | enc = ["encoding", "encodedSpace", "matrixSize"] 190 | enc_size = ( 191 | int(et_query(et_root, enc + ["x"])), 192 | int(et_query(et_root, enc + ["y"])), 193 | int(et_query(et_root, enc + ["z"])), 194 | ) 195 | rec = ["encoding", "reconSpace", "matrixSize"] 196 | recon_size = ( 197 | int(et_query(et_root, rec + ["x"])), 198 | int(et_query(et_root, rec + ["y"])), 199 | int(et_query(et_root, rec + ["z"])), 200 | ) 201 | 202 | lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"] 203 | enc_limits_center = int(et_query(et_root, lims + ["center"])) 204 | enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1 205 | 206 | padding_left = enc_size[1] // 2 - enc_limits_center 207 | padding_right = padding_left + enc_limits_max 208 | 209 | num_slices = hf["kspace"].shape[0] 210 | 211 | metadata = { 212 | "padding_left": padding_left, 213 | "padding_right": padding_right, 214 | "encoding_size": enc_size, 215 | "recon_size": recon_size, 216 | **hf.attrs, 217 | } 218 | 219 | return metadata, num_slices 220 | 221 | def __len__(self): 222 | return len(self.raw_samples) 223 | 224 | def __getitem__(self, idx) -> SliceSample: 225 | try: 226 | raw_sample: RawSample = self.raw_samples[idx] 227 | fname, slice_num, metadata = raw_sample 228 | 229 | # load kspace and target 230 | with h5py.File(fname, "r") as hf: 231 | kspace = torch.tensor(hf["kspace"][()][slice_num]) 232 | if not self.complex: 233 | kspace = torch.view_as_real(kspace) 234 | if self.coils: 235 | if kspace.shape[0] < self.coils: 236 | return None 237 | kspace = kspace[: self.coils, :, :, :] 238 | target_key = ( 239 | "reconstruction_rss" 240 | if self.partition in ["train", "val"] 241 | else "reconstruction_esc" 242 | ) 243 | target = hf.get(target_key, None) 244 | if target is not None: 245 | target = torch.tensor(target[()][slice_num]) 246 | if self.body_part == "brain": 247 | target = T.center_crop(target, self.crop_shape) 248 | 249 | # center crop to enable collating for batching 250 | if self.complex: 251 | # if complex, crop across dims: -2 and -1 (last 2) 252 | raise NotImplementedError("Not implemented for complex native") 253 | else: 254 | # crop in image space, to not lose high-frequency information 255 | image = fastmri.ifft2c(kspace) 256 | image_cropped = T.complex_center_crop(image, self.crop_shape) 257 | kspace = fastmri.fft2c(image_cropped) 258 | 259 | # apply transform mask if there is one 260 | if self.mask_fns: 261 | # choose a random mask 262 | mask_fn = random.choice(self.mask_fns) 263 | kspace, mask, num_low_frequencies = T.apply_mask( 264 | kspace, 265 | mask_fn, 266 | # seed=seed, 267 | ) 268 | mask = mask.bool() 269 | else: 270 | mask = torch.ones_like(kspace, dtype=torch.bool) 271 | num_low_frequencies = 0 272 | sample = SliceSample( 273 | kspace, 274 | mask, 275 | num_low_frequencies, 276 | target, 277 | metadata["max"], 278 | fname.name, 279 | slice_num, 280 | ) 281 | return sample 282 | except: 283 | return None 284 | 285 | 286 | class SliceDatasetLMDB(torch.utils.data.Dataset): 287 | """ 288 | A simplified PyTorch Dataset that provides access to multicoil MR image 289 | slices from the fastMRI dataset. Loads from LMDB saved samples. 290 | """ 291 | 292 | def __init__( 293 | self, 294 | body_part: Literal["knee", "brain"], 295 | partition: Literal["train", "val", "test"], 296 | root: Optional[Path | str] = None, 297 | mask_fns: Optional[List[Callable]] = None, 298 | sample_rate: float = 1.0, 299 | complex: bool = False, 300 | crop_shape: Tuple[int, int] = (320, 320), 301 | slug: str = "", 302 | coils: int = 15, 303 | ): 304 | """ 305 | Initializes the fastMRI multi-coil challenge dataset. 306 | 307 | Samples are individual 2D slices taken from k-space volume data. 308 | 309 | Parameters 310 | ---------- 311 | body_part : {'knee', 'brain'} 312 | The body part to analyze. 313 | root : Path or str, optional 314 | Root to lmdb dataset. If not provided, the root is automatically 315 | loaded directly from fastmri.yaml config 316 | partition : {'train', 'val', 'test'} 317 | The data partition type. 318 | mask_fns : list of callable, optional 319 | A list of masking functions to apply to samples. 320 | If multiple are given, a mask is randomly chosen for each sample. 321 | sample_rate : float, optional 322 | Fraction of data to sample, by default 1.0. 323 | complex : bool, optional 324 | Whether the $k$-space data should return complex-valued, by default False. 325 | If True, kspace values will be complex. 326 | If False, kspace values will be real (shape, 2). 327 | crop_shape : tuple of two ints, optional 328 | The shape to center crop the k-space data, by default (320, 320). 329 | slug : string 330 | dataset slug name 331 | """ 332 | 333 | # set attrs 334 | self.coils = coils 335 | self.slug = slug 336 | self.partition = partition 337 | self.mask_fns = mask_fns 338 | self.sample_rate = sample_rate 339 | self.complex = complex 340 | self.crop_shape = crop_shape 341 | 342 | # load lmdb info 343 | if root: 344 | if isinstance(root, str): 345 | root = Path(root) 346 | assert root.exists(), "Provided root doesn't exist." 347 | self.root = root 348 | else: 349 | with open("fastmri.yaml", "r") as file: 350 | config = yaml.safe_load(file) 351 | self.root = Path(config["lmdb"][f"{body_part}_{partition}_path"]) 352 | self.meta = np.load(self.root / "meta.npy") 353 | self.kspace_env = lmdb.open( 354 | str(self.root / "kspace"), 355 | readonly=True, 356 | lock=False, 357 | create=False, 358 | ) 359 | self.kspace_txn = self.kspace_env.begin(write=False) 360 | self.rss_env = lmdb.open( 361 | str(self.root / "rss"), 362 | readonly=True, 363 | lock=False, 364 | create=False, 365 | ) 366 | self.rss_txn = self.rss_env.begin(write=False) 367 | self.length = self.kspace_txn.stat()["entries"] 368 | 369 | def __len__(self): 370 | return int(self.sample_rate * self.length) 371 | 372 | def __getitem__(self, idx) -> SliceSample: 373 | idx_key = str(idx).encode("utf-8") 374 | 375 | # load sample data 376 | kspace = torch.from_numpy( 377 | np.frombuffer(self.kspace_txn.get(idx_key), dtype=np.float32) 378 | .reshape(self.coils, 320, 320, 2) 379 | .copy() 380 | ) 381 | rss = torch.from_numpy( 382 | np.frombuffer(self.rss_txn.get(idx_key), dtype=np.float32) 383 | .reshape(320, 320) 384 | .copy() 385 | ) 386 | 387 | # crop in image space, to not lose high-frequency information 388 | if self.crop_shape and self.crop_shape != (320, 320): 389 | image = fastmri.ifft2c(kspace) 390 | image_cropped = T.complex_center_crop(image, self.crop_shape) 391 | kspace = fastmri.fft2c(image_cropped) 392 | rss = T.center_crop(rss, self.crop_shape) 393 | 394 | # load and apply mask 395 | if self.mask_fns: 396 | # choose a random mask 397 | mask_fn = random.choice(self.mask_fns) 398 | kspace, mask, num_low_frequencies = T.apply_mask( 399 | kspace, 400 | mask_fn, # type: ignore 401 | ) 402 | mask = mask.bool() 403 | else: 404 | mask = torch.ones_like(kspace, dtype=torch.bool) 405 | num_low_frequencies = 0 406 | 407 | # load metadata 408 | fname, slice_num, max_value = self.meta[idx] 409 | fname = str(fname) 410 | slice_num = int(slice_num) 411 | max_value = float(max_value) 412 | 413 | return SliceSample( 414 | kspace, 415 | mask, 416 | num_low_frequencies, 417 | rss, 418 | max_value, 419 | fname, 420 | slice_num, 421 | ) 422 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import itertools 3 | import logging 4 | import os 5 | import random 6 | import sys 7 | import traceback 8 | import uuid 9 | from argparse import ArgumentParser 10 | from pathlib import Path 11 | 12 | import lightning as L 13 | import lightning.pytorch as pl 14 | import torch 15 | import yaml 16 | from lightning.pytorch.loggers import WandbLogger 17 | from lightning.pytorch.utilities import CombinedLoader 18 | from torch.utils.data import DataLoader, Subset 19 | 20 | import wandb 21 | from fastmri.datasets import SliceDatasetLMDB 22 | from fastmri.subsample import create_mask_for_mask_type 23 | from models.lightning.no_varnet_module import NOVarnetModule 24 | from models.lightning.varnet_module import VarNetModule 25 | from type_utils import tuple_type 26 | 27 | SEED = 999 28 | 29 | 30 | def main(run_id, args): 31 | random.seed(SEED) 32 | torch.manual_seed(SEED) 33 | torch.cuda.manual_seed_all(SEED) 34 | pl.seed_everything(SEED) 35 | 36 | logger.info(f"Python version: {sys.version}") 37 | logger.info(f"PyTorch version: {torch.__version__}") 38 | logger.info(f"Lightning version: {pl.__version__}") # type: ignore 39 | logger.info(f"CUDA version: {torch.version.cuda}") # type: ignore 40 | logger.info(f"CUDNN version: {torch.backends.cudnn.version()}") 41 | logger.info(f"Wandb version: {wandb.__version__}") 42 | 43 | # torch config 44 | torch.set_float32_matmul_precision("highest") 45 | 46 | # load paths from fastmri.yaml 47 | with open("fastmri.yaml", "r") as f: 48 | fmri_config = yaml.safe_load(f) 49 | 50 | acceleration_to_fractions = { 51 | 1: 1, 52 | 2: 0.16, 53 | 4: 0.08, 54 | 6: 0.06, 55 | 8: 0.04, 56 | 16: 0.02, 57 | 32: 0.01, 58 | } 59 | 60 | # training setting 61 | if args.mode == "train": 62 | exp_train = { 63 | "wandb_name": args.name, 64 | "wandb_tags": [ 65 | "training", 66 | ], 67 | "mask_type": args.train_patterns, 68 | "center_fractions": [ 69 | acceleration_to_fractions[acc] for acc in args.train_accelerations 70 | ], 71 | "accelerations": args.train_accelerations, 72 | "stds": [70], 73 | } 74 | 75 | # create masks 76 | train_mask_fns = [] 77 | for mask in exp_train["mask_type"]: 78 | if mask == "gaussian_2d": 79 | mask = create_mask_for_mask_type( 80 | mask, 81 | exp_train["stds"], 82 | exp_train["accelerations"], 83 | ) 84 | elif mask == "radial_2d": 85 | mask = create_mask_for_mask_type( 86 | mask, 87 | None, 88 | exp_train["accelerations"], 89 | ) 90 | else: 91 | mask = create_mask_for_mask_type( 92 | mask, 93 | exp_train["center_fractions"], 94 | exp_train["accelerations"], 95 | ) 96 | train_mask_fns.append(mask) 97 | 98 | # validation setting 99 | exp_val = { 100 | "val_mask_type": args.val_patterns, 101 | "val_center_fractions": [ 102 | acceleration_to_fractions[acc] for acc in args.val_accelerations 103 | ], 104 | "val_accelerations": args.val_accelerations, 105 | "stds": [70], 106 | } 107 | 108 | val_mask_fns = [] 109 | for pattern, acc in itertools.product( 110 | exp_val["val_mask_type"], exp_val["val_accelerations"] 111 | ): 112 | if pattern == "gaussian_2d": 113 | mask = create_mask_for_mask_type( 114 | pattern, 115 | exp_val["stds"], 116 | [acc], 117 | ) 118 | elif pattern == "radial_2d": 119 | mask = create_mask_for_mask_type( 120 | pattern, 121 | None, # calculate num arms dynamically 122 | [acc], 123 | ) 124 | else: 125 | mask = create_mask_for_mask_type( 126 | pattern, 127 | [acceleration_to_fractions[acc]], 128 | [acc], 129 | ) 130 | val_mask_fns.append((f"{pattern}-{acc}x", mask)) 131 | if args.val_subset: 132 | val_datasets = [ 133 | Subset( 134 | SliceDatasetLMDB( 135 | args.body_part, 136 | partition="val", 137 | mask_fns=[fn], 138 | complex=False, 139 | sample_rate=1.0, 140 | crop_shape=(320, 320), 141 | slug=slug, 142 | coils=16 if args.body_part == "brain" else 15, 143 | ), 144 | args.val_subset, 145 | ) 146 | for slug, fn in val_mask_fns 147 | ] 148 | else: 149 | # for training, use smaller val 150 | val_samplerate = 0.1 if args.mode == "train" else args.sample_rate 151 | val_datasets = [ 152 | SliceDatasetLMDB( 153 | args.body_part, 154 | partition="val", 155 | mask_fns=[fn], 156 | complex=False, 157 | sample_rate=val_samplerate, 158 | crop_shape=(320, 320), 159 | slug=slug, 160 | coils=16 if args.body_part == "brain" else 15, 161 | ) 162 | for slug, fn in val_mask_fns 163 | ] 164 | val_dataloaders = { 165 | ds.slug: DataLoader( # type: ignore 166 | ds, 167 | batch_size=args.batch_size, 168 | shuffle=False, 169 | num_workers=args.num_workers, 170 | pin_memory=True, 171 | ) 172 | for ds in val_datasets 173 | } 174 | combined_val_dataloader = CombinedLoader(val_dataloaders, mode="sequential") 175 | 176 | # datasets & dataloaders 177 | if args.mode == "train": 178 | train_dataset = SliceDatasetLMDB( 179 | args.body_part, 180 | partition="train", 181 | mask_fns=train_mask_fns, # type: ignore 182 | complex=False, 183 | sample_rate=args.sample_rate, 184 | crop_shape=args.crop_shape, 185 | coils=16 if args.body_part == "brain" else 15, 186 | ) 187 | train_dataloader = DataLoader( 188 | train_dataset, 189 | batch_size=args.batch_size, 190 | shuffle=True, 191 | num_workers=args.num_workers, 192 | pin_memory=True, 193 | drop_last=True, 194 | ) 195 | 196 | # load model 197 | if args.model == "vn": 198 | module = VarNetModule( 199 | num_cascades=args.num_cascades, 200 | pools=args.pools, 201 | chans=args.chans, 202 | sens_pools=args.sens_pools, 203 | sens_chans=args.sens_chans, 204 | lr=args.lr, 205 | lr_step_size=args.lr_step_size, 206 | lr_gamma=args.lr_gamma, 207 | weight_decay=args.weight_decay, 208 | ) 209 | elif args.model == "no_vn": 210 | module = NOVarnetModule( 211 | num_cascades=args.num_cascades, 212 | pools=args.pools, 213 | chans=args.chans, 214 | sens_pools=args.sens_pools, 215 | sens_chans=args.sens_chans, 216 | kno_chans=args.kno_chans, 217 | kno_pools=args.kno_pools, 218 | kno_radius_cutoff=args.kno_radius_cutoff, 219 | kno_kernel_shape=args.kno_kernel_shape, 220 | radius_cutoff=args.radius_cutoff, 221 | kernel_shape=args.kernel_shape, 222 | in_shape=args.in_shape, 223 | use_dc_term=args.use_dc_term, 224 | lr=args.lr, 225 | lr_step_size=args.lr_step_size, 226 | lr_gamma=args.lr_gamma, 227 | weight_decay=args.weight_decay, 228 | reduction_method=args.reduction_method, 229 | skip_method=args.skip_method, 230 | ) 231 | else: 232 | raise NotImplementedError("model not implemented!") 233 | 234 | # Init cloud logger (wandb) 235 | if args.no_logs: 236 | wandb_logger = WandbLogger(mode="disabled") 237 | else: 238 | wandb_logger = WandbLogger( 239 | project="no-medical", 240 | log_model=False, 241 | dir=(Path(fmri_config["log_path"])), 242 | entity="armeet-team", # replace this with your wandb team name 243 | name=args.name, 244 | id=os.getenv("SLURM_JOB_ID", str(uuid.uuid4())), 245 | tags=args.wandb_tags, 246 | config={ 247 | **vars(args), 248 | "slurm_job_id": os.getenv("SLURM_JOB_ID", None), 249 | "num_params": f"{module.num_params:,}", 250 | "cuda_available": torch.cuda.is_available(), 251 | "cuda_device_count": torch.cuda.device_count(), 252 | }, 253 | ) 254 | 255 | # callbacks 256 | checkpoint_callback = pl.callbacks.ModelCheckpoint( # type: ignore 257 | dirpath=(Path(fmri_config["checkpoint_path"]) / run_id), 258 | filename="{epoch}", 259 | save_top_k=-1, 260 | every_n_epochs=1, 261 | verbose=True, 262 | ) 263 | 264 | # Trainer 265 | module.strict_loading = False 266 | trainer = L.Trainer( 267 | deterministic=True, 268 | accelerator="gpu", 269 | num_nodes=args.num_nodes, 270 | devices=args.devices, 271 | strategy="ddp" if args.num_nodes > 1 and args.devices > 1 else "auto", 272 | max_epochs=args.max_epochs, 273 | logger=[ 274 | wandb_logger, 275 | ], 276 | callbacks=[ 277 | checkpoint_callback, 278 | ], 279 | ) 280 | 281 | print("RUN_ID", run_id) 282 | if args.mode == "train": 283 | trainer.fit( 284 | model=module, 285 | train_dataloaders=train_dataloader, # type: ignore 286 | val_dataloaders=combined_val_dataloader, 287 | ckpt_path=args.ckpt_path, 288 | ) 289 | elif args.mode == "val": 290 | trainer.validate( 291 | model=module, 292 | dataloaders=combined_val_dataloader, 293 | ckpt_path=args.ckpt_path, 294 | ) 295 | else: 296 | raise ValueError("Invalid mode") 297 | 298 | 299 | def build_args(): 300 | parser = ArgumentParser() 301 | parser.add_argument( 302 | "--body_part", 303 | type=str, 304 | required=True, 305 | choices=["knee", "brain"], 306 | help="Whether to use knee or brain dataset", 307 | ) 308 | parser.add_argument( 309 | "--experiment", 310 | type=str, 311 | required=True, 312 | help="Wandb experiment group", 313 | ) 314 | parser.add_argument( 315 | "--ckpt_path", 316 | type=str, 317 | default=None, 318 | required=False, 319 | help="Resume from checkpoint at path", 320 | ) 321 | # trainer args 322 | parser.add_argument( 323 | "--max_epochs", 324 | required=False, 325 | type=int, 326 | default=100, 327 | help="Number of training epochs", 328 | ) 329 | parser.add_argument( 330 | "--num_nodes", 331 | required=False, 332 | type=int, 333 | default=1, 334 | help="Number of training nodes (machines)", 335 | ) 336 | parser.add_argument( 337 | "--devices", 338 | required=False, 339 | type=int, 340 | default=1, 341 | help="Number of training devices (gpus)", 342 | ) 343 | 344 | # script args 345 | parser.add_argument( 346 | "--mode", 347 | required=True, 348 | choices=["train", "val"], 349 | type=str, 350 | help="Mode of operation: train or validation", 351 | ) 352 | parser.add_argument( 353 | "--name", 354 | required=True, 355 | type=str, 356 | help="Wandb exp name", 357 | ) 358 | parser.add_argument( 359 | "--num_workers", 360 | default=2, 361 | type=int, 362 | help="Number of dataloader workers", 363 | ) 364 | parser.add_argument( 365 | "--batch_size", 366 | default=1, 367 | type=int, 368 | help="Batch size for run", 369 | ) 370 | 371 | # model, pattern config 372 | parser.add_argument( 373 | "--model", 374 | required=True, 375 | # choices=("vn", "simple_no", "no_vn"), 376 | type=str, 377 | help="Model architecture to train", 378 | ) 379 | 380 | # data subsampling args 381 | parser.add_argument( 382 | "--sample_rate", 383 | default=1.0, 384 | type=float, 385 | help="Sampling rate for the dataset (between 0.0 and 1.0)", 386 | ) 387 | parser.add_argument( 388 | "--crop_shape", 389 | default=(320, 320), 390 | type=tuple_type, 391 | help="The shape to center crop the k-space data, by default (320, 320).", 392 | ) 393 | parser.add_argument( 394 | "--train_accelerations", 395 | required=False, 396 | choices=(1, 2, 4, 6, 8, 16, 32), 397 | type=int, 398 | nargs="+", 399 | help="List of training accelerations, separated by spaces", 400 | ) 401 | parser.add_argument( 402 | "--train_patterns", 403 | required=False, 404 | nargs="+", 405 | default=[ 406 | "equispaced_fraction", 407 | "magic", 408 | "random", 409 | "gaussian_2d", 410 | "poisson_2d", 411 | "radial_2d", 412 | ], 413 | choices=( 414 | "equispaced_fraction", 415 | "magic", 416 | "random", 417 | "equispaced", 418 | "gaussian_2d", 419 | "poisson_2d", 420 | "radial_2d", 421 | ), 422 | type=str, 423 | help="List of training mask patterns, separated by spaces", 424 | ) 425 | parser.add_argument( 426 | "--val_accelerations", 427 | required=True, 428 | choices=(1, 2, 4, 6, 8, 16, 32), 429 | type=int, 430 | nargs="+", 431 | help="List of validation accelerations, separated by spaces", 432 | ) 433 | parser.add_argument( 434 | "--val_patterns", 435 | default=[ 436 | "equispaced_fraction", 437 | "magic", 438 | "random", 439 | "gaussian_2d", 440 | "poisson_2d", 441 | "radial_2d", 442 | ], 443 | nargs="+", 444 | choices=( 445 | "equispaced_fraction", 446 | "magic", 447 | "random", 448 | "gaussian_2d", 449 | "poisson_2d", 450 | "radial_2d", 451 | ), 452 | type=str, 453 | help="List of validation mask patterns, separated by spaces", 454 | ) 455 | 456 | parser.add_argument( 457 | "--val_subset", 458 | nargs="+", 459 | type=int, 460 | help="List of validation sample indices", 461 | ) 462 | 463 | # misc: logging, debugging 464 | parser.add_argument( 465 | "--wandb_tags", 466 | type=str, 467 | nargs="+", 468 | help="List of wandb tags to add, separated by spaces", 469 | ) 470 | 471 | parser.add_argument( 472 | "--no_logs", 473 | action="store_true", 474 | help="Disable logging if this flag is set", 475 | ) 476 | 477 | args, _ = parser.parse_known_args() 478 | modules = [ 479 | (VarNetModule, "vn"), 480 | (NOVarnetModule, "no_vn"), 481 | ] 482 | 483 | for module, model_name in modules: 484 | if args.model == model_name: 485 | parser = module.add_model_specific_args(parser) 486 | 487 | # hyperparams 488 | return parser.parse_args() 489 | 490 | 491 | def config_logger(run_id, file_logging=False): 492 | """ 493 | Configures logging to both the console and a file. 494 | """ 495 | # Create a logger 496 | logger = logging.getLogger() 497 | logger.setLevel(logging.INFO) 498 | 499 | # Define formatter for log messages 500 | formatter = logging.Formatter("%(asctime)s \t %(levelname)s \t %(message)s") 501 | 502 | # Create a handler for console output 503 | console_handler = logging.StreamHandler() 504 | console_handler.setLevel(logging.INFO) 505 | console_handler.setFormatter(formatter) 506 | 507 | # Create a handler for file output 508 | if file_logging: 509 | if not os.path.exists("logs"): 510 | os.makedirs("logs") 511 | file_handler = logging.FileHandler("logs/" + f"{run_id}.log") 512 | file_handler.setLevel(logging.INFO) 513 | file_handler.setFormatter(formatter) 514 | 515 | # Add both handlers to the logger 516 | logger.addHandler(console_handler) 517 | if file_logging: 518 | logger.addHandler(file_handler) # type: ignore 519 | 520 | def handle_uncaught_exception(exc_type, exc_value, exc_traceback): 521 | """ 522 | Logs uncaught exceptions with stack traces. 523 | """ 524 | allow_keyboard_interrupt = False 525 | if allow_keyboard_interrupt and issubclass(exc_type, KeyboardInterrupt): 526 | sys.__excepthook__(exc_type, exc_value, exc_traceback) 527 | return 528 | logger.error( 529 | "Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback) 530 | ) 531 | logger.warning( 532 | "Warning: An error occurred\n" 533 | + "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) 534 | ) 535 | 536 | sys.excepthook = handle_uncaught_exception 537 | 538 | 539 | if __name__ == "__main__": 540 | run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 541 | print("RUN_ID: " + run_id) 542 | args = build_args() 543 | config_logger(run_id) 544 | 545 | # local logger 546 | logger = logging.getLogger() 547 | logger.info("Training started") 548 | main(run_id, args) 549 | logger.info("Training completed") 550 | -------------------------------------------------------------------------------- /torch_harmonics/_disco_convolution.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. 4 | # SPDX-License-Identifier: BSD-3-Clause 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # 9 | # 1. Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | # 12 | # 2. Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | # 16 | # 3. Neither the name of the copyright holder nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | # 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | 32 | 33 | import torch 34 | 35 | # triton will only be avaiable on cuda installations of pytorch 36 | import triton 37 | import triton.language as tl 38 | 39 | BLOCK_SIZE_BATCH = 4 40 | BLOCK_SIZE_NZ = 8 41 | BLOCK_SIZE_POUT = 8 42 | 43 | 44 | @triton.jit 45 | def _disco_s2_contraction_kernel( 46 | inz_ptr, 47 | vnz_ptr, 48 | nnz, 49 | inz_stride_ii, 50 | inz_stride_nz, 51 | vnz_stride, 52 | x_ptr, 53 | batch_size, 54 | nlat_in, 55 | nlon_in, 56 | x_stride_b, 57 | x_stride_t, 58 | x_stride_p, 59 | y_ptr, 60 | kernel_size, 61 | nlat_out, 62 | nlon_out, 63 | y_stride_b, 64 | y_stride_f, 65 | y_stride_t, 66 | y_stride_p, 67 | pscale, 68 | backward: tl.constexpr, 69 | BLOCK_SIZE_BATCH: tl.constexpr, 70 | BLOCK_SIZE_NZ: tl.constexpr, 71 | BLOCK_SIZE_POUT: tl.constexpr, 72 | ): 73 | """ 74 | Kernel for the sparse-dense contraction for the S2 DISCO convolution. 75 | """ 76 | 77 | pid_batch = tl.program_id(0) 78 | pid_pout = tl.program_id(2) 79 | 80 | # pid_nz should always be 0 as we do not account for larger grids in this dimension 81 | pid_nz = tl.program_id(1) # should be always 0 82 | tl.device_assert(pid_nz == 0) 83 | 84 | # create the pointer block for pout 85 | pout = pid_pout * BLOCK_SIZE_POUT + tl.arange(0, BLOCK_SIZE_POUT) 86 | b = pid_batch * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH) 87 | 88 | # create pointer blocks for the psi datastructure 89 | iinz = tl.arange(0, BLOCK_SIZE_NZ) 90 | 91 | # get the initial pointers 92 | fout_ptrs = inz_ptr + iinz * inz_stride_nz 93 | tout_ptrs = inz_ptr + iinz * inz_stride_nz + inz_stride_ii 94 | tpnz_ptrs = inz_ptr + iinz * inz_stride_nz + 2 * inz_stride_ii 95 | vals_ptrs = vnz_ptr + iinz * vnz_stride 96 | 97 | # iterate in a blocked fashion over the non-zero entries 98 | for offs_nz in range(0, nnz, BLOCK_SIZE_NZ): 99 | # load input output latitude coordinate pairs 100 | fout = tl.load( 101 | fout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1 102 | ) 103 | tout = tl.load( 104 | tout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1 105 | ) 106 | tpnz = tl.load( 107 | tpnz_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1 108 | ) 109 | 110 | # load corresponding values 111 | vals = tl.load( 112 | vals_ptrs + offs_nz * vnz_stride, mask=(offs_nz + iinz < nnz), other=0.0 113 | ) 114 | 115 | # compute the shifted longitude coordinates p+p' to read in a coalesced fashion 116 | tnz = tpnz // nlon_in 117 | pnz = tpnz % nlon_in 118 | 119 | # make sure the value is not out of bounds 120 | tl.device_assert(fout < kernel_size) 121 | tl.device_assert(tout < nlat_out) 122 | tl.device_assert(tnz < nlat_in) 123 | tl.device_assert(pnz < nlon_in) 124 | 125 | # load corresponding portion of the input array 126 | x_ptrs = ( 127 | x_ptr 128 | + tnz[None, :, None] * x_stride_t 129 | + ((pnz[None, :, None] + pout[None, None, :] * pscale) % nlon_in) 130 | * x_stride_p 131 | + b[:, None, None] * x_stride_b 132 | ) 133 | y_ptrs = ( 134 | y_ptr 135 | + fout[None, :, None] * y_stride_f 136 | + tout[None, :, None] * y_stride_t 137 | + (pout[None, None, :] % nlon_out) * y_stride_p 138 | + b[:, None, None] * y_stride_b 139 | ) 140 | 141 | # precompute the mask 142 | mask = ( 143 | (b[:, None, None] < batch_size) and (offs_nz + iinz[None, :, None] < nnz) 144 | ) and (pout[None, None, :] < nlon_out) 145 | 146 | # do the actual computation. Backward is essentially just the same operation with swapped tensors. 147 | if not backward: 148 | x = tl.load(x_ptrs, mask=mask, other=0.0) 149 | y = vals[None, :, None] * x 150 | 151 | # store it to the output array 152 | tl.atomic_add(y_ptrs, y, mask=mask) 153 | else: 154 | y = tl.load(y_ptrs, mask=mask, other=0.0) 155 | x = vals[None, :, None] * y 156 | 157 | # store it to the output array 158 | tl.atomic_add(x_ptrs, x, mask=mask) 159 | 160 | 161 | def _disco_s2_contraction_fwd(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): 162 | """ 163 | Wrapper function for the triton implementation of the efficient DISCO convolution on the sphere. 164 | 165 | Parameters 166 | ---------- 167 | x: torch.Tensor 168 | Input signal on the sphere. Expects a tensor of shape batch_size x channels x nlat_in x nlon_in). 169 | psi : torch.Tensor 170 | Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in). 171 | nlon_out: int 172 | Number of longitude points the output should have. 173 | """ 174 | 175 | # check the shapes of all input tensors 176 | assert len(psi.shape) == 3 177 | assert len(x.shape) == 4 178 | assert psi.is_sparse, "Psi must be a sparse COO tensor" 179 | 180 | # TODO: check that Psi is also coalesced 181 | 182 | # get the dimensions of the problem 183 | kernel_size, nlat_out, n_in = psi.shape 184 | nnz = psi.indices().shape[-1] 185 | batch_size, n_chans, nlat_in, nlon_in = x.shape 186 | assert nlat_in * nlon_in == n_in 187 | 188 | # TODO: check that Psi index vector is of type long 189 | 190 | # make sure that the grid-points of the output grid fall onto the grid points of the input grid 191 | assert nlon_in % nlon_out == 0 192 | pscale = nlon_in // nlon_out 193 | 194 | # to simplify things, we merge batch and channel dimensions 195 | x = x.reshape(batch_size * n_chans, nlat_in, nlon_in) 196 | 197 | # prepare the output tensor 198 | y = torch.zeros( 199 | batch_size * n_chans, 200 | kernel_size, 201 | nlat_out, 202 | nlon_out, 203 | device=x.device, 204 | dtype=x.dtype, 205 | ) 206 | 207 | # determine the grid for the computation 208 | grid = ( 209 | triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH), 210 | 1, 211 | triton.cdiv(nlon_out, BLOCK_SIZE_POUT), 212 | ) 213 | 214 | # launch the kernel 215 | _disco_s2_contraction_kernel[grid]( 216 | psi.indices(), 217 | psi.values(), 218 | nnz, 219 | psi.indices().stride(-2), 220 | psi.indices().stride(-1), 221 | psi.values().stride(-1), 222 | x, 223 | batch_size * n_chans, 224 | nlat_in, 225 | nlon_in, 226 | x.stride(0), 227 | x.stride(-2), 228 | x.stride(-1), 229 | y, 230 | kernel_size, 231 | nlat_out, 232 | nlon_out, 233 | y.stride(0), 234 | y.stride(1), 235 | y.stride(-2), 236 | y.stride(-1), 237 | pscale, 238 | False, 239 | BLOCK_SIZE_BATCH, 240 | BLOCK_SIZE_NZ, 241 | BLOCK_SIZE_POUT, 242 | ) 243 | 244 | # reshape y back to expose the correct dimensions 245 | y = y.reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out) 246 | 247 | return y 248 | 249 | 250 | def _disco_s2_contraction_bwd(grad_y: torch.Tensor, psi: torch.Tensor, nlon_in: int): 251 | """ 252 | Backward pass for the triton implementation of the efficient DISCO convolution on the sphere. 253 | 254 | Parameters 255 | ---------- 256 | grad_y: torch.Tensor 257 | Input gradient on the sphere. Expects a tensor of shape batch_size x channels x kernel_size x nlat_out x nlon_out. 258 | psi : torch.Tensor 259 | Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in). 260 | nlon_in: int 261 | Number of longitude points the input used. Is required to infer the correct dimensions 262 | """ 263 | 264 | # check the shapes of all input tensors 265 | assert len(psi.shape) == 3 266 | assert len(grad_y.shape) == 5 267 | assert psi.is_sparse, "psi must be a sparse COO tensor" 268 | 269 | # TODO: check that Psi is also coalesced 270 | 271 | # get the dimensions of the problem 272 | kernel_size, nlat_out, n_in = psi.shape 273 | nnz = psi.indices().shape[-1] 274 | assert grad_y.shape[-2] == nlat_out 275 | assert grad_y.shape[-3] == kernel_size 276 | assert n_in % nlon_in == 0 277 | nlat_in = n_in // nlon_in 278 | batch_size, n_chans, _, _, nlon_out = grad_y.shape 279 | 280 | # make sure that the grid-points of the output grid fall onto the grid points of the input grid 281 | assert nlon_in % nlon_out == 0 282 | pscale = nlon_in // nlon_out 283 | 284 | # to simplify things, we merge batch and channel dimensions 285 | grad_y = grad_y.reshape(batch_size * n_chans, kernel_size, nlat_out, nlon_out) 286 | 287 | # prepare the output tensor 288 | grad_x = torch.zeros( 289 | batch_size * n_chans, nlat_in, nlon_in, device=grad_y.device, dtype=grad_y.dtype 290 | ) 291 | 292 | # determine the grid for the computation 293 | grid = ( 294 | triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH), 295 | 1, 296 | triton.cdiv(nlon_out, BLOCK_SIZE_POUT), 297 | ) 298 | 299 | # launch the kernel 300 | _disco_s2_contraction_kernel[grid]( 301 | psi.indices(), 302 | psi.values(), 303 | nnz, 304 | psi.indices().stride(-2), 305 | psi.indices().stride(-1), 306 | psi.values().stride(-1), 307 | grad_x, 308 | batch_size * n_chans, 309 | nlat_in, 310 | nlon_in, 311 | grad_x.stride(0), 312 | grad_x.stride(-2), 313 | grad_x.stride(-1), 314 | grad_y, 315 | kernel_size, 316 | nlat_out, 317 | nlon_out, 318 | grad_y.stride(0), 319 | grad_y.stride(1), 320 | grad_y.stride(-2), 321 | grad_y.stride(-1), 322 | pscale, 323 | True, 324 | BLOCK_SIZE_BATCH, 325 | BLOCK_SIZE_NZ, 326 | BLOCK_SIZE_POUT, 327 | ) 328 | 329 | # reshape y back to expose the correct dimensions 330 | grad_x = grad_x.reshape(batch_size, n_chans, nlat_in, nlon_in) 331 | 332 | return grad_x 333 | 334 | 335 | class _DiscoS2ContractionTriton(torch.autograd.Function): 336 | """ 337 | Helper function to make the triton implementation work with PyTorch autograd functionality 338 | """ 339 | 340 | @staticmethod 341 | def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int): 342 | ctx.save_for_backward(psi) 343 | ctx.nlon_in = x.shape[-1] 344 | 345 | return _disco_s2_contraction_fwd(x, psi, nlon_out) 346 | 347 | @staticmethod 348 | def backward(ctx, grad_output): 349 | (psi,) = ctx.saved_tensors 350 | grad_input = _disco_s2_contraction_bwd(grad_output, psi, ctx.nlon_in) 351 | grad_x = grad_psi = None 352 | 353 | return grad_input, None, None 354 | 355 | 356 | class _DiscoS2TransposeContractionTriton(torch.autograd.Function): 357 | """ 358 | Helper function to make the triton implementation work with PyTorch autograd functionality 359 | """ 360 | 361 | @staticmethod 362 | def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int): 363 | ctx.save_for_backward(psi) 364 | ctx.nlon_in = x.shape[-1] 365 | 366 | return _disco_s2_contraction_bwd(x, psi, nlon_out) 367 | 368 | @staticmethod 369 | def backward(ctx, grad_output): 370 | (psi,) = ctx.saved_tensors 371 | grad_input = _disco_s2_contraction_fwd(grad_output, psi, ctx.nlon_in) 372 | grad_x = grad_psi = None 373 | 374 | return grad_input, None, None 375 | 376 | 377 | def _disco_s2_contraction_triton(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): 378 | return _DiscoS2ContractionTriton.apply(x, psi, nlon_out) 379 | 380 | 381 | def _disco_s2_transpose_contraction_triton( 382 | x: torch.Tensor, psi: torch.Tensor, nlon_out: int 383 | ): 384 | return _DiscoS2TransposeContractionTriton.apply(x, psi, nlon_out) 385 | 386 | 387 | def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): 388 | """ 389 | Reference implementation of the custom contraction as described in [1]. This requires repeated 390 | shifting of the input tensor, which can potentially be costly. For an efficient implementation 391 | on GPU, make sure to use the custom kernel written in Triton. 392 | """ 393 | assert len(psi.shape) == 3 394 | assert len(x.shape) == 4 395 | psi = psi.to(x.device) 396 | 397 | batch_size, n_chans, nlat_in, nlon_in = x.shape 398 | kernel_size, nlat_out, _ = psi.shape 399 | 400 | assert psi.shape[-1] == nlat_in * nlon_in 401 | assert nlon_in % nlon_out == 0 402 | assert nlon_in >= nlat_out 403 | pscale = nlon_in // nlon_out 404 | 405 | # add a dummy dimension for nkernel and move the batch and channel dims to the end 406 | x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1) 407 | x = x.expand(kernel_size, -1, -1, -1) 408 | 409 | y = torch.zeros( 410 | nlon_out, 411 | kernel_size, 412 | nlat_out, 413 | batch_size * n_chans, 414 | device=x.device, 415 | dtype=x.dtype, 416 | ) 417 | 418 | for pout in range(nlon_out): 419 | # sparse contraction with psi 420 | y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1)) 421 | # we need to repeatedly roll the input tensor to faciliate the shifted multiplication 422 | x = torch.roll(x, -pscale, dims=2) 423 | 424 | # reshape y back to expose the correct dimensions 425 | y = y.permute(3, 1, 2, 0).reshape( 426 | batch_size, n_chans, kernel_size, nlat_out, nlon_out 427 | ) 428 | 429 | return y 430 | 431 | 432 | def _disco_s2_transpose_contraction_torch( 433 | x: torch.Tensor, psi: torch.Tensor, nlon_out: int 434 | ): 435 | """ 436 | Reference implementation of the custom contraction as described in [1]. This requires repeated 437 | shifting of the input tensor, which can potentially be costly. For an efficient implementation 438 | on GPU, make sure to use the custom kernel written in Triton. 439 | """ 440 | assert len(psi.shape) == 3 441 | assert len(x.shape) == 5 442 | psi = psi.to(x.device) 443 | 444 | batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape 445 | kernel_size, _, n_out = psi.shape 446 | 447 | assert psi.shape[-2] == nlat_in 448 | assert n_out % nlon_out == 0 449 | nlat_out = n_out // nlon_out 450 | assert nlon_out >= nlat_in 451 | pscale = nlon_out // nlon_in 452 | 453 | # we do a semi-transposition to faciliate the computation 454 | inz = psi.indices() 455 | tout = inz[2] // nlon_out 456 | pout = inz[2] % nlon_out 457 | # flip the axis of longitudes 458 | pout = nlon_out - 1 - pout 459 | tin = inz[1] 460 | inz = torch.stack([inz[0], tout, tin * nlon_out + pout], dim=0) 461 | psi_mod = torch.sparse_coo_tensor( 462 | inz, psi.values(), size=(kernel_size, nlat_out, nlat_in * nlon_out) 463 | ) 464 | 465 | # interleave zeros along the longitude dimension to allow for fractional offsets to be considered 466 | x_ext = torch.zeros( 467 | kernel_size, 468 | nlat_in, 469 | nlon_out, 470 | batch_size * n_chans, 471 | device=x.device, 472 | dtype=x.dtype, 473 | ) 474 | x_ext[:, :, ::pscale, :] = x.reshape( 475 | batch_size * n_chans, kernel_size, nlat_in, nlon_in 476 | ).permute(1, 2, 3, 0) 477 | # we need to go backwards through the vector, so we flip the axis 478 | x_ext = x_ext.contiguous() 479 | 480 | y = torch.zeros( 481 | kernel_size, 482 | nlon_out, 483 | nlat_out, 484 | batch_size * n_chans, 485 | device=x.device, 486 | dtype=x.dtype, 487 | ) 488 | 489 | for pout in range(nlon_out): 490 | # we need to repeatedly roll the input tensor to faciliate the shifted multiplication 491 | # TODO: double-check why this has to happen first 492 | x_ext = torch.roll(x_ext, -1, dims=2) 493 | # sparse contraction with the modified psi 494 | y[:, pout, :, :] = torch.bmm( 495 | psi_mod, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1) 496 | ) 497 | 498 | # sum over the kernel dimension and reshape to the correct output size 499 | y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out) 500 | 501 | return y 502 | -------------------------------------------------------------------------------- /models/no_varnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Literal, Optional, Tuple 3 | 4 | import fastmri 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from fastmri import transforms 9 | from models.udno import UDNO 10 | 11 | 12 | def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 13 | """ 14 | Calculates F (x sens_maps) 15 | 16 | Parameters 17 | ---------- 18 | x : ndarray 19 | Single-channel image of shape (..., H, W, 2) 20 | sens_maps : ndarray 21 | Sensitivity maps (image space) 22 | 23 | Returns 24 | ------- 25 | ndarray 26 | Result of the operation F (x sens_maps) 27 | """ 28 | return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) 29 | 30 | 31 | def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: 32 | """ 33 | Calculates F^{-1}(k) * conj(sens_maps) 34 | where conj(sens_maps) is the element-wise applied complex conjugate 35 | 36 | Parameters 37 | ---------- 38 | k : ndarray 39 | Multi-channel k-space of shape (B, C, H, W, 2) 40 | sens_maps : ndarray 41 | Sensitivity maps (image space) 42 | 43 | Returns 44 | ------- 45 | ndarray 46 | Result of the operation F^{-1}(k) * conj(sens_maps) 47 | """ 48 | return fastmri.complex_mul(fastmri.ifft2c(k), fastmri.complex_conj(sens_maps)).sum( 49 | dim=1, keepdim=True 50 | ) 51 | 52 | 53 | def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]: 54 | """Reshapes batched multi-channel samples into multiple single channel samples. 55 | 56 | Parameters 57 | ---------- 58 | x : torch.Tensor 59 | x has shape (b, c, h, w, 2) 60 | 61 | Returns 62 | ------- 63 | Tuple[torch.Tensor, int] 64 | tensor of shape (b * c, 1, h, w, 2), b 65 | """ 66 | b, c, h, w, comp = x.shape 67 | return x.view(b * c, 1, h, w, comp), b 68 | 69 | 70 | def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor: 71 | """Reshapes batched independent samples into original multi-channel samples. 72 | 73 | Parameters 74 | ---------- 75 | x : torch.Tensor 76 | tensor of shape (b * c, 1, h, w, 2) 77 | batch_size : int 78 | batch size 79 | 80 | Returns 81 | ------- 82 | torch.Tensor 83 | original multi-channel tensor of shape (b, c, h, w, 2) 84 | """ 85 | bc, _, h, w, comp = x.shape 86 | c = bc // batch_size 87 | return x.view(batch_size, c, h, w, comp) 88 | 89 | 90 | class NormUDNO(nn.Module): 91 | """ 92 | Normalized UDNO model. 93 | 94 | Inputs are normalized before the UDNO for numerically stable training. 95 | """ 96 | 97 | def __init__( 98 | self, 99 | chans: int, 100 | num_pool_layers: int, 101 | radius_cutoff: float, 102 | in_shape: Tuple[int, int], 103 | kernel_shape: Tuple[int, int], 104 | in_chans: int = 2, 105 | out_chans: int = 2, 106 | drop_prob: float = 0.0, 107 | ): 108 | """ 109 | Initialize the VarNet model. 110 | 111 | Parameters 112 | ---------- 113 | chans : int 114 | Number of output channels of the first convolution layer. 115 | num_pools : int 116 | Number of down-sampling and up-sampling layers. 117 | in_chans : int, optional 118 | Number of channels in the input to the U-Net model. Default is 2. 119 | out_chans : int, optional 120 | Number of channels in the output to the U-Net model. Default is 2. 121 | drop_prob : float, optional 122 | Dropout probability. Default is 0.0. 123 | """ 124 | super().__init__() 125 | 126 | self.udno = UDNO( 127 | in_chans=in_chans, 128 | out_chans=out_chans, 129 | radius_cutoff=radius_cutoff, 130 | chans=chans, 131 | num_pool_layers=num_pool_layers, 132 | drop_prob=drop_prob, 133 | in_shape=in_shape, 134 | kernel_shape=kernel_shape, 135 | ) 136 | 137 | def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: 138 | b, c, h, w, two = x.shape 139 | assert two == 2 140 | return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) 141 | 142 | def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: 143 | b, c2, h, w = x.shape 144 | assert c2 % 2 == 0 145 | c = c2 // 2 146 | return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() 147 | 148 | def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 149 | # group norm 150 | b, c, h, w = x.shape 151 | x = x.view(b, 2, c // 2 * h * w) 152 | 153 | mean = x.mean(dim=2).view(b, 2, 1, 1) 154 | std = x.std(dim=2).view(b, 2, 1, 1) 155 | 156 | x = x.view(b, c, h, w) 157 | 158 | return (x - mean) / std, mean, std 159 | 160 | def norm_new( 161 | self, x: torch.Tensor 162 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 163 | # FIXME: not working, wip 164 | # group norm 165 | b, c, h, w = x.shape 166 | num_groups = 2 167 | assert c % num_groups == 0, ( 168 | f"Number of channels ({c}) must be divisible by number of groups ({num_groups})." 169 | ) 170 | 171 | x = x.view(b, num_groups, c // num_groups * h * w) 172 | 173 | mean = x.mean(dim=2).view(b, num_groups, 1, 1) 174 | std = x.std(dim=2).view(b, num_groups, 1, 1) 175 | print(x.shape, mean.shape, std.shape) 176 | 177 | x = x.view(b, c, h, w) 178 | mean = ( 179 | mean.view(b, num_groups, 1, 1) 180 | .repeat(1, c // num_groups, h, w) 181 | .view(b, c, h, w) 182 | ) 183 | std = ( 184 | std.view(b, num_groups, 1, 1) 185 | .repeat(1, c // num_groups, h, w) 186 | .view(b, c, h, w) 187 | ) 188 | 189 | return (x - mean) / std, mean, std 190 | 191 | def unnorm( 192 | self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor 193 | ) -> torch.Tensor: 194 | return x * std + mean 195 | 196 | def pad( 197 | self, x: torch.Tensor 198 | ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: 199 | _, _, h, w = x.shape 200 | w_mult = ((w - 1) | 15) + 1 201 | h_mult = ((h - 1) | 15) + 1 202 | w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] 203 | h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] 204 | # TODO: fix this type when PyTorch fixes theirs 205 | # the documentation lies - this actually takes a list 206 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457 207 | # https://github.com/pytorch/pytorch/pull/16949 208 | x = F.pad(x, w_pad + h_pad) 209 | 210 | return x, (h_pad, w_pad, h_mult, w_mult) 211 | 212 | def unpad( 213 | self, 214 | x: torch.Tensor, 215 | h_pad: List[int], 216 | w_pad: List[int], 217 | h_mult: int, 218 | w_mult: int, 219 | ) -> torch.Tensor: 220 | return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] 221 | 222 | def forward(self, x: torch.Tensor) -> torch.Tensor: 223 | if not x.shape[-1] == 2: 224 | raise ValueError("Last dimension must be 2 for complex.") 225 | 226 | chans = x.shape[1] 227 | if chans == 2: 228 | # FIXME: hard coded skip norm/pad temporarily to avoid group norm bug 229 | x = self.complex_to_chan_dim(x) 230 | x = self.udno(x) 231 | return self.chan_complex_to_last_dim(x) 232 | 233 | # get shapes for unet and normalize 234 | x = self.complex_to_chan_dim(x) 235 | x, mean, std = self.norm(x) 236 | x, pad_sizes = self.pad(x) 237 | 238 | x = self.udno(x) 239 | 240 | # get shapes back and unnormalize 241 | x = self.unpad(x, *pad_sizes) 242 | x = self.unnorm(x, mean, std) 243 | x = self.chan_complex_to_last_dim(x) 244 | 245 | return x 246 | 247 | 248 | class SensitivityModel(nn.Module): 249 | """ 250 | Learn sensitivity maps 251 | """ 252 | 253 | def __init__( 254 | self, 255 | chans: int, 256 | num_pools: int, 257 | radius_cutoff: float, 258 | in_shape: Tuple[int, int], 259 | kernel_shape: Tuple[int, int], 260 | in_chans: int = 2, 261 | out_chans: int = 2, 262 | drop_prob: float = 0.0, 263 | mask_center: bool = True, 264 | ): 265 | """ 266 | Parameters 267 | ---------- 268 | chans : int 269 | Number of output channels of the first convolution layer. 270 | num_pools : int 271 | Number of down-sampling and up-sampling layers. 272 | in_chans : int, optional 273 | Number of channels in the input to the U-Net model. Default is 2. 274 | out_chans : int, optional 275 | Number of channels in the output to the U-Net model. Default is 2. 276 | drop_prob : float, optional 277 | Dropout probability. Default is 0.0. 278 | mask_center : bool, optional 279 | Whether to mask center of k-space for sensitivity map calculation. 280 | Default is True. 281 | """ 282 | super().__init__() 283 | self.mask_center = mask_center 284 | self.norm_udno = NormUDNO( 285 | chans, 286 | num_pools, 287 | radius_cutoff, 288 | in_shape, 289 | kernel_shape, 290 | in_chans=in_chans, 291 | out_chans=out_chans, 292 | drop_prob=drop_prob, 293 | ) 294 | 295 | def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor: 296 | return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1) 297 | 298 | def get_pad_and_num_low_freqs( 299 | self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None 300 | ) -> Tuple[torch.Tensor, torch.Tensor]: 301 | if num_low_frequencies is None or any( 302 | torch.any(t == 0) for t in num_low_frequencies 303 | ): 304 | # get low frequency line locations and mask them out 305 | squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8) 306 | cent = squeezed_mask.shape[1] // 2 307 | # running argmin returns the first non-zero 308 | left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1) 309 | right = torch.argmin(squeezed_mask[:, cent:], dim=1) 310 | num_low_frequencies_tensor = torch.max( 311 | 2 * torch.min(left, right), torch.ones_like(left) 312 | ) # force a symmetric center unless 1 313 | else: 314 | num_low_frequencies_tensor = num_low_frequencies * torch.ones( 315 | mask.shape[0], dtype=mask.dtype, device=mask.device 316 | ) 317 | 318 | pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2 319 | 320 | return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long) 321 | 322 | def forward( 323 | self, 324 | masked_kspace: torch.Tensor, 325 | mask: torch.Tensor, 326 | num_low_frequencies: Optional[int] = None, 327 | ) -> torch.Tensor: 328 | if self.mask_center: 329 | pad, num_low_freqs = self.get_pad_and_num_low_freqs( 330 | mask, num_low_frequencies 331 | ) 332 | masked_kspace = transforms.batched_mask_center( 333 | masked_kspace, pad, pad + num_low_freqs 334 | ) 335 | 336 | # convert to image space 337 | images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace)) 338 | 339 | # estimate sensitivities 340 | return self.divide_root_sum_of_squares( 341 | batch_chans_to_chan_dim(self.norm_udno(images), batches) 342 | ) 343 | 344 | 345 | class VarNetBlock(nn.Module): 346 | """ 347 | Model block for iterative refinement of k-space data. 348 | 349 | This model applies a combination of soft data consistency with the input 350 | model as a regularizer. A series of these blocks can be stacked to form 351 | the full variational network. 352 | 353 | aka Refinement Module in Fig 1 354 | """ 355 | 356 | def __init__(self, model: nn.Module): 357 | """ 358 | Args: 359 | model: Module for "regularization" component of variational 360 | network. 361 | """ 362 | super().__init__() 363 | 364 | self.model = model 365 | self.dc_weight = nn.Parameter(torch.ones(1)) 366 | 367 | def forward( 368 | self, 369 | current_kspace: torch.Tensor, 370 | ref_kspace: torch.Tensor, 371 | mask: torch.Tensor, 372 | sens_maps: torch.Tensor, 373 | use_dc_term: bool = True, 374 | ) -> torch.Tensor: 375 | """ 376 | Args: 377 | current_kspace: The current k-space data (frequency domain data) 378 | being processed by the network. (torch.Tensor) 379 | ref_kspace: Original subsampled k-space data (from which we are 380 | reconstrucintg the image (reference k-space). (torch.Tensor) 381 | mask: A binary mask indicating the locations in k-space where 382 | data consistency should be enforced. (torch.Tensor) 383 | sens_maps: Sensitivity maps for the different coils in parallel 384 | imaging. (torch.Tensor) 385 | """ 386 | 387 | # model-term see orange box of Fig 1 in E2E-VarNet paper! 388 | # multi channel k-space -> single channel image-space 389 | b, c, h, w, _ = current_kspace.shape 390 | 391 | if c == 30: 392 | # get kspace and inpainted kspace 393 | kspace = current_kspace[:, :15, :, :, :] 394 | in_kspace = current_kspace[:, 15:, :, :, :] 395 | # convert to image space 396 | image = sens_reduce(kspace, sens_maps) 397 | in_image = sens_reduce(in_kspace, sens_maps) 398 | # concatenate both onto each other 399 | reduced_image = torch.cat([image, in_image], dim=1) 400 | else: 401 | reduced_image = sens_reduce(current_kspace, sens_maps) 402 | 403 | # single channel image-space 404 | refined_image = self.model(reduced_image) 405 | 406 | # single channel image-space -> multi channel k-space 407 | model_term = sens_expand(refined_image, sens_maps) 408 | 409 | # only use first 15 channels (masked_kspace) in the update 410 | # current_kspace = current_kspace[:, :15, :, :, :] 411 | 412 | if not use_dc_term: 413 | return current_kspace - model_term 414 | 415 | """ 416 | Soft data consistency term: 417 | - Calculates the difference between current k-space and reference k-space where the mask is true. 418 | - Multiplies this difference by the data consistency weight. 419 | """ 420 | # dc_term: see green box of Fig 1 in E2E-VarNet paper! 421 | zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace) 422 | soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight 423 | return current_kspace - soft_dc - model_term 424 | 425 | 426 | class NOVarnet(nn.Module): 427 | """ 428 | Neural Operator model for MRI reconstruction. 429 | 430 | Uses a variational architecture (iterative updates) with a learned sensitivity 431 | model. All operations are resolution invariant employing neural operator 432 | modules (UDNO). 433 | """ 434 | 435 | def __init__( 436 | self, 437 | num_cascades: int = 12, 438 | sens_chans: int = 8, 439 | sens_pools: int = 4, 440 | chans: int = 18, 441 | pools: int = 4, 442 | kno_chans: int = 16, 443 | kno_pools: int = 4, 444 | kno_radius_cutoff: float = 0.02, 445 | kno_kernel_shape: Tuple[int, int] = (6, 7), 446 | radius_cutoff: float = 0.01, 447 | kernel_shape: Tuple[int, int] = (3, 4), 448 | in_shape: Tuple[int, int] = (640, 320), 449 | mask_center: bool = True, 450 | use_dc_term: bool = True, 451 | reduction_method: Literal["batch", "rss"] = "rss", 452 | skip_method: Literal["replace", "add", "add_inv", "concat"] = "add", 453 | ): 454 | """ 455 | Parameters 456 | ---------- 457 | num_cascades : int 458 | Number of cascades (i.e., layers) for variational network. 459 | sens_chans : int 460 | Number of channels for sensitivity map U-Net. 461 | sens_pools : int 462 | Number of downsampling and upsampling layers for sensitivity map U-Net. 463 | chans : int 464 | Number of channels for cascade U-Net. 465 | pools : int 466 | Number of downsampling and upsampling layers for cascade U-Net. 467 | mask_center : bool 468 | Whether to mask center of k-space for sensitivity map calculation. 469 | use_dc_term : bool 470 | Whether to use the data consistency term. 471 | reduction_method : "batch" or "rss" 472 | Method for reducing sensitivity maps to single channel. 473 | "batch" reduces to single channel by stacking channels. 474 | "rss" reduces to single channel by root sum of squares. 475 | skip_method : "replace" or "add" or "add_inv" or "concat" 476 | "replace" replaces the input with the output of the KNO 477 | "add" adds the output of the KNO to the input 478 | "add_inv" adds the output of the KNO to the input (only where samples are missing) 479 | "concat" concatenates the output of the KNO to the input 480 | """ 481 | 482 | super().__init__() 483 | 484 | self.sens_net = SensitivityModel( 485 | sens_chans, 486 | sens_pools, 487 | radius_cutoff, 488 | in_shape, 489 | kernel_shape, 490 | mask_center=mask_center, 491 | ) 492 | self.kno = NormUDNO( 493 | kno_chans, 494 | kno_pools, 495 | in_shape=in_shape, 496 | radius_cutoff=radius_cutoff, 497 | kernel_shape=kernel_shape, 498 | # radius_cutoff=kno_radius_cutoff, 499 | # kernel_shape=kno_kernel_shape, 500 | in_chans=2, 501 | out_chans=2, 502 | ) 503 | self.cascades = nn.ModuleList( 504 | [ 505 | VarNetBlock( 506 | NormUDNO( 507 | chans, 508 | pools, 509 | radius_cutoff, 510 | in_shape, 511 | kernel_shape, 512 | in_chans=( 513 | 4 if skip_method == "concat" and cascade_idx == 0 else 2 514 | ), 515 | out_chans=2, 516 | ) 517 | ) 518 | for cascade_idx in range(num_cascades) 519 | ] 520 | ) 521 | self.use_dc_term = use_dc_term 522 | self.reduction_method = reduction_method 523 | self.skip_method = skip_method 524 | 525 | def forward( 526 | self, 527 | masked_kspace: torch.Tensor, 528 | mask: torch.Tensor, 529 | num_low_frequencies: Optional[int] = None, 530 | ) -> torch.Tensor: 531 | # (B, C, X, Y, 2) 532 | sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) 533 | 534 | # reduce before inpainting 535 | if self.reduction_method == "rss": 536 | # (B, 1, H, W, 2) single channel image space 537 | x_reduced = sens_reduce(masked_kspace, sens_maps) 538 | # (B, 1, H, W, 2) 539 | k_reduced = fastmri.fft2c(x_reduced) 540 | elif self.reduction_method == "batch": 541 | k_reduced, b = chans_to_batch_dim(masked_kspace) 542 | 543 | # inpainting 544 | if self.skip_method == "replace": 545 | kspace_pred = self.kno(k_reduced) 546 | elif self.skip_method == "add_inv": 547 | # FIXME: this is not correct (mask has shape B, 1, H, W, 2 and self.gno(k_reduced) has shape B*C, 1, H, W, 2) 548 | kspace_pred = k_reduced.clone() + (~mask * self.kno(k_reduced)) 549 | elif self.skip_method == "add": 550 | kspace_pred = k_reduced.clone() + self.kno(k_reduced) 551 | elif self.skip_method == "concat": 552 | kspace_pred = torch.cat([k_reduced.clone(), self.kno(k_reduced)], dim=1) 553 | else: 554 | raise NotImplementedError("skip_method not implemented") 555 | 556 | # expand after inpainting 557 | if self.reduction_method == "rss": 558 | if self.skip_method == "concat": 559 | # kspace_pred is (B, 2, H, W, 2) 560 | kspace = kspace_pred[:, :1, :, :, :] 561 | in_kspace = kspace_pred[:, 1:, :, :, :] 562 | # B, 2C, H, W, 2 563 | kspace_pred = torch.cat( 564 | [sens_expand(kspace, sens_maps), sens_expand(in_kspace, sens_maps)], 565 | dim=1, 566 | ) 567 | else: 568 | # (B, 1, H, W, 2) -> (B, C, H, W, 2) multi-channel k space 569 | kspace_pred = sens_expand(kspace_pred, sens_maps) 570 | elif self.reduction_method == "batch": 571 | # (B, C, H, W, 2) multi-channel k space 572 | if self.skip_method == "concat": 573 | kspace = kspace_pred[:, :1, :, :, :] 574 | in_kspace = kspace_pred[:, 1:, :, :, :] 575 | # B, 2C, H, W, 2 576 | kspace_pred = torch.cat( 577 | [ 578 | batch_chans_to_chan_dim(kspace, b), 579 | batch_chans_to_chan_dim(in_kspace, b), 580 | ], 581 | dim=1, 582 | ) 583 | else: 584 | kspace_pred = batch_chans_to_chan_dim(kspace_pred, b) 585 | 586 | # iterative update 587 | for cascade in self.cascades: 588 | kspace_pred = cascade( 589 | kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term 590 | ) 591 | 592 | spatial_pred = fastmri.ifft2c(kspace_pred) 593 | spatial_pred_abs = fastmri.complex_abs(spatial_pred) 594 | combined_spatial = fastmri.rss(spatial_pred_abs, dim=1) 595 | 596 | return combined_spatial 597 | -------------------------------------------------------------------------------- /fastmri/subsample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import os 9 | from typing import Dict, Optional, Sequence, Tuple, Union 10 | 11 | import numpy as np 12 | import torch 13 | import torch.distributions as D 14 | from sigpy.mri import poisson, radial, spiral 15 | 16 | 17 | class MaskFunc: 18 | """ 19 | An object for GRAPPA-style sampling masks. 20 | 21 | This crates a sampling mask that densely samples the center while 22 | subsampling outer k-space regions based on the undersampling factor. 23 | 24 | When called, ``MaskFunc`` uses internal functions create mask by 1) 25 | creating a mask for the k-space center, 2) create a mask outside of the 26 | k-space center, and 3) combining them into a total mask. The internals are 27 | handled by ``sample_mask``, which calls ``calculate_center_mask`` for (1) 28 | and ``calculate_acceleration_mask`` for (2). The combination is executed 29 | in the ``MaskFunc`` ``__call__`` function. 30 | 31 | If you would like to implement a new mask, simply subclass ``MaskFunc`` 32 | and overwrite the ``sample_mask`` logic. See examples in ``RandomMaskFunc`` 33 | and ``EquispacedMaskFunc``. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | center_fractions: Sequence[float], 39 | accelerations: Sequence[int], 40 | allow_any_combination: bool = False, 41 | seed: Optional[int] = None, 42 | ): 43 | """ 44 | Args: 45 | center_fractions: Fraction of low-frequency columns to be retained. 46 | If multiple values are provided, then one of these numbers is 47 | chosen uniformly each time. 48 | accelerations: Amount of under-sampling. This should have the same 49 | length as center_fractions. If multiple values are provided, 50 | then one of these is chosen uniformly each time. 51 | allow_any_combination: Whether to allow cross combinations of 52 | elements from ``center_fractions`` and ``accelerations``. 53 | seed: Seed for starting the internal random number generator of the 54 | ``MaskFunc``. 55 | """ 56 | if len(center_fractions) != len(accelerations) and not allow_any_combination: 57 | raise ValueError( 58 | "Number of center fractions should match number of" 59 | " accelerations if allow_any_combination is False." 60 | ) 61 | 62 | self.center_fractions = center_fractions 63 | self.accelerations = accelerations 64 | self.allow_any_combination = allow_any_combination 65 | self.rng = np.random.RandomState(seed) 66 | 67 | def __call__( 68 | self, 69 | shape: Sequence[int], 70 | offset: Optional[int] = None, 71 | seed: Optional[Union[int, Tuple[int, ...]]] = None, 72 | ) -> Tuple[torch.Tensor, int]: 73 | """ 74 | Sample and return a k-space mask. 75 | 76 | Args: 77 | shape: Shape of k-space. 78 | offset: Offset from 0 to begin mask (for equispaced masks). If no 79 | offset is given, then one is selected randomly. 80 | seed: Seed for random number generator for reproducibility. 81 | 82 | Returns: 83 | A 2-tuple containing 1) the k-space mask and 2) the number of 84 | center frequency lines. 85 | """ 86 | if len(shape) < 3: 87 | raise ValueError("Shape should have 3 or more dimensions") 88 | 89 | center_mask, accel_mask, num_low_frequencies = self.sample_mask(shape, offset) 90 | # combine masks together 91 | return torch.max(center_mask, accel_mask), num_low_frequencies 92 | 93 | def sample_mask( 94 | self, 95 | shape: Sequence[int], 96 | offset: Optional[int], 97 | ) -> Tuple[torch.Tensor, torch.Tensor, int]: 98 | """ 99 | Sample a new k-space mask. 100 | 101 | This function samples and returns two components of a k-space mask: 1) 102 | the center mask (e.g., for sensitivity map calculation) and 2) the 103 | acceleration mask (for the edge of k-space). Both of these masks, as 104 | well as the integer of low frequency samples, are returned. 105 | 106 | Args: 107 | shape: Shape of the k-space to subsample. 108 | offset: Offset from 0 to begin mask (for equispaced masks). 109 | 110 | Returns: 111 | A 3-tuple contaiing 1) the mask for the center of k-space, 2) the 112 | mask for the high frequencies of k-space, and 3) the integer count 113 | of low frequency samples. 114 | """ 115 | num_cols = shape[-2] 116 | center_fraction, acceleration = self.choose_acceleration() 117 | num_low_frequencies = round(num_cols * center_fraction) 118 | center_mask = self.reshape_mask( 119 | self.calculate_center_mask(shape, num_low_frequencies), shape 120 | ) 121 | acceleration_mask = self.reshape_mask( 122 | self.calculate_acceleration_mask( 123 | num_cols, acceleration, offset, num_low_frequencies 124 | ), 125 | shape, 126 | ) 127 | return center_mask, acceleration_mask, num_low_frequencies 128 | 129 | def reshape_mask(self, mask: torch.Tensor, shape: Sequence[int]) -> torch.Tensor: 130 | """Reshape mask to desired output shape.""" 131 | if len(mask.shape) == 1: 132 | mask = torch.tensor(mask) 133 | mask_num_freqs = len(mask) 134 | mask = mask.reshape(1, 1, mask_num_freqs, 1) 135 | mask = mask.expand(shape) 136 | return mask.expand(shape) 137 | 138 | def reshape_mask_old(self, mask: np.ndarray, shape: Sequence[int]) -> torch.Tensor: 139 | """Reshape mask to desired output shape.""" 140 | num_cols = shape[-2] 141 | mask_shape = [1 for s in shape] 142 | mask_shape[-2] = num_cols 143 | 144 | return torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) 145 | 146 | def calculate_acceleration_mask( 147 | self, 148 | num_cols: int, 149 | acceleration: int, 150 | offset: Optional[int], 151 | num_low_frequencies: int, 152 | ) -> np.ndarray: 153 | """ 154 | Produce mask for non-central acceleration lines. 155 | 156 | Args: 157 | num_cols: Number of columns of k-space (2D subsampling). 158 | acceleration: Desired acceleration rate. 159 | offset: Offset from 0 to begin masking (for equispaced masks). 160 | num_low_frequencies: Integer count of low-frequency lines sampled. 161 | 162 | Returns: 163 | A mask for the high spatial frequencies of k-space. 164 | """ 165 | raise NotImplementedError 166 | 167 | def calculate_center_mask( 168 | self, shape: Sequence[int], num_low_freqs: int 169 | ) -> np.ndarray: 170 | """ 171 | Build center mask based on number of low frequencies. 172 | 173 | Args: 174 | shape: Shape of k-space to mask. 175 | num_low_freqs: Number of low-frequency lines to sample. 176 | 177 | Returns: 178 | A mask for hte low spatial frequencies of k-space. 179 | """ 180 | num_cols = shape[-2] 181 | mask = np.zeros(num_cols, dtype=np.float32) 182 | pad = (num_cols - num_low_freqs + 1) // 2 183 | mask[pad : pad + num_low_freqs] = 1 184 | assert mask.sum() == num_low_freqs 185 | 186 | return mask 187 | 188 | def choose_acceleration(self): 189 | """Choose acceleration based on class parameters.""" 190 | if self.allow_any_combination: 191 | return self.rng.choice(self.center_fractions), self.rng.choice( 192 | self.accelerations 193 | ) 194 | else: 195 | choice = self.rng.randint(len(self.center_fractions)) 196 | return self.center_fractions[choice], self.accelerations[choice] 197 | 198 | 199 | class RandomMaskFunc(MaskFunc): 200 | """ 201 | Creates a random sub-sampling mask of a given shape. 202 | 203 | The mask selects a subset of columns from the input k-space data. If the 204 | k-space data has N columns, the mask picks out: 205 | 1. N_low_freqs = (N * center_fraction) columns in the center 206 | corresponding to low-frequencies. 207 | 2. The other columns are selected uniformly at random with a 208 | probability equal to: prob = (N / acceleration - N_low_freqs) / 209 | (N - N_low_freqs). This ensures that the expected number of columns 210 | selected is equal to (N / acceleration). 211 | 212 | It is possible to use multiple center_fractions and accelerations, in which 213 | case one possible (center_fraction, acceleration) is chosen uniformly at 214 | random each time the ``RandomMaskFunc`` object is called. 215 | 216 | For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], 217 | then there is a 50% probability that 4-fold acceleration with 8% center 218 | fraction is selected and a 50% probability that 8-fold acceleration with 4% 219 | center fraction is selected. 220 | """ 221 | 222 | def calculate_acceleration_mask( 223 | self, 224 | num_cols: int, 225 | acceleration: int, 226 | offset: Optional[int], 227 | num_low_frequencies: int, 228 | ) -> np.ndarray: 229 | prob = (num_cols / acceleration - num_low_frequencies) / ( 230 | num_cols - num_low_frequencies 231 | ) 232 | 233 | return self.rng.uniform(size=num_cols) < prob 234 | 235 | 236 | class EquiSpacedMaskFunc(MaskFunc): 237 | """ 238 | Sample data with equally-spaced k-space lines. 239 | 240 | The lines are spaced exactly evenly, as is done in standard GRAPPA-style 241 | acquisitions. This means that with a densely-sampled center, 242 | ``acceleration`` will be greater than the true acceleration rate. 243 | """ 244 | 245 | def calculate_acceleration_mask( 246 | self, 247 | num_cols: int, 248 | acceleration: int, 249 | offset: Optional[int], 250 | num_low_frequencies: int, 251 | ) -> np.ndarray: 252 | """ 253 | Produce mask for non-central acceleration lines. 254 | 255 | Args: 256 | num_cols: Number of columns of k-space (2D subsampling). 257 | acceleration: Desired acceleration rate. 258 | offset: Offset from 0 to begin masking. If no offset is specified, 259 | then one is selected randomly. 260 | num_low_frequencies: Not used. 261 | 262 | Returns: 263 | A mask for the high spatial frequencies of k-space. 264 | """ 265 | if offset is None: 266 | offset = self.rng.randint(0, high=round(acceleration)) 267 | 268 | mask = np.zeros(num_cols, dtype=np.float32) 269 | mask[offset::acceleration] = 1 270 | 271 | return mask 272 | 273 | 274 | class EquispacedMaskFractionFunc(MaskFunc): 275 | """ 276 | Equispaced mask with approximate acceleration matching. 277 | 278 | The mask selects a subset of columns from the input k-space data. If the 279 | k-space data has N columns, the mask picks out: 280 | 1. N_low_freqs = (N * center_fraction) columns in the center 281 | corresponding to low-frequencies. 282 | 2. The other columns are selected with equal spacing at a proportion 283 | that reaches the desired acceleration rate taking into consideration 284 | the number of low frequencies. This ensures that the expected number 285 | of columns selected is equal to (N / acceleration) 286 | 287 | It is possible to use multiple center_fractions and accelerations, in which 288 | case one possible (center_fraction, acceleration) is chosen uniformly at 289 | random each time the EquispacedMaskFunc object is called. 290 | 291 | Note that this function may not give equispaced samples (documented in 292 | https://github.com/facebookresearch/fastMRI/issues/54), which will require 293 | modifications to standard GRAPPA approaches. Nonetheless, this aspect of 294 | the function has been preserved to match the public multicoil data. 295 | """ 296 | 297 | def calculate_acceleration_mask( 298 | self, 299 | num_cols: int, 300 | acceleration: int, 301 | offset: Optional[int], 302 | num_low_frequencies: int, 303 | ) -> np.ndarray: 304 | """ 305 | Produce mask for non-central acceleration lines. 306 | 307 | Args: 308 | num_cols: Number of columns of k-space (2D subsampling). 309 | acceleration: Desired acceleration rate. 310 | offset: Offset from 0 to begin masking. If no offset is specified, 311 | then one is selected randomly. 312 | num_low_frequencies: Number of low frequencies. Used to adjust mask 313 | to exactly match the target acceleration. 314 | 315 | Returns: 316 | A mask for the high spatial frequencies of k-space. 317 | """ 318 | # determine acceleration rate by adjusting for the number of low frequencies 319 | adjusted_accel = (acceleration * (num_low_frequencies - num_cols)) / ( 320 | num_low_frequencies * acceleration - num_cols 321 | ) 322 | if offset is None: 323 | offset = self.rng.randint(0, high=round(adjusted_accel)) 324 | 325 | mask = np.zeros(num_cols, dtype=np.float32) 326 | accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) 327 | accel_samples = np.around(accel_samples).astype(np.uint) 328 | mask[accel_samples] = 1.0 329 | 330 | return mask 331 | 332 | 333 | class MagicMaskFunc(MaskFunc): 334 | """ 335 | Masking function for exploiting conjugate symmetry via offset-sampling. 336 | 337 | This function applies the mask described in the following paper: 338 | 339 | Defazio, A. (2019). Offset Sampling Improves Deep Learning based 340 | Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint, 341 | arXiv:1912.01101. 342 | 343 | It is essentially an equispaced mask with an offset for the opposite site 344 | of k-space. Since MRI images often exhibit approximate conjugate k-space 345 | symmetry, this mask is generally more efficient than a standard equispaced 346 | mask. 347 | 348 | Similarly to ``EquispacedMaskFunc``, this mask will usually undereshoot the 349 | target acceleration rate. 350 | """ 351 | 352 | def calculate_acceleration_mask( 353 | self, 354 | num_cols: int, 355 | acceleration: int, 356 | offset: Optional[int], 357 | num_low_frequencies: int, 358 | ) -> np.ndarray: 359 | """ 360 | Produce mask for non-central acceleration lines. 361 | 362 | Args: 363 | num_cols: Number of columns of k-space (2D subsampling). 364 | acceleration: Desired acceleration rate. 365 | offset: Offset from 0 to begin masking. If no offset is specified, 366 | then one is selected randomly. 367 | num_low_frequencies: Not used. 368 | 369 | Returns: 370 | A mask for the high spatial frequencies of k-space. 371 | """ 372 | if offset is None: 373 | offset = self.rng.randint(0, high=acceleration) 374 | 375 | if offset % 2 == 0: 376 | offset_pos = offset + 1 377 | offset_neg = offset + 2 378 | else: 379 | offset_pos = offset - 1 + 3 380 | offset_neg = offset - 1 + 0 381 | 382 | poslen = (num_cols + 1) // 2 383 | neglen = num_cols - (num_cols + 1) // 2 384 | mask_positive = np.zeros(poslen, dtype=np.float32) 385 | mask_negative = np.zeros(neglen, dtype=np.float32) 386 | 387 | mask_positive[offset_pos::acceleration] = 1 388 | mask_negative[offset_neg::acceleration] = 1 389 | mask_negative = np.flip(mask_negative) 390 | 391 | mask = np.concatenate((mask_positive, mask_negative)) 392 | 393 | return np.fft.fftshift(mask) # shift mask and return 394 | 395 | 396 | class MagicMaskFractionFunc(MagicMaskFunc): 397 | """ 398 | Masking function for exploiting conjugate symmetry via offset-sampling. 399 | 400 | This function applies the mask described in the following paper: 401 | 402 | Defazio, A. (2019). Offset Sampling Improves Deep Learning based 403 | Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint, 404 | arXiv:1912.01101. 405 | 406 | It is essentially an equispaced mask with an offset for the opposite site 407 | of k-space. Since MRI images often exhibit approximate conjugate k-space 408 | symmetry, this mask is generally more efficient than a standard equispaced 409 | mask. 410 | 411 | Similarly to ``EquispacedMaskFractionFunc``, this method exactly matches 412 | the target acceleration by adjusting the offsets. 413 | """ 414 | 415 | def sample_mask( 416 | self, 417 | shape: Sequence[int], 418 | offset: Optional[int], 419 | ) -> Tuple[torch.Tensor, torch.Tensor, int]: 420 | """ 421 | Sample a new k-space mask. 422 | 423 | This function samples and returns two components of a k-space mask: 1) 424 | the center mask (e.g., for sensitivity map calculation) and 2) the 425 | acceleration mask (for the edge of k-space). Both of these masks, as 426 | well as the integer of low frequency samples, are returned. 427 | 428 | Args: 429 | shape: Shape of the k-space to subsample. 430 | offset: Offset from 0 to begin mask (for equispaced masks). 431 | 432 | Returns: 433 | A 3-tuple contaiing 1) the mask for the center of k-space, 2) the 434 | mask for the high frequencies of k-space, and 3) the integer count 435 | of low frequency samples. 436 | """ 437 | num_cols = shape[-2] 438 | fraction_low_freqs, acceleration = self.choose_acceleration() 439 | num_cols = shape[-2] 440 | num_low_frequencies = round(num_cols * fraction_low_freqs) 441 | 442 | # bound the number of low frequencies between 1 and target columns 443 | target_columns_to_sample = round(num_cols / acceleration) 444 | num_low_frequencies = max(min(num_low_frequencies, target_columns_to_sample), 1) 445 | 446 | # adjust acceleration rate based on target acceleration. 447 | adjusted_target_columns_to_sample = ( 448 | target_columns_to_sample - num_low_frequencies 449 | ) 450 | adjusted_acceleration = 0 451 | if adjusted_target_columns_to_sample > 0: 452 | adjusted_acceleration = round(num_cols / adjusted_target_columns_to_sample) 453 | 454 | center_mask = self.reshape_mask( 455 | self.calculate_center_mask(shape, num_low_frequencies), shape 456 | ) 457 | accel_mask = self.reshape_mask( 458 | self.calculate_acceleration_mask( 459 | num_cols, adjusted_acceleration, offset, num_low_frequencies 460 | ), 461 | shape, 462 | ) 463 | 464 | return center_mask, accel_mask, num_low_frequencies 465 | 466 | 467 | class Gaussian2DMaskFunc(MaskFunc): 468 | """Gaussian 2D Masking 469 | 470 | Args: 471 | MaskFunc (_type_): _description_ 472 | """ 473 | 474 | def __init__( 475 | self, 476 | accelerations: Sequence[int], 477 | stds: Sequence[float], 478 | seed: Optional[int] = None, 479 | ): 480 | """initialize Gaussian 2D Mask 481 | 482 | Args: 483 | accelerations (Sequence[int]): list of acceleration factors, when 484 | generating a mask, an acceleration factor from this list will be chosen 485 | stds (Sequence[float]): list of torch.Normal scale (~std) to choose from 486 | seed (Optional[int], optional): Seed for selecting mask parameters. Defaults to None. 487 | """ 488 | self.rng = np.random.RandomState(seed) 489 | self.accelerations = accelerations 490 | self.stds = stds 491 | 492 | def __call__( 493 | self, 494 | shape: Sequence[int], 495 | offset: Optional[int] = None, 496 | seed: Optional[Union[int, Tuple[int, ...]]] = None, 497 | ) -> Tuple[torch.Tensor, torch.Tensor, int]: 498 | if len(shape) < 3: 499 | raise ValueError("Shape should have 3 or more dimensions") 500 | 501 | acceleration = self.rng.choice(self.accelerations) 502 | std = self.rng.choice(self.stds) 503 | 504 | x, y = shape[-3], shape[-2] 505 | mean_x = x // 2 506 | mean_y = y // 2 507 | num_samples_collected = 0 508 | 509 | dist = D.Normal( 510 | loc=torch.tensor([mean_x, mean_y], dtype=torch.float32), 511 | scale=std, 512 | ) 513 | 514 | N = ( 515 | int(1 / acceleration * x * y) + 10000 516 | ) # add constant or won't reach desired subsampling rate 517 | sample_x, sample_y = ( 518 | torch.zeros(N, dtype=torch.int), 519 | torch.zeros(N, dtype=torch.int), 520 | ) 521 | 522 | while num_samples_collected < N: 523 | samples = dist.sample((N,)) # type: ignore 524 | valid_samples = ( 525 | (samples[:, 0] >= 0) 526 | & (samples[:, 0] < x) 527 | & (samples[:, 1] >= 0) 528 | & (samples[:, 1] < y) 529 | ) 530 | 531 | valid_x = samples[valid_samples, 0].int() 532 | valid_y = samples[valid_samples, 1].int() 533 | 534 | num_to_take = min(N - num_samples_collected, valid_x.size(0)) 535 | sample_x[num_samples_collected : num_samples_collected + num_to_take] = ( 536 | valid_x[:num_to_take] 537 | ) 538 | sample_y[num_samples_collected : num_samples_collected + num_to_take] = ( 539 | valid_y[:num_to_take] 540 | ) 541 | num_samples_collected += num_to_take 542 | 543 | mask = torch.zeros((x, y)) 544 | mask[sample_x, sample_y] = 1.0 545 | 546 | # broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size 547 | mask = mask.unsqueeze(-1) # (x, y, 1) 548 | mask = mask.unsqueeze(0) # (1, x, y, 1) 549 | mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone() 550 | 551 | # num_low_freqs doesn't make sense so just return std (a number) 552 | # returning None doesn't work since we can't stack for multiple batches 553 | return mask, std 554 | 555 | 556 | class Poisson2DMaskFunc(MaskFunc): 557 | """ 558 | Variable Density Poisson Disk Sampling 559 | https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.poisson.html#sigpy.mri.poisson 560 | """ 561 | 562 | def __init__( 563 | self, 564 | accelerations: Sequence[int], 565 | stds: None, 566 | seed: Optional[int] = None, 567 | use_cache: bool = True, 568 | ): 569 | """initialize VDPD (Poisson) mask 570 | 571 | Args: 572 | accelerations (Sequence[int]): list of acceleration factors to 573 | choose from 574 | stds: Dummy param. Do not pass value. Defaults to None. 575 | seed (Optional[int], optional): Seed for selecting mask params. 576 | Defaults to None. 577 | """ 578 | self.rng = np.random.RandomState(seed) 579 | self.accelerations = accelerations 580 | self.use_cache = use_cache 581 | if use_cache: 582 | self.cache: Dict[int, np.ndarray] = dict() 583 | for acc in accelerations: 584 | assert os.path.exists(f"fastmri/poisson_cache/poisson_{acc}x.npy") 585 | self.cache[acc] = np.load(f"fastmri/poisson_cache/poisson_{acc}x.npy") 586 | 587 | def __call__( 588 | self, 589 | shape: Sequence[int], 590 | offset: Optional[int] = None, 591 | seed: Optional[Union[int, Tuple[int, ...]]] = None, 592 | ) -> Tuple[torch.Tensor, torch.Tensor, int]: 593 | if self.use_cache: 594 | acceleration = self.rng.choice(self.accelerations) 595 | return torch.from_numpy(self.cache[acceleration]), 1.0 # type: ignore 596 | if len(shape) < 3: 597 | raise ValueError("Shape should have 3 or more dimensions") 598 | 599 | acceleration = self.rng.choice(self.accelerations) 600 | x, y = shape[-3], shape[-2] 601 | 602 | mask = poisson(img_shape=(x, y), accel=acceleration, dtype=np.float32) 603 | mask = torch.from_numpy(mask) 604 | 605 | # broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size 606 | mask = mask.unsqueeze(-1) # (x, y, 1e 607 | mask = mask.unsqueeze(0) # (1, x, y, 1) 608 | mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone() 609 | 610 | # num low freqs doesn't make sense here, so we return arbitrary value 1.0 611 | return mask, 100.0 612 | 613 | 614 | class Radial2DMaskFunc(MaskFunc): 615 | """ 616 | Radial trajectory MRI masking method. 617 | https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.radial.html#sigpy.mri.radial 618 | """ 619 | 620 | def __init__( 621 | self, 622 | accelerations: Sequence[int], 623 | arms: Optional[Sequence[int]], 624 | seed: Optional[int] = None, 625 | ): 626 | """ 627 | initialize Radial mask 628 | 629 | Args: 630 | accelerations (Sequence[int]): list of acceleration factors to 631 | choose from 632 | arms: Number of radial arms. 633 | seed (Optional[int], optional): Seed for selecting mask params. 634 | Defaults to None. 635 | """ 636 | self.rng = np.random.RandomState(seed) 637 | self.accelerations = accelerations 638 | self.arms = arms 639 | 640 | def __call__( 641 | self, 642 | shape: Sequence[int], 643 | offset: Optional[int] = None, 644 | seed: Optional[Union[int, Tuple[int, ...]]] = None, 645 | ) -> Tuple[torch.Tensor, torch.Tensor, int]: 646 | if len(shape) < 3: 647 | raise ValueError("Shape should have 3 or more dimensions") 648 | 649 | acceleration = self.rng.choice(self.accelerations) 650 | x, y = shape[-3], shape[-2] 651 | npoints = int(x * y * (1 / acceleration)) 652 | if self.arms: 653 | arms = self.rng.choice(self.arms) 654 | else: 655 | points_per_arm = x // 3 656 | arms = npoints // points_per_arm 657 | 658 | # calculate radial parameters to satisfy acceleration factor 659 | ntr = arms # num radial lines 660 | nro = npoints // arms # num points on each radial line 661 | ndim = 2 # 2D 662 | 663 | # gen trajectory w/ shape (ntr, nro, ndim) 664 | traj = radial( 665 | coord_shape=[ntr, nro, ndim], 666 | img_shape=(x, y), 667 | golden=True, 668 | dtype=int, 669 | ) 670 | 671 | mask = torch.zeros(x, y, dtype=torch.float32) 672 | x_coords = traj[..., 0].flatten() + (x // 2) 673 | y_coords = traj[..., 1].flatten() + (y // 2) 674 | mask[x_coords, y_coords] = 1.0 675 | 676 | # broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size 677 | mask = mask.unsqueeze(-1) # (x, y, 1) 678 | mask = mask.unsqueeze(0) # (1, x, y, 1) 679 | mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone() 680 | 681 | # num low freqs doesn't make sense here, so we return arbitrary value 1.0 682 | return mask, 100.0 683 | 684 | 685 | class Spiral2DMaskFunc(MaskFunc): 686 | """ 687 | Spiral trajectory MRI masking method. 688 | https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.spiral.html#sigpy.mri.spiral 689 | """ 690 | 691 | def __init__( 692 | self, 693 | accelerations: Sequence[int], 694 | arms: Sequence[int], 695 | seed: Optional[int] = None, 696 | ): 697 | """ 698 | initialize Radial mask 699 | 700 | Args: 701 | accelerations (Sequence[int]): list of acceleration factors to 702 | choose from 703 | arms: Number of radial arms. 704 | seed (Optional[int], optional): Seed for selecting mask params. 705 | Defaults to None. 706 | """ 707 | self.rng = np.random.RandomState(seed) 708 | self.accelerations = accelerations 709 | self.arms = arms 710 | 711 | def __call__( 712 | self, 713 | shape: Sequence[int], 714 | offset: Optional[int] = None, 715 | seed: Optional[Union[int, Tuple[int, ...]]] = None, 716 | ) -> Tuple[torch.Tensor, torch.Tensor, int]: 717 | raise (NotImplementedError("Spiral2D not implemented")) 718 | 719 | 720 | def create_mask_for_mask_type( 721 | mask_type_str: str, 722 | center_fractions: Optional[Sequence], 723 | accelerations: Sequence[int], 724 | ) -> MaskFunc: 725 | """ 726 | Creates a mask of the specified type. 727 | 728 | Args: 729 | center_fractions: What fraction of the center of k-space to include. 730 | accelerations: What accelerations to apply. 731 | 732 | Returns: 733 | A mask func for the target mask type. 734 | """ 735 | if mask_type_str == "random": 736 | return RandomMaskFunc(center_fractions, accelerations) 737 | elif mask_type_str == "equispaced": 738 | return EquiSpacedMaskFunc(center_fractions, accelerations) 739 | elif mask_type_str == "equispaced_fraction": 740 | return EquispacedMaskFractionFunc(center_fractions, accelerations) 741 | elif mask_type_str == "magic": 742 | return MagicMaskFunc(center_fractions, accelerations) 743 | elif mask_type_str == "magic_fraction": 744 | return MagicMaskFractionFunc(center_fractions, accelerations) 745 | elif mask_type_str == "gaussian_2d": 746 | return Gaussian2DMaskFunc( 747 | stds=center_fractions, 748 | accelerations=accelerations, 749 | ) 750 | elif mask_type_str == "poisson_2d": 751 | return Poisson2DMaskFunc( 752 | accelerations=accelerations, 753 | stds=None, 754 | ) 755 | elif mask_type_str == "radial_2d": 756 | return Radial2DMaskFunc( 757 | accelerations=accelerations, 758 | arms=([int(arm) for arm in center_fractions] if center_fractions else None), 759 | ) 760 | elif mask_type_str == "spiral_2d": 761 | raise NotImplementedError("spiral_2d not implemented") 762 | else: 763 | raise ValueError(f"{mask_type_str} not supported") 764 | --------------------------------------------------------------------------------