├── .ipynb_checkpoints ├── Compressed sensing-checkpoint.ipynb └── MRI-checkpoint.ipynb ├── ASGLD_PR.py ├── Compressed sensing.ipynb ├── MRI.ipynb ├── PRImageSet ├── .DS_Store ├── Natural256 │ ├── .directory │ ├── barbara.png │ ├── boat.png │ ├── cameraman.png │ ├── couple.png │ ├── peppers.png │ └── streamandbridge.png └── Unatural256 │ ├── Butterfly.png │ ├── Ecoli.png │ ├── PillarsofCreation.png │ ├── Pollen.png │ ├── TadpoleGalaxy.png │ └── Yeast.png ├── README.md ├── cs_nonoise.py ├── data ├── CS │ └── Set11 │ │ ├── Monarch.tif │ │ ├── Parrots.tif │ │ ├── barbara.tif │ │ ├── boats.tif │ │ ├── cameraman.tif │ │ ├── fingerprint.tif │ │ ├── flinstones.tif │ │ ├── foreman.tif │ │ ├── house.tif │ │ ├── lena256.tif │ │ └── peppers256.tif ├── MRI │ └── 321.png └── mask │ ├── mask1d0.25.mat │ ├── mask1s0.25.mat │ └── mask2r0.25.mat ├── models ├── __init__.py ├── __pycache__ │ ├── Qskip.cpython-37.pyc │ ├── Qskip.cpython-38.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── common.cpython-37.pyc │ ├── common.cpython-38.pyc │ ├── decoder.cpython-37.pyc │ ├── decoder.cpython-38.pyc │ ├── downsampler.cpython-37.pyc │ ├── downsampler.cpython-38.pyc │ ├── drop_skip.cpython-37.pyc │ ├── drop_skip.cpython-38.pyc │ ├── fcn.cpython-37.pyc │ ├── fcn.cpython-38.pyc │ ├── multi_skip.cpython-37.pyc │ ├── multi_skip.cpython-38.pyc │ ├── resnet.cpython-37.pyc │ ├── resnet.cpython-38.pyc │ ├── skip.cpython-37.pyc │ ├── skip.cpython-38.pyc │ ├── texture_nets.cpython-37.pyc │ ├── texture_nets.cpython-38.pyc │ ├── unet.cpython-37.pyc │ └── unet.cpython-38.pyc ├── common.py ├── downsampler.py ├── resnet.py ├── skip.py └── unet.py ├── requirements.txt └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── common_mri.cpython-37.pyc ├── common_mri.cpython-38.pyc ├── common_utils.cpython-37.pyc ├── common_utils.cpython-38.pyc ├── denoising_utils.cpython-37.pyc ├── denoising_utils.cpython-38.pyc ├── inpainting_utils.cpython-38.pyc ├── sr_utils.cpython-37.pyc ├── sr_utils.cpython-38.pyc ├── transform_mri.cpython-37.pyc └── transform_mri.cpython-38.pyc ├── common_utils.py ├── denoising_utils.py └── transform_mri.py /ASGLD_PR.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | from __future__ import print_function 4 | import argparse 5 | from datetime import datetime 6 | import os, sys 7 | import cv2 8 | import numpy as np 9 | from models import * 10 | import math 11 | import glob 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim 15 | from torch.optim.lr_scheduler import MultiStepLR 16 | import scipy.io as sio 17 | 18 | from skimage.measure.simple_metrics import compare_psnr 19 | from skimage.measure import compare_ssim 20 | from utils.common_utils import * 21 | from torch.distributions import Poisson 22 | import matplotlib.pyplot as plt 23 | 24 | parser = argparse.ArgumentParser(description="PR_Poisson_noise") 25 | parser.add_argument("--imgs-dir", type=str, default='Unatural256', help="directory of testing images") 26 | parser.add_argument("--masktype", type=str, default='bipolar', help="type of mask") 27 | parser.add_argument("--noisetype", type=str, default='poisson', help="type of noise") 28 | parser.add_argument("--nummask", type=int, default=3, help="number of masks") 29 | parser.add_argument("--gpuid", type=int, default=0, help="gpu id") 30 | parser.add_argument("--numits", type=int, default=10000,help="Number of training iterations") 31 | 32 | opts = parser.parse_args() 33 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opts.gpuid) 34 | torch.set_num_threads(8) 35 | 36 | torch.backends.cudnn.enabled = True 37 | torch.backends.cudnn.benchmark =True 38 | dtype = torch.cuda.FloatTensor 39 | 40 | # define the add noise module 41 | def add_noise(model, nlevel): 42 | for n in [x for x in model.parameters() if len(x.size()) == 4]: 43 | noise = torch.randn(n.size())*nlevel 44 | noise = noise.type(dtype) 45 | n.data = n.data + noise 46 | 47 | class Logger(object): 48 | def __init__(self, fileN="Default.log"): 49 | self.terminal = sys.__stdout__ 50 | self.log = open(fileN, "a+") 51 | 52 | def write(self, message): 53 | self.terminal.write(message) 54 | self.log.write(message) 55 | self.flush() 56 | # self.close() 57 | def flush(self): 58 | self.log.flush() 59 | 60 | 61 | input_depth = 32 62 | lr = 0.01 # learning rate 63 | 64 | INPUT = 'noise' 65 | 66 | reg_noise_std = 0.01 # the input noise helps the performance 67 | psnr_trace=[] 68 | psnr_noavg_trace = [] 69 | loss_trace=[] 70 | trace ={} 71 | 72 | burnin_iter = 2000 73 | weight_decay = 5e-8 # help for the performance 74 | 75 | sgld_mean_each = 0 76 | sgld_mean_tmp = 0 77 | 78 | class Ax_pr_cdp(nn.Module): 79 | def __init__(self, c,h,w,masktype): 80 | super(Ax_pr_cdp, self).__init__() 81 | self.c=c 82 | self.h=h 83 | self.w=w 84 | if masktype == 'uniform': 85 | ang = 2*torch.acos(torch.zeros(1)).item()*2*torch.rand(c,h,w,1) 86 | self.mask = torch.cat((torch.cos(ang),torch.sin(ang)), 3).cuda() 87 | elif masktype == 'bipolar': 88 | self.mask = 2.0*(torch.bernoulli(0.5*torch.ones(c,h,w,2))-0.5).cuda() 89 | 90 | def forward(self, img): 91 | img_s=img.reshape(1,self.h,self.w,1).repeat(self.c,1,1,2)*self.mask 92 | meas=torch.fft(img_s,2,normalized=True) 93 | meas= torch.sqrt(meas[:,:,:,0]**2+meas[:,:,:,1]**2) 94 | return meas 95 | 96 | if opts.noisetype == 'poisson': 97 | alpha = 9 # Noise level, choices from [9,27, 81] 98 | if opts.imgs_dir == 'Natural256': 99 | nlevel = 1e-10 # tuneable for different noise level, which depends on the noise level and the accuracy of the estimated noise 100 | else: 101 | nlevel = 1e-12 # the smaller the alpha value, the smaller the nlevel is 102 | elif opts.noisetype == 'awgn': 103 | alpha = 10 #[10, 15, 20] 104 | if opts.imgs_dir == 'Natural256': 105 | nlevel = 1e-5 # tuneable for different noise level, which depends on the noise level and the accuracy of the estimated noise 106 | else: 107 | nlevel = 1e-7 # Generally for unnatural256 images, we set nlevel smaller for better performance 108 | else: 109 | alpha = 0 110 | nlevel = 0.0001 111 | 112 | file_name = glob.glob('./PRImageSet/'+opts.imgs_dir+'/*.png' ) 113 | data_num = len(file_name) 114 | psnr_a = np.zeros((data_num,)) 115 | ssim_a = np.zeros((data_num,)) 116 | 117 | 118 | psnr_aver = 0.0 119 | ssim_aver = 0.0 120 | num_iter = opts.numits 121 | for fi,Img_Name in enumerate(file_name): #[0:1] 122 | MODEL_PATH = './Github_PR_est_noise_poi_FINAL/PR_Recon_Results/%s_alpha%2d/%s/' % (opts.noisetype, opts.masktype, opts.imgs_dir,alpha,Img_Name[Img_Name.rfind('/')+1:-4]) 123 | if not os.path.exists(MODEL_PATH): 124 | os.makedirs(MODEL_PATH) 125 | sgld_mean_each = 0 126 | sgld_mean_tmp = 0 127 | img = np.array(cv2.imread(Img_Name, -1), dtype=np.float32)/255. 128 | if img.ndim == 2: 129 | Img = np.expand_dims(img, axis=0) 130 | else: 131 | Img = img.transpose(2,0,1) 132 | c,w,h = Img.shape 133 | Ax=Ax_pr_cdp(opts.nummask,w,h,opts.masktype) 134 | Img = np.expand_dims(Img,axis=1) 135 | Img_tensor = torch.FloatTensor(Img).cuda() 136 | 137 | meas_nf = Ax(Img_tensor) 138 | if opts.noisetype == 'poisson': 139 | meas_nf_before = meas_nf 140 | intensity_noise=alpha/255.0*(meas_nf)*torch.randn(meas_nf.shape).cuda() 141 | 142 | y2=meas_nf**2+intensity_noise 143 | y2=torch.clamp(y2,min=0.0) 144 | 145 | Measure= torch.sqrt(y2) 146 | error = torch.mean((Measure - meas_nf_before)**2) # oracle noise, which is not accessiable 147 | error_est = (alpha/255.0)**2/4 #/(Measure.shape[0]*Measure.shape[1]*Measure.shape[2]) 148 | sys.stdout = Logger(MODEL_PATH+'results.txt') 149 | print('error {:f} error_est {:f}'.format(error, error_est)) 150 | elif opts.noisetype == 'awgn': 151 | noise_std=torch.randn(meas_nf.shape).cuda() 152 | noise = noise_std*torch.norm(meas_nf)/torch.norm(noise_std)/float(np.sqrt(10.0**(alpha/10.0))) 153 | 154 | y2=meas_nf+noise 155 | Measure=torch.clamp(y2,min=0.0) 156 | error = torch.mean((Measure - meas_nf)**2) 157 | error_est = torch.mean(Measure**2)/(float(np.sqrt(10.0**(alpha/10.0))))**2 158 | sys.stdout = Logger(MODEL_PATH+'results.txt') 159 | print('error {:f} error_est {:f}'.format(error, error_est)) 160 | else : 161 | sys.stdout = Logger(MODEL_PATH+'results.txt') 162 | print('no noise') 163 | Measure = meas_nf 164 | error = 0 165 | 166 | 167 | img_name = Img_Name 168 | net_input = get_noise(input_depth, INPUT, (w,h)).type(dtype).detach() 169 | 170 | NET_TYPE = 'skip' #'skip' # UNet, ResNet 171 | net = get_net(input_depth, NET_TYPE, 'reflection', 172 | skip_n33d=128, 173 | skip_n33u=128, 174 | skip_n11=4, 175 | num_scales=5, 176 | upsample_mode='bilinear', 177 | n_channels=c).type(dtype) 178 | 179 | # Losses 180 | mse = torch.nn.MSELoss().type(dtype) 181 | 182 | 183 | # # Define closure and optimize 184 | 185 | def closure(): 186 | global i, num_iter, net_input, psnr_trace, psnr_noavg_trace, loss_trace, trace, sgld_mean_tmp, sgld_mean_each 187 | 188 | if reg_noise_std > 0: 189 | net_input = net_input_saved + (noise.normal_() * reg_noise_std) 190 | else: 191 | net_input = net_input_saved 192 | 193 | Img_rec = net(net_input) 194 | 195 | net_output = Ax(Img_rec) 196 | 197 | 198 | total_loss = mse(net_output, Measure) 199 | 200 | 201 | total_loss.backward() 202 | 203 | if i > burnin_iter: 204 | sgld_mean_each += Img_rec 205 | sgld_mean_tmp = sgld_mean_each / (i-burnin_iter) 206 | else: 207 | sgld_mean_tmp = Img_rec 208 | 209 | if (i + 1) % 10 == 0: 210 | psnr_noavg = compare_psnr(np.squeeze(torch_to_np(Img_rec)),img,1.) 211 | ssim_noavg = compare_ssim(np.squeeze(torch_to_np(Img_rec)),img,data_range = 1.) 212 | psnr = compare_psnr(np.squeeze(torch_to_np(sgld_mean_tmp)),img,1.) 213 | ssim = compare_ssim(np.squeeze(torch_to_np(sgld_mean_tmp)),img,data_range = 1.) 214 | now = datetime.now() 215 | sys.stdout = Logger(MODEL_PATH+'results.txt') 216 | print(img_name, "loss in ", i + 1, ":", total_loss.item(),"psnr_noavg:",psnr_noavg, "psnr:",psnr, now.strftime("%H:%M:%S")) 217 | 218 | loss_trace = np.append(loss_trace,total_loss.item()) 219 | psnr_trace= np.append(psnr_trace,psnr) 220 | psnr_noavg_trace = np.append(psnr_noavg_trace,psnr_noavg) 221 | 222 | trace['loss'] = loss_trace 223 | 224 | trace['psnr'] = psnr_trace 225 | trace['psnr_noavg'] = psnr_noavg_trace 226 | 227 | sio.savemat(MODEL_PATH+'trace.mat',trace) 228 | 229 | if i == num_iter - 1: 230 | # save the image 231 | plt.plot(psnr_trace) 232 | plt.plot(psnr_noavg_trace) 233 | plt.savefig(os.path.join(MODEL_PATH,'comp.png')) 234 | if (i + 1) % 500 == 0 : 235 | cv2.imwrite(MODEL_PATH+img_name[img_name.rfind('/')+1:-4] + '_bipolar_poisson_%2d_itn%d_%0.4f_%0.2f.png'%(alpha,i+1,psnr,ssim), np.int32(255*np.squeeze(torch_to_np(sgld_mean_tmp)))) 236 | 237 | i += 1 238 | 239 | return total_loss 240 | 241 | 242 | net_input_saved = net_input.detach().clone() 243 | noise = net_input.detach().clone() 244 | i = 0 245 | 246 | sys.stdout = Logger(MODEL_PATH+'results.txt') 247 | print('Starting optimization with ASGLD') 248 | optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay) 249 | scheduler = MultiStepLR(optimizer, milestones=[5000, 6000, 7000], gamma=0.5) # learning rates 250 | stop_iter = -1 251 | for j in range(num_iter): 252 | scheduler.step(j) 253 | optimizer.zero_grad() 254 | loss = closure() 255 | optimizer.step() 256 | nlevel_ = nlevel*math.exp(30*((error_est/loss)-1)) 257 | add_noise(net, nlevel_) 258 | 259 | 260 | 261 | 262 | 263 | Img_rec = np.squeeze(torch_to_np(sgld_mean_tmp)) 264 | 265 | torch.save({ 266 | 'net': net, 267 | 'net_input': net_input, 268 | 'iters': i + 1, 269 | 'net_state_dict': net.state_dict(), 270 | }, MODEL_PATH +img_name[img_name.rfind('/')+1:-4]+ 'checkpoint_best.pth') 271 | 272 | psnr_single = compare_psnr(img.reshape(h,w),Img_rec.reshape(h,w),1.) 273 | ssim_single = compare_ssim(img.reshape(h,w),Img_rec.reshape(h,w),data_range = 1.) #compare_ssim(Img.reshape(h,w),Img_rec/255.0,data_range = 1.) 274 | psnr_aver+=psnr_single 275 | ssim_aver+=ssim_single 276 | psnr_a[fi]=psnr_single 277 | ssim_a[fi]=ssim_single 278 | sys.stdout = Logger(MODEL_PATH+'psnr.txt') 279 | print("alpha:", alpha, img_name, "psnr:",psnr_single,ssim_single) 280 | 281 | psnr_aver /= len(file_name) 282 | ssim_aver /= len(file_name) 283 | sys.stdout = Logger(MODEL_PATH+'psnr.txt') 284 | print("alpha:", alpha, "average psnr over the image set:", psnr_aver,"average ssim over the image set:", ssim_aver,'\n') 285 | sio.savemat('./Github_PR_est_noise_poi_FINAL/PR_Recon_Results/%s_PSNR_SSIM_%s_%s_DIP.mat'%(opts.imgs_dir,opts.masktype,opts.noisetype), {'psnr_a':psnr_a,'ssim_a':ssim_a}) 286 | 287 | 288 | 289 | 290 | -------------------------------------------------------------------------------- /PRImageSet/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/.DS_Store -------------------------------------------------------------------------------- /PRImageSet/Natural256/.directory: -------------------------------------------------------------------------------- 1 | [Dolphin] 2 | PreviewsShown=true 3 | Timestamp=2021,3,2,14,16,1 4 | Version=3 5 | -------------------------------------------------------------------------------- /PRImageSet/Natural256/barbara.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Natural256/barbara.png -------------------------------------------------------------------------------- /PRImageSet/Natural256/boat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Natural256/boat.png -------------------------------------------------------------------------------- /PRImageSet/Natural256/cameraman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Natural256/cameraman.png -------------------------------------------------------------------------------- /PRImageSet/Natural256/couple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Natural256/couple.png -------------------------------------------------------------------------------- /PRImageSet/Natural256/peppers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Natural256/peppers.png -------------------------------------------------------------------------------- /PRImageSet/Natural256/streamandbridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Natural256/streamandbridge.png -------------------------------------------------------------------------------- /PRImageSet/Unatural256/Butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Unatural256/Butterfly.png -------------------------------------------------------------------------------- /PRImageSet/Unatural256/Ecoli.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Unatural256/Ecoli.png -------------------------------------------------------------------------------- /PRImageSet/Unatural256/PillarsofCreation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Unatural256/PillarsofCreation.png -------------------------------------------------------------------------------- /PRImageSet/Unatural256/Pollen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Unatural256/Pollen.png -------------------------------------------------------------------------------- /PRImageSet/Unatural256/TadpoleGalaxy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Unatural256/TadpoleGalaxy.png -------------------------------------------------------------------------------- /PRImageSet/Unatural256/Yeast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/PRImageSet/Unatural256/Yeast.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Sampling 2 | This is the code for cvpr 2022 paper "Self-supervised Deep Image Restoration via Adaptive Stochastic Gradient Langevin Dynamics". 3 | ## Usage instructions. 4 | See the ```requirements.txt``` to install the dependent packages and libraries. 5 | ## Installation instructions, including any requirements. 6 | See the ```requirements.txt``` to install the dependent packages and libraries. 7 | 8 | + Use ```conda``` constrcut the virtual environment 9 | ```python 10 | pip3 install virtualenv 11 | virtualenv --no-site-packages --python=python3 ASGLD 12 | source ASGLD/bin/activate # enter the environment 13 | pip3 install -r requirements.txt # install the dependency 14 | deactivate 15 | ``` 16 | ## Citation 17 | If you find our work useful in your research or publication, please cite it: 18 | 19 | ``` 20 | @inproceedings{wang2022self, 21 | title={Self-supervised deep image restoration via adaptive stochastic gradient langevin dynamics}, 22 | author={Wang, Weixi and Li, Ji and Ji, Hui}, 23 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 24 | pages={1989--1998}, 25 | year={2022} 26 | } 27 | ``` 28 | 29 | ## Phase Retrieval 30 | 31 | If you have any questions of the code, please feel free to contact [Ji Li](mailto:matliji@nus.edu.sg). 32 | 33 | -------------------------------------------------------------------------------- /cs_nonoise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch.nn.functional as F 4 | fname = 'data/CS/Set11/' 5 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 7 | 8 | import numpy as np 9 | from models import * 10 | 11 | import torch 12 | import torch.optim 13 | 14 | from skimage.measure import compare_psnr 15 | from utils.denoising_utils import * 16 | 17 | torch.backends.cudnn.enabled = True 18 | torch.backends.cudnn.benchmark =True 19 | dtype = torch.cuda.FloatTensor 20 | 21 | imsize =-1 22 | PLOT = False 23 | psnr_total=[] 24 | Imgname=sorted(os.listdir(fname)) 25 | kk=0 26 | from scipy.linalg import sqrtm 27 | n = 33*33 28 | m = int(n/2.5) 29 | A = torch.empty(n,m).normal_(0, 1) 30 | B=torch.mm(torch.transpose(A,1,0),A) 31 | B=torch.inverse(B) 32 | B=sqrtm(B) 33 | B=torch.Tensor(B) 34 | A=torch.mm(A,B) 35 | A=A.type(dtype) 36 | for imgn in Imgname: 37 | kk=kk+1 38 | img_pil = crop_image(get_image(fname+imgn, imsize)[0], d=1) 39 | img_np=pil_to_np(img_pil) 40 | img_np = img_np 41 | img_var=torch.tensor(img_np).type(dtype) 42 | block_size=33 43 | c,w,h = img_var.shape 44 | pad_right = block_size - w%block_size 45 | pad_bottom = block_size - h%block_size 46 | padd = (0,pad_bottom,0,pad_right) 47 | img_var=img_var.unsqueeze(0) 48 | def forwardm(Img,Phi_input,pad,block_size): 49 | Img_pad = F.pad(Img, pad, mode='constant', value=0) 50 | 51 | p,c,w,h = Img_pad.size() 52 | Img_col = torch.reshape(Img_pad,(p,c,-1,block_size,h)) 53 | n = Img_col.size()[2] 54 | Img_col = Img_col.reshape((p,c,n,block_size,-1,block_size)) 55 | Img_col = Img_col.permute(0,1,2,4,3,5) 56 | Img_col = Img_col.reshape(p,c,-1,block_size*block_size) 57 | 58 | Img_cs = torch.matmul(Img_col, Phi_input) 59 | return Img_cs 60 | measurement=forwardm(img_var,A,padd,33) 61 | 62 | reg_noise_std = 1./40. 63 | LR = 0.002 64 | 65 | show_every = 100 66 | exp_weight=0.999 67 | 68 | num_iter = 20000 69 | input_depth = 16 70 | figsize = 4 71 | INPUT = 'noise' # 'meshgrid' 72 | pad = 'reflection' 73 | OPT_OVER = 'net' # 'net,input' 74 | 75 | net = get_net(input_depth,'skip', pad, 76 | skip_n33d=128, 77 | skip_n33u=128, 78 | skip_n11=4, 79 | num_scales=5, 80 | n_channels=1, 81 | upsample_mode='bilinear').type(dtype) 82 | net_input = get_noise(input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach() 83 | # Compute number of parameters 84 | s = sum([np.prod(list(p.size())) for p in net.parameters()]); 85 | print ('Number of params: %d' % s) 86 | 87 | # Loss 88 | mse = torch.nn.MSELoss().type(dtype) 89 | img_torch = np_to_torch(img_np).type(dtype) 90 | net_input_saved = net_input.detach().clone() 91 | noise = net_input.detach().clone() 92 | out_avg = None 93 | 94 | last_net = None 95 | psrn_noisy_last = 0 96 | i = 0 97 | burn_in=5000 98 | psrn_gt_sm=0 99 | 100 | def closure(): 101 | 102 | global i, out_avg, psrn_noisy_last, last_net, net_input,img_torch,psrn_gt_sm,loss_data 103 | 104 | if reg_noise_std > 0: 105 | net_input = net_input_saved + (noise.normal_() * reg_noise_std) 106 | 107 | out = net(net_input) 108 | #out_2 = out[:,3:6,:,:] 109 | #out = out[:,0:3,:,:] 110 | # Smoothing 111 | if i>burn_in: 112 | if out_avg is None: 113 | out_avg=out.detach() 114 | else: 115 | #out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight) 116 | out_avg = out_avg * (i-burn_in-1)/(i-burn_in) + out.detach() * 1/(i-burn_in) 117 | #noise_disturb= get_noise(3, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach() 118 | total_loss = mse(forwardm(out,A,padd,33),measurement) 119 | total_loss.backward() 120 | psrn_gt = compare_psnr(img_np, out.detach().cpu().numpy()[0]) 121 | if i % show_every==0: 122 | print('psnr:',psrn_gt) 123 | if i>burn_in: 124 | psrn_gt_sm = compare_psnr(img_np, out_avg.detach().cpu().numpy()[0]) 125 | if i % show_every==0 and i>burn_in: 126 | print(imgn) 127 | print(kk) 128 | print(i) 129 | print(psrn_gt_sm) 130 | print(psnr_total) 131 | print(sum(psnr_total)/kk) 132 | i += 1 133 | 134 | return total_loss 135 | 136 | 137 | optimizer = Adam(0.000001,[{'params':net.parameters()}], lr=LR) # for 40% sampling ratio, 138 | #optimizer = Adam(0.000002,[{'params':net.parameters()}], lr=LR) for 25% sampling ratio, 139 | #optimizer = Adam(0.000003,[{'params':net.parameters()}], lr=LR) for 10% sampling ratio, 140 | for _ in range(num_iter): 141 | nlevel=1 142 | optimizer.zero_grad() 143 | closure() 144 | optimizer.step(1) 145 | path='result/Compressed_sensing/no_noise/' 146 | path=path+imgn 147 | out_img=torch.clamp(out_avg.squeeze().cpu(),0.,1.) 148 | out_img=out_img.detach().numpy() 149 | print(out_img.shape) 150 | cv2.imwrite(path,255*out_img) 151 | psnr_total.append(psrn_gt_sm) 152 | print(psnr_total) 153 | print(sum(psnr_total)/11) -------------------------------------------------------------------------------- /data/CS/Set11/Monarch.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/CS/Set11/Monarch.tif -------------------------------------------------------------------------------- /data/CS/Set11/Parrots.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/CS/Set11/Parrots.tif -------------------------------------------------------------------------------- /data/CS/Set11/barbara.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/CS/Set11/barbara.tif -------------------------------------------------------------------------------- /data/CS/Set11/boats.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/CS/Set11/boats.tif -------------------------------------------------------------------------------- /data/CS/Set11/cameraman.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/CS/Set11/cameraman.tif -------------------------------------------------------------------------------- /data/CS/Set11/fingerprint.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/CS/Set11/fingerprint.tif -------------------------------------------------------------------------------- /data/CS/Set11/flinstones.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/CS/Set11/flinstones.tif -------------------------------------------------------------------------------- /data/CS/Set11/foreman.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/CS/Set11/foreman.tif -------------------------------------------------------------------------------- /data/CS/Set11/house.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/CS/Set11/house.tif -------------------------------------------------------------------------------- /data/CS/Set11/lena256.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/CS/Set11/lena256.tif -------------------------------------------------------------------------------- /data/CS/Set11/peppers256.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/CS/Set11/peppers256.tif -------------------------------------------------------------------------------- /data/MRI/321.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/MRI/321.png -------------------------------------------------------------------------------- /data/mask/mask1d0.25.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/mask/mask1d0.25.mat -------------------------------------------------------------------------------- /data/mask/mask1s0.25.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/mask/mask1s0.25.mat -------------------------------------------------------------------------------- /data/mask/mask2r0.25.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/data/mask/mask2r0.25.mat -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .skip import skip 2 | from .resnet import ResNet 3 | from .unet import UNet 4 | 5 | import torch.nn as nn 6 | 7 | 8 | def get_net(input_depth, NET_TYPE, pad, upsample_mode, n_channels=3, act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=5, downsample_mode='stride',p=0.1): 9 | if NET_TYPE == 'ResNet': 10 | # TODO 11 | net = ResNet(input_depth, 3, 10, 16, 1, nn.BatchNorm2d, False) 12 | elif NET_TYPE == 'skip': 13 | net = skip(input_depth, n_channels, num_channels_down = [skip_n33d]*num_scales if isinstance(skip_n33d, int) else skip_n33d, 14 | num_channels_up = [skip_n33u]*num_scales if isinstance(skip_n33u, int) else skip_n33u, 15 | num_channels_skip = [skip_n11]*num_scales if isinstance(skip_n11, int) else skip_n11, 16 | upsample_mode=upsample_mode, downsample_mode=downsample_mode, 17 | need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun) 18 | 19 | elif NET_TYPE =='UNet': 20 | net = UNet(num_input_channels=input_depth, num_output_channels=3, 21 | feature_scale=4, more_layers=0, concat_x=False, 22 | upsample_mode=upsample_mode, pad=pad, norm_layer=nn.BatchNorm2d, need_sigmoid=True, need_bias=True) 23 | elif NET_TYPE == 'identity': 24 | assert input_depth == 3 25 | net = nn.Sequential() 26 | else: 27 | assert False 28 | 29 | return net -------------------------------------------------------------------------------- /models/__pycache__/Qskip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/Qskip.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/Qskip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/Qskip.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/decoder.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/decoder.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/downsampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/downsampler.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/downsampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/downsampler.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/drop_skip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/drop_skip.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/drop_skip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/drop_skip.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/fcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/fcn.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/fcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/fcn.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/multi_skip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/multi_skip.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/multi_skip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/multi_skip.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/skip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/skip.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/skip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/skip.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/texture_nets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/texture_nets.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/texture_nets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/texture_nets.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/unet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/models/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from siren_pytorch import Sine 5 | act_sin = Sine(1.) 6 | from .downsampler import Downsampler 7 | 8 | def add_module(self, module): 9 | self.add_module(str(len(self) + 1), module) 10 | 11 | torch.nn.Module.add = add_module 12 | 13 | class Concat(nn.Module): 14 | def __init__(self, dim, *args): 15 | super(Concat, self).__init__() 16 | self.dim = dim 17 | 18 | for idx, module in enumerate(args): 19 | self.add_module(str(idx), module) 20 | 21 | def forward(self, input): 22 | inputs = [] 23 | for module in self._modules.values(): 24 | inputs.append(module(input)) 25 | 26 | inputs_shapes2 = [x.shape[2] for x in inputs] 27 | inputs_shapes3 = [x.shape[3] for x in inputs] 28 | 29 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)): 30 | inputs_ = inputs 31 | else: 32 | target_shape2 = min(inputs_shapes2) 33 | target_shape3 = min(inputs_shapes3) 34 | 35 | inputs_ = [] 36 | for inp in inputs: 37 | diff2 = (inp.size(2) - target_shape2) // 2 38 | diff3 = (inp.size(3) - target_shape3) // 2 39 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) 40 | 41 | return torch.cat(inputs_, dim=self.dim) 42 | 43 | def __len__(self): 44 | return len(self._modules) 45 | 46 | 47 | class GenNoise(nn.Module): 48 | def __init__(self, dim2): 49 | super(GenNoise, self).__init__() 50 | self.dim2 = dim2 51 | 52 | def forward(self, input): 53 | a = list(input.size()) 54 | a[1] = self.dim2 55 | # print (input.data.type()) 56 | 57 | b = torch.zeros(a).type_as(input.data) 58 | b.normal_() 59 | 60 | x = torch.autograd.Variable(b) 61 | 62 | return x 63 | 64 | 65 | class Swish(nn.Module): 66 | """ 67 | https://arxiv.org/abs/1710.05941 68 | The hype was so huge that I could not help but try it 69 | """ 70 | def __init__(self): 71 | super(Swish, self).__init__() 72 | self.s = nn.Sigmoid() 73 | 74 | def forward(self, x): 75 | return x * self.s(x) 76 | 77 | 78 | def act(act_fun = 'LeakyReLU'): 79 | ''' 80 | Either string defining an activation function or module (e.g. nn.ReLU) 81 | ''' 82 | if isinstance(act_fun, str): 83 | if act_fun == 'LeakyReLU': 84 | return nn.LeakyReLU(0.2, inplace=True) 85 | elif act_fun == 'Swish': 86 | return Swish() 87 | elif act_fun == 'ELU': 88 | return nn.ELU() 89 | elif act_fun == 'Sin': 90 | return act_sin 91 | elif act_fun == 'none': 92 | return nn.Sequential() 93 | else: 94 | assert False 95 | else: 96 | return act_fun() 97 | 98 | 99 | def bn(num_features): 100 | return nn.BatchNorm2d(num_features) 101 | 102 | 103 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'): 104 | downsampler = None 105 | if stride != 1 and downsample_mode != 'stride': 106 | 107 | if downsample_mode == 'avg': 108 | downsampler = nn.AvgPool2d(stride, stride) 109 | elif downsample_mode == 'max': 110 | downsampler = nn.MaxPool2d(stride, stride) 111 | elif downsample_mode in ['lanczos2', 'lanczos3']: 112 | downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True) 113 | else: 114 | assert False 115 | 116 | stride = 1 117 | 118 | padder = None 119 | to_pad = int((kernel_size - 1) / 2) 120 | if pad == 'reflection': 121 | padder = nn.ReflectionPad2d(to_pad) 122 | to_pad = 0 123 | 124 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 125 | 126 | 127 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 128 | return nn.Sequential(*layers) 129 | 130 | class multi(nn.Module): 131 | def __init__(self,channels): 132 | super(multi, self).__init__() 133 | layers = [] 134 | layers.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1, bias=False)) 135 | layers.append(nn.ReLU(inplace=True)) 136 | self.net = nn.Sequential(*layers) 137 | def forward(self, input): 138 | return input*self.net(input) -------------------------------------------------------------------------------- /models/downsampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class Downsampler(nn.Module): 6 | ''' 7 | http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf 8 | ''' 9 | def __init__(self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False): 10 | super(Downsampler, self).__init__() 11 | 12 | assert phase in [0, 0.5], 'phase should be 0 or 0.5' 13 | 14 | if kernel_type == 'lanczos2': 15 | support = 2 16 | kernel_width = 4 * factor + 1 17 | kernel_type_ = 'lanczos' 18 | 19 | elif kernel_type == 'lanczos3': 20 | support = 3 21 | kernel_width = 6 * factor + 1 22 | kernel_type_ = 'lanczos' 23 | 24 | elif kernel_type == 'gauss12': 25 | kernel_width = 7 26 | sigma = 1/2 27 | kernel_type_ = 'gauss' 28 | 29 | elif kernel_type == 'gauss1sq2': 30 | kernel_width = 9 31 | sigma = 1./np.sqrt(2) 32 | kernel_type_ = 'gauss' 33 | 34 | elif kernel_type in ['lanczos', 'gauss', 'box']: 35 | kernel_type_ = kernel_type 36 | 37 | else: 38 | assert False, 'wrong name kernel' 39 | 40 | 41 | # note that `kernel width` will be different to actual size for phase = 1/2 42 | self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma) 43 | 44 | downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0) 45 | downsampler.weight.data[:] = 0 46 | downsampler.bias.data[:] = 0 47 | 48 | kernel_torch = torch.from_numpy(self.kernel) 49 | for i in range(n_planes): 50 | downsampler.weight.data[i, i] = kernel_torch 51 | 52 | self.downsampler_ = downsampler 53 | 54 | if preserve_size: 55 | 56 | if self.kernel.shape[0] % 2 == 1: 57 | pad = int((self.kernel.shape[0] - 1) / 2.) 58 | else: 59 | pad = int((self.kernel.shape[0] - factor) / 2.) 60 | 61 | self.padding = nn.ReplicationPad2d(pad) 62 | 63 | self.preserve_size = preserve_size 64 | 65 | def forward(self, input): 66 | if self.preserve_size: 67 | x = self.padding(input) 68 | else: 69 | x= input 70 | self.x = x 71 | return self.downsampler_(x) 72 | 73 | def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None): 74 | assert kernel_type in ['lanczos', 'gauss', 'box'] 75 | 76 | # factor = float(factor) 77 | if phase == 0.5 and kernel_type != 'box': 78 | kernel = np.zeros([kernel_width - 1, kernel_width - 1]) 79 | else: 80 | kernel = np.zeros([kernel_width, kernel_width]) 81 | 82 | 83 | if kernel_type == 'box': 84 | assert phase == 0.5, 'Box filter is always half-phased' 85 | kernel[:] = 1./(kernel_width * kernel_width) 86 | 87 | elif kernel_type == 'gauss': 88 | assert sigma, 'sigma is not specified' 89 | assert phase != 0.5, 'phase 1/2 for gauss not implemented' 90 | 91 | center = (kernel_width + 1.)/2. 92 | print(center, kernel_width) 93 | sigma_sq = sigma * sigma 94 | 95 | for i in range(1, kernel.shape[0] + 1): 96 | for j in range(1, kernel.shape[1] + 1): 97 | di = (i - center)/2. 98 | dj = (j - center)/2. 99 | kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj)/(2 * sigma_sq)) 100 | kernel[i - 1][j - 1] = kernel[i - 1][j - 1]/(2. * np.pi * sigma_sq) 101 | elif kernel_type == 'lanczos': 102 | assert support, 'support is not specified' 103 | center = (kernel_width + 1) / 2. 104 | 105 | for i in range(1, kernel.shape[0] + 1): 106 | for j in range(1, kernel.shape[1] + 1): 107 | 108 | if phase == 0.5: 109 | di = abs(i + 0.5 - center) / factor 110 | dj = abs(j + 0.5 - center) / factor 111 | else: 112 | di = abs(i - center) / factor 113 | dj = abs(j - center) / factor 114 | 115 | 116 | pi_sq = np.pi * np.pi 117 | 118 | val = 1 119 | if di != 0: 120 | val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support) 121 | val = val / (np.pi * np.pi * di * di) 122 | 123 | if dj != 0: 124 | val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support) 125 | val = val / (np.pi * np.pi * dj * dj) 126 | 127 | kernel[i - 1][j - 1] = val 128 | 129 | 130 | else: 131 | assert False, 'wrong method name' 132 | 133 | kernel /= kernel.sum() 134 | 135 | return kernel 136 | 137 | #a = Downsampler(n_planes=3, factor=2, kernel_type='lanczos2', phase='1', preserve_size=True) 138 | 139 | 140 | 141 | 142 | 143 | 144 | ################# 145 | # Learnable downsampler 146 | 147 | # KS = 32 148 | # dow = nn.Sequential(nn.ReplicationPad2d(int((KS - factor) / 2.)), nn.Conv2d(1,1,KS,factor)) 149 | 150 | # class Apply(nn.Module): 151 | # def __init__(self, what, dim, *args): 152 | # super(Apply, self).__init__() 153 | # self.dim = dim 154 | 155 | # self.what = what 156 | 157 | # def forward(self, input): 158 | # inputs = [] 159 | # for i in range(input.size(self.dim)): 160 | # inputs.append(self.what(input.narrow(self.dim, i, 1))) 161 | 162 | # return torch.cat(inputs, dim=self.dim) 163 | 164 | # def __len__(self): 165 | # return len(self._modules) 166 | 167 | # downs = Apply(dow, 1) 168 | # downs.type(dtype)(net_input.type(dtype)).size() 169 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from numpy.random import normal 4 | from numpy.linalg import svd 5 | from math import sqrt 6 | import torch.nn.init 7 | from .common import * 8 | 9 | class ResidualSequential(nn.Sequential): 10 | def __init__(self, *args): 11 | super(ResidualSequential, self).__init__(*args) 12 | 13 | def forward(self, x): 14 | out = super(ResidualSequential, self).forward(x) 15 | # print(x.size(), out.size()) 16 | x_ = None 17 | if out.size(2) != x.size(2) or out.size(3) != x.size(3): 18 | diff2 = x.size(2) - out.size(2) 19 | diff3 = x.size(3) - out.size(3) 20 | # print(1) 21 | x_ = x[:, :, diff2 /2:out.size(2) + diff2 / 2, diff3 / 2:out.size(3) + diff3 / 2] 22 | else: 23 | x_ = x 24 | return out + x_ 25 | 26 | def eval(self): 27 | print(2) 28 | for m in self.modules(): 29 | m.eval() 30 | exit() 31 | 32 | 33 | def get_block(num_channels, norm_layer, act_fun): 34 | layers = [ 35 | nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False), 36 | norm_layer(num_channels, affine=True), 37 | act(act_fun), 38 | nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False), 39 | norm_layer(num_channels, affine=True), 40 | ] 41 | return layers 42 | 43 | 44 | class ResNet(nn.Module): 45 | def __init__(self, num_input_channels, num_output_channels, num_blocks, num_channels, need_residual=True, act_fun='LeakyReLU', need_sigmoid=True, norm_layer=nn.BatchNorm2d, pad='reflection'): 46 | ''' 47 | pad = 'start|zero|replication' 48 | ''' 49 | super(ResNet, self).__init__() 50 | 51 | if need_residual: 52 | s = ResidualSequential 53 | else: 54 | s = nn.Sequential 55 | 56 | stride = 1 57 | # First layers 58 | layers = [ 59 | # nn.ReplicationPad2d(num_blocks * 2 * stride + 3), 60 | conv(num_input_channels, num_channels, 3, stride=1, bias=True, pad=pad), 61 | act(act_fun) 62 | ] 63 | # Residual blocks 64 | # layers_residual = [] 65 | for i in range(num_blocks): 66 | layers += [s(*get_block(num_channels, norm_layer, act_fun))] 67 | 68 | layers += [ 69 | nn.Conv2d(num_channels, num_channels, 3, 1, 1), 70 | norm_layer(num_channels, affine=True) 71 | ] 72 | 73 | # if need_residual: 74 | # layers += [ResidualSequential(*layers_residual)] 75 | # else: 76 | # layers += [Sequential(*layers_residual)] 77 | 78 | # if factor >= 2: 79 | # # Do upsampling if needed 80 | # layers += [ 81 | # nn.Conv2d(num_channels, num_channels * 82 | # factor ** 2, 3, 1), 83 | # nn.PixelShuffle(factor), 84 | # act(act_fun) 85 | # ] 86 | layers += [ 87 | conv(num_channels, num_output_channels, 3, 1, bias=True, pad=pad), 88 | nn.Sigmoid() 89 | ] 90 | self.model = nn.Sequential(*layers) 91 | 92 | def forward(self, input): 93 | return self.model(input) 94 | 95 | def eval(self): 96 | self.model.eval() 97 | -------------------------------------------------------------------------------- /models/skip.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .common import * 3 | 4 | def skip( 5 | num_input_channels=2, num_output_channels=3, 6 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 7 | filter_size_down=3, filter_size_up=3, filter_skip_size=1, 8 | need_sigmoid=True, need_bias=True, 9 | pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU', 10 | need1x1_up=True): 11 | """Assembles encoder-decoder with skip connections. 12 | 13 | Arguments: 14 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 15 | pad (string): zero|reflection (default: 'zero') 16 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 17 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') 18 | 19 | """ 20 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 21 | 22 | n_scales = len(num_channels_down) 23 | 24 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) : 25 | upsample_mode = [upsample_mode]*n_scales 26 | 27 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)): 28 | downsample_mode = [downsample_mode]*n_scales 29 | 30 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) : 31 | filter_size_down = [filter_size_down]*n_scales 32 | 33 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 34 | filter_size_up = [filter_size_up]*n_scales 35 | 36 | last_scale = n_scales - 1 37 | 38 | cur_depth = None 39 | 40 | model = nn.Sequential() 41 | model_tmp = model 42 | 43 | input_depth = num_input_channels 44 | for i in range(len(num_channels_down)): 45 | 46 | deeper = nn.Sequential() 47 | skip = nn.Sequential() 48 | 49 | if num_channels_skip[i] != 0: 50 | model_tmp.add(Concat(1, skip, deeper)) 51 | else: 52 | model_tmp.add(deeper) 53 | 54 | model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) 55 | 56 | if num_channels_skip[i] != 0: 57 | skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) 58 | skip.add(bn(num_channels_skip[i])) 59 | skip.add(act(act_fun)) 60 | 61 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) 62 | 63 | deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i])) 64 | deeper.add(bn(num_channels_down[i])) 65 | deeper.add(act(act_fun)) 66 | 67 | deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 68 | deeper.add(bn(num_channels_down[i])) 69 | deeper.add(act(act_fun)) 70 | 71 | deeper_main = nn.Sequential() 72 | 73 | if i == len(num_channels_down) - 1: 74 | # The deepest 75 | k = num_channels_down[i] 76 | else: 77 | deeper.add(deeper_main) 78 | k = num_channels_up[i + 1] 79 | 80 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) 81 | 82 | model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) 83 | model_tmp.add(bn(num_channels_up[i])) 84 | model_tmp.add(act(act_fun)) 85 | 86 | 87 | if need1x1_up: 88 | model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 89 | model_tmp.add(bn(num_channels_up[i])) 90 | model_tmp.add(act(act_fun)) 91 | 92 | input_depth = num_channels_down[i] 93 | model_tmp = deeper_main 94 | 95 | model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 96 | if need_sigmoid: 97 | model.add(nn.Sigmoid()) 98 | 99 | return model 100 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .common import * 6 | 7 | class ListModule(nn.Module): 8 | def __init__(self, *args): 9 | super(ListModule, self).__init__() 10 | idx = 0 11 | for module in args: 12 | self.add_module(str(idx), module) 13 | idx += 1 14 | 15 | def __getitem__(self, idx): 16 | if idx >= len(self._modules): 17 | raise IndexError('index {} is out of range'.format(idx)) 18 | if idx < 0: 19 | idx = len(self) + idx 20 | 21 | it = iter(self._modules.values()) 22 | for i in range(idx): 23 | next(it) 24 | return next(it) 25 | 26 | def __iter__(self): 27 | return iter(self._modules.values()) 28 | 29 | def __len__(self): 30 | return len(self._modules) 31 | 32 | class UNet(nn.Module): 33 | ''' 34 | upsample_mode in ['deconv', 'nearest', 'bilinear'] 35 | pad in ['zero', 'replication', 'none'] 36 | ''' 37 | def __init__(self, num_input_channels=3, num_output_channels=3, 38 | feature_scale=4, more_layers=0, concat_x=False, 39 | upsample_mode='deconv', pad='zero', norm_layer=nn.InstanceNorm2d, need_sigmoid=True, need_bias=True): 40 | super(UNet, self).__init__() 41 | 42 | self.feature_scale = feature_scale 43 | self.more_layers = more_layers 44 | self.concat_x = concat_x 45 | 46 | 47 | filters = [64, 128, 256, 512, 1024] 48 | filters = [x // self.feature_scale for x in filters] 49 | 50 | self.start = unetConv2(num_input_channels, filters[0] if not concat_x else filters[0] - num_input_channels, norm_layer, need_bias, pad) 51 | 52 | self.down1 = unetDown(filters[0], filters[1] if not concat_x else filters[1] - num_input_channels, norm_layer, need_bias, pad) 53 | self.down2 = unetDown(filters[1], filters[2] if not concat_x else filters[2] - num_input_channels, norm_layer, need_bias, pad) 54 | self.down3 = unetDown(filters[2], filters[3] if not concat_x else filters[3] - num_input_channels, norm_layer, need_bias, pad) 55 | self.down4 = unetDown(filters[3], filters[4] if not concat_x else filters[4] - num_input_channels, norm_layer, need_bias, pad) 56 | 57 | # more downsampling layers 58 | if self.more_layers > 0: 59 | self.more_downs = [ 60 | unetDown(filters[4], filters[4] if not concat_x else filters[4] - num_input_channels , norm_layer, need_bias, pad) for i in range(self.more_layers)] 61 | self.more_ups = [unetUp(filters[4], upsample_mode, need_bias, pad, same_num_filt =True) for i in range(self.more_layers)] 62 | 63 | self.more_downs = ListModule(*self.more_downs) 64 | self.more_ups = ListModule(*self.more_ups) 65 | 66 | self.up4 = unetUp(filters[3], upsample_mode, need_bias, pad) 67 | self.up3 = unetUp(filters[2], upsample_mode, need_bias, pad) 68 | self.up2 = unetUp(filters[1], upsample_mode, need_bias, pad) 69 | self.up1 = unetUp(filters[0], upsample_mode, need_bias, pad) 70 | 71 | self.final = conv(filters[0], num_output_channels, 1, bias=need_bias, pad=pad) 72 | 73 | if need_sigmoid: 74 | self.final = nn.Sequential(self.final, nn.Sigmoid()) 75 | 76 | def forward(self, inputs): 77 | 78 | # Downsample 79 | downs = [inputs] 80 | down = nn.AvgPool2d(2, 2) 81 | for i in range(4 + self.more_layers): 82 | downs.append(down(downs[-1])) 83 | 84 | in64 = self.start(inputs) 85 | if self.concat_x: 86 | in64 = torch.cat([in64, downs[0]], 1) 87 | 88 | down1 = self.down1(in64) 89 | if self.concat_x: 90 | down1 = torch.cat([down1, downs[1]], 1) 91 | 92 | down2 = self.down2(down1) 93 | if self.concat_x: 94 | down2 = torch.cat([down2, downs[2]], 1) 95 | 96 | down3 = self.down3(down2) 97 | if self.concat_x: 98 | down3 = torch.cat([down3, downs[3]], 1) 99 | 100 | down4 = self.down4(down3) 101 | if self.concat_x: 102 | down4 = torch.cat([down4, downs[4]], 1) 103 | 104 | if self.more_layers > 0: 105 | prevs = [down4] 106 | for kk, d in enumerate(self.more_downs): 107 | # print(prevs[-1].size()) 108 | out = d(prevs[-1]) 109 | if self.concat_x: 110 | out = torch.cat([out, downs[kk + 5]], 1) 111 | 112 | prevs.append(out) 113 | 114 | up_ = self.more_ups[-1](prevs[-1], prevs[-2]) 115 | for idx in range(self.more_layers - 1): 116 | l = self.more_ups[self.more_layers - idx - 2] 117 | up_= l(up_, prevs[self.more_layers - idx - 2]) 118 | else: 119 | up_= down4 120 | 121 | up4= self.up4(up_, down3) 122 | up3= self.up3(up4, down2) 123 | up2= self.up2(up3, down1) 124 | up1= self.up1(up2, in64) 125 | 126 | return self.final(up1) 127 | 128 | 129 | 130 | class unetConv2(nn.Module): 131 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 132 | super(unetConv2, self).__init__() 133 | 134 | print(pad) 135 | if norm_layer is not None: 136 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 137 | norm_layer(out_size), 138 | nn.ReLU(),) 139 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 140 | norm_layer(out_size), 141 | nn.ReLU(),) 142 | else: 143 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 144 | nn.ReLU(),) 145 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 146 | nn.ReLU(),) 147 | def forward(self, inputs): 148 | outputs= self.conv1(inputs) 149 | outputs= self.conv2(outputs) 150 | return outputs 151 | 152 | 153 | class unetDown(nn.Module): 154 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 155 | super(unetDown, self).__init__() 156 | self.conv= unetConv2(in_size, out_size, norm_layer, need_bias, pad) 157 | self.down= nn.MaxPool2d(2, 2) 158 | 159 | def forward(self, inputs): 160 | outputs= self.down(inputs) 161 | outputs= self.conv(outputs) 162 | return outputs 163 | 164 | 165 | class unetUp(nn.Module): 166 | def __init__(self, out_size, upsample_mode, need_bias, pad, same_num_filt=False): 167 | super(unetUp, self).__init__() 168 | 169 | num_filt = out_size if same_num_filt else out_size * 2 170 | if upsample_mode == 'deconv': 171 | self.up= nn.ConvTranspose2d(num_filt, out_size, 4, stride=2, padding=1) 172 | self.conv= unetConv2(out_size * 2, out_size, None, need_bias, pad) 173 | elif upsample_mode=='bilinear' or upsample_mode=='nearest': 174 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode=upsample_mode), 175 | conv(num_filt, out_size, 3, bias=need_bias, pad=pad)) 176 | self.conv= unetConv2(out_size * 2, out_size, None, need_bias, pad) 177 | else: 178 | assert False 179 | 180 | def forward(self, inputs1, inputs2): 181 | in1_up= self.up(inputs1) 182 | 183 | if (inputs2.size(2) != in1_up.size(2)) or (inputs2.size(3) != in1_up.size(3)): 184 | diff2 = (inputs2.size(2) - in1_up.size(2)) // 2 185 | diff3 = (inputs2.size(3) - in1_up.size(3)) // 2 186 | inputs2_ = inputs2[:, :, diff2 : diff2 + in1_up.size(2), diff3 : diff3 + in1_up.size(3)] 187 | else: 188 | inputs2_ = inputs2 189 | 190 | output= self.conv(torch.cat([in1_up, inputs2_], 1)) 191 | 192 | return output 193 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision==0.7.0 3 | numpy==1.21.4 4 | scikit-image==0.17.2 5 | scipy==1.5.0 6 | Pillow==8.4.0 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common_mri.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/common_mri.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common_mri.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/common_mri.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/common_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/common_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/denoising_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/denoising_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/denoising_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/denoising_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/inpainting_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/inpainting_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sr_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/sr_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sr_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/sr_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transform_mri.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/transform_mri.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transform_mri.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wang-weixi/Adaptive_Sampling/7141e27a5579e97b7161a457cc9f4fd5000867dc/utils/__pycache__/transform_mri.cpython-38.pyc -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | from PIL import Image 5 | import numpy as np 6 | import math 7 | import matplotlib.pyplot as plt 8 | 9 | def crop_image(img, d=32): 10 | '''Make dimensions divisible by `d`''' 11 | 12 | new_size = (img.size[0] - img.size[0] % d, 13 | img.size[1] - img.size[1] % d) 14 | 15 | bbox = [ 16 | int((img.size[0] - new_size[0])/2), 17 | int((img.size[1] - new_size[1])/2), 18 | int((img.size[0] + new_size[0])/2), 19 | int((img.size[1] + new_size[1])/2), 20 | ] 21 | 22 | img_cropped = img.crop(bbox) 23 | return img_cropped 24 | 25 | def get_params(opt_over, net, net_kernel, downsampler=None): 26 | '''Returns parameters that we want to optimize over. 27 | 28 | Args: 29 | opt_over: comma separated list, e.g. "net,input" or "net" 30 | net: network 31 | net_input: torch.Tensor that stores input `z` 32 | ''' 33 | opt_over_list = opt_over.split(',') 34 | params = [] 35 | 36 | for opt in opt_over_list: 37 | 38 | if opt == 'net': 39 | params += [x for x in net.parameters() ] 40 | elif opt=='down': 41 | assert downsampler is not None 42 | params = [x for x in downsampler.parameters()] 43 | elif opt == 'net_kernel': 44 | params += [x for x in net_kernel.parameters() ] 45 | else: 46 | assert False, 'what is it?' 47 | 48 | return params 49 | 50 | def get_image_grid(images_np, nrow=8): 51 | '''Creates a grid from a list of images by concatenating them.''' 52 | images_torch = [torch.from_numpy(x) for x in images_np] 53 | torch_grid = torchvision.utils.make_grid(images_torch, nrow) 54 | 55 | return torch_grid.numpy() 56 | 57 | def plot_image_grid(images_np, nrow =8, factor=1, interpolation='lanczos'): 58 | """Draws images in a grid 59 | 60 | Args: 61 | images_np: list of images, each image is np.array of size 3xHxW of 1xHxW 62 | nrow: how many images will be in one row 63 | factor: size if the plt.figure 64 | interpolation: interpolation used in plt.imshow 65 | """ 66 | n_channels = max(x.shape[0] for x in images_np) 67 | assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels" 68 | 69 | images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np] 70 | 71 | grid = get_image_grid(images_np, nrow) 72 | 73 | plt.figure(figsize=(len(images_np) + factor, 12 + factor)) 74 | 75 | if images_np[0].shape[0] == 1: 76 | plt.imshow(grid[0], cmap='gray', interpolation=interpolation) 77 | else: 78 | plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation) 79 | 80 | plt.show() 81 | 82 | return grid 83 | 84 | def load(path): 85 | """Load PIL image.""" 86 | img = Image.open(path) 87 | return img 88 | 89 | def get_image(path, imsize=-1): 90 | """Load an image and resize to a cpecific size. 91 | 92 | Args: 93 | path: path to image 94 | imsize: tuple or scalar with dimensions; -1 for `no resize` 95 | """ 96 | img = load(path) 97 | 98 | if isinstance(imsize, int): 99 | imsize = (imsize, imsize) 100 | 101 | if imsize[0]!= -1 and img.size != imsize: 102 | if imsize[0] > img.size[0]: 103 | img = img.resize(imsize, Image.BICUBIC) 104 | else: 105 | img = img.resize(imsize, Image.ANTIALIAS) 106 | 107 | img_np = pil_to_np(img) 108 | 109 | return img, img_np 110 | 111 | 112 | 113 | def fill_noise(x, noise_type): 114 | """Fills tensor `x` with noise of type `noise_type`.""" 115 | if noise_type == 'u': 116 | x.uniform_() 117 | elif noise_type == 'n': 118 | x.normal_() 119 | else: 120 | assert False 121 | 122 | def get_noise(input_depth, method, spatial_size, noise_type='u', var=1./10): 123 | """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) 124 | initialized in a specific way. 125 | Args: 126 | input_depth: number of channels in the tensor 127 | method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid 128 | spatial_size: spatial size of the tensor to initialize 129 | noise_type: 'u' for uniform; 'n' for normal 130 | var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler. 131 | """ 132 | if isinstance(spatial_size, int): 133 | spatial_size = (spatial_size, spatial_size) 134 | if method == 'noise': 135 | shape = [1, input_depth, spatial_size[0], spatial_size[1]] 136 | net_input = torch.zeros(shape) 137 | 138 | fill_noise(net_input, noise_type) 139 | net_input *= var 140 | elif method == 'meshgrid': 141 | assert input_depth == 2 142 | X, Y = np.meshgrid(np.arange(0, spatial_size[1])/float(spatial_size[1]-1), np.arange(0, spatial_size[0])/float(spatial_size[0]-1)) 143 | meshgrid = np.concatenate([X[None,:], Y[None,:]]) 144 | net_input= np_to_torch(meshgrid) 145 | else: 146 | assert False 147 | 148 | return net_input 149 | 150 | def pil_to_np(img_PIL): 151 | '''Converts image in PIL format to np.array. 152 | 153 | From W x H x C [0...255] to C x W x H [0..1] 154 | ''' 155 | ar = np.array(img_PIL) 156 | 157 | if len(ar.shape) == 3: 158 | ar = ar.transpose(2,0,1) 159 | else: 160 | ar = ar[None, ...] 161 | 162 | return ar.astype(np.float32) / 255. 163 | 164 | def np_to_pil(img_np): 165 | '''Converts image in np.array format to PIL image. 166 | 167 | From C x W x H [0..1] to W x H x C [0...255] 168 | ''' 169 | ar = np.clip(img_np*255,0,255).astype(np.uint8) 170 | 171 | if img_np.shape[0] == 1: 172 | ar = ar[0] 173 | else: 174 | ar = ar.transpose(1, 2, 0) 175 | 176 | return Image.fromarray(ar) 177 | 178 | def np_to_torch(img_np): 179 | '''Converts image in numpy.array to torch.Tensor. 180 | 181 | From C x W x H [0..1] to C x W x H [0..1] 182 | ''' 183 | return torch.from_numpy(img_np)[None, :] 184 | 185 | def torch_to_np(img_var): 186 | '''Converts an image in torch.Tensor format to np.array. 187 | 188 | From 1 x C x W x H [0..1] to C x W x H [0..1] 189 | ''' 190 | return img_var.detach().cpu().numpy()[0] 191 | 192 | 193 | class Adam(torch.optim.Optimizer): 194 | r"""Implements Adam algorithm. 195 | 196 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 197 | 198 | Arguments: 199 | params (iterable): iterable of parameters to optimize or dicts defining 200 | parameter groups 201 | lr (float, optional): learning rate (default: 1e-3) 202 | betas (Tuple[float, float], optional): coefficients used for computing 203 | running averages of gradient and its square (default: (0.9, 0.999)) 204 | eps (float, optional): term added to the denominator to improve 205 | numerical stability (default: 1e-8) 206 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 207 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 208 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 209 | (default: False) 210 | 211 | .. _Adam\: A Method for Stochastic Optimization: 212 | https://arxiv.org/abs/1412.6980 213 | .. _On the Convergence of Adam and Beyond: 214 | https://openreview.net/forum?id=ryQu7f-RZ 215 | """ 216 | 217 | def __init__(self, noise_level,params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 218 | weight_decay=0, amsgrad=False): 219 | if not 0.0 <= lr: 220 | raise ValueError("Invalid learning rate: {}".format(lr)) 221 | if not 0.0 <= eps: 222 | raise ValueError("Invalid epsilon value: {}".format(eps)) 223 | if not 0.0 <= betas[0] < 1.0: 224 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 225 | if not 0.0 <= betas[1] < 1.0: 226 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 227 | if not 0.0 <= weight_decay: 228 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 229 | defaults = dict(lr=lr, betas=betas, eps=eps, 230 | weight_decay=weight_decay, amsgrad=amsgrad) 231 | super(Adam, self).__init__(params, defaults) 232 | self.noise_level=noise_level 233 | def __setstate__(self, state): 234 | super(Adam, self).__setstate__(state) 235 | for group in self.param_groups: 236 | group.setdefault('amsgrad', False) 237 | 238 | @torch.no_grad() 239 | 240 | def step(self, loss_data,noise_type='gaussian',closure=None): 241 | """Performs a single optimization step. 242 | 243 | Arguments: 244 | closure (callable, optional): A closure that reevaluates the model 245 | and returns the loss. 246 | """ 247 | loss = None 248 | 249 | if closure is not None: 250 | with torch.enable_grad(): 251 | loss = closure() 252 | 253 | for group in self.param_groups: 254 | for p in group['params']: 255 | if p.grad is None: 256 | continue 257 | grad = p.grad 258 | ssize=grad.size() 259 | #zzero=torch.zeros(ssize).cuda() 260 | if grad.is_sparse: 261 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 262 | amsgrad = group['amsgrad'] 263 | #norm=torch.norm(grad) 264 | #print('grad_norm',norm) 265 | noise = torch.cuda.FloatTensor(ssize) 266 | #torch.randn(ssize, out=noise) 267 | if noise_type=='poission': 268 | torch.randn(ssize, out=noise) 269 | possion=torch.poisson(grad) 270 | if torch.norm(possion)>1*torch.norm(grad): 271 | possion=possion-grad 272 | noise=10*possion 273 | if noise_type=='gaussian': 274 | torch.randn(ssize, out=noise) 275 | if noise_type=='bernoulli': 276 | torch.ones(ssize,out=noise) 277 | noise=0.1*noise 278 | torch.bernoulli(noise, out=noise) 279 | noise=-noise*grad 280 | if noise_type=='levi': 281 | torch.randn(ssize, out=noise) 282 | possion=torch.poisson(grad) 283 | if torch.norm(possion)>1*torch.norm(grad): 284 | possion=possion-grad 285 | noise=30*possion+noise 286 | 287 | state = self.state[p] 288 | 289 | # State initialization 290 | if len(state) == 0: 291 | state['step'] = 0 292 | # Exponential moving average of gradient values 293 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 294 | # Exponential moving average of squared gradient values 295 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 296 | if amsgrad: 297 | # Maintains max of all exp. moving avg. of sq. grad. values 298 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 299 | 300 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 301 | if amsgrad: 302 | max_exp_avg_sq = state['max_exp_avg_sq'] 303 | beta1, beta2 = group['betas'] 304 | 305 | state['step'] += 1 306 | bias_correction1 = 1 - beta1 ** state['step'] 307 | bias_correction2 = 1 - beta2 ** state['step'] 308 | 309 | if group['weight_decay'] != 0: 310 | grad = grad.add(p, alpha=group['weight_decay']) 311 | 312 | # Decay the first and second moment running average coefficient 313 | 314 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 315 | exp_avg=exp_avg+(1 - beta1)*loss_data*self.noise_level*noise 316 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 317 | if amsgrad: 318 | # Maintains the maximum of all 2nd moment running avg. till now 319 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 320 | # Use the max. for normalizing running avg. of gradient 321 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 322 | else: 323 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 324 | 325 | step_size = group['lr'] / bias_correction1 326 | 327 | p.addcdiv_(exp_avg, denom, value=-step_size) 328 | return loss -------------------------------------------------------------------------------- /utils/denoising_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .common_utils import * 3 | 4 | 5 | 6 | def get_noisy_image(img_np, sigma): 7 | """Adds Gaussian noise to an image. 8 | 9 | Args: 10 | img_np: image, np.array with values from 0 to 1 11 | sigma: std of the noise 12 | """ 13 | img_noisy_np = np.clip(img_np + np.random.normal(scale=sigma, size=img_np.shape), 0, 1).astype(np.float32) 14 | img_noisy_pil = np_to_pil(img_noisy_np) 15 | 16 | return img_noisy_pil, img_noisy_np 17 | 18 | def get_blur_image(img_path): 19 | """Adds Gaussian noise to an image. 20 | 21 | Args: 22 | img_np: image, np.array with values from 0 to 1 23 | sigma: std of the noise 24 | 25 | """ 26 | imsize=-1 27 | path=img_path+'/1' 28 | c='.png' 29 | d=path+'.png' 30 | AAA = crop_image(get_image(d, imsize)[0], d=32) 31 | AAA = pil_to_np(AAA) 32 | for i in range (2,16): 33 | b=str(i) 34 | d=img_path+'/'+b+c 35 | img_pil = crop_image(get_image(d, imsize)[0], d=32) 36 | img_np_0 = pil_to_np(img_pil) 37 | AAA=AAA+img_np_0 38 | ''' 39 | img_pil = crop_image(get_image('data/pair/000/00000000.png', imsize)[0], d=32) 40 | img_np_0 = pil_to_np(img_pil) 41 | img_pil = crop_image(get_image('data/pair/000/00000001.png', imsize)[0], d=32) 42 | img_np_1 = pil_to_np(img_pil) 43 | img_pil = crop_image(get_image('data/pair/000/00000002.png', imsize)[0], d=32) 44 | img_np_2 = pil_to_np(img_pil) 45 | img_pil = crop_image(get_image('data/pair/000/00000003.png', imsize)[0], d=32) 46 | img_np_3 = pil_to_np(img_pil) 47 | img_pil = crop_image(get_image('data/pair/000/00000004.png', imsize)[0], d=32) 48 | img_np_4 = pil_to_np(img_pil) 49 | img_pil = crop_image(get_image('data/pair/000/00000005.png', imsize)[0], d=32) 50 | img_np_5 = pil_to_np(img_pil) 51 | img_pil = crop_image(get_image('data/pair/000/00000006.png', imsize)[0], d=32) 52 | img_np_6 = pil_to_np(img_pil) 53 | img_pil = crop_image(get_image('data/pair/000/00000007.png', imsize)[0], d=32) 54 | img_np_7 = pil_to_np(img_pil) 55 | img_pil = crop_image(get_image('data/pair/000/00000008.png', imsize)[0], d=32) 56 | img_np_8 = pil_to_np(img_pil) 57 | img_pil = crop_image(get_image('data/pair/000/00000009.png', imsize)[0], d=32) 58 | img_np_9 = pil_to_np(img_pil) 59 | img_pil = crop_image(get_image('data/pair/000/00000010.png', imsize)[0], d=32) 60 | img_np_10 = pil_to_np(img_pil) 61 | ''' 62 | 63 | #img_blur_np = np.clip((img_np_0+img_np_1+img_np_2+img_np_3+img_np_4+img_np_5+img_np_6+img_np_7+img_np_8+img_np_9+img_np_10)/11, 0, 1).astype(np.float32) 64 | img_blur_np=np.clip((AAA)/15, 0, 1).astype(np.float32) 65 | img_blur_pil = np_to_pil(img_blur_np) 66 | 67 | return img_blur_pil, img_blur_np -------------------------------------------------------------------------------- /utils/transform_mri.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def to_tensor(data): 5 | """ 6 | Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts 7 | are stacked along the last dimension. 8 | Args: 9 | data (np.array): Input numpy array 10 | Returns: 11 | torch.Tensor: PyTorch version of data 12 | """ 13 | if np.iscomplexobj(data): 14 | data = np.stack((data.real, data.imag), axis=-1) 15 | return torch.from_numpy(data) 16 | 17 | 18 | def apply_mask(data, mask_func = None, mask = None, seed=None): 19 | """ 20 | Subsample given k-space by multiplying with a mask. 21 | Args: 22 | data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where 23 | dimensions -3 and -2 are the spatial dimensions, and the final dimension has size 24 | 2 (for complex values). 25 | mask_func (callable): A function that takes a shape (tuple of ints) and a random 26 | number seed and returns a mask. 27 | seed (int or 1-d array_like, optional): Seed for the random number generator. 28 | Returns: 29 | (tuple): tuple containing: 30 | masked data (torch.Tensor): Subsampled k-space data 31 | mask (torch.Tensor): The generated mask 32 | """ 33 | shape = np.array(data.shape) 34 | shape[:-3] = 1 35 | if mask is None: 36 | mask = mask_func(shape, seed) 37 | return data * mask, mask 38 | 39 | 40 | def fft2(data): 41 | """ 42 | Apply centered 2 dimensional Fast Fourier Transform. 43 | Args: 44 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 45 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 46 | assumed to be batch dimensions. 47 | Returns: 48 | torch.Tensor: The FFT of the input. 49 | """ 50 | assert data.size(-1) == 2 51 | data = ifftshift(data, dim=(-3, -2)) 52 | data = torch.fft(data, 2, normalized=True) 53 | data = fftshift(data, dim=(-3, -2)) 54 | return data 55 | 56 | 57 | def ifft2(data): 58 | """ 59 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 60 | Args: 61 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 62 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 63 | assumed to be batch dimensions. 64 | Returns: 65 | torch.Tensor: The IFFT of the input. 66 | """ 67 | assert data.size(-1) == 2 68 | data = ifftshift(data, dim=(-3, -2)) 69 | data = torch.ifft(data, 2, normalized=True) 70 | data = fftshift(data, dim=(-3, -2)) 71 | return data 72 | 73 | 74 | def complex_abs(data): 75 | """ 76 | Compute the absolute value of a complex valued input tensor. 77 | Args: 78 | data (torch.Tensor): A complex valued tensor, where the size of the final dimension 79 | should be 2. 80 | Returns: 81 | torch.Tensor: Absolute value of data 82 | """ 83 | assert data.size(-1) == 2 84 | return (data ** 2).sum(dim=-1).sqrt() 85 | 86 | 87 | def root_sum_of_squares(data, dim=0): 88 | """ 89 | Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor. 90 | Args: 91 | data (torch.Tensor): The input tensor 92 | dim (int): The dimensions along which to apply the RSS transform 93 | Returns: 94 | torch.Tensor: The RSS value 95 | """ 96 | return torch.sqrt((data ** 2).sum(dim)) 97 | 98 | 99 | def center_crop(data, shape): 100 | """ 101 | Apply a center crop to the input real image or batch of real images. 102 | Args: 103 | data (torch.Tensor): The input tensor to be center cropped. It should have at 104 | least 2 dimensions and the cropping is applied along the last two dimensions. 105 | shape (int, int): The output shape. The shape should be smaller than the 106 | corresponding dimensions of data. 107 | Returns: 108 | torch.Tensor: The center cropped image 109 | """ 110 | assert 0 < shape[0] <= data.shape[-2] 111 | assert 0 < shape[1] <= data.shape[-1] 112 | w_from = (data.shape[-2] - shape[0]) // 2 113 | h_from = (data.shape[-1] - shape[1]) // 2 114 | w_to = w_from + shape[0] 115 | h_to = h_from + shape[1] 116 | return data[..., w_from:w_to, h_from:h_to] 117 | 118 | 119 | def complex_center_crop(data, shape): 120 | """ 121 | Apply a center crop to the input image or batch of complex images. 122 | Args: 123 | data (torch.Tensor): The complex input tensor to be center cropped. It should 124 | have at least 3 dimensions and the cropping is applied along dimensions 125 | -3 and -2 and the last dimensions should have a size of 2. 126 | shape (int, int): The output shape. The shape should be smaller than the 127 | corresponding dimensions of data. 128 | Returns: 129 | torch.Tensor: The center cropped image 130 | """ 131 | assert 0 < shape[0] <= data.shape[-3] 132 | assert 0 < shape[1] <= data.shape[-2] 133 | w_from = (data.shape[-3] - shape[0]) // 2 134 | h_from = (data.shape[-2] - shape[1]) // 2 135 | w_to = w_from + shape[0] 136 | h_to = h_from + shape[1] 137 | return data[..., w_from:w_to, h_from:h_to, :] 138 | 139 | 140 | def normalize(data, mean, stddev, eps=0.): 141 | """ 142 | Normalize the given tensor using: 143 | (data - mean) / (stddev + eps) 144 | Args: 145 | data (torch.Tensor): Input data to be normalized 146 | mean (float): Mean value 147 | stddev (float): Standard deviation 148 | eps (float): Added to stddev to prevent dividing by zero 149 | Returns: 150 | torch.Tensor: Normalized tensor 151 | """ 152 | return (data - mean) / (stddev + eps) 153 | 154 | 155 | def normalize_instance(data, eps=0.): 156 | """ 157 | Normalize the given tensor using: 158 | (data - mean) / (stddev + eps) 159 | where mean and stddev are computed from the data itself. 160 | Args: 161 | data (torch.Tensor): Input data to be normalized 162 | eps (float): Added to stddev to prevent dividing by zero 163 | Returns: 164 | torch.Tensor: Normalized tensor 165 | """ 166 | mean = data.mean() 167 | std = data.std() 168 | return normalize(data, mean, std, eps), mean, std 169 | 170 | 171 | # Helper functions 172 | 173 | def roll(x, shift, dim): 174 | """ 175 | Similar to np.roll but applies to PyTorch Tensors 176 | """ 177 | if isinstance(shift, (tuple, list)): 178 | assert len(shift) == len(dim) 179 | for s, d in zip(shift, dim): 180 | x = roll(x, s, d) 181 | return x 182 | shift = shift % x.size(dim) 183 | if shift == 0: 184 | return x 185 | left = x.narrow(dim, 0, x.size(dim) - shift) 186 | right = x.narrow(dim, x.size(dim) - shift, shift) 187 | return torch.cat((right, left), dim=dim) 188 | 189 | 190 | def fftshift(x, dim=None): 191 | """ 192 | Similar to np.fft.fftshift but applies to PyTorch Tensors 193 | """ 194 | if dim is None: 195 | dim = tuple(range(x.dim())) 196 | shift = [dim // 2 for dim in x.shape] 197 | elif isinstance(dim, int): 198 | shift = x.shape[dim] // 2 199 | else: 200 | shift = [x.shape[i] // 2 for i in dim] 201 | return roll(x, shift, dim) 202 | 203 | 204 | def ifftshift(x, dim=None): 205 | """ 206 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 207 | """ 208 | if dim is None: 209 | dim = tuple(range(x.dim())) 210 | shift = [(dim + 1) // 2 for dim in x.shape] 211 | elif isinstance(dim, int): 212 | shift = (x.shape[dim] + 1) // 2 213 | else: 214 | shift = [(x.shape[i] + 1) // 2 for i in dim] 215 | return roll(x, shift, dim) --------------------------------------------------------------------------------