├── ConvDecoder_architecture └── ConvDecoder.JPG ├── ConvDecoder_for_MRI.ipynb ├── ConvDecoder_vs_DIP_vs_DD_multicoil.ipynb ├── ConvDecoder_vs_Unet_multicoil.ipynb ├── DIP_UNET_models ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── common.cpython-36.pyc │ ├── common.cpython-37.pyc │ ├── downsampler.cpython-36.pyc │ ├── mri_model.cpython-36.pyc │ ├── resnet.cpython-36.pyc │ ├── skip.cpython-36.pyc │ ├── skip.cpython-37.pyc │ ├── skip_decoder.cpython-36.pyc │ ├── texture_nets.cpython-36.pyc │ └── unet.cpython-36.pyc ├── common.py ├── skip.py └── unet_and_tv │ ├── .ipynb_checkpoints │ └── run_bart-checkpoint.ipynb │ ├── README.md │ ├── __pycache__ │ ├── mri_model.cpython-36.pyc │ ├── mri_model.cpython-37.pyc │ ├── train_unet.cpython-36.pyc │ ├── train_unet.cpython-37.pyc │ ├── unet_model.cpython-36.pyc │ └── unet_model.cpython-37.pyc │ ├── common │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── args.cpython-36.pyc │ │ ├── args.cpython-37.pyc │ │ ├── evaluate.cpython-36.pyc │ │ ├── evaluate.cpython-37.pyc │ │ ├── subsample.cpython-36.pyc │ │ ├── subsample.cpython-37.pyc │ │ ├── utils.cpython-36.pyc │ │ └── utils.cpython-37.pyc │ ├── args.py │ ├── evaluate.py │ ├── subsample.py │ ├── test_subsample.py │ └── utils.py │ ├── data │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── mri_data.cpython-36.pyc │ │ ├── mri_data.cpython-37.pyc │ │ ├── transforms.cpython-36.pyc │ │ └── transforms.cpython-37.pyc │ ├── mri_data.py │ ├── test_transforms.py │ └── transforms.py │ ├── mri_model.py │ ├── run_bart.ipynb │ ├── run_bart_val.py │ ├── train_unet.py │ ├── unet_model.py │ └── uuu.zip ├── LICENSE ├── README.md ├── UNET_trained └── epoch=49.ckpt ├── common ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── args.cpython-36.pyc │ ├── args.cpython-37.pyc │ ├── evaluate.cpython-36.pyc │ ├── evaluate.cpython-37.pyc │ ├── subsample.cpython-36.pyc │ ├── subsample.cpython-37.pyc │ ├── utils.cpython-36.pyc │ └── utils.cpython-37.pyc ├── args.py ├── evaluate.py ├── subsample.py ├── test_subsample.py └── utils.py ├── demo_helper └── helpers.py ├── include ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── compression.cpython-36.pyc │ ├── decoder.cpython-36.pyc │ ├── decoder_conv.cpython-36.pyc │ ├── decoder_conv.cpython-37.pyc │ ├── decoder_parallel.cpython-36.pyc │ ├── decoder_parallel2.cpython-36.pyc │ ├── decoder_parallel_conv.cpython-36.pyc │ ├── decoder_parallel_conv.cpython-37.pyc │ ├── decoder_skip.cpython-36.pyc │ ├── decoder_skip.cpython-37.pyc │ ├── decoder_skip2.cpython-36.pyc │ ├── fit.cpython-36.pyc │ ├── fit.cpython-37.pyc │ ├── helpers.cpython-36.pyc │ ├── helpers.cpython-37.pyc │ ├── mri_helpers.cpython-36.pyc │ ├── mri_helpers.cpython-37.pyc │ ├── transforms.cpython-36.pyc │ ├── transforms.cpython-37.pyc │ ├── visualize.cpython-36.pyc │ └── wavelet.cpython-36.pyc ├── decoder_conv.py ├── decoder_parallel_conv.py ├── decoder_skip.py ├── fit.py ├── helpers.py ├── mri_helpers.py ├── onedim.py ├── pytorch_ssim │ ├── __init__.py │ └── __pycache__ │ │ └── __init__.cpython-36.pyc └── transforms.py ├── out_of_distribution_image └── cameraman.png ├── robustness_to_distribution_shift.ipynb └── visualize_layers_singlecoil.ipynb /ConvDecoder_architecture/ConvDecoder.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/ConvDecoder_architecture/ConvDecoder.JPG -------------------------------------------------------------------------------- /DIP_UNET_models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/__pycache__/downsampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/downsampler.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/__pycache__/mri_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/mri_model.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/__pycache__/skip.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/skip.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/__pycache__/skip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/skip.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/__pycache__/skip_decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/skip_decoder.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/__pycache__/texture_nets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/texture_nets.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/__pycache__/unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/unet.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | #from .downsampler import Downsampler 5 | 6 | def add_module(self, module): 7 | self.add_module(str(len(self) + 1), module) 8 | 9 | torch.nn.Module.add = add_module 10 | 11 | class Concat(nn.Module): 12 | def __init__(self, dim, *args): 13 | super(Concat, self).__init__() 14 | self.dim = dim 15 | 16 | for idx, module in enumerate(args): 17 | self.add_module(str(idx), module) 18 | 19 | def forward(self, input): 20 | inputs = [] 21 | for module in self._modules.values(): 22 | inputs.append(module(input)) 23 | 24 | inputs_shapes2 = [x.shape[2] for x in inputs] 25 | inputs_shapes3 = [x.shape[3] for x in inputs] 26 | 27 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)): 28 | inputs_ = inputs 29 | else: 30 | target_shape2 = min(inputs_shapes2) 31 | target_shape3 = min(inputs_shapes3) 32 | 33 | inputs_ = [] 34 | for inp in inputs: 35 | diff2 = (inp.size(2) - target_shape2) // 2 36 | diff3 = (inp.size(3) - target_shape3) // 2 37 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) 38 | 39 | return torch.cat(inputs_, dim=self.dim) 40 | 41 | def __len__(self): 42 | return len(self._modules) 43 | 44 | 45 | class GenNoise(nn.Module): 46 | def __init__(self, dim2): 47 | super(GenNoise, self).__init__() 48 | self.dim2 = dim2 49 | 50 | def forward(self, input): 51 | a = list(input.size()) 52 | a[1] = self.dim2 53 | # print (input.data.type()) 54 | 55 | b = torch.zeros(a).type_as(input.data) 56 | b.normal_() 57 | 58 | x = torch.autograd.Variable(b) 59 | 60 | return x 61 | 62 | 63 | class Swish(nn.Module): 64 | """ 65 | https://arxiv.org/abs/1710.05941 66 | The hype was so huge that I could not help but try it 67 | """ 68 | def __init__(self): 69 | super(Swish, self).__init__() 70 | self.s = nn.Sigmoid() 71 | 72 | def forward(self, x): 73 | return x * self.s(x) 74 | 75 | 76 | def act(act_fun = 'LeakyReLU'): 77 | ''' 78 | Either string defining an activation function or module (e.g. nn.ReLU) 79 | ''' 80 | if isinstance(act_fun, str): 81 | if act_fun == 'LeakyReLU': 82 | return nn.LeakyReLU(0.2, inplace=True) 83 | elif act_fun == 'ReLU': 84 | return nn.ReLU() 85 | elif act_fun == 'Swish': 86 | return Swish() 87 | elif act_fun == 'ELU': 88 | return nn.ELU() 89 | elif act_fun == 'none': 90 | return nn.Sequential() 91 | else: 92 | assert False 93 | else: 94 | return act_fun() 95 | 96 | 97 | def bn(num_features): 98 | return nn.BatchNorm2d(num_features) 99 | 100 | 101 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'): 102 | downsampler = None 103 | if stride != 1 and downsample_mode != 'stride': 104 | 105 | if downsample_mode == 'avg': 106 | downsampler = nn.AvgPool2d(stride, stride) 107 | elif downsample_mode == 'max': 108 | downsampler = nn.MaxPool2d(stride, stride) 109 | elif downsample_mode in ['lanczos2', 'lanczos3']: 110 | downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True) 111 | else: 112 | assert False 113 | 114 | stride = 1 115 | 116 | padder = None 117 | to_pad = int((kernel_size - 1) / 2) 118 | if pad == 'reflection': 119 | padder = nn.ReflectionPad2d(to_pad) 120 | to_pad = 0 121 | 122 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 123 | 124 | 125 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 126 | return nn.Sequential(*layers) -------------------------------------------------------------------------------- /DIP_UNET_models/skip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import * 4 | 5 | def skip( 6 | out_size,num_input_channels=2, num_output_channels=3, 7 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 8 | filter_size_down=3, filter_size_up=3, filter_skip_size=1, 9 | need_sigmoid=True, need_bias=True, 10 | pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU', 11 | need1x1_up=True): 12 | """Assembles encoder-decoder with skip connections. 13 | 14 | Arguments: 15 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 16 | pad (string): zero|reflection (default: 'zero') 17 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 18 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') 19 | 20 | """ 21 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 22 | 23 | n_scales = len(num_channels_down) 24 | 25 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) : 26 | upsample_mode = [upsample_mode]*n_scales 27 | 28 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)): 29 | downsample_mode = [downsample_mode]*n_scales 30 | 31 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) : 32 | filter_size_down = [filter_size_down]*n_scales 33 | 34 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 35 | filter_size_up = [filter_size_up]*n_scales 36 | 37 | last_scale = n_scales - 1 38 | 39 | cur_depth = None 40 | 41 | model = nn.Sequential() 42 | model_tmp = model 43 | 44 | input_depth = num_input_channels 45 | for i in range(len(num_channels_down)): 46 | 47 | deeper = nn.Sequential() 48 | skip = nn.Sequential() 49 | 50 | if num_channels_skip[i] != 0: 51 | model_tmp.add(Concat(1, skip, deeper)) 52 | else: 53 | model_tmp.add(deeper) 54 | 55 | model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) 56 | 57 | if num_channels_skip[i] != 0: 58 | skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) 59 | skip.add(bn(num_channels_skip[i])) 60 | skip.add(act(act_fun)) 61 | 62 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) 63 | 64 | deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i])) 65 | deeper.add(bn(num_channels_down[i])) 66 | deeper.add(act(act_fun)) 67 | 68 | deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 69 | deeper.add(bn(num_channels_down[i])) 70 | deeper.add(act(act_fun)) 71 | 72 | deeper_main = nn.Sequential() 73 | 74 | if i == len(num_channels_down) - 1: 75 | # The deepest 76 | k = num_channels_down[i] 77 | else: 78 | deeper.add(deeper_main) 79 | k = num_channels_up[i + 1] 80 | 81 | if i == 0: 82 | deeper.add(nn.Upsample(size=out_size, mode=upsample_mode[i])) 83 | else: 84 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) 85 | 86 | model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) 87 | model_tmp.add(bn(num_channels_up[i])) 88 | model_tmp.add(act(act_fun)) 89 | 90 | 91 | if need1x1_up: 92 | model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 93 | model_tmp.add(bn(num_channels_up[i])) 94 | model_tmp.add(act(act_fun)) 95 | 96 | input_depth = num_channels_down[i] 97 | model_tmp = deeper_main 98 | 99 | model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 100 | if need_sigmoid: 101 | model.add(nn.Sigmoid()) 102 | 103 | return model 104 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/README.md: -------------------------------------------------------------------------------- 1 | ## U-Net Model for MRI Reconstruction 2 | 3 | This directory contains a reference U-Net implementation for MRI reconstruction 4 | in PyTorch. 5 | 6 | To start training the model, run: 7 | ```bash 8 | python models/unet/train_unet.py --mode train --challenge CHALLENGE --data-path DATA --exp unet --mask-type MASK_TYPE 9 | ``` 10 | where `CHALLENGE` is either `singlecoil` or `multicoil`. And `MASK_TYPE` is either `random` (for knee) 11 | or `equispaced` (for brain). Training logs and checkpoints are saved in `experiments/unet` directory. 12 | 13 | To run the model on test data: 14 | ```bash 15 | python models/unet/train_unet.py --mode test --challenge CHALLENGE --data-path DATA --exp unet --out-dir reconstructions --checkpoint MODEL 16 | ``` 17 | where `MODEL` is the path to the model checkpoint from `experiments/unet/version_0/checkpoints/`. 18 | 19 | The outputs will be saved to `reconstructions` directory which can be uploaded for submission. 20 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/__pycache__/mri_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/mri_model.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/__pycache__/mri_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/mri_model.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/__pycache__/train_unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/train_unet.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/__pycache__/train_unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/train_unet.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/__pycache__/unet_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/unet_model.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/__pycache__/unet_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/unet_model.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/__pycache__/args.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/args.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/__pycache__/args.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/args.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/__pycache__/evaluate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/evaluate.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/__pycache__/evaluate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/evaluate.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/__pycache__/subsample.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/subsample.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/__pycache__/subsample.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/subsample.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/args.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import pathlib 10 | 11 | 12 | class Args(argparse.ArgumentParser): 13 | """ 14 | Defines global default arguments. 15 | """ 16 | 17 | def __init__(self, **overrides): 18 | """ 19 | Args: 20 | **overrides (dict, optional): Keyword arguments used to override default argument values 21 | """ 22 | 23 | super().__init__(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | 25 | self.add_argument('--seed', default=42, type=int, help='Seed for random number generators') 26 | self.add_argument('--resolution', default=320, type=int, help='Resolution of images') 27 | 28 | # Data parameters 29 | self.add_argument('--challenge', choices=['singlecoil', 'multicoil'], required=True, 30 | help='Which challenge') 31 | self.add_argument('--data-path', type=pathlib.Path, required=True, 32 | help='Path to the dataset') 33 | self.add_argument('--sample-rate', type=float, default=1., 34 | help='Fraction of total volumes to include') 35 | 36 | # Mask parameters 37 | self.add_argument('--accelerations', nargs='+', default=[4, 8], type=int, 38 | help='Ratio of k-space columns to be sampled. If multiple values are ' 39 | 'provided, then one of those is chosen uniformly at random for ' 40 | 'each volume.') 41 | self.add_argument('--center-fractions', nargs='+', default=[0.08, 0.04], type=float, 42 | help='Fraction of low-frequency k-space columns to be sampled. Should ' 43 | 'have the same length as accelerations') 44 | 45 | # Override defaults with passed overrides 46 | self.set_defaults(**overrides) 47 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import pathlib 10 | from argparse import ArgumentParser 11 | 12 | import h5py 13 | import numpy as np 14 | from runstats import Statistics 15 | from skimage.measure import compare_psnr, compare_ssim 16 | 17 | 18 | def mse(gt, pred): 19 | """ Compute Mean Squared Error (MSE) """ 20 | return np.mean((gt - pred) ** 2) 21 | 22 | 23 | def nmse(gt, pred): 24 | """ Compute Normalized Mean Squared Error (NMSE) """ 25 | return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 26 | 27 | 28 | def psnr(gt, pred): 29 | """ Compute Peak Signal to Noise Ratio metric (PSNR) """ 30 | return compare_psnr(gt, pred, data_range=gt.max()) 31 | 32 | 33 | def ssim(gt, pred): 34 | """ Compute Structural Similarity Index Metric (SSIM). """ 35 | return compare_ssim( 36 | gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max() 37 | ) 38 | 39 | 40 | METRIC_FUNCS = dict( 41 | MSE=mse, 42 | NMSE=nmse, 43 | PSNR=psnr, 44 | SSIM=ssim, 45 | ) 46 | 47 | 48 | class Metrics: 49 | """ 50 | Maintains running statistics for a given collection of metrics. 51 | """ 52 | 53 | def __init__(self, metric_funcs): 54 | self.metrics = { 55 | metric: Statistics() for metric in metric_funcs 56 | } 57 | 58 | def push(self, target, recons): 59 | for metric, func in METRIC_FUNCS.items(): 60 | self.metrics[metric].push(func(target, recons)) 61 | 62 | def means(self): 63 | return { 64 | metric: stat.mean() for metric, stat in self.metrics.items() 65 | } 66 | 67 | def stddevs(self): 68 | return { 69 | metric: stat.stddev() for metric, stat in self.metrics.items() 70 | } 71 | 72 | def __repr__(self): 73 | means = self.means() 74 | stddevs = self.stddevs() 75 | metric_names = sorted(list(means)) 76 | return ' '.join( 77 | f'{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}' for name in metric_names 78 | ) 79 | 80 | 81 | def evaluate(args, recons_key): 82 | metrics = Metrics(METRIC_FUNCS) 83 | 84 | for tgt_file in args.target_path.iterdir(): 85 | with h5py.File(tgt_file) as target, h5py.File( 86 | args.predictions_path / tgt_file.name) as recons: 87 | if args.acquisition and args.acquisition != target.attrs['acquisition']: 88 | continue 89 | target = target[recons_key].value 90 | recons = recons['reconstruction'].value 91 | metrics.push(target, recons) 92 | return metrics 93 | 94 | 95 | if __name__ == '__main__': 96 | parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 97 | parser.add_argument('--target-path', type=pathlib.Path, required=True, 98 | help='Path to the ground truth data') 99 | parser.add_argument('--predictions-path', type=pathlib.Path, required=True, 100 | help='Path to reconstructions') 101 | parser.add_argument('--challenge', choices=['singlecoil', 'multicoil'], required=True, 102 | help='Which challenge') 103 | parser.add_argument('--acquisition', choices=['CORPD_FBK', 'CORPDFS_FBK'], default=None, 104 | help='If set, only volumes of the specified acquisition type are used ' 105 | 'for evaluation. By default, all volumes are included.') 106 | args = parser.parse_args() 107 | 108 | recons_key = 'reconstruction_rss' if args.challenge == 'multicoil' else 'reconstruction_esc' 109 | metrics = evaluate(args, recons_key) 110 | print(metrics) 111 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/subsample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): 12 | if mask_type_str == 'random': 13 | return MaskFunc(center_fractions, accelerations) 14 | elif mask_type_str == 'equispaced': 15 | return EquispacedMaskFunc(center_fractions, accelerations) 16 | else: 17 | raise Exception(f"{mask_type_str} not supported") 18 | 19 | class MaskFunc: 20 | """ 21 | MaskFunc creates a sub-sampling mask of a given shape. 22 | 23 | The mask selects a subset of columns from the input k-space data. If the k-space data has N 24 | columns, the mask picks out: 25 | 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to 26 | low-frequencies 27 | 2. The other columns are selected uniformly at random with a probability equal to: 28 | prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). 29 | This ensures that the expected number of columns selected is equal to (N / acceleration) 30 | 31 | It is possible to use multiple center_fractions and accelerations, in which case one possible 32 | (center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is 33 | called. 34 | 35 | For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there 36 | is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% 37 | probability that 8-fold acceleration with 4% center fraction is selected. 38 | """ 39 | 40 | def __init__(self, center_fractions, accelerations): 41 | """ 42 | Args: 43 | center_fractions (List[float]): Fraction of low-frequency columns to be retained. 44 | If multiple values are provided, then one of these numbers is chosen uniformly 45 | each time. 46 | 47 | accelerations (List[int]): Amount of under-sampling. This should have the same length 48 | as center_fractions. If multiple values are provided, then one of these is chosen 49 | uniformly each time. An acceleration of 4 retains 25% of the columns, but they may 50 | not be spaced evenly. 51 | """ 52 | if len(center_fractions) != len(accelerations): 53 | raise ValueError('Number of center fractions should match number of accelerations') 54 | 55 | self.center_fractions = center_fractions 56 | self.accelerations = accelerations 57 | self.rng = np.random.RandomState() 58 | 59 | def __call__(self, shape, seed=None): 60 | """ 61 | Args: 62 | shape (iterable[int]): The shape of the mask to be created. The shape should have 63 | at least 3 dimensions. Samples are drawn along the second last dimension. 64 | seed (int, optional): Seed for the random number generator. Setting the seed 65 | ensures the same mask is generated each time for the same shape. 66 | Returns: 67 | torch.Tensor: A mask of the specified shape. 68 | """ 69 | if len(shape) < 3: 70 | raise ValueError('Shape should have 3 or more dimensions') 71 | 72 | self.rng.seed(seed) 73 | num_cols = shape[-2] 74 | 75 | choice = self.rng.randint(0, len(self.accelerations)) 76 | center_fraction = self.center_fractions[choice] 77 | acceleration = self.accelerations[choice] 78 | 79 | # Create the mask 80 | num_low_freqs = int(round(num_cols * center_fraction)) 81 | prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) 82 | mask = self.rng.uniform(size=num_cols) < prob 83 | pad = (num_cols - num_low_freqs + 1) // 2 84 | mask[pad:pad + num_low_freqs] = True 85 | 86 | # Reshape the mask 87 | mask_shape = [1 for _ in shape] 88 | mask_shape[-2] = num_cols 89 | mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) 90 | 91 | return mask 92 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/test_subsample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import pytest 10 | import torch 11 | 12 | from common.subsample import MaskFunc 13 | 14 | 15 | @pytest.mark.parametrize("center_fracs, accelerations, batch_size, dim", [ 16 | ([0.2], [4], 4, 320), 17 | ([0.2, 0.4], [4, 8], 2, 368), 18 | ]) 19 | def test_mask_reuse(center_fracs, accelerations, batch_size, dim): 20 | mask_func = MaskFunc(center_fracs, accelerations) 21 | shape = (batch_size, dim, dim, 2) 22 | mask1 = mask_func(shape, seed=123) 23 | mask2 = mask_func(shape, seed=123) 24 | mask3 = mask_func(shape, seed=123) 25 | assert torch.all(mask1 == mask2) 26 | assert torch.all(mask2 == mask3) 27 | 28 | 29 | @pytest.mark.parametrize("center_fracs, accelerations, batch_size, dim", [ 30 | ([0.2], [4], 4, 320), 31 | ([0.2, 0.4], [4, 8], 2, 368), 32 | ]) 33 | def test_mask_low_freqs(center_fracs, accelerations, batch_size, dim): 34 | mask_func = MaskFunc(center_fracs, accelerations) 35 | shape = (batch_size, dim, dim, 2) 36 | mask = mask_func(shape, seed=123) 37 | mask_shape = [1 for _ in shape] 38 | mask_shape[-2] = dim 39 | assert list(mask.shape) == mask_shape 40 | 41 | num_low_freqs_matched = False 42 | for center_frac in center_fracs: 43 | num_low_freqs = int(round(dim * center_frac)) 44 | pad = (dim - num_low_freqs + 1) // 2 45 | if np.all(mask[pad:pad + num_low_freqs].numpy() == 1): 46 | num_low_freqs_matched = True 47 | assert num_low_freqs_matched 48 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/common/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | import json 8 | 9 | import h5py 10 | 11 | 12 | def save_reconstructions(reconstructions, out_dir): 13 | """ 14 | Saves the reconstructions from a model into h5 files that is appropriate for submission 15 | to the leaderboard. 16 | 17 | Args: 18 | reconstructions (dict[str, np.array]): A dictionary mapping input filenames to 19 | corresponding reconstructions (of shape num_slices x height x width). 20 | out_dir (pathlib.Path): Path to the output directory where the reconstructions 21 | should be saved. 22 | """ 23 | out_dir.mkdir(exist_ok=True) 24 | for fname, recons in reconstructions.items(): 25 | with h5py.File(out_dir / fname, 'w') as f: 26 | f.create_dataset('reconstruction', data=recons) 27 | 28 | 29 | def tensor_to_complex_np(data): 30 | """ 31 | Converts a complex torch tensor to numpy array. 32 | Args: 33 | data (torch.Tensor): Input data to be converted to numpy. 34 | 35 | Returns: 36 | np.array: Complex numpy version of data 37 | """ 38 | data = data.numpy() 39 | return data[..., 0] + 1j * data[..., 1] 40 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/data/README.md: -------------------------------------------------------------------------------- 1 | ## MRI Data Loader and Transforms 2 | 3 | This directory provides a reference data loader to read the fastMRI data one slice at a time and 4 | some useful data transforms to work with the data in PyTorch. 5 | 6 | Each partition (train, validation or test) of the fastMRI data is distributed as a set of HDF5 7 | files, such that each HDF5 file contains data from one MR acquisition. The set of fields and 8 | attributes in these HDF5 files depends on the track (single-coil or multi-coil) and the data 9 | partition. 10 | 11 | #### Single-Coil Track 12 | 13 | * Training & Validation data: 14 | * `kspace`: Emulated single-coil k-space data. The shape of the kspace tensor is 15 | (number of slices, height, width). 16 | * `reconstruction_rss`: Root-sum-of-squares reconstruction of the multi-coil k-space that was 17 | used to derive the emulated single-coil k-space cropped to the center 320 × 320 region. 18 | The shape of the reconstruction rss tensor is (number of slices, 320, 320). 19 | * `reconstruction_esc`: The inverse Fourier transform of the single-coil k-space data cropped 20 | to the center 320 × 320 region. The shape of the reconstruction esc tensor is (number of 21 | slices, 320, 320). 22 | * Test data: 23 | * `kspace`: Undersampled emulated single-coil k-space. The shape of the kspace tensor is 24 | (number of slices, height, width). 25 | * `mask`: Defines the undersampled Cartesian k-space trajectory. The number of elements in 26 | the mask tensor is the same as the width of k-space. 27 | 28 | 29 | #### Multi-Coil Track 30 | 31 | * Training & Validation data: 32 | * `kspace`: Multi-coil k-space data. The shape of the kspace tensor is 33 | (number of slices, number of coils, height, width). 34 | * `reconstruction_rss`: Root-sum-of-squares reconstruction of the multi-coil k-space 35 | data cropped to the center 320 × 320 region. The shape of the reconstruction rss tensor 36 | is (number of slices, 320, 320). 37 | * Test data: 38 | * `kspace`: Undersampled multi-coil k-space. The shape of the kspace tensor is 39 | (number of slices, number of coils, height, width). 40 | * `mask` Defines the undersampled Cartesian k-space trajectory. The number of elements in 41 | the mask tensor is the same as the width of k-space. 42 | 43 | 44 | ### Data Transforms 45 | 46 | `data/transforms.py` provides a number of useful data transformation functions that work with 47 | PyTorch tensors. 48 | 49 | 50 | #### Data Loader 51 | 52 | `data/mri_data.py` provides a `SliceData` class to read one MR slice at a time. It takes as input 53 | a `transform` function or callable object that can be used transform the data into the format that 54 | you need. This makes the data loader versatile and can be used to run different kinds of 55 | reconstruction methods. 56 | 57 | The following is a simple example for how to use the data loader. For more concrete examples, 58 | please look at the baseline model code in the `models` directory. 59 | 60 | ```python 61 | import pathlib 62 | from common import subsample 63 | from data import transforms, mri_data 64 | 65 | # Create a mask function 66 | mask_func = subsample.RandomMaskFunc(center_fractions=[0.08, 0.04], accelerations=[4, 8]) 67 | 68 | def data_transform(kspace, target, data_attributes, filename, slice_num): 69 | # Transform the data into appropriate format 70 | # Here we simply mask the k-space and return the result 71 | kspace = transforms.to_tensor(kspace) 72 | masked_kspace, _ = transforms.apply_mask(kspace, mask_func) 73 | return masked_kspace 74 | 75 | dataset = mri_data.SliceData( 76 | root=pathlib.Path('path/to/data'), 77 | transform=data_transform, 78 | challenge='singlecoil' 79 | ) 80 | 81 | for masked_kspace in dataset: 82 | # Do reconstruction 83 | pass 84 | ``` 85 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/data/__pycache__/mri_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/mri_data.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/data/__pycache__/mri_data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/mri_data.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/data/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/data/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/data/mri_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import pathlib 9 | import random 10 | 11 | import h5py 12 | from torch.utils.data import Dataset 13 | 14 | 15 | class SliceData(Dataset): 16 | """ 17 | A PyTorch Dataset that provides access to MR image slices. 18 | """ 19 | 20 | def __init__(self, root, transform, challenge, sample_rate=1): 21 | """ 22 | Args: 23 | root (pathlib.Path): Path to the dataset. 24 | transform (callable): A callable object that pre-processes the raw data into 25 | appropriate form. The transform function should take 'kspace', 'target', 26 | 'attributes', 'filename', and 'slice' as inputs. 'target' may be null 27 | for test data. 28 | challenge (str): "singlecoil" or "multicoil" depending on which challenge to use. 29 | sample_rate (float, optional): A float between 0 and 1. This controls what fraction 30 | of the volumes should be loaded. 31 | """ 32 | if challenge not in ('singlecoil', 'multicoil'): 33 | raise ValueError('challenge should be either "singlecoil" or "multicoil"') 34 | 35 | self.transform = transform 36 | self.recons_key = 'reconstruction_esc' if challenge == 'singlecoil' \ 37 | else 'reconstruction_rss' 38 | 39 | self.examples = [] 40 | files = list(pathlib.Path(root).iterdir()) 41 | if sample_rate < 1: 42 | random.shuffle(files) 43 | num_files = round(len(files) * sample_rate) 44 | files = files[:num_files] 45 | for fname in sorted(files): 46 | kspace = h5py.File(fname, 'r')['kspace'] 47 | num_slices = kspace.shape[0] 48 | self.examples += [(fname, slice) for slice in range(num_slices)] 49 | 50 | def __len__(self): 51 | return len(self.examples) 52 | 53 | def __getitem__(self, i): 54 | fname, slice = self.examples[i] 55 | with h5py.File(fname, 'r') as data: 56 | kspace = data['kspace'][slice] 57 | target = data[self.recons_key][slice] if self.recons_key in data else None 58 | return self.transform(kspace, target, data.attrs, fname.name, slice) 59 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/data/test_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import pytest 10 | import torch 11 | 12 | from common import utils 13 | from common.subsample import RandomMaskFunc 14 | from data import transforms 15 | 16 | 17 | def create_input(shape): 18 | input = np.arange(np.product(shape)).reshape(shape) 19 | input = torch.from_numpy(input).float() 20 | return input 21 | 22 | 23 | @pytest.mark.parametrize('shape, center_fractions, accelerations', [ 24 | ([4, 32, 32, 2], [0.08], [4]), 25 | ([2, 64, 64, 2], [0.04, 0.08], [8, 4]), 26 | ]) 27 | def test_apply_mask(shape, center_fractions, accelerations): 28 | mask_func = RandomMaskFunc(center_fractions, accelerations) 29 | expected_mask = mask_func(shape, seed=123) 30 | input = create_input(shape) 31 | output, mask = transforms.apply_mask(input, mask_func, seed=123) 32 | assert output.shape == input.shape 33 | assert mask.shape == expected_mask.shape 34 | assert np.all(expected_mask.numpy() == mask.numpy()) 35 | assert np.all(np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy()) 36 | 37 | 38 | @pytest.mark.parametrize('shape', [ 39 | [3, 3], 40 | [4, 6], 41 | [10, 8, 4], 42 | ]) 43 | def test_fft2(shape): 44 | shape = shape + [2] 45 | input = create_input(shape) 46 | out_torch = transforms.fft2(input).numpy() 47 | out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] 48 | 49 | input_numpy = utils.tensor_to_complex_np(input) 50 | input_numpy = np.fft.ifftshift(input_numpy, (-2, -1)) 51 | out_numpy = np.fft.fft2(input_numpy, norm='ortho') 52 | out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) 53 | assert np.allclose(out_torch, out_numpy) 54 | 55 | 56 | @pytest.mark.parametrize('shape', [ 57 | [3, 3], 58 | [4, 6], 59 | [10, 8, 4], 60 | ]) 61 | def test_ifft2(shape): 62 | shape = shape + [2] 63 | input = create_input(shape) 64 | out_torch = transforms.ifft2(input).numpy() 65 | out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] 66 | 67 | input_numpy = utils.tensor_to_complex_np(input) 68 | input_numpy = np.fft.ifftshift(input_numpy, (-2, -1)) 69 | out_numpy = np.fft.ifft2(input_numpy, norm='ortho') 70 | out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) 71 | assert np.allclose(out_torch, out_numpy) 72 | 73 | 74 | @pytest.mark.parametrize('shape', [ 75 | [3, 3], 76 | [4, 6], 77 | [10, 8, 4], 78 | ]) 79 | def test_complex_abs(shape): 80 | shape = shape + [2] 81 | input = create_input(shape) 82 | out_torch = transforms.complex_abs(input).numpy() 83 | input_numpy = utils.tensor_to_complex_np(input) 84 | out_numpy = np.abs(input_numpy) 85 | assert np.allclose(out_torch, out_numpy) 86 | 87 | 88 | @pytest.mark.parametrize('shape, dim', [ 89 | [[3, 3], 0], 90 | [[4, 6], 1], 91 | [[10, 8, 4], 2], 92 | ]) 93 | def test_root_sum_of_squares(shape, dim): 94 | input = create_input(shape) 95 | out_torch = transforms.root_sum_of_squares(input, dim).numpy() 96 | out_numpy = np.sqrt(np.sum(input.numpy() ** 2, dim)) 97 | assert np.allclose(out_torch, out_numpy) 98 | 99 | 100 | @pytest.mark.parametrize('shape, target_shape', [ 101 | [[10, 10], [4, 4]], 102 | [[4, 6], [2, 4]], 103 | [[8, 4], [4, 4]], 104 | ]) 105 | def test_center_crop(shape, target_shape): 106 | input = create_input(shape) 107 | out_torch = transforms.center_crop(input, target_shape).numpy() 108 | assert list(out_torch.shape) == target_shape 109 | 110 | 111 | @pytest.mark.parametrize('shape, target_shape', [ 112 | [[10, 10], [4, 4]], 113 | [[4, 6], [2, 4]], 114 | [[8, 4], [4, 4]], 115 | ]) 116 | def test_complex_center_crop(shape, target_shape): 117 | shape = shape + [2] 118 | input = create_input(shape) 119 | out_torch = transforms.complex_center_crop(input, target_shape).numpy() 120 | assert list(out_torch.shape) == target_shape + [2, ] 121 | 122 | 123 | @pytest.mark.parametrize('shape, mean, stddev', [ 124 | [[10, 10], 0, 1], 125 | [[4, 6], 4, 10], 126 | [[8, 4], 2, 3], 127 | ]) 128 | def test_normalize(shape, mean, stddev): 129 | input = create_input(shape) 130 | output = transforms.normalize(input, mean, stddev).numpy() 131 | assert np.isclose(output.mean(), (input.numpy().mean() - mean) / stddev) 132 | assert np.isclose(output.std(), input.numpy().std() / stddev) 133 | 134 | 135 | @pytest.mark.parametrize('shape', [ 136 | [10, 10], 137 | [20, 40, 30], 138 | ]) 139 | def test_normalize_instance(shape): 140 | input = create_input(shape) 141 | output, mean, stddev = transforms.normalize_instance(input) 142 | output = output.numpy() 143 | assert np.isclose(input.numpy().mean(), mean, rtol=1e-2) 144 | assert np.isclose(input.numpy().std(), stddev, rtol=1e-2) 145 | assert np.isclose(output.mean(), 0, rtol=1e-2, atol=1e-3) 146 | assert np.isclose(output.std(), 1, rtol=1e-2, atol=1e-3) 147 | 148 | 149 | @pytest.mark.parametrize('shift, dim', [ 150 | (0, 0), 151 | (1, 0), 152 | (-1, 0), 153 | (100, 0), 154 | ((1, 2), (1, 2)), 155 | ]) 156 | @pytest.mark.parametrize('shape', [ 157 | [5, 6, 2], 158 | [3, 4, 5], 159 | ]) 160 | def test_roll(shift, dim, shape): 161 | input = np.arange(np.product(shape)).reshape(shape) 162 | out_torch = transforms.roll(torch.from_numpy(input), shift, dim).numpy() 163 | out_numpy = np.roll(input, shift, dim) 164 | assert np.allclose(out_torch, out_numpy) 165 | 166 | 167 | @pytest.mark.parametrize('shape', [ 168 | [5, 3], 169 | [2, 4, 6], 170 | ]) 171 | def test_fftshift(shape): 172 | input = np.arange(np.product(shape)).reshape(shape) 173 | out_torch = transforms.fftshift(torch.from_numpy(input)).numpy() 174 | out_numpy = np.fft.fftshift(input) 175 | assert np.allclose(out_torch, out_numpy) 176 | 177 | 178 | @pytest.mark.parametrize('shape', [ 179 | [5, 3], 180 | [2, 4, 5], 181 | [2, 7, 5], 182 | ]) 183 | def test_ifftshift(shape): 184 | input = np.arange(np.product(shape)).reshape(shape) 185 | out_torch = transforms.ifftshift(torch.from_numpy(input)).numpy() 186 | out_numpy = np.fft.ifftshift(input) 187 | assert np.allclose(out_torch, out_numpy) 188 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/data/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def to_tensor(data): 13 | """ 14 | Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts 15 | are stacked along the last dimension. 16 | 17 | Args: 18 | data (np.array): Input numpy array 19 | 20 | Returns: 21 | torch.Tensor: PyTorch version of data 22 | """ 23 | if np.iscomplexobj(data): 24 | data = np.stack((data.real, data.imag), axis=-1) 25 | return torch.from_numpy(data) 26 | 27 | 28 | def apply_mask(data, mask_func, seed=None): 29 | """ 30 | Subsample given k-space by multiplying with a mask. 31 | 32 | Args: 33 | data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where 34 | dimensions -3 and -2 are the spatial dimensions, and the final dimension has size 35 | 2 (for complex values). 36 | mask_func (callable): A function that takes a shape (tuple of ints) and a random 37 | number seed and returns a mask. 38 | seed (int or 1-d array_like, optional): Seed for the random number generator. 39 | 40 | Returns: 41 | (tuple): tuple containing: 42 | masked data (torch.Tensor): Subsampled k-space data 43 | mask (torch.Tensor): The generated mask 44 | """ 45 | shape = np.array(data.shape) 46 | shape[:-3] = 1 47 | mask = mask_func(shape, seed) 48 | return torch.where(mask == 0, torch.Tensor([0]), data), mask 49 | 50 | 51 | def fft2(data): 52 | """ 53 | Apply centered 2 dimensional Fast Fourier Transform. 54 | 55 | Args: 56 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 57 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 58 | assumed to be batch dimensions. 59 | 60 | Returns: 61 | torch.Tensor: The FFT of the input. 62 | """ 63 | assert data.size(-1) == 2 64 | data = ifftshift(data, dim=(-3, -2)) 65 | data = torch.fft(data, 2, normalized=True) 66 | data = fftshift(data, dim=(-3, -2)) 67 | return data 68 | 69 | 70 | def ifft2(data): 71 | """ 72 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 73 | 74 | Args: 75 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 76 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 77 | assumed to be batch dimensions. 78 | 79 | Returns: 80 | torch.Tensor: The IFFT of the input. 81 | """ 82 | assert data.size(-1) == 2 83 | data = ifftshift(data, dim=(-3, -2)) 84 | data = torch.ifft(data, 2, normalized=True) 85 | data = fftshift(data, dim=(-3, -2)) 86 | return data 87 | 88 | 89 | def complex_abs(data): 90 | """ 91 | Compute the absolute value of a complex valued input tensor. 92 | 93 | Args: 94 | data (torch.Tensor): A complex valued tensor, where the size of the final dimension 95 | should be 2. 96 | 97 | Returns: 98 | torch.Tensor: Absolute value of data 99 | """ 100 | assert data.size(-1) == 2 101 | return (data ** 2).sum(dim=-1).sqrt() 102 | 103 | 104 | def root_sum_of_squares(data, dim=0): 105 | """ 106 | Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor. 107 | 108 | Args: 109 | data (torch.Tensor): The input tensor 110 | dim (int): The dimensions along which to apply the RSS transform 111 | 112 | Returns: 113 | torch.Tensor: The RSS value 114 | """ 115 | return torch.sqrt((data ** 2).sum(dim)) 116 | 117 | 118 | def center_crop(data, shape): 119 | """ 120 | Apply a center crop to the input real image or batch of real images. 121 | 122 | Args: 123 | data (torch.Tensor): The input tensor to be center cropped. It should have at 124 | least 2 dimensions and the cropping is applied along the last two dimensions. 125 | shape (int, int): The output shape. The shape should be smaller than the 126 | corresponding dimensions of data. 127 | 128 | Returns: 129 | torch.Tensor: The center cropped image 130 | """ 131 | assert 0 < shape[0] <= data.shape[-2] 132 | assert 0 < shape[1] <= data.shape[-1] 133 | w_from = (data.shape[-2] - shape[0]) // 2 134 | h_from = (data.shape[-1] - shape[1]) // 2 135 | w_to = w_from + shape[0] 136 | h_to = h_from + shape[1] 137 | return data[..., w_from:w_to, h_from:h_to] 138 | 139 | 140 | def complex_center_crop(data, shape): 141 | """ 142 | Apply a center crop to the input image or batch of complex images. 143 | 144 | Args: 145 | data (torch.Tensor): The complex input tensor to be center cropped. It should 146 | have at least 3 dimensions and the cropping is applied along dimensions 147 | -3 and -2 and the last dimensions should have a size of 2. 148 | shape (int, int): The output shape. The shape should be smaller than the 149 | corresponding dimensions of data. 150 | 151 | Returns: 152 | torch.Tensor: The center cropped image 153 | """ 154 | assert 0 < shape[0] <= data.shape[-3] 155 | assert 0 < shape[1] <= data.shape[-2] 156 | w_from = (data.shape[-3] - shape[0]) // 2 157 | h_from = (data.shape[-2] - shape[1]) // 2 158 | w_to = w_from + shape[0] 159 | h_to = h_from + shape[1] 160 | return data[..., w_from:w_to, h_from:h_to, :] 161 | 162 | 163 | def normalize(data, mean, stddev, eps=0.): 164 | """ 165 | Normalize the given tensor using: 166 | (data - mean) / (stddev + eps) 167 | 168 | Args: 169 | data (torch.Tensor): Input data to be normalized 170 | mean (float): Mean value 171 | stddev (float): Standard deviation 172 | eps (float): Added to stddev to prevent dividing by zero 173 | 174 | Returns: 175 | torch.Tensor: Normalized tensor 176 | """ 177 | return (data - mean) / (stddev + eps) 178 | 179 | 180 | def normalize_instance(data, eps=0.): 181 | """ 182 | Normalize the given tensor using: 183 | (data - mean) / (stddev + eps) 184 | where mean and stddev are computed from the data itself. 185 | 186 | Args: 187 | data (torch.Tensor): Input data to be normalized 188 | eps (float): Added to stddev to prevent dividing by zero 189 | 190 | Returns: 191 | torch.Tensor: Normalized tensor 192 | """ 193 | mean = data.mean() 194 | std = data.std() 195 | return normalize(data, mean, std, eps), mean, std 196 | 197 | 198 | # Helper functions 199 | 200 | def roll(x, shift, dim): 201 | """ 202 | Similar to np.roll but applies to PyTorch Tensors 203 | """ 204 | if isinstance(shift, (tuple, list)): 205 | assert len(shift) == len(dim) 206 | for s, d in zip(shift, dim): 207 | x = roll(x, s, d) 208 | return x 209 | shift = shift % x.size(dim) 210 | if shift == 0: 211 | return x 212 | left = x.narrow(dim, 0, x.size(dim) - shift) 213 | right = x.narrow(dim, x.size(dim) - shift, shift) 214 | return torch.cat((right, left), dim=dim) 215 | 216 | 217 | def fftshift(x, dim=None): 218 | """ 219 | Similar to np.fft.fftshift but applies to PyTorch Tensors 220 | """ 221 | if dim is None: 222 | dim = tuple(range(x.dim())) 223 | shift = [dim // 2 for dim in x.shape] 224 | elif isinstance(dim, int): 225 | shift = x.shape[dim] // 2 226 | else: 227 | shift = [x.shape[i] // 2 for i in dim] 228 | return roll(x, shift, dim) 229 | 230 | 231 | def ifftshift(x, dim=None): 232 | """ 233 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 234 | """ 235 | if dim is None: 236 | dim = tuple(range(x.dim())) 237 | shift = [(dim + 1) // 2 for dim in x.shape] 238 | elif isinstance(dim, int): 239 | shift = (x.shape[dim] + 1) // 2 240 | else: 241 | shift = [(x.shape[i] + 1) // 2 for i in dim] 242 | return roll(x, shift, dim) 243 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/mri_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from collections import defaultdict 9 | 10 | import numpy as np 11 | import pytorch_lightning as pl 12 | import torch 13 | import torchvision 14 | from torch.utils.data import DistributedSampler, DataLoader 15 | 16 | from common import evaluate 17 | from common.utils import save_reconstructions 18 | from .data.mri_data import SliceData 19 | 20 | 21 | class MRIModel(pl.LightningModule): 22 | """ 23 | Abstract super class for Deep Learning based reconstruction models. 24 | This is a subclass of the LightningModule class from pytorch_lightning, with 25 | some additional functionality specific to fastMRI: 26 | - fastMRI data loaders 27 | - Evaluating reconstructions 28 | - Visualization 29 | - Saving test reconstructions 30 | 31 | To implement a new reconstruction model, inherit from this class and implement the 32 | following methods: 33 | - train_data_transform, val_data_transform, test_data_transform: 34 | Create and return data transformer objects for each data split 35 | - training_step, validation_step, test_step: 36 | Define what happens in one step of training, validation and testing respectively 37 | - configure_optimizers: 38 | Create and return the optimizers 39 | Other methods from LightningModule can be overridden as needed. 40 | """ 41 | 42 | def __init__(self, hparams): 43 | super().__init__() 44 | self.hparams = hparams 45 | 46 | def _create_data_loader(self, data_transform, data_partition, sample_rate=None): 47 | sample_rate = sample_rate or self.hparams.sample_rate 48 | dataset = SliceData( 49 | root=self.hparams.data_path / f'{self.hparams.challenge}_{data_partition}', 50 | transform=data_transform, 51 | sample_rate=sample_rate, 52 | challenge=self.hparams.challenge 53 | ) 54 | sampler = DistributedSampler(dataset) 55 | return DataLoader( 56 | dataset=dataset, 57 | batch_size=self.hparams.batch_size, 58 | num_workers=0, 59 | pin_memory=True, 60 | sampler=sampler, 61 | ) 62 | 63 | def train_data_transform(self): 64 | raise NotImplementedError 65 | 66 | @pl.data_loader 67 | def train_dataloader(self): 68 | return self._create_data_loader(self.train_data_transform(), data_partition='train') 69 | 70 | def val_data_transform(self): 71 | raise NotImplementedError 72 | 73 | @pl.data_loader 74 | def val_dataloader(self): 75 | return self._create_data_loader(self.val_data_transform(), data_partition='val') 76 | 77 | def test_data_transform(self): 78 | raise NotImplementedError 79 | 80 | @pl.data_loader 81 | def test_dataloader(self): 82 | return self._create_data_loader(self.test_data_transform(), data_partition='test', sample_rate=1.) 83 | 84 | def _evaluate(self, val_logs): 85 | losses = [] 86 | outputs = defaultdict(list) 87 | targets = defaultdict(list) 88 | for log in val_logs: 89 | losses.append(log['val_loss'].cpu().numpy()) 90 | for i, (fname, slice) in enumerate(zip(log['fname'], log['slice'])): 91 | outputs[fname].append((slice, log['output'][i])) 92 | targets[fname].append((slice, log['target'][i])) 93 | metrics = dict(val_loss=losses, nmse=[], ssim=[], psnr=[]) 94 | for fname in outputs: 95 | output = np.stack([out for _, out in sorted(outputs[fname])]) 96 | target = np.stack([tgt for _, tgt in sorted(targets[fname])]) 97 | metrics['nmse'].append(evaluate.nmse(target, output)) 98 | metrics['ssim'].append(evaluate.ssim(target, output)) 99 | metrics['psnr'].append(evaluate.psnr(target, output)) 100 | metrics = {metric: np.mean(values) for metric, values in metrics.items()} 101 | print(metrics, '\n') 102 | return dict(log=metrics, **metrics) 103 | 104 | def _visualize(self, val_logs): 105 | def _normalize(image): 106 | image = image[np.newaxis] 107 | image -= image.min() 108 | return image / image.max() 109 | 110 | def _save_image(image, tag): 111 | grid = torchvision.utils.make_grid(torch.Tensor(image), nrow=4, pad_value=1) 112 | self.logger.experiment.add_image(tag, grid) 113 | 114 | # Only process first size to simplify visualization. 115 | visualize_size = val_logs[0]['output'].shape 116 | val_logs = [x for x in val_logs if x['output'].shape == visualize_size] 117 | num_logs = len(val_logs) 118 | num_viz_images = 16 119 | step = (num_logs + num_viz_images - 1) // num_viz_images 120 | outputs, targets = [], [] 121 | for i in range(0, num_logs, step): 122 | outputs.append(_normalize(val_logs[i]['output'][0])) 123 | targets.append(_normalize(val_logs[i]['target'][0])) 124 | outputs = np.stack(outputs) 125 | targets = np.stack(targets) 126 | _save_image(targets, 'Target') 127 | _save_image(outputs, 'Reconstruction') 128 | _save_image(np.abs(targets - outputs), 'Error') 129 | 130 | def validation_end(self, val_logs): 131 | self._visualize(val_logs) 132 | return self._evaluate(val_logs) 133 | 134 | def test_end(self, test_logs): 135 | outputs = defaultdict(list) 136 | for log in test_logs: 137 | for i, (fname, slice) in enumerate(zip(log['fname'], log['slice'])): 138 | outputs[fname].append((slice, log['output'][i])) 139 | for fname in outputs: 140 | outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) 141 | save_reconstructions(outputs, self.hparams.exp_dir / self.hparams.exp / 'reconstructions') 142 | return dict() 143 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/run_bart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.insert(0,'/root/bart-0.5.00/python/')\n", 11 | "\n", 12 | "\n", 13 | "import logging\n", 14 | "import multiprocessing\n", 15 | "import pathlib\n", 16 | "import random\n", 17 | "import time\n", 18 | "from collections import defaultdict\n", 19 | "\n", 20 | "import numpy as np\n", 21 | "import torch\n", 22 | "\n", 23 | "import bart\n", 24 | "from common import utils\n", 25 | "from common.args import Args\n", 26 | "from common.subsample import create_mask_for_mask_type\n", 27 | "from common.utils import tensor_to_complex_np\n", 28 | "from data import transforms\n", 29 | "from data.mri_data import SliceData\n", 30 | "\n", 31 | "logging.basicConfig(level=logging.INFO)\n", 32 | "logger = logging.getLogger(__name__)\n", 33 | "\n", 34 | "\n", 35 | "class DataTransform:\n", 36 | " \"\"\"\n", 37 | " Data Transformer that masks input k-space.\n", 38 | " \"\"\"\n", 39 | "\n", 40 | " def __init__(self, mask_func):\n", 41 | " \"\"\"\n", 42 | " Args:\n", 43 | " mask_func (common.subsample.MaskFunc): A function that can create a mask of\n", 44 | " appropriate shape.\n", 45 | " \"\"\"\n", 46 | " self.mask_func = mask_func\n", 47 | "\n", 48 | " def __call__(self, kspace, target, attrs, fname, slice):\n", 49 | " \"\"\"\n", 50 | " Args:\n", 51 | " kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil\n", 52 | " data or (rows, cols, 2) for single coil data.\n", 53 | " target (numpy.array, optional): Target image\n", 54 | " attrs (dict): Acquisition related information stored in the HDF5 object.\n", 55 | " fname (str): File name\n", 56 | " slice (int): Serial number of the slice.\n", 57 | " Returns:\n", 58 | " (tuple): tuple containing:\n", 59 | " masked_kspace (torch.Tensor): Sub-sampled k-space with the same shape as kspace.\n", 60 | " fname (str): File name containing the current data item\n", 61 | " slice (int): The index of the current slice in the volume\n", 62 | " \"\"\"\n", 63 | " kspace = transforms.to_tensor(kspace)\n", 64 | " seed = tuple(map(ord, fname))\n", 65 | " # Apply mask to raw k-space\n", 66 | " masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)\n", 67 | " return masked_kspace, fname, slice\n", 68 | "\n", 69 | "\n", 70 | "def create_data_loader(args):\n", 71 | " dev_mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations)\n", 72 | " data = SliceData(\n", 73 | " root=args.data_path + str(f'{args.challenge}_val'),\n", 74 | " transform=DataTransform(dev_mask),\n", 75 | " challenge=args.challenge,\n", 76 | " sample_rate=args.sample_rate\n", 77 | " )\n", 78 | " return data\n", 79 | "\n", 80 | "\n", 81 | "def cs_total_variation(args, kspace):\n", 82 | " \"\"\"\n", 83 | " Run ESPIRIT coil sensitivity estimation and Total Variation Minimization based\n", 84 | " reconstruction algorithm using the BART toolkit.\n", 85 | " \"\"\"\n", 86 | "\n", 87 | " if args.challenge == 'singlecoil':\n", 88 | " kspace = kspace.unsqueeze(0)\n", 89 | " kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0)\n", 90 | " kspace = tensor_to_complex_np(kspace)\n", 91 | "\n", 92 | " # Estimate sensitivity maps\n", 93 | " sens_maps = bart.bart(1, f'ecalib -d0 -m1', kspace)\n", 94 | "\n", 95 | " # Use Total Variation Minimization to reconstruct the image\n", 96 | " pred = bart.bart(\n", 97 | " 1, f'pics -d0 -S -R T:7:0:{args.reg_wt} -i {args.num_iters}', kspace, sens_maps\n", 98 | " )\n", 99 | " pred = torch.from_numpy(np.abs(pred[0]))\n", 100 | "\n", 101 | " # Crop the predicted image to selected resolution if bigger\n", 102 | " smallest_width = min(args.resolution, pred.shape[-1])\n", 103 | " smallest_height = min(args.resolution, pred.shape[-2])\n", 104 | " return transforms.center_crop(pred, (smallest_height, smallest_width))\n", 105 | "\n", 106 | "\n", 107 | "def run_model(i):\n", 108 | " masked_kspace, fname, slice = data[i]\n", 109 | " prediction = cs_total_variation(args, masked_kspace)\n", 110 | " return fname, slice, prediction\n", 111 | "\n", 112 | "\n", 113 | "def main():\n", 114 | " if args.num_procs == 0:\n", 115 | " start_time = time.perf_counter()\n", 116 | " outputs = []\n", 117 | " for i in range(len(data)):\n", 118 | " outputs.append(run_model(i))\n", 119 | " save_outputs([run_model(i)], args.output_path)\n", 120 | " time_taken = time.perf_counter() - start_time\n", 121 | " else:\n", 122 | " with multiprocessing.Pool(args.num_procs) as pool:\n", 123 | " start_time = time.perf_counter()\n", 124 | " outputs = pool.map(run_model, range(len(data)))\n", 125 | " time_taken = time.perf_counter() - start_time\n", 126 | " save_outputs(outputs, args.output_path)\n", 127 | " logging.info(f'Run Time = {time_taken:}s')\n", 128 | " \n", 129 | "\n", 130 | "\n", 131 | "import json\n", 132 | "\n", 133 | "import h5py\n", 134 | "\n", 135 | "\n", 136 | "def save_reconstructions(reconstructions, out_dir):\n", 137 | " \"\"\"\n", 138 | " Saves the reconstructions from a model into h5 files that is appropriate for submission\n", 139 | " to the leaderboard.\n", 140 | "\n", 141 | " Args:\n", 142 | " reconstructions (dict[str, np.array]): A dictionary mapping input filenames to\n", 143 | " corresponding reconstructions (of shape num_slices x height x width).\n", 144 | " out_dir (pathlib.Path): Path to the output directory where the reconstructions\n", 145 | " should be saved.\n", 146 | " \"\"\"\n", 147 | " for fname, recons in reconstructions.items():\n", 148 | " with h5py.File(out_dir + fname, 'w') as f:\n", 149 | " f.create_dataset('reconstruction', data=recons)\n", 150 | " \n", 151 | "def save_outputs(outputs, output_path):\n", 152 | " reconstructions = defaultdict(list)\n", 153 | " for fname, slice, pred in outputs:\n", 154 | " reconstructions[fname].append((slice, pred))\n", 155 | " reconstructions = {\n", 156 | " fname: np.stack([pred for _, pred in sorted(slice_preds)])\n", 157 | " for fname, slice_preds in reconstructions.items()\n", 158 | " }\n", 159 | " save_reconstructions(reconstructions, output_path)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 3, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "class Args():\n", 169 | " def __init__(self,ch,path,rate,acc,cent,outpath,iters,reg,procs,maskt,seed,res):\n", 170 | " self.challenge = ch\n", 171 | " self.data_path = path\n", 172 | " self.sample_rate = rate\n", 173 | " self.accelerations = acc\n", 174 | " self.center_fractions = cent\n", 175 | " self.output_path = outpath\n", 176 | " self.num_iters = iters\n", 177 | " self.reg_wt = reg\n", 178 | " self.num_procs = procs\n", 179 | " self.mask_type = maskt\n", 180 | " self.seed = seed\n", 181 | " self.resolution = res\n", 182 | "args = Args(\"multicoil\",\"/hdd/\",1,[4],[0.07],\"/root/multires_deep_decoder/mri/FINAL/TV/\",100,0.01,4,\"random\",42,320)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 4, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "import os\n", 192 | "import subprocess\n", 193 | "import sys\n", 194 | "\n", 195 | "os.environ['TOOLBOX_PATH'] = \"/root/bart-0.5.00/\" # visible in this process + all children\n", 196 | "random.seed(args.seed)\n", 197 | "np.random.seed(args.seed)\n", 198 | "torch.manual_seed(args.seed)\n", 199 | "\n", 200 | "dataset = create_data_loader(args)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 4, 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "data": { 210 | "text/plain": [ 211 | "((PosixPath('/hdd/multicoil_val/file1000017.h5'), 7), 7135, 3)" 212 | ] 213 | }, 214 | "execution_count": 4, 215 | "metadata": {}, 216 | "output_type": "execute_result" 217 | } 218 | ], 219 | "source": [ 220 | "dataset.examples[80],len(dataset),len(dataset[0])" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 6, 226 | "metadata": { 227 | "scrolled": true 228 | }, 229 | "outputs": [ 230 | { 231 | "name": "stderr", 232 | "output_type": "stream", 233 | "text": [ 234 | "INFO:root:Run Time = 333.7682563047856s\n", 235 | "INFO:root:Run Time = 340.6625652872026s\n", 236 | "INFO:root:Run Time = 301.1880727428943s\n", 237 | "INFO:root:Run Time = 274.9473853390664s\n", 238 | "INFO:root:Run Time = 395.30600419454277s\n", 239 | "INFO:root:Run Time = 338.08261120319366s\n", 240 | "INFO:root:Run Time = 357.19963121786714s\n", 241 | "INFO:root:Run Time = 415.9970267675817s\n", 242 | "INFO:root:Run Time = 351.20377369225025s\n", 243 | "INFO:root:Run Time = 335.272654408589s\n", 244 | "INFO:root:Run Time = 320.07918173260987s\n", 245 | "INFO:root:Run Time = 337.77712962403893s\n", 246 | "INFO:root:Run Time = 350.6248935814947s\n", 247 | "INFO:root:Run Time = 319.2754284348339s\n", 248 | "INFO:root:Run Time = 339.5878656320274s\n", 249 | "INFO:root:Run Time = 308.0043003074825s\n", 250 | "INFO:root:Run Time = 358.6736814174801s\n", 251 | "INFO:root:Run Time = 358.03209225833416s\n", 252 | "INFO:root:Run Time = 369.3443151284009s\n", 253 | "INFO:root:Run Time = 376.01454484649s\n", 254 | "INFO:root:Run Time = 327.3354206569493s\n", 255 | "INFO:root:Run Time = 311.5621940083802s\n", 256 | "INFO:root:Run Time = 315.0435768086463s\n", 257 | "INFO:root:Run Time = 325.7221032604575s\n", 258 | "INFO:root:Run Time = 291.2186458352953s\n", 259 | "INFO:root:Run Time = 424.02862623520195s\n", 260 | "INFO:root:Run Time = 380.9776658322662s\n", 261 | "INFO:root:Run Time = 251.65318821929395s\n", 262 | "INFO:root:Run Time = 340.7662496250123s\n", 263 | "INFO:root:Run Time = 467.76796199940145s\n", 264 | "INFO:root:Run Time = 311.6906248535961s\n", 265 | "INFO:root:Run Time = 534.1931731365621s\n", 266 | "INFO:root:Run Time = 298.5228198878467s\n", 267 | "INFO:root:Run Time = 305.9466844946146s\n", 268 | "INFO:root:Run Time = 354.2080086879432s\n", 269 | "INFO:root:Run Time = 425.45318215712905s\n", 270 | "INFO:root:Run Time = 444.71551893651485s\n", 271 | "INFO:root:Run Time = 395.4665974806994s\n", 272 | "INFO:root:Run Time = 450.11666345596313s\n", 273 | "INFO:root:Run Time = 348.74674895592034s\n", 274 | "INFO:root:Run Time = 424.314414229244s\n", 275 | "INFO:root:Run Time = 429.9421614725143s\n", 276 | "INFO:root:Run Time = 523.5043416228145s\n", 277 | "INFO:root:Run Time = 407.0870214384049s\n", 278 | "INFO:root:Run Time = 447.2023910563439s\n", 279 | "INFO:root:Run Time = 460.6034576483071s\n", 280 | "INFO:root:Run Time = 394.97820459865034s\n", 281 | "INFO:root:Run Time = 376.82941082306206s\n", 282 | "INFO:root:Run Time = 370.4078102763742s\n", 283 | "INFO:root:Run Time = 429.3003299552947s\n", 284 | "INFO:root:Run Time = 442.9237650651485s\n", 285 | "INFO:root:Run Time = 492.6802581641823s\n", 286 | "INFO:root:Run Time = 456.0776470042765s\n", 287 | "INFO:root:Run Time = 309.43278669193387s\n", 288 | "INFO:root:Run Time = 440.0848300922662s\n", 289 | "INFO:root:Run Time = 432.87664224021137s\n", 290 | "INFO:root:Run Time = 430.8312914352864s\n", 291 | "INFO:root:Run Time = 421.66501943580806s\n", 292 | "INFO:root:Run Time = 504.81006176769733s\n", 293 | "INFO:root:Run Time = 538.9878176227212s\n", 294 | "INFO:root:Run Time = 450.25816937349737s\n", 295 | "INFO:root:Run Time = 468.013383731246s\n", 296 | "INFO:root:Run Time = 398.28537249937654s\n", 297 | "INFO:root:Run Time = 517.1249352190644s\n", 298 | "INFO:root:Run Time = 337.58466753922403s\n", 299 | "INFO:root:Run Time = 390.9817121960223s\n", 300 | "INFO:root:Run Time = 336.4889906011522s\n", 301 | "INFO:root:Run Time = 472.28560002334416s\n", 302 | "INFO:root:Run Time = 474.4735741727054s\n", 303 | "INFO:root:Run Time = 394.586749330163s\n", 304 | "INFO:root:Run Time = 460.2016584146768s\n", 305 | "INFO:root:Run Time = 424.244948470965s\n", 306 | "INFO:root:Run Time = 444.0486744288355s\n", 307 | "INFO:root:Run Time = 465.4690223969519s\n", 308 | "INFO:root:Run Time = 355.617677655071s\n", 309 | "INFO:root:Run Time = 471.27355794236064s\n", 310 | "INFO:root:Run Time = 389.24966777674854s\n", 311 | "INFO:root:Run Time = 460.87415464036167s\n", 312 | "INFO:root:Run Time = 440.83318879827857s\n", 313 | "INFO:root:Run Time = 432.70478705503047s\n", 314 | "INFO:root:Run Time = 406.13156732544303s\n", 315 | "INFO:root:Run Time = 481.7247441355139s\n", 316 | "INFO:root:Run Time = 385.23040606081486s\n", 317 | "INFO:root:Run Time = 436.92768172733486s\n", 318 | "INFO:root:Run Time = 585.8447301853448s\n", 319 | "INFO:root:Run Time = 542.3469700664282s\n", 320 | "INFO:root:Run Time = 528.0813769325614s\n", 321 | "INFO:root:Run Time = 479.7424617353827s\n", 322 | "INFO:root:Run Time = 466.7871928848326s\n", 323 | "INFO:root:Run Time = 338.71273868344724s\n", 324 | "INFO:root:Run Time = 441.5500371027738s\n", 325 | "INFO:root:Run Time = 428.8800290077925s\n", 326 | "INFO:root:Run Time = 481.04396928846836s\n", 327 | "INFO:root:Run Time = 462.27772130444646s\n", 328 | "INFO:root:Run Time = 486.94645626842976s\n", 329 | "INFO:root:Run Time = 472.66510405763984s\n", 330 | "INFO:root:Run Time = 382.8884541783482s\n", 331 | "INFO:root:Run Time = 526.5198707617819s\n", 332 | "INFO:root:Run Time = 516.3673057667911s\n", 333 | "INFO:root:Run Time = 430.80366771668196s\n", 334 | "INFO:root:Run Time = 505.9666231442243s\n", 335 | "INFO:root:Run Time = 431.4829865563661s\n", 336 | "INFO:root:Run Time = 328.0189682934433s\n", 337 | "INFO:root:Run Time = 566.4797005280852s\n", 338 | "INFO:root:Run Time = 414.8851204998791s\n", 339 | "INFO:root:Run Time = 432.82284201681614s\n", 340 | "INFO:root:Run Time = 431.45999561063945s\n", 341 | "INFO:root:Run Time = 392.26196658983827s\n", 342 | "INFO:root:Run Time = 419.4514162931591s\n", 343 | "INFO:root:Run Time = 447.83714538253844s\n", 344 | "INFO:root:Run Time = 424.4092899840325s\n", 345 | "INFO:root:Run Time = 416.14934803172946s\n", 346 | "INFO:root:Run Time = 431.89119616523385s\n", 347 | "INFO:root:Run Time = 384.00734995119274s\n", 348 | "INFO:root:Run Time = 499.7653317414224s\n", 349 | "INFO:root:Run Time = 486.80337627232075s\n", 350 | "INFO:root:Run Time = 463.53818341344595s\n", 351 | "INFO:root:Run Time = 462.03550822660327s\n", 352 | "INFO:root:Run Time = 385.84180288761854s\n", 353 | "INFO:root:Run Time = 533.4132130518556s\n", 354 | "INFO:root:Run Time = 424.44709764048457s\n", 355 | "INFO:root:Run Time = 391.3850141931325s\n", 356 | "INFO:root:Run Time = 571.5244209095836s\n", 357 | "INFO:root:Run Time = 398.41338567994535s\n", 358 | "INFO:root:Run Time = 379.29579119198024s\n", 359 | "INFO:root:Run Time = 377.02867659181356s\n", 360 | "INFO:root:Run Time = 471.6786704082042s\n", 361 | "INFO:root:Run Time = 331.51744497194886s\n", 362 | "INFO:root:Run Time = 443.2164211887866s\n", 363 | "INFO:root:Run Time = 336.92970431409776s\n", 364 | "INFO:root:Run Time = 397.283165872097s\n", 365 | "INFO:root:Run Time = 396.1671455092728s\n", 366 | "INFO:root:Run Time = 330.75096010789275s\n", 367 | "INFO:root:Run Time = 447.02856631204486s\n", 368 | "INFO:root:Run Time = 390.8315053060651s\n", 369 | "INFO:root:Run Time = 392.9724945295602s\n", 370 | "INFO:root:Run Time = 507.23743664473295s\n", 371 | "INFO:root:Run Time = 288.0143873449415s\n", 372 | "INFO:root:Run Time = 496.09178671613336s\n", 373 | "INFO:root:Run Time = 430.7832391001284s\n", 374 | "INFO:root:Run Time = 418.55701276659966s\n", 375 | "INFO:root:Run Time = 510.4864778164774s\n", 376 | "INFO:root:Run Time = 437.0698127951473s\n", 377 | "INFO:root:Run Time = 301.17781374789774s\n", 378 | "INFO:root:Run Time = 379.37366267479956s\n", 379 | "INFO:root:Run Time = 420.580105535686s\n", 380 | "INFO:root:Run Time = 487.29186703264713s\n", 381 | "INFO:root:Run Time = 550.3076649773866s\n", 382 | "INFO:root:Run Time = 419.5049721375108s\n", 383 | "INFO:root:Run Time = 411.9790954552591s\n", 384 | "INFO:root:Run Time = 371.1687485575676s\n", 385 | "INFO:root:Run Time = 370.06326948851347s\n", 386 | "INFO:root:Run Time = 494.05042749270797s\n", 387 | "INFO:root:Run Time = 488.94175343960524s\n", 388 | "INFO:root:Run Time = 428.4196085240692s\n", 389 | "INFO:root:Run Time = 451.088182432577s\n", 390 | "INFO:root:Run Time = 446.9839248675853s\n", 391 | "INFO:root:Run Time = 450.20198553428054s\n", 392 | "INFO:root:Run Time = 412.72148968465626s\n", 393 | "INFO:root:Run Time = 551.8574121091515s\n", 394 | "INFO:root:Run Time = 714.7977640237659s\n", 395 | "INFO:root:Run Time = 515.068151017651s\n", 396 | "INFO:root:Run Time = 595.8476344123483s\n", 397 | "INFO:root:Run Time = 617.6524196490645s\n", 398 | "INFO:root:Run Time = 600.2221750151366s\n", 399 | "INFO:root:Run Time = 734.3725633509457s\n", 400 | "INFO:root:Run Time = 459.1810085400939s\n", 401 | "INFO:root:Run Time = 546.2445830088109s\n", 402 | "INFO:root:Run Time = 518.3114790972322s\n", 403 | "INFO:root:Run Time = 392.59291512332857s\n", 404 | "INFO:root:Run Time = 484.23065080679953s\n", 405 | "INFO:root:Run Time = 590.5678717344999s\n", 406 | "INFO:root:Run Time = 503.79120522737503s\n", 407 | "INFO:root:Run Time = 552.8279358427972s\n", 408 | "INFO:root:Run Time = 727.5677454862744s\n", 409 | "INFO:root:Run Time = 584.7055314276367s\n", 410 | "INFO:root:Run Time = 594.3992878198624s\n", 411 | "INFO:root:Run Time = 626.2585807479918s\n", 412 | "INFO:root:Run Time = 631.4217041246593s\n", 413 | "INFO:root:Run Time = 589.3917139805853s\n", 414 | "INFO:root:Run Time = 519.05597788468s\n", 415 | "INFO:root:Run Time = 759.129485655576s\n", 416 | "INFO:root:Run Time = 530.3207853045315s\n", 417 | "INFO:root:Run Time = 657.3508451338857s\n", 418 | "INFO:root:Run Time = 802.8821316771209s\n", 419 | "INFO:root:Run Time = 589.222182802856s\n", 420 | "INFO:root:Run Time = 471.22330814413726s\n", 421 | "INFO:root:Run Time = 711.6352181173861s\n", 422 | "INFO:root:Run Time = 660.118361864239s\n", 423 | "INFO:root:Run Time = 545.1319549642503s\n", 424 | "INFO:root:Run Time = 604.8600967172533s\n", 425 | "INFO:root:Run Time = 605.9351858440787s\n", 426 | "INFO:root:Run Time = 569.5467841047794s\n", 427 | "INFO:root:Run Time = 750.0642936062068s\n", 428 | "INFO:root:Run Time = 712.3209960609674s\n", 429 | "INFO:root:Run Time = 572.0809261798859s\n", 430 | "INFO:root:Run Time = 590.359125174582s\n", 431 | "INFO:root:Run Time = 676.413792964071s\n", 432 | "INFO:root:Run Time = 488.95933836884797s\n" 433 | ] 434 | } 435 | ], 436 | "source": [ 437 | "this_data = []\n", 438 | "prev_slicenu = -1\n", 439 | "for i,d in enumerate(dataset):\n", 440 | " if dataset.examples[i][1] > prev_slicenu: \n", 441 | " this_data.append(d)\n", 442 | " prev_slicenu = dataset.examples[i][1]\n", 443 | " else:\n", 444 | " data = this_data\n", 445 | " main()\n", 446 | " this_data = [d]\n", 447 | " prev_slicenu = 0\n", 448 | " if i == len(dataset) - 1:\n", 449 | " data = this_data\n", 450 | " main()" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "metadata": {}, 457 | "outputs": [], 458 | "source": [] 459 | } 460 | ], 461 | "metadata": { 462 | "kernelspec": { 463 | "display_name": "Python 3", 464 | "language": "python", 465 | "name": "python3" 466 | }, 467 | "language_info": { 468 | "codemirror_mode": { 469 | "name": "ipython", 470 | "version": 3 471 | }, 472 | "file_extension": ".py", 473 | "mimetype": "text/x-python", 474 | "name": "python", 475 | "nbconvert_exporter": "python", 476 | "pygments_lexer": "ipython3", 477 | "version": "3.6.7" 478 | } 479 | }, 480 | "nbformat": 4, 481 | "nbformat_minor": 4 482 | } 483 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/run_bart_val.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | This source code is licensed under the MIT license found in the 4 | LICENSE file in the root directory of this source tree. 5 | """ 6 | 7 | import logging 8 | import multiprocessing 9 | import pathlib 10 | import random 11 | import time 12 | from collections import defaultdict 13 | 14 | import numpy as np 15 | import torch 16 | 17 | import bart 18 | from common import utils 19 | from common.args import Args 20 | from common.subsample import create_mask_for_mask_type 21 | from common.utils import tensor_to_complex_np 22 | from data import transforms 23 | from data.mri_data import SliceData 24 | 25 | logging.basicConfig(level=logging.INFO) 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class DataTransform: 30 | """ 31 | Data Transformer that masks input k-space. 32 | """ 33 | 34 | def __init__(self, mask_func): 35 | """ 36 | Args: 37 | mask_func (common.subsample.MaskFunc): A function that can create a mask of 38 | appropriate shape. 39 | """ 40 | self.mask_func = mask_func 41 | 42 | def __call__(self, kspace, target, attrs, fname, slice): 43 | """ 44 | Args: 45 | kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil 46 | data or (rows, cols, 2) for single coil data. 47 | target (numpy.array, optional): Target image 48 | attrs (dict): Acquisition related information stored in the HDF5 object. 49 | fname (str): File name 50 | slice (int): Serial number of the slice. 51 | Returns: 52 | (tuple): tuple containing: 53 | masked_kspace (torch.Tensor): Sub-sampled k-space with the same shape as kspace. 54 | fname (str): File name containing the current data item 55 | slice (int): The index of the current slice in the volume 56 | """ 57 | kspace = transforms.to_tensor(kspace) 58 | seed = tuple(map(ord, fname)) 59 | # Apply mask to raw k-space 60 | masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed) 61 | return masked_kspace, fname, slice 62 | 63 | 64 | def create_data_loader(args): 65 | dev_mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations) 66 | data = SliceData( 67 | root=args.data_path / f'{args.challenge}_val', 68 | transform=DataTransform(dev_mask), 69 | challenge=args.challenge, 70 | sample_rate=args.sample_rate 71 | ) 72 | return data 73 | 74 | 75 | def cs_total_variation(args, kspace): 76 | """ 77 | Run ESPIRIT coil sensitivity estimation and Total Variation Minimization based 78 | reconstruction algorithm using the BART toolkit. 79 | """ 80 | 81 | if args.challenge == 'singlecoil': 82 | kspace = kspace.unsqueeze(0) 83 | kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0) 84 | kspace = tensor_to_complex_np(kspace) 85 | 86 | # Estimate sensitivity maps 87 | sens_maps = bart.bart(1, f'ecalib -d0 -m1', kspace) 88 | 89 | # Use Total Variation Minimization to reconstruct the image 90 | pred = bart.bart( 91 | 1, f'pics -d0 -S -R T:7:0:{args.reg_wt} -i {args.num_iters}', kspace, sens_maps 92 | ) 93 | pred = torch.from_numpy(np.abs(pred[0])) 94 | 95 | # Crop the predicted image to selected resolution if bigger 96 | smallest_width = min(args.resolution, pred.shape[-1]) 97 | smallest_height = min(args.resolution, pred.shape[-2]) 98 | return transforms.center_crop(pred, (smallest_height, smallest_width)) 99 | 100 | 101 | def run_model(i): 102 | masked_kspace, fname, slice = data[i] 103 | prediction = cs_total_variation(args, masked_kspace) 104 | return fname, slice, prediction 105 | 106 | 107 | def main(): 108 | if args.num_procs == 0: 109 | start_time = time.perf_counter() 110 | outputs = [] 111 | for i in range(len(data)): 112 | outputs.append(run_model(i)) 113 | time_taken = time.perf_counter() - start_time 114 | else: 115 | with multiprocessing.Pool(args.num_procs) as pool: 116 | start_time = time.perf_counter() 117 | outputs = pool.map(run_model, range(len(data))) 118 | time_taken = time.perf_counter() - start_time 119 | logging.info(f'Run Time = {time_taken:}s') 120 | save_outputs(outputs, args.output_path) 121 | 122 | 123 | def save_outputs(outputs, output_path): 124 | reconstructions = defaultdict(list) 125 | for fname, slice, pred in outputs: 126 | reconstructions[fname].append((slice, pred)) 127 | reconstructions = { 128 | fname: np.stack([pred for _, pred in sorted(slice_preds)]) 129 | for fname, slice_preds in reconstructions.items() 130 | } 131 | utils.save_reconstructions(reconstructions, output_path) 132 | 133 | 134 | if __name__ == '__main__': 135 | parser = Args() 136 | parser.add_argument('--output-path', type=pathlib.Path, default=None, 137 | help='Path to save the reconstructions to') 138 | parser.add_argument('--num-iters', type=int, default=200, 139 | help='Number of iterations to run the reconstruction algorithm') 140 | parser.add_argument('--reg-wt', type=float, default=0.01, 141 | help='Regularization weight parameter') 142 | parser.add_argument('--num-procs', type=int, default=20, 143 | help='Number of processes. Set to 0 to disable multiprocessing.') 144 | parser.add_argument('--mask-type',default='random') 145 | args = parser.parse_args() 146 | 147 | random.seed(args.seed) 148 | np.random.seed(args.seed) 149 | torch.manual_seed(args.seed) 150 | 151 | data = create_data_loader(args) 152 | main() -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/train_unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import pathlib 9 | import random 10 | 11 | import numpy as np 12 | import torch 13 | from pytorch_lightning import Trainer 14 | from pytorch_lightning.logging import TestTubeLogger 15 | from torch.nn import functional as F 16 | from torch.optim import RMSprop 17 | 18 | from .common.args import Args 19 | from .common.subsample import create_mask_for_mask_type 20 | from .data import transforms 21 | from .mri_model import MRIModel 22 | from .unet_model import UnetModel 23 | 24 | 25 | class DataTransform: 26 | """ 27 | Data Transformer for training U-Net models. 28 | """ 29 | 30 | def __init__(self, resolution, which_challenge, mask_func=None, use_seed=True): 31 | """ 32 | Args: 33 | mask_func (common.subsample.MaskFunc): A function that can create a mask of 34 | appropriate shape. 35 | resolution (int): Resolution of the image. 36 | which_challenge (str): Either "singlecoil" or "multicoil" denoting the dataset. 37 | use_seed (bool): If true, this class computes a pseudo random number generator seed 38 | from the filename. This ensures that the same mask is used for all the slices of 39 | a given volume every time. 40 | """ 41 | if which_challenge not in ('singlecoil', 'multicoil'): 42 | raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') 43 | self.mask_func = mask_func 44 | self.resolution = resolution 45 | self.which_challenge = which_challenge 46 | self.use_seed = use_seed 47 | 48 | def __call__(self, kspace, target, attrs, fname, slice): 49 | """ 50 | Args: 51 | kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil 52 | data or (rows, cols, 2) for single coil data. 53 | target (numpy.array): Target image 54 | attrs (dict): Acquisition related information stored in the HDF5 object. 55 | fname (str): File name 56 | slice (int): Serial number of the slice. 57 | Returns: 58 | (tuple): tuple containing: 59 | image (torch.Tensor): Zero-filled input image. 60 | target (torch.Tensor): Target image converted to a torch Tensor. 61 | mean (float): Mean value used for normalization. 62 | std (float): Standard deviation value used for normalization. 63 | """ 64 | kspace = transforms.to_tensor(kspace) 65 | # Apply mask 66 | if self.mask_func: 67 | seed = None if not self.use_seed else tuple(map(ord, fname)) 68 | masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed) 69 | else: 70 | masked_kspace = kspace 71 | 72 | # Inverse Fourier Transform to get zero filled solution 73 | image = transforms.ifft2(masked_kspace) 74 | # Crop input image to given resolution if larger 75 | smallest_width = min(self.resolution, image.shape[-2]) 76 | smallest_height = min(self.resolution, image.shape[-3]) 77 | if target is not None: 78 | smallest_width = min(smallest_width, target.shape[-1]) 79 | smallest_height = min(smallest_height, target.shape[-2]) 80 | crop_size = (smallest_height, smallest_width) 81 | image = transforms.complex_center_crop(image, crop_size) 82 | # Absolute value 83 | image = transforms.complex_abs(image) 84 | # Apply Root-Sum-of-Squares if multicoil data 85 | if self.which_challenge == 'multicoil': 86 | image = transforms.root_sum_of_squares(image) 87 | # Normalize input 88 | image, mean, std = transforms.normalize_instance(image, eps=1e-11) 89 | image = image.clamp(-6, 6) 90 | # Normalize target 91 | if target is not None: 92 | target = transforms.to_tensor(target) 93 | target = transforms.center_crop(target, crop_size) 94 | target = transforms.normalize(target, mean, std, eps=1e-11) 95 | target = target.clamp(-6, 6) 96 | else: 97 | target = torch.Tensor([0]) 98 | return image, target, mean, std, fname, slice 99 | 100 | 101 | class UnetMRIModel(MRIModel): 102 | def __init__(self, hparams): 103 | super().__init__(hparams) 104 | self.unet = UnetModel( 105 | in_chans=1, 106 | out_chans=1, 107 | chans=hparams.num_chans, 108 | num_pool_layers=hparams.num_pools, 109 | drop_prob=hparams.drop_prob 110 | ) 111 | 112 | def forward(self, input): 113 | return self.unet(input.unsqueeze(1)).squeeze(1) 114 | 115 | def training_step(self, batch, batch_idx): 116 | input, target, mean, std, _, _ = batch 117 | output = self.forward(input) 118 | loss = F.l1_loss(output, target) 119 | logs = {'loss': loss.item()} 120 | return dict(loss=loss, log=logs) 121 | 122 | def validation_step(self, batch, batch_idx): 123 | input, target, mean, std, fname, slice = batch 124 | output = self.forward(input) 125 | mean = mean.unsqueeze(1).unsqueeze(2) 126 | std = std.unsqueeze(1).unsqueeze(2) 127 | return { 128 | 'fname': fname, 129 | 'slice': slice, 130 | 'output': (output * std + mean).cpu().numpy(), 131 | 'target': (target * std + mean).cpu().numpy(), 132 | 'val_loss': F.l1_loss(output, target), 133 | } 134 | 135 | def test_step(self, batch, batch_idx): 136 | input, _, mean, std, fname, slice = batch 137 | output = self.forward(input) 138 | mean = mean.unsqueeze(1).unsqueeze(2) 139 | std = std.unsqueeze(1).unsqueeze(2) 140 | return { 141 | 'fname': fname, 142 | 'slice': slice, 143 | 'output': (output * std + mean).cpu().numpy(), 144 | } 145 | 146 | def configure_optimizers(self): 147 | optim = RMSprop(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) 148 | scheduler = torch.optim.lr_scheduler.StepLR(optim, self.hparams.lr_step_size, self.hparams.lr_gamma) 149 | return [optim], [scheduler] 150 | 151 | def train_data_transform(self): 152 | mask = create_mask_for_mask_type(self.hparams.mask_type, self.hparams.center_fractions, 153 | self.hparams.accelerations) 154 | return DataTransform(self.hparams.resolution, self.hparams.challenge, mask, use_seed=False) 155 | 156 | def val_data_transform(self): 157 | mask = create_mask_for_mask_type(self.hparams.mask_type, self.hparams.center_fractions, 158 | self.hparams.accelerations) 159 | return DataTransform(self.hparams.resolution, self.hparams.challenge, mask) 160 | 161 | def test_data_transform(self): 162 | return DataTransform(self.hparams.resolution, self.hparams.challenge) 163 | 164 | @staticmethod 165 | def add_model_specific_args(parser): 166 | parser.add_argument('--num-pools', type=int, default=4, help='Number of U-Net pooling layers') 167 | parser.add_argument('--drop-prob', type=float, default=0.0, help='Dropout probability') 168 | parser.add_argument('--num-chans', type=int, default=32, help='Number of U-Net channels') 169 | parser.add_argument('--batch-size', default=16, type=int, help='Mini batch size') 170 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 171 | parser.add_argument('--lr-step-size', type=int, default=40, 172 | help='Period of learning rate decay') 173 | parser.add_argument('--lr-gamma', type=float, default=0.1, 174 | help='Multiplicative factor of learning rate decay') 175 | parser.add_argument('--weight-decay', type=float, default=0., 176 | help='Strength of weight decay regularization') 177 | parser.add_argument('--mask_type',default='random') 178 | return parser 179 | 180 | 181 | def create_trainer(args, logger): 182 | return Trainer( 183 | #num_nodes=1, 184 | logger=logger, 185 | default_save_path=args.exp_dir, 186 | checkpoint_callback=True, 187 | max_nb_epochs=args.num_epochs, 188 | gpus=args.gpus, 189 | distributed_backend='ddp', 190 | check_val_every_n_epoch=1, 191 | val_check_interval=1., 192 | early_stop_callback=False 193 | ) 194 | 195 | 196 | def main(args): 197 | if args.mode == 'train': 198 | load_version = 0 if args.resume else None 199 | logger = TestTubeLogger(save_dir=args.exp_dir, name=args.exp, version=load_version) 200 | trainer = create_trainer(args, logger) 201 | model = UnetMRIModel(args) 202 | trainer.fit(model) 203 | else: # args.mode == 'test' 204 | assert args.checkpoint is not None 205 | model = UnetMRIModel.load_from_checkpoint(str(args.checkpoint)) 206 | model.hparams.sample_rate = 1. 207 | trainer = create_trainer(args, logger=False) 208 | trainer.test(model) 209 | 210 | 211 | if __name__ == '__main__': 212 | parser = Args() 213 | parser.add_argument('--mode', choices=['train', 'test'], default='train') 214 | parser.add_argument('--num-epochs', type=int, default=50, help='Number of training epochs') 215 | parser.add_argument('--gpus', type=int, default=1) 216 | parser.add_argument('--exp-dir', type=pathlib.Path, default='experiments', 217 | help='Path where model and results should be saved') 218 | parser.add_argument('--exp', type=str, help='Name of the experiment') 219 | parser.add_argument('--checkpoint', type=pathlib.Path, 220 | help='Path to pre-trained model. Use with --mode test') 221 | parser.add_argument('--resume', action='store_true', 222 | help='If set, resume the training from a previous model checkpoint. ') 223 | parser = UnetMRIModel.add_model_specific_args(parser) 224 | args = parser.parse_args() 225 | random.seed(args.seed) 226 | np.random.seed(args.seed) 227 | torch.manual_seed(args.seed) 228 | main(args) 229 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/unet_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | """ 15 | A Convolutional Block that consists of two convolution layers each followed by 16 | instance normalization, LeakyReLU activation and dropout. 17 | """ 18 | 19 | def __init__(self, in_chans, out_chans, drop_prob): 20 | """ 21 | Args: 22 | in_chans (int): Number of channels in the input. 23 | out_chans (int): Number of channels in the output. 24 | drop_prob (float): Dropout probability. 25 | """ 26 | super().__init__() 27 | 28 | self.in_chans = in_chans 29 | self.out_chans = out_chans 30 | self.drop_prob = drop_prob 31 | 32 | self.layers = nn.Sequential( 33 | nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), 34 | nn.InstanceNorm2d(out_chans), 35 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 36 | nn.Dropout2d(drop_prob), 37 | nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), 38 | nn.InstanceNorm2d(out_chans), 39 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 40 | nn.Dropout2d(drop_prob) 41 | ) 42 | 43 | def forward(self, input): 44 | """ 45 | Args: 46 | input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] 47 | 48 | Returns: 49 | (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] 50 | """ 51 | return self.layers(input) 52 | 53 | def __repr__(self): 54 | return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans}, ' \ 55 | f'drop_prob={self.drop_prob})' 56 | 57 | 58 | class TransposeConvBlock(nn.Module): 59 | """ 60 | A Transpose Convolutional Block that consists of one convolution transpose layers followed by 61 | instance normalization and LeakyReLU activation. 62 | """ 63 | 64 | def __init__(self, in_chans, out_chans): 65 | """ 66 | Args: 67 | in_chans (int): Number of channels in the input. 68 | out_chans (int): Number of channels in the output. 69 | """ 70 | super().__init__() 71 | 72 | self.in_chans = in_chans 73 | self.out_chans = out_chans 74 | 75 | self.layers = nn.Sequential( 76 | nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False), 77 | nn.InstanceNorm2d(out_chans), 78 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 79 | ) 80 | 81 | def forward(self, input): 82 | """ 83 | Args: 84 | input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] 85 | 86 | Returns: 87 | (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] 88 | """ 89 | return self.layers(input) 90 | 91 | def __repr__(self): 92 | return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans})' 93 | 94 | 95 | class UnetModel(nn.Module): 96 | """ 97 | PyTorch implementation of a U-Net model. 98 | 99 | This is based on: 100 | Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks 101 | for biomedical image segmentation. In International Conference on Medical image 102 | computing and computer-assisted intervention, pages 234–241. Springer, 2015. 103 | """ 104 | 105 | def __init__(self, in_chans, out_chans, chans, num_pool_layers, drop_prob): 106 | """ 107 | Args: 108 | in_chans (int): Number of channels in the input to the U-Net model. 109 | out_chans (int): Number of channels in the output to the U-Net model. 110 | chans (int): Number of output channels of the first convolution layer. 111 | num_pool_layers (int): Number of down-sampling and up-sampling layers. 112 | drop_prob (float): Dropout probability. 113 | """ 114 | super().__init__() 115 | 116 | self.in_chans = in_chans 117 | self.out_chans = out_chans 118 | self.chans = chans 119 | self.num_pool_layers = num_pool_layers 120 | self.drop_prob = drop_prob 121 | 122 | self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) 123 | ch = chans 124 | for i in range(num_pool_layers - 1): 125 | self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob)] 126 | ch *= 2 127 | self.conv = ConvBlock(ch, ch * 2, drop_prob) 128 | 129 | self.up_conv = nn.ModuleList() 130 | self.up_transpose_conv = nn.ModuleList() 131 | for i in range(num_pool_layers - 1): 132 | self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)] 133 | self.up_conv += [ConvBlock(ch * 2, ch, drop_prob)] 134 | ch //= 2 135 | 136 | self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)] 137 | self.up_conv += [ 138 | nn.Sequential( 139 | ConvBlock(ch * 2, ch, drop_prob), 140 | nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), 141 | )] 142 | 143 | def forward(self, input): 144 | """ 145 | Args: 146 | input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] 147 | 148 | Returns: 149 | (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] 150 | """ 151 | stack = [] 152 | output = input 153 | 154 | # Apply down-sampling layers 155 | for i, layer in enumerate(self.down_sample_layers): 156 | output = layer(output) 157 | stack.append(output) 158 | output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) 159 | 160 | output = self.conv(output) 161 | 162 | # Apply up-sampling layers 163 | for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): 164 | downsample_layer = stack.pop() 165 | output = transpose_conv(output) 166 | 167 | # Reflect pad on the right/botton if needed to handle odd input dimensions. 168 | padding = [0, 0, 0, 0] 169 | if output.shape[-1] != downsample_layer.shape[-1]: 170 | padding[1] = 1 # Padding right 171 | if output.shape[-2] != downsample_layer.shape[-2]: 172 | padding[3] = 1 # Padding bottom 173 | if sum(padding) != 0: 174 | output = F.pad(output, padding, "reflect") 175 | 176 | output = torch.cat([output, downsample_layer], dim=1) 177 | output = conv(output) 178 | 179 | return output 180 | -------------------------------------------------------------------------------- /DIP_UNET_models/unet_and_tv/uuu.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/uuu.zip -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConvDecoder 2 | 3 | Check out our colab-demo for a quick example on how the decoder works for multi-coil accelerated MRI reconstruction: 4 | 5 | [![Explore ConvDecoder in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1xu_NS6ClikkOM1TTPL7EDqOjQZCvCvlL#offline=true&sandboxMode=true)
6 | 7 | 8 |

