├── .gitignore ├── Command Cheatsheet.docx ├── LICENSE ├── README.md ├── active-mri.def ├── activemri ├── __init__.py ├── baselines │ ├── __init__.py │ ├── differential_sampler.py │ ├── evaluation.py │ ├── loupe.py │ ├── loupe_codes │ │ ├── evaluate.py │ │ ├── layers.py │ │ ├── reconstructors.py │ │ ├── samplers.py │ │ └── transforms.py │ ├── non_rl.py │ ├── sequential_sampling_codes │ │ ├── __init__.py │ │ ├── conv_sampler.py │ │ ├── joint_reconstructor.py │ │ └── sampler2d.py │ └── simple_baselines.py ├── data │ ├── __init__.py │ ├── real_brain_data.py │ ├── real_knee_data.py │ ├── singlecoil_knee_data.py │ ├── splits │ │ └── knee_singlecoil │ │ │ ├── test.txt │ │ │ └── val.txt │ └── transforms.py └── envs │ ├── __init__.py │ ├── envs.py │ ├── loupe_envs.py │ ├── masks.py │ └── util.py ├── docs ├── 1d_animation.gif ├── 2d_animation.gif ├── GETTING_STARTED.md ├── INSTALL.md └── teaser.png ├── examples ├── loupe │ ├── test_loupe.sh │ ├── train_loupe.py │ └── train_loupe.sh └── sequential │ ├── __init__.py │ ├── test_sequential.sh │ ├── train_sequential.py │ └── train_sequential.sh ├── figure_reproduction ├── 1d_data │ ├── 4x_lc_loupe_785FD0.pkl │ ├── 4x_lc_loupe_85A60D.pkl │ ├── 4x_lc_loupe_FB9BE7.pkl │ ├── 4x_lc_seq0_191014.pkl │ ├── 4x_lc_seq0_997E9B.pkl │ ├── 4x_lc_seq0_B2F31C.pkl │ ├── 4x_lc_seq1_0F3D13.pkl │ ├── 4x_lc_seq1_8D8ED1.pkl │ ├── 4x_lc_seq1_9BA1F7.pkl │ ├── 4x_lc_seq2_806982.pkl │ ├── 4x_lc_seq2_A0B684.pkl │ ├── 4x_lc_seq2_AB92C0.pkl │ ├── 4x_lc_seq4_2C2751.pkl │ ├── 4x_lc_seq4_8E14D6.pkl │ └── 4x_lc_seq4_C8B962.pkl ├── 2d_data │ ├── 16x_2d_loupe_1CB132.pkl │ ├── 16x_2d_loupe_CF3852.pkl │ ├── 16x_2d_loupe_F5C5C1.pkl │ ├── 16x_2d_seq0_04BC1F.pkl │ ├── 16x_2d_seq0_09DB18.pkl │ ├── 16x_2d_seq0_0F247E.pkl │ ├── 16x_2d_seq1_13BF86.pkl │ ├── 16x_2d_seq1_9CB984.pkl │ ├── 16x_2d_seq1_E154DC.pkl │ ├── 16x_2d_seq2_8934DD.pkl │ ├── 16x_2d_seq2_B11CEC.pkl │ ├── 16x_2d_seq2_B68595.pkl │ ├── 16x_2d_seq4_203258.pkl │ ├── 16x_2d_seq4_2450E9.pkl │ ├── 16x_2d_seq4_BD196A.pkl │ ├── 4x_2d_loupe_72B682.pkl │ ├── 4x_2d_loupe_73A783.pkl │ ├── 4x_2d_loupe_C655FA.pkl │ ├── 4x_2d_seq0_076771.pkl │ ├── 4x_2d_seq0_350405.pkl │ ├── 4x_2d_seq0_9E225B.pkl │ ├── 4x_2d_seq1_748B84.pkl │ ├── 4x_2d_seq1_F3945C.pkl │ ├── 4x_2d_seq1_F8733A.pkl │ ├── 4x_2d_seq2_244A77.pkl │ ├── 4x_2d_seq2_93381A.pkl │ ├── 4x_2d_seq2_A4DA53.pkl │ ├── 4x_2d_seq4_08BBA4.pkl │ ├── 4x_2d_seq4_C32832.pkl │ ├── 4x_2d_seq4_D0D30F.pkl │ ├── 8x_2d_loupe_4F9069.pkl │ ├── 8x_2d_loupe_5C47BE.pkl │ ├── 8x_2d_loupe_D010FA.pkl │ ├── 8x_2d_seq0_130F17.pkl │ ├── 8x_2d_seq0_72A09A.pkl │ ├── 8x_2d_seq0_CB05F5.pkl │ ├── 8x_2d_seq1_848B51.pkl │ ├── 8x_2d_seq1_948938.pkl │ ├── 8x_2d_seq1_EC7415.pkl │ ├── 8x_2d_seq2_0B402E.pkl │ ├── 8x_2d_seq2_2956B7.pkl │ ├── 8x_2d_seq2_4F5297.pkl │ ├── 8x_2d_seq4_300BD3.pkl │ ├── 8x_2d_seq4_56CED5.pkl │ └── 8x_2d_seq4_7A9703.pkl ├── figures.ipynb └── teaser_data │ ├── left_test_iter=651_step0.pkl │ ├── left_test_iter=651_step1.pkl │ ├── left_test_iter=651_step2.pkl │ ├── left_test_iter=651_step3.pkl │ ├── right_test_iter=651_step0.pkl │ ├── right_test_iter=651_step1.pkl │ ├── right_test_iter=651_step2.pkl │ └── right_test_iter=651_step3.pkl ├── requirements.txt ├── resources ├── equispaced_4x_128.pt ├── spectrum_16x_128.pt ├── spectrum_4x_128.pt └── spectrum_8x_128.pt └── split_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints_old 2 | checkpoints 3 | temp 4 | .vscode 5 | misc 6 | *__pycache__* 7 | .ipynb_checkpoints 8 | .DS_Store 9 | .idea 10 | personal_scripts* 11 | notebooks 12 | *.orig 13 | docs/_build 14 | docs/build* 15 | *.mp4 16 | *egg-info* 17 | activemri/models/custom_* 18 | activemri/envs/custom_* 19 | datasets 20 | *.png 21 | *.jpg 22 | *.pth 23 | -------------------------------------------------------------------------------- /Command Cheatsheet.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/Command Cheatsheet.docx -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Tianwei Yin and Zihui Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End-to-End Sequential Sampling and Reconstruction for MR Imaging 2 | 3 |

4 |

5 |

