├── ADMM.py ├── GD.py ├── LICENSE ├── README.md ├── admm_config.yml ├── environment.yml ├── gd_config.yml ├── rpi └── preview.py ├── test_images ├── cal_logo_rgb.png ├── dog_rgb.jpg ├── google_chrome_logo_rgb.png └── spiral_bw.gif └── tutorial ├── ADMM.ipynb ├── GD.ipynb ├── psf_sample.tif └── rawdata_hand_sample.tif /ADMM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.fft as fft 3 | from PIL import Image 4 | import matplotlib.pyplot as plt 5 | import yaml 6 | 7 | def loadData(show_im=True): 8 | psf = Image.open(psfname) 9 | psf = np.array(psf, dtype='float32') 10 | data = Image.open(imgname) 11 | data = np.array(data, dtype='float32') 12 | 13 | """In the picamera, there is a non-trivial background 14 | (even in the dark) that must be subtracted""" 15 | bg = np.mean(psf[5:15,5:15]) 16 | psf -= bg 17 | data -= bg 18 | 19 | """Resize to a more manageable size to do reconstruction on. 20 | Because resizing is downsampling, it is subject to aliasing 21 | (artifacts produced by the periodic nature of sampling). Demosaicing is an attempt 22 | to account for/reduce the aliasing caused. In this application, we do the simplest 23 | possible demosaicing algorithm: smoothing/blurring the image with a box filter""" 24 | 25 | def resize(img, factor): 26 | num = int(-np.log2(factor)) 27 | for i in range(num): 28 | img = 0.25*(img[::2,::2,...]+img[1::2,::2,...]+img[::2,1::2,...]+img[1::2,1::2,...]) 29 | return img 30 | 31 | 32 | psf = resize(psf, f) 33 | data = resize(data, f) 34 | 35 | """Now we normalize the images so they have the same total power. Technically not a 36 | necessary step, but the optimal hyperparameters are a function of the total power in 37 | the PSF (among other things), so it makes sense to standardize it""" 38 | 39 | psf /= np.linalg.norm(psf.ravel()) 40 | data /= np.linalg.norm(data.ravel()) 41 | 42 | if show_im: 43 | fig1 = plt.figure() 44 | plt.imshow(psf, cmap='gray') 45 | plt.title('PSF') 46 | fig2 = plt.figure() 47 | plt.imshow(data, cmap='gray') 48 | plt.title('Raw data') 49 | plt.show() 50 | return psf, data 51 | 52 | def U_update(eta, image_est, tau): 53 | return SoftThresh(Psi(image_est) + eta/mu2, tau/mu2) 54 | 55 | 56 | def SoftThresh(x, tau): 57 | # numpy automatically applies functions to each element of the array 58 | return np.sign(x)*np.maximum(0, np.abs(x) - tau) 59 | 60 | 61 | def Psi(v): 62 | return np.stack((np.roll(v,1,axis=0) - v, np.roll(v, 1, axis=1) - v), axis=2) 63 | 64 | 65 | def X_update(xi, image_est, H_fft, sensor_reading, X_divmat): 66 | return X_divmat * (xi + mu1*M(image_est, H_fft) + CT(sensor_reading)) 67 | 68 | 69 | def M(vk, H_fft): 70 | return np.real(fft.fftshift(fft.ifft2(fft.fft2(fft.ifftshift(vk))*H_fft))) 71 | 72 | 73 | def C(M): 74 | # Image stored as matrix (row-column rather than x-y) 75 | top = (full_size[0] - sensor_size[0])//2 76 | bottom = (full_size[0] + sensor_size[0])//2 77 | left = (full_size[1] - sensor_size[1])//2 78 | right = (full_size[1] + sensor_size[1])//2 79 | return M[top:bottom,left:right] 80 | 81 | def CT(b): 82 | v_pad = (full_size[0] - sensor_size[0])//2 83 | h_pad = (full_size[1] - sensor_size[1])//2 84 | return np.pad(b, ((v_pad, v_pad), (h_pad, h_pad)), 'constant',constant_values=(0,0)) 85 | 86 | 87 | def precompute_X_divmat(): 88 | """Only call this function once! 89 | Store it in a variable and only use that variable 90 | during every update step""" 91 | return 1./(CT(np.ones(sensor_size)) + mu1) 92 | 93 | def W_update(rho, image_est): 94 | return np.maximum(rho/mu3 + image_est, 0) 95 | 96 | def r_calc(w, rho, u, eta, x, xi, H_fft): 97 | return (mu3*w - rho)+PsiT(mu2*u - eta) + MT(mu1*x - xi, H_fft) 98 | 99 | def V_update(w, rho, u, eta, x, xi, H_fft, R_divmat): 100 | freq_space_result = R_divmat*fft.fft2( fft.ifftshift(r_calc(w, rho, u, eta, x, xi, H_fft)) ) 101 | return np.real(fft.fftshift(fft.ifft2(freq_space_result))) 102 | 103 | def PsiT(U): 104 | diff1 = np.roll(U[...,0],-1,axis=0) - U[...,0] 105 | diff2 = np.roll(U[...,1],-1,axis=1) - U[...,1] 106 | return diff1 + diff2 107 | 108 | def MT(x, H_fft): 109 | x_zeroed = fft.ifftshift(x) 110 | return np.real(fft.fftshift(fft.ifft2(fft.fft2(x_zeroed) * np.conj(H_fft)))) 111 | 112 | def precompute_PsiTPsi(): 113 | PsiTPsi = np.zeros(full_size) 114 | PsiTPsi[0,0] = 4 115 | PsiTPsi[0,1] = PsiTPsi[1,0] = PsiTPsi[0,-1] = PsiTPsi[-1,0] = -1 116 | PsiTPsi = fft.fft2(PsiTPsi) 117 | return PsiTPsi 118 | 119 | 120 | def precompute_R_divmat(H_fft, PsiTPsi): 121 | """Only call this function once! 122 | Store it in a variable and only use that variable 123 | during every update step""" 124 | MTM_component = mu1*(np.abs(np.conj(H_fft)*H_fft)) 125 | PsiTPsi_component = mu2*np.abs(PsiTPsi) 126 | id_component = mu3 127 | """This matrix is a mask in frequency space. So we will only use 128 | it on images that have already been transformed via an fft""" 129 | return 1./(MTM_component + PsiTPsi_component + id_component) 130 | 131 | def xi_update(xi, V, H_fft, X): 132 | return xi + mu1*(M(V,H_fft) - X) 133 | 134 | def eta_update(eta, V, U): 135 | return eta + mu2*(Psi(V) - U) 136 | 137 | def rho_update(rho, V, W): 138 | return rho + mu3*(V - W) 139 | 140 | 141 | def init_Matrices(H_fft): 142 | X = np.zeros(full_size) 143 | U = np.zeros((full_size[0], full_size[1], 2)) 144 | V = np.zeros(full_size) 145 | W = np.zeros(full_size) 146 | 147 | xi = np.zeros_like(M(V,H_fft)) 148 | eta = np.zeros_like(Psi(V)) 149 | rho = np.zeros_like(W) 150 | return X,U,V,W,xi,eta,rho 151 | 152 | 153 | def precompute_H_fft(psf): 154 | return fft.fft2(fft.ifftshift(CT(psf))) 155 | 156 | def ADMM_Step(X,U,V,W,xi,eta,rho, precomputed): 157 | H_fft, data, X_divmat, R_divmat = precomputed 158 | U = U_update(eta, V, tau) 159 | X = X_update(xi, V, H_fft, data, X_divmat) 160 | V = V_update(W, rho, U, eta, X, xi, H_fft, R_divmat) 161 | W = W_update(rho, V) 162 | xi = xi_update(xi, V, H_fft, X) 163 | eta = eta_update(eta, V, U) 164 | rho = rho_update(rho, V, W) 165 | 166 | return X,U,V,W,xi,eta,rho 167 | 168 | 169 | def runADMM(psf, data): 170 | H_fft = precompute_H_fft(psf) 171 | X,U,V,W,xi,eta,rho = init_Matrices(H_fft) 172 | X_divmat = precompute_X_divmat() 173 | PsiTPsi = precompute_PsiTPsi() 174 | R_divmat = precompute_R_divmat(H_fft, PsiTPsi) 175 | 176 | for i in range(iters): 177 | X,U,V,W,xi,eta,rho = ADMM_Step(X,U,V,W,xi,eta,rho, [H_fft, data, X_divmat, R_divmat]) 178 | if i % disp_pic == 0: 179 | print(i) 180 | image = C(V) 181 | image[image<0] = 0 182 | f = plt.figure(1) 183 | plt.imshow(image, cmap='gray') 184 | plt.title('Reconstruction after iteration {}'.format(i)) 185 | plt.show() 186 | return image 187 | 188 | 189 | 190 | if __name__ == "__main__": 191 | ### Reading in params from config file (don't mess with parameter names!) 192 | params = yaml.load(open("admm_config.yml")) 193 | for k,v in params.items(): 194 | exec(k + "=v") 195 | 196 | ### Loading images and initializing the required arrays 197 | psf, data = loadData(True) 198 | sensor_size = np.array(psf.shape) 199 | full_size = 2*sensor_size 200 | 201 | ### Running the algorithm 202 | final_im = runADMM(psf, data) 203 | plt.imshow(final_im, cmap='gray') 204 | plt.title('Final reconstructed image after {} iterations'.format(iters)) 205 | plt.show() 206 | saveim = input('Save final image? (y/n) ') 207 | if saveim == 'y': 208 | filename = input('Name of file: ') 209 | plt.imshow(final_im, cmap='gray') 210 | plt.axis('off') 211 | plt.savefig(filename+'.png', bbox_inches='tight') 212 | 213 | -------------------------------------------------------------------------------- /GD.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import numpy.fft as fft 4 | import matplotlib.pyplot as plt 5 | from PIL import Image 6 | import yaml 7 | 8 | 9 | def loaddata(show_im=True): 10 | psf = Image.open(psfname) 11 | psf = np.array(psf, dtype='float32') 12 | data = Image.open(imgname) 13 | data = np.array(data, dtype='float32') 14 | 15 | """In the picamera, there is a non-trivial background 16 | (even in the dark) that must be subtracted""" 17 | bg = np.mean(psf[5:15,5:15]) 18 | psf -= bg 19 | data -= bg 20 | 21 | """Resize to a more manageable size to do reconstruction on. 22 | Because resizing is downsampling, it is subject to aliasing 23 | (artifacts produced by the periodic nature of sampling). Demosaicing is an attempt 24 | to account for/reduce the aliasing caused. In this application, we do the simplest 25 | possible demosaicing algorithm: smoothing/blurring the image with a box filter""" 26 | 27 | def resize(img, factor): 28 | num = int(-np.log2(factor)) 29 | for i in range(num): 30 | img = 0.25*(img[::2,::2,...]+img[1::2,::2,...]+img[::2,1::2,...]+img[1::2,1::2,...]) 31 | return img 32 | 33 | psf = resize(psf, f) 34 | data = resize(data, f) 35 | 36 | 37 | """ nmormalizing copy from shreyas""" 38 | psf /= np.linalg.norm(psf.ravel()) 39 | data /= np.linalg.norm(data.ravel()) 40 | 41 | if show_im: 42 | fig1 = plt.figure() 43 | plt.imshow(psf, cmap='gray') 44 | plt.title('PSF') 45 | plt.show() 46 | fig2 = plt.figure() 47 | plt.imshow(data, cmap='gray') 48 | plt.title('Raw data') 49 | plt.show() 50 | return psf, data 51 | 52 | def initMatrices(h): 53 | pixel_start = (np.max(h) + np.min(h))/2 54 | x = np.ones(h.shape)*pixel_start 55 | 56 | init_shape = h.shape 57 | padded_shape = [nextPow2(2*n - 1) for n in init_shape] 58 | starti = (padded_shape[0]- init_shape[0])//2 59 | endi = starti + init_shape[0] 60 | startj = (padded_shape[1]//2) - (init_shape[1]//2) 61 | endj = startj + init_shape[1] 62 | hpad = np.zeros(padded_shape) 63 | hpad[starti:endi, startj:endj] = h 64 | 65 | H = fft.fft2(hpad, norm="ortho") 66 | Hadj = np.conj(H) 67 | 68 | def crop(X): 69 | return X[starti:endi, startj:endj] 70 | 71 | def pad(v): 72 | vpad = np.zeros(padded_shape).astype(np.complex64) 73 | vpad[starti:endi, startj:endj] = v 74 | return vpad 75 | 76 | utils = [crop, pad] 77 | v = np.real(pad(x)) 78 | 79 | return H, Hadj, v, utils 80 | 81 | def nextPow2(n): 82 | return int(2**np.ceil(np.log2(n))) 83 | 84 | def grad(Hadj, H, vk, b, crop, pad): 85 | Av = calcA(H, vk, crop) 86 | diff = Av - b 87 | return np.real(calcAHerm(Hadj, diff, pad)) 88 | 89 | def calcA(H, vk, crop): 90 | Vk = fft.fft2(vk, norm="ortho") 91 | return crop(fft.ifftshift(fft.ifft2(H*Vk, norm="ortho"))) 92 | 93 | def calcAHerm(Hadj, diff, pad): 94 | xpad = pad(diff) 95 | X = fft.fft2(xpad, norm="ortho") 96 | return fft.ifftshift(fft.ifft2(Hadj*X, norm="ortho")) 97 | 98 | 99 | def grad_descent(h, b): 100 | H, Hadj, v, utils = initMatrices(h) 101 | crop = utils[0] 102 | pad = utils[1] 103 | 104 | alpha = np.real(2/(np.max(Hadj * H))) 105 | iterations = 0 106 | 107 | def non_neg(xi): 108 | xi = np.maximum(xi,0) 109 | return xi 110 | 111 | #proj = lambda x: x #Do no projection 112 | proj = non_neg #Enforce nonnegativity at every gradient step. Comment out as needed. 113 | 114 | 115 | parent_var = [H, Hadj, b, crop, pad, alpha, proj] 116 | 117 | vk = v 118 | 119 | 120 | 121 | #### uncomment for Nesterov momentum update #### 122 | #p = 0 123 | #mu = 0.9 124 | ################################################ 125 | 126 | 127 | 128 | #### uncomment for FISTA update ################ 129 | tk = 1 130 | xk = v 131 | ################################################ 132 | 133 | for iterations in range(iters): 134 | 135 | # uncomment for regular GD update 136 | #vk = gd_update(vk, parent_var) 137 | 138 | # uncomment for Nesterov momentum update 139 | #vk, p = nesterov_update(vk, p, mu, parent_var) 140 | 141 | # uncomment for FISTA update 142 | vk, tk, xk = fista_update(vk, tk, xk, parent_var) 143 | 144 | if iterations % disp_pic == 0: 145 | print(iterations) 146 | image = proj(crop(vk)) 147 | f = plt.figure(1) 148 | plt.imshow(image, cmap='gray') 149 | plt.title('Reconstruction after iteration {}'.format(iterations)) 150 | plt.show() 151 | 152 | 153 | return proj(crop(vk)) 154 | 155 | def gd_update(vk, parent_var): 156 | H, Hadj, b, crop, pad, alpha, proj = parent_var 157 | 158 | gradient = grad(Hadj, H, vk, b, crop, pad) 159 | vk -= alpha*gradient 160 | vk = proj(vk) 161 | 162 | return xk 163 | 164 | def nesterov_update(vk, p, mu, parent_var): 165 | H, Hadj, b, crop, pad, alpha, proj = parent_var 166 | 167 | p_prev = p 168 | gradient = grad(Hadj, H, vk, b, crop, pad) 169 | p = mu*p - alpha*gradient 170 | vk += -mu*p_prev + (1+mu)*p 171 | vk = proj(vk) 172 | 173 | return vk, p 174 | 175 | def fista_update(vk, tk, xk, parent_var): 176 | H, Hadj, b, crop, pad, alpha, proj = parent_var 177 | 178 | x_k1 = xk 179 | gradient = grad(Hadj, H, vk, b, crop, pad) 180 | vk -= alpha*gradient 181 | xk = proj(vk) 182 | t_k1 = (1+np.sqrt(1+4*tk**2))/2 183 | vk = xk+(tk-1)/t_k1*(xk - x_k1) 184 | tk = t_k1 185 | 186 | return vk, tk, xk 187 | 188 | 189 | if __name__ == "__main__": 190 | ### Reading in params from config file (don't mess with parameter names!) 191 | params = yaml.load(open("gd_config.yml")) 192 | for k,v in params.items(): 193 | exec(k + "=v") 194 | 195 | psf, data = loaddata() 196 | final_im = grad_descent(psf, data) 197 | print(iters) 198 | plt.imshow(final_im, cmap='gray') 199 | plt.title('Final reconstruction after {} iterations'.format(iters)) 200 | plt.show() 201 | saveim = input('Save final image? (y/n) ') 202 | if saveim == 'y': 203 | filename = input('Name of file: ') 204 | plt.imshow(final_im, cmap='gray') 205 | plt.axis('off') 206 | plt.savefig(filename+'.png', bbox_inches='tight') 207 | 208 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Waller Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffuserCam-Tutorial 2 | #### See our [full tutorial](https://waller-lab.github.io/DiffuserCam/tutorial) for complete guides on setting up the DiffuserCam hardware and installing and running the software. 3 | Below is an overview of the organization of this repo. 4 |