9 | 10 |
11 | This repository provides code for reproducing the results in the paper: 12 | 13 | **''Accelerated MRI with Un-trained Neural Networks''** by Mohammad Zalbagi Darestani and Reinhard Heckel 14 | 15 | Code by: Mohammad Zalbagi Darestani (mz35@rice.edu) and Reinhard Heckel (rh43@rice.edu) 16 | *** 17 | 18 | The aim of the code is to investigate the capability of different un-trained methods, including our proposed ConvDecoder, for the MRI acceleration problem. The task is to recover a fine image from a few measurements. We provide experiments to: 19 | 20 | (i) compare ConvDecoder with U-net, a standard popular trained method for medical imaging, on the FastMRI validation set (**ConvDecoder_vs_Unet_multicoil.ipynb**), 21 | 22 | (ii) compare ConvDecoder with Deep Decoder and Deep Image Prior, two popular un-trained methods for standard inverse problems, again, on the FastMRI dataset (**ConvDecoder_vs_DIP_vs_DD_multicoil.ipynb**), 23 | 24 | (iii) compare ConvDecoder with U-net on an out-of-distribution sample to demonstrate the robustness of un-trained methods toward a shift in the distribution at the inference time (**robustness_to_distribution_shift.ipynb**), 25 | 26 | (iv) and finally, visualize the output of ConvDecoder layers to illustrate how ConvDecoder, as a convolutional generator, finds a fine representation of an image (**visualize_layers_singlecoil.ipynb**). 27 | 28 | ### List of contents 29 | * [Setup and installation](#Setup-and-installation)
30 | * [Dataset](#Dataset)
31 | * [Running the code](#Running-the-code)
32 | * [References](#References)
33 | * [License](#License) 34 | *** 35 | 36 | # Setup and installation 37 | On a normal computer, it takes aproximately 10 minutes to install all the required softwares and packages. 38 | 39 | ### OS requirements 40 | The code has been tested on the following operating system: 41 | 42 | Linux: Ubuntu 16.04.5 43 | 44 | ### Python dependencies 45 | To reproduce the results by running each of the jupyter notebooks, the following softwares are required. Assuming the experiment is being performed in a docker container or a linux machine, the following libraries and packages need to be installed. 46 | 47 | apt-get update 48 | apt-get install python3.6 # --> or any other system-specific command for installing python3 on your system. 49 | pip install jupyter 50 | pip install numpy 51 | pip install matplotlib 52 | pip install sigpy 53 | pip install h5py 54 | pip install scikit-image 55 | pip install runstats 56 | pip install pytorch_msssim 57 | pip install pytorch-lightning==0.7.5 58 | pip install test-tube 59 | pip install Pillow 60 | 61 | If pip does not come with the version of python you installed, install pip manually from [here](https://ehmatthes.github.io/pcc/chapter_12/installing_pip.html). Also, install pytorch from [here](https://pytorch.org/) according to your system specifications. 62 | 63 | # Dataset 64 | All the experiments are performed on the [FastMRI](https://fastmri.org/dataset) dataset--except the experiment for measuring the robustness toward out-of-distribution samples which is performed on the cameraman test image. 65 | 66 | # Running the code 67 | You may simply clone this repository and run each notebook to reproduce the results. **Note** that you need to download the [FastMRI](https://fastmri.org/dataset) dataset and change the **data path** (when loading the measurements) in each notebook accordingly, provided that you intend to run the code for MRI data (for MRI data, all of our experiments are performed on the validation sets--either single-coil or multi-coil). 68 | 69 | # References 70 | Code for training the U-net is taken from [here](https://github.com/facebookresearch/fastMRI/tree/master/models/unet).
71 | Code for Deep Decoder and Deep Image Prior architectures are taken from [repo1](https://github.com/reinhardh/supplement_deep_decoder) and [repo2](https://github.com/DmitryUlyanov/deep-image-prior), respectively. 72 | 73 | # License 74 | This project is covered by **Apache 2.0 License**. 75 | 76 | ## Citation 77 | If you find our work useful in your research, please cite: 78 | ``` 79 | @inproceedings{, 80 | author = {Zalbagi Darestani, Mohammad and Heckel, Reinhard}, 81 | title = {Accelerated MRI with Un-trained Neural Networks}, 82 | booktitle = {IEEE Transactions of Computational Imaging}, 83 | volume={7}, 84 | pages={724--733}, 85 | year={2021} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /UNET_trained/epoch=49.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/UNET_trained/epoch=49.ckpt -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /common/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/common/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/common/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/args.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/common/__pycache__/args.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/args.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/common/__pycache__/args.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/evaluate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/common/__pycache__/evaluate.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/evaluate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/common/__pycache__/evaluate.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/subsample.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/common/__pycache__/subsample.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/subsample.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/common/__pycache__/subsample.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/common/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/common/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /common/args.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import pathlib 10 | 11 | 12 | class Args(argparse.ArgumentParser): 13 | """ 14 | Defines global default arguments. 15 | """ 16 | 17 | def __init__(self, **overrides): 18 | """ 19 | Args: 20 | **overrides (dict, optional): Keyword arguments used to override default argument values 21 | """ 22 | 23 | super().__init__(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | 25 | self.add_argument('--seed', default=42, type=int, help='Seed for random number generators') 26 | self.add_argument('--resolution', default=320, type=int, help='Resolution of images') 27 | 28 | # Data parameters 29 | self.add_argument('--challenge', choices=['singlecoil', 'multicoil'], required=True, 30 | help='Which challenge') 31 | self.add_argument('--data-path', type=pathlib.Path, required=True, 32 | help='Path to the dataset') 33 | self.add_argument('--sample-rate', type=float, default=1., 34 | help='Fraction of total volumes to include') 35 | 36 | # Mask parameters 37 | self.add_argument('--accelerations', nargs='+', default=[4, 8], type=int, 38 | help='Ratio of k-space columns to be sampled. If multiple values are ' 39 | 'provided, then one of those is chosen uniformly at random for ' 40 | 'each volume.') 41 | self.add_argument('--center-fractions', nargs='+', default=[0.08, 0.04], type=float, 42 | help='Fraction of low-frequency k-space columns to be sampled. Should ' 43 | 'have the same length as accelerations') 44 | 45 | # Override defaults with passed overrides 46 | self.set_defaults(**overrides) 47 | -------------------------------------------------------------------------------- /common/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import pathlib 10 | from argparse import ArgumentParser 11 | 12 | import h5py 13 | import scipy 14 | import numpy as np 15 | from runstats import Statistics 16 | from skimage.measure import compare_psnr, compare_ssim 17 | 18 | 19 | def mse(gt, pred): 20 | """ Compute Mean Squared Error (MSE) """ 21 | return np.mean((gt - pred) ** 2) 22 | 23 | 24 | def nmse(gt, pred): 25 | """ Compute Normalized Mean Squared Error (NMSE) """ 26 | return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2 27 | 28 | 29 | def psnr(gt, pred): 30 | """ Compute Peak Signal to Noise Ratio metric (PSNR) """ 31 | return compare_psnr(gt, pred, data_range=gt.max()) 32 | 33 | 34 | def ssim(gt, pred): 35 | """ Compute Structural Similarity Index Metric (SSIM). """ 36 | return compare_ssim( 37 | gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max() 38 | ) 39 | 40 | def vifp_mscale(ref, dist,sigma_nsq=1,eps=1e-10): 41 | ### from https://github.com/aizvorski/video-quality/blob/master/vifp.py 42 | sigma_nsq = sigma_nsq ### tune this for your dataset to get reasonable numbers 43 | eps = eps 44 | 45 | num = 0.0 46 | den = 0.0 47 | for scale in range(1, 5): 48 | 49 | N = 2**(4-scale+1) + 1 50 | sd = N/5.0 51 | 52 | if (scale > 1): 53 | ref = scipy.ndimage.gaussian_filter(ref, sd) 54 | dist = scipy.ndimage.gaussian_filter(dist, sd) 55 | ref = ref[::2, ::2] 56 | dist = dist[::2, ::2] 57 | 58 | mu1 = scipy.ndimage.gaussian_filter(ref, sd) 59 | mu2 = scipy.ndimage.gaussian_filter(dist, sd) 60 | mu1_sq = mu1 * mu1 61 | mu2_sq = mu2 * mu2 62 | mu1_mu2 = mu1 * mu2 63 | sigma1_sq = scipy.ndimage.gaussian_filter(ref * ref, sd) - mu1_sq 64 | sigma2_sq = scipy.ndimage.gaussian_filter(dist * dist, sd) - mu2_sq 65 | sigma12 = scipy.ndimage.gaussian_filter(ref * dist, sd) - mu1_mu2 66 | 67 | sigma1_sq[sigma1_sq<0] = 0 68 | sigma2_sq[sigma2_sq<0] = 0 69 | 70 | g = sigma12 / (sigma1_sq + eps) 71 | sv_sq = sigma2_sq - g * sigma12 72 | 73 | g[sigma1_sq desired_factor + tolerance: 148 | mask_func = MaskFunc(center_fractions=[cent], accelerations=[desired_factor]) # Create the mask function object 149 | masked_kspace, mask = apply_mask(slice_ksp_torchtensor, mask_func=mask_func) # Apply the mask to k-space 150 | mask1d = var_to_np(mask)[0,:,0] 151 | undersampling_factor = len(mask1d) / sum(mask1d) 152 | 153 | mask1d = var_to_np(mask)[0,:,0] 154 | 155 | # The provided mask and data have last dim of 368, but the actual data is smaller. 156 | # To prevent forcing the network to learn outside the data region, we force the mask to 0 there. 157 | mask1d[:mask1d.shape[-1]//2-160] = 0 158 | mask1d[mask1d.shape[-1]//2+160:] =0 159 | mask2d = np.repeat(mask1d[None,:], slice_ksp.shape[1], axis=0).astype(int) # Turning 1D Mask into 2D that matches data dimensions 160 | mask2d = np.pad(mask2d,((0,),((slice_ksp.shape[-1]-mask2d.shape[-1])//2,)),mode='constant') # Zero padding to make sure dimensions match up 161 | mask = to_tensor( np.array( [[mask2d[0][np.newaxis].T]] ) ).type(dtype).detach().cpu() 162 | return mask, mask1d, mask2d 163 | 164 | def apply_mask(data, mask_func = None, mask = None, seed=None): 165 | """ 166 | ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri 167 | Subsample given k-space by multiplying with a mask. 168 | 169 | Args: 170 | data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where 171 | dimensions -3 and -2 are the spatial dimensions, and the final dimension has size 172 | 2 (for complex values). 173 | mask_func (callable): A function that takes a shape (tuple of ints) and a random 174 | number seed and returns a mask. 175 | seed (int or 1-d array_like, optional): Seed for the random number generator. 176 | 177 | Returns: 178 | (tuple): tuple containing: 179 | masked data (torch.Tensor): Subsampled k-space data 180 | mask (torch.Tensor): The generated mask 181 | """ 182 | shape = np.array(data.shape) 183 | shape[:-3] = 1 184 | if mask is None: 185 | mask = mask_func(shape, seed) 186 | return data * mask, mask 187 | 188 | def fft(input, signal_ndim, normalized=False): 189 | # This function is called from the fft2 function below 190 | if signal_ndim < 1 or signal_ndim > 3: 191 | print("Signal ndim out of range, was", signal_ndim, "but expected a value between 1 and 3, inclusive") 192 | return 193 | 194 | dims = (-1) 195 | if signal_ndim == 2: 196 | dims = (-2, -1) 197 | if signal_ndim == 3: 198 | dims = (-3, -2, -1) 199 | 200 | norm = "backward" 201 | if normalized: 202 | norm = "ortho" 203 | 204 | return torch.view_as_real(torch.fft.fftn(torch.view_as_complex(input), dim=dims, norm=norm)) 205 | 206 | def ifft(input, signal_ndim, normalized=False): 207 | # This function is called from the ifft2 function below 208 | if signal_ndim < 1 or signal_ndim > 3: 209 | print("Signal ndim out of range, was", signal_ndim, "but expected a value between 1 and 3, inclusive") 210 | return 211 | 212 | dims = (-1) 213 | if signal_ndim == 2: 214 | dims = (-2, -1) 215 | if signal_ndim == 3: 216 | dims = (-3, -2, -1) 217 | 218 | norm = "backward" 219 | if normalized: 220 | norm = "ortho" 221 | 222 | return torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(input), dim=dims, norm=norm)) 223 | 224 | def fft2(data): 225 | """ 226 | ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri 227 | Apply centered 2 dimensional Fast Fourier Transform. It calls the fft function above to make it compatible with the latest version of pytorch. 228 | 229 | Args: 230 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 231 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 232 | assumed to be batch dimensions. 233 | 234 | Returns: 235 | torch.Tensor: The FFT of the input. 236 | """ 237 | assert data.size(-1) == 2 238 | data = ifftshift(data, dim=(-3, -2)) 239 | data = fft(data, 2, normalized=True) 240 | data = fftshift(data, dim=(-3, -2)) 241 | return data 242 | 243 | 244 | def ifft2(data): 245 | """ 246 | ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri 247 | Apply centered 2-dimensional Inverse Fast Fourier Transform. It calls the ifft function above to make it compatible with the latest version of pytorch. 248 | 249 | Args: 250 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 251 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 252 | assumed to be batch dimensions. 253 | 254 | Returns: 255 | torch.Tensor: The IFFT of the input. 256 | """ 257 | assert data.size(-1) == 2 258 | data = ifftshift(data, dim=(-3, -2)) 259 | data = ifft(data, 2, normalized=True) 260 | data = fftshift(data, dim=(-3, -2)) 261 | return data 262 | 263 | 264 | def complex_abs(data): 265 | """ 266 | ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri 267 | Compute the absolute value of a complex valued input tensor. 268 | 269 | Args: 270 | data (torch.Tensor): A complex valued tensor, where the size of the final dimension 271 | should be 2. 272 | 273 | Returns: 274 | torch.Tensor: Absolute value of data 275 | """ 276 | assert data.size(-1) == 2 277 | return (data ** 2).sum(dim=-1).sqrt() 278 | 279 | def fftshift(x, dim=None): 280 | """ 281 | ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri 282 | Similar to np.fft.fftshift but applies to PyTorch Tensors 283 | """ 284 | if dim is None: 285 | dim = tuple(range(x.dim())) 286 | shift = [dim // 2 for dim in x.shape] 287 | elif isinstance(dim, int): 288 | shift = x.shape[dim] // 2 289 | else: 290 | shift = [x.shape[i] // 2 for i in dim] 291 | return roll(x, shift, dim) 292 | 293 | 294 | def ifftshift(x, dim=None): 295 | """ 296 | ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri 297 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 298 | """ 299 | if dim is None: 300 | dim = tuple(range(x.dim())) 301 | shift = [(dim + 1) // 2 for dim in x.shape] 302 | elif isinstance(dim, int): 303 | shift = (x.shape[dim] + 1) // 2 304 | else: 305 | shift = [(x.shape[i] + 1) // 2 for i in dim] 306 | return roll(x, shift, dim) 307 | 308 | def roll(x, shift, dim): 309 | """ 310 | ref: https://github.com/facebookresearch/fastMRI/tree/master/fastmri 311 | Similar to np.roll but applies to PyTorch Tensors 312 | """ 313 | if isinstance(shift, (tuple, list)): 314 | assert len(shift) == len(dim) 315 | for s, d in zip(shift, dim): 316 | x = roll(x, s, d) 317 | return x 318 | shift = shift % x.size(dim) 319 | if shift == 0: 320 | return x 321 | left = x.narrow(dim, 0, x.size(dim) - shift) 322 | right = x.narrow(dim, x.size(dim) - shift, shift) 323 | return torch.cat((right, left), dim=dim) 324 | -------------------------------------------------------------------------------- /include/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | from .decoder_parallel_conv import * 3 | from .decoder_conv import * 4 | from .decoder_skip import * 5 | from .fit import * 6 | from .helpers import * 7 | from .mri_helpers import * -------------------------------------------------------------------------------- /include/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/compression.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/compression.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/decoder.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder_conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/decoder_conv.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder_conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/decoder_conv.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder_parallel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/decoder_parallel.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder_parallel2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/decoder_parallel2.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder_parallel_conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/decoder_parallel_conv.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder_parallel_conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/decoder_parallel_conv.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder_skip.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/decoder_skip.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder_skip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/decoder_skip.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder_skip2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/decoder_skip2.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/fit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/fit.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/fit.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/fit.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/helpers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/helpers.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/helpers.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/mri_helpers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/mri_helpers.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/mri_helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/mri_helpers.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/visualize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/visualize.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/wavelet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/__pycache__/wavelet.cpython-36.pyc -------------------------------------------------------------------------------- /include/decoder_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from copy import copy 5 | 6 | def add_module(self, module): 7 | self.add_module(str(len(self) + 1), module) 8 | 9 | torch.nn.Module.add = add_module 10 | 11 | class conv_model(nn.Module): 12 | def __init__(self, num_layers, strides, num_channels, num_output_channels, hidden_size, upsample_mode, act_fun,sig=None, bn_affine=True, skips=False,intermeds=None,bias=False,need_lin_comb=False,need_last=False,kernel_size=3): 13 | super(conv_model, self).__init__() 14 | 15 | self.num_layers = num_layers 16 | self.hidden_size = hidden_size 17 | self.upsample_mode = upsample_mode 18 | self.act_fun = act_fun 19 | self.sig= sig 20 | self.skips = skips 21 | self.intermeds = intermeds 22 | self.layer_inds = [] # record index of the layers that generate output in the sequential mode (after each BatchNorm) 23 | self.combinations = None # this holds input of the last layer which is upsampled versions of previous layers 24 | 25 | cntr = 1 26 | net1 = nn.Sequential() 27 | for i in range(num_layers-1): 28 | 29 | net1.add(nn.Upsample(size=hidden_size[i], mode=upsample_mode))#,align_corners=True)) 30 | cntr += 1 31 | 32 | conv = nn.Conv2d(num_channels, num_channels, kernel_size, strides[i], padding=(kernel_size-1)//2, bias=bias) 33 | net1.add(conv) 34 | cntr += 1 35 | 36 | #net1.add(nn.BatchNorm2d( num_channels, affine=bn_affine)) 37 | net1.add(act_fun) 38 | cntr += 1 39 | 40 | if need_lin_comb: 41 | net1.add(nn.BatchNorm2d( num_channels, affine=bn_affine)) 42 | #net1.add(act_fun) 43 | cntr += 1 44 | 45 | net1.add(nn.Conv2d(num_channels, num_channels, 1, 1, padding=0, bias=bias)) 46 | cntr += 1 47 | 48 | #net1.add(nn.BatchNorm2d( num_channels, affine=bn_affine)) 49 | net1.add(act_fun) 50 | cntr += 1 51 | 52 | #net1.add(act_fun) 53 | net1.add(nn.BatchNorm2d( num_channels, affine=bn_affine)) 54 | if i != num_layers - 2: # penultimate layer will automatically be concatenated if skip connection option is chosen 55 | self.layer_inds.append(cntr) 56 | cntr += 1 57 | 58 | net2 = nn.Sequential() 59 | 60 | nic = num_channels 61 | if skips: 62 | nic = num_channels*( sum(intermeds)+1 ) 63 | 64 | if need_last: 65 | net2.add( nn.Conv2d(nic, num_channels, kernel_size, strides[i], padding=(kernel_size-1)//2, bias=bias) ) 66 | net2.add(act_fun) 67 | net2.add(nn.BatchNorm2d( num_channels, affine=bn_affine)) 68 | nic = num_channels 69 | 70 | net2.add(nn.Conv2d(nic, num_output_channels, 1, 1, padding=0, bias=bias)) 71 | 72 | if sig is not None: 73 | net2.add(self.sig) 74 | 75 | self.net1 = net1 76 | self.net2 = net2 77 | 78 | def forward(self, x, scale_out=1): 79 | out1 = self.net1(x) 80 | if self.skips: 81 | intermed_outs = [] 82 | for i,c in enumerate(self.net1): 83 | if i+1 in self.layer_inds: 84 | f = self.net1[:i+1] 85 | intermed_outs.append(f(x)) 86 | intermed_outs = [intermed_outs[i] for i in range(len(intermed_outs)) if self.intermeds[i]] 87 | intermed_outs = [self.up_sample(io) for io in intermed_outs] 88 | out1 = torch.cat(intermed_outs+[out1],1) 89 | self.combinations = copy(out1) 90 | out2 = self.net2(out1) 91 | return out2*scale_out 92 | def up_sample(self,img): 93 | samp_block = nn.Upsample(size=self.hidden_size[-1], mode=self.upsample_mode)#,align_corners=True) 94 | img = samp_block(img) 95 | return img 96 | 97 | def convdecoder( 98 | out_size = [256,256], 99 | in_size = [16,16], 100 | num_output_channels=3, 101 | num_layers=6, 102 | strides=[1]*6, 103 | num_channels=64, 104 | need_sigmoid=True, 105 | pad='reflection', 106 | upsample_mode='bilinear', 107 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 108 | bn_before_act = False, 109 | bn_affine = True, 110 | skips = True, 111 | intermeds=None, 112 | nonlin_scales=False, 113 | bias=False, 114 | need_lin_comb=False, 115 | need_last=False, 116 | kernel_size=3, 117 | ): 118 | 119 | 120 | scale_x,scale_y = (out_size[0]/in_size[0])**(1./(num_layers-1)), (out_size[1]/in_size[1])**(1./(num_layers-1)) 121 | if nonlin_scales: 122 | xscales = np.ceil( np.linspace(scale_x * in_size[0],out_size[0],num_layers-1) ) 123 | yscales = np.ceil( np.linspace(scale_y * in_size[1],out_size[1],num_layers-1) ) 124 | hidden_size = [(int(x),int(y)) for (x,y) in zip(xscales,yscales)] 125 | else: 126 | hidden_size = [(int(np.ceil(scale_x**n * in_size[0])), 127 | int(np.ceil(scale_y**n * in_size[1]))) for n in range(1, (num_layers-1))] + [out_size] 128 | print(hidden_size) 129 | if need_sigmoid: 130 | sig = nn.Sigmoid() 131 | #sig = nn.Tanh() 132 | #sig = nn.Softmax() 133 | else: 134 | sig = None 135 | 136 | model = conv_model(num_layers, strides, num_channels, num_output_channels, hidden_size, 137 | upsample_mode=upsample_mode, 138 | act_fun=act_fun, 139 | sig=sig, 140 | bn_affine=bn_affine, 141 | skips=skips, 142 | intermeds=intermeds, 143 | bias=bias, 144 | need_lin_comb=need_lin_comb, 145 | need_last = need_last, 146 | kernel_size=kernel_size,) 147 | return model -------------------------------------------------------------------------------- /include/decoder_parallel_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | def add_module(self, module): 6 | self.add_module(str(len(self) + 1), module) 7 | 8 | torch.nn.Module.add = add_module 9 | 10 | class catc_model(nn.Module): 11 | def __init__(self, decoders_numlayers_list,decoders_last_channels,num_channels, num_output_channels, 12 | upsample_mode,act_fun,hidden_size,sig=None,bn_affine=True,bias=True,need_lin_comb=False, 13 | need_last=False,kernel_size=[3]*3): 14 | super(catc_model, self).__init__() 15 | 16 | self.sig = sig 17 | nets = [] 18 | M = max(decoders_numlayers_list) 19 | for n,num_layers in enumerate(decoders_numlayers_list): 20 | nc = num_channels 21 | net = nn.Sequential() 22 | for i in range(num_layers-1): 23 | net.add(nn.Upsample(size=hidden_size[n][i], mode=upsample_mode)) 24 | net.add(nn.Conv2d(num_channels, nc, kernel_size[n], 1, padding=(kernel_size[n]-1)//2, bias=bias)) 25 | net.add(nn.BatchNorm2d( nc, affine=bn_affine)) 26 | net.add(act_fun) 27 | 28 | if need_lin_comb: 29 | temp = nn.Sequential() 30 | temp.add(nn.Conv2d(num_channels, num_channels, 1, 1, padding=0, bias=bias)) 31 | temp.add(nn.BatchNorm2d( num_channels, affine=bn_affine)) 32 | temp.add(act_fun) 33 | net.add(temp) 34 | 35 | nc = num_channels 36 | if need_last: 37 | temp = nn.Sequential() 38 | temp.add( nn.Conv2d(nc, decoders_last_channels[n], 1, 1, padding=0, bias=bias) ) 39 | temp.add(nn.BatchNorm2d( decoders_last_channels[n], affine=bn_affine)) 40 | temp.add(act_fun) 41 | net.add(temp) 42 | nc = decoders_last_channels[n] 43 | net.add(nn.Conv2d(nc, decoders_last_channels[n], 1, 1, padding=0, bias=bias)) 44 | if self.sig is not None: 45 | net.add(self.sig) 46 | 47 | nets.append(net) 48 | del(net) 49 | 50 | self.net1 = nets[0] 51 | self.net2 = nets[1] 52 | self.net3 = nets[2] 53 | 54 | net4 = nn.Sequential() 55 | nc = sum(decoders_last_channels) 56 | if need_last: 57 | net4.add(nn.Conv2d(nc,num_output_channels,1,1,padding=0,bias=bias)) 58 | net4.add(act_fun) 59 | net4.add(nn.BatchNorm2d( num_output_channels, affine=bn_affine)) 60 | nc = num_output_channels 61 | net4.add(nn.Conv2d(nc,num_output_channels,1,1,padding=0,bias=bias)) 62 | self.net4 = net4 63 | 64 | def forward(self,x,scale_out=1): 65 | out1 = self.net1(x) 66 | out2 = self.net2(x) 67 | out3 = self.net3(x) 68 | 69 | last_inp = torch.cat([out1,out2,out3],1) 70 | out = self.net4(last_inp) 71 | if self.sig is not None: 72 | out = self.sig(out) 73 | return out*scale_out 74 | def parcdecoder(out_size = [256,256], 75 | in_size = [16,16], 76 | num_output_channels=3, 77 | num_channels=128, 78 | decoders_numlayers_list = [2,4,6], # (ascending order) determines the number of layers per each decoder in the parallel structure 79 | decoders_last_channels = [20,20,20], # last layer channel contribution of each decoder 80 | need_sigmoid=True, 81 | upsample_mode='bilinear', 82 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 83 | bn_affine = True, 84 | nonlin_scales=False, 85 | bias=True, 86 | kernel_size=[3]*3, 87 | need_lin_comb=True, 88 | need_last=True, 89 | ): 90 | 91 | hidden_size = [] 92 | for num_layers in decoders_numlayers_list: 93 | scale_x,scale_y = (out_size[0]/in_size[0])**(1./(num_layers-1)), (out_size[1]/in_size[1])**(1./(num_layers-1)) 94 | if nonlin_scales: 95 | xscales = np.ceil( np.linspace(scale_x * in_size[0],out_size[0],num_layers-1) ) 96 | yscales = np.ceil( np.linspace(scale_y * in_size[1],out_size[1],num_layers-1) ) 97 | h_s = [(int(x),int(y)) for (x,y) in zip(xscales,yscales)] 98 | else: 99 | h_s = [(int(np.ceil(scale_x**n * in_size[0])), 100 | int(np.ceil(scale_y**n * in_size[1]))) for n in range(1, (num_layers-1))] + [out_size] 101 | hidden_size.append(h_s) 102 | print(hidden_size) 103 | 104 | if need_sigmoid: 105 | sig = nn.Sigmoid() 106 | else: 107 | sig = None 108 | 109 | model = catc_model(decoders_numlayers_list, 110 | decoders_last_channels, 111 | num_channels, 112 | num_output_channels, 113 | upsample_mode, 114 | act_fun, 115 | hidden_size, 116 | sig = sig, 117 | bn_affine = bn_affine, 118 | bias=bias, 119 | kernel_size=kernel_size, 120 | need_lin_comb=need_lin_comb, 121 | need_last=need_last 122 | ) 123 | 124 | return model -------------------------------------------------------------------------------- /include/decoder_skip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from copy import copy 5 | 6 | def add_module(self, module): 7 | self.add_module(str(len(self) + 1), module) 8 | 9 | torch.nn.Module.add = add_module 10 | 11 | class skip_model(nn.Module): 12 | def __init__(self, num_layers, num_channels, num_output_channels, hidden_size, upsample_mode, act_fun,sig=None, bn_affine=True, skips=False,need_pad=True,need_last=True): 13 | super(skip_model, self).__init__() 14 | 15 | self.num_layers = num_layers 16 | #self.upsamp = nn.Upsample(scale_factor=2, mode=upsample_mode) 17 | self.hidden_size = hidden_size 18 | self.upsample_mode = upsample_mode 19 | self.act_fun = act_fun 20 | self.sig= sig 21 | self.skips = skips 22 | self.layer_inds = [] # record index of the layers that generate output in the sequential mode (after each BatchNorm) 23 | self.combinations = None # this holds input of the last layer which is upsampled versions of previous layers 24 | 25 | cntr = 1 26 | net1 = nn.Sequential() 27 | for i in range(num_layers-1): 28 | 29 | if need_pad: 30 | net1.add(nn.ReflectionPad2d(0)) 31 | net1.add(nn.Conv2d(num_channels, num_channels, 1, 1, padding=0, bias=False)) 32 | cntr += 1 33 | 34 | net1.add(nn.Upsample(size=hidden_size[i], mode=upsample_mode,align_corners=True)) 35 | cntr += 1 36 | 37 | net1.add(act_fun) 38 | cntr += 1 39 | net1.add(nn.BatchNorm2d( num_channels, affine=bn_affine)) 40 | if i != num_layers - 2: 41 | self.layer_inds.append(cntr) 42 | cntr += 1 43 | 44 | net2 = nn.Sequential() 45 | nic = num_channels 46 | if skips: 47 | nic = num_channels*(num_layers-1) 48 | if need_last: 49 | net2.add( nn.Conv2d(nic, num_channels, 1, 1, padding=0, bias=False) ) 50 | net2.add(act_fun) 51 | net2.add(nn.BatchNorm2d( num_channels, affine=bn_affine)) 52 | nic = num_channels 53 | if need_pad: 54 | net2.add(nn.ReflectionPad2d(0)) 55 | net2.add(nn.Conv2d(nic, num_output_channels, 1, 1, padding=0, bias=False)) 56 | if sig is not None: 57 | net2.add(self.sig) 58 | 59 | self.net1 = net1 60 | self.net2 = net2 61 | 62 | def forward(self, x, scale_out=1): 63 | out1 = self.net1(x) 64 | if self.skips: 65 | intermed_outs = [] 66 | for i,c in enumerate(self.net1): 67 | if i+1 in self.layer_inds: 68 | f = self.net1[:i+1] 69 | intermed_outs.append(f(x)) 70 | 71 | intermed_outs = [self.up_sample(io,i+1) for i,io in enumerate(intermed_outs)] 72 | 73 | out1 = torch.cat(intermed_outs+[out1],1) 74 | self.combinations = copy(out1) 75 | out2 = self.net2(out1) 76 | return out2*scale_out 77 | def up_sample(self,img,layer_ind): 78 | if layer_ind != self.num_layers-1: 79 | samp_block = nn.Upsample(size=self.hidden_size[-1], mode=self.upsample_mode)#,align_corners=True) 80 | img = samp_block(img) 81 | return img 82 | 83 | def skipdecoder( 84 | out_size = [256,256], 85 | in_size = [16,16], 86 | num_output_channels=3, 87 | num_layers=6, 88 | num_channels=64, 89 | need_sigmoid=True, 90 | need_pad=False, 91 | pad='reflection', 92 | upsample_mode='bilinear', 93 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 94 | bn_before_act = False, 95 | bn_affine = True, 96 | skips = True, 97 | nonlin_scales=False, 98 | need_last=True, 99 | ): 100 | 101 | 102 | scale_x,scale_y = (out_size[0]/in_size[0])**(1./(num_layers-1)), (out_size[1]/in_size[1])**(1./(num_layers-1)) 103 | if nonlin_scales: 104 | xscales = np.ceil( np.linspace(scale_x * in_size[0],out_size[0],num_layers-1) ) 105 | yscales = np.ceil( np.linspace(scale_y * in_size[1],out_size[1],num_layers-1) ) 106 | hidden_size = [(int(x),int(y)) for (x,y) in zip(xscales,yscales)] 107 | else: 108 | hidden_size = [(int(np.ceil(scale_x**n * in_size[0])), 109 | int(np.ceil(scale_y**n * in_size[1]))) for n in range(1, (num_layers-1))] + [out_size] 110 | print(hidden_size) 111 | if need_sigmoid: 112 | sig = nn.Sigmoid() 113 | else: 114 | sig = None 115 | 116 | model = skip_model(num_layers, num_channels, num_output_channels, hidden_size, 117 | upsample_mode=upsample_mode, 118 | act_fun=act_fun, 119 | sig=sig, 120 | bn_affine=bn_affine, 121 | skips=skips, 122 | need_pad=need_pad, 123 | need_last=need_last,) 124 | return model -------------------------------------------------------------------------------- /include/fit.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import torch.optim 4 | import copy 5 | import numpy as np 6 | from scipy.linalg import hadamard 7 | from skimage.metrics import structural_similarity as ssim 8 | 9 | from .helpers import * 10 | from .mri_helpers import * 11 | from .transforms import * 12 | 13 | dtype = torch.cuda.FloatTensor 14 | #dtype = torch.FloatTensor 15 | 16 | 17 | def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=500): 18 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 19 | lr = init_lr * (0.65**(epoch // lr_decay_epoch)) 20 | 21 | if epoch % lr_decay_epoch == 0: 22 | print('LR is set to {}'.format(lr)) 23 | 24 | for param_group in optimizer.param_groups: 25 | param_group['lr'] = lr 26 | 27 | return optimizer 28 | 29 | def sqnorm(a): 30 | return np.sum( a*a ) 31 | 32 | def get_distances(initial_maps,final_maps): 33 | results = [] 34 | for a,b in zip(initial_maps,final_maps): 35 | res = sqnorm(a-b)/(sqnorm(a) + sqnorm(b)) 36 | results += [res] 37 | return(results) 38 | 39 | def get_weights(net): 40 | weights = [] 41 | for m in net.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | weights += [m.weight.data.cpu().numpy()] 44 | return weights 45 | 46 | class MSLELoss(torch.nn.Module): 47 | def __init__(self): 48 | super(MSLELoss,self).__init__() 49 | 50 | def forward(self,x,y): 51 | criterion = nn.MSELoss() 52 | loss = torch.log(criterion(x, y)) 53 | return loss 54 | 55 | def fit(net, 56 | img_noisy_var, 57 | num_channels, 58 | img_clean_var, 59 | num_iter = 5000, 60 | LR = 0.01, 61 | OPTIMIZER='adam', 62 | opt_input = False, 63 | reg_noise_std = 0, 64 | reg_noise_decayevery = 100000, 65 | mask_var = None, 66 | mask = None, 67 | apply_f = None, 68 | lr_decay_epoch = 0, 69 | net_input = None, 70 | net_input_gen = "random", 71 | lsimg = None, 72 | target_img = None, 73 | find_best=False, 74 | weight_decay=0, 75 | upsample_mode = "bilinear", 76 | totalupsample = 1, 77 | loss_type="MSE", 78 | output_gradients=False, 79 | output_weights=False, 80 | show_images=False, 81 | plot_after=None, 82 | in_size=None, 83 | retain_graph = False, 84 | scale_out=1, 85 | ): 86 | 87 | if net_input is not None: 88 | print("input provided") 89 | else: 90 | 91 | if upsample_mode=="bilinear": 92 | # feed uniform noise into the network 93 | totalupsample = 2**len(num_channels) 94 | width = int(img_clean_var.data.shape[2]/totalupsample) 95 | height = int(img_clean_var.data.shape[3]/totalupsample) 96 | elif upsample_mode=="deconv": 97 | # feed uniform noise into the network 98 | totalupsample = 2**(len(num_channels)-1) 99 | width = int(img_clean_var.data.shape[2]/totalupsample) 100 | height = int(img_clean_var.data.shape[3]/totalupsample) 101 | elif upsample_mode=="free": 102 | width,height = in_size 103 | 104 | 105 | shape = [1,num_channels[0], width, height] 106 | print("input shape: ", shape) 107 | net_input = Variable(torch.zeros(shape)).type(dtype) 108 | net_input.data.uniform_() 109 | net_input.data *= 1./10 110 | 111 | net_input = net_input.type(dtype) 112 | net_input_saved = net_input.data.clone() 113 | noise = net_input.data.clone() 114 | p = [x for x in net.parameters() ] 115 | 116 | if(opt_input == True): # optimizer over the input as well 117 | net_input.requires_grad = True 118 | p += [net_input] 119 | 120 | mse_wrt_noisy = np.zeros(num_iter) 121 | mse_wrt_truth = np.zeros(num_iter) 122 | 123 | 124 | if OPTIMIZER == 'SGD': 125 | print("optimize with SGD", LR) 126 | optimizer = torch.optim.SGD(p, lr=LR,momentum=0.9,weight_decay=weight_decay) 127 | elif OPTIMIZER == 'adam': 128 | print("optimize with adam", LR) 129 | optimizer = torch.optim.Adam(p, lr=LR,weight_decay=weight_decay) 130 | elif OPTIMIZER == 'LBFGS': 131 | print("optimize with LBFGS", LR) 132 | optimizer = torch.optim.LBFGS(p, lr=LR) 133 | elif OPTIMIZER == "adagrad": 134 | print("optimize with adagrad", LR) 135 | optimizer = torch.optim.Adagrad(p, lr=LR,weight_decay=weight_decay) 136 | 137 | if loss_type=="MSE": 138 | mse = torch.nn.MSELoss() 139 | if loss_type == "MSLE": 140 | mse = MSLELoss() 141 | if loss_type=="L1": 142 | mse = nn.L1Loss() 143 | 144 | if find_best: 145 | best_net = copy.deepcopy(net) 146 | best_mse = 1000000.0 147 | 148 | nconvnets = 0 149 | for p in list(filter(lambda p: len(p.data.shape)>2, net.parameters())): 150 | nconvnets += 1 151 | 152 | out_grads = np.zeros((nconvnets,num_iter)) 153 | 154 | init_weights = get_weights(net) 155 | out_weights = np.zeros(( len(init_weights) ,num_iter)) 156 | 157 | out_imgs = np.zeros((1,1)) 158 | 159 | if plot_after is not None: 160 | try: 161 | out_img_np = net( net_input_saved.type(dtype),scale_out=scale_out ).data.cpu().numpy()[0] 162 | except: 163 | out_img_np = net( net_input_saved.type(dtype) ).data.cpu().numpy()[0] 164 | out_imgs = np.zeros( (len(plot_after),) + out_img_np.shape ) 165 | 166 | PSNRs = [] 167 | SSIMs = [] 168 | norm_ratio = [] 169 | for i in range(num_iter): 170 | """if i<=300: 171 | for param_group in optimizer.param_groups: 172 | param_group['lr'] = LR*20 173 | else: 174 | for param_group in optimizer.param_groups: 175 | param_group['lr'] = LR 176 | """ 177 | if lr_decay_epoch is not 0: 178 | optimizer = exp_lr_scheduler(optimizer, i, init_lr=LR, lr_decay_epoch=lr_decay_epoch) 179 | if reg_noise_std > 0: 180 | if i % reg_noise_decayevery == 0: 181 | reg_noise_std *= 0.7 182 | net_input = Variable(net_input_saved + (noise.normal_() * reg_noise_std)) 183 | 184 | 185 | def closure(): 186 | 187 | ### adjust scaling 188 | """if i <= num_iter: 189 | out = net(net_input.type(dtype),scale_out=1) 190 | out_chs = out.data.cpu().numpy()[0] 191 | out_imgs = channels2imgs(out_chs) 192 | orignorm = np.linalg.norm( root_sum_of_squares2(var_to_np(lsimg)) ) 193 | recnorm = np.linalg.norm( root_sum_of_squares2(out_imgs) ) 194 | scale_out = orignorm / recnorm 195 | ### 196 | if i == num_iter-1: 197 | print(scale_out) 198 | """ 199 | optimizer.zero_grad() 200 | try: 201 | out = net(net_input.type(dtype),scale_out=scale_out) 202 | except: 203 | out = net(net_input.type(dtype)) 204 | 205 | # training loss 206 | if mask_var is not None: 207 | loss = mse( out * mask_var , img_noisy_var * mask_var ) 208 | elif apply_f: 209 | loss = mse( apply_f(out,mask) , img_noisy_var ) 210 | else: 211 | loss = mse(out, img_noisy_var) 212 | 213 | loss.backward(retain_graph=retain_graph) 214 | 215 | mse_wrt_noisy[i] = loss.data.cpu().numpy() 216 | 217 | # the actual loss 218 | true_loss = mse( Variable(out.data, requires_grad=False).type(dtype), img_clean_var.type(dtype) ) 219 | mse_wrt_truth[i] = true_loss.data.cpu().numpy() 220 | 221 | if output_gradients: 222 | for ind,p in enumerate(list(filter(lambda p: p.grad is not None and len(p.data.shape)>2, net.parameters()))): 223 | out_grads[ind,i] = p.grad.data.norm(2).item() 224 | #print(p.grad.data.norm(2).item()) 225 | #su += p.grad.data.norm(2).item() 226 | #mse_wrt_noisy[i] = su 227 | 228 | if i % 100 == 0: 229 | if lsimg is not None: 230 | ### compute ssim and psnr ### 231 | out_chs = out.data.cpu().numpy()[0] 232 | out_imgs = channels2imgs(out_chs) 233 | # least squares reconstruciton 234 | orig = crop_center2( root_sum_of_squares2(var_to_np(lsimg)) , 320,320) 235 | 236 | # deep decoder reconstruction 237 | rec = crop_center2(root_sum_of_squares2(out_imgs),320,320) 238 | 239 | ssim_const = ssim(orig,rec,data_range=orig.max()) 240 | SSIMs.append(ssim_const) 241 | 242 | psnr_const = psnr(orig,rec,np.max(orig)) 243 | PSNRs.append(psnr_const) 244 | 245 | norm_ratio.append( np.linalg.norm(root_sum_of_squares2(out_imgs)) / np.linalg.norm(root_sum_of_squares2(var_to_np(lsimg))) ) 246 | ### ### 247 | 248 | trloss = loss.data 249 | true_loss = true_loss.data 250 | try: 251 | out2 = net(Variable(net_input_saved).type(dtype),scale_out=scale_out) 252 | except: 253 | out2 = net(Variable(net_input_saved).type(dtype)) 254 | loss2 = mse(out2, img_clean_var).data 255 | print ('Iteration %05d Train loss %f Actual loss %f Actual loss orig %f' % (i, trloss,true_loss,loss2), '\r', end='') 256 | 257 | if show_images: 258 | if i % 50 == 0: 259 | print(i) 260 | try: 261 | out_img_np = net( ni.type(dtype),scale_out=scale_out ).data.cpu().numpy()[0] 262 | except: 263 | out_img_np = net( ni.type(dtype) ).data.cpu().numpy()[0] 264 | myimgshow(plt,out_img_np) 265 | plt.show() 266 | 267 | if plot_after is not None: 268 | if i in plot_after: 269 | try: 270 | out_imgs[ plot_after.index(i) ,:] = net( net_input_saved.type(dtype),scale_out=scale_out ).data.cpu().numpy()[0] 271 | except: 272 | out_imgs[ plot_after.index(i) ,:] = net( net_input_saved.type(dtype),scale_out=scale_out ).data.cpu().numpy()[0] 273 | if output_weights: 274 | out_weights[:,i] = np.array( get_distances( init_weights, get_weights(net) ) ) 275 | 276 | return loss 277 | 278 | loss = optimizer.step(closure) 279 | 280 | if find_best: 281 | # if training loss improves by at least one percent, we found a new best net 282 | lossval = loss.data 283 | if best_mse > 1.005*lossval: 284 | best_mse = lossval 285 | best_net = copy.deepcopy(net) 286 | if opt_input: 287 | best_ni = net_input.data.clone() 288 | else: 289 | best_ni = net_input_saved.clone() 290 | 291 | 292 | if find_best: 293 | net = best_net 294 | net_input_saved = best_ni 295 | if output_gradients and output_weights: 296 | return scale_out,SSIMs,PSNRs,norm_ratio,mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_grads 297 | elif output_gradients: 298 | return scale_out,SSIMs,PSNRs,norm_ratio,mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_grads 299 | elif output_weights: 300 | return scale_out,SSIMs,PSNRs,norm_ratio,mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_weights 301 | elif plot_after is not None: 302 | return scale_out,SSIMs,PSNRs,norm_ratio,mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_imgs 303 | else: 304 | return scale_out,SSIMs,PSNRs,norm_ratio,mse_wrt_noisy, mse_wrt_truth,net_input_saved, net 305 | 306 | 307 | 308 | 309 | 310 | 311 | ### weight regularization 312 | #if orth_reg > 0: 313 | # for name, param in net.named_parameters(): 314 | # consider all the conv weights, but the last one which only combines colors 315 | # if '.1.weight' in name and str( len(net)-1 ) not in name: 316 | # param_flat = param.view(param.shape[0], -1) 317 | # sym = torch.mm(param_flat, torch.t(param_flat)) 318 | # sym -= Variable(torch.eye(param_flat.shape[0])).type(dtype) 319 | # loss = loss + (orth_reg * sym.sum().type(dtype) ) 320 | ### 321 | 322 | def fit_multiple(net, 323 | imgs, # list of images [ [1, color channels, W, H] ] 324 | num_channels, 325 | num_iter = 5000, 326 | LR = 0.01, 327 | find_best=False, 328 | upsample_mode="bilinear", 329 | ): 330 | # generate netinputs 331 | # feed uniform noise into the network 332 | nis = [] 333 | for i in range(len(imgs)): 334 | if upsample_mode=="bilinear": 335 | # feed uniform noise into the network 336 | totalupsample = 2**len(num_channels) 337 | elif upsample_mode=="deconv": 338 | # feed uniform noise into the network 339 | totalupsample = 2**(len(num_channels)-1) 340 | #totalupsample = 2**len(num_channels) 341 | width = int(imgs[0].data.shape[2]/totalupsample) 342 | height = int(imgs[0].data.shape[3]/totalupsample) 343 | shape = [1 ,num_channels[0], width, height] 344 | print("shape: ", shape) 345 | net_input = Variable(torch.zeros(shape)) 346 | net_input.data.uniform_() 347 | net_input.data *= 1./10 348 | nis.append(net_input) 349 | 350 | # learnable parameters are the weights 351 | p = [x for x in net.parameters() ] 352 | 353 | mse_wrt_noisy = np.zeros(num_iter) 354 | 355 | optimizer = torch.optim.Adam(p, lr=LR) 356 | 357 | mse = torch.nn.MSELoss() #.type(dtype) 358 | 359 | if find_best: 360 | best_net = copy.deepcopy(net) 361 | best_mse = 1000000.0 362 | 363 | for i in range(num_iter): 364 | 365 | def closure(): 366 | optimizer.zero_grad() 367 | 368 | #loss = np_to_var(np.array([0.0])) 369 | out = net(nis[0].type(dtype)) 370 | loss = mse(out, imgs[0].type(dtype)) 371 | #for img,ni in zip(imgs,nis): 372 | for j in range(1,len(imgs)): 373 | #out = net(ni.type(dtype)) 374 | #loss += mse(out, img.type(dtype)) 375 | out = net(nis[j].type(dtype)) 376 | loss += mse(out, imgs[j].type(dtype)) 377 | 378 | #out = net(nis[0].type(dtype)) 379 | #out2 = net(nis[1].type(dtype)) 380 | #loss = mse(out, imgs[0].type(dtype)) + mse(out2, imgs[1].type(dtype)) 381 | 382 | loss.backward() 383 | mse_wrt_noisy[i] = loss.data.cpu().numpy() 384 | 385 | if i % 10 == 0: 386 | print ('Iteration %05d Train loss %f' % (i, loss.data), '\r', end='') 387 | return loss 388 | 389 | loss = optimizer.step(closure) 390 | 391 | if find_best: 392 | # if training loss improves by at least one percent, we found a new best net 393 | if best_mse > 1.005*loss.data: 394 | best_mse = loss.data 395 | best_net = copy.deepcopy(net) 396 | 397 | if find_best: 398 | net = best_net 399 | return mse_wrt_noisy, nis, net 400 | 401 | -------------------------------------------------------------------------------- /include/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import sys 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import PIL 9 | import numpy as np 10 | 11 | from torch.autograd import Variable 12 | 13 | import random 14 | import numpy as np 15 | import torch 16 | import matplotlib.pyplot as plt 17 | 18 | from PIL import Image 19 | import PIL 20 | 21 | from torch.autograd import Variable 22 | 23 | def myimgshow(plt,img): 24 | if(img.shape[0] == 1): 25 | plt.imshow(np.clip(img[0],0,1),cmap='Greys',interpolation='none') 26 | else: 27 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1),interpolation='none') 28 | 29 | def load_and_crop(imgname,target_width=512,target_height=512): 30 | ''' 31 | imgname: string of image location 32 | load an image, and center-crop if the image is large enough, else return none 33 | ''' 34 | img = Image.open(imgname) 35 | width, height = img.size 36 | if width <= target_width or height <= target_height: 37 | return None 38 | 39 | left = (width - target_width)/2 40 | top = (height - target_height)/2 41 | right = (width + target_width)/2 42 | bottom = (height + target_height)/2 43 | 44 | return img.crop((left, top, right, bottom)) 45 | 46 | def save_np_img(img,filename): 47 | if(img.shape[0] == 1): 48 | plt.imshow(np.clip(img[0],0,1),cmap='Greys',interpolation='nearest') 49 | else: 50 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1)) 51 | plt.axis('off') 52 | plt.savefig(filename, bbox_inches='tight') 53 | plt.close() 54 | 55 | def np_to_tensor(img_np): 56 | '''Converts image in numpy.array to torch.Tensor. 57 | 58 | From C x W x H [0..1] to C x W x H [0..1] 59 | ''' 60 | return torch.from_numpy(img_np) 61 | 62 | def np_to_var(img_np, dtype = torch.cuda.FloatTensor): 63 | '''Converts image in numpy.array to torch.Variable. 64 | 65 | From C x W x H [0..1] to 1 x C x W x H [0..1] 66 | ''' 67 | return Variable(np_to_tensor(img_np)[None, :]) 68 | 69 | def var_to_np(img_var): 70 | '''Converts an image in torch.Variable format to np.array. 71 | 72 | From 1 x C x W x H [0..1] to C x W x H [0..1] 73 | ''' 74 | return img_var.data.cpu().numpy()[0] 75 | 76 | 77 | def pil_to_np(img_PIL): 78 | '''Converts image in PIL format to np.array. 79 | 80 | From W x H x C [0...255] to C x W x H [0..1] 81 | ''' 82 | ar = np.array(img_PIL) 83 | 84 | if len(ar.shape) == 3: 85 | ar = ar.transpose(2,0,1) 86 | else: 87 | ar = ar[None, ...] 88 | 89 | return ar.astype(np.float32) / 255. 90 | 91 | 92 | def rgb2ycbcr(img): 93 | #out = color.rgb2ycbcr( img.transpose(1, 2, 0) ) 94 | #return out.transpose(2,0,1)/256. 95 | r,g,b = img[0],img[1],img[2] 96 | y = 0.299*r+0.587*g+0.114*b 97 | cb = 0.5 - 0.168736*r - 0.331264*g + 0.5*b 98 | cr = 0.5 + 0.5*r - 0.418588*g - 0.081312*b 99 | return np.array([y,cb,cr]) 100 | 101 | def ycbcr2rgb(img): 102 | #out = color.ycbcr2rgb( 256.*img.transpose(1, 2, 0) ) 103 | #return (out.transpose(2,0,1) - np.min(out))/(np.max(out)-np.min(out)) 104 | y,cb,cr = img[0],img[1],img[2] 105 | r = y + 1.402*(cr-0.5) 106 | g = y - 0.344136*(cb-0.5) - 0.714136*(cr-0.5) 107 | b = y + 1.772*(cb - 0.5) 108 | return np.array([r,g,b]) 109 | 110 | 111 | 112 | def mse(x_hat,x_true,maxv=1.): 113 | x_hat = x_hat.flatten() 114 | x_true = x_true.flatten() 115 | mse = np.mean(np.square(x_hat-x_true)) 116 | energy = np.mean(np.square(x_true)) 117 | return mse/energy 118 | 119 | def psnr(x_hat,x_true,maxv=1.): 120 | x_hat = x_hat.flatten() 121 | x_true = x_true.flatten() 122 | mse=np.mean(np.square(x_hat-x_true)) 123 | psnr_ = 10.*np.log(maxv**2/mse)/np.log(10.) 124 | return psnr_ 125 | 126 | def num_param(net): 127 | s = sum([np.prod(list(p.size())) for p in net.parameters()]); 128 | return s 129 | #print('Number of params: %d' % s) 130 | 131 | def rgb2gray(rgb): 132 | r, g, b = rgb[0,:,:], rgb[1,:,:], rgb[2,:,:] 133 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 134 | return np.array([gray]) 135 | 136 | def savemtx_for_logplot(A,filename = "exp.dat"): 137 | ind = sorted(list(set([int(i) for i in np.geomspace(1, len(A[0])-1 ,num=700)]))) 138 | A = [ [a[i] for i in ind] for a in A] 139 | X = np.array([ind] + A) 140 | np.savetxt(filename, X.T, delimiter=' ') 141 | 142 | 143 | def get_imgnet_imgs(num_samples = 100, path = '../imagenet/',verbose=False): 144 | perm = [i for i in range(1,50000)] 145 | random.Random(4).shuffle(perm) 146 | siz = 512 147 | file = open("exp_imgnet_imgs.txt","w") 148 | 149 | imgs = [] 150 | sampled = 0 151 | imgslist = [] 152 | for imgnr in perm: 153 | # prepare and select image 154 | # Format is: ILSVRC2012_val_00024995.JPEG 155 | imgnr_str = str(imgnr).zfill(8) 156 | imgname = path + 'ILSVRC2012_val_' + imgnr_str + ".JPEG" 157 | img = load_and_crop(imgname,target_width=512,target_height=512) 158 | if img is None: # then the image could not be croped to 512x512 159 | continue 160 | 161 | img_np = pil_to_np(img) 162 | 163 | if img_np.shape[0] != 3: # we only want to consider color images 164 | continue 165 | if verbose: 166 | imgslist += ['ILSVRC2012_val_' + imgnr_str + ".JPEG"] 167 | print("cp ", imgname, "./imgs") 168 | imgs += [img_np] 169 | sampled += 1 170 | if sampled >= num_samples: 171 | break 172 | if verbose: 173 | print(imgslist) 174 | return imgs 175 | 176 | 177 | 178 | -------------------------------------------------------------------------------- /include/mri_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import sys 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import PIL 9 | import numpy as np 10 | 11 | from torch.autograd import Variable 12 | 13 | import random 14 | import numpy as np 15 | import torch 16 | import matplotlib.pyplot as plt 17 | 18 | from PIL import Image 19 | import PIL 20 | 21 | from torch.autograd import Variable 22 | dtype = torch.cuda.FloatTensor 23 | 24 | from . import transforms as transform 25 | from .helpers import var_to_np,np_to_var 26 | 27 | import numpy 28 | import scipy.signal 29 | import scipy.ndimage 30 | 31 | 32 | def ksp2measurement(ksp): 33 | return np_to_var( np.transpose( np.array([np.real(ksp),np.imag(ksp)]) , (1, 2, 3, 0)) ) 34 | 35 | def lsreconstruction(measurement,mode='both'): 36 | # measurement has dimension (1, num_slices, x, y, 2) 37 | fimg = transform.ifft2(measurement) 38 | normimag = torch.norm(fimg[:,:,:,:,0]) 39 | normreal = torch.norm(fimg[:,:,:,:,1]) 40 | #print("real/img parts: ",normimag, normreal) 41 | if mode == 'both': 42 | return torch.sqrt(fimg[:,:,:,:,0]**2 + fimg[:,:,:,:,1]**2) 43 | elif mode == 'real': 44 | return torch.tensor(fimg[:,:,:,:,0]) #torch.sqrt(fimg[:,:,:,:,0]**2) 45 | elif mode == 'imag': 46 | return torch.sqrt(fimg[:,:,:,:,1]**2) 47 | 48 | def root_sum_of_squares2(lsimg): 49 | out = np.zeros(lsimg[0].shape) 50 | for img in lsimg: 51 | out += img**2 52 | return np.sqrt(out) 53 | 54 | def crop_center2(img,cropx,cropy): 55 | y,x = img.shape 56 | startx = x//2-(cropx//2) 57 | starty = y//2-(cropy//2) 58 | return img[starty:starty+cropy,startx:startx+cropx] 59 | 60 | def channels2imgs(out): 61 | sh = out.shape 62 | chs = int(sh[0]/2) 63 | imgs = np.zeros( (chs,sh[1],sh[2]) ) 64 | for i in range(chs): 65 | imgs[i] = np.sqrt( out[2*i]**2 + out[2*i+1]**2 ) 66 | return imgs 67 | 68 | def forwardm(img,mask): 69 | # img has dimension (2*num_slices, x,y) 70 | # output has dimension (1, num_slices, x, y, 2) 71 | mask = np_to_var(mask)[0].type(dtype) 72 | s = img.shape 73 | ns = int(s[1]/2) # number of slices 74 | fimg = Variable( torch.zeros( (s[0],ns,s[2],s[3],2 ) ) ).type(dtype) 75 | for i in range(ns): 76 | fimg[0,i,:,:,0] = img[0,2*i,:,:] 77 | fimg[0,i,:,:,1] = img[0,2*i+1,:,:] 78 | Fimg = transform.fft2(fimg) # dim: (1,num_slices,x,y,2) 79 | for i in range(ns): 80 | Fimg[0,i,:,:,0] *= mask 81 | Fimg[0,i,:,:,1] *= mask 82 | return Fimg 83 | 84 | def get_scale_factor(net,num_channels,in_size,slice_ksp,scale_out=1,scale_type="norm"): 85 | ### get norm of deep decoder output 86 | # get net input, scaling of that is irrelevant 87 | shape = [1,num_channels, in_size[0], in_size[1]] 88 | ni = Variable(torch.zeros(shape)).type(dtype) 89 | ni.data.uniform_() 90 | # generate random image 91 | try: 92 | out_chs = net( ni.type(dtype),scale_out=scale_out ).data.cpu().numpy()[0] 93 | except: 94 | out_chs = net( ni.type(dtype) ).data.cpu().numpy()[0] 95 | out_imgs = channels2imgs(out_chs) 96 | out_img_tt = transform.root_sum_of_squares( torch.tensor(out_imgs) , dim=0) 97 | 98 | ### get norm of least-squares reconstruction 99 | ksp_tt = transform.to_tensor(slice_ksp) 100 | orig_tt = transform.ifft2(ksp_tt) # Apply Inverse Fourier Transform to get the complex image 101 | orig_imgs_tt = transform.complex_abs(orig_tt) # Compute absolute value to get a real image 102 | orig_img_tt = transform.root_sum_of_squares(orig_imgs_tt, dim=0) 103 | orig_img_np = orig_img_tt.cpu().numpy() 104 | 105 | if scale_type == "norm": 106 | s = np.linalg.norm(out_img_tt) / np.linalg.norm(orig_img_np) 107 | if scale_type == "mean": 108 | s = (out_img_tt.mean() / orig_img_np.mean()).numpy()[np.newaxis][0] 109 | return s,ni 110 | 111 | -------------------------------------------------------------------------------- /include/onedim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim 5 | from torch.autograd import Variable 6 | import matplotlib.pyplot as plt 7 | import copy 8 | 9 | 10 | 11 | dtype = torch.FloatTensor # This code is meant for CPU 12 | 13 | def add_module(self, module): 14 | self.add_module(str(len(self) + 1), module) 15 | 16 | torch.nn.Module.add = add_module 17 | 18 | 19 | def conv1(in_f, out_f, kernel_size, stride=1, pad='zero'): 20 | padder = None 21 | to_pad = int((kernel_size - 1) / 2) 22 | if pad == 'reflection': 23 | padder = nn.ReflectionPad2d(to_pad) 24 | to_pad = 0 25 | 26 | convolver = nn.Conv1d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=False) 27 | 28 | layers = filter(lambda x: x is not None, [padder, convolver]) 29 | return nn.Sequential(*layers) 30 | 31 | 32 | 33 | # Define the upsampling matrices 34 | def get_upsample_matrix(k, identity=False, upsample_mode='linear'): 35 | # Returns a 2*k-1 x k numpy array corresponding to an upsampling matrix 36 | 37 | if identity: 38 | return np.eye(k) 39 | U = np.zeros((2*k-1, k)) 40 | for i in range(k): 41 | U[2*i, i] = 1 42 | 43 | if i < k-1: 44 | if upsample_mode=='linear': 45 | U[2*i+1, [i, (i+1) % k]] = [1./2, 1./2] 46 | elif upsample_mode=='convex0.7-0.3': 47 | U[2*i+1, [i, (i+1) % k]] = [0.7, 0.3] 48 | elif upsample_mode=='convex0.75-0.25': 49 | U[2*i+1, [i, (i+1) % k]] = [0.75, 0.25] 50 | return U 51 | 52 | 53 | 54 | class Upsample_Module(nn.Module): 55 | # Only works for batch size 1. Works for any number of channels 56 | 57 | def __init__(self, upsample_mode='linear'): 58 | super(Upsample_Module,self).__init__() 59 | self.upsample_mode=upsample_mode 60 | 61 | def forward(self, x): 62 | n = x.shape[2] 63 | U = Variable(torch.Tensor(get_upsample_matrix(n, upsample_mode=self.upsample_mode))) 64 | return torch.stack([torch.t(U.matmul(torch.t(x[0,...])))], 0) 65 | 66 | 67 | 68 | 69 | def decoder_1d( 70 | num_output_channels=1, 71 | num_channels_up=[128]*5, 72 | filter_size_up=1, 73 | need_sigmoid=False, 74 | pad='zero', 75 | upsample_mode='linear', 76 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 77 | need_bn=True, 78 | ): 79 | 80 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 81 | n_scales = len(num_channels_up) 82 | #print('n_scales = %d' %n_scales) 83 | 84 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 85 | filter_size_up = [filter_size_up]*n_scales 86 | 87 | model = nn.Sequential() 88 | 89 | for i in range(len(num_channels_up)-1): 90 | 91 | if upsample_mode!='none' and i!=0: 92 | if upsample_mode=='MatrixUpsample': 93 | model.add(Upsample_Module()) 94 | elif upsample_mode=='MatrixUpsampleConvex0.7-0.3': 95 | model.add(Upsample_Module(upsample_mode='convex0.7-0.3')) 96 | elif upsample_mode=='MatrixUpsampleConvex0.75-0.25': 97 | model.add(Upsample_Module(upsample_mode='convex0.75-0.25')) 98 | elif upsample_mode=='nnUpsampleDouble': 99 | model.add(nn.Upsample(scale_factor=2.0, mode='linear', align_corners=False)) 100 | elif upsample_mode=='nearest': 101 | model.add(nn.Upsample(scale_factor=2.0, mode='nearest')) 102 | 103 | model.add(conv1( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad)) 104 | if i != len(num_channels_up)-1: 105 | if need_bn: 106 | model.add(nn.BatchNorm1d( num_channels_up[i+1] )) 107 | model.add(act_fun) 108 | 109 | model.add(conv1( num_channels_up[-1], num_output_channels, 1, pad=pad)) 110 | 111 | if need_sigmoid: 112 | model.add(nn.Sigmoid()) 113 | 114 | return model 115 | 116 | 117 | def fit_1d(net, 118 | img_noisy_var, 119 | num_channels, 120 | img_clean_var, 121 | net_input, # Passing in the net_input is required 122 | num_iter = 5000, 123 | LR = 0.01, 124 | OPTIMIZER='adam', 125 | opt_input = False, 126 | reg_noise_std = 0, 127 | reg_noise_decayevery = 100000, 128 | mask_var = None, 129 | apply_f = None, 130 | decaylr = False, 131 | net_input_gen = "random", 132 | plot_output_every = None, 133 | ): 134 | 135 | net_input_saved = net_input.data.clone() 136 | noise = net_input.data.clone() 137 | p = [x for x in net.parameters() ] 138 | 139 | if(opt_input == True): 140 | net_input.requires_grad = True 141 | p += [net_input] 142 | 143 | mse_wrt_noisy = np.zeros(num_iter) 144 | mse_wrt_truth = np.zeros(num_iter) 145 | 146 | if OPTIMIZER == 'SGD': 147 | print("optimize with SGD", LR) 148 | optimizer = torch.optim.SGD(p, lr=LR,momentum=0.9) 149 | elif OPTIMIZER == 'adam': 150 | print("optimize with adam", LR) 151 | optimizer = torch.optim.Adam(p, lr=LR) 152 | 153 | mse = torch.nn.MSELoss() #.type(dtype) 154 | noise_energy = mse(img_noisy_var, img_clean_var) 155 | 156 | 157 | 158 | for i in range(num_iter): 159 | if decaylr is True: 160 | optimizer = exp_lr_scheduler(optimizer, i, init_lr=LR, lr_decay_epoch=100) 161 | if reg_noise_std > 0: 162 | if i % reg_noise_decayevery == 0: 163 | reg_noise_std *= 0.7 164 | net_input = Variable(net_input_saved + (noise.normal_() * reg_noise_std)) 165 | optimizer.zero_grad() 166 | out = net(net_input.type(dtype)) 167 | 168 | 169 | 170 | # training loss 171 | if mask_var is not None: 172 | loss = mse( out * mask_var , img_noisy_var * mask_var ) 173 | elif apply_f: 174 | loss = mse( apply_f(out) , img_noisy_var ) 175 | else: 176 | loss = mse(out, img_noisy_var) 177 | loss.backward() 178 | mse_wrt_noisy[i] = loss.data.cpu().numpy() 179 | if mse_wrt_noisy[i] == np.min(mse_wrt_noisy[:i+1]): 180 | best_net = copy.deepcopy(net) 181 | best_mse_wrt_noisy = mse_wrt_noisy[i] 182 | 183 | # the actual loss 184 | true_loss = mse(Variable(out.data, requires_grad=False), img_clean_var) 185 | mse_wrt_truth[i] = true_loss.data.cpu().numpy() 186 | if i % 10 == 0: 187 | out2 = net(Variable(net_input_saved).type(dtype)) 188 | loss2 = mse(out2, img_clean_var) 189 | print ('Iteration %05d Train loss %f Actual loss %f Actual loss orig %f Noise Energy %f' 190 | % (i, loss.data.item(),true_loss.data.item(),loss2.data.item(),noise_energy.data.item()), '\r', end='') 191 | if plot_output_every and (i % plot_output_every==1): 192 | out3 = net(Variable(net_input_saved).type(dtype)) 193 | ax = plt.figure(figsize=(12,5)) 194 | plt.plot(out3[0,0,:].data.numpy(), '.b') 195 | plt.plot(img_clean_var[0,0,:].data.numpy(), '-r') 196 | plt.show() 197 | optimizer.step() 198 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, best_net, best_mse_wrt_noisy # Didn't implement case wehere there is noise in signal 199 | 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /include/pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | """def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average)""" 74 | -------------------------------------------------------------------------------- /include/pytorch_ssim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/include/pytorch_ssim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /include/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def to_tensor(data): 13 | """ 14 | Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts 15 | are stacked along the last dimension. 16 | 17 | Args: 18 | data (np.array): Input numpy array 19 | 20 | Returns: 21 | torch.Tensor: PyTorch version of data 22 | """ 23 | if np.iscomplexobj(data): 24 | data = np.stack((data.real, data.imag), axis=-1) 25 | return torch.from_numpy(data) 26 | 27 | 28 | def apply_mask(data, mask_func = None, mask = None, seed=None): 29 | """ 30 | Subsample given k-space by multiplying with a mask. 31 | 32 | Args: 33 | data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where 34 | dimensions -3 and -2 are the spatial dimensions, and the final dimension has size 35 | 2 (for complex values). 36 | mask_func (callable): A function that takes a shape (tuple of ints) and a random 37 | number seed and returns a mask. 38 | seed (int or 1-d array_like, optional): Seed for the random number generator. 39 | 40 | Returns: 41 | (tuple): tuple containing: 42 | masked data (torch.Tensor): Subsampled k-space data 43 | mask (torch.Tensor): The generated mask 44 | """ 45 | shape = np.array(data.shape) 46 | shape[:-3] = 1 47 | if mask is None: 48 | mask = mask_func(shape, seed) 49 | return data * mask, mask 50 | 51 | 52 | def fft2(data): 53 | """ 54 | Apply centered 2 dimensional Fast Fourier Transform. 55 | 56 | Args: 57 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 58 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 59 | assumed to be batch dimensions. 60 | 61 | Returns: 62 | torch.Tensor: The FFT of the input. 63 | """ 64 | assert data.size(-1) == 2 65 | data = ifftshift(data, dim=(-3, -2)) 66 | data = torch.fft(data, 2, normalized=True) 67 | data = fftshift(data, dim=(-3, -2)) 68 | return data 69 | 70 | 71 | def ifft2(data): 72 | """ 73 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 74 | 75 | Args: 76 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 77 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 78 | assumed to be batch dimensions. 79 | 80 | Returns: 81 | torch.Tensor: The IFFT of the input. 82 | """ 83 | assert data.size(-1) == 2 84 | data = ifftshift(data, dim=(-3, -2)) 85 | data = torch.ifft(data, 2, normalized=True) 86 | data = fftshift(data, dim=(-3, -2)) 87 | return data 88 | 89 | 90 | def complex_abs(data): 91 | """ 92 | Compute the absolute value of a complex valued input tensor. 93 | 94 | Args: 95 | data (torch.Tensor): A complex valued tensor, where the size of the final dimension 96 | should be 2. 97 | 98 | Returns: 99 | torch.Tensor: Absolute value of data 100 | """ 101 | assert data.size(-1) == 2 102 | return (data ** 2).sum(dim=-1).sqrt() 103 | 104 | 105 | def root_sum_of_squares(data, dim=0): 106 | """ 107 | Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor. 108 | 109 | Args: 110 | data (torch.Tensor): The input tensor 111 | dim (int): The dimensions along which to apply the RSS transform 112 | 113 | Returns: 114 | torch.Tensor: The RSS value 115 | """ 116 | return torch.sqrt((data ** 2).sum(dim)) 117 | 118 | 119 | def center_crop(data, shape): 120 | """ 121 | Apply a center crop to the input real image or batch of real images. 122 | 123 | Args: 124 | data (torch.Tensor): The input tensor to be center cropped. It should have at 125 | least 2 dimensions and the cropping is applied along the last two dimensions. 126 | shape (int, int): The output shape. The shape should be smaller than the 127 | corresponding dimensions of data. 128 | 129 | Returns: 130 | torch.Tensor: The center cropped image 131 | """ 132 | assert 0 < shape[0] <= data.shape[-2] 133 | assert 0 < shape[1] <= data.shape[-1] 134 | w_from = (data.shape[-2] - shape[0]) // 2 135 | h_from = (data.shape[-1] - shape[1]) // 2 136 | w_to = w_from + shape[0] 137 | h_to = h_from + shape[1] 138 | return data[..., w_from:w_to, h_from:h_to] 139 | 140 | 141 | def complex_center_crop(data, shape): 142 | """ 143 | Apply a center crop to the input image or batch of complex images. 144 | 145 | Args: 146 | data (torch.Tensor): The complex input tensor to be center cropped. It should 147 | have at least 3 dimensions and the cropping is applied along dimensions 148 | -3 and -2 and the last dimensions should have a size of 2. 149 | shape (int, int): The output shape. The shape should be smaller than the 150 | corresponding dimensions of data. 151 | 152 | Returns: 153 | torch.Tensor: The center cropped image 154 | """ 155 | assert 0 < shape[0] <= data.shape[-3] 156 | assert 0 < shape[1] <= data.shape[-2] 157 | w_from = (data.shape[-3] - shape[0]) // 2 158 | h_from = (data.shape[-2] - shape[1]) // 2 159 | w_to = w_from + shape[0] 160 | h_to = h_from + shape[1] 161 | return data[..., w_from:w_to, h_from:h_to, :] 162 | 163 | 164 | def normalize(data, mean, stddev, eps=0.): 165 | """ 166 | Normalize the given tensor using: 167 | (data - mean) / (stddev + eps) 168 | 169 | Args: 170 | data (torch.Tensor): Input data to be normalized 171 | mean (float): Mean value 172 | stddev (float): Standard deviation 173 | eps (float): Added to stddev to prevent dividing by zero 174 | 175 | Returns: 176 | torch.Tensor: Normalized tensor 177 | """ 178 | return (data - mean) / (stddev + eps) 179 | 180 | 181 | def normalize_instance(data, eps=0.): 182 | """ 183 | Normalize the given tensor using: 184 | (data - mean) / (stddev + eps) 185 | where mean and stddev are computed from the data itself. 186 | 187 | Args: 188 | data (torch.Tensor): Input data to be normalized 189 | eps (float): Added to stddev to prevent dividing by zero 190 | 191 | Returns: 192 | torch.Tensor: Normalized tensor 193 | """ 194 | mean = data.mean() 195 | std = data.std() 196 | return normalize(data, mean, std, eps), mean, std 197 | 198 | 199 | # Helper functions 200 | 201 | def roll(x, shift, dim): 202 | """ 203 | Similar to np.roll but applies to PyTorch Tensors 204 | """ 205 | if isinstance(shift, (tuple, list)): 206 | assert len(shift) == len(dim) 207 | for s, d in zip(shift, dim): 208 | x = roll(x, s, d) 209 | return x 210 | shift = shift % x.size(dim) 211 | if shift == 0: 212 | return x 213 | left = x.narrow(dim, 0, x.size(dim) - shift) 214 | right = x.narrow(dim, x.size(dim) - shift, shift) 215 | return torch.cat((right, left), dim=dim) 216 | 217 | 218 | def fftshift(x, dim=None): 219 | """ 220 | Similar to np.fft.fftshift but applies to PyTorch Tensors 221 | """ 222 | if dim is None: 223 | dim = tuple(range(x.dim())) 224 | shift = [dim // 2 for dim in x.shape] 225 | elif isinstance(dim, int): 226 | shift = x.shape[dim] // 2 227 | else: 228 | shift = [x.shape[i] // 2 for i in dim] 229 | return roll(x, shift, dim) 230 | 231 | 232 | def ifftshift(x, dim=None): 233 | """ 234 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 235 | """ 236 | if dim is None: 237 | dim = tuple(range(x.dim())) 238 | shift = [(dim + 1) // 2 for dim in x.shape] 239 | elif isinstance(dim, int): 240 | shift = (x.shape[dim] + 1) // 2 241 | else: 242 | shift = [(x.shape[i] + 1) // 2 for i in dim] 243 | return roll(x, shift, dim) 244 | -------------------------------------------------------------------------------- /out_of_distribution_image/cameraman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/out_of_distribution_image/cameraman.png --------------------------------------------------------------------------------