├── .gitignore ├── LICENSE.md ├── README.md ├── complex_layers ├── __init__.py ├── cmplx_activation.py ├── cmplx_conv.py ├── cmplx_dropout.py ├── cmplx_fc.py ├── cmplx_upsample.py └── radial_bn.py ├── complex_net ├── __init__.py ├── cmplx_blocks.py └── cmplx_unet.py ├── configs.py ├── main.py └── utils ├── __init__.py ├── dataset.py ├── loss.py └── polar_transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | 47 | # Translations 48 | *.mo 49 | *.pot 50 | 51 | # Django stuff: 52 | *.log 53 | 54 | # Sphinx documentation 55 | docs/_build/ 56 | 57 | # PyBuilder 58 | target/ 59 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 6 | documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 7 | rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit 8 | persons to whom the Software is furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the 11 | Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 14 | WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 15 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 16 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | PyTorch implementation of complex convolutional network for Magnetic Reasonance Imaging (MRI) reconstruction 3 | 4 | 5 | If you find this code useful, please cite the following paper: 6 | 7 | @article{el2020deep, 8 | title={Deep complex convolutional network for fast reconstruction of 3D late gadolinium enhancement cardiac MRI}, 9 | author={El-Rewaidy, Hossam and Neisius, Ulf and Mancio, Jennifer and Kucukseymen, Selcuk and Rodriguez, Jennifer and Paskavitz, Amanda and Menze, Bjoern and Nezafat, Reza}, 10 | journal={NMR in Biomedicine}, 11 | volume={33}, 12 | number={7}, 13 | pages={e4312}, 14 | year={2020}, 15 | publisher={Wiley Online Library} 16 | } 17 | -------------------------------------------------------------------------------- /complex_layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .cmplx_conv import * 2 | from .cmplx_dropout import * 3 | from .cmplx_fc import * 4 | from .cmplx_activation import * 5 | from .cmplx_upsample import * 6 | from .radial_bn import * 7 | -------------------------------------------------------------------------------- /complex_layers/cmplx_activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | from utils.polar_transforms import ( 5 | convert_polar_to_cylindrical, 6 | convert_cylindrical_to_polar 7 | ) 8 | 9 | 10 | class CReLU(nn.ReLU): 11 | def __init__(self, inplace: bool=False): 12 | super(CReLU, self).__init__(inplace) 13 | 14 | 15 | class ModReLU(nn.Module): 16 | def __init__(self, in_channels, inplace=True): 17 | """ModReLU 18 | 19 | Parameters 20 | ---------- 21 | in_channels : int 22 | The number of input channels. 23 | inplace : bool 24 | If True, the input is modified. 25 | """ 26 | super(ModReLU, self).__init__() 27 | self.inplace = inplace 28 | self.in_channels = in_channels 29 | self.b = Parameter(torch.Tensor(in_channels), requires_grad=True) 30 | self.reset_parameters() 31 | self.relu = nn.ReLU(self.inplace) 32 | 33 | def reset_parameters(self): 34 | self.b.data.uniform_(-0.1, 0.1) 35 | 36 | def forward(self, input): 37 | real, imag = torch.unbind(input, -1) 38 | mag, phase = convert_cylindrical_to_polar(real, imag) 39 | brdcst_b = torch.swapaxes(torch.broadcast_to(self.b, mag.shape), -1, 1) 40 | mag = self.relu(mag + brdcst_b) 41 | real, imag = convert_polar_to_cylindrical(mag, phase) 42 | output = torch.stack((real, imag), dim=-1) 43 | return output 44 | 45 | 46 | class ZReLU(nn.Module): 47 | def __init__(self): 48 | super(ZReLU, self).__init__() 49 | 50 | def forward(self, input): 51 | real, imag = torch.unbind(input, dim=-1) 52 | mag, phase = convert_cylindrical_to_polar(real, imag) 53 | 54 | phase = torch.stack([phase, phase], dim=-1) 55 | output = torch.where(phase >= 0.0, input, torch.tensor(0.0).to(input.device)) 56 | output = torch.where(phase <= np.pi / 2, output, torch.tensor(0.0).to(input.device)) 57 | 58 | return output 59 | -------------------------------------------------------------------------------- /complex_layers/cmplx_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ComplexConv(nn.Module): 6 | def __init__( 7 | self, 8 | rank, 9 | in_channels, 10 | out_channels, 11 | kernel_size, 12 | stride=1, 13 | padding=0, 14 | output_padding=0, 15 | dilation=1, 16 | groups=1, 17 | bias=True, 18 | conv_transposed=False 19 | ): 20 | super(ComplexConv, self).__init__() 21 | self.rank = rank 22 | self.in_channels = in_channels 23 | self.out_channels = out_channels 24 | self.kernel_size = kernel_size 25 | self.stride = stride 26 | self.padding = padding 27 | self.output_padding = output_padding 28 | self.dilation = dilation 29 | self.groups = groups 30 | self.bias = bias 31 | self.conv_transposed = conv_transposed 32 | 33 | self.conv_args = { 34 | "in_channels": self.in_channels, 35 | "out_channels": self.out_channels, 36 | "kernel_size": self.kernel_size, 37 | "stride": self.stride, 38 | "padding": self.padding, 39 | "groups": self.groups, 40 | "bias": self.bias 41 | } 42 | 43 | if self.conv_transposed: 44 | self.conv_args["output_padding"] = self.output_padding 45 | else: 46 | self.conv_args["dilation"] = self.dilation 47 | 48 | self.conv_func = {1: nn.Conv1d if not self.conv_transposed else nn.ConvTranspose1d, 49 | 2: nn.Conv2d if not self.conv_transposed else nn.ConvTranspose2d, 50 | 3: nn.Conv3d if not self.conv_transposed else nn.ConvTranspose3d}[self.rank] 51 | 52 | self.real_conv = self.conv_func(**self.conv_args) 53 | self.imag_conv = self.conv_func(**self.conv_args) 54 | 55 | def forward(self, input): 56 | ''' 57 | Considering a complex-valued input z = x + iy to be convolved by complex-valued filter h = a + ib 58 | where Output O = z * h, where * is a complex convolution operator, then O = x*a + i(x*b)+ i(y*a) - y*b 59 | so we need to calculate each of the 4 convolution operations in the previous equation, 60 | one simple way to implement this as two conolution layers, one layer for the real weights (a) 61 | and the other for imaginary weights (b), this can be done by concatenating both real and imaginary 62 | parts of the input and convolve over both of them as follows: 63 | c_r = [x; y] * a , and c_i= [x; y] * b, so that 64 | O_real = c_r[real] - c_i[real], and O_imag = c_r[imag] - c_i[imag] 65 | ''' 66 | input_real, input_imag = torch.unbind(input, dim=-1) 67 | 68 | output_real = self.real_conv(input_real) - self.imag_conv(input_imag) 69 | output_imag = self.real_conv(input_imag) + self.imag_conv(input_real) 70 | 71 | output = torch.stack([output_real, output_imag], dim=-1) 72 | return output 73 | 74 | 75 | class ComplexConv1d(ComplexConv): 76 | """Applies a 1D Complex convolution over an input signal composed of several input 77 | planes. 78 | 79 | Args: 80 | in_channels (int): Number of channels in the input image 81 | out_channels (int): Number of channels produced by the convolution 82 | kernel_size (int or tuple): Size of the convolving kernel 83 | stride (int or tuple, optional): Stride of the convolution. Default: 1 84 | padding (int or tuple, optional): Zero-padding added to both sides of 85 | the input. Default: 0 86 | dilation (int or tuple, optional): Spacing between kernel 87 | elements. Default: 1 88 | groups (int, optional): Number of blocked connections from input 89 | channels to output channels. Default: 1 90 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` 91 | Shape: 92 | - Input: :math:`(N, C_{in}, L_{in}, 2)` 93 | - Output: :math:`(N, C_{out}, L_{out}, 2)` 94 | Attributes: 95 | weight (Tensor): the learnable weights of the module of shape 96 | (out_channels, in_channels, kernel_size, 2) 97 | bias (Tensor): the learnable bias of the module of shape 98 | (out_channels, 2) 99 | """ 100 | 101 | def __init__(self, in_channels, 102 | out_channels, 103 | kernel_size, 104 | stride=1, 105 | padding=0, 106 | dilation=1, 107 | groups=1, 108 | bias=True): 109 | super(ComplexConv1d, self).__init__( 110 | rank=1, 111 | in_channels=in_channels, 112 | out_channels=out_channels, 113 | kernel_size=kernel_size, 114 | stride=stride, 115 | padding=padding, 116 | dilation=dilation, 117 | groups=groups, 118 | bias=bias 119 | ) 120 | 121 | 122 | class ComplexConv2d(ComplexConv): 123 | """Applies a 2D Complex convolution over an input signal composed of several input 124 | planes. 125 | 126 | Args: 127 | in_channels (int): Number of channels in the input image 128 | out_channels (int): Number of channels produced by the convolution 129 | kernel_size (int or tuple): Size of the convolving kernel 130 | stride (int or tuple, optional): Stride of the convolution. Default: 1 131 | padding (int or tuple, optional): Zero-padding added to both sides of 132 | the input. Default: 0 133 | dilation (int or tuple, optional): Spacing between kernel 134 | elements. Default: 1 135 | groups (int, optional): Number of blocked connections from input 136 | channels to output channels. Default: 1 137 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` 138 | Shape: 139 | - Input: :math:`(N, C_{in}, L_{in}, 2)` 140 | - Output: :math:`(N, C_{out}, L_{out}, 2)` 141 | Attributes: 142 | weight (Tensor): the learnable weights of the module of shape 143 | (out_channels, in_channels, kernel_size, 2) 144 | bias (Tensor): the learnable bias of the module of shape 145 | (out_channels, 2) 146 | """ 147 | 148 | def __init__(self, in_channels, 149 | out_channels, 150 | kernel_size, 151 | stride=(1, 1), 152 | padding=(0, 0), 153 | dilation=(1, 1), 154 | groups=1, 155 | bias=True): 156 | super(ComplexConv2d, self).__init__( 157 | rank=2, 158 | in_channels=in_channels, 159 | out_channels=out_channels, 160 | kernel_size=kernel_size, 161 | stride=stride, 162 | padding=padding, 163 | dilation=dilation, 164 | groups=groups, 165 | bias=bias 166 | ) 167 | 168 | 169 | class ComplexConv3d(ComplexConv): 170 | """Applies a 3D complex convolution over an input signal composed of several input 171 | planes. 172 | Args: 173 | in_channels (int): Number of channels in the input image 174 | out_channels (int): Number of channels produced by the convolution 175 | kernel_size (int or tuple): Size of the convolving kernel 176 | stride (int or tuple, optional): Stride of the convolution. Default: 1 177 | padding (int or tuple, optional): Zero-padding added to both sides of 178 | the input. Default: 0 179 | dilation (int or tuple, optional): Spacing between kernel 180 | elements. Default: 1 181 | groups (int, optional): Number of blocked connections from input 182 | channels to output channels. Default: 1 183 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` 184 | Shape: 185 | - Input: :math:`(N, C_{in}, L_{in}, 2)` 186 | - Output: :math:`(N, C_{out}, L_{out}, 2)` 187 | Attributes: 188 | weight (Tensor): the learnable weights of the module of shape 189 | (out_channels, in_channels, kernel_size, 2) 190 | bias (Tensor): the learnable bias of the module of shape 191 | (out_channels, 2) 192 | """ 193 | 194 | def __init__(self, in_channels, 195 | out_channels, 196 | kernel_size, 197 | stride=(1, 1, 1), 198 | padding=(0, 0, 0), 199 | dilation=(1, 1, 1), 200 | groups=1, 201 | bias=True): 202 | super(ComplexConv3d, self).__init__( 203 | rank=3, 204 | in_channels=in_channels, 205 | out_channels=out_channels, 206 | kernel_size=kernel_size, 207 | stride=stride, 208 | padding=padding, 209 | dilation=dilation, 210 | groups=groups, 211 | bias=bias 212 | ) 213 | -------------------------------------------------------------------------------- /complex_layers/cmplx_dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ComplexDropout(nn.Module): 6 | 7 | def __init__(self, rank, p=0.5, inplace=True): 8 | super(ComplexDropout, self).__init__() 9 | if p < 0 or p > 1: 10 | raise ValueError("dropout probability has to be between 0 and 1, " 11 | "but got {}".format(p)) 12 | self.rank = rank 13 | self.p = p 14 | self.inplace = inplace 15 | 16 | def extra_repr(self): 17 | inplace_str = ', inplace' if self.inplace else '' 18 | return 'p={}{}'.format(self.p, inplace_str) 19 | 20 | def forward(self, input): 21 | if not self.training or self.p == 0: 22 | return input 23 | 24 | if self.p == 1: 25 | return torch.FloatTensor(input.shape).to(input.device).zero_() 26 | 27 | msk = torch.FloatTensor(input.shape[:-1]).to(input.device).uniform_() > self.p 28 | msk = torch.stack([msk, msk], dim=-1) 29 | 30 | output = input * msk.to(torch.float32) 31 | 32 | return output 33 | 34 | 35 | class ComplexDropout1d(ComplexDropout): 36 | r"""Randomly zeroes whole channels of the complex input tensor. 37 | The channels to zero are randomized on every forward call. 38 | Usually the input comes from :class:`nn.Conv3d` modules. 39 | Args: 40 | p (float, optional): probability of an element to be zeroed. 41 | inplace (bool, optional): If set to ``True``, will do this operation 42 | in-place 43 | Shape: 44 | - Input: :math:`(N, C, D, H, W, 2)` 45 | - Output: :math:`(N, C, D, H, W, 2)` (same shape as input) 46 | """ 47 | def __init__(self, p=0.5, inplace=False): 48 | super(ComplexDropout1d, self).__init__( 49 | rank=1, 50 | p=p, 51 | inplace=inplace 52 | ) 53 | 54 | 55 | class ComplexDropout2d(ComplexDropout): 56 | r"""Randomly zeroes whole channels of the complex input tensor. 57 | The channels to zero-out are randomized on every forward call. 58 | Usually the input comes from :class:`nn.Conv2d` modules. 59 | Args: 60 | p (float, optional): probability of an element to be zero-ed. 61 | inplace (bool, optional): If set to ``True``, will do this operation 62 | in-place 63 | Shape: 64 | - Input: :math:`(N, C, H, W, 2)` 65 | - Output: :math:`(N, C, H, W, 2)` (same shape as input) 66 | 67 | """ 68 | def __init__(self, p=0.5, inplace=False): 69 | super(ComplexDropout2d, self).__init__( 70 | rank=2, 71 | p=p, 72 | inplace=inplace 73 | ) 74 | 75 | 76 | class ComplexDropout3d(ComplexDropout): 77 | r"""Randomly zeroes whole channels of the complex input tensor. 78 | The channels to zero are randomized on every forward call. 79 | Usually the input comes from :class:`nn.Conv3d` modules. 80 | Args: 81 | p (float, optional): probability of an element to be zeroed. 82 | inplace (bool, optional): If set to ``True``, will do this operation 83 | in-place 84 | Shape: 85 | - Input: :math:`(N, C, D, H, W, 2)` 86 | - Output: :math:`(N, C, D, H, W, 2)` (same shape as input) 87 | 88 | """ 89 | def __init__(self, p=0.5, inplace=False): 90 | super(ComplexDropout3d, self).__init__( 91 | rank=3, 92 | p=p, 93 | inplace=inplace 94 | ) 95 | 96 | 97 | if __name__ == '__main__': 98 | x = torch.rand((2, 2, 8, 8, 2)) 99 | print(f"non-zero elements: {len(x.nonzero())}") 100 | dropout = ComplexDropout2d(p=0.5) 101 | y = dropout(x) 102 | print(f"non-zero elements: {len(y.nonzero())}") 103 | -------------------------------------------------------------------------------- /complex_layers/cmplx_fc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ComplexLinear(nn.Module): 6 | def __init__(self, in_features, out_features, bias=True): 7 | super(ComplexLinear, self).__init__() 8 | self.real_linear = nn.Linear(in_features=in_features, out_features=out_features, bias=bias) 9 | self.imag_linear = nn.Linear(in_features=in_features, out_features=out_features, bias=bias) 10 | 11 | def forward(self, input): 12 | real, imag = torch.unbind(input, dim=-1) 13 | 14 | real_out = self.real_linear(real) - self.imag_linear(imag) 15 | imag_out = self.real_linear(imag) + self.imag_linear(real) 16 | 17 | output = torch.stack((real_out, imag_out), dim=-1) 18 | 19 | return output 20 | -------------------------------------------------------------------------------- /complex_layers/cmplx_upsample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ComplexUpsample(nn.Module): 6 | def __init__( 7 | self, 8 | size=None, 9 | scale_factor=None, 10 | mode='nearest', 11 | align_corners=False, 12 | recompute_scale_factor=False, 13 | ): 14 | """Upsample layer for complex inputs. 15 | 16 | Parameters 17 | ---------- 18 | size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional): 19 | output spatial sizes 20 | scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional): 21 | multiplier for spatial size. Has to match input size if it is a tuple. 22 | mode (str, optional): the upsampling algorithm: one of ``'nearest'``, 23 | ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. 24 | Default: ``'nearest'`` 25 | align_corners (bool, optional): if ``True``, the corner pixels of the input 26 | and output tensors are aligned, and thus preserving the values at 27 | those pixels. This only has effect when :attr:`mode` is 28 | ``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``. 29 | Default: ``False`` 30 | recompute_scale_factor (bool, optional): recompute the scale_factor for use in the 31 | interpolation calculation. If `recompute_scale_factor` is ``True``, then 32 | `scale_factor` must be passed in and `scale_factor` is used to compute the 33 | output `size`. The computed output `size` will be used to infer new scales for 34 | the interpolation. Note that when `scale_factor` is floating-point, it may differ 35 | from the recomputed `scale_factor` due to rounding and precision issues. 36 | If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will 37 | be used directly for interpolation. 38 | """ 39 | super(ComplexUpsample, self).__init__() 40 | self.upsample = nn.Upsample( 41 | size=size, 42 | scale_factor=scale_factor, 43 | mode=mode, 44 | align_corners=align_corners, 45 | recompute_scale_factor=recompute_scale_factor 46 | ) 47 | 48 | def forward(self, input): 49 | real, imag = torch.unbind(input, dim=-1) 50 | output = torch.stack((self.upsample(real), self.upsample(imag)), dim=-1) 51 | return output 52 | -------------------------------------------------------------------------------- /complex_layers/radial_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.polar_transforms import ( 4 | convert_cylindrical_to_polar, 5 | convert_polar_to_cylindrical, 6 | ) 7 | 8 | 9 | class RadialNorm(nn.Module): 10 | 11 | def __init__( 12 | self, 13 | rank, 14 | num_features, 15 | t=5, 16 | eps=1e-5, 17 | momentum=0.1, 18 | affine=True, 19 | track_running_stats=True, 20 | polar=False 21 | ): 22 | """Radial Batch Normalization 23 | 24 | Parameters 25 | ---------- 26 | rank : int 27 | The spatial dimension of the input tensor. 28 | num_features : int 29 | The number of features of the input tensor. 30 | t : float 31 | The threshold for the normalization. 32 | eps : float 33 | The epsilon for the normalization. 34 | momentum : float 35 | The momentum for the normalization. 36 | affine : bool 37 | If True, this module has learnable affine parameters. 38 | track_running_stats : bool 39 | If True, this module tracks the running mean and variance, 40 | and during training time uses the running mean and variance to normalize. 41 | During testing time, this module uses the mean and variance of the 42 | input statistics to normalize. 43 | polar : bool 44 | If True, the input is in the polar form (magnitude and phase). 45 | If False, the input is in the cylindrical form (real and imag). 46 | """ 47 | super(RadialNorm, self).__init__() 48 | self.rank = rank 49 | self.num_features = num_features 50 | self.t = t 51 | self.eps = eps 52 | self.momentum = momentum 53 | self.affine = affine 54 | self.track_running_stats = track_running_stats 55 | self.polar = polar 56 | 57 | bns = { 58 | 1: nn.BatchNorm1d, 59 | 2: nn.BatchNorm2d, 60 | 3: nn.BatchNorm3d 61 | } 62 | self.bn_func = bns[self.rank]( 63 | num_features=num_features, 64 | eps=eps, 65 | momentum=momentum, 66 | affine=affine, 67 | track_running_stats=track_running_stats 68 | ) 69 | 70 | def forward(self, input): 71 | real, imag = torch.unbind(input, -1) 72 | mag, phase = convert_cylindrical_to_polar(real, imag) if not self.polar \ 73 | else (real, imag) 74 | 75 | # normalize the magnitude (see paper: El-Rewaidy et al. "Deep complex 76 | # convolutional network for fast reconstruction of 3D late gadolinium 77 | # enhancement cardiac MRI", NMR in Biomedicne, 2020) 78 | # Normalize the radius to be around self.t (i.e. 5 std) (1 also works fine) 79 | norm_mag = self.bn_func(mag) + self.t 80 | 81 | output_real, output_imag = convert_polar_to_cylindrical(norm_mag, phase) \ 82 | if not self.polar else (norm_mag, phase) 83 | output = torch.stack((output_real, output_imag), dim=-1) 84 | 85 | return output 86 | 87 | 88 | class RadialBatchNorm1d(RadialNorm): 89 | r"""Applies Radial Batch Normalization over a 2D and 3D input (a mini-batch of 1D 90 | complex inputs with optional additional channel dimension) 91 | 92 | Parameters 93 | ---------- 94 | num_features: :math:`C` from an expected input of size 95 | :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` 96 | eps: a value added to the denominator for numerical stability. 97 | Default: 1e-5 98 | momentum: the value used for the running_mean and running_var 99 | computation. Can be set to ``None`` for cumulative moving average 100 | (i.e. simple average). Default: 0.1 101 | affine: a boolean value that when set to ``True``, this module has 102 | learnable affine parameters. Default: ``True`` 103 | track_running_stats: a boolean value that when set to ``True``, this 104 | module tracks the running mean and variance, and when set to ``False``, 105 | this module does not track such statistics and always uses batch 106 | statistics in both training and eval modes. Default: ``True`` 107 | 108 | Shape: 109 | - Input: :math:`(N, C)` or :math:`(N, C, L, 2)` 110 | - Output: :math:`(N, C)` or :math:`(N, C, L, 2)` (same shape as input) 111 | """ 112 | def __init__(self, 113 | num_features, 114 | t=5, 115 | eps=1e-5, 116 | momentum=0.1, 117 | affine=True, 118 | track_running_stats=True, 119 | polar=False): 120 | super(RadialBatchNorm1d, self).__init__( 121 | rank=1, 122 | num_features=num_features, 123 | t=t, 124 | eps=eps, 125 | momentum=momentum, 126 | affine=affine, 127 | track_running_stats=track_running_stats, 128 | polar=polar 129 | ) 130 | 131 | 132 | class RadialBatchNorm2d(RadialNorm): 133 | r"""Applies Radial Batch Normalization over a 5D complex input (a mini-batch of 2D 134 | complex inputs with optional additional channel dimension) 135 | 136 | Parameters 137 | ---------- 138 | num_features: :math:`C` from an expected input of size 139 | :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` 140 | eps: a value added to the denominator for numerical stability. 141 | Default: 1e-5 142 | momentum: the value used for the running_mean and running_var 143 | computation. Can be set to ``None`` for cumulative moving average 144 | (i.e. simple average). Default: 0.1 145 | affine: a boolean value that when set to ``True``, this module has 146 | learnable affine parameters. Default: ``True`` 147 | track_running_stats: a boolean value that when set to ``True``, this 148 | module tracks the running mean and variance, and when set to ``False``, 149 | this module does not track such statistics and always uses batch 150 | statistics in both training and eval modes. Default: ``True`` 151 | 152 | Shape: 153 | - Input: :math:`(N, C, H, W, 2)` 154 | - Output: :math:`(N, C, H, W, 2)` (same shape as input) 155 | """ 156 | 157 | def __init__(self, 158 | num_features, 159 | t=5, 160 | eps=1e-5, 161 | momentum=0.1, 162 | affine=True, 163 | track_running_stats=True, 164 | polar=False): 165 | super(RadialBatchNorm2d, self).__init__( 166 | rank=2, 167 | num_features=num_features, 168 | t=t, 169 | eps=eps, 170 | momentum=momentum, 171 | affine=affine, 172 | track_running_stats=track_running_stats, 173 | polar=polar 174 | ) 175 | 176 | 177 | class RadialBatchNorm3d(RadialNorm): 178 | r"""Applies Radial Batch Normalization over a 6D complex input (a mini-batch of 3D 179 | complex inputs with optional additional channel dimension) 180 | 181 | Parameters 182 | ---------- 183 | num_features: :math:`C` from an expected input of size 184 | :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` 185 | eps: a value added to the denominator for numerical stability. 186 | Default: 1e-5 187 | momentum: the value used for the running_mean and running_var 188 | computation. Can be set to ``None`` for cumulative moving average 189 | (i.e. simple average). Default: 0.1 190 | affine: a boolean value that when set to ``True``, this module has 191 | learnable affine parameters. Default: ``True`` 192 | track_running_stats: a boolean value that when set to ``True``, this 193 | module tracks the running mean and variance, and when set to ``False``, 194 | this module does not track such statistics and always uses batch 195 | statistics in both training and eval modes. Default: ``True`` 196 | 197 | Shape: 198 | - Input: :math:`(N, C, D, H, W, 2)` 199 | - Output: :math:`(N, C, D, H, W, 2)` (same shape as input) 200 | """ 201 | 202 | def __init__(self, 203 | num_features, 204 | t=5, 205 | eps=1e-5, 206 | momentum=0.1, 207 | affine=True, 208 | track_running_stats=True, 209 | polar=False): 210 | super(RadialBatchNorm3d, self).__init__( 211 | rank=3, 212 | num_features=num_features, 213 | t=t, 214 | eps=eps, 215 | momentum=momentum, 216 | affine=affine, 217 | track_running_stats=track_running_stats, 218 | polar=polar 219 | ) 220 | -------------------------------------------------------------------------------- /complex_net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/helrewaidy/deep-complex-convolutional-network/380930aef0dde3629a2ea27d37227e819e08a447/complex_net/__init__.py -------------------------------------------------------------------------------- /complex_net/cmplx_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from complex_layers.cmplx_conv import ComplexConv1d, ComplexConv2d, ComplexConv3d 5 | from complex_layers.cmplx_activation import CReLU, ModReLU, ZReLU 6 | from complex_layers.cmplx_upsample import ComplexUpsample 7 | from complex_layers.radial_bn import RadialBatchNorm1d, RadialBatchNorm2d, RadialBatchNorm3d 8 | from complex_layers.cmplx_dropout import ComplexDropout 9 | from configs import config 10 | 11 | 12 | def complex_conv(in_ch, out_ch, **kwargs): 13 | conv = { 14 | 1: ComplexConv1d, 15 | 2: ComplexConv2d, 16 | 3: ComplexConv3d 17 | }[config.spatial_dimentions] 18 | if 'kernel_size' not in kwargs: 19 | kwargs['kernel_size'] = config.kernel_size 20 | return conv( 21 | in_ch, 22 | out_ch, 23 | bias=config.bias, 24 | **kwargs 25 | ) 26 | 27 | 28 | def activation(in_channels=None, **kwargs): 29 | if config.activation == 'modReLU': 30 | kwargs['in_channels'] = in_channels 31 | return { 32 | 'CReLU': CReLU, 33 | 'modReLU': ModReLU, 34 | 'ZReLU': ZReLU 35 | }[config.activation](**kwargs) 36 | 37 | 38 | def batch_norm(in_channels=None, **kwargs): 39 | bn = { 40 | 1: RadialBatchNorm1d, 41 | 2: RadialBatchNorm2d, 42 | 3: RadialBatchNorm3d 43 | }[config.spatial_dimentions] 44 | return bn(in_channels, t=config.bn_t, **kwargs) 45 | 46 | 47 | class DoubleConv(nn.Module): 48 | def __init__(self, in_ch, out_ch): 49 | super(DoubleConv, self).__init__() 50 | self.conv = nn.Sequential( 51 | complex_conv(in_ch, out_ch, padding=1), 52 | batch_norm(out_ch), 53 | activation(out_ch), 54 | ComplexDropout(config.dropout_ratio), 55 | complex_conv(out_ch, out_ch, padding=1), 56 | batch_norm(out_ch), 57 | activation(out_ch), 58 | ComplexDropout(config.dropout_ratio) 59 | ) 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | return x 64 | 65 | 66 | class DownConv(nn.Module): 67 | def __init__(self, in_ch): 68 | super(DownConv, self).__init__() 69 | self.conv = nn.Sequential( 70 | complex_conv(in_ch, in_ch, stride=2, padding=1), 71 | batch_norm(in_ch), 72 | activation(in_ch) 73 | ) 74 | 75 | def forward(self, x): 76 | x = self.conv(x) 77 | return x 78 | 79 | 80 | class InConv(nn.Module): 81 | def __init__(self, in_ch, out_ch): 82 | super(InConv, self).__init__() 83 | self.conv = DoubleConv(in_ch, out_ch) 84 | 85 | def forward(self, x): 86 | x = self.conv(x) 87 | return x 88 | 89 | 90 | class Down(nn.Module): 91 | def __init__(self, in_ch, out_ch): 92 | super(Down, self).__init__() 93 | self.down_conv = DownConv(in_ch) 94 | self.double_conv = DoubleConv(in_ch, out_ch) 95 | 96 | def forward(self, x): 97 | down_x = self.down_conv(x) 98 | x = self.double_conv(down_x) 99 | return x, down_x 100 | 101 | 102 | class BottleNeck(nn.Module): 103 | def __init__(self, in_ch, out_ch, residual_connection=True): 104 | super(BottleNeck, self).__init__() 105 | self.residual_connection = residual_connection 106 | self.down_conv = DownConv(in_ch) 107 | self.double_conv = nn.Sequential( 108 | complex_conv(in_ch, 2 * in_ch, padding=1), 109 | batch_norm(2 * in_ch), 110 | activation(2 * in_ch), 111 | ComplexDropout(config.dropout_ratio), 112 | complex_conv(2 * in_ch, out_ch, padding=1), 113 | batch_norm(out_ch), 114 | activation(out_ch), 115 | ComplexDropout(config.dropout_ratio) 116 | ) 117 | 118 | def forward(self, x): 119 | down_x = self.down_conv(x) 120 | if self.residual_connection: 121 | x = self.double_conv(down_x) + down_x 122 | else: 123 | x = self.double_conv(down_x) 124 | 125 | return x 126 | 127 | 128 | class Up(nn.Module): 129 | def __init__(self, in_ch, out_ch): 130 | super(Up, self).__init__() 131 | 132 | self.up = ComplexUpsample(scale_factor=2, mode='bilinear') 133 | 134 | self.conv = nn.Sequential( 135 | complex_conv(in_ch * 2, in_ch, padding=1), 136 | batch_norm(in_ch), 137 | activation(in_ch), 138 | ComplexDropout(config.dropout_ratio), 139 | complex_conv(in_ch, out_ch, padding=1), 140 | batch_norm(out_ch), 141 | activation(out_ch), 142 | ComplexDropout(config.dropout_ratio) 143 | ) 144 | 145 | def forward(self, x1, x2): 146 | x1 = self.up(x1) 147 | diffX = x1.size()[2] - x2.size()[2] 148 | diffY = x1.size()[3] - x2.size()[3] 149 | x2 = F.pad(x2, (diffX // 2, int(diffX / 2), 150 | diffY // 2, int(diffY / 2))) 151 | x = torch.cat([x2, x1], dim=1) 152 | x = self.conv(x) 153 | return x 154 | 155 | class OutConv(nn.Module): 156 | def __init__(self, in_ch, out_ch): 157 | super(OutConv, self).__init__() 158 | self.conv = complex_conv(in_ch, out_ch, kernel_size=1) 159 | 160 | def forward(self, x): 161 | x = self.conv(x) 162 | return x 163 | -------------------------------------------------------------------------------- /complex_net/cmplx_unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import complex_net.cmplx_blocks as unet_cmplx 3 | from configs import config 4 | 5 | 6 | class CUNet(nn.Module): 7 | def __init__(self, in_channels, out_channels): 8 | super(CUNet, self).__init__() 9 | self.inc = unet_cmplx.InConv(in_channels, 64) 10 | self.down1 = unet_cmplx.Down(64, 128) 11 | self.down2 = unet_cmplx.Down(128, 256) 12 | self.bottleneck = unet_cmplx.BottleNeck(256, 256, False) 13 | self.up2 = unet_cmplx.Up(256, 128) 14 | self.up3 = unet_cmplx.Up(128, 64) 15 | self.up4 = unet_cmplx.Up(64, 64) 16 | self.ouc = unet_cmplx.OutConv(64, out_channels) 17 | 18 | def forward(self, x): 19 | x0 = x 20 | x1 = self.inc(x) 21 | x2, _ = self.down1(x1) 22 | x3, _ = self.down2(x2) 23 | x4 = self.bottleneck(x3) 24 | x = self.up2(x4, x3) 25 | x = self.up3(x, x2) 26 | x = self.up4(x, x1) 27 | x = x + x0 if config.unet_global_residual_conn else x 28 | x = self.ouc(x) 29 | 30 | return x 31 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Config(object): 4 | """Global Config class""" 5 | 6 | # Dataset configs 7 | spatial_dimentions = 2 8 | input_shape = (256, 256, 2) # (x_dim, y_dim, z_dim, real-imag) 9 | in_channels = 1 10 | out_channels = 1 11 | batch_size = 2 12 | data_loaders_num_workers = 4 13 | 14 | # Training configs 15 | learning_rate = 0.01 16 | models_dir = 'models/' 17 | workspace_dir = 'workspace/' 18 | num_epochs = 50 19 | normalize_input = True 20 | 21 | # Model configs 22 | unet_global_residual_conn = False 23 | kernel_size = 3 24 | bias = False 25 | activation = 'CReLU' 26 | activation_params = { 27 | 'inplace': True, 28 | } 29 | bn_t = 5 30 | dropout_ratio = 0.0 31 | 32 | 33 | config = Config() 34 | 35 | 36 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from cmath import inf 2 | import shutil 3 | import os 4 | 5 | import torch.nn.modules.loss as Loss 6 | from torch import optim 7 | import numpy as np 8 | 9 | from complex_net.cmplx_unet import CUNet 10 | from complex_net.cmplx_blocks import batch_norm 11 | from utils.dataset import get_dataloaders 12 | from utils.loss import SSIM 13 | from configs import config 14 | import logging 15 | from complex_layers.radial_bn import RadialNorm 16 | 17 | 18 | logging.basicConfig( 19 | format='%(asctime)s - %(message)s', 20 | datefmt='%d-%b-%y %H:%M:%S', 21 | level=logging.INFO 22 | ) 23 | 24 | 25 | def set_seeds(seed): 26 | """Set the seeds for reproducibility 27 | 28 | Parameters 29 | ---------- 30 | seed : int 31 | The seed to set. 32 | """ 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | np.random.seed(seed) 36 | 37 | 38 | def get_device(): 39 | """Get device 40 | 41 | Returns 42 | ------- 43 | device : torch.device 44 | The device to use. 45 | """ 46 | return torch.device("cuda" if torch.cuda.is_available() else "cpu") 47 | 48 | 49 | def train_epoch(net, optimizer, loss_criterion, tr_dataloader, epoch): 50 | """Train for one epoch of the data 51 | 52 | Parameters 53 | ---------- 54 | net : torch.nn.Module 55 | The network to train. 56 | optimizer : torch.optim.Optimizer 57 | The optimizer to use. 58 | loss_criterion : torch.nn.Module 59 | The loss criterion to use. 60 | tr_dataloader : torch.utils.data.DataLoader 61 | The training data loader. 62 | epoch : int 63 | The epoch number. 64 | 65 | Returns 66 | ------- 67 | avg_loss : float 68 | The average loss for the epoch. 69 | net : torch.nn.Module 70 | The trained network. 71 | optimizer : torch.optim.Optimizer 72 | The optimizer. 73 | """ 74 | avg_loss = 0.0 75 | net.train() 76 | device = get_device() 77 | radial_normalizer = batch_norm( 78 | in_channels =config.in_channels, 79 | ) 80 | for itt, (input, target) in enumerate(tr_dataloader): 81 | X = Variable(torch.FloatTensor(input.float())).to(device) 82 | y = Variable(torch.FloatTensor(target.float())).to(device) 83 | 84 | if config.normalize_input: 85 | X = radial_normalizer(X) 86 | y = radial_normalizer(y) 87 | 88 | y_pred = net(X) 89 | 90 | loss = loss_criterion(y_pred, y) 91 | optimizer.zero_grad() 92 | loss.backward() 93 | optimizer.step() 94 | 95 | avg_loss += loss.detach().item() / len(tr_dataloader) 96 | 97 | logging.info('Epoch: {0} - Itter: {1}/{2} - loss: {3:.6f}'.format( 98 | epoch, itt, len(tr_dataloader), loss.detach().item()) 99 | ) 100 | 101 | return avg_loss, net, optimizer 102 | 103 | 104 | def validate(net, loss_criterion, val_dataloader, epoch): 105 | """Validate the model on the validation set 106 | 107 | Parameters 108 | ---------- 109 | net : torch.nn.Module 110 | The network to validate. 111 | loss_criterion : torch.nn.Module 112 | The loss criterion to use. 113 | val_dataloader : torch.utils.data.DataLoader 114 | The validation data loader. 115 | epoch : int 116 | The epoch number. 117 | 118 | Returns 119 | ------- 120 | avg_loss : float 121 | The average loss for the epoch. 122 | avg_ssim : float 123 | The average SSIM for the epoch. 124 | """ 125 | avg_loss = 0.0 126 | avg_ssim = 0.0 127 | ssim_criterion = SSIM() 128 | device = get_device() 129 | radial_normalizer = batch_norm( 130 | in_channels =config.in_channels, 131 | ) 132 | mag = lambda x: (x[..., 0] ** 2 + x[..., 1] ** 2) ** 0.5 133 | with torch.no_grad(): 134 | for itt, (input, target) in enumerate(val_dataloader): 135 | X = Variable(torch.FloatTensor(input.float())).to(device) 136 | y = Variable(torch.FloatTensor(target.float())).to(device) 137 | 138 | if config.normalize_input: 139 | X = radial_normalizer(X) 140 | y = radial_normalizer(y) 141 | 142 | y_pred = net(X) 143 | 144 | loss = loss_criterion(y_pred, y) 145 | ssim = ssim_criterion(mag(y_pred), mag(y)) 146 | 147 | avg_loss += loss.detach().item() / len(val_dataloader) 148 | avg_ssim += ssim.detach().item() / len(val_dataloader) 149 | 150 | logging.info('Epoch: {0} - Itter: {1}/{2} - loss: {3:.6f} - SSIM: {4:.6f}'.format( 151 | epoch, itt, len(val_dataloader), loss.detach().item(), ssim.detach().item()) 152 | ) 153 | return avg_loss, avg_ssim 154 | 155 | 156 | def train(net, optimizer, loss_criterion, tr_dataloader, val_dataloader): 157 | """Train the network 158 | 159 | Parameters 160 | ---------- 161 | net : torch.nn.Module 162 | The network to train. 163 | optimizer : torch.optim.Optimizer 164 | The optimizer to use. 165 | loss_criterion : torch.nn.Module 166 | The loss criterion to use. 167 | tr_dataloader : torch.utils.data.DataLoader 168 | The training data loader. 169 | val_dataloader : torch.utils.data.DataLoader 170 | The validation data loader. 171 | """ 172 | best_loss = inf 173 | for epoch in range(config.num_epochs): 174 | logging.info(f'Training epoch {epoch}/{config.num_epochs}...') 175 | 176 | optimizer = adjust_learning_rate(epoch, optimizer) 177 | 178 | # Training 179 | avg_tr_loss, net, optimizer = train_epoch( 180 | net, optimizer, loss_criterion, tr_dataloader, epoch 181 | ) 182 | logging.info(f'Epoch {epoch} - Avg. training loss: {avg_tr_loss:.3f}') 183 | 184 | # Validation 185 | avg_vld_loss, avg_vld_ssim = validate(net, loss_criterion, val_dataloader, epoch) 186 | logging.info(f'Epoch {epoch} - Avg. validation loss: {avg_tr_loss:.3f}, SSIM: {avg_vld_ssim:.3f}') 187 | 188 | save_checkpoint( 189 | { 190 | 'epoch': epoch, 191 | 'arch': 'complexnet', 192 | 'state_dict': net.state_dict(), 193 | 'optimizer': optimizer.state_dict(), 194 | }, 195 | is_best = avg_vld_loss < best_loss, 196 | filename = config.models_dir + 'checkpoint.pth' 197 | ) 198 | logging.info('Model Saved!') 199 | 200 | 201 | def save_checkpoint(state, is_best, filename='checkpoint.pth'): 202 | """Save a checkpoint 203 | 204 | Parameters 205 | ---------- 206 | state : dict 207 | The state to save. 208 | is_best : bool 209 | Whether this is the best model. 210 | filename : str 211 | The filename to save the checkpoint to. 212 | """ 213 | torch.save(state, filename) 214 | if is_best: 215 | shutil.copyfile(filename, 'model_best.pth') 216 | 217 | 218 | def adjust_learning_rate(epoch, optimizer): 219 | """Sets the learning rate to the initial LR decayed by 10 every 20 epochs 220 | 221 | Parameters 222 | ---------- 223 | epoch : int 224 | The epoch number. 225 | optimizer : torch.optim.Optimizer 226 | The optimizer. 227 | 228 | Returns 229 | ------- 230 | optimizer : torch.optim.Optimizer 231 | The optimizer. 232 | """ 233 | lr = config.learning_rate * (0.1 ** (epoch // 20)) 234 | for param_group in optimizer.param_groups: 235 | param_group['lr'] = lr 236 | return optimizer 237 | 238 | 239 | if __name__ == '__main__': 240 | set_seeds(222) 241 | tr_dataloader, val_dataloader = get_dataloaders() 242 | net = CUNet(config.in_channels, config.out_channels) 243 | if torch.cuda.is_available(): 244 | net = torch.nn.DataParallel(net).cuda() 245 | 246 | optimizer = optim.Adam(net.parameters(), lr=config.learning_rate) 247 | loss_criterion = Loss.MSELoss() 248 | os.makedirs(config.models_dir, exist_ok=True) 249 | 250 | train(net, optimizer, loss_criterion, tr_dataloader, val_dataloader) 251 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | from .loss import * 3 | from .polar_transforms import * 4 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | from configs import config 5 | 6 | 7 | def create_cmplx_dataset(data_shape: Tuple) -> np.ndarray: 8 | """Create a dataset of complex images. 9 | 10 | Parameters 11 | ---------- 12 | data_shape : Tuple 13 | The shape of the data. 14 | 15 | Returns 16 | ------- 17 | np.ndarray 18 | A dataset of complex random values. 19 | """ 20 | return np.random.rand(*data_shape) 21 | 22 | 23 | def get_dataloaders() -> Tuple[DataLoader, DataLoader]: 24 | """Get the dataset loaders. 25 | 26 | Returns 27 | ------- 28 | Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader] 29 | The training and validation data loaders 30 | """ 31 | input_shape = (config.in_channels, *config.input_shape) 32 | tr_dataset = DataGenerator(input_shape, list(range(80))) 33 | vld_dataset = DataGenerator(input_shape, list(range(20))) 34 | 35 | tr_data_loader = DataLoader( 36 | tr_dataset, 37 | batch_size=config.batch_size, 38 | shuffle=True, 39 | num_workers=config.data_loaders_num_workers 40 | ) 41 | 42 | vld_data_loader = DataLoader( 43 | vld_dataset, 44 | batch_size=config.batch_size, 45 | shuffle=False, 46 | num_workers=config.data_loaders_num_workers 47 | ) 48 | 49 | return tr_data_loader, vld_data_loader 50 | 51 | 52 | class DataGenerator(Dataset): 53 | def __init__(self, input_shape: Tuple, data_indicies: List): 54 | """Initialize the DataGenerator class. 55 | 56 | Parameters 57 | ---------- 58 | input_shape : Tuple 59 | The shape of the input data. 60 | """ 61 | self.input_shape = input_shape 62 | self.data_indicies = data_indicies 63 | 64 | def __len__(self): 65 | return len(self.data_indicies) 66 | 67 | def __getitem__(self, index): 68 | 'Generate one batch of data' 69 | usr = 4 70 | X = create_cmplx_dataset(self.input_shape) 71 | ds_shape = [s//usr for s in range(1, len(self.input_shape)-1)] 72 | y = np.resize(np.resize(X.copy(), ds_shape), self.input_shape) 73 | return X, y 74 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | SSIM is borrowed from https://github.com/Po-Hsun-Su/pytorch-ssim 3 | """ 4 | 5 | from math import exp 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | 12 | def gaussian(window_size, sigma): 13 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 14 | return gauss / gauss.sum() 15 | 16 | 17 | def create_window(window_size, channel): 18 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 19 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 20 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 21 | return window 22 | 23 | 24 | def _ssim(img1, img2, window, window_size, channel, size_average=True, full=False): 25 | padd = 0 26 | 27 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 28 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 29 | 30 | mu1_sq = mu1.pow(2) 31 | mu2_sq = mu2.pow(2) 32 | mu1_mu2 = mu1 * mu2 33 | 34 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 35 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 36 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 37 | 38 | C1 = 0.01 ** 2 39 | C2 = 0.03 ** 2 40 | 41 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 42 | 43 | v1 = 2.0 * sigma12 + C2 44 | v2 = sigma1_sq + sigma2_sq + C2 45 | cs = torch.mean(v1 / v2) 46 | 47 | if size_average: 48 | ret = ssim_map.mean() 49 | else: 50 | ret = ssim_map.mean(1).mean(1).mean(1) 51 | 52 | if full: 53 | return ret, cs 54 | return ret 55 | 56 | 57 | class SSIM(torch.nn.Module): 58 | def __init__(self, window_size=11, size_average=True): 59 | super(SSIM, self).__init__() 60 | self.window_size = window_size 61 | self.size_average = size_average 62 | self.channel = 1 63 | self.window = create_window(window_size, self.channel) 64 | 65 | def forward(self, img1, img2): 66 | (_, channel, _, _) = img1.size() 67 | 68 | if channel == self.channel and self.window.data.type() == img1.data.type(): 69 | window = self.window 70 | else: 71 | window = create_window(self.window_size, channel) 72 | 73 | if img1.is_cuda: 74 | window = window.cuda(img1.get_device()) 75 | window = window.type_as(img1) 76 | 77 | self.window = window 78 | self.channel = channel 79 | 80 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 81 | 82 | 83 | def ssim(img1, img2, window_size=11, size_average=True, full=False): 84 | (_, channel, height, width) = img1.size() 85 | 86 | real_size = min(window_size, height, width) 87 | window = create_window(real_size, channel) 88 | 89 | if img1.is_cuda: 90 | window = window.cuda(img1.get_device()) 91 | window = window.type_as(img1) 92 | 93 | return _ssim(img1, img2, window, real_size, channel, size_average, full=full) 94 | -------------------------------------------------------------------------------- /utils/polar_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | 5 | def convert_polar_to_cylindrical( 6 | magnitude: torch.Tensor, 7 | phase: torch.Tensor 8 | ) -> Tuple[torch.Tensor, torch.Tensor]: 9 | """Convert the polar representation (i.e. magnitude and phase) to 10 | cylindrical representation (i.e. real and imaginary) 11 | 12 | Parameters 13 | ---------- 14 | magnitude : torch.Tensor 15 | The magnitude of the complex tensor 16 | phase : torch.Tensor 17 | The phase of the complex tensor 18 | 19 | Returns 20 | ------- 21 | Tuple[torch.Tensor, torch.Tensor] 22 | The real and imaginary part of the complex tensor 23 | """ 24 | real = magnitude * torch.cos(phase) 25 | imag = magnitude * torch.sin(phase) 26 | return real, imag 27 | 28 | 29 | def convert_cylindrical_to_polar( 30 | real: torch.Tensor, 31 | imag: torch.Tensor 32 | ) -> Tuple[torch.Tensor, torch.Tensor]: 33 | """Convert the cylindrical representation (i.e. real and imaginary) to 34 | polar representation (i.e. magnitude and phase) 35 | 36 | Parameters 37 | ---------- 38 | real : torch.Tensor 39 | The real part of the complex tensor 40 | imag : torch.Tensor 41 | The imaginary part of the complex tensor 42 | 43 | Returns 44 | ------- 45 | Tuple[torch.Tensor, torch.Tensor] 46 | The magnitude and phase of the complex tensor 47 | """ 48 | mag = (real ** 2 + imag ** 2) ** (0.5) 49 | phase = torch.atan2(imag, real) 50 | phase[phase.ne(phase)] = 0.0 # remove NANs if any 51 | return mag, phase 52 | 53 | --------------------------------------------------------------------------------