├── .gitignore
├── Command Cheatsheet.docx
├── LICENSE
├── README.md
├── active-mri.def
├── activemri
├── __init__.py
├── baselines
│ ├── __init__.py
│ ├── differential_sampler.py
│ ├── evaluation.py
│ ├── loupe.py
│ ├── loupe_codes
│ │ ├── evaluate.py
│ │ ├── layers.py
│ │ ├── reconstructors.py
│ │ ├── samplers.py
│ │ └── transforms.py
│ ├── non_rl.py
│ ├── sequential_sampling_codes
│ │ ├── __init__.py
│ │ ├── conv_sampler.py
│ │ ├── joint_reconstructor.py
│ │ └── sampler2d.py
│ └── simple_baselines.py
├── data
│ ├── __init__.py
│ ├── real_brain_data.py
│ ├── real_knee_data.py
│ ├── singlecoil_knee_data.py
│ ├── splits
│ │ └── knee_singlecoil
│ │ │ ├── test.txt
│ │ │ └── val.txt
│ └── transforms.py
└── envs
│ ├── __init__.py
│ ├── envs.py
│ ├── loupe_envs.py
│ ├── masks.py
│ └── util.py
├── docs
├── 1d_animation.gif
├── 2d_animation.gif
├── GETTING_STARTED.md
├── INSTALL.md
└── teaser.png
├── examples
├── loupe
│ ├── test_loupe.sh
│ ├── train_loupe.py
│ └── train_loupe.sh
└── sequential
│ ├── __init__.py
│ ├── test_sequential.sh
│ ├── train_sequential.py
│ └── train_sequential.sh
├── figure_reproduction
├── 1d_data
│ ├── 4x_lc_loupe_785FD0.pkl
│ ├── 4x_lc_loupe_85A60D.pkl
│ ├── 4x_lc_loupe_FB9BE7.pkl
│ ├── 4x_lc_seq0_191014.pkl
│ ├── 4x_lc_seq0_997E9B.pkl
│ ├── 4x_lc_seq0_B2F31C.pkl
│ ├── 4x_lc_seq1_0F3D13.pkl
│ ├── 4x_lc_seq1_8D8ED1.pkl
│ ├── 4x_lc_seq1_9BA1F7.pkl
│ ├── 4x_lc_seq2_806982.pkl
│ ├── 4x_lc_seq2_A0B684.pkl
│ ├── 4x_lc_seq2_AB92C0.pkl
│ ├── 4x_lc_seq4_2C2751.pkl
│ ├── 4x_lc_seq4_8E14D6.pkl
│ └── 4x_lc_seq4_C8B962.pkl
├── 2d_data
│ ├── 16x_2d_loupe_1CB132.pkl
│ ├── 16x_2d_loupe_CF3852.pkl
│ ├── 16x_2d_loupe_F5C5C1.pkl
│ ├── 16x_2d_seq0_04BC1F.pkl
│ ├── 16x_2d_seq0_09DB18.pkl
│ ├── 16x_2d_seq0_0F247E.pkl
│ ├── 16x_2d_seq1_13BF86.pkl
│ ├── 16x_2d_seq1_9CB984.pkl
│ ├── 16x_2d_seq1_E154DC.pkl
│ ├── 16x_2d_seq2_8934DD.pkl
│ ├── 16x_2d_seq2_B11CEC.pkl
│ ├── 16x_2d_seq2_B68595.pkl
│ ├── 16x_2d_seq4_203258.pkl
│ ├── 16x_2d_seq4_2450E9.pkl
│ ├── 16x_2d_seq4_BD196A.pkl
│ ├── 4x_2d_loupe_72B682.pkl
│ ├── 4x_2d_loupe_73A783.pkl
│ ├── 4x_2d_loupe_C655FA.pkl
│ ├── 4x_2d_seq0_076771.pkl
│ ├── 4x_2d_seq0_350405.pkl
│ ├── 4x_2d_seq0_9E225B.pkl
│ ├── 4x_2d_seq1_748B84.pkl
│ ├── 4x_2d_seq1_F3945C.pkl
│ ├── 4x_2d_seq1_F8733A.pkl
│ ├── 4x_2d_seq2_244A77.pkl
│ ├── 4x_2d_seq2_93381A.pkl
│ ├── 4x_2d_seq2_A4DA53.pkl
│ ├── 4x_2d_seq4_08BBA4.pkl
│ ├── 4x_2d_seq4_C32832.pkl
│ ├── 4x_2d_seq4_D0D30F.pkl
│ ├── 8x_2d_loupe_4F9069.pkl
│ ├── 8x_2d_loupe_5C47BE.pkl
│ ├── 8x_2d_loupe_D010FA.pkl
│ ├── 8x_2d_seq0_130F17.pkl
│ ├── 8x_2d_seq0_72A09A.pkl
│ ├── 8x_2d_seq0_CB05F5.pkl
│ ├── 8x_2d_seq1_848B51.pkl
│ ├── 8x_2d_seq1_948938.pkl
│ ├── 8x_2d_seq1_EC7415.pkl
│ ├── 8x_2d_seq2_0B402E.pkl
│ ├── 8x_2d_seq2_2956B7.pkl
│ ├── 8x_2d_seq2_4F5297.pkl
│ ├── 8x_2d_seq4_300BD3.pkl
│ ├── 8x_2d_seq4_56CED5.pkl
│ └── 8x_2d_seq4_7A9703.pkl
├── figures.ipynb
└── teaser_data
│ ├── left_test_iter=651_step0.pkl
│ ├── left_test_iter=651_step1.pkl
│ ├── left_test_iter=651_step2.pkl
│ ├── left_test_iter=651_step3.pkl
│ ├── right_test_iter=651_step0.pkl
│ ├── right_test_iter=651_step1.pkl
│ ├── right_test_iter=651_step2.pkl
│ └── right_test_iter=651_step3.pkl
├── requirements.txt
├── resources
├── equispaced_4x_128.pt
├── spectrum_16x_128.pt
├── spectrum_4x_128.pt
└── spectrum_8x_128.pt
└── split_data.py
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints_old
2 | checkpoints
3 | temp
4 | .vscode
5 | misc
6 | *__pycache__*
7 | .ipynb_checkpoints
8 | .DS_Store
9 | .idea
10 | personal_scripts*
11 | notebooks
12 | *.orig
13 | docs/_build
14 | docs/build*
15 | *.mp4
16 | *egg-info*
17 | activemri/models/custom_*
18 | activemri/envs/custom_*
19 | datasets
20 | *.png
21 | *.jpg
22 | *.pth
23 |
--------------------------------------------------------------------------------
/Command Cheatsheet.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/Command Cheatsheet.docx
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Tianwei Yin and Zihui Wu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # End-to-End Sequential Sampling and Reconstruction for MR Imaging
2 |
3 |
4 |
5 |
6 |
7 | > [**End-to-End Sequential Sampling and Reconstruction for MR Imaging**](http://arxiv.org/abs/2105.06460),
8 | > Tianwei Yin*, Zihui Wu*, He Sun, Adrian V. Dalca, Yisong Yue, Katherine L. Bouman (*equal contributions)
9 | > *arXiv technical report ([arXiv 2105.06460](http://arxiv.org/abs/2105.06460))*
10 |
11 |
12 | @article{yin2021end,
13 | title={End-to-End Sequential Sampling and Reconstruction for MR Imaging},
14 | author={Yin, Tianwei, and Wu, Zihui and Sun, He and Dalca, Adrian V. and Yue, Yisong and Bouman, Katherine L.},
15 | journal={arXiv preprint},
16 | year={2021},
17 | }
18 |
19 |
20 | ## Contact
21 | Any questions or suggestions are welcome!
22 |
23 | Tianwei Yin [yintianwei@utexas.edu](mailto:yintianwei@utexas.edu)
24 | Zihui Wu [zwu2@caltech.edu](mailto:zwu2@caltech.edu)
25 |
26 | ## Abstract
27 | Accelerated MRI shortens acquisition time by subsampling in the measurement k-space. Recovering a high-fidelity anatomical image from subsampled measurements requires close cooperation between two components: (1) a sampler that chooses the subsampling pattern and (2) a reconstructor that recovers images from incomplete measurements. In this paper, we leverage the sequential nature of MRI measurements, and propose a fully differentiable framework that jointly learns a sequential sampling policy simultaneously with a reconstruction strategy. This co-designed framework is able to adapt during acquisition in order to capture the most informative measurements for a particular target (Figure 1). Experimental results on the fastMRI knee dataset demonstrate that the proposed approach successfully utilizes intermediate information during the sampling process to boost reconstruction performance. In particular, our proposed method outperforms the current state-of-the-art baseline on up to 96.96% of test samples. We also investigate the individual and collective benefits of the sequential sampling and co-design strategies.
28 |
29 | ## Main Results
30 |
31 | #### Line-constrained Sampling
32 |
33 | | Model | Accelearation | SSIM |
34 | |---------|---------------|------|
35 | | Loupe | 4x | 89.5 |
36 | | Seq1 | 4x | 90.8 |
37 | | Seq4 | 4x | 91.2 |
38 |
39 |
40 | #### 2D Point Sampling
41 |
42 | | Model | Accelearation | SSIM |
43 | |---------|---------------|------|
44 | | Loupe | 4x | 92.4 |
45 | | Seq1 | 4x | 92.7 |
46 | | Seq4 | 4x | 92.9 |
47 |
48 | ## Installation
49 |
50 | Please refer to [INSTALL](docs/INSTALL.md) to set up libraries.
51 |
52 | ## Getting Started
53 |
54 | Please refer to [GETTING_STARTED](docs/GETTING_STARTED.md) to prepare the data and follow the instructions to reproduce all results.
55 | You can use [figures.ipynb](figure_reproduction/figures.ipynb) to reproduce all the figures in the paper.
56 |
57 | ## License
58 |
59 | Seq-MRI is release under MIT license (see [LICENSE](LICENSE)). It is developed based on a forked version of [active-mri-acquisition](https://github.com/facebookresearch/active-mri-acquisition). We thank the original authors for their great codebase.
60 |
--------------------------------------------------------------------------------
/active-mri.def:
--------------------------------------------------------------------------------
1 | bootstrap: docker
2 | from: jupyter/scipy-notebook
3 |
4 | %post
5 | apt-get update
6 | apt-get -y upgrade
7 | apt-get clean
8 |
9 | # Install anaconda if it is not installed yet
10 | if [ ! -d /opt/conda ]; then
11 | wget https://repo.continuum.io/archive/Anaconda3-2018.12-Linux-x86_64.sh \
12 | -O ~/conda.sh && \
13 | bash ~/conda.sh -b -p /opt/conda && \
14 | rm ~/conda.sh
15 | fi
16 |
17 | # Set anaconda path
18 | export PATH=/opt/conda/bin:$PATH
19 |
20 | # Update conda; NOTE: for some reason this doesnt actually update conda at the moment...
21 | conda update -y -n base conda
22 |
23 | # Download alternative version of python if needed (default is 3.8)
24 | conda install -y python=3.7
25 |
26 | # Install conda packages; -y is used to silently install
27 | conda config --add channels conda-forge
28 |
29 | conda install -y numpy
30 | conda install -y scipy
31 | conda install -y joblib
32 | conda install -y tqdm
33 | conda install -y vim
34 | conda install -y pynfft
35 |
36 | conda clean --tarballs
37 |
38 | # Install git and pip3
39 | apt-get -y install git-all
40 | apt-get -y install python3-pip
41 |
42 | %environment
43 | export PYTHONPATH=/opt/conda/lib/python3.7/site-packages:$PYTHONPATH
44 |
45 | export LC_ALL=C
46 |
--------------------------------------------------------------------------------
/activemri/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from . import data, envs
7 |
8 | __all__ = ["data", "envs", "experimental"]
9 |
--------------------------------------------------------------------------------
/activemri/baselines/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import abc
7 | from typing import Any, Dict, List
8 |
9 |
10 | class Policy:
11 | """ A basic policy interface. """
12 |
13 | def __init__(self, *args, **kwargs):
14 | pass
15 |
16 | @abc.abstractmethod
17 | def get_action(self, obs: Dict[str, Any], **kwargs: Any) -> List[int]:
18 | """ Returns a list of actions for a batch of observations. """
19 | pass
20 |
21 | def __call__(self, obs: Dict[str, Any], **kwargs: Any) -> List[int]:
22 | return self.get_action(obs, **kwargs)
23 |
24 |
25 | from .simple_baselines import (
26 | RandomPolicy,
27 | RandomLowBiasPolicy,
28 | LowestIndexPolicy,
29 | OneStepGreedyOracle,
30 | )
31 | from .evaluation import evaluate
32 |
33 | __all__ = [
34 | "RandomPolicy",
35 | "RandomLowBiasPolicy",
36 | "LowestIndexPolicy",
37 | "OneStepGreedyOracle",
38 | "CVPR19Evaluator",
39 | "DDQN",
40 | "DDQNTrainer",
41 | "evaluate",
42 | ]
43 |
--------------------------------------------------------------------------------
/activemri/baselines/evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from typing import Any, Dict, List, Optional, Tuple
7 |
8 | import numpy as np
9 | import argparse
10 | import numpy.matlib
11 | import activemri.baselines as baselines
12 | import activemri.envs as envs
13 |
14 | import matplotlib.pyplot as plt
15 |
16 |
17 | def evaluate(
18 | args: argparse.Namespace,
19 | env: envs.envs.ActiveMRIEnv,
20 | policy: baselines.Policy,
21 | num_episodes: int,
22 | seed: int,
23 | split: str,
24 | verbose: Optional[bool] = False,
25 | ) -> Tuple[Dict[str, np.ndarray], List[Tuple[Any, Any]]]:
26 | env.seed(seed)
27 | if split == "test":
28 | env.set_test()
29 | elif split == "val":
30 | env.set_val()
31 | else:
32 | raise ValueError(f"Invalid evaluation split: {split}.")
33 |
34 | score_keys = env.score_keys()
35 | all_scores = dict(
36 | (k, np.zeros((num_episodes * env.num_parallel_episodes, env.budget + 1)))
37 | for k in score_keys
38 | )
39 | all_img_ids = []
40 | trajectories_written = 0
41 | for episode in range(num_episodes):
42 | print('Now running episode: '+str(episode))
43 | step = 0
44 | obs, meta = env.reset()
45 | # plt.imshow(obs["reconstruction"][0,:,:,0])
46 | # plt.savefig('random.png')
47 | # input("Press Enter to continue...")
48 | if not obs:
49 | break # no more images
50 | # in case the last batch is smaller
51 | actual_batch_size = len(obs["reconstruction"])
52 | if verbose:
53 | msg = ", ".join(
54 | [
55 | f"({meta['fname'][i]}, {meta['slice_id'][i]})"
56 | for i in range(actual_batch_size)
57 | ]
58 | )
59 | print(f"Read images: {msg}")
60 | for i in range(actual_batch_size):
61 | all_img_ids.append((meta["fname"][i], meta["slice_id"][i]))
62 | batch_idx = slice(
63 | trajectories_written, trajectories_written + actual_batch_size
64 | )
65 | for k in score_keys:
66 | all_scores[k][batch_idx, step] = meta["current_score"][k]
67 | trajectories_written += actual_batch_size
68 | all_done = False
69 | while not all_done:
70 | step += 1
71 | action = policy.get_action(obs)
72 | obs, reward, done, meta, gt = env.step(action)
73 |
74 | for k in score_keys:
75 | all_scores[k][batch_idx, step] = meta["current_score"][k]
76 | all_done = all(done)
77 |
78 | # visualize the first image
79 | if episode == 0:
80 | print('Visualize')
81 | num_row, num_col = obs["reconstruction"][0,:,:,0].cpu().detach().numpy().shape
82 |
83 | kspace_t = obs["reconstruction"][0,:,:,:].fft(2,normalized=False)
84 | kspace_t = (kspace_t ** 2).sum(dim=-1).sqrt().cpu().detach().numpy()
85 | action_t = np.zeros((num_row, num_col))
86 | action_play = action[0] if type(action)==list else action
87 | action_t[:,action_play] = 1
88 | reward_play = reward[0]
89 | # gt["gt_kspace"] is a List
90 | gt_kspace_t = np.sqrt((gt["gt_kspace"][0] ** 2).sum(-1))
91 |
92 | eval_vis = {"phi_t": obs["reconstruction"][0,:,:,0].cpu().detach().numpy(),
93 | "kspace_t": kspace_t,
94 | "mask_t": np.matlib.repmat(obs["mask"][0:1,:].cpu().detach().numpy(),num_row,1),
95 | "action_t": action_t,
96 | "action_t_val": action_play,
97 | "reward_t_val": reward_play,
98 | "gt_t": gt["gt"][0,:,:,0].numpy(),
99 | "gt_kspace_t": gt_kspace_t
100 | }
101 | if not 'output_dir' in vars(args).keys():
102 | args.output_dir = args.checkpoints_dir
103 | visualize_and_save(eval_vis, step, episode, args.budget, args.output_dir)
104 |
105 | if step == 1:
106 | import os
107 | os.makedirs(args.output_dir+'mask/', exist_ok=True)
108 |
109 | left = obs['mask'][0][:len(obs['mask'][0])//2].numpy()
110 | right = np.flip(obs['mask'][0][len(obs['mask'][0])//2:].numpy())
111 | visualize_symmetry(left, right, step, episode, env.budget, args.output_dir)
112 |
113 | elif all_done:
114 | left = obs['mask'][0][:len(obs['mask'][0])//2].numpy()
115 | right = np.flip(obs['mask'][0][len(obs['mask'][0])//2:].numpy())
116 | visualize_symmetry(left, right, step, episode, env.budget, args.output_dir)
117 |
118 | # print('-------')
119 | # print(step)
120 | # print(obs['mask'])
121 | # print(list(map(int,left+right)))
122 | # print(meta["current_score"][k])
123 | # input("Press Enter to continue...")
124 | if episode % 10 == 0:
125 | np.save(os.path.join(args.output_dir, "scores_0-{}.npy".format(episode)), all_scores)
126 |
127 | for k in score_keys:
128 | all_scores[k] = all_scores[k][: len(all_img_ids), :]
129 | return all_scores, all_img_ids
130 |
131 |
132 | def visualize_and_save(eval_vis: Dict[str, Any], step: int, episode: int, num_train_steps: int, checkpoints_dir: str):
133 | _, num_col = eval_vis["phi_t"].shape
134 |
135 | fig, ax = plt.subplots(nrows=2, ncols=4, figsize=[20, 8])
136 |
137 | cmap = "viridis" if num_col == 32 else "gray"
138 |
139 | sp1 = ax[0, 0].imshow(eval_vis["phi_t"], cmap=cmap)
140 | sp2 = ax[0, 1].imshow(np.fft.fftshift(np.log(eval_vis["kspace_t"])))
141 | sp3 = ax[0, 2].imshow(eval_vis["gt_t"], cmap=cmap)
142 | sp4 = ax[0, 3].imshow(np.fft.fftshift(np.log(eval_vis["gt_kspace_t"])))
143 | sp5 = ax[1, 0].imshow(np.fft.fftshift(eval_vis["mask_t"]-eval_vis["action_t"]))
144 | sp6 = ax[1, 1].imshow(np.fft.fftshift(eval_vis["action_t"]))
145 | sp7 = ax[1, 2].imshow(np.fft.fftshift(eval_vis["mask_t"]))
146 | ax[0, 0].title.set_text('Recon. at time t (phi_t)')
147 | ax[0, 1].title.set_text('Log k-space of phi_t')
148 | ax[0, 2].title.set_text('Ground truth of phi_t')
149 | ax[0, 3].title.set_text('GT log k-space of phi_t')
150 | ax[1, 0].title.set_text('Mask at time t-1')
151 | ax[1, 1].title.set_text('Action at time t')
152 | ax[1, 2].title.set_text('Mask at time t')
153 | fig.colorbar(sp1, ax=ax[0, 0])
154 | fig.colorbar(sp2, ax=ax[0, 1])
155 | fig.colorbar(sp3, ax=ax[0, 2])
156 | fig.colorbar(sp4, ax=ax[0, 3])
157 | fig.colorbar(sp5, ax=ax[1, 0])
158 | fig.colorbar(sp6, ax=ax[1, 1])
159 | fig.colorbar(sp7, ax=ax[1, 2])
160 | ax[1, 3].text(0, 0.8, r'action_t: ', fontsize=15)
161 | ax[1, 3].text(0.1, 0.65, str(eval_vis["action_t_val"])+' in [1,'+str(num_col)+']', fontsize=15)
162 | ax[1, 3].text(0.1, 0.5, str((eval_vis["action_t_val"]+num_col/2)%num_col)+' after fftshift', fontsize=15)
163 | ax[1, 3].text(0, 0.3, r'reward_t: ', fontsize=15)
164 | ax[1, 3].text(0.1, 0.15, str(eval_vis["reward_t_val"])+' (SSIM)', fontsize=15)
165 | ax[1, 3].axis('off')
166 | plt.suptitle('Step = [{}/{}]'.format(step, num_train_steps), fontsize=20)
167 |
168 | plt.savefig(checkpoints_dir+'episode='+str(episode)+'_step='+str(step)+'.png')
169 | plt.close()
170 |
171 | def visualize_symmetry(left, right, step, episode, budget, save_dir):
172 | plt.plot(list(map(int,left+right)),'.')
173 | plt.xlabel('Columns')
174 | plt.ylabel('2: Both, 1: One, 0: Neither')
175 | plt.title('Step = [{}/{}]\nEpisode = {}'.format(step, budget, 1+episode))
176 | plt.savefig(save_dir+'mask/episode='+str(1+episode)+'_step='+str(step)+'.png')
177 | plt.close()
178 |
--------------------------------------------------------------------------------
/activemri/baselines/loupe.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import argparse
3 | import time
4 | import numpy as np
5 |
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 | from tensorboardX import SummaryWriter
10 |
11 | import activemri.envs.loupe_envs as loupe_envs
12 | from activemri.baselines.loupe_codes.samplers import *
13 | from activemri.baselines.loupe_codes.reconstructors import *
14 | from activemri.baselines.loupe_codes.layers import *
15 | from activemri.baselines.loupe_codes.transforms import *
16 | from ..envs.util import compute_ssim_torch, compute_psnr_torch
17 | import os
18 | from torch.autograd import Function
19 | import matplotlib.pyplot as plt
20 |
21 | from typing import Any, Dict, List, Optional, Tuple
22 |
23 |
24 | class LOUPE(nn.Module):
25 | """
26 | Reimplementation of Loupe (https://arxiv.org/abs/1907.11374) sampling-reconstruction framework
27 | with straight through estimator (https://arxiv.org/abs/1308.3432).
28 |
29 | The model gets two components: A learned probability mask (Sampler) and a UNet reconstructor.
30 |
31 | Args:
32 | in_chans (int): Number of channels in the input to the reconstructor (2 for complex image, 1 for real image).
33 | out_chans (int): Number of channels in the output to the reconstructor (default is 1 for real image).
34 | chans (int): Number of output channels of the first convolution layer.
35 | num_pool_layers (int): Number of down-sampling and up-sampling layers.
36 | drop_prob (float): Dropout probability.
37 | shape ([int. int]): Shape of the reconstructed image
38 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to
39 | deterministic state.
40 | sparsity (float): Predefined sparsity of the learned probability mask. 1 / acceleration_ratio
41 | line_constrained (bool): Sample kspace measurements column by column
42 | conjugate_mask (bool): For real image, the corresponding kspace measurements have conjugate symmetry property
43 | (point reflection). Therefore, the information in the left half of the kspace image is the same as the
44 | other half. To take advantage of this, we can force the model to only sample right half of the kspace
45 | (when conjugate_mask is set to True)
46 | preselect (bool): preselect DC components
47 | bi_dir (bool): sample from both vertical and horizontal lines
48 | """
49 | def __init__(
50 | self,
51 | in_chans: int,
52 | out_chans: int,
53 | chans: int = 64,
54 | num_pool_layers: int = 4,
55 | drop_prob: float = 0,
56 | shape: List[int] = [320, 320],
57 | slope: float = 5,
58 | sparsity: float = 0.25,
59 | line_constrained: bool = False,
60 | conjugate_mask: bool = False,
61 | preselect: bool = False,
62 | preselect_num: int = 2,
63 | bi_dir: bool = False,
64 | random: bool = False,
65 | poisson: bool = False,
66 | spectrum: bool = False,
67 | equispaced: bool = False
68 | ):
69 | super().__init__()
70 | assert conjugate_mask is False
71 |
72 | self.preselect =preselect
73 |
74 | if not line_constrained:
75 | sparsity = (sparsity - preselect_num**2 / (shape[0]*shape[1])) if preselect else sparsity
76 | else:
77 | sparsity = (sparsity - preselect_num / shape[1]) if preselect else sparsity
78 |
79 | # for backward compatability
80 | self.samplers = nn.ModuleList()
81 |
82 | if bi_dir:
83 | assert 0
84 | self.samplers.append(BiLOUPESampler(shape, slope, sparsity, line_constrained, conjugate_mask, preselect, preselect_num))
85 | else:
86 | self.samplers.append(LOUPESampler(shape, slope, sparsity, line_constrained, conjugate_mask, preselect, preselect_num,
87 | random=random, poisson=poisson, spectrum=spectrum, equispaced=equispaced))
88 |
89 | self.reconstructor = LOUPEUNet(in_chans, out_chans, chans, num_pool_layers, drop_prob)
90 |
91 | self.sparsity = sparsity
92 | self.conjugate_mask = conjugate_mask
93 | self.data_for_vis = {}
94 |
95 | if in_chans == 1:
96 | assert self.conjugate_mask, "Reconstructor (denoiser) only take the real part of the ifft output"
97 |
98 | def forward(self, target, kspace, seed=None):
99 | """
100 | Args:
101 | kspace (torch.Tensor): Input tensor of shape NHWC (kspace data)
102 | Returns:
103 | (torch.Tensor): Output tensor of shape NCHW (reconstructed image )
104 | """
105 |
106 | # choose kspace sampling location
107 | # masked_kspace: NHWC
108 | masked_kspace, mask, neg_entropy, data_to_vis_sampler = self.samplers[0](kspace, self.sparsity)
109 |
110 | self.data_for_vis.update(data_to_vis_sampler)
111 |
112 | # Inverse Fourier Transform to get zero filled solution
113 | # NHWC to NCHW
114 | zero_filled_recon = transforms.fftshift(transforms.ifft2(masked_kspace),dim=(1,2)).permute(0, -1, 1, 2)
115 |
116 | if self.conjugate_mask:
117 | # only the real part of the ifft output is the same as the original image
118 | # when you only sample in one half of the kspace
119 | recon = self.reconstructor(zero_filled_recon[:,0:1,:,:])
120 | else:
121 | recon = self.reconstructor(zero_filled_recon, 0)
122 |
123 | self.data_for_vis.update({'input': target[0,0,:,:].cpu().detach().numpy(),
124 | 'kspace': transforms.complex_abs(transforms.fftshift(kspace[0,:,:,:],dim=(0,1))).cpu().detach().numpy(),
125 | 'masked_kspace': transforms.complex_abs(transforms.fftshift(masked_kspace[0,:,:,:],dim=(0,1))).cpu().detach().numpy(),
126 | 'zero_filled_recon': zero_filled_recon[0,0,:,:].cpu().detach().numpy(),
127 | 'recon': recon[0,0,:,:].cpu().detach().numpy()})
128 |
129 | pred_dict = {'output': recon.norm(dim=1, keepdim=True), 'energy': neg_entropy, 'mask': mask}
130 |
131 | return pred_dict
132 |
133 | def loss(self, pred_dict, target_dict, meta, loss_type):
134 | """
135 | Args:
136 | pred_dict:
137 | output: reconstructed image from downsampled kspace measurement
138 | energy: negative entropy of the probability mask
139 | mask: the binazried sampling mask (used for visualization)
140 |
141 | target_dict:
142 | target: original fully sampled image
143 |
144 | meta:
145 | recon_weight: weight of reconstruction loss
146 | entropy_weight: weight of the entropy loss (to encourage exploration)
147 | """
148 | target = target_dict['target']
149 | pred = pred_dict['output']
150 | energy = pred_dict['energy']
151 |
152 | if loss_type == 'l1':
153 | reconstruction_loss = F.l1_loss(pred, target, size_average=True)
154 | elif loss_type == 'ssim':
155 | reconstruction_loss = -torch.mean(compute_ssim_torch(pred, target))
156 | elif loss_type == 'psnr':
157 | reconstruction_loss = - torch.mean(compute_psnr_torch(pred, target))
158 | else:
159 | raise NotImplementedError
160 |
161 | entropy_loss = torch.mean(energy)
162 |
163 | loss = entropy_loss * meta['entropy_weight'] + reconstruction_loss * meta['recon_weight']
164 |
165 | log_dict = {'Total Loss': loss.item(), 'Entropy': entropy_loss.item(), 'Reconstruction': reconstruction_loss.item()}
166 |
167 | return loss, log_dict
168 |
169 | def show_mask(self):
170 | """
171 | Return:
172 | (np.ndarray) the learned undersampling mask shape HW
173 | """
174 | H, W = self.samplers[0].shape
175 | pseudo_image = torch.zeros(1, H, W, 1)
176 |
177 | return self.samplers[0].mask(pseudo_image, self.sparsity)
178 |
179 | def visualize_and_save(self, options, epoch, data_for_vis_name):
180 | fig, ax = plt.subplots(nrows=2, ncols=5, figsize=[18, 6])
181 |
182 | cmap = "viridis" if options.resolution[1] == 32 else "gray"
183 |
184 | sp1 = ax[0, 0].imshow(self.data_for_vis['input'], cmap=cmap)
185 | sp2 = ax[0, 1].imshow(np.log(self.data_for_vis['kspace']))
186 | sp3 = ax[0, 2].imshow(np.log(self.data_for_vis['masked_kspace']))
187 | sp4 = ax[0, 3].imshow(self.data_for_vis['zero_filled_recon'], cmap=cmap)
188 | sp5 = ax[0, 4].imshow(self.data_for_vis['recon'], cmap=cmap)
189 | sp6 = ax[1, 0].imshow(self.data_for_vis['prob_mask'], aspect='auto')
190 | sp7 = ax[1, 1].imshow(self.data_for_vis['rescaled_mask'], aspect='auto')
191 | sp8 = ax[1, 2].imshow(self.data_for_vis['binarized_mask'], aspect='auto')
192 | ax[0, 0].title.set_text('Input image')
193 | ax[0, 1].title.set_text('Log k-space of the input')
194 | ax[0, 2].title.set_text('Undersampled log k-space')
195 | ax[0, 3].title.set_text('Zero-filled reconstruction')
196 | ax[0, 4].title.set_text('Reconstruction')
197 | ax[1, 0].title.set_text('Probabilistic mask')
198 | ax[1, 1].title.set_text('Rescaled mask')
199 | ax[1, 2].title.set_text('Binary mask')
200 | fig.colorbar(sp1, ax=ax[0, 0])
201 | fig.colorbar(sp2, ax=ax[0, 1])
202 | fig.colorbar(sp3, ax=ax[0, 2])
203 | fig.colorbar(sp4, ax=ax[0, 3])
204 | fig.colorbar(sp5, ax=ax[0, 4])
205 | fig.colorbar(sp6, ax=ax[1, 0])
206 | fig.colorbar(sp7, ax=ax[1, 1])
207 | fig.colorbar(sp8, ax=ax[1, 2])
208 | ax[1, 3].axis('off')
209 | ax[1, 4].axis('off')
210 | plt.suptitle('Epoch = [{}/{}]'.format(1 + epoch, options.num_epochs), fontsize=20)
211 |
212 | if not os.path.isdir(options.visualization_dir):
213 | os.mkdir(options.visualization_dir)
214 |
215 | plt.savefig(str(options.visualization_dir)+'/'+data_for_vis_name+'.png')
216 | plt.close()
217 |
--------------------------------------------------------------------------------
/activemri/baselines/loupe_codes/evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | import argparse
9 | import pathlib
10 | from argparse import ArgumentParser
11 |
12 | import h5py
13 | import numpy as np
14 | from runstats import Statistics
15 | from skimage.metrics import structural_similarity, peak_signal_noise_ratio
16 | from fastmri.data import transforms
17 |
18 |
19 | def mse(gt, pred):
20 | """ Compute Mean Squared Error (MSE) """
21 | return np.mean((gt - pred) ** 2)
22 |
23 |
24 | def nmse(gt, pred):
25 | """ Compute Normalized Mean Squared Error (NMSE) """
26 | return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2
27 |
28 |
29 | def psnr(gt, pred):
30 | """ Compute Peak Signal to Noise Ratio metric (PSNR) """
31 | return peak_signal_noise_ratio(gt, pred, data_range=gt.max())
32 |
33 |
34 | def ssim(gt, pred):
35 | """ Compute Structural Similarity Index Metric (SSIM). """
36 | return structural_similarity(
37 | gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max()
38 | )
39 |
40 |
41 | METRIC_FUNCS = dict(
42 | MSE=mse,
43 | NMSE=nmse,
44 | PSNR=psnr,
45 | SSIM=ssim,
46 | )
47 |
48 |
49 | class Metrics:
50 | """
51 | Maintains running statistics for a given collection of metrics.
52 | """
53 |
54 | def __init__(self, metric_funcs):
55 | self.metrics = {
56 | metric: Statistics() for metric in metric_funcs
57 | }
58 |
59 | def push(self, target, recons):
60 | for metric, func in METRIC_FUNCS.items():
61 | self.metrics[metric].push(func(target, recons))
62 |
63 | def means(self):
64 | return {
65 | metric: stat.mean() for metric, stat in self.metrics.items()
66 | }
67 |
68 | def stddevs(self):
69 | return {
70 | metric: stat.stddev() for metric, stat in self.metrics.items()
71 | }
72 |
73 | def __repr__(self):
74 | means = self.means()
75 | stddevs = self.stddevs()
76 | metric_names = sorted(list(means))
77 | return ' '.join(
78 | f'{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}' for name in metric_names
79 | )
80 |
81 |
82 | def evaluate(args, recons_key):
83 | metrics = Metrics(METRIC_FUNCS)
84 |
85 | for tgt_file in args.target_path.iterdir():
86 | with h5py.File(tgt_file, 'r') as target, h5py.File(
87 | args.predictions_path / tgt_file.name, 'r') as recons:
88 | if args.acquisition and args.acquisition != target.attrs['acquisition']:
89 | continue
90 |
91 | if args.acceleration and target.attrs['acceleration'] != args.acceleration:
92 | continue
93 |
94 | target = target[recons_key][()]
95 | recons = recons['reconstruction'][()]
96 | target = transforms.center_crop(target, (target.shape[-1], target.shape[-1]))
97 | recons = transforms.center_crop(recons, (target.shape[-1], target.shape[-1]))
98 | metrics.push(target, recons)
99 | return metrics
100 |
101 |
102 | if __name__ == '__main__':
103 | parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
104 | parser.add_argument('--target-path', type=pathlib.Path, required=True,
105 | help='Path to the ground truth data')
106 | parser.add_argument('--predictions-path', type=pathlib.Path, required=True,
107 | help='Path to reconstructions')
108 | parser.add_argument('--challenge', choices=['singlecoil', 'multicoil'], required=True,
109 | help='Which challenge')
110 | parser.add_argument('--acceleration', type=int, default=None)
111 | parser.add_argument('--acquisition', choices=['CORPD_FBK', 'CORPDFS_FBK', 'AXT1', 'AXT1PRE',
112 | 'AXT1POST', 'AXT2', 'AXFLAIR'], default=None,
113 | help='If set, only volumes of the specified acquisition type are used '
114 | 'for evaluation. By default, all volumes are included.')
115 | args = parser.parse_args()
116 |
117 | recons_key = 'reconstruction_rss' if args.challenge == 'multicoil' else 'reconstruction_esc'
118 | metrics = evaluate(args, recons_key)
119 | print(metrics)
120 |
--------------------------------------------------------------------------------
/activemri/baselines/loupe_codes/layers.py:
--------------------------------------------------------------------------------
1 | """
2 | Portion of this code is from fastmri(https://github.com/facebookresearch/fastMRI)
3 |
4 | Copyright (c) Facebook, Inc. and its affiliates.
5 |
6 | Licensed under the MIT License.
7 | """
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 | from torch.autograd import Function
13 | from activemri.baselines.loupe_codes import transforms
14 |
15 | class ConvBlock(nn.Module):
16 | """
17 | A Convolutional Block that consists of two convolution layers each followed by
18 | instance normalization, relu activation and dropout.
19 | """
20 | def __init__(self, in_chans, out_chans, drop_prob):
21 | """
22 | Args:
23 | in_chans (int): Number of channels in the input.
24 | out_chans (int): Number of channels in the output.
25 | drop_prob (float): Dropout probability.
26 | """
27 | super(ConvBlock, self).__init__()
28 |
29 | self.in_chans = in_chans
30 | self.out_chans = out_chans
31 | self.drop_prob = drop_prob
32 |
33 | self.layers = nn.Sequential(
34 | nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1),
35 | nn.InstanceNorm2d(out_chans),
36 | nn.ReLU(),
37 | nn.Dropout2d(drop_prob),
38 | nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1),
39 | nn.InstanceNorm2d(out_chans),
40 | nn.ReLU(),
41 | nn.Dropout2d(drop_prob)
42 | )
43 |
44 | def forward(self, input):
45 | """
46 | Args:
47 | input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
48 |
49 | Returns:
50 | (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
51 | """
52 | return self.layers(input)
53 |
54 | def __repr__(self):
55 | return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans}, ' \
56 | f'drop_prob={self.drop_prob})'
57 |
58 |
59 | class ProbMask(nn.Module):
60 | """
61 | A learnable probablistic mask with the same shape as the kspace measurement.
62 | This learned mask samples measurements in the whole kspace
63 | """
64 | def __init__(self, shape=[320, 320], slope=5, preselect=False, preselect_num=0):
65 | """
66 | shape ([int. int]): Shape of the reconstructed image
67 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to
68 | deterministic state.
69 | """
70 | super(ProbMask, self).__init__()
71 |
72 | self.slope = slope
73 | self.preselect = preselect
74 | self.preselect_num_one_side = preselect_num // 2
75 |
76 | init_tensor = self._slope_random_uniform(shape)
77 | self.mask = nn.Parameter(init_tensor)
78 |
79 | def forward(self, input):
80 | """
81 | Args:
82 | input (torch.Tensor): Input tensor of shape NHWC
83 |
84 | Returns:
85 | (torch.Tensor): Output tensor of shape NHWC
86 | """
87 | logits = self.mask.view(1, input.shape[1], input.shape[2], 1)
88 | return torch.sigmoid(self.slope * logits)
89 |
90 | def _slope_random_uniform(self, shape, eps=1e-2):
91 | """
92 | uniform random sampling mask
93 | """
94 | temp = torch.zeros(shape).uniform_(eps, 1-eps)
95 |
96 | # logit with slope factor
97 | logits = -torch.log(1./temp-1.) / self.slope
98 |
99 | logits = logits.reshape(1, shape[0], shape[1], 1)
100 |
101 | logits[:, :self.preselect_num_one_side, :self.preselect_num_one_side] = -1e2
102 | logits[:, :self.preselect_num_one_side, -self.preselect_num_one_side:] = -1e2
103 | logits[:, -self.preselect_num_one_side:, :self.preselect_num_one_side] = -1e2
104 | logits[:, -self.preselect_num_one_side:, -self.preselect_num_one_side:] = -1e2
105 |
106 | return logits
107 |
108 | class HalfProbMask(nn.Module):
109 | """
110 | A learnable probablistic mask with the same shape as half of the kspace measurement (to force the model
111 | to not take conjugate symmetry points)
112 | """
113 | def __init__(self, shape=[320, 320], slope=5):
114 | """
115 | shape ([int. int]): Shape of the reconstructed image
116 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to
117 | deterministic state.
118 | """
119 | super(HalfProbMask, self).__init__()
120 |
121 | self.slope = slope
122 | init_tensor = self._slope_random_uniform(shape)
123 | self.mask = nn.Parameter(init_tensor)
124 |
125 | def forward(self, input):
126 | """
127 | Args:
128 | input (torch.Tensor): Input tensor of shape NHWC
129 |
130 | Returns:
131 | (torch.Tensor): Output tensor of shape NHWC
132 | """
133 | mask = torch.sigmoid(self.slope * self.mask).to(input.device).view(1, input.shape[1], input.shape[2]//2, 1) # only half of the kspace
134 |
135 | zero_mask = torch.zeros((1, input.shape[1], input.shape[2], 1))
136 | zero_mask[:, :, :input.shape[2]//2] = mask
137 |
138 | return zero_mask.to(input.device)
139 |
140 | def _slope_random_uniform(self, shape, eps=1e-2):
141 | """
142 | uniform random sampling mask with the shape as half of the kspace measurement
143 | """
144 | temp = torch.zeros([shape[0], shape[1]//2]).uniform_(eps, 1-eps)
145 |
146 | # logit with slope factor
147 | return -torch.log(1./temp-1.) / self.slope
148 |
149 | class LineConstrainedProbMask(nn.Module):
150 | """
151 | A learnable probablistic mask with the same shape as the kspace measurement.
152 | The mask is constrinaed to include whole kspace lines in the readout direction
153 | """
154 | def __init__(self, shape=[32], slope=5, preselect=False, preselect_num=2):
155 | super(LineConstrainedProbMask, self).__init__()
156 |
157 | if preselect:
158 | length = shape[0] - preselect_num
159 | else:
160 | length = shape[0]
161 |
162 | self.preselect_num = preselect_num
163 | self.preselect = preselect
164 | self.slope = slope
165 | init_tensor = self._slope_random_uniform(length)
166 | self.mask = nn.Parameter(init_tensor)
167 |
168 | def forward(self, input, eps=1e-10):
169 | """
170 | Args:
171 | input (torch.Tensor): Input tensor of shape NHWC
172 |
173 | Returns:
174 | (torch.Tensor): Output tensor of shape NHWC
175 | """
176 | logits = self.mask
177 | mask = torch.sigmoid(self.slope * logits).view(1, 1, self.mask.shape[0], 1)
178 |
179 | if self.preselect:
180 | if self.preselect_num % 2 ==0:
181 | zeros = torch.zeros(1, 1, self.preselect_num // 2, 1).to(input.device)
182 | mask = torch.cat([zeros, mask, zeros], dim=2)
183 | else:
184 | raise NotImplementedError()
185 |
186 | return mask
187 |
188 | def _slope_random_uniform(self, shape, eps=1e-2):
189 | """
190 | uniform random sampling mask with the same shape as the kspace measurement
191 | """
192 | temp = torch.zeros(shape).uniform_(eps, 1-eps)
193 |
194 | # logit with slope factor
195 | return -torch.log(1./temp-1.) / self.slope
196 |
197 | class BiLineConstrainedProbMask(LineConstrainedProbMask):
198 | def __init__(self, shape, slope, preselect, preselect_num):
199 | super().__init__(shape=shape, slope=slope, preselect=preselect, preselect_num=preselect_num)
200 |
201 | def forward(self, input, eps=1e-10):
202 | """
203 | Args:
204 | input (torch.Tensor): Input tensor of shape NHWC
205 |
206 | Returns:
207 | (torch.Tensor): Output tensor of shape NHWC
208 | """
209 | logits = self.mask
210 | mask = torch.sigmoid(self.slope * logits).view(1, 1, self.mask.shape[0], 1)
211 |
212 | if self.preselect:
213 | zeros = torch.zeros(1, 1, self.preselect_num, 1).to(input.device)
214 | mask = torch.cat([zeros, mask], dim=2)
215 | return mask
216 |
217 | class HalfLineConstrainedProbMask(nn.Module):
218 | """
219 | A learnable probablistic mask with the same shape as half of the kspace measurement (to force the model
220 | to not take conjugate symmetry points).
221 | The mask is constrained to include whole kspace lines in the readout direction
222 | """
223 | def __init__(self, shape=[32], slope=5):
224 | """
225 | shape ([int. int]): Shape of the reconstructed image
226 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to
227 | deterministic state.
228 | """
229 | super(HalfLineConstrainedProbMask, self).__init__()
230 |
231 | self.slope = slope
232 | init_tensor = self._slope_random_uniform(shape[0])
233 | self.mask = nn.Parameter(init_tensor)
234 |
235 | def forward(self, input):
236 | """
237 | Args:
238 | input (torch.Tensor): Input tensor of shape NHWC
239 |
240 | Returns:
241 | (torch.Tensor): Output tensor of shape NHWC
242 | """
243 | mask = torch.sigmoid(self.slope * self.mask).to(input.device).view(1, 1, input.shape[2]//2, 1) # only half of the kspace
244 |
245 | zero_mask = torch.zeros((1, 1, input.shape[2], 1))
246 | zero_mask[:, :, :input.shape[2]//2] = mask
247 |
248 | return zero_mask.to(input.device)
249 |
250 | def _slope_random_uniform(self, shape, eps=1e-2):
251 | """
252 | uniform random sampling mask with the shape as half of the kspace measurement
253 | """
254 | temp = torch.zeros(shape//2).uniform_(eps, 1-eps)
255 |
256 | # logit with slope factor
257 | return -torch.log(1./temp-1.) / self.slope
258 |
259 |
260 | def RescaleProbMap(batch_x, sparsity):
261 | """
262 | Rescale Probability Map
263 | given a prob map x, rescales it so that it obtains the desired sparsity
264 |
265 | if mean(x) > sparsity, then rescaling is easy: x' = x * sparsity / mean(x)
266 | if mean(x) < sparsity, one can basically do the same thing by rescaling
267 | (1-x) appropriately, then taking 1 minus the result.
268 | """
269 | batch_size = len(batch_x)
270 | ret = []
271 | for i in range(batch_size):
272 | x = batch_x[i:i+1]
273 | xbar = torch.mean(x)
274 | r = sparsity / (xbar)
275 | beta = (1-sparsity) / (1-xbar)
276 |
277 | # compute adjucement
278 | le = torch.le(r, 1).float()
279 | ret.append(le * x * r + (1-le) * (1 - (1 - x) * beta))
280 |
281 | return torch.cat(ret, dim=0)
282 |
283 |
284 | class ThresholdRandomMaskSigmoidV1(Function):
285 | def __init__(self):
286 | """
287 | Straight through estimator.
288 | The forward step stochastically binarizes the probability mask.
289 | The backward step estimate the non differentiable > operator using sigmoid with large slope (10).
290 | """
291 | super(ThresholdRandomMaskSigmoidV1, self).__init__()
292 |
293 | @staticmethod
294 | def forward(ctx, input):
295 | batch_size = len(input)
296 | probs = []
297 | results = []
298 |
299 | for i in range(batch_size):
300 | x = input[i:i+1]
301 |
302 | count = 0
303 | while True:
304 | prob = x.new(x.size()).uniform_()
305 | result = (x > prob).float()
306 |
307 | if torch.isclose(torch.mean(result), torch.mean(x), atol=1e-3):
308 | break
309 |
310 | count += 1
311 |
312 | if count > 1000:
313 | print(torch.mean(prob), torch.mean(result), torch.mean(x))
314 | assert 0
315 |
316 | probs.append(prob)
317 | results.append(result)
318 |
319 | results = torch.cat(results, dim=0)
320 | probs = torch.cat(probs, dim=0)
321 | ctx.save_for_backward(input, probs)
322 |
323 | return results
324 |
325 | @staticmethod
326 | def backward(ctx, grad_output):
327 | slope = 10
328 | input, prob = ctx.saved_tensors
329 |
330 | # derivative of sigmoid function
331 | current_grad = slope * torch.exp(-slope * (input - prob)) / torch.pow((torch.exp(-slope*(input-prob))+1), 2)
332 |
333 | return current_grad * grad_output
334 |
335 | def MaximumBinarize(input):
336 | batch_size = len(input)
337 | results = []
338 |
339 | for i in range(batch_size):
340 | x = input[i:i+1]
341 | num = torch.sum(x).round().int()
342 |
343 | indices = torch.topk(x.reshape(-1), k=num)[1]
344 |
345 | mask = torch.zeros_like(x).reshape(-1)
346 |
347 | mask[indices] = 1
348 |
349 | mask = mask.reshape(*x.shape)
350 |
351 | results.append(mask)
352 |
353 | results = torch.cat(results, dim=0)
354 |
355 | return results
356 |
--------------------------------------------------------------------------------
/activemri/baselines/loupe_codes/reconstructors.py:
--------------------------------------------------------------------------------
1 | """
2 | Portion of this code is from fastmri(https://github.com/facebookresearch/fastMRI)
3 |
4 | Copyright (c) Facebook, Inc. and its affiliates.
5 |
6 | Licensed under the MIT License.
7 | """
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 | from activemri.baselines.loupe_codes.transforms import *
13 | from activemri.baselines.loupe_codes.layers import *
14 |
15 | class LOUPEUNet(nn.Module):
16 | """
17 | PyTorch implementation of a U-Net model.
18 | This is based on:
19 | Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks
20 | for biomedical image segmentation. In International Conference on Medical image
21 | computing and computer-assisted intervention, pages 234–241. Springer, 2015.
22 |
23 | The model takes a real or complex value image and use a UNet to denoise the image.
24 | A residual connection is applied to stablize training.
25 | """
26 | def __init__(self,
27 | in_chans,
28 | out_chans,
29 | chans,
30 | num_pool_layers,
31 | drop_prob,
32 | # mask_length,
33 | bi_dir=False,
34 | old_recon=False,
35 | with_uncertainty=False):
36 | """
37 | Args:
38 | in_chans (int): Number of channels in the input to the U-Net model.
39 | out_chans (int): Number of channels in the output to the U-Net model.
40 | chans (int): Number of output channels of the first convolution layer.
41 | num_pool_layers (int): Number of down-sampling and up-sampling layers.
42 | drop_prob (float): Dropout probability.
43 | """
44 | super().__init__()
45 | self.old_recon = old_recon
46 | if old_recon:
47 | assert 0
48 | in_chans = in_chans+1 # add mask dim and old reconstruction dim
49 |
50 | self.with_uncertainty = with_uncertainty
51 |
52 | if with_uncertainty:
53 | out_chans = out_chans+1
54 |
55 | self.in_chans = in_chans
56 | self.out_chans = out_chans
57 | self.chans = chans
58 | self.num_pool_layers = num_pool_layers
59 | self.drop_prob = drop_prob
60 |
61 | self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
62 | ch = chans
63 | for i in range(num_pool_layers - 1):
64 | self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob)]
65 | ch *= 2
66 | self.conv = ConvBlock(ch, ch, drop_prob)
67 |
68 | self.up_sample_layers = nn.ModuleList()
69 | for i in range(num_pool_layers - 1):
70 | self.up_sample_layers += [ConvBlock(ch * 2, ch // 2, drop_prob)]
71 | ch //= 2
72 | self.up_sample_layers += [ConvBlock(ch * 2, ch, drop_prob)]
73 | self.conv2 = nn.Sequential(
74 | nn.Conv2d(ch, ch // 2, kernel_size=1),
75 | nn.Conv2d(ch // 2, out_chans, kernel_size=1),
76 | nn.Conv2d(out_chans, out_chans, kernel_size=1),
77 | )
78 |
79 | nn.init.normal_(self.conv2[-1].weight, mean=0, std=0.001)
80 | self.conv2[-1].bias.data.fill_(0)
81 |
82 | def forward(self, input, old_recon=None, eps=1e-8):
83 | # input: NCHW
84 | # output: NCHW
85 |
86 | if self.old_recon:
87 | assert 0
88 | output = torch.cat([input, old_recon], dim=1)
89 | else:
90 | output = input
91 |
92 | stack = []
93 | # print(input.shape, mask.shape)
94 |
95 | # Apply down-sampling layers
96 | for layer in self.down_sample_layers:
97 | output = layer(output)
98 | stack.append(output)
99 | output = F.max_pool2d(output, kernel_size=2)
100 |
101 | output = self.conv(output)
102 |
103 | # Apply up-sampling layers
104 | for layer in self.up_sample_layers:
105 | downsample_layer = stack.pop()
106 | layer_size = (downsample_layer.shape[-2], downsample_layer.shape[-1])
107 | output = F.interpolate(output, size=layer_size, mode='bilinear', align_corners=False)
108 | output = torch.cat([output, downsample_layer], dim=1)
109 | output = layer(output)
110 |
111 | out_conv2 = self.conv2(output)
112 |
113 | img_residual = out_conv2[:, :1]
114 |
115 | if self.with_uncertainty:
116 | map = out_conv2[:, 1:]
117 | else:
118 | map = torch.zeros_like(out_conv2)
119 |
120 | if self.old_recon:
121 | return img_residual + old_recon
122 | else:
123 | return img_residual + torch.norm(input, dim=1, keepdim=True)
--------------------------------------------------------------------------------
/activemri/baselines/loupe_codes/samplers.py:
--------------------------------------------------------------------------------
1 | from enum import auto
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from activemri.baselines.loupe_codes.layers import *
6 |
7 | from activemri.baselines.loupe_codes import transforms
8 |
9 | from torch.autograd import Function
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import sigpy
13 | import sigpy.mri
14 |
15 | class LOUPESampler(nn.Module):
16 | """
17 | LOUPE Sampler
18 | """
19 | def __init__(self, shape=[320, 320], slope=5, sparsity=0.25, line_constrained=False,
20 | conjugate_mask=False, preselect=False, preselect_num=2, random=False, poisson=False,
21 | spectrum=False, equispaced=False):
22 | """
23 | shape ([int. int]): Shape of the reconstructed image
24 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to
25 | deterministic state.
26 | sparsity (float): Predefined sparsity of the learned probability mask. 1 / acceleration_ratio
27 | line_constrained (bool): Sample kspace measurements column by column
28 | conjugate_mask (bool): For real image, the corresponding kspace measurements have conjugate symmetry property
29 | (point reflection). Therefore, the information in the left half of the kspace image is the same as the
30 | other half. To take advantage of this, we can force the model to only sample right half of the kspace
31 | (when conjugate_mask is set to True)
32 | preselect: preselect center regions
33 | """
34 | super().__init__()
35 |
36 | assert conjugate_mask is False
37 |
38 | # probability mask
39 | if line_constrained:
40 | self.gen_mask = LineConstrainedProbMask(shape, slope, preselect=preselect, preselect_num=preselect_num)
41 | else:
42 | self.gen_mask = ProbMask(shape, slope, preselect_num=preselect_num)
43 |
44 | self.rescale = RescaleProbMap
45 | self.binarize = ThresholdRandomMaskSigmoidV1.apply # FIXME
46 |
47 | self.preselect =preselect
48 | self.preselect_num_one_side = preselect_num // 2
49 | self.shape = shape
50 | self.line_constrained = line_constrained
51 | self.random_baseline = random
52 | self.poisson_baseline = poisson
53 | self.spectrum_baseline = spectrum
54 | self.equispaced_baseline = equispaced
55 |
56 | if self.poisson_baseline:
57 | self.acc = 1 / (sparsity + (self.preselect_num_one_side*2)**2 / (128*128))
58 | print("generate variable density mask with acceleration {}".format(self.acc))
59 |
60 | if self.spectrum_baseline:
61 | acc = 1 / (sparsity + (self.preselect_num_one_side*2)**2 / (128*128))
62 | print("generate spectrum mask with acceleration {}".format(acc))
63 | mask = torch.load('resources/spectrum_{}x_128.pt'.format(int(acc)))
64 | mask = mask.reshape(1, 128, 128, 1).float()
65 | self.spectrum_mask = nn.Parameter(mask, requires_grad=False)
66 |
67 | if self.equispaced_baseline:
68 | acc = 1 / (sparsity + (self.preselect_num_one_side)*2 / (128))
69 |
70 | assert self.line_constrained
71 | assert acc == 4
72 | mask = torch.load('resources/equispaced_4x_128.pt').reshape(1, 128, 128, 1).float()
73 | self.equispaced_mask = nn.Parameter(mask, requires_grad=False)
74 |
75 | def _gen_poisson_mask(self):
76 | mask = sigpy.mri.poisson((128, 128), self.acc, dtype='int32', crop_corner=False)
77 | mask = torch.tensor(mask).reshape(1, 128, 128, 1).float()
78 | return mask
79 |
80 | def _mask_neg_entropy(self, mask, eps=1e-10):
81 | # negative of pixel wise entropy
82 | entropy = mask * torch.log(mask+eps) + (1-mask) * torch.log(1-mask+eps)
83 | return entropy
84 |
85 | def forward(self, kspace, sparsity):
86 | # kspace: NHWC
87 | # sparsity (float)
88 | prob_mask = self.gen_mask(kspace)
89 |
90 | if self.random_baseline:
91 | prob_mask = torch.ones_like(prob_mask) / 4
92 |
93 | if not self.line_constrained:
94 | prob_mask[:, :self.preselect_num_one_side, :self.preselect_num_one_side] = 0
95 | prob_mask[:, :self.preselect_num_one_side, -self.preselect_num_one_side:] = 0
96 | prob_mask[:, -self.preselect_num_one_side:, :self.preselect_num_one_side] = 0
97 | prob_mask[:, -self.preselect_num_one_side:, -self.preselect_num_one_side:] = 0
98 | else:
99 | prob_mask[..., :self.preselect_num_one_side, :] = 0
100 | prob_mask[..., -self.preselect_num_one_side:, :] = 0
101 |
102 |
103 | if not self.preselect:
104 | rescaled_mask = self.rescale(prob_mask, sparsity)
105 | binarized_mask = self.binarize(rescaled_mask)
106 | else:
107 | rescaled_mask = self.rescale(prob_mask, sparsity)
108 | if self.training:
109 | binarized_mask = self.binarize(rescaled_mask)
110 | else:
111 | binarized_mask = self.binarize(rescaled_mask)
112 |
113 | if not self.line_constrained:
114 | binarized_mask[:, :self.preselect_num_one_side, :self.preselect_num_one_side] = 1
115 | binarized_mask[:, :self.preselect_num_one_side, -self.preselect_num_one_side:] = 1
116 | binarized_mask[:, -self.preselect_num_one_side:, :self.preselect_num_one_side] = 1
117 | binarized_mask[:, -self.preselect_num_one_side:, -self.preselect_num_one_side:] = 1
118 | else:
119 | binarized_mask[..., :self.preselect_num_one_side, :] = 1
120 | binarized_mask[..., -self.preselect_num_one_side:, :] = 1
121 |
122 | neg_entropy = self._mask_neg_entropy(rescaled_mask)
123 |
124 | if self.poisson_baseline:
125 | assert not self.line_constrained
126 | binarized_mask = transforms.fftshift(self._gen_poisson_mask(), dim=(1,2)).to(kspace.device)
127 |
128 | if self.spectrum_baseline:
129 | assert not self.line_constrained
130 | binarized_mask = self.spectrum_mask # DC are in the corners
131 |
132 | if self.equispaced_baseline:
133 | binarized_mask = transforms.fftshift(self.equispaced_mask, dim=(1, 2))
134 |
135 | masked_kspace = binarized_mask * kspace
136 |
137 | data_to_vis_sampler = {'prob_mask': transforms.fftshift(prob_mask[0,:,:,0],dim=(0,1)).cpu().detach().numpy(),
138 | 'rescaled_mask': transforms.fftshift(rescaled_mask[0,:,:,0],dim=(0,1)).cpu().detach().numpy(),
139 | 'binarized_mask': transforms.fftshift(binarized_mask[0,:,:,0],dim=(0,1)).cpu().detach().numpy()}
140 |
141 | return masked_kspace, binarized_mask, neg_entropy, data_to_vis_sampler
142 |
143 | class BiLOUPESampler(nn.Module):
144 | """
145 | LOUPE Sampler
146 | """
147 | def __init__(self, shape=[320, 320], slope=5, sparsity=0.25, line_constrained=False,
148 | conjugate_mask=False, preselect=False, preselect_num=2):
149 | """
150 | shape ([int. int]): Shape of the reconstructed image
151 | slope (float): Slope for the Loupe probability mask. Larger slopes make the mask converge faster to
152 | deterministic state.
153 | sparsity (float): Predefined sparsity of the learned probability mask. 1 / acceleration_ratio
154 | line_constrained (bool): Sample kspace measurements column by column
155 | conjugate_mask (bool): For real image, the corresponding kspace measurements have conjugate symmetry property
156 | (point reflection). Therefore, the information in the left half of the kspace image is the same as the
157 | other half. To take advantage of this, we can force the model to only sample right half of the kspace
158 | (when conjugate_mask is set to True)
159 | preselect: preselect center regions
160 | """
161 | super().__init__()
162 |
163 | assert conjugate_mask is False
164 | assert line_constrained
165 |
166 | # probability mask
167 | if line_constrained:
168 | self.gen_mask = BiLineConstrainedProbMask([shape[0]*2], slope, preselect=preselect, preselect_num=preselect_num)
169 | else:
170 | assert 0
171 | self.gen_mask = ProbMask(shape, slope)
172 |
173 | self.rescale = RescaleProbMap
174 | self.binarize = ThresholdRandomMaskSigmoidV1.apply # FIXME
175 |
176 | self.preselect =preselect
177 | self.preselect_num = preselect_num
178 | self.shape = shape
179 |
180 | def _mask_neg_entropy(self, mask, eps=1e-10):
181 | # negative of pixel wise entropy
182 | entropy = mask * torch.log(mask+eps) + (1-mask) * torch.log(1-mask+eps)
183 | return entropy
184 |
185 | def forward(self, kspace, sparsity):
186 | # kspace: NHWC
187 | # sparsity (float)
188 | prob_mask = self.gen_mask(kspace)
189 |
190 | batch_size = kspace.shape[0]
191 |
192 | if not self.preselect:
193 | assert 0
194 | else:
195 | rescaled_mask = self.rescale(prob_mask, sparsity/2)
196 | if self.training:
197 | binarized_mask = self.binarize(rescaled_mask)
198 | else:
199 | binarized_mask = self.binarize(rescaled_mask)
200 |
201 | # always preselect vertical lines
202 | binarized_vertical_mask, binarized_horizontal_mask = torch.chunk(binarized_mask, dim=2, chunks=2)
203 |
204 | binarized_horizontal_mask = binarized_horizontal_mask.transpose(1, 2)
205 |
206 | binarized_mask = torch.clamp(binarized_vertical_mask + binarized_horizontal_mask, max=1, min=0)
207 |
208 | binarized_mask[..., :self.preselect_num, :] = 1
209 |
210 | masked_kspace = binarized_mask * kspace
211 | neg_entropy = self._mask_neg_entropy(rescaled_mask)
212 |
213 | # for visualization purpose
214 | vertical_mask, horizontal_mask = torch.chunk(prob_mask.reshape(1, -1), dim=-1, chunks=2)
215 | prob_mask =vertical_mask.reshape(1, 1, 1, -1)+horizontal_mask.reshape(1, 1, -1, 1)
216 |
217 | rescaled_vertical_mask, rescaled_horizontal_mask = torch.chunk(rescaled_mask.reshape(1, -1), dim=-1, chunks=2)
218 | rescaled_mask = rescaled_vertical_mask.reshape(1, 1, 1, -1)+rescaled_horizontal_mask.reshape(1, 1, -1, 1)
219 |
220 |
221 | data_to_vis_sampler = {'prob_mask': transforms.fftshift(prob_mask[0, 0],dim=(0,1)).cpu().detach().numpy(),
222 | 'rescaled_mask': transforms.fftshift(rescaled_mask[0, 0],dim=(0,1)).cpu().detach().numpy(),
223 | 'binarized_mask': transforms.fftshift(binarized_mask[0,:,:,0],dim=(0,1)).cpu().detach().numpy()}
224 |
225 | return masked_kspace, binarized_mask, neg_entropy, data_to_vis_sampler
226 |
--------------------------------------------------------------------------------
/activemri/baselines/loupe_codes/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | import numpy as np
9 | import torch
10 | from torch._C import FUSE_ADD_RELU
11 |
12 |
13 | def to_tensor(data):
14 | """
15 | Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts
16 | are stacked along the last dimension.
17 |
18 | Args:
19 | data (np.array): Input numpy array
20 |
21 | Returns:
22 | torch.Tensor: PyTorch version of data
23 | """
24 | if np.iscomplexobj(data):
25 | data = np.stack((data.real, data.imag), axis=-1)
26 | return torch.from_numpy(data)
27 |
28 |
29 | def apply_mask(data, mask_func, seed=None):
30 | """
31 | Subsample given k-space by multiplying with a mask.
32 |
33 | Args:
34 | data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where
35 | dimensions -3 and -2 are the spatial dimensions, and the final dimension has size
36 | 2 (for complex values).
37 | mask_func (callable): A function that takes a shape (tuple of ints) and a random
38 | number seed and returns a mask.
39 | seed (int or 1-d array_like, optional): Seed for the random number generator.
40 |
41 | Returns:
42 | (tuple): tuple containing:
43 | masked data (torch.Tensor): Subsampled k-space data
44 | mask (torch.Tensor): The generated mask
45 | """
46 | shape = np.array(data.shape)
47 | shape[:-3] = 1
48 | mask = mask_func(shape, seed).to(data.device)
49 | return torch.where(mask == 0, torch.Tensor([0]).to(data.device), data), mask
50 |
51 |
52 | def fft2(data):
53 | """
54 | Apply centered 2 dimensional Fast Fourier Transform.
55 |
56 | Args:
57 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
58 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
59 | assumed to be batch dimensions.
60 |
61 | Returns:
62 | torch.Tensor: The FFT of the input.
63 | """
64 | assert data.size(-1) == 2
65 | data = data.fft(2, normalized=False)
66 | return data
67 |
68 |
69 | def ifft2(data):
70 | """
71 | Apply centered 2-dimensional Inverse Fast Fourier Transform.
72 |
73 | Args:
74 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
75 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
76 | assumed to be batch dimensions.
77 |
78 | Returns:
79 | torch.Tensor: The IFFT of the input.
80 | """
81 | assert data.size(-1) == 2
82 | data = torch.ifft(data, 2, normalized=False)
83 | return data
84 |
85 |
86 | def complex_abs(data):
87 | """
88 | Compute the absolute value of a complex valued input tensor.
89 |
90 | Args:
91 | data (torch.Tensor): A complex valued tensor, where the size of the final dimension
92 | should be 2.
93 |
94 | Returns:
95 | torch.Tensor: Absolute value of data
96 | """
97 | assert data.size(-1) == 2
98 | return (data ** 2).sum(dim=-1).sqrt()
99 |
100 |
101 | def root_sum_of_squares(data, dim=0):
102 | """
103 | Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor.
104 |
105 | Args:
106 | data (torch.Tensor): The input tensor
107 | dim (int): The dimensions along which to apply the RSS transform
108 |
109 | Returns:
110 | torch.Tensor: The RSS value
111 | """
112 | return torch.sqrt((data ** 2).sum(dim))
113 |
114 |
115 | def center_crop(data, shape):
116 | """
117 | Apply a center crop to the input real image or batch of real images.
118 |
119 | Args:
120 | data (torch.Tensor): The input tensor to be center cropped. It should have at
121 | least 2 dimensions and the cropping is applied along the last two dimensions.
122 | shape (int, int): The output shape. The shape should be smaller than the
123 | corresponding dimensions of data.
124 |
125 | Returns:
126 | torch.Tensor: The center cropped image
127 | """
128 | assert 0 < shape[0] <= data.shape[-2]
129 | assert 0 < shape[1] <= data.shape[-1]
130 | w_from = (data.shape[-2] - shape[0]) // 2
131 | h_from = (data.shape[-1] - shape[1]) // 2
132 | w_to = w_from + shape[0]
133 | h_to = h_from + shape[1]
134 | return data[..., w_from:w_to, h_from:h_to]
135 |
136 |
137 | def complex_center_crop(data, shape):
138 | """
139 | Apply a center crop to the input image or batch of complex images.
140 |
141 | Args:
142 | data (torch.Tensor): The complex input tensor to be center cropped. It should
143 | have at least 3 dimensions and the cropping is applied along dimensions
144 | -3 and -2 and the last dimensions should have a size of 2.
145 | shape (int, int): The output shape. The shape should be smaller than the
146 | corresponding dimensions of data.
147 |
148 | Returns:
149 | torch.Tensor: The center cropped image
150 | """
151 | assert 0 < shape[0] <= data.shape[-3]
152 | assert 0 < shape[1] <= data.shape[-2]
153 | w_from = (data.shape[-3] - shape[0]) // 2
154 | h_from = (data.shape[-2] - shape[1]) // 2
155 | w_to = w_from + shape[0]
156 | h_to = h_from + shape[1]
157 | return data[..., w_from:w_to, h_from:h_to, :]
158 |
159 |
160 | def normalize(data, mean, stddev, eps=0.):
161 | """
162 | Normalize the given tensor using:
163 | (data - mean) / (stddev + eps)
164 |
165 | Args:
166 | data (torch.Tensor): Input data to be normalized
167 | mean (float): Mean value
168 | stddev (float): Standard deviation
169 | eps (float): Added to stddev to prevent dividing by zero
170 |
171 | Returns:
172 | torch.Tensor: Normalized tensor
173 | """
174 | return (data - mean) / (stddev + eps)
175 |
176 |
177 | def normalize_instance(data, eps=0.):
178 | """
179 | Normalize the given tensor using:
180 | (data - mean) / (stddev + eps)
181 | where mean and stddev are computed from the data itself.
182 |
183 | Args:
184 | data (torch.Tensor): Input data to be normalized
185 | eps (float): Added to stddev to prevent dividing by zero
186 |
187 | Returns:
188 | torch.Tensor: Normalized tensor
189 | """
190 | mean = data.mean()
191 | std = data.std()
192 | return normalize(data, mean, std, eps), mean, std
193 |
194 |
195 | # Helper functions
196 |
197 | def roll(x, shift, dim):
198 | """
199 | Similar to np.roll but applies to PyTorch Tensors
200 | """
201 | if isinstance(shift, (tuple, list)):
202 | assert len(shift) == len(dim)
203 | for s, d in zip(shift, dim):
204 | x = roll(x, s, d)
205 | return x
206 | shift = shift % x.size(dim)
207 | if shift == 0:
208 | return x
209 | left = x.narrow(dim, 0, x.size(dim) - shift)
210 | right = x.narrow(dim, x.size(dim) - shift, shift)
211 | return torch.cat((right, left), dim=dim)
212 |
213 |
214 | def fftshift(x, dim=None):
215 | """
216 | Similar to np.fft.fftshift but applies to PyTorch Tensors
217 | """
218 | if dim is None:
219 | dim = tuple(range(x.dim()))
220 | shift = [dim // 2 for dim in x.shape]
221 | elif isinstance(dim, int):
222 | shift = x.shape[dim] // 2
223 | else:
224 | shift = [x.shape[i] // 2 for i in dim]
225 | return roll(x, shift, dim)
226 |
227 |
228 | def ifftshift(x, dim=None):
229 | """
230 | Similar to np.fft.ifftshift but applies to PyTorch Tensors
231 | """
232 | if dim is None:
233 | dim = tuple(range(x.dim()))
234 | shift = [(dim + 1) // 2 for dim in x.shape]
235 | elif isinstance(dim, int):
236 | shift = (x.shape[dim] + 1) // 2
237 | else:
238 | shift = [(x.shape[i] + 1) // 2 for i in dim]
239 | return roll(x, shift, dim)
240 |
241 |
242 | def to_kspace(x):
243 | image = torch.cat([x, torch.zeros_like(x)], dim=-1)
244 | kspace = fft2(image)
245 |
246 | return kspace
247 |
248 | def rl_input_to_loupe_input(x, kspace):
249 | assert x.size(-1) == 2
250 | assert kspace.size(-1) == 2
251 |
252 | x = x[...,0:1].permute((0,3,1,2))
253 | kspace = fft2(fftshift(ifft2(kspace), dim=(1, 2)))
254 |
255 | return x, kspace
256 |
257 | def add_gaussian_noise(args, kspace, mean=0., std=1.):
258 | # kspace: 32*32*2
259 | noise = mean + torch.randn(kspace.size()) * std
260 | kspace = kspace + noise
261 | return kspace
262 |
263 | def data_consistency(recon, zero_filled_recon, mask):
264 | mask = mask.permute(0, 2, 3, 1)
265 | pred_kspace = ifftshift(recon.permute(0, 2, 3, 1), dim=(1, 2)).fft(2, normalized=False)
266 | fuse_kspace = torch.where((1-mask).byte(), pred_kspace, torch.tensor(0.0).to(mask.device))
267 | new_img = fftshift(fuse_kspace.ifft(2, normalized=False), dim=(1, 2)).permute(0, -1, 1, 2)
268 |
269 | return new_img + zero_filled_recon
270 |
--------------------------------------------------------------------------------
/activemri/baselines/non_rl.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch, argparse
3 | from typing import List
4 | import activemri.envs.envs as mri_envs
5 | import numpy as np
6 | import logging
7 | import os
8 | from torch.nn import functional as F
9 | from tensorboardX import SummaryWriter
10 | import time
11 | from .loupe_codes.evaluate import Metrics, METRIC_FUNCS
12 | from .loupe_codes import transforms
13 | import activemri.envs.loupe_envs as loupe_envs
14 | import pickle
15 | from .differential_sampler import SequentialUnet
16 | from .loupe import LOUPE
17 | import random
18 | import shutil
19 | import matplotlib.pyplot as plt
20 | from ..envs.util import compute_gaussian_nll_loss, compute_ssim_torch, compute_psnr_torch
21 |
22 | def build_optimizer(lr, weight_decay, model_parameters, type='adam'):
23 | if type == 'adam':
24 | optimizer = torch.optim.Adam(model_parameters, lr, weight_decay=weight_decay)
25 | elif type == 'sgd':
26 | optimizer = torch.optim.SGD(model_parameters, lr, weight_decay=weight_decay)
27 | else:
28 | raise NotImplementedError()
29 |
30 | return optimizer
31 |
32 | def build_lr_scheduler(options, optimizer):
33 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, options.lr_step_size, options.lr_gamma)
34 | return scheduler
35 |
36 | def build_model(options):
37 | if options.model == 'LOUPE':
38 | model = LOUPE(
39 | in_chans=options.input_chans,
40 | out_chans=options.output_chans,
41 | chans=options.num_chans,
42 | num_pool_layers=options.num_pools,
43 | drop_prob=options.drop_prob,
44 | sparsity=1.0/options.accelerations[0],
45 | shape=options.resolution,
46 | conjugate_mask=options.conjugate_mask,
47 | line_constrained=options.line_constrained,
48 | bi_dir=options.bi_dir,
49 | preselect_num=options.preselect_num,
50 | preselect=options.preselect,
51 | random=options.random if 'random' in options.__dict__ else False,
52 | poisson=options.poisson if 'poisson' in options.__dict__ else False,
53 | spectrum=options.spectrum if 'spectrum' in options.__dict__ else False,
54 | equispaced=options.equispaced if 'equispaced' in options.__dict__ else False,).to(options.device)
55 | elif options.model == 'SequentialSampling':
56 | model = SequentialUnet(
57 | in_chans=options.input_chans,
58 | out_chans=options.output_chans,
59 | chans=options.num_chans,
60 | num_pool_layers=options.num_pools,
61 | drop_prob=options.drop_prob,
62 | sparsity=1.0/options.accelerations[0],
63 | shape=options.resolution,
64 | conjugate_mask=options.conjugate_mask,
65 | num_step=options.num_step,
66 | preselect=options.preselect,
67 | bi_direction=options.bi_dir,
68 | preselect_num=options.preselect_num,
69 | binary_sampler=options.binary_sampler,
70 | clamp=options.clamp,
71 | line_constrained=options.line_constrained,
72 | old_recon=options.old_recon,
73 | with_uncertainty=options.uncertainty_loss,
74 | detach_kspace=options.detach_kspace,
75 | fixed_input=options.fixed_input,
76 | pretrained_recon=options.pretrained_recon if 'pretrained_recon' in options.__dict__ else False).to(options.device)
77 | else:
78 | raise NotImplementedError()
79 |
80 | return model
81 |
82 | def get_mask_stats(masks):
83 | masks = np.array(masks)
84 |
85 | return np.mean(masks, axis=0), np.std(masks, axis=0)
86 |
87 |
88 | class NonRLTrainer:
89 | """Differentiabl Sampler Trainer for active MRI acquisition.
90 |
91 | Configuration for the trainer is provided by argument ``options``. Must contain the
92 | following fields:
93 |
94 | Args:
95 | options(``argparse.Namespace``): Options for the trainer.
96 | env(``activemri.envs.ActiveMRIEnv``): Env for which the policy is trained.
97 | device(``torch.device``): Device to use.
98 | """
99 | def __init__(
100 | self,
101 | options: argparse.Namespace,
102 | env: mri_envs.ActiveMRIEnv,
103 | device: torch.device
104 | ):
105 | self.options = options
106 | self.env = env
107 | self.device = device
108 | self.model = build_model(self.options)
109 | if options.data_parallel:
110 | self.model = torch.nn.DataParallel(self.model)
111 |
112 | self.optimizer = build_optimizer(self.options.lr, self.options.weight_decay, self.model.parameters())
113 |
114 | self.train_loader, self.dev_loader, self.display_loader, _ = self.env._setup_data_handlers()
115 |
116 | self.scheduler = build_lr_scheduler(self.options, self.optimizer)
117 | self.best_dev_loss = np.inf
118 | self.epoch = 0
119 | self.start_epoch = 0
120 | self.end_epoch = options.num_epochs
121 |
122 | # setup saving, writer, and logging
123 | options.exp_dir.mkdir(parents=True, exist_ok=True)
124 |
125 | with open(os.path.join(str(options.exp_dir), 'args.pkl'), "wb") as f:
126 | pickle.dump(options.__dict__, f)
127 |
128 | options.visualization_dir.mkdir(parents=True, exist_ok=True)
129 | self.writer = SummaryWriter(log_dir=options.exp_dir / 'summary')
130 |
131 | logging.basicConfig(level=logging.INFO)
132 | self.logger = logging.getLogger(__name__)
133 | self.logger.info(self.options)
134 | self.logger.info(self.model)
135 |
136 | def load_checkpoint_if_needed(self):
137 | if self.options.resume:
138 | self.load()
139 |
140 | def evaluate(self):
141 | self.model.eval()
142 | losses = []
143 | sparsity = []
144 | targets, preds = [], []
145 | metrics = Metrics(METRIC_FUNCS)
146 | start = time.perf_counter()
147 |
148 | with torch.no_grad():
149 | for iter, data in enumerate(self.dev_loader):
150 | # input: [batch_size, num_channels, height, width] denoted as NCHW in other places
151 | # label: label of the current image (0~9 for mnist/fashion-mnist) default: -1
152 | # target: a copy of the input image for computing reconstruction loss in [NCHW]
153 | kspace, _, input, label, *ignored = data
154 |
155 | # adapt data to loupe
156 | target = input.clone().detach()
157 | target = transforms.complex_abs(target).unsqueeze(1)
158 |
159 | input = input.to(self.options.device)
160 | target = target.to(self.options.device)
161 | kspace = kspace.to(self.options.device)
162 | # label = label.to(self.options.device)
163 |
164 | pred_dict = self.model(target, kspace)
165 |
166 | if (self.epoch == 0 or (self.epoch+1) % 1 == 0) and iter == 0:
167 | data_for_vis_name = 'eval_epoch=' + str(self.epoch+1)
168 | self.model.visualize_and_save(self.options, self.epoch, data_for_vis_name)
169 |
170 | output = pred_dict['output']
171 | # only use the last reconstructed image to compute loss
172 | if isinstance(output, list):
173 | output = output[-1]
174 |
175 | target_dict = {'target': target, 'label': label, 'kspace':kspace}
176 | meta = {'entropy_weight': self.options.entropy_weight, 'recon_weight': self.options.recon_weight,
177 | 'uncertainty_weight': 0, 'kspace_weight': self.options.kspace_weight}
178 |
179 | loss, log_dict = self.model.loss(pred_dict, target_dict, meta, self.options.loss_type)
180 |
181 | mask = pred_dict['mask']
182 | sparsity.append(torch.mean(mask).item())
183 | losses.append(loss.item())
184 |
185 | # target: 16*1*32*32
186 | # output: 16*1*32*32
187 | target = target.cpu().numpy()
188 | pred = output.cpu().numpy()
189 |
190 | for t, p in zip(target, pred):
191 | metrics.push(t, p)
192 |
193 | print(metrics)
194 | self.writer.add_scalar('Dev_MSE', metrics.means()['MSE'], self.epoch)
195 | self.writer.add_scalar('Dev_NMSE', metrics.means()['NMSE'], self.epoch)
196 | self.writer.add_scalar('Dev_PSNR', metrics.means()['PSNR'], self.epoch)
197 | self.writer.add_scalar('Dev_SSIM', metrics.means()['SSIM'], self.epoch)
198 |
199 | self.writer.add_scalar('Dev_Loss', np.mean(losses), self.epoch)
200 |
201 | return np.mean(losses), np.mean(sparsity), time.perf_counter() - start
202 |
203 | def train_epoch(self):
204 | self.model.train()
205 | losses = []
206 | targets, preds = [], []
207 | metrics = Metrics(METRIC_FUNCS)
208 | avg_loss = 0.
209 | start_epoch = start_iter = time.perf_counter()
210 | global_step = self.epoch * len(self.train_loader)
211 |
212 | for iter, data in enumerate(self.train_loader):
213 | # self.scheduler.step()
214 | # input: [batch_size, num_channels, height, width] denoted as NCHW in other places
215 | # label: label of the current image (0~9 for mnist/fashion-mnist) default: -1
216 | # target: a copy of the input image for computing reconstruction loss in [NCHW]
217 | kspace, _, input, label, *ignored= data
218 |
219 | # adapt data to loupe
220 | target = input.clone().detach()
221 | target = transforms.complex_abs(target).unsqueeze(1)
222 |
223 | input = input.to(self.options.device)
224 | target = target.to(self.options.device)
225 | kspace = kspace.to(self.options.device)
226 | # label = label.to(self.options.device)
227 |
228 | """if self.options.noise_type == 'gaussian':
229 | kspace = transforms.add_gaussian_noise(self.options, kspace, mean=0., std=self.options.noise_level)
230 | """
231 |
232 | pred_dict = self.model(target, kspace)
233 |
234 | if (self.epoch == 0 or (self.epoch+1) % 1 == 0) and iter == 0:
235 | data_for_vis_name = 'train_epoch={}_iter={}'.format(str(self.epoch+1), str(iter+1))
236 | self.model.visualize_and_save(self.options, self.epoch, data_for_vis_name)
237 |
238 | output = pred_dict['output']
239 | target_dict = {'target': target, 'label': label, 'kspace': kspace}
240 | meta = {'entropy_weight': self.options.entropy_weight, 'recon_weight': self.options.recon_weight,
241 | 'kspace_weight': self.options.kspace_weight,
242 | 'uncertainty_weight': self.options.uncertainty_weight if 'uncertainty_weight' in self.options.__dict__ else 0}
243 |
244 | loss, log_dict = self.model.loss(pred_dict, target_dict, meta, self.options.loss_type)
245 |
246 | self.optimizer.zero_grad()
247 | loss.backward()
248 | # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
249 |
250 | self.optimizer.step()
251 |
252 | self.writer.add_scalar('Train_Loss', loss.item(), global_step + iter)
253 |
254 | losses.append(loss.item())
255 |
256 | # target: 16*1*32*32
257 | # output: 16*1*32*32
258 |
259 | if isinstance(output, list):
260 | output = output[-1]
261 |
262 | target = target.cpu().detach().numpy()
263 | pred = output.cpu().detach().numpy()
264 |
265 | if iter % self.options.report_interval == 0:
266 | self.logger.info(
267 | f'Epoch = [{1 + self.epoch:3d}/{self.options.num_epochs:3d}] '
268 | f'Iter = [{iter:4d}/{len(self.train_loader):4d}] '
269 | f'Time = {time.perf_counter() - start_iter:.4f}s',
270 | )
271 | for key, val in log_dict.items():
272 | print('{} = {}'.format(key, val))
273 |
274 | start_iter = time.perf_counter()
275 |
276 | for t, p in zip(target, pred):
277 | metrics.push(t, p)
278 |
279 | print(metrics)
280 | self.writer.add_scalar('Train_MSE', metrics.means()['MSE'], self.epoch)
281 | self.writer.add_scalar('Train_NMSE', metrics.means()['NMSE'], self.epoch)
282 | self.writer.add_scalar('Train_PSNR', metrics.means()['PSNR'], self.epoch)
283 | self.writer.add_scalar('Train_SSIM', metrics.means()['SSIM'], self.epoch)
284 |
285 | return np.mean(np.array(losses)), time.perf_counter() - start_epoch
286 |
287 | def _train_loupe(self):
288 | for epoch in range(self.start_epoch, self.end_epoch):
289 | self.epoch = epoch
290 | train_loss, train_time = self.train_epoch()
291 | self.scheduler.step(epoch)
292 | dev_loss, mean_sparsity, dev_time = self.evaluate()
293 |
294 | is_new_best = dev_loss < self.best_dev_loss
295 | self.best_dev_loss = min(self.best_dev_loss, dev_loss)
296 | if self.options.save_model:
297 | self.save_model(is_new_best)
298 | self.logger.info(
299 | f'Epoch = [{1 + self.epoch:4d}/{self.options.num_epochs:4d}] TrainLoss = {train_loss:.4g} '
300 | f'DevLoss = {dev_loss:.4g} MeanSparsity = {mean_sparsity:.4g} TrainTime = {train_time:.4f}s DevTime = {dev_time:.4f}s',
301 | )
302 | self.writer.close()
303 |
304 | def __call__(self):
305 | self.load_checkpoint_if_needed()
306 | return self._train_loupe()
307 |
308 | def save_model(self,is_new_best):
309 | exp_dir = self.options.exp_dir
310 | torch.save(
311 | {
312 | 'epoch': self.epoch,
313 | 'options': self.options,
314 | 'model': self.model.state_dict(),
315 | 'optimizer': self.optimizer.state_dict(),
316 | 'best_dev_loss': self.best_dev_loss,
317 | 'exp_dir': exp_dir
318 | },
319 | f = exp_dir / 'model.pt'
320 | )
321 | if is_new_best:
322 | shutil.copyfile(exp_dir / 'model.pt', exp_dir / 'best_model.pt')
323 |
324 |
325 | def load(self):
326 | self.model = build_model(self.options)
327 | if self.options.data_parallel:
328 | self.model = torch.nn.DataParallel(self.model)
329 |
330 | checkpoint1 = torch.load(self.options.checkpoint1)
331 | print("Load checkpoint {} with loss {}".format(checkpoint1['epoch'], checkpoint1['best_dev_loss']))
332 |
333 | self.model.load_state_dict(checkpoint1['model'])
334 |
335 | self.optimizer = build_optimizer(self.options.lr, self.options.weight_decay, self.model.parameters())
336 | self.optimizer.load_state_dict(checkpoint1['optimizer'])
337 |
338 | self.best_dev_loss = checkpoint1['best_dev_loss']
339 | self.start_epoch = checkpoint1['epoch'] + 1
340 | del checkpoint1
341 |
342 |
343 | class NonRLTester:
344 | def __init__(
345 | self,
346 | env: loupe_envs.LOUPEDataEnv,
347 | exp_dir: str,
348 | options: argparse.Namespace,
349 | label_range: List[int]
350 | ):
351 | self.env = env
352 | # load options and model
353 | self.load(os.path.join(exp_dir , 'best_model.pt'))
354 | self.options.label_range = label_range
355 | self.options.exp_dir = exp_dir
356 | self.options.test_visual_frequency = options.test_visual_frequency
357 | self.options.visualization_dir = options.visualization_dir
358 | self.options.batch_size = 1
359 | _, _, self.dev_loader, _ = self.env._setup_data_handlers()
360 |
361 | # setup saving and logging
362 | self.options.exp_dir.mkdir(parents=True, exist_ok=True)
363 |
364 | logging.basicConfig(level=logging.INFO)
365 | self.logger = logging.getLogger(__name__)
366 | self.logger.info(self.options)
367 | self.logger.info(self.model)
368 |
369 | def evaluate(self):
370 | self.model.eval()
371 | losses = []
372 | sparsity = []
373 | targets, preds = [], []
374 | metrics = Metrics(METRIC_FUNCS)
375 | masks = defaultdict(list)
376 | sample_image = dict()
377 |
378 | with torch.no_grad():
379 | for iter, data in enumerate(self.dev_loader):
380 | # input: [batch_size, num_channels, height, width] denoted as NCHW in other places
381 | # label: label of the current image (0~9 for mnist/fashion-mnist) default: -1
382 | # target: a copy of the input image for computing reconstruction loss in [NCHW]
383 | kspace, _, input, label, *ignored = data
384 |
385 | # adapt data to loupe
386 | target = input.clone().detach()
387 | target = transforms.complex_abs(target).unsqueeze(1)
388 |
389 | input = input.to(self.options.device)
390 | target = target.to(self.options.device)
391 | kspace = kspace.to(self.options.device)
392 |
393 | """if self.options.noise_type == 'gaussian':
394 | kspace = transforms.add_gaussian_noise(self.options, kspace, mean=0., std=self.options.noise_level)
395 | """
396 |
397 | pred_dict = self.model(target, kspace)
398 |
399 | if iter % self.options.test_visual_frequency == 0:
400 | data_for_vis_name = 'test_iter=' + str(iter+1)
401 | print("visualize {}".format(data_for_vis_name))
402 | self.model.visualize_and_save(self.options, iter, data_for_vis_name)
403 |
404 | output = pred_dict['output']
405 | if isinstance(output, list):
406 | output = output[-1]
407 |
408 | loss = compute_ssim_torch(output, target) # F.l1_loss(output, target, size_average=True)
409 |
410 | mask = pred_dict['mask']
411 | # masks[label.item()].append(transforms.fftshift(mask[0, 0], dim=(0, 1)).cpu().detach().numpy())
412 |
413 | sparsity.append(torch.mean(mask).item())
414 | losses.append(loss.item())
415 |
416 | targets.extend(target.cpu().numpy())
417 | preds.extend(output.cpu().numpy())
418 |
419 | # sample_image[label.item()] = target[0, 0].cpu().detach().numpy()
420 |
421 | """for key, val in masks.items():
422 | # get average mask and standard deviation
423 | mask_mean, mask_std = get_mask_stats(val)
424 | mask_mean[16] = 0
425 | mask_mean[:, 16] = 0
426 | fig, ax = plt.subplots(nrows=1, ncols=3, figsize=[10, 3])
427 | sp1 = ax[0].imshow(sample_image[key], cmap='viridis')
428 | sp2 = ax[1].imshow(mask_mean)
429 | sp3 = ax[2].imshow(mask_std)
430 | fig.colorbar(sp1, ax=ax[0])
431 | fig.colorbar(sp2, ax=ax[1])
432 | fig.colorbar(sp3, ax=ax[2])
433 |
434 | plt.savefig(str(self.options.visualization_dir)+'/label_{}_mask_stats.png'.format(key))
435 | plt.close()
436 | """
437 |
438 | print("Done Prediction")
439 |
440 | for t, p in zip(targets, preds):
441 | metrics.push(t, p)
442 |
443 | return losses, np.mean(sparsity), metrics
444 |
445 | def __call__(self):
446 | dev_loss, sparsity, metrics = self.evaluate()
447 |
448 | print("L1 Loss {} STD {} Sparsity {}".format(np.mean(dev_loss), np.std(dev_loss), sparsity))
449 |
450 | with open(os.path.join(self.options.exp_dir, 'statistics.txt'), 'w') as f:
451 | f.write('L1 Loss {} +- {}\n'.format(np.mean(dev_loss), np.std(dev_loss)))
452 | f.write(str(metrics))
453 |
454 | print(metrics)
455 |
456 | with open(os.path.join(self.options.exp_dir, 'loss.pkl'), 'wb') as f:
457 | pickle.dump(dev_loss, f)
458 |
459 | def load(self, checkpoint_file):
460 | checkpoint = torch.load(checkpoint_file)
461 | print("Load checkpoint {} with loss {}".format(checkpoint['epoch'], checkpoint['best_dev_loss']))
462 | self.options = checkpoint['options']
463 | random.seed(self.options.seed)
464 | np.random.seed(self.options.seed)
465 | torch.manual_seed(self.options.seed)
466 |
467 |
468 | if 'bi_dir' not in self.options.__dict__:
469 | self.options.bi_dir = False
470 |
471 | if 'clamp' not in self.options.__dict__:
472 | self.options.clamp = 100
473 |
474 | if 'fixed_input' not in self.options.__dict__:
475 | self.options.fixed_input = False
476 |
477 | print('clamp value {}'.format(self.options.clamp))
478 |
479 | self.model = build_model(self.options)
480 | self.model.eval()
481 | self.model.load_state_dict(checkpoint['model'])
482 |
483 | torch.save({'model':self.model.reconstructor.state_dict()}, os.path.join(*checkpoint_file.split('/')[:-1], 'best_recon.pt'))
484 |
485 | # checkpoint = torch.load('checkpoints/Reconstructors/dicom_ckpt/PG_DICOM_Knee_Reconstructor_UNet/best_model.pt', map_location='cpu')
486 | # self.model.reconstructor.load_state_dict(checkpoint['model'])
487 |
--------------------------------------------------------------------------------
/activemri/baselines/sequential_sampling_codes/__init__.py:
--------------------------------------------------------------------------------
1 | from .conv_sampler import ConvSamplerSmall, LineConstrainedSampler
--------------------------------------------------------------------------------
/activemri/baselines/sequential_sampling_codes/conv_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from activemri.baselines.loupe_codes import transforms
5 |
6 | # ResNet code is modified from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
7 | # Licensed under MIT License
8 |
9 | class BasicBlock(nn.Module):
10 | expansion = 1
11 |
12 | def __init__(self, in_planes, planes, stride=1):
13 | super(BasicBlock, self).__init__()
14 | self.conv1 = nn.Conv2d(
15 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
16 | self.bn1 = nn.InstanceNorm2d(planes)
17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
18 | stride=1, padding=1, bias=False)
19 | self.bn2 = nn.InstanceNorm2d(planes)
20 |
21 | self.shortcut = nn.Sequential()
22 | if stride != 1 or in_planes != self.expansion*planes:
23 | self.shortcut = nn.Sequential(
24 | nn.Conv2d(in_planes, self.expansion*planes,
25 | kernel_size=1, stride=stride, bias=False),
26 | nn.InstanceNorm2d(self.expansion*planes)
27 | )
28 |
29 | def forward(self, x):
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = self.bn2(self.conv2(out))
32 | out += self.shortcut(x)
33 | out = F.relu(out)
34 | return out
35 |
36 | class Bottleneck(nn.Module):
37 | expansion = 4
38 |
39 | def __init__(self, in_planes, planes, stride=1):
40 | super(Bottleneck, self).__init__()
41 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
42 | self.bn1 = nn.InstanceNorm2d(planes)
43 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
44 | stride=stride, padding=1, bias=False)
45 | self.bn2 = nn.InstanceNorm2d(planes)
46 | self.conv3 = nn.Conv2d(planes, self.expansion *
47 | planes, kernel_size=1, bias=False)
48 | self.bn3 = nn.InstanceNorm2d(self.expansion*planes)
49 |
50 | self.shortcut = nn.Sequential()
51 | if stride != 1 or in_planes != self.expansion*planes:
52 | self.shortcut = nn.Sequential(
53 | nn.Conv2d(in_planes, self.expansion*planes,
54 | kernel_size=1, stride=stride, bias=False),
55 | nn.InstanceNorm2d(self.expansion*planes)
56 | )
57 |
58 | def forward(self, x):
59 | out = F.relu(self.bn1(self.conv1(x)))
60 | out = F.relu(self.bn2(self.conv2(out)))
61 | out = self.bn3(self.conv3(out))
62 | out += self.shortcut(x)
63 | out = F.relu(out)
64 | return out
65 |
66 | class ResNet(nn.Module):
67 | def __init__(self, block, num_blocks, num_classes=10, fixed_input=False):
68 | super(ResNet, self).__init__()
69 | self.in_planes = 64
70 | # down sample layers
71 |
72 | self.conv1 = nn.Conv2d(5, 64, kernel_size=3,
73 | stride=1, padding=1, bias=False)
74 | self.bn1 = nn.InstanceNorm2d(64)
75 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
76 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
77 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
78 |
79 | # up sample layers
80 |
81 | self.up_convs = nn.Sequential(
82 | nn.ConvTranspose2d(256, 128, 2, stride=2, bias=False),
83 | nn.InstanceNorm2d(128),
84 | nn.ReLU(inplace=True),
85 | nn.ConvTranspose2d(128, 64, 2, stride=2, bias=False),
86 | nn.InstanceNorm2d(64),
87 | nn.ReLU(inplace=True),
88 | )
89 |
90 | self.conv_last = nn.Sequential(
91 | nn.Conv2d(64, 16, kernel_size=3, padding=1, stride=1),
92 | nn.InstanceNorm2d(16),
93 | nn.ReLU(inplace=True),
94 | nn.Conv2d(16, 1, kernel_size=1, stride=1)
95 | )
96 |
97 | self.fixed_input = fixed_input
98 | if fixed_input:
99 | print('generate random input tensor')
100 | fixed_input_tensor = torch.randn(size=[1, 5, 128, 128])
101 | self.fixed_input_tensor = nn.Parameter(fixed_input_tensor, requires_grad=False)
102 |
103 |
104 | def _make_layer(self, block, planes, num_blocks, stride):
105 | strides = [stride] + [1]*(num_blocks-1)
106 | layers = []
107 | for stride in strides:
108 | layers.append(block(self.in_planes, planes, stride))
109 | self.in_planes = planes * block.expansion
110 | return nn.Sequential(*layers)
111 |
112 | def forward(self, x, mask):
113 | """
114 | This function takes the observed kspace data and sampling trajectories as input
115 | and output next sampling probability mask. We additionally mask out all
116 | previous sampled locations.
117 | input: NHWC
118 | mask: N1HW
119 | """
120 | # NHWC -> NCHW
121 | x = x.permute(0, -1, 1, 2)
122 |
123 | # concatenate with previous sampling mask
124 | x = torch.cat([x, mask], dim=1)
125 |
126 | if self.fixed_input:
127 | x = self.fixed_input_tensor.repeat(x.shape[0], 1, 1, 1).to(x.device)
128 |
129 | # down convs
130 | out = F.relu(self.bn1(self.conv1(x)))
131 | out = self.layer1(out)
132 | out = self.layer2(out)
133 | out = self.layer3(out)
134 |
135 | out = self.up_convs(out)
136 | out = self.conv_last(out)
137 |
138 | out = torch.relu(out) / torch.max(out.reshape(out.shape[0], -1), dim=1)[0].reshape(-1, 1, 1, 1)
139 |
140 | # Don't need to select an existed measurement
141 | new_mask = out * (1-mask)
142 |
143 | return new_mask
144 |
145 | def ConvSamplerSmall(fixed_input=False):
146 | return ResNet(BasicBlock, [1, 1, 1], fixed_input=fixed_input)
147 |
148 |
149 |
150 | class LineConstrainedSampler(nn.Module):
151 | """
152 | A line constrained convolutional sampler which uses the same architecture
153 | as the evaluator in https://arxiv.org/pdf/1902.03051.pdf
154 | In this module, we first create a pseudo image containing spectral maps
155 | corresponding to every kspace line. We then pass this pseudo image into
156 | a classification network to predict the probability of selecting
157 | each kspace line. Refer to the above paper for details
158 | """
159 |
160 | def __init__(self, shape=[32]):
161 | super().__init__()
162 | self.mask_conv = nn.Linear(shape[0], 6)
163 | self.middle_convs = nn.Sequential(
164 | nn.Conv2d(6+shape[0], 256, kernel_size=3, stride=2, padding=1),
165 | nn.InstanceNorm2d(256),
166 | nn.LeakyReLU(0.2),
167 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
168 | nn.InstanceNorm2d(512),
169 | nn.LeakyReLU(0.2),
170 | nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
171 | nn.InstanceNorm2d(1024),
172 | nn.LeakyReLU(0.2),
173 | )
174 | self.conv_last = nn.Linear(1024, shape[0])
175 |
176 | # 1xWxHxWxC
177 | # spectral maps of each kspace column
178 | self.masks = torch.zeros(1, shape[1], shape[0], shape[1], 1, requires_grad=False).cuda()
179 | for i in range(shape[1]):
180 | self.masks[:, i, :, i] = 1
181 |
182 | print("Finish Mask Initialization")
183 |
184 | def _to_spectral(self, kspace):
185 | """
186 | Args:
187 | kspace (torch.Tensor): Already sampled measurements shape NHWC
188 |
189 | Returns:
190 | spectral_map (torch.Tensor): NWHW
191 | """
192 | spectral_map = transforms.complex_abs(transforms.fftshift(transforms.ifft2(kspace.unsqueeze(1) * self.masks)))
193 |
194 | return spectral_map
195 |
196 | def forward(self, kspace, mask):
197 | """
198 | Args:
199 | kspace (torch.Tensor): Already sampled measurements shape NHWC
200 | mask (torch.Tensor): previous sampling trajectories shape N1HW
201 |
202 | Returns:
203 | new_mask (torch.Tensor): probability map of next iteration's sampling locations
204 | shape N1HW <- broadcasted from N11W
205 | """
206 | N, C, H, W = mask.shape
207 | spectral_map = self._to_spectral(kspace)
208 | # mask = mask[:, 0, 0]
209 | mask = (mask.sum(dim=-2) == H).float().squeeze()
210 |
211 | mask_embedding = self.mask_conv(mask).reshape(N, 6, 1, 1).repeat(1, 1, H, W)
212 |
213 | spectral_map = torch.cat([spectral_map, mask_embedding], dim=1)
214 |
215 | out = self.middle_convs(spectral_map)
216 |
217 | # global average pooling
218 | out = out.mean([-2, -1])
219 | out = self.conv_last(out)
220 |
221 | out = torch.sigmoid(out)
222 |
223 | # Don't need to select an existed measurement
224 | new_mask = out * (1-mask)
225 |
226 | # don't use repeat here. Pytorch can automatically broadcast.
227 | return new_mask.reshape(N, 1, 1, W)
228 |
229 | class KspaceLineConstrainedSampler(nn.Module):
230 | """
231 | A line constrained convolutional sampler which uses the same architecture
232 | as the evaluator in https://arxiv.org/pdf/1902.03051.pdf
233 | In this module, we first create a pseudo image containing spectral maps
234 | corresponding to every kspace line. We then pass this pseudo image into
235 | a classification network to predict the probability of selecting
236 | each kspace line. Refer to the above paper for details
237 | """
238 |
239 | def __init__(self, in_chans, out_chans, clamp=100, with_uncertainty=False, fixed_input=False):
240 | super().__init__()
241 | """self.middle_convs = nn.Sequential(
242 | nn.Conv2d(1+2, 64, kernel_size=3, stride=2, padding=1),
243 | nn.InstanceNorm2d(64),
244 | nn.LeakyReLU(0.2),
245 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
246 | nn.InstanceNorm2d(128),
247 | nn.LeakyReLU(0.2),
248 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
249 | nn.InstanceNorm2d(256),
250 | nn.LeakyReLU(0.2),
251 | )"""
252 |
253 | if with_uncertainty:
254 | self.flatten_size = int((in_chans) **2 * 6)
255 | else:
256 | self.flatten_size = int((in_chans) **2 * 5)
257 |
258 | self.conv_last = nn.Sequential(
259 | nn.Linear(self.flatten_size, 512),
260 | nn.ReLU(inplace=True),
261 | nn.Linear(512, 512),
262 | nn.ReLU(inplace=True),
263 | nn.Linear(512, 512),
264 | nn.ReLU(inplace=True),
265 | nn.Linear(512, 512),
266 | nn.ReLU(inplace=True),
267 | nn.Linear(512, out_chans)
268 | )
269 |
270 | self.clamp = clamp
271 | self.with_uncertainty = with_uncertainty
272 | self.fixed_input = fixed_input
273 |
274 | if fixed_input:
275 | print("Generate random input tensor")
276 | fixed_input_tensor = torch.randn(size=[1, 5, in_chans, in_chans])
277 | self.fixed_input_tensor = nn.Parameter(fixed_input_tensor, requires_grad=False)
278 |
279 | print("Finish Mask Initialization")
280 |
281 | def forward(self, kspace, mask, uncertainty_map=None):
282 | """
283 | Args:
284 | kspace (torch.Tensor): Already sampled measurements shape NHWC
285 | mask (torch.Tensor): previous sampling trajectories shape N1HW
286 |
287 | Returns:
288 | new_mask (torch.Tensor): probability map of next iteration's sampling locations
289 | shape N1HW <- broadcasted from N11W
290 | """
291 | N, C, H, W = mask.shape
292 |
293 | if self.with_uncertainty:
294 | assert 0
295 | feat_map = torch.cat([kspace.permute(0, 3, 1, 2), uncertainty_map], dim=1)
296 | else:
297 | feat_map = torch.cat([kspace.permute(0, 3, 1, 2), mask], dim=1)
298 |
299 | if self.fixed_input:
300 | feat_map = self.fixed_input_tensor.repeat(N, 1, 1, 1).to(kspace.device)
301 |
302 | out = feat_map.flatten(start_dim=1)
303 |
304 | out = self.conv_last(out)
305 |
306 | out = F.softplus(out)
307 |
308 | # out = out / torch.clamp(out / torch.clamp(torch.max(out.reshape(out.shape[0], -1), dim=1)[0].reshape(-1, 1), min=1e-10), max=1-1e-10, min=1e-10)
309 | out = out / torch.max(out.reshape(out.shape[0], -1), dim=1)[0].reshape(-1, 1)
310 |
311 | # Don't need to select an existed measurement
312 | if out.shape[-1] == mask.shape[-1]:
313 | vertical_mask = (mask.sum(dim=-2) == H).float().reshape(N, -1)
314 | new_mask = out * (1-vertical_mask)
315 | else:
316 | # verify horizontal / vertical separately
317 | vertical_mask = (mask.sum(dim=-2) == H).float().reshape(N, -1)
318 | horizontal_mask = (mask.transpose(-2, -1).sum(dim=-2)==W).float().reshape(N, -1)
319 |
320 | length = out.shape[1] // 2
321 | new_mask = torch.zeros_like(out)
322 | new_mask[:, :length] = out[:, :length] * (1-vertical_mask)
323 | new_mask[:, length:] = out[:, length:] * (1-horizontal_mask)
324 |
325 | # don't use repeat here. Pytorch can automatically broadcast.
326 | return new_mask.reshape(N, 1, 1, -1)
327 |
328 |
329 | class BiLineConstrainedSampler(nn.Module):
330 | """
331 | A line constrained convolutional sampler which uses the same architecture
332 | as the evaluator in https://arxiv.org/pdf/1902.03051.pdf
333 | In this module, we first create a pseudo image containing spectral maps
334 | corresponding to every kspace line. We then pass this pseudo image into
335 | a classification network to predict the probability of selecting
336 | each kspace line. Refer to the above paper for details
337 | """
338 |
339 | def __init__(self, in_chans, clamp, with_uncertainty=False, fixed_input=False):
340 | super().__init__()
341 |
342 | self.sampler = KspaceLineConstrainedSampler(in_chans=in_chans, out_chans=in_chans*2, clamp=clamp,
343 | with_uncertainty=with_uncertainty, fixed_input=fixed_input)
344 |
345 | def forward(self, kspace, mask, uncertainty_map=None):
346 | """
347 | Args:
348 | kspace (torch.Tensor): Already sampled measurements shape NHWC
349 | mask (torch.Tensor): previous sampling trajectories shape N1HW
350 |
351 | Returns:
352 | new_mask (torch.Tensor): probability map of next iteration's sampling locations
353 | shape N1HW <- broadcasted from N11W
354 | """
355 | return self.sampler.forward(kspace, mask, uncertainty_map)
356 |
--------------------------------------------------------------------------------
/activemri/baselines/sequential_sampling_codes/joint_reconstructor.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/activemri/baselines/sequential_sampling_codes/joint_reconstructor.py
--------------------------------------------------------------------------------
/activemri/baselines/sequential_sampling_codes/sampler2d.py:
--------------------------------------------------------------------------------
1 | from activemri.baselines.sequential_sampling_codes.conv_sampler import BasicBlock
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from activemri.baselines.loupe_codes import transforms
6 | from activemri.baselines.loupe_codes.reconstructors import ConvBlock
7 |
8 | class Sampler2D(nn.Module):
9 | def __init__(self, num_blocks=[1, 1, 1], fixed_input=False):
10 | super(Sampler2D, self).__init__()
11 |
12 | in_chans = 5
13 | out_chans = 1
14 | chans = 64
15 | num_pool_layers = 4
16 | drop_prob = 0
17 |
18 | self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
19 | ch = chans
20 | for i in range(num_pool_layers - 1):
21 | self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob)]
22 | ch *= 2
23 | self.conv = ConvBlock(ch, ch, drop_prob)
24 |
25 | self.up_sample_layers = nn.ModuleList()
26 | for i in range(num_pool_layers - 1):
27 | self.up_sample_layers += [ConvBlock(ch * 2, ch // 2, drop_prob)]
28 | ch //= 2
29 | self.up_sample_layers += [ConvBlock(ch * 2, ch, drop_prob)]
30 | self.conv2 = nn.Sequential(
31 | nn.Conv2d(ch, ch // 2, kernel_size=1),
32 | nn.Conv2d(ch // 2, out_chans, kernel_size=1),
33 | nn.Conv2d(out_chans, out_chans, kernel_size=1),
34 | )
35 |
36 | self.fixed_input = fixed_input
37 | if fixed_input:
38 | print('generate random input tensor')
39 | fixed_input_tensor = torch.randn(size=[1, 5, 128, 128])
40 | self.fixed_input_tensor = nn.Parameter(fixed_input_tensor, requires_grad=False)
41 |
42 | def _make_layer(self, block, planes, num_blocks, stride):
43 | strides = [stride] + [1]*(num_blocks-1)
44 | layers = []
45 | for stride in strides:
46 | layers.append(block(self.in_planes, planes, stride))
47 | self.in_planes = planes * block.expansion
48 | return nn.Sequential(*layers)
49 |
50 | def forward(self, x, mask):
51 | """
52 | This function takes the observed kspace data and sampling trajectories as input
53 | and output next sampling probability mask. We additionally mask out all
54 | previous sampled locations.
55 | input: NHWC
56 | mask: N1HW
57 | """
58 | # NHWC -> NCHW
59 | x = x.permute(0, -1, 1, 2)
60 |
61 | # concatenate with previous sampling mask
62 | x = torch.cat([x, mask], dim=1)
63 |
64 | if self.fixed_input:
65 | x = self.fixed_input_tensor
66 |
67 | output = x
68 |
69 | stack = []
70 | # print(input.shape, mask.shape)
71 |
72 | # Apply down-sampling layers
73 | for layer in self.down_sample_layers:
74 | output = layer(output)
75 | stack.append(output)
76 | output = F.max_pool2d(output, kernel_size=2)
77 |
78 | output = self.conv(output)
79 |
80 | # Apply up-sampling layers
81 | for layer in self.up_sample_layers:
82 | downsample_layer = stack.pop()
83 | layer_size = (downsample_layer.shape[-2], downsample_layer.shape[-1])
84 | output = F.interpolate(output, size=layer_size, mode='bilinear', align_corners=False)
85 | output = torch.cat([output, downsample_layer], dim=1)
86 | output = layer(output)
87 |
88 | out = self.conv2(output)
89 |
90 | out = F.softplus(out)
91 |
92 | out = out / torch.max(out.reshape(out.shape[0], -1), dim=1)[0].reshape(-1, 1, 1, 1)
93 |
94 | # Don't need to select an existed measurement
95 | new_mask = out * (1-mask)
96 |
97 | return new_mask
98 |
--------------------------------------------------------------------------------
/activemri/baselines/simple_baselines.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """
7 | activemri.baselines.simple_baselines.py
8 | =======================================
9 | Simple baselines for active MRI acquisition.
10 | """
11 | from typing import Any, Dict, List, Optional
12 |
13 | import numpy as np
14 | import torch
15 |
16 | import activemri.envs
17 |
18 | from . import Policy
19 |
20 |
21 | class RandomPolicy(Policy):
22 | """A policy representing random k-space selection.
23 |
24 | Returns one of the valid actions uniformly at random.
25 |
26 | Args:
27 | seed(optional(int)): The seed to use for the random number generator, which is
28 | based on ``torch.Generator()``.
29 | """
30 |
31 | def __init__(self, seed: Optional[int] = None):
32 | super().__init__()
33 | self.rng = torch.Generator()
34 | if seed:
35 | self.rng.manual_seed(seed)
36 |
37 | def get_action(self, obs: Dict[str, Any], **_kwargs) -> List[int]:
38 | """Returns a random action without replacement.
39 |
40 | Args:
41 | obs(dict(str, any)): As returned by :class:`activemri.envs.ActiveMRIEnv`.
42 |
43 | Returns:
44 | list(int): A list of random k-space column indices, one per batch element in
45 | the observation. The indices are sampled from the set of inactive (0) columns
46 | on each batch element.
47 | """
48 | return (
49 | (obs["mask"].logical_not().float() + 1e-6)
50 | .multinomial(1, generator=self.rng)
51 | .squeeze()
52 | .tolist()
53 | )
54 |
55 |
56 | class RandomLowBiasPolicy(Policy):
57 | def __init__(
58 | self, acceleration: float, centered: bool = True, seed: Optional[int] = None
59 | ):
60 | super().__init__()
61 | self.acceleration = acceleration
62 | self.centered = centered
63 | self.rng = np.random.RandomState(seed)
64 |
65 | def get_action(self, obs: Dict[str, Any], **_kwargs) -> List[int]:
66 | mask = obs["mask"].squeeze().cpu().numpy()
67 | new_mask = self._cartesian_mask(mask)
68 | action = (new_mask - mask).argmax(axis=1)
69 | return action.tolist()
70 |
71 | @staticmethod
72 | def _normal_pdf(length: int, sensitivity: float):
73 | return np.exp(-sensitivity * (np.arange(length) - length / 2) ** 2)
74 |
75 | def _cartesian_mask(self, current_mask: np.ndarray) -> np.ndarray:
76 | batch_size, image_width = current_mask.shape
77 | pdf_x = RandomLowBiasPolicy._normal_pdf(
78 | image_width, 0.5 / (image_width / 10.0) ** 2
79 | )
80 | pdf_x = np.expand_dims(pdf_x, axis=0)
81 | lmda = image_width / (2.0 * self.acceleration)
82 | # add uniform distribution
83 | pdf_x += lmda * 1.0 / image_width
84 | # remove previously chosen columns
85 | # note that pdf_x designed for centered masks
86 | new_mask = (
87 | np.fft.ifftshift(current_mask, axes=1)
88 | if not self.centered
89 | else current_mask.copy()
90 | )
91 | pdf_x = pdf_x * np.logical_not(new_mask)
92 | # normalize probabilities and choose accordingly
93 | pdf_x /= pdf_x.sum(axis=1, keepdims=True)
94 | indices = [
95 | self.rng.choice(image_width, 1, False, pdf_x[i]).item()
96 | for i in range(batch_size)
97 | ]
98 | new_mask[range(batch_size), indices] = 1
99 | if not self.centered:
100 | new_mask = np.fft.ifftshift(new_mask, axes=1)
101 | return new_mask
102 |
103 |
104 | class LowestIndexPolicy(Policy):
105 | """A policy that represents low-to-high frequency k-space selection.
106 |
107 | Args:
108 | alternate_sides(bool): If ``True`` the indices of selected actions will alternate
109 | between the sides of the mask. For example, for an image with 100
110 | columns, and non-centered k-space, the order will be 0, 99, 1, 98, 2, 97, ..., etc.
111 | For the same size and centered, the order will be 49, 50, 48, 51, 47, 52, ..., etc.
112 |
113 | centered(bool): If ``True`` (default), low frequencies are in the center of the mask.
114 | Otherwise, they are in the edges of the mask.
115 | """
116 |
117 | def __init__(
118 | self,
119 | alternate_sides: bool,
120 | centered: bool = True,
121 | ):
122 | super().__init__()
123 | self.alternate_sides = alternate_sides
124 | self.centered = centered
125 | self.bottom_side = True
126 |
127 | def get_action(self, obs: Dict[str, Any], **_kwargs) -> List[int]:
128 | """Returns a random action without replacement.
129 |
130 | Args:
131 | obs(dict(str, any)): As returned by :class:`activemri.envs.ActiveMRIEnv`.
132 |
133 | Returns:
134 | list(int): A list of k-space column indices, one per batch element in
135 | the observation, equal to the lowest non-active k-space column in their
136 | corresponding observation masks.
137 | """
138 | mask = obs["mask"].squeeze().cpu().numpy()
139 | new_mask = self._get_new_mask(mask)
140 | action = (new_mask - mask).argmax(axis=1)
141 | return action.tolist()
142 |
143 | def _get_new_mask(self, current_mask: np.ndarray) -> np.ndarray:
144 | # The code below assumes mask in non centered
145 | new_mask = (
146 | np.fft.ifftshift(current_mask, axes=1)
147 | if self.centered
148 | else current_mask.copy()
149 | )
150 | if self.bottom_side:
151 | idx = np.arange(new_mask.shape[1], 0, -1)
152 | else:
153 | idx = np.arange(new_mask.shape[1])
154 | if self.alternate_sides:
155 | self.bottom_side = not self.bottom_side
156 | # Next line finds the first non-zero index (from edge to center) and returns it
157 | indices = (np.logical_not(new_mask) * idx).argmax(axis=1)
158 | indices = np.expand_dims(indices, axis=1)
159 | new_mask[range(new_mask.shape[0]), indices] = 1
160 | if self.centered:
161 | new_mask = np.fft.ifftshift(new_mask, axes=1)
162 | return new_mask
163 |
164 |
165 | class OneStepGreedyOracle(Policy):
166 | """A policy that returns the k-space column leading to best reconstruction score.
167 |
168 | Args:
169 | env(``activemri.envs.ActiveMRIEnv``): The environment for which the policy is computed
170 | for.
171 | metric(str): The name of the score metric to use (must be in ``env.score_keys()``).
172 | num_samples(optional(int)): If given, only ``num_samples`` random actions will be
173 | tested. Defaults to ``None``, which means that method will consider all actions.
174 | rng(``numpy.random.RandomState``): A random number generator to use for sampling.
175 | """
176 |
177 | def __init__(
178 | self,
179 | env: activemri.envs.ActiveMRIEnv,
180 | metric: str,
181 | num_samples: Optional[int] = None,
182 | rng: Optional[np.random.RandomState] = None,
183 | ):
184 | assert metric in env.score_keys()
185 | super().__init__()
186 | self.env = env
187 | self.metric = metric
188 | self.num_samples = num_samples
189 | self.rng = rng if rng is not None else np.random.RandomState()
190 |
191 | def get_action(self, obs: Dict[str, Any], **_kwargs) -> List[int]:
192 | """Returns a one-step greedy action maximizing reconstruction score.
193 |
194 | Args:
195 | obs(dict(str, any)): As returned by :class:`activemri.envs.ActiveMRIEnv`.
196 |
197 | Returns:
198 | list(int): A list of k-space column indices, one per batch element in
199 | the observation, equal to the action that maximizes reconstruction score
200 | (e.g, SSIM or negative MSE).
201 | """
202 | mask = obs["mask"]
203 | batch_size = mask.shape[0]
204 | all_action_lists = []
205 | for i in range(batch_size):
206 | available_actions = mask[i].logical_not().nonzero().squeeze().tolist()
207 | self.rng.shuffle(available_actions)
208 | if len(available_actions) < self.num_samples:
209 | # Add dummy actions to try if num of samples is higher than the
210 | # number of inactive columns in this mask
211 | available_actions.extend(
212 | [0] * (self.num_samples - len(available_actions))
213 | )
214 | all_action_lists.append(available_actions)
215 |
216 | all_scores = np.zeros((batch_size, self.num_samples))
217 | for i in range(self.num_samples):
218 | batch_action_to_try = [action_list[i] for action_list in all_action_lists]
219 | obs, new_score = self.env.try_action(batch_action_to_try)
220 | all_scores[:, i] = new_score[self.metric]
221 | if self.metric in ["mse", "nmse"]:
222 | all_scores *= -1
223 | else:
224 | assert self.metric in ["ssim", "psnr"]
225 |
226 | best_indices = all_scores.argmax(axis=1)
227 | action = []
228 | for i in range(batch_size):
229 | action.append(all_action_lists[i][best_indices[i]])
230 | return action
231 |
--------------------------------------------------------------------------------
/activemri/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from typing import Any, Dict, List
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from . import singlecoil_knee_data
12 | from . import transforms
13 |
14 | __all__ = ["singlecoil_knee_data", "transforms"]
15 |
16 |
17 | def transform_template(
18 | kspace: List[np.ndarray] = None,
19 | mask: torch.Tensor = None,
20 | ground_truth: torch.Tensor = None,
21 | attrs: List[Dict[str, Any]] = None,
22 | fname: List[str] = None,
23 | slice_id: List[int] = None,
24 | ):
25 | """Template for transform functions.
26 |
27 | Args:
28 | - kspace(list(np.ndarray)): A list of complex numpy arrays, one per k-space in the batch.
29 | The length is the ``batch_size``, and array shapes are ``H x W x 2`` for single coil data,
30 | and ``C x H x W x 2`` for multicoil data, where ``H`` denotes k-space height, ``W``
31 | denotes k-space width, and ``C`` is the number of coils. Note that the width can differ
32 | between batch elements, if ``num_cols`` is set to a tuple when creating the environment.
33 | - mask(torch.Tensor): A tensor of binary column masks, where 1s indicate that the
34 | corresponding k-space column should be selected. The shape is ``batch_size x 1 x maxW``,
35 | for single coil data, and ``batch_size x 1 x 1 x maxW`` for multicoil data. Here ``maxW``
36 | is the maximum k-space width returned by the environment.
37 | - ground_truth(torch.Tensor): A tensor of ground truth 2D images. The shape is
38 | ``batch_size x 320 x 320``.
39 | - attrs(list(dict)): A list of dictionaries with the attributes read from the fastMRI for
40 | each image.
41 | - fname(list(str)): A list of the filenames where the images where read from.
42 | - slice_id(list(int)): A list with the slice ids in the files where each image was read
43 | from.
44 |
45 | Returns:
46 | tuple(Any...): A tuple with any number of inputs required by the reconstructor model.
47 |
48 | """
49 | pass
50 |
--------------------------------------------------------------------------------
/activemri/data/real_brain_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # from activemri.baselines.loupe_codes.transforms import normalize
7 | import pathlib
8 | from typing import Callable, List, Optional, Tuple
9 |
10 | import fastmri
11 | import h5py
12 | import numpy as np
13 | import torch.utils.data
14 | import activemri
15 | import scipy.ndimage as ndimage
16 |
17 | class RealBrainData(torch.utils.data.Dataset):
18 | # This is the same as fastMRI singlecoil_knee, except we provide a custom test split
19 | # and also normalize images by the mean norm of the k-space over training data
20 | # KSPACE_WIDTH = 368
21 | # KSPACE_HEIGHT = 640
22 | # START_PADDING = 166
23 | # END_PADDING = 202
24 | # CENTER_CROP_SIZE = 320
25 |
26 | def __init__(
27 | self,
28 | root: pathlib.Path,
29 | image_shape: Tuple[int, int],
30 | transform: Callable,
31 | noise_type: str,
32 | noise_level: float = 5e-5,
33 | num_cols: Optional[int] = None,
34 | num_volumes: Optional[int] = None,
35 | num_rand_slices: Optional[int] = None,
36 | custom_split: Optional[str] = None,
37 | random_rotate=False
38 | ):
39 | self.image_shape = image_shape
40 | self.transform = transform
41 | self.noise_type = noise_type
42 | self.noise_level = noise_level
43 | self.examples: List[Tuple[pathlib.PurePath, int]] = []
44 |
45 |
46 | self.num_rand_slices = num_rand_slices
47 | self.rng = np.random.RandomState(1234)
48 | self.recons_key = 'reconstruction_rss'
49 |
50 | files = []
51 | for fname in list(pathlib.Path(root).iterdir()):
52 | data = h5py.File(fname, "r")
53 | if 'reconstruction_rss' not in data.keys():
54 | continue
55 | files.append(fname)
56 |
57 | self.train_mode = False
58 |
59 | """if custom_split is not None:
60 | split_info = []
61 | with open(f"activemri/data/splits/knee_singlecoil/{custom_split}.txt") as f:
62 | for line in f:
63 | split_info.append(line.rsplit("\n")[0])
64 | files = [f for f in files if f.name in split_info]
65 | else:
66 | self.train_mode = True
67 | """
68 | if num_volumes is not None:
69 | self.rng.shuffle(files)
70 | files = files[:num_volumes]
71 |
72 | for volume_i, fname in enumerate(sorted(files)):
73 | data = h5py.File(fname, "r")
74 | # kspace = data["kspace"]
75 |
76 | if num_rand_slices is None:
77 | num_slices = data['reconstruction_rss'].shape[0]
78 | self.examples += [(fname, slice_id) for slice_id in range(num_slices)]
79 | else:
80 | assert 0
81 | slice_ids = list(range(kspace.shape[0]))
82 | self.rng.seed(seed=volume_i)
83 | self.rng.shuffle(slice_ids)
84 | self.examples += [
85 | (fname, slice_id) for slice_id in slice_ids[:num_rand_slices]
86 | ]
87 |
88 | self.random_rotate = random_rotate
89 | if self.random_rotate:
90 | np.random.seed(42)
91 | self.random_angles = np.random.random(len(self)) * 30- 15
92 |
93 | def center_crop(self, data, shape):
94 | """
95 | (Same as the one in activemri/baselines/policy_gradient_codes/helpers/transforms.py)
96 | Apply a center crop to the input real image or batch of real images.
97 |
98 | Args:
99 | data (torch.Tensor): The input tensor to be center cropped. It should have at
100 | least 2 dimensions and the cropping is applied along the last two dimensions.
101 | shape (int, int): The output shape. The shape should be smaller than the
102 | corresponding dimensions of data.
103 |
104 | Returns:
105 | torch.Tensor: The center cropped image
106 | """
107 | assert 0 < shape[0] <= data.shape[-2]
108 | assert 0 < shape[1] <= data.shape[-1]
109 | w_from = (data.shape[-2] - shape[0]) // 2
110 | h_from = (data.shape[-1] - shape[1]) // 2
111 | w_to = w_from + shape[0]
112 | h_to = h_from + shape[1]
113 | return data[..., w_from:w_to, h_from:h_to]
114 |
115 | def __len__(self):
116 | return len(self.examples)
117 |
118 | def __getitem__(self, i):
119 | fname, slice_id = self.examples[i]
120 | with h5py.File(fname, 'r') as data:
121 |
122 | # kspace = data["kspace"][slice_id]
123 | # kspace = np.stack([kspace.real, kspace.imag], axis=-1)
124 | # if self.random_rotate:
125 | # kspace = ndimage.rotate(kspace, self.random_angles[i], reshape=False, mode='nearest')
126 |
127 | #kspace = torch.from_numpy(kspace).permute(2, 0, 1)
128 | #kspace = self.center_crop(kspace, self.image_shape).permute(1, 2, 0)
129 |
130 | #kspace = fastmri.ifftshift(kspace, dim=(0, 1))
131 | target = torch.from_numpy(data['reconstruction_rss'][slice_id]).unsqueeze(-1)
132 | target = torch.cat([target, torch.zeros_like(target)], dim=-1)
133 | target = self.center_crop(target.permute(2, 0, 1), self.image_shape).permute(1, 2, 0)
134 |
135 | kspace = fastmri.ifftshift(target, dim=(0, 1)).fft(2, normalized=False)
136 |
137 |
138 | # target = torch.ifft(kspace, 2, normalized=False)
139 | #target = fastmri.ifftshift(target, dim=(0, 1))
140 |
141 |
142 | # Normalize using mean of k-space in training data
143 | # target /= 7.072103529760345e-07
144 | # kspace /= 7.072103529760345e-07
145 |
146 | kspace = kspace.numpy()
147 | target = target.numpy()
148 |
149 | return self.transform(
150 | kspace,
151 | torch.zeros(kspace.shape[1]),
152 | target,
153 | dict(data.attrs),
154 | fname.name,
155 | slice_id
156 | )
157 |
158 | """from activemri.envs.envs import ActiveMRIEnv
159 | data = RealBrainData(
160 | 'datasets/brain/train_no_kspace',
161 | (128, 128),
162 | ActiveMRIEnv._void_transform,
163 | noise_type='none'
164 | )
165 |
166 | data[0]
167 | """
--------------------------------------------------------------------------------
/activemri/data/real_knee_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # from activemri.baselines.loupe_codes.transforms import normalize
7 | import pathlib
8 | from typing import Callable, List, Optional, Tuple
9 |
10 | import fastmri
11 | import h5py
12 | import numpy as np
13 | import torch.utils.data
14 | import activemri
15 | import scipy.ndimage as ndimage
16 |
17 | class RealKneeData(torch.utils.data.Dataset):
18 | # This is the same as fastMRI singlecoil_knee, except we provide a custom test split
19 | # and also normalize images by the mean norm of the k-space over training data
20 | # KSPACE_WIDTH = 368
21 | # KSPACE_HEIGHT = 640
22 | # START_PADDING = 166
23 | # END_PADDING = 202
24 | # CENTER_CROP_SIZE = 320
25 |
26 | def __init__(
27 | self,
28 | root: pathlib.Path,
29 | image_shape: Tuple[int, int],
30 | transform: Callable,
31 | noise_type: str,
32 | noise_level: float = 5e-5,
33 | num_cols: Optional[int] = None,
34 | num_volumes: Optional[int] = None,
35 | num_rand_slices: Optional[int] = None,
36 | custom_split: Optional[str] = None,
37 | random_rotate=False
38 | ):
39 | self.image_shape = image_shape
40 | self.transform = transform
41 | self.noise_type = noise_type
42 | self.noise_level = noise_level
43 | self.examples: List[Tuple[pathlib.PurePath, int]] = []
44 |
45 | self.num_rand_slices = num_rand_slices
46 | self.rng = np.random.RandomState(1234)
47 | self.recons_key = 'reconstruction_rss'
48 |
49 | files = []
50 | for fname in list(pathlib.Path(root).iterdir()):
51 | data = h5py.File(fname, "r")
52 | if 'reconstruction_rss' not in data.keys():
53 | continue
54 | files.append(fname)
55 |
56 | self.train_mode = False
57 |
58 | if custom_split is not None:
59 | split_info = []
60 | with open(f"activemri/data/splits/knee_singlecoil/{custom_split}.txt") as f:
61 | for line in f:
62 | split_info.append(line.rsplit("\n")[0])
63 | files = [f for f in files if f.name in split_info]
64 | else:
65 | self.train_mode = True
66 |
67 | if num_volumes is not None:
68 | self.rng.shuffle(files)
69 | files = files[:num_volumes]
70 |
71 | for volume_i, fname in enumerate(sorted(files)):
72 | data = h5py.File(fname, "r")
73 | kspace = data["kspace"]
74 |
75 | if num_rand_slices is None:
76 | num_slices = kspace.shape[0]
77 | self.examples += [(fname, slice_id) for slice_id in range(num_slices)]
78 | else:
79 | slice_ids = list(range(kspace.shape[0]))
80 | self.rng.seed(seed=volume_i)
81 | self.rng.shuffle(slice_ids)
82 | self.examples += [
83 | (fname, slice_id) for slice_id in slice_ids[:num_rand_slices]
84 | ]
85 |
86 | self.random_rotate = random_rotate
87 | if self.random_rotate:
88 | np.random.seed(42)
89 | self.random_angles = np.random.random(len(self)) * 30- 15
90 |
91 | def center_crop(self, data, shape):
92 | """
93 | (Same as the one in activemri/baselines/policy_gradient_codes/helpers/transforms.py)
94 | Apply a center crop to the input real image or batch of real images.
95 |
96 | Args:
97 | data (torch.Tensor): The input tensor to be center cropped. It should have at
98 | least 2 dimensions and the cropping is applied along the last two dimensions.
99 | shape (int, int): The output shape. The shape should be smaller than the
100 | corresponding dimensions of data.
101 |
102 | Returns:
103 | torch.Tensor: The center cropped image
104 | """
105 | assert 0 < shape[0] <= data.shape[-2]
106 | assert 0 < shape[1] <= data.shape[-1]
107 | w_from = (data.shape[-2] - shape[0]) // 2
108 | h_from = (data.shape[-1] - shape[1]) // 2
109 | w_to = w_from + shape[0]
110 | h_to = h_from + shape[1]
111 | return data[..., w_from:w_to, h_from:h_to]
112 |
113 | def __len__(self):
114 | return len(self.examples)
115 |
116 | def __getitem__(self, i):
117 | fname, slice_id = self.examples[i]
118 | with h5py.File(fname, 'r') as data:
119 |
120 | kspace = data["kspace"][slice_id]
121 | kspace = np.stack([kspace.real, kspace.imag], axis=-1)
122 | if self.random_rotate:
123 | kspace = ndimage.rotate(kspace, self.random_angles[i], reshape=False, mode='nearest')
124 |
125 | kspace = torch.from_numpy(kspace).permute(2, 0, 1)
126 | kspace = self.center_crop(kspace, self.image_shape).permute(1, 2, 0)
127 |
128 | kspace = fastmri.ifftshift(kspace, dim=(0, 1))
129 |
130 | target = torch.ifft(kspace, 2, normalized=False)
131 | target = fastmri.ifftshift(target, dim=(0, 1))
132 |
133 |
134 | # Normalize using mean of k-space in training data
135 | target /= 7.072103529760345e-07
136 | kspace /= 7.072103529760345e-07
137 |
138 | kspace = kspace.numpy()
139 | target = target.numpy()
140 |
141 | return self.transform(
142 | kspace,
143 | torch.zeros(kspace.shape[1]),
144 | target,
145 | dict(data.attrs),
146 | fname.name,
147 | slice_id
148 | )
149 |
--------------------------------------------------------------------------------
/activemri/data/singlecoil_knee_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import pathlib
7 | from typing import Callable, List, Optional, Tuple
8 |
9 | import fastmri
10 | import h5py
11 | import numpy as np
12 | import torch.utils.data
13 |
14 |
15 | # -----------------------------------------------------------------------------
16 | # Single coil knee dataset (as used in MICCAI'20)
17 | # -----------------------------------------------------------------------------
18 | class MICCAI2020Data(torch.utils.data.Dataset):
19 | # This is the same as fastMRI singlecoil_knee, except we provide a custom test split
20 | # and also normalize images by the mean norm of the k-space over training data
21 | KSPACE_WIDTH = 368
22 | KSPACE_HEIGHT = 640
23 | START_PADDING = 166
24 | END_PADDING = 202
25 | CENTER_CROP_SIZE = 320
26 |
27 | def __init__(
28 | self,
29 | root: pathlib.Path,
30 | transform: Callable,
31 | num_cols: Optional[int] = None,
32 | num_volumes: Optional[int] = None,
33 | num_rand_slices: Optional[int] = None,
34 | custom_split: Optional[str] = None,
35 | ):
36 | self.transform = transform
37 | self.examples: List[Tuple[pathlib.PurePath, int]] = []
38 |
39 | self.num_rand_slices = num_rand_slices
40 | self.rng = np.random.RandomState(1234)
41 |
42 | files = []
43 | for fname in list(pathlib.Path(root).iterdir()):
44 | data = h5py.File(fname, "r")
45 | if num_cols is not None and data["kspace"].shape[2] != num_cols:
46 | continue
47 | files.append(fname)
48 |
49 | if custom_split is not None:
50 | split_info = []
51 | with open(f"activemri/data/splits/knee_singlecoil/{custom_split}.txt") as f:
52 | for line in f:
53 | split_info.append(line.rsplit("\n")[0])
54 | files = [f for f in files if f.name in split_info]
55 |
56 | if num_volumes is not None:
57 | self.rng.shuffle(files)
58 | files = files[:num_volumes]
59 |
60 | for volume_i, fname in enumerate(sorted(files)):
61 | data = h5py.File(fname, "r")
62 | kspace = data["kspace"]
63 |
64 | if num_rand_slices is None:
65 | num_slices = kspace.shape[0]
66 | self.examples += [(fname, slice_id) for slice_id in range(num_slices)]
67 | else:
68 | slice_ids = list(range(kspace.shape[0]))
69 | self.rng.seed(seed=volume_i)
70 | self.rng.shuffle(slice_ids)
71 | self.examples += [
72 | (fname, slice_id) for slice_id in slice_ids[:num_rand_slices]
73 | ]
74 |
75 | def __len__(self):
76 | return len(self.examples)
77 |
78 | def __getitem__(self, i):
79 | fname, slice_id = self.examples[i]
80 | with h5py.File(fname, "r") as data:
81 | kspace = data["kspace"][slice_id]
82 | kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1))
83 | kspace = fastmri.ifftshift(kspace, dim=(0, 1))
84 | target = torch.ifft(kspace, 2, normalized=False)
85 | target = fastmri.ifftshift(target, dim=(0, 1))
86 | # Normalize using mean of k-space in training data
87 | target /= 7.072103529760345e-07
88 | kspace /= 7.072103529760345e-07
89 |
90 | # Environment expects numpy arrays. The code above was used with an older
91 | # version of the environment to generate the results of the MICCAI'20 paper.
92 | # So, to keep this consistent with the version in the paper, we convert
93 | # the tensors back to numpy rather than changing the original code.
94 | kspace = kspace.numpy()
95 | target = target.numpy()
96 | return self.transform(
97 | kspace,
98 | torch.zeros(kspace.shape[1]),
99 | target,
100 | dict(data.attrs),
101 | fname.name,
102 | slice_id,
103 | )
104 |
--------------------------------------------------------------------------------
/activemri/data/splits/knee_singlecoil/test.txt:
--------------------------------------------------------------------------------
1 | file1001126.h5
2 | file1001930.h5
3 | file1002436.h5
4 | file1002382.h5
5 | file1001585.h5
6 | file1001506.h5
7 | file1001057.h5
8 | file1002021.h5
9 | file1000831.h5
10 | file1001344.h5
11 | file1000007.h5
12 | file1001938.h5
13 | file1001381.h5
14 | file1001365.h5
15 | file1001458.h5
16 | file1001726.h5
17 | file1000280.h5
18 | file1002145.h5
19 | file1001759.h5
20 | file1000976.h5
21 | file1001064.h5
22 | file1000108.h5
23 | file1000291.h5
24 | file1001689.h5
25 | file1001191.h5
26 | file1000903.h5
27 | file1001440.h5
28 | file1001184.h5
29 | file1001119.h5
30 | file1002351.h5
31 | file1001643.h5
32 | file1001893.h5
33 | file1001968.h5
34 | file1001566.h5
35 | file1001850.h5
36 | file1000660.h5
37 | file1000593.h5
38 | file1001763.h5
39 | file1002546.h5
40 | file1000697.h5
41 | file1000190.h5
42 | file1000273.h5
43 | file1001144.h5
44 | file1000538.h5
45 | file1000635.h5
46 | file1000769.h5
47 | file1001262.h5
48 | file1001851.h5
49 | file1000942.h5
50 |
--------------------------------------------------------------------------------
/activemri/data/splits/knee_singlecoil/val.txt:
--------------------------------------------------------------------------------
1 | file1002340.h5
2 | file1000182.h5
3 | file1001655.h5
4 | file1000972.h5
5 | file1001338.h5
6 | file1000476.h5
7 | file1002252.h5
8 | file1000591.h5
9 | file1000858.h5
10 | file1001202.h5
11 | file1002159.h5
12 | file1001163.h5
13 | file1001497.h5
14 | file1000196.h5
15 | file1001331.h5
16 | file1000477.h5
17 | file1001651.h5
18 | file1000464.h5
19 | file1000625.h5
20 | file1000033.h5
21 | file1000041.h5
22 | file1000000.h5
23 | file1000389.h5
24 | file1001668.h5
25 | file1002451.h5
26 | file1000990.h5
27 | file1001077.h5
28 | file1002389.h5
29 | file1001298.h5
30 | file1002570.h5
31 | file1001834.h5
32 | file1001955.h5
33 | file1001715.h5
34 | file1001444.h5
35 | file1002214.h5
36 | file1002187.h5
37 | file1001170.h5
38 | file1000926.h5
39 | file1000818.h5
40 | file1000702.h5
41 | file1001289.h5
42 | file1000528.h5
43 | file1000759.h5
44 | file1001862.h5
45 | file1001339.h5
46 | file1001090.h5
47 | file1002067.h5
48 | file1001221.h5
49 |
--------------------------------------------------------------------------------
/activemri/data/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """
7 | activemri.data.transforms.py
8 | ====================================
9 | Transform functions to process fastMRI data for reconstruction models.
10 | """
11 | from typing import Tuple, Union
12 |
13 | import fastmri
14 | import fastmri.data.transforms as fastmri_transforms
15 | import numpy as np
16 | import torch
17 |
18 | import activemri.data.singlecoil_knee_data as scknee_data
19 |
20 | TensorType = Union[np.ndarray, torch.Tensor]
21 |
22 |
23 | def add_gaussian_noise(kspace, mean=0., std=1.):
24 | # kspace: 32*32*2
25 | noise = mean + torch.randn(kspace.size()) * std
26 | kspace = kspace + noise
27 | return kspace
28 |
29 |
30 | def to_magnitude(tensor: torch.Tensor, dim: int) -> torch.Tensor:
31 | return (tensor ** 2).sum(dim=dim) ** 0.5
32 |
33 |
34 | def center_crop(x: TensorType, shape: Tuple[int, int]) -> TensorType:
35 | """Center crops a tensor to the desired 2D shape.
36 |
37 | Args:
38 | x(union(``torch.Tensor``, ``np.ndarray``)): The tensor to crop.
39 | Shape should be ``(batch_size, height, width)``.
40 | shape(tuple(int,int)): The desired shape to crop to.
41 |
42 | Returns:
43 | (union(``torch.Tensor``, ``np.ndarray``)): The cropped tensor.
44 | """
45 | assert len(x.shape) == 3
46 | assert 0 < shape[0] <= x.shape[1]
47 | assert 0 < shape[1] <= x.shape[2]
48 | h_from = (x.shape[1] - shape[0]) // 2
49 | w_from = (x.shape[2] - shape[1]) // 2
50 | w_to = w_from + shape[0]
51 | h_to = h_from + shape[1]
52 | x = x[:, h_from:h_to, w_from:w_to]
53 | return x
54 |
55 |
56 | def ifft_permute_maybe_shift(
57 | x: torch.Tensor, normalized: bool = False, ifft_shift: bool = False
58 | ) -> torch.Tensor:
59 | x = x.permute(0, 2, 3, 1)
60 | y = torch.ifft(x, 2, normalized=normalized)
61 | if ifft_shift:
62 | y = fastmri.ifftshift(y, dim=(1, 2))
63 | return y.permute(0, 3, 1, 2)
64 |
65 |
66 | def raw_transform_miccai2020(kspace=None, mask=None, **_kwargs):
67 | """Transform to produce input for reconstructor used in `Pineda et al. MICCAI'20 `_.
68 |
69 | Produces a zero-filled reconstruction and a mask that serve as a input to models of type
70 | :class:`activemri.models.cvpr10_reconstructor.CVPR19Reconstructor`. The mask is almost
71 | equal to the mask passed as argument, except that high-frequency padding columns are set
72 | to 1, and the mask is reshaped to be compatible with the reconstructor.
73 |
74 | Args:
75 | kspace(``np.ndarray``): The array containing the k-space data returned by the dataset.
76 | mask(``torch.Tensor``): The masks to apply to the k-space.
77 |
78 | Returns:
79 | tuple: A tuple containing:
80 | - ``torch.Tensor``: The zero-filled reconstructor that will be passed to the
81 | reconstructor.
82 | - ``torch.Tensor``: The mask to use as input to the reconstructor.
83 | """
84 | # alter mask to always include the highest frequencies that include padding
85 | mask[
86 | :,
87 | :,
88 | scknee_data.MICCAI2020Data.START_PADDING : scknee_data.MICCAI2020Data.END_PADDING,
89 | ] = 1
90 | mask = mask.unsqueeze(1)
91 |
92 | all_kspace = []
93 | for ksp in kspace:
94 | all_kspace.append(torch.from_numpy(ksp).permute(2, 0, 1))
95 | k_space = torch.stack(all_kspace)
96 |
97 | masked_true_k_space = torch.where(
98 | mask.byte(),
99 | k_space,
100 | torch.tensor(0.0).to(mask.device),
101 | )
102 | reconstructor_input = ifft_permute_maybe_shift(masked_true_k_space, ifft_shift=True)
103 | return reconstructor_input, mask
104 |
105 |
106 | # Based on
107 | # https://github.com/facebookresearch/fastMRI/blob/master/experimental/unet/unet_module.py
108 | def _base_fastmri_unet_transform(
109 | kspace,
110 | mask,
111 | ground_truth,
112 | attrs,
113 | which_challenge="singlecoil",
114 | ):
115 | kspace = fastmri_transforms.to_tensor(kspace)
116 |
117 | mask = mask[..., : kspace.shape[-2]] # accounting for variable size masks
118 | masked_kspace = kspace * mask.unsqueeze(-1) + 0.0
119 |
120 | # inverse Fourier transform to get zero filled solution
121 | image = fastmri.ifft2c(masked_kspace)
122 |
123 | # crop input to correct size
124 | if ground_truth is not None:
125 | crop_size = (ground_truth.shape[-2], ground_truth.shape[-1])
126 | else:
127 | crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])
128 |
129 | # check for FLAIR 203
130 | if image.shape[-2] < crop_size[1]:
131 | crop_size = (image.shape[-2], image.shape[-2])
132 |
133 | # noinspection PyTypeChecker
134 | image = fastmri_transforms.complex_center_crop(image, crop_size)
135 |
136 | # absolute value
137 | image = fastmri.complex_abs(image)
138 |
139 | # apply Root-Sum-of-Squares if multicoil data
140 | if which_challenge == "multicoil":
141 | image = fastmri.rss(image)
142 |
143 | # normalize input
144 | image, mean, std = fastmri_transforms.normalize_instance(image, eps=1e-11)
145 | image = image.clamp(-6, 6)
146 |
147 | return image.unsqueeze(0), mean, std
148 |
149 |
150 | def _batched_fastmri_unet_transform(
151 | kspace, mask, ground_truth, attrs, which_challenge="singlecoil"
152 | ):
153 | batch_size = len(kspace)
154 | images, means, stds = [], [], []
155 | for i in range(batch_size):
156 | image, mean, std = _base_fastmri_unet_transform(
157 | kspace[i],
158 | mask[i],
159 | ground_truth[i],
160 | attrs[i],
161 | which_challenge=which_challenge,
162 | )
163 | images.append(image)
164 | means.append(mean)
165 | stds.append(std)
166 | return torch.stack(images), torch.stack(means), torch.stack(stds)
167 |
168 |
169 | # noinspection PyUnusedLocal
170 | def fastmri_unet_transform_singlecoil(
171 | kspace=None, mask=None, ground_truth=None, attrs=None, fname=None, slice_id=None
172 | ):
173 | """
174 | Transform to use as input to fastMRI's Unet model for singlecoil data.
175 |
176 | This is an adapted version of the code found in
177 | `fastMRI `_.
178 | """
179 | return _batched_fastmri_unet_transform(
180 | kspace, mask, ground_truth, attrs, "singlecoil"
181 | )
182 |
183 |
184 | # noinspection PyUnusedLocal
185 | def fastmri_unet_transform_multicoil(
186 | kspace=None, mask=None, ground_truth=None, attrs=None, fname=None, slice_id=None
187 | ):
188 | """Transform to use as input to fastMRI's Unet model for multicoil data.
189 |
190 | This is an adapted version of the code found in
191 | `fastMRI `_.
192 | """
193 | return _batched_fastmri_unet_transform(
194 | kspace, mask, ground_truth, attrs, "multicoil"
195 | )
196 |
--------------------------------------------------------------------------------
/activemri/envs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | __all__ = [
7 | "ActiveMRIEnv",
8 | "MICCAI2020Env",
9 | "FastMRIEnv",
10 | "SingleCoilKneeEnv",
11 | "MultiCoilKneeEnv",
12 | "FashionMNISTEnv",
13 | "DicomKneeEnv",
14 | "RealKneeEnv"
15 | ]
16 |
17 | from .envs import (
18 | ActiveMRIEnv,
19 | FastMRIEnv,
20 | MICCAI2020Env,
21 | MultiCoilKneeEnv,
22 | SingleCoilKneeEnv,
23 | FashionMNISTEnv,
24 | DicomKneeEnv,
25 | RealKneeEnv
26 | )
27 |
--------------------------------------------------------------------------------
/activemri/envs/loupe_envs.py:
--------------------------------------------------------------------------------
1 | import fastmri.data
2 | import pathlib
3 | import numpy as np
4 | import torch
5 | import torch.utils.data
6 | from torchvision import datasets
7 | from torchvision import transforms as TF
8 | from torch.utils.data import DataLoader, Subset, random_split
9 |
10 | import activemri.data.singlecoil_knee_data as scknee_data
11 | import activemri.data.transforms
12 | import activemri.envs.masks
13 | import activemri.envs.util
14 | import fastmri
15 | from activemri.data.real_knee_data import RealKneeData
16 | from activemri.data.real_brain_data import RealBrainData
17 | from typing import (
18 | Any,
19 | Callable,
20 | Dict,
21 | Iterator,
22 | List,
23 | Mapping,
24 | Optional,
25 | Sequence,
26 | Sized,
27 | Tuple,
28 | Union,
29 | )
30 | from activemri.envs.envs import ActiveMRIEnv, _env_collate_fn
31 |
32 | def transform(
33 | kspace: torch.Tensor,
34 | mask: torch.Tensor,
35 | target: torch.Tensor,
36 | attrs: List[Dict[str, Any]],
37 | fname: List[str],
38 | slice_id: List[int],
39 | ) -> Tuple:
40 | label = attrs
41 | return torch.from_numpy(kspace), mask, torch.from_numpy(target), label
42 |
43 | class LOUPEDataEnv():
44 | def __init__(self, options):
45 | self._data_location = options.data_path
46 | self.options = options
47 |
48 | def _setup_data_handlers(self):
49 | train_data, val_data, test_data = self._create_datasets()
50 |
51 | display_data = [val_data[i] for i in range(0, len(val_data), len(val_data) // 16)]
52 |
53 | train_loader = DataLoader(
54 | dataset=train_data,
55 | batch_size=self.options.batch_size,
56 | shuffle=True,
57 | num_workers=8,
58 | pin_memory=True,
59 | drop_last=True
60 | )
61 | val_loader = DataLoader(
62 | dataset=val_data,
63 | batch_size=self.options.batch_size,
64 | num_workers=8,
65 | shuffle=False,
66 | pin_memory=True
67 | )
68 | test_loader = DataLoader(
69 | dataset=test_data,
70 | batch_size=self.options.batch_size,
71 | num_workers=8,
72 | shuffle=False,
73 | pin_memory=True
74 | )
75 | display_loader = DataLoader(
76 | dataset=display_data,
77 | batch_size=16,
78 | num_workers=8,
79 | shuffle=False,
80 | pin_memory=True
81 | )
82 | return train_loader, val_loader, test_loader, display_loader
83 |
84 |
85 | class LOUPEActiveMRIEnv(LOUPEDataEnv):
86 | def __init__(self, options):
87 | super().__init__(options)
88 |
89 | def _create_datasets(self):
90 | root_path = pathlib.Path(self._data_location)
91 | train_path = root_path / "knee_singlecoil_train"
92 | val_and_test_path = root_path / "knee_singlecoil_val"
93 |
94 | train_data = scknee_data.MICCAI2020Data(
95 | train_path,
96 | ActiveMRIEnv._void_transform,
97 | num_cols=self.options.resolution[1],
98 | )
99 | val_data = scknee_data.MICCAI2020Data(
100 | val_and_test_path,
101 | ActiveMRIEnv._void_transform,
102 | custom_split="val",
103 | num_cols=self.options.resolution[1],
104 | )
105 | test_data = scknee_data.MICCAI2020Data(
106 | val_and_test_path,
107 | ActiveMRIEnv._void_transform,
108 | custom_split="test",
109 | num_cols=self.options.resolution[1],
110 | )
111 | return train_data, val_data, test_data
112 |
113 | class LOUPERealKspaceEnv(LOUPEDataEnv):
114 | def __init__(self, options):
115 | super().__init__(options)
116 |
117 | def _create_datasets(self):
118 | root_path = pathlib.Path(self._data_location)
119 | train_path = root_path / "knee_singlecoil_train"
120 | val_path = root_path / "knee_singlecoil_val"
121 | test_path = root_path / "knee_singlecoil_val"
122 |
123 | train_data = RealKneeData(
124 | train_path,
125 | self.options.resolution,
126 | ActiveMRIEnv._void_transform,
127 | self.options.noise_type,
128 | self.options.noise_level,
129 | random_rotate=self.options.random_rotate
130 | )
131 | val_data = RealKneeData(
132 | val_path,
133 | self.options.resolution,
134 | ActiveMRIEnv._void_transform,
135 | self.options.noise_type,
136 | self.options.noise_level,
137 | custom_split='val',
138 | random_rotate=self.options.random_rotate
139 | )
140 | test_data = RealKneeData(
141 | test_path,
142 | self.options.resolution,
143 | ActiveMRIEnv._void_transform,
144 | self.options.noise_type,
145 | self.options.noise_level,
146 | custom_split='test',
147 | random_rotate=self.options.random_rotate
148 | )
149 |
150 | return train_data, val_data, test_data
151 |
152 | class LoupeBrainEnv(LOUPEDataEnv):
153 | def __init__(self, options):
154 | super().__init__(options)
155 |
156 | def _create_datasets(self):
157 | root_path = pathlib.Path(self._data_location)
158 | train_path = root_path / "train_no_kspace"
159 | val_path = root_path / "val_no_kspace"
160 | test_path = root_path / "test_no_kspace"
161 |
162 | train_data = RealBrainData(
163 | train_path,
164 | self.options.resolution,
165 | ActiveMRIEnv._void_transform,
166 | self.options.noise_type,
167 | self.options.noise_level,
168 | random_rotate=self.options.random_rotate
169 | )
170 | val_data = RealBrainData(
171 | val_path,
172 | self.options.resolution,
173 | ActiveMRIEnv._void_transform,
174 | self.options.noise_type,
175 | self.options.noise_level,
176 | custom_split='val',
177 | random_rotate=self.options.random_rotate
178 | )
179 | test_data = RealBrainData(
180 | test_path,
181 | self.options.resolution,
182 | ActiveMRIEnv._void_transform,
183 | self.options.noise_type,
184 | self.options.noise_level,
185 | custom_split='test',
186 | random_rotate=self.options.random_rotate
187 | )
188 |
189 | return train_data, val_data, test_data
190 |
191 | class LOUPEFashionMNISTEnv(LOUPEDataEnv):
192 | def __init__(self, options):
193 | super().__init__(options)
194 |
195 | def get_same_index(target, label):
196 | label_indices = []
197 |
198 | for i in range(len(target)):
199 | if target[i] in label:
200 | label_indices.append(i)
201 |
202 | return label_indices
203 |
204 | def _create_datasets(self):
205 | train_val_set = FashionMNISTData(args=self.options, root=self._data_location, transform=transform,
206 | noise_type=self.options.noise_type, noise_level=self.options.noise_level, custom_split='train',
207 | label_range=self.options.label_range)
208 | test_set = FashionMNISTData(args=self.options, root=self._data_location, transform=transform,
209 | noise_type=self.options.noise_type, noise_level=self.options.noise_level, custom_split='test',
210 | label_range=self.options.label_range)
211 |
212 |
213 | val_set_size = int(len(train_val_set)/4)
214 | train_set_size = len(train_val_set) - val_set_size
215 | train_set, val_set = random_split(train_val_set, [train_set_size, val_set_size], generator=torch.Generator().manual_seed(42))
216 |
217 |
218 | if self.options.sample_rate < 1:
219 | # create a random subset
220 | train_set = random_split(train_set, [int(self.options.sample_rate * len(train_set)),
221 | len(train_set)-int(self.options.sample_rate * len(train_set))], generator=torch.Generator().manual_seed(42))[0]
222 |
223 | val_set = random_split(val_set, [int(self.options.sample_rate * len(val_set)),
224 | len(val_set)-int(self.options.sample_rate * len(val_set))], generator=torch.Generator().manual_seed(42))[0]
225 |
226 | return train_set, val_set, test_set
227 |
228 | class LOUPESyntheticEnv(LOUPEDataEnv):
229 | def __init__(self, options):
230 | super().__init__(options)
231 |
232 | def _create_datasets(self):
233 | train_set = SyntheticData(root=self._data_location, split='train')
234 | val_set = SyntheticData(root=self._data_location, split='val')
235 | test_set = SyntheticData(root=self._data_location, split='test')
236 |
237 | return train_set, val_set, test_set
238 |
--------------------------------------------------------------------------------
/activemri/envs/masks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """
7 | activemri.envs.masks.py
8 | ====================================
9 | Utilities to generate and manipulate active acquisition masks.
10 | """
11 | from typing import Any, Dict, List, Optional, Sequence, Tuple
12 |
13 | import fastmri
14 | import numpy as np
15 | import torch
16 |
17 |
18 | def update_masks_from_indices(
19 | masks: torch.Tensor, indices: Sequence[int]
20 | ) -> torch.Tensor:
21 | assert masks.shape[0] == len(indices)
22 | new_masks = masks.clone()
23 | for i in range(len(indices)):
24 | new_masks[i, ..., indices[i]] = 1
25 | return new_masks
26 |
27 |
28 | def check_masks_complete(masks: torch.Tensor) -> List[bool]:
29 | done = []
30 | for mask in masks:
31 | done.append(mask.bool().all().item())
32 | return done
33 |
34 |
35 | def sample_low_frequency_mask(
36 | mask_args: Dict[str, Any],
37 | kspace_shapes: List[Tuple[int, ...]],
38 | rng: np.random.RandomState,
39 | attrs: Optional[List[Dict[str, Any]]] = None,
40 | ) -> torch.Tensor:
41 | """Samples low frequency masks.
42 |
43 | Returns masks that contain some number of the lowest k-space frequencies active.
44 | The number of frequencies doesn't have to be the same for all masks in the batch, and
45 | it can also be a random number, depending on the given ``mask_args``. Active columns
46 | will be represented as 1s in the mask, and inactive columns as 0s.
47 |
48 | The distribution and shape of the masks can be controlled by ``mask_args``. This is a
49 | dictionary with the following keys:
50 |
51 | - *"max_width"(int)*: The maximum width of the masks.
52 | - *"min_cols"(int)*: The minimum number of low frequencies columns to activate per side.
53 | - *"max_cols"(int)*: The maximum number of low frequencies columns to activate
54 | per side (inclusive).
55 | - *"width_dim"(int)*: Indicates which of the dimensions in ``kspace_shapes``
56 | corresponds to the k-space width.
57 | - *"centered"(bool)*: Specifies if the low frequencies are in the center of the
58 | k-space (``True``) or on the edges (``False``).
59 | - *"apply_attrs_padding"(optional(bool))*: If ``True``, the function will read
60 | keys ``"padding_left"`` and ``"padding_right"`` from ``attrs`` and set all
61 | corresponding high-frequency columns to 1.
62 |
63 | The number of 1s in the effective region of the mask (see next paragraph) is sampled
64 | between ``mask_args["min_cols"]`` and ``mask_args["max_cols"]`` (inclusive).
65 | The number of dimensions for the mask tensor will be ``mask_args["width_dim"] + 2``.
66 | The size will be ``[batch_size, 1, ..., 1, mask_args["max_width"]]``. For example, with
67 | ``mask_args["width_dim"] = 1`` and ``mask_args["max_width"] = 368``, output tensor
68 | has shape ``[batch_size, 1, 368]``.
69 |
70 | This function supports simultaneously sampling masks for k-space of different number of
71 | columns. This is controlled by argument ``kspace_shapes``. From this list, the function will
72 | obtain 1) ``batch_size = len(kspace_shapes``), and 2) the width of the k-spaces for
73 | each element in the batch. The i-th mask will have
74 | ``kspace_shapes[item][mask_args["width_dim"]]``
75 | *effective* columns.
76 |
77 |
78 | Note:
79 | The mask tensor returned will always have
80 | ``mask_args["max_width"]`` columns. However, for any element ``i``
81 | s.t. ``kspace_shapes[i][mask_args["width_dim"]] < mask_args["max_width"]``, the
82 | function will then pad the extra k-space columns with 1s. The rest of the columns
83 | will be filled out as if the mask has the same width as that indicated by
84 | ``kspace_shape[i]``.
85 |
86 | Args:
87 | mask_args(dict(str,any)): Specifies configuration options for the masks, as explained
88 | above.
89 |
90 | kspace_shapes(list(tuple(int,...))): Specifies the shapes of the k-space data on
91 | which this mask will be applied, as explained above.
92 |
93 | rng(``np.random.RandomState``): A random number generator to sample the masks.
94 |
95 | attrs(dict(str,int)): Used to determine any high-frequency padding. It must contain
96 | keys ``"padding_left"`` and ``"padding_right"``.
97 |
98 | Returns:
99 | ``torch.Tensor``: The generated low frequency masks.
100 |
101 | """
102 | batch_size = len(kspace_shapes)
103 | num_cols = [shape[mask_args["width_dim"]] for shape in kspace_shapes]
104 | mask = torch.zeros(batch_size, mask_args["max_width"])
105 | num_low_freqs = rng.randint(
106 | mask_args["min_cols"], mask_args["max_cols"] + 1, size=batch_size
107 | )
108 | for i in range(batch_size):
109 | # If padding needs to be accounted for, only add low frequency lines
110 | # beyond the padding
111 | if attrs and mask_args.get("apply_attrs_padding", False):
112 | padding_left = attrs[i]["padding_left"]
113 | padding_right = attrs[i]["padding_right"]
114 | else:
115 | padding_left, padding_right = 0, num_cols[i]
116 |
117 | pad = (num_cols[i] - 2 * num_low_freqs[i] + 1) // 2
118 | mask[i, pad : pad + 2 * num_low_freqs[i]] = 1
119 | mask[i, :padding_left] = 1
120 | mask[i, padding_right : num_cols[i]] = 1
121 |
122 | if not mask_args["centered"]:
123 | mask[i, : num_cols[i]] = fastmri.ifftshift(mask[i, : num_cols[i]])
124 | mask[i, num_cols[i] : mask_args["max_width"]] = 1
125 |
126 | mask_shape = [batch_size] + [1] * (mask_args["width_dim"] + 1)
127 | mask_shape[mask_args["width_dim"] + 1] = mask_args["max_width"]
128 | return mask.view(*mask_shape)
129 |
--------------------------------------------------------------------------------
/activemri/envs/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import importlib
7 | import json
8 | import pathlib
9 |
10 | from typing import Dict, Tuple
11 |
12 | import numpy as np
13 | import skimage.metrics
14 | import torch
15 |
16 |
17 | def get_user_dir() -> pathlib.Path:
18 | # return pathlib.Path.home() / ".activemri"
19 | return pathlib.Path.cwd() / ".activemri"
20 |
21 |
22 | def maybe_create_datacache_dir() -> pathlib.Path:
23 | datacache_dir = get_user_dir() / "__datacache__"
24 | if not datacache_dir.is_dir():
25 | datacache_dir.mkdir()
26 | return datacache_dir
27 |
28 |
29 | def get_defaults_json() -> Tuple[Dict[str, str], str]:
30 | defaults_path = get_user_dir() / "defaults.json"
31 | if not pathlib.Path.exists(defaults_path):
32 | parent = defaults_path.parents[0]
33 | parent.mkdir(exist_ok=True)
34 | content = {"data_location": "", "saved_models_dir": ""}
35 | with defaults_path.open("w", encoding="utf-8") as f:
36 | json.dump(content, f)
37 | else:
38 | with defaults_path.open("r", encoding="utf-8") as f:
39 | content = json.load(f)
40 | return content, str(defaults_path)
41 |
42 |
43 | def import_object_from_str(classname: str):
44 | the_module, the_object = classname.rsplit(".", 1)
45 | the_object = classname.split(".")[-1]
46 | module = importlib.import_module(the_module)
47 | return getattr(module, the_object)
48 |
49 |
50 | def compute_ssim(xs: torch.Tensor, ys: torch.Tensor) -> np.ndarray:
51 | ssims = []
52 | for i in range(xs.shape[0]):
53 | ssim = skimage.metrics.structural_similarity(
54 | xs[i].cpu().numpy(),
55 | ys[i].cpu().numpy(),
56 | data_range=ys[i].cpu().numpy().max(),
57 | )
58 | ssims.append(ssim)
59 | return np.array(ssims, dtype=np.float32)
60 |
61 |
62 | def compute_psnr(xs: torch.Tensor, ys: torch.Tensor) -> np.ndarray:
63 | psnrs = []
64 | for i in range(xs.shape[0]):
65 | psnr = skimage.metrics.peak_signal_noise_ratio(
66 | xs[i].cpu().numpy(),
67 | ys[i].cpu().numpy(),
68 | data_range=ys[i].cpu().numpy().max(),
69 | )
70 | psnrs.append(psnr)
71 | return np.array(psnrs, dtype=np.float32)
72 |
73 |
74 | def compute_mse(xs: torch.Tensor, ys: torch.Tensor) -> np.ndarray:
75 | dims = tuple(range(1, len(xs.shape)))
76 | return np.mean((ys.cpu().numpy() - xs.cpu().numpy()) ** 2, axis=dims)
77 |
78 | def compute_nmse(xs: torch.Tensor, ys: torch.Tensor) -> np.ndarray:
79 | ys_numpy = ys.cpu().numpy()
80 | nmses = []
81 | for i in range(xs.shape[0]):
82 | x = xs[i].cpu().numpy()
83 | y = ys_numpy[i]
84 | nmse = np.linalg.norm(y - x) ** 2 / np.linalg.norm(y) ** 2
85 | nmses.append(nmse)
86 | return np.array(nmses, dtype=np.float32)
87 |
88 | def compute_mse_torch(xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor:
89 | dims = tuple(range(1, len(xs.shape)))
90 | return torch.mean((ys - xs) ** 2, dim=dims)
91 |
92 | def compute_nmse_torch(xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor:
93 | dims = tuple(range(1, len(xs.shape)))
94 | nmse = torch.linalg.norm(ys-xs, dim=dims)**2 / torch.linalg.norm(ys, dim=dims)**2
95 |
96 | return nmse
97 |
98 | from torch import nn
99 | import torch.nn.functional as F
100 |
101 | class SSIMLoss(nn.Module):
102 | """
103 | SSIM loss module.
104 | """
105 |
106 | def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03):
107 | """
108 | Args:
109 | win_size: Window size for SSIM calculation.
110 | k1: k1 parameter for SSIM calculation.
111 | k2: k2 parameter for SSIM calculation.
112 | """
113 | super().__init__()
114 | self.win_size = win_size
115 | self.k1, self.k2 = k1, k2
116 | self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size ** 2)
117 | NP = win_size ** 2
118 | self.cov_norm = NP / (NP - 1)
119 |
120 | def forward(self, X: torch.Tensor, Y: torch.Tensor, data_range: torch.Tensor):
121 | assert isinstance(self.w, torch.Tensor)
122 |
123 | data_range = data_range[:, None, None, None]
124 | C1 = (self.k1 * data_range) ** 2
125 | C2 = (self.k2 * data_range) ** 2
126 | ux = F.conv2d(X, self.w) # typing: ignore
127 | uy = F.conv2d(Y, self.w) #
128 | uxx = F.conv2d(X * X, self.w)
129 | uyy = F.conv2d(Y * Y, self.w)
130 | uxy = F.conv2d(X * Y, self.w)
131 | vx = self.cov_norm * (uxx - ux * ux)
132 | vy = self.cov_norm * (uyy - uy * uy)
133 | vxy = self.cov_norm * (uxy - ux * uy)
134 | A1, A2, B1, B2 = (
135 | 2 * ux * uy + C1,
136 | 2 * vxy + C2,
137 | ux ** 2 + uy ** 2 + C1,
138 | vx + vy + C2,
139 | )
140 | D = B1 * B2
141 | S = (A1 * A2) / D
142 |
143 | dims = tuple(range(1, len(X.shape)))
144 |
145 | return S.mean(dim=dims)
146 |
147 | SSIM = SSIMLoss()
148 |
149 | def compute_ssim_torch(xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor:
150 | global SSIM
151 | SSIM = SSIM.to(xs.device)
152 | data_range = [y.max() for y in ys]
153 | data_range = torch.stack(data_range, dim=0)
154 |
155 | return SSIM(xs, ys, data_range=data_range.detach())
156 |
157 | def compute_psnr_torch(xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor:
158 | mse = compute_mse_torch(xs, ys)
159 | data_range = [y.max() for y in ys]
160 | data_range = torch.stack(data_range, dim=0)
161 |
162 | return 10 * torch.log10((data_range ** 2) / mse)
163 |
164 | def compute_gaussian_nll_loss(reconstruction, target, logvar):
165 | l2 = F.mse_loss(reconstruction, target, reduce=False)
166 | # Clip logvar to make variance in [0.0001, 0.1], for numerical stability
167 |
168 | logvar = logvar.clamp(min=-9.2, max=1.609)
169 | one_over_var = torch.exp(-logvar)
170 |
171 |
172 | assert len(l2) == len(logvar)
173 | return 0.5 * (one_over_var * l2 + logvar)
174 |
--------------------------------------------------------------------------------
/docs/1d_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/docs/1d_animation.gif
--------------------------------------------------------------------------------
/docs/2d_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/docs/2d_animation.gif
--------------------------------------------------------------------------------
/docs/GETTING_STARTED.md:
--------------------------------------------------------------------------------
1 | ## Getting Started with SeqMRI
2 |
3 | ### Prerequisite
4 |
5 | - Follow [INSTALL.md](INSTALL.md) to install all required libraries.
6 | - Download the FastMRI single coil knee data from [FastMRI](https://fastmri.med.nyu.edu/)
7 |
8 | ### Download data and organise as follows
9 |
10 | ```
11 | # For knee dataset
12 | └── datasets
13 | ├── knee
14 | ├── knee_singlecoil_train
15 | ├── knee_singlecoil_val
16 | ```
17 |
18 | ### Train & Evaluate in Command Line
19 |
20 | #### Loupe Training
21 |
22 | Please refer to [train_loupe.sh](../examples/loupe/train_loupe.sh) for details of each arguments.
23 |
24 | ```bash
25 | # 4x line constrained sampling
26 | bash examples/loupe/train_loupe.sh 1 4 ssim real-knee 4 5e-5 10 0.5 128 0 cuda:0 1 0 0 0 0
27 |
28 | # 4x 2d point sampling
29 | bash examples/loupe/train_loupe.sh 1 22 ssim real-knee 4 1e-3 10 0.5 128 0 cuda:0 0 0 0 0 0
30 |
31 | # 8x 2d point sampling
32 | bash examples/loupe/train_loupe.sh 1 16 ssim real-knee 8 1e-3 10 0.5 128 0 cuda:0 0 0 0 0 0
33 |
34 | # 16x 2d point sampling
35 | bash examples/loupe/train_loupe.sh 1 12 ssim real-knee 16 1e-3 10 0.5 128 0 cuda:0 0 0 0 0 0
36 | ```
37 |
38 | ```bash
39 | # 4x line constrained sampling
40 | bash examples/loupe/train_loupe.sh 1 4 ssim brain 4 5e-5 10 0.5 128 0 cuda:0 1 0 0 0 0
41 |
42 | # 4x 2d point sampling
43 | bash examples/loupe/train_loupe.sh 1 22 ssim brain 4 1e-3 10 0.5 128 0 cuda:0 0 0 0 0 0
44 |
45 | ```
46 |
47 |
48 | #### Loupe Evaluation
49 |
50 |
51 | ```bash
52 | bash examples/loupe/test_loupe.sh EXP_DIR real-knee
53 | ```
54 |
55 | Remeber to replace EXP_DIR with the path to the directory that contains the saved checkpoint.
56 |
57 |
58 | #### Sequential Sampling Training
59 |
60 | ```bash
61 | # 4x line constrained sampling
62 | bash examples/sequential/train_sequential.sh NUM_STEP 1 4 ssim real-knee 4 5e-5 cuda:0 10 0.5 128 0 1
63 |
64 | # 4x 2d point sampling
65 | bash examples/sequential/train_sequential.sh NUM_STEP 1 22 ssim real-knee 4 1e-3 cuda:0 10 0.5 128 0 0
66 |
67 | # 8x 2d point sampling
68 | bash examples/sequential/train_sequential.sh NUM_STEP 1 16 ssim real-knee 8 1e-3 cuda:0 10 0.5 128 0 0
69 |
70 | # 16x 2d point sampling
71 | bash examples/sequential/train_sequential.sh NUM_STEP 1 12 ssim real-knee 16 1e-3 cuda:0 10 0.5 128 0 0
72 | ```
73 |
74 | Remember to change NUM_STEP to the a value in [1,2,4] for sequential sampling.
75 |
76 | #### Sequential Sampling Evaluation
77 |
78 | ```bash
79 | bash examples/sequential/test_sequential.sh EXP_DIR real-knee
80 | ```
81 |
82 | ### Model Zoo
83 |
84 | #### Line-constrained Sampling
85 |
86 | | Model | Accelearation | SSIM | Link |
87 | |---------|---------------|------|------|
88 | | Loupe | 4x | 89.5 | [URL](https://drive.google.com/drive/folders/1A-JFRd5KJ_HoCd2gYePjln67YzcsTiK5?usp=sharing) |
89 | | Seq1 | 4x | 90.8 | [URL](https://drive.google.com/drive/folders/1vcIaIdSnlDPElbQm8kusBOfxR8FfMlzc?usp=sharing) |
90 | | Seq4 | 4x | 91.2 | [URL](https://drive.google.com/drive/folders/1Y_fvnne5Gx0zaXFC0ANZYnlun7CeJ2Kv?usp=sharing) |
91 |
92 |
93 | #### 2D Point Sampling
94 |
95 | | Model | Accelearation | SSIM | Link |
96 | |---------|---------------|------|------|
97 | | Loupe | 4x | 92.4 | [URL](https://drive.google.com/drive/folders/1cTpc1V8EuLVyZ4iy3EIW_XhEzgmiecgN?usp=sharing) |
98 | | Seq1 | 4x | 92.7 | [URL](https://drive.google.com/drive/folders/1ptKDYk7Dbff9kOJBXUPkpmLqoNoPA4_z?usp=sharing) |
99 | | Seq4 | 4x | 92.9 | [URL](https://drive.google.com/drive/folders/1KG8vzruVlJkxlyywZUXDkQdXGyFJzaNB?usp=sharing) |
100 |
--------------------------------------------------------------------------------
/docs/INSTALL.md:
--------------------------------------------------------------------------------
1 | ## Installation
2 | This installation guide shows you how to set up the environment for running our code using conda or Singularity container.
3 |
4 | First clone the ActiveMRI repository
5 | ```
6 | git clone https://github.com/tianweiy/SeqMRI.git
7 | cd SeqMRI
8 | ```
9 | Then start a virtual environment with new environment variables nad
10 | ```
11 | conda create --name activemri python=3.7
12 | conda activate activemri
13 | ```
14 | Install PyTorch
15 | ```
16 | conda install pytorch=1.7.1 torchvision cudatoolkit=10.2 -c pytorch
17 | ```
18 | Install all requirements
19 | ```
20 | pip install -r requirements.txt
21 | ```
22 | ## Singularity Container
23 | To use the Singularity Container, run the following piece of code before creating the virtual environment.
24 |
25 | Login and enter API access token
26 | ```
27 | singularity remote login
28 | ```
29 | Build the image to a .sif file
30 | ```
31 | singularity build --remote active-mri.sif active-mri.def
32 | ```
33 | Run a singularity shell
34 | ```
35 | singularity shell --nv active-mri.sif
36 | ```
37 | Proceed with the installation instruction above.
--------------------------------------------------------------------------------
/docs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/docs/teaser.png
--------------------------------------------------------------------------------
/examples/loupe/test_loupe.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | EXP_DIR=$1
4 | DATASET_NAME=$2
5 |
6 | python examples/loupe/train_loupe.py \
7 | --exp-dir ${EXP_DIR} \
8 | --dataset-name ${DATASET_NAME} \
9 | --model LOUPE \
10 | --input_chans 2 \
11 | --test
--------------------------------------------------------------------------------
/examples/loupe/train_loupe.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import argparse
7 | import pathlib
8 | import random
9 | import numpy as np
10 | import torch
11 | import uuid
12 | import activemri.envs.loupe_envs as loupe_envs
13 | from activemri.baselines.non_rl import NonRLTrainer, NonRLTester
14 | import matplotlib
15 | matplotlib.use('Agg')
16 |
17 |
18 | if __name__ == "__main__":
19 | parser = argparse.ArgumentParser(description='MRI Reconstruction Example')
20 | parser.add_argument('--num-pools', type=int, default=4, help='Number of U-Net pooling layers')
21 | parser.add_argument('--num-step', type=int, default=2, help='Number of LSTM iterations')
22 | parser.add_argument('--drop-prob', type=float, default=0.0, help='Dropout probability')
23 | parser.add_argument('--num-chans', type=int, default=64, help='Number of U-Net channels')
24 |
25 | parser.add_argument('--batch-size', default=16, type=int, help='Mini batch size')
26 | parser.add_argument('--num-epochs', type=int, default=50, help='Number of training epochs')
27 | parser.add_argument('--noise-type', type=str, default='none', help='Type of additive noise to measurements')
28 | parser.add_argument('--noise-level', type=float, default=0, help='Noise level')
29 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
30 | parser.add_argument('--lr-step-size', type=int, default=40,
31 | help='Period of learning rate decay')
32 | parser.add_argument('--lr-gamma', type=float, default=0.1,
33 | help='Multiplicative factor of learning rate decay')
34 | parser.add_argument('--weight-decay', type=float, default=0.,
35 | help='Strength of weight decay regularization')
36 | parser.add_argument('--report-interval', type=int, default=100, help='Period of loss reporting')
37 | parser.add_argument('--data-parallel', action='store_true',
38 | help='If set, use multiple GPUs using data parallelism')
39 | parser.add_argument('--device', type=str, default='cuda',
40 | help='Which device to train on. Set to "cuda" to use the GPU')
41 | parser.add_argument('--exp-dir', type=pathlib.Path, required=True,
42 | help='Path where model and results should be saved')
43 | parser.add_argument('--checkpoint1', type=str,
44 | help='Path to an existing checkpoint. Used along with "--resume"')
45 | parser.add_argument('--entropy_weight', type=float, default=0.0,
46 | help='weight for the entropy/diversity loss')
47 | parser.add_argument('--recon_weight', type=float, default=1.0,
48 | help='weight for the reconsturction loss')
49 | parser.add_argument('--sparsity_weight', type=float, default=0.0,
50 | help='weight for the sparsity loss')
51 | parser.add_argument('--save-model', type=bool, default=False, help='save model every iteration or not')
52 |
53 | parser.add_argument('--seed', default=42, type=int, help='Seed for random number generators')
54 | parser.add_argument('--resolution', default=[128, 128], nargs='+', type=int, help='Resolution of images')
55 |
56 | parser.add_argument('--dataset-name', type=str, choices=['fashion-mnist', 'dicom-knee', 'real-knee', 'brain'],
57 | required=True, help='name of the dataset')
58 | parser.add_argument('--sample-rate', type=float, default=1.,
59 | help='Fraction of total volumes to include')
60 |
61 | # Mask parameters
62 | parser.add_argument('--accelerations', nargs='+', default=[4], type=float,
63 | help='Ratio of k-space columns to be sampled. If multiple values are '
64 | 'provided, then one of those is chosen uniformly at random for '
65 | 'each volume.')
66 | parser.add_argument('--label_range', nargs='+', type=int, help='train using images of specific class')
67 | parser.add_argument('--model', type=str, help='name of the model to run', required=True)
68 | parser.add_argument('--input_chans', type=int, choices=[1, 2], required=True, help='number of input channels. One for real image, 2 for compelx image')
69 | parser.add_argument('--output_chans', type=int, default=1, help='number of output channels. One for real image')
70 | parser.add_argument('--line-constrained', type=int, default=0)
71 | parser.add_argument('--unet', action='store_true')
72 | parser.add_argument('--conjugate_mask', action='store_true', help='force loupe model to use conjugate symmetry.')
73 | parser.add_argument('--bi-dir', type=int, default=0)
74 | parser.add_argument('--loss_type', type=str, choices=['l1', 'ssim', 'psnr'], default='l1')
75 | parser.add_argument('--test_visual_frequency', type=int, default=1000)
76 | parser.add_argument('--test', action='store_true')
77 | parser.add_argument('--preselect', type=int, default=0)
78 | parser.add_argument('--preselect_num', type=int, default=2)
79 | parser.add_argument('--random_rotate', type=int, default=0)
80 | parser.add_argument('--random_baseline', type=int, default=0)
81 | parser.add_argument('--poisson', type=int, default=0)
82 | parser.add_argument('--spectrum', type=int, default=0)
83 | parser.add_argument("--equispaced", type=int, default=0)
84 |
85 |
86 | args = parser.parse_args()
87 | args.equispaced = args.equispaced > 0
88 | args.spectrum = args.spectrum > 0
89 | args.poisson = args.poisson > 0
90 | args.random = args.random_baseline > 0
91 | args.random_rotate = args.random_rotate > 0
92 | args.kspace_weight = 0
93 | args.line_constrained = args.line_constrained > 0
94 |
95 | if args.checkpoint1 is not None:
96 | args.resume = True
97 | else:
98 | args.resume = False
99 |
100 | noise_str = ''
101 | if args.noise_type is 'none':
102 | noise_str = '_no_noise_'
103 | else:
104 | noise_str = '_' + args.noise_type + str(args.noise_level) + '_'
105 |
106 | if args.preselect > 0:
107 | args.preselect = True
108 | else:
109 | args.preselect = False
110 |
111 | if args.bi_dir > 0 :
112 | args.bi_dir = True
113 | else:
114 | args.bi_dir = False
115 |
116 | if str(args.exp_dir) is 'auto':
117 | args.exp_dir =('checkpoints/'+args.dataset_name + '_' + str(float(args.accelerations[0])) +
118 | 'x_' + args.model + '_bi_dir_{}'.format(args.bi_dir)+ '_preselect_{}'.format(args.preselect) +
119 | noise_str + 'lr=' + str(args.lr) + '_bs=' + str(args.batch_size) + '_loss_type='+args.loss_type +
120 | '_epochs=' + str(args.num_epochs))
121 |
122 | args.exp_dir = pathlib.Path(args.exp_dir+'_uuid_'+uuid.uuid4().hex.upper()[0:6])
123 |
124 | print('save logs to {}'.format(args.exp_dir))
125 |
126 |
127 | args.visualization_dir = args.exp_dir / 'visualizations'
128 |
129 | if args.test:
130 | args.batch_size = 1
131 |
132 | if args.dataset_name == 'real-knee':
133 | args.data_path = 'datasets/knee'
134 | # args.resolution = [128, 128]
135 | env = loupe_envs.LOUPERealKspaceEnv(args)
136 | elif args.dataset_name == 'brain':
137 | args.data_path = 'datasets/brain'
138 | env = loupe_envs.LoupeBrainEnv(args)
139 | else:
140 | raise NotImplementedError
141 |
142 | # set random seeds
143 | random.seed(args.seed)
144 | np.random.seed(args.seed)
145 | torch.manual_seed(args.seed)
146 |
147 | if args.test:
148 | policy = NonRLTester(env, args.exp_dir, args, None)
149 | else:
150 | policy = NonRLTrainer(args, env, torch.device(args.device))
151 | policy()
152 |
--------------------------------------------------------------------------------
/examples/loupe/train_loupe.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | PRE_SELECT=$1 # do we preselect center measurements before future sampling, should be True for all experiments
4 | PRE_SELECT_NUM=$2 # number of preselect line / pixels in 1d / 2d sampling case
5 | LOSS_TYPE=$3 # loss used for training. Options: [l1, ssim, psnr]
6 | DATASET_NAME=$4 # real-knee
7 | ACC=$5 # acceleration ratio
8 | LR=$6 # learning rate
9 | LR_STEP=${7} # epoch for learning rate decay
10 | GAMMA=${8} # leraning rate decay ratio
11 | RESOLUTION=${9} # image resolution, default to [128, 128]
12 | ROTATE=${10} # rotate the k-space and corresponding image target
13 | DEVICE=${11} # GPU device id
14 | LINE_CONSTRAINED=${12} # use 1d sampling
15 | RANDOM_BASELINE=${13} # random sampling baseline
16 | POISSON=${14} # poisson sampling baseline
17 | SPECTRUM=${15} # baseline
18 | EQUISPACED=${16} # baseline
19 |
20 | python examples/loupe/train_loupe.py \
21 | --exp-dir auto \
22 | --dataset-name ${DATASET_NAME} \
23 | --accelerations ${ACC} \
24 | --model LOUPE \
25 | --input_chans 2 \
26 | --line-constrained ${LINE_CONSTRAINED} \
27 | --batch-size 16 \
28 | --noise-type gaussian \
29 | --loss_type "${LOSS_TYPE}" \
30 | --save-model True \
31 | --preselect ${PRE_SELECT} \
32 | --preselect_num ${PRE_SELECT_NUM} \
33 | --lr ${LR} \
34 | --lr-step-size ${LR_STEP} \
35 | --lr-gamma ${GAMMA} \
36 | --resolution ${RESOLUTION} ${RESOLUTION} \
37 | --random_rotate ${ROTATE} \
38 | --device ${DEVICE} \
39 | --random_baseline ${RANDOM_BASELINE} \
40 | --poisson ${POISSON} \
41 | --spectrum ${SPECTRUM} \
42 | --equispaced ${EQUISPACED}
43 |
--------------------------------------------------------------------------------
/examples/sequential/__init__.py:
--------------------------------------------------------------------------------
1 | from .classifiers import *
--------------------------------------------------------------------------------
/examples/sequential/test_sequential.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | EXP_DIR=$1
4 | DATASET_NAME=$2
5 |
6 | python examples/sequential/train_sequential.py \
7 | --exp-dir ${EXP_DIR} \
8 | --dataset-name ${DATASET_NAME} \
9 | --model SequentialSampling \
10 | --input_chans 2 \
11 | --test
--------------------------------------------------------------------------------
/examples/sequential/train_sequential.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import argparse
7 | import pathlib
8 | import random
9 | import numpy as np
10 | import torch
11 | import os
12 | import activemri.envs.loupe_envs as loupe_envs
13 | from activemri.baselines.non_rl import NonRLTrainer, NonRLTester
14 | import matplotlib
15 | matplotlib.use('Agg')
16 |
17 |
18 | if __name__ == "__main__":
19 | parser = argparse.ArgumentParser(description='MRI Reconstruction Example')
20 | parser.add_argument('--num-pools', type=int, default=4, help='Number of U-Net pooling layers')
21 | parser.add_argument('--num-step', type=int, default=2, help='Number of LSTM iterations')
22 | parser.add_argument('--drop-prob', type=float, default=0.0, help='Dropout probability')
23 | parser.add_argument('--num-chans', type=int, default=64, help='Number of U-Net channels')
24 |
25 | parser.add_argument('--batch-size', default=4, type=int, help='Mini batch size')
26 | parser.add_argument('--num-epochs', type=int, default=50, help='Number of training epochs')
27 | parser.add_argument('--noise-type', type=str, default='none', help='Type of additive noise to measurements')
28 | parser.add_argument('--noise-level', type=float, default=5e-5, help='Noise level')
29 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
30 | parser.add_argument('--lr-step-size', type=int, default=40,
31 | help='Period of learning rate decay')
32 | parser.add_argument('--lr-gamma', type=float, default=0.1,
33 | help='Multiplicative factor of learning rate decay')
34 | parser.add_argument('--weight-decay', type=float, default=0.,
35 | help='Strength of weight decay regularization')
36 | parser.add_argument('--report-interval', type=int, default=100, help='Period of loss reporting')
37 | parser.add_argument('--data-parallel', action='store_true',
38 | help='If set, use multiple GPUs using data parallelism')
39 | parser.add_argument('--device', type=str, default='cuda',
40 | help='Which device to train on. Set to "cuda" to use the GPU')
41 | parser.add_argument('--exp-dir', type=pathlib.Path, required=True,
42 | help='Path where model and results should be saved')
43 | parser.add_argument('--checkpoint1', type=str,
44 | help='Path to an existing checkpoint. Used along with "--resume"')
45 | parser.add_argument('--entropy_weight', type=float, default=0.0,
46 | help='weight for the entropy/diversity loss')
47 | parser.add_argument('--recon_weight', type=float, default=1.0,
48 | help='weight for the reconsturction loss')
49 | parser.add_argument('--sparsity_weight', type=float, default=0.0,
50 | help='weight for the sparsity loss')
51 | parser.add_argument('--save-model', type=bool, default=False, help='save model every iteration or not')
52 |
53 | parser.add_argument('--seed', default=42, type=int, help='Seed for random number generators')
54 | parser.add_argument('--resolution', default=[128,128], nargs='+', type=int, help='Resolution of images')
55 |
56 | # Data parameters
57 | parser.add_argument('--data-path', type=pathlib.Path, required=False,
58 | help='Path to the dataset')
59 | parser.add_argument('--dataset-name', type=str, choices=['fashion-mnist', 'dicom-knee', 'real-knee', 'brain'],
60 | required=True, help='name of the dataset')
61 | parser.add_argument('--sample-rate', type=float, default=1.,
62 | help='Fraction of total volumes to include')
63 |
64 | # Mask parameters
65 | parser.add_argument('--accelerations', nargs='+', default=[4], type=float,
66 | help='Ratio of k-space columns to be sampled. If multiple values are '
67 | 'provided, then one of those is chosen uniformly at random for '
68 | 'each volume.')
69 | parser.add_argument('--label_range', nargs='+', type=int, help='train using images of specific class', default=None)
70 | parser.add_argument('--model', type=str, help='name of the model to run', required=True)
71 | parser.add_argument('--input_chans', type=int, choices=[1, 2], required=True, help='number of input channels. One for real image, 2 for compelx image')
72 | parser.add_argument('--output_chans', type=int, default=1, help='number of output channels. One for real image')
73 | parser.add_argument('--line-constrained', type=int, default=0)
74 | parser.add_argument('--unet', action='store_true')
75 | parser.add_argument('--preselect', type=int, default=0)
76 | parser.add_argument('--conjugate_mask', action='store_true', help='force loupe model to use conjugate symmetry.')
77 | parser.add_argument('--loss_type', type=str, choices=['l1', 'ssim', 'psnr', 'xentropy'], default='l1')
78 | parser.add_argument('--test_visual_frequency', type=int, default=1000)
79 | parser.add_argument('--test', action='store_true')
80 | parser.add_argument('--bi-dir', type=int, default=0)
81 | parser.add_argument('--preselect_num', type=int, default=2)
82 | parser.add_argument('--binary_sampler', type=int, default=0)
83 | parser.add_argument('--clamp', type=float, default=1e10)
84 | parser.add_argument('--old_recon', type=int, default=0)
85 | parser.add_argument('--uncertainty_loss', type=int, choices=[0, 1], default=0)
86 | parser.add_argument('--uncertainty_weight', type=float, default=0)
87 | parser.add_argument('--detach_kspace', type=int, default=0)
88 | parser.add_argument('--random_rotate', type=int, default=0)
89 | parser.add_argument('--kspace_weight', type=float, default=0)
90 | parser.add_argument('--pretrained_recon', type=int, default=0)
91 |
92 | args = parser.parse_args()
93 |
94 | args.pretrained_recon = args.pretrained_recon >0
95 | args.random_rotate = args.random_rotate > 0
96 | args.line_constrained = args.line_constrained > 0
97 |
98 | if args.detach_kspace == 1:
99 | args.detach_kspace = True
100 | else:
101 | args.detach_kspace = False
102 |
103 | if args.uncertainty_loss ==1:
104 | args.uncertainty_loss = True
105 | else:
106 | args.uncertainty_loss = False
107 |
108 | if args.old_recon > 0:
109 | args.old_recon = True
110 | else:
111 | args.old_recon = False
112 |
113 | if args.checkpoint1 is not None:
114 | args.resume = True
115 | else:
116 | args.resume = False
117 |
118 | noise_str = ''
119 | if args.noise_type is 'none':
120 | noise_str = '_no_noise_'
121 | else:
122 | noise_str = '_' + args.noise_type + str(args.noise_level) + '_'
123 |
124 | if args.preselect > 0:
125 | args.preselect = True
126 | else:
127 | args.preselect = False
128 |
129 | if args.bi_dir > 0 :
130 | args.bi_dir = True
131 | else:
132 | args.bi_dir = False
133 |
134 | if args.binary_sampler > 0:
135 | args.binary_sampler = True
136 | else:
137 | args.bianry_sampler = False
138 |
139 |
140 | import uuid; # uuid.uuid4().hex.upper()[0:6]
141 |
142 | if str(args.exp_dir) is 'auto':
143 | if os.name == 'nt':
144 | args.exp_dir = pathlib.Path('checkpoints/'+uuid.uuid4().hex.upper()[0:6])
145 | else:
146 | args.exp_dir = ('checkpoints/'+args.dataset_name + '_' +
147 | str(float(args.accelerations[0])) + 'x_' + args.model + '_bi_dir_{}'.format(args.bi_dir) + '_step_{}'.format(args.num_step)
148 | + '_preselect_{}'.format(args.preselect) + noise_str + 'lr=' + str(args.lr)
149 | + '_bs=' + str(args.batch_size) + '_loss_type='+args.loss_type + '_epochs=' + str(args.num_epochs))
150 |
151 | args.exp_dir = pathlib.Path(args.exp_dir+'_uuid_'+uuid.uuid4().hex.upper()[0:6])
152 |
153 | print('save logs to {}'.format(args.exp_dir))
154 |
155 | args.visualization_dir = args.exp_dir / 'visualizations'
156 |
157 | if args.num_step == 0:
158 | args.fixed_input = True
159 | args.num_step = 1
160 | else:
161 | args.fixed_input = False
162 |
163 | if args.test:
164 | args.batch_size = 1
165 |
166 | if args.dataset_name == 'fashion-mnist':
167 | args.data_path = 'datasets/fashion-mnist'
168 | # args.resolution = [64, 64]
169 | env = loupe_envs.LOUPEFashionMNISTEnv(args)
170 | elif args.dataset_name == 'dicom-knee':
171 | args.data_path = 'datasets/knee'
172 | # args.resolution = [128, 128]
173 | env = loupe_envs.LOUPEDICOMEnv(args)
174 | elif args.dataset_name == 'brain':
175 | args.data_path = 'datasets/brain'
176 | env = loupe_envs.LoupeBrainEnv(args)
177 | elif args.dataset_name == 'real-knee':
178 | args.data_path = 'datasets/knee'
179 | # args.resolution = [128, 128]
180 | env = loupe_envs.LOUPERealKspaceEnv(args)
181 |
182 | elif args.dataset_name == 'synthetic':
183 | args.data_path = 'datasets/synthetic'
184 | # args.resolution = [48, 48]
185 | env = loupe_envs.LOUPESyntheticEnv(args)
186 | else:
187 | raise NotImplementedError
188 | # set random seeds
189 | random.seed(args.seed)
190 | np.random.seed(args.seed)
191 | torch.manual_seed(args.seed)
192 |
193 | if args.test:
194 | policy = NonRLTester(env, args.exp_dir, args, None)
195 | else:
196 | policy = NonRLTrainer(args, env, torch.device(args.device))
197 | policy()
198 |
--------------------------------------------------------------------------------
/examples/sequential/train_sequential.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | NUM_STEP=$1 # number of sequential sampling steps (excluding preselection)
4 | PRE_SELECT=$2
5 | PRE_SELECT_NUM=$3 # number of preselect line / pixels in 1d / 2d sampling case
6 | LOSS_TYPE=$4 # loss used for training. Options: [l1, ssim, psnr]
7 | DATASET_NAME=$5 # real-knee
8 | ACCELERATIONS=$6 # acceleration ratio
9 | LR=$7 # learning rate
10 | DEVICE=${8} # GPU device id
11 | LR_STEP=${9} # epoch for learning rate decay
12 | GAMMA=${10} # leraning rate decay ratio
13 | RESOLUTION=${11} # image resolution, default to [128, 128]
14 | ROTATE=${12} # rotate the k-space and corresponding image target
15 | LINE_CONSTRAINED=${13} # use 1d sampling
16 |
17 | python examples/sequential/train_sequential.py \
18 | --exp-dir auto \
19 | --dataset-name ${DATASET_NAME} \
20 | --accelerations $ACCELERATIONS\
21 | --model SequentialSampling \
22 | --input_chans 2 \
23 | --line-constrained ${LINE_CONSTRAINED}\
24 | --unet \
25 | --save-model True \
26 | --batch-size 16 \
27 | --noise-type gaussian \
28 | --loss_type "${LOSS_TYPE}" \
29 | --num-step ${NUM_STEP} \
30 | --preselect ${PRE_SELECT} \
31 | --preselect_num ${PRE_SELECT_NUM} \
32 | --lr ${LR} \
33 | --device ${DEVICE} \
34 | --lr-step-size ${LR_STEP} \
35 | --lr-gamma ${GAMMA} \
36 | --resolution ${RESOLUTION} ${RESOLUTION} \
37 | --random_rotate ${ROTATE}
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_loupe_785FD0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_loupe_785FD0.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_loupe_85A60D.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_loupe_85A60D.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_loupe_FB9BE7.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_loupe_FB9BE7.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq0_191014.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq0_191014.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq0_997E9B.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq0_997E9B.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq0_B2F31C.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq0_B2F31C.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq1_0F3D13.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq1_0F3D13.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq1_8D8ED1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq1_8D8ED1.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq1_9BA1F7.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq1_9BA1F7.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq2_806982.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq2_806982.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq2_A0B684.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq2_A0B684.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq2_AB92C0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq2_AB92C0.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq4_2C2751.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq4_2C2751.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq4_8E14D6.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq4_8E14D6.pkl
--------------------------------------------------------------------------------
/figure_reproduction/1d_data/4x_lc_seq4_C8B962.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/1d_data/4x_lc_seq4_C8B962.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_loupe_1CB132.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_loupe_1CB132.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_loupe_CF3852.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_loupe_CF3852.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_loupe_F5C5C1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_loupe_F5C5C1.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq0_04BC1F.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq0_04BC1F.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq0_09DB18.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq0_09DB18.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq0_0F247E.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq0_0F247E.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq1_13BF86.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq1_13BF86.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq1_9CB984.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq1_9CB984.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq1_E154DC.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq1_E154DC.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq2_8934DD.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq2_8934DD.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq2_B11CEC.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq2_B11CEC.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq2_B68595.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq2_B68595.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq4_203258.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq4_203258.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq4_2450E9.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq4_2450E9.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/16x_2d_seq4_BD196A.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/16x_2d_seq4_BD196A.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_loupe_72B682.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_loupe_72B682.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_loupe_73A783.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_loupe_73A783.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_loupe_C655FA.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_loupe_C655FA.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq0_076771.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq0_076771.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq0_350405.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq0_350405.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq0_9E225B.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq0_9E225B.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq1_748B84.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq1_748B84.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq1_F3945C.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq1_F3945C.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq1_F8733A.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq1_F8733A.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq2_244A77.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq2_244A77.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq2_93381A.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq2_93381A.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq2_A4DA53.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq2_A4DA53.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq4_08BBA4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq4_08BBA4.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq4_C32832.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq4_C32832.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/4x_2d_seq4_D0D30F.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/4x_2d_seq4_D0D30F.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_loupe_4F9069.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_loupe_4F9069.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_loupe_5C47BE.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_loupe_5C47BE.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_loupe_D010FA.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_loupe_D010FA.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq0_130F17.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq0_130F17.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq0_72A09A.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq0_72A09A.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq0_CB05F5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq0_CB05F5.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq1_848B51.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq1_848B51.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq1_948938.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq1_948938.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq1_EC7415.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq1_EC7415.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq2_0B402E.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq2_0B402E.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq2_2956B7.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq2_2956B7.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq2_4F5297.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq2_4F5297.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq4_300BD3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq4_300BD3.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq4_56CED5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq4_56CED5.pkl
--------------------------------------------------------------------------------
/figure_reproduction/2d_data/8x_2d_seq4_7A9703.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/2d_data/8x_2d_seq4_7A9703.pkl
--------------------------------------------------------------------------------
/figure_reproduction/teaser_data/left_test_iter=651_step0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/left_test_iter=651_step0.pkl
--------------------------------------------------------------------------------
/figure_reproduction/teaser_data/left_test_iter=651_step1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/left_test_iter=651_step1.pkl
--------------------------------------------------------------------------------
/figure_reproduction/teaser_data/left_test_iter=651_step2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/left_test_iter=651_step2.pkl
--------------------------------------------------------------------------------
/figure_reproduction/teaser_data/left_test_iter=651_step3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/left_test_iter=651_step3.pkl
--------------------------------------------------------------------------------
/figure_reproduction/teaser_data/right_test_iter=651_step0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/right_test_iter=651_step0.pkl
--------------------------------------------------------------------------------
/figure_reproduction/teaser_data/right_test_iter=651_step1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/right_test_iter=651_step1.pkl
--------------------------------------------------------------------------------
/figure_reproduction/teaser_data/right_test_iter=651_step2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/right_test_iter=651_step2.pkl
--------------------------------------------------------------------------------
/figure_reproduction/teaser_data/right_test_iter=651_step3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/figure_reproduction/teaser_data/right_test_iter=651_step3.pkl
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.18.5
2 | scikit_image>=0.16.2
3 | pytest>=5.4.3
4 | filelock>=3.0.12
5 | gym>=0.17.2
6 | tensorboardX>=2.1
7 | imageio-ffmpeg>=0.4.2
8 | jupyter
9 | nbsphinx>=0.7.1
10 | sphinx>=3.2.1
11 | sphinxcontrib-napoleon>=0.7
12 | sphinxcontrib-osexample>=0.1.1
13 | sphinx-rtd-theme>=0.5.0
14 | sigpy
15 | fastmri@git+https://github.com/facebookresearch/fastMRI.git
--------------------------------------------------------------------------------
/resources/equispaced_4x_128.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/resources/equispaced_4x_128.pt
--------------------------------------------------------------------------------
/resources/spectrum_16x_128.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/resources/spectrum_16x_128.pt
--------------------------------------------------------------------------------
/resources/spectrum_4x_128.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/resources/spectrum_4x_128.pt
--------------------------------------------------------------------------------
/resources/spectrum_8x_128.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianweiy/SeqMRI/930c056284ab5538881cb6cf4138a107944fd29c/resources/spectrum_8x_128.pt
--------------------------------------------------------------------------------
/split_data.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import h5py
3 |
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from tqdm import tqdm
7 |
8 | train_dir = pathlib.Path('/home/tianweiy/ActiveMRI-Release/datasets/brain/train')
9 | val_dir = pathlib.Path('/home/tianweiy/ActiveMRI-Release/datasets/brain/val')
10 |
11 | def remove_kspace(path):
12 | for fname in tqdm(list(path.iterdir())):
13 | if not fname.name.endswith('.h5'):
14 | continue # Skip directories
15 | new_dir = fname.parent.parent / pathlib.Path(str(fname.parent.name) + '_no_kspace')
16 | if not new_dir.exists():
17 | new_dir.mkdir(parents=False)
18 | new_filename = new_dir / fname.name
19 | if new_filename.exists():
20 | continue # Skip already done files
21 | f = h5py.File(fname, 'r')
22 | fn = h5py.File(new_filename, 'w')
23 | for at in f.attrs:
24 | fn.attrs[at] = f.attrs[at]
25 | for dat in f:
26 | if dat == 'kspace':
27 | continue
28 | f.copy(dat, fn)
29 |
30 | # Run the calls below to remove the stored kspace from multicoil_ brain .h5 file, which will save on I/O later.
31 | # We don't need the multicoil kspace since we will construct singlecoil kspace from the ground truth images.
32 | # Commented out for safety.
33 |
34 | remove_kspace(train_dir)
35 | remove_kspace(val_dir)
36 |
37 | def create_train_test_split(orig_train_dir, target_test_dir, test_frac):
38 | """
39 | Creates a train and test split from the provided training data. Works by
40 | moving random volumes from the training directory to a new test directory.
41 |
42 | WARNING: Only use this function once to create the required datasets!
43 | """
44 | import shutil
45 | np.random.seed(0)
46 |
47 | files = sorted(list(orig_train_dir.iterdir()))
48 | target_test_dir.mkdir(parents=False, exist_ok=False)
49 |
50 | permutation = np.random.permutation(len(files))
51 | test_indices = permutation[:int(len(files) * test_frac)]
52 | test_files = list(np.array(files)[test_indices])
53 |
54 | for i, file in enumerate(test_files):
55 | print("Moving file {}/{}".format(i + 1, len(test_files)))
56 | shutil.move(file, target_test_dir / file.name)
57 |
58 |
59 | def count_slices(data_dir, dataset):
60 | vol_count, slice_count = 0, 0
61 | for fname in data_dir.iterdir():
62 | with h5py.File(fname, 'r') as data:
63 | if dataset == 'brain':
64 | gt = data['reconstruction_rss'][()]
65 | vol_count += 1
66 | slice_count += gt.shape[0]
67 | print(f'{vol_count} volumes, {slice_count} slices')
68 |
69 | # For both Knee and Brain data, split off 20% of train as test
70 | dataset = 'brain' # or 'brain'
71 | train_dir = pathlib.Path('/home/tianweiy/ActiveMRI-Release/datasets/brain/train_no_kspace')
72 | val_dir = pathlib.Path('/home/tianweiy/ActiveMRI-Release/datasets/brain/val_no_kspace')
73 | test_dir = pathlib.Path('/home/tianweiy/ActiveMRI-Release/datasets/brain/test_no_kspace')
74 |
75 | test_frac = 0.2
76 |
77 | # Run this to split of test_frac of train data into test data.
78 | # Commented out for safety.
79 |
80 | # create_train_test_split(train_dir, test_dir, test_frac)
81 |
82 | count_slices(train_dir, dataset)
83 | count_slices(val_dir, dataset)
84 | count_slices(test_dir, dataset)
85 |
--------------------------------------------------------------------------------