├── ADMM.py
├── GD.py
├── LICENSE
├── README.md
├── admm_config.yml
├── environment.yml
├── gd_config.yml
├── rpi
└── preview.py
├── test_images
├── cal_logo_rgb.png
├── dog_rgb.jpg
├── google_chrome_logo_rgb.png
└── spiral_bw.gif
└── tutorial
├── ADMM.ipynb
├── GD.ipynb
├── psf_sample.tif
└── rawdata_hand_sample.tif
/ADMM.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.fft as fft
3 | from PIL import Image
4 | import matplotlib.pyplot as plt
5 | import yaml
6 |
7 | def loadData(show_im=True):
8 | psf = Image.open(psfname)
9 | psf = np.array(psf, dtype='float32')
10 | data = Image.open(imgname)
11 | data = np.array(data, dtype='float32')
12 |
13 | """In the picamera, there is a non-trivial background
14 | (even in the dark) that must be subtracted"""
15 | bg = np.mean(psf[5:15,5:15])
16 | psf -= bg
17 | data -= bg
18 |
19 | """Resize to a more manageable size to do reconstruction on.
20 | Because resizing is downsampling, it is subject to aliasing
21 | (artifacts produced by the periodic nature of sampling). Demosaicing is an attempt
22 | to account for/reduce the aliasing caused. In this application, we do the simplest
23 | possible demosaicing algorithm: smoothing/blurring the image with a box filter"""
24 |
25 | def resize(img, factor):
26 | num = int(-np.log2(factor))
27 | for i in range(num):
28 | img = 0.25*(img[::2,::2,...]+img[1::2,::2,...]+img[::2,1::2,...]+img[1::2,1::2,...])
29 | return img
30 |
31 |
32 | psf = resize(psf, f)
33 | data = resize(data, f)
34 |
35 | """Now we normalize the images so they have the same total power. Technically not a
36 | necessary step, but the optimal hyperparameters are a function of the total power in
37 | the PSF (among other things), so it makes sense to standardize it"""
38 |
39 | psf /= np.linalg.norm(psf.ravel())
40 | data /= np.linalg.norm(data.ravel())
41 |
42 | if show_im:
43 | fig1 = plt.figure()
44 | plt.imshow(psf, cmap='gray')
45 | plt.title('PSF')
46 | fig2 = plt.figure()
47 | plt.imshow(data, cmap='gray')
48 | plt.title('Raw data')
49 | plt.show()
50 | return psf, data
51 |
52 | def U_update(eta, image_est, tau):
53 | return SoftThresh(Psi(image_est) + eta/mu2, tau/mu2)
54 |
55 |
56 | def SoftThresh(x, tau):
57 | # numpy automatically applies functions to each element of the array
58 | return np.sign(x)*np.maximum(0, np.abs(x) - tau)
59 |
60 |
61 | def Psi(v):
62 | return np.stack((np.roll(v,1,axis=0) - v, np.roll(v, 1, axis=1) - v), axis=2)
63 |
64 |
65 | def X_update(xi, image_est, H_fft, sensor_reading, X_divmat):
66 | return X_divmat * (xi + mu1*M(image_est, H_fft) + CT(sensor_reading))
67 |
68 |
69 | def M(vk, H_fft):
70 | return np.real(fft.fftshift(fft.ifft2(fft.fft2(fft.ifftshift(vk))*H_fft)))
71 |
72 |
73 | def C(M):
74 | # Image stored as matrix (row-column rather than x-y)
75 | top = (full_size[0] - sensor_size[0])//2
76 | bottom = (full_size[0] + sensor_size[0])//2
77 | left = (full_size[1] - sensor_size[1])//2
78 | right = (full_size[1] + sensor_size[1])//2
79 | return M[top:bottom,left:right]
80 |
81 | def CT(b):
82 | v_pad = (full_size[0] - sensor_size[0])//2
83 | h_pad = (full_size[1] - sensor_size[1])//2
84 | return np.pad(b, ((v_pad, v_pad), (h_pad, h_pad)), 'constant',constant_values=(0,0))
85 |
86 |
87 | def precompute_X_divmat():
88 | """Only call this function once!
89 | Store it in a variable and only use that variable
90 | during every update step"""
91 | return 1./(CT(np.ones(sensor_size)) + mu1)
92 |
93 | def W_update(rho, image_est):
94 | return np.maximum(rho/mu3 + image_est, 0)
95 |
96 | def r_calc(w, rho, u, eta, x, xi, H_fft):
97 | return (mu3*w - rho)+PsiT(mu2*u - eta) + MT(mu1*x - xi, H_fft)
98 |
99 | def V_update(w, rho, u, eta, x, xi, H_fft, R_divmat):
100 | freq_space_result = R_divmat*fft.fft2( fft.ifftshift(r_calc(w, rho, u, eta, x, xi, H_fft)) )
101 | return np.real(fft.fftshift(fft.ifft2(freq_space_result)))
102 |
103 | def PsiT(U):
104 | diff1 = np.roll(U[...,0],-1,axis=0) - U[...,0]
105 | diff2 = np.roll(U[...,1],-1,axis=1) - U[...,1]
106 | return diff1 + diff2
107 |
108 | def MT(x, H_fft):
109 | x_zeroed = fft.ifftshift(x)
110 | return np.real(fft.fftshift(fft.ifft2(fft.fft2(x_zeroed) * np.conj(H_fft))))
111 |
112 | def precompute_PsiTPsi():
113 | PsiTPsi = np.zeros(full_size)
114 | PsiTPsi[0,0] = 4
115 | PsiTPsi[0,1] = PsiTPsi[1,0] = PsiTPsi[0,-1] = PsiTPsi[-1,0] = -1
116 | PsiTPsi = fft.fft2(PsiTPsi)
117 | return PsiTPsi
118 |
119 |
120 | def precompute_R_divmat(H_fft, PsiTPsi):
121 | """Only call this function once!
122 | Store it in a variable and only use that variable
123 | during every update step"""
124 | MTM_component = mu1*(np.abs(np.conj(H_fft)*H_fft))
125 | PsiTPsi_component = mu2*np.abs(PsiTPsi)
126 | id_component = mu3
127 | """This matrix is a mask in frequency space. So we will only use
128 | it on images that have already been transformed via an fft"""
129 | return 1./(MTM_component + PsiTPsi_component + id_component)
130 |
131 | def xi_update(xi, V, H_fft, X):
132 | return xi + mu1*(M(V,H_fft) - X)
133 |
134 | def eta_update(eta, V, U):
135 | return eta + mu2*(Psi(V) - U)
136 |
137 | def rho_update(rho, V, W):
138 | return rho + mu3*(V - W)
139 |
140 |
141 | def init_Matrices(H_fft):
142 | X = np.zeros(full_size)
143 | U = np.zeros((full_size[0], full_size[1], 2))
144 | V = np.zeros(full_size)
145 | W = np.zeros(full_size)
146 |
147 | xi = np.zeros_like(M(V,H_fft))
148 | eta = np.zeros_like(Psi(V))
149 | rho = np.zeros_like(W)
150 | return X,U,V,W,xi,eta,rho
151 |
152 |
153 | def precompute_H_fft(psf):
154 | return fft.fft2(fft.ifftshift(CT(psf)))
155 |
156 | def ADMM_Step(X,U,V,W,xi,eta,rho, precomputed):
157 | H_fft, data, X_divmat, R_divmat = precomputed
158 | U = U_update(eta, V, tau)
159 | X = X_update(xi, V, H_fft, data, X_divmat)
160 | V = V_update(W, rho, U, eta, X, xi, H_fft, R_divmat)
161 | W = W_update(rho, V)
162 | xi = xi_update(xi, V, H_fft, X)
163 | eta = eta_update(eta, V, U)
164 | rho = rho_update(rho, V, W)
165 |
166 | return X,U,V,W,xi,eta,rho
167 |
168 |
169 | def runADMM(psf, data):
170 | H_fft = precompute_H_fft(psf)
171 | X,U,V,W,xi,eta,rho = init_Matrices(H_fft)
172 | X_divmat = precompute_X_divmat()
173 | PsiTPsi = precompute_PsiTPsi()
174 | R_divmat = precompute_R_divmat(H_fft, PsiTPsi)
175 |
176 | for i in range(iters):
177 | X,U,V,W,xi,eta,rho = ADMM_Step(X,U,V,W,xi,eta,rho, [H_fft, data, X_divmat, R_divmat])
178 | if i % disp_pic == 0:
179 | print(i)
180 | image = C(V)
181 | image[image<0] = 0
182 | f = plt.figure(1)
183 | plt.imshow(image, cmap='gray')
184 | plt.title('Reconstruction after iteration {}'.format(i))
185 | plt.show()
186 | return image
187 |
188 |
189 |
190 | if __name__ == "__main__":
191 | ### Reading in params from config file (don't mess with parameter names!)
192 | params = yaml.load(open("admm_config.yml"))
193 | for k,v in params.items():
194 | exec(k + "=v")
195 |
196 | ### Loading images and initializing the required arrays
197 | psf, data = loadData(True)
198 | sensor_size = np.array(psf.shape)
199 | full_size = 2*sensor_size
200 |
201 | ### Running the algorithm
202 | final_im = runADMM(psf, data)
203 | plt.imshow(final_im, cmap='gray')
204 | plt.title('Final reconstructed image after {} iterations'.format(iters))
205 | plt.show()
206 | saveim = input('Save final image? (y/n) ')
207 | if saveim == 'y':
208 | filename = input('Name of file: ')
209 | plt.imshow(final_im, cmap='gray')
210 | plt.axis('off')
211 | plt.savefig(filename+'.png', bbox_inches='tight')
212 |
213 |
--------------------------------------------------------------------------------
/GD.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import numpy.fft as fft
4 | import matplotlib.pyplot as plt
5 | from PIL import Image
6 | import yaml
7 |
8 |
9 | def loaddata(show_im=True):
10 | psf = Image.open(psfname)
11 | psf = np.array(psf, dtype='float32')
12 | data = Image.open(imgname)
13 | data = np.array(data, dtype='float32')
14 |
15 | """In the picamera, there is a non-trivial background
16 | (even in the dark) that must be subtracted"""
17 | bg = np.mean(psf[5:15,5:15])
18 | psf -= bg
19 | data -= bg
20 |
21 | """Resize to a more manageable size to do reconstruction on.
22 | Because resizing is downsampling, it is subject to aliasing
23 | (artifacts produced by the periodic nature of sampling). Demosaicing is an attempt
24 | to account for/reduce the aliasing caused. In this application, we do the simplest
25 | possible demosaicing algorithm: smoothing/blurring the image with a box filter"""
26 |
27 | def resize(img, factor):
28 | num = int(-np.log2(factor))
29 | for i in range(num):
30 | img = 0.25*(img[::2,::2,...]+img[1::2,::2,...]+img[::2,1::2,...]+img[1::2,1::2,...])
31 | return img
32 |
33 | psf = resize(psf, f)
34 | data = resize(data, f)
35 |
36 |
37 | """ nmormalizing copy from shreyas"""
38 | psf /= np.linalg.norm(psf.ravel())
39 | data /= np.linalg.norm(data.ravel())
40 |
41 | if show_im:
42 | fig1 = plt.figure()
43 | plt.imshow(psf, cmap='gray')
44 | plt.title('PSF')
45 | plt.show()
46 | fig2 = plt.figure()
47 | plt.imshow(data, cmap='gray')
48 | plt.title('Raw data')
49 | plt.show()
50 | return psf, data
51 |
52 | def initMatrices(h):
53 | pixel_start = (np.max(h) + np.min(h))/2
54 | x = np.ones(h.shape)*pixel_start
55 |
56 | init_shape = h.shape
57 | padded_shape = [nextPow2(2*n - 1) for n in init_shape]
58 | starti = (padded_shape[0]- init_shape[0])//2
59 | endi = starti + init_shape[0]
60 | startj = (padded_shape[1]//2) - (init_shape[1]//2)
61 | endj = startj + init_shape[1]
62 | hpad = np.zeros(padded_shape)
63 | hpad[starti:endi, startj:endj] = h
64 |
65 | H = fft.fft2(hpad, norm="ortho")
66 | Hadj = np.conj(H)
67 |
68 | def crop(X):
69 | return X[starti:endi, startj:endj]
70 |
71 | def pad(v):
72 | vpad = np.zeros(padded_shape).astype(np.complex64)
73 | vpad[starti:endi, startj:endj] = v
74 | return vpad
75 |
76 | utils = [crop, pad]
77 | v = np.real(pad(x))
78 |
79 | return H, Hadj, v, utils
80 |
81 | def nextPow2(n):
82 | return int(2**np.ceil(np.log2(n)))
83 |
84 | def grad(Hadj, H, vk, b, crop, pad):
85 | Av = calcA(H, vk, crop)
86 | diff = Av - b
87 | return np.real(calcAHerm(Hadj, diff, pad))
88 |
89 | def calcA(H, vk, crop):
90 | Vk = fft.fft2(vk, norm="ortho")
91 | return crop(fft.ifftshift(fft.ifft2(H*Vk, norm="ortho")))
92 |
93 | def calcAHerm(Hadj, diff, pad):
94 | xpad = pad(diff)
95 | X = fft.fft2(xpad, norm="ortho")
96 | return fft.ifftshift(fft.ifft2(Hadj*X, norm="ortho"))
97 |
98 |
99 | def grad_descent(h, b):
100 | H, Hadj, v, utils = initMatrices(h)
101 | crop = utils[0]
102 | pad = utils[1]
103 |
104 | alpha = np.real(2/(np.max(Hadj * H)))
105 | iterations = 0
106 |
107 | def non_neg(xi):
108 | xi = np.maximum(xi,0)
109 | return xi
110 |
111 | #proj = lambda x: x #Do no projection
112 | proj = non_neg #Enforce nonnegativity at every gradient step. Comment out as needed.
113 |
114 |
115 | parent_var = [H, Hadj, b, crop, pad, alpha, proj]
116 |
117 | vk = v
118 |
119 |
120 |
121 | #### uncomment for Nesterov momentum update ####
122 | #p = 0
123 | #mu = 0.9
124 | ################################################
125 |
126 |
127 |
128 | #### uncomment for FISTA update ################
129 | tk = 1
130 | xk = v
131 | ################################################
132 |
133 | for iterations in range(iters):
134 |
135 | # uncomment for regular GD update
136 | #vk = gd_update(vk, parent_var)
137 |
138 | # uncomment for Nesterov momentum update
139 | #vk, p = nesterov_update(vk, p, mu, parent_var)
140 |
141 | # uncomment for FISTA update
142 | vk, tk, xk = fista_update(vk, tk, xk, parent_var)
143 |
144 | if iterations % disp_pic == 0:
145 | print(iterations)
146 | image = proj(crop(vk))
147 | f = plt.figure(1)
148 | plt.imshow(image, cmap='gray')
149 | plt.title('Reconstruction after iteration {}'.format(iterations))
150 | plt.show()
151 |
152 |
153 | return proj(crop(vk))
154 |
155 | def gd_update(vk, parent_var):
156 | H, Hadj, b, crop, pad, alpha, proj = parent_var
157 |
158 | gradient = grad(Hadj, H, vk, b, crop, pad)
159 | vk -= alpha*gradient
160 | vk = proj(vk)
161 |
162 | return xk
163 |
164 | def nesterov_update(vk, p, mu, parent_var):
165 | H, Hadj, b, crop, pad, alpha, proj = parent_var
166 |
167 | p_prev = p
168 | gradient = grad(Hadj, H, vk, b, crop, pad)
169 | p = mu*p - alpha*gradient
170 | vk += -mu*p_prev + (1+mu)*p
171 | vk = proj(vk)
172 |
173 | return vk, p
174 |
175 | def fista_update(vk, tk, xk, parent_var):
176 | H, Hadj, b, crop, pad, alpha, proj = parent_var
177 |
178 | x_k1 = xk
179 | gradient = grad(Hadj, H, vk, b, crop, pad)
180 | vk -= alpha*gradient
181 | xk = proj(vk)
182 | t_k1 = (1+np.sqrt(1+4*tk**2))/2
183 | vk = xk+(tk-1)/t_k1*(xk - x_k1)
184 | tk = t_k1
185 |
186 | return vk, tk, xk
187 |
188 |
189 | if __name__ == "__main__":
190 | ### Reading in params from config file (don't mess with parameter names!)
191 | params = yaml.load(open("gd_config.yml"))
192 | for k,v in params.items():
193 | exec(k + "=v")
194 |
195 | psf, data = loaddata()
196 | final_im = grad_descent(psf, data)
197 | print(iters)
198 | plt.imshow(final_im, cmap='gray')
199 | plt.title('Final reconstruction after {} iterations'.format(iters))
200 | plt.show()
201 | saveim = input('Save final image? (y/n) ')
202 | if saveim == 'y':
203 | filename = input('Name of file: ')
204 | plt.imshow(final_im, cmap='gray')
205 | plt.axis('off')
206 | plt.savefig(filename+'.png', bbox_inches='tight')
207 |
208 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2018, Waller Lab
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DiffuserCam-Tutorial
2 | #### See our [full tutorial](https://waller-lab.github.io/DiffuserCam/tutorial) for complete guides on setting up the DiffuserCam hardware and installing and running the software.
3 | Below is an overview of the organization of this repo.
4 |
5 |
6 | #### Home Directory
7 | The base directory contains python code for processing DiffuserCam raw data with two algorithms, gradient descent (`GD.py`) and alternating direction method of multipliers (`ADMM.py`). The corresponding `.yml` files should be modified to include the file path of the raw data that is to be processed.
8 |
9 | #### Rpi Folder
10 | This folder contains python code for previewing and capturing raw images using a Raspberry Pi camera.
11 |
12 | #### Tutorial Folder
13 | This folder contains iPython notebooks that walk the user step-by-step through the two algorithms, gradient descent (`GD.ipynb`) and alternating direction method of multipliers (`ADMM.ipynb`). Sample test data is included.
14 |
15 | #### Test_Images Folder
16 | This folder contains sample images that you can place on a phone or laptop screen for testing your Raspberry Pi DiffuserCam. We recommend you start with `sprial_bw.gif`.
17 |
18 |
19 |
--------------------------------------------------------------------------------
/admm_config.yml:
--------------------------------------------------------------------------------
1 | psfname: "./images/psf_box_exp8.tif" #path to psf image
2 | imgname: "./images/baffle_hand.tif" #path to raw data image file
3 | f: 0.25 #Downsampling factor (must be decimal, must be of form 1/2^k where k is positive integer)
4 | disp_pic: 4 #Number of iterations after which we display intermediate reconstruction
5 | mu1: 1.0e-6 #Decimal point is REQUIRED if using scientific notation
6 | mu2: 1.0e-5
7 | mu3: 4.0e-5
8 | tau: 0.0001
9 | iters: 1
10 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: diffuser_cam
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - python=3
6 | - numpy
7 | - matplotlib
8 | - pillow
9 | - pip
10 | - pyyaml
11 |
--------------------------------------------------------------------------------
/gd_config.yml:
--------------------------------------------------------------------------------
1 | psfname: "./images/psf_box_exp8.tif" #Path to PSF image
2 | imgname: "./images/baffle_hand.tif" #Path to raw data image
3 | f: 0.125 #Downsampling factor (must be decimal, must be 1/2^k where k is positive integer)
4 | iters: 100 #Number of iterations
5 | disp_pic: 20 #Number of iterations after which we display intermediate reconstruction
--------------------------------------------------------------------------------
/rpi/preview.py:
--------------------------------------------------------------------------------
1 | import picamera
2 | import picamera.array
3 | import numpy as np
4 | from PIL import Image
5 |
6 | if __name__== '__main__':
7 | camera = picamera.PiCamera()
8 | camera.resolution = camera.MAX_RESOLUTION
9 | camera.start_preview(resolution=(410,313),fullscreen=False,window=(20,20,820,616))
10 | camera.exposure_mode = 'auto'
11 |
12 | for i in range(1):
13 | customize = input('Change shutter speed? (y/[n])')
14 | if customize == 'y':
15 | speed = int(input('shutter speed (mus) : '))
16 | camera.shutter_speed = speed
17 | input('Press enter to take picture ')
18 | stream = picamera.array.PiBayerArray(camera)
19 | camera.capture(stream, 'jpeg', bayer=True)
20 | filename = input('Name of file: ')
21 | arr = np.sum(stream.array,axis=2).astype(np.uint8)
22 | img = Image.fromarray(arr)
23 | img.save(filename)
24 |
--------------------------------------------------------------------------------
/test_images/cal_logo_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/test_images/cal_logo_rgb.png
--------------------------------------------------------------------------------
/test_images/dog_rgb.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/test_images/dog_rgb.jpg
--------------------------------------------------------------------------------
/test_images/google_chrome_logo_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/test_images/google_chrome_logo_rgb.png
--------------------------------------------------------------------------------
/test_images/spiral_bw.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/test_images/spiral_bw.gif
--------------------------------------------------------------------------------
/tutorial/GD.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### Gradient Descent and FISTA"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "$\\newcommand\\measurementvec{\\mathbf{b}}\n",
15 | "\\newcommand\\measurementmtx{\\mathbf{A}}\n",
16 | "\\newcommand\\imagevec{\\mathbf{v}}\n",
17 | "\\newcommand\\psf{\\mathbf{h}}\n",
18 | "\\newcommand{\\crop}{\\mathbf{C}}\n",
19 | "\\newcommand\\full{\\mathbf{A}}\n",
20 | "\\newcommand{\\ftpsf}{\\mathbf{H}}$\n",
21 | "Gradient descent is an iterative algorithm that finds the minimum of a convex function by following the slope \"downhill\" until it reaches a minimum. To solve the minimization problem\n",
22 | "\\begin{equation*}\n",
23 | " \\operatorname{minimize} g(\\mathbf{x}),\n",
24 | "\\end{equation*}\n",
25 | "we find the gradient of $g$ wrt $\\mathbf{x}$, $\\nabla_\\mathbf{x} g$, and use the property that the gradient always points in the direction of steepest _ascent_. In order to minimize $g$, we go the other direction:\n",
26 | "$$\\begin{align*}\n",
27 | " \\mathbf{x}_0 &= \\text{ initial guess} \\\\\n",
28 | " \\mathbf{x}_{k+1} &\\leftarrow \\mathbf{x}_k - \\alpha_k \\nabla g(\\mathbf{x}_k),\n",
29 | "\\end{align*}$$\n",
30 | "where $\\alpha$ is a step size that determines how far in the descent direction we go at each iteration.\n",
31 | "\n",
32 | "Applied to our problem:\n",
33 | "$$\\begin{align*}\n",
34 | " g(\\imagevec) &= \\frac{1}{2} \\|\\full\\imagevec- \\measurementvec \\|_2^2 \\\\\n",
35 | " \\nabla_\\imagevec g(\\imagevec) &= \\full^H (\\full\\imagevec-\\measurementvec),\n",
36 | "\\end{align*}$$\n",
37 | "where $\\full^H$ is the adjoint of $\\full$, $\\measurementvec$ is the sensor measurement and $\\imagevec$ is the image of the scene.\n",
38 | "\n",
39 | "We use more efficient variants of this algorithm, like Nesterov Momentum and FISTA, both of which are shown below. \n",
40 | "\n"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {},
46 | "source": [
47 | "#### Loading and preparing our images"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 1,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "import numpy as np\n",
57 | "import numpy.fft as fft\n",
58 | "import matplotlib.pyplot as plt\n",
59 | "from IPython import display\n",
60 | "from PIL import Image\n",
61 | "\n",
62 | "%matplotlib inline"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {},
68 | "source": [
69 | "The code takes in two grayscale images: a point spread function (PSF) $\\texttt{psfname}$ and a sensor measurement $\\texttt{imgname}$. The images can be downsampled by a factor $f$, which must be a of the form $1/{2^k}$, for some non negative integer $k$ (typically between 1/2 and 1/8). "
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 2,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "psfname = \"./psf_sample.tif\"\n",
79 | "imgname = \"./rawdata_hand_sample.tif\"\n",
80 | "\n",
81 | "# Downsampling factor (used to shrink images)\n",
82 | "f = 1/8 \n",
83 | "\n",
84 | "# Number of iterations\n",
85 | "iters = 100"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 3,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "def loaddata(show_im=True):\n",
95 | " psf = Image.open(psfname)\n",
96 | " psf = np.array(psf, dtype='float32')\n",
97 | " data = Image.open(imgname)\n",
98 | " data = np.array(data, dtype='float32')\n",
99 | " \n",
100 | " \"\"\"In the picamera, there is a non-trivial background \n",
101 | " (even in the dark) that must be subtracted\"\"\"\n",
102 | " bg = np.mean(psf[5:15,5:15]) \n",
103 | " psf -= bg\n",
104 | " data -= bg\n",
105 | " \n",
106 | " \"\"\"Resize to a more manageable size to do reconstruction on. \n",
107 | " Because resizing is downsampling, it is subject to aliasing \n",
108 | " (artifacts produced by the periodic nature of sampling). Demosaicing is an attempt\n",
109 | " to account for/reduce the aliasing caused. In this application, we do the simplest\n",
110 | " possible demosaicing algorithm: smoothing/blurring the image with a box filter\"\"\"\n",
111 | " \n",
112 | " def resize(img, factor):\n",
113 | " num = int(-np.log2(factor))\n",
114 | " for i in range(num):\n",
115 | " img = 0.25*(img[::2,::2,...]+img[1::2,::2,...]+img[::2,1::2,...]+img[1::2,1::2,...])\n",
116 | " return img \n",
117 | " \n",
118 | " psf = resize(psf, f)\n",
119 | " data = resize(data, f)\n",
120 | " \n",
121 | " \n",
122 | " \"\"\"Now we normalize the images so they have the same total power. Technically not a\n",
123 | " necessary step, but the optimal hyperparameters are a function of the total power in \n",
124 | " the PSF (among other things), so it makes sense to standardize it\"\"\"\n",
125 | " \n",
126 | " psf /= np.linalg.norm(psf.ravel())\n",
127 | " data /= np.linalg.norm(data.ravel())\n",
128 | " \n",
129 | " if show_im:\n",
130 | " fig1 = plt.figure()\n",
131 | " plt.imshow(psf, cmap='gray')\n",
132 | " plt.title('PSF')\n",
133 | " display.display(fig1)\n",
134 | " fig2 = plt.figure()\n",
135 | " plt.imshow(data, cmap='gray')\n",
136 | " plt.title('Raw data')\n",
137 | " display.display(fig2)\n",
138 | " return psf, data\n",
139 | " "
140 | ]
141 | },
142 | {
143 | "cell_type": "markdown",
144 | "metadata": {},
145 | "source": [
146 | "### Calculating convolutions using $\\texttt{fft}$\n",
147 | "We want to calculate convolutions efficiently. To do this, we use the \"fast fourier transform\" $\\texttt{fft2}$ which computes the Discrete Fourier Transform (DFT). The convolution theorem for DFTs only holds for circular convolutions. We can still recover a linear convolution by first padding the input images then cropping the output of the inverse DFT:\n",
148 | "\\begin{equation}\n",
149 | "h*x=\\mathcal{F}^{-1}[\\mathcal{F}[h]\\cdot\\mathcal{F}[x]] = \\texttt{crop}\\left[\\ \\texttt{DFT}^{-1}\\left\\{\\ \\texttt{DFT} [\\ \\texttt{pad}[h]\\ ]\\cdot\\texttt{DFT}[\\ \\texttt{pad}[x]\\ ]\\ \\right\\} \\ \\right]\n",
150 | "\\end{equation}\n",
151 | "\n",
152 | "Recovering the linear convolution correctly requires that we double the dimensions of our images. To take full advantage of the speed of the $\\texttt{fft2}$ algorithm, we actually pad $\\texttt{full_size}$, which is the nearest power of two that is larger than that size.\n",
153 | "\n",
154 | "We have chosen $\\texttt{full_size}$ in such a way that it provides enough padding to make circular and linear convolutions look the same after being cropped back down to $\\texttt{sensor_size}$. That way, the \"sensor crop\" due to the sensor's finite size and the \"fft crop\" above are the same, and we just need one crop function."
155 | ]
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "metadata": {},
160 | "source": [
161 | "Along with initialization, we compute $\\texttt{H} = \\texttt{fft2}(\\texttt{hpad})$ and $\\texttt{Hadj} = \\texttt{H}^*$, which are constant matrices that will be needed to calculate the action of $\\measurementmtx$ and $\\measurementmtx^H$ at every iteration. "
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "metadata": {},
167 | "source": [
168 | "Lastly, we must take into account one more practical difference. In imaging, we often treat the center of the image as the origin of the coordinate system. This is theoretically convenient, but fft algorithms assume the origin of the image is the top left pixel. The magnitude of the fft doesn't change because of this distinction, but the phase does, since it is sensitive to shifts in real space. An example with the simplest function, a delta function, is displayed below. In order to correct this problem, we use $\\texttt{ifftshift}$ to move the origin of an image to the top left corner and $\\texttt{fftshift}$ to move the origin from the top left corner to the center. "
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 4,
174 | "metadata": {},
175 | "outputs": [
176 | {
177 | "data": {
178 | "image/png": "\n",
179 | "text/plain": [
180 | ""
181 | ]
182 | },
183 | "metadata": {},
184 | "output_type": "display_data"
185 | },
186 | {
187 | "data": {
188 | "image/png": "\n",
189 | "text/plain": [
190 | ""
191 | ]
192 | },
193 | "metadata": {},
194 | "output_type": "display_data"
195 | }
196 | ],
197 | "source": [
198 | "def no_shift():\n",
199 | " delta = np.zeros((5,5))\n",
200 | " delta[2][2] = 1\n",
201 | " fft_mag = np.abs(fft.fft2(delta))\n",
202 | " fft_arg = np.angle(fft.fft2(delta))\n",
203 | " \n",
204 | " fig, ax = plt.subplots(nrows=1, ncols=3)\n",
205 | " fig.tight_layout()\n",
206 | " ax[0].imshow(delta, cmap='gray')\n",
207 | " ax[0].set_title('Delta function in \\n real space')\n",
208 | "\n",
209 | " ax[1].imshow(fft_mag,vmin=-3,vmax=3,cmap='gray')\n",
210 | " ax[1].set_title('Magnitude of FT of \\n a delta function')\n",
211 | " \n",
212 | " ax[2].imshow(fft_arg,vmin=-3,vmax=3,cmap='gray')\n",
213 | " ax[2].set_title('Phase of FT of \\n delta function')\n",
214 | " \n",
215 | "no_shift() \n",
216 | "\n",
217 | "def shift():\n",
218 | " delta = np.zeros((5,5))\n",
219 | " delta[2][2] = 1\n",
220 | " delta_shifted = fft.ifftshift(delta)\n",
221 | " fft_mag = np.abs(fft.fft2(delta_shifted))\n",
222 | " fft_arg = np.angle(fft.fft2(delta_shifted))\n",
223 | " \n",
224 | " fig2, ax2 = plt.subplots(nrows=1, ncols=3)\n",
225 | " fig2.tight_layout()\n",
226 | " ax2[0].imshow(delta_shifted, cmap='gray')\n",
227 | " ax2[0].set_title('Delta function shifted in \\n real space')\n",
228 | "\n",
229 | " ax2[1].imshow(fft_mag,vmin=-3,vmax=3,cmap='gray')\n",
230 | " ax2[1].set_title('Magnitude of FT of a \\n shifted delta function')\n",
231 | " \n",
232 | " ax2[2].imshow(fft_arg,vmin=-3,vmax=3,cmap='gray')\n",
233 | " ax2[2].set_title('Phase of FT of a \\n shifted delta function')\n",
234 | " \n",
235 | "shift()"
236 | ]
237 | },
238 | {
239 | "cell_type": "markdown",
240 | "metadata": {},
241 | "source": [
242 | "For this notebook and the ADMM notebook, we follow the following convention so we don't have to worry about this issue again:\n",
243 | "1. All images in _real_ space are stored with the origin in the center (so they can be displayed correctly)\n",
244 | "2. All images in _Fourier_ space are stored with the origin in the top left corner (so they can be used for processing correctly)\n",
245 | "3. The above rules mean that, to perform a convolution between two real space images $h$ and $x$, we do $$\\texttt{fftshift}( \\texttt{ifft} [\\texttt{fft}[ \\texttt{ifftshift}(h) \\cdot \\texttt{ifftshift}(x) ] ] )$$ instead of $$\\texttt{ifft}[\\texttt{fft}[h \\cdot x]]$$\n",
246 | "The rules imply that if we store the fourier transform of $h$ for future use, instead of storing $\\texttt{fft}[h]$, we store $\\texttt{fft}[\\texttt{ifftshift}(h)]$."
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 5,
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "def initMatrices(h):\n",
256 | " pixel_start = (np.max(h) + np.min(h))/2\n",
257 | " x = np.ones(h.shape)*pixel_start\n",
258 | "\n",
259 | " init_shape = h.shape\n",
260 | " padded_shape = [nextPow2(2*n - 1) for n in init_shape]\n",
261 | " starti = (padded_shape[0]- init_shape[0])//2\n",
262 | " endi = starti + init_shape[0]\n",
263 | " startj = (padded_shape[1]//2) - (init_shape[1]//2)\n",
264 | " endj = startj + init_shape[1]\n",
265 | " hpad = np.zeros(padded_shape)\n",
266 | " hpad[starti:endi, startj:endj] = h\n",
267 | "\n",
268 | " H = fft.fft2(fft.ifftshift(hpad), norm=\"ortho\")\n",
269 | " Hadj = np.conj(H)\n",
270 | "\n",
271 | " def crop(X):\n",
272 | " return X[starti:endi, startj:endj]\n",
273 | "\n",
274 | " def pad(v):\n",
275 | " vpad = np.zeros(padded_shape).astype(np.complex64)\n",
276 | " vpad[starti:endi, startj:endj] = v\n",
277 | " return vpad\n",
278 | "\n",
279 | " utils = [crop, pad]\n",
280 | " v = np.real(pad(x))\n",
281 | " \n",
282 | " return H, Hadj, v, utils\n",
283 | "\n",
284 | "def nextPow2(n):\n",
285 | " return int(2**np.ceil(np.log2(n)))"
286 | ]
287 | },
288 | {
289 | "cell_type": "markdown",
290 | "metadata": {},
291 | "source": [
292 | "#### Computing the gradient"
293 | ]
294 | },
295 | {
296 | "cell_type": "markdown",
297 | "metadata": {},
298 | "source": [
299 | "The most important step in Gradient Descent is calculating the gradient\n",
300 | "$$ \\nabla_\\imagevec \\ g(\\imagevec) = \\full^H (\\full\\imagevec-\\measurementvec)$$\n",
301 | "We do this in 2 steps:\n",
302 | "1. We compute the action of $\\full$ on $\\imagevec$, using $\\texttt{calcA}$\n",
303 | "2. We compute the action of $\\full^H$ on $\\texttt{diff} = \\texttt{Av-b}$ using $\\texttt{calcAHerm}$
\n",
304 | "\n",
305 | "Here, $\\texttt{vk}$ is the current padded estimate of the scene and $\\texttt{b}$ is the sensor measurement.\n"
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "execution_count": 6,
311 | "metadata": {},
312 | "outputs": [],
313 | "source": [
314 | "def grad(Hadj, H, vk, b, crop, pad):\n",
315 | " Av = calcA(H, vk, crop)\n",
316 | " diff = Av - b\n",
317 | " return np.real(calcAHerm(Hadj, diff, pad))"
318 | ]
319 | },
320 | {
321 | "cell_type": "markdown",
322 | "metadata": {},
323 | "source": [
324 | "We write $\\full$ as:\n",
325 | "$$ \\full\\imagevec \\iff \\mathrm{crop} \\left[ \\mathcal{F}^{-1} \\left\\{\\mathcal{F}(h) \\cdot \\mathcal{F}(v)\\right\\} \\right]$$\n",
326 | "In code, this becomes\n",
327 | "\\begin{align*} \n",
328 | "\\texttt{calcA}(\\texttt{vk}) & = \\texttt{crop}\\ (\\texttt{ifft}\\ (\\texttt{fft}(\\texttt{hpad}) \\cdot \\texttt{fft}(\\texttt{vk})\\ )\\ )\\\\\n",
329 | "& = \\texttt{crop}\\ (\\texttt{ifft}\\ (\\texttt{H} \\cdot \\texttt{Vk}))\n",
330 | "\\end{align*}\n",
331 | "where $\\cdot$ represents point-wise multiplication"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 7,
337 | "metadata": {},
338 | "outputs": [],
339 | "source": [
340 | "def calcA(H, vk, crop):\n",
341 | " Vk = fft.fft2(fft.ifftshift(vk))\n",
342 | " return crop(fft.fftshift(fft.ifft2(H*Vk)))"
343 | ]
344 | },
345 | {
346 | "cell_type": "markdown",
347 | "metadata": {},
348 | "source": [
349 | "We first pad $\\texttt{diff}$ , giving us $\\texttt{xpad}$, then we take the 2D fourier transform, $\\texttt{X} = \\mathcal{F}(\\texttt{xpad})$. The action of the adjoint of $A$ is\n",
350 | "\n",
351 | "$$ A^H \\mathbf{x} \\iff \\mathcal{F}^{-1} \\left\\{ \\mathcal{F}(\\psf)^* \\cdot \\mathcal{F}( \\operatorname{pad}\\left[x\\right]) \\right\\}$$\n",
352 | "This becomes\n",
353 | "\\begin{align*}\n",
354 | "\\texttt{calcAHerm}(\\texttt{xk}) &= \\texttt{ifft}\\ (\\ (\\texttt{fft}(\\texttt{h}))^H \\cdot \\texttt{fft}\\ (\\texttt{pad}(\\texttt{diff}))\\ ) \\\\\n",
355 | "& = \\texttt{ifft}\\ (\\texttt{Hadj} \\cdot \\texttt{X})\n",
356 | "\\end{align*}"
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "execution_count": 8,
362 | "metadata": {},
363 | "outputs": [],
364 | "source": [
365 | "def calcAHerm(Hadj, diff, pad):\n",
366 | " xpad = pad(diff)\n",
367 | " X = fft.fft2(fft.ifftshift(xpad))\n",
368 | " return fft.fftshift(fft.ifft2(Hadj*X))"
369 | ]
370 | },
371 | {
372 | "cell_type": "markdown",
373 | "metadata": {},
374 | "source": [
375 | "#### Putting it all together"
376 | ]
377 | },
378 | {
379 | "cell_type": "markdown",
380 | "metadata": {},
381 | "source": [
382 | "This is the main function, which calculates the gradients and updates our estimation of the scene:\n",
383 | "\\begin{align*}\n",
384 | " &\\imagevec_0 = \\text{ anything} \\\\\n",
385 | " &\\text{for } k = 0 \\text{ to num_iters:}\\\\\n",
386 | " &\\quad \\quad \\imagevec_{k+1} \\leftarrow \\texttt{gradient_update}(\\imagevec_k) \\\\\n",
387 | "\\end{align*} \n",
388 | "\n",
389 | "There are different ways of doing the gradient update. The three we will show are regular GD, Nesterov momentum update, and FISTA.\n",
390 | "\n",
391 | "To guarantee convergence, we set the step size to be \n",
392 | "\\begin{align*} \n",
393 | "\\alpha_k &< \\frac{2}{\\|\\measurementmtx^H \\measurementmtx\\|}_2 \\approx \\frac{2}{\\lambda_{max}(\\mathbf{M}^H\\mathbf{M})}\n",
394 | "\\end{align*}\n",
395 | "\n",
396 | "To calculate this, we use the property that $\\mathbf{M}$ is diagonalizable by a Fourier Transform:\n",
397 | "$$\\begin{align*}\n",
398 | "\\mathbf{M}^H\\mathbf{M} &= \\left(\\mathbf{F}^{-1} \\mathrm{diag}(\\mathbf{Fh}) \\ \\mathbf{F}\\right)^H \\ \\mathbf{F}^{-1} \\mathrm{diag}(\\mathbf{Fh}) \\ \\mathbf{F} \\\\\n",
399 | "&= \\mathbf{F}^{-1} \\mathrm{diag}(\\mathbf{Fh})^* \\ \\mathrm{diag}(\\mathbf{Fh}) \\ \\mathbf{F} \\\\\n",
400 | "\\lambda_{max}(\\mathbf{M}^H\\mathbf{M}) &= \\max \\left(\\mathrm{diag}(\\mathbf{Fh})^* \\ \\mathrm{diag}(\\mathbf{Fh}) \\right)\n",
401 | "\\end{align*}$$\n",
402 | "In code, we have\n",
403 | "\\begin{align*}\n",
404 | "\\alpha = \\frac{1.8}{\\texttt{max} \\left(\\texttt{Hadj} \\cdot \\texttt{H}\\right)}\n",
405 | "\\end{align*}\n",
406 | "\n",
407 | "\n",
408 | "Since we are dealing with images, one constraint on the reconstructed image $\\imagevec_k$ is that all the entries have to be non-negative. We do this by doing projected gradient descent. The projection function $\\texttt{proj}$ we use is non-negativity, which projects $\\texttt{vk}$ onto the non-negative halfspace. "
409 | ]
410 | },
411 | {
412 | "cell_type": "code",
413 | "execution_count": 9,
414 | "metadata": {},
415 | "outputs": [],
416 | "source": [
417 | "def grad_descent(h, b):\n",
418 | " H, Hadj, v, utils = initMatrices(h)\n",
419 | " crop = utils[0]\n",
420 | " pad = utils[1]\n",
421 | " \n",
422 | " alpha = np.real(1.8/(np.max(Hadj * H)))\n",
423 | " iterations = 0\n",
424 | " \n",
425 | " def non_neg(xi):\n",
426 | " xi = np.maximum(xi,0)\n",
427 | " return xi\n",
428 | " \n",
429 | " #proj = lambda x:x #Do no projection\n",
430 | " proj = non_neg #Enforce nonnegativity at every gradient step. Comment out as needed.\n",
431 | "\n",
432 | "\n",
433 | " parent_var = [H, Hadj, b, crop, pad, alpha, proj]\n",
434 | " \n",
435 | " vk = v\n",
436 | " \n",
437 | " \n",
438 | " \n",
439 | " #### uncomment for Nesterov momentum update #### \n",
440 | " #p = 0\n",
441 | " #mu = 0.9\n",
442 | " ################################################\n",
443 | " \n",
444 | " \n",
445 | " \n",
446 | " #### uncomment for FISTA update ################\n",
447 | " tk = 1\n",
448 | " xk = v\n",
449 | " ################################################\n",
450 | " \n",
451 | " for iterations in range(iters): \n",
452 | " \n",
453 | " # uncomment for regular GD update\n",
454 | " #vk = gd_update(vk, parent_var)\n",
455 | " \n",
456 | " # uncomment for Nesterov momentum update \n",
457 | " #vk, p = nesterov_update(vk, p, mu, parent_var)\n",
458 | " \n",
459 | " # uncomment for FISTA update\n",
460 | " vk, tk, xk = fista_update(vk, tk, xk, parent_var)\n",
461 | "\n",
462 | " if iterations % 10 == 0:\n",
463 | " image = proj(crop(vk))\n",
464 | " f = plt.figure(1)\n",
465 | " plt.imshow(image, cmap='gray')\n",
466 | " plt.title('Reconstruction after iteration {}'.format(iterations))\n",
467 | " display.display(f)\n",
468 | " display.clear_output(wait=True)\n",
469 | " \n",
470 | " \n",
471 | " return proj(crop(vk)) \n",
472 | " "
473 | ]
474 | },
475 | {
476 | "cell_type": "markdown",
477 | "metadata": {},
478 | "source": [
479 | "#### Gradient descent algorithms"
480 | ]
481 | },
482 | {
483 | "cell_type": "markdown",
484 | "metadata": {},
485 | "source": [
486 | "##### Regular Gradient Descent\n",
487 | "Regular gradient descent is simply following the negative of the gradient until we reach the minimum:\n",
488 | "\n",
489 | "\\begin{align*}\n",
490 | " & \\texttt{gradient_update}(\\imagevec_k): \\\\\n",
491 | " &\\quad \\quad\\imagevec'_{k+1} \\leftarrow \\imagevec_k - \\alpha_k \\full^H(\\full\\imagevec_k - \\measurementvec) \\\\\n",
492 | " &\\quad \\quad \\imagevec_{k+1} \\leftarrow \\operatorname{proj}_{\\imagevec \\geq 0} (\\imagevec_{k+1}')\n",
493 | "\\end{align*} "
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "execution_count": 10,
499 | "metadata": {},
500 | "outputs": [],
501 | "source": [
502 | "def gd_update(vk, parent_var):\n",
503 | " H, Hadj, b, crop, pad, alpha, proj = parent_var\n",
504 | " \n",
505 | " gradient = grad(Hadj, H, vk, b, crop, pad)\n",
506 | " vk -= alpha*gradient\n",
507 | " vk = proj(vk)\n",
508 | " \n",
509 | " return xk "
510 | ]
511 | },
512 | {
513 | "cell_type": "markdown",
514 | "metadata": {},
515 | "source": [
516 | "##### Nesterov Momentum\n",
517 | "GD works but it's slow. One way to speed it up is to consider a velocity term, $\\mathbf{p}$. Each update becomes:\n",
518 | "\n",
519 | "\\begin{align*}\n",
520 | " &\\texttt{gradient_update}(\\imagevec_k): \\\\\n",
521 | " &\\qquad \\mathbf{p}_{k+1} \\leftarrow \\mu \\mathbf{p}_k - \\alpha_k \\texttt{grad}(\\imagevec_k) \\\\\n",
522 | " &\\qquad \\imagevec_{k+1}' \\leftarrow \\imagevec_k - \\mu \\mathbf{p}_{k} + (1+\\mu)\\mathbf{p}_{k+1} \\\\\n",
523 | " &\\qquad \\imagevec_{k+1} \\leftarrow \\operatorname{proj}_{\\imagevec \\geq 0} (\\imagevec_{k+1}')\n",
524 | "\\end{align*}\n",
525 | "\n",
526 | "The parameter $\\mu$ is called the momentum and is strictly between 0 and 1."
527 | ]
528 | },
529 | {
530 | "cell_type": "code",
531 | "execution_count": 11,
532 | "metadata": {},
533 | "outputs": [],
534 | "source": [
535 | "def nesterov_update(vk, p, mu, parent_var):\n",
536 | " H, Hadj, b, crop, pad, alpha, proj = parent_var\n",
537 | " \n",
538 | " p_prev = p\n",
539 | " gradient = grad(Hadj, H, vk, b, crop, pad)\n",
540 | " p = mu*p - alpha*gradient\n",
541 | " vk += -mu*p_prev + (1+mu)*p\n",
542 | " vk = proj(vk)\n",
543 | " \n",
544 | " return vk, p"
545 | ]
546 | },
547 | {
548 | "cell_type": "markdown",
549 | "metadata": {},
550 | "source": [
551 | "##### FISTA \n",
552 | "Instead of Nesterov momentum, we can use FISTA, which speeds up the iterative process. Each iteration of the algorithm is as follows:\n",
553 | "\n",
554 | "\\begin{align*}\n",
555 | " &\\texttt{gradient_update}(\\imagevec_k):\\\\\n",
556 | " &\\qquad \\imagevec_k \\leftarrow \\imagevec_k - \\alpha_k \\texttt{grad}(\\imagevec_k) \\\\ \n",
557 | " &\\qquad x_{k} \\leftarrow \\texttt{proj}(\\imagevec_k) \\\\\n",
558 | " &\\qquad t_{k+1} \\leftarrow \\frac{1+\\sqrt{1+4t_k^2}}{2} \\\\\n",
559 | " &\\qquad \\imagevec_{k+1} \\leftarrow x_{k} + \\frac{t_k-1}{t_{k+1}} (x_{k}-x_{k-1}) \\\\\n",
560 | "\\end{align*}\n"
561 | ]
562 | },
563 | {
564 | "cell_type": "code",
565 | "execution_count": 12,
566 | "metadata": {},
567 | "outputs": [],
568 | "source": [
569 | "def fista_update(vk, tk, xk, parent_var):\n",
570 | " H, Hadj, b, crop, pad, alpha, proj = parent_var\n",
571 | " \n",
572 | " x_k1 = xk\n",
573 | " gradient = grad(Hadj, H, vk, b, crop, pad)\n",
574 | " vk -= alpha*gradient\n",
575 | " xk = proj(vk)\n",
576 | " t_k1 = (1+np.sqrt(1+4*tk**2))/2\n",
577 | " vk = xk+(tk-1)/t_k1*(xk - x_k1)\n",
578 | " tk = t_k1\n",
579 | " \n",
580 | " return vk, tk, xk"
581 | ]
582 | },
583 | {
584 | "cell_type": "markdown",
585 | "metadata": {},
586 | "source": [
587 | "#### Running the algorithm"
588 | ]
589 | },
590 | {
591 | "cell_type": "code",
592 | "execution_count": 13,
593 | "metadata": {},
594 | "outputs": [
595 | {
596 | "data": {
597 | "image/png": "\n",
598 | "text/plain": [
599 | ""
600 | ]
601 | },
602 | "metadata": {},
603 | "output_type": "display_data"
604 | },
605 | {
606 | "data": {
607 | "image/png": "\n",
608 | "text/plain": [
609 | ""
610 | ]
611 | },
612 | "metadata": {},
613 | "output_type": "display_data"
614 | }
615 | ],
616 | "source": [
617 | "psf, data = loaddata()\n",
618 | "final_im = grad_descent(psf, data)\n",
619 | "plt.imshow(final_im, cmap='gray')\n",
620 | "plt.title('Final reconstruction after {} iterations'.format(iters))\n",
621 | "display.display()\n"
622 | ]
623 | }
624 | ],
625 | "metadata": {
626 | "kernelspec": {
627 | "display_name": "Python [default]",
628 | "language": "python",
629 | "name": "python3"
630 | },
631 | "language_info": {
632 | "codemirror_mode": {
633 | "name": "ipython",
634 | "version": 3
635 | },
636 | "file_extension": ".py",
637 | "mimetype": "text/x-python",
638 | "name": "python",
639 | "nbconvert_exporter": "python",
640 | "pygments_lexer": "ipython3",
641 | "version": "3.6.5"
642 | }
643 | },
644 | "nbformat": 4,
645 | "nbformat_minor": 2
646 | }
647 |
--------------------------------------------------------------------------------
/tutorial/psf_sample.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/tutorial/psf_sample.tif
--------------------------------------------------------------------------------
/tutorial/rawdata_hand_sample.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Waller-Lab/DiffuserCam-Tutorial/34674be3b063a266d8ce5279a31e8ac8779c1f1b/tutorial/rawdata_hand_sample.tif
--------------------------------------------------------------------------------