├── README.md ├── csnet_model ├── cs_net_sample0.2.pth ├── cs_net_sample0.3.pth └── cs_net_sample0.4.pth ├── data ├── mask_20 │ └── mask.mat ├── mask_30 │ └── mask0.3.mat ├── mask_40 │ └── mask0.4.mat ├── test │ └── new01.mat ├── train │ └── new01.mat └── validate │ └── new01.mat ├── network └── CSNet_Layers.py ├── requirements.txt ├── test.py ├── torchpwl ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── pwl.cpython-38.pyc │ └── pwl_test.cpython-38.pyc ├── pwl.py └── pwl_test.py ├── train.py └── utils ├── dataset.py ├── fftc.py ├── metric.py └── my_loss.py /README.md: -------------------------------------------------------------------------------- 1 | # ADMM-CSNet 2 | 3 | *********************************************************************************************************** 4 | 5 | Codes for model in "ADMM-CSNet: A Deep Learning Approach for Image Compressive Sensing" (TPAMI 2019) 6 | 7 | If you use these codes, please cite our paper: 8 | 9 | [1] Yan Yang, Jian Sun, Huibin Li, Zongben Xu. ADMM-CSNet: A Deep Learning Approach for Image Compressive Sensing (TPAMI 2019). 10 | 11 | http://gr.xjtu.edu.cn/web/jiansun/publications 12 | 13 | All rights are reserved by the authors. 14 | 15 | Yan Yang -2021/11/04. For more detail or traning data, feel free to contact: yangyan92@stu.xjtu.edu.cn 16 | 17 | 18 | *********************************************************************************************************** 19 | 20 | 21 | 22 | ## Data link: 23 | https://drive.google.com/drive/folders/1UhQ01pdmO11Agc5sM61Mt7KQTN9LytNt 24 | 25 | Download data and organise as follows in the directory: 26 | ``` 27 | ### For dataset 28 | └── data 29 | ├── train 30 | ├── test 31 | ├── validate 32 | ``` 33 | ## Installation 34 | This installation guide shows you how to set up the environment for running our code using conda. 35 | 36 | First clone the ADMM-CSNET repository 37 | ``` 38 | git clone https://github.com/lixing0810/Pytorch_ADMM-CSNet.git 39 | cd https://github.com/lixing0810/Pytorch_ADMM-CSNet 40 | ``` 41 | Then start a virtual environment with new environment variables nad 42 | ``` 43 | conda create --name ADMM_CSNET python=3.9 44 | conda activate ADMM_CSNET 45 | ``` 46 | Install PyTorch 47 | ``` 48 | pip install torch torchvision 49 | ``` 50 | Install all requirements 51 | ``` 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | ## Usage: 56 | 57 | 1. The nonlinear layer of ADMM-CSNET in pytorch version was supported by the torchpwl package.Replace torchpwl package in your env with which we have provided in the torchpwl folder after installing. 58 | 59 | 2. We have provided a training mri image and mask in the data directory,please replace the dataset downlowded in the google cloud. 60 | The full datasets contains 100 training data,50 testing data and 50 validating data. 61 | The mask_20 directory in data represents 20% sample and so on. 62 | 63 | 3. The net of training was implemented by end-to-end in ADMM-CSNET of pytorch version.we have provided the final model of 20%,30%,40% sample in csnet_model directory,just replace the model name in test.py. 64 | 65 | Test: 66 | ``` 67 | python test.py 68 | ``` 69 | 70 | 4. For retraining you should change the data_dir and mask name in train.py and saved in the logs_csnet directory. 71 | 72 | Training: 73 | ``` 74 | python train.py 75 | ``` 76 | 77 | 5. Results 78 | Notice: 79 | The optimizer of pytorch ADMM-CSNET is adam as its avaliable and faster training in big data which of matlab version used by LBFGS may cause the difference in final results. 80 | 81 | | Sampling Mask | PSNR | 82 | | ------ | ------ | 83 | | 0.2 | 31.354 | 84 | | 0.3 | 34.365 | 85 | | 0.4 | 37.153 | 86 | 87 | *********************************************************************************************************** 88 | -------------------------------------------------------------------------------- /csnet_model/cs_net_sample0.2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/csnet_model/cs_net_sample0.2.pth -------------------------------------------------------------------------------- /csnet_model/cs_net_sample0.3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/csnet_model/cs_net_sample0.3.pth -------------------------------------------------------------------------------- /csnet_model/cs_net_sample0.4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/csnet_model/cs_net_sample0.4.pth -------------------------------------------------------------------------------- /data/mask_20/mask.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/data/mask_20/mask.mat -------------------------------------------------------------------------------- /data/mask_30/mask0.3.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/data/mask_30/mask0.3.mat -------------------------------------------------------------------------------- /data/mask_40/mask0.4.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/data/mask_40/mask0.4.mat -------------------------------------------------------------------------------- /data/test/new01.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/data/test/new01.mat -------------------------------------------------------------------------------- /data/train/new01.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/data/train/new01.mat -------------------------------------------------------------------------------- /data/validate/new01.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/data/validate/new01.mat -------------------------------------------------------------------------------- /network/CSNet_Layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torchpwl 4 | from scipy.io import loadmat 5 | from os.path import join 6 | import os 7 | from utils.fftc import * 8 | import torch 9 | 10 | 11 | class CSNetADMMLayer(nn.Module): 12 | def __init__( 13 | self, 14 | mask, 15 | in_channels: int = 1, 16 | out_channels: int = 128, 17 | kernel_size: int = 5 18 | 19 | ): 20 | """ 21 | Args: 22 | 23 | """ 24 | super(CSNetADMMLayer, self).__init__() 25 | 26 | self.rho = nn.Parameter(torch.tensor([0.1]), requires_grad=True) 27 | self.gamma = nn.Parameter(torch.tensor([1.0]), requires_grad=True) 28 | self.mask = mask 29 | self.re_org_layer = ReconstructionOriginalLayer(self.rho, self.mask) 30 | self.conv1_layer = ConvolutionLayer1(in_channels, out_channels, kernel_size) 31 | self.nonlinear_layer = NonlinearLayer() 32 | self.conv2_layer = ConvolutionLayer2(out_channels, in_channels, kernel_size) 33 | self.min_layer = MinusLayer() 34 | self.multiple_org_layer = MultipleOriginalLayer(self.gamma) 35 | self.re_update_layer = ReconstructionUpdateLayer(self.rho, self.mask) 36 | self.add_layer = AdditionalLayer() 37 | self.multiple_update_layer = MultipleUpdateLayer(self.gamma) 38 | self.re_final_layer = ReconstructionFinalLayer(self.rho, self.mask) 39 | layers = [] 40 | 41 | layers.append(self.re_org_layer) 42 | layers.append(self.conv1_layer) 43 | layers.append(self.nonlinear_layer) 44 | layers.append(self.conv2_layer) 45 | layers.append(self.min_layer) 46 | layers.append(self.multiple_org_layer) 47 | 48 | for i in range(8): 49 | layers.append(self.re_update_layer) 50 | layers.append(self.add_layer) 51 | layers.append(self.conv1_layer) 52 | layers.append(self.nonlinear_layer) 53 | layers.append(self.conv2_layer) 54 | layers.append(self.min_layer) 55 | layers.append(self.multiple_update_layer) 56 | 57 | layers.append(self.re_update_layer) 58 | layers.append(self.add_layer) 59 | layers.append(self.conv1_layer) 60 | layers.append(self.nonlinear_layer) 61 | layers.append(self.conv2_layer) 62 | layers.append(self.min_layer) 63 | layers.append(self.multiple_update_layer) 64 | 65 | layers.append(self.re_final_layer) 66 | 67 | self.cs_net = nn.Sequential(*layers) 68 | self.reset_parameters() 69 | 70 | def reset_parameters(self): 71 | self.conv1_layer.conv.weight = torch.nn.init.normal_(self.conv1_layer.conv.weight, mean=0, std=1) 72 | self.conv2_layer.conv.weight = torch.nn.init.normal_(self.conv2_layer.conv.weight, mean=0, std=1) 73 | self.conv1_layer.conv.weight.data = self.conv1_layer.conv.weight.data * 0.025 74 | self.conv2_layer.conv.weight.data = self.conv2_layer.conv.weight.data * 0.025 75 | 76 | def forward(self, x): 77 | y = torch.mul(x, self.mask) 78 | x = self.cs_net(y) 79 | x = torch.fft.ifft2(y+(1-self.mask)*torch.fft.fft2(x)) 80 | return x 81 | 82 | 83 | # reconstruction original layers 84 | class ReconstructionOriginalLayer(nn.Module): 85 | def __init__(self, rho, mask): 86 | super(ReconstructionOriginalLayer,self).__init__() 87 | self.rho = rho 88 | self.mask = mask 89 | 90 | def forward(self, x): 91 | mask = self.mask 92 | denom = torch.add(mask.cuda(), self.rho) 93 | a = 1e-6 94 | value = torch.full(denom.size(), a).cuda() 95 | denom = torch.where(denom == 0, value, denom) 96 | orig_output1 = torch.div(1, denom) 97 | 98 | orig_output2 = torch.mul(x, orig_output1) 99 | orig_output3 = torch.fft.ifft2(orig_output2) 100 | # define data dict 101 | cs_data = dict() 102 | cs_data['input'] = x 103 | cs_data['conv1_input'] = orig_output3 104 | return cs_data 105 | 106 | 107 | # reconstruction middle layers 108 | class ReconstructionUpdateLayer(nn.Module): 109 | def __init__(self, rho, mask): 110 | super(ReconstructionUpdateLayer,self).__init__() 111 | self.rho = rho 112 | self.mask = mask 113 | 114 | def forward(self, x): 115 | minus_output = x['minus_output'] 116 | multiple_output = x['multi_output'] 117 | input = x['input'] 118 | mask = self.mask 119 | number = torch.add(input, self.rho * torch.fft.fft2(torch.sub(minus_output, multiple_output))) 120 | denom = torch.add(mask.cuda(), self.rho) 121 | a = 1e-6 122 | value = torch.full(denom.size(), a).cuda() 123 | denom = torch.where(denom == 0, value, denom) 124 | orig_output1 = torch.div(1, denom) 125 | orig_output2 = torch.mul(number, orig_output1) 126 | orig_output3 = torch.fft.ifft2(orig_output2) 127 | x['re_mid_output'] = orig_output3 128 | return x 129 | 130 | 131 | # reconstruction middle layers 132 | class ReconstructionFinalLayer(nn.Module): 133 | def __init__(self, rho, mask): 134 | super(ReconstructionFinalLayer, self).__init__() 135 | self.rho = rho 136 | self.mask = mask 137 | 138 | def forward(self, x): 139 | minus_output = x['minus_output'] 140 | multiple_output = x['multi_output'] 141 | input = x['input'] 142 | mask = self.mask 143 | number = torch.add(input, self.rho * torch.fft.fft2(torch.sub(minus_output, multiple_output))) 144 | denom = torch.add(mask.cuda(), self.rho) 145 | a = 1e-6 146 | value = torch.full(denom.size(), a).cuda() 147 | denom = torch.where(denom == 0, value, denom) 148 | orig_output1 = torch.div(1, denom) 149 | orig_output2 = torch.mul(number, orig_output1) 150 | orig_output3 = torch.fft.ifft2(orig_output2) 151 | x['re_final_output'] = orig_output3 152 | return x['re_final_output'] 153 | 154 | 155 | # multiple original layer 156 | class MultipleOriginalLayer(nn.Module): 157 | def __init__(self,gamma): 158 | super(MultipleOriginalLayer,self).__init__() 159 | self.gamma = gamma 160 | 161 | def forward(self,x): 162 | org_output = x['conv1_input'] 163 | minus_output = x['minus_output'] 164 | output= torch.mul(self.gamma,torch.sub(org_output, minus_output)) 165 | x['multi_output'] = output 166 | return x 167 | 168 | 169 | # multiple middle layer 170 | class MultipleUpdateLayer(nn.Module): 171 | def __init__(self,gamma): 172 | super(MultipleUpdateLayer,self).__init__() 173 | self.gamma = gamma 174 | 175 | def forward(self, x): 176 | multiple_output = x['multi_output'] 177 | re_mid_output = x['re_mid_output'] 178 | minus_output = x['minus_output'] 179 | output= torch.add(multiple_output,torch.mul(self.gamma,torch.sub(re_mid_output , minus_output))) 180 | x['multi_output'] = output 181 | return x 182 | 183 | 184 | # convolution layer 185 | class ConvolutionLayer1(nn.Module): 186 | def __init__(self, in_channels: int, out_channels: int,kernel_size:int): 187 | super(ConvolutionLayer1,self).__init__() 188 | self.in_channels = in_channels 189 | self.out_channels = out_channels 190 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=int((kernel_size-1)/2), stride=1, dilation= 1,bias=True) 191 | 192 | def forward(self, x): 193 | conv1_input = x['conv1_input'] 194 | real = self.conv(conv1_input.real) 195 | imag = self.conv(conv1_input.imag) 196 | output = torch.complex(real, imag) 197 | x['conv1_output'] = output 198 | return x 199 | 200 | 201 | # convolution layer 202 | class ConvolutionLayer2(nn.Module): 203 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int): 204 | super(ConvolutionLayer2, self).__init__() 205 | self.in_channels = in_channels 206 | self.out_channels = out_channels 207 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=int((kernel_size - 1) / 2), 208 | stride=1, dilation=1, bias=True) 209 | 210 | def forward(self, x): 211 | nonlinear_output = x['nonlinear_output'] 212 | real = self.conv(nonlinear_output.real) 213 | imag = self.conv(nonlinear_output.imag) 214 | output = torch.complex(real, imag) 215 | 216 | x['conv2_output'] = output 217 | return x 218 | 219 | 220 | # nonlinear layer 221 | class NonlinearLayer(nn.Module): 222 | def __init__(self): 223 | super(NonlinearLayer,self).__init__() 224 | self.pwl = torchpwl.PWL(num_channels=128, num_breakpoints=101) 225 | 226 | def forward(self, x): 227 | conv1_output = x['conv1_output'] 228 | y_real = self.pwl(conv1_output.real) 229 | y_imag = self.pwl(conv1_output.imag) 230 | output = torch.complex(y_real, y_imag) 231 | x['nonlinear_output'] = output 232 | return x 233 | 234 | 235 | # minus layer 236 | class MinusLayer(nn.Module): 237 | def __init__(self): 238 | super(MinusLayer, self).__init__() 239 | 240 | def forward(self, x): 241 | minus_input = x['conv1_input'] 242 | conv2_output = x['conv2_output'] 243 | output= torch.sub(minus_input, conv2_output) 244 | x['minus_output'] = output 245 | return x 246 | 247 | 248 | # addtional layer 249 | class AdditionalLayer(nn.Module): 250 | def __init__(self): 251 | super(AdditionalLayer,self).__init__() 252 | 253 | def forward(self, x): 254 | mid_output = x['re_mid_output'] 255 | multi_output = x['multi_output'] 256 | output= torch.add(mid_output,multi_output) 257 | x['conv1_input'] = output 258 | return x -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.11 2 | scipy>=1.7.0 3 | h5py==3.5.0 4 | tensorboardX>=2.1.0 5 | torchpwl>=0.1.0 6 | packaging>=20.8 7 | six>=1.15.0 8 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | ADMM_-CSNET test example (v1) with MR slices 3 | By Yan Yang, Jian Sun, Huibin Li, Zongben Xu 4 | 5 | Please cite the below paper for the code: 6 | 7 | Yan Yang, Jian Sun, Huibin Li, Zongben Xu. ADMM-CSNet: A Deep Learning Approach for Image Compressive Sensing, 8 | TPAMI(2019). 9 | """ 10 | from __future__ import print_function, division 11 | import os 12 | import argparse 13 | from network.CSNet_Layers import CSNetADMMLayer 14 | from utils.dataset import get_test_data 15 | import torch.utils.data as data 16 | from utils.my_loss import MyLoss 17 | from utils.metric import complex_psnr 18 | import gc 19 | from scipy.io import loadmat 20 | from utils.fftc import * 21 | from os.path import join 22 | 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 24 | 25 | if __name__ == '__main__': 26 | 27 | ############################################################################### 28 | # parameters 29 | ############################################################################### 30 | parser = argparse.ArgumentParser(description=' main ') 31 | parser.add_argument('--data_dir', default='data/', type=str, 32 | help='directory of data') 33 | parser.add_argument('--batch_size', default=1, type=int, help='batch size') 34 | parser.add_argument('--outf', type=str, default='csnet_model', help='path of log files') 35 | args = parser.parse_args() 36 | 37 | ############################################################################### 38 | # load data info 39 | ############################################################################### 40 | test = get_test_data(args.data_dir) 41 | test_loader = data.DataLoader(dataset=test, batch_size=args.batch_size, shuffle=False, num_workers=4, 42 | pin_memory=False) 43 | 44 | ############################################################################### 45 | # mask 46 | ############################################################################### 47 | dir = 'data/mask_20' 48 | data = loadmat(join(dir, os.listdir(dir)[0])) 49 | mask_data = data['mask'] 50 | mask = ifftshift(torch.Tensor(mask_data)).cuda() 51 | 52 | ############################################################################### 53 | # Build model 54 | ############################################################################### 55 | print('Loading model ...\n') 56 | model = CSNetADMMLayer(mask).cuda() 57 | model.load_state_dict(torch.load(os.path.join(args.outf, 'cs_net_sample0.2.pth'))) 58 | model.eval() 59 | 60 | ############################################################################### 61 | # loss 62 | ############################################################################### 63 | criterion = MyLoss().cuda() 64 | 65 | ############################################################################### 66 | # test 67 | ############################################################################### 68 | test_err = 0 69 | test_psnr = 0 70 | test_batches = 0 71 | for batch , (label,num) in enumerate(test_loader): 72 | gc.collect() 73 | with torch.no_grad(): 74 | full_kspace = torch.fft.fft2(label.cuda()) 75 | test_output = model(full_kspace) 76 | test_loss_normal = criterion(test_output, label.cuda()) 77 | test_err += test_loss_normal.item() 78 | test_batches += 1 79 | test_psnr_value = complex_psnr(abs(test_output).cpu().numpy(), abs(label).cpu().numpy(), 80 | peak='normalized') 81 | test_psnr += test_psnr_value 82 | test_err /= test_batches 83 | test_psnr /= test_batches 84 | print("test_loss ", test_err) 85 | print("test_psnr ", test_psnr) 86 | -------------------------------------------------------------------------------- /torchpwl/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["PWL", "MonoPWL", "PointPWL", "MonoPointPWL", "Calibrator"] 2 | __version__ = "0.1.1" 3 | 4 | from .pwl import PointPWL, MonoPointPWL, PWL, MonoPWL, Calibrator 5 | -------------------------------------------------------------------------------- /torchpwl/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/torchpwl/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /torchpwl/__pycache__/pwl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/torchpwl/__pycache__/pwl.cpython-38.pyc -------------------------------------------------------------------------------- /torchpwl/__pycache__/pwl_test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixing0810/Pytorch_ADMM-CSNet/0f37fdf3323dafedd5eb78b4bac1ee63283d5069/torchpwl/__pycache__/pwl_test.cpython-38.pyc -------------------------------------------------------------------------------- /torchpwl/pwl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_monotonicity(monotonicity, num_channels): 5 | if isinstance(monotonicity, (int, float)): 6 | if not monotonicity in (-1, 0, 1): 7 | raise ValueError("monotonicity must be one of -1, 0, +1") 8 | return monotonicity * torch.ones(num_channels) 9 | else: 10 | if not (isinstance(monotonicity, torch.Tensor) and list(monotonicity.shape) == [num_channels]): 11 | raise ValueError("monotonicity must be either an int or a tensor with shape [num_channels]") 12 | if not torch.all( 13 | torch.eq(monotonicity, 0) | torch.eq(monotonicity, 1) | torch.eq(monotonicity, -1) 14 | ).item(): 15 | raise ValueError("monotonicity must be one of -1, 0, +1") 16 | return monotonicity.float() 17 | 18 | 19 | class BasePWL(torch.nn.Module): 20 | def __init__(self, num_breakpoints): 21 | super(BasePWL, self).__init__() 22 | if not num_breakpoints >= 1: 23 | raise ValueError( 24 | "Piecewise linear function only makes sense when you have 1 or more breakpoints." 25 | ) 26 | self.num_breakpoints = num_breakpoints 27 | 28 | def slope_at(self, x): 29 | dx = 1e-3 30 | return -(self.forward(x) - self.forward(x + dx)) / dx 31 | 32 | 33 | def calibrate1d(x, xp, yp): 34 | """ 35 | x: [N, C] 36 | xp: [C, K] 37 | yp: [C, K] 38 | """ 39 | x_breakpoints = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((x.shape[0], 1, 1))], dim=2) 40 | num_x_points = xp.shape[1] 41 | sorted_x_breakpoints, x_indices = torch.sort(x_breakpoints, dim=2) 42 | x_idx = torch.argmin(x_indices, dim=2) 43 | cand_start_idx = x_idx - 1 44 | start_idx = torch.where( 45 | torch.eq(x_idx, 0), 46 | torch.tensor(1, device=x.device), 47 | torch.where( 48 | torch.eq(x_idx, num_x_points), torch.tensor(num_x_points - 2, device=x.device), cand_start_idx, 49 | ), 50 | ) 51 | end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) 52 | start_x = torch.gather(sorted_x_breakpoints, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) 53 | end_x = torch.gather(sorted_x_breakpoints, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) 54 | start_idx2 = torch.where( 55 | torch.eq(x_idx, 0), 56 | torch.tensor(0, device=x.device), 57 | torch.where( 58 | torch.eq(x_idx, num_x_points), torch.tensor(num_x_points - 2, device=x.device), cand_start_idx, 59 | ), 60 | ) 61 | y_positions_expanded = yp.unsqueeze(0).expand(x.shape[0], -1, -1) 62 | start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) 63 | end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) 64 | cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x + 1e-7) 65 | return cand 66 | 67 | 68 | class Calibrator(torch.nn.Module): 69 | def __init__(self, keypoints, monotonicity, missing_value=11.11): 70 | """ 71 | Calibrates input to the output range of [-0.5*monotonicity, 0.5*monotonicity]. 72 | The output is always monotonic with respect to the input. 73 | Recommended to use Adam for training. The calibrator is initalized as a straight line. 74 | 75 | value <= keypoint[0] will map to -0.5*monotonicity. 76 | value >= keypoint[-1] will map to 0.5*monotonicity. 77 | value == missing_value will map to a learnable value (within the standard output range). 78 | Each channel is independently calibrated and can have its own keypoints. 79 | Note: monotonicity and keypoints are not trainable, they remain fixed, only the calibration output at 80 | each keypoint is trainable. 81 | 82 | keypoints: tensor with shape [C, K], where K > 2 83 | monotonicity: tensor with shape [C] 84 | missing_value: float 85 | """ 86 | super(Calibrator, self).__init__() 87 | xp = torch.tensor(keypoints, dtype=torch.float32) 88 | self.register_buffer("offset", xp[:, :1].clone().detach()) 89 | self.register_buffer("scale", (xp[:, -1:] - self.offset).clone().detach()) 90 | xp = (xp - self.offset) / self.scale 91 | self.register_buffer("keypoints", xp) 92 | self.register_buffer("monotonicity", torch.tensor(monotonicity, dtype=torch.float32).unsqueeze(0)) 93 | self.missing_value = missing_value 94 | yp = xp[:, 1:] - xp[:, :-1] 95 | # [C, K - 1] 96 | self.yp = torch.nn.Parameter(yp, requires_grad=True) 97 | # [1, C] 98 | self.missing_y = torch.nn.Parameter(torch.zeros_like(xp[:, 0]).unsqueeze(0), requires_grad=True) 99 | 100 | def forward(self, x): 101 | """Calibrates input x tensor. x has shape [BATCH_SIZE, C].""" 102 | missing = torch.zeros_like(x) + torch.tanh(self.missing_y) / 2.0 103 | yp = torch.cumsum(torch.abs(self.yp) + 1e-9, dim=1) 104 | xp = self.keypoints 105 | last_val = yp[:, -1:] 106 | yp = torch.cat([torch.zeros_like(last_val), yp / last_val], dim=1) 107 | x_transformed = torch.clamp((x - self.offset) / self.scale, 0.0, 1.0) 108 | calibrated = calibrate1d(x_transformed, xp, yp) - 0.5 109 | return self.monotonicity * torch.where(x == self.missing_value, missing, calibrated) 110 | 111 | 112 | class BasePWLX(BasePWL): 113 | def __init__(self, num_channels, num_breakpoints, num_x_points): 114 | super(BasePWLX, self).__init__(num_breakpoints) 115 | self.num_channels = num_channels 116 | self.num_x_points = num_x_points 117 | # self.x_positions = torch.nn.Parameter(torch.Tensor(self.num_channels, self.num_x_points)) 118 | self.x_positions = torch.Tensor(self.num_channels, self.num_x_points) 119 | self._reset_x_points() 120 | 121 | def _reset_x_points(self): 122 | # torch.nn.init.normal_(self.x_positions, std=0.000001) 123 | # torch.nn.init.zeros_(self.x_positions) 124 | self.x_positions = torch.linspace(-1,1,self.num_x_points).unsqueeze(0).expand(self.num_channels, self.num_x_points) 125 | 126 | def get_x_positions(self): 127 | return self.x_positions 128 | 129 | def get_sorted_x_positions(self): 130 | return torch.sort(self.get_x_positions(), dim=1)[0] 131 | 132 | def get_spreads(self): 133 | sorted_x_positions = self.get_sorted_x_positions() 134 | return (torch.roll(sorted_x_positions, shifts=-1, dims=1) - sorted_x_positions)[:, :-1] 135 | 136 | def unpack_input(self, x): 137 | shape = list(x.shape) 138 | if len(shape) == 2: 139 | return x 140 | elif len(shape) < 2: 141 | raise ValueError( 142 | "Invalid input, the input to the PWL module must have at least 2 dimensions with channels at dimension dim(1)." 143 | ) 144 | assert shape[1] == self.num_channels, ( 145 | "Invalid input, the size of dim(1) must be equal to num_channels (%d)" % self.num_channels 146 | ) 147 | x = torch.transpose(x, 1, len(shape) - 1) 148 | assert x.shape[-1] == self.num_channels 149 | return x.reshape(-1, self.num_channels) 150 | 151 | def repack_input(self, unpacked, old_shape): 152 | old_shape = list(old_shape) 153 | if len(old_shape) == 2: 154 | return unpacked 155 | transposed_shape = old_shape[:] 156 | transposed_shape[1] = old_shape[-1] 157 | transposed_shape[-1] = old_shape[1] 158 | unpacked = unpacked.view(*transposed_shape) 159 | return torch.transpose(unpacked, 1, len(old_shape) - 1) 160 | 161 | 162 | class BasePointPWL(BasePWLX): 163 | def get_y_positions(self): 164 | raise NotImplementedError() 165 | 166 | def forward(self, x): 167 | old_shape = x.shape 168 | x = self.unpack_input(x) 169 | cand = calibrate1d(x, self.get_x_positions(), self.get_y_positions()) 170 | return self.repack_input(cand, old_shape) 171 | 172 | 173 | class PointPWL(BasePointPWL): 174 | def __init__(self, num_channels, num_breakpoints): 175 | super(PointPWL, self).__init__(num_channels, num_breakpoints, num_x_points=num_breakpoints + 1) 176 | self.y_positions = torch.nn.Parameter(torch.Tensor(self.num_channels, self.num_x_points)) 177 | self._reset_params() 178 | 179 | def _reset_params(self): 180 | BasePWLX._reset_x_points(self) 181 | with torch.no_grad(): 182 | self.y_positions.copy_(self.get_sorted_x_positions()) 183 | 184 | def get_x_positions(self): 185 | return self.x_positions 186 | 187 | def get_y_positions(self): 188 | return self.y_positions 189 | 190 | 191 | class MonoPointPWL(BasePointPWL): 192 | def __init__(self, num_channels, num_breakpoints, monotonicity=1): 193 | super(MonoPointPWL, self).__init__(num_channels, num_breakpoints, num_x_points=num_breakpoints + 1) 194 | self.y_starts = torch.nn.Parameter(torch.Tensor(self.num_channels)) 195 | self.y_deltas = torch.nn.Parameter(torch.Tensor(self.num_channels, self.num_breakpoints)) 196 | self.register_buffer("monotonicity", get_monotonicity(monotonicity, num_channels)) 197 | self._reset_params() 198 | 199 | def _reset_params(self): 200 | BasePWLX._reset_x_points(self) 201 | with torch.no_grad(): 202 | sorted_x_positions = self.get_sorted_x_positions() 203 | mono_mul = torch.where( 204 | torch.eq(self.monotonicity, 0.0), 205 | torch.tensor(1.0, device=self.monotonicity.device), 206 | self.monotonicity, 207 | ) 208 | self.y_starts.copy_(sorted_x_positions[:, 0] * mono_mul) 209 | spreads = self.get_spreads() 210 | self.y_deltas.copy_(spreads * mono_mul.unsqueeze(1)) 211 | 212 | def get_x_positions(self): 213 | return self.x_positions 214 | 215 | def get_y_positions(self): 216 | starts = self.y_starts.unsqueeze(1) 217 | deltas = torch.where( 218 | torch.eq(self.monotonicity, 0.0).unsqueeze(1), 219 | self.y_deltas, 220 | torch.abs(self.y_deltas) * self.monotonicity.unsqueeze(1), 221 | ) 222 | return torch.cat([starts, starts + torch.cumsum(deltas, dim=1)], dim=1) 223 | 224 | 225 | class BaseSlopedPWL(BasePWLX): 226 | def get_biases(self): 227 | raise NotImplementedError() 228 | 229 | def get_slopes(self): 230 | raise NotImplementedError() 231 | 232 | def forward(self, x): 233 | old_shape = x.shape 234 | x = self.unpack_input(x) 235 | bs = x.shape[0] 236 | sorted_x_positions = self.get_sorted_x_positions().cuda() 237 | skips = torch.roll(sorted_x_positions, shifts=-1, dims=1) - sorted_x_positions 238 | slopes = self.get_slopes() 239 | skip_deltas = skips * slopes[:, 1:] 240 | biases = self.get_biases().unsqueeze(1) 241 | cumsums = torch.cumsum(skip_deltas, dim=1)[:, :-1] 242 | 243 | betas = torch.cat([biases, biases, cumsums + biases], dim=1) 244 | breakpoints = torch.cat([sorted_x_positions[:, 0].unsqueeze(1), sorted_x_positions], dim=1) 245 | 246 | # find the index of the first breakpoint smaller than x 247 | # TODO(pdabkowski) improve the implementation 248 | s = x.unsqueeze(2) - sorted_x_positions.unsqueeze(0) 249 | # discard larger breakpoints 250 | s = torch.where(s < 0, torch.tensor(float("inf"), device=x.device), s) 251 | b_ids = torch.where( 252 | sorted_x_positions[:, 0].unsqueeze(0) <= x, 253 | torch.argmin(s, dim=2) + 1, 254 | torch.tensor(0, device=x.device), 255 | ).unsqueeze(2) 256 | 257 | selected_betas = torch.gather(betas.unsqueeze(0).expand(bs, -1, -1), dim=2, index=b_ids).squeeze(2) 258 | selected_breakpoints = torch.gather( 259 | breakpoints.unsqueeze(0).expand(bs, -1, -1), dim=2, index=b_ids 260 | ).squeeze(2) 261 | selected_slopes = torch.gather(slopes.unsqueeze(0).expand(bs, -1, -1), dim=2, index=b_ids).squeeze(2) 262 | cand = selected_betas + (x - selected_breakpoints) * selected_slopes 263 | return self.repack_input(cand, old_shape) 264 | 265 | 266 | class PWL(BaseSlopedPWL): 267 | r"""Piecewise Linear Function (PWL) module. 268 | 269 | The module takes the Tensor of (N, num_channels, ...) shape and returns the processed Tensor of the same shape. 270 | Each entry in the input tensor is processed by the PWL function. There are num_channels separate PWL functions, 271 | the PWL function used depends on the channel. 272 | 273 | The x coordinates of the breakpoints are initialized randomly from the Gaussian with std of 2. You may want to 274 | use your own custom initialization depending on the use-case as the optimization is quite sensitive to the 275 | initialization of breakpoints. As long as your data is normalized (zero mean, unit variance) the default 276 | initialization should be fine. 277 | 278 | Arguments: 279 | num_channels (int): number of channels (or features) that this PWL should process. Each channel 280 | will get its own PWL function. 281 | num_breakpoints (int): number of PWL breakpoints. Total number of segments constructing the PWL is 282 | given by num_breakpoints + 1. This value is shared by all the PWL channels in this module. 283 | """ 284 | 285 | def __init__(self, num_channels, num_breakpoints): 286 | super(PWL, self).__init__(num_channels, num_breakpoints, num_x_points=num_breakpoints) 287 | self.slopes = torch.nn.Parameter(torch.Tensor(self.num_channels, self.num_breakpoints + 1)) 288 | self.biases = torch.nn.Parameter(torch.Tensor(self.num_channels)) 289 | self._reset_params() 290 | 291 | def _reset_params(self): 292 | BasePWLX._reset_x_points(self) 293 | torch.nn.init.ones_(self.slopes) 294 | self.slopes.data[:,:(self.num_breakpoints + 1)//2] = 0.0 295 | print() 296 | with torch.no_grad(): 297 | self.biases.copy_(torch.zeros_like(self.biases)) 298 | 299 | 300 | def get_biases(self): 301 | return self.biases 302 | 303 | def get_x_positions(self): 304 | return self.x_positions 305 | 306 | def get_slopes(self): 307 | return self.slopes 308 | 309 | 310 | class MonoPWL(PWL): 311 | r"""Piecewise Linear Function (PWL) module with the monotonicity constraint. 312 | 313 | The module takes the Tensor of (N, num_channels, ...) shape and returns the processed Tensor of the same shape. 314 | Each entry in the input tensor is processed by the PWL function. There are num_channels separate PWL functions, 315 | the PWL function used depends on the channel. Each PWL is guaranteed to have the requested monotonicity. 316 | 317 | The x coordinates of the breakpoints are initialized randomly from the Gaussian with std of 2. You may want to 318 | use your own custom initialization depending on the use-case as the optimization is quite sensitive to the 319 | initialization of breakpoints. As long as your data is normalized (zero mean, unit variance) the default 320 | initialization should be fine. 321 | 322 | Arguments: 323 | num_channels (int): number of channels (or features) that this PWL should process. Each channel 324 | will get its own PWL function. 325 | num_breakpoints (int): number of PWL breakpoints. Total number of segments constructing the PWL is 326 | given by num_breakpoints + 1. This value is shared by all the PWL channels in this module. 327 | monotonicity (int, Tensor): Monotonicty constraint, the monotonicity can be either +1 (increasing), 328 | 0 (no constraint) or -1 (decreasing). You can provide either an int to set the constraint 329 | for all the channels or a long Tensor of shape [num_channels]. All the entries must be in -1, 0, +1. 330 | """ 331 | 332 | def __init__(self, num_channels, num_breakpoints, monotonicity=1): 333 | super(MonoPWL, self).__init__(num_channels=num_channels, num_breakpoints=num_breakpoints) 334 | self.register_buffer("monotonicity", get_monotonicity(monotonicity, self.num_channels)) 335 | with torch.no_grad(): 336 | mono_mul = torch.where( 337 | torch.eq(self.monotonicity, 0.0), 338 | torch.tensor(1.0, device=self.monotonicity.device), 339 | self.monotonicity, 340 | ) 341 | self.biases.copy_(self.biases * mono_mul) 342 | 343 | def get_slopes(self): 344 | return torch.where( 345 | torch.eq(self.monotonicity, 0.0).unsqueeze(1), 346 | self.slopes, 347 | torch.abs(self.slopes) * self.monotonicity.unsqueeze(1), 348 | ) 349 | 350 | -------------------------------------------------------------------------------- /torchpwl/pwl_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | 5 | from .pwl import PointPWL, MonoPointPWL, SlopedPWL, MonoSlopedPWL 6 | 7 | TOLERANCE = 1e-4 8 | 9 | torch.manual_seed(11) 10 | 11 | 12 | def get_x(num_channels, batch_size=37, std=3.): 13 | return torch.Tensor(batch_size, num_channels).normal_(mean=0., std=std) 14 | 15 | 16 | @pytest.mark.parametrize("pwl_module", 17 | [PointPWL, MonoPointPWL, SlopedPWL, MonoSlopedPWL]) 18 | @pytest.mark.parametrize("num_channels", [1, 3, 5]) 19 | @pytest.mark.parametrize("num_breakpoints", [1, 7, 11]) 20 | def test_pwl_init(pwl_module, num_channels, num_breakpoints): 21 | module = pwl_module( 22 | num_channels=num_channels, num_breakpoints=num_breakpoints) 23 | x = get_x(num_channels) 24 | y = module(x) 25 | 26 | 27 | @pytest.mark.parametrize("pwl_module", 28 | [SlopedPWL, MonoSlopedPWL, MonoPointPWL]) 29 | @pytest.mark.parametrize("num_channels", [1, 3]) 30 | @pytest.mark.parametrize("num_breakpoints", [1, 7]) 31 | def test_pwl_default_init_response(pwl_module, num_channels, num_breakpoints): 32 | module = pwl_module( 33 | num_channels=num_channels, num_breakpoints=num_breakpoints) 34 | x = get_x(num_channels) 35 | y = module(x) 36 | # Should initialize to y = x by default. 37 | expected_y = x 38 | assert torch.max(torch.abs(y - expected_y)) < TOLERANCE 39 | 40 | 41 | @pytest.mark.parametrize("pwl_module", [MonoSlopedPWL]) 42 | @pytest.mark.parametrize("num_channels", [1, 3]) 43 | @pytest.mark.parametrize("num_breakpoints", [1, 7]) 44 | @pytest.mark.parametrize("monotonicity", [-1, 0, 1]) 45 | def test_pwl_default_init_mono_response(pwl_module, num_channels, 46 | num_breakpoints, monotonicity): 47 | module = pwl_module( 48 | num_channels=num_channels, 49 | num_breakpoints=num_breakpoints, 50 | monotonicity=monotonicity) 51 | x = get_x(num_channels) 52 | y = module(x) 53 | # Should initialize to y = x if monotonicity is 1 or 0, otherwise y = -x 54 | expected_y = x if monotonicity in (1, 0) else -x 55 | assert torch.max(torch.abs(y - expected_y)) < TOLERANCE 56 | 57 | 58 | @pytest.mark.parametrize("pwl_module", [MonoSlopedPWL]) 59 | @pytest.mark.parametrize("num_channels", [1, 3]) 60 | @pytest.mark.parametrize("num_breakpoints", [1, 7]) 61 | def test_pwl_default_init_multi_mono_response(pwl_module, num_channels, 62 | num_breakpoints): 63 | monotonicity = torch.Tensor(num_channels).normal_(std=100).long() % 3 - 1 64 | module = pwl_module( 65 | num_channels=num_channels, 66 | num_breakpoints=num_breakpoints, 67 | monotonicity=monotonicity) 68 | x = get_x(num_channels) 69 | y = module(x) 70 | # Should initialize to y = x if monotonicity is 1 or 0, otherwise y = -x 71 | expected_y = torch.where(torch.eq(monotonicity, -1).unsqueeze(0), -x, x) 72 | assert torch.max(torch.abs(y - expected_y)) < TOLERANCE 73 | 74 | 75 | @pytest.mark.parametrize("pwl_module", [SlopedPWL, MonoSlopedPWL]) 76 | @pytest.mark.parametrize("num_channels", [1, 3]) 77 | @pytest.mark.parametrize("num_breakpoints", [1, 7]) 78 | def test_pwl_gradient_flows(pwl_module, num_channels, num_breakpoints): 79 | module = pwl_module( 80 | num_channels=num_channels, num_breakpoints=num_breakpoints) 81 | x = get_x(num_channels) 82 | x.requires_grad = True 83 | y = module(x) 84 | torch.sum(y).backward() 85 | expected_grad = 1. 86 | assert torch.max(torch.abs(x.grad - expected_grad)) < TOLERANCE 87 | 88 | 89 | @pytest.mark.parametrize("pwl_module", [SlopedPWL, MonoSlopedPWL]) 90 | @pytest.mark.parametrize("num_channels", [1, 3]) 91 | @pytest.mark.parametrize("num_breakpoints", [1, 7]) 92 | def test_pwl_sloped_correct_num_breakpoints(pwl_module, num_channels, 93 | num_breakpoints): 94 | module = pwl_module( 95 | num_channels=num_channels, num_breakpoints=num_breakpoints) 96 | assert list(module.get_sorted_x_positions().shape) == [ 97 | num_channels, num_breakpoints 98 | ] 99 | 100 | 101 | @pytest.mark.parametrize("pwl_module", 102 | [SlopedPWL, MonoSlopedPWL, PointPWL, MonoPointPWL]) 103 | @pytest.mark.parametrize("num_channels", [1, 3]) 104 | @pytest.mark.parametrize("num_breakpoints", [1, 2, 3, 4]) 105 | def test_pwl_is_continous(pwl_module, num_channels, num_breakpoints): 106 | module = pwl_module( 107 | num_channels=num_channels, num_breakpoints=num_breakpoints) 108 | with torch.no_grad(): 109 | for parameter in module.parameters(): 110 | parameter.normal_() 111 | x = torch.linspace( 112 | -4., 4., steps=10000).unsqueeze(1).expand(-1, num_channels) 113 | y = module(x) 114 | dy = torch.roll(y, shifts=-1, dims=0) - y 115 | dx = torch.roll(x, shifts=-1, dims=0) - x 116 | grad = dy / dx 117 | if isinstance(module, (PointPWL, MonoPointPWL)): 118 | allowed_grad = torch.max(4 / module.get_spreads()) 119 | else: 120 | allowed_grad = 4 121 | assert torch.max(abs(grad)) < allowed_grad 122 | 123 | 124 | @pytest.mark.parametrize("pwl_module", [SlopedPWL, MonoSlopedPWL]) 125 | @pytest.mark.parametrize("num_channels", [1, 3]) 126 | @pytest.mark.parametrize("num_breakpoints", [1, 2, 3, 4]) 127 | @pytest.mark.parametrize( 128 | "optimizer_fn", 129 | [ 130 | #lambda params: torch.optim.SGD(params=params, lr=0.1, momentum=0.5), 131 | lambda params: torch.optim.Adam(params=params, lr=0.2), 132 | ]) 133 | def test_pwl_fits(pwl_module, num_channels, num_breakpoints, optimizer_fn): 134 | module = pwl_module( 135 | num_channels=num_channels, num_breakpoints=num_breakpoints) 136 | bs = 128 137 | opt = optimizer_fn(module.parameters()) 138 | steps = 4000 139 | loss_ = 0 140 | desired_loss = 0.02 141 | for step in range(steps): 142 | x = torch.Tensor(np.random.normal(0, scale=2, size=(bs, num_channels))) 143 | expected_y = torch.Tensor( 144 | np.random.normal(0, scale=0.1, size=(bs, num_channels)) + 145 | np.where(x > 0.2, x, 0.2)) 146 | 147 | y = module(x) 148 | loss = torch.mean((expected_y - y)**2) 149 | opt.zero_grad() 150 | loss.backward() 151 | opt.step() 152 | if step % 10 == 0: 153 | print(loss.item()) 154 | loss_ = 0.8 * loss_ + loss.item() * 0.2 155 | if loss_ < desired_loss: 156 | break 157 | assert loss_ < desired_loss 158 | 159 | 160 | @pytest.mark.parametrize("pwl_module", [SlopedPWL, MonoSlopedPWL]) 161 | @pytest.mark.parametrize("input_shape", [ 162 | (11, 5, 7, 3), 163 | (11, 6), 164 | (11, 1), 165 | (11, 1, 1, 1), 166 | (5, 1, 2, 1), 167 | (5, 2, 2, 1), 168 | (5, 2, 2), 169 | ]) 170 | def test_input_packing(pwl_module, input_shape): 171 | num_channels = input_shape[1] 172 | b = pwl_module(num_channels=num_channels, num_breakpoints=2) 173 | inp = torch.Tensor(*input_shape).normal_() 174 | unpacked_inp = b.unpack_input(inp) 175 | assert unpacked_inp.shape[1] == num_channels 176 | assert len(unpacked_inp.shape) == 2 177 | inp_restored = b.repack_input(unpacked_inp, inp.shape) 178 | assert list(inp_restored.shape) == list(inp.shape) 179 | assert torch.max(torch.abs(inp_restored - inp)).item() < TOLERANCE 180 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | ADMM_-CSNET training example (v1) with MR slices 3 | By Yan Yang, Jian Sun, Huibin Li, Zongben Xu 4 | 5 | Please cite the below paper for the code: 6 | 7 | Yan Yang, Jian Sun, Huibin Li, Zongben Xu. ADMM-CSNet: A Deep Learning Approach for Image Compressive Sensing, 8 | TPAMI(2019). 9 | """ 10 | from __future__ import print_function, division 11 | import sys 12 | import os 13 | import torch 14 | import argparse 15 | from network.CSNet_Layers import CSNetADMMLayer 16 | from utils.dataset import get_data 17 | import torch.utils.data as data 18 | from utils.my_loss import MyLoss 19 | import time 20 | from utils.metric import complex_psnr 21 | from tensorboardX import SummaryWriter 22 | import gc 23 | import torchvision.utils as utils 24 | from scipy.io import loadmat 25 | from utils.fftc import * 26 | from os.path import join 27 | 28 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 29 | 30 | if __name__ == '__main__': 31 | 32 | ############################################################################### 33 | # parameters 34 | ############################################################################### 35 | parser = argparse.ArgumentParser(description=' main ') 36 | parser.add_argument('--data_dir', default='data/', type=str, 37 | help='directory of data') 38 | parser.add_argument('--batch_size', default=1, type=int, help='batch size') 39 | parser.add_argument('--num_epoch', default=1000, type=int, help='number of epochs') 40 | parser.add_argument('--outf', type=str, default='logs_csnet', help='path of log files') 41 | args = parser.parse_args() 42 | 43 | ############################################################################### 44 | # callable methods 45 | ############################################################################### 46 | 47 | def adjust_learning_rate(opt, epo, lr): 48 | """Sets the learning rate to the initial LR decayed by 5 every 50 epochs""" 49 | lr = lr * (0.5 ** (epo // 50)) 50 | for param_group in opt.param_groups: 51 | param_group['lr'] = lr 52 | 53 | 54 | ############################################################################### 55 | # dataset 56 | ############################################################################### 57 | train, test, validate = get_data(args.data_dir) 58 | len_train, len_test, len_validate = len(train), len(test), len(validate) 59 | print("len_train: ", len_train, "\tlen_test:", len_test, "\tlen_test:", len_test) 60 | train_loader = data.DataLoader(dataset=train, batch_size=args.batch_size, shuffle=False, num_workers=4, 61 | pin_memory=False) 62 | test_loader = data.DataLoader(dataset=test, batch_size=args.batch_size, shuffle=False, num_workers=4, 63 | pin_memory=False) 64 | valid_loader = data.DataLoader(dataset=validate, batch_size=args.batch_size, shuffle=False, num_workers=4, 65 | pin_memory=False) 66 | 67 | ############################################################################### 68 | # mask 69 | ############################################################################### 70 | dir = 'data/mask_20' 71 | data = loadmat(join(dir, os.listdir(dir)[0])) 72 | mask_data = data['mask'] 73 | mask = ifftshift(torch.Tensor(mask_data)).cuda() 74 | 75 | ############################################################################### 76 | # ADMM-CSNET model 77 | ############################################################################### 78 | model = CSNetADMMLayer(mask).cuda() 79 | 80 | ############################################################################### 81 | # Adam optimizer 82 | ############################################################################### 83 | optimizer = torch.optim.Adam(model.parameters()) 84 | 85 | ############################################################################### 86 | # self-define loss 87 | ############################################################################### 88 | criterion = MyLoss().cuda() 89 | 90 | writer = SummaryWriter(args.outf) 91 | ############################################################################### 92 | # train 93 | ############################################################################### 94 | print("start training...") 95 | start_time = time.time() 96 | for epoch in range(0, args.num_epoch + 1): 97 | total_loss_org = 0 98 | train_batches = 0 99 | train_psnr = 0 100 | adjust_learning_rate(optimizer, epoch, lr=0.002) 101 | # ===================train========================== 102 | for batch_idx, (label, num) in enumerate(train_loader): 103 | full_kspace = torch.fft.fft2(label.cuda()) 104 | output = model(full_kspace) 105 | optimizer.zero_grad() 106 | loss_normal = criterion(output, label.cuda()) 107 | loss_normal.backward() 108 | optimizer.step() 109 | total_loss_org += loss_normal.data.item() 110 | train_batches += 1 111 | train_psnr_value = complex_psnr(abs(output).cpu().detach().numpy(), abs(label).cpu().detach().numpy(), 112 | peak='normalized') 113 | train_psnr += train_psnr_value 114 | print("[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f" % 115 | (epoch + 1, batch_idx + 1, len(train_loader), total_loss_org / (batch_idx + 1), 116 | train_psnr / (batch_idx + 1))) 117 | train_psnr /= train_batches 118 | total_loss_org /= train_batches 119 | print("train_loss: ", total_loss_org) 120 | print("train_psnr: ", train_psnr) 121 | writer.add_scalar('psnr on train data', train_psnr, epoch) 122 | if epoch % 10 == 0: 123 | torch.save(model.state_dict(), 124 | os.path.join(args.outf, 'model{}.pth'.format(epoch))) 125 | model.eval() 126 | ############################################################################### 127 | # validate 128 | ############################################################################### 129 | validate_err = 0 130 | validate_psnr = 0 131 | validate_batches = 0 132 | with torch.no_grad(): 133 | for batch_idx, (label, num) in enumerate(valid_loader): 134 | gc.collect() 135 | torch.cuda.empty_cache() 136 | full_kspace = torch.fft.fft2(label.cuda()) 137 | val_output = model(full_kspace) 138 | validate_loss_normal = criterion(val_output, label.cuda()) 139 | validate_err += validate_loss_normal.item() 140 | validate_batches += 1 141 | valid_psnr_value = complex_psnr(abs(val_output).cpu().numpy(), abs(label).cpu().numpy(), 142 | peak='normalized') 143 | validate_psnr += valid_psnr_value 144 | if epoch % 10 == 0: 145 | resconstructed_image = utils.make_grid(abs(val_output.data.squeeze().cpu()), nrow=5, normalize=True, 146 | scale_each=True) 147 | writer.add_image('reconstructed image', resconstructed_image, epoch) 148 | 149 | validate_err /= validate_batches 150 | validate_psnr /= validate_batches 151 | print("valid_loss ", validate_err) 152 | print("valid_psnr ", validate_psnr) 153 | writer.add_scalar('psnr on valid data', validate_psnr, epoch) 154 | ############################################################################### 155 | # test 156 | ############################################################################### 157 | test_err = 0 158 | test_psnr = 0 159 | test_batches = 0 160 | model.eval() 161 | for batch_idx, (label, num) in enumerate(test_loader): 162 | gc.collect() 163 | with torch.no_grad(): 164 | full_kspace = torch.fft.fft2(label.cuda()) 165 | test_output = model(full_kspace) 166 | test_loss_normal = criterion(test_output, label.cuda()) 167 | test_err += test_loss_normal.item() 168 | test_batches += 1 169 | test_psnr_value = complex_psnr(abs(test_output).cpu().numpy(), abs(label).cpu().numpy(), 170 | peak='normalized') 171 | test_psnr += test_psnr_value 172 | test_err /= test_batches 173 | test_psnr /= test_batches 174 | print("test_loss ", test_err) 175 | print("test_psnr ", test_psnr) 176 | writer.add_scalar('psnr on test data', test_psnr, epoch) 177 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import torch.utils.data as data 3 | import os 4 | import os.path 5 | from scipy.io import loadmat 6 | import re 7 | import numpy as np 8 | import torch 9 | import h5py 10 | 11 | 12 | class DataSet(data.Dataset): 13 | 14 | def __init__(self, dir): 15 | imgs = os.listdir(dir) 16 | imgs.sort() 17 | self.imgs = [os.path.join(dir, img) for img in imgs] 18 | self.num = [re.sub("\D", "", img) for img in imgs] 19 | 20 | def __getitem__(self, index): 21 | num = self.num[index] 22 | img_path = self.imgs[index] 23 | data = loadmat(img_path) 24 | 25 | data_label = data['data'][0][0][0][0][0][0] 26 | label_real = torch.Tensor(data_label.real) 27 | label_real = label_real.view(1, 256, 256) 28 | label_imag = torch.Tensor(data_label.imag) 29 | label_imag = label_imag.view(1, 256, 256) 30 | 31 | label = torch.complex(label_real, label_imag) 32 | return label, num 33 | 34 | def __len__(self): 35 | return len(self.imgs) 36 | 37 | 38 | def get_data(load_root): 39 | 40 | train = load_root + 'train' 41 | test = load_root + 'test' 42 | validate = load_root + 'validate' 43 | train_data = DataSet(train) 44 | test_data = DataSet(test) 45 | validate_data = DataSet(validate) 46 | return train_data, test_data, validate_data 47 | 48 | 49 | def get_test_data(load_root): 50 | 51 | test = load_root + 'test' 52 | test_data = DataSet(test) 53 | return test_data -------------------------------------------------------------------------------- /utils/fftc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | This source code is licensed under the MIT license found in the 4 | LICENSE file in the root directory of this source tree. 5 | """ 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | from packaging import version 11 | 12 | if version.parse(torch.__version__) >= version.parse("1.7.0"): 13 | import torch.fft # type: ignore 14 | 15 | 16 | def fft2c_old(data: torch.Tensor) -> torch.Tensor: 17 | """ 18 | Apply centered 2 dimensional Fast Fourier Transform. 19 | Args: 20 | data: Complex valued input data containing at least 3 dimensions: 21 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 22 | 2. All other dimensions are assumed to be batch dimensions. 23 | Returns: 24 | The FFT of the input. 25 | """ 26 | if not data.shape[-1] == 2: 27 | raise ValueError("Tensor does not have separate complex dim.") 28 | 29 | data = ifftshift(data, dim=[-3, -2]) 30 | data = torch.fft(data, 2, normalized=True) 31 | data = fftshift(data, dim=[-3, -2]) 32 | 33 | return data 34 | 35 | 36 | def ifft2c_old(data: torch.Tensor) -> torch.Tensor: 37 | """ 38 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 39 | Args: 40 | data: Complex valued input data containing at least 3 dimensions: 41 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 42 | 2. All other dimensions are assumed to be batch dimensions. 43 | Returns: 44 | The IFFT of the input. 45 | """ 46 | if not data.shape[-1] == 2: 47 | raise ValueError("Tensor does not have separate complex dim.") 48 | 49 | data = ifftshift(data, dim=[-3, -2]) 50 | data = torch.ifft(data, 2, normalized=True) 51 | data = fftshift(data, dim=[-3, -2]) 52 | 53 | return data 54 | 55 | 56 | def fft2c_new(data: torch.Tensor) -> torch.Tensor: 57 | """ 58 | Apply centered 2 dimensional Fast Fourier Transform. 59 | Args: 60 | data: Complex valued input data containing at least 3 dimensions: 61 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 62 | 2. All other dimensions are assumed to be batch dimensions. 63 | Returns: 64 | The FFT of the input. 65 | """ 66 | if not data.shape[-1] == 2: 67 | raise ValueError("Tensor does not have separate complex dim.") 68 | 69 | data = ifftshift(data, dim=[-3, -2]) 70 | data = torch.view_as_real( 71 | torch.fft.fftn( # type: ignore 72 | torch.view_as_complex(data), dim=(-2, -1), norm="ortho" 73 | ) 74 | ) 75 | data = fftshift(data, dim=[-3, -2]) 76 | 77 | return data 78 | 79 | 80 | def ifft2c_new(data: torch.Tensor) -> torch.Tensor: 81 | """ 82 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 83 | Args: 84 | data: Complex valued input data containing at least 3 dimensions: 85 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 86 | 2. All other dimensions are assumed to be batch dimensions. 87 | Returns: 88 | The IFFT of the input. 89 | """ 90 | if not data.shape[-1] == 2: 91 | raise ValueError("Tensor does not have separate complex dim.") 92 | 93 | data = ifftshift(data, dim=[-3, -2]) 94 | data = torch.view_as_real( 95 | torch.fft.ifftn( # type: ignore 96 | torch.view_as_complex(data), dim=(-2, -1), norm="ortho" 97 | ) 98 | ) 99 | data = fftshift(data, dim=[-3, -2]) 100 | 101 | return data 102 | 103 | 104 | # Helper functions 105 | 106 | 107 | def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor: 108 | """ 109 | Similar to roll but for only one dim. 110 | Args: 111 | x: A PyTorch tensor. 112 | shift: Amount to roll. 113 | dim: Which dimension to roll. 114 | Returns: 115 | Rolled version of x. 116 | """ 117 | shift = shift % x.size(dim) 118 | if shift == 0: 119 | return x 120 | 121 | left = x.narrow(dim, 0, x.size(dim) - shift) 122 | right = x.narrow(dim, x.size(dim) - shift, shift) 123 | 124 | return torch.cat((right, left), dim=dim) 125 | 126 | 127 | def roll( 128 | x: torch.Tensor, 129 | shift: List[int], 130 | dim: List[int], 131 | ) -> torch.Tensor: 132 | """ 133 | Similar to np.roll but applies to PyTorch Tensors. 134 | Args: 135 | x: A PyTorch tensor. 136 | shift: Amount to roll. 137 | dim: Which dimension to roll. 138 | Returns: 139 | Rolled version of x. 140 | """ 141 | if len(shift) != len(dim): 142 | raise ValueError("len(shift) must match len(dim)") 143 | 144 | for (s, d) in zip(shift, dim): 145 | x = roll_one_dim(x, s, d) 146 | 147 | return x 148 | 149 | 150 | def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: 151 | """ 152 | Similar to np.fft.fftshift but applies to PyTorch Tensors 153 | Args: 154 | x: A PyTorch tensor. 155 | dim: Which dimension to fftshift. 156 | Returns: 157 | fftshifted version of x. 158 | """ 159 | if dim is None: 160 | # this weird code is necessary for toch.jit.script typing 161 | dim = [0] * (x.dim()) 162 | for i in range(1, x.dim()): 163 | dim[i] = i 164 | 165 | # also necessary for torch.jit.script 166 | shift = [0] * len(dim) 167 | for i, dim_num in enumerate(dim): 168 | shift[i] = x.shape[dim_num] // 2 169 | 170 | return roll(x, shift, dim) 171 | 172 | 173 | def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: 174 | """ 175 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 176 | Args: 177 | x: A PyTorch tensor. 178 | dim: Which dimension to ifftshift. 179 | Returns: 180 | ifftshifted version of x. 181 | """ 182 | if dim is None: 183 | # this weird code is necessary for toch.jit.script typing 184 | dim = [0] * (x.dim()) 185 | for i in range(1, x.dim()): 186 | dim[i] = i 187 | 188 | # also necessary for torch.jit.script 189 | shift = [0] * len(dim) 190 | for i, dim_num in enumerate(dim): 191 | shift[i] = (x.shape[dim_num] + 1) // 2 192 | 193 | return roll(x, shift, dim) -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import math 4 | 5 | def mse(x, y): 6 | return np.mean(np.abs(x - y)**2) 7 | 8 | 9 | def psnr(x, y): 10 | ''' 11 | Measures the PSNR of recon w.r.t x. 12 | Image must be of either integer (0, 256) or float value (0,1) 13 | :param x: [m,n] 14 | :param y: [m,n] 15 | :return: 16 | ''' 17 | assert x.shape == y.shape 18 | assert x.dtype == y.dtype or np.issubdtype(x.dtype, np.float) \ 19 | and np.issubdtype(y.dtype, np.float) 20 | if x.dtype == np.uint8: 21 | max_intensity = 256 22 | else: 23 | max_intensity = 1 24 | 25 | mse = np.sum((x - y) ** 2).astype(float) / x.size 26 | return 20 * np.log10(max_intensity) - 10 * np.log10(mse) 27 | 28 | 29 | def complex_psnr(x, y, peak='normalized'): 30 | ''' 31 | x: reference image 32 | y: reconstructed image 33 | peak: normalised or max 34 | Notice that ``abs'' squares 35 | Be careful with the order, since peak intensity is taken from the reference 36 | image (taking from reconstruction yields a different value). 37 | ''' 38 | # a = np.abs(x - y) ** 2 39 | mse = np.mean(np.abs(x - y)**2) 40 | if peak == 'max': 41 | return 10*np.log10(np.max(np.abs(x))**2/mse) 42 | else: 43 | return 10*np.log10(1./mse + 1e-5) 44 | 45 | def nrmse(outputs, targets): 46 | """ 47 | Normalized root-mean square error 48 | :param outputs: Module's outputs 49 | :param targets: Target signal to be learned 50 | :return: Normalized root-mean square deviation 51 | """ 52 | # Flatten tensors 53 | outputs = outputs.reshape(-1) 54 | targets = targets.reshape(-1) 55 | 56 | # Check dim 57 | if outputs.size() != targets.size(): 58 | raise ValueError(u"Ouputs and targets tensors don have the same number of elements") 59 | # end if 60 | 61 | # Normalization with N-1 62 | var = torch.std(targets) ** 2 63 | 64 | # Error 65 | error = (targets - outputs) ** 2 66 | 67 | # Return 68 | return float(math.sqrt(torch.mean(error) / var)) -------------------------------------------------------------------------------- /utils/my_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MyLoss(torch.nn.Module): 4 | def __init__(self): 5 | super(MyLoss, self).__init__() 6 | 7 | def forward(self, output, target): 8 | 9 | return torch.norm((output - target),'fro') / torch.norm(target,'fro') 10 | 11 | --------------------------------------------------------------------------------