├── LICENSE ├── README.md ├── algorithms.py ├── configs_2d.txt ├── configs_3d.txt ├── configs_4d.txt ├── configs_model_training.txt ├── env.yml ├── holo2lf.py ├── hw ├── __init__.py ├── calibration_module.py ├── camera_capture_module.py ├── detect_heds_module_path.py ├── discrete_slm.py ├── phase_encodings.py ├── slm_display_module.py ├── ti.py └── ti_encodings.py ├── image_loader.py ├── img └── teaser.png ├── main.py ├── params.py ├── props ├── __init__.py ├── prop_ideal.py ├── prop_model.py ├── prop_physical.py ├── prop_submodules.py └── prop_zernike.py ├── quantization.py ├── train.py ├── unet.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Stanford Computational Imaging Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Time-multiplexed Neural Holography: A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
SIGGRAPH 2022 2 | ### [Project Page](http://www.computationalimaging.org/publications/time-multiplexed-neural-holography/) | [Video](https://youtu.be/k2dg-Ckhk5Q) | [Paper](https://drive.google.com/file/d/1n8xSdHgW0D5G5HhwSKrqCy1iztAcDHgX/view?usp=sharing) 3 | PyTorch implementation of
4 | [Time-multiplexed Neural Holography: A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators](http://www.computationalimaging.org/publications/time-multiplexed-neural-holography/)
5 | [Suyeon Choi](http://stanford.edu/~suyeon/)\*, 6 | [Manu Gopakumar](https://www.linkedin.com/in/manu-gopakumar-25032412b/)\*, 7 | [Yifan Peng](http://web.stanford.edu/~evanpeng/), 8 | [Jonghyun Kim](http://j-kim.kr/), 9 | [Matthew O'Toole](https://www.cs.cmu.edu/~motoole2/), 10 | [Gordon Wetzstein](https://computationalimaging.org)
11 | \*denotes equal contribution 12 | in SIGGRAPH 2022 13 | 14 | 15 | 16 | ## Get started 17 | Our code uses [PyTorch Lightning](https://www.pytorchlightning.ai/) and PyTorch >=1.10.0. 18 | 19 | You can set up a conda environment with all dependencies like so: 20 | ``` 21 | conda env create -f env.yml 22 | conda activate tmnh 23 | ``` 24 | 25 | ## High-Level structure 26 | The code is organized as follows: 27 | 28 | 29 | `./` 30 | * ```main.py``` generates phase patterns from LF/RGBD/RGB data using SGD. 31 | * ```holo2lf.py``` contains the Light-field ↔ Hologram conversion implementations. 32 | * ```algorithms.py``` contains the gradient-descent based algorithm for LF/RGBD/RGB supervision 33 | 34 | * ```params.py``` contains our default parameter settings. :heavy_exclamation_mark:**(Replace values here with those in your setup.)**:heavy_exclamation_mark: 35 | 36 | * ```quantization.py``` contains modules for quantizations (projected gradient, sigmoid, Gumbel-Softmax). 37 | * ```image_loader.py``` contains data loader modules. 38 | * ```utils.py``` has some other utilities. 39 | 40 | 41 | 42 | 43 | `./props/` contain the wave propagation operators (in simulation and physics). 44 | 45 | `./hw/` contains modules for hardware control and homography calibration 46 | * ```ti.py``` contains data given by Texas Instruments. 47 | * ```ti_encodings.py``` contains phase encoding and decoding functionalities for the TI SLM. 48 | 49 | 50 | ## Run 51 | To run, download the sample images from [here](https://drive.google.com/file/d/1aooTbzsmGw-Rfel7ntb1HJY1kILLSuEk/view?usp=sharing) and place the contents in the `data/` folder. 52 | 53 | ### Dataset generation / Model training 54 | Please see the [supplement](https://drive.google.com/file/d/1n9hdLq1xvur4I_OkGNyFgoKHGDZcMxcE/view) and [Neural 3D Holography repo](https://github.com/computational-imaging/neural-3d-holography) for more details on dataset generation and model training. 55 | ``` 56 | # Train TMNH models 57 | for c in 0 1 2 58 | do 59 | python train.py -c=configs_model_training.txt --channel=$c --data_path=${dataset_path} 60 | done 61 | 62 | ``` 63 | 64 | 65 | ### Run SGD with various target distributions (RGB images, focal stacks, and light fields) 66 | ``` 67 | for c in 0 1 2 68 | do 69 | # 2D rgb images 70 | python main.py -c=configs_2d.txt --channel=$c 71 | # 3D focal stacks 72 | python main.py -c=configs_3d.txt --channel=$c 73 | # 4D light fields 74 | python main.py -c=configs_4d.txt --channel=$c 75 | done 76 | ``` 77 | 78 | ### Run SGD with advanced quantizations 79 | ``` 80 | q=gumbel-softmax; # try none, nn, nn_sigmoid as well. 81 | python main.py -c=configs_2d.txt --channel=$c --quan_method=$q 82 | 83 | ``` 84 | 85 | ## Citation 86 | If you find our work useful in your research, please cite: 87 | ``` 88 | @inproceedings{choi2022time, 89 | author = {Choi, Suyeon 90 | and Gopakumar, Manu 91 | and Peng, Yifan 92 | and Kim, Jonghyun 93 | and O'Toole, Matthew 94 | and Wetzstein, Gordon}, 95 | title={Time-multiplexed neural holography: a flexible framework for holographic near-eye displays with fast heavily-quantized spatial light modulators}, 96 | booktitle={ACM SIGGRAPH 2022 Conference Proceedings}, 97 | pages={1--9}, 98 | year={2022} 99 | } 100 | ``` 101 | 102 | ## Acknowledgmenets 103 | Thanks to [Brian Chao](https://bchao1.github.io/) for the help with code updates and [Cindy Nguyen](https://ccnguyen.github.io) for helpful discussions. This project was in part supported by a Kwanjeong Scholarship, a Stanford SGF, Intel, NSF (award 1839974), a PECASE by the ARO (W911NF-19-1-0120), and Sony. 104 | 105 | ## Contact 106 | If you have any questions, please feel free to email the authors. -------------------------------------------------------------------------------- /algorithms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various algorithms for LF/RGBD/RGB supervision. 3 | 4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu) 5 | 6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 8 | # The material is provided as-is, with no warranties whatsoever. 9 | # If you publish any code, data, or scientific work based on this, please cite our work. 10 | 11 | Technical Paper: 12 | Time-multiplexed Neural Holography: 13 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 14 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein. 15 | SIGGRAPH 2022 16 | """ 17 | 18 | import imageio 19 | from PIL import Image, ImageDraw 20 | import torch 21 | import torch.nn as nn 22 | import torch.optim as optim 23 | import torchvision.transforms.functional as TF 24 | import numpy as np 25 | from tqdm import tqdm 26 | from copy import deepcopy 27 | 28 | import utils 29 | from holo2lf import holo2lf 30 | 31 | def load_alg(alg_type, mem_eff=False): 32 | if 'sgd' in alg_type.lower(): 33 | if mem_eff: 34 | algorithm = efficient_gradient_descent 35 | else: 36 | algorithm = gradient_descent 37 | else: 38 | raise ValueError(f"Algorithm {alg_type} is not supported!") 39 | 40 | return algorithm 41 | 42 | def gradient_descent(init_phase, target_amp, target_mask=None, target_idx=None, forward_prop=None, num_iters=1000, roi_res=None, 43 | border_margin=None, loss_fn=nn.MSELoss(), lr=0.01, out_path_idx='./results', 44 | citl=False, camera_prop=None, writer=None, quantization=None, 45 | time_joint=True, flipud=False, reg_lf_var=0.0, *args, **kwargs): 46 | """ 47 | Gradient-descent based method for phase optimization. 48 | 49 | :param init_phase: 50 | :param target_amp: 51 | :param target_mask: 52 | :param forward_prop: 53 | :param num_iters: 54 | :param roi_res: 55 | :param loss_fn: 56 | :param lr: 57 | :param out_path_idx: 58 | :param citl: 59 | :param camera_prop: 60 | :param writer: 61 | :param quantization: 62 | :param time_joint: 63 | :param flipud: 64 | :param args: 65 | :param kwargs: 66 | :return: 67 | """ 68 | print("Naive gradient descent") 69 | assert forward_prop is not None 70 | dev = init_phase.device 71 | 72 | 73 | h, w = init_phase.shape[-2], init_phase.shape[-1] # total energy = h*w 74 | 75 | init_amp = torch.ones_like(init_phase) * 0.5 76 | init_amp_logits = torch.log(init_amp / (1 - init_amp)) # convert to inverse sigmoid 77 | 78 | slm_phase = init_phase.requires_grad_(True) # phase at the slm plane 79 | slm_amp_logits = init_amp_logits.requires_grad_(True) # amplitude at the slm plane 80 | 81 | optvars = [{'params': slm_phase}] 82 | if kwargs["optimize_amp"]: 83 | optvars.append({'params': slm_amp_logits}) 84 | 85 | #if "opt_s" in reg_loss_fn_type: 86 | # s = torch.tensor(1.0).requires_grad_(True) # initial s value 87 | # optvars.append({'params': s}) 88 | #else: 89 | # s = None 90 | s = torch.tensor(1.0) 91 | optimizer = optim.Adam(optvars, lr=lr) 92 | 93 | loss_vals = [] 94 | psnr_vals = [] 95 | loss_vals_quantized = [] 96 | best_loss = 1e10 97 | best_iter = 0 98 | best_amp = None 99 | lf_supervision = len(target_amp.shape) > 4 100 | 101 | print("target amp shape", target_amp.shape) 102 | 103 | if target_mask is not None: 104 | target_amp = target_amp * target_mask 105 | nonzeros = target_mask > 0 106 | if roi_res is not None: 107 | target_amp = utils.crop_image(target_amp, roi_res, stacked_complex=False, lf=lf_supervision) 108 | if target_mask is not None: 109 | target_mask = utils.crop_image(target_mask, roi_res, stacked_complex=False, lf=lf_supervision) 110 | nonzeros = target_mask > 0 111 | 112 | if border_margin is not None: 113 | # make borders of target black 114 | mask = torch.zeros_like(target_amp) 115 | mask[:, :, border_margin:-border_margin, border_margin:-border_margin] = 1 116 | target_amp = target_amp * mask 117 | 118 | for t in tqdm(range(num_iters)): 119 | optimizer.zero_grad() 120 | if quantization is not None: 121 | quantized_phase = quantization(slm_phase, t/num_iters) 122 | else: 123 | quantized_phase = slm_phase 124 | 125 | if flipud: 126 | quantized_phase_f = quantized_phase.flip(dims=[2]) 127 | else: 128 | quantized_phase_f = quantized_phase 129 | 130 | field_input = torch.exp(1j * quantized_phase_f) 131 | 132 | recon_field = forward_prop(field_input) 133 | recon_field = utils.crop_image(recon_field, roi_res, pytorch=True, stacked_complex=False) # here, also record an uncropped image 134 | 135 | if lf_supervision: 136 | recon_amp_t = holo2lf(recon_field, n_fft=kwargs['n_fft'], hop_length=kwargs['hop_len'], 137 | win_length=kwargs['win_len'], device=dev, impl='torch').sqrt() 138 | else: 139 | recon_amp_t = recon_field.abs() 140 | 141 | if time_joint: # time-multiplexed forward model 142 | recon_amp = (recon_amp_t**2).mean(dim=0, keepdims=True).sqrt() 143 | else: 144 | recon_amp = recon_amp_t 145 | 146 | if citl: # surrogate gradients for CITL 147 | captured_amp = camera_prop(slm_phase, 1) 148 | captured_amp = utils.crop_image(captured_amp, roi_res, 149 | stacked_complex=False) 150 | recon_amp_sim = recon_amp.clone() # simulated reconstructed image 151 | recon_amp = recon_amp + captured_amp - recon_amp.detach() # reconstructed image with surrogate gradients 152 | 153 | # clip to range 154 | if target_mask is not None: 155 | final_amp = torch.zeros_like(recon_amp) 156 | final_amp[nonzeros] += (recon_amp[nonzeros] * target_mask[nonzeros]) 157 | else: 158 | final_amp = recon_amp 159 | 160 | # also track gradient of s 161 | with torch.no_grad(): 162 | s = (final_amp * target_amp).mean(dim=(-1, -2), keepdims=True) / (final_amp ** 2).mean(dim=(-1, -2), keepdims=True) # scale minimizing MSE btw recon and target 163 | 164 | loss_val = loss_fn(s * final_amp, target_amp) 165 | 166 | mse_loss = ((s * final_amp - target_amp)**2).mean().item() 167 | psnr_val = 20 * np.log10(1 / np.sqrt(mse_loss)) 168 | 169 | # loss term for having even emission at in-focus points (STFT-based regularization described in Supplementary) 170 | if reg_lf_var > 0.0: 171 | recon_amp_lf = holo2lf(recon_field, n_fft=kwargs['n_fft'], hop_length=kwargs['hop_len'], 172 | win_length=kwargs['win_len'], device=dev, impl='torch') 173 | recon_amp_lf = s * recon_amp_lf.mean(dim=0, keepdims=True).sqrt() 174 | loss_lf_var = torch.mean(torch.var(recon_amp_lf, (-2, -1))) 175 | loss_val += reg_lf_var * loss_lf_var 176 | 177 | loss_val.backward() 178 | optimizer.step() 179 | 180 | with torch.no_grad(): 181 | if loss_val.item() < best_loss: 182 | best_phase = slm_phase 183 | best_loss = loss_val.item() 184 | best_amp = s * final_amp # fits target image. 185 | best_iter = t + 1 186 | 187 | psnr = 20 * torch.log10(1 / torch.sqrt(((s * final_amp - target_amp)**2).mean())) 188 | psnr_vals.append(psnr.item()) 189 | 190 | return {'loss_vals': loss_vals, 191 | 'psnr_vals': psnr_vals, 192 | 'loss_vals_q': loss_vals_quantized, 193 | 'best_iter': best_iter, 194 | 'best_loss': best_loss, 195 | 'recon_amp': best_amp, 196 | 'target_amp': target_amp, 197 | 'final_phase': best_phase 198 | } 199 | 200 | 201 | def efficient_gradient_descent(init_phase, target_amp, target_mask=None, target_idx=None, forward_prop=None, num_iters=1000, roi_res=None, 202 | loss_fn=nn.MSELoss(), lr=0.01, out_path_idx='./results', 203 | citl=False, camera_prop=None, writer=None, quantization=None, 204 | time_joint=True, flipud=False, *args, **kwargs): 205 | """ 206 | Gradient-descent based method for phase optimization. 207 | 208 | :param init_phase: 209 | :param target_amp: 210 | :param target_mask: 211 | :param forward_prop: 212 | :param num_iters: 213 | :param roi_res: 214 | :param loss_fn: 215 | :param lr: 216 | :param out_path_idx: 217 | :param citl: 218 | :param camera_prop: 219 | :param writer: 220 | :param quantization: 221 | :param time_joint: 222 | :param flipud: 223 | :param args: 224 | :param kwargs: 225 | :return: 226 | """ 227 | print("Memory efficient gradient descent") 228 | 229 | assert forward_prop is not None 230 | dev = init_phase.device 231 | num_frames = init_phase.shape[0] 232 | 233 | slm_phase = init_phase.requires_grad_(True) # phase at the slm plane 234 | optvars = [{'params': slm_phase}] 235 | optimizer = optim.Adam(optvars, lr=lr) 236 | 237 | loss_vals = [] 238 | loss_vals_quantized = [] 239 | best_loss = 10. 240 | lf_supervision = len(target_amp.shape) > 4 241 | 242 | if target_mask is not None: 243 | target_amp = target_amp * target_mask 244 | nonzeros = target_mask > 0 245 | if roi_res is not None: 246 | target_amp = utils.crop_image(target_amp, roi_res, stacked_complex=False, lf=lf_supervision) 247 | if target_mask is not None: 248 | target_mask = utils.crop_image(target_mask, roi_res, stacked_complex=False, lf=lf_supervision) 249 | nonzeros = target_mask > 0 250 | 251 | for t in tqdm(range(num_iters)): 252 | optimizer.zero_grad() # zero grad 253 | 254 | # amplitude reconstruction without graph 255 | with torch.no_grad(): 256 | if quantization is not None: 257 | quantized_phase = quantization(slm_phase, t/num_iters) 258 | else: 259 | quantized_phase = slm_phase 260 | 261 | if flipud: 262 | quantized_phase_f = quantized_phase.flip(dims=[2]) 263 | else: 264 | quantized_phase_f = quantized_phase 265 | 266 | recon_field = forward_prop(quantized_phase_f) # just sample one depth plane 267 | recon_field = utils.crop_image(recon_field, roi_res, stacked_complex=False) 268 | 269 | if lf_supervision: 270 | recon_amp_t = holo2lf(recon_field, n_fft=kwargs['n_fft'], hop_length=kwargs['hop_len'], 271 | win_length=kwargs['win_len'], device=dev, impl='torch').sqrt() 272 | else: 273 | recon_amp_t = recon_field.abs() 274 | 275 | if citl: # surrogate gradients for CITL 276 | captured_amp = camera_prop(slm_phase) 277 | captured_amp = utils.crop_image(captured_amp, roi_res, 278 | stacked_complex=False) 279 | 280 | total_loss_val = 0 281 | # insert single frame's graph and accumulate gradient 282 | for f in range(num_frames): 283 | slm_phase_sf = slm_phase[f:f+1, ...] 284 | if quantization is not None: 285 | quantized_phase_sf = quantization(slm_phase_sf, t/num_iters) 286 | else: 287 | quantized_phase_sf = slm_phase_sf 288 | 289 | if flipud: 290 | quantized_phase_f_sf = quantized_phase_sf.flip(dims=[2]) 291 | else: 292 | quantized_phase_f_sf = quantized_phase_sf 293 | 294 | recon_field_sf = forward_prop(quantized_phase_f_sf) 295 | recon_field_sf = utils.crop_image(recon_field_sf, roi_res, stacked_complex=False) 296 | 297 | if lf_supervision: 298 | recon_amp_t_sf = holo2lf(recon_field_sf, n_fft=kwargs['n_fft'], hop_length=kwargs['hop_len'], 299 | win_length=kwargs['win_len'], device=dev, impl='torch').sqrt() 300 | else: 301 | recon_amp_t_sf = recon_field_sf.abs() 302 | 303 | ### insert graph from single frame ### 304 | recon_amp_t_with_grad = recon_amp_t.clone().detach() 305 | recon_amp_t_with_grad[f:f+1,...] = recon_amp_t_sf 306 | 307 | if time_joint: # time-multiplexed forward model 308 | recon_amp = (recon_amp_t_with_grad**2).mean(dim=0, keepdims=True).sqrt() 309 | else: 310 | recon_amp = recon_amp_t_with_grad 311 | 312 | if citl: 313 | recon_amp = recon_amp + captured_amp / (num_frames) - recon_amp.detach() 314 | 315 | if target_mask is not None: 316 | final_amp = torch.zeros_like(recon_amp) 317 | final_amp[nonzeros] += recon_amp[nonzeros] * target_mask[nonzeros] 318 | else: 319 | final_amp = recon_amp 320 | 321 | 322 | with torch.no_grad(): 323 | s = (final_amp * target_amp).mean() / \ 324 | (final_amp ** 2).mean() # scale minimizing MSE btw recon and 325 | 326 | 327 | 328 | loss_val = loss_fn(s * final_amp, target_amp) 329 | loss_val.backward(retain_graph=False) 330 | 331 | total_loss_val += loss_val.item() 332 | 333 | if t % 10 == 0: 334 | pass 335 | #writer.add_scalar("loss", total_loss_val, t) 336 | #writer.add_scalar("recon loss", recon_loss.item(), t) 337 | #writer.add_scalar("light eff loss", reg_loss.item(), t) 338 | #writer.add_scalar("s", s.item(), t) 339 | #writer.add_image("recon", torch.clamp(s*final_amp[0], 0, 1), t) 340 | 341 | # update phase variables 342 | optimizer.step() 343 | 344 | with torch.no_grad(): 345 | if total_loss_val < best_loss: 346 | best_phase = slm_phase 347 | best_loss = total_loss_val 348 | best_amp = s * recon_amp 349 | best_iter = t + 1 350 | print(total_loss_val) 351 | 352 | return {'loss_vals': loss_vals, 353 | 'loss_vals_q': loss_vals_quantized, 354 | 'best_iter': best_iter, 355 | 'best_loss': best_loss, 356 | 'recon_amp': best_amp, 357 | 'target_amp': target_amp, 358 | 'final_phase': best_phase, 359 | 's': s.item()} -------------------------------------------------------------------------------- /configs_2d.txt: -------------------------------------------------------------------------------- 1 | data_path=data/2d 2 | out_path=results 3 | target=2d 4 | loss_func=l2 5 | uniform_nbits=4 6 | eval_plane_idx=3 -------------------------------------------------------------------------------- /configs_3d.txt: -------------------------------------------------------------------------------- 1 | data_path=data/3d_bamboo 2 | out_path=results 3 | target=3d 4 | loss_func=l2 5 | uniform_nbits=4 6 | eyepiece=0.035 7 | -------------------------------------------------------------------------------- /configs_4d.txt: -------------------------------------------------------------------------------- 1 | data_path=data/4d_olas 2 | out_path=results 3 | target=4d 4 | loss_func=l2 5 | uniform_nbits=4 6 | eval_plane_idx=3 7 | eyepiece=0.035 8 | -------------------------------------------------------------------------------- /configs_model_training.txt: -------------------------------------------------------------------------------- 1 | lr=3e-4 2 | batch_size=4 3 | prop_model=nh4d -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: tmnh 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.10 6 | - setuptools 7 | - pip 8 | - wheel 9 | - anaconda-client 10 | - anaconda-project 11 | - anaconda-navigator 12 | - conda 13 | - conda-build 14 | - conda-content-trust 15 | - conda-pack 16 | - conda-package-handling 17 | - conda-package-streaming 18 | - conda-token 19 | - conda-verify 20 | - setuptools 21 | - pip: 22 | - aotools 23 | - kornia 24 | - lightning-utilities 25 | - opencv-python==4.7.0.72 26 | - pytorch-lightning==2.0.4 27 | - serial 28 | - tensorboard==2.13.0 29 | - tensorboard-data-server==0.7.1 30 | - torch 31 | - torchaudio 32 | - torchmetrics 33 | - torchvision 34 | - h5py 35 | - tensorboard 36 | - configargparse 37 | - imageio 38 | - scikit-image 39 | - tqdm 40 | prefix: /home/suyeon/anaconda3 41 | -------------------------------------------------------------------------------- /holo2lf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementations of the Light-field ↔ Hologram conversion. Note that lf2holo method is basically the OLAS method. 3 | 4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu) 5 | 6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 8 | # The material is provided as-is, with no warranties whatsoever. 9 | # If you publish any code, data, or scientific work based on this, please cite our work. 10 | 11 | Technical Paper: 12 | Time-multiplexed Neural Holography: 13 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 14 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein. 15 | SIGGRAPH 2022 16 | """ 17 | import math 18 | import numpy as np 19 | import torch 20 | import torch.nn.functional as F 21 | 22 | 23 | def holo2lf(input_field, n_fft=(9, 9), hop_length=(1, 1), win_func=None, 24 | win_length=None, device=torch.device('cuda'), impl='torch', predefined_h=None, 25 | return_h=False, h_size=(1, 1)): 26 | """ 27 | Hologram to Light field transformation. 28 | 29 | :param input_field: input field shape of (N, 1, H, W), if 1D, set H=1. 30 | :param n_fft: a tuple of numbers of fourier basis. 31 | :param hop_length: a tuple of hop lengths to sample at the end. 32 | :param win_func: window function applied to each segment, default hann window. 33 | :param win_length: a tuple of lengths of window function. if win_length is smaller than n_fft, pad zeros to the windows. 34 | :param device: torch cuda. 35 | :param impl: implementation ('conv', 'torch', 'olas') 36 | :return: A 4D representation of light field, shape of (N, 1, H, W, U, V) 37 | """ 38 | input_length = input_field.shape[-2:] 39 | batch_size, _, Ny, Nx = input_field.shape 40 | 41 | # for 1D input (n_fft = 1), don't take fourier transform toward that direction. 42 | n_fft_y = min(n_fft[0], input_length[0]) 43 | n_fft_x = min(n_fft[1], input_length[1]) 44 | 45 | if win_length is None: 46 | win_length = n_fft 47 | 48 | win_length_y = min(win_length[0], input_length[0]) 49 | win_length_x = min(win_length[1], input_length[1]) 50 | 51 | if win_func is None: 52 | w_func = lambda length: torch.hann_window(length + 1, device=device)[1:] 53 | # w_func = lambda length: torch.ones(length) 54 | win_func = torch.ger(w_func(win_length_y), w_func(win_length_x)) 55 | 56 | win_func = win_func.to(input_field.device) 57 | win_func /= win_func.sum() 58 | 59 | if impl == 'torch': 60 | # 1) use STFT implementation of PyTorch 61 | if len(input_field.squeeze().shape) > 1: # with 2D input 62 | # input_field = input_field.view(-1, input_field.shape[-1]) # merge batch & y dimension 63 | input_field = input_field.reshape(np.prod(input_field.size()[:-1]), input_field.shape[-1]) # merge batch & y dimension 64 | 65 | # take 1D stft along x dimension 66 | stft_x = torch.stft(input_field, n_fft=n_fft_x, hop_length=hop_length[1], win_length=win_length_x, 67 | onesided=False, window=win_func[win_length_y//2, :], pad_mode='constant', 68 | normalized=False, return_complex=True) 69 | 70 | if n_fft_y > 1: # 4D light field output 71 | stft_x = stft_x.reshape(batch_size, Ny, n_fft_x, Nx//hop_length[1]).permute(0, 3, 2, 1) 72 | stft_x = stft_x.contiguous().view(-1, Ny) 73 | 74 | # take one more 1D stft along y dimension 75 | stft_xy = torch.stft(stft_x, n_fft=n_fft_y, hop_length=hop_length[0], win_length=win_length_y, 76 | onesided=False, window=win_func[:, win_length_x//2], pad_mode='constant', 77 | normalized=False, return_complex=True) 78 | 79 | # reshape tensor to (N, 1, Y, X, fy, fx) 80 | stft_xy = stft_xy.reshape(batch_size, Nx//hop_length[1], n_fft[1], n_fft[0], Ny//hop_length[0]) 81 | stft_xy = stft_xy.unsqueeze(1).permute(0, 1, 5, 2, 4, 3) 82 | freq_space_rep = torch.fft.fftshift(stft_xy, (-2, -1)) 83 | 84 | else: # 3D light field output 85 | stft_xy = stft_x.reshape(batch_size, Ny, n_fft_x, Nx//hop_length[1]).permute(0, 1, 3, 2) 86 | stft_xy = stft_xy.unsqueeze(1).unsqueeze(4) 87 | freq_space_rep = torch.fft.fftshift(stft_xy, -1) 88 | 89 | else: # with 1D input -- to be deprecated 90 | freq_space_rep = torch.stft(input_field.squeeze(), 91 | n_fft=n_fft, hop_length=hop_length, onesided=False, window=win_func, 92 | win_length=win_length, normalized=False, return_complex=True) 93 | elif impl == 'olas': 94 | # 2) Our own implementation: 95 | # slide 1d representation to left and right (to amount of win_length/2) and stack in another dimension 96 | overlap_field = torch.zeros(*input_field.shape[:2], 97 | (win_func.shape[0] - 1) + input_length[0], 98 | (win_func.shape[1] - 1) + input_length[1], 99 | win_func.shape[0], win_func.shape[1], 100 | dtype=input_field.dtype).to(input_field.device) 101 | 102 | # slide the input field 103 | for i in range(win_length_y): 104 | for j in range(win_length_x): 105 | overlap_field[..., i:i+input_length[0], j:j+input_length[1], i, j] = input_field 106 | 107 | # toward the new dimensions, apply the window function and take fourier transform. 108 | win_func = win_func.reshape(1, 1, 1, 1, *win_func.shape) 109 | win_func = win_func.repeat(*input_field.shape[:2], *overlap_field.shape[2:4], 1, 1) 110 | overlap_field *= win_func # apply window 111 | 112 | # take Fourier transform (it will pad zeros when n_fft > win_length) 113 | # apply no normalization since window is already normalized 114 | if n_fft_y > 1: 115 | overlap_field = torch.fft.fftshift(torch.fft.ifft(overlap_field, n=n_fft_y, norm='forward', dim=-2), -2) 116 | freq_space_rep = torch.fft.fftshift(torch.fft.ifft(overlap_field, n=n_fft_x, norm='forward', dim=-1), -1) 117 | 118 | # take every hop_length columns, and when hop_length == win_length it should be HS. 119 | freq_space_rep = freq_space_rep[:,:, win_length_y//2:win_length_y//2+input_length[0]:hop_length[0], 120 | win_length_x//2:win_length_x//2+input_length[1]:hop_length[1], ...] 121 | 122 | return freq_space_rep.abs()**2 # LF = |U|^2 123 | 124 | 125 | def lf2holo(light_field, light_field_depth, wavelength, pixel_pitch, win=None, target_phase=None): 126 | """ 127 | Pytorch implementation of OLAS, Padmanban et al., (2019) 128 | 129 | :param light_field: 130 | :param light_field_depth: 131 | :param wavelength: 132 | :param pixel_pitch: 133 | :param win: 134 | :param target_phase: 135 | :return: 136 | """ 137 | 138 | # hogel size is same as angular resolution 139 | res_hogel = light_field.shape[-2:] 140 | 141 | # resolution of hologram is same spatial resolution of light field 142 | res_hologram = light_field.shape[2:4] 143 | 144 | # initialize hologram with zeros, padded to avoid edges/for centering 145 | radius_hogel = torch.tensor(res_hogel) // 2 146 | apas_ola = torch.zeros(*(torch.tensor(res_hologram) + radius_hogel * 2), 147 | dtype=torch.complex64, device=light_field.device) 148 | 149 | ####################################################################### 150 | # compute synthesis window 151 | # custom version of hann without zeros at start 152 | if win is None: 153 | w_func = lambda length: torch.hann_window(length + 1, device=light_field.device)[1:] 154 | # w_func = lambda length: torch.ones(length) 155 | win = torch.ger(w_func(res_hogel[0]), w_func(res_hogel[1])) 156 | win /= win.sum() 157 | 158 | ####################################################################### 159 | 160 | # compute complex field 161 | comp_depth = torch.zeros(light_field_depth.shape, device=light_field.device) 162 | 163 | # apply depth compensation 164 | fx = torch.linspace(-1 + 1 / res_hogel[1], 1 - 1 / res_hogel[1], 165 | res_hogel[1], device=light_field.device) / (2 * pixel_pitch[1]) 166 | fy = torch.linspace(-1 + 1 / res_hogel[0], 1 - 1 / res_hogel[0], 167 | res_hogel[0], device=light_field.device) / (2 * pixel_pitch[0]) 168 | 169 | y = torch.linspace(-pixel_pitch[0] * res_hologram[0] / 2, 170 | pixel_pitch[0] * res_hologram[0] / 2, 171 | res_hologram[0], device=light_field.device) 172 | x = torch.linspace(-pixel_pitch[1] * res_hologram[1] / 2, 173 | pixel_pitch[1] * res_hologram[1] / 2, 174 | res_hologram[1], device=light_field.device) 175 | y, x = torch.meshgrid(y, x) 176 | 177 | for ky in range(res_hogel[0]): 178 | for kx in range(res_hogel[1]): 179 | theta = torch.asin(torch.sqrt(fx[kx] ** 2 + fy[ky] ** 2) * wavelength) 180 | comp_depth[..., ky, kx] = (light_field_depth[..., ky, kx] * (1 - torch.cos(theta))) 181 | 182 | # comp_depth[..., ky, kx] = (fx[kx] * x + fy[ky] * y) * wavelength 183 | print(comp_depth.max(), comp_depth.min()) 184 | 185 | comp_amp = torch.sqrt(light_field) 186 | comp_phase = 2 * math.pi / wavelength * comp_depth 187 | 188 | if target_phase is not None: 189 | x_pos = torch.zeros_like(comp_depth) 190 | y_pos = torch.zeros_like(comp_depth) 191 | for ky in range(res_hogel[0]): 192 | y_pos[..., ky, :] = (light_field_depth[..., ky, :] * fy[ky] * wavelength 193 | + y.unsqueeze(-1).unsqueeze(0).unsqueeze(0)) * 2/(pixel_pitch[0] * target_phase.shape[-2]) 194 | for kx in range(res_hogel[1]): 195 | x_pos[..., kx] = (light_field_depth[..., kx] * fx[kx] * wavelength 196 | + x.unsqueeze(-1).unsqueeze(0).unsqueeze(0)) * 2/(pixel_pitch[1] * target_phase.shape[-1]) 197 | for ky in range(res_hogel[0]): 198 | for kx in range(res_hogel[1]): 199 | sample_grid = torch.stack((x_pos[:, 0, :, :, ky, kx], y_pos[:, 0, :, :, ky, kx]), -1) 200 | comp_phase[..., ky, kx] += F.grid_sample(target_phase, sample_grid, 201 | padding_mode='reflection') 202 | 203 | complex_lf = comp_amp * torch.exp(1j * comp_phase) 204 | 205 | # fft over the hogel dimension 206 | complex_lf = torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(complex_lf, dim=(-2, -1)), 207 | dim=(-2, -1), norm='forward'), dim=(-2, -1)) 208 | 209 | # apply window, extra dims are for spatial dims, color, and complex dim 210 | complex_lf = complex_lf * win[None, None, None, None, ...] 211 | 212 | # overlap and add the hogels 213 | for ky in range(res_hogel[0]): 214 | for kx in range(res_hogel[1]): 215 | apas_ola[..., 216 | ky:ky + res_hologram[0], 217 | kx:kx + res_hologram[1]] += complex_lf[..., ky, kx].squeeze() 218 | 219 | # crop back to light field size 220 | return apas_ola[..., radius_hogel[0]:-radius_hogel[0], radius_hogel[1]:-radius_hogel[1]].unsqueeze(0).unsqueeze(0) 221 | 222 | -------------------------------------------------------------------------------- /hw/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/time-multiplexed-neural-holography/5cf6c275c459652abb3ddddd2e167f9584072aeb/hw/__init__.py -------------------------------------------------------------------------------- /hw/calibration_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing the calibration module, basically calculating homography matrix. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020. 11 | """ 12 | 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | import torchvision 16 | import cv2 17 | import skimage.transform as transform 18 | import time 19 | import datetime 20 | from scipy.io import savemat, loadmat 21 | from scipy.ndimage import map_coordinates 22 | import torch 23 | import torch.nn.functional as F 24 | import torch.nn as nn 25 | 26 | def id(x): 27 | return x 28 | 29 | def circle_detect(captured_img, num_circles, spacing, pad_pixels=(0., 0.), show_preview=True, quadratic=False): 30 | """ 31 | Detects the circle of a circle board pattern 32 | 33 | :param captured_img: captured image 34 | :param num_circles: a tuple of integers, (num_circle_x, num_circle_y) 35 | :param spacing: a tuple of integers, in pixels, (space between circles in x, space btw circs in y direction) 36 | :param show_preview: boolean, default True 37 | :param pad_pixels: coordinate of the left top corner of warped image. 38 | Assuming pad this amount of pixels on the other side. 39 | :return: a tuple, (found_dots, H) 40 | found_dots: boolean, indicating success of calibration 41 | H: a 3x3 homography matrix (numpy) 42 | """ 43 | 44 | # Binarization 45 | # org_copy = org.copy() # Otherwise, we write on the original image! 46 | img = (np.clip(captured_img.copy(), 0, 1) * 255).astype(np.uint8) 47 | print(img[...,0].mean()) 48 | print(img[...,1].mean()) 49 | print(img[...,2].mean()) 50 | if len(img.shape) > 2: 51 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 52 | cv2.imwrite("temp/img_gray.png", img) 53 | print(img[...,0].mean()) 54 | print(img[...,1].mean()) 55 | print(img[...,2].mean()) 56 | 57 | 58 | img = cv2.medianBlur(img, 5) # Red 71 59 | # cv2.imwrite("temp/img_blur.png", img) 60 | img_gray = img.copy() 61 | 62 | img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 201, 0) 63 | cv2.imwrite("temp/img_adapt_thres.png", img) 64 | 65 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) 66 | img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel) 67 | cv2.imwrite("temp/img_open.png", img) 68 | img = 255 - img 69 | 70 | # Blob detection 71 | params = cv2.SimpleBlobDetector_Params() 72 | 73 | # Change thresholds 74 | params.filterByColor = True 75 | params.minThreshold = 121 76 | 77 | # Filter by Area. 78 | params.filterByArea = True 79 | params.minArea = 150 80 | 81 | # Filter by Circularity 82 | params.filterByCircularity = True 83 | params.minCircularity = 0.5 # change here, easier to detect blob 84 | 85 | # Filter by Convexity 86 | params.filterByConvexity = True 87 | params.minConvexity = 0.3 88 | 89 | # Filter by Inertia 90 | params.filterByInertia = False 91 | params.minInertiaRatio = 0.01 92 | 93 | detector = cv2.SimpleBlobDetector_create(params) 94 | 95 | # Detecting keypoints 96 | # this is redundant for what comes next, but gives us access to the detected dots for debug 97 | keypoints = detector.detect(img) 98 | found_dots, centers = cv2.findCirclesGrid(img, (num_circles[1], num_circles[0]), 99 | blobDetector=detector, flags=cv2.CALIB_CB_SYMMETRIC_GRID) 100 | 101 | # Drawing the keypoints 102 | cv2.drawChessboardCorners(captured_img, num_circles, centers, found_dots) 103 | img_gray = cv2.drawKeypoints(img_gray, keypoints, np.array([]), (0, 255, 0), 104 | cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS) 105 | 106 | # Find transformation 107 | H = np.array([[1., 0., 0.], 108 | [0., 1., 0.], 109 | [0., 0., 1.]], dtype=np.float32) 110 | ref_pts = np.zeros((num_circles[0] * num_circles[1], 1, 2), np.float32) 111 | pos = 0 112 | for j in range(0, num_circles[0]): 113 | for i in range(0, num_circles[1]): 114 | ref_pts[pos, 0, :] = spacing * np.array([i, j]) + np.array([pad_pixels[1], pad_pixels[0]]) 115 | 116 | pos += 1 117 | ref_pts = ref_pts.reshape(num_circles[0] * num_circles[1], 2) 118 | if found_dots: 119 | # Generate reference points to compute the homography 120 | print("Found dots") 121 | H, mask = cv2.findHomography(centers, ref_pts, cv2.RANSAC, 1) 122 | 123 | centers = np.flip(centers.reshape(num_circles[0] * num_circles[1], 2), 1) 124 | homography_cache = {'H':H, 'centers':centers} 125 | savemat(f'./cache_h.mat', homography_cache) 126 | else: 127 | print("No dots") 128 | homography_cache = loadmat(f'./cache_h.mat') 129 | H = homography_cache['H'] 130 | centers = homography_cache['centers'] 131 | 132 | 133 | now = datetime.datetime.now() 134 | mdic = {"centers": centers, 'H': H} 135 | dsize = [int((num_circs - 1) * space + 2 * pad_pixs) 136 | for num_circs, space, pad_pixs in zip(num_circles, spacing, pad_pixels) ] 137 | if quadratic: 138 | H = transform.estimate_transform('polynomial', ref_pts, centers) 139 | coords = transform.warp_coords(H, dsize, dtype=np.float32) # for pytorch 140 | else: 141 | tf = transform.estimate_transform('projective', ref_pts, centers) 142 | coords = transform.warp_coords(tf, (800, 1280), dtype=np.float32) # for pytorch 143 | 144 | if show_preview: 145 | dsize = [int((num_circs - 1) * space + 2 * pad_pixs) 146 | for num_circs, space, pad_pixs in zip(num_circles, spacing, pad_pixels)] 147 | if quadratic: 148 | captured_img_warp = transform.warp(captured_img, H, output_shape=(dsize[0], dsize[1])) 149 | else: 150 | captured_img_warp = cv2.warpPerspective(captured_img, H, (dsize[1], dsize[0])) 151 | 152 | 153 | if show_preview: 154 | fig = plt.figure() 155 | 156 | ax = fig.add_subplot(223) # grayscale 157 | ax.imshow(img_gray, cmap='gray') 158 | 159 | ax2 = fig.add_subplot(221) # binarized image 160 | ax2.imshow(img, cmap='gray') 161 | 162 | ax3 = fig.add_subplot(222) # captured image 163 | ax3.imshow(captured_img, cmap='gray') 164 | 165 | if found_dots: 166 | ax4 = fig.add_subplot(224) 167 | ax4.imshow(captured_img_warp, cmap='gray') 168 | 169 | plt.show() 170 | 171 | return found_dots, H, coords 172 | 173 | 174 | class Warper(nn.Module): 175 | def __init__(self, params_calib): 176 | super(Warper, self).__init__() 177 | self.num_circles = params_calib.num_circles 178 | self.spacing_size = params_calib.spacing_size 179 | self.pad_pixels = params_calib.pad_pixels 180 | self.quadratic = params_calib.quadratic 181 | self.img_size_native = params_calib.img_size_native # get this from image 182 | self.h_transform = np.array([[1., 0., 0.], 183 | [0., 1., 0.], 184 | [0., 0., 1.]]) 185 | self.range_x = params_calib.range_x # slice 186 | self.range_y = params_calib.range_y # slice 187 | 188 | 189 | def calibrate(self, img, show_preview=True): 190 | img_masked = np.zeros_like(img) 191 | img_masked[self.range_y, self.range_x, ...] = img[self.range_y, self.range_x, ...] 192 | 193 | found_corners, self.h_transform, self.coords = circle_detect(img_masked, self.num_circles, 194 | self.spacing_size, self.pad_pixels, show_preview, 195 | quadratic=self.quadratic) 196 | 197 | if not self.coords is None: 198 | self.coords_tensor = torch.tensor(np.transpose(self.coords, (1, 2, 0)), 199 | dtype=torch.float32).unsqueeze(0) 200 | 201 | # normalize it into [-1, 1] 202 | self.coords_tensor[..., 0] = 2*self.coords_tensor[..., 0] / (self.img_size_native[1]-1) - 1 203 | self.coords_tensor[..., 1] = 2*self.coords_tensor[..., 1] / (self.img_size_native[0]-1) - 1 204 | 205 | return found_corners 206 | 207 | def __call__(self, input_img, img_size=None): 208 | """ 209 | This forward pass returns the warped image. 210 | 211 | :param input_img: A numpy grayscale image shape of [H, W]. 212 | :param img_size: output size, default None. 213 | :return: output_img: warped image with pre-calculated homography and destination size. 214 | """ 215 | 216 | if img_size is None: 217 | img_size = [int((num_circs - 1) * space + 2 * pad_pixs) 218 | for num_circs, space, pad_pixs in zip(self.num_circles, self.spacing_size, self.pad_pixels)] 219 | 220 | if torch.is_tensor(input_img): 221 | output_img = F.grid_sample(input_img, self.coords_tensor, align_corners=True) 222 | else: 223 | if self.quadratic: 224 | output_img = transform.warp(input_img, self.h_transform, output_shape=(img_size[0], img_size[1])) 225 | else: 226 | output_img = cv2.warpPerspective(input_img, self.h_transform, (img_size[0], img_size[1])) 227 | 228 | return output_img 229 | 230 | @property 231 | def h_transform(self): 232 | return self._h_transform 233 | 234 | @h_transform.setter 235 | def h_transform(self, new_h): 236 | self._h_transform = new_h 237 | 238 | def to(self, *args, **kwargs): 239 | slf = super().to(*args, **kwargs) 240 | if slf.coords_tensor is not None: 241 | slf.coords_tensor = slf.coords_tensor.to(*args, **kwargs) 242 | try: 243 | slf.dev = next(slf.parameters()).device 244 | except StopIteration: # no parameters 245 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0] 246 | if device_arg is not None: 247 | slf.dev = device_arg 248 | return slf -------------------------------------------------------------------------------- /hw/camera_capture_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing the calibration module, basically calculating homography matrix. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | Time-multiplexed Neural Holography: 11 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 12 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein. 13 | SIGGRAPH 2022 14 | """ 15 | 16 | import PyCapture2 17 | import cv2 18 | import numpy as np 19 | import utils 20 | 21 | 22 | def callback_captured(image): 23 | print(image.getData()) 24 | 25 | 26 | class CameraCapture: 27 | def __init__(self, params): 28 | self.bus = PyCapture2.BusManager() 29 | num_cams = self.bus.getNumOfCameras() 30 | if not num_cams: 31 | exit() 32 | # self.demosaick_rule = cv2.COLOR_BAYER_RG2BGR 33 | #self.demosaick_rule = cv2.COLOR_BAYER_GR2RGB # GBRG to RGB 34 | self.demosaick_rule = cv2.COLOR_BAYER_BG2RGB # RGGB to RGB, Grasshopper3, U3, projector 35 | self.params = params 36 | 37 | def connect(self, i, trigger=False): 38 | uid = self.bus.getCameraFromIndex(i) 39 | self.camera_device = PyCapture2.Camera() 40 | self.camera_device.connect(uid) 41 | self.camera_device.setConfiguration(highPerformanceRetrieveBuffer=True) 42 | self.camera_device.setConfiguration(numBuffers=1000) 43 | config = self.camera_device.getConfiguration() 44 | self.toggle_embedded_timestamp(True) 45 | 46 | if trigger: 47 | trigger_mode = self.camera_device.getTriggerMode() 48 | trigger_mode.onOff = True 49 | trigger_mode.mode = 0 50 | trigger_mode.parameter = 0 51 | trigger_mode.source = 3 # Using software trigger 52 | self.camera_device.setTriggerMode(trigger_mode) 53 | else: 54 | trigger_mode = self.camera_device.getTriggerMode() 55 | trigger_mode.onOff = False 56 | trigger_mode.mode = 0 57 | trigger_mode.parameter = 0 58 | trigger_mode.source = 3 # Using software trigger 59 | self.camera_device.setTriggerMode(trigger_mode) 60 | 61 | trigger_mode = self.camera_device.getTriggerMode() 62 | if trigger_mode.onOff is True: 63 | print(' - setting trigger mode on') 64 | 65 | def set_shutter_speed(self, val): 66 | self.camera_device.setProperty(type = PyCapture2.PROPERTY_TYPE.SHUTTER, autoManualMode = False, absValue = val) 67 | shutter_speed = self.camera_device.getProperty(PyCapture2.PROPERTY_TYPE.SHUTTER ).absValue 68 | print(f"Shutter speed set to {shutter_speed}ms") 69 | 70 | def set_gain(self, val): 71 | self.camera_device.setProperty(type = PyCapture2.PROPERTY_TYPE.GAIN, autoManualMode = False, absValue = val) 72 | gain = self.camera_device.getProperty(PyCapture2.PROPERTY_TYPE.GAIN).absValue 73 | print(f"Gain set to {gain}dB") 74 | 75 | def disconnect(self): 76 | self.toggle_embedded_timestamp(False) 77 | self.camera_device.disconnect() 78 | 79 | def toggle_embedded_timestamp(self, enable_timestamp): 80 | embedded_info = self.camera_device.getEmbeddedImageInfo() 81 | if embedded_info.available.timestamp: 82 | self.camera_device.setEmbeddedImageInfo(timestamp=enable_timestamp) 83 | 84 | def grab_images(self, num_images_to_grab=1): 85 | """ 86 | Retrieve the camera buffer and returns a list of grabbed images. 87 | 88 | :param num_images_to_grab: integer, default 1 89 | :return: a list of numpy 2d color images from the camera buffer. 90 | """ 91 | self.camera_device.startCapture() 92 | img_list = [] 93 | for i in range(num_images_to_grab): 94 | imgData = self.retrieve_buffer() 95 | offset = 64 # offset that inherently exist.retrieve_buffer 96 | imgData = imgData - offset 97 | 98 | color_cv_image = cv2.cvtColor(imgData, self.demosaick_rule) 99 | color_cv_image = utils.im2float(color_cv_image) 100 | img_list.append(color_cv_image.copy()) 101 | 102 | self.camera_device.stopCapture() 103 | return img_list 104 | 105 | def grab_images_fast(self, num_images_to_grab=1): 106 | """ 107 | Retrieve the camera buffer and returns a grabbed image 108 | 109 | :param num_images_to_grab: integer, default 1 110 | :return: a list of numpy 2d color images from the camera buffer. 111 | """ 112 | imgData = self.retrieve_buffer() 113 | offset = 64 # offset that inherently exist. 114 | imgData = imgData - offset 115 | 116 | color_cv_image = cv2.cvtColor(imgData, self.demosaick_rule) 117 | color_cv_image = utils.im2float(color_cv_image) 118 | color_img = color_cv_image 119 | return color_img 120 | 121 | def retrieve_buffer(self): 122 | try: 123 | img = self.camera_device.retrieveBuffer() 124 | except PyCapture2.Fc2error as fc2Err: 125 | raise fc2Err 126 | 127 | imgData = img.getData() 128 | 129 | # when using raw8 from the PG sensor 130 | # cv_image = np.array(img.getData(), dtype="uint8").reshape((img.getRows(), img.getCols())) 131 | 132 | # when using raw16 from the PG sensor - concat 2 8bits in a row 133 | imgData.dtype = np.uint16 134 | imgData = imgData.reshape(img.getRows(), img.getCols()) 135 | return imgData.copy() 136 | 137 | def start_capture(self): 138 | # these two were previously inside the grab_images func, and can be clarified outside the loop 139 | self.camera_device.startCapture() 140 | 141 | def stop_capture(self): 142 | self.camera_device.stopCapture() 143 | 144 | @property 145 | def params(self): 146 | return self._params 147 | 148 | @params.setter 149 | def params(self, p): 150 | self._params = p -------------------------------------------------------------------------------- /hw/detect_heds_module_path.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #--------------------------------------------------------------------# 4 | # # 5 | # Copyright (C) 2020 HOLOEYE Photonics AG. All rights reserved. # 6 | # Contact: https://holoeye.com/contact/ # 7 | # # 8 | # This file is part of HOLOEYE SLM Display SDK. # 9 | # # 10 | # You may use this file under the terms and conditions of the # 11 | # 'HOLOEYE SLM Display SDK Standard License v1.0' license agreement. # 12 | # # 13 | #--------------------------------------------------------------------# 14 | 15 | 16 | # Please import this file in your scripts before actually importing the HOLOEYE SLM Display SDK, 17 | # i. e. copy this file to your project and use this code in your scripts: 18 | # 19 | # import detect_heds_module_path 20 | # import holoeye 21 | # 22 | # 23 | # Another option is to copy the holoeye module directory into your project and import by only using 24 | # import holoeye 25 | # This way, code completion etc. might work better. 26 | 27 | 28 | import os, sys 29 | from platform import system 30 | 31 | # Import the SLM Display SDK: 32 | HEDSModulePath = os.getenv('HEDS_2_PYTHON_MODULES', '') 33 | 34 | if HEDSModulePath == '': 35 | sdklocal = os.path.abspath(os.path.join(os.path.dirname(__file__), 36 | 'holoeye', 'slmdisplaysdk', '__init__.py')) 37 | if os.path.isfile(sdklocal): 38 | HEDSModulePath = os.path.dirname(os.path.dirname(os.path.dirname(sdklocal))) 39 | else: 40 | sdklocal = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 41 | 'sdk', 'holoeye', 'slmdisplaysdk', '__init__.py')) 42 | if os.path.isfile(sdklocal): 43 | HEDSModulePath = os.path.dirname(os.path.dirname(os.path.dirname(sdklocal))) 44 | 45 | if HEDSModulePath == '': 46 | if system() == 'Windows': 47 | print('\033[91m' 48 | '\nError: Could not find HOLOEYE SLM Display SDK installation path from environment variable. ' 49 | '\n\nPlease relogin your Windows user account and try again. ' 50 | '\nIf that does not help, please reinstall the SDK and then relogin your user account and try again. ' 51 | '\nA simple restart of the computer might fix the problem, too.' 52 | '\033[0m') 53 | else: 54 | print('\033[91m' 55 | '\nError: Could not detect HOLOEYE SLM Display SDK installation path. ' 56 | '\n\nPlease make sure it is present within the same folder or in "../../sdk".' 57 | '\033[0m') 58 | 59 | sys.exit(1) 60 | 61 | sys.path.append(HEDSModulePath) 62 | -------------------------------------------------------------------------------- /hw/discrete_slm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Any info about discrete SLM 3 | 4 | Technical Paper: 5 | Time-multiplexed Neural Holography: 6 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 7 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein. 8 | SIGGRAPH 2022 9 | """ 10 | 11 | import torch 12 | import hw.ti as ti 13 | import utils 14 | 15 | 16 | class DiscreteSLM: 17 | """ 18 | Class for Discrete SLM that supports discrete LUT 19 | """ 20 | _lut_midvals = None 21 | _lut = None 22 | prev_idx = 0. 23 | 24 | @property 25 | def lut_midvals(self): 26 | return self._lut_midvals 27 | 28 | @lut_midvals.setter 29 | def lut_midvals(self, new_midvals): 30 | self._lut_midvals = torch.tensor(new_midvals)#, device=torch.device('cuda')) 31 | 32 | @property 33 | def lut(self): 34 | return self._lut 35 | 36 | @lut.setter 37 | def lut(self, new_lut): 38 | if new_lut is None: 39 | self._lut = None 40 | else: 41 | self.lut_midvals = utils.lut_mid(new_lut) 42 | if torch.is_tensor(new_lut): 43 | self._lut = new_lut.clone().detach() 44 | else: 45 | self._lut = torch.tensor(new_lut)#, device=torch.device('cuda')) 46 | 47 | 48 | DiscreteSLM = DiscreteSLM() # class singleton 49 | DiscreteSLM.lut = ti.given_lut 50 | 51 | #num_bits = 4 52 | #DiscreteSLM.lut = np.linspace(-math.pi, math.pi, 2**num_bits + 1) # test for ideal lut 53 | 54 | -------------------------------------------------------------------------------- /hw/phase_encodings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encoding and decoding functions for our TI SLM. 3 | 4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu) 5 | 6 | Technical Paper: 7 | Time-multiplexed Neural Holography: 8 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 9 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein. 10 | SIGGRAPH 2022 11 | """ 12 | 13 | import numpy as np 14 | import hw.ti_encodings as ti_encodings 15 | 16 | 17 | def phasemap_8bit(phasemap, inverted=True): 18 | """convert a phasemap tensor into a numpy 8bit phasemap that can be directly displayed 19 | 20 | :param phasemap: input phasemap tensor, which is supposed to be in the range of [-pi, pi]. 21 | :param inverted: a boolean value that indicates whether the phasemap is inverted. 22 | 23 | :return: output phasemap, with uint8 dtype (in [0, 255]) 24 | """ 25 | 26 | output_phase = ((phasemap + np.pi) % (2 * np.pi)) / (2 * np.pi) 27 | if inverted: 28 | phase_out_8bit = ((1 - output_phase) * 255).round().cpu().detach().squeeze().numpy().astype(np.uint8) # quantized to 8 bits 29 | else: 30 | phase_out_8bit = ((output_phase) * 255).round().cpu().detach().squeeze().numpy().astype(np.uint8) # quantized to 8 bits 31 | return phase_out_8bit 32 | 33 | 34 | def phase_encoding(phase, slm_type): 35 | assert len(phase.shape) == 4 36 | """ phase encoding for SLM """ 37 | if slm_type.lower() in ('holoeye', 'leto', 'pluto'): 38 | return phasemap_8bit(phase) 39 | elif slm_type.lower() in ('ti', "ee236a"): 40 | return np.fliplr(ti_encodings.rgb_encoding(phase.cpu())) 41 | else: 42 | return None -------------------------------------------------------------------------------- /hw/slm_display_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the script containing the calibration module, basically calculating homography matrix. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | Time-multiplexed Neural Holography: 11 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 12 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein. 13 | SIGGRAPH 2022 14 | """ 15 | 16 | import hw.detect_heds_module_path 17 | import holoeye 18 | from holoeye import slmdisplaysdk 19 | 20 | 21 | class SLMDisplay: 22 | ErrorCode = slmdisplaysdk.SLMDisplay.ErrorCode 23 | ShowFlags = slmdisplaysdk.SLMDisplay.ShowFlags 24 | State = slmdisplaysdk.SLMDisplay.State 25 | ApplyDataHandleValue = slmdisplaysdk.SLMDisplay.ApplyDataHandleValue 26 | 27 | def __init__(self): 28 | self.ErrorCode = slmdisplaysdk.SLMDisplay.ErrorCode 29 | self.ShowFlags = slmdisplaysdk.SLMDisplay.ShowFlags 30 | 31 | self.displayOptions = self.ShowFlags.PresentAutomatic # PresentAutomatic == 0 (default) 32 | self.displayOptions |= self.ShowFlags.PresentFitWithBars 33 | 34 | def connect(self): 35 | self.slm_device = slmdisplaysdk.SLMDisplay() 36 | self.slm_device.open() # For version 2.0.1 37 | 38 | def disconnect(self): 39 | self.slm_device.release() 40 | 41 | def show_data_from_file(self, filepath): 42 | error = self.slm_device.showDataFromFile(filepath, self.displayOptions) 43 | assert error == self.ErrorCode.NoError, self.slm_device.errorString(error) 44 | 45 | def show_data_from_array(self, numpy_array): 46 | error = self.slm_device.showData(numpy_array) 47 | assert error == self.ErrorCode.NoError, self.slm_device.errorString(error) 48 | -------------------------------------------------------------------------------- /hw/ti.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data from the TI SLM manual 3 | 4 | Technical Paper: 5 | Time-multiplexed Neural Holography: 6 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 7 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein. 8 | SIGGRAPH 2022 9 | """ 10 | 11 | import math 12 | import torch 13 | 14 | given_chart = (0., 15 | 1.07, 16 | 2.19, 17 | 4.50, 18 | 5.98, 19 | 7.75, 20 | 12.06, 21 | 18.5, 22 | 36.55, 23 | 39.55, 24 | 45.1, 25 | 52.44, 26 | 63.93, 27 | 71.16, 28 | 85.02, 29 | 100.) 30 | adjusted = [p / 100 * 15 / 16 * 2 * math.pi for p in given_chart] 31 | adjusted.append(adjusted[0] + 2*math.pi) 32 | given_lut = [p - math.pi for p in adjusted] # [-pi, pi] 33 | 34 | idx_order = [4, 2, 1, 0, 7, 6, 5, 3, 11, 10, 9, 8, 15, 14, 13, 12] # see manual 35 | idx2phase = torch.tensor([given_lut[idx_order[i]] for i in range(len(idx_order))]) 36 | -------------------------------------------------------------------------------- /hw/ti_encodings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encoding and decoding functions for our TI SLM. 3 | 4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu) 5 | 6 | Technical Paper: 7 | Time-multiplexed Neural Holography: 8 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 9 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein. 10 | SIGGRAPH 2022 11 | """ 12 | 13 | import numpy as np 14 | import torch 15 | import utils 16 | import hw.ti as ti 17 | from hw.discrete_slm import DiscreteSLM 18 | 19 | 20 | def binary_encoding_ti_slm(phase): 21 | """ gets phase in [-pi, pi] and returns binary encoded phase of DMD """ 22 | #print("binary", phase.shape) 23 | idx = utils.nearest_idx(phase, DiscreteSLM.lut_midvals) 24 | height = phase.shape[2] * 2 25 | width = phase.shape[3] * 2 26 | 27 | encoded_phase = torch.zeros(1, 1, height, width).to(phase.device) 28 | encoded_phase[:, :, ::2, 1::2] = torch.div(idx, 8, rounding_mode='floor') # M3, ur 29 | encoded_phase[:, :, 1::2, 1::2] = torch.where( 30 | torch.logical_or(idx == 3, 31 | torch.logical_and(idx != 4, idx % 8 >= 4)), 1, 0) # M2, dr 32 | encoded_phase[:, :, ::2, ::2] = torch.where( 33 | torch.logical_or(idx == 3, 34 | torch.logical_and(idx != 4, (idx % 4) < 2)), 1, 0) # M1, ul 35 | encoded_phase[:, :, 1::2, ::2] = torch.where( 36 | torch.logical_or(idx == 3, 37 | torch.logical_and(idx != 4, (idx % 2 == 0))), 1, 0) # M0, dl 38 | 39 | return encoded_phase 40 | 41 | 42 | def bit_encoding(phase, bits): 43 | """ gets phase of shape (N, 1, H, W) and returns """ 44 | 45 | power = sum(2**b for b in bits) 46 | return binary_encoding_ti_slm(phase) * power 47 | 48 | 49 | def rgb_encoding(phase, ch=None): 50 | """ gets phase in a batch ot tensor and return RGB-encoded phase (for specific TI) """ 51 | phase = (phase + np.pi) % (2*np.pi) - np.pi 52 | num_phases = len(phase) 53 | #print("rgb", phase.shape) 54 | if num_phases % 3 == 0: 55 | num_bits_per_ch = num_phases // 3 56 | 57 | # placeholder with doubled resolution 58 | res = np.zeros((*(2*p for p in phase.shape[2:]), 3), dtype=np.uint8) 59 | for c in range(3): 60 | res[..., c] = rgb_encoding(phase[c*num_bits_per_ch:(c+1)*num_bits_per_ch, ...]) 61 | return res 62 | else: 63 | phase = sum([bit_encoding(phase[j:j+1, ...], range(j*(8//num_phases), (j+1)*(8//num_phases))) 64 | for j in range(num_phases)]) 65 | 66 | if ch is None: 67 | res = phase.squeeze().cpu().detach().numpy().astype(np.uint8) 68 | else: 69 | res = np.zeros((*phase.shape[2:], 3)) 70 | res[..., ch] = phase.squeeze().cpu().detach().numpy().astype(np.uint8) 71 | 72 | return res 73 | 74 | 75 | def rgb_decoding(phase_img, num_frames=None, one_hot=False): 76 | """ gets phase values in [-pi, pi] from encoded phase image displayed 77 | 78 | :param phase_img: numpy image of [M, N, 3] channels 79 | :param num_frames: If not None, the number of frames should be known and reduce computation 80 | :param one_hot: If true, return one-hot decoded image (with number of channels 16) 81 | :return: A tensor either decoded phase (one-hot or exact value) 82 | """ 83 | phase_img_flipped = torch.tensor(phase_img, dtype=torch.float32).flip(dims=[1]) # flip LR here 84 | if len(phase_img_flipped.shape) < 3: 85 | phase_img_flipped = phase_img_flipped.unsqueeze(2) 86 | 87 | # figure out what's the number of frames 88 | if num_frames is None: 89 | num_frames = num_frames_ti_phase(phase_img_flipped) 90 | num_ch = 3 if num_frames % 3 == 0 else 1 91 | # num_bit_per_ch = 8 // (num_frames // num_ch) 92 | num_frames_per_ch = num_frames // num_ch 93 | num_bit_per_ch = 8 // num_frames_per_ch 94 | slm_phase_2x = torch.zeros(num_frames, *phase_img_flipped.shape[:-1]) 95 | 96 | # assign every the unique encoded binary image to each tensor (stack in batch dimension) 97 | for c in range(num_ch): 98 | for i in range(num_frames_per_ch): 99 | f = c * num_frames_per_ch + i 100 | slm_phase_2x[f, ...] = phase_img_flipped[..., c:c+1].squeeze().clone().detach() % 2 101 | phase_img_flipped[..., c:c+1].div_((2**num_bit_per_ch), rounding_mode='trunc') 102 | 103 | if one_hot: 104 | # return one-hot vector agnostic of the discrete phase values the SLM supports 105 | indices = decode_binary_phase(slm_phase_2x, return_index=True) 106 | output = torch.zeros((len(DiscreteSLM.lut_midvals), *indices.shape[-2:])).scatter_(0, indices, 1.0) 107 | else: 108 | # binary to 4bit, and apply LUT 109 | slm_phase = decode_binary_phase(slm_phase_2x) 110 | output = slm_phase.unsqueeze(1) # return a tensor shape of (N, 1, H, W) 111 | 112 | return output 113 | 114 | 115 | def num_frames_ti_phase(phase_img): 116 | """ 117 | return the number of frames encoded in this numpy image. 118 | 119 | :param phase_img: phase pattern input 120 | :return: An integer, number of frames 121 | """ 122 | if len(phase_img.shape) < 3 or phase_img.shape[2] == 1: 123 | num_frames = 1 124 | one_bit_imgs = torch.zeros((8, *phase_img.shape), device=phase_img.device) 125 | r = phase_img.clone().detach() 126 | else: 127 | r = phase_img[..., 0].clone().detach() 128 | g = phase_img[..., 1] 129 | b = phase_img[..., 2] 130 | 131 | img_size = r.shape 132 | one_bit_imgs = torch.zeros((8, *img_size)) 133 | 134 | if ((r-g)**2).mean() < 1e-3 and ((g-b)**2).mean() < 1e-3: 135 | # monochromatic 136 | num_frames = 1 137 | else: 138 | num_frames = 3 139 | 140 | # check this is unique or not 141 | cnt = 0 142 | for i in range(8): 143 | one_bit_imgs[i, ...] = r % 2 144 | r //= 2 # shift 1 bit 145 | if ((one_bit_imgs[i, ...] - one_bit_imgs[0, ...])**2).mean() < 1e-3: 146 | cnt += 1 147 | return num_frames * (8 // cnt) 148 | 149 | 150 | def decode_binary_phase(binary_img, return_index=False): 151 | """ 152 | 153 | :param phase_img: Assume as a tensor shape of (N, H, W) 154 | :return: 155 | """ 156 | top_left = binary_img[..., ::2, ::2] # M1 157 | top_right = binary_img[..., ::2, 1::2] # M3 158 | bottom_left = binary_img[..., 1::2, ::2] # M0 159 | bottom_right = binary_img[..., 1::2, 1::2] # M2 160 | 161 | indices = 8 * top_right + 4 * bottom_right + 2 * top_left + bottom_left 162 | img_shape = indices.shape 163 | indices = indices.type(torch.int32) 164 | indices = indices.reshape(indices.numel()) 165 | 166 | if return_index: 167 | # return index (0~15) per pixels 168 | memory_cell_lut = torch.tensor(ti.idx_order).to(binary_img.device) 169 | output = torch.index_select(memory_cell_lut, 0, indices).reshape(*img_shape) 170 | else: 171 | # return phase values 172 | decoded_phase = torch.index_select(ti.idx2phase.to(binary_img.device), 0, indices) 173 | output = decoded_phase.reshape(*img_shape) 174 | 175 | return output 176 | 177 | 178 | def merge_binary_phases(phases): 179 | """ 180 | 181 | :param phases: input phase tensors 182 | :return: 183 | """ 184 | rgb_phases = [] 185 | for phase in phases: 186 | decoded_phase = rgb_decoding(phase) 187 | print(decoded_phase) 188 | rgb_phases.append(decoded_phase) 189 | rgb_phases = torch.cat(rgb_phases, 0) 190 | num_phases = rgb_phases.shape[0] 191 | if num_phases < 24: 192 | rgb_phases = torch.cat((rgb_phases, rgb_phases[:24-num_phases, ...]), 0) 193 | encoded_phase = rgb_encoding(torch.tensor(rgb_phases, dtype=torch.float32)) 194 | 195 | return encoded_phase 196 | -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/time-multiplexed-neural-holography/5cf6c275c459652abb3ddddd2e167f9584072aeb/img/teaser.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu) 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | Technical Paper: 10 | Time-multiplexed Neural Holography: 11 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 12 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein. 13 | SIGGRAPH 2022 14 | ----- 15 | 16 | $ python main.py --lr=0.01 --num_iters=10000 --num_frames=8 --quan_method=gumbel-softmax 17 | 18 | """ 19 | import os 20 | import json 21 | import torch 22 | import imageio 23 | import configargparse 24 | from torch.utils.tensorboard import SummaryWriter 25 | from collections import defaultdict 26 | 27 | import utils 28 | import params 29 | import algorithms as algs 30 | import quantization as q 31 | import numpy as np 32 | import image_loader as loaders 33 | from torch.utils.data import DataLoader 34 | import props.prop_model as prop_model 35 | import props.prop_physical as prop_physical 36 | from hw.phase_encodings import phase_encoding 37 | from torchvision.utils import save_image 38 | 39 | from pprint import pprint 40 | 41 | #import wx 42 | #wx.DisableAsserts() 43 | 44 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 45 | 46 | 47 | def main(): 48 | # Command line argument processing / Parameters 49 | torch.set_default_dtype(torch.float32) 50 | p = configargparse.ArgumentParser() 51 | p.add('-c', '--config_filepath', required=False, 52 | is_config_file=True, help='Path to config file.') 53 | params.add_parameters(p, 'eval') 54 | opt = params.set_configs(p.parse_args()) 55 | params.add_lf_params(opt) 56 | dev = torch.device('cuda') 57 | 58 | run_id = params.run_id(opt) 59 | # path to save out optimized phases 60 | out_path = os.path.join(opt.out_path, run_id) 61 | print(f' - out_path: {out_path}') 62 | 63 | # Tensorboard 64 | summaries_dir = os.path.join(out_path, 'summaries') 65 | utils.cond_mkdir(summaries_dir) 66 | writer = SummaryWriter(summaries_dir) 67 | 68 | # Write opt to experiment folder 69 | utils.write_opt(vars(p.parse_args()), out_path) 70 | 71 | # Propagations 72 | camera_prop = None 73 | if opt.citl: 74 | camera_prop = prop_physical.PhysicalProp(*(params.hw_params(opt)), shutter_speed=opt.shutter_speed).to(dev) 75 | camera_prop.calibrate_total_laser_energy() # important! 76 | sim_prop = prop_model.model(opt) 77 | sim_prop.eval() 78 | 79 | # Look-up table of SLM 80 | if opt.use_lut: 81 | lut = q.load_lut(sim_prop, opt) 82 | else: 83 | lut = None 84 | quantization = q.quantization(opt, lut) 85 | 86 | # Algorithm 87 | algorithm = algs.load_alg(opt.method, mem_eff=opt.mem_eff) 88 | 89 | # Loader 90 | if ',' in opt.data_path: 91 | opt.data_path = opt.data_path.split(',') 92 | img_loader = loaders.TargetLoader(shuffle=opt.random_gen, 93 | vertical_flips=opt.random_gen, 94 | horizontal_flips=opt.random_gen, 95 | scale_vd_range=False, **opt) 96 | 97 | for i, target in enumerate(img_loader): 98 | target_amp, target_mask, target_idx = target 99 | target_amp = target_amp.to(dev).detach() 100 | 101 | if target_mask is not None: 102 | target_mask = target_mask.to(dev).detach() 103 | if len(target_amp.shape) < 4: 104 | target_amp = target_amp.unsqueeze(0) 105 | 106 | print(f' - run phase optimization for {target_idx}th image ...') 107 | 108 | if opt.random_gen: # random parameters for dataset generation 109 | img_files = os.listdir(out_path) 110 | img_files = [f for f in img_files if f.endswith('.png')] 111 | if len(img_files) > opt.num_data: # generate enough data 112 | break 113 | print("Num images: ", len(img_files), " (max: ", opt.num_data) 114 | opt.num_frames, opt.num_iters, opt.init_phase_range, \ 115 | target_range, opt.lr, opt.eval_plane_idx, \ 116 | opt.quan_method, opt.reg_lf_var = utils.random_gen(**opt) 117 | sim_prop = prop_model.model(opt) 118 | quantization = q.quantization(opt, lut) 119 | target_amp *= target_range 120 | if opt.reg_lf_var > 0.0 and isinstance(sim_prop, prop_model.CNNpropCNN): 121 | opt.num_frames = min(opt.num_frames, 4) 122 | 123 | out_path_idx = f'{opt.out_path}_{target_idx}' 124 | 125 | # initial slm phase 126 | init_phase = utils.init_phase(opt.init_phase_type, target_amp, dev, opt) 127 | 128 | # run algorithm 129 | results = algorithm(init_phase, target_amp, target_mask, target_idx, 130 | forward_prop=sim_prop, camera_prop=camera_prop, 131 | writer=writer, quantization=quantization, 132 | out_path_idx=out_path_idx, **opt) 133 | 134 | # optimized slm phase 135 | final_phase = results['final_phase'] 136 | recon_amp = results['recon_amp'] 137 | target_amp = results['target_amp'] 138 | 139 | # encoding for SLM & save it out 140 | if opt.random_gen: 141 | # decompose it into several 1-bit phases 142 | for k, final_phase_1bit in enumerate(final_phase): 143 | phase_out = phase_encoding(final_phase_1bit.unsqueeze(0), opt.slm_type) 144 | phase_out_path = os.path.join(out_path, f'{target_idx}_{opt.num_iters}{k}.png') 145 | imageio.imwrite(phase_out_path, phase_out) 146 | else: 147 | phase_out = phase_encoding(final_phase, opt.slm_type) 148 | recon_amp, target_amp = recon_amp.squeeze().detach().cpu().numpy(), target_amp.squeeze().detach().cpu().numpy() 149 | 150 | # save final phase and intermediate phases 151 | if phase_out is not None: 152 | phase_out_path = os.path.join(out_path, f'{target_idx}_phase.png') 153 | imageio.imwrite(phase_out_path, phase_out) 154 | 155 | if opt.save_images: 156 | recon_out_path = os.path.join(out_path, f'{target_idx}_recon.png') 157 | target_out_path = os.path.join(out_path, f'{target_idx}_target.png') 158 | 159 | if opt.channel is None: 160 | recon_amp = recon_amp.transpose(1, 2, 0) 161 | target_amp = target_amp.transpose(1, 2, 0) 162 | 163 | recon_out = utils.srgb_lin2gamma(np.clip(recon_amp**2, 0, 1)) # linearize and gamma 164 | target_out = utils.srgb_lin2gamma(np.clip(target_amp**2, 0, 1)) # linearize and gamma 165 | 166 | imageio.imwrite(recon_out_path, (recon_out * 255).astype(np.uint8)) 167 | imageio.imwrite(target_out_path, (target_out * 255).astype(np.uint8)) 168 | 169 | if camera_prop is not None: 170 | camera_prop.disconnect() 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | """ 2 | Default parameter settings for SLMs as well as laser/sensors 3 | 4 | """ 5 | import sys 6 | import utils 7 | import datetime 8 | import torch.nn as nn 9 | from hw.discrete_slm import DiscreteSLM 10 | if sys.platform == 'win32': 11 | import serial 12 | 13 | cm, mm, um, nm = 1e-2, 1e-3, 1e-6, 1e-9 14 | 15 | 16 | def str2bool(v): 17 | """ Simple query parser for configArgParse (which doesn't support native bool from cmd) 18 | Ref: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 19 | 20 | """ 21 | if isinstance(v, bool): 22 | return v 23 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 24 | return True 25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 26 | return False 27 | else: 28 | raise ValueError('Boolean value expected.') 29 | 30 | 31 | class PMap(dict): 32 | # use it for parameters 33 | __getattr__ = dict.get 34 | __setattr__ = dict.__setitem__ 35 | __delattr__ = dict.__delitem__ 36 | 37 | 38 | def clone_params(opt): 39 | """ 40 | opt: PMap object 41 | """ 42 | cloned = PMap() 43 | for k in opt.keys(): 44 | cloned[k] = opt[k] 45 | return cloned 46 | 47 | def add_parameters(p, mode='train'): 48 | p.add_argument('--channel', type=int, default=None, help='Red:0, green:1, blue:2') 49 | p.add_argument('--method', type=str, default='SGD', help='Type of algorithm, GS/SGD/DPAC/HOLONET/UNET') 50 | p.add_argument('--slm_type', type=str, default='holoeye', help='holoeye(leto) or ti') 51 | p.add_argument('--sensor_type', type=str, default='4k', help='4k or 2k') 52 | p.add_argument('--laser_type', type=str, default='new', help='laser, new_laser, sLED, ...') 53 | p.add_argument('--setup_type', type=str, default='siggraph2022', help='siggraph2022, ...') 54 | p.add_argument('--prop_model', type=str, default='ASM', help='Type of propagation model, ASM or model') 55 | p.add_argument('--out_path', type=str, default='./results', 56 | help='Directory for output') 57 | p.add_argument('--citl', type=str2bool, default=False, 58 | help='If True, run camera-in-the-loop') 59 | p.add_argument('--mod_i', type=int, default=None, 60 | help='If not None, say K, pick every K target images from the target loader') 61 | p.add_argument('--mod', type=int, default=None, 62 | help='If not None, say K, pick every K target images from the target loader') 63 | p.add_argument('--data_path', type=str, default='data/2d', 64 | help='Directory for input') 65 | p.add_argument('--exp', type=str, default='', help='Name of experiment') 66 | p.add_argument('--lr', type=float, default=0.02, help='Learning rate') 67 | p.add_argument('--num_iters', type=int, default=5000, help='Number of iterations (GS, SGD)') 68 | p.add_argument('--prop_dist', type=float, default=None, help='propagation distance from SLM to midplane') 69 | p.add_argument('--num_frames', type=int, default=1, help='Number of frames to average') # effect time joint 70 | p.add_argument('--F_aperture', type=float, default=1.0, help='Fourier filter size') # how this effects 71 | p.add_argument('--eyepiece', type=float, default=0.12, help='eyepiece focal length') 72 | p.add_argument('--full_roi', type=str2bool, default=False, 73 | help='If True, force ROI to SLM resolution') 74 | p.add_argument('--flipud', type=str2bool, default=False, 75 | help='flip slm vertically before propagation') 76 | p.add_argument('--target', type=str, default='2d', 77 | help='Type of target:' 78 | '{2d, rgb} or ' 79 | '{2.5d, rgbd} or' 80 | '{3d, fs, focal-stack, focal_stack} or' 81 | '{4d, lf, light-field, light_field}') 82 | p.add_argument('--show_preview', type=str2bool, default=False, 83 | help='If true, show the preview for homography calibration') 84 | p.add_argument('--random_gen', type=str2bool, default=False, 85 | help='If true, randomize a few parameters for phase dataset generation') 86 | p.add_argument('--test_set_3d', type=str2bool, default=False, 87 | help='If true, load a set of 3D scenes for phase inference') 88 | p.add_argument('--mem_eff', type=str2bool, default=False, 89 | help='If true, run memory an efficient version of algorithms (slow)') 90 | p.add_argument("--roi_h", type=int, default=None) # height of ROI 91 | p.add_argument("--optimize_amp", type=str2bool, default=False) # optimize amplitude 92 | 93 | # Hardware 94 | p.add_argument("--slm_settle_time", type=float, default=1.0) 95 | 96 | # Regularization 97 | p.add_argument('--reg_loss_fn_type', type=str, default=None) 98 | p.add_argument('--reg_loss_w', type=float, default=0.0) 99 | p.add_argument('--recon_loss_w', type=float, default=1.0) 100 | p.add_argument('--adaptive_roi_scale', type=float, default=1.0) 101 | 102 | p.add_argument("--save_images", action="store_true") 103 | p.add_argument("--save_npy", action="store_true") 104 | p.add_argument("--serial_two_prop_off", action="store_true", help="Directly propagate prop_dist, and don't use prop_dist_from_wrp.") 105 | 106 | # Initialization schemes 107 | p.add_argument('--init_phase_type', type=str, default="random", choices=["random"]) 108 | 109 | 110 | # Quantization 111 | p.add_argument('--quan_method', type=str, default='None', 112 | help='Quantization method, None, nn, nn_sigmoid, gumbel-softmax, ...') 113 | p.add_argument('--c_s', type=float, default=300, 114 | help='Coefficient mutliplied to score value - considering Gumbel noise scale') 115 | p.add_argument('--uniform_nbits', type=int, default=None, 116 | help='If not None, use uniformly-distributed discrete SLM levels for quantization') 117 | p.add_argument('--tau_max', type=float, default=5.5, 118 | help='tau value used for quantization at the beginning - increase for more constrained cases') 119 | p.add_argument('--tau_min', type=float, default=2.0, 120 | help='minimum tau value used for quantization') 121 | p.add_argument('--r', type=float, default=None, 122 | help='coefficient on the exponent (speed of decrease)') 123 | p.add_argument('--phase_offset', type=float, default=0.0, 124 | help='You can shift the whole phase to some extent (Not used in the paper)') 125 | p.add_argument('--time_joint', type=str2bool, default=True, 126 | help='If True, jointly optimize multiple frames with time-multiplexed forward model') 127 | p.add_argument('--init_phase_range', type=float, default=1.0, 128 | help='initial phase range') 129 | p.add_argument('--eval_plane_idx', type=int, default=None, 130 | help='depth plane to evaluate hologram reconstruction') 131 | p.add_argument('--use_lut', action="store_true", help="Use SLM discrete phase lookup table.") 132 | 133 | p.add_argument('--gpu_id', type=int, default=0, help="GPU id") 134 | 135 | # Dataset 136 | p.add_argument("--dataset_subset_size", type=int, default=None) 137 | p.add_argument("--img_paths", type=str, nargs="+", default=None) 138 | p.add("--shutter_speed", type=float, nargs='+', default=100, help="Shutter speed of camera.") 139 | p.add("--num_data", type=int, default=100, help="Number of data to generate.") 140 | 141 | # Light field 142 | p.add_argument('--hop_len', type=int, default=0.0, 143 | help='hop every k - if you hop every window size being HS') 144 | p.add_argument('--n_fft', type=int, default=True, 145 | help='number of fourier samples per patch') 146 | p.add_argument('--win_len', type=int, default=1.0, 147 | help='STFT window size') 148 | p.add_argument('--central_views', type=str2bool, default=False, 149 | help='If True, penalize only central views') 150 | p.add_argument('--reg_lf_var', type=float, default=0.0, 151 | help='lf regularization') 152 | 153 | if mode in ('train', 'eval'): 154 | p.add_argument('--num_epochs', type=int, default=350, help='') 155 | p.add_argument('--batch_size', type=int, default=1, help='') 156 | p.add_argument('--prop_model_path', type=str, default=None, help='Path to checkpoints') 157 | p.add_argument('--predefined_model', type=str, default=None, help='string for predefined model' 158 | 'nh, nh3d, nh4d') 159 | p.add_argument('--num_downs_slm', type=int, default=5, help='') 160 | p.add_argument('--num_feats_slm_min', type=int, default=32, help='') 161 | p.add_argument('--num_feats_slm_max', type=int, default=128, help='') 162 | p.add_argument('--num_downs_target', type=int, default=5, help='') 163 | p.add_argument('--num_feats_target_min', type=int, default=32, help='') 164 | p.add_argument('--num_feats_target_max', type=int, default=128, help='') 165 | p.add_argument('--slm_coord', type=str, default='rect', help='coordinates to represent a complex-valued field.' 166 | 'rect(real+imag) or polar(amp+phase)') 167 | p.add_argument('--target_coord', type=str, default='rect', help='coordinates to represent a complex-valued field.' 168 | 'rect(real+imag) or polar(amp+phase)') 169 | p.add_argument('--param_lut', type=str2bool, default=False, help='') 170 | p.add_argument('--norm', type=str, default='instance', help='normalization layer') 171 | p.add_argument('--slm_latent_amp', type=str2bool, default=False, help='If True, ' 172 | 'param amplitdues multiplied at SLM') 173 | p.add_argument('--slm_latent_phase', type=str2bool, default=False, help='If True, ' 174 | 'parameterize phase added at SLM') 175 | p.add_argument('--f_latent_amp', type=str2bool, default=False, help='If True, ' 176 | 'parameterize amplitdues multiplied at F') 177 | p.add_argument('--f_latent_phase', type=str2bool, default=False, help='If True, ' 178 | 'parameterize amplitdues added at F') 179 | p.add_argument('--share_f_amp', type=str2bool, default=False, help='If True, use the same f_latent_amp params ' 180 | 'for propagating fields from WRP to' 181 | 'Target planes') 182 | p.add_argument('--share_f_phase', type=str2bool, default=False, help='If True, use the same f_latent_phase ' 183 | 'params for propagating fields from WRP to' 184 | 'Target planes') 185 | p.add_argument('--loss_func', type=str, default='l1', help='l1 or l2') 186 | p.add_argument('--energy_compensation', type=str2bool, default=True, help='adjust intensities ' 187 | 'with avg intensity of training set') 188 | p.add_argument('--num_train_planes', type=int, default=6, help='number of planes fed to models') 189 | p.add_argument('--learn_f_amp_wrp', type=str2bool, default=False) 190 | p.add_argument('--learn_f_phase_wrp', type=str2bool, default=False) 191 | 192 | # cnn residuals 193 | p.add_argument("--slm_cnn_residual", type=str2bool, default=False) 194 | p.add_argument("--target_cnn_residual", type=str2bool, default=False) 195 | p.add_argument("--min_mse_scaling", type=str2bool, default=False) 196 | p.add_argument("--dataset_subset", type=int, default=None) 197 | 198 | return p 199 | 200 | 201 | def set_configs(opt_p): 202 | """ 203 | set or replace parameters with pre-defined parameters with string inputs 204 | """ 205 | opt = PMap() 206 | for k, v in vars(opt_p).items(): 207 | opt[k] = v 208 | 209 | # hardware setup 210 | optics_config(opt.setup_type, opt) # prop_dist, etc ... 211 | laser_config(opt.laser_type, opt) # Our Old FISBA Laser, New, SLED, LED 212 | slm_config(opt.slm_type, opt) # Holoeye or TI 213 | sensor_config(opt.sensor_type, opt) # old or new 4k 214 | 215 | # set predefined model parameters 216 | forward_model_config(opt.prop_model, opt) 217 | 218 | # wavelength, propagation distance (from SLM to midplane) 219 | if opt.channel is None: 220 | opt.chan_str = 'rgb' 221 | #opt.prop_dist = opt.prop_dists_rgb 222 | opt.prop_dist_green = opt.prop_dist 223 | opt.wavelength = opt.wavelengths 224 | else: 225 | opt.chan_str = ('red', 'green', 'blue')[opt.channel] 226 | if opt.prop_dist is None: 227 | opt.prop_dist = opt.prop_dists_rgb[opt.channel][opt.mid_idx] # prop dist from SLM plane to target plane 228 | if len(opt.prop_dists_rgb[opt.channel]) <= 1: 229 | opt.prop_dist_green = opt.prop_dists_rgb[opt.channel][0] 230 | else: 231 | opt.prop_dist_green = opt.prop_dists_rgb[opt.channel][1] 232 | else: 233 | opt.prop_dist_green = opt.prop_dist 234 | opt.wavelength = opt.wavelengths[opt.channel] # wavelength of each color 235 | 236 | # propagation distances from the wavefront recording plane 237 | if opt.channel is not None: 238 | opt.prop_dists_from_wrp = [p - opt.prop_dist for p in opt.prop_dists_rgb[opt.channel]] 239 | else: 240 | opt.prop_dists_from_wrp = [p - opt.prop_dist for p in opt.prop_dists_rgb[1]] 241 | opt.physical_depth_planes = [p - opt.prop_dist_green for p in opt.prop_dists_physical] 242 | opt.virtual_depth_planes = utils.prop_dist_to_diopter(opt.physical_depth_planes, 243 | opt.eyepiece, 244 | opt.physical_depth_planes[0]) 245 | if opt.serial_two_prop_off: 246 | opt.prop_dists_from_wrp = None 247 | opt.num_planes = 1 # use prop_dist 248 | assert opt.prop_dist is not None 249 | else: 250 | opt.num_planes = len(opt.prop_dists_from_wrp) 251 | opt.all_plane_idxs = range(opt.num_planes) 252 | 253 | # force ROI to that of SLM 254 | if opt.full_roi: 255 | opt.roi_res = opt.slm_res 256 | 257 | ################ 258 | # Model Training 259 | # compensate the brightness difference per plane (for model training) 260 | if opt.energy_compensation: 261 | if opt.channel is not None: 262 | opt.avg_energy_ratio = opt.avg_energy_ratio_rgb[opt.channel] 263 | else: 264 | opt.avg_energy_ratio = None 265 | else: 266 | opt.avg_energy_ratio = None 267 | 268 | # loss functions (for model training) 269 | opt.loss_train = None 270 | opt.loss_fn = None 271 | if opt.loss_func.lower() in ('l2', 'mse'): 272 | opt.loss_train = nn.functional.mse_loss 273 | opt.loss_fn = nn.functional.mse_loss 274 | elif opt.loss_func.lower() == 'l1': 275 | opt.loss_train = nn.functional.l1_loss 276 | opt.loss_fn = nn.functional.l1_loss 277 | 278 | # plane idxs (for model training) 279 | opt.plane_idxs = {} 280 | opt.plane_idxs['all'] = opt.all_plane_idxs 281 | opt.plane_idxs['train'] = opt.training_plane_idxs 282 | opt.plane_idxs['validation'] = opt.training_plane_idxs 283 | opt.plane_idxs['test'] = opt.training_plane_idxs 284 | opt.plane_idxs['heldout'] = opt.heldout_plane_idxs 285 | 286 | return opt 287 | 288 | 289 | def run_id(opt): 290 | id_str = f'{opt.exp}_{opt.method}_{opt.chan_str}_{opt.prop_model}_{opt.num_iters}_recon_{opt.recon_loss_w}_{opt.reg_loss_fn_type}_{opt.reg_loss_w}_{opt.init_phase_type}' 291 | if opt.citl: 292 | id_str = f'{id_str}_citl' 293 | if opt.mem_eff: 294 | id_str = f'{id_str}_memeff' 295 | id_str = f'{id_str}_tm_{opt.num_frames}' # time multiplexing 296 | if opt.citl: 297 | id_str = f'{id_str}_sht_{opt.shutter_speed[0]}' # shutter speed 298 | if opt.optimize_amp: 299 | id_str = f'{id_str}_opt_amp' 300 | return id_str 301 | 302 | def run_id_training(opt): 303 | id_str = f'{opt.exp}_{opt.chan_str}-' \ 304 | f'data_{opt.capture_subset}-' \ 305 | f'slm{opt.num_downs_slm}-{opt.num_feats_slm_min}-{opt.num_feats_slm_max}_' \ 306 | f'{str(opt.slm_latent_amp)[0]}{str(opt.slm_latent_phase)[0]}_' \ 307 | f'tg{opt.num_downs_target}-{opt.num_feats_target_min}-{opt.num_feats_target_max}_' \ 308 | f'lut{str(opt.param_lut)[0]}_' \ 309 | f'lH{str(opt.f_latent_amp)[0]}{str(opt.f_latent_phase)[0]}_' \ 310 | f'sH{str(opt.share_f_amp)[0]}{str(opt.share_f_phase)[0]}_' \ 311 | f'eH{str(opt.learn_f_amp_wrp)[0]}{str(opt.learn_f_phase_wrp)[0]}_' \ 312 | f'{opt.slm_coord}{opt.target_coord}_{opt.loss_func}_{opt.num_train_planes}pls_' \ 313 | f'bs{opt.batch_size}_' \ 314 | f'res-{opt.slm_cnn_residual}-{opt.target_cnn_residual}_' \ 315 | f'mse-s{opt.min_mse_scaling}' 316 | 317 | cur_time = datetime.datetime.now().strftime("%d-%H%M") 318 | id_str = f'{cur_time}_{id_str}' 319 | 320 | return id_str 321 | 322 | 323 | def hw_params(opt): 324 | params_slm = PMap() 325 | params_slm.settle_time = max(opt.shutter_speed) * 2.5 / 1000 # shutter speed is in ms 326 | params_slm.monitor_num = 1 # change here 327 | params_slm.slm_type = opt.slm_type 328 | 329 | params_camera = PMap() 330 | #params_camera.img_size_native = (3000, 4096) # 4k sensor native 331 | params_camera.img_size_native = (1700, 2736) # Used for SIGGRAPH 2022 332 | params_camera.ser = None #serial.Serial('COM5', 9600, timeout=0.5) 333 | 334 | params_calib = PMap() 335 | params_calib.show_preview = opt.show_preview 336 | params_calib.range_y = slice(0, params_camera.img_size_native[0]) 337 | params_calib.range_x = slice(0, params_camera.img_size_native[1]) 338 | params_calib.num_circles = (11, 18) 339 | 340 | params_calib.spacing_size = [int(roi / (num_circs - 1)) 341 | for roi, num_circs in zip(opt.roi_res, params_calib.num_circles)] 342 | params_calib.pad_pixels = [int(slm - roi) // 2 for slm, roi in zip(opt.slm_res, opt.roi_res)] 343 | params_calib.quadratic = True 344 | 345 | colors = ['red', 'green', 'blue'] 346 | params_calib.phase_path = f"data/calib/{colors[opt.channel]}/11x18_r19_ti_slm_dots_phase.png" # optimize homography pattern for every plane 347 | params_calib.blank_phase_path = "data/calib/2560x1600_blank.png" 348 | params_calib.img_size_native = params_camera.img_size_native 349 | 350 | return params_slm, params_camera, params_calib 351 | 352 | 353 | def slm_config(slm_type, opt): 354 | # setting for specific SLM. 355 | if slm_type.lower() in ('ti'): 356 | opt.feature_size = (10.8 * um, 10.8 * um) # SLM pitch 357 | opt.slm_res = (800, 1280) # resolution of SLM 358 | opt.image_res = (800, 1280) 359 | #opt.image_res = (1600, 2560) 360 | if opt.channel is not None: 361 | opt.lut0 = DiscreteSLM.lut[:-1] * 636.4 * nm / opt.wavelengths[opt.channel] # scaled LUT 362 | else: 363 | opt.lut0 = DiscreteSLM.lut[:-1] 364 | opt.flipud = True 365 | elif slm_type.lower() in ('leto', 'holoeye'): 366 | opt.feature_size = (6.4 * um, 6.4 * um) # SLM pitch 367 | opt.slm_res = (1080, 1920) # resolution of SLM 368 | opt.image_res = opt.slm_res 369 | opt.lut0 = None 370 | if opt.projector: 371 | opt.flipud = not opt.flipud 372 | 373 | def laser_config(laser_type, opt): 374 | # setting for specific laser. 375 | if 'new' in laser_type.lower(): 376 | opt.wavelengths = [636.17 * nm, 518.48 * nm, 442.03 * nm] # wavelength of each color 377 | elif "readybeam" in laser_type.lower(): 378 | # using this for etech 379 | opt.wavelengths = (638.35 * nm, 521.16 * nm, 443.50 * nm) 380 | else: 381 | opt.wavelengths = [636.4 * nm, 517.7 * nm, 440.8 * nm] 382 | 383 | 384 | def sensor_config(sensor_type, opt): 385 | return opt 386 | 387 | 388 | def optics_config(setup_type, opt): 389 | if setup_type in ('siggraph2022'): 390 | opt.laser_type = 'old' 391 | opt.slm_type = 'ti' 392 | opt.avg_energy_ratio_rgb = [[1.0000, 1.0595, 1.1067, 1.1527, 1.1943, 1.2504, 1.3122], 393 | [1.0000, 1.0581, 1.1051, 1.1490, 1.1994, 1.2505, 1.3172], 394 | [1.0000, 1.0560, 1.1035, 1.1487, 1.2008, 1.2541, 1.3183]] # averaged over training set 395 | opt.prop_dists_rgb = [[7.76*cm, 7.96*cm, 8.13*cm, 8.31*cm, 8.48*cm, 8.72*cm, 9.04*cm], 396 | [7.77*cm, 7.97*cm, 8.13*cm, 8.31*cm, 8.48*cm, 8.72*cm, 9.04*cm], 397 | [7.76*cm, 7.96*cm, 8.13*cm, 8.31*cm, 8.48*cm, 8.72*cm, 9.04*cm]] 398 | opt.prop_dists_physical = opt.prop_dists_rgb[1] 399 | opt.roi_res = (700, 1190) # regions of interest (to penalize for SGD) 400 | 401 | if not opt.method.lower() in ['olas', 'dpac']: 402 | opt.F_aperture = (0.7, 0.78, 0.9)[opt.channel] 403 | else: 404 | opt.F_aperture = 0.49 405 | 406 | # indices of training planes (idx 4 is the held-out plane) 407 | if opt.num_train_planes == 1: 408 | opt.training_plane_idxs = [3] 409 | elif opt.num_train_planes == 3: 410 | opt.training_plane_idxs = [0, 3, 6] 411 | elif opt.num_train_planes == 5: 412 | opt.training_plane_idxs = [0, 2, 3, 5, 6] 413 | elif opt.num_train_planes == 6: 414 | opt.training_plane_idxs = [0, 1, 2, 3, 5, 6] 415 | else: 416 | opt.training_plane_idxs = None 417 | opt.heldout_plane_idxs = [4] 418 | opt.mid_idx = 3 # intermediate plane as 1.5D 419 | 420 | 421 | def forward_model_config(model_type, opt): 422 | # setting for specific model that is predefined. 423 | if model_type is not None: 424 | print(f' - changing model parameters for {model_type}') 425 | if model_type.lower() == 'nh3d': 426 | opt.num_downs_slm = 8 427 | opt.num_feats_slm_min = 32 428 | opt.num_feats_slm_max = 512 429 | opt.num_downs_target = 5 430 | opt.num_feats_target_min = 8 431 | opt.num_feats_target_max = 128 432 | opt.param_lut = False 433 | 434 | elif model_type.lower() == 'hil': 435 | opt.num_downs_slm = 0 436 | opt.num_feats_slm_min = 0 437 | opt.num_feats_slm_max = 0 438 | opt.num_downs_target = 8 439 | opt.num_feats_target_min = 32 440 | opt.num_feats_target_max = 512 441 | opt.target_coord = 'amp' 442 | opt.param_lut = False 443 | 444 | elif model_type.lower() == 'cnnprop': 445 | opt.num_downs_slm = 8 446 | opt.num_feats_slm_min = 32 447 | opt.num_feats_slm_max = 512 448 | opt.num_downs_target = 0 449 | opt.num_feats_target_min = 0 450 | opt.num_feats_target_max = 0 451 | opt.param_lut = False 452 | 453 | elif model_type.lower() == 'propcnn': 454 | opt.num_downs_slm = 0 455 | opt.num_feats_slm_min = 0 456 | opt.num_feats_slm_max = 0 457 | opt.num_downs_target = 8 458 | opt.num_feats_target_min = 32 459 | opt.num_feats_target_max = 512 460 | opt.param_lut = False 461 | 462 | elif model_type.lower() == 'nh4d': 463 | opt.num_downs_slm = 5 464 | opt.num_feats_slm_min = 32 465 | opt.num_feats_slm_max = 128 466 | opt.num_downs_target = 5 467 | opt.num_feats_target_min = 32 468 | opt.num_feats_target_max = 128 469 | opt.num_target_latent = 0 470 | opt.norm = 'instance' 471 | opt.slm_coord = 'both' 472 | opt.target_coord = 'both_1ch_output' 473 | opt.param_lut = True 474 | opt.slm_latent_amp = True 475 | opt.slm_latent_phase = True 476 | opt.f_latent_amp = True 477 | opt.f_latent_phase = True 478 | opt.share_f_amp = True 479 | 480 | 481 | def add_lf_params(opt, dataset='olas'): 482 | """ Add Light-Field parameters """ 483 | if opt.target.lower() in ('rgbd'): 484 | if opt.reg_lf_var > 0.0: 485 | opt.ang_res = (7, 7) 486 | opt.load_only_central_view = True 487 | opt.hop_len = (1, 1) 488 | opt.n_fft = opt.ang_res 489 | opt.win_len = opt.ang_res 490 | if opt.central_views: 491 | opt.selected_views = (slice(1, 6, 1), slice(1, 6, 1)) 492 | else: 493 | opt.selected_views = None 494 | return opt 495 | else: 496 | return opt 497 | else: 498 | if dataset == 'olas': 499 | opt.ang_res = (9, 9) 500 | opt.load_only_central_view = opt.target.lower() == 'rgbd' 501 | opt.hop_len = (1, 1) 502 | opt.n_fft = opt.ang_res 503 | opt.win_len = opt.ang_res 504 | 505 | if dataset == 'parallax': 506 | opt.ang_res = (7, 7) 507 | opt.load_only_central_view = opt.target.lower() == 'rgbd' 508 | opt.hop_len = (1, 1) 509 | opt.n_fft = opt.ang_res 510 | opt.win_len = opt.ang_res 511 | 512 | if 'lf' in opt.target.lower(): 513 | opt.prop_dist_from_wrp = [0.] 514 | opt.c_s = 700 515 | if opt.central_views: 516 | opt.selected_views = (slice(1, 6, 1), slice(1, 6, 1)) 517 | else: 518 | opt.selected_views = None 519 | 520 | return opt -------------------------------------------------------------------------------- /props/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-imaging/time-multiplexed-neural-holography/5cf6c275c459652abb3ddddd2e167f9584072aeb/props/__init__.py -------------------------------------------------------------------------------- /props/prop_ideal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ideal propagation 3 | 4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu) 5 | 6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 8 | # The material is provided as-is, with no warranties whatsoever. 9 | # If you publish any code, data, or scientific work based on this, please cite our work. 10 | 11 | Technical Paper: 12 | Time-multiplexed Neural Holography: 13 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 14 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein. 15 | SIGGRAPH 2022 16 | """ 17 | 18 | import torch 19 | import torch.nn as nn 20 | import utils 21 | import torch.fft as tfft 22 | import math 23 | from copy import deepcopy 24 | 25 | class Propagation(nn.Module): 26 | """ 27 | The ideal, convolution-based propagation implementation 28 | 29 | Class initialization parameters 30 | ------------------------------- 31 | :param prop_dist: propagation distance(s) 32 | :param wavelength: wavelength 33 | :param feature_size: pixel pitch 34 | :param prop_type: type of propagation (ASM or fresnel), by default the angular spectrum method 35 | :param F_aperture: filter size at fourier plane, by default 1.0 36 | :param dim: for propagation to multiple planes, dimension to stack the output, by default 1 (second dimension) 37 | :param linear_conv: If true, pad zeros to ensure the linear convolution, by default True 38 | :param learned_amp: Learned amplitude at Fourier plane, by default None 39 | :param learned_phase: Learned phase at Fourier plane, by default None 40 | """ 41 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0, 42 | dim=1, linear_conv=True, learned_amp=None, learned_phase=None, learned_field=None): 43 | super(Propagation, self).__init__() 44 | 45 | self.H = None # kernel at Fourier plane 46 | self.prop_type = prop_type 47 | if not isinstance(prop_dist, list): 48 | prop_dist = [prop_dist] 49 | self.prop_dist = prop_dist 50 | self.feature_size = feature_size 51 | if not isinstance(wavelength, list): 52 | wavelength = [wavelength] 53 | self.wvl = wavelength 54 | self.linear_conv = linear_conv # ensure linear convolution by padding 55 | self.bl_asm = min(prop_dist) > 0.3 56 | self.F_aperture = F_aperture 57 | self.dim = dim # The dimension to stack the kernels as well as the resulting fields (if multi-channel) 58 | 59 | self.preload_params = False 60 | self.preloaded_H_amp = False # preload H_mask once trained 61 | self.preloaded_H_phase = False # preload H_phase once trained 62 | 63 | self.fourier_amp = learned_amp 64 | self.fourier_phase = learned_phase 65 | self.fourier_field = learned_field 66 | 67 | #self.bl_asm = True 68 | if self.bl_asm: 69 | print("Using band-limited ASM") 70 | else: 71 | print("Using naive ASM") 72 | 73 | def forward(self, u_in): 74 | if u_in.dtype == torch.float32: # check if this is phase or already a wavefield 75 | u_in = torch.exp(1j * u_in) # convert phase to wavefront 76 | 77 | if self.H is None: 78 | Hs = [] 79 | if len(self.wvl) > 1: # If multi-channel, rearrange kernels 80 | for i, wv in enumerate(self.wvl): 81 | H_wvl = [] 82 | for prop_dist in self.prop_dist: 83 | print(f' -- generating kernel for {wv*1e9:.1f}nm, {prop_dist*100:.2f}cm..') 84 | h = self.compute_H(torch.empty_like(u_in), prop_dist, wv, self.feature_size, 85 | self.prop_type, self.linear_conv, 86 | F_aperture=self.F_aperture, bl_asm=self.bl_asm) 87 | H_wvl.append(h) 88 | H_wvl = torch.cat(H_wvl, dim=1) 89 | Hs.append(H_wvl) 90 | self.H = torch.cat(Hs, dim=1) 91 | else: 92 | for wv in self.wvl: 93 | for prop_dist in self.prop_dist: 94 | print(f' -- generating kernel for {wv*1e9:.1f}nm, {prop_dist*100:.2f}cm..') 95 | h = self.compute_H(torch.empty_like(u_in), prop_dist, wv, self.feature_size, 96 | self.prop_type, self.linear_conv, 97 | F_aperture=self.F_aperture, bl_asm=self.bl_asm) 98 | Hs.append(h) 99 | self.H = torch.cat(Hs, dim=1) 100 | 101 | if self.preload_params: 102 | self.premultiply() 103 | 104 | if self.fourier_field is not None: 105 | # for neural wavefront model 106 | fourier_field, fourier_dc_field = self.fourier_field() # neural wavefield 107 | H = self.H * fourier_field 108 | else: 109 | if self.fourier_amp is not None and not self.preloaded_H_amp: 110 | H = self.fourier_amp.clamp(min=0.) * self.H 111 | else: 112 | H = self.H 113 | 114 | if self.fourier_phase is not None and not self.preloaded_H_phase: 115 | H = H * torch.exp(1j * self.fourier_phase) 116 | 117 | return self.prop(u_in, H, self.linear_conv) 118 | 119 | def compute_H(self, input_field, prop_dist, wvl, feature_size, prop_type, lin_conv=True, 120 | return_exp=False, F_aperture=1.0, bl_asm=False, return_filter=False): 121 | dev = input_field.device 122 | res_mul = 2 if lin_conv else 1 123 | num_y, num_x = res_mul*input_field.shape[-2], res_mul*input_field.shape[-1] # number of pixels 124 | dy, dx = feature_size # sampling inteval size, pixel pitch of the SLM 125 | # does this mean the holographic display can display only one pixel (focus light to one pixel (smallest feature size))? 126 | 127 | # frequency coordinates sampling 128 | fy = torch.linspace(-1 / (2 * dy), 1 / (2 * dy), num_y) 129 | fx = torch.linspace(-1 / (2 * dx), 1 / (2 * dx), num_x) 130 | 131 | # momentum/reciprocal space 132 | # FY, FX = torch.meshgrid(fy, fx) 133 | FX, FY = torch.meshgrid(fx, fy) 134 | FX = torch.transpose(FX, 0, 1) 135 | FY = torch.transpose(FY, 0, 1) 136 | 137 | if prop_type.lower() == 'asm': 138 | G = 2 * math.pi * (1 / wvl**2 - (FX ** 2 + FY ** 2)).sqrt() 139 | elif prop_type.lower() == 'fresnel': 140 | G = math.pi * wvl * (FX ** 2 + FY ** 2) 141 | 142 | H_exp = G.reshape((1, 1, *G.shape)).to(dev) 143 | 144 | if return_exp: 145 | return H_exp 146 | 147 | if bl_asm: 148 | fy_max = 1 / math.sqrt((2 * prop_dist * (1 / (dy * float(num_y))))**2 + 1) / wvl 149 | fx_max = 1 / math.sqrt((2 * prop_dist * (1 / (dx * float(num_x))))**2 + 1) / wvl 150 | 151 | H_filter = ((torch.abs(FX**2 + FY**2) <= (F_aperture**2) * torch.abs(FX**2 + FY**2).max()) 152 | & (torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).type(torch.FloatTensor) 153 | else: 154 | H_filter = (torch.abs(FX**2 + FY**2) <= (F_aperture**2) * torch.abs(FX**2 + FY**2).max()).type(torch.FloatTensor) 155 | 156 | if prop_dist == 0.: 157 | H = torch.ones_like(H_exp) 158 | else: 159 | H = H_filter.to(input_field.device) * torch.exp(1j * H_exp * prop_dist) 160 | self.H_without_filter = torch.exp(1j * H_exp * prop_dist) 161 | self.H_filter = H_filter 162 | 163 | if return_filter: 164 | return H_filter 165 | else: 166 | return H 167 | 168 | def prop(self, u_in, H, linear_conv=True, padtype='zero'): 169 | if linear_conv: 170 | # preprocess with padding for linear conv. 171 | input_resolution = u_in.size()[-2:] 172 | conv_size = [i * 2 for i in input_resolution] 173 | if padtype == 'zero': 174 | padval = 0 175 | elif padtype == 'median': 176 | padval = torch.median(torch.pow((u_in ** 2).sum(-1), 0.5)) 177 | u_in = utils.pad_image(u_in, conv_size, padval=padval, stacked_complex=False) 178 | 179 | U1 = tfft.fftshift(tfft.fftn(u_in, dim=(-2, -1), norm='ortho'), (-2, -1)) # fourier transform 180 | #U2_without_filter = U1 * self.H_without_filter 181 | U2 = U1 * H 182 | u_out = tfft.ifftn(tfft.ifftshift(U2, (-2, -1)), dim=(-2, -1), norm='ortho') 183 | 184 | if linear_conv: # also record uncropped image 185 | self.uncropped_u_out = u_out.clone() 186 | u_out = utils.crop_image(u_out, input_resolution, pytorch=True, stacked_complex=False) 187 | 188 | """ 189 | U2_amp = torch.abs(U2) 190 | U2_without_filter_amp = torch.abs(U2_without_filter) 191 | filtered_intensity_sum = (U2_without_filter_amp**2 - U2_amp**2).mean() 192 | #print('total amp:', (U2_without_filter_amp**2).mean()) 193 | #print('filtered_intensity_sum: ', filtered_intensity_sum) 194 | 195 | # normalize to 0, 1 196 | U2_amp = (U2_amp - U2_amp.min()) / (U2_amp.max() - U2_amp.min()) 197 | U2_without_filter_amp = (U2_without_filter_amp - U2_without_filter_amp.min()) / (U2_without_filter_amp.max() - U2_without_filter_amp.min()) 198 | U2_amp = U2_amp.mean(axis=0).squeeze() 199 | U2_without_filter_amp = U2_without_filter_amp.mean(axis=0).squeeze() 200 | U2_amp = torch.log(U2_amp + 1e-10) 201 | U2_without_filter_amp = torch.log(U2_without_filter_amp + 1e-10) 202 | 203 | H_amp = torch.abs(self.H_filter) 204 | H_amp = (H_amp - H_amp.min()) / (H_amp.max() - H_amp.min()) 205 | H_amp = H_amp.squeeze() 206 | """ 207 | 208 | 209 | return u_out 210 | 211 | def __len__(self): 212 | return len(self.prop_dist) 213 | 214 | def preload_H(self): 215 | self.preload_params = True 216 | 217 | def premultiply(self): 218 | self.preload_params = False 219 | 220 | if self.fourier_amp is not None and not self.preloaded_H_amp: 221 | self.H = self.fourier_amp.clamp(min=0.) * self.H 222 | if self.fourier_phase is not None and not self.preloaded_H_phase: 223 | self.H = self.H * torch.exp(1j * self.fourier_phase) 224 | 225 | self.H.detach_() 226 | self.preloaded_H_amp = True 227 | self.preloaded_H_phase = True 228 | 229 | @property 230 | def plane_idx(self): 231 | return self._plane_idx 232 | 233 | @plane_idx.setter 234 | def plane_idx(self, idx): 235 | if idx is None: 236 | return 237 | 238 | self._plane_idx = idx 239 | if len(self.prop_dist) > 1: 240 | self.prop_dist = [self.prop_dist[idx]] 241 | 242 | if self.fourier_amp is not None and self.fourier_amp.shape[1] > 1: 243 | self.fourier_amp = nn.Parameter(self.fourier_amp[:, idx:idx+1, ...], requires_grad=False) 244 | if self.fourier_phase is not None and self.fourier_phase.shape[1] > 1: 245 | self.fourier_phase = nn.Parameter(self.fourier_phase[:, idx:idx+1, ...], requires_grad=False) 246 | 247 | 248 | 249 | class SerialProp(nn.Module): 250 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0, 251 | prop_dists_from_wrp=None, linear_conv=True, dim=1, opt=None): 252 | super(SerialProp, self).__init__() 253 | first_prop = Propagation(prop_dist, wavelength, feature_size, 254 | prop_type=prop_type, linear_conv=linear_conv, F_aperture=F_aperture, dim=dim) 255 | props = [first_prop] 256 | 257 | if prop_dists_from_wrp is not None: 258 | second_prop = Propagation(prop_dists_from_wrp, wavelength, feature_size, 259 | prop_type=prop_type, linear_conv=linear_conv, F_aperture=1.0, dim=dim) 260 | props += [second_prop] 261 | self.props = nn.Sequential(*props) 262 | 263 | # copy the opt parameters for initializing prop in other modules 264 | self.opt = opt 265 | 266 | def forward(self, u_in): 267 | 268 | u_out = self.props(u_in) 269 | self.uncropped_u_out = self.props[-1].uncropped_u_out # dirty way to access final layer uncropped output 270 | 271 | return u_out 272 | 273 | def preload_H(self): 274 | for prop in self.props: 275 | prop.preload_H() 276 | 277 | @property 278 | def plane_idx(self): 279 | return self._plane_idx 280 | 281 | @plane_idx.setter 282 | def plane_idx(self, idx): 283 | if idx is None: 284 | return 285 | 286 | self._plane_idx = idx 287 | for prop in self.props: 288 | prop.plane_idx = idx -------------------------------------------------------------------------------- /props/prop_physical.py: -------------------------------------------------------------------------------- 1 | """ 2 | Propagation happening on the setup 3 | 4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu) 5 | 6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 8 | # The material is provided as-is, with no warranties whatsoever. 9 | # If you publish any code, data, or scientific work based on this, please cite our work. 10 | 11 | Technical Paper: 12 | Time-multiplexed Neural Holography: 13 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 14 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein. 15 | SIGGRAPH 2022 16 | """ 17 | 18 | import torch 19 | import torch.nn as nn 20 | import utils 21 | import time 22 | import cv2 23 | import imageio 24 | 25 | from hw.phase_encodings import phase_encoding 26 | import sys 27 | if sys.platform == 'win32': 28 | import slmpy 29 | import hw.camera_capture_module as cam 30 | import hw.calibration_module as calibration 31 | 32 | 33 | class PhysicalProp(nn.Module): 34 | """ A module for physical propagation, 35 | forward pass displays gets SLM pattern as an input and display the pattern on the physical setup, 36 | and capture the diffraction image at the target plane, 37 | and then return warped image using pre-calibrated homography from instantiation. 38 | 39 | Class initialization parameters 40 | ------------------------------- 41 | :param params_slm: a set of parameters for the SLM. 42 | :param params_camera: a set of parameters for the camera sensor. 43 | :param params_calib: a set of parameters for homography calibration. 44 | :param q_fn: quantization function module 45 | 46 | Usage 47 | ----- 48 | Functions as a pytorch module: 49 | 50 | >>> camera_prop = PhysicalProp(...) 51 | >>> captured_amp = camera_prop(slm_phase) 52 | 53 | slm_phase: phase at the SLM plane, with dimensions [batch, 1, height, width] 54 | captured_amp: amplitude at the target plane, with dimensions [batch, 1, height, width] 55 | 56 | """ 57 | def __init__(self, params_slm, params_camera, params_calib=None, q_fn=None, shutter_speed=100, hdr=False): 58 | super(PhysicalProp, self).__init__() 59 | 60 | 61 | self.shutter_speed = shutter_speed 62 | self.hdr = hdr 63 | self.q_fn = q_fn 64 | self.params_calib = params_calib 65 | 66 | if self.hdr: 67 | assert len(self.shutter_speed) > 1 # more than 1 shutter speed for HDR capture 68 | else: 69 | assert len(self.shutter_speed) == 1 # non-hdr mode supports only one shutter speed 70 | 71 | # 1. Connect Camera 72 | self.camera = cam.CameraCapture(params_camera) 73 | self.camera.connect(0) # specify the camera to use, 0 for main cam, 1 for the second cam 74 | #self.camera.start_capture() 75 | self.camera.start_capture() 76 | 77 | # 2. Connect SLM 78 | self.slm = slmpy.SLMdisplay(isImageLock=True, monitor=params_slm.monitor_num) 79 | self.params_slm = params_slm 80 | 81 | # 3. Calibrate hardware using homography 82 | if params_calib is not None: 83 | self.warper = calibration.Warper(params_calib) 84 | self.calibrate(params_calib.phase_path, params_calib.show_preview) 85 | else: 86 | self.warper = None 87 | 88 | def calibrate_total_laser_energy(self): 89 | print("Calibrating total laser energy...") 90 | phase_img = imageio.imread(self.params_calib.blank_phase_path) 91 | self.slm.updateArray(phase_img) 92 | time.sleep(5) 93 | captured_plane_wave = self.forward(phase_img) 94 | h, w = captured_plane_wave.shape[-2], captured_plane_wave.shape[-1] # full SLM size 95 | cropped_energy = utils.crop_image(captured_plane_wave**2, (500, 500), stacked_complex=False) 96 | self.total_laser_energy = cropped_energy.sum() * (h * w) / (500 * 500) 97 | 98 | def calibrate(self, phase_path, show_preview=False): 99 | """ 100 | 101 | :param phase_path: 102 | :param show_preview: 103 | :return: 104 | """ 105 | print(' -- Calibrating ...') 106 | self.camera.set_shutter_speed(2000) # for homography pattern. remember to reset it! 107 | self.camera.set_gain(10) # for homography pattern. remember to reset it! 108 | phase_img = imageio.imread(phase_path) 109 | #print(phase_img) 110 | self.slm.updateArray(phase_img) 111 | time.sleep(5) 112 | captured_img = self.camera.grab_images_fast(5) # capture 5-10 images for averaging 113 | calib_success = self.warper.calibrate(captured_img, show_preview) 114 | self.camera.set_gain(0) 115 | if calib_success: 116 | print(' -- Calibration succeeded!...') 117 | if not self.hdr: 118 | print("One time step shutter speed for non-HDR capture...") 119 | self.camera.set_shutter_speed(self.shutter_speed[0]) # reset for capture 120 | else: 121 | raise ValueError(' -- Calibration failed') 122 | 123 | def forward(self, slm_phase, time_avg=1): 124 | """ 125 | 126 | :param slm_phase: 127 | :return: 128 | """ 129 | input_phase = slm_phase 130 | if self.q_fn is not None: 131 | dp_phase = self.q_fn(input_phase) 132 | else: 133 | dp_phase = input_phase 134 | 135 | self.display_slm_phase(dp_phase) 136 | 137 | raw_intensity_sum = 0 138 | for t in range(time_avg): 139 | raw_intensity = self.capture_linear_intensity(dp_phase) # grayscale raw16 intensity image 140 | raw_intensity_sum += raw_intensity 141 | raw_intensity = raw_intensity_sum / time_avg 142 | 143 | # amplitude is returned! not intensity! 144 | warped_intensity = self.warper(raw_intensity) # apply homography 145 | return warped_intensity.sqrt() # return amplitude 146 | 147 | def capture_linear_intensity(self, slm_phase): 148 | """ 149 | display a phase pattern on the SLM and capture a generated holographic image with the sensor. 150 | 151 | :param slm_phase: 152 | :return: 153 | """ 154 | raw_uint16_data = self.capture_uint16() # display & retrieve buffer 155 | captured_intensity = self.process_raw_data(raw_uint16_data) # demosaick & sum up 156 | return captured_intensity 157 | 158 | def forward_hdr(self, slm_phase): 159 | """ 160 | 161 | :param slm_phase: 162 | :return: 163 | """ 164 | input_phase = slm_phase 165 | if self.q_fn is not None: 166 | dp_phase = self.q_fn(input_phase) 167 | else: 168 | dp_phase = input_phase 169 | 170 | raw_intensity_hdr, raw_intensity_stack = self.capture_linear_intensity_hdr(dp_phase) # grayscale raw16 intensity image 171 | 172 | # amplitude is returned! not intensity! 173 | warped_intensity_hdr = self.warper(raw_intensity_hdr) # apply homography 174 | warped_intensity_stack = [self.warper(intensity) for intensity in raw_intensity_stack] 175 | warped_amplitude_hdr = warped_intensity_hdr.sqrt() 176 | warped_amplitude_stack = [intensity.sqrt() for intensity in warped_intensity_stack] 177 | return warped_amplitude_hdr, warped_amplitude_stack 178 | 179 | def capture_linear_intensity_hdr(self, slm_phase): 180 | raw_uint16_data_list = [] 181 | for s in self.shutter_speed: 182 | self.camera.set_shutter_speed(s) # one exposure 183 | raw_uint16_data = self.capture_uint16(slm_phase) 184 | raw_uint16_data_list.append(raw_uint16_data) 185 | #captured_intensity_hdr = self.process_raw_data(raw_uint16_data_list[0]) # convert to hdr and demosaick? 186 | captured_intensity_exposure_stack = [torch.clip(self.process_raw_data(raw_data), 0, 1) for raw_data in raw_uint16_data_list] # overexposed images, clip to range 187 | captured_intensity_hdr = self.merge_hdr(captured_intensity_exposure_stack) 188 | return captured_intensity_hdr, captured_intensity_exposure_stack 189 | 190 | def merge_hdr(self, exposure_stack): 191 | weight_sum = 0 192 | weighted_img_sum = 0 193 | for s, img in zip(self.shutter_speed, exposure_stack): 194 | weight = torch.exp(-4 * (img - 0.5)**2 / 0.5**2 ) 195 | weighted_img = weight * (torch.log(img) - torch.log(torch.tensor(s))) 196 | weight_sum = weight_sum + weight 197 | weighted_img_sum = weighted_img_sum + weighted_img 198 | merged_img = torch.exp(weighted_img_sum / (weight_sum + 1e-10)) # numerical issues 199 | return merged_img 200 | 201 | def display_slm_phase(self, slm_phase): 202 | if slm_phase is not None: # just for simple camera capture 203 | if torch.is_tensor(slm_phase): # raw phase is always tensor. 204 | slm_phase_encoded = phase_encoding(slm_phase, self.params_slm.slm_type) 205 | else: # uint8 encoded phase (should be np.array) 206 | slm_phase_encoded = slm_phase 207 | self.slm.updateArray(slm_phase_encoded) 208 | 209 | def capture_uint16(self): 210 | """ 211 | gets phase pattern(s) and display it on the SLM, and then send a signal to board (wait next clock from SLM). 212 | Right after hearing back from the SLM, it sends another signal to PC so that PC retreives the camera buffer. 213 | 214 | :param slm_phase: 215 | :return: 216 | """ 217 | 218 | if self.camera.params.ser is not None: 219 | self.camera.params.ser.write(f'D'.encode()) 220 | 221 | # TODO: make the following in a separate function. 222 | # Wait until receiving signal from arduino 223 | incoming_byte = self.camera.params.ser.inWaiting() 224 | t0 = time.perf_counter() 225 | while True: 226 | received = self.camera.params.ser.read(incoming_byte).decode('UTF-8') 227 | if received != 'C': 228 | incoming_byte = self.camera.params.ser.inWaiting() 229 | if time.perf_counter() - t0 > 2.0: 230 | break 231 | else: 232 | break 233 | else: 234 | #print("settling...") 235 | time.sleep(self.params_slm.settle_time) 236 | raw_data_from_buffer = self.camera.retrieve_buffer() 237 | 238 | return raw_data_from_buffer 239 | 240 | def process_raw_data(self, raw_data): 241 | """ 242 | gets raw data from the camera buffer, and demosaick it 243 | 244 | :param raw_data: 245 | :return: 246 | """ 247 | raw_data = raw_data - 64 248 | color_cv_image = cv2.cvtColor(raw_data, self.camera.demosaick_rule) # it gives float64 from uint16 -- double check it 249 | captured_intensity = utils.im2float(color_cv_image) # float64 to float32 250 | 251 | # Numpy to tensor 252 | captured_intensity = torch.tensor(captured_intensity, dtype=torch.float32, 253 | device=self.dev).permute(2, 0, 1).unsqueeze(0) 254 | captured_intensity = torch.sum(captured_intensity, dim=1, keepdim=True) 255 | return captured_intensity 256 | 257 | def disconnect(self): 258 | #self.camera.stop_capture() 259 | self.camera.stop_capture() 260 | self.camera.disconnect() 261 | self.slm.close() 262 | 263 | def to(self, *args, **kwargs): 264 | slf = super().to(*args, **kwargs) 265 | if slf.warper is not None: 266 | slf.warper = slf.warper.to(*args, **kwargs) 267 | try: 268 | slf.dev = next(slf.parameters()).device 269 | except StopIteration: # no parameters 270 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0] 271 | if device_arg is not None: 272 | slf.dev = device_arg 273 | return slf -------------------------------------------------------------------------------- /props/prop_submodules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modules for propagation 3 | 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import utils 10 | from unet import Conv2dSame 11 | 12 | 13 | class Field2Input(nn.Module): 14 | """Gets complex-valued field and turns it into multi-channel images""" 15 | 16 | def __init__(self, input_res=(800, 1280), coord='rect', latent_amp=None, latent_phase=None, shared_cnn=False): 17 | super(Field2Input, self).__init__() 18 | self.input_res = input_res 19 | self.coord = coord.lower() 20 | self.latent_amp = latent_amp 21 | self.latent_phase = latent_phase 22 | self.shared_cnn = shared_cnn 23 | 24 | def forward(self, input_field): 25 | # If input field is slm phase 26 | if input_field.dtype == torch.float32: 27 | input_field = torch.exp(1j * input_field) 28 | 29 | # 1) Learned phase offset 30 | if self.latent_phase is not None: 31 | input_field = input_field * torch.exp(1j * self.latent_phase) 32 | 33 | # 2) Learned amplitude 34 | if self.latent_amp is not None: 35 | input_field = self.latent_amp * input_field 36 | 37 | input_field = utils.pad_image(input_field, self.input_res, pytorch=True, stacked_complex=False) 38 | input_field = utils.crop_image(input_field, self.input_res, pytorch=True, stacked_complex=False) 39 | 40 | # To use shared CNN, put everything into batch dimension; 41 | if self.shared_cnn: 42 | num_mb, num_dists = input_field.shape[0], input_field.shape[1] 43 | input_field = input_field.reshape(num_mb*num_dists, 1, *input_field.shape[2:]) 44 | 45 | # Input format 46 | if self.coord == 'rect': 47 | stacked_input = torch.cat((input_field.real, input_field.imag), 1) 48 | elif self.coord == 'polar': 49 | stacked_input = torch.cat((input_field.abs(), input_field.angle()), 1) 50 | elif self.coord == 'amp': 51 | stacked_input = input_field.abs() 52 | elif 'both' in self.coord: 53 | stacked_input = torch.cat((input_field.abs(), input_field.angle(), input_field.real, input_field.imag), 1) 54 | 55 | return stacked_input 56 | 57 | 58 | class Output2Field(nn.Module): 59 | """Gets complex-valued field and turns it into multi-channel images""" 60 | 61 | def __init__(self, output_res=(800, 1280), coord='rect', num_ch_output=1): 62 | super(Output2Field, self).__init__() 63 | self.output_res = output_res 64 | self.coord = coord.lower() 65 | self.num_ch_output = num_ch_output # number of channels in output 66 | 67 | def forward(self, stacked_output): 68 | 69 | if self.coord in ('rect', 'both'): 70 | complex_valued_field = torch.view_as_complex(stacked_output.unsqueeze(4). 71 | permute(0, 4, 2, 3, 1).contiguous()) 72 | elif self.coord == 'polar': 73 | amp = stacked_output[:, 0:1, ...] 74 | phi = stacked_output[:, 1:2, ...] 75 | complex_valued_field = amp * torch.exp(1j * phi) 76 | elif self.coord == 'amp' or '1ch_output' in self.coord: 77 | complex_valued_field = stacked_output * torch.exp(1j * torch.zeros_like(stacked_output)) 78 | 79 | output_field = utils.pad_image(complex_valued_field, self.output_res, pytorch=True, stacked_complex=False) 80 | output_field = utils.crop_image(output_field, self.output_res, pytorch=True, stacked_complex=False) 81 | 82 | if self.num_ch_output > 1: 83 | # reshape to original tensor shape 84 | output_field = output_field.reshape(output_field.shape[0] // self.num_ch_output, self.num_ch_output, 85 | *output_field.shape[2:]) 86 | 87 | return output_field 88 | 89 | 90 | class Conv2dField(nn.Module): 91 | """Apply 2d conv on amp or field""" 92 | 93 | def __init__(self, complex=False, conv_size=3): 94 | super(Conv2dField, self).__init__() 95 | self.complex = complex # apply convolution on field 96 | self.conv_size = (conv_size, conv_size) 97 | if self.complex: 98 | self.conv_real = Conv2dSame(1, 1, conv_size) 99 | self.conv_imag = Conv2dSame(1, 1, conv_size) 100 | init_weight = torch.zeros(1, 1, *self.conv_size) 101 | init_weight[..., conv_size//2, conv_size//2] = 1. 102 | self.conv_real.net[1].weight = nn.Parameter(init_weight.detach().requires_grad_(True)) 103 | self.conv_imag.net[1].weight = nn.Parameter(init_weight.detach().requires_grad_(True)) 104 | else: 105 | self.conv = Conv2dSame(1, 1, conv_size, bias=False) 106 | init_weight = torch.zeros(1, 1, *self.conv_size) 107 | init_weight[..., conv_size//2, conv_size//2] = 1. 108 | self.conv.net[1].weight = nn.Parameter(init_weight.requires_grad_(True)) 109 | 110 | def forward(self, input_field): 111 | # check if input is light field 112 | if len(input_field.shape) > 4: 113 | lf_batch_size = input_field.shape[0] 114 | num_ch = input_field.shape[1] 115 | num_y = input_field.shape[4] 116 | num_x = input_field.shape[5] 117 | input_field = input_field.permute(0, 4, 5, 1, 2, 3) 118 | input_field = input_field.reshape(lf_batch_size * num_y * num_x, num_ch, *input_field.shape[-2:]) 119 | lf = True 120 | else: 121 | lf = False 122 | 123 | # reshape tensor if number of channels > 1 124 | num_ch = input_field.shape[1] 125 | if num_ch > 1: 126 | batch_size = input_field.shape[0] 127 | input_field = input_field.reshape(batch_size * num_ch, 1, *input_field.shape[2:]) 128 | 129 | if self.complex: 130 | # apply conv on complex fields 131 | real = self.conv_real(input_field.real) - self.conv_imag(input_field.imag) 132 | imag = self.conv_real(input_field.imag) + self.conv_imag(input_field.real) 133 | output_field = torch.view_as_complex(torch.stack((real, imag), -1)) 134 | else: 135 | # apply conv on intensity 136 | output_amp = self.conv(input_field.abs()**2).abs().mean(dim=1, keepdims=True).sqrt() 137 | output_field = output_amp * torch.exp(1j * input_field.angle()) 138 | 139 | # reshape to original tensor shape 140 | if num_ch > 1: 141 | output_field = output_field.reshape(batch_size, num_ch, *output_field.shape[2:]) 142 | 143 | if lf: 144 | output_field = output_field.reshape(lf_batch_size, num_y, num_x, num_ch, *output_field.shape[-2:]) 145 | output_field = output_field.permute(0, 3, 4, 5, 1, 2) 146 | 147 | return output_field 148 | 149 | 150 | class LatentCodedMLP(nn.Module): 151 | """ 152 | concatenate latent codes in the middle of forward pass as well. 153 | put latent codes shape of (1, L, H, W) as a parameter for the forward pass. 154 | num_latent_codes: list of numbers of slices for each layer 155 | * so the sum of num_latent_codes should be total number of the latent codes channels 156 | """ 157 | def __init__(self, num_layers=5, num_features=32, norm=None, num_latent_codes=None): 158 | super(LatentCodedMLP, self).__init__() 159 | 160 | if num_latent_codes is None: 161 | num_latent_codes = [0] * num_layers 162 | 163 | assert len(num_latent_codes) == num_layers 164 | 165 | self.num_latent_codes = num_latent_codes 166 | self.idxs = [sum(num_latent_codes[:y]) for y in range(num_layers + 1)] 167 | self.nets = nn.ModuleList([]) 168 | num_features = [num_features] * num_layers 169 | num_features[0] = 1 170 | 171 | # define each layer 172 | for i in range(num_layers - 1): 173 | net = [nn.Conv2d(num_features[i] + num_latent_codes[i], num_features[i + 1], kernel_size=1)] 174 | if norm is not None: 175 | net += [norm(num_groups=4, num_channels=num_features[i + 1], affine=True)] 176 | net += [nn.LeakyReLU(0.2, True)] 177 | self.nets.append(nn.Sequential(*net)) 178 | 179 | self.nets.append(nn.Conv2d(num_features[-1] + num_latent_codes[-1], 1, kernel_size=1)) 180 | 181 | for m in self.modules(): 182 | if isinstance(m, nn.Conv2d): 183 | nn.init.normal_(m.weight, std=0.05) 184 | 185 | def forward(self, phases, latent_codes=None): 186 | 187 | after_relu = phases 188 | # concatenate latent codes at each layer and send through the convolutional layers 189 | for i in range(len(self.num_latent_codes)): 190 | if latent_codes is not None: 191 | latent_codes_b = latent_codes.repeat(phases.shape[0], 1, 1, 1) 192 | after_relu = torch.cat((after_relu, latent_codes_b[:, self.idxs[i]:self.idxs[i + 1], ...]), 1) 193 | after_relu = self.nets[i](after_relu) 194 | 195 | # residual connection 196 | return phases - after_relu 197 | 198 | 199 | class ContentDependentField(nn.Module): 200 | def __init__(self, num_layers=5, num_features=32, norm=nn.GroupNorm, latent_coords=False): 201 | """ Simple 5layers CNN modeling content dependent undiffracted light """ 202 | 203 | super(ContentDependentField, self).__init__() 204 | 205 | if not latent_coords: 206 | first_ch = 1 207 | else: 208 | first_ch = 3 209 | 210 | net = [Conv2dSame(first_ch, num_features, kernel_size=3)] 211 | 212 | for i in range(num_layers - 2): 213 | if norm is not None: 214 | net += [norm(num_groups=2, num_channels=num_features, affine=True)] 215 | net += [nn.LeakyReLU(0.2, True), 216 | Conv2dSame(num_features, num_features, kernel_size=3)] 217 | 218 | if norm is not None: 219 | net += [norm(num_groups=4, num_channels=num_features, affine=True)] 220 | 221 | net += [nn.LeakyReLU(0.2, True), 222 | Conv2dSame(num_features, 2, kernel_size=3)] 223 | 224 | self.net = nn.Sequential(*net) 225 | 226 | def forward(self, phases, latent_coords=None): 227 | if latent_coords is not None: 228 | input_cnn = torch.cat((phases, latent_coords), dim=1) 229 | else: 230 | input_cnn = phases 231 | 232 | return self.net(input_cnn) 233 | 234 | 235 | class ProcessPhase(nn.Module): 236 | def __init__(self, num_layers=5, num_features=32, num_output_feat=0, norm=nn.BatchNorm2d, num_latent_codes=0): 237 | super(ProcessPhase, self).__init__() 238 | 239 | # avoid zero 240 | self.num_output_feat = max(num_output_feat, 1) 241 | self.num_latent_codes = num_latent_codes 242 | 243 | # a bunch of 1x1 conv layers, set by num_layers 244 | net = [nn.Conv2d(1 + num_latent_codes, num_features, kernel_size=1)] 245 | 246 | for i in range(num_layers - 2): 247 | if norm is not None: 248 | net += [norm(num_groups=2, num_channels=num_features, affine=True)] 249 | net += [nn.LeakyReLU(0.2, True), 250 | nn.Conv2d(num_features, num_features, kernel_size=1)] 251 | 252 | if norm is not None: 253 | net += [norm(num_groups=2, num_channels=num_features, affine=True)] 254 | 255 | net += [nn.ReLU(True), 256 | nn.Conv2d(num_features, self.num_output_feat, kernel_size=1)] 257 | 258 | self.net = nn.Sequential(*net) 259 | 260 | def forward(self, phases): 261 | return phases - self.net(phases) 262 | 263 | 264 | class SourceAmplitude(nn.Module): 265 | def __init__(self, num_gaussians=3, init_sigma=None, init_amp=0.7, x_s0=0.0, y_s0=0.0): 266 | super(SourceAmplitude, self).__init__() 267 | 268 | self.num_gaussians = num_gaussians 269 | 270 | if init_sigma is None: 271 | init_sigma = [100.] * self.num_gaussians # default to 100 for all 272 | 273 | # create parameters for source amplitudes 274 | self.sigmas = nn.Parameter(torch.tensor(init_sigma)) 275 | self.x_s = nn.Parameter(torch.ones(num_gaussians) * x_s0) 276 | self.y_s = nn.Parameter(torch.ones(num_gaussians) * y_s0) 277 | self.amplitudes = nn.Parameter(torch.ones(num_gaussians) / (num_gaussians) * init_amp) 278 | self.dc_term = nn.Parameter(torch.zeros(1)) 279 | 280 | self.x_dim = None 281 | self.y_dim = None 282 | 283 | def forward(self, phases): 284 | # create DC term, then add the gaussians 285 | source_amp = torch.ones_like(phases) * self.dc_term 286 | for i in range(self.num_gaussians): 287 | source_amp += self.create_gaussian(phases.shape, i) 288 | 289 | return source_amp 290 | 291 | def create_gaussian(self, shape, idx): 292 | # create sampling grid if needed 293 | if self.x_dim is None or self.y_dim is None: 294 | self.x_dim = torch.linspace(-(shape[-1] - 1) / 2, 295 | (shape[-1] - 1) / 2, 296 | shape[-1], device=self.dc_term.device) 297 | self.y_dim = torch.linspace(-(shape[-2] - 1) / 2, 298 | (shape[-2] - 1) / 2, 299 | shape[-2], device=self.dc_term.device) 300 | 301 | if self.x_dim.device != self.sigmas.device: 302 | self.x_dim.to(self.sigmas.device).detach() 303 | self.x_dim.requires_grad = False 304 | if self.y_dim.device != self.sigmas.device: 305 | self.y_dim.to(self.sigmas.device).detach() 306 | self.y_dim.requires_grad = False 307 | 308 | # offset grid by coordinate and compute x and y gaussian components 309 | x_gaussian = torch.exp(-0.5 * torch.pow(torch.div(self.x_dim - self.x_s[idx], self.sigmas[idx]), 2)) 310 | y_gaussian = torch.exp(-0.5 * torch.pow(torch.div(self.y_dim - self.y_s[idx], self.sigmas[idx]), 2)) 311 | 312 | # outer product with amplitude scaling 313 | gaussian = torch.ger(self.amplitudes[idx] * y_gaussian, x_gaussian) 314 | 315 | return gaussian 316 | 317 | 318 | class FiniteDiffField(nn.Module): 319 | def __init__(self): 320 | super(FiniteDiffField, self).__init__() 321 | pass 322 | 323 | def forward(self, model, slm_phase, delta_phase): 324 | # delta phase is the phase difference to be added to the input phase. 325 | # Can sample some SLM locations each iteration 326 | 327 | field_1 = model(slm_phase) 328 | field_2 = model(slm_phase + delta_phase) 329 | # size H*W, which is the ith column of Jacobian df/d(phi) 330 | delta_field = (field_2 - field_1) / delta_phase 331 | return delta_field 332 | 333 | 334 | def make_kernel_gaussian(sigma, kernel_size): 335 | 336 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 337 | x_cord = torch.arange(kernel_size) 338 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) 339 | y_grid = x_grid.t() 340 | xy_grid = torch.stack([x_grid, y_grid], dim=-1) 341 | 342 | mean = (kernel_size - 1) / 2 343 | variance = sigma**2 344 | 345 | # Calculate the 2-dimensional gaussian kernel which is 346 | # the product of two gaussian distributions for two different 347 | # variables (in this case called x and y) 348 | gaussian_kernel = ((1 / (2 * math.pi * variance)) 349 | * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) 350 | / (2 * variance))) 351 | # Make sure sum of values in gaussian kernel equals 1. 352 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 353 | 354 | # Reshape to 2d depthwise convolutional weight 355 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 356 | 357 | return gaussian_kernel 358 | 359 | 360 | def create_gaussian(shape, sigma=800, dev=torch.device('cuda')): 361 | # create sampling grid if needed 362 | shape_min = min(shape[-1], shape[-2]) 363 | x_dim = torch.linspace(-(shape_min - 1) / 2, 364 | (shape_min - 1) / 2, 365 | shape[-1], device=dev) 366 | y_dim = torch.linspace(-(shape_min - 1) / 2, 367 | (shape_min - 1) / 2, 368 | shape[-2], device=dev) 369 | 370 | # offset grid by coordinate and compute x and y gaussian components 371 | x_gaussian = torch.exp(-0.5 * torch.pow(torch.div(x_dim, sigma), 2)) 372 | y_gaussian = torch.exp(-0.5 * torch.pow(torch.div(y_dim, sigma), 2)) 373 | 374 | # outer product with amplitude scaling 375 | gaussian = torch.ger(y_gaussian, x_gaussian) 376 | 377 | return gaussian 378 | 379 | 380 | 381 | 382 | -------------------------------------------------------------------------------- /props/prop_zernike.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for zernike basis 3 | 4 | """ 5 | 6 | import torch 7 | import numpy as np 8 | import utils 9 | import torch.fft 10 | from aotools.functions import zernikeArray 11 | 12 | 13 | def combine_zernike_basis(coeffs, basis, return_phase=False): 14 | """ 15 | Multiplies the Zernike coefficients and basis functions while preserving 16 | dimensions 17 | 18 | :param coeffs: torch tensor with coeffs, see propagation_ASM_zernike 19 | :param basis: the output of compute_zernike_basis, must be same length as coeffs 20 | :param return_phase: 21 | :return: A float32 tensor that combines coeffs and basis. 22 | """ 23 | 24 | if len(coeffs.shape) < 3: 25 | coeffs = torch.reshape(coeffs, (coeffs.shape[0], 1, 1)) 26 | 27 | # combine zernike basis and coefficients 28 | zernike = (coeffs * basis).sum(0, keepdim=True) 29 | 30 | # shape to [1, len(coeffs), H, W] 31 | zernike = zernike.unsqueeze(0) 32 | 33 | return zernike 34 | 35 | 36 | def compute_zernike_basis(num_polynomials, field_res, dtype=torch.float32, wo_piston=False): 37 | """Computes a set of Zernike basis function with resolution field_res 38 | 39 | num_polynomials: number of Zernike polynomials in this basis 40 | field_res: [height, width] in px, any list-like object 41 | dtype: torch dtype for computation at different precision 42 | """ 43 | 44 | # size the zernike basis to avoid circular masking 45 | zernike_diam = int(np.ceil(np.sqrt(field_res[0]**2 + field_res[1]**2))) 46 | 47 | # create zernike functions 48 | 49 | if not wo_piston: 50 | zernike = zernikeArray(num_polynomials, zernike_diam) 51 | else: # 200427 - exclude pistorn term 52 | idxs = range(2, 2 + num_polynomials) 53 | zernike = zernikeArray(idxs, zernike_diam) 54 | 55 | zernike = utils.crop_image(zernike, field_res, pytorch=False) 56 | 57 | # convert to tensor and create phase 58 | zernike = torch.tensor(zernike, dtype=dtype, requires_grad=False) 59 | 60 | return zernike 61 | -------------------------------------------------------------------------------- /quantization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Quantization modules using projected gradient-descent, surrogate gradients, and Gumbel-Softmax. 3 | 4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu) 5 | """ 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | from PIL import Image 13 | 14 | import utils 15 | import hw.ti as ti 16 | from hw.discrete_slm import DiscreteSLM 17 | 18 | 19 | def load_lut(sim_prop, opt): 20 | lut = None 21 | if hasattr(sim_prop, 'lut'): 22 | if sim_prop.lut is not None: 23 | lut = sim_prop.lut.squeeze().cpu().detach().numpy().tolist() 24 | else: 25 | # here directly sets lut to given 17 level lut, 26 | # no matter what, if quan_method = True, just set it to TI SLM levels 27 | lut = ti.given_lut 28 | if opt.channel is not None: 29 | lut = np.array(lut) * opt.wavelengths[1] / opt.wavelengths[opt.channel] 30 | print("given lut...") 31 | 32 | # TODO: work to remove this line 33 | if lut is not None and len(lut) % 2 == 0: 34 | lut.append(lut[0] + 2 * math.pi) # for lut_mid 35 | 36 | print(f'LUT: {lut}') 37 | return lut 38 | 39 | 40 | def tau_iter(quan_fn, iter_frac, tau_min, tau_max, r=None): 41 | if 'softmax' in quan_fn: 42 | if r is None: 43 | r = math.log(tau_max / tau_min) 44 | tau = max(tau_min, tau_max * math.exp(-r * iter_frac)) 45 | elif 'sigmoid' in quan_fn or 'poly' in quan_fn: 46 | tau = 1 + 10 * iter_frac 47 | else: 48 | tau = None 49 | return tau 50 | 51 | 52 | def quantization(opt, lut): 53 | if opt.quan_method == 'None': 54 | qtz = None 55 | else: 56 | qtz = Quantization(opt.quan_method, lut=lut, c=opt.c_s, num_bits=opt.uniform_nbits if lut is None else 4, 57 | tau_max=opt.tau_max, tau_min=opt.tau_min, r=opt.r, offset=opt.phase_offset) 58 | 59 | return qtz 60 | 61 | 62 | def score_phase(phase, lut, s=5., func='sigmoid'): 63 | # Here s is kinda representing the steepness 64 | 65 | wrapped_phase = (phase + math.pi) % (2 * math.pi) - math.pi 66 | 67 | diff = wrapped_phase - lut 68 | diff = (diff + math.pi) % (2*math.pi) - math.pi # signed angular difference 69 | diff /= math.pi # normalize 70 | 71 | if func == 'sigmoid': 72 | z = s * diff 73 | scores = torch.sigmoid(z) * (1 - torch.sigmoid(z)) * 4 74 | elif func == 'log': 75 | scores = -torch.log(diff.abs() + 1e-20) * s 76 | elif func == 'poly': 77 | scores = (1-torch.abs(diff)**s) 78 | elif func == 'sine': 79 | scores = torch.cos(math.pi * (s * diff).clamp(-1., 1.)) 80 | elif func == 'chirp': 81 | scores = 1 - torch.cos(math.pi * (1-diff.abs())**s) 82 | 83 | return scores 84 | 85 | 86 | # Basic function for NN-based quantization, customize it with various surrogate gradients! 87 | class NearestNeighborSearch(torch.autograd.Function): 88 | 89 | @staticmethod 90 | def forward(ctx, phase, s=torch.tensor(1.0)): 91 | phase_raw = phase.detach() 92 | idx = utils.nearest_idx(phase_raw, DiscreteSLM.lut_midvals) 93 | phase_q = DiscreteSLM.lut[idx] 94 | ctx.mark_non_differentiable(idx) 95 | ctx.save_for_backward(phase_raw, s, phase_q, idx) 96 | return phase_q 97 | 98 | def backward(ctx, grad_output): 99 | return grad_output, None 100 | 101 | 102 | class NearestNeighborPolyGrad(NearestNeighborSearch): 103 | 104 | @staticmethod 105 | def forward(ctx, phase, s=torch.tensor(1.0)): 106 | return NearestNeighborSearch.forward(ctx, phase, s) 107 | 108 | def backward(ctx, grad_output): 109 | input, s, output, idx = ctx.saved_tensors 110 | grad_input = grad_output.clone() 111 | 112 | dx = input - output 113 | d_idx = (dx / torch.abs(dx)).int().nan_to_num() 114 | other_end = DiscreteSLM.lut[(idx + d_idx)].to(input.device) # far end not selected for quantization 115 | 116 | # normalization 117 | mid_point = (other_end + output) / 2 118 | gap = torch.abs(other_end - output) + 1e-20 119 | z = (input - mid_point) / gap * 2 # normalize to [-1. 1] 120 | 121 | dout_din = (0.5 * s * (1 - abs(z)) ** (s - 1)).nan_to_num() 122 | scale = 2. #* dout_din.mean() / ((dout_din**2).mean() + 1e-20) 123 | grad_input *= (dout_din * scale) # scale according to distance 124 | 125 | return grad_input, None 126 | 127 | 128 | class NearestNeighborSigmoidGrad(NearestNeighborSearch): 129 | 130 | @staticmethod 131 | def forward(ctx, phase, s=torch.tensor(1.0)): 132 | return NearestNeighborSearch.forward(ctx, phase, s) 133 | 134 | def backward(ctx, grad_output): 135 | x, s, output, idx = ctx.saved_tensors 136 | grad_input = grad_output.clone() 137 | 138 | dx = x - output 139 | d_idx = (dx / torch.abs(dx)).int().nan_to_num() 140 | other_end = DiscreteSLM.lut[(idx + d_idx)].to(x.device) # far end not selected for quantization 141 | 142 | # normalization 143 | mid_point = (other_end + output) / 2 144 | gap = torch.abs(other_end - output) + 1e-20 145 | z = (x - mid_point) / gap * 2 # normalize to [-1, 1] 146 | z *= s 147 | 148 | dout_din = (torch.sigmoid(z) * (1 - torch.sigmoid(z))) 149 | scale = 4. * s#1 / 0.462 * gap * s#dout_din.mean() / ((dout_din**2).mean() + 1e-20) # =100 150 | grad_input *= (dout_din * scale) 151 | 152 | return grad_input, None 153 | 154 | 155 | nns = NearestNeighborSearch.apply 156 | nns_poly = NearestNeighborPolyGrad.apply 157 | nns_sigmoid = NearestNeighborSigmoidGrad.apply 158 | 159 | 160 | class SoftmaxBasedQuantization(nn.Module): 161 | def __init__(self, lut, gumbel=True, tau_max=3.0, c=300.): 162 | super(SoftmaxBasedQuantization, self).__init__() 163 | 164 | if not torch.is_tensor(lut): 165 | self.lut = torch.tensor(lut, dtype=torch.float32) 166 | else: 167 | self.lut = lut 168 | self.lut = self.lut.reshape(1, len(lut), 1, 1) 169 | self.c = c # boost the score 170 | self.gumbel = gumbel 171 | self.tau_max = tau_max 172 | 173 | def forward(self, phase, tau=1.0, hard=False): 174 | phase_wrapped = (phase + math.pi) % (2*math.pi) - math.pi 175 | 176 | # phase to score 177 | scores = score_phase(phase_wrapped, self.lut.to(phase_wrapped.device), (self.tau_max / tau)**1) * self.c * (self.tau_max / tau)**1.0 178 | 179 | # score to one-hot encoding 180 | if self.gumbel: # (N, 1, H, W) -> (N, C, H, W) 181 | one_hot = F.gumbel_softmax(scores, tau=tau, hard=hard, dim=1) 182 | else: 183 | y_soft = F.softmax(scores/tau, dim=1) 184 | index = y_soft.max(1, keepdim=True)[1] 185 | one_hot_hard = torch.zeros_like(scores, 186 | memory_format=torch.legacy_contiguous_format).scatter_(1, index, 1.0) 187 | if hard: 188 | one_hot = one_hot_hard + y_soft - y_soft.detach() 189 | else: 190 | one_hot = y_soft 191 | 192 | # one-hot encoding to phase value 193 | q_phase = (one_hot * self.lut.to(one_hot.device)) 194 | q_phase = q_phase.sum(1, keepdims=True) 195 | return q_phase 196 | 197 | 198 | class Quantization(nn.Module): 199 | def __init__(self, method=None, num_bits=4, lut=None, dev=torch.device('cuda'), 200 | tau_min=0.5, tau_max=3.0, r=None, c=300., offset=0.0): 201 | super(Quantization, self).__init__() 202 | if lut is None: 203 | # linear look-up table 204 | DiscreteSLM.lut = torch.linspace(-math.pi, math.pi, 2**num_bits + 1).to(dev) 205 | else: 206 | # non-linear look-up table 207 | assert len(lut) == (2**num_bits) + 1 208 | DiscreteSLM.lut = torch.tensor(lut, dtype=torch.float32).to(dev) 209 | 210 | self.quan_fn = None 211 | self.gumbel = 'gumbel' in method.lower() 212 | if method.lower() == 'nn': 213 | self.quan_fn = nns 214 | elif method.lower() == 'nn_sigmoid': 215 | self.quan_fn = nns_sigmoid 216 | elif method.lower() == 'nn_poly': 217 | self.quan_fn = nns_poly 218 | elif 'softmax' in method.lower(): 219 | self.quan_fn = SoftmaxBasedQuantization(DiscreteSLM.lut[:-1], self.gumbel, tau_max=tau_max, c=c) 220 | 221 | self.method = method 222 | self.tau_min = tau_min 223 | self.tau_max = tau_max 224 | self.r = r 225 | self.offset = offset 226 | 227 | def forward(self, input_phase, iter_frac=None, hard=True): 228 | if iter_frac is not None: 229 | tau = tau_iter(self.method, iter_frac, self.tau_min, self.tau_max, self.r) 230 | wrapped_phase = (input_phase + self.offset + math.pi) % (2 * math.pi) - math.pi 231 | if self.quan_fn is None: 232 | return wrapped_phase 233 | else: 234 | if isinstance(tau, float): 235 | tau = torch.tensor(tau, dtype=torch.float32).to(input_phase.device) 236 | if 'nn' in self.method.lower(): 237 | s = tau 238 | return self.quan_fn(wrapped_phase, s) 239 | else: 240 | return self.quan_fn(wrapped_phase, tau, hard) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script for model training 3 | 4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu) 5 | 6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford). 8 | # The material is provided as-is, with no warranties whatsoever. 9 | # If you publish any code, data, or scientific work based on this, please cite our work. 10 | 11 | Technical Paper: 12 | Time-multiplexed Neural Holography: 13 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators 14 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein. 15 | SIGGRAPH 2022 16 | """ 17 | import os 18 | import configargparse 19 | import pytorch_lightning as pl 20 | from pytorch_lightning import Trainer 21 | from torch.utils.data import DataLoader 22 | 23 | import utils 24 | import params 25 | import props.prop_model as prop_model 26 | import image_loader as loaders 27 | import torch 28 | import os 29 | 30 | 31 | # Command line argument processing 32 | p = configargparse.ArgumentParser() 33 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.') 34 | p.add('--capture_subset', type=str, default=None) 35 | 36 | params.add_parameters(p, 'train') 37 | opt = params.set_configs(p.parse_args()) 38 | run_id = params.run_id_training(opt) 39 | 40 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 41 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu_id) 42 | 43 | if opt.gpu_id > 0: 44 | # torch.cuda.set_device(opt.gpu_id) 45 | print(f"Using gpu {opt.gpu_id} ...") 46 | 47 | def main(): 48 | if ',' in opt.data_path: 49 | opt.data_path = opt.data_path.split(',') 50 | else: 51 | opt.data_path = [opt.data_path] 52 | print(f' - training a model ... Dataset path:{opt.data_path}') 53 | # Setup up dataloaders 54 | num_workers = 4 55 | # modify plane idxes! 56 | train_loader = DataLoader(loaders.PairsLoader([os.path.join(path, 'train') for path in opt.data_path], 57 | plane_idxs=opt.plane_idxs['train'], image_res=opt.image_res, 58 | avg_energy_ratio=opt.avg_energy_ratio, slm_type=opt.slm_type, 59 | capture_subset=opt.capture_subset, dataset_subset=opt.dataset_subset), 60 | num_workers=num_workers, batch_size=opt.batch_size, pin_memory=True) 61 | val_loader = DataLoader(loaders.PairsLoader([os.path.join(path, 'val') for path in opt.data_path], 62 | plane_idxs=opt.plane_idxs['train'], image_res=opt.image_res, 63 | shuffle=False, avg_energy_ratio=opt.avg_energy_ratio, 64 | slm_type=opt.slm_type, capture_subset=opt.capture_subset), 65 | num_workers=num_workers, batch_size=opt.batch_size, shuffle=False, pin_memory=True) 66 | test_loader = DataLoader(loaders.PairsLoader([os.path.join(path, 'test') for path in opt.data_path], 67 | plane_idxs=opt.plane_idxs['all'], image_res=opt.image_res, 68 | shuffle=False, avg_energy_ratio=opt.avg_energy_ratio, slm_type=opt.slm_type), 69 | num_workers=num_workers, batch_size=opt.batch_size, shuffle=False, pin_memory=True) 70 | 71 | # Init model 72 | if opt.slm_type == 'ti': 73 | opt.roi_res = (760, 1240) # mofidy here!. should be 700, 1190? 74 | else: 75 | opt.roi_res = (840, 1200) 76 | model = prop_model.model(opt) 77 | model.train() 78 | 79 | # Init root path 80 | root_dir = os.path.join(opt.out_path, run_id) 81 | utils.cond_mkdir(root_dir) 82 | p.write_config_file(opt, [os.path.join(root_dir, 'config.txt')]) 83 | 84 | psnr_checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="PSNR_validation_epoch", dirpath=root_dir, 85 | filename="model-{epoch:02d}-{PSNR_validation_epoch:.2f}", 86 | save_top_k=1, mode="max", ) 87 | latest_checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="PSNR_validation_epoch", dirpath=root_dir, 88 | filename="model-latest-{PSNR_validation_epoch:.2f}", 89 | every_n_epochs=1, save_last=True) 90 | 91 | # Init trainer 92 | trainer = Trainer(default_root_dir=root_dir, accelerator='gpu', 93 | log_every_n_steps=400, gpus=1, max_epochs=opt.num_epochs, callbacks=[psnr_checkpoint_callback, latest_checkpoint_callback]) 94 | 95 | # Fit Model 96 | trainer.fit(model, train_loader, val_loader) 97 | 98 | # Test Model 99 | trainer.test(model, dataloaders=test_loader) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | U-net implementations 3 | 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import init 9 | import functools 10 | 11 | 12 | def norm_layer(norm_str): 13 | if norm_str.lower() == 'instance': 14 | return nn.InstanceNorm2d 15 | elif norm_str.lower() == 'group': 16 | return nn.GroupNorm 17 | elif norm_str.lower() == 'batch': 18 | return nn.BatchNorm2d 19 | 20 | 21 | class UnetSkipConnectionBlock(nn.Module): 22 | """Defines the Unet submodule with skip connection. 23 | X -------------------identity---------------------- 24 | |-- downsampling -- |submodule| -- upsampling --| 25 | """ 26 | 27 | def __init__(self, outer_nc, inner_nc, input_nc=None, 28 | submodule=None, outermost=False, innermost=False, 29 | norm_layer=nn.InstanceNorm2d, use_dropout=False, 30 | outer_skip=False): 31 | """Construct a Unet submodule with skip connections. 32 | Parameters: 33 | outer_nc (int) -- the number of filters in the outer conv layer 34 | inner_nc (int) -- the number of filters in the inner conv layer 35 | input_nc (int) -- the number of channels in input images/features 36 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 37 | outermost (bool) -- if this module is the outermost module 38 | innermost (bool) -- if this module is the innermost module 39 | norm_layer -- normalization layer 40 | use_dropout (bool) -- if use dropout layers. 41 | """ 42 | super(UnetSkipConnectionBlock, self).__init__() 43 | self.outermost = outermost 44 | self.outer_skip = outer_skip 45 | if norm_layer == None: 46 | use_bias = True 47 | elif type(norm_layer) == functools.partial: 48 | use_bias = norm_layer.func == nn.InstanceNorm2d 49 | else: 50 | use_bias = norm_layer == nn.InstanceNorm2d 51 | if input_nc is None: 52 | input_nc = outer_nc 53 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=5, 54 | # Change kernel size changed to 5 from 4 and padding size from 1 to 2 55 | stride=2, padding=2, bias=use_bias) 56 | downrelu = nn.LeakyReLU(0.2, True) 57 | if norm_layer is not None: 58 | if norm_layer == nn.GroupNorm: 59 | downnorm = norm_layer(8, inner_nc) 60 | else: 61 | downnorm = norm_layer(inner_nc) 62 | else: 63 | downnorm = None 64 | uprelu = nn.ReLU(True) 65 | if norm_layer is not None: 66 | if norm_layer == nn.GroupNorm: 67 | upnorm = norm_layer(8, outer_nc) 68 | else: 69 | upnorm = norm_layer(outer_nc) 70 | else: 71 | upnorm = None 72 | 73 | if outermost: 74 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 75 | kernel_size=4, stride=2, 76 | padding=1) 77 | down = [downconv, downrelu] 78 | up = [upconv] # Removed tanh and uprelu 79 | model = down + [submodule] + up 80 | elif innermost: 81 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 82 | kernel_size=4, stride=2, 83 | padding=1, bias=use_bias) 84 | if norm_layer is not None: 85 | down = [downconv, downnorm, downrelu] 86 | up = [upconv, upnorm, uprelu] 87 | else: 88 | down = [downconv, downrelu] 89 | up = [upconv, uprelu] 90 | 91 | model = down + up 92 | else: 93 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 94 | kernel_size=4, stride=2, 95 | padding=1, bias=use_bias) 96 | if norm_layer is not None: 97 | down = [downconv, downnorm, downrelu] 98 | up = [upconv, upnorm, uprelu] 99 | else: 100 | down = [downconv, downrelu] 101 | up = [upconv, uprelu] 102 | 103 | if use_dropout: 104 | model = down + [submodule] + up + [nn.Dropout(0.5)] 105 | else: 106 | model = down + [submodule] + up 107 | 108 | self.model = nn.Sequential(*model) 109 | 110 | def forward(self, x): 111 | if self.outermost and not self.outer_skip: 112 | return self.model(x) 113 | else: # add skip connections 114 | return torch.cat([x, self.model(x)], 1) 115 | 116 | def init_latent(latent_num, wavefront_res, ones=False): 117 | if latent_num > 0: 118 | if ones: 119 | latent = nn.Parameter(torch.ones(1, latent_num, *wavefront_res, 120 | requires_grad=True)) 121 | else: 122 | latent = nn.Parameter(torch.zeros(1, latent_num, *wavefront_res, 123 | requires_grad=True)) 124 | else: 125 | latent = None 126 | return latent 127 | 128 | 129 | def apply_net(net, input, latent_code, complex=False): 130 | if net is None: 131 | return input 132 | if complex: # Only valid for single batch or single channel complex inputs and outputs 133 | multi_channel = (input.shape[1] > 1) 134 | if multi_channel: 135 | input = torch.view_as_real(input[0,...]) 136 | else: 137 | input = torch.view_as_real(input[:,0,...]) 138 | input = input.permute(0,3,1,2) 139 | if latent_code is not None: 140 | input = torch.cat((input, latent_code), dim=1) 141 | output = net(input) 142 | if complex: 143 | if multi_channel: 144 | output = output.permute(0,2,3,1).unsqueeze(0) 145 | else: 146 | output = output.permute(0,2,3,1).unsqueeze(1) 147 | output = torch.complex(output[...,0], output[...,1]) 148 | return output 149 | 150 | def init_weights(net, init_type='normal', init_gain=0.02, outer_skip=False): 151 | """Initialize network weights. 152 | Parameters: 153 | net (network) -- network to be initialized 154 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 155 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 156 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 157 | work better for some applications. Feel free to try yourself. 158 | """ 159 | 160 | def init_func(m): # define the initialization function 161 | classname = m.__class__.__name__ 162 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 163 | if init_type == 'normal': 164 | init.normal_(m.weight.data, 0.0, init_gain) 165 | elif init_type == 'xavier': 166 | init.xavier_normal_(m.weight.data, gain=init_gain) 167 | elif init_type == 'kaiming': 168 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 169 | elif init_type == 'orthogonal': 170 | init.orthogonal_(m.weight.data, gain=init_gain) 171 | else: 172 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 173 | if hasattr(m, 'bias') and m.bias is not None: 174 | init.constant_(m.bias.data, 0.0) 175 | elif classname.find( 176 | 'BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 177 | init.normal_(m.weight.data, 1.0, init_gain) 178 | init.constant_(m.bias.data, 0.0) 179 | 180 | print('initialize network with %s' % init_type) 181 | net.apply(init_func) # apply the initialization function 182 | 183 | 184 | class UnetGenerator(nn.Module): 185 | """Create a Unet-based generator""" 186 | 187 | def __init__(self, input_nc=1, output_nc=1, num_downs=8, nf0=32, max_channels=512, 188 | norm_layer=nn.InstanceNorm2d, use_dropout=False, outer_skip=True, 189 | half_channels=False, eighth_channels=False): 190 | """Construct a Unet generator 191 | Parameters: 192 | input_nc (int) -- the number of channels in input images 193 | output_nc (int) -- the number of channels in output images 194 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 195 | image of size 128x128 will become of size 1x1 # at the bottleneck 196 | ngf (int) -- the number of filters in the last conv layer 197 | norm_layer -- normalization layer 198 | We construct the U-Net from the innermost layer to the outermost layer. 199 | It is a recursive process. 200 | """ 201 | super(UnetGenerator, self).__init__() 202 | self.outer_skip = outer_skip 203 | self.input_nc = input_nc 204 | 205 | if eighth_channels: 206 | divisor = 8 207 | elif half_channels: 208 | divisor = 2 209 | else: 210 | divisor = 1 211 | # construct unet structure 212 | 213 | assert num_downs >= 2 214 | 215 | # Add the innermost layer 216 | unet_block = UnetSkipConnectionBlock(min(2 ** (num_downs - 1) * nf0, max_channels) // divisor, 217 | min(2 ** (num_downs - 1) * nf0, max_channels) // divisor, 218 | input_nc=None, submodule=None, norm_layer=norm_layer, 219 | innermost=True) 220 | 221 | for i in list(range(1, num_downs - 1))[::-1]: 222 | if i == 1: 223 | norm = None # Praneeth's modification 224 | else: 225 | norm = norm_layer 226 | 227 | unet_block = UnetSkipConnectionBlock(min(2 ** i * nf0, max_channels) // divisor, 228 | min(2 ** (i + 1) * nf0, max_channels) // divisor, 229 | input_nc=None, submodule=unet_block, 230 | norm_layer=norm, 231 | use_dropout=use_dropout) 232 | 233 | # Add the outermost layer 234 | self.model = UnetSkipConnectionBlock(min(nf0, max_channels) // divisor, 235 | min(2 * nf0, max_channels) // divisor, 236 | input_nc=input_nc, submodule=unet_block, outermost=True, 237 | norm_layer=None, outer_skip=self.outer_skip) 238 | if self.outer_skip: 239 | self.additional_conv = nn.Conv2d(input_nc + min(nf0, max_channels) // divisor, output_nc, 240 | kernel_size=4, stride=1, padding=2, bias=True) 241 | else: 242 | self.additional_conv = nn.Conv2d(min(nf0, max_channels) // divisor, output_nc, 243 | kernel_size=4, stride=1, padding=2, bias=True) 244 | 245 | def forward(self, cnn_input): 246 | """Standard forward""" 247 | output = self.model(cnn_input) 248 | output = self.additional_conv(output) 249 | output = output[:,:,:-1,:-1] 250 | return output 251 | 252 | 253 | class Conv2dSame(torch.nn.Module): 254 | '''2D convolution that pads to keep spatial dimensions equal. 255 | Cannot deal with stride. Only quadratic kernels (=scalar kernel_size). 256 | ''' 257 | 258 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, padding_layer=nn.ReflectionPad2d): 259 | ''' 260 | :param in_channels: Number of input channels 261 | :param out_channels: Number of output channels 262 | :param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported). 263 | :param bias: Whether or not to use bias. 264 | :param padding_layer: Which padding to use. Default is reflection padding. 265 | ''' 266 | super().__init__() 267 | ka = kernel_size // 2 268 | kb = ka - 1 if kernel_size % 2 == 0 else ka 269 | self.net = nn.Sequential( 270 | padding_layer((ka, kb, ka, kb)), 271 | nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias, stride=1) 272 | ) 273 | 274 | self.weight = self.net[1].weight 275 | self.bias = self.net[1].bias 276 | 277 | def forward(self, x): 278 | return self.net(x) 279 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils 3 | 4 | """ 5 | 6 | import math 7 | import random 8 | import numpy as np 9 | 10 | import os 11 | import torch 12 | import torch.nn as nn 13 | 14 | from skimage.metrics import peak_signal_noise_ratio as psnr 15 | from skimage.metrics import structural_similarity as ssim 16 | 17 | import torch.nn.functional as F 18 | from torchvision.utils import save_image 19 | 20 | import props.prop_model as prop_model 21 | 22 | class AverageMeter: 23 | def __init__(self): 24 | self._sum = 0 25 | self._avg = 0 26 | self._cnt = 0 27 | def update(self, val): 28 | self._sum += val 29 | self._cnt += 1 30 | self._avg = self._sum / self._cnt 31 | @property 32 | def avg(self): 33 | return self._avg 34 | 35 | def apply_func_list(func, data_list): 36 | return [func(data) for data in data_list] 37 | 38 | def post_process_amp(amp, scale=1.0): 39 | # amp is a image tensor in range [0, 1] 40 | amp = amp * scale 41 | amp = torch.clip(amp, 0, 1) 42 | amp = amp.detach().squeeze().cpu().numpy() 43 | return amp 44 | 45 | def roll_torch(tensor, shift: int, axis: int): 46 | if shift == 0: 47 | return tensor 48 | 49 | if axis < 0: 50 | axis += tensor.dim() 51 | 52 | dim_size = tensor.size(axis) 53 | after_start = dim_size - shift 54 | if shift < 0: 55 | after_start = -shift 56 | shift = dim_size - abs(shift) 57 | 58 | before = tensor.narrow(axis, 0, dim_size - shift) 59 | after = tensor.narrow(axis, after_start, shift) 60 | return torch.cat([after, before], axis) 61 | 62 | 63 | def ifftshift(tensor): 64 | """ifftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2] 65 | 66 | shifts the width and heights 67 | """ 68 | size = tensor.size() 69 | tensor_shifted = roll_torch(tensor, -math.floor(size[2] / 2.0), 2) 70 | tensor_shifted = roll_torch(tensor_shifted, -math.floor(size[3] / 2.0), 3) 71 | return tensor_shifted 72 | 73 | 74 | def fftshift(tensor): 75 | """fftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2] 76 | 77 | shifts the width and heights 78 | """ 79 | size = tensor.size() 80 | tensor_shifted = roll_torch(tensor, math.floor(size[2] / 2.0), 2) 81 | tensor_shifted = roll_torch(tensor_shifted, math.floor(size[3] / 2.0), 3) 82 | return tensor_shifted 83 | 84 | 85 | def pad_image(field, target_shape, pytorch=True, stacked_complex=True, padval=0, mode='constant'): 86 | """Pads a 2D complex field up to target_shape in size 87 | 88 | Padding is done such that when used with crop_image(), odd and even dimensions are 89 | handled correctly to properly undo the padding. 90 | 91 | field: the field to be padded. May have as many leading dimensions as necessary 92 | (e.g., batch or channel dimensions) 93 | target_shape: the 2D target output dimensions. If any dimensions are smaller 94 | than field, no padding is applied 95 | pytorch: if True, uses torch functions, if False, uses numpy 96 | stacked_complex: for pytorch=True, indicates that field has a final dimension 97 | representing real and imag 98 | padval: the real number value to pad by 99 | mode: padding mode for numpy or torch 100 | """ 101 | if pytorch: 102 | if stacked_complex: 103 | size_diff = np.array(target_shape) - np.array(field.shape[-3:-1]) 104 | odd_dim = np.array(field.shape[-3:-1]) % 2 105 | else: 106 | size_diff = np.array(target_shape) - np.array(field.shape[-2:]) 107 | odd_dim = np.array(field.shape[-2:]) % 2 108 | else: 109 | size_diff = np.array(target_shape) - np.array(field.shape[-2:]) 110 | odd_dim = np.array(field.shape[-2:]) % 2 111 | 112 | # pad the dimensions that need to increase in size 113 | if (size_diff > 0).any(): 114 | pad_total = np.maximum(size_diff, 0) 115 | pad_front = (pad_total + odd_dim) // 2 116 | pad_end = (pad_total + 1 - odd_dim) // 2 117 | 118 | if pytorch: 119 | pad_axes = [int(p) # convert from np.int64 120 | for tple in zip(pad_front[::-1], pad_end[::-1]) 121 | for p in tple] 122 | if stacked_complex: 123 | return pad_stacked_complex(field, pad_axes, mode=mode, padval=padval) 124 | else: 125 | return nn.functional.pad(field, pad_axes, mode=mode, value=padval) 126 | else: 127 | leading_dims = field.ndim - 2 # only pad the last two dims 128 | if leading_dims > 0: 129 | pad_front = np.concatenate(([0] * leading_dims, pad_front)) 130 | pad_end = np.concatenate(([0] * leading_dims, pad_end)) 131 | return np.pad(field, tuple(zip(pad_front, pad_end)), mode, 132 | constant_values=padval) 133 | else: 134 | return field 135 | 136 | 137 | def crop_image(field, target_shape, pytorch=True, stacked_complex=True, lf=False): 138 | """Crops a 2D field, see pad_image() for details 139 | 140 | No cropping is done if target_shape is already smaller than field 141 | """ 142 | if target_shape is None: 143 | return field 144 | 145 | if lf: 146 | size_diff = np.array(field.shape[-4:-2]) - np.array(target_shape) 147 | odd_dim = np.array(field.shape[-4:-2]) % 2 148 | else: 149 | if pytorch: 150 | if stacked_complex: 151 | size_diff = np.array(field.shape[-3:-1]) - np.array(target_shape) 152 | odd_dim = np.array(field.shape[-3:-1]) % 2 153 | else: 154 | size_diff = np.array(field.shape[-2:]) - np.array(target_shape) 155 | odd_dim = np.array(field.shape[-2:]) % 2 156 | else: 157 | size_diff = np.array(field.shape[-2:]) - np.array(target_shape) 158 | odd_dim = np.array(field.shape[-2:]) % 2 159 | 160 | # crop dimensions that need to decrease in size 161 | if (size_diff > 0).any(): 162 | crop_total = np.maximum(size_diff, 0) 163 | crop_front = (crop_total + 1 - odd_dim) // 2 164 | crop_end = (crop_total + odd_dim) // 2 165 | 166 | crop_slices = [slice(int(f), int(-e) if e else None) 167 | for f, e in zip(crop_front, crop_end)] 168 | if lf: 169 | return field[(..., *crop_slices, slice(None), slice(None))] 170 | else: 171 | if pytorch and stacked_complex: 172 | return field[(..., *crop_slices, slice(None))] 173 | else: 174 | return field[(..., *crop_slices)] 175 | else: 176 | return field 177 | 178 | 179 | def srgb_gamma2lin(im_in): 180 | """converts from sRGB to linear color space""" 181 | thresh = 0.04045 182 | im_out = np.where(im_in <= thresh, im_in / 12.92, ((im_in + 0.055) / 1.055)**(2.4)) 183 | return im_out 184 | 185 | 186 | def srgb_lin2gamma(im_in): 187 | """converts from linear to sRGB color space""" 188 | thresh = 0.0031308 189 | im_out = np.where(im_in <= thresh, 12.92 * im_in, 1.055 * (im_in**(1 / 2.4)) - 0.055) 190 | return im_out 191 | 192 | 193 | def cond_mkdir(path): 194 | if not os.path.exists(path): 195 | os.makedirs(path) 196 | 197 | 198 | def burst_img_processor(img_burst_list): 199 | img_tensor = np.stack(img_burst_list, axis=0) 200 | img_avg = np.mean(img_tensor, axis=0) 201 | return im2float(img_avg) # changed from int8 to float32 202 | 203 | 204 | def im2float(im, dtype=np.float32): 205 | """convert uint16 or uint8 image to float32, with range scaled to 0-1 206 | 207 | :param im: image 208 | :param dtype: default np.float32 209 | :return: 210 | """ 211 | if issubclass(im.dtype.type, np.floating): 212 | return im.astype(dtype) 213 | elif issubclass(im.dtype.type, np.integer): 214 | return im / dtype(np.iinfo(im.dtype).max) 215 | else: 216 | raise ValueError(f'Unsupported data type {im.dtype}') 217 | 218 | 219 | def get_psnr_ssim(recon_amp, target_amp, multichannel=False): 220 | """get PSNR and SSIM metrics""" 221 | psnrs, ssims = {}, {} 222 | 223 | 224 | # amplitude 225 | psnrs['amp'] = psnr(target_amp, recon_amp) 226 | ssims['amp'] = ssim(target_amp, recon_amp, multichannel=multichannel) 227 | 228 | # linear 229 | target_linear = target_amp**2 230 | recon_linear = recon_amp**2 231 | psnrs['lin'] = psnr(target_linear, recon_linear) 232 | ssims['lin'] = ssim(target_linear, recon_linear, multichannel=multichannel) 233 | 234 | # srgb 235 | target_srgb = srgb_lin2gamma(np.clip(target_linear, 0.0, 1.0)) 236 | recon_srgb = srgb_lin2gamma(np.clip(recon_linear, 0.0, 1.0)) 237 | psnrs['srgb'] = psnr(target_srgb, recon_srgb) 238 | ssims['srgb'] = ssim(target_srgb, recon_srgb, multichannel=multichannel) 239 | 240 | return psnrs, ssims 241 | 242 | 243 | def make_kernel_gaussian(sigma, kernel_size): 244 | 245 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 246 | x_cord = torch.arange(kernel_size) 247 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) 248 | y_grid = x_grid.t() 249 | xy_grid = torch.stack([x_grid, y_grid], dim=-1) 250 | 251 | mean = (kernel_size - 1) / 2 252 | variance = sigma**2 253 | 254 | # Calculate the 2-dimensional gaussian kernel which is 255 | # the product of two gaussian distributions for two different 256 | # variables (in this case called x and y) 257 | gaussian_kernel = ((1 / (2 * math.pi * variance)) 258 | * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) 259 | / (2 * variance))) 260 | # Make sure sum of values in gaussian kernel equals 1. 261 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 262 | 263 | # Reshape to 2d depthwise convolutional weight 264 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 265 | 266 | return gaussian_kernel 267 | 268 | 269 | def pad_stacked_complex(field, pad_width, padval=0): 270 | if padval == 0: 271 | pad_width = (0, 0, *pad_width) # add 0 padding for stacked_complex dimension 272 | return nn.functional.pad(field, pad_width) 273 | else: 274 | if isinstance(padval, torch.Tensor): 275 | padval = padval.item() 276 | 277 | real, imag = field[..., 0], field[..., 1] 278 | real = nn.functional.pad(real, pad_width, value=padval) 279 | imag = nn.functional.pad(imag, pad_width, value=0) 280 | return torch.stack((real, imag), -1) 281 | 282 | 283 | def lut_mid(lut): 284 | return [(a + b) / 2 for a, b in zip(lut[:-1], lut[1:])] 285 | 286 | 287 | def nearest_neighbor_search(input_val, lut, lut_midvals=None): 288 | """ 289 | Quantize to nearest neighbor values in lut 290 | :param input_val: input tensor 291 | :param lut: list of discrete values supported 292 | :param lut_midvals: set threshold to put into torch.searchsorted function. 293 | :return: 294 | """ 295 | # if lut_midvals is None: 296 | # lut_midvals = torch.tensor(lut_mid(lut), dtype=torch.float32).to(input_val.device) 297 | idx = nearest_idx(input_val, lut_midvals) 298 | assert not torch.isnan(idx).any() 299 | return lut[idx], idx 300 | 301 | 302 | def nearest_idx(input_val, lut_midvals): 303 | """ Return nearest idx of lut per pixel """ 304 | input_array = input_val.detach() 305 | len_lut = len(lut_midvals) 306 | # print(lut_midvals.shape) 307 | # idx = torch.searchsorted(lut_midvals.to(input_val.device), input_array, right=True) 308 | idx = torch.bucketize(input_array, lut_midvals.to(input_val.device), right=True) 309 | 310 | return idx % len_lut 311 | 312 | 313 | def srgb_gamma2lin(im_in): 314 | """ converts from sRGB to linear color space """ 315 | thresh = 0.04045 316 | if torch.is_tensor(im_in): 317 | low_val = im_in <= thresh 318 | im_out = torch.zeros_like(im_in) 319 | im_out[low_val] = 25 / 323 * im_in[low_val] 320 | im_out[torch.logical_not(low_val)] = ((200 * im_in[torch.logical_not(low_val)] + 11) 321 | / 211) ** (12 / 5) 322 | else: 323 | im_out = np.where(im_in <= thresh, im_in / 12.92, ((im_in + 0.055) / 1.055) ** (12/5)) 324 | 325 | return im_out 326 | 327 | 328 | def srgb_lin2gamma(im_in): 329 | """ converts from linear to sRGB color space """ 330 | thresh = 0.0031308 331 | im_out = np.where(im_in <= thresh, 12.92 * im_in, 1.055 * (im_in**(1 / 2.4)) - 0.055) 332 | return im_out 333 | 334 | 335 | def decompose_depthmap(depthmap_virtual_D, depth_planes_D): 336 | """ decompose a depthmap image into a set of masks with depth positions (in Diopter) """ 337 | 338 | num_planes = len(depth_planes_D) 339 | 340 | masks = torch.zeros(depthmap_virtual_D.shape[0], len(depth_planes_D), *depthmap_virtual_D.shape[-2:], 341 | dtype=torch.float32).to(depthmap_virtual_D.device) 342 | for k in range(len(depth_planes_D) - 1): 343 | depth_l = depth_planes_D[k] 344 | depth_h = depth_planes_D[k + 1] 345 | idxs = (depthmap_virtual_D >= depth_l) & (depthmap_virtual_D < depth_h) 346 | close_idxs = (depth_h - depthmap_virtual_D) > (depthmap_virtual_D - depth_l) 347 | 348 | # closer one 349 | mask = torch.zeros_like(depthmap_virtual_D) 350 | mask += idxs * close_idxs * 1 351 | masks[:, k, ...] += mask.squeeze(1) 352 | 353 | # farther one 354 | mask = torch.zeros_like(depthmap_virtual_D) 355 | mask += idxs * (~close_idxs) * 1 356 | masks[:, k + 1, ...] += mask.squeeze(1) 357 | 358 | # even closer ones 359 | idxs = depthmap_virtual_D >= max(depth_planes_D) 360 | mask = torch.zeros_like(depthmap_virtual_D) 361 | mask += idxs * 1 362 | masks[:, len(depth_planes_D) - 1, ...] += mask.clone().squeeze(1) 363 | 364 | # even farther ones 365 | idxs = depthmap_virtual_D < min(depth_planes_D) 366 | mask = torch.zeros_like(depthmap_virtual_D) 367 | mask += idxs * 1 368 | masks[:, 0, ...] += mask.clone().squeeze(1) 369 | 370 | # sanity check 371 | assert torch.sum(masks).item() == torch.numel(masks) / num_planes 372 | 373 | return masks 374 | 375 | def decompose_depthmap_v2(depth_batch, num_depth_planes, roi_res): 376 | """ 377 | Depth (N, 1, H, W) -> Masks (N, num_depth_planes, H, W) 378 | Decompose depth map in each batch 379 | """ 380 | def _decompose_depthmap(depth, num_depth_planes): 381 | depth = depth * 1000 382 | print(roi_res) 383 | depth_vals = crop_image(depth, roi_res, stacked_complex=False).ravel() 384 | npt = len(depth_vals) 385 | depth_bins = np.interp(np.linspace(0, npt, num_depth_planes), 386 | np.arange(npt), 387 | np.sort(depth_vals)).round(decimals=2) 388 | 389 | masks = [] 390 | for i in range(num_depth_planes): 391 | if i < num_depth_planes - 1: 392 | min_d = depth_bins[i] 393 | max_d = depth_bins[i + 1] 394 | mask = torch.where(depth >= min_d, 1, 0) * torch.where(depth < max_d, 1, 0) 395 | else: 396 | mask = torch.where(depth >= depth_bins[-1], 1, 0) 397 | masks.append(mask) 398 | masks = torch.stack(masks) 399 | masks = torch.where(masks > 0, 1, 0).float() 400 | for i in range(num_depth_planes - 1): 401 | mask_diff = torch.logical_and(masks[i], masks[i + 1]).float() 402 | masks[i] -= mask_diff 403 | # reverse depth order 404 | masks = masks.flip(0) 405 | return masks.unsqueeze(0) 406 | 407 | masks = [_decompose_depthmap(depth.squeeze(), num_depth_planes) for depth in depth_batch] 408 | masks = torch.cat(masks, dim=0) 409 | return masks 410 | 411 | 412 | 413 | def prop_dist_to_diopter(prop_dists, focal_distance, prop_dist_inf, from_lens=True): 414 | """ 415 | Calculates distance from the user in diopter unit given the propagation distance from the SLM. 416 | :param prop_dists: 417 | :param focal_distance: 418 | :param prop_dist_inf: 419 | :param from_lens: 420 | :return: 421 | """ 422 | x0 = prop_dist_inf # prop distance from SLM that correcponds to optical infinity from the user 423 | f = focal_distance # focal distance of eyepiece 424 | 425 | if from_lens: # distance is from the lens 426 | diopters = [1 / (x0 + f - x) - 1 / f for x in prop_dists] # diopters from the user side 427 | else: # distance is from the user (basically adding focal length) 428 | diopters = [(x - x0) / f**2 for x in prop_dists] 429 | 430 | return diopters 431 | 432 | 433 | def switch_lf(input, mode='elemental'): 434 | spatial_res = input.shape[2:4] 435 | angular_res = input.shape[-2:] 436 | if mode == 'elemental': 437 | lf = input.permute(0, 1, 2, 4, 3, 5) 438 | elif mode == 'whole': 439 | lf = input.permute(0, 1, 4, 2, 5, 3) # show each view 440 | return lf.reshape(1, 1, *(s*a for s, a in zip(spatial_res, angular_res))) 441 | 442 | 443 | def nonnegative_mean_dilate(im): 444 | """ 445 | """ 446 | 447 | # take the mean filter over all pixels not equal to -1 448 | im = F.pad(im, (1, 1, 1, 1), mode='reflect') 449 | im = im.unfold(2, 3, 1).unfold(3, 3, 1) 450 | im = im.contiguous().view(im.size()[:4] + (-1, )) 451 | percent_surrounded_by_holes = ((im != -1) * (im < 0)).sum(dim=-1)/(1e-12 + (im != -1).sum(dim=-1)) 452 | holes = (0.7 < percent_surrounded_by_holes) 453 | mean_im = ((im > -1) * im).sum(dim= -1)/(1e-12 + (im > -1).sum(dim=-1)) 454 | im = mean_im * torch.logical_not(holes) - 1 * (0 == (im > -1).sum(dim=-1))*torch.logical_not(holes) - 2 * holes 455 | 456 | return im 457 | 458 | 459 | def generate_incoherent_stack(target_amp, depth_masks, depth_planes_depth, 460 | wavelength, pitch, focal_stack_blur_radius=1.0): 461 | """ 462 | 463 | :param target_amp: 464 | :param depth_masks: 465 | :param depth_planes_depth: 466 | :param wavelength: 467 | :param pitch: 468 | :param focal_stack_blur_radius: 469 | :return: 470 | """ 471 | with torch.no_grad(): 472 | # Create inpainted images for better approximation of occluded regions (start with -1 for occluded regions to be inpainted, and -2 for holes) 473 | inpainted_images = depth_masks*target_amp - 2 * (1 - depth_masks) 474 | occluded_regions = torch.zeros_like(depth_masks) 475 | for j in range(depth_masks.shape[1]): 476 | for k in range(depth_masks.shape[1]): 477 | if k > j: 478 | occluded_regions[:, j, ...] = torch.logical_or(depth_masks[:, k, ...] > 0, occluded_regions[:, j, ...]) 479 | inpainted_images += 1 * occluded_regions 480 | 481 | inpainting_ordering = depth_masks.clone() 482 | for j in range(depth_masks.shape[1]): 483 | buffer = 50 * math.ceil(((depth_planes_depth[-1] - depth_planes_depth[0] / pitch)* \ 484 | math.sqrt(1/((2 * pitch / wavelength)**2 - 1)))) 485 | for i in range(buffer): 486 | blurred_im = nonnegative_mean_dilate(inpainted_images[:, j, ...].unsqueeze(1))[:, 0, ...] 487 | inpainting_ordering[:, j, ...][torch.logical_and((inpainted_images[:, j, ...] == -1), (blurred_im >= 0))] = i + 2 488 | inpainted_images[:, j, ...][(inpainted_images[:, j, ...] == -1)] = blurred_im[(inpainted_images[:, j, ...] == -1)] 489 | closest_inpainting = torch.zeros_like(depth_masks) # tracks if depth is closest inpainting depth of the remaining planes 490 | for j in range(inpainting_ordering.shape[1]): 491 | closest_inpainting[:, j, ...] = inpainting_ordering[:, j, ...] > 0 492 | for k in range(inpainting_ordering.shape[1]): 493 | if k < j: 494 | closest_inpainting[:, j, ...] *= torch.logical_or(inpainting_ordering[:, k, ...] < 1, 495 | inpainting_ordering[:, j, ...] <= inpainting_ordering[:, k, ...]) 496 | 497 | # Propagation starting with front planes to handle occlusion 498 | focal_stack = torch.zeros_like(depth_masks) 499 | unblocked_weighting = torch.ones_like(depth_masks) 500 | for j in range(focal_stack.shape[1] - 1, -1, -1): 501 | for k in range(focal_stack.shape[1] - 1, -1, -1): 502 | if k == j: 503 | focal_stack[:, k, ...] += unblocked_weighting[:, k, ...]*(target_amp[:, 0, ...]*depth_masks[:, j, ...]) 504 | unblocked_weighting[:, k, ...] -= unblocked_weighting[:, k, ...]*depth_masks[:, j, ...] 505 | else: 506 | incoherent_propagator = create_diffraction_cone_propagator(focal_stack_blur_radius * 507 | abs(depth_planes_depth[j] - depth_planes_depth[k]), wavelength, pitch, depth_masks.device) 508 | focal_stack[:, k, ...] += unblocked_weighting[:,k,...] * \ 509 | (incoherent_propagator((target_amp[:,0,...] * depth_masks[:,j,...]).unsqueeze(1))[:, 0, ...]) 510 | unblocked_weighting[:, k, ...] -= unblocked_weighting[:, k, ...] * \ 511 | (incoherent_propagator((depth_masks[:, j, ...]).unsqueeze(1))[:, 0, ...]) 512 | 513 | # Propagate inpainted content where necessary 514 | for j in range(focal_stack.shape[1] - 1, -1, -1): 515 | for k in range(focal_stack.shape[1] - 1, -1, -1): 516 | if k == j: 517 | focal_stack[:, k, ...] += unblocked_weighting[:, k, ...] * inpainted_images[:, j, ...] *\ 518 | (inpainted_images[:, j, ...] >= 0) * closest_inpainting[:, j, ...] 519 | unblocked_weighting[:, k, ...] -= unblocked_weighting[:,k,...]*closest_inpainting[:, j, ...] * (inpainted_images[:, j, ...] >= 0) 520 | else: 521 | incoherent_propagator = create_diffraction_cone_propagator(focal_stack_blur_radius * abs(depth_planes_depth[j] - depth_planes_depth[k]), 522 | wavelength, pitch, depth_masks.device) 523 | focal_stack[:, k, ...] += unblocked_weighting[:, k, ...] * \ 524 | (incoherent_propagator((inpainted_images[:, j, ...] * 525 | (inpainted_images[:, j, ...] >= 0)).unsqueeze(1))[:, 0, ...]) \ 526 | * closest_inpainting[:,j,...] 527 | unblocked_weighting[:, k, ...] -= unblocked_weighting[:, k, ...]*closest_inpainting[:, j, ...] * \ 528 | (incoherent_propagator(1.0 * (inpainted_images[:, j, ...] >= 0).unsqueeze(1))[:, 0, ...]) 529 | 530 | return focal_stack 531 | 532 | 533 | def create_diffraction_cone_propagator(distance, wavelength, pitch, device): 534 | """ Create blur layer for incoherent propagation """ 535 | with torch.no_grad(): 536 | subhologram_halfsize = ((distance/pitch)* \ 537 | math.sqrt(1/((2*pitch/wavelength)**2-1))) 538 | kernel = np.zeros((2*math.ceil(subhologram_halfsize)+5, 2*math.ceil(subhologram_halfsize)+5)) 539 | y,x = np.ogrid[-math.ceil(subhologram_halfsize)-2:math.ceil(subhologram_halfsize)+3, -math.ceil(subhologram_halfsize)-2:math.ceil(subhologram_halfsize)+3] 540 | mask = x**2+y**2 <= subhologram_halfsize**2 541 | kernel[mask] = 1 542 | kernel = torch.Tensor(kernel).unsqueeze(0).unsqueeze(0).to(device) 543 | kernel = kernel/kernel.sum() 544 | incoherent_propagator = nn.Conv2d(1, 1, kernel_size=2*math.ceil(subhologram_halfsize)+5, stride=1, padding=math.ceil(subhologram_halfsize)+2, padding_mode='replicate', bias=False) 545 | incoherent_propagator.weight = nn.Parameter(kernel, requires_grad=False) 546 | 547 | return incoherent_propagator 548 | 549 | 550 | def laplacian(img): 551 | 552 | # signed angular difference 553 | grad_x1, grad_y1 = grad(img, next_pixel=True) # x_{n+1} - x_{n} 554 | grad_x0, grad_y0 = grad(img, next_pixel=False) # x_{n} - x_{n-1} 555 | 556 | laplacian_x = grad_x1 - grad_x0 # (x_{n+1} - x_{n}) - (x_{n} - x_{n-1}) 557 | laplacian_y = grad_y1 - grad_y0 558 | 559 | return laplacian_x + laplacian_y 560 | 561 | 562 | def grad(img, next_pixel=False, sovel=False): 563 | 564 | if img.shape[1] > 1: 565 | permuted = True 566 | img = img.permute(1, 0, 2, 3) 567 | else: 568 | permuted = False 569 | 570 | # set diff kernel 571 | if sovel: # use sovel filter for gradient calculation 572 | k_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32) / 8 573 | k_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32) / 8 574 | else: 575 | if next_pixel: # x_{n+1} - x_n 576 | k_x = torch.tensor([[0, -1, 1]], dtype=torch.float32) 577 | k_y = torch.tensor([[1], [-1], [0]], dtype=torch.float32) 578 | else: # x_{n} - x_{n-1} 579 | k_x = torch.tensor([[-1, 1, 0]], dtype=torch.float32) 580 | k_y = torch.tensor([[0], [1], [-1]], dtype=torch.float32) 581 | 582 | # upload to gpu 583 | k_x = k_x.to(img.device).unsqueeze(0).unsqueeze(0) 584 | k_y = k_y.to(img.device).unsqueeze(0).unsqueeze(0) 585 | 586 | # boundary handling (replicate elements at boundary) 587 | img_x = F.pad(img, (1, 1, 0, 0), 'replicate') 588 | img_y = F.pad(img, (0, 0, 1, 1), 'replicate') 589 | 590 | # take sign angular difference 591 | grad_x = signed_ang(F.conv2d(img_x, k_x)) 592 | grad_y = signed_ang(F.conv2d(img_y, k_y)) 593 | 594 | if permuted: 595 | grad_x = grad_x.permute(1, 0, 2, 3) 596 | grad_y = grad_y.permute(1, 0, 2, 3) 597 | 598 | return grad_x, grad_y 599 | 600 | 601 | def signed_ang(angle): 602 | """ 603 | cast all angles into [-pi, pi] 604 | """ 605 | return (angle + math.pi) % (2*math.pi) - math.pi 606 | 607 | 608 | # Adapted from https://github.com/svaiter/pyprox/blob/master/pyprox/operators.py 609 | def soft_thresholding(x, gamma): 610 | """ 611 | return element-wise shrinkage function with threshold kappa 612 | """ 613 | return torch.maximum(torch.zeros_like(x), 614 | 1 - gamma / torch.maximum(torch.abs(x), 1e-10*torch.ones_like(x))) * x 615 | 616 | 617 | def random_gen(num_planes=7, slm_type='ti', **kwargs): 618 | """ 619 | random hyperparameters for the dataset 620 | """ 621 | frame_choices = [1, 1, 2, 2, 4, 4, 4, 8, 8, 8] if slm_type.lower() == 'ti' else [1] 622 | q_choices = ['None', 'nn', 'nn_sigmoid', 'gumbel_softmax'] if slm_type.lower() == 'ti' else ['None'] 623 | 624 | 625 | num_frames = random.choice(frame_choices) 626 | quan_method = random.choice(q_choices) 627 | num_iters = random.choice(range(2000)) + 1 628 | phase_range = random.uniform(1.0, 6.28) 629 | target_range = random.uniform(0.5, 1.5) 630 | learning_rate = random.uniform(0.01, 0.03) 631 | plane_idx = random.choice(range(num_planes)) 632 | # reg_lf_var = random.choice([0., 0., 1.0, 10.0, 100.0]) 633 | reg_lf_var = -1 634 | 635 | 636 | # for profiling 637 | #num_frames = 1 638 | #quan_method = "None" 639 | #num_iters = 10 640 | #phase_range = 3 641 | #target_range = 1 642 | #learning_rate = 0.02 643 | #plane_idx = 4 644 | #reg_lf_var = -1 645 | 646 | 647 | return num_frames, num_iters, phase_range, target_range, learning_rate, plane_idx, quan_method, reg_lf_var 648 | 649 | def write_opt(opt, out_path): 650 | import json 651 | with open(os.path.join(out_path, f'opt.json'), "w") as opt_file: 652 | json.dump(dict(opt), opt_file, indent=4) 653 | 654 | def init_phase(init_phase_type, target_amp, dev, opt): 655 | if init_phase_type == "random": 656 | init_phase = -0.5 + 1.0 * torch.rand(opt.num_frames, 1, *opt.slm_res) 657 | return opt.init_phase_range * init_phase.to(dev) 658 | 659 | def create_backprop_instance(forward_prop): 660 | from params import clone_params 661 | # find a cleaner way to create a backprop instance 662 | 663 | # forward prop only front propagation 664 | # need backwards propagation 665 | assert forward_prop.opt.serial_two_prop_off # assert 1 prop 666 | 667 | # also update the prop_dist and wrp stuff 668 | 669 | backprop_opt = clone_params(forward_prop.opt) 670 | backprop_opt.prop_dist = -forward_prop.opt.prop_dist # propagate back 671 | backward_prop = prop_model.model(backprop_opt) 672 | 673 | return backward_prop 674 | 675 | def normalize_range(data, data_min, data_max, low, high): 676 | data = (data - data_min) / (data_max - data_min) # 0 - 1 677 | data = (high - low) * data + low 678 | return data 679 | --------------------------------------------------------------------------------