├── .gitignore ├── LICENSE ├── README.md ├── common ├── fista_svd.py ├── fista_svd_demo_2d.ipynb ├── fista_svd_demo_3d.ipynb ├── forward_model.py ├── helper_functions │ ├── helper_functions.py │ ├── tv_approx_haar_cp.py │ └── tv_approx_haar_np.py ├── process_psf_for_svd.ipynb └── svd_model.py ├── data ├── 3D_data_simulated │ ├── blurred_4cells-Copy1.mat │ └── cellcool (59)flipudlr-Copy1.mat ├── fista3D-cellcool.mat ├── fista3D-fourCells.mat ├── hydra3.jpg └── real_data │ ├── interesting_bear.mat │ └── resTargetZ_1.mat ├── pytorch ├── 3D deconvolution demo (pretrained).ipynb ├── debug.ipynb ├── environment.yml ├── helper.py ├── models │ ├── dataset.py │ ├── dataset_tiff.py │ ├── modules.py │ ├── modules3d.py │ ├── unet.py │ ├── unet2.py │ ├── unet3d.py │ ├── unet_parts.py │ └── wiener_model.py ├── training_code.ipynb └── training_script.py └── tensorflow ├── 2D deconvolution demo (pretrained).ipynb ├── environment.yml ├── forward_model.py ├── forward_model.pyc ├── models ├── dataset.py ├── layers.py └── model_2d.py ├── training_code.ipynb ├── utils.py └── utils.pyc /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | */.ipynb_checkpoints/* 3 | __pycache__/ 4 | ipython_config.py 5 | 6 | .python-version 7 | *.png 8 | *.jpg 9 | *.jpeg 10 | *.tiff 11 | *.npy 12 | *.pt 13 | /tensorflow/saved_models/* 14 | /pytorch/saved_models/* 15 | /pytorch/saved_data/* 16 | /data/multiWienerPSFStack_40z_aligned.mat 17 | /pytorch/2D deconvolution demo (pretrained).ipynb -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Waller Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiWienerNet 2 | ## Deep learning for fast spatially-varying deconvolution 3 | 4 | ### [Project Page](https://waller-lab.github.io/MultiWienerNet/) | [Paper](https://doi.org/10.1364/OPTICA.442438) 5 | 6 | ### Setup: 7 | Clone this project using: 8 | ``` 9 | git clone https://github.com/Waller-Lab/MultiWienerNet.git 10 | ``` 11 | 12 | Install dependencies. We provide code both in Tensorflow and in Pytorch. Tensorflow version only contains an implementation for 2D deconvolutions, whereas the Pytorch contains both 2D and 3D deconvolutions. 13 | 14 | If using Pytorch, install the depencies as: 15 | 16 | ``` 17 | conda env create -f environment.yml 18 | source activate multiwiener_pytorch 19 | ``` 20 | 21 | If using Tensorflow, install the depencies as: 22 | 23 | ``` 24 | conda env create -f environment.yml 25 | source activate multiwiener_tf 26 | ``` 27 | 28 | ## Using pre-trained models 29 | We provide an example of a pre-trained MutliWienerNet for fast 2D deconvolutions as well as compressive 3D deconvolutions from 2D measurements. These examples are based on data for Randoscope3D. To adapt this model to your own data, please see below. 30 | 31 | 32 | ### Loading in pretrained models 33 | The pre-trained models can be downloaded: [here (pytorch)](https://drive.google.com/drive/folders/1teIPp2q2ce0l9FjYe0LuC9c-Rpq2fA8x?usp=sharing) and [here (tensorflow)](https://drive.google.com/drive/folders/1E3bye75ovDvfKsDG4IMe_hzo5wQU1zTP?usp=sharing) 34 | Please download these and place them in the pytorch/saved_models and tensorflow/saved_models 35 | 36 | ### Loading in data 37 | We provide several limited examples of 2D and 3D data in /data/ 38 | You can download the full dataset that we have used for training [here 2D](https://drive.google.com/drive/folders/199awM1qqQDqScgeI_HF65CG9PyjUWHGH?usp=sharing), [here 3D](https://drive.google.com/drive/folders/1QxtvjhCjnq9PtS9qMn5TVtSbg5sck3Ju?usp=sharing). 39 | You also need to download the PFSs [here](https://drive.google.com/drive/folders/103q6fND3W7hH-TCkCRv6Ho0xfgyScbvK?usp=sharing) and add it to the /data folder. 40 | 41 | ## Training for your own microscope/imaging system 42 | 43 | ### Characterize your imaging system forward model 44 | To retrain MultiWienerNet to work for your own imaging system, you first need to simulate realistic measurements from your imaging system to create sharp/blurred image pairs. If you already have a spatially-varying model for your imaging system (e.g. in Zemax), you can use this. If you do not have a model for your spatially-varying imaging system, we propose you follow the following calibration approach: 45 | 46 | * Scan a bead on a 8x8 grid across your microscope/imaging system's field of view. Repeat this for each depth plane of interest. 47 | * Run your data through our low rank code [here](https://github.com/Waller-Lab/MultiWienerNet/tree/main/common/process_psf_for_svd.ipynb) 48 | 49 | ### Create a dataset 50 | You can simulate data using the low rank forward model as shown in the SVD notebook above or use your own field-varying forward model to simulate measurements. 51 | ### Train your network 52 | We have provided training scripts for 2D imaging in tensorflow [here](https://github.com/Waller-Lab/MultiWienerNet/blob/main/tensorflow/2D%20deconvolution%20demo%20(pretrained).ipynb) and for single-shot 3D imaging in pytorch [here](https://github.com/Waller-Lab/MultiWienerNet/blob/main/pytorch/3D%20deconvolution%20demo%20(pretrained).ipynb). 53 | -------------------------------------------------------------------------------- /common/fista_svd.py: -------------------------------------------------------------------------------- 1 | import sys 2 | global device 3 | device= sys.argv[1] 4 | sys.path.append('helper_functions/') 5 | import forward_model as fm 6 | 7 | if device == 'GPU': 8 | import cupy as np 9 | import tv_approx_haar_cp as tv 10 | 11 | print('device = ', device, ', using GPU and cupy') 12 | else: 13 | import numpy as np 14 | import tv_approx_haar_np as tv 15 | print('device = ', device, ', using CPU and numpy') 16 | 17 | import helper_functions.helper_functions as fc 18 | import numpy as numpy 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | 23 | class fista_svd(): 24 | def __init__(self, H, weights, crop_indices, obj_type = '2D_svd'): 25 | 26 | ## Initialize constants 27 | self.DIMS0 = weights.shape[0] # Image Dimensions 28 | self.DIMS1 = weights.shape[1] # Image Dimensions 29 | 30 | self.py = int((self.DIMS0)//2) # Pad size 31 | self.px = int((self.DIMS1)//2) # Pad size 32 | 33 | # FFT of point spread function 34 | self.H = H#np.expand_dims(np.fft.fft2((np.fft.ifftshift(self.pad(h), axes = (0,1))), axes = (0,1)), -1) 35 | self.Hconj = np.conj(self.H) 36 | 37 | self.weights = weights 38 | self.crop_indices = crop_indices 39 | #self.mask = mask 40 | 41 | if obj_type == '2D_svd': 42 | self.Apow = fm.A_2d_svd_power 43 | self.A = fm.A_2d_svd_crop 44 | self.Aadj = fm.A_2d_adj_svd 45 | self.pad = fm.pad2d 46 | self.im_dims = [weights.shape[0]*2, weights.shape[1]*2] 47 | elif obj_type == '2D': 48 | self.Apow = fm.A_2d_power 49 | self.A = fm.A_2d_crop 50 | self.Aadj = fm.A_2d_adj 51 | self.pad = fm.pad2d 52 | self.im_dims = [weights.shape[0]*2, weights.shape[1]*2] 53 | elif obj_type == '3D': 54 | self.Apow = fm.A_3d_power 55 | self.A = fm.A_3d_crop 56 | self.Aadj = fm.A_3d_adj_fista 57 | self.pad = fm.pad2d 58 | self.im_dims = [weights.shape[0]*2, weights.shape[1]*2, H.shape[2]] 59 | elif obj_type == '3D_svd': 60 | self.Apow = fm.A_3d_svd_power 61 | self.A = fm.A_3d_svd_crop 62 | self.Aadj = fm.A_3d_adj_svd 63 | self.pad = fm.pad2d 64 | self.im_dims = [weights.shape[0]*2, weights.shape[1]*2, H.shape[2]] 65 | else: 66 | print('invalid object type') 67 | 68 | # Calculate the eigenvalue to set the step size 69 | maxeig = self.power_iteration(self.Apow, (self.im_dims[0:2]), 10) 70 | self.L = maxeig* 2 #45 71 | 72 | 73 | self.prox_method = 'tv' # options: 'non-neg', 'tv', 'native' 74 | 75 | # Define soft-thresholding constants 76 | self.tau = .5 # Native sparsity tuning parameter 77 | self.tv_lambda = 0.00005 # TV tuning parameter 78 | self.tv_lambdaw = 0.00005 # TV tuning parameter for wavelength 79 | self.lowrank_lambda = 0.00005 # Low rank tuning parameter 80 | 81 | 82 | # Number of iterations of FISTA 83 | self.iters = 500 84 | 85 | self.show_recon_progress = True # Display the intermediate results 86 | self.print_every = 20 # Sets how often to print the image 87 | 88 | self.l_data = [] 89 | self.l_tv = [] 90 | self.obj_type = obj_type 91 | 92 | # Power iteration to calculate eigenvalue 93 | def power_iteration(self, A, sample_vect_shape, num_iters): 94 | bk = np.random.randn(*sample_vect_shape) 95 | for i in range(0, num_iters): 96 | bk1 = A(bk, self.H, self.weights,fm.pad2d) 97 | bk1_norm = np.linalg.norm(bk1) 98 | 99 | bk = bk1/bk1_norm 100 | Mx = A(bk,self.H, self.weights,fm.pad2d) 101 | xx = np.transpose(np.dot(bk.ravel(), bk.ravel())) 102 | eig_b = np.transpose(bk.ravel()).dot(Mx.ravel())/xx 103 | 104 | return eig_b 105 | 106 | # Helper functions for forward model 107 | def crop(self,x): 108 | return x[self.py:-self.py, self.px:-self.px] 109 | 110 | def pad(self,x): 111 | if len(x.shape) == 2: 112 | out = np.pad(x, ([self.py, self.py], [self.px,self.px]), mode = 'constant') 113 | elif len(x.shape) == 3: 114 | out = np.pad(x, ([self.py, self.py], [self.px,self.px], [0, 0]), mode = 'constant') 115 | return out 116 | 117 | def soft_thresh(self, x, tau): 118 | out = np.maximum(np.abs(x)- tau, 0) 119 | out = out*np.sign(x) 120 | return out 121 | 122 | def prox(self,x): 123 | if self.prox_method == 'tv': 124 | x = 0.5*(np.maximum(x,0) + tv.tv3dApproxHaar(x, self.tv_lambda/self.L, self.tv_lambdaw)) 125 | if self.prox_method == 'native': 126 | x = 0.5*(np.maximum(x,0) + self.soft_thresh(x, self.tau)) 127 | if self.prox_method == 'non-neg': 128 | x = np.maximum(x,0) 129 | return x 130 | 131 | def tv(self, x): 132 | d = np.zeros_like(x) 133 | d[0:-1,:] = (x[0:-1,:] - x[1:, :])**2 134 | d[:,0:-1] = d[:,0:-1] + (x[:,0:-1] - x[:,1:])**2 135 | return np.sum(np.sqrt(d)) 136 | 137 | def loss(self,x,err): 138 | if self.prox_method == 'tv': 139 | self.l_data.append(np.linalg.norm(err)**2) 140 | self.l_tv.append(2*self.tv_lambda/self.L * self.tv(x)) 141 | 142 | l = np.linalg.norm(err)**2 + 2*self.tv_lambda/self.L * self.tv(x) 143 | if self.prox_method == 'native': 144 | l = np.linalg.norm(err)**2 + 2*self.tv_lambda/self.L * np.linalg.norm(x.ravel(), 1) 145 | if self.prox_method == 'non-neg': 146 | l = np.linalg.norm(err)**2 147 | return l 148 | 149 | # Main FISTA update 150 | def fista_update(self, vk, tk, xk, inputs): 151 | 152 | error = self.A(vk, self.H, self.weights, self.pad, self.crop_indices) - inputs 153 | grads = self.Aadj(self.Hconj,self.weights,error,self.pad) 154 | 155 | xup = self.prox(vk - 1/self.L * grads) 156 | tup = 1 + np.sqrt(1 + 4*tk**2)/2 157 | vup = xup + (tk-1)/tup * (xup-xk) 158 | 159 | return vup, tup, xup, self.loss(xup, error) 160 | 161 | 162 | # Run FISTA 163 | def run(self, inputs): 164 | 165 | # Initialize variables to zero 166 | xk = np.zeros((self.im_dims)) 167 | vk = np.zeros((self.im_dims)) 168 | tk = 1.0 169 | 170 | llist = [] 171 | 172 | # Start FISTA loop 173 | for i in range(0,self.iters): 174 | vk, tk, xk, l = self.fista_update(vk, tk, xk, inputs) 175 | 176 | if device == 'GPU': 177 | l =l.get() 178 | 179 | llist.append(l) 180 | 181 | # Print out the intermediate results and the loss 182 | if self.show_recon_progress== True and i%self.print_every == 0: 183 | print('iteration: ', i, ' loss: ', l) 184 | if device == 'GPU': 185 | out_img = np.asnumpy(self.crop(xk)) 186 | else: 187 | out_img = self.crop(xk) 188 | 189 | if len(out_img.shape)==3: 190 | fc_img = numpy.max(numpy.real(out_img),-1) 191 | else: 192 | fc_img = out_img 193 | 194 | 195 | plt.figure(figsize = (10,3)) 196 | plt.subplot(1,2,1), plt.imshow(fc_img/numpy.max(fc_img)); plt.title('Reconstruction') 197 | plt.subplot(1,2,2), plt.plot(llist); plt.title('Loss') 198 | plt.show() 199 | self.out_img = out_img 200 | xout = self.crop(xk) 201 | return xout, llist -------------------------------------------------------------------------------- /common/forward_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io 3 | def pad2d (x): 4 | Ny=x.shape[0] 5 | Nx=x.shape[1] 6 | return np.pad(x,((Ny//2,Ny//2),(Nx//2,Nx//2)), mode = 'constant')#, constant_values=(0)) 7 | 8 | def pad4d(x): 9 | Ny=x.shape[0] 10 | Nx=x.shape[1] 11 | return np.pad(x,((Ny//2,Ny//2),(Nx//2,Nx//2),(0,0),(0,0)), mode = 'constant')#, constant_values=(0)) 12 | 13 | def crop4d(x,rcL,rcU,ccL,ccU): 14 | return x[rcL:rcU,ccL:ccU,:,:] 15 | 16 | def crop2d(x,rcL,rcU,ccL,ccU): 17 | return x[rcL:rcU,ccL:ccU] 18 | 19 | def nocrop(x): 20 | return x 21 | 22 | def nopad(x): 23 | return x 24 | 25 | def A_2d_svd(x,H,weights,pad,mode='shift_variant'): #NOTE, H is already padded outside to save memory 26 | x=pad(x) 27 | #Y=np.zeros((x.shape[0],x.shape[1])) 28 | Y=np.zeros_like(x) 29 | 30 | if (mode =='shift_variant'): 31 | for r in range (0,weights.shape[2]): 32 | X=np.fft.fft2((np.multiply(pad(weights[:,:,r]),x))) 33 | Y=Y+ np.multiply(X,H[:,:,r]) 34 | 35 | return np.real((np.fft.ifftshift(np.fft.ifft2(Y)))) 36 | 37 | def A_2d_svd_power(x,H,weights,pad,mode='shift_variant'): #NOTE, H is already padded outside to save memory 38 | Y=np.zeros_like(x) 39 | 40 | if (mode =='shift_variant'): 41 | for r in range (0,weights.shape[2]): 42 | X=np.fft.fft2((np.multiply(pad(weights[:,:,r]),x))) 43 | Y=Y+ X*H[:,:,r] 44 | 45 | return np.real((np.fft.ifftshift(np.fft.ifft2(Y)))) 46 | 47 | def A_2d_svd_crop(x,H,weights,pad,crop_indices,mode='shift_variant'): #NOTE, H is already padded outside to save memory 48 | #x=pad(x) 49 | #Y=np.zeros((x.shape[0],x.shape[1])) 50 | Y=np.zeros_like(x) 51 | 52 | if (mode =='shift_variant'): 53 | for r in range (0,weights.shape[2]): 54 | X=np.fft.fft2((np.multiply(pad(weights[:,:,r]),x))) 55 | Y=Y+ np.multiply(X,H[:,:,r]) 56 | 57 | return crop2d(np.real((np.fft.ifftshift(np.fft.ifft2(Y)))),*crop_indices) 58 | 59 | def A_2d(x,H,weights,pad,crop_indices): 60 | X=np.fft.fft2((pad(x))) 61 | Y=np.multiply(X,H) 62 | 63 | return np.real((np.fft.ifftshift(np.fft.ifft2(Y)))) 64 | 65 | def A_2d_crop(x,H,weights,pad,crop_indices): 66 | X=np.fft.fft2(x) 67 | Y=np.multiply(X,H) 68 | 69 | return crop2d(np.real((np.fft.ifftshift(np.fft.ifft2(Y)))),*crop_indices) 70 | 71 | def A_2d_power(x,H,weights,pad,mode='shift_variant'): #NOTE, H is already padded outside to save memory 72 | 73 | X=np.fft.fft2(x) 74 | Y=np.multiply(X,H) 75 | 76 | return np.real((np.fft.ifftshift(np.fft.ifft2(Y)))) 77 | 78 | def A_2d_adj_svd(Hconj,weights,y,pad): 79 | y=pad(y) 80 | x=np.zeros_like(y) 81 | #x=np.zeros((y.shape[0],y.shape[1])) 82 | for r in range (0, weights.shape[2]): 83 | x=x+np.multiply(pad(weights[:,:,r]),(np.real(np.fft.ifftshift(np.fft.ifft2(np.multiply(Hconj[:,:,r], np.fft.fft2((y)))))))) 84 | #note the weights are real so we dont take the complex conjugate of it, which is the adjoint of the diag 85 | return x 86 | 87 | def A_2d_adj(Hconj,weights,y,pad): 88 | x=(np.real(np.fft.ifftshift(np.fft.ifft2(np.multiply(Hconj, np.fft.fft2((pad(y)))))))) 89 | 90 | return x 91 | 92 | def A_3d_power(v,H,weights,pad): 93 | #h is the psf stack 94 | #x is the variable to convolve with h 95 | 96 | B=np.zeros_like(v) 97 | for z in range (H.shape[2]): 98 | B=B+np.multiply(H[:,:,z],np.fft.fft2(v)) 99 | 100 | return np.real((np.fft.ifftshift(np.fft.ifft2(B)))) 101 | 102 | def A_3d_crop(v,H,weights,pad, crop_indices): 103 | #h is the psf stack 104 | #x is the variable to convolve with h 105 | 106 | B=np.zeros_like(v[...,0]) 107 | for z in range (H.shape[2]): 108 | B=B+np.multiply(H[:,:,z],np.fft.fft2(v[:,:,z])) 109 | 110 | return crop2d(np.real((np.fft.ifftshift(np.fft.ifft2(B)))),*crop_indices) 111 | 112 | def A_3d(x,h,pad): 113 | #h is the psf stack 114 | #x is the variable to convolve with h 115 | x=pad(x) 116 | B=np.zeros((x.shape[0],x.shape[1])) 117 | 118 | 119 | for z in range (0,h.shape[2]): 120 | #X=np.fft.fft2((np.multiply(pad(weights[:,:,z]),x))) 121 | B=B+ np.multiply(np.fft.fft2(x[:,:,z]),np.fft.fft2(pad(h[:,:,z]))) 122 | 123 | return np.real((np.fft.ifftshift(np.fft.ifft2(B)))) 124 | 125 | def A_3d_svd_power(v,H,weights,pad): 126 | #alpha is Ny-Nx-Nz-Nr, weights 127 | #v is Ny-Nx-Nz 128 | #H is Ny-Nx-Nz-Nr 129 | # b= sum_r (sum_z (h**alpra.*v)) 130 | #b=np.zeros((v.shape[0],v.shape[1])) 131 | b=np.zeros_like(v) 132 | for r in range (H.shape[3]): 133 | for z in range (H.shape[2]): 134 | b=b+np.multiply(H[:,:,z,r],np.fft.fft2(np.multiply(v,pad(weights[:,:,z,r])))) 135 | 136 | return np.real(np.fft.ifftshift(np.fft.ifft2(b))) 137 | 138 | def A_3d_svd_crop(v,H,weights,pad, crop_indices): 139 | #alpha is Ny-Nx-Nz-Nr, weights 140 | #v is Ny-Nx-Nz 141 | #H is Ny-Nx-Nz-Nr 142 | # b= sum_r (sum_z (h**alpra.*v)) 143 | b=np.zeros_like(v[:,:,0]) 144 | #b=np.zeros((v.shape[0],v.shape[1])) 145 | for r in range (H.shape[3]): 146 | for z in range (H.shape[2]): 147 | b=b+np.multiply(H[:,:,z,r],np.fft.fft2(np.multiply(v[:,:,z],pad(weights[:,:,z,r])))) 148 | 149 | return crop2d(np.real(np.fft.ifftshift(np.fft.ifft2(b))),*crop_indices) 150 | 151 | def A_3d_svd(v,alpha,H,pad): 152 | #alpha is Ny-Nx-Nz-Nr, weights 153 | #v is Ny-Nx-Nz 154 | #H is Ny-Nx-Nz-Nr 155 | # b= sum_r (sum_z (h**alpra.*v)) 156 | b=np.zeros((v.shape[0],v.shape[1])) 157 | for r in range (H.shape[3]): 158 | for z in range (H.shape[2]): 159 | b=b+np.multiply(H[:,:,z,r],np.fft.fft2(np.multiply(v[:,:,z],alpha[:,:,z,r]))) 160 | 161 | return np.real(np.fft.ifftshift(np.fft.ifft2(b))) 162 | 163 | 164 | def A_3d_adj_fista(Hconj,alpha,x,pad): 165 | y=np.zeros_like(Hconj) 166 | B=np.fft.fft2(pad(x)) 167 | for z in range(alpha.shape[2]): 168 | y[:,:,z]= np.real(np.fft.ifftshift(np.fft.ifft2(np.multiply(B,Hconj[:,:,z])))) 169 | 170 | return y 171 | 172 | def A_3d_adj(Hconj,alpha,x,pad): 173 | y=np.zeros_like(h) 174 | X=np.fft.fft2(pad(x)) 175 | for z in range(h.shape[2]): 176 | H=np.conj(np.fft.fft2(pad(h[:,:,z]))) 177 | y[:,:,z]=np.real(np.fft.ifftshift(np.fft.ifft2(np.multiply(H,X)))) 178 | return y 179 | 180 | # def A_2d_adj_svd(Hconj,weights,y,pad): 181 | # y=pad(y) 182 | # x=np.zeros_like(y) 183 | # #x=np.zeros((y.shape[0],y.shape[1])) 184 | # for r in range (0, weights.shape[2]): 185 | # x=x+np.multiply(pad(weights[:,:,r]),(np.real(np.fft.ifftshift(np.fft.ifft2(np.multiply(Hconj[:,:,r], np.fft.fft2((y)))))))) 186 | # #note the weights are real so we dont take the complex conjugate of it, which is the adjoint of the diag 187 | # return x 188 | 189 | #def A_3d_adj_svd(b,alpha,Hconj,pad): 190 | def A_3d_adj_svd(Hconj,alpha,x,pad): 191 | y=np.zeros_like(Hconj[...,0]) 192 | B=np.fft.fft2(pad(x)) 193 | for z in range(alpha.shape[2]): 194 | for r in range(alpha.shape[3]): 195 | y[:,:,z]=y[:,:,z]+ pad(alpha[:,:,z,r])* np.real(np.fft.ifftshift(np.fft.ifft2(np.multiply(B,Hconj[:,:,z,r])))) 196 | 197 | return y 198 | 199 | def grad(v): 200 | return np.array(np.gradient(v)) #returns gradient in x and in y 201 | 202 | 203 | def grad_adj(v): #adj of gradient is negative divergence 204 | z = np.zeros((n,n)) + 1j 205 | z -= np.gradient(v[0,:,:])[0] 206 | z -= np.gradient(v[1,:,:])[1] 207 | return z 208 | 209 | def sim_data(im,H,weights,crop_indices): 210 | mu=0 211 | sigma=np.random.rand(1)*0.02+0.005 #abit much maybe 0.04 best0.04+0.01 212 | PEAK=np.random.rand(1)*1000+50 213 | 214 | I=im/np.max(im) 215 | #I[I<0.12]=0 216 | sim=crop2d(A_2d_svd(I,H,weights,pad2d),*crop_indices) 217 | sim=sim/np.max(sim) 218 | sim=np.maximum(sim,0) 219 | 220 | p_noise = np.random.poisson(sim * PEAK)/PEAK 221 | 222 | g_noise= np.random.normal(mu, sigma, 648*486) 223 | g_noise=np.reshape(g_noise,(486,648)) 224 | sim=sim+g_noise+p_noise 225 | sim=sim/np.max(sim) 226 | sim=np.maximum(sim,0) 227 | sim=sim/np.max(sim) 228 | return sim 229 | 230 | 231 | # load in forward model weights 232 | def load_weights(): 233 | h=scipy.io.loadmat('../data/SVD_2_5um_PSF_5um_1_ds4_dsz1_comps_green_SubAvg.mat') 234 | weights=scipy.io.loadmat('../data/SVD_2_5um_PSF_5um_1_ds4_dsz1_weights_green_SubAvg.mat') 235 | 236 | depth_plane=0 #NOTE Z here is 1 less than matlab file as python zero index. So this is z31 in matlab 237 | 238 | h=h['array_out'] 239 | weights=weights['array_out'] 240 | # make sure its (x,y,z,r) 241 | h=np.swapaxes(h,2,3) 242 | weights=np.swapaxes(weights,2,3) 243 | 244 | h=h[:,:,depth_plane,:] 245 | weights=weights[:,:,depth_plane,:] 246 | 247 | # Normalize weights to have maximum sum through rank of 1 248 | weights_norm = np.max(np.sum(weights[weights.shape[0]//2-1,weights.shape[1]//2-1,:],0)) 249 | weights = weights/weights_norm; 250 | 251 | #normalize by norm of all stack. Can also try normalizing by max of all stack or by norm of each slice 252 | h=h/np.linalg.norm(np.ravel(h)) 253 | 254 | # padded values for 2D 255 | 256 | ccL = h.shape[1]//2 257 | ccU = 3*h.shape[1]//2 258 | rcL = h.shape[0]//2 259 | rcU = 3*h.shape[0]//2 260 | 261 | H=np.ndarray((h.shape[0]*2,h.shape[1]*2,h.shape[2]), dtype=complex) 262 | Hconj=np.ndarray((h.shape[0]*2,h.shape[1]*2,h.shape[2]),dtype=complex) 263 | for i in range (h.shape[2]): 264 | H[:,:,i]=(np.fft.fft2(pad2d(h[:,:,i]))) 265 | Hconj[:,:,i]=(np.conj(H[:,:,i])) 266 | return H,weights,[rcL,rcU,ccL,ccU] 267 | 268 | # load in forward model weights 269 | def load_weights_3d(path_psfs, path_weights): 270 | h=scipy.io.loadmat(path_psfs) 271 | weights=scipy.io.loadmat(path_weights) 272 | 273 | h=h['array_out'] 274 | weights=weights['array_out'] 275 | #make the shape, xyzr 276 | h=np.swapaxes(h,2,3) 277 | weights=np.swapaxes(weights,2,3) 278 | h=h[:,:,::2,:] 279 | h=h[:,:,0:32,:] 280 | weights=weights[:,:,::2,:] 281 | weights=weights[:,:,0:32,:] 282 | 283 | # Normalize weights to have maximum sum through rank of 1 284 | weights_norm = np.sum(weights[weights.shape[0]//2-1,weights.shape[1]//2-1,:],0).max() 285 | weights = weights/weights_norm; 286 | 287 | #normalize by norm of all stack. Can also try normalizing by max of all stack or by norm of each slice 288 | h=h/(np.linalg.norm(h.ravel())) 289 | 290 | ccL = h.shape[1]//2 291 | ccU = 3*h.shape[1]//2 292 | rcL = h.shape[0]//2 293 | rcU = 3*h.shape[0]//2 294 | 295 | crop_indices = [rcL,rcU,ccL,ccU] 296 | 297 | H=np.fft.fft2(pad4d(h), axes = (0,1)) 298 | Hconj=np.conj(H) 299 | 300 | return H,weights,crop_indices 301 | -------------------------------------------------------------------------------- /common/helper_functions/helper_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | import scipy.io 5 | from IPython.core.display import display, HTML 6 | from ipywidgets import interact, widgets, fixed 7 | 8 | import sys 9 | sys.path.append('helper_functions/') 10 | 11 | 12 | def plotf2(r, img, ttl, sz): 13 | plt.title(ttl+' {}'.format(r)) 14 | plt.imshow(img[:,:,r], vmin = np.min(img), vmax = np.max(img)); 15 | plt.axis('off'); 16 | fig = plt.gcf() 17 | fig.set_size_inches(sz) 18 | plt.show(); 19 | return 20 | 21 | def plt3D(img, title = '', size = (5,5)): 22 | interact(plotf2, 23 | r=widgets.IntSlider(min=0,max=np.shape(img)[-1]-1,step=1,value=1), 24 | img = fixed(img), 25 | continuous_update= False, 26 | ttl = fixed(title), 27 | sz = fixed(size)); 28 | 29 | 30 | 31 | def plotf22(r, img, ttl, sz): 32 | plt.title(ttl+' {}'.format(r)) 33 | plt.imshow(img[:,:,r], vmin = np.min(img[:,:,r]), vmax = np.max(img[:,:,r])); 34 | plt.axis('off'); 35 | fig = plt.gcf() 36 | fig.set_size_inches(sz) 37 | plt.show(); 38 | return 39 | 40 | def plt3D2(img, title = '', size = (5,5)): 41 | interact(plotf22, 42 | r=widgets.IntSlider(min=0,max=np.shape(img)[-1]-1,step=1,value=1), 43 | img = fixed(img), 44 | continuous_update= False, 45 | ttl = fixed(title), 46 | sz = fixed(size)); 47 | 48 | def crop(x): 49 | DIMS0 = x.shape[0]//2 # Image Dimensions 50 | DIMS1 = x.shape[1]//2 # Image Dimensions 51 | 52 | PAD_SIZE0 = int((DIMS0)//2) # Pad size 53 | PAD_SIZE1 = int((DIMS1)//2) # Pad size 54 | 55 | C01 = PAD_SIZE0; C02 = PAD_SIZE0 + DIMS0 # Crop indices 56 | C11 = PAD_SIZE1; C12 = PAD_SIZE1 + DIMS1 # Crop indices 57 | return x[C01:C02, C11:C12,:] 58 | 59 | def pre_plot(x): 60 | x = np.fliplr(np.flipud(x)) 61 | x = x/np.max(x) 62 | x = np.clip(x, 0,1) 63 | return x 64 | -------------------------------------------------------------------------------- /common/helper_functions/tv_approx_haar_cp.py: -------------------------------------------------------------------------------- 1 | import cupy as np 2 | 3 | def soft_py(x, tau): 4 | #print('tau', tau, 'x', np.max(np.abs(x))) 5 | threshed = np.maximum(np.abs(x)-tau, 0) 6 | threshed = threshed*np.sign(x) 7 | return threshed 8 | 9 | def ht3(x, ax, shift, thresh): 10 | C = 1./np.sqrt(2.) 11 | 12 | if shift == True: 13 | x = np.roll(x, -1, axis = ax) 14 | if ax == 0: 15 | w1 = C*(x[1::2] + x[0::2]) 16 | w2 = soft_py(C*(x[1::2] - x[0::2]), thresh) 17 | elif ax == 1: 18 | w1 = C*(x[:, 1::2] + x[:, 0::2]) 19 | w2 = soft_py(C*(x[:,1::2] - x[:,0::2]), thresh) 20 | elif ax == 2: 21 | w1 = C*(x[:,:,1::2] + x[:,:, 0::2]) 22 | w2 = soft_py(C*(x[:,:,1::2] - x[:,:,0::2]), thresh) 23 | return w1, w2 24 | 25 | def iht3(w1, w2, ax, shift, shape): 26 | 27 | C = 1./np.sqrt(2.) 28 | y = np.zeros(shape) 29 | 30 | x1 = C*(w1 - w2); x2 = C*(w1 + w2); 31 | if ax == 0: 32 | y[0::2] = x1 33 | y[1::2] = x2 34 | 35 | if ax == 1: 36 | y[:, 0::2] = x1 37 | y[:, 1::2] = x2 38 | if ax == 2: 39 | y[:, :, 0::2] = x1 40 | y[:, :, 1::2] = x2 41 | 42 | 43 | if shift == True: 44 | y = np.roll(y, 1, axis = ax) 45 | return y 46 | 47 | 48 | def iht3_py2(w1, w2, ax, shift, shape): 49 | 50 | C = 1./np.sqrt(2.) 51 | y = np.zeros(shape) 52 | 53 | x1 = C*(w1 - w2); x2 = C*(w1 + w2); 54 | 55 | ind = ax + 2; 56 | y = np.reshape(np.concatenate([np.expand_dims(x1, ind), np.expand_dims(x2, ind)], axis = ind), shape) 57 | 58 | 59 | if shift == True: 60 | y = np.roll(y, 1, axis = ax+1) 61 | return y 62 | 63 | def tv3dApproxHaar(x, tau, alpha): 64 | 65 | D = len(x.shape) # D =3 for 3D and 2 for 2D 66 | fact = np.sqrt(2)*2 67 | 68 | 69 | 70 | sqeezed = False 71 | if x.shape[-1] == 1: 72 | x = x[...,0] 73 | sqeezed = True 74 | D = 2 75 | 76 | thresh = D*tau*fact 77 | 78 | y = np.zeros_like(x) 79 | for ax in range(0,len(x.shape)): 80 | if ax ==2: 81 | t_scale = alpha 82 | else: 83 | t_scale = 1; 84 | 85 | w0, w1 = ht3(x, ax, False, thresh*t_scale) 86 | w2, w3 = ht3(x, ax, True, thresh*t_scale) 87 | 88 | t1 = iht3(w0, w1, ax, False, x.shape) 89 | t2 = iht3(w2, w3, ax, True, x.shape) 90 | y = y + t1 + t2 91 | 92 | y = y/(2*D) 93 | 94 | if sqeezed == True: 95 | y = y[..., np.newaxis] 96 | return y 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /common/helper_functions/tv_approx_haar_np.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def soft_py(x, tau): 4 | threshed = np.maximum(np.abs(x)-tau, 0) 5 | threshed = threshed*np.sign(x) 6 | return threshed 7 | 8 | def ht3(x, ax, shift, thresh): 9 | C = 1./np.sqrt(2.) 10 | 11 | if shift == True: 12 | x = np.roll(x, -1, axis = ax) 13 | if ax == 0: 14 | w1 = C*(x[1::2,...] + x[0::2, ...]) 15 | w2 = soft_py(C*(x[1::2,...] - x[0::2, ...]), thresh) 16 | elif ax == 1: 17 | w1 = C*(x[:, 1::2] + x[:, 0::2]) 18 | w2 = soft_py(C*(x[:,1::2] - x[:,0::2]), thresh) 19 | elif ax == 2: 20 | w1 = C*(x[:,:,1::2] + x[:,:, 0::2]) 21 | w2 = soft_py(C*(x[:,:,1::2] - x[:,:,0::2]), thresh) 22 | return w1, w2 23 | 24 | def iht3(w1, w2, ax, shift, shape): 25 | 26 | C = 1./np.sqrt(2.) 27 | y = np.zeros(shape) 28 | 29 | x1 = C*(w1 - w2); x2 = C*(w1 + w2); 30 | if ax == 0: 31 | y[0::2, ...] = x1 32 | y[1::2, ...] = x2 33 | 34 | if ax == 1: 35 | y[:, 0::2] = x1 36 | y[:, 1::2] = x2 37 | if ax == 2: 38 | y[:, :, 0::2] = x1 39 | y[:, :, 1::2] = x2 40 | 41 | 42 | if shift == True: 43 | y = np.roll(y, 1, axis = ax) 44 | return y 45 | 46 | 47 | def iht3_py2(w1, w2, ax, shift, shape): 48 | 49 | C = 1./np.sqrt(2.) 50 | y = np.zeros(shape) 51 | 52 | x1 = C*(w1 - w2); x2 = C*(w1 + w2); 53 | 54 | ind = ax + 2; 55 | y = np.reshape(np.concatenate([np.expand_dims(x1, ind), np.expand_dims(x2, ind)], axis = ind), shape) 56 | 57 | 58 | if shift == True: 59 | y = np.roll(y, 1, axis = ax+1) 60 | return y 61 | 62 | def tv3dApproxHaar(x, tau, alpha): 63 | D = len(x.shape) # D =3 for 3D and 2 for 2D 64 | fact = np.sqrt(2)*2 65 | 66 | thresh = D*tau*fact 67 | 68 | sqeezed = False 69 | if x.shape[-1] == 1: 70 | x = x[...,0] 71 | sqeezed = True 72 | 73 | y = np.zeros_like(x) 74 | for ax in range(0,len(x.shape)): 75 | if ax ==2: 76 | t_scale = alpha 77 | else: 78 | t_scale = 1; 79 | 80 | w0, w1 = ht3(x, ax, False, thresh*t_scale) 81 | w2, w3 = ht3(x, ax, True, thresh*t_scale) 82 | 83 | t1 = iht3(w0, w1, ax, False, x.shape) 84 | t2 = iht3(w2, w3, ax, True, x.shape) 85 | y = y + t1 + t2 86 | 87 | y = y/(2*D) 88 | if sqeezed == True: 89 | y = y[..., np.newaxis] 90 | return y 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /common/svd_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.fftpack import dct, idct 3 | 4 | import scipy.io 5 | from scipy.sparse import csr_matrix 6 | from scipy.sparse.linalg import svds 7 | 8 | import matplotlib.pyplot as plt 9 | from scipy.interpolate import griddata 10 | 11 | 12 | 13 | def register_psfs(stack,ref_im,dct_on=True): 14 | 15 | [Ny, Nx] = stack[:,:,0].shape; 16 | vec = lambda x: x.ravel() 17 | pad2d = lambda x: np.pad(x,((Ny//2,Ny//2),(Nx//2,Nx//2)),'constant', constant_values=(0)) 18 | fftcorr = lambda x,y:np.fft.ifft2(np.fft.fft2(pad2d(x))*np.conj(np.fft.fft2(np.fft.ifftshift(pad2d(y))))); 19 | M = stack.shape[2] 20 | Si = lambda x,si:np.roll(np.roll(x,si[0],axis=0),si[1],axis=1); 21 | 22 | pr = Ny + 1; 23 | pc = Nx + 1; # Relative centers of all correlations 24 | 25 | yi_reg = 0*stack; #Registered stack 26 | pad = lambda x:x; 27 | crop = lambda x:x; 28 | pad2d = lambda x:np.pad(x,((Ny//2,Ny//2),(Nx//2,Nx//2)),'constant', constant_values=(0)) 29 | crop2d = lambda x: x[Ny//2:3*Ny//2,Nx//2:3*Nx//2]; 30 | 31 | 32 | # % Normalize the stack first 33 | stack_norm = np.zeros((1,M)); 34 | stack_dct = stack*1; 35 | ref_norm = np.linalg.norm(ref_im,'fro'); 36 | for m in range (M): 37 | stack_norm[0,m] = np.linalg.norm(stack_dct[:,:,m],'fro'); 38 | stack_dct[:,:,m] = stack_dct[:,:,m]/stack_norm[0,m]; 39 | stack[:,:,m] = stack[:,:,m]/ref_norm; 40 | 41 | ref_im = ref_im/ref_norm; 42 | 43 | 44 | # ######### 45 | si={} 46 | 47 | # # % Do fft registration 48 | 49 | 50 | if dct_on: 51 | print('Removing background\n') 52 | for n in range (stack_dct.shape[2]): 53 | im = stack_dct[:,:,n]; 54 | bg_dct = dct(im); 55 | bg_dct[0:19,0:19] = 0; 56 | 57 | stack_dct[:,:,n] = idct(np.reshape(bg_dct,im.shape)); 58 | 59 | 60 | print('done\n') 61 | roi=np.zeros((Ny,Nx)) 62 | print('registering\n') 63 | good_count = 0; 64 | 65 | for m in range (M): 66 | 67 | corr_im = np.real(fftcorr(stack_dct[:,:,m],ref_im)); 68 | 69 | if np.max(corr_im) < .01: 70 | print('image %i has poor match. Skipping\n',m); 71 | else: 72 | 73 | [r,c] =np.unravel_index(np.argmax(corr_im),(2*Ny,2*Nx)) 74 | 75 | si[good_count] = [-(r-pr),-(c-pc)]; 76 | 77 | 78 | W = crop2d(Si(np.logical_not(pad2d(np.logical_not(roi))),-np.array(si[good_count]))); 79 | 80 | bg_estimate = np.sum(np.sum(W*stack[:,:,m]))/np.maximum(np.count_nonzero(roi),1)*0; 81 | im_reg = ref_norm*crop(Si(pad(stack[:,:,m]-bg_estimate),si[good_count])); 82 | 83 | 84 | yi_reg[:,:,good_count] = im_reg; 85 | good_count = good_count + 1; 86 | 87 | 88 | yi_reg = yi_reg[:,:,0:good_count]; 89 | 90 | 91 | print('done registering\n') 92 | 93 | 94 | return yi_reg,si 95 | 96 | def calc_svd(yi_reg,si,rnk): 97 | [Ny, Nx] = yi_reg[:,:,0].shape; 98 | print('creating matrix\n') 99 | Mgood = yi_reg.shape[2]; 100 | ymat = np.zeros((Ny*Nx,Mgood)); 101 | ymat=yi_reg.reshape(( yi_reg.shape[0]* yi_reg.shape[1], yi_reg.shape[2]),order='F') 102 | 103 | print('done\n') 104 | 105 | print('starting svd...\n') 106 | 107 | print('check values of ymat') 108 | [u,s,v] = svds(ymat,rnk); 109 | 110 | 111 | comps = np.reshape(u,[Ny, Nx,rnk],order='F'); 112 | vt = v*1 113 | # s=np.flip(s) 114 | # vt=np.flipud(vt) 115 | weights = np.zeros((Mgood,rnk)); 116 | for m in range (Mgood): 117 | for r in range(rnk): 118 | weights[m,r]=s[r]*vt[r,m] 119 | 120 | 121 | # si_mat = reshape(cell2mat(si)',[2,Mgood]); 122 | xq = np.arange(-Nx/2,Nx/2); 123 | yq = np.arange(-Ny/2,Ny/2); 124 | [Xq, Yq] = np.meshgrid(xq,yq); 125 | 126 | weights_interp = np.zeros((Ny, Nx,rnk)); 127 | xi=[] 128 | yi=[] 129 | si_list=list(si.values()) 130 | 131 | for i in range(len(si_list)): 132 | xi.append(si_list[i][0]) 133 | yi.append(si_list[i][1]) 134 | 135 | print('interpolating...\n') 136 | for r in range(rnk): 137 | # interpolant_r = scatteredInterpolant(si_mat(2,:)', si_mat(1,:)', weights(:,r),'natural','nearest'); 138 | # weights_interp(:,:,r) = rot90(interpolant_r(Xq,Yq),2); 139 | weights_interp[:,:,r]=griddata((xi,yi),weights[:,r],(Xq,Yq),method='nearest') 140 | 141 | print('done\n\n') 142 | 143 | return np.flip(comps,-1), np.flip(weights_interp,-1) 144 | 145 | -------------------------------------------------------------------------------- /data/3D_data_simulated/blurred_4cells-Copy1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/MultiWienerNet/f49a38e74a73bcf58f91a46d3ff0d2360d213283/data/3D_data_simulated/blurred_4cells-Copy1.mat -------------------------------------------------------------------------------- /data/3D_data_simulated/cellcool (59)flipudlr-Copy1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/MultiWienerNet/f49a38e74a73bcf58f91a46d3ff0d2360d213283/data/3D_data_simulated/cellcool (59)flipudlr-Copy1.mat -------------------------------------------------------------------------------- /data/fista3D-cellcool.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/MultiWienerNet/f49a38e74a73bcf58f91a46d3ff0d2360d213283/data/fista3D-cellcool.mat -------------------------------------------------------------------------------- /data/fista3D-fourCells.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/MultiWienerNet/f49a38e74a73bcf58f91a46d3ff0d2360d213283/data/fista3D-fourCells.mat -------------------------------------------------------------------------------- /data/hydra3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/MultiWienerNet/f49a38e74a73bcf58f91a46d3ff0d2360d213283/data/hydra3.jpg -------------------------------------------------------------------------------- /data/real_data/interesting_bear.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/MultiWienerNet/f49a38e74a73bcf58f91a46d3ff0d2360d213283/data/real_data/interesting_bear.mat -------------------------------------------------------------------------------- /data/real_data/resTargetZ_1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/MultiWienerNet/f49a38e74a73bcf58f91a46d3ff0d2360d213283/data/real_data/resTargetZ_1.mat -------------------------------------------------------------------------------- /pytorch/3D deconvolution demo (pretrained).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import numpy as np\n", 13 | "import torch\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "%matplotlib inline\n", 16 | "#from ipywidgets import interact, interactive, fixed, interact_manual\n", 17 | "from IPython.display import clear_output\n", 18 | "\n", 19 | "import os, sys, glob, cv2, hdf5storage, time\n", 20 | "import torch.nn as nn\n", 21 | "\n", 22 | "from torchvision import transforms\n", 23 | "import scipy.io\n", 24 | "\n", 25 | "import models.dataset as ds\n", 26 | "import helper as hp\n", 27 | "\n", 28 | "import matplotlib as mpl\n", 29 | "mpl.rc('image', cmap='inferno')\n", 30 | "\n", 31 | "\n", 32 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 33 | "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n", 34 | "device = 'cuda:0'\n", 35 | "dtype = torch.cuda.FloatTensor" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "!gpustat" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "# MultiWienerNet 3D Deconvolution Demo\n", 52 | "\n", 53 | "In this Jupyter Notebook, we take a pretrained MultiWienerNet and demonstrate fast spatially-varying deconvolutions using both simulated and real data. We compare the performance against a pre-trained U-Net, WienerNet (non-spatially-varying), and spatially-varying FISTA. " 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## Load in saved models" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# Filepaths to saved models\n", 70 | "multiwiener_file_path='saved_models/trained_multiwiener3D/'\n", 71 | "unet_file_path='saved_models/trained_unet3D/'\n", 72 | "wiener_file_path='saved_models/trained_wiener3D/'" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "unet_model = hp.load_pretrained_model(unet_file_path,model_type = 'unet', device = device)\n", 82 | "wiener_model = hp.load_pretrained_model(wiener_file_path, model_type = 'wiener', device = device)\n", 83 | "multiwiener_model = hp.load_pretrained_model(multiwiener_file_path, model_type = 'multiwiener', device = device)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "## Load in data " 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "## CLEAN UP\n", 100 | "down_size = ds.downsize(ds=.75)\n", 101 | "to_tensor = ds.ToTensor()\n", 102 | "add_noise=ds.AddNoise()\n", 103 | "\n", 104 | "filepath_gt = '../data/3D_data_simulated/'\n", 105 | "\n", 106 | "filepath_all=glob.glob(filepath_gt+'*')\n", 107 | "filepath_test=filepath_all\n", 108 | "\n", 109 | "dataset_test = ds.MiniscopeDataset(filepath_test, transform = transforms.Compose([down_size,add_noise,to_tensor]))" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "## Run deconvolution for simulated data" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "### Load in measurement" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "img_ind = 1 # We provide 2 sample images: 0 and 1 \n", 133 | "sample_batched = dataset_test.__getitem__(img_ind)\n", 134 | "meas_np = hp.to_np(sample_batched['meas'])\n", 135 | "sample_batched['meas'] = sample_batched['meas'].unsqueeze(0)\n", 136 | "\n", 137 | "plt.imshow(meas_np);\n", 138 | "plt.title('measurement');\n", 139 | "print('measurement shape:', meas_np.shape)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "### Deconvolve! " 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "t_list = []\n", 156 | "with torch.no_grad():\n", 157 | " t0 = time.time()\n", 158 | " out_unet = unet_model(sample_batched['meas'].repeat(1,1,32,1,1).to(device))\n", 159 | " t_list.append(time.time() - t0)\n", 160 | " \n", 161 | " t0 = time.time()\n", 162 | " out_wiener = wiener_model((sample_batched['meas']).to(device))\n", 163 | " t_list.append(time.time() - t0)\n", 164 | " \n", 165 | " t0 = time.time()\n", 166 | " out_multiwiener = multiwiener_model((sample_batched['meas']).to(device))\n", 167 | " t_list.append(time.time() - t0)\n", 168 | " \n", 169 | "recon_titles = ['Unet', 'WienerNet', 'MultiWienerNet (Ours)']\n", 170 | "recon_list = [out_unet, out_wiener, out_multiwiener]" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "### Plot results" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "gt_np = hp.to_np(sample_batched['im_gt'].unsqueeze(0))\n", 187 | "recons_np = []\n", 188 | "for i in range(0,len(recon_list)):\n", 189 | " recons_np.append(hp.to_np(recon_list[i]))\n", 190 | "\n", 191 | "f, ax = plt.subplots(1, 4, figsize=(15,15))\n", 192 | "ax[0].imshow(hp.max_proj(gt_np))\n", 193 | "ax[0].set_title('Ground Truth')\n", 194 | "for i in range(0,len(recons_np)):\n", 195 | " ax[i+1].imshow(hp.max_proj(recons_np[i]))\n", 196 | " ax[i+1].set_title(recon_titles[i])\n", 197 | " \n", 198 | "for i in range(0,len(recons_np)):\n", 199 | " print(recon_titles[i], ': ', np.round(t_list[i],2),'s, PSNR: ', np.round(hp.calc_psnr(gt_np, recons_np[i]),2))" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "out_np = recons_np[-1]\n", 209 | "def plot_slider(x):\n", 210 | " f, ax = plt.subplots(1, 4, figsize=(15,15))\n", 211 | " plt.title('Reconstruction: frame %d'%(x))\n", 212 | " \n", 213 | " ax[0].imshow(gt_np[x],vmin=0, vmax=np.max(gt_np))\n", 214 | " ax[0].set_title('Ground Truth, frame %d'%(x))\n", 215 | " ax[0].axis('off')\n", 216 | " for i in range(0,len(recons_np)):\n", 217 | " ax[i+1].imshow(recons_np[i][x], vmin=0, vmax=np.max(recons_np[i]))\n", 218 | " ax[i+1].set_title(recon_titles[i])\n", 219 | " ax[i+1].axis('off')\n", 220 | " \n", 221 | " return x\n", 222 | "\n", 223 | "\n", 224 | "interactive(plot_slider,x=(0,out_np.shape[0]-1,1))" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "### Compare against spatially-varying FISTA" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "#compare to fista\n", 241 | "saved_fista = [ 'fista3D-fourCells.mat', 'fista3D-cellcool.mat',]\n", 242 | "\n", 243 | "Ifista=scipy.io.loadmat('../data/' + saved_fista[img_ind])\n", 244 | "Ifista=Ifista['xhat_out']\n", 245 | "Ifista=Ifista.transpose([2,0,1])/np.max(Ifista)\n", 246 | "\n", 247 | "f, ax = plt.subplots(1, 2, figsize=(10,5))\n", 248 | "ax[0].imshow(hp.max_proj(Ifista))\n", 249 | "ax[0].set_title('FISTA result')\n", 250 | "ax[1].imshow(hp.max_proj(recons_np[-1]))\n", 251 | "ax[1].set_title(recon_titles[-1])\n", 252 | "\n", 253 | "print('FISTA PSNR: ', np.round(hp.calc_psnr(gt_np, Ifista),2))" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "## Run deconvolution for real data" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "img_ind = 0 # 0: resolution target, 1: waterbear\n", 270 | "\n", 271 | "loaded_meas = glob.glob('../data/real_data/*')\n", 272 | "meas_loaded = scipy.io.loadmat(loaded_meas[img_ind])['b']" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "meas_loaded.shape" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "meas=meas_loaded[18:466,4:644]\n", 291 | "meas= cv2.resize(meas, (0,0), fx=0.75, fy=0.75) \n", 292 | "meas_tensor=torch.tensor(meas, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0).unsqueeze(0)\n", 293 | "plt.imshow(meas)" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "with torch.no_grad():\n", 303 | " meas_t = meas_tensor.repeat(1,1,32,1,1)\n", 304 | " out_unet = unet_model(meas_t.to(device))\n", 305 | " out_wiener = wiener_model((meas_t).to(device))\n", 306 | " out_multiwiener = multiwiener_model((meas_t).to(device))\n", 307 | " \n", 308 | " recon_titles = ['Unet', 'WienerNet', 'MultiWienerNet (Ours)']\n", 309 | " recon_list = [out_unet, out_wiener, out_multiwiener]" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "\n", 319 | "with torch.no_grad():\n", 320 | " out_wiener = wiener_model.wiener_model(meas_t.to(device))\n", 321 | "\n", 322 | " out_multiwiener = multiwiener_model.wiener_model(meas_t.to(device))\n", 323 | "\n", 324 | " \n", 325 | "plt.imshow(out_multiwiener[0,4,0].detach().cpu().numpy()); plt.colorbar()\n" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "recons_np = []\n", 335 | "for i in range(0,len(recon_list)):\n", 336 | " recons_np.append(hp.to_np(recon_list[i]))\n", 337 | "\n", 338 | "f, ax = plt.subplots(1, 3, figsize=(15,15))\n", 339 | "for i in range(0,len(recons_np)):\n", 340 | " if img_ind == 0:\n", 341 | " ax[i].imshow(recons_np[i][1])\n", 342 | " else:\n", 343 | " ax[i].imshow(hp.max_proj(recons_np[i]))\n", 344 | " ax[i].set_title(recon_titles[i])" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "def plot_slider(x):\n", 354 | " f, ax = plt.subplots(1, 3, figsize=(15,15))\n", 355 | " plt.title('Reconstruction: frame %d'%(x))\n", 356 | " \n", 357 | " for i in range(0,len(recons_np)):\n", 358 | " ax[i].imshow(recons_np[i][x], vmin=0, vmax=np.max(recons_np[i]))\n", 359 | " ax[i].axis('off')\n", 360 | " \n", 361 | " if i ==0:\n", 362 | " ax[i].set_title('Unet, frame %d'%(x))\n", 363 | " else:\n", 364 | " ax[i].set_title(recon_titles[i])\n", 365 | " \n", 366 | " return x\n", 367 | "\n", 368 | "\n", 369 | "interactive(plot_slider,x=(0,out_np.shape[0]-1,1))" 370 | ] 371 | }, 372 | { 373 | "cell_type": "markdown", 374 | "metadata": {}, 375 | "source": [ 376 | "## Run deconvolution movie for real data" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [ 385 | "waterbear=hdf5storage.loadmat('/media/lahvahndata/Kyrollos/LearnedMiniscope3D/real_data/waterbear_all.mat') \n", 386 | "waterbear=waterbear['b']\n", 387 | "waterbear=(waterbear)" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "waterbear.shape" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "meas=waterbear[18:466,4:644,:]\n", 406 | "meas= cv2.resize(meas, (0,0), fx=0.75, fy=0.75) \n", 407 | "meas_tensor=torch.tensor(meas, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0).unsqueeze(0)\n", 408 | "plt.imshow(meas[...,0])" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "metadata": {}, 415 | "outputs": [], 416 | "source": [ 417 | "def plot_slider(x):\n", 418 | " plt.title('Reconstruction: frame %d'%(x))\n", 419 | " plt.axis('off')\n", 420 | " plt.imshow(meas[...,x])\n", 421 | " return x\n", 422 | "\n", 423 | "\n", 424 | "interactive(plot_slider,x=(0,meas.shape[-1]-1,1))" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": null, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [ 433 | "\n", 434 | "out_bear_xy=[]\n", 435 | "out_bear_yz=[]\n", 436 | "for t in range(30):\n", 437 | " \n", 438 | " print('processing image: ', t, end='\\r')\n", 439 | " with torch.no_grad():\n", 440 | " out_waterbear=multiwiener_model(meas_tensor[...,t]) #.repeat(1,1,32,1,1)\n", 441 | " out_waterbear_np = out_waterbear.detach().cpu().numpy()[0,0]\n", 442 | " \n", 443 | " out_bear_xy.append(np.max(out_waterbear_np,0))\n", 444 | " out_bear_yz.append(np.max(out_waterbear_np,2))\n", 445 | " \n", 446 | " \n", 447 | "# plt.imshow(out_bear_xy[-1])\n", 448 | "# plt.title(t)\n", 449 | "# plt.show()\n", 450 | "# clear_output(wait=True)\n" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "metadata": {}, 457 | "outputs": [], 458 | "source": [ 459 | "out_bear_xy=np.array(out_bear_xy)\n", 460 | "out_bear_yz=np.array(out_bear_yz)\n", 461 | "# test=test.transpose([1,2,0])" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [ 470 | "def plot_slider(x):\n", 471 | " f, ax = plt.subplots(1, 3, figsize=(15,3))\n", 472 | " \n", 473 | " \n", 474 | " ax[0].imshow(meas[...,x], vmin=0, vmax=np.max(meas))\n", 475 | " ax[1].imshow(out_bear_xy[x], vmin=0, vmax=np.max(out_bear_xy))\n", 476 | " ax[2].imshow(out_bear_yz[x].transpose(), vmin=0, vmax=np.max(out_bear_yz))\n", 477 | " \n", 478 | " ax[0].set_title('Measurement')\n", 479 | " ax[1].set_title('Reconstruction: frame %d'%(x))\n", 480 | " \n", 481 | " ax[0].axis('off')\n", 482 | " ax[1].axis('off')\n", 483 | " ax[2].axis('off')\n", 484 | " \n", 485 | " \n", 486 | " return x\n", 487 | "\n", 488 | "\n", 489 | "interactive(plot_slider,x=(0,out_bear_xy.shape[0]-1,1))" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": {}, 495 | "source": [ 496 | "## Visualize Learned PSFs" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": null, 502 | "metadata": {}, 503 | "outputs": [], 504 | "source": [ 505 | "learned_psfs_wiener_np=wiener_model.wiener_model.psfs.detach().cpu().numpy()" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "metadata": {}, 512 | "outputs": [], 513 | "source": [ 514 | "def plot_slider(x):\n", 515 | " plt.title('Reconstruction: frame %d'%(x))\n", 516 | " plt.axis('off')\n", 517 | " plt.imshow(learned_psfs_wiener_np[x])\n", 518 | " return x\n", 519 | "\n", 520 | "\n", 521 | "interactive(plot_slider,x=(0,learned_psfs_wiener_np.shape[0]-1,1))" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "metadata": {}, 528 | "outputs": [], 529 | "source": [ 530 | "learned_psfs_np=multiwiener_model.wiener_model.psfs.detach().cpu().numpy()\n", 531 | "learned_Ks_np=multiwiener_model.wiener_model.Ks.detach().cpu().numpy()" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": null, 537 | "metadata": {}, 538 | "outputs": [], 539 | "source": [ 540 | "def plot_slider(x):\n", 541 | " plt.title('Reconstruction: frame %d'%(x))\n", 542 | " plt.axis('off')\n", 543 | " plt.imshow(learned_psfs_np[4][x])\n", 544 | " return x\n", 545 | "\n", 546 | "\n", 547 | "interactive(plot_slider,x=(0,learned_psfs_np.shape[1]-1,1))" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": null, 553 | "metadata": {}, 554 | "outputs": [], 555 | "source": [ 556 | "x=20\n", 557 | "plt.imshow(np.abs(learned_psfs_np[8][x]-learned_psfs_np[0][x])); plt.colorbar()" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": null, 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [] 566 | } 567 | ], 568 | "metadata": { 569 | "kernelspec": { 570 | "display_name": "multiwiener_torch3", 571 | "language": "python", 572 | "name": "multiwiener_torch3" 573 | }, 574 | "language_info": { 575 | "codemirror_mode": { 576 | "name": "ipython", 577 | "version": 3 578 | }, 579 | "file_extension": ".py", 580 | "mimetype": "text/x-python", 581 | "name": "python", 582 | "nbconvert_exporter": "python", 583 | "pygments_lexer": "ipython3", 584 | "version": "3.8.12" 585 | } 586 | }, 587 | "nbformat": 4, 588 | "nbformat_minor": 4 589 | } 590 | -------------------------------------------------------------------------------- /pytorch/debug.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import torch, torch.optim\n", 11 | "import torch.nn.functional as F\n", 12 | "torch.backends.cudnn.enabled = True\n", 13 | "torch.backends.cudnn.benchmark =True\n", 14 | "dtype = torch.cuda.FloatTensor\n", 15 | "import os, sys, json, glob\n", 16 | "from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM\n", 17 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 18 | "\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "\n", 21 | "import random\n", 22 | "\n", 23 | "import skimage.io\n", 24 | "import torch.nn as nn\n", 25 | "import argparse\n", 26 | "\n", 27 | "from torch.utils.data import Dataset, DataLoader\n", 28 | "from torchvision import transforms, utils\n", 29 | "import cv2\n", 30 | "import models.wiener_model as wm\n", 31 | "import models.dataset as ds\n", 32 | "from PIL import Image\n", 33 | "import helper as hp\n", 34 | "\n", 35 | "import scipy.io" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "parser = argparse.ArgumentParser(description='Process some integers.')\n", 45 | "parser.add_argument('--data_type', default='2D')\n", 46 | "parser.add_argument('--network', default='multiwiener') #'wiener' or 'unet' or 'multiwiener'\n", 47 | "parser.add_argument('--id', default='new_unet2') #some identifier\n", 48 | "parser.add_argument('--loss_type', default='l1') \n", 49 | "parser.add_argument('--device', default='0') \n", 50 | "parser.add_argument('--psf_num', default=9, type=int)\n", 51 | "parser.add_argument('--psf_ds', default=0.75, type=float)\n", 52 | "parser.add_argument('--epochs', default=10000, type=int)\n", 53 | "parser.add_argument('--lr', default=1e-4, type=float) \n", 54 | "parser.add_argument('--batch_size', default=4, type=int) \n", 55 | "parser.add_argument('--load_path',default=None)\n", 56 | "parser.add_argument('--save_checkponts',default=True)\n", 57 | "\n", 58 | "#args = parser.parse_args()\n", 59 | "args = parser.parse_args(''.split())\n", 60 | "\n", 61 | "os.environ['CUDA_VISIBLE_DEVICES'] = args.device" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "choosing 9 psfs\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "\n", 79 | "\n", 80 | "# for 3D-UNet multiwiener\n", 81 | "registered_psfs_path = '../data/multiWienerPSFStack_40z_aligned.mat'\n", 82 | "psfs = scipy.io.loadmat(registered_psfs_path)\n", 83 | "psfs=psfs['multiWienerPSFStack_40z']\n", 84 | "\n", 85 | "if args.data_type == '3D':\n", 86 | " if args.network=='wiener' or args.network=='unet':\n", 87 | " psfs=hp.pre_process_psfs(psfs)[:,:,4]\n", 88 | " Ks=np.ones((32,1,1))\n", 89 | " print('choosing 1 psfs')\n", 90 | "\n", 91 | " elif args.network=='multiwiener':\n", 92 | " Ks=np.ones((args.psf_num,32,1,1))\n", 93 | " if args.psf_num==9:\n", 94 | " print('choosing 9 psfs')\n", 95 | " psfs=hp.pre_process_psfs(psfs)\n", 96 | " else:\n", 97 | " print('invalid network')\n", 98 | " psfs = hp.downsize_psf(psfs)\n", 99 | "else: #2D\n", 100 | " if args.network=='wiener' or args.network=='unet':\n", 101 | " psfs=hp.pre_process_psfs_2d(psfs)[:,:,4, 0]\n", 102 | " Ks= 1.\n", 103 | " print('choosing 1 psfs')\n", 104 | "\n", 105 | " elif args.network=='multiwiener':\n", 106 | " Ks=np.ones((args.psf_num,1,1))\n", 107 | " if args.psf_num==9:\n", 108 | " print('choosing 9 psfs')\n", 109 | " psfs=hp.pre_process_psfs_2d(psfs)[...,0]\n", 110 | " psfs = psfs.transpose(2,0,1)\n", 111 | " else:\n", 112 | " print('invalid network')\n", 113 | "\n", 114 | " \n", 115 | "down_size = ds.downsize(ds=args.psf_ds)\n", 116 | "to_tensor = ds.ToTensor()\n", 117 | "add_noise=ds.AddNoise()" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 6, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "total number of images 22126\n", 130 | "training images: 17700 testing images: 4426\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "\n", 136 | "\n", 137 | "if args.data_type == '3D':\n", 138 | " filepath_gt = '/home/kyrollos/LearnedMiniscope3D/Data3D/Training_data_all/' \n", 139 | "else:\n", 140 | " filepath_gt = '/home/kyrollos/LearnedMiniscope3D/Data/Target/'\n", 141 | " filepath_meas = '/home/kyrollos/LearnedMiniscope3D/Data/Train/'\n", 142 | "\n", 143 | "\n", 144 | "filepath_all=glob.glob(filepath_gt+'*')\n", 145 | "random.Random(8).shuffle(filepath_all)\n", 146 | "print('total number of images',len(filepath_all))\n", 147 | "total_num_images = len(filepath_all)\n", 148 | "num_test = 0.2 # 20% test\n", 149 | "filepath_train=filepath_all[0:int(total_num_images*(1-num_test))]\n", 150 | "filepath_test=filepath_all[int(total_num_images*(1-num_test)):]\n", 151 | "\n", 152 | "print('training images:', len(filepath_train), \n", 153 | " 'testing images:', len(filepath_test))\n", 154 | "\n", 155 | "if args.data_type == '3D':\n", 156 | " dataset_train = ds.MiniscopeDataset(filepath_train, transform = transforms.Compose([down_size,add_noise,to_tensor]))\n", 157 | " dataset_test = ds.MiniscopeDataset(filepath_test, transform = transforms.Compose([down_size,add_noise,to_tensor]))\n", 158 | "else:\n", 159 | " dataset_train = ds.MiniscopeDataset_2D(filepath_train, filepath_meas, transform = transforms.Compose([ds.crop2d(),ds.ToTensor2d()]))\n", 160 | " dataset_test = ds.MiniscopeDataset_2D(filepath_test, filepath_meas, transform = transforms.Compose([ds.crop2d(),ds.ToTensor2d()]))\n", 161 | "\n", 162 | "\n", 163 | "dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size,\n", 164 | " shuffle=True, num_workers=1)\n", 165 | "\n", 166 | "dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size,\n", 167 | " shuffle=False, num_workers=1)\n", 168 | "\n", 169 | "device = 'cuda:0'\n", 170 | "\n", 171 | "if args.data_type == '3D':\n", 172 | " from models.unet3d import Unet\n", 173 | " unet_model = Unet(n_channel_in=args.psf_num, n_channel_out=1, residual=False, down='conv', up='tconv', activation='selu').to(device)\n", 174 | "\n", 175 | " if args.network == 'multiwiener' or args.network == 'wiener':\n", 176 | " wiener_model=wm.WienerDeconvolution3D(psfs,Ks).to(device)\n", 177 | " model=wm.MyEnsemble(wiener_model,unet_model)\n", 178 | " else:\n", 179 | " model = unet_model\n", 180 | "else: #2D\n", 181 | " from models.unet2 import UNet\n", 182 | " if args.network == 'multiwiener':\n", 183 | " num_in_channels = args.psf_num\n", 184 | " else:\n", 185 | " num_in_channels = 1\n", 186 | " \n", 187 | " #unet_model = Unet(n_channel_in=num_in_channels, n_channel_out=1, residual=False, down='conv', up='tconv', activation='selu').to(device)\n", 188 | " unet_model = UNet(n_channels=num_in_channels, n_classes=1).to(device)\n", 189 | "\n", 190 | " if args.network == 'multiwiener' or args.network == 'wiener':\n", 191 | " wiener_model=wm.WienerDeconvolution3D(psfs,Ks).to(device)\n", 192 | " model=wm.MyEnsemble(wiener_model,unet_model)\n", 193 | " else:\n", 194 | " model = unet_model\n" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 7, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "if args.load_path is not None:\n", 204 | " model.load_state_dict(torch.load('saved_data/'+args.load_path,map_location=torch.device(device)))\n", 205 | " print('loading saved model')\n", 206 | "\n", 207 | "\n", 208 | "loss_fn = torch.nn.L1Loss()\n", 209 | "ssim_loss = SSIM(win_size=11, win_sigma=1.5, data_range=1, size_average=True, channel=1)\n", 210 | "\n", 211 | "optimizer = torch.optim.Adam(model.parameters(), lr = args.lr)\n", 212 | "\n", 213 | "if args.save_checkponts == True:\n", 214 | " filepath_save = 'saved_data/' +\"_\".join((list(vars(args).values()))[0:5]) + \"/\"\n", 215 | "\n", 216 | " if not os.path.exists(filepath_save):\n", 217 | " os.makedirs(filepath_save)\n", 218 | "\n", 219 | " with open(filepath_save + 'args.json', 'w') as fp:\n", 220 | " json.dump(vars(args), fp)\n", 221 | " \n" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "name": "stderr", 231 | "output_type": "stream", 232 | "text": [ 233 | "/home/kyrollos/anaconda3/envs/torch_anaconda/lib/python3.6/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448233824/work/c10/core/TensorImpl.h:1156.)\n", 234 | " return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n" 235 | ] 236 | }, 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "loss for testing set 0 1106 78.378226496279242976\n", 242 | "loss for testing set 1 1106 71.941382549703122068\n", 243 | "loss for testing set 2 1106 74.401428252458570695\n", 244 | "loss for testing set 3 1106 71.045484699308877864\n", 245 | "epoch: 4 batch: 1800 loss: 0.108448676764965064\r" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "best_loss=27e7\n", 251 | "\n", 252 | "for itr in range(0,args.epochs):\n", 253 | " for i_batch, sample_batched in enumerate(dataloader_train):\n", 254 | " optimizer.zero_grad()\n", 255 | " #sample_batched['meas'] = sample_batched['meas']*255.\n", 256 | " #out = model(sample_batched['meas'].repeat(1,32,1,1)[...,18:466,4:644].unsqueeze(0).to(device))\n", 257 | " if args.network=='unet' and args.data_type == '3D':\n", 258 | " out = model(sample_batched['meas'].repeat(1,1,32,1,1).to(device))\n", 259 | " else:\n", 260 | " out = model(sample_batched['meas'].to(device))\n", 261 | "\n", 262 | " if args.loss_type=='l1':\n", 263 | " loss = loss_fn(out, sample_batched['im_gt'].to(device))\n", 264 | " else:\n", 265 | " #loss = loss_fn(out, sample_batched['im_gt'].to(device))+ (1- ms_ssim( out[0], sample_batched['im_gt'][0].to(device), data_range=1, size_average=False))\n", 266 | " \n", 267 | " loss = loss_fn(out, sample_batched['im_gt'].to(device)) + (1-ssim_loss(out, sample_batched['im_gt'].to(device)))\n", 268 | " loss.backward()\n", 269 | " optimizer.step()\n", 270 | " \n", 271 | " model.wiener_model.Ks.data = model.wiener_model.Ks.data.clamp(0,1e9)\n", 272 | " \n", 273 | " if i_batch %100 ==0:\n", 274 | " print('epoch: ', itr, ' batch: ', i_batch, ' loss: ', loss.item(), end='\\r')\n", 275 | "\n", 276 | " #break \n", 277 | " if args.data_type == '3D':\n", 278 | " out_np = np.max(out.detach().cpu().numpy()[0,0],0)\n", 279 | " gt_np = np.max(sample_batched['im_gt'].detach().cpu().numpy()[0,0],0)\n", 280 | " meas_np = np.max(sample_batched['meas'].detach().cpu().numpy()[0,0],0)\n", 281 | " else:\n", 282 | " out_np = out.detach().cpu().numpy()[0][0]\n", 283 | " gt_np = sample_batched['im_gt'].detach().cpu().numpy()[0][0]\n", 284 | " meas_np = sample_batched['meas'].detach().cpu().numpy()[0][0]\n", 285 | "\n", 286 | "\n", 287 | " if args.save_checkponts == True:\n", 288 | " torch.save(model.state_dict(), filepath_save + 'model_noval.pt')\n", 289 | " \n", 290 | " if itr%1==0:\n", 291 | " total_loss=0\n", 292 | " for i_batch, sample_batched in enumerate(dataloader_test):\n", 293 | " sample_batched['meas'] = sample_batched['meas']*255.\n", 294 | " with torch.no_grad():\n", 295 | " if args.network=='unet' and args.data_type == '3D':\n", 296 | " out = model(sample_batched['meas'].repeat(1,1,32,1,1).to(device))\n", 297 | " else:\n", 298 | " out = model(sample_batched['meas'].to(device))\n", 299 | " if args.loss_type=='l1':\n", 300 | " loss = loss_fn(out, sample_batched['im_gt'].to(device))\n", 301 | " else:\n", 302 | " loss = loss_fn(out, sample_batched['im_gt'].to(device)) + (1-ssim_loss(out, sample_batched['im_gt'].to(device)))\n", 303 | " \n", 304 | " #loss = loss_fn(out, sample_batched['im_gt'].to(device))+(1- ms_ssim( out, sample_batched['im_gt'][0].to(device), data_range=1, size_average=False))\n", 305 | " \n", 306 | " \n", 307 | " total_loss+=loss.item()\n", 308 | " \n", 309 | " print('loss for testing set ',itr,' ',i_batch, total_loss)\n", 310 | " \n", 311 | " #break\n", 312 | " \n", 313 | " if args.save_checkponts == True:\n", 314 | " im_gt = Image.fromarray((np.clip(gt_np/np.max(gt_np),0,1)*255).astype(np.uint8))\n", 315 | " im = Image.fromarray((np.clip(out_np/np.max(out_np),0,1)*255).astype(np.uint8))\n", 316 | " im.save(filepath_save + str(itr) + '.png')\n", 317 | " im_gt.save(filepath_save + 'gt.png')\n", 318 | " \n", 319 | " \n", 320 | " if total_loss=0.5 9 | - python=3.8 10 | - pip 11 | - nb_conda 12 | - numpy=1.18 13 | - matplotlib 14 | - scipy 15 | - opencv 16 | - imageio 17 | - ipykernel 18 | - ipython 19 | - scikit-image 20 | - hdf5storage 21 | - ipywidgets 22 | 23 | - pip: 24 | - pytorch-msssim==0.2.1 25 | - tifffile 26 | 27 | -------------------------------------------------------------------------------- /pytorch/helper.py: -------------------------------------------------------------------------------- 1 | import argparse, json, math 2 | import scipy.io 3 | import numpy as np 4 | import cv2 5 | #import models.wiener_model as wm 6 | import torch 7 | 8 | def to_np(x): 9 | x = x.detach().cpu().numpy()[0,0] 10 | x=x/np.max(x) 11 | return x 12 | 13 | def max_proj(x, axis = 0): 14 | return np.max(x,axis) 15 | 16 | def mean_proj(x, axis = 0): 17 | return np.mean(x,axis) 18 | 19 | def calc_psnr(Iin,Itarget): 20 | 21 | mse=np.mean(np.square(Iin-Itarget)) 22 | return 10*math.log10(1/mse) 23 | 24 | def load_saved_args(model_file_path): 25 | parser = argparse.ArgumentParser(description='Process some integers.') 26 | parser.add_argument('--data_type', default='3D') 27 | parser.add_argument('--num_images', default='multi') #'single' or 'multi 28 | parser.add_argument('--network', default='combined') #'combined' or 'unet' 29 | parser.add_argument('--device', default='0') 30 | parser.add_argument('--epochs', default=10000, type=int) 31 | 32 | args = parser.parse_args("--device 1".split()) 33 | 34 | with open(model_file_path+'args.json', "r") as f: 35 | args.__dict__=json.load(f) 36 | return args 37 | 38 | def initialize_psfs(model_file_path): 39 | 40 | args = load_saved_args(model_file_path) 41 | # for 3D-UNet multiwiener 42 | registered_psfs_path = '../data/multiWienerPSFStack_40z_aligned.mat' 43 | psfs = scipy.io.loadmat(registered_psfs_path) 44 | psfs=psfs['multiWienerPSFStack_40z'] 45 | if args.data_type == '3D': 46 | if args.network=='wiener' or args.network=='unet': 47 | psfs=psfs[18:466,4:644,4,0:32] 48 | Ks=np.ones((32,1,1)) 49 | 50 | elif args.network=='multiwiener': 51 | Ks=np.ones((args.psf_num,32,1,1)) 52 | if args.psf_num==9: 53 | print('choosing 9 psfs') 54 | psfs=psfs[18:466,4:644,:,0:32] 55 | 56 | elif args.psf_num==4: 57 | print('choosing 4 psfs') 58 | psfs=psfs[18:466,4:644,2:6,0:32] 59 | else: 60 | print('invalid psf num') 61 | 62 | psfs_ds=np.zeros((int(psfs.shape[0]*args.psf_ds),int(psfs.shape[1]*args.psf_ds),*psfs.shape[2:])) 63 | 64 | if args.psf_num>1: 65 | for p in range(psfs.shape[2]): 66 | psfs_ds[:,:,p,:]=cv2.resize(psfs[:,:,p], (0,0), fx=args.psf_ds, fy=args.psf_ds) 67 | psfs_ds=np.transpose(psfs_ds,[2,3,0,1]) 68 | 69 | else: 70 | psfs_ds=cv2.resize(psfs, (0,0), fx=args.psf_ds, fy=args.psf_ds) 71 | 72 | psfs_ds=np.transpose(psfs_ds,[2,0,1]) 73 | 74 | psfs_ds=psfs_ds/np.max(psfs_ds) 75 | psfs=psfs_ds 76 | 77 | else: #2D 78 | if args.network=='wiener' or args.network=='unet': 79 | psfs=pre_process_psfs_2d(psfs)[:,:,4, 0] 80 | Ks= 1. 81 | print('choosing 1 psfs') 82 | 83 | elif args.network=='multiwiener': 84 | Ks=np.ones((args.psf_num,1,1)) 85 | if args.psf_num==9: 86 | print('choosing 9 psfs') 87 | psfs=pre_process_psfs_2d(psfs)[...,0] 88 | psfs = psfs.transpose(2,0,1) 89 | else: 90 | print('invalid network') 91 | 92 | 93 | psfs = psfs/np.max(psfs) 94 | 95 | return psfs,Ks, args 96 | 97 | def load_pretrained_model(filepath, model_type = 'unet', device = 'cuda:0'): 98 | from models.unet3d import Unet 99 | import models.wiener_model as wm 100 | if model_type == 'unet': 101 | model = Unet(n_channel_in=1, n_channel_out=1, residual=False, down='conv', up='tconv', activation='selu').to(device) 102 | 103 | elif model_type == 'wiener': 104 | psfs_wiener,Ks_wiener, args_wiener = initialize_psfs(filepath) 105 | unet_model = Unet(n_channel_in=1, n_channel_out=1, residual=False, down='conv', up='tconv', activation='selu').to(device) 106 | wiener_stage=wm.WienerDeconvolution3D(psfs_wiener,Ks_wiener).to(device) 107 | model=wm.MyEnsemble(wiener_stage,unet_model) 108 | 109 | elif model_type == 'multiwiener': 110 | psfs_multiwiener,Ks_multiwiener, args_multi= initialize_psfs(filepath) 111 | unet_stage = Unet(n_channel_in=9, n_channel_out=1, residual=False, down='conv', up='tconv', activation='selu').to(device) 112 | multiwiener_stage=wm.WienerDeconvolution3D(psfs_multiwiener,Ks_multiwiener).to(device) 113 | model=wm.MyEnsemble(multiwiener_stage,unet_stage) 114 | 115 | model.load_state_dict(torch.load(filepath+'model.pt',map_location=torch.device(device))) 116 | return model 117 | 118 | def load_pretrained_model_2d(filepath, model_type = 'unet', device = 'cuda:0', load_model = True): 119 | from models.unet import Unet 120 | import models.wiener_model as wm 121 | if model_type == 'unet': 122 | model = Unet(n_channel_in=1, n_channel_out=1).to(device)#, residual=False, down='conv', up='nearest', activation='relu').to(device) 123 | 124 | elif model_type == 'wiener': 125 | psfs_wiener,Ks_wiener, args_wiener = initialize_psfs(filepath) 126 | #unet_model = Unet(n_channel_in=1, n_channel_out=1).to(device)#, residual=False, down='conv', up='nearest',activation='relu').to(device) 127 | 128 | unet_model = Unet(n_channel_in=1, n_channel_out=1, residual=True, down='conv', up='nearest',activation='relu').to(device) 129 | wiener_stage=wm.WienerDeconvolution3D(psfs_wiener,Ks_wiener).to(device) 130 | model=wm.MyEnsemble2d(wiener_stage,unet_model) 131 | 132 | elif model_type == 'multiwiener': 133 | psfs_multiwiener,Ks_multiwiener, args_multi= initialize_psfs(filepath) 134 | #unet_stage = Unet(n_channel_in=9, n_channel_out=1).to(device)#, residual=False, down='conv', up='nearest',activation='relu').to(device) 135 | unet_stage = Unet(n_channel_in=9, n_channel_out=1, residual=True, down='conv', up='nearest',activation='relu').to(device) 136 | multiwiener_stage=wm.WienerDeconvolution3D(psfs_multiwiener,Ks_multiwiener).to(device) 137 | model=wm.MyEnsemble2d(multiwiener_stage,unet_stage) 138 | 139 | if load_model == True: 140 | model.load_state_dict(torch.load(filepath+'model.pt',map_location=torch.device(device))) 141 | return model 142 | 143 | def pre_process_psfs(x): 144 | # Use this to make the image size a power of 2 for the network. 145 | # Change these numbers according to your image size 146 | x = x[18:466,4:644,:,0:32] 147 | return x 148 | def pre_process_psfs_2d(x): 149 | # Use this to make the image size a power of 2 for the network. 150 | # Change these numbers according to your image size 151 | x = x[18:466,4:644] 152 | return x 153 | 154 | def downsize_psf(psfs): 155 | psfs_ds=np.zeros((int(psfs.shape[0]*args.psf_ds),int(psfs.shape[1]*args.psf_ds),*psfs.shape[2:])) 156 | if args.psf_num>1: 157 | for p in range(psfs.shape[2]): 158 | psfs_ds[:,:,p,:]=cv2.resize(psfs[:,:,p], (0,0), fx=args.psf_ds, fy=args.psf_ds) 159 | psfs_ds=np.transpose(psfs_ds,[2,3,0,1]) 160 | else: 161 | psfs_ds=cv2.resize(psfs, (0,0), fx=args.psf_ds, fy=args.psf_ds) 162 | psfs_ds=np.transpose(psfs_ds,[2,0,1]) 163 | 164 | psfs_ds=psfs_ds/np.max(psfs_ds) 165 | return psfs_ds -------------------------------------------------------------------------------- /pytorch/models/dataset.py: -------------------------------------------------------------------------------- 1 | import skimage.io 2 | import glob 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms, utils 5 | import cv2 6 | import torch 7 | import numpy as np 8 | import scipy.io 9 | 10 | class MiniscopeDataset(Dataset): 11 | """Face Landmarks dataset.""" 12 | 13 | def __init__(self, all_files, transform=None): 14 | 15 | 16 | self.all_files_gt = all_files 17 | 18 | 19 | self.transform = transform 20 | 21 | def __len__(self): 22 | return len(self.all_files_gt) 23 | 24 | def __getitem__(self, idx): 25 | 26 | im_all = scipy.io.loadmat(self.all_files_gt[idx]) 27 | 28 | im_gt = im_all['gt'] 29 | 30 | 31 | im_meas = im_all['meas'] 32 | 33 | sample = {'im_gt': im_gt.astype('float32')/np.max(im_gt), 'meas': im_meas.astype('float32')/np.max(im_meas)} 34 | 35 | 36 | if self.transform: 37 | sample = self.transform(sample) 38 | 39 | return sample 40 | 41 | 42 | class MiniscopeDataset_2D(Dataset): 43 | """Face Landmarks dataset.""" 44 | 45 | def __init__(self, all_files, filepath_meas, transform=None): 46 | 47 | 48 | self.all_files_gt = all_files 49 | self.filepath_meas = filepath_meas 50 | 51 | self.transform = transform 52 | 53 | def __len__(self): 54 | return len(self.all_files_gt) 55 | 56 | def __getitem__(self, idx): 57 | 58 | im_gt = skimage.io.imread(self.all_files_gt[idx]) 59 | im_meas = skimage.io.imread(self.filepath_meas+self.all_files_gt[idx].split('/')[-1]) 60 | 61 | sample = {'im_gt': im_gt.astype('float32')/255., 'meas': im_meas.astype('float32')/255.} 62 | 63 | 64 | if self.transform: 65 | sample = self.transform(sample) 66 | 67 | return sample 68 | 69 | class MiniscopeDataset_backup(Dataset): 70 | """Face Landmarks dataset.""" 71 | 72 | def __init__(self, filepath_meas, filepath_gt, transform=None): 73 | 74 | self.filepath_meas = filepath_meas 75 | self.filepath_gt = filepath_gt 76 | self.all_files_gt = glob.glob(filepath_gt + '*.tiff') 77 | 78 | self.transform = transform 79 | 80 | def __len__(self): 81 | return len(self.all_files_gt) 82 | 83 | def __getitem__(self, idx): 84 | 85 | im_gt = skimage.io.imread(self.filepath_gt + str(idx) + '.tiff') 86 | im_meas = skimage.io.imread(self.filepath_meas + str(idx) + '.png') 87 | 88 | sample = {'im_gt': im_gt.astype('float32')/255., 'meas': im_meas.astype('float32')/255.} 89 | 90 | 91 | if self.transform: 92 | sample = self.transform(sample) 93 | 94 | return sample 95 | class crop2d(object): 96 | """Convert ndarrays in sample to Tensors.""" 97 | 98 | def __call__(self, sample): 99 | 100 | im_gt, meas = sample['im_gt'], sample['meas'] 101 | meas=meas[18:466,4:644] 102 | im_gt=im_gt[18:466,4:644] 103 | return {'im_gt': im_gt, 104 | 'meas': meas} 105 | 106 | class downsize(object): 107 | """Convert ndarrays in sample to Tensors.""" 108 | 109 | def __init__(self, ds=0.5): 110 | self.ds=ds 111 | 112 | 113 | def __call__(self, sample): 114 | 115 | im_gt, meas = sample['im_gt'], sample['meas'] 116 | meas=meas[18:466,4:644] 117 | meas= cv2.resize(meas, (0,0), fx=self.ds, fy=self.ds) 118 | im_gt=im_gt[18:466,4:644] 119 | im_gt= cv2.resize(im_gt, (0,0), fx=self.ds, fy=self.ds) 120 | im_gt=im_gt.transpose([2,0,1]) 121 | return {'im_gt': im_gt/np.max(im_gt), 122 | 'meas': meas/np.max(meas)} 123 | class ToTensor2d(object): 124 | """Convert ndarrays in sample to Tensors.""" 125 | 126 | def __call__(self, sample): 127 | im_gt, meas = sample['im_gt'], sample['meas'] 128 | 129 | return {'im_gt': torch.from_numpy(im_gt).unsqueeze(0), 130 | 'meas': torch.from_numpy(meas).unsqueeze(0)} 131 | class ToTensor(object): 132 | """Convert ndarrays in sample to Tensors.""" 133 | 134 | def __call__(self, sample): 135 | im_gt, meas = sample['im_gt'], sample['meas'] 136 | 137 | return {'im_gt': torch.from_numpy(im_gt).unsqueeze(0), 138 | 'meas': torch.from_numpy(meas).unsqueeze(0).unsqueeze(0)} 139 | 140 | 141 | class AddNoise(object): 142 | """adds noise""" 143 | 144 | def __call__(self, sample): 145 | mu=0 146 | sigma=np.random.rand(1)*0.02+0.005 #abit much maybe 0.04 best0.04+0.01 147 | PEAK=np.random.rand(1)*4500+500 148 | p_noise = np.random.poisson(sample['meas'] * PEAK)/PEAK 149 | 150 | g_noise= np.random.normal(mu, sigma, sample['meas'].shape) 151 | sample['meas']=(g_noise+p_noise).astype('float32') 152 | sample['meas']=sample['meas']/np.max(sample['meas']) 153 | sample['meas']=np.maximum(sample['meas'],0) #rethink negative noise 154 | 155 | 156 | return sample -------------------------------------------------------------------------------- /pytorch/models/dataset_tiff.py: -------------------------------------------------------------------------------- 1 | import skimage.io 2 | import glob 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms, utils 5 | import cv2 6 | import torch 7 | import numpy as np 8 | 9 | class MiniscopeDataset(Dataset): 10 | """Face Landmarks dataset.""" 11 | 12 | def __init__(self, all_files,filepath_meas, transform=None): 13 | 14 | 15 | self.all_files_gt = all_files 16 | self.filepath_meas = filepath_meas 17 | # self.all_files_gt = glob.glob(filepath_gt + '*.tiff') 18 | 19 | self.transform = transform 20 | 21 | def __len__(self): 22 | return len(self.all_files_gt) 23 | 24 | def __getitem__(self, idx): 25 | 26 | im_gt = skimage.io.imread(self.all_files_gt[idx]) 27 | im_meas = skimage.io.imread(self.filepath_meas+self.all_files_gt[idx].split('/')[-1].split('.tiff')[0]+'.png') 28 | 29 | sample = {'im_gt': im_gt.astype('float32')/255., 'meas': im_meas.astype('float32')/255.} 30 | 31 | 32 | if self.transform: 33 | sample = self.transform(sample) 34 | 35 | return sample 36 | 37 | 38 | 39 | 40 | class MiniscopeDataset_backup(Dataset): 41 | """Face Landmarks dataset.""" 42 | 43 | def __init__(self, filepath_meas, filepath_gt, transform=None): 44 | 45 | self.filepath_meas = filepath_meas 46 | self.filepath_gt = filepath_gt 47 | self.all_files_gt = glob.glob(filepath_gt + '*.tiff') 48 | 49 | self.transform = transform 50 | 51 | def __len__(self): 52 | return len(self.all_files_gt) 53 | 54 | def __getitem__(self, idx): 55 | 56 | im_gt = skimage.io.imread(self.filepath_gt + str(idx) + '.tiff') 57 | im_meas = skimage.io.imread(self.filepath_meas + str(idx) + '.png') 58 | 59 | sample = {'im_gt': im_gt.astype('float32')/255., 'meas': im_meas.astype('float32')/255.} 60 | 61 | 62 | if self.transform: 63 | sample = self.transform(sample) 64 | 65 | return sample 66 | class downsize(object): 67 | """Convert ndarrays in sample to Tensors.""" 68 | 69 | def __init__(self, ds=0.5): 70 | self.ds=ds 71 | 72 | 73 | def __call__(self, sample): 74 | 75 | im_gt, meas = sample['im_gt'], sample['meas'] 76 | meas=meas[18:466,4:644] 77 | meas= cv2.resize(meas, (0,0), fx=self.ds, fy=self.ds) 78 | im_gt=im_gt[0:32,18:466,4:644] 79 | im_gt= cv2.resize(im_gt.transpose([1,2,0]), (0,0), fx=self.ds, fy=self.ds) 80 | im_gt=im_gt.transpose([2,0,1]) 81 | return {'im_gt': im_gt, 82 | 'meas': meas} 83 | class ToTensor(object): 84 | """Convert ndarrays in sample to Tensors.""" 85 | 86 | def __call__(self, sample): 87 | im_gt, meas = sample['im_gt'], sample['meas'] 88 | 89 | return {'im_gt': torch.from_numpy(im_gt).unsqueeze(0), 90 | 'meas': torch.from_numpy(meas).unsqueeze(0).unsqueeze(0)} -------------------------------------------------------------------------------- /pytorch/models/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ConvBlock(nn.Module): 5 | def __init__(self, in_channels, out_channels, dropout=False, norm='batch', residual=True, activation='leakyrelu', transpose=False): 6 | super(ConvBlock, self).__init__() 7 | self.dropout = dropout 8 | self.residual = residual 9 | self.activation = activation 10 | self.transpose = transpose 11 | 12 | if self.dropout: 13 | self.dropout1 = nn.Dropout2d(p=0.05) 14 | self.dropout2 = nn.Dropout2d(p=0.05) 15 | 16 | self.norm1 = None 17 | self.norm2 = None 18 | if norm == 'batch': 19 | self.norm1 = nn.BatchNorm2d(out_channels) 20 | self.norm2 = nn.BatchNorm2d(out_channels) 21 | elif norm == 'instance': 22 | self.norm1 = nn.InstanceNorm2d(out_channels, affine=True) 23 | self.norm2 = nn.InstanceNorm2d(out_channels, affine=True) 24 | elif norm == 'mixed': 25 | self.norm1 = nn.BatchNorm2d(out_channels, affine=True) 26 | self.norm2 = nn.InstanceNorm2d(out_channels, affine=True) 27 | 28 | if self.transpose: 29 | self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, padding=1) 30 | self.conv2 = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, padding=1) 31 | else: 32 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 33 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 34 | 35 | if self.activation == 'relu': 36 | self.actfun1 = nn.ReLU() 37 | self.actfun2 = nn.ReLU() 38 | elif self.activation == 'leakyrelu': 39 | self.actfun1 = nn.LeakyReLU() 40 | self.actfun2 = nn.LeakyReLU() 41 | elif self.activation == 'elu': 42 | self.actfun1 = nn.ELU() 43 | self.actfun2 = nn.ELU() 44 | elif self.activation == 'selu': 45 | self.actfun1 = nn.SELU() 46 | self.actfun2 = nn.SELU() 47 | 48 | def forward(self, x): 49 | ox = x 50 | 51 | x = self.conv1(x) 52 | 53 | if self.dropout: 54 | x = self.dropout1(x) 55 | 56 | if self.norm1: 57 | x = self.norm1(x) 58 | 59 | x = self.actfun1(x) 60 | 61 | x = self.conv2(x) 62 | 63 | if self.dropout: 64 | x = self.dropout2(x) 65 | 66 | if self.norm2: 67 | x = self.norm2(x) 68 | 69 | if self.residual: 70 | x[:, 0:min(ox.shape[1], x.shape[1]), :, :] += ox[:, 0:min(ox.shape[1], x.shape[1]), :, :] 71 | 72 | x = self.actfun2(x) 73 | 74 | # print("shapes: x:%s ox:%s " % (x.shape,ox.shape)) 75 | 76 | return x 77 | 78 | 79 | -------------------------------------------------------------------------------- /pytorch/models/modules3d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ConvBlock(nn.Module): 5 | def __init__(self, in_channels, out_channels, dropout=False, norm='batch', residual=True, activation='leakyrelu', transpose=False): 6 | super(ConvBlock, self).__init__() 7 | self.dropout = dropout 8 | self.residual = residual 9 | self.activation = activation 10 | self.transpose = transpose 11 | 12 | if self.dropout: 13 | self.dropout1 = nn.Dropout3d(p=0.05) 14 | self.dropout2 = nn.Dropout3d(p=0.05) 15 | 16 | self.norm1 = None 17 | self.norm2 = None 18 | if norm == 'batch': 19 | self.norm1 = nn.BatchNorm3d(out_channels) 20 | self.norm2 = nn.BatchNorm3d(out_channels) 21 | elif norm == 'instance': 22 | self.norm1 = nn.InstanceNorm3d(out_channels, affine=True) 23 | self.norm2 = nn.InstanceNorm3d(out_channels, affine=True) 24 | elif norm == 'mixed': 25 | self.norm1 = nn.BatchNorm3d(out_channels, affine=True) 26 | self.norm2 = nn.InstanceNorm3d(out_channels, affine=True) 27 | 28 | if self.transpose: 29 | self.conv1 = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=3, padding=1) 30 | self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=3, padding=1) 31 | else: 32 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) 33 | self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1) 34 | 35 | if self.activation == 'relu': 36 | self.actfun1 = nn.ReLU() 37 | self.actfun2 = nn.ReLU() 38 | elif self.activation == 'leakyrelu': 39 | self.actfun1 = nn.LeakyReLU() 40 | self.actfun2 = nn.LeakyReLU() 41 | elif self.activation == 'elu': 42 | self.actfun1 = nn.ELU() 43 | self.actfun2 = nn.ELU() 44 | elif self.activation == 'selu': 45 | self.actfun1 = nn.SELU() 46 | self.actfun2 = nn.SELU() 47 | 48 | def forward(self, x): 49 | ox = x 50 | 51 | x = self.conv1(x) 52 | 53 | if self.dropout: 54 | x = self.dropout1(x) 55 | 56 | if self.norm1: 57 | x = self.norm1(x) 58 | 59 | x = self.actfun1(x) 60 | 61 | x = self.conv2(x) 62 | 63 | if self.dropout: 64 | x = self.dropout2(x) 65 | 66 | if self.norm2: 67 | x = self.norm2(x) 68 | 69 | if self.residual: 70 | x[:, 0:min(ox.shape[1], x.shape[1]), :, :] += ox[:, 0:min(ox.shape[1], x.shape[1]), :, :] 71 | 72 | x = self.actfun2(x) 73 | 74 | # print("shapes: x:%s ox:%s " % (x.shape,ox.shape)) 75 | 76 | return x -------------------------------------------------------------------------------- /pytorch/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.modules import ConvBlock 6 | 7 | 8 | class Unet(nn.Module): 9 | def __init__(self, n_channel_in=1, n_channel_out=1, residual=False, down='conv', up='tconv', activation='selu'): 10 | super(Unet, self).__init__() 11 | 12 | self.residual = residual 13 | 14 | if down == 'maxpool': 15 | self.down1 = nn.MaxPool2d(kernel_size=2) 16 | self.down2 = nn.MaxPool2d(kernel_size=2) 17 | self.down3 = nn.MaxPool2d(kernel_size=2) 18 | self.down4 = nn.MaxPool2d(kernel_size=2) 19 | elif down == 'avgpool': 20 | self.down1 = nn.AvgPool2d(kernel_size=2) 21 | self.down2 = nn.AvgPool2d(kernel_size=2) 22 | self.down3 = nn.AvgPool2d(kernel_size=2) 23 | self.down4 = nn.AvgPool2d(kernel_size=2) 24 | elif down == 'conv': 25 | self.down1 = nn.Conv2d(32, 32, kernel_size=2, stride=2, groups=32) 26 | self.down2 = nn.Conv2d(64, 64, kernel_size=2, stride=2, groups=64) 27 | self.down3 = nn.Conv2d(128, 128, kernel_size=2, stride=2, groups=128) 28 | self.down4 = nn.Conv2d(256, 256, kernel_size=2, stride=2, groups=256) 29 | 30 | self.down1.weight.data = 0.01 * self.down1.weight.data + 0.25 31 | self.down2.weight.data = 0.01 * self.down2.weight.data + 0.25 32 | self.down3.weight.data = 0.01 * self.down3.weight.data + 0.25 33 | self.down4.weight.data = 0.01 * self.down4.weight.data + 0.25 34 | 35 | self.down1.bias.data = 0.01 * self.down1.bias.data + 0 36 | self.down2.bias.data = 0.01 * self.down2.bias.data + 0 37 | self.down3.bias.data = 0.01 * self.down3.bias.data + 0 38 | self.down4.bias.data = 0.01 * self.down4.bias.data + 0 39 | 40 | if up == 'bilinear' or up == 'nearest': 41 | self.up1 = lambda x: nn.functional.interpolate(x, mode=up, scale_factor=2) 42 | self.up2 = lambda x: nn.functional.interpolate(x, mode=up, scale_factor=2) 43 | self.up3 = lambda x: nn.functional.interpolate(x, mode=up, scale_factor=2) 44 | self.up4 = lambda x: nn.functional.interpolate(x, mode=up, scale_factor=2) 45 | elif up == 'tconv': 46 | self.up1 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2, groups=256) 47 | self.up2 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2, groups=128) 48 | self.up3 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, groups=64) 49 | self.up4 = nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2, groups=32) 50 | 51 | self.up1.weight.data = 0.01 * self.up1.weight.data + 0.25 52 | self.up2.weight.data = 0.01 * self.up2.weight.data + 0.25 53 | self.up3.weight.data = 0.01 * self.up3.weight.data + 0.25 54 | self.up4.weight.data = 0.01 * self.up4.weight.data + 0.25 55 | 56 | self.up1.bias.data = 0.01 * self.up1.bias.data + 0 57 | self.up2.bias.data = 0.01 * self.up2.bias.data + 0 58 | self.up3.bias.data = 0.01 * self.up3.bias.data + 0 59 | self.up4.bias.data = 0.01 * self.up4.bias.data + 0 60 | 61 | self.conv1 = ConvBlock(n_channel_in, 32, residual, activation) 62 | self.conv2 = ConvBlock(32, 64, residual, activation) 63 | self.conv3 = ConvBlock(64, 128, residual, activation) 64 | self.conv4 = ConvBlock(128, 256, residual, activation) 65 | 66 | self.conv5 = ConvBlock(256, 256, residual, activation) 67 | 68 | self.conv6 = ConvBlock(2 * 256, 128, residual, activation) 69 | self.conv7 = ConvBlock(2 * 128, 64, residual, activation) 70 | self.conv8 = ConvBlock(2 * 64, 32, residual, activation) 71 | self.conv9 = ConvBlock(2 * 32, n_channel_out, residual, activation) 72 | 73 | if self.residual: 74 | self.convres = ConvBlock(n_channel_in, n_channel_out, residual, activation) 75 | 76 | def forward(self, x): 77 | c0 = x 78 | c1 = self.conv1(x) 79 | x = self.down1(c1) 80 | c2 = self.conv2(x) 81 | x = self.down2(c2) 82 | c3 = self.conv3(x) 83 | x = self.down3(c3) 84 | c4 = self.conv4(x) 85 | x = self.down4(c4) 86 | x = self.conv5(x) 87 | x = self.up1(x) 88 | # print("shapes: c0:%sx:%s c4:%s " % (c0.shape,x.shape,c4.shape)) 89 | x = torch.cat([x, c4], 1) # x[:,0:128]*x[:,128:256], 90 | x = self.conv6(x) 91 | x = self.up2(x) 92 | x = torch.cat([x, c3], 1) # x[:,0:64]*x[:,64:128], 93 | x = self.conv7(x) 94 | x = self.up3(x) 95 | x = torch.cat([x, c2], 1) # x[:,0:32]*x[:,32:64], 96 | x = self.conv8(x) 97 | x = self.up4(x) 98 | x = torch.cat([x, c1], 1) # x[:,0:16]*x[:,16:32], 99 | x = self.conv9(x) 100 | if self.residual: 101 | x = torch.add(x, self.convres(c0)) 102 | 103 | return x -------------------------------------------------------------------------------- /pytorch/models/unet2.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | from .unet_parts import * 4 | 5 | 6 | class UNet(nn.Module): 7 | def __init__(self, n_channels, n_classes, bilinear=True): 8 | super(UNet, self).__init__() 9 | self.n_channels = n_channels 10 | self.n_classes = n_classes 11 | self.bilinear = bilinear 12 | 13 | self.inc = DoubleConv(n_channels, 64) 14 | self.down1 = Down(64, 128) 15 | self.down2 = Down(128, 256) 16 | self.down3 = Down(256, 512) 17 | factor = 2 if bilinear else 1 18 | self.down4 = Down(512, 1024 // factor) 19 | self.up1 = Up(1024, 512 // factor, bilinear) 20 | self.up2 = Up(512, 256 // factor, bilinear) 21 | self.up3 = Up(256, 128 // factor, bilinear) 22 | self.up4 = Up(128, 64, bilinear) 23 | self.outc = OutConv(64, n_classes) 24 | 25 | def forward(self, x): 26 | x1 = self.inc(x) 27 | x2 = self.down1(x1) 28 | x3 = self.down2(x2) 29 | x4 = self.down3(x3) 30 | x5 = self.down4(x4) 31 | x = self.up1(x5, x4) 32 | x = self.up2(x, x3) 33 | x = self.up3(x, x2) 34 | x = self.up4(x, x1) 35 | logits = self.outc(x) 36 | return logits -------------------------------------------------------------------------------- /pytorch/models/unet3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.modules3d import ConvBlock 6 | 7 | 8 | class Unet(nn.Module): 9 | def __init__(self, n_channel_in=1, n_channel_out=1, residual=False, down='conv', up='tconv', activation='selu'): 10 | super(Unet, self).__init__() 11 | 12 | self.residual = residual 13 | 14 | if down == 'maxpool': 15 | self.down1 = nn.MaxPool3d(kernel_size=2) 16 | self.down2 = nn.MaxPool3d(kernel_size=2) 17 | self.down3 = nn.MaxPool3d(kernel_size=2) 18 | self.down4 = nn.MaxPool3d(kernel_size=2) 19 | elif down == 'avgpool': 20 | self.down1 = nn.AvgPool3d(kernel_size=2) 21 | self.down2 = nn.AvgPool3d(kernel_size=2) 22 | self.down3 = nn.AvgPool3d(kernel_size=2) 23 | self.down4 = nn.AvgPool3d(kernel_size=2) 24 | elif down == 'conv': 25 | self.down1 = nn.Conv3d(32, 32, kernel_size=2, stride=2, groups=32) 26 | self.down2 = nn.Conv3d(64, 64, kernel_size=2, stride=2, groups=64) 27 | self.down3 = nn.Conv3d(128, 128, kernel_size=2, stride=2, groups=128) 28 | self.down4 = nn.Conv3d(256, 256, kernel_size=2, stride=2, groups=256) 29 | 30 | self.down1.weight.data = 0.01 * self.down1.weight.data + 0.25 31 | self.down2.weight.data = 0.01 * self.down2.weight.data + 0.25 32 | self.down3.weight.data = 0.01 * self.down3.weight.data + 0.25 33 | self.down4.weight.data = 0.01 * self.down4.weight.data + 0.25 34 | 35 | self.down1.bias.data = 0.01 * self.down1.bias.data + 0 36 | self.down2.bias.data = 0.01 * self.down2.bias.data + 0 37 | self.down3.bias.data = 0.01 * self.down3.bias.data + 0 38 | self.down4.bias.data = 0.01 * self.down4.bias.data + 0 39 | 40 | if up == 'bilinear' or up == 'nearest': 41 | self.up1 = lambda x: nn.functional.interpolate(x, mode=up, scale_factor=2) 42 | self.up2 = lambda x: nn.functional.interpolate(x, mode=up, scale_factor=2) 43 | self.up3 = lambda x: nn.functional.interpolate(x, mode=up, scale_factor=2) 44 | self.up4 = lambda x: nn.functional.interpolate(x, mode=up, scale_factor=2) 45 | elif up == 'tconv': 46 | self.up1 = nn.ConvTranspose3d(256, 256, kernel_size=2, stride=2, groups=256) 47 | self.up2 = nn.ConvTranspose3d(128, 128, kernel_size=2, stride=2, groups=128) 48 | self.up3 = nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2, groups=64) 49 | self.up4 = nn.ConvTranspose3d(32, 32, kernel_size=2, stride=2, groups=32) 50 | 51 | self.up1.weight.data = 0.01 * self.up1.weight.data + 0.25 52 | self.up2.weight.data = 0.01 * self.up2.weight.data + 0.25 53 | self.up3.weight.data = 0.01 * self.up3.weight.data + 0.25 54 | self.up4.weight.data = 0.01 * self.up4.weight.data + 0.25 55 | 56 | self.up1.bias.data = 0.01 * self.up1.bias.data + 0 57 | self.up2.bias.data = 0.01 * self.up2.bias.data + 0 58 | self.up3.bias.data = 0.01 * self.up3.bias.data + 0 59 | self.up4.bias.data = 0.01 * self.up4.bias.data + 0 60 | 61 | self.conv1 = ConvBlock(n_channel_in, 32, residual, activation) 62 | self.conv2 = ConvBlock(32, 64, residual, activation) 63 | self.conv3 = ConvBlock(64, 128, residual, activation) 64 | self.conv4 = ConvBlock(128, 256, residual, activation) 65 | 66 | self.conv5 = ConvBlock(256, 256, residual, activation) 67 | 68 | self.conv6 = ConvBlock(2 * 256, 128, residual, activation) 69 | self.conv7 = ConvBlock(2 * 128, 64, residual, activation) 70 | self.conv8 = ConvBlock(2 * 64, 32, residual, activation) 71 | self.conv9 = ConvBlock(2 * 32, n_channel_out, residual, activation) 72 | 73 | if self.residual: 74 | self.convres = ConvBlock(n_channel_in, n_channel_out, residual, activation) 75 | 76 | def forward(self, x): 77 | c0 = x 78 | c1 = self.conv1(x) 79 | x = self.down1(c1) 80 | c2 = self.conv2(x) 81 | x = self.down2(c2) 82 | c3 = self.conv3(x) 83 | x = self.down3(c3) 84 | c4 = self.conv4(x) 85 | x = self.down4(c4) 86 | x = self.conv5(x) 87 | x = self.up1(x) 88 | x = torch.cat([x, c4], 1) 89 | x = self.conv6(x) 90 | x = self.up2(x) 91 | x = torch.cat([x, c3], 1) 92 | x = self.conv7(x) 93 | x = self.up3(x) 94 | x = torch.cat([x, c2], 1) 95 | x = self.conv8(x) 96 | x = self.up4(x) 97 | x = torch.cat([x, c1], 1) 98 | x = self.conv9(x) 99 | if self.residual: 100 | x = torch.add(x, self.convres(c0)) 101 | 102 | return x -------------------------------------------------------------------------------- /pytorch/models/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | def forward(self, x1, x2): 57 | x1 = self.up(x1) 58 | # input is CHW 59 | diffY = x2.size()[2] - x1.size()[2] 60 | diffX = x2.size()[3] - x1.size()[3] 61 | 62 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 63 | diffY // 2, diffY - diffY // 2]) 64 | # if you have padding issues, see 65 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 66 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 67 | x = torch.cat([x2, x1], dim=1) 68 | return self.conv(x) 69 | 70 | 71 | class OutConv(nn.Module): 72 | def __init__(self, in_channels, out_channels): 73 | super(OutConv, self).__init__() 74 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 75 | 76 | def forward(self, x): 77 | return self.conv(x) -------------------------------------------------------------------------------- /pytorch/models/wiener_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | class MultiWienerDeconvolution3D(nn.Module): 6 | """ 7 | Performs Wiener Deconvolution in the frequency domain for each psf. 8 | 9 | Input: initial_psfs of shape (Y, X, C), initial_K has shape (1, 1, C) for each psf. 10 | """ 11 | 12 | def __init__(self, initial_psfs, initial_Ks): 13 | super(MultiWienerDeconvolution3D, self).__init__() 14 | initial_psfs = torch.tensor(initial_psfs, dtype=torch.float32) 15 | initial_Ks = torch.tensor(initial_Ks, dtype=torch.float32) 16 | 17 | self.psfs = nn.Parameter(initial_psfs, requires_grad =True) 18 | self.Ks = nn.Parameter(initial_Ks, requires_grad =True) #NEEED RELU CONSTRAINT HERE K is constrained to be nonnegative 19 | 20 | def forward(self, y): 21 | # Y preprocessing, Y is shape (N, C,H, W) 22 | h, w = y.shape[-3:-1] 23 | y = y.type(torch.complex64) 24 | 25 | 26 | # Pad Y 27 | padding = ((0, 0), 28 | (int(np.ceil(h / 2)), int(np.floor(h / 2))), 29 | (int(np.ceil(w / 2)), int(np.floor(w / 2))), 30 | (0, 0)) 31 | 32 | # Temporarily transpose y since we cannot specify axes for fft2d 33 | Y=torch.fft.fft2(y) 34 | 35 | # Components preprocessing, psfs is shape (C,H, W) 36 | psf = self.psfs.type(torch.complex64) 37 | h_psf, w_psf = self.psfs.shape[0:2] 38 | 39 | # Pad psf 40 | padding_psf = ( 41 | (int(np.ceil(h_psf / 2)), int(np.floor(h_psf / 2))), 42 | (int(np.ceil(w_psf / 2)), int(np.floor(w_psf / 2))), 43 | (0, 0)) 44 | 45 | H_sum = torch.fft.fft2(self.psfs) 46 | 47 | X=(torch.conj(H_sum)*Y)/ (torch.square(torch.abs(H_sum))+100*self.Ks)#, dtype=tf.complex64) 48 | 49 | x=torch.real((torch.fft.ifftshift(torch.fft.ifft2(X), dim=(-2, -1)))) 50 | 51 | 52 | return x 53 | 54 | def get_config(self): 55 | config = super().get_config().copy() 56 | config.update({ 57 | 'initial_psfs': self.psfs.numpy(), 58 | 'initial_Ks': self.Ks.numpy() 59 | }) 60 | return config 61 | 62 | 63 | class WienerDeconvolution3D(nn.Module): 64 | """ 65 | Performs Wiener Deconvolution in the frequency domain for each psf. 66 | 67 | Input: initial_psfs of shape (Y, X, C), initial_K has shape (1, 1, C) for each psf. 68 | """ 69 | 70 | def __init__(self, initial_psfs, initial_Ks): 71 | super(WienerDeconvolution3D, self).__init__() 72 | initial_psfs = torch.tensor(initial_psfs, dtype=torch.float32) 73 | initial_Ks = torch.tensor(initial_Ks, dtype=torch.float32) 74 | 75 | self.psfs = nn.Parameter(initial_psfs, requires_grad =True) 76 | self.Ks = nn.Parameter(initial_Ks, requires_grad =True) #NEEED RELU CONSTRAINT HERE K is constrained to be nonnegative 77 | 78 | def forward(self, y): 79 | # Y preprocessing, Y is shape (N, C,H, W) 80 | h, w = y.shape[-3:-1] 81 | y = y.type(torch.complex64) 82 | 83 | 84 | # Pad Y 85 | padding = ((0, 0), 86 | (int(np.ceil(h / 2)), int(np.floor(h / 2))), 87 | (int(np.ceil(w / 2)), int(np.floor(w / 2))), 88 | (0, 0)) 89 | 90 | # Temporarily transpose y since we cannot specify axes for fft2d 91 | Y=torch.fft.fft2(y) 92 | 93 | # Components preprocessing, psfs is shape (C,H, W) 94 | psf = self.psfs.type(torch.complex64) 95 | h_psf, w_psf = self.psfs.shape[0:2] 96 | 97 | # Pad psf 98 | padding_psf = ( 99 | (int(np.ceil(h_psf / 2)), int(np.floor(h_psf / 2))), 100 | (int(np.ceil(w_psf / 2)), int(np.floor(w_psf / 2))), 101 | (0, 0)) 102 | 103 | H_sum = torch.fft.fft2(self.psfs) 104 | 105 | #print(H_sum.shape, Y.shape, self.Ks.shape) 106 | X=(torch.conj(H_sum)*Y)/ (torch.square(torch.abs(H_sum))+100*self.Ks)#, dtype=tf.complex64) 107 | 108 | x=torch.real((torch.fft.ifftshift(torch.fft.ifft2(X), dim=(-2, -1)))) 109 | 110 | 111 | return x 112 | 113 | def get_config(self): 114 | config = super().get_config().copy() 115 | config.update({ 116 | 'initial_psfs': self.psfs.numpy(), 117 | 'initial_Ks': self.Ks.numpy() 118 | }) 119 | return config 120 | 121 | 122 | class MyEnsemble2d(nn.Module): 123 | def __init__(self, wiener_model, unet_model): 124 | super(MyEnsemble2d, self).__init__() 125 | self.wiener_model = wiener_model 126 | self.unet_model = unet_model 127 | def forward(self, x): 128 | wiener_output = self.wiener_model(x) 129 | wiener_output = wiener_output/torch.max(wiener_output) 130 | final_output = self.unet_model(wiener_output) 131 | return final_output 132 | 133 | class MyEnsemble(nn.Module): 134 | def __init__(self, wiener_model, unet_model): 135 | super(MyEnsemble, self).__init__() 136 | self.wiener_model = wiener_model 137 | self.unet_model = unet_model 138 | def forward(self, x): 139 | wiener_output = self.wiener_model(x) 140 | final_output = self.unet_model(wiener_output) 141 | return final_output -------------------------------------------------------------------------------- /pytorch/training_code.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import numpy as np\n", 13 | "import torch, torch.optim\n", 14 | "import torch.nn.functional as F\n", 15 | "torch.backends.cudnn.enabled = True\n", 16 | "torch.backends.cudnn.benchmark =True\n", 17 | "dtype = torch.cuda.FloatTensor\n", 18 | "import os, sys, json, glob\n", 19 | "from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM\n", 20 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 21 | "\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "\n", 24 | "import random\n", 25 | "\n", 26 | "import skimage.io\n", 27 | "import torch.nn as nn\n", 28 | "import argparse\n", 29 | "\n", 30 | "from torch.utils.data import Dataset, DataLoader\n", 31 | "from torchvision import transforms, utils\n", 32 | "import cv2\n", 33 | "import models.wiener_model as wm\n", 34 | "import models.dataset as ds\n", 35 | "from PIL import Image\n", 36 | "import helper as hp\n", 37 | "\n", 38 | "import scipy.io" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "\u001b[1m\u001b[37mwaller-fuoco\u001b[m Fri Jan 21 16:37:05 2022\r\n", 51 | "\u001b[36m[0]\u001b[m \u001b[34mNVIDIA GeForce GTX 1080 Ti\u001b[m |\u001b[1m\u001b[31m 91'C\u001b[m, \u001b[1m\u001b[32m100 %\u001b[m | \u001b[36m\u001b[1m\u001b[33m 9956\u001b[m / \u001b[33m11176\u001b[m MB | \u001b[1m\u001b[30mkyrollos\u001b[m(\u001b[33m9945M\u001b[m) \u001b[1m\u001b[30mgdm\u001b[m(\u001b[33m6M\u001b[m)\r\n", 52 | "\u001b[36m[1]\u001b[m \u001b[34mNVIDIA TITAN X (Pascal) \u001b[m |\u001b[1m\u001b[31m 61'C\u001b[m, \u001b[32m 0 %\u001b[m | \u001b[36m\u001b[1m\u001b[33m 4\u001b[m / \u001b[33m12196\u001b[m MB |\r\n", 53 | "\u001b[36m[2]\u001b[m \u001b[34mNVIDIA TITAN Xp \u001b[m |\u001b[1m\u001b[31m 88'C\u001b[m, \u001b[1m\u001b[32m 99 %\u001b[m | \u001b[36m\u001b[1m\u001b[33m11973\u001b[m / \u001b[33m12196\u001b[m MB | \u001b[1m\u001b[30mtiffany\u001b[m(\u001b[33m11969M\u001b[m)\r\n", 54 | "\u001b[36m[3]\u001b[m \u001b[34mNVIDIA TITAN Xp \u001b[m |\u001b[31m 32'C\u001b[m, \u001b[32m 0 %\u001b[m | \u001b[36m\u001b[1m\u001b[33m 4\u001b[m / \u001b[33m12196\u001b[m MB |\r\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "!gpustat" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "# Training code for 2D & 3D spatially-varying deconvolutions" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "parser = argparse.ArgumentParser(description='Process some integers.')\n", 76 | "parser.add_argument('--data_type', default='2D')\n", 77 | "parser.add_argument('--network', default='multiwiener') #'wiener' or 'unet' or 'multiwiener'\n", 78 | "parser.add_argument('--id', default='') #some identifier\n", 79 | "parser.add_argument('--loss_type', default='l1') \n", 80 | "parser.add_argument('--device', default='0') \n", 81 | "parser.add_argument('--psf_num', default=9, type=int)\n", 82 | "parser.add_argument('--psf_ds', default=0.75, type=float)\n", 83 | "parser.add_argument('--epochs', default=10000, type=int)\n", 84 | "parser.add_argument('--lr', default=1e-4, type=float) \n", 85 | "parser.add_argument('--load_path',default=None)\n", 86 | "parser.add_argument('--save_checkponts',default=True)\n", 87 | "\n", 88 | "args = parser.parse_args(''.split())\n", 89 | "\n", 90 | "os.environ['CUDA_VISIBLE_DEVICES'] = args.device" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 4, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# for 3D-UNet multiwiener\n", 100 | "registered_psfs_path = '../data/multiWienerPSFStack_40z_aligned.mat'\n", 101 | "psfs = scipy.io.loadmat(registered_psfs_path)\n", 102 | "psfs=psfs['multiWienerPSFStack_40z']" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 5, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "choosing 9 psfs\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "if args.data_type == '3D':\n", 120 | " if args.network=='wiener' or args.network=='unet':\n", 121 | " psfs=hp.pre_process_psfs(psfs)[:,:,4]\n", 122 | " Ks=np.ones((32,1,1))\n", 123 | " print('choosing 1 psfs')\n", 124 | "\n", 125 | " elif args.network=='multiwiener':\n", 126 | " Ks=np.ones((args.psf_num,32,1,1))\n", 127 | " if args.psf_num==9:\n", 128 | " print('choosing 9 psfs')\n", 129 | " psfs=hp.pre_process_psfs(psfs)\n", 130 | " else:\n", 131 | " print('invalid network')\n", 132 | " psfs = hp.downsize_psf(psfs)\n", 133 | "else: #2D\n", 134 | " if args.network=='wiener' or args.network=='unet':\n", 135 | " psfs=hp.pre_process_psfs_2d(psfs)[:,:,4, 0]\n", 136 | " Ks= 1.\n", 137 | " print('choosing 1 psfs')\n", 138 | "\n", 139 | " elif args.network=='multiwiener':\n", 140 | " Ks=np.ones((args.psf_num,1,1))\n", 141 | " if args.psf_num==9:\n", 142 | " print('choosing 9 psfs')\n", 143 | " psfs=hp.pre_process_psfs_2d(psfs)[...,0]\n", 144 | " psfs = psfs.transpose(2,0,1)\n", 145 | " else:\n", 146 | " print('invalid network')\n", 147 | "\n", 148 | " \n" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "## Make dataset and dataloader for training data" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 6, 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "total number of images 22126\n", 168 | "training images: 17700 testing images: 4426\n" 169 | ] 170 | } 171 | ], 172 | "source": [ 173 | "down_size = ds.downsize(ds=args.psf_ds)\n", 174 | "to_tensor = ds.ToTensor()\n", 175 | "add_noise=ds.AddNoise()\n", 176 | "\n", 177 | "if args.data_type == '3D':\n", 178 | " filepath_gt = '/home/kyrollos/LearnedMiniscope3D/Data3D/Training_data_all/' \n", 179 | "else:\n", 180 | " filepath_gt = '/home/kyrollos/LearnedMiniscope3D/Data/Target/'\n", 181 | " filepath_meas = '/home/kyrollos/LearnedMiniscope3D/Data/Train/'\n", 182 | "\n", 183 | "\n", 184 | "filepath_all=glob.glob(filepath_gt+'*')\n", 185 | "random.Random(8).shuffle(filepath_all)\n", 186 | "print('total number of images',len(filepath_all))\n", 187 | "total_num_images = len(filepath_all)\n", 188 | "num_test = 0.2 # 20% test\n", 189 | "filepath_train=filepath_all[0:int(total_num_images*(1-num_test))]\n", 190 | "filepath_test=filepath_all[int(total_num_images*(1-num_test)):]\n", 191 | "\n", 192 | "print('training images:', len(filepath_train), \n", 193 | " 'testing images:', len(filepath_test))\n", 194 | "\n", 195 | "if args.data_type == '3D':\n", 196 | " dataset_train = ds.MiniscopeDataset(filepath_train, transform = transforms.Compose([down_size,add_noise,to_tensor]))\n", 197 | " dataset_test = ds.MiniscopeDataset(filepath_test, transform = transforms.Compose([down_size,add_noise,to_tensor]))\n", 198 | "else:\n", 199 | " dataset_train = ds.MiniscopeDataset_2D(filepath_train, filepath_meas, transform = transforms.Compose([ds.crop2d(),ds.ToTensor2d()]))\n", 200 | " dataset_test = ds.MiniscopeDataset_2D(filepath_test, filepath_meas, transform = transforms.Compose([ds.crop2d(),ds.ToTensor2d()]))\n", 201 | "\n", 202 | "\n", 203 | "dataloader_train = DataLoader(dataset_train, batch_size=1,\n", 204 | " shuffle=True, num_workers=1)\n", 205 | "\n", 206 | "dataloader_test = DataLoader(dataset_test, batch_size=1,\n", 207 | " shuffle=False, num_workers=1)\n", 208 | "\n", 209 | "device = 'cuda:0'\n" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "## Define model and optimizer" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "if args.data_type == '3D':\n", 226 | " from models.unet3d import Unet\n", 227 | " unet_model = Unet(n_channel_in=args.psf_num, n_channel_out=1, residual=False, down='conv', up='tconv', activation='selu').to(device)\n", 228 | "\n", 229 | " if args.network == 'multiwiener' or args.network == 'wiener':\n", 230 | " wiener_model=wm.WienerDeconvolution3D(psfs,Ks).to(device)\n", 231 | " model=wm.MyEnsemble(wiener_model,unet_model)\n", 232 | " else:\n", 233 | " model = unet_model\n", 234 | "else: #2D\n", 235 | " from models.unet import Unet\n", 236 | " if args.network == 'multiwiener':\n", 237 | " num_in_channels = args.psf_num\n", 238 | " else:\n", 239 | " num_in_channels = 1\n", 240 | " \n", 241 | " \n", 242 | " unet_model = Unet(n_channel_in=num_in_channels, n_channel_out=1, residual=False, down='conv', up='tconv', activation='selu').to(device)\n", 243 | "\n", 244 | " if args.network == 'multiwiener' or args.network == 'wiener':\n", 245 | " wiener_model=wm.WienerDeconvolution3D(psfs,Ks).to(device)\n", 246 | " model=wm.MyEnsemble(wiener_model,unet_model)\n", 247 | " else:\n", 248 | " model = unet_model\n", 249 | "\n", 250 | " \n", 251 | "if args.load_path is not None:\n", 252 | " model.load_state_dict(torch.load('saved_data/'+args.load_path,map_location=torch.device(device)))\n", 253 | " print('loading saved model')\n", 254 | "\n", 255 | "\n", 256 | "loss_fn = torch.nn.L1Loss()\n", 257 | "\n", 258 | "optimizer = torch.optim.Adam(model.parameters(), lr = args.lr)\n", 259 | "\n" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "if args.save_checkponts == True:\n", 269 | " filepath_save = 'saved_data/' +\"_\".join((list(vars(args).values()))[0:5]) + \"/\"\n", 270 | "\n", 271 | " if not os.path.exists(filepath_save):\n", 272 | " os.makedirs(filepath_save)\n", 273 | "\n", 274 | " with open(filepath_save + 'args.json', 'w') as fp:\n", 275 | " json.dump(vars(args), fp)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "\n", 285 | "\n", 286 | "best_loss=27e7\n", 287 | "\n", 288 | "for itr in range(0,args.epochs):\n", 289 | " for i_batch, sample_batched in enumerate(dataloader_train):\n", 290 | " optimizer.zero_grad()\n", 291 | " #out = model(sample_batched['meas'].repeat(1,32,1,1)[...,18:466,4:644].unsqueeze(0).to(device))\n", 292 | " if args.network=='unet' and args.data_type == '3D':\n", 293 | " out = model(sample_batched['meas'].repeat(1,1,32,1,1).to(device))\n", 294 | " else:\n", 295 | " out = model(sample_batched['meas'].to(device))\n", 296 | "\n", 297 | " if args.loss_type=='l1':\n", 298 | " loss = loss_fn(out, sample_batched['im_gt'].to(device))\n", 299 | " else:\n", 300 | " loss = loss_fn(out, sample_batched['im_gt'].to(device))+(1- ms_ssim( out[0], sample_batched['im_gt'][0].to(device), data_range=1, size_average=False))\n", 301 | " loss.backward()\n", 302 | " optimizer.step()\n", 303 | " print('epoch: ', itr, ' batch: ', i_batch, ' loss: ', loss.item(), end='\\r')\n", 304 | "\n", 305 | " #break \n", 306 | " if args.data_type == '3D':\n", 307 | " out_np = np.max(out.detach().cpu().numpy()[0,0],0)\n", 308 | " gt_np = np.max(sample_batched['im_gt'].detach().cpu().numpy()[0,0],0)\n", 309 | " meas_np = np.max(sample_batched['meas'].detach().cpu().numpy()[0,0],0)\n", 310 | " else:\n", 311 | " out_np = out.detach().cpu().numpy()[0][0]\n", 312 | " gt_np = sample_batched['im_gt'].detach().cpu().numpy()[0][0]\n", 313 | " meas_np = sample_batched['meas'].detach().cpu().numpy()[0][0]\n", 314 | "\n", 315 | " f, ax = plt.subplots(1, 3, figsize=(15,15))\n", 316 | " ax[0].imshow(gt_np)\n", 317 | " ax[1].imshow(meas_np)\n", 318 | " ax[2].imshow(out_np)\n", 319 | " plt.show()\n", 320 | "\n", 321 | " if args.save_checkponts == True:\n", 322 | " torch.save(model.state_dict(), filepath_save + 'model_noval.pt')\n", 323 | " \n", 324 | " if itr%1==0:\n", 325 | " total_loss=0\n", 326 | " for i_batch, sample_batched in enumerate(dataloader_test):\n", 327 | " with torch.no_grad():\n", 328 | " if args.network=='unet' and args.data_type == '3D':\n", 329 | " out = model(sample_batched['meas'].repeat(1,1,32,1,1).to(device))\n", 330 | " else:\n", 331 | " out = model(sample_batched['meas'].to(device))\n", 332 | " if args.loss_type=='l1':\n", 333 | " loss = loss_fn(out, sample_batched['im_gt'].to(device))\n", 334 | " else:\n", 335 | " loss = loss_fn(out, sample_batched['im_gt'].to(device))+(1- ms_ssim( out[0], sample_batched['im_gt'][0].to(device), data_range=1, size_average=False))\n", 336 | " \n", 337 | " \n", 338 | " total_loss+=loss.item()\n", 339 | " \n", 340 | " print('loss for testing set ',itr,' ',i_batch, total_loss)\n", 341 | " \n", 342 | " #break\n", 343 | " \n", 344 | " if args.save_checkponts == True:\n", 345 | " im_gt = Image.fromarray((np.clip(gt_np/np.max(gt_np),0,1)*255).astype(np.uint8))\n", 346 | " im = Image.fromarray((np.clip(out_np/np.max(out_np),0,1)*255).astype(np.uint8))\n", 347 | " im.save(filepath_save + str(itr) + '.png')\n", 348 | " im_gt.save(filepath_save + 'gt.png')\n", 349 | " \n", 350 | " \n", 351 | " if total_loss (N, H, W, C) 59 | x = tf.transpose(x, [0, 2, 3, 1]) 60 | 61 | x = crop_2d_tf(x) 62 | # x = tf.concat([x,y_im],-1) 63 | 64 | return x 65 | 66 | def get_config(self): 67 | config = super().get_config().copy() 68 | config.update({ 69 | 'initial_psfs': self.psfs.numpy(), 70 | 'initial_Ks': self.Ks.numpy() 71 | }) 72 | return config 73 | 74 | 75 | class MultiWienerDeconvolutionWFourier(layers.Layer): 76 | """ 77 | Performs Wiener Deconvolution in the frequency domain for each psf. 78 | 79 | Input: initial_psfs of shape (Y, X, C), initial_K has shape (1, 1, C) for each psf. 80 | """ 81 | 82 | def __init__(self, initial_psfs, initial_Ks): 83 | super(MultiWienerDeconvolutionWFourier, self).__init__() 84 | initial_psfs = tf.dtypes.cast(initial_psfs, dtype=tf.float32) 85 | initial_Ks = tf.dtypes.cast(initial_Ks, dtype=tf.float32) 86 | 87 | self.psfs = tf.Variable(initial_value=initial_psfs, trainable=True) 88 | self.Ks = tf.Variable(initial_value=initial_Ks, constraint=tf.nn.relu, trainable=True) # K is constrained to be nonnegative 89 | self.fourier_weights = tf.Variable(initial_value=tf.ones([initial_psfs.shape[0],initial_psfs.shape[1],1]), trainable=True) 90 | 91 | def call(self, y): 92 | # Y preprocessing, Y is shape (N, H, W, C) 93 | # y_im=y 94 | _, h, w, _ = y.shape 95 | y = tf.dtypes.cast(y, dtype=tf.complex64) 96 | 97 | 98 | # Pad Y 99 | padding = ((0, 0), 100 | (int(np.ceil(h / 2)), int(np.floor(h / 2))), 101 | (int(np.ceil(w / 2)), int(np.floor(w / 2))), 102 | (0, 0)) 103 | y_fourier=tf.transpose(y, perm=[0, 3, 1, 2]) 104 | y = tf.pad(y, paddings=padding) 105 | 106 | # Temporarily transpose y since we cannot specify axes for fft2d 107 | y = tf.transpose(y, perm=[0, 3, 1, 2]) # Y is now shape (N, C, H, W) 108 | Y=tf.signal.fft2d(y) 109 | Y_fourier=tf.signal.fft2d(y_fourier) 110 | 111 | 112 | 113 | 114 | # Components preprocessing, psfs is shape (H, W, C) 115 | psf = tf.dtypes.cast(self.psfs, dtype=tf.complex64) 116 | h_psf, w_psf, _ = psf.shape 117 | 118 | # Pad psf 119 | padding_psf = ( 120 | (int(np.ceil(h_psf / 2)), int(np.floor(h_psf / 2))), 121 | (int(np.ceil(w_psf / 2)), int(np.floor(w_psf / 2))), 122 | (0, 0)) 123 | 124 | H_sum = tf.pad(psf, paddings=padding_psf) 125 | 126 | H_sum = tf.transpose(H_sum, perm=[2, 0, 1]) # H_sum is now shape (C, H, W) 127 | H_sum = tf.signal.fft2d(H_sum) 128 | 129 | Ks = tf.transpose(self.Ks, [2, 0, 1]) # Ks is now shape (C, 1, 1) 130 | 131 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*Ks, dtype=tf.complex64) 132 | x=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(2, 3)))) 133 | 134 | #fourier map 135 | 136 | fourier_weights = tf.transpose(self.fourier_weights, perm=[2,0, 1]) 137 | fourier_weights = tf.dtypes.cast(fourier_weights, dtype=tf.complex64) 138 | fourier_map=tf.math.multiply(Y_fourier,fourier_weights) 139 | # fourier_inv=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(fourier_map), axes=(2, 3)))) 140 | fourier_inv=tf.math.real(tf.signal.ifft2d(fourier_map)) 141 | 142 | fourier_inv=tf.transpose(fourier_inv, [0, 2, 3, 1]) 143 | # print(fourier_inv.shape) 144 | 145 | 146 | # x goes from shape (N, C, H, W) -> (N, H, W, C) 147 | x = tf.transpose(x, [0, 2, 3, 1]) 148 | 149 | x = crop_2d_tf(x) 150 | # print(x.shape) 151 | x = tf.concat([x,fourier_inv],-1) 152 | 153 | return x 154 | 155 | def get_config(self): 156 | config = super().get_config().copy() 157 | config.update({ 158 | 'initial_psfs': self.psfs.numpy(), 159 | 'initial_Ks': self.Ks.numpy() 160 | }) 161 | return config 162 | 163 | class WienerDeconvolution(layers.Layer): 164 | """ 165 | Performs Wiener Deconvolution in frequency domain. 166 | PSF, K are learnable parameters. K is enforced to be nonnegative everywhere. 167 | 168 | Input: initial_psf of shape (Y, X), initial_K is a scalar. 169 | """ 170 | def __init__(self, initial_psf, initial_K): 171 | # def __init__(self): 172 | 173 | super(WienerDeconvolution, self).__init__() 174 | initial_psf = tf.dtypes.cast(initial_psf, dtype=tf.float32) 175 | initial_K = tf.dtypes.cast(initial_K, dtype=tf.float32) 176 | 177 | self.psf = tf.Variable(initial_value=initial_psf, trainable=True) 178 | self.K = tf.Variable(initial_value=initial_K, constraint=tf.nn.relu, trainable=True) # K is constrained to be nonnegative 179 | 180 | def call(self, y): 181 | # Y preprocessing, Y is shape (N, H, W, C) 182 | y_im=y 183 | 184 | _, h, w, _ = y.shape 185 | y = tf.squeeze(tf.dtypes.cast(y, dtype=tf.complex64), axis=-1) # Remove channel dimension 186 | 187 | # Pad Y 188 | padding = ((0, 0), 189 | (int(np.ceil(h / 2)), int(np.floor(h / 2))), 190 | (int(np.ceil(w / 2)), int(np.floor(w / 2)))) 191 | y = tf.pad(y, paddings=padding) 192 | Y=tf.signal.fft2d(y) 193 | 194 | # PSF preprocessing, psf is shape (H, W) 195 | psf = tf.dtypes.cast(self.psf, dtype=tf.complex64) 196 | h_psf, w_psf = psf.shape 197 | 198 | # Pad psf 199 | padding_psf = ( 200 | (int(np.ceil(h_psf / 2)), int(np.floor(h_psf / 2))), 201 | (int(np.ceil(w_psf / 2)), int(np.floor(w_psf / 2)))) 202 | H_sum = tf.pad(psf, paddings=padding_psf) 203 | H_sum=tf.signal.fft2d(H_sum) 204 | 205 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*self.K, dtype=tf.complex64) 206 | x=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(1, 2)))) 207 | 208 | x = crop_2d_tf(x) 209 | 210 | x=x[..., None] # Add channel dimension 211 | # x = tf.concat([x,y_im],-1) 212 | 213 | return x 214 | 215 | def get_config(self): 216 | config = super().get_config().copy() 217 | config.update({ 218 | 'initial_psf': self.psf.numpy(), 219 | 'initial_K': self.K.numpy() 220 | }) 221 | return config 222 | 223 | 224 | 225 | class WienerDeconvolution3D(layers.Layer): 226 | """ 227 | Performs Wiener Deconvolution in the frequency domain for each psf. 228 | 229 | Input: initial_psfs of shape (Y, X, C), initial_K has shape (1, 1, C) for each psf. 230 | """ 231 | 232 | def __init__(self, initial_psfs, initial_Ks): 233 | super(WienerDeconvolution3D, self).__init__() 234 | initial_psfs = tf.dtypes.cast(initial_psfs, dtype=tf.float32) 235 | initial_Ks = tf.dtypes.cast(initial_Ks, dtype=tf.float32) 236 | 237 | self.psfs = tf.Variable(initial_value=initial_psfs, trainable=True) 238 | self.Ks = tf.Variable(initial_value=initial_Ks, constraint=tf.nn.relu, trainable=True) # K is constrained to be nonnegative 239 | 240 | def call(self, y): 241 | # Y preprocessing, Y is shape (N, H, W, C) 242 | # y_im=y 243 | _, h, w, _ = y.shape 244 | y = tf.dtypes.cast(y, dtype=tf.complex64) 245 | 246 | 247 | # Pad Y 248 | padding = ((0, 0), 249 | (int(np.ceil(h / 2)), int(np.floor(h / 2))), 250 | (int(np.ceil(w / 2)), int(np.floor(w / 2))), 251 | (0, 0)) 252 | y = tf.pad(y, paddings=padding) 253 | 254 | # Temporarily transpose y since we cannot specify axes for fft2d 255 | y = tf.transpose(y, perm=[0, 3, 1, 2]) # Y is now shape (N, C, H, W) 256 | Y=tf.signal.fft2d(y) 257 | 258 | # Components preprocessing, psfs is shape (H, W, C) 259 | psf = tf.dtypes.cast(self.psfs, dtype=tf.complex64) 260 | h_psf, w_psf, _ = psf.shape 261 | 262 | # Pad psf 263 | padding_psf = ( 264 | (int(np.ceil(h_psf / 2)), int(np.floor(h_psf / 2))), 265 | (int(np.ceil(w_psf / 2)), int(np.floor(w_psf / 2))), 266 | (0, 0)) 267 | 268 | H_sum = tf.pad(psf, paddings=padding_psf) 269 | 270 | H_sum = tf.transpose(H_sum, perm=[2, 0, 1]) # H_sum is now shape (C, H, W) 271 | H_sum = tf.signal.fft2d(H_sum) 272 | 273 | Ks = tf.transpose(self.Ks, [2, 0, 1]) # Ks is now shape (C, 1, 1) 274 | 275 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*Ks, dtype=tf.complex64) 276 | x=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(2, 3)))) 277 | 278 | # x goes from shape (N, C, H, W) -> (N, H, W, C) 279 | x = tf.transpose(x, [0, 2, 3, 1]) 280 | 281 | x = crop_2d_tf(x) 282 | # x = tf.concat([x,y_im],-1) 283 | 284 | return x 285 | 286 | def get_config(self): 287 | config = super().get_config().copy() 288 | config.update({ 289 | 'initial_psfs': self.psfs.numpy(), 290 | 'initial_Ks': self.Ks.numpy() 291 | }) 292 | return config 293 | 294 | class MultiWienerDeconvolution3D(layers.Layer): 295 | """ 296 | Performs Wiener Deconvolution in the frequency domain for each psf. 297 | 298 | Input: initial_psfs of shape (Y, X, C), initial_K has shape (1, 1, C) for each psf. 299 | """ 300 | 301 | def __init__(self, initial_psfs, initial_Ks): 302 | super(MultiWienerDeconvolution3D, self).__init__() 303 | initial_psfs = tf.dtypes.cast(initial_psfs, dtype=tf.float32) 304 | initial_Ks = tf.dtypes.cast(initial_Ks, dtype=tf.float32) 305 | 306 | self.psfs = tf.Variable(initial_value=initial_psfs, trainable=True) 307 | self.Ks = tf.Variable(initial_value=initial_Ks, constraint=tf.nn.relu, trainable=True) # K is constrained to be nonnegative 308 | 309 | def call(self, y): 310 | # Y preprocessing, Y is shape (N, H, W, C) 311 | # print(y.shape) 312 | y=tf.image.resize(y, [y.shape[1]//2,y.shape[2]//2]) 313 | _, h, w, _ = y.shape 314 | y = tf.dtypes.cast(y, dtype=tf.complex64) 315 | 316 | # print(y.shape) 317 | # Pad Y 318 | padding = ((0, 0), 319 | (int(np.ceil(h / 2)), int(np.floor(h / 2))), 320 | (int(np.ceil(w / 2)), int(np.floor(w / 2))), 321 | (0, 0)) 322 | y = tf.pad(y, paddings=padding) 323 | 324 | # Temporarily transpose y since we cannot specify axes for fft2d 325 | y = tf.transpose(y, perm=[0, 3, 1, 2]) # Y is now shape (N, C, H, W) 326 | Y=tf.signal.fft2d(y) 327 | 328 | # Components preprocessing, psfs is shape (H, W, C) 329 | psf = tf.dtypes.cast(self.psfs, dtype=tf.complex64) 330 | h_psf, w_psf, _,_ = psf.shape 331 | 332 | # Pad psf 333 | padding_psf = ( 334 | (int(np.ceil(h_psf / 2)), int(np.floor(h_psf / 2))), 335 | (int(np.ceil(w_psf / 2)), int(np.floor(w_psf / 2))), 336 | (0, 0),(0,0)) 337 | 338 | 339 | H_sum = tf.pad(psf, paddings=padding_psf) 340 | 341 | 342 | H_sum = tf.transpose(H_sum, perm=[3,2, 0, 1]) # H_sum is now shape (C, H, W) 343 | H_sum_all = tf.signal.fft2d(H_sum) 344 | 345 | Ks_all = tf.transpose(self.Ks, [3,2, 0, 1]) # Ks is now shape (C, 1, 1) 346 | 347 | ##for one xy location 348 | H_sum=H_sum_all[:,0,:,:]#tf.expand_dims(H_sum_all[:,0,:,:],0) 349 | Ks=Ks_all[:,0,:,:]#tf.expand_dims(Ks_all[:,0,:,:],0) 350 | 351 | 352 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*Ks, dtype=tf.complex64) 353 | x1=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(2, 3)))) 354 | 355 | ##for one xy location 356 | H_sum=H_sum_all[:,1,:,:]#tf.expand_dims(H_sum_all[:,0,:,:],0) 357 | Ks=Ks_all[:,1,:,:]#tf.expand_dims(Ks_all[:,0,:,:],0) 358 | 359 | 360 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*Ks, dtype=tf.complex64) 361 | x2=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(2, 3)))) 362 | 363 | ##for one xy location 364 | H_sum=H_sum_all[:,2,:,:]#tf.expand_dims(H_sum_all[:,0,:,:],0) 365 | Ks=Ks_all[:,2,:,:]#tf.expand_dims(Ks_all[:,0,:,:],0) 366 | 367 | 368 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*Ks, dtype=tf.complex64) 369 | x3=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(2, 3)))) 370 | 371 | ##for one xy location 372 | H_sum=H_sum_all[:,3,:,:]#tf.expand_dims(H_sum_all[:,0,:,:],0) 373 | Ks=Ks_all[:,3,:,:]#tf.expand_dims(Ks_all[:,0,:,:],0) 374 | 375 | 376 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*Ks, dtype=tf.complex64) 377 | x4=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(2, 3)))) 378 | 379 | ##for one xy location 380 | H_sum=H_sum_all[:,4,:,:]#tf.expand_dims(H_sum_all[:,0,:,:],0) 381 | Ks=Ks_all[:,4,:,:]#tf.expand_dims(Ks_all[:,0,:,:],0) 382 | 383 | 384 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*Ks, dtype=tf.complex64) 385 | x5=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(2, 3)))) 386 | 387 | ##for one xy location 388 | H_sum=H_sum_all[:,5,:,:]#tf.expand_dims(H_sum_all[:,0,:,:],0) 389 | Ks=Ks_all[:,5,:,:]#tf.expand_dims(Ks_all[:,0,:,:],0) 390 | 391 | 392 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*Ks, dtype=tf.complex64) 393 | x6=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(2, 3)))) 394 | 395 | ##for one xy location 396 | H_sum=H_sum_all[:,6,:,:]#tf.expand_dims(H_sum_all[:,0,:,:],0) 397 | Ks=Ks_all[:,6,:,:]#tf.expand_dims(Ks_all[:,0,:,:],0) 398 | 399 | 400 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*Ks, dtype=tf.complex64) 401 | x7=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(2, 3)))) 402 | 403 | ##for one xy location 404 | H_sum=H_sum_all[:,7,:,:]#tf.expand_dims(H_sum_all[:,0,:,:],0) 405 | Ks=Ks_all[:,7,:,:]#tf.expand_dims(Ks_all[:,0,:,:],0) 406 | 407 | 408 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*Ks, dtype=tf.complex64) 409 | x8=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(2, 3)))) 410 | 411 | ##for one xy location 412 | H_sum=H_sum_all[:,8,:,:]#tf.expand_dims(H_sum_all[:,0,:,:],0) 413 | Ks=Ks_all[:,8,:,:]#tf.expand_dims(Ks_all[:,0,:,:],0) 414 | 415 | 416 | X=(tf.math.conj(H_sum)*Y) / tf.dtypes.cast(tf.math.square(tf.math.abs(H_sum))+1000*Ks, dtype=tf.complex64) 417 | x9=tf.math.real((tf.signal.ifftshift(tf.signal.ifft2d(X), axes=(2, 3)))) 418 | 419 | # x goes from shape (N, C, H, W) -> (N, H, W, C) 420 | x=tf.concat([x1,x2,x3,x4,x5,x6,x7,x8,x9],1) 421 | x = tf.transpose(x, [0, 2, 3, 1]) 422 | x = crop_2d_tf(x) 423 | # x = tf.concat([x,y_im],-1) 424 | 425 | return x 426 | 427 | def get_config(self): 428 | config = super().get_config().copy() 429 | config.update({ 430 | 'initial_psfs': self.psfs.numpy(), 431 | 'initial_Ks': self.Ks.numpy() 432 | }) 433 | return config -------------------------------------------------------------------------------- /tensorflow/models/model_2d.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers 3 | from models.layers import * 4 | 5 | def conv2d_block(x, filters, kernel_size, padding='same', dilation_rate=1, batch_norm=True, activation='relu'): 6 | """ 7 | Applies Conv2D - BN - ReLU block. 8 | """ 9 | x = layers.Conv2D(filters, kernel_size, padding=padding, use_bias=False)(x) 10 | 11 | if batch_norm: 12 | x = layers.BatchNormalization()(x) 13 | 14 | if activation is not None: 15 | x = layers.Activation(activation)(x) 16 | 17 | return x 18 | 19 | 20 | def encoder_block(x, filters, kernel_size, padding='same', dilation_rate=1, pooling='max'): 21 | """ 22 | Encoder block used in contracting path of UNet. 23 | """ 24 | x = conv2d_block(x, filters, kernel_size, padding, dilation_rate, batch_norm=True, activation='relu') 25 | x = conv2d_block(x, filters, kernel_size, padding, dilation_rate, batch_norm=True, activation='relu') 26 | x_skip = x 27 | # print(x.shape) 28 | if pooling == 'max': 29 | x = layers.MaxPooling2D(pool_size=(2, 2))(x) 30 | elif pooling == 'average': 31 | x = layers.AveragePooling2D(pool_size=(2, 2))(x) 32 | else: 33 | assert False, 'Pooling layer {} not implemented'.format(pooling) 34 | 35 | return x, x_skip 36 | 37 | 38 | def decoder_block(x, x_skip, filters, kernel_size, padding='same', dilation_rate=1): 39 | """ 40 | Decoder block used in expansive path of UNet. 41 | """ 42 | x = layers.UpSampling2D(size=(2, 2), interpolation='nearest')(x) 43 | 44 | # Calculate cropping for down_tensor to concatenate with x 45 | 46 | if x_skip is not None: 47 | _, h2, w2, _ = x_skip.shape 48 | _, h1, w1, _ = x.shape 49 | h_diff, w_diff = h2 - h1, w2 - w1 50 | 51 | cropping = ((int(np.ceil(h_diff / 2)), int(np.floor(h_diff / 2))), 52 | (int(np.ceil(w_diff / 2)), int(np.floor(w_diff / 2)))) 53 | x_skip = layers.Cropping2D(cropping=cropping)(x_skip) 54 | x = layers.concatenate([x, x_skip], axis=3) 55 | 56 | x = conv2d_block(x, filters, kernel_size, padding, dilation_rate, batch_norm=True, activation='relu') 57 | x = conv2d_block(x, filters, kernel_size, padding, dilation_rate, batch_norm=True, activation='relu') 58 | x = conv2d_block(x, filters, kernel_size, padding, dilation_rate, batch_norm=True, activation='relu') 59 | 60 | return x 61 | 62 | 63 | ################################################################################################################################################ 64 | 65 | 66 | def decoder_block_resize(x, x_skip, filters, kernel_size, padding='same', dilation_rate=1): 67 | """ 68 | Decoder block used in expansive path of UNet. Unlike before, this block resizes the skip connections rather than cropping. 69 | """ 70 | # print(x.shape) 71 | # print(x_skip.shape[1:3]) 72 | x = tf.image.resize(x, x_skip.shape[1:3], method='nearest') 73 | 74 | x = layers.concatenate([x, x_skip], axis=3) 75 | x = conv2d_block(x, filters, kernel_size, padding, dilation_rate, batch_norm=True, activation='relu') 76 | x = conv2d_block(x, filters, kernel_size, padding, dilation_rate, batch_norm=True, activation='relu') 77 | x = conv2d_block(x, filters, kernel_size, padding, dilation_rate, batch_norm=True, activation='relu') 78 | 79 | return x 80 | 81 | 82 | def UNet(height, width, encoding_cs=[24, 64, 128, 256, 512, 1024], 83 | center_cs=1024, 84 | decoding_cs=[512, 256, 128, 64, 24, 24], 85 | skip_connections=[True, True, True, True, True, False]): 86 | 87 | """ 88 | Basic UNet which does not require cropping. 89 | 90 | Inputs: 91 | - height: input height 92 | - width: input width 93 | - encoding_cs: list of channels along contracting path 94 | - decoding_cs: list of channels along expansive path 95 | """ 96 | 97 | inputs = tf.keras.Input((height, width, 1)) 98 | 99 | x = inputs 100 | 101 | skips = [] 102 | 103 | # Contracting path 104 | for c in encoding_cs: 105 | x, x_skip = encoder_block(x, c, kernel_size=3, padding='same', dilation_rate=1, pooling='average') 106 | skips.append(x_skip) 107 | 108 | skips = list(reversed(skips)) 109 | 110 | # Center 111 | x = conv2d_block(x, center_cs, kernel_size=3, padding='same') 112 | 113 | # Expansive path 114 | for i, c in enumerate(decoding_cs): 115 | if skip_connections[i]: 116 | x = decoder_block_resize(x, skips[i], c, kernel_size=3, padding='same', dilation_rate=1) 117 | else: 118 | x = decoder_block(x, None, c, kernel_size=3, padding='same', dilation_rate=1) 119 | 120 | # Classify 121 | x = layers.Conv2D(filters=1, kernel_size=1, use_bias=True, activation='relu')(x) 122 | # outputs=x 123 | outputs = tf.squeeze(x, axis=3) 124 | 125 | model = tf.keras.Model(inputs=[inputs], outputs=[outputs]) 126 | 127 | return model 128 | 129 | 130 | 131 | def UNet_multiwiener_resize(height, width, initial_psfs, initial_Ks, 132 | encoding_cs=[24, 64, 128, 256, 512, 1024], 133 | center_cs=1024, 134 | decoding_cs=[512, 256, 128, 64, 24, 24], 135 | skip_connections=[True, True, True, True, True, True]): 136 | """ 137 | Multiwiener UNet which doesn't require cropping. 138 | 139 | Inputs: 140 | - height: input height 141 | - width: input width 142 | - initial_psfs: preinitialized psfs 143 | - initial_Ks: regularization terms for Wiener deconvolutions 144 | - encoding_cs: list of channels along contracting path 145 | - decoding_cs: list of channels along expansive path 146 | - skip_connections: list of boolean to determine whether to concatenate with decoding channel at that index 147 | """ 148 | 149 | inputs = tf.keras.Input((height, width, 1)) 150 | 151 | x = inputs 152 | 153 | # Multi-Wiener deconvolutions 154 | x = MultiWienerDeconvolution(initial_psfs, initial_Ks)(x) 155 | 156 | skips = [] 157 | 158 | # Contracting path 159 | for c in encoding_cs: 160 | x, x_skip = encoder_block(x, c, kernel_size=3, padding='same', dilation_rate=1, pooling='average') 161 | skips.append(x_skip) 162 | 163 | skips = list(reversed(skips)) 164 | 165 | # Center 166 | x = conv2d_block(x, center_cs, kernel_size=3, padding='same') 167 | 168 | # Expansive path 169 | for i, c in enumerate(decoding_cs): 170 | if skip_connections[i]: 171 | x = decoder_block_resize(x, skips[i], c, kernel_size=3, padding='same', dilation_rate=1) 172 | else: 173 | x = decoder_block(x, None, c, kernel_size=3, padding='same', dilation_rate=1) 174 | 175 | # Classify 176 | x = layers.Conv2D(filters=1, kernel_size=1, use_bias=True, activation='relu')(x) 177 | outputs = tf.squeeze(x, axis=3) 178 | 179 | model = tf.keras.Model(inputs=[inputs], outputs=[outputs]) 180 | 181 | return model 182 | 183 | 184 | 185 | def UNet_wiener(height, width, initial_psf, initial_K, 186 | encoding_cs=[24, 64, 128, 256, 512, 1024], 187 | center_cs=1024, 188 | decoding_cs=[512, 256, 128, 64, 24, 24], 189 | skip_connections=[True, True, True, True, True, True]): 190 | """ 191 | Single Wiener UNet which doesn't require cropping. 192 | 193 | Inputs: 194 | - height: input height 195 | - width: input width 196 | - initial_psf: preinitialized psf 197 | - initial_K: regularization term for Wiener deconvolution 198 | - encoding_cs: list of channels along contracting path 199 | - decoding_cs: list of channels along expansive path 200 | - skip_connections: list of boolean to determine whether to concatenate with decoding channel at that index 201 | """ 202 | 203 | inputs = tf.keras.Input((height, width, 1)) 204 | 205 | x = inputs 206 | 207 | # Multi-Wiener deconvolutions 208 | x = WienerDeconvolution(initial_psf, initial_K)(x) 209 | 210 | skips = [] 211 | 212 | # Contracting path 213 | for c in encoding_cs: 214 | x, x_skip = encoder_block(x, c, kernel_size=3, padding='same', dilation_rate=1, pooling='average') 215 | skips.append(x_skip) 216 | 217 | skips = list(reversed(skips)) 218 | 219 | # Center 220 | x = conv2d_block(x, center_cs, kernel_size=3, padding='same') 221 | 222 | # Expansive path 223 | for i, c in enumerate(decoding_cs): 224 | if skip_connections[i]: 225 | x = decoder_block_resize(x, skips[i], c, kernel_size=3, padding='same', dilation_rate=1) 226 | else: 227 | x = decoder_block(x, None, c, kernel_size=3, padding='same', dilation_rate=1) 228 | 229 | # Classify 230 | x = layers.Conv2D(filters=1, kernel_size=1, use_bias=True, activation='relu')(x) 231 | outputs = tf.squeeze(x, axis=3) 232 | 233 | model = tf.keras.Model(inputs=[inputs], outputs=[outputs]) 234 | 235 | return model 236 | 237 | 238 | 239 | 240 | -------------------------------------------------------------------------------- /tensorflow/training_code.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import os, glob\n", 13 | "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' \n", 14 | "\n", 15 | "import numpy as np\n", 16 | "import scipy.io\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import tensorflow as tf\n", 19 | "import tensorflow.keras as keras\n", 20 | "from tensorflow.keras import backend as K\n", 21 | "from tensorflow.keras import layers\n", 22 | "from tensorflow.keras.callbacks import ModelCheckpoint\n", 23 | "\n", 24 | "import matplotlib as mpl\n", 25 | "mpl.rc('image', cmap='inferno')\n", 26 | "\n", 27 | "import models.model_2d as mod\n", 28 | "import forward_model as fm\n", 29 | "import utils as ut" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "!gpustat" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 48 | "\n", 49 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" " 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "# Training code for 2D spatially-varying deconvolutions" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Make dataset and dataloader for training data" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "AUTOTUNE = tf.data.experimental.AUTOTUNE\n", 73 | "batch_size = 2" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "target_dir = '/home/kyrollos/LearnedMiniscope3D/Data/Target/' # path to objects (ground truth)\n", 83 | "input_dir = '/home/kyrollos/LearnedMiniscope3D/Data/Train/' # path to simulated measurements (inputs to deconv.)\n", 84 | "\n", 85 | "target_path = sorted(glob.glob(target_dir + '*'))\n", 86 | "input_path = sorted(glob.glob(input_dir + '*'))\n", 87 | "\n", 88 | "image_count=len(os.listdir(target_dir))\n", 89 | "print(image_count) " 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "# Create a first dataset of file paths and labels\n", 99 | "dataset = tf.data.Dataset.from_tensor_slices((input_path, target_path))\n", 100 | "dataset = dataset.shuffle(image_count, reshuffle_each_iteration=False)\n", 101 | "\n", 102 | "\n", 103 | "# Split into train/validation\n", 104 | "val_size = int(image_count * 0.25)\n", 105 | "train_ds = dataset.skip(val_size)\n", 106 | "val_ds = dataset.take(val_size)\n", 107 | "\n", 108 | "print(tf.data.experimental.cardinality(train_ds).numpy())\n", 109 | "print(tf.data.experimental.cardinality(val_ds).numpy())\n", 110 | "\n", 111 | "train_ds = train_ds.map(ut.parse_function, num_parallel_calls=AUTOTUNE)\n", 112 | "val_ds = val_ds.map(ut.parse_function, num_parallel_calls=AUTOTUNE)\n", 113 | "\n", 114 | "train_ds = ut.configure_for_performance(train_ds,batch_size)\n", 115 | "val_ds = ut.configure_for_performance(val_ds,batch_size)\n", 116 | "\n", 117 | "print(tf.data.experimental.cardinality(train_ds).numpy())\n", 118 | "print(tf.data.experimental.cardinality(val_ds).numpy())" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "#visualzie data to make sure all is good\n", 128 | "input_batch, target_batch = next(iter(val_ds))\n", 129 | "f, ax = plt.subplots(1, 2, figsize=(15,15))\n", 130 | "\n", 131 | "ax[0].imshow(input_batch[0,:,:,0], vmax = 1)\n", 132 | "ax[0].set_title('Input Data')\n", 133 | "\n", 134 | "ax[1].imshow(target_batch[0,:,:,0], vmax = 1)\n", 135 | "ax[1].set_title('Target Data')\n", 136 | "\n", 137 | "print(input_batch[0,:,:,0].shape)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "# load in Psfs and initialize network to train\n", 145 | "\n", 146 | "Here we initialize with 9 PSFs taken from different parts in the field of view" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "# choose network type to train\n", 156 | "model_type='multiwiener' # choices are 'multiwiener', 'wiener', 'unet'\n", 157 | "filter_init_path = '../data/multiWienerPSFStack_40z_aligned.mat' # initialize with 9 PSFs\n", 158 | "filter_key = 'multiWienerPSFStack_40z' # key to load in" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "if model_type=='unet':\n", 168 | " model =mod.UNet(486, 648, \n", 169 | " encoding_cs=[24, 64, 128, 256, 512, 1024],\n", 170 | " center_cs=1024,\n", 171 | " decoding_cs=[512, 256, 128, 64, 24, 24],\n", 172 | " skip_connections=[True, True, True, True, True, False])\n", 173 | "elif model_type=='wiener':\n", 174 | "\n", 175 | " registered_psfs_path = filter_init_path\n", 176 | " psfs = scipy.io.loadmat(registered_psfs_path)\n", 177 | " psfs=psfs[filter_key]\n", 178 | " psfs=psfs[:,:,0,0]\n", 179 | " psfs=psfs/np.max(psfs)\n", 180 | " \n", 181 | " Ks=1\n", 182 | "\n", 183 | " model = mod.UNet_wiener(486, 648, psfs, Ks, \n", 184 | " encoding_cs=[24, 64, 128, 256, 512, 1024],\n", 185 | " center_cs=1024,\n", 186 | " decoding_cs=[512, 256, 128, 64, 24, 24],\n", 187 | " skip_connections=[True, True, True, True, True, False])\n", 188 | " \n", 189 | " print(psfs.shape, 1)\n", 190 | " \n", 191 | "elif model_type=='multiwiener':\n", 192 | " registered_psfs_path = filter_init_path\n", 193 | " psfs = scipy.io.loadmat(registered_psfs_path)\n", 194 | " psfs=psfs[filter_key]\n", 195 | " \n", 196 | " psfs=psfs[:,:,:,0]\n", 197 | " psfs=psfs/np.max(psfs)\n", 198 | " \n", 199 | " Ks =np.ones((1,1,9))\n", 200 | " \n", 201 | " model =mod.UNet_multiwiener_resize(486, 648, psfs, Ks, \n", 202 | " encoding_cs=[24, 64, 128, 256, 512, 1024],\n", 203 | " center_cs=1024,\n", 204 | " decoding_cs=[512, 256, 128, 64, 24, 24],\n", 205 | " skip_connections=[True, True, True, True, True, False])\n", 206 | " \n", 207 | " print('initialized filter shape:', psfs.shape, 'initialized K shape:', Ks.shape)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "model.build((None, 486, 648, 1))\n", 217 | "\n", 218 | "model.summary()" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "# Train" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "## Training with TF.Dataset\n", 235 | "initial_learning_rate = 1e-4\n", 236 | "optimizer = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate, beta_1=0.9, beta_2=0.999, amsgrad=False) #1e-3 diverges\n", 237 | "\n", 238 | "# Keep results for plotting\n", 239 | "train_loss_results = []\n", 240 | "train_accuracy_results = []\n", 241 | "validtate_loss_results=[]\n", 242 | "num_epochs = 1000\n", 243 | "loss_func=ut.SSIMLoss_l1\n", 244 | "learning_rate_counter=0\n", 245 | "for epoch in range(num_epochs):\n", 246 | " validation_loss_avg=tf.keras.metrics.Mean()\n", 247 | " epoch_loss_avg = tf.keras.metrics.Mean()\n", 248 | " epoch_accuracy = tf.keras.metrics.MeanSquaredError()\n", 249 | "\n", 250 | " # Training loop\n", 251 | " iter_num=0\n", 252 | " for x, y in train_ds:\n", 253 | " # Optimize the model\n", 254 | " loss_value, grads = ut.grad(model,loss_func, x, y)\n", 255 | " optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", 256 | "\n", 257 | " # Track progress\n", 258 | " epoch_loss_avg.update_state(loss_value) # Add current batch loss\n", 259 | "\n", 260 | " epoch_accuracy.update_state(y, model(x)) \n", 261 | " # Print every 1\n", 262 | " if iter_num % 1 == 0:\n", 263 | " print(\"Epoch {:03d}: Step: {:03d}, Loss: {:.3f}, MSE: {:.3}\".format(epoch, iter_num,epoch_loss_avg.result(),\n", 264 | " epoch_accuracy.result()),end='\\r')\n", 265 | " iter_num=iter_num+1\n", 266 | " \n", 267 | " \n", 268 | "\n", 269 | " # End epoch\n", 270 | " train_loss_results.append(epoch_loss_avg.result())\n", 271 | " train_accuracy_results.append(epoch_accuracy.result())\n", 272 | "\n", 273 | "\n", 274 | " for x_val, y_val in val_ds:\n", 275 | " val_loss_value= loss_func(model, x_val, y_val)\n", 276 | " validation_loss_avg.update_state(val_loss_value)\n", 277 | " \n", 278 | " \n", 279 | " validtate_loss_results.append(validation_loss_avg.result()) \n", 280 | " if epoch % 1 == 0:\n", 281 | " print(\"Epoch {:03d}: MSE: {:.3}, Training Loss: {:.3f}, Validation Loss: {:.3f}\".format(epoch,\n", 282 | " epoch_accuracy.result(), epoch_loss_avg.result(), \n", 283 | " validation_loss_avg.result()))" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "# model.load_weights('./saved_models/multiwiener')" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "# test on validation data\n", 302 | "input_batch, target_batch = next(iter(val_ds))\n", 303 | "imnum=1\n", 304 | "f, ax = plt.subplots(1, 2, figsize=(15,15))\n", 305 | "ax[0].imshow((target_batch[imnum,:,:,0]))\n", 306 | "ax[0].set_title('Target Data')\n", 307 | "\n", 308 | "test=model(input_batch[imnum,:,:,0].numpy().reshape((1,486, 648,1)))\n", 309 | "ax[1].set_title('recon')\n", 310 | "ax[1].imshow(test[0,:,:])\n" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": {}, 316 | "source": [ 317 | "Once training is working, save your model using: \n", 318 | "\n", 319 | " model.save_weights('./saved_models/model_name')\n", 320 | "\n", 321 | "You can save after training is complete, or periodically throughout epochs." 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [] 330 | } 331 | ], 332 | "metadata": { 333 | "kernelspec": { 334 | "display_name": "eager-latest2", 335 | "language": "python", 336 | "name": "homekyrollosanaconda3env" 337 | }, 338 | "language_info": { 339 | "codemirror_mode": { 340 | "name": "ipython", 341 | "version": 3 342 | }, 343 | "file_extension": ".py", 344 | "mimetype": "text/x-python", 345 | "name": "python", 346 | "nbconvert_exporter": "python", 347 | "pygments_lexer": "ipython3", 348 | "version": "3.5.6" 349 | } 350 | }, 351 | "nbformat": 4, 352 | "nbformat_minor": 4 353 | } 354 | -------------------------------------------------------------------------------- /tensorflow/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import math 4 | 5 | def normalize(x): 6 | """ 7 | Normalizes numpy array to [0, 1]. 8 | """ 9 | a = np.min(x) 10 | b = np.max(x) 11 | return (x - a) / (b - a) 12 | 13 | def pad_2d(x, mode='constant'): 14 | """ 15 | Pads 2d array x before FFT convolution. 16 | """ 17 | _, _, h, w = x.shape 18 | padding = ((0, 0), (0, 0), 19 | (int(np.ceil(h / 2)), int(np.floor(h / 2))), 20 | (int(np.ceil(w / 2)), int(np.floor(w / 2)))) 21 | x = np.pad(x, pad_width=padding, mode=mode) 22 | return x 23 | 24 | def pad_2d_tf(x, mode='CONSTANT', axes=(-2, -1)): 25 | """ 26 | Fix pad_2d to allow variable length dimensions. 27 | """ 28 | n_dim = len(x.shape) 29 | axes = np.array(axes) % n_dim 30 | 31 | h, w = np.array(x.shape)[axes] 32 | padding = np.array([(0, 0)] * n_dim) 33 | padding[axes] = (int(np.ceil(h / 2)), int(np.floor(h / 2))), (int(np.ceil(w / 2)), int(np.floor(w / 2))) 34 | 35 | x = tf.pad(x, paddings=padding, mode=mode) 36 | return x 37 | 38 | 39 | 40 | def crop_2d(v): 41 | """ 42 | Crops 2d array x after FFT convolution. Inverse of pad2d. 43 | """ 44 | h, w = v.shape 45 | h1, h2 = int(np.ceil(h / 4)), h - int(np.floor(h / 4)) 46 | w1, w2 = int(np.ceil(w / 4)), w - int(np.floor(w / 4)) 47 | return v[h1:h2, w1:w2] 48 | 49 | def crop_2d_tf(v): 50 | """ 51 | Crops 2d array v after FFT convolution. Inverse of pad2d. 52 | """ 53 | n_dim = len(v.shape) 54 | 55 | if n_dim == 2: 56 | h, w = v.shape 57 | h1, h2 = int(np.ceil(h / 4)), h - int(np.floor(h / 4)) 58 | w1, w2 = int(np.ceil(w / 4)), w - int(np.floor(w / 4)) 59 | return v[h1:h2, w1:w2] 60 | elif n_dim == 3: 61 | _, h, w = v.shape 62 | h1, h2 = int(np.ceil(h / 4)), h - int(np.floor(h / 4)) 63 | w1, w2 = int(np.ceil(w / 4)), w - int(np.floor(w / 4)) 64 | return v[:, h1:h2, w1:w2] 65 | elif n_dim == 4: 66 | _, h, w, _ = v.shape 67 | h1, h2 = int(np.ceil(h / 4)), h - int(np.floor(h / 4)) 68 | w1, w2 = int(np.ceil(w / 4)), w - int(np.floor(w / 4)) 69 | return v[:, h1:h2, w1:w2, :] 70 | 71 | elif n_dim == 5: 72 | _, h, w, _,_ = v.shape 73 | h1, h2 = int(np.ceil(h / 4)), h - int(np.floor(h / 4)) 74 | w1, w2 = int(np.ceil(w / 4)), w - int(np.floor(w / 4)) 75 | return v[:, h1:h2, w1:w2, :,:] 76 | 77 | 78 | def calc_psnr(Iin,Itarget): 79 | 80 | mse=np.mean(np.square(Iin-Itarget)) 81 | return 10*math.log10(1/mse) 82 | 83 | 84 | def parse_function(inputname, outputname): 85 | 86 | # Read an image from a file 87 | input_string = tf.io.read_file(inputname) 88 | # Decode it into a dense vector 89 | input_decoded = tf.cast(tf.image.decode_png(input_string, channels=1), tf.float32) 90 | # Resize it to fixed shape 91 | # input_resized = tf.image.resize(input_decoded, [img_height, img_width]) 92 | input_normalized = input_decoded / 255.0 93 | 94 | # Read an image from a file 95 | output_string = tf.io.read_file(outputname) 96 | # Decode it into a dense vector 97 | output_decoded = tf.cast(tf.image.decode_png(output_string, channels=1), tf.float32) 98 | # Resize it to fixed shape 99 | # Normalize it from [0, 255] to [0.0, 1.0] 100 | output_normalized = output_decoded / 255.0 101 | 102 | return input_normalized, output_normalized 103 | 104 | 105 | def configure_for_performance(ds,batch_size): #shuffte, batch, and have batches available asap 106 | ds = ds.cache() 107 | ds = ds.shuffle(buffer_size=1000) 108 | ds = ds.batch(batch_size) 109 | return ds 110 | 111 | 112 | 113 | def SSIMLoss(y_true, y_pred): 114 | y_true = y_true[..., np.newaxis] 115 | y_pred = y_pred[..., np.newaxis] 116 | 117 | return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0)) 118 | 119 | def SSIMLoss_l1(y_true, y_pred): 120 | y_true = y_true[..., np.newaxis] 121 | y_pred = y_pred[..., np.newaxis] 122 | L1=tf.reduce_mean(tf.abs(y_true-y_pred)) 123 | 124 | return (1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0)))+L1 125 | 126 | 127 | 128 | def grad(model, myloss,inputs, targets): 129 | with tf.GradientTape() as tape: 130 | loss_value = myloss(model, inputs, targets) 131 | return loss_value, tape.gradient(loss_value, model.trainable_variables) 132 | 133 | 134 | def SSIMLoss(y_true, y_pred): 135 | y_true = y_true[..., np.newaxis] 136 | y_pred = y_pred[..., np.newaxis] 137 | 138 | return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0)) 139 | 140 | 141 | 142 | def SSIMLoss_l1(model,x,y_true): 143 | y_pred=model(x) 144 | y_pred = tf.expand_dims(y_pred, -1) 145 | loss_l1 = tf.reduce_mean(tf.abs(y_pred - y_true), axis=-1) 146 | loss_ssim=1.0 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0)) 147 | return loss_l1+loss_ssim 148 | -------------------------------------------------------------------------------- /tensorflow/utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/MultiWienerNet/f49a38e74a73bcf58f91a46d3ff0d2360d213283/tensorflow/utils.pyc --------------------------------------------------------------------------------