5 | 6 | #### Home Directory 7 | The base directory contains python code for processing DiffuserCam raw data with two algorithms, gradient descent (`GD.py`) and alternating direction method of multipliers (`ADMM.py`). The corresponding `.yml` files should be modified to include the file path of the raw data that is to be processed. 8 | 9 | #### Rpi Folder 10 | This folder contains python code for previewing and capturing raw images using a Raspberry Pi camera. 11 | 12 | #### Tutorial Folder 13 | This folder contains iPython notebooks that walk the user step-by-step through the two algorithms, gradient descent (`GD.ipynb`) and alternating direction method of multipliers (`ADMM.ipynb`). Sample test data is included. 14 | 15 | #### Test_Images Folder 16 | This folder contains sample images that you can place on a phone or laptop screen for testing your Raspberry Pi DiffuserCam. We recommend you start with `sprial_bw.gif`. 17 | 18 | 19 | -------------------------------------------------------------------------------- /admm_config.yml: -------------------------------------------------------------------------------- 1 | psfname: "./images/psf_box_exp8.tif" #path to psf image 2 | imgname: "./images/baffle_hand.tif" #path to raw data image file 3 | f: 0.25 #Downsampling factor (must be decimal, must be of form 1/2^k where k is positive integer) 4 | disp_pic: 4 #Number of iterations after which we display intermediate reconstruction 5 | mu1: 1.0e-6 #Decimal point is REQUIRED if using scientific notation 6 | mu2: 1.0e-5 7 | mu3: 4.0e-5 8 | tau: 0.0001 9 | iters: 1 10 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: diffuser_cam 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3 6 | - numpy 7 | - matplotlib 8 | - pillow 9 | - pip 10 | - pyyaml 11 | -------------------------------------------------------------------------------- /gd_config.yml: -------------------------------------------------------------------------------- 1 | psfname: "./images/psf_box_exp8.tif" #Path to PSF image 2 | imgname: "./images/baffle_hand.tif" #Path to raw data image 3 | f: 0.125 #Downsampling factor (must be decimal, must be 1/2^k where k is positive integer) 4 | iters: 100 #Number of iterations 5 | disp_pic: 20 #Number of iterations after which we display intermediate reconstruction -------------------------------------------------------------------------------- /rpi/preview.py: -------------------------------------------------------------------------------- 1 | import picamera 2 | import picamera.array 3 | import numpy as np 4 | from PIL import Image 5 | 6 | if __name__== '__main__': 7 | camera = picamera.PiCamera() 8 | camera.resolution = camera.MAX_RESOLUTION 9 | camera.start_preview(resolution=(410,313),fullscreen=False,window=(20,20,820,616)) 10 | camera.exposure_mode = 'auto' 11 | 12 | for i in range(1): 13 | customize = input('Change shutter speed? (y/[n])') 14 | if customize == 'y': 15 | speed = int(input('shutter speed (mus) : ')) 16 | camera.shutter_speed = speed 17 | input('Press enter to take picture ') 18 | stream = picamera.array.PiBayerArray(camera) 19 | camera.capture(stream, 'jpeg', bayer=True) 20 | filename = input('Name of file: ') 21 | arr = np.sum(stream.array,axis=2).astype(np.uint8) 22 | img = Image.fromarray(arr) 23 | img.save(filename) 24 | -------------------------------------------------------------------------------- /test_images/cal_logo_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/test_images/cal_logo_rgb.png -------------------------------------------------------------------------------- /test_images/dog_rgb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/test_images/dog_rgb.jpg -------------------------------------------------------------------------------- /test_images/google_chrome_logo_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/test_images/google_chrome_logo_rgb.png -------------------------------------------------------------------------------- /test_images/spiral_bw.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/test_images/spiral_bw.gif -------------------------------------------------------------------------------- /tutorial/GD.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Gradient Descent and FISTA" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "$\\newcommand\\measurementvec{\\mathbf{b}}\n", 15 | "\\newcommand\\measurementmtx{\\mathbf{A}}\n", 16 | "\\newcommand\\imagevec{\\mathbf{v}}\n", 17 | "\\newcommand\\psf{\\mathbf{h}}\n", 18 | "\\newcommand{\\crop}{\\mathbf{C}}\n", 19 | "\\newcommand\\full{\\mathbf{A}}\n", 20 | "\\newcommand{\\ftpsf}{\\mathbf{H}}$\n", 21 | "Gradient descent is an iterative algorithm that finds the minimum of a convex function by following the slope \"downhill\" until it reaches a minimum. To solve the minimization problem\n", 22 | "\\begin{equation*}\n", 23 | " \\operatorname{minimize} g(\\mathbf{x}),\n", 24 | "\\end{equation*}\n", 25 | "we find the gradient of $g$ wrt $\\mathbf{x}$, $\\nabla_\\mathbf{x} g$, and use the property that the gradient always points in the direction of steepest _ascent_. In order to minimize $g$, we go the other direction:\n", 26 | "$$\\begin{align*}\n", 27 | " \\mathbf{x}_0 &= \\text{ initial guess} \\\\\n", 28 | " \\mathbf{x}_{k+1} &\\leftarrow \\mathbf{x}_k - \\alpha_k \\nabla g(\\mathbf{x}_k),\n", 29 | "\\end{align*}$$\n", 30 | "where $\\alpha$ is a step size that determines how far in the descent direction we go at each iteration.\n", 31 | "\n", 32 | "Applied to our problem:\n", 33 | "$$\\begin{align*}\n", 34 | " g(\\imagevec) &= \\frac{1}{2} \\|\\full\\imagevec- \\measurementvec \\|_2^2 \\\\\n", 35 | " \\nabla_\\imagevec g(\\imagevec) &= \\full^H (\\full\\imagevec-\\measurementvec),\n", 36 | "\\end{align*}$$\n", 37 | "where $\\full^H$ is the adjoint of $\\full$, $\\measurementvec$ is the sensor measurement and $\\imagevec$ is the image of the scene.\n", 38 | "\n", 39 | "We use more efficient variants of this algorithm, like Nesterov Momentum and FISTA, both of which are shown below. \n", 40 | "\n" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "#### Loading and preparing our images" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 1, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "import numpy as np\n", 57 | "import numpy.fft as fft\n", 58 | "import matplotlib.pyplot as plt\n", 59 | "from IPython import display\n", 60 | "from PIL import Image\n", 61 | "\n", 62 | "%matplotlib inline" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "The code takes in two grayscale images: a point spread function (PSF) $\\texttt{psfname}$ and a sensor measurement $\\texttt{imgname}$. The images can be downsampled by a factor $f$, which must be a of the form $1/{2^k}$, for some non negative integer $k$ (typically between 1/2 and 1/8). " 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "psfname = \"./psf_sample.tif\"\n", 79 | "imgname = \"./rawdata_hand_sample.tif\"\n", 80 | "\n", 81 | "# Downsampling factor (used to shrink images)\n", 82 | "f = 1/8 \n", 83 | "\n", 84 | "# Number of iterations\n", 85 | "iters = 100" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 3, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "def loaddata(show_im=True):\n", 95 | " psf = Image.open(psfname)\n", 96 | " psf = np.array(psf, dtype='float32')\n", 97 | " data = Image.open(imgname)\n", 98 | " data = np.array(data, dtype='float32')\n", 99 | " \n", 100 | " \"\"\"In the picamera, there is a non-trivial background \n", 101 | " (even in the dark) that must be subtracted\"\"\"\n", 102 | " bg = np.mean(psf[5:15,5:15]) \n", 103 | " psf -= bg\n", 104 | " data -= bg\n", 105 | " \n", 106 | " \"\"\"Resize to a more manageable size to do reconstruction on. \n", 107 | " Because resizing is downsampling, it is subject to aliasing \n", 108 | " (artifacts produced by the periodic nature of sampling). Demosaicing is an attempt\n", 109 | " to account for/reduce the aliasing caused. In this application, we do the simplest\n", 110 | " possible demosaicing algorithm: smoothing/blurring the image with a box filter\"\"\"\n", 111 | " \n", 112 | " def resize(img, factor):\n", 113 | " num = int(-np.log2(factor))\n", 114 | " for i in range(num):\n", 115 | " img = 0.25*(img[::2,::2,...]+img[1::2,::2,...]+img[::2,1::2,...]+img[1::2,1::2,...])\n", 116 | " return img \n", 117 | " \n", 118 | " psf = resize(psf, f)\n", 119 | " data = resize(data, f)\n", 120 | " \n", 121 | " \n", 122 | " \"\"\"Now we normalize the images so they have the same total power. Technically not a\n", 123 | " necessary step, but the optimal hyperparameters are a function of the total power in \n", 124 | " the PSF (among other things), so it makes sense to standardize it\"\"\"\n", 125 | " \n", 126 | " psf /= np.linalg.norm(psf.ravel())\n", 127 | " data /= np.linalg.norm(data.ravel())\n", 128 | " \n", 129 | " if show_im:\n", 130 | " fig1 = plt.figure()\n", 131 | " plt.imshow(psf, cmap='gray')\n", 132 | " plt.title('PSF')\n", 133 | " display.display(fig1)\n", 134 | " fig2 = plt.figure()\n", 135 | " plt.imshow(data, cmap='gray')\n", 136 | " plt.title('Raw data')\n", 137 | " display.display(fig2)\n", 138 | " return psf, data\n", 139 | " " 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "### Calculating convolutions using $\\texttt{fft}$\n", 147 | "We want to calculate convolutions efficiently. To do this, we use the \"fast fourier transform\" $\\texttt{fft2}$ which computes the Discrete Fourier Transform (DFT). The convolution theorem for DFTs only holds for circular convolutions. We can still recover a linear convolution by first padding the input images then cropping the output of the inverse DFT:\n", 148 | "\\begin{equation}\n", 149 | "h*x=\\mathcal{F}^{-1}[\\mathcal{F}[h]\\cdot\\mathcal{F}[x]] = \\texttt{crop}\\left[\\ \\texttt{DFT}^{-1}\\left\\{\\ \\texttt{DFT} [\\ \\texttt{pad}[h]\\ ]\\cdot\\texttt{DFT}[\\ \\texttt{pad}[x]\\ ]\\ \\right\\} \\ \\right]\n", 150 | "\\end{equation}\n", 151 | "\n", 152 | "Recovering the linear convolution correctly requires that we double the dimensions of our images. To take full advantage of the speed of the $\\texttt{fft2}$ algorithm, we actually pad $\\texttt{full_size}$, which is the nearest power of two that is larger than that size.\n", 153 | "\n", 154 | "We have chosen $\\texttt{full_size}$ in such a way that it provides enough padding to make circular and linear convolutions look the same after being cropped back down to $\\texttt{sensor_size}$. That way, the \"sensor crop\" due to the sensor's finite size and the \"fft crop\" above are the same, and we just need one crop function." 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "Along with initialization, we compute $\\texttt{H} = \\texttt{fft2}(\\texttt{hpad})$ and $\\texttt{Hadj} = \\texttt{H}^*$, which are constant matrices that will be needed to calculate the action of $\\measurementmtx$ and $\\measurementmtx^H$ at every iteration. " 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "Lastly, we must take into account one more practical difference. In imaging, we often treat the center of the image as the origin of the coordinate system. This is theoretically convenient, but fft algorithms assume the origin of the image is the top left pixel. The magnitude of the fft doesn't change because of this distinction, but the phase does, since it is sensitive to shifts in real space. An example with the simplest function, a delta function, is displayed below. In order to correct this problem, we use $\\texttt{ifftshift}$ to move the origin of an image to the top left corner and $\\texttt{fftshift}$ to move the origin from the top left corner to the center. " 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 4, 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "data": { 178 | "image/png": "\n", 179 | "text/plain": [ 180 | "
" 181 | ] 182 | }, 183 | "metadata": {}, 184 | "output_type": "display_data" 185 | }, 186 | { 187 | "data": { 188 | "image/png": "\n", 189 | "text/plain": [ 190 | "
" 191 | ] 192 | }, 193 | "metadata": {}, 194 | "output_type": "display_data" 195 | } 196 | ], 197 | "source": [ 198 | "def no_shift():\n", 199 | " delta = np.zeros((5,5))\n", 200 | " delta[2][2] = 1\n", 201 | " fft_mag = np.abs(fft.fft2(delta))\n", 202 | " fft_arg = np.angle(fft.fft2(delta))\n", 203 | " \n", 204 | " fig, ax = plt.subplots(nrows=1, ncols=3)\n", 205 | " fig.tight_layout()\n", 206 | " ax[0].imshow(delta, cmap='gray')\n", 207 | " ax[0].set_title('Delta function in \\n real space')\n", 208 | "\n", 209 | " ax[1].imshow(fft_mag,vmin=-3,vmax=3,cmap='gray')\n", 210 | " ax[1].set_title('Magnitude of FT of \\n a delta function')\n", 211 | " \n", 212 | " ax[2].imshow(fft_arg,vmin=-3,vmax=3,cmap='gray')\n", 213 | " ax[2].set_title('Phase of FT of \\n delta function')\n", 214 | " \n", 215 | "no_shift() \n", 216 | "\n", 217 | "def shift():\n", 218 | " delta = np.zeros((5,5))\n", 219 | " delta[2][2] = 1\n", 220 | " delta_shifted = fft.ifftshift(delta)\n", 221 | " fft_mag = np.abs(fft.fft2(delta_shifted))\n", 222 | " fft_arg = np.angle(fft.fft2(delta_shifted))\n", 223 | " \n", 224 | " fig2, ax2 = plt.subplots(nrows=1, ncols=3)\n", 225 | " fig2.tight_layout()\n", 226 | " ax2[0].imshow(delta_shifted, cmap='gray')\n", 227 | " ax2[0].set_title('Delta function shifted in \\n real space')\n", 228 | "\n", 229 | " ax2[1].imshow(fft_mag,vmin=-3,vmax=3,cmap='gray')\n", 230 | " ax2[1].set_title('Magnitude of FT of a \\n shifted delta function')\n", 231 | " \n", 232 | " ax2[2].imshow(fft_arg,vmin=-3,vmax=3,cmap='gray')\n", 233 | " ax2[2].set_title('Phase of FT of a \\n shifted delta function')\n", 234 | " \n", 235 | "shift()" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": {}, 241 | "source": [ 242 | "For this notebook and the ADMM notebook, we follow the following convention so we don't have to worry about this issue again:\n", 243 | "1. All images in _real_ space are stored with the origin in the center (so they can be displayed correctly)\n", 244 | "2. All images in _Fourier_ space are stored with the origin in the top left corner (so they can be used for processing correctly)\n", 245 | "3. The above rules mean that, to perform a convolution between two real space images $h$ and $x$, we do $$\\texttt{fftshift}( \\texttt{ifft} [\\texttt{fft}[ \\texttt{ifftshift}(h) \\cdot \\texttt{ifftshift}(x) ] ] )$$ instead of $$\\texttt{ifft}[\\texttt{fft}[h \\cdot x]]$$\n", 246 | "The rules imply that if we store the fourier transform of $h$ for future use, instead of storing $\\texttt{fft}[h]$, we store $\\texttt{fft}[\\texttt{ifftshift}(h)]$." 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 5, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "def initMatrices(h):\n", 256 | " pixel_start = (np.max(h) + np.min(h))/2\n", 257 | " x = np.ones(h.shape)*pixel_start\n", 258 | "\n", 259 | " init_shape = h.shape\n", 260 | " padded_shape = [nextPow2(2*n - 1) for n in init_shape]\n", 261 | " starti = (padded_shape[0]- init_shape[0])//2\n", 262 | " endi = starti + init_shape[0]\n", 263 | " startj = (padded_shape[1]//2) - (init_shape[1]//2)\n", 264 | " endj = startj + init_shape[1]\n", 265 | " hpad = np.zeros(padded_shape)\n", 266 | " hpad[starti:endi, startj:endj] = h\n", 267 | "\n", 268 | " H = fft.fft2(fft.ifftshift(hpad), norm=\"ortho\")\n", 269 | " Hadj = np.conj(H)\n", 270 | "\n", 271 | " def crop(X):\n", 272 | " return X[starti:endi, startj:endj]\n", 273 | "\n", 274 | " def pad(v):\n", 275 | " vpad = np.zeros(padded_shape).astype(np.complex64)\n", 276 | " vpad[starti:endi, startj:endj] = v\n", 277 | " return vpad\n", 278 | "\n", 279 | " utils = [crop, pad]\n", 280 | " v = np.real(pad(x))\n", 281 | " \n", 282 | " return H, Hadj, v, utils\n", 283 | "\n", 284 | "def nextPow2(n):\n", 285 | " return int(2**np.ceil(np.log2(n)))" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "#### Computing the gradient" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "metadata": {}, 298 | "source": [ 299 | "The most important step in Gradient Descent is calculating the gradient\n", 300 | "$$ \\nabla_\\imagevec \\ g(\\imagevec) = \\full^H (\\full\\imagevec-\\measurementvec)$$\n", 301 | "We do this in 2 steps:\n", 302 | "1. We compute the action of $\\full$ on $\\imagevec$, using $\\texttt{calcA}$\n", 303 | "2. We compute the action of $\\full^H$ on $\\texttt{diff} = \\texttt{Av-b}$ using $\\texttt{calcAHerm}$
\n", 304 | "\n", 305 | "Here, $\\texttt{vk}$ is the current padded estimate of the scene and $\\texttt{b}$ is the sensor measurement.\n" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 6, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "def grad(Hadj, H, vk, b, crop, pad):\n", 315 | " Av = calcA(H, vk, crop)\n", 316 | " diff = Av - b\n", 317 | " return np.real(calcAHerm(Hadj, diff, pad))" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": {}, 323 | "source": [ 324 | "We write $\\full$ as:\n", 325 | "$$ \\full\\imagevec \\iff \\mathrm{crop} \\left[ \\mathcal{F}^{-1} \\left\\{\\mathcal{F}(h) \\cdot \\mathcal{F}(v)\\right\\} \\right]$$\n", 326 | "In code, this becomes\n", 327 | "\\begin{align*} \n", 328 | "\\texttt{calcA}(\\texttt{vk}) & = \\texttt{crop}\\ (\\texttt{ifft}\\ (\\texttt{fft}(\\texttt{hpad}) \\cdot \\texttt{fft}(\\texttt{vk})\\ )\\ )\\\\\n", 329 | "& = \\texttt{crop}\\ (\\texttt{ifft}\\ (\\texttt{H} \\cdot \\texttt{Vk}))\n", 330 | "\\end{align*}\n", 331 | "where $\\cdot$ represents point-wise multiplication" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 7, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "def calcA(H, vk, crop):\n", 341 | " Vk = fft.fft2(fft.ifftshift(vk))\n", 342 | " return crop(fft.fftshift(fft.ifft2(H*Vk)))" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "metadata": {}, 348 | "source": [ 349 | "We first pad $\\texttt{diff}$ , giving us $\\texttt{xpad}$, then we take the 2D fourier transform, $\\texttt{X} = \\mathcal{F}(\\texttt{xpad})$. The action of the adjoint of $A$ is\n", 350 | "\n", 351 | "$$ A^H \\mathbf{x} \\iff \\mathcal{F}^{-1} \\left\\{ \\mathcal{F}(\\psf)^* \\cdot \\mathcal{F}( \\operatorname{pad}\\left[x\\right]) \\right\\}$$\n", 352 | "This becomes\n", 353 | "\\begin{align*}\n", 354 | "\\texttt{calcAHerm}(\\texttt{xk}) &= \\texttt{ifft}\\ (\\ (\\texttt{fft}(\\texttt{h}))^H \\cdot \\texttt{fft}\\ (\\texttt{pad}(\\texttt{diff}))\\ ) \\\\\n", 355 | "& = \\texttt{ifft}\\ (\\texttt{Hadj} \\cdot \\texttt{X})\n", 356 | "\\end{align*}" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 8, 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "def calcAHerm(Hadj, diff, pad):\n", 366 | " xpad = pad(diff)\n", 367 | " X = fft.fft2(fft.ifftshift(xpad))\n", 368 | " return fft.fftshift(fft.ifft2(Hadj*X))" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": {}, 374 | "source": [ 375 | "#### Putting it all together" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "metadata": {}, 381 | "source": [ 382 | "This is the main function, which calculates the gradients and updates our estimation of the scene:\n", 383 | "\\begin{align*}\n", 384 | " &\\imagevec_0 = \\text{ anything} \\\\\n", 385 | " &\\text{for } k = 0 \\text{ to num_iters:}\\\\\n", 386 | " &\\quad \\quad \\imagevec_{k+1} \\leftarrow \\texttt{gradient_update}(\\imagevec_k) \\\\\n", 387 | "\\end{align*} \n", 388 | "\n", 389 | "There are different ways of doing the gradient update. The three we will show are regular GD, Nesterov momentum update, and FISTA.\n", 390 | "\n", 391 | "To guarantee convergence, we set the step size to be \n", 392 | "\\begin{align*} \n", 393 | "\\alpha_k &< \\frac{2}{\\|\\measurementmtx^H \\measurementmtx\\|}_2 \\approx \\frac{2}{\\lambda_{max}(\\mathbf{M}^H\\mathbf{M})}\n", 394 | "\\end{align*}\n", 395 | "\n", 396 | "To calculate this, we use the property that $\\mathbf{M}$ is diagonalizable by a Fourier Transform:\n", 397 | "$$\\begin{align*}\n", 398 | "\\mathbf{M}^H\\mathbf{M} &= \\left(\\mathbf{F}^{-1} \\mathrm{diag}(\\mathbf{Fh}) \\ \\mathbf{F}\\right)^H \\ \\mathbf{F}^{-1} \\mathrm{diag}(\\mathbf{Fh}) \\ \\mathbf{F} \\\\\n", 399 | "&= \\mathbf{F}^{-1} \\mathrm{diag}(\\mathbf{Fh})^* \\ \\mathrm{diag}(\\mathbf{Fh}) \\ \\mathbf{F} \\\\\n", 400 | "\\lambda_{max}(\\mathbf{M}^H\\mathbf{M}) &= \\max \\left(\\mathrm{diag}(\\mathbf{Fh})^* \\ \\mathrm{diag}(\\mathbf{Fh}) \\right)\n", 401 | "\\end{align*}$$\n", 402 | "In code, we have\n", 403 | "\\begin{align*}\n", 404 | "\\alpha = \\frac{1.8}{\\texttt{max} \\left(\\texttt{Hadj} \\cdot \\texttt{H}\\right)}\n", 405 | "\\end{align*}\n", 406 | "\n", 407 | "\n", 408 | "Since we are dealing with images, one constraint on the reconstructed image $\\imagevec_k$ is that all the entries have to be non-negative. We do this by doing projected gradient descent. The projection function $\\texttt{proj}$ we use is non-negativity, which projects $\\texttt{vk}$ onto the non-negative halfspace. " 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 9, 414 | "metadata": {}, 415 | "outputs": [], 416 | "source": [ 417 | "def grad_descent(h, b):\n", 418 | " H, Hadj, v, utils = initMatrices(h)\n", 419 | " crop = utils[0]\n", 420 | " pad = utils[1]\n", 421 | " \n", 422 | " alpha = np.real(1.8/(np.max(Hadj * H)))\n", 423 | " iterations = 0\n", 424 | " \n", 425 | " def non_neg(xi):\n", 426 | " xi = np.maximum(xi,0)\n", 427 | " return xi\n", 428 | " \n", 429 | " #proj = lambda x:x #Do no projection\n", 430 | " proj = non_neg #Enforce nonnegativity at every gradient step. Comment out as needed.\n", 431 | "\n", 432 | "\n", 433 | " parent_var = [H, Hadj, b, crop, pad, alpha, proj]\n", 434 | " \n", 435 | " vk = v\n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " #### uncomment for Nesterov momentum update #### \n", 440 | " #p = 0\n", 441 | " #mu = 0.9\n", 442 | " ################################################\n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " #### uncomment for FISTA update ################\n", 447 | " tk = 1\n", 448 | " xk = v\n", 449 | " ################################################\n", 450 | " \n", 451 | " for iterations in range(iters): \n", 452 | " \n", 453 | " # uncomment for regular GD update\n", 454 | " #vk = gd_update(vk, parent_var)\n", 455 | " \n", 456 | " # uncomment for Nesterov momentum update \n", 457 | " #vk, p = nesterov_update(vk, p, mu, parent_var)\n", 458 | " \n", 459 | " # uncomment for FISTA update\n", 460 | " vk, tk, xk = fista_update(vk, tk, xk, parent_var)\n", 461 | "\n", 462 | " if iterations % 10 == 0:\n", 463 | " image = proj(crop(vk))\n", 464 | " f = plt.figure(1)\n", 465 | " plt.imshow(image, cmap='gray')\n", 466 | " plt.title('Reconstruction after iteration {}'.format(iterations))\n", 467 | " display.display(f)\n", 468 | " display.clear_output(wait=True)\n", 469 | " \n", 470 | " \n", 471 | " return proj(crop(vk)) \n", 472 | " " 473 | ] 474 | }, 475 | { 476 | "cell_type": "markdown", 477 | "metadata": {}, 478 | "source": [ 479 | "#### Gradient descent algorithms" 480 | ] 481 | }, 482 | { 483 | "cell_type": "markdown", 484 | "metadata": {}, 485 | "source": [ 486 | "##### Regular Gradient Descent\n", 487 | "Regular gradient descent is simply following the negative of the gradient until we reach the minimum:\n", 488 | "\n", 489 | "\\begin{align*}\n", 490 | " & \\texttt{gradient_update}(\\imagevec_k): \\\\\n", 491 | " &\\quad \\quad\\imagevec'_{k+1} \\leftarrow \\imagevec_k - \\alpha_k \\full^H(\\full\\imagevec_k - \\measurementvec) \\\\\n", 492 | " &\\quad \\quad \\imagevec_{k+1} \\leftarrow \\operatorname{proj}_{\\imagevec \\geq 0} (\\imagevec_{k+1}')\n", 493 | "\\end{align*} " 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 10, 499 | "metadata": {}, 500 | "outputs": [], 501 | "source": [ 502 | "def gd_update(vk, parent_var):\n", 503 | " H, Hadj, b, crop, pad, alpha, proj = parent_var\n", 504 | " \n", 505 | " gradient = grad(Hadj, H, vk, b, crop, pad)\n", 506 | " vk -= alpha*gradient\n", 507 | " vk = proj(vk)\n", 508 | " \n", 509 | " return xk " 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "metadata": {}, 515 | "source": [ 516 | "##### Nesterov Momentum\n", 517 | "GD works but it's slow. One way to speed it up is to consider a velocity term, $\\mathbf{p}$. Each update becomes:\n", 518 | "\n", 519 | "\\begin{align*}\n", 520 | " &\\texttt{gradient_update}(\\imagevec_k): \\\\\n", 521 | " &\\qquad \\mathbf{p}_{k+1} \\leftarrow \\mu \\mathbf{p}_k - \\alpha_k \\texttt{grad}(\\imagevec_k) \\\\\n", 522 | " &\\qquad \\imagevec_{k+1}' \\leftarrow \\imagevec_k - \\mu \\mathbf{p}_{k} + (1+\\mu)\\mathbf{p}_{k+1} \\\\\n", 523 | " &\\qquad \\imagevec_{k+1} \\leftarrow \\operatorname{proj}_{\\imagevec \\geq 0} (\\imagevec_{k+1}')\n", 524 | "\\end{align*}\n", 525 | "\n", 526 | "The parameter $\\mu$ is called the momentum and is strictly between 0 and 1." 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": 11, 532 | "metadata": {}, 533 | "outputs": [], 534 | "source": [ 535 | "def nesterov_update(vk, p, mu, parent_var):\n", 536 | " H, Hadj, b, crop, pad, alpha, proj = parent_var\n", 537 | " \n", 538 | " p_prev = p\n", 539 | " gradient = grad(Hadj, H, vk, b, crop, pad)\n", 540 | " p = mu*p - alpha*gradient\n", 541 | " vk += -mu*p_prev + (1+mu)*p\n", 542 | " vk = proj(vk)\n", 543 | " \n", 544 | " return vk, p" 545 | ] 546 | }, 547 | { 548 | "cell_type": "markdown", 549 | "metadata": {}, 550 | "source": [ 551 | "##### FISTA \n", 552 | "Instead of Nesterov momentum, we can use FISTA, which speeds up the iterative process. Each iteration of the algorithm is as follows:\n", 553 | "\n", 554 | "\\begin{align*}\n", 555 | " &\\texttt{gradient_update}(\\imagevec_k):\\\\\n", 556 | " &\\qquad \\imagevec_k \\leftarrow \\imagevec_k - \\alpha_k \\texttt{grad}(\\imagevec_k) \\\\ \n", 557 | " &\\qquad x_{k} \\leftarrow \\texttt{proj}(\\imagevec_k) \\\\\n", 558 | " &\\qquad t_{k+1} \\leftarrow \\frac{1+\\sqrt{1+4t_k^2}}{2} \\\\\n", 559 | " &\\qquad \\imagevec_{k+1} \\leftarrow x_{k} + \\frac{t_k-1}{t_{k+1}} (x_{k}-x_{k-1}) \\\\\n", 560 | "\\end{align*}\n" 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "execution_count": 12, 566 | "metadata": {}, 567 | "outputs": [], 568 | "source": [ 569 | "def fista_update(vk, tk, xk, parent_var):\n", 570 | " H, Hadj, b, crop, pad, alpha, proj = parent_var\n", 571 | " \n", 572 | " x_k1 = xk\n", 573 | " gradient = grad(Hadj, H, vk, b, crop, pad)\n", 574 | " vk -= alpha*gradient\n", 575 | " xk = proj(vk)\n", 576 | " t_k1 = (1+np.sqrt(1+4*tk**2))/2\n", 577 | " vk = xk+(tk-1)/t_k1*(xk - x_k1)\n", 578 | " tk = t_k1\n", 579 | " \n", 580 | " return vk, tk, xk" 581 | ] 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "metadata": {}, 586 | "source": [ 587 | "#### Running the algorithm" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": 13, 593 | "metadata": {}, 594 | "outputs": [ 595 | { 596 | "data": { 597 | "image/png": "\n", 598 | "text/plain": [ 599 | "
" 600 | ] 601 | }, 602 | "metadata": {}, 603 | "output_type": "display_data" 604 | }, 605 | { 606 | "data": { 607 | "image/png": "\n", 608 | "text/plain": [ 609 | "
" 610 | ] 611 | }, 612 | "metadata": {}, 613 | "output_type": "display_data" 614 | } 615 | ], 616 | "source": [ 617 | "psf, data = loaddata()\n", 618 | "final_im = grad_descent(psf, data)\n", 619 | "plt.imshow(final_im, cmap='gray')\n", 620 | "plt.title('Final reconstruction after {} iterations'.format(iters))\n", 621 | "display.display()\n" 622 | ] 623 | } 624 | ], 625 | "metadata": { 626 | "kernelspec": { 627 | "display_name": "Python [default]", 628 | "language": "python", 629 | "name": "python3" 630 | }, 631 | "language_info": { 632 | "codemirror_mode": { 633 | "name": "ipython", 634 | "version": 3 635 | }, 636 | "file_extension": ".py", 637 | "mimetype": "text/x-python", 638 | "name": "python", 639 | "nbconvert_exporter": "python", 640 | "pygments_lexer": "ipython3", 641 | "version": "3.6.5" 642 | } 643 | }, 644 | "nbformat": 4, 645 | "nbformat_minor": 2 646 | } 647 | -------------------------------------------------------------------------------- /tutorial/psf_sample.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/tutorial/psf_sample.tif -------------------------------------------------------------------------------- /tutorial/rawdata_hand_sample.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/tutorial/rawdata_hand_sample.tif --------------------------------------------------------------------------------