├── ConvDecoder_architecture
└── ConvDecoder.JPG
├── ConvDecoder_for_MRI.ipynb
├── ConvDecoder_vs_DIP_vs_DD_multicoil.ipynb
├── ConvDecoder_vs_Unet_multicoil.ipynb
├── DIP_UNET_models
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── common.cpython-36.pyc
│ ├── common.cpython-37.pyc
│ ├── downsampler.cpython-36.pyc
│ ├── mri_model.cpython-36.pyc
│ ├── resnet.cpython-36.pyc
│ ├── skip.cpython-36.pyc
│ ├── skip.cpython-37.pyc
│ ├── skip_decoder.cpython-36.pyc
│ ├── texture_nets.cpython-36.pyc
│ └── unet.cpython-36.pyc
├── common.py
├── skip.py
└── unet_and_tv
│ ├── .ipynb_checkpoints
│ └── run_bart-checkpoint.ipynb
│ ├── README.md
│ ├── __pycache__
│ ├── mri_model.cpython-36.pyc
│ ├── mri_model.cpython-37.pyc
│ ├── train_unet.cpython-36.pyc
│ ├── train_unet.cpython-37.pyc
│ ├── unet_model.cpython-36.pyc
│ └── unet_model.cpython-37.pyc
│ ├── common
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── args.cpython-36.pyc
│ │ ├── args.cpython-37.pyc
│ │ ├── evaluate.cpython-36.pyc
│ │ ├── evaluate.cpython-37.pyc
│ │ ├── subsample.cpython-36.pyc
│ │ ├── subsample.cpython-37.pyc
│ │ ├── utils.cpython-36.pyc
│ │ └── utils.cpython-37.pyc
│ ├── args.py
│ ├── evaluate.py
│ ├── subsample.py
│ ├── test_subsample.py
│ └── utils.py
│ ├── data
│ ├── README.md
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── mri_data.cpython-36.pyc
│ │ ├── mri_data.cpython-37.pyc
│ │ ├── transforms.cpython-36.pyc
│ │ └── transforms.cpython-37.pyc
│ ├── mri_data.py
│ ├── test_transforms.py
│ └── transforms.py
│ ├── mri_model.py
│ ├── run_bart.ipynb
│ ├── run_bart_val.py
│ ├── train_unet.py
│ ├── unet_model.py
│ └── uuu.zip
├── LICENSE
├── README.md
├── UNET_trained
└── epoch=49.ckpt
├── common
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── args.cpython-36.pyc
│ ├── args.cpython-37.pyc
│ ├── evaluate.cpython-36.pyc
│ ├── evaluate.cpython-37.pyc
│ ├── subsample.cpython-36.pyc
│ ├── subsample.cpython-37.pyc
│ ├── utils.cpython-36.pyc
│ └── utils.cpython-37.pyc
├── args.py
├── evaluate.py
├── subsample.py
├── test_subsample.py
└── utils.py
├── demo_helper
└── helpers.py
├── include
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── compression.cpython-36.pyc
│ ├── decoder.cpython-36.pyc
│ ├── decoder_conv.cpython-36.pyc
│ ├── decoder_conv.cpython-37.pyc
│ ├── decoder_parallel.cpython-36.pyc
│ ├── decoder_parallel2.cpython-36.pyc
│ ├── decoder_parallel_conv.cpython-36.pyc
│ ├── decoder_parallel_conv.cpython-37.pyc
│ ├── decoder_skip.cpython-36.pyc
│ ├── decoder_skip.cpython-37.pyc
│ ├── decoder_skip2.cpython-36.pyc
│ ├── fit.cpython-36.pyc
│ ├── fit.cpython-37.pyc
│ ├── helpers.cpython-36.pyc
│ ├── helpers.cpython-37.pyc
│ ├── mri_helpers.cpython-36.pyc
│ ├── mri_helpers.cpython-37.pyc
│ ├── transforms.cpython-36.pyc
│ ├── transforms.cpython-37.pyc
│ ├── visualize.cpython-36.pyc
│ └── wavelet.cpython-36.pyc
├── decoder_conv.py
├── decoder_parallel_conv.py
├── decoder_skip.py
├── fit.py
├── helpers.py
├── mri_helpers.py
├── onedim.py
├── pytorch_ssim
│ ├── __init__.py
│ └── __pycache__
│ │ └── __init__.cpython-36.pyc
└── transforms.py
├── out_of_distribution_image
└── cameraman.png
├── robustness_to_distribution_shift.ipynb
└── visualize_layers_singlecoil.ipynb
/ConvDecoder_architecture/ConvDecoder.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/ConvDecoder_architecture/ConvDecoder.JPG
--------------------------------------------------------------------------------
/DIP_UNET_models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/__pycache__/common.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/common.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/__pycache__/common.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/common.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/__pycache__/downsampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/downsampler.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/__pycache__/mri_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/mri_model.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/__pycache__/skip.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/skip.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/__pycache__/skip.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/skip.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/__pycache__/skip_decoder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/skip_decoder.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/__pycache__/texture_nets.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/texture_nets.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/__pycache__/unet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/__pycache__/unet.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | #from .downsampler import Downsampler
5 |
6 | def add_module(self, module):
7 | self.add_module(str(len(self) + 1), module)
8 |
9 | torch.nn.Module.add = add_module
10 |
11 | class Concat(nn.Module):
12 | def __init__(self, dim, *args):
13 | super(Concat, self).__init__()
14 | self.dim = dim
15 |
16 | for idx, module in enumerate(args):
17 | self.add_module(str(idx), module)
18 |
19 | def forward(self, input):
20 | inputs = []
21 | for module in self._modules.values():
22 | inputs.append(module(input))
23 |
24 | inputs_shapes2 = [x.shape[2] for x in inputs]
25 | inputs_shapes3 = [x.shape[3] for x in inputs]
26 |
27 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):
28 | inputs_ = inputs
29 | else:
30 | target_shape2 = min(inputs_shapes2)
31 | target_shape3 = min(inputs_shapes3)
32 |
33 | inputs_ = []
34 | for inp in inputs:
35 | diff2 = (inp.size(2) - target_shape2) // 2
36 | diff3 = (inp.size(3) - target_shape3) // 2
37 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])
38 |
39 | return torch.cat(inputs_, dim=self.dim)
40 |
41 | def __len__(self):
42 | return len(self._modules)
43 |
44 |
45 | class GenNoise(nn.Module):
46 | def __init__(self, dim2):
47 | super(GenNoise, self).__init__()
48 | self.dim2 = dim2
49 |
50 | def forward(self, input):
51 | a = list(input.size())
52 | a[1] = self.dim2
53 | # print (input.data.type())
54 |
55 | b = torch.zeros(a).type_as(input.data)
56 | b.normal_()
57 |
58 | x = torch.autograd.Variable(b)
59 |
60 | return x
61 |
62 |
63 | class Swish(nn.Module):
64 | """
65 | https://arxiv.org/abs/1710.05941
66 | The hype was so huge that I could not help but try it
67 | """
68 | def __init__(self):
69 | super(Swish, self).__init__()
70 | self.s = nn.Sigmoid()
71 |
72 | def forward(self, x):
73 | return x * self.s(x)
74 |
75 |
76 | def act(act_fun = 'LeakyReLU'):
77 | '''
78 | Either string defining an activation function or module (e.g. nn.ReLU)
79 | '''
80 | if isinstance(act_fun, str):
81 | if act_fun == 'LeakyReLU':
82 | return nn.LeakyReLU(0.2, inplace=True)
83 | elif act_fun == 'ReLU':
84 | return nn.ReLU()
85 | elif act_fun == 'Swish':
86 | return Swish()
87 | elif act_fun == 'ELU':
88 | return nn.ELU()
89 | elif act_fun == 'none':
90 | return nn.Sequential()
91 | else:
92 | assert False
93 | else:
94 | return act_fun()
95 |
96 |
97 | def bn(num_features):
98 | return nn.BatchNorm2d(num_features)
99 |
100 |
101 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'):
102 | downsampler = None
103 | if stride != 1 and downsample_mode != 'stride':
104 |
105 | if downsample_mode == 'avg':
106 | downsampler = nn.AvgPool2d(stride, stride)
107 | elif downsample_mode == 'max':
108 | downsampler = nn.MaxPool2d(stride, stride)
109 | elif downsample_mode in ['lanczos2', 'lanczos3']:
110 | downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True)
111 | else:
112 | assert False
113 |
114 | stride = 1
115 |
116 | padder = None
117 | to_pad = int((kernel_size - 1) / 2)
118 | if pad == 'reflection':
119 | padder = nn.ReflectionPad2d(to_pad)
120 | to_pad = 0
121 |
122 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias)
123 |
124 |
125 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
126 | return nn.Sequential(*layers)
--------------------------------------------------------------------------------
/DIP_UNET_models/skip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .common import *
4 |
5 | def skip(
6 | out_size,num_input_channels=2, num_output_channels=3,
7 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4],
8 | filter_size_down=3, filter_size_up=3, filter_skip_size=1,
9 | need_sigmoid=True, need_bias=True,
10 | pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU',
11 | need1x1_up=True):
12 | """Assembles encoder-decoder with skip connections.
13 |
14 | Arguments:
15 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU)
16 | pad (string): zero|reflection (default: 'zero')
17 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest')
18 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride')
19 |
20 | """
21 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip)
22 |
23 | n_scales = len(num_channels_down)
24 |
25 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) :
26 | upsample_mode = [upsample_mode]*n_scales
27 |
28 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)):
29 | downsample_mode = [downsample_mode]*n_scales
30 |
31 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) :
32 | filter_size_down = [filter_size_down]*n_scales
33 |
34 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) :
35 | filter_size_up = [filter_size_up]*n_scales
36 |
37 | last_scale = n_scales - 1
38 |
39 | cur_depth = None
40 |
41 | model = nn.Sequential()
42 | model_tmp = model
43 |
44 | input_depth = num_input_channels
45 | for i in range(len(num_channels_down)):
46 |
47 | deeper = nn.Sequential()
48 | skip = nn.Sequential()
49 |
50 | if num_channels_skip[i] != 0:
51 | model_tmp.add(Concat(1, skip, deeper))
52 | else:
53 | model_tmp.add(deeper)
54 |
55 | model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i])))
56 |
57 | if num_channels_skip[i] != 0:
58 | skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad))
59 | skip.add(bn(num_channels_skip[i]))
60 | skip.add(act(act_fun))
61 |
62 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part))
63 |
64 | deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i]))
65 | deeper.add(bn(num_channels_down[i]))
66 | deeper.add(act(act_fun))
67 |
68 | deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad))
69 | deeper.add(bn(num_channels_down[i]))
70 | deeper.add(act(act_fun))
71 |
72 | deeper_main = nn.Sequential()
73 |
74 | if i == len(num_channels_down) - 1:
75 | # The deepest
76 | k = num_channels_down[i]
77 | else:
78 | deeper.add(deeper_main)
79 | k = num_channels_up[i + 1]
80 |
81 | if i == 0:
82 | deeper.add(nn.Upsample(size=out_size, mode=upsample_mode[i]))
83 | else:
84 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i]))
85 |
86 | model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad))
87 | model_tmp.add(bn(num_channels_up[i]))
88 | model_tmp.add(act(act_fun))
89 |
90 |
91 | if need1x1_up:
92 | model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad))
93 | model_tmp.add(bn(num_channels_up[i]))
94 | model_tmp.add(act(act_fun))
95 |
96 | input_depth = num_channels_down[i]
97 | model_tmp = deeper_main
98 |
99 | model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad))
100 | if need_sigmoid:
101 | model.add(nn.Sigmoid())
102 |
103 | return model
104 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/README.md:
--------------------------------------------------------------------------------
1 | ## U-Net Model for MRI Reconstruction
2 |
3 | This directory contains a reference U-Net implementation for MRI reconstruction
4 | in PyTorch.
5 |
6 | To start training the model, run:
7 | ```bash
8 | python models/unet/train_unet.py --mode train --challenge CHALLENGE --data-path DATA --exp unet --mask-type MASK_TYPE
9 | ```
10 | where `CHALLENGE` is either `singlecoil` or `multicoil`. And `MASK_TYPE` is either `random` (for knee)
11 | or `equispaced` (for brain). Training logs and checkpoints are saved in `experiments/unet` directory.
12 |
13 | To run the model on test data:
14 | ```bash
15 | python models/unet/train_unet.py --mode test --challenge CHALLENGE --data-path DATA --exp unet --out-dir reconstructions --checkpoint MODEL
16 | ```
17 | where `MODEL` is the path to the model checkpoint from `experiments/unet/version_0/checkpoints/`.
18 |
19 | The outputs will be saved to `reconstructions` directory which can be uploaded for submission.
20 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/__pycache__/mri_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/mri_model.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/__pycache__/mri_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/mri_model.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/__pycache__/train_unet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/train_unet.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/__pycache__/train_unet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/train_unet.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/__pycache__/unet_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/unet_model.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/__pycache__/unet_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/__pycache__/unet_model.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/__pycache__/args.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/args.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/__pycache__/args.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/args.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/__pycache__/evaluate.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/evaluate.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/__pycache__/evaluate.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/evaluate.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/__pycache__/subsample.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/subsample.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/__pycache__/subsample.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/subsample.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/common/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/args.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | import argparse
9 | import pathlib
10 |
11 |
12 | class Args(argparse.ArgumentParser):
13 | """
14 | Defines global default arguments.
15 | """
16 |
17 | def __init__(self, **overrides):
18 | """
19 | Args:
20 | **overrides (dict, optional): Keyword arguments used to override default argument values
21 | """
22 |
23 | super().__init__(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
24 |
25 | self.add_argument('--seed', default=42, type=int, help='Seed for random number generators')
26 | self.add_argument('--resolution', default=320, type=int, help='Resolution of images')
27 |
28 | # Data parameters
29 | self.add_argument('--challenge', choices=['singlecoil', 'multicoil'], required=True,
30 | help='Which challenge')
31 | self.add_argument('--data-path', type=pathlib.Path, required=True,
32 | help='Path to the dataset')
33 | self.add_argument('--sample-rate', type=float, default=1.,
34 | help='Fraction of total volumes to include')
35 |
36 | # Mask parameters
37 | self.add_argument('--accelerations', nargs='+', default=[4, 8], type=int,
38 | help='Ratio of k-space columns to be sampled. If multiple values are '
39 | 'provided, then one of those is chosen uniformly at random for '
40 | 'each volume.')
41 | self.add_argument('--center-fractions', nargs='+', default=[0.08, 0.04], type=float,
42 | help='Fraction of low-frequency k-space columns to be sampled. Should '
43 | 'have the same length as accelerations')
44 |
45 | # Override defaults with passed overrides
46 | self.set_defaults(**overrides)
47 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | import argparse
9 | import pathlib
10 | from argparse import ArgumentParser
11 |
12 | import h5py
13 | import numpy as np
14 | from runstats import Statistics
15 | from skimage.measure import compare_psnr, compare_ssim
16 |
17 |
18 | def mse(gt, pred):
19 | """ Compute Mean Squared Error (MSE) """
20 | return np.mean((gt - pred) ** 2)
21 |
22 |
23 | def nmse(gt, pred):
24 | """ Compute Normalized Mean Squared Error (NMSE) """
25 | return np.linalg.norm(gt - pred) ** 2 / np.linalg.norm(gt) ** 2
26 |
27 |
28 | def psnr(gt, pred):
29 | """ Compute Peak Signal to Noise Ratio metric (PSNR) """
30 | return compare_psnr(gt, pred, data_range=gt.max())
31 |
32 |
33 | def ssim(gt, pred):
34 | """ Compute Structural Similarity Index Metric (SSIM). """
35 | return compare_ssim(
36 | gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max()
37 | )
38 |
39 |
40 | METRIC_FUNCS = dict(
41 | MSE=mse,
42 | NMSE=nmse,
43 | PSNR=psnr,
44 | SSIM=ssim,
45 | )
46 |
47 |
48 | class Metrics:
49 | """
50 | Maintains running statistics for a given collection of metrics.
51 | """
52 |
53 | def __init__(self, metric_funcs):
54 | self.metrics = {
55 | metric: Statistics() for metric in metric_funcs
56 | }
57 |
58 | def push(self, target, recons):
59 | for metric, func in METRIC_FUNCS.items():
60 | self.metrics[metric].push(func(target, recons))
61 |
62 | def means(self):
63 | return {
64 | metric: stat.mean() for metric, stat in self.metrics.items()
65 | }
66 |
67 | def stddevs(self):
68 | return {
69 | metric: stat.stddev() for metric, stat in self.metrics.items()
70 | }
71 |
72 | def __repr__(self):
73 | means = self.means()
74 | stddevs = self.stddevs()
75 | metric_names = sorted(list(means))
76 | return ' '.join(
77 | f'{name} = {means[name]:.4g} +/- {2 * stddevs[name]:.4g}' for name in metric_names
78 | )
79 |
80 |
81 | def evaluate(args, recons_key):
82 | metrics = Metrics(METRIC_FUNCS)
83 |
84 | for tgt_file in args.target_path.iterdir():
85 | with h5py.File(tgt_file) as target, h5py.File(
86 | args.predictions_path / tgt_file.name) as recons:
87 | if args.acquisition and args.acquisition != target.attrs['acquisition']:
88 | continue
89 | target = target[recons_key].value
90 | recons = recons['reconstruction'].value
91 | metrics.push(target, recons)
92 | return metrics
93 |
94 |
95 | if __name__ == '__main__':
96 | parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
97 | parser.add_argument('--target-path', type=pathlib.Path, required=True,
98 | help='Path to the ground truth data')
99 | parser.add_argument('--predictions-path', type=pathlib.Path, required=True,
100 | help='Path to reconstructions')
101 | parser.add_argument('--challenge', choices=['singlecoil', 'multicoil'], required=True,
102 | help='Which challenge')
103 | parser.add_argument('--acquisition', choices=['CORPD_FBK', 'CORPDFS_FBK'], default=None,
104 | help='If set, only volumes of the specified acquisition type are used '
105 | 'for evaluation. By default, all volumes are included.')
106 | args = parser.parse_args()
107 |
108 | recons_key = 'reconstruction_rss' if args.challenge == 'multicoil' else 'reconstruction_esc'
109 | metrics = evaluate(args, recons_key)
110 | print(metrics)
111 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/subsample.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | import numpy as np
9 | import torch
10 |
11 | def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations):
12 | if mask_type_str == 'random':
13 | return MaskFunc(center_fractions, accelerations)
14 | elif mask_type_str == 'equispaced':
15 | return EquispacedMaskFunc(center_fractions, accelerations)
16 | else:
17 | raise Exception(f"{mask_type_str} not supported")
18 |
19 | class MaskFunc:
20 | """
21 | MaskFunc creates a sub-sampling mask of a given shape.
22 |
23 | The mask selects a subset of columns from the input k-space data. If the k-space data has N
24 | columns, the mask picks out:
25 | 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to
26 | low-frequencies
27 | 2. The other columns are selected uniformly at random with a probability equal to:
28 | prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs).
29 | This ensures that the expected number of columns selected is equal to (N / acceleration)
30 |
31 | It is possible to use multiple center_fractions and accelerations, in which case one possible
32 | (center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is
33 | called.
34 |
35 | For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there
36 | is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50%
37 | probability that 8-fold acceleration with 4% center fraction is selected.
38 | """
39 |
40 | def __init__(self, center_fractions, accelerations):
41 | """
42 | Args:
43 | center_fractions (List[float]): Fraction of low-frequency columns to be retained.
44 | If multiple values are provided, then one of these numbers is chosen uniformly
45 | each time.
46 |
47 | accelerations (List[int]): Amount of under-sampling. This should have the same length
48 | as center_fractions. If multiple values are provided, then one of these is chosen
49 | uniformly each time. An acceleration of 4 retains 25% of the columns, but they may
50 | not be spaced evenly.
51 | """
52 | if len(center_fractions) != len(accelerations):
53 | raise ValueError('Number of center fractions should match number of accelerations')
54 |
55 | self.center_fractions = center_fractions
56 | self.accelerations = accelerations
57 | self.rng = np.random.RandomState()
58 |
59 | def __call__(self, shape, seed=None):
60 | """
61 | Args:
62 | shape (iterable[int]): The shape of the mask to be created. The shape should have
63 | at least 3 dimensions. Samples are drawn along the second last dimension.
64 | seed (int, optional): Seed for the random number generator. Setting the seed
65 | ensures the same mask is generated each time for the same shape.
66 | Returns:
67 | torch.Tensor: A mask of the specified shape.
68 | """
69 | if len(shape) < 3:
70 | raise ValueError('Shape should have 3 or more dimensions')
71 |
72 | self.rng.seed(seed)
73 | num_cols = shape[-2]
74 |
75 | choice = self.rng.randint(0, len(self.accelerations))
76 | center_fraction = self.center_fractions[choice]
77 | acceleration = self.accelerations[choice]
78 |
79 | # Create the mask
80 | num_low_freqs = int(round(num_cols * center_fraction))
81 | prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs)
82 | mask = self.rng.uniform(size=num_cols) < prob
83 | pad = (num_cols - num_low_freqs + 1) // 2
84 | mask[pad:pad + num_low_freqs] = True
85 |
86 | # Reshape the mask
87 | mask_shape = [1 for _ in shape]
88 | mask_shape[-2] = num_cols
89 | mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32))
90 |
91 | return mask
92 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/test_subsample.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | import numpy as np
9 | import pytest
10 | import torch
11 |
12 | from common.subsample import MaskFunc
13 |
14 |
15 | @pytest.mark.parametrize("center_fracs, accelerations, batch_size, dim", [
16 | ([0.2], [4], 4, 320),
17 | ([0.2, 0.4], [4, 8], 2, 368),
18 | ])
19 | def test_mask_reuse(center_fracs, accelerations, batch_size, dim):
20 | mask_func = MaskFunc(center_fracs, accelerations)
21 | shape = (batch_size, dim, dim, 2)
22 | mask1 = mask_func(shape, seed=123)
23 | mask2 = mask_func(shape, seed=123)
24 | mask3 = mask_func(shape, seed=123)
25 | assert torch.all(mask1 == mask2)
26 | assert torch.all(mask2 == mask3)
27 |
28 |
29 | @pytest.mark.parametrize("center_fracs, accelerations, batch_size, dim", [
30 | ([0.2], [4], 4, 320),
31 | ([0.2, 0.4], [4, 8], 2, 368),
32 | ])
33 | def test_mask_low_freqs(center_fracs, accelerations, batch_size, dim):
34 | mask_func = MaskFunc(center_fracs, accelerations)
35 | shape = (batch_size, dim, dim, 2)
36 | mask = mask_func(shape, seed=123)
37 | mask_shape = [1 for _ in shape]
38 | mask_shape[-2] = dim
39 | assert list(mask.shape) == mask_shape
40 |
41 | num_low_freqs_matched = False
42 | for center_frac in center_fracs:
43 | num_low_freqs = int(round(dim * center_frac))
44 | pad = (dim - num_low_freqs + 1) // 2
45 | if np.all(mask[pad:pad + num_low_freqs].numpy() == 1):
46 | num_low_freqs_matched = True
47 | assert num_low_freqs_matched
48 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/common/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 | import json
8 |
9 | import h5py
10 |
11 |
12 | def save_reconstructions(reconstructions, out_dir):
13 | """
14 | Saves the reconstructions from a model into h5 files that is appropriate for submission
15 | to the leaderboard.
16 |
17 | Args:
18 | reconstructions (dict[str, np.array]): A dictionary mapping input filenames to
19 | corresponding reconstructions (of shape num_slices x height x width).
20 | out_dir (pathlib.Path): Path to the output directory where the reconstructions
21 | should be saved.
22 | """
23 | out_dir.mkdir(exist_ok=True)
24 | for fname, recons in reconstructions.items():
25 | with h5py.File(out_dir / fname, 'w') as f:
26 | f.create_dataset('reconstruction', data=recons)
27 |
28 |
29 | def tensor_to_complex_np(data):
30 | """
31 | Converts a complex torch tensor to numpy array.
32 | Args:
33 | data (torch.Tensor): Input data to be converted to numpy.
34 |
35 | Returns:
36 | np.array: Complex numpy version of data
37 | """
38 | data = data.numpy()
39 | return data[..., 0] + 1j * data[..., 1]
40 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/data/README.md:
--------------------------------------------------------------------------------
1 | ## MRI Data Loader and Transforms
2 |
3 | This directory provides a reference data loader to read the fastMRI data one slice at a time and
4 | some useful data transforms to work with the data in PyTorch.
5 |
6 | Each partition (train, validation or test) of the fastMRI data is distributed as a set of HDF5
7 | files, such that each HDF5 file contains data from one MR acquisition. The set of fields and
8 | attributes in these HDF5 files depends on the track (single-coil or multi-coil) and the data
9 | partition.
10 |
11 | #### Single-Coil Track
12 |
13 | * Training & Validation data:
14 | * `kspace`: Emulated single-coil k-space data. The shape of the kspace tensor is
15 | (number of slices, height, width).
16 | * `reconstruction_rss`: Root-sum-of-squares reconstruction of the multi-coil k-space that was
17 | used to derive the emulated single-coil k-space cropped to the center 320 × 320 region.
18 | The shape of the reconstruction rss tensor is (number of slices, 320, 320).
19 | * `reconstruction_esc`: The inverse Fourier transform of the single-coil k-space data cropped
20 | to the center 320 × 320 region. The shape of the reconstruction esc tensor is (number of
21 | slices, 320, 320).
22 | * Test data:
23 | * `kspace`: Undersampled emulated single-coil k-space. The shape of the kspace tensor is
24 | (number of slices, height, width).
25 | * `mask`: Defines the undersampled Cartesian k-space trajectory. The number of elements in
26 | the mask tensor is the same as the width of k-space.
27 |
28 |
29 | #### Multi-Coil Track
30 |
31 | * Training & Validation data:
32 | * `kspace`: Multi-coil k-space data. The shape of the kspace tensor is
33 | (number of slices, number of coils, height, width).
34 | * `reconstruction_rss`: Root-sum-of-squares reconstruction of the multi-coil k-space
35 | data cropped to the center 320 × 320 region. The shape of the reconstruction rss tensor
36 | is (number of slices, 320, 320).
37 | * Test data:
38 | * `kspace`: Undersampled multi-coil k-space. The shape of the kspace tensor is
39 | (number of slices, number of coils, height, width).
40 | * `mask` Defines the undersampled Cartesian k-space trajectory. The number of elements in
41 | the mask tensor is the same as the width of k-space.
42 |
43 |
44 | ### Data Transforms
45 |
46 | `data/transforms.py` provides a number of useful data transformation functions that work with
47 | PyTorch tensors.
48 |
49 |
50 | #### Data Loader
51 |
52 | `data/mri_data.py` provides a `SliceData` class to read one MR slice at a time. It takes as input
53 | a `transform` function or callable object that can be used transform the data into the format that
54 | you need. This makes the data loader versatile and can be used to run different kinds of
55 | reconstruction methods.
56 |
57 | The following is a simple example for how to use the data loader. For more concrete examples,
58 | please look at the baseline model code in the `models` directory.
59 |
60 | ```python
61 | import pathlib
62 | from common import subsample
63 | from data import transforms, mri_data
64 |
65 | # Create a mask function
66 | mask_func = subsample.RandomMaskFunc(center_fractions=[0.08, 0.04], accelerations=[4, 8])
67 |
68 | def data_transform(kspace, target, data_attributes, filename, slice_num):
69 | # Transform the data into appropriate format
70 | # Here we simply mask the k-space and return the result
71 | kspace = transforms.to_tensor(kspace)
72 | masked_kspace, _ = transforms.apply_mask(kspace, mask_func)
73 | return masked_kspace
74 |
75 | dataset = mri_data.SliceData(
76 | root=pathlib.Path('path/to/data'),
77 | transform=data_transform,
78 | challenge='singlecoil'
79 | )
80 |
81 | for masked_kspace in dataset:
82 | # Do reconstruction
83 | pass
84 | ```
85 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/data/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/data/__pycache__/mri_data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/mri_data.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/data/__pycache__/mri_data.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/mri_data.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/data/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/data/__pycache__/transforms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/data/__pycache__/transforms.cpython-37.pyc
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/data/mri_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | import pathlib
9 | import random
10 |
11 | import h5py
12 | from torch.utils.data import Dataset
13 |
14 |
15 | class SliceData(Dataset):
16 | """
17 | A PyTorch Dataset that provides access to MR image slices.
18 | """
19 |
20 | def __init__(self, root, transform, challenge, sample_rate=1):
21 | """
22 | Args:
23 | root (pathlib.Path): Path to the dataset.
24 | transform (callable): A callable object that pre-processes the raw data into
25 | appropriate form. The transform function should take 'kspace', 'target',
26 | 'attributes', 'filename', and 'slice' as inputs. 'target' may be null
27 | for test data.
28 | challenge (str): "singlecoil" or "multicoil" depending on which challenge to use.
29 | sample_rate (float, optional): A float between 0 and 1. This controls what fraction
30 | of the volumes should be loaded.
31 | """
32 | if challenge not in ('singlecoil', 'multicoil'):
33 | raise ValueError('challenge should be either "singlecoil" or "multicoil"')
34 |
35 | self.transform = transform
36 | self.recons_key = 'reconstruction_esc' if challenge == 'singlecoil' \
37 | else 'reconstruction_rss'
38 |
39 | self.examples = []
40 | files = list(pathlib.Path(root).iterdir())
41 | if sample_rate < 1:
42 | random.shuffle(files)
43 | num_files = round(len(files) * sample_rate)
44 | files = files[:num_files]
45 | for fname in sorted(files):
46 | kspace = h5py.File(fname, 'r')['kspace']
47 | num_slices = kspace.shape[0]
48 | self.examples += [(fname, slice) for slice in range(num_slices)]
49 |
50 | def __len__(self):
51 | return len(self.examples)
52 |
53 | def __getitem__(self, i):
54 | fname, slice = self.examples[i]
55 | with h5py.File(fname, 'r') as data:
56 | kspace = data['kspace'][slice]
57 | target = data[self.recons_key][slice] if self.recons_key in data else None
58 | return self.transform(kspace, target, data.attrs, fname.name, slice)
59 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/data/test_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | import numpy as np
9 | import pytest
10 | import torch
11 |
12 | from common import utils
13 | from common.subsample import RandomMaskFunc
14 | from data import transforms
15 |
16 |
17 | def create_input(shape):
18 | input = np.arange(np.product(shape)).reshape(shape)
19 | input = torch.from_numpy(input).float()
20 | return input
21 |
22 |
23 | @pytest.mark.parametrize('shape, center_fractions, accelerations', [
24 | ([4, 32, 32, 2], [0.08], [4]),
25 | ([2, 64, 64, 2], [0.04, 0.08], [8, 4]),
26 | ])
27 | def test_apply_mask(shape, center_fractions, accelerations):
28 | mask_func = RandomMaskFunc(center_fractions, accelerations)
29 | expected_mask = mask_func(shape, seed=123)
30 | input = create_input(shape)
31 | output, mask = transforms.apply_mask(input, mask_func, seed=123)
32 | assert output.shape == input.shape
33 | assert mask.shape == expected_mask.shape
34 | assert np.all(expected_mask.numpy() == mask.numpy())
35 | assert np.all(np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy())
36 |
37 |
38 | @pytest.mark.parametrize('shape', [
39 | [3, 3],
40 | [4, 6],
41 | [10, 8, 4],
42 | ])
43 | def test_fft2(shape):
44 | shape = shape + [2]
45 | input = create_input(shape)
46 | out_torch = transforms.fft2(input).numpy()
47 | out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]
48 |
49 | input_numpy = utils.tensor_to_complex_np(input)
50 | input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
51 | out_numpy = np.fft.fft2(input_numpy, norm='ortho')
52 | out_numpy = np.fft.fftshift(out_numpy, (-2, -1))
53 | assert np.allclose(out_torch, out_numpy)
54 |
55 |
56 | @pytest.mark.parametrize('shape', [
57 | [3, 3],
58 | [4, 6],
59 | [10, 8, 4],
60 | ])
61 | def test_ifft2(shape):
62 | shape = shape + [2]
63 | input = create_input(shape)
64 | out_torch = transforms.ifft2(input).numpy()
65 | out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]
66 |
67 | input_numpy = utils.tensor_to_complex_np(input)
68 | input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
69 | out_numpy = np.fft.ifft2(input_numpy, norm='ortho')
70 | out_numpy = np.fft.fftshift(out_numpy, (-2, -1))
71 | assert np.allclose(out_torch, out_numpy)
72 |
73 |
74 | @pytest.mark.parametrize('shape', [
75 | [3, 3],
76 | [4, 6],
77 | [10, 8, 4],
78 | ])
79 | def test_complex_abs(shape):
80 | shape = shape + [2]
81 | input = create_input(shape)
82 | out_torch = transforms.complex_abs(input).numpy()
83 | input_numpy = utils.tensor_to_complex_np(input)
84 | out_numpy = np.abs(input_numpy)
85 | assert np.allclose(out_torch, out_numpy)
86 |
87 |
88 | @pytest.mark.parametrize('shape, dim', [
89 | [[3, 3], 0],
90 | [[4, 6], 1],
91 | [[10, 8, 4], 2],
92 | ])
93 | def test_root_sum_of_squares(shape, dim):
94 | input = create_input(shape)
95 | out_torch = transforms.root_sum_of_squares(input, dim).numpy()
96 | out_numpy = np.sqrt(np.sum(input.numpy() ** 2, dim))
97 | assert np.allclose(out_torch, out_numpy)
98 |
99 |
100 | @pytest.mark.parametrize('shape, target_shape', [
101 | [[10, 10], [4, 4]],
102 | [[4, 6], [2, 4]],
103 | [[8, 4], [4, 4]],
104 | ])
105 | def test_center_crop(shape, target_shape):
106 | input = create_input(shape)
107 | out_torch = transforms.center_crop(input, target_shape).numpy()
108 | assert list(out_torch.shape) == target_shape
109 |
110 |
111 | @pytest.mark.parametrize('shape, target_shape', [
112 | [[10, 10], [4, 4]],
113 | [[4, 6], [2, 4]],
114 | [[8, 4], [4, 4]],
115 | ])
116 | def test_complex_center_crop(shape, target_shape):
117 | shape = shape + [2]
118 | input = create_input(shape)
119 | out_torch = transforms.complex_center_crop(input, target_shape).numpy()
120 | assert list(out_torch.shape) == target_shape + [2, ]
121 |
122 |
123 | @pytest.mark.parametrize('shape, mean, stddev', [
124 | [[10, 10], 0, 1],
125 | [[4, 6], 4, 10],
126 | [[8, 4], 2, 3],
127 | ])
128 | def test_normalize(shape, mean, stddev):
129 | input = create_input(shape)
130 | output = transforms.normalize(input, mean, stddev).numpy()
131 | assert np.isclose(output.mean(), (input.numpy().mean() - mean) / stddev)
132 | assert np.isclose(output.std(), input.numpy().std() / stddev)
133 |
134 |
135 | @pytest.mark.parametrize('shape', [
136 | [10, 10],
137 | [20, 40, 30],
138 | ])
139 | def test_normalize_instance(shape):
140 | input = create_input(shape)
141 | output, mean, stddev = transforms.normalize_instance(input)
142 | output = output.numpy()
143 | assert np.isclose(input.numpy().mean(), mean, rtol=1e-2)
144 | assert np.isclose(input.numpy().std(), stddev, rtol=1e-2)
145 | assert np.isclose(output.mean(), 0, rtol=1e-2, atol=1e-3)
146 | assert np.isclose(output.std(), 1, rtol=1e-2, atol=1e-3)
147 |
148 |
149 | @pytest.mark.parametrize('shift, dim', [
150 | (0, 0),
151 | (1, 0),
152 | (-1, 0),
153 | (100, 0),
154 | ((1, 2), (1, 2)),
155 | ])
156 | @pytest.mark.parametrize('shape', [
157 | [5, 6, 2],
158 | [3, 4, 5],
159 | ])
160 | def test_roll(shift, dim, shape):
161 | input = np.arange(np.product(shape)).reshape(shape)
162 | out_torch = transforms.roll(torch.from_numpy(input), shift, dim).numpy()
163 | out_numpy = np.roll(input, shift, dim)
164 | assert np.allclose(out_torch, out_numpy)
165 |
166 |
167 | @pytest.mark.parametrize('shape', [
168 | [5, 3],
169 | [2, 4, 6],
170 | ])
171 | def test_fftshift(shape):
172 | input = np.arange(np.product(shape)).reshape(shape)
173 | out_torch = transforms.fftshift(torch.from_numpy(input)).numpy()
174 | out_numpy = np.fft.fftshift(input)
175 | assert np.allclose(out_torch, out_numpy)
176 |
177 |
178 | @pytest.mark.parametrize('shape', [
179 | [5, 3],
180 | [2, 4, 5],
181 | [2, 7, 5],
182 | ])
183 | def test_ifftshift(shape):
184 | input = np.arange(np.product(shape)).reshape(shape)
185 | out_torch = transforms.ifftshift(torch.from_numpy(input)).numpy()
186 | out_numpy = np.fft.ifftshift(input)
187 | assert np.allclose(out_torch, out_numpy)
188 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/data/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | import numpy as np
9 | import torch
10 |
11 |
12 | def to_tensor(data):
13 | """
14 | Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts
15 | are stacked along the last dimension.
16 |
17 | Args:
18 | data (np.array): Input numpy array
19 |
20 | Returns:
21 | torch.Tensor: PyTorch version of data
22 | """
23 | if np.iscomplexobj(data):
24 | data = np.stack((data.real, data.imag), axis=-1)
25 | return torch.from_numpy(data)
26 |
27 |
28 | def apply_mask(data, mask_func, seed=None):
29 | """
30 | Subsample given k-space by multiplying with a mask.
31 |
32 | Args:
33 | data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where
34 | dimensions -3 and -2 are the spatial dimensions, and the final dimension has size
35 | 2 (for complex values).
36 | mask_func (callable): A function that takes a shape (tuple of ints) and a random
37 | number seed and returns a mask.
38 | seed (int or 1-d array_like, optional): Seed for the random number generator.
39 |
40 | Returns:
41 | (tuple): tuple containing:
42 | masked data (torch.Tensor): Subsampled k-space data
43 | mask (torch.Tensor): The generated mask
44 | """
45 | shape = np.array(data.shape)
46 | shape[:-3] = 1
47 | mask = mask_func(shape, seed)
48 | return torch.where(mask == 0, torch.Tensor([0]), data), mask
49 |
50 |
51 | def fft2(data):
52 | """
53 | Apply centered 2 dimensional Fast Fourier Transform.
54 |
55 | Args:
56 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
57 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
58 | assumed to be batch dimensions.
59 |
60 | Returns:
61 | torch.Tensor: The FFT of the input.
62 | """
63 | assert data.size(-1) == 2
64 | data = ifftshift(data, dim=(-3, -2))
65 | data = torch.fft(data, 2, normalized=True)
66 | data = fftshift(data, dim=(-3, -2))
67 | return data
68 |
69 |
70 | def ifft2(data):
71 | """
72 | Apply centered 2-dimensional Inverse Fast Fourier Transform.
73 |
74 | Args:
75 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
76 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
77 | assumed to be batch dimensions.
78 |
79 | Returns:
80 | torch.Tensor: The IFFT of the input.
81 | """
82 | assert data.size(-1) == 2
83 | data = ifftshift(data, dim=(-3, -2))
84 | data = torch.ifft(data, 2, normalized=True)
85 | data = fftshift(data, dim=(-3, -2))
86 | return data
87 |
88 |
89 | def complex_abs(data):
90 | """
91 | Compute the absolute value of a complex valued input tensor.
92 |
93 | Args:
94 | data (torch.Tensor): A complex valued tensor, where the size of the final dimension
95 | should be 2.
96 |
97 | Returns:
98 | torch.Tensor: Absolute value of data
99 | """
100 | assert data.size(-1) == 2
101 | return (data ** 2).sum(dim=-1).sqrt()
102 |
103 |
104 | def root_sum_of_squares(data, dim=0):
105 | """
106 | Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor.
107 |
108 | Args:
109 | data (torch.Tensor): The input tensor
110 | dim (int): The dimensions along which to apply the RSS transform
111 |
112 | Returns:
113 | torch.Tensor: The RSS value
114 | """
115 | return torch.sqrt((data ** 2).sum(dim))
116 |
117 |
118 | def center_crop(data, shape):
119 | """
120 | Apply a center crop to the input real image or batch of real images.
121 |
122 | Args:
123 | data (torch.Tensor): The input tensor to be center cropped. It should have at
124 | least 2 dimensions and the cropping is applied along the last two dimensions.
125 | shape (int, int): The output shape. The shape should be smaller than the
126 | corresponding dimensions of data.
127 |
128 | Returns:
129 | torch.Tensor: The center cropped image
130 | """
131 | assert 0 < shape[0] <= data.shape[-2]
132 | assert 0 < shape[1] <= data.shape[-1]
133 | w_from = (data.shape[-2] - shape[0]) // 2
134 | h_from = (data.shape[-1] - shape[1]) // 2
135 | w_to = w_from + shape[0]
136 | h_to = h_from + shape[1]
137 | return data[..., w_from:w_to, h_from:h_to]
138 |
139 |
140 | def complex_center_crop(data, shape):
141 | """
142 | Apply a center crop to the input image or batch of complex images.
143 |
144 | Args:
145 | data (torch.Tensor): The complex input tensor to be center cropped. It should
146 | have at least 3 dimensions and the cropping is applied along dimensions
147 | -3 and -2 and the last dimensions should have a size of 2.
148 | shape (int, int): The output shape. The shape should be smaller than the
149 | corresponding dimensions of data.
150 |
151 | Returns:
152 | torch.Tensor: The center cropped image
153 | """
154 | assert 0 < shape[0] <= data.shape[-3]
155 | assert 0 < shape[1] <= data.shape[-2]
156 | w_from = (data.shape[-3] - shape[0]) // 2
157 | h_from = (data.shape[-2] - shape[1]) // 2
158 | w_to = w_from + shape[0]
159 | h_to = h_from + shape[1]
160 | return data[..., w_from:w_to, h_from:h_to, :]
161 |
162 |
163 | def normalize(data, mean, stddev, eps=0.):
164 | """
165 | Normalize the given tensor using:
166 | (data - mean) / (stddev + eps)
167 |
168 | Args:
169 | data (torch.Tensor): Input data to be normalized
170 | mean (float): Mean value
171 | stddev (float): Standard deviation
172 | eps (float): Added to stddev to prevent dividing by zero
173 |
174 | Returns:
175 | torch.Tensor: Normalized tensor
176 | """
177 | return (data - mean) / (stddev + eps)
178 |
179 |
180 | def normalize_instance(data, eps=0.):
181 | """
182 | Normalize the given tensor using:
183 | (data - mean) / (stddev + eps)
184 | where mean and stddev are computed from the data itself.
185 |
186 | Args:
187 | data (torch.Tensor): Input data to be normalized
188 | eps (float): Added to stddev to prevent dividing by zero
189 |
190 | Returns:
191 | torch.Tensor: Normalized tensor
192 | """
193 | mean = data.mean()
194 | std = data.std()
195 | return normalize(data, mean, std, eps), mean, std
196 |
197 |
198 | # Helper functions
199 |
200 | def roll(x, shift, dim):
201 | """
202 | Similar to np.roll but applies to PyTorch Tensors
203 | """
204 | if isinstance(shift, (tuple, list)):
205 | assert len(shift) == len(dim)
206 | for s, d in zip(shift, dim):
207 | x = roll(x, s, d)
208 | return x
209 | shift = shift % x.size(dim)
210 | if shift == 0:
211 | return x
212 | left = x.narrow(dim, 0, x.size(dim) - shift)
213 | right = x.narrow(dim, x.size(dim) - shift, shift)
214 | return torch.cat((right, left), dim=dim)
215 |
216 |
217 | def fftshift(x, dim=None):
218 | """
219 | Similar to np.fft.fftshift but applies to PyTorch Tensors
220 | """
221 | if dim is None:
222 | dim = tuple(range(x.dim()))
223 | shift = [dim // 2 for dim in x.shape]
224 | elif isinstance(dim, int):
225 | shift = x.shape[dim] // 2
226 | else:
227 | shift = [x.shape[i] // 2 for i in dim]
228 | return roll(x, shift, dim)
229 |
230 |
231 | def ifftshift(x, dim=None):
232 | """
233 | Similar to np.fft.ifftshift but applies to PyTorch Tensors
234 | """
235 | if dim is None:
236 | dim = tuple(range(x.dim()))
237 | shift = [(dim + 1) // 2 for dim in x.shape]
238 | elif isinstance(dim, int):
239 | shift = (x.shape[dim] + 1) // 2
240 | else:
241 | shift = [(x.shape[i] + 1) // 2 for i in dim]
242 | return roll(x, shift, dim)
243 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/mri_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | from collections import defaultdict
9 |
10 | import numpy as np
11 | import pytorch_lightning as pl
12 | import torch
13 | import torchvision
14 | from torch.utils.data import DistributedSampler, DataLoader
15 |
16 | from common import evaluate
17 | from common.utils import save_reconstructions
18 | from .data.mri_data import SliceData
19 |
20 |
21 | class MRIModel(pl.LightningModule):
22 | """
23 | Abstract super class for Deep Learning based reconstruction models.
24 | This is a subclass of the LightningModule class from pytorch_lightning, with
25 | some additional functionality specific to fastMRI:
26 | - fastMRI data loaders
27 | - Evaluating reconstructions
28 | - Visualization
29 | - Saving test reconstructions
30 |
31 | To implement a new reconstruction model, inherit from this class and implement the
32 | following methods:
33 | - train_data_transform, val_data_transform, test_data_transform:
34 | Create and return data transformer objects for each data split
35 | - training_step, validation_step, test_step:
36 | Define what happens in one step of training, validation and testing respectively
37 | - configure_optimizers:
38 | Create and return the optimizers
39 | Other methods from LightningModule can be overridden as needed.
40 | """
41 |
42 | def __init__(self, hparams):
43 | super().__init__()
44 | self.hparams = hparams
45 |
46 | def _create_data_loader(self, data_transform, data_partition, sample_rate=None):
47 | sample_rate = sample_rate or self.hparams.sample_rate
48 | dataset = SliceData(
49 | root=self.hparams.data_path / f'{self.hparams.challenge}_{data_partition}',
50 | transform=data_transform,
51 | sample_rate=sample_rate,
52 | challenge=self.hparams.challenge
53 | )
54 | sampler = DistributedSampler(dataset)
55 | return DataLoader(
56 | dataset=dataset,
57 | batch_size=self.hparams.batch_size,
58 | num_workers=0,
59 | pin_memory=True,
60 | sampler=sampler,
61 | )
62 |
63 | def train_data_transform(self):
64 | raise NotImplementedError
65 |
66 | @pl.data_loader
67 | def train_dataloader(self):
68 | return self._create_data_loader(self.train_data_transform(), data_partition='train')
69 |
70 | def val_data_transform(self):
71 | raise NotImplementedError
72 |
73 | @pl.data_loader
74 | def val_dataloader(self):
75 | return self._create_data_loader(self.val_data_transform(), data_partition='val')
76 |
77 | def test_data_transform(self):
78 | raise NotImplementedError
79 |
80 | @pl.data_loader
81 | def test_dataloader(self):
82 | return self._create_data_loader(self.test_data_transform(), data_partition='test', sample_rate=1.)
83 |
84 | def _evaluate(self, val_logs):
85 | losses = []
86 | outputs = defaultdict(list)
87 | targets = defaultdict(list)
88 | for log in val_logs:
89 | losses.append(log['val_loss'].cpu().numpy())
90 | for i, (fname, slice) in enumerate(zip(log['fname'], log['slice'])):
91 | outputs[fname].append((slice, log['output'][i]))
92 | targets[fname].append((slice, log['target'][i]))
93 | metrics = dict(val_loss=losses, nmse=[], ssim=[], psnr=[])
94 | for fname in outputs:
95 | output = np.stack([out for _, out in sorted(outputs[fname])])
96 | target = np.stack([tgt for _, tgt in sorted(targets[fname])])
97 | metrics['nmse'].append(evaluate.nmse(target, output))
98 | metrics['ssim'].append(evaluate.ssim(target, output))
99 | metrics['psnr'].append(evaluate.psnr(target, output))
100 | metrics = {metric: np.mean(values) for metric, values in metrics.items()}
101 | print(metrics, '\n')
102 | return dict(log=metrics, **metrics)
103 |
104 | def _visualize(self, val_logs):
105 | def _normalize(image):
106 | image = image[np.newaxis]
107 | image -= image.min()
108 | return image / image.max()
109 |
110 | def _save_image(image, tag):
111 | grid = torchvision.utils.make_grid(torch.Tensor(image), nrow=4, pad_value=1)
112 | self.logger.experiment.add_image(tag, grid)
113 |
114 | # Only process first size to simplify visualization.
115 | visualize_size = val_logs[0]['output'].shape
116 | val_logs = [x for x in val_logs if x['output'].shape == visualize_size]
117 | num_logs = len(val_logs)
118 | num_viz_images = 16
119 | step = (num_logs + num_viz_images - 1) // num_viz_images
120 | outputs, targets = [], []
121 | for i in range(0, num_logs, step):
122 | outputs.append(_normalize(val_logs[i]['output'][0]))
123 | targets.append(_normalize(val_logs[i]['target'][0]))
124 | outputs = np.stack(outputs)
125 | targets = np.stack(targets)
126 | _save_image(targets, 'Target')
127 | _save_image(outputs, 'Reconstruction')
128 | _save_image(np.abs(targets - outputs), 'Error')
129 |
130 | def validation_end(self, val_logs):
131 | self._visualize(val_logs)
132 | return self._evaluate(val_logs)
133 |
134 | def test_end(self, test_logs):
135 | outputs = defaultdict(list)
136 | for log in test_logs:
137 | for i, (fname, slice) in enumerate(zip(log['fname'], log['slice'])):
138 | outputs[fname].append((slice, log['output'][i]))
139 | for fname in outputs:
140 | outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])
141 | save_reconstructions(outputs, self.hparams.exp_dir / self.hparams.exp / 'reconstructions')
142 | return dict()
143 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/run_bart.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import sys\n",
10 | "sys.path.insert(0,'/root/bart-0.5.00/python/')\n",
11 | "\n",
12 | "\n",
13 | "import logging\n",
14 | "import multiprocessing\n",
15 | "import pathlib\n",
16 | "import random\n",
17 | "import time\n",
18 | "from collections import defaultdict\n",
19 | "\n",
20 | "import numpy as np\n",
21 | "import torch\n",
22 | "\n",
23 | "import bart\n",
24 | "from common import utils\n",
25 | "from common.args import Args\n",
26 | "from common.subsample import create_mask_for_mask_type\n",
27 | "from common.utils import tensor_to_complex_np\n",
28 | "from data import transforms\n",
29 | "from data.mri_data import SliceData\n",
30 | "\n",
31 | "logging.basicConfig(level=logging.INFO)\n",
32 | "logger = logging.getLogger(__name__)\n",
33 | "\n",
34 | "\n",
35 | "class DataTransform:\n",
36 | " \"\"\"\n",
37 | " Data Transformer that masks input k-space.\n",
38 | " \"\"\"\n",
39 | "\n",
40 | " def __init__(self, mask_func):\n",
41 | " \"\"\"\n",
42 | " Args:\n",
43 | " mask_func (common.subsample.MaskFunc): A function that can create a mask of\n",
44 | " appropriate shape.\n",
45 | " \"\"\"\n",
46 | " self.mask_func = mask_func\n",
47 | "\n",
48 | " def __call__(self, kspace, target, attrs, fname, slice):\n",
49 | " \"\"\"\n",
50 | " Args:\n",
51 | " kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil\n",
52 | " data or (rows, cols, 2) for single coil data.\n",
53 | " target (numpy.array, optional): Target image\n",
54 | " attrs (dict): Acquisition related information stored in the HDF5 object.\n",
55 | " fname (str): File name\n",
56 | " slice (int): Serial number of the slice.\n",
57 | " Returns:\n",
58 | " (tuple): tuple containing:\n",
59 | " masked_kspace (torch.Tensor): Sub-sampled k-space with the same shape as kspace.\n",
60 | " fname (str): File name containing the current data item\n",
61 | " slice (int): The index of the current slice in the volume\n",
62 | " \"\"\"\n",
63 | " kspace = transforms.to_tensor(kspace)\n",
64 | " seed = tuple(map(ord, fname))\n",
65 | " # Apply mask to raw k-space\n",
66 | " masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)\n",
67 | " return masked_kspace, fname, slice\n",
68 | "\n",
69 | "\n",
70 | "def create_data_loader(args):\n",
71 | " dev_mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations)\n",
72 | " data = SliceData(\n",
73 | " root=args.data_path + str(f'{args.challenge}_val'),\n",
74 | " transform=DataTransform(dev_mask),\n",
75 | " challenge=args.challenge,\n",
76 | " sample_rate=args.sample_rate\n",
77 | " )\n",
78 | " return data\n",
79 | "\n",
80 | "\n",
81 | "def cs_total_variation(args, kspace):\n",
82 | " \"\"\"\n",
83 | " Run ESPIRIT coil sensitivity estimation and Total Variation Minimization based\n",
84 | " reconstruction algorithm using the BART toolkit.\n",
85 | " \"\"\"\n",
86 | "\n",
87 | " if args.challenge == 'singlecoil':\n",
88 | " kspace = kspace.unsqueeze(0)\n",
89 | " kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0)\n",
90 | " kspace = tensor_to_complex_np(kspace)\n",
91 | "\n",
92 | " # Estimate sensitivity maps\n",
93 | " sens_maps = bart.bart(1, f'ecalib -d0 -m1', kspace)\n",
94 | "\n",
95 | " # Use Total Variation Minimization to reconstruct the image\n",
96 | " pred = bart.bart(\n",
97 | " 1, f'pics -d0 -S -R T:7:0:{args.reg_wt} -i {args.num_iters}', kspace, sens_maps\n",
98 | " )\n",
99 | " pred = torch.from_numpy(np.abs(pred[0]))\n",
100 | "\n",
101 | " # Crop the predicted image to selected resolution if bigger\n",
102 | " smallest_width = min(args.resolution, pred.shape[-1])\n",
103 | " smallest_height = min(args.resolution, pred.shape[-2])\n",
104 | " return transforms.center_crop(pred, (smallest_height, smallest_width))\n",
105 | "\n",
106 | "\n",
107 | "def run_model(i):\n",
108 | " masked_kspace, fname, slice = data[i]\n",
109 | " prediction = cs_total_variation(args, masked_kspace)\n",
110 | " return fname, slice, prediction\n",
111 | "\n",
112 | "\n",
113 | "def main():\n",
114 | " if args.num_procs == 0:\n",
115 | " start_time = time.perf_counter()\n",
116 | " outputs = []\n",
117 | " for i in range(len(data)):\n",
118 | " outputs.append(run_model(i))\n",
119 | " save_outputs([run_model(i)], args.output_path)\n",
120 | " time_taken = time.perf_counter() - start_time\n",
121 | " else:\n",
122 | " with multiprocessing.Pool(args.num_procs) as pool:\n",
123 | " start_time = time.perf_counter()\n",
124 | " outputs = pool.map(run_model, range(len(data)))\n",
125 | " time_taken = time.perf_counter() - start_time\n",
126 | " save_outputs(outputs, args.output_path)\n",
127 | " logging.info(f'Run Time = {time_taken:}s')\n",
128 | " \n",
129 | "\n",
130 | "\n",
131 | "import json\n",
132 | "\n",
133 | "import h5py\n",
134 | "\n",
135 | "\n",
136 | "def save_reconstructions(reconstructions, out_dir):\n",
137 | " \"\"\"\n",
138 | " Saves the reconstructions from a model into h5 files that is appropriate for submission\n",
139 | " to the leaderboard.\n",
140 | "\n",
141 | " Args:\n",
142 | " reconstructions (dict[str, np.array]): A dictionary mapping input filenames to\n",
143 | " corresponding reconstructions (of shape num_slices x height x width).\n",
144 | " out_dir (pathlib.Path): Path to the output directory where the reconstructions\n",
145 | " should be saved.\n",
146 | " \"\"\"\n",
147 | " for fname, recons in reconstructions.items():\n",
148 | " with h5py.File(out_dir + fname, 'w') as f:\n",
149 | " f.create_dataset('reconstruction', data=recons)\n",
150 | " \n",
151 | "def save_outputs(outputs, output_path):\n",
152 | " reconstructions = defaultdict(list)\n",
153 | " for fname, slice, pred in outputs:\n",
154 | " reconstructions[fname].append((slice, pred))\n",
155 | " reconstructions = {\n",
156 | " fname: np.stack([pred for _, pred in sorted(slice_preds)])\n",
157 | " for fname, slice_preds in reconstructions.items()\n",
158 | " }\n",
159 | " save_reconstructions(reconstructions, output_path)"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 3,
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | "class Args():\n",
169 | " def __init__(self,ch,path,rate,acc,cent,outpath,iters,reg,procs,maskt,seed,res):\n",
170 | " self.challenge = ch\n",
171 | " self.data_path = path\n",
172 | " self.sample_rate = rate\n",
173 | " self.accelerations = acc\n",
174 | " self.center_fractions = cent\n",
175 | " self.output_path = outpath\n",
176 | " self.num_iters = iters\n",
177 | " self.reg_wt = reg\n",
178 | " self.num_procs = procs\n",
179 | " self.mask_type = maskt\n",
180 | " self.seed = seed\n",
181 | " self.resolution = res\n",
182 | "args = Args(\"multicoil\",\"/hdd/\",1,[4],[0.07],\"/root/multires_deep_decoder/mri/FINAL/TV/\",100,0.01,4,\"random\",42,320)"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 4,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "import os\n",
192 | "import subprocess\n",
193 | "import sys\n",
194 | "\n",
195 | "os.environ['TOOLBOX_PATH'] = \"/root/bart-0.5.00/\" # visible in this process + all children\n",
196 | "random.seed(args.seed)\n",
197 | "np.random.seed(args.seed)\n",
198 | "torch.manual_seed(args.seed)\n",
199 | "\n",
200 | "dataset = create_data_loader(args)"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": 4,
206 | "metadata": {},
207 | "outputs": [
208 | {
209 | "data": {
210 | "text/plain": [
211 | "((PosixPath('/hdd/multicoil_val/file1000017.h5'), 7), 7135, 3)"
212 | ]
213 | },
214 | "execution_count": 4,
215 | "metadata": {},
216 | "output_type": "execute_result"
217 | }
218 | ],
219 | "source": [
220 | "dataset.examples[80],len(dataset),len(dataset[0])"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 6,
226 | "metadata": {
227 | "scrolled": true
228 | },
229 | "outputs": [
230 | {
231 | "name": "stderr",
232 | "output_type": "stream",
233 | "text": [
234 | "INFO:root:Run Time = 333.7682563047856s\n",
235 | "INFO:root:Run Time = 340.6625652872026s\n",
236 | "INFO:root:Run Time = 301.1880727428943s\n",
237 | "INFO:root:Run Time = 274.9473853390664s\n",
238 | "INFO:root:Run Time = 395.30600419454277s\n",
239 | "INFO:root:Run Time = 338.08261120319366s\n",
240 | "INFO:root:Run Time = 357.19963121786714s\n",
241 | "INFO:root:Run Time = 415.9970267675817s\n",
242 | "INFO:root:Run Time = 351.20377369225025s\n",
243 | "INFO:root:Run Time = 335.272654408589s\n",
244 | "INFO:root:Run Time = 320.07918173260987s\n",
245 | "INFO:root:Run Time = 337.77712962403893s\n",
246 | "INFO:root:Run Time = 350.6248935814947s\n",
247 | "INFO:root:Run Time = 319.2754284348339s\n",
248 | "INFO:root:Run Time = 339.5878656320274s\n",
249 | "INFO:root:Run Time = 308.0043003074825s\n",
250 | "INFO:root:Run Time = 358.6736814174801s\n",
251 | "INFO:root:Run Time = 358.03209225833416s\n",
252 | "INFO:root:Run Time = 369.3443151284009s\n",
253 | "INFO:root:Run Time = 376.01454484649s\n",
254 | "INFO:root:Run Time = 327.3354206569493s\n",
255 | "INFO:root:Run Time = 311.5621940083802s\n",
256 | "INFO:root:Run Time = 315.0435768086463s\n",
257 | "INFO:root:Run Time = 325.7221032604575s\n",
258 | "INFO:root:Run Time = 291.2186458352953s\n",
259 | "INFO:root:Run Time = 424.02862623520195s\n",
260 | "INFO:root:Run Time = 380.9776658322662s\n",
261 | "INFO:root:Run Time = 251.65318821929395s\n",
262 | "INFO:root:Run Time = 340.7662496250123s\n",
263 | "INFO:root:Run Time = 467.76796199940145s\n",
264 | "INFO:root:Run Time = 311.6906248535961s\n",
265 | "INFO:root:Run Time = 534.1931731365621s\n",
266 | "INFO:root:Run Time = 298.5228198878467s\n",
267 | "INFO:root:Run Time = 305.9466844946146s\n",
268 | "INFO:root:Run Time = 354.2080086879432s\n",
269 | "INFO:root:Run Time = 425.45318215712905s\n",
270 | "INFO:root:Run Time = 444.71551893651485s\n",
271 | "INFO:root:Run Time = 395.4665974806994s\n",
272 | "INFO:root:Run Time = 450.11666345596313s\n",
273 | "INFO:root:Run Time = 348.74674895592034s\n",
274 | "INFO:root:Run Time = 424.314414229244s\n",
275 | "INFO:root:Run Time = 429.9421614725143s\n",
276 | "INFO:root:Run Time = 523.5043416228145s\n",
277 | "INFO:root:Run Time = 407.0870214384049s\n",
278 | "INFO:root:Run Time = 447.2023910563439s\n",
279 | "INFO:root:Run Time = 460.6034576483071s\n",
280 | "INFO:root:Run Time = 394.97820459865034s\n",
281 | "INFO:root:Run Time = 376.82941082306206s\n",
282 | "INFO:root:Run Time = 370.4078102763742s\n",
283 | "INFO:root:Run Time = 429.3003299552947s\n",
284 | "INFO:root:Run Time = 442.9237650651485s\n",
285 | "INFO:root:Run Time = 492.6802581641823s\n",
286 | "INFO:root:Run Time = 456.0776470042765s\n",
287 | "INFO:root:Run Time = 309.43278669193387s\n",
288 | "INFO:root:Run Time = 440.0848300922662s\n",
289 | "INFO:root:Run Time = 432.87664224021137s\n",
290 | "INFO:root:Run Time = 430.8312914352864s\n",
291 | "INFO:root:Run Time = 421.66501943580806s\n",
292 | "INFO:root:Run Time = 504.81006176769733s\n",
293 | "INFO:root:Run Time = 538.9878176227212s\n",
294 | "INFO:root:Run Time = 450.25816937349737s\n",
295 | "INFO:root:Run Time = 468.013383731246s\n",
296 | "INFO:root:Run Time = 398.28537249937654s\n",
297 | "INFO:root:Run Time = 517.1249352190644s\n",
298 | "INFO:root:Run Time = 337.58466753922403s\n",
299 | "INFO:root:Run Time = 390.9817121960223s\n",
300 | "INFO:root:Run Time = 336.4889906011522s\n",
301 | "INFO:root:Run Time = 472.28560002334416s\n",
302 | "INFO:root:Run Time = 474.4735741727054s\n",
303 | "INFO:root:Run Time = 394.586749330163s\n",
304 | "INFO:root:Run Time = 460.2016584146768s\n",
305 | "INFO:root:Run Time = 424.244948470965s\n",
306 | "INFO:root:Run Time = 444.0486744288355s\n",
307 | "INFO:root:Run Time = 465.4690223969519s\n",
308 | "INFO:root:Run Time = 355.617677655071s\n",
309 | "INFO:root:Run Time = 471.27355794236064s\n",
310 | "INFO:root:Run Time = 389.24966777674854s\n",
311 | "INFO:root:Run Time = 460.87415464036167s\n",
312 | "INFO:root:Run Time = 440.83318879827857s\n",
313 | "INFO:root:Run Time = 432.70478705503047s\n",
314 | "INFO:root:Run Time = 406.13156732544303s\n",
315 | "INFO:root:Run Time = 481.7247441355139s\n",
316 | "INFO:root:Run Time = 385.23040606081486s\n",
317 | "INFO:root:Run Time = 436.92768172733486s\n",
318 | "INFO:root:Run Time = 585.8447301853448s\n",
319 | "INFO:root:Run Time = 542.3469700664282s\n",
320 | "INFO:root:Run Time = 528.0813769325614s\n",
321 | "INFO:root:Run Time = 479.7424617353827s\n",
322 | "INFO:root:Run Time = 466.7871928848326s\n",
323 | "INFO:root:Run Time = 338.71273868344724s\n",
324 | "INFO:root:Run Time = 441.5500371027738s\n",
325 | "INFO:root:Run Time = 428.8800290077925s\n",
326 | "INFO:root:Run Time = 481.04396928846836s\n",
327 | "INFO:root:Run Time = 462.27772130444646s\n",
328 | "INFO:root:Run Time = 486.94645626842976s\n",
329 | "INFO:root:Run Time = 472.66510405763984s\n",
330 | "INFO:root:Run Time = 382.8884541783482s\n",
331 | "INFO:root:Run Time = 526.5198707617819s\n",
332 | "INFO:root:Run Time = 516.3673057667911s\n",
333 | "INFO:root:Run Time = 430.80366771668196s\n",
334 | "INFO:root:Run Time = 505.9666231442243s\n",
335 | "INFO:root:Run Time = 431.4829865563661s\n",
336 | "INFO:root:Run Time = 328.0189682934433s\n",
337 | "INFO:root:Run Time = 566.4797005280852s\n",
338 | "INFO:root:Run Time = 414.8851204998791s\n",
339 | "INFO:root:Run Time = 432.82284201681614s\n",
340 | "INFO:root:Run Time = 431.45999561063945s\n",
341 | "INFO:root:Run Time = 392.26196658983827s\n",
342 | "INFO:root:Run Time = 419.4514162931591s\n",
343 | "INFO:root:Run Time = 447.83714538253844s\n",
344 | "INFO:root:Run Time = 424.4092899840325s\n",
345 | "INFO:root:Run Time = 416.14934803172946s\n",
346 | "INFO:root:Run Time = 431.89119616523385s\n",
347 | "INFO:root:Run Time = 384.00734995119274s\n",
348 | "INFO:root:Run Time = 499.7653317414224s\n",
349 | "INFO:root:Run Time = 486.80337627232075s\n",
350 | "INFO:root:Run Time = 463.53818341344595s\n",
351 | "INFO:root:Run Time = 462.03550822660327s\n",
352 | "INFO:root:Run Time = 385.84180288761854s\n",
353 | "INFO:root:Run Time = 533.4132130518556s\n",
354 | "INFO:root:Run Time = 424.44709764048457s\n",
355 | "INFO:root:Run Time = 391.3850141931325s\n",
356 | "INFO:root:Run Time = 571.5244209095836s\n",
357 | "INFO:root:Run Time = 398.41338567994535s\n",
358 | "INFO:root:Run Time = 379.29579119198024s\n",
359 | "INFO:root:Run Time = 377.02867659181356s\n",
360 | "INFO:root:Run Time = 471.6786704082042s\n",
361 | "INFO:root:Run Time = 331.51744497194886s\n",
362 | "INFO:root:Run Time = 443.2164211887866s\n",
363 | "INFO:root:Run Time = 336.92970431409776s\n",
364 | "INFO:root:Run Time = 397.283165872097s\n",
365 | "INFO:root:Run Time = 396.1671455092728s\n",
366 | "INFO:root:Run Time = 330.75096010789275s\n",
367 | "INFO:root:Run Time = 447.02856631204486s\n",
368 | "INFO:root:Run Time = 390.8315053060651s\n",
369 | "INFO:root:Run Time = 392.9724945295602s\n",
370 | "INFO:root:Run Time = 507.23743664473295s\n",
371 | "INFO:root:Run Time = 288.0143873449415s\n",
372 | "INFO:root:Run Time = 496.09178671613336s\n",
373 | "INFO:root:Run Time = 430.7832391001284s\n",
374 | "INFO:root:Run Time = 418.55701276659966s\n",
375 | "INFO:root:Run Time = 510.4864778164774s\n",
376 | "INFO:root:Run Time = 437.0698127951473s\n",
377 | "INFO:root:Run Time = 301.17781374789774s\n",
378 | "INFO:root:Run Time = 379.37366267479956s\n",
379 | "INFO:root:Run Time = 420.580105535686s\n",
380 | "INFO:root:Run Time = 487.29186703264713s\n",
381 | "INFO:root:Run Time = 550.3076649773866s\n",
382 | "INFO:root:Run Time = 419.5049721375108s\n",
383 | "INFO:root:Run Time = 411.9790954552591s\n",
384 | "INFO:root:Run Time = 371.1687485575676s\n",
385 | "INFO:root:Run Time = 370.06326948851347s\n",
386 | "INFO:root:Run Time = 494.05042749270797s\n",
387 | "INFO:root:Run Time = 488.94175343960524s\n",
388 | "INFO:root:Run Time = 428.4196085240692s\n",
389 | "INFO:root:Run Time = 451.088182432577s\n",
390 | "INFO:root:Run Time = 446.9839248675853s\n",
391 | "INFO:root:Run Time = 450.20198553428054s\n",
392 | "INFO:root:Run Time = 412.72148968465626s\n",
393 | "INFO:root:Run Time = 551.8574121091515s\n",
394 | "INFO:root:Run Time = 714.7977640237659s\n",
395 | "INFO:root:Run Time = 515.068151017651s\n",
396 | "INFO:root:Run Time = 595.8476344123483s\n",
397 | "INFO:root:Run Time = 617.6524196490645s\n",
398 | "INFO:root:Run Time = 600.2221750151366s\n",
399 | "INFO:root:Run Time = 734.3725633509457s\n",
400 | "INFO:root:Run Time = 459.1810085400939s\n",
401 | "INFO:root:Run Time = 546.2445830088109s\n",
402 | "INFO:root:Run Time = 518.3114790972322s\n",
403 | "INFO:root:Run Time = 392.59291512332857s\n",
404 | "INFO:root:Run Time = 484.23065080679953s\n",
405 | "INFO:root:Run Time = 590.5678717344999s\n",
406 | "INFO:root:Run Time = 503.79120522737503s\n",
407 | "INFO:root:Run Time = 552.8279358427972s\n",
408 | "INFO:root:Run Time = 727.5677454862744s\n",
409 | "INFO:root:Run Time = 584.7055314276367s\n",
410 | "INFO:root:Run Time = 594.3992878198624s\n",
411 | "INFO:root:Run Time = 626.2585807479918s\n",
412 | "INFO:root:Run Time = 631.4217041246593s\n",
413 | "INFO:root:Run Time = 589.3917139805853s\n",
414 | "INFO:root:Run Time = 519.05597788468s\n",
415 | "INFO:root:Run Time = 759.129485655576s\n",
416 | "INFO:root:Run Time = 530.3207853045315s\n",
417 | "INFO:root:Run Time = 657.3508451338857s\n",
418 | "INFO:root:Run Time = 802.8821316771209s\n",
419 | "INFO:root:Run Time = 589.222182802856s\n",
420 | "INFO:root:Run Time = 471.22330814413726s\n",
421 | "INFO:root:Run Time = 711.6352181173861s\n",
422 | "INFO:root:Run Time = 660.118361864239s\n",
423 | "INFO:root:Run Time = 545.1319549642503s\n",
424 | "INFO:root:Run Time = 604.8600967172533s\n",
425 | "INFO:root:Run Time = 605.9351858440787s\n",
426 | "INFO:root:Run Time = 569.5467841047794s\n",
427 | "INFO:root:Run Time = 750.0642936062068s\n",
428 | "INFO:root:Run Time = 712.3209960609674s\n",
429 | "INFO:root:Run Time = 572.0809261798859s\n",
430 | "INFO:root:Run Time = 590.359125174582s\n",
431 | "INFO:root:Run Time = 676.413792964071s\n",
432 | "INFO:root:Run Time = 488.95933836884797s\n"
433 | ]
434 | }
435 | ],
436 | "source": [
437 | "this_data = []\n",
438 | "prev_slicenu = -1\n",
439 | "for i,d in enumerate(dataset):\n",
440 | " if dataset.examples[i][1] > prev_slicenu: \n",
441 | " this_data.append(d)\n",
442 | " prev_slicenu = dataset.examples[i][1]\n",
443 | " else:\n",
444 | " data = this_data\n",
445 | " main()\n",
446 | " this_data = [d]\n",
447 | " prev_slicenu = 0\n",
448 | " if i == len(dataset) - 1:\n",
449 | " data = this_data\n",
450 | " main()"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": null,
456 | "metadata": {},
457 | "outputs": [],
458 | "source": []
459 | }
460 | ],
461 | "metadata": {
462 | "kernelspec": {
463 | "display_name": "Python 3",
464 | "language": "python",
465 | "name": "python3"
466 | },
467 | "language_info": {
468 | "codemirror_mode": {
469 | "name": "ipython",
470 | "version": 3
471 | },
472 | "file_extension": ".py",
473 | "mimetype": "text/x-python",
474 | "name": "python",
475 | "nbconvert_exporter": "python",
476 | "pygments_lexer": "ipython3",
477 | "version": "3.6.7"
478 | }
479 | },
480 | "nbformat": 4,
481 | "nbformat_minor": 4
482 | }
483 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/run_bart_val.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 | This source code is licensed under the MIT license found in the
4 | LICENSE file in the root directory of this source tree.
5 | """
6 |
7 | import logging
8 | import multiprocessing
9 | import pathlib
10 | import random
11 | import time
12 | from collections import defaultdict
13 |
14 | import numpy as np
15 | import torch
16 |
17 | import bart
18 | from common import utils
19 | from common.args import Args
20 | from common.subsample import create_mask_for_mask_type
21 | from common.utils import tensor_to_complex_np
22 | from data import transforms
23 | from data.mri_data import SliceData
24 |
25 | logging.basicConfig(level=logging.INFO)
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | class DataTransform:
30 | """
31 | Data Transformer that masks input k-space.
32 | """
33 |
34 | def __init__(self, mask_func):
35 | """
36 | Args:
37 | mask_func (common.subsample.MaskFunc): A function that can create a mask of
38 | appropriate shape.
39 | """
40 | self.mask_func = mask_func
41 |
42 | def __call__(self, kspace, target, attrs, fname, slice):
43 | """
44 | Args:
45 | kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
46 | data or (rows, cols, 2) for single coil data.
47 | target (numpy.array, optional): Target image
48 | attrs (dict): Acquisition related information stored in the HDF5 object.
49 | fname (str): File name
50 | slice (int): Serial number of the slice.
51 | Returns:
52 | (tuple): tuple containing:
53 | masked_kspace (torch.Tensor): Sub-sampled k-space with the same shape as kspace.
54 | fname (str): File name containing the current data item
55 | slice (int): The index of the current slice in the volume
56 | """
57 | kspace = transforms.to_tensor(kspace)
58 | seed = tuple(map(ord, fname))
59 | # Apply mask to raw k-space
60 | masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)
61 | return masked_kspace, fname, slice
62 |
63 |
64 | def create_data_loader(args):
65 | dev_mask = create_mask_for_mask_type(args.mask_type, args.center_fractions, args.accelerations)
66 | data = SliceData(
67 | root=args.data_path / f'{args.challenge}_val',
68 | transform=DataTransform(dev_mask),
69 | challenge=args.challenge,
70 | sample_rate=args.sample_rate
71 | )
72 | return data
73 |
74 |
75 | def cs_total_variation(args, kspace):
76 | """
77 | Run ESPIRIT coil sensitivity estimation and Total Variation Minimization based
78 | reconstruction algorithm using the BART toolkit.
79 | """
80 |
81 | if args.challenge == 'singlecoil':
82 | kspace = kspace.unsqueeze(0)
83 | kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0)
84 | kspace = tensor_to_complex_np(kspace)
85 |
86 | # Estimate sensitivity maps
87 | sens_maps = bart.bart(1, f'ecalib -d0 -m1', kspace)
88 |
89 | # Use Total Variation Minimization to reconstruct the image
90 | pred = bart.bart(
91 | 1, f'pics -d0 -S -R T:7:0:{args.reg_wt} -i {args.num_iters}', kspace, sens_maps
92 | )
93 | pred = torch.from_numpy(np.abs(pred[0]))
94 |
95 | # Crop the predicted image to selected resolution if bigger
96 | smallest_width = min(args.resolution, pred.shape[-1])
97 | smallest_height = min(args.resolution, pred.shape[-2])
98 | return transforms.center_crop(pred, (smallest_height, smallest_width))
99 |
100 |
101 | def run_model(i):
102 | masked_kspace, fname, slice = data[i]
103 | prediction = cs_total_variation(args, masked_kspace)
104 | return fname, slice, prediction
105 |
106 |
107 | def main():
108 | if args.num_procs == 0:
109 | start_time = time.perf_counter()
110 | outputs = []
111 | for i in range(len(data)):
112 | outputs.append(run_model(i))
113 | time_taken = time.perf_counter() - start_time
114 | else:
115 | with multiprocessing.Pool(args.num_procs) as pool:
116 | start_time = time.perf_counter()
117 | outputs = pool.map(run_model, range(len(data)))
118 | time_taken = time.perf_counter() - start_time
119 | logging.info(f'Run Time = {time_taken:}s')
120 | save_outputs(outputs, args.output_path)
121 |
122 |
123 | def save_outputs(outputs, output_path):
124 | reconstructions = defaultdict(list)
125 | for fname, slice, pred in outputs:
126 | reconstructions[fname].append((slice, pred))
127 | reconstructions = {
128 | fname: np.stack([pred for _, pred in sorted(slice_preds)])
129 | for fname, slice_preds in reconstructions.items()
130 | }
131 | utils.save_reconstructions(reconstructions, output_path)
132 |
133 |
134 | if __name__ == '__main__':
135 | parser = Args()
136 | parser.add_argument('--output-path', type=pathlib.Path, default=None,
137 | help='Path to save the reconstructions to')
138 | parser.add_argument('--num-iters', type=int, default=200,
139 | help='Number of iterations to run the reconstruction algorithm')
140 | parser.add_argument('--reg-wt', type=float, default=0.01,
141 | help='Regularization weight parameter')
142 | parser.add_argument('--num-procs', type=int, default=20,
143 | help='Number of processes. Set to 0 to disable multiprocessing.')
144 | parser.add_argument('--mask-type',default='random')
145 | args = parser.parse_args()
146 |
147 | random.seed(args.seed)
148 | np.random.seed(args.seed)
149 | torch.manual_seed(args.seed)
150 |
151 | data = create_data_loader(args)
152 | main()
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/train_unet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | import pathlib
9 | import random
10 |
11 | import numpy as np
12 | import torch
13 | from pytorch_lightning import Trainer
14 | from pytorch_lightning.logging import TestTubeLogger
15 | from torch.nn import functional as F
16 | from torch.optim import RMSprop
17 |
18 | from .common.args import Args
19 | from .common.subsample import create_mask_for_mask_type
20 | from .data import transforms
21 | from .mri_model import MRIModel
22 | from .unet_model import UnetModel
23 |
24 |
25 | class DataTransform:
26 | """
27 | Data Transformer for training U-Net models.
28 | """
29 |
30 | def __init__(self, resolution, which_challenge, mask_func=None, use_seed=True):
31 | """
32 | Args:
33 | mask_func (common.subsample.MaskFunc): A function that can create a mask of
34 | appropriate shape.
35 | resolution (int): Resolution of the image.
36 | which_challenge (str): Either "singlecoil" or "multicoil" denoting the dataset.
37 | use_seed (bool): If true, this class computes a pseudo random number generator seed
38 | from the filename. This ensures that the same mask is used for all the slices of
39 | a given volume every time.
40 | """
41 | if which_challenge not in ('singlecoil', 'multicoil'):
42 | raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"')
43 | self.mask_func = mask_func
44 | self.resolution = resolution
45 | self.which_challenge = which_challenge
46 | self.use_seed = use_seed
47 |
48 | def __call__(self, kspace, target, attrs, fname, slice):
49 | """
50 | Args:
51 | kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
52 | data or (rows, cols, 2) for single coil data.
53 | target (numpy.array): Target image
54 | attrs (dict): Acquisition related information stored in the HDF5 object.
55 | fname (str): File name
56 | slice (int): Serial number of the slice.
57 | Returns:
58 | (tuple): tuple containing:
59 | image (torch.Tensor): Zero-filled input image.
60 | target (torch.Tensor): Target image converted to a torch Tensor.
61 | mean (float): Mean value used for normalization.
62 | std (float): Standard deviation value used for normalization.
63 | """
64 | kspace = transforms.to_tensor(kspace)
65 | # Apply mask
66 | if self.mask_func:
67 | seed = None if not self.use_seed else tuple(map(ord, fname))
68 | masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)
69 | else:
70 | masked_kspace = kspace
71 |
72 | # Inverse Fourier Transform to get zero filled solution
73 | image = transforms.ifft2(masked_kspace)
74 | # Crop input image to given resolution if larger
75 | smallest_width = min(self.resolution, image.shape[-2])
76 | smallest_height = min(self.resolution, image.shape[-3])
77 | if target is not None:
78 | smallest_width = min(smallest_width, target.shape[-1])
79 | smallest_height = min(smallest_height, target.shape[-2])
80 | crop_size = (smallest_height, smallest_width)
81 | image = transforms.complex_center_crop(image, crop_size)
82 | # Absolute value
83 | image = transforms.complex_abs(image)
84 | # Apply Root-Sum-of-Squares if multicoil data
85 | if self.which_challenge == 'multicoil':
86 | image = transforms.root_sum_of_squares(image)
87 | # Normalize input
88 | image, mean, std = transforms.normalize_instance(image, eps=1e-11)
89 | image = image.clamp(-6, 6)
90 | # Normalize target
91 | if target is not None:
92 | target = transforms.to_tensor(target)
93 | target = transforms.center_crop(target, crop_size)
94 | target = transforms.normalize(target, mean, std, eps=1e-11)
95 | target = target.clamp(-6, 6)
96 | else:
97 | target = torch.Tensor([0])
98 | return image, target, mean, std, fname, slice
99 |
100 |
101 | class UnetMRIModel(MRIModel):
102 | def __init__(self, hparams):
103 | super().__init__(hparams)
104 | self.unet = UnetModel(
105 | in_chans=1,
106 | out_chans=1,
107 | chans=hparams.num_chans,
108 | num_pool_layers=hparams.num_pools,
109 | drop_prob=hparams.drop_prob
110 | )
111 |
112 | def forward(self, input):
113 | return self.unet(input.unsqueeze(1)).squeeze(1)
114 |
115 | def training_step(self, batch, batch_idx):
116 | input, target, mean, std, _, _ = batch
117 | output = self.forward(input)
118 | loss = F.l1_loss(output, target)
119 | logs = {'loss': loss.item()}
120 | return dict(loss=loss, log=logs)
121 |
122 | def validation_step(self, batch, batch_idx):
123 | input, target, mean, std, fname, slice = batch
124 | output = self.forward(input)
125 | mean = mean.unsqueeze(1).unsqueeze(2)
126 | std = std.unsqueeze(1).unsqueeze(2)
127 | return {
128 | 'fname': fname,
129 | 'slice': slice,
130 | 'output': (output * std + mean).cpu().numpy(),
131 | 'target': (target * std + mean).cpu().numpy(),
132 | 'val_loss': F.l1_loss(output, target),
133 | }
134 |
135 | def test_step(self, batch, batch_idx):
136 | input, _, mean, std, fname, slice = batch
137 | output = self.forward(input)
138 | mean = mean.unsqueeze(1).unsqueeze(2)
139 | std = std.unsqueeze(1).unsqueeze(2)
140 | return {
141 | 'fname': fname,
142 | 'slice': slice,
143 | 'output': (output * std + mean).cpu().numpy(),
144 | }
145 |
146 | def configure_optimizers(self):
147 | optim = RMSprop(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
148 | scheduler = torch.optim.lr_scheduler.StepLR(optim, self.hparams.lr_step_size, self.hparams.lr_gamma)
149 | return [optim], [scheduler]
150 |
151 | def train_data_transform(self):
152 | mask = create_mask_for_mask_type(self.hparams.mask_type, self.hparams.center_fractions,
153 | self.hparams.accelerations)
154 | return DataTransform(self.hparams.resolution, self.hparams.challenge, mask, use_seed=False)
155 |
156 | def val_data_transform(self):
157 | mask = create_mask_for_mask_type(self.hparams.mask_type, self.hparams.center_fractions,
158 | self.hparams.accelerations)
159 | return DataTransform(self.hparams.resolution, self.hparams.challenge, mask)
160 |
161 | def test_data_transform(self):
162 | return DataTransform(self.hparams.resolution, self.hparams.challenge)
163 |
164 | @staticmethod
165 | def add_model_specific_args(parser):
166 | parser.add_argument('--num-pools', type=int, default=4, help='Number of U-Net pooling layers')
167 | parser.add_argument('--drop-prob', type=float, default=0.0, help='Dropout probability')
168 | parser.add_argument('--num-chans', type=int, default=32, help='Number of U-Net channels')
169 | parser.add_argument('--batch-size', default=16, type=int, help='Mini batch size')
170 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
171 | parser.add_argument('--lr-step-size', type=int, default=40,
172 | help='Period of learning rate decay')
173 | parser.add_argument('--lr-gamma', type=float, default=0.1,
174 | help='Multiplicative factor of learning rate decay')
175 | parser.add_argument('--weight-decay', type=float, default=0.,
176 | help='Strength of weight decay regularization')
177 | parser.add_argument('--mask_type',default='random')
178 | return parser
179 |
180 |
181 | def create_trainer(args, logger):
182 | return Trainer(
183 | #num_nodes=1,
184 | logger=logger,
185 | default_save_path=args.exp_dir,
186 | checkpoint_callback=True,
187 | max_nb_epochs=args.num_epochs,
188 | gpus=args.gpus,
189 | distributed_backend='ddp',
190 | check_val_every_n_epoch=1,
191 | val_check_interval=1.,
192 | early_stop_callback=False
193 | )
194 |
195 |
196 | def main(args):
197 | if args.mode == 'train':
198 | load_version = 0 if args.resume else None
199 | logger = TestTubeLogger(save_dir=args.exp_dir, name=args.exp, version=load_version)
200 | trainer = create_trainer(args, logger)
201 | model = UnetMRIModel(args)
202 | trainer.fit(model)
203 | else: # args.mode == 'test'
204 | assert args.checkpoint is not None
205 | model = UnetMRIModel.load_from_checkpoint(str(args.checkpoint))
206 | model.hparams.sample_rate = 1.
207 | trainer = create_trainer(args, logger=False)
208 | trainer.test(model)
209 |
210 |
211 | if __name__ == '__main__':
212 | parser = Args()
213 | parser.add_argument('--mode', choices=['train', 'test'], default='train')
214 | parser.add_argument('--num-epochs', type=int, default=50, help='Number of training epochs')
215 | parser.add_argument('--gpus', type=int, default=1)
216 | parser.add_argument('--exp-dir', type=pathlib.Path, default='experiments',
217 | help='Path where model and results should be saved')
218 | parser.add_argument('--exp', type=str, help='Name of the experiment')
219 | parser.add_argument('--checkpoint', type=pathlib.Path,
220 | help='Path to pre-trained model. Use with --mode test')
221 | parser.add_argument('--resume', action='store_true',
222 | help='If set, resume the training from a previous model checkpoint. ')
223 | parser = UnetMRIModel.add_model_specific_args(parser)
224 | args = parser.parse_args()
225 | random.seed(args.seed)
226 | np.random.seed(args.seed)
227 | torch.manual_seed(args.seed)
228 | main(args)
229 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/unet_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 |
8 | import torch
9 | from torch import nn
10 | from torch.nn import functional as F
11 |
12 |
13 | class ConvBlock(nn.Module):
14 | """
15 | A Convolutional Block that consists of two convolution layers each followed by
16 | instance normalization, LeakyReLU activation and dropout.
17 | """
18 |
19 | def __init__(self, in_chans, out_chans, drop_prob):
20 | """
21 | Args:
22 | in_chans (int): Number of channels in the input.
23 | out_chans (int): Number of channels in the output.
24 | drop_prob (float): Dropout probability.
25 | """
26 | super().__init__()
27 |
28 | self.in_chans = in_chans
29 | self.out_chans = out_chans
30 | self.drop_prob = drop_prob
31 |
32 | self.layers = nn.Sequential(
33 | nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
34 | nn.InstanceNorm2d(out_chans),
35 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
36 | nn.Dropout2d(drop_prob),
37 | nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
38 | nn.InstanceNorm2d(out_chans),
39 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
40 | nn.Dropout2d(drop_prob)
41 | )
42 |
43 | def forward(self, input):
44 | """
45 | Args:
46 | input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
47 |
48 | Returns:
49 | (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
50 | """
51 | return self.layers(input)
52 |
53 | def __repr__(self):
54 | return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans}, ' \
55 | f'drop_prob={self.drop_prob})'
56 |
57 |
58 | class TransposeConvBlock(nn.Module):
59 | """
60 | A Transpose Convolutional Block that consists of one convolution transpose layers followed by
61 | instance normalization and LeakyReLU activation.
62 | """
63 |
64 | def __init__(self, in_chans, out_chans):
65 | """
66 | Args:
67 | in_chans (int): Number of channels in the input.
68 | out_chans (int): Number of channels in the output.
69 | """
70 | super().__init__()
71 |
72 | self.in_chans = in_chans
73 | self.out_chans = out_chans
74 |
75 | self.layers = nn.Sequential(
76 | nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2, bias=False),
77 | nn.InstanceNorm2d(out_chans),
78 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
79 | )
80 |
81 | def forward(self, input):
82 | """
83 | Args:
84 | input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
85 |
86 | Returns:
87 | (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
88 | """
89 | return self.layers(input)
90 |
91 | def __repr__(self):
92 | return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans})'
93 |
94 |
95 | class UnetModel(nn.Module):
96 | """
97 | PyTorch implementation of a U-Net model.
98 |
99 | This is based on:
100 | Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks
101 | for biomedical image segmentation. In International Conference on Medical image
102 | computing and computer-assisted intervention, pages 234–241. Springer, 2015.
103 | """
104 |
105 | def __init__(self, in_chans, out_chans, chans, num_pool_layers, drop_prob):
106 | """
107 | Args:
108 | in_chans (int): Number of channels in the input to the U-Net model.
109 | out_chans (int): Number of channels in the output to the U-Net model.
110 | chans (int): Number of output channels of the first convolution layer.
111 | num_pool_layers (int): Number of down-sampling and up-sampling layers.
112 | drop_prob (float): Dropout probability.
113 | """
114 | super().__init__()
115 |
116 | self.in_chans = in_chans
117 | self.out_chans = out_chans
118 | self.chans = chans
119 | self.num_pool_layers = num_pool_layers
120 | self.drop_prob = drop_prob
121 |
122 | self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)])
123 | ch = chans
124 | for i in range(num_pool_layers - 1):
125 | self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob)]
126 | ch *= 2
127 | self.conv = ConvBlock(ch, ch * 2, drop_prob)
128 |
129 | self.up_conv = nn.ModuleList()
130 | self.up_transpose_conv = nn.ModuleList()
131 | for i in range(num_pool_layers - 1):
132 | self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)]
133 | self.up_conv += [ConvBlock(ch * 2, ch, drop_prob)]
134 | ch //= 2
135 |
136 | self.up_transpose_conv += [TransposeConvBlock(ch * 2, ch)]
137 | self.up_conv += [
138 | nn.Sequential(
139 | ConvBlock(ch * 2, ch, drop_prob),
140 | nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
141 | )]
142 |
143 | def forward(self, input):
144 | """
145 | Args:
146 | input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
147 |
148 | Returns:
149 | (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
150 | """
151 | stack = []
152 | output = input
153 |
154 | # Apply down-sampling layers
155 | for i, layer in enumerate(self.down_sample_layers):
156 | output = layer(output)
157 | stack.append(output)
158 | output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
159 |
160 | output = self.conv(output)
161 |
162 | # Apply up-sampling layers
163 | for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
164 | downsample_layer = stack.pop()
165 | output = transpose_conv(output)
166 |
167 | # Reflect pad on the right/botton if needed to handle odd input dimensions.
168 | padding = [0, 0, 0, 0]
169 | if output.shape[-1] != downsample_layer.shape[-1]:
170 | padding[1] = 1 # Padding right
171 | if output.shape[-2] != downsample_layer.shape[-2]:
172 | padding[3] = 1 # Padding bottom
173 | if sum(padding) != 0:
174 | output = F.pad(output, padding, "reflect")
175 |
176 | output = torch.cat([output, downsample_layer], dim=1)
177 | output = conv(output)
178 |
179 | return output
180 |
--------------------------------------------------------------------------------
/DIP_UNET_models/unet_and_tv/uuu.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MLI-lab/ConvDecoder/db3a13cbaffe436bc07870b93c6a1d7b47b44f85/DIP_UNET_models/unet_and_tv/uuu.zip
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ConvDecoder
2 |
3 | Check out our colab-demo for a quick example on how the decoder works for multi-coil accelerated MRI reconstruction:
4 |
5 | [](https://colab.research.google.com/drive/1xu_NS6ClikkOM1TTPL7EDqOjQZCvCvlL#offline=true&sandboxMode=true)
6 |
7 |
8 |