├── LICENSE ├── README.md ├── na_list_test.mat ├── opticaltomography ├── opticsalg.py ├── opticsmodel.py ├── opticsutil.py ├── regularizers.py └── settings.py └── sample_script.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Waller Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-layer Born scattering model for 3D phase contrast optical diffraction tomography 2 | Please cite the following paper when using the code: 3 | 4 | [Multi-layer Born multiple-scattering model for 3D phase microscopy](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-7-5-394) 5 | 6 | Michael Chen, David Ren, Hsiou-Yuan Liu, Shwetadwip Chowdhury, and Laura Waller 7 | 8 | ## Contents 9 | 1. [Usage](#usage) 10 | 2. [Data](#Data) 11 | 3. [Resources](#resources) 12 | 4. [Updates](#updates) 13 | 14 | 15 | 16 | ## Usage 17 | 0. Install required dependencies [Arrayfire](https://github.com/arrayfire/arrayfire-python) and [contexttimer](https://pypi.org/project/contexttimer/). 18 | 1. Clone this repo. 19 | 2. Run sample script for simulated measurements and reconstructions. 20 | 21 | ## Data 22 | Sample data is prepared [here](https://drive.google.com/drive/folders/19eQCMjTtiK8N1f1nGtXlfXkEa8qL6kDl?usp=sharing). 23 | 24 | ## Resources 25 | 0. [MATLAB implementation](https://github.com/Waller-Lab/multi-slice) of the [multi-slice method](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-6-9-1211). 26 | 27 | ## Updates 28 | 08/27/2020: 29 | 1. Added first version of code to repo. Same as that used in the paper. 30 | 31 | -------------------------------------------------------------------------------- /na_list_test.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Waller-Lab/multi-layer-born/b90b8c6a73c99bdb16a2ec60caa48893a40895c1/na_list_test.mat -------------------------------------------------------------------------------- /opticaltomography/opticsalg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implement optics algorithms for optical phase tomography using GPU 3 | 4 | Michael Chen mchen0405@berkeley.edu 5 | David Ren david.ren@berkeley.edu 6 | November 22, 2017 7 | """ 8 | import scipy.io as sio 9 | import numpy as np 10 | import arrayfire as af 11 | import skimage.feature 12 | import contexttimer 13 | from opticaltomography import settings 14 | from opticaltomography.opticsmodel import MultiTransmittance, MultiPhaseContrast, MultiBorn 15 | from opticaltomography.opticsmodel import Defocus, Aberration 16 | from opticaltomography.opticsutil import ImageCrop, calculateNumericalGradient, EmptyClass 17 | from opticaltomography.regularizers import Regularizer 18 | from IPython.core.debugger import set_trace 19 | 20 | np_complex_datatype = settings.np_complex_datatype 21 | np_float_datatype = settings.np_float_datatype 22 | af_float_datatype = settings.af_float_datatype 23 | af_complex_datatype = settings.af_complex_datatype 24 | 25 | class AlgorithmConfigs: 26 | """ 27 | Class created for all parameters for tomography solver 28 | """ 29 | def __init__(self): 30 | self.method = "FISTA" #or GradientDescent" 31 | self.stepsize = 1e-2 32 | self.max_iter = 20 33 | self.error = [] 34 | self.reg_term = 0.0 #L2 norm 35 | 36 | #FISTA 37 | self.fista_global_update = False 38 | self.restart = False 39 | 40 | #total variation regularization 41 | self.total_variation = False 42 | self.reg_tv = 1.0 #lambda 43 | self.max_iter_tv = 15 44 | self.order_tv = 1 45 | self.total_variation_gpu = False 46 | self.total_variation_anisotropic = False 47 | 48 | #lasso 49 | self.lasso = False 50 | self.reg_lasso = 1.0 51 | 52 | #positivity constraint 53 | self.positivity_real = (False, "larger") #or smaller 54 | self.positivity_imag = (False, "larger") #or smaller 55 | self.pure_real = False 56 | self.pure_imag = False 57 | 58 | #purely amplitude/phase 59 | self.pure_amplitude = False 60 | self.pure_phase = False 61 | 62 | #batch gradient update 63 | self.batch_size = 1 64 | 65 | #random order update 66 | self.random_order = False 67 | 68 | # Reconstruct from field or amplitude measurement 69 | self.recon_from_field = False 70 | self.cost_criterion = "amplitude" 71 | 72 | class PhaseObject3D: 73 | """ 74 | Class created for 3D objects. 75 | Depending on the scattering model, one of the following quantities will be used: 76 | - Refractive index (RI) 77 | - Transmittance function (Trans) 78 | - PhaseContrast 79 | - Scattering potential (V) 80 | 81 | shape: shape of object to be reconstructed in (x,y,z), tuple 82 | voxel_size: size of each voxel in (x,y,z), tuple 83 | RI_obj: refractive index of object(Optional) 84 | RI: background refractive index (Optional) 85 | slice_separation: For multislice algorithms, how far apart are slices separated, array (Optional) 86 | """ 87 | def __init__(self, shape, voxel_size, RI_obj = None, RI = 1.0, slice_separation = None): 88 | assert len(shape) == 3, "shape should be 3 dimensional!" 89 | self.shape = shape 90 | self.RI_obj = RI * np.ones(shape, dtype = np_complex_datatype) if RI_obj is None else RI_obj.astype(np_complex_datatype) 91 | print("RI max:", np.max(np.abs(self.RI_obj))) 92 | self.RI = RI 93 | self.pixel_size = voxel_size[0] 94 | self.pixel_size_z = voxel_size[2] 95 | 96 | if slice_separation is not None: 97 | #for discontinuous slices 98 | assert len(slice_separation) == shape[2]-1, "number of separations should match with number of layers!" 99 | self.slice_separation = np.asarray(slice_separation).astype(np_float_datatype) 100 | else: 101 | #for continuous slices 102 | self.slice_separation = self.pixel_size_z * np.ones((shape[2]-1,), dtype = np_float_datatype) 103 | 104 | def convertRItoTrans(self, wavelength): 105 | k0 = 2.0 * np.pi / wavelength 106 | self.trans_obj = np.exp(1.0j*k0*(self.RI_obj - self.RI)*self.pixel_size_z) 107 | self.RI_obj = None 108 | 109 | def convertRItoPhaseContrast(self): 110 | self.contrast_obj = self.RI_obj - self.RI 111 | self.RI_obj = None 112 | 113 | def convertPhaseContrasttoRI(self): 114 | self.RI_obj = self.contrast_obj + self.RI 115 | 116 | def convertRItoV(self, wavelength): 117 | k0 = 2.0 * np.pi / wavelength 118 | self.V_obj = k0**2 * (self.RI**2 - self.RI_obj**2) 119 | self.RI_obj = None 120 | 121 | def convertVtoRI(self, wavelength): 122 | k0 = 2.0 * np.pi / wavelength 123 | B = -1.0 * (self.RI**2 - self.V_obj.real/k0**2) 124 | C = -1.0 * (-1.0 * self.V_obj.imag/k0**2/2.0)**2 125 | RI_obj_real = ((-1.0 * B + (B**2-4.0*C)**0.5)/2.0)**0.5 126 | RI_obj_imag = -0.5 * self.V_obj.imag/k0**2/RI_obj_real 127 | self.RI_obj = RI_obj_real + 1.0j * RI_obj_imag 128 | 129 | class TomographySolver: 130 | """ 131 | Highest level solver object for tomography problem 132 | 133 | phase_obj_3d: phase_obj_3d object defined from class PhaseObject3D 134 | fx_illu_list: illumination angles in x, default = [0] (on axis) 135 | fy_illu_list: illumination angles in y 136 | propagation_distance_list: defocus distances for each illumination 137 | """ 138 | def __init__(self, phase_obj_3d, fx_illu_list = [0], fy_illu_list = [0], \ 139 | propagation_distance_list = [0], **kwargs): 140 | 141 | self.phase_obj_3d = phase_obj_3d 142 | self.wavelength = kwargs["wavelength"] 143 | self.na = kwargs["na"] 144 | 145 | #Illumination angles 146 | assert len(fx_illu_list) == len(fy_illu_list) 147 | self.fx_illu_list = fx_illu_list 148 | self.fy_illu_list = fy_illu_list 149 | self.number_illum = len(self.fx_illu_list) 150 | 151 | #Defocus distances and object 152 | self.prop_distances = propagation_distance_list 153 | self._aberration_obj = Aberration(self.phase_obj_3d.shape[:2], self.phase_obj_3d.pixel_size, self.wavelength, self.na) 154 | self._defocus_obj = Defocus(self.phase_obj_3d.shape[:2], self.phase_obj_3d.pixel_size, **kwargs) 155 | self._crop_obj = ImageCrop(self.phase_obj_3d.shape[:2], **kwargs) 156 | self.number_defocus = len(self.prop_distances) 157 | 158 | #Scattering models and algorithms 159 | self._opticsmodel = {"MultiTrans": MultiTransmittance, 160 | "MultiPhaseContrast": MultiPhaseContrast, 161 | "MultiBorn": MultiBorn, 162 | } 163 | self._algorithms = {"GradientDescent": self._solveFirstOrderGradient, 164 | "FISTA": self._solveFirstOrderGradient, 165 | } 166 | self.scat_model_args = kwargs 167 | 168 | def setScatteringMethod(self, model = "MultiTrans"): 169 | """ 170 | Define scattering method for tomography 171 | 172 | model: scattering models, it can be one of the followings: 173 | "MultiTrans", "MultiPhaseContrast", "Rytov", "MultiRytov", "Born", "Born" 174 | """ 175 | 176 | if hasattr(self, '_scattering_obj'): 177 | del self._scattering_obj 178 | if model == "MultiTrans": 179 | try: 180 | self.phase_obj_3d.convertRItoTrans(self.wavelength) 181 | except: 182 | self.phase_obj_3d.trans_obj = self._x 183 | self._x = self.phase_obj_3d.trans_obj 184 | 185 | elif model == "MultiPhaseContrast": 186 | if not hasattr(self.phase_obj_3d, 'contrast_obj'): 187 | try: 188 | self.phase_obj_3d.convertRItoPhaseContrast() 189 | except: 190 | self.phase_obj_3d.contrast_obj = self._x 191 | self._x = self.phase_obj_3d.contrast_obj 192 | 193 | else: 194 | if not hasattr(self.phase_obj_3d, 'V_obj'): 195 | try: 196 | self.phase_obj_3d.convertRItoV(self.wavelength) 197 | except: 198 | self.phase_obj_3d.V_obj = self._x 199 | self._x = self.phase_obj_3d.V_obj 200 | 201 | self._scattering_obj = self._opticsmodel[model](self.phase_obj_3d, **self.scat_model_args) 202 | self._initialization(x_init = self._x) 203 | self.scat_model = model 204 | 205 | def forwardPredict(self, field = False): 206 | """ 207 | Uses current object in the phase_obj_3d to predict the amplitude of the exit wave 208 | Before calling, make sure correct object is contained 209 | """ 210 | obj_gpu = af.to_array(self._x) #self._x created in self.setScatteringMethod() 211 | with contexttimer.Timer() as timer: 212 | forward_scattered_predict= [] 213 | for illu_idx in range(self.number_illum): 214 | fx_illu = self.fx_illu_list[illu_idx] 215 | fy_illu = self.fy_illu_list[illu_idx] 216 | 217 | fields = self._forwardMeasure(fx_illu, fy_illu, obj = obj_gpu) 218 | if field: 219 | forward_scattered_predict.append(np.array(fields["forward_scattered_field"])) 220 | else: 221 | forward_scattered_predict.append(np.abs(fields["forward_scattered_field"])) 222 | if self.number_illum > 1: 223 | print("illumination {:03d}/{:03d}.".format(illu_idx, self.number_illum), end="\r") 224 | if len(forward_scattered_predict[0][0].shape)==2: 225 | forward_scattered_predict = np.array(forward_scattered_predict).transpose(2, 3, 1, 0) 226 | elif len(forward_scattered_predict[0][0].shape)==1: 227 | forward_scattered_predict = np.array(forward_scattered_predict).transpose(1, 2, 0) 228 | return forward_scattered_predict 229 | 230 | def checkGradient(self, delta = 1e-4): 231 | """ 232 | check if the numerical gradient is similar to the analytical gradient. Only works for 64 bit data type. 233 | """ 234 | 235 | #assert af_float_datatype == af.Dtype.f64, "This will only be accurate if 64 bit datatype is used!" 236 | shape = self.phase_obj_3d.shape 237 | point = (np.random.randint(shape[0]), np.random.randint(shape[1]), np.random.randint(shape[2])) 238 | illu_idx = np.random.randint(len(self.fx_illu_list)) 239 | fx_illu = self.fx_illu_list[illu_idx] 240 | fy_illu = self.fy_illu_list[illu_idx] 241 | x = np.ones(shape, dtype = np_complex_datatype) 242 | if self._crop_obj.pad: 243 | amplitude = af.randu(shape[0]//2, shape[1]//2, dtype = af_float_datatype) 244 | else: 245 | amplitude = af.randu(shape[0], shape[1], dtype = af_float_datatype) 246 | print("Computing the gradient at point: ", point) 247 | print("fx : %5.2f, fy : %5.2f " %(fx_illu, fy_illu)) 248 | 249 | def func(x0): 250 | fields = self._scattering_obj.forward(x0, fx_illu, fy_illu) 251 | field_scattered = self._defocus_obj.forward(field_scattered, self.prop_distances) 252 | field_measure = self._crop_obj.forward(field_scattered) 253 | residual = af.abs(field_measure) - amplitude 254 | function_value = af.sum(residual*af.conjg(residual)).real 255 | return function_value 256 | 257 | numerical_gradient = calculateNumericalGradient(func, x, point, delta = delta) 258 | 259 | fields = self._scattering_obj.forward(x, fx_illu, fy_illu) 260 | forward_scattered_field = fields["forward_scattered_field"] 261 | forward_scattered_field = self._defocus_obj.forward(forward_scattered_field, self.prop_distances) 262 | field_measure = self._crop_obj.forward(forward_scattered_field) 263 | fields["forward_scattered_field"] = field_measure 264 | analytical_gradient = self._computeGradient(fields, amplitude)[0][point] 265 | 266 | print("numerical gradient: %5.5e + %5.5e j" %(numerical_gradient.real, numerical_gradient.imag)) 267 | print("analytical gradient: %5.5e + %5.5e j" %(analytical_gradient.real, analytical_gradient.imag)) 268 | 269 | def _forwardMeasure(self, fx_illu, fy_illu, obj = None): 270 | 271 | """ 272 | From an illumination angle, this function computes the exit wave. 273 | fx_illu, fy_illu: illumination angle in x and y (scalars) 274 | obj: object to be passed through (Optional, default pick from phase_obj_3d) 275 | """ 276 | #Forward scattering of light through object 277 | #_scattering_obj set in setScatteringModel() 278 | if obj is None: 279 | fields = self._scattering_obj.forward(self._x, fx_illu, fy_illu) 280 | else: 281 | fields = self._scattering_obj.forward(obj, fx_illu, fy_illu) 282 | field_scattered = self._defocus_obj.forward(fields["forward_scattered_field"], self.prop_distances) 283 | field_scattered = self._aberration_obj.forward(field_scattered) 284 | field_scattered = self._crop_obj.forward(field_scattered) 285 | fields["forward_scattered_field"] = field_scattered 286 | return fields 287 | 288 | def _computeGradient(self, fields, measurement): 289 | """ 290 | Error backpropagation to return a gradient 291 | fields: contains all necessary predicted information (predicted field, intermediate cache) 292 | to compute gradient 293 | measurement: amplitude measured 294 | """ 295 | cache = fields["cache"] 296 | field_predict = fields["forward_scattered_field"] 297 | if self.configs.recon_from_field: 298 | field_bp = field_predict - measurement 299 | else: 300 | if self.configs.cost_criterion == "intensity": 301 | field_bp = 2 * field_predict * (af.abs(field_predict) ** 2 - af.abs(measurement) ** 2) 302 | elif self.configs.cost_criterion == "amplitude": 303 | field_bp = field_predict - measurement*af.exp(1.0j*af.arg(field_predict)) 304 | field_bp = self._crop_obj.adjoint(field_bp) 305 | field_bp = self._aberration_obj.adjoint(field_bp) 306 | field_bp = self._defocus_obj.adjoint(field_bp, self.prop_distances) 307 | gradient = self._scattering_obj.adjoint(field_bp, cache) 308 | 309 | return gradient["gradient"] 310 | 311 | def _initialization(self,configs=AlgorithmConfigs(), \ 312 | x_init = None, obj_support = None,\ 313 | pupil = None, pupil_support = None, \ 314 | verbose = True): 315 | """ 316 | Initialize algorithm 317 | configs: configs object from class AlgorithmConfigs 318 | x_init: initial guess of object 319 | """ 320 | self.configs = configs 321 | if pupil is not None: 322 | self.configs.pupil = pupil 323 | if pupil_support is not None: 324 | self.configs.pupil_support = pupil_support 325 | 326 | if x_init is None: 327 | if self.scat_model is "MultiTrans": 328 | self._x[:, :, :] = 1.0 329 | else: 330 | self._x[:, :, :] = 0.0 331 | else: 332 | self._x[:, :, :] = x_init 333 | 334 | self.obj_support = obj_support 335 | 336 | self._regularizer_obj = Regularizer(self.configs, verbose) 337 | 338 | def _solveFirstOrderGradient(self, measurements, verbose, callback=None): 339 | """ 340 | MAIN part of the solver, runs the FISTA algorithm 341 | configs: configs object from class AlgorithmConfigs 342 | measurements: all measurements 343 | self.configs.recon_from_field == True: field 344 | self.configs.recon_from_field == False: amplitude measurement 345 | verbose: boolean variable to print verbosely 346 | """ 347 | flag_FISTA = False 348 | if self.configs.method == "FISTA": 349 | flag_FISTA = True 350 | 351 | # update multiple angles at a time 352 | batch_update = False 353 | if self.configs.fista_global_update or self.configs.batch_size != 1: 354 | gradient_batch = af.constant(0.0, self.phase_obj_3d.shape[0],\ 355 | self.phase_obj_3d.shape[1],\ 356 | self.phase_obj_3d.shape[2], dtype = af_complex_datatype) 357 | batch_update = True 358 | if self.configs.fista_global_update: 359 | self.configs.batch_size = 0 360 | 361 | #TODO: what if num_batch is not an integer 362 | if self.configs.batch_size == 0: 363 | num_batch = 1 364 | else: 365 | num_batch = self.number_illum // self.configs.batch_size 366 | stepsize = self.configs.stepsize 367 | max_iter = self.configs.max_iter 368 | reg_term = self.configs.reg_term 369 | self.configs.error = [] 370 | obj_gpu = af.constant(0.0, self.phase_obj_3d.shape[0],\ 371 | self.phase_obj_3d.shape[1],\ 372 | self.phase_obj_3d.shape[2], dtype = af_complex_datatype) 373 | 374 | #Initialization for FISTA update 375 | if flag_FISTA: 376 | restart = self.configs.restart 377 | y_k = self._x.copy() 378 | t_k = 1.0 379 | 380 | #Set Callback flag 381 | if callback is None: 382 | run_callback = False 383 | else: 384 | run_callback = True 385 | 386 | #Start of iterative algorithm 387 | with contexttimer.Timer() as timer: 388 | if verbose: 389 | print("---- Start of the %5s algorithm ----" %(self.scat_model)) 390 | for iteration in range(max_iter): 391 | illu_counter = 0 392 | cost = 0.0 393 | obj_gpu[:] = af.to_array(self._x) 394 | if self.configs.random_order: 395 | illu_order = np.random.permutation(range(self.number_illum)) 396 | else: 397 | illu_order = range(self.number_illum) 398 | 399 | for batch_idx in range(num_batch): 400 | if batch_update: 401 | gradient_batch[:,:,:] = 0.0 402 | 403 | if self.configs.batch_size == 0: 404 | illu_indices = illu_order 405 | else: 406 | illu_indices = illu_order[batch_idx * self.configs.batch_size : (batch_idx+1) * self.configs.batch_size] 407 | for illu_idx in illu_indices: 408 | #forward scattering 409 | fx_illu = self.fx_illu_list[illu_idx] 410 | fy_illu = self.fy_illu_list[illu_idx] 411 | fields = self._forwardMeasure(fx_illu, fy_illu, obj = obj_gpu) 412 | #calculate error 413 | measurement = af.to_array(measurements[:,:,:,illu_idx].astype(np_complex_datatype)) 414 | 415 | if self.configs.recon_from_field: 416 | residual = fields["forward_scattered_field"] - measurement 417 | else: 418 | if self.configs.cost_criterion == "intensity": 419 | residual = af.abs(fields["forward_scattered_field"])**2 - measurement**2 420 | elif self.configs.cost_criterion == "amplitude": 421 | residual = af.abs(fields["forward_scattered_field"]) - measurement 422 | cost += af.sum(residual*af.conjg(residual)).real 423 | #calculate gradient 424 | if batch_update: 425 | gradient_batch[:, :, :] += self._computeGradient(fields, measurement)[0] 426 | else: 427 | gradient = self._computeGradient(fields, measurement) 428 | obj_gpu[:, :, :] -= stepsize * gradient 429 | 430 | if verbose: 431 | if self.number_illum > 1: 432 | print("gradient update of illumination {:03d}/{:03d}.".format(illu_counter, self.number_illum), end="\r") 433 | illu_counter += 1 434 | fields = None 435 | residual = None 436 | gradient = None 437 | measurement = None 438 | pupil = None 439 | af.device_gc() 440 | 441 | if batch_update: 442 | obj_gpu[:, :, :] -= stepsize * gradient_batch 443 | 444 | if np.isnan(obj_gpu).sum() > 0: 445 | stepsize *= 0.1 446 | self.configs.time_elapsed = timer.elapsed 447 | print("WARNING: Gradient update diverges! Resetting stepsize to %3.2f" %(stepsize)) 448 | t_k = 1.0 449 | continue 450 | 451 | # L2 regularizer 452 | obj_gpu[:, :, :] -= stepsize * reg_term * obj_gpu 453 | 454 | #record total error 455 | self.configs.error.append(cost + reg_term * af.sum(obj_gpu*af.conjg(obj_gpu)).real) 456 | 457 | #Prox operators 458 | af.device_gc() 459 | obj_gpu = self._regularizer_obj.applyRegularizer(obj_gpu) 460 | 461 | if flag_FISTA: 462 | #check convergence 463 | if iteration > 0: 464 | if self.configs.error[-1] > self.configs.error[-2]: 465 | if restart: 466 | t_k = 1.0 467 | 468 | self._x[:, :, :] = y_k 469 | # stepsize *= 0.8 470 | 471 | print("WARNING: FISTA Restart! Error: %5.5f" %(np.log10(self.configs.error[-1]))) 472 | if run_callback: 473 | callback(self._x, self.configs) 474 | continue 475 | else: 476 | print("WARNING: Error increased! Error: %5.5f" %(np.log10(self.configs.error[-1]))) 477 | 478 | #FISTA auxiliary variable 479 | y_k1 = np.array(obj_gpu) 480 | if len(y_k1.shape) < 3: 481 | y_k1 = y_k1[:,:,np.newaxis] 482 | 483 | #FISTA update 484 | t_k1 = 0.5*(1.0 + (1.0 + 4.0*t_k**2)**0.5) 485 | beta = (t_k - 1.0) / t_k1 486 | self._x[:, :, :] = y_k1 + beta * (y_k1 - y_k) 487 | t_k = t_k1 488 | y_k = y_k1.copy() 489 | else: 490 | #check convergence 491 | temp = np.array(obj_gpu) 492 | if len(temp.shape) < 3: 493 | temp = temp[:,:,np.newaxis] 494 | self._x[:, :, :] = temp 495 | if iteration > 0: 496 | if self.configs.error[-1] > self.configs.error[-2]: 497 | print("WARNING: Error increased! Error: %5.5f" %(np.log10(self.configs.error[-1]))) 498 | stepsize *= 0.8 499 | 500 | if verbose: 501 | print("iteration: %d/%d, error: %5.5f, elapsed time: %5.2f seconds" %(iteration+1, max_iter, np.log10(self.configs.error[-1]), timer.elapsed)) 502 | 503 | if run_callback: 504 | callback(self._x, self.configs) 505 | 506 | self.configs.time_elapsed = timer.elapsed 507 | return self._x 508 | 509 | def solve(self, configs, measurements, \ 510 | x_init = None, obj_support = None,\ 511 | pupil = None, pupil_support = None,\ 512 | verbose = True, callback = None): 513 | """ 514 | function to solve for the tomography problem 515 | 516 | configs: configs object from class AlgorithmConfigs 517 | measurements: measurements in amplitude not INTENSITY, ordered by (x,y,defocus, illumination) (if configs.recon_from_field is False) 518 | measurements of fields (if configs.recon_from_field is False) 519 | x_init: initial guess for object 520 | verbose: boolean variable to print verbosely 521 | """ 522 | 523 | # Ensure all data is of appropriate dimension 524 | # Namely, measurements are [image_size_1,image_size_2,number_defocus,number_illum] 525 | if self.number_defocus < 2: 526 | measurements = measurements[:,:, np.newaxis] 527 | 528 | if self.number_illum < 2: 529 | measurements = measurements[:, :, :, np.newaxis] 530 | 531 | self._initialization(configs = configs, x_init = x_init, obj_support = obj_support, pupil = pupil, pupil_support = pupil_support,\ 532 | verbose = verbose) 533 | return self._algorithms[configs.method](measurements, verbose, callback) -------------------------------------------------------------------------------- /opticaltomography/opticsmodel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implement optics algorithms for optical phase tomography using GPU 3 | 4 | Michael Chen mchen0405@berkeley.edu 5 | David Ren david.ren@berkeley.edu 6 | November 22, 2017 7 | """ 8 | 9 | import contexttimer 10 | 11 | import numpy as np 12 | import arrayfire as af 13 | from opticaltomography.opticsutil import genGrid, genPupil, propKernel, genZernikeAberration 14 | from opticaltomography import settings 15 | from IPython.core.debugger import set_trace 16 | 17 | ##TEMP 18 | # import scipy.io as sio 19 | 20 | np_complex_datatype = settings.np_complex_datatype 21 | af_complex_datatype = settings.af_complex_datatype 22 | 23 | class Aperture: 24 | """ 25 | Class for optical aperture (general) 26 | """ 27 | def __init__(self, shape, pixel_size, na, **kwargs): 28 | """ 29 | shape: shape of object (y,x,z) 30 | pixel_size: pixel size of the system 31 | na: NA of the system 32 | """ 33 | self.shape = shape 34 | self.pixel_size = pixel_size 35 | self.na = na 36 | 37 | def forward(self): 38 | pass 39 | 40 | def adjoint(self): 41 | pass 42 | 43 | class Aberration(Aperture): 44 | """ 45 | Aberration class used for pupil recovery 46 | """ 47 | def __init__(self, shape, pixel_size, wavelength, na, flag_update = False, pupil_step_size = 1.0, update_method = "gradient", **kwargs): 48 | """ 49 | Initialization of the class 50 | 51 | wavelength: wavelength of light 52 | flag_update: boolean variable to update pupil 53 | pupil_step_size: if update the pupil, what is a step size for gradient method 54 | update_method: can be "gradient" or "GaussNewton" 55 | """ 56 | super().__init__(shape, pixel_size, na, **kwargs) 57 | self.pupil = genPupil(self.shape, self.pixel_size, self.na, wavelength) 58 | self.wavelength = wavelength 59 | self.pupil_support = self.pupil.copy() 60 | self.pupil_step_size = pupil_step_size 61 | self.flag_update = flag_update 62 | self.update_method = update_method 63 | 64 | def forward(self, field, **kwargs): 65 | """Apply pupil""" 66 | self.field_f = af.fft2(field) 67 | if self.flag_update: 68 | if self.update_method == "GaussNewton": 69 | self.approx_hessian[:, :] += self.field_f*af.conjg(self.field_f) 70 | field = af.ifft2(self.getPupil() * self.field_f) 71 | return field 72 | 73 | def adjoint(self, field): 74 | """Adjoint operator for pupil (and estimate pupil if selected)""" 75 | field_adj_f = af.fft2(field) 76 | field_pupil_adj = af.ifft2(af.conjg(self.getPupil()) * field_adj_f) 77 | #Pupil recovery 78 | if self.flag_update: 79 | self._update(field_adj_f) 80 | return field_pupil_adj 81 | 82 | def _update(self, field_adj_f): 83 | """function to recover pupil""" 84 | self.pupil_gradient_phase[:, :] += af.conjg(self.field_f) * field_adj_f 85 | self.measure_count += 1 86 | if self.measure_count == self.measurement_num: 87 | if self.update_method == "gradient": 88 | self.pupil[:, :] -= self.pupil_step_size * self.pupil_gradient_phase * self.pupil_support 89 | elif self.update_method == "GaussNewton": 90 | self.pupil[:, :] -= self.pupil_step_size * af.abs(self.field_f) * self.pupil_gradient_phase * self.pupil_support / ((self.approx_hessian + 1e-3) * af.max(af.abs(self.field_f[:]))) 91 | self.approx_hessian[:, :] = 0.0 92 | else: 93 | print("there is no update_method \"%s\"!" %(self.update_method)) 94 | raise 95 | self.pupil_gradient_phase[:, :] = 0.0 96 | self.measure_count = 0 97 | 98 | def constrainPupil(self, pupil, pupil_mask=None, flag_smooth_phase = True): 99 | #makes sure pupil amplitude is between 0 and 1. 100 | #if pupil mask is passed in, amplitude will become the mask 101 | pupil_abs = np.abs(pupil) 102 | pupil_phase = af.shift(af.arg(pupil), pupil.shape[0]//2, pupil.shape[1]//2) 103 | #smoothen phase 104 | pupil_phase = af.convolve2(pupil_phase, self.gaussian_kernel) 105 | if pupil_mask is None: 106 | pupil_abs[pupil_abs>1.0] = 1.0 107 | pupil_abs[pupil_abs<0] = 0.0 108 | else: 109 | pupil_abs[:] = np.abs(pupil_mask) 110 | pupil = pupil_abs * np.exp(1.0j * af.shift(pupil_phase, -1*int(pupil.shape[0]//2), -1*int(pupil.shape[1]//2))) 111 | return pupil 112 | def constrainPupilAmpOnly(self, pupil, pupil_mask=None, flag_smooth_phase = True): 113 | #makes sure pupil amplitude is between 0 and 1. 114 | #if pupil mask is passed in, amplitude will become the mask 115 | pupil_abs = np.abs(pupil) 116 | pupil_phase = af.shift(af.arg(pupil), pupil.shape[0]//2, pupil.shape[1]//2) 117 | #smoothen phase 118 | if pupil_mask is None: 119 | pupil_abs[pupil_abs>1.0] = 1.0 120 | pupil_abs[pupil_abs<0] = 0.0 121 | else: 122 | pupil_abs[:] = np.abs(pupil_mask) 123 | pupil = pupil_abs * np.exp(1.0j * af.shift(pupil_phase, -1*int(pupil.shape[0]//2), -1*int(pupil.shape[1]//2))) 124 | return pupil 125 | 126 | def getZernikeCoefficients(self): 127 | return 0.0 128 | 129 | def getPupil(self): 130 | """function to retrieve pupil""" 131 | return self.pupil 132 | 133 | def setPupilSupport(self, pupil_support): 134 | """Input arbitrary pupil support""" 135 | self.pupil_support[:, :] = af.to_array(pupil_support) 136 | # self.pupil *= self.pupil_support 137 | 138 | def setPupil(self, pupil): 139 | """Input arbitratyt pupil""" 140 | self.pupil[:, :] = af.to_array(pupil) 141 | 142 | def setZernikePupil(self, z_coeff = [1], z_index_list = [0], pupil_support = None): 143 | """Set pupil according to Zernike coefficients""" 144 | self.pupil[:, :] = af.to_array(np.array(af.exp(1.0j*genZernikeAberration(self.shape, self.pixel_size, self.na, self.wavelength, z_coeff = z_coeff.copy(), z_index_list = z_index_list)))) 145 | if pupil_support is not None: 146 | self.pupil[:, :] *= pupil_support 147 | 148 | def setUpdateParams(self, flag_update = None, pupil_step_size = None, update_method = None, global_update = False, \ 149 | measurement_num = None, \ 150 | pupil = None, pupil_support = None): 151 | """Modify update parameters for pupil""" 152 | if pupil is not None: self.setPupil(pupil=pupil) 153 | if pupil_support is not None: self.setPupilSupport(pupil_support) 154 | self.flag_update = flag_update if flag_update is not None else self.flag_update 155 | if self.flag_update: 156 | self.update_method = update_method if update_method is not None else self.update_method 157 | self.pupil_step_size = pupil_step_size if pupil_step_size is not None else self.pupil_step_size 158 | self.pupil_gradient_amp = af.constant(0.0, self.shape[0], self.shape[1], dtype = af_complex_datatype) 159 | self.pupil_gradient_phase = af.constant(0.0, self.shape[0], self.shape[1], dtype = af_complex_datatype) 160 | self.measurement_num = measurement_num if global_update else 1 161 | self.measure_count = 0 162 | if self.update_method == "GaussNewton": 163 | self.approx_hessian = af.constant(0.0, self.shape[0], self.shape[1], dtype = af_complex_datatype) 164 | 165 | class Defocus(Aperture): 166 | """Defocus subclass for tomography""" 167 | def __init__(self, shape, pixel_size, wavelength, na, RI_measure = 1.0, **kwargs): 168 | """ 169 | Initialization of the class 170 | 171 | RI_measure: refractive index on the detection side (example: oil immersion objectives) 172 | """ 173 | super().__init__(shape, pixel_size, na, **kwargs) 174 | 175 | fxlin = genGrid(self.shape[1], 1.0/self.pixel_size/self.shape[1], flag_shift = True) 176 | fylin = genGrid(self.shape[0], 1.0/self.pixel_size/self.shape[0], flag_shift = True) 177 | fxlin = af.tile(fxlin.T, self.shape[0], 1) 178 | fylin = af.tile(fylin, 1, self.shape[1]) 179 | self.pupil_support = genPupil(self.shape, self.pixel_size, self.na, wavelength) 180 | self.prop_kernel_phase = 1.0j*2.0*np.pi*self.pupil_support*((RI_measure/wavelength)**2 - fxlin*af.conjg(fxlin) - fylin*af.conjg(fylin))**0.5 181 | 182 | def forward(self, field, propagation_distances): 183 | """defocus with angular spectrum""" 184 | field_defocus = self.pupil_support * af.fft2(field) 185 | field_defocus = af.tile(field_defocus, 1, 1, len(propagation_distances)) 186 | for z_idx, propagation_distance in enumerate(propagation_distances): 187 | propagation_kernel = af.exp(self.prop_kernel_phase*propagation_distance) 188 | field_defocus[:, :, z_idx] *= propagation_kernel 189 | af.ifft2_inplace(field_defocus) 190 | return field_defocus 191 | 192 | def adjoint(self, residual, propagation_distances): 193 | """adjoint operator for defocus with angular spectrum""" 194 | field_focus = af.constant(0.0, self.shape[0], self.shape[1], dtype = af_complex_datatype) 195 | for z_idx, propagation_distance in enumerate(propagation_distances): 196 | propagation_kernel_conj = af.exp(-1.0*self.prop_kernel_phase*propagation_distances[z_idx]) 197 | field_focus[:, :] += propagation_kernel_conj * af.fft2(residual[:, :, z_idx]) 198 | field_focus *= self.pupil_support 199 | af.ifft2_inplace(field_focus) 200 | return field_focus 201 | 202 | class ScatteringModels: 203 | """ 204 | Core of the scattering models 205 | """ 206 | def __init__(self, phase_obj_3d, wavelength, slice_binning_factor = 1, **kwargs): 207 | """ 208 | Initialization of the class 209 | 210 | phase_obj_3d: object of the class PhaseObject3D 211 | wavelength: wavelength of the light 212 | slice_binning_factor: the object is compress in z-direction by this factor 213 | """ 214 | self.slice_binning_factor = slice_binning_factor 215 | self._shape_full = phase_obj_3d.shape 216 | self.shape = phase_obj_3d.shape[0:2] + (int(np.ceil(phase_obj_3d.shape[2]/self.slice_binning_factor)),) 217 | self.RI = phase_obj_3d.RI 218 | self.wavelength = wavelength 219 | self.pixel_size = phase_obj_3d.pixel_size 220 | self.pixel_size_z = phase_obj_3d.pixel_size_z * self.slice_binning_factor 221 | self.back_scatter = False 222 | #Broadcasts b to a, size(a) > size(b) 223 | self.assign_broadcast = lambda a, b : a - a + b 224 | 225 | fxlin, fylin = self._genFrequencyGrid() 226 | self.fzlin = ((self.RI/self.wavelength)**2 - fxlin*af.conjg(fxlin) - fylin*af.conjg(fylin))**0.5 227 | self.prop_kernel_phase = 1.0j * 2.0 * np.pi * self.fzlin 228 | 229 | def _genRealGrid(self, fftshift = False): 230 | xlin = genGrid(self.shape[1], self.pixel_size, flag_shift = fftshift) 231 | ylin = genGrid(self.shape[0], self.pixel_size, flag_shift = fftshift) 232 | xlin = af.tile(xlin.T, self.shape[0], 1) 233 | ylin = af.tile(ylin, 1, self.shape[1]) 234 | return xlin, ylin 235 | 236 | def _genFrequencyGrid(self): 237 | fxlin = genGrid(self.shape[1], 1.0/self.pixel_size/self.shape[1], flag_shift = True) 238 | fylin = genGrid(self.shape[0], 1.0/self.pixel_size/self.shape[0], flag_shift = True) 239 | fxlin = af.tile(fxlin.T, self.shape[0], 1) 240 | fylin = af.tile(fylin, 1, self.shape[1]) 241 | return fxlin, fylin 242 | 243 | def _genIllumination(self, fx_illu, fy_illu): 244 | fx_illu, fy_illu = self._setIlluminationOnGrid(fx_illu, fy_illu) 245 | xlin, ylin = self._genRealGrid() 246 | fz_illu = ((self.RI/self.wavelength)**2 - fx_illu**2 - fy_illu**2)**0.5 247 | illumination_xy = af.exp(1.0j*2.0*np.pi*(fx_illu*xlin + fy_illu*ylin)) 248 | return illumination_xy, fx_illu, fy_illu, fz_illu 249 | 250 | def _setIlluminationOnGrid(self, fx_illu, fy_illu): 251 | dfx = 1.0/self.pixel_size/self.shape[1] 252 | dfy = 1.0/self.pixel_size/self.shape[0] 253 | fx_illu_on_grid = np.round(fx_illu/dfx)*dfx 254 | fy_illu_on_grid = np.round(fy_illu/dfy)*dfy 255 | return fx_illu_on_grid, fy_illu_on_grid 256 | 257 | 258 | def _binObject(self, obj, adjoint = False): 259 | """ 260 | function to bin the object by factor of slice_binning_factor 261 | """ 262 | if self.slice_binning_factor == 1: 263 | return obj 264 | assert self.shape[2] >= 1 265 | if adjoint: 266 | obj_out = af.constant(0.0, self._shape_full[0], self._shape_full[1], self._shape_full[2], dtype = af_complex_datatype) 267 | for idx in range((self.shape[2]-1)*self.slice_binning_factor, -1, -self.slice_binning_factor): 268 | idx_slice = slice(idx, np.min([obj_out.shape[2],idx+self.slice_binning_factor])) 269 | obj_out[:,:,idx_slice] = af.broadcast(self.assign_broadcast, obj_out[:,:,idx_slice], obj[:,:,idx//self.slice_binning_factor]) 270 | else: 271 | obj_out = af.constant(0.0, self.shape[0], self.shape[1], self.shape[2], dtype = af_complex_datatype) 272 | for idx in range(0, obj.shape[2], self.slice_binning_factor): 273 | idx_slice = slice(idx, np.min([obj.shape[2], idx+self.slice_binning_factor])) 274 | obj_out[:,:,idx//self.slice_binning_factor] = af.sum(obj[:,:,idx_slice], 2) 275 | return obj_out 276 | 277 | def _meanObject(self, obj, adjoint = False): 278 | """ 279 | function to bin the object by factor of slice_binning_factor 280 | """ 281 | if self.slice_binning_factor == 1: 282 | return obj 283 | assert self.shape[2] >= 1 284 | if adjoint: 285 | obj_out = af.constant(0.0, self._shape_full[0], self._shape_full[1], self._shape_full[2], dtype = af_complex_datatype) 286 | for idx in range((self.shape[2]-1)*self.slice_binning_factor, -1, -self.slice_binning_factor): 287 | idx_slice = slice(idx, np.min([obj_out.shape[2],idx+self.slice_binning_factor])) 288 | obj_out[:,:,idx_slice] = af.broadcast(self.assign_broadcast, obj_out[:,:,idx_slice], obj[:,:,idx//self.slice_binning_factor]) 289 | else: 290 | obj_out = af.constant(0.0, self.shape[0], self.shape[1], self.shape[2], dtype = af_complex_datatype) 291 | for idx in range(0, obj.shape[2], self.slice_binning_factor): 292 | idx_slice = slice(idx, np.min([obj.shape[2], idx+self.slice_binning_factor])) 293 | obj_out[:,:,idx//self.slice_binning_factor] = af.mean(obj[:,:,idx_slice], dim=2) 294 | return obj_out 295 | 296 | def _propagationInplace(self, field, propagation_distance, adjoint = False, in_real = True): 297 | """ 298 | propagation operator that uses angular spectrum to propagate the wave 299 | 300 | field: input field 301 | propagation_distance: distance to propagate the wave 302 | adjoint: boolean variable to perform adjoint operation (i.e. opposite direction) 303 | """ 304 | if in_real: 305 | af.fft2_inplace(field) 306 | if adjoint: 307 | field[:, :] *= af.conjg(af.exp(self.prop_kernel_phase * propagation_distance)) 308 | else: 309 | field[:, :] *= af.exp(self.prop_kernel_phase * propagation_distance) 310 | if in_real: 311 | af.ifft2_inplace(field) 312 | return field 313 | 314 | def forward(self, x_obj, fx_illu, fy_illu): 315 | pass 316 | 317 | def adjoint(self, residual, cache): 318 | pass 319 | 320 | 321 | class MultiTransmittance(ScatteringModels): 322 | """ 323 | MultiTransmittance scattering model. This class also serves as a parent class for all multi-slice scattering methods 324 | """ 325 | def __init__(self, phase_obj_3d, wavelength, **kwargs): 326 | super().__init__(phase_obj_3d, wavelength, **kwargs) 327 | self.slice_separation = [sum(phase_obj_3d.slice_separation[x:x+self.slice_binning_factor]) \ 328 | for x in range(0,len(phase_obj_3d.slice_separation),self.slice_binning_factor)] 329 | sample_thickness = np.sum(self.slice_separation) 330 | total_prop_distance = np.sum(self.slice_separation[0:self.shape[2]-1]) 331 | self.distance_end_to_center = total_prop_distance/2.# - sample_thickness/2. 332 | if np.abs(np.mean(self.slice_separation) - self.pixel_size_z) < 1e-6 or self.slice_binning_factor > 1: 333 | self.focus_at_center = True 334 | self.initial_z_position = -1 * ((self.shape[2]-1)/2.) * self.pixel_size_z 335 | else: 336 | self.focus_at_center = False 337 | self.initial_z_position = 0.0 338 | 339 | def forward(self, trans_obj, fx_illu, fy_illu): 340 | if self.slice_binning_factor > 1: 341 | print("Slicing is not implemented for MultiTransmittance algorithm!") 342 | raise 343 | 344 | #compute illumination 345 | field, fx_illu, fy_illu, fz_illu = self._genIllumination(fx_illu, fy_illu) 346 | field[:, :] *= np.exp(1.0j * 2.0 * np.pi * fz_illu * self.initial_z_position) 347 | 348 | #multi-slice transmittance forward propagation 349 | field_layer_conj = af.constant(0.0, trans_obj.shape[0], trans_obj.shape[1], trans_obj.shape[2], dtype = af_complex_datatype) 350 | if (type(trans_obj).__module__ == np.__name__): 351 | trans_obj_af = af.to_array(trans_obj) 352 | flag_gpu_inout = False 353 | else: 354 | trans_obj_af = trans_obj 355 | flag_gpu_inout = True 356 | 357 | for layer in range(self.shape[2]): 358 | field_layer_conj[:, :, layer] = af.conjg(field) 359 | field[:, :] *= trans_obj_af[:, :, layer] 360 | if layer < self.shape[2]-1: 361 | field = self._propagationInplace(field, self.slice_separation[layer]) 362 | 363 | #store intermediate variables for adjoint operation 364 | cache = (trans_obj_af, field_layer_conj, flag_gpu_inout) 365 | 366 | if self.focus_at_center: 367 | #propagate to volume center 368 | field = self._propagationInplace(field, self.distance_end_to_center, adjoint = True) 369 | return {'forward_scattered_field': field, 'cache': cache} 370 | 371 | def adjoint(self, residual, cache): 372 | trans_obj_af, field_layer_conj_or_grad, flag_gpu_inout = cache 373 | 374 | #back-propagte to volume center 375 | field_bp = residual 376 | if self.focus_at_center: 377 | #propagate to the last layer 378 | field_bp = self._propagationInplace(field_bp, self.distance_end_to_center) 379 | 380 | #multi-slice transmittance backward 381 | for layer in range(self.shape[2]-1, -1, -1): 382 | field_layer_conj_or_grad[:, :, layer] = field_bp * field_layer_conj_or_grad[:, :, layer] 383 | if layer > 0: 384 | field_bp[:, :] *= af.conjg(trans_obj_af[:, :, layer]) 385 | field_bp = self._propagationInplace(field_bp, self.slice_separation[layer-1], adjoint = True) 386 | if flag_gpu_inout: 387 | return {'gradient':field_layer_conj_or_grad} 388 | else: 389 | return {'gradient':np.array(field_layer_conj_or_grad)} 390 | 391 | class MultiPhaseContrast(MultiTransmittance): 392 | """ MultiPhaseContrast, solves directly for the phase contrast {i.e. Transmittance = exp(sigma * PhaseContrast)} """ 393 | def __init__(self, phase_obj_3d, wavelength, sigma = 1, **kwargs): 394 | super().__init__(phase_obj_3d, wavelength, **kwargs) 395 | self.sigma = sigma 396 | 397 | def forward(self, contrast_obj, fx_illu, fy_illu): 398 | #compute illumination 399 | with contexttimer.Timer() as timer: 400 | field, fx_illu, fy_illu, fz_illu = self._genIllumination(fx_illu, fy_illu) 401 | field[:, :] *= np.exp(1.0j * 2.0 * np.pi * fz_illu * self.initial_z_position) 402 | field_layer_conj = af.constant(0.0, self.shape[0], self.shape[1], self.shape[2], dtype = af_complex_datatype) 403 | 404 | if (type(contrast_obj).__module__ == np.__name__): 405 | phasecontrast_obj_af = af.to_array(contrast_obj) 406 | flag_gpu_inout = False 407 | else: 408 | phasecontrast_obj_af = contrast_obj 409 | flag_gpu_inout = True 410 | 411 | #Binning 412 | obj_af = self._binObject(phasecontrast_obj_af) 413 | 414 | #Potentials to Transmittance 415 | obj_af = af.exp(1.0j * self.sigma * obj_af) 416 | 417 | for layer in range(self.shape[2]): 418 | field_layer_conj[:, :, layer] = af.conjg(field) 419 | field[:, :] *= obj_af[:, :, layer] 420 | if layer < self.shape[2]-1: 421 | field = self._propagationInplace(field, self.slice_separation[layer]) 422 | 423 | #propagate to volume center 424 | cache = (obj_af, field_layer_conj, flag_gpu_inout) 425 | 426 | if self.focus_at_center: 427 | #propagate to volume center 428 | field = self._propagationInplace(field, self.distance_end_to_center, adjoint = True) 429 | return {'forward_scattered_field': field, 'cache': cache} 430 | 431 | def adjoint(self, residual, cache): 432 | phasecontrast_obj_af, field_layer_conj_or_grad, flag_gpu_inout = cache 433 | trans_obj_af_conj = af.conjg(phasecontrast_obj_af) 434 | #back-propagte to volume center 435 | field_bp = residual 436 | #propagate to the last layer 437 | if self.focus_at_center: 438 | field_bp = self._propagationInplace(field_bp, self.distance_end_to_center) 439 | #multi-slice transmittance backward 440 | for layer in range(self.shape[2]-1, -1, -1): 441 | field_layer_conj_or_grad[:, :, layer] = field_bp * field_layer_conj_or_grad[:, :, layer] * (-1.0j) * self.sigma * trans_obj_af_conj[:,:,layer] 442 | if layer > 0: 443 | field_bp[:, :] *= trans_obj_af_conj[:, :, layer] 444 | field_bp = self._propagationInplace(field_bp, self.slice_separation[layer-1], adjoint = True) 445 | 446 | #Unbinning 447 | grad = self._binObject(field_layer_conj_or_grad, adjoint = True) 448 | 449 | phasecontrast_obj_af = None 450 | if flag_gpu_inout: 451 | return {'gradient':grad} 452 | else: 453 | return {'gradient':np.array(grad)} 454 | 455 | class MultiBorn(MultiTransmittance): 456 | """Multislice algorithm that computes scattered field with first Born approximation at every slice""" 457 | def __init__(self, phase_obj_3d, wavelength, **kwargs): 458 | super().__init__(phase_obj_3d, wavelength, **kwargs) 459 | self.distance_end_to_center = np.sum(self.slice_separation)/2. + phase_obj_3d.slice_separation[-1] 460 | self.distance_beginning_to_center = np.sum(self.slice_separation)/2. 461 | self.slice_separation = np.append(phase_obj_3d.slice_separation, [phase_obj_3d.slice_separation[-1]]) 462 | self.slice_separation = np.asarray([sum(self.slice_separation[x:x+self.slice_binning_factor]) \ 463 | for x in range(0,len(self.slice_separation),self.slice_binning_factor)]) 464 | if self.slice_separation.shape[0] != self.shape[2]: 465 | print("Number of slices does not match with number of separations!") 466 | 467 | # generate green's function convolution kernel 468 | fxlin, fylin = self._genFrequencyGrid() 469 | kernel_mask = (self.RI/self.wavelength)**2 > 1.01*(fxlin*af.conjg(fxlin) + fylin*af.conjg(fylin)) 470 | self.green_kernel_2d = -0.25j * af.exp(2.0j*np.pi*self.fzlin*self.pixel_size_z) / np.pi / self.fzlin 471 | self.green_kernel_2d*= kernel_mask 472 | self.green_kernel_2d[af.isnan(self.green_kernel_2d)==1] = 0.0 473 | 474 | def forward(self, V_obj, fx_illu, fy_illu): 475 | if (type(V_obj).__module__ == np.__name__): 476 | V_obj_af = af.to_array(V_obj) 477 | flag_gpu_inout = False 478 | else: 479 | V_obj_af = V_obj 480 | flag_gpu_inout = True 481 | 482 | #Binning 483 | obj_af = self._meanObject(V_obj_af) 484 | 485 | #compute illumination 486 | field, fx_illu, fy_illu, fz_illu = self._genIllumination(fx_illu, fy_illu) 487 | field[:, :] *= np.exp(1.0j * 2.0 * np.pi * fz_illu * self.initial_z_position) 488 | field_layer_in = af.constant(0.0, self.shape[0], self.shape[1], self.shape[2], dtype = af_complex_datatype) 489 | for layer in range(self.shape[2]): 490 | field_layer_in[:, :, layer] = field 491 | field = self._propagationInplace(field, self.slice_separation[layer]) 492 | field_scat = af.ifft2(af.fft2(field_layer_in[:, :, layer] * obj_af[:, :, layer])*\ 493 | self.green_kernel_2d) * self.pixel_size_z 494 | field[:, :] += field_scat 495 | 496 | #store intermediate variables for adjoint operation 497 | cache = (obj_af, field_layer_in, flag_gpu_inout) 498 | 499 | if self.focus_at_center: 500 | #propagate to volume center 501 | field = self._propagationInplace(field, self.distance_end_to_center, adjoint = True) 502 | return {'forward_scattered_field': field, 'cache': cache} 503 | 504 | 505 | def adjoint(self, residual, cache): 506 | V_obj_af, field_layer_in_or_grad,\ 507 | flag_gpu_inout = cache 508 | 509 | #back-propagte to volume center 510 | field_bp = residual 511 | 512 | if self.focus_at_center: 513 | #propagate to the last layer 514 | field_bp = self._propagationInplace(field_bp, self.distance_end_to_center) 515 | 516 | for layer in range(self.shape[2]-1, -1, -1): 517 | field_bp_1 = af.ifft2(af.fft2(field_bp) * af.conjg(self.green_kernel_2d)) * self.pixel_size_z 518 | 519 | if layer > 0: 520 | field_bp = self._propagationInplace(field_bp, self.slice_separation[layer], adjoint = True) 521 | field_bp[:, :] += field_bp_1 * af.conjg(V_obj_af[:, :, layer]) 522 | 523 | field_layer_in_or_grad[:, :, layer] = field_bp_1 * af.conjg(field_layer_in_or_grad[:, :, layer]) 524 | 525 | #Unbinning 526 | grad = self._meanObject(field_layer_in_or_grad, adjoint = True) 527 | 528 | if flag_gpu_inout: 529 | return {'gradient':grad} 530 | else: 531 | return {'gradient':np.array(grad)} -------------------------------------------------------------------------------- /opticaltomography/opticsutil.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implement utilities using GPU 3 | 4 | Michael Chen mchen0405@berkeley.edu 5 | David Ren david.ren@berkeley.edu 6 | November 22, 2017 7 | """ 8 | 9 | import numpy as np 10 | import arrayfire as af 11 | import skimage.io as skio 12 | import scipy.io as sio 13 | import contexttimer 14 | import tkinter 15 | import matplotlib.pyplot as plt 16 | from math import factorial 17 | from scipy.ndimage.filters import uniform_filter 18 | from tkinter.filedialog import askdirectory 19 | from matplotlib.widgets import Slider 20 | from os import listdir, path 21 | import sys 22 | from opticaltomography import settings 23 | 24 | np_float_datatype = settings.np_float_datatype 25 | af_complex_datatype = settings.af_complex_datatype 26 | MAX_DIM = 512*512*512 if settings.bit == 32 else 512*512*256 27 | 28 | 29 | 30 | def show3DStack(image_3d, axis = 2, cmap = "gray", clim = (0, 1), extent = (0, 1, 0, 1)): 31 | if axis == 0: 32 | image = lambda index: image_3d[index, :, :] 33 | elif axis == 1: 34 | image = lambda index: image_3d[:, index, :] 35 | else: 36 | image = lambda index: image_3d[:, :, index] 37 | 38 | current_idx= 0 39 | _, ax = plt.subplots(1, 1, figsize=(6.5, 5)) 40 | plt.subplots_adjust(left=0.15, bottom=0.15) 41 | fig = ax.imshow(image(current_idx), cmap = cmap, clim = clim, extent = extent) 42 | ax.set_title("layer: " + str(current_idx)) 43 | plt.colorbar(fig, ax=ax) 44 | plt.axis('off') 45 | ax_slider = plt.axes([0.15, 0.1, 0.65, 0.03]) 46 | slider_obj = Slider(ax_slider, "layer", 0, image_3d.shape[axis]-1, valinit=current_idx, valfmt='%d') 47 | def update_image(index): 48 | global current_idx 49 | index = int(index) 50 | current_idx = index 51 | ax.set_title("layer: " + str(index)) 52 | fig.set_data(image(index)) 53 | def arrow_key(event): 54 | global current_idx 55 | if event.key == "left": 56 | if current_idx-1 >=0: 57 | current_idx -= 1 58 | elif event.key == "right": 59 | if current_idx+1 < image_3d.shape[axis]: 60 | current_idx += 1 61 | slider_obj.set_val(current_idx) 62 | slider_obj.on_changed(update_image) 63 | plt.gcf().canvas.mpl_connect("key_release_event", arrow_key) 64 | plt.show() 65 | return slider_obj 66 | 67 | def compare3DStack(stack_1, stack_2, axis = 2, cmap = "gray", clim = (0, 1), extent = (0, 1, 0, 1) , colorbar = True): 68 | assert stack_1.shape == stack_2.shape, "shape of two input stacks should be the same!" 69 | 70 | if axis == 0: 71 | image_1 = lambda index: stack_1[index, :, :] 72 | image_2 = lambda index: stack_2[index, :, :] 73 | elif axis == 1: 74 | image_1 = lambda index: stack_1[:, index, :] 75 | image_2 = lambda index: stack_2[:, index, :] 76 | else: 77 | image_1 = lambda index: stack_1[:, :, index] 78 | image_2 = lambda index: stack_2[:, :, index] 79 | 80 | current_idx = 0 81 | _, ax = plt.subplots(1, 2, figsize=(9, 5), sharex = 'all', sharey = 'all') 82 | plt.subplots_adjust(left=0.15, bottom=0.15) 83 | fig_1 = ax[0].imshow(image_1(current_idx), cmap = cmap, clim = clim, extent = extent) 84 | ax[0].axis("off") 85 | ax[0].set_title("stack 1, layer: " + str(current_idx)) 86 | plt.colorbar(fig_1, ax = ax[0]) 87 | fig_2 = ax[1].imshow(image_2(current_idx), cmap = cmap, clim = clim, extent = extent) 88 | ax[1].axis("off") 89 | ax[1].set_title("stack 2, layer: " + str(current_idx)) 90 | plt.colorbar(fig_2, ax = ax[1]) 91 | ax_slider = plt.axes([0.10, 0.05, 0.65, 0.03]) 92 | slider_obj = Slider(ax_slider, 'layer', 0, stack_1.shape[axis]-1, valinit=current_idx, valfmt='%d') 93 | def update_image(index): 94 | global current_idx 95 | index = int(index) 96 | current_idx = index 97 | ax[0].set_title("stack 1, layer: " + str(index)) 98 | fig_1.set_data(image_1(index)) 99 | ax[1].set_title("stack 2, layer: " + str(index)) 100 | fig_2.set_data(image_2(index)) 101 | def arrow_key(event): 102 | global current_idx 103 | if event.key == "left": 104 | if current_idx-1 >=0: 105 | current_idx -= 1 106 | elif event.key == "right": 107 | if current_idx+1 < stack_1.shape[axis]: 108 | current_idx += 1 109 | slider_obj.set_val(current_idx) 110 | slider_obj.on_changed(update_image) 111 | plt.gcf().canvas.mpl_connect("key_release_event", arrow_key) 112 | plt.show() 113 | return slider_obj 114 | 115 | def calculateNumericalGradient(func, x, point, delta = 1e-4): 116 | function_value_0 = func(x) 117 | x_shift = x.copy() 118 | x_shift[point] += delta 119 | function_value_re = func(x_shift) 120 | grad_re = (function_value_re - function_value_0)/delta 121 | 122 | x_shift = x.copy() 123 | x_shift[point] += 1.0j* delta 124 | function_value_im = func(x_shift) 125 | grad_im = (function_value_im - function_value_0)/delta 126 | gradient = 0.5*(grad_re + 1.0j*grad_im) 127 | 128 | return gradient 129 | 130 | def cart2Pol(x, y): 131 | rho = (x * af.conjg(x) + y * af.conjg(y))**0.5 132 | theta = af.atan2(af.real(y), af.real(x)).as_type(af_complex_datatype) 133 | return rho, theta 134 | 135 | def genZernikeAberration(shape, pixel_size, NA, wavelength, z_coeff = [1], z_index_list = [0]): 136 | assert len(z_coeff) == len(z_index_list), "number of coefficients does not match with number of zernike indices!" 137 | 138 | pupil = genPupil(shape, pixel_size, NA, wavelength) 139 | fxlin = genGrid(shape[1], 1/pixel_size/shape[1], flag_shift = True) 140 | fylin = genGrid(shape[0], 1/pixel_size/shape[0], flag_shift = True) 141 | fxlin = af.tile(fxlin.T, shape[0], 1) 142 | fylin = af.tile(fylin, 1, shape[1]) 143 | rho, theta = cart2Pol(fxlin, fylin) 144 | rho[:, :] /= NA/wavelength 145 | 146 | def zernikePolynomial(z_index): 147 | n = int(np.ceil((-3.0 + np.sqrt(9+8*z_index))/2.0)) 148 | m = 2*z_index - n*(n+2) 149 | normalization_coeff = np.sqrt(2 * (n+1)) if abs(m) > 0 else np.sqrt(n+1) 150 | azimuthal_function = af.sin(abs(m)*theta) if m < 0 else af.cos(abs(m)*theta) 151 | zernike_poly = af.constant(0.0, shape[0], shape[1], dtype = af_complex_datatype) 152 | for k in range((n-abs(m))//2+1): 153 | zernike_poly[:, :] += ((-1)**k * factorial(n-k))/ \ 154 | (factorial(k)*factorial(0.5*(n+m)-k)*factorial(0.5*(n-m)-k))\ 155 | * rho**(n-2*k) 156 | 157 | return normalization_coeff * zernike_poly * azimuthal_function 158 | 159 | for z_coeff_index, z_index in enumerate(z_index_list): 160 | zernike_poly = zernikePolynomial(z_index) 161 | 162 | if z_coeff_index == 0: 163 | zernike_aberration = z_coeff.ravel()[z_coeff_index] * zernike_poly 164 | else: 165 | zernike_aberration[:, :] += z_coeff.ravel()[z_coeff_index] * zernike_poly 166 | 167 | return zernike_aberration * pupil 168 | 169 | def genPupil(shape, pixel_size, NA, wavelength): 170 | assert len(shape) == 2, "pupil should be two dimensional!" 171 | fxlin = genGrid(shape[1], 1/pixel_size/shape[1], flag_shift = True) 172 | fylin = genGrid(shape[0], 1/pixel_size/shape[0], flag_shift = True) 173 | fxlin = af.tile(fxlin.T, shape[0], 1) 174 | fylin = af.tile(fylin, 1, shape[1]) 175 | 176 | pupil_radius = NA/wavelength 177 | pupil = (fxlin**2 + fylin**2 <= pupil_radius**2).as_type(af_complex_datatype) 178 | return pupil 179 | 180 | def propKernel(shape, pixel_size, wavelength, prop_distance, NA = None, RI = 1.0, band_limited=True): 181 | assert len(shape) == 2, "pupil should be two dimensional!" 182 | fxlin = genGrid(shape[1], 1/pixel_size/shape[1], flag_shift = True) 183 | fylin = genGrid(shape[0], 1/pixel_size/shape[0], flag_shift = True) 184 | fxlin = af.tile(fxlin.T, shape[0], 1) 185 | fylin = af.tile(fylin, 1, shape[1]) 186 | 187 | if band_limited: 188 | assert NA is not None, "need to provide numerical aperture of the system!" 189 | Pcrop = genPupil(shape, pixel_size, NA, wavelength) 190 | else: 191 | Pcrop = 1.0 192 | 193 | prop_kernel = Pcrop * af.exp(1.0j * 2.0 * np.pi * abs(prop_distance) * Pcrop *\ 194 | ((RI/wavelength)**2 - fxlin**2 - fylin**2)**0.5) 195 | prop_kernel = af.conjg(prop_kernel) if prop_distance < 0 else prop_kernel 196 | return prop_kernel 197 | 198 | def genGrid(size, dx, flag_shift = False): 199 | """ 200 | This function generates 1D Fourier grid, and is centered at the middle of the array 201 | Inputs: 202 | size - length of the array 203 | dx - pixel size 204 | Optional parameters: 205 | flag_shift - flag indicating whether the final array is circularly shifted 206 | should be false when computing real space coordinates 207 | should be true when computing Fourier coordinates 208 | Outputs: 209 | xlin - 1D Fourier grid 210 | 211 | """ 212 | xlin = (af.range(size) - size//2) * dx 213 | if flag_shift: 214 | xlin = af.shift(xlin, -1 * size//2) 215 | return xlin.as_type(af_complex_datatype) 216 | 217 | class EmptyClass: 218 | def __init__(self): 219 | pass 220 | def forward(self, field, zernike_coeffs=None, x_shift=None, y_shift=None, **kwargs): 221 | return field 222 | def adjoint(self, field, zernike_coeffs=None, x_shift=None, y_shift=None, **kwargs): 223 | return field 224 | 225 | class ImageCrop: 226 | """ 227 | Class for image cropping 228 | """ 229 | def __init__(self, shape, pad = True, pad_size = None, **kwargs): 230 | """ 231 | shape: shape of object (y,x,z) 232 | pixel_size: pixel size of the system 233 | pad: boolean variable to pad the reconstruction 234 | pad_size: if pad is true, default pad_size is shape//2. Takes a tuple, pad size in dimensions (y, x) 235 | """ 236 | self.shape = shape 237 | self.pad = pad 238 | if self.pad: 239 | self.pad_size = pad_size 240 | if self.pad_size == None: 241 | self.pad_size = (self.shape[0]//4, self.shape[1]//4) 242 | self.row_crop = slice(self.pad_size[0], self.shape[0] - self.pad_size[0]) 243 | self.col_crop = slice(self.pad_size[1], self.shape[1] - self.pad_size[1]) 244 | else: 245 | self.row_crop = slice(0, self.shape[0]) 246 | self.col_crop = slice(0, self.shape[1]) 247 | 248 | def forward(self, field): 249 | return field[self.row_crop, self.col_crop] 250 | 251 | def adjoint(self, residual): 252 | if len(residual.shape) > 2: 253 | field_pad = af.constant(0.0, self.shape[0], self.shape[1], residual.shape[2], dtype = af_complex_datatype) 254 | for z_idx in range(field_pad.shape[2]): 255 | field_pad[self.row_crop, self.col_crop, z_idx] = residual[:, :, z_idx] + 0.0j 256 | else: 257 | field_pad = af.constant(0.0, self.shape[0], self.shape[1], dtype = af_complex_datatype) 258 | field_pad[self.row_crop, self.col_crop] = residual + 0.0j 259 | return field_pad 260 | 261 | -------------------------------------------------------------------------------- /opticaltomography/regularizers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Regularizer class for that also supports GPU code 3 | 4 | Michael Chen mchen0405@berkeley.edu 5 | David Ren david.ren@berkeley.edu 6 | March 04, 2018 7 | """ 8 | 9 | import arrayfire as af 10 | import numpy as np 11 | from opticaltomography import settings 12 | 13 | np_complex_datatype = settings.np_complex_datatype 14 | np_float_datatype = settings.np_float_datatype 15 | af_float_datatype = settings.af_float_datatype 16 | af_complex_datatype = settings.af_complex_datatype 17 | 18 | class Regularizer: 19 | """ 20 | Highest-level Regularizer class that is responsible for parsing user arguments to create proximal operators 21 | All proximal operators operate on complex variables (real & imaginary part separately) 22 | Pure Amplitude: 23 | pure_amplitude: boolean, whether or not to enforce object to be purely amplitude 24 | Pure phase: 25 | pure_phase: boolean, whether or not to enforce object to be purely phase object 26 | Pure Real: 27 | pure_real: boolean, whether or not to enforce object to be purely real 28 | 29 | Pure imaginary: 30 | pure_imag: boolean, whether or not to enforce object to be purely imaginary 31 | 32 | Positivity: 33 | positivity_real(positivity_imag): boolean, whether or not to enforce positivity for real(imaginary) part 34 | 35 | Negativity: 36 | negativity_real(negativity_imag): boolean, whether or not to enforce negativity for real(imaginary) part 37 | 38 | LASSO (L1 regularizer): 39 | lasso: boolean, whether or not to use LASSO proximal operator 40 | lasso_parameter: threshold for LASSO 41 | 42 | Total variation (3D only): 43 | total_variation: boolean, whether or not to use total variation regularization 44 | total_variation_gpu: boolean, whether or not to use GPU implementation 45 | total_variation_parameter: scalar, regularization parameter (lambda) 46 | total_variation_maxitr: integer, number of each iteration for total variation 47 | """ 48 | def __init__(self, configs = None, verbose = True, **kwargs): 49 | #Given all parameters, construct all proximal operators 50 | self.prox_list = [] 51 | reg_params = kwargs 52 | if configs != None: 53 | reg_params = self._parseConfigs(configs) 54 | 55 | #Purely real 56 | if reg_params.get("pure_real", False): 57 | self.prox_list.append(PureReal()) 58 | 59 | #Purely imaginary 60 | if reg_params.get("pure_imag", False): 61 | self.prox_list.append(Pureimag()) 62 | 63 | #Purely amplitude object 64 | if reg_params.get("pure_amplitude", False): 65 | self.prox_list.append(PureAmplitude()) 66 | 67 | #Purely phase object 68 | if reg_params.get(("pure_phase"), False): 69 | self.prox_list.append(PurePhase()) 70 | 71 | #Total Variation 72 | if reg_params.get("total_variation", False): 73 | if reg_params.get("total_variation_gpu", False): 74 | self.prox_list.append(TotalVariationGPU(**reg_params)) 75 | else: 76 | self.prox_list.append(TotalVariationCPU(**reg_params)) 77 | 78 | #L1 Regularizer (LASSO) 79 | elif reg_params.get("lasso", False): 80 | self.prox_list.append(Lasso(reg_params.get("lasso_parameter", 1.0))) 81 | 82 | #Others 83 | else: 84 | #Positivity 85 | positivity_real = reg_params.get("positivity_real", False) 86 | positivity_imag = reg_params.get("positivity_imag", False) 87 | if positivity_real or positivity_imag: 88 | self.prox_list.append(Positivity(positivity_real, positivity_imag)) 89 | 90 | #Negativity 91 | negativity_real = reg_params.get("negativity_real", False) 92 | negativity_imag = reg_params.get("negativity_imag", False) 93 | if negativity_real or negativity_imag: 94 | self.prox_list.append(Negativity(negativity_real, negativity_imag)) 95 | 96 | if verbose: 97 | for prox_op in self.prox_list: 98 | print("Regularizer -", prox_op.proximal_name) 99 | 100 | def _parseConfigs(self, configs): 101 | params = {} 102 | params["pure_real"] = configs.pure_real 103 | params["pure_imag"] = configs.pure_imag 104 | #Total variation 105 | params["total_variation"] = configs.total_variation 106 | params["total_variation_gpu"] = configs.total_variation_gpu 107 | params["total_variation_maxitr"] = configs.max_iter_tv 108 | params["total_variation_order"] = configs.order_tv 109 | params["total_variation_parameter"] = configs.reg_tv 110 | #LASSO 111 | params["lasso"] = configs.lasso 112 | params["lasso_parameter"] = configs.reg_lasso 113 | #Pure amplitude/Pure phase 114 | params["pure_amplitude"] = configs.pure_amplitude 115 | params["pure_phase"] = configs.pure_phase 116 | #Positivity/Negativity 117 | if configs.positivity_real[0]: 118 | if configs.positivity_real[1] == "larger": 119 | params["positivity_real"] = True 120 | elif configs.positivity_real[1] == "smaller": 121 | params["negativity_real"] = True 122 | else: 123 | raise ValueError("positivity/negativity parameter not recognized") 124 | 125 | if configs.positivity_imag[0]: 126 | if configs.positivity_imag[1] == "larger": 127 | params["positivity_imag"] = True 128 | elif configs.positivity_imag[1] == "smaller": 129 | params["negativity_imag"] = True 130 | else: 131 | raise ValueError("positivity/negativity parameter not recognized") 132 | return params 133 | 134 | def computeCost(self, x): 135 | cost = 0.0 136 | for prox_op in self.prox_list: 137 | cost_temp = prox_op.computeCost(x) 138 | if cost_temp != None: 139 | cost += cost_temp 140 | return cost 141 | 142 | def applyRegularizer(self, x): 143 | for prox_op in self.prox_list: 144 | x = prox_op.computeProx(x) 145 | return x 146 | 147 | class ProximalOperator(): 148 | def __init__(self, proximal_name): 149 | self.proximal_name = proximal_name 150 | def computeCost(self): 151 | pass 152 | def computeProx(self): 153 | pass 154 | def setParameter(self): 155 | pass 156 | def _boundRealValue(self, x, value = 0, flag_project = True): 157 | """If flag is true, only values that are greater than 'value' are preserved""" 158 | if flag_project: 159 | x[x < value] = 0 160 | return x 161 | 162 | class TotalVariationGPU(ProximalOperator): 163 | def __init__(self, **kwargs): 164 | proximal_name = "Total Variation" 165 | parameter = kwargs.get("total_variation_parameter", 1.0) 166 | maxitr = kwargs.get("total_variation_maxitr", 15) 167 | self.order = kwargs.get("total_variation_order", 1) 168 | self.pure_real = kwargs.get("pure_real", False) 169 | self.pure_imag = kwargs.get("pure_imag", False) 170 | self.pure_amplitude = kwargs.get("pure_amplitude", False) 171 | self.pure_phase = kwargs.get("pure_phase", False) 172 | 173 | #real part 174 | if kwargs.get("positivity_real", False): 175 | self.realProjector = lambda x: self._boundRealValue(x, 0, True) 176 | proximal_name = "%s+%s" % (proximal_name, "positivity_real") 177 | elif kwargs.get("negativity_real", False): 178 | self.realProjector = lambda x: -1.0 * self._boundRealValue(-1.0 * x, 0, True) 179 | proximal_name = "%s+%s" % (proximal_name, "negativity_real") 180 | else: 181 | self.realProjector = lambda x: x 182 | 183 | #imaginary part 184 | if kwargs.get("positivity_imag", False): 185 | self.imagProjector = lambda x: self._boundRealValue(x, 0, True) 186 | proximal_name = "%s+%s" % (proximal_name, "positivity_imag") 187 | elif kwargs.get("negativity_imag", False): 188 | self.imagProjector = lambda x: -1.0 * self._boundRealValue(-1.0 * x, 0, True) 189 | proximal_name = "%s+%s" % (proximal_name, "negativity_imag") 190 | else: 191 | self.imagProjector = lambda x: x 192 | self.setParameter(parameter, maxitr) 193 | super().__init__(proximal_name) 194 | 195 | 196 | 197 | def setParameter(self, parameter, maxitr): 198 | self.parameter = parameter 199 | self.maxitr = maxitr 200 | 201 | def computeCost(self, x): 202 | return None 203 | 204 | def computeProx(self, x): 205 | if self.pure_real: 206 | x = self._computeProxReal(af.real(x), self.realProjector) + 1.0j * 0.0 207 | elif self.pure_imag: 208 | x = 1.0j *self._computeProxReal(af.imag(x), self.imagProjector) 209 | elif self.pure_amplitude: 210 | x = self._computeProxReal(af.abs(x), self.realProjector) * af.exp(1.0j * 0 * x) 211 | elif self.pure_phase: 212 | x = af.exp(1.0j * self._computeProxReal(af.arg(x), self.realProjector)) 213 | else: 214 | # x = self._computeProxReal(af.real(x), self.realProjector) + \ 215 | # 1.0j * self._computeProxReal(af.imag(x), self.imagProjector) 216 | temp = self._computeProxReal(af.real(x), self.realProjector) + 1.0j * 0.0 217 | self.setParameter(self.parameter / 1.0, self.maxitr) 218 | x = temp + 1.0j * self._computeProxReal(af.imag(x), self.imagProjector) 219 | self.setParameter(self.parameter * 1.0, self.maxitr) 220 | return x 221 | 222 | def _indexLastAxis(self, x, index = 0): 223 | if len(x.shape) == 3: 224 | return x[:,:,index] 225 | elif len(x.shape) == 4: 226 | return x[:,:,:,index] 227 | return 228 | 229 | def _computeTVNorm(self, x): 230 | x_norm = x**2 231 | x_norm = af.sum(x_norm, dim = 3 if len(x.shape) == 4 else 2)**0.5 232 | x_norm[x_norm<1.0] = 1.0 233 | return x_norm 234 | 235 | def _filterD(self, x, axis): 236 | assert axis<3, "This function only supports matrix up to 3 dimension!" 237 | if len(x.shape) == 2: 238 | if self.order == 1: 239 | if axis == 0: 240 | Dx = x - af.shift(x, 1, 0) 241 | elif axis == 1: 242 | Dx = x - af.shift(x, 0, 1) 243 | elif self.order == 2: 244 | if axis == 0: 245 | Dx = x - 2*af.shift(x, 1, 0) + af.shift(x, 2, 0) 246 | elif axis == 1: 247 | Dx = x - 2*af.shift(x, 0, 1) + af.shift(x, 0, 2) 248 | elif self.order == 3: 249 | if axis == 0: 250 | Dx = x - 3*af.shift(x, 1, 0) + 3*af.shift(x, 2, 0) - af.shift(x, 3, 0) 251 | elif axis == 1: 252 | Dx = x - 3*af.shift(x, 0, 1) + 3*af.shift(x, 0, 2) - af.shift(x, 0, 3) 253 | else: 254 | raise NotImplementedError("filter orders larger than 3 are not implemented!") 255 | elif len(x.shape) == 3: 256 | if self.order == 1: 257 | if axis == 0: 258 | Dx = x - af.shift(x, 1, 0, 0) 259 | elif axis == 1: 260 | Dx = x - af.shift(x, 0, 1, 0) 261 | else: 262 | Dx = x - af.shift(x, 0, 0, 1) 263 | elif self.order == 2: 264 | if axis == 0: 265 | Dx = x - 2*af.shift(x, 1, 0, 0) + af.shift(x, 2, 0, 0) 266 | elif axis == 1: 267 | Dx = x - 2*af.shift(x, 0, 1, 0) + af.shift(x, 0, 2, 0) 268 | else: 269 | Dx = x - 2*af.shift(x, 0, 0, 1) + af.shift(x, 0, 0, 2) 270 | elif self.order == 3: 271 | if axis == 0: 272 | Dx = x - 3*af.shift(x, 1, 0, 0) + 3*af.shift(x, 2, 0, 0) - af.shift(x, 3, 0, 0) 273 | elif axis == 1: 274 | Dx = x - 3*af.shift(x, 0, 1, 0) + 3*af.shift(x, 0, 2, 0) - af.shift(x, 0, 3, 0) 275 | else: 276 | Dx = x - 3*af.shift(x, 0, 0, 1) + 3*af.shift(x, 0, 0, 2) - af.shift(x, 0, 0, 3) 277 | else: 278 | raise NotImplementedError("filter orders larger than 3 are not implemented!") 279 | return Dx 280 | 281 | def _filterDT(self, x): 282 | if self.order == 1: 283 | if len(x.shape) == 3: 284 | DTx = self._indexLastAxis(x, 0) - af.shift(self._indexLastAxis(x, 0), -1, 0) + \ 285 | self._indexLastAxis(x, 1) - af.shift(self._indexLastAxis(x, 1), 0, -1) 286 | elif len(x.shape) == 4: 287 | DTx = self._indexLastAxis(x, 0) - af.shift(self._indexLastAxis(x, 0), -1, 0, 0) + \ 288 | self._indexLastAxis(x, 1) - af.shift(self._indexLastAxis(x, 1), 0, -1, 0) + \ 289 | self._indexLastAxis(x, 2) - af.shift(self._indexLastAxis(x, 2), 0, 0, -1) 290 | elif self.order == 2: 291 | if len(x.shape) == 3: 292 | DTx = self._indexLastAxis(x, 0) - 2*af.shift(self._indexLastAxis(x, 0), -1, 0) + af.shift(self._indexLastAxis(x, 0), -2, 0) + \ 293 | self._indexLastAxis(x, 1) - 2*af.shift(self._indexLastAxis(x, 1), 0, -1) + af.shift(self._indexLastAxis(x, 1), 0, -2) 294 | elif len(x.shape) == 4: 295 | DTx = self._indexLastAxis(x, 0) - 2*af.shift(self._indexLastAxis(x, 0), -1, 0, 0) + af.shift(self._indexLastAxis(x, 0), -2, 0, 0) + \ 296 | self._indexLastAxis(x, 1) - 2*af.shift(self._indexLastAxis(x, 1), 0, -1, 0) + af.shift(self._indexLastAxis(x, 1), 0, -2, 0) + \ 297 | self._indexLastAxis(x, 2) - 2*af.shift(self._indexLastAxis(x, 2), 0, 0, -1) + af.shift(self._indexLastAxis(x, 2), 0, 0, -2) 298 | elif self.order == 3: 299 | if len(x.shape) == 3: 300 | DTx = self._indexLastAxis(x, 0) - 3*af.shift(self._indexLastAxis(x, 0), -1, 0) + 3*af.shift(self._indexLastAxis(x, 0), -2, 0) - af.shift(self._indexLastAxis(x, 0), -3, 0) + \ 301 | self._indexLastAxis(x, 1) - 3*af.shift(self._indexLastAxis(x, 1), 0, -1) + 3*af.shift(self._indexLastAxis(x, 1), 0, -2) - af.shift(self._indexLastAxis(x, 1), 0, -3) 302 | elif len(x.shape) == 4: 303 | DTx = self._indexLastAxis(x, 0) - 3*af.shift(self._indexLastAxis(x, 0), -1, 0, 0) + 3*af.shift(self._indexLastAxis(x, 0), -2, 0, 0) - af.shift(self._indexLastAxis(x, 0), -3, 0, 0) + \ 304 | self._indexLastAxis(x, 1) - 3*af.shift(self._indexLastAxis(x, 1), 0, -1, 0) + 3*af.shift(self._indexLastAxis(x, 1), 0, -2, 0) - af.shift(self._indexLastAxis(x, 1), 0, -3, 0) + \ 305 | self._indexLastAxis(x, 2) - 3*af.shift(self._indexLastAxis(x, 2), 0, 0, -1) + 3*af.shift(self._indexLastAxis(x, 2), 0, 0, -2) - af.shift(self._indexLastAxis(x, 2), 0, 0, -3) 306 | else: 307 | raise NotImplementedError("filter orders larger than 3 are not implemented!") 308 | return DTx 309 | 310 | def _computeProxReal(self, x, projector): 311 | t_k = 1.0 312 | 313 | def _gradUpdate(): 314 | grad_u_hat = x - self.parameter * self._filterDT(u_k1) 315 | return grad_u_hat 316 | 317 | if len(x.shape) == 2: 318 | u_k = af.constant(0.0, x.shape[0], x.shape[1], 2, dtype = af_float_datatype) 319 | u_k1 = af.constant(0.0, x.shape[0], x.shape[1], 2, dtype = af_float_datatype) 320 | # grad_u_hat = af.constant(0.0, x.shape[0], x.shape[1], dtype = af_float_datatype) 321 | elif len(x.shape) == 3: 322 | u_k = af.constant(0.0, x.shape[0], x.shape[1], x.shape[2], 3, dtype = af_float_datatype) 323 | u_k1 = af.constant(0.0, x.shape[0], x.shape[1], x.shape[2], 3, dtype = af_float_datatype) 324 | # grad_u_hat = af.constant(0.0, x.shape[0], x.shape[1], x.shape[2], dtype = af_float_datatype) 325 | 326 | for iteration in range(self.maxitr): 327 | if iteration > 0: 328 | grad_u_hat = _gradUpdate() 329 | else: 330 | grad_u_hat = x.copy() 331 | 332 | grad_u_hat = projector(grad_u_hat) 333 | if len(x.shape) == 2: #2D case 334 | u_k1[:,:, 0] = self._indexLastAxis(u_k1, 0) + (1.0/(8.0)**self.order/self.parameter) * self._filterD(grad_u_hat, axis=0) 335 | u_k1[:,:, 1] = self._indexLastAxis(u_k1, 1) + (1.0/(8.0)**self.order/self.parameter) * self._filterD(grad_u_hat, axis=1) 336 | elif len(x.shape) == 3: #3D case 337 | u_k1[:,:,:, 0] = self._indexLastAxis(u_k1, 0) + (1.0/(12.0)**self.order/self.parameter) * self._filterD(grad_u_hat, axis=0) 338 | u_k1[:,:,:, 1] = self._indexLastAxis(u_k1, 1) + (1.0/(12.0)**self.order/self.parameter) * self._filterD(grad_u_hat, axis=1) 339 | u_k1[:,:,:, 2] = self._indexLastAxis(u_k1, 2) + (1.0/(12.0)**self.order/self.parameter) * self._filterD(grad_u_hat, axis=2) 340 | grad_u_hat = None 341 | u_k1_norm = self._computeTVNorm(u_k1) 342 | if len(x.shape) == 2: #2D case 343 | u_k1[:,:, 0] /= u_k1_norm 344 | u_k1[:,:, 1] /= u_k1_norm 345 | if len(x.shape) == 3: #3D case 346 | u_k1[:,:,:, 0]/= u_k1_norm 347 | u_k1[:,:,:, 1]/= u_k1_norm 348 | u_k1[:,:,:, 2]/= u_k1_norm 349 | u_k1_norm = None 350 | t_k1 = 0.5 * (1.0 + (1.0 + 4.0*t_k**2)**0.5) 351 | beta = (t_k - 1.0)/t_k1 352 | if len(x.shape) == 2: #2D case 353 | temp = u_k[:,:,0].copy() 354 | if iteration < self.maxitr - 1: 355 | u_k[:,:,0] = u_k1[:,:,0] 356 | u_k1[:,:,0] = (1.0 + beta)*u_k1[:,:,0] - beta*temp #now u_hat 357 | temp = u_k[:,:,1].copy() 358 | if iteration < self.maxitr - 1: 359 | u_k[:,:,1] = u_k1[:,:,1] 360 | u_k1[:,:,1] = (1.0 + beta)*u_k1[:,:,1] - beta*temp 361 | elif len(x.shape) == 3: #3D case 362 | temp = u_k[:,:,:,0].copy() 363 | if iteration < self.maxitr - 1: 364 | u_k[:,:,:,0] = u_k1[:,:,:,0] 365 | u_k1[:,:,:,0] = (1.0 + beta)*u_k1[:,:,:,0] - beta*temp #now u_hat 366 | temp = u_k[:,:,:,1].copy() 367 | if iteration < self.maxitr - 1: 368 | u_k[:,:,:,1] = u_k1[:,:,:,1] 369 | u_k1[:,:,:,1] = (1.0 + beta)*u_k1[:,:,:,1] - beta*temp 370 | temp = u_k[:,:,:,2].copy() 371 | if iteration < self.maxitr - 1: 372 | u_k[:,:,:,2] = u_k1[:,:,:,2] 373 | u_k1[:,:,:,2] = (1.0 + beta)*u_k1[:,:,:,2] - beta*temp 374 | temp = None 375 | 376 | grad_u_hat = projector(_gradUpdate()) 377 | u_k = None 378 | u_k1 = None 379 | return grad_u_hat 380 | 381 | class TotalVariationCPU(TotalVariationGPU): 382 | def _computeTVNorm(self, x): 383 | u_k1_norm = af.to_array(x) 384 | u_k1_norm[:] *= u_k1_norm 385 | u_k1_norm = af.sum(u_k1_norm, dim = 3 if len(x.shape) == 4 else 2)**0.5 386 | u_k1_norm[u_k1_norm<1.0] = 1.0 387 | return np.array(u_k1_norm) 388 | 389 | def computeProx(self, x): 390 | if self.pure_real: 391 | x = self._computeProxReal(np.real(x), self.realProjector) + 1.0j * 0.0 392 | elif self.pure_imag: 393 | x = 1.0j *self._computeProxReal(np.imag(x), self.imagProjector) 394 | else: 395 | x = self._computeProxReal(np.real(x), self.realProjector) \ 396 | + 1.0j * self._computeProxReal(np.imag(x), self.imagProjector) 397 | return af.to_array(x) 398 | 399 | def _computeProxReal(self, x, projector): 400 | t_k = 1.0 401 | u_k = np.zeros(x.shape + (3 if len(x.shape) == 3 else 2,), dtype = np_float_datatype); 402 | u_k1 = u_k.copy() 403 | u_hat = u_k.copy() 404 | 405 | def _gradUpdate(): 406 | u_hat_af = af.to_array(u_hat) 407 | if len(x.shape) == 2: 408 | DTu_hat = self._indexLastAxis(u_hat_af, 0) - af.shift(self._indexLastAxis(u_hat_af, 0), -1, 0) + \ 409 | self._indexLastAxis(u_hat_af, 1) - af.shift(self._indexLastAxis(u_hat_af, 1), 0, -1) 410 | elif len(x.shape) == 3: 411 | DTu_hat = self._indexLastAxis(u_hat_af, 0) - af.shift(self._indexLastAxis(u_hat_af, 0), -1, 0, 0) + \ 412 | self._indexLastAxis(u_hat_af, 1) - af.shift(self._indexLastAxis(u_hat_af, 1), 0, -1, 0) + \ 413 | self._indexLastAxis(u_hat_af, 2) - af.shift(self._indexLastAxis(u_hat_af, 2), 0, 0, -1) 414 | grad_u_hat = x - np.array(self.parameter * DTu_hat) 415 | return grad_u_hat 416 | 417 | for iteration in range(self.maxitr): 418 | if iteration > 0: 419 | grad_u_hat = _gradUpdate() 420 | else: 421 | grad_u_hat = x.copy() 422 | 423 | grad_u_hat = projector(grad_u_hat) 424 | u_k1[..., 0] = u_hat[..., 0] + (1.0/12.0/self.parameter) * (grad_u_hat-np.roll(grad_u_hat, 1, axis = 0)) 425 | u_k1[..., 1] = u_hat[..., 1] + (1.0/12.0/self.parameter) * (grad_u_hat-np.roll(grad_u_hat, 1, axis = 1)) 426 | if len(x.shape) == 3: 427 | u_k1[..., 2] = u_hat[..., 2] + (1.0/12.0/self.parameter) * (grad_u_hat-np.roll(grad_u_hat, 1, axis = 2)) 428 | u_k1_norm = self._computeTVNorm(u_k1) 429 | u_k1[:] /= u_k1_norm[..., np.newaxis] 430 | t_k1 = 0.5 * (1.0 + (1.0 + 4.0*t_k**2)**0.5) 431 | beta = (t_k - 1.0)/t_k1 432 | u_hat = (1.0 + beta)*u_k1 - beta*u_k 433 | if iteration < self.maxitr - 1: 434 | u_k = u_k1.copy() 435 | return projector(_gradUpdate()) 436 | 437 | 438 | 439 | class Positivity(ProximalOperator): 440 | """Enforce positivity constraint on a complex variable's real & imaginary part.""" 441 | def __init__(self, positivity_real, positivity_imag, proximal_name = "Positivity"): 442 | super().__init__(proximal_name) 443 | self.real = positivity_real 444 | self.imag = positivity_imag 445 | 446 | def computeCost(self, x): 447 | return None 448 | 449 | def computeProx(self, x): 450 | if type(x).__module__ == "arrayfire.array": 451 | x = self._boundRealValue(af.real(x), 0, self.real) +\ 452 | 1.0j * self._boundRealValue(af.imag(x), 0, self.imag) 453 | else: 454 | x = self._boundRealValue(x.real, 0, self.real) +\ 455 | 1.0j * self._boundRealValue(x.imag, 0, self.imag) 456 | return x 457 | 458 | class Negativity(Positivity): 459 | """Enforce positivity constraint on a complex variable's real & imaginary part.""" 460 | def __init__(self, negativity_real, negativity_imag): 461 | super().__init__(negativity_real, negativity_imag, "Negativity") 462 | 463 | def computeProx(self, x): 464 | return (-1.) * super().computeProx((-1.) * x) 465 | 466 | class PureReal(ProximalOperator): 467 | """Enforce real constraint on a complex, imaginary part will be cleared""" 468 | def __init__(self): 469 | super().__init__("Pure real") 470 | 471 | def computeCost(self, x): 472 | return None 473 | 474 | def computeProx(self, x): 475 | if type(x).__module__ == "arrayfire.array": 476 | x = af.real(x) + 1j*0.0 477 | else: 478 | x = x.real + 1j*0.0 479 | return x 480 | 481 | class Pureimag(ProximalOperator): 482 | """Enforce imaginary constraint on a complex, real part will be cleared""" 483 | def __init__(self): 484 | super().__init__("Pure imaginary") 485 | 486 | def computeCost(self, x): 487 | return None 488 | 489 | def computeProx(self, x): 490 | if type(x).__module__ == "arrayfire.array": 491 | x = 1j*af.imag(x) 492 | else: 493 | x = 1j*x.imag 494 | return x 495 | 496 | class Lasso(ProximalOperator): 497 | """||x||_1 regularizer, soft thresholding with certain parameter""" 498 | def __init__(self, parameter): 499 | super().__init__("LASSO") 500 | self.setParameter(parameter) 501 | 502 | def _softThreshold(self, x): 503 | if type(x).__module__ == "arrayfire.array": 504 | #POTENTIAL BUG: af.sign implementation does not agree with documentation 505 | x = (af.sign(x)-0.5)*(-2.0) * (af.abs(x) - self.parameter) * (af.abs(x) > self.parameter) 506 | else: 507 | x = np.sign(x) * (np.abs(x) - self.parameter) * (np.abs(x) > self.parameter) 508 | return x 509 | 510 | def setParameter(self, parameter): 511 | self.parameter = parameter 512 | 513 | def computeCost(self, x): 514 | return af.norm(af.moddims(x, np.prod(x.shape)), norm_type = af.NORM.VECTOR_1) 515 | 516 | def computeProx(self, x): 517 | if type(x).__module__ == "arrayfire.array": 518 | x = self._softThreshold(af.real(x)) + 1.0j * self._softThreshold(af.imag(x)) 519 | else: 520 | x = self._softThreshold(x.real) + 1.0j * self._softThreshold(x.imag) 521 | return x 522 | 523 | #TODO: implement Tikhonov 524 | class Tikhonov(ProximalOperator): 525 | def __init__(self): 526 | pass 527 | def setParameter(self, parameter): 528 | self.parameter = parameter 529 | def computeCost(self, x): 530 | pass 531 | def computeProx(self, x): 532 | return x 533 | 534 | class PureAmplitude(ProximalOperator): 535 | def __init__(self): 536 | super().__init__("Purely Amplitude") 537 | def computeCost(self, x): 538 | return None 539 | def computeProx(self, x): 540 | return af.abs(x) * af.exp(1.0j * 0 * x) 541 | 542 | class PurePhase(ProximalOperator): 543 | def __init__(self): 544 | super().__init__("Purely Phase") 545 | def computeCost(self, x): 546 | return None 547 | def computeProx(self, x): 548 | return af.exp(1.0j*af.arg(x)) 549 | 550 | -------------------------------------------------------------------------------- /opticaltomography/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | variables setting 3 | """ 4 | import arrayfire as af 5 | bit = 32 6 | 7 | np_float_datatype = "float32" if bit == 32 else "float64" 8 | np_complex_datatype = "complex64" if bit == 32 else "complex128" 9 | af_float_datatype = af.Dtype.f32 if bit == 32 else af.Dtype.f64 10 | af_complex_datatype = af.Dtype.c32 if bit == 32 else af.Dtype.c64 -------------------------------------------------------------------------------- /sample_script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import sys\n", 17 | "import numpy as np\n", 18 | "import scipy.io as sio\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "import os\n", 21 | "import contexttimer\n", 22 | "%matplotlib notebook\n", 23 | "%load_ext autoreload\n", 24 | "%autoreload 2\n", 25 | "import arrayfire as af\n", 26 | "af.set_device(0)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Specify the path of your repo" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "from opticaltomography.opticsutil import compare3DStack, show3DStack\n", 43 | "from opticaltomography.opticsalg import PhaseObject3D, TomographySolver, AlgorithmConfigs" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "# Specify parameters & load data" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# Units in microns\n", 60 | "wavelength = 0.514\n", 61 | "n_measure = 1.0\n", 62 | "n_b = 1.0\n", 63 | "maginification = 80.\n", 64 | "dx = 6.5 / maginification\n", 65 | "dy = 6.5 / maginification\n", 66 | "dz = 3 * dx\n", 67 | "na = 0.65" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "#Make sure the path is correct\n", 77 | "#Illumination angle, change to [0.0] if only on-axis is needed:\n", 78 | "na_list = sio.loadmat(\"na_list_test.mat\")\n", 79 | "fx_illu_list = na_list[\"na_list\"][150:,0] / wavelength\n", 80 | "fy_illu_list = na_list[\"na_list\"][150:,1] / wavelength" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "# Plot object in z (y,x,z)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "scrolled": false 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "phantom = np.ones((400,400,10),dtype=\"complex64\") * n_b\n", 99 | "show3DStack(np.real(phantom), axis=2, clim=(np.min(np.real(phantom)), np.max(np.real(phantom))))" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## Fill in phantom" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "x, y = np.meshgrid(np.linspace(-1,1,phantom.shape[0]), np.linspace(-1,1,phantom.shape[1]))\n", 116 | "r2 = x ** 2 + y ** 2\n", 117 | "phantom[...,4] += (r2 < 0.25 ** 2) * 0.1 / (2 * np.pi * dz / wavelength)\n", 118 | "show3DStack(np.real(phantom), axis=2, clim=(np.min(np.real(phantom)), np.max(np.real(phantom))))" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "# Setup solver objects" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "scrolled": true 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "solver_params = dict(wavelength = wavelength, na = na, \\\n", 137 | " RI_measure = n_measure, sigma = 2 * np.pi * dz / wavelength,\\\n", 138 | " fx_illu_list = fx_illu_list, fy_illu_list = fy_illu_list,\\\n", 139 | " pad = True, pad_size = (25,25))\n", 140 | "phase_obj_3d = PhaseObject3D(shape=phantom.shape, voxel_size=(dy,dx,dz), RI=n_b, RI_obj=phantom)\n", 141 | "solver_obj = TomographySolver(phase_obj_3d, **solver_params)\n", 142 | "# Forward simulation method\n", 143 | "# solver_obj.setScatteringMethod(model = \"MultiPhaseContrast\")\n", 144 | "solver_obj.setScatteringMethod(model = \"MultiBorn\")" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "# Generate forward prediction" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "with contexttimer.Timer() as timer:\n", 161 | " forward_field_mb = solver_obj.forwardPredict(field=False)\n", 162 | " print(timer.elapsed) \n", 163 | "forward_field_mb = np.squeeze(forward_field_mb) " 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "#plot\n", 173 | "%matplotlib notebook\n", 174 | "show3DStack(np.real(forward_field_mb), axis=2, clim=(np.min(np.real(forward_field_mb)), np.max(np.real(forward_field_mb))))" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "# Solving an inverse problem" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "#Create a class for all inverse problem parameters\n", 191 | "configs = AlgorithmConfigs()\n", 192 | "configs.batch_size = 1\n", 193 | "configs.method = \"FISTA\"\n", 194 | "configs.restart = True\n", 195 | "configs.max_iter = 5\n", 196 | "# multislice stepsize\n", 197 | "# configs.stepsize = 2e-4\n", 198 | "# multiborn stepsize\n", 199 | "configs.stepsize = 10\n", 200 | "configs.error = []\n", 201 | "configs.pure_real = True\n", 202 | "#total variation regularization\n", 203 | "configs.total_variation = False\n", 204 | "configs.reg_tv = 1.0 #lambda\n", 205 | "configs.max_iter_tv = 15\n", 206 | "configs.order_tv = 1\n", 207 | "configs.total_variation_gpu = True\n", 208 | "configs.total_variation_anisotropic = False\n", 209 | "\n", 210 | "# reconstruction method\n", 211 | "# solver_obj.setScatteringMethod(model = \"MultiPhaseContrast\")\n", 212 | "solver_obj.setScatteringMethod(model = \"MultiBorn\")" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "recon_obj_3d = solver_obj.solve(configs, forward_field_mb)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": {}, 227 | "source": [ 228 | "## Plotting results" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "current_rec = recon_obj_3d\n", 238 | "cost = solver_obj.configs.error\n", 239 | "show3DStack(np.real(current_rec), axis=2, clim=(np.min(np.real(current_rec)), np.max(np.real(current_rec))))" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "plt.figure()\n", 249 | "plt.plot(np.log10(cost))" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [] 258 | } 259 | ], 260 | "metadata": { 261 | "kernelspec": { 262 | "display_name": "Python 3", 263 | "language": "python", 264 | "name": "python3" 265 | }, 266 | "language_info": { 267 | "codemirror_mode": { 268 | "name": "ipython", 269 | "version": 3 270 | }, 271 | "file_extension": ".py", 272 | "mimetype": "text/x-python", 273 | "name": "python", 274 | "nbconvert_exporter": "python", 275 | "pygments_lexer": "ipython3", 276 | "version": "3.6.5" 277 | } 278 | }, 279 | "nbformat": 4, 280 | "nbformat_minor": 2 281 | } 282 | --------------------------------------------------------------------------------