├── .gitignore ├── 10x10.png ├── 6x6.png ├── CompareNN_MatlabBilinearInterp.py ├── GetMeasuredDiffractionPattern.py ├── LICENSE ├── PropagateSphericalAperture.py ├── README.md ├── SquareWFtest └── CameraNoise │ └── 1_1000 │ └── Bild_1.png ├── TestPropagate.py ├── _main.py ├── addnoise.py ├── buildui.sh ├── datagen.dat ├── datagen.py ├── diffraction_functions.py ├── diffraction_net.py ├── get_single_sample.py ├── live_capture ├── TIS.py ├── live_stream_test.py ├── setup.txt └── tiscamera_0.11.1_amd64.deb ├── main.py ├── main.ui ├── matlab_cdi ├── cdi_interpolation.m ├── seeded_reconst_func.m └── seeded_run_CDI_noprocessing.m ├── mp_plotdata └── plot.py ├── nice_plot_nn_iterative_compared_noise.py ├── noise_resistant_net.py ├── noise_test_compareCDI.sh ├── params.py ├── requirements.txt ├── run_tests.sh ├── rungdb.sh ├── sample.p ├── size_6um_pitch_600nm_diameter_300nm_psize_5nm.png └── submit_gpu_job.slurm /.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | *.png 3 | *.jpg 4 | *.vim 5 | *.gif 6 | *.p 7 | coco_dataset/ 8 | tags 9 | tensorboard_graph/ 10 | *.hdf5 11 | .grep* 12 | models/ 13 | matlab_cdi/*.mat 14 | mp_plotdata/ 15 | __pycache__/ 16 | venv 17 | zernike3/build/*.dat 18 | matlab_cdi/loaddata.m 19 | 20 | -------------------------------------------------------------------------------- /10x10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwhite510/wavefront-sensor-neural-network/3a521fea276119438fe07b6e2b777e677c1f83c4/10x10.png -------------------------------------------------------------------------------- /6x6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwhite510/wavefront-sensor-neural-network/3a521fea276119438fe07b6e2b777e677c1f83c4/6x6.png -------------------------------------------------------------------------------- /CompareNN_MatlabBilinearInterp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from GetMeasuredDiffractionPattern import GetMeasuredDiffractionPattern 3 | from numpy import unravel_index 4 | import scipy 5 | import diffraction_functions 6 | from PIL import Image, ImageDraw 7 | import matplotlib.pyplot as plt 8 | import diffraction_net 9 | import tables 10 | import pickle 11 | import os 12 | from scipy import interpolate 13 | import argparse 14 | import params 15 | import imageio 16 | 17 | def get_interpolation_points(amplitude_mask): 18 | """ 19 | get the points for bilinear interp 20 | """ 21 | x=[] 22 | y=[] 23 | 24 | # plt.figure() 25 | # plt.title("Left Points") 26 | # plt.pcolormesh(amplitude_mask) 27 | for col in [35,41,48,54,61]: 28 | for row in [93,86,80,74,67,61,54,48,41,35]: 29 | x.append(col) 30 | y.append(row) 31 | # plt.axvline(x=col,color="red") 32 | # plt.axhline(y=row,color="blue") 33 | 34 | # plt.figure() 35 | # plt.title("Right Upper Points") 36 | # plt.pcolormesh(amplitude_mask) 37 | for col in [67,73,80,86,93]: 38 | for row in [93,86,80,74,67]: 39 | x.append(col) 40 | y.append(row) 41 | # plt.axvline(x=col,color="red") 42 | # plt.axhline(y=row,color="blue") 43 | 44 | # plt.figure() 45 | # plt.title("Right Lower Points") 46 | # plt.pcolormesh(amplitude_mask) 47 | for col in [66,72,79,85,91]: 48 | for row in [62,55,49,42,36]: 49 | x.append(col) 50 | y.append(row) 51 | # plt.axvline(x=col,color="red") 52 | # plt.axhline(y=row,color="blue") 53 | 54 | return x,y 55 | 56 | class CompareNetworkIterative(): 57 | def __init__(self, args): 58 | # retrieve image with neural network 59 | self.network=diffraction_net.DiffractionNet(args.network,args.net_type) # load a pre trained network 60 | self.args=args 61 | self.net_type=args.net_type 62 | 63 | def test(self,index,folder): 64 | m_index=(128,128) 65 | # load diffraction pattern 66 | # index=11 67 | # index=9 # best 68 | N=None 69 | with tables.open_file(self.args.network+"_test_noise.hdf5",mode="r") as file: 70 | N = file.root.N[0,0] 71 | object_real = file.root.object_real[index, :].reshape(N,N) 72 | object_imag = file.root.object_imag[index, :].reshape(N,N) 73 | diffraction = file.root.diffraction_noise[index, :].reshape(N,N) 74 | diffraction_noisefree = file.root.diffraction_noisefree[index, :].reshape(N,N) 75 | 76 | actual_object = {} 77 | actual_object["measured_pattern"] = diffraction 78 | actual_object["tf_reconstructed_diff"] = diffraction_noisefree 79 | actual_object["real_output"] = object_real 80 | actual_object["imag_output"] = object_imag 81 | 82 | # fig=diffraction_functions.plot_amplitude_phase_meas_retreival(actual_object,"actual_object",ACTUAL=True,m_index=m_index,mask=False) 83 | 84 | # # get the reconstructed diffraction pattern and the real / imaginary object 85 | nn_retrieved = {} 86 | nn_retrieved["measured_pattern"] = diffraction 87 | nn_retrieved["tf_reconstructed_diff"] = self.network.sess.run( 88 | self.network.nn_nodes["recons_diffraction_pattern"], feed_dict={self.network.x:diffraction.reshape(1,N,N,1)}) 89 | nn_retrieved["real_output"] = self.network.sess.run( 90 | self.network.nn_nodes["real_out"], feed_dict={self.network.x:diffraction.reshape(1,N,N,1)}) 91 | nn_retrieved["imag_output"] = self.network.sess.run( 92 | self.network.nn_nodes["imag_out"], feed_dict={self.network.x:diffraction.reshape(1,N,N,1)}) 93 | 94 | # with open("nn_retrieved.p","wb") as file: 95 | # pickle.dump(nn_retrieved,file) 96 | # with open("nn_retrieved.p","rb") as file: 97 | # nn_retrieved=pickle.load(file) 98 | 99 | # plot retrieval with neural network 100 | # fig=diffraction_functions.plot_amplitude_phase_meas_retreival(nn_retrieved,"nn_retrieved",m_index=m_index) 101 | 102 | # get amplitude mask 103 | N = np.shape(nn_retrieved["measured_pattern"])[1] 104 | _, amplitude_mask = diffraction_functions.get_amplitude_mask_and_imagesize(N, int(params.params.wf_ratio*N)) 105 | # get interpolation points 106 | 107 | # run matlab retrieval with and without interpolation 108 | matlabcdi_retrieved_interp=diffraction_functions.matlab_cdi_retrieval(np.squeeze(nn_retrieved['measured_pattern']),amplitude_mask,interpolate=True) 109 | # with open("matlab_cdi_retrieval.p","wb") as file: 110 | # pickle.dump(matlabcdi_retrieved_interp,file) 111 | # with open("matlab_cdi_retrieval.p","rb") as file: 112 | # matlabcdi_retrieved_interp=pickle.load(file) 113 | 114 | # fig=diffraction_functions.plot_amplitude_phase_meas_retreival(matlabcdi_retrieved_interp,"matlabcdi_retrieved_interp",m_index=m_index) 115 | 116 | # compare and calculate phase + intensity error 117 | network={} 118 | iterative={} 119 | 120 | phase_rmse,intensity_rmse=intensity_phase_error(actual_object,matlabcdi_retrieved_interp,"matlabcdi_retrieved_interp_"+str(index),folder) 121 | print("phase_rmse: ",phase_rmse," intensity_rmse: ",intensity_rmse) 122 | iterative['phase_rmse']=phase_rmse 123 | iterative['intensity_rmse']=intensity_rmse 124 | 125 | phase_rmse,intensity_rmse=intensity_phase_error(actual_object,nn_retrieved,"nn_retrieved_"+str(index),folder) 126 | print("phase_rmse: ",phase_rmse," intensity_rmse: ",intensity_rmse) 127 | network['phase_rmse']=phase_rmse 128 | network['intensity_rmse']=intensity_rmse 129 | 130 | # rmse=intensity_phase_error(actual_object,nn_retrieved) 131 | # actual_object 132 | # matlabcdi_retrieved_interp 133 | # nn_retrieved 134 | # plt.close('all') 135 | # plt.show() 136 | return network,iterative 137 | 138 | def simulated_test(self,N): 139 | # measured sample test 140 | 141 | network_error_phase=[] 142 | network_error_intensity=[] 143 | iterative_error_phase=[] 144 | iterative_error_intensity=[] 145 | 146 | network_error=PhaseIntensityError() 147 | iterative_error=PhaseIntensityError() 148 | for i in range(0,N): 149 | network,iterative=self.test(i,'test_pc_'+args.pc) 150 | network_error.phase_error.values.append(network['phase_rmse']) 151 | network_error.intensity_error.values.append(network['intensity_rmse']) 152 | 153 | iterative_error.phase_error.values.append(iterative['phase_rmse']) 154 | iterative_error.intensity_error.values.append(iterative['intensity_rmse']) 155 | 156 | # calculate statistics 157 | network_error.calculate_statistics() 158 | iterative_error.calculate_statistics() 159 | 160 | 161 | # save these to a pickle 162 | errorvals={} 163 | errorvals["network_error"]=network_error 164 | errorvals["iterative_error"]=iterative_error 165 | with open('error_'+args.pc+'.p','wb') as file: 166 | pickle.dump(errorvals,file) 167 | 168 | def retrieve_measured(self,measured,figtitle,mask=False): 169 | # retrieve with network 170 | # plot 171 | N=256 172 | retrieved = {} 173 | retrieved["measured_pattern"] = measured 174 | retrieved["tf_reconstructed_diff"] = self.network.sess.run( 175 | self.network.nn_nodes["recons_diffraction_pattern"], 176 | feed_dict={self.network.x:measured.reshape(1,N,N,1)} 177 | ) 178 | retrieved["real_output"] = self.network.sess.run( 179 | self.network.nn_nodes["real_out"], 180 | feed_dict={self.network.x:measured.reshape(1,N,N,1)} 181 | ) 182 | retrieved["imag_output"] = self.network.sess.run( 183 | self.network.nn_nodes["imag_out"], 184 | feed_dict={self.network.x:measured.reshape(1,N,N,1)} 185 | ) 186 | if self.net_type=='nr': 187 | retrieved["coefficients"]=self.network.sess.run( 188 | self.network.nn_nodes["out_zcoefs"], 189 | feed_dict={self.network.x:measured.reshape(1,N,N,1)} 190 | ) 191 | retrieved["scale"]=self.network.sess.run( 192 | self.network.nn_nodes["out_scale"], 193 | feed_dict={self.network.x:measured.reshape(1,N,N,1)} 194 | ) 195 | fig=diffraction_functions.plot_amplitude_phase_meas_retreival(retrieved,figtitle,mask=mask) 196 | return retrieved,fig 197 | 198 | def matlab_cdi_retrieval(self,measured,figtitle,mask=False): 199 | measured=np.squeeze(measured) 200 | N = np.shape(measured)[1] 201 | _, amplitude_mask = diffraction_functions.get_amplitude_mask_and_imagesize(N, int(params.params.wf_ratio*N)) 202 | retrieved=diffraction_functions.matlab_cdi_retrieval(measured,amplitude_mask,interpolate=True) 203 | 204 | fig=diffraction_functions.plot_amplitude_phase_meas_retreival(retrieved,figtitle,mask=mask) 205 | return retrieved,fig 206 | 207 | 208 | def get_train_sample(self,index): 209 | with tables.open_file(self.args.network+"_train_noise.hdf5",mode="r") as file: 210 | N = file.root.N[0,0] 211 | object_real = file.root.object_real[index, :].reshape(N,N) 212 | object_imag = file.root.object_imag[index, :].reshape(N,N) 213 | diffraction = file.root.diffraction_noise[index, :].reshape(N,N) 214 | diffraction_noisefree = file.root.diffraction_noisefree[index, :].reshape(N,N) 215 | 216 | obj={} 217 | obj["measured_pattern"]=diffraction 218 | obj["tf_reconstructed_diff"]=diffraction_noisefree 219 | obj["imag_output"]=object_imag 220 | obj["real_output"]=object_real 221 | return obj 222 | 223 | 224 | def intensity_phase_error(actual,predicted,title,folder): 225 | """ 226 | actual, predicted 227 | : dictionaries with keys: 228 | 229 | measured_pattern 230 | tf_reconstructed_diff 231 | real_output 232 | imag_output 233 | 234 | """ 235 | actual_c = actual["real_output"]+1j*actual["imag_output"] 236 | predicted_c = predicted["real_output"]+1j*predicted["imag_output"] 237 | actual_c=np.squeeze(actual_c) 238 | predicted_c=np.squeeze(predicted_c) 239 | # both normalized 240 | actual_c=actual_c/np.max(np.abs(actual_c)) 241 | predicted_c=predicted_c/np.max(np.abs(predicted_c)) 242 | 243 | print("title: ",title) 244 | print("np.max(np.abs(actual_c)) =>", np.max(np.abs(actual_c))) 245 | print("np.max(np.abs(predicted_c)) =>", np.max(np.abs(predicted_c))) 246 | 247 | # get wavefront sensor 248 | N=np.shape(actual_c)[0] 249 | _,amplitude_mask=diffraction_functions.get_amplitude_mask_and_imagesize(N,int(params.params.wf_ratio*N)) 250 | # get wavefront sensor boundary 251 | w_l=0 252 | w_r=N-1 253 | w_t=0 254 | w_b=N-1 255 | while np.sum(amplitude_mask,axis=1)[w_t]==0: 256 | w_t+=1 257 | while np.sum(amplitude_mask,axis=1)[w_b]==0: 258 | w_b-=1 259 | while np.sum(amplitude_mask,axis=0)[w_l]==0: 260 | w_l+=1 261 | while np.sum(amplitude_mask,axis=0)[w_r]==0: 262 | w_r-=1 263 | 264 | w_t+=8 265 | w_b-=8 266 | w_l+=8 267 | w_r-=8 268 | 269 | # set both to 0 outside wavefront sensor area 270 | predicted_c[0:w_t,:]=0 271 | predicted_c[w_b:,:]=0 272 | predicted_c[:,0:w_l]=0 273 | predicted_c[:,w_r:]=0 274 | actual_c[0:w_t,:]=0 275 | actual_c[w_b:,:]=0 276 | actual_c[:,0:w_l]=0 277 | actual_c[:,w_r:]=0 278 | 279 | # # set both to 0 at less than 50% predicted peak 280 | actual_c[np.abs(actual_c)**2 < 0.05 * np.max(np.abs(actual_c)**2)] = 0.0 281 | predicted_c[np.abs(actual_c)**2 < 0.05 * np.max(np.abs(actual_c)**2)] = 0.0 282 | 283 | actual_I = np.abs(actual_c)**2 284 | predicted_I = np.abs(predicted_c)**2 285 | 286 | # find intensity peak of predicted 287 | m_index = unravel_index(actual_I.argmax(), actual_I.shape) 288 | # m_index = [int(N/2),int(N/2)] 289 | # predicted_phase_Imax = np.angle(predicted_c[m_index[0], m_index[1]]) 290 | # actual_phase_Imax = np.angle(actual_c[m_index[0], m_index[1]]) 291 | 292 | # phase angle scan to find smallest phase error 293 | min_phase_angle=None 294 | min_phase_mse=999.0 295 | for d_phi in np.linspace(-2*np.pi,2*np.pi,1000): 296 | _predicted_c=np.array(predicted_c) 297 | _predicted_c*=np.exp(-1j * d_phi) 298 | # phase rmse 299 | A = np.angle(actual_c).reshape(-1) 300 | B = np.angle(_predicted_c).reshape(-1) 301 | phase_mse = (np.square(A-B)).mean() 302 | if phase_mse < min_phase_mse: 303 | min_phase_mse=phase_mse 304 | min_phase_angle=d_phi 305 | 306 | # subtract phase at center 307 | predicted_c *= np.exp(-1j * min_phase_angle) 308 | # actual_c *= np.exp(-1j * actual_phase_Imax) 309 | 310 | 311 | # phase rmse 312 | A = np.angle(actual_c).reshape(-1) 313 | B = np.angle(predicted_c).reshape(-1) 314 | phase_rmse = np.sqrt((np.square(A-B)).mean()) 315 | 316 | # intensity mse 317 | A = actual_I.reshape(-1) 318 | B = predicted_I.reshape(-1) 319 | intensity_rmse = np.sqrt((np.square(A-B)).mean()) 320 | 321 | 322 | fig = plt.figure(figsize=(10,10)) 323 | fig.suptitle(title) 324 | gs = fig.add_gridspec(4,2) 325 | 326 | # plot rmse 327 | fig.text(0.2, 0.95, "intensity_rmse:"+str(intensity_rmse)+"\n"+" phase_rmse"+str(phase_rmse) 328 | , ha="center", size=12, backgroundcolor="cyan") 329 | 330 | # measured diffraction pattern 331 | ax=fig.add_subplot(gs[0,0]) 332 | ax.set_title("Measured Diffraction Pattern") 333 | im=ax.pcolormesh(np.squeeze(predicted['measured_pattern'])) 334 | fig.colorbar(im,ax=ax) 335 | 336 | ax=fig.add_subplot(gs[0,1]) 337 | ax.set_title("Reconstructed Diffraction Pattern") 338 | im=ax.pcolormesh(np.squeeze(predicted['tf_reconstructed_diff'])) 339 | fig.colorbar(im,ax=ax) 340 | 341 | # intensity 342 | ax=fig.add_subplot(gs[1,0]) 343 | ax.set_title("actual_I") 344 | im=ax.pcolormesh(actual_I) 345 | ax.axvline(x=m_index[1],color="red",alpha=0.8) 346 | ax.axhline(y=m_index[0],color="blue",alpha=0.8) 347 | fig.colorbar(im,ax=ax) 348 | 349 | ax=fig.add_subplot(gs[1,1]) 350 | ax.set_title("predicted_I") 351 | im=ax.pcolormesh(predicted_I) 352 | ax.axvline(x=m_index[1],color="red",alpha=0.8) 353 | ax.axhline(y=m_index[0],color="blue",alpha=0.8) 354 | fig.colorbar(im,ax=ax) 355 | 356 | ax=fig.add_subplot(gs[2,0]) 357 | ax.set_title("actual_c angle") 358 | im=ax.pcolormesh(np.angle(actual_c)) 359 | ax.axvline(x=m_index[1],color="red",alpha=0.8) 360 | ax.axhline(y=m_index[0],color="blue",alpha=0.8) 361 | fig.colorbar(im,ax=ax) 362 | 363 | ax=fig.add_subplot(gs[3,0]) 364 | ax.set_title("actual_c angle") 365 | ax.plot(np.angle(actual_c)[m_index[0],:]) 366 | ax.axvline(x=m_index[1],color="red") 367 | 368 | ax=fig.add_subplot(gs[2,1]) 369 | ax.set_title("predicted_c angle") 370 | im=ax.pcolormesh(np.angle(predicted_c)) 371 | ax.axvline(x=m_index[1],color="red",alpha=0.8) 372 | ax.axhline(y=m_index[0],color="blue",alpha=0.8) 373 | fig.colorbar(im,ax=ax) 374 | 375 | ax=fig.add_subplot(gs[3,1]) 376 | ax.set_title("predicted_c angle") 377 | ax.plot(np.angle(predicted_c)[m_index[0],:]) 378 | ax.axvline(x=m_index[1],color="red") 379 | 380 | if not os.path.isdir(folder): 381 | os.mkdir(folder) 382 | fig.savefig(os.path.join(folder,title)) 383 | 384 | # plt.figure(105) 385 | # plt.pcolormesh(np.angle(predicted_c) - np.angle(actual_c)) 386 | # plt.gca().axvline(x=m_index[1],color="red",alpha=0.8) 387 | # plt.gca().axhline(y=m_index[0],color="blue",alpha=0.8) 388 | # plt.colorbar() 389 | 390 | return phase_rmse,intensity_rmse 391 | 392 | def plot_show_cm(mat,title,same_colorbar=True): 393 | mat=np.squeeze(mat) 394 | fig,ax=plt.subplots(1,2,figsize=(10,5)) 395 | fig.suptitle(title) 396 | 397 | ax[0].set_title('linear scale, center and center of mass') 398 | if same_colorbar: 399 | im=ax[0].imshow(np.squeeze(mat),cmap='jet',vmin=0.0,vmax=1.0) 400 | else: 401 | im=ax[0].imshow(np.squeeze(mat),cmap='jet') 402 | fig.colorbar(im,ax=ax[0]) 403 | # cy=diffraction_functions.calc_centroid(mat,0)# summation along columns 404 | # cx=diffraction_functions.calc_centroid(mat,1)# summation along rows 405 | m_index = unravel_index(mat.argmax(), mat.shape) 406 | cy=m_index[0] 407 | cx=m_index[1] 408 | 409 | ax[0].axvline(x=63, color="red",alpha=0.5) 410 | ax[0].axhline(y=63, color="red",alpha=0.5) 411 | ax[0].axvline(x=cx, color="yellow",alpha=0.5) 412 | ax[0].axhline(y=cy, color="yellow",alpha=0.5) 413 | 414 | # plot distance from cm 415 | ax[0].text(0.1, 0.6,"cx:"+str(cx)+"\ncy:"+str(cy), fontsize=10, ha='center', transform=ax[0].transAxes, backgroundcolor="red") 416 | 417 | ax[0].text(0.2, 0.9,"center of image", fontsize=10, ha='center', transform=ax[0].transAxes, backgroundcolor="red") 418 | ax[0].text(0.2, 0.8,"peak of intensity", fontsize=10, ha='center', transform=ax[0].transAxes, backgroundcolor="yellow") 419 | 420 | ax[1].set_title('log(image)') 421 | if same_colorbar: 422 | im=ax[1].imshow(np.squeeze(np.log(mat)),cmap='jet',vmin=-20,vmax=0.0) 423 | else: 424 | im=ax[1].imshow(np.squeeze(np.log(mat)),cmap='jet') 425 | fig.colorbar(im,ax=ax[1]) 426 | 427 | return fig 428 | 429 | 430 | 431 | # class contains list of error values, 432 | class ErrorDistribution(): 433 | def __init__(self): 434 | self.values=[] 435 | self.standard_deviation=None 436 | self.average=None 437 | def calculate_statistics(self): 438 | self.standard_deviation=np.std(self.values) 439 | self.average=np.average(self.values) 440 | 441 | 442 | class PhaseIntensityError(): 443 | def __init__(self): 444 | self.phase_error=ErrorDistribution() 445 | self.intensity_error=ErrorDistribution() 446 | 447 | def calculate_statistics(self): 448 | self.phase_error.calculate_statistics() 449 | self.intensity_error.calculate_statistics() 450 | 451 | def compare_channels(arr): 452 | _r1,_r2=1000+50,1120-50 453 | _c1,_c2=1200+40,1300-40 454 | 455 | plt.figure() 456 | plt.imshow(arr[:,:,0][_r1:_r2,_c1:_c2],cmap='jet') 457 | plt.title("channel: 0") 458 | 459 | plt.figure() 460 | plt.imshow(arr[:,:,1][_r1:_r2,_c1:_c2],cmap='jet') 461 | plt.title("channel: 1") 462 | 463 | plt.figure() 464 | plt.imshow(arr[:,:,2][_r1:_r2,_c1:_c2],cmap='jet') 465 | plt.title("channel: 2") 466 | 467 | plt.figure() 468 | plt.imshow(arr[:,:,3][_r1:_r2,_c1:_c2],cmap='jet') 469 | plt.title("channel: 3") 470 | 471 | plt.figure() 472 | plt.imshow(np.abs(arr[:,:,0][_r1:_r2,_c1:_c2]-arr[:,:,1][_r1:_r2,_c1:_c2]),cmap='jet') 473 | plt.title("channel: abs(0 - 1)") 474 | 475 | plt.figure() 476 | plt.imshow(np.abs(arr[:,:,1][_r1:_r2,_c1:_c2]-arr[:,:,2][_r1:_r2,_c1:_c2]),cmap='jet') 477 | plt.title("channel: abs(1 - 2)") 478 | 479 | plt.figure() 480 | plt.imshow(np.abs(arr[:,:,0][_r1:_r2,_c1:_c2]-arr[:,:,2][_r1:_r2,_c1:_c2]),cmap='jet') 481 | plt.title("channel: abs(0 - 2)") 482 | 483 | plt.show() 484 | 485 | 486 | 487 | if __name__ == "__main__": 488 | 489 | # TODO : evaluate rmse at high intensity areas 490 | # + phase, set constant phase shift 491 | 492 | # evaluate at different noise levels 493 | 494 | # run a variational network, run an RNN network 495 | 496 | parser=argparse.ArgumentParser() 497 | parser.add_argument('--network',type=str) 498 | parser.add_argument('--net_type',type=str) 499 | args,_=parser.parse_known_args() 500 | comparenetworkiterative = CompareNetworkIterative(args) 501 | # run test on simulated validation data 502 | # comparenetworkiterative.simulated_test(100) 503 | 504 | # list of measured images 505 | measured_images={} 506 | 507 | # # HDR image 508 | # filename='8_6_data/06082020.npy' 509 | # # filename='2307.npy' 510 | # a=np.load(filename) 511 | # a[a<0]=0 512 | # measured_images['HDR_image']=a 513 | 514 | # # filename='8_6_data/1_837/signal/Bild_1.png' 515 | 516 | # # image from one capture 517 | # filename='8_6_data/1_1541/signal/Bild_2.png' 518 | # a=Image.open(filename) 519 | # a=np.array(a) 520 | # a[a<0]=0 521 | # measured_images['one_capture']=a[:,:,0] 522 | 523 | # 2020 08 12 data, HDR images 524 | for _fn in [ 525 | # HDR images 526 | # '12_18_20_data/left/new_folder/120_1216_HDR.npy', 527 | # '12_18_20_data/left/new_folder/128_1216_HDR.npy', 528 | # '12_18_20_data/left/new_folder/134_1216_HDR.npy', 529 | # '12_18_20_data/left/new_folder/138_1216_HDR.npy', 530 | # # '12_18_20_data/left/new_folder/140_1216_HDR.npy', 531 | # # '12_18_20_data/left2/142_1216_HDR.npy', 532 | # '12_18_20_data/left2/144_1216_HDR.npy', 533 | # '12_18_20_data/left2/146_1216_HDR.npy', 534 | # '12_18_20_data/left2/148_1216_HDR.npy', 535 | # '12_18_20_data/left2/150_1216_HDR.npy', 536 | # '12_18_20_data/right/152_1216_HDR.npy', 537 | # '12_18_20_data/right/154_1216_HDR.npy', 538 | # '12_18_20_data/right/156_1216_HDR.npy', 539 | # '12_18_20_data/right/158_1216_HDR.npy', 540 | # '12_18_20_data/right/160_1216_HDR.npy', 541 | # '12_18_20_data/right/162_1216_HDR.npy', 542 | # '12_18_20_data/right/166_1216_HDR.npy', 543 | 544 | 545 | # # non HDR images 546 | # '12_18_20_data/left/new_folder/120_1216.npy', 547 | # '12_18_20_data/left/new_folder/128_1216.npy', 548 | # '12_18_20_data/left/new_folder/134_1216.npy', 549 | # '12_18_20_data/left/new_folder/138_1216.npy', 550 | # # '12_18_20_data/left/new_folder/140_1216.npy', 551 | # # '12_18_20_data/left2/142_1216.npy', 552 | # '12_18_20_data/left2/144_1216.npy', 553 | # '12_18_20_data/left2/146_1216.npy', 554 | # '12_18_20_data/left2/148_1216.npy', 555 | # '12_18_20_data/left2/150_1216.npy', 556 | # '12_18_20_data/right/152_1216.npy', 557 | # '12_18_20_data/right/154_1216.npy', 558 | # '12_18_20_data/right/156_1216.npy', 559 | # '12_18_20_data/right/158_1216.npy', 560 | # '12_18_20_data/right/160_1216.npy', 561 | # '12_18_20_data/right/162_1216.npy', 562 | # '12_18_20_data/right/166_1216.npy', 563 | 564 | '2021_01_18_processed/140_0118.npy', 565 | '2021_01_18_processed/142_0118.npy', 566 | '2021_01_18_processed/144_0118.npy', 567 | '2021_01_18_processed/146_0118.npy', 568 | '2021_01_18_processed/148_0118.npy', 569 | '2021_01_18_processed/150_0118.npy', 570 | '2021_01_18_processed/150_1216.npy', 571 | '2021_01_18_processed/152_0118.npy', 572 | '2021_01_18_processed/154_0118.npy', 573 | '2021_01_18_processed/156_0118.npy', 574 | '2021_01_18_processed/158_0118.npy', 575 | '2021_01_18_processed/160_0118.npy', 576 | ]: 577 | a=np.load(_fn) 578 | # print(_fn); plt.figure(1);plt.title(_fn);plt.imshow(a);plt.savefig('fig1.png'); 579 | a[a<0]=0 580 | measured_images[os.path.split(_fn)[-1].replace('.','_')]=a 581 | 582 | # simulated data 583 | # a=np.load('simulation_2021_01_18/2021_01_16.npy') 584 | # for i,_name in enumerate(['-7mm','-3mm','0mm','3mm','7mm']): 585 | # measured_images['simulation_2021_01_18_'+_name]=a[:,:,i] 586 | 587 | # a=Image.open(filename).convert("L") 588 | # a=np.array(a) 589 | # a[a<0]=0 590 | # measured_images['greyscale']=a 591 | 592 | experimental_params = {} 593 | experimental_params['pixel_size'] = 6.9e-6 # [meters] with 2x2 binning 594 | experimental_params['z_distance'] = 16.6e-3 # [meters] distance from camera 595 | experimental_params['wavelength'] = 612e-9 #[meters] wavelength 596 | getMeasuredDiffractionPattern = GetMeasuredDiffractionPattern(N_sim=256, 597 | N_meas=np.shape(a)[0], # for calculating the measured frequency axis (not really needed) 598 | experimental_params=experimental_params) 599 | 600 | PREFIX='A13' 601 | for _name in measured_images.keys(): 602 | transform={};transform['rotation_angle']=0;transform['scale']=1.0;transform['flip']='lr' 603 | m = getMeasuredDiffractionPattern.format_measured_diffraction_pattern(measured_images[_name], transform) 604 | # m[m<0.003*np.max(m)]=0 605 | m=np.squeeze(m) 606 | for _f,_ret_type in zip( 607 | [comparenetworkiterative.retrieve_measured], 608 | ['retrieved_nn'] 609 | ): 610 | # compare to training data set 611 | retrieved,fig=_f(m,"measured: "+_name+"\n"+_ret_type+"predicted",mask=True) 612 | fig.savefig(PREFIX+'retrieved'+_name+'_'+_ret_type) 613 | 614 | 615 | os.system('mkdir '+PREFIX) 616 | os.system('mkdir '+PREFIX+'/HDR');os.system('mkdir '+PREFIX+'/SINGLE');os.system('mkdir '+PREFIX+'/SIM') 617 | # os.system('mv ./'+PREFIX+'*HDR*png '+PREFIX+'/HDR') 618 | os.system('mv ./'+PREFIX+'*simulated*png '+PREFIX+'/SIM') 619 | os.system('mv ./'+PREFIX+'*png '+PREFIX+'/SINGLE') 620 | 621 | exit() 622 | 623 | 624 | 625 | # # retrieve samples from specific data set 626 | 627 | # sim=comparenetworkiterative.get_train_sample(_i_sim) 628 | # # fig=plot_show_cm(sim['measured_pattern'],_name+" training ("+str(_i_sim)+")") 629 | # # fig.savefig(os.path.join(DIR,_name+"_similar_"+"training_measured_center")) 630 | 631 | # # compare to training data set 632 | # fig=comparenetworkiterative.retrieve_measured(sim['measured_pattern'],_name+" training, Predicted") 633 | # fig.savefig(os.path.join(DIR,_name+"_similar_"+"training_predicted")) 634 | # # plot simulated retrieved and actual 635 | 636 | # fig=diffraction_functions.plot_amplitude_phase_meas_retreival(sim,_name+" training, Actual",ACTUAL=True) 637 | # fig.savefig(os.path.join(DIR,_name+"_similar_"+"training_actual")) 638 | 639 | # # plot simulated sample 640 | # sim=comparenetworkiterative.get_train_sample(0) 641 | # fig=plot_show_cm(sim['measured_pattern'],"validation (0)") 642 | # fig.savefig(os.path.join(DIR,"validation_measured_center")) 643 | 644 | # # compare to training data set 645 | # fig=comparenetworkiterative.retrieve_measured(sim['measured_pattern'],"Validation, Predicted") 646 | # fig.savefig(os.path.join(DIR,"validation_predicted")) 647 | # # plot simulated retrieved and actual 648 | 649 | # fig=diffraction_functions.plot_amplitude_phase_meas_retreival(sim,"Validation, Actual",ACTUAL=True) 650 | # fig.savefig(os.path.join(DIR,"validation_actual")) 651 | 652 | # fig=diffraction_functions.plot_amplitude_phase_meas_retreival(sim,"Validation, Actual",ACTUAL=True,mask=True) 653 | # fig.savefig(os.path.join(DIR,"validation_actual_WF")) 654 | 655 | 656 | plt.show() 657 | exit() 658 | 659 | 660 | import ipdb; ipdb.set_trace() # BREAKPOINT 661 | print("BREAKPOINT") 662 | plt.figure() 663 | 664 | 665 | 666 | 667 | 668 | -------------------------------------------------------------------------------- /GetMeasuredDiffractionPattern.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import diffraction_functions 4 | from astropy.io import fits 5 | import params 6 | 7 | class GetMeasuredDiffractionPattern(): 8 | """ 9 | for getting a measured diffraction pattern and formatting it to match the 10 | frequency axis of the simulation grid 11 | """ 12 | def __init__(self, N_sim, N_meas, experimental_params): 13 | """ 14 | experimental_params: dict{} 15 | experimental_params['pixel_size'] 16 | experimental_params['z_distance'] 17 | experimental_params['wavelength'] 18 | """ 19 | 20 | 21 | # get the axes of the simulation 22 | self.N_sim = N_sim # size of the measured simulated pattern 23 | self.experimental_params = experimental_params 24 | self.N_meas = N_meas 25 | self.simulation_axes, _ = diffraction_functions.get_amplitude_mask_and_imagesize(self.N_sim, int(params.params.wf_ratio*N_sim)) 26 | 27 | # parameters for the input measured diffraction pattern 28 | self.measured_axes = {} 29 | self.measured_axes["diffraction_plane"]={} 30 | self.calculate_measured_axes() 31 | 32 | # calculate ratio of df 33 | self.df_ratio = self.measured_axes['diffraction_plane']['df'] / self.simulation_axes['diffraction_plane']['df'] 34 | # multiply by scale adjustment 35 | 36 | # multiply by scale close to 1 (to account for errors) 37 | # self.df_ratio *= # fine adjustment 38 | 39 | def calculate_measured_axes(self): 40 | # calculate delta frequency 41 | self.measured_axes["diffraction_plane"]["df"] = self.experimental_params['pixel_size'] / (self.experimental_params['wavelength'] * self.experimental_params['z_distance']) 42 | self.measured_axes["diffraction_plane"]["xmax"] = self.N_meas * (self.experimental_params['pixel_size'] / 2) 43 | self.measured_axes["diffraction_plane"]["x"] = np.arange(-(self.measured_axes["diffraction_plane"]["xmax"]), (self.measured_axes["diffraction_plane"]["xmax"]), self.experimental_params['pixel_size']) 44 | self.measured_axes["diffraction_plane"]["f"] = self.measured_axes["diffraction_plane"]["x"] / (self.experimental_params['wavelength'] * self.experimental_params['z_distance']) 45 | # self.measured_axes["diffraction_plane"]["df"] = self.measured_axes["diffraction_plane"]["f"][-1] - self.measured_axes["diffraction_plane"]["f"][-2] 46 | 47 | 48 | def format_measured_diffraction_pattern(self, measured_diffraction_pattern, transform): 49 | """ 50 | measured_diffraction_pattern (numpy array) 51 | transform["rotation_angle"] = float/int (degrees) 52 | transform["scale"] = integer 53 | transform["flip"] = "lr" "ud" "lrud" None 54 | """ 55 | 56 | measured_diffraction_pattern = diffraction_functions.format_experimental_trace( 57 | N=self.N_sim, 58 | df_ratio=self.df_ratio*transform["scale"], 59 | measured_diffraction_pattern=measured_diffraction_pattern, 60 | rotation_angle=transform["rotation_angle"], 61 | trim=1) # if transposed (measured_pattern.T) , flip the rotation 62 | 63 | 64 | if transform["flip"] is not None: 65 | if transform["flip"] == "lr": 66 | measured_diffraction_pattern = np.flip(measured_diffraction_pattern, axis=1) 67 | 68 | elif transform["flip"] == "ud": 69 | measured_diffraction_pattern = np.flip(measured_diffraction_pattern, axis=0) 70 | 71 | elif transform["flip"] == "lrud": 72 | measured_diffraction_pattern = np.flip(measured_diffraction_pattern, axis=0) 73 | measured_diffraction_pattern = np.flip(measured_diffraction_pattern, axis=1) 74 | else: 75 | raise ValueError("invalid flip specified") 76 | # center it again 77 | 78 | measured_diffraction_pattern = diffraction_functions.center_image_at_centroid(measured_diffraction_pattern) 79 | measured_diffraction_pattern = np.expand_dims(measured_diffraction_pattern, axis=0) 80 | measured_diffraction_pattern = np.expand_dims(measured_diffraction_pattern, axis=-1) 81 | measured_diffraction_pattern *= (1/np.max(measured_diffraction_pattern)) 82 | return measured_diffraction_pattern 83 | 84 | 85 | if __name__ == "__main__": 86 | 87 | fits_file_name = "m3_scan_0000.fits" 88 | thing = fits.open(fits_file_name) 89 | measured_pattern = thing[0].data[0,:,:] 90 | N_meas = np.shape(measured_pattern)[0] 91 | measured_pattern = measured_pattern.astype(np.float64) 92 | 93 | experimental_params = {} 94 | experimental_params['pixel_size'] = 27e-6 # [meters] with 2x2 binning 95 | experimental_params['z_distance'] = 33e-3 # [meters] distance from camera 96 | experimental_params['wavelength'] = 13.5e-9 #[meters] wavelength 97 | 98 | getMeasuredDiffractionPattern = GetMeasuredDiffractionPattern(N_sim=256, N_meas=N_meas, experimental_params=experimental_params) 99 | 100 | transform={} 101 | transform["rotation_angle"]=3 102 | transform["scale"]=1 103 | # transform["flip"]="lr" 104 | transform["flip"]=None 105 | m = getMeasuredDiffractionPattern.format_measured_diffraction_pattern(measured_pattern, transform) 106 | m2 = diffraction_functions.get_and_format_experimental_trace(256, transform) 107 | 108 | print("np.shape(m) => ",np.shape(m)) 109 | print("np.shape(m2) => ",np.shape(m2)) 110 | 111 | plt.figure() 112 | plt.imshow(np.squeeze(m)) 113 | plt.colorbar() 114 | 115 | plt.figure() 116 | plt.imshow(np.squeeze(m2)) 117 | plt.colorbar() 118 | 119 | plt.figure() 120 | plt.imshow(np.squeeze(m)-np.squeeze(m2)) 121 | plt.colorbar() 122 | 123 | plt.show() 124 | 125 | 126 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jonathon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PropagateSphericalAperture.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import matplotlib.pyplot as plt 4 | from zernike3.build import PropagateTF 5 | import diffraction_functions 6 | import tensorflow as tf 7 | import params 8 | 9 | def create_aperture_material(N,x,p): 10 | 11 | aperture_radius=2.7e-6 12 | material=np.zeros((N,N),dtype=np.complex64) 13 | x=x.reshape(1,-1) 14 | y=x.reshape(-1,1) 15 | r = np.sqrt(x**2 + y**2) 16 | 17 | material[r=aperture_radius]=np.exp(-p.k*p.beta_Ta*p.dz)*np.exp(-1j*p.k*p.delta_Ta*p.dz) 19 | return material 20 | 21 | 22 | if __name__ == "__main__": 23 | 24 | N=256 25 | wavefront = tf.placeholder(tf.complex64, shape=[1,N,N,1]) 26 | 27 | params_Si=PropagateTF.Params() 28 | PropagateTF.loadparams("zernike3/build/params_Si.dat",params_Si) 29 | 30 | params_cu=PropagateTF.Params() 31 | PropagateTF.loadparams("zernike3/build/params_cu.dat",params_cu) 32 | 33 | # get position axis of wavefront sensor 34 | measured_axes, _=diffraction_functions.get_amplitude_mask_and_imagesize(N,int(params.params.wf_ratio*N)) 35 | x=measured_axes["object"]["x"] 36 | 37 | # this is the same data contained in measured_axes, but im loading it from the dat file anyway 38 | wf_f=None 39 | with open(os.path.join("zernike3/build/wfs_f.dat"),"r") as file: 40 | wf_f=np.array(file.readlines(),dtype=np.double) 41 | 42 | # distance of spherical aperture 43 | cu_distance=300e-9 44 | si_distance=50e-9 45 | steps_cu=int(round(cu_distance/params_cu.dz)) 46 | steps_Si=int(round(si_distance/params_Si.dz)) 47 | 48 | # steps_Si=None 49 | # with open(os.path.join("zernike3/build/steps_Si.dat"),"r") as file: 50 | # steps_Si=int(file.readlines()[0]) 51 | # steps_cu=None 52 | # with open(os.path.join("zernike3/build/steps_cu.dat"),"r") as file: 53 | # steps_cu=int(file.readlines()[0]) 54 | 55 | # create the material 56 | slice_Si=create_aperture_material(N,x,params_Si) 57 | slice_cu=create_aperture_material(N,x,params_cu) 58 | 59 | # test the wavefront sensor 60 | # slice_Si=PropagateTF.read_complex_array("zernike3/build/slice_Si") 61 | # slice_cu=PropagateTF.read_complex_array("zernike3/build/slice_cu") 62 | 63 | 64 | # increasing the wavelength will show the effect of propagation 65 | # params_Si.lam*=20 66 | # params_cu.lam*=20 67 | 68 | propagated=wavefront 69 | for _ in range(steps_Si): 70 | propagated=PropagateTF.forward_propagate(propagated,slice_Si,wf_f,params_Si) 71 | for _ in range(steps_cu): 72 | propagated=PropagateTF.forward_propagate(propagated,slice_cu,wf_f,params_cu) 73 | 74 | 75 | with tf.Session() as sess: 76 | out=sess.run(propagated,feed_dict={wavefront:np.ones((N,N)).reshape(1,N,N,1)}) 77 | 78 | # plt.figure() 79 | # plt.imshow(np.squeeze(np.real(out))) 80 | plt.figure() 81 | plt.imshow(np.squeeze(np.abs(out))) 82 | plt.show() 83 | 84 | print("ran!") 85 | print(out) 86 | 87 | 88 | 89 | # 300 nm Cu 90 | 91 | # 50 nm Si 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wavefront sensor neural network 2 | neural network trained to retrieve the amplitude and phase of objects by measurement of their diffraction pattern 3 | 4 | ## publication: 5 | Real-time phase-retrieval and wavefront sensing enabled by an artificial neural network 6 | Jonathon White, Sici Wang, Wilhelm Eschen, and Jan Rothhardt 7 | https://www.osapublishing.org/oe/fulltext.cfm?uri=oe-29-6-9283&id=449096 8 | 9 | 10 | # installation 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | # Overview / Files Explained in order of importance 15 | ### run\_tests.sh 16 | To train the network, run the shellscript 'run\_tests.sh'. This will call several .py files to create a new data set for training, add artificial noise to the dataset, and then train a neural network. Parameters for training the network are specified in this file, these parameters are the number of training samples, the peak counts of the noise levels for the training data, and the wavefront sensor. 17 | 18 | ### noise\_test\_compareCDI.sh 19 | Load the trained network and perform retrieval on several test samples. Also perform retrieval on the same samples using iterative phase retrieval for comparison 20 | 21 | 22 | ### diffraction\_net.py 23 | Contains the neural network, and classes for training the neural network, run from file 'run\_tests.sh' 24 | 25 | ### datagen.py 26 | Contains functions/classes for generating wavefronts and propagating them through the wavefront, used to create the training data set 27 | 28 | ### addnoise.py 29 | Adds poisson noise to simulate low photon count / additive noise from camera. 30 | 31 | ### buildui.sh / main.py / \_main.py / main.ui 32 | Files for building a user interface in Qt with file for use in realtime phase retrieval when connected to a camera. 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /SquareWFtest/CameraNoise/1_1000/Bild_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwhite510/wavefront-sensor-neural-network/3a521fea276119438fe07b6e2b777e677c1f83c4/SquareWFtest/CameraNoise/1_1000/Bild_1.png -------------------------------------------------------------------------------- /TestPropagate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import diffraction_functions 3 | import matplotlib.pyplot as plt 4 | import pickle 5 | from GetMeasuredDiffractionPattern import GetMeasuredDiffractionPattern 6 | import os 7 | 8 | 9 | # def plot_beam(complex_b,title): 10 | # fig, ax = plt.subplots(2,1, figsize=(10,10)) 11 | # fig.suptitle(title) 12 | # ax[0].pcolormesh(np.abs(complex_b)) 13 | # ax[0].set_title("ABS") 14 | 15 | 16 | if __name__=="__main__": 17 | folder_dir="nn_pictures/teslatest5_doubleksize_doublefilters_reconscostfunction_pictures/46/measured/" 18 | # run_name="Data_for_Jonathon_z0_1-fits_ud_1-0_reconstructed.p" 19 | sample_name="Data_for_Jonathon_z-500_1-fits_ud_1-0_reconstructed.p" 20 | filename=os.path.join(folder_dir,sample_name) 21 | print("filename =>", filename) 22 | 23 | with open(filename,"rb") as file: 24 | obj=pickle.load(file) 25 | complex_beam=np.squeeze(obj["real_output"]+1j*obj["imag_output"]) 26 | 27 | 28 | experimental_params = {} 29 | experimental_params['pixel_size'] = 27e-6 # [meters] with 2x2 binning 30 | experimental_params['z_distance'] = 16e-3 # [meters] distance from camera 31 | experimental_params['wavelength'] = 633e-9 #[meters] wavelength 32 | getMeasuredDiffractionPattern = GetMeasuredDiffractionPattern(N_sim=np.shape(complex_beam)[0], 33 | N_meas=2,#/// 34 | experimental_params=experimental_params) 35 | 36 | f_object=getMeasuredDiffractionPattern.simulation_axes['diffraction_plane']['f'] 37 | 38 | # x_pos = f_object * (experimental_params['wavelength'] * experimental_params['z_distance']) 39 | # # frequency axis for propagation 40 | # x_pos_dx=x_pos[-1]-x_pos[-2] 41 | # N=f_object.shape[0] 42 | # x_d_alpha=1 / (N * x_pos_dx) 43 | # x_alpha_max=(x_d_alpha * N)/2 44 | # x_alpha=np.arange(-x_alpha_max,x_alpha_max,x_d_alpha) 45 | 46 | # import ipdb; ipdb.set_trace() # BREAKPOINT 47 | # print("BREAKPOINT") 48 | 49 | fig = diffraction_functions.plot_amplitude_phase_meas_retreival( 50 | {"measured_pattern":np.zeros_like(np.abs(complex_beam)), 51 | "tf_reconstructed_diff":np.zeros_like(np.abs(complex_beam)), 52 | "real_output":np.real(complex_beam), 53 | "imag_output":np.imag(complex_beam)}, 54 | "complex_beam") 55 | 56 | 57 | 58 | # z = 0.5 # meters distance traveled 59 | for _i,z in enumerate(np.arange(-500e-6,2000e-6,100e-6)): 60 | gamma=np.sqrt( 61 | 1- 62 | (experimental_params['wavelength']*f_object.reshape(-1,1))**2- 63 | (experimental_params['wavelength']*f_object.reshape(1,-1))**2+ 64 | 0j # complex so it can be square rooted and imaginry 65 | ) 66 | k_sq = 2 * np.pi * z / experimental_params['wavelength'] 67 | # transfer function 68 | H = np.exp(1j * np.real(gamma) * k_sq) * np.exp(-1 * np.imag(gamma) * k_sq) 69 | 70 | complex_beam_f=np.fft.fftshift(np.fft.fft2(np.fft.fftshift(complex_beam))) 71 | complex_beam_f*=H 72 | complex_beam_prop=np.fft.fftshift(np.fft.ifft2(np.fft.fftshift(complex_beam_f))) 73 | 74 | fig = diffraction_functions.plot_amplitude_phase_meas_retreival( 75 | {"measured_pattern":np.zeros_like(np.abs(complex_beam_prop)), 76 | "tf_reconstructed_diff":np.zeros_like(np.abs(complex_beam_prop)), 77 | "real_output":np.real(complex_beam_prop), 78 | "imag_output":np.imag(complex_beam_prop)}, 79 | "z:-500 um -> + "+str(round(z*1e6,1))+" um [Simulated]",plot_spherical_aperture=True) 80 | plt.savefig(str(_i)+"_SIMULATEDpropfrom-500z__"+str(round(z*1e6,1))+"_um_prop.png") 81 | plt.close(fig) 82 | print("saving:"+str(z*1e6)+"um") 83 | 84 | 85 | 86 | # compare the z - 500 nm propagation to the 0 nm retrieval 87 | sample_name="Data_for_Jonathon_z0_1-fits_ud_1-0_reconstructed.p" 88 | filename=os.path.join(folder_dir,sample_name) 89 | with open(filename,"rb") as file: 90 | obj=pickle.load(file) 91 | fig=diffraction_functions.plot_amplitude_phase_meas_retreival(obj,"Retrieval at z=0 um") 92 | plt.savefig("z0retrieval") 93 | plt.close(fig) 94 | 95 | # plot the original retrieval at -500 nm 96 | sample_name="Data_for_Jonathon_z-500_1-fits_ud_1-0_reconstructed.p" 97 | filename=os.path.join(folder_dir,sample_name) 98 | with open(filename,"rb") as file: 99 | obj=pickle.load(file) 100 | fig=diffraction_functions.plot_amplitude_phase_meas_retreival(obj,"Retrieval at z=-500 um") 101 | plt.savefig("z-500retrieval") 102 | plt.close(fig) 103 | -------------------------------------------------------------------------------- /_main.py: -------------------------------------------------------------------------------- 1 | import main 2 | import params as globalparams 3 | import diffraction_functions 4 | import time 5 | from numpy import unravel_index 6 | import diffraction_net 7 | import numpy as np 8 | import sys 9 | from PyQt5 import QtCore, QtGui, QtWidgets 10 | import pyqtgraph as pg 11 | import os 12 | from GetMeasuredDiffractionPattern import GetMeasuredDiffractionPattern 13 | import pickle 14 | # from live_capture import TIS 15 | import matplotlib.pyplot as plt 16 | from matplotlib import cm 17 | import sys 18 | from typing import Optional 19 | from vimba import * 20 | 21 | 22 | def nparray_to_axislabel(arr:np.array,ticknumber:int)->list: 23 | axislabel=np.linspace(0,len(arr),ticknumber) 24 | if axislabel[-1]>=len(arr):axislabel[-1]-=1; 25 | axislabel=[[(int(v),"%.2f"%(arr[int(v)]))for v in axislabel]] 26 | return axislabel 27 | 28 | def randomgaussiansignal()->np.array: 29 | x=np.linspace(-1,1,128).reshape(-1,1);y=np.linspace(-1,1,128).reshape(1,-1);w=0.5; 30 | gau=np.exp(-(x**2/w))*np.exp(-(y**2)/w); 31 | gau+=np.random.rand(128*128).reshape(128,128); 32 | return gau; 33 | 34 | 35 | def addimageitemplot(qtgraphics,title:str,color:str,lut,ticks:list=None): 36 | newplot = {} 37 | newplot["data"] = pg.ImageItem() 38 | newplot["data"].setLookupTable(lut) 39 | newplot["plot"] = qtgraphics.addPlot() 40 | newplot["plot"].addItem(newplot["data"]) 41 | newplot["plot"].getAxis('left').setLabel('Position', color=color) 42 | if ticks:newplot["plot"].getAxis('left').setTicks(ticks) 43 | if ticks:newplot["plot"].getAxis('bottom').setTicks(ticks) 44 | newplot["plot"].getAxis('bottom').setLabel('Position', color=color) 45 | newplot["plot"].setTitle(title,color=color) 46 | return newplot 47 | 48 | 49 | def addimageviewplot(title:str,color:str,ticks:list=None, 50 | axislabel:str=None, 51 | ): 52 | 53 | # processed image 54 | leftaxis=pg.AxisItem(orientation='left') 55 | if ticks:leftaxis.setTicks(ticks) 56 | bottomaxis=pg.AxisItem(orientation='bottom') 57 | if ticks:bottomaxis.setTicks(ticks) 58 | plot=pg.PlotItem(axisItems={'left':leftaxis,'bottom':bottomaxis}) 59 | 60 | if axislabel:plot.setLabel(axis='left',text=axislabel,color=color) 61 | if axislabel:plot.setLabel(axis='bottom',text=axislabel,color=color) 62 | plot.setTitle(title,color=color) 63 | widget = pg.ImageView(view=plot) 64 | colors = [ 65 | (0, 0, 0), 66 | (45, 5, 61), 67 | (84, 42, 55), 68 | (150, 87, 60), 69 | (208, 171, 141), 70 | (255, 255, 255) 71 | ] 72 | cmap = pg.ColorMap(pos=np.linspace(0.0, 1.0, 6), color=colors) 73 | widget.setColorMap(cmap) 74 | # self.display_proc_draw["data"].setLookupTable(lut) 75 | # GraphicsLayoutWidget 76 | return widget 77 | 78 | 79 | 80 | def print_preamble(): 81 | print('//////////////////////////////////////////') 82 | print('/// Vimba API Synchronous Grab Example ///') 83 | print('//////////////////////////////////////////\n') 84 | 85 | 86 | def print_usage(): 87 | print('Usage:') 88 | print(' python synchronous_grab.py [camera_id]') 89 | print(' python synchronous_grab.py [/h] [-h]') 90 | print() 91 | print('Parameters:') 92 | print(' camera_id ID of the camera to use (using first camera if not specified)') 93 | print() 94 | 95 | 96 | def abort(reason: str, return_code: int = 1, usage: bool = False): 97 | print(reason + '\n') 98 | 99 | if usage: 100 | print_usage() 101 | 102 | sys.exit(return_code) 103 | 104 | 105 | def parse_args() -> Optional[str]: 106 | args = sys.argv[1:] 107 | argc = len(args) 108 | 109 | for arg in args: 110 | if arg in ('/h', '-h'): 111 | print_usage() 112 | sys.exit(0) 113 | 114 | if argc > 1: 115 | abort(reason="Invalid number of arguments. Abort.", return_code=2, usage=True) 116 | 117 | return None if argc == 0 else args[0] 118 | 119 | 120 | def get_camera(camera_id: Optional[str]) -> Camera: 121 | with Vimba.get_instance() as vimba: 122 | if camera_id: 123 | try: 124 | return vimba.get_camera_by_id(camera_id) 125 | 126 | except VimbaCameraError: 127 | abort('Failed to access Camera \'{}\'. Abort.'.format(camera_id)) 128 | 129 | else: 130 | cams = vimba.get_all_cameras() 131 | if not cams: 132 | abort('No Cameras accessible. Abort.') 133 | 134 | return cams[0] 135 | 136 | 137 | def setup_camera(cam: Camera): 138 | with cam: 139 | # Try to adjust GeV packet size. This Feature is only available for GigE - Cameras. 140 | try: 141 | cam.GVSPAdjustPacketSize.run() 142 | 143 | while not cam.GVSPAdjustPacketSize.is_done(): 144 | pass 145 | 146 | except (AttributeError, VimbaFeatureError): 147 | pass 148 | 149 | class Processing(): 150 | def __init__(self): 151 | # string 152 | self.orientation=None 153 | # float 154 | self.rotation=0.0 155 | # float 156 | self.scale=1.0 157 | 158 | 159 | class MainWindow(QtWidgets.QMainWindow, main.Ui_MainWindow): 160 | 161 | def __init__(self,params): 162 | app = QtWidgets.QApplication(sys.argv) 163 | # MainWindow = QtWidgets.QMainWindow() 164 | QtWidgets.QMainWindow.__init__(self) 165 | 166 | # # initialize neural network 167 | self.network=diffraction_net.DiffractionNet(params['network'],"original") # load a pre trained network 168 | # N=self.network.get_data.N; 169 | N=256 170 | # get position and frequenxy axis 171 | simulation_axes, amplitude_mask = diffraction_functions.get_amplitude_mask_and_imagesize( 172 | N, int(globalparams.params.wf_ratio*N) 173 | ) 174 | self.x=simulation_axes['object']['x'] # meters 175 | self.x*=1e6 # 1d numpy array [micrometers] 176 | 177 | self.f=simulation_axes['diffraction_plane']['f'] # 1/meters 178 | self.f*=1e-6 # 1d numpy array 179 | 180 | self.COLORGREEN='#54f542' 181 | 182 | self.setupUi(self) 183 | 184 | colormap = cm.get_cmap("nipy_spectral") 185 | colormap._init() 186 | lut = (colormap._lut*255).view(np.ndarray) 187 | 188 | 189 | # reconstructed image 190 | self.display_recons_draw=addimageviewplot('reconstruced',color=self.COLORGREEN, 191 | ticks=nparray_to_axislabel(self.f,3), 192 | axislabel='frequency [1/m] 10^6' 193 | ); 194 | self.verticalLayout_3.addWidget(self.display_recons_draw) 195 | self.display_recons_draw.setImage(np.random.rand(128*128).reshape(128,128)) 196 | 197 | # raw image 198 | self.display_raw_draw=addimageviewplot('raw image',color=self.COLORGREEN,axislabel='pixel'); 199 | self.verticalLayout.addWidget(self.display_raw_draw) 200 | self.display_raw_draw.setImage(np.random.rand(128*128).reshape(128,128)) 201 | 202 | # processed image 203 | self.display_proc_draw=addimageviewplot('processed image',color=self.COLORGREEN, 204 | ticks=nparray_to_axislabel(self.f,3), 205 | axislabel='frequency [1/m] 10^6' 206 | ); 207 | self.verticalLayout.addWidget(self.display_proc_draw) 208 | self.display_proc_draw.setImage(np.random.rand(128*128).reshape(128,128)) 209 | 210 | # intensity / real 211 | self.display_intens_real_draw=addimageviewplot('intensity',color=self.COLORGREEN, 212 | ticks=nparray_to_axislabel(self.x,3), 213 | axislabel='position [um]' 214 | ) 215 | self.horizontalLayout_4.addWidget(self.display_intens_real_draw) 216 | self.display_intens_real_draw.setImage(np.random.rand(128*128).reshape(128,128)) 217 | 218 | # phase / imag 219 | self.display_phase_imag_draw=addimageviewplot('phase[rad]',color=self.COLORGREEN, 220 | ticks=nparray_to_axislabel(self.x,3), 221 | axislabel='position [um]' 222 | ) 223 | self.horizontalLayout_4.addWidget(self.display_phase_imag_draw) 224 | self.display_phase_imag_draw.setImage(np.random.rand(128*128).reshape(128,128)) 225 | 226 | # initialize processing parameters 227 | self.processing=Processing() 228 | # set the buttons to these values 229 | self.rotation_edit.setText(str(self.processing.rotation)) 230 | self.scale_edit.setText(str(self.processing.scale)) 231 | self.orientation_edit.addItems(['None','Left->Right','Up->Down','Left->Right & Up->Down']) 232 | self.orientation_edit.setCurrentIndex(1) # default to Left->Right 233 | 234 | # initialize camera 235 | # self.Tis=params['Tis'] 236 | # self.Tis.Start_pipeline() # Start the pipeline so the camera streams 237 | 238 | # plt.ion() 239 | # while True: 240 | # if self.Tis.Snap_image(1) is True: # Snap an image with one second timeout 241 | # image = self.Tis.Get_image() # Get the image. It is a numpy array 242 | # plt.figure(1) 243 | # plt.imshow(np.squeeze(image)) 244 | # plt.pause(0.1) 245 | # print("hello?") 246 | # self.Tis.Stop_pipeline() 247 | # exit() 248 | # im=None 249 | im=None 250 | with Vimba.get_instance(): 251 | with get_camera(None) as cam: 252 | setup_camera(cam) 253 | im=np.squeeze( cam.get_frame().as_numpy_ndarray() ) 254 | 255 | # im=self.retrieve_raw_img() 256 | # self.Tis.Stop_pipeline() 257 | 258 | experimental_params = {} 259 | experimental_params['pixel_size'] = params['pixel_size'] # [meters] with 2x2 binning 260 | experimental_params['z_distance'] = params['z_distance'] # [meters] distance from camera 261 | experimental_params['wavelength'] = params['wavelength'] #[meters] wavelength 262 | self.getMeasuredDiffractionPattern = GetMeasuredDiffractionPattern(N_sim=256, 263 | N_meas=np.shape(im)[0], # for calculating the measured frequency axis (not really needed) 264 | experimental_params=experimental_params) 265 | 266 | # state of UI 267 | self.running=False 268 | self.plot_RE_IM=False 269 | 270 | 271 | 272 | 273 | 274 | self.show() 275 | sys.exit(app.exec_()) 276 | 277 | def __del__(self): 278 | pass 279 | # self.inst.close(); 280 | # self.cam.close(); 281 | # cleanup camera 282 | # self.Tis.Stop_pipeline() 283 | 284 | 285 | def textchanged(self): 286 | print("the text was changed") 287 | 288 | def ProcessingUpdated(self): 289 | 290 | try: 291 | new_rotation = float(self.rotation_edit.text()) 292 | new_scale = float(self.scale_edit.text()) 293 | new_orientation = self.orientation_edit.currentText() 294 | if new_scale <= 0: 295 | raise ValueError("scale must be greater than 0") 296 | 297 | self.update_processing_values(new_rotation,new_scale,new_orientation) 298 | 299 | except Exception as e: 300 | print(e) 301 | pass 302 | 303 | def update_processing_values(self,new_rotation,new_scale,new_orientation): 304 | self.processing.orientation=new_orientation 305 | self.processing.rotation=new_rotation 306 | self.processing.scale=new_scale 307 | 308 | def Start_Stop_Clicked(self): 309 | if not self.running: 310 | self.running=True 311 | self.pushButton.setText("Stop") 312 | self.run_retrieval() 313 | 314 | if self.running: 315 | self.running=False 316 | self.pushButton.setText("Start") 317 | 318 | def TogglePlotRE_IM(self): 319 | 320 | if self.plot_RE_IM == True: 321 | self.plot_RE_IM=False 322 | self.display_phase_imag_draw["plot"].setTitle('Phase',color=self.COLORGREEN) 323 | self.display_intens_real_draw["plot"].setTitle('Intensity',color=self.COLORGREEN) 324 | self.view_toggle.setText("Real/Imag") 325 | 326 | elif self.plot_RE_IM == False: 327 | self.plot_RE_IM=True 328 | self.display_phase_imag_draw["plot"].setTitle('Imaginary',color=self.COLORGREEN) 329 | self.display_intens_real_draw["plot"].setTitle('Real',color=self.COLORGREEN) 330 | self.view_toggle.setText("Phase/\nIntensity") 331 | 332 | def run_retrieval(self): 333 | 334 | first_image=True 335 | with Vimba.get_instance(): 336 | with get_camera(None) as cam: 337 | setup_camera(cam) 338 | while self.running: 339 | time1=time.time() 340 | QtCore.QCoreApplication.processEvents() 341 | 342 | # grab raw image 343 | # im = self.retrieve_raw_img() 344 | # process image 345 | im=np.squeeze( cam.get_frame().as_numpy_ndarray() ) 346 | # im=np.random.rand(500*500).reshape(500,500) 347 | 348 | transform={} 349 | transform["rotation_angle"]=self.processing.rotation 350 | transform["scale"]=self.processing.scale 351 | if self.processing.orientation == "None": 352 | transform["flip"]=None 353 | 354 | elif self.processing.orientation == "Left->Right": 355 | transform["flip"]="lr" 356 | 357 | elif self.processing.orientation == "Up->Down": 358 | transform["flip"]="ud" 359 | 360 | elif self.processing.orientation == "Left->Right & Up->Down": 361 | transform["flip"]="lrud" 362 | im_p = self.getMeasuredDiffractionPattern.format_measured_diffraction_pattern(im, transform) 363 | 364 | # input through neural network 365 | print("input through net:") 366 | time_a=time.time() 367 | out_recons = self.network.sess.run( self.network.nn_nodes["recons_diffraction_pattern"], feed_dict={self.network.x:im_p}) 368 | out_real = self.network.sess.run( self.network.nn_nodes["real_out"], feed_dict={self.network.x:im_p}) 369 | out_imag = self.network.sess.run( self.network.nn_nodes["imag_out"], feed_dict={self.network.x:im_p}) 370 | time_b=time.time() 371 | print(time_b-time_a) 372 | 373 | out_real=np.squeeze(out_real) 374 | out_imag=np.squeeze(out_imag) 375 | out_recons=np.squeeze(out_recons) 376 | 377 | # calculate the intensity 378 | complex_obj = out_real + 1j * out_imag 379 | I = np.abs(complex_obj)**2 380 | m_index = unravel_index(I.argmax(), I.shape) 381 | phase_Imax = np.angle(complex_obj[m_index[0], m_index[1]]) 382 | complex_obj *= np.exp(-1j * phase_Imax) 383 | obj_phase = np.angle(complex_obj) 384 | 385 | # not using the amplitude_mask, use the absolute value of the intensity 386 | nonzero_intensity = np.array(np.abs(complex_obj)) 387 | nonzero_intensity[nonzero_intensity < 0.01*np.max(nonzero_intensity)] = 0 388 | nonzero_intensity[nonzero_intensity >= 0.01*np.max(nonzero_intensity)] = 1 389 | obj_phase *= nonzero_intensity 390 | 391 | im_p=np.squeeze(im_p) 392 | 393 | # grab image with orientation, rotation, scale settings 394 | print("self.processing.orientation =>", self.processing.orientation) 395 | print("self.processing.rotation =>", self.processing.rotation) 396 | print("self.processing.scale =>", self.processing.scale) 397 | 398 | self.display_raw_draw.setImage(im, 399 | autoRange=(True if first_image else False),autoLevels=(True if first_image else False), 400 | ) 401 | 402 | self.display_proc_draw.setImage(im_p, 403 | autoRange=(True if first_image else False),autoLevels=(True if first_image else False), 404 | ) 405 | self.display_intens_real_draw.setImage(I) 406 | self.display_phase_imag_draw.setImage(obj_phase) 407 | 408 | self.display_recons_draw.setImage(out_recons, 409 | autoRange=(True if first_image else False),autoLevels=(True if first_image else False), 410 | ) 411 | first_image=False 412 | 413 | time2=time.time() 414 | print("total time:") 415 | print(time2-time1) 416 | 417 | 418 | def retrieve_raw_img(self)->np.array: 419 | 420 | # with open("sample.p", "rb") as file: 421 | # obj = pickle.load(file) 422 | # obj=np.pad(obj,pad_width=300,mode="constant",constant_values=0) 423 | # x=np.linspace(-1,1,500).reshape(1,-1) 424 | # y=np.linspace(-1,1,500).reshape(-1,1) 425 | 426 | # z = np.exp(-x**2 / 0.5) * np.exp(-y**2 / 0.5) 427 | # return obj 428 | # return np.random.rand(500,600) 429 | 430 | # if self.Tis.Snap_image(1) is True: # Snap an image with one second timeout 431 | # im=self.Tis.Get_image() 432 | # im=np.sum(im,axis=2) 433 | # return im 434 | # else: 435 | # return None 436 | # frame=self.cam.get_frame() 437 | # return np.squeeze(frame.to_np_ndarray()) 438 | return None 439 | 440 | 441 | 442 | 443 | 444 | if __name__ == "__main__": 445 | params={} 446 | params['pixel_size']=3.45e-6 # meters 447 | params['z_distance']=16.5e-3 # meter 448 | params['wavelength']=612e-9 449 | params['network']="256_6x6_square_wfs_test" 450 | 451 | # for camera 452 | # self.Tis = TIS.TIS("48710182", 640, 480, 30, False) 453 | # self.Tis = TIS.TIS("48710182", 2592, 2048, 60, False) 454 | # params['Tis']=TIS.TIS("48710182", 2592, 2048, 60, False) 455 | # https://github.com/TheImagingSource/tiscamera 456 | 457 | 458 | mainw = MainWindow(params) 459 | 460 | 461 | 462 | -------------------------------------------------------------------------------- /addnoise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from PIL import Image, ImageDraw 4 | import tables 5 | import os 6 | import diffraction_functions 7 | import argparse 8 | import datagen 9 | 10 | class CameraNoise(): 11 | def __init__(self,imagefile): 12 | print("init camera noise") 13 | print("imagefile =>", imagefile) 14 | self.imagefile=imagefile 15 | self.im = Image.open(self.imagefile) 16 | self.im=self.im.convert("L") 17 | self.im=np.array(self.im) 18 | 19 | # flatten the array 20 | self.distribution=self.im.reshape(-1) 21 | avg=np.average(self.distribution) 22 | print("distribution avergae:"+str(avg)) 23 | 24 | if __name__ == "__main__": 25 | parser=argparse.ArgumentParser() 26 | parser.add_argument('--infile',type=str) 27 | parser.add_argument('--outfile',type=str) 28 | parser.add_argument('--peakcount',type=str) 29 | parser.add_argument('--cameraimage',type=str) 30 | args,_=parser.parse_known_args() 31 | 32 | ocameraNoise=CameraNoise(args.cameraimage) 33 | 34 | # print("parser.INFILE =>", args.infile) 35 | print("args.infile =>", args.infile) 36 | print("args.outfile =>", args.outfile) 37 | print("args.peakcount =>", args.peakcount) 38 | print("args.cameraimage =>", args.cameraimage) 39 | 40 | # create new hdf5 file 41 | with tables.open_file(args.infile,mode="r") as hd5file: 42 | # get number of coefficients 43 | n_z_coefs = len(hd5file.root.coefficients[0]) 44 | datagen.create_dataset(args.outfile,coefficients=n_z_coefs) 45 | with tables.open_file(args.outfile,mode="a") as newhd5file: 46 | 47 | N = hd5file.root.N[0,0] 48 | samples = hd5file.root.object_real.shape[0] 49 | print("samples =>", samples) 50 | for _i in range(samples): 51 | 52 | if _i % 100 == 0: 53 | print("add noise to sample:"+str(_i)) 54 | 55 | object_real=hd5file.root.object_real[_i, :].reshape(N,N) 56 | object_imag=hd5file.root.object_imag[_i, :].reshape(N,N) 57 | diffraction=hd5file.root.diffraction_noisefree[_i, :].reshape(N,N) 58 | _z_coefs = hd5file.root.coefficients[_i, :] 59 | _scales = hd5file.root.scale[_i,:] 60 | 61 | 62 | peakcounts = [int(_s) for _s in args.peakcount.split(",")] 63 | if peakcounts[0]==0 and len(peakcounts)==1: 64 | newhd5file.root.object_real.append(object_real.reshape(1,-1)) 65 | newhd5file.root.object_imag.append(object_imag.reshape(1,-1)) 66 | newhd5file.root.diffraction_noise.append(diffraction.reshape(1,-1)) 67 | newhd5file.root.diffraction_noisefree.append(diffraction.reshape(1,-1)) 68 | newhd5file.root.coefficients.append(_z_coefs.reshape(1,-1)) 69 | newhd5file.root.scale.append(_scales.reshape(1,-1)) 70 | else: 71 | for _pc in peakcounts: 72 | # apply poisson noise 73 | scalar=_pc/np.max(diffraction) 74 | diffraction_pattern_with_noise_poisson=diffraction*scalar 75 | diffraction_pattern_with_noise_poisson=np.random.poisson(diffraction_pattern_with_noise_poisson) 76 | 77 | # draw from random sample 78 | # apply camera noise 79 | total_sim_size=diffraction.shape[0]*diffraction.shape[1] 80 | camera_noise=np.random.choice(ocameraNoise.distribution,size=total_sim_size) 81 | camera_noise=camera_noise.reshape(diffraction.shape) 82 | 83 | diffraction_pattern_with_noise_poisson_and_camera=diffraction_pattern_with_noise_poisson+camera_noise 84 | # normalize 85 | diffraction_pattern_with_noise_poisson_and_camera=diffraction_pattern_with_noise_poisson_and_camera/np.max(diffraction_pattern_with_noise_poisson_and_camera) 86 | 87 | # center it again 88 | diffraction_pattern_with_noise_poisson_and_camera = diffraction_functions.center_image_at_centroid(diffraction_pattern_with_noise_poisson_and_camera) 89 | diffraction_pattern_with_noise_poisson_and_camera[diffraction_pattern_with_noise_poisson_and_camera<0]=0 90 | 91 | newhd5file.root.object_real.append(object_real.reshape(1,-1)) 92 | newhd5file.root.object_imag.append(object_imag.reshape(1,-1)) 93 | newhd5file.root.diffraction_noise.append(diffraction_pattern_with_noise_poisson_and_camera.reshape(1,-1)) 94 | newhd5file.root.diffraction_noisefree.append(diffraction.reshape(1,-1)) 95 | newhd5file.root.coefficients.append(_z_coefs.reshape(1,-1)) 96 | newhd5file.root.scale.append(_scales.reshape(1,-1)) 97 | 98 | # diffraction=diffraction_pattern_with_noise_poisson_and_camera 99 | 100 | # plt.figure() 101 | # plt.pcolormesh(diffraction_pattern_with_noise_poisson_and_camera) 102 | # plt.title("diffraction_pattern_with_noise_poisson_and_camera") 103 | 104 | # plt.figure() 105 | # plt.pcolormesh(diffraction) 106 | # plt.title("diffraction") 107 | 108 | # plt.show() 109 | # # exit() 110 | 111 | 112 | -------------------------------------------------------------------------------- /buildui.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # venv/bin/pyuic5 -x main.ui -o main.py 3 | ~/mypython/bin/pyuic5 -x main.ui -o main.py 4 | -------------------------------------------------------------------------------- /datagen.dat: -------------------------------------------------------------------------------- 1 | 1.5 0.0 0.0 0.0 6.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2 | -------------------------------------------------------------------------------- /datagen.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from params import MaterialParams 3 | import params 4 | import argparse 5 | import tables 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import math 9 | import diffraction_functions 10 | 11 | 12 | # def gaussian_propagate(zernike_polynom:tf.Tensor,scale:tf.Tensor)->tf.Tensor: 13 | 14 | 15 | def tf_make_zernike(m:int, n:int, N_computational:int, scale:tf.Tensor, amp:tf.Tensor)->tf.Tensor: 16 | _scale = tf.expand_dims(scale,axis=-1) 17 | x = np.linspace(-7,7,N_computational).reshape(1,-1,1) 18 | y = np.linspace(-7,7,N_computational).reshape(1,1,-1) 19 | phi=tf.constant(np.arctan2(x,y),dtype=tf.float32) 20 | x = (1/_scale)*tf.constant(x,dtype=tf.float32) 21 | y = (1/_scale)*tf.constant(y,dtype=tf.float32) 22 | # x = scale * x 23 | # y = scale * y 24 | rho = tf.sqrt(x**2 + y**2) 25 | 26 | positive_m = False 27 | if m >= 0: 28 | positive_m=True 29 | m = abs(m) 30 | # summation over k 31 | rho = tf.expand_dims(rho,axis=1) 32 | # for each k value 33 | # k = np.linspace(0,(n-m)//2) 34 | k = np.arange(start=0,stop=1+((n-m)//2)) 35 | numerator = (-1)**k 36 | for i in range(len(k)): 37 | numerator[i]*= math.factorial(n-k[i]) 38 | 39 | denominator = np.ones_like(k) 40 | for i in range(len(k)): 41 | denominator[i]*= math.factorial(k[i]) 42 | denominator[i]*= math.factorial(((n+m)/2)-k[i]) 43 | denominator[i]*= math.factorial(((n-m)/2)-k[i]) 44 | 45 | scalar = numerator/denominator 46 | zernike_polynom = scalar.reshape(1,-1,1,1) * rho**(n-2*k.reshape(1,-1,1,1)) 47 | zernike_polynom = tf.reduce_sum(zernike_polynom,axis=1) 48 | 49 | if positive_m: 50 | zernike_polynom *= tf.cos(m * phi) 51 | else: 52 | zernike_polynom *= tf.sin(m * phi) 53 | 54 | # set values outside unit circle to 0 55 | zernike_polynom*=tf.cast(tf.less_equal(tf.squeeze(rho,axis=1),1),dtype=tf.float32) 56 | _amp = tf.expand_dims(amp,axis=-1) 57 | _amp = tf.expand_dims(_amp,axis=-1) 58 | zernike_polynom*=_amp 59 | return tf.expand_dims(zernike_polynom,axis=1) 60 | 61 | def makezernike(m: int,n: int, N_computational: int)->np.array: 62 | x = np.linspace(-7/2,7/2,N_computational).reshape(-1,1) 63 | y = np.linspace(-7/2,7/2,N_computational).reshape(1,-1) 64 | rho = np.sqrt(x**2 + y**2) 65 | phi = np.arctan2(x,y) 66 | 67 | positive_m = False 68 | if m >= 0: 69 | positive_m=True 70 | m = abs(m) 71 | zernike_polynom = np.zeros((N_computational,N_computational)) 72 | k=0 73 | while k <= (n-m)/2: 74 | 75 | numerator = -1 ** k 76 | numerator *= math.factorial(n-k) 77 | 78 | denominator = math.factorial(k) 79 | denominator *= math.factorial(((n+m)/2)-k) 80 | denominator *= math.factorial(((n-m)/2)-k) 81 | 82 | scalar = numerator/denominator 83 | 84 | zernike_polynom+=scalar*rho**(n-2*k) 85 | k+=1 86 | 87 | if positive_m: 88 | zernike_polynom *= np.cos(m * phi) 89 | else: 90 | zernike_polynom *= np.sin(m * phi) 91 | 92 | # set values outside unit circle to 0 93 | zernike_polynom[rho>1]=0 94 | return zernike_polynom 95 | 96 | class Zernike_C(): 97 | def __init__(self,m:int,n:int): 98 | self.m=m 99 | self.n=n 100 | 101 | def create_slice(p: MaterialParams, N_interp: int)->np.array: 102 | _, wfs = diffraction_functions.get_amplitude_mask_and_imagesize(N_interp, int(params.params.wf_ratio*N_interp)) 103 | slice = np.zeros((N_interp,N_interp),dtype=np.complex128) 104 | slice[wfs<0.5]=np.exp(-1*p.k * p.beta_Ta * p.dz)*\ 105 | np.exp(-1j*p.k*p.delta_Ta*p.dz) 106 | slice[wfs>=0.5]=1.0 107 | return slice 108 | 109 | class Material(): 110 | mparams=None 111 | steps=None 112 | slice=None 113 | distance=None 114 | def __init__(self,mparams:MaterialParams,N:int): 115 | self.mparams=mparams 116 | self.distance=mparams.distance 117 | self.steps=round(self.distance/self.mparams.dz) 118 | self.slice = create_slice(self.mparams,N) 119 | 120 | 121 | class DataGenerator(): 122 | def __init__(self,N_computational:int,N_interp:int): 123 | self.N_interp=N_interp 124 | self.N_computational=N_computational 125 | # generate zernike coefficients 126 | self.batch_size=4 127 | start_n=1 128 | max_n=4 129 | self.zernike_cvector = [] 130 | 131 | for n in range(start_n,max_n+1): 132 | for m in range(n,-n-2,-2): 133 | self.zernike_cvector.append(Zernike_C(m,n)) 134 | 135 | materials = [Material(mparams=_p,N=N_interp) for _p in params.params.material_params] 136 | self.propagate_tf=PropagateTF(N_interp,materials) 137 | 138 | def buildgraph(self,x:tf.Tensor,scale:tf.Tensor)->(tf.Tensor,tf.Tensor): 139 | 140 | # generate polynomials 141 | zernikes=[] 142 | for i in range(len(self.zernike_cvector)): 143 | _z = self.zernike_cvector[i] 144 | zernikes.append(tf_make_zernike(_z.m,_z.n,self.N_computational,scale,x[:,i])) 145 | 146 | zernikes=tf.concat(zernikes,axis=1) 147 | zernike_polynom = tf.reduce_sum(zernikes,axis=1) 148 | 149 | # propagate through gaussian 150 | x = np.linspace(1,-1,self.N_computational).reshape(1,-1,1) 151 | y = np.linspace(1,-1,self.N_computational).reshape(1,1,-1) 152 | _scale = tf.expand_dims(scale,axis=-1) 153 | x = (1/_scale)*tf.constant(x,dtype=tf.float32) 154 | y = (1/_scale)*tf.constant(y,dtype=tf.float32) 155 | width = 0.05 156 | gaussian_amp = tf.exp(-(x**2)/(width**2))*tf.exp(-(y**2)/(width**2)) 157 | 158 | field = tf.complex(real=gaussian_amp,imag=tf.zeros_like(gaussian_amp)) * tf.exp(tf.complex(real=tf.zeros_like(zernike_polynom),imag=zernike_polynom)) 159 | 160 | # fft 161 | field = tf.expand_dims(field,axis=-1) 162 | field_ft=diffraction_functions.tf_fft2(field,dimmensions=[1,2]) 163 | 164 | # crop and interpolate 165 | field_cropped = field_ft[:,(self.N_computational//2)-(self.N_interp//2):(self.N_computational//2)+(self.N_interp//2),(self.N_computational//2)-(self.N_interp//2):(self.N_computational//2)+(self.N_interp//2),:] 166 | 167 | z_center = field_cropped[:,self.N_interp//2,self.N_interp//2,0] 168 | z_center = tf.expand_dims(z_center,-1) 169 | z_center = tf.expand_dims(z_center,-1) 170 | z_center = tf.expand_dims(z_center,-1) 171 | field_cropped = field_cropped * tf.exp(tf.complex(real=0.0,imag=-1.0*tf.angle(z_center))) 172 | 173 | # normalize within wavefront sensor 174 | _, wfs = diffraction_functions.get_amplitude_mask_and_imagesize(self.N_interp, int(params.params.wf_ratio*self.N_interp)) 175 | wfs = np.expand_dims(wfs,0) 176 | wfs = np.expand_dims(wfs,-1) 177 | wfs = tf.constant(wfs,dtype=tf.float32) 178 | norm_factor = tf.reduce_max(wfs*tf.abs(field_cropped),keepdims=True,axis=[1,2]) 179 | field_cropped = field_cropped / tf.complex(real=norm_factor,imag=tf.zeros_like(norm_factor)) 180 | # return this as the field before wfs 181 | 182 | # propagator through wavefront sensor 183 | return field_cropped 184 | 185 | def propagate_through_wfs(self,field:tf.Tensor): 186 | return self.propagate_tf.setup_graph_through_wfs(field) 187 | 188 | class PropagateTF(): 189 | def __init__(self, N_interp:int, materials:list): 190 | 191 | self.materials=materials 192 | 193 | measured_axes, _ = diffraction_functions.get_amplitude_mask_and_imagesize(N_interp, int(params.params.wf_ratio*N_interp)) 194 | self.wf_f = measured_axes["diffraction_plane"]["f"] 195 | 196 | def setup_graph_through_wfs(self, wavefront): 197 | 198 | wavefront_ref=wavefront 199 | for _material in self.materials: 200 | for _ in range(_material.steps): 201 | wavefront_ref=forward_propagate(wavefront_ref,_material.slice,self.wf_f,_material.mparams) 202 | return wavefront_ref 203 | 204 | def forward_propagate(E,slice,f,p): 205 | slice=np.expand_dims(slice,0) 206 | slice=np.expand_dims(slice,3) 207 | 208 | E*=slice 209 | E=diffraction_functions.tf_fft2(E,dimmensions=[1,2]) 210 | gamma1=tf.constant( 211 | 1- 212 | (p.lam*f.reshape(-1,1))**2- 213 | (p.lam*f.reshape(1,-1))**2, 214 | dtype=tf.complex64 215 | ) 216 | gamma=tf.sqrt( 217 | gamma1 218 | ) 219 | k_sq = 2 * np.pi * p.dz / p.lam 220 | H = tf.exp( 221 | tf.complex( 222 | real=tf.zeros_like(tf.real(gamma)*k_sq), 223 | imag=tf.real(gamma)*k_sq 224 | ) 225 | )*tf.exp( 226 | tf.complex( 227 | real=-1*tf.imag(gamma)*k_sq, 228 | imag=tf.zeros_like(tf.imag(gamma)*k_sq) 229 | ) 230 | ) 231 | H = tf.expand_dims(H,0) 232 | H = tf.expand_dims(H,-1) 233 | E*=H 234 | E=diffraction_functions.tf_ifft2(E,dimmensions=[1,2]) 235 | return E 236 | 237 | def create_dataset(filename:str, coefficients:int): 238 | 239 | print("called create_dataset") 240 | print(filename) 241 | N = 256 242 | with tables.open_file(filename, "w") as hdf5file: 243 | 244 | # create array for the object 245 | hdf5file.create_earray(hdf5file.root, "object_real", tables.Float32Atom(), shape=(0,N*N)) 246 | 247 | # create array for the object phase 248 | hdf5file.create_earray(hdf5file.root, "object_imag", tables.Float32Atom(), shape=(0,N*N)) 249 | 250 | # create array for the image 251 | hdf5file.create_earray(hdf5file.root, "diffraction_noise", tables.Float32Atom(), shape=(0,N*N)) 252 | 253 | # create array for the image 254 | hdf5file.create_earray(hdf5file.root, "diffraction_noisefree", tables.Float32Atom(), shape=(0,N*N)) 255 | 256 | # scale 257 | hdf5file.create_earray(hdf5file.root, "scale", tables.Float32Atom(), shape=(0,1)) 258 | 259 | # zernike coefficients 260 | hdf5file.create_earray(hdf5file.root, "coefficients", tables.Float32Atom(), shape=(0,coefficients)) 261 | 262 | hdf5file.create_earray(hdf5file.root, "N", tables.Int32Atom(), shape=(0,1)) 263 | 264 | hdf5file.close() 265 | 266 | with tables.open_file(filename, mode='a') as hd5file: 267 | # save the dimmensions of the data 268 | hd5file.root.N.append(np.array([[N]])) 269 | 270 | def save_to_hdf5(filename:str, afterwf:np.array, beforewf:np.array, z_coefs:np.array, scales:np.array): 271 | with tables.open_file(filename, mode='a') as hd5file: 272 | for i in range(np.shape(beforewf)[0]): 273 | object_real = np.real(beforewf[i,:,:]) 274 | object_imag = np.imag(beforewf[i,:,:]) 275 | diffraction_pattern_noisefree = np.abs(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(afterwf[i,:,:]))))**2 276 | _z_coefs = z_coefs[i,:] 277 | _scales = scales[i] 278 | 279 | 280 | # normalize 281 | diffraction_pattern_noisefree = diffraction_pattern_noisefree / np.max(diffraction_pattern_noisefree) 282 | diffraction_pattern_noisefree = diffraction_functions.center_image_at_centroid(diffraction_pattern_noisefree) 283 | diffraction_pattern_noisefree[diffraction_pattern_noisefree<0]=0 284 | hd5file.root.object_real.append(object_real.reshape(1,-1)) 285 | hd5file.root.object_imag.append(object_imag.reshape(1,-1)) 286 | hd5file.root.diffraction_noisefree.append(diffraction_pattern_noisefree.reshape(1,-1)) 287 | hd5file.root.coefficients.append(_z_coefs.reshape(1,-1)) 288 | hd5file.root.scale.append(_scales.reshape(1,-1)) 289 | 290 | print("calling flush") 291 | hd5file.flush() 292 | 293 | if __name__ == "__main__": 294 | 295 | parser=argparse.ArgumentParser() 296 | parser.add_argument('--count',type=int) 297 | parser.add_argument('--seed',type=int) 298 | parser.add_argument('--name',type=str) 299 | parser.add_argument('--batch_size',type=int) 300 | parser.add_argument('--samplesf',type=str) 301 | args,_=parser.parse_known_args() 302 | 303 | datagenerator = DataGenerator(1024,256) 304 | 305 | x = tf.placeholder(tf.float32, shape=[None, len(datagenerator.zernike_cvector)]) 306 | scale = tf.placeholder(tf.float32, shape=[None,1]) 307 | beforewf=datagenerator.buildgraph(x,scale) 308 | afterwf=datagenerator.propagate_through_wfs(beforewf) 309 | 310 | create_dataset(filename=args.name,coefficients=len(datagenerator.zernike_cvector)) 311 | if args.samplesf: 312 | print("make specific samples") 313 | with tf.Session() as sess: 314 | print("args.samplesf =>", args.samplesf) 315 | samplesf=np.loadtxt(args.samplesf) 316 | if len(np.shape(samplesf))==1: samplesf=samplesf.reshape(1,-1) 317 | if(np.shape(samplesf)[1]!=len(datagenerator.zernike_cvector)+1): 318 | raise ValueError('incorrect dimmensions in samples file: z coefs:'+str(len(datagenerator.zernike_cvector)) + " + 1 (scale)") 319 | with tf.Session() as sess: 320 | for _s in samplesf: 321 | print(" generating sample: _s =>", _s) 322 | z_coefs=_s[1:].reshape(1,-1) 323 | scales=_s[0].reshape(1,-1) 324 | f={x:z_coefs,scale:scales} 325 | _afterwf=sess.run(afterwf,feed_dict=f) 326 | _beforewf=sess.run(beforewf,feed_dict=f) 327 | save_to_hdf5( 328 | args.name, 329 | np.expand_dims(np.squeeze(_afterwf),0), 330 | np.expand_dims(np.squeeze(_beforewf),0), 331 | np.expand_dims(np.squeeze(z_coefs),0), 332 | np.expand_dims(np.squeeze(scales),0) 333 | ) 334 | else: 335 | if args.count % args.batch_size != 0: 336 | raise ValueError('batch size and count divide with remainder') 337 | with tf.Session() as sess: 338 | np.random.seed(args.seed) 339 | _count = 0 340 | while _count", _count) 342 | # make random numbers 343 | 344 | # for the zernike coefs 345 | n_z_coefs=len(datagenerator.zernike_cvector)* args.batch_size 346 | # for the scales 347 | n_scales=args.batch_size 348 | 349 | z_coefs = 12*(np.random.rand(n_z_coefs)-0.5) 350 | z_coefs=z_coefs.reshape(args.batch_size,-1) 351 | scales = 1+1*(np.random.rand(n_scales)-0.5) 352 | scales = scales.reshape(args.batch_size,1) 353 | # z_coefs[:,0:3]=0 354 | # z_coefs[:,9:]=0 # doesnt work 355 | # z_coefs[:,8:]=0 # works 356 | f={x: z_coefs, 357 | scale:scales 358 | } 359 | _afterwf=sess.run(afterwf,feed_dict=f) 360 | _beforewf=sess.run(beforewf,feed_dict=f) 361 | save_to_hdf5( 362 | args.name, 363 | np.squeeze(_afterwf), 364 | np.squeeze(_beforewf), 365 | np.squeeze(z_coefs), 366 | np.squeeze(scales) 367 | ) 368 | # plot data 369 | # for i in range(2): 370 | # fig,ax=plt.subplots(1,2,figsize=(10,5)) 371 | # im=ax[0].imshow(np.abs(_beforewf[i,:,:,0])**2,cmap='jet') 372 | # ax[0].set_title("intensity") 373 | # fig.colorbar(im,ax=ax[0]) 374 | # im=ax[1].imshow(np.angle(_beforewf[i,:,:,0]),cmap='jet') 375 | # ax[1].set_title("angle") 376 | # fig.colorbar(im,ax=ax[1]) 377 | # plt.show() 378 | _count += args.batch_size 379 | -------------------------------------------------------------------------------- /diffraction_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import params 3 | from numpy import unravel_index 4 | import os 5 | import math 6 | from PIL import Image, ImageDraw 7 | from PIL import ImagePath 8 | import tensorflow as tf 9 | import PIL.ImageOps 10 | import matplotlib.pyplot as plt 11 | from scipy.ndimage.interpolation import rotate 12 | from scipy.ndimage.filters import gaussian_filter 13 | from scipy.ndimage.interpolation import shift as sc_shift 14 | from astropy.io import fits 15 | from scipy.ndimage import shift as sc_im_shift 16 | from scipy.misc import factorial 17 | from skimage.transform import resize 18 | # import cv2 19 | from scipy import ndimage 20 | import scipy.io 21 | import params 22 | import datagen 23 | 24 | def fits_to_numpy(fits_file_name): 25 | thing = fits.open(fits_file_name) 26 | nparr = thing[0].data[0,:,:] 27 | nparr = nparr.astype(np.float64) 28 | return nparr 29 | 30 | 31 | def plot_amplitude_phase_meas_retreival(retrieved_obj, title, plot_spherical_aperture=False,ACTUAL=False,m_index=None,mask=False): 32 | 33 | if ACTUAL: 34 | RETRIEVED="ACTUAL" 35 | RECONSTRUCTED="ACTUAL" 36 | else: 37 | RETRIEVED="retrieved" 38 | RECONSTRUCTED="reconstructed" 39 | 40 | # coefficients and scale 41 | if 'coefficients' in retrieved_obj.keys(): 42 | retrieved_obj['coefficients'] = np.squeeze(retrieved_obj['coefficients']) 43 | assert len(np.shape(retrieved_obj['coefficients']))==1 44 | retrieved_obj['scale'] = np.squeeze(retrieved_obj['scale']) 45 | assert len(np.shape(retrieved_obj['scale']))==0 46 | 47 | # get axes for retrieved object and diffraction pattern 48 | N=np.shape(np.squeeze(retrieved_obj['measured_pattern']))[0] 49 | simulation_axes, amplitude_mask = get_amplitude_mask_and_imagesize(N, int(params.params.wf_ratio*N)) 50 | 51 | # normalize complex retrieved 52 | retrieved_complex_normalized = retrieved_obj['real_output']+1j*retrieved_obj['imag_output'] 53 | retrieved_complex_normalized*=1/(np.max(np.abs(retrieved_complex_normalized))) 54 | retrieved_obj['real_output']=np.real(retrieved_complex_normalized) 55 | retrieved_obj['imag_output']=np.imag(retrieved_complex_normalized) 56 | 57 | # object 58 | x=simulation_axes['object']['x'] # meters 59 | x*=1e6 60 | f=simulation_axes['diffraction_plane']['f'] # 1/meters 61 | f*=1e-6 62 | 63 | fig = plt.figure(figsize=(10,10)) 64 | # fig.subplots_adjust(wspace=0.5, left=0.5, top=0.95, bottom=0.10) 65 | fig.subplots_adjust(wspace=0.1, left=0.3) 66 | gs = fig.add_gridspec(4,2) 67 | fig.text(0.5, 0.95, title, ha="center", size=15) 68 | 69 | # run the constructor to get z n, m vector 70 | datagenerator = datagen.DataGenerator(1024,256) 71 | if 'coefficients' in retrieved_obj.keys(): 72 | fig.text(0.05,0.9,'Zernike Coefficients:',size=20,color='red') 73 | c_str="" 74 | for _c, _z in zip(retrieved_obj['coefficients'],datagenerator.zernike_cvector): 75 | c_str += r"$Z^{"+str(_z.m)+"}_{"+str(_z.n)+"}$" 76 | c_str+=" " 77 | c_str+="%.2f"%_c+'\n' 78 | fig.text(0.03,0.85,c_str,ha='left',va='top',size=20) 79 | fig.text(0.05,0.15,'Scale:',size=20,color='red') 80 | fig.text(0.03,0.10,'S:'+"%.2f"%retrieved_obj['scale'],ha='left',va='top',size=20) 81 | 82 | 83 | axes = {} 84 | axes["measured"] = fig.add_subplot(gs[0,0]) 85 | axes["reconstructed"] = fig.add_subplot(gs[0,1]) 86 | 87 | axes["real"] = fig.add_subplot(gs[1,0]) 88 | axes["imag"] = fig.add_subplot(gs[1,1]) 89 | 90 | axes["intensity"] = fig.add_subplot(gs[2,0]) 91 | axes["phase"] = fig.add_subplot(gs[2,1]) 92 | 93 | axes["phase_vertical"] = fig.add_subplot(gs[3,0]) 94 | axes["phase_horizontal"] = fig.add_subplot(gs[3,1]) 95 | 96 | # calculate the intensity 97 | complex_obj = np.squeeze(retrieved_obj["real_output"]) + 1j * np.squeeze(retrieved_obj["imag_output"]) 98 | 99 | I = np.abs(complex_obj)**2 100 | 101 | # calculate the phase 102 | # subtract phase at intensity peak 103 | if not m_index: 104 | m_index = unravel_index(I.argmax(), I.shape) 105 | 106 | phase_Imax = np.angle(complex_obj[m_index[0], m_index[1]]) 107 | complex_obj *= np.exp(-1j * phase_Imax) 108 | 109 | obj_phase = np.angle(complex_obj) 110 | 111 | # not using the amplitude_mask, use the absolute value of the intensity 112 | nonzero_intensity = np.array(np.abs(complex_obj)) 113 | nonzero_intensity[nonzero_intensity < 0.01*np.max(nonzero_intensity)] = 0 114 | nonzero_intensity[nonzero_intensity >= 0.01*np.max(nonzero_intensity)] = 1 115 | obj_phase *= nonzero_intensity 116 | 117 | 118 | 119 | # for testing 120 | # obj_phase[10:20, :] = np.max(obj_phase) 121 | # obj_phase[:, 10:20] = np.max(obj_phase) 122 | # obj_phase[:, -30:-20] = np.max(obj_phase) 123 | 124 | if mask: 125 | obj_phase_softmask=np.array(obj_phase) 126 | obj_phase_softmask[amplitude_mask>0]*=0.7 127 | im = axes["phase"].pcolormesh(x,x,obj_phase_softmask,vmin=-np.pi,vmax=np.pi,cmap='jet') 128 | else: 129 | im = axes["phase"].pcolormesh(x,x,obj_phase,vmin=-np.pi,vmax=np.pi,cmap='jet') 130 | 131 | axes["phase"].text(0.2, 0.9,"phase("+RETRIEVED+")", fontsize=10, ha='center', transform=axes["phase"].transAxes, backgroundcolor="cyan") 132 | fig.colorbar(im, ax=axes["phase"]) 133 | axes["phase"].axvline(x=x[m_index[1]], color="red", alpha=0.8) 134 | axes["phase"].axhline(y=x[m_index[0]], color="blue", alpha=0.8) 135 | 136 | axes["phase_horizontal"].plot(x,obj_phase[m_index[0], :], color="blue") 137 | axes["phase_horizontal"].set_ylim(-np.pi,np.pi) 138 | axes["phase_horizontal"].text(0.2, -0.25,"phase(horizontal)", fontsize=10, ha='center', transform=axes["phase_horizontal"].transAxes, backgroundcolor="blue") 139 | 140 | axes["phase_vertical"].plot(x,obj_phase[:, m_index[1]], color="red") 141 | axes["phase_vertical"].set_ylim(-np.pi,np.pi) 142 | axes["phase_vertical"].text(0.2, -0.25,"phase(vertical)", fontsize=10, ha='center', transform=axes["phase_vertical"].transAxes, backgroundcolor="red") 143 | 144 | 145 | if mask: 146 | I_softmask = np.array(I) 147 | I_softmask[amplitude_mask>0]*=0.7 148 | im = axes["intensity"].pcolormesh(x,x,I_softmask,vmin=0.0,vmax=1.0,cmap='jet') 149 | else: 150 | im = axes["intensity"].pcolormesh(x,x,I,vmin=0,vmax=1.0,cmap='jet') 151 | 152 | # plot the spherical aperture 153 | if plot_spherical_aperture: 154 | # circle to show where the wavefront originates 155 | circle=plt.Circle((0,0),2.7,color='r',fill=False,linewidth=2.0) 156 | axes["intensity"].add_artist(circle) 157 | axes["intensity"].text(0.8, 0.7,"Spherical\nAperture\n2.7 um", fontsize=10, ha='center', transform=axes["intensity"].transAxes,color="red") 158 | 159 | axes["intensity"].text(0.2, 0.9,"intensity("+RETRIEVED+")", fontsize=10, ha='center', transform=axes["intensity"].transAxes, backgroundcolor="cyan") 160 | axes["intensity"].set_ylabel("position [um]") 161 | fig.colorbar(im, ax=axes["intensity"]) 162 | 163 | 164 | # plt.figure(13); plt.pcolormesh(np.squeeze(retrieved_obj["measured_pattern"]));plt.savefig('test.png') 165 | # import ipdb; ipdb.set_trace() # BREAKPOINT 166 | # print("BREAKPOINT") 167 | 168 | im = axes["measured"].pcolormesh(f,f,np.log(np.squeeze(retrieved_obj["measured_pattern"])),vmin=-10,vmax=0,cmap='jet') 169 | # axes["measured"].set_ylim(-0.25,0.25);axes["measured"].set_xlim(-0.25,0.25) 170 | axes["measured"].set_ylabel(r"frequency [1/m]$\cdot 10^{6}$") 171 | axes["measured"].text(0.2, 0.9,"measured", fontsize=10, ha='center', transform=axes["measured"].transAxes, backgroundcolor="cyan") 172 | fig.colorbar(im, ax=axes["measured"]) 173 | 174 | im = axes["reconstructed"].pcolormesh(f,f,np.log(np.squeeze(retrieved_obj["tf_reconstructed_diff"])),vmin=-10,vmax=0,cmap='jet') 175 | # axes["reconstructed"].set_ylim(-0.25,0.25);axes["reconstructed"].set_xlim(-0.25,0.25) 176 | axes["reconstructed"].text(0.2, 0.9,RECONSTRUCTED, fontsize=10, ha='center', transform=axes["reconstructed"].transAxes, backgroundcolor="cyan") 177 | 178 | # calc mse 179 | A = retrieved_obj["measured_pattern"].reshape(-1) 180 | B = retrieved_obj["tf_reconstructed_diff"].reshape(-1) 181 | mse = (np.square(A-B)).mean() 182 | mse = str(mse) 183 | axes["reconstructed"].text(0.2, 1.1,"mse("+RECONSTRUCTED+", measured): "+mse, fontsize=10, ha='center', transform=axes["reconstructed"].transAxes, backgroundcolor="cyan") 184 | 185 | fig.colorbar(im, ax=axes["reconstructed"]) 186 | 187 | if mask: 188 | real_output_softmask=np.array(np.squeeze(retrieved_obj["real_output"])) 189 | real_output_softmask[amplitude_mask>0]*=0.7 190 | im = axes["real"].pcolormesh(x,x,real_output_softmask,vmin=-1.0,vmax=1.0,cmap='jet') 191 | else: 192 | im = axes["real"].pcolormesh(x,x,np.squeeze(retrieved_obj["real_output"]),vmin=-1.0,vmax=1.0,cmap='jet') 193 | axes["real"].text(0.2, 0.9,"real("+RETRIEVED+")", fontsize=10, ha='center', transform=axes["real"].transAxes, backgroundcolor="cyan") 194 | axes["real"].set_ylabel("position [um]") 195 | fig.colorbar(im, ax=axes["real"]) 196 | 197 | if mask: 198 | imag_output_softmask=np.array(np.squeeze(retrieved_obj["imag_output"])) 199 | imag_output_softmask[amplitude_mask>0]*=0.7 200 | im = axes["imag"].pcolormesh(x,x,imag_output_softmask,vmin=-1.0,vmax=1.0,cmap='jet') 201 | else: 202 | im = axes["imag"].pcolormesh(x,x,np.squeeze(retrieved_obj["imag_output"]),vmin=-1.0,vmax=1.0,cmap='jet') 203 | axes["imag"].text(0.2, 0.9,"imag("+RETRIEVED+")", fontsize=10, ha='center', transform=axes["imag"].transAxes, backgroundcolor="cyan") 204 | fig.colorbar(im, ax=axes["imag"]) 205 | 206 | return fig 207 | 208 | def plot_image_show_centroid_distance(mat, title, figurenum): 209 | """ 210 | plots an image and shows the distance from the centroid to the image center 211 | """ 212 | # calculate the current centroid 213 | c_y = calc_centroid(mat, axis=0) 214 | c_x = calc_centroid(mat, axis=1) 215 | plt.figure(figurenum) 216 | plt.imshow(mat) 217 | plt.axvline(x=c_x, color="yellow") 218 | plt.axhline(y=c_y, color="yellow") 219 | plt.title(title) 220 | plt.gca().text(0.0, 0.9,"c_x: {}".format( c_x ), fontsize=10, ha='center', transform=plt.gca().transAxes, backgroundcolor="yellow") 221 | plt.gca().text(0.0, 0.8,"c_y: {}".format( c_y ), fontsize=10, ha='center', transform=plt.gca().transAxes, backgroundcolor="yellow") 222 | 223 | plt.gca().text(0.0, 0.7,"image center x: {}".format( np.shape(mat)[0] / 2 ), fontsize=10, ha='center', transform=plt.gca().transAxes, backgroundcolor="green") 224 | plt.gca().text(0.0, 0.6,"image center y: {}".format( np.shape(mat)[0] / 2 ), fontsize=10, ha='center', transform=plt.gca().transAxes, backgroundcolor="green") 225 | 226 | # center of the image 227 | plt.axvline(x=int(np.shape(mat)[0] / 2), color="green") 228 | plt.axhline(y=int(np.shape(mat)[0] / 2), color="green") 229 | 230 | def format_experimental_trace(N, df_ratio, measured_diffraction_pattern, rotation_angle, trim): 231 | """ 232 | N: the desired size of the formatted experimental image 233 | df_ratio: (df calculated from diffraction plane) / (df calculated from object plane) 234 | measured_diffraction_pattern: the measured diffraction pattern to format 235 | rotation_angle: angle which to rotate the measured diffraction pattern (currently must be done by eye) 236 | """ 237 | # divide scale the measured trace by this amount 238 | new_size = np.shape(measured_diffraction_pattern)[0] * df_ratio 239 | new_size = int(new_size) 240 | 241 | # if the image is not square, make it square 242 | # print("shape before ",np.shape(measured_diffraction_pattern)) 243 | s=np.shape(measured_diffraction_pattern) 244 | if s[0]>s[1]: 245 | # less rows than cols 246 | _s=s[1] # number of cols 247 | measured_diffraction_pattern=measured_diffraction_pattern[(s[0]//2)-(_s//2):(s[0]//2)+(_s//2),:] 248 | elif s[0]N: 269 | measured_diffraction_pattern = measured_diffraction_pattern[int((new_size/2) - (N/2)):int((new_size/2) + (N/2)) , 270 | int((new_size/2) - (N/2)):int((new_size/2) + (N/2))] 271 | elif new_size ",np.shape(measured_pattern)) 343 | N = np.shape(measured_pattern)[0] 344 | 345 | # calculate delta frequency 346 | measured_axes = {} 347 | measured_axes["diffraction_plane"] = {} 348 | 349 | measured_axes["diffraction_plane"]["xmax"] = N * (experimental_params['pixel_size'] / 2) 350 | measured_axes["diffraction_plane"]["x"] = np.arange(-(measured_axes["diffraction_plane"]["xmax"]), (measured_axes["diffraction_plane"]["xmax"]), experimental_params['pixel_size']) 351 | 352 | measured_axes["diffraction_plane"]["f"] = measured_axes["diffraction_plane"]["x"] / (experimental_params['wavelength'] * experimental_params['z_distance']) 353 | measured_axes["diffraction_plane"]["df"] = measured_axes["diffraction_plane"]["f"][-1] - measured_axes["diffraction_plane"]["f"][-2] 354 | 355 | return measured_axes, measured_pattern 356 | 357 | 358 | def halve_imag(im): 359 | dim = im.size[0] 360 | left=int(dim/2)-int(dim/4) 361 | right=int(dim/2)+int(dim/4) 362 | top=int(dim/2)-int(dim/4) 363 | bottom=int(dim/2)+int(dim/4) 364 | im=im.crop((left,top,right,bottom)) 365 | return im 366 | 367 | def get_amplitude_mask_and_imagesize2(image_dimmension, desired_mask_width): 368 | 369 | # image_dimmension must be divisible by 4 370 | assert image_dimmension/4 == int(image_dimmension/4) 371 | # get the png image for amplitude 372 | im = Image.open("siemens_star_2048_psize_7nm.png") 373 | size1 = im.size[0] 374 | im_size_nm = 7*im.size[0] * 1e-9 # meters 375 | 376 | # zoom into image 377 | im = halve_imag(im) 378 | im_size_nm*=0.5 379 | im = halve_imag(im) 380 | im_size_nm*=0.5 381 | im = halve_imag(im) 382 | im_size_nm*=0.5 383 | # print(im.size) 384 | # exit() 385 | # im = PIL.ImageOps.invert(im) 386 | #TODO crop the image and adjust the image size 387 | 388 | # print("im_size_nm =>", im_size_nm) 389 | 390 | # scale down the image 391 | im = im.resize((desired_mask_width,desired_mask_width)) 392 | size2 = im.size[0] 393 | im = np.array(im)[:,:,1] 394 | 395 | # im=np.invert(im)# this is not physical 396 | 397 | 398 | # # # # # # # # # # # # # 399 | # # # pad the image # # # 400 | # # # # # # # # # # # # # 401 | 402 | # determine width of mask 403 | pad_amount = int((image_dimmension - desired_mask_width)/2) 404 | amplitude_mask = np.pad(im, pad_width=pad_amount, mode="constant", constant_values=0) 405 | amplitude_mask = amplitude_mask.astype(np.float64) 406 | amplitude_mask *= 1/np.max(amplitude_mask) # normalize 407 | assert amplitude_mask.shape[0] == image_dimmension 408 | size3 = np.shape(amplitude_mask)[0] 409 | ratio = size3 / size2 410 | # print("ratio =>", ratio) 411 | im_size_nm *= ratio # object image size [m] 412 | 413 | measured_axes = {} 414 | measured_axes["object"] = {} 415 | measured_axes["object"]["dx"] = im_size_nm / image_dimmension 416 | measured_axes["object"]["xmax"] = im_size_nm/2 417 | measured_axes["object"]["x"] = np.arange(-(measured_axes["object"]["xmax"]), (measured_axes["object"]["xmax"]), measured_axes["object"]["dx"]) 418 | measured_axes["diffraction_plane"] = {} 419 | measured_axes["diffraction_plane"]["df"] = 1 / (image_dimmension * measured_axes["object"]["dx"]) # frequency axis in diffraction plane 420 | measured_axes["diffraction_plane"]["fmax"] = (measured_axes["diffraction_plane"]["df"] * image_dimmension) / 2 421 | measured_axes["diffraction_plane"]["f"] = np.arange(-measured_axes["diffraction_plane"]["fmax"], measured_axes["diffraction_plane"]["fmax"], measured_axes["diffraction_plane"]["df"]) 422 | 423 | return measured_axes, amplitude_mask 424 | 425 | def get_amplitude_mask_and_imagesize(image_dimmension, desired_mask_width): 426 | desired_mask_width+=1 427 | 428 | # image_dimmension must be divisible by 4 429 | assert image_dimmension/4 == int(image_dimmension/4) 430 | # get the png image for amplitude 431 | im = params.params.wavefront_sensor 432 | im = PIL.ImageOps.invert(im) 433 | size1 = im.size[0] 434 | im_size_nm = params.params.wavefron_sensor_size_nm 435 | 436 | # scale down the image 437 | im = im.resize((desired_mask_width,desired_mask_width)).convert("L") 438 | size2 = im.size[0] 439 | im = np.array(im) 440 | 441 | # # # # # # # # # # # # # 442 | # # # pad the image # # # 443 | # # # # # # # # # # # # # 444 | 445 | # determine width of mask 446 | pad_amount = int((image_dimmension - desired_mask_width)/2) 447 | amplitude_mask = np.pad(im, pad_width=pad_amount, mode="constant", constant_values=0) 448 | amplitude_mask = amplitude_mask.astype(np.float64) 449 | amplitude_mask *= 1/np.max(amplitude_mask) # normalize 450 | assert amplitude_mask.shape[0] == image_dimmension 451 | size3 = np.shape(amplitude_mask)[0] 452 | ratio = size3 / size2 453 | # print("ratio =>", ratio) 454 | im_size_nm *= ratio # object image size [m] 455 | 456 | measured_axes = {} 457 | measured_axes["object"] = {} 458 | measured_axes["object"]["dx"] = im_size_nm / image_dimmension 459 | measured_axes["object"]["xmax"] = im_size_nm/2 460 | measured_axes["object"]["x"] = np.arange(-(measured_axes["object"]["xmax"]), (measured_axes["object"]["xmax"]), measured_axes["object"]["dx"]) 461 | measured_axes["diffraction_plane"] = {} 462 | measured_axes["diffraction_plane"]["df"] = 1 / (image_dimmension * measured_axes["object"]["dx"]) # frequency axis in diffraction plane 463 | measured_axes["diffraction_plane"]["fmax"] = (measured_axes["diffraction_plane"]["df"] * image_dimmension) / 2 464 | measured_axes["diffraction_plane"]["f"] = np.arange(-measured_axes["diffraction_plane"]["fmax"], measured_axes["diffraction_plane"]["fmax"], measured_axes["diffraction_plane"]["df"]) 465 | 466 | return measured_axes, amplitude_mask 467 | 468 | def make_image_square(image): 469 | 470 | # make the image square 471 | rows = np.shape(image)[0] 472 | cols = np.shape(image)[1] 473 | if rows > cols: 474 | half_new_rows = int(cols / 2) 475 | center_row = int(rows / 2) 476 | im_new = image[center_row-half_new_rows:center_row+half_new_rows,:] 477 | if cols > rows: 478 | half_new_cols = int(rows / 2) 479 | center_col = int(cols / 2) 480 | im_new = image[:,center_col-half_new_cols:center_col+half_new_cols] 481 | 482 | return im_new 483 | 484 | def bin_image(image, bin_shape): 485 | im_shape = image.shape 486 | image = image.reshape( 487 | int(im_shape[0]/bin_shape[0]), 488 | bin_shape[0], 489 | int(im_shape[1]/bin_shape[1]), 490 | bin_shape[1] 491 | ).sum(3).sum(1) 492 | return image 493 | 494 | 495 | def circular_crop(image, radius): 496 | 497 | # multpily by a circular beam amplitude 498 | y = np.linspace(-1, 1, np.shape(image)[0]).reshape(-1,1) 499 | x = np.linspace(-1, 1, np.shape(image)[1]).reshape(1,-1) 500 | r = np.sqrt(x**2 + y**2) 501 | # multiply the object by the beam 502 | image[r>radius] = 0 503 | 504 | 505 | def rescale_image(image, scale): 506 | new_img_size = (np.array(image.size) * scale).astype(int) # (width, height) 507 | c_left = (image.size[0] / 2) - (new_img_size[0]/2) 508 | c_upper = (image.size[1] / 2) - (new_img_size[1]/2) 509 | c_right = (image.size[0] / 2) + (new_img_size[0]/2) 510 | c_lower = (image.size[1] / 2) + (new_img_size[1]/2) 511 | image = image.crop((c_left,c_upper,c_right,c_lower))# left , upper, right, lower 512 | return image 513 | 514 | 515 | def zernike_polynomial(N, m, n, scalef): 516 | 517 | if m >= 0: 518 | even = True 519 | else: 520 | even = False 521 | n = np.abs(n) 522 | m = np.abs(m) 523 | 524 | # assert (n-m)/2 is an integer 525 | assert float((n-m)/2) - int((n-m)/2) == 0 526 | 527 | # make axes of rho and phi 528 | scale = 10.0/scalef 529 | x = np.linspace(-1*scale,1*scale,N).reshape(1,-1) 530 | y = np.linspace(-1*scale,1*scale,N).reshape(-1,1) 531 | 532 | rho = np.sqrt(x**2 + y**2) 533 | rho = np.expand_dims(rho, axis=2) 534 | 535 | phi = np.arctan2(y,x) 536 | 537 | # define axes 538 | # rho = np.linspace(0, 1, 100).reshape(1,-1) 539 | k = np.arange(0, ((n-m)/2)+1 ).reshape(1,1,-1) 540 | 541 | numerator = (-1)**k 542 | numerator *= factorial(n-k) 543 | 544 | denominator = factorial(k) 545 | denominator *= factorial(((n+m)/2)-k) 546 | denominator *= factorial(((n-m)/2)-k) 547 | 548 | R = (numerator / denominator)*rho**(n-2*k) 549 | R = np.sum(R, axis=2) 550 | 551 | if even: 552 | Z = R*np.cos(m*phi) 553 | else: 554 | Z = R*np.sin(m*phi) 555 | 556 | # for checking the sampling 557 | # the -1 <-> 1 range of the zernike polynomial should be approximately the width of the 558 | # not propagated pulse 559 | # use this to set scale 560 | r = np.sqrt(x**2 + y**2) 561 | # Z[r>1] = 0 562 | 563 | 564 | return Z, r 565 | 566 | 567 | 568 | def tf_reconstruct_diffraction_pattern(real_norm, imag_norm, propagateTF): 569 | 570 | # real_norm *= 2 # between 0 and 2 571 | # imag_norm *= 2 # between 0 and 2 572 | 573 | # real_norm -= 1 # between -1 and 1 574 | # imag_norm -= 1 # between -1 and 1 575 | 576 | # propagate through wavefront 577 | wavefront=tf.complex(real=real_norm,imag=imag_norm) 578 | through_wf=propagateTF.setup_graph_through_wfs(wavefront) 579 | 580 | complex_object_retrieved = tf.complex(real=tf.real(through_wf), imag=tf.imag(through_wf)) 581 | diffraction_pattern = tf.abs(tf_fft2(complex_object_retrieved, dimmensions=[1,2]))**2 582 | 583 | diffraction_pattern = diffraction_pattern / tf.reduce_max(diffraction_pattern, keepdims=True, axis=[1,2]) # normalize the diffraction pattern 584 | return diffraction_pattern 585 | 586 | def construct_diffraction_pattern(normalized_amplitude, normalized_phase, scalar): 587 | """ 588 | construct diffraction pattern from normalized (retrieved object) 589 | 590 | """ 591 | amplitude = np.array(normalized_amplitude) 592 | phase = np.array(normalized_phase) 593 | 594 | phase *= scalar 595 | 596 | complex_object = amplitude * np.exp(1j * phase) 597 | 598 | diffraction_pattern = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(complex_object))) 599 | # absolute value 600 | diffraction_pattern = np.abs(diffraction_pattern) 601 | # normalize the diffraction pattern 602 | diffraction_pattern = diffraction_pattern / np.max(diffraction_pattern) 603 | 604 | return diffraction_pattern 605 | 606 | def tf_fft2(image_in, dimmensions): 607 | """ 608 | 2D fourer transform matrix along dimmensions 609 | 610 | image_in: n-dimmensional complex tensor 611 | dimmensions: the dimmensions to do 2D FFt 612 | """ 613 | assert len(dimmensions) == 2 614 | 615 | # image_shifted = np.array(image_in) 616 | for _i in dimmensions: 617 | assert int(image_in.shape[_i]) % 2 == 0 618 | dim_shift = int(int(image_in.shape[_i]) / 2) 619 | image_in = tf.manip.roll(image_in, shift=dim_shift, axis=_i) 620 | 621 | # function is only made for inner two dimmensions to be fourier transformed 622 | # assert image_in.shape[0] == 1 623 | assert image_in.shape[3] == 1 624 | 625 | image_in = tf.transpose(image_in, perm=[0,3,1,2]) 626 | image_in = tf.fft2d(image_in) 627 | image_in = tf.transpose(image_in, perm=[0,2,3,1]) 628 | 629 | for _i in dimmensions: 630 | dim_shift = int(int(image_in.shape[_i]) / 2) 631 | image_in = tf.manip.roll(image_in, shift=dim_shift, axis=_i) 632 | 633 | return image_in 634 | 635 | def tf_ifft2(image_in, dimmensions): 636 | """ 637 | 2D fourer transform matrix along dimmensions 638 | 639 | image_in: n-dimmensional complex tensor 640 | dimmensions: the dimmensions to do 2D FFt 641 | """ 642 | assert len(dimmensions) == 2 643 | 644 | # image_shifted = np.array(image_in) 645 | for _i in dimmensions: 646 | assert int(image_in.shape[_i]) % 2 == 0 647 | dim_shift = int(int(image_in.shape[_i]) / 2) 648 | image_in = tf.manip.roll(image_in, shift=dim_shift, axis=_i) 649 | 650 | # function is only made for inner two dimmensions to be fourier transformed 651 | # assert image_in.shape[0] == 1 652 | assert image_in.shape[3] == 1 653 | 654 | image_in = tf.transpose(image_in, perm=[0,3,1,2]) 655 | image_in = tf.ifft2d(image_in) 656 | image_in = tf.transpose(image_in, perm=[0,2,3,1]) 657 | 658 | for _i in dimmensions: 659 | dim_shift = int(int(image_in.shape[_i]) / 2) 660 | image_in = tf.manip.roll(image_in, shift=dim_shift, axis=_i) 661 | 662 | return image_in 663 | 664 | 665 | def f_position_shift(mat, shift_value, axis): 666 | # shift_value = 1.5 667 | """ 668 | mat: 2d numpy array 669 | shift_value: the number of columns/rows to shift the matrix 670 | axis: the axis to shift the position 671 | 672 | shift_value may be a float 673 | 674 | """ 675 | # print("shift_value =>", shift_value) 676 | 677 | shift_val = [0,0] 678 | shift_val[axis] = shift_value 679 | mat = sc_shift(mat, shift=tuple(shift_val)) 680 | 681 | return mat 682 | 683 | def sum_over_all_except(mat, axis): 684 | """ 685 | mat: 2 or more dimmensional numpy array 686 | axis: the axis to keep 687 | 688 | perform summation over all axes except the input argument axis 689 | """ 690 | n_axes = np.shape(np.shape(mat)) 691 | summation_axes = list(range(n_axes[0])) 692 | summation_axes.remove(axis) 693 | summation_axes = tuple(summation_axes) 694 | summ = np.sum(mat, axis=(summation_axes)) 695 | return summ 696 | 697 | def centroid_shift(mat, value, axis): 698 | """ 699 | mat: the matrix which the centroid will be shifted 700 | value: the amount to shift the centroid 701 | axis: the axis along which the centroid will be shifted 702 | """ 703 | # use the absolute value 704 | # cast value to integer 705 | value = int(value) 706 | 707 | # print("value =>", value) 708 | 709 | # calculate current centroid: 710 | start_c = calc_centroid(np.abs(mat), axis) 711 | target_c = start_c + value 712 | # print("start_c =>", start_c) 713 | delta_c = 0 714 | 715 | if value > 0: 716 | # increase the centroid while its less than the target 717 | new_c = float(start_c) 718 | while new_c < target_c: 719 | mat = np.roll(mat, shift=1, axis=axis) 720 | new_c = calc_centroid(np.abs(mat), axis) 721 | 722 | elif value < 0: 723 | # decrease the centroid while its greater than the target 724 | new_c = float(start_c) 725 | 726 | while new_c > target_c: 727 | mat = np.roll(mat, shift=-1, axis=axis) 728 | new_c = calc_centroid(np.abs(mat), axis) 729 | 730 | return mat 731 | 732 | def calc_centroid(mat, axis): 733 | """ 734 | mat: 2 or more dimmensional numpy array 735 | axis: the axis to find the centroid 736 | """ 737 | 738 | # the number of axes in the input matrix 739 | summ = sum_over_all_except(mat, axis) 740 | index_vals = np.arange(0,len(summ)) 741 | 742 | # calculate centroid along this plot 743 | centroid = np.sum(summ * index_vals) / np.sum(summ) 744 | return centroid 745 | 746 | def remove_ambiguitues(object): 747 | """ 748 | object: 2d numpy array (not complex) 749 | 750 | remove the translation and conjugate flip ambiguities 751 | of a 2d complex matrix 752 | """ 753 | 754 | obj_size = np.shape(object) 755 | target_row = int(obj_size[0]/2) 756 | target_col = int(obj_size[1]/2) 757 | 758 | # calculate centroid along rows 759 | centr_row = calc_centroid(object, axis=0) 760 | centr_col = calc_centroid(object, axis=1) 761 | 762 | # move centroid to the center 763 | object = f_position_shift(object, shift_value=(target_row-centr_row), axis=0) 764 | object = f_position_shift(object, shift_value=(target_col-centr_col), axis=1) 765 | # remove conjugate flip ambiguity 766 | 767 | # integrate upper left and bottom right triangle 768 | # lower left 769 | tri_l = np.tril(np.ones(np.shape(object))) 770 | # upper right 771 | tri_u = np.triu(np.ones(np.shape(object))) 772 | integral_upper = np.sum(tri_u*object, axis=(0,1)) 773 | integral_lower = np.sum(tri_l*object, axis=(0,1)) 774 | 775 | # print(integral_upper > integral_lower) 776 | if integral_upper > integral_lower: 777 | # make conjugate flip 778 | object = np.flip(object, axis=1) 779 | object = np.flip(object, axis=0) 780 | 781 | return object 782 | 783 | def make_roll_ambiguity(object): 784 | n_elements = -5 785 | object = np.roll(object, shift=n_elements, axis=1) 786 | return object 787 | 788 | def make_flip_ambiguity(object): 789 | object = np.flip(object, axis=1) 790 | object = np.flip(object, axis=0) 791 | # complex conjugate 792 | object = np.conj(object) 793 | return object 794 | 795 | def make_object_phase(object, phase): 796 | """ 797 | input: 798 | object: between 0 and 1 799 | phase: between 0 and 1 800 | """ 801 | 802 | # multiply phase by object mask 803 | phase = phase * (object>0.2) 804 | 805 | # apply the phase 806 | object_with_phase = object * np.exp(-1j*phase*(2*np.pi)) 807 | 808 | return object_with_phase 809 | 810 | def make_object(N, min_indexes, max_indexes): 811 | """ 812 | returns: 813 | amplitude, phase 814 | both with normalized values between 0 and 1 815 | 816 | """ 817 | # must be divisible by 4 818 | assert N % 4 == 0 819 | obj = np.zeros((N,N), dtype=np.complex128) 820 | 821 | min_x = N/4 + 1 822 | min_y = N/4 + 1 823 | max_x = N - N/4 - 1 824 | max_y = N - N/4 - 1 825 | 826 | # generate random indexes 827 | # np.random.seed(3367) 828 | indexes_n = np.random.randint(min_indexes,max_indexes) 829 | # for each index generate an x and y point 830 | x = [] 831 | y = [] 832 | for i in range(indexes_n): 833 | 834 | x_val = min_x + np.random.rand(1)*(max_x-min_x) 835 | y_val = min_y + np.random.rand(1)*(max_y-min_y) 836 | 837 | x.append(int(x_val)) 838 | y.append(int(y_val)) 839 | 840 | x.append(x[0]) 841 | y.append(y[0]) 842 | 843 | xy = [(x_, y_) for x_, y_ in zip(x,y)] 844 | image = ImagePath.Path(xy).getbbox() 845 | size = list(map(int, map(math.ceil, image[2:]))) 846 | img = Image.new("RGB", [N,N], "#000000") 847 | img1 = ImageDraw.Draw(img) 848 | img1.polygon(xy, fill ="#ffffff") 849 | # convert to numpy array 850 | amplitude = np.array(img.getdata(), dtype=np.uint8).reshape(N, N, -1) 851 | amplitude = np.sum(amplitude, axis=2) 852 | amplitude = amplitude/np.max(amplitude) 853 | 854 | # apply gaussian filter 855 | amplitude = gaussian_filter(amplitude, sigma=0.8, order=0) 856 | return amplitude 857 | 858 | def plot_fft(object_in): 859 | 860 | # diffraction pattern 861 | diffraction_pattern = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(object_in))) 862 | 863 | # plt.figure() 864 | fig, ax = plt.subplots(2,2, figsize=(10,10)) 865 | fig.subplots_adjust(wspace=0.5, top=0.95, bottom=0.10) 866 | # object plane 867 | ax[0][0].pcolormesh(object_plane_x, object_plane_x, np.abs(object_in)) 868 | ax[0][0].set_xlabel("object plane distance [m]") 869 | ax[0][0].set_ylabel("object plane distance [m]") 870 | ax[0][0].set_title("object") 871 | 872 | # object phase 873 | ax[1][0].pcolormesh(object_plane_x, object_plane_x, np.angle(object_in)) 874 | ax[1][0].set_xlabel("object plane distance [m]") 875 | ax[1][0].set_ylabel("object plane distance [m]") 876 | ax[1][0].set_title("object phase") 877 | 878 | 879 | # diffraction plane 880 | ax[0][1].pcolormesh(diffraction_plane_x, diffraction_plane_x, np.abs(diffraction_pattern)) 881 | ax[0][1].set_title("diffraction pattern at %i [m]" % diffraction_plane_z) 882 | ax[0][1].set_xlabel("diffraction plane distance [m]") 883 | ax[0][1].set_ylabel("diffraction plane distance [m]") 884 | 885 | # diffraction plane 886 | ax[1][1].pcolormesh(diffraction_plane_x, diffraction_plane_x, np.log10(np.abs(diffraction_pattern))) 887 | ax[1][1].set_title(r"$log_{10}$"+"diffraction pattern at %i [m]" % diffraction_plane_z) 888 | ax[1][1].set_xlabel("diffraction plane distance [m]") 889 | ax[1][1].set_ylabel("diffraction plane distance [m]") 890 | 891 | return fig 892 | 893 | def create_phase(N): 894 | """ 895 | N: dimmensions of image 896 | 897 | returns: 898 | phase from -pi to +pi 899 | 900 | the phase is 0 at the center of the image 901 | 902 | """ 903 | # np.random.seed(22) 904 | # define a line with slope 905 | x_phase = np.linspace(-N/2, N/2, N).reshape(1,-1) 906 | y_phase = np.linspace(-N/2, N/2, N).reshape(-1,1) 907 | 908 | # create random rotation angle 909 | alpha_rad = np.random.rand() * 360.0 910 | alpha = alpha_rad*(np.pi / 180.0) 911 | # create random spacial frequency 912 | phase_frequency_min, phase_frequency_max = 0.4, 0.8 913 | phase_frequency = phase_frequency_min + np.random.rand() * (phase_frequency_max - phase_frequency_min) 914 | # rotation matrix 915 | x_rot = x_phase * np.cos(alpha) + y_phase * np.sin(alpha) 916 | y_rot = y_phase * np.cos(alpha) - x_phase * np.sin(alpha) 917 | 918 | phase = np.exp(1j * phase_frequency * x_rot) * np.exp(1j * 10*np.random.rand()) 919 | 920 | # subtract phase at center 921 | # phase_at_center = np.angle(phase[int(N/2), int(N/2)]) 922 | # phase = phase * np.exp(-1j * phase_at_center) * np.exp(1j * np.pi) 923 | phase = np.pi*phase 924 | 925 | # from - pi to + pi 926 | return np.real(phase) 927 | 928 | def matlab_cdi_retrieval(diffraction_pattern, support, interpolate=True,noise_reduction=False): 929 | 930 | diffraction_pattern=np.array(diffraction_pattern) # because i dont know if this is passed by reference 931 | if noise_reduction: 932 | diffraction_pattern[diffraction_pattern<0.05*np.max(diffraction_pattern)]=0 933 | 934 | # move to matlab cdi folder 935 | start_dir = os.getcwd() 936 | os.chdir("matlab_cdi") 937 | 938 | randomid_num = np.random.randint(10,size=10) 939 | randomid = "" 940 | for r in randomid_num: 941 | randomid += str(r) 942 | 943 | diffraction_pattern_file = randomid + "_diffraction.mat" 944 | support_file = randomid + "_support.mat" 945 | 946 | retrieved_obj_file = randomid + "_retrieved_obj.mat" 947 | reconstructed_file = randomid + "_reconstructed.mat" 948 | real_interp_file = randomid + "_real_interp.mat" 949 | imag_interp_file = randomid + "_imag_interp.mat" 950 | 951 | scipy.io.savemat(support_file, {'support':support}) 952 | scipy.io.savemat(diffraction_pattern_file, {'diffraction':diffraction_pattern}) 953 | 954 | # matlab load file 955 | with open("loaddata.m", "w") as file: 956 | file.write("function [diffraction_pattern_file, support_file, retrieved_obj_file, reconstructed_file, real_interp_file, imag_interp_file] = loaddata()\n") 957 | file.write("diffraction_pattern_file = '{}';\n".format(diffraction_pattern_file)) 958 | file.write("support_file = '{}';\n".format(support_file)) 959 | file.write("retrieved_obj_file = '{}';\n".format(retrieved_obj_file)) 960 | file.write("reconstructed_file = '{}';\n".format(reconstructed_file)) 961 | file.write("real_interp_file = '{}';\n".format(real_interp_file)) 962 | file.write("imag_interp_file = '{}';\n".format(imag_interp_file)) 963 | file.flush() 964 | os.system('/usr/local/R2020a/bin/matlab -nodesktop -r seeded_run_CDI_noprocessing') 965 | print("matlab ran") 966 | 967 | # load the results from matlab run 968 | rec_object = scipy.io.loadmat(retrieved_obj_file)['rec_object'] 969 | recon_diffracted = scipy.io.loadmat(reconstructed_file)['recon_diffracted'] 970 | obj_real_interp1 = scipy.io.loadmat(real_interp_file)['obj_real_interp1'] 971 | obj_imag_interp1 = scipy.io.loadmat(imag_interp_file)['obj_imag_interp1'] 972 | 973 | obj_imag_interp1[np.isnan(obj_imag_interp1)]=0.0 974 | obj_real_interp1[np.isnan(obj_real_interp1)]=0.0 975 | 976 | # plt.figure() 977 | # plt.pcolormesh(obj_real_interp1) 978 | # plt.title("obj_real_interp1") 979 | 980 | # plt.figure() 981 | # plt.pcolormesh(obj_imag_interp1) 982 | # plt.title("obj_imag_interp1") 983 | 984 | # plt.figure() 985 | # plt.pcolormesh(np.real(rec_object)) 986 | # plt.title("np.real(rec_object)") 987 | 988 | # plt.figure() 989 | # plt.pcolormesh(np.imag(rec_object)) 990 | # plt.title("np.imag(rec_object)") 991 | 992 | # plt.show() 993 | 994 | # go back to starting dir 995 | os.chdir(start_dir) 996 | 997 | retrieved_obj = {} 998 | retrieved_obj["measured_pattern"] = diffraction_pattern 999 | retrieved_obj["tf_reconstructed_diff"] = recon_diffracted 1000 | 1001 | # without the interpolation 1002 | if interpolate: 1003 | retrieved_obj["real_output"] = obj_real_interp1 1004 | retrieved_obj["imag_output"] = obj_imag_interp1 1005 | else: 1006 | retrieved_obj["real_output"] = np.real(rec_object) 1007 | retrieved_obj["imag_output"] = np.imag(rec_object) 1008 | 1009 | return retrieved_obj 1010 | 1011 | 1012 | 1013 | # grid space of the diffraction pattern 1014 | N = 40 # measurement points 1015 | diffraction_plane_x_max = 1 # meters 1016 | diffraction_plane_z = 10 # meters 1017 | wavelength = 400e-9 1018 | 1019 | # measured diffraction plane 1020 | diffraction_plane_dx = 2*diffraction_plane_x_max/N 1021 | diffraction_plane_x = diffraction_plane_dx * np.arange(-N/2, N/2, 1) 1022 | 1023 | # convert distance to frequency domain 1024 | diffraction_plane_fx = diffraction_plane_x / (wavelength * diffraction_plane_z) 1025 | diffraction_plane_dfx = diffraction_plane_dx / (wavelength * diffraction_plane_z) 1026 | 1027 | # x coordinates at object plane 1028 | object_plane_dx = 1 / ( diffraction_plane_dfx * N) 1029 | object_plane_x = object_plane_dx * np.arange(-N/2, N/2, 1) 1030 | 1031 | 1032 | if __name__ == "__main__": 1033 | 1034 | # construct object in the object plane 1035 | object, object_phase = make_object(N, min_indexes=4, max_indexes=8) 1036 | # object_phase = np.ones_like(object_phase) 1037 | # construct phase 1038 | object_with_phase = make_object_phase(object, object_phase) 1039 | 1040 | object_with_phase_removed_ambi = remove_ambiguitues(np.array(object_with_phase)) 1041 | 1042 | plt.figure(1) 1043 | plt.pcolormesh(np.abs(object_with_phase)) 1044 | plt.figure(2) 1045 | plt.pcolormesh(np.angle(object_with_phase)) 1046 | 1047 | plt.figure(3) 1048 | plt.pcolormesh(np.abs(object_with_phase_removed_ambi)) 1049 | plt.figure(4) 1050 | plt.pcolormesh(np.angle(object_with_phase_removed_ambi)) 1051 | 1052 | plt.show() 1053 | exit() 1054 | 1055 | 1056 | fig = plot_fft(object_with_phase) 1057 | fig = plot_fft(object_with_phase_removed_ambi) 1058 | # fig.savefig("obj1.png") 1059 | 1060 | # apply roll to generate ambiguity 1061 | # fig = plot_fft(make_roll_ambiguity(object_with_phase)) 1062 | # fig.savefig("obj2.png") 1063 | 1064 | # # conjugate flip to generate ambiguity 1065 | # fig = plot_fft(make_flip_ambiguity(object_with_phase)) 1066 | # fig.savefig("obj3.png") 1067 | 1068 | plt.show() 1069 | 1070 | 1071 | 1072 | 1073 | 1074 | 1075 | 1076 | -------------------------------------------------------------------------------- /get_single_sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tables 3 | import pickle 4 | 5 | 6 | if __name__ == "__main__": 7 | hdf5_file_validation = tables.open_file("zernike3/build/test.hdf5", mode="r") 8 | sample = hdf5_file_validation.root.diffraction[0, :].reshape(256,256) 9 | with open("sample.p", "wb") as file: 10 | pickle.dump(sample,file) 11 | -------------------------------------------------------------------------------- /live_capture/TIS.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import gi 3 | import time 4 | 5 | from collections import namedtuple 6 | 7 | gi.require_version("Gst", "1.0") 8 | gi.require_version("Tcam", "0.1") 9 | 10 | from gi.repository import Tcam, Gst, GLib, GObject 11 | 12 | 13 | DeviceInfo = namedtuple("DeviceInfo", "status name identifier connection_type") 14 | CameraProperty = namedtuple("CameraProperty", "status value min max default step type flags category group") 15 | 16 | 17 | class TIS: 18 | 'The Imaging Source Camera' 19 | 20 | def __init__(self,serial, width, height, framerate, color): 21 | ''' Constructor 22 | :param serial: Serial number of the camera to be used. 23 | :param width: Width of the wanted video format 24 | :param height: Height of the wanted video format 25 | :param framerate: Numerator of the frame rate. /1 is added automatically 26 | :param color: True = 8 bit color, False = 8 bit mono. ToDo: Y16 27 | :return: none 28 | ''' 29 | Gst.init([]) 30 | self.height = height 31 | self.width = width 32 | self.sample = None 33 | self.samplelocked = False 34 | self.newsample = False 35 | self.img_mat = None 36 | self.ImageCallback = None 37 | 38 | pixelformat = "BGRx" 39 | if color is False: 40 | pixelformat="GRAY8" 41 | 42 | p = 'tcambin serial="%s" name=source ! video/x-raw,format=%s,width=%d,height=%d,framerate=%d/1' % (serial,pixelformat,width,height,framerate,) 43 | p += ' ! appsink name=sink' 44 | 45 | print(p) 46 | try: 47 | self.pipeline = Gst.parse_launch(p) 48 | except GLib.Error as error: 49 | print("Error creating pipeline: {0}".format(err)) 50 | raise 51 | 52 | self.pipeline.set_state(Gst.State.READY) 53 | self.pipeline.get_state(Gst.CLOCK_TIME_NONE) 54 | # Query a pointer to our source, so we can set properties. 55 | self.source = self.pipeline.get_by_name("source") 56 | 57 | # Query a pointer to the appsink, so we can assign the callback function. 58 | self.appsink = self.pipeline.get_by_name("sink") 59 | self.appsink.set_property("max-buffers",5) 60 | self.appsink.set_property("drop",1) 61 | self.appsink.set_property("emit-signals",1) 62 | self.appsink.connect('new-sample', self.on_new_buffer) 63 | 64 | def on_new_buffer(self, appsink): 65 | self.newsample = True 66 | if self.samplelocked is False: 67 | try: 68 | self.sample = appsink.get_property('last-sample') 69 | if self.ImageCallback is not None: 70 | self.__convert_sample_to_numpy() 71 | self.ImageCallback(self, *self.ImageCallbackData); 72 | 73 | except GLib.Error as error: 74 | print("Error on_new_buffer pipeline: {0}".format(err)) 75 | raise 76 | return False 77 | 78 | def Start_pipeline(self): 79 | try: 80 | self.pipeline.set_state(Gst.State.PLAYING) 81 | self.pipeline.get_state(Gst.CLOCK_TIME_NONE) 82 | 83 | except GLib.Error as error: 84 | print("Error starting pipeline: {0}".format(err)) 85 | raise 86 | 87 | def __convert_sample_to_numpy(self): 88 | ''' Convert a GStreamer sample to a numpy array 89 | Sample code from https://gist.github.com/cbenhagen/76b24573fa63e7492fb6#file-gst-appsink-opencv-py-L34 90 | 91 | The result is in self.img_mat. 92 | :return: 93 | ''' 94 | self.samplelocked = True 95 | buf = self.sample.get_buffer() 96 | caps = self.sample.get_caps() 97 | bpp = 4; 98 | dtype = numpy.uint8 99 | bla = caps.get_structure(0).get_value('height') 100 | if( caps.get_structure(0).get_value('format') == "BGRx" ): 101 | bpp = 4; 102 | 103 | if(caps.get_structure(0).get_value('format') == "GRAY8" ): 104 | bpp = 1; 105 | 106 | if(caps.get_structure(0).get_value('format') == "GRAY16_LE" ): 107 | bpp = 1; 108 | dtype = numpy.uint16 109 | 110 | self.img_mat = numpy.ndarray( 111 | (caps.get_structure(0).get_value('height'), 112 | caps.get_structure(0).get_value('width'), 113 | bpp), 114 | buffer=buf.extract_dup(0, buf.get_size()), 115 | dtype=dtype) 116 | self.newsample = False 117 | self.samplelocked = False 118 | 119 | def wait_for_image(self,timeout): 120 | ''' Wait for a new image with timeout 121 | :param timeout: wait time in second, should be a float number 122 | :return: 123 | ''' 124 | 125 | tries = 10 126 | while tries > 0 and not self.newsample: 127 | tries -= 1 128 | time.sleep(float(timeout) / 10.0) 129 | 130 | def Snap_image(self, timeout): 131 | ''' 132 | Snap an image from stream using a timeout. 133 | :param timeout: wait time in second, should be a float number. Not used 134 | :return: bool: True, if we got a new image, otherwise false. 135 | ''' 136 | if self.ImageCallback is not None: 137 | print("Snap_image can not be called, if a callback is set.") 138 | return False 139 | 140 | self.wait_for_image(timeout) 141 | if( self.sample != None and self.newsample == True): 142 | self.__convert_sample_to_numpy() 143 | return True 144 | 145 | return False 146 | 147 | def Get_image(self): 148 | return self.img_mat 149 | 150 | def Stop_pipeline(self): 151 | self.pipeline.set_state(Gst.State.PAUSED) 152 | self.pipeline.set_state(Gst.State.READY) 153 | self.pipeline.set_state(Gst.State.NULL) 154 | 155 | def List_Properties(self): 156 | for name in self.source.get_tcam_property_names(): 157 | print( name ) 158 | 159 | def Get_Property(self, PropertyName): 160 | try: 161 | return CameraProperty(*self.source.get_tcam_property(PropertyName)) 162 | except GLib.Error as error: 163 | print("Error get Property {0}: {1}",PropertyName, format(err)) 164 | raise 165 | 166 | def Set_Property(self, PropertyName, value): 167 | try: 168 | self.source.set_tcam_property(PropertyName,GObject.Value(type(value),value)) 169 | except GLib.Error as error: 170 | print("Error set Property {0}: {1}",PropertyName, format(err)) 171 | raise 172 | 173 | def Set_Image_Callback(self, function, *data): 174 | self.ImageCallback = function 175 | self.ImageCallbackData = data; 176 | -------------------------------------------------------------------------------- /live_capture/live_stream_test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import TIS 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | # Open camera, set video format, framerate and determine, whether the sink is color or bw 10 | # Parameters: Serialnumber, width, height, framerate (numerator only) , color 11 | # If color is False, then monochrome / bw format is in memory. If color is True, then RGB32 12 | # colorformat is in memory 13 | Tis = TIS.TIS("48710182", 640, 480, 30, False) 14 | Tis.Start_pipeline() # Start the pipeline so the camera streams 15 | 16 | print('Press Esc to stop') 17 | lastkey = 0 18 | 19 | cv2.namedWindow('Window') # Create an OpenCV output window 20 | 21 | kernel = np.ones((5, 5), np.uint8) # Create a Kernel for OpenCV erode function 22 | 23 | fig, ax = plt.subplots(1,1, figsize=(5,5)) 24 | plt.ion() 25 | first_image = False 26 | 27 | while lastkey != 27: 28 | if Tis.Snap_image(1) is True: # Snap an image with one second timeout 29 | image = Tis.Get_image() # Get the image. It is a numpy array 30 | 31 | if not first_image: 32 | im = ax.imshow(image[:,:,0], vmin=0.0, vmax=20) 33 | colorbar = fig.colorbar(im, ax=ax) 34 | first_image = True 35 | else: 36 | im.set_data(image[:,:,0]) 37 | print("calling set_clim") 38 | print("np.min(image) =>", np.min(image)) 39 | print("np.max(image) =>", np.max(image)) 40 | im.set_clim(vmin=np.min(image), vmax=np.max(image)) 41 | 42 | # print("np.min(image) =>", np.min(image)) 43 | # print("np.max(image) =>", np.max(image)) 44 | 45 | plt.pause(0.001) 46 | 47 | image = cv2.erode(image, kernel, iterations=5) # Example OpenCV image processing 48 | cv2.imshow('Window', image) # Display the result 49 | 50 | lastkey = cv2.waitKey(10) 51 | 52 | # Stop the pipeline and clean up 53 | Tis.Stop_pipeline() 54 | cv2.destroyAllWindows() 55 | print('Program ends') 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /live_capture/setup.txt: -------------------------------------------------------------------------------- 1 | installation for camera DMK 33up5000 / 48710182 2 | on ubuntu 18.04 3 | 4 | install this deb package, then install gstreamer from ubuntu repo 5 | use this for reference: https://github.com/TheImagingSource/Linux-tiscamera-Programming-Samples 6 | 7 | on ubuntu 18.04 - tcam_capture will install to the python3.5 dist-packages, 8 | but it will not by found by python3 interpreter, 9 | copy the tcam_capture folder in /usr/lib/python3.5/site-packages to /usr/lib/python3/dist-packages 10 | 11 | -------------------------------------------------------------------------------- /live_capture/tiscamera_0.11.1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwhite510/wavefront-sensor-neural-network/3a521fea276119438fe07b6e2b777e677c1f83c4/live_capture/tiscamera_0.11.1_amd64.deb -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'main.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.1 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | 13 | 14 | class Ui_MainWindow(object): 15 | def setupUi(self, MainWindow): 16 | MainWindow.setObjectName("MainWindow") 17 | MainWindow.resize(976, 856) 18 | self.centralwidget = QtWidgets.QWidget(MainWindow) 19 | self.centralwidget.setObjectName("centralwidget") 20 | self.gridLayout = QtWidgets.QGridLayout(self.centralwidget) 21 | self.gridLayout.setObjectName("gridLayout") 22 | self.verticalLayout_10 = QtWidgets.QVBoxLayout() 23 | self.verticalLayout_10.setObjectName("verticalLayout_10") 24 | self.horizontalLayout = QtWidgets.QHBoxLayout() 25 | self.horizontalLayout.setObjectName("horizontalLayout") 26 | self.verticalLayout = QtWidgets.QVBoxLayout() 27 | self.verticalLayout.setObjectName("verticalLayout") 28 | self.horizontalLayout.addLayout(self.verticalLayout) 29 | self.verticalLayout_3 = QtWidgets.QVBoxLayout() 30 | self.verticalLayout_3.setObjectName("verticalLayout_3") 31 | self.widget = QtWidgets.QWidget(self.centralwidget) 32 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Expanding) 33 | sizePolicy.setHorizontalStretch(0) 34 | sizePolicy.setVerticalStretch(0) 35 | sizePolicy.setHeightForWidth(self.widget.sizePolicy().hasHeightForWidth()) 36 | self.widget.setSizePolicy(sizePolicy) 37 | self.widget.setObjectName("widget") 38 | self.pushButton = QtWidgets.QPushButton(self.widget) 39 | self.pushButton.setGeometry(QtCore.QRect(10, 10, 80, 23)) 40 | self.pushButton.setObjectName("pushButton") 41 | self.rotation_edit = QtWidgets.QLineEdit(self.widget) 42 | self.rotation_edit.setGeometry(QtCore.QRect(210, 100, 113, 23)) 43 | self.rotation_edit.setObjectName("rotation_edit") 44 | self.orientation_edit = QtWidgets.QComboBox(self.widget) 45 | self.orientation_edit.setGeometry(QtCore.QRect(158, 60, 171, 23)) 46 | self.orientation_edit.setObjectName("orientation_edit") 47 | self.scale_edit = QtWidgets.QLineEdit(self.widget) 48 | self.scale_edit.setGeometry(QtCore.QRect(210, 140, 113, 23)) 49 | self.scale_edit.setObjectName("scale_edit") 50 | self.label = QtWidgets.QLabel(self.widget) 51 | self.label.setGeometry(QtCore.QRect(80, 100, 121, 20)) 52 | self.label.setObjectName("label") 53 | self.label_2 = QtWidgets.QLabel(self.widget) 54 | self.label_2.setGeometry(QtCore.QRect(80, 140, 121, 20)) 55 | self.label_2.setObjectName("label_2") 56 | self.label_3 = QtWidgets.QLabel(self.widget) 57 | self.label_3.setGeometry(QtCore.QRect(80, 60, 121, 20)) 58 | self.label_3.setObjectName("label_3") 59 | self.verticalLayout_3.addWidget(self.widget) 60 | self.horizontalLayout.addLayout(self.verticalLayout_3) 61 | self.verticalLayout_10.addLayout(self.horizontalLayout) 62 | self.horizontalLayout_4 = QtWidgets.QHBoxLayout() 63 | self.horizontalLayout_4.setObjectName("horizontalLayout_4") 64 | self.verticalLayout_10.addLayout(self.horizontalLayout_4) 65 | self.gridLayout.addLayout(self.verticalLayout_10, 0, 0, 1, 1) 66 | self.view_toggle = QtWidgets.QPushButton(self.centralwidget) 67 | self.view_toggle.setObjectName("view_toggle") 68 | self.gridLayout.addWidget(self.view_toggle, 1, 0, 1, 1) 69 | MainWindow.setCentralWidget(self.centralwidget) 70 | self.menubar = QtWidgets.QMenuBar(MainWindow) 71 | self.menubar.setGeometry(QtCore.QRect(0, 0, 976, 20)) 72 | self.menubar.setObjectName("menubar") 73 | MainWindow.setMenuBar(self.menubar) 74 | self.statusbar = QtWidgets.QStatusBar(MainWindow) 75 | self.statusbar.setObjectName("statusbar") 76 | MainWindow.setStatusBar(self.statusbar) 77 | 78 | self.retranslateUi(MainWindow) 79 | self.view_toggle.clicked.connect(MainWindow.TogglePlotRE_IM) 80 | self.rotation_edit.textChanged['QString'].connect(MainWindow.ProcessingUpdated) 81 | self.pushButton.clicked.connect(MainWindow.Start_Stop_Clicked) 82 | self.scale_edit.textChanged['QString'].connect(MainWindow.ProcessingUpdated) 83 | self.orientation_edit.currentIndexChanged['QString'].connect(MainWindow.ProcessingUpdated) 84 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 85 | 86 | def retranslateUi(self, MainWindow): 87 | _translate = QtCore.QCoreApplication.translate 88 | MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) 89 | self.pushButton.setText(_translate("MainWindow", "Start")) 90 | self.label.setText(_translate("MainWindow", "rotation (degrees)")) 91 | self.label_2.setText(_translate("MainWindow", "scale")) 92 | self.label_3.setText(_translate("MainWindow", "orientation")) 93 | self.view_toggle.setText(_translate("MainWindow", "Real/Imag")) 94 | 95 | 96 | if __name__ == "__main__": 97 | import sys 98 | app = QtWidgets.QApplication(sys.argv) 99 | MainWindow = QtWidgets.QMainWindow() 100 | ui = Ui_MainWindow() 101 | ui.setupUi(MainWindow) 102 | MainWindow.show() 103 | sys.exit(app.exec_()) 104 | -------------------------------------------------------------------------------- /main.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | MainWindow 4 | 5 | 6 | 7 | 0 8 | 0 9 | 976 10 | 856 11 | 12 | 13 | 14 | MainWindow 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 0 32 | 0 33 | 34 | 35 | 36 | 37 | 38 | 10 39 | 10 40 | 80 41 | 23 42 | 43 | 44 | 45 | Start 46 | 47 | 48 | 49 | 50 | 51 | 210 52 | 100 53 | 113 54 | 23 55 | 56 | 57 | 58 | 59 | 60 | 61 | 158 62 | 60 63 | 171 64 | 23 65 | 66 | 67 | 68 | 69 | 70 | 71 | 210 72 | 140 73 | 113 74 | 23 75 | 76 | 77 | 78 | 79 | 80 | 81 | 80 82 | 100 83 | 121 84 | 20 85 | 86 | 87 | 88 | rotation (degrees) 89 | 90 | 91 | 92 | 93 | 94 | 80 95 | 140 96 | 121 97 | 20 98 | 99 | 100 | 101 | scale 102 | 103 | 104 | 105 | 106 | 107 | 80 108 | 60 109 | 121 110 | 20 111 | 112 | 113 | 114 | orientation 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | Real/Imag 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 0 141 | 0 142 | 976 143 | 20 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | view_toggle 153 | clicked() 154 | MainWindow 155 | TogglePlotRE_IM() 156 | 157 | 158 | 899 159 | 677 160 | 161 | 162 | 898 163 | 570 164 | 165 | 166 | 167 | 168 | rotation_edit 169 | textChanged(QString) 170 | MainWindow 171 | ProcessingUpdated() 172 | 173 | 174 | 720 175 | 188 176 | 177 | 178 | 846 179 | 240 180 | 181 | 182 | 183 | 184 | pushButton 185 | clicked() 186 | MainWindow 187 | Start_Stop_Clicked() 188 | 189 | 190 | 499 191 | 89 192 | 193 | 194 | 884 195 | 52 196 | 197 | 198 | 199 | 200 | scale_edit 201 | textChanged(QString) 202 | MainWindow 203 | ProcessingUpdated() 204 | 205 | 206 | 711 207 | 227 208 | 209 | 210 | 840 211 | 282 212 | 213 | 214 | 215 | 216 | orientation_edit 217 | currentIndexChanged(QString) 218 | MainWindow 219 | ProcessingUpdated() 220 | 221 | 222 | 662 223 | 135 224 | 225 | 226 | 851 227 | 199 228 | 229 | 230 | 231 | 232 | 233 | signal1() 234 | textchanged() 235 | Start_Stop_Clicked() 236 | TogglePlotRE_IM() 237 | ProcessingUpdated() 238 | 239 | 240 | -------------------------------------------------------------------------------- /matlab_cdi/cdi_interpolation.m: -------------------------------------------------------------------------------- 1 | % This is a script for retrieving the phase of a diffraction pattern thus enabling the 2 | % reconstruction of the diffracting object in real space. 3 | % Date : 23.11.2018 4 | % important parameters : samp_cam_dist, sub_image, on_chip_binning, crop, sam_size 5 | % important points : stray light, threshold for background, scaling of 6 | % support, flip/not of the support, correct sample size 7 | % Mask the center... 8 | 9 | close all 10 | clear all 11 | clc 12 | 13 | addpath('Y:\Software\Matlab_classFig'); %path to class_fig 14 | 15 | % If a simulation should be run 16 | sim = false; 17 | 18 | %% loading the data 19 | if ~sim 20 | % Loading measurement data 21 | % [name,pathstr,~] = uigetfile('Y:\AG_HHG\experiments\2018_11_Wavefront_sensing\20200127\4um_pinhole\z-360\feld_2\fits\*.fits'); 22 | % [name,pathstr,~] = uigetfile('Y:\AG_HHG\experiments\2018_11_Wavefront_sensing\20200127\4um_pinhole\z0\feld_2\fits\*.fits'); 23 | % [name,pathstr,~] = uigetfile('Y:\AG_HHG\experiments\2018_11_Wavefront_sensing\20200221\wfs_3um_pinhole\fits\*.fits'); 24 | [name,pathstr,~] = uigetfile('Y:\AG_HHG\experiments\2018_11_Wavefront_sensing\20200224\feld_2\z-1000\fits\*.fits'); 25 | file=fullfile(pathstr,name); 26 | data = fitsread(file); 27 | else 28 | % % loading simulation data 29 | load('\20200108\fits\FELD3_0p5s_2x2b.fits'); 30 | data = DP; 31 | end 32 | 33 | % support_filename = 'samples\\size_3p2um_pitch_400nm_diameter_200nm_psize_5nm.png'; 34 | support_filename = 'samples\\size_6um_pitch_600nm_diameter_300nm_psize_5nm.png'; 35 | supp_pxl_size = 0.005; % psize in the png image [um/px] 36 | 37 | meas_diff = double(data); 38 | 39 | %% Output settings 40 | reconst_interp = 1; % 1 to do interpolation of the reconstruction 41 | reconst_plot = 1; % 1 to plot the reconstructions 42 | dump_recon = true; % If the reconstruction should be saved in a mat file 43 | dump_folder = 'test'; 44 | phase_threshold = 0.1; 45 | % Create dump folder if it does not exists 46 | if ~exist(dump_folder, 'dir') 47 | mkdir(dump_folder) 48 | end 49 | 50 | %% settings of the experiment 51 | % do some rework here and only use necessary parameters 52 | lam = 0.0135; % in microns 53 | % samp_cam_dist_l = [32500, 32501, 32502, 32750, 32751, 33000, 33001, 33002]; 54 | samp_cam_dist_l = [32000, 32001, 32250, 32251]; 55 | cam_pxl_size = 13.5; 56 | on_chip_binning = 2; % during measurement. camera binning to get the loaded data 57 | binning = 1; % additional numerical binning to be applied now 58 | crop = 0; % 0 for no additional cropping, 1 for half the camera and 2 for 1/4 of the camera 59 | pad = false; 60 | total_bin = binning * on_chip_binning; 61 | num_pxls_dp = 512; % number of pixels of the detector 62 | window = true; 63 | 64 | if pad > num_pxls_dp 65 | num_pxls_dp = pad; 66 | end 67 | 68 | %% Settings of the reconsruction algorithm 69 | % ---------------------------------------------------------------------- 70 | stiching = false; 71 | in_BS = NaN; 72 | BS_used = 0 ; 73 | update_me = 0 ; % get an GUI update or not, slow if updated 74 | use_RAAR = 1; % 1 for RAAR and 0 for HIO 75 | conditionals = [BS_used update_me use_RAAR]; 76 | 77 | beta = 0.8; %between 0.8 and 0.98; if higher, feedback is stronger 78 | iter_total = 49; % total number of iterations 79 | iter_cycle = 50; % cycle of HIO/RAAR and ER 80 | HIO_num = 40; % HIO/RAAR per cycle 81 | parameters = [beta iter_total iter_cycle HIO_num]; 82 | 83 | %% Preprocessing parameters 84 | %------------------------------------------------------------------------- 85 | cc = true; 86 | threshold = 20; 87 | offset = 0; 88 | % rot_angle = -1.25; % First data set 89 | rot_angle = 2.7; 90 | long_exp_t = 10; 91 | short_exp_t = 1; 92 | sigma_conv = 0.1; % 0.1 93 | 94 | %% Preprocessing 95 | % Stiching 96 | if stiching 97 | factor = long_exp_t / short_exp_t; 98 | 99 | % Load short exposure data 100 | [name,pathstr,~] = uigetfile('Y:\AG_HHG\experiments\2018_11_Wavefront_sensing\20181212\fits\*.fits'); 101 | file = fullfile(pathstr,name); 102 | short_exp = fitsread(file); 103 | short_exp = double(short_exp); 104 | 105 | short_exp(short_exp < 5) = 0; 106 | 107 | figure 108 | imagesc(meas_diff) 109 | caxis([30000 50000]) 110 | colorbar 111 | title('Measured diffraction pattern') 112 | 113 | load('mask.mat'); 114 | 115 | figure 116 | imagesc(mask) 117 | title('loaded mask') 118 | 119 | figure 120 | imagesc(meas_diff .* mask, [0, 1000]) 121 | title('part one') 122 | 123 | figure 124 | imagesc((mask-1)*(-1) .* short_exp * factor, [0, 1000]) 125 | title('part 2') 126 | 127 | meas_diff = meas_diff .* mask + (mask-1)*(-1) .* short_exp * factor; 128 | 129 | end 130 | 131 | % Centering 132 | [diff_max, diff_max_ind] = max(meas_diff); 133 | [~, diff_ind_c] = max(diff_max); 134 | diff_ind_r = diff_max_ind(diff_ind_c); 135 | 136 | raw_diff_int = imtranslate(meas_diff, [(0.5*size(meas_diff,1)-diff_ind_c)+1, (0.5*size(meas_diff,2)-diff_ind_r)+1]); 137 | 138 | if crop 139 | num_pxl = size(raw_diff_int, 1); 140 | raw_diff_int = raw_diff_int(num_pxl / 2 - num_pxl /4 + 1: num_pxl / 2 + num_pxl /4, num_pxl / 2 - num_pxl /4 + 1: num_pxl / 2 + num_pxl /4); 141 | num_pxls_dp = size(raw_diff_int, 1); 142 | end 143 | 144 | diff_int = databin(raw_diff_int, binning); % .*1e3; 145 | 146 | % rotate the diffraction pattern 147 | if ~sim 148 | diff_int = imrotate(diff_int,rot_angle,'crop'); 149 | end 150 | temp_2 = diff_int; 151 | 152 | for i = 1:size(samp_cam_dist_l, 2) 153 | diff_int = temp_2; 154 | samp_cam_dist = samp_cam_dist_l(i); 155 | obj_pxl_size = lam / (num_pxls_dp * cam_pxl_size * total_bin / samp_cam_dist); % resolution in the object plane 156 | 157 | if ~sim && cc 158 | % curvature correction 159 | diff_int = curvature_correction(temp_2, cam_pxl_size*total_bin*1e-6, samp_cam_dist*1e-6, lam*1e-6); 160 | diff_int(isnan(diff_int)) = 0.; 161 | end 162 | 163 | if isa(pad, 'double') 164 | add_y = (pad - size(diff_int, 1)) / 2.; 165 | add_x = (pad - size(diff_int, 2)) / 2.; 166 | diff_int = padarray(diff_int, [add_y, add_x], 'both'); 167 | end 168 | 169 | figure 170 | imagesc(diff_int) 171 | 172 | % Setting the offset 173 | diff_int = diff_int - offset; 174 | 175 | % Thresholding of the noise 176 | diff_int(diff_int0.03); % to make it binary 200 | pad_size = floor((num_pxls_dp - size(seed_obj_downsamp, 1))/2.); 201 | seed_obj = padarray(seed_obj_downsamp, [pad_size, pad_size], 0, 'both'); 202 | 203 | if size(seed_obj) ~= size(diff_int) 204 | temp = zeros(size(diff_int)); 205 | temp(1: end - 1, 1: end - 1) = seed_obj; 206 | seed_obj = temp; 207 | size(seed_obj) 208 | end 209 | 210 | seed_obj = fliplr(seed_obj); 211 | % seed_obj = flipud(seed_obj); % correct?? 212 | 213 | figure 214 | fig = imagesc(seed_obj); 215 | title('seeded object adjusted') 216 | saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_seeded support.png')) 217 | 218 | %% calling the reconstruction function 219 | % ------------------------------------------------------------------ 220 | [rec_object, ref_support, err_fourier_space, err_obj_space, recon_diffracted] = seeded_reconst_func(diff_int, parameters, conditionals, in_BS, seed_obj); 221 | 222 | %% ploting the results 223 | % ------------------------------------------------------------------- 224 | if reconst_plot == 1 225 | % Generate grid of spatial frequencies 226 | theta_max_edge = atand((0.5*num_pxls_dp*cam_pxl_size)/samp_cam_dist); 227 | q_max_measured = 2*sind(theta_max_edge/2)/lam; 228 | 229 | q_x_cam_plus = linspace(0,q_max_measured,0.5*size(diff_int,1)); 230 | q_x_cam_minus = -fliplr(q_x_cam_plus(2:end)); 231 | q_x_cam_minus = [2*q_x_cam_minus(1)-q_x_cam_minus(2) q_x_cam_minus]; 232 | 233 | q_x_cam = [q_x_cam_minus q_x_cam_plus]; 234 | 235 | q_x_cam_cr = q_x_cam; 236 | 237 | num_pxls = (num_pxls_dp)/on_chip_binning; 238 | 239 | q_x_cam_arr = downsample(q_x_cam_cr, 1); 240 | q_y_cam_arr = q_x_cam_arr; 241 | 242 | [q_x_cam,q_y_cam] = meshgrid(q_x_cam_arr ,q_y_cam_arr); 243 | [AC, x_sam, y_sam] = fourier2Dplus(diff_int, q_x_cam, q_y_cam); % just to get the axes 244 | 245 | fig1 = classFig('PPT'); 246 | fig = imagesc(x_sam(1,:), y_sam(:,1),(abs(rec_object))); 247 | title('Reconstructed amplitude') 248 | title_val2 = ['Reconstructed amp, beta = ',num2str(beta)]; 249 | title(title_val2) 250 | colormap hot 251 | colorbar 252 | saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Reconst_Amp.png'), 'png') 253 | 254 | 255 | rec_object3 = rec_object; 256 | [m, ind] = max(abs(rec_object3(:))); 257 | [x, y] = ind2sub(size(rec_object3), ind); 258 | abs(rec_object3(x, y)) 259 | rec_object3(abs(rec_object3)<0.05 * max(max(abs(rec_object3)))) = NaN; 260 | rec_phs = angle(rec_object3 * exp(-1j*angle(rec_object3(x,y)))); 261 | 262 | norm_phase = angle(rec_object3); 263 | norm_phase(isnan(norm_phase)) = 0; 264 | range = max(max(norm_phase)) - min(min(norm_phase)); 265 | 266 | fig2 = classFig('PPT'); 267 | fig = imagesc(x_sam(1,:), y_sam(:,1), rec_phs); 268 | colormap hot 269 | colorbar 270 | title_val = ['Reconstructed phase, beta = ',num2str(beta)]; 271 | title(title_val) 272 | saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Reconst_Phase.png')) 273 | 274 | % 275 | fig3 = classFig('PPT'); 276 | fig = imagesc(log10(diff_int)); 277 | caxis([0 5]) 278 | colorbar 279 | title('Measured diffraction pattern') 280 | saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Measure_Diff.png')) 281 | 282 | % 283 | fig4 = classFig('PPT'); 284 | fig = imagesc(log10(abs(recon_diffracted))); 285 | caxis([0 5]) 286 | colorbar 287 | title('Reconstructed diffraction pattern') 288 | saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Reconst_Diff.png')) 289 | 290 | figure 291 | fig = imagesc(ref_support); 292 | title('Final support') 293 | saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_final support.png')) 294 | 295 | % 296 | fig7 = classFig('PPT'); 297 | fig = plot(err_obj_space); 298 | hold on 299 | plot(err_fourier_space, 'Color','r') 300 | legend('Error obj space', 'Error Fourier space') 301 | title('Error metrics - real and Fourier space') 302 | saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Error_metrics.png')) 303 | 304 | fig8 = classFig('PPT'); 305 | fig = imagesc(x_sam(1,:), y_sam(:,1),(imag(rec_object))); 306 | title('Imagenary part') 307 | title_val2 = ['Imaginary part obj., beta = ',num2str(beta)]; 308 | title(title_val2) 309 | colormap hot 310 | colorbar 311 | saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Imag_part_recon.png')) 312 | 313 | 314 | 315 | end 316 | 317 | %% interpolation 318 | seed_obj = ref_support; 319 | if reconst_interp == 1 320 | 321 | rec_object(isnan(rec_object)) = 0; 322 | rec_phs(isnan(rec_phs)) = 0; 323 | rec_object(seed_obj==0) = 0; 324 | 325 | CC = bwconncomp(abs(rec_object)); 326 | 327 | regions = CC.PixelIdxList; 328 | obj_ave_amp = NaN(size(diff_int, 2)); 329 | obj_ave_ph = NaN(size(diff_int, 2)); 330 | 331 | obj_amp = abs(rec_object3); 332 | obj_amp = obj_amp/max(max(obj_amp)); 333 | obj_amp(isnan(obj_amp)) = 0; 334 | obj_amp = imgaussfilt(obj_amp, 5); 335 | obj_amp(seed_obj==0) = 0; 336 | 337 | figure 338 | imagesc(obj_amp) 339 | 340 | norm_phase(obj_amp < phase_threshold) = 0; 341 | norm_phase = rec_phs; 342 | 343 | for i=1:length(regions) 344 | IND = cell2mat(regions(i)); 345 | s = [size(diff_int,1),size(diff_int,2)]; 346 | [Ind_row,Ind_col] = ind2sub(s,IND); 347 | amp_ave = 0; 348 | ph_ave = 0; 349 | 350 | temp_amps = zeros(size(IND)); 351 | amps = zeros(size(IND)); 352 | phs = zeros(size(IND)); 353 | 354 | for j=1:length(IND) 355 | temp_amps(j) = abs(rec_object(Ind_row(j),Ind_col(j))); 356 | % amp_ave = amp_ave + abs(rec_object(Ind_row(j),Ind_col(j))); 357 | % ph_ave = ph_ave + norm_phase(Ind_row(j),Ind_col(j)); 358 | end 359 | 360 | maximum = max(temp_amps); 361 | for j=1:length(IND) 362 | if temp_amps(j) > 0.5*maximum 363 | phs(j) = norm_phase(Ind_row(j),Ind_col(j)); 364 | end 365 | if temp_amps(j) > 0.3*maximum 366 | % amps(j) = abs(rec_object(Ind_row(j),Ind_col(j))); 367 | amps(j) = obj_amp(Ind_row(j),Ind_col(j)); 368 | end 369 | end 370 | 371 | phs = phs(phs~=0); 372 | amps = amps(amps~=0); 373 | 374 | Ind_row_ave = ceil(sum(Ind_row)/length(Ind_row)); 375 | Ind_col_ave = ceil(sum(Ind_col)/length(Ind_col)); 376 | 377 | % obj_ave_amp(Ind_row_ave,Ind_col_ave) = amp_ave/length(IND); 378 | % obj_ave_ph(Ind_row_ave,Ind_col_ave) = ph_ave/length(IND); 379 | 380 | obj_ave_amp(Ind_row_ave,Ind_col_ave) = mean(amps); 381 | obj_ave_ph(Ind_row_ave,Ind_col_ave) = mean(phs); 382 | 383 | end 384 | 385 | % 386 | % fig8 =classFig('PPT2'); 387 | % imagesc(x_sam(1,:), y_sam(:,1),obj_ave_amp); 388 | % colormap hot 389 | % fig4 =classFig('PPT2'); 390 | % imagesc(x_sam(1,:), y_sam(:,1),obj_ave_ph); 391 | 392 | valid = ~isnan(obj_ave_amp); 393 | obj_amp_interp1 = griddata(x_sam(valid),y_sam(valid),obj_ave_amp(valid),x_sam,y_sam,'cubic'); 394 | 395 | fig9 =classFig('PPT'); 396 | fig = imagesc(x_sam(1,:), y_sam(:,1),obj_amp_interp1); 397 | title('amp interpolation') 398 | colormap hot 399 | colorbar 400 | saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_amp_interpol.png')) 401 | 402 | valid = ~isnan(obj_ave_ph); 403 | obj_ph_interp1 = griddata(x_sam(valid),y_sam(valid),obj_ave_ph(valid),x_sam,y_sam,'cubic'); 404 | 405 | fig5 = classFig('PPT'); 406 | fig = imagesc(x_sam(1,:), y_sam(:,1),obj_ph_interp1); 407 | title('phs interpolation') 408 | colorbar 409 | saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_phs_interpol.png')) 410 | end 411 | if dump_recon 412 | close all 413 | % Save workspace 414 | save(strcat(dump_folder, '\\', int2str(samp_cam_dist),'_recon_results')); 415 | end 416 | end 417 | -------------------------------------------------------------------------------- /matlab_cdi/seeded_reconst_func.m: -------------------------------------------------------------------------------- 1 | function [final_object, support, err_fourier_space, err_obj_space, recon_diffracted] = seeded_reconst_func(diff_int, parameters, conditionals, in_BS, seed_obj) 2 | 3 | BS_used = conditionals(1) ; 4 | update_me = conditionals(2) ; 5 | use_RAAR = conditionals(3) ; 6 | positivity = true; 7 | 8 | beta = parameters(1); 9 | iter_total = parameters(2); 10 | iter_cycle = parameters(3); 11 | HIO_num = parameters(4); 12 | 13 | pixel_nr_y = size(diff_int,1); 14 | pixel_nr_x = size(diff_int,2); 15 | 16 | phase_rand = -pi + 2*pi*rand(pixel_nr_y,pixel_nr_x); 17 | full_diff_field = diff_int.^0.5 .* exp(1i*phase_rand); 18 | diff_amp = abs(full_diff_field); 19 | 20 | in_support = double(seed_obj); 21 | out_of_support = double(not(seed_obj)); 22 | 23 | Err = zeros(1,iter_total); 24 | Err2 = zeros(1,iter_total); 25 | fourier_error = zeros(1,iter_total); 26 | 27 | P_filtered = full_diff_field; 28 | if update_me == 1 29 | figure; 30 | set(gcf,'Units', 'Normalized', 'OuterPosition', [0.05 0.3 0.9 0.7]); 31 | end 32 | 33 | reduce_supp = zeros(length(diff_int)); 34 | % iteration starts 35 | for i=1:iter_total 36 | % i 37 | if mod(i,100) == 0 38 | i 39 | end 40 | S = ifftshift(ifft2(ifftshift(P_filtered))); % to real space 41 | 42 | sum_total = sum(sum(abs(S).^2)); 43 | sum_out = sum(sum(abs(S.*out_of_support).^2)); 44 | sum_in = sum(sum(abs(S.*in_support).^2)); 45 | Err(1,i) = sqrt(sum_out/sum_total); % error in real space 46 | Err2(1,i) = sqrt(sum_out/sum_in); % alternative formulation 47 | 48 | count = rem(i,iter_cycle); 49 | in_support_prev = in_support; 50 | 51 | % % reduce_supp(:,1:0.5*length(reduce_supp)) = 1; 52 | % reduce_supp(:,0.5*length(reduce_supp):end) = 1; 53 | % 54 | % if i > 500 && i < 550 55 | % in_support = reduce_supp.*in_support; % support reduction for escaping from twin-image problem 56 | % out_of_support = double(not(logical(in_support))); 57 | % else 58 | % in_support = double(seed_obj); 59 | % out_of_support = double(not(seed_obj)); 60 | % end 61 | 62 | %% object space constraint 63 | % --------------------------------------------------------------------------------------- 64 | if count0.75*iter_total 65 | % HIO or RAAR 66 | 67 | % Hybrid input-output algorithm or RAAR 68 | if i==1 69 | S_filtered = S; 70 | else if i~=1 && use_RAAR==0 71 | obj_in_supp = S.*in_support; % common version: S, original version: S_prev 72 | obj_out_supp = (S_prev - beta*S).*out_of_support; 73 | S_filtered = obj_in_supp + obj_out_supp; 74 | else 75 | beta_n = beta; % + (1 - beta)*(1 - exp(-(0.3*iter_total/7)^3)); 76 | obj_in_supp = S.*in_support; 77 | obj_out_supp = (beta_n*S_prev + (1 - 2*beta_n)*S).*out_of_support; 78 | S_filtered = obj_in_supp + obj_out_supp; 79 | end 80 | end 81 | 82 | S_prev = S_filtered; 83 | 84 | else 85 | % Error reduction 86 | in_support = double(seed_obj); 87 | S_filtered = S.*in_support; 88 | S_prev = S_filtered; 89 | end 90 | 91 | 92 | %% Fourier Space operations 93 | % ------------------------------------------------------------------------------------- 94 | P = fftshift(fft2(fftshift(S_filtered))); % to Fourier space 95 | 96 | fourier_error(1,i) = sqrt(sum(sum((abs(P) - diff_amp).^2))/sum(sum(diff_int))); 97 | 98 | if BS_used == 1 99 | % diff_amp_BS = diff_amp + abs(P.*in_BS); 100 | diff_amp_BS = diff_amp.*double(not(logical(in_BS))) + abs(P.*in_BS); 101 | % diff_amp_BS = diff_amp.*double(not(logical(in_BS))) + max(abs(P),diff_amp).*in_BS; % value is replaced only if it is greater than the measured one 102 | P_filtered = diff_amp_BS .*exp(1i*angle(P)); % Fourier constraint I 103 | else 104 | P_filtered = diff_amp .*exp(1i*angle(P)); % Fourier constraint I 105 | end 106 | 107 | % Apply positivity constraint 108 | % if positivity && i > 200 109 | % S(imag(S)<0) = 0; 110 | % end 111 | % 112 | % if mod(i, 1000) == 0 113 | % figure 114 | % imagesc(abs(S)) 115 | % title('Amplitude after pos constraint.') 116 | % colormap hot 117 | % colorbar 118 | % 119 | % h = figure 120 | % imagesc(imag(S)) 121 | % title('Imaginary part after pos. constraint.') 122 | % colormap hot 123 | % colorbar 124 | % drawnow; 125 | % 126 | % figure 127 | % temp = S; 128 | % temp(abs(temp)<0.03*max(max(abs(temp)))) = NaN; 129 | % imagesc(angle(temp)); 130 | % title('phase') 131 | % colormap hot 132 | % colorbar 133 | % 134 | % waitfor(h) 135 | % end 136 | 137 | %% for updating the figures 138 | % ------------------------------------------------------------------------------------- 139 | 140 | if rem(i,50) == 0 && update_me ==1 141 | subplot(1,2,1); 142 | % pcolor(x_vec,y_vec,rot90(abs(S),2)); shading interp 143 | pcolor(abs(S)); shading interp 144 | set(gca,'FontName','Calibri','FontSize',18); 145 | axis square 146 | hold on 147 | drawnow 148 | 149 | % to save animation into a movie 150 | % frame = getframe; % uncomment for animation purposes 151 | % writeVideo(v,frame); % uncomment for animation purposes 152 | 153 | subplot(1,2,2); 154 | scatter(i,fourier_error(i)); 155 | axis square 156 | hold on 157 | drawnow 158 | 159 | end 160 | end 161 | 162 | %close(v); % uncomment for animation purposes 163 | final_object = S; 164 | support = in_support_prev; 165 | err_fourier_space = fourier_error; 166 | err_obj_space = Err; 167 | recon_diffracted = abs(P).^2; 168 | -------------------------------------------------------------------------------- /matlab_cdi/seeded_run_CDI_noprocessing.m: -------------------------------------------------------------------------------- 1 | clear all 2 | clc 3 | 4 | [diffraction_pattern_file, support_file, retrieved_obj_file, reconstructed_file, real_interp_file, imag_interp_file] = loaddata(); 5 | 6 | diffraction = load(diffraction_pattern_file, 'diffraction'); 7 | diffraction = getfield(diffraction, 'diffraction'); 8 | support = load(support_file, 'support'); 9 | support = getfield(support, 'support'); 10 | 11 | % diffraction = load('diffraction.mat', 'diffraction'); 12 | % diffraction = getfield(diffraction, 'diffraction'); 13 | % support = load('support.mat', 'support'); 14 | % support = getfield(support, 'support'); 15 | 16 | meas_diff = double(diffraction); 17 | seed_obj = double(support); 18 | 19 | % the parameters 20 | beta = 0.9; %between 0.8 and 0.98; if higher, feedback is stronger 21 | iter_total = 400; % total number of iterations 22 | iter_cycle = 401; % cycle of HIO/RAAR and ER 23 | HIO_num = 200; % HIO/RAAR per cycle 24 | parameters = [beta iter_total iter_cycle HIO_num]; 25 | 26 | in_BS = NaN; 27 | BS_used = 0 ; 28 | update_me = 0 ; % get an GUI update or not, slow if updated 29 | use_RAAR = 1; % 1 for RAAR and 0 for HIO 30 | conditionals = [BS_used update_me use_RAAR]; 31 | 32 | % run this for the same diffraction patterns used in the neural network retrieval 33 | % tic 34 | [rec_object, ref_support, err_fourier_space, err_obj_space, recon_diffracted] = seeded_reconst_func(meas_diff, parameters, conditionals, in_BS, seed_obj); 35 | % rec_object(1:64,:)=0.1*max(max(abs(rec_object)))+0.1*max(max(abs(rec_object)))*1i; 36 | % toc 37 | reconst_plot=1; 38 | reconst_interp=1; 39 | num_pxls_dp = 512; % number of pixels of the detector 40 | cam_pxl_size = 13.5; 41 | samp_cam_dist=32000; 42 | lam = 0.0135; % in microns 43 | on_chip_binning=2; 44 | phase_threshold = 0.1; 45 | %% ploting the results 46 | % ------------------------------------------------------------------- 47 | if reconst_plot == 1 48 | % Generate grid of spatial frequencies 49 | theta_max_edge = atand((0.5*num_pxls_dp*cam_pxl_size)/samp_cam_dist); 50 | q_max_measured = 2*sind(theta_max_edge/2)/lam; 51 | 52 | q_x_cam_plus = linspace(0,q_max_measured,0.5*size(meas_diff,1)); 53 | q_x_cam_minus = -fliplr(q_x_cam_plus(2:end)); 54 | q_x_cam_minus = [2*q_x_cam_minus(1)-q_x_cam_minus(2) q_x_cam_minus]; 55 | 56 | q_x_cam = [q_x_cam_minus q_x_cam_plus]; 57 | 58 | q_x_cam_cr = q_x_cam; 59 | 60 | num_pxls = (num_pxls_dp)/on_chip_binning; 61 | 62 | q_x_cam_arr = downsample(q_x_cam_cr, 1); 63 | q_y_cam_arr = q_x_cam_arr; 64 | [q_x_cam,q_y_cam] = meshgrid(q_x_cam_arr ,q_y_cam_arr); 65 | 66 | % df = (1 / N * dx) 67 | % dx = (1 / N * df) 68 | % calcute frequency axes from position axes 69 | dfgh=size(q_x_cam); 70 | N=dfgh(1); 71 | d_alpha=q_x_cam_arr(2)-q_x_cam_arr(1); 72 | d_x = (1 / N * d_alpha); 73 | % make linspace 74 | x_sam1=-N/2:1:N/2-1; 75 | x_sam1=x_sam1*d_x; 76 | [x_sam,y_sam] = meshgrid(x_sam1,x_sam1); 77 | % [AC, x_sam, y_sam] = fourier2Dplus(meas_diff, q_x_cam, q_y_cam); % just to get the axes 78 | 79 | 80 | 81 | % x_sam=() 82 | 83 | % fig1 = classFig('PPT'); 84 | % fig = imagesc(x_sam(1,:), y_sam(:,1),(abs(rec_object))); 85 | title('Reconstructed amplitude') 86 | title_val2 = ['Reconstructed amp, beta = ',num2str(beta)]; 87 | title(title_val2) 88 | colormap hot 89 | colorbar 90 | % saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Reconst_Amp.png'), 'png') 91 | 92 | 93 | rec_object3 = rec_object; 94 | [m, ind] = max(abs(rec_object3(:))); 95 | [x, y] = ind2sub(size(rec_object3), ind); 96 | abs(rec_object3(x, y)) 97 | rec_object3(abs(rec_object3)<0.05 * max(max(abs(rec_object3)))) = NaN; 98 | rec_phs = angle(rec_object3 * exp(-1j*angle(rec_object3(x,y)))); 99 | 100 | norm_phase = angle(rec_object3); 101 | norm_phase(isnan(norm_phase)) = 0; 102 | range = max(max(norm_phase)) - min(min(norm_phase)); 103 | 104 | % fig2 = classFig('PPT'); 105 | % fig = imagesc(x_sam(1,:), y_sam(:,1), rec_phs); 106 | colormap hot 107 | colorbar 108 | title_val = ['Reconstructed phase, beta = ',num2str(beta)]; 109 | title(title_val) 110 | % saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Reconst_Phase.png')) 111 | 112 | % 113 | % fig3 = classFig('PPT'); 114 | % fig = imagesc(log10(meas_diff)); 115 | caxis([0 5]) 116 | colorbar 117 | title('Measured diffraction pattern') 118 | % saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Measure_Diff.png')) 119 | 120 | % 121 | % % fig4 = classFig('PPT'); 122 | % fig = imagesc(log10(abs(recon_diffracted))); 123 | caxis([0 5]) 124 | colorbar 125 | title('Reconstructed diffraction pattern') 126 | % saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Reconst_Diff.png')) 127 | 128 | figure 129 | fig = imagesc(ref_support); 130 | title('Final support') 131 | % saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_final support.png')) 132 | 133 | % 134 | % fig7 = classFig('PPT'); 135 | fig = plot(err_obj_space); 136 | hold on 137 | plot(err_fourier_space, 'Color','r') 138 | legend('Error obj space', 'Error Fourier space') 139 | title('Error metrics - real and Fourier space') 140 | % saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Error_metrics.png')) 141 | 142 | % fig8 = classFig('PPT'); 143 | % fig = imagesc(x_sam(1,:), y_sam(:,1),(imag(rec_object))); 144 | title('Imagenary part') 145 | title_val2 = ['Imaginary part obj., beta = ',num2str(beta)]; 146 | title(title_val2) 147 | colormap hot 148 | colorbar 149 | % saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_Imag_part_recon.png')) 150 | 151 | 152 | 153 | end 154 | 155 | %% interpolation 156 | seed_obj = ref_support; 157 | if reconst_interp == 1 158 | 159 | rec_object(isnan(rec_object)) = 0; 160 | rec_phs(isnan(rec_phs)) = 0; 161 | rec_object(seed_obj==0) = 0; 162 | 163 | CC = bwconncomp(abs(rec_object)); 164 | 165 | regions = CC.PixelIdxList; 166 | obj_ave_real = NaN(size(meas_diff, 2)); 167 | obj_ave_imag = NaN(size(meas_diff, 2)); 168 | 169 | obj_amp = abs(rec_object3); 170 | obj_amp = obj_amp/max(max(obj_amp)); 171 | obj_amp(isnan(obj_amp)) = 0; 172 | obj_amp = imgaussfilt(obj_amp, 5); 173 | obj_amp(seed_obj==0) = 0; 174 | 175 | % figure 176 | % imagesc(obj_amp) 177 | 178 | norm_phase(obj_amp < phase_threshold) = 0; 179 | norm_phase = rec_phs; 180 | 181 | % use the real and imaginary part to do interpolation 182 | obj_real = real(rec_object3); 183 | obj_real = obj_real/max(max(obj_real)); 184 | obj_real(isnan(obj_real)) = 0; 185 | obj_real = imgaussfilt(obj_real, 5); 186 | obj_real(seed_obj==0) = 0; 187 | 188 | obj_imag = imag(rec_object3); 189 | obj_imag = obj_imag/max(max(obj_imag)); 190 | obj_imag(isnan(obj_imag)) = 0; 191 | obj_imag = imgaussfilt(obj_imag, 5); 192 | obj_imag(seed_obj==0) = 0; 193 | 194 | for i=1:length(regions) 195 | IND = cell2mat(regions(i)); 196 | s = [size(meas_diff,1),size(meas_diff,2)]; 197 | [Ind_row,Ind_col] = ind2sub(s,IND); 198 | amp_ave = 0; 199 | ph_ave = 0; 200 | 201 | temp_amps = zeros(size(IND)); 202 | reals = zeros(size(IND)); 203 | imags = zeros(size(IND)); 204 | 205 | for j=1:length(IND) 206 | temp_amps(j) = abs(rec_object(Ind_row(j),Ind_col(j))); 207 | %amp_ave = amp_ave + abs(rec_object(Ind_row(j),Ind_col(j))); 208 | %ph_ave = ph_ave + norm_phase(Ind_row(j),Ind_col(j)); 209 | end 210 | 211 | maximum = max(temp_amps); 212 | for j=1:length(IND) 213 | if temp_amps(j) > 0.5*maximum 214 | imags(j) = obj_imag(Ind_row(j),Ind_col(j)); 215 | end 216 | if temp_amps(j) > 0.3*maximum 217 | %reals(j) = abs(rec_object(Ind_row(j),Ind_col(j))); 218 | reals(j) = obj_real(Ind_row(j),Ind_col(j)); 219 | end 220 | end 221 | 222 | imags = imags(imags~=0); 223 | reals = reals(reals~=0); 224 | 225 | Ind_row_ave = ceil(sum(Ind_row)/length(Ind_row)); 226 | Ind_col_ave = ceil(sum(Ind_col)/length(Ind_col)); 227 | 228 | %obj_ave_real(Ind_row_ave,Ind_col_ave) = amp_ave/length(IND); 229 | %obj_ave_imag(Ind_row_ave,Ind_col_ave) = ph_ave/length(IND); 230 | 231 | obj_ave_real(Ind_row_ave,Ind_col_ave) = mean(reals); 232 | obj_ave_imag(Ind_row_ave,Ind_col_ave) = mean(imags); 233 | 234 | end 235 | 236 | %fig8 =classFig('PPT2'); 237 | %imagesc(x_sam(1,:), y_sam(:,1),obj_ave_real); 238 | %colormap hot 239 | %fig4 =classFig('PPT2'); 240 | %imagesc(x_sam(1,:), y_sam(:,1),obj_ave_imag); 241 | 242 | % plot the real and imaginary retrieved 243 | close all; 244 | figure(); 245 | h=pcolor(real(rec_object)); 246 | title("real(rec object)"); 247 | set(h,'EdgeColor','none'); 248 | figure(); 249 | h=pcolor(imag(rec_object)); 250 | title("imag(rec object)"); 251 | set(h,'EdgeColor','none'); 252 | 253 | valid = ~isnan(obj_ave_real); 254 | obj_real_interp1 = griddata(x_sam(valid),y_sam(valid),obj_ave_real(valid),x_sam,y_sam,'cubic'); 255 | 256 | % fig9 =classFig('PPT'); 257 | % figure(); 258 | % fig = imagesc(x_sam(1,:), y_sam(:,1),obj_real_interp1); 259 | % title('real interpolation') 260 | % colormap hot 261 | % colorbar 262 | % % saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_amp_interpol.png')) 263 | 264 | figure(); 265 | h=pcolor(obj_real_interp1); 266 | title("obj real interp1"); 267 | set(h,'EdgeColor','none'); 268 | 269 | valid = ~isnan(obj_ave_imag); 270 | obj_imag_interp1 = griddata(x_sam(valid),y_sam(valid),obj_ave_imag(valid),x_sam,y_sam,'cubic'); 271 | 272 | % % fig5 = classFig('PPT'); 273 | % figure(); 274 | % fig = imagesc(x_sam(1,:), y_sam(:,1),obj_imag_interp1); 275 | % title('imag interpolation') 276 | % colorbar 277 | % % saveas(fig, strcat(dump_folder, '\\', int2str(samp_cam_dist), '_phs_interpol.png')) 278 | 279 | figure(); 280 | h=pcolor(obj_imag_interp1); 281 | title("obj imag interp1"); 282 | set(h,'EdgeColor','none'); 283 | 284 | end 285 | 286 | % figure 287 | % imagesc(abs(rec_object)) 288 | % title('abs(rec_object)') 289 | % saveas(gcf, 'abs_rec_object.png') 290 | 291 | % figure 292 | % imagesc(real(rec_object)) 293 | % title('real(rec_object)') 294 | % saveas(gcf, 'real_rec_object.png') 295 | 296 | % figure 297 | % imagesc(imag(rec_object)) 298 | % title('imag(rec_object)') 299 | % saveas(gcf, 'imag_rec_object.png') 300 | 301 | % save .mat file 302 | save(retrieved_obj_file, 'rec_object'); 303 | save(reconstructed_file, 'recon_diffracted'); 304 | save(real_interp_file, 'obj_real_interp1'); 305 | save(imag_interp_file, 'obj_imag_interp1'); 306 | exit(); 307 | 308 | 309 | 310 | -------------------------------------------------------------------------------- /mp_plotdata/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import glob 3 | import numpy as np 4 | import os 5 | 6 | 7 | def makeplots(run_name): 8 | train_log = np.loadtxt(os.path.join(run_name, "train_log.dat")) 9 | validation_log = np.loadtxt(os.path.join(run_name, "validation_log.dat")) 10 | 11 | # training and validation error for real, imaginary, reconstruction 12 | # plt.figure() 13 | fig, ax = plt.subplots(2,1, figsize=(5,7)) 14 | fig.subplots_adjust(hspace=0.0, left=0.15) 15 | ax[0].plot(train_log[:,1], train_log[:,2]*1e3, 16 | label="real loss (training)", 17 | color="b") # real_loss 18 | 19 | ax[0].plot(validation_log[:,1], validation_log[:,2]*1e3, 20 | label="real loss(validation)", 21 | color="b", 22 | linestyle="dashed") # real_loss 23 | 24 | ax[0].plot(train_log[:,1], train_log[:,3]*1e3, 25 | label="imaginary loss (training)", 26 | color="r") # imag_loss 27 | 28 | 29 | ax[0].plot(validation_log[:,1], validation_log[:,3]*1e3, 30 | label="imaginary loss (validation)", 31 | color="r", 32 | linestyle="dashed") # imag_loss 33 | 34 | ax[0].legend() 35 | ax[0].set_ylabel(r"error [mse] $\cdot 10^3$") 36 | ax[0].set_title("Error plotted during training") 37 | 38 | ax[1].plot(train_log[:,1], train_log[:,4]*1e3, 39 | label="reconstruction loss (training)", 40 | color="g") # reconstruction_loss 41 | 42 | ax[1].plot(validation_log[:,1], validation_log[:,4]*1e3, 43 | label="reconstruction loss (validation)", 44 | color="g", linestyle="dashed") # reconstruction_loss 45 | ax[1].set_ylabel(r"error [mse] $\cdot 10^3$") 46 | ax[1].legend() 47 | ax[1].set_xlabel("epoch") 48 | fig.savefig("./training_validation_loss_{}.png".format(run_name)) 49 | 50 | 51 | measured_trace_data = {} 52 | for file in glob.glob(r'./{}/*1-0*'.format(run_name)): 53 | data = np.loadtxt(file) 54 | measured_trace_data[os.path.split(file)[-1].split(".")[0]] = data 55 | 56 | # plt.show() 57 | # plt.figure() 58 | added_labels={} 59 | for m in measured_trace_data.keys(): 60 | _label="" 61 | 62 | if "multiple_measurements" in m: 63 | fignum=3 64 | figtitle="multiple_measurements" 65 | if "z0" in m: 66 | fignum=4 67 | figtitle="z0" 68 | if "z-500" in m: 69 | fignum=5 70 | figtitle="z-500" 71 | if "z-1000" in m: 72 | fignum=6 73 | figtitle="z-1000" 74 | 75 | if "_None_" in m: 76 | color="red" 77 | _label="NONE" 78 | 79 | if "_lr_" in m: 80 | color="blue" 81 | _label="LR" 82 | 83 | if "_lrud_" in m: 84 | color="green" 85 | _label="LRUD" 86 | 87 | if "_ud_" in m: 88 | color="orange" 89 | _label="UD" 90 | 91 | plt.figure(fignum) 92 | plt.title(figtitle) 93 | if str(fignum)+_label in added_labels.keys(): 94 | plt.plot(measured_trace_data[m][:,1] # epoch 95 | , measured_trace_data[m][:,2]*1e3, # reconstruction_loss 96 | color=color 97 | ) 98 | else: 99 | plt.plot(measured_trace_data[m][:,1] # epoch 100 | , measured_trace_data[m][:,2]*1e3, # reconstruction_loss 101 | label=_label, 102 | color=color 103 | ) 104 | added_labels[str(fignum)+_label]=True 105 | # plt.title("Measured Trace Reconstruction Error [mse]") 106 | plt.xlabel("epoch") 107 | plt.ylabel(r"error [mse] $\cdot 10^3$") 108 | plt.legend() 109 | 110 | plt.savefig("./measured_trace_error_{}.png".format(run_name)) 111 | 112 | 113 | if __name__ == "__main__": 114 | makeplots("multiplefiles_5_test_A-6_0--S-0_5") 115 | plt.show() 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /nice_plot_nn_iterative_compared_noise.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import datagen 4 | import numpy as np 5 | import pickle 6 | import diffraction_functions 7 | import params 8 | 9 | def place_colorbar(im,ax,offsetx,offsety,ticks:list,color:str,labels:list=None): 10 | caxis=fig.add_axes([ax.get_position().bounds[0]+offsetx,ax.get_position().bounds[1]+offsety,0.07,0.01]) 11 | fig.colorbar(im,cax=caxis,orientation='horizontal') 12 | caxis.xaxis.set_ticks_position('top') 13 | caxis.xaxis.set_ticks(ticks) 14 | if labels: caxis.xaxis.set_ticklabels(labels) 15 | caxis.tick_params(axis='x',colors=color) 16 | 17 | if __name__ == "__main__": 18 | parser=argparse.ArgumentParser() 19 | parser.add_argument('--wfsensor',type=int) 20 | args,_=parser.parse_known_args() 21 | files={} 22 | cts = ['0','50','40','30'] 23 | for filename in ['out_'+_ct+'.p' for _ct in cts]: 24 | with open(filename,'rb') as file: 25 | files[filename.split('.')[0]]=pickle.load(file) 26 | 27 | N=np.shape(np.squeeze(files['out_0']['retrieved_nn']['measured_pattern']))[0] 28 | simulation_axes, _ = diffraction_functions.get_amplitude_mask_and_imagesize(N, int(params.params.wf_ratio*N)) 29 | x=simulation_axes['object']['x'] # meters 30 | x*=1e6 31 | f=simulation_axes['diffraction_plane']['f'] # 1/meters 32 | f*=1e-6 33 | 34 | for _ct in cts: 35 | 36 | fig = plt.figure(figsize=(10,10)) 37 | fig.subplots_adjust(wspace=0.05,hspace=0.05) 38 | fig.suptitle('nn iterative compared') 39 | gs = fig.add_gridspec(3,4) 40 | ax=fig.add_subplot(gs[0:2,0:2]) 41 | ax.pcolormesh(f,f,files['out_'+_ct]['retrieved_nn']['measured_pattern'],cmap='jet') 42 | ax.xaxis.set_ticks([]) 43 | ax.set_ylabel(r"frequency [1/m]$\cdot 10^{6}$") 44 | for i,retrieval,name in zip( 45 | [0,1,2],['retrieved_nn','retrieved_iterative','actual'], 46 | ['Neural Network\nRetrieved ','Iterative\nRetrieved ','Actual '] 47 | ): 48 | ax=fig.add_subplot(gs[i,2]) 49 | complex_obj=np.squeeze(files['out_'+_ct][retrieval]['real_output'] + 1j*files['out_'+_ct][retrieval]['imag_output']) 50 | complex_obj*=(1/np.max(np.abs(complex_obj))) 51 | # phase at center set to 0 52 | complex_obj*=np.exp(-1j * np.angle(complex_obj[N//2,N//2])) 53 | 54 | phase = np.angle(complex_obj) 55 | phase[np.abs(complex_obj)**2 < 0.001]=0 56 | im=ax.pcolormesh(x,x,np.abs(complex_obj)**2,vmin=0,vmax=1,cmap='jet') 57 | ax.text(0.05,0.95,name+'Intensity',transform=ax.transAxes,color='white',weight='bold',va='top') 58 | ax.xaxis.set_ticks_position('top'); ax.xaxis.set_label_position('top') 59 | ax.yaxis.set_ticks([]) 60 | if not i==0: ax.xaxis.set_ticks([]) 61 | else:ax.set_xlabel(("position [um]")) 62 | place_colorbar(im,ax,offsetx=0.015,offsety=0.005,ticks=[0,0.5,1],color='white') 63 | 64 | ax=fig.add_subplot(gs[i,3]) 65 | im=ax.pcolormesh(x,x,phase,vmin=-np.pi,vmax=np.pi,cmap='jet') 66 | ax.yaxis.set_ticks_position('right'); ax.yaxis.set_label_position('right') 67 | if not i==1: ax.yaxis.set_ticks([]) 68 | else:ax.set_ylabel(("position [um]")) 69 | ax.xaxis.set_ticks([]) 70 | ax.text(0.05,0.95,name+'Phase',transform=ax.transAxes,color='black',weight='bold',va='top') 71 | place_colorbar(im,ax,offsetx=0.015,offsety=0.005,ticks=[-3.14,0,3.14],color='black',labels=['-pi','0','+pi']) 72 | 73 | # draw zernike coefficients 74 | coefficients=np.squeeze(files['out_'+_ct]['actual']['coefficients']); assert len(np.shape(coefficients))==1 75 | scale=np.squeeze(files['out_'+_ct]['actual']['scale']); assert len(np.shape(scale))==0 76 | fig.text(0.05,0.34,'Zernike Coefficients:',size=20,color='red') 77 | c_str="" 78 | for _i, _c, _z in zip(range(len(coefficients)), 79 | coefficients, 80 | datagen.DataGenerator(1024,256).zernike_cvector): 81 | c_str += '\n' if (_i%3==0) else ' ' 82 | c_str += r"$Z^{"+str(_z.m)+"}_{"+str(_z.n)+"}$" 83 | c_str+=" " 84 | c_str += "%.2f"%_c 85 | 86 | # c_str+="%.2f"%_c+'\n' 87 | fig.text(0.03,0.34,c_str,ha='left',va='top',size=20) 88 | fig.text(0.05,0.07,'Scale:',size=20,color='red') 89 | fig.text(0.03,0.05,'S:'+"%.2f"%scale,ha='left',va='top',size=20) 90 | 91 | fig.savefig('./wfs_'+str(args.wfsensor)+'_iterative_nn_compared_noise_'+_ct+'.png') 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /noise_resistant_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def convolutional_layer_nopadding(input_x, shape, activate, stride): 4 | W = init_weights(shape) 5 | b = init_bias([shape[3]]) 6 | 7 | if activate == 'relu': 8 | return tf.nn.relu(conv2d_nopad(input_x, W, stride) + b) 9 | 10 | if activate == 'leaky': 11 | return tf.nn.leaky_relu(conv2d_nopad(input_x, W, stride) + b) 12 | 13 | elif activate == 'none': 14 | return conv2d_nopad(input_x, W, stride) + b 15 | 16 | def max_pooling_layer(input_x, pool_size_val, stride_val, pad=False): 17 | if pad: 18 | return tf.layers.max_pooling2d(input_x, pool_size=[pool_size_val[0], pool_size_val[1]], strides=[stride_val[0], stride_val[1]], padding="SAME") 19 | else: 20 | return tf.layers.max_pooling2d(input_x, pool_size=[pool_size_val[0], pool_size_val[1]], strides=[stride_val[0], stride_val[1]], padding="VALID") 21 | 22 | def part1_purple(input): 23 | conv1 = convolutional_layer_nopadding(input, shape=[10, 10, 1, 20], activate='relu', stride=[1, 1]) 24 | print("conv1", conv1) 25 | 26 | pool1 = max_pooling_layer(conv1, pool_size_val=[6, 6], stride_val=[3, 3]) 27 | print("pool1", pool1) 28 | 29 | conv2 = convolutional_layer_nopadding(pool1, shape=[6, 6, 20, 40], activate='relu', stride=[1, 1]) 30 | print("conv2", conv2) 31 | 32 | pool2 = max_pooling_layer(conv2, pool_size_val=[4, 4], stride_val=[2, 2]) 33 | print("pool2 =>", pool2) 34 | 35 | return pool2 36 | 37 | def part2_grey(input): 38 | # center 39 | conv31 = convolutional_layer_nopadding(input, shape=[6, 6, 40, 40], activate='relu', stride=[1, 1]) 40 | conv322 = convolutional_layer_nopadding(conv31, shape=[6, 6, 40, 20], activate='relu', stride=[1, 1]) 41 | 42 | #left 43 | pool3 = max_pooling_layer(input, pool_size_val=[7, 7], stride_val=[2, 2]) 44 | conv321 = convolutional_layer_nopadding(pool3, shape=[1, 1, 40, 20], activate='relu', stride=[1, 1]) 45 | 46 | #right 47 | conv323 = convolutional_layer_nopadding(input, shape=[11, 11, 40, 20], activate='relu', stride=[1, 1]) 48 | 49 | conc1 = tf.concat([conv321, conv322, conv323], axis=3) 50 | 51 | return conc1 52 | 53 | def part3_green(input): 54 | # left side 55 | conv4 = convolutional_layer_nopadding(input, shape=[3, 3, 60, 40], activate='relu', stride=[1, 1]) 56 | 57 | # right side 58 | pool4 = max_pooling_layer(input, pool_size_val=[3, 3], stride_val=[1, 1]) 59 | 60 | # concatinate 61 | conc2 = tf.concat([conv4, pool4], axis=3) 62 | 63 | return conc2 64 | 65 | def avg_pooling_layer(input_x, pool_size_val, stride_val, pad=False): 66 | if pad: 67 | return tf.layers.average_pooling2d(input_x, pool_size=[pool_size_val[0], pool_size_val[1]], strides=[stride_val[0], stride_val[1]], padding="SAME") 68 | else: 69 | return tf.layers.average_pooling2d(input_x, pool_size=[pool_size_val[0], pool_size_val[1]], strides=[stride_val[0], stride_val[1]], padding="VALID") 70 | 71 | def conv2d(x, W, stride): 72 | return tf.nn.conv2d(x, W, strides=[1, stride[0], stride[1], 1], padding='SAME') 73 | 74 | def conv2d_nopad(x, W, stride): 75 | return tf.nn.conv2d(x, W, strides=[1, stride[0], stride[1], 1], padding="VALID") 76 | 77 | def init_weights(shape): 78 | init_random_dist = tf.truncated_normal(shape, stddev=0.1, dtype=tf.float32) 79 | return tf.Variable(init_random_dist) 80 | 81 | def init_bias(shape): 82 | init_bias_vals = tf.constant(0.1, shape=shape, dtype=tf.float32) 83 | return tf.Variable(init_bias_vals) 84 | 85 | def normal_full_layer(input_layer, size): 86 | input_size = int(input_layer.get_shape()[1]) 87 | W = init_weights([input_size, size]) 88 | b = init_bias([size]) 89 | return tf.matmul(input_layer, W) + b 90 | 91 | def noise_resistant_phase_retrieval_net(input_image:tf.Tensor,zernike_coefs:int)->(tf.Tensor,tf.Tensor,tf.Tensor): 92 | 93 | with tf.variable_scope("phase"): 94 | pool2 = part1_purple(input_image) 95 | 96 | conc1 = part2_grey(pool2) 97 | 98 | conc2 = part3_green(conc1) 99 | 100 | pool51 = avg_pooling_layer(conc2, pool_size_val=[3, 3], stride_val=[1, 1]) 101 | # print("pool51", pool51) 102 | 103 | pool52 = avg_pooling_layer(pool2, pool_size_val=[5, 5], stride_val=[5, 5], pad=True) 104 | # print("pool52", pool52) 105 | 106 | pool53 = avg_pooling_layer(conc1, pool_size_val=[3, 3], stride_val=[2, 2]) 107 | # print("pool53", pool53) 108 | 109 | pool51_flat = tf.contrib.layers.flatten(pool51) 110 | pool52_flat = tf.contrib.layers.flatten(pool52) 111 | pool53_flat = tf.contrib.layers.flatten(pool53) 112 | 113 | conc3 = tf.concat([pool51_flat, pool52_flat, pool53_flat], axis=1) 114 | 115 | fc5 = tf.layers.dense(inputs=conc3, units=256) 116 | 117 | # dropout 118 | hold_prob = tf.placeholder_with_default(1.0, shape=()) 119 | dropout_layer = tf.nn.dropout(fc5, keep_prob=hold_prob) 120 | 121 | # output layer 122 | predicted_zernike = normal_full_layer(dropout_layer, zernike_coefs) 123 | predicted_scale = normal_full_layer(dropout_layer, 1) 124 | 125 | return predicted_zernike, predicted_scale, hold_prob 126 | -------------------------------------------------------------------------------- /noise_test_compareCDI.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # read amples in .dat file 4 | wfsensor=2 5 | network="256_6x6_square_wfs_test" 6 | 7 | python CompareNN_MatlabBilinearInterp.py --network $network --net_type original --wfsensor $wfsensor --outfile 8 | 9 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | import PIL.ImageOps 3 | import PIL 4 | import numpy as np 5 | import argparse 6 | 7 | def custom_round(coordinates:list,radius:float,size:tuple)->np.array: 8 | image = 255*np.ones((size[0],size[1],3),dtype=np.uint8) 9 | row=np.arange(0,size[0]).reshape(-1,1) 10 | col=np.arange(0,size[1]).reshape(1,-1) 11 | for _c in coordinates: 12 | # radius 13 | # draw distance 14 | dist = np.sqrt((row-_c[0])**2 + (col-_c[1])**2) 15 | image[dist size[1]//2: 88 | coordinates.append((_r+25+random_r_offset,_c-25+random_c_offset)) 89 | else: coordinates.append((_r+random_r_offset,_c+random_c_offset)) 90 | params.wavefront_sensor=custom_round( 91 | coordinates, 92 | radius=30, 93 | size=size 94 | ) 95 | 96 | elif args.wfsensor==2: 97 | # used in visible light experiment 98 | params.wavefront_sensor=wfs_square_6x6_upright() 99 | 100 | elif args.wfsensor==3: 101 | # the modified wavefront sensor that matches the up / right part of the square ones 102 | params.wavefront_sensor=wfs_round_10x10_upright() 103 | 104 | elif args.wfsensor==4: 105 | params.wavefront_sensor=wfs_square_10x10_upright() 106 | 107 | # size used in XUV 108 | # params.wavefron_sensor_size_nm=5*im.size[0]*1e-9 109 | 110 | # size used in visible 111 | params.wavefron_sensor_size_nm=60e-6 112 | 113 | # define materials 114 | # https://refractiveindex.info/?shelf=main&book=Cu&page=Johnson 115 | # https://refractiveindex.info/?shelf=main&book=Si3N4&page=Luke 116 | # wavefront sensor material properties 117 | params_cu = MaterialParams( 118 | lam=633e-9, 119 | dz=10e-9, 120 | delta_Ta = 0.26965-1, # double check this 121 | beta_Ta = 3.4106, 122 | distance = 150e-9 123 | ) 124 | params.material_params.append(params_cu) 125 | 126 | params_Si = MaterialParams( 127 | delta_Ta = 2.0394-1, 128 | beta_Ta = 0.0, 129 | dz=10e-9, 130 | lam=633e-9, 131 | distance=50e-9 132 | ) 133 | params.material_params.append(params_Si) 134 | 135 | # # XUV 136 | # params_Si = MaterialParams( 137 | # beta_Ta = 0.00926, 138 | # delta_Ta = 0.02661, 139 | # dz=10e-9, 140 | # lam=18.5e-9, 141 | # distance=50e-9) 142 | # 143 | # params.material_params.append(params_Si) 144 | # params_cu = MaterialParams( 145 | # beta_Ta = 0.0612, 146 | # delta_Ta = 0.03748, 147 | # lam=18.5e-9, 148 | # dz=10e-9, 149 | # distance = 150e-9) 150 | # params.material_params.append(params_cu) 151 | 152 | 153 | params.wf_ratio=1/3 154 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.13.3 2 | tensorflow-gpu==1.12.0 3 | matplotlib==3.0.2 4 | tables==3.4.4 5 | scipy==1.1.0 6 | scikit-image==0.16.2 7 | Pillow==7.0.0 8 | imageio==2.6.1 9 | astropy==1.3 10 | pyqt5-tools==5.15.0.1.7 11 | PyQt5==5.15.0 12 | pyqtgraph==0.11.0 13 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # set -e 4 | 5 | # source zernike3/loadmodules.sh 6 | 7 | declare -a runs=( 8 | # "varnoise_wfstest_0" # name 9 | # "36000" # train samples 10 | # "50,40,30,20,10,5" # peak count 11 | # "original" # nr network 12 | # "0" # wavefront sensor 13 | 14 | "256_6x6_square_wfs_test" # name 15 | "36000" # train samples 16 | "50,40,30,20,10,5" # peak count 17 | "original" # nr network 18 | "2" # wavefront sensor 6x6 wfs (2) 19 | 20 | ) 21 | i=0 22 | while (( $i < ${#runs[@]})) 23 | do 24 | network=${runs[$i]} 25 | training_samples=${runs[$i+1]} 26 | pc=${runs[$i+2]} 27 | net_type=${runs[$i+3]} 28 | wfsensor=${runs[$i+4]} 29 | 30 | echo $network 31 | echo $training_samples 32 | 33 | # generate dataset 34 | rm ./${network}*.hdf5 35 | # create samples 36 | python datagen.py --count ${training_samples} --name ${network}_train.hdf5 --batch_size 100 --seed 345678 --wfsensor $wfsensor 37 | python datagen.py --count 200 --name ${network}_test.hdf5 --batch_size 100 --seed 8977 --wfsensor $wfsensor 38 | 39 | camera_noise="SquareWFtest/CameraNoise/1_1000/Bild_1.png" 40 | # add noise to samples 41 | python addnoise.py --infile ${network}_train.hdf5 --outfile ${network}_train_noise.hdf5 --peakcount $pc --cameraimage $camera_noise --wfsensor $wfsensor 42 | python addnoise.py --infile ${network}_test.hdf5 --outfile ${network}_test_noise.hdf5 --peakcount $pc --cameraimage $camera_noise --wfsensor $wfsensor 43 | 44 | echo ${network} 45 | python diffraction_net.py --name ${network} --net_type ${net_type} --wfsensor $wfsensor 46 | 47 | i=$i+5 48 | done 49 | 50 | -------------------------------------------------------------------------------- /rungdb.sh: -------------------------------------------------------------------------------- 1 | gdb -q -x debug.gdb -ex 'break zernikedatagen.h:296' -ex 'r' 2 | -------------------------------------------------------------------------------- /sample.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwhite510/wavefront-sensor-neural-network/3a521fea276119438fe07b6e2b777e677c1f83c4/sample.p -------------------------------------------------------------------------------- /size_6um_pitch_600nm_diameter_300nm_psize_5nm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jwhite510/wavefront-sensor-neural-network/3a521fea276119438fe07b6e2b777e677c1f83c4/size_6um_pitch_600nm_diameter_300nm_psize_5nm.png -------------------------------------------------------------------------------- /submit_gpu_job.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --job-name=test 4 | #SBATCH --output=slurm-%j.out 5 | # 6 | #SBATCH --ntasks=1 7 | #SBATCH --time=1:00:00 8 | #SBATCH --gres=gpu:1 9 | 10 | ##SBATCH --partition=gpu_p100 11 | ##SBATCH --partition=gpu_v100 12 | #SBATCH --partition=gpu_test 13 | # env 14 | 15 | # # ~/python_compiled/bin/python3 --version 16 | ~/python_compiled/bin/python3 diffraction_net.py $batch_run_name 17 | 18 | 19 | --------------------------------------------------------------------------------