├── common ├── __init__.py ├── utils.py ├── test_subsample.py ├── args.py ├── evaluate.py └── subsample.py ├── experimental ├── __init__.py └── MANet │ ├── __init__.py │ ├── job.sh │ ├── train.py │ ├── mri_manet_dirs.yaml │ ├── RAdam.py │ └── module_MANet.py ├── fastmri ├── models │ ├── __init__.py │ ├── utils │ │ ├── solver.py │ │ ├── loss.py │ │ ├── transforms.py │ │ ├── comm.py │ │ └── dataloader.py │ ├── common.py │ └── MANet.py ├── data │ ├── __init__.py │ ├── volume_sampler.py │ ├── README.md │ ├── transforms.py │ ├── mri_data.py │ └── subsample.py ├── __init__.py ├── coil_combine.py ├── utils.py ├── losses.py ├── math.py ├── evaluate.py └── mri_module.py └── README.md /common/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /experimental/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /experimental/MANet/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /experimental/MANet/job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -n 10 3 | #SBATCH --gres=gpu:v100:2 4 | #SBATCH --time=48:00:00 5 | 6 | export PATH=/home/jc3/miniconda2/bin/:$PATH 7 | source activate pytorch-1.5.0 8 | 9 | 10 | python train.py -------------------------------------------------------------------------------- /fastmri/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | # from .unet import UNet 9 | # from .varnet import NormUnet, SensitivityModel, VarNet, VarNetBlock 10 | -------------------------------------------------------------------------------- /fastmri/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from .mri_data import SliceDataset 9 | # from .mri_data_ixi_jiangsu import IXIdataset 10 | # from .mri_data_ixi_lianying import IXIdataset 11 | # form .mri_data_ixi_jiangsu import 12 | -------------------------------------------------------------------------------- /fastmri/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from .coil_combine import rss, rss_complex 9 | from .losses import SSIMLoss 10 | from .math import ( 11 | complex_abs, 12 | complex_abs_sq, 13 | complex_conj, 14 | complex_mul, 15 | fft2c, 16 | fftshift, 17 | ifft2c, 18 | ifftshift, 19 | roll, 20 | tensor_to_complex_np, 21 | ) 22 | from .mri_module import MriModule 23 | # from .mri_module_ixi import MriModule 24 | # from .mri_module_ixi_T1T2 import MriModule as MriModuleT1T2 25 | from .utils import save_reconstructions 26 | -------------------------------------------------------------------------------- /fastmri/models/utils/solver.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Yuanhao Cai 3 | @date: 2020.03 4 | """ 5 | 6 | import torch.optim as optim 7 | 8 | def make_optimizer(cfg, model, num_gpu): 9 | if cfg.SOLVER.OPTIMIZER == 'Adam': 10 | optimizer = optim.Adam(model.parameters(), 11 | lr=cfg.SOLVER.BASE_LR * num_gpu, 12 | betas=(0.9, 0.999), eps=1e-08, 13 | weight_decay=cfg.SOLVER.WEIGHT_DECAY) 14 | elif cfg.OPTIMIZER == 'SGD': 15 | optimizer = optim.SGD(model.parameters(), lr=cfg.SOLVER.BASE_LR, 16 | momentum=cfg.SOLVER.MOMENTUM, 17 | weight_decay=cfg.SOLVER.WEIGHT_DECAY) 18 | 19 | return optimizer 20 | 21 | 22 | def make_lr_scheduler(cfg, optimizer): 23 | w_iters = cfg.SOLVER.WARMUP_ITERS 24 | w_fac = cfg.SOLVER.WARMUP_FACTOR 25 | max_iter = cfg.SOLVER.MAX_ITER 26 | lr_lambda = lambda iteration : w_fac + (1 - w_fac) * iteration / w_iters \ 27 | if iteration < w_iters \ 28 | else 1 - (iteration - w_iters) / (max_iter - w_iters) 29 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1) 30 | 31 | return scheduler 32 | 33 | -------------------------------------------------------------------------------- /fastmri/coil_combine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | 10 | import fastmri 11 | 12 | 13 | def rss(data, dim=0): 14 | """ 15 | Compute the Root Sum of Squares (RSS). 16 | 17 | RSS is computed assuming that dim is the coil dimension. 18 | 19 | Args: 20 | data (torch.Tensor): The input tensor 21 | dim (int): The dimensions along which to apply the RSS transform 22 | 23 | Returns: 24 | torch.Tensor: The RSS value. 25 | """ 26 | return torch.sqrt((data ** 2).sum(dim)) 27 | 28 | 29 | def rss_complex(data, dim=0): 30 | """ 31 | Compute the Root Sum of Squares (RSS) for complex inputs. 32 | 33 | RSS is computed assuming that dim is the coil dimension. 34 | 35 | Args: 36 | data (torch.Tensor): The input tensor 37 | dim (int): The dimensions along which to apply the RSS transform 38 | 39 | Returns: 40 | torch.Tensor: The RSS value. 41 | """ 42 | return torch.sqrt(fastmri.complex_abs_sq(data).sum(dim)) 43 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | import json 8 | 9 | import h5py 10 | 11 | 12 | def save_reconstructions(reconstructions, out_dir): 13 | """ 14 | Saves the reconstructions from a model into h5 files that is appropriate for submission 15 | to the leaderboard. 16 | 17 | Args: 18 | reconstructions (dict[str, np.array]): A dictionary mapping input filenames to 19 | corresponding reconstructions (of shape num_slices x height x width). 20 | out_dir (pathlib.Path): Path to the output directory where the reconstructions 21 | should be saved. 22 | """ 23 | out_dir.mkdir(exist_ok=True, parents=True) 24 | for fname, recons in reconstructions.items(): 25 | with h5py.File(out_dir / fname, 'w') as f: 26 | f.create_dataset('reconstruction', data=recons) 27 | 28 | 29 | def tensor_to_complex_np(data): 30 | """ 31 | Converts a complex torch tensor to numpy array. 32 | Args: 33 | data (torch.Tensor): Input data to be converted to numpy. 34 | 35 | Returns: 36 | np.array: Complex numpy version of data 37 | """ 38 | data = data.numpy() 39 | return data[..., 0] + 1j * data[..., 1] 40 | -------------------------------------------------------------------------------- /fastmri/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import h5py 9 | import os 10 | import torch 11 | 12 | def save_reconstructions(reconstructions, out_dir): 13 | """ 14 | Save reconstruction images. 15 | 16 | This function writes to h5 files that are appropriate for submission to the 17 | leaderboard. 18 | 19 | Args: 20 | reconstructions (dict[str, np.array]): A dictionary mapping input 21 | filenames to corresponding reconstructions (of shape num_slices x 22 | height x width). 23 | out_dir (pathlib.Path): Path to the output directory where the 24 | reconstructions should be saved. 25 | """ 26 | # out_dir.mkdir(exist_ok=True, parents=True) 27 | 28 | os.makedirs(str(out_dir), exist_ok=True) 29 | print(out_dir)#logs/unet/unet_demo/zpCar6X/unet_demo/reconstructions 30 | for fname, recons in reconstructions.items(): 31 | print(fname) 32 | with h5py.File(str(out_dir) + '/' + str(fname) + '.hdf5', "w") as f: 33 | print(fname) 34 | if isinstance(recons, list): 35 | recons = [r[1][None, ...] for r in recons] 36 | recons = torch.cat(recons, dim=0) 37 | f.create_dataset("reconstruction", data=recons) 38 | 39 | -------------------------------------------------------------------------------- /common/test_subsample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import pytest 10 | import torch 11 | 12 | from common.subsample import RandomMaskFunc 13 | 14 | 15 | @pytest.mark.parametrize("center_fracs, accelerations, batch_size, dim", [ 16 | ([0.2], [4], 4, 320), 17 | ([0.2, 0.4], [4, 8], 2, 368), 18 | ]) 19 | def test_random_mask_reuse(center_fracs, accelerations, batch_size, dim): 20 | mask_func = RandomMaskFunc(center_fracs, accelerations) 21 | shape = (batch_size, dim, dim, 2) 22 | mask1 = mask_func(shape, seed=123) 23 | mask2 = mask_func(shape, seed=123) 24 | mask3 = mask_func(shape, seed=123) 25 | assert torch.all(mask1 == mask2) 26 | assert torch.all(mask2 == mask3) 27 | 28 | 29 | @pytest.mark.parametrize("center_fracs, accelerations, batch_size, dim", [ 30 | ([0.2], [4], 4, 320), 31 | ([0.2, 0.4], [4, 8], 2, 368), 32 | ]) 33 | def test_random_mask_low_freqs(center_fracs, accelerations, batch_size, dim): 34 | mask_func = RandomMaskFunc(center_fracs, accelerations) 35 | shape = (batch_size, dim, dim, 2) 36 | mask = mask_func(shape, seed=123) 37 | mask_shape = [1 for _ in shape] 38 | mask_shape[-2] = dim 39 | assert list(mask.shape) == mask_shape 40 | 41 | num_low_freqs_matched = False 42 | for center_frac in center_fracs: 43 | num_low_freqs = int(round(dim * center_frac)) 44 | pad = (dim - num_low_freqs + 1) // 2 45 | if np.all(mask[pad:pad + num_low_freqs].numpy() == 1): 46 | num_low_freqs_matched = True 47 | assert num_low_freqs_matched 48 | -------------------------------------------------------------------------------- /fastmri/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SSIMLoss(nn.Module): 14 | """ 15 | SSIM loss module. 16 | """ 17 | 18 | def __init__(self, win_size=7, k1=0.01, k2=0.03): 19 | """ 20 | Args: 21 | win_size (int, default=7): Window size for SSIM calculation. 22 | k1 (float, default=0.1): k1 parameter for SSIM calculation. 23 | k2 (float, default=0.03): k2 parameter for SSIM calculation. 24 | """ 25 | super().__init__() 26 | self.win_size = win_size 27 | self.k1, self.k2 = k1, k2 28 | self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size ** 2) 29 | NP = win_size ** 2 30 | self.cov_norm = NP / (NP - 1) 31 | 32 | def forward(self, X, Y, data_range): 33 | data_range = data_range[:, None, None, None] 34 | C1 = (self.k1 * data_range) ** 2 35 | C2 = (self.k2 * data_range) ** 2 36 | ux = F.conv2d(X, self.w) 37 | uy = F.conv2d(Y, self.w) 38 | uxx = F.conv2d(X * X, self.w) 39 | uyy = F.conv2d(Y * Y, self.w) 40 | uxy = F.conv2d(X * Y, self.w) 41 | vx = self.cov_norm * (uxx - ux * ux) 42 | vy = self.cov_norm * (uyy - uy * uy) 43 | vxy = self.cov_norm * (uxy - ux * uy) 44 | A1, A2, B1, B2 = ( 45 | 2 * ux * uy + C1, 46 | 2 * vxy + C2, 47 | ux ** 2 + uy ** 2 + C1, 48 | vx + vy + C2, 49 | ) 50 | D = B1 * B2 51 | S = (A1 * A2) / D 52 | 53 | return 1 - S.mean() 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MARIO 2 | # Deep multi-modal aggregation network for MR image reconstruction with auxiliary modality 3 | 4 | ## Dependencies 5 | * numpy==1.18.5 6 | * scikit_image==0.16.2 7 | * torchvision==0.8.1 8 | * torch==1.7.0 9 | * runstats==1.8.0 10 | * pytorch_lightning==1.0.6 11 | * h5py==2.10.0 12 | * PyYAML==5.4 13 | 14 | Our code is based on the fastMRI, more details can be find at https://github.com/facebookresearch/fastMRI. 15 | 16 | **Train** 17 | ```bash 18 | cd experimental/MANet/ 19 | sbatch job.sh 20 | ``` 21 | 22 | **Change other arguments that you can train your own model.** 23 | 24 | **The detailed parameter settings can be find in our arXiv paper.** 25 | 26 | Citation 27 | 28 | If you use our code in your project, please cite the arXiv papers: 29 | 30 | ``` 31 | @inproceedings{zbontar2018fastMRI, 32 | title={{fastMRI}: An Open Dataset and Benchmarks for Accelerated {MRI}}, 33 | author={Jure Zbontar and Florian Knoll and Anuroop Sriram and Tullie Murrell and Zhengnan Huang and Matthew J. Muckley and Aaron Defazio and Ruben Stern and Patricia Johnson and Mary Bruno and Marc Parente and Krzysztof J. Geras and Joe Katsnelson and Hersh Chandarana and Zizhao Zhang and Michal Drozdzal and Adriana Romero and Michael Rabbat and Pascal Vincent and Nafissa Yakubova and James Pinkerton and Duo Wang and Erich Owens and C. Lawrence Zitnick and Michael P. Recht and Daniel K. Sodickson and Yvonne W. Lui}, 34 | journal = {ArXiv e-prints}, 35 | archivePrefix = "arXiv", 36 | eprint = {1811.08839}, 37 | year={2018} 38 | } 39 | @article{feng2021multi, 40 | title={Deep multi-modal aggregation network for MR image reconstruction with auxiliary modality}, 41 | author={Feng, Chun-Mei and Fu, Huazhu and Zhou, Tianfei and Xu, Yong and Shao, Ling and Zhang, David}, 42 | journal={arXiv preprint arXiv:2110.08080}, 43 | year={2021} 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /fastmri/models/utils/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Yuanhao Cai 3 | @date: 2020.03 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | class JointsL2Loss(nn.Module): 10 | def __init__(self, has_ohkm=False, topk=8, thresh1=1, thresh2=0): 11 | super(JointsL2Loss, self).__init__() 12 | self.has_ohkm = has_ohkm 13 | self.topk = topk 14 | self.t1 = thresh1 15 | self.t2 = thresh2 16 | method = 'none' if self.has_ohkm else 'mean' 17 | self.calculate = nn.MSELoss(reduction=method) 18 | 19 | def forward(self, output, valid, label): 20 | assert output.shape == label.shape 21 | batch_size = output.size(0) 22 | keypoint_num = output.size(1) 23 | loss = 0 24 | 25 | for i in range(batch_size): 26 | pred = output[i].reshape(keypoint_num, -1) 27 | gt = label[i].reshape(keypoint_num, -1) 28 | 29 | if not self.has_ohkm: 30 | weight = torch.gt(valid[i], self.t1).float() 31 | gt = gt * weight 32 | 33 | tmp_loss = self.calculate(pred, gt) 34 | 35 | if self.has_ohkm: 36 | tmp_loss = tmp_loss.mean(dim=1) 37 | weight = torch.gt(valid[i].squeeze(), self.t2).float() 38 | tmp_loss = tmp_loss * weight 39 | topk_val, topk_id = torch.topk(tmp_loss, k=self.topk, dim=0, 40 | sorted=False) 41 | sample_loss = topk_val.mean(dim=0) 42 | else: 43 | sample_loss = tmp_loss 44 | 45 | loss = loss + sample_loss 46 | 47 | return loss / batch_size 48 | 49 | 50 | if __name__ == '__main__': 51 | a = torch.ones(1, 17, 12, 12) 52 | b = torch.ones(1, 17, 12, 12) 53 | c = torch.ones(1, 17, 1) * 2 54 | loss = JointsL2Loss() 55 | # loss = JointsL2Loss(has_ohkm=True) 56 | device = torch.device('cuda') 57 | a = a.to(device) 58 | b = b.to(device) 59 | c = c.to(device) 60 | loss = loss.to(device) 61 | res = loss(a, c, b) 62 | print(res) 63 | 64 | 65 | -------------------------------------------------------------------------------- /common/args.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import pathlib 10 | 11 | 12 | class Args(argparse.ArgumentParser): 13 | """ 14 | Defines global default arguments. 15 | """ 16 | 17 | def __init__(self, **overrides): 18 | """ 19 | Args: 20 | **overrides (dict, optional): Keyword arguments used to override default argument values 21 | """ 22 | 23 | super().__init__(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | 25 | self.add_argument('--seed', default=42, type=int, help='Seed for random number generators') 26 | self.add_argument('--resolution', default=320, type=int, help='Resolution of images (brain \ 27 | challenge expects resolution of 384, knee resolution expcts resolution of 320') 28 | 29 | # Data parameters 30 | self.add_argument('--challenge', choices=['singlecoil', 'multicoil'], default='singlecoil', help='Which challenge') 31 | self.add_argument('--data-path', type=pathlib.Path, default='/home/jc3/Data/', help='Path to the dataset') 32 | self.add_argument('--sample-rate', type=float, default=1., 33 | help='Fraction of total volumes to include') 34 | 35 | # Mask parameters 36 | self.add_argument('--mask-type', choices=['random', 'equispaced'], default='random', 37 | help='The type of mask function to use') 38 | self.add_argument('--accelerations', nargs='+', default=[4], type=int, 39 | help='Ratio of k-space columns to be sampled. If multiple values are ' 40 | 'provided, then one of those is chosen uniformly at random for ' 41 | 'each volume.') 42 | self.add_argument('--center-fractions', nargs='+', default=[0.08], type=float, 43 | help='Fraction of low-frequency k-space columns to be sampled. Should ' 44 | 'have the same length as accelerations') 45 | 46 | # Override defaults with passed overrides 47 | self.set_defaults(**overrides) 48 | -------------------------------------------------------------------------------- /fastmri/models/utils/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Yuanhao Cai 3 | @date: 2020.03 4 | """ 5 | 6 | import numpy as np 7 | import cv2 8 | 9 | 10 | def get_affine_transform(center, scale, rot, output_size): 11 | if not isinstance(scale, np.ndarray) and not isinstance(scale, list): 12 | scale = np.array([scale, scale]) 13 | scale_tmp = scale * 200.0 14 | 15 | src_w = scale_tmp[0] 16 | dst_w = output_size[1] 17 | dst_h = output_size[0] 18 | 19 | rot_rad = np.pi * rot / 180 20 | src_dir = get_dir([0, src_w * -0.5], rot_rad) 21 | dst_dir = np.array([0, dst_w * -0.5], np.float32) 22 | 23 | src = np.zeros((3, 2), dtype=np.float32) 24 | dst = np.zeros((3, 2), dtype=np.float32) 25 | src[0, :] = center 26 | src[1, :] = center + src_dir 27 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 28 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 29 | 30 | src[2:, :] = get_3rd_point(src[0, :], src[1, :]) 31 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) 32 | 33 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 34 | 35 | return trans 36 | 37 | 38 | def affine_transform(pt, t): 39 | new_pt = np.array([pt[0], pt[1], 1.]) 40 | new_pt = np.dot(t, new_pt) 41 | return new_pt[:2] 42 | 43 | 44 | def get_3rd_point(a, b): 45 | direct = a - b 46 | return b + np.array([-direct[1], direct[0]], dtype=np.float32) 47 | 48 | 49 | def get_dir(src_point, rot_rad): 50 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 51 | 52 | src_result = [0, 0] 53 | src_result[0] = src_point[0] * cs - src_point[1] * sn 54 | src_result[1] = src_point[0] * sn + src_point[1] * cs 55 | 56 | return src_result 57 | 58 | 59 | def flip_back(output, pairs): 60 | output = output[:, :, :, ::-1] 61 | 62 | for pair in pairs: 63 | tmp = output[:, pair[0], :, :].copy() 64 | output[:, pair[0], :, :] = output[:, pair[1], :, :] 65 | output[:, pair[1], :, :] = tmp 66 | 67 | return output 68 | 69 | 70 | def flip_joints(joints, joints_vis, width, pairs): 71 | joints[:, 0] = width - joints[:, 0] - 1 72 | 73 | for pair in pairs: 74 | joints[pair[0], :], joints[pair[1], :] = \ 75 | joints[pair[1], :], joints[pair[0], :].copy() 76 | joints_vis[pair[0], :], joints_vis[pair[1], :] = \ 77 | joints_vis[pair[1], :], joints_vis[pair[0], :].copy() 78 | 79 | return joints, joints_vis 80 | -------------------------------------------------------------------------------- /fastmri/models/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size//2), bias=bias) 11 | 12 | class MeanShift(nn.Conv2d): 13 | def __init__( 14 | self, rgb_range, 15 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 16 | 17 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 18 | std = torch.Tensor(rgb_std) 19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | class BasicBlock(nn.Sequential): 25 | def __init__( 26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 27 | bn=True, act=nn.ReLU(True)): 28 | 29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 30 | if bn: 31 | m.append(nn.BatchNorm2d(out_channels)) 32 | if act is not None: 33 | m.append(act) 34 | 35 | super(BasicBlock, self).__init__(*m) 36 | 37 | class ResBlock(nn.Module): 38 | def __init__( 39 | self, conv, n_feats, kernel_size, 40 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 41 | 42 | super(ResBlock, self).__init__() 43 | m = [] 44 | for i in range(2): 45 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 46 | if bn: 47 | m.append(nn.BatchNorm2d(n_feats)) 48 | if i == 0: 49 | m.append(act) 50 | 51 | self.body = nn.Sequential(*m) 52 | self.res_scale = res_scale 53 | 54 | def forward(self, x): 55 | res = self.body(x).mul(self.res_scale) 56 | res += x 57 | 58 | return res 59 | 60 | class Upsampler(nn.Sequential): 61 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 62 | 63 | m = [] 64 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 65 | for _ in range(int(math.log(scale, 2))): 66 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 67 | m.append(nn.PixelShuffle(2)) 68 | if bn: 69 | m.append(nn.BatchNorm2d(n_feats)) 70 | if act == 'relu': 71 | m.append(nn.ReLU(True)) 72 | elif act == 'prelu': 73 | m.append(nn.PReLU(n_feats)) 74 | 75 | elif scale == 3: 76 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 77 | m.append(nn.PixelShuffle(3)) 78 | if bn: 79 | m.append(nn.BatchNorm2d(n_feats)) 80 | if act == 'relu': 81 | m.append(nn.ReLU(True)) 82 | elif act == 'prelu': 83 | m.append(nn.PReLU(n_feats)) 84 | else: 85 | raise NotImplementedError 86 | 87 | super(Upsampler, self).__init__(*m) -------------------------------------------------------------------------------- /experimental/MANet/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Chun-Mei Feng, Harbin Institute of Technology, Shenzhen / Inception Institute of Artifical Intelligence,UAE 4 | 5 | Our code bulid on fastMRI https://github.com/facebookresearch/fastMRI 6 | 7 | """ 8 | 9 | import pathlib 10 | import sys 11 | from argparse import ArgumentParser 12 | 13 | from pytorch_lightning import Trainer, seed_everything 14 | 15 | 16 | sys.path.append('/home/chunmeifeng/SANet/') 17 | 18 | from fastmri.data.mri_data import fetch_dir 19 | from module_SANet import SRModule 20 | 21 | 22 | def main(args): 23 | """Main training routine.""" 24 | # ------------------------ 25 | # 1 INIT LIGHTNING MODEL 26 | # ------------------------ 27 | seed_everything(args.seed) 28 | model = SRModule(**vars(args)) 29 | 30 | # ------------------------ 31 | # 2 INIT TRAINER 32 | # ------------------------ 33 | trainer = Trainer.from_argparse_args(args) 34 | 35 | # ------------------------ 36 | # 3 START TRAINING OR TEST 37 | # ------------------------ 38 | if args.mode == "train": 39 | trainer.fit(model) 40 | elif args.mode == "test": 41 | assert args.resume_from_checkpoint is not None 42 | trainer.test(model) 43 | else: 44 | raise ValueError(f"unrecognized mode {args.mode}") 45 | 46 | 47 | def build_args(): 48 | # ------------------------ 49 | # TRAINING ARGUMENTS 50 | # ------------------------ 51 | path_config = pathlib.Path.cwd() / "mri_manet_dirs.yaml" 52 | knee_path = fetch_dir("knee_path", path_config) 53 | logdir = fetch_dir("log_path", path_config) / "MANet" / "rec" 54 | 55 | 56 | parent_parser = ArgumentParser(add_help=False) 57 | 58 | parser = SRModule.add_model_specific_args(parent_parser) 59 | parser = Trainer.add_argparse_args(parser) 60 | 61 | num_gpus = 1 62 | backend = "ddp" 63 | batch_size = 4 if backend == "ddp" else num_gpus 64 | 65 | # module config 66 | config = dict( 67 | n_channels_in=1, 68 | n_channels_out=1, 69 | lr=0.001, 70 | lr_step_size=40, 71 | lr_gamma=0.1, 72 | weight_decay=0.0, 73 | data_path=data_path, 74 | exp_dir=logdir, 75 | exp_name="unet_demo", 76 | test_split="test", 77 | batch_size=batch_size, 78 | ixi_args=ixi_args, 79 | ) 80 | parser.set_defaults(**config) 81 | 82 | # trainer config 83 | parser.set_defaults( 84 | gpus=num_gpus, 85 | max_epochs=35, 86 | default_root_dir=logdir, 87 | replace_sampler_ddp=(backend != "ddp"), 88 | distributed_backend=backend, 89 | seed=42, 90 | deterministic=True, 91 | # resume_from_checkpoint = '/checkpoints/epoch=34.ckpt' 92 | ) 93 | 94 | 95 | parser.add_argument("--mode", default="train", type=str) 96 | args = parser.parse_args() 97 | 98 | return args 99 | 100 | 101 | def run_cli(): 102 | args = build_args() 103 | 104 | # --------------------- 105 | # RUN TRAINING 106 | # --------------------- 107 | main(args) 108 | 109 | 110 | if __name__ == "__main__": 111 | run_cli() 112 | -------------------------------------------------------------------------------- /fastmri/models/utils/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | 6 | import pickle 7 | 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | def get_world_size(): 13 | if not dist.is_available(): 14 | return 1 15 | if not dist.is_initialized(): 16 | return 1 17 | return dist.get_world_size() 18 | 19 | 20 | def get_rank(): 21 | if not dist.is_available(): 22 | return 0 23 | if not dist.is_initialized(): 24 | return 0 25 | return dist.get_rank() 26 | 27 | 28 | def is_main_process(): 29 | return get_rank() == 0 30 | 31 | 32 | def synchronize(): 33 | """ 34 | Helper function to synchronize (barrier) among all processes when 35 | using distributed training 36 | """ 37 | if not dist.is_available(): 38 | return 39 | if not dist.is_initialized(): 40 | return 41 | world_size = dist.get_world_size() 42 | if world_size == 1: 43 | return 44 | dist.barrier() 45 | 46 | 47 | def all_gather(data): 48 | """ 49 | Run all_gather on arbitrary picklable data (not necessarily tensors) 50 | Args: 51 | data: any picklable object 52 | Returns: 53 | list[data]: list of data gathered from each rank 54 | """ 55 | world_size = get_world_size() 56 | if world_size == 1: 57 | return [data] 58 | 59 | # serialized to a Tensor 60 | buffer = pickle.dumps(data) 61 | storage = torch.ByteStorage.from_buffer(buffer) 62 | tensor = torch.ByteTensor(storage).to("cuda") 63 | 64 | # obtain Tensor size of each rank 65 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 66 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] 67 | dist.all_gather(size_list, local_size) 68 | size_list = [int(size.item()) for size in size_list] 69 | max_size = max(size_list) 70 | 71 | # receiving Tensor from all ranks 72 | # we pad the tensor because torch all_gather does not support 73 | # gathering tensors of different shapes 74 | tensor_list = [] 75 | for _ in size_list: 76 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 77 | if local_size != max_size: 78 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 79 | tensor = torch.cat((tensor, padding), dim=0) 80 | dist.all_gather(tensor_list, tensor) 81 | 82 | data_list = [] 83 | for size, tensor in zip(size_list, tensor_list): 84 | buffer = tensor.cpu().numpy().tobytes()[:size] 85 | data_list.append(pickle.loads(buffer)) 86 | 87 | return data_list 88 | 89 | 90 | def reduce_dict(input_dict, average=True): 91 | """ 92 | Args: 93 | input_dict (dict): all the values will be reduced 94 | average (bool): whether to do average or sum 95 | Reduce the values in the dictionary from all processes so that process with rank 96 | 0 has the averaged results. Returns a dict with the same fields as 97 | input_dict, after reduction. 98 | """ 99 | world_size = get_world_size() 100 | if world_size < 2: 101 | return input_dict 102 | with torch.no_grad(): 103 | names = [] 104 | values = [] 105 | # sort the keys so that they are consistent across processes 106 | for k in sorted(input_dict.keys()): 107 | names.append(k) 108 | values.append(input_dict[k]) 109 | values = torch.stack(values, dim=0) 110 | dist.reduce(values, dst=0) 111 | if dist.get_rank() == 0 and average: 112 | # only main process gets accumulated, so only divide by 113 | # world_size in this case 114 | values /= world_size 115 | reduced_dict = {k: v for k, v in zip(names, values)} 116 | return reduced_dict 117 | 118 | -------------------------------------------------------------------------------- /experimental/MANet/mri_manet_dirs.yaml: -------------------------------------------------------------------------------- 1 | 2 | #Data parameters 3 | dataset: "IXI" # Dataset type. currently only IXI supported. 4 | data_dir: "/home/jc3/Data/IXI/h5/" 5 | train_data_dir: "/home/jc3/Data/IXI/mat/train" # Training files dir, should contain hdf5 preprocessed data. 6 | val_data_dir: "/home/jc3/Data/IXI/mat/val" # Validation files dir, should contain hdf5 preprocessed data. 7 | output_dir: "./logs" # Directory to save checkpoints and tensorboard data. 8 | sampling_percentage: 30 # Sampling mask precentage (provided with the code 20%,30% and 50% sampling masks of 256X256). 9 | num_input_slices: 3 # Num of slices to use for input (3 means the predicted slice + previous slice + next slice). 10 | img_size: 256 # Input image size (256X256 for IXI). 11 | slice_range: [20,120] # Slices to use for training data. 12 | 13 | #Load checkpoint 14 | load_cp: 0 # 0 to start a new training or checkpoint path to load network weights. 15 | resume_training: 1 # 0 - Load only model weights , 1 - Load Weights + epoch number + optimizer and scheduler state. 16 | 17 | #Networks parameters 18 | bilinear: 1 # 1 - Use bilinear upsampling , 0 - Use up-conv. 19 | crop_center: 128 # Discriminator center crop size (128X128), to avoid classifying blank patches. 20 | 21 | #Training parameters 22 | lr: 0.001 # Learning rate 23 | epochs_n: 50 # Number of epochs 24 | batch_size: 64 # Batch size. Reduce if out of memory.Batch size of 32 256X256 images needs ~13GB memory. 25 | GAN_training: 1 # 1 - Use GAN training. 0 - No GAN (no discriminator training and adverserial loss) 26 | loss_weights: [1000, 1000, 5, 0.1 ] # Loss weighting [Imspac L2, Imspace L1, Kspace L2, GAN_Loss]. Losses are weighted to be roughly at the same scale. 27 | minmax_noise_val: [-0.01, 0.01] 28 | 29 | #Tensorboard 30 | tb_write_losses: 1 # Write losses and scalars to Tensorboard. 31 | tb_write_images: 1 # Write images to Tensorboard. 32 | 33 | #Runtime 34 | device: 'cuda' # For GPU training : 'cuda', for CPU training (not recomended!) 'cpu'. 35 | gpu_id: '4' # GPU ID to use. 36 | train_num_workers: 16 # Number of training dataset workers. Reduce if you are getting a shared memory error. 37 | val_num_workers: 4 # Number of validation dataset workers. Reduce if you are getting a shared memory error. 38 | 39 | #Predict parameters 40 | save_prediction: 1 # Save predicted images. 41 | save_path: "home/jc3/mycode/IXI_fastMRI/save_predictions/T2T1_Unet" # Path to save predictions 42 | visualize_images: 1 # Visualize predicted images. 43 | model: "model_path.pth" # Model checkpoint to use for prediction. 44 | predict_data_dir: "/home/jc3/Data/IXI_T2/h5/test" # Test set files dir, should contain hdf5 preprocessed data. 45 | 46 | -------------------------------------------------------------------------------- /fastmri/data/volume_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import math 9 | 10 | import numpy as np 11 | import torch 12 | import torch.distributed as dist 13 | from torch.utils.data import Sampler 14 | 15 | 16 | class VolumeSampler(Sampler): 17 | """ 18 | Sampler for volumetric MRI data. 19 | 20 | Based on pytorch DistributedSampler, the difference is that all instances 21 | from the same MRI volume need to go to the same node for distributed 22 | training. Dataset example is a list of tuples (fname, instance), where 23 | fname is essentially the volume name (actually a filename). 24 | """ 25 | 26 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0): 27 | """ 28 | Args: 29 | dataset (torch.utils.data.Dataset): An MRI dataset (e.g., SliceData). 30 | num_replicas (int, optional): Number of processes participating in 31 | distributed training. By default, :attr:`rank` is retrieved 32 | from the current distributed group. 33 | rank (int, optional): Rank of the current process within 34 | :attr:`num_replicas`. By default, :attr:`rank` is retrieved 35 | from the current distributed group. 36 | shuffle (bool, optional): If ``True`` (default), sampler will 37 | shuffle the indices. 38 | seed (int, optional): random seed used to shuffle the sampler if 39 | :attr:`shuffle=True`. This number should be identical across 40 | all processes in the distributed group. Default: ``0``. 41 | """ 42 | if num_replicas is None: 43 | if not dist.is_available(): 44 | raise RuntimeError("Requires distributed package to be available") 45 | num_replicas = dist.get_world_size() 46 | if rank is None: 47 | if not dist.is_available(): 48 | raise RuntimeError("Requires distributed package to be available") 49 | rank = dist.get_rank() 50 | self.dataset = dataset 51 | self.num_replicas = num_replicas 52 | self.rank = rank 53 | self.epoch = 0 54 | self.shuffle = shuffle 55 | self.seed = seed 56 | 57 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 58 | self.total_size = self.num_samples * self.num_replicas 59 | 60 | # get all file names and split them based on number of processes 61 | self.all_volume_names = np.array( 62 | sorted(list({example[0] for example in self.dataset.examples})) 63 | ) 64 | self.all_volumes_split = np.array_split( 65 | self.all_volume_names, self.num_replicas 66 | ) 67 | 68 | # get slice indices for each file name 69 | indices = [list() for _ in range(self.num_replicas)] 70 | 71 | for i, example in enumerate(self.dataset.examples): 72 | vname = example[0] 73 | for rank in range(self.num_replicas): 74 | if vname in self.all_volumes_split[rank]: 75 | indices[rank].append(i) 76 | 77 | # need to send equal number of samples to each process - take the max 78 | self.num_samples = max([len(l) for l in indices]) 79 | self.total_size = self.num_samples * self.num_replicas 80 | self.indices = indices[self.rank] 81 | 82 | def __iter__(self): 83 | if self.shuffle: 84 | # deterministically shuffle based on epoch and seed 85 | g = torch.Generator() 86 | g.manual_seed(self.seed + self.epoch) 87 | ordering = torch.randperm(len(self.indices), generator=g).tolist() 88 | indices = list(np.array(self.indices)[ordering]) 89 | else: 90 | indices = self.indices 91 | 92 | # add extra samples to match num_samples 93 | indices = indices + indices[: self.num_samples - len(indices)] 94 | assert len(indices) == self.num_samples 95 | 96 | return iter(indices) 97 | 98 | def __len__(self): 99 | return self.num_samples 100 | 101 | def set_epoch(self, epoch): 102 | self.epoch = epoch 103 | -------------------------------------------------------------------------------- /common/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import pathlib 10 | from argparse import ArgumentParser 11 | 12 | import h5py 13 | import numpy as np 14 | from runstats import Statistics 15 | from skimage.metrics import structural_similarity, peak_signal_noise_ratio 16 | from data import transforms 17 | 18 | 19 | def mse(gt, pred): 20 | """ Compute Mean Squared Error (MSE) """ 21 | return np.mean((gt - pred) ** 2) 22 | 23 | 24 | def nmse(gt, pred): 25 | """ Compute Normalized Mean Squared Error (NMSE) """ 26 | return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 27 | 28 | 29 | def psnr(gt, pred): 30 | """ Compute Peak Signal to Noise Ratio metric (PSNR) """ 31 | return peak_signal_noise_ratio(gt, pred, data_range=gt.max()) 32 | 33 | 34 | def ssim(gt, pred): 35 | """ Compute Structural Similarity Index Metric (SSIM). """ 36 | return structural_similarity( 37 | gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max() 38 | ) 39 | 40 | 41 | METRIC_FUNCS = dict( 42 | MSE=mse, 43 | NMSE=nmse, 44 | PSNR=psnr, 45 | SSIM=ssim, 46 | ) 47 | 48 | 49 | class Metrics: 50 | """ 51 | Maintains running statistics for a given collection of metrics. 52 | """ 53 | 54 | def __init__(self, metric_funcs): 55 | self.metrics = { 56 | metric: Statistics() for metric in metric_funcs 57 | } 58 | 59 | def push(self, target, recons): 60 | for metric, func in METRIC_FUNCS.items(): 61 | self.metrics[metric].push(func(target, recons)) 62 | 63 | def means(self): 64 | return { 65 | metric: stat.mean() for metric, stat in self.metrics.items() 66 | } 67 | 68 | def stddevs(self): 69 | return { 70 | metric: stat.stddev() for metric, stat in self.metrics.items() 71 | } 72 | 73 | def __repr__(self): 74 | means = self.means() 75 | stddevs = self.stddevs() 76 | metric_names = sorted(list(means)) 77 | return ' '.join( 78 | f'{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}' for name in metric_names 79 | ) 80 | 81 | 82 | def evaluate(args, recons_key): 83 | metrics = Metrics(METRIC_FUNCS) 84 | 85 | for tgt_file in args.target_path.iterdir(): 86 | with h5py.File(tgt_file, 'r') as target, h5py.File( 87 | args.predictions_path / tgt_file.name, 'r') as recons: 88 | if args.acquisition and args.acquisition != target.attrs['acquisition']: 89 | continue 90 | 91 | if args.acceleration and target.attrs['acceleration'] != args.acceleration: 92 | continue 93 | 94 | target = target[recons_key][()] 95 | recons = recons['reconstruction'][()] 96 | target = transforms.center_crop(target, (target.shape[-1], target.shape[-1])) 97 | recons = transforms.center_crop(recons, (target.shape[-1], target.shape[-1])) 98 | metrics.push(target, recons) 99 | return metrics 100 | 101 | 102 | if __name__ == '__main__': 103 | parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 104 | parser.add_argument('--target-path', type=pathlib.Path, required=True, 105 | help='Path to the ground truth data') 106 | parser.add_argument('--predictions-path', type=pathlib.Path, required=True, 107 | help='Path to reconstructions') 108 | parser.add_argument('--challenge', choices=['singlecoil', 'multicoil'], required=True, 109 | help='Which challenge') 110 | parser.add_argument('--acceleration', type=int, default=None) 111 | parser.add_argument('--acquisition', choices=['CORPD_FBK', 'CORPDFS_FBK', 'AXT1', 'AXT1PRE', 112 | 'AXT1POST', 'AXT2', 'AXFLAIR'], default=None, 113 | help='If set, only volumes of the specified acquisition type are used ' 114 | 'for evaluation. By default, all volumes are included.') 115 | args = parser.parse_args() 116 | 117 | recons_key = 'reconstruction_rss' if args.challenge == 'multicoil' else 'reconstruction_esc' 118 | metrics = evaluate(args, recons_key) 119 | print(metrics) 120 | -------------------------------------------------------------------------------- /fastmri/data/README.md: -------------------------------------------------------------------------------- 1 | # MRI Data Loader and Transforms 2 | 3 | This directory provides a reference data loader to read the fastMRI data one slice at a time and some useful data transforms to work with the data in PyTorch. 4 | 5 | Each partition (train, validation or test) of the fastMRI data is distributed as a set of HDF5 files, such that each HDF5 file contains data from one MR acquisition. The set of fields and attributes in these HDF5 files depends on the track (single-coil or multi-coil) and the data partition. 6 | 7 | ## 2020 fastMRI Challenge (Brain Data) 8 | 9 | For the 2020 fastMRI challenge, all data will be on the brain data with a multi-coil track. 10 | 11 | ### Multi-Coil Track (2020) 12 | 13 | * Training & Validation data: 14 | * `kspace`: Multi-coil k-space data. The shape of the kspace tensor is (number of slices, number of coils, height, width). 15 | * `reconstruction_rss`: Root-sum-of-squares reconstruction of the multi-coil k-space data cropped to the center. 16 | * Test data: 17 | * `kspace`: Undersampled multi-coil k-space. The shape of the kspace tensor is (number of slices, number of coils, height, width). 18 | * `mask` Defines the undersampled Cartesian k-space trajectory. The number of elements in the mask tensor is the same as the width of k-space. 19 | 20 | ## 2019 fastMRI Challenge (Knee Data) 21 | 22 | ### Single-Coil Track 23 | 24 | * Training & Validation data: 25 | * `kspace`: Emulated single-coil k-space data. The shape of the kspace tensor is (number of slices, height, width). 26 | * `reconstruction_rss`: Root-sum-of-squares reconstruction of the multi-coil k-space that was 27 | used to derive the emulated single-coil k-space cropped to the center 320 × 320 region. 28 | The shape of the reconstruction rss tensor is (number of slices, 320, 320). 29 | * `reconstruction_esc`: The inverse Fourier transform of the single-coil k-space data cropped to the center 320 × 320 region. The shape of the reconstruction esc tensor is (number of slices, 320, 320). 30 | * Test data: 31 | * `kspace`: Undersampled emulated single-coil k-space. The shape of the kspace tensor is (number of slices, height, width). 32 | * `mask`: Defines the undersampled Cartesian k-space trajectory. The number of elements in the mask tensor is the same as the width of k-space. 33 | 34 | ### Multi-Coil Track (2019) 35 | 36 | * Training & Validation data: 37 | * `kspace`: Multi-coil k-space data. The shape of the kspace tensor is (number of slices, number of coils, height, width). 38 | * `reconstruction_rss`: Root-sum-of-squares reconstruction of the multi-coil k-space data cropped to the center 320 × 320 region. The shape of the reconstruction rss tensor is (number of slices, 320, 320). 39 | * Test data: 40 | * `kspace`: Undersampled multi-coil k-space. The shape of the kspace tensor is (number of slices, number of coils, height, width). 41 | * `mask` Defines the undersampled Cartesian k-space trajectory. The number of elements in the mask tensor is the same as the width of k-space. 42 | 43 | ## Data Transforms 44 | 45 | `data.transforms` provides a number of useful data transformation functions that work with PyTorch tensors. 46 | 47 | ## Data Loader 48 | 49 | `fastmri.data.mri_data` provides a `SliceDataset` class to read one MR slice at a time. It takes as input 50 | a `transform` function or callable object that can be used transform the data into the format that 51 | you need. This makes the data loader versatile and can be used to run different kinds of 52 | reconstruction methods. 53 | 54 | The following is a simple example for how to use the data loader. For more concrete examples, 55 | please look at the baseline model code in the `models` directory. 56 | 57 | ```python 58 | import pathlib 59 | from fastmri.data import subsample 60 | from fastmri.data import transforms, mri_data 61 | 62 | # Create a mask function 63 | mask_func = subsample.RandomMaskFunc(center_fractions=[0.08, 0.04], accelerations=[4, 8]) 64 | 65 | def data_transform(kspace, mask, target, data_attributes, filename, slice_num): 66 | # Transform the data into appropriate format 67 | # Here we simply mask the k-space and return the result 68 | kspace = transforms.to_tensor(kspace) 69 | masked_kspace, _ = transforms.apply_mask(kspace, mask_func) 70 | return masked_kspace 71 | 72 | dataset = mri_data.SliceDataset( 73 | root=pathlib.Path('/private/home/mmuckley/data/fastmri_knee/singlecoil_train'), 74 | transform=data_transform, 75 | challenge='singlecoil' 76 | ) 77 | 78 | for masked_kspace in dataset: 79 | # Do reconstruction 80 | pass 81 | ``` 82 | -------------------------------------------------------------------------------- /fastmri/models/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Yuanhao Cai 3 | @date: 2020.03 4 | """ 5 | 6 | import math 7 | 8 | import torch 9 | import torchvision.transforms as transforms 10 | 11 | from cvpack.dataset import torch_samplers 12 | 13 | from dataset.attribute import load_dataset 14 | from dataset.COCO.coco import COCODataset 15 | from dataset.MPII.mpii import MPIIDataset 16 | 17 | 18 | def get_train_loader( 19 | cfg, num_gpu, is_dist=True, is_shuffle=True, start_iter=0): 20 | # -------- get raw dataset interface -------- # 21 | normalize = transforms.Normalize(mean=cfg.INPUT.MEANS, std=cfg.INPUT.STDS) 22 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 23 | attr = load_dataset(cfg.DATASET.NAME) 24 | if cfg.DATASET.NAME == 'COCO': 25 | Dataset = COCODataset 26 | elif cfg.DATASET.NAME == 'MPII': 27 | Dataset = MPIIDataset 28 | dataset = Dataset(attr, 'train', transform) 29 | 30 | # -------- make samplers -------- # 31 | if is_dist: 32 | sampler = torch_samplers.DistributedSampler( 33 | dataset, shuffle=is_shuffle) 34 | elif is_shuffle: 35 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 36 | else: 37 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 38 | 39 | images_per_gpu = cfg.SOLVER.IMS_PER_GPU 40 | # images_per_gpu = cfg.SOLVER.IMS_PER_BATCH // num_gpu 41 | 42 | aspect_grouping = [1] if cfg.DATALOADER.ASPECT_RATIO_GROUPING else [] 43 | if aspect_grouping: 44 | batch_sampler = torch_samplers.GroupedBatchSampler( 45 | sampler, dataset, aspect_grouping, images_per_gpu, 46 | drop_uneven=False) 47 | else: 48 | batch_sampler = torch.utils.data.sampler.BatchSampler( 49 | sampler, images_per_gpu, drop_last=True)#False 50 | 51 | batch_sampler = torch_samplers.IterationBasedBatchSampler( 52 | batch_sampler, cfg.SOLVER.MAX_ITER, start_iter) 53 | 54 | # -------- make data_loader -------- # 55 | class BatchCollator(object): 56 | def __init__(self, size_divisible): 57 | self.size_divisible = size_divisible 58 | 59 | def __call__(self, batch): 60 | transposed_batch = list(zip(*batch)) 61 | images = torch.stack(transposed_batch[0], dim=0) 62 | valids = torch.stack(transposed_batch[1], dim=0) 63 | labels = torch.stack(transposed_batch[2], dim=0) 64 | 65 | return images, valids, labels 66 | 67 | data_loader = torch.utils.data.DataLoader( 68 | dataset, num_workers=cfg.DATALOADER.NUM_WORKERS, 69 | batch_sampler=batch_sampler, 70 | collate_fn=BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY), ) 71 | 72 | return data_loader 73 | 74 | 75 | def get_test_loader(cfg, num_gpu, local_rank, stage, is_dist=True): 76 | # -------- get raw dataset interface -------- # 77 | normalize = transforms.Normalize(mean=cfg.INPUT.MEANS, std=cfg.INPUT.STDS) 78 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 79 | attr = load_dataset(cfg.DATASET.NAME) 80 | if cfg.DATASET.NAME == 'COCO': 81 | Dataset = COCODataset 82 | elif cfg.DATASET.NAME == 'MPII': 83 | Dataset = MPIIDataset 84 | dataset = Dataset(attr, stage, transform) 85 | 86 | # -------- split dataset to gpus -------- # 87 | num_data = dataset.__len__() 88 | num_data_per_gpu = math.ceil(num_data / num_gpu) 89 | st = local_rank * num_data_per_gpu 90 | ed = min(num_data, st + num_data_per_gpu) 91 | indices = range(st, ed) 92 | subset= torch.utils.data.Subset(dataset, indices) 93 | 94 | # -------- make samplers -------- # 95 | sampler = torch.utils.data.sampler.SequentialSampler(subset) 96 | 97 | images_per_gpu = cfg.TEST.IMS_PER_GPU 98 | 99 | batch_sampler = torch.utils.data.sampler.BatchSampler( 100 | sampler, images_per_gpu, drop_last=True)#False 101 | 102 | # -------- make data_loader -------- # 103 | class BatchCollator(object): 104 | def __init__(self, size_divisible): 105 | self.size_divisible = size_divisible 106 | 107 | def __call__(self, batch): 108 | transposed_batch = list(zip(*batch)) 109 | images = torch.stack(transposed_batch[0], dim=0) 110 | scores = list(transposed_batch[1]) 111 | centers = list(transposed_batch[2]) 112 | scales = list(transposed_batch[3]) 113 | image_ids = list(transposed_batch[4]) 114 | 115 | return images, scores, centers, scales, image_ids 116 | 117 | data_loader = torch.utils.data.DataLoader( 118 | subset, num_workers=cfg.DATALOADER.NUM_WORKERS, 119 | batch_sampler=batch_sampler, 120 | collate_fn=BatchCollator(cfg.DATALOADER.SIZE_DIVISIBILITY), ) 121 | data_loader.ori_dataset = dataset 122 | 123 | return data_loader 124 | -------------------------------------------------------------------------------- /fastmri/data/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def to_tensor(data): 13 | """ 14 | Convert numpy array to PyTorch tensor. 15 | 16 | For complex arrays, the real and imaginary parts are stacked along the last 17 | dimension. 18 | 19 | Args: 20 | data (np.array): Input numpy array. 21 | 22 | Returns: 23 | torch.Tensor: PyTorch version of data. 24 | """ 25 | if np.iscomplexobj(data): 26 | data = np.stack((data.real, data.imag), axis=-1) 27 | 28 | return torch.from_numpy(data) 29 | 30 | 31 | def tensor_to_complex_np(data): 32 | """ 33 | Converts a complex torch tensor to numpy array. 34 | 35 | Args: 36 | data (torch.Tensor): Input data to be converted to numpy. 37 | 38 | Returns: 39 | np.array: Complex numpy version of data. 40 | """ 41 | data = data.numpy() 42 | 43 | return data[..., 0] + 1j * data[..., 1] 44 | 45 | 46 | def apply_mask(data, mask_func, seed=None, padding=None): 47 | """ 48 | Subsample given k-space by multiplying with a mask. 49 | 50 | Args: 51 | data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where 52 | dimensions -3 and -2 are the spatial dimensions, and the final dimension has size 53 | 2 (for complex values). 54 | mask_func (callable): A function that takes a shape (tuple of ints) and a random 55 | number seed and returns a mask. 56 | seed (int or 1-d array_like, optional): Seed for the random number generator. 57 | 58 | Returns: 59 | (tuple): tuple containing: 60 | masked data (torch.Tensor): Subsampled k-space data 61 | mask (torch.Tensor): The generated mask 62 | """ 63 | shape = np.array(data.shape) 64 | shape[:-3] = 1 65 | mask = mask_func(shape, seed) 66 | if padding is not None: 67 | mask[:, :, : padding[0]] = 0 68 | mask[:, :, padding[1] :] = 0 # padding value inclusive on right of zeros 69 | 70 | masked_data = data * mask + 0.0 # the + 0.0 removes the sign of the zeros 71 | 72 | return masked_data, mask 73 | 74 | 75 | def mask_center(x, mask_from, mask_to): 76 | mask = torch.zeros_like(x) 77 | mask[:, :, :, mask_from:mask_to] = x[:, :, :, mask_from:mask_to] 78 | 79 | return mask 80 | 81 | 82 | def center_crop(data, shape): 83 | """ 84 | Apply a center crop to the input real image or batch of real images. 85 | 86 | Args: 87 | data (torch.Tensor): The input tensor to be center cropped. It should 88 | have at least 2 dimensions and the cropping is applied along the 89 | last two dimensions. 90 | shape (int, int): The output shape. The shape should be smaller than 91 | the corresponding dimensions of data. 92 | 93 | Returns: 94 | torch.Tensor: The center cropped image. 95 | """ 96 | assert 0 < shape[0] <= data.shape[-2] 97 | assert 0 < shape[1] <= data.shape[-1] 98 | 99 | w_from = (data.shape[-2] - shape[0]) // 2 100 | h_from = (data.shape[-1] - shape[1]) // 2 101 | w_to = w_from + shape[0] 102 | h_to = h_from + shape[1] 103 | 104 | return data[..., w_from:w_to, h_from:h_to] 105 | 106 | 107 | def complex_center_crop(data, shape): 108 | """ 109 | Apply a center crop to the input image or batch of complex images. 110 | 111 | Args: 112 | data (torch.Tensor): The complex input tensor to be center cropped. It 113 | should have at least 3 dimensions and the cropping is applied along 114 | dimensions -3 and -2 and the last dimensions should have a size of 115 | 2. 116 | shape (int): The output shape. The shape should be smaller than 117 | the corresponding dimensions of data. 118 | 119 | Returns: 120 | torch.Tensor: The center cropped image 121 | """ 122 | assert 0 < shape[0] <= data.shape[-3] 123 | assert 0 < shape[1] <= data.shape[-2] 124 | 125 | w_from = (data.shape[-3] - shape[0]) // 2 #80 126 | h_from = (data.shape[-2] - shape[1]) // 2 #80 127 | w_to = w_from + shape[0] #240 128 | h_to = h_from + shape[1] #240 129 | 130 | return data[..., w_from:w_to, h_from:h_to, :] 131 | 132 | 133 | def center_crop_to_smallest(x, y): 134 | """ 135 | Apply a center crop on the larger image to the size of the smaller. 136 | 137 | The minimum is taken over dim=-1 and dim=-2. If x is smaller than y at 138 | dim=-1 and y is smaller than x at dim=-2, then the returned dimension will 139 | be a mixture of the two. 140 | 141 | Args: 142 | x (torch.Tensor): The first image. 143 | y (torch.Tensor): The second image 144 | 145 | Returns: 146 | tuple: tuple of tensors x and y, each cropped to the minimim size. 147 | """ 148 | smallest_width = min(x.shape[-1], y.shape[-1]) 149 | smallest_height = min(x.shape[-2], y.shape[-2]) 150 | x = center_crop(x, (smallest_height, smallest_width)) 151 | y = center_crop(y, (smallest_height, smallest_width)) 152 | 153 | return x, y 154 | 155 | 156 | def normalize(data, mean, stddev, eps=0.0): 157 | """ 158 | Normalize the given tensor. 159 | 160 | Applies the formula (data - mean) / (stddev + eps). 161 | 162 | Args: 163 | data (torch.Tensor): Input data to be normalized. 164 | mean (float): Mean value. 165 | stddev (float): Standard deviation. 166 | eps (float, default=0.0): Added to stddev to prevent dividing by zero. 167 | 168 | Returns: 169 | torch.Tensor: Normalized tensor 170 | """ 171 | return (data - mean) / (stddev + eps) 172 | 173 | 174 | def normalize_instance(data, eps=0.0): 175 | """ 176 | Normalize the given tensor with instance norm/ 177 | 178 | Applies the formula (data - mean) / (stddev + eps), where mean and stddev 179 | are computed from the data itself. 180 | 181 | Args: 182 | data (torch.Tensor): Input data to be normalized 183 | eps (float): Added to stddev to prevent dividing by zero 184 | 185 | Returns: 186 | torch.Tensor: Normalized tensor 187 | """ 188 | mean = data.mean() 189 | std = data.std() 190 | 191 | return normalize(data, mean, std, eps), mean, std 192 | -------------------------------------------------------------------------------- /fastmri/math.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | 11 | def complex_mul(x, y): 12 | """ 13 | Complex multiplication. 14 | 15 | This multiplies two complex tensors assuming that they are both stored as 16 | real arrays with the last dimension being the complex dimension. 17 | 18 | Args: 19 | x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. 20 | y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. 21 | 22 | Returns: 23 | torch.Tensor: A PyTorch tensor with the last dimension of size 2. 24 | """ 25 | assert x.shape[-1] == y.shape[-1] == 2 26 | re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1] 27 | im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0] 28 | 29 | return torch.stack((re, im), dim=-1) 30 | 31 | 32 | def complex_conj(x): 33 | """ 34 | Complex conjugate. 35 | 36 | This applies the complex conjugate assuming that the input array has the 37 | last dimension as the complex dimension. 38 | 39 | Args: 40 | x (torch.Tensor): A PyTorch tensor with the last dimension of size 2. 41 | y (torch.Tensor): A PyTorch tensor with the last dimension of size 2. 42 | 43 | Returns: 44 | torch.Tensor: A PyTorch tensor with the last dimension of size 2. 45 | """ 46 | assert x.shape[-1] == 2 47 | 48 | return torch.stack((x[..., 0], -x[..., 1]), dim=-1) 49 | 50 | 51 | def fft2c(data): 52 | """ 53 | Apply centered 2 dimensional Fast Fourier Transform. 54 | 55 | Args: 56 | data (torch.Tensor): Complex valued input data containing at least 3 57 | dimensions: dimensions -3 & -2 are spatial dimensions and dimension 58 | -1 has size 2. All other dimensions are assumed to be batch 59 | dimensions. 60 | 61 | Returns: 62 | torch.Tensor: The FFT of the input. 63 | """ 64 | assert data.size(-1) == 2 65 | data = ifftshift(data, dim=(-3, -2)) 66 | data = torch.fft(data, 2, normalized=True) 67 | data = fftshift(data, dim=(-3, -2)) 68 | 69 | return data 70 | 71 | 72 | def ifft2c(data): 73 | """ 74 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 75 | 76 | Args: 77 | data (torch.Tensor): Complex valued input data containing at least 3 78 | dimensions: dimensions -3 & -2 are spatial dimensions and dimension 79 | -1 has size 2. All other dimensions are assumed to be batch 80 | dimensions. 81 | 82 | Returns: 83 | torch.Tensor: The IFFT of the input. 84 | """ 85 | assert data.size(-1) == 2 86 | data = ifftshift(data, dim=(-3, -2)) 87 | data = torch.ifft(data, 2, normalized=True) 88 | data = fftshift(data, dim=(-3, -2)) 89 | 90 | return data 91 | 92 | 93 | def complex_abs(data): 94 | """ 95 | Compute the absolute value of a complex valued input tensor. 96 | 97 | Args: 98 | data (torch.Tensor): A complex valued tensor, where the size of the 99 | final dimension should be 2. 100 | 101 | Returns: 102 | torch.Tensor: Absolute value of data. 103 | """ 104 | assert data.size(-1) == 2 105 | 106 | return (data ** 2).sum(dim=-1).sqrt() 107 | 108 | 109 | 110 | def complex_abs_numpy(data): 111 | assert data.shape[-1] == 2 112 | 113 | return np.sqrt(np.sum(data ** 2, axis=-1)) 114 | 115 | 116 | def complex_abs_sq(data):#multi coil 117 | """ 118 | Compute the squared absolute value of a complex tensor. 119 | 120 | Args: 121 | data (torch.Tensor): A complex valued tensor, where the size of the 122 | final dimension should be 2. 123 | 124 | Returns: 125 | torch.Tensor: Squared absolute value of data. 126 | """ 127 | assert data.size(-1) == 2 128 | return (data ** 2).sum(dim=-1) 129 | 130 | 131 | # Helper functions 132 | 133 | 134 | def roll(x, shift, dim): 135 | """ 136 | Similar to np.roll but applies to PyTorch Tensors. 137 | 138 | Args: 139 | x (torch.Tensor): A PyTorch tensor. 140 | shift (int): Amount to roll. 141 | dim (int): Which dimension to roll. 142 | 143 | Returns: 144 | torch.Tensor: Rolled version of x. 145 | """ 146 | if isinstance(shift, (tuple, list)): 147 | assert len(shift) == len(dim) 148 | for s, d in zip(shift, dim): 149 | x = roll(x, s, d) 150 | return x 151 | shift = shift % x.size(dim) 152 | if shift == 0: 153 | return x 154 | left = x.narrow(dim, 0, x.size(dim) - shift) 155 | right = x.narrow(dim, x.size(dim) - shift, shift) 156 | return torch.cat((right, left), dim=dim) 157 | 158 | 159 | def fftshift(x, dim=None): 160 | """ 161 | Similar to np.fft.fftshift but applies to PyTorch Tensors 162 | 163 | Args: 164 | x (torch.Tensor): A PyTorch tensor. 165 | dim (int): Which dimension to fftshift. 166 | 167 | Returns: 168 | torch.Tensor: fftshifted version of x. 169 | """ 170 | if dim is None: 171 | dim = tuple(range(x.dim())) 172 | shift = [dim // 2 for dim in x.shape] 173 | elif isinstance(dim, int): 174 | shift = x.shape[dim] // 2 175 | else: 176 | shift = [x.shape[i] // 2 for i in dim] 177 | 178 | return roll(x, shift, dim) 179 | 180 | 181 | def ifftshift(x, dim=None): 182 | """ 183 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 184 | 185 | Args: 186 | x (torch.Tensor): A PyTorch tensor. 187 | dim (int): Which dimension to ifftshift. 188 | 189 | Returns: 190 | torch.Tensor: ifftshifted version of x. 191 | """ 192 | if dim is None: 193 | dim = tuple(range(x.dim())) 194 | shift = [(dim + 1) // 2 for dim in x.shape] 195 | elif isinstance(dim, int): 196 | shift = (x.shape[dim] + 1) // 2 197 | else: 198 | shift = [(x.shape[i] + 1) // 2 for i in dim] 199 | 200 | return roll(x, shift, dim) 201 | 202 | 203 | def tensor_to_complex_np(data): 204 | """ 205 | Converts a complex torch tensor to numpy array. 206 | Args: 207 | data (torch.Tensor): Input data to be converted to numpy. 208 | 209 | Returns: 210 | np.array: Complex numpy version of data 211 | """ 212 | data = data.numpy() 213 | return data[..., 0] + 1j * data[..., 1] 214 | -------------------------------------------------------------------------------- /fastmri/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import pathlib 10 | from argparse import ArgumentParser 11 | 12 | import h5py 13 | import numpy as np 14 | import pytorch_lightning 15 | from pytorch_lightning.metrics.metric import NumpyMetric, TensorMetric 16 | from runstats import Statistics 17 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 18 | from torch.distributed import ReduceOp 19 | 20 | from fastmri.data import transforms 21 | 22 | 23 | class MSE(NumpyMetric): 24 | """Calculates MSE and aggregates by summing across distr processes.""" 25 | 26 | def __init__(self, name="MSE", *args, **kwargs): 27 | super().__init__(name=name, *args, **kwargs) 28 | 29 | def forward(self, gt, pred): 30 | return mse(gt, pred) 31 | 32 | 33 | class NMSE(NumpyMetric): 34 | """Calculates NMSE and aggregates by summing across distr processes.""" 35 | 36 | def __init__(self, name="NMSE", *args, **kwargs): 37 | super().__init__(name=name, *args, **kwargs) 38 | 39 | def forward(self, gt, pred): 40 | return nmse(gt, pred) 41 | 42 | 43 | class PSNR(NumpyMetric): 44 | """Calculates PSNR and aggregates by summing across distr processes.""" 45 | 46 | def __init__(self, name="PSNR", *args, **kwargs): 47 | super().__init__(name=name, *args, **kwargs) 48 | 49 | def forward(self, gt, pred): 50 | return psnr(gt, pred) 51 | 52 | 53 | class SSIM(NumpyMetric): 54 | """Calculates SSIM and aggregates by summing across distr processes.""" 55 | 56 | def __init__(self, name="SSIM", *args, **kwargs): 57 | super().__init__(name=name, *args, **kwargs) 58 | 59 | def forward(self, gt, pred, maxval=None): 60 | return ssim(gt, pred, maxval=maxval) 61 | 62 | 63 | class DistributedMetricSum(TensorMetric): 64 | """Used for summing parameters across distr processes.""" 65 | 66 | def __init__(self, name="DistributedMetricSum", *args, **kwargs): 67 | super().__init__(name=name, *args, **kwargs) 68 | 69 | def forward(self, x): 70 | return x.clone() 71 | 72 | 73 | def mse(gt, pred): 74 | """Compute Mean Squared Error (MSE)""" 75 | return np.mean((gt - pred) ** 2) 76 | 77 | 78 | def nmse(gt, pred): 79 | """Compute Normalized Mean Squared Error (NMSE)""" 80 | return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 81 | 82 | 83 | def psnr(gt, pred): 84 | """Compute Peak Signal to Noise Ratio metric (PSNR)""" 85 | return peak_signal_noise_ratio(gt, pred, data_range=gt.max()) 86 | 87 | 88 | def ssim(gt, pred, maxval=None): 89 | """Compute Structural Similarity Index Metric (SSIM)""" 90 | maxval = gt.max() if maxval is None else maxval 91 | 92 | ssim = 0 93 | for slice_num in range(gt.shape[0]): 94 | ssim = ssim + structural_similarity( 95 | gt[slice_num], pred[slice_num], data_range=maxval 96 | ) 97 | 98 | ssim = ssim / gt.shape[0] 99 | 100 | return ssim 101 | 102 | 103 | METRIC_FUNCS = dict(MSE=mse, NMSE=nmse, PSNR=psnr, SSIM=ssim,) 104 | 105 | 106 | class Metrics(object): 107 | """ 108 | Maintains running statistics for a given collection of metrics. 109 | """ 110 | 111 | def __init__(self, metric_funcs): 112 | """ 113 | Args: 114 | metric_funcs (dict): A dict where the keys are metric names and the 115 | values are Python functions for evaluating that metric. 116 | """ 117 | self.metrics = {metric: Statistics() for metric in metric_funcs} 118 | 119 | def push(self, target, recons): 120 | for metric, func in METRIC_FUNCS.items(): 121 | self.metrics[metric].push(func(target, recons)) 122 | 123 | def means(self): 124 | return {metric: stat.mean() for metric, stat in self.metrics.items()} 125 | 126 | def stddevs(self): 127 | return {metric: stat.stddev() for metric, stat in self.metrics.items()} 128 | 129 | def __repr__(self): 130 | means = self.means() 131 | stddevs = self.stddevs() 132 | metric_names = sorted(list(means)) 133 | return " ".join( 134 | f"{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}" 135 | for name in metric_names 136 | ) 137 | 138 | 139 | def evaluate(args, recons_key): 140 | metrics = Metrics(METRIC_FUNCS) 141 | 142 | for tgt_file in args.target_path.iterdir(): 143 | with h5py.File(tgt_file, "r") as target, h5py.File( 144 | args.predictions_path / tgt_file.name, "r" 145 | ) as recons: 146 | if args.acquisition and args.acquisition != target.attrs["acquisition"]: 147 | continue 148 | 149 | if args.acceleration and target.attrs["acceleration"] != args.acceleration: 150 | continue 151 | 152 | target = target[recons_key][()] 153 | recons = recons["reconstruction"][()] 154 | target = transforms.center_crop( 155 | target, (target.shape[-1], target.shape[-1]) 156 | ) 157 | recons = transforms.center_crop( 158 | recons, (target.shape[-1], target.shape[-1]) 159 | ) 160 | metrics.push(target, recons) 161 | 162 | return metrics 163 | 164 | 165 | if __name__ == "__main__": 166 | parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 167 | parser.add_argument( 168 | "--target-path", 169 | type=pathlib.Path, 170 | required=True, 171 | help="Path to the ground truth data", 172 | ) 173 | parser.add_argument( 174 | "--predictions-path", 175 | type=pathlib.Path, 176 | required=True, 177 | help="Path to reconstructions", 178 | ) 179 | parser.add_argument( 180 | "--challenge", 181 | choices=["singlecoil", "multicoil"], 182 | required=True, 183 | help="Which challenge", 184 | ) 185 | parser.add_argument("--acceleration", type=int, default=None) 186 | print ('corpd') 187 | parser.add_argument( 188 | "--acquisition", 189 | choices=[ 190 | "CORPD_FBK", 191 | "CORPDFS_FBK", 192 | "AXT1", 193 | "AXT1PRE", 194 | "AXT1POST", 195 | "AXT2", 196 | "AXFLAIR", 197 | ], 198 | default=CORPD_FBK, 199 | help="If set, only volumes of the specified acquisition type are used " 200 | "for evaluation. By default, all volumes are included.", 201 | ) 202 | args = parser.parse_args() 203 | 204 | recons_key = ( 205 | "reconstruction_rss" if args.challenge == "multicoil" else "reconstruction_esc" 206 | ) 207 | metrics = evaluate(args, recons_key) 208 | print(metrics) 209 | -------------------------------------------------------------------------------- /fastmri/data/mri_data.py: -------------------------------------------------------------------------------- 1 | 2 | import csv 3 | import os 4 | 5 | import logging 6 | import pickle 7 | import random 8 | import xml.etree.ElementTree as etree 9 | from pathlib import Path 10 | from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union 11 | from warnings import warn 12 | import pathlib 13 | 14 | import h5py 15 | import numpy as np 16 | import torch 17 | import yaml 18 | from torch.utils.data import Dataset 19 | 20 | def fetch_dir(key, data_config_file=pathlib.Path("fastmri_dirs.yaml")): 21 | """ 22 | Data directory fetcher. 23 | 24 | This is a brute-force simple way to configure data directories for a 25 | project. Simply overwrite the variables for `knee_path` and `brain_path` 26 | and this function will retrieve the requested subsplit of the data for use. 27 | 28 | Args: 29 | key (str): key to retrieve path from data_config_file. 30 | data_config_file (pathlib.Path, 31 | default=pathlib.Path("fastmri_dirs.yaml")): Default path config 32 | file. 33 | 34 | Returns: 35 | pathlib.Path: The path to the specified directory. 36 | """ 37 | if not data_config_file.is_file(): 38 | default_config = dict( 39 | knee_path="/home/chunmeifeng/Data/", 40 | brain_path="/home/chunmeifeng/Data/", 41 | # log_path="/home/chunmeifeng/experimental/MINet/", 42 | ) 43 | with open(data_config_file, "w") as f: 44 | yaml.dump(default_config, f) 45 | 46 | raise ValueError(f"Please populate {data_config_file} with directory paths.") 47 | 48 | with open(data_config_file, "r") as f: 49 | data_dir = yaml.safe_load(f)[key] 50 | 51 | data_dir = pathlib.Path(data_dir) 52 | 53 | if not data_dir.exists(): 54 | raise ValueError(f"Path {data_dir} from {data_config_file} does not exist.") 55 | 56 | return data_dir 57 | 58 | def et_query( 59 | root: etree.Element, 60 | qlist: Sequence[str], 61 | namespace: str = "http://www.ismrm.org/ISMRMRD", 62 | ) -> str: 63 | """ 64 | ElementTree query function. 65 | This can be used to query an xml document via ElementTree. It uses qlist 66 | for nested queries. 67 | Args: 68 | root: Root of the xml to search through. 69 | qlist: A list of strings for nested searches, e.g. ["Encoding", 70 | "matrixSize"] 71 | namespace: Optional; xml namespace to prepend query. 72 | Returns: 73 | The retrieved data as a string. 74 | """ 75 | s = "." 76 | prefix = "ismrmrd_namespace" 77 | 78 | ns = {prefix: namespace} 79 | 80 | for el in qlist: 81 | s = s + f"//{prefix}:{el}" 82 | 83 | value = root.find(s, ns) 84 | if value is None: 85 | raise RuntimeError("Element not found") 86 | 87 | return str(value.text) 88 | 89 | 90 | class SliceDataset(Dataset): 91 | def __init__( 92 | self, 93 | root, 94 | transform, 95 | challenge, 96 | sample_rate=1, 97 | dataset_cache_file=pathlib.Path("dataset_cache.pkl"), 98 | num_cols=None, 99 | mode='train', 100 | ): 101 | self.mode = mode 102 | 103 | #challenge 104 | if challenge not in ("singlecoil", "multicoil"): 105 | raise ValueError('challenge should be either "singlecoil" or "multicoil"') 106 | self.recons_key = ( 107 | "reconstruction_esc" if challenge == "singlecoil" else "reconstruction_rss" 108 | ) 109 | #transform 110 | self.transform = transform 111 | 112 | self.examples=[] 113 | 114 | self.cur_path=root 115 | self.csv_file=os.path.join(self.cur_path,"singlecoil_"+self.mode+"_split_less.csv") 116 | 117 | #读取CSV 118 | with open(self.csv_file,'r') as f: 119 | reader=csv.reader(f) 120 | 121 | for row in reader: 122 | pd_metadata, pd_num_slices = self._retrieve_metadata(os.path.join(self.cur_path,row[0]+'.h5')) 123 | 124 | pdfs_metadata, pdfs_num_slices = self._retrieve_metadata(os.path.join(self.cur_path, row[1]+'.h5')) 125 | 126 | for slice_id in range(min(pd_num_slices,pdfs_num_slices)): 127 | self.examples.append((os.path.join(self.cur_path, row[0]+'.h5'),os.path.join(self.cur_path, row[1]+'.h5') 128 | ,slice_id,pd_metadata,pdfs_metadata)) 129 | 130 | if sample_rate < 1: 131 | random.shuffle(self.examples) 132 | num_examples = round(len(self.examples) * sample_rate) 133 | 134 | self.examples=self.examples[0:num_examples] 135 | 136 | def __len__(self): 137 | return len(self.examples) 138 | 139 | def __getitem__(self, i): 140 | 141 | #读取pd 142 | pd_fname,pdfs_fname,slice,pd_metadata,pdfs_metadata = self.examples[i] 143 | 144 | with h5py.File(pd_fname, "r") as hf: 145 | pd_kspace = hf["kspace"][slice] 146 | 147 | pd_mask = np.asarray(hf["mask"]) if "mask" in hf else None 148 | 149 | pd_target = hf[self.recons_key][slice] if self.recons_key in hf else None 150 | 151 | attrs = dict(hf.attrs) 152 | 153 | attrs.update(pd_metadata) 154 | 155 | if self.transform is None: 156 | pd_sample = (pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) 157 | else: 158 | pd_sample = self.transform(pd_kspace, pd_mask, pd_target, attrs, pd_fname, slice) 159 | 160 | with h5py.File(pdfs_fname, "r") as hf: 161 | pdfs_kspace = hf["kspace"][slice] 162 | pdfs_mask = np.asarray(hf["mask"]) if "mask" in hf else None 163 | 164 | pdfs_target = hf[self.recons_key][slice] if self.recons_key in hf else None 165 | 166 | attrs = dict(hf.attrs) 167 | 168 | attrs.update(pdfs_metadata) 169 | 170 | if self.transform is None: 171 | pdfs_sample = (pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) 172 | else: 173 | pdfs_sample = self.transform(pdfs_kspace, pdfs_mask, pdfs_target, attrs, pdfs_fname, slice) 174 | 175 | 176 | return (pd_sample,pdfs_sample) 177 | 178 | def _retrieve_metadata(self, fname): 179 | with h5py.File(fname, "r") as hf: 180 | et_root = etree.fromstring(hf["ismrmrd_header"][()]) 181 | 182 | enc = ["encoding", "encodedSpace", "matrixSize"] 183 | enc_size = ( 184 | int(et_query(et_root, enc + ["x"])), 185 | int(et_query(et_root, enc + ["y"])), 186 | int(et_query(et_root, enc + ["z"])), 187 | ) 188 | rec = ["encoding", "reconSpace", "matrixSize"] 189 | recon_size = ( 190 | int(et_query(et_root, rec + ["x"])), 191 | int(et_query(et_root, rec + ["y"])), 192 | int(et_query(et_root, rec + ["z"])), 193 | ) 194 | 195 | lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"] 196 | enc_limits_center = int(et_query(et_root, lims + ["center"])) 197 | enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1 198 | 199 | padding_left = enc_size[1] // 2 - enc_limits_center 200 | padding_right = padding_left + enc_limits_max 201 | 202 | num_slices = hf["kspace"].shape[0] 203 | 204 | metadata = { 205 | "padding_left": padding_left, 206 | "padding_right": padding_right, 207 | "encoding_size": enc_size, 208 | "recon_size": recon_size, 209 | } 210 | 211 | return metadata, num_slices 212 | -------------------------------------------------------------------------------- /common/subsample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): 13 | if mask_type_str == 'random': 14 | return RandomMaskFunc(center_fractions, accelerations) 15 | elif mask_type_str == 'equispaced': 16 | return EquispacedMaskFunc(center_fractions, accelerations) 17 | else: 18 | raise Exception(f"{mask_type_str} not supported") 19 | 20 | 21 | class MaskFunc(): 22 | def __init__(self, center_fractions, accelerations): 23 | """ 24 | Args: 25 | center_fractions (List[float]): Fraction of low-frequency columns to be retained. 26 | If multiple values are provided, then one of these numbers is chosen uniformly 27 | each time. 28 | 29 | accelerations (List[int]): Amount of under-sampling. This should have the same length 30 | as center_fractions. If multiple values are provided, then one of these is chosen 31 | uniformly each time. 32 | """ 33 | if len(center_fractions) != len(accelerations): 34 | raise ValueError('Number of center fractions should match number of accelerations') 35 | 36 | self.center_fractions = center_fractions 37 | self.accelerations = accelerations 38 | self.rng = np.random.RandomState() 39 | 40 | def choose_acceleration(self): 41 | choice = self.rng.randint(0, len(self.accelerations)) 42 | center_fraction = self.center_fractions[choice] 43 | acceleration = self.accelerations[choice] 44 | return center_fraction, acceleration 45 | 46 | 47 | class RandomMaskFunc(MaskFunc): 48 | """ 49 | RandomMaskFunc creates a sub-sampling mask of a given shape. 50 | 51 | The mask selects a subset of columns from the input k-space data. If the k-space data has N 52 | columns, the mask picks out: 53 | 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to 54 | low-frequencies 55 | 2. The other columns are selected uniformly at random with a probability equal to: 56 | prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). 57 | This ensures that the expected number of columns selected is equal to (N / acceleration) 58 | 59 | It is possible to use multiple center_fractions and accelerations, in which case one possible 60 | (center_fraction, acceleration) is chosen uniformly at random each time the RandomMaskFunc object is 61 | called. 62 | 63 | For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there 64 | is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% 65 | probability that 8-fold acceleration with 4% center fraction is selected. 66 | """ 67 | 68 | def __init__(self, center_fractions, accelerations): 69 | """ 70 | Args: 71 | center_fractions (List[float]): Fraction of low-frequency columns to be retained. 72 | If multiple values are provided, then one of these numbers is chosen uniformly 73 | each time. 74 | 75 | accelerations (List[int]): Amount of under-sampling. This should have the same length 76 | as center_fractions. If multiple values are provided, then one of these is chosen 77 | uniformly each time. An acceleration of 4 retains 25% of the columns, but they may 78 | not be spaced evenly. 79 | """ 80 | if len(center_fractions) != len(accelerations): 81 | raise ValueError('Number of center fractions should match number of accelerations') 82 | 83 | self.center_fractions = center_fractions 84 | self.accelerations = accelerations 85 | self.rng = np.random.RandomState() 86 | 87 | def __call__(self, shape, seed=None): 88 | """ 89 | Args: 90 | shape (iterable[int]): The shape of the mask to be created. The shape should have 91 | at least 3 dimensions. Samples are drawn along the second last dimension. 92 | seed (int, optional): Seed for the random number generator. Setting the seed 93 | ensures the same mask is generated each time for the same shape. 94 | Returns: 95 | torch.Tensor: A mask of the specified shape. 96 | """ 97 | if len(shape) < 3: 98 | raise ValueError('Shape should have 3 or more dimensions') 99 | 100 | self.rng.seed(seed) 101 | num_cols = shape[-2] 102 | center_fraction, acceleration = self.choose_acceleration() 103 | 104 | # Create the mask 105 | num_low_freqs = int(round(num_cols * center_fraction)) 106 | prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) 107 | mask = self.rng.uniform(size=num_cols) < prob 108 | pad = (num_cols - num_low_freqs + 1) // 2 109 | mask[pad:pad + num_low_freqs] = True 110 | 111 | # Reshape the mask 112 | mask_shape = [1 for _ in shape] 113 | mask_shape[-2] = num_cols 114 | mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) 115 | 116 | return mask 117 | 118 | class EquispacedMaskFunc(MaskFunc): 119 | """ 120 | EquispacedMaskFunc creates a sub-sampling mask of a given shape. 121 | 122 | The mask selects a subset of columns from the input k-space data. If the k-space data has N 123 | columns, the mask picks out: 124 | 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to 125 | low-frequencies 126 | 2. The other columns are selected with equal spacing at a proportion that reaches the 127 | desired acceleration rate taking into consideration the number of low frequencies. This 128 | ensures that the expected number of columns selected is equal to (N / acceleration) 129 | 130 | It is possible to use multiple center_fractions and accelerations, in which case one possible 131 | (center_fraction, acceleration) is chosen uniformly at random each time the EquispacedMaskFunc 132 | object is called. 133 | """ 134 | def __call__(self, shape, seed): 135 | """ 136 | Args: 137 | shape (iterable[int]): The shape of the mask to be created. The shape should have 138 | at least 3 dimensions. Samples are drawn along the second last dimension. 139 | seed (int, optional): Seed for the random number generator. Setting the seed 140 | ensures the same mask is generated each time for the same shape. 141 | Returns: 142 | torch.Tensor: A mask of the specified shape. 143 | """ 144 | if len(shape) < 3: 145 | raise ValueError('Shape should have 3 or more dimensions') 146 | 147 | self.rng.seed(seed) 148 | center_fraction, acceleration = self.choose_acceleration() 149 | num_cols = shape[-2] 150 | num_low_freqs = int(round(num_cols * center_fraction)) 151 | 152 | # Create the mask 153 | mask = np.zeros(num_cols, dtype=np.float32) 154 | pad = (num_cols - num_low_freqs + 1) // 2 155 | mask[pad:pad + num_low_freqs] = True 156 | 157 | # Determine acceleration rate by adjusting for the number of low frequencies 158 | adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols) 159 | offset = self.rng.randint(0, round(adjusted_accel)) 160 | 161 | accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) 162 | accel_samples = np.around(accel_samples).astype(np.uint) 163 | mask[accel_samples] = True 164 | 165 | # Reshape the mask 166 | mask_shape = [1 for _ in shape] 167 | mask_shape[-2] = num_cols 168 | mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) 169 | 170 | return mask 171 | -------------------------------------------------------------------------------- /fastmri/data/subsample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import contextlib 9 | 10 | import numpy as np 11 | import torch 12 | 13 | 14 | @contextlib.contextmanager 15 | def temp_seed(rng, seed): 16 | state = rng.get_state() 17 | rng.seed(seed) 18 | try: 19 | yield 20 | finally: 21 | rng.set_state(state) 22 | 23 | 24 | def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): 25 | if mask_type_str == "random": 26 | return RandomMaskFunc(center_fractions, accelerations) 27 | elif mask_type_str == "equispaced": 28 | return EquispacedMaskFunc(center_fractions, accelerations) 29 | else: 30 | raise Exception(f"{mask_type_str} not supported") 31 | 32 | 33 | class MaskFunc(object): 34 | """ 35 | An object for GRAPPA-style sampling masks. 36 | 37 | This crates a sampling mask that densely samples the center while 38 | subsampling outer k-space regions based on the undersampling factor. 39 | """ 40 | 41 | def __init__(self, center_fractions, accelerations): 42 | """ 43 | Args: 44 | center_fractions (List[float]): Fraction of low-frequency columns to be 45 | retained. If multiple values are provided, then one of these 46 | numbers is chosen uniformly each time. 47 | accelerations (List[int]): Amount of under-sampling. This should have 48 | the same length as center_fractions. If multiple values are 49 | provided, then one of these is chosen uniformly each time. 50 | """ 51 | if len(center_fractions) != len(accelerations): 52 | raise ValueError( 53 | "Number of center fractions should match number of accelerations" 54 | ) 55 | 56 | self.center_fractions = center_fractions 57 | self.accelerations = accelerations 58 | self.rng = np.random 59 | 60 | def choose_acceleration(self): 61 | """Choose acceleration based on class parameters.""" 62 | choice = self.rng.randint(0, len(self.accelerations)) 63 | center_fraction = self.center_fractions[choice] 64 | acceleration = self.accelerations[choice] 65 | 66 | return center_fraction, acceleration 67 | 68 | 69 | class RandomMaskFunc(MaskFunc): 70 | """ 71 | RandomMaskFunc creates a sub-sampling mask of a given shape. 72 | 73 | The mask selects a subset of columns from the input k-space data. If the 74 | k-space data has N columns, the mask picks out: 75 | 1. N_low_freqs = (N * center_fraction) columns in the center 76 | corresponding to low-frequencies. 77 | 2. The other columns are selected uniformly at random with a 78 | probability equal to: prob = (N / acceleration - N_low_freqs) / 79 | (N - N_low_freqs). This ensures that the expected number of columns 80 | selected is equal to (N / acceleration). 81 | 82 | It is possible to use multiple center_fractions and accelerations, in which 83 | case one possible (center_fraction, acceleration) is chosen uniformly at 84 | random each time the RandomMaskFunc object is called. 85 | 86 | For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], 87 | then there is a 50% probability that 4-fold acceleration with 8% center 88 | fraction is selected and a 50% probability that 8-fold acceleration with 4% 89 | center fraction is selected. 90 | """ 91 | 92 | def __call__(self, shape, seed=None): 93 | """ 94 | Create the mask. 95 | 96 | Args: 97 | shape (iterable[int]): The shape of the mask to be created. The 98 | shape should have at least 3 dimensions. Samples are drawn 99 | along the second last dimension. 100 | seed (int, optional): Seed for the random number generator. Setting 101 | the seed ensures the same mask is generated each time for the 102 | same shape. The random state is reset afterwards. 103 | 104 | Returns: 105 | torch.Tensor: A mask of the specified shape. 106 | """ 107 | if len(shape) < 3: 108 | raise ValueError("Shape should have 3 or more dimensions") 109 | 110 | with temp_seed(self.rng, seed): 111 | num_cols = shape[-2] 112 | center_fraction, acceleration = self.choose_acceleration() 113 | 114 | # create the mask 115 | num_low_freqs = int(round(num_cols * center_fraction)) 116 | prob = (num_cols / acceleration - num_low_freqs) / ( 117 | num_cols - num_low_freqs 118 | ) 119 | mask = self.rng.uniform(size=num_cols) < prob 120 | pad = (num_cols - num_low_freqs + 1) // 2 121 | mask[pad : pad + num_low_freqs] = True 122 | 123 | # reshape the mask 124 | mask_shape = [1 for _ in shape] 125 | mask_shape[-2] = num_cols 126 | mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) 127 | 128 | return mask 129 | 130 | 131 | class EquispacedMaskFunc(MaskFunc): 132 | """ 133 | EquispacedMaskFunc creates a sub-sampling mask of a given shape. 134 | 135 | The mask selects a subset of columns from the input k-space data. If the 136 | k-space data has N columns, the mask picks out: 137 | 1. N_low_freqs = (N * center_fraction) columns in the center 138 | corresponding tovlow-frequencies. 139 | 2. The other columns are selected with equal spacing at a proportion 140 | that reaches the desired acceleration rate taking into consideration 141 | the number of low frequencies. This ensures that the expected number 142 | of columns selected is equal to (N / acceleration) 143 | 144 | It is possible to use multiple center_fractions and accelerations, in which 145 | case one possible (center_fraction, acceleration) is chosen uniformly at 146 | random each time the EquispacedMaskFunc object is called. 147 | 148 | Note that this function may not give equispaced samples (documented in 149 | https://github.com/facebookresearch/fastMRI/issues/54), which will require 150 | modifications to standard GRAPPA approaches. Nonetheless, this aspect of 151 | the function has been preserved to match the public multicoil data. 152 | """ 153 | 154 | def __call__(self, shape, seed): 155 | """ 156 | Args: 157 | shape (iterable[int]): The shape of the mask to be created. The 158 | shape should have at least 3 dimensions. Samples are drawn 159 | along the second last dimension. 160 | seed (int, optional): Seed for the random number generator. Setting 161 | the seed ensures the same mask is generated each time for the 162 | same shape. The random state is reset afterwards. 163 | 164 | Returns: 165 | torch.Tensor: A mask of the specified shape. 166 | """ 167 | if len(shape) < 3: 168 | raise ValueError("Shape should have 3 or more dimensions") 169 | 170 | with temp_seed(self.rng, seed): 171 | center_fraction, acceleration = self.choose_acceleration() 172 | num_cols = shape[-2] 173 | num_low_freqs = int(round(num_cols * center_fraction)) 174 | 175 | # create the mask 176 | mask = np.zeros(num_cols, dtype=np.float32) 177 | pad = (num_cols - num_low_freqs + 1) // 2 178 | mask[pad : pad + num_low_freqs] = True 179 | 180 | # determine acceleration rate by adjusting for the number of low frequencies 181 | adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / ( 182 | num_low_freqs * acceleration - num_cols 183 | ) 184 | offset = self.rng.randint(0, round(adjusted_accel)) 185 | 186 | accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) 187 | accel_samples = np.around(accel_samples).astype(np.uint) 188 | mask[accel_samples] = True 189 | 190 | # reshape the mask 191 | mask_shape = [1 for _ in shape] 192 | mask_shape[-2] = num_cols 193 | mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) 194 | 195 | return mask 196 | -------------------------------------------------------------------------------- /experimental/MANet/RAdam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class RAdam(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 8 | if not 0.0 <= lr: 9 | raise ValueError("Invalid learning rate: {}".format(lr)) 10 | if not 0.0 <= eps: 11 | raise ValueError("Invalid epsilon value: {}".format(eps)) 12 | if not 0.0 <= betas[0] < 1.0: 13 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 14 | if not 0.0 <= betas[1] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 16 | 17 | self.degenerated_to_sgd = degenerated_to_sgd 18 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 19 | for param in params: 20 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 21 | param['buffer'] = [[None, None, None] for _ in range(10)] 22 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 23 | super(RAdam, self).__init__(params, defaults) 24 | 25 | def __setstate__(self, state): 26 | super(RAdam, self).__setstate__(state) 27 | 28 | def step(self, closure=None): 29 | 30 | loss = None 31 | if closure is not None: 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | 36 | for p in group['params']: 37 | if p.grad is None: 38 | continue 39 | grad = p.grad.data.float() 40 | if grad.is_sparse: 41 | raise RuntimeError('RAdam does not support sparse gradients') 42 | 43 | p_data_fp32 = p.data.float() 44 | 45 | state = self.state[p] 46 | 47 | if len(state) == 0: 48 | state['step'] = 0 49 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 50 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 51 | else: 52 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 53 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 54 | 55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 56 | beta1, beta2 = group['betas'] 57 | 58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 59 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 60 | 61 | state['step'] += 1 62 | buffered = group['buffer'][int(state['step'] % 10)] 63 | if state['step'] == buffered[0]: 64 | N_sma, step_size = buffered[1], buffered[2] 65 | else: 66 | buffered[0] = state['step'] 67 | beta2_t = beta2 ** state['step'] 68 | N_sma_max = 2 / (1 - beta2) - 1 69 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 70 | buffered[1] = N_sma 71 | 72 | # more conservative since it's an approximated value 73 | if N_sma >= 5: 74 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 75 | elif self.degenerated_to_sgd: 76 | step_size = 1.0 / (1 - beta1 ** state['step']) 77 | else: 78 | step_size = -1 79 | buffered[2] = step_size 80 | 81 | # more conservative since it's an approximated value 82 | if N_sma >= 5: 83 | if group['weight_decay'] != 0: 84 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 85 | denom = exp_avg_sq.sqrt().add_(group['eps']) 86 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 87 | p.data.copy_(p_data_fp32) 88 | elif step_size > 0: 89 | if group['weight_decay'] != 0: 90 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 91 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 92 | p.data.copy_(p_data_fp32) 93 | 94 | return loss 95 | 96 | class PlainRAdam(Optimizer): 97 | 98 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 99 | if not 0.0 <= lr: 100 | raise ValueError("Invalid learning rate: {}".format(lr)) 101 | if not 0.0 <= eps: 102 | raise ValueError("Invalid epsilon value: {}".format(eps)) 103 | if not 0.0 <= betas[0] < 1.0: 104 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 105 | if not 0.0 <= betas[1] < 1.0: 106 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 107 | 108 | self.degenerated_to_sgd = degenerated_to_sgd 109 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 110 | 111 | super(PlainRAdam, self).__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super(PlainRAdam, self).__setstate__(state) 115 | 116 | def step(self, closure=None): 117 | 118 | loss = None 119 | if closure is not None: 120 | loss = closure() 121 | 122 | for group in self.param_groups: 123 | 124 | for p in group['params']: 125 | if p.grad is None: 126 | continue 127 | grad = p.grad.data.float() 128 | if grad.is_sparse: 129 | raise RuntimeError('RAdam does not support sparse gradients') 130 | 131 | p_data_fp32 = p.data.float() 132 | 133 | state = self.state[p] 134 | 135 | if len(state) == 0: 136 | state['step'] = 0 137 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 138 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 139 | else: 140 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 141 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 142 | 143 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 144 | beta1, beta2 = group['betas'] 145 | 146 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 147 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 148 | 149 | state['step'] += 1 150 | beta2_t = beta2 ** state['step'] 151 | N_sma_max = 2 / (1 - beta2) - 1 152 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 153 | 154 | 155 | # more conservative since it's an approximated value 156 | if N_sma >= 5: 157 | if group['weight_decay'] != 0: 158 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 159 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 160 | denom = exp_avg_sq.sqrt().add_(group['eps']) 161 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 162 | p.data.copy_(p_data_fp32) 163 | elif self.degenerated_to_sgd: 164 | if group['weight_decay'] != 0: 165 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 166 | step_size = group['lr'] / (1 - beta1 ** state['step']) 167 | p_data_fp32.add_(-step_size, exp_avg) 168 | p.data.copy_(p_data_fp32) 169 | 170 | return loss 171 | 172 | 173 | class AdamW(Optimizer): 174 | 175 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 176 | if not 0.0 <= lr: 177 | raise ValueError("Invalid learning rate: {}".format(lr)) 178 | if not 0.0 <= eps: 179 | raise ValueError("Invalid epsilon value: {}".format(eps)) 180 | if not 0.0 <= betas[0] < 1.0: 181 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 182 | if not 0.0 <= betas[1] < 1.0: 183 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 184 | 185 | defaults = dict(lr=lr, betas=betas, eps=eps, 186 | weight_decay=weight_decay, warmup = warmup) 187 | super(AdamW, self).__init__(params, defaults) 188 | 189 | def __setstate__(self, state): 190 | super(AdamW, self).__setstate__(state) 191 | 192 | def step(self, closure=None): 193 | loss = None 194 | if closure is not None: 195 | loss = closure() 196 | 197 | for group in self.param_groups: 198 | 199 | for p in group['params']: 200 | if p.grad is None: 201 | continue 202 | grad = p.grad.data.float() 203 | if grad.is_sparse: 204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 205 | 206 | p_data_fp32 = p.data.float() 207 | 208 | state = self.state[p] 209 | 210 | if len(state) == 0: 211 | state['step'] = 0 212 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 213 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 214 | else: 215 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 216 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 217 | 218 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 219 | beta1, beta2 = group['betas'] 220 | 221 | state['step'] += 1 222 | 223 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 224 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 225 | 226 | denom = exp_avg_sq.sqrt().add_(group['eps']) 227 | bias_correction1 = 1 - beta1 ** state['step'] 228 | bias_correction2 = 1 - beta2 ** state['step'] 229 | 230 | if group['warmup'] > state['step']: 231 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 232 | else: 233 | scheduled_lr = group['lr'] 234 | 235 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 236 | 237 | if group['weight_decay'] != 0: 238 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 239 | 240 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 241 | 242 | p.data.copy_(p_data_fp32) 243 | 244 | return loss -------------------------------------------------------------------------------- /experimental/MANet/module_MANet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import hashlib 9 | import os 10 | from argparse import ArgumentParser 11 | 12 | import pytorch_lightning as pl 13 | import torch 14 | from torch.nn import functional as F 15 | 16 | import fastmri 17 | from fastmri import MriModule 18 | from fastmri.data import transforms 19 | from fastmri.data.subsample import create_mask_for_mask_type 20 | from fastmri.models.MANet import MANet 21 | import cv2 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | 25 | class SRModule(MriModule): 26 | """ 27 | Unet training module. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | n_channels_in=1, 33 | n_channels_out=2, 34 | lr=0.001, 35 | lr_step_size=40, 36 | lr_gamma=0.1, 37 | weight_decay=0.0, 38 | L=2, 39 | **kwargs, 40 | ): 41 | """ 42 | Args: 43 | in_chans (int): Number of channels in the input to the U-Net model. 44 | out_chans (int): Number of channels in the output to the U-Net 45 | model. 46 | chans (int): Number of output channels of the first convolution 47 | layer. 48 | num_pool_layers (int): Number of down-sampling and up-sampling 49 | layers. 50 | drop_prob (float): Dropout probability. 51 | mask_type (str): Type of mask from ("random", "equispaced"). 52 | center_fractions (list): Fraction of all samples to take from 53 | center (i.e., list of floats). 54 | accelerations (list): List of accelerations to apply (i.e., list 55 | of ints). 56 | lr (float): Learning rate. 57 | lr_step_size (int): Learning rate step size. 58 | lr_gamma (float): Learning rate gamma decay. 59 | weight_decay (float): Parameter for penalizing weights norm. 60 | """ 61 | super().__init__(**kwargs) 62 | self.n_channels_in = n_channels_in 63 | self.n_channels_out = n_channels_out 64 | self.lr = lr 65 | self.lr_step_size = lr_step_size 66 | self.lr_gamma = lr_gamma 67 | self.weight_decay = weight_decay 68 | # self.unet = UNet 69 | self.UNet_k = Dense_Unet_k( 70 | in_chan=2, 71 | out_chan=2, 72 | filters=64, 73 | num_conv = 4,) 74 | self.UNet_img = Dense_Unet_img( 75 | in_chan=1, 76 | out_chan=1, 77 | filters=64, 78 | num_conv = 4,) 79 | 80 | def forward(self, target_Kspace_T1,masked_Kspaces_T2): 81 | return self.UNet_k(target_Kspace_T1,masked_Kspaces_T2) 82 | def forward2(self, input_T1_img,input_T2_img): 83 | return self.UNet_img(input_T1_img,input_T2_img) 84 | 85 | def training_step(self, batch, batch_idx): 86 | #T1 87 | masked_Kspace_T1 = batch['masked_Kspaces_T1'].cuda().float() 88 | target_Kspace_T1 = batch['target_Kspace_T1'].cuda().float() 89 | target_img_T1 = batch['target_img_T1'].cuda().float() 90 | maskedNot = batch['maskedNot'].cuda().float() 91 | #T2 92 | masked_Kspaces_T2 = batch['masked_Kspaces_T2'].cuda().float() 93 | target_Kspace_T2 = batch['target_Kspace_T2'].cuda().float() 94 | target_img_T2 = batch['target_img_T2'].cuda().float() 95 | fname = batch['fname'] 96 | slice_num = batch['slice_num'] 97 | 98 | masked_img_T2 = self.inverseFT(masked_Kspaces_T2).cuda().float() 99 | 100 | output_T1k,output_T2k = self.forward(target_Kspace_T1,masked_Kspaces_T2) 101 | loss_T1 = F.l1_loss(output_T1k, target_Kspace_T1) 102 | loss_T2 = F.l1_loss(output_T2k, target_Kspace_T2) 103 | loss_k = 0.1*loss_T1+0.9*loss_T2 104 | 105 | input_T1_img = self.inverseFT(output_T1k) 106 | input_T2 = self.inverseFT(output_T2k) 107 | 108 | output_T1img,output_T2 = self.forward2(input_T1_img,input_T2) 109 | loss_T1img = F.l1_loss(output_T1img, target_img_T1) 110 | loss_T2img = F.l1_loss(output_T2, target_img_T2) 111 | loss_img = 0.1*loss_T1img+0.9*loss_T2img 112 | 113 | loss = loss_k + loss_img 114 | 115 | logs = {"loss": loss.detach()} 116 | 117 | return dict(loss=loss, log=logs) 118 | 119 | def inverseFT(self, Kspace): 120 | Kspace = Kspace.permute(0, 2, 3, 1)#last dimension=2 121 | img_cmplx = torch.ifft(Kspace, 2) 122 | img = torch.sqrt(img_cmplx[:, :, :, 0]**2 + img_cmplx[:, :, :, 1]**2) 123 | img = img[:, None, :, :] 124 | return img 125 | 126 | def contrastStretching(self,img, saturated_pixel=0.004): 127 | """ constrast stretching according to imageJ 128 | http://homepages.inf.ed.ac.uk/rbf/HIPR2/stretch.htm""" 129 | values = np.sort(img, axis=None) 130 | nr_pixels = np.size(values) 131 | lim = int(np.round(saturated_pixel*nr_pixels)) 132 | v_min = values[lim] 133 | v_max = values[-lim-1] 134 | img = (img - v_min)*(255.0)/(v_max - v_min) 135 | img = np.minimum(255.0, np.maximum(0.0, img)) 136 | return img 137 | 138 | def fftshift(self, x, dim=None): 139 | 140 | if dim is None: 141 | dim = tuple(range(x.dim())) 142 | shift = [dim // 2 for dim in x.shape] 143 | elif isinstance(dim, int): 144 | shift = x.shape[dim] // 2 145 | else: 146 | shift = [x.shape[i] // 2 for i in dim] 147 | 148 | return roll(x, shift, dim) 149 | 150 | def imshow(self, img, title=""): 151 | """ Show image as grayscale. """ 152 | if img.dtype == np.complex64 or img.dtype == np.complex128: 153 | print('img is complex! Take absolute value.') 154 | img = np.abs(img) 155 | 156 | plt.figure() 157 | plt.imshow(img, cmap='gray', interpolation='nearest') 158 | plt.axis('off') 159 | plt.title(title) 160 | plt.show() 161 | 162 | 163 | def ifft2(self, kspace_cplx): 164 | return np.absolute(np.fft.ifft2(kspace_cplx))[None, :, :] 165 | 166 | def fft2(self, img): 167 | return np.fft.fftshift(np.fft.fft2(img)) 168 | 169 | def validation_step(self, batch, batch_idx): 170 | #T1 171 | masked_Kspace_T1 = batch['masked_Kspaces_T1'].cuda().float() 172 | target_Kspace_T1 = batch['target_Kspace_T1'].cuda().float() 173 | target_img_T1 = batch['target_img_T1'].cuda().float() 174 | maskedNot = batch['maskedNot'].cuda().float() 175 | masks = batch['masks'].cuda().float() 176 | 177 | #T2 178 | masked_Kspaces_T2 = batch['masked_Kspaces_T2'].cuda().float() 179 | target_Kspace_T2 = batch['target_Kspace_T2'].cuda().float() 180 | target_img_T2 = batch['target_img_T2'].cuda().float() 181 | fname = batch['fname'] 182 | slice_num = batch['slice_num'] 183 | 184 | masked_img_T2 = self.inverseFT(masked_Kspaces_T2).cuda().float() 185 | 186 | output_T1k,output_T2k = self.forward(target_Kspace_T1,masked_Kspaces_T2) 187 | 188 | loss_T1 = F.l1_loss(output_T1k, target_Kspace_T1) 189 | loss_T2 = F.l1_loss(output_T2k, target_Kspace_T2) 190 | loss_k = 0.1*loss_T1+0.9*loss_T2 191 | 192 | 193 | input_T1img = self.inverseFT(output_T1k)# 194 | input_T2 = self.inverseFT(output_T2k)# 195 | 196 | output_T1img,output_T2 = self.forward2(input_T1img,input_T2) 197 | loss_T1img = F.l1_loss(output_T1img, target_img_T1) 198 | loss_T2img = F.l1_loss(output_T2, target_img_T2) 199 | loss_img = 0.1*loss_T1img+0.9*loss_T2img 200 | loss = loss_k + loss_img 201 | 202 | fnumber = torch.zeros(len(fname), dtype=torch.long, device=output_T2.device) 203 | for i, fn in enumerate(fname): 204 | fnumber[i] = ( 205 | int(hashlib.sha256(fn.encode("utf-8")).hexdigest(), 16) % 10 ** 12 206 | ) 207 | 208 | return { 209 | "fname": fnumber, 210 | "slice": slice_num, 211 | # "output": output * std + mean, 212 | # "target": target * std + mean, 213 | "output_T2": output_T2, 214 | "target_im_T2": target_img_T2, 215 | "val_loss": loss, 216 | } 217 | 218 | def test_step(self, batch, batch_idx): 219 | #T1 220 | masked_Kspace_T1 = batch['masked_Kspaces_T1'].cuda().float()#masked_kspace: torch.Size([1, 2, 256, 256]) 221 | target_Kspace_T1 = batch['target_Kspace_T1'].cuda().float()# target_kspace: torch.Size([1, 2, 256, 256]) 222 | target_img_T1 = batch['target_img_T1'].cuda().float()#target_img: torch.Size([1, 1, 256, 256]) 223 | #T2 224 | masked_Kspaces_T2 = batch['masked_Kspaces_T2'].cuda().float()#masked_kspace: torch.Size([1, 2, 256, 256]) 225 | target_Kspace_T2 = batch['target_Kspace_T2'].cuda().float()# target_kspace: torch.Size([1, 2, 256, 256]) 226 | target_img_T2 = batch['target_img_T2'].cuda().float()#target_img: torch.Size([1, 1, 256, 256]) 227 | 228 | masked_img_T2 = self.inverseFT(masked_Kspaces_T2).cuda().float() 229 | output_T1, output_T2 = self(target_img_T1,masked_img_T2) 230 | 231 | fname = batch['fname'] 232 | slice_num = batch['slice_num'] 233 | fnumber = torch.zeros(len(fname), dtype=torch.long, device=output_T2.device) 234 | for i, fn in enumerate(fname): 235 | fnumber[i] = ( 236 | int(hashlib.sha256(fn.encode("utf-8")).hexdigest(), 16) % 10 ** 12 237 | ) 238 | 239 | return { 240 | "fname": fnumber, 241 | "slice": slice_num, 242 | "output_T2": output_T2, 243 | "target_im_T2": target_img_T2, 244 | "test_loss": F.l1_loss(output_T2, target_img_T2), 245 | } 246 | 247 | def configure_optimizers(self): 248 | optim = torch.optim.RMSprop( 249 | self.parameters(), lr=self.lr, weight_decay=self.weight_decay, 250 | ) 251 | scheduler = torch.optim.lr_scheduler.StepLR( 252 | optim, self.lr_step_size, self.lr_gamma 253 | ) 254 | 255 | return [optim], [scheduler] 256 | 257 | 258 | @staticmethod 259 | def add_model_specific_args(parent_parser): # pragma: no-cover 260 | """ 261 | Define parameters that only apply to this model 262 | """ 263 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 264 | parser = MriModule.add_model_specific_args(parser) 265 | 266 | # param overwrites 267 | 268 | # network params 269 | parser.add_argument("--in_chans", default=1, type=int) 270 | parser.add_argument("--out_chans", default=1, type=int) 271 | parser.add_argument("--chans", default=1, type=int) 272 | parser.add_argument("--num_pool_layers", default=4, type=int) 273 | parser.add_argument("--drop_prob", default=0.0, type=float) 274 | 275 | # data params 276 | 277 | # training params (opt) 278 | parser.add_argument("--lr", default=0.001, type=float) 279 | parser.add_argument("--lr_step_size", default=40, type=int) 280 | parser.add_argument("--lr_gamma", default=0.1, type=float) 281 | parser.add_argument("--weight_decay", default=0.0, type=float) 282 | 283 | parser.add_argument('--ixi-args', type=dict) 284 | 285 | return parser 286 | 287 | -------------------------------------------------------------------------------- /fastmri/mri_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import pathlib 9 | from argparse import ArgumentParser 10 | from collections import defaultdict 11 | 12 | import numpy as np 13 | import pytorch_lightning as pl 14 | import torch 15 | import torchvision 16 | from pytorch_lightning import _logger as log 17 | from torch.utils.data import DataLoader, DistributedSampler 18 | from .math import complex_abs_numpy 19 | 20 | import fastmri 21 | from fastmri import evaluate 22 | from fastmri.data import SliceDataset 23 | from fastmri.data.volume_sampler import VolumeSampler 24 | from fastmri.evaluate import NMSE, PSNR, SSIM, DistributedMetricSum 25 | 26 | 27 | class MriModule(pl.LightningModule): 28 | """ 29 | Abstract super class for deep larning reconstruction models. 30 | 31 | This is a subclass of the LightningModule class from pytorch_lightning, 32 | with some additional functionality specific to fastMRI: 33 | - fastMRI data loaders 34 | - Evaluating reconstructions 35 | - Visualization 36 | - Saving test reconstructions 37 | 38 | To implement a new reconstruction model, inherit from this class and 39 | implement the following methods: 40 | - train_data_transform, val_data_transform, test_data_transform: 41 | Create and return data transformer objects for each data split 42 | - training_step, validation_step, test_step: 43 | Define what happens in one step of training, validation, and 44 | testing, respectively 45 | - configure_optimizers: 46 | Create and return the optimizers 47 | 48 | Other methods from LightningModule can be overridden as needed. 49 | """ 50 | 51 | def __init__( 52 | self, 53 | data_path, 54 | challenge, 55 | exp_dir, 56 | exp_name, 57 | test_split="test", 58 | sample_rate=1.0, 59 | batch_size=1, 60 | num_workers=4, 61 | **kwargs, 62 | ): 63 | """ 64 | Args: 65 | data_path (pathlib.Path): Path to root data directory. For example, if 66 | knee/path is the root directory with subdirectories 67 | multicoil_train and multicoil_val, you would input knee/path for 68 | data_path. 69 | challenge (str): Name of challenge from ('multicoil', 'singlecoil'). 70 | exp_dir (pathlib.Path): Top directory for where you want to store log 71 | files. 72 | exp_name (str): Name of this experiment - this will store logs in 73 | exp_dir / {exp_name}. 74 | test_split (str): Name of test split from ("test", "challenge"). 75 | sample_rate (float, default=1.0): Fraction of models from the 76 | dataset to use. 77 | batch_size (int, default=1): Batch size. 78 | num_workers (int, default=4): Number of workers for PyTorch dataloader. 79 | """ 80 | super().__init__() 81 | 82 | self.data_path = data_path 83 | self.challenge = challenge 84 | self.exp_dir = exp_dir 85 | self.exp_name = exp_name 86 | self.test_split = test_split 87 | self.sample_rate = sample_rate 88 | self.batch_size = batch_size 89 | self.num_workers = num_workers 90 | 91 | self.NMSE = DistributedMetricSum(name="NMSE") 92 | self.SSIM = DistributedMetricSum(name="SSIM") 93 | self.PSNR = DistributedMetricSum(name="PSNR") 94 | self.ValLoss = DistributedMetricSum(name="ValLoss") 95 | self.TotExamples = DistributedMetricSum(name="TotExamples") 96 | 97 | def _create_data_loader(self, data_transform, data_partition, sample_rate=None): 98 | sample_rate = sample_rate or self.sample_rate 99 | dataset = SliceDataset( 100 | root=self.data_path / f"{self.challenge}_{data_partition}", 101 | transform=data_transform, 102 | sample_rate=sample_rate, 103 | challenge=self.challenge, 104 | mode=data_partition 105 | ) 106 | 107 | is_train = data_partition == "train" 108 | 109 | # ensure that entire volumes go to the same GPU in the ddp setting 110 | sampler = None 111 | if self.use_ddp: 112 | if is_train: 113 | sampler = DistributedSampler(dataset) 114 | else: 115 | sampler = VolumeSampler(dataset) 116 | 117 | dataloader = DataLoader( 118 | dataset=dataset, 119 | batch_size=self.batch_size, 120 | num_workers=self.num_workers, 121 | pin_memory=False, 122 | drop_last=is_train, 123 | sampler=sampler, 124 | ) 125 | 126 | return dataloader 127 | 128 | def train_data_transform(self): 129 | # raise NotImplementedError 130 | pass 131 | 132 | def train_dataloader(self): 133 | return self._create_data_loader( 134 | self.train_data_transform(), data_partition="train" 135 | ) 136 | 137 | def val_data_transform(self): 138 | # raise NotImplementedError 139 | pass 140 | 141 | def val_dataloader(self): 142 | # print ('666') 143 | return self._create_data_loader(self.val_data_transform(), data_partition="val") 144 | 145 | def test_data_transform(self): 146 | # raise NotImplementedError 147 | pass 148 | 149 | def test_dataloader(self): 150 | return self._create_data_loader( 151 | self.test_data_transform(), data_partition=self.test_split, sample_rate=1.0, 152 | ) 153 | 154 | def _visualize(self, val_outputs, val_targets): 155 | def _normalize(image): 156 | image = image[np.newaxis] 157 | image = image - image.min() 158 | return image / image.max() 159 | 160 | def _save_image(image, tag): 161 | grid = torchvision.utils.make_grid(torch.Tensor(image), nrow=4, pad_value=1) 162 | self.logger.experiment.add_image(tag, grid, self.global_step) 163 | 164 | # only process first size to simplify visualization. 165 | visualize_size = val_outputs[0].shape 166 | val_outputs = [x[0] for x in val_outputs if x.shape == visualize_size] 167 | val_targets = [x[0] for x in val_targets if x.shape == visualize_size] 168 | 169 | num_logs = len(val_outputs) 170 | assert num_logs == len(val_targets) 171 | 172 | num_viz_images = 16 173 | step = (num_logs + num_viz_images - 1) // num_viz_images 174 | outputs, targets = [], [] 175 | 176 | for i in range(0, num_logs, step): 177 | outputs.append(_normalize(val_outputs[i])) 178 | targets.append(_normalize(val_targets[i])) 179 | 180 | outputs = np.stack(outputs) 181 | targets = np.stack(targets) 182 | _save_image(targets, "Target") 183 | _save_image(outputs, "Reconstruction") 184 | _save_image(np.abs(targets - outputs), "Error") 185 | 186 | def _visualize_val(self, val_outputs, val_targets, val_inputs): 187 | def _normalize(image): 188 | image = image[np.newaxis] 189 | image = image - image.min() 190 | return image / image.max() 191 | 192 | def _save_image(image, tag): 193 | grid = torchvision.utils.make_grid(torch.Tensor(image), nrow=4, pad_value=1) 194 | self.logger.experiment.add_image(tag, grid, self.global_step) 195 | 196 | # only process first size to simplify visualization. 197 | visualize_size = val_outputs[0].shape 198 | visualize_size_inputs = val_inputs[0].shape 199 | val_outputs = [x[0] for x in val_outputs if x.shape == visualize_size] 200 | val_targets = [x[0] for x in val_targets if x.shape == visualize_size] 201 | val_inputs = [x[0] for x in val_inputs if x.shape == visualize_size_inputs]#???? 202 | 203 | num_logs = len(val_outputs) 204 | num_logs = len(val_inputs) 205 | assert num_logs == len(val_targets) 206 | 207 | num_viz_images = 16 208 | step = (num_logs + num_viz_images - 1) // num_viz_images 209 | outputs, targets, inputs = [], [], [] 210 | 211 | for i in range(0, num_logs, step): 212 | outputs.append(_normalize(val_outputs[i])) 213 | targets.append(_normalize(val_targets[i])) 214 | inputs.append(_normalize(val_inputs[i])) 215 | 216 | outputs = np.stack(outputs)#(2, 1, 1, 256, 256) 217 | targets = np.stack(targets)#(2, 1, 1, 256, 256) 218 | inputs = np.stack(inputs)#(2, 1, 1, 256, 256) 219 | 220 | _save_image(targets, "Target") 221 | _save_image(outputs, "Reconstruction") 222 | _save_image(inputs, "Input") 223 | _save_image(np.abs(targets - outputs), "Error") 224 | 225 | def validation_step_end(self, val_logs): 226 | device = val_logs["output"].device 227 | # device = val_logs["output_k"].device #kspace branch 228 | # move to CPU to save GPU memory 229 | val_logs = {key: value.cpu() for key, value in val_logs.items()} 230 | val_logs["device"] = device 231 | 232 | return val_logs 233 | 234 | def validation_epoch_end(self, val_logs): 235 | #assert val_logs[0]["output_im"].ndim == 3 236 | device = val_logs[0]["device"] 237 | 238 | # run the visualizations 239 | self._visualize_val( 240 | val_outputs=[x["output"].numpy() for x in val_logs], 241 | val_targets=[x["target"].numpy() for x in val_logs], 242 | val_inputs=[x["input"].numpy() for x in val_logs], 243 | ) 244 | 245 | # aggregate losses 246 | losses = [] 247 | outputs = defaultdict(list) 248 | targets = defaultdict(list) 249 | inputs = defaultdict(list) 250 | 251 | for val_log in val_logs: 252 | losses.append(val_log["val_loss"]) 253 | for i, (fname, slice_ind) in enumerate( 254 | zip(val_log["fname"], val_log["slice"]) 255 | ): 256 | # need to check for duplicate slices 257 | if slice_ind not in [s for (s, _) in outputs[int(fname)]]: 258 | outputs[int(fname)].append((int(slice_ind), val_log["output"][i])) 259 | targets[int(fname)].append((int(slice_ind), val_log["target"][i])) 260 | inputs[int(fname)].append((int(slice_ind), val_log["input"][i])) 261 | 262 | # handle aggregation for distributed case with pytorch_lightning metrics 263 | metrics = dict(val_loss=0, nmse=0, ssim=0, psnr=0) 264 | for fname in outputs: 265 | output = torch.stack([out for _, out in sorted(outputs[fname])]).numpy() 266 | target = torch.stack([tgt for _, tgt in sorted(targets[fname])]).numpy() 267 | input = torch.stack([inn for _, inn in sorted(inputs[fname])]).numpy() 268 | 269 | metrics["nmse"] = metrics["nmse"] + evaluate.nmse(target, output) 270 | metrics["ssim"] = metrics["ssim"] + evaluate.ssim(target, output) 271 | metrics["psnr"] = metrics["psnr"] + evaluate.psnr(target, output) 272 | 273 | # currently ddp reduction requires everything on CUDA, so we'll do this manually 274 | metrics["nmse"] = self.NMSE(torch.tensor(metrics["nmse"]).to(device)) 275 | metrics["ssim"] = self.SSIM(torch.tensor(metrics["ssim"]).to(device)) 276 | metrics["psnr"] = self.PSNR(torch.tensor(metrics["psnr"]).to(device)) 277 | metrics["val_loss"] = self.ValLoss(torch.sum(torch.stack(losses)).to(device)) 278 | 279 | num_examples = torch.tensor(len(outputs)).to(device) 280 | tot_examples = self.TotExamples(num_examples) 281 | 282 | log_metrics = { 283 | f"metrics/{metric}": values / tot_examples 284 | for metric, values in metrics.items() 285 | } 286 | metrics = {metric: values / tot_examples for metric, values in metrics.items()} 287 | print(tot_examples, device, metrics) 288 | 289 | fastmri.save_reconstructions( 290 | inputs, self.exp_dir / self.exp_name / "bicubic" 291 | ) 292 | 293 | return dict(log=log_metrics, **metrics) 294 | 295 | def test_epoch_end(self, test_logs): 296 | outputs = defaultdict(list) 297 | 298 | for log in test_logs: 299 | for i, (fname, slice) in enumerate(zip(log["fname"], log["slice"])): 300 | outputs[fname].append((slice, log["output"][i])) 301 | 302 | for fname in outputs: 303 | outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) 304 | 305 | fastmri.save_reconstructions( 306 | outputs, self.exp_dir / self.exp_name / "bicubic" 307 | ) 308 | 309 | return dict() 310 | 311 | @staticmethod 312 | def add_model_specific_args(parent_parser): # pragma: no-cover 313 | """ 314 | Define parameters that only apply to this model 315 | """ 316 | parser = ArgumentParser(parents=[parent_parser]) 317 | 318 | # data arguments 319 | parser.add_argument( 320 | "--data_path", default=pathlib.Path("Datasets/"), type=pathlib.Path 321 | ) 322 | parser.add_argument( 323 | "--challenge", 324 | choices=["singlecoil", "multicoil"], 325 | default="singlecoil", 326 | type=str, 327 | ) 328 | parser.add_argument( 329 | "--sample_rate", default=1.0, type=float, 330 | ) 331 | parser.add_argument( 332 | "--batch_size", default=1, type=int, 333 | ) 334 | parser.add_argument( 335 | "--num_workers", default=4, type=float, 336 | ) 337 | parser.add_argument( 338 | "--seed", default=42, type=int, 339 | ) 340 | 341 | # logging params 342 | parser.add_argument( 343 | "--exp_dir", default=pathlib.Path("logs/"), type=pathlib.Path 344 | ) 345 | parser.add_argument( 346 | "--exp_name", default="my_experiment", type=str, 347 | ) 348 | parser.add_argument( 349 | "--test_split", default="test", type=str, 350 | ) 351 | 352 | return parser 353 | -------------------------------------------------------------------------------- /fastmri/models/MANet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Single_level_densenet(nn.Module): 7 | def __init__(self,filters, num_conv = 4): 8 | super(Single_level_densenet, self).__init__() 9 | self.num_conv = num_conv 10 | self.conv_list = nn.ModuleList() 11 | self.bn_list = nn.ModuleList() 12 | for i in range(self.num_conv): 13 | self.conv_list.append(nn.Conv2d(filters,filters,3, padding = 1)) 14 | self.bn_list.append(nn.BatchNorm2d(filters)) 15 | 16 | def forward(self,x): 17 | outs = [] 18 | outs.append(x) 19 | for i in range(self.num_conv): 20 | temp_out = self.conv_list[i](outs[i]) 21 | if i > 0: 22 | for j in range(i): 23 | temp_out += outs[j] 24 | outs.append(F.relu(self.bn_list[i](temp_out))) 25 | out_final = outs[-1] 26 | del outs 27 | return out_final 28 | 29 | class Down_sample(nn.Module): 30 | def __init__(self,kernel_size = 2, stride = 2): 31 | super(Down_sample, self).__init__() 32 | self.down_sample_layer = nn.MaxPool2d(kernel_size, stride) 33 | 34 | def forward(self,x): 35 | y = self.down_sample_layer(x) 36 | return y,x 37 | 38 | class Upsample_n_Concat_1(nn.Module): 39 | def __init__(self,filters): 40 | super(Upsample_n_Concat_1, self).__init__() 41 | self.upsample_layer = nn.ConvTranspose2d(filters, filters, 4, padding = 1, stride = 2) 42 | self.conv = nn.Conv2d(128,filters,3, padding = 1) 43 | self.bn = nn.BatchNorm2d(filters) 44 | 45 | def forward(self,x,y): 46 | x = self.upsample_layer(x) 47 | x = torch.cat([x,y],dim = 1) 48 | x = F.relu(self.bn(self.conv(x))) 49 | return x 50 | 51 | class Upsample_n_Concat_2(nn.Module): 52 | def __init__(self,filters): 53 | super(Upsample_n_Concat_2, self).__init__() 54 | self.upsample_layer = nn.ConvTranspose2d(64, filters, 4, padding = 1, stride = 2) 55 | self.conv = nn.Conv2d(128,filters,3, padding = 1) 56 | self.bn = nn.BatchNorm2d(filters) 57 | 58 | def forward(self,x,y): 59 | x = self.upsample_layer(x) 60 | x = torch.cat([x,y],dim = 1) 61 | x = F.relu(self.bn(self.conv(x))) 62 | return x 63 | 64 | class Upsample_n_Concat_3(nn.Module): 65 | def __init__(self,filters): 66 | super(Upsample_n_Concat_3, self).__init__() 67 | self.upsample_layer = nn.ConvTranspose2d(64, filters, 4, padding = 1, stride = 2) 68 | self.conv = nn.Conv2d(128,filters,3, padding = 1) 69 | self.bn = nn.BatchNorm2d(filters) 70 | 71 | def forward(self,x,y): 72 | x = self.upsample_layer(x) 73 | x = torch.cat([x,y],dim = 1) 74 | x = F.relu(self.bn(self.conv(x))) 75 | return x 76 | 77 | class Upsample_n_Concat_4(nn.Module): 78 | def __init__(self,filters): 79 | super(Upsample_n_Concat_4, self).__init__() 80 | self.upsample_layer = nn.ConvTranspose2d(64, filters, 4, padding = 1, stride = 2) 81 | self.conv = nn.Conv2d(128,filters,3, padding = 1) 82 | self.bn = nn.BatchNorm2d(filters) 83 | 84 | def forward(self,x,y): 85 | x = self.upsample_layer(x) 86 | x = torch.cat([x,y],dim = 1) 87 | x = F.relu(self.bn(self.conv(x))) 88 | return x 89 | 90 | class Upsample_n_Concat_T1(nn.Module): 91 | def __init__(self,filters): 92 | super(Upsample_n_Concat_T1, self).__init__() 93 | self.upsample_layer = nn.ConvTranspose2d(filters, filters, 4, padding = 1, stride = 2) 94 | self.conv = nn.Conv2d(filters,filters,3, padding = 1) 95 | self.bn = nn.BatchNorm2d(filters) 96 | 97 | def forward(self,x): 98 | x = self.upsample_layer(x) 99 | x = F.relu(self.bn(self.conv(x))) 100 | return x 101 | class ChannelAttention(nn.Module): 102 | def __init__(self, in_planes, ratio=16): 103 | super(ChannelAttention, self).__init__() 104 | 105 | self.max_pool = nn.AdaptiveMaxPool2d(1) 106 | 107 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 108 | self.relu1 = nn.ReLU() 109 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 110 | 111 | self.sigmoid = nn.Sigmoid() 112 | def forward(self, x): 113 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 114 | out = max_out 115 | return self.sigmoid(out) 116 | 117 | class SpatialAttention(nn.Module): 118 | def __init__(self, kernel_size=7): 119 | super(SpatialAttention, self).__init__() 120 | 121 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 122 | padding = 3 if kernel_size == 7 else 1 123 | 124 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False) 125 | self.sigmoid = nn.Sigmoid() 126 | 127 | def forward(self, x): 128 | max_out, _ = torch.max(x, dim=1, keepdim=True) 129 | x=max_out 130 | x = self.conv1(x) 131 | return self.sigmoid(x) 132 | 133 | class Dense_Unet_k(nn.Module): 134 | def __init__(self, in_chan, out_chan, filters, num_conv = 4): #64 256 135 | super(Dense_Unet_k, self).__init__() 136 | self.conv1T1 = nn.Conv2d(in_chan,filters,1) 137 | self.conv1T2 = nn.Conv2d(4,filters,1) 138 | self.convdemD0 = nn.Conv2d(64,64,kernel_size=3,padding=1) 139 | self.convdemD1 = nn.Conv2d(64,64,kernel_size=3,padding=1) 140 | self.convdemD2 = nn.Conv2d(64,64,kernel_size=3,padding=1) 141 | self.convdemD3 = nn.Conv2d(64,64,kernel_size=3,padding=1) 142 | self.convdemU0 = nn.Conv2d(64,64,kernel_size=3,padding=1) 143 | self.convdemU1 = nn.Conv2d(64,64,kernel_size=3,padding=1) 144 | self.convdemU2 = nn.Conv2d(64,64,kernel_size=3,padding=1) 145 | self.convdemU3 = nn.Conv2d(64,64,kernel_size=3,padding=1) 146 | 147 | 148 | self.dT1_1 = Single_level_densenet(filters,num_conv ) 149 | self.downT1_1 = Down_sample() 150 | self.dT1_2 = Single_level_densenet(filters,num_conv ) 151 | self.downT1_2 = Down_sample() 152 | self.dT1_3 = Single_level_densenet(filters,num_conv ) 153 | self.downT1_3 = Down_sample() 154 | self.dT1_4 = Single_level_densenet(filters,num_conv ) 155 | self.downT1_4 = Down_sample() 156 | 157 | self.dT2_1 = Single_level_densenet(filters,num_conv ) 158 | self.downT2_1 = Down_sample() 159 | self.dT2_2 = Single_level_densenet(filters,num_conv ) 160 | self.downT2_2 = Down_sample() 161 | self.dT2_3 = Single_level_densenet(filters,num_conv ) 162 | self.downT2_3 = Down_sample() 163 | self.dT2_4 = Single_level_densenet(filters,num_conv ) 164 | self.downT2_4 = Down_sample() 165 | 166 | self.bottom_T1 = Single_level_densenet(filters,num_conv ) 167 | self.bottom_T2 = Single_level_densenet(filters,num_conv ) 168 | 169 | self.up4_T1 = Upsample_n_Concat_T1(filters) 170 | self.u4_T1 = Single_level_densenet(filters,num_conv ) 171 | self.up3_T1 = Upsample_n_Concat_T1(filters) 172 | self.u3_T1 = Single_level_densenet(filters,num_conv ) 173 | self.up2_T1 = Upsample_n_Concat_T1(filters) 174 | self.u2_T1 = Single_level_densenet(filters,num_conv ) 175 | self.up1_T1 = Upsample_n_Concat_T1(filters) 176 | self.u1_T1 = Single_level_densenet(filters,num_conv ) 177 | 178 | self.up4_T2 = Upsample_n_Concat_1(filters) 179 | self.u4_T2 = Single_level_densenet(filters,num_conv ) 180 | self.up3_T2 = Upsample_n_Concat_2(filters) 181 | self.u3_T2 = Single_level_densenet(filters,num_conv ) 182 | self.up2_T2 = Upsample_n_Concat_3(filters) 183 | self.u2_T2 = Single_level_densenet(filters,num_conv ) 184 | self.up1_T2 = Upsample_n_Concat_4(filters) 185 | self.u1_T2 = Single_level_densenet(filters,num_conv ) 186 | 187 | self.outconvT1 = nn.Conv2d(filters,out_chan, 1) 188 | self.outconvT2 = nn.Conv2d(64,out_chan, 1) 189 | 190 | self.atten_depth_channel_0=ChannelAttention(64) 191 | self.atten_depth_channel_1=ChannelAttention(64) 192 | self.atten_depth_channel_2=ChannelAttention(64) 193 | self.atten_depth_channel_3=ChannelAttention(64) 194 | 195 | self.atten_depth_channel_U_0=ChannelAttention(64) 196 | self.atten_depth_channel_U_1=ChannelAttention(64) 197 | self.atten_depth_channel_U_2=ChannelAttention(64) 198 | self.atten_depth_channel_U_3=ChannelAttention(64) 199 | 200 | self.atten_depth_spatial_0=SpatialAttention() 201 | self.atten_depth_spatial_1=SpatialAttention() 202 | self.atten_depth_spatial_2=SpatialAttention() 203 | self.atten_depth_spatial_3=SpatialAttention() 204 | 205 | self.atten_depth_spatial_U_0=SpatialAttention() 206 | self.atten_depth_spatial_U_1=SpatialAttention() 207 | self.atten_depth_spatial_U_2=SpatialAttention() 208 | self.atten_depth_spatial_U_3=SpatialAttention() 209 | 210 | 211 | 212 | 213 | 214 | def forward(self,T1, T2): 215 | 216 | T1_x1 = self.conv1T1(T1) 217 | T2 = torch.cat((T2,T1),dim=1) 218 | T2_x1 = self.conv1T2(T2) 219 | 220 | 221 | T1_x2,T1_y1 = self.downT1_1(self.dT1_1(T1_x1)) 222 | T2_x2,T2_y1 = self.downT2_1(self.dT2_1(T2_x1)) 223 | temp = T1_x2.mul(self.atten_depth_channel_0(T1_x2)) 224 | temp = temp.mul(self.atten_depth_spatial_0(temp)) 225 | T12_x2 = T2_x2.mul(temp)+T2_x2 226 | 227 | 228 | T1_x3,T1_y2 = self.downT1_1(self.dT1_2(T1_x2)) 229 | T2_x3,T2_y2 = self.downT2_1(self.dT2_2(T12_x2)) 230 | temp = T1_x3.mul(self.atten_depth_channel_1(T1_x3)) 231 | temp = temp.mul(self.atten_depth_spatial_1(temp)) 232 | T12_x3 = T2_x3.mul(temp)+T2_x3 233 | 234 | 235 | 236 | T1_x4,T1_y3 = self.downT1_1(self.dT1_3(T1_x3)) 237 | T2_x4,T2_y3 = self.downT2_1(self.dT2_3(T12_x3)) 238 | temp = T1_x4.mul(self.atten_depth_channel_2(T1_x4)) 239 | temp = temp.mul(self.atten_depth_spatial_2(temp)) 240 | T12_x4 = T2_x4.mul(temp)+T2_x4 241 | 242 | 243 | 244 | 245 | T1_x5,T1_y4 = self.downT1_1(self.dT1_4(T1_x4)) 246 | T2_x5,T2_y4 = self.downT2_1(self.dT2_4(T12_x4)) 247 | temp = T1_x5.mul(self.atten_depth_channel_3(T1_x5)) 248 | temp = temp.mul(self.atten_depth_spatial_3(temp)) 249 | T12_x5 = T2_x5.mul(temp)+T2_x5 250 | 251 | 252 | T1_x = self.bottom_T1(T1_x5) 253 | T2_x = self.bottom_T2(T12_x5) 254 | 255 | 256 | T1_1x = self.u4_T1(self.up4_T1(T1_x)) 257 | T2_1x = self.u4_T2(self.up4_T2(T2_x,T2_y4)) 258 | temp = T1_1x.mul(self.atten_depth_channel_U_0(T1_1x)) 259 | temp = temp.mul(self.atten_depth_spatial_U_0(temp)) 260 | T12_x = T2_1x.mul(temp)+T2_1x 261 | 262 | 263 | 264 | T1_2x = self.u3_T1(self.up3_T1(T1_1x)) 265 | T2_2x = self.u3_T2(self.up3_T2(T12_x,T2_y3)) 266 | temp = T1_2x.mul(self.atten_depth_channel_U_1(T1_2x)) 267 | temp = temp.mul(self.atten_depth_spatial_U_1(temp)) 268 | T12_x = T2_2x.mul(temp)+T2_2x 269 | 270 | T1_3x = self.u2_T1(self.up2_T1(T1_2x)) 271 | T2_3x = self.u2_T2(self.up2_T2(T12_x,T2_y2)) 272 | temp = T1_3x.mul(self.atten_depth_channel_U_2(T1_3x)) 273 | temp = temp.mul(self.atten_depth_spatial_U_2(temp)) 274 | T12_x = T2_3x.mul(temp)+T2_3x 275 | 276 | T1_4x = self.u1_T1(self.up1_T1(T1_3x)) 277 | T2_4x = self.u1_T2(self.up1_T2(T12_x,T2_y1)) 278 | temp = T1_4x.mul(self.atten_depth_channel_U_3(T1_4x)) 279 | temp = temp.mul(self.atten_depth_spatial_U_3(temp)) 280 | T12_x = T2_4x.mul(temp)+T2_4x 281 | 282 | T1 = self.outconvT1(T1_4x) 283 | T2 = self.outconvT2(T12_x) 284 | 285 | return T1,T2 286 | 287 | 288 | class Dense_Unet_img(nn.Module): 289 | def __init__(self, in_chan, out_chan, filters, num_conv = 4): 290 | super(Dense_Unet_img, self).__init__() 291 | self.conv1T1 = nn.Conv2d(in_chan,filters,1) 292 | self.conv1T2 = nn.Conv2d(2,filters,1) 293 | 294 | self.dT1_1 = Single_level_densenet_img(filters,num_conv ) 295 | self.downT1_1 = Down_sample_img() 296 | self.dT1_2 = Single_level_densenet_img(filters,num_conv ) 297 | self.downT1_2 = Down_sample_img() 298 | self.dT1_3 = Single_level_densenet_img(filters,num_conv ) 299 | self.downT1_3 = Down_sample_img() 300 | self.dT1_4 = Single_level_densenet_img(filters,num_conv ) 301 | self.downT1_4 = Down_sample_img() 302 | 303 | self.dT2_1 = Single_level_densenet_img(filters,num_conv ) 304 | self.downT2_1 = Down_sample_img() 305 | self.dT2_2 = Single_level_densenet_img(filters,num_conv ) 306 | self.downT2_2 = Down_sample_img() 307 | self.dT2_3 = Single_level_densenet_img(filters,num_conv ) 308 | self.downT2_3 = Down_sample_img() 309 | self.dT2_4 = Single_level_densenet_img(filters,num_conv ) 310 | self.downT2_4 = Down_sample_img() 311 | 312 | self.bottom_T1 = Single_level_densenet_img(filters,num_conv ) 313 | self.bottom_T2 = Single_level_densenet_img(filters,num_conv ) 314 | 315 | self.up4_T1 = Upsample_n_Concat_T1_img(filters) 316 | self.u4_T1 = Single_level_densenet_img(filters,num_conv ) 317 | self.up3_T1 = Upsample_n_Concat_T1_img(filters) 318 | self.u3_T1 = Single_level_densenet_img(filters,num_conv ) 319 | self.up2_T1 = Upsample_n_Concat_T1_img(filters) 320 | self.u2_T1 = Single_level_densenet_img(filters,num_conv ) 321 | self.up1_T1 = Upsample_n_Concat_T1_img(filters) 322 | self.u1_T1 = Single_level_densenet_img(filters,num_conv ) 323 | 324 | self.up4_T2 = Upsample_n_Concat_1_img(filters) 325 | self.u4_T2 = Single_level_densenet_img(filters,num_conv ) 326 | self.up3_T2 = Upsample_n_Concat_2_img(filters) 327 | self.u3_T2 = Single_level_densenet_img(filters,num_conv ) 328 | self.up2_T2 = Upsample_n_Concat_3_img(filters) 329 | self.u2_T2 = Single_level_densenet_img(filters,num_conv ) 330 | self.up1_T2 = Upsample_n_Concat_4_img(filters) 331 | self.u1_T2 = Single_level_densenet_img(filters,num_conv ) 332 | 333 | self.outconvT1 = nn.Conv2d(filters,out_chan, 1) 334 | self.outconvT2 = nn.Conv2d(64,out_chan, 1) 335 | #Components of DEM module 336 | self.atten_depth_channel_0=ChannelAttention_img(64) 337 | self.atten_depth_channel_1=ChannelAttention_img(64) 338 | self.atten_depth_channel_2=ChannelAttention_img(64) 339 | self.atten_depth_channel_3=ChannelAttention_img(64) 340 | 341 | self.atten_depth_channel_U_0=ChannelAttention_img(64) 342 | self.atten_depth_channel_U_1=ChannelAttention_img(64) 343 | self.atten_depth_channel_U_2=ChannelAttention_img(64) 344 | self.atten_depth_channel_U_3=ChannelAttention_img(64) 345 | 346 | self.atten_depth_spatial_0=SpatialAttention_img() 347 | self.atten_depth_spatial_1=SpatialAttention_img() 348 | self.atten_depth_spatial_2=SpatialAttention_img() 349 | self.atten_depth_spatial_3=SpatialAttention_img() 350 | 351 | self.atten_depth_spatial_U_0=SpatialAttention_img() 352 | self.atten_depth_spatial_U_1=SpatialAttention_img() 353 | self.atten_depth_spatial_U_2=SpatialAttention_img() 354 | self.atten_depth_spatial_U_3=SpatialAttention_img() 355 | 356 | 357 | 358 | 359 | 360 | def forward(self,T1, T2): 361 | 362 | T1_x1 = self.conv1T1(T1) 363 | T2 = torch.cat((T2,T1),dim=1) 364 | T2_x1 = self.conv1T2(T2) 365 | 366 | T1_x2,T1_y1 = self.downT1_1(self.dT1_1(T1_x1)) 367 | T2_x2,T2_y1 = self.downT2_1(self.dT2_1(T2_x1)) 368 | temp = T1_x2.mul(self.atten_depth_channel_0(T1_x2)) 369 | temp = temp.mul(self.atten_depth_spatial_0(temp)) 370 | T12_x2 = T2_x2.mul(temp)+T2_x2 371 | 372 | 373 | T1_x3,T1_y2 = self.downT1_1(self.dT1_2(T1_x2)) 374 | T2_x3,T2_y2 = self.downT2_1(self.dT2_2(T12_x2)) 375 | temp = T1_x3.mul(self.atten_depth_channel_1(T1_x3)) 376 | temp = temp.mul(self.atten_depth_spatial_1(temp)) 377 | T12_x3 = T2_x3.mul(temp)+T2_x3 378 | 379 | T1_x4,T1_y3 = self.downT1_1(self.dT1_3(T1_x3)) 380 | T2_x4,T2_y3 = self.downT2_1(self.dT2_3(T12_x3)) 381 | temp = T1_x4.mul(self.atten_depth_channel_2(T1_x4)) 382 | temp = temp.mul(self.atten_depth_spatial_2(temp)) 383 | T12_x4 = T2_x4.mul(temp)+T2_x4 384 | 385 | 386 | T1_x5,T1_y4 = self.downT1_1(self.dT1_4(T1_x4)) 387 | T2_x5,T2_y4 = self.downT2_1(self.dT2_4(T12_x4)) 388 | temp = T1_x5.mul(self.atten_depth_channel_3(T1_x5)) 389 | temp = temp.mul(self.atten_depth_spatial_3(temp)) 390 | T12_x5 = T2_x5.mul(temp)+T2_x5 391 | 392 | 393 | T1_x = self.bottom_T1(T1_x5) 394 | T2_x = self.bottom_T2(T12_x5) 395 | 396 | 397 | T1_1x = self.u4_T1(self.up4_T1(T1_x)) 398 | T2_1x = self.u4_T2(self.up4_T2(T2_x,T2_y4)) 399 | temp = T1_1x.mul(self.atten_depth_channel_U_0(T1_1x)) 400 | temp = temp.mul(self.atten_depth_spatial_U_0(temp)) 401 | T12_x = T2_1x.mul(temp)+T2_1x 402 | 403 | T1_2x = self.u3_T1(self.up3_T1(T1_1x)) 404 | T2_2x = self.u3_T2(self.up3_T2(T12_x,T2_y3)) 405 | temp = T1_2x.mul(self.atten_depth_channel_U_1(T1_2x)) 406 | temp = temp.mul(self.atten_depth_spatial_U_1(temp)) 407 | T12_x = T2_2x.mul(temp)+T2_2x 408 | 409 | 410 | T1_3x = self.u2_T1(self.up2_T1(T1_2x)) 411 | T2_3x = self.u2_T2(self.up2_T2(T12_x,T2_y2)) 412 | temp = T1_3x.mul(self.atten_depth_channel_U_2(T1_3x)) 413 | temp = temp.mul(self.atten_depth_spatial_U_2(temp)) 414 | T12_x = T2_3x.mul(temp)+T2_3x 415 | 416 | T1_4x = self.u1_T1(self.up1_T1(T1_3x)) 417 | T2_4x = self.u1_T2(self.up1_T2(T12_x,T2_y1)) 418 | temp = T1_4x.mul(self.atten_depth_channel_U_3(T1_4x)) 419 | temp = temp.mul(self.atten_depth_spatial_U_3(temp)) 420 | T12_x = T2_4x.mul(temp)+T2_4x 421 | 422 | T1 = self.outconvT1(T1_4x) 423 | T2 = self.outconvT2(T12_x) 424 | 425 | return T1,T2 426 | 427 | class Single_level_densenet_img(nn.Module): 428 | def __init__(self,filters, num_conv = 4): 429 | super(Single_level_densenet_img, self).__init__() 430 | self.num_conv = num_conv 431 | self.conv_list = nn.ModuleList() 432 | self.bn_list = nn.ModuleList() 433 | for i in range(self.num_conv): 434 | self.conv_list.append(nn.Conv2d(filters,filters,3, padding = 1)) 435 | self.bn_list.append(nn.BatchNorm2d(filters)) 436 | 437 | def forward(self,x): 438 | outs = [] 439 | outs.append(x) 440 | for i in range(self.num_conv): 441 | temp_out = self.conv_list[i](outs[i]) 442 | if i > 0: 443 | for j in range(i): 444 | temp_out += outs[j] 445 | outs.append(F.relu(self.bn_list[i](temp_out))) 446 | out_final = outs[-1] 447 | del outs 448 | return out_final 449 | 450 | class Down_sample_img(nn.Module): 451 | def __init__(self,kernel_size = 2, stride = 2): 452 | super(Down_sample_img, self).__init__() 453 | self.down_sample_layer = nn.MaxPool2d(kernel_size, stride) 454 | 455 | def forward(self,x): 456 | y = self.down_sample_layer(x) 457 | return y,x 458 | 459 | class Upsample_n_Concat_1_img(nn.Module): 460 | def __init__(self,filters): 461 | super(Upsample_n_Concat_1_img, self).__init__() 462 | self.upsample_layer = nn.ConvTranspose2d(filters, filters, 4, padding = 1, stride = 2) 463 | self.conv = nn.Conv2d(filters*2,filters,3, padding = 1) 464 | self.bn = nn.BatchNorm2d(filters) 465 | 466 | def forward(self,x,y): 467 | x = self.upsample_layer(x) 468 | x = torch.cat([x,y],dim = 1) 469 | x = F.relu(self.bn(self.conv(x))) 470 | return x 471 | 472 | class Upsample_n_Concat_2_img(nn.Module): 473 | def __init__(self,filters): 474 | super(Upsample_n_Concat_2_img, self).__init__() 475 | self.upsample_layer = nn.ConvTranspose2d(64, filters, 4, padding = 1, stride = 2) 476 | self.conv = nn.Conv2d(128,filters,3, padding = 1) 477 | self.bn = nn.BatchNorm2d(filters) 478 | 479 | def forward(self,x,y): 480 | x = self.upsample_layer(x) 481 | x = torch.cat([x,y],dim = 1) 482 | x = F.relu(self.bn(self.conv(x))) 483 | return x 484 | 485 | class Upsample_n_Concat_3_img(nn.Module): 486 | def __init__(self,filters): 487 | super(Upsample_n_Concat_3_img, self).__init__() 488 | self.upsample_layer = nn.ConvTranspose2d(64, filters, 4, padding = 1, stride = 2) 489 | self.conv = nn.Conv2d(128,filters,3, padding = 1) 490 | self.bn = nn.BatchNorm2d(filters) 491 | 492 | def forward(self,x,y): 493 | x = self.upsample_layer(x) 494 | x = torch.cat([x,y],dim = 1) 495 | x = F.relu(self.bn(self.conv(x))) 496 | return x 497 | 498 | class Upsample_n_Concat_4_img(nn.Module): 499 | def __init__(self,filters): 500 | super(Upsample_n_Concat_4_img, self).__init__() 501 | self.upsample_layer = nn.ConvTranspose2d(64, filters, 4, padding = 1, stride = 2) 502 | self.conv = nn.Conv2d(128,filters,3, padding = 1) 503 | self.bn = nn.BatchNorm2d(filters) 504 | 505 | def forward(self,x,y): 506 | x = self.upsample_layer(x) 507 | x = torch.cat([x,y],dim = 1) 508 | x = F.relu(self.bn(self.conv(x))) 509 | return x 510 | 511 | class Upsample_n_Concat_T1_img(nn.Module): 512 | def __init__(self,filters): 513 | super(Upsample_n_Concat_T1_img, self).__init__() 514 | self.upsample_layer = nn.ConvTranspose2d(filters, filters, 4, padding = 1, stride = 2) 515 | self.conv = nn.Conv2d(filters,filters,3, padding = 1) 516 | self.bn = nn.BatchNorm2d(filters) 517 | 518 | def forward(self,x): 519 | x = self.upsample_layer(x) 520 | x = F.relu(self.bn(self.conv(x))) 521 | return x 522 | class ChannelAttention_img(nn.Module): 523 | def __init__(self, in_planes, ratio=16): 524 | super(ChannelAttention_img, self).__init__() 525 | 526 | self.max_pool = nn.AdaptiveMaxPool2d(1) 527 | 528 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 529 | self.relu1 = nn.ReLU() 530 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 531 | 532 | self.sigmoid = nn.Sigmoid() 533 | def forward(self, x): 534 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 535 | out = max_out 536 | return self.sigmoid(out) 537 | 538 | class SpatialAttention_img(nn.Module): 539 | def __init__(self, kernel_size=7): 540 | super(SpatialAttention_img, self).__init__() 541 | 542 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 543 | padding = 3 if kernel_size == 7 else 1 544 | 545 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False) 546 | self.sigmoid = nn.Sigmoid() 547 | 548 | def forward(self, x): 549 | max_out, _ = torch.max(x, dim=1, keepdim=True) 550 | x=max_out 551 | x = self.conv1(x) 552 | return self.sigmoid(x) 553 | --------------------------------------------------------------------------------