6 | 7 | > [**End-to-End Sequential Sampling and Reconstruction for MR Imaging**](http://arxiv.org/abs/2105.06460), 8 | > Tianwei Yin*, Zihui Wu*, He Sun, Adrian V. Dalca, Yisong Yue, Katherine L. Bouman (*equal contributions) 9 | > *arXiv technical report ([arXiv 2105.06460](http://arxiv.org/abs/2105.06460))* 10 | 11 | 12 | @article{yin2021end, 13 | title={End-to-End Sequential Sampling and Reconstruction for MR Imaging}, 14 | author={Yin, Tianwei, and Wu, Zihui and Sun, He and Dalca, Adrian V. and Yue, Yisong and Bouman, Katherine L.}, 15 | journal={arXiv preprint}, 16 | year={2021}, 17 | } 18 | 19 | 20 | ## Contact 21 | Any questions or suggestions are welcome! 22 | 23 | Tianwei Yin [yintianwei@utexas.edu](mailto:yintianwei@utexas.edu) 24 | Zihui Wu [zwu2@caltech.edu](mailto:zwu2@caltech.edu) 25 | 26 | ## Abstract 27 | Accelerated MRI shortens acquisition time by subsampling in the measurement k-space. Recovering a high-fidelity anatomical image from subsampled measurements requires close cooperation between two components: (1) a sampler that chooses the subsampling pattern and (2) a reconstructor that recovers images from incomplete measurements. In this paper, we leverage the sequential nature of MRI measurements, and propose a fully differentiable framework that jointly learns a sequential sampling policy simultaneously with a reconstruction strategy. This co-designed framework is able to adapt during acquisition in order to capture the most informative measurements for a particular target (Figure 1). Experimental results on the fastMRI knee dataset demonstrate that the proposed approach successfully utilizes intermediate information during the sampling process to boost reconstruction performance. In particular, our proposed method outperforms the current state-of-the-art baseline on up to 96.96% of test samples. We also investigate the individual and collective benefits of the sequential sampling and co-design strategies. 28 | 29 | ## Main Results 30 | 31 | #### Line-constrained Sampling 32 | 33 | | Model | Accelearation | SSIM | 34 | |---------|---------------|------| 35 | | Loupe | 4x | 89.5 | 36 | | Seq1 | 4x | 90.8 | 37 | | Seq4 | 4x | 91.2 | 38 | 39 | 40 | #### 2D Point Sampling 41 | 42 | | Model | Accelearation | SSIM | 43 | |---------|---------------|------| 44 | | Loupe | 4x | 92.4 | 45 | | Seq1 | 4x | 92.7 | 46 | | Seq4 | 4x | 92.9 | 47 | 48 | ## Installation 49 | 50 | Please refer to [INSTALL](docs/INSTALL.md) to set up libraries. 51 | 52 | ## Getting Started 53 | 54 | Please refer to [GETTING_STARTED](docs/GETTING_STARTED.md) to prepare the data and follow the instructions to reproduce all results. 55 | You can use [figures.ipynb](figure_reproduction/figures.ipynb) to reproduce all the figures in the paper. 56 | 57 | ## License 58 | 59 | Seq-MRI is release under MIT license (see [LICENSE](LICENSE)). It is developed based on a forked version of [active-mri-acquisition](https://github.com/facebookresearch/active-mri-acquisition). We thank the original authors for their great codebase. 60 | -------------------------------------------------------------------------------- /active-mri.def: -------------------------------------------------------------------------------- 1 | bootstrap: docker 2 | from: jupyter/scipy-notebook 3 | 4 | %post 5 | apt-get update 6 | apt-get -y upgrade 7 | apt-get clean 8 | 9 | # Install anaconda if it is not installed yet 10 | if [ ! -d /opt/conda ]; then 11 | wget https://repo.continuum.io/archive/Anaconda3-2018.12-Linux-x86_64.sh \ 12 | -O ~/conda.sh && \ 13 | bash ~/conda.sh -b -p /opt/conda && \ 14 | rm ~/conda.sh 15 | fi 16 | 17 | # Set anaconda path 18 | export PATH=/opt/conda/bin:$PATH 19 | 20 | # Update conda; NOTE: for some reason this doesnt actually update conda at the moment... 21 | conda update -y -n base conda 22 | 23 | # Download alternative version of python if needed (default is 3.8) 24 | conda install -y python=3.7 25 | 26 | # Install conda packages; -y is used to silently install 27 | conda config --add channels conda-forge 28 | 29 | conda install -y numpy 30 | conda install -y scipy 31 | conda install -y joblib 32 | conda install -y tqdm 33 | conda install -y vim 34 | conda install -y pynfft 35 | 36 | conda clean --tarballs 37 | 38 | # Install git and pip3 39 | apt-get -y install git-all 40 | apt-get -y install python3-pip 41 | 42 | %environment 43 | export PYTHONPATH=/opt/conda/lib/python3.7/site-packages:$PYTHONPATH 44 | 45 | export LC_ALL=C 46 | -------------------------------------------------------------------------------- /activemri/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import data, envs 7 | 8 | __all__ = ["data", "envs", "experimental"] 9 | -------------------------------------------------------------------------------- /activemri/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import abc 7 | from typing import Any, Dict, List 8 | 9 | 10 | class Policy: 11 | """ A basic policy interface. """ 12 | 13 | def __init__(self, *args, **kwargs): 14 | pass 15 | 16 | @abc.abstractmethod 17 | def get_action(self, obs: Dict[str, Any], **kwargs: Any) -> List[int]: 18 | """ Returns a list of actions for a batch of observations. """ 19 | pass 20 | 21 | def __call__(self, obs: Dict[str, Any], **kwargs: Any) -> List[int]: 22 | return self.get_action(obs, **kwargs) 23 | 24 | 25 | from .simple_baselines import ( 26 | RandomPolicy, 27 | RandomLowBiasPolicy, 28 | LowestIndexPolicy, 29 | OneStepGreedyOracle, 30 | ) 31 | from .evaluation import evaluate 32 | 33 | __all__ = [ 34 | "RandomPolicy", 35 | "RandomLowBiasPolicy", 36 | "LowestIndexPolicy", 37 | "OneStepGreedyOracle", 38 | "CVPR19Evaluator", 39 | "DDQN", 40 | "DDQNTrainer", 41 | "evaluate", 42 | ] 43 | -------------------------------------------------------------------------------- /activemri/baselines/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Dict, List, Optional, Tuple 7 | 8 | import numpy as np 9 | import argparse 10 | import numpy.matlib 11 | import activemri.baselines as baselines 12 | import activemri.envs as envs 13 | 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | def evaluate( 18 | args: argparse.Namespace, 19 | env: envs.envs.ActiveMRIEnv, 20 | policy: baselines.Policy, 21 | num_episodes: int, 22 | seed: int, 23 | split: str, 24 | verbose: Optional[bool] = False, 25 | ) -> Tuple[Dict[str, np.ndarray], List[Tuple[Any, Any]]]: 26 | env.seed(seed) 27 | if split == "test": 28 | env.set_test() 29 | elif split == "val": 30 | env.set_val() 31 | else: 32 | raise ValueError(f"Invalid evaluation split: {split}.") 33 | 34 | score_keys = env.score_keys() 35 | all_scores = dict( 36 | (k, np.zeros((num_episodes * env.num_parallel_episodes, env.budget + 1))) 37 | for k in score_keys 38 | ) 39 | all_img_ids = [] 40 | trajectories_written = 0 41 | for episode in range(num_episodes): 42 | print('Now running episode: '+str(episode)) 43 | step = 0 44 | obs, meta = env.reset() 45 | # plt.imshow(obs["reconstruction"][0,:,:,0]) 46 | # plt.savefig('random.png') 47 | # input("Press Enter to continue...") 48 | if not obs: 49 | break # no more images 50 | # in case the last batch is smaller 51 | actual_batch_size = len(obs["reconstruction"]) 52 | if verbose: 53 | msg = ", ".join( 54 | [ 55 | f"({meta['fname'][i]}, {meta['slice_id'][i]})" 56 | for i in range(actual_batch_size) 57 | ] 58 | ) 59 | print(f"Read images: {msg}") 60 | for i in range(actual_batch_size): 61 | all_img_ids.append((meta["fname"][i], meta["slice_id"][i])) 62 | batch_idx = slice( 63 | trajectories_written, trajectories_written + actual_batch_size 64 | ) 65 | for k in score_keys: 66 | all_scores[k][batch_idx, step] = meta["current_score"][k] 67 | trajectories_written += actual_batch_size 68 | all_done = False 69 | while not all_done: 70 | step += 1 71 | action = policy.get_action(obs) 72 | obs, reward, done, meta, gt = env.step(action) 73 | 74 | for k in score_keys: 75 | all_scores[k][batch_idx, step] = meta["current_score"][k] 76 | all_done = all(done) 77 | 78 | # visualize the first image 79 | if episode == 0: 80 | print('Visualize') 81 | num_row, num_col = obs["reconstruction"][0,:,:,0].cpu().detach().numpy().shape 82 | 83 | kspace_t = obs["reconstruction"][0,:,:,:].fft(2,normalized=False) 84 | kspace_t = (kspace_t ** 2).sum(dim=-1).sqrt().cpu().detach().numpy() 85 | action_t = np.zeros((num_row, num_col)) 86 | action_play = action[0] if type(action)==list else action 87 | action_t[:,action_play] = 1 88 | reward_play = reward[0] 89 | # gt["gt_kspace"] is a List 90 | gt_kspace_t = np.sqrt((gt["gt_kspace"][0] ** 2).sum(-1)) 91 | 92 | eval_vis = {"phi_t": obs["reconstruction"][0,:,:,0].cpu().detach().numpy(), 93 | "kspace_t": kspace_t, 94 | "mask_t": np.matlib.repmat(obs["mask"][0:1,:].cpu().detach().numpy(),num_row,1), 95 | "action_t": action_t, 96 | "action_t_val": action_play, 97 | "reward_t_val": reward_play, 98 | "gt_t": gt["gt"][0,:,:,0].numpy(), 99 | "gt_kspace_t": gt_kspace_t 100 | } 101 | if not 'output_dir' in vars(args).keys(): 102 | args.output_dir = args.checkpoints_dir 103 | visualize_and_save(eval_vis, step, episode, args.budget, args.output_dir) 104 | 105 | if step == 1: 106 | import os 107 | os.makedirs(args.output_dir+'mask/', exist_ok=True) 108 | 109 | left = obs['mask'][0][:len(obs['mask'][0])//2].numpy() 110 | right = np.flip(obs['mask'][0][len(obs['mask'][0])//2:].numpy()) 111 | visualize_symmetry(left, right, step, episode, env.budget, args.output_dir) 112 | 113 | elif all_done: 114 | left = obs['mask'][0][:len(obs['mask'][0])//2].numpy() 115 | right = np.flip(obs['mask'][0][len(obs['mask'][0])//2:].numpy()) 116 | visualize_symmetry(left, right, step, episode, env.budget, args.output_dir) 117 | 118 | # print('-------') 119 | # print(step) 120 | # print(obs['mask']) 121 | # print(list(map(int,left+right))) 122 | # print(meta["current_score"][k]) 123 | # input("Press Enter to continue...") 124 | if episode % 10 == 0: 125 | np.save(os.path.join(args.output_dir, "scores_0-{}.npy".format(episode)), all_scores) 126 | 127 | for k in score_keys: 128 | all_scores[k] = all_scores[k][: len(all_img_ids), :] 129 | return all_scores, all_img_ids 130 | 131 | 132 | def visualize_and_save(eval_vis: Dict[str, Any], step: int, episode: int, num_train_steps: int, checkpoints_dir: str): 133 | _, num_col = eval_vis["phi_t"].shape 134 | 135 | fig, ax = plt.subplots(nrows=2, ncols=4, figsize=[20, 8]) 136 | 137 | cmap = "viridis" if num_col == 32 else "gray" 138 | 139 | sp1 = ax[0, 0].imshow(eval_vis["phi_t"], cmap=cmap) 140 | sp2 = ax[0, 1].imshow(np.fft.fftshift(np.log(eval_vis["kspace_t"]))) 141 | sp3 = ax[0, 2].imshow(eval_vis["gt_t"], cmap=cmap) 142 | sp4 = ax[0, 3].imshow(np.fft.fftshift(np.log(eval_vis["gt_kspace_t"]))) 143 | sp5 = ax[1, 0].imshow(np.fft.fftshift(eval_vis["mask_t"]-eval_vis["action_t"])) 144 | sp6 = ax[1, 1].imshow(np.fft.fftshift(eval_vis["action_t"])) 145 | sp7 = ax[1, 2].imshow(np.fft.fftshift(eval_vis["mask_t"])) 146 | ax[0, 0].title.set_text('Recon. at time t (phi_t)') 147 | ax[0, 1].title.set_text('Log k-space of phi_t') 148 | ax[0, 2].title.set_text('Ground truth of phi_t') 149 | ax[0, 3].title.set_text('GT log k-space of phi_t') 150 | ax[1, 0].title.set_text('Mask at time t-1') 151 | ax[1, 1].title.set_text('Action at time t') 152 | ax[1, 2].title.set_text('Mask at time t') 153 | fig.colorbar(sp1, ax=ax[0, 0]) 154 | fig.colorbar(sp2, ax=ax[0, 1]) 155 | fig.colorbar(sp3, ax=ax[0, 2]) 156 | fig.colorbar(sp4, ax=ax[0, 3]) 157 | fig.colorbar(sp5, ax=ax[1, 0]) 158 | fig.colorbar(sp6, ax=ax[1, 1]) 159 | fig.colorbar(sp7, ax=ax[1, 2]) 160 | ax[1, 3].text(0, 0.8, r'action_t: ', fontsize=15) 161 | ax[1, 3].text(0.1, 0.65, str(eval_vis["action_t_val"])+' in [1,'+str(num_col)+']', fontsize=15) 162 | ax[1, 3].text(0.1, 0.5, str((eval_vis["action_t_val"]+num_col/2)%num_col)+' after fftshift', fontsize=15) 163 | ax[1, 3].text(0, 0.3, r'reward_t: ', fontsize=15) 164 | ax[1, 3].text(0.1, 0.15, str(eval_vis["reward_t_val"])+' (SSIM)', fontsize=15) 165 | ax[1, 3].axis('off') 166 | plt.suptitle('Step = [{}/{}]'.format(step, num_train_steps), fontsize=20) 167 | 168 | plt.savefig(checkpoints_dir+'episode='+str(episode)+'_step='+str(step)+'.png') 169 | plt.close() 170 | 171 | def visualize_symmetry(left, right, step, episode, budget, save_dir): 172 | plt.plot(list(map(int,left+right)),'.') 173 | plt.xlabel('Columns') 174 | plt.ylabel('2: Both, 1: One, 0: Neither') 175 | plt.title('Step = [{}/{}]\nEpisode = {}'.format(step, budget, 1+episode)) 176 | plt.savefig(save_dir+'mask/episode='+str(1+episode)+'_step='+str(step)+'.png') 177 | plt.close() 178 | -------------------------------------------------------------------------------- /activemri/baselines/loupe.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import time 4 | import numpy as np 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from tensorboardX import SummaryWriter 10 | 11 | import activemri.envs.loupe_envs as loupe_envs 12 | from activemri.baselines.loupe_codes.samplers import * 13 | from activemri.baselines.loupe_codes.reconstructors import * 14 | from activemri.baselines.loupe_codes.layers import * 15 | from activemri.baselines.loupe_codes.transforms import * 16 | from ..envs.util import compute_ssim_torch, compute_psnr_torch 17 | import os 18 | from torch.autograd import Function 19 | import matplotlib.pyplot as plt 20 | 21 | from typing import Any, Dict, List, Optional, Tuple 22 | 23 | 24 | class LOUPE(nn.Module): 25 | """ 26 | Reimplementation of Loupe (https://arxiv.org/abs/1907.11374) sampling-reconstruction framework 27 | with straight through estimator (https://arxiv.org/abs/1308.3432). 28 | 29 | The model gets two components: A learned probability mask (Sampler) and a UNet reconstructor. 30 | 31 | Args: 32 | in_chans (int): Number of channels in the input to the reconstructor (2 for complex image, 1 for real image). 33 | out_chans (int): Number of channels in the output to the reconstructor (default is 1 for real image). 34 | chans (int): Number of output channels of the first convolution layer. 35 | num_pool_layers (int): Number of down-sampling and up-sampling layers. 36 | drop_prob (float): Dropout probability. 37 | shape ([int. int]): Shape of the reconstructed image 38 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to 39 | deterministic state. 40 | sparsity (float): Predefined sparsity of the learned probability mask. 1 / acceleration_ratio 41 | line_constrained (bool): Sample kspace measurements column by column 42 | conjugate_mask (bool): For real image, the corresponding kspace measurements have conjugate symmetry property 43 | (point reflection). Therefore, the information in the left half of the kspace image is the same as the 44 | other half. To take advantage of this, we can force the model to only sample right half of the kspace 45 | (when conjugate_mask is set to True) 46 | preselect (bool): preselect DC components 47 | bi_dir (bool): sample from both vertical and horizontal lines 48 | """ 49 | def __init__( 50 | self, 51 | in_chans: int, 52 | out_chans: int, 53 | chans: int = 64, 54 | num_pool_layers: int = 4, 55 | drop_prob: float = 0, 56 | shape: List[int] = [320, 320], 57 | slope: float = 5, 58 | sparsity: float = 0.25, 59 | line_constrained: bool = False, 60 | conjugate_mask: bool = False, 61 | preselect: bool = False, 62 | preselect_num: int = 2, 63 | bi_dir: bool = False, 64 | random: bool = False, 65 | poisson: bool = False, 66 | spectrum: bool = False, 67 | equispaced: bool = False 68 | ): 69 | super().__init__() 70 | assert conjugate_mask is False 71 | 72 | self.preselect =preselect 73 | 74 | if not line_constrained: 75 | sparsity = (sparsity - preselect_num**2 / (shape[0]*shape[1])) if preselect else sparsity 76 | else: 77 | sparsity = (sparsity - preselect_num / shape[1]) if preselect else sparsity 78 | 79 | # for backward compatability 80 | self.samplers = nn.ModuleList() 81 | 82 | if bi_dir: 83 | assert 0 84 | self.samplers.append(BiLOUPESampler(shape, slope, sparsity, line_constrained, conjugate_mask, preselect, preselect_num)) 85 | else: 86 | self.samplers.append(LOUPESampler(shape, slope, sparsity, line_constrained, conjugate_mask, preselect, preselect_num, 87 | random=random, poisson=poisson, spectrum=spectrum, equispaced=equispaced)) 88 | 89 | self.reconstructor = LOUPEUNet(in_chans, out_chans, chans, num_pool_layers, drop_prob) 90 | 91 | self.sparsity = sparsity 92 | self.conjugate_mask = conjugate_mask 93 | self.data_for_vis = {} 94 | 95 | if in_chans == 1: 96 | assert self.conjugate_mask, "Reconstructor (denoiser) only take the real part of the ifft output" 97 | 98 | def forward(self, target, kspace, seed=None): 99 | """ 100 | Args: 101 | kspace (torch.Tensor): Input tensor of shape NHWC (kspace data) 102 | Returns: 103 | (torch.Tensor): Output tensor of shape NCHW (reconstructed image ) 104 | """ 105 | 106 | # choose kspace sampling location 107 | # masked_kspace: NHWC 108 | masked_kspace, mask, neg_entropy, data_to_vis_sampler = self.samplers[0](kspace, self.sparsity) 109 | 110 | self.data_for_vis.update(data_to_vis_sampler) 111 | 112 | # Inverse Fourier Transform to get zero filled solution 113 | # NHWC to NCHW 114 | zero_filled_recon = transforms.fftshift(transforms.ifft2(masked_kspace),dim=(1,2)).permute(0, -1, 1, 2) 115 | 116 | if self.conjugate_mask: 117 | # only the real part of the ifft output is the same as the original image 118 | # when you only sample in one half of the kspace 119 | recon = self.reconstructor(zero_filled_recon[:,0:1,:,:]) 120 | else: 121 | recon = self.reconstructor(zero_filled_recon, 0) 122 | 123 | self.data_for_vis.update({'input': target[0,0,:,:].cpu().detach().numpy(), 124 | 'kspace': transforms.complex_abs(transforms.fftshift(kspace[0,:,:,:],dim=(0,1))).cpu().detach().numpy(), 125 | 'masked_kspace': transforms.complex_abs(transforms.fftshift(masked_kspace[0,:,:,:],dim=(0,1))).cpu().detach().numpy(), 126 | 'zero_filled_recon': zero_filled_recon[0,0,:,:].cpu().detach().numpy(), 127 | 'recon': recon[0,0,:,:].cpu().detach().numpy()}) 128 | 129 | pred_dict = {'output': recon.norm(dim=1, keepdim=True), 'energy': neg_entropy, 'mask': mask} 130 | 131 | return pred_dict 132 | 133 | def loss(self, pred_dict, target_dict, meta, loss_type): 134 | """ 135 | Args: 136 | pred_dict: 137 | output: reconstructed image from downsampled kspace measurement 138 | energy: negative entropy of the probability mask 139 | mask: the binazried sampling mask (used for visualization) 140 | 141 | target_dict: 142 | target: original fully sampled image 143 | 144 | meta: 145 | recon_weight: weight of reconstruction loss 146 | entropy_weight: weight of the entropy loss (to encourage exploration) 147 | """ 148 | target = target_dict['target'] 149 | pred = pred_dict['output'] 150 | energy = pred_dict['energy'] 151 | 152 | if loss_type == 'l1': 153 | reconstruction_loss = F.l1_loss(pred, target, size_average=True) 154 | elif loss_type == 'ssim': 155 | reconstruction_loss = -torch.mean(compute_ssim_torch(pred, target)) 156 | elif loss_type == 'psnr': 157 | reconstruction_loss = - torch.mean(compute_psnr_torch(pred, target)) 158 | else: 159 | raise NotImplementedError 160 | 161 | entropy_loss = torch.mean(energy) 162 | 163 | loss = entropy_loss * meta['entropy_weight'] + reconstruction_loss * meta['recon_weight'] 164 | 165 | log_dict = {'Total Loss': loss.item(), 'Entropy': entropy_loss.item(), 'Reconstruction': reconstruction_loss.item()} 166 | 167 | return loss, log_dict 168 | 169 | def show_mask(self): 170 | """ 171 | Return: 172 | (np.ndarray) the learned undersampling mask shape HW 173 | """ 174 | H, W = self.samplers[0].shape 175 | pseudo_image = torch.zeros(1, H, W, 1) 176 | 177 | return self.samplers[0].mask(pseudo_image, self.sparsity) 178 | 179 | def visualize_and_save(self, options, epoch, data_for_vis_name): 180 | fig, ax = plt.subplots(nrows=2, ncols=5, figsize=[18, 6]) 181 | 182 | cmap = "viridis" if options.resolution[1] == 32 else "gray" 183 | 184 | sp1 = ax[0, 0].imshow(self.data_for_vis['input'], cmap=cmap) 185 | sp2 = ax[0, 1].imshow(np.log(self.data_for_vis['kspace'])) 186 | sp3 = ax[0, 2].imshow(np.log(self.data_for_vis['masked_kspace'])) 187 | sp4 = ax[0, 3].imshow(self.data_for_vis['zero_filled_recon'], cmap=cmap) 188 | sp5 = ax[0, 4].imshow(self.data_for_vis['recon'], cmap=cmap) 189 | sp6 = ax[1, 0].imshow(self.data_for_vis['prob_mask'], aspect='auto') 190 | sp7 = ax[1, 1].imshow(self.data_for_vis['rescaled_mask'], aspect='auto') 191 | sp8 = ax[1, 2].imshow(self.data_for_vis['binarized_mask'], aspect='auto') 192 | ax[0, 0].title.set_text('Input image') 193 | ax[0, 1].title.set_text('Log k-space of the input') 194 | ax[0, 2].title.set_text('Undersampled log k-space') 195 | ax[0, 3].title.set_text('Zero-filled reconstruction') 196 | ax[0, 4].title.set_text('Reconstruction') 197 | ax[1, 0].title.set_text('Probabilistic mask') 198 | ax[1, 1].title.set_text('Rescaled mask') 199 | ax[1, 2].title.set_text('Binary mask') 200 | fig.colorbar(sp1, ax=ax[0, 0]) 201 | fig.colorbar(sp2, ax=ax[0, 1]) 202 | fig.colorbar(sp3, ax=ax[0, 2]) 203 | fig.colorbar(sp4, ax=ax[0, 3]) 204 | fig.colorbar(sp5, ax=ax[0, 4]) 205 | fig.colorbar(sp6, ax=ax[1, 0]) 206 | fig.colorbar(sp7, ax=ax[1, 1]) 207 | fig.colorbar(sp8, ax=ax[1, 2]) 208 | ax[1, 3].axis('off') 209 | ax[1, 4].axis('off') 210 | plt.suptitle('Epoch = [{}/{}]'.format(1 + epoch, options.num_epochs), fontsize=20) 211 | 212 | if not os.path.isdir(options.visualization_dir): 213 | os.mkdir(options.visualization_dir) 214 | 215 | plt.savefig(str(options.visualization_dir)+'/'+data_for_vis_name+'.png') 216 | plt.close() 217 | -------------------------------------------------------------------------------- /activemri/baselines/loupe_codes/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import pathlib 10 | from argparse import ArgumentParser 11 | 12 | import h5py 13 | import numpy as np 14 | from runstats import Statistics 15 | from skimage.metrics import structural_similarity, peak_signal_noise_ratio 16 | from fastmri.data import transforms 17 | 18 | 19 | def mse(gt, pred): 20 | """ Compute Mean Squared Error (MSE) """ 21 | return np.mean((gt - pred) ** 2) 22 | 23 | 24 | def nmse(gt, pred): 25 | """ Compute Normalized Mean Squared Error (NMSE) """ 26 | return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 27 | 28 | 29 | def psnr(gt, pred): 30 | """ Compute Peak Signal to Noise Ratio metric (PSNR) """ 31 | return peak_signal_noise_ratio(gt, pred, data_range=gt.max()) 32 | 33 | 34 | def ssim(gt, pred): 35 | """ Compute Structural Similarity Index Metric (SSIM). """ 36 | return structural_similarity( 37 | gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max() 38 | ) 39 | 40 | 41 | METRIC_FUNCS = dict( 42 | MSE=mse, 43 | NMSE=nmse, 44 | PSNR=psnr, 45 | SSIM=ssim, 46 | ) 47 | 48 | 49 | class Metrics: 50 | """ 51 | Maintains running statistics for a given collection of metrics. 52 | """ 53 | 54 | def __init__(self, metric_funcs): 55 | self.metrics = { 56 | metric: Statistics() for metric in metric_funcs 57 | } 58 | 59 | def push(self, target, recons): 60 | for metric, func in METRIC_FUNCS.items(): 61 | self.metrics[metric].push(func(target, recons)) 62 | 63 | def means(self): 64 | return { 65 | metric: stat.mean() for metric, stat in self.metrics.items() 66 | } 67 | 68 | def stddevs(self): 69 | return { 70 | metric: stat.stddev() for metric, stat in self.metrics.items() 71 | } 72 | 73 | def __repr__(self): 74 | means = self.means() 75 | stddevs = self.stddevs() 76 | metric_names = sorted(list(means)) 77 | return ' '.join( 78 | f'{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}' for name in metric_names 79 | ) 80 | 81 | 82 | def evaluate(args, recons_key): 83 | metrics = Metrics(METRIC_FUNCS) 84 | 85 | for tgt_file in args.target_path.iterdir(): 86 | with h5py.File(tgt_file, 'r') as target, h5py.File( 87 | args.predictions_path / tgt_file.name, 'r') as recons: 88 | if args.acquisition and args.acquisition != target.attrs['acquisition']: 89 | continue 90 | 91 | if args.acceleration and target.attrs['acceleration'] != args.acceleration: 92 | continue 93 | 94 | target = target[recons_key][()] 95 | recons = recons['reconstruction'][()] 96 | target = transforms.center_crop(target, (target.shape[-1], target.shape[-1])) 97 | recons = transforms.center_crop(recons, (target.shape[-1], target.shape[-1])) 98 | metrics.push(target, recons) 99 | return metrics 100 | 101 | 102 | if __name__ == '__main__': 103 | parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 104 | parser.add_argument('--target-path', type=pathlib.Path, required=True, 105 | help='Path to the ground truth data') 106 | parser.add_argument('--predictions-path', type=pathlib.Path, required=True, 107 | help='Path to reconstructions') 108 | parser.add_argument('--challenge', choices=['singlecoil', 'multicoil'], required=True, 109 | help='Which challenge') 110 | parser.add_argument('--acceleration', type=int, default=None) 111 | parser.add_argument('--acquisition', choices=['CORPD_FBK', 'CORPDFS_FBK', 'AXT1', 'AXT1PRE', 112 | 'AXT1POST', 'AXT2', 'AXFLAIR'], default=None, 113 | help='If set, only volumes of the specified acquisition type are used ' 114 | 'for evaluation. By default, all volumes are included.') 115 | args = parser.parse_args() 116 | 117 | recons_key = 'reconstruction_rss' if args.challenge == 'multicoil' else 'reconstruction_esc' 118 | metrics = evaluate(args, recons_key) 119 | print(metrics) 120 | -------------------------------------------------------------------------------- /activemri/baselines/loupe_codes/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Portion of this code is from fastmri(https://github.com/facebookresearch/fastMRI) 3 | 4 | Copyright (c) Facebook, Inc. and its affiliates. 5 | 6 | Licensed under the MIT License. 7 | """ 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from torch.autograd import Function 13 | from activemri.baselines.loupe_codes import transforms 14 | 15 | class ConvBlock(nn.Module): 16 | """ 17 | A Convolutional Block that consists of two convolution layers each followed by 18 | instance normalization, relu activation and dropout. 19 | """ 20 | def __init__(self, in_chans, out_chans, drop_prob): 21 | """ 22 | Args: 23 | in_chans (int): Number of channels in the input. 24 | out_chans (int): Number of channels in the output. 25 | drop_prob (float): Dropout probability. 26 | """ 27 | super(ConvBlock, self).__init__() 28 | 29 | self.in_chans = in_chans 30 | self.out_chans = out_chans 31 | self.drop_prob = drop_prob 32 | 33 | self.layers = nn.Sequential( 34 | nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1), 35 | nn.InstanceNorm2d(out_chans), 36 | nn.ReLU(), 37 | nn.Dropout2d(drop_prob), 38 | nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1), 39 | nn.InstanceNorm2d(out_chans), 40 | nn.ReLU(), 41 | nn.Dropout2d(drop_prob) 42 | ) 43 | 44 | def forward(self, input): 45 | """ 46 | Args: 47 | input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] 48 | 49 | Returns: 50 | (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] 51 | """ 52 | return self.layers(input) 53 | 54 | def __repr__(self): 55 | return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans}, ' \ 56 | f'drop_prob={self.drop_prob})' 57 | 58 | 59 | class ProbMask(nn.Module): 60 | """ 61 | A learnable probablistic mask with the same shape as the kspace measurement. 62 | This learned mask samples measurements in the whole kspace 63 | """ 64 | def __init__(self, shape=[320, 320], slope=5, preselect=False, preselect_num=0): 65 | """ 66 | shape ([int. int]): Shape of the reconstructed image 67 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to 68 | deterministic state. 69 | """ 70 | super(ProbMask, self).__init__() 71 | 72 | self.slope = slope 73 | self.preselect = preselect 74 | self.preselect_num_one_side = preselect_num // 2 75 | 76 | init_tensor = self._slope_random_uniform(shape) 77 | self.mask = nn.Parameter(init_tensor) 78 | 79 | def forward(self, input): 80 | """ 81 | Args: 82 | input (torch.Tensor): Input tensor of shape NHWC 83 | 84 | Returns: 85 | (torch.Tensor): Output tensor of shape NHWC 86 | """ 87 | logits = self.mask.view(1, input.shape[1], input.shape[2], 1) 88 | return torch.sigmoid(self.slope * logits) 89 | 90 | def _slope_random_uniform(self, shape, eps=1e-2): 91 | """ 92 | uniform random sampling mask 93 | """ 94 | temp = torch.zeros(shape).uniform_(eps, 1-eps) 95 | 96 | # logit with slope factor 97 | logits = -torch.log(1./temp-1.) / self.slope 98 | 99 | logits = logits.reshape(1, shape[0], shape[1], 1) 100 | 101 | logits[:, :self.preselect_num_one_side, :self.preselect_num_one_side] = -1e2 102 | logits[:, :self.preselect_num_one_side, -self.preselect_num_one_side:] = -1e2 103 | logits[:, -self.preselect_num_one_side:, :self.preselect_num_one_side] = -1e2 104 | logits[:, -self.preselect_num_one_side:, -self.preselect_num_one_side:] = -1e2 105 | 106 | return logits 107 | 108 | class HalfProbMask(nn.Module): 109 | """ 110 | A learnable probablistic mask with the same shape as half of the kspace measurement (to force the model 111 | to not take conjugate symmetry points) 112 | """ 113 | def __init__(self, shape=[320, 320], slope=5): 114 | """ 115 | shape ([int. int]): Shape of the reconstructed image 116 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to 117 | deterministic state. 118 | """ 119 | super(HalfProbMask, self).__init__() 120 | 121 | self.slope = slope 122 | init_tensor = self._slope_random_uniform(shape) 123 | self.mask = nn.Parameter(init_tensor) 124 | 125 | def forward(self, input): 126 | """ 127 | Args: 128 | input (torch.Tensor): Input tensor of shape NHWC 129 | 130 | Returns: 131 | (torch.Tensor): Output tensor of shape NHWC 132 | """ 133 | mask = torch.sigmoid(self.slope * self.mask).to(input.device).view(1, input.shape[1], input.shape[2]//2, 1) # only half of the kspace 134 | 135 | zero_mask = torch.zeros((1, input.shape[1], input.shape[2], 1)) 136 | zero_mask[:, :, :input.shape[2]//2] = mask 137 | 138 | return zero_mask.to(input.device) 139 | 140 | def _slope_random_uniform(self, shape, eps=1e-2): 141 | """ 142 | uniform random sampling mask with the shape as half of the kspace measurement 143 | """ 144 | temp = torch.zeros([shape[0], shape[1]//2]).uniform_(eps, 1-eps) 145 | 146 | # logit with slope factor 147 | return -torch.log(1./temp-1.) / self.slope 148 | 149 | class LineConstrainedProbMask(nn.Module): 150 | """ 151 | A learnable probablistic mask with the same shape as the kspace measurement. 152 | The mask is constrinaed to include whole kspace lines in the readout direction 153 | """ 154 | def __init__(self, shape=[32], slope=5, preselect=False, preselect_num=2): 155 | super(LineConstrainedProbMask, self).__init__() 156 | 157 | if preselect: 158 | length = shape[0] - preselect_num 159 | else: 160 | length = shape[0] 161 | 162 | self.preselect_num = preselect_num 163 | self.preselect = preselect 164 | self.slope = slope 165 | init_tensor = self._slope_random_uniform(length) 166 | self.mask = nn.Parameter(init_tensor) 167 | 168 | def forward(self, input, eps=1e-10): 169 | """ 170 | Args: 171 | input (torch.Tensor): Input tensor of shape NHWC 172 | 173 | Returns: 174 | (torch.Tensor): Output tensor of shape NHWC 175 | """ 176 | logits = self.mask 177 | mask = torch.sigmoid(self.slope * logits).view(1, 1, self.mask.shape[0], 1) 178 | 179 | if self.preselect: 180 | if self.preselect_num % 2 ==0: 181 | zeros = torch.zeros(1, 1, self.preselect_num // 2, 1).to(input.device) 182 | mask = torch.cat([zeros, mask, zeros], dim=2) 183 | else: 184 | raise NotImplementedError() 185 | 186 | return mask 187 | 188 | def _slope_random_uniform(self, shape, eps=1e-2): 189 | """ 190 | uniform random sampling mask with the same shape as the kspace measurement 191 | """ 192 | temp = torch.zeros(shape).uniform_(eps, 1-eps) 193 | 194 | # logit with slope factor 195 | return -torch.log(1./temp-1.) / self.slope 196 | 197 | class BiLineConstrainedProbMask(LineConstrainedProbMask): 198 | def __init__(self, shape, slope, preselect, preselect_num): 199 | super().__init__(shape=shape, slope=slope, preselect=preselect, preselect_num=preselect_num) 200 | 201 | def forward(self, input, eps=1e-10): 202 | """ 203 | Args: 204 | input (torch.Tensor): Input tensor of shape NHWC 205 | 206 | Returns: 207 | (torch.Tensor): Output tensor of shape NHWC 208 | """ 209 | logits = self.mask 210 | mask = torch.sigmoid(self.slope * logits).view(1, 1, self.mask.shape[0], 1) 211 | 212 | if self.preselect: 213 | zeros = torch.zeros(1, 1, self.preselect_num, 1).to(input.device) 214 | mask = torch.cat([zeros, mask], dim=2) 215 | return mask 216 | 217 | class HalfLineConstrainedProbMask(nn.Module): 218 | """ 219 | A learnable probablistic mask with the same shape as half of the kspace measurement (to force the model 220 | to not take conjugate symmetry points). 221 | The mask is constrained to include whole kspace lines in the readout direction 222 | """ 223 | def __init__(self, shape=[32], slope=5): 224 | """ 225 | shape ([int. int]): Shape of the reconstructed image 226 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to 227 | deterministic state. 228 | """ 229 | super(HalfLineConstrainedProbMask, self).__init__() 230 | 231 | self.slope = slope 232 | init_tensor = self._slope_random_uniform(shape[0]) 233 | self.mask = nn.Parameter(init_tensor) 234 | 235 | def forward(self, input): 236 | """ 237 | Args: 238 | input (torch.Tensor): Input tensor of shape NHWC 239 | 240 | Returns: 241 | (torch.Tensor): Output tensor of shape NHWC 242 | """ 243 | mask = torch.sigmoid(self.slope * self.mask).to(input.device).view(1, 1, input.shape[2]//2, 1) # only half of the kspace 244 | 245 | zero_mask = torch.zeros((1, 1, input.shape[2], 1)) 246 | zero_mask[:, :, :input.shape[2]//2] = mask 247 | 248 | return zero_mask.to(input.device) 249 | 250 | def _slope_random_uniform(self, shape, eps=1e-2): 251 | """ 252 | uniform random sampling mask with the shape as half of the kspace measurement 253 | """ 254 | temp = torch.zeros(shape//2).uniform_(eps, 1-eps) 255 | 256 | # logit with slope factor 257 | return -torch.log(1./temp-1.) / self.slope 258 | 259 | 260 | def RescaleProbMap(batch_x, sparsity): 261 | """ 262 | Rescale Probability Map 263 | given a prob map x, rescales it so that it obtains the desired sparsity 264 | 265 | if mean(x) > sparsity, then rescaling is easy: x' = x * sparsity / mean(x) 266 | if mean(x) < sparsity, one can basically do the same thing by rescaling 267 | (1-x) appropriately, then taking 1 minus the result. 268 | """ 269 | batch_size = len(batch_x) 270 | ret = [] 271 | for i in range(batch_size): 272 | x = batch_x[i:i+1] 273 | xbar = torch.mean(x) 274 | r = sparsity / (xbar) 275 | beta = (1-sparsity) / (1-xbar) 276 | 277 | # compute adjucement 278 | le = torch.le(r, 1).float() 279 | ret.append(le * x * r + (1-le) * (1 - (1 - x) * beta)) 280 | 281 | return torch.cat(ret, dim=0) 282 | 283 | 284 | class ThresholdRandomMaskSigmoidV1(Function): 285 | def __init__(self): 286 | """ 287 | Straight through estimator. 288 | The forward step stochastically binarizes the probability mask. 289 | The backward step estimate the non differentiable > operator using sigmoid with large slope (10). 290 | """ 291 | super(ThresholdRandomMaskSigmoidV1, self).__init__() 292 | 293 | @staticmethod 294 | def forward(ctx, input): 295 | batch_size = len(input) 296 | probs = [] 297 | results = [] 298 | 299 | for i in range(batch_size): 300 | x = input[i:i+1] 301 | 302 | count = 0 303 | while True: 304 | prob = x.new(x.size()).uniform_() 305 | result = (x > prob).float() 306 | 307 | if torch.isclose(torch.mean(result), torch.mean(x), atol=1e-3): 308 | break 309 | 310 | count += 1 311 | 312 | if count > 1000: 313 | print(torch.mean(prob), torch.mean(result), torch.mean(x)) 314 | assert 0 315 | 316 | probs.append(prob) 317 | results.append(result) 318 | 319 | results = torch.cat(results, dim=0) 320 | probs = torch.cat(probs, dim=0) 321 | ctx.save_for_backward(input, probs) 322 | 323 | return results 324 | 325 | @staticmethod 326 | def backward(ctx, grad_output): 327 | slope = 10 328 | input, prob = ctx.saved_tensors 329 | 330 | # derivative of sigmoid function 331 | current_grad = slope * torch.exp(-slope * (input - prob)) / torch.pow((torch.exp(-slope*(input-prob))+1), 2) 332 | 333 | return current_grad * grad_output 334 | 335 | def MaximumBinarize(input): 336 | batch_size = len(input) 337 | results = [] 338 | 339 | for i in range(batch_size): 340 | x = input[i:i+1] 341 | num = torch.sum(x).round().int() 342 | 343 | indices = torch.topk(x.reshape(-1), k=num)[1] 344 | 345 | mask = torch.zeros_like(x).reshape(-1) 346 | 347 | mask[indices] = 1 348 | 349 | mask = mask.reshape(*x.shape) 350 | 351 | results.append(mask) 352 | 353 | results = torch.cat(results, dim=0) 354 | 355 | return results 356 | -------------------------------------------------------------------------------- /activemri/baselines/loupe_codes/reconstructors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Portion of this code is from fastmri(https://github.com/facebookresearch/fastMRI) 3 | 4 | Copyright (c) Facebook, Inc. and its affiliates. 5 | 6 | Licensed under the MIT License. 7 | """ 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from activemri.baselines.loupe_codes.transforms import * 13 | from activemri.baselines.loupe_codes.layers import * 14 | 15 | class LOUPEUNet(nn.Module): 16 | """ 17 | PyTorch implementation of a U-Net model. 18 | This is based on: 19 | Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks 20 | for biomedical image segmentation. In International Conference on Medical image 21 | computing and computer-assisted intervention, pages 234–241. Springer, 2015. 22 | 23 | The model takes a real or complex value image and use a UNet to denoise the image. 24 | A residual connection is applied to stablize training. 25 | """ 26 | def __init__(self, 27 | in_chans, 28 | out_chans, 29 | chans, 30 | num_pool_layers, 31 | drop_prob, 32 | # mask_length, 33 | bi_dir=False, 34 | old_recon=False, 35 | with_uncertainty=False): 36 | """ 37 | Args: 38 | in_chans (int): Number of channels in the input to the U-Net model. 39 | out_chans (int): Number of channels in the output to the U-Net model. 40 | chans (int): Number of output channels of the first convolution layer. 41 | num_pool_layers (int): Number of down-sampling and up-sampling layers. 42 | drop_prob (float): Dropout probability. 43 | """ 44 | super().__init__() 45 | self.old_recon = old_recon 46 | if old_recon: 47 | assert 0 48 | in_chans = in_chans+1 # add mask dim and old reconstruction dim 49 | 50 | self.with_uncertainty = with_uncertainty 51 | 52 | if with_uncertainty: 53 | out_chans = out_chans+1 54 | 55 | self.in_chans = in_chans 56 | self.out_chans = out_chans 57 | self.chans = chans 58 | self.num_pool_layers = num_pool_layers 59 | self.drop_prob = drop_prob 60 | 61 | self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) 62 | ch = chans 63 | for i in range(num_pool_layers - 1): 64 | self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob)] 65 | ch *= 2 66 | self.conv = ConvBlock(ch, ch, drop_prob) 67 | 68 | self.up_sample_layers = nn.ModuleList() 69 | for i in range(num_pool_layers - 1): 70 | self.up_sample_layers += [ConvBlock(ch * 2, ch // 2, drop_prob)] 71 | ch //= 2 72 | self.up_sample_layers += [ConvBlock(ch * 2, ch, drop_prob)] 73 | self.conv2 = nn.Sequential( 74 | nn.Conv2d(ch, ch // 2, kernel_size=1), 75 | nn.Conv2d(ch // 2, out_chans, kernel_size=1), 76 | nn.Conv2d(out_chans, out_chans, kernel_size=1), 77 | ) 78 | 79 | nn.init.normal_(self.conv2[-1].weight, mean=0, std=0.001) 80 | self.conv2[-1].bias.data.fill_(0) 81 | 82 | def forward(self, input, old_recon=None, eps=1e-8): 83 | # input: NCHW 84 | # output: NCHW 85 | 86 | if self.old_recon: 87 | assert 0 88 | output = torch.cat([input, old_recon], dim=1) 89 | else: 90 | output = input 91 | 92 | stack = [] 93 | # print(input.shape, mask.shape) 94 | 95 | # Apply down-sampling layers 96 | for layer in self.down_sample_layers: 97 | output = layer(output) 98 | stack.append(output) 99 | output = F.max_pool2d(output, kernel_size=2) 100 | 101 | output = self.conv(output) 102 | 103 | # Apply up-sampling layers 104 | for layer in self.up_sample_layers: 105 | downsample_layer = stack.pop() 106 | layer_size = (downsample_layer.shape[-2], downsample_layer.shape[-1]) 107 | output = F.interpolate(output, size=layer_size, mode='bilinear', align_corners=False) 108 | output = torch.cat([output, downsample_layer], dim=1) 109 | output = layer(output) 110 | 111 | out_conv2 = self.conv2(output) 112 | 113 | img_residual = out_conv2[:, :1] 114 | 115 | if self.with_uncertainty: 116 | map = out_conv2[:, 1:] 117 | else: 118 | map = torch.zeros_like(out_conv2) 119 | 120 | if self.old_recon: 121 | return img_residual + old_recon 122 | else: 123 | return img_residual + torch.norm(input, dim=1, keepdim=True) -------------------------------------------------------------------------------- /activemri/baselines/loupe_codes/samplers.py: -------------------------------------------------------------------------------- 1 | from enum import auto 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from activemri.baselines.loupe_codes.layers import * 6 | 7 | from activemri.baselines.loupe_codes import transforms 8 | 9 | from torch.autograd import Function 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import sigpy 13 | import sigpy.mri 14 | 15 | class LOUPESampler(nn.Module): 16 | """ 17 | LOUPE Sampler 18 | """ 19 | def __init__(self, shape=[320, 320], slope=5, sparsity=0.25, line_constrained=False, 20 | conjugate_mask=False, preselect=False, preselect_num=2, random=False, poisson=False, 21 | spectrum=False, equispaced=False): 22 | """ 23 | shape ([int. int]): Shape of the reconstructed image 24 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to 25 | deterministic state. 26 | sparsity (float): Predefined sparsity of the learned probability mask. 1 / acceleration_ratio 27 | line_constrained (bool): Sample kspace measurements column by column 28 | conjugate_mask (bool): For real image, the corresponding kspace measurements have conjugate symmetry property 29 | (point reflection). Therefore, the information in the left half of the kspace image is the same as the 30 | other half. To take advantage of this, we can force the model to only sample right half of the kspace 31 | (when conjugate_mask is set to True) 32 | preselect: preselect center regions 33 | """ 34 | super().__init__() 35 | 36 | assert conjugate_mask is False 37 | 38 | # probability mask 39 | if line_constrained: 40 | self.gen_mask = LineConstrainedProbMask(shape, slope, preselect=preselect, preselect_num=preselect_num) 41 | else: 42 | self.gen_mask = ProbMask(shape, slope, preselect_num=preselect_num) 43 | 44 | self.rescale = RescaleProbMap 45 | self.binarize = ThresholdRandomMaskSigmoidV1.apply # FIXME 46 | 47 | self.preselect =preselect 48 | self.preselect_num_one_side = preselect_num // 2 49 | self.shape = shape 50 | self.line_constrained = line_constrained 51 | self.random_baseline = random 52 | self.poisson_baseline = poisson 53 | self.spectrum_baseline = spectrum 54 | self.equispaced_baseline = equispaced 55 | 56 | if self.poisson_baseline: 57 | self.acc = 1 / (sparsity + (self.preselect_num_one_side*2)**2 / (128*128)) 58 | print("generate variable density mask with acceleration {}".format(self.acc)) 59 | 60 | if self.spectrum_baseline: 61 | acc = 1 / (sparsity + (self.preselect_num_one_side*2)**2 / (128*128)) 62 | print("generate spectrum mask with acceleration {}".format(acc)) 63 | mask = torch.load('resources/spectrum_{}x_128.pt'.format(int(acc))) 64 | mask = mask.reshape(1, 128, 128, 1).float() 65 | self.spectrum_mask = nn.Parameter(mask, requires_grad=False) 66 | 67 | if self.equispaced_baseline: 68 | acc = 1 / (sparsity + (self.preselect_num_one_side)*2 / (128)) 69 | 70 | assert self.line_constrained 71 | assert acc == 4 72 | mask = torch.load('resources/equispaced_4x_128.pt').reshape(1, 128, 128, 1).float() 73 | self.equispaced_mask = nn.Parameter(mask, requires_grad=False) 74 | 75 | def _gen_poisson_mask(self): 76 | mask = sigpy.mri.poisson((128, 128), self.acc, dtype='int32', crop_corner=False) 77 | mask = torch.tensor(mask).reshape(1, 128, 128, 1).float() 78 | return mask 79 | 80 | def _mask_neg_entropy(self, mask, eps=1e-10): 81 | # negative of pixel wise entropy 82 | entropy = mask * torch.log(mask+eps) + (1-mask) * torch.log(1-mask+eps) 83 | return entropy 84 | 85 | def forward(self, kspace, sparsity): 86 | # kspace: NHWC 87 | # sparsity (float) 88 | prob_mask = self.gen_mask(kspace) 89 | 90 | if self.random_baseline: 91 | prob_mask = torch.ones_like(prob_mask) / 4 92 | 93 | if not self.line_constrained: 94 | prob_mask[:, :self.preselect_num_one_side, :self.preselect_num_one_side] = 0 95 | prob_mask[:, :self.preselect_num_one_side, -self.preselect_num_one_side:] = 0 96 | prob_mask[:, -self.preselect_num_one_side:, :self.preselect_num_one_side] = 0 97 | prob_mask[:, -self.preselect_num_one_side:, -self.preselect_num_one_side:] = 0 98 | else: 99 | prob_mask[..., :self.preselect_num_one_side, :] = 0 100 | prob_mask[..., -self.preselect_num_one_side:, :] = 0 101 | 102 | 103 | if not self.preselect: 104 | rescaled_mask = self.rescale(prob_mask, sparsity) 105 | binarized_mask = self.binarize(rescaled_mask) 106 | else: 107 | rescaled_mask = self.rescale(prob_mask, sparsity) 108 | if self.training: 109 | binarized_mask = self.binarize(rescaled_mask) 110 | else: 111 | binarized_mask = self.binarize(rescaled_mask) 112 | 113 | if not self.line_constrained: 114 | binarized_mask[:, :self.preselect_num_one_side, :self.preselect_num_one_side] = 1 115 | binarized_mask[:, :self.preselect_num_one_side, -self.preselect_num_one_side:] = 1 116 | binarized_mask[:, -self.preselect_num_one_side:, :self.preselect_num_one_side] = 1 117 | binarized_mask[:, -self.preselect_num_one_side:, -self.preselect_num_one_side:] = 1 118 | else: 119 | binarized_mask[..., :self.preselect_num_one_side, :] = 1 120 | binarized_mask[..., -self.preselect_num_one_side:, :] = 1 121 | 122 | neg_entropy = self._mask_neg_entropy(rescaled_mask) 123 | 124 | if self.poisson_baseline: 125 | assert not self.line_constrained 126 | binarized_mask = transforms.fftshift(self._gen_poisson_mask(), dim=(1,2)).to(kspace.device) 127 | 128 | if self.spectrum_baseline: 129 | assert not self.line_constrained 130 | binarized_mask = self.spectrum_mask # DC are in the corners 131 | 132 | if self.equispaced_baseline: 133 | binarized_mask = transforms.fftshift(self.equispaced_mask, dim=(1, 2)) 134 | 135 | masked_kspace = binarized_mask * kspace 136 | 137 | data_to_vis_sampler = {'prob_mask': transforms.fftshift(prob_mask[0,:,:,0],dim=(0,1)).cpu().detach().numpy(), 138 | 'rescaled_mask': transforms.fftshift(rescaled_mask[0,:,:,0],dim=(0,1)).cpu().detach().numpy(), 139 | 'binarized_mask': transforms.fftshift(binarized_mask[0,:,:,0],dim=(0,1)).cpu().detach().numpy()} 140 | 141 | return masked_kspace, binarized_mask, neg_entropy, data_to_vis_sampler 142 | 143 | class BiLOUPESampler(nn.Module): 144 | """ 145 | LOUPE Sampler 146 | """ 147 | def __init__(self, shape=[320, 320], slope=5, sparsity=0.25, line_constrained=False, 148 | conjugate_mask=False, preselect=False, preselect_num=2): 149 | """ 150 | shape ([int. int]): Shape of the reconstructed image 151 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to 152 | deterministic state. 153 | sparsity (float): Predefined sparsity of the learned probability mask. 1 / acceleration_ratio 154 | line_constrained (bool): Sample kspace measurements column by column 155 | conjugate_mask (bool): For real image, the corresponding kspace measurements have conjugate symmetry property 156 | (point reflection). Therefore, the information in the left half of the kspace image is the same as the 157 | other half. To take advantage of this, we can force the model to only sample right half of the kspace 158 | (when conjugate_mask is set to True) 159 | preselect: preselect center regions 160 | """ 161 | super().__init__() 162 | 163 | assert conjugate_mask is False 164 | assert line_constrained 165 | 166 | # probability mask 167 | if line_constrained: 168 | self.gen_mask = BiLineConstrainedProbMask([shape[0]*2], slope, preselect=preselect, preselect_num=preselect_num) 169 | else: 170 | assert 0 171 | self.gen_mask = ProbMask(shape, slope) 172 | 173 | self.rescale = RescaleProbMap 174 | self.binarize = ThresholdRandomMaskSigmoidV1.apply # FIXME 175 | 176 | self.preselect =preselect 177 | self.preselect_num = preselect_num 178 | self.shape = shape 179 | 180 | def _mask_neg_entropy(self, mask, eps=1e-10): 181 | # negative of pixel wise entropy 182 | entropy = mask * torch.log(mask+eps) + (1-mask) * torch.log(1-mask+eps) 183 | return entropy 184 | 185 | def forward(self, kspace, sparsity): 186 | # kspace: NHWC 187 | # sparsity (float) 188 | prob_mask = self.gen_mask(kspace) 189 | 190 | batch_size = kspace.shape[0] 191 | 192 | if not self.preselect: 193 | assert 0 194 | else: 195 | rescaled_mask = self.rescale(prob_mask, sparsity/2) 196 | if self.training: 197 | binarized_mask = self.binarize(rescaled_mask) 198 | else: 199 | binarized_mask = self.binarize(rescaled_mask) 200 | 201 | # always preselect vertical lines 202 | binarized_vertical_mask, binarized_horizontal_mask = torch.chunk(binarized_mask, dim=2, chunks=2) 203 | 204 | binarized_horizontal_mask = binarized_horizontal_mask.transpose(1, 2) 205 | 206 | binarized_mask = torch.clamp(binarized_vertical_mask + binarized_horizontal_mask, max=1, min=0) 207 | 208 | binarized_mask[..., :self.preselect_num, :] = 1 209 | 210 | masked_kspace = binarized_mask * kspace 211 | neg_entropy = self._mask_neg_entropy(rescaled_mask) 212 | 213 | # for visualization purpose 214 | vertical_mask, horizontal_mask = torch.chunk(prob_mask.reshape(1, -1), dim=-1, chunks=2) 215 | prob_mask =vertical_mask.reshape(1, 1, 1, -1)+horizontal_mask.reshape(1, 1, -1, 1) 216 | 217 | rescaled_vertical_mask, rescaled_horizontal_mask = torch.chunk(rescaled_mask.reshape(1, -1), dim=-1, chunks=2) 218 | rescaled_mask = rescaled_vertical_mask.reshape(1, 1, 1, -1)+rescaled_horizontal_mask.reshape(1, 1, -1, 1) 219 | 220 | 221 | data_to_vis_sampler = {'prob_mask': transforms.fftshift(prob_mask[0, 0],dim=(0,1)).cpu().detach().numpy(), 222 | 'rescaled_mask': transforms.fftshift(rescaled_mask[0, 0],dim=(0,1)).cpu().detach().numpy(), 223 | 'binarized_mask': transforms.fftshift(binarized_mask[0,:,:,0],dim=(0,1)).cpu().detach().numpy()} 224 | 225 | return masked_kspace, binarized_mask, neg_entropy, data_to_vis_sampler 226 | -------------------------------------------------------------------------------- /activemri/baselines/loupe_codes/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | from torch._C import FUSE_ADD_RELU 11 | 12 | 13 | def to_tensor(data): 14 | """ 15 | Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts 16 | are stacked along the last dimension. 17 | 18 | Args: 19 | data (np.array): Input numpy array 20 | 21 | Returns: 22 | torch.Tensor: PyTorch version of data 23 | """ 24 | if np.iscomplexobj(data): 25 | data = np.stack((data.real, data.imag), axis=-1) 26 | return torch.from_numpy(data) 27 | 28 | 29 | def apply_mask(data, mask_func, seed=None): 30 | """ 31 | Subsample given k-space by multiplying with a mask. 32 | 33 | Args: 34 | data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where 35 | dimensions -3 and -2 are the spatial dimensions, and the final dimension has size 36 | 2 (for complex values). 37 | mask_func (callable): A function that takes a shape (tuple of ints) and a random 38 | number seed and returns a mask. 39 | seed (int or 1-d array_like, optional): Seed for the random number generator. 40 | 41 | Returns: 42 | (tuple): tuple containing: 43 | masked data (torch.Tensor): Subsampled k-space data 44 | mask (torch.Tensor): The generated mask 45 | """ 46 | shape = np.array(data.shape) 47 | shape[:-3] = 1 48 | mask = mask_func(shape, seed).to(data.device) 49 | return torch.where(mask == 0, torch.Tensor([0]).to(data.device), data), mask 50 | 51 | 52 | def fft2(data): 53 | """ 54 | Apply centered 2 dimensional Fast Fourier Transform. 55 | 56 | Args: 57 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 58 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 59 | assumed to be batch dimensions. 60 | 61 | Returns: 62 | torch.Tensor: The FFT of the input. 63 | """ 64 | assert data.size(-1) == 2 65 | data = data.fft(2, normalized=False) 66 | return data 67 | 68 | 69 | def ifft2(data): 70 | """ 71 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 72 | 73 | Args: 74 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 75 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 76 | assumed to be batch dimensions. 77 | 78 | Returns: 79 | torch.Tensor: The IFFT of the input. 80 | """ 81 | assert data.size(-1) == 2 82 | data = torch.ifft(data, 2, normalized=False) 83 | return data 84 | 85 | 86 | def complex_abs(data): 87 | """ 88 | Compute the absolute value of a complex valued input tensor. 89 | 90 | Args: 91 | data (torch.Tensor): A complex valued tensor, where the size of the final dimension 92 | should be 2. 93 | 94 | Returns: 95 | torch.Tensor: Absolute value of data 96 | """ 97 | assert data.size(-1) == 2 98 | return (data ** 2).sum(dim=-1).sqrt() 99 | 100 | 101 | def root_sum_of_squares(data, dim=0): 102 | """ 103 | Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor. 104 | 105 | Args: 106 | data (torch.Tensor): The input tensor 107 | dim (int): The dimensions along which to apply the RSS transform 108 | 109 | Returns: 110 | torch.Tensor: The RSS value 111 | """ 112 | return torch.sqrt((data ** 2).sum(dim)) 113 | 114 | 115 | def center_crop(data, shape): 116 | """ 117 | Apply a center crop to the input real image or batch of real images. 118 | 119 | Args: 120 | data (torch.Tensor): The input tensor to be center cropped. It should have at 121 | least 2 dimensions and the cropping is applied along the last two dimensions. 122 | shape (int, int): The output shape. The shape should be smaller than the 123 | corresponding dimensions of data. 124 | 125 | Returns: 126 | torch.Tensor: The center cropped image 127 | """ 128 | assert 0 < shape[0] <= data.shape[-2] 129 | assert 0 < shape[1] <= data.shape[-1] 130 | w_from = (data.shape[-2] - shape[0]) // 2 131 | h_from = (data.shape[-1] - shape[1]) // 2 132 | w_to = w_from + shape[0] 133 | h_to = h_from + shape[1] 134 | return data[..., w_from:w_to, h_from:h_to] 135 | 136 | 137 | def complex_center_crop(data, shape): 138 | """ 139 | Apply a center crop to the input image or batch of complex images. 140 | 141 | Args: 142 | data (torch.Tensor): The complex input tensor to be center cropped. It should 143 | have at least 3 dimensions and the cropping is applied along dimensions 144 | -3 and -2 and the last dimensions should have a size of 2. 145 | shape (int, int): The output shape. The shape should be smaller than the 146 | corresponding dimensions of data. 147 | 148 | Returns: 149 | torch.Tensor: The center cropped image 150 | """ 151 | assert 0 < shape[0] <= data.shape[-3] 152 | assert 0 < shape[1] <= data.shape[-2] 153 | w_from = (data.shape[-3] - shape[0]) // 2 154 | h_from = (data.shape[-2] - shape[1]) // 2 155 | w_to = w_from + shape[0] 156 | h_to = h_from + shape[1] 157 | return data[..., w_from:w_to, h_from:h_to, :] 158 | 159 | 160 | def normalize(data, mean, stddev, eps=0.): 161 | """ 162 | Normalize the given tensor using: 163 | (data - mean) / (stddev + eps) 164 | 165 | Args: 166 | data (torch.Tensor): Input data to be normalized 167 | mean (float): Mean value 168 | stddev (float): Standard deviation 169 | eps (float): Added to stddev to prevent dividing by zero 170 | 171 | Returns: 172 | torch.Tensor: Normalized tensor 173 | """ 174 | return (data - mean) / (stddev + eps) 175 | 176 | 177 | def normalize_instance(data, eps=0.): 178 | """ 179 | Normalize the given tensor using: 180 | (data - mean) / (stddev + eps) 181 | where mean and stddev are computed from the data itself. 182 | 183 | Args: 184 | data (torch.Tensor): Input data to be normalized 185 | eps (float): Added to stddev to prevent dividing by zero 186 | 187 | Returns: 188 | torch.Tensor: Normalized tensor 189 | """ 190 | mean = data.mean() 191 | std = data.std() 192 | return normalize(data, mean, std, eps), mean, std 193 | 194 | 195 | # Helper functions 196 | 197 | def roll(x, shift, dim): 198 | """ 199 | Similar to np.roll but applies to PyTorch Tensors 200 | """ 201 | if isinstance(shift, (tuple, list)): 202 | assert len(shift) == len(dim) 203 | for s, d in zip(shift, dim): 204 | x = roll(x, s, d) 205 | return x 206 | shift = shift % x.size(dim) 207 | if shift == 0: 208 | return x 209 | left = x.narrow(dim, 0, x.size(dim) - shift) 210 | right = x.narrow(dim, x.size(dim) - shift, shift) 211 | return torch.cat((right, left), dim=dim) 212 | 213 | 214 | def fftshift(x, dim=None): 215 | """ 216 | Similar to np.fft.fftshift but applies to PyTorch Tensors 217 | """ 218 | if dim is None: 219 | dim = tuple(range(x.dim())) 220 | shift = [dim // 2 for dim in x.shape] 221 | elif isinstance(dim, int): 222 | shift = x.shape[dim] // 2 223 | else: 224 | shift = [x.shape[i] // 2 for i in dim] 225 | return roll(x, shift, dim) 226 | 227 | 228 | def ifftshift(x, dim=None): 229 | """ 230 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 231 | """ 232 | if dim is None: 233 | dim = tuple(range(x.dim())) 234 | shift = [(dim + 1) // 2 for dim in x.shape] 235 | elif isinstance(dim, int): 236 | shift = (x.shape[dim] + 1) // 2 237 | else: 238 | shift = [(x.shape[i] + 1) // 2 for i in dim] 239 | return roll(x, shift, dim) 240 | 241 | 242 | def to_kspace(x): 243 | image = torch.cat([x, torch.zeros_like(x)], dim=-1) 244 | kspace = fft2(image) 245 | 246 | return kspace 247 | 248 | def rl_input_to_loupe_input(x, kspace): 249 | assert x.size(-1) == 2 250 | assert kspace.size(-1) == 2 251 | 252 | x = x[...,0:1].permute((0,3,1,2)) 253 | kspace = fft2(fftshift(ifft2(kspace), dim=(1, 2))) 254 | 255 | return x, kspace 256 | 257 | def add_gaussian_noise(args, kspace, mean=0., std=1.): 258 | # kspace: 32*32*2 259 | noise = mean + torch.randn(kspace.size()) * std 260 | kspace = kspace + noise 261 | return kspace 262 | 263 | def data_consistency(recon, zero_filled_recon, mask): 264 | mask = mask.permute(0, 2, 3, 1) 265 | pred_kspace = ifftshift(recon.permute(0, 2, 3, 1), dim=(1, 2)).fft(2, normalized=False) 266 | fuse_kspace = torch.where((1-mask).byte(), pred_kspace, torch.tensor(0.0).to(mask.device)) 267 | new_img = fftshift(fuse_kspace.ifft(2, normalized=False), dim=(1, 2)).permute(0, -1, 1, 2) 268 | 269 | return new_img + zero_filled_recon 270 | -------------------------------------------------------------------------------- /activemri/baselines/non_rl.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch, argparse 3 | from typing import List 4 | import activemri.envs.envs as mri_envs 5 | import numpy as np 6 | import logging 7 | import os 8 | from torch.nn import functional as F 9 | from tensorboardX import SummaryWriter 10 | import time 11 | from .loupe_codes.evaluate import Metrics, METRIC_FUNCS 12 | from .loupe_codes import transforms 13 | import activemri.envs.loupe_envs as loupe_envs 14 | import pickle 15 | from .differential_sampler import SequentialUnet 16 | from .loupe import LOUPE 17 | import random 18 | import shutil 19 | import matplotlib.pyplot as plt 20 | from ..envs.util import compute_gaussian_nll_loss, compute_ssim_torch, compute_psnr_torch 21 | 22 | def build_optimizer(lr, weight_decay, model_parameters, type='adam'): 23 | if type == 'adam': 24 | optimizer = torch.optim.Adam(model_parameters, lr, weight_decay=weight_decay) 25 | elif type == 'sgd': 26 | optimizer = torch.optim.SGD(model_parameters, lr, weight_decay=weight_decay) 27 | else: 28 | raise NotImplementedError() 29 | 30 | return optimizer 31 | 32 | def build_lr_scheduler(options, optimizer): 33 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, options.lr_step_size, options.lr_gamma) 34 | return scheduler 35 | 36 | def build_model(options): 37 | if options.model == 'LOUPE': 38 | model = LOUPE( 39 | in_chans=options.input_chans, 40 | out_chans=options.output_chans, 41 | chans=options.num_chans, 42 | num_pool_layers=options.num_pools, 43 | drop_prob=options.drop_prob, 44 | sparsity=1.0/options.accelerations[0], 45 | shape=options.resolution, 46 | conjugate_mask=options.conjugate_mask, 47 | line_constrained=options.line_constrained, 48 | bi_dir=options.bi_dir, 49 | preselect_num=options.preselect_num, 50 | preselect=options.preselect, 51 | random=options.random if 'random' in options.__dict__ else False, 52 | poisson=options.poisson if 'poisson' in options.__dict__ else False, 53 | spectrum=options.spectrum if 'spectrum' in options.__dict__ else False, 54 | equispaced=options.equispaced if 'equispaced' in options.__dict__ else False,).to(options.device) 55 | elif options.model == 'SequentialSampling': 56 | model = SequentialUnet( 57 | in_chans=options.input_chans, 58 | out_chans=options.output_chans, 59 | chans=options.num_chans, 60 | num_pool_layers=options.num_pools, 61 | drop_prob=options.drop_prob, 62 | sparsity=1.0/options.accelerations[0], 63 | shape=options.resolution, 64 | conjugate_mask=options.conjugate_mask, 65 | num_step=options.num_step, 66 | preselect=options.preselect, 67 | bi_direction=options.bi_dir, 68 | preselect_num=options.preselect_num, 69 | binary_sampler=options.binary_sampler, 70 | clamp=options.clamp, 71 | line_constrained=options.line_constrained, 72 | old_recon=options.old_recon, 73 | with_uncertainty=options.uncertainty_loss, 74 | detach_kspace=options.detach_kspace, 75 | fixed_input=options.fixed_input, 76 | pretrained_recon=options.pretrained_recon if 'pretrained_recon' in options.__dict__ else False).to(options.device) 77 | else: 78 | raise NotImplementedError() 79 | 80 | return model 81 | 82 | def get_mask_stats(masks): 83 | masks = np.array(masks) 84 | 85 | return np.mean(masks, axis=0), np.std(masks, axis=0) 86 | 87 | 88 | class NonRLTrainer: 89 | """Differentiabl Sampler Trainer for active MRI acquisition. 90 | 91 | Configuration for the trainer is provided by argument ``options``. Must contain the 92 | following fields: 93 | 94 | Args: 95 | options(``argparse.Namespace``): Options for the trainer. 96 | env(``activemri.envs.ActiveMRIEnv``): Env for which the policy is trained. 97 | device(``torch.device``): Device to use. 98 | """ 99 | def __init__( 100 | self, 101 | options: argparse.Namespace, 102 | env: mri_envs.ActiveMRIEnv, 103 | device: torch.device 104 | ): 105 | self.options = options 106 | self.env = env 107 | self.device = device 108 | self.model = build_model(self.options) 109 | if options.data_parallel: 110 | self.model = torch.nn.DataParallel(self.model) 111 | 112 | self.optimizer = build_optimizer(self.options.lr, self.options.weight_decay, self.model.parameters()) 113 | 114 | self.train_loader, self.dev_loader, self.display_loader, _ = self.env._setup_data_handlers() 115 | 116 | self.scheduler = build_lr_scheduler(self.options, self.optimizer) 117 | self.best_dev_loss = np.inf 118 | self.epoch = 0 119 | self.start_epoch = 0 120 | self.end_epoch = options.num_epochs 121 | 122 | # setup saving, writer, and logging 123 | options.exp_dir.mkdir(parents=True, exist_ok=True) 124 | 125 | with open(os.path.join(str(options.exp_dir), 'args.pkl'), "wb") as f: 126 | pickle.dump(options.__dict__, f) 127 | 128 | options.visualization_dir.mkdir(parents=True, exist_ok=True) 129 | self.writer = SummaryWriter(log_dir=options.exp_dir / 'summary') 130 | 131 | logging.basicConfig(level=logging.INFO) 132 | self.logger = logging.getLogger(__name__) 133 | self.logger.info(self.options) 134 | self.logger.info(self.model) 135 | 136 | def load_checkpoint_if_needed(self): 137 | if self.options.resume: 138 | self.load() 139 | 140 | def evaluate(self): 141 | self.model.eval() 142 | losses = [] 143 | sparsity = [] 144 | targets, preds = [], [] 145 | metrics = Metrics(METRIC_FUNCS) 146 | start = time.perf_counter() 147 | 148 | with torch.no_grad(): 149 | for iter, data in enumerate(self.dev_loader): 150 | # input: [batch_size, num_channels, height, width] denoted as NCHW in other places 151 | # label: label of the current image (0~9 for mnist/fashion-mnist) default: -1 152 | # target: a copy of the input image for computing reconstruction loss in [NCHW] 153 | kspace, _, input, label, *ignored = data 154 | 155 | # adapt data to loupe 156 | target = input.clone().detach() 157 | target = transforms.complex_abs(target).unsqueeze(1) 158 | 159 | input = input.to(self.options.device) 160 | target = target.to(self.options.device) 161 | kspace = kspace.to(self.options.device) 162 | # label = label.to(self.options.device) 163 | 164 | pred_dict = self.model(target, kspace) 165 | 166 | if (self.epoch == 0 or (self.epoch+1) % 1 == 0) and iter == 0: 167 | data_for_vis_name = 'eval_epoch=' + str(self.epoch+1) 168 | self.model.visualize_and_save(self.options, self.epoch, data_for_vis_name) 169 | 170 | output = pred_dict['output'] 171 | # only use the last reconstructed image to compute loss 172 | if isinstance(output, list): 173 | output = output[-1] 174 | 175 | target_dict = {'target': target, 'label': label, 'kspace':kspace} 176 | meta = {'entropy_weight': self.options.entropy_weight, 'recon_weight': self.options.recon_weight, 177 | 'uncertainty_weight': 0, 'kspace_weight': self.options.kspace_weight} 178 | 179 | loss, log_dict = self.model.loss(pred_dict, target_dict, meta, self.options.loss_type) 180 | 181 | mask = pred_dict['mask'] 182 | sparsity.append(torch.mean(mask).item()) 183 | losses.append(loss.item()) 184 | 185 | # target: 16*1*32*32 186 | # output: 16*1*32*32 187 | target = target.cpu().numpy() 188 | pred = output.cpu().numpy() 189 | 190 | for t, p in zip(target, pred): 191 | metrics.push(t, p) 192 | 193 | print(metrics) 194 | self.writer.add_scalar('Dev_MSE', metrics.means()['MSE'], self.epoch) 195 | self.writer.add_scalar('Dev_NMSE', metrics.means()['NMSE'], self.epoch) 196 | self.writer.add_scalar('Dev_PSNR', metrics.means()['PSNR'], self.epoch) 197 | self.writer.add_scalar('Dev_SSIM', metrics.means()['SSIM'], self.epoch) 198 | 199 | self.writer.add_scalar('Dev_Loss', np.mean(losses), self.epoch) 200 | 201 | return np.mean(losses), np.mean(sparsity), time.perf_counter() - start 202 | 203 | def train_epoch(self): 204 | self.model.train() 205 | losses = [] 206 | targets, preds = [], [] 207 | metrics = Metrics(METRIC_FUNCS) 208 | avg_loss = 0. 209 | start_epoch = start_iter = time.perf_counter() 210 | global_step = self.epoch * len(self.train_loader) 211 | 212 | for iter, data in enumerate(self.train_loader): 213 | # self.scheduler.step() 214 | # input: [batch_size, num_channels, height, width] denoted as NCHW in other places 215 | # label: label of the current image (0~9 for mnist/fashion-mnist) default: -1 216 | # target: a copy of the input image for computing reconstruction loss in [NCHW] 217 | kspace, _, input, label, *ignored= data 218 | 219 | # adapt data to loupe 220 | target = input.clone().detach() 221 | target = transforms.complex_abs(target).unsqueeze(1) 222 | 223 | input = input.to(self.options.device) 224 | target = target.to(self.options.device) 225 | kspace = kspace.to(self.options.device) 226 | # label = label.to(self.options.device) 227 | 228 | """if self.options.noise_type == 'gaussian': 229 | kspace = transforms.add_gaussian_noise(self.options, kspace, mean=0., std=self.options.noise_level) 230 | """ 231 | 232 | pred_dict = self.model(target, kspace) 233 | 234 | if (self.epoch == 0 or (self.epoch+1) % 1 == 0) and iter == 0: 235 | data_for_vis_name = 'train_epoch={}_iter={}'.format(str(self.epoch+1), str(iter+1)) 236 | self.model.visualize_and_save(self.options, self.epoch, data_for_vis_name) 237 | 238 | output = pred_dict['output'] 239 | target_dict = {'target': target, 'label': label, 'kspace': kspace} 240 | meta = {'entropy_weight': self.options.entropy_weight, 'recon_weight': self.options.recon_weight, 241 | 'kspace_weight': self.options.kspace_weight, 242 | 'uncertainty_weight': self.options.uncertainty_weight if 'uncertainty_weight' in self.options.__dict__ else 0} 243 | 244 | loss, log_dict = self.model.loss(pred_dict, target_dict, meta, self.options.loss_type) 245 | 246 | self.optimizer.zero_grad() 247 | loss.backward() 248 | # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10) 249 | 250 | self.optimizer.step() 251 | 252 | self.writer.add_scalar('Train_Loss', loss.item(), global_step + iter) 253 | 254 | losses.append(loss.item()) 255 | 256 | # target: 16*1*32*32 257 | # output: 16*1*32*32 258 | 259 | if isinstance(output, list): 260 | output = output[-1] 261 | 262 | target = target.cpu().detach().numpy() 263 | pred = output.cpu().detach().numpy() 264 | 265 | if iter % self.options.report_interval == 0: 266 | self.logger.info( 267 | f'Epoch = [{1 + self.epoch:3d}/{self.options.num_epochs:3d}] ' 268 | f'Iter = [{iter:4d}/{len(self.train_loader):4d}] ' 269 | f'Time = {time.perf_counter() - start_iter:.4f}s', 270 | ) 271 | for key, val in log_dict.items(): 272 | print('{} = {}'.format(key, val)) 273 | 274 | start_iter = time.perf_counter() 275 | 276 | for t, p in zip(target, pred): 277 | metrics.push(t, p) 278 | 279 | print(metrics) 280 | self.writer.add_scalar('Train_MSE', metrics.means()['MSE'], self.epoch) 281 | self.writer.add_scalar('Train_NMSE', metrics.means()['NMSE'], self.epoch) 282 | self.writer.add_scalar('Train_PSNR', metrics.means()['PSNR'], self.epoch) 283 | self.writer.add_scalar('Train_SSIM', metrics.means()['SSIM'], self.epoch) 284 | 285 | return np.mean(np.array(losses)), time.perf_counter() - start_epoch 286 | 287 | def _train_loupe(self): 288 | for epoch in range(self.start_epoch, self.end_epoch): 289 | self.epoch = epoch 290 | train_loss, train_time = self.train_epoch() 291 | self.scheduler.step(epoch) 292 | dev_loss, mean_sparsity, dev_time = self.evaluate() 293 | 294 | is_new_best = dev_loss < self.best_dev_loss 295 | self.best_dev_loss = min(self.best_dev_loss, dev_loss) 296 | if self.options.save_model: 297 | self.save_model(is_new_best) 298 | self.logger.info( 299 | f'Epoch = [{1 + self.epoch:4d}/{self.options.num_epochs:4d}] TrainLoss = {train_loss:.4g} ' 300 | f'DevLoss = {dev_loss:.4g} MeanSparsity = {mean_sparsity:.4g} TrainTime = {train_time:.4f}s DevTime = {dev_time:.4f}s', 301 | ) 302 | self.writer.close() 303 | 304 | def __call__(self): 305 | self.load_checkpoint_if_needed() 306 | return self._train_loupe() 307 | 308 | def save_model(self,is_new_best): 309 | exp_dir = self.options.exp_dir 310 | torch.save( 311 | { 312 | 'epoch': self.epoch, 313 | 'options': self.options, 314 | 'model': self.model.state_dict(), 315 | 'optimizer': self.optimizer.state_dict(), 316 | 'best_dev_loss': self.best_dev_loss, 317 | 'exp_dir': exp_dir 318 | }, 319 | f = exp_dir / 'model.pt' 320 | ) 321 | if is_new_best: 322 | shutil.copyfile(exp_dir / 'model.pt', exp_dir / 'best_model.pt') 323 | 324 | 325 | def load(self): 326 | self.model = build_model(self.options) 327 | if self.options.data_parallel: 328 | self.model = torch.nn.DataParallel(self.model) 329 | 330 | checkpoint1 = torch.load(self.options.checkpoint1) 331 | print("Load checkpoint {} with loss {}".format(checkpoint1['epoch'], checkpoint1['best_dev_loss'])) 332 | 333 | self.model.load_state_dict(checkpoint1['model']) 334 | 335 | self.optimizer = build_optimizer(self.options.lr, self.options.weight_decay, self.model.parameters()) 336 | self.optimizer.load_state_dict(checkpoint1['optimizer']) 337 | 338 | self.best_dev_loss = checkpoint1['best_dev_loss'] 339 | self.start_epoch = checkpoint1['epoch'] + 1 340 | del checkpoint1 341 | 342 | 343 | class NonRLTester: 344 | def __init__( 345 | self, 346 | env: loupe_envs.LOUPEDataEnv, 347 | exp_dir: str, 348 | options: argparse.Namespace, 349 | label_range: List[int] 350 | ): 351 | self.env = env 352 | # load options and model 353 | self.load(os.path.join(exp_dir , 'best_model.pt')) 354 | self.options.label_range = label_range 355 | self.options.exp_dir = exp_dir 356 | self.options.test_visual_frequency = options.test_visual_frequency 357 | self.options.visualization_dir = options.visualization_dir 358 | self.options.batch_size = 1 359 | _, _, self.dev_loader, _ = self.env._setup_data_handlers() 360 | 361 | # setup saving and logging 362 | self.options.exp_dir.mkdir(parents=True, exist_ok=True) 363 | 364 | logging.basicConfig(level=logging.INFO) 365 | self.logger = logging.getLogger(__name__) 366 | self.logger.info(self.options) 367 | self.logger.info(self.model) 368 | 369 | def evaluate(self): 370 | self.model.eval() 371 | losses = [] 372 | sparsity = [] 373 | targets, preds = [], [] 374 | metrics = Metrics(METRIC_FUNCS) 375 | masks = defaultdict(list) 376 | sample_image = dict() 377 | 378 | with torch.no_grad(): 379 | for iter, data in enumerate(self.dev_loader): 380 | # input: [batch_size, num_channels, height, width] denoted as NCHW in other places 381 | # label: label of the current image (0~9 for mnist/fashion-mnist) default: -1 382 | # target: a copy of the input image for computing reconstruction loss in [NCHW] 383 | kspace, _, input, label, *ignored = data 384 | 385 | # adapt data to loupe 386 | target = input.clone().detach() 387 | target = transforms.complex_abs(target).unsqueeze(1) 388 | 389 | input = input.to(self.options.device) 390 | target = target.to(self.options.device) 391 | kspace = kspace.to(self.options.device) 392 | 393 | """if self.options.noise_type == 'gaussian': 394 | kspace = transforms.add_gaussian_noise(self.options, kspace, mean=0., std=self.options.noise_level) 395 | """ 396 | 397 | pred_dict = self.model(target, kspace) 398 | 399 | if iter % self.options.test_visual_frequency == 0: 400 | data_for_vis_name = 'test_iter=' + str(iter+1) 401 | print("visualize {}".format(data_for_vis_name)) 402 | self.model.visualize_and_save(self.options, iter, data_for_vis_name) 403 | 404 | output = pred_dict['output'] 405 | if isinstance(output, list): 406 | output = output[-1] 407 | 408 | loss = compute_ssim_torch(output, target) # F.l1_loss(output, target, size_average=True) 409 | 410 | mask = pred_dict['mask'] 411 | # masks[label.item()].append(transforms.fftshift(mask[0, 0], dim=(0, 1)).cpu().detach().numpy()) 412 | 413 | sparsity.append(torch.mean(mask).item()) 414 | losses.append(loss.item()) 415 | 416 | targets.extend(target.cpu().numpy()) 417 | preds.extend(output.cpu().numpy()) 418 | 419 | # sample_image[label.item()] = target[0, 0].cpu().detach().numpy() 420 | 421 | """for key, val in masks.items(): 422 | # get average mask and standard deviation 423 | mask_mean, mask_std = get_mask_stats(val) 424 | mask_mean[16] = 0 425 | mask_mean[:, 16] = 0 426 | fig, ax = plt.subplots(nrows=1, ncols=3, figsize=[10, 3]) 427 | sp1 = ax[0].imshow(sample_image[key], cmap='viridis') 428 | sp2 = ax[1].imshow(mask_mean) 429 | sp3 = ax[2].imshow(mask_std) 430 | fig.colorbar(sp1, ax=ax[0]) 431 | fig.colorbar(sp2, ax=ax[1]) 432 | fig.colorbar(sp3, ax=ax[2]) 433 | 434 | plt.savefig(str(self.options.visualization_dir)+'/label_{}_mask_stats.png'.format(key)) 435 | plt.close() 436 | """ 437 | 438 | print("Done Prediction") 439 | 440 | for t, p in zip(targets, preds): 441 | metrics.push(t, p) 442 | 443 | return losses, np.mean(sparsity), metrics 444 | 445 | def __call__(self): 446 | dev_loss, sparsity, metrics = self.evaluate() 447 | 448 | print("L1 Loss {} STD {} Sparsity {}".format(np.mean(dev_loss), np.std(dev_loss), sparsity)) 449 | 450 | with open(os.path.join(self.options.exp_dir, 'statistics.txt'), 'w') as f: 451 | f.write('L1 Loss {} +- {}\n'.format(np.mean(dev_loss), np.std(dev_loss))) 452 | f.write(str(metrics)) 453 | 454 | print(metrics) 455 | 456 | with open(os.path.join(self.options.exp_dir, 'loss.pkl'), 'wb') as f: 457 | pickle.dump(dev_loss, f) 458 | 459 | def load(self, checkpoint_file): 460 | checkpoint = torch.load(checkpoint_file) 461 | print("Load checkpoint {} with loss {}".format(checkpoint['epoch'], checkpoint['best_dev_loss'])) 462 | self.options = checkpoint['options'] 463 | random.seed(self.options.seed) 464 | np.random.seed(self.options.seed) 465 | torch.manual_seed(self.options.seed) 466 | 467 | 468 | if 'bi_dir' not in self.options.__dict__: 469 | self.options.bi_dir = False 470 | 471 | if 'clamp' not in self.options.__dict__: 472 | self.options.clamp = 100 473 | 474 | if 'fixed_input' not in self.options.__dict__: 475 | self.options.fixed_input = False 476 | 477 | print('clamp value {}'.format(self.options.clamp)) 478 | 479 | self.model = build_model(self.options) 480 | self.model.eval() 481 | self.model.load_state_dict(checkpoint['model']) 482 | 483 | torch.save({'model':self.model.reconstructor.state_dict()}, os.path.join(*checkpoint_file.split('/')[:-1], 'best_recon.pt')) 484 | 485 | # checkpoint = torch.load('checkpoints/Reconstructors/dicom_ckpt/PG_DICOM_Knee_Reconstructor_UNet/best_model.pt', map_location='cpu') 486 | # self.model.reconstructor.load_state_dict(checkpoint['model']) 487 | -------------------------------------------------------------------------------- /activemri/baselines/sequential_sampling_codes/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_sampler import ConvSamplerSmall, LineConstrainedSampler -------------------------------------------------------------------------------- /activemri/baselines/sequential_sampling_codes/conv_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from activemri.baselines.loupe_codes import transforms 5 | 6 | # ResNet code is modified from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 7 | # Licensed under MIT License 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1): 13 | super(BasicBlock, self).__init__() 14 | self.conv1 = nn.Conv2d( 15 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.InstanceNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 18 | stride=1, padding=1, bias=False) 19 | self.bn2 = nn.InstanceNorm2d(planes) 20 | 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion*planes: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d(in_planes, self.expansion*planes, 25 | kernel_size=1, stride=stride, bias=False), 26 | nn.InstanceNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | class Bottleneck(nn.Module): 37 | expansion = 4 38 | 39 | def __init__(self, in_planes, planes, stride=1): 40 | super(Bottleneck, self).__init__() 41 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 42 | self.bn1 = nn.InstanceNorm2d(planes) 43 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 44 | stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.InstanceNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion * 47 | planes, kernel_size=1, bias=False) 48 | self.bn3 = nn.InstanceNorm2d(self.expansion*planes) 49 | 50 | self.shortcut = nn.Sequential() 51 | if stride != 1 or in_planes != self.expansion*planes: 52 | self.shortcut = nn.Sequential( 53 | nn.Conv2d(in_planes, self.expansion*planes, 54 | kernel_size=1, stride=stride, bias=False), 55 | nn.InstanceNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | class ResNet(nn.Module): 67 | def __init__(self, block, num_blocks, num_classes=10, fixed_input=False): 68 | super(ResNet, self).__init__() 69 | self.in_planes = 64 70 | # down sample layers 71 | 72 | self.conv1 = nn.Conv2d(5, 64, kernel_size=3, 73 | stride=1, padding=1, bias=False) 74 | self.bn1 = nn.InstanceNorm2d(64) 75 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 76 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 77 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 78 | 79 | # up sample layers 80 | 81 | self.up_convs = nn.Sequential( 82 | nn.ConvTranspose2d(256, 128, 2, stride=2, bias=False), 83 | nn.InstanceNorm2d(128), 84 | nn.ReLU(inplace=True), 85 | nn.ConvTranspose2d(128, 64, 2, stride=2, bias=False), 86 | nn.InstanceNorm2d(64), 87 | nn.ReLU(inplace=True), 88 | ) 89 | 90 | self.conv_last = nn.Sequential( 91 | nn.Conv2d(64, 16, kernel_size=3, padding=1, stride=1), 92 | nn.InstanceNorm2d(16), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(16, 1, kernel_size=1, stride=1) 95 | ) 96 | 97 | self.fixed_input = fixed_input 98 | if fixed_input: 99 | print('generate random input tensor') 100 | fixed_input_tensor = torch.randn(size=[1, 5, 128, 128]) 101 | self.fixed_input_tensor = nn.Parameter(fixed_input_tensor, requires_grad=False) 102 | 103 | 104 | def _make_layer(self, block, planes, num_blocks, stride): 105 | strides = [stride] + [1]*(num_blocks-1) 106 | layers = [] 107 | for stride in strides: 108 | layers.append(block(self.in_planes, planes, stride)) 109 | self.in_planes = planes * block.expansion 110 | return nn.Sequential(*layers) 111 | 112 | def forward(self, x, mask): 113 | """ 114 | This function takes the observed kspace data and sampling trajectories as input 115 | and output next sampling probability mask. We additionally mask out all 116 | previous sampled locations. 117 | input: NHWC 118 | mask: N1HW 119 | """ 120 | # NHWC -> NCHW 121 | x = x.permute(0, -1, 1, 2) 122 | 123 | # concatenate with previous sampling mask 124 | x = torch.cat([x, mask], dim=1) 125 | 126 | if self.fixed_input: 127 | x = self.fixed_input_tensor.repeat(x.shape[0], 1, 1, 1).to(x.device) 128 | 129 | # down convs 130 | out = F.relu(self.bn1(self.conv1(x))) 131 | out = self.layer1(out) 132 | out = self.layer2(out) 133 | out = self.layer3(out) 134 | 135 | out = self.up_convs(out) 136 | out = self.conv_last(out) 137 | 138 | out = torch.relu(out) / torch.max(out.reshape(out.shape[0], -1), dim=1)[0].reshape(-1, 1, 1, 1) 139 | 140 | # Don't need to select an existed measurement 141 | new_mask = out * (1-mask) 142 | 143 | return new_mask 144 | 145 | def ConvSamplerSmall(fixed_input=False): 146 | return ResNet(BasicBlock, [1, 1, 1], fixed_input=fixed_input) 147 | 148 | 149 | 150 | class LineConstrainedSampler(nn.Module): 151 | """ 152 | A line constrained convolutional sampler which uses the same architecture 153 | as the evaluator in https://arxiv.org/pdf/1902.03051.pdf 154 | In this module, we first create a pseudo image containing spectral maps 155 | corresponding to every kspace line. We then pass this pseudo image into 156 | a classification network to predict the probability of selecting 157 | each kspace line. Refer to the above paper for details 158 | """ 159 | 160 | def __init__(self, shape=[32]): 161 | super().__init__() 162 | self.mask_conv = nn.Linear(shape[0], 6) 163 | self.middle_convs = nn.Sequential( 164 | nn.Conv2d(6+shape[0], 256, kernel_size=3, stride=2, padding=1), 165 | nn.InstanceNorm2d(256), 166 | nn.LeakyReLU(0.2), 167 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), 168 | nn.InstanceNorm2d(512), 169 | nn.LeakyReLU(0.2), 170 | nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), 171 | nn.InstanceNorm2d(1024), 172 | nn.LeakyReLU(0.2), 173 | ) 174 | self.conv_last = nn.Linear(1024, shape[0]) 175 | 176 | # 1xWxHxWxC 177 | # spectral maps of each kspace column 178 | self.masks = torch.zeros(1, shape[1], shape[0], shape[1], 1, requires_grad=False).cuda() 179 | for i in range(shape[1]): 180 | self.masks[:, i, :, i] = 1 181 | 182 | print("Finish Mask Initialization") 183 | 184 | def _to_spectral(self, kspace): 185 | """ 186 | Args: 187 | kspace (torch.Tensor): Already sampled measurements shape NHWC 188 | 189 | Returns: 190 | spectral_map (torch.Tensor): NWHW 191 | """ 192 | spectral_map = transforms.complex_abs(transforms.fftshift(transforms.ifft2(kspace.unsqueeze(1) * self.masks))) 193 | 194 | return spectral_map 195 | 196 | def forward(self, kspace, mask): 197 | """ 198 | Args: 199 | kspace (torch.Tensor): Already sampled measurements shape NHWC 200 | mask (torch.Tensor): previous sampling trajectories shape N1HW 201 | 202 | Returns: 203 | new_mask (torch.Tensor): probability map of next iteration's sampling locations 204 | shape N1HW <- broadcasted from N11W 205 | """ 206 | N, C, H, W = mask.shape 207 | spectral_map = self._to_spectral(kspace) 208 | # mask = mask[:, 0, 0] 209 | mask = (mask.sum(dim=-2) == H).float().squeeze() 210 | 211 | mask_embedding = self.mask_conv(mask).reshape(N, 6, 1, 1).repeat(1, 1, H, W) 212 | 213 | spectral_map = torch.cat([spectral_map, mask_embedding], dim=1) 214 | 215 | out = self.middle_convs(spectral_map) 216 | 217 | # global average pooling 218 | out = out.mean([-2, -1]) 219 | out = self.conv_last(out) 220 | 221 | out = torch.sigmoid(out) 222 | 223 | # Don't need to select an existed measurement 224 | new_mask = out * (1-mask) 225 | 226 | # don't use repeat here. Pytorch can automatically broadcast. 227 | return new_mask.reshape(N, 1, 1, W) 228 | 229 | class KspaceLineConstrainedSampler(nn.Module): 230 | """ 231 | A line constrained convolutional sampler which uses the same architecture 232 | as the evaluator in https://arxiv.org/pdf/1902.03051.pdf 233 | In this module, we first create a pseudo image containing spectral maps 234 | corresponding to every kspace line. We then pass this pseudo image into 235 | a classification network to predict the probability of selecting 236 | each kspace line. Refer to the above paper for details 237 | """ 238 | 239 | def __init__(self, in_chans, out_chans, clamp=100, with_uncertainty=False, fixed_input=False): 240 | super().__init__() 241 | """self.middle_convs = nn.Sequential( 242 | nn.Conv2d(1+2, 64, kernel_size=3, stride=2, padding=1), 243 | nn.InstanceNorm2d(64), 244 | nn.LeakyReLU(0.2), 245 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 246 | nn.InstanceNorm2d(128), 247 | nn.LeakyReLU(0.2), 248 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), 249 | nn.InstanceNorm2d(256), 250 | nn.LeakyReLU(0.2), 251 | )""" 252 | 253 | if with_uncertainty: 254 | self.flatten_size = int((in_chans) **2 * 6) 255 | else: 256 | self.flatten_size = int((in_chans) **2 * 5) 257 | 258 | self.conv_last = nn.Sequential( 259 | nn.Linear(self.flatten_size, 512), 260 | nn.ReLU(inplace=True), 261 | nn.Linear(512, 512), 262 | nn.ReLU(inplace=True), 263 | nn.Linear(512, 512), 264 | nn.ReLU(inplace=True), 265 | nn.Linear(512, 512), 266 | nn.ReLU(inplace=True), 267 | nn.Linear(512, out_chans) 268 | ) 269 | 270 | self.clamp = clamp 271 | self.with_uncertainty = with_uncertainty 272 | self.fixed_input = fixed_input 273 | 274 | if fixed_input: 275 | print("Generate random input tensor") 276 | fixed_input_tensor = torch.randn(size=[1, 5, in_chans, in_chans]) 277 | self.fixed_input_tensor = nn.Parameter(fixed_input_tensor, requires_grad=False) 278 | 279 | print("Finish Mask Initialization") 280 | 281 | def forward(self, kspace, mask, uncertainty_map=None): 282 | """ 283 | Args: 284 | kspace (torch.Tensor): Already sampled measurements shape NHWC 285 | mask (torch.Tensor): previous sampling trajectories shape N1HW 286 | 287 | Returns: 288 | new_mask (torch.Tensor): probability map of next iteration's sampling locations 289 | shape N1HW <- broadcasted from N11W 290 | """ 291 | N, C, H, W = mask.shape 292 | 293 | if self.with_uncertainty: 294 | assert 0 295 | feat_map = torch.cat([kspace.permute(0, 3, 1, 2), uncertainty_map], dim=1) 296 | else: 297 | feat_map = torch.cat([kspace.permute(0, 3, 1, 2), mask], dim=1) 298 | 299 | if self.fixed_input: 300 | feat_map = self.fixed_input_tensor.repeat(N, 1, 1, 1).to(kspace.device) 301 | 302 | out = feat_map.flatten(start_dim=1) 303 | 304 | out = self.conv_last(out) 305 | 306 | out = F.softplus(out) 307 | 308 | # out = out / torch.clamp(out / torch.clamp(torch.max(out.reshape(out.shape[0], -1), dim=1)[0].reshape(-1, 1), min=1e-10), max=1-1e-10, min=1e-10) 309 | out = out / torch.max(out.reshape(out.shape[0], -1), dim=1)[0].reshape(-1, 1) 310 | 311 | # Don't need to select an existed measurement 312 | if out.shape[-1] == mask.shape[-1]: 313 | vertical_mask = (mask.sum(dim=-2) == H).float().reshape(N, -1) 314 | new_mask = out * (1-vertical_mask) 315 | else: 316 | # verify horizontal / vertical separately 317 | vertical_mask = (mask.sum(dim=-2) == H).float().reshape(N, -1) 318 | horizontal_mask = (mask.transpose(-2, -1).sum(dim=-2)==W).float().reshape(N, -1) 319 | 320 | length = out.shape[1] // 2 321 | new_mask = torch.zeros_like(out) 322 | new_mask[:, :length] = out[:, :length] * (1-vertical_mask) 323 | new_mask[:, length:] = out[:, length:] * (1-horizontal_mask) 324 | 325 | # don't use repeat here. Pytorch can automatically broadcast. 326 | return new_mask.reshape(N, 1, 1, -1) 327 | 328 | 329 | class BiLineConstrainedSampler(nn.Module): 330 | """ 331 | A line constrained convolutional sampler which uses the same architecture 332 | as the evaluator in https://arxiv.org/pdf/1902.03051.pdf 333 | In this module, we first create a pseudo image containing spectral maps 334 | corresponding to every kspace line. We then pass this pseudo image into 335 | a classification network to predict the probability of selecting 336 | each kspace line. Refer to the above paper for details 337 | """ 338 | 339 | def __init__(self, in_chans, clamp, with_uncertainty=False, fixed_input=False): 340 | super().__init__() 341 | 342 | self.sampler = KspaceLineConstrainedSampler(in_chans=in_chans, out_chans=in_chans*2, clamp=clamp, 343 | with_uncertainty=with_uncertainty, fixed_input=fixed_input) 344 | 345 | def forward(self, kspace, mask, uncertainty_map=None): 346 | """ 347 | Args: 348 | kspace (torch.Tensor): Already sampled measurements shape NHWC 349 | mask (torch.Tensor): previous sampling trajectories shape N1HW 350 | 351 | Returns: 352 | new_mask (torch.Tensor): probability map of next iteration's sampling locations 353 | shape N1HW <- broadcasted from N11W 354 | """ 355 | return self.sampler.forward(kspace, mask, uncertainty_map) 356 | -------------------------------------------------------------------------------- /activemri/baselines/sequential_sampling_codes/joint_reconstructor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/activemri/baselines/sequential_sampling_codes/joint_reconstructor.py -------------------------------------------------------------------------------- /activemri/baselines/sequential_sampling_codes/sampler2d.py: -------------------------------------------------------------------------------- 1 | from activemri.baselines.sequential_sampling_codes.conv_sampler import BasicBlock 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from activemri.baselines.loupe_codes import transforms 6 | from activemri.baselines.loupe_codes.reconstructors import ConvBlock 7 | 8 | class Sampler2D(nn.Module): 9 | def __init__(self, num_blocks=[1, 1, 1], fixed_input=False): 10 | super(Sampler2D, self).__init__() 11 | 12 | in_chans = 5 13 | out_chans = 1 14 | chans = 64 15 | num_pool_layers = 4 16 | drop_prob = 0 17 | 18 | self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) 19 | ch = chans 20 | for i in range(num_pool_layers - 1): 21 | self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob)] 22 | ch *= 2 23 | self.conv = ConvBlock(ch, ch, drop_prob) 24 | 25 | self.up_sample_layers = nn.ModuleList() 26 | for i in range(num_pool_layers - 1): 27 | self.up_sample_layers += [ConvBlock(ch * 2, ch // 2, drop_prob)] 28 | ch //= 2 29 | self.up_sample_layers += [ConvBlock(ch * 2, ch, drop_prob)] 30 | self.conv2 = nn.Sequential( 31 | nn.Conv2d(ch, ch // 2, kernel_size=1), 32 | nn.Conv2d(ch // 2, out_chans, kernel_size=1), 33 | nn.Conv2d(out_chans, out_chans, kernel_size=1), 34 | ) 35 | 36 | self.fixed_input = fixed_input 37 | if fixed_input: 38 | print('generate random input tensor') 39 | fixed_input_tensor = torch.randn(size=[1, 5, 128, 128]) 40 | self.fixed_input_tensor = nn.Parameter(fixed_input_tensor, requires_grad=False) 41 | 42 | def _make_layer(self, block, planes, num_blocks, stride): 43 | strides = [stride] + [1]*(num_blocks-1) 44 | layers = [] 45 | for stride in strides: 46 | layers.append(block(self.in_planes, planes, stride)) 47 | self.in_planes = planes * block.expansion 48 | return nn.Sequential(*layers) 49 | 50 | def forward(self, x, mask): 51 | """ 52 | This function takes the observed kspace data and sampling trajectories as input 53 | and output next sampling probability mask. We additionally mask out all 54 | previous sampled locations. 55 | input: NHWC 56 | mask: N1HW 57 | """ 58 | # NHWC -> NCHW 59 | x = x.permute(0, -1, 1, 2) 60 | 61 | # concatenate with previous sampling mask 62 | x = torch.cat([x, mask], dim=1) 63 | 64 | if self.fixed_input: 65 | x = self.fixed_input_tensor 66 | 67 | output = x 68 | 69 | stack = [] 70 | # print(input.shape, mask.shape) 71 | 72 | # Apply down-sampling layers 73 | for layer in self.down_sample_layers: 74 | output = layer(output) 75 | stack.append(output) 76 | output = F.max_pool2d(output, kernel_size=2) 77 | 78 | output = self.conv(output) 79 | 80 | # Apply up-sampling layers 81 | for layer in self.up_sample_layers: 82 | downsample_layer = stack.pop() 83 | layer_size = (downsample_layer.shape[-2], downsample_layer.shape[-1]) 84 | output = F.interpolate(output, size=layer_size, mode='bilinear', align_corners=False) 85 | output = torch.cat([output, downsample_layer], dim=1) 86 | output = layer(output) 87 | 88 | out = self.conv2(output) 89 | 90 | out = F.softplus(out) 91 | 92 | out = out / torch.max(out.reshape(out.shape[0], -1), dim=1)[0].reshape(-1, 1, 1, 1) 93 | 94 | # Don't need to select an existed measurement 95 | new_mask = out * (1-mask) 96 | 97 | return new_mask 98 | -------------------------------------------------------------------------------- /activemri/baselines/simple_baselines.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | activemri.baselines.simple_baselines.py 8 | ======================================= 9 | Simple baselines for active MRI acquisition. 10 | """ 11 | from typing import Any, Dict, List, Optional 12 | 13 | import numpy as np 14 | import torch 15 | 16 | import activemri.envs 17 | 18 | from . import Policy 19 | 20 | 21 | class RandomPolicy(Policy): 22 | """A policy representing random k-space selection. 23 | 24 | Returns one of the valid actions uniformly at random. 25 | 26 | Args: 27 | seed(optional(int)): The seed to use for the random number generator, which is 28 | based on ``torch.Generator()``. 29 | """ 30 | 31 | def __init__(self, seed: Optional[int] = None): 32 | super().__init__() 33 | self.rng = torch.Generator() 34 | if seed: 35 | self.rng.manual_seed(seed) 36 | 37 | def get_action(self, obs: Dict[str, Any], **_kwargs) -> List[int]: 38 | """Returns a random action without replacement. 39 | 40 | Args: 41 | obs(dict(str, any)): As returned by :class:`activemri.envs.ActiveMRIEnv`. 42 | 43 | Returns: 44 | list(int): A list of random k-space column indices, one per batch element in 45 | the observation. The indices are sampled from the set of inactive (0) columns 46 | on each batch element. 47 | """ 48 | return ( 49 | (obs["mask"].logical_not().float() + 1e-6) 50 | .multinomial(1, generator=self.rng) 51 | .squeeze() 52 | .tolist() 53 | ) 54 | 55 | 56 | class RandomLowBiasPolicy(Policy): 57 | def __init__( 58 | self, acceleration: float, centered: bool = True, seed: Optional[int] = None 59 | ): 60 | super().__init__() 61 | self.acceleration = acceleration 62 | self.centered = centered 63 | self.rng = np.random.RandomState(seed) 64 | 65 | def get_action(self, obs: Dict[str, Any], **_kwargs) -> List[int]: 66 | mask = obs["mask"].squeeze().cpu().numpy() 67 | new_mask = self._cartesian_mask(mask) 68 | action = (new_mask - mask).argmax(axis=1) 69 | return action.tolist() 70 | 71 | @staticmethod 72 | def _normal_pdf(length: int, sensitivity: float): 73 | return np.exp(-sensitivity * (np.arange(length) - length / 2) ** 2) 74 | 75 | def _cartesian_mask(self, current_mask: np.ndarray) -> np.ndarray: 76 | batch_size, image_width = current_mask.shape 77 | pdf_x = RandomLowBiasPolicy._normal_pdf( 78 | image_width, 0.5 / (image_width / 10.0) ** 2 79 | ) 80 | pdf_x = np.expand_dims(pdf_x, axis=0) 81 | lmda = image_width / (2.0 * self.acceleration) 82 | # add uniform distribution 83 | pdf_x += lmda * 1.0 / image_width 84 | # remove previously chosen columns 85 | # note that pdf_x designed for centered masks 86 | new_mask = ( 87 | np.fft.ifftshift(current_mask, axes=1) 88 | if not self.centered 89 | else current_mask.copy() 90 | ) 91 | pdf_x = pdf_x * np.logical_not(new_mask) 92 | # normalize probabilities and choose accordingly 93 | pdf_x /= pdf_x.sum(axis=1, keepdims=True) 94 | indices = [ 95 | self.rng.choice(image_width, 1, False, pdf_x[i]).item() 96 | for i in range(batch_size) 97 | ] 98 | new_mask[range(batch_size), indices] = 1 99 | if not self.centered: 100 | new_mask = np.fft.ifftshift(new_mask, axes=1) 101 | return new_mask 102 | 103 | 104 | class LowestIndexPolicy(Policy): 105 | """A policy that represents low-to-high frequency k-space selection. 106 | 107 | Args: 108 | alternate_sides(bool): If ``True`` the indices of selected actions will alternate 109 | between the sides of the mask. For example, for an image with 100 110 | columns, and non-centered k-space, the order will be 0, 99, 1, 98, 2, 97, ..., etc. 111 | For the same size and centered, the order will be 49, 50, 48, 51, 47, 52, ..., etc. 112 | 113 | centered(bool): If ``True`` (default), low frequencies are in the center of the mask. 114 | Otherwise, they are in the edges of the mask. 115 | """ 116 | 117 | def __init__( 118 | self, 119 | alternate_sides: bool, 120 | centered: bool = True, 121 | ): 122 | super().__init__() 123 | self.alternate_sides = alternate_sides 124 | self.centered = centered 125 | self.bottom_side = True 126 | 127 | def get_action(self, obs: Dict[str, Any], **_kwargs) -> List[int]: 128 | """Returns a random action without replacement. 129 | 130 | Args: 131 | obs(dict(str, any)): As returned by :class:`activemri.envs.ActiveMRIEnv`. 132 | 133 | Returns: 134 | list(int): A list of k-space column indices, one per batch element in 135 | the observation, equal to the lowest non-active k-space column in their 136 | corresponding observation masks. 137 | """ 138 | mask = obs["mask"].squeeze().cpu().numpy() 139 | new_mask = self._get_new_mask(mask) 140 | action = (new_mask - mask).argmax(axis=1) 141 | return action.tolist() 142 | 143 | def _get_new_mask(self, current_mask: np.ndarray) -> np.ndarray: 144 | # The code below assumes mask in non centered 145 | new_mask = ( 146 | np.fft.ifftshift(current_mask, axes=1) 147 | if self.centered 148 | else current_mask.copy() 149 | ) 150 | if self.bottom_side: 151 | idx = np.arange(new_mask.shape[1], 0, -1) 152 | else: 153 | idx = np.arange(new_mask.shape[1]) 154 | if self.alternate_sides: 155 | self.bottom_side = not self.bottom_side 156 | # Next line finds the first non-zero index (from edge to center) and returns it 157 | indices = (np.logical_not(new_mask) * idx).argmax(axis=1) 158 | indices = np.expand_dims(indices, axis=1) 159 | new_mask[range(new_mask.shape[0]), indices] = 1 160 | if self.centered: 161 | new_mask = np.fft.ifftshift(new_mask, axes=1) 162 | return new_mask 163 | 164 | 165 | class OneStepGreedyOracle(Policy): 166 | """A policy that returns the k-space column leading to best reconstruction score. 167 | 168 | Args: 169 | env(``activemri.envs.ActiveMRIEnv``): The environment for which the policy is computed 170 | for. 171 | metric(str): The name of the score metric to use (must be in ``env.score_keys()``). 172 | num_samples(optional(int)): If given, only ``num_samples`` random actions will be 173 | tested. Defaults to ``None``, which means that method will consider all actions. 174 | rng(``numpy.random.RandomState``): A random number generator to use for sampling. 175 | """ 176 | 177 | def __init__( 178 | self, 179 | env: activemri.envs.ActiveMRIEnv, 180 | metric: str, 181 | num_samples: Optional[int] = None, 182 | rng: Optional[np.random.RandomState] = None, 183 | ): 184 | assert metric in env.score_keys() 185 | super().__init__() 186 | self.env = env 187 | self.metric = metric 188 | self.num_samples = num_samples 189 | self.rng = rng if rng is not None else np.random.RandomState() 190 | 191 | def get_action(self, obs: Dict[str, Any], **_kwargs) -> List[int]: 192 | """Returns a one-step greedy action maximizing reconstruction score. 193 | 194 | Args: 195 | obs(dict(str, any)): As returned by :class:`activemri.envs.ActiveMRIEnv`. 196 | 197 | Returns: 198 | list(int): A list of k-space column indices, one per batch element in 199 | the observation, equal to the action that maximizes reconstruction score 200 | (e.g, SSIM or negative MSE). 201 | """ 202 | mask = obs["mask"] 203 | batch_size = mask.shape[0] 204 | all_action_lists = [] 205 | for i in range(batch_size): 206 | available_actions = mask[i].logical_not().nonzero().squeeze().tolist() 207 | self.rng.shuffle(available_actions) 208 | if len(available_actions) < self.num_samples: 209 | # Add dummy actions to try if num of samples is higher than the 210 | # number of inactive columns in this mask 211 | available_actions.extend( 212 | [0] * (self.num_samples - len(available_actions)) 213 | ) 214 | all_action_lists.append(available_actions) 215 | 216 | all_scores = np.zeros((batch_size, self.num_samples)) 217 | for i in range(self.num_samples): 218 | batch_action_to_try = [action_list[i] for action_list in all_action_lists] 219 | obs, new_score = self.env.try_action(batch_action_to_try) 220 | all_scores[:, i] = new_score[self.metric] 221 | if self.metric in ["mse", "nmse"]: 222 | all_scores *= -1 223 | else: 224 | assert self.metric in ["ssim", "psnr"] 225 | 226 | best_indices = all_scores.argmax(axis=1) 227 | action = [] 228 | for i in range(batch_size): 229 | action.append(all_action_lists[i][best_indices[i]]) 230 | return action 231 | -------------------------------------------------------------------------------- /activemri/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Dict, List 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from . import singlecoil_knee_data 12 | from . import transforms 13 | 14 | __all__ = ["singlecoil_knee_data", "transforms"] 15 | 16 | 17 | def transform_template( 18 | kspace: List[np.ndarray] = None, 19 | mask: torch.Tensor = None, 20 | ground_truth: torch.Tensor = None, 21 | attrs: List[Dict[str, Any]] = None, 22 | fname: List[str] = None, 23 | slice_id: List[int] = None, 24 | ): 25 | """Template for transform functions. 26 | 27 | Args: 28 | - kspace(list(np.ndarray)): A list of complex numpy arrays, one per k-space in the batch. 29 | The length is the ``batch_size``, and array shapes are ``H x W x 2`` for single coil data, 30 | and ``C x H x W x 2`` for multicoil data, where ``H`` denotes k-space height, ``W`` 31 | denotes k-space width, and ``C`` is the number of coils. Note that the width can differ 32 | between batch elements, if ``num_cols`` is set to a tuple when creating the environment. 33 | - mask(torch.Tensor): A tensor of binary column masks, where 1s indicate that the 34 | corresponding k-space column should be selected. The shape is ``batch_size x 1 x maxW``, 35 | for single coil data, and ``batch_size x 1 x 1 x maxW`` for multicoil data. Here ``maxW`` 36 | is the maximum k-space width returned by the environment. 37 | - ground_truth(torch.Tensor): A tensor of ground truth 2D images. The shape is 38 | ``batch_size x 320 x 320``. 39 | - attrs(list(dict)): A list of dictionaries with the attributes read from the fastMRI for 40 | each image. 41 | - fname(list(str)): A list of the filenames where the images where read from. 42 | - slice_id(list(int)): A list with the slice ids in the files where each image was read 43 | from. 44 | 45 | Returns: 46 | tuple(Any...): A tuple with any number of inputs required by the reconstructor model. 47 | 48 | """ 49 | pass 50 | -------------------------------------------------------------------------------- /activemri/data/real_brain_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # from activemri.baselines.loupe_codes.transforms import normalize 7 | import pathlib 8 | from typing import Callable, List, Optional, Tuple 9 | 10 | import fastmri 11 | import h5py 12 | import numpy as np 13 | import torch.utils.data 14 | import activemri 15 | import scipy.ndimage as ndimage 16 | 17 | class RealBrainData(torch.utils.data.Dataset): 18 | # This is the same as fastMRI singlecoil_knee, except we provide a custom test split 19 | # and also normalize images by the mean norm of the k-space over training data 20 | # KSPACE_WIDTH = 368 21 | # KSPACE_HEIGHT = 640 22 | # START_PADDING = 166 23 | # END_PADDING = 202 24 | # CENTER_CROP_SIZE = 320 25 | 26 | def __init__( 27 | self, 28 | root: pathlib.Path, 29 | image_shape: Tuple[int, int], 30 | transform: Callable, 31 | noise_type: str, 32 | noise_level: float = 5e-5, 33 | num_cols: Optional[int] = None, 34 | num_volumes: Optional[int] = None, 35 | num_rand_slices: Optional[int] = None, 36 | custom_split: Optional[str] = None, 37 | random_rotate=False 38 | ): 39 | self.image_shape = image_shape 40 | self.transform = transform 41 | self.noise_type = noise_type 42 | self.noise_level = noise_level 43 | self.examples: List[Tuple[pathlib.PurePath, int]] = [] 44 | 45 | 46 | self.num_rand_slices = num_rand_slices 47 | self.rng = np.random.RandomState(1234) 48 | self.recons_key = 'reconstruction_rss' 49 | 50 | files = [] 51 | for fname in list(pathlib.Path(root).iterdir()): 52 | data = h5py.File(fname, "r") 53 | if 'reconstruction_rss' not in data.keys(): 54 | continue 55 | files.append(fname) 56 | 57 | self.train_mode = False 58 | 59 | """if custom_split is not None: 60 | split_info = [] 61 | with open(f"activemri/data/splits/knee_singlecoil/{custom_split}.txt") as f: 62 | for line in f: 63 | split_info.append(line.rsplit("\n")[0]) 64 | files = [f for f in files if f.name in split_info] 65 | else: 66 | self.train_mode = True 67 | """ 68 | if num_volumes is not None: 69 | self.rng.shuffle(files) 70 | files = files[:num_volumes] 71 | 72 | for volume_i, fname in enumerate(sorted(files)): 73 | data = h5py.File(fname, "r") 74 | # kspace = data["kspace"] 75 | 76 | if num_rand_slices is None: 77 | num_slices = data['reconstruction_rss'].shape[0] 78 | self.examples += [(fname, slice_id) for slice_id in range(num_slices)] 79 | else: 80 | assert 0 81 | slice_ids = list(range(kspace.shape[0])) 82 | self.rng.seed(seed=volume_i) 83 | self.rng.shuffle(slice_ids) 84 | self.examples += [ 85 | (fname, slice_id) for slice_id in slice_ids[:num_rand_slices] 86 | ] 87 | 88 | self.random_rotate = random_rotate 89 | if self.random_rotate: 90 | np.random.seed(42) 91 | self.random_angles = np.random.random(len(self)) * 30- 15 92 | 93 | def center_crop(self, data, shape): 94 | """ 95 | (Same as the one in activemri/baselines/policy_gradient_codes/helpers/transforms.py) 96 | Apply a center crop to the input real image or batch of real images. 97 | 98 | Args: 99 | data (torch.Tensor): The input tensor to be center cropped. It should have at 100 | least 2 dimensions and the cropping is applied along the last two dimensions. 101 | shape (int, int): The output shape. The shape should be smaller than the 102 | corresponding dimensions of data. 103 | 104 | Returns: 105 | torch.Tensor: The center cropped image 106 | """ 107 | assert 0 < shape[0] <= data.shape[-2] 108 | assert 0 < shape[1] <= data.shape[-1] 109 | w_from = (data.shape[-2] - shape[0]) // 2 110 | h_from = (data.shape[-1] - shape[1]) // 2 111 | w_to = w_from + shape[0] 112 | h_to = h_from + shape[1] 113 | return data[..., w_from:w_to, h_from:h_to] 114 | 115 | def __len__(self): 116 | return len(self.examples) 117 | 118 | def __getitem__(self, i): 119 | fname, slice_id = self.examples[i] 120 | with h5py.File(fname, 'r') as data: 121 | 122 | # kspace = data["kspace"][slice_id] 123 | # kspace = np.stack([kspace.real, kspace.imag], axis=-1) 124 | # if self.random_rotate: 125 | # kspace = ndimage.rotate(kspace, self.random_angles[i], reshape=False, mode='nearest') 126 | 127 | #kspace = torch.from_numpy(kspace).permute(2, 0, 1) 128 | #kspace = self.center_crop(kspace, self.image_shape).permute(1, 2, 0) 129 | 130 | #kspace = fastmri.ifftshift(kspace, dim=(0, 1)) 131 | target = torch.from_numpy(data['reconstruction_rss'][slice_id]).unsqueeze(-1) 132 | target = torch.cat([target, torch.zeros_like(target)], dim=-1) 133 | target = self.center_crop(target.permute(2, 0, 1), self.image_shape).permute(1, 2, 0) 134 | 135 | kspace = fastmri.ifftshift(target, dim=(0, 1)).fft(2, normalized=False) 136 | 137 | 138 | # target = torch.ifft(kspace, 2, normalized=False) 139 | #target = fastmri.ifftshift(target, dim=(0, 1)) 140 | 141 | 142 | # Normalize using mean of k-space in training data 143 | # target /= 7.072103529760345e-07 144 | # kspace /= 7.072103529760345e-07 145 | 146 | kspace = kspace.numpy() 147 | target = target.numpy() 148 | 149 | return self.transform( 150 | kspace, 151 | torch.zeros(kspace.shape[1]), 152 | target, 153 | dict(data.attrs), 154 | fname.name, 155 | slice_id 156 | ) 157 | 158 | """from activemri.envs.envs import ActiveMRIEnv 159 | data = RealBrainData( 160 | 'datasets/brain/train_no_kspace', 161 | (128, 128), 162 | ActiveMRIEnv._void_transform, 163 | noise_type='none' 164 | ) 165 | 166 | data[0] 167 | """ -------------------------------------------------------------------------------- /activemri/data/real_knee_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # from activemri.baselines.loupe_codes.transforms import normalize 7 | import pathlib 8 | from typing import Callable, List, Optional, Tuple 9 | 10 | import fastmri 11 | import h5py 12 | import numpy as np 13 | import torch.utils.data 14 | import activemri 15 | import scipy.ndimage as ndimage 16 | 17 | class RealKneeData(torch.utils.data.Dataset): 18 | # This is the same as fastMRI singlecoil_knee, except we provide a custom test split 19 | # and also normalize images by the mean norm of the k-space over training data 20 | # KSPACE_WIDTH = 368 21 | # KSPACE_HEIGHT = 640 22 | # START_PADDING = 166 23 | # END_PADDING = 202 24 | # CENTER_CROP_SIZE = 320 25 | 26 | def __init__( 27 | self, 28 | root: pathlib.Path, 29 | image_shape: Tuple[int, int], 30 | transform: Callable, 31 | noise_type: str, 32 | noise_level: float = 5e-5, 33 | num_cols: Optional[int] = None, 34 | num_volumes: Optional[int] = None, 35 | num_rand_slices: Optional[int] = None, 36 | custom_split: Optional[str] = None, 37 | random_rotate=False 38 | ): 39 | self.image_shape = image_shape 40 | self.transform = transform 41 | self.noise_type = noise_type 42 | self.noise_level = noise_level 43 | self.examples: List[Tuple[pathlib.PurePath, int]] = [] 44 | 45 | self.num_rand_slices = num_rand_slices 46 | self.rng = np.random.RandomState(1234) 47 | self.recons_key = 'reconstruction_rss' 48 | 49 | files = [] 50 | for fname in list(pathlib.Path(root).iterdir()): 51 | data = h5py.File(fname, "r") 52 | if 'reconstruction_rss' not in data.keys(): 53 | continue 54 | files.append(fname) 55 | 56 | self.train_mode = False 57 | 58 | if custom_split is not None: 59 | split_info = [] 60 | with open(f"activemri/data/splits/knee_singlecoil/{custom_split}.txt") as f: 61 | for line in f: 62 | split_info.append(line.rsplit("\n")[0]) 63 | files = [f for f in files if f.name in split_info] 64 | else: 65 | self.train_mode = True 66 | 67 | if num_volumes is not None: 68 | self.rng.shuffle(files) 69 | files = files[:num_volumes] 70 | 71 | for volume_i, fname in enumerate(sorted(files)): 72 | data = h5py.File(fname, "r") 73 | kspace = data["kspace"] 74 | 75 | if num_rand_slices is None: 76 | num_slices = kspace.shape[0] 77 | self.examples += [(fname, slice_id) for slice_id in range(num_slices)] 78 | else: 79 | slice_ids = list(range(kspace.shape[0])) 80 | self.rng.seed(seed=volume_i) 81 | self.rng.shuffle(slice_ids) 82 | self.examples += [ 83 | (fname, slice_id) for slice_id in slice_ids[:num_rand_slices] 84 | ] 85 | 86 | self.random_rotate = random_rotate 87 | if self.random_rotate: 88 | np.random.seed(42) 89 | self.random_angles = np.random.random(len(self)) * 30- 15 90 | 91 | def center_crop(self, data, shape): 92 | """ 93 | (Same as the one in activemri/baselines/policy_gradient_codes/helpers/transforms.py) 94 | Apply a center crop to the input real image or batch of real images. 95 | 96 | Args: 97 | data (torch.Tensor): The input tensor to be center cropped. It should have at 98 | least 2 dimensions and the cropping is applied along the last two dimensions. 99 | shape (int, int): The output shape. The shape should be smaller than the 100 | corresponding dimensions of data. 101 | 102 | Returns: 103 | torch.Tensor: The center cropped image 104 | """ 105 | assert 0 < shape[0] <= data.shape[-2] 106 | assert 0 < shape[1] <= data.shape[-1] 107 | w_from = (data.shape[-2] - shape[0]) // 2 108 | h_from = (data.shape[-1] - shape[1]) // 2 109 | w_to = w_from + shape[0] 110 | h_to = h_from + shape[1] 111 | return data[..., w_from:w_to, h_from:h_to] 112 | 113 | def __len__(self): 114 | return len(self.examples) 115 | 116 | def __getitem__(self, i): 117 | fname, slice_id = self.examples[i] 118 | with h5py.File(fname, 'r') as data: 119 | 120 | kspace = data["kspace"][slice_id] 121 | kspace = np.stack([kspace.real, kspace.imag], axis=-1) 122 | if self.random_rotate: 123 | kspace = ndimage.rotate(kspace, self.random_angles[i], reshape=False, mode='nearest') 124 | 125 | kspace = torch.from_numpy(kspace).permute(2, 0, 1) 126 | kspace = self.center_crop(kspace, self.image_shape).permute(1, 2, 0) 127 | 128 | kspace = fastmri.ifftshift(kspace, dim=(0, 1)) 129 | 130 | target = torch.ifft(kspace, 2, normalized=False) 131 | target = fastmri.ifftshift(target, dim=(0, 1)) 132 | 133 | 134 | # Normalize using mean of k-space in training data 135 | target /= 7.072103529760345e-07 136 | kspace /= 7.072103529760345e-07 137 | 138 | kspace = kspace.numpy() 139 | target = target.numpy() 140 | 141 | return self.transform( 142 | kspace, 143 | torch.zeros(kspace.shape[1]), 144 | target, 145 | dict(data.attrs), 146 | fname.name, 147 | slice_id 148 | ) 149 | -------------------------------------------------------------------------------- /activemri/data/singlecoil_knee_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import pathlib 7 | from typing import Callable, List, Optional, Tuple 8 | 9 | import fastmri 10 | import h5py 11 | import numpy as np 12 | import torch.utils.data 13 | 14 | 15 | # ----------------------------------------------------------------------------- 16 | # Single coil knee dataset (as used in MICCAI'20) 17 | # ----------------------------------------------------------------------------- 18 | class MICCAI2020Data(torch.utils.data.Dataset): 19 | # This is the same as fastMRI singlecoil_knee, except we provide a custom test split 20 | # and also normalize images by the mean norm of the k-space over training data 21 | KSPACE_WIDTH = 368 22 | KSPACE_HEIGHT = 640 23 | START_PADDING = 166 24 | END_PADDING = 202 25 | CENTER_CROP_SIZE = 320 26 | 27 | def __init__( 28 | self, 29 | root: pathlib.Path, 30 | transform: Callable, 31 | num_cols: Optional[int] = None, 32 | num_volumes: Optional[int] = None, 33 | num_rand_slices: Optional[int] = None, 34 | custom_split: Optional[str] = None, 35 | ): 36 | self.transform = transform 37 | self.examples: List[Tuple[pathlib.PurePath, int]] = [] 38 | 39 | self.num_rand_slices = num_rand_slices 40 | self.rng = np.random.RandomState(1234) 41 | 42 | files = [] 43 | for fname in list(pathlib.Path(root).iterdir()): 44 | data = h5py.File(fname, "r") 45 | if num_cols is not None and data["kspace"].shape[2] != num_cols: 46 | continue 47 | files.append(fname) 48 | 49 | if custom_split is not None: 50 | split_info = [] 51 | with open(f"activemri/data/splits/knee_singlecoil/{custom_split}.txt") as f: 52 | for line in f: 53 | split_info.append(line.rsplit("\n")[0]) 54 | files = [f for f in files if f.name in split_info] 55 | 56 | if num_volumes is not None: 57 | self.rng.shuffle(files) 58 | files = files[:num_volumes] 59 | 60 | for volume_i, fname in enumerate(sorted(files)): 61 | data = h5py.File(fname, "r") 62 | kspace = data["kspace"] 63 | 64 | if num_rand_slices is None: 65 | num_slices = kspace.shape[0] 66 | self.examples += [(fname, slice_id) for slice_id in range(num_slices)] 67 | else: 68 | slice_ids = list(range(kspace.shape[0])) 69 | self.rng.seed(seed=volume_i) 70 | self.rng.shuffle(slice_ids) 71 | self.examples += [ 72 | (fname, slice_id) for slice_id in slice_ids[:num_rand_slices] 73 | ] 74 | 75 | def __len__(self): 76 | return len(self.examples) 77 | 78 | def __getitem__(self, i): 79 | fname, slice_id = self.examples[i] 80 | with h5py.File(fname, "r") as data: 81 | kspace = data["kspace"][slice_id] 82 | kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1)) 83 | kspace = fastmri.ifftshift(kspace, dim=(0, 1)) 84 | target = torch.ifft(kspace, 2, normalized=False) 85 | target = fastmri.ifftshift(target, dim=(0, 1)) 86 | # Normalize using mean of k-space in training data 87 | target /= 7.072103529760345e-07 88 | kspace /= 7.072103529760345e-07 89 | 90 | # Environment expects numpy arrays. The code above was used with an older 91 | # version of the environment to generate the results of the MICCAI'20 paper. 92 | # So, to keep this consistent with the version in the paper, we convert 93 | # the tensors back to numpy rather than changing the original code. 94 | kspace = kspace.numpy() 95 | target = target.numpy() 96 | return self.transform( 97 | kspace, 98 | torch.zeros(kspace.shape[1]), 99 | target, 100 | dict(data.attrs), 101 | fname.name, 102 | slice_id, 103 | ) 104 | -------------------------------------------------------------------------------- /activemri/data/splits/knee_singlecoil/test.txt: -------------------------------------------------------------------------------- 1 | file1001126.h5 2 | file1001930.h5 3 | file1002436.h5 4 | file1002382.h5 5 | file1001585.h5 6 | file1001506.h5 7 | file1001057.h5 8 | file1002021.h5 9 | file1000831.h5 10 | file1001344.h5 11 | file1000007.h5 12 | file1001938.h5 13 | file1001381.h5 14 | file1001365.h5 15 | file1001458.h5 16 | file1001726.h5 17 | file1000280.h5 18 | file1002145.h5 19 | file1001759.h5 20 | file1000976.h5 21 | file1001064.h5 22 | file1000108.h5 23 | file1000291.h5 24 | file1001689.h5 25 | file1001191.h5 26 | file1000903.h5 27 | file1001440.h5 28 | file1001184.h5 29 | file1001119.h5 30 | file1002351.h5 31 | file1001643.h5 32 | file1001893.h5 33 | file1001968.h5 34 | file1001566.h5 35 | file1001850.h5 36 | file1000660.h5 37 | file1000593.h5 38 | file1001763.h5 39 | file1002546.h5 40 | file1000697.h5 41 | file1000190.h5 42 | file1000273.h5 43 | file1001144.h5 44 | file1000538.h5 45 | file1000635.h5 46 | file1000769.h5 47 | file1001262.h5 48 | file1001851.h5 49 | file1000942.h5 50 | -------------------------------------------------------------------------------- /activemri/data/splits/knee_singlecoil/val.txt: -------------------------------------------------------------------------------- 1 | file1002340.h5 2 | file1000182.h5 3 | file1001655.h5 4 | file1000972.h5 5 | file1001338.h5 6 | file1000476.h5 7 | file1002252.h5 8 | file1000591.h5 9 | file1000858.h5 10 | file1001202.h5 11 | file1002159.h5 12 | file1001163.h5 13 | file1001497.h5 14 | file1000196.h5 15 | file1001331.h5 16 | file1000477.h5 17 | file1001651.h5 18 | file1000464.h5 19 | file1000625.h5 20 | file1000033.h5 21 | file1000041.h5 22 | file1000000.h5 23 | file1000389.h5 24 | file1001668.h5 25 | file1002451.h5 26 | file1000990.h5 27 | file1001077.h5 28 | file1002389.h5 29 | file1001298.h5 30 | file1002570.h5 31 | file1001834.h5 32 | file1001955.h5 33 | file1001715.h5 34 | file1001444.h5 35 | file1002214.h5 36 | file1002187.h5 37 | file1001170.h5 38 | file1000926.h5 39 | file1000818.h5 40 | file1000702.h5 41 | file1001289.h5 42 | file1000528.h5 43 | file1000759.h5 44 | file1001862.h5 45 | file1001339.h5 46 | file1001090.h5 47 | file1002067.h5 48 | file1001221.h5 49 | -------------------------------------------------------------------------------- /activemri/data/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | activemri.data.transforms.py 8 | ==================================== 9 | Transform functions to process fastMRI data for reconstruction models. 10 | """ 11 | from typing import Tuple, Union 12 | 13 | import fastmri 14 | import fastmri.data.transforms as fastmri_transforms 15 | import numpy as np 16 | import torch 17 | 18 | import activemri.data.singlecoil_knee_data as scknee_data 19 | 20 | TensorType = Union[np.ndarray, torch.Tensor] 21 | 22 | 23 | def add_gaussian_noise(kspace, mean=0., std=1.): 24 | # kspace: 32*32*2 25 | noise = mean + torch.randn(kspace.size()) * std 26 | kspace = kspace + noise 27 | return kspace 28 | 29 | 30 | def to_magnitude(tensor: torch.Tensor, dim: int) -> torch.Tensor: 31 | return (tensor ** 2).sum(dim=dim) ** 0.5 32 | 33 | 34 | def center_crop(x: TensorType, shape: Tuple[int, int]) -> TensorType: 35 | """Center crops a tensor to the desired 2D shape. 36 | 37 | Args: 38 | x(union(``torch.Tensor``, ``np.ndarray``)): The tensor to crop. 39 | Shape should be ``(batch_size, height, width)``. 40 | shape(tuple(int,int)): The desired shape to crop to. 41 | 42 | Returns: 43 | (union(``torch.Tensor``, ``np.ndarray``)): The cropped tensor. 44 | """ 45 | assert len(x.shape) == 3 46 | assert 0 < shape[0] <= x.shape[1] 47 | assert 0 < shape[1] <= x.shape[2] 48 | h_from = (x.shape[1] - shape[0]) // 2 49 | w_from = (x.shape[2] - shape[1]) // 2 50 | w_to = w_from + shape[0] 51 | h_to = h_from + shape[1] 52 | x = x[:, h_from:h_to, w_from:w_to] 53 | return x 54 | 55 | 56 | def ifft_permute_maybe_shift( 57 | x: torch.Tensor, normalized: bool = False, ifft_shift: bool = False 58 | ) -> torch.Tensor: 59 | x = x.permute(0, 2, 3, 1) 60 | y = torch.ifft(x, 2, normalized=normalized) 61 | if ifft_shift: 62 | y = fastmri.ifftshift(y, dim=(1, 2)) 63 | return y.permute(0, 3, 1, 2) 64 | 65 | 66 | def raw_transform_miccai2020(kspace=None, mask=None, **_kwargs): 67 | """Transform to produce input for reconstructor used in `Pineda et al. MICCAI'20 `_. 68 | 69 | Produces a zero-filled reconstruction and a mask that serve as a input to models of type 70 | :class:`activemri.models.cvpr10_reconstructor.CVPR19Reconstructor`. The mask is almost 71 | equal to the mask passed as argument, except that high-frequency padding columns are set 72 | to 1, and the mask is reshaped to be compatible with the reconstructor. 73 | 74 | Args: 75 | kspace(``np.ndarray``): The array containing the k-space data returned by the dataset. 76 | mask(``torch.Tensor``): The masks to apply to the k-space. 77 | 78 | Returns: 79 | tuple: A tuple containing: 80 | - ``torch.Tensor``: The zero-filled reconstructor that will be passed to the 81 | reconstructor. 82 | - ``torch.Tensor``: The mask to use as input to the reconstructor. 83 | """ 84 | # alter mask to always include the highest frequencies that include padding 85 | mask[ 86 | :, 87 | :, 88 | scknee_data.MICCAI2020Data.START_PADDING : scknee_data.MICCAI2020Data.END_PADDING, 89 | ] = 1 90 | mask = mask.unsqueeze(1) 91 | 92 | all_kspace = [] 93 | for ksp in kspace: 94 | all_kspace.append(torch.from_numpy(ksp).permute(2, 0, 1)) 95 | k_space = torch.stack(all_kspace) 96 | 97 | masked_true_k_space = torch.where( 98 | mask.byte(), 99 | k_space, 100 | torch.tensor(0.0).to(mask.device), 101 | ) 102 | reconstructor_input = ifft_permute_maybe_shift(masked_true_k_space, ifft_shift=True) 103 | return reconstructor_input, mask 104 | 105 | 106 | # Based on 107 | # https://github.com/facebookresearch/fastMRI/blob/master/experimental/unet/unet_module.py 108 | def _base_fastmri_unet_transform( 109 | kspace, 110 | mask, 111 | ground_truth, 112 | attrs, 113 | which_challenge="singlecoil", 114 | ): 115 | kspace = fastmri_transforms.to_tensor(kspace) 116 | 117 | mask = mask[..., : kspace.shape[-2]] # accounting for variable size masks 118 | masked_kspace = kspace * mask.unsqueeze(-1) + 0.0 119 | 120 | # inverse Fourier transform to get zero filled solution 121 | image = fastmri.ifft2c(masked_kspace) 122 | 123 | # crop input to correct size 124 | if ground_truth is not None: 125 | crop_size = (ground_truth.shape[-2], ground_truth.shape[-1]) 126 | else: 127 | crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) 128 | 129 | # check for FLAIR 203 130 | if image.shape[-2] < crop_size[1]: 131 | crop_size = (image.shape[-2], image.shape[-2]) 132 | 133 | # noinspection PyTypeChecker 134 | image = fastmri_transforms.complex_center_crop(image, crop_size) 135 | 136 | # absolute value 137 | image = fastmri.complex_abs(image) 138 | 139 | # apply Root-Sum-of-Squares if multicoil data 140 | if which_challenge == "multicoil": 141 | image = fastmri.rss(image) 142 | 143 | # normalize input 144 | image, mean, std = fastmri_transforms.normalize_instance(image, eps=1e-11) 145 | image = image.clamp(-6, 6) 146 | 147 | return image.unsqueeze(0), mean, std 148 | 149 | 150 | def _batched_fastmri_unet_transform( 151 | kspace, mask, ground_truth, attrs, which_challenge="singlecoil" 152 | ): 153 | batch_size = len(kspace) 154 | images, means, stds = [], [], [] 155 | for i in range(batch_size): 156 | image, mean, std = _base_fastmri_unet_transform( 157 | kspace[i], 158 | mask[i], 159 | ground_truth[i], 160 | attrs[i], 161 | which_challenge=which_challenge, 162 | ) 163 | images.append(image) 164 | means.append(mean) 165 | stds.append(std) 166 | return torch.stack(images), torch.stack(means), torch.stack(stds) 167 | 168 | 169 | # noinspection PyUnusedLocal 170 | def fastmri_unet_transform_singlecoil( 171 | kspace=None, mask=None, ground_truth=None, attrs=None, fname=None, slice_id=None 172 | ): 173 | """ 174 | Transform to use as input to fastMRI's Unet model for singlecoil data. 175 | 176 | This is an adapted version of the code found in 177 | `fastMRI `_. 178 | """ 179 | return _batched_fastmri_unet_transform( 180 | kspace, mask, ground_truth, attrs, "singlecoil" 181 | ) 182 | 183 | 184 | # noinspection PyUnusedLocal 185 | def fastmri_unet_transform_multicoil( 186 | kspace=None, mask=None, ground_truth=None, attrs=None, fname=None, slice_id=None 187 | ): 188 | """Transform to use as input to fastMRI's Unet model for multicoil data. 189 | 190 | This is an adapted version of the code found in 191 | `fastMRI `_. 192 | """ 193 | return _batched_fastmri_unet_transform( 194 | kspace, mask, ground_truth, attrs, "multicoil" 195 | ) 196 | -------------------------------------------------------------------------------- /activemri/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | __all__ = [ 7 | "ActiveMRIEnv", 8 | "MICCAI2020Env", 9 | "FastMRIEnv", 10 | "SingleCoilKneeEnv", 11 | "MultiCoilKneeEnv", 12 | "FashionMNISTEnv", 13 | "DicomKneeEnv", 14 | "RealKneeEnv" 15 | ] 16 | 17 | from .envs import ( 18 | ActiveMRIEnv, 19 | FastMRIEnv, 20 | MICCAI2020Env, 21 | MultiCoilKneeEnv, 22 | SingleCoilKneeEnv, 23 | FashionMNISTEnv, 24 | DicomKneeEnv, 25 | RealKneeEnv 26 | ) 27 | -------------------------------------------------------------------------------- /activemri/envs/loupe_envs.py: -------------------------------------------------------------------------------- 1 | import fastmri.data 2 | import pathlib 3 | import numpy as np 4 | import torch 5 | import torch.utils.data 6 | from torchvision import datasets 7 | from torchvision import transforms as TF 8 | from torch.utils.data import DataLoader, Subset, random_split 9 | 10 | import activemri.data.singlecoil_knee_data as scknee_data 11 | import activemri.data.transforms 12 | import activemri.envs.masks 13 | import activemri.envs.util 14 | import fastmri 15 | from activemri.data.real_knee_data import RealKneeData 16 | from activemri.data.real_brain_data import RealBrainData 17 | from typing import ( 18 | Any, 19 | Callable, 20 | Dict, 21 | Iterator, 22 | List, 23 | Mapping, 24 | Optional, 25 | Sequence, 26 | Sized, 27 | Tuple, 28 | Union, 29 | ) 30 | from activemri.envs.envs import ActiveMRIEnv, _env_collate_fn 31 | 32 | def transform( 33 | kspace: torch.Tensor, 34 | mask: torch.Tensor, 35 | target: torch.Tensor, 36 | attrs: List[Dict[str, Any]], 37 | fname: List[str], 38 | slice_id: List[int], 39 | ) -> Tuple: 40 | label = attrs 41 | return torch.from_numpy(kspace), mask, torch.from_numpy(target), label 42 | 43 | class LOUPEDataEnv(): 44 | def __init__(self, options): 45 | self._data_location = options.data_path 46 | self.options = options 47 | 48 | def _setup_data_handlers(self): 49 | train_data, val_data, test_data = self._create_datasets() 50 | 51 | display_data = [val_data[i] for i in range(0, len(val_data), len(val_data) // 16)] 52 | 53 | train_loader = DataLoader( 54 | dataset=train_data, 55 | batch_size=self.options.batch_size, 56 | shuffle=True, 57 | num_workers=8, 58 | pin_memory=True, 59 | drop_last=True 60 | ) 61 | val_loader = DataLoader( 62 | dataset=val_data, 63 | batch_size=self.options.batch_size, 64 | num_workers=8, 65 | shuffle=False, 66 | pin_memory=True 67 | ) 68 | test_loader = DataLoader( 69 | dataset=test_data, 70 | batch_size=self.options.batch_size, 71 | num_workers=8, 72 | shuffle=False, 73 | pin_memory=True 74 | ) 75 | display_loader = DataLoader( 76 | dataset=display_data, 77 | batch_size=16, 78 | num_workers=8, 79 | shuffle=False, 80 | pin_memory=True 81 | ) 82 | return train_loader, val_loader, test_loader, display_loader 83 | 84 | 85 | class LOUPEActiveMRIEnv(LOUPEDataEnv): 86 | def __init__(self, options): 87 | super().__init__(options) 88 | 89 | def _create_datasets(self): 90 | root_path = pathlib.Path(self._data_location) 91 | train_path = root_path / "knee_singlecoil_train" 92 | val_and_test_path = root_path / "knee_singlecoil_val" 93 | 94 | train_data = scknee_data.MICCAI2020Data( 95 | train_path, 96 | ActiveMRIEnv._void_transform, 97 | num_cols=self.options.resolution[1], 98 | ) 99 | val_data = scknee_data.MICCAI2020Data( 100 | val_and_test_path, 101 | ActiveMRIEnv._void_transform, 102 | custom_split="val", 103 | num_cols=self.options.resolution[1], 104 | ) 105 | test_data = scknee_data.MICCAI2020Data( 106 | val_and_test_path, 107 | ActiveMRIEnv._void_transform, 108 | custom_split="test", 109 | num_cols=self.options.resolution[1], 110 | ) 111 | return train_data, val_data, test_data 112 | 113 | class LOUPERealKspaceEnv(LOUPEDataEnv): 114 | def __init__(self, options): 115 | super().__init__(options) 116 | 117 | def _create_datasets(self): 118 | root_path = pathlib.Path(self._data_location) 119 | train_path = root_path / "knee_singlecoil_train" 120 | val_path = root_path / "knee_singlecoil_val" 121 | test_path = root_path / "knee_singlecoil_val" 122 | 123 | train_data = RealKneeData( 124 | train_path, 125 | self.options.resolution, 126 | ActiveMRIEnv._void_transform, 127 | self.options.noise_type, 128 | self.options.noise_level, 129 | random_rotate=self.options.random_rotate 130 | ) 131 | val_data = RealKneeData( 132 | val_path, 133 | self.options.resolution, 134 | ActiveMRIEnv._void_transform, 135 | self.options.noise_type, 136 | self.options.noise_level, 137 | custom_split='val', 138 | random_rotate=self.options.random_rotate 139 | ) 140 | test_data = RealKneeData( 141 | test_path, 142 | self.options.resolution, 143 | ActiveMRIEnv._void_transform, 144 | self.options.noise_type, 145 | self.options.noise_level, 146 | custom_split='test', 147 | random_rotate=self.options.random_rotate 148 | ) 149 | 150 | return train_data, val_data, test_data 151 | 152 | class LoupeBrainEnv(LOUPEDataEnv): 153 | def __init__(self, options): 154 | super().__init__(options) 155 | 156 | def _create_datasets(self): 157 | root_path = pathlib.Path(self._data_location) 158 | train_path = root_path / "train_no_kspace" 159 | val_path = root_path / "val_no_kspace" 160 | test_path = root_path / "test_no_kspace" 161 | 162 | train_data = RealBrainData( 163 | train_path, 164 | self.options.resolution, 165 | ActiveMRIEnv._void_transform, 166 | self.options.noise_type, 167 | self.options.noise_level, 168 | random_rotate=self.options.random_rotate 169 | ) 170 | val_data = RealBrainData( 171 | val_path, 172 | self.options.resolution, 173 | ActiveMRIEnv._void_transform, 174 | self.options.noise_type, 175 | self.options.noise_level, 176 | custom_split='val', 177 | random_rotate=self.options.random_rotate 178 | ) 179 | test_data = RealBrainData( 180 | test_path, 181 | self.options.resolution, 182 | ActiveMRIEnv._void_transform, 183 | self.options.noise_type, 184 | self.options.noise_level, 185 | custom_split='test', 186 | random_rotate=self.options.random_rotate 187 | ) 188 | 189 | return train_data, val_data, test_data 190 | 191 | class LOUPEFashionMNISTEnv(LOUPEDataEnv): 192 | def __init__(self, options): 193 | super().__init__(options) 194 | 195 | def get_same_index(target, label): 196 | label_indices = [] 197 | 198 | for i in range(len(target)): 199 | if target[i] in label: 200 | label_indices.append(i) 201 | 202 | return label_indices 203 | 204 | def _create_datasets(self): 205 | train_val_set = FashionMNISTData(args=self.options, root=self._data_location, transform=transform, 206 | noise_type=self.options.noise_type, noise_level=self.options.noise_level, custom_split='train', 207 | label_range=self.options.label_range) 208 | test_set = FashionMNISTData(args=self.options, root=self._data_location, transform=transform, 209 | noise_type=self.options.noise_type, noise_level=self.options.noise_level, custom_split='test', 210 | label_range=self.options.label_range) 211 | 212 | 213 | val_set_size = int(len(train_val_set)/4) 214 | train_set_size = len(train_val_set) - val_set_size 215 | train_set, val_set = random_split(train_val_set, [train_set_size, val_set_size], generator=torch.Generator().manual_seed(42)) 216 | 217 | 218 | if self.options.sample_rate < 1: 219 | # create a random subset 220 | train_set = random_split(train_set, [int(self.options.sample_rate * len(train_set)), 221 | len(train_set)-int(self.options.sample_rate * len(train_set))], generator=torch.Generator().manual_seed(42))[0] 222 | 223 | val_set = random_split(val_set, [int(self.options.sample_rate * len(val_set)), 224 | len(val_set)-int(self.options.sample_rate * len(val_set))], generator=torch.Generator().manual_seed(42))[0] 225 | 226 | return train_set, val_set, test_set 227 | 228 | class LOUPESyntheticEnv(LOUPEDataEnv): 229 | def __init__(self, options): 230 | super().__init__(options) 231 | 232 | def _create_datasets(self): 233 | train_set = SyntheticData(root=self._data_location, split='train') 234 | val_set = SyntheticData(root=self._data_location, split='val') 235 | test_set = SyntheticData(root=self._data_location, split='test') 236 | 237 | return train_set, val_set, test_set 238 | -------------------------------------------------------------------------------- /activemri/envs/masks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | activemri.envs.masks.py 8 | ==================================== 9 | Utilities to generate and manipulate active acquisition masks. 10 | """ 11 | from typing import Any, Dict, List, Optional, Sequence, Tuple 12 | 13 | import fastmri 14 | import numpy as np 15 | import torch 16 | 17 | 18 | def update_masks_from_indices( 19 | masks: torch.Tensor, indices: Sequence[int] 20 | ) -> torch.Tensor: 21 | assert masks.shape[0] == len(indices) 22 | new_masks = masks.clone() 23 | for i in range(len(indices)): 24 | new_masks[i, ..., indices[i]] = 1 25 | return new_masks 26 | 27 | 28 | def check_masks_complete(masks: torch.Tensor) -> List[bool]: 29 | done = [] 30 | for mask in masks: 31 | done.append(mask.bool().all().item()) 32 | return done 33 | 34 | 35 | def sample_low_frequency_mask( 36 | mask_args: Dict[str, Any], 37 | kspace_shapes: List[Tuple[int, ...]], 38 | rng: np.random.RandomState, 39 | attrs: Optional[List[Dict[str, Any]]] = None, 40 | ) -> torch.Tensor: 41 | """Samples low frequency masks. 42 | 43 | Returns masks that contain some number of the lowest k-space frequencies active. 44 | The number of frequencies doesn't have to be the same for all masks in the batch, and 45 | it can also be a random number, depending on the given ``mask_args``. Active columns 46 | will be represented as 1s in the mask, and inactive columns as 0s. 47 | 48 | The distribution and shape of the masks can be controlled by ``mask_args``. This is a 49 | dictionary with the following keys: 50 | 51 | - *"max_width"(int)*: The maximum width of the masks. 52 | - *"min_cols"(int)*: The minimum number of low frequencies columns to activate per side. 53 | - *"max_cols"(int)*: The maximum number of low frequencies columns to activate 54 | per side (inclusive). 55 | - *"width_dim"(int)*: Indicates which of the dimensions in ``kspace_shapes`` 56 | corresponds to the k-space width. 57 | - *"centered"(bool)*: Specifies if the low frequencies are in the center of the 58 | k-space (``True``) or on the edges (``False``). 59 | - *"apply_attrs_padding"(optional(bool))*: If ``True``, the function will read 60 | keys ``"padding_left"`` and ``"padding_right"`` from ``attrs`` and set all 61 | corresponding high-frequency columns to 1. 62 | 63 | The number of 1s in the effective region of the mask (see next paragraph) is sampled 64 | between ``mask_args["min_cols"]`` and ``mask_args["max_cols"]`` (inclusive). 65 | The number of dimensions for the mask tensor will be ``mask_args["width_dim"] + 2``. 66 | The size will be ``[batch_size, 1, ..., 1, mask_args["max_width"]]``. For example, with 67 | ``mask_args["width_dim"] = 1`` and ``mask_args["max_width"] = 368``, output tensor 68 | has shape ``[batch_size, 1, 368]``. 69 | 70 | This function supports simultaneously sampling masks for k-space of different number of 71 | columns. This is controlled by argument ``kspace_shapes``. From this list, the function will 72 | obtain 1) ``batch_size = len(kspace_shapes``), and 2) the width of the k-spaces for 73 | each element in the batch. The i-th mask will have 74 | ``kspace_shapes[item][mask_args["width_dim"]]`` 75 | *effective* columns. 76 | 77 | 78 | Note: 79 | The mask tensor returned will always have 80 | ``mask_args["max_width"]`` columns. However, for any element ``i`` 81 | s.t. ``kspace_shapes[i][mask_args["width_dim"]] < mask_args["max_width"]``, the 82 | function will then pad the extra k-space columns with 1s. The rest of the columns 83 | will be filled out as if the mask has the same width as that indicated by 84 | ``kspace_shape[i]``. 85 | 86 | Args: 87 | mask_args(dict(str,any)): Specifies configuration options for the masks, as explained 88 | above. 89 | 90 | kspace_shapes(list(tuple(int,...))): Specifies the shapes of the k-space data on 91 | which this mask will be applied, as explained above. 92 | 93 | rng(``np.random.RandomState``): A random number generator to sample the masks. 94 | 95 | attrs(dict(str,int)): Used to determine any high-frequency padding. It must contain 96 | keys ``"padding_left"`` and ``"padding_right"``. 97 | 98 | Returns: 99 | ``torch.Tensor``: The generated low frequency masks. 100 | 101 | """ 102 | batch_size = len(kspace_shapes) 103 | num_cols = [shape[mask_args["width_dim"]] for shape in kspace_shapes] 104 | mask = torch.zeros(batch_size, mask_args["max_width"]) 105 | num_low_freqs = rng.randint( 106 | mask_args["min_cols"], mask_args["max_cols"] + 1, size=batch_size 107 | ) 108 | for i in range(batch_size): 109 | # If padding needs to be accounted for, only add low frequency lines 110 | # beyond the padding 111 | if attrs and mask_args.get("apply_attrs_padding", False): 112 | padding_left = attrs[i]["padding_left"] 113 | padding_right = attrs[i]["padding_right"] 114 | else: 115 | padding_left, padding_right = 0, num_cols[i] 116 | 117 | pad = (num_cols[i] - 2 * num_low_freqs[i] + 1) // 2 118 | mask[i, pad : pad + 2 * num_low_freqs[i]] = 1 119 | mask[i, :padding_left] = 1 120 | mask[i, padding_right : num_cols[i]] = 1 121 | 122 | if not mask_args["centered"]: 123 | mask[i, : num_cols[i]] = fastmri.ifftshift(mask[i, : num_cols[i]]) 124 | mask[i, num_cols[i] : mask_args["max_width"]] = 1 125 | 126 | mask_shape = [batch_size] + [1] * (mask_args["width_dim"] + 1) 127 | mask_shape[mask_args["width_dim"] + 1] = mask_args["max_width"] 128 | return mask.view(*mask_shape) 129 | -------------------------------------------------------------------------------- /activemri/envs/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import json 8 | import pathlib 9 | 10 | from typing import Dict, Tuple 11 | 12 | import numpy as np 13 | import skimage.metrics 14 | import torch 15 | 16 | 17 | def get_user_dir() -> pathlib.Path: 18 | # return pathlib.Path.home() / ".activemri" 19 | return pathlib.Path.cwd() / ".activemri" 20 | 21 | 22 | def maybe_create_datacache_dir() -> pathlib.Path: 23 | datacache_dir = get_user_dir() / "__datacache__" 24 | if not datacache_dir.is_dir(): 25 | datacache_dir.mkdir() 26 | return datacache_dir 27 | 28 | 29 | def get_defaults_json() -> Tuple[Dict[str, str], str]: 30 | defaults_path = get_user_dir() / "defaults.json" 31 | if not pathlib.Path.exists(defaults_path): 32 | parent = defaults_path.parents[0] 33 | parent.mkdir(exist_ok=True) 34 | content = {"data_location": "", "saved_models_dir": ""} 35 | with defaults_path.open("w", encoding="utf-8") as f: 36 | json.dump(content, f) 37 | else: 38 | with defaults_path.open("r", encoding="utf-8") as f: 39 | content = json.load(f) 40 | return content, str(defaults_path) 41 | 42 | 43 | def import_object_from_str(classname: str): 44 | the_module, the_object = classname.rsplit(".", 1) 45 | the_object = classname.split(".")[-1] 46 | module = importlib.import_module(the_module) 47 | return getattr(module, the_object) 48 | 49 | 50 | def compute_ssim(xs: torch.Tensor, ys: torch.Tensor) -> np.ndarray: 51 | ssims = [] 52 | for i in range(xs.shape[0]): 53 | ssim = skimage.metrics.structural_similarity( 54 | xs[i].cpu().numpy(), 55 | ys[i].cpu().numpy(), 56 | data_range=ys[i].cpu().numpy().max(), 57 | ) 58 | ssims.append(ssim) 59 | return np.array(ssims, dtype=np.float32) 60 | 61 | 62 | def compute_psnr(xs: torch.Tensor, ys: torch.Tensor) -> np.ndarray: 63 | psnrs = [] 64 | for i in range(xs.shape[0]): 65 | psnr = skimage.metrics.peak_signal_noise_ratio( 66 | xs[i].cpu().numpy(), 67 | ys[i].cpu().numpy(), 68 | data_range=ys[i].cpu().numpy().max(), 69 | ) 70 | psnrs.append(psnr) 71 | return np.array(psnrs, dtype=np.float32) 72 | 73 | 74 | def compute_mse(xs: torch.Tensor, ys: torch.Tensor) -> np.ndarray: 75 | dims = tuple(range(1, len(xs.shape))) 76 | return np.mean((ys.cpu().numpy() - xs.cpu().numpy()) ** 2, axis=dims) 77 | 78 | def compute_nmse(xs: torch.Tensor, ys: torch.Tensor) -> np.ndarray: 79 | ys_numpy = ys.cpu().numpy() 80 | nmses = [] 81 | for i in range(xs.shape[0]): 82 | x = xs[i].cpu().numpy() 83 | y = ys_numpy[i] 84 | nmse = np.linalg.norm(y - x) ** 2 / np.linalg.norm(y) ** 2 85 | nmses.append(nmse) 86 | return np.array(nmses, dtype=np.float32) 87 | 88 | def compute_mse_torch(xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor: 89 | dims = tuple(range(1, len(xs.shape))) 90 | return torch.mean((ys - xs) ** 2, dim=dims) 91 | 92 | def compute_nmse_torch(xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor: 93 | dims = tuple(range(1, len(xs.shape))) 94 | nmse = torch.linalg.norm(ys-xs, dim=dims)**2 / torch.linalg.norm(ys, dim=dims)**2 95 | 96 | return nmse 97 | 98 | from torch import nn 99 | import torch.nn.functional as F 100 | 101 | class SSIMLoss(nn.Module): 102 | """ 103 | SSIM loss module. 104 | """ 105 | 106 | def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03): 107 | """ 108 | Args: 109 | win_size: Window size for SSIM calculation. 110 | k1: k1 parameter for SSIM calculation. 111 | k2: k2 parameter for SSIM calculation. 112 | """ 113 | super().__init__() 114 | self.win_size = win_size 115 | self.k1, self.k2 = k1, k2 116 | self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size ** 2) 117 | NP = win_size ** 2 118 | self.cov_norm = NP / (NP - 1) 119 | 120 | def forward(self, X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor): 121 | assert isinstance(self.w, torch.Tensor) 122 | 123 | data_range = data_range[:, None, None, None] 124 | C1 = (self.k1 * data_range) ** 2 125 | C2 = (self.k2 * data_range) ** 2 126 | ux = F.conv2d(X, self.w) # typing: ignore 127 | uy = F.conv2d(Y, self.w) # 128 | uxx = F.conv2d(X * X, self.w) 129 | uyy = F.conv2d(Y * Y, self.w) 130 | uxy = F.conv2d(X * Y, self.w) 131 | vx = self.cov_norm * (uxx - ux * ux) 132 | vy = self.cov_norm * (uyy - uy * uy) 133 | vxy = self.cov_norm * (uxy - ux * uy) 134 | A1, A2, B1, B2 = ( 135 | 2 * ux * uy + C1, 136 | 2 * vxy + C2, 137 | ux ** 2 + uy ** 2 + C1, 138 | vx + vy + C2, 139 | ) 140 | D = B1 * B2 141 | S = (A1 * A2) / D 142 | 143 | dims = tuple(range(1, len(X.shape))) 144 | 145 | return S.mean(dim=dims) 146 | 147 | SSIM = SSIMLoss() 148 | 149 | def compute_ssim_torch(xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor: 150 | global SSIM 151 | SSIM = SSIM.to(xs.device) 152 | data_range = [y.max() for y in ys] 153 | data_range = torch.stack(data_range, dim=0) 154 | 155 | return SSIM(xs, ys, data_range=data_range.detach()) 156 | 157 | def compute_psnr_torch(xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor: 158 | mse = compute_mse_torch(xs, ys) 159 | data_range = [y.max() for y in ys] 160 | data_range = torch.stack(data_range, dim=0) 161 | 162 | return 10 * torch.log10((data_range ** 2) / mse) 163 | 164 | def compute_gaussian_nll_loss(reconstruction, target, logvar): 165 | l2 = F.mse_loss(reconstruction, target, reduce=False) 166 | # Clip logvar to make variance in [0.0001, 0.1], for numerical stability 167 | 168 | logvar = logvar.clamp(min=-9.2, max=1.609) 169 | one_over_var = torch.exp(-logvar) 170 | 171 | 172 | assert len(l2) == len(logvar) 173 | return 0.5 * (one_over_var * l2 + logvar) 174 | -------------------------------------------------------------------------------- /docs/1d_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/docs/1d_animation.gif -------------------------------------------------------------------------------- /docs/2d_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/docs/2d_animation.gif -------------------------------------------------------------------------------- /docs/GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | ## Getting Started with SeqMRI 2 | 3 | ### Prerequisite 4 | 5 | - Follow [INSTALL.md](INSTALL.md) to install all required libraries. 6 | - Download the FastMRI single coil knee data from [FastMRI](https://fastmri.med.nyu.edu/) 7 | 8 | ### Download data and organise as follows 9 | 10 | ``` 11 | # For knee dataset 12 | └── datasets 13 | ├── knee 14 | ├── knee_singlecoil_train 15 | ├── knee_singlecoil_val 16 | ``` 17 | 18 | ### Train & Evaluate in Command Line 19 | 20 | #### Loupe Training 21 | 22 | Please refer to [train_loupe.sh](../examples/loupe/train_loupe.sh) for details of each arguments. 23 | 24 | ```bash 25 | # 4x line constrained sampling 26 | bash examples/loupe/train_loupe.sh 1 4 ssim real-knee 4 5e-5 10 0.5 128 0 cuda:0 1 0 0 0 0 27 | 28 | # 4x 2d point sampling 29 | bash examples/loupe/train_loupe.sh 1 22 ssim real-knee 4 1e-3 10 0.5 128 0 cuda:0 0 0 0 0 0 30 | 31 | # 8x 2d point sampling 32 | bash examples/loupe/train_loupe.sh 1 16 ssim real-knee 8 1e-3 10 0.5 128 0 cuda:0 0 0 0 0 0 33 | 34 | # 16x 2d point sampling 35 | bash examples/loupe/train_loupe.sh 1 12 ssim real-knee 16 1e-3 10 0.5 128 0 cuda:0 0 0 0 0 0 36 | ``` 37 | 38 | ```bash 39 | # 4x line constrained sampling 40 | bash examples/loupe/train_loupe.sh 1 4 ssim brain 4 5e-5 10 0.5 128 0 cuda:0 1 0 0 0 0 41 | 42 | # 4x 2d point sampling 43 | bash examples/loupe/train_loupe.sh 1 22 ssim brain 4 1e-3 10 0.5 128 0 cuda:0 0 0 0 0 0 44 | 45 | ``` 46 | 47 | 48 | #### Loupe Evaluation 49 | 50 | 51 | ```bash 52 | bash examples/loupe/test_loupe.sh EXP_DIR real-knee 53 | ``` 54 | 55 | Remeber to replace EXP_DIR with the path to the directory that contains the saved checkpoint. 56 | 57 | 58 | #### Sequential Sampling Training 59 | 60 | ```bash 61 | # 4x line constrained sampling 62 | bash examples/sequential/train_sequential.sh NUM_STEP 1 4 ssim real-knee 4 5e-5 cuda:0 10 0.5 128 0 1 63 | 64 | # 4x 2d point sampling 65 | bash examples/sequential/train_sequential.sh NUM_STEP 1 22 ssim real-knee 4 1e-3 cuda:0 10 0.5 128 0 0 66 | 67 | # 8x 2d point sampling 68 | bash examples/sequential/train_sequential.sh NUM_STEP 1 16 ssim real-knee 8 1e-3 cuda:0 10 0.5 128 0 0 69 | 70 | # 16x 2d point sampling 71 | bash examples/sequential/train_sequential.sh NUM_STEP 1 12 ssim real-knee 16 1e-3 cuda:0 10 0.5 128 0 0 72 | ``` 73 | 74 | Remember to change NUM_STEP to the a value in [1,2,4] for sequential sampling. 75 | 76 | #### Sequential Sampling Evaluation 77 | 78 | ```bash 79 | bash examples/sequential/test_sequential.sh EXP_DIR real-knee 80 | ``` 81 | 82 | ### Model Zoo 83 | 84 | #### Line-constrained Sampling 85 | 86 | | Model | Accelearation | SSIM | Link | 87 | |---------|---------------|------|------| 88 | | Loupe | 4x | 89.5 | [URL](https://drive.google.com/drive/folders/1A-JFRd5KJ_HoCd2gYePjln67YzcsTiK5?usp=sharing) | 89 | | Seq1 | 4x | 90.8 | [URL](https://drive.google.com/drive/folders/1vcIaIdSnlDPElbQm8kusBOfxR8FfMlzc?usp=sharing) | 90 | | Seq4 | 4x | 91.2 | [URL](https://drive.google.com/drive/folders/1Y_fvnne5Gx0zaXFC0ANZYnlun7CeJ2Kv?usp=sharing) | 91 | 92 | 93 | #### 2D Point Sampling 94 | 95 | | Model | Accelearation | SSIM | Link | 96 | |---------|---------------|------|------| 97 | | Loupe | 4x | 92.4 | [URL](https://drive.google.com/drive/folders/1cTpc1V8EuLVyZ4iy3EIW_XhEzgmiecgN?usp=sharing) | 98 | | Seq1 | 4x | 92.7 | [URL](https://drive.google.com/drive/folders/1ptKDYk7Dbff9kOJBXUPkpmLqoNoPA4_z?usp=sharing) | 99 | | Seq4 | 4x | 92.9 | [URL](https://drive.google.com/drive/folders/1KG8vzruVlJkxlyywZUXDkQdXGyFJzaNB?usp=sharing) | 100 | -------------------------------------------------------------------------------- /docs/INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | This installation guide shows you how to set up the environment for running our code using conda or Singularity container. 3 | 4 | First clone the ActiveMRI repository 5 | ``` 6 | git clone https://github.com/tianweiy/SeqMRI.git 7 | cd SeqMRI 8 | ``` 9 | Then start a virtual environment with new environment variables nad 10 | ``` 11 | conda create --name activemri python=3.7 12 | conda activate activemri 13 | ``` 14 | Install PyTorch 15 | ``` 16 | conda install pytorch=1.7.1 torchvision cudatoolkit=10.2 -c pytorch 17 | ``` 18 | Install all requirements 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | ## Singularity Container 23 | To use the Singularity Container, run the following piece of code before creating the virtual environment. 24 | 25 | Login and enter API access token 26 | ``` 27 | singularity remote login 28 | ``` 29 | Build the image to a .sif file 30 | ``` 31 | singularity build --remote active-mri.sif active-mri.def 32 | ``` 33 | Run a singularity shell 34 | ``` 35 | singularity shell --nv active-mri.sif 36 | ``` 37 | Proceed with the installation instruction above. -------------------------------------------------------------------------------- /docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/docs/teaser.png -------------------------------------------------------------------------------- /examples/loupe/test_loupe.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | EXP_DIR=$1 4 | DATASET_NAME=$2 5 | 6 | python examples/loupe/train_loupe.py \ 7 | --exp-dir ${EXP_DIR} \ 8 | --dataset-name ${DATASET_NAME} \ 9 | --model LOUPE \ 10 | --input_chans 2 \ 11 | --test -------------------------------------------------------------------------------- /examples/loupe/train_loupe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import pathlib 8 | import random 9 | import numpy as np 10 | import torch 11 | import uuid 12 | import activemri.envs.loupe_envs as loupe_envs 13 | from activemri.baselines.non_rl import NonRLTrainer, NonRLTester 14 | import matplotlib 15 | matplotlib.use('Agg') 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser(description='MRI Reconstruction Example') 20 | parser.add_argument('--num-pools', type=int, default=4, help='Number of U-Net pooling layers') 21 | parser.add_argument('--num-step', type=int, default=2, help='Number of LSTM iterations') 22 | parser.add_argument('--drop-prob', type=float, default=0.0, help='Dropout probability') 23 | parser.add_argument('--num-chans', type=int, default=64, help='Number of U-Net channels') 24 | 25 | parser.add_argument('--batch-size', default=16, type=int, help='Mini batch size') 26 | parser.add_argument('--num-epochs', type=int, default=50, help='Number of training epochs') 27 | parser.add_argument('--noise-type', type=str, default='none', help='Type of additive noise to measurements') 28 | parser.add_argument('--noise-level', type=float, default=0, help='Noise level') 29 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 30 | parser.add_argument('--lr-step-size', type=int, default=40, 31 | help='Period of learning rate decay') 32 | parser.add_argument('--lr-gamma', type=float, default=0.1, 33 | help='Multiplicative factor of learning rate decay') 34 | parser.add_argument('--weight-decay', type=float, default=0., 35 | help='Strength of weight decay regularization') 36 | parser.add_argument('--report-interval', type=int, default=100, help='Period of loss reporting') 37 | parser.add_argument('--data-parallel', action='store_true', 38 | help='If set, use multiple GPUs using data parallelism') 39 | parser.add_argument('--device', type=str, default='cuda', 40 | help='Which device to train on. Set to "cuda" to use the GPU') 41 | parser.add_argument('--exp-dir', type=pathlib.Path, required=True, 42 | help='Path where model and results should be saved') 43 | parser.add_argument('--checkpoint1', type=str, 44 | help='Path to an existing checkpoint. Used along with "--resume"') 45 | parser.add_argument('--entropy_weight', type=float, default=0.0, 46 | help='weight for the entropy/diversity loss') 47 | parser.add_argument('--recon_weight', type=float, default=1.0, 48 | help='weight for the reconsturction loss') 49 | parser.add_argument('--sparsity_weight', type=float, default=0.0, 50 | help='weight for the sparsity loss') 51 | parser.add_argument('--save-model', type=bool, default=False, help='save model every iteration or not') 52 | 53 | parser.add_argument('--seed', default=42, type=int, help='Seed for random number generators') 54 | parser.add_argument('--resolution', default=[128, 128], nargs='+', type=int, help='Resolution of images') 55 | 56 | parser.add_argument('--dataset-name', type=str, choices=['fashion-mnist', 'dicom-knee', 'real-knee', 'brain'], 57 | required=True, help='name of the dataset') 58 | parser.add_argument('--sample-rate', type=float, default=1., 59 | help='Fraction of total volumes to include') 60 | 61 | # Mask parameters 62 | parser.add_argument('--accelerations', nargs='+', default=[4], type=float, 63 | help='Ratio of k-space columns to be sampled. If multiple values are ' 64 | 'provided, then one of those is chosen uniformly at random for ' 65 | 'each volume.') 66 | parser.add_argument('--label_range', nargs='+', type=int, help='train using images of specific class') 67 | parser.add_argument('--model', type=str, help='name of the model to run', required=True) 68 | parser.add_argument('--input_chans', type=int, choices=[1, 2], required=True, help='number of input channels. One for real image, 2 for compelx image') 69 | parser.add_argument('--output_chans', type=int, default=1, help='number of output channels. One for real image') 70 | parser.add_argument('--line-constrained', type=int, default=0) 71 | parser.add_argument('--unet', action='store_true') 72 | parser.add_argument('--conjugate_mask', action='store_true', help='force loupe model to use conjugate symmetry.') 73 | parser.add_argument('--bi-dir', type=int, default=0) 74 | parser.add_argument('--loss_type', type=str, choices=['l1', 'ssim', 'psnr'], default='l1') 75 | parser.add_argument('--test_visual_frequency', type=int, default=1000) 76 | parser.add_argument('--test', action='store_true') 77 | parser.add_argument('--preselect', type=int, default=0) 78 | parser.add_argument('--preselect_num', type=int, default=2) 79 | parser.add_argument('--random_rotate', type=int, default=0) 80 | parser.add_argument('--random_baseline', type=int, default=0) 81 | parser.add_argument('--poisson', type=int, default=0) 82 | parser.add_argument('--spectrum', type=int, default=0) 83 | parser.add_argument("--equispaced", type=int, default=0) 84 | 85 | 86 | args = parser.parse_args() 87 | args.equispaced = args.equispaced > 0 88 | args.spectrum = args.spectrum > 0 89 | args.poisson = args.poisson > 0 90 | args.random = args.random_baseline > 0 91 | args.random_rotate = args.random_rotate > 0 92 | args.kspace_weight = 0 93 | args.line_constrained = args.line_constrained > 0 94 | 95 | if args.checkpoint1 is not None: 96 | args.resume = True 97 | else: 98 | args.resume = False 99 | 100 | noise_str = '' 101 | if args.noise_type is 'none': 102 | noise_str = '_no_noise_' 103 | else: 104 | noise_str = '_' + args.noise_type + str(args.noise_level) + '_' 105 | 106 | if args.preselect > 0: 107 | args.preselect = True 108 | else: 109 | args.preselect = False 110 | 111 | if args.bi_dir > 0 : 112 | args.bi_dir = True 113 | else: 114 | args.bi_dir = False 115 | 116 | if str(args.exp_dir) is 'auto': 117 | args.exp_dir =('checkpoints/'+args.dataset_name + '_' + str(float(args.accelerations[0])) + 118 | 'x_' + args.model + '_bi_dir_{}'.format(args.bi_dir)+ '_preselect_{}'.format(args.preselect) + 119 | noise_str + 'lr=' + str(args.lr) + '_bs=' + str(args.batch_size) + '_loss_type='+args.loss_type + 120 | '_epochs=' + str(args.num_epochs)) 121 | 122 | args.exp_dir = pathlib.Path(args.exp_dir+'_uuid_'+uuid.uuid4().hex.upper()[0:6]) 123 | 124 | print('save logs to {}'.format(args.exp_dir)) 125 | 126 | 127 | args.visualization_dir = args.exp_dir / 'visualizations' 128 | 129 | if args.test: 130 | args.batch_size = 1 131 | 132 | if args.dataset_name == 'real-knee': 133 | args.data_path = 'datasets/knee' 134 | # args.resolution = [128, 128] 135 | env = loupe_envs.LOUPERealKspaceEnv(args) 136 | elif args.dataset_name == 'brain': 137 | args.data_path = 'datasets/brain' 138 | env = loupe_envs.LoupeBrainEnv(args) 139 | else: 140 | raise NotImplementedError 141 | 142 | # set random seeds 143 | random.seed(args.seed) 144 | np.random.seed(args.seed) 145 | torch.manual_seed(args.seed) 146 | 147 | if args.test: 148 | policy = NonRLTester(env, args.exp_dir, args, None) 149 | else: 150 | policy = NonRLTrainer(args, env, torch.device(args.device)) 151 | policy() 152 | -------------------------------------------------------------------------------- /examples/loupe/train_loupe.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | PRE_SELECT=$1 # do we preselect center measurements before future sampling, should be True for all experiments 4 | PRE_SELECT_NUM=$2 # number of preselect line / pixels in 1d / 2d sampling case 5 | LOSS_TYPE=$3 # loss used for training. Options: [l1, ssim, psnr] 6 | DATASET_NAME=$4 # real-knee 7 | ACC=$5 # acceleration ratio 8 | LR=$6 # learning rate 9 | LR_STEP=${7} # epoch for learning rate decay 10 | GAMMA=${8} # leraning rate decay ratio 11 | RESOLUTION=${9} # image resolution, default to [128, 128] 12 | ROTATE=${10} # rotate the k-space and corresponding image target 13 | DEVICE=${11} # GPU device id 14 | LINE_CONSTRAINED=${12} # use 1d sampling 15 | RANDOM_BASELINE=${13} # random sampling baseline 16 | POISSON=${14} # poisson sampling baseline 17 | SPECTRUM=${15} # baseline 18 | EQUISPACED=${16} # baseline 19 | 20 | python examples/loupe/train_loupe.py \ 21 | --exp-dir auto \ 22 | --dataset-name ${DATASET_NAME} \ 23 | --accelerations ${ACC} \ 24 | --model LOUPE \ 25 | --input_chans 2 \ 26 | --line-constrained ${LINE_CONSTRAINED} \ 27 | --batch-size 16 \ 28 | --noise-type gaussian \ 29 | --loss_type "${LOSS_TYPE}" \ 30 | --save-model True \ 31 | --preselect ${PRE_SELECT} \ 32 | --preselect_num ${PRE_SELECT_NUM} \ 33 | --lr ${LR} \ 34 | --lr-step-size ${LR_STEP} \ 35 | --lr-gamma ${GAMMA} \ 36 | --resolution ${RESOLUTION} ${RESOLUTION} \ 37 | --random_rotate ${ROTATE} \ 38 | --device ${DEVICE} \ 39 | --random_baseline ${RANDOM_BASELINE} \ 40 | --poisson ${POISSON} \ 41 | --spectrum ${SPECTRUM} \ 42 | --equispaced ${EQUISPACED} 43 | -------------------------------------------------------------------------------- /examples/sequential/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifiers import * -------------------------------------------------------------------------------- /examples/sequential/test_sequential.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | EXP_DIR=$1 4 | DATASET_NAME=$2 5 | 6 | python examples/sequential/train_sequential.py \ 7 | --exp-dir ${EXP_DIR} \ 8 | --dataset-name ${DATASET_NAME} \ 9 | --model SequentialSampling \ 10 | --input_chans 2 \ 11 | --test -------------------------------------------------------------------------------- /examples/sequential/train_sequential.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import pathlib 8 | import random 9 | import numpy as np 10 | import torch 11 | import os 12 | import activemri.envs.loupe_envs as loupe_envs 13 | from activemri.baselines.non_rl import NonRLTrainer, NonRLTester 14 | import matplotlib 15 | matplotlib.use('Agg') 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser(description='MRI Reconstruction Example') 20 | parser.add_argument('--num-pools', type=int, default=4, help='Number of U-Net pooling layers') 21 | parser.add_argument('--num-step', type=int, default=2, help='Number of LSTM iterations') 22 | parser.add_argument('--drop-prob', type=float, default=0.0, help='Dropout probability') 23 | parser.add_argument('--num-chans', type=int, default=64, help='Number of U-Net channels') 24 | 25 | parser.add_argument('--batch-size', default=4, type=int, help='Mini batch size') 26 | parser.add_argument('--num-epochs', type=int, default=50, help='Number of training epochs') 27 | parser.add_argument('--noise-type', type=str, default='none', help='Type of additive noise to measurements') 28 | parser.add_argument('--noise-level', type=float, default=5e-5, help='Noise level') 29 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 30 | parser.add_argument('--lr-step-size', type=int, default=40, 31 | help='Period of learning rate decay') 32 | parser.add_argument('--lr-gamma', type=float, default=0.1, 33 | help='Multiplicative factor of learning rate decay') 34 | parser.add_argument('--weight-decay', type=float, default=0., 35 | help='Strength of weight decay regularization') 36 | parser.add_argument('--report-interval', type=int, default=100, help='Period of loss reporting') 37 | parser.add_argument('--data-parallel', action='store_true', 38 | help='If set, use multiple GPUs using data parallelism') 39 | parser.add_argument('--device', type=str, default='cuda', 40 | help='Which device to train on. Set to "cuda" to use the GPU') 41 | parser.add_argument('--exp-dir', type=pathlib.Path, required=True, 42 | help='Path where model and results should be saved') 43 | parser.add_argument('--checkpoint1', type=str, 44 | help='Path to an existing checkpoint. Used along with "--resume"') 45 | parser.add_argument('--entropy_weight', type=float, default=0.0, 46 | help='weight for the entropy/diversity loss') 47 | parser.add_argument('--recon_weight', type=float, default=1.0, 48 | help='weight for the reconsturction loss') 49 | parser.add_argument('--sparsity_weight', type=float, default=0.0, 50 | help='weight for the sparsity loss') 51 | parser.add_argument('--save-model', type=bool, default=False, help='save model every iteration or not') 52 | 53 | parser.add_argument('--seed', default=42, type=int, help='Seed for random number generators') 54 | parser.add_argument('--resolution', default=[128,128], nargs='+', type=int, help='Resolution of images') 55 | 56 | # Data parameters 57 | parser.add_argument('--data-path', type=pathlib.Path, required=False, 58 | help='Path to the dataset') 59 | parser.add_argument('--dataset-name', type=str, choices=['fashion-mnist', 'dicom-knee', 'real-knee', 'brain'], 60 | required=True, help='name of the dataset') 61 | parser.add_argument('--sample-rate', type=float, default=1., 62 | help='Fraction of total volumes to include') 63 | 64 | # Mask parameters 65 | parser.add_argument('--accelerations', nargs='+', default=[4], type=float, 66 | help='Ratio of k-space columns to be sampled. If multiple values are ' 67 | 'provided, then one of those is chosen uniformly at random for ' 68 | 'each volume.') 69 | parser.add_argument('--label_range', nargs='+', type=int, help='train using images of specific class', default=None) 70 | parser.add_argument('--model', type=str, help='name of the model to run', required=True) 71 | parser.add_argument('--input_chans', type=int, choices=[1, 2], required=True, help='number of input channels. One for real image, 2 for compelx image') 72 | parser.add_argument('--output_chans', type=int, default=1, help='number of output channels. One for real image') 73 | parser.add_argument('--line-constrained', type=int, default=0) 74 | parser.add_argument('--unet', action='store_true') 75 | parser.add_argument('--preselect', type=int, default=0) 76 | parser.add_argument('--conjugate_mask', action='store_true', help='force loupe model to use conjugate symmetry.') 77 | parser.add_argument('--loss_type', type=str, choices=['l1', 'ssim', 'psnr', 'xentropy'], default='l1') 78 | parser.add_argument('--test_visual_frequency', type=int, default=1000) 79 | parser.add_argument('--test', action='store_true') 80 | parser.add_argument('--bi-dir', type=int, default=0) 81 | parser.add_argument('--preselect_num', type=int, default=2) 82 | parser.add_argument('--binary_sampler', type=int, default=0) 83 | parser.add_argument('--clamp', type=float, default=1e10) 84 | parser.add_argument('--old_recon', type=int, default=0) 85 | parser.add_argument('--uncertainty_loss', type=int, choices=[0, 1], default=0) 86 | parser.add_argument('--uncertainty_weight', type=float, default=0) 87 | parser.add_argument('--detach_kspace', type=int, default=0) 88 | parser.add_argument('--random_rotate', type=int, default=0) 89 | parser.add_argument('--kspace_weight', type=float, default=0) 90 | parser.add_argument('--pretrained_recon', type=int, default=0) 91 | 92 | args = parser.parse_args() 93 | 94 | args.pretrained_recon = args.pretrained_recon >0 95 | args.random_rotate = args.random_rotate > 0 96 | args.line_constrained = args.line_constrained > 0 97 | 98 | if args.detach_kspace == 1: 99 | args.detach_kspace = True 100 | else: 101 | args.detach_kspace = False 102 | 103 | if args.uncertainty_loss ==1: 104 | args.uncertainty_loss = True 105 | else: 106 | args.uncertainty_loss = False 107 | 108 | if args.old_recon > 0: 109 | args.old_recon = True 110 | else: 111 | args.old_recon = False 112 | 113 | if args.checkpoint1 is not None: 114 | args.resume = True 115 | else: 116 | args.resume = False 117 | 118 | noise_str = '' 119 | if args.noise_type is 'none': 120 | noise_str = '_no_noise_' 121 | else: 122 | noise_str = '_' + args.noise_type + str(args.noise_level) + '_' 123 | 124 | if args.preselect > 0: 125 | args.preselect = True 126 | else: 127 | args.preselect = False 128 | 129 | if args.bi_dir > 0 : 130 | args.bi_dir = True 131 | else: 132 | args.bi_dir = False 133 | 134 | if args.binary_sampler > 0: 135 | args.binary_sampler = True 136 | else: 137 | args.bianry_sampler = False 138 | 139 | 140 | import uuid; # uuid.uuid4().hex.upper()[0:6] 141 | 142 | if str(args.exp_dir) is 'auto': 143 | if os.name == 'nt': 144 | args.exp_dir = pathlib.Path('checkpoints/'+uuid.uuid4().hex.upper()[0:6]) 145 | else: 146 | args.exp_dir = ('checkpoints/'+args.dataset_name + '_' + 147 | str(float(args.accelerations[0])) + 'x_' + args.model + '_bi_dir_{}'.format(args.bi_dir) + '_step_{}'.format(args.num_step) 148 | + '_preselect_{}'.format(args.preselect) + noise_str + 'lr=' + str(args.lr) 149 | + '_bs=' + str(args.batch_size) + '_loss_type='+args.loss_type + '_epochs=' + str(args.num_epochs)) 150 | 151 | args.exp_dir = pathlib.Path(args.exp_dir+'_uuid_'+uuid.uuid4().hex.upper()[0:6]) 152 | 153 | print('save logs to {}'.format(args.exp_dir)) 154 | 155 | args.visualization_dir = args.exp_dir / 'visualizations' 156 | 157 | if args.num_step == 0: 158 | args.fixed_input = True 159 | args.num_step = 1 160 | else: 161 | args.fixed_input = False 162 | 163 | if args.test: 164 | args.batch_size = 1 165 | 166 | if args.dataset_name == 'fashion-mnist': 167 | args.data_path = 'datasets/fashion-mnist' 168 | # args.resolution = [64, 64] 169 | env = loupe_envs.LOUPEFashionMNISTEnv(args) 170 | elif args.dataset_name == 'dicom-knee': 171 | args.data_path = 'datasets/knee' 172 | # args.resolution = [128, 128] 173 | env = loupe_envs.LOUPEDICOMEnv(args) 174 | elif args.dataset_name == 'brain': 175 | args.data_path = 'datasets/brain' 176 | env = loupe_envs.LoupeBrainEnv(args) 177 | elif args.dataset_name == 'real-knee': 178 | args.data_path = 'datasets/knee' 179 | # args.resolution = [128, 128] 180 | env = loupe_envs.LOUPERealKspaceEnv(args) 181 | 182 | elif args.dataset_name == 'synthetic': 183 | args.data_path = 'datasets/synthetic' 184 | # args.resolution = [48, 48] 185 | env = loupe_envs.LOUPESyntheticEnv(args) 186 | else: 187 | raise NotImplementedError 188 | # set random seeds 189 | random.seed(args.seed) 190 | np.random.seed(args.seed) 191 | torch.manual_seed(args.seed) 192 | 193 | if args.test: 194 | policy = NonRLTester(env, args.exp_dir, args, None) 195 | else: 196 | policy = NonRLTrainer(args, env, torch.device(args.device)) 197 | policy() 198 | -------------------------------------------------------------------------------- /examples/sequential/train_sequential.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | NUM_STEP=$1 # number of sequential sampling steps (excluding preselection) 4 | PRE_SELECT=$2 5 | PRE_SELECT_NUM=$3 # number of preselect line / pixels in 1d / 2d sampling case 6 | LOSS_TYPE=$4 # loss used for training. Options: [l1, ssim, psnr] 7 | DATASET_NAME=$5 # real-knee 8 | ACCELERATIONS=$6 # acceleration ratio 9 | LR=$7 # learning rate 10 | DEVICE=${8} # GPU device id 11 | LR_STEP=${9} # epoch for learning rate decay 12 | GAMMA=${10} # leraning rate decay ratio 13 | RESOLUTION=${11} # image resolution, default to [128, 128] 14 | ROTATE=${12} # rotate the k-space and corresponding image target 15 | LINE_CONSTRAINED=${13} # use 1d sampling 16 | 17 | python examples/sequential/train_sequential.py \ 18 | --exp-dir auto \ 19 | --dataset-name ${DATASET_NAME} \ 20 | --accelerations $ACCELERATIONS\ 21 | --model SequentialSampling \ 22 | --input_chans 2 \ 23 | --line-constrained ${LINE_CONSTRAINED}\ 24 | --unet \ 25 | --save-model True \ 26 | --batch-size 16 \ 27 | --noise-type gaussian \ 28 | --loss_type "${LOSS_TYPE}" \ 29 | --num-step ${NUM_STEP} \ 30 | --preselect ${PRE_SELECT} \ 31 | --preselect_num ${PRE_SELECT_NUM} \ 32 | --lr ${LR} \ 33 | --device ${DEVICE} \ 34 | --lr-step-size ${LR_STEP} \ 35 | --lr-gamma ${GAMMA} \ 36 | --resolution ${RESOLUTION} ${RESOLUTION} \ 37 | --random_rotate ${ROTATE} -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_loupe_785FD0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_loupe_785FD0.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_loupe_85A60D.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_loupe_85A60D.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_loupe_FB9BE7.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_loupe_FB9BE7.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq0_191014.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq0_191014.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq0_997E9B.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq0_997E9B.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq0_B2F31C.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq0_B2F31C.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq1_0F3D13.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq1_0F3D13.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq1_8D8ED1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq1_8D8ED1.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq1_9BA1F7.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq1_9BA1F7.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq2_806982.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq2_806982.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq2_A0B684.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq2_A0B684.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq2_AB92C0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq2_AB92C0.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq4_2C2751.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq4_2C2751.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq4_8E14D6.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq4_8E14D6.pkl -------------------------------------------------------------------------------- /figure_reproduction/1d_data/4x_lc_seq4_C8B962.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq4_C8B962.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_loupe_1CB132.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_loupe_1CB132.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_loupe_CF3852.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_loupe_CF3852.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_loupe_F5C5C1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_loupe_F5C5C1.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq0_04BC1F.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq0_04BC1F.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq0_09DB18.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq0_09DB18.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq0_0F247E.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq0_0F247E.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq1_13BF86.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq1_13BF86.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq1_9CB984.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq1_9CB984.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq1_E154DC.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq1_E154DC.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq2_8934DD.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq2_8934DD.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq2_B11CEC.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq2_B11CEC.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq2_B68595.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq2_B68595.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq4_203258.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq4_203258.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq4_2450E9.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq4_2450E9.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/16x_2d_seq4_BD196A.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq4_BD196A.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_loupe_72B682.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_loupe_72B682.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_loupe_73A783.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_loupe_73A783.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_loupe_C655FA.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_loupe_C655FA.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq0_076771.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq0_076771.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq0_350405.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq0_350405.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq0_9E225B.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq0_9E225B.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq1_748B84.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq1_748B84.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq1_F3945C.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq1_F3945C.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq1_F8733A.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq1_F8733A.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq2_244A77.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq2_244A77.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq2_93381A.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq2_93381A.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq2_A4DA53.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq2_A4DA53.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq4_08BBA4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq4_08BBA4.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq4_C32832.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq4_C32832.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/4x_2d_seq4_D0D30F.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq4_D0D30F.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_loupe_4F9069.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_loupe_4F9069.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_loupe_5C47BE.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_loupe_5C47BE.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_loupe_D010FA.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_loupe_D010FA.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq0_130F17.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq0_130F17.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq0_72A09A.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq0_72A09A.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq0_CB05F5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq0_CB05F5.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq1_848B51.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq1_848B51.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq1_948938.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq1_948938.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq1_EC7415.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq1_EC7415.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq2_0B402E.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq2_0B402E.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq2_2956B7.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq2_2956B7.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq2_4F5297.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq2_4F5297.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq4_300BD3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq4_300BD3.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq4_56CED5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq4_56CED5.pkl -------------------------------------------------------------------------------- /figure_reproduction/2d_data/8x_2d_seq4_7A9703.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq4_7A9703.pkl -------------------------------------------------------------------------------- /figure_reproduction/teaser_data/left_test_iter=651_step0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/left_test_iter=651_step0.pkl -------------------------------------------------------------------------------- /figure_reproduction/teaser_data/left_test_iter=651_step1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/left_test_iter=651_step1.pkl -------------------------------------------------------------------------------- /figure_reproduction/teaser_data/left_test_iter=651_step2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/left_test_iter=651_step2.pkl -------------------------------------------------------------------------------- /figure_reproduction/teaser_data/left_test_iter=651_step3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/left_test_iter=651_step3.pkl -------------------------------------------------------------------------------- /figure_reproduction/teaser_data/right_test_iter=651_step0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/right_test_iter=651_step0.pkl -------------------------------------------------------------------------------- /figure_reproduction/teaser_data/right_test_iter=651_step1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/right_test_iter=651_step1.pkl -------------------------------------------------------------------------------- /figure_reproduction/teaser_data/right_test_iter=651_step2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/right_test_iter=651_step2.pkl -------------------------------------------------------------------------------- /figure_reproduction/teaser_data/right_test_iter=651_step3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/right_test_iter=651_step3.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.5 2 | scikit_image>=0.16.2 3 | pytest>=5.4.3 4 | filelock>=3.0.12 5 | gym>=0.17.2 6 | tensorboardX>=2.1 7 | imageio-ffmpeg>=0.4.2 8 | jupyter 9 | nbsphinx>=0.7.1 10 | sphinx>=3.2.1 11 | sphinxcontrib-napoleon>=0.7 12 | sphinxcontrib-osexample>=0.1.1 13 | sphinx-rtd-theme>=0.5.0 14 | sigpy 15 | fastmri@git+https://github.com/facebookresearch/fastMRI.git -------------------------------------------------------------------------------- /resources/equispaced_4x_128.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/resources/equispaced_4x_128.pt -------------------------------------------------------------------------------- /resources/spectrum_16x_128.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/resources/spectrum_16x_128.pt -------------------------------------------------------------------------------- /resources/spectrum_4x_128.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/resources/spectrum_4x_128.pt -------------------------------------------------------------------------------- /resources/spectrum_8x_128.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/resources/spectrum_8x_128.pt -------------------------------------------------------------------------------- /split_data.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import h5py 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from tqdm import tqdm 7 | 8 | train_dir = pathlib.Path('/home/tianweiy/ActiveMRI-Release/datasets/brain/train') 9 | val_dir = pathlib.Path('/home/tianweiy/ActiveMRI-Release/datasets/brain/val') 10 | 11 | def remove_kspace(path): 12 | for fname in tqdm(list(path.iterdir())): 13 | if not fname.name.endswith('.h5'): 14 | continue # Skip directories 15 | new_dir = fname.parent.parent / pathlib.Path(str(fname.parent.name) + '_no_kspace') 16 | if not new_dir.exists(): 17 | new_dir.mkdir(parents=False) 18 | new_filename = new_dir / fname.name 19 | if new_filename.exists(): 20 | continue # Skip already done files 21 | f = h5py.File(fname, 'r') 22 | fn = h5py.File(new_filename, 'w') 23 | for at in f.attrs: 24 | fn.attrs[at] = f.attrs[at] 25 | for dat in f: 26 | if dat == 'kspace': 27 | continue 28 | f.copy(dat, fn) 29 | 30 | # Run the calls below to remove the stored kspace from multicoil_ brain .h5 file, which will save on I/O later. 31 | # We don't need the multicoil kspace since we will construct singlecoil kspace from the ground truth images. 32 | # Commented out for safety. 33 | 34 | remove_kspace(train_dir) 35 | remove_kspace(val_dir) 36 | 37 | def create_train_test_split(orig_train_dir, target_test_dir, test_frac): 38 | """ 39 | Creates a train and test split from the provided training data. Works by 40 | moving random volumes from the training directory to a new test directory. 41 | 42 | WARNING: Only use this function once to create the required datasets! 43 | """ 44 | import shutil 45 | np.random.seed(0) 46 | 47 | files = sorted(list(orig_train_dir.iterdir())) 48 | target_test_dir.mkdir(parents=False, exist_ok=False) 49 | 50 | permutation = np.random.permutation(len(files)) 51 | test_indices = permutation[:int(len(files) * test_frac)] 52 | test_files = list(np.array(files)[test_indices]) 53 | 54 | for i, file in enumerate(test_files): 55 | print("Moving file {}/{}".format(i + 1, len(test_files))) 56 | shutil.move(file, target_test_dir / file.name) 57 | 58 | 59 | def count_slices(data_dir, dataset): 60 | vol_count, slice_count = 0, 0 61 | for fname in data_dir.iterdir(): 62 | with h5py.File(fname, 'r') as data: 63 | if dataset == 'brain': 64 | gt = data['reconstruction_rss'][()] 65 | vol_count += 1 66 | slice_count += gt.shape[0] 67 | print(f'{vol_count} volumes, {slice_count} slices') 68 | 69 | # For both Knee and Brain data, split off 20% of train as test 70 | dataset = 'brain' # or 'brain' 71 | train_dir = pathlib.Path('/home/tianweiy/ActiveMRI-Release/datasets/brain/train_no_kspace') 72 | val_dir = pathlib.Path('/home/tianweiy/ActiveMRI-Release/datasets/brain/val_no_kspace') 73 | test_dir = pathlib.Path('/home/tianweiy/ActiveMRI-Release/datasets/brain/test_no_kspace') 74 | 75 | test_frac = 0.2 76 | 77 | # Run this to split of test_frac of train data into test data. 78 | # Commented out for safety. 79 | 80 | # create_train_test_split(train_dir, test_dir, test_frac) 81 | 82 | count_slices(train_dir, dataset) 83 | count_slices(val_dir, dataset) 84 | count_slices(test_dir, dataset) 85 | --------------------------------------------------------------------------------