├── MRI_multicoil_deep_decoder_accelerate.ipynb ├── README.md ├── compressive_sensing_example_convergence.ipynb ├── include ├── __init__.py ├── decoder.py ├── fit.py ├── helpers.py └── visualize.py └── test_data ├── art.jpeg └── grass.jpg /README.md: -------------------------------------------------------------------------------- 1 | # Compressive sensing with un-trained neural networks: Gradient descent finds the smoothest approximation 2 | 3 | This repository provides code for reproducing the figures in the paper: 4 | 5 | **``Compressive sensing with un-trained neural networks: Gradient descent finds the smoothest approximation''** by Reinhard Heckel and Mahdi Soltanolkotabi. Contact: [reinhard.heckel@gmail.com](reinhard.heckel@gmail.com) 6 | 7 | 8 | ## Organization 9 | 10 | - Figure 1: compressive_sensing_example_convergence.ipynb 11 | - Figure 5: MRI_multicoil_deep_decoder_accelerate.ipynb 12 | 13 | ## Installation 14 | 15 | The code is written in python and relies on pytorch. The following libraries are required: 16 | - python 3 17 | - pytorch 18 | - numpy 19 | - skimage 20 | - matplotlib 21 | - scikit-image 22 | - jupyter 23 | 24 | The libraries can be installed via: 25 | ``` 26 | conda install jupyter 27 | ``` 28 | 29 | The code to reproduce the MRI experiment uses a few function from the fastMRI repository to load the k-space data, those can be obtained by copying the data and common folders from the repository [https://github.com/facebookresearch/fastMRI](https://github.com/facebookresearch/fastMRI). In particular, download the code from the fastMRI repository, and copy the folder fastMRI/data into the cs_deep_decoder repository. 30 | 31 | ## Citation 32 | ``` 33 | @inproceedings{heckel_compressive_2020, 34 | author = {Reinhard Heckel and Mahdi Soltanolkotabi}, 35 | title = {Compressive sensing with un-trained neural networks: {Gradient} descent finds the smoothest approximation}, 36 | booktitle = { {International} {Conference} on {Machine} {Learning} }, 37 | year = {2020}, 38 | } 39 | ``` 40 | 41 | ## Licence 42 | 43 | All files are provided under the terms of the Apache License, Version 2.0 44 | -------------------------------------------------------------------------------- /include/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import * 2 | #from .decoder_conv import * 3 | from .visualize import * 4 | from .fit import * 5 | from .helpers import * 6 | -------------------------------------------------------------------------------- /include/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | def add_module(self, module): 6 | self.add_module(str(len(self) + 1), module) 7 | 8 | torch.nn.Module.add = add_module 9 | 10 | 11 | def conv(in_f, out_f, kernel_size, stride=1, pad='zero',bias=False): 12 | padder = None 13 | to_pad = int((kernel_size - 1) / 2) 14 | if pad == 'reflection': 15 | padder = nn.ReflectionPad2d(to_pad) 16 | to_pad = 0 17 | 18 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 19 | 20 | layers = filter(lambda x: x is not None, [padder, convolver]) 21 | return nn.Sequential(*layers) 22 | 23 | 24 | def deepdecoder( 25 | in_size, 26 | out_size, 27 | num_output_channels=3, 28 | num_channels=[128]*5, 29 | filter_size=1, 30 | need_sigmoid=True, 31 | pad ='reflection', 32 | upsample_mode='bilinear', 33 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 34 | bn_before_act = False, 35 | bn_affine = True, 36 | bias=False, 37 | last_noup=False, # if true have a last extra conv-relu-bn layer without the upsampling before linearly combining them 38 | ): 39 | 40 | depth = len(num_channels) 41 | scale_x,scale_y = (out_size[0]/in_size[0])**(1./depth), (out_size[1]/in_size[1])**(1./depth) 42 | hidden_size = [(int(np.ceil(scale_x**n * in_size[0])), 43 | int(np.ceil(scale_y**n * in_size[1]))) for n in range(1, depth)] + [out_size] 44 | 45 | print(hidden_size) 46 | 47 | num_channels = num_channels + [num_channels[-1],num_channels[-1]] 48 | 49 | n_scales = len(num_channels) 50 | 51 | if not (isinstance(filter_size, list) or isinstance(filter_size, tuple)) : 52 | filter_size = [filter_size]*n_scales 53 | 54 | model = nn.Sequential() 55 | 56 | for i in range(len(num_channels)-2): 57 | model.add(conv( num_channels[i], num_channels[i+1], filter_size[i], 1, pad=pad, bias=bias)) 58 | if upsample_mode!='none' and i != len(num_channels)-2: 59 | # align_corners: from pytorch.org: if True, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. Default: False 60 | # default seems to work slightly better 61 | model.add(nn.Upsample(size=hidden_size[i], mode=upsample_mode,align_corners=False)) 62 | 63 | if(bn_before_act): 64 | model.add(nn.BatchNorm2d( num_channels[i+1] ,affine=bn_affine)) 65 | if act_fun is not None: 66 | model.add(act_fun) 67 | if not bn_before_act: 68 | model.add(nn.BatchNorm2d( num_channels[i+1], affine=bn_affine)) 69 | 70 | if last_noup: 71 | model.add(conv( num_channels[-2], num_channels[-1], filter_size[-2], 1, pad=pad, bias=bias)) 72 | model.add(act_fun) 73 | model.add(nn.BatchNorm2d( num_channels[-1], affine=bn_affine)) 74 | 75 | model.add(conv( num_channels[-1], num_output_channels, 1, pad=pad,bias=bias)) 76 | if need_sigmoid: 77 | model.add(nn.Sigmoid()) 78 | 79 | return model 80 | 81 | 82 | 83 | 84 | def decodernw( 85 | num_output_channels=3, 86 | num_channels_up=[128]*5, 87 | filter_size_up=1, 88 | need_sigmoid=True, 89 | need_tanh=False, 90 | pad ='reflection', 91 | upsample_mode='bilinear', 92 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 93 | bn_before_act = False, 94 | bn_affine = True, 95 | bn = True, 96 | upsample_first = True, 97 | bias=False 98 | ): 99 | 100 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 101 | n_scales = len(num_channels_up) 102 | 103 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 104 | filter_size_up = [filter_size_up]*n_scales 105 | model = nn.Sequential() 106 | 107 | 108 | for i in range(len(num_channels_up)-1): 109 | 110 | if upsample_first: 111 | model.add(conv( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad, bias=bias)) 112 | if upsample_mode!='none' and i != len(num_channels_up)-2: 113 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 114 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 115 | else: 116 | if upsample_mode!='none' and i!=0: 117 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 118 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 119 | model.add(conv( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad,bias=bias)) 120 | 121 | if i != len(num_channels_up)-1: 122 | if(bn_before_act and bn): 123 | model.add(nn.BatchNorm2d( num_channels_up[i+1] ,affine=bn_affine)) 124 | if act_fun is not None: 125 | model.add(act_fun) 126 | if( (not bn_before_act) and bn): 127 | model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine)) 128 | 129 | model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad,bias=bias)) 130 | if need_sigmoid: 131 | model.add(nn.Sigmoid()) 132 | elif need_tanh: 133 | model.add(nn.Tanh()) 134 | 135 | return model 136 | 137 | -------------------------------------------------------------------------------- /include/fit.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import torch.optim 4 | import copy 5 | import numpy as np 6 | from scipy.linalg import hadamard 7 | 8 | from .helpers import * 9 | 10 | dtype = torch.cuda.FloatTensor 11 | #dtype = torch.FloatTensor 12 | 13 | from data import transforms as transform 14 | 15 | def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=500): 16 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 17 | lr = init_lr * (0.65**(epoch // lr_decay_epoch)) 18 | 19 | if epoch % lr_decay_epoch == 0: 20 | print('LR is set to {}'.format(lr)) 21 | 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | 25 | return optimizer 26 | 27 | def sqnorm(a): 28 | return np.sum( a*a ) 29 | 30 | def get_distances(initial_maps,final_maps): 31 | results = [] 32 | for a,b in zip(initial_maps,final_maps): 33 | res = sqnorm(a-b)/(sqnorm(a) + sqnorm(b)) 34 | results += [res] 35 | return(results) 36 | 37 | def get_weights(net): 38 | weights = [] 39 | for m in net.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | weights += [m.weight.data.cpu().numpy()] 42 | return weights 43 | 44 | def channels2imgs(out): 45 | sh = out.shape 46 | chs = int(sh[0]/2) 47 | imgs = np.zeros( (chs,sh[1],sh[2]) ) 48 | for i in range(chs): 49 | imgs[i] = np.sqrt( out[2*i]**2 + out[2*i+1]**2 ) 50 | return imgs 51 | 52 | def fit(net, 53 | img_noisy_var, 54 | num_channels, 55 | img_clean_var, 56 | num_iter = 5000, 57 | LR = 0.01, 58 | OPTIMIZER='adam', 59 | opt_input = False, 60 | reg_noise_std = 0, 61 | reg_noise_decayevery = 100000, 62 | mask_var = None, 63 | apply_f = None, 64 | lr_decay_epoch = 0, 65 | net_input = None, 66 | net_input_gen = "random", 67 | find_best=False, 68 | weight_decay=0, 69 | upsample_mode = "bilinear", 70 | totalupsample = 1, 71 | loss_type="MSE", 72 | output_gradients=False, 73 | output_weights=False, 74 | show_images=False, 75 | plot_after=None, 76 | in_size=None, 77 | MRI_multicoil_reference=None, 78 | ): 79 | 80 | if net_input is not None: 81 | print("input provided") 82 | else: 83 | 84 | if upsample_mode=="bilinear": 85 | # feed uniform noise into the network 86 | totalupsample = 2**len(num_channels) 87 | width = int(img_clean_var.data.shape[2]/totalupsample) 88 | height = int(img_clean_var.data.shape[3]/totalupsample) 89 | elif upsample_mode=="deconv": 90 | # feed uniform noise into the network 91 | totalupsample = 2**(len(num_channels)-1) 92 | width = int(img_clean_var.data.shape[2]/totalupsample) 93 | height = int(img_clean_var.data.shape[3]/totalupsample) 94 | elif upsample_mode=="free": 95 | width,height = in_size 96 | 97 | 98 | shape = [1,num_channels[0], width, height] 99 | print("input shape: ", shape) 100 | net_input = Variable(torch.zeros(shape)).type(dtype) 101 | net_input.data.uniform_() 102 | net_input.data *= 1./10 103 | 104 | net_input = net_input.type(dtype) 105 | net_input_saved = net_input.data.clone() 106 | noise = net_input.data.clone() 107 | p = [x for x in net.parameters() ] 108 | 109 | if(opt_input == True): # optimizer over the input as well 110 | net_input.requires_grad = True 111 | p += [net_input] 112 | 113 | mse_wrt_noisy = np.zeros(num_iter) 114 | mse_wrt_truth = np.zeros(num_iter) 115 | 116 | print( "init norm: ", np.linalg.norm( net( net_input.type(dtype) ).data.cpu().numpy()[0] ) ) 117 | print( "orig img norm: ", np.linalg.norm( img_clean_var.data.cpu().numpy() )) 118 | 119 | if OPTIMIZER == 'SGD': 120 | print("optimize with SGD", LR) 121 | optimizer = torch.optim.SGD(p, lr=LR,momentum=0.9,weight_decay=weight_decay) 122 | elif OPTIMIZER == 'adam': 123 | print("optimize with adam", LR) 124 | optimizer = torch.optim.Adam(p, lr=LR,weight_decay=weight_decay) 125 | elif OPTIMIZER == 'LBFGS': 126 | print("optimize with LBFGS", LR) 127 | optimizer = torch.optim.LBFGS(p, lr=LR) 128 | 129 | if loss_type=="MSE": 130 | mse = torch.nn.MSELoss() #.type(dtype) 131 | if loss_type=="L1": 132 | mse = nn.L1Loss() 133 | 134 | if find_best: 135 | best_net = copy.deepcopy(net) 136 | best_mse = 1000000.0 137 | 138 | nconvnets = 0 139 | for p in list(filter(lambda p: len(p.data.shape)>2, net.parameters())): 140 | nconvnets += 1 141 | 142 | out_grads = np.zeros((nconvnets,num_iter)) 143 | 144 | init_weights = get_weights(net) 145 | out_weights = np.zeros(( len(init_weights) ,num_iter)) 146 | 147 | out_imgs = np.zeros((1,1)) 148 | 149 | if plot_after is not None: 150 | out_img_np = net( net_input_saved.type(dtype) ).data.cpu().numpy()[0] 151 | out_imgs = np.zeros( (len(plot_after),) + out_img_np.shape ) 152 | 153 | for i in range(num_iter): 154 | 155 | if lr_decay_epoch is not 0: 156 | optimizer = exp_lr_scheduler(optimizer, i, init_lr=LR, lr_decay_epoch=lr_decay_epoch) 157 | if reg_noise_std > 0: 158 | if i % reg_noise_decayevery == 0: 159 | reg_noise_std *= 0.7 160 | net_input = Variable(net_input_saved + (noise.normal_() * reg_noise_std)) 161 | 162 | def closure(): 163 | optimizer.zero_grad() 164 | out = net(net_input.type(dtype)) 165 | 166 | # training loss 167 | if mask_var is not None: 168 | loss = mse( out * mask_var , img_noisy_var * mask_var ) 169 | elif apply_f: 170 | loss = mse( apply_f(out) , img_noisy_var ) 171 | else: 172 | loss = mse(out, img_noisy_var) 173 | 174 | loss.backward() 175 | mse_wrt_noisy[i] = loss.data.cpu().numpy() 176 | 177 | 178 | # the actual loss 179 | true_loss = mse( Variable(out.data, requires_grad=False).type(dtype), img_clean_var.type(dtype) ) 180 | mse_wrt_truth[i] = true_loss.data.cpu().numpy() 181 | 182 | if MRI_multicoil_reference is not None: 183 | out_chs = net( net_input.type(dtype) ).data.cpu().numpy()[0] 184 | out_imgs = channels2imgs(out_chs) 185 | out_img_np = transform.root_sum_of_squares( torch.tensor(out_imgs) , dim=0).numpy() 186 | mse_wrt_truth[i] = np.linalg.norm(MRI_multicoil_reference - out_img_np) 187 | 188 | if output_gradients: 189 | for ind,p in enumerate(list(filter(lambda p: p.grad is not None and len(p.data.shape)>2, net.parameters()))): 190 | out_grads[ind,i] = p.grad.data.norm(2).item() 191 | #print(p.grad.data.norm(2).item()) 192 | #su += p.grad.data.norm(2).item() 193 | #mse_wrt_noisy[i] = su 194 | 195 | if i % 10 == 0: 196 | out2 = net(Variable(net_input_saved).type(dtype)) 197 | loss2 = mse(out2, img_clean_var) 198 | print ('Iteration %05d Train loss %f Actual loss %f Actual loss orig %f' % (i, loss.data,mse_wrt_truth[i],loss2.data), '\r', end='') 199 | 200 | if show_images: 201 | if i % 50 == 0: 202 | print(i) 203 | out_img_np = net( ni.type(dtype) ).data.cpu().numpy()[0] 204 | myimgshow(plt,out_img_np) 205 | plt.show() 206 | 207 | if plot_after is not None: 208 | if i in plot_after: 209 | out_imgs[ plot_after.index(i) ,:] = net( net_input_saved.type(dtype) ).data.cpu().numpy()[0] 210 | 211 | if output_weights: 212 | out_weights[:,i] = np.array( get_distances( init_weights, get_weights(net) ) ) 213 | 214 | return loss 215 | 216 | loss = optimizer.step(closure) 217 | 218 | if find_best: 219 | # if training loss improves by at least one percent, we found a new best net 220 | if best_mse > 1.005*loss.data: 221 | best_mse = loss.data 222 | best_net = copy.deepcopy(net) 223 | 224 | 225 | if find_best: 226 | net = best_net 227 | if output_gradients and output_weights: 228 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_grads 229 | elif output_gradients: 230 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_grads 231 | elif output_weights: 232 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_weights 233 | elif plot_after is not None: 234 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_imgs 235 | else: 236 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net 237 | -------------------------------------------------------------------------------- /include/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import sys 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import PIL 9 | import numpy as np 10 | 11 | from torch.autograd import Variable 12 | 13 | import random 14 | import numpy as np 15 | import torch 16 | import matplotlib.pyplot as plt 17 | 18 | from PIL import Image 19 | import PIL 20 | 21 | from torch.autograd import Variable 22 | 23 | def myimgshow(plt,img): 24 | if(img.shape[0] == 1): 25 | plt.imshow(np.clip(img[0],0,1),cmap='Greys',interpolation='none') 26 | else: 27 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1),interpolation='none') 28 | 29 | def load_and_crop(imgname,target_width=512,target_height=512): 30 | ''' 31 | imgname: string of image location 32 | load an image, and center-crop if the image is large enough, else return none 33 | ''' 34 | img = Image.open(imgname) 35 | width, height = img.size 36 | if width <= target_width or height <= target_height: 37 | return None 38 | 39 | left = (width - target_width)/2 40 | top = (height - target_height)/2 41 | right = (width + target_width)/2 42 | bottom = (height + target_height)/2 43 | 44 | return img.crop((left, top, right, bottom)) 45 | 46 | def save_np_img(img,filename): 47 | if(img.shape[0] == 1): 48 | plt.imshow(np.clip(img[0],0,1),cmap='Greys',interpolation='nearest') 49 | else: 50 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1)) 51 | plt.axis('off') 52 | plt.savefig(filename, bbox_inches='tight') 53 | plt.close() 54 | 55 | def np_to_tensor(img_np): 56 | '''Converts image in numpy.array to torch.Tensor. 57 | 58 | From C x W x H [0..1] to C x W x H [0..1] 59 | ''' 60 | return torch.from_numpy(img_np) 61 | 62 | def np_to_var(img_np, dtype = torch.cuda.FloatTensor): 63 | '''Converts image in numpy.array to torch.Variable. 64 | 65 | From C x W x H [0..1] to 1 x C x W x H [0..1] 66 | ''' 67 | return Variable(np_to_tensor(img_np)[None, :]) 68 | 69 | def var_to_np(img_var): 70 | '''Converts an image in torch.Variable format to np.array. 71 | 72 | From 1 x C x W x H [0..1] to C x W x H [0..1] 73 | ''' 74 | return img_var.data.cpu().numpy()[0] 75 | 76 | 77 | def pil_to_np(img_PIL): 78 | '''Converts image in PIL format to np.array. 79 | 80 | From W x H x C [0...255] to C x W x H [0..1] 81 | ''' 82 | ar = np.array(img_PIL) 83 | 84 | if len(ar.shape) == 3: 85 | ar = ar.transpose(2,0,1) 86 | else: 87 | ar = ar[None, ...] 88 | 89 | return ar.astype(np.float32) / 255. 90 | 91 | 92 | def rgb2ycbcr(img): 93 | #out = color.rgb2ycbcr( img.transpose(1, 2, 0) ) 94 | #return out.transpose(2,0,1)/256. 95 | r,g,b = img[0],img[1],img[2] 96 | y = 0.299*r+0.587*g+0.114*b 97 | cb = 0.5 - 0.168736*r - 0.331264*g + 0.5*b 98 | cr = 0.5 + 0.5*r - 0.418588*g - 0.081312*b 99 | return np.array([y,cb,cr]) 100 | 101 | def ycbcr2rgb(img): 102 | #out = color.ycbcr2rgb( 256.*img.transpose(1, 2, 0) ) 103 | #return (out.transpose(2,0,1) - np.min(out))/(np.max(out)-np.min(out)) 104 | y,cb,cr = img[0],img[1],img[2] 105 | r = y + 1.402*(cr-0.5) 106 | g = y - 0.344136*(cb-0.5) - 0.714136*(cr-0.5) 107 | b = y + 1.772*(cb - 0.5) 108 | return np.array([r,g,b]) 109 | 110 | 111 | 112 | def mse(x_hat,x_true,maxv=1.): 113 | x_hat = x_hat.flatten() 114 | x_true = x_true.flatten() 115 | mse = np.mean(np.square(x_hat-x_true)) 116 | energy = np.mean(np.square(x_true)) 117 | return mse/energy 118 | 119 | def psnr(x_hat,x_true,maxv=1.): 120 | x_hat = x_hat.flatten() 121 | x_true = x_true.flatten() 122 | mse=np.mean(np.square(x_hat-x_true)) 123 | psnr_ = 10.*np.log(maxv**2/mse)/np.log(10.) 124 | return psnr_ 125 | 126 | def num_param(net): 127 | s = sum([np.prod(list(p.size())) for p in net.parameters()]); 128 | return s 129 | #print('Number of params: %d' % s) 130 | 131 | def rgb2gray(rgb): 132 | r, g, b = rgb[0,:,:], rgb[1,:,:], rgb[2,:,:] 133 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 134 | return np.array([gray]) 135 | 136 | def savemtx_for_logplot(A,filename = "exp.dat"): 137 | ind = sorted(list(set([int(i) for i in np.geomspace(1, len(A[0])-1 ,num=700)]))) 138 | A = [ [a[i] for i in ind] for a in A] 139 | X = np.array([ind] + A) 140 | np.savetxt(filename, X.T, delimiter=' ') 141 | 142 | 143 | def get_imgnet_imgs(num_samples = 100, path = '../imagenet/',verbose=False): 144 | perm = [i for i in range(1,50000)] 145 | random.Random(4).shuffle(perm) 146 | siz = 512 147 | file = open("exp_imgnet_imgs.txt","w") 148 | 149 | imgs = [] 150 | sampled = 0 151 | imgslist = [] 152 | for imgnr in perm: 153 | # prepare and select image 154 | # Format is: ILSVRC2012_val_00024995.JPEG 155 | imgnr_str = str(imgnr).zfill(8) 156 | imgname = path + 'ILSVRC2012_val_' + imgnr_str + ".JPEG" 157 | img = load_and_crop(imgname,target_width=512,target_height=512) 158 | if img is None: # then the image could not be croped to 512x512 159 | continue 160 | 161 | img_np = pil_to_np(img) 162 | 163 | if img_np.shape[0] != 3: # we only want to consider color images 164 | continue 165 | if verbose: 166 | imgslist += ['ILSVRC2012_val_' + imgnr_str + ".JPEG"] 167 | print("cp ", imgname, "./imgs") 168 | imgs += [img_np] 169 | sampled += 1 170 | if sampled >= num_samples: 171 | break 172 | if verbose: 173 | print(imgslist) 174 | return imgs 175 | 176 | 177 | 178 | -------------------------------------------------------------------------------- /include/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from torch.autograd import Variable 3 | import torch 4 | import torch.optim 5 | import numpy as np 6 | from collections import Iterable 7 | 8 | 9 | dtype = torch.cuda.FloatTensor 10 | #dtype = torch.FloatTensor 11 | 12 | def save_np_img(img,filename): 13 | if(img.shape[0] == 1): 14 | plt.imshow(np.clip(img[0],0,1)) 15 | else: 16 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1)) 17 | plt.axis('off') 18 | plt.savefig(filename, bbox_inches='tight') 19 | plt.close() 20 | 21 | def apply_until(net_input,net,n = 100): 22 | # applies function by funtion of a network 23 | for i,fun in enumerate(net): 24 | if i>=n: 25 | break 26 | if i==0: 27 | out = fun(net_input.type(dtype)) 28 | else: 29 | out = fun(out) 30 | print(i, "last func. applied:", net[i-1]) 31 | if n == 0: 32 | return net_input 33 | else: 34 | return out 35 | 36 | 37 | from math import ceil 38 | 39 | 40 | # given a lists of images as np-arrays, plot them as a row# given 41 | def plot_image_grid(imgs,nrows=10): 42 | ncols = ceil( len(imgs)/nrows ) 43 | nrows = min(nrows,len(imgs)) 44 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows),squeeze=False) 45 | for i, row in enumerate(axes): 46 | for j, ax in enumerate(row): 47 | ax.imshow(imgs[j*nrows+i], cmap='Greys_r', interpolation='none') 48 | ax.get_xaxis().set_visible(False) 49 | ax.get_yaxis().set_visible(False) 50 | fig.tight_layout(pad=0.1) 51 | return fig 52 | 53 | def save_tensor(out,filename,nrows=8): 54 | imgs = [img for img in out.data.cpu().numpy()[0]] 55 | fig = plot_image_grid(imgs,nrows=nrows) 56 | plt.savefig(filename) 57 | plt.close() 58 | 59 | def plot_kernels(tensor): 60 | if not len(tensor.shape)==4: 61 | raise Exception("assumes a 4D tensor") 62 | num_kernels = tensor.shape[0] 63 | fig = plt.figure(figsize=(tensor.shape[0],tensor.shape[1])) 64 | for i in range(tensor.shape[0]): 65 | for j in range(tensor.shape[1]): 66 | ax1 = fig.add_subplot(tensor.shape[0],tensor.shape[1],1+i*tensor.shape[0]+j) 67 | ax1.imshow(tensor[i][j]) 68 | ax1.axis('off') 69 | ax1.set_xticklabels([]) 70 | ax1.set_yticklabels([]) 71 | 72 | plt.subplots_adjust(wspace=0.1, hspace=0.1) 73 | plt.show() 74 | 75 | def plot_tensor(out,nrows=8): 76 | imgs = [img for img in out.data.cpu().numpy()[0]] 77 | fig = plot_image_grid(imgs,nrows=nrows) 78 | plt.show() 79 | -------------------------------------------------------------------------------- /test_data/art.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/cs_deep_decoder/2fe0653d3a8c291b1282f8cfcbb63c6958332724/test_data/art.jpeg -------------------------------------------------------------------------------- /test_data/grass.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/cs_deep_decoder/2fe0653d3a8c291b1282f8cfcbb63c6958332724/test_data/grass.jpg --------------------------------------------------------------------------------