├── .gitignore ├── LICENSE ├── README.md ├── algorithms.py ├── data ├── calib │ └── 1.png └── test │ ├── depth │ └── frame_0043.dpt │ └── rgb │ ├── 0.png │ └── frame_0043.png ├── dataset_capture.py ├── env.yml ├── hw ├── calibration_module.py ├── camera_capture_module.py ├── detect_heds_module_path.py ├── lens_control.ino ├── lens_control.py └── slm_display_module.py ├── image_loader.py ├── img ├── citl-asm.png ├── sgd-asm.png ├── sgd-ours.png └── teaser.png ├── main.py ├── params.py ├── prop_ideal.py ├── prop_model.py ├── prop_physical.py ├── prop_submodules.py ├── prop_zernike.py ├── train.py ├── unet.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.DS_Store 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Stanford Computational Imaging Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays
SIGGRAPH Asia 2021 2 | ### [Project Page](http://www.computationalimaging.org/publications/neuralholography3d/) | [Video](https://www.youtube.com/watch?v=EsxGnUd8Efs) | [Paper](https://www.computationalimaging.org/wp-content/uploads/2021/08/NeuralHolography3D.pdf) 3 | PyTorch implementation of
4 | [Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays](http://www.computationalimaging.org/publications/neuralholography3d/)
5 | [Suyeon Choi](http://stanford.edu/~suyeon/)\*, 6 | [Manu Gopakumar](https://www.linkedin.com/in/manu-gopakumar-25032412b/)\*, 7 | [Yifan Peng](http://web.stanford.edu/~evanpeng/), 8 | [Jonghyun Kim](http://j-kim.kr/), 9 | [Gordon Wetzstein](https://computationalimaging.org)
10 | \*denotes equal contribution 11 | in SIGGRAPH Asia 2021 12 | 13 | 14 | 15 | ## Get started 16 | Our code uses [PyTorch Lightning](https://www.pytorchlightning.ai/) and PyTorch >=1.10.0, as it uses complex value operations 17 | 18 | You can set up a conda environment with all dependencies like so: 19 | ``` 20 | conda env create -f env.yml 21 | conda activate neural3d 22 | ``` 23 | 24 | Also, download [PyCapture2 SDK](https://www.flir.com/products/flycapture-sdk/) and [HOLOEYE SDK](https://holoeye.com/spatial-light-modulators/slm-software/slm-display-sdk/) (or use [slmPy](https://github.com/wavefrontshaping/slmPy)) and place the SDKs in your environment folder. If your hardware setup (SLM, Camera, Laser) is different from ours, you need to make sure that their Python API works well in the ```prop_physical.py```. 25 | 26 | ## High-Level structure 27 | 28 | The code is organized as follows: 29 | 30 | 31 | `./` 32 | * ```main.py``` generates phase patterns from RGBD/RGB data using SGD. 33 | * ```train.py``` trains parameterized propagation models from captured data at multiple planes. 34 | * ```algorithms.py``` contains the gradient-descent based algorithm for RGBD/RGB supervision 35 | 36 | 37 | 38 | * ```props_*.py``` contain the wave propagation operators (in simulation and physics). 39 | * ```utils.py``` has some utils. 40 | * ```params.py``` contains our default parameter settings. :heavy_exclamation_mark:**(Replace values here with those in your setup.)**:heavy_exclamation_mark: 41 | * ```image_loader.py``` contains data loader modules. 42 | 43 | `./hw/` contains modules for hardware control and homography calibration. 44 | 45 | ## Step-by-step instructions 46 | ### 0. Set up parameters 47 | First off, update the values in ``` params.py ``` such that all of them match your configuration (wavelength, propagation distances, the resolution of SLM, encoding for SLM, etc ...). And then, you're ready to use our codebase for your setup! Let's quickly test it by running the naive SGD algorithm with the following command: 48 | ``` 49 | # Run naive SGD 50 | c=1; i=0; your_out_path=./results_first; 51 | 52 | # RGB supervision 53 | python main.py --data_path=./data/test/rgb --out_path=${your_out_path} --channel=$c --target=rgb --loss_func=l2 --lr=0.01 --num_iters=1000 --eval_plane_idx=$i 54 | 55 | # RGBD supervision 56 | python main.py --data_path=./data/test --out_path=${your_out_path} --channel=$c --target=rgbd --loss_func=l2 --lr=0.01 --num_iters=1000 57 | 58 | ``` 59 | This will generate an SLM phase pattern in ```${your_out_path}```. Display it on your SLM and check out the holographic image formed. While you will get an almost flawless image in your simulation, unfortunately, you will probably get an imagery like below in your experimental setup. Here we show a captured image from our setup: 60 | 61 | 62 | This image degradation is primarily due to the model mismatch between the simulated model you just used (ASM) and the actual propagation in the physical setup. To reduce the gap, we had proposed the "camera-in-the-loop" (CITL) optimization technique which we will try next! 63 | 64 | ### 1. Camera-in-the-loop calibration 65 | To run the CITL optimization, we first need to align the target plane captured through the sensor with the rectified simulated coordinate. To this end, we calculate a homography between the two planes, after displaying a dot pattern (Note that you can use your own image for calibration!) at the target plane. Note that you need to do this procedure per plane. You can generate the SLM patterns for that with the following command (pass the index of the plane with ```--eval_plane_idx=$i``` for multiple planes): 66 | 67 | ``` 68 | # Generate homography calibration patterns 69 | python main.py --data_path=./data/calib --out_path=./calibration --channel=$c --target=2d --loss_func=l2 --eval_plane_idx=$i --full_roi=True 70 | 71 | ``` 72 | Now you're ready to run the camera-in-the-loop optimization! Before that, please make sure that all of your hardware components are ready and parameters are correctly set in ```hw_params()``` in ```params.py```. For example, you need to set up the python APIs of your SLM/sensor to run them "in-the-loop". For more information, check out the supplements of our papers: [[link1]](https://drive.google.com/file/d/1vay4xeg5iC7y8CLWR6nQEWe3mjBuqCWB/view) [[link2]](https://opg.optica.org/optica/viewmedia.cfm?uri=optica-8-2-143&seq=s001) [[link3]](https://drive.google.com/file/d/1FNSXBYivIN9hqDUxzUIMgSFVKJKtFEOr/view) 73 | ``` 74 | # Camera-in-the-loop optimization 75 | python main.py --citl=True --data_path=./data/test --out_path=${your_out_path} --channel=$c --target=2d --loss_func=l2 --eval_plane_idx=$i 76 | 77 | ``` 78 | With the phase pattern generated by this code, you will get experimental results like below: 79 | 80 | 81 | 82 | ### 2. Forward model training 83 | Although the camera-in-the-loop shows a significant improvement over the naive approaches, still it has limitations in the need for a camera for every iteration and target image which may not be practical, and also it is hard to extend to 3D since you will need to change the focus of the lens every iteration. To overcome the limitation, we proposed to train a parameterized wave propagation model instead of optimizing a single-phase pattern. After training once, you can use this model to either optimize a phase pattern for 3D or use it as a loss function for training an inverse network. 84 | 85 | #### 2-1. Dataset generation 86 | To train the model, we first create a set of thousands of phase patterns that are various in randomness, target image, number of iterations, method, etc. 87 | For dataset images, we used the [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) for our training - you can also download put them in your ```${data_path}```. 88 | 89 | ``` 90 | python main.py --prop_model=${model} --data_path=${data_path} --out_path=${out_path} --channel=$c --target=2d --loss_func=l2 --prop_model_path=${model_path} --random_gen=True 91 | 92 | ``` 93 | 94 | #### 2-2. Dataset capture 95 | Then, let's display and massively capture all of them for all planes. 96 | ``` 97 | for i in 0 1 ... 7 98 | do 99 | python dataset_capture.py --channel=$c --plane_idx=i 100 | done 101 | ``` 102 | We also release a subset of our dataset which you can download from here: [[Big Dataset (~220GB)]](https://drive.google.com/file/d/1E9ppFPwueGwRTG9yRk9wbB3Xy7eOOGOI/view?usp=sharing), [[Small Dataset (~60GB)]](https://drive.google.com/file/d/1EC2pzHlsB_P_braGc1r71oKt9vlmwzlq/view?usp=sharing).. 103 | #### 2-3. Forward model training 104 | With thousands of pairs of `(phase pattern, captured intensities)`, now you can train any parameterized model letting them to predict the intensities at multiple planes. We implemented several architectures for the models described in our paper, the original [Neural Holography](https://www.computationalimaging.org/publications/neuralholography/) model, Hardware-in-the-loop model (Chakravathula et al., 2020), and three variants of our [NH3D](https://www.computationalimaging.org/publications/neuralholography3d/) model (CNNprop, propCNN, CNNpropCNN). 105 | Train the Neural Holography models (SIGGRAPH Asia 2020, 2021) with the same codebase! 106 | ``` 107 | # try nh, hil, cnnprop, propcnn, cnnpropcnn for ${model}. 108 | python train.py --prop_model=${model} --data_path=${data_path} --out_path=${out_path} --channel=${channel} --lr=4e-4 --num_train_planes=${num_train_planes} 109 | ``` 110 | 111 | Repeat the procedures from 2-1. to 2.3 using your trained model. 112 | 113 | #### 2-4. Run SGD-Model (with the trained forward model) 114 | After training, simply run the phase generation script you ran at the [beginning](#step-by-step-instructions), adding a pointer of your trained model `${model_path}`. 115 | ``` 116 | # RGB 117 | python main.py --prop_model=${model} --data_path=${data_path} --out_path=${out_path} --channel=$c --target=2d --loss_func=l2 --prop_model_path=${model_path} --eval_plane_idx=$i 118 | 119 | # RGBD supervision 120 | python main.py --prop_model=${model} --data_path=${data_path} --out_path=${out_path} --channel=$c --target=rgbd --loss_func=l2 --prop_model_path=${model_path} 121 | ``` 122 | 123 | Using the command above, we can now achieve a holographic image like below in the same setup. Please try RGBD supervision and other methods we discuss in our paper! 124 | 125 | 126 | ### Citation 127 | If you find our work useful in your research, please cite: 128 | ``` 129 | @article{choi2021neural, 130 | title={Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays}, 131 | author={Choi, Suyeon and Gopakumar, Manu and Peng, Yifan and Kim, Jonghyun and Wetzstein, Gordon}, 132 | journal={ACM Transactions on Graphics (TOG)}, 133 | volume={40}, 134 | number={6}, 135 | pages={1--12}, 136 | year={2021}, 137 | publisher={ACM New York, NY, USA} 138 | } 139 | ``` 140 | 141 | ### Contact 142 | If you have any questions, please email Suyeon Choi at suyeon@stanford.edu. 143 | -------------------------------------------------------------------------------- /algorithms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various algorithms for RGBD/RGB supervision. 3 | """ 4 | 5 | import imageio 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from tqdm import tqdm 10 | import utils 11 | 12 | 13 | def load_alg(alg_type): 14 | if alg_type.lower() in ('sgd', 'admm', 'gd', 'gradient-descent'): 15 | algorithm = gradient_descent 16 | 17 | return algorithm 18 | 19 | 20 | def gradient_descent(init_phase, target_amp, target_mask=None, forward_prop=None, num_iters=1000, roi_res=None, 21 | loss_fn=nn.MSELoss(), lr=0.01, out_path_idx='./results', 22 | citl=False, camera_prop=None, writer=None, admm_opt=None, 23 | *args, **kwargs): 24 | """ 25 | Gradient-descent based method for phase optimization. 26 | 27 | :param init_phase: 28 | :param target_amp: 29 | :param target_mask: 30 | :param forward_prop: 31 | :param num_iters: 32 | :param roi_res: 33 | :param loss_fn: 34 | :param lr: 35 | :param out_path_idx: 36 | :param citl: 37 | :param camera_prop: 38 | :param writer: 39 | :param args: 40 | :param kwargs: 41 | :return: 42 | """ 43 | 44 | assert forward_prop is not None 45 | dev = init_phase.device 46 | num_iters_admm_inner = 1 if admm_opt is None else admm_opt['num_iters_inner'] 47 | 48 | slm_phase = init_phase.requires_grad_(True) # phase at the slm plane 49 | optvars = [{'params': slm_phase}] 50 | optimizer = optim.Adam(optvars, lr=lr) 51 | 52 | loss_vals = [] 53 | loss_vals_quantized = [] 54 | best_loss = 10. 55 | best_iter = 0 56 | 57 | if target_mask is not None: 58 | target_amp = target_amp * target_mask 59 | nonzeros = target_mask > 0 60 | if roi_res is not None: 61 | target_amp = utils.crop_image(target_amp, roi_res, stacked_complex=False) 62 | if target_mask is not None: 63 | target_mask = utils.crop_image(target_mask, roi_res, stacked_complex=False) 64 | nonzeros = target_mask > 0 65 | 66 | if admm_opt is not None: 67 | u = torch.zeros(1, 1, *roi_res).to(dev) 68 | z = torch.zeros(1, 1, *roi_res).to(dev) 69 | 70 | for t in range(num_iters): 71 | for t_inner in range(num_iters_admm_inner): 72 | optimizer.zero_grad() 73 | 74 | recon_field = forward_prop(slm_phase) 75 | recon_field = utils.crop_image(recon_field, roi_res, stacked_complex=False) 76 | recon_amp = recon_field.abs() 77 | 78 | if citl: # surrogate gradients for CITL 79 | captured_amp = camera_prop(slm_phase) 80 | captured_amp = utils.crop_image(captured_amp, roi_res, 81 | stacked_complex=False) 82 | recon_amp = recon_amp + captured_amp - recon_amp.detach() 83 | 84 | if target_mask is not None: 85 | final_amp = torch.zeros_like(recon_amp) 86 | final_amp[nonzeros] += (recon_amp[nonzeros] * target_mask[nonzeros]) 87 | else: 88 | final_amp = recon_amp 89 | 90 | with torch.no_grad(): 91 | s = (final_amp * target_amp).mean() / \ 92 | (final_amp ** 2).mean() # scale minimizing MSE btw recon and 93 | 94 | loss_val = loss_fn(s * final_amp, target_amp) 95 | 96 | # second loss term if ADMM 97 | if admm_opt is not None: 98 | # augmented lagrangian 99 | recon_phase = recon_field.angle() 100 | loss_prior = loss_fn(utils.laplacian(recon_phase) * target_mask, (z - u) * target_mask) 101 | loss_val = loss_val + admm_opt['rho'] * loss_prior 102 | 103 | loss_val.backward() 104 | optimizer.step() 105 | 106 | ## ADMM steps 107 | if admm_opt is not None: 108 | with torch.no_grad(): 109 | reg_norm = utils.laplacian(recon_phase).detach() * target_mask 110 | Ax = admm_opt['alpha'] * reg_norm + (1 - admm_opt['alpha']) * z # over-relaxation 111 | z = utils.soft_thresholding(u + Ax, admm_opt['gamma'] / (rho + 1e-10)) 112 | u = u + Ax - z 113 | 114 | # varying penalty (rho) 115 | if admm_opt['varying-penalty']: 116 | if t == 0: 117 | z_prev = z 118 | 119 | r_k = ((reg_norm - z).detach() ** 2).mean() # primal residual 120 | s_k = ((rho * utils.laplacian(z_prev - z).detach()) ** 2).mean() # dual residual 121 | 122 | if r_k > admm_opt['mu'] * s_k: 123 | rho = admm_opt['tau_incr'] * rho 124 | u /= admm_opt['tau_incr'] 125 | elif s_k > admm_opt['mu'] * r_k: 126 | rho /= admm_opt['tau_decr'] 127 | u *= admm_opt['tau_decr'] 128 | z_prev = z 129 | 130 | with torch.no_grad(): 131 | if loss_val < best_loss: 132 | best_phase = slm_phase 133 | best_loss = loss_val 134 | best_amp = s * recon_amp 135 | best_iter = t + 1 136 | print(f' -- optimization is done, best loss: {best_loss}') 137 | 138 | return {'loss_vals': loss_vals, 139 | 'loss_vals_q': loss_vals_quantized, 140 | 'best_iter': best_iter, 141 | 'best_loss': best_loss, 142 | 'recon_amp': best_amp, 143 | 'final_phase': best_phase} -------------------------------------------------------------------------------- /data/calib/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/data/calib/1.png -------------------------------------------------------------------------------- /data/test/depth/frame_0043.dpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/data/test/depth/frame_0043.dpt -------------------------------------------------------------------------------- /data/test/rgb/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/data/test/rgb/0.png -------------------------------------------------------------------------------- /data/test/rgb/frame_0043.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/data/test/rgb/frame_0043.png -------------------------------------------------------------------------------- /dataset_capture.py: -------------------------------------------------------------------------------- 1 | """ 2 | capture all the phases in a folder 3 | 201203 4 | 5 | """ 6 | 7 | import os 8 | import time 9 | import concurrent 10 | import cv2 11 | import skimage.io 12 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 13 | import numpy as np 14 | import torch 15 | import imageio 16 | import configargparse 17 | 18 | import utils 19 | import props 20 | import params 21 | 22 | def save_image(raw_data, file_path, demosaickrule, warper, dev): 23 | raw_data = raw_data - 64 24 | color_cv_image = cv2.cvtColor(raw_data, demosaickrule) # it gives float64 from uint16 25 | captured_intensity = utils.im2float(color_cv_image) # float64 to float32 26 | 27 | # Numpy to tensor 28 | captured_intensity = torch.tensor(captured_intensity, dtype=torch.float32, 29 | device=dev).permute(2, 0, 1).unsqueeze(0) 30 | captured_intensity = torch.sum(captured_intensity, dim=1, keepdim=True) 31 | 32 | warped_intensity = warper(captured_intensity) 33 | imageio.imwrite(file_path, (np.clip(warped_intensity.squeeze().cpu().detach().numpy(), 0.0, 1.0) 34 | * np.iinfo(np.uint16).max).round().astype(np.uint16)) 35 | 36 | if __name__ == '__main__': 37 | # parse arguments 38 | # Command line argument processing / Parameters 39 | p = configargparse.ArgumentParser() 40 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 41 | params.add_parameters(p, 'eval') 42 | opt = p.parse_args() 43 | params.set_configs(opt) 44 | dev = torch.device('cuda') 45 | 46 | # hardware setup 47 | params_slm, params_camera, params_calib = params.hw_params(opt) 48 | camera_prop = props.PhysicalProp(params_slm, params_camera, params_calib).to(dev) 49 | 50 | data_path = opt.data_path # pass a path of your dataset folder like f'F:/dataset/green' 51 | 52 | if not opt.chan_str in data_path: 53 | raise ValueError('Double check the color!') 54 | 55 | ds = ['test', 'val', 'train'] 56 | data_paths = [os.path.join(data_path, d) for d in ds] 57 | for root_path in data_paths: 58 | 59 | # load phases 60 | phase_path = f'{root_path}/phase' 61 | captured_path = f'{root_path}/captured' 62 | utils.cond_mkdir(captured_path) 63 | names = os.listdir(f'{phase_path}') 64 | 65 | # run multiple thread for fast capture 66 | with concurrent.futures.ThreadPoolExecutor() as executor: 67 | for ii, full_name in enumerate(names): 68 | t0 = time.perf_counter() 69 | 70 | filename = f'{phase_path}/{full_name}' 71 | phase_img = skimage.io.imread(filename) 72 | raw_uint16_data = camera_prop.capture_uint16(phase_img) # display & retrieve buffer 73 | 74 | out_full_path = os.path.join(captured_path, f'{full_name[:-4]}_{opt.plane_idx}.png') 75 | executor.submit(save_image, 76 | raw_uint16_data, 77 | out_full_path, 78 | camera_prop.camera.demosaick_rule, 79 | camera_prop.warper, dev) 80 | 81 | print(f'{out_full_path}: ' 82 | f'{1 / (time.perf_counter() - t0):.2f}Hz') 83 | 84 | if ii % 500 == 0 and ii > 0: 85 | time.sleep(10.0) 86 | 87 | camera_prop.disconnect() -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | ��name: neural3d 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.6 6 | - numpy 7 | - torchvision 8 | - torchaudio 9 | - cudatoolkit=11.3 10 | - opencv 11 | prefix: C:\Users\suyeon\.conda\envs\flex 12 | 13 | -------------------------------------------------------------------------------- /hw/calibration_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing the calibration module, basically calculating homography matrix. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | S. Choi, M. Gopakumar, Y. Peng, J. Kim, G. Wetzstein. Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays. ACM TOG (SIGGRAPH Asia), 2021. 11 | """ 12 | 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | import torchvision 16 | import cv2 17 | import skimage.transform as transform 18 | import time 19 | import datetime 20 | from scipy.io import savemat 21 | from scipy.ndimage import map_coordinates 22 | import torch 23 | import torch.nn.functional as F 24 | import torch.nn as nn 25 | 26 | def id(x): 27 | return x 28 | 29 | def circle_detect(captured_img, num_circles, spacing, pad_pixels=(0., 0.), show_preview=True, quadratic=False): 30 | """ 31 | Detects the circle of a circle board pattern 32 | 33 | :param captured_img: captured image 34 | :param num_circles: a tuple of integers, (num_circle_x, num_circle_y) 35 | :param spacing: a tuple of integers, in pixels, (space between circles in x, space btw circs in y direction) 36 | :param show_preview: boolean, default True 37 | :param pad_pixels: coordinate of the left top corner of warped image. 38 | Assuming pad this amount of pixels on the other side. 39 | :return: a tuple, (found_dots, H) 40 | found_dots: boolean, indicating success of calibration 41 | H: a 3x3 homography matrix (numpy) 42 | """ 43 | 44 | # Binarization 45 | # org_copy = org.copy() # Otherwise, we write on the original image! 46 | img = (captured_img.copy() * 255).astype(np.uint8) 47 | if len(img.shape) > 2: 48 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 49 | # print(img[...,0].mean()) 50 | # print(img[...,1].mean()) 51 | # print(img[...,2].mean()) 52 | 53 | #img = cv2.medianBlur(img, 31) 54 | img = cv2.medianBlur(img, 55) # Red 71 55 | # img = cv2.medianBlur(img, 5) #210104 56 | img_gray = img.copy() 57 | 58 | # img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 121, 0) 59 | img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 127, 0) 60 | 61 | # img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 117, 0) # Red 127 62 | # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) 63 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (31, 31)) 64 | img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel) 65 | img = 255 - img 66 | 67 | # Blob detection 68 | params = cv2.SimpleBlobDetector_Params() 69 | 70 | # Change thresholds 71 | params.filterByColor = True 72 | params.minThreshold = 150 73 | params.minThreshold = 121 74 | # params.minThreshold = 121 75 | 76 | # Filter by Area. 77 | params.filterByArea = True 78 | params.minArea = 150 79 | # params.minArea = 80 # Red 120 80 | # params.minArea = 30 # 210104 81 | # params.maxArea = 100 # 210104 82 | 83 | # Filter by Circularity 84 | params.filterByCircularity = True 85 | params.minCircularity = 0.85 86 | params.minCircularity = 0.60 87 | 88 | # Filter by Convexity 89 | params.filterByConvexity = True 90 | params.minConvexity = 0.87 91 | # params.minConvexity = 0.80 92 | 93 | # Filter by Inertia 94 | params.filterByInertia = False 95 | params.minInertiaRatio = 0.01 96 | 97 | detector = cv2.SimpleBlobDetector_create(params) 98 | 99 | # Detecting keypoints 100 | # this is redundant for what comes next, but gives us access to the detected dots for debug 101 | keypoints = detector.detect(img) 102 | found_dots, centers = cv2.findCirclesGrid(img, (num_circles[1], num_circles[0]), 103 | blobDetector=detector, flags=cv2.CALIB_CB_SYMMETRIC_GRID) 104 | 105 | # Drawing the keypoints 106 | cv2.drawChessboardCorners(captured_img, num_circles, centers, found_dots) 107 | img_gray = cv2.drawKeypoints(img_gray, keypoints, np.array([]), (0, 255, 0), 108 | cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS) 109 | 110 | # Find transformation 111 | H = np.array([[1., 0., 0.], 112 | [0., 1., 0.], 113 | [0., 0., 1.]], dtype=np.float32) 114 | if found_dots: 115 | # Generate reference points to compute the homography 116 | ref_pts = np.zeros((num_circles[0] * num_circles[1], 1, 2), np.float32) 117 | pos = 0 118 | for j in range(0, num_circles[0]): 119 | for i in range(0, num_circles[1]): 120 | ref_pts[pos, 0, :] = spacing * np.array([i, j]) + np.array([pad_pixels[1], pad_pixels[0]]) 121 | 122 | pos += 1 123 | 124 | 125 | H, mask = cv2.findHomography(centers, ref_pts, cv2.RANSAC, 1) 126 | 127 | ref_pts = ref_pts.reshape(num_circles[0] * num_circles[1], 2) 128 | centers = np.flip(centers.reshape(num_circles[0] * num_circles[1], 2), 1) 129 | 130 | 131 | now = datetime.datetime.now() 132 | mdic = {"centers": centers, 'H': H} 133 | # savemat(f'F:/2021/centers_coords/centers_{now.strftime("%m%d_%H%M%S")}.mat', mdic) 134 | dsize = [int((num_circs - 1) * space + 2 * pad_pixs) 135 | for num_circs, space, pad_pixs in zip(num_circles, spacing, pad_pixels) ] 136 | if quadratic: 137 | H = transform.estimate_transform('polynomial', ref_pts, centers) 138 | coords = transform.warp_coords(H, dsize, dtype=np.float32) # for pytorch 139 | else: 140 | tf = transform.estimate_transform('projective', ref_pts, centers) 141 | coords = transform.warp_coords(tf, (800, 1280), dtype=np.float32) # for pytorch 142 | 143 | if show_preview: 144 | dsize = [int((num_circs - 1) * space + 2 * pad_pixs) 145 | for num_circs, space, pad_pixs in zip(num_circles, spacing, pad_pixels)] 146 | if quadratic: 147 | captured_img_warp = transform.warp(captured_img, H, output_shape=(dsize[0], dsize[1])) 148 | else: 149 | captured_img_warp = cv2.warpPerspective(captured_img, H, (dsize[1], dsize[0])) 150 | 151 | 152 | if show_preview: 153 | fig = plt.figure() 154 | 155 | ax = fig.add_subplot(223) 156 | ax.imshow(img_gray, cmap='gray') 157 | 158 | ax2 = fig.add_subplot(221) 159 | ax2.imshow(img, cmap='gray') 160 | 161 | ax3 = fig.add_subplot(222) 162 | ax3.imshow(captured_img, cmap='gray') 163 | 164 | if found_dots: 165 | ax4 = fig.add_subplot(224) 166 | ax4.imshow(captured_img_warp, cmap='gray') 167 | 168 | plt.show() 169 | 170 | return found_dots, H, coords 171 | 172 | 173 | class Warper(nn.Module): 174 | def __init__(self, params_calib): 175 | super(Warper, self).__init__() 176 | self.num_circles = params_calib.num_circles 177 | self.spacing_size = params_calib.spacing_size 178 | self.pad_pixels = params_calib.pad_pixels 179 | self.quadratic = params_calib.quadratic 180 | self.img_size_native = params_calib.img_size_native # get this from image 181 | self.h_transform = np.array([[1., 0., 0.], 182 | [0., 1., 0.], 183 | [0., 0., 1.]]) 184 | self.range_x = params_calib.range_x # slice 185 | self.range_y = params_calib.range_y # slice 186 | 187 | 188 | def calibrate(self, img, show_preview=True): 189 | img_masked = np.zeros_like(img) 190 | img_masked[self.range_y, self.range_x, ...] = img[self.range_y, self.range_x, ...] 191 | 192 | found_corners, self.h_transform, self.coords = circle_detect(img_masked, self.num_circles, 193 | self.spacing_size, self.pad_pixels, show_preview, 194 | quadratic=self.quadratic) 195 | if not self.coords is None: 196 | self.coords_tensor = torch.tensor(np.transpose(self.coords, (1, 2, 0)), 197 | dtype=torch.float32).unsqueeze(0) 198 | 199 | # normalize it into [-1, 1] 200 | # self.coords_tensor[..., 0] = self.coords_tensor[..., 0] / (self.img_size_native[1]//2) - 1 201 | # self.coords_tensor[..., 1] = self.coords_tensor[..., 1] / (self.img_size_native[0]//2) - 1 202 | self.coords_tensor[..., 0] = 2*self.coords_tensor[..., 0] / (self.img_size_native[1]-1) - 1 203 | self.coords_tensor[..., 1] = 2*self.coords_tensor[..., 1] / (self.img_size_native[0]-1) - 1 204 | 205 | return found_corners 206 | 207 | def __call__(self, input_img, img_size=None): 208 | """ 209 | This forward pass returns the warped image. 210 | 211 | :param input_img: A numpy grayscale image shape of [H, W]. 212 | :param img_size: output size, default None. 213 | :return: output_img: warped image with pre-calculated homography and destination size. 214 | """ 215 | 216 | if img_size is None: 217 | img_size = [int((num_circs - 1) * space + 2 * pad_pixs) 218 | for num_circs, space, pad_pixs in zip(self.num_circles, self.spacing_size, self.pad_pixels)] 219 | 220 | if torch.is_tensor(input_img): 221 | # output_img = F.grid_sample(input_img, self.coords_tensor, align_corners=False) 222 | output_img = F.grid_sample(input_img, self.coords_tensor, align_corners=True) 223 | else: 224 | if self.quadratic: 225 | output_img = transform.warp(input_img, self.h_transform, output_shape=(img_size[0], img_size[1])) 226 | else: 227 | output_img = cv2.warpPerspective(input_img, self.h_transform, (img_size[0], img_size[1])) 228 | 229 | return output_img 230 | 231 | @property 232 | def h_transform(self): 233 | return self._h_transform 234 | 235 | @h_transform.setter 236 | def h_transform(self, new_h): 237 | self._h_transform = new_h 238 | 239 | def to(self, *args, **kwargs): 240 | slf = super().to(*args, **kwargs) 241 | if slf.coords_tensor is not None: 242 | slf.coords_tensor = slf.coords_tensor.to(*args, **kwargs) 243 | try: 244 | slf.dev = next(slf.parameters()).device 245 | except StopIteration: # no parameters 246 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0] 247 | if device_arg is not None: 248 | slf.dev = device_arg 249 | return slf -------------------------------------------------------------------------------- /hw/camera_capture_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing the calibration module, basically calculating homography matrix. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | S. Choi, M. Gopakumar, Y. Peng, J. Kim, G. Wetzstein. Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays. ACM TOG (SIGGRAPH Asia), 2021. 11 | """ 12 | 13 | import PyCapture2 14 | import cv2 15 | import numpy as np 16 | import time 17 | import utils 18 | 19 | 20 | def callback_captured(image): 21 | print(image.getData()) 22 | 23 | 24 | class CameraCapture: 25 | def __init__(self, params): 26 | self.bus = PyCapture2.BusManager() 27 | num_cams = self.bus.getNumOfCameras() 28 | if not num_cams: 29 | exit() 30 | # self.demosaick_rule = cv2.COLOR_BAYER_RG2BGR 31 | self.demosaick_rule = cv2.COLOR_BAYER_GR2RGB # GBRG to RGB 32 | self.params = params 33 | 34 | def connect(self, i, trigger=False): 35 | uid = self.bus.getCameraFromIndex(i) 36 | self.camera_device = PyCapture2.Camera() 37 | self.camera_device.connect(uid) 38 | self.camera_device.setConfiguration(highPerformanceRetrieveBuffer=True) 39 | self.camera_device.setConfiguration(numBuffers=1000) 40 | config = self.camera_device.getConfiguration() 41 | self.toggle_embedded_timestamp(True) 42 | 43 | if trigger: 44 | trigger_mode = self.camera_device.getTriggerMode() 45 | trigger_mode.onOff = True 46 | trigger_mode.mode = 0 47 | trigger_mode.parameter = 0 48 | trigger_mode.source = 3 # Using software trigger 49 | self.camera_device.setTriggerMode(trigger_mode) 50 | else: 51 | trigger_mode = self.camera_device.getTriggerMode() 52 | trigger_mode.onOff = False 53 | trigger_mode.mode = 0 54 | trigger_mode.parameter = 0 55 | trigger_mode.source = 3 # Using software trigger 56 | self.camera_device.setTriggerMode(trigger_mode) 57 | 58 | trigger_mode = self.camera_device.getTriggerMode() 59 | if trigger_mode.onOff is True: 60 | print(' - setting trigger mode on') 61 | 62 | 63 | def disconnect(self): 64 | self.toggle_embedded_timestamp(False) 65 | self.camera_device.disconnect() 66 | 67 | def toggle_embedded_timestamp(self, enable_timestamp): 68 | embedded_info = self.camera_device.getEmbeddedImageInfo() 69 | if embedded_info.available.timestamp: 70 | self.camera_device.setEmbeddedImageInfo(timestamp=enable_timestamp) 71 | 72 | def grab_images(self, num_images_to_grab=1): 73 | """ 74 | Retrieve the camera buffer and returns a list of grabbed images. 75 | 76 | :param num_images_to_grab: integer, default 1 77 | :return: a list of numpy 2d color images from the camera buffer. 78 | """ 79 | self.camera_device.startCapture() 80 | img_list = [] 81 | for i in range(num_images_to_grab): 82 | imgData = self.retrieve_buffer() 83 | offset = 64 # offset that inherently exist.retrieve_buffer 84 | imgData = imgData - offset 85 | 86 | color_cv_image = cv2.cvtColor(imgData, self.demosaick_rule) 87 | color_cv_image = utils.im2float(color_cv_image) 88 | img_list.append(color_cv_image.copy()) 89 | 90 | self.camera_device.stopCapture() 91 | return img_list 92 | 93 | def grab_images_fast(self, num_images_to_grab=1): 94 | """ 95 | Retrieve the camera buffer and returns a grabbed image 96 | 97 | :param num_images_to_grab: integer, default 1 98 | :return: a list of numpy 2d color images from the camera buffer. 99 | """ 100 | imgData = self.retrieve_buffer() 101 | offset = 64 # offset that inherently exist. 102 | imgData = imgData - offset 103 | 104 | color_cv_image = cv2.cvtColor(imgData, self.demosaick_rule) 105 | color_cv_image = utils.im2float(color_cv_image) 106 | color_img = color_cv_image 107 | return color_img 108 | 109 | def retrieve_buffer(self): 110 | try: 111 | img = self.camera_device.retrieveBuffer() 112 | except PyCapture2.Fc2error as fc2Err: 113 | raise fc2Err 114 | 115 | imgData = img.getData() 116 | 117 | # when using raw8 from the PG sensor 118 | # cv_image = np.array(img.getData(), dtype="uint8").reshape((img.getRows(), img.getCols())) 119 | 120 | # when using raw16 from the PG sensor - concat 2 8bits in a row 121 | imgData.dtype = np.uint16 122 | imgData = imgData.reshape(img.getRows(), img.getCols()) 123 | return imgData.copy() 124 | 125 | def start_capture(self): 126 | # these two were previously inside the grab_images func, and can be clarified outside the loop 127 | self.camera_device.startCapture() 128 | 129 | def stop_capture(self): 130 | self.camera_device.stopCapture() 131 | 132 | @property 133 | def params(self): 134 | return self._params 135 | 136 | @params.setter 137 | def params(self, p): 138 | self._params = p -------------------------------------------------------------------------------- /hw/detect_heds_module_path.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #--------------------------------------------------------------------# 4 | # # 5 | # Copyright (C) 2020 HOLOEYE Photonics AG. All rights reserved. # 6 | # Contact: https://holoeye.com/contact/ # 7 | # # 8 | # This file is part of HOLOEYE SLM Display SDK. # 9 | # # 10 | # You may use this file under the terms and conditions of the # 11 | # 'HOLOEYE SLM Display SDK Standard License v1.0' license agreement. # 12 | # # 13 | #--------------------------------------------------------------------# 14 | 15 | 16 | # Please import this file in your scripts before actually importing the HOLOEYE SLM Display SDK, 17 | # i. e. copy this file to your project and use this code in your scripts: 18 | # 19 | # import detect_heds_module_path 20 | # import holoeye 21 | # 22 | # 23 | # Another option is to copy the holoeye module directory into your project and import by only using 24 | # import holoeye 25 | # This way, code completion etc. might work better. 26 | 27 | 28 | import os, sys 29 | from platform import system 30 | 31 | # Import the SLM Display SDK: 32 | HEDSModulePath = os.getenv('HEDS_2_PYTHON_MODULES', '') 33 | 34 | if HEDSModulePath == '': 35 | sdklocal = os.path.abspath(os.path.join(os.path.dirname(__file__), 36 | 'holoeye', 'slmdisplaysdk', '__init__.py')) 37 | if os.path.isfile(sdklocal): 38 | HEDSModulePath = os.path.dirname(os.path.dirname(os.path.dirname(sdklocal))) 39 | else: 40 | sdklocal = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 41 | 'sdk', 'holoeye', 'slmdisplaysdk', '__init__.py')) 42 | if os.path.isfile(sdklocal): 43 | HEDSModulePath = os.path.dirname(os.path.dirname(os.path.dirname(sdklocal))) 44 | 45 | if HEDSModulePath == '': 46 | if system() == 'Windows': 47 | print('\033[91m' 48 | '\nError: Could not find HOLOEYE SLM Display SDK installation path from environment variable. ' 49 | '\n\nPlease relogin your Windows user account and try again. ' 50 | '\nIf that does not help, please reinstall the SDK and then relogin your user account and try again. ' 51 | '\nA simple restart of the computer might fix the problem, too.' 52 | '\033[0m') 53 | else: 54 | print('\033[91m' 55 | '\nError: Could not detect HOLOEYE SLM Display SDK installation path. ' 56 | '\n\nPlease make sure it is present within the same folder or in "../../sdk".' 57 | '\033[0m') 58 | 59 | sys.exit(1) 60 | 61 | sys.path.append(HEDSModulePath) 62 | -------------------------------------------------------------------------------- /hw/lens_control.ino: -------------------------------------------------------------------------------- 1 | // Autofocus control code of Canon EF-S18-55mm f/3.5-5.6 IS II 2 | // helpful instructions: 3 | // https://docplayer.net/20752779-How-to-move-canon-ef-lenses-yosuke-bando.html 4 | // https://github.com/crescentvenus/EF-Lens-CONTROL/blob/master/EF-Lens-control.ino 5 | // https://gist.github.com/marcan/858c242db2fc595da1e0bb70a05192fc 6 | // Contact: Suyeon Choi (suyeon@stanford.edu) 7 | 8 | 9 | #include 10 | #include 11 | 12 | const int HotShoe_Pin = 8; 13 | const int HotShoe_Gnd = 9; 14 | const int LogicVDD_Pin = 10; 15 | const int Cam2Lens_Pin = 11; 16 | const int Clock_Pin = 13; 17 | 18 | // manually set for 11 planes (the last one is dummy) 19 | // *** Calibrate these focus values for your own lens *** 20 | // Inf to 3D (0D, 0.3D, ... 3D) gotta correspond to some range in your physical setup 21 | const int focus[12] = {2047, 1975, 1896, 1830, 1795, 1750, 1655, 1580, 1560, 1510, 1460, 0}; 22 | 23 | const int current_focus = 0; 24 | #define INPUT_SIZE 30 25 | 26 | void init_lens() { 27 | // Before sending signal to lens you should do this 28 | 29 | SPI.transfer(0x0A); 30 | delay(30); 31 | SPI.transfer(0x00); 32 | delay(30); 33 | SPI.transfer(0x0A); 34 | delay(30); 35 | SPI.transfer(0x00); 36 | delay(30); 37 | } 38 | 39 | void setup() // initialization 40 | { 41 | Serial.begin(9600); 42 | 43 | pinMode(LogicVDD_Pin, OUTPUT); 44 | digitalWrite(LogicVDD_Pin, HIGH); 45 | pinMode(Cam2Lens_Pin, OUTPUT); 46 | pinMode(Clock_Pin, OUTPUT); 47 | digitalWrite(Clock_Pin, HIGH); 48 | SPI.beginTransaction(SPISettings(9600, MSBFIRST, SPI_MODE3)); 49 | move_focus_infinity(); 50 | } 51 | 52 | void loop() { 53 | char input[INPUT_SIZE + 1]; 54 | byte size = Serial.readBytes(input, INPUT_SIZE); 55 | // Add the final 0 to end the C string 56 | input[size] = 0; 57 | 58 | // Read each command 59 | char* command = strtok(input, ","); 60 | while (command != 0) 61 | { 62 | // input command is assumed to be an integer. 63 | int idx_plane = atoi(command); 64 | move_focus(idx_plane); 65 | 66 | // Find the next command in input string 67 | command = strtok(0, ","); 68 | 69 | } 70 | } 71 | 72 | void move_focus(int idx_plane) { 73 | // Move focus state of lens with index of plane (values for each plane are predefined) 74 | 75 | if (idx_plane > 10) { 76 | Serial.print(" - wrong idx"); 77 | return; 78 | } 79 | //else if (idx_plane == 0){ // commenting this out is and using relative is indeed more stable 80 | // Serial.print(" - move to infinity idx"); 81 | // move_focus_infinity(); 82 | //} 83 | else { 84 | // Below print cmds are for python 85 | 86 | Serial.print(" - from arduino: moving to the "); 87 | Serial.print(idx_plane); 88 | Serial.print("th plane"); 89 | int offset = focus[idx_plane] - read_int_EEPROM(current_focus); 90 | Serial.print(offset); 91 | if (offset != 0){ 92 | /////////////////////////////////// 93 | // This is what you send to lens // 94 | /////////////////////////////////// 95 | byte HH = highByte(offset); 96 | byte LL = lowByte(offset); 97 | init_lens(); 98 | SPI.transfer(0x44); delay(10); 99 | SPI.transfer(HH); delay(10); 100 | SPI.transfer(LL); delay(10); 101 | write_int_EEPROM(current_focus, focus[idx_plane]); 102 | } 103 | } 104 | delay(100); 105 | 106 | } 107 | 108 | 109 | void move_focus_value(int value) { 110 | // Move focus state of lens with exact value 111 | 112 | int offset = value - read_int_EEPROM(current_focus); 113 | byte HH = highByte(offset); 114 | byte LL = lowByte(offset); 115 | init_lens(); 116 | SPI.transfer(0x44); delay(10); 117 | SPI.transfer(HH); delay(10); 118 | SPI.transfer(LL); delay(10); 119 | 120 | write_int_EEPROM(current_focus, value); 121 | } 122 | 123 | void move_focus_infinity(){ 124 | init_lens(); 125 | SPI.transfer(0x05); delay(10); 126 | 127 | write_int_EEPROM(current_focus, focus[0]); 128 | } 129 | 130 | void write_int_EEPROM(int address, int number) 131 | { 132 | EEPROM.write(address, number >> 8); 133 | EEPROM.write(address + 1, number & 0xFF); 134 | } 135 | 136 | int read_int_EEPROM(int address) 137 | { 138 | return (EEPROM.read(address) << 8) + EEPROM.read(address + 1); 139 | } 140 | -------------------------------------------------------------------------------- /hw/lens_control.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test code for Canon EF lens 3 | 4 | Contact: Suyeon Choi (suyeon@stanford.edu) 5 | ----- 6 | 7 | $ python lens_control.py --index=$i 8 | """ 9 | 10 | import serial 11 | import time 12 | import serial, time 13 | import random 14 | import configargparse 15 | 16 | # Command line argument processing 17 | p = configargparse.ArgumentParser() 18 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 19 | p.add_argument('--index', type=int, default=None, help='index of plane') 20 | 21 | opt = p.parse_args() 22 | 23 | 24 | t0 = time.perf_counter() 25 | arduino = serial.Serial('COM7', 9600, timeout=1.) 26 | time.sleep(0.1) # give the connection a second to settle 27 | print(f' -- connection is established.. {time.perf_counter() - t0}') 28 | 29 | 30 | if opt.index is not None: 31 | time.sleep(5) 32 | print(f' -- writing...{opt.index}') 33 | arduino.write(f'{opt.index}'.encode()) 34 | 35 | time.sleep(5) 36 | 37 | data = arduino.readline().decode('UTF-8') # hear back from your arduino 38 | if data: 39 | print(data) 40 | else: 41 | while True: 42 | my_input = input() 43 | arduino.write(f'{str(my_input)}'.encode()) 44 | time.sleep(1.0) 45 | 46 | data = arduino.readline().decode('UTF-8') 47 | if data: 48 | print(data) 49 | -------------------------------------------------------------------------------- /hw/slm_display_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing the calibration module, basically calculating homography matrix. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | S. Choi, M. Gopakumar, Y. Peng, J. Kim, G. Wetzstein. Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays. ACM TOG (SIGGRAPH Asia), 2021. 11 | """ 12 | 13 | import hw.detect_heds_module_path 14 | import holoeye 15 | from holoeye import slmdisplaysdk 16 | 17 | 18 | class SLMDisplay: 19 | ErrorCode = slmdisplaysdk.SLMDisplay.ErrorCode 20 | ShowFlags = slmdisplaysdk.SLMDisplay.ShowFlags 21 | State = slmdisplaysdk.SLMDisplay.State 22 | ApplyDataHandleValue = slmdisplaysdk.SLMDisplay.ApplyDataHandleValue 23 | 24 | def __init__(self): 25 | self.ErrorCode = slmdisplaysdk.SLMDisplay.ErrorCode 26 | self.ShowFlags = slmdisplaysdk.SLMDisplay.ShowFlags 27 | 28 | self.displayOptions = self.ShowFlags.PresentAutomatic # PresentAutomatic == 0 (default) 29 | self.displayOptions |= self.ShowFlags.PresentFitWithBars 30 | 31 | def connect(self): 32 | self.slm_device = slmdisplaysdk.SLMDisplay() 33 | self.slm_device.open() # For version 2.0.1 34 | 35 | def disconnect(self): 36 | self.slm_device.release() 37 | 38 | def show_data_from_file(self, filepath): 39 | error = self.slm_device.showDataFromFile(filepath, self.displayOptions) 40 | assert error == self.ErrorCode.NoError, self.slm_device.errorString(error) 41 | 42 | def show_data_from_array(self, numpy_array): 43 | error = self.slm_device.showData(numpy_array) 44 | assert error == self.ErrorCode.NoError, self.slm_device.errorString(error) 45 | -------------------------------------------------------------------------------- /image_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import skimage.io 4 | from imageio import imread 5 | from skimage.transform import resize 6 | from torchvision.transforms.functional import resize as resize_tensor 7 | import cv2 8 | import random 9 | import json 10 | import numpy as np 11 | import h5py 12 | import torch 13 | import utils 14 | 15 | # Check for endianness, based on Daniel Scharstein's optical flow code. 16 | # Using little-endian architecture, these two should be equal. 17 | TAG_FLOAT = 202021.25 18 | TAG_CHAR = 'PIEH' 19 | 20 | def depth_read(filename): 21 | """ Read depth data from file, return as numpy array. """ 22 | f = open(filename, 'rb') 23 | check = np.fromfile(f, dtype=np.float32, count=1)[0] 24 | assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format( 25 | TAG_FLOAT, check) 26 | width = np.fromfile(f, dtype=np.int32, count=1)[0] 27 | height = np.fromfile(f, dtype=np.int32, count=1)[0] 28 | size = width * height 29 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format( 30 | width, height) 31 | depth = np.fromfile(f, dtype=np.float32, count=-1).reshape((height, width)) 32 | return depth 33 | 34 | 35 | def get_matlab_filenames(dir, focuses=None): 36 | """Returns all files in the input directory dir that are images""" 37 | image_types = ('mat') 38 | if isinstance(dir, str): 39 | files = os.listdir(dir) 40 | exts = (os.path.splitext(f)[1] for f in files) 41 | if focuses is not None: 42 | images = [os.path.join(dir, f) 43 | for e, f in zip(exts, files) 44 | if e[1:] in image_types and int(os.path.splitext(f)[0].split('_')[-1]) in focuses] 45 | else: 46 | images = [os.path.join(dir, f) 47 | for e, f in zip(exts, files) 48 | if e[1:] in image_types] 49 | return images 50 | elif isinstance(dir, list): 51 | # Suppport multiple directories (randomly shuffle all) 52 | images = [] 53 | for folder in dir: 54 | files = os.listdir(folder) 55 | exts = (os.path.splitext(f)[1] for f in files) 56 | images_in_folder = [os.path.join(folder, f) 57 | for e, f in zip(exts, files) 58 | if e[1:] in image_types] 59 | images = [*images, *images_in_folder] 60 | 61 | return images 62 | 63 | 64 | def get_image_filenames(dir, focuses=None): 65 | """Returns all files in the input directory dir that are images""" 66 | image_types = ('jpg', 'jpeg', 'tiff', 'tif', 'png', 'bmp', 'gif', 'exr', 'dpt', 'hdf5') 67 | if isinstance(dir, str): 68 | files = os.listdir(dir) 69 | exts = (os.path.splitext(f)[1] for f in files) 70 | if focuses is not None: 71 | images = [os.path.join(dir, f) 72 | for e, f in zip(exts, files) 73 | if e[1:] in image_types and int(os.path.splitext(f)[0].split('_')[-1]) in focuses] 74 | else: 75 | images = [os.path.join(dir, f) 76 | for e, f in zip(exts, files) 77 | if e[1:] in image_types] 78 | return images 79 | elif isinstance(dir, list): 80 | # Suppport multiple directories (randomly shuffle all) 81 | images = [] 82 | for folder in dir: 83 | files = os.listdir(folder) 84 | exts = (os.path.splitext(f)[1] for f in files) 85 | images_in_folder = [os.path.join(folder, f) 86 | for e, f in zip(exts, files) 87 | if e[1:] in image_types] 88 | images = [*images, *images_in_folder] 89 | 90 | return images 91 | 92 | 93 | def resize_keep_aspect(image, target_res, pad=False, lf=False, pytorch=False): 94 | """Resizes image to the target_res while keeping aspect ratio by cropping 95 | 96 | image: an 3d array with dims [channel, height, width] 97 | target_res: [height, width] 98 | pad: if True, will pad zeros instead of cropping to preserve aspect ratio 99 | """ 100 | im_res = image.shape[-2:] 101 | 102 | # finds the resolution needed for either dimension to have the target aspect 103 | # ratio, when the other is kept constant. If the image doesn't have the 104 | # target ratio, then one of these two will be larger, and the other smaller, 105 | # than the current image dimensions 106 | resized_res = (int(np.ceil(im_res[1] * target_res[0] / target_res[1])), 107 | int(np.ceil(im_res[0] * target_res[1] / target_res[0]))) 108 | 109 | # only pads smaller or crops larger dims, meaning that the resulting image 110 | # size will be the target aspect ratio after a single pad/crop to the 111 | # resized_res dimensions 112 | if pad: 113 | image = utils.pad_image(image, resized_res, pytorch=False) 114 | else: 115 | image = utils.crop_image(image, resized_res, pytorch=False, lf=lf) 116 | 117 | # switch to numpy channel dim convention, resize, switch back 118 | if lf or pytorch: 119 | image = resize_tensor(image, target_res) 120 | return image 121 | else: 122 | image = np.transpose(image, axes=(1, 2, 0)) 123 | image = resize(image, target_res, mode='reflect') 124 | return np.transpose(image, axes=(2, 0, 1)) 125 | 126 | 127 | def pad_crop_to_res(image, target_res, pytorch=False): 128 | """Pads with 0 and crops as needed to force image to be target_res 129 | 130 | image: an array with dims [..., channel, height, width] 131 | target_res: [height, width] 132 | """ 133 | return utils.crop_image(utils.pad_image(image, 134 | target_res, pytorch=pytorch, stacked_complex=False), 135 | target_res, pytorch=pytorch, stacked_complex=False) 136 | 137 | 138 | def get_folder_names(folder): 139 | """Returns all files in the input directory dir that are images""" 140 | return [d for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d))] 141 | 142 | 143 | class PairsLoader(torch.utils.data.IterableDataset): 144 | """Loads (phase, captured) tuples for forward model training 145 | 146 | Class initialization parameters 147 | ------------------------------- 148 | 149 | :param data_path: 150 | :param plane_idxs: 151 | :param batch_size: 152 | :param image_res: 153 | :param shuffle: 154 | :param avg_energy_ratio: 155 | :param slm_type: 156 | 157 | 158 | """ 159 | 160 | def __init__(self, data_path, plane_idxs=None, batch_size=1, 161 | image_res=(800, 1280), shuffle=True, 162 | avg_energy_ratio=None, slm_type='leto'): 163 | """ 164 | 165 | """ 166 | print(data_path) 167 | if isinstance(data_path, str): 168 | if not os.path.isdir(data_path): 169 | raise NotADirectoryError(f'Data folder: {data_path}') 170 | self.phase_path = os.path.join(data_path, 'phase') 171 | self.captured_path = os.path.join(data_path, 'captured') 172 | elif isinstance(data_path, list): 173 | self.phase_path = [os.path.join(path, 'phase') for path in data_path] 174 | self.captured_path = [os.path.join(path, 'captured') for path in data_path] 175 | 176 | self.all_plane_idxs = plane_idxs 177 | self.avg_energy_ratio = avg_energy_ratio 178 | self.batch_size = batch_size 179 | self.shuffle = shuffle 180 | self.image_res = image_res 181 | self.slm_type = slm_type.lower() 182 | self.im_names = get_image_filenames(self.phase_path) 183 | self.im_names.sort() 184 | 185 | # create list of image IDs with augmentation state 186 | self.order = ((i,) for i in range(len(self.im_names))) 187 | self.order = list(self.order) 188 | 189 | def __iter__(self): 190 | self.ind = 0 191 | if self.shuffle: 192 | random.shuffle(self.order) 193 | return self 194 | 195 | def __len__(self): 196 | return len(self.im_names) 197 | 198 | def __next__(self): 199 | if self.ind < len(self.order): 200 | phase_idx = self.order[self.ind] 201 | 202 | self.ind += 1 203 | return self.load_pair(phase_idx[0]) 204 | else: 205 | raise StopIteration 206 | 207 | def load_pair(self, filenum): 208 | phase_path = self.im_names[filenum] 209 | captured_path = os.path.splitext(os.path.dirname(phase_path))[0] 210 | captured_path = os.path.splitext(os.path.dirname(captured_path))[0] 211 | captured_path = os.path.join(captured_path, 'captured') 212 | 213 | # load phase 214 | phase_im_enc = imread(phase_path) 215 | im = (1 - phase_im_enc / np.iinfo(np.uint8).max) * 2 * np.pi - np.pi 216 | phase_im = torch.tensor(im, dtype=torch.float32).unsqueeze(0) 217 | 218 | _, captured_filename = os.path.split(os.path.splitext(self.im_names[filenum])[0]) 219 | idx = captured_filename.split('/')[-1] 220 | 221 | # load focal stack 222 | captured_amps = [] 223 | for plane_idx in self.all_plane_idxs: 224 | captured_filename = os.path.join(captured_path, f'{idx}_{plane_idx}.png') 225 | captured_intensity = utils.im2float(skimage.io.imread(captured_filename)) 226 | captured_intensity = torch.tensor(captured_intensity, dtype=torch.float32) 227 | if self.avg_energy_ratio is not None: 228 | captured_intensity /= self.avg_energy_ratio[plane_idx] # energy compensation; 229 | captured_amp = torch.sqrt(captured_intensity) 230 | captured_amps.append(captured_amp) 231 | captured_amps = torch.stack(captured_amps, 0) 232 | 233 | return phase_im, captured_amps 234 | 235 | 236 | class TargetLoader(torch.utils.data.IterableDataset): 237 | """Loads target amp/mask tuples for phase optimization 238 | 239 | Class initialization parameters 240 | ------------------------------- 241 | :param data_path: 242 | :param target_type: 243 | :param channel: 244 | :param image_res: 245 | :param roi_res: 246 | :param crop_to_roi: 247 | :param shuffle: 248 | :param vertical_flips: 249 | :param horizontal_flips: 250 | :param virtual_depth_planes: 251 | :param scale_vd_range: 252 | 253 | """ 254 | 255 | def __init__(self, data_path, target_type, channel=None, 256 | image_res=(800, 1280), roi_res=(700, 1190), 257 | crop_to_roi=False, shuffle=False, 258 | vertical_flips=False, horizontal_flips=False, 259 | physical_depth_planes=None, 260 | virtual_depth_planes=None, scale_vd_range=True, 261 | mod_i=None, mod=None, options=None): 262 | """ initialization """ 263 | if isinstance(data_path, str) and not os.path.isdir(data_path): 264 | raise NotADirectoryError(f'Data folder: {data_path}') 265 | 266 | self.data_path = data_path 267 | self.target_type = target_type.lower() 268 | self.channel = channel 269 | self.roi_res = roi_res 270 | self.crop_to_roi = crop_to_roi 271 | self.image_res = image_res 272 | self.shuffle = shuffle 273 | self.physical_depth_planes = physical_depth_planes 274 | self.virtual_depth_planes = virtual_depth_planes 275 | self.vd_min = 0.01 276 | self.vd_max = max(self.virtual_depth_planes) 277 | self.scale_vd_range = scale_vd_range 278 | self.options = options 279 | 280 | self.augmentations = [] 281 | if vertical_flips: 282 | self.augmentations.append(self.augment_vert) 283 | if horizontal_flips: 284 | self.augmentations.append(self.augment_horz) 285 | 286 | # store the possible states for enumerating augmentations 287 | self.augmentation_states = [fn() for fn in self.augmentations] 288 | 289 | if target_type in ('2d', 'rgb'): 290 | self.im_names = get_image_filenames(self.data_path) 291 | self.im_names.sort() 292 | elif target_type in ('2.5d', 'rgbd'): 293 | self.im_names = get_image_filenames(os.path.join(self.data_path, 'rgb')) 294 | self.depth_names = get_image_filenames(os.path.join(self.data_path, 'depth')) 295 | 296 | self.im_names.sort() 297 | self.depth_names.sort() 298 | 299 | # create list of image IDs with augmentation state 300 | self.order = ((i,) for i in range(len(self.im_names))) 301 | for aug_type in self.augmentations: 302 | states = aug_type() # empty call gets possible states 303 | # augment existing list with new entry to states tuple 304 | self.order = ((*prev_states, s) 305 | for prev_states in self.order 306 | for s in states) 307 | self.order = list(self.order) 308 | 309 | if mod_i is not None: 310 | new_order = [] 311 | for m, o in enumerate(self.order): 312 | if m % mod == mod_i: 313 | new_order.append(o) 314 | self.order = new_order 315 | 316 | def __iter__(self): 317 | self.ind = 0 318 | if self.shuffle: 319 | random.shuffle(self.order) 320 | return self 321 | 322 | def __len__(self): 323 | return len(self.order) 324 | 325 | def __next__(self): 326 | if self.ind < len(self.order): 327 | img_idx = self.order[self.ind] 328 | 329 | self.ind += 1 330 | if self.target_type in ('2d', 'rgb'): 331 | return self.load_image(*img_idx) 332 | if self.target_type in ('2.5d', 'rgbd'): 333 | return self.load_image_mask(*img_idx) 334 | else: 335 | raise StopIteration 336 | 337 | def load_image(self, filenum, *augmentation_states): 338 | im = imread(self.im_names[filenum]) 339 | 340 | if len(im.shape) < 3: 341 | im = np.repeat(im[:, :, np.newaxis], 3, axis=2) # augment channels for gray images 342 | 343 | if self.channel is None: 344 | im = im[..., :3] # remove alpha channel, if any 345 | else: 346 | # select channel while keeping dims 347 | im = im[..., self.channel, np.newaxis] 348 | 349 | im = utils.im2float(im, dtype=np.float64) # convert to double, max 1 350 | 351 | # linearize intensity and convert to amplitude 352 | im = utils.srgb_gamma2lin(im) 353 | im = np.sqrt(im) # to amplitude 354 | 355 | # move channel dim to torch convention 356 | im = np.transpose(im, axes=(2, 0, 1)) 357 | 358 | # apply data augmentation 359 | for fn, state in zip(self.augmentations, augmentation_states): 360 | im = fn(im, state) 361 | 362 | # normalize resolution 363 | if self.crop_to_roi: 364 | im = pad_crop_to_res(im, self.roi_res) 365 | else: 366 | im = resize_keep_aspect(im, self.roi_res) 367 | im = pad_crop_to_res(im, self.image_res) 368 | 369 | path = os.path.splitext(self.im_names[filenum])[0] 370 | 371 | return (torch.from_numpy(im).float(), 372 | None, 373 | os.path.split(path)[1].split('_')[-1]) 374 | 375 | def load_depth(self, filenum, *augmentation_states): 376 | depth_path = self.depth_names[filenum] 377 | if 'exr' in depth_path: 378 | depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 379 | elif 'dpt' in depth_path: 380 | dist = depth_read(depth_path) 381 | depth = np.nan_to_num(dist, 100) # NaN to inf 382 | elif 'hdf5' in depth_path: 383 | # Depth (in m) 384 | with h5py.File(depth_path, 'r') as f: 385 | dist = np.array(f['dataset'][:], dtype=np.float32) 386 | depth = np.nan_to_num(dist, 100) # NaN to inf 387 | else: 388 | depth = imread(depth_path) 389 | 390 | depth = utils.im2float(depth, dtype=np.float64) # convert to double, max 1 391 | 392 | if 'bbb' in depth_path: 393 | depth *= 6 # this gives us decent depth distribution with 120mm eyepiece setting. 394 | elif 'sintel' in depth_path or 'dpt' in depth_path: 395 | depth /= 2.5 # this gives us decent depth distribution with 120mm eyepiece setting. 396 | if len(depth.shape) > 2 and depth.shape[-1] > 1: 397 | depth = depth[..., 1] 398 | 399 | if not 'eth' in depth_path.lower(): 400 | depth = 1 / (depth + 1e-20) # meter to diopter conversion 401 | 402 | # apply data augmentation 403 | for fn, state in zip(self.augmentations, augmentation_states): 404 | depth = fn(depth, *state) 405 | 406 | depth = torch.from_numpy(depth.copy()).float().unsqueeze(0) 407 | # normalize resolution 408 | depth.unsqueeze_(0) 409 | if self.crop_to_roi: 410 | depth = pad_crop_to_res(depth, self.roi_res, pytorch=True) 411 | else: 412 | depth = resize_keep_aspect(depth, self.roi_res, pytorch=True) 413 | depth = pad_crop_to_res(depth, self.image_res, pytorch=True) 414 | 415 | # perform scaling in meters 416 | if self.scale_vd_range: 417 | depth = depth - depth.min() 418 | depth = (depth / depth.max()) * (self.vd_max - self.vd_min) 419 | depth = depth + self.vd_min 420 | 421 | # check nans 422 | if (depth.isnan().any()): 423 | print("Found Nans in target depth!") 424 | min_substitute = self.vd_min * torch.ones_like(depth) 425 | depth = torch.where(depth.isnan(), min_substitute, depth) 426 | 427 | path = os.path.splitext(self.depth_names[filenum])[0] 428 | 429 | return (depth.float(), 430 | None, 431 | os.path.split(path)[1].split('_')[-1]) 432 | 433 | def load_image_mask(self, filenum, *augmentation_states): 434 | img_none_idx = self.load_image(filenum, *augmentation_states) 435 | depth_none_idx = self.load_depth(filenum, *augmentation_states) 436 | mask = utils.decompose_depthmap(depth_none_idx[0], self.virtual_depth_planes) 437 | return (img_none_idx[0].unsqueeze(0), mask, img_none_idx[-1]) 438 | 439 | def augment_vert(self, image=None, flip=False): 440 | """ augment data with vertical flip """ 441 | if image is None: 442 | return (True, False) # return possible augmentation values 443 | 444 | if flip: 445 | return image[..., ::-1, :] 446 | return image 447 | 448 | def augment_horz(self, image=None, flip=False): 449 | """ augment data with horizontal flip """ 450 | if image is None: 451 | return (True, False) # return possible augmentation values 452 | 453 | if flip: 454 | return image[..., ::-1] 455 | return image 456 | 457 | -------------------------------------------------------------------------------- /img/citl-asm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/img/citl-asm.png -------------------------------------------------------------------------------- /img/sgd-asm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/img/sgd-asm.png -------------------------------------------------------------------------------- /img/sgd-ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/img/sgd-ours.png -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/img/teaser.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays 3 | 4 | Suyeon Choi*, Manu Gopakumar*, Yifan Peng, Jonghyun Kim, Gordon Wetzstein 5 | 6 | This is the main executive script used for the phase generation using SGD. 7 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 8 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 9 | # The material is provided as-is, with no warranties whatsoever. 10 | # If you publish any code, data, or scientific work based on this, please cite our work. 11 | ----- 12 | 13 | $ python main.py --lr=0.01 --num_iters=10000 14 | 15 | """ 16 | import algorithms as algs 17 | import image_loader as loaders 18 | import numpy as np 19 | 20 | import os 21 | import torch 22 | from torch.utils.tensorboard import SummaryWriter 23 | import imageio 24 | import configargparse 25 | import prop_physical 26 | import prop_model 27 | import utils 28 | import params 29 | 30 | 31 | def main(): 32 | # Command line argument processing / Parameters 33 | torch.set_default_dtype(torch.float32) 34 | p = configargparse.ArgumentParser() 35 | p.add('-c', '--config_filepath', required=False, 36 | is_config_file=True, help='Path to config file.') 37 | params.add_parameters(p, 'eval') 38 | opt = p.parse_args() 39 | params.set_configs(opt) 40 | dev = torch.device('cuda') 41 | 42 | run_id = params.run_id(opt) 43 | # path to save out optimized phases 44 | out_path = os.path.join(opt.out_path, run_id) 45 | print(f' - out_path: {out_path}') 46 | 47 | # Tensorboard 48 | summaries_dir = os.path.join(out_path, 'summaries') 49 | utils.cond_mkdir(summaries_dir) 50 | writer = SummaryWriter(summaries_dir) 51 | 52 | # Propagations 53 | camera_prop = None 54 | if opt.citl: 55 | camera_prop = prop_physical.PhysicalProp(*(params.hw_params(opt))).to(dev) 56 | sim_prop = prop_model.model(opt) 57 | 58 | # Algorithm 59 | algorithm = algs.load_alg(opt.method) 60 | 61 | # Loader 62 | if ',' in opt.data_path: 63 | opt.data_path = opt.data_path.split(',') 64 | img_loader = loaders.TargetLoader(opt.data_path, opt.target, channel=opt.channel, 65 | image_res=opt.image_res, roi_res=opt.roi_res, 66 | crop_to_roi=False, shuffle=opt.random_gen, 67 | vertical_flips=opt.random_gen, horizontal_flips=opt.random_gen, 68 | physical_depth_planes=opt.physical_depth_planes, 69 | virtual_depth_planes=opt.virtual_depth_planes, 70 | scale_vd_range=False, 71 | mod_i=opt.mod_i, mod=opt.mod, options=opt) 72 | 73 | for i, target in enumerate(img_loader): 74 | target_amp, target_mask, target_idx = target 75 | target_amp = target_amp.to(dev).detach() 76 | if target_mask is not None: 77 | target_mask = target_mask.to(dev).detach() 78 | if len(target_amp.shape) < 4: 79 | target_amp = target_amp.unsqueeze(0) 80 | 81 | print(f' - run phase optimization for {target_idx}th image ...') 82 | if opt.random_gen: # random parameters for dataset generation 83 | opt.num_iters, opt.init_phase_range, \ 84 | target_range, opt.lr, opt.eval_plane_idx = utils.random_gen(num_planes=opt.num_planes, 85 | slm_type=opt.slm_type) 86 | sim_prop = prop_model.model(opt) 87 | target_amp *= target_range 88 | 89 | # initial slm phase 90 | init_phase = (opt.init_phase_range * (-0.5 + 1.0 * torch.rand(1, 1, *opt.slm_res))).to(dev) 91 | 92 | # run algorithm 93 | results = algorithm(init_phase, target_amp, target_mask, 94 | forward_prop=sim_prop, num_iters=opt.num_iters, roi_res=opt.roi_res, 95 | loss_fn=opt.loss_fn, lr=opt.lr, 96 | out_path_idx=f'{opt.out_path}_{target_idx}', 97 | citl=opt.citl, camera_prop=camera_prop, 98 | writer=writer, 99 | ) 100 | 101 | # optimized slm phase 102 | final_phase = results['final_phase'] 103 | 104 | # encoding for SLM & save it out 105 | phase_out = utils.phasemap_8bit(final_phase) 106 | if opt.random_gen: 107 | phase_out_path = os.path.join(out_path, f'{target_idx}_{opt.num_iters}.png') 108 | else: 109 | phase_out_path = os.path.join(out_path, f'{target_idx}_{opt.eval_plane_idx}.png') 110 | imageio.imwrite(phase_out_path, phase_out) 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | """ 2 | Default parameter settings for SLMs as well as laser/sensors 3 | 4 | """ 5 | import datetime 6 | import math 7 | import sys 8 | import numpy as np 9 | import utils 10 | import torch.nn as nn 11 | if sys.platform == 'win32': 12 | import serial 13 | 14 | cm, mm, um, nm = 1e-2, 1e-3, 1e-6, 1e-9 15 | 16 | def str2bool(v): 17 | """ Simple query parser for configArgParse (which doesn't support native bool from cmd) 18 | Ref: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 19 | 20 | """ 21 | if isinstance(v, bool): 22 | return v 23 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 24 | return True 25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 26 | return False 27 | else: 28 | raise ValueError('Boolean value expected.') 29 | 30 | 31 | class PMap(dict): 32 | # use it for parameters 33 | __getattr__ = dict.get 34 | __setattr__ = dict.__setitem__ 35 | __delattr__ = dict.__delitem__ 36 | 37 | 38 | def add_parameters(p, mode='train'): 39 | p.add_argument('--channel', type=int, default=None, help='Red:0, green:1, blue:2') 40 | p.add_argument('--method', type=str, default='SGD', help='Type of algorithm, GS/SGD/DPAC/HOLONET/UNET') 41 | p.add_argument('--slm_type', type=str, default='leto', help='leto/pluto ...') 42 | p.add_argument('--sensor_type', type=str, default='2k', help='sensor type') 43 | p.add_argument('--laser_type', type=str, default='new', help='old, new_laser, sLED, ...') 44 | p.add_argument('--setup_type', type=str, default='sigasia2021_vr', help='VR or AR') 45 | p.add_argument('--prop_model', type=str, default='ASM', help='Type of propagation model, ASM/NH/NH3D') 46 | p.add_argument('--out_path', type=str, default='./results', 47 | help='Directory for output') 48 | p.add_argument('--citl', type=str2bool, default=False, 49 | help='If True, run camera-in-the-loop') 50 | p.add_argument('--mod_i', type=int, default=None, 51 | help='If not None, say K, pick every K target images from the target loader') 52 | p.add_argument('--mod', type=int, default=None, 53 | help='If not None, say K, pick every K target images from the target loader') 54 | p.add_argument('--data_path', type=str, default='/mount/workspace/data/NH3D/', 55 | help='Directory for input') 56 | p.add_argument('--exp', type=str, default='', help='Name of experiment') 57 | p.add_argument('--lr', type=float, default=0.02, help='Learning rate') 58 | p.add_argument('--num_iters', type=int, default=1000, help='Number of iterations (GS, SGD)') 59 | p.add_argument('--prop_dist', type=float, default=None, help='propagation distance from SLM to midplane') 60 | p.add_argument('--F_aperture', type=float, default=1.0, help='Fourier filter size') 61 | p.add_argument('--eyepiece', type=float, default=0.12, help='eyepiece focal length') 62 | p.add_argument('--full_roi', type=str2bool, default=False, 63 | help='If True, force ROI to SLM resolution') 64 | p.add_argument('--target', type=str, default='rgbd', 65 | help='Type of target:' '{2d, rgb} or {2.5d, rgbd}') 66 | p.add_argument('--show_preview', type=str2bool, default=False, 67 | help='If true, show the preview for homography calibration') 68 | p.add_argument('--random_gen', type=str2bool, default=False, 69 | help='If true, randomize a few parameters for phase dataset generation') 70 | p.add_argument('--eval_plane_idx', type=int, default=None, 71 | help='When evaluating 2d, choose the index of the plane') 72 | p.add_argument('--init_phase_range', type=float, default=1.0, 73 | help='Phase sampling range for initializaation') 74 | 75 | if mode in ('train', 'eval'): 76 | p.add_argument('--num_epochs', type=int, default=350, help='') 77 | p.add_argument('--batch_size', type=int, default=1, help='') 78 | p.add_argument('--prop_model_path', type=str, default=None, help='Path to checkpoints') 79 | p.add_argument('--predefined_model', type=str, default=None, help='string for predefined model' 80 | 'nh, nh3d') 81 | p.add_argument('--num_downs_slm', type=int, default=5, help='') 82 | p.add_argument('--num_feats_slm_min', type=int, default=32, help='') 83 | p.add_argument('--num_feats_slm_max', type=int, default=128, help='') 84 | p.add_argument('--num_downs_target', type=int, default=5, help='') 85 | p.add_argument('--num_feats_target_min', type=int, default=32, help='') 86 | p.add_argument('--num_feats_target_max', type=int, default=128, help='') 87 | p.add_argument('--slm_coord', type=str, default='rect', help='coordinates to represent a complex-valued field.' 88 | 'rect(real+imag) or polar(amp+phase)') 89 | p.add_argument('--target_coord', type=str, default='rect', help='coordinates to represent a complex-valued field.' 90 | 'rect(real+imag) or polar(amp+phase)') 91 | p.add_argument('--norm', type=str, default='instance', help='normalization layer') 92 | p.add_argument('--loss_func', type=str, default='l1', help='l1 or l2') 93 | p.add_argument('--energy_compensation', type=str2bool, default=True, help='adjust intensities ' 94 | 'with avg intensity of training set') 95 | p.add_argument('--num_train_planes', type=int, default=7, help='number of planes fed to models') 96 | 97 | return p 98 | 99 | 100 | def set_configs(opt): 101 | """ 102 | set or replace parameters with pre-defined parameters with string inputs 103 | """ 104 | 105 | # hardware setup 106 | optics_config(opt.setup_type, opt) # prop_dist, etc ... 107 | laser_config(opt.laser_type, opt) # Our Old FISBA Laser, New, SLED, LED, ... 108 | slm_config(opt.slm_type, opt) # Holoeye 109 | sensor_config(opt.sensor_type, opt) # our sensor 110 | 111 | # set predefined model parameters 112 | forward_model_config(opt.prop_model, opt) 113 | 114 | # wavelength, propagation distance (from SLM to midplane) 115 | if opt.channel is None: 116 | opt.chan_str = 'rgb' 117 | opt.prop_dist = opt.prop_dists_rgb 118 | opt.wavelength = opt.wavelengths 119 | else: 120 | opt.chan_str = ('red', 'green', 'blue')[opt.channel] 121 | if opt.prop_dist is None: 122 | opt.prop_dist = opt.prop_dists_rgb[opt.channel][opt.mid_idx] # prop dist from SLM plane to target plane 123 | opt.prop_dist_green = opt.prop_dists_rgb[opt.channel][1] 124 | else: 125 | opt.prop_dist_green = opt.prop_dist 126 | opt.wavelength = opt.wavelengths[opt.channel] # wavelength of each color 127 | 128 | # propagation distances from the wavefront recording plane 129 | opt.prop_dists_from_wrp = [p - opt.prop_dist for p in opt.prop_dists_rgb[opt.channel]] 130 | opt.physical_depth_planes = [p - opt.prop_dist_green for p in opt.prop_dists_physical] 131 | opt.virtual_depth_planes = utils.prop_dist_to_diopter(opt.physical_depth_planes, 132 | opt.eyepiece, 133 | opt.physical_depth_planes[0]) 134 | opt.num_planes = len(opt.prop_dists_from_wrp) 135 | opt.all_plane_idxs = range(opt.num_planes) 136 | 137 | # force ROI to that of SLM 138 | if opt.full_roi: 139 | opt.roi_res = opt.slm_res 140 | 141 | ################ 142 | # Model Training 143 | # compensate the brightness difference per plane (for model training) 144 | if opt.energy_compensation: 145 | opt.avg_energy_ratio = opt.avg_energy_ratio_rgb[opt.channel] 146 | else: 147 | opt.avg_energy_ratio = None 148 | 149 | # loss functions (for model training) 150 | opt.loss_train = None 151 | opt.loss_fn = None 152 | if opt.loss_func.lower() in ('l2', 'mse'): 153 | opt.loss_train = nn.functional.mse_loss 154 | opt.loss_fn = nn.functional.mse_loss 155 | elif opt.loss_func.lower() == 'l1': 156 | opt.loss_train = nn.functional.l1_loss 157 | opt.loss_fn = nn.functional.l1_loss 158 | 159 | # plane idxs (for model training) 160 | opt.plane_idxs = {} 161 | opt.plane_idxs['all'] = opt.all_plane_idxs 162 | opt.plane_idxs['train'] = opt.training_plane_idxs 163 | opt.plane_idxs['validation'] = opt.training_plane_idxs 164 | opt.plane_idxs['test'] = opt.training_plane_idxs 165 | opt.plane_idxs['heldout'] = opt.heldout_plane_idxs 166 | 167 | admm_opt = None 168 | if 'admm' in opt.method: 169 | admm_opt = {'num_iters_inner': 50, 170 | 'rho': 0.01, 171 | 'alpha': 1.0, 172 | 'gamma': 0.1, 173 | 'varying-penalty': True, 174 | 'mu': 10.0, 175 | 'tau_incr': 2.0, 176 | 'tau_decr': 2.0} 177 | 178 | return opt 179 | 180 | 181 | def run_id(opt): 182 | id_str = f'{opt.exp}_{opt.chan_str}_{opt.prop_model}_{opt.lr}_{opt.num_iters}' 183 | id_str = f'{opt.chan_str}' 184 | return id_str 185 | 186 | def run_id_training(opt): 187 | id_str = f'{opt.exp}_{opt.chan_str}-' \ 188 | f'slm{opt.num_downs_slm}-{opt.num_feats_slm_min}-{opt.num_feats_slm_max}_' \ 189 | f'tg{opt.num_downs_target}-{opt.num_feats_target_min}-{opt.num_feats_target_max}_' \ 190 | f'{opt.slm_coord}{opt.target_coord}_{opt.loss_func}_{opt.num_train_planes}pls_' \ 191 | f'bs{opt.batch_size}' 192 | cur_time = datetime.datetime.now().strftime("%d-%H%M") 193 | id_str = f'{cur_time}_{id_str}' 194 | 195 | return id_str 196 | 197 | 198 | def hw_params(opt): 199 | """ Default setting for hardware. Please replace and adjust parameters for your own setup. """ 200 | params_slm = PMap() 201 | params_slm.settle_time = 0.3 202 | params_slm.monitor_num = 2 203 | params_slm.slm_type = opt.slm_type 204 | 205 | params_camera = PMap() 206 | params_camera.img_size_native = (3000, 4096) # 4k sensor native 207 | params_camera.ser = None #serial.Serial('COM5', 9600, timeout=0.5) 208 | 209 | params_calib = PMap() 210 | params_calib.show_preview = opt.show_preview 211 | params_calib.range_y = slice(0, params_camera.img_size_native[0]) 212 | params_calib.range_x = slice(0, params_camera.img_size_native[1]) 213 | params_calib.num_circles = (13, 22) 214 | params_calib.spacing_size = [int(roi / (num_circs - 1)) 215 | for roi, num_circs in zip(opt.roi_res, params_calib.num_circles)] 216 | params_calib.pad_pixels = [int(slm - roi) // 2 for slm, roi in zip(opt.slm_res, opt.roi_res)] 217 | params_calib.quadratic = True 218 | params_calib.phase_path = f'./calibration/{opt.chan_str}/1_{opt.eval_plane_idx}.png' 219 | params_calib.img_size_native = params_camera.img_size_native 220 | 221 | return params_slm, params_camera, params_calib 222 | 223 | 224 | def slm_config(slm_type, opt): 225 | """ Setting for specific SLM. """ 226 | if slm_type.lower() in ('leto'): 227 | opt.feature_size = (6.4 * um, 6.4 * um) # SLM pitch 228 | opt.slm_res = (1080, 1920) # resolution of SLM 229 | opt.image_res = opt.slm_res 230 | elif slm_type.lower() in ('pluto'): 231 | opt.feature_size = (8.0 * um, 8.0 * um) # SLM pitch 232 | opt.slm_res = (1080, 1920) # resolution of SLM 233 | opt.image_res = opt.slm_res 234 | 235 | 236 | def laser_config(laser_type, opt): 237 | """ Setting for specific laser. """ 238 | if 'new' in laser_type.lower(): 239 | opt.wavelengths = (636.17 * nm, 518.48 * nm, 442.03 * nm) # wavelength of each color 240 | elif 'ar' in laser_type.lower(): 241 | opt.wavelengths = (532 * nm, 532 * nm, 532 * nm) 242 | else: 243 | opt.wavelengths = (636.4 * nm, 517.7 * nm, 440.8 * nm) 244 | 245 | 246 | def sensor_config(sensor_type, opt): 247 | return opt 248 | 249 | 250 | def optics_config(setup_type, opt): 251 | """ Setting for specific setup (prop dists, filter, training plane index ...) """ 252 | if setup_type in ('sigasia2021_vr'): 253 | opt.laser_type = 'old' 254 | opt.slm_type = 'leto' 255 | opt.prop_dists_rgb = [[0.0, 1*mm, 2*mm, 3*mm, 4.4*mm, 5.7*mm, 7.0*mm, 8.1*mm], 256 | [0.0, 1.2*mm, 2.0*mm, 3.4*mm, 4.4*mm, 5.7*mm, 7.2*mm, 8.2*mm], 257 | [0.0, 1*mm, 2.1*mm, 3.2*mm, 4.3*mm, 5.5*mm, 7.2*mm, 8.2*mm]] 258 | 259 | opt.avg_energy_ratio_rgb = [[1.0000, 1.0407, 1.0870, 1.1216, 1.1568, 1.2091, 1.2589, 1.2924], 260 | [1.0000, 1.0409, 1.0869, 1.1226, 1.1540, 1.2107, 1.2602, 1.2958], 261 | [1.0000, 1.0409, 1.0869, 1.1226, 1.1540, 1.2107, 1.2602, 1.2958]] 262 | opt.prop_dists_physical = opt.prop_dists_rgb[1] 263 | opt.F_aperture = 0.5 264 | opt.roi_res = (960, 1680) # regions of interest (to penalize for SGD) 265 | opt.training_plane_idxs = [0, 1, 3, 4, 5, 6, 7] 266 | opt.heldout_plane_idxs = [2] 267 | opt.mid_idx = 4 # intermediate plane as 1.5D 268 | elif setup_type in ('sigasia2021_ar'): 269 | opt.laser_type = 'ar_green_only' 270 | opt.slm_type = 'pluto' 271 | opt.prop_dists_rgb = [[9.9*mm, 10.3*mm, 11.8*mm, 13.3*mm], 272 | [9.9*mm, 10.3*mm, 11.8*mm, 13.3*mm], 273 | [9.9*mm, 10.3*mm, 11.8*mm, 13.3*mm]] 274 | opt.prop_dists_physical = opt.prop_dists_rgb[1] 275 | opt.F_aperture = 0.5 276 | opt.roi_res = (768, 1536) # regions of interest (to penalize for SGD) 277 | opt.training_plane_idxs = [0, 1, 2] 278 | opt.heldout_plane_idxs = [] 279 | 280 | 281 | def forward_model_config(model_type, opt): 282 | # setting for specific model that is predefined. 283 | if model_type is not None: 284 | print(f' - changing model parameters for {model_type}') 285 | if model_type.lower() in ('cnnpropcnn', 'nh3d'): 286 | opt.num_downs_slm = 8 287 | opt.num_feats_slm_min = 32 288 | opt.num_feats_slm_max = 512 289 | opt.num_downs_target = 5 290 | opt.num_feats_target_min = 8 291 | opt.num_feats_target_max = 128 292 | elif model_type.lower() == 'hil': 293 | opt.num_downs_slm = 0 294 | opt.num_feats_slm_min = 0 295 | opt.num_feats_slm_max = 0 296 | opt.num_downs_target = 8 297 | opt.num_feats_target_min = 32 298 | opt.num_feats_target_max = 512 299 | opt.target_coord = 'amp' 300 | elif model_type.lower() == 'cnnprop': 301 | opt.num_downs_slm = 8 302 | opt.num_feats_slm_min = 32 303 | opt.num_feats_slm_max = 512 304 | opt.num_downs_target = 0 305 | opt.num_feats_target_min = 0 306 | opt.num_feats_target_max = 0 307 | elif model_type.lower() == 'propcnn': 308 | opt.num_downs_slm = 0 309 | opt.num_feats_slm_min = 0 310 | opt.num_feats_slm_max = 0 311 | opt.num_downs_target = 8 312 | opt.num_feats_target_min = 32 313 | opt.num_feats_target_max = 512 -------------------------------------------------------------------------------- /prop_ideal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ideal propagation 3 | 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import utils 9 | import torch.fft as tfft 10 | import math 11 | 12 | class Propagation(nn.Module): 13 | """ 14 | The ideal, convolution-based propagation implementation 15 | 16 | Class initialization parameters 17 | ------------------------------- 18 | :param prop_dist: propagation distance(s) 19 | :param wavelength: wavelength 20 | :param feature_size: pixel pitch 21 | :param prop_type: type of propagation (ASM or fresnel), by default the angular spectrum method 22 | :param F_aperture: filter size at fourier plane, by default 1.0 23 | :param dim: for propagation to multiple planes, dimension to stack the output, by default 1 (second dimension) 24 | :param linear_conv: If true, pad zeros to ensure the linear convolution, by default True 25 | :param learned_amp: Learned amplitude at Fourier plane, by default None 26 | :param learned_phase: Learned phase at Fourier plane, by default None 27 | """ 28 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0, 29 | dim=1, linear_conv=True, learned_amp=None, learned_phase=None): 30 | super(Propagation, self).__init__() 31 | 32 | self.H = None # kernel at Fourier plane 33 | self.prop_type = prop_type 34 | if not isinstance(prop_dist, list): 35 | prop_dist = [prop_dist] 36 | self.prop_dist = prop_dist 37 | self.feature_size = feature_size 38 | if not isinstance(wavelength, list): 39 | wavelength = [wavelength] 40 | self.wvl = wavelength 41 | self.linear_conv = linear_conv # ensure linear convolution by padding 42 | self.bl_asm = min(prop_dist) > 0.3 43 | self.F_aperture = F_aperture 44 | self.dim = dim # The dimension to stack the kernels as well as the resulting fields (if multi-channel) 45 | 46 | self.preload_params = False 47 | self.preloaded_H_amp = False # preload H_mask once trained 48 | self.preloaded_H_phase = False # preload H_phase once trained 49 | 50 | self.fourier_amp = learned_amp 51 | self.fourier_phase = learned_phase 52 | 53 | def forward(self, u_in): 54 | if u_in.dtype == torch.float32: 55 | u_in = torch.exp(1j * u_in) 56 | 57 | if self.H is None: 58 | Hs = [] 59 | if len(self.wvl) > 1: # If multi-channel, rearrange kernels 60 | for wv, prop_dist in zip(self.wvl, self.prop_dist): 61 | print(f' -- generating kernel for {wv*1e9:.1f}nm, {prop_dist*100:.2f}cm..') 62 | h = self.compute_H(torch.empty_like(u_in), prop_dist, wv, self.feature_size, 63 | self.prop_type, self.linear_conv, 64 | F_aperture=self.F_aperture, bl_asm=self.bl_asm) 65 | Hs.append(h) 66 | self.H = torch.cat(Hs, dim=self.dim) 67 | else: 68 | for wv in self.wvl: 69 | for prop_dist in self.prop_dist: 70 | print(f' -- generating kernel for {wv*1e9:.1f}nm, {prop_dist*100:.2f}cm..') 71 | h = self.compute_H(torch.empty_like(u_in), prop_dist, wv, self.feature_size, 72 | self.prop_type, self.linear_conv, 73 | F_aperture=self.F_aperture, bl_asm=self.bl_asm) 74 | Hs.append(h) 75 | self.H = torch.cat(Hs, dim=1) 76 | 77 | if self.preload_params: 78 | self.premultiply() 79 | 80 | if self.fourier_amp is not None and not self.preloaded_H_amp: 81 | H = self.fourier_amp.clamp(min=0.) * self.H 82 | else: 83 | H = self.H 84 | 85 | if self.fourier_phase is not None and not self.preloaded_H_phase: 86 | H = H * torch.exp(1j * self.fourier_phase) 87 | 88 | return self.prop(u_in, H, self.linear_conv) 89 | 90 | def compute_H(self, input_field, prop_dist, wvl, feature_size, prop_type, lin_conv=True, 91 | return_exp=False, F_aperture=1.0, bl_asm=False, return_filter=False): 92 | dev = input_field.device 93 | res_mul = 2 if lin_conv else 1 94 | num_y, num_x = res_mul*input_field.shape[-2], res_mul*input_field.shape[-1] # number of pixels 95 | dy, dx = feature_size # sampling inteval size 96 | 97 | # frequency coordinates sampling 98 | fy = torch.linspace(-1 / (2 * dy), 1 / (2 * dy), num_y) 99 | fx = torch.linspace(-1 / (2 * dx), 1 / (2 * dx), num_x) 100 | 101 | # momentum/reciprocal space 102 | # FY, FX = torch.meshgrid(fy, fx) 103 | FX, FY = torch.meshgrid(fx, fy) 104 | FX = torch.transpose(FX, 0, 1) 105 | FY = torch.transpose(FY, 0, 1) 106 | 107 | if prop_type.lower() == 'asm': 108 | G = 2 * math.pi * (1 / wvl**2 - (FX ** 2 + FY ** 2)).sqrt() 109 | elif prop_type.lower() == 'fresnel': 110 | G = math.pi * wvl * (FX ** 2 + FY ** 2) 111 | 112 | H_exp = G.reshape((1, 1, *G.shape)).to(dev) 113 | 114 | if return_exp: 115 | return H_exp 116 | 117 | if bl_asm: 118 | fy_max = 1 / math.sqrt((2 * prop_dist * (1 / (dy * float(num_y))))**2 + 1) / wvl 119 | fx_max = 1 / math.sqrt((2 * prop_dist * (1 / (dx * float(num_x))))**2 + 1) / wvl 120 | 121 | H_filter = ((torch.abs(FX**2 + FY**2) <= (F_aperture**2) * torch.abs(FX**2 + FY**2).max()) 122 | & (torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).type(torch.FloatTensor) 123 | else: 124 | H_filter = (torch.abs(FX**2 + FY**2) <= (F_aperture**2) * torch.abs(FX**2 + FY**2).max()).type(torch.FloatTensor) 125 | 126 | if prop_dist == 0.: 127 | H = torch.ones_like(H_exp) 128 | else: 129 | H = H_filter.to(input_field.device) * torch.exp(1j * H_exp * prop_dist) 130 | 131 | if return_filter: 132 | return H_filter 133 | else: 134 | return H 135 | 136 | def prop(self, u_in, H, linear_conv=True, padtype='zero'): 137 | if linear_conv: 138 | # preprocess with padding for linear conv. 139 | input_resolution = u_in.size()[-2:] 140 | conv_size = [i * 2 for i in input_resolution] 141 | if padtype == 'zero': 142 | padval = 0 143 | elif padtype == 'median': 144 | padval = torch.median(torch.pow((u_in ** 2).sum(-1), 0.5)) 145 | u_in = utils.pad_image(u_in, conv_size, padval=padval, stacked_complex=False) 146 | 147 | U1 = tfft.fftshift(tfft.fftn(u_in, dim=(-2, -1), norm='ortho'), (-2, -1)) 148 | U2 = U1 * H 149 | u_out = tfft.ifftn(tfft.ifftshift(U2, (-2, -1)), dim=(-2, -1), norm='ortho') 150 | 151 | if linear_conv: 152 | u_out = utils.crop_image(u_out, input_resolution, pytorch=True, stacked_complex=False) 153 | 154 | return u_out 155 | 156 | def __len__(self): 157 | return len(self.prop_dist) 158 | 159 | def preload_H(self): 160 | self.preload_params = True 161 | 162 | def premultiply(self): 163 | self.preload_params = False 164 | 165 | if self.fourier_amp is not None and not self.preloaded_H_amp: 166 | self.H = self.fourier_amp.clamp(min=0.) * self.H 167 | if self.fourier_phase is not None and not self.preloaded_H_phase: 168 | self.H = self.H * torch.exp(1j * self.fourier_phase) 169 | 170 | self.H.detach_() 171 | self.preloaded_H_amp = True 172 | self.preloaded_H_phase = True 173 | 174 | @property 175 | def plane_idx(self): 176 | return self._plane_idx 177 | 178 | @plane_idx.setter 179 | def plane_idx(self, idx): 180 | if idx is None: 181 | return 182 | 183 | self._plane_idx = idx 184 | if len(self.prop_dist) > 1: 185 | self.prop_dist = [self.prop_dist[idx]] 186 | 187 | if self.fourier_amp is not None and self.fourier_amp.shape[1] > 1: 188 | self.fourier_amp = nn.Parameter(self.fourier_amp[:, idx:idx+1, ...], requires_grad=False) 189 | if self.fourier_phase is not None and self.fourier_phase.shape[1] > 1: 190 | self.fourier_phase = nn.Parameter(self.fourier_phase[:, idx:idx+1, ...], requires_grad=False) 191 | 192 | 193 | 194 | class SerialProp(nn.Module): 195 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0, 196 | prop_dists_from_wrp=None, linear_conv=True, dim=1): 197 | super(SerialProp, self).__init__() 198 | 199 | first_prop = Propagation(prop_dist, wavelength, feature_size, 200 | prop_type=prop_type, linear_conv=linear_conv, F_aperture=F_aperture, dim=dim) 201 | props = [first_prop] 202 | if prop_dists_from_wrp is not None: 203 | second_prop = Propagation(prop_dists_from_wrp, wavelength, feature_size, 204 | prop_type=prop_type, linear_conv=linear_conv, F_aperture=1.0, dim=dim) 205 | props += [second_prop] 206 | self.props = nn.Sequential(*props) 207 | 208 | def forward(self, u_in): 209 | 210 | u_out = self.props(u_in) 211 | 212 | return u_out 213 | 214 | def preload_H(self): 215 | for prop in self.props: 216 | prop.preload_H() 217 | 218 | @property 219 | def plane_idx(self): 220 | return self._plane_idx 221 | 222 | @plane_idx.setter 223 | def plane_idx(self, idx): 224 | if idx is None: 225 | return 226 | 227 | self._plane_idx = idx 228 | for prop in self.props: 229 | prop.plane_idx = idx 230 | 231 | -------------------------------------------------------------------------------- /prop_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Parameterized propagations 3 | 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import matplotlib.pyplot as plt 11 | import pytorch_lightning as pl 12 | import utils 13 | from unet import UnetGenerator, init_weights, norm_layer 14 | import prop_ideal 15 | from prop_submodules import Field2Input, Output2Field, Conv2dField,\ 16 | ContentDependentField, LatentCodedMLP, SourceAmplitude 17 | from prop_zernike import compute_zernike_basis, combine_zernike_basis 18 | 19 | 20 | def model(opt, dev=torch.device('cuda'), preload_H=False): 21 | """ 22 | 23 | :param opt: 24 | :param dev: 25 | :param preload_H: 26 | :return: 27 | """ 28 | if opt.prop_model.lower() == 'asm': 29 | sim_prop = prop_ideal.SerialProp(opt.prop_dist, opt.wavelength, opt.feature_size, 30 | 'ASM', opt.F_aperture, opt.prop_dists_from_wrp, 31 | dim=1) 32 | elif opt.prop_model.lower() == 'nh': 33 | sim_prop = PropNH(prop_dist=opt.prop_dist, # Parameterized wave propagation model 34 | feature_size=opt.feature_size, 35 | wavelength=opt.wavelength, 36 | slm_res=opt.slm_res, 37 | roi_res=opt.roi_res, 38 | F_aperture=opt.F_aperture, 39 | prop_dists_from_wrp=opt.prop_dists_from_wrp, 40 | loss_func=opt.loss_train, 41 | lr=opt.lr, 42 | plane_idxs=opt.plane_idxs 43 | ).to(dev) 44 | elif opt.prop_model.lower() in ('cnnprop', 'propcnn', 'hil', 'nh3d', 'cnnpropcnn'): 45 | sim_prop = CNNpropCNN(opt.prop_dist, opt.wavelength, opt.feature_size, 46 | prop_type='ASM', 47 | F_aperture=opt.F_aperture, 48 | prop_dists_from_wrp=opt.prop_dists_from_wrp, 49 | linear_conv=True, 50 | slm_res=opt.slm_res, 51 | roi_res=opt.roi_res, 52 | num_downs_slm=opt.num_downs_slm, 53 | num_feats_slm_min=opt.num_feats_slm_min, 54 | num_feats_slm_max=opt.num_feats_slm_max, 55 | num_downs_target=opt.num_downs_target, 56 | num_feats_target_min=opt.num_feats_target_min, 57 | num_feats_target_max=opt.num_feats_target_max, 58 | norm=norm_layer(opt.norm), 59 | slm_coord=opt.slm_coord, 60 | target_coord=opt.target_coord, 61 | loss_func=opt.loss_train, 62 | lr=opt.lr, 63 | plane_idxs=opt.plane_idxs 64 | ).to(dev) 65 | 66 | if opt.prop_model_path is not None: 67 | checkpoint = torch.load(opt.prop_model_path) 68 | sim_prop.load_state_dict(checkpoint["state_dict"]) 69 | sim_prop.eval() 70 | print(f' - Model loaded from {opt.prop_model_path}') 71 | 72 | if preload_H: 73 | sim_prop.preload_H() 74 | 75 | if opt.eval_plane_idx is not None: 76 | sim_prop.plane_idx = opt.eval_plane_idx 77 | 78 | return sim_prop 79 | 80 | 81 | class PropModel(pl.LightningModule): 82 | """ 83 | A parameterized model trained on captured images at multiple planes 84 | 85 | Class initialization parameters 86 | ------------------------------- 87 | :param roi_res: 88 | :param plane_idxs: 89 | :param loss_func: 90 | :param lr: 91 | """ 92 | def __init__(self, roi_res=(1080, 1920), plane_idxs=None, loss_func=F.l1_loss, lr=4e-4): 93 | super(PropModel, self).__init__() 94 | self.roi_res = roi_res 95 | self.plane_idxs = plane_idxs 96 | self.loss_func = loss_func 97 | self.lr = lr 98 | self.recon_amp = {} 99 | self.target_amp = {} 100 | 101 | def perform_step(self, batch, batch_idx, prefix): 102 | slm_phase, target_amps = batch 103 | 104 | recon_fields = self.forward(slm_phase) # output field at target planes. 105 | 106 | # calculate losses 107 | loss, loss_mse, loss_mse_all = self.mp_loss(recon_fields, target_amps, 108 | self.plane_idxs, self.loss_func, prefix) 109 | 110 | with torch.no_grad(): 111 | self.log(f'loss_{prefix}', loss, on_step=True, on_epoch=True) 112 | self.log(f'PSNR_{prefix}', 10*math.log10(1/loss_mse), on_step=True, on_epoch=True) 113 | self.log(f'PSNR_except_held_out_{prefix}', 114 | 10*math.log10(1/loss_mse_all), on_step=True, on_epoch=True) 115 | 116 | return loss 117 | 118 | def mp_loss(self, recon_fields, target_amps, plane_idxs, loss_func, prefix): 119 | """ Loss function on multiplane amplitudes""" 120 | 121 | # take the amplitudes. 122 | target_amps = utils.crop_image(target_amps, self.roi_res, stacked_complex=False) 123 | recon_amps = utils.crop_image(recon_fields, self.roi_res, stacked_complex=False).abs() 124 | 125 | with torch.no_grad(): 126 | self.recon_amp[prefix] = recon_amps 127 | self.target_amp[prefix] = target_amps 128 | 129 | # calculate loss values. 130 | loss = 0. # the loss you penalize. 131 | mse_loss = 0. # PSNR on planes you penalize. 132 | mse_loss_all = 0. # PSNR on all planes except the held-out plane. 133 | 134 | if plane_idxs is None: 135 | # penalize all planes 136 | loss += loss_func(recon_amps, target_amps) 137 | else: 138 | # penalize only selected planes 139 | for i in range(target_amps.shape[1]): 140 | if i in plane_idxs[prefix]: 141 | # selected planes 142 | loss_i = loss_func(recon_amps[:, i, ...], target_amps[:, i, ...]) 143 | loss += loss_i / len(plane_idxs[prefix]) 144 | else: 145 | # these are not penalized, just for tensorboard 146 | with torch.no_grad(): 147 | loss_i = loss_func(recon_amps[:, i, ...], target_amps[:, i, ...]) 148 | 149 | # report PSNR 150 | with torch.no_grad(): 151 | # this is for evaluation so just use the min-mse scaling. 152 | s = (recon_amps[:, i, ...] * target_amps[:, i, ...]).mean() \ 153 | / (recon_amps[:, i, ...]**2).mean() 154 | mse_loss_i = F.mse_loss(s*recon_amps[:, i, ...], target_amps[:, i, ...]) 155 | 156 | if i in self.plane_idxs[prefix]: 157 | mse_loss += mse_loss_i / len(plane_idxs[prefix]) 158 | 159 | if not i in self.plane_idxs['heldout']: 160 | # exclude held-out plane 161 | mse_loss_all += mse_loss_i / len(plane_idxs['train']) 162 | 163 | self.log(f'loss_{prefix}/plane_{i}', 164 | loss_i, on_step=True, on_epoch=True) 165 | self.log(f'PSNR_{prefix}/plane_{i}', 166 | 10*math.log10(1./mse_loss_i), on_step=True, on_epoch=True) 167 | 168 | return loss, mse_loss, mse_loss_all 169 | 170 | def training_step(self, batch, batch_idx): 171 | return self.perform_step(batch, batch_idx, 'train') 172 | 173 | def validation_step(self, batch, batch_idx): 174 | return self.perform_step(batch, batch_idx, 'validation') 175 | 176 | def test_step(self, batch, batch_idx): 177 | return self.perform_step(batch, batch_idx, 'test') 178 | 179 | def test_epoch_end(self, outputs) -> None: 180 | self.epoch_end_images('test') 181 | 182 | def training_epoch_end(self, outputs) -> None: 183 | self.epoch_end_images('train') 184 | 185 | def validation_epoch_end(self, outputs) -> None: 186 | self.epoch_end_images('validation') 187 | 188 | def configure_optimizers(self): 189 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 190 | return optimizer 191 | 192 | 193 | class CNNpropCNN(PropModel): 194 | """ 195 | A parameterized model with CNNs 196 | 197 | Class initialization parameters 198 | ------------------------------- 199 | :param prop_dist: Propagation distance from SLM to Intermediate plane, float. 200 | :param wavelength: wavelength, float. 201 | :param feature_size: Pixel pitch of SLM, float. 202 | :param prop_type: Type of propagation, string, default 'ASM'. 203 | :param F_aperture: Level of filtering at Fourier plane, float, default 1.0 (and circular). 204 | :param prop_dists_from_wrp: An array of propagation distances from Intermediate plane to Target planes. 205 | :param linear_conv: If true, pad before taking Fourier transform to ensure the linear convolution. 206 | :param slm_res: Resolution of SLM. 207 | :param roi_res: Resolution of Region of Interest. 208 | :param num_downs_slm: Number of layers of U-net at SLM network. 209 | :param num_feats_slm_min: Number of features at the top layer of SLM network. 210 | :param num_feats_slm_max: Number of features at the very bottom layer of SLM network. 211 | :param num_downs_target: Number of layers of U-net at target network. 212 | :param num_feats_target_min: Number of features at the top layer of target network. 213 | :param num_feats_target_max: Number of features at the very bottom layer of target network. 214 | :param norm: normalization layers. 215 | :param slm_coord: input/output format of SLM network. 216 | :param target_coord: input/output format of target network.). 217 | :param loss_func: Loss function to train the model. 218 | :param lr: a learning rate. 219 | :param plane_idxs: a dictionary that has plane idxs for 'train', 'val', 'test', 'heldout'. 220 | """ 221 | 222 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0, 223 | prop_dists_from_wrp=None, linear_conv=True, slm_res=(1080, 1920), roi_res=(960, 1680), 224 | num_downs_slm=0, num_feats_slm_min=0, num_feats_slm_max=0, 225 | num_downs_target=0, num_feats_target_min=0, num_feats_target_max=0, 226 | norm=nn.InstanceNorm2d, slm_coord='rect', target_coord='rect', 227 | loss_func=F.l1_loss, lr=4e-4, 228 | plane_idxs=None): 229 | super(CNNpropCNN, self).__init__(roi_res=roi_res, 230 | plane_idxs=plane_idxs, loss_func=loss_func, lr=lr) 231 | 232 | ################## 233 | # Model pipeline # 234 | ################## 235 | # SLM Network 236 | if num_downs_slm > 0: 237 | slm_cnns = [] 238 | slm_cnn_res = tuple(res if res % (2 ** num_downs_slm) == 0 else 239 | res + (2 ** num_downs_slm - res % (2 ** num_downs_slm)) 240 | for res in slm_res) 241 | print(slm_cnn_res,' res') 242 | 243 | slm_input = Field2Input(slm_cnn_res, coord=slm_coord) 244 | slm_cnns += [slm_input] 245 | 246 | if num_downs_slm > 0: 247 | slm_cnn = UnetGenerator(input_nc=4 if 'both' in slm_coord else 2, output_nc=2, 248 | num_downs=num_downs_slm, nf0=num_feats_slm_min, 249 | max_channels=num_feats_slm_max, norm_layer=norm, outer_skip=True) 250 | init_weights(slm_cnn, init_type='normal') 251 | slm_cnns += [slm_cnn] 252 | 253 | slm_output = Output2Field(slm_res, slm_coord) 254 | slm_cnns += [slm_output] 255 | self.slm_cnn = nn.Sequential(*slm_cnns) 256 | else: 257 | self.slm_cnn = None 258 | 259 | # Propagation from the SLM plane to the WRP. 260 | if prop_dist != 0.: 261 | self.prop_slm_wrp = prop_ideal.Propagation(prop_dist, wavelength, feature_size, 262 | prop_type=prop_type, linear_conv=linear_conv, 263 | F_aperture=F_aperture) 264 | else: 265 | self.prop_slm_wrp = None 266 | 267 | # Propagation from the WRP to other planes. 268 | if prop_dists_from_wrp is not None: 269 | self.prop_wrp_target = prop_ideal.Propagation(prop_dists_from_wrp, wavelength, feature_size, 270 | prop_type=prop_type, linear_conv=1.0, 271 | F_aperture=F_aperture) 272 | else: 273 | self.prop_wrp_target = None 274 | 275 | # Target network (This is either included (prop later) or not (prop before, which is then basically NH3D). 276 | if num_downs_target > 0: 277 | target_cnn_res = tuple(res if res % (2 ** num_downs_target) == 0 else 278 | res + (2 ** num_downs_target - res % (2 ** num_downs_target)) for res in slm_res) 279 | target_input = Field2Input(target_cnn_res, coord=target_coord, shared_cnn=True) 280 | input_nc_target = 4 if 'both' in target_coord else 2 if target_coord != 'amp' else 1 281 | output_nc_target = 2 if target_coord != 'amp' and ('1ch_output' not in target_coord) else 1 282 | target_cnn = UnetGenerator(input_nc=input_nc_target, output_nc=output_nc_target, 283 | num_downs=num_downs_target, nf0=num_feats_target_min, 284 | max_channels=num_feats_target_max, norm_layer=norm, outer_skip=True) 285 | init_weights(target_cnn, init_type='normal') 286 | 287 | # shared target cnn requires permutation in channels here. 288 | num_ch_output = 1 if not prop_dists_from_wrp else len(self.prop_wrp_target) 289 | target_output = Output2Field(slm_res, target_coord, num_ch_output=num_ch_output) 290 | target_cnns = [target_input, target_cnn, target_output] 291 | self.target_cnn = nn.Sequential(*target_cnns) 292 | else: 293 | self.target_cnn = None 294 | 295 | def forward(self, field): 296 | if self.slm_cnn is not None: 297 | slm_field = self.slm_cnn(field) # Applying CNN at SLM plane. 298 | else: 299 | slm_field = field 300 | if self.prop_slm_wrp is not None: 301 | wrp_field = self.prop_slm_wrp(slm_field) # Propagation from SLM to Intermediate plane. 302 | if self.prop_wrp_target is not None: 303 | target_field = self.prop_wrp_target(wrp_field) # Propagation from Intermediate plane to Target planes. 304 | if self.target_cnn is not None: 305 | amp = self.target_cnn(target_field).abs() # Applying CNN at Target planes. 306 | phase = target_field.angle() 307 | output_field = amp * torch.exp(1j * phase) 308 | else: 309 | output_field = target_field 310 | 311 | return output_field 312 | 313 | def epoch_end_images(self, prefix): 314 | """ 315 | execute at the end of epochs 316 | 317 | :param prefix: 318 | :return: 319 | """ 320 | ################# 321 | # Reconstructions 322 | logger = self.logger.experiment 323 | recon_amp = self.recon_amp[prefix][0] 324 | target_amp = self.target_amp[prefix][0] 325 | for i in range(recon_amp.shape[0]): 326 | logger.add_image(f'amp_recon/{prefix}_{i}', recon_amp[i:i+1, ...].clip(0, 1), self.global_step) 327 | logger.add_image(f'amp_target/{prefix}_{i}', target_amp[i:i+1, ...].clip(0, 1), self.global_step) 328 | 329 | @property 330 | def plane_idx(self): 331 | return self._plane_idx 332 | 333 | @plane_idx.setter 334 | def plane_idx(self, idx): 335 | """ 336 | 337 | """ 338 | if idx is None: 339 | return 340 | self._plane_idx = idx 341 | if self.prop_wrp_target is not None and len(self.prop_wrp_target) > 1: 342 | self.prop_wrp_target.plane_idx = idx 343 | if self.target_cnn is not None and self.target_cnn[-1].num_ch_output > 1: 344 | self.target_cnn[-1].num_ch_output = 1 345 | 346 | 347 | class PropNH(PropModel): 348 | """ 349 | A parameterized model proposed in the original Neural Holography paper (Peng et al. 2020) 350 | 351 | Class initialization parameters 352 | ------------------------------- 353 | """ 354 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0, 355 | prop_dists_from_wrp=None, linear_conv=True, slm_res=(1080, 1920), 356 | roi_res=(1080, 1920), plane_idxs=None, loss_func=F.l1_loss, lr=4e-4, 357 | num_gaussians=3, init_sigma=(1300.0, 1500.0, 1700.0), init_amp=0.9, 358 | num_zernike_slm=0, num_zernike_f=5, init_coeffs=0.0, use_conv1d_mlp=True, 359 | num_layers=3, num_features=16, num_latent_codes=None, norm=nn.GroupNorm, 360 | target_field=True, content_field=True, num_layers_cnn=5, num_feats_cnn=8): 361 | super(PropNH, self).__init__(roi_res=roi_res, 362 | plane_idxs=plane_idxs, loss_func=loss_func, lr=lr) 363 | 364 | self.prop_slm_wrp = None 365 | self.prop_wrp_target = None 366 | self.lut_slm = None # Section 5.1.3. Phase nonlinearity 367 | self.code_slm = None 368 | self.a_src = None # Section 5.1.1. Content-independent Source 369 | self.phi_znk_slm = None # Section 5.1.2 Modeling Optical Propagation with Aberrations 370 | self.znk_basis_slm = None 371 | self.preloaded_znk_slm = False 372 | self.phi_znk_f = None 373 | self.znk_basis_f = None 374 | self.preloaded_znk_f = False 375 | self.coeffs_slm = None 376 | self.coeffs_f = None 377 | self.preload_zernike = False 378 | self.a_t = None # Section 5.1.1. Target Field variation 379 | self.phi_t = None 380 | self.cnn = None # Section 5.1.4. Content-dependent Undiffracted Light 381 | 382 | # Section 3. Propagation from the SLM plane to the WRP. 383 | if prop_dist != 0.: 384 | self.prop_slm_wrp = prop_ideal.Propagation(prop_dist, wavelength, feature_size, 385 | prop_type=prop_type, linear_conv=linear_conv, 386 | F_aperture=F_aperture) 387 | # Propagation from the WRP to other planes. 388 | if prop_dists_from_wrp is not None: 389 | self.prop_wrp_target = prop_ideal.Propagation(prop_dists_from_wrp, wavelength, feature_size, 390 | prop_type=prop_type, linear_conv=linear_conv, 391 | F_aperture=F_aperture) 392 | 393 | # Section 5.1.1. Content-independent Source 394 | if num_gaussians: 395 | self.a_src = SourceAmplitude(num_gaussians, init_sigma, init_amp=init_amp, x_s0=0.0, y_s0=0.0) 396 | 397 | # Section 5.1.1. Target Field variation 398 | if target_field: 399 | self.a_t = nn.Parameter(0.07 * torch.ones(1, 1, *slm_res)) 400 | self.phi_t = nn.Parameter(torch.zeros((1, 1, *slm_res))) 401 | 402 | # Section 5.1.2 Modeling Optical Propagation with Aberrations 403 | if num_zernike_slm: 404 | self.coeffs_slm = nn.Parameter(torch.ones(num_zernike_slm) * init_coeffs) 405 | if num_zernike_f: 406 | self.coeffs_f = nn.Parameter(torch.ones(num_zernike_f) * init_coeffs) 407 | 408 | # Section 5.1.3. Phase nonlinearity 409 | if num_latent_codes is None: 410 | num_latent_codes = [2, 0, 0] 411 | if sum(num_latent_codes) > 0: 412 | self.code_slm = nn.Parameter(torch.zeros(1, sum(num_latent_codes), *slm_res)) 413 | if use_conv1d_mlp: 414 | self.lut_slm = LatentCodedMLP(num_layers, num_features, norm=norm, num_latent_codes=num_latent_codes) 415 | 416 | # Section 5.1.4. Content-dependent Undiffracted Light 417 | if content_field: 418 | target_cnn = [ContentDependentField(num_layers=num_layers_cnn, 419 | num_features=num_feats_cnn, 420 | norm=nn.GroupNorm)] 421 | target_cnn += [Output2Field(slm_res, 'rect')] 422 | self.cnn = nn.Sequential(*target_cnn) 423 | 424 | self.conv = Conv2dField() 425 | 426 | def forward(self, u_in): 427 | """ Implementation of Equation (6) in the Neural Holography paper """ 428 | if torch.is_complex(u_in): 429 | phi_in = u_in.angle() 430 | else: 431 | phi_in = u_in 432 | 433 | # Section 5.1.3. Phase nonlinearity 434 | if self.lut_slm is not None: 435 | slm_phase = self.lut_slm(phi_in, self.code_slm) 436 | else: 437 | slm_phase = phi_in 438 | 439 | # Section 5.1.2. precompute the zernike basis only once 440 | if self.znk_basis_slm is None and self.coeffs_slm is not None: 441 | self.znk_basis_slm = compute_zernike_basis(self.coeffs_slm.size()[0], 442 | slm_phase.size()[-2:], wo_piston=True) 443 | self.znk_basis_slm = self.znk_basis_slm.to(u_in.device).detach().requires_grad_(False) 444 | 445 | if not self.preloaded_znk_slm and self.phi_znk_slm is None and self.coeffs_slm is not None: 446 | self.phi_znk_slm = combine_zernike_basis(self.coeffs_slm, self.znk_basis_slm) 447 | self.phi_znk_slm = self.phi_znk_slm.to(u_in.device) 448 | 449 | if self.preload_zernike: 450 | self.preload_zernike_slm() 451 | 452 | if self.phi_znk_slm is not None: 453 | slm_phase = slm_phase + self.phi_znk_slm 454 | 455 | # Section 5.1.1. Create Source Amplitude (DC + gaussians) 456 | if self.a_src is not None: 457 | slm_field = self.a_src(slm_phase) * torch.exp(1j * slm_phase) 458 | else: 459 | slm_field = torch.exp(1j * slm_phase) 460 | 461 | # Section 5.1.2. precompute the zernike basis only once 462 | if self.znk_basis_f is None and self.coeffs_f is not None: 463 | self.znk_basis_f = compute_zernike_basis(self.coeffs_f.size()[0], 464 | [i * 2 for i in slm_phase.size()[-2:]], 465 | wo_piston=True) 466 | self.znk_basis_f = self.znk_basis_f.to(u_in.device).detach().requires_grad_(False) 467 | 468 | if not self.preloaded_znk_f and self.phi_znk_f is None and self.coeffs_f is not None: 469 | self.phi_znk_f = combine_zernike_basis(self.coeffs_f, self.znk_basis_f).to(u_in.device) 470 | 471 | if self.preload_zernike: 472 | self.preload_zernike_f() 473 | 474 | self.prop_slm_wrp.fourier_phase = self.phi_znk_f 475 | 476 | # Propagation from SLM to Intermediate plane 477 | recon_field = self.prop_slm_wrp(slm_field) 478 | 479 | # Section 5.1.1. Content-independent field at target plane 480 | if self.a_t is not None and self.phi_t is not None: 481 | recon_field = recon_field + self.a_t * torch.exp(1j * self.phi_t) 482 | 483 | # Section 5.1.4. Content-dependent Undiffracted light 484 | if self.cnn is not None: 485 | recon_field = recon_field + self.cnn(phi_in) 486 | 487 | # Propagation from Intermediate plane to Target planes. 488 | if self.prop_wrp_target is not None: 489 | target_field = self.prop_wrp_target(recon_field) 490 | 491 | if self.conv is not None: 492 | return self.conv(target_field) * torch.exp(1j * target_field.angle()) 493 | else: 494 | return target_field 495 | 496 | def epoch_end_images(self, prefix): 497 | """ execute at the end of epochs """ 498 | 499 | # Reconstructions 500 | logger = self.logger.experiment 501 | recon_amp = self.recon_amp[prefix][0] 502 | target_amp = self.target_amp[prefix][0] 503 | for i in range(recon_amp.shape[0]): 504 | logger.add_image(f'amp_recon/{prefix}_{i}', recon_amp[i:i+1, ...].clip(0, 1), self.global_step) 505 | logger.add_image(f'amp_target/{prefix}_{i}', target_amp[i:i+1, ...].clip(0, 1), self.global_step) 506 | 507 | def preload_zernike_slm(self): 508 | self.preload_zernike_slm = True 509 | if self.phi_znk_slm is not None: 510 | self.phi_znk_slm = self.phi_znk_slm.detach().requires_grad_(False) 511 | 512 | def preload_zernike_f(self): 513 | self.preload_zernike_f = True 514 | if self.phi_znk_f is not None: 515 | self.phi_znk_f = self.phi_znk_f.detach().requires_grad_(False) 516 | self.preload_zernike = False 517 | 518 | def preload_H(self): 519 | self.preload_zernike = True 520 | 521 | @property 522 | def plane_idx(self): 523 | return self._plane_idx 524 | 525 | @plane_idx.setter 526 | def plane_idx(self, idx): 527 | """ """ 528 | if idx is None: 529 | return 530 | self._plane_idx = idx 531 | if self.prop_wrp_target is not None and len(self.prop_wrp_target) > 1: 532 | self.prop_wrp_target.plane_idx = idx 533 | -------------------------------------------------------------------------------- /prop_physical.py: -------------------------------------------------------------------------------- 1 | """ 2 | Propagation happening on the setup 3 | 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import utils 9 | import time 10 | import cv2 11 | import numpy as np 12 | import imageio 13 | 14 | import sys 15 | if sys.platform == 'win32': 16 | import slmpy 17 | import hw.camera_capture_module as cam 18 | import hw.calibration_module as calibration 19 | 20 | 21 | class PhysicalProp(nn.Module): 22 | """ A module for physical propagation, 23 | forward pass displays gets SLM pattern as an input and display the pattern on the physical setup, 24 | and capture the diffraction image at the target plane, 25 | and then return warped image using pre-calibrated homography from instantiation. 26 | 27 | Class initialization parameters 28 | ------------------------------- 29 | :param params_slm: a set of parameters for the SLM. 30 | :param params_camera: a set of parameters for the camera sensor. 31 | :param params_calib: a set of parameters for homography calibration. 32 | 33 | Usage 34 | ----- 35 | Functions as a pytorch module: 36 | 37 | >>> camera_prop = PhysicalProp(...) 38 | >>> captured_amp = camera_prop(slm_phase) 39 | 40 | slm_phase: phase at the SLM plane, with dimensions [batch, 1, height, width] 41 | captured_amp: amplitude at the target plane, with dimensions [batch, 1, height, width] 42 | 43 | """ 44 | def __init__(self, params_slm, params_camera, params_calib=None): 45 | super(PhysicalProp, self).__init__() 46 | 47 | # 1. Connect Camera 48 | self.camera = cam.CameraCapture(params_camera) 49 | self.camera.connect(0) # specify the camera to use, 0 for main cam, 1 for the second cam 50 | self.camera.start_capture() 51 | 52 | # 2. Connect SLM 53 | self.slm = slmpy.SLMdisplay(isImageLock=True, monitor=params_slm.monitor_num) 54 | self.params_slm = params_slm 55 | 56 | # 3. Calibrate hardware using homography 57 | if params_calib is not None: 58 | self.warper = calibration.Warper(params_calib) 59 | self.calibrate(params_calib.phase_path, params_calib.show_preview) 60 | else: 61 | self.warper = None 62 | 63 | def calibrate(self, phase_path, show_preview=False): 64 | """ 65 | 66 | :param phase_path: 67 | :param show_preview: 68 | :return: 69 | """ 70 | print(' -- Calibrating ...') 71 | phase_img = imageio.imread(phase_path) 72 | self.slm.updateArray(phase_img) 73 | time.sleep(self.params_slm.settle_time) 74 | captured_img = self.camera.grab_images_fast(5) # capture 5-10 images for averaging 75 | calib_success = self.warper.calibrate(captured_img, show_preview) 76 | if calib_success: 77 | print(' -- Calibration succeeded!...') 78 | else: 79 | raise ValueError(' -- Calibration failed') 80 | 81 | def forward(self, slm_phase): 82 | """ 83 | 84 | :param slm_phase: 85 | :return: 86 | """ 87 | raw_intensity = self.capture_linear_intensity(slm_phase) # grayscale raw16 intensity image 88 | warped_intensity = self.warper(raw_intensity) # apply homography 89 | return warped_intensity.sqrt() # return amplitude 90 | 91 | def capture_linear_intensity(self, slm_phase): 92 | """ 93 | display a phase pattern on the SLM and capture a generated holographic image with the sensor. 94 | 95 | :param slm_phase: 96 | :return: 97 | """ 98 | raw_uint16_data = self.capture_uint16(slm_phase) # display & retrieve buffer 99 | captured_intensity = self.process_raw_data(raw_uint16_data) # demosaick & sum up 100 | return captured_intensity 101 | 102 | def capture_uint16(self, slm_phase): 103 | """ 104 | gets phase pattern(s) and display it on the SLM, and then send a signal to board (wait next clock from SLM). 105 | Right after hearing back from the SLM, it sends another signal to PC so that PC retreives the camera buffer. 106 | 107 | :param slm_phase: 108 | :return: 109 | """ 110 | if torch.is_tensor(slm_phase): 111 | slm_phase_encoded = utils.phasemap_8bit(slm_phase) 112 | else: 113 | slm_phase_encoded = slm_phase 114 | self.slm.updateArray(slm_phase_encoded) 115 | 116 | if self.camera.params.ser is not None: 117 | self.camera.params.ser.write(f'D'.encode()) 118 | 119 | # TODO: make the following in a separate function. 120 | # Wait until receiving signal from arduino 121 | incoming_byte = self.camera.params.ser.inWaiting() 122 | t0 = time.perf_counter() 123 | while True: 124 | received = self.camera.params.ser.read(incoming_byte).decode('UTF-8') 125 | if received != 'C': 126 | incoming_byte = self.camera.params.ser.inWaiting() 127 | if time.perf_counter() - t0 > 2.0: 128 | break 129 | else: 130 | break 131 | else: 132 | time.sleep(self.params_slm.settle_time) 133 | 134 | raw_data_from_buffer = self.camera.retrieve_buffer() 135 | return raw_data_from_buffer 136 | 137 | def process_raw_data(self, raw_data): 138 | """ 139 | gets raw data from the camera buffer, and demosaick it 140 | 141 | :param raw_data: 142 | :return: 143 | """ 144 | raw_data = raw_data - 64 145 | color_cv_image = cv2.cvtColor(raw_data, self.camera.demosaick_rule) # it gives float64 from uint16 -- double check it 146 | captured_intensity = utils.im2float(color_cv_image) # float64 to float32 147 | 148 | # Numpy to tensor 149 | captured_intensity = torch.tensor(captured_intensity, dtype=torch.float32, 150 | device=self.dev).permute(2, 0, 1).unsqueeze(0) 151 | captured_intensity = torch.sum(captured_intensity, dim=1, keepdim=True) 152 | return captured_intensity 153 | 154 | def disconnect(self): 155 | self.camera.stop_capture() 156 | self.camera.disconnect() 157 | self.slm.close() 158 | 159 | def to(self, *args, **kwargs): 160 | slf = super().to(*args, **kwargs) 161 | if slf.warper is not None: 162 | slf.warper = slf.warper.to(*args, **kwargs) 163 | try: 164 | slf.dev = next(slf.parameters()).device 165 | except StopIteration: # no parameters 166 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0] 167 | if device_arg is not None: 168 | slf.dev = device_arg 169 | return slf -------------------------------------------------------------------------------- /prop_submodules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modules for propagation 3 | 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import utils 10 | from unet import Conv2dSame 11 | 12 | 13 | class Field2Input(nn.Module): 14 | """Gets complex-valued field and turns it into multi-channel images""" 15 | 16 | def __init__(self, input_res=(800, 1280), coord='rect', shared_cnn=False): 17 | super(Field2Input, self).__init__() 18 | self.input_res = input_res 19 | self.coord = coord.lower() 20 | self.shared_cnn = shared_cnn 21 | 22 | def forward(self, input_field): 23 | # If input field is slm phase 24 | if input_field.dtype == torch.float32: 25 | input_field = torch.exp(1j * input_field) 26 | 27 | input_field = utils.pad_image(input_field, self.input_res, pytorch=True, stacked_complex=False) 28 | input_field = utils.crop_image(input_field, self.input_res, pytorch=True, stacked_complex=False) 29 | 30 | # To use shared CNN, put everything into batch dimension; 31 | if self.shared_cnn: 32 | num_mb, num_dists = input_field.shape[0], input_field.shape[1] 33 | input_field = input_field.reshape(num_mb*num_dists, 1, *input_field.shape[2:]) 34 | 35 | # Input format 36 | if self.coord == 'rect': 37 | stacked_input = torch.cat((input_field.real, input_field.imag), 1) 38 | elif self.coord == 'polar': 39 | stacked_input = torch.cat((input_field.abs(), input_field.angle()), 1) 40 | elif self.coord == 'amp': 41 | stacked_input = input_field.abs() 42 | elif 'both' in self.coord: 43 | stacked_input = torch.cat((input_field.abs(), input_field.angle(), input_field.real, input_field.imag), 1) 44 | 45 | return stacked_input 46 | 47 | 48 | class Output2Field(nn.Module): 49 | """Gets complex-valued field and turns it into multi-channel images""" 50 | 51 | def __init__(self, output_res=(800, 1280), coord='rect', num_ch_output=1): 52 | super(Output2Field, self).__init__() 53 | self.output_res = output_res 54 | self.coord = coord.lower() 55 | self.num_ch_output = num_ch_output # number of channels in output 56 | 57 | def forward(self, stacked_output): 58 | 59 | if self.coord in ('rect', 'both'): 60 | complex_valued_field = torch.view_as_complex(stacked_output.unsqueeze(4). 61 | permute(0, 4, 2, 3, 1).contiguous()) 62 | elif self.coord == 'polar': 63 | amp = stacked_output[:, 0:1, ...] 64 | phi = stacked_output[:, 1:2, ...] 65 | complex_valued_field = amp * torch.exp(1j * phi) 66 | elif self.coord == 'amp' or '1ch_output' in self.coord: 67 | complex_valued_field = stacked_output * torch.exp(1j * torch.zeros_like(stacked_output)) 68 | 69 | output_field = utils.pad_image(complex_valued_field, self.output_res, pytorch=True, stacked_complex=False) 70 | output_field = utils.crop_image(output_field, self.output_res, pytorch=True, stacked_complex=False) 71 | 72 | if self.num_ch_output > 1: 73 | # reshape to original tensor shape 74 | output_field = output_field.reshape(output_field.shape[0] // self.num_ch_output, self.num_ch_output, 75 | *output_field.shape[2:]) 76 | 77 | return output_field 78 | 79 | 80 | class Conv2dField(nn.Module): 81 | """Apply 2d conv on amp or field""" 82 | 83 | def __init__(self, comp=False, conv_size=3): 84 | super(Conv2dField, self).__init__() 85 | self.comp = comp # apply convolution on field 86 | self.conv_size = (conv_size, conv_size) 87 | if self.comp: 88 | self.conv_real = Conv2dSame(1, 1, conv_size) 89 | self.conv_imag = Conv2dSame(1, 1, conv_size) 90 | init_weight = torch.zeros(1, 1, *self.conv_size) 91 | init_weight[..., conv_size//2, conv_size//2] = 1. 92 | self.conv_real.net[1].weight = nn.Parameter(init_weight.detach().requires_grad_(True)) 93 | self.conv_imag.net[1].weight = nn.Parameter(init_weight.detach().requires_grad_(True)) 94 | else: 95 | self.conv = Conv2dSame(1, 1, conv_size, bias=False) 96 | init_weight = torch.zeros(1, 1, *self.conv_size) 97 | init_weight[..., conv_size//2, conv_size//2] = 1. 98 | self.conv.net[1].weight = nn.Parameter(init_weight.requires_grad_(True)) 99 | 100 | def forward(self, input_field): 101 | 102 | # reshape tensor if number of channels > 1 103 | num_ch = input_field.shape[1] 104 | if num_ch > 1: 105 | batch_size = input_field.shape[0] 106 | input_field = input_field.reshape(batch_size * num_ch, 1, *input_field.shape[2:]) 107 | 108 | if self.comp: 109 | # apply conv on complex fields 110 | real = self.conv_real(input_field.real) - self.conv_imag(input_field.imag) 111 | imag = self.conv_real(input_field.imag) + self.conv_imag(input_field.real) 112 | output_field = torch.view_as_complex(torch.stack((real, imag), -1)) 113 | else: 114 | # apply conv on intensity 115 | output_amp = self.conv(input_field.abs()**2).abs().mean(dim=1, keepdims=True).sqrt() 116 | output_field = output_amp * torch.exp(1j * input_field.angle()) 117 | 118 | # reshape to original tensor shape 119 | if num_ch > 1: 120 | output_field = output_field.reshape(batch_size, num_ch, *output_field.shape[2:]) 121 | 122 | return output_field 123 | 124 | 125 | class LatentCodedMLP(nn.Module): 126 | """ 127 | concatenate latent codes in the middle of forward pass as well. 128 | put latent codes shape of (1, L, H, W) as a parameter for the forward pass. 129 | num_latent_codes: list of numbers of slices for each layer 130 | * so the sum of num_latent_codes should be total number of the latent codes channels 131 | """ 132 | def __init__(self, num_layers=5, num_features=32, norm=None, num_latent_codes=None): 133 | super(LatentCodedMLP, self).__init__() 134 | 135 | if num_latent_codes is None: 136 | num_latent_codes = [0] * num_layers 137 | 138 | assert len(num_latent_codes) == num_layers 139 | 140 | self.num_latent_codes = num_latent_codes 141 | self.idxs = [sum(num_latent_codes[:y]) for y in range(num_layers + 1)] 142 | self.nets = nn.ModuleList([]) 143 | num_features = [num_features] * num_layers 144 | num_features[0] = 1 145 | 146 | # define each layer 147 | for i in range(num_layers - 1): 148 | net = [nn.Conv2d(num_features[i] + num_latent_codes[i], num_features[i + 1], kernel_size=1)] 149 | if norm is not None: 150 | net += [norm(num_groups=4, num_channels=num_features[i + 1], affine=True)] 151 | net += [nn.LeakyReLU(0.2, True)] 152 | self.nets.append(nn.Sequential(*net)) 153 | 154 | self.nets.append(nn.Conv2d(num_features[-1] + num_latent_codes[-1], 1, kernel_size=1)) 155 | 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.normal_(m.weight, std=0.05) 159 | 160 | def forward(self, phases, latent_codes=None): 161 | 162 | after_relu = phases 163 | # concatenate latent codes at each layer and send through the convolutional layers 164 | for i in range(len(self.num_latent_codes)): 165 | if latent_codes is not None: 166 | after_relu = torch.cat((after_relu, latent_codes[:, self.idxs[i]:self.idxs[i + 1], ...]), 1) 167 | after_relu = self.nets[i](after_relu) 168 | 169 | # residual connection 170 | return phases - after_relu 171 | 172 | 173 | class ContentDependentField(nn.Module): 174 | def __init__(self, num_layers=5, num_features=32, norm=nn.GroupNorm, latent_coords=False): 175 | """ Simple 5layers CNN modeling content dependent undiffracted light """ 176 | 177 | super(ContentDependentField, self).__init__() 178 | 179 | if not latent_coords: 180 | first_ch = 1 181 | else: 182 | first_ch = 3 183 | 184 | net = [Conv2dSame(first_ch, num_features, kernel_size=3)] 185 | 186 | for i in range(num_layers - 2): 187 | if norm is not None: 188 | net += [norm(num_groups=2, num_channels=num_features, affine=True)] 189 | net += [nn.LeakyReLU(0.2, True), 190 | Conv2dSame(num_features, num_features, kernel_size=3)] 191 | 192 | if norm is not None: 193 | net += [norm(num_groups=4, num_channels=num_features, affine=True)] 194 | 195 | net += [nn.LeakyReLU(0.2, True), 196 | Conv2dSame(num_features, 2, kernel_size=3)] 197 | 198 | self.net = nn.Sequential(*net) 199 | 200 | def forward(self, phases, latent_coords=None): 201 | if latent_coords is not None: 202 | input_cnn = torch.cat((phases, latent_coords), dim=1) 203 | else: 204 | input_cnn = phases 205 | 206 | return self.net(input_cnn) 207 | 208 | 209 | class ProcessPhase(nn.Module): 210 | def __init__(self, num_layers=5, num_features=32, num_output_feat=0, norm=nn.BatchNorm2d, num_latent_codes=0): 211 | super(ProcessPhase, self).__init__() 212 | 213 | # avoid zero 214 | self.num_output_feat = max(num_output_feat, 1) 215 | self.num_latent_codes = num_latent_codes 216 | 217 | # a bunch of 1x1 conv layers, set by num_layers 218 | net = [nn.Conv2d(1 + num_latent_codes, num_features, kernel_size=1)] 219 | 220 | for i in range(num_layers - 2): 221 | if norm is not None: 222 | net += [norm(num_groups=2, num_channels=num_features, affine=True)] 223 | net += [nn.LeakyReLU(0.2, True), 224 | nn.Conv2d(num_features, num_features, kernel_size=1)] 225 | 226 | if norm is not None: 227 | net += [norm(num_groups=2, num_channels=num_features, affine=True)] 228 | 229 | net += [nn.ReLU(True), 230 | nn.Conv2d(num_features, self.num_output_feat, kernel_size=1)] 231 | 232 | self.net = nn.Sequential(*net) 233 | 234 | def forward(self, phases): 235 | return phases - self.net(phases) 236 | 237 | 238 | class SourceAmplitude(nn.Module): 239 | def __init__(self, num_gaussians=3, init_sigma=None, init_amp=0.7, x_s0=0.0, y_s0=0.0): 240 | super(SourceAmplitude, self).__init__() 241 | 242 | self.num_gaussians = num_gaussians 243 | 244 | if init_sigma is None: 245 | init_sigma = [100.] * self.num_gaussians # default to 100 for all 246 | 247 | # create parameters for source amplitudes 248 | self.sigmas = nn.Parameter(torch.tensor(init_sigma)) 249 | self.x_s = nn.Parameter(torch.ones(num_gaussians) * x_s0) 250 | self.y_s = nn.Parameter(torch.ones(num_gaussians) * y_s0) 251 | self.amplitudes = nn.Parameter(torch.ones(num_gaussians) / (num_gaussians) * init_amp) 252 | self.dc_term = nn.Parameter(torch.zeros(1)) 253 | 254 | self.x_dim = None 255 | self.y_dim = None 256 | 257 | def forward(self, phases): 258 | # create DC term, then add the gaussians 259 | source_amp = torch.ones_like(phases) * self.dc_term 260 | for i in range(self.num_gaussians): 261 | source_amp += self.create_gaussian(phases.shape, i) 262 | 263 | return source_amp 264 | 265 | def create_gaussian(self, shape, idx): 266 | # create sampling grid if needed 267 | if self.x_dim is None or self.y_dim is None: 268 | self.x_dim = torch.linspace(-(shape[-1] - 1) / 2, 269 | (shape[-1] - 1) / 2, 270 | shape[-1], device=self.dc_term.device) 271 | self.y_dim = torch.linspace(-(shape[-2] - 1) / 2, 272 | (shape[-2] - 1) / 2, 273 | shape[-2], device=self.dc_term.device) 274 | 275 | if self.x_dim.device != self.sigmas.device: 276 | self.x_dim.to(self.sigmas.device).detach() 277 | self.x_dim.requires_grad = False 278 | if self.y_dim.device != self.sigmas.device: 279 | self.y_dim.to(self.sigmas.device).detach() 280 | self.y_dim.requires_grad = False 281 | 282 | # offset grid by coordinate and compute x and y gaussian components 283 | x_gaussian = torch.exp(-0.5 * torch.pow(torch.div(self.x_dim - self.x_s[idx], self.sigmas[idx]), 2)) 284 | y_gaussian = torch.exp(-0.5 * torch.pow(torch.div(self.y_dim - self.y_s[idx], self.sigmas[idx]), 2)) 285 | 286 | # outer product with amplitude scaling 287 | gaussian = torch.ger(self.amplitudes[idx] * y_gaussian, x_gaussian) 288 | 289 | return gaussian 290 | 291 | 292 | def make_kernel_gaussian(sigma, kernel_size): 293 | 294 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 295 | x_cord = torch.arange(kernel_size) 296 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) 297 | y_grid = x_grid.t() 298 | xy_grid = torch.stack([x_grid, y_grid], dim=-1) 299 | 300 | mean = (kernel_size - 1) / 2 301 | variance = sigma**2 302 | 303 | # Calculate the 2-dimensional gaussian kernel which is 304 | # the product of two gaussian distributions for two different 305 | # variables (in this case called x and y) 306 | gaussian_kernel = ((1 / (2 * math.pi * variance)) 307 | * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) 308 | / (2 * variance))) 309 | # Make sure sum of values in gaussian kernel equals 1. 310 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 311 | 312 | # Reshape to 2d depthwise convolutional weight 313 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 314 | 315 | return gaussian_kernel 316 | 317 | 318 | def create_gaussian(shape, sigma=800, dev=torch.device('cuda')): 319 | # create sampling grid if needed 320 | shape_min = min(shape[-1], shape[-2]) 321 | x_dim = torch.linspace(-(shape_min - 1) / 2, 322 | (shape_min - 1) / 2, 323 | shape[-1], device=dev) 324 | y_dim = torch.linspace(-(shape_min - 1) / 2, 325 | (shape_min - 1) / 2, 326 | shape[-2], device=dev) 327 | 328 | # offset grid by coordinate and compute x and y gaussian components 329 | x_gaussian = torch.exp(-0.5 * torch.pow(torch.div(x_dim, sigma), 2)) 330 | y_gaussian = torch.exp(-0.5 * torch.pow(torch.div(y_dim, sigma), 2)) 331 | 332 | # outer product with amplitude scaling 333 | gaussian = torch.ger(y_gaussian, x_gaussian) 334 | 335 | return gaussian 336 | 337 | 338 | 339 | 340 | -------------------------------------------------------------------------------- /prop_zernike.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for zernike basis 3 | 4 | """ 5 | 6 | import math 7 | import torch 8 | import numpy as np 9 | import utils 10 | import torch.fft 11 | from aotools.functions import zernikeArray 12 | 13 | 14 | def combine_zernike_basis(coeffs, basis, return_phase=False): 15 | """ 16 | Multiplies the Zernike coefficients and basis functions while preserving 17 | dimensions 18 | 19 | :param coeffs: torch tensor with coeffs, see propagation_ASM_zernike 20 | :param basis: the output of compute_zernike_basis, must be same length as coeffs 21 | :param return_phase: 22 | :return: A float32 tensor that combines coeffs and basis. 23 | """ 24 | 25 | if len(coeffs.shape) < 3: 26 | coeffs = torch.reshape(coeffs, (coeffs.shape[0], 1, 1)) 27 | 28 | # combine zernike basis and coefficients 29 | zernike = (coeffs * basis).sum(0, keepdim=True) 30 | 31 | # shape to [1, len(coeffs), H, W] 32 | zernike = zernike.unsqueeze(0) 33 | 34 | return zernike 35 | 36 | 37 | def compute_zernike_basis(num_polynomials, field_res, dtype=torch.float32, wo_piston=False): 38 | """Computes a set of Zernike basis function with resolution field_res 39 | 40 | num_polynomials: number of Zernike polynomials in this basis 41 | field_res: [height, width] in px, any list-like object 42 | dtype: torch dtype for computation at different precision 43 | """ 44 | 45 | # size the zernike basis to avoid circular masking 46 | zernike_diam = int(np.ceil(np.sqrt(field_res[0]**2 + field_res[1]**2))) 47 | 48 | # create zernike functions 49 | 50 | if not wo_piston: 51 | zernike = zernikeArray(num_polynomials, zernike_diam) 52 | else: # 200427 - exclude pistorn term 53 | idxs = range(2, 2 + num_polynomials) 54 | zernike = zernikeArray(idxs, zernike_diam) 55 | 56 | zernike = utils.crop_image(zernike, field_res, pytorch=False) 57 | 58 | # convert to tensor and create phase 59 | zernike = torch.tensor(zernike, dtype=dtype, requires_grad=False) 60 | 61 | return zernike 62 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for model training 3 | 4 | """ 5 | import os 6 | import configargparse 7 | import utils 8 | import prop_model 9 | import params 10 | import image_loader as loaders 11 | 12 | import pytorch_lightning as pl 13 | from pytorch_lightning import Trainer 14 | from torch.utils.data import DataLoader 15 | 16 | # Command line argument processing 17 | p = configargparse.ArgumentParser() 18 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 19 | params.add_parameters(p, 'train') 20 | opt = p.parse_args() 21 | params.set_configs(opt) 22 | run_id = params.run_id_training(opt) 23 | 24 | def main(): 25 | if ',' in opt.data_path: 26 | opt.data_path = opt.data_path.split(',') 27 | else: 28 | opt.data_path = opt.data_path 29 | print(f' - training a model ... Dataset path:{opt.data_path}') 30 | # Setup up dataloaders 31 | num_workers = 8 32 | train_loader = DataLoader(loaders.PairsLoader(os.path.join(opt.data_path, 'train'), 33 | plane_idxs=opt.plane_idxs['all'], image_res=opt.image_res, 34 | avg_energy_ratio=opt.avg_energy_ratio, slm_type=opt.slm_type), 35 | num_workers=num_workers, batch_size=opt.batch_size) 36 | val_loader = DataLoader(loaders.PairsLoader(os.path.join(opt.data_path, 'val'), 37 | plane_idxs=opt.plane_idxs['all'], image_res=opt.image_res, 38 | shuffle=False, avg_energy_ratio=opt.avg_energy_ratio, 39 | slm_type=opt.slm_type), 40 | num_workers=num_workers, batch_size=opt.batch_size, shuffle=False) 41 | test_loader = DataLoader(loaders.PairsLoader(os.path.join(opt.data_path, 'test'), 42 | plane_idxs=opt.plane_idxs['all'], image_res=opt.image_res, 43 | shuffle=False, avg_energy_ratio=opt.avg_energy_ratio, slm_type=opt.slm_type), 44 | num_workers=num_workers, batch_size=opt.batch_size, shuffle=False) 45 | 46 | # Init model 47 | model = prop_model.model(opt) 48 | model.train() 49 | 50 | # Init root path 51 | root_dir = os.path.join(opt.out_path, run_id) 52 | utils.cond_mkdir(root_dir) 53 | p.write_config_file(opt, [os.path.join(root_dir, 'config.txt')]) 54 | 55 | checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="PSNR_validation_epoch", dirpath=root_dir, 56 | filename="model-{epoch:02d}-{PSNR_validation_epoch:.2f}", 57 | save_top_k=1, mode="max", ) 58 | 59 | # Init trainer 60 | trainer = Trainer(default_root_dir=root_dir, accelerator='gpu', 61 | log_every_n_steps=400, gpus=1, max_epochs=opt.num_epochs, callbacks=[checkpoint_callback]) 62 | 63 | # Fit Model 64 | trainer.fit(model, train_loader, val_loader) 65 | # Test Model 66 | trainer.test(model, dataloaders=test_loader) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | 6 | 7 | def norm_layer(norm_str): 8 | if norm_str.lower() == 'instance': 9 | return nn.InstanceNorm2d 10 | elif norm_str.lower() == 'group': 11 | return nn.GroupNorm 12 | elif norm_str.lower() == 'batch': 13 | return nn.BatchNorm2d 14 | 15 | 16 | class UnetSkipConnectionBlock(nn.Module): 17 | """Defines the Unet submodule with skip connection. 18 | X -------------------identity---------------------- 19 | |-- downsampling -- |submodule| -- upsampling --| 20 | """ 21 | 22 | def __init__(self, outer_nc, inner_nc, input_nc=None, 23 | submodule=None, outermost=False, innermost=False, 24 | norm_layer=nn.InstanceNorm2d, use_dropout=False, 25 | outer_skip=False): 26 | """Construct a Unet submodule with skip connections. 27 | Parameters: 28 | outer_nc (int) -- the number of filters in the outer conv layer 29 | inner_nc (int) -- the number of filters in the inner conv layer 30 | input_nc (int) -- the number of channels in input images/features 31 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 32 | outermost (bool) -- if this module is the outermost module 33 | innermost (bool) -- if this module is the innermost module 34 | norm_layer -- normalization layer 35 | use_dropout (bool) -- if use dropout layers. 36 | """ 37 | super(UnetSkipConnectionBlock, self).__init__() 38 | self.outermost = outermost 39 | self.outer_skip = outer_skip 40 | if norm_layer == None: 41 | use_bias = True 42 | elif type(norm_layer) == functools.partial: 43 | use_bias = norm_layer.func == nn.InstanceNorm2d 44 | else: 45 | use_bias = norm_layer == nn.InstanceNorm2d 46 | if input_nc is None: 47 | input_nc = outer_nc 48 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=5, 49 | # Change kernel size changed to 5 from 4 and padding size from 1 to 2 50 | stride=2, padding=2, bias=use_bias) 51 | downrelu = nn.LeakyReLU(0.2, True) 52 | if norm_layer is not None: 53 | if norm_layer == nn.GroupNorm: 54 | downnorm = norm_layer(8, inner_nc) 55 | else: 56 | downnorm = norm_layer(inner_nc) 57 | else: 58 | downnorm = None 59 | uprelu = nn.ReLU(True) 60 | if norm_layer is not None: 61 | if norm_layer == nn.GroupNorm: 62 | upnorm = norm_layer(8, outer_nc) 63 | else: 64 | upnorm = norm_layer(outer_nc) 65 | else: 66 | upnorm = None 67 | 68 | if outermost: 69 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 70 | kernel_size=4, stride=2, 71 | padding=1) 72 | down = [downconv, downrelu] 73 | up = [upconv] # Removed tanh and uprelu 74 | model = down + [submodule] + up 75 | elif innermost: 76 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 77 | kernel_size=4, stride=2, 78 | padding=1, bias=use_bias) 79 | if norm_layer is not None: 80 | down = [downconv, downnorm, downrelu] 81 | up = [upconv, upnorm, uprelu] 82 | else: 83 | down = [downconv, downrelu] 84 | up = [upconv, uprelu] 85 | 86 | model = down + up 87 | else: 88 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 89 | kernel_size=4, stride=2, 90 | padding=1, bias=use_bias) 91 | if norm_layer is not None: 92 | down = [downconv, downnorm, downrelu] 93 | up = [upconv, upnorm, uprelu] 94 | else: 95 | down = [downconv, downrelu] 96 | up = [upconv, uprelu] 97 | 98 | if use_dropout: 99 | model = down + [submodule] + up + [nn.Dropout(0.5)] 100 | else: 101 | model = down + [submodule] + up 102 | 103 | self.model = nn.Sequential(*model) 104 | 105 | def forward(self, x): 106 | if self.outermost and not self.outer_skip: 107 | return self.model(x) 108 | else: # add skip connections 109 | return torch.cat([x, self.model(x)], 1) 110 | 111 | 112 | def init_latent(latent_num, wavefront_res, ones=False): 113 | if latent_num > 0: 114 | if ones: 115 | latent = nn.Parameter(torch.ones(1, latent_num, *wavefront_res, 116 | requires_grad=True)) 117 | else: 118 | latent = nn.Parameter(torch.zeros(1, latent_num, *wavefront_res, 119 | requires_grad=True)) 120 | else: 121 | latent = None 122 | return latent 123 | 124 | 125 | def apply_net(net, input, latent_code, complex=False): 126 | if net is None: 127 | return input 128 | if complex: # Only valid for single batch or single channel complex inputs and outputs 129 | multi_channel = (input.shape[1] > 1) 130 | if multi_channel: 131 | input = torch.view_as_real(input[0,...]) 132 | else: 133 | input = torch.view_as_real(input[:,0,...]) 134 | input = input.permute(0,3,1,2) 135 | if latent_code is not None: 136 | input = torch.cat((input, latent_code), dim=1) 137 | output = net(input) 138 | if complex: 139 | if multi_channel: 140 | output = output.permute(0,2,3,1).unsqueeze(0) 141 | else: 142 | output = output.permute(0,2,3,1).unsqueeze(1) 143 | output = torch.complex(output[...,0], output[...,1]) 144 | return output 145 | 146 | 147 | def init_weights(net, init_type='normal', init_gain=0.02, outer_skip=False): 148 | """Initialize network weights. 149 | Parameters: 150 | net (network) -- network to be initialized 151 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 152 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 153 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 154 | work better for some applications. Feel free to try yourself. 155 | """ 156 | 157 | def init_func(m): # define the initialization function 158 | classname = m.__class__.__name__ 159 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 160 | if init_type == 'normal': 161 | init.normal_(m.weight.data, 0.0, init_gain) 162 | elif init_type == 'xavier': 163 | init.xavier_normal_(m.weight.data, gain=init_gain) 164 | elif init_type == 'kaiming': 165 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 166 | elif init_type == 'orthogonal': 167 | init.orthogonal_(m.weight.data, gain=init_gain) 168 | else: 169 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 170 | if hasattr(m, 'bias') and m.bias is not None: 171 | init.constant_(m.bias.data, 0.0) 172 | elif classname.find( 173 | 'BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 174 | init.normal_(m.weight.data, 1.0, init_gain) 175 | init.constant_(m.bias.data, 0.0) 176 | 177 | print('initialize network with %s' % init_type) 178 | net.apply(init_func) # apply the initialization function 179 | 180 | 181 | class UnetGenerator(nn.Module): 182 | """Create a Unet-based generator""" 183 | 184 | def __init__(self, input_nc=1, output_nc=1, num_downs=8, nf0=32, max_channels=512, 185 | norm_layer=nn.InstanceNorm2d, use_dropout=False, outer_skip=True, 186 | half_channels=False, eighth_channels=False): 187 | """Construct a Unet generator 188 | Parameters: 189 | input_nc (int) -- the number of channels in input images 190 | output_nc (int) -- the number of channels in output images 191 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 192 | image of size 128x128 will become of size 1x1 # at the bottleneck 193 | ngf (int) -- the number of filters in the last conv layer 194 | norm_layer -- normalization layer 195 | We construct the U-Net from the innermost layer to the outermost layer. 196 | It is a recursive process. 197 | """ 198 | super(UnetGenerator, self).__init__() 199 | self.outer_skip = outer_skip 200 | self.input_nc = input_nc 201 | 202 | if eighth_channels: 203 | divisor = 8 204 | elif half_channels: 205 | divisor = 2 206 | else: 207 | divisor = 1 208 | # construct unet structure 209 | 210 | assert num_downs >= 2 211 | 212 | # Add the innermost layer 213 | unet_block = UnetSkipConnectionBlock(min(2 ** (num_downs - 1) * nf0, max_channels) // divisor, 214 | min(2 ** (num_downs - 1) * nf0, max_channels) // divisor, 215 | input_nc=None, submodule=None, norm_layer=norm_layer, 216 | innermost=True) 217 | 218 | for i in list(range(1, num_downs - 1))[::-1]: 219 | if i == 1: 220 | norm = None # Praneeth's modification 221 | else: 222 | norm = norm_layer 223 | 224 | unet_block = UnetSkipConnectionBlock(min(2 ** i * nf0, max_channels) // divisor, 225 | min(2 ** (i + 1) * nf0, max_channels) // divisor, 226 | input_nc=None, submodule=unet_block, 227 | norm_layer=norm, 228 | use_dropout=use_dropout) 229 | 230 | # Add the outermost layer 231 | self.model = UnetSkipConnectionBlock(min(nf0, max_channels) // divisor, 232 | min(2 * nf0, max_channels) // divisor, 233 | input_nc=input_nc, submodule=unet_block, outermost=True, 234 | norm_layer=None, outer_skip=self.outer_skip) 235 | if self.outer_skip: 236 | self.additional_conv = nn.Conv2d(input_nc + min(nf0, max_channels) // divisor, output_nc, 237 | kernel_size=4, stride=1, padding=2, bias=True) 238 | else: 239 | self.additional_conv = nn.Conv2d(min(nf0, max_channels) // divisor, output_nc, 240 | kernel_size=4, stride=1, padding=2, bias=True) 241 | 242 | def forward(self, cnn_input): 243 | """Standard forward""" 244 | output = self.model(cnn_input) 245 | output = self.additional_conv(output) 246 | output = output[:,:,:-1,:-1] 247 | return output 248 | 249 | 250 | class Conv2dSame(torch.nn.Module): 251 | '''2D convolution that pads to keep spatial dimensions equal. 252 | Cannot deal with stride. Only quadratic kernels (=scalar kernel_size). 253 | ''' 254 | 255 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, padding_layer=nn.ReflectionPad2d): 256 | ''' 257 | :param in_channels: Number of input channels 258 | :param out_channels: Number of output channels 259 | :param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported). 260 | :param bias: Whether or not to use bias. 261 | :param padding_layer: Which padding to use. Default is reflection padding. 262 | ''' 263 | super().__init__() 264 | ka = kernel_size // 2 265 | kb = ka - 1 if kernel_size % 2 == 0 else ka 266 | self.net = nn.Sequential( 267 | padding_layer((ka, kb, ka, kb)), 268 | nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias, stride=1) 269 | ) 270 | 271 | self.weight = self.net[1].weight 272 | self.bias = self.net[1].bias 273 | 274 | def forward(self, x): 275 | return self.net(x) 276 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils 3 | """ 4 | 5 | import math 6 | import random 7 | import numpy as np 8 | 9 | import os 10 | import torch 11 | import torch.nn as nn 12 | 13 | from skimage.metrics import peak_signal_noise_ratio as psnr 14 | from skimage.metrics import structural_similarity as ssim 15 | from torchvision.transforms import GaussianBlur 16 | 17 | from skimage.restoration import inpaint 18 | import kornia 19 | import torch.nn.functional as F 20 | import torch.fft as tfft 21 | 22 | def roll_torch(tensor, shift: int, axis: int): 23 | if shift == 0: 24 | return tensor 25 | 26 | if axis < 0: 27 | axis += tensor.dim() 28 | 29 | dim_size = tensor.size(axis) 30 | after_start = dim_size - shift 31 | if shift < 0: 32 | after_start = -shift 33 | shift = dim_size - abs(shift) 34 | 35 | before = tensor.narrow(axis, 0, dim_size - shift) 36 | after = tensor.narrow(axis, after_start, shift) 37 | return torch.cat([after, before], axis) 38 | 39 | 40 | def ifftshift(tensor): 41 | """ifftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2] 42 | 43 | shifts the width and heights 44 | """ 45 | size = tensor.size() 46 | tensor_shifted = roll_torch(tensor, -math.floor(size[2] / 2.0), 2) 47 | tensor_shifted = roll_torch(tensor_shifted, -math.floor(size[3] / 2.0), 3) 48 | return tensor_shifted 49 | 50 | 51 | def fftshift(tensor): 52 | """fftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2] 53 | 54 | shifts the width and heights 55 | """ 56 | size = tensor.size() 57 | tensor_shifted = roll_torch(tensor, math.floor(size[2] / 2.0), 2) 58 | tensor_shifted = roll_torch(tensor_shifted, math.floor(size[3] / 2.0), 3) 59 | return tensor_shifted 60 | 61 | 62 | def pad_image(field, target_shape, pytorch=True, stacked_complex=True, padval=0, mode='constant'): 63 | """Pads a 2D complex field up to target_shape in size 64 | 65 | Padding is done such that when used with crop_image(), odd and even dimensions are 66 | handled correctly to properly undo the padding. 67 | 68 | field: the field to be padded. May have as many leading dimensions as necessary 69 | (e.g., batch or channel dimensions) 70 | target_shape: the 2D target output dimensions. If any dimensions are smaller 71 | than field, no padding is applied 72 | pytorch: if True, uses torch functions, if False, uses numpy 73 | stacked_complex: for pytorch=True, indicates that field has a final dimension 74 | representing real and imag 75 | padval: the real number value to pad by 76 | mode: padding mode for numpy or torch 77 | """ 78 | if pytorch: 79 | if stacked_complex: 80 | size_diff = np.array(target_shape) - np.array(field.shape[-3:-1]) 81 | odd_dim = np.array(field.shape[-3:-1]) % 2 82 | else: 83 | size_diff = np.array(target_shape) - np.array(field.shape[-2:]) 84 | odd_dim = np.array(field.shape[-2:]) % 2 85 | else: 86 | size_diff = np.array(target_shape) - np.array(field.shape[-2:]) 87 | odd_dim = np.array(field.shape[-2:]) % 2 88 | 89 | # pad the dimensions that need to increase in size 90 | if (size_diff > 0).any(): 91 | pad_total = np.maximum(size_diff, 0) 92 | pad_front = (pad_total + odd_dim) // 2 93 | pad_end = (pad_total + 1 - odd_dim) // 2 94 | 95 | if pytorch: 96 | pad_axes = [int(p) # convert from np.int64 97 | for tple in zip(pad_front[::-1], pad_end[::-1]) 98 | for p in tple] 99 | if stacked_complex: 100 | return pad_stacked_complex(field, pad_axes, mode=mode, padval=padval) 101 | else: 102 | return nn.functional.pad(field, pad_axes, mode=mode, value=padval) 103 | else: 104 | leading_dims = field.ndim - 2 # only pad the last two dims 105 | if leading_dims > 0: 106 | pad_front = np.concatenate(([0] * leading_dims, pad_front)) 107 | pad_end = np.concatenate(([0] * leading_dims, pad_end)) 108 | return np.pad(field, tuple(zip(pad_front, pad_end)), mode, 109 | constant_values=padval) 110 | else: 111 | return field 112 | 113 | 114 | def crop_image(field, target_shape, pytorch=True, stacked_complex=True, lf=False): 115 | """Crops a 2D field, see pad_image() for details 116 | 117 | No cropping is done if target_shape is already smaller than field 118 | """ 119 | if target_shape is None: 120 | return field 121 | 122 | if lf: 123 | size_diff = np.array(field.shape[-4:-2]) - np.array(target_shape) 124 | odd_dim = np.array(field.shape[-4:-2]) % 2 125 | else: 126 | if pytorch: 127 | if stacked_complex: 128 | size_diff = np.array(field.shape[-3:-1]) - np.array(target_shape) 129 | odd_dim = np.array(field.shape[-3:-1]) % 2 130 | else: 131 | size_diff = np.array(field.shape[-2:]) - np.array(target_shape) 132 | odd_dim = np.array(field.shape[-2:]) % 2 133 | else: 134 | size_diff = np.array(field.shape[-2:]) - np.array(target_shape) 135 | odd_dim = np.array(field.shape[-2:]) % 2 136 | 137 | # crop dimensions that need to decrease in size 138 | if (size_diff > 0).any(): 139 | crop_total = np.maximum(size_diff, 0) 140 | crop_front = (crop_total + 1 - odd_dim) // 2 141 | crop_end = (crop_total + odd_dim) // 2 142 | 143 | crop_slices = [slice(int(f), int(-e) if e else None) 144 | for f, e in zip(crop_front, crop_end)] 145 | if lf: 146 | return field[(..., *crop_slices, slice(None), slice(None))] 147 | else: 148 | if pytorch and stacked_complex: 149 | return field[(..., *crop_slices, slice(None))] 150 | else: 151 | return field[(..., *crop_slices)] 152 | else: 153 | return field 154 | 155 | 156 | def cond_mkdir(path): 157 | if not os.path.exists(path): 158 | os.makedirs(path) 159 | 160 | 161 | def phasemap_8bit(phasemap, inverted=True): 162 | """convert a phasemap tensor into a numpy 8bit phasemap that can be directly displayed 163 | 164 | Input 165 | ----- 166 | :param phasemap: input phasemap tensor, which is supposed to be in the range of [-pi, pi]. 167 | :param inverted: a boolean value that indicates whether the phasemap is inverted. 168 | 169 | Output 170 | ------ 171 | :return: output phasemap, with uint8 dtype (in [0, 255]) 172 | """ 173 | 174 | output_phase = ((phasemap + np.pi) % (2 * np.pi)) / (2 * np.pi) 175 | if inverted: 176 | phase_out_8bit = ((1 - output_phase) * 255).round().cpu().detach().squeeze().numpy().astype(np.uint8) # quantized to 8 bits 177 | else: 178 | phase_out_8bit = ((output_phase) * 255).round().cpu().detach().squeeze().numpy().astype(np.uint8) # quantized to 8 bits 179 | return phase_out_8bit 180 | 181 | 182 | def burst_img_processor(img_burst_list): 183 | img_tensor = np.stack(img_burst_list, axis=0) 184 | img_avg = np.mean(img_tensor, axis=0) 185 | return im2float(img_avg) # changed from int8 to float32 186 | 187 | 188 | def im2float(im, dtype=np.float32): 189 | """convert uint16 or uint8 image to float32, with range scaled to 0-1 190 | 191 | :param im: image 192 | :param dtype: default np.float32 193 | :return: 194 | """ 195 | if issubclass(im.dtype.type, np.floating): 196 | return im.astype(dtype) 197 | elif issubclass(im.dtype.type, np.integer): 198 | return im / dtype(np.iinfo(im.dtype).max) 199 | else: 200 | raise ValueError(f'Unsupported data type {im.dtype}') 201 | 202 | 203 | def get_psnr_ssim(recon_amp, target_amp, multichannel=False): 204 | """get PSNR and SSIM metrics""" 205 | psnrs, ssims = {}, {} 206 | 207 | # amplitude 208 | psnrs['amp'] = psnr(target_amp, recon_amp) 209 | ssims['amp'] = ssim(target_amp, recon_amp, multichannel=multichannel) 210 | 211 | # linear 212 | target_linear = target_amp**2 213 | recon_linear = recon_amp**2 214 | psnrs['lin'] = psnr(target_linear, recon_linear) 215 | ssims['lin'] = ssim(target_linear, recon_linear, multichannel=multichannel) 216 | 217 | # srgb 218 | target_srgb = srgb_lin2gamma(np.clip(target_linear, 0.0, 1.0)) 219 | recon_srgb = srgb_lin2gamma(np.clip(recon_linear, 0.0, 1.0)) 220 | psnrs['srgb'] = psnr(target_srgb, recon_srgb) 221 | ssims['srgb'] = ssim(target_srgb, recon_srgb, multichannel=multichannel) 222 | 223 | return psnrs, ssims 224 | 225 | 226 | def make_kernel_gaussian(sigma, kernel_size): 227 | 228 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 229 | x_cord = torch.arange(kernel_size) 230 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) 231 | y_grid = x_grid.t() 232 | xy_grid = torch.stack([x_grid, y_grid], dim=-1) 233 | 234 | mean = (kernel_size - 1) / 2 235 | variance = sigma**2 236 | 237 | # Calculate the 2-dimensional gaussian kernel which is 238 | # the product of two gaussian distributions for two different 239 | # variables (in this case called x and y) 240 | gaussian_kernel = ((1 / (2 * math.pi * variance)) 241 | * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) 242 | / (2 * variance))) 243 | # Make sure sum of values in gaussian kernel equals 1. 244 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 245 | 246 | # Reshape to 2d depthwise convolutional weight 247 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 248 | 249 | return gaussian_kernel 250 | 251 | 252 | def pad_stacked_complex(field, pad_width, padval=0): 253 | if padval == 0: 254 | pad_width = (0, 0, *pad_width) # add 0 padding for stacked_complex dimension 255 | return nn.functional.pad(field, pad_width) 256 | else: 257 | if isinstance(padval, torch.Tensor): 258 | padval = padval.item() 259 | 260 | real, imag = field[..., 0], field[..., 1] 261 | real = nn.functional.pad(real, pad_width, value=padval) 262 | imag = nn.functional.pad(imag, pad_width, value=0) 263 | return torch.stack((real, imag), -1) 264 | 265 | 266 | def srgb_gamma2lin(im_in): 267 | """ converts from sRGB to linear color space """ 268 | thresh = 0.04045 269 | if torch.is_tensor(im_in): 270 | low_val = im_in <= thresh 271 | im_out = torch.zeros_like(im_in) 272 | im_out[low_val] = 25 / 323 * im_in[low_val] 273 | im_out[torch.logical_not(low_val)] = ((200 * im_in[torch.logical_not(low_val)] + 11) 274 | / 211) ** (12 / 5) 275 | else: 276 | im_out = np.where(im_in <= thresh, im_in / 12.92, ((im_in + 0.055) / 1.055) ** (12/5)) 277 | 278 | return im_out 279 | 280 | 281 | def srgb_lin2gamma(im_in): 282 | """ converts from linear to sRGB color space """ 283 | thresh = 0.0031308 284 | im_out = np.where(im_in <= thresh, 12.92 * im_in, 1.055 * (im_in**(1 / 2.4)) - 0.055) 285 | return im_out 286 | 287 | 288 | def decompose_depthmap(depthmap_virtual_D, depth_planes_D): 289 | """ decompose a depthmap image into a set of masks with depth positions (in Diopter) """ 290 | 291 | num_planes = len(depth_planes_D) 292 | 293 | masks = torch.zeros(depthmap_virtual_D.shape[0], len(depth_planes_D), *depthmap_virtual_D.shape[-2:], 294 | dtype=torch.float32).to(depthmap_virtual_D.device) 295 | for k in range(len(depth_planes_D) - 1): 296 | depth_l = depth_planes_D[k] 297 | depth_h = depth_planes_D[k + 1] 298 | idxs = (depthmap_virtual_D >= depth_l) & (depthmap_virtual_D < depth_h) 299 | close_idxs = (depth_h - depthmap_virtual_D) > (depthmap_virtual_D - depth_l) 300 | 301 | # closer one 302 | mask = torch.zeros_like(depthmap_virtual_D) 303 | mask += idxs * close_idxs * 1 304 | masks[:, k, ...] += mask.squeeze(1) 305 | 306 | # farther one 307 | mask = torch.zeros_like(depthmap_virtual_D) 308 | mask += idxs * (~close_idxs) * 1 309 | masks[:, k + 1, ...] += mask.squeeze(1) 310 | 311 | # even closer ones 312 | idxs = depthmap_virtual_D >= max(depth_planes_D) 313 | mask = torch.zeros_like(depthmap_virtual_D) 314 | mask += idxs * 1 315 | masks[:, len(depth_planes_D) - 1, ...] += mask.clone().squeeze(1) 316 | 317 | # even farther ones 318 | idxs = depthmap_virtual_D < min(depth_planes_D) 319 | mask = torch.zeros_like(depthmap_virtual_D) 320 | mask += idxs * 1 321 | masks[:, 0, ...] += mask.clone().squeeze(1) 322 | 323 | # sanity check 324 | assert torch.sum(masks).item() == torch.numel(masks) / num_planes 325 | 326 | return masks 327 | 328 | 329 | def prop_dist_to_diopter(prop_dists, focal_distance, prop_dist_inf, from_lens=True): 330 | """ 331 | Calculates distance from the user in diopter unit given the propagation distance from the SLM. 332 | :param prop_dists: 333 | :param focal_distance: 334 | :param prop_dist_inf: 335 | :param from_lens: 336 | :return: 337 | """ 338 | x0 = prop_dist_inf # prop distance from SLM that correcponds to optical infinity from the user 339 | f = focal_distance # focal distance of eyepiece 340 | 341 | if from_lens: # distance is from the lens 342 | diopters = [1 / (x0 + f - x) - 1 / f for x in prop_dists] # diopters from the user side 343 | else: # distance is from the user (basically adding focal length) 344 | diopters = [(x - x0) / f**2 for x in prop_dists] 345 | 346 | return diopters 347 | 348 | 349 | class PSNR: 350 | """Peak Signal to Noise Ratio 351 | img1 and img2 have range [0, 255]""" 352 | 353 | def __init__(self): 354 | self.name = "PSNR" 355 | 356 | @staticmethod 357 | def __call__(img1, img2): 358 | mse = torch.mean((img1 - img2) ** 2) 359 | return 20 * torch.log10(255.0 / torch.sqrt(mse)) 360 | 361 | 362 | def laplacian(img): 363 | 364 | # signed angular difference 365 | grad_x1, grad_y1 = grad(img, next_pixel=True) # x_{n+1} - x_{n} 366 | grad_x0, grad_y0 = grad(img, next_pixel=False) # x_{n} - x_{n-1} 367 | 368 | laplacian_x = grad_x1 - grad_x0 # (x_{n+1} - x_{n}) - (x_{n} - x_{n-1}) 369 | laplacian_y = grad_y1 - grad_y0 370 | 371 | return laplacian_x + laplacian_y 372 | 373 | 374 | def grad(img, next_pixel=False, sovel=False): 375 | 376 | if img.shape[1] > 1: 377 | permuted = True 378 | img = img.permute(1, 0, 2, 3) 379 | else: 380 | permuted = False 381 | 382 | # set diff kernel 383 | if sovel: # use sovel filter for gradient calculation 384 | k_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32) / 8 385 | k_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32) / 8 386 | else: 387 | if next_pixel: # x_{n+1} - x_n 388 | k_x = torch.tensor([[0, -1, 1]], dtype=torch.float32) 389 | k_y = torch.tensor([[1], [-1], [0]], dtype=torch.float32) 390 | else: # x_{n} - x_{n-1} 391 | k_x = torch.tensor([[-1, 1, 0]], dtype=torch.float32) 392 | k_y = torch.tensor([[0], [1], [-1]], dtype=torch.float32) 393 | 394 | # upload to gpu 395 | k_x = k_x.to(img.device).unsqueeze(0).unsqueeze(0) 396 | k_y = k_y.to(img.device).unsqueeze(0).unsqueeze(0) 397 | 398 | # boundary handling (replicate elements at boundary) 399 | img_x = F.pad(img, (1, 1, 0, 0), 'replicate') 400 | img_y = F.pad(img, (0, 0, 1, 1), 'replicate') 401 | 402 | # take sign angular difference 403 | grad_x = signed_ang(F.conv2d(img_x, k_x)) 404 | grad_y = signed_ang(F.conv2d(img_y, k_y)) 405 | 406 | if permuted: 407 | grad_x = grad_x.permute(1, 0, 2, 3) 408 | grad_y = grad_y.permute(1, 0, 2, 3) 409 | 410 | return grad_x, grad_y 411 | 412 | 413 | def signed_ang(angle): 414 | """ 415 | cast all angles into [-pi, pi] 416 | """ 417 | return (angle + math.pi) % (2*math.pi) - math.pi 418 | 419 | 420 | # Adapted from https://github.com/svaiter/pyprox/blob/master/pyprox/operators.py 421 | def soft_thresholding(x, gamma): 422 | """ 423 | return element-wise shrinkage function with threshold kappa 424 | """ 425 | return torch.maximum(torch.zeros_like(x), 426 | 1 - gamma / torch.maximum(torch.abs(x), 1e-10*torch.ones_like(x))) * x 427 | 428 | 429 | def random_gen(num_planes=7, slm_type='ti'): 430 | """ 431 | random hyperparameters for the dataset 432 | """ 433 | frame_choices = [1, 2, 3, 3, 4, 4, 8, 8, 8, 8] if slm_type.lower() == 'ti' else [1] 434 | 435 | num_iters = random.choice(range(3000)) 436 | phase_range = random.uniform(1.0, 6.28) 437 | target_range = random.uniform(0.5, 1.5) 438 | learning_rate = random.uniform(0.01, 0.03) 439 | plane_idx = random.choice(range(num_planes)) 440 | 441 | return num_iters, phase_range, target_range, learning_rate, plane_idx 442 | --------------------------------------------------------------------------------