├── LICENSE
├── README.md
├── algorithms.py
├── configs_2d.txt
├── configs_3d.txt
├── configs_4d.txt
├── configs_model_training.txt
├── env.yml
├── holo2lf.py
├── hw
├── __init__.py
├── calibration_module.py
├── camera_capture_module.py
├── detect_heds_module_path.py
├── discrete_slm.py
├── phase_encodings.py
├── slm_display_module.py
├── ti.py
└── ti_encodings.py
├── image_loader.py
├── img
└── teaser.png
├── main.py
├── params.py
├── props
├── __init__.py
├── prop_ideal.py
├── prop_model.py
├── prop_physical.py
├── prop_submodules.py
└── prop_zernike.py
├── quantization.py
├── train.py
├── unet.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Stanford Computational Imaging Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Time-multiplexed Neural Holography: A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
SIGGRAPH 2022
2 | ### [Project Page](http://www.computationalimaging.org/publications/time-multiplexed-neural-holography/) | [Video](https://youtu.be/k2dg-Ckhk5Q) | [Paper](https://drive.google.com/file/d/1n8xSdHgW0D5G5HhwSKrqCy1iztAcDHgX/view?usp=sharing)
3 | PyTorch implementation of
4 | [Time-multiplexed Neural Holography: A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators](http://www.computationalimaging.org/publications/time-multiplexed-neural-holography/)
5 | [Suyeon Choi](http://stanford.edu/~suyeon/)\*,
6 | [Manu Gopakumar](https://www.linkedin.com/in/manu-gopakumar-25032412b/)\*,
7 | [Yifan Peng](http://web.stanford.edu/~evanpeng/),
8 | [Jonghyun Kim](http://j-kim.kr/),
9 | [Matthew O'Toole](https://www.cs.cmu.edu/~motoole2/),
10 | [Gordon Wetzstein](https://computationalimaging.org)
11 | \*denotes equal contribution
12 | in SIGGRAPH 2022
13 |
14 |
15 |
16 | ## Get started
17 | Our code uses [PyTorch Lightning](https://www.pytorchlightning.ai/) and PyTorch >=1.10.0.
18 |
19 | You can set up a conda environment with all dependencies like so:
20 | ```
21 | conda env create -f env.yml
22 | conda activate tmnh
23 | ```
24 |
25 | ## High-Level structure
26 | The code is organized as follows:
27 |
28 |
29 | `./`
30 | * ```main.py``` generates phase patterns from LF/RGBD/RGB data using SGD.
31 | * ```holo2lf.py``` contains the Light-field ↔ Hologram conversion implementations.
32 | * ```algorithms.py``` contains the gradient-descent based algorithm for LF/RGBD/RGB supervision
33 |
34 | * ```params.py``` contains our default parameter settings. :heavy_exclamation_mark:**(Replace values here with those in your setup.)**:heavy_exclamation_mark:
35 |
36 | * ```quantization.py``` contains modules for quantizations (projected gradient, sigmoid, Gumbel-Softmax).
37 | * ```image_loader.py``` contains data loader modules.
38 | * ```utils.py``` has some other utilities.
39 |
40 |
41 |
42 |
43 | `./props/` contain the wave propagation operators (in simulation and physics).
44 |
45 | `./hw/` contains modules for hardware control and homography calibration
46 | * ```ti.py``` contains data given by Texas Instruments.
47 | * ```ti_encodings.py``` contains phase encoding and decoding functionalities for the TI SLM.
48 |
49 |
50 | ## Run
51 | To run, download the sample images from [here](https://drive.google.com/file/d/1aooTbzsmGw-Rfel7ntb1HJY1kILLSuEk/view?usp=sharing) and place the contents in the `data/` folder.
52 |
53 | ### Dataset generation / Model training
54 | Please see the [supplement](https://drive.google.com/file/d/1n9hdLq1xvur4I_OkGNyFgoKHGDZcMxcE/view) and [Neural 3D Holography repo](https://github.com/computational-imaging/neural-3d-holography) for more details on dataset generation and model training.
55 | ```
56 | # Train TMNH models
57 | for c in 0 1 2
58 | do
59 | python train.py -c=configs_model_training.txt --channel=$c --data_path=${dataset_path}
60 | done
61 |
62 | ```
63 |
64 |
65 | ### Run SGD with various target distributions (RGB images, focal stacks, and light fields)
66 | ```
67 | for c in 0 1 2
68 | do
69 | # 2D rgb images
70 | python main.py -c=configs_2d.txt --channel=$c
71 | # 3D focal stacks
72 | python main.py -c=configs_3d.txt --channel=$c
73 | # 4D light fields
74 | python main.py -c=configs_4d.txt --channel=$c
75 | done
76 | ```
77 |
78 | ### Run SGD with advanced quantizations
79 | ```
80 | q=gumbel-softmax; # try none, nn, nn_sigmoid as well.
81 | python main.py -c=configs_2d.txt --channel=$c --quan_method=$q
82 |
83 | ```
84 |
85 | ## Citation
86 | If you find our work useful in your research, please cite:
87 | ```
88 | @inproceedings{choi2022time,
89 | author = {Choi, Suyeon
90 | and Gopakumar, Manu
91 | and Peng, Yifan
92 | and Kim, Jonghyun
93 | and O'Toole, Matthew
94 | and Wetzstein, Gordon},
95 | title={Time-multiplexed neural holography: a flexible framework for holographic near-eye displays with fast heavily-quantized spatial light modulators},
96 | booktitle={ACM SIGGRAPH 2022 Conference Proceedings},
97 | pages={1--9},
98 | year={2022}
99 | }
100 | ```
101 |
102 | ## Acknowledgmenets
103 | Thanks to [Brian Chao](https://bchao1.github.io/) for the help with code updates and [Cindy Nguyen](https://ccnguyen.github.io) for helpful discussions. This project was in part supported by a Kwanjeong Scholarship, a Stanford SGF, Intel, NSF (award 1839974), a PECASE by the ARO (W911NF-19-1-0120), and Sony.
104 |
105 | ## Contact
106 | If you have any questions, please feel free to email the authors.
--------------------------------------------------------------------------------
/algorithms.py:
--------------------------------------------------------------------------------
1 | """
2 | Various algorithms for LF/RGBD/RGB supervision.
3 |
4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu)
5 |
6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
8 | # The material is provided as-is, with no warranties whatsoever.
9 | # If you publish any code, data, or scientific work based on this, please cite our work.
10 |
11 | Technical Paper:
12 | Time-multiplexed Neural Holography:
13 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
14 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein.
15 | SIGGRAPH 2022
16 | """
17 |
18 | import imageio
19 | from PIL import Image, ImageDraw
20 | import torch
21 | import torch.nn as nn
22 | import torch.optim as optim
23 | import torchvision.transforms.functional as TF
24 | import numpy as np
25 | from tqdm import tqdm
26 | from copy import deepcopy
27 |
28 | import utils
29 | from holo2lf import holo2lf
30 |
31 | def load_alg(alg_type, mem_eff=False):
32 | if 'sgd' in alg_type.lower():
33 | if mem_eff:
34 | algorithm = efficient_gradient_descent
35 | else:
36 | algorithm = gradient_descent
37 | else:
38 | raise ValueError(f"Algorithm {alg_type} is not supported!")
39 |
40 | return algorithm
41 |
42 | def gradient_descent(init_phase, target_amp, target_mask=None, target_idx=None, forward_prop=None, num_iters=1000, roi_res=None,
43 | border_margin=None, loss_fn=nn.MSELoss(), lr=0.01, out_path_idx='./results',
44 | citl=False, camera_prop=None, writer=None, quantization=None,
45 | time_joint=True, flipud=False, reg_lf_var=0.0, *args, **kwargs):
46 | """
47 | Gradient-descent based method for phase optimization.
48 |
49 | :param init_phase:
50 | :param target_amp:
51 | :param target_mask:
52 | :param forward_prop:
53 | :param num_iters:
54 | :param roi_res:
55 | :param loss_fn:
56 | :param lr:
57 | :param out_path_idx:
58 | :param citl:
59 | :param camera_prop:
60 | :param writer:
61 | :param quantization:
62 | :param time_joint:
63 | :param flipud:
64 | :param args:
65 | :param kwargs:
66 | :return:
67 | """
68 | print("Naive gradient descent")
69 | assert forward_prop is not None
70 | dev = init_phase.device
71 |
72 |
73 | h, w = init_phase.shape[-2], init_phase.shape[-1] # total energy = h*w
74 |
75 | init_amp = torch.ones_like(init_phase) * 0.5
76 | init_amp_logits = torch.log(init_amp / (1 - init_amp)) # convert to inverse sigmoid
77 |
78 | slm_phase = init_phase.requires_grad_(True) # phase at the slm plane
79 | slm_amp_logits = init_amp_logits.requires_grad_(True) # amplitude at the slm plane
80 |
81 | optvars = [{'params': slm_phase}]
82 | if kwargs["optimize_amp"]:
83 | optvars.append({'params': slm_amp_logits})
84 |
85 | #if "opt_s" in reg_loss_fn_type:
86 | # s = torch.tensor(1.0).requires_grad_(True) # initial s value
87 | # optvars.append({'params': s})
88 | #else:
89 | # s = None
90 | s = torch.tensor(1.0)
91 | optimizer = optim.Adam(optvars, lr=lr)
92 |
93 | loss_vals = []
94 | psnr_vals = []
95 | loss_vals_quantized = []
96 | best_loss = 1e10
97 | best_iter = 0
98 | best_amp = None
99 | lf_supervision = len(target_amp.shape) > 4
100 |
101 | print("target amp shape", target_amp.shape)
102 |
103 | if target_mask is not None:
104 | target_amp = target_amp * target_mask
105 | nonzeros = target_mask > 0
106 | if roi_res is not None:
107 | target_amp = utils.crop_image(target_amp, roi_res, stacked_complex=False, lf=lf_supervision)
108 | if target_mask is not None:
109 | target_mask = utils.crop_image(target_mask, roi_res, stacked_complex=False, lf=lf_supervision)
110 | nonzeros = target_mask > 0
111 |
112 | if border_margin is not None:
113 | # make borders of target black
114 | mask = torch.zeros_like(target_amp)
115 | mask[:, :, border_margin:-border_margin, border_margin:-border_margin] = 1
116 | target_amp = target_amp * mask
117 |
118 | for t in tqdm(range(num_iters)):
119 | optimizer.zero_grad()
120 | if quantization is not None:
121 | quantized_phase = quantization(slm_phase, t/num_iters)
122 | else:
123 | quantized_phase = slm_phase
124 |
125 | if flipud:
126 | quantized_phase_f = quantized_phase.flip(dims=[2])
127 | else:
128 | quantized_phase_f = quantized_phase
129 |
130 | field_input = torch.exp(1j * quantized_phase_f)
131 |
132 | recon_field = forward_prop(field_input)
133 | recon_field = utils.crop_image(recon_field, roi_res, pytorch=True, stacked_complex=False) # here, also record an uncropped image
134 |
135 | if lf_supervision:
136 | recon_amp_t = holo2lf(recon_field, n_fft=kwargs['n_fft'], hop_length=kwargs['hop_len'],
137 | win_length=kwargs['win_len'], device=dev, impl='torch').sqrt()
138 | else:
139 | recon_amp_t = recon_field.abs()
140 |
141 | if time_joint: # time-multiplexed forward model
142 | recon_amp = (recon_amp_t**2).mean(dim=0, keepdims=True).sqrt()
143 | else:
144 | recon_amp = recon_amp_t
145 |
146 | if citl: # surrogate gradients for CITL
147 | captured_amp = camera_prop(slm_phase, 1)
148 | captured_amp = utils.crop_image(captured_amp, roi_res,
149 | stacked_complex=False)
150 | recon_amp_sim = recon_amp.clone() # simulated reconstructed image
151 | recon_amp = recon_amp + captured_amp - recon_amp.detach() # reconstructed image with surrogate gradients
152 |
153 | # clip to range
154 | if target_mask is not None:
155 | final_amp = torch.zeros_like(recon_amp)
156 | final_amp[nonzeros] += (recon_amp[nonzeros] * target_mask[nonzeros])
157 | else:
158 | final_amp = recon_amp
159 |
160 | # also track gradient of s
161 | with torch.no_grad():
162 | s = (final_amp * target_amp).mean(dim=(-1, -2), keepdims=True) / (final_amp ** 2).mean(dim=(-1, -2), keepdims=True) # scale minimizing MSE btw recon and target
163 |
164 | loss_val = loss_fn(s * final_amp, target_amp)
165 |
166 | mse_loss = ((s * final_amp - target_amp)**2).mean().item()
167 | psnr_val = 20 * np.log10(1 / np.sqrt(mse_loss))
168 |
169 | # loss term for having even emission at in-focus points (STFT-based regularization described in Supplementary)
170 | if reg_lf_var > 0.0:
171 | recon_amp_lf = holo2lf(recon_field, n_fft=kwargs['n_fft'], hop_length=kwargs['hop_len'],
172 | win_length=kwargs['win_len'], device=dev, impl='torch')
173 | recon_amp_lf = s * recon_amp_lf.mean(dim=0, keepdims=True).sqrt()
174 | loss_lf_var = torch.mean(torch.var(recon_amp_lf, (-2, -1)))
175 | loss_val += reg_lf_var * loss_lf_var
176 |
177 | loss_val.backward()
178 | optimizer.step()
179 |
180 | with torch.no_grad():
181 | if loss_val.item() < best_loss:
182 | best_phase = slm_phase
183 | best_loss = loss_val.item()
184 | best_amp = s * final_amp # fits target image.
185 | best_iter = t + 1
186 |
187 | psnr = 20 * torch.log10(1 / torch.sqrt(((s * final_amp - target_amp)**2).mean()))
188 | psnr_vals.append(psnr.item())
189 |
190 | return {'loss_vals': loss_vals,
191 | 'psnr_vals': psnr_vals,
192 | 'loss_vals_q': loss_vals_quantized,
193 | 'best_iter': best_iter,
194 | 'best_loss': best_loss,
195 | 'recon_amp': best_amp,
196 | 'target_amp': target_amp,
197 | 'final_phase': best_phase
198 | }
199 |
200 |
201 | def efficient_gradient_descent(init_phase, target_amp, target_mask=None, target_idx=None, forward_prop=None, num_iters=1000, roi_res=None,
202 | loss_fn=nn.MSELoss(), lr=0.01, out_path_idx='./results',
203 | citl=False, camera_prop=None, writer=None, quantization=None,
204 | time_joint=True, flipud=False, *args, **kwargs):
205 | """
206 | Gradient-descent based method for phase optimization.
207 |
208 | :param init_phase:
209 | :param target_amp:
210 | :param target_mask:
211 | :param forward_prop:
212 | :param num_iters:
213 | :param roi_res:
214 | :param loss_fn:
215 | :param lr:
216 | :param out_path_idx:
217 | :param citl:
218 | :param camera_prop:
219 | :param writer:
220 | :param quantization:
221 | :param time_joint:
222 | :param flipud:
223 | :param args:
224 | :param kwargs:
225 | :return:
226 | """
227 | print("Memory efficient gradient descent")
228 |
229 | assert forward_prop is not None
230 | dev = init_phase.device
231 | num_frames = init_phase.shape[0]
232 |
233 | slm_phase = init_phase.requires_grad_(True) # phase at the slm plane
234 | optvars = [{'params': slm_phase}]
235 | optimizer = optim.Adam(optvars, lr=lr)
236 |
237 | loss_vals = []
238 | loss_vals_quantized = []
239 | best_loss = 10.
240 | lf_supervision = len(target_amp.shape) > 4
241 |
242 | if target_mask is not None:
243 | target_amp = target_amp * target_mask
244 | nonzeros = target_mask > 0
245 | if roi_res is not None:
246 | target_amp = utils.crop_image(target_amp, roi_res, stacked_complex=False, lf=lf_supervision)
247 | if target_mask is not None:
248 | target_mask = utils.crop_image(target_mask, roi_res, stacked_complex=False, lf=lf_supervision)
249 | nonzeros = target_mask > 0
250 |
251 | for t in tqdm(range(num_iters)):
252 | optimizer.zero_grad() # zero grad
253 |
254 | # amplitude reconstruction without graph
255 | with torch.no_grad():
256 | if quantization is not None:
257 | quantized_phase = quantization(slm_phase, t/num_iters)
258 | else:
259 | quantized_phase = slm_phase
260 |
261 | if flipud:
262 | quantized_phase_f = quantized_phase.flip(dims=[2])
263 | else:
264 | quantized_phase_f = quantized_phase
265 |
266 | recon_field = forward_prop(quantized_phase_f) # just sample one depth plane
267 | recon_field = utils.crop_image(recon_field, roi_res, stacked_complex=False)
268 |
269 | if lf_supervision:
270 | recon_amp_t = holo2lf(recon_field, n_fft=kwargs['n_fft'], hop_length=kwargs['hop_len'],
271 | win_length=kwargs['win_len'], device=dev, impl='torch').sqrt()
272 | else:
273 | recon_amp_t = recon_field.abs()
274 |
275 | if citl: # surrogate gradients for CITL
276 | captured_amp = camera_prop(slm_phase)
277 | captured_amp = utils.crop_image(captured_amp, roi_res,
278 | stacked_complex=False)
279 |
280 | total_loss_val = 0
281 | # insert single frame's graph and accumulate gradient
282 | for f in range(num_frames):
283 | slm_phase_sf = slm_phase[f:f+1, ...]
284 | if quantization is not None:
285 | quantized_phase_sf = quantization(slm_phase_sf, t/num_iters)
286 | else:
287 | quantized_phase_sf = slm_phase_sf
288 |
289 | if flipud:
290 | quantized_phase_f_sf = quantized_phase_sf.flip(dims=[2])
291 | else:
292 | quantized_phase_f_sf = quantized_phase_sf
293 |
294 | recon_field_sf = forward_prop(quantized_phase_f_sf)
295 | recon_field_sf = utils.crop_image(recon_field_sf, roi_res, stacked_complex=False)
296 |
297 | if lf_supervision:
298 | recon_amp_t_sf = holo2lf(recon_field_sf, n_fft=kwargs['n_fft'], hop_length=kwargs['hop_len'],
299 | win_length=kwargs['win_len'], device=dev, impl='torch').sqrt()
300 | else:
301 | recon_amp_t_sf = recon_field_sf.abs()
302 |
303 | ### insert graph from single frame ###
304 | recon_amp_t_with_grad = recon_amp_t.clone().detach()
305 | recon_amp_t_with_grad[f:f+1,...] = recon_amp_t_sf
306 |
307 | if time_joint: # time-multiplexed forward model
308 | recon_amp = (recon_amp_t_with_grad**2).mean(dim=0, keepdims=True).sqrt()
309 | else:
310 | recon_amp = recon_amp_t_with_grad
311 |
312 | if citl:
313 | recon_amp = recon_amp + captured_amp / (num_frames) - recon_amp.detach()
314 |
315 | if target_mask is not None:
316 | final_amp = torch.zeros_like(recon_amp)
317 | final_amp[nonzeros] += recon_amp[nonzeros] * target_mask[nonzeros]
318 | else:
319 | final_amp = recon_amp
320 |
321 |
322 | with torch.no_grad():
323 | s = (final_amp * target_amp).mean() / \
324 | (final_amp ** 2).mean() # scale minimizing MSE btw recon and
325 |
326 |
327 |
328 | loss_val = loss_fn(s * final_amp, target_amp)
329 | loss_val.backward(retain_graph=False)
330 |
331 | total_loss_val += loss_val.item()
332 |
333 | if t % 10 == 0:
334 | pass
335 | #writer.add_scalar("loss", total_loss_val, t)
336 | #writer.add_scalar("recon loss", recon_loss.item(), t)
337 | #writer.add_scalar("light eff loss", reg_loss.item(), t)
338 | #writer.add_scalar("s", s.item(), t)
339 | #writer.add_image("recon", torch.clamp(s*final_amp[0], 0, 1), t)
340 |
341 | # update phase variables
342 | optimizer.step()
343 |
344 | with torch.no_grad():
345 | if total_loss_val < best_loss:
346 | best_phase = slm_phase
347 | best_loss = total_loss_val
348 | best_amp = s * recon_amp
349 | best_iter = t + 1
350 | print(total_loss_val)
351 |
352 | return {'loss_vals': loss_vals,
353 | 'loss_vals_q': loss_vals_quantized,
354 | 'best_iter': best_iter,
355 | 'best_loss': best_loss,
356 | 'recon_amp': best_amp,
357 | 'target_amp': target_amp,
358 | 'final_phase': best_phase,
359 | 's': s.item()}
--------------------------------------------------------------------------------
/configs_2d.txt:
--------------------------------------------------------------------------------
1 | data_path=data/2d
2 | out_path=results
3 | target=2d
4 | loss_func=l2
5 | uniform_nbits=4
6 | eval_plane_idx=3
--------------------------------------------------------------------------------
/configs_3d.txt:
--------------------------------------------------------------------------------
1 | data_path=data/3d_bamboo
2 | out_path=results
3 | target=3d
4 | loss_func=l2
5 | uniform_nbits=4
6 | eyepiece=0.035
7 |
--------------------------------------------------------------------------------
/configs_4d.txt:
--------------------------------------------------------------------------------
1 | data_path=data/4d_olas
2 | out_path=results
3 | target=4d
4 | loss_func=l2
5 | uniform_nbits=4
6 | eval_plane_idx=3
7 | eyepiece=0.035
8 |
--------------------------------------------------------------------------------
/configs_model_training.txt:
--------------------------------------------------------------------------------
1 | lr=3e-4
2 | batch_size=4
3 | prop_model=nh4d
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: tmnh
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.10
6 | - setuptools
7 | - pip
8 | - wheel
9 | - anaconda-client
10 | - anaconda-project
11 | - anaconda-navigator
12 | - conda
13 | - conda-build
14 | - conda-content-trust
15 | - conda-pack
16 | - conda-package-handling
17 | - conda-package-streaming
18 | - conda-token
19 | - conda-verify
20 | - setuptools
21 | - pip:
22 | - aotools
23 | - kornia
24 | - lightning-utilities
25 | - opencv-python==4.7.0.72
26 | - pytorch-lightning==2.0.4
27 | - serial
28 | - tensorboard==2.13.0
29 | - tensorboard-data-server==0.7.1
30 | - torch
31 | - torchaudio
32 | - torchmetrics
33 | - torchvision
34 | - h5py
35 | - tensorboard
36 | - configargparse
37 | - imageio
38 | - scikit-image
39 | - tqdm
40 | prefix: /home/suyeon/anaconda3
41 |
--------------------------------------------------------------------------------
/holo2lf.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementations of the Light-field ↔ Hologram conversion. Note that lf2holo method is basically the OLAS method.
3 |
4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu)
5 |
6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
8 | # The material is provided as-is, with no warranties whatsoever.
9 | # If you publish any code, data, or scientific work based on this, please cite our work.
10 |
11 | Technical Paper:
12 | Time-multiplexed Neural Holography:
13 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
14 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein.
15 | SIGGRAPH 2022
16 | """
17 | import math
18 | import numpy as np
19 | import torch
20 | import torch.nn.functional as F
21 |
22 |
23 | def holo2lf(input_field, n_fft=(9, 9), hop_length=(1, 1), win_func=None,
24 | win_length=None, device=torch.device('cuda'), impl='torch', predefined_h=None,
25 | return_h=False, h_size=(1, 1)):
26 | """
27 | Hologram to Light field transformation.
28 |
29 | :param input_field: input field shape of (N, 1, H, W), if 1D, set H=1.
30 | :param n_fft: a tuple of numbers of fourier basis.
31 | :param hop_length: a tuple of hop lengths to sample at the end.
32 | :param win_func: window function applied to each segment, default hann window.
33 | :param win_length: a tuple of lengths of window function. if win_length is smaller than n_fft, pad zeros to the windows.
34 | :param device: torch cuda.
35 | :param impl: implementation ('conv', 'torch', 'olas')
36 | :return: A 4D representation of light field, shape of (N, 1, H, W, U, V)
37 | """
38 | input_length = input_field.shape[-2:]
39 | batch_size, _, Ny, Nx = input_field.shape
40 |
41 | # for 1D input (n_fft = 1), don't take fourier transform toward that direction.
42 | n_fft_y = min(n_fft[0], input_length[0])
43 | n_fft_x = min(n_fft[1], input_length[1])
44 |
45 | if win_length is None:
46 | win_length = n_fft
47 |
48 | win_length_y = min(win_length[0], input_length[0])
49 | win_length_x = min(win_length[1], input_length[1])
50 |
51 | if win_func is None:
52 | w_func = lambda length: torch.hann_window(length + 1, device=device)[1:]
53 | # w_func = lambda length: torch.ones(length)
54 | win_func = torch.ger(w_func(win_length_y), w_func(win_length_x))
55 |
56 | win_func = win_func.to(input_field.device)
57 | win_func /= win_func.sum()
58 |
59 | if impl == 'torch':
60 | # 1) use STFT implementation of PyTorch
61 | if len(input_field.squeeze().shape) > 1: # with 2D input
62 | # input_field = input_field.view(-1, input_field.shape[-1]) # merge batch & y dimension
63 | input_field = input_field.reshape(np.prod(input_field.size()[:-1]), input_field.shape[-1]) # merge batch & y dimension
64 |
65 | # take 1D stft along x dimension
66 | stft_x = torch.stft(input_field, n_fft=n_fft_x, hop_length=hop_length[1], win_length=win_length_x,
67 | onesided=False, window=win_func[win_length_y//2, :], pad_mode='constant',
68 | normalized=False, return_complex=True)
69 |
70 | if n_fft_y > 1: # 4D light field output
71 | stft_x = stft_x.reshape(batch_size, Ny, n_fft_x, Nx//hop_length[1]).permute(0, 3, 2, 1)
72 | stft_x = stft_x.contiguous().view(-1, Ny)
73 |
74 | # take one more 1D stft along y dimension
75 | stft_xy = torch.stft(stft_x, n_fft=n_fft_y, hop_length=hop_length[0], win_length=win_length_y,
76 | onesided=False, window=win_func[:, win_length_x//2], pad_mode='constant',
77 | normalized=False, return_complex=True)
78 |
79 | # reshape tensor to (N, 1, Y, X, fy, fx)
80 | stft_xy = stft_xy.reshape(batch_size, Nx//hop_length[1], n_fft[1], n_fft[0], Ny//hop_length[0])
81 | stft_xy = stft_xy.unsqueeze(1).permute(0, 1, 5, 2, 4, 3)
82 | freq_space_rep = torch.fft.fftshift(stft_xy, (-2, -1))
83 |
84 | else: # 3D light field output
85 | stft_xy = stft_x.reshape(batch_size, Ny, n_fft_x, Nx//hop_length[1]).permute(0, 1, 3, 2)
86 | stft_xy = stft_xy.unsqueeze(1).unsqueeze(4)
87 | freq_space_rep = torch.fft.fftshift(stft_xy, -1)
88 |
89 | else: # with 1D input -- to be deprecated
90 | freq_space_rep = torch.stft(input_field.squeeze(),
91 | n_fft=n_fft, hop_length=hop_length, onesided=False, window=win_func,
92 | win_length=win_length, normalized=False, return_complex=True)
93 | elif impl == 'olas':
94 | # 2) Our own implementation:
95 | # slide 1d representation to left and right (to amount of win_length/2) and stack in another dimension
96 | overlap_field = torch.zeros(*input_field.shape[:2],
97 | (win_func.shape[0] - 1) + input_length[0],
98 | (win_func.shape[1] - 1) + input_length[1],
99 | win_func.shape[0], win_func.shape[1],
100 | dtype=input_field.dtype).to(input_field.device)
101 |
102 | # slide the input field
103 | for i in range(win_length_y):
104 | for j in range(win_length_x):
105 | overlap_field[..., i:i+input_length[0], j:j+input_length[1], i, j] = input_field
106 |
107 | # toward the new dimensions, apply the window function and take fourier transform.
108 | win_func = win_func.reshape(1, 1, 1, 1, *win_func.shape)
109 | win_func = win_func.repeat(*input_field.shape[:2], *overlap_field.shape[2:4], 1, 1)
110 | overlap_field *= win_func # apply window
111 |
112 | # take Fourier transform (it will pad zeros when n_fft > win_length)
113 | # apply no normalization since window is already normalized
114 | if n_fft_y > 1:
115 | overlap_field = torch.fft.fftshift(torch.fft.ifft(overlap_field, n=n_fft_y, norm='forward', dim=-2), -2)
116 | freq_space_rep = torch.fft.fftshift(torch.fft.ifft(overlap_field, n=n_fft_x, norm='forward', dim=-1), -1)
117 |
118 | # take every hop_length columns, and when hop_length == win_length it should be HS.
119 | freq_space_rep = freq_space_rep[:,:, win_length_y//2:win_length_y//2+input_length[0]:hop_length[0],
120 | win_length_x//2:win_length_x//2+input_length[1]:hop_length[1], ...]
121 |
122 | return freq_space_rep.abs()**2 # LF = |U|^2
123 |
124 |
125 | def lf2holo(light_field, light_field_depth, wavelength, pixel_pitch, win=None, target_phase=None):
126 | """
127 | Pytorch implementation of OLAS, Padmanban et al., (2019)
128 |
129 | :param light_field:
130 | :param light_field_depth:
131 | :param wavelength:
132 | :param pixel_pitch:
133 | :param win:
134 | :param target_phase:
135 | :return:
136 | """
137 |
138 | # hogel size is same as angular resolution
139 | res_hogel = light_field.shape[-2:]
140 |
141 | # resolution of hologram is same spatial resolution of light field
142 | res_hologram = light_field.shape[2:4]
143 |
144 | # initialize hologram with zeros, padded to avoid edges/for centering
145 | radius_hogel = torch.tensor(res_hogel) // 2
146 | apas_ola = torch.zeros(*(torch.tensor(res_hologram) + radius_hogel * 2),
147 | dtype=torch.complex64, device=light_field.device)
148 |
149 | #######################################################################
150 | # compute synthesis window
151 | # custom version of hann without zeros at start
152 | if win is None:
153 | w_func = lambda length: torch.hann_window(length + 1, device=light_field.device)[1:]
154 | # w_func = lambda length: torch.ones(length)
155 | win = torch.ger(w_func(res_hogel[0]), w_func(res_hogel[1]))
156 | win /= win.sum()
157 |
158 | #######################################################################
159 |
160 | # compute complex field
161 | comp_depth = torch.zeros(light_field_depth.shape, device=light_field.device)
162 |
163 | # apply depth compensation
164 | fx = torch.linspace(-1 + 1 / res_hogel[1], 1 - 1 / res_hogel[1],
165 | res_hogel[1], device=light_field.device) / (2 * pixel_pitch[1])
166 | fy = torch.linspace(-1 + 1 / res_hogel[0], 1 - 1 / res_hogel[0],
167 | res_hogel[0], device=light_field.device) / (2 * pixel_pitch[0])
168 |
169 | y = torch.linspace(-pixel_pitch[0] * res_hologram[0] / 2,
170 | pixel_pitch[0] * res_hologram[0] / 2,
171 | res_hologram[0], device=light_field.device)
172 | x = torch.linspace(-pixel_pitch[1] * res_hologram[1] / 2,
173 | pixel_pitch[1] * res_hologram[1] / 2,
174 | res_hologram[1], device=light_field.device)
175 | y, x = torch.meshgrid(y, x)
176 |
177 | for ky in range(res_hogel[0]):
178 | for kx in range(res_hogel[1]):
179 | theta = torch.asin(torch.sqrt(fx[kx] ** 2 + fy[ky] ** 2) * wavelength)
180 | comp_depth[..., ky, kx] = (light_field_depth[..., ky, kx] * (1 - torch.cos(theta)))
181 |
182 | # comp_depth[..., ky, kx] = (fx[kx] * x + fy[ky] * y) * wavelength
183 | print(comp_depth.max(), comp_depth.min())
184 |
185 | comp_amp = torch.sqrt(light_field)
186 | comp_phase = 2 * math.pi / wavelength * comp_depth
187 |
188 | if target_phase is not None:
189 | x_pos = torch.zeros_like(comp_depth)
190 | y_pos = torch.zeros_like(comp_depth)
191 | for ky in range(res_hogel[0]):
192 | y_pos[..., ky, :] = (light_field_depth[..., ky, :] * fy[ky] * wavelength
193 | + y.unsqueeze(-1).unsqueeze(0).unsqueeze(0)) * 2/(pixel_pitch[0] * target_phase.shape[-2])
194 | for kx in range(res_hogel[1]):
195 | x_pos[..., kx] = (light_field_depth[..., kx] * fx[kx] * wavelength
196 | + x.unsqueeze(-1).unsqueeze(0).unsqueeze(0)) * 2/(pixel_pitch[1] * target_phase.shape[-1])
197 | for ky in range(res_hogel[0]):
198 | for kx in range(res_hogel[1]):
199 | sample_grid = torch.stack((x_pos[:, 0, :, :, ky, kx], y_pos[:, 0, :, :, ky, kx]), -1)
200 | comp_phase[..., ky, kx] += F.grid_sample(target_phase, sample_grid,
201 | padding_mode='reflection')
202 |
203 | complex_lf = comp_amp * torch.exp(1j * comp_phase)
204 |
205 | # fft over the hogel dimension
206 | complex_lf = torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(complex_lf, dim=(-2, -1)),
207 | dim=(-2, -1), norm='forward'), dim=(-2, -1))
208 |
209 | # apply window, extra dims are for spatial dims, color, and complex dim
210 | complex_lf = complex_lf * win[None, None, None, None, ...]
211 |
212 | # overlap and add the hogels
213 | for ky in range(res_hogel[0]):
214 | for kx in range(res_hogel[1]):
215 | apas_ola[...,
216 | ky:ky + res_hologram[0],
217 | kx:kx + res_hologram[1]] += complex_lf[..., ky, kx].squeeze()
218 |
219 | # crop back to light field size
220 | return apas_ola[..., radius_hogel[0]:-radius_hogel[0], radius_hogel[1]:-radius_hogel[1]].unsqueeze(0).unsqueeze(0)
221 |
222 |
--------------------------------------------------------------------------------
/hw/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/time-multiplexed-neural-holography/5cf6c275c459652abb3ddddd2e167f9584072aeb/hw/__init__.py
--------------------------------------------------------------------------------
/hw/calibration_module.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the script containing the calibration module, basically calculating homography matrix.
3 |
4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
6 | # The material is provided as-is, with no warranties whatsoever.
7 | # If you publish any code, data, or scientific work based on this, please cite our work.
8 |
9 | Technical Paper:
10 | Y. Peng, S. Choi, N. Padmanaban, G. Wetzstein. Neural Holography with Camera-in-the-loop Training. ACM TOG (SIGGRAPH Asia), 2020.
11 | """
12 |
13 | import numpy as np
14 | import matplotlib.pyplot as plt
15 | import torchvision
16 | import cv2
17 | import skimage.transform as transform
18 | import time
19 | import datetime
20 | from scipy.io import savemat, loadmat
21 | from scipy.ndimage import map_coordinates
22 | import torch
23 | import torch.nn.functional as F
24 | import torch.nn as nn
25 |
26 | def id(x):
27 | return x
28 |
29 | def circle_detect(captured_img, num_circles, spacing, pad_pixels=(0., 0.), show_preview=True, quadratic=False):
30 | """
31 | Detects the circle of a circle board pattern
32 |
33 | :param captured_img: captured image
34 | :param num_circles: a tuple of integers, (num_circle_x, num_circle_y)
35 | :param spacing: a tuple of integers, in pixels, (space between circles in x, space btw circs in y direction)
36 | :param show_preview: boolean, default True
37 | :param pad_pixels: coordinate of the left top corner of warped image.
38 | Assuming pad this amount of pixels on the other side.
39 | :return: a tuple, (found_dots, H)
40 | found_dots: boolean, indicating success of calibration
41 | H: a 3x3 homography matrix (numpy)
42 | """
43 |
44 | # Binarization
45 | # org_copy = org.copy() # Otherwise, we write on the original image!
46 | img = (np.clip(captured_img.copy(), 0, 1) * 255).astype(np.uint8)
47 | print(img[...,0].mean())
48 | print(img[...,1].mean())
49 | print(img[...,2].mean())
50 | if len(img.shape) > 2:
51 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
52 | cv2.imwrite("temp/img_gray.png", img)
53 | print(img[...,0].mean())
54 | print(img[...,1].mean())
55 | print(img[...,2].mean())
56 |
57 |
58 | img = cv2.medianBlur(img, 5) # Red 71
59 | # cv2.imwrite("temp/img_blur.png", img)
60 | img_gray = img.copy()
61 |
62 | img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 201, 0)
63 | cv2.imwrite("temp/img_adapt_thres.png", img)
64 |
65 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
66 | img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel)
67 | cv2.imwrite("temp/img_open.png", img)
68 | img = 255 - img
69 |
70 | # Blob detection
71 | params = cv2.SimpleBlobDetector_Params()
72 |
73 | # Change thresholds
74 | params.filterByColor = True
75 | params.minThreshold = 121
76 |
77 | # Filter by Area.
78 | params.filterByArea = True
79 | params.minArea = 150
80 |
81 | # Filter by Circularity
82 | params.filterByCircularity = True
83 | params.minCircularity = 0.5 # change here, easier to detect blob
84 |
85 | # Filter by Convexity
86 | params.filterByConvexity = True
87 | params.minConvexity = 0.3
88 |
89 | # Filter by Inertia
90 | params.filterByInertia = False
91 | params.minInertiaRatio = 0.01
92 |
93 | detector = cv2.SimpleBlobDetector_create(params)
94 |
95 | # Detecting keypoints
96 | # this is redundant for what comes next, but gives us access to the detected dots for debug
97 | keypoints = detector.detect(img)
98 | found_dots, centers = cv2.findCirclesGrid(img, (num_circles[1], num_circles[0]),
99 | blobDetector=detector, flags=cv2.CALIB_CB_SYMMETRIC_GRID)
100 |
101 | # Drawing the keypoints
102 | cv2.drawChessboardCorners(captured_img, num_circles, centers, found_dots)
103 | img_gray = cv2.drawKeypoints(img_gray, keypoints, np.array([]), (0, 255, 0),
104 | cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
105 |
106 | # Find transformation
107 | H = np.array([[1., 0., 0.],
108 | [0., 1., 0.],
109 | [0., 0., 1.]], dtype=np.float32)
110 | ref_pts = np.zeros((num_circles[0] * num_circles[1], 1, 2), np.float32)
111 | pos = 0
112 | for j in range(0, num_circles[0]):
113 | for i in range(0, num_circles[1]):
114 | ref_pts[pos, 0, :] = spacing * np.array([i, j]) + np.array([pad_pixels[1], pad_pixels[0]])
115 |
116 | pos += 1
117 | ref_pts = ref_pts.reshape(num_circles[0] * num_circles[1], 2)
118 | if found_dots:
119 | # Generate reference points to compute the homography
120 | print("Found dots")
121 | H, mask = cv2.findHomography(centers, ref_pts, cv2.RANSAC, 1)
122 |
123 | centers = np.flip(centers.reshape(num_circles[0] * num_circles[1], 2), 1)
124 | homography_cache = {'H':H, 'centers':centers}
125 | savemat(f'./cache_h.mat', homography_cache)
126 | else:
127 | print("No dots")
128 | homography_cache = loadmat(f'./cache_h.mat')
129 | H = homography_cache['H']
130 | centers = homography_cache['centers']
131 |
132 |
133 | now = datetime.datetime.now()
134 | mdic = {"centers": centers, 'H': H}
135 | dsize = [int((num_circs - 1) * space + 2 * pad_pixs)
136 | for num_circs, space, pad_pixs in zip(num_circles, spacing, pad_pixels) ]
137 | if quadratic:
138 | H = transform.estimate_transform('polynomial', ref_pts, centers)
139 | coords = transform.warp_coords(H, dsize, dtype=np.float32) # for pytorch
140 | else:
141 | tf = transform.estimate_transform('projective', ref_pts, centers)
142 | coords = transform.warp_coords(tf, (800, 1280), dtype=np.float32) # for pytorch
143 |
144 | if show_preview:
145 | dsize = [int((num_circs - 1) * space + 2 * pad_pixs)
146 | for num_circs, space, pad_pixs in zip(num_circles, spacing, pad_pixels)]
147 | if quadratic:
148 | captured_img_warp = transform.warp(captured_img, H, output_shape=(dsize[0], dsize[1]))
149 | else:
150 | captured_img_warp = cv2.warpPerspective(captured_img, H, (dsize[1], dsize[0]))
151 |
152 |
153 | if show_preview:
154 | fig = plt.figure()
155 |
156 | ax = fig.add_subplot(223) # grayscale
157 | ax.imshow(img_gray, cmap='gray')
158 |
159 | ax2 = fig.add_subplot(221) # binarized image
160 | ax2.imshow(img, cmap='gray')
161 |
162 | ax3 = fig.add_subplot(222) # captured image
163 | ax3.imshow(captured_img, cmap='gray')
164 |
165 | if found_dots:
166 | ax4 = fig.add_subplot(224)
167 | ax4.imshow(captured_img_warp, cmap='gray')
168 |
169 | plt.show()
170 |
171 | return found_dots, H, coords
172 |
173 |
174 | class Warper(nn.Module):
175 | def __init__(self, params_calib):
176 | super(Warper, self).__init__()
177 | self.num_circles = params_calib.num_circles
178 | self.spacing_size = params_calib.spacing_size
179 | self.pad_pixels = params_calib.pad_pixels
180 | self.quadratic = params_calib.quadratic
181 | self.img_size_native = params_calib.img_size_native # get this from image
182 | self.h_transform = np.array([[1., 0., 0.],
183 | [0., 1., 0.],
184 | [0., 0., 1.]])
185 | self.range_x = params_calib.range_x # slice
186 | self.range_y = params_calib.range_y # slice
187 |
188 |
189 | def calibrate(self, img, show_preview=True):
190 | img_masked = np.zeros_like(img)
191 | img_masked[self.range_y, self.range_x, ...] = img[self.range_y, self.range_x, ...]
192 |
193 | found_corners, self.h_transform, self.coords = circle_detect(img_masked, self.num_circles,
194 | self.spacing_size, self.pad_pixels, show_preview,
195 | quadratic=self.quadratic)
196 |
197 | if not self.coords is None:
198 | self.coords_tensor = torch.tensor(np.transpose(self.coords, (1, 2, 0)),
199 | dtype=torch.float32).unsqueeze(0)
200 |
201 | # normalize it into [-1, 1]
202 | self.coords_tensor[..., 0] = 2*self.coords_tensor[..., 0] / (self.img_size_native[1]-1) - 1
203 | self.coords_tensor[..., 1] = 2*self.coords_tensor[..., 1] / (self.img_size_native[0]-1) - 1
204 |
205 | return found_corners
206 |
207 | def __call__(self, input_img, img_size=None):
208 | """
209 | This forward pass returns the warped image.
210 |
211 | :param input_img: A numpy grayscale image shape of [H, W].
212 | :param img_size: output size, default None.
213 | :return: output_img: warped image with pre-calculated homography and destination size.
214 | """
215 |
216 | if img_size is None:
217 | img_size = [int((num_circs - 1) * space + 2 * pad_pixs)
218 | for num_circs, space, pad_pixs in zip(self.num_circles, self.spacing_size, self.pad_pixels)]
219 |
220 | if torch.is_tensor(input_img):
221 | output_img = F.grid_sample(input_img, self.coords_tensor, align_corners=True)
222 | else:
223 | if self.quadratic:
224 | output_img = transform.warp(input_img, self.h_transform, output_shape=(img_size[0], img_size[1]))
225 | else:
226 | output_img = cv2.warpPerspective(input_img, self.h_transform, (img_size[0], img_size[1]))
227 |
228 | return output_img
229 |
230 | @property
231 | def h_transform(self):
232 | return self._h_transform
233 |
234 | @h_transform.setter
235 | def h_transform(self, new_h):
236 | self._h_transform = new_h
237 |
238 | def to(self, *args, **kwargs):
239 | slf = super().to(*args, **kwargs)
240 | if slf.coords_tensor is not None:
241 | slf.coords_tensor = slf.coords_tensor.to(*args, **kwargs)
242 | try:
243 | slf.dev = next(slf.parameters()).device
244 | except StopIteration: # no parameters
245 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0]
246 | if device_arg is not None:
247 | slf.dev = device_arg
248 | return slf
--------------------------------------------------------------------------------
/hw/camera_capture_module.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the script containing the calibration module, basically calculating homography matrix.
3 |
4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
6 | # The material is provided as-is, with no warranties whatsoever.
7 | # If you publish any code, data, or scientific work based on this, please cite our work.
8 |
9 | Technical Paper:
10 | Time-multiplexed Neural Holography:
11 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
12 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein.
13 | SIGGRAPH 2022
14 | """
15 |
16 | import PyCapture2
17 | import cv2
18 | import numpy as np
19 | import utils
20 |
21 |
22 | def callback_captured(image):
23 | print(image.getData())
24 |
25 |
26 | class CameraCapture:
27 | def __init__(self, params):
28 | self.bus = PyCapture2.BusManager()
29 | num_cams = self.bus.getNumOfCameras()
30 | if not num_cams:
31 | exit()
32 | # self.demosaick_rule = cv2.COLOR_BAYER_RG2BGR
33 | #self.demosaick_rule = cv2.COLOR_BAYER_GR2RGB # GBRG to RGB
34 | self.demosaick_rule = cv2.COLOR_BAYER_BG2RGB # RGGB to RGB, Grasshopper3, U3, projector
35 | self.params = params
36 |
37 | def connect(self, i, trigger=False):
38 | uid = self.bus.getCameraFromIndex(i)
39 | self.camera_device = PyCapture2.Camera()
40 | self.camera_device.connect(uid)
41 | self.camera_device.setConfiguration(highPerformanceRetrieveBuffer=True)
42 | self.camera_device.setConfiguration(numBuffers=1000)
43 | config = self.camera_device.getConfiguration()
44 | self.toggle_embedded_timestamp(True)
45 |
46 | if trigger:
47 | trigger_mode = self.camera_device.getTriggerMode()
48 | trigger_mode.onOff = True
49 | trigger_mode.mode = 0
50 | trigger_mode.parameter = 0
51 | trigger_mode.source = 3 # Using software trigger
52 | self.camera_device.setTriggerMode(trigger_mode)
53 | else:
54 | trigger_mode = self.camera_device.getTriggerMode()
55 | trigger_mode.onOff = False
56 | trigger_mode.mode = 0
57 | trigger_mode.parameter = 0
58 | trigger_mode.source = 3 # Using software trigger
59 | self.camera_device.setTriggerMode(trigger_mode)
60 |
61 | trigger_mode = self.camera_device.getTriggerMode()
62 | if trigger_mode.onOff is True:
63 | print(' - setting trigger mode on')
64 |
65 | def set_shutter_speed(self, val):
66 | self.camera_device.setProperty(type = PyCapture2.PROPERTY_TYPE.SHUTTER, autoManualMode = False, absValue = val)
67 | shutter_speed = self.camera_device.getProperty(PyCapture2.PROPERTY_TYPE.SHUTTER ).absValue
68 | print(f"Shutter speed set to {shutter_speed}ms")
69 |
70 | def set_gain(self, val):
71 | self.camera_device.setProperty(type = PyCapture2.PROPERTY_TYPE.GAIN, autoManualMode = False, absValue = val)
72 | gain = self.camera_device.getProperty(PyCapture2.PROPERTY_TYPE.GAIN).absValue
73 | print(f"Gain set to {gain}dB")
74 |
75 | def disconnect(self):
76 | self.toggle_embedded_timestamp(False)
77 | self.camera_device.disconnect()
78 |
79 | def toggle_embedded_timestamp(self, enable_timestamp):
80 | embedded_info = self.camera_device.getEmbeddedImageInfo()
81 | if embedded_info.available.timestamp:
82 | self.camera_device.setEmbeddedImageInfo(timestamp=enable_timestamp)
83 |
84 | def grab_images(self, num_images_to_grab=1):
85 | """
86 | Retrieve the camera buffer and returns a list of grabbed images.
87 |
88 | :param num_images_to_grab: integer, default 1
89 | :return: a list of numpy 2d color images from the camera buffer.
90 | """
91 | self.camera_device.startCapture()
92 | img_list = []
93 | for i in range(num_images_to_grab):
94 | imgData = self.retrieve_buffer()
95 | offset = 64 # offset that inherently exist.retrieve_buffer
96 | imgData = imgData - offset
97 |
98 | color_cv_image = cv2.cvtColor(imgData, self.demosaick_rule)
99 | color_cv_image = utils.im2float(color_cv_image)
100 | img_list.append(color_cv_image.copy())
101 |
102 | self.camera_device.stopCapture()
103 | return img_list
104 |
105 | def grab_images_fast(self, num_images_to_grab=1):
106 | """
107 | Retrieve the camera buffer and returns a grabbed image
108 |
109 | :param num_images_to_grab: integer, default 1
110 | :return: a list of numpy 2d color images from the camera buffer.
111 | """
112 | imgData = self.retrieve_buffer()
113 | offset = 64 # offset that inherently exist.
114 | imgData = imgData - offset
115 |
116 | color_cv_image = cv2.cvtColor(imgData, self.demosaick_rule)
117 | color_cv_image = utils.im2float(color_cv_image)
118 | color_img = color_cv_image
119 | return color_img
120 |
121 | def retrieve_buffer(self):
122 | try:
123 | img = self.camera_device.retrieveBuffer()
124 | except PyCapture2.Fc2error as fc2Err:
125 | raise fc2Err
126 |
127 | imgData = img.getData()
128 |
129 | # when using raw8 from the PG sensor
130 | # cv_image = np.array(img.getData(), dtype="uint8").reshape((img.getRows(), img.getCols()))
131 |
132 | # when using raw16 from the PG sensor - concat 2 8bits in a row
133 | imgData.dtype = np.uint16
134 | imgData = imgData.reshape(img.getRows(), img.getCols())
135 | return imgData.copy()
136 |
137 | def start_capture(self):
138 | # these two were previously inside the grab_images func, and can be clarified outside the loop
139 | self.camera_device.startCapture()
140 |
141 | def stop_capture(self):
142 | self.camera_device.stopCapture()
143 |
144 | @property
145 | def params(self):
146 | return self._params
147 |
148 | @params.setter
149 | def params(self, p):
150 | self._params = p
--------------------------------------------------------------------------------
/hw/detect_heds_module_path.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | #--------------------------------------------------------------------#
4 | # #
5 | # Copyright (C) 2020 HOLOEYE Photonics AG. All rights reserved. #
6 | # Contact: https://holoeye.com/contact/ #
7 | # #
8 | # This file is part of HOLOEYE SLM Display SDK. #
9 | # #
10 | # You may use this file under the terms and conditions of the #
11 | # 'HOLOEYE SLM Display SDK Standard License v1.0' license agreement. #
12 | # #
13 | #--------------------------------------------------------------------#
14 |
15 |
16 | # Please import this file in your scripts before actually importing the HOLOEYE SLM Display SDK,
17 | # i. e. copy this file to your project and use this code in your scripts:
18 | #
19 | # import detect_heds_module_path
20 | # import holoeye
21 | #
22 | #
23 | # Another option is to copy the holoeye module directory into your project and import by only using
24 | # import holoeye
25 | # This way, code completion etc. might work better.
26 |
27 |
28 | import os, sys
29 | from platform import system
30 |
31 | # Import the SLM Display SDK:
32 | HEDSModulePath = os.getenv('HEDS_2_PYTHON_MODULES', '')
33 |
34 | if HEDSModulePath == '':
35 | sdklocal = os.path.abspath(os.path.join(os.path.dirname(__file__),
36 | 'holoeye', 'slmdisplaysdk', '__init__.py'))
37 | if os.path.isfile(sdklocal):
38 | HEDSModulePath = os.path.dirname(os.path.dirname(os.path.dirname(sdklocal)))
39 | else:
40 | sdklocal = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',
41 | 'sdk', 'holoeye', 'slmdisplaysdk', '__init__.py'))
42 | if os.path.isfile(sdklocal):
43 | HEDSModulePath = os.path.dirname(os.path.dirname(os.path.dirname(sdklocal)))
44 |
45 | if HEDSModulePath == '':
46 | if system() == 'Windows':
47 | print('\033[91m'
48 | '\nError: Could not find HOLOEYE SLM Display SDK installation path from environment variable. '
49 | '\n\nPlease relogin your Windows user account and try again. '
50 | '\nIf that does not help, please reinstall the SDK and then relogin your user account and try again. '
51 | '\nA simple restart of the computer might fix the problem, too.'
52 | '\033[0m')
53 | else:
54 | print('\033[91m'
55 | '\nError: Could not detect HOLOEYE SLM Display SDK installation path. '
56 | '\n\nPlease make sure it is present within the same folder or in "../../sdk".'
57 | '\033[0m')
58 |
59 | sys.exit(1)
60 |
61 | sys.path.append(HEDSModulePath)
62 |
--------------------------------------------------------------------------------
/hw/discrete_slm.py:
--------------------------------------------------------------------------------
1 | """
2 | Any info about discrete SLM
3 |
4 | Technical Paper:
5 | Time-multiplexed Neural Holography:
6 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
7 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein.
8 | SIGGRAPH 2022
9 | """
10 |
11 | import torch
12 | import hw.ti as ti
13 | import utils
14 |
15 |
16 | class DiscreteSLM:
17 | """
18 | Class for Discrete SLM that supports discrete LUT
19 | """
20 | _lut_midvals = None
21 | _lut = None
22 | prev_idx = 0.
23 |
24 | @property
25 | def lut_midvals(self):
26 | return self._lut_midvals
27 |
28 | @lut_midvals.setter
29 | def lut_midvals(self, new_midvals):
30 | self._lut_midvals = torch.tensor(new_midvals)#, device=torch.device('cuda'))
31 |
32 | @property
33 | def lut(self):
34 | return self._lut
35 |
36 | @lut.setter
37 | def lut(self, new_lut):
38 | if new_lut is None:
39 | self._lut = None
40 | else:
41 | self.lut_midvals = utils.lut_mid(new_lut)
42 | if torch.is_tensor(new_lut):
43 | self._lut = new_lut.clone().detach()
44 | else:
45 | self._lut = torch.tensor(new_lut)#, device=torch.device('cuda'))
46 |
47 |
48 | DiscreteSLM = DiscreteSLM() # class singleton
49 | DiscreteSLM.lut = ti.given_lut
50 |
51 | #num_bits = 4
52 | #DiscreteSLM.lut = np.linspace(-math.pi, math.pi, 2**num_bits + 1) # test for ideal lut
53 |
54 |
--------------------------------------------------------------------------------
/hw/phase_encodings.py:
--------------------------------------------------------------------------------
1 | """
2 | Encoding and decoding functions for our TI SLM.
3 |
4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu)
5 |
6 | Technical Paper:
7 | Time-multiplexed Neural Holography:
8 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
9 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein.
10 | SIGGRAPH 2022
11 | """
12 |
13 | import numpy as np
14 | import hw.ti_encodings as ti_encodings
15 |
16 |
17 | def phasemap_8bit(phasemap, inverted=True):
18 | """convert a phasemap tensor into a numpy 8bit phasemap that can be directly displayed
19 |
20 | :param phasemap: input phasemap tensor, which is supposed to be in the range of [-pi, pi].
21 | :param inverted: a boolean value that indicates whether the phasemap is inverted.
22 |
23 | :return: output phasemap, with uint8 dtype (in [0, 255])
24 | """
25 |
26 | output_phase = ((phasemap + np.pi) % (2 * np.pi)) / (2 * np.pi)
27 | if inverted:
28 | phase_out_8bit = ((1 - output_phase) * 255).round().cpu().detach().squeeze().numpy().astype(np.uint8) # quantized to 8 bits
29 | else:
30 | phase_out_8bit = ((output_phase) * 255).round().cpu().detach().squeeze().numpy().astype(np.uint8) # quantized to 8 bits
31 | return phase_out_8bit
32 |
33 |
34 | def phase_encoding(phase, slm_type):
35 | assert len(phase.shape) == 4
36 | """ phase encoding for SLM """
37 | if slm_type.lower() in ('holoeye', 'leto', 'pluto'):
38 | return phasemap_8bit(phase)
39 | elif slm_type.lower() in ('ti', "ee236a"):
40 | return np.fliplr(ti_encodings.rgb_encoding(phase.cpu()))
41 | else:
42 | return None
--------------------------------------------------------------------------------
/hw/slm_display_module.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the script containing the calibration module, basically calculating homography matrix.
3 |
4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
6 | # The material is provided as-is, with no warranties whatsoever.
7 | # If you publish any code, data, or scientific work based on this, please cite our work.
8 |
9 | Technical Paper:
10 | Time-multiplexed Neural Holography:
11 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
12 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein.
13 | SIGGRAPH 2022
14 | """
15 |
16 | import hw.detect_heds_module_path
17 | import holoeye
18 | from holoeye import slmdisplaysdk
19 |
20 |
21 | class SLMDisplay:
22 | ErrorCode = slmdisplaysdk.SLMDisplay.ErrorCode
23 | ShowFlags = slmdisplaysdk.SLMDisplay.ShowFlags
24 | State = slmdisplaysdk.SLMDisplay.State
25 | ApplyDataHandleValue = slmdisplaysdk.SLMDisplay.ApplyDataHandleValue
26 |
27 | def __init__(self):
28 | self.ErrorCode = slmdisplaysdk.SLMDisplay.ErrorCode
29 | self.ShowFlags = slmdisplaysdk.SLMDisplay.ShowFlags
30 |
31 | self.displayOptions = self.ShowFlags.PresentAutomatic # PresentAutomatic == 0 (default)
32 | self.displayOptions |= self.ShowFlags.PresentFitWithBars
33 |
34 | def connect(self):
35 | self.slm_device = slmdisplaysdk.SLMDisplay()
36 | self.slm_device.open() # For version 2.0.1
37 |
38 | def disconnect(self):
39 | self.slm_device.release()
40 |
41 | def show_data_from_file(self, filepath):
42 | error = self.slm_device.showDataFromFile(filepath, self.displayOptions)
43 | assert error == self.ErrorCode.NoError, self.slm_device.errorString(error)
44 |
45 | def show_data_from_array(self, numpy_array):
46 | error = self.slm_device.showData(numpy_array)
47 | assert error == self.ErrorCode.NoError, self.slm_device.errorString(error)
48 |
--------------------------------------------------------------------------------
/hw/ti.py:
--------------------------------------------------------------------------------
1 | """
2 | Data from the TI SLM manual
3 |
4 | Technical Paper:
5 | Time-multiplexed Neural Holography:
6 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
7 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein.
8 | SIGGRAPH 2022
9 | """
10 |
11 | import math
12 | import torch
13 |
14 | given_chart = (0.,
15 | 1.07,
16 | 2.19,
17 | 4.50,
18 | 5.98,
19 | 7.75,
20 | 12.06,
21 | 18.5,
22 | 36.55,
23 | 39.55,
24 | 45.1,
25 | 52.44,
26 | 63.93,
27 | 71.16,
28 | 85.02,
29 | 100.)
30 | adjusted = [p / 100 * 15 / 16 * 2 * math.pi for p in given_chart]
31 | adjusted.append(adjusted[0] + 2*math.pi)
32 | given_lut = [p - math.pi for p in adjusted] # [-pi, pi]
33 |
34 | idx_order = [4, 2, 1, 0, 7, 6, 5, 3, 11, 10, 9, 8, 15, 14, 13, 12] # see manual
35 | idx2phase = torch.tensor([given_lut[idx_order[i]] for i in range(len(idx_order))])
36 |
--------------------------------------------------------------------------------
/hw/ti_encodings.py:
--------------------------------------------------------------------------------
1 | """
2 | Encoding and decoding functions for our TI SLM.
3 |
4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu)
5 |
6 | Technical Paper:
7 | Time-multiplexed Neural Holography:
8 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
9 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, G. Wetzstein.
10 | SIGGRAPH 2022
11 | """
12 |
13 | import numpy as np
14 | import torch
15 | import utils
16 | import hw.ti as ti
17 | from hw.discrete_slm import DiscreteSLM
18 |
19 |
20 | def binary_encoding_ti_slm(phase):
21 | """ gets phase in [-pi, pi] and returns binary encoded phase of DMD """
22 | #print("binary", phase.shape)
23 | idx = utils.nearest_idx(phase, DiscreteSLM.lut_midvals)
24 | height = phase.shape[2] * 2
25 | width = phase.shape[3] * 2
26 |
27 | encoded_phase = torch.zeros(1, 1, height, width).to(phase.device)
28 | encoded_phase[:, :, ::2, 1::2] = torch.div(idx, 8, rounding_mode='floor') # M3, ur
29 | encoded_phase[:, :, 1::2, 1::2] = torch.where(
30 | torch.logical_or(idx == 3,
31 | torch.logical_and(idx != 4, idx % 8 >= 4)), 1, 0) # M2, dr
32 | encoded_phase[:, :, ::2, ::2] = torch.where(
33 | torch.logical_or(idx == 3,
34 | torch.logical_and(idx != 4, (idx % 4) < 2)), 1, 0) # M1, ul
35 | encoded_phase[:, :, 1::2, ::2] = torch.where(
36 | torch.logical_or(idx == 3,
37 | torch.logical_and(idx != 4, (idx % 2 == 0))), 1, 0) # M0, dl
38 |
39 | return encoded_phase
40 |
41 |
42 | def bit_encoding(phase, bits):
43 | """ gets phase of shape (N, 1, H, W) and returns """
44 |
45 | power = sum(2**b for b in bits)
46 | return binary_encoding_ti_slm(phase) * power
47 |
48 |
49 | def rgb_encoding(phase, ch=None):
50 | """ gets phase in a batch ot tensor and return RGB-encoded phase (for specific TI) """
51 | phase = (phase + np.pi) % (2*np.pi) - np.pi
52 | num_phases = len(phase)
53 | #print("rgb", phase.shape)
54 | if num_phases % 3 == 0:
55 | num_bits_per_ch = num_phases // 3
56 |
57 | # placeholder with doubled resolution
58 | res = np.zeros((*(2*p for p in phase.shape[2:]), 3), dtype=np.uint8)
59 | for c in range(3):
60 | res[..., c] = rgb_encoding(phase[c*num_bits_per_ch:(c+1)*num_bits_per_ch, ...])
61 | return res
62 | else:
63 | phase = sum([bit_encoding(phase[j:j+1, ...], range(j*(8//num_phases), (j+1)*(8//num_phases)))
64 | for j in range(num_phases)])
65 |
66 | if ch is None:
67 | res = phase.squeeze().cpu().detach().numpy().astype(np.uint8)
68 | else:
69 | res = np.zeros((*phase.shape[2:], 3))
70 | res[..., ch] = phase.squeeze().cpu().detach().numpy().astype(np.uint8)
71 |
72 | return res
73 |
74 |
75 | def rgb_decoding(phase_img, num_frames=None, one_hot=False):
76 | """ gets phase values in [-pi, pi] from encoded phase image displayed
77 |
78 | :param phase_img: numpy image of [M, N, 3] channels
79 | :param num_frames: If not None, the number of frames should be known and reduce computation
80 | :param one_hot: If true, return one-hot decoded image (with number of channels 16)
81 | :return: A tensor either decoded phase (one-hot or exact value)
82 | """
83 | phase_img_flipped = torch.tensor(phase_img, dtype=torch.float32).flip(dims=[1]) # flip LR here
84 | if len(phase_img_flipped.shape) < 3:
85 | phase_img_flipped = phase_img_flipped.unsqueeze(2)
86 |
87 | # figure out what's the number of frames
88 | if num_frames is None:
89 | num_frames = num_frames_ti_phase(phase_img_flipped)
90 | num_ch = 3 if num_frames % 3 == 0 else 1
91 | # num_bit_per_ch = 8 // (num_frames // num_ch)
92 | num_frames_per_ch = num_frames // num_ch
93 | num_bit_per_ch = 8 // num_frames_per_ch
94 | slm_phase_2x = torch.zeros(num_frames, *phase_img_flipped.shape[:-1])
95 |
96 | # assign every the unique encoded binary image to each tensor (stack in batch dimension)
97 | for c in range(num_ch):
98 | for i in range(num_frames_per_ch):
99 | f = c * num_frames_per_ch + i
100 | slm_phase_2x[f, ...] = phase_img_flipped[..., c:c+1].squeeze().clone().detach() % 2
101 | phase_img_flipped[..., c:c+1].div_((2**num_bit_per_ch), rounding_mode='trunc')
102 |
103 | if one_hot:
104 | # return one-hot vector agnostic of the discrete phase values the SLM supports
105 | indices = decode_binary_phase(slm_phase_2x, return_index=True)
106 | output = torch.zeros((len(DiscreteSLM.lut_midvals), *indices.shape[-2:])).scatter_(0, indices, 1.0)
107 | else:
108 | # binary to 4bit, and apply LUT
109 | slm_phase = decode_binary_phase(slm_phase_2x)
110 | output = slm_phase.unsqueeze(1) # return a tensor shape of (N, 1, H, W)
111 |
112 | return output
113 |
114 |
115 | def num_frames_ti_phase(phase_img):
116 | """
117 | return the number of frames encoded in this numpy image.
118 |
119 | :param phase_img: phase pattern input
120 | :return: An integer, number of frames
121 | """
122 | if len(phase_img.shape) < 3 or phase_img.shape[2] == 1:
123 | num_frames = 1
124 | one_bit_imgs = torch.zeros((8, *phase_img.shape), device=phase_img.device)
125 | r = phase_img.clone().detach()
126 | else:
127 | r = phase_img[..., 0].clone().detach()
128 | g = phase_img[..., 1]
129 | b = phase_img[..., 2]
130 |
131 | img_size = r.shape
132 | one_bit_imgs = torch.zeros((8, *img_size))
133 |
134 | if ((r-g)**2).mean() < 1e-3 and ((g-b)**2).mean() < 1e-3:
135 | # monochromatic
136 | num_frames = 1
137 | else:
138 | num_frames = 3
139 |
140 | # check this is unique or not
141 | cnt = 0
142 | for i in range(8):
143 | one_bit_imgs[i, ...] = r % 2
144 | r //= 2 # shift 1 bit
145 | if ((one_bit_imgs[i, ...] - one_bit_imgs[0, ...])**2).mean() < 1e-3:
146 | cnt += 1
147 | return num_frames * (8 // cnt)
148 |
149 |
150 | def decode_binary_phase(binary_img, return_index=False):
151 | """
152 |
153 | :param phase_img: Assume as a tensor shape of (N, H, W)
154 | :return:
155 | """
156 | top_left = binary_img[..., ::2, ::2] # M1
157 | top_right = binary_img[..., ::2, 1::2] # M3
158 | bottom_left = binary_img[..., 1::2, ::2] # M0
159 | bottom_right = binary_img[..., 1::2, 1::2] # M2
160 |
161 | indices = 8 * top_right + 4 * bottom_right + 2 * top_left + bottom_left
162 | img_shape = indices.shape
163 | indices = indices.type(torch.int32)
164 | indices = indices.reshape(indices.numel())
165 |
166 | if return_index:
167 | # return index (0~15) per pixels
168 | memory_cell_lut = torch.tensor(ti.idx_order).to(binary_img.device)
169 | output = torch.index_select(memory_cell_lut, 0, indices).reshape(*img_shape)
170 | else:
171 | # return phase values
172 | decoded_phase = torch.index_select(ti.idx2phase.to(binary_img.device), 0, indices)
173 | output = decoded_phase.reshape(*img_shape)
174 |
175 | return output
176 |
177 |
178 | def merge_binary_phases(phases):
179 | """
180 |
181 | :param phases: input phase tensors
182 | :return:
183 | """
184 | rgb_phases = []
185 | for phase in phases:
186 | decoded_phase = rgb_decoding(phase)
187 | print(decoded_phase)
188 | rgb_phases.append(decoded_phase)
189 | rgb_phases = torch.cat(rgb_phases, 0)
190 | num_phases = rgb_phases.shape[0]
191 | if num_phases < 24:
192 | rgb_phases = torch.cat((rgb_phases, rgb_phases[:24-num_phases, ...]), 0)
193 | encoded_phase = rgb_encoding(torch.tensor(rgb_phases, dtype=torch.float32))
194 |
195 | return encoded_phase
196 |
--------------------------------------------------------------------------------
/img/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/time-multiplexed-neural-holography/5cf6c275c459652abb3ddddd2e167f9584072aeb/img/teaser.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu)
3 |
4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
6 | # The material is provided as-is, with no warranties whatsoever.
7 | # If you publish any code, data, or scientific work based on this, please cite our work.
8 |
9 | Technical Paper:
10 | Time-multiplexed Neural Holography:
11 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
12 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein.
13 | SIGGRAPH 2022
14 | -----
15 |
16 | $ python main.py --lr=0.01 --num_iters=10000 --num_frames=8 --quan_method=gumbel-softmax
17 |
18 | """
19 | import os
20 | import json
21 | import torch
22 | import imageio
23 | import configargparse
24 | from torch.utils.tensorboard import SummaryWriter
25 | from collections import defaultdict
26 |
27 | import utils
28 | import params
29 | import algorithms as algs
30 | import quantization as q
31 | import numpy as np
32 | import image_loader as loaders
33 | from torch.utils.data import DataLoader
34 | import props.prop_model as prop_model
35 | import props.prop_physical as prop_physical
36 | from hw.phase_encodings import phase_encoding
37 | from torchvision.utils import save_image
38 |
39 | from pprint import pprint
40 |
41 | #import wx
42 | #wx.DisableAsserts()
43 |
44 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
45 |
46 |
47 | def main():
48 | # Command line argument processing / Parameters
49 | torch.set_default_dtype(torch.float32)
50 | p = configargparse.ArgumentParser()
51 | p.add('-c', '--config_filepath', required=False,
52 | is_config_file=True, help='Path to config file.')
53 | params.add_parameters(p, 'eval')
54 | opt = params.set_configs(p.parse_args())
55 | params.add_lf_params(opt)
56 | dev = torch.device('cuda')
57 |
58 | run_id = params.run_id(opt)
59 | # path to save out optimized phases
60 | out_path = os.path.join(opt.out_path, run_id)
61 | print(f' - out_path: {out_path}')
62 |
63 | # Tensorboard
64 | summaries_dir = os.path.join(out_path, 'summaries')
65 | utils.cond_mkdir(summaries_dir)
66 | writer = SummaryWriter(summaries_dir)
67 |
68 | # Write opt to experiment folder
69 | utils.write_opt(vars(p.parse_args()), out_path)
70 |
71 | # Propagations
72 | camera_prop = None
73 | if opt.citl:
74 | camera_prop = prop_physical.PhysicalProp(*(params.hw_params(opt)), shutter_speed=opt.shutter_speed).to(dev)
75 | camera_prop.calibrate_total_laser_energy() # important!
76 | sim_prop = prop_model.model(opt)
77 | sim_prop.eval()
78 |
79 | # Look-up table of SLM
80 | if opt.use_lut:
81 | lut = q.load_lut(sim_prop, opt)
82 | else:
83 | lut = None
84 | quantization = q.quantization(opt, lut)
85 |
86 | # Algorithm
87 | algorithm = algs.load_alg(opt.method, mem_eff=opt.mem_eff)
88 |
89 | # Loader
90 | if ',' in opt.data_path:
91 | opt.data_path = opt.data_path.split(',')
92 | img_loader = loaders.TargetLoader(shuffle=opt.random_gen,
93 | vertical_flips=opt.random_gen,
94 | horizontal_flips=opt.random_gen,
95 | scale_vd_range=False, **opt)
96 |
97 | for i, target in enumerate(img_loader):
98 | target_amp, target_mask, target_idx = target
99 | target_amp = target_amp.to(dev).detach()
100 |
101 | if target_mask is not None:
102 | target_mask = target_mask.to(dev).detach()
103 | if len(target_amp.shape) < 4:
104 | target_amp = target_amp.unsqueeze(0)
105 |
106 | print(f' - run phase optimization for {target_idx}th image ...')
107 |
108 | if opt.random_gen: # random parameters for dataset generation
109 | img_files = os.listdir(out_path)
110 | img_files = [f for f in img_files if f.endswith('.png')]
111 | if len(img_files) > opt.num_data: # generate enough data
112 | break
113 | print("Num images: ", len(img_files), " (max: ", opt.num_data)
114 | opt.num_frames, opt.num_iters, opt.init_phase_range, \
115 | target_range, opt.lr, opt.eval_plane_idx, \
116 | opt.quan_method, opt.reg_lf_var = utils.random_gen(**opt)
117 | sim_prop = prop_model.model(opt)
118 | quantization = q.quantization(opt, lut)
119 | target_amp *= target_range
120 | if opt.reg_lf_var > 0.0 and isinstance(sim_prop, prop_model.CNNpropCNN):
121 | opt.num_frames = min(opt.num_frames, 4)
122 |
123 | out_path_idx = f'{opt.out_path}_{target_idx}'
124 |
125 | # initial slm phase
126 | init_phase = utils.init_phase(opt.init_phase_type, target_amp, dev, opt)
127 |
128 | # run algorithm
129 | results = algorithm(init_phase, target_amp, target_mask, target_idx,
130 | forward_prop=sim_prop, camera_prop=camera_prop,
131 | writer=writer, quantization=quantization,
132 | out_path_idx=out_path_idx, **opt)
133 |
134 | # optimized slm phase
135 | final_phase = results['final_phase']
136 | recon_amp = results['recon_amp']
137 | target_amp = results['target_amp']
138 |
139 | # encoding for SLM & save it out
140 | if opt.random_gen:
141 | # decompose it into several 1-bit phases
142 | for k, final_phase_1bit in enumerate(final_phase):
143 | phase_out = phase_encoding(final_phase_1bit.unsqueeze(0), opt.slm_type)
144 | phase_out_path = os.path.join(out_path, f'{target_idx}_{opt.num_iters}{k}.png')
145 | imageio.imwrite(phase_out_path, phase_out)
146 | else:
147 | phase_out = phase_encoding(final_phase, opt.slm_type)
148 | recon_amp, target_amp = recon_amp.squeeze().detach().cpu().numpy(), target_amp.squeeze().detach().cpu().numpy()
149 |
150 | # save final phase and intermediate phases
151 | if phase_out is not None:
152 | phase_out_path = os.path.join(out_path, f'{target_idx}_phase.png')
153 | imageio.imwrite(phase_out_path, phase_out)
154 |
155 | if opt.save_images:
156 | recon_out_path = os.path.join(out_path, f'{target_idx}_recon.png')
157 | target_out_path = os.path.join(out_path, f'{target_idx}_target.png')
158 |
159 | if opt.channel is None:
160 | recon_amp = recon_amp.transpose(1, 2, 0)
161 | target_amp = target_amp.transpose(1, 2, 0)
162 |
163 | recon_out = utils.srgb_lin2gamma(np.clip(recon_amp**2, 0, 1)) # linearize and gamma
164 | target_out = utils.srgb_lin2gamma(np.clip(target_amp**2, 0, 1)) # linearize and gamma
165 |
166 | imageio.imwrite(recon_out_path, (recon_out * 255).astype(np.uint8))
167 | imageio.imwrite(target_out_path, (target_out * 255).astype(np.uint8))
168 |
169 | if camera_prop is not None:
170 | camera_prop.disconnect()
171 |
172 | if __name__ == "__main__":
173 | main()
174 |
--------------------------------------------------------------------------------
/params.py:
--------------------------------------------------------------------------------
1 | """
2 | Default parameter settings for SLMs as well as laser/sensors
3 |
4 | """
5 | import sys
6 | import utils
7 | import datetime
8 | import torch.nn as nn
9 | from hw.discrete_slm import DiscreteSLM
10 | if sys.platform == 'win32':
11 | import serial
12 |
13 | cm, mm, um, nm = 1e-2, 1e-3, 1e-6, 1e-9
14 |
15 |
16 | def str2bool(v):
17 | """ Simple query parser for configArgParse (which doesn't support native bool from cmd)
18 | Ref: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
19 |
20 | """
21 | if isinstance(v, bool):
22 | return v
23 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
24 | return True
25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
26 | return False
27 | else:
28 | raise ValueError('Boolean value expected.')
29 |
30 |
31 | class PMap(dict):
32 | # use it for parameters
33 | __getattr__ = dict.get
34 | __setattr__ = dict.__setitem__
35 | __delattr__ = dict.__delitem__
36 |
37 |
38 | def clone_params(opt):
39 | """
40 | opt: PMap object
41 | """
42 | cloned = PMap()
43 | for k in opt.keys():
44 | cloned[k] = opt[k]
45 | return cloned
46 |
47 | def add_parameters(p, mode='train'):
48 | p.add_argument('--channel', type=int, default=None, help='Red:0, green:1, blue:2')
49 | p.add_argument('--method', type=str, default='SGD', help='Type of algorithm, GS/SGD/DPAC/HOLONET/UNET')
50 | p.add_argument('--slm_type', type=str, default='holoeye', help='holoeye(leto) or ti')
51 | p.add_argument('--sensor_type', type=str, default='4k', help='4k or 2k')
52 | p.add_argument('--laser_type', type=str, default='new', help='laser, new_laser, sLED, ...')
53 | p.add_argument('--setup_type', type=str, default='siggraph2022', help='siggraph2022, ...')
54 | p.add_argument('--prop_model', type=str, default='ASM', help='Type of propagation model, ASM or model')
55 | p.add_argument('--out_path', type=str, default='./results',
56 | help='Directory for output')
57 | p.add_argument('--citl', type=str2bool, default=False,
58 | help='If True, run camera-in-the-loop')
59 | p.add_argument('--mod_i', type=int, default=None,
60 | help='If not None, say K, pick every K target images from the target loader')
61 | p.add_argument('--mod', type=int, default=None,
62 | help='If not None, say K, pick every K target images from the target loader')
63 | p.add_argument('--data_path', type=str, default='data/2d',
64 | help='Directory for input')
65 | p.add_argument('--exp', type=str, default='', help='Name of experiment')
66 | p.add_argument('--lr', type=float, default=0.02, help='Learning rate')
67 | p.add_argument('--num_iters', type=int, default=5000, help='Number of iterations (GS, SGD)')
68 | p.add_argument('--prop_dist', type=float, default=None, help='propagation distance from SLM to midplane')
69 | p.add_argument('--num_frames', type=int, default=1, help='Number of frames to average') # effect time joint
70 | p.add_argument('--F_aperture', type=float, default=1.0, help='Fourier filter size') # how this effects
71 | p.add_argument('--eyepiece', type=float, default=0.12, help='eyepiece focal length')
72 | p.add_argument('--full_roi', type=str2bool, default=False,
73 | help='If True, force ROI to SLM resolution')
74 | p.add_argument('--flipud', type=str2bool, default=False,
75 | help='flip slm vertically before propagation')
76 | p.add_argument('--target', type=str, default='2d',
77 | help='Type of target:'
78 | '{2d, rgb} or '
79 | '{2.5d, rgbd} or'
80 | '{3d, fs, focal-stack, focal_stack} or'
81 | '{4d, lf, light-field, light_field}')
82 | p.add_argument('--show_preview', type=str2bool, default=False,
83 | help='If true, show the preview for homography calibration')
84 | p.add_argument('--random_gen', type=str2bool, default=False,
85 | help='If true, randomize a few parameters for phase dataset generation')
86 | p.add_argument('--test_set_3d', type=str2bool, default=False,
87 | help='If true, load a set of 3D scenes for phase inference')
88 | p.add_argument('--mem_eff', type=str2bool, default=False,
89 | help='If true, run memory an efficient version of algorithms (slow)')
90 | p.add_argument("--roi_h", type=int, default=None) # height of ROI
91 | p.add_argument("--optimize_amp", type=str2bool, default=False) # optimize amplitude
92 |
93 | # Hardware
94 | p.add_argument("--slm_settle_time", type=float, default=1.0)
95 |
96 | # Regularization
97 | p.add_argument('--reg_loss_fn_type', type=str, default=None)
98 | p.add_argument('--reg_loss_w', type=float, default=0.0)
99 | p.add_argument('--recon_loss_w', type=float, default=1.0)
100 | p.add_argument('--adaptive_roi_scale', type=float, default=1.0)
101 |
102 | p.add_argument("--save_images", action="store_true")
103 | p.add_argument("--save_npy", action="store_true")
104 | p.add_argument("--serial_two_prop_off", action="store_true", help="Directly propagate prop_dist, and don't use prop_dist_from_wrp.")
105 |
106 | # Initialization schemes
107 | p.add_argument('--init_phase_type', type=str, default="random", choices=["random"])
108 |
109 |
110 | # Quantization
111 | p.add_argument('--quan_method', type=str, default='None',
112 | help='Quantization method, None, nn, nn_sigmoid, gumbel-softmax, ...')
113 | p.add_argument('--c_s', type=float, default=300,
114 | help='Coefficient mutliplied to score value - considering Gumbel noise scale')
115 | p.add_argument('--uniform_nbits', type=int, default=None,
116 | help='If not None, use uniformly-distributed discrete SLM levels for quantization')
117 | p.add_argument('--tau_max', type=float, default=5.5,
118 | help='tau value used for quantization at the beginning - increase for more constrained cases')
119 | p.add_argument('--tau_min', type=float, default=2.0,
120 | help='minimum tau value used for quantization')
121 | p.add_argument('--r', type=float, default=None,
122 | help='coefficient on the exponent (speed of decrease)')
123 | p.add_argument('--phase_offset', type=float, default=0.0,
124 | help='You can shift the whole phase to some extent (Not used in the paper)')
125 | p.add_argument('--time_joint', type=str2bool, default=True,
126 | help='If True, jointly optimize multiple frames with time-multiplexed forward model')
127 | p.add_argument('--init_phase_range', type=float, default=1.0,
128 | help='initial phase range')
129 | p.add_argument('--eval_plane_idx', type=int, default=None,
130 | help='depth plane to evaluate hologram reconstruction')
131 | p.add_argument('--use_lut', action="store_true", help="Use SLM discrete phase lookup table.")
132 |
133 | p.add_argument('--gpu_id', type=int, default=0, help="GPU id")
134 |
135 | # Dataset
136 | p.add_argument("--dataset_subset_size", type=int, default=None)
137 | p.add_argument("--img_paths", type=str, nargs="+", default=None)
138 | p.add("--shutter_speed", type=float, nargs='+', default=100, help="Shutter speed of camera.")
139 | p.add("--num_data", type=int, default=100, help="Number of data to generate.")
140 |
141 | # Light field
142 | p.add_argument('--hop_len', type=int, default=0.0,
143 | help='hop every k - if you hop every window size being HS')
144 | p.add_argument('--n_fft', type=int, default=True,
145 | help='number of fourier samples per patch')
146 | p.add_argument('--win_len', type=int, default=1.0,
147 | help='STFT window size')
148 | p.add_argument('--central_views', type=str2bool, default=False,
149 | help='If True, penalize only central views')
150 | p.add_argument('--reg_lf_var', type=float, default=0.0,
151 | help='lf regularization')
152 |
153 | if mode in ('train', 'eval'):
154 | p.add_argument('--num_epochs', type=int, default=350, help='')
155 | p.add_argument('--batch_size', type=int, default=1, help='')
156 | p.add_argument('--prop_model_path', type=str, default=None, help='Path to checkpoints')
157 | p.add_argument('--predefined_model', type=str, default=None, help='string for predefined model'
158 | 'nh, nh3d, nh4d')
159 | p.add_argument('--num_downs_slm', type=int, default=5, help='')
160 | p.add_argument('--num_feats_slm_min', type=int, default=32, help='')
161 | p.add_argument('--num_feats_slm_max', type=int, default=128, help='')
162 | p.add_argument('--num_downs_target', type=int, default=5, help='')
163 | p.add_argument('--num_feats_target_min', type=int, default=32, help='')
164 | p.add_argument('--num_feats_target_max', type=int, default=128, help='')
165 | p.add_argument('--slm_coord', type=str, default='rect', help='coordinates to represent a complex-valued field.'
166 | 'rect(real+imag) or polar(amp+phase)')
167 | p.add_argument('--target_coord', type=str, default='rect', help='coordinates to represent a complex-valued field.'
168 | 'rect(real+imag) or polar(amp+phase)')
169 | p.add_argument('--param_lut', type=str2bool, default=False, help='')
170 | p.add_argument('--norm', type=str, default='instance', help='normalization layer')
171 | p.add_argument('--slm_latent_amp', type=str2bool, default=False, help='If True, '
172 | 'param amplitdues multiplied at SLM')
173 | p.add_argument('--slm_latent_phase', type=str2bool, default=False, help='If True, '
174 | 'parameterize phase added at SLM')
175 | p.add_argument('--f_latent_amp', type=str2bool, default=False, help='If True, '
176 | 'parameterize amplitdues multiplied at F')
177 | p.add_argument('--f_latent_phase', type=str2bool, default=False, help='If True, '
178 | 'parameterize amplitdues added at F')
179 | p.add_argument('--share_f_amp', type=str2bool, default=False, help='If True, use the same f_latent_amp params '
180 | 'for propagating fields from WRP to'
181 | 'Target planes')
182 | p.add_argument('--share_f_phase', type=str2bool, default=False, help='If True, use the same f_latent_phase '
183 | 'params for propagating fields from WRP to'
184 | 'Target planes')
185 | p.add_argument('--loss_func', type=str, default='l1', help='l1 or l2')
186 | p.add_argument('--energy_compensation', type=str2bool, default=True, help='adjust intensities '
187 | 'with avg intensity of training set')
188 | p.add_argument('--num_train_planes', type=int, default=6, help='number of planes fed to models')
189 | p.add_argument('--learn_f_amp_wrp', type=str2bool, default=False)
190 | p.add_argument('--learn_f_phase_wrp', type=str2bool, default=False)
191 |
192 | # cnn residuals
193 | p.add_argument("--slm_cnn_residual", type=str2bool, default=False)
194 | p.add_argument("--target_cnn_residual", type=str2bool, default=False)
195 | p.add_argument("--min_mse_scaling", type=str2bool, default=False)
196 | p.add_argument("--dataset_subset", type=int, default=None)
197 |
198 | return p
199 |
200 |
201 | def set_configs(opt_p):
202 | """
203 | set or replace parameters with pre-defined parameters with string inputs
204 | """
205 | opt = PMap()
206 | for k, v in vars(opt_p).items():
207 | opt[k] = v
208 |
209 | # hardware setup
210 | optics_config(opt.setup_type, opt) # prop_dist, etc ...
211 | laser_config(opt.laser_type, opt) # Our Old FISBA Laser, New, SLED, LED
212 | slm_config(opt.slm_type, opt) # Holoeye or TI
213 | sensor_config(opt.sensor_type, opt) # old or new 4k
214 |
215 | # set predefined model parameters
216 | forward_model_config(opt.prop_model, opt)
217 |
218 | # wavelength, propagation distance (from SLM to midplane)
219 | if opt.channel is None:
220 | opt.chan_str = 'rgb'
221 | #opt.prop_dist = opt.prop_dists_rgb
222 | opt.prop_dist_green = opt.prop_dist
223 | opt.wavelength = opt.wavelengths
224 | else:
225 | opt.chan_str = ('red', 'green', 'blue')[opt.channel]
226 | if opt.prop_dist is None:
227 | opt.prop_dist = opt.prop_dists_rgb[opt.channel][opt.mid_idx] # prop dist from SLM plane to target plane
228 | if len(opt.prop_dists_rgb[opt.channel]) <= 1:
229 | opt.prop_dist_green = opt.prop_dists_rgb[opt.channel][0]
230 | else:
231 | opt.prop_dist_green = opt.prop_dists_rgb[opt.channel][1]
232 | else:
233 | opt.prop_dist_green = opt.prop_dist
234 | opt.wavelength = opt.wavelengths[opt.channel] # wavelength of each color
235 |
236 | # propagation distances from the wavefront recording plane
237 | if opt.channel is not None:
238 | opt.prop_dists_from_wrp = [p - opt.prop_dist for p in opt.prop_dists_rgb[opt.channel]]
239 | else:
240 | opt.prop_dists_from_wrp = [p - opt.prop_dist for p in opt.prop_dists_rgb[1]]
241 | opt.physical_depth_planes = [p - opt.prop_dist_green for p in opt.prop_dists_physical]
242 | opt.virtual_depth_planes = utils.prop_dist_to_diopter(opt.physical_depth_planes,
243 | opt.eyepiece,
244 | opt.physical_depth_planes[0])
245 | if opt.serial_two_prop_off:
246 | opt.prop_dists_from_wrp = None
247 | opt.num_planes = 1 # use prop_dist
248 | assert opt.prop_dist is not None
249 | else:
250 | opt.num_planes = len(opt.prop_dists_from_wrp)
251 | opt.all_plane_idxs = range(opt.num_planes)
252 |
253 | # force ROI to that of SLM
254 | if opt.full_roi:
255 | opt.roi_res = opt.slm_res
256 |
257 | ################
258 | # Model Training
259 | # compensate the brightness difference per plane (for model training)
260 | if opt.energy_compensation:
261 | if opt.channel is not None:
262 | opt.avg_energy_ratio = opt.avg_energy_ratio_rgb[opt.channel]
263 | else:
264 | opt.avg_energy_ratio = None
265 | else:
266 | opt.avg_energy_ratio = None
267 |
268 | # loss functions (for model training)
269 | opt.loss_train = None
270 | opt.loss_fn = None
271 | if opt.loss_func.lower() in ('l2', 'mse'):
272 | opt.loss_train = nn.functional.mse_loss
273 | opt.loss_fn = nn.functional.mse_loss
274 | elif opt.loss_func.lower() == 'l1':
275 | opt.loss_train = nn.functional.l1_loss
276 | opt.loss_fn = nn.functional.l1_loss
277 |
278 | # plane idxs (for model training)
279 | opt.plane_idxs = {}
280 | opt.plane_idxs['all'] = opt.all_plane_idxs
281 | opt.plane_idxs['train'] = opt.training_plane_idxs
282 | opt.plane_idxs['validation'] = opt.training_plane_idxs
283 | opt.plane_idxs['test'] = opt.training_plane_idxs
284 | opt.plane_idxs['heldout'] = opt.heldout_plane_idxs
285 |
286 | return opt
287 |
288 |
289 | def run_id(opt):
290 | id_str = f'{opt.exp}_{opt.method}_{opt.chan_str}_{opt.prop_model}_{opt.num_iters}_recon_{opt.recon_loss_w}_{opt.reg_loss_fn_type}_{opt.reg_loss_w}_{opt.init_phase_type}'
291 | if opt.citl:
292 | id_str = f'{id_str}_citl'
293 | if opt.mem_eff:
294 | id_str = f'{id_str}_memeff'
295 | id_str = f'{id_str}_tm_{opt.num_frames}' # time multiplexing
296 | if opt.citl:
297 | id_str = f'{id_str}_sht_{opt.shutter_speed[0]}' # shutter speed
298 | if opt.optimize_amp:
299 | id_str = f'{id_str}_opt_amp'
300 | return id_str
301 |
302 | def run_id_training(opt):
303 | id_str = f'{opt.exp}_{opt.chan_str}-' \
304 | f'data_{opt.capture_subset}-' \
305 | f'slm{opt.num_downs_slm}-{opt.num_feats_slm_min}-{opt.num_feats_slm_max}_' \
306 | f'{str(opt.slm_latent_amp)[0]}{str(opt.slm_latent_phase)[0]}_' \
307 | f'tg{opt.num_downs_target}-{opt.num_feats_target_min}-{opt.num_feats_target_max}_' \
308 | f'lut{str(opt.param_lut)[0]}_' \
309 | f'lH{str(opt.f_latent_amp)[0]}{str(opt.f_latent_phase)[0]}_' \
310 | f'sH{str(opt.share_f_amp)[0]}{str(opt.share_f_phase)[0]}_' \
311 | f'eH{str(opt.learn_f_amp_wrp)[0]}{str(opt.learn_f_phase_wrp)[0]}_' \
312 | f'{opt.slm_coord}{opt.target_coord}_{opt.loss_func}_{opt.num_train_planes}pls_' \
313 | f'bs{opt.batch_size}_' \
314 | f'res-{opt.slm_cnn_residual}-{opt.target_cnn_residual}_' \
315 | f'mse-s{opt.min_mse_scaling}'
316 |
317 | cur_time = datetime.datetime.now().strftime("%d-%H%M")
318 | id_str = f'{cur_time}_{id_str}'
319 |
320 | return id_str
321 |
322 |
323 | def hw_params(opt):
324 | params_slm = PMap()
325 | params_slm.settle_time = max(opt.shutter_speed) * 2.5 / 1000 # shutter speed is in ms
326 | params_slm.monitor_num = 1 # change here
327 | params_slm.slm_type = opt.slm_type
328 |
329 | params_camera = PMap()
330 | #params_camera.img_size_native = (3000, 4096) # 4k sensor native
331 | params_camera.img_size_native = (1700, 2736) # Used for SIGGRAPH 2022
332 | params_camera.ser = None #serial.Serial('COM5', 9600, timeout=0.5)
333 |
334 | params_calib = PMap()
335 | params_calib.show_preview = opt.show_preview
336 | params_calib.range_y = slice(0, params_camera.img_size_native[0])
337 | params_calib.range_x = slice(0, params_camera.img_size_native[1])
338 | params_calib.num_circles = (11, 18)
339 |
340 | params_calib.spacing_size = [int(roi / (num_circs - 1))
341 | for roi, num_circs in zip(opt.roi_res, params_calib.num_circles)]
342 | params_calib.pad_pixels = [int(slm - roi) // 2 for slm, roi in zip(opt.slm_res, opt.roi_res)]
343 | params_calib.quadratic = True
344 |
345 | colors = ['red', 'green', 'blue']
346 | params_calib.phase_path = f"data/calib/{colors[opt.channel]}/11x18_r19_ti_slm_dots_phase.png" # optimize homography pattern for every plane
347 | params_calib.blank_phase_path = "data/calib/2560x1600_blank.png"
348 | params_calib.img_size_native = params_camera.img_size_native
349 |
350 | return params_slm, params_camera, params_calib
351 |
352 |
353 | def slm_config(slm_type, opt):
354 | # setting for specific SLM.
355 | if slm_type.lower() in ('ti'):
356 | opt.feature_size = (10.8 * um, 10.8 * um) # SLM pitch
357 | opt.slm_res = (800, 1280) # resolution of SLM
358 | opt.image_res = (800, 1280)
359 | #opt.image_res = (1600, 2560)
360 | if opt.channel is not None:
361 | opt.lut0 = DiscreteSLM.lut[:-1] * 636.4 * nm / opt.wavelengths[opt.channel] # scaled LUT
362 | else:
363 | opt.lut0 = DiscreteSLM.lut[:-1]
364 | opt.flipud = True
365 | elif slm_type.lower() in ('leto', 'holoeye'):
366 | opt.feature_size = (6.4 * um, 6.4 * um) # SLM pitch
367 | opt.slm_res = (1080, 1920) # resolution of SLM
368 | opt.image_res = opt.slm_res
369 | opt.lut0 = None
370 | if opt.projector:
371 | opt.flipud = not opt.flipud
372 |
373 | def laser_config(laser_type, opt):
374 | # setting for specific laser.
375 | if 'new' in laser_type.lower():
376 | opt.wavelengths = [636.17 * nm, 518.48 * nm, 442.03 * nm] # wavelength of each color
377 | elif "readybeam" in laser_type.lower():
378 | # using this for etech
379 | opt.wavelengths = (638.35 * nm, 521.16 * nm, 443.50 * nm)
380 | else:
381 | opt.wavelengths = [636.4 * nm, 517.7 * nm, 440.8 * nm]
382 |
383 |
384 | def sensor_config(sensor_type, opt):
385 | return opt
386 |
387 |
388 | def optics_config(setup_type, opt):
389 | if setup_type in ('siggraph2022'):
390 | opt.laser_type = 'old'
391 | opt.slm_type = 'ti'
392 | opt.avg_energy_ratio_rgb = [[1.0000, 1.0595, 1.1067, 1.1527, 1.1943, 1.2504, 1.3122],
393 | [1.0000, 1.0581, 1.1051, 1.1490, 1.1994, 1.2505, 1.3172],
394 | [1.0000, 1.0560, 1.1035, 1.1487, 1.2008, 1.2541, 1.3183]] # averaged over training set
395 | opt.prop_dists_rgb = [[7.76*cm, 7.96*cm, 8.13*cm, 8.31*cm, 8.48*cm, 8.72*cm, 9.04*cm],
396 | [7.77*cm, 7.97*cm, 8.13*cm, 8.31*cm, 8.48*cm, 8.72*cm, 9.04*cm],
397 | [7.76*cm, 7.96*cm, 8.13*cm, 8.31*cm, 8.48*cm, 8.72*cm, 9.04*cm]]
398 | opt.prop_dists_physical = opt.prop_dists_rgb[1]
399 | opt.roi_res = (700, 1190) # regions of interest (to penalize for SGD)
400 |
401 | if not opt.method.lower() in ['olas', 'dpac']:
402 | opt.F_aperture = (0.7, 0.78, 0.9)[opt.channel]
403 | else:
404 | opt.F_aperture = 0.49
405 |
406 | # indices of training planes (idx 4 is the held-out plane)
407 | if opt.num_train_planes == 1:
408 | opt.training_plane_idxs = [3]
409 | elif opt.num_train_planes == 3:
410 | opt.training_plane_idxs = [0, 3, 6]
411 | elif opt.num_train_planes == 5:
412 | opt.training_plane_idxs = [0, 2, 3, 5, 6]
413 | elif opt.num_train_planes == 6:
414 | opt.training_plane_idxs = [0, 1, 2, 3, 5, 6]
415 | else:
416 | opt.training_plane_idxs = None
417 | opt.heldout_plane_idxs = [4]
418 | opt.mid_idx = 3 # intermediate plane as 1.5D
419 |
420 |
421 | def forward_model_config(model_type, opt):
422 | # setting for specific model that is predefined.
423 | if model_type is not None:
424 | print(f' - changing model parameters for {model_type}')
425 | if model_type.lower() == 'nh3d':
426 | opt.num_downs_slm = 8
427 | opt.num_feats_slm_min = 32
428 | opt.num_feats_slm_max = 512
429 | opt.num_downs_target = 5
430 | opt.num_feats_target_min = 8
431 | opt.num_feats_target_max = 128
432 | opt.param_lut = False
433 |
434 | elif model_type.lower() == 'hil':
435 | opt.num_downs_slm = 0
436 | opt.num_feats_slm_min = 0
437 | opt.num_feats_slm_max = 0
438 | opt.num_downs_target = 8
439 | opt.num_feats_target_min = 32
440 | opt.num_feats_target_max = 512
441 | opt.target_coord = 'amp'
442 | opt.param_lut = False
443 |
444 | elif model_type.lower() == 'cnnprop':
445 | opt.num_downs_slm = 8
446 | opt.num_feats_slm_min = 32
447 | opt.num_feats_slm_max = 512
448 | opt.num_downs_target = 0
449 | opt.num_feats_target_min = 0
450 | opt.num_feats_target_max = 0
451 | opt.param_lut = False
452 |
453 | elif model_type.lower() == 'propcnn':
454 | opt.num_downs_slm = 0
455 | opt.num_feats_slm_min = 0
456 | opt.num_feats_slm_max = 0
457 | opt.num_downs_target = 8
458 | opt.num_feats_target_min = 32
459 | opt.num_feats_target_max = 512
460 | opt.param_lut = False
461 |
462 | elif model_type.lower() == 'nh4d':
463 | opt.num_downs_slm = 5
464 | opt.num_feats_slm_min = 32
465 | opt.num_feats_slm_max = 128
466 | opt.num_downs_target = 5
467 | opt.num_feats_target_min = 32
468 | opt.num_feats_target_max = 128
469 | opt.num_target_latent = 0
470 | opt.norm = 'instance'
471 | opt.slm_coord = 'both'
472 | opt.target_coord = 'both_1ch_output'
473 | opt.param_lut = True
474 | opt.slm_latent_amp = True
475 | opt.slm_latent_phase = True
476 | opt.f_latent_amp = True
477 | opt.f_latent_phase = True
478 | opt.share_f_amp = True
479 |
480 |
481 | def add_lf_params(opt, dataset='olas'):
482 | """ Add Light-Field parameters """
483 | if opt.target.lower() in ('rgbd'):
484 | if opt.reg_lf_var > 0.0:
485 | opt.ang_res = (7, 7)
486 | opt.load_only_central_view = True
487 | opt.hop_len = (1, 1)
488 | opt.n_fft = opt.ang_res
489 | opt.win_len = opt.ang_res
490 | if opt.central_views:
491 | opt.selected_views = (slice(1, 6, 1), slice(1, 6, 1))
492 | else:
493 | opt.selected_views = None
494 | return opt
495 | else:
496 | return opt
497 | else:
498 | if dataset == 'olas':
499 | opt.ang_res = (9, 9)
500 | opt.load_only_central_view = opt.target.lower() == 'rgbd'
501 | opt.hop_len = (1, 1)
502 | opt.n_fft = opt.ang_res
503 | opt.win_len = opt.ang_res
504 |
505 | if dataset == 'parallax':
506 | opt.ang_res = (7, 7)
507 | opt.load_only_central_view = opt.target.lower() == 'rgbd'
508 | opt.hop_len = (1, 1)
509 | opt.n_fft = opt.ang_res
510 | opt.win_len = opt.ang_res
511 |
512 | if 'lf' in opt.target.lower():
513 | opt.prop_dist_from_wrp = [0.]
514 | opt.c_s = 700
515 | if opt.central_views:
516 | opt.selected_views = (slice(1, 6, 1), slice(1, 6, 1))
517 | else:
518 | opt.selected_views = None
519 |
520 | return opt
--------------------------------------------------------------------------------
/props/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/time-multiplexed-neural-holography/5cf6c275c459652abb3ddddd2e167f9584072aeb/props/__init__.py
--------------------------------------------------------------------------------
/props/prop_ideal.py:
--------------------------------------------------------------------------------
1 | """
2 | Ideal propagation
3 |
4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu)
5 |
6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
8 | # The material is provided as-is, with no warranties whatsoever.
9 | # If you publish any code, data, or scientific work based on this, please cite our work.
10 |
11 | Technical Paper:
12 | Time-multiplexed Neural Holography:
13 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
14 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein.
15 | SIGGRAPH 2022
16 | """
17 |
18 | import torch
19 | import torch.nn as nn
20 | import utils
21 | import torch.fft as tfft
22 | import math
23 | from copy import deepcopy
24 |
25 | class Propagation(nn.Module):
26 | """
27 | The ideal, convolution-based propagation implementation
28 |
29 | Class initialization parameters
30 | -------------------------------
31 | :param prop_dist: propagation distance(s)
32 | :param wavelength: wavelength
33 | :param feature_size: pixel pitch
34 | :param prop_type: type of propagation (ASM or fresnel), by default the angular spectrum method
35 | :param F_aperture: filter size at fourier plane, by default 1.0
36 | :param dim: for propagation to multiple planes, dimension to stack the output, by default 1 (second dimension)
37 | :param linear_conv: If true, pad zeros to ensure the linear convolution, by default True
38 | :param learned_amp: Learned amplitude at Fourier plane, by default None
39 | :param learned_phase: Learned phase at Fourier plane, by default None
40 | """
41 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0,
42 | dim=1, linear_conv=True, learned_amp=None, learned_phase=None, learned_field=None):
43 | super(Propagation, self).__init__()
44 |
45 | self.H = None # kernel at Fourier plane
46 | self.prop_type = prop_type
47 | if not isinstance(prop_dist, list):
48 | prop_dist = [prop_dist]
49 | self.prop_dist = prop_dist
50 | self.feature_size = feature_size
51 | if not isinstance(wavelength, list):
52 | wavelength = [wavelength]
53 | self.wvl = wavelength
54 | self.linear_conv = linear_conv # ensure linear convolution by padding
55 | self.bl_asm = min(prop_dist) > 0.3
56 | self.F_aperture = F_aperture
57 | self.dim = dim # The dimension to stack the kernels as well as the resulting fields (if multi-channel)
58 |
59 | self.preload_params = False
60 | self.preloaded_H_amp = False # preload H_mask once trained
61 | self.preloaded_H_phase = False # preload H_phase once trained
62 |
63 | self.fourier_amp = learned_amp
64 | self.fourier_phase = learned_phase
65 | self.fourier_field = learned_field
66 |
67 | #self.bl_asm = True
68 | if self.bl_asm:
69 | print("Using band-limited ASM")
70 | else:
71 | print("Using naive ASM")
72 |
73 | def forward(self, u_in):
74 | if u_in.dtype == torch.float32: # check if this is phase or already a wavefield
75 | u_in = torch.exp(1j * u_in) # convert phase to wavefront
76 |
77 | if self.H is None:
78 | Hs = []
79 | if len(self.wvl) > 1: # If multi-channel, rearrange kernels
80 | for i, wv in enumerate(self.wvl):
81 | H_wvl = []
82 | for prop_dist in self.prop_dist:
83 | print(f' -- generating kernel for {wv*1e9:.1f}nm, {prop_dist*100:.2f}cm..')
84 | h = self.compute_H(torch.empty_like(u_in), prop_dist, wv, self.feature_size,
85 | self.prop_type, self.linear_conv,
86 | F_aperture=self.F_aperture, bl_asm=self.bl_asm)
87 | H_wvl.append(h)
88 | H_wvl = torch.cat(H_wvl, dim=1)
89 | Hs.append(H_wvl)
90 | self.H = torch.cat(Hs, dim=1)
91 | else:
92 | for wv in self.wvl:
93 | for prop_dist in self.prop_dist:
94 | print(f' -- generating kernel for {wv*1e9:.1f}nm, {prop_dist*100:.2f}cm..')
95 | h = self.compute_H(torch.empty_like(u_in), prop_dist, wv, self.feature_size,
96 | self.prop_type, self.linear_conv,
97 | F_aperture=self.F_aperture, bl_asm=self.bl_asm)
98 | Hs.append(h)
99 | self.H = torch.cat(Hs, dim=1)
100 |
101 | if self.preload_params:
102 | self.premultiply()
103 |
104 | if self.fourier_field is not None:
105 | # for neural wavefront model
106 | fourier_field, fourier_dc_field = self.fourier_field() # neural wavefield
107 | H = self.H * fourier_field
108 | else:
109 | if self.fourier_amp is not None and not self.preloaded_H_amp:
110 | H = self.fourier_amp.clamp(min=0.) * self.H
111 | else:
112 | H = self.H
113 |
114 | if self.fourier_phase is not None and not self.preloaded_H_phase:
115 | H = H * torch.exp(1j * self.fourier_phase)
116 |
117 | return self.prop(u_in, H, self.linear_conv)
118 |
119 | def compute_H(self, input_field, prop_dist, wvl, feature_size, prop_type, lin_conv=True,
120 | return_exp=False, F_aperture=1.0, bl_asm=False, return_filter=False):
121 | dev = input_field.device
122 | res_mul = 2 if lin_conv else 1
123 | num_y, num_x = res_mul*input_field.shape[-2], res_mul*input_field.shape[-1] # number of pixels
124 | dy, dx = feature_size # sampling inteval size, pixel pitch of the SLM
125 | # does this mean the holographic display can display only one pixel (focus light to one pixel (smallest feature size))?
126 |
127 | # frequency coordinates sampling
128 | fy = torch.linspace(-1 / (2 * dy), 1 / (2 * dy), num_y)
129 | fx = torch.linspace(-1 / (2 * dx), 1 / (2 * dx), num_x)
130 |
131 | # momentum/reciprocal space
132 | # FY, FX = torch.meshgrid(fy, fx)
133 | FX, FY = torch.meshgrid(fx, fy)
134 | FX = torch.transpose(FX, 0, 1)
135 | FY = torch.transpose(FY, 0, 1)
136 |
137 | if prop_type.lower() == 'asm':
138 | G = 2 * math.pi * (1 / wvl**2 - (FX ** 2 + FY ** 2)).sqrt()
139 | elif prop_type.lower() == 'fresnel':
140 | G = math.pi * wvl * (FX ** 2 + FY ** 2)
141 |
142 | H_exp = G.reshape((1, 1, *G.shape)).to(dev)
143 |
144 | if return_exp:
145 | return H_exp
146 |
147 | if bl_asm:
148 | fy_max = 1 / math.sqrt((2 * prop_dist * (1 / (dy * float(num_y))))**2 + 1) / wvl
149 | fx_max = 1 / math.sqrt((2 * prop_dist * (1 / (dx * float(num_x))))**2 + 1) / wvl
150 |
151 | H_filter = ((torch.abs(FX**2 + FY**2) <= (F_aperture**2) * torch.abs(FX**2 + FY**2).max())
152 | & (torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).type(torch.FloatTensor)
153 | else:
154 | H_filter = (torch.abs(FX**2 + FY**2) <= (F_aperture**2) * torch.abs(FX**2 + FY**2).max()).type(torch.FloatTensor)
155 |
156 | if prop_dist == 0.:
157 | H = torch.ones_like(H_exp)
158 | else:
159 | H = H_filter.to(input_field.device) * torch.exp(1j * H_exp * prop_dist)
160 | self.H_without_filter = torch.exp(1j * H_exp * prop_dist)
161 | self.H_filter = H_filter
162 |
163 | if return_filter:
164 | return H_filter
165 | else:
166 | return H
167 |
168 | def prop(self, u_in, H, linear_conv=True, padtype='zero'):
169 | if linear_conv:
170 | # preprocess with padding for linear conv.
171 | input_resolution = u_in.size()[-2:]
172 | conv_size = [i * 2 for i in input_resolution]
173 | if padtype == 'zero':
174 | padval = 0
175 | elif padtype == 'median':
176 | padval = torch.median(torch.pow((u_in ** 2).sum(-1), 0.5))
177 | u_in = utils.pad_image(u_in, conv_size, padval=padval, stacked_complex=False)
178 |
179 | U1 = tfft.fftshift(tfft.fftn(u_in, dim=(-2, -1), norm='ortho'), (-2, -1)) # fourier transform
180 | #U2_without_filter = U1 * self.H_without_filter
181 | U2 = U1 * H
182 | u_out = tfft.ifftn(tfft.ifftshift(U2, (-2, -1)), dim=(-2, -1), norm='ortho')
183 |
184 | if linear_conv: # also record uncropped image
185 | self.uncropped_u_out = u_out.clone()
186 | u_out = utils.crop_image(u_out, input_resolution, pytorch=True, stacked_complex=False)
187 |
188 | """
189 | U2_amp = torch.abs(U2)
190 | U2_without_filter_amp = torch.abs(U2_without_filter)
191 | filtered_intensity_sum = (U2_without_filter_amp**2 - U2_amp**2).mean()
192 | #print('total amp:', (U2_without_filter_amp**2).mean())
193 | #print('filtered_intensity_sum: ', filtered_intensity_sum)
194 |
195 | # normalize to 0, 1
196 | U2_amp = (U2_amp - U2_amp.min()) / (U2_amp.max() - U2_amp.min())
197 | U2_without_filter_amp = (U2_without_filter_amp - U2_without_filter_amp.min()) / (U2_without_filter_amp.max() - U2_without_filter_amp.min())
198 | U2_amp = U2_amp.mean(axis=0).squeeze()
199 | U2_without_filter_amp = U2_without_filter_amp.mean(axis=0).squeeze()
200 | U2_amp = torch.log(U2_amp + 1e-10)
201 | U2_without_filter_amp = torch.log(U2_without_filter_amp + 1e-10)
202 |
203 | H_amp = torch.abs(self.H_filter)
204 | H_amp = (H_amp - H_amp.min()) / (H_amp.max() - H_amp.min())
205 | H_amp = H_amp.squeeze()
206 | """
207 |
208 |
209 | return u_out
210 |
211 | def __len__(self):
212 | return len(self.prop_dist)
213 |
214 | def preload_H(self):
215 | self.preload_params = True
216 |
217 | def premultiply(self):
218 | self.preload_params = False
219 |
220 | if self.fourier_amp is not None and not self.preloaded_H_amp:
221 | self.H = self.fourier_amp.clamp(min=0.) * self.H
222 | if self.fourier_phase is not None and not self.preloaded_H_phase:
223 | self.H = self.H * torch.exp(1j * self.fourier_phase)
224 |
225 | self.H.detach_()
226 | self.preloaded_H_amp = True
227 | self.preloaded_H_phase = True
228 |
229 | @property
230 | def plane_idx(self):
231 | return self._plane_idx
232 |
233 | @plane_idx.setter
234 | def plane_idx(self, idx):
235 | if idx is None:
236 | return
237 |
238 | self._plane_idx = idx
239 | if len(self.prop_dist) > 1:
240 | self.prop_dist = [self.prop_dist[idx]]
241 |
242 | if self.fourier_amp is not None and self.fourier_amp.shape[1] > 1:
243 | self.fourier_amp = nn.Parameter(self.fourier_amp[:, idx:idx+1, ...], requires_grad=False)
244 | if self.fourier_phase is not None and self.fourier_phase.shape[1] > 1:
245 | self.fourier_phase = nn.Parameter(self.fourier_phase[:, idx:idx+1, ...], requires_grad=False)
246 |
247 |
248 |
249 | class SerialProp(nn.Module):
250 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0,
251 | prop_dists_from_wrp=None, linear_conv=True, dim=1, opt=None):
252 | super(SerialProp, self).__init__()
253 | first_prop = Propagation(prop_dist, wavelength, feature_size,
254 | prop_type=prop_type, linear_conv=linear_conv, F_aperture=F_aperture, dim=dim)
255 | props = [first_prop]
256 |
257 | if prop_dists_from_wrp is not None:
258 | second_prop = Propagation(prop_dists_from_wrp, wavelength, feature_size,
259 | prop_type=prop_type, linear_conv=linear_conv, F_aperture=1.0, dim=dim)
260 | props += [second_prop]
261 | self.props = nn.Sequential(*props)
262 |
263 | # copy the opt parameters for initializing prop in other modules
264 | self.opt = opt
265 |
266 | def forward(self, u_in):
267 |
268 | u_out = self.props(u_in)
269 | self.uncropped_u_out = self.props[-1].uncropped_u_out # dirty way to access final layer uncropped output
270 |
271 | return u_out
272 |
273 | def preload_H(self):
274 | for prop in self.props:
275 | prop.preload_H()
276 |
277 | @property
278 | def plane_idx(self):
279 | return self._plane_idx
280 |
281 | @plane_idx.setter
282 | def plane_idx(self, idx):
283 | if idx is None:
284 | return
285 |
286 | self._plane_idx = idx
287 | for prop in self.props:
288 | prop.plane_idx = idx
--------------------------------------------------------------------------------
/props/prop_physical.py:
--------------------------------------------------------------------------------
1 | """
2 | Propagation happening on the setup
3 |
4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu)
5 |
6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
8 | # The material is provided as-is, with no warranties whatsoever.
9 | # If you publish any code, data, or scientific work based on this, please cite our work.
10 |
11 | Technical Paper:
12 | Time-multiplexed Neural Holography:
13 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
14 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein.
15 | SIGGRAPH 2022
16 | """
17 |
18 | import torch
19 | import torch.nn as nn
20 | import utils
21 | import time
22 | import cv2
23 | import imageio
24 |
25 | from hw.phase_encodings import phase_encoding
26 | import sys
27 | if sys.platform == 'win32':
28 | import slmpy
29 | import hw.camera_capture_module as cam
30 | import hw.calibration_module as calibration
31 |
32 |
33 | class PhysicalProp(nn.Module):
34 | """ A module for physical propagation,
35 | forward pass displays gets SLM pattern as an input and display the pattern on the physical setup,
36 | and capture the diffraction image at the target plane,
37 | and then return warped image using pre-calibrated homography from instantiation.
38 |
39 | Class initialization parameters
40 | -------------------------------
41 | :param params_slm: a set of parameters for the SLM.
42 | :param params_camera: a set of parameters for the camera sensor.
43 | :param params_calib: a set of parameters for homography calibration.
44 | :param q_fn: quantization function module
45 |
46 | Usage
47 | -----
48 | Functions as a pytorch module:
49 |
50 | >>> camera_prop = PhysicalProp(...)
51 | >>> captured_amp = camera_prop(slm_phase)
52 |
53 | slm_phase: phase at the SLM plane, with dimensions [batch, 1, height, width]
54 | captured_amp: amplitude at the target plane, with dimensions [batch, 1, height, width]
55 |
56 | """
57 | def __init__(self, params_slm, params_camera, params_calib=None, q_fn=None, shutter_speed=100, hdr=False):
58 | super(PhysicalProp, self).__init__()
59 |
60 |
61 | self.shutter_speed = shutter_speed
62 | self.hdr = hdr
63 | self.q_fn = q_fn
64 | self.params_calib = params_calib
65 |
66 | if self.hdr:
67 | assert len(self.shutter_speed) > 1 # more than 1 shutter speed for HDR capture
68 | else:
69 | assert len(self.shutter_speed) == 1 # non-hdr mode supports only one shutter speed
70 |
71 | # 1. Connect Camera
72 | self.camera = cam.CameraCapture(params_camera)
73 | self.camera.connect(0) # specify the camera to use, 0 for main cam, 1 for the second cam
74 | #self.camera.start_capture()
75 | self.camera.start_capture()
76 |
77 | # 2. Connect SLM
78 | self.slm = slmpy.SLMdisplay(isImageLock=True, monitor=params_slm.monitor_num)
79 | self.params_slm = params_slm
80 |
81 | # 3. Calibrate hardware using homography
82 | if params_calib is not None:
83 | self.warper = calibration.Warper(params_calib)
84 | self.calibrate(params_calib.phase_path, params_calib.show_preview)
85 | else:
86 | self.warper = None
87 |
88 | def calibrate_total_laser_energy(self):
89 | print("Calibrating total laser energy...")
90 | phase_img = imageio.imread(self.params_calib.blank_phase_path)
91 | self.slm.updateArray(phase_img)
92 | time.sleep(5)
93 | captured_plane_wave = self.forward(phase_img)
94 | h, w = captured_plane_wave.shape[-2], captured_plane_wave.shape[-1] # full SLM size
95 | cropped_energy = utils.crop_image(captured_plane_wave**2, (500, 500), stacked_complex=False)
96 | self.total_laser_energy = cropped_energy.sum() * (h * w) / (500 * 500)
97 |
98 | def calibrate(self, phase_path, show_preview=False):
99 | """
100 |
101 | :param phase_path:
102 | :param show_preview:
103 | :return:
104 | """
105 | print(' -- Calibrating ...')
106 | self.camera.set_shutter_speed(2000) # for homography pattern. remember to reset it!
107 | self.camera.set_gain(10) # for homography pattern. remember to reset it!
108 | phase_img = imageio.imread(phase_path)
109 | #print(phase_img)
110 | self.slm.updateArray(phase_img)
111 | time.sleep(5)
112 | captured_img = self.camera.grab_images_fast(5) # capture 5-10 images for averaging
113 | calib_success = self.warper.calibrate(captured_img, show_preview)
114 | self.camera.set_gain(0)
115 | if calib_success:
116 | print(' -- Calibration succeeded!...')
117 | if not self.hdr:
118 | print("One time step shutter speed for non-HDR capture...")
119 | self.camera.set_shutter_speed(self.shutter_speed[0]) # reset for capture
120 | else:
121 | raise ValueError(' -- Calibration failed')
122 |
123 | def forward(self, slm_phase, time_avg=1):
124 | """
125 |
126 | :param slm_phase:
127 | :return:
128 | """
129 | input_phase = slm_phase
130 | if self.q_fn is not None:
131 | dp_phase = self.q_fn(input_phase)
132 | else:
133 | dp_phase = input_phase
134 |
135 | self.display_slm_phase(dp_phase)
136 |
137 | raw_intensity_sum = 0
138 | for t in range(time_avg):
139 | raw_intensity = self.capture_linear_intensity(dp_phase) # grayscale raw16 intensity image
140 | raw_intensity_sum += raw_intensity
141 | raw_intensity = raw_intensity_sum / time_avg
142 |
143 | # amplitude is returned! not intensity!
144 | warped_intensity = self.warper(raw_intensity) # apply homography
145 | return warped_intensity.sqrt() # return amplitude
146 |
147 | def capture_linear_intensity(self, slm_phase):
148 | """
149 | display a phase pattern on the SLM and capture a generated holographic image with the sensor.
150 |
151 | :param slm_phase:
152 | :return:
153 | """
154 | raw_uint16_data = self.capture_uint16() # display & retrieve buffer
155 | captured_intensity = self.process_raw_data(raw_uint16_data) # demosaick & sum up
156 | return captured_intensity
157 |
158 | def forward_hdr(self, slm_phase):
159 | """
160 |
161 | :param slm_phase:
162 | :return:
163 | """
164 | input_phase = slm_phase
165 | if self.q_fn is not None:
166 | dp_phase = self.q_fn(input_phase)
167 | else:
168 | dp_phase = input_phase
169 |
170 | raw_intensity_hdr, raw_intensity_stack = self.capture_linear_intensity_hdr(dp_phase) # grayscale raw16 intensity image
171 |
172 | # amplitude is returned! not intensity!
173 | warped_intensity_hdr = self.warper(raw_intensity_hdr) # apply homography
174 | warped_intensity_stack = [self.warper(intensity) for intensity in raw_intensity_stack]
175 | warped_amplitude_hdr = warped_intensity_hdr.sqrt()
176 | warped_amplitude_stack = [intensity.sqrt() for intensity in warped_intensity_stack]
177 | return warped_amplitude_hdr, warped_amplitude_stack
178 |
179 | def capture_linear_intensity_hdr(self, slm_phase):
180 | raw_uint16_data_list = []
181 | for s in self.shutter_speed:
182 | self.camera.set_shutter_speed(s) # one exposure
183 | raw_uint16_data = self.capture_uint16(slm_phase)
184 | raw_uint16_data_list.append(raw_uint16_data)
185 | #captured_intensity_hdr = self.process_raw_data(raw_uint16_data_list[0]) # convert to hdr and demosaick?
186 | captured_intensity_exposure_stack = [torch.clip(self.process_raw_data(raw_data), 0, 1) for raw_data in raw_uint16_data_list] # overexposed images, clip to range
187 | captured_intensity_hdr = self.merge_hdr(captured_intensity_exposure_stack)
188 | return captured_intensity_hdr, captured_intensity_exposure_stack
189 |
190 | def merge_hdr(self, exposure_stack):
191 | weight_sum = 0
192 | weighted_img_sum = 0
193 | for s, img in zip(self.shutter_speed, exposure_stack):
194 | weight = torch.exp(-4 * (img - 0.5)**2 / 0.5**2 )
195 | weighted_img = weight * (torch.log(img) - torch.log(torch.tensor(s)))
196 | weight_sum = weight_sum + weight
197 | weighted_img_sum = weighted_img_sum + weighted_img
198 | merged_img = torch.exp(weighted_img_sum / (weight_sum + 1e-10)) # numerical issues
199 | return merged_img
200 |
201 | def display_slm_phase(self, slm_phase):
202 | if slm_phase is not None: # just for simple camera capture
203 | if torch.is_tensor(slm_phase): # raw phase is always tensor.
204 | slm_phase_encoded = phase_encoding(slm_phase, self.params_slm.slm_type)
205 | else: # uint8 encoded phase (should be np.array)
206 | slm_phase_encoded = slm_phase
207 | self.slm.updateArray(slm_phase_encoded)
208 |
209 | def capture_uint16(self):
210 | """
211 | gets phase pattern(s) and display it on the SLM, and then send a signal to board (wait next clock from SLM).
212 | Right after hearing back from the SLM, it sends another signal to PC so that PC retreives the camera buffer.
213 |
214 | :param slm_phase:
215 | :return:
216 | """
217 |
218 | if self.camera.params.ser is not None:
219 | self.camera.params.ser.write(f'D'.encode())
220 |
221 | # TODO: make the following in a separate function.
222 | # Wait until receiving signal from arduino
223 | incoming_byte = self.camera.params.ser.inWaiting()
224 | t0 = time.perf_counter()
225 | while True:
226 | received = self.camera.params.ser.read(incoming_byte).decode('UTF-8')
227 | if received != 'C':
228 | incoming_byte = self.camera.params.ser.inWaiting()
229 | if time.perf_counter() - t0 > 2.0:
230 | break
231 | else:
232 | break
233 | else:
234 | #print("settling...")
235 | time.sleep(self.params_slm.settle_time)
236 | raw_data_from_buffer = self.camera.retrieve_buffer()
237 |
238 | return raw_data_from_buffer
239 |
240 | def process_raw_data(self, raw_data):
241 | """
242 | gets raw data from the camera buffer, and demosaick it
243 |
244 | :param raw_data:
245 | :return:
246 | """
247 | raw_data = raw_data - 64
248 | color_cv_image = cv2.cvtColor(raw_data, self.camera.demosaick_rule) # it gives float64 from uint16 -- double check it
249 | captured_intensity = utils.im2float(color_cv_image) # float64 to float32
250 |
251 | # Numpy to tensor
252 | captured_intensity = torch.tensor(captured_intensity, dtype=torch.float32,
253 | device=self.dev).permute(2, 0, 1).unsqueeze(0)
254 | captured_intensity = torch.sum(captured_intensity, dim=1, keepdim=True)
255 | return captured_intensity
256 |
257 | def disconnect(self):
258 | #self.camera.stop_capture()
259 | self.camera.stop_capture()
260 | self.camera.disconnect()
261 | self.slm.close()
262 |
263 | def to(self, *args, **kwargs):
264 | slf = super().to(*args, **kwargs)
265 | if slf.warper is not None:
266 | slf.warper = slf.warper.to(*args, **kwargs)
267 | try:
268 | slf.dev = next(slf.parameters()).device
269 | except StopIteration: # no parameters
270 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0]
271 | if device_arg is not None:
272 | slf.dev = device_arg
273 | return slf
--------------------------------------------------------------------------------
/props/prop_submodules.py:
--------------------------------------------------------------------------------
1 | """
2 | Modules for propagation
3 |
4 | """
5 |
6 | import math
7 | import torch
8 | import torch.nn as nn
9 | import utils
10 | from unet import Conv2dSame
11 |
12 |
13 | class Field2Input(nn.Module):
14 | """Gets complex-valued field and turns it into multi-channel images"""
15 |
16 | def __init__(self, input_res=(800, 1280), coord='rect', latent_amp=None, latent_phase=None, shared_cnn=False):
17 | super(Field2Input, self).__init__()
18 | self.input_res = input_res
19 | self.coord = coord.lower()
20 | self.latent_amp = latent_amp
21 | self.latent_phase = latent_phase
22 | self.shared_cnn = shared_cnn
23 |
24 | def forward(self, input_field):
25 | # If input field is slm phase
26 | if input_field.dtype == torch.float32:
27 | input_field = torch.exp(1j * input_field)
28 |
29 | # 1) Learned phase offset
30 | if self.latent_phase is not None:
31 | input_field = input_field * torch.exp(1j * self.latent_phase)
32 |
33 | # 2) Learned amplitude
34 | if self.latent_amp is not None:
35 | input_field = self.latent_amp * input_field
36 |
37 | input_field = utils.pad_image(input_field, self.input_res, pytorch=True, stacked_complex=False)
38 | input_field = utils.crop_image(input_field, self.input_res, pytorch=True, stacked_complex=False)
39 |
40 | # To use shared CNN, put everything into batch dimension;
41 | if self.shared_cnn:
42 | num_mb, num_dists = input_field.shape[0], input_field.shape[1]
43 | input_field = input_field.reshape(num_mb*num_dists, 1, *input_field.shape[2:])
44 |
45 | # Input format
46 | if self.coord == 'rect':
47 | stacked_input = torch.cat((input_field.real, input_field.imag), 1)
48 | elif self.coord == 'polar':
49 | stacked_input = torch.cat((input_field.abs(), input_field.angle()), 1)
50 | elif self.coord == 'amp':
51 | stacked_input = input_field.abs()
52 | elif 'both' in self.coord:
53 | stacked_input = torch.cat((input_field.abs(), input_field.angle(), input_field.real, input_field.imag), 1)
54 |
55 | return stacked_input
56 |
57 |
58 | class Output2Field(nn.Module):
59 | """Gets complex-valued field and turns it into multi-channel images"""
60 |
61 | def __init__(self, output_res=(800, 1280), coord='rect', num_ch_output=1):
62 | super(Output2Field, self).__init__()
63 | self.output_res = output_res
64 | self.coord = coord.lower()
65 | self.num_ch_output = num_ch_output # number of channels in output
66 |
67 | def forward(self, stacked_output):
68 |
69 | if self.coord in ('rect', 'both'):
70 | complex_valued_field = torch.view_as_complex(stacked_output.unsqueeze(4).
71 | permute(0, 4, 2, 3, 1).contiguous())
72 | elif self.coord == 'polar':
73 | amp = stacked_output[:, 0:1, ...]
74 | phi = stacked_output[:, 1:2, ...]
75 | complex_valued_field = amp * torch.exp(1j * phi)
76 | elif self.coord == 'amp' or '1ch_output' in self.coord:
77 | complex_valued_field = stacked_output * torch.exp(1j * torch.zeros_like(stacked_output))
78 |
79 | output_field = utils.pad_image(complex_valued_field, self.output_res, pytorch=True, stacked_complex=False)
80 | output_field = utils.crop_image(output_field, self.output_res, pytorch=True, stacked_complex=False)
81 |
82 | if self.num_ch_output > 1:
83 | # reshape to original tensor shape
84 | output_field = output_field.reshape(output_field.shape[0] // self.num_ch_output, self.num_ch_output,
85 | *output_field.shape[2:])
86 |
87 | return output_field
88 |
89 |
90 | class Conv2dField(nn.Module):
91 | """Apply 2d conv on amp or field"""
92 |
93 | def __init__(self, complex=False, conv_size=3):
94 | super(Conv2dField, self).__init__()
95 | self.complex = complex # apply convolution on field
96 | self.conv_size = (conv_size, conv_size)
97 | if self.complex:
98 | self.conv_real = Conv2dSame(1, 1, conv_size)
99 | self.conv_imag = Conv2dSame(1, 1, conv_size)
100 | init_weight = torch.zeros(1, 1, *self.conv_size)
101 | init_weight[..., conv_size//2, conv_size//2] = 1.
102 | self.conv_real.net[1].weight = nn.Parameter(init_weight.detach().requires_grad_(True))
103 | self.conv_imag.net[1].weight = nn.Parameter(init_weight.detach().requires_grad_(True))
104 | else:
105 | self.conv = Conv2dSame(1, 1, conv_size, bias=False)
106 | init_weight = torch.zeros(1, 1, *self.conv_size)
107 | init_weight[..., conv_size//2, conv_size//2] = 1.
108 | self.conv.net[1].weight = nn.Parameter(init_weight.requires_grad_(True))
109 |
110 | def forward(self, input_field):
111 | # check if input is light field
112 | if len(input_field.shape) > 4:
113 | lf_batch_size = input_field.shape[0]
114 | num_ch = input_field.shape[1]
115 | num_y = input_field.shape[4]
116 | num_x = input_field.shape[5]
117 | input_field = input_field.permute(0, 4, 5, 1, 2, 3)
118 | input_field = input_field.reshape(lf_batch_size * num_y * num_x, num_ch, *input_field.shape[-2:])
119 | lf = True
120 | else:
121 | lf = False
122 |
123 | # reshape tensor if number of channels > 1
124 | num_ch = input_field.shape[1]
125 | if num_ch > 1:
126 | batch_size = input_field.shape[0]
127 | input_field = input_field.reshape(batch_size * num_ch, 1, *input_field.shape[2:])
128 |
129 | if self.complex:
130 | # apply conv on complex fields
131 | real = self.conv_real(input_field.real) - self.conv_imag(input_field.imag)
132 | imag = self.conv_real(input_field.imag) + self.conv_imag(input_field.real)
133 | output_field = torch.view_as_complex(torch.stack((real, imag), -1))
134 | else:
135 | # apply conv on intensity
136 | output_amp = self.conv(input_field.abs()**2).abs().mean(dim=1, keepdims=True).sqrt()
137 | output_field = output_amp * torch.exp(1j * input_field.angle())
138 |
139 | # reshape to original tensor shape
140 | if num_ch > 1:
141 | output_field = output_field.reshape(batch_size, num_ch, *output_field.shape[2:])
142 |
143 | if lf:
144 | output_field = output_field.reshape(lf_batch_size, num_y, num_x, num_ch, *output_field.shape[-2:])
145 | output_field = output_field.permute(0, 3, 4, 5, 1, 2)
146 |
147 | return output_field
148 |
149 |
150 | class LatentCodedMLP(nn.Module):
151 | """
152 | concatenate latent codes in the middle of forward pass as well.
153 | put latent codes shape of (1, L, H, W) as a parameter for the forward pass.
154 | num_latent_codes: list of numbers of slices for each layer
155 | * so the sum of num_latent_codes should be total number of the latent codes channels
156 | """
157 | def __init__(self, num_layers=5, num_features=32, norm=None, num_latent_codes=None):
158 | super(LatentCodedMLP, self).__init__()
159 |
160 | if num_latent_codes is None:
161 | num_latent_codes = [0] * num_layers
162 |
163 | assert len(num_latent_codes) == num_layers
164 |
165 | self.num_latent_codes = num_latent_codes
166 | self.idxs = [sum(num_latent_codes[:y]) for y in range(num_layers + 1)]
167 | self.nets = nn.ModuleList([])
168 | num_features = [num_features] * num_layers
169 | num_features[0] = 1
170 |
171 | # define each layer
172 | for i in range(num_layers - 1):
173 | net = [nn.Conv2d(num_features[i] + num_latent_codes[i], num_features[i + 1], kernel_size=1)]
174 | if norm is not None:
175 | net += [norm(num_groups=4, num_channels=num_features[i + 1], affine=True)]
176 | net += [nn.LeakyReLU(0.2, True)]
177 | self.nets.append(nn.Sequential(*net))
178 |
179 | self.nets.append(nn.Conv2d(num_features[-1] + num_latent_codes[-1], 1, kernel_size=1))
180 |
181 | for m in self.modules():
182 | if isinstance(m, nn.Conv2d):
183 | nn.init.normal_(m.weight, std=0.05)
184 |
185 | def forward(self, phases, latent_codes=None):
186 |
187 | after_relu = phases
188 | # concatenate latent codes at each layer and send through the convolutional layers
189 | for i in range(len(self.num_latent_codes)):
190 | if latent_codes is not None:
191 | latent_codes_b = latent_codes.repeat(phases.shape[0], 1, 1, 1)
192 | after_relu = torch.cat((after_relu, latent_codes_b[:, self.idxs[i]:self.idxs[i + 1], ...]), 1)
193 | after_relu = self.nets[i](after_relu)
194 |
195 | # residual connection
196 | return phases - after_relu
197 |
198 |
199 | class ContentDependentField(nn.Module):
200 | def __init__(self, num_layers=5, num_features=32, norm=nn.GroupNorm, latent_coords=False):
201 | """ Simple 5layers CNN modeling content dependent undiffracted light """
202 |
203 | super(ContentDependentField, self).__init__()
204 |
205 | if not latent_coords:
206 | first_ch = 1
207 | else:
208 | first_ch = 3
209 |
210 | net = [Conv2dSame(first_ch, num_features, kernel_size=3)]
211 |
212 | for i in range(num_layers - 2):
213 | if norm is not None:
214 | net += [norm(num_groups=2, num_channels=num_features, affine=True)]
215 | net += [nn.LeakyReLU(0.2, True),
216 | Conv2dSame(num_features, num_features, kernel_size=3)]
217 |
218 | if norm is not None:
219 | net += [norm(num_groups=4, num_channels=num_features, affine=True)]
220 |
221 | net += [nn.LeakyReLU(0.2, True),
222 | Conv2dSame(num_features, 2, kernel_size=3)]
223 |
224 | self.net = nn.Sequential(*net)
225 |
226 | def forward(self, phases, latent_coords=None):
227 | if latent_coords is not None:
228 | input_cnn = torch.cat((phases, latent_coords), dim=1)
229 | else:
230 | input_cnn = phases
231 |
232 | return self.net(input_cnn)
233 |
234 |
235 | class ProcessPhase(nn.Module):
236 | def __init__(self, num_layers=5, num_features=32, num_output_feat=0, norm=nn.BatchNorm2d, num_latent_codes=0):
237 | super(ProcessPhase, self).__init__()
238 |
239 | # avoid zero
240 | self.num_output_feat = max(num_output_feat, 1)
241 | self.num_latent_codes = num_latent_codes
242 |
243 | # a bunch of 1x1 conv layers, set by num_layers
244 | net = [nn.Conv2d(1 + num_latent_codes, num_features, kernel_size=1)]
245 |
246 | for i in range(num_layers - 2):
247 | if norm is not None:
248 | net += [norm(num_groups=2, num_channels=num_features, affine=True)]
249 | net += [nn.LeakyReLU(0.2, True),
250 | nn.Conv2d(num_features, num_features, kernel_size=1)]
251 |
252 | if norm is not None:
253 | net += [norm(num_groups=2, num_channels=num_features, affine=True)]
254 |
255 | net += [nn.ReLU(True),
256 | nn.Conv2d(num_features, self.num_output_feat, kernel_size=1)]
257 |
258 | self.net = nn.Sequential(*net)
259 |
260 | def forward(self, phases):
261 | return phases - self.net(phases)
262 |
263 |
264 | class SourceAmplitude(nn.Module):
265 | def __init__(self, num_gaussians=3, init_sigma=None, init_amp=0.7, x_s0=0.0, y_s0=0.0):
266 | super(SourceAmplitude, self).__init__()
267 |
268 | self.num_gaussians = num_gaussians
269 |
270 | if init_sigma is None:
271 | init_sigma = [100.] * self.num_gaussians # default to 100 for all
272 |
273 | # create parameters for source amplitudes
274 | self.sigmas = nn.Parameter(torch.tensor(init_sigma))
275 | self.x_s = nn.Parameter(torch.ones(num_gaussians) * x_s0)
276 | self.y_s = nn.Parameter(torch.ones(num_gaussians) * y_s0)
277 | self.amplitudes = nn.Parameter(torch.ones(num_gaussians) / (num_gaussians) * init_amp)
278 | self.dc_term = nn.Parameter(torch.zeros(1))
279 |
280 | self.x_dim = None
281 | self.y_dim = None
282 |
283 | def forward(self, phases):
284 | # create DC term, then add the gaussians
285 | source_amp = torch.ones_like(phases) * self.dc_term
286 | for i in range(self.num_gaussians):
287 | source_amp += self.create_gaussian(phases.shape, i)
288 |
289 | return source_amp
290 |
291 | def create_gaussian(self, shape, idx):
292 | # create sampling grid if needed
293 | if self.x_dim is None or self.y_dim is None:
294 | self.x_dim = torch.linspace(-(shape[-1] - 1) / 2,
295 | (shape[-1] - 1) / 2,
296 | shape[-1], device=self.dc_term.device)
297 | self.y_dim = torch.linspace(-(shape[-2] - 1) / 2,
298 | (shape[-2] - 1) / 2,
299 | shape[-2], device=self.dc_term.device)
300 |
301 | if self.x_dim.device != self.sigmas.device:
302 | self.x_dim.to(self.sigmas.device).detach()
303 | self.x_dim.requires_grad = False
304 | if self.y_dim.device != self.sigmas.device:
305 | self.y_dim.to(self.sigmas.device).detach()
306 | self.y_dim.requires_grad = False
307 |
308 | # offset grid by coordinate and compute x and y gaussian components
309 | x_gaussian = torch.exp(-0.5 * torch.pow(torch.div(self.x_dim - self.x_s[idx], self.sigmas[idx]), 2))
310 | y_gaussian = torch.exp(-0.5 * torch.pow(torch.div(self.y_dim - self.y_s[idx], self.sigmas[idx]), 2))
311 |
312 | # outer product with amplitude scaling
313 | gaussian = torch.ger(self.amplitudes[idx] * y_gaussian, x_gaussian)
314 |
315 | return gaussian
316 |
317 |
318 | class FiniteDiffField(nn.Module):
319 | def __init__(self):
320 | super(FiniteDiffField, self).__init__()
321 | pass
322 |
323 | def forward(self, model, slm_phase, delta_phase):
324 | # delta phase is the phase difference to be added to the input phase.
325 | # Can sample some SLM locations each iteration
326 |
327 | field_1 = model(slm_phase)
328 | field_2 = model(slm_phase + delta_phase)
329 | # size H*W, which is the ith column of Jacobian df/d(phi)
330 | delta_field = (field_2 - field_1) / delta_phase
331 | return delta_field
332 |
333 |
334 | def make_kernel_gaussian(sigma, kernel_size):
335 |
336 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
337 | x_cord = torch.arange(kernel_size)
338 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
339 | y_grid = x_grid.t()
340 | xy_grid = torch.stack([x_grid, y_grid], dim=-1)
341 |
342 | mean = (kernel_size - 1) / 2
343 | variance = sigma**2
344 |
345 | # Calculate the 2-dimensional gaussian kernel which is
346 | # the product of two gaussian distributions for two different
347 | # variables (in this case called x and y)
348 | gaussian_kernel = ((1 / (2 * math.pi * variance))
349 | * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1)
350 | / (2 * variance)))
351 | # Make sure sum of values in gaussian kernel equals 1.
352 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
353 |
354 | # Reshape to 2d depthwise convolutional weight
355 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
356 |
357 | return gaussian_kernel
358 |
359 |
360 | def create_gaussian(shape, sigma=800, dev=torch.device('cuda')):
361 | # create sampling grid if needed
362 | shape_min = min(shape[-1], shape[-2])
363 | x_dim = torch.linspace(-(shape_min - 1) / 2,
364 | (shape_min - 1) / 2,
365 | shape[-1], device=dev)
366 | y_dim = torch.linspace(-(shape_min - 1) / 2,
367 | (shape_min - 1) / 2,
368 | shape[-2], device=dev)
369 |
370 | # offset grid by coordinate and compute x and y gaussian components
371 | x_gaussian = torch.exp(-0.5 * torch.pow(torch.div(x_dim, sigma), 2))
372 | y_gaussian = torch.exp(-0.5 * torch.pow(torch.div(y_dim, sigma), 2))
373 |
374 | # outer product with amplitude scaling
375 | gaussian = torch.ger(y_gaussian, x_gaussian)
376 |
377 | return gaussian
378 |
379 |
380 |
381 |
382 |
--------------------------------------------------------------------------------
/props/prop_zernike.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for zernike basis
3 |
4 | """
5 |
6 | import torch
7 | import numpy as np
8 | import utils
9 | import torch.fft
10 | from aotools.functions import zernikeArray
11 |
12 |
13 | def combine_zernike_basis(coeffs, basis, return_phase=False):
14 | """
15 | Multiplies the Zernike coefficients and basis functions while preserving
16 | dimensions
17 |
18 | :param coeffs: torch tensor with coeffs, see propagation_ASM_zernike
19 | :param basis: the output of compute_zernike_basis, must be same length as coeffs
20 | :param return_phase:
21 | :return: A float32 tensor that combines coeffs and basis.
22 | """
23 |
24 | if len(coeffs.shape) < 3:
25 | coeffs = torch.reshape(coeffs, (coeffs.shape[0], 1, 1))
26 |
27 | # combine zernike basis and coefficients
28 | zernike = (coeffs * basis).sum(0, keepdim=True)
29 |
30 | # shape to [1, len(coeffs), H, W]
31 | zernike = zernike.unsqueeze(0)
32 |
33 | return zernike
34 |
35 |
36 | def compute_zernike_basis(num_polynomials, field_res, dtype=torch.float32, wo_piston=False):
37 | """Computes a set of Zernike basis function with resolution field_res
38 |
39 | num_polynomials: number of Zernike polynomials in this basis
40 | field_res: [height, width] in px, any list-like object
41 | dtype: torch dtype for computation at different precision
42 | """
43 |
44 | # size the zernike basis to avoid circular masking
45 | zernike_diam = int(np.ceil(np.sqrt(field_res[0]**2 + field_res[1]**2)))
46 |
47 | # create zernike functions
48 |
49 | if not wo_piston:
50 | zernike = zernikeArray(num_polynomials, zernike_diam)
51 | else: # 200427 - exclude pistorn term
52 | idxs = range(2, 2 + num_polynomials)
53 | zernike = zernikeArray(idxs, zernike_diam)
54 |
55 | zernike = utils.crop_image(zernike, field_res, pytorch=False)
56 |
57 | # convert to tensor and create phase
58 | zernike = torch.tensor(zernike, dtype=dtype, requires_grad=False)
59 |
60 | return zernike
61 |
--------------------------------------------------------------------------------
/quantization.py:
--------------------------------------------------------------------------------
1 | """
2 | Quantization modules using projected gradient-descent, surrogate gradients, and Gumbel-Softmax.
3 |
4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu)
5 | """
6 |
7 | import math
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import numpy as np
12 | from PIL import Image
13 |
14 | import utils
15 | import hw.ti as ti
16 | from hw.discrete_slm import DiscreteSLM
17 |
18 |
19 | def load_lut(sim_prop, opt):
20 | lut = None
21 | if hasattr(sim_prop, 'lut'):
22 | if sim_prop.lut is not None:
23 | lut = sim_prop.lut.squeeze().cpu().detach().numpy().tolist()
24 | else:
25 | # here directly sets lut to given 17 level lut,
26 | # no matter what, if quan_method = True, just set it to TI SLM levels
27 | lut = ti.given_lut
28 | if opt.channel is not None:
29 | lut = np.array(lut) * opt.wavelengths[1] / opt.wavelengths[opt.channel]
30 | print("given lut...")
31 |
32 | # TODO: work to remove this line
33 | if lut is not None and len(lut) % 2 == 0:
34 | lut.append(lut[0] + 2 * math.pi) # for lut_mid
35 |
36 | print(f'LUT: {lut}')
37 | return lut
38 |
39 |
40 | def tau_iter(quan_fn, iter_frac, tau_min, tau_max, r=None):
41 | if 'softmax' in quan_fn:
42 | if r is None:
43 | r = math.log(tau_max / tau_min)
44 | tau = max(tau_min, tau_max * math.exp(-r * iter_frac))
45 | elif 'sigmoid' in quan_fn or 'poly' in quan_fn:
46 | tau = 1 + 10 * iter_frac
47 | else:
48 | tau = None
49 | return tau
50 |
51 |
52 | def quantization(opt, lut):
53 | if opt.quan_method == 'None':
54 | qtz = None
55 | else:
56 | qtz = Quantization(opt.quan_method, lut=lut, c=opt.c_s, num_bits=opt.uniform_nbits if lut is None else 4,
57 | tau_max=opt.tau_max, tau_min=opt.tau_min, r=opt.r, offset=opt.phase_offset)
58 |
59 | return qtz
60 |
61 |
62 | def score_phase(phase, lut, s=5., func='sigmoid'):
63 | # Here s is kinda representing the steepness
64 |
65 | wrapped_phase = (phase + math.pi) % (2 * math.pi) - math.pi
66 |
67 | diff = wrapped_phase - lut
68 | diff = (diff + math.pi) % (2*math.pi) - math.pi # signed angular difference
69 | diff /= math.pi # normalize
70 |
71 | if func == 'sigmoid':
72 | z = s * diff
73 | scores = torch.sigmoid(z) * (1 - torch.sigmoid(z)) * 4
74 | elif func == 'log':
75 | scores = -torch.log(diff.abs() + 1e-20) * s
76 | elif func == 'poly':
77 | scores = (1-torch.abs(diff)**s)
78 | elif func == 'sine':
79 | scores = torch.cos(math.pi * (s * diff).clamp(-1., 1.))
80 | elif func == 'chirp':
81 | scores = 1 - torch.cos(math.pi * (1-diff.abs())**s)
82 |
83 | return scores
84 |
85 |
86 | # Basic function for NN-based quantization, customize it with various surrogate gradients!
87 | class NearestNeighborSearch(torch.autograd.Function):
88 |
89 | @staticmethod
90 | def forward(ctx, phase, s=torch.tensor(1.0)):
91 | phase_raw = phase.detach()
92 | idx = utils.nearest_idx(phase_raw, DiscreteSLM.lut_midvals)
93 | phase_q = DiscreteSLM.lut[idx]
94 | ctx.mark_non_differentiable(idx)
95 | ctx.save_for_backward(phase_raw, s, phase_q, idx)
96 | return phase_q
97 |
98 | def backward(ctx, grad_output):
99 | return grad_output, None
100 |
101 |
102 | class NearestNeighborPolyGrad(NearestNeighborSearch):
103 |
104 | @staticmethod
105 | def forward(ctx, phase, s=torch.tensor(1.0)):
106 | return NearestNeighborSearch.forward(ctx, phase, s)
107 |
108 | def backward(ctx, grad_output):
109 | input, s, output, idx = ctx.saved_tensors
110 | grad_input = grad_output.clone()
111 |
112 | dx = input - output
113 | d_idx = (dx / torch.abs(dx)).int().nan_to_num()
114 | other_end = DiscreteSLM.lut[(idx + d_idx)].to(input.device) # far end not selected for quantization
115 |
116 | # normalization
117 | mid_point = (other_end + output) / 2
118 | gap = torch.abs(other_end - output) + 1e-20
119 | z = (input - mid_point) / gap * 2 # normalize to [-1. 1]
120 |
121 | dout_din = (0.5 * s * (1 - abs(z)) ** (s - 1)).nan_to_num()
122 | scale = 2. #* dout_din.mean() / ((dout_din**2).mean() + 1e-20)
123 | grad_input *= (dout_din * scale) # scale according to distance
124 |
125 | return grad_input, None
126 |
127 |
128 | class NearestNeighborSigmoidGrad(NearestNeighborSearch):
129 |
130 | @staticmethod
131 | def forward(ctx, phase, s=torch.tensor(1.0)):
132 | return NearestNeighborSearch.forward(ctx, phase, s)
133 |
134 | def backward(ctx, grad_output):
135 | x, s, output, idx = ctx.saved_tensors
136 | grad_input = grad_output.clone()
137 |
138 | dx = x - output
139 | d_idx = (dx / torch.abs(dx)).int().nan_to_num()
140 | other_end = DiscreteSLM.lut[(idx + d_idx)].to(x.device) # far end not selected for quantization
141 |
142 | # normalization
143 | mid_point = (other_end + output) / 2
144 | gap = torch.abs(other_end - output) + 1e-20
145 | z = (x - mid_point) / gap * 2 # normalize to [-1, 1]
146 | z *= s
147 |
148 | dout_din = (torch.sigmoid(z) * (1 - torch.sigmoid(z)))
149 | scale = 4. * s#1 / 0.462 * gap * s#dout_din.mean() / ((dout_din**2).mean() + 1e-20) # =100
150 | grad_input *= (dout_din * scale)
151 |
152 | return grad_input, None
153 |
154 |
155 | nns = NearestNeighborSearch.apply
156 | nns_poly = NearestNeighborPolyGrad.apply
157 | nns_sigmoid = NearestNeighborSigmoidGrad.apply
158 |
159 |
160 | class SoftmaxBasedQuantization(nn.Module):
161 | def __init__(self, lut, gumbel=True, tau_max=3.0, c=300.):
162 | super(SoftmaxBasedQuantization, self).__init__()
163 |
164 | if not torch.is_tensor(lut):
165 | self.lut = torch.tensor(lut, dtype=torch.float32)
166 | else:
167 | self.lut = lut
168 | self.lut = self.lut.reshape(1, len(lut), 1, 1)
169 | self.c = c # boost the score
170 | self.gumbel = gumbel
171 | self.tau_max = tau_max
172 |
173 | def forward(self, phase, tau=1.0, hard=False):
174 | phase_wrapped = (phase + math.pi) % (2*math.pi) - math.pi
175 |
176 | # phase to score
177 | scores = score_phase(phase_wrapped, self.lut.to(phase_wrapped.device), (self.tau_max / tau)**1) * self.c * (self.tau_max / tau)**1.0
178 |
179 | # score to one-hot encoding
180 | if self.gumbel: # (N, 1, H, W) -> (N, C, H, W)
181 | one_hot = F.gumbel_softmax(scores, tau=tau, hard=hard, dim=1)
182 | else:
183 | y_soft = F.softmax(scores/tau, dim=1)
184 | index = y_soft.max(1, keepdim=True)[1]
185 | one_hot_hard = torch.zeros_like(scores,
186 | memory_format=torch.legacy_contiguous_format).scatter_(1, index, 1.0)
187 | if hard:
188 | one_hot = one_hot_hard + y_soft - y_soft.detach()
189 | else:
190 | one_hot = y_soft
191 |
192 | # one-hot encoding to phase value
193 | q_phase = (one_hot * self.lut.to(one_hot.device))
194 | q_phase = q_phase.sum(1, keepdims=True)
195 | return q_phase
196 |
197 |
198 | class Quantization(nn.Module):
199 | def __init__(self, method=None, num_bits=4, lut=None, dev=torch.device('cuda'),
200 | tau_min=0.5, tau_max=3.0, r=None, c=300., offset=0.0):
201 | super(Quantization, self).__init__()
202 | if lut is None:
203 | # linear look-up table
204 | DiscreteSLM.lut = torch.linspace(-math.pi, math.pi, 2**num_bits + 1).to(dev)
205 | else:
206 | # non-linear look-up table
207 | assert len(lut) == (2**num_bits) + 1
208 | DiscreteSLM.lut = torch.tensor(lut, dtype=torch.float32).to(dev)
209 |
210 | self.quan_fn = None
211 | self.gumbel = 'gumbel' in method.lower()
212 | if method.lower() == 'nn':
213 | self.quan_fn = nns
214 | elif method.lower() == 'nn_sigmoid':
215 | self.quan_fn = nns_sigmoid
216 | elif method.lower() == 'nn_poly':
217 | self.quan_fn = nns_poly
218 | elif 'softmax' in method.lower():
219 | self.quan_fn = SoftmaxBasedQuantization(DiscreteSLM.lut[:-1], self.gumbel, tau_max=tau_max, c=c)
220 |
221 | self.method = method
222 | self.tau_min = tau_min
223 | self.tau_max = tau_max
224 | self.r = r
225 | self.offset = offset
226 |
227 | def forward(self, input_phase, iter_frac=None, hard=True):
228 | if iter_frac is not None:
229 | tau = tau_iter(self.method, iter_frac, self.tau_min, self.tau_max, self.r)
230 | wrapped_phase = (input_phase + self.offset + math.pi) % (2 * math.pi) - math.pi
231 | if self.quan_fn is None:
232 | return wrapped_phase
233 | else:
234 | if isinstance(tau, float):
235 | tau = torch.tensor(tau, dtype=torch.float32).to(input_phase.device)
236 | if 'nn' in self.method.lower():
237 | s = tau
238 | return self.quan_fn(wrapped_phase, s)
239 | else:
240 | return self.quan_fn(wrapped_phase, tau, hard)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | A script for model training
3 |
4 | Any questions about the code can be addressed to Suyeon Choi (suyeon@stanford.edu)
5 |
6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
7 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
8 | # The material is provided as-is, with no warranties whatsoever.
9 | # If you publish any code, data, or scientific work based on this, please cite our work.
10 |
11 | Technical Paper:
12 | Time-multiplexed Neural Holography:
13 | A Flexible Framework for Holographic Near-eye Displays with Fast Heavily-quantized Spatial Light Modulators
14 | S. Choi*, M. Gopakumar*, Y. Peng, J. Kim, Matthew O'Toole, G. Wetzstein.
15 | SIGGRAPH 2022
16 | """
17 | import os
18 | import configargparse
19 | import pytorch_lightning as pl
20 | from pytorch_lightning import Trainer
21 | from torch.utils.data import DataLoader
22 |
23 | import utils
24 | import params
25 | import props.prop_model as prop_model
26 | import image_loader as loaders
27 | import torch
28 | import os
29 |
30 |
31 | # Command line argument processing
32 | p = configargparse.ArgumentParser()
33 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
34 | p.add('--capture_subset', type=str, default=None)
35 |
36 | params.add_parameters(p, 'train')
37 | opt = params.set_configs(p.parse_args())
38 | run_id = params.run_id_training(opt)
39 |
40 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
41 | os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu_id)
42 |
43 | if opt.gpu_id > 0:
44 | # torch.cuda.set_device(opt.gpu_id)
45 | print(f"Using gpu {opt.gpu_id} ...")
46 |
47 | def main():
48 | if ',' in opt.data_path:
49 | opt.data_path = opt.data_path.split(',')
50 | else:
51 | opt.data_path = [opt.data_path]
52 | print(f' - training a model ... Dataset path:{opt.data_path}')
53 | # Setup up dataloaders
54 | num_workers = 4
55 | # modify plane idxes!
56 | train_loader = DataLoader(loaders.PairsLoader([os.path.join(path, 'train') for path in opt.data_path],
57 | plane_idxs=opt.plane_idxs['train'], image_res=opt.image_res,
58 | avg_energy_ratio=opt.avg_energy_ratio, slm_type=opt.slm_type,
59 | capture_subset=opt.capture_subset, dataset_subset=opt.dataset_subset),
60 | num_workers=num_workers, batch_size=opt.batch_size, pin_memory=True)
61 | val_loader = DataLoader(loaders.PairsLoader([os.path.join(path, 'val') for path in opt.data_path],
62 | plane_idxs=opt.plane_idxs['train'], image_res=opt.image_res,
63 | shuffle=False, avg_energy_ratio=opt.avg_energy_ratio,
64 | slm_type=opt.slm_type, capture_subset=opt.capture_subset),
65 | num_workers=num_workers, batch_size=opt.batch_size, shuffle=False, pin_memory=True)
66 | test_loader = DataLoader(loaders.PairsLoader([os.path.join(path, 'test') for path in opt.data_path],
67 | plane_idxs=opt.plane_idxs['all'], image_res=opt.image_res,
68 | shuffle=False, avg_energy_ratio=opt.avg_energy_ratio, slm_type=opt.slm_type),
69 | num_workers=num_workers, batch_size=opt.batch_size, shuffle=False, pin_memory=True)
70 |
71 | # Init model
72 | if opt.slm_type == 'ti':
73 | opt.roi_res = (760, 1240) # mofidy here!. should be 700, 1190?
74 | else:
75 | opt.roi_res = (840, 1200)
76 | model = prop_model.model(opt)
77 | model.train()
78 |
79 | # Init root path
80 | root_dir = os.path.join(opt.out_path, run_id)
81 | utils.cond_mkdir(root_dir)
82 | p.write_config_file(opt, [os.path.join(root_dir, 'config.txt')])
83 |
84 | psnr_checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="PSNR_validation_epoch", dirpath=root_dir,
85 | filename="model-{epoch:02d}-{PSNR_validation_epoch:.2f}",
86 | save_top_k=1, mode="max", )
87 | latest_checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="PSNR_validation_epoch", dirpath=root_dir,
88 | filename="model-latest-{PSNR_validation_epoch:.2f}",
89 | every_n_epochs=1, save_last=True)
90 |
91 | # Init trainer
92 | trainer = Trainer(default_root_dir=root_dir, accelerator='gpu',
93 | log_every_n_steps=400, gpus=1, max_epochs=opt.num_epochs, callbacks=[psnr_checkpoint_callback, latest_checkpoint_callback])
94 |
95 | # Fit Model
96 | trainer.fit(model, train_loader, val_loader)
97 |
98 | # Test Model
99 | trainer.test(model, dataloaders=test_loader)
100 |
101 |
102 | if __name__ == "__main__":
103 | main()
--------------------------------------------------------------------------------
/unet.py:
--------------------------------------------------------------------------------
1 | """
2 | U-net implementations
3 |
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import init
9 | import functools
10 |
11 |
12 | def norm_layer(norm_str):
13 | if norm_str.lower() == 'instance':
14 | return nn.InstanceNorm2d
15 | elif norm_str.lower() == 'group':
16 | return nn.GroupNorm
17 | elif norm_str.lower() == 'batch':
18 | return nn.BatchNorm2d
19 |
20 |
21 | class UnetSkipConnectionBlock(nn.Module):
22 | """Defines the Unet submodule with skip connection.
23 | X -------------------identity----------------------
24 | |-- downsampling -- |submodule| -- upsampling --|
25 | """
26 |
27 | def __init__(self, outer_nc, inner_nc, input_nc=None,
28 | submodule=None, outermost=False, innermost=False,
29 | norm_layer=nn.InstanceNorm2d, use_dropout=False,
30 | outer_skip=False):
31 | """Construct a Unet submodule with skip connections.
32 | Parameters:
33 | outer_nc (int) -- the number of filters in the outer conv layer
34 | inner_nc (int) -- the number of filters in the inner conv layer
35 | input_nc (int) -- the number of channels in input images/features
36 | submodule (UnetSkipConnectionBlock) -- previously defined submodules
37 | outermost (bool) -- if this module is the outermost module
38 | innermost (bool) -- if this module is the innermost module
39 | norm_layer -- normalization layer
40 | use_dropout (bool) -- if use dropout layers.
41 | """
42 | super(UnetSkipConnectionBlock, self).__init__()
43 | self.outermost = outermost
44 | self.outer_skip = outer_skip
45 | if norm_layer == None:
46 | use_bias = True
47 | elif type(norm_layer) == functools.partial:
48 | use_bias = norm_layer.func == nn.InstanceNorm2d
49 | else:
50 | use_bias = norm_layer == nn.InstanceNorm2d
51 | if input_nc is None:
52 | input_nc = outer_nc
53 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=5,
54 | # Change kernel size changed to 5 from 4 and padding size from 1 to 2
55 | stride=2, padding=2, bias=use_bias)
56 | downrelu = nn.LeakyReLU(0.2, True)
57 | if norm_layer is not None:
58 | if norm_layer == nn.GroupNorm:
59 | downnorm = norm_layer(8, inner_nc)
60 | else:
61 | downnorm = norm_layer(inner_nc)
62 | else:
63 | downnorm = None
64 | uprelu = nn.ReLU(True)
65 | if norm_layer is not None:
66 | if norm_layer == nn.GroupNorm:
67 | upnorm = norm_layer(8, outer_nc)
68 | else:
69 | upnorm = norm_layer(outer_nc)
70 | else:
71 | upnorm = None
72 |
73 | if outermost:
74 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
75 | kernel_size=4, stride=2,
76 | padding=1)
77 | down = [downconv, downrelu]
78 | up = [upconv] # Removed tanh and uprelu
79 | model = down + [submodule] + up
80 | elif innermost:
81 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
82 | kernel_size=4, stride=2,
83 | padding=1, bias=use_bias)
84 | if norm_layer is not None:
85 | down = [downconv, downnorm, downrelu]
86 | up = [upconv, upnorm, uprelu]
87 | else:
88 | down = [downconv, downrelu]
89 | up = [upconv, uprelu]
90 |
91 | model = down + up
92 | else:
93 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
94 | kernel_size=4, stride=2,
95 | padding=1, bias=use_bias)
96 | if norm_layer is not None:
97 | down = [downconv, downnorm, downrelu]
98 | up = [upconv, upnorm, uprelu]
99 | else:
100 | down = [downconv, downrelu]
101 | up = [upconv, uprelu]
102 |
103 | if use_dropout:
104 | model = down + [submodule] + up + [nn.Dropout(0.5)]
105 | else:
106 | model = down + [submodule] + up
107 |
108 | self.model = nn.Sequential(*model)
109 |
110 | def forward(self, x):
111 | if self.outermost and not self.outer_skip:
112 | return self.model(x)
113 | else: # add skip connections
114 | return torch.cat([x, self.model(x)], 1)
115 |
116 | def init_latent(latent_num, wavefront_res, ones=False):
117 | if latent_num > 0:
118 | if ones:
119 | latent = nn.Parameter(torch.ones(1, latent_num, *wavefront_res,
120 | requires_grad=True))
121 | else:
122 | latent = nn.Parameter(torch.zeros(1, latent_num, *wavefront_res,
123 | requires_grad=True))
124 | else:
125 | latent = None
126 | return latent
127 |
128 |
129 | def apply_net(net, input, latent_code, complex=False):
130 | if net is None:
131 | return input
132 | if complex: # Only valid for single batch or single channel complex inputs and outputs
133 | multi_channel = (input.shape[1] > 1)
134 | if multi_channel:
135 | input = torch.view_as_real(input[0,...])
136 | else:
137 | input = torch.view_as_real(input[:,0,...])
138 | input = input.permute(0,3,1,2)
139 | if latent_code is not None:
140 | input = torch.cat((input, latent_code), dim=1)
141 | output = net(input)
142 | if complex:
143 | if multi_channel:
144 | output = output.permute(0,2,3,1).unsqueeze(0)
145 | else:
146 | output = output.permute(0,2,3,1).unsqueeze(1)
147 | output = torch.complex(output[...,0], output[...,1])
148 | return output
149 |
150 | def init_weights(net, init_type='normal', init_gain=0.02, outer_skip=False):
151 | """Initialize network weights.
152 | Parameters:
153 | net (network) -- network to be initialized
154 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
155 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
156 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
157 | work better for some applications. Feel free to try yourself.
158 | """
159 |
160 | def init_func(m): # define the initialization function
161 | classname = m.__class__.__name__
162 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
163 | if init_type == 'normal':
164 | init.normal_(m.weight.data, 0.0, init_gain)
165 | elif init_type == 'xavier':
166 | init.xavier_normal_(m.weight.data, gain=init_gain)
167 | elif init_type == 'kaiming':
168 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
169 | elif init_type == 'orthogonal':
170 | init.orthogonal_(m.weight.data, gain=init_gain)
171 | else:
172 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
173 | if hasattr(m, 'bias') and m.bias is not None:
174 | init.constant_(m.bias.data, 0.0)
175 | elif classname.find(
176 | 'BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
177 | init.normal_(m.weight.data, 1.0, init_gain)
178 | init.constant_(m.bias.data, 0.0)
179 |
180 | print('initialize network with %s' % init_type)
181 | net.apply(init_func) # apply the initialization function
182 |
183 |
184 | class UnetGenerator(nn.Module):
185 | """Create a Unet-based generator"""
186 |
187 | def __init__(self, input_nc=1, output_nc=1, num_downs=8, nf0=32, max_channels=512,
188 | norm_layer=nn.InstanceNorm2d, use_dropout=False, outer_skip=True,
189 | half_channels=False, eighth_channels=False):
190 | """Construct a Unet generator
191 | Parameters:
192 | input_nc (int) -- the number of channels in input images
193 | output_nc (int) -- the number of channels in output images
194 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
195 | image of size 128x128 will become of size 1x1 # at the bottleneck
196 | ngf (int) -- the number of filters in the last conv layer
197 | norm_layer -- normalization layer
198 | We construct the U-Net from the innermost layer to the outermost layer.
199 | It is a recursive process.
200 | """
201 | super(UnetGenerator, self).__init__()
202 | self.outer_skip = outer_skip
203 | self.input_nc = input_nc
204 |
205 | if eighth_channels:
206 | divisor = 8
207 | elif half_channels:
208 | divisor = 2
209 | else:
210 | divisor = 1
211 | # construct unet structure
212 |
213 | assert num_downs >= 2
214 |
215 | # Add the innermost layer
216 | unet_block = UnetSkipConnectionBlock(min(2 ** (num_downs - 1) * nf0, max_channels) // divisor,
217 | min(2 ** (num_downs - 1) * nf0, max_channels) // divisor,
218 | input_nc=None, submodule=None, norm_layer=norm_layer,
219 | innermost=True)
220 |
221 | for i in list(range(1, num_downs - 1))[::-1]:
222 | if i == 1:
223 | norm = None # Praneeth's modification
224 | else:
225 | norm = norm_layer
226 |
227 | unet_block = UnetSkipConnectionBlock(min(2 ** i * nf0, max_channels) // divisor,
228 | min(2 ** (i + 1) * nf0, max_channels) // divisor,
229 | input_nc=None, submodule=unet_block,
230 | norm_layer=norm,
231 | use_dropout=use_dropout)
232 |
233 | # Add the outermost layer
234 | self.model = UnetSkipConnectionBlock(min(nf0, max_channels) // divisor,
235 | min(2 * nf0, max_channels) // divisor,
236 | input_nc=input_nc, submodule=unet_block, outermost=True,
237 | norm_layer=None, outer_skip=self.outer_skip)
238 | if self.outer_skip:
239 | self.additional_conv = nn.Conv2d(input_nc + min(nf0, max_channels) // divisor, output_nc,
240 | kernel_size=4, stride=1, padding=2, bias=True)
241 | else:
242 | self.additional_conv = nn.Conv2d(min(nf0, max_channels) // divisor, output_nc,
243 | kernel_size=4, stride=1, padding=2, bias=True)
244 |
245 | def forward(self, cnn_input):
246 | """Standard forward"""
247 | output = self.model(cnn_input)
248 | output = self.additional_conv(output)
249 | output = output[:,:,:-1,:-1]
250 | return output
251 |
252 |
253 | class Conv2dSame(torch.nn.Module):
254 | '''2D convolution that pads to keep spatial dimensions equal.
255 | Cannot deal with stride. Only quadratic kernels (=scalar kernel_size).
256 | '''
257 |
258 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, padding_layer=nn.ReflectionPad2d):
259 | '''
260 | :param in_channels: Number of input channels
261 | :param out_channels: Number of output channels
262 | :param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported).
263 | :param bias: Whether or not to use bias.
264 | :param padding_layer: Which padding to use. Default is reflection padding.
265 | '''
266 | super().__init__()
267 | ka = kernel_size // 2
268 | kb = ka - 1 if kernel_size % 2 == 0 else ka
269 | self.net = nn.Sequential(
270 | padding_layer((ka, kb, ka, kb)),
271 | nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias, stride=1)
272 | )
273 |
274 | self.weight = self.net[1].weight
275 | self.bias = self.net[1].bias
276 |
277 | def forward(self, x):
278 | return self.net(x)
279 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utils
3 |
4 | """
5 |
6 | import math
7 | import random
8 | import numpy as np
9 |
10 | import os
11 | import torch
12 | import torch.nn as nn
13 |
14 | from skimage.metrics import peak_signal_noise_ratio as psnr
15 | from skimage.metrics import structural_similarity as ssim
16 |
17 | import torch.nn.functional as F
18 | from torchvision.utils import save_image
19 |
20 | import props.prop_model as prop_model
21 |
22 | class AverageMeter:
23 | def __init__(self):
24 | self._sum = 0
25 | self._avg = 0
26 | self._cnt = 0
27 | def update(self, val):
28 | self._sum += val
29 | self._cnt += 1
30 | self._avg = self._sum / self._cnt
31 | @property
32 | def avg(self):
33 | return self._avg
34 |
35 | def apply_func_list(func, data_list):
36 | return [func(data) for data in data_list]
37 |
38 | def post_process_amp(amp, scale=1.0):
39 | # amp is a image tensor in range [0, 1]
40 | amp = amp * scale
41 | amp = torch.clip(amp, 0, 1)
42 | amp = amp.detach().squeeze().cpu().numpy()
43 | return amp
44 |
45 | def roll_torch(tensor, shift: int, axis: int):
46 | if shift == 0:
47 | return tensor
48 |
49 | if axis < 0:
50 | axis += tensor.dim()
51 |
52 | dim_size = tensor.size(axis)
53 | after_start = dim_size - shift
54 | if shift < 0:
55 | after_start = -shift
56 | shift = dim_size - abs(shift)
57 |
58 | before = tensor.narrow(axis, 0, dim_size - shift)
59 | after = tensor.narrow(axis, after_start, shift)
60 | return torch.cat([after, before], axis)
61 |
62 |
63 | def ifftshift(tensor):
64 | """ifftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2]
65 |
66 | shifts the width and heights
67 | """
68 | size = tensor.size()
69 | tensor_shifted = roll_torch(tensor, -math.floor(size[2] / 2.0), 2)
70 | tensor_shifted = roll_torch(tensor_shifted, -math.floor(size[3] / 2.0), 3)
71 | return tensor_shifted
72 |
73 |
74 | def fftshift(tensor):
75 | """fftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2]
76 |
77 | shifts the width and heights
78 | """
79 | size = tensor.size()
80 | tensor_shifted = roll_torch(tensor, math.floor(size[2] / 2.0), 2)
81 | tensor_shifted = roll_torch(tensor_shifted, math.floor(size[3] / 2.0), 3)
82 | return tensor_shifted
83 |
84 |
85 | def pad_image(field, target_shape, pytorch=True, stacked_complex=True, padval=0, mode='constant'):
86 | """Pads a 2D complex field up to target_shape in size
87 |
88 | Padding is done such that when used with crop_image(), odd and even dimensions are
89 | handled correctly to properly undo the padding.
90 |
91 | field: the field to be padded. May have as many leading dimensions as necessary
92 | (e.g., batch or channel dimensions)
93 | target_shape: the 2D target output dimensions. If any dimensions are smaller
94 | than field, no padding is applied
95 | pytorch: if True, uses torch functions, if False, uses numpy
96 | stacked_complex: for pytorch=True, indicates that field has a final dimension
97 | representing real and imag
98 | padval: the real number value to pad by
99 | mode: padding mode for numpy or torch
100 | """
101 | if pytorch:
102 | if stacked_complex:
103 | size_diff = np.array(target_shape) - np.array(field.shape[-3:-1])
104 | odd_dim = np.array(field.shape[-3:-1]) % 2
105 | else:
106 | size_diff = np.array(target_shape) - np.array(field.shape[-2:])
107 | odd_dim = np.array(field.shape[-2:]) % 2
108 | else:
109 | size_diff = np.array(target_shape) - np.array(field.shape[-2:])
110 | odd_dim = np.array(field.shape[-2:]) % 2
111 |
112 | # pad the dimensions that need to increase in size
113 | if (size_diff > 0).any():
114 | pad_total = np.maximum(size_diff, 0)
115 | pad_front = (pad_total + odd_dim) // 2
116 | pad_end = (pad_total + 1 - odd_dim) // 2
117 |
118 | if pytorch:
119 | pad_axes = [int(p) # convert from np.int64
120 | for tple in zip(pad_front[::-1], pad_end[::-1])
121 | for p in tple]
122 | if stacked_complex:
123 | return pad_stacked_complex(field, pad_axes, mode=mode, padval=padval)
124 | else:
125 | return nn.functional.pad(field, pad_axes, mode=mode, value=padval)
126 | else:
127 | leading_dims = field.ndim - 2 # only pad the last two dims
128 | if leading_dims > 0:
129 | pad_front = np.concatenate(([0] * leading_dims, pad_front))
130 | pad_end = np.concatenate(([0] * leading_dims, pad_end))
131 | return np.pad(field, tuple(zip(pad_front, pad_end)), mode,
132 | constant_values=padval)
133 | else:
134 | return field
135 |
136 |
137 | def crop_image(field, target_shape, pytorch=True, stacked_complex=True, lf=False):
138 | """Crops a 2D field, see pad_image() for details
139 |
140 | No cropping is done if target_shape is already smaller than field
141 | """
142 | if target_shape is None:
143 | return field
144 |
145 | if lf:
146 | size_diff = np.array(field.shape[-4:-2]) - np.array(target_shape)
147 | odd_dim = np.array(field.shape[-4:-2]) % 2
148 | else:
149 | if pytorch:
150 | if stacked_complex:
151 | size_diff = np.array(field.shape[-3:-1]) - np.array(target_shape)
152 | odd_dim = np.array(field.shape[-3:-1]) % 2
153 | else:
154 | size_diff = np.array(field.shape[-2:]) - np.array(target_shape)
155 | odd_dim = np.array(field.shape[-2:]) % 2
156 | else:
157 | size_diff = np.array(field.shape[-2:]) - np.array(target_shape)
158 | odd_dim = np.array(field.shape[-2:]) % 2
159 |
160 | # crop dimensions that need to decrease in size
161 | if (size_diff > 0).any():
162 | crop_total = np.maximum(size_diff, 0)
163 | crop_front = (crop_total + 1 - odd_dim) // 2
164 | crop_end = (crop_total + odd_dim) // 2
165 |
166 | crop_slices = [slice(int(f), int(-e) if e else None)
167 | for f, e in zip(crop_front, crop_end)]
168 | if lf:
169 | return field[(..., *crop_slices, slice(None), slice(None))]
170 | else:
171 | if pytorch and stacked_complex:
172 | return field[(..., *crop_slices, slice(None))]
173 | else:
174 | return field[(..., *crop_slices)]
175 | else:
176 | return field
177 |
178 |
179 | def srgb_gamma2lin(im_in):
180 | """converts from sRGB to linear color space"""
181 | thresh = 0.04045
182 | im_out = np.where(im_in <= thresh, im_in / 12.92, ((im_in + 0.055) / 1.055)**(2.4))
183 | return im_out
184 |
185 |
186 | def srgb_lin2gamma(im_in):
187 | """converts from linear to sRGB color space"""
188 | thresh = 0.0031308
189 | im_out = np.where(im_in <= thresh, 12.92 * im_in, 1.055 * (im_in**(1 / 2.4)) - 0.055)
190 | return im_out
191 |
192 |
193 | def cond_mkdir(path):
194 | if not os.path.exists(path):
195 | os.makedirs(path)
196 |
197 |
198 | def burst_img_processor(img_burst_list):
199 | img_tensor = np.stack(img_burst_list, axis=0)
200 | img_avg = np.mean(img_tensor, axis=0)
201 | return im2float(img_avg) # changed from int8 to float32
202 |
203 |
204 | def im2float(im, dtype=np.float32):
205 | """convert uint16 or uint8 image to float32, with range scaled to 0-1
206 |
207 | :param im: image
208 | :param dtype: default np.float32
209 | :return:
210 | """
211 | if issubclass(im.dtype.type, np.floating):
212 | return im.astype(dtype)
213 | elif issubclass(im.dtype.type, np.integer):
214 | return im / dtype(np.iinfo(im.dtype).max)
215 | else:
216 | raise ValueError(f'Unsupported data type {im.dtype}')
217 |
218 |
219 | def get_psnr_ssim(recon_amp, target_amp, multichannel=False):
220 | """get PSNR and SSIM metrics"""
221 | psnrs, ssims = {}, {}
222 |
223 |
224 | # amplitude
225 | psnrs['amp'] = psnr(target_amp, recon_amp)
226 | ssims['amp'] = ssim(target_amp, recon_amp, multichannel=multichannel)
227 |
228 | # linear
229 | target_linear = target_amp**2
230 | recon_linear = recon_amp**2
231 | psnrs['lin'] = psnr(target_linear, recon_linear)
232 | ssims['lin'] = ssim(target_linear, recon_linear, multichannel=multichannel)
233 |
234 | # srgb
235 | target_srgb = srgb_lin2gamma(np.clip(target_linear, 0.0, 1.0))
236 | recon_srgb = srgb_lin2gamma(np.clip(recon_linear, 0.0, 1.0))
237 | psnrs['srgb'] = psnr(target_srgb, recon_srgb)
238 | ssims['srgb'] = ssim(target_srgb, recon_srgb, multichannel=multichannel)
239 |
240 | return psnrs, ssims
241 |
242 |
243 | def make_kernel_gaussian(sigma, kernel_size):
244 |
245 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
246 | x_cord = torch.arange(kernel_size)
247 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
248 | y_grid = x_grid.t()
249 | xy_grid = torch.stack([x_grid, y_grid], dim=-1)
250 |
251 | mean = (kernel_size - 1) / 2
252 | variance = sigma**2
253 |
254 | # Calculate the 2-dimensional gaussian kernel which is
255 | # the product of two gaussian distributions for two different
256 | # variables (in this case called x and y)
257 | gaussian_kernel = ((1 / (2 * math.pi * variance))
258 | * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1)
259 | / (2 * variance)))
260 | # Make sure sum of values in gaussian kernel equals 1.
261 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
262 |
263 | # Reshape to 2d depthwise convolutional weight
264 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
265 |
266 | return gaussian_kernel
267 |
268 |
269 | def pad_stacked_complex(field, pad_width, padval=0):
270 | if padval == 0:
271 | pad_width = (0, 0, *pad_width) # add 0 padding for stacked_complex dimension
272 | return nn.functional.pad(field, pad_width)
273 | else:
274 | if isinstance(padval, torch.Tensor):
275 | padval = padval.item()
276 |
277 | real, imag = field[..., 0], field[..., 1]
278 | real = nn.functional.pad(real, pad_width, value=padval)
279 | imag = nn.functional.pad(imag, pad_width, value=0)
280 | return torch.stack((real, imag), -1)
281 |
282 |
283 | def lut_mid(lut):
284 | return [(a + b) / 2 for a, b in zip(lut[:-1], lut[1:])]
285 |
286 |
287 | def nearest_neighbor_search(input_val, lut, lut_midvals=None):
288 | """
289 | Quantize to nearest neighbor values in lut
290 | :param input_val: input tensor
291 | :param lut: list of discrete values supported
292 | :param lut_midvals: set threshold to put into torch.searchsorted function.
293 | :return:
294 | """
295 | # if lut_midvals is None:
296 | # lut_midvals = torch.tensor(lut_mid(lut), dtype=torch.float32).to(input_val.device)
297 | idx = nearest_idx(input_val, lut_midvals)
298 | assert not torch.isnan(idx).any()
299 | return lut[idx], idx
300 |
301 |
302 | def nearest_idx(input_val, lut_midvals):
303 | """ Return nearest idx of lut per pixel """
304 | input_array = input_val.detach()
305 | len_lut = len(lut_midvals)
306 | # print(lut_midvals.shape)
307 | # idx = torch.searchsorted(lut_midvals.to(input_val.device), input_array, right=True)
308 | idx = torch.bucketize(input_array, lut_midvals.to(input_val.device), right=True)
309 |
310 | return idx % len_lut
311 |
312 |
313 | def srgb_gamma2lin(im_in):
314 | """ converts from sRGB to linear color space """
315 | thresh = 0.04045
316 | if torch.is_tensor(im_in):
317 | low_val = im_in <= thresh
318 | im_out = torch.zeros_like(im_in)
319 | im_out[low_val] = 25 / 323 * im_in[low_val]
320 | im_out[torch.logical_not(low_val)] = ((200 * im_in[torch.logical_not(low_val)] + 11)
321 | / 211) ** (12 / 5)
322 | else:
323 | im_out = np.where(im_in <= thresh, im_in / 12.92, ((im_in + 0.055) / 1.055) ** (12/5))
324 |
325 | return im_out
326 |
327 |
328 | def srgb_lin2gamma(im_in):
329 | """ converts from linear to sRGB color space """
330 | thresh = 0.0031308
331 | im_out = np.where(im_in <= thresh, 12.92 * im_in, 1.055 * (im_in**(1 / 2.4)) - 0.055)
332 | return im_out
333 |
334 |
335 | def decompose_depthmap(depthmap_virtual_D, depth_planes_D):
336 | """ decompose a depthmap image into a set of masks with depth positions (in Diopter) """
337 |
338 | num_planes = len(depth_planes_D)
339 |
340 | masks = torch.zeros(depthmap_virtual_D.shape[0], len(depth_planes_D), *depthmap_virtual_D.shape[-2:],
341 | dtype=torch.float32).to(depthmap_virtual_D.device)
342 | for k in range(len(depth_planes_D) - 1):
343 | depth_l = depth_planes_D[k]
344 | depth_h = depth_planes_D[k + 1]
345 | idxs = (depthmap_virtual_D >= depth_l) & (depthmap_virtual_D < depth_h)
346 | close_idxs = (depth_h - depthmap_virtual_D) > (depthmap_virtual_D - depth_l)
347 |
348 | # closer one
349 | mask = torch.zeros_like(depthmap_virtual_D)
350 | mask += idxs * close_idxs * 1
351 | masks[:, k, ...] += mask.squeeze(1)
352 |
353 | # farther one
354 | mask = torch.zeros_like(depthmap_virtual_D)
355 | mask += idxs * (~close_idxs) * 1
356 | masks[:, k + 1, ...] += mask.squeeze(1)
357 |
358 | # even closer ones
359 | idxs = depthmap_virtual_D >= max(depth_planes_D)
360 | mask = torch.zeros_like(depthmap_virtual_D)
361 | mask += idxs * 1
362 | masks[:, len(depth_planes_D) - 1, ...] += mask.clone().squeeze(1)
363 |
364 | # even farther ones
365 | idxs = depthmap_virtual_D < min(depth_planes_D)
366 | mask = torch.zeros_like(depthmap_virtual_D)
367 | mask += idxs * 1
368 | masks[:, 0, ...] += mask.clone().squeeze(1)
369 |
370 | # sanity check
371 | assert torch.sum(masks).item() == torch.numel(masks) / num_planes
372 |
373 | return masks
374 |
375 | def decompose_depthmap_v2(depth_batch, num_depth_planes, roi_res):
376 | """
377 | Depth (N, 1, H, W) -> Masks (N, num_depth_planes, H, W)
378 | Decompose depth map in each batch
379 | """
380 | def _decompose_depthmap(depth, num_depth_planes):
381 | depth = depth * 1000
382 | print(roi_res)
383 | depth_vals = crop_image(depth, roi_res, stacked_complex=False).ravel()
384 | npt = len(depth_vals)
385 | depth_bins = np.interp(np.linspace(0, npt, num_depth_planes),
386 | np.arange(npt),
387 | np.sort(depth_vals)).round(decimals=2)
388 |
389 | masks = []
390 | for i in range(num_depth_planes):
391 | if i < num_depth_planes - 1:
392 | min_d = depth_bins[i]
393 | max_d = depth_bins[i + 1]
394 | mask = torch.where(depth >= min_d, 1, 0) * torch.where(depth < max_d, 1, 0)
395 | else:
396 | mask = torch.where(depth >= depth_bins[-1], 1, 0)
397 | masks.append(mask)
398 | masks = torch.stack(masks)
399 | masks = torch.where(masks > 0, 1, 0).float()
400 | for i in range(num_depth_planes - 1):
401 | mask_diff = torch.logical_and(masks[i], masks[i + 1]).float()
402 | masks[i] -= mask_diff
403 | # reverse depth order
404 | masks = masks.flip(0)
405 | return masks.unsqueeze(0)
406 |
407 | masks = [_decompose_depthmap(depth.squeeze(), num_depth_planes) for depth in depth_batch]
408 | masks = torch.cat(masks, dim=0)
409 | return masks
410 |
411 |
412 |
413 | def prop_dist_to_diopter(prop_dists, focal_distance, prop_dist_inf, from_lens=True):
414 | """
415 | Calculates distance from the user in diopter unit given the propagation distance from the SLM.
416 | :param prop_dists:
417 | :param focal_distance:
418 | :param prop_dist_inf:
419 | :param from_lens:
420 | :return:
421 | """
422 | x0 = prop_dist_inf # prop distance from SLM that correcponds to optical infinity from the user
423 | f = focal_distance # focal distance of eyepiece
424 |
425 | if from_lens: # distance is from the lens
426 | diopters = [1 / (x0 + f - x) - 1 / f for x in prop_dists] # diopters from the user side
427 | else: # distance is from the user (basically adding focal length)
428 | diopters = [(x - x0) / f**2 for x in prop_dists]
429 |
430 | return diopters
431 |
432 |
433 | def switch_lf(input, mode='elemental'):
434 | spatial_res = input.shape[2:4]
435 | angular_res = input.shape[-2:]
436 | if mode == 'elemental':
437 | lf = input.permute(0, 1, 2, 4, 3, 5)
438 | elif mode == 'whole':
439 | lf = input.permute(0, 1, 4, 2, 5, 3) # show each view
440 | return lf.reshape(1, 1, *(s*a for s, a in zip(spatial_res, angular_res)))
441 |
442 |
443 | def nonnegative_mean_dilate(im):
444 | """
445 | """
446 |
447 | # take the mean filter over all pixels not equal to -1
448 | im = F.pad(im, (1, 1, 1, 1), mode='reflect')
449 | im = im.unfold(2, 3, 1).unfold(3, 3, 1)
450 | im = im.contiguous().view(im.size()[:4] + (-1, ))
451 | percent_surrounded_by_holes = ((im != -1) * (im < 0)).sum(dim=-1)/(1e-12 + (im != -1).sum(dim=-1))
452 | holes = (0.7 < percent_surrounded_by_holes)
453 | mean_im = ((im > -1) * im).sum(dim= -1)/(1e-12 + (im > -1).sum(dim=-1))
454 | im = mean_im * torch.logical_not(holes) - 1 * (0 == (im > -1).sum(dim=-1))*torch.logical_not(holes) - 2 * holes
455 |
456 | return im
457 |
458 |
459 | def generate_incoherent_stack(target_amp, depth_masks, depth_planes_depth,
460 | wavelength, pitch, focal_stack_blur_radius=1.0):
461 | """
462 |
463 | :param target_amp:
464 | :param depth_masks:
465 | :param depth_planes_depth:
466 | :param wavelength:
467 | :param pitch:
468 | :param focal_stack_blur_radius:
469 | :return:
470 | """
471 | with torch.no_grad():
472 | # Create inpainted images for better approximation of occluded regions (start with -1 for occluded regions to be inpainted, and -2 for holes)
473 | inpainted_images = depth_masks*target_amp - 2 * (1 - depth_masks)
474 | occluded_regions = torch.zeros_like(depth_masks)
475 | for j in range(depth_masks.shape[1]):
476 | for k in range(depth_masks.shape[1]):
477 | if k > j:
478 | occluded_regions[:, j, ...] = torch.logical_or(depth_masks[:, k, ...] > 0, occluded_regions[:, j, ...])
479 | inpainted_images += 1 * occluded_regions
480 |
481 | inpainting_ordering = depth_masks.clone()
482 | for j in range(depth_masks.shape[1]):
483 | buffer = 50 * math.ceil(((depth_planes_depth[-1] - depth_planes_depth[0] / pitch)* \
484 | math.sqrt(1/((2 * pitch / wavelength)**2 - 1))))
485 | for i in range(buffer):
486 | blurred_im = nonnegative_mean_dilate(inpainted_images[:, j, ...].unsqueeze(1))[:, 0, ...]
487 | inpainting_ordering[:, j, ...][torch.logical_and((inpainted_images[:, j, ...] == -1), (blurred_im >= 0))] = i + 2
488 | inpainted_images[:, j, ...][(inpainted_images[:, j, ...] == -1)] = blurred_im[(inpainted_images[:, j, ...] == -1)]
489 | closest_inpainting = torch.zeros_like(depth_masks) # tracks if depth is closest inpainting depth of the remaining planes
490 | for j in range(inpainting_ordering.shape[1]):
491 | closest_inpainting[:, j, ...] = inpainting_ordering[:, j, ...] > 0
492 | for k in range(inpainting_ordering.shape[1]):
493 | if k < j:
494 | closest_inpainting[:, j, ...] *= torch.logical_or(inpainting_ordering[:, k, ...] < 1,
495 | inpainting_ordering[:, j, ...] <= inpainting_ordering[:, k, ...])
496 |
497 | # Propagation starting with front planes to handle occlusion
498 | focal_stack = torch.zeros_like(depth_masks)
499 | unblocked_weighting = torch.ones_like(depth_masks)
500 | for j in range(focal_stack.shape[1] - 1, -1, -1):
501 | for k in range(focal_stack.shape[1] - 1, -1, -1):
502 | if k == j:
503 | focal_stack[:, k, ...] += unblocked_weighting[:, k, ...]*(target_amp[:, 0, ...]*depth_masks[:, j, ...])
504 | unblocked_weighting[:, k, ...] -= unblocked_weighting[:, k, ...]*depth_masks[:, j, ...]
505 | else:
506 | incoherent_propagator = create_diffraction_cone_propagator(focal_stack_blur_radius *
507 | abs(depth_planes_depth[j] - depth_planes_depth[k]), wavelength, pitch, depth_masks.device)
508 | focal_stack[:, k, ...] += unblocked_weighting[:,k,...] * \
509 | (incoherent_propagator((target_amp[:,0,...] * depth_masks[:,j,...]).unsqueeze(1))[:, 0, ...])
510 | unblocked_weighting[:, k, ...] -= unblocked_weighting[:, k, ...] * \
511 | (incoherent_propagator((depth_masks[:, j, ...]).unsqueeze(1))[:, 0, ...])
512 |
513 | # Propagate inpainted content where necessary
514 | for j in range(focal_stack.shape[1] - 1, -1, -1):
515 | for k in range(focal_stack.shape[1] - 1, -1, -1):
516 | if k == j:
517 | focal_stack[:, k, ...] += unblocked_weighting[:, k, ...] * inpainted_images[:, j, ...] *\
518 | (inpainted_images[:, j, ...] >= 0) * closest_inpainting[:, j, ...]
519 | unblocked_weighting[:, k, ...] -= unblocked_weighting[:,k,...]*closest_inpainting[:, j, ...] * (inpainted_images[:, j, ...] >= 0)
520 | else:
521 | incoherent_propagator = create_diffraction_cone_propagator(focal_stack_blur_radius * abs(depth_planes_depth[j] - depth_planes_depth[k]),
522 | wavelength, pitch, depth_masks.device)
523 | focal_stack[:, k, ...] += unblocked_weighting[:, k, ...] * \
524 | (incoherent_propagator((inpainted_images[:, j, ...] *
525 | (inpainted_images[:, j, ...] >= 0)).unsqueeze(1))[:, 0, ...]) \
526 | * closest_inpainting[:,j,...]
527 | unblocked_weighting[:, k, ...] -= unblocked_weighting[:, k, ...]*closest_inpainting[:, j, ...] * \
528 | (incoherent_propagator(1.0 * (inpainted_images[:, j, ...] >= 0).unsqueeze(1))[:, 0, ...])
529 |
530 | return focal_stack
531 |
532 |
533 | def create_diffraction_cone_propagator(distance, wavelength, pitch, device):
534 | """ Create blur layer for incoherent propagation """
535 | with torch.no_grad():
536 | subhologram_halfsize = ((distance/pitch)* \
537 | math.sqrt(1/((2*pitch/wavelength)**2-1)))
538 | kernel = np.zeros((2*math.ceil(subhologram_halfsize)+5, 2*math.ceil(subhologram_halfsize)+5))
539 | y,x = np.ogrid[-math.ceil(subhologram_halfsize)-2:math.ceil(subhologram_halfsize)+3, -math.ceil(subhologram_halfsize)-2:math.ceil(subhologram_halfsize)+3]
540 | mask = x**2+y**2 <= subhologram_halfsize**2
541 | kernel[mask] = 1
542 | kernel = torch.Tensor(kernel).unsqueeze(0).unsqueeze(0).to(device)
543 | kernel = kernel/kernel.sum()
544 | incoherent_propagator = nn.Conv2d(1, 1, kernel_size=2*math.ceil(subhologram_halfsize)+5, stride=1, padding=math.ceil(subhologram_halfsize)+2, padding_mode='replicate', bias=False)
545 | incoherent_propagator.weight = nn.Parameter(kernel, requires_grad=False)
546 |
547 | return incoherent_propagator
548 |
549 |
550 | def laplacian(img):
551 |
552 | # signed angular difference
553 | grad_x1, grad_y1 = grad(img, next_pixel=True) # x_{n+1} - x_{n}
554 | grad_x0, grad_y0 = grad(img, next_pixel=False) # x_{n} - x_{n-1}
555 |
556 | laplacian_x = grad_x1 - grad_x0 # (x_{n+1} - x_{n}) - (x_{n} - x_{n-1})
557 | laplacian_y = grad_y1 - grad_y0
558 |
559 | return laplacian_x + laplacian_y
560 |
561 |
562 | def grad(img, next_pixel=False, sovel=False):
563 |
564 | if img.shape[1] > 1:
565 | permuted = True
566 | img = img.permute(1, 0, 2, 3)
567 | else:
568 | permuted = False
569 |
570 | # set diff kernel
571 | if sovel: # use sovel filter for gradient calculation
572 | k_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32) / 8
573 | k_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32) / 8
574 | else:
575 | if next_pixel: # x_{n+1} - x_n
576 | k_x = torch.tensor([[0, -1, 1]], dtype=torch.float32)
577 | k_y = torch.tensor([[1], [-1], [0]], dtype=torch.float32)
578 | else: # x_{n} - x_{n-1}
579 | k_x = torch.tensor([[-1, 1, 0]], dtype=torch.float32)
580 | k_y = torch.tensor([[0], [1], [-1]], dtype=torch.float32)
581 |
582 | # upload to gpu
583 | k_x = k_x.to(img.device).unsqueeze(0).unsqueeze(0)
584 | k_y = k_y.to(img.device).unsqueeze(0).unsqueeze(0)
585 |
586 | # boundary handling (replicate elements at boundary)
587 | img_x = F.pad(img, (1, 1, 0, 0), 'replicate')
588 | img_y = F.pad(img, (0, 0, 1, 1), 'replicate')
589 |
590 | # take sign angular difference
591 | grad_x = signed_ang(F.conv2d(img_x, k_x))
592 | grad_y = signed_ang(F.conv2d(img_y, k_y))
593 |
594 | if permuted:
595 | grad_x = grad_x.permute(1, 0, 2, 3)
596 | grad_y = grad_y.permute(1, 0, 2, 3)
597 |
598 | return grad_x, grad_y
599 |
600 |
601 | def signed_ang(angle):
602 | """
603 | cast all angles into [-pi, pi]
604 | """
605 | return (angle + math.pi) % (2*math.pi) - math.pi
606 |
607 |
608 | # Adapted from https://github.com/svaiter/pyprox/blob/master/pyprox/operators.py
609 | def soft_thresholding(x, gamma):
610 | """
611 | return element-wise shrinkage function with threshold kappa
612 | """
613 | return torch.maximum(torch.zeros_like(x),
614 | 1 - gamma / torch.maximum(torch.abs(x), 1e-10*torch.ones_like(x))) * x
615 |
616 |
617 | def random_gen(num_planes=7, slm_type='ti', **kwargs):
618 | """
619 | random hyperparameters for the dataset
620 | """
621 | frame_choices = [1, 1, 2, 2, 4, 4, 4, 8, 8, 8] if slm_type.lower() == 'ti' else [1]
622 | q_choices = ['None', 'nn', 'nn_sigmoid', 'gumbel_softmax'] if slm_type.lower() == 'ti' else ['None']
623 |
624 |
625 | num_frames = random.choice(frame_choices)
626 | quan_method = random.choice(q_choices)
627 | num_iters = random.choice(range(2000)) + 1
628 | phase_range = random.uniform(1.0, 6.28)
629 | target_range = random.uniform(0.5, 1.5)
630 | learning_rate = random.uniform(0.01, 0.03)
631 | plane_idx = random.choice(range(num_planes))
632 | # reg_lf_var = random.choice([0., 0., 1.0, 10.0, 100.0])
633 | reg_lf_var = -1
634 |
635 |
636 | # for profiling
637 | #num_frames = 1
638 | #quan_method = "None"
639 | #num_iters = 10
640 | #phase_range = 3
641 | #target_range = 1
642 | #learning_rate = 0.02
643 | #plane_idx = 4
644 | #reg_lf_var = -1
645 |
646 |
647 | return num_frames, num_iters, phase_range, target_range, learning_rate, plane_idx, quan_method, reg_lf_var
648 |
649 | def write_opt(opt, out_path):
650 | import json
651 | with open(os.path.join(out_path, f'opt.json'), "w") as opt_file:
652 | json.dump(dict(opt), opt_file, indent=4)
653 |
654 | def init_phase(init_phase_type, target_amp, dev, opt):
655 | if init_phase_type == "random":
656 | init_phase = -0.5 + 1.0 * torch.rand(opt.num_frames, 1, *opt.slm_res)
657 | return opt.init_phase_range * init_phase.to(dev)
658 |
659 | def create_backprop_instance(forward_prop):
660 | from params import clone_params
661 | # find a cleaner way to create a backprop instance
662 |
663 | # forward prop only front propagation
664 | # need backwards propagation
665 | assert forward_prop.opt.serial_two_prop_off # assert 1 prop
666 |
667 | # also update the prop_dist and wrp stuff
668 |
669 | backprop_opt = clone_params(forward_prop.opt)
670 | backprop_opt.prop_dist = -forward_prop.opt.prop_dist # propagate back
671 | backward_prop = prop_model.model(backprop_opt)
672 |
673 | return backward_prop
674 |
675 | def normalize_range(data, data_min, data_max, low, high):
676 | data = (data - data_min) / (data_max - data_min) # 0 - 1
677 | data = (high - low) * data + low
678 | return data
679 |
--------------------------------------------------------------------------------