├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── algorithms.py ├── data └── 1.png ├── environment.yml ├── environment_windows.yml ├── eval.py ├── holonet.py ├── main.py ├── main_eval.sh ├── pretrained_networks └── README.md ├── propagation_ASM.py ├── propagation_model.py ├── setup_env.sh ├── train_holonet.py ├── train_model.py └── utils ├── __init__.py ├── arduino_laser_control_module.py ├── augmented_image_loader.py ├── calibration_module.py ├── camera_capture_module.py ├── detect_heds_module_path.py ├── modules.py ├── perceptualloss.py ├── slm_display_module.py ├── utils.py └── utils_tensorboard.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .DS_Store? 3 | __pycache__/ 4 | .idea/* 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "utils/pytorch_prototyping"] 2 | path = utils/pytorch_prototyping 3 | url = https://github.com/vsitzmann/pytorch_prototyping 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # License 2 | This project is licensed under the following license, with exception of the file "data/1.png", which is licensed under the [CC-BY](https://creativecommons.org/licenses/by/3.0/) license. 3 | 4 | 5 | Copyright (c) 2020, Stanford University 6 | 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms for academic and other non-commercial purposes with or without modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code, including modified source code, must retain the above copyright notice, this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form or a modified form of the source code must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 14 | 15 | * Neither the name of The Leland Stanford Junior University, any of its trademarks, the names of its employees, nor contributors to the source code may be used to endorse or promote products derived from this software without specific prior written permission. 16 | 17 | * Where a modified version of the source code is redistributed publicly in source or binary forms, the modified source code must be published in a freely accessible manner, or otherwise redistributed at no charge to anyone requesting a copy of the modified source code, subject to the same terms as this agreement. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE TRUSTEES OF THE LELAND STANFORD JUNIOR UNIVERSITY "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE LELAND STANFORD JUNIOR UNIVERSITY OR ITS TRUSTEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Holography with Camera-in-the-loop Training 2 | ### [Project Page](http://www.computationalimaging.org/publications/neuralholography/) | [Paper](http://www.computationalimaging.org/wp-content/uploads/2020/08/NeuralHolography_SIGAsia2020.pdf) 3 | 4 | [Yifan Peng](http://stanford.edu/~evanpeng/), [Suyeon Choi](https://choisuyeon.github.io/), [Nitish Padmanaban](https://nitish.me/), [Gordon Wetzstein](http://stanford.edu/~gordonwz/) 5 | 6 | This repository contains the scripts associated with the SIGGRAPH Asia 2020 paper "Neural Holography with Camera-in-the-loop Training" 7 | 8 | ## Update 20201203: 9 | We just released the second part of our scripts. 10 | 11 | It contains all the code that can reproduce our work, including the camera-in-the-loop optimization, the parameterized wave propagation model / Holonet training code. 12 | Also, we are publishing our hardware SDK incorporation code for automated pipelines. Please have a look and feel free to modify and use for your work! 13 | 14 | The specific updates can be found in following sections: 15 | - [2.2) CITL-calibrated model simulation](#22-citl-calibrated-model-simulation), 16 | - [2.3) Evaluation on the physical setup](#23-evaluation-on-the-physical-setup), 17 | - [3) Training](#3-training), 18 | - [4) Hardwares](#4-hardwares-camera-slm-laser-automation-and-calibration) 19 | 20 | ## Getting Started 21 | 22 | **Our code requires PyTorch >1.7.0, as it uses Complex64 type Tensors.** 23 | You can implement it in previous versions of PyTorch with the complex number operations implemented in ```utils/utils.py```. 24 | 25 | You can set up a conda environment with all dependencies like so: 26 | 27 | For Windows: 28 | ``` 29 | conda env create -f environment_windows.yml 30 | conda activate neural-holography 31 | ``` 32 | 33 | For Linux: (Hardware SDKs may not be compatible) 34 | ``` 35 | conda env create -f environment.yml 36 | conda activate neural-holography 37 | ``` 38 | or you can manually set up a conda environment with (just execute ```setup_env.sh``` if you use Windows): 39 | ``` 40 | chmod u+x setup_env.sh 41 | ./setup_env.sh 42 | conda activate neural-holography 43 | ``` 44 | 45 | You can load the [submodule](https://github.com/vsitzmann/pytorch_prototyping) in ```utils/pytorch_prototyping``` folder with 46 | ``` 47 | git submodule init 48 | git submodule update 49 | ``` 50 | To run phase generation with Holonet/U-net, download the pretrained model weights from [here](https://drive.google.com/file/d/1Xr353I3ycRFBXLoIjYTzUbdWzurW0N_H/view?usp=sharing) and place the contents in the ```pretrained_networks/``` folder. 51 | 52 | To run Camera-in-the-loop optimization or training, download [PyCapture2 SDK](https://www.flir.com/products/flycapture-sdk/) and [HOLOEYE SDK](https://holoeye.com/spatial-light-modulators/slm-software/slm-display-sdk/) and place the SDKs in your environment folder. If your hardware setup (SLM, Camera, Laser) is different from ours, please modify related files in the ```utils/``` folder according to the SDK before running the camera-in-th-loop optimization or training. 53 | Our hardware specifications can be found in the [paper](http://www.computationalimaging.org/wp-content/uploads/2020/08/NeuralHolography_SIGAsia2020.pdf) (Appendix B). 54 | 55 | 56 | ## High-level structure 57 | 58 | The code is organized as follows: 59 | 60 | * ```main.py``` generates phase patterns via SGD/GS/DPAC/Holonet/U-net. 61 | * ```eval.py``` reconstructs and evaluates with optimized phase patterns. 62 | * ```main_eval.sh``` first executes ```main.py``` for RGB channels and then executes ```eval.py```. 63 | 64 | 65 | 66 | * ```propagation_ASM.py``` contains the wave propagation operator (angular spectrum method). 67 | * ```propagation_model.py``` contains our parameterized wave propagation model. 68 | * ```holenet.py``` contains modules of HoloNet/U-net implementations. 69 | * ```algorithms.py``` contains GS/SGD/DPAC algorithm implementations. 70 | 71 | 72 | 73 | * ```train_holonet.py``` trains Holonet with ASM or the CITL-calibrated model. 74 | * ```train_model.py``` trains our wave propagation model with camera-in-the-loop training. 75 | 76 | 77 | ./utils/ 78 | * ```utils.py``` contains utility functions. 79 | * ```modules.py``` contatins PyTorch wrapper modules for easy use of ```algorithms.py``` and our hardware controller. 80 | * ```pytorch_prototyping/``` submodule contains custom pytorch modules with sane default parameters. (adapted from [here](https://github.com/vsitzmann/pytorch_prototyping)) 81 | * ```augmented_image_loader.py``` contains modules of loading a set of images. 82 | 83 | 84 | 85 | * ```utils_tensorboard.py ``` contains utility functions used for visualization on tensorboard. 86 | * ```slm_display_module.py ``` contains the SLM display controller module. ([HOLOEYE SDK](https://holoeye.com/spatial-light-modulators/slm-software/slm-display-sdk/)) 87 | * ```detect_heds_module_path.py``` sets the SLM SDK path. Otherwise you can copy the holoeye module directory into your project and import by using ```import holoeye```. 88 | * ```camera_capture_module.py ``` contains the FLIR camera capture controller module. ([PyCapture2 SDK](https://www.flir.com/products/flycapture-sdk/)) 89 | * ```calibration_module.py ``` contains the homography calibration module. 90 | 91 | ## Running the test 92 | 93 | You can simply execute the following bash script with a method parameter (You can replace the parameter ```SGD``` with ```GS/DPAC/HOLONET/UNET```.): 94 | ``` 95 | chmod u+x main_eval.sh 96 | ./main_eval.sh SGD 97 | ``` 98 | 99 | 100 | This bash script executes the phase optimization with 1) ```main.py``` for each R/G/B channel, and then executes 2) ```eval.py```, which simulates the holographic image reconstruction for the optimized patterns with the angular spectrum method. 101 | Check the ```./phases``` and ```./recon``` folders after the execution. 102 | 103 | ### 1) Phase optimization 104 | The SLM phase patterns can be reproduced with 105 | 106 | SGD (Gradient Descent): 107 | ``` 108 | python main.py --channel=0 --method=SGD --root_path=./phases 109 | ``` 110 | 111 | SGD with Camera-in-the-loop optimization: 112 | ``` 113 | python main.py --channel=0 --method=SGD --citl=True --root_path=./phases 114 | ``` 115 | 116 | SGD with CITL-calibrated models: 117 | ``` 118 | python main.py --channel=0 --method=SGD --prop_model='MODEL' --prop_model_dir=YOUR_MODEL_PATH --root_path=./phases 119 | ``` 120 | HoloNet 121 | ``` 122 | python main.py --channel=0 --method=HOLONET --root_path=./phases --generator_dir=./pretrained_networks 123 | ``` 124 | 125 | GS (Gerchberg-Saxton): 126 | ``` 127 | python main.py --channel=0 --method=GS --root_path=./phases 128 | ``` 129 | 130 | DPAC (Double Phase Encoding): 131 | ``` 132 | python main.py --channel=0 --method=DPAC --root_path=./phases 133 | ``` 134 | U-net 135 | ``` 136 | python main.py --channel=0 --method=UNET --root_path=./phases --generator_dir=./pretrained_networks 137 | ``` 138 | You can set ```--channel=1/2``` for other (green/blue) channels. 139 | 140 | To monitor progress, the optimization code writes tensorboard summaries into a "summaries" subdirectory in the ```root_path```. 141 | 142 | ### 2) Simulation/Evaluation 143 | #### 2.1) Ideal model simulation: 144 | 145 | With optimized phase patterns, you can simulate the holographic image reconstruction with 146 | 147 | ``` 148 | python eval.py --channel=0 --root_path=./phases/SGD_ASM --prop_model=ASM 149 | ``` 150 | 151 | For full-color simulation, you can set ```--channel=3```, ```0/1/2``` corresponds to `R/G/B`, respectively. 152 | 153 | This simulation code writes the reconstruction images in ```./recon``` folder as default. 154 | 155 | Feel free test other images after putting them in ```./data``` folder! 156 | 157 | #### 2.2) CITL-calibrated model simulation 158 | 159 | You can simulate those patterns with CITL-calibrated model 160 | 161 | ``` 162 | python eval.py --channel=0 --root_path=./phases/SGD_ASM --prop_model=MODEL 163 | ``` 164 | 165 | #### 2.3) Evaluation on the physical setup 166 | 167 | You can capture the image on the physical setup with 168 | 169 | ``` 170 | python eval.py --channel=0 --root_path=./phases/SGD_ASM --prop_model=CAMERA 171 | ``` 172 | 173 | ### 3) Training 174 | There are two-types of training in our work: 175 | 1) Parameterized wave propagation model (Camera-in-the-loop training) 176 | 2) Holonet. 177 | 178 | Note that we need the camera-in-the-loop for training the wave propagation model while the Holonet can be trained offline once the CITL-calibrated model is calibrated. 179 | (You can pre-capture a bunch of phase-captured image pairs and can train the wave propagation models offline as well, though.) 180 | 181 | You can train our wave propagation models with 182 | 183 | ``` 184 | python train_model.py --channel=0 185 | ``` 186 | 187 | You can train Holonet with 188 | 189 | ``` 190 | python train_holonet.py --perfect_prop_model=True --run_id=my_first_holonet --batch_size=4 --channel=0 191 | ``` 192 | 193 | If you want to train it with CITL-calibrated models, set ```perfect_prop_model``` option to ```False``` and set ```model_path``` to your calibrated models. 194 | 195 | ### 4) Hardware (Camera, SLM, laser) Automation and Calibration 196 | We incorporated the hardware SDKs as a pytorch module so that we can easily capture the experimental results and put them **IN-THE-LOOP**. 197 | 198 | You can call the module with 199 | ``` 200 | camera_prop = PhysicalProp(channel, roi_res=YOUR_ROI, 201 | range_row=(220, 1000), range_col=(300, 1630), 202 | patterns_path=opt.calibration_path, # path of 21 x 12 calibration patterns, see Supplement. 203 | show_preview=True) 204 | ``` 205 | Here, you may want to naively crop around the calibration patterns by setting ```range_row/col``` manually with the preview so that it can calculate the homography matrix without problems. (See Section S5 of [Supplement](https://drive.google.com/file/d/1vay4xeg5iC7y8CLWR6nQEWe3mjBuqCWB/view)) 206 | 207 | Then, you can get camera-captured images by simply sending SLM phase patterns through the forward pass of the module: 208 | 209 | ``` 210 | captured_amp = camera_prop(slm_phase) 211 | ``` 212 | 213 | 214 | ## Citation 215 | If you find our work useful in your research, please cite: 216 | 217 | ``` 218 | @article{Peng:2020:NeuralHolography, 219 | author = {Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein}, 220 | title = {Neural Holography with Camera-in-the-loop Training}, 221 | journal = {ACM Trans. Graph. (SIGGRAPH Asia)}, 222 | issue = {39}, 223 | number = {6}, 224 | year = {2020}, 225 | } 226 | ``` 227 | 228 | ## License 229 | This project is licensed under the following license, with exception of the file "data/1.png", which is licensed under the [CC-BY](https://creativecommons.org/licenses/by/3.0/) license. 230 | 231 | 232 | Copyright (c) 2020, Stanford University 233 | 234 | All rights reserved. 235 | 236 | Redistribution and use in source and binary forms for academic and other non-commercial purposes with or without modification, are permitted provided that the following conditions are met: 237 | 238 | * Redistributions of source code, including modified source code, must retain the above copyright notice, this list of conditions and the following disclaimer. 239 | 240 | * Redistributions in binary form or a modified form of the source code must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 241 | 242 | * Neither the name of The Leland Stanford Junior University, any of its trademarks, the names of its employees, nor contributors to the source code may be used to endorse or promote products derived from this software without specific prior written permission. 243 | 244 | * Where a modified version of the source code is redistributed publicly in source or binary forms, the modified source code must be published in a freely accessible manner, or otherwise redistributed at no charge to anyone requesting a copy of the modified source code, subject to the same terms as this agreement. 245 | 246 | THIS SOFTWARE IS PROVIDED BY THE TRUSTEES OF THE LELAND STANFORD JUNIOR UNIVERSITY "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE LELAND STANFORD JUNIOR UNIVERSITY OR ITS TRUSTEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 247 | 248 | ## Contact 249 | If you have any questions, please contact 250 | 251 | * Yifan (Evan) Peng, evanpeng@stanford.edu 252 | * Suyeon Choi, suyeon@stanford.edu 253 | * Gordon Wetzstein, gordon.wetzstein@stanford.edu -------------------------------------------------------------------------------- /algorithms.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | This is the algorithm script used for the representative iterative CGH implementations, i.e., GS and SGD. 4 | 5 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 6 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 7 | # The material is provided as-is, with no warranties whatsoever. 8 | # If you publish any code, data, or scientific work based on this, please cite our work. 9 | 10 | Technical Paper: 11 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 12 | """ 13 | 14 | import time 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | 19 | import utils.utils as utils 20 | from propagation_ASM import * 21 | 22 | 23 | # 1. GS 24 | def gerchberg_saxton(init_phase, target_amp, num_iters, prop_dist, wavelength, feature_size=6.4e-6, 25 | phase_path=None, prop_model='ASM', propagator=None, 26 | writer=None, dtype=torch.float32, precomputed_H_f=None, precomputed_H_b=None): 27 | """ 28 | Given the initial guess, run the SGD algorithm to calculate the optimal phase pattern of spatial light modulator 29 | 30 | :param init_phase: a tensor, in the shape of (1,1,H,W), initial guess for the phase. 31 | :param target_amp: a tensor, in the shape of (1,1,H,W), the amplitude of the target image. 32 | :param num_iters: the number of iterations to run the GS. 33 | :param prop_dist: propagation distance in m. 34 | :param wavelength: wavelength in m. 35 | :param feature_size: the SLM pixel pitch, in meters, default 6.4e-6 36 | :param phase_path: path to save the results. 37 | :param prop_model: string indicating the light transport model, default 'ASM'. ex) 'ASM', 'fresnel', 'model' 38 | :param propagator: predefined function or model instance for the propagation. 39 | :param writer: tensorboard writer 40 | :param dtype: torch datatype for computation at different precision, default torch.float32. 41 | :param precomputed_H_f: A Pytorch complex64 tensor, pre-computed kernel for forward prop (SLM to image) 42 | :param precomputed_H_b: A Pytorch complex64 tensor, pre-computed kernel for backward propagation (image to SLM) 43 | 44 | Output 45 | ------ 46 | :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W) 47 | """ 48 | 49 | # initial guess; random phase 50 | real, imag = utils.polar_to_rect(torch.ones_like(init_phase), init_phase) 51 | slm_field = torch.complex(real, imag) 52 | 53 | # run the GS algorithm 54 | for k in range(num_iters): 55 | # SLM plane to image plane 56 | recon_field = utils.propagate_field(slm_field, propagator, prop_dist, wavelength, feature_size, 57 | prop_model, dtype, precomputed_H_f) 58 | 59 | # write to tensorboard / write phase image 60 | # Note that it takes 0.~ s for writing it to tensorboard 61 | if False:#k > 0 and k % 10 == 0: 62 | utils.write_gs_summary(slm_field, recon_field, target_amp, k, writer, prefix='test') 63 | 64 | # replace amplitude at the image plane 65 | recon_field = utils.replace_amplitude(recon_field, target_amp) 66 | 67 | # image plane to SLM plane 68 | slm_field = utils.propagate_field(recon_field, propagator, -prop_dist, wavelength, feature_size, 69 | prop_model, dtype, precomputed_H_b) 70 | 71 | # amplitude constraint at the SLM plane 72 | slm_field = utils.replace_amplitude(slm_field, torch.ones_like(target_amp)) 73 | 74 | # return phases 75 | return slm_field.angle() 76 | 77 | 78 | # 2. SGD 79 | def stochastic_gradient_descent(init_phase, target_amp, num_iters, prop_dist, wavelength, feature_size, 80 | roi_res=None, phase_path=None, prop_model='ASM', propagator=None, 81 | loss=nn.MSELoss(), lr=0.01, lr_s=0.003, s0=1.0, citl=False, camera_prop=None, 82 | writer=None, dtype=torch.float32, precomputed_H=None): 83 | 84 | """ 85 | Given the initial guess, run the SGD algorithm to calculate the optimal phase pattern of spatial light modulator. 86 | 87 | Input 88 | ------ 89 | :param init_phase: a tensor, in the shape of (1,1,H,W), initial guess for the phase. 90 | :param target_amp: a tensor, in the shape of (1,1,H,W), the amplitude of the target image. 91 | :param num_iters: the number of iterations to run the SGD. 92 | :param prop_dist: propagation distance in m. 93 | :param wavelength: wavelength in m. 94 | :param feature_size: the SLM pixel pitch, in meters, default 6.4e-6 95 | :param roi_res: a tuple of integer, region of interest, like (880, 1600) 96 | :param phase_path: a string, for saving intermediate phases 97 | :param prop_model: a string, that indicates the propagation model. ('ASM' or 'MODEL') 98 | :param propagator: predefined function or model instance for the propagation. 99 | :param loss: loss function, default L2 100 | :param lr: learning rate for optimization variables 101 | :param lr_s: learning rate for learnable scale 102 | :param s0: initial scale 103 | :param writer: Tensorboard writer instance 104 | :param dtype: default torch.float32 105 | :param precomputed_H: A Pytorch complex64 tensor, pre-computed kernel shape of (1,1,2H,2W) for fast computation. 106 | 107 | Output 108 | ------ 109 | :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W) 110 | """ 111 | 112 | device = init_phase.device 113 | s = torch.tensor(s0, requires_grad=True, device=device) 114 | 115 | # phase at the slm plane 116 | slm_phase = init_phase.requires_grad_(True) 117 | 118 | # optimization variables and adam optimizer 119 | optvars = [{'params': slm_phase}] 120 | if lr_s > 0: 121 | optvars += [{'params': s, 'lr': lr_s}] 122 | optimizer = optim.Adam(optvars, lr=lr) 123 | 124 | # crop target roi 125 | target_amp = utils.crop_image(target_amp, roi_res, stacked_complex=False) 126 | 127 | # run the iterative algorithm 128 | for k in range(num_iters): 129 | print(k) 130 | optimizer.zero_grad() 131 | # forward propagation from the SLM plane to the target plane 132 | real, imag = utils.polar_to_rect(torch.ones_like(slm_phase), slm_phase) 133 | slm_field = torch.complex(real, imag) 134 | 135 | recon_field = utils.propagate_field(slm_field, propagator, prop_dist, wavelength, feature_size, 136 | prop_model, dtype, precomputed_H) 137 | 138 | # get amplitude 139 | recon_amp = recon_field.abs() 140 | 141 | # crop roi 142 | recon_amp = utils.crop_image(recon_amp, target_shape=roi_res, stacked_complex=False) 143 | 144 | # camera-in-the-loop technique 145 | if citl: 146 | captured_amp = camera_prop(slm_phase) 147 | 148 | # use the gradient of proxy, replacing the amplitudes 149 | # captured_amp is assumed that its size already matches that of recon_amp 150 | out_amp = recon_amp + (captured_amp - recon_amp).detach() 151 | else: 152 | out_amp = recon_amp 153 | 154 | # calculate loss and backprop 155 | lossValue = loss(s * out_amp, target_amp) 156 | lossValue.backward() 157 | optimizer.step() 158 | 159 | # write to tensorboard / write phase image 160 | # Note that it takes 0.~ s for writing it to tensorboard 161 | with torch.no_grad(): 162 | if k % 50 == 0: 163 | utils.write_sgd_summary(slm_phase, out_amp, target_amp, k, 164 | writer=writer, path=phase_path, s=s, prefix='test') 165 | 166 | return slm_phase 167 | 168 | 169 | # 3. DPAC 170 | def double_phase_amplitude_coding(target_phase, target_amp, prop_dist, wavelength, feature_size, 171 | prop_model='ASM', propagator=None, 172 | dtype=torch.float32, precomputed_H=None): 173 | """ 174 | Use a single propagation and converts amplitude and phase to double phase coding 175 | 176 | Input 177 | ----- 178 | :param target_phase: The phase at the target image plane 179 | :param target_amp: A tensor, (B,C,H,W), the amplitude at the target image plane. 180 | :param prop_dist: propagation distance, in m. 181 | :param wavelength: wavelength, in m. 182 | :param feature_size: The SLM pixel pitch, in meters. 183 | :param prop_model: The light propagation model to use for prop from target plane to slm plane 184 | :param propagator: propagation_ASM 185 | :param dtype: torch datatype for computation at different precision. 186 | :param precomputed_H: pre-computed kernel - to make it faster over multiple iteration/images - calculate it once 187 | 188 | Output 189 | ------ 190 | :return: a tensor, the optimized phase pattern at the SLM plane, in the shape of (1,1,H,W) 191 | """ 192 | real, imag = utils.polar_to_rect(target_amp, target_phase) 193 | target_field = torch.complex(real, imag) 194 | 195 | slm_field = utils.propagate_field(target_field, propagator, prop_dist, wavelength, feature_size, 196 | prop_model, dtype, precomputed_H) 197 | 198 | slm_phase = double_phase(slm_field, three_pi=False, mean_adjust=True) 199 | 200 | return slm_phase 201 | 202 | 203 | def double_phase(field, three_pi=True, mean_adjust=True): 204 | """Converts a complex field to double phase coding 205 | 206 | field: A complex64 tensor with dims [..., height, width] 207 | three_pi, mean_adjust: see double_phase_amp_phase 208 | """ 209 | return double_phase_amp_phase(field.abs(), field.angle(), three_pi, mean_adjust) 210 | 211 | 212 | def double_phase_amp_phase(amplitudes, phases, three_pi=True, mean_adjust=True): 213 | """converts amplitude and phase to double phase coding 214 | 215 | amplitudes: per-pixel amplitudes of the complex field 216 | phases: per-pixel phases of the complex field 217 | three_pi: if True, outputs values in a 3pi range, instead of 2pi 218 | mean_adjust: if True, centers the phases in the range of interest 219 | """ 220 | # normalize 221 | amplitudes = amplitudes / amplitudes.max() 222 | 223 | phases_a = phases - torch.acos(amplitudes) 224 | phases_b = phases + torch.acos(amplitudes) 225 | 226 | phases_out = phases_a 227 | phases_out[..., ::2, 1::2] = phases_b[..., ::2, 1::2] 228 | phases_out[..., 1::2, ::2] = phases_b[..., 1::2, ::2] 229 | 230 | if three_pi: 231 | max_phase = 3 * math.pi 232 | else: 233 | max_phase = 2 * math.pi 234 | 235 | if mean_adjust: 236 | phases_out -= phases_out.mean() 237 | 238 | return (phases_out + max_phase / 2) % max_phase - max_phase / 2 239 | -------------------------------------------------------------------------------- /data/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/neural-holography/d2e399014aa80844edffd98bca34d2df80a69c84/data/1.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: neural-holography 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - aotools 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - aotools=1.0.4=py37h39e3cac_0 10 | - blas=1.0=mkl 11 | - ca-certificates=2020.10.14=0 12 | - certifi=2020.11.8=py37h06a4308_0 13 | - cffi=1.12.3=py37h2e261b9_0 14 | - cloudpickle=1.2.2=py_0 15 | - cudatoolkit=11.0.221=h6bb024c_0 16 | - cycler=0.10.0=py37_0 17 | - cytoolz=0.10.1=py37h7b6447c_0 18 | - dask-core=2.9.0=py_0 19 | - dbus=1.13.12=h746ee38_0 20 | - decorator=4.4.1=py_0 21 | - expat=2.2.6=he6710b0_0 22 | - fontconfig=2.13.0=h9420a91_0 23 | - freetype=2.9.1=h8a8886c_1 24 | - glib=2.63.1=h5a9c865_0 25 | - gst-plugins-base=1.14.0=hbbd80ab_1 26 | - gstreamer=1.14.0=hb453b48_1 27 | - icu=58.2=h9c2bf20_1 28 | - imageio=2.6.1=py37_0 29 | - intel-openmp=2019.4=243 30 | - jpeg=9b=h024ee3a_2 31 | - kiwisolver=1.1.0=py37he6710b0_0 32 | - libedit=3.1.20181209=hc058e9b_0 33 | - libffi=3.2.1=hd88cf55_4 34 | - libgcc-ng=9.1.0=hdf63c60_0 35 | - libgfortran-ng=7.3.0=hdf63c60_0 36 | - libllvm10=10.0.1=hbcb73fb_5 37 | - libpng=1.6.37=hbc83047_0 38 | - libprotobuf=3.11.1=h8b12597_0 39 | - libstdcxx-ng=9.1.0=hdf63c60_0 40 | - libtiff=4.0.10=h2733197_2 41 | - libuuid=1.0.3=h1bed415_2 42 | - libuv=1.40.0=h7b6447c_0 43 | - libxcb=1.13=h1bed415_1 44 | - libxml2=2.9.9=hea5a465_1 45 | - llvmlite=0.34.0=py37h269e1b5_4 46 | - matplotlib=3.1.1=py37h5429711_0 47 | - mkl=2019.4=243 48 | - mkl_fft=1.0.12=py37ha843d7b_0 49 | - mkl_random=1.0.2=py37hd81dba3_0 50 | - ncurses=6.1=he6710b0_1 51 | - networkx=2.4=py_0 52 | - ninja=1.9.0=py37hfd86e86_0 53 | - numba=0.51.2=py37h04863e7_1 54 | - numpy=1.16.4=py37h7e9f1db_0 55 | - numpy-base=1.16.4=py37hde5b4d6_0 56 | - olefile=0.46=py37_0 57 | - openssl=1.1.1h=h7b6447c_0 58 | - pcre=8.43=he6710b0_0 59 | - pillow=6.0.0=py37h34e0f95_0 60 | - pip=19.1.1=py37_0 61 | - protobuf=3.11.1=py37he1b5a44_0 62 | - pycparser=2.19=py37_0 63 | - pyparsing=2.4.5=py_0 64 | - pyqt=5.9.2=py37h05f1152_2 65 | - python=3.7.3=h0371630_0 66 | - python-dateutil=2.8.1=py_0 67 | - pytorch=1.7.0=py3.7_cuda11.0.221_cudnn8.0.3_0 68 | - pytz=2019.3=py_0 69 | - pywavelets=1.1.1=py37h7b6447c_0 70 | - qt=5.9.7=h5867ecd_1 71 | - readline=7.0=h7b6447c_5 72 | - scikit-image=0.16.2=py37h0573a6f_0 73 | - scipy=1.2.1=py37h7c811a0_0 74 | - setuptools=41.0.1=py37_0 75 | - sip=4.19.8=py37hf484d3e_0 76 | - six=1.12.0=py37_0 77 | - sqlite=3.28.0=h7b6447c_0 78 | - tbb=2020.3=hfd86e86_0 79 | - tensorboardx=1.9=py_0 80 | - tk=8.6.8=hbc83047_0 81 | - toolz=0.10.0=py_0 82 | - torchaudio=0.7.0=py37 83 | - torchvision=0.8.1=py37_cu110 84 | - tornado=6.0.3=py37h7b6447c_0 85 | - typing_extensions=3.7.4.3=py_0 86 | - wheel=0.33.4=py37_0 87 | - xz=5.2.4=h14c3975_4 88 | - zlib=1.2.11=h7b6447c_3 89 | - zstd=1.3.7=h0b5b093_0 90 | - pip: 91 | - absl-py==0.9.0 92 | - cachetools==4.1.0 93 | - chardet==3.0.4 94 | - configargparse==1.0 95 | - densenet==0.1.1 96 | - google-auth==1.14.3 97 | - google-auth-oauthlib==0.4.1 98 | - grpcio==1.28.1 99 | - h5py==2.10.0 100 | - idna==2.9 101 | - importlib-metadata==1.6.0 102 | - markdown==3.2.2 103 | - oauthlib==3.1.0 104 | - opencv-python==4.4.0.42 105 | - pyasn1==0.4.8 106 | - pyasn1-modules==0.2.8 107 | - pyfirmata==1.1.0 108 | - pyserial==3.5 109 | - requests==2.23.0 110 | - requests-oauthlib==1.3.0 111 | - rsa==4.0 112 | - tensorboard==2.2.1 113 | - tensorboard-plugin-wit==1.6.0.post3 114 | - urllib3==1.25.9 115 | - werkzeug==1.0.1 116 | - zipp==3.1.0 117 | prefix: /home/suyeon/anaconda3/envs/neural-holography 118 | -------------------------------------------------------------------------------- /environment_windows.yml: -------------------------------------------------------------------------------- 1 | name: neural-holography 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - absl-py=0.11.0=py36ha15d459_0 9 | - aiohttp=3.7.3=py36h68aa20f_0 10 | - async-timeout=3.0.1=py_1000 11 | - attrs=20.3.0=pyhd3deb0d_0 12 | - blas=1.0=mkl 13 | - blinker=1.4=py_1 14 | - brotlipy=0.7.0=py36hc753bc4_1001 15 | - ca-certificates=2020.10.14=0 16 | - cachetools=4.1.1=py_0 17 | - certifi=2020.11.8=py36haa95532_0 18 | - cffi=1.14.4=py36he58ceb7_1 19 | - chardet=3.0.4=py36hd36e781_1008 20 | - click=7.1.2=pyh9f0ad1d_0 21 | - cloudpickle=1.6.0=py_0 22 | - cryptography=3.2.1=py36he58ceb7_0 23 | - cudatoolkit=11.0.221=h74a9793_0 24 | - cycler=0.10.0=py36h009560c_0 25 | - cytoolz=0.11.0=py36he774522_0 26 | - dask-core=1.1.4=py36_1 27 | - dataclasses=0.7=py36_0 28 | - decorator=4.4.2=py_0 29 | - freeglut=3.0.0=h6538335_1005 30 | - freetype=2.10.4=h546665d_0 31 | - google-auth=1.23.0=pyhd8ed1ab_0 32 | - google-auth-oauthlib=0.4.1=py_2 33 | - grpcio=1.33.2=py36h4374274_2 34 | - icc_rt=2019.0.0=h0cc432a_1 35 | - icu=67.1=h33f27b4_0 36 | - idna=2.10=pyh9f0ad1d_0 37 | - idna_ssl=1.1.0=py36h9f0ad1d_1001 38 | - imageio=2.9.0=py_0 39 | - importlib-metadata=3.1.0=pyhd8ed1ab_0 40 | - intel-openmp=2020.3=h57928b3_311 41 | - jasper=2.0.14=hdc05fd1_1 42 | - jpeg=9d=h8ffe710_0 43 | - kiwisolver=1.2.0=py36h74a9793_0 44 | - libblas=3.8.0=21_mkl 45 | - libcblas=3.8.0=21_mkl 46 | - libclang=10.0.1=default_hf44288c_1 47 | - liblapack=3.8.0=21_mkl 48 | - liblapacke=3.8.0=21_mkl 49 | - libopencv=4.5.0=py36_3 50 | - libpng=1.6.37=h1d00b33_2 51 | - libprotobuf=3.13.0.1=h200bbdf_0 52 | - libtiff=4.1.0=hc10be44_6 53 | - libuv=1.40.0=he774522_0 54 | - libwebp-base=1.1.0=h8ffe710_3 55 | - lz4-c=1.9.2=h62dcd97_2 56 | - markdown=3.3.3=pyh9f0ad1d_0 57 | - matplotlib-base=3.3.1=py36hba9282a_0 58 | - mkl=2020.4=hb70f87d_311 59 | - mkl-service=2.3.0=py36h2bbff1b_0 60 | - multidict=4.7.5=py36h779f372_2 61 | - networkx=2.5=py_0 62 | - ninja=1.10.1=py36h7ef1ec2_0 63 | - numpy=1.19.4=py36hd1b969e_1 64 | - oauthlib=3.0.1=py_0 65 | - olefile=0.46=pyh9f0ad1d_1 66 | - opencv=4.5.0=py36_3 67 | - openssl=1.1.1h=he774522_0 68 | - pillow=8.0.1=py36ha0524ae_0 69 | - pip=20.2.4=py36haa95532_0 70 | - protobuf=3.13.0.1=py36he2d232f_1 71 | - py-opencv=4.5.0=py36h7b2dad6_3 72 | - pyasn1=0.4.8=py_0 73 | - pyasn1-modules=0.2.7=py_0 74 | - pycparser=2.20=pyh9f0ad1d_2 75 | - pyjwt=1.7.1=py_0 76 | - pyopenssl=20.0.0=pyhd8ed1ab_0 77 | - pyparsing=2.4.7=py_0 78 | - pysocks=1.7.1=py36hd36e781_2 79 | - python=3.6.12=h5500b2f_2 80 | - python-dateutil=2.8.1=py_0 81 | - python_abi=3.6=1_cp36m 82 | - pytorch=1.7.0=py3.6_cuda110_cudnn8_0 83 | - pywavelets=1.1.1=py36he774522_2 84 | - qt=5.12.9=hb2cf2c5_0 85 | - requests=2.25.0=pyhd3deb0d_0 86 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 87 | - rsa=4.6=pyh9f0ad1d_0 88 | - scikit-image=0.17.2=py36h1e1f486_0 89 | - scipy=1.5.2=py36h9439919_0 90 | - setuptools=50.3.1=py36haa95532_1 91 | - six=1.15.0=py_0 92 | - sqlite=3.33.0=h2a8f88b_0 93 | - tensorboard=2.4.0=pyhd8ed1ab_0 94 | - tensorboard-plugin-wit=1.7.0=pyh9f0ad1d_0 95 | - tifffile=2020.10.1=py36h8c2d366_2 96 | - tk=8.6.10=he774522_1 97 | - toolz=0.11.1=py_0 98 | - torchaudio=0.7.0=py36 99 | - torchvision=0.8.1=py36_cu110 100 | - tornado=6.0.4=py36he774522_1 101 | - typing-extensions=3.7.4.3=0 102 | - typing_extensions=3.7.4.3=py_0 103 | - urllib3=1.25.11=py_0 104 | - vc=14.1=h0510ff6_4 105 | - vs2015_runtime=14.16.27012=hf0eaf9b_3 106 | - werkzeug=1.0.1=pyh9f0ad1d_0 107 | - wheel=0.35.1=pyhd3eb1b0_0 108 | - win_inet_pton=1.1.0=py36h9f0ad1d_1 109 | - wincertstore=0.2=py36h7fe50ca_0 110 | - xz=5.2.5=h62dcd97_1 111 | - yarl=1.6.3=py36h68aa20f_0 112 | - zipp=3.4.0=py_0 113 | - zlib=1.2.11=h62dcd97_4 114 | - zstd=1.4.5=h1f3a1b7_2 115 | - pip: 116 | - configargparse==1.2.3 117 | - future==0.18.2 118 | - pycapture2==2.13.61 119 | - pyfirmata==1.1.0 120 | - pyserial==3.5 121 | - tensorboardx==2.1 122 | - torchgeometry==0.1.2 123 | - wxpython==4.1.1 124 | prefix: C:\Users\Suyeon\anaconda3\envs\neural-holography 125 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script that is used for evaluating phases for physical or simulation forward model 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 11 | 12 | ----- 13 | 14 | $ python eval.py --channel=[0 or 1 or 2 or 3] --root_path=[some path] 15 | 16 | """ 17 | 18 | import imageio 19 | import os 20 | import skimage.io 21 | import scipy.io as sio 22 | import sys 23 | import torch 24 | import numpy as np 25 | import configargparse 26 | 27 | from propagation_ASM import propagation_ASM 28 | from utils.augmented_image_loader import ImageLoader 29 | import utils.utils as utils 30 | from utils.modules import PhysicalProp 31 | from propagation_model import ModelPropagate 32 | 33 | # Command line argument processing 34 | p = configargparse.ArgumentParser() 35 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 36 | 37 | p.add_argument('--channel', type=int, default=1, help='red:0, green:1, blue:2, rgb:3') 38 | p.add_argument('--prop_model', type=str, default='ASM', 39 | help='Type of propagation model for reconstruction: ASM / MODEL / CAMERA') 40 | p.add_argument('--root_path', type=str, default='./phases', help='Directory where test phases are being stored.') 41 | p.add_argument('--prop_model_dir', type=str, default='./calibrated_models/', 42 | help='Directory for the CITL-calibrated wave propagation models') 43 | p.add_argument('--calibration_path', type=str, default=f'./calibration', 44 | help='Directory where calibration phases are being stored.') 45 | 46 | # Parse 47 | opt = p.parse_args() 48 | channel = opt.channel 49 | chs = range(channel) if channel == 3 else [channel] # retrieve all channels if channel is 3 50 | run_id = f'{opt.root_path.split("/")[-1]}_{opt.prop_model}' # {algorithm}_{prop_model} 51 | 52 | # Hyperparameters setting 53 | cm, mm, um, nm = 1e-2, 1e-3, 1e-6, 1e-9 54 | chan_strs = ('red', 'green', 'blue', 'rgb') 55 | prop_dists = (20*cm, 20*cm, 20*cm) 56 | wavelengths = (638*nm, 520*nm, 450*nm) # wavelength of each color 57 | feature_size = (6.4*um, 6.4*um) # SLM pitch 58 | 59 | # Resolutions 60 | slm_res = (1080, 1920) # resolution of SLM 61 | if 'HOLONET' in run_id.upper(): 62 | slm_res = (1072, 1920) 63 | elif 'UNET' in run_id.upper(): 64 | slm_res = (1024, 2048) 65 | 66 | image_res = (1080, 1920) 67 | roi_res = (880, 1600) # regions of interest (to penalize) 68 | dtype = torch.float32 # default datatype (Note: the result may be slightly different if you use float64, etc.) 69 | device = torch.device('cuda') # The gpu you are using 70 | 71 | # You can pre-compute kernels for fast-computation 72 | precomputed_H = [None] * 3 73 | if opt.prop_model == 'ASM': 74 | propagator = propagation_ASM 75 | for c in chs: 76 | precomputed_H[c] = propagator(torch.empty(1, 1, *slm_res, 2), feature_size, 77 | wavelengths[c], prop_dists[c], return_H=True).to(device) 78 | 79 | elif opt.prop_model.upper() == 'CAMERA': 80 | propagator = PhysicalProp(channel, laser_arduino=True, roi_res=(roi_res[1], roi_res[0]), slm_settle_time=0.15, 81 | range_row=(220, 1000), range_col=(300, 1630), 82 | patterns_path=opt.calibration_path, # path of 21 x 12 calibration patterns, see Supplement. 83 | show_preview=True) 84 | elif opt.prop_model.upper() == 'MODEL': 85 | blur = utils.make_kernel_gaussian(0.85, 3) 86 | propagators = {} 87 | for c in chs: 88 | propagator = ModelPropagate(distance=prop_dists[c], 89 | feature_size=feature_size, 90 | wavelength=wavelengths[c], 91 | blur=blur).to(device) 92 | 93 | propagator.load_state_dict(torch.load(os.path.join(opt.prop_model_dir, f'{chan_strs[c]}.pth'), map_location=device)) 94 | propagator.eval() 95 | propagators[c] = propagator 96 | 97 | print(f' - reconstruction with {opt.prop_model}... ') 98 | 99 | # Data path 100 | data_path = './data' 101 | recon_path = './recon' 102 | 103 | # Augmented image loader (if you want to shuffle, augment dataset, put options accordingly.)
 104 | image_loader = ImageLoader(data_path, channel=channel if channel < 3 else None, 105 | image_res=image_res, homography_res=roi_res, 106 | crop_to_homography=True, 107 | shuffle=False, vertical_flips=False, horizontal_flips=False) 108 | 109 | # Placeholders for metrics 110 | psnrs = {'amp': [], 'lin': [], 'srgb': []} 111 | ssims = {'amp': [], 'lin': [], 'srgb': []} 112 | idxs = [] 113 | 114 | # Loop over the dataset 115 | for k, target in enumerate(image_loader): 116 | # get target image 117 | target_amp, target_res, target_filename = target 118 | target_path, target_filename = os.path.split(target_filename[0]) 119 | target_idx = target_filename.split('_')[-1] 120 | target_amp = target_amp.to(device) 121 | 122 | print(f' - running for img_{target_idx}...') 123 | 124 | # crop to ROI 125 | target_amp = utils.crop_image(target_amp, target_shape=roi_res, stacked_complex=False).to(device) 126 | 127 | recon_amp = [] 128 | 129 | # for each channel, propagate wave from the SLM plane to the image plane and get the reconstructed image. 130 | for c in chs: 131 | # load and invert phase (our SLM setup) 132 | phase_filename = os.path.join(opt.root_path, chan_strs[c], f'{target_idx}.png') 133 | slm_phase = skimage.io.imread(phase_filename) / 255. 134 | slm_phase = torch.tensor((1 - slm_phase) * 2 * np.pi - np.pi, dtype=dtype).reshape(1, 1, *slm_res).to(device) 135 | 136 | # propagate field 137 | real, imag = utils.polar_to_rect(torch.ones_like(slm_phase), slm_phase) 138 | slm_field = torch.complex(real, imag) 139 | 140 | if opt.prop_model.upper() == 'MODEL': 141 | propagator = propagators[c] # Select CITL-calibrated models for each channel 142 | recon_field = utils.propagate_field(slm_field, propagator, prop_dists[c], wavelengths[c], feature_size, 143 | opt.prop_model, dtype) 144 | 145 | # cartesian to polar coordinate 146 | recon_amp_c = recon_field.abs() 147 | 148 | # crop to ROI 149 | recon_amp_c = utils.crop_image(recon_amp_c, target_shape=roi_res, stacked_complex=False) 150 | 151 | # append to list 152 | recon_amp.append(recon_amp_c) 153 | 154 | # list to tensor, scaling 155 | recon_amp = torch.cat(recon_amp, dim=1) 156 | recon_amp *= (torch.sum(recon_amp * target_amp, (-2, -1), keepdim=True) 157 | / torch.sum(recon_amp * recon_amp, (-2, -1), keepdim=True)) 158 | 159 | # tensor to numpy 160 | recon_amp = recon_amp.squeeze().cpu().detach().numpy() 161 | target_amp = target_amp.squeeze().cpu().detach().numpy() 162 | 163 | if channel == 3: 164 | recon_amp = recon_amp.transpose(1, 2, 0) 165 | target_amp = target_amp.transpose(1, 2, 0) 166 | 167 | # calculate metrics 168 | psnr_val, ssim_val = utils.get_psnr_ssim(recon_amp, target_amp, multichannel=(channel == 3)) 169 | 170 | idxs.append(target_idx) 171 | 172 | for domain in ['amp', 'lin', 'srgb']: 173 | psnrs[domain].append(psnr_val[domain]) 174 | ssims[domain].append(ssim_val[domain]) 175 | print(f'PSNR({domain}): {psnr_val[domain]}, SSIM({domain}): {ssim_val[domain]:.4f}, ') 176 | 177 | # save reconstructed image in srgb domain 178 | recon_srgb = utils.srgb_lin2gamma(np.clip(recon_amp**2, 0.0, 1.0)) 179 | utils.cond_mkdir(recon_path) 180 | imageio.imwrite(os.path.join(recon_path, f'{target_idx}_{run_id}_{chan_strs[channel]}.png'), (recon_srgb * np.iinfo(np.uint8).max).round().astype(np.uint8)) 181 | 182 | # save it as a .mat file 183 | data_dict = {} 184 | data_dict['img_idx'] = idxs 185 | for domain in ['amp', 'lin', 'srgb']: 186 | data_dict[f'ssims_{domain}'] = ssims[domain] 187 | data_dict[f'psnrs_{domain}'] = psnrs[domain] 188 | 189 | sio.savemat(os.path.join(recon_path, f'metrics_{run_id}_{chan_strs[channel]}.mat'), data_dict) 190 | -------------------------------------------------------------------------------- /holonet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script that is used for the implementation of HoloNet. The class HoloNet(nn.module) is 3 | described in the following. 4 | 5 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 6 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 7 | # The material is provided as-is, with no warranties whatsoever. 8 | # If you publish any code, data, or scientific work based on this, please cite our work. 9 | 10 | Technical Paper: 11 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 12 | """ 13 | 14 | import math 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | 19 | import utils.utils as utils 20 | from propagation_ASM import propagation_ASM, compute_zernike_basis, combine_zernike_basis 21 | from utils.pytorch_prototyping.pytorch_prototyping import Conv2dSame, Unet 22 | 23 | 24 | class HoloNet(nn.Module): 25 | """Generates phase for the final non-iterative model 26 | 27 | Class initialization parameters 28 | ------------------------------- 29 | distance: propagation dist between SLM and target, in meters, default 0.1. 30 | Note: distance is negated internally, so the PhaseGenerator and 31 | ProcessAndPropagate get the same input 32 | wavelength: the wavelength of interest, in meters, default 520e-9 33 | feature_size: the SLM pixel pitch, in meters, default 6.4e-6 34 | zernike_coeffs: a torch tensor that corresponds to process_phase.py, 35 | ProcessAndPropagate.coeffs, after training is completed. Default None, 36 | which disables passing zernike coeffs to the final network 37 | source_amplitude: a process_phase.SourceAmplitude module, after training. 38 | Default None, which disables passing source amp to the final network 39 | target_field: a torch tensor that corresponds to propagation_model.py, 40 | citl_calibrated_model.target_field, after training is completed. Default None, 41 | which disables passing target_field to the final network 42 | latent_codes: a citl_calibrated_model.latent_codes parameter, after training. 43 | Default None, which disables passing latent_codes to the final network 44 | initial_phase: a module that returns an initial phase given the target amp. 45 | Default None, which assumes all zeros initial phase 46 | final_phase_only: a module that processes the post-propagation amp+phase to 47 | a phase-only output that works as well as iterative results. Default 48 | None, which switches to double phase coding 49 | proptype: chooses the propagation operator ('ASM': propagation_ASM, 50 | 'fresnel': propagation_fresnel). Default ASM. 51 | linear_conv: if True, pads for linear conv for propagation. Default True 52 | 53 | Usage 54 | ----- 55 | Functions as a pytorch module: 56 | 57 | >>> phase_generator = HoloNet(...) 58 | >>> slm_amp, slm_phase = phase_generator(target_amp) 59 | 60 | target_amp: amplitude at the target plane, with dimensions [batch, 1, 61 | height, width] 62 | slm_amp: amplitude to be encoded in the phase pattern at the SLM plane. Used 63 | to enforce uniformity, if desired. Same as target dimensions 64 | slm_phase: encoded phase-only representation at SLM plane, same dimensions 65 | """ 66 | def __init__(self, distance=0.1, wavelength=520e-9, feature_size=6.4e-6, 67 | zernike_coeffs=None, source_amplitude=None, target_field=None, latent_codes=None, 68 | initial_phase=None, final_phase_only=None, proptype='ASM', linear_conv=True, 69 | manual_aberr_corr=False): 70 | super(HoloNet, self).__init__() 71 | 72 | # submodules 73 | self.source_amplitude = source_amplitude 74 | self.initial_phase = initial_phase 75 | self.final_phase_only = final_phase_only 76 | if target_field is not None: 77 | self.target_field = target_field.detach() 78 | else: 79 | self.target_field = None 80 | 81 | if latent_codes is not None: 82 | self.latent_codes = latent_codes.detach() 83 | else: 84 | self.latent_codes = None 85 | 86 | # propagation parameters 87 | self.wavelength = wavelength 88 | self.feature_size = (feature_size 89 | if hasattr(feature_size, '__len__') 90 | else [feature_size] * 2) 91 | self.distance = -distance 92 | 93 | self.zernike_coeffs = (None if zernike_coeffs is None 94 | else -zernike_coeffs.clone().detach()) 95 | 96 | # objects to precompute 97 | self.zernike = None 98 | self.precomped_H = None 99 | self.precomped_H_zernike = None 100 | self.source_amp = None 101 | 102 | # whether to pass zernike/source amp as layers or divide out manually 103 | self.manual_aberr_corr = manual_aberr_corr 104 | 105 | # make sure parameters from the model training phase don't update 106 | if self.zernike_coeffs is not None: 107 | self.zernike_coeffs.requires_grad = False 108 | if self.source_amplitude is not None: 109 | for p in self.source_amplitude.parameters(): 110 | p.requires_grad = False 111 | 112 | # change out the propagation operator 113 | if proptype == 'ASM': 114 | self.prop = propagation_ASM 115 | else: 116 | ValueError(f'Unsupported prop type {proptype}') 117 | 118 | self.linear_conv = linear_conv 119 | 120 | # set a device for initializing the precomputed objects 121 | try: 122 | self.dev = next(self.parameters()).device 123 | except StopIteration: # no parameters 124 | self.dev = torch.device('cpu') 125 | 126 | def forward(self, target_amp): 127 | # compute some initial phase, convert to real+imag representation 128 | if self.initial_phase is not None: 129 | init_phase = self.initial_phase(target_amp) 130 | real, imag = utils.polar_to_rect(target_amp, init_phase) 131 | target_complex = torch.complex(real, imag) 132 | else: 133 | init_phase = torch.zeros_like(target_amp) 134 | # no need to convert, zero phase implies amplitude = real part 135 | target_complex = torch.complex(target_amp, init_phase) 136 | 137 | # subtract the additional target field 138 | if self.target_field is not None: 139 | target_complex_diff = target_complex - self.target_field 140 | else: 141 | target_complex_diff = target_complex 142 | 143 | # precompute the propagation kernel only once 144 | if self.precomped_H is None: 145 | self.precomped_H = self.prop(target_complex_diff, 146 | self.feature_size, 147 | self.wavelength, 148 | self.distance, 149 | return_H=True, 150 | linear_conv=self.linear_conv) 151 | self.precomped_H = self.precomped_H.to(self.dev).detach() 152 | self.precomped_H.requires_grad = False 153 | 154 | if self.precomped_H_zernike is None: 155 | if self.zernike is None and self.zernike_coeffs is not None: 156 | self.zernike_basis = compute_zernike_basis(self.zernike_coeffs.size()[0], 157 | [i * 2 for i in target_amp.size()[-2:]], wo_piston=True) 158 | self.zernike_basis = self.zernike_basis.to(self.dev).detach() 159 | self.zernike = combine_zernike_basis(self.zernike_coeffs, self.zernike_basis) 160 | self.zernike = utils.ifftshift(self.zernike) 161 | self.zernike = self.zernike.to(self.dev).detach() 162 | self.zernike.requires_grad = False 163 | self.precomped_H_zernike = self.zernike * self.precomped_H 164 | self.precomped_H_zernike = self.precomped_H_zernike.to(self.dev).detach() 165 | self.precomped_H_zernike.requires_grad = False 166 | else: 167 | self.precomped_H_zernike = self.precomped_H 168 | 169 | # precompute the source amplitude, only once 170 | if self.source_amp is None and self.source_amplitude is not None: 171 | self.source_amp = self.source_amplitude(target_amp) 172 | self.source_amp = self.source_amp.to(self.dev).detach() 173 | self.source_amp.requires_grad = False 174 | 175 | # implement the basic propagation to the SLM plane 176 | slm_naive = self.prop(target_complex_diff, self.feature_size, 177 | self.wavelength, self.distance, 178 | precomped_H=self.precomped_H_zernike, 179 | linear_conv=self.linear_conv) 180 | 181 | # switch to amplitude+phase and apply source amplitude adjustment 182 | amp, ang = utils.rect_to_polar(slm_naive.real, slm_naive.imag) 183 | # amp, ang = slm_naive.abs(), slm_naive.angle() # PyTorch 1.7.0 Complex tensor doesn't support 184 | # the gradient of angle() currently. 185 | 186 | if self.source_amp is not None and self.manual_aberr_corr: 187 | amp = amp / self.source_amp 188 | 189 | if self.final_phase_only is None: 190 | return amp, double_phase(amp, ang, three_pi=False) 191 | else: 192 | # note the change to usual complex number stacking! 193 | # We're making this the channel dim via cat instead of stack 194 | if (self.zernike is None and self.source_amp is None 195 | or self.manual_aberr_corr): 196 | if self.latent_codes is not None: 197 | slm_amp_phase = torch.cat((amp, ang, self.latent_codes.repeat(amp.shape[0], 1, 1, 1)), -3) 198 | else: 199 | slm_amp_phase = torch.cat((amp, ang), -3) 200 | elif self.zernike is None: 201 | slm_amp_phase = torch.cat((amp, ang, self.source_amp), -3) 202 | elif self.source_amp is None: 203 | slm_amp_phase = torch.cat((amp, ang, self.zernike), -3) 204 | else: 205 | slm_amp_phase = torch.cat((amp, ang, self.zernike, 206 | self.source_amp), -3) 207 | return amp, self.final_phase_only(slm_amp_phase) 208 | 209 | def to(self, *args, **kwargs): 210 | slf = super().to(*args, **kwargs) 211 | if slf.zernike is not None: 212 | slf.zernike = slf.zernike.to(*args, **kwargs) 213 | if slf.precomped_H is not None: 214 | slf.precomped_H = slf.precomped_H.to(*args, **kwargs) 215 | if slf.source_amp is not None: 216 | slf.source_amp = slf.source_amp.to(*args, **kwargs) 217 | if slf.target_field is not None: 218 | slf.target_field = slf.target_field.to(*args, **kwargs) 219 | if slf.latent_codes is not None: 220 | slf.latent_codes = slf.latent_codes.to(*args, **kwargs) 221 | 222 | # try setting dev based on some parameter, default to cpu 223 | try: 224 | slf.dev = next(slf.parameters()).device 225 | except StopIteration: # no parameters 226 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0] 227 | if device_arg is not None: 228 | slf.dev = device_arg 229 | return slf 230 | 231 | 232 | class InitialPhaseUnet(nn.Module): 233 | """computes the initial input phase given a target amplitude""" 234 | def __init__(self, num_down=8, num_features_init=32, max_features=256, 235 | norm=nn.BatchNorm2d): 236 | super(InitialPhaseUnet, self).__init__() 237 | 238 | net = [Unet(1, 1, num_features_init, num_down, max_features, 239 | use_dropout=False, 240 | upsampling_mode='transpose', 241 | norm=norm, 242 | outermost_linear=True), 243 | nn.Hardtanh(-math.pi, math.pi)] 244 | 245 | self.net = nn.Sequential(*net) 246 | 247 | def forward(self, amp): 248 | out_phase = self.net(amp) 249 | return out_phase 250 | 251 | 252 | class FinalPhaseOnlyUnet(nn.Module): 253 | """computes the final SLM phase given a naive SLM amplitude and phase""" 254 | def __init__(self, num_down=8, num_features_init=32, max_features=256, 255 | norm=nn.BatchNorm2d, num_in=4): 256 | super(FinalPhaseOnlyUnet, self).__init__() 257 | 258 | net = [Unet(num_in, 1, num_features_init, num_down, max_features, 259 | use_dropout=False, 260 | upsampling_mode='transpose', 261 | norm=norm, 262 | outermost_linear=True), 263 | nn.Hardtanh(-math.pi, math.pi)] 264 | 265 | self.net = nn.Sequential(*net) 266 | 267 | def forward(self, amp_phase): 268 | out_phase = self.net(amp_phase) 269 | return out_phase 270 | 271 | 272 | class PhaseOnlyUnet(nn.Module): 273 | """computes the final SLM phase given a target amplitude""" 274 | def __init__(self, num_down=10, num_features_init=16, norm=nn.BatchNorm2d): 275 | super(PhaseOnlyUnet, self).__init__() 276 | 277 | net = [Unet(1, 1, num_features_init, num_down, 1024, 278 | use_dropout=False, 279 | upsampling_mode='transpose', 280 | norm=norm, 281 | outermost_linear=True), 282 | nn.Hardtanh(-math.pi, math.pi)] 283 | 284 | self.net = nn.Sequential(*net) 285 | 286 | def forward(self, target_amp): 287 | out_phase = self.net(target_amp) 288 | return (torch.ones(1), out_phase) 289 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Neural holography: 3 | 4 | This is the main executive script used for the phase generation using Holonet/UNET or 5 | optimization using (GS/DPAC/SGD) + camera-in-the-loop (CITL). 6 | 7 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 8 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 9 | # The material is provided as-is, with no warranties whatsoever. 10 | # If you publish any code, data, or scientific work based on this, please cite our work. 11 | 12 | @article{Peng:2020:NeuralHolography, 13 | author = {Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein}, 14 | title = {{Neural Holography with Camera-in-the-loop Training}}, 15 | journal = {ACM Trans. Graph. (SIGGRAPH Asia)}, 16 | year = {2020}, 17 | } 18 | 19 | ----- 20 | 21 | $ python main.py --channel=0 --algorithm=HOLONET --root_path=./phases --generator_dir=./pretrained_models 22 | """ 23 | 24 | import os 25 | import sys 26 | import cv2 27 | import torch 28 | import torch.nn as nn 29 | import configargparse 30 | from torch.utils.tensorboard import SummaryWriter 31 | 32 | import utils.utils as utils 33 | from utils.augmented_image_loader import ImageLoader 34 | from propagation_model import ModelPropagate 35 | from utils.modules import SGD, GS, DPAC, PhysicalProp 36 | from holonet import HoloNet, InitialPhaseUnet, FinalPhaseOnlyUnet, PhaseOnlyUnet 37 | from propagation_ASM import propagation_ASM 38 | 39 | # Command line argument processing 40 | p = configargparse.ArgumentParser() 41 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 42 | 43 | p.add_argument('--channel', type=int, default=1, help='Red:0, green:1, blue:2') 44 | p.add_argument('--method', type=str, default='SGD', help='Type of algorithm, GS/SGD/DPAC/HOLONET/UNET') 45 | p.add_argument('--prop_model', type=str, default='ASM', help='Type of propagation model, ASM or model') 46 | p.add_argument('--root_path', type=str, default='./phases', help='Directory where optimized phases will be saved.') 47 | p.add_argument('--data_path', type=str, default='./data', help='Directory for the dataset') 48 | p.add_argument('--generator_dir', type=str, default='./pretrained_networks', 49 | help='Directory for the pretrained holonet/unet network') 50 | p.add_argument('--prop_model_dir', type=str, default='./calibrated_models', 51 | help='Directory for the CITL-calibrated wave propagation models') 52 | p.add_argument('--citl', type=utils.str2bool, default=False, help='Use of Camera-in-the-loop optimization with SGD') 53 | p.add_argument('--experiment', type=str, default='', help='Name of experiment') 54 | p.add_argument('--lr', type=float, default=8e-3, help='Learning rate for phase variables (for SGD)') 55 | p.add_argument('--lr_s', type=float, default=2e-3, help='Learning rate for learnable scale (for SGD)') 56 | p.add_argument('--num_iters', type=int, default=500, help='Number of iterations (GS, SGD)') 57 | 58 | # parse arguments 59 | opt = p.parse_args() 60 | run_id = f'{opt.experiment}_{opt.method}_{opt.prop_model}' # {algorithm}_{prop_model} format 61 | if opt.citl: 62 | run_id = f'{run_id}_citl' 63 | 64 | channel = opt.channel # Red:0 / Green:1 / Blue:2 65 | chan_str = ('red', 'green', 'blue')[channel] 66 | 67 | print(f' - optimizing phase with {opt.method}/{opt.prop_model} ... ') 68 | if opt.citl: 69 | print(f' - with camera-in-the-loop ...') 70 | 71 | # Hyperparameters setting 72 | cm, mm, um, nm = 1e-2, 1e-3, 1e-6, 1e-9 73 | prop_dist = (20 * cm, 20 * cm, 20 * cm)[channel] # propagation distance from SLM plane to target plane 74 | wavelength = (638 * nm, 520 * nm, 450 * nm)[channel] # wavelength of each color 75 | feature_size = (6.4 * um, 6.4 * um) # SLM pitch 76 | slm_res = (1080, 1920) # resolution of SLM 77 | image_res = (1080, 1920) 78 | roi_res = (880, 1600) # regions of interest (to penalize for SGD) 79 | dtype = torch.float32 # default datatype (Note: the result may be slightly different if you use float64, etc.) 80 | device = torch.device('cuda') # The gpu you are using 81 | 82 | # Options for the algorithm 83 | loss = nn.MSELoss().to(device) # loss functions to use (try other loss functions!) 84 | s0 = 1.0 # initial scale 85 | 86 | root_path = os.path.join(opt.root_path, run_id, chan_str) # path for saving out optimized phases 87 | 88 | # Tensorboard writer 89 | summaries_dir = os.path.join(root_path, 'summaries') 90 | utils.cond_mkdir(summaries_dir) 91 | writer = SummaryWriter(summaries_dir) 92 | 93 | # Hardware setup for CITL 94 | if opt.citl: 95 | camera_prop = PhysicalProp(channel, laser_arduino=True, roi_res=(roi_res[1], roi_res[0]), slm_settle_time=0.12, 96 | range_row=(220, 1000), range_col=(300, 1630), 97 | patterns_path=f'F:/citl/calibration', 98 | show_preview=True) 99 | else: 100 | camera_prop = None 101 | 102 | # Simulation model 103 | if opt.prop_model == 'ASM': 104 | propagator = propagation_ASM # Ideal model 105 | 106 | elif opt.prop_model.upper() == 'MODEL': 107 | blur = utils.make_kernel_gaussian(0.85, 3) 108 | propagator = ModelPropagate(distance=prop_dist, # Parameterized wave propagation model 109 | feature_size=feature_size, 110 | wavelength=wavelength, 111 | blur=blur).to(device) 112 | 113 | # load CITL-calibrated model 114 | propagator.load_state_dict(torch.load(f'{opt.prop_model_dir}/{chan_str}.pth', map_location=device)) 115 | propagator.eval() 116 | 117 | 118 | # Select Phase generation method, algorithm 119 | if opt.method == 'SGD': 120 | phase_only_algorithm = SGD(prop_dist, wavelength, feature_size, opt.num_iters, roi_res, root_path, 121 | opt.prop_model, propagator, loss, opt.lr, opt.lr_s, s0, opt.citl, camera_prop, writer, device) 122 | elif opt.method == 'GS': 123 | phase_only_algorithm = GS(prop_dist, wavelength, feature_size, opt.num_iters, root_path, 124 | opt.prop_model, propagator, writer, device) 125 | elif opt.method == 'DPAC': 126 | phase_only_algorithm = DPAC(prop_dist, wavelength, feature_size, opt.prop_model, propagator, device) 127 | elif opt.method == 'HOLONET': 128 | phase_only_algorithm = HoloNet(prop_dist, wavelength, feature_size, initial_phase=InitialPhaseUnet(4, 16), 129 | final_phase_only=FinalPhaseOnlyUnet(4, 16, num_in=2)).to(device) 130 | model_path = os.path.join(opt.generator_dir, f'holonet20_{chan_str}.pth') 131 | image_res = (1072, 1920) 132 | elif opt.method == 'UNET': 133 | phase_only_algorithm = PhaseOnlyUnet(num_features_init=32).to(device) 134 | model_path = os.path.join(opt.generator_dir, f'unet20_{chan_str}.pth') 135 | image_res = (1024, 2048) 136 | 137 | if 'NET' in opt.method: 138 | checkpoint = torch.load(model_path) 139 | phase_only_algorithm.load_state_dict(checkpoint) 140 | phase_only_algorithm.eval() 141 | 142 | 143 | # Augmented image loader (if you want to shuffle, augment dataset, put options accordingly.) 144 | image_loader = ImageLoader(opt.data_path, channel=channel, 145 | image_res=image_res, homography_res=roi_res, 146 | crop_to_homography=True, 147 | shuffle=False, vertical_flips=False, horizontal_flips=False) 148 | 149 | # Loop over the dataset 150 | for k, target in enumerate(image_loader): 151 | # get target image 152 | target_amp, target_res, target_filename = target 153 | target_path, target_filename = os.path.split(target_filename[0]) 154 | target_idx = target_filename.split('_')[-1] 155 | target_amp = target_amp.to(device) 156 | print(target_idx) 157 | 158 | # if you want to separate folders by target_idx or whatever, you can do so here. 159 | phase_only_algorithm.init_scale = s0 * utils.crop_image(target_amp, roi_res, stacked_complex=False).mean() 160 | phase_only_algorithm.phase_path = os.path.join(root_path) 161 | 162 | # run algorithm (See algorithm_modules.py and algorithms.py) 163 | if opt.method in ['DPAC', 'HOLONET', 'UNET']: 164 | # direct methods 165 | _, final_phase = phase_only_algorithm(target_amp) 166 | else: 167 | # iterative methods, initial phase: random guess 168 | init_phase = (-0.5 + 1.0 * torch.rand(1, 1, *slm_res)).to(device) 169 | final_phase = phase_only_algorithm(target_amp, init_phase) 170 | 171 | print(final_phase.shape) 172 | 173 | # save the final result somewhere. 174 | phase_out_8bit = utils.phasemap_8bit(final_phase.cpu().detach(), inverted=True) 175 | 176 | utils.cond_mkdir(root_path) 177 | cv2.imwrite(os.path.join(root_path, f'{target_idx}.png'), phase_out_8bit) 178 | 179 | print(f' - Done, result: --root_path={root_path}') 180 | -------------------------------------------------------------------------------- /main_eval.sh: -------------------------------------------------------------------------------- 1 | python main.py --channel=0 --method="$1" --root_path=./phases 2 | python main.py --channel=1 --method="$1" --root_path=./phases 3 | python main.py --channel=2 --method="$1" --root_path=./phases 4 | python eval.py --channel=3 --root_path=./phases/"$1"_ASM 5 | -------------------------------------------------------------------------------- /pretrained_networks/README.md: -------------------------------------------------------------------------------- 1 | Pre-trained networks should be placed in this directory for the main.py 2 | script. -------------------------------------------------------------------------------- /propagation_ASM.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script that is used for the wave propagation using the angular spectrum method (ASM). Refer to 3 | Goodman, Joseph W. Introduction to Fourier optics. Roberts and Company Publishers, 2005, for principle details. 4 | 5 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 6 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 7 | # The material is provided as-is, with no warranties whatsoever. 8 | # If you publish any code, data, or scientific work based on this, please cite our work. 9 | 10 | Technical Paper: 11 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 12 | """ 13 | 14 | import math 15 | import torch 16 | import numpy as np 17 | import utils.utils as utils 18 | import torch.fft 19 | from aotools.functions import zernikeArray 20 | 21 | 22 | def propagation_ASM(u_in, feature_size, wavelength, z, linear_conv=True, 23 | padtype='zero', return_H=False, precomped_H=None, 24 | return_H_exp=False, precomped_H_exp=None, 25 | dtype=torch.float32): 26 | """Propagates the input field using the angular spectrum method 27 | 28 | Inputs 29 | ------ 30 | u_in: PyTorch Complex tensor (torch.cfloat) of size (num_images, 1, height, width) -- updated with PyTorch 1.7.0 31 | feature_size: (height, width) of individual holographic features in m 32 | wavelength: wavelength in m 33 | z: propagation distance 34 | linear_conv: if True, pad the input to obtain a linear convolution 35 | padtype: 'zero' to pad with zeros, 'median' to pad with median of u_in's 36 | amplitude 37 | return_H[_exp]: used for precomputing H or H_exp, ends the computation early 38 | and returns the desired variable 39 | precomped_H[_exp]: the precomputed value for H or H_exp 40 | dtype: torch dtype for computation at different precision 41 | 42 | Output 43 | ------ 44 | tensor of size (num_images, 1, height, width, 2) 45 | """ 46 | 47 | if linear_conv: 48 | # preprocess with padding for linear conv. 49 | input_resolution = u_in.size()[-2:] 50 | conv_size = [i * 2 for i in input_resolution] 51 | if padtype == 'zero': 52 | padval = 0 53 | elif padtype == 'median': 54 | padval = torch.median(torch.pow((u_in**2).sum(-1), 0.5)) 55 | u_in = utils.pad_image(u_in, conv_size, padval=padval, stacked_complex=False) 56 | 57 | if precomped_H is None and precomped_H_exp is None: 58 | # resolution of input field, should be: (num_images, num_channels, height, width, 2) 59 | field_resolution = u_in.size() 60 | 61 | # number of pixels 62 | num_y, num_x = field_resolution[2], field_resolution[3] 63 | 64 | # sampling inteval size 65 | dy, dx = feature_size 66 | 67 | # size of the field 68 | y, x = (dy * float(num_y), dx * float(num_x)) 69 | 70 | # frequency coordinates sampling 71 | fy = np.linspace(-1 / (2 * dy) + 0.5 / (2 * y), 1 / (2 * dy) - 0.5 / (2 * y), num_y) 72 | fx = np.linspace(-1 / (2 * dx) + 0.5 / (2 * x), 1 / (2 * dx) - 0.5 / (2 * x), num_x) 73 | 74 | # momentum/reciprocal space 75 | FX, FY = np.meshgrid(fx, fy) 76 | 77 | # transfer function in numpy (omit distance) 78 | HH = 2 * math.pi * np.sqrt(1 / wavelength**2 - (FX**2 + FY**2)) 79 | 80 | # create tensor & upload to device (GPU) 81 | H_exp = torch.tensor(HH, dtype=dtype).to(u_in.device) 82 | 83 | ### 84 | # here one may iterate over multiple distances, once H_exp is uploaded on GPU 85 | 86 | # reshape tensor and multiply 87 | H_exp = torch.reshape(H_exp, (1, 1, *H_exp.size())) 88 | 89 | # handle loading the precomputed H_exp value, or saving it for later runs 90 | elif precomped_H_exp is not None: 91 | H_exp = precomped_H_exp 92 | 93 | if precomped_H is None: 94 | # multiply by distance 95 | H_exp = torch.mul(H_exp, z) 96 | 97 | # band-limited ASM - Matsushima et al. (2009) 98 | fy_max = 1 / np.sqrt((2 * z * (1 / y))**2 + 1) / wavelength 99 | fx_max = 1 / np.sqrt((2 * z * (1 / x))**2 + 1) / wavelength 100 | H_filter = torch.tensor(((np.abs(FX) < fx_max) & (np.abs(FY) < fy_max)).astype(np.uint8), dtype=dtype) 101 | 102 | # get real/img components 103 | H_real, H_imag = utils.polar_to_rect(H_filter.to(u_in.device), H_exp) 104 | 105 | H = torch.stack((H_real, H_imag), 4) 106 | H = utils.ifftshift(H) 107 | H = torch.view_as_complex(H) 108 | else: 109 | H = precomped_H 110 | 111 | # return for use later as precomputed inputs 112 | if return_H_exp: 113 | return H_exp 114 | if return_H: 115 | return H 116 | 117 | # For who cannot use Pytorch 1.7.0 and its Complex tensors support: 118 | # # angular spectrum 119 | # U1 = torch.fft(utils.ifftshift(u_in), 2, True) 120 | # 121 | # # convolution of the system 122 | # U2 = utils.mul_complex(H, U1) 123 | # 124 | # # Fourier transform of the convolution to the observation plane 125 | # u_out = utils.fftshift(torch.ifft(U2, 2, True)) 126 | 127 | U1 = torch.fft.fftn(utils.ifftshift(u_in), dim=(-2, -1), norm='ortho') 128 | 129 | U2 = H * U1 130 | 131 | u_out = utils.fftshift(torch.fft.ifftn(U2, dim=(-2, -1), norm='ortho')) 132 | 133 | if linear_conv: 134 | # return utils.crop_image(u_out, input_resolution) # using stacked version 135 | return utils.crop_image(u_out, input_resolution, pytorch=True, stacked_complex=False) # using complex tensor 136 | else: 137 | return u_out 138 | 139 | 140 | def propagation_ASM_zernike(u_in, feature_size, wavelength, z, linear_conv=True, 141 | padtype='zero', coeffs=None, zernike=None, adjoint=False, 142 | return_H=False, precomped_H=None, return_H_exp=False, 143 | precomped_H_exp=None, dtype=torch.float32): 144 | """A wrapper around propagation_ASM that applies a Zernike phase correction 145 | 146 | Inputs 147 | ------ 148 | coeffs: a 1d tensor of Zernike coefficients 149 | zernike: a precomputed Zernike function basis. Should contain the same 150 | number of basis functions as number of coeffs, and be the same height 151 | and width as u_in. Use compute_zernike_basis to compute 152 | adjoint: if True, propagate then apply zernike. If False, apply then prop 153 | 154 | See propagation_ASM for u_in, feature_size, wavelength, z, linear_conv, 155 | padtype, return_H, precomped_H, return_H_exp, precomped_H_exp, dtype 156 | """ 157 | 158 | if return_H or return_H_exp or coeffs is None: 159 | return propagation_ASM(u_in, feature_size, wavelength, z, linear_conv, 160 | padtype, return_H, precomped_H, return_H_exp, 161 | precomped_H_exp, dtype) 162 | 163 | coeffs = coeffs.reshape(-1, 1, 1) 164 | 165 | if zernike is None: 166 | # resolution of input field, should be: (num_images, num_channels, height, width, 2) 167 | zernike = compute_zernike_basis(coeffs.size()[0], u_in.size()[-3:-1]) 168 | 169 | if z < 0: 170 | coeffs = -coeffs 171 | 172 | # create phase 173 | zernike = combine_zernike_basis(coeffs, zernike) 174 | 175 | # operator order 176 | if adjoint: 177 | u_out = propagation_ASM(u_in, feature_size, wavelength, z, linear_conv, 178 | padtype, return_H, precomped_H, return_H_exp, 179 | precomped_H_exp) 180 | u_out = utils.mul_complex(zernike, u_out) 181 | else: 182 | u_out = utils.mul_complex(zernike, u_in) 183 | u_out = propagation_ASM(u_out, feature_size, wavelength, z, linear_conv, 184 | padtype, return_H, precomped_H, return_H_exp, 185 | precomped_H_exp) 186 | 187 | return u_out 188 | 189 | 190 | def propagation_ASM_zernike_fourier(u_in, feature_size, wavelength, z, linear_conv=True, 191 | padtype='zero', coeffs=None, zernike=None, 192 | return_H=False, precomped_H=None, return_H_exp=False, 193 | precomped_H_exp=None, dtype=torch.float32): 194 | """A wrapper around propagation_ASM that applies a Zernike phase correction at Fourier plane 195 | 196 | Inputs 197 | ------ 198 | coeffs: a 1d tensor of Zernike coefficients 199 | zernike: a precomputed Zernike function basis. Should contain the same 200 | number of basis functions as number of coeffs, and be the same height 201 | and width as u_in. Use compute_zernike_basis to compute 202 | 203 | See propagation_ASM for u_in, feature_size, wavelength, z, linear_conv, 204 | padtype, return_H, precomped_H, return_H_exp, precomped_H_exp, dtype 205 | """ 206 | 207 | if return_H or return_H_exp or coeffs is None: 208 | return propagation_ASM(u_in, feature_size, wavelength, z, linear_conv, 209 | padtype, return_H, precomped_H, return_H_exp, 210 | precomped_H_exp, dtype) 211 | 212 | coeffs = coeffs.reshape(-1, 1, 1) 213 | 214 | if zernike is None: 215 | # resolution of input field, should be: (num_images, num_channels, height, width, 2) 216 | zernike = compute_zernike_basis(coeffs.size()[0], u_in.size()[-3:-1]) 217 | 218 | if z < 0: 219 | coeffs = -coeffs 220 | 221 | # create phase 222 | zernike = combine_zernike_basis(coeffs, zernike) 223 | zernike = utils.ifftshift(zernike) 224 | 225 | precomped_H_new = zernike * precomped_H 226 | 227 | u_out = propagation_ASM(u_in, feature_size, wavelength, z, linear_conv, 228 | padtype, return_H, precomped_H_new, return_H_exp, 229 | precomped_H_exp) 230 | 231 | return u_out 232 | 233 | 234 | def combine_zernike_basis(coeffs, basis, return_phase=False): 235 | """ 236 | Multiplies the Zernike coefficients and basis functions while preserving 237 | dimensions 238 | 239 | :param coeffs: torch tensor with coeffs, see propagation_ASM_zernike 240 | :param basis: the output of compute_zernike_basis, must be same length as coeffs 241 | :param return_phase: 242 | :return: A Complex64 tensor that combines coeffs and basis. 243 | """ 244 | 245 | if len(coeffs.shape) < 3: 246 | coeffs = torch.reshape(coeffs, (coeffs.shape[0], 1, 1)) 247 | 248 | # combine zernike basis and coefficients 249 | zernike = (coeffs * basis).sum(0, keepdim=True) 250 | 251 | # shape to [1, len(coeffs), H, W] 252 | zernike = zernike.unsqueeze(0) 253 | 254 | # convert to Pytorch Complex tensor 255 | real, imag = utils.polar_to_rect(torch.ones_like(zernike), zernike) 256 | return torch.complex(real, imag) 257 | 258 | 259 | def compute_zernike_basis(num_polynomials, field_res, dtype=torch.float32, wo_piston=False): 260 | """Computes a set of Zernike basis function with resolution field_res 261 | 262 | num_polynomials: number of Zernike polynomials in this basis 263 | field_res: [height, width] in px, any list-like object 264 | dtype: torch dtype for computation at different precision 265 | """ 266 | 267 | # size the zernike basis to avoid circular masking 268 | zernike_diam = int(np.ceil(np.sqrt(field_res[0]**2 + field_res[1]**2))) 269 | 270 | # create zernike functions 271 | 272 | if not wo_piston: 273 | zernike = zernikeArray(num_polynomials, zernike_diam) 274 | else: # 200427 - exclude pistorn term 275 | idxs = range(2, 2 + num_polynomials) 276 | zernike = zernikeArray(idxs, zernike_diam) 277 | 278 | zernike = utils.crop_image(zernike, field_res, pytorch=False) 279 | 280 | # convert to tensor and create phase 281 | zernike = torch.tensor(zernike, dtype=dtype, requires_grad=False) 282 | 283 | return zernike 284 | -------------------------------------------------------------------------------- /propagation_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script that is used for the parameterized wave propagation described in the paper. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 11 | """ 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import numpy as np 17 | 18 | import utils.utils as utils 19 | from propagation_ASM import compute_zernike_basis, combine_zernike_basis, \ 20 | propagation_ASM, propagation_ASM_zernike, propagation_ASM_zernike_fourier 21 | 22 | from utils.pytorch_prototyping.pytorch_prototyping import Conv2dSame 23 | 24 | 25 | class LatentCodedMLP(nn.Module): 26 | """ 27 | concatenate latent codes in the middle of forward pass as well. 28 | put latent codes shape of (1, L, H, W) as a parameter for the forward pass. 29 | 30 | num_latent_codes: list of numbers of slices for each layer 31 | * so the sum of num_latent_codes should be total number of the latent codes channels 32 | """ 33 | def __init__(self, num_layers=5, num_features=32, norm=None, num_latent_codes=None): 34 | super(LatentCodedMLP, self).__init__() 35 | 36 | if num_latent_codes is None: 37 | num_latent_codes = [0] * num_layers 38 | 39 | assert len(num_latent_codes) == num_layers 40 | 41 | self.num_latent_codes = num_latent_codes 42 | self.idxs = [sum(num_latent_codes[:y]) for y in range(num_layers + 1)] 43 | self.nets = nn.ModuleList([]) 44 | num_features = [num_features] * num_layers 45 | num_features[0] = 1 46 | 47 | # define each layer 48 | for i in range(num_layers - 1): 49 | net = [nn.Conv2d(num_features[i] + num_latent_codes[i], num_features[i + 1], kernel_size=1)] 50 | if norm is not None: 51 | net += [norm(num_groups=4, num_channels=num_features[i + 1], affine=True)] 52 | net += [nn.LeakyReLU(0.2, True)] 53 | self.nets.append(nn.Sequential(*net)) 54 | 55 | self.nets.append(nn.Conv2d(num_features[-1] + num_latent_codes[-1], 1, kernel_size=1)) 56 | 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | nn.init.normal_(m.weight, std=0.05) 60 | 61 | def forward(self, phases, latent_codes=None): 62 | 63 | after_relu = phases 64 | # concatenate latent codes at each layer and send through the convolutional layers 65 | for i in range(len(self.num_latent_codes)): 66 | if latent_codes is not None: 67 | after_relu = torch.cat((after_relu, latent_codes[:, self.idxs[i]:self.idxs[i + 1], ...]), 1) 68 | after_relu = self.nets[i](after_relu) 69 | 70 | # residual connection 71 | return phases - after_relu 72 | 73 | 74 | class ContentDependentField(nn.Module): 75 | def __init__(self, num_layers=5, num_features=32, norm=nn.GroupNorm, latent_coords=False): 76 | """ Simple 5layers CNN modeling content dependent undiffracted light """ 77 | 78 | super(ContentDependentField, self).__init__() 79 | 80 | if not latent_coords: 81 | first_ch = 1 82 | else: 83 | first_ch = 3 84 | 85 | net = [Conv2dSame(first_ch, num_features, kernel_size=3)] 86 | 87 | for i in range(num_layers - 2): 88 | if norm is not None: 89 | net += [norm(num_groups=2, num_channels=num_features, affine=True)] 90 | net += [nn.LeakyReLU(0.2, True), 91 | Conv2dSame(num_features, num_features, kernel_size=3)] 92 | 93 | if norm is not None: 94 | net += [norm(num_groups=4, num_channels=num_features, affine=True)] 95 | 96 | net += [nn.LeakyReLU(0.2, True), 97 | Conv2dSame(num_features, 2, kernel_size=3)] 98 | 99 | self.net = nn.Sequential(*net) 100 | 101 | def forward(self, phases, latent_coords=None): 102 | if latent_coords is not None: 103 | input_cnn = torch.cat((phases, latent_coords), dim=1) 104 | else: 105 | input_cnn = phases 106 | 107 | return self.net(input_cnn).unsqueeze(4).permute(0, 4, 2, 3, 1) 108 | 109 | 110 | class ProcessPhase(nn.Module): 111 | def __init__(self, num_layers=5, num_features=32, num_output_feat=0, norm=nn.BatchNorm2d, num_latent_codes=0): 112 | super(ProcessPhase, self).__init__() 113 | 114 | # avoid zero 115 | self.num_output_feat = max(num_output_feat, 1) 116 | self.num_latent_codes = num_latent_codes 117 | 118 | # a bunch of 1x1 conv layers, set by num_layers 119 | net = [nn.Conv2d(1 + num_latent_codes, num_features, kernel_size=1)] 120 | 121 | for i in range(num_layers - 2): 122 | if norm is not None: 123 | net += [norm(num_groups=2, num_channels=num_features, affine=True)] 124 | net += [nn.LeakyReLU(0.2, True), 125 | nn.Conv2d(num_features, num_features, kernel_size=1)] 126 | 127 | if norm is not None: 128 | net += [norm(num_groups=2, num_channels=num_features, affine=True)] 129 | 130 | net += [nn.ReLU(True), 131 | nn.Conv2d(num_features, self.num_output_feat, kernel_size=1)] 132 | 133 | self.net = nn.Sequential(*net) 134 | 135 | def forward(self, phases): 136 | return phases - self.net(phases) 137 | 138 | 139 | class SourceAmplitude(nn.Module): 140 | def __init__(self, num_gaussians=3, init_sigma=None, init_amp=0.7, x_s0=0.0, y_s0=0.0): 141 | super(SourceAmplitude, self).__init__() 142 | 143 | self.num_gaussians = num_gaussians 144 | 145 | if init_sigma is None: 146 | init_sigma = [100.] * self.num_gaussians # default to 100 for all 147 | 148 | # create parameters for source amplitudes 149 | self.sigmas = nn.Parameter(torch.tensor(init_sigma), 150 | requires_grad=True) 151 | self.x_s = nn.Parameter(torch.ones(num_gaussians) * x_s0, 152 | requires_grad=True) 153 | self.y_s = nn.Parameter(torch.ones(num_gaussians) * y_s0, 154 | requires_grad=True) 155 | self.amplitudes = nn.Parameter(torch.ones(num_gaussians) / (num_gaussians) * init_amp, 156 | requires_grad=True) 157 | 158 | self.dc_term = nn.Parameter(torch.zeros(1), 159 | requires_grad=True) 160 | 161 | self.x_dim = None 162 | self.y_dim = None 163 | 164 | def forward(self, phases): 165 | # create DC term, then add the gaussians 166 | source_amp = torch.ones_like(phases) * self.dc_term 167 | for i in range(self.num_gaussians): 168 | source_amp += self.create_gaussian(phases.shape, i) 169 | 170 | return source_amp 171 | 172 | def create_gaussian(self, shape, idx): 173 | # create sampling grid if needed 174 | if self.x_dim is None or self.y_dim is None: 175 | self.x_dim = torch.linspace(-(shape[-1] - 1) / 2, 176 | (shape[-1] - 1) / 2, 177 | shape[-1], device=self.dc_term.device) 178 | self.y_dim = torch.linspace(-(shape[-2] - 1) / 2, 179 | (shape[-2] - 1) / 2, 180 | shape[-2], device=self.dc_term.device) 181 | 182 | if self.x_dim.device != self.sigmas.device: 183 | self.x_dim.to(self.sigmas.device).detach() 184 | self.x_dim.requires_grad = False 185 | if self.y_dim.device != self.sigmas.device: 186 | self.y_dim.to(self.sigmas.device).detach() 187 | self.y_dim.requires_grad = False 188 | 189 | # offset grid by coordinate and compute x and y gaussian components 190 | x_gaussian = torch.exp(-0.5 * torch.pow(torch.div(self.x_dim - self.x_s[idx], self.sigmas[idx]), 2)) 191 | y_gaussian = torch.exp(-0.5 * torch.pow(torch.div(self.y_dim - self.y_s[idx], self.sigmas[idx]), 2)) 192 | 193 | # outer product with amplitude scaling 194 | gaussian = torch.ger(self.amplitudes[idx] * y_gaussian, x_gaussian) 195 | 196 | return gaussian 197 | 198 | 199 | class ModelPropagate(nn.Module): 200 | """Parameterized light transport model, propagates a SLM phase with multipart propagation, including 201 | learnable Zernike phase, source amplitude, and phase LUT corrections, etc.... 202 | 203 | Class initialization parameters 204 | ------------------------------- 205 | distance: propagation dist between SLM and target, in meters, default 0.1 206 | wavelength: the wavelength of interest, in meters, default 520e-9 207 | feature_size: the SLM pixel pitch, in meters, default 6.4e-6 208 | num_coeffs: number of Zernike basis function coeffs to learn, default 15 209 | num_layers: number of layers in phase LUT correction convnet, default 5 210 | num_features: number of features per layer of LUT convnet, default 32 211 | num_output_feat: number of "attention" layers, per-pixel parameters, set 0 if not using. default 0 212 | num_gaussians: number of Gaussians to use in source amp model, default 3 213 | init_sigma: initial spread of Gaussians, in pixels, default 100 214 | learn_dist: if True, makes distance a learnable parameter, default False 215 | init_coeffs: initial value for Zernike coefficients 216 | use_conv1d_mlp: if False, disable phase LUT correction, default True 217 | norm: norm (e.g., nn.BatchNorm2d) to use in LUT convnet, default None 218 | proptype: chooses the propagation operator ('ASM': propagation_ASM, 219 | 'fresnel': propagation_fresnel). Default ASM. 220 | linear_conv: if True, pads for linear conv for propagation, default True 221 | 222 | Usage 223 | ----- 224 | Functions as a pytorch module: 225 | 226 | >>> propagate_model = ModelPropagate(...) 227 | >>> output_complex = propagate_model(slm_phase) 228 | 229 | slm_phase: encoded phase-only representation at SLM plane , with dimensions 230 | [batch, 1, height, width] 231 | output_complex: complex field at the target plane, with dimensions [batch, 232 | 1, height, width, 2], where the final dimension is stacked real and 233 | imaginary values 234 | """ 235 | 236 | def __init__(self, distance=0.1, wavelength=520e-9, feature_size=6.4e-6, image_res=(1080, 1920), learn_dist=False, 237 | target_field=True, num_gaussians=3, init_sigma=(1300.0, 1500.0, 1700.0), init_amp=0.9, 238 | num_coeffs=0, num_coeffs_fourier=5, init_coeffs=0.0, 239 | use_conv1d_mlp=True, num_layers=3, num_features=16, num_latent_codes=None, norm=nn.GroupNorm, 240 | blur=None, 241 | content_field=True, num_layers_cdp=5, num_feats_cdp=8, latent_coords=False, 242 | proptype='ASM', linear_conv=True): 243 | super(ModelPropagate, self).__init__() 244 | 245 | # Section 5.1.1. Content-independent Source & Target Field variation 246 | if num_gaussians: 247 | self.source_amp = SourceAmplitude(num_gaussians, init_sigma, init_amp=init_amp, x_s0=0.0, y_s0=0.0) 248 | else: 249 | self.source_amp = None 250 | if target_field: 251 | self.target_constant_amp = nn.Parameter(0.07 * torch.ones(1, 1, *image_res), requires_grad=True) 252 | self.target_constant_phase = nn.Parameter(torch.zeros((1, 1, *image_res)), requires_grad=True) 253 | else: 254 | self.target_constant_amp, self.target_constant_phase = None, None 255 | 256 | # Section 5.1.2 Modeling Optical Propagation with Aberrations 257 | if num_coeffs: 258 | self.coeffs = nn.Parameter(torch.ones(num_coeffs) * init_coeffs, 259 | requires_grad=True) 260 | else: 261 | self.coeffs = None 262 | if num_coeffs_fourier: 263 | self.coeffs_fourier = nn.Parameter(torch.ones(num_coeffs_fourier) * init_coeffs, 264 | requires_grad=True) 265 | else: 266 | self.coeffs_fourier = None 267 | 268 | # Section 5.1.3. Phase nonlinearity 269 | if num_latent_codes is None: 270 | num_latent_codes = [2, 0, 0] 271 | 272 | if use_conv1d_mlp: 273 | self.process_phase = LatentCodedMLP(num_layers, num_features, norm=norm, num_latent_codes=num_latent_codes) 274 | else: 275 | self.process_phase = None 276 | 277 | if sum(num_latent_codes) > 0: 278 | self.latent_code = nn.Parameter(torch.zeros(1, sum(num_latent_codes), *image_res), requires_grad=True) 279 | else: 280 | self.latent_code = None 281 | 282 | # Section 5.1.4. Content-dependent Undiffracted Light 283 | if content_field: 284 | self.content_dependent_field = ContentDependentField(num_layers=num_layers_cdp, num_features=num_feats_cdp, norm=nn.GroupNorm, latent_coords=latent_coords) 285 | else: 286 | self.content_dependent_field = None 287 | 288 | if latent_coords: 289 | latent_x = np.linspace(-1.0, 1.0, image_res[1]) 290 | latent_y = np.linspace(-1.0 * image_res[0] / image_res[1], 291 | 1.0 * image_res[0] / image_res[1], image_res[0]) 292 | lx, ly = np.meshgrid(latent_x, latent_y) 293 | self.latent_coords = nn.Parameter(torch.from_numpy(np.stack((lx, ly), 0)).type(torch.float32).reshape(1, 2, *image_res), requires_grad=False) 294 | else: 295 | self.latent_coords = None 296 | 297 | self.learn_dist = learn_dist 298 | if learn_dist: 299 | self.distance = nn.Parameter(torch.tensor(distance, dtype=torch.float), 300 | requires_grad=True) 301 | else: 302 | self.distance = distance 303 | 304 | if blur is not None: 305 | self.blur = blur 306 | self.blur = Conv2dSame(1, 1, kernel_size=3, bias=False) 307 | self.blur.net[1].weight = nn.Parameter(blur, requires_grad=False) 308 | else: 309 | self.blur = None 310 | 311 | # propagation parameters 312 | self.wavelength = wavelength 313 | self.feature_size = (feature_size 314 | if hasattr(feature_size, '__len__') 315 | else [feature_size] * 2) 316 | 317 | self.zernike = None 318 | self.zernike_fourier = None 319 | self.zernike_eval = None 320 | self.zernike_eval_fourier = None 321 | self.precomped_H = None 322 | self.precomped_H_exp = None 323 | 324 | # change out the propagation operator 325 | if proptype == 'ASM': 326 | self.prop = propagation_ASM 327 | self.prop_zernike = propagation_ASM_zernike 328 | self.prop_zernike_fourier = propagation_ASM_zernike_fourier 329 | 330 | self.linear_conv = linear_conv 331 | 332 | # set a device for initializing the precomputed objects 333 | try: 334 | self.dev = next(self.parameters()).device 335 | except StopIteration: # no parameters 336 | self.dev = torch.device('cpu') 337 | 338 | def forward(self, phases, skip_lut=False, skip_tm=False): 339 | 340 | # Section 5.1.3. Modeling Phase Nonlinearity 341 | if self.process_phase is not None and not skip_lut: 342 | if self.latent_code is not None: 343 | # support mini-batch 344 | processed_phase = self.process_phase(phases, self.latent_code.repeat(phases.shape[0], 1, 1, 1)) 345 | else: 346 | processed_phase = self.process_phase(phases) 347 | else: 348 | processed_phase = phases 349 | 350 | # Section 5.1.1. Create Source Amplitude (DC + gaussians) 351 | if self.source_amp is not None: 352 | source_amp = self.source_amp(processed_phase) 353 | else: 354 | source_amp = torch.ones_like(processed_phase) 355 | 356 | # convert phase to real and imaginary 357 | real, imag = utils.polar_to_rect(source_amp, processed_phase) 358 | processed_complex = torch.complex(real, imag) 359 | 360 | # Section 5.1.2. precompute the zernike basis only once 361 | if self.zernike is None and self.coeffs is not None: 362 | self.zernike = compute_zernike_basis(self.coeffs.size()[0], 363 | phases.size()[-2:], wo_piston=True) 364 | self.zernike = self.zernike.to(self.dev).detach() 365 | self.zernike.requires_grad = False 366 | 367 | if self.zernike_fourier is None and self.coeffs_fourier is not None: 368 | self.zernike_fourier = compute_zernike_basis(self.coeffs_fourier.size()[0], 369 | [i * 2 for i in phases.size()[-2:]], 370 | wo_piston=True) 371 | self.zernike_fourier = self.zernike_fourier.to(self.dev).detach() 372 | self.zernike_fourier.requires_grad = False 373 | 374 | if not self.training and self.zernike_eval is None and self.coeffs is not None: 375 | # sum the phases 376 | self.zernike_eval = combine_zernike_basis(self.coeffs, self.zernike) 377 | self.zernike_eval = self.zernike_eval.to(self.coeffs.device).detach() 378 | self.zernike_eval.requires_grad = False 379 | 380 | if not self.training and self.zernike_eval_fourier is None and self.coeffs_fourier is not None: 381 | # sum the phases 382 | self.zernike_eval_fourier = combine_zernike_basis(self.coeffs_fourier, self.zernike_fourier) 383 | self.zernike_eval_fourier = utils.ifftshift(self.zernike_eval_fourier) 384 | self.zernike_eval_fourier = self.zernike_eval_fourier.to(self.coeffs_fourier.device).detach() 385 | self.zernike_eval_fourier.requires_grad = False 386 | 387 | # precompute the kernel only once 388 | if self.learn_dist and self.training: 389 | self.precompute_H_exp(processed_complex) 390 | else: 391 | self.precompute_H(processed_complex) 392 | 393 | # Section 5.1.2. apply zernike and propagate 394 | if self.training: 395 | if self.coeffs_fourier is None: 396 | output_complex = self.prop_zernike(processed_complex, 397 | self.feature_size, 398 | self.wavelength, 399 | self.distance, 400 | coeffs=self.coeffs, 401 | zernike=self.zernike, 402 | precomped_H=self.precomped_H, 403 | precomped_H_exp=self.precomped_H_exp, 404 | linear_conv=self.linear_conv) 405 | else: 406 | output_complex = self.prop_zernike_fourier(processed_complex, 407 | self.feature_size, 408 | self.wavelength, 409 | self.distance, 410 | coeffs=self.coeffs_fourier, 411 | zernike=self.zernike_fourier, 412 | precomped_H=self.precomped_H, 413 | precomped_H_exp=self.precomped_H_exp, 414 | linear_conv=self.linear_conv) 415 | 416 | else: 417 | if self.coeffs is not None: 418 | # in primal domain 419 | processed_zernike = self.zernike_eval * processed_complex 420 | else: 421 | processed_zernike = processed_complex 422 | 423 | if self.coeffs_fourier is not None: 424 | # in fourier domain 425 | precomped_H = self.zernike_eval_fourier * self.precomped_H 426 | else: 427 | precomped_H = self.precomped_H 428 | 429 | output_complex = self.prop(processed_zernike, 430 | self.feature_size, 431 | self.wavelength, 432 | self.distance, 433 | precomped_H=precomped_H, 434 | linear_conv=self.linear_conv) 435 | 436 | # Section 5.1.1. Content-independent field at target plane 437 | if self.target_constant_amp is not None: 438 | real, imag = utils.polar_to_rect(self.target_constant_amp, self.target_constant_phase) 439 | target_field = torch.complex(real, imag) 440 | output_complex = output_complex + target_field 441 | 442 | # Section 5.1.4. Content-dependent Undiffracted light 443 | if self.content_dependent_field is not None: 444 | if self.latent_coords is not None: 445 | cdf = self.content_dependent_field(phases, self.latent_coords.repeat(phases.shape[0], 1, 1, 1)) 446 | else: 447 | cdf = self.content_dependent_field(phases) 448 | real, imag = utils.polar_to_rect(cdf[..., 0], cdf[..., 1]) 449 | cdf_rect = torch.complex(real, imag) 450 | output_complex = output_complex + cdf_rect 451 | 452 | amp = output_complex.abs() 453 | _, phase = utils.rect_to_polar(output_complex.real, output_complex.imag) 454 | 455 | if self.blur is not None: 456 | amp = self.blur(amp) 457 | 458 | real, imag = utils.polar_to_rect(amp, phase) 459 | 460 | return torch.complex(real, imag) 461 | 462 | def precompute_H(self, processed_complex): 463 | if self.precomped_H is None: 464 | self.precomped_H = self.prop( 465 | processed_complex, 466 | self.feature_size, 467 | self.wavelength, 468 | self.distance, 469 | return_H=True, 470 | linear_conv=self.linear_conv) 471 | self.precomped_H = self.precomped_H.to(self.dev).detach() 472 | self.precomped_H.requires_grad = False 473 | 474 | def precompute_H_exp(self, processed_complex): 475 | if self.precomped_H_exp is None: 476 | self.precomped_H_exp = self.prop( 477 | processed_complex, 478 | self.feature_size, 479 | self.wavelength, 480 | self.distance, 481 | return_H_exp=True, 482 | linear_conv=self.linear_conv) 483 | self.precomped_H_exp = self.precomped_H_exp.to(self.dev).detach() 484 | self.precomped_H_exp.requires_grad = False 485 | 486 | def to(self, *args, **kwargs): 487 | slf = super().to(*args, **kwargs) 488 | if slf.zernike is not None: 489 | slf.zernike = slf.zernike.to(*args, **kwargs) 490 | if slf.zernike_eval is not None: 491 | slf.zernike_eval = slf.zernike_eval.to(*args, **kwargs) 492 | if slf.precomped_H is not None: 493 | slf.precomped_H = slf.precomped_H.to(*args, **kwargs) 494 | if slf.precomped_H_exp is not None: 495 | slf.precomped_H_exp = slf.precomped_H_exp.to(*args, **kwargs) 496 | # try setting dev based on some parameter, default to cpu 497 | try: 498 | slf.dev = next(slf.parameters()).device 499 | except StopIteration: # no parameters 500 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0] 501 | if device_arg is not None: 502 | slf.dev = device_arg 503 | return slf 504 | 505 | # override default training bool so we can detect eval/train switch 506 | @property 507 | def training(self): 508 | return self._training 509 | 510 | @training.setter 511 | def training(self, mode): 512 | if mode: 513 | self.zernike_eval = None # reset when switching to training 514 | self._training = mode 515 | -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | source ~/anaconda3/etc/profile.d/conda.sh 2 | conda create -n neural-holography python=3.6 scipy opencv 3 | conda activate neural-holography 4 | conda install -c conda-forge opencv 5 | conda install pytorch torchvision -c pytorch 6 | conda install -c conda-forge tensorboard 7 | conda install -c anaconda scikit-image 8 | pip install ConfigArgParse 9 | conda install -c conda-forge opencv 10 | pip install aotools 11 | 12 | -------------------------------------------------------------------------------- /train_holonet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Neural holography: 3 | 4 | This is the main script used for training the Holonet 5 | 6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 8 | # The material is provided as-is, with no warranties whatsoever. 9 | # If you publish any code, data, or scientific work based on this, please cite our work. 10 | 11 | @article{Peng:2020:NeuralHolography, 12 | author = {Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein}, 13 | title = {{Neural Holography with Camera-in-the-loop Training}}, 14 | journal = {ACM Trans. Graph. (SIGGRAPH Asia)}, 15 | year = {2020}, 16 | } 17 | 18 | Usage 19 | ----- 20 | 21 | $ python train_holonet.py --channel=1 --run_id=experiment_1 22 | 23 | 24 | """ 25 | import os 26 | import sys 27 | import math 28 | import torch 29 | import torch.nn as nn 30 | import torch.optim as optim 31 | from datetime import datetime 32 | import configargparse 33 | from tensorboardX import SummaryWriter 34 | 35 | import utils.utils as utils 36 | import utils.perceptualloss as perceptualloss 37 | 38 | from propagation_model import ModelPropagate 39 | from holonet import * 40 | from utils.augmented_image_loader import ImageLoader 41 | 42 | 43 | # Command line argument processing 44 | p = configargparse.ArgumentParser() 45 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 46 | p.add_argument('--channel', type=int, default=1, help='red:0, green:1, blue:2, rgb:3') 47 | p.add_argument('--run_id', type=str, default='', help='Experiment name', required=True) 48 | p.add_argument('--proptype', type=str, default='ASM', help='Ideal propagation model') 49 | p.add_argument('--generator_path', type=str, default='', help='Torch save of Holonet, start from pre-trained gen.') 50 | p.add_argument('--model_path', type=str, default='./models', help='Torch save CITL-calibrated model') 51 | p.add_argument('--num_epochs', type=int, default=5, help='Number of epochs') 52 | p.add_argument('--batch_size', type=int, default=1, help='Size of minibatch') 53 | p.add_argument('--lr', type=float, default=1e-3, help='learning rate of Holonet weights') 54 | p.add_argument('--scale_output', type=float, default=0.95, 55 | help='Scale of output applied to reconstructed intensity from SLM') 56 | p.add_argument('--loss_fun', type=str, default='vgg-low', help='Options: mse, l1, si_mse, vgg, vgg-low') 57 | p.add_argument('--purely_unet', type=utils.str2bool, default=False, help='Use U-Net for entire estimation if True') 58 | p.add_argument('--model_lut', type=utils.str2bool, default=True, help='Activate the lut of model') 59 | p.add_argument('--disable_loss_amp', type=utils.str2bool, default=True, help='Disable manual amplitude loss') 60 | p.add_argument('--num_latent_codes', type=int, default=2, help='Number of latent codes of trained prop model.') 61 | p.add_argument('--step_lr', type=utils.str2bool, default=False, help='Use of lr scheduler') 62 | p.add_argument('--perfect_prop_model', type=utils.str2bool, default=False, 63 | help='Use ideal ASM as the loss function') 64 | p.add_argument('--manual_aberr_corr', type=utils.str2bool, default=True, 65 | help='Divide source amplitude manually, (possibly apply inverse zernike of primal domain') 66 | 67 | # parse arguments 68 | opt = p.parse_args() 69 | channel = opt.channel 70 | run_id = opt.run_id 71 | loss_fun = opt.loss_fun 72 | warm_start = opt.generator_path != '' 73 | chan_str = ('red', 'green', 'blue')[channel] 74 | 75 | # tensorboard setup and file naming 76 | time_str = str(datetime.now()).replace(' ', '-').replace(':', '-') 77 | writer = SummaryWriter(f'runs/{run_id}_{chan_str}_{time_str}') 78 | 79 | 80 | ############## 81 | # Parameters # 82 | ############## 83 | 84 | # units 85 | cm, mm, um, nm = 1e-2, 1e-3, 1e-6, 1e-9 86 | 87 | # Propagation parameters 88 | prop_dist = (20 * cm, 20 * cm, 20 * cm)[channel] 89 | wavelength = (638 * nm, 520 * nm, 450 * nm)[channel] 90 | feature_size = (6.4 * um, 6.4 * um) # SLM pitch 91 | homography_res = (880, 1600) # for CITL crop, see ImageLoader 92 | 93 | # Training parameters 94 | device = torch.device('cuda') 95 | use_mse_init = False # first 500 iters will be MSE regardless of loss_fun 96 | 97 | # Image data for training 98 | data_path = '/data/train1080' # path for training data 99 | 100 | if opt.model_path == '': 101 | opt.model_path = f'./models/{chan_str}.pth' 102 | 103 | # resolutions need to be divisible by powers of 2 for unet 104 | if opt.purely_unet: 105 | image_res = (1024, 2048) # 10 down layers 106 | else: 107 | image_res = (1072, 1920) # 4 down layers 108 | 109 | 110 | ############### 111 | # Load models # 112 | ############### 113 | 114 | # re-use parameters from CITL-calibrated model for out Holonet. 115 | if opt.perfect_prop_model: 116 | final_phase_num_in = 2 117 | 118 | # set model instance as naive ASM 119 | model_prop = ModelPropagate(distance=prop_dist, feature_size=feature_size, wavelength=wavelength, 120 | target_field=False, num_gaussians=0, num_coeffs_fourier=0, 121 | use_conv1d_mlp=False, num_latent_codes=[0], 122 | norm=None, blur=None, content_field=False, proptype=opt.proptype).to(device) 123 | 124 | zernike_coeffs = None 125 | source_amplitude = None 126 | latent_codes = None 127 | u_t = None 128 | else: 129 | if opt.manual_aberr_corr: 130 | final_phase_num_in = 2 + opt.num_latent_codes 131 | else: 132 | final_phase_num_in = 4 133 | blur = utils.make_kernel_gaussian(0.849, 3) 134 | 135 | # load camera model and set it into eval mode 136 | model_prop = ModelPropagate(distance=prop_dist, 137 | feature_size=feature_size, 138 | wavelength=wavelength, 139 | blur=blur).to(device) 140 | model_prop.load_state_dict(torch.load(opt.model_path, map_location=device)) 141 | 142 | # Here, we crop model parameters to match the Holonet resolution, which is slightly different from 1080p. 143 | # parameters for CITL model 144 | zernike_coeffs = model_prop.coeffs_fourier 145 | source_amplitude = model_prop.source_amp 146 | latent_codes = model_prop.latent_code 147 | latent_codes = utils.crop_image(latent_codes, target_shape=image_res, pytorch=True, stacked_complex=False) 148 | 149 | # content independent target field (See Section 5.1.1.) 150 | u_t_amp = utils.crop_image(model_prop.target_constant_amp, target_shape=image_res, stacked_complex=False) 151 | u_t_phase = utils.crop_image(model_prop.target_constant_phase, target_shape=image_res, stacked_complex=False) 152 | real, imag = utils.polar_to_rect(u_t_amp, u_t_phase) 153 | u_t = torch.complex(real, imag) 154 | 155 | # match the shape of forward model parameters (1072, 1920) 156 | 157 | # If you make it nn.Parameter, the Holonet will include these parameters in the torch.save files 158 | model_prop.latent_code = nn.Parameter(latent_codes.detach(), requires_grad=False) 159 | model_prop.target_constant_amp = nn.Parameter(u_t_amp.detach(), requires_grad=False) 160 | model_prop.target_constant_phase = nn.Parameter(u_t_phase.detach(), requires_grad=False) 161 | 162 | # But as these parameters are already in the CITL-calibrated models, 163 | # you can load these from those models avoiding duplicated saves. 164 | 165 | model_prop.eval() # ensure freezing propagation model 166 | 167 | # create new phase generator object 168 | if opt.purely_unet: 169 | phase_generator = PhaseOnlyUnet(num_features_init=32).to(device) 170 | else: 171 | phase_generator = HoloNet( 172 | distance=prop_dist, 173 | wavelength=wavelength, 174 | zernike_coeffs=zernike_coeffs, 175 | source_amplitude=source_amplitude, 176 | initial_phase=InitialPhaseUnet(4, 16), 177 | final_phase_only=FinalPhaseOnlyUnet(4, 16, num_in=final_phase_num_in), 178 | manual_aberr_corr=opt.manual_aberr_corr, 179 | target_field=u_t, 180 | latent_codes=latent_codes, 181 | proptype=opt.proptype).to(device) 182 | 183 | if warm_start: 184 | phase_generator.load_state_dict(torch.load(opt.generator_path, map_location=device)) 185 | 186 | phase_generator.train() # generator to be trained 187 | 188 | 189 | ################### 190 | # Set up training # 191 | ################### 192 | 193 | # loss function 194 | loss_fun = ['mse', 'l1', 'si_mse', 'vgg', 'ssim', 'vgg-low', 'l1-vgg'].index(loss_fun.lower()) 195 | 196 | if loss_fun == 0: # MSE loss 197 | loss = nn.MSELoss() 198 | elif loss_fun == 1: # L1 loss 199 | loss = nn.L1Loss() 200 | elif loss_fun == 2: # scale invariant MSE loss 201 | loss = nn.MSELoss() 202 | elif loss_fun == 3: # vgg perceptual loss 203 | loss = perceptualloss.PerceptualLoss() 204 | elif loss_fun == 5: 205 | loss = perceptualloss.PerceptualLoss(lambda_feat=0.025) 206 | loss_fun = 3 207 | 208 | mse_loss = nn.MSELoss() 209 | 210 | # upload to GPU 211 | loss = loss.to(device) 212 | mse_loss = mse_loss.to(device) 213 | 214 | if loss_fun == 2: 215 | scaleLoss = torch.ones(1) 216 | scaleLoss = scaleLoss.to(device) 217 | scaleLoss.requires_grad = True 218 | 219 | optvars = [scaleLoss, *phase_generator.parameters()] 220 | else: 221 | optvars = phase_generator.parameters() 222 | 223 | # set aside the VGG loss for the first num_mse_epochs 224 | if loss_fun == 3: 225 | vgg_loss = loss 226 | loss = mse_loss 227 | 228 | # create optimizer 229 | if warm_start: 230 | opt.lr /= 10 231 | optimizer = optim.Adam(optvars, lr=opt.lr) 232 | 233 | # loads images from disk, set to augment with flipping 234 | image_loader = ImageLoader(data_path, 235 | channel=channel, 236 | batch_size=opt.batch_size, 237 | image_res=image_res, 238 | homography_res=homography_res, 239 | shuffle=True, 240 | vertical_flips=True, 241 | horizontal_flips=True) 242 | 243 | num_mse_iters = 500 244 | num_mse_epochs = 1 if use_mse_init else 0 245 | opt.num_epochs += 1 if use_mse_init else 0 246 | 247 | # learning rate scheduler 248 | if opt.step_lr: 249 | scheduler = optim.lr_scheduler.StepLR(optimizer, 500, 0.5) 250 | 251 | 252 | ################# 253 | # Training loop # 254 | ################# 255 | 256 | for i in range(opt.num_epochs): 257 | # switch to actual loss function from mse when desired 258 | if i == num_mse_epochs: 259 | if loss_fun == 3: 260 | loss = vgg_loss 261 | 262 | for k, target in enumerate(image_loader): 263 | # cap iters on the mse epoch(s) 264 | if i < num_mse_epochs and k == num_mse_iters: 265 | break 266 | 267 | # get target image 268 | target_amp, target_res, target_filename = target 269 | target_amp = target_amp.to(device) 270 | 271 | # cropping to desired region for loss 272 | if homography_res is not None: 273 | target_res = homography_res 274 | else: 275 | target_res = target_res[0] # use resolution of first image in batch 276 | 277 | optimizer.zero_grad() 278 | 279 | # forward model 280 | slm_amp, slm_phase = phase_generator(target_amp) 281 | output_complex = model_prop(slm_phase) 282 | 283 | if loss_fun == 2: 284 | output_complex = output_complex * scaleLoss 285 | 286 | output_lin_intensity = torch.sum(output_complex.abs()**2 * opt.scale_output, dim=1, keepdim=True) 287 | 288 | output_amp = torch.pow(output_lin_intensity, 0.5) 289 | 290 | # VGG assumes RGB input, we just replicate 291 | if loss_fun == 3: 292 | target_amp = target_amp.repeat(1, 3, 1, 1) 293 | output_amp = output_amp.repeat(1, 3, 1, 1) 294 | 295 | # ignore the cropping and do full-image loss 296 | if 'nocrop' in run_id: 297 | loss_nocrop = loss(output_amp, target_amp) 298 | 299 | # crop outputs to the region we care about 300 | output_amp = utils.crop_image(output_amp, target_res, stacked_complex=False) 301 | target_amp = utils.crop_image(target_amp, target_res, stacked_complex=False) 302 | 303 | # force equal mean amplitude when checking loss 304 | if 'force_scale' in run_id: 305 | print('scale forced equal', end=' ') 306 | with torch.no_grad(): 307 | scaled_out = output_amp * target_amp.mean() / output_amp.mean() 308 | output_amp = output_amp + (scaled_out - output_amp).detach() 309 | 310 | # loss and optimize 311 | loss_main = loss(output_amp, target_amp) 312 | if warm_start or opt.disable_loss_amp: 313 | loss_amp = 0 314 | else: 315 | # extra loss term to encourage uniform amplitude at SLM plane 316 | # helps some networks converge properly initially 317 | loss_amp = mse_loss(slm_amp.mean(), slm_amp) 318 | 319 | loss_val = ((loss_nocrop if 'nocrop' in run_id else loss_main) 320 | + 0.1 * loss_amp) 321 | loss_val.backward() 322 | optimizer.step() 323 | 324 | if opt.step_lr: 325 | scheduler.step() 326 | 327 | # print and output to tensorboard 328 | ik = k + i * len(image_loader) 329 | if use_mse_init and i >= num_mse_epochs: 330 | ik += num_mse_iters - len(image_loader) 331 | print(f'iteration {ik}: {loss_val.item()}') 332 | 333 | with torch.no_grad(): 334 | writer.add_scalar('Loss', loss_val, ik) 335 | 336 | if ik % 50 == 0: 337 | # write images and loss to tensorboard 338 | writer.add_image('Target Amplitude', target_amp[0, ...], ik) 339 | 340 | # normalize reconstructed amplitude 341 | output_amp0 = output_amp[0, ...] 342 | maxVal = torch.max(output_amp0) 343 | minVal = torch.min(output_amp0) 344 | tmp = (output_amp0 - minVal) / (maxVal - minVal) 345 | writer.add_image('Reconstruction Amplitude', tmp, ik) 346 | 347 | # normalize SLM phase 348 | writer.add_image('SLM Phase', (slm_phase[0, ...] + math.pi) / (2 * math.pi), ik) 349 | 350 | if loss_fun == 2: 351 | writer.add_scalar('Scale factor', scaleLoss, ik) 352 | 353 | # save trained model 354 | if not os.path.isdir('checkpoints'): 355 | os.mkdir('checkpoints') 356 | torch.save(phase_generator.state_dict(), 357 | f'checkpoints/{run_id}_{chan_str}_{time_str}_{i+1}.pth') 358 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Neural holography: 3 | 4 | This is the main executive script used for training our parameterized wave propagation model 5 | 6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 8 | # The material is provided as-is, with no warranties whatsoever. 9 | # If you publish any code, data, or scientific work based on this, please cite our work. 10 | 11 | @article{Peng:2020:NeuralHolography, 12 | author = {Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein}, 13 | title = {{Neural Holography with Camera-in-the-loop Training}}, 14 | journal = {ACM Trans. Graph. (SIGGRAPH Asia)}, 15 | year = {2020}, 16 | } 17 | 18 | ----- 19 | 20 | $ python train_model.py --channel=1 --experiment=test 21 | 22 | """ 23 | 24 | 25 | import os 26 | import cv2 27 | import sys 28 | import time 29 | import torch 30 | import numpy as np 31 | import configargparse 32 | import skimage.util 33 | import torch.nn as nn 34 | import torch.optim as optim 35 | 36 | import utils.utils as utils 37 | from utils.modules import PhysicalProp 38 | from propagation_model import ModelPropagate 39 | from utils.augmented_image_loader import ImageLoader 40 | from utils.utils_tensorboard import SummaryModelWriter 41 | 42 | 43 | # Command line argument processing 44 | p = configargparse.ArgumentParser() 45 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 46 | 47 | p.add_argument('--channel', type=int, default=1, help='red:0, green:1, blue:2, rgb:3') 48 | p.add_argument('--pretrained_path', type=str, default='', help='Path of pretrained checkpoints as a starting point.') 49 | p.add_argument('--model_path', type=str, default='./models', help='Directory for saving out checkpoints') 50 | p.add_argument('--phase_path', type=str, default='./precomputed_phases', help='Directory for precalculated phases') 51 | p.add_argument('--calibration_path', type=str, default=f'./calibration', 52 | help='Directory where calibration phases are being stored.') 53 | p.add_argument('--lr_model', type=float, default=3e-3, help='Learning rate for model parameters') 54 | p.add_argument('--lr_phase', type=float, default=5e-3, help='Learning rate for phase') 55 | p.add_argument('--num_epochs', type=int, default=15, help='Number of epochs') 56 | p.add_argument('--batch_size', type=int, default=2, help='Size of minibatch') 57 | p.add_argument('--step_lr', type=utils.str2bool, default=True, help='Use of lr scheduler') 58 | p.add_argument('--experiment', type=str, default='', help='Name of the experiment') 59 | 60 | 61 | # parse arguments 62 | opt = p.parse_args() 63 | 64 | channel = opt.channel # Red:0 / Green:1 / Blue:2 65 | chan_str = ('red', 'green', 'blue')[channel] 66 | run_id = f'{chan_str}_{opt.experiment}_lr{opt.lr_model}_batchsize{opt.batch_size}' # {algorithm}_{prop_model} format 67 | 68 | print(f' - training parameterized wave propagataion model....') 69 | 70 | # units 71 | cm, mm, um, nm = 1e-2, 1e-3, 1e-6, 1e-9 72 | 73 | prop_dist = (20 * cm, 20 * cm, 20 * cm)[channel] # propagation distance from SLM plane to target plane 74 | wavelength = (638 * nm, 520 * nm, 450 * nm)[channel] # wavelength of each color 75 | feature_size = (6.4 * um, 6.4 * um) # SLM pitch 76 | slm_res = (1080, 1920) # resolution of SLM 77 | image_res = (1080, 1920) # 1080p dataset 78 | roi_res = (880, 1600) # regions of interest (to penalize) 79 | dtype = torch.float32 # default datatype (results may differ if using, e.g., float64) 80 | device = torch.device('cuda') # The gpu you are using 81 | 82 | # Options for the algorithm 83 | lr_s_phase = opt.lr_phase / 200 84 | loss_model = nn.MSELoss().to(device) # loss function for SGD (or perceptualloss.PerceptualLoss()) 85 | loss_phase = nn.MSELoss().to(device) 86 | loss_mse = nn.MSELoss().to(device) 87 | s0_phase = 1.0 # initial scale for phase optimization 88 | s0_model = 1.0 # initial scale for model training 89 | sa = torch.tensor(s0_phase, device=device, requires_grad=True) 90 | sb = torch.tensor(0.3, device=device, requires_grad=True) 91 | 92 | num_iters_model_update = 1 # number of iterations for model-training subloops 93 | num_iters_phase_update = 1 # number of iterations for phase optimization 94 | 95 | # Path for data 96 | result_path = f'./models' 97 | model_path = opt.model_path # path for new model checkpoints 98 | utils.cond_mkdir(model_path) 99 | phase_path = opt.phase_path # path of precomputed phase pool 100 | data_path = f'./data/train1080' # path of targets 101 | 102 | 103 | # Hardware setup 104 | camera_prop = PhysicalProp(channel, laser_arduino=True, roi_res=(roi_res[1], roi_res[0]), slm_settle_time=0.15, 105 | range_row=(220, 1000), range_col=(300, 1630), 106 | patterns_path=opt.calibration_path, # path of 21 x 12 calibration patterns, see Supplement. 107 | show_preview=True) 108 | 109 | # Model instance to train 110 | # Check propagation_model.py for the default parameter settings! 111 | blur = utils.make_kernel_gaussian(0.85, 3) # Optional, just be consistent with inference. 112 | model = ModelPropagate(distance=prop_dist, 113 | feature_size=feature_size, 114 | wavelength=wavelength, 115 | blur=blur).to(device) 116 | 117 | if opt.pretrained_path != '': 118 | print(f' - Start from pre-trained model: {opt.pretrained_model_path}') 119 | checkpoint = torch.load(opt.pretrained_model_path) 120 | model.load_state_dict(checkpoint) 121 | model = model.train() 122 | 123 | # Augmented image loader (If you want to shuffle, augment dataset, put options accordingly) 124 | image_loader = ImageLoader(data_path, 125 | channel=channel, 126 | batch_size=opt.batch_size, 127 | image_res=image_res, 128 | homography_res=roi_res, 129 | crop_to_homography=False, 130 | shuffle=True, 131 | vertical_flips=False, 132 | horizontal_flips=False) 133 | 134 | # optimizer for model training 135 | # Note that indeed, you can set lrs of each parameters different! (especially for Source Amplitude params) 136 | # But it works well with the same lr. 137 | optimizer_model = optim.Adam([{'params': [param for name, param in model.named_parameters() 138 | if 'source_amp' not in name and 'process_phase' not in name]}, 139 | {'params': model.source_amp.parameters(), 'lr': opt.lr_model * 1}, 140 | {'params': model.process_phase.parameters(), 'lr': opt.lr_model * 1}], 141 | lr=opt.lr_model) 142 | 143 | optimizer_phase_scale = optim.Adam([sa, sb], lr=lr_s_phase) 144 | if opt.step_lr: 145 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer_model, step_size=5, gamma=0.2) # 1/5 every 3 epoch 146 | 147 | # tensorboard writer 148 | summaries_dir = os.path.join('runs', run_id) 149 | utils.cond_mkdir(summaries_dir) 150 | writer = SummaryModelWriter(model, f'{summaries_dir}', ch=channel) 151 | 152 | i_acc = 0 153 | for e in range(opt.num_epochs): 154 | 155 | print(f' - Epoch {e+1} ...') 156 | # visualize all the modules in the model on tensorboard 157 | with torch.no_grad(): 158 | writer.visualize_model(e) 159 | 160 | for i, target in enumerate(image_loader): 161 | target_amp, _, target_filenames = target 162 | 163 | # extract indices of images 164 | idxs = [] 165 | for name in target_filenames: 166 | _, target_filename = os.path.split(name) 167 | idxs.append(target_filename.split('_')[-1]) 168 | target_amp = utils.crop_image(target_amp, target_shape=roi_res, stacked_complex=False).to(device) 169 | 170 | # load phases 171 | slm_phases = [] 172 | for k, idx in enumerate(idxs): 173 | # Load pre-computed phases 174 | # Instead, you can optimize phases from the scratch after a few number of iterations. 175 | if e > 0: 176 | phase_filename = os.path.join(phase_path, f'{chan_str}', f'{idx}.png') 177 | else: 178 | phase_filename = os.path.join(phase_path, f'{chan_str}', f'{idx}_{channel}', f'phasemaps_1000.png') 179 | slm_phase = skimage.io.imread(phase_filename) / np.iinfo(np.uint8).max 180 | 181 | # invert phase (our SLM setup) 182 | slm_phase = torch.tensor((1 - slm_phase) * 2 * np.pi - np.pi, 183 | dtype=dtype).reshape(1, 1, *slm_res).to(device) 184 | slm_phases.append(slm_phase) 185 | slm_phases = torch.cat(slm_phases, 0).detach().requires_grad_(True) 186 | 187 | # optimizer for phase 188 | optimizer_phase = optim.Adam([slm_phases], lr=opt.lr_phase) 189 | 190 | # 1) phase update loop 191 | model = model.eval() 192 | for j in range(max(e * num_iters_phase_update, 1)): 193 | optimizer_phase.zero_grad() 194 | 195 | # propagate forward through the model 196 | recon_field = model(slm_phases) 197 | recon_amp = recon_field.abs() 198 | model_amp = utils.crop_image(recon_amp, target_shape=roi_res, pytorch=True, stacked_complex=False) 199 | 200 | # calculate loss and backpropagate to phase 201 | with torch.no_grad(): 202 | scale_phase = (model_amp * target_amp).mean(dim=[-2, -1], keepdims=True) / \ 203 | (model_amp**2).mean(dim=[-2, -1], keepdims=True) 204 | 205 | # or we can optimize scale with regression and statistics of the image 206 | # scale_phase = target_amp.mean(dim=[-2,-1], keepdims=True).detach() * sa + sb 207 | 208 | loss_value_phase = loss_phase(scale_phase * model_amp, target_amp) 209 | loss_value_phase.backward() 210 | optimizer_phase.step() 211 | optimizer_phase_scale.step() 212 | 213 | # write phase (update phase pool) 214 | with torch.no_grad(): 215 | for k, idx in enumerate(idxs): 216 | phase_out_8bit = utils.phasemap_8bit(slm_phases[k, np.newaxis, ...].cpu().detach(), inverted=True) 217 | cv2.imwrite(os.path.join(phase_path, f'{idx}.png'), phase_out_8bit) 218 | 219 | # make slm phases 8bit variable as displayed 220 | slm_phases = utils.quantized_phase(slm_phases) 221 | 222 | # 2) display and capture 223 | camera_amp = [] 224 | with torch.no_grad(): 225 | # forward physical pass (display), capture and stack them in batch dimension 226 | for k, idx in enumerate(idxs): 227 | slm_phase = slm_phases[k, np.newaxis, ...] 228 | camera_amp.append(camera_prop(slm_phase)) 229 | camera_amp = torch.cat(camera_amp, 0) 230 | 231 | # 3) model update loop 232 | model = model.train() 233 | for j in range(num_iters_model_update): 234 | 235 | # zero grad 236 | optimizer_model.zero_grad() 237 | 238 | # propagate forward through the model 239 | recon_field = model(slm_phases) 240 | recon_amp = recon_field.abs() 241 | model_amp = utils.crop_image(recon_amp, target_shape=roi_res, pytorch=True, stacked_complex=False) 242 | 243 | # calculate loss and backpropagate to model parameters 244 | loss_value_model = loss_model(model_amp, camera_amp) 245 | loss_value_model.backward() 246 | optimizer_model.step() 247 | 248 | # write to tensorboard 249 | with torch.no_grad(): 250 | if i % 50 == 0: 251 | writer.add_scalar('Scale/sa', sa, i_acc) 252 | writer.add_scalar('Scale/sb', sb, i_acc) 253 | for idx_s in range(opt.batch_size): 254 | writer.add_scalar(f'Scale/model_vs_target_{idx_s}', scale_phase[idx_s], i_acc) 255 | writer.add_scalar('Loss/model_vs_target', loss_value_phase, i_acc) 256 | writer.add_scalar('Loss/model_vs_camera', loss_value_model, i_acc) 257 | writer.add_scalar('Loss/camera_vs_target', loss_mse(camera_amp * target_amp.mean() / camera_amp.mean(), 258 | target_amp), i_acc) 259 | if i % 50 == 0: 260 | recon = model_amp[0, ...] 261 | captured = camera_amp[0, ...] 262 | gt = target_amp[0, ...] / scale_phase[0, ...] 263 | max_amp = max(recon.max(), captured.max(), gt.max()) 264 | writer.add_image('Amp/recon', recon / max_amp, i_acc) 265 | writer.add_image('Amp/captured', captured / max_amp, i_acc) 266 | writer.add_image('Amp/target', gt / max_amp, i_acc) 267 | 268 | i_acc += 1 269 | 270 | # save model, every epoch 271 | torch.save(model.state_dict(), os.path.join(model_path, f'{run_id}_{e}epoch.pth')) 272 | if opt.step_lr: 273 | lr_scheduler.step() 274 | 275 | # disconnect everything 276 | if camera_prop is not None: 277 | camera_prop.disconnect() 278 | camera_prop.alc.disconnect() 279 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/neural-holography/d2e399014aa80844edffd98bca34d2df80a69c84/utils/__init__.py -------------------------------------------------------------------------------- /utils/arduino_laser_control_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing Arduino controller module using pyfirmata. 3 | If you don't want to automate the laser control using python script, just turn off laser_arduino option of PhysicalProp. 4 | 5 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 6 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 7 | # The material is provided as-is, with no warranties whatsoever. 8 | # If you publish any code, data, or scientific work based on this, please cite our work. 9 | 10 | Technical Paper: 11 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 12 | """ 13 | 14 | 15 | try: 16 | from pyfirmata import Arduino, util 17 | except (ImportError, ModuleNotFoundError): 18 | import pip 19 | pip.main(['install', 'pyfirmata']) 20 | from pyfirmata import Arduino, util 21 | 22 | import time 23 | 24 | 25 | class ArduinoLaserControl: 26 | 27 | color2index = {'r': 0, 'g': 1, 'b': 2} 28 | ind2color = ['r', 'g', 'b'] 29 | 30 | def __init__(self, port='/dev/cu.usbmodem14301', pins=None): 31 | """ 32 | 33 | :param port: the port of your Arduino. 34 | You may find the port at 35 | your Arduino program -> Tool tab -> port (Mac) 36 | Device manager -> Ports (COM & LPT) (Windows) 37 | 38 | :param pins: an array of Arduino d pins with PWM 39 | """ 40 | self.board = Arduino(port) 41 | self.pinNums = [6, 10, 11] if pins is None else pins 42 | self.pins = {} 43 | self.default_pin = self.board.get_pin(f'd:3:p') 44 | 45 | for c in self.color2index: 46 | self.pins[c] = self.board.get_pin(f'd:{self.pinNums[self.color2index[c]]}:p') 47 | 48 | def setPins(self, pins): 49 | """ 50 | set output pins of arduinos, for control 51 | 52 | :param pins: an array of RGB pin numbers - PWM capable - at your Arduino Uno (e.g. [9, 10, 11]) 53 | """ 54 | self.pinNums = pins 55 | for c in self.color2index: 56 | self.pins[c] = self.board.get_pin(f'd:{self.pinNums[self.color2index[c]]}:p') 57 | 58 | def setValue(self, colors, values): 59 | """ 60 | 61 | :param colors: an array or chars ('r' or 'g' or 'b'), single char (e.g. 'r') is acceptable 62 | :param values: an array of normalized values (corresponds to the percent in the control box) 63 | for each color. 64 | e.g. [0.4 0.1 1] 65 | if you want identical values for all colors just put a scalar 66 | """ 67 | 68 | # check whether parameter is scalar or array 69 | if isinstance(colors, list) is False: 70 | if len(colors) > 1: 71 | colors = colors[0] 72 | colors = [colors] 73 | 74 | numColors = len(colors) 75 | 76 | if not isinstance(values, list): 77 | values = [values] * numColors 78 | 79 | if len(values) != len(colors): 80 | print(" - LASER CONTROL : Please put the same number of values to 'colors' and 'values' ") 81 | return 82 | 83 | # turn on each color 84 | for i in range(numColors): 85 | 86 | # Detect color 87 | if colors[i] in self.color2index: 88 | pin = self.pins[colors[i]] 89 | else: 90 | # colors must be 'r' or(and) 'g' or(and) 'b' 91 | print(" - LASER CONTROL: Wrong colors for 'setValue' method, it must be 'r' or(and) 'g' or(and) 'b'") 92 | return 93 | 94 | print(f' - V[{colors[i]}] from Arduino : {values[i]:.3f}V\n') 95 | pin.write(values[i]) 96 | 97 | def switch_control_box(self, channel): 98 | """ 99 | switch color of laser through control box 100 | with D-Sub 9pin, but note that it uses only 4-bit encoding. 101 | 102 | R: 1100 103 | G: 1010 104 | B: 1001 105 | 106 | :param channel: integer, channel to switch (Red:0, Green:1, Blue:2) 107 | """ 108 | self.default_pin.write(1.0) 109 | time.sleep(10.0) 110 | 111 | if channel in [0, 1, 2]: 112 | self.pins[self.ind2color[channel]].write(1.0) 113 | for c in [0, 1, 2]: 114 | if c != channel: 115 | self.pins[self.ind2color[c]].write(0.0) 116 | else: 117 | print('turning off') 118 | for c in [0, 1, 2]: 119 | self.pins[self.ind2color[c]].write(0.0) 120 | time.sleep(10.0) 121 | self.default_pin.write(1.0) 122 | 123 | def turnOffAll(self): 124 | """ 125 | Feed 0000 to control box 126 | 127 | :return: 128 | """ 129 | for c in self.color2index: 130 | pin = self.pins[c] 131 | pin.write(0) 132 | self.switch_control_box(3) 133 | 134 | print(' - Turned off') 135 | 136 | def disconnect(self): 137 | self.turnOffAll() 138 | -------------------------------------------------------------------------------- /utils/augmented_image_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from imageio import imread 6 | from skimage.transform import resize 7 | 8 | import utils.utils as utils 9 | 10 | 11 | class ImageLoader: 12 | """Loads images a folder with augmentation for generator training 13 | 14 | Class initialization parameters 15 | ------------------------------- 16 | data_path: folder containing images 17 | channel: color channel to load (0, 1, 2 for R, G, B, None for all 3), 18 | default None 19 | batch_size: number of images to pass each iteration, default 1 20 | image_res: 2d dimensions to pad/crop the image to for final output, default 21 | (1080, 1920) 22 | homography_res: 2d dims to scale the image to before final crop to image_res 23 | for consistent resolutions (crops to preserve input aspect ratio), 24 | default (880, 1600) 25 | shuffle: True to randomize image order across batches, default True 26 | vertical_flips: True to augment with vertical flipping, default True 27 | horizontal_flips: True to augment with horizontal flipping, default True 28 | idx_subset: for the iterator, skip all but these images. Given as a list of 29 | indices corresponding to sorted filename order. Forces shuffle=False and 30 | batch_size=1. Defaults to None to not subset at all. 31 | crop_to_homography: if True, only crops the image instead of scaling to get 32 | to target homography resolution, default False 33 | 34 | Usage 35 | ----- 36 | To be used as an iterator: 37 | 38 | >>> image_loader = ImageLoader(...) 39 | >>> for ims, input_resolutions, filenames in image_loader: 40 | >>> ... 41 | 42 | ims: images in the batch after transformation and conversion to linear 43 | amplitude, with dimensions [batch, channel, height, width] 44 | input_resolutions: list of length batch_size containing tuples of the 45 | original image height/width before scaling/cropping 46 | filenames: list of input image filenames, without extension 47 | 48 | Alternatively, can be used to manually load a single image: 49 | 50 | >>> ims, input_resolutions, filenames = image_loader.load_image(idx) 51 | 52 | idx: the index for the image to load, indices are alphabetical based on the 53 | file path. 54 | """ 55 | 56 | def __init__(self, data_path, channel=None, batch_size=1, 57 | image_res=(1080, 1920), homography_res=(880, 1600), 58 | shuffle=True, vertical_flips=True, horizontal_flips=True, 59 | idx_subset=None, crop_to_homography=False): 60 | if not os.path.isdir(data_path): 61 | raise NotADirectoryError(f'Data folder: {data_path}') 62 | self.data_path = data_path 63 | self.channel = channel 64 | self.batch_size = batch_size 65 | self.shuffle = shuffle 66 | self.image_res = image_res 67 | self.homography_res = homography_res 68 | self.subset = idx_subset 69 | self.crop_to_homography = crop_to_homography 70 | 71 | self.augmentations = [] 72 | if vertical_flips: 73 | self.augmentations.append(self.augment_vert) 74 | if horizontal_flips: 75 | self.augmentations.append(self.augment_horz) 76 | # store the possible states for enumerating augmentations 77 | self.augmentation_states = [fn() for fn in self.augmentations] 78 | 79 | self.im_names = get_image_filenames(data_path) 80 | self.im_names.sort() 81 | 82 | # if subsetting indices, force no randomization and batch size 1 83 | if self.subset is not None: 84 | self.shuffle = False 85 | self.batch_size = 1 86 | 87 | # create list of image IDs with augmentation state 88 | self.order = ((i,) for i in range(len(self.im_names))) 89 | for aug_type in self.augmentations: 90 | states = aug_type() # empty call gets possible states 91 | # augment existing list with new entry to states tuple 92 | self.order = ((*prev_states, s) 93 | for prev_states in self.order 94 | for s in states) 95 | self.order = list(self.order) 96 | 97 | def __iter__(self): 98 | self.ind = 0 99 | if self.shuffle: 100 | random.shuffle(self.order) 101 | return self 102 | 103 | def __next__(self): 104 | if self.subset is not None: 105 | while self.ind not in self.subset and self.ind < len(self.order): 106 | self.ind += 1 107 | 108 | if self.ind < len(self.order): 109 | batch_ims = self.order[self.ind:self.ind+self.batch_size] 110 | self.ind += self.batch_size 111 | return self.load_batch(batch_ims) 112 | else: 113 | raise StopIteration 114 | 115 | def __len__(self): 116 | if self.subset is None: 117 | return len(self.order) 118 | else: 119 | return len(self.subset) 120 | 121 | def load_batch(self, images): 122 | im_res_name = [self.load_image(*im_data) for im_data in images] 123 | ims = torch.stack([im for im, _, _ in im_res_name], 0) 124 | return (ims, 125 | [res for _, res, _ in im_res_name], 126 | [name for _, _, name in im_res_name]) 127 | 128 | def load_image(self, filenum, *augmentation_states): 129 | im = imread(self.im_names[filenum]) 130 | 131 | if len(im.shape) < 3: 132 | im = np.repeat(im[:, :, np.newaxis], 3, axis=2) # augment channels for gray images 133 | 134 | if self.channel is None: 135 | im = im[..., :3] # remove alpha channel, if any 136 | else: 137 | # select channel while keeping dims 138 | im = im[..., self.channel, np.newaxis] 139 | 140 | im = utils.im2float(im, dtype=np.float64) # convert to double, max 1 141 | 142 | # linearize intensity and convert to amplitude 143 | low_val = im <= 0.04045 144 | im[low_val] = 25 / 323 * im[low_val] 145 | im[np.logical_not(low_val)] = ((200 * im[np.logical_not(low_val)] + 11) 146 | / 211) ** (12 / 5) 147 | im = np.sqrt(im) # to amplitude 148 | 149 | # move channel dim to torch convention 150 | im = np.transpose(im, axes=(2, 0, 1)) 151 | 152 | # apply data augmentation 153 | for fn, state in zip(self.augmentations, augmentation_states): 154 | im = fn(im, state) 155 | 156 | # normalize resolution 157 | input_res = im.shape[-2:] 158 | if self.crop_to_homography: 159 | im = pad_crop_to_res(im, self.homography_res) 160 | else: 161 | im = resize_keep_aspect(im, self.homography_res) 162 | im = pad_crop_to_res(im, self.image_res) 163 | 164 | return (torch.from_numpy(im).float(), 165 | input_res, 166 | os.path.splitext(self.im_names[filenum])[0]) 167 | 168 | def augment_vert(self, image=None, flip=False): 169 | if image is None: 170 | return (True, False) # return possible augmentation values 171 | 172 | if flip: 173 | return image[..., ::-1, :] 174 | return image 175 | 176 | def augment_horz(self, image=None, flip=False): 177 | if image is None: 178 | return (True, False) # return possible augmentation values 179 | 180 | if flip: 181 | return image[..., ::-1] 182 | return image 183 | 184 | 185 | def get_image_filenames(dir): 186 | """Returns all files in the input directory dir that are images""" 187 | image_types = ('jpg', 'jpeg', 'tiff', 'tif', 'png', 'bmp', 'gif') 188 | files = os.listdir(dir) 189 | exts = (os.path.splitext(f)[1] for f in files) 190 | images = [os.path.join(dir, f) 191 | for e, f in zip(exts, files) 192 | if e[1:] in image_types] 193 | return images 194 | 195 | 196 | def resize_keep_aspect(image, target_res, pad=False): 197 | """Resizes image to the target_res while keeping aspect ratio by cropping 198 | 199 | image: an 3d array with dims [channel, height, width] 200 | target_res: [height, width] 201 | pad: if True, will pad zeros instead of cropping to preserve aspect ratio 202 | """ 203 | im_res = image.shape[-2:] 204 | 205 | # finds the resolution needed for either dimension to have the target aspect 206 | # ratio, when the other is kept constant. If the image doesn't have the 207 | # target ratio, then one of these two will be larger, and the other smaller, 208 | # than the current image dimensions 209 | resized_res = (int(np.ceil(im_res[1] * target_res[0] / target_res[1])), 210 | int(np.ceil(im_res[0] * target_res[1] / target_res[0]))) 211 | 212 | # only pads smaller or crops larger dims, meaning that the resulting image 213 | # size will be the target aspect ratio after a single pad/crop to the 214 | # resized_res dimensions 215 | if pad: 216 | image = utils.pad_image(image, resized_res, pytorch=False) 217 | else: 218 | image = utils.crop_image(image, resized_res, pytorch=False) 219 | 220 | # switch to numpy channel dim convention, resize, switch back 221 | image = np.transpose(image, axes=(1, 2, 0)) 222 | image = resize(image, target_res, mode='reflect') 223 | return np.transpose(image, axes=(2, 0, 1)) 224 | 225 | 226 | def pad_crop_to_res(image, target_res): 227 | """Pads with 0 and crops as needed to force image to be target_res 228 | 229 | image: an array with dims [..., channel, height, width] 230 | target_res: [height, width] 231 | """ 232 | return utils.crop_image(utils.pad_image(image, 233 | target_res, pytorch=False), 234 | target_res, pytorch=False) 235 | -------------------------------------------------------------------------------- /utils/calibration_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing the calibration module, basically calculating homography matrix. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 11 | """ 12 | 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | import cv2 16 | 17 | 18 | def circle_detect(captured_img, num_circles, spacing, pad_pixels=(0., 0.), show_preview=True): 19 | """ 20 | Detects the circle of a circle board pattern 21 | 22 | :param captured_img: captured image 23 | :param num_circles: a tuple of integers, (num_circle_x, num_circle_y) 24 | :param spacing: a tuple of integers, in pixels, (space between circles in x, space btw circs in y direction) 25 | :param show_preview: boolean, default True 26 | :param pad_pixels: coordinate of the left top corner of warped image. 27 | Assuming pad this amount of pixels on the other side. 28 | :return: a tuple, (found_dots, H) 29 | found_dots: boolean, indicating success of calibration 30 | H: a 3x3 homography matrix (numpy) 31 | """ 32 | 33 | # Binarization 34 | # org_copy = org.copy() # Otherwise, we write on the original image! 35 | img = (captured_img.copy() * 255).astype(np.uint8) 36 | if len(img.shape) > 2: 37 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 38 | 39 | img = cv2.medianBlur(img, 15) 40 | img_gray = img.copy() 41 | 42 | img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 121, 0) 43 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) 44 | img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel) 45 | img = 255 - img 46 | 47 | # Blob detection 48 | params = cv2.SimpleBlobDetector_Params() 49 | 50 | # Change thresholds 51 | params.filterByColor = True 52 | params.minThreshold = 128 53 | 54 | # Filter by Area. 55 | params.filterByArea = True 56 | params.minArea = 50 57 | 58 | # Filter by Circularity 59 | params.filterByCircularity = True 60 | params.minCircularity = 0.785 61 | 62 | # Filter by Convexity 63 | params.filterByConvexity = True 64 | params.minConvexity = 0.87 65 | 66 | # Filter by Inertia 67 | params.filterByInertia = False 68 | params.minInertiaRatio = 0.01 69 | 70 | detector = cv2.SimpleBlobDetector_create(params) 71 | 72 | # Detecting keypoints 73 | # this is redundant for what comes next, but gives us access to the detected dots for debug 74 | keypoints = detector.detect(img) 75 | found_dots, centers = cv2.findCirclesGrid(img, num_circles, 76 | blobDetector=detector, flags=cv2.CALIB_CB_SYMMETRIC_GRID) 77 | 78 | # Drawing the keypoints 79 | cv2.drawChessboardCorners(captured_img, num_circles, centers, found_dots) 80 | img_gray = cv2.drawKeypoints(img_gray, keypoints, np.array([]), (0, 255, 0), 81 | cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS) 82 | 83 | # Find transformation 84 | H = np.array([[1., 0., 0.], 85 | [0., 1., 0.], 86 | [0., 0., 1.]], dtype=np.float32) 87 | if found_dots: 88 | # Generate reference points to compute the homography 89 | ref_pts = np.zeros((num_circles[0] * num_circles[1], 1, 2), np.float32) 90 | pos = 0 91 | for i in range(0, num_circles[1]): 92 | for j in range(0, num_circles[0]): 93 | ref_pts[pos, 0, :] = spacing * np.array([j, i]) + np.array(pad_pixels) 94 | pos += 1 95 | 96 | H, mask = cv2.findHomography(centers, ref_pts, cv2.RANSAC, 1) 97 | if show_preview: 98 | dsize = [int((num_circs - 1) * space + 2 * pad_pixs) 99 | for num_circs, space, pad_pixs in zip(num_circles, spacing, pad_pixels)] 100 | captured_img_warp = cv2.warpPerspective(captured_img, H, tuple(dsize)) 101 | 102 | if show_preview: 103 | fig = plt.figure() 104 | 105 | ax = fig.add_subplot(223) 106 | ax.imshow(img_gray, cmap='gray') 107 | 108 | ax2 = fig.add_subplot(221) 109 | ax2.imshow(img, cmap='gray') 110 | 111 | ax3 = fig.add_subplot(222) 112 | ax3.imshow(captured_img, cmap='gray') 113 | 114 | if found_dots: 115 | ax4 = fig.add_subplot(224) 116 | ax4.imshow(captured_img_warp, cmap='gray') 117 | 118 | plt.show() 119 | 120 | return found_dots, H 121 | 122 | 123 | class Calibration: 124 | def __init__(self, num_circles=(21, 12), spacing_size=(80, 80), pad_pixels=(0, 0)): 125 | self.num_circles = num_circles 126 | self.spacing_size = spacing_size 127 | self.pad_pixels = pad_pixels 128 | self.h_transform = np.array([[1., 0., 0.], 129 | [0., 1., 0.], 130 | [0., 0., 1.]]) 131 | 132 | def calibrate(self, img, show_preview=True): 133 | found_corners, self.h_transform = circle_detect(img, self.num_circles, 134 | self.spacing_size, self.pad_pixels, show_preview) 135 | return found_corners 136 | 137 | def get_transform(self): 138 | return self.h_transform 139 | 140 | def __call__(self, input_img, img_size=None): 141 | """ 142 | This forward pass returns the warped image. 143 | 144 | :param input_img: A numpy grayscale image shape of [H, W]. 145 | :param img_size: output size, default None. 146 | :return: output_img: warped image with pre-calculated homography and destination size. 147 | """ 148 | 149 | if img_size is None: 150 | img_size = [int((num_circs - 1) * space + 2 * pad_pixs) 151 | for num_circs, space, pad_pixs in zip(self.num_circles, self.spacing_size, self.pad_pixels)] 152 | output_img = cv2.warpPerspective(input_img, self.h_transform, tuple(img_size)) 153 | 154 | return output_img 155 | -------------------------------------------------------------------------------- /utils/camera_capture_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing Camera module (PyCapture). 3 | Refer to this interface and modify it to match your camera SDK 4 | 5 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 6 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 7 | # The material is provided as-is, with no warranties whatsoever. 8 | # If you publish any code, data, or scientific work based on this, please cite our work. 9 | 10 | Technical Paper: 11 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 12 | """ 13 | 14 | import PyCapture2 15 | import cv2 16 | import numpy as np 17 | import time 18 | import utils.utils as utils 19 | 20 | 21 | def callback_captured(image): 22 | print(image.getData()) 23 | 24 | 25 | class CameraCapture: 26 | def __init__(self): 27 | self.bus = PyCapture2.BusManager() 28 | num_cams = self.bus.getNumOfCameras() 29 | if not num_cams: 30 | exit() 31 | 32 | def connect(self, i): 33 | uid = self.bus.getCameraFromIndex(i) 34 | self.camera_device = PyCapture2.Camera() 35 | self.camera_device.connect(uid) 36 | self.toggle_embedded_timestamp(True) 37 | 38 | def disconnect(self): 39 | self.toggle_embedded_timestamp(False) 40 | self.camera_device.disconnect() 41 | 42 | def toggle_embedded_timestamp(self, enable_timestamp): 43 | embedded_info = self.camera_device.getEmbeddedImageInfo() 44 | if embedded_info.available.timestamp: 45 | self.camera_device.setEmbeddedImageInfo(timestamp=enable_timestamp) 46 | 47 | def grab_images(self, num_images_to_grab=1): 48 | """ 49 | Retrieve the camera buffer and returns a list of grabbed images. 50 | 51 | :param num_images_to_grab: integer, default 1 52 | :return: a list of numpy 2d color images from the camera buffer. 53 | """ 54 | self.camera_device.startCapture() 55 | 56 | img_list = [] 57 | for i in range(num_images_to_grab): 58 | try: 59 | img = self.camera_device.retrieveBuffer() 60 | except PyCapture2.Fc2error as fc2Err: 61 | continue 62 | 63 | imgData = img.getData() 64 | 65 | # when using raw8 from the PG sensor 66 | # cv_image = np.array(img.getData(), dtype="uint8").reshape((img.getRows(), img.getCols())) 67 | 68 | # when using raw16 from the PG sensor - concat 2 8bits in a row 69 | imgData.dtype = np.uint16 70 | imgData = imgData.reshape(img.getRows(), img.getCols()) 71 | offset = 64 # offset that inherently exist. 72 | imgData = imgData - offset 73 | 74 | color_cv_image = cv2.cvtColor(imgData, cv2.COLOR_BAYER_RG2BGR) 75 | color_cv_image = utils.im2float(color_cv_image) 76 | img_list.append(color_cv_image.copy()) 77 | 78 | self.camera_device.stopCapture() 79 | return img_list 80 | 81 | def start_capture(self): 82 | # these two were previously inside the grab_images func, and can be clarified outside the loop 83 | self.camera_device.startCapture() 84 | 85 | def stop_capture(self): 86 | self.camera_device.stopCapture() 87 | -------------------------------------------------------------------------------- /utils/detect_heds_module_path.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #--------------------------------------------------------------------# 4 | # # 5 | # Copyright (C) 2020 HOLOEYE Photonics AG. All rights reserved. # 6 | # Contact: https://holoeye.com/contact/ # 7 | # # 8 | # This file is part of HOLOEYE SLM Display SDK. # 9 | # # 10 | # You may use this file under the terms and conditions of the # 11 | # 'HOLOEYE SLM Display SDK Standard License v1.0' license agreement. # 12 | # # 13 | #--------------------------------------------------------------------# 14 | 15 | 16 | # Please import this file in your scripts before actually importing the HOLOEYE SLM Display SDK, 17 | # i. e. copy this file to your project and use this code in your scripts: 18 | # 19 | # import detect_heds_module_path 20 | # import holoeye 21 | # 22 | # 23 | # Another option is to copy the holoeye module directory into your project and import by only using 24 | # import holoeye 25 | # This way, code completion etc. might work better. 26 | 27 | 28 | import os, sys 29 | from platform import system 30 | 31 | # Import the SLM Display SDK: 32 | HEDSModulePath = os.getenv('HEDS_2_PYTHON_MODULES', '') 33 | 34 | if HEDSModulePath == '': 35 | sdklocal = os.path.abspath(os.path.join(os.path.dirname(__file__), 36 | 'holoeye', 'slmdisplaysdk', '__init__.py')) 37 | if os.path.isfile(sdklocal): 38 | HEDSModulePath = os.path.dirname(os.path.dirname(os.path.dirname(sdklocal))) 39 | else: 40 | sdklocal = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 41 | 'sdk', 'holoeye', 'slmdisplaysdk', '__init__.py')) 42 | if os.path.isfile(sdklocal): 43 | HEDSModulePath = os.path.dirname(os.path.dirname(os.path.dirname(sdklocal))) 44 | 45 | if HEDSModulePath == '': 46 | if system() == 'Windows': 47 | print('\033[91m' 48 | '\nError: Could not find HOLOEYE SLM Display SDK installation path from environment variable. ' 49 | '\n\nPlease relogin your Windows user account and try again. ' 50 | '\nIf that does not help, please reinstall the SDK and then relogin your user account and try again. ' 51 | '\nA simple restart of the computer might fix the problem, too.' 52 | '\033[0m') 53 | else: 54 | print('\033[91m' 55 | '\nError: Could not detect HOLOEYE SLM Display SDK installation path. ' 56 | '\n\nPlease make sure it is present within the same folder or in "../../sdk".' 57 | '\033[0m') 58 | 59 | sys.exit(1) 60 | 61 | sys.path.append(HEDSModulePath) 62 | -------------------------------------------------------------------------------- /utils/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some modules for easy use. (No need to calculate kernels explicitly) 3 | 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | from algorithms import gerchberg_saxton, stochastic_gradient_descent, double_phase_amplitude_coding 8 | 9 | import os 10 | import time 11 | import skimage.io 12 | import utils.utils as utils 13 | import platform 14 | my_os = platform.system() 15 | if my_os == 'Windows': 16 | from utils.arduino_laser_control_module import ArduinoLaserControl 17 | from utils.camera_capture_module import CameraCapture 18 | from utils.calibration_module import Calibration 19 | from utils.slm_display_module import SLMDisplay 20 | 21 | 22 | class GS(nn.Module): 23 | """Classical Gerchberg-Saxton algorithm 24 | 25 | Class initialization parameters 26 | ------------------------------- 27 | :param prop_dist: propagation dist between SLM and target, in meters 28 | :param wavelength: the wavelength of interest, in meters 29 | :param feature_size: the SLM pixel pitch, in meters 30 | :param num_iters: the number of iteration, default 500 31 | :param phase_path: path to write intermediate results 32 | :param loss: loss function, default L2 33 | :param prop_model: chooses the propagation operator ('ASM': propagation_ASM, 34 | 'model': calibrated model). Default 'ASM'. 35 | :param propagator: propagator instance (function / pytorch module) 36 | :param device: torch.device 37 | 38 | Usage 39 | ----- 40 | Functions as a pytorch module: 41 | 42 | >>> gs = GS(...) 43 | >>> final_phase = gs(target_amp, init_phase) 44 | 45 | target_amp: amplitude at the target plane, with dimensions [batch, 1, height, width] 46 | init_phase: initial guess of phase of phase-only slm 47 | final_phase: optimized phase-only representation at SLM plane, same dimensions 48 | """ 49 | def __init__(self, prop_dist, wavelength, feature_size, num_iters, phase_path=None, 50 | prop_model='ASM', propagator=None, writer=None, device=torch.device('cuda')): 51 | super(GS, self).__init__() 52 | 53 | # Setting parameters 54 | self.prop_dist = prop_dist 55 | self.wavelength = wavelength 56 | self.feature_size = feature_size 57 | self.phase_path = phase_path 58 | self.precomputed_H_f = None 59 | self.precomputed_H_b = None 60 | self.prop_model = prop_model 61 | self.prop = propagator 62 | self.num_iters = num_iters 63 | self.writer = writer 64 | self.dev = device 65 | 66 | def forward(self, target_amp, init_phase=None): 67 | # Pre-compute propagataion kernel only once 68 | if self.precomputed_H_f is None and self.prop_model == 'ASM': 69 | self.precomputed_H_f = self.prop(torch.empty(*init_phase.shape, dtype=torch.complex64), self.feature_size, 70 | self.wavelength, self.prop_dist, return_H=True) 71 | self.precomputed_H_f = self.precomputed_H_f.to(self.dev).detach() 72 | self.precomputed_H_f.requires_grad = False 73 | 74 | if self.precomputed_H_b is None and self.prop_model == 'ASM': 75 | self.precomputed_H_b = self.prop(torch.empty(*init_phase.shape, dtype=torch.complex64), self.feature_size, 76 | self.wavelength, -self.prop_dist, return_H=True) 77 | self.precomputed_H_b = self.precomputed_H_b.to(self.dev).detach() 78 | self.precomputed_H_b.requires_grad = False 79 | 80 | # Run algorithm 81 | final_phase = gerchberg_saxton(init_phase, target_amp, self.num_iters, self.prop_dist, 82 | self.wavelength, self.feature_size, 83 | phase_path=self.phase_path, 84 | prop_model=self.prop_model, propagator=self.prop, 85 | precomputed_H_f=self.precomputed_H_f, precomputed_H_b=self.precomputed_H_b, 86 | writer=self.writer) 87 | return final_phase 88 | 89 | @property 90 | def phase_path(self): 91 | return self._phase_path 92 | 93 | @phase_path.setter 94 | def phase_path(self, phase_path): 95 | self._phase_path = phase_path 96 | 97 | 98 | class SGD(nn.Module): 99 | """Proposed Stochastic Gradient Descent Algorithm using Auto-diff Function of PyTorch 100 | 101 | Class initialization parameters 102 | ------------------------------- 103 | :param prop_dist: propagation dist between SLM and target, in meters 104 | :param wavelength: the wavelength of interest, in meters 105 | :param feature_size: the SLM pixel pitch, in meters 106 | :param num_iters: the number of iteration, default 500 107 | :param roi_res: region of interest to penalize the loss 108 | :param phase_path: path to write intermediate results 109 | :param loss: loss function, default L2 110 | :param prop_model: chooses the propagation operator ('ASM': propagation_ASM, 111 | 'model': calibrated model). Default 'ASM'. 112 | :param propagator: propagator instance (function / pytorch module) 113 | :param lr: learning rate for phase variables 114 | :param lr_s: learning rate for the learnable scale 115 | :param s0: initial scale 116 | :param writer: SummaryWrite instance for tensorboard 117 | :param device: torch.device 118 | 119 | Usage 120 | ----- 121 | Functions as a pytorch module: 122 | 123 | >>> sgd = SGD(...) 124 | >>> final_phase = sgd(target_amp, init_phase) 125 | 126 | target_amp: amplitude at the target plane, with dimensions [batch, 1, height, width] 127 | init_phase: initial guess of phase of phase-only slm 128 | final_phase: optimized phase-only representation at SLM plane, same dimensions 129 | """ 130 | def __init__(self, prop_dist, wavelength, feature_size, num_iters, roi_res, phase_path=None, prop_model='ASM', 131 | propagator=None, loss=nn.MSELoss(), lr=0.01, lr_s=0.003, s0=1.0, citl=False, camera_prop=None, 132 | writer=None, device=torch.device('cuda')): 133 | super(SGD, self).__init__() 134 | 135 | # Setting parameters 136 | self.prop_dist = prop_dist 137 | self.wavelength = wavelength 138 | self.feature_size = feature_size 139 | self.roi_res = roi_res 140 | self.phase_path = phase_path 141 | self.precomputed_H = None 142 | self.prop_model = prop_model 143 | self.prop = propagator 144 | 145 | self.num_iters = num_iters 146 | self.lr = lr 147 | self.lr_s = lr_s 148 | self.init_scale = s0 149 | 150 | self.citl = citl 151 | self.camera_prop = camera_prop 152 | 153 | self.writer = writer 154 | self.dev = device 155 | self.loss = loss.to(device) 156 | 157 | def forward(self, target_amp, init_phase=None): 158 | # Pre-compute propagataion kernel only once 159 | if self.precomputed_H is None and self.prop_model == 'ASM': 160 | self.precomputed_H = self.prop(torch.empty(*init_phase.shape, dtype=torch.complex64), self.feature_size, 161 | self.wavelength, self.prop_dist, return_H=True) 162 | self.precomputed_H = self.precomputed_H.to(self.dev).detach() 163 | self.precomputed_H.requires_grad = False 164 | 165 | # Run algorithm 166 | final_phase = stochastic_gradient_descent(init_phase, target_amp, self.num_iters, self.prop_dist, 167 | self.wavelength, self.feature_size, 168 | roi_res=self.roi_res, phase_path=self.phase_path, 169 | prop_model=self.prop_model, propagator=self.prop, 170 | loss=self.loss, lr=self.lr, lr_s=self.lr_s, s0=self.init_scale, 171 | citl=self.citl, camera_prop=self.camera_prop, 172 | writer=self.writer, 173 | precomputed_H=self.precomputed_H) 174 | return final_phase 175 | 176 | @property 177 | def init_scale(self): 178 | return self._init_scale 179 | 180 | @init_scale.setter 181 | def init_scale(self, s): 182 | self._init_scale = s 183 | 184 | @property 185 | def citl_hardwares(self): 186 | return self._citl_hardwares 187 | 188 | @citl_hardwares.setter 189 | def citl_hardwares(self, citl_hardwares): 190 | self._citl_hardwares = citl_hardwares 191 | 192 | @property 193 | def phase_path(self): 194 | return self._phase_path 195 | 196 | @phase_path.setter 197 | def phase_path(self, phase_path): 198 | self._phase_path = phase_path 199 | 200 | @property 201 | def prop(self): 202 | return self._prop 203 | 204 | @prop.setter 205 | def prop(self, prop): 206 | self._prop = prop 207 | 208 | 209 | class DPAC(nn.Module): 210 | """Double-phase Amplitude Coding 211 | 212 | Class initialization parameters 213 | ------------------------------- 214 | :param prop_dist: propagation dist between SLM and target, in meters 215 | :param wavelength: the wavelength of interest, in meters 216 | :param feature_size: the SLM pixel pitch, in meters 217 | :param prop_model: chooses the propagation operator ('ASM': propagation_ASM, 218 | 'model': calibrated model). Default 'ASM'. 219 | :param propagator: propagator instance (function / pytorch module) 220 | :param device: torch.device 221 | 222 | Usage 223 | ----- 224 | Functions as a pytorch module: 225 | 226 | >>> dpac = DPAC(...) 227 | >>> _, final_phase = dpac(target_amp, target_phase) 228 | 229 | target_amp: amplitude at the target plane, with dimensions [batch, 1, height, width] 230 | target_amp (optional): phase at the target plane, with dimensions [batch, 1, height, width] 231 | final_phase: optimized phase-only representation at SLM plane, same dimensions 232 | 233 | """ 234 | def __init__(self, prop_dist, wavelength, feature_size, prop_model='ASM', propagator=None, 235 | device=torch.device('cuda')): 236 | """ 237 | 238 | """ 239 | super(DPAC, self).__init__() 240 | 241 | # propagation is from target to SLM plane (one step) 242 | self.prop_dist = -prop_dist 243 | self.wavelength = wavelength 244 | self.feature_size = feature_size 245 | self.precomputed_H = None 246 | self.prop_model = prop_model 247 | self.prop = propagator 248 | self.dev = device 249 | 250 | def forward(self, target_amp, target_phase=None): 251 | if target_phase is None: 252 | target_phase = torch.zeros_like(target_amp) 253 | 254 | if self.precomputed_H is None and self.prop_model == 'ASM': 255 | self.precomputed_H = self.prop(torch.empty(*target_amp.shape, dtype=torch.complex64), self.feature_size, 256 | self.wavelength, self.prop_dist, return_H=True) 257 | self.precomputed_H = self.precomputed_H.to(self.dev).detach() 258 | self.precomputed_H.requires_grad = False 259 | 260 | final_phase = double_phase_amplitude_coding(target_phase, target_amp, self.prop_dist, 261 | self.wavelength, self.feature_size, 262 | prop_model=self.prop_model, propagator=self.prop, 263 | precomputed_H=self.precomputed_H) 264 | return None, final_phase 265 | 266 | @property 267 | def phase_path(self): 268 | return self._phase_path 269 | 270 | @phase_path.setter 271 | def phase_path(self, phase_path): 272 | self._phase_path = phase_path 273 | 274 | 275 | 276 | class PhysicalProp(nn.Module): 277 | """ A module for physical propagation, 278 | forward pass displays gets SLM pattern as an input and display the pattern on the physical setup, 279 | and capture the diffraction image at the target plane, 280 | and then return warped image using pre-calibrated homography from instantiation. 281 | 282 | Class initialization parameters 283 | ------------------------------- 284 | :param channel: 285 | :param slm_settle_time: 286 | :param roi_res: *** Note that the order of x / y is reversed here *** 287 | :param num_circles: 288 | :param laser_arduino: 289 | :param com_port: 290 | :param arduino_port_num: 291 | :param range_row: 292 | :param range_col: 293 | :param patterns_path: 294 | :param calibration_preview: 295 | 296 | Usage 297 | ----- 298 | Functions as a pytorch module: 299 | 300 | >>> camera_prop = PhysicalProp(...) 301 | >>> captured_amp = camera_prop(slm_phase) 302 | 303 | slm_phase: phase at the SLM plane, with dimensions [batch, 1, height, width] 304 | captured_amp: amplitude at the target plane, with dimensions [batch, 1, height, width] 305 | 306 | """ 307 | def __init__(self, channel=1, slm_settle_time=0.1, roi_res=(1600, 880), num_circles=(21, 12), 308 | laser_arduino=False, com_port='COM3', arduino_port_num=(6, 10, 11), 309 | range_row=(200, 1000), range_col=(300, 1700), 310 | patterns_path=f'F:/citl/calibration', show_preview=False): 311 | super(PhysicalProp, self).__init__() 312 | 313 | # 1. Connect Camera 314 | self.camera = CameraCapture() 315 | self.camera.connect(0) # specify the camera to use, 0 for main cam, 1 for the second cam 316 | 317 | # 2. Connect SLM 318 | self.slm = SLMDisplay() 319 | self.slm.connect() 320 | self.slm_settle_time = slm_settle_time 321 | 322 | # 3. Connect to the Arduino that switches rgb color through the laser control box. 323 | if laser_arduino: 324 | self.alc = ArduinoLaserControl(com_port, arduino_port_num) 325 | self.alc.switch_control_box(channel) 326 | else: 327 | self.alc = None 328 | 329 | # 4. Calibrate hardwares using homography 330 | calib_ptrn_path = os.path.join(patterns_path, f'{("red", "green", "blue")[channel]}.png') 331 | space_btw_circs = [int(roi / (num_circs - 1)) for roi, num_circs in zip(roi_res, num_circles)] 332 | 333 | self.calibrate(calib_ptrn_path, num_circles, space_btw_circs, 334 | range_row=range_row, range_col=range_col, show_preview=show_preview) 335 | 336 | def calibrate(self, calibration_pattern_path, num_circles, space_btw_circs, 337 | range_row, range_col, show_preview=False, num_grab_images=10): 338 | """ 339 | pre-calculate the homography between target plane and the camera captured plane 340 | 341 | :param calibration_pattern_path: 342 | :param num_circles: 343 | :param space_btw_circs: number of pixels between circles 344 | :param slm_settle_time: 345 | :param range_row: 346 | :param range_col: 347 | :param show_preview: 348 | :param num_grab_images: 349 | :return: 350 | """ 351 | 352 | self.calibrator = Calibration(num_circles, space_btw_circs) 353 | 354 | # supposed to be a grid pattern image (21 x 12) for calibration 355 | calib_phase_img = skimage.io.imread(calibration_pattern_path) 356 | self.slm.show_data_from_array(calib_phase_img) 357 | 358 | # sleep for 0.1s 359 | time.sleep(self.slm_settle_time) 360 | 361 | # capture displayed grid pattern image 362 | captured_intensities = self.camera.grab_images(num_grab_images) # capture 5-10 images for averaging 363 | captured_img = utils.burst_img_processor(captured_intensities) 364 | 365 | # masking out dot pattern region for homography 366 | captured_img_masked = captured_img[range_row[0]:range_row[1], range_col[0]:range_col[1], ...] 367 | calib_success = self.calibrator.calibrate(captured_img_masked, show_preview=show_preview) 368 | 369 | self.calibrator.start_row, self.calibrator.end_row = range_row 370 | self.calibrator.start_col, self.calibrator.end_col = range_col 371 | 372 | if calib_success: 373 | print(' - calibration success') 374 | else: 375 | raise ValueError(' - Calibration failed') 376 | 377 | def forward(self, slm_phase, num_grab_images=1): 378 | """ 379 | this forward pass gets slm_phase to display and returns the amplitude image at the target plane. 380 | 381 | :param slm_phase: 382 | :param num_grab_images: 383 | :return: A pytorch tensor shape of (1, 1, H, W) 384 | """ 385 | 386 | slm_phase_8bit = utils.phasemap_8bit(slm_phase, True) 387 | 388 | # display the pattern and capture linear intensity, after perspective transform 389 | captured_linear_np = self.capture_linear_intensity(slm_phase_8bit, num_grab_images=num_grab_images) 390 | 391 | # convert raw-16 linear intensity image into an amplitude tensor 392 | if len(captured_linear_np.shape) > 2: 393 | captured_linear = torch.tensor(captured_linear_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) 394 | captured_linear = captured_linear.to(slm_phase.device) 395 | captured_linear = torch.sum(captured_linear, dim=1, keepdim=True) 396 | else: 397 | captured_linear = torch.tensor(captured_linear_np, dtype=torch.float32).unsqueeze(0).unsqueeze(0) 398 | captured_linear = captured_linear.to(slm_phase.device) 399 | 400 | # return amplitude 401 | return torch.sqrt(captured_linear) 402 | 403 | def capture_linear_intensity(self, slm_phase, num_grab_images): 404 | """ 405 | 406 | :param slm_phase: 407 | :param num_grab_images: 408 | :return: 409 | """ 410 | 411 | # display on SLM and sleep for 0.1s 412 | self.slm.show_data_from_array(slm_phase) 413 | time.sleep(self.slm_settle_time) 414 | 415 | # capture and take average 416 | grabbed_images = self.camera.grab_images(num_grab_images) 417 | captured_intensity_raw_avg = utils.burst_img_processor(grabbed_images) # averaging 418 | 419 | # crop ROI as calibrated 420 | captured_intensity_raw_cropped = captured_intensity_raw_avg[ 421 | self.calibrator.start_row:self.calibrator.end_row, 422 | self.calibrator.start_col:self.calibrator.end_col, ...] 423 | # apply homography 424 | return self.calibrator(captured_intensity_raw_cropped) 425 | 426 | def disconnect(self): 427 | self.camera.disconnect() 428 | self.slm.disconnect() 429 | if self.alc is not None: 430 | self.alc.turnOffAll() 431 | -------------------------------------------------------------------------------- /utils/perceptualloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models 3 | 4 | 5 | class PerceptualLoss(torch.nn.modules.loss._Loss): 6 | 7 | def __init__(self, pixel_loss=1.0, l1_loss=False, style_loss=0.0, lambda_feat=1, include_vgg_layers=('1', '2', '3', '4', '5')): 8 | super(PerceptualLoss, self).__init__(True, True) 9 | 10 | # download pretrained vgg19 if necessary and instantiate it 11 | vgg19 = torchvision.models.vgg.vgg19(pretrained=True) 12 | self.vgg_layers = vgg19.features 13 | 14 | # the vgg feature layers we want to use for the perceptual loss 15 | self.layer_name_mapping = { 16 | } 17 | if '1' in include_vgg_layers: 18 | self.layer_name_mapping['3'] = "conv1_2" 19 | if '2' in include_vgg_layers: 20 | self.layer_name_mapping['8'] = "conv2_2" 21 | if '3' in include_vgg_layers: 22 | self.layer_name_mapping['13'] = "conv3_2" 23 | if '4' in include_vgg_layers: 24 | self.layer_name_mapping['22'] = "conv4_2" 25 | if '5' in include_vgg_layers: 26 | self.layer_name_mapping['31'] = "conv5_2" 27 | 28 | # weights for pixel loss and style loss (feature loss assumed 1.0) 29 | self.pixel_loss = pixel_loss 30 | self.l1_loss = l1_loss 31 | self.lambda_feat = lambda_feat 32 | self.style_loss = style_loss 33 | 34 | def forward(self, input, target): 35 | 36 | lossValue = torch.tensor(0.0).to(input.device) 37 | l2_loss_func = lambda ipt, tgt: torch.sum(torch.pow(ipt - tgt, 2)) # amplitude to intensity 38 | l1_loss_func = lambda ipt, tgt: torch.sum(torch.abs(ipt - tgt)) # amplitude to intensity 39 | 40 | # get size 41 | s = input.size() 42 | 43 | # number of tensors in this mini batch 44 | num_images = s[0] 45 | 46 | # L2 loss (L1 originally) 47 | if self.l1_loss: 48 | scale = s[1] * s[2] * s[3] 49 | lossValue += l1_loss_func(input, target) * (2 * self.pixel_loss / scale) 50 | loss_func = l2_loss_func 51 | elif self.pixel_loss: 52 | scale = s[1] * s[2] * s[3] 53 | lossValue += l2_loss_func(input, target) * (2 * self.pixel_loss / scale) 54 | loss_func = l2_loss_func 55 | 56 | # stack input and output so we can feed-forward it through vgg19 57 | x = torch.cat((input, target), 0) 58 | 59 | for name, module in self.vgg_layers._modules.items(): 60 | 61 | # run x through current module 62 | x = module(x) 63 | s = x.size() 64 | 65 | # scale factor 66 | scale = s[1] * s[2] * s[3] 67 | 68 | if name in self.layer_name_mapping: 69 | a, b = torch.split(x, num_images, 0) 70 | lossValue += self.lambda_feat * loss_func(a, b) / scale 71 | 72 | # Gram matrix for style loss 73 | if self.style_loss: 74 | A = a.reshape(num_images, s[1], -1) 75 | B = b.reshape(num_images, s[1], -1).detach() 76 | 77 | G_A = A @ torch.transpose(A, 1, 2) 78 | del A 79 | G_B = B @ torch.transpose(B, 1, 2) 80 | del B 81 | 82 | lossValue += loss_func(G_A, G_B) * (self.style_loss / scale) 83 | 84 | return lossValue 85 | -------------------------------------------------------------------------------- /utils/slm_display_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing SLM display module (HOLOEYE). 3 | Refer to this interface and modify it to match your SLM SDK. 4 | 5 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 6 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 7 | # The material is provided as-is, with no warranties whatsoever. 8 | # If you publish any code, data, or scientific work based on this, please cite our work. 9 | 10 | Technical Paper: 11 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 12 | """ 13 | 14 | import utils.detect_heds_module_path 15 | import holoeye 16 | from holoeye import slmdisplaysdk 17 | 18 | 19 | class SLMDisplay: 20 | def __init__(self): 21 | self.ErrorCode = slmdisplaysdk.SLMDisplay.ErrorCode 22 | self.ShowFlags = slmdisplaysdk.SLMDisplay.ShowFlags 23 | 24 | self.displayOptions = self.ShowFlags.PresentAutomatic # PresentAutomatic == 0 (default) 25 | self.displayOptions |= self.ShowFlags.PresentFitWithBars 26 | 27 | def connect(self): 28 | self.slm_device = slmdisplaysdk.SLMDisplay() 29 | self.slm_device.open() # For version 2.0.1 30 | 31 | def disconnect(self): 32 | self.slm_device.release() 33 | 34 | def show_data_from_file(self, filepath): 35 | error = self.slm_device.showDataFromFile(filepath, self.displayOptions) 36 | assert error == self.ErrorCode.NoError, self.slm_device.errorString(error) 37 | 38 | def show_data_from_array(self, numpy_array): 39 | error = self.slm_device.showData(numpy_array) 40 | assert error == self.ErrorCode.NoError, self.slm_device.errorString(error) 41 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing all uility functions used for the implementation. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 11 | """ 12 | import math 13 | import numpy as np 14 | 15 | import os 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as func 19 | import torch.nn.modules.loss as ll 20 | 21 | from skimage.metrics import peak_signal_noise_ratio as psnr 22 | from skimage.metrics import structural_similarity as ssim 23 | 24 | 25 | def mul_complex(t1, t2): 26 | """multiply two complex valued tensors element-wise. the two last dimensions are 27 | assumed to be the real and imaginary part 28 | 29 | complex multiplication: (a+bi)(c+di) = (ac-bd) + (bc+ad)i 30 | """ 31 | # real and imaginary parts of first tensor 32 | a, b = t1.split(1, 4) 33 | # real and imaginary parts of second tensor 34 | c, d = t2.split(1, 4) 35 | 36 | # multiply out 37 | return torch.cat((a * c - b * d, b * c + a * d), 4) 38 | 39 | 40 | def div_complex(t1, t2): 41 | """divide two complex valued tensors element-wise. the two last dimensions are 42 | assumed to be the real and imaginary part 43 | 44 | complex division: (a+bi) / (c+di) = (ac+bd)/(c^2+d^2) + (bc-ad)/(c^2+d^2) i 45 | """ 46 | # real and imaginary parts of first tensor 47 | (a, b) = t1.split(1, 4) 48 | # real and imaginary parts of second tensor 49 | (c, d) = t2.split(1, 4) 50 | 51 | # get magnitude 52 | mag = torch.mul(c, c) + torch.mul(d, d) 53 | 54 | # multiply out 55 | return torch.cat(((a * c + b * d) / mag, (b * c - a * d) / mag), 4) 56 | 57 | 58 | def reciprocal_complex(t): 59 | """element-wise inverse of complex-valued tensor 60 | 61 | reciprocal of complex number z=a+bi: 62 | 1/z = a / (a^2 + b^2) - ( b / (a^2 + b^2) ) i 63 | """ 64 | # real and imaginary parts of first tensor 65 | (a, b) = t.split(1, 4) 66 | 67 | # get magnitude 68 | mag = torch.mul(a, a) + torch.mul(b, b) 69 | 70 | # multiply out 71 | return torch.cat((a / mag, -(b / mag)), 4) 72 | 73 | 74 | def rect_to_polar(real, imag): 75 | """Converts the rectangular complex representation to polar""" 76 | mag = torch.pow(real**2 + imag**2, 0.5) 77 | ang = torch.atan2(imag, real) 78 | return mag, ang 79 | 80 | 81 | def polar_to_rect(mag, ang): 82 | """Converts the polar complex representation to rectangular""" 83 | real = mag * torch.cos(ang) 84 | imag = mag * torch.sin(ang) 85 | return real, imag 86 | 87 | 88 | def replace_amplitude(field, amplitude): 89 | """takes a Complex tensor with real/imag channels, converts to 90 | amplitude/phase, replaces amplitude, then converts back to real/imag 91 | 92 | resolution of both Complex64 tensors should be (M, N, height, width) 93 | """ 94 | # replace amplitude with target amplitude and convert back to real/imag 95 | real, imag = polar_to_rect(amplitude, field.angle()) 96 | 97 | # concatenate 98 | return torch.complex(real, imag) 99 | 100 | 101 | def ifftshift(tensor): 102 | """ifftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2] 103 | 104 | shifts the width and heights 105 | """ 106 | size = tensor.size() 107 | tensor_shifted = roll_torch(tensor, -math.floor(size[2] / 2.0), 2) 108 | tensor_shifted = roll_torch(tensor_shifted, -math.floor(size[3] / 2.0), 3) 109 | return tensor_shifted 110 | 111 | 112 | def fftshift(tensor): 113 | """fftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2] 114 | 115 | shifts the width and heights 116 | """ 117 | size = tensor.size() 118 | tensor_shifted = roll_torch(tensor, math.floor(size[2] / 2.0), 2) 119 | tensor_shifted = roll_torch(tensor_shifted, math.floor(size[3] / 2.0), 3) 120 | return tensor_shifted 121 | 122 | 123 | def ifft2(tensor_re, tensor_im, shift=False): 124 | """Applies a 2D ifft to the complex tensor represented by tensor_re and _im""" 125 | tensor_out = torch.stack((tensor_re, tensor_im), 4) 126 | 127 | if shift: 128 | tensor_out = ifftshift(tensor_out) 129 | (tensor_out_re, tensor_out_im) = torch.ifft(tensor_out, 2, True).split(1, 4) 130 | 131 | tensor_out_re = tensor_out_re.squeeze(4) 132 | tensor_out_im = tensor_out_im.squeeze(4) 133 | 134 | return tensor_out_re, tensor_out_im 135 | 136 | 137 | def fft2(tensor_re, tensor_im, shift=False): 138 | """Applies a 2D fft to the complex tensor represented by tensor_re and _im""" 139 | # fft2 140 | (tensor_out_re, tensor_out_im) = torch.fft(torch.stack((tensor_re, tensor_im), 4), 2, True).split(1, 4) 141 | 142 | tensor_out_re = tensor_out_re.squeeze(4) 143 | tensor_out_im = tensor_out_im.squeeze(4) 144 | 145 | # apply fftshift 146 | if shift: 147 | tensor_out_re = fftshift(tensor_out_re) 148 | tensor_out_im = fftshift(tensor_out_im) 149 | 150 | return tensor_out_re, tensor_out_im 151 | 152 | 153 | def roll_torch(tensor, shift, axis): 154 | """implements numpy roll() or Matlab circshift() functions for tensors""" 155 | if shift == 0: 156 | return tensor 157 | 158 | if axis < 0: 159 | axis += tensor.dim() 160 | 161 | dim_size = tensor.size(axis) 162 | after_start = dim_size - shift 163 | if shift < 0: 164 | after_start = -shift 165 | shift = dim_size - abs(shift) 166 | 167 | before = tensor.narrow(axis, 0, dim_size - shift) 168 | after = tensor.narrow(axis, after_start, shift) 169 | return torch.cat([after, before], axis) 170 | 171 | 172 | def pad_stacked_complex(field, pad_width, padval=0, mode='constant'): 173 | """Helper for pad_image() that pads a real padval in a complex-aware manner""" 174 | if padval == 0: 175 | pad_width = (0, 0, *pad_width) # add 0 padding for stacked_complex dimension 176 | return nn.functional.pad(field, pad_width, mode=mode) 177 | else: 178 | if isinstance(padval, torch.Tensor): 179 | padval = padval.item() 180 | 181 | real, imag = field[..., 0], field[..., 1] 182 | real = nn.functional.pad(real, pad_width, mode=mode, value=padval) 183 | imag = nn.functional.pad(imag, pad_width, mode=mode, value=0) 184 | return torch.stack((real, imag), -1) 185 | 186 | 187 | def pad_image(field, target_shape, pytorch=True, stacked_complex=True, padval=0, mode='constant'): 188 | """Pads a 2D complex field up to target_shape in size 189 | 190 | Padding is done such that when used with crop_image(), odd and even dimensions are 191 | handled correctly to properly undo the padding. 192 | 193 | field: the field to be padded. May have as many leading dimensions as necessary 194 | (e.g., batch or channel dimensions) 195 | target_shape: the 2D target output dimensions. If any dimensions are smaller 196 | than field, no padding is applied 197 | pytorch: if True, uses torch functions, if False, uses numpy 198 | stacked_complex: for pytorch=True, indicates that field has a final dimension 199 | representing real and imag 200 | padval: the real number value to pad by 201 | mode: padding mode for numpy or torch 202 | """ 203 | if pytorch: 204 | if stacked_complex: 205 | size_diff = np.array(target_shape) - np.array(field.shape[-3:-1]) 206 | odd_dim = np.array(field.shape[-3:-1]) % 2 207 | else: 208 | size_diff = np.array(target_shape) - np.array(field.shape[-2:]) 209 | odd_dim = np.array(field.shape[-2:]) % 2 210 | else: 211 | size_diff = np.array(target_shape) - np.array(field.shape[-2:]) 212 | odd_dim = np.array(field.shape[-2:]) % 2 213 | 214 | # pad the dimensions that need to increase in size 215 | if (size_diff > 0).any(): 216 | pad_total = np.maximum(size_diff, 0) 217 | pad_front = (pad_total + odd_dim) // 2 218 | pad_end = (pad_total + 1 - odd_dim) // 2 219 | 220 | if pytorch: 221 | pad_axes = [int(p) # convert from np.int64 222 | for tple in zip(pad_front[::-1], pad_end[::-1]) 223 | for p in tple] 224 | if stacked_complex: 225 | return pad_stacked_complex(field, pad_axes, mode=mode, padval=padval) 226 | else: 227 | return nn.functional.pad(field, pad_axes, mode=mode, value=padval) 228 | else: 229 | leading_dims = field.ndim - 2 # only pad the last two dims 230 | if leading_dims > 0: 231 | pad_front = np.concatenate(([0] * leading_dims, pad_front)) 232 | pad_end = np.concatenate(([0] * leading_dims, pad_end)) 233 | return np.pad(field, tuple(zip(pad_front, pad_end)), mode, 234 | constant_values=padval) 235 | else: 236 | return field 237 | 238 | 239 | def crop_image(field, target_shape, pytorch=True, stacked_complex=True): 240 | """Crops a 2D field, see pad_image() for details 241 | 242 | No cropping is done if target_shape is already smaller than field 243 | """ 244 | if target_shape is None: 245 | return field 246 | 247 | if pytorch: 248 | if stacked_complex: 249 | size_diff = np.array(field.shape[-3:-1]) - np.array(target_shape) 250 | odd_dim = np.array(field.shape[-3:-1]) % 2 251 | else: 252 | size_diff = np.array(field.shape[-2:]) - np.array(target_shape) 253 | odd_dim = np.array(field.shape[-2:]) % 2 254 | else: 255 | size_diff = np.array(field.shape[-2:]) - np.array(target_shape) 256 | odd_dim = np.array(field.shape[-2:]) % 2 257 | 258 | # crop dimensions that need to decrease in size 259 | if (size_diff > 0).any(): 260 | crop_total = np.maximum(size_diff, 0) 261 | crop_front = (crop_total + 1 - odd_dim) // 2 262 | crop_end = (crop_total + odd_dim) // 2 263 | 264 | crop_slices = [slice(int(f), int(-e) if e else None) 265 | for f, e in zip(crop_front, crop_end)] 266 | if pytorch and stacked_complex: 267 | return field[(..., *crop_slices, slice(None))] 268 | else: 269 | return field[(..., *crop_slices)] 270 | else: 271 | return field 272 | 273 | 274 | def srgb_gamma2lin(im_in): 275 | """converts from sRGB to linear color space""" 276 | thresh = 0.04045 277 | im_out = np.where(im_in <= thresh, im_in / 12.92, ((im_in + 0.055) / 1.055)**(2.4)) 278 | return im_out 279 | 280 | 281 | def srgb_lin2gamma(im_in): 282 | """converts from linear to sRGB color space""" 283 | thresh = 0.0031308 284 | im_out = np.where(im_in <= thresh, 12.92 * im_in, 1.055 * (im_in**(1 / 2.4)) - 0.055) 285 | return im_out 286 | 287 | 288 | def cond_mkdir(path): 289 | if not os.path.exists(path): 290 | os.makedirs(path) 291 | 292 | 293 | def phasemap_8bit(phasemap, inverted=True): 294 | """convert a phasemap tensor into a numpy 8bit phasemap that can be directly displayed 295 | 296 | Input 297 | ----- 298 | :param phasemap: input phasemap tensor, which is supposed to be in the range of [-pi, pi]. 299 | :param inverted: a boolean value that indicates whether the phasemap is inverted. 300 | 301 | Output 302 | ------ 303 | :return: output phasemap, with uint8 dtype (in [0, 255]) 304 | """ 305 | 306 | output_phase = ((phasemap + np.pi) % (2 * np.pi)) / (2 * np.pi) 307 | if inverted: 308 | phase_out_8bit = ((1 - output_phase) * 255).round().cpu().detach().squeeze().numpy().astype(np.uint8) # quantized to 8 bits 309 | else: 310 | phase_out_8bit = ((output_phase) * 255).round().cpu().detach().squeeze().numpy().astype(np.uint8) # quantized to 8 bits 311 | return phase_out_8bit 312 | 313 | 314 | def burst_img_processor(img_burst_list): 315 | img_tensor = np.stack(img_burst_list, axis=0) 316 | img_avg = np.mean(img_tensor, axis=0) 317 | return im2float(img_avg) # changed from int8 to float32 318 | 319 | 320 | def im2float(im, dtype=np.float32): 321 | """convert uint16 or uint8 image to float32, with range scaled to 0-1 322 | 323 | :param im: image 324 | :param dtype: default np.float32 325 | :return: 326 | """ 327 | if issubclass(im.dtype.type, np.floating): 328 | return im.astype(dtype) 329 | elif issubclass(im.dtype.type, np.integer): 330 | return im / dtype(np.iinfo(im.dtype).max) 331 | else: 332 | raise ValueError(f'Unsupported data type {im.dtype}') 333 | 334 | 335 | def propagate_field(input_field, propagator, prop_dist=0.2, wavelength=520e-9, feature_size=(6.4e-6, 6.4e-6), 336 | prop_model='ASM', dtype=torch.float32, precomputed_H=None): 337 | """ 338 | A wrapper for various propagation methods, including the parameterized model. 339 | Note that input_field is supposed to be in Cartesian coordinate, not polar! 340 | 341 | Input 342 | ----- 343 | :param input_field: pytorch complex tensor shape of (1, C, H, W), the field before propagation, in X, Y coordinates 344 | :param prop_dist: propagation distance in m. 345 | :param wavelength: wavelength of the wave in m. 346 | :param feature_size: pixel pitch 347 | :param prop_model: propagation model ('ASM', 'MODEL', 'fresnel', ...) 348 | :param trained_model: function or model instance for propagation 349 | :param dtype: torch.float32 by default 350 | :param precomputed_H: Propagation Kernel in Fourier domain (could be calculated at the very first time and reuse) 351 | 352 | Output 353 | ----- 354 | :return: output_field: pytorch complex tensor shape of (1, C, H, W), the field after propagation, in X, Y coordinates 355 | """ 356 | 357 | if prop_model == 'ASM': 358 | output_field = propagator(u_in=input_field, z=prop_dist, feature_size=feature_size, wavelength=wavelength, 359 | dtype=dtype, precomped_H=precomputed_H) 360 | elif 'MODEL' in prop_model.upper(): 361 | # forward propagate through our citl-calibrated model. 362 | # You can directly use this model propagation, not using this wrapper module. 363 | _, input_phase = rect_to_polar(input_field.real, input_field.imag) 364 | output_field = propagator(input_phase) 365 | elif prop_model == 'CAMERA': 366 | _, input_phase = rect_to_polar(input_field.real, input_field.imag) 367 | output_field = propagator(input_phase) 368 | else: 369 | raise ValueError('Unexpected prop_model value') 370 | 371 | return output_field 372 | 373 | 374 | def write_sgd_summary(slm_phase, out_amp, target_amp, k, 375 | writer=None, path=None, s=0., prefix='test'): 376 | """tensorboard summary for SGD 377 | 378 | :param slm_phase: Use it if you want to save intermediate phases during optimization. 379 | :param out_amp: PyTorch Tensor, Field amplitude at the image plane. 380 | :param target_amp: PyTorch Tensor, Ground Truth target Amplitude. 381 | :param k: iteration number. 382 | :param writer: SummaryWriter instance. 383 | :param path: path to save image files. 384 | :param s: scale for SGD algorithm. 385 | :param prefix: 386 | :return: 387 | """ 388 | loss = nn.MSELoss().to(out_amp.device) 389 | loss_value = loss(s * out_amp, target_amp) 390 | psnr_value = psnr(target_amp.squeeze().cpu().detach().numpy(), (s * out_amp).squeeze().cpu().detach().numpy()) 391 | ssim_value = ssim(target_amp.squeeze().cpu().detach().numpy(), (s * out_amp).squeeze().cpu().detach().numpy()) 392 | 393 | s_min = (target_amp * out_amp).mean() / (out_amp**2).mean() 394 | psnr_value_min = psnr(target_amp.squeeze().cpu().detach().numpy(), (s_min * out_amp).squeeze().cpu().detach().numpy()) 395 | ssim_value_min = ssim(target_amp.squeeze().cpu().detach().numpy(), (s_min * out_amp).squeeze().cpu().detach().numpy()) 396 | 397 | if writer is not None: 398 | writer.add_image(f'{prefix}_Recon/amp', (s * out_amp).squeeze(0), k) 399 | writer.add_scalar(f'{prefix}_loss', loss_value, k) 400 | writer.add_scalar(f'{prefix}_psnr', psnr_value, k) 401 | writer.add_scalar(f'{prefix}_ssim', ssim_value, k) 402 | 403 | writer.add_scalar(f'{prefix}_psnr/scaled', psnr_value_min, k) 404 | writer.add_scalar(f'{prefix}_ssim/scaled', ssim_value_min, k) 405 | 406 | writer.add_scalar(f'{prefix}_scalar', s, k) 407 | 408 | 409 | def write_gs_summary(slm_field, recon_field, target_amp, k, writer, roi=(880, 1600), prefix='test'): 410 | """tensorboard summary for GS""" 411 | slm_phase = slm_field.angle() 412 | recon_amp, recon_phase = recon_field.abs(), recon_field.angle() 413 | loss = nn.MSELoss().to(recon_amp.device) 414 | 415 | recon_amp = crop_image(recon_amp, target_shape=roi, stacked_complex=False) 416 | target_amp = crop_image(target_amp, target_shape=roi, stacked_complex=False) 417 | 418 | recon_amp *= (torch.sum(recon_amp * target_amp, (-2, -1), keepdim=True) 419 | / torch.sum(recon_amp * recon_amp, (-2, -1), keepdim=True)) 420 | 421 | loss_value = loss(recon_amp, target_amp) 422 | psnr_value = psnr(target_amp.squeeze().cpu().detach().numpy(), recon_amp.squeeze().cpu().detach().numpy()) 423 | ssim_value = ssim(target_amp.squeeze().cpu().detach().numpy(), recon_amp.squeeze().cpu().detach().numpy()) 424 | 425 | if writer is not None: 426 | writer.add_image(f'{prefix}_Recon/amp', recon_amp.squeeze(0), k) 427 | writer.add_scalar(f'{prefix}_loss', loss_value, k) 428 | writer.add_scalar(f'{prefix}_psnr', psnr_value, k) 429 | writer.add_scalar(f'{prefix}_ssim', ssim_value, k) 430 | 431 | 432 | def get_psnr_ssim(recon_amp, target_amp, multichannel=False): 433 | """get PSNR and SSIM metrics""" 434 | psnrs, ssims = {}, {} 435 | 436 | # amplitude 437 | psnrs['amp'] = psnr(target_amp, recon_amp) 438 | ssims['amp'] = ssim(target_amp, recon_amp, multichannel=multichannel) 439 | 440 | # linear 441 | target_linear = target_amp**2 442 | recon_linear = recon_amp**2 443 | psnrs['lin'] = psnr(target_linear, recon_linear) 444 | ssims['lin'] = ssim(target_linear, recon_linear, multichannel=multichannel) 445 | 446 | # srgb 447 | target_srgb = srgb_lin2gamma(np.clip(target_linear, 0.0, 1.0)) 448 | recon_srgb = srgb_lin2gamma(np.clip(recon_linear, 0.0, 1.0)) 449 | psnrs['srgb'] = psnr(target_srgb, recon_srgb) 450 | ssims['srgb'] = ssim(target_srgb, recon_srgb, multichannel=multichannel) 451 | 452 | return psnrs, ssims 453 | 454 | 455 | def str2bool(v): 456 | """ Simple query parser for configArgParse (which doesn't support native bool from cmd) 457 | Ref: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 458 | 459 | """ 460 | if isinstance(v, bool): 461 | return v 462 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 463 | return True 464 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 465 | return False 466 | else: 467 | raise ValueError('Boolean value expected.') 468 | 469 | 470 | def make_kernel_gaussian(sigma, kernel_size): 471 | 472 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 473 | x_cord = torch.arange(kernel_size) 474 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) 475 | y_grid = x_grid.t() 476 | xy_grid = torch.stack([x_grid, y_grid], dim=-1) 477 | 478 | mean = (kernel_size - 1) / 2 479 | variance = sigma**2 480 | 481 | # Calculate the 2-dimensional gaussian kernel which is 482 | # the product of two gaussian distributions for two different 483 | # variables (in this case called x and y) 484 | gaussian_kernel = ((1 / (2 * math.pi * variance)) 485 | * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) 486 | / (2 * variance))) 487 | # Make sure sum of values in gaussian kernel equals 1. 488 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 489 | 490 | # Reshape to 2d depthwise convolutional weight 491 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 492 | 493 | return gaussian_kernel 494 | 495 | 496 | def quantized_phase(phasemap): 497 | """ 498 | just quantize phase into 8bit and return a tensor with the same dtype 499 | :param phasemap: 500 | :return: 501 | """ 502 | 503 | # Shift to [0 1] 504 | phasemap = (phasemap + np.pi) / (2 * np.pi) 505 | 506 | # Convert into integer and take rounding 507 | phasemap = torch.round(255 * phasemap) 508 | 509 | # Shift to original range 510 | phasemap = phasemap / 255 * 2 * np.pi - np.pi 511 | return phasemap 512 | -------------------------------------------------------------------------------- /utils/utils_tensorboard.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing all utility functions used for tensorboard. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 11 | """ 12 | 13 | import time 14 | import copy 15 | import torch 16 | import numpy as np 17 | import tensorboardX 18 | import matplotlib.pyplot as plt 19 | 20 | import utils.utils as utils 21 | from mpl_toolkits.axes_grid1 import make_axes_locatable 22 | 23 | from propagation_ASM import compute_zernike_basis, combine_zernike_basis 24 | 25 | 26 | class SummaryModelWriter(tensorboardX.SummaryWriter): 27 | """ 28 | Inherited class of tensorboard summarywriter for visualization of model parameters. 29 | 30 | :param model: ModelPropagate instance that is being trained. 31 | :param writer_dir: directory of this summary 32 | :param slm_res: resolution of the SLM (1080, 1920) 33 | :param roi_res: resolution of the region of interest, default (880, 1600) 34 | :param ch: an integer indicating training channel (red:0, green:1, blue:2) 35 | :param kw: 36 | """ 37 | def __init__(self, model, writer_dir, slm_res=(1080, 1920), roi_res=(880, 1600), ch=1, **kw): 38 | super(SummaryModelWriter, self).__init__(writer_dir, **kw) 39 | self.model = model 40 | self.zernike_basis = None 41 | self.zernike_basis_fourier = None 42 | self.slm_res = slm_res 43 | self.roi_res = roi_res 44 | self.cmap_rgb = (plt.cm.Reds, plt.cm.Greens, plt.cm.Blues)[ch] 45 | 46 | def visualize_model(self, idx=0): 47 | """ 48 | Visualize the model parameters on Tensorboard 49 | 50 | :param idx: Global step value to record 51 | """ 52 | 53 | self.add_lut_mean_var(idx) 54 | self.add_source_amplitude_image(idx) 55 | self.add_source_amplitude_parameters(idx) 56 | self.add_zernike(idx) 57 | self.add_target_field(idx) 58 | self.add_zeroth_order(idx) 59 | self.add_latent_codes(idx) 60 | 61 | def add_lut_mean_var(self, idx=0, show_identity=True): 62 | """ 63 | Add phase-non-linearity lookuptable to Tensorboard 64 | 65 | :param idx: Global step value to record 66 | :param show_identity: show y=x on the graph 67 | """ 68 | with torch.no_grad(): 69 | num_x = 64 70 | if self.model.process_phase is not None: 71 | lut = copy.deepcopy(self.model.process_phase) 72 | lut.eval() 73 | 74 | test_phase = torch.linspace(-np.pi, np.pi, num_x, dtype=torch.float) 75 | input_phase = test_phase.numpy() 76 | test_phase = test_phase.reshape(num_x, 1, 1, 1) 77 | output_mean, output_std = np.empty(num_x), np.empty(num_x) 78 | for v in range(num_x): 79 | x = test_phase[v, ...].repeat(1, 1, *self.slm_res).to(self.model.dev) 80 | output_phase = lut(x, self.model.latent_code).detach().cpu().numpy().squeeze() 81 | output_mean[v] = np.mean(output_phase) 82 | output_std[v] = np.std(output_phase) 83 | 84 | fig = plt.figure() 85 | if show_identity: 86 | plt.plot(input_phase, output_mean, 'b', 87 | input_phase, input_phase - input_phase[int(num_x / 2)] + output_mean[int(num_x / 2)], 'k--') 88 | plt.fill_between(input_phase, output_mean - output_std, output_mean + output_std, 89 | alpha=0.5) 90 | else: 91 | plt.plot(input_phase, output_phase, 'b') 92 | 93 | self.add_figure(f'parameters/voltage-to-phase', fig, idx) 94 | del lut 95 | 96 | def add_source_amplitude_parameters(self, idx=0): 97 | """ 98 | Add parameters of gaussian source amplitudes to Tensorboard 99 | 100 | :param idx: Global step value to record 101 | """ 102 | if self.model.source_amp.num_gaussians > 0: 103 | sa = self.model.source_amp 104 | self.add_scalar('parameters_SA/DC', sa.dc_term.cpu().numpy(), idx) 105 | for i in range(self.model.source_amp.num_gaussians): 106 | self.add_scalar(f'parameters_SA/Amps_{i}', sa.amplitudes.cpu().numpy()[i], idx) 107 | self.add_scalar(f'parameters_SA/sigmas_{i}', sa.sigmas.cpu().numpy()[i], idx) 108 | self.add_scalar(f'parameters_SA/x_{i}', sa.x_s.cpu().numpy()[i], idx) 109 | self.add_scalar(f'parameters_SA/y_{i}', sa.y_s.cpu().numpy()[i], idx) 110 | 111 | def add_source_amplitude_image(self, idx=0): 112 | """ 113 | Add visualization of gaussian source amplitudes to Tensorboard 114 | 115 | :param idx: Global step value to record 116 | """ 117 | if self.model.source_amp is not None: 118 | img = self.model.source_amp(torch.empty(1, 1, *self.slm_res). 119 | to(self.model.dev)).squeeze().cpu().detach().numpy() 120 | self.add_figure_cmap(f'parameters/source_amp', img, idx, self.cmap_rgb) 121 | 122 | def add_zernike(self, idx=0, domain='fourier', cm=plt.cm.plasma): 123 | """ 124 | plot Zernike coeffs as a bar plot and 125 | plot Zernike map visualization 126 | 127 | :param domain: 'fourier' or 'primal' 128 | :param idx: Global step value to record 129 | :param cm: colomap for the zernike, default plasma 130 | """ 131 | 132 | if domain == 'fourier': 133 | zernike_coeffs = self.model.coeffs_fourier 134 | map_size = [2160, 3840] 135 | elif domain == 'primal': 136 | zernike_coeffs = self.model.coeffs 137 | map_size = [1080, 1920] 138 | 139 | if zernike_coeffs is not None: 140 | num_coeffs = len(zernike_coeffs) 141 | 142 | # Zernike coeffs visualization 143 | x = torch.linspace(0, num_coeffs - 1, num_coeffs) 144 | fig_zernike_coeffs = plt.figure() 145 | plt.bar(x.numpy(), zernike_coeffs.cpu().numpy(), width=0.5, align='center') 146 | self.add_figure(f'parameters/Zernike_coeffs_{domain}', fig_zernike_coeffs, idx) 147 | 148 | # Zernike map visualization 149 | if domain == 'fourier': 150 | if self.model.zernike_fourier is None: 151 | self.model.zernike_fourier = compute_zernike_basis(self.model.coeffs_fourier.size()[0], 152 | map_size, wo_piston=True) 153 | self.model.zernike_fourier = self.model.zernike_fourier.to(self.model.dev).detach() 154 | self.model.zernike_fourier.requires_grad = False 155 | zernike_basis = self.model.zernike_fourier 156 | if domain == 'primal': 157 | if self.zernike_basis is None: 158 | self.model.zernike = compute_zernike_basis(self.model.coeffs.size()[0], 159 | map_size.size()[-2:], wo_piston=True) 160 | self.model.zernike = self.model.zernike.to(self.model.dev).detach() 161 | self.model.zernike.requires_grad = False 162 | zernike_basis = self.model.zernike 163 | 164 | basis_rect = combine_zernike_basis(zernike_coeffs, zernike_basis) 165 | zernike_phase = basis_rect.angle() 166 | img_phase = zernike_phase.squeeze().cpu().detach().numpy() 167 | self.add_figure_cmap(f'parameters/Zernike_map_{domain}', img_phase, idx, cm) 168 | 169 | def add_target_field(self, idx=0): 170 | """ 171 | Plot u_t, content-independent undiffracted field at the target plane 172 | 173 | :param idx: Global step value to record 174 | """ 175 | 176 | if self.model.target_constant_amp is not None: 177 | amp = self.model.target_constant_amp 178 | amp = amp.squeeze().unsqueeze(0).cpu().detach().numpy() 179 | self.add_figure_cmap(f'parameters/Content-independent_target_amp', amp.squeeze(), idx, self.cmap_rgb) 180 | self.add_image(f'parameters/Content-independent_target_amp_1080p', 181 | ((amp - amp.min()) 182 | / (amp.max() - amp.min() + 1e-6)), idx) 183 | if self.model.target_constant_phase is not None: 184 | phase = self.model.target_constant_phase 185 | phase = phase.squeeze().unsqueeze(0).cpu().detach().numpy() 186 | self.add_figure_cmap(f'parameters/Content-independent_target_phase', phase.squeeze(), idx, plt.cm.plasma) 187 | self.add_image(f'parameters/Content-independent_target_phase_1080p', 188 | ((phase - phase.min()) 189 | / (phase.max() - phase.min() + 1e-6)), idx) 190 | 191 | def add_zeroth_order(self, idx=0): 192 | """ 193 | Plot output of model with zero-phase input. 194 | 195 | :param idx: Global step value to record 196 | """ 197 | 198 | zero_phase = torch.zeros((1, 1, *self.slm_res)).to(self.model.dev) 199 | recon_field = self.model(zero_phase) 200 | recon_amp = recon_field.abs() 201 | recon_amp = utils.crop_image(recon_amp, self.slm_res, 202 | stacked_complex=False).cpu().detach().squeeze().unsqueeze(0) 203 | self.add_image(f'parameters/zero_input_1080p', (recon_amp - recon_amp.min()) 204 | / (recon_amp.max() - recon_amp.min()), idx) 205 | self.add_figure_cmap(f'parameters/zero_input_figure', recon_amp.squeeze(), idx, self.cmap_rgb) 206 | 207 | def add_latent_codes(self, idx=0, chs=(0, 1)): 208 | """ 209 | plot latent codes (if exists) 210 | 211 | :param idx: Global step value to record 212 | :param chs: a list of channel indices to visualize 213 | """ 214 | if self.model.latent_code is not None: 215 | for ch in chs: 216 | lc = self.model.latent_code[0, ch, ...] 217 | self.add_figure_cmap(f'parameters/latent_code/{ch}', lc.cpu().detach().squeeze().numpy(), 218 | idx, plt.cm.plasma) 219 | 220 | def add_content_dependent_field(self, phase, idx=0): 221 | if self.model.content_dependent_field is not None: 222 | cdf = self.model.content_dependent_field(phase, self.model.latent_coords) 223 | cdf_amp, cdf_phase = cdf[..., 0], cdf[..., 1] 224 | cdf_amp = cdf_amp.cpu().detach().squeeze().unsqueeze(0) 225 | cdf_phase = cdf_phase.cpu().detach().squeeze().unsqueeze(0) 226 | 227 | self.add_figure_cmap(f'parameters/content_dependent_amp', cdf_amp, idx, self.cmap_rgb) 228 | self.add_figure_cmap(f'parameters/content_dependent_phase', cdf_phase, idx, plt.cm.plasma) 229 | 230 | def add_figure_cmap(self, title, img, idx, cmap=plt.cm.plasma): 231 | figure = plt.figure() 232 | p = plt.imshow(img.squeeze()) 233 | p.set_cmap(cmap) 234 | colorbar(p) 235 | self.add_figure(title, figure, idx) 236 | 237 | 238 | def colorbar(mappable): 239 | last_axes = plt.gca() 240 | ax = mappable.axes 241 | fig = ax.figure 242 | divider = make_axes_locatable(ax) 243 | cax = divider.append_axes("right", size="5%", pad=0.05) 244 | cbar = fig.colorbar(mappable, cax=cax) 245 | plt.sca(last_axes) 246 | return cbar 247 | --------------------------------------------------------------------------------