├── .gitignore
├── LICENSE
├── README.md
├── algorithms.py
├── data
├── calib
│ └── 1.png
└── test
│ ├── depth
│ └── frame_0043.dpt
│ └── rgb
│ ├── 0.png
│ └── frame_0043.png
├── dataset_capture.py
├── env.yml
├── hw
├── calibration_module.py
├── camera_capture_module.py
├── detect_heds_module_path.py
├── lens_control.ino
├── lens_control.py
└── slm_display_module.py
├── image_loader.py
├── img
├── citl-asm.png
├── sgd-asm.png
├── sgd-ours.png
└── teaser.png
├── main.py
├── params.py
├── prop_ideal.py
├── prop_model.py
├── prop_physical.py
├── prop_submodules.py
├── prop_zernike.py
├── train.py
├── unet.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | *.DS_Store
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Stanford Computational Imaging Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays
SIGGRAPH Asia 2021
2 | ### [Project Page](http://www.computationalimaging.org/publications/neuralholography3d/) | [Video](https://www.youtube.com/watch?v=EsxGnUd8Efs) | [Paper](https://www.computationalimaging.org/wp-content/uploads/2021/08/NeuralHolography3D.pdf)
3 | PyTorch implementation of
4 | [Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays](http://www.computationalimaging.org/publications/neuralholography3d/)
5 | [Suyeon Choi](http://stanford.edu/~suyeon/)\*,
6 | [Manu Gopakumar](https://www.linkedin.com/in/manu-gopakumar-25032412b/)\*,
7 | [Yifan Peng](http://web.stanford.edu/~evanpeng/),
8 | [Jonghyun Kim](http://j-kim.kr/),
9 | [Gordon Wetzstein](https://computationalimaging.org)
10 | \*denotes equal contribution
11 | in SIGGRAPH Asia 2021
12 |
13 |
14 |
15 | ## Get started
16 | Our code uses [PyTorch Lightning](https://www.pytorchlightning.ai/) and PyTorch >=1.10.0, as it uses complex value operations
17 |
18 | You can set up a conda environment with all dependencies like so:
19 | ```
20 | conda env create -f env.yml
21 | conda activate neural3d
22 | ```
23 |
24 | Also, download [PyCapture2 SDK](https://www.flir.com/products/flycapture-sdk/) and [HOLOEYE SDK](https://holoeye.com/spatial-light-modulators/slm-software/slm-display-sdk/) (or use [slmPy](https://github.com/wavefrontshaping/slmPy)) and place the SDKs in your environment folder. If your hardware setup (SLM, Camera, Laser) is different from ours, you need to make sure that their Python API works well in the ```prop_physical.py```.
25 |
26 | ## High-Level structure
27 |
28 | The code is organized as follows:
29 |
30 |
31 | `./`
32 | * ```main.py``` generates phase patterns from RGBD/RGB data using SGD.
33 | * ```train.py``` trains parameterized propagation models from captured data at multiple planes.
34 | * ```algorithms.py``` contains the gradient-descent based algorithm for RGBD/RGB supervision
35 |
36 |
37 |
38 | * ```props_*.py``` contain the wave propagation operators (in simulation and physics).
39 | * ```utils.py``` has some utils.
40 | * ```params.py``` contains our default parameter settings. :heavy_exclamation_mark:**(Replace values here with those in your setup.)**:heavy_exclamation_mark:
41 | * ```image_loader.py``` contains data loader modules.
42 |
43 | `./hw/` contains modules for hardware control and homography calibration.
44 |
45 | ## Step-by-step instructions
46 | ### 0. Set up parameters
47 | First off, update the values in ``` params.py ``` such that all of them match your configuration (wavelength, propagation distances, the resolution of SLM, encoding for SLM, etc ...). And then, you're ready to use our codebase for your setup! Let's quickly test it by running the naive SGD algorithm with the following command:
48 | ```
49 | # Run naive SGD
50 | c=1; i=0; your_out_path=./results_first;
51 |
52 | # RGB supervision
53 | python main.py --data_path=./data/test/rgb --out_path=${your_out_path} --channel=$c --target=rgb --loss_func=l2 --lr=0.01 --num_iters=1000 --eval_plane_idx=$i
54 |
55 | # RGBD supervision
56 | python main.py --data_path=./data/test --out_path=${your_out_path} --channel=$c --target=rgbd --loss_func=l2 --lr=0.01 --num_iters=1000
57 |
58 | ```
59 | This will generate an SLM phase pattern in ```${your_out_path}```. Display it on your SLM and check out the holographic image formed. While you will get an almost flawless image in your simulation, unfortunately, you will probably get an imagery like below in your experimental setup. Here we show a captured image from our setup:
60 |
61 |
62 | This image degradation is primarily due to the model mismatch between the simulated model you just used (ASM) and the actual propagation in the physical setup. To reduce the gap, we had proposed the "camera-in-the-loop" (CITL) optimization technique which we will try next!
63 |
64 | ### 1. Camera-in-the-loop calibration
65 | To run the CITL optimization, we first need to align the target plane captured through the sensor with the rectified simulated coordinate. To this end, we calculate a homography between the two planes, after displaying a dot pattern (Note that you can use your own image for calibration!) at the target plane. Note that you need to do this procedure per plane. You can generate the SLM patterns for that with the following command (pass the index of the plane with ```--eval_plane_idx=$i``` for multiple planes):
66 |
67 | ```
68 | # Generate homography calibration patterns
69 | python main.py --data_path=./data/calib --out_path=./calibration --channel=$c --target=2d --loss_func=l2 --eval_plane_idx=$i --full_roi=True
70 |
71 | ```
72 | Now you're ready to run the camera-in-the-loop optimization! Before that, please make sure that all of your hardware components are ready and parameters are correctly set in ```hw_params()``` in ```params.py```. For example, you need to set up the python APIs of your SLM/sensor to run them "in-the-loop". For more information, check out the supplements of our papers: [[link1]](https://drive.google.com/file/d/1vay4xeg5iC7y8CLWR6nQEWe3mjBuqCWB/view) [[link2]](https://opg.optica.org/optica/viewmedia.cfm?uri=optica-8-2-143&seq=s001) [[link3]](https://drive.google.com/file/d/1FNSXBYivIN9hqDUxzUIMgSFVKJKtFEOr/view)
73 | ```
74 | # Camera-in-the-loop optimization
75 | python main.py --citl=True --data_path=./data/test --out_path=${your_out_path} --channel=$c --target=2d --loss_func=l2 --eval_plane_idx=$i
76 |
77 | ```
78 | With the phase pattern generated by this code, you will get experimental results like below:
79 |
80 |
81 |
82 | ### 2. Forward model training
83 | Although the camera-in-the-loop shows a significant improvement over the naive approaches, still it has limitations in the need for a camera for every iteration and target image which may not be practical, and also it is hard to extend to 3D since you will need to change the focus of the lens every iteration. To overcome the limitation, we proposed to train a parameterized wave propagation model instead of optimizing a single-phase pattern. After training once, you can use this model to either optimize a phase pattern for 3D or use it as a loss function for training an inverse network.
84 |
85 | #### 2-1. Dataset generation
86 | To train the model, we first create a set of thousands of phase patterns that are various in randomness, target image, number of iterations, method, etc.
87 | For dataset images, we used the [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) for our training - you can also download put them in your ```${data_path}```.
88 |
89 | ```
90 | python main.py --prop_model=${model} --data_path=${data_path} --out_path=${out_path} --channel=$c --target=2d --loss_func=l2 --prop_model_path=${model_path} --random_gen=True
91 |
92 | ```
93 |
94 | #### 2-2. Dataset capture
95 | Then, let's display and massively capture all of them for all planes.
96 | ```
97 | for i in 0 1 ... 7
98 | do
99 | python dataset_capture.py --channel=$c --plane_idx=i
100 | done
101 | ```
102 | We also release a subset of our dataset which you can download from here: [[Big Dataset (~220GB)]](https://drive.google.com/file/d/1E9ppFPwueGwRTG9yRk9wbB3Xy7eOOGOI/view?usp=sharing), [[Small Dataset (~60GB)]](https://drive.google.com/file/d/1EC2pzHlsB_P_braGc1r71oKt9vlmwzlq/view?usp=sharing)..
103 | #### 2-3. Forward model training
104 | With thousands of pairs of `(phase pattern, captured intensities)`, now you can train any parameterized model letting them to predict the intensities at multiple planes. We implemented several architectures for the models described in our paper, the original [Neural Holography](https://www.computationalimaging.org/publications/neuralholography/) model, Hardware-in-the-loop model (Chakravathula et al., 2020), and three variants of our [NH3D](https://www.computationalimaging.org/publications/neuralholography3d/) model (CNNprop, propCNN, CNNpropCNN).
105 | Train the Neural Holography models (SIGGRAPH Asia 2020, 2021) with the same codebase!
106 | ```
107 | # try nh, hil, cnnprop, propcnn, cnnpropcnn for ${model}.
108 | python train.py --prop_model=${model} --data_path=${data_path} --out_path=${out_path} --channel=${channel} --lr=4e-4 --num_train_planes=${num_train_planes}
109 | ```
110 |
111 | Repeat the procedures from 2-1. to 2.3 using your trained model.
112 |
113 | #### 2-4. Run SGD-Model (with the trained forward model)
114 | After training, simply run the phase generation script you ran at the [beginning](#step-by-step-instructions), adding a pointer of your trained model `${model_path}`.
115 | ```
116 | # RGB
117 | python main.py --prop_model=${model} --data_path=${data_path} --out_path=${out_path} --channel=$c --target=2d --loss_func=l2 --prop_model_path=${model_path} --eval_plane_idx=$i
118 |
119 | # RGBD supervision
120 | python main.py --prop_model=${model} --data_path=${data_path} --out_path=${out_path} --channel=$c --target=rgbd --loss_func=l2 --prop_model_path=${model_path}
121 | ```
122 |
123 | Using the command above, we can now achieve a holographic image like below in the same setup. Please try RGBD supervision and other methods we discuss in our paper!
124 |
125 |
126 | ### Citation
127 | If you find our work useful in your research, please cite:
128 | ```
129 | @article{choi2021neural,
130 | title={Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays},
131 | author={Choi, Suyeon and Gopakumar, Manu and Peng, Yifan and Kim, Jonghyun and Wetzstein, Gordon},
132 | journal={ACM Transactions on Graphics (TOG)},
133 | volume={40},
134 | number={6},
135 | pages={1--12},
136 | year={2021},
137 | publisher={ACM New York, NY, USA}
138 | }
139 | ```
140 |
141 | ### Contact
142 | If you have any questions, please email Suyeon Choi at suyeon@stanford.edu.
143 |
--------------------------------------------------------------------------------
/algorithms.py:
--------------------------------------------------------------------------------
1 | """
2 | Various algorithms for RGBD/RGB supervision.
3 | """
4 |
5 | import imageio
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from tqdm import tqdm
10 | import utils
11 |
12 |
13 | def load_alg(alg_type):
14 | if alg_type.lower() in ('sgd', 'admm', 'gd', 'gradient-descent'):
15 | algorithm = gradient_descent
16 |
17 | return algorithm
18 |
19 |
20 | def gradient_descent(init_phase, target_amp, target_mask=None, forward_prop=None, num_iters=1000, roi_res=None,
21 | loss_fn=nn.MSELoss(), lr=0.01, out_path_idx='./results',
22 | citl=False, camera_prop=None, writer=None, admm_opt=None,
23 | *args, **kwargs):
24 | """
25 | Gradient-descent based method for phase optimization.
26 |
27 | :param init_phase:
28 | :param target_amp:
29 | :param target_mask:
30 | :param forward_prop:
31 | :param num_iters:
32 | :param roi_res:
33 | :param loss_fn:
34 | :param lr:
35 | :param out_path_idx:
36 | :param citl:
37 | :param camera_prop:
38 | :param writer:
39 | :param args:
40 | :param kwargs:
41 | :return:
42 | """
43 |
44 | assert forward_prop is not None
45 | dev = init_phase.device
46 | num_iters_admm_inner = 1 if admm_opt is None else admm_opt['num_iters_inner']
47 |
48 | slm_phase = init_phase.requires_grad_(True) # phase at the slm plane
49 | optvars = [{'params': slm_phase}]
50 | optimizer = optim.Adam(optvars, lr=lr)
51 |
52 | loss_vals = []
53 | loss_vals_quantized = []
54 | best_loss = 10.
55 | best_iter = 0
56 |
57 | if target_mask is not None:
58 | target_amp = target_amp * target_mask
59 | nonzeros = target_mask > 0
60 | if roi_res is not None:
61 | target_amp = utils.crop_image(target_amp, roi_res, stacked_complex=False)
62 | if target_mask is not None:
63 | target_mask = utils.crop_image(target_mask, roi_res, stacked_complex=False)
64 | nonzeros = target_mask > 0
65 |
66 | if admm_opt is not None:
67 | u = torch.zeros(1, 1, *roi_res).to(dev)
68 | z = torch.zeros(1, 1, *roi_res).to(dev)
69 |
70 | for t in range(num_iters):
71 | for t_inner in range(num_iters_admm_inner):
72 | optimizer.zero_grad()
73 |
74 | recon_field = forward_prop(slm_phase)
75 | recon_field = utils.crop_image(recon_field, roi_res, stacked_complex=False)
76 | recon_amp = recon_field.abs()
77 |
78 | if citl: # surrogate gradients for CITL
79 | captured_amp = camera_prop(slm_phase)
80 | captured_amp = utils.crop_image(captured_amp, roi_res,
81 | stacked_complex=False)
82 | recon_amp = recon_amp + captured_amp - recon_amp.detach()
83 |
84 | if target_mask is not None:
85 | final_amp = torch.zeros_like(recon_amp)
86 | final_amp[nonzeros] += (recon_amp[nonzeros] * target_mask[nonzeros])
87 | else:
88 | final_amp = recon_amp
89 |
90 | with torch.no_grad():
91 | s = (final_amp * target_amp).mean() / \
92 | (final_amp ** 2).mean() # scale minimizing MSE btw recon and
93 |
94 | loss_val = loss_fn(s * final_amp, target_amp)
95 |
96 | # second loss term if ADMM
97 | if admm_opt is not None:
98 | # augmented lagrangian
99 | recon_phase = recon_field.angle()
100 | loss_prior = loss_fn(utils.laplacian(recon_phase) * target_mask, (z - u) * target_mask)
101 | loss_val = loss_val + admm_opt['rho'] * loss_prior
102 |
103 | loss_val.backward()
104 | optimizer.step()
105 |
106 | ## ADMM steps
107 | if admm_opt is not None:
108 | with torch.no_grad():
109 | reg_norm = utils.laplacian(recon_phase).detach() * target_mask
110 | Ax = admm_opt['alpha'] * reg_norm + (1 - admm_opt['alpha']) * z # over-relaxation
111 | z = utils.soft_thresholding(u + Ax, admm_opt['gamma'] / (rho + 1e-10))
112 | u = u + Ax - z
113 |
114 | # varying penalty (rho)
115 | if admm_opt['varying-penalty']:
116 | if t == 0:
117 | z_prev = z
118 |
119 | r_k = ((reg_norm - z).detach() ** 2).mean() # primal residual
120 | s_k = ((rho * utils.laplacian(z_prev - z).detach()) ** 2).mean() # dual residual
121 |
122 | if r_k > admm_opt['mu'] * s_k:
123 | rho = admm_opt['tau_incr'] * rho
124 | u /= admm_opt['tau_incr']
125 | elif s_k > admm_opt['mu'] * r_k:
126 | rho /= admm_opt['tau_decr']
127 | u *= admm_opt['tau_decr']
128 | z_prev = z
129 |
130 | with torch.no_grad():
131 | if loss_val < best_loss:
132 | best_phase = slm_phase
133 | best_loss = loss_val
134 | best_amp = s * recon_amp
135 | best_iter = t + 1
136 | print(f' -- optimization is done, best loss: {best_loss}')
137 |
138 | return {'loss_vals': loss_vals,
139 | 'loss_vals_q': loss_vals_quantized,
140 | 'best_iter': best_iter,
141 | 'best_loss': best_loss,
142 | 'recon_amp': best_amp,
143 | 'final_phase': best_phase}
--------------------------------------------------------------------------------
/data/calib/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/data/calib/1.png
--------------------------------------------------------------------------------
/data/test/depth/frame_0043.dpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/data/test/depth/frame_0043.dpt
--------------------------------------------------------------------------------
/data/test/rgb/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/data/test/rgb/0.png
--------------------------------------------------------------------------------
/data/test/rgb/frame_0043.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/data/test/rgb/frame_0043.png
--------------------------------------------------------------------------------
/dataset_capture.py:
--------------------------------------------------------------------------------
1 | """
2 | capture all the phases in a folder
3 | 201203
4 |
5 | """
6 |
7 | import os
8 | import time
9 | import concurrent
10 | import cv2
11 | import skimage.io
12 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
13 | import numpy as np
14 | import torch
15 | import imageio
16 | import configargparse
17 |
18 | import utils
19 | import props
20 | import params
21 |
22 | def save_image(raw_data, file_path, demosaickrule, warper, dev):
23 | raw_data = raw_data - 64
24 | color_cv_image = cv2.cvtColor(raw_data, demosaickrule) # it gives float64 from uint16
25 | captured_intensity = utils.im2float(color_cv_image) # float64 to float32
26 |
27 | # Numpy to tensor
28 | captured_intensity = torch.tensor(captured_intensity, dtype=torch.float32,
29 | device=dev).permute(2, 0, 1).unsqueeze(0)
30 | captured_intensity = torch.sum(captured_intensity, dim=1, keepdim=True)
31 |
32 | warped_intensity = warper(captured_intensity)
33 | imageio.imwrite(file_path, (np.clip(warped_intensity.squeeze().cpu().detach().numpy(), 0.0, 1.0)
34 | * np.iinfo(np.uint16).max).round().astype(np.uint16))
35 |
36 | if __name__ == '__main__':
37 | # parse arguments
38 | # Command line argument processing / Parameters
39 | p = configargparse.ArgumentParser()
40 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
41 | params.add_parameters(p, 'eval')
42 | opt = p.parse_args()
43 | params.set_configs(opt)
44 | dev = torch.device('cuda')
45 |
46 | # hardware setup
47 | params_slm, params_camera, params_calib = params.hw_params(opt)
48 | camera_prop = props.PhysicalProp(params_slm, params_camera, params_calib).to(dev)
49 |
50 | data_path = opt.data_path # pass a path of your dataset folder like f'F:/dataset/green'
51 |
52 | if not opt.chan_str in data_path:
53 | raise ValueError('Double check the color!')
54 |
55 | ds = ['test', 'val', 'train']
56 | data_paths = [os.path.join(data_path, d) for d in ds]
57 | for root_path in data_paths:
58 |
59 | # load phases
60 | phase_path = f'{root_path}/phase'
61 | captured_path = f'{root_path}/captured'
62 | utils.cond_mkdir(captured_path)
63 | names = os.listdir(f'{phase_path}')
64 |
65 | # run multiple thread for fast capture
66 | with concurrent.futures.ThreadPoolExecutor() as executor:
67 | for ii, full_name in enumerate(names):
68 | t0 = time.perf_counter()
69 |
70 | filename = f'{phase_path}/{full_name}'
71 | phase_img = skimage.io.imread(filename)
72 | raw_uint16_data = camera_prop.capture_uint16(phase_img) # display & retrieve buffer
73 |
74 | out_full_path = os.path.join(captured_path, f'{full_name[:-4]}_{opt.plane_idx}.png')
75 | executor.submit(save_image,
76 | raw_uint16_data,
77 | out_full_path,
78 | camera_prop.camera.demosaick_rule,
79 | camera_prop.warper, dev)
80 |
81 | print(f'{out_full_path}: '
82 | f'{1 / (time.perf_counter() - t0):.2f}Hz')
83 |
84 | if ii % 500 == 0 and ii > 0:
85 | time.sleep(10.0)
86 |
87 | camera_prop.disconnect()
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | ��n a m e : neural3d
2 | c h a n n e l s :
3 | - d e f a u l t s
4 | d e p e n d e n c i e s :
5 | - p y t h o n = 3 . 6
6 | - n u m p y
7 | - t o r c h v i s i o n
8 | - t o r c h a u d i o
9 | - c u d a t o o l k i t = 1 1 . 3
10 | - o p e n c v
11 | p r e f i x : C : \ U s e r s \ s u y e o n \ . c o n d a \ e n v s \ f l e x
12 |
13 |
--------------------------------------------------------------------------------
/hw/calibration_module.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the script containing the calibration module, basically calculating homography matrix.
3 |
4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
6 | # The material is provided as-is, with no warranties whatsoever.
7 | # If you publish any code, data, or scientific work based on this, please cite our work.
8 |
9 | Technical Paper:
10 | S. Choi, M. Gopakumar, Y. Peng, J. Kim, G. Wetzstein. Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays. ACM TOG (SIGGRAPH Asia), 2021.
11 | """
12 |
13 | import numpy as np
14 | import matplotlib.pyplot as plt
15 | import torchvision
16 | import cv2
17 | import skimage.transform as transform
18 | import time
19 | import datetime
20 | from scipy.io import savemat
21 | from scipy.ndimage import map_coordinates
22 | import torch
23 | import torch.nn.functional as F
24 | import torch.nn as nn
25 |
26 | def id(x):
27 | return x
28 |
29 | def circle_detect(captured_img, num_circles, spacing, pad_pixels=(0., 0.), show_preview=True, quadratic=False):
30 | """
31 | Detects the circle of a circle board pattern
32 |
33 | :param captured_img: captured image
34 | :param num_circles: a tuple of integers, (num_circle_x, num_circle_y)
35 | :param spacing: a tuple of integers, in pixels, (space between circles in x, space btw circs in y direction)
36 | :param show_preview: boolean, default True
37 | :param pad_pixels: coordinate of the left top corner of warped image.
38 | Assuming pad this amount of pixels on the other side.
39 | :return: a tuple, (found_dots, H)
40 | found_dots: boolean, indicating success of calibration
41 | H: a 3x3 homography matrix (numpy)
42 | """
43 |
44 | # Binarization
45 | # org_copy = org.copy() # Otherwise, we write on the original image!
46 | img = (captured_img.copy() * 255).astype(np.uint8)
47 | if len(img.shape) > 2:
48 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
49 | # print(img[...,0].mean())
50 | # print(img[...,1].mean())
51 | # print(img[...,2].mean())
52 |
53 | #img = cv2.medianBlur(img, 31)
54 | img = cv2.medianBlur(img, 55) # Red 71
55 | # img = cv2.medianBlur(img, 5) #210104
56 | img_gray = img.copy()
57 |
58 | # img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 121, 0)
59 | img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 127, 0)
60 |
61 | # img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 117, 0) # Red 127
62 | # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
63 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (31, 31))
64 | img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel)
65 | img = 255 - img
66 |
67 | # Blob detection
68 | params = cv2.SimpleBlobDetector_Params()
69 |
70 | # Change thresholds
71 | params.filterByColor = True
72 | params.minThreshold = 150
73 | params.minThreshold = 121
74 | # params.minThreshold = 121
75 |
76 | # Filter by Area.
77 | params.filterByArea = True
78 | params.minArea = 150
79 | # params.minArea = 80 # Red 120
80 | # params.minArea = 30 # 210104
81 | # params.maxArea = 100 # 210104
82 |
83 | # Filter by Circularity
84 | params.filterByCircularity = True
85 | params.minCircularity = 0.85
86 | params.minCircularity = 0.60
87 |
88 | # Filter by Convexity
89 | params.filterByConvexity = True
90 | params.minConvexity = 0.87
91 | # params.minConvexity = 0.80
92 |
93 | # Filter by Inertia
94 | params.filterByInertia = False
95 | params.minInertiaRatio = 0.01
96 |
97 | detector = cv2.SimpleBlobDetector_create(params)
98 |
99 | # Detecting keypoints
100 | # this is redundant for what comes next, but gives us access to the detected dots for debug
101 | keypoints = detector.detect(img)
102 | found_dots, centers = cv2.findCirclesGrid(img, (num_circles[1], num_circles[0]),
103 | blobDetector=detector, flags=cv2.CALIB_CB_SYMMETRIC_GRID)
104 |
105 | # Drawing the keypoints
106 | cv2.drawChessboardCorners(captured_img, num_circles, centers, found_dots)
107 | img_gray = cv2.drawKeypoints(img_gray, keypoints, np.array([]), (0, 255, 0),
108 | cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
109 |
110 | # Find transformation
111 | H = np.array([[1., 0., 0.],
112 | [0., 1., 0.],
113 | [0., 0., 1.]], dtype=np.float32)
114 | if found_dots:
115 | # Generate reference points to compute the homography
116 | ref_pts = np.zeros((num_circles[0] * num_circles[1], 1, 2), np.float32)
117 | pos = 0
118 | for j in range(0, num_circles[0]):
119 | for i in range(0, num_circles[1]):
120 | ref_pts[pos, 0, :] = spacing * np.array([i, j]) + np.array([pad_pixels[1], pad_pixels[0]])
121 |
122 | pos += 1
123 |
124 |
125 | H, mask = cv2.findHomography(centers, ref_pts, cv2.RANSAC, 1)
126 |
127 | ref_pts = ref_pts.reshape(num_circles[0] * num_circles[1], 2)
128 | centers = np.flip(centers.reshape(num_circles[0] * num_circles[1], 2), 1)
129 |
130 |
131 | now = datetime.datetime.now()
132 | mdic = {"centers": centers, 'H': H}
133 | # savemat(f'F:/2021/centers_coords/centers_{now.strftime("%m%d_%H%M%S")}.mat', mdic)
134 | dsize = [int((num_circs - 1) * space + 2 * pad_pixs)
135 | for num_circs, space, pad_pixs in zip(num_circles, spacing, pad_pixels) ]
136 | if quadratic:
137 | H = transform.estimate_transform('polynomial', ref_pts, centers)
138 | coords = transform.warp_coords(H, dsize, dtype=np.float32) # for pytorch
139 | else:
140 | tf = transform.estimate_transform('projective', ref_pts, centers)
141 | coords = transform.warp_coords(tf, (800, 1280), dtype=np.float32) # for pytorch
142 |
143 | if show_preview:
144 | dsize = [int((num_circs - 1) * space + 2 * pad_pixs)
145 | for num_circs, space, pad_pixs in zip(num_circles, spacing, pad_pixels)]
146 | if quadratic:
147 | captured_img_warp = transform.warp(captured_img, H, output_shape=(dsize[0], dsize[1]))
148 | else:
149 | captured_img_warp = cv2.warpPerspective(captured_img, H, (dsize[1], dsize[0]))
150 |
151 |
152 | if show_preview:
153 | fig = plt.figure()
154 |
155 | ax = fig.add_subplot(223)
156 | ax.imshow(img_gray, cmap='gray')
157 |
158 | ax2 = fig.add_subplot(221)
159 | ax2.imshow(img, cmap='gray')
160 |
161 | ax3 = fig.add_subplot(222)
162 | ax3.imshow(captured_img, cmap='gray')
163 |
164 | if found_dots:
165 | ax4 = fig.add_subplot(224)
166 | ax4.imshow(captured_img_warp, cmap='gray')
167 |
168 | plt.show()
169 |
170 | return found_dots, H, coords
171 |
172 |
173 | class Warper(nn.Module):
174 | def __init__(self, params_calib):
175 | super(Warper, self).__init__()
176 | self.num_circles = params_calib.num_circles
177 | self.spacing_size = params_calib.spacing_size
178 | self.pad_pixels = params_calib.pad_pixels
179 | self.quadratic = params_calib.quadratic
180 | self.img_size_native = params_calib.img_size_native # get this from image
181 | self.h_transform = np.array([[1., 0., 0.],
182 | [0., 1., 0.],
183 | [0., 0., 1.]])
184 | self.range_x = params_calib.range_x # slice
185 | self.range_y = params_calib.range_y # slice
186 |
187 |
188 | def calibrate(self, img, show_preview=True):
189 | img_masked = np.zeros_like(img)
190 | img_masked[self.range_y, self.range_x, ...] = img[self.range_y, self.range_x, ...]
191 |
192 | found_corners, self.h_transform, self.coords = circle_detect(img_masked, self.num_circles,
193 | self.spacing_size, self.pad_pixels, show_preview,
194 | quadratic=self.quadratic)
195 | if not self.coords is None:
196 | self.coords_tensor = torch.tensor(np.transpose(self.coords, (1, 2, 0)),
197 | dtype=torch.float32).unsqueeze(0)
198 |
199 | # normalize it into [-1, 1]
200 | # self.coords_tensor[..., 0] = self.coords_tensor[..., 0] / (self.img_size_native[1]//2) - 1
201 | # self.coords_tensor[..., 1] = self.coords_tensor[..., 1] / (self.img_size_native[0]//2) - 1
202 | self.coords_tensor[..., 0] = 2*self.coords_tensor[..., 0] / (self.img_size_native[1]-1) - 1
203 | self.coords_tensor[..., 1] = 2*self.coords_tensor[..., 1] / (self.img_size_native[0]-1) - 1
204 |
205 | return found_corners
206 |
207 | def __call__(self, input_img, img_size=None):
208 | """
209 | This forward pass returns the warped image.
210 |
211 | :param input_img: A numpy grayscale image shape of [H, W].
212 | :param img_size: output size, default None.
213 | :return: output_img: warped image with pre-calculated homography and destination size.
214 | """
215 |
216 | if img_size is None:
217 | img_size = [int((num_circs - 1) * space + 2 * pad_pixs)
218 | for num_circs, space, pad_pixs in zip(self.num_circles, self.spacing_size, self.pad_pixels)]
219 |
220 | if torch.is_tensor(input_img):
221 | # output_img = F.grid_sample(input_img, self.coords_tensor, align_corners=False)
222 | output_img = F.grid_sample(input_img, self.coords_tensor, align_corners=True)
223 | else:
224 | if self.quadratic:
225 | output_img = transform.warp(input_img, self.h_transform, output_shape=(img_size[0], img_size[1]))
226 | else:
227 | output_img = cv2.warpPerspective(input_img, self.h_transform, (img_size[0], img_size[1]))
228 |
229 | return output_img
230 |
231 | @property
232 | def h_transform(self):
233 | return self._h_transform
234 |
235 | @h_transform.setter
236 | def h_transform(self, new_h):
237 | self._h_transform = new_h
238 |
239 | def to(self, *args, **kwargs):
240 | slf = super().to(*args, **kwargs)
241 | if slf.coords_tensor is not None:
242 | slf.coords_tensor = slf.coords_tensor.to(*args, **kwargs)
243 | try:
244 | slf.dev = next(slf.parameters()).device
245 | except StopIteration: # no parameters
246 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0]
247 | if device_arg is not None:
248 | slf.dev = device_arg
249 | return slf
--------------------------------------------------------------------------------
/hw/camera_capture_module.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the script containing the calibration module, basically calculating homography matrix.
3 |
4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
6 | # The material is provided as-is, with no warranties whatsoever.
7 | # If you publish any code, data, or scientific work based on this, please cite our work.
8 |
9 | Technical Paper:
10 | S. Choi, M. Gopakumar, Y. Peng, J. Kim, G. Wetzstein. Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays. ACM TOG (SIGGRAPH Asia), 2021.
11 | """
12 |
13 | import PyCapture2
14 | import cv2
15 | import numpy as np
16 | import time
17 | import utils
18 |
19 |
20 | def callback_captured(image):
21 | print(image.getData())
22 |
23 |
24 | class CameraCapture:
25 | def __init__(self, params):
26 | self.bus = PyCapture2.BusManager()
27 | num_cams = self.bus.getNumOfCameras()
28 | if not num_cams:
29 | exit()
30 | # self.demosaick_rule = cv2.COLOR_BAYER_RG2BGR
31 | self.demosaick_rule = cv2.COLOR_BAYER_GR2RGB # GBRG to RGB
32 | self.params = params
33 |
34 | def connect(self, i, trigger=False):
35 | uid = self.bus.getCameraFromIndex(i)
36 | self.camera_device = PyCapture2.Camera()
37 | self.camera_device.connect(uid)
38 | self.camera_device.setConfiguration(highPerformanceRetrieveBuffer=True)
39 | self.camera_device.setConfiguration(numBuffers=1000)
40 | config = self.camera_device.getConfiguration()
41 | self.toggle_embedded_timestamp(True)
42 |
43 | if trigger:
44 | trigger_mode = self.camera_device.getTriggerMode()
45 | trigger_mode.onOff = True
46 | trigger_mode.mode = 0
47 | trigger_mode.parameter = 0
48 | trigger_mode.source = 3 # Using software trigger
49 | self.camera_device.setTriggerMode(trigger_mode)
50 | else:
51 | trigger_mode = self.camera_device.getTriggerMode()
52 | trigger_mode.onOff = False
53 | trigger_mode.mode = 0
54 | trigger_mode.parameter = 0
55 | trigger_mode.source = 3 # Using software trigger
56 | self.camera_device.setTriggerMode(trigger_mode)
57 |
58 | trigger_mode = self.camera_device.getTriggerMode()
59 | if trigger_mode.onOff is True:
60 | print(' - setting trigger mode on')
61 |
62 |
63 | def disconnect(self):
64 | self.toggle_embedded_timestamp(False)
65 | self.camera_device.disconnect()
66 |
67 | def toggle_embedded_timestamp(self, enable_timestamp):
68 | embedded_info = self.camera_device.getEmbeddedImageInfo()
69 | if embedded_info.available.timestamp:
70 | self.camera_device.setEmbeddedImageInfo(timestamp=enable_timestamp)
71 |
72 | def grab_images(self, num_images_to_grab=1):
73 | """
74 | Retrieve the camera buffer and returns a list of grabbed images.
75 |
76 | :param num_images_to_grab: integer, default 1
77 | :return: a list of numpy 2d color images from the camera buffer.
78 | """
79 | self.camera_device.startCapture()
80 | img_list = []
81 | for i in range(num_images_to_grab):
82 | imgData = self.retrieve_buffer()
83 | offset = 64 # offset that inherently exist.retrieve_buffer
84 | imgData = imgData - offset
85 |
86 | color_cv_image = cv2.cvtColor(imgData, self.demosaick_rule)
87 | color_cv_image = utils.im2float(color_cv_image)
88 | img_list.append(color_cv_image.copy())
89 |
90 | self.camera_device.stopCapture()
91 | return img_list
92 |
93 | def grab_images_fast(self, num_images_to_grab=1):
94 | """
95 | Retrieve the camera buffer and returns a grabbed image
96 |
97 | :param num_images_to_grab: integer, default 1
98 | :return: a list of numpy 2d color images from the camera buffer.
99 | """
100 | imgData = self.retrieve_buffer()
101 | offset = 64 # offset that inherently exist.
102 | imgData = imgData - offset
103 |
104 | color_cv_image = cv2.cvtColor(imgData, self.demosaick_rule)
105 | color_cv_image = utils.im2float(color_cv_image)
106 | color_img = color_cv_image
107 | return color_img
108 |
109 | def retrieve_buffer(self):
110 | try:
111 | img = self.camera_device.retrieveBuffer()
112 | except PyCapture2.Fc2error as fc2Err:
113 | raise fc2Err
114 |
115 | imgData = img.getData()
116 |
117 | # when using raw8 from the PG sensor
118 | # cv_image = np.array(img.getData(), dtype="uint8").reshape((img.getRows(), img.getCols()))
119 |
120 | # when using raw16 from the PG sensor - concat 2 8bits in a row
121 | imgData.dtype = np.uint16
122 | imgData = imgData.reshape(img.getRows(), img.getCols())
123 | return imgData.copy()
124 |
125 | def start_capture(self):
126 | # these two were previously inside the grab_images func, and can be clarified outside the loop
127 | self.camera_device.startCapture()
128 |
129 | def stop_capture(self):
130 | self.camera_device.stopCapture()
131 |
132 | @property
133 | def params(self):
134 | return self._params
135 |
136 | @params.setter
137 | def params(self, p):
138 | self._params = p
--------------------------------------------------------------------------------
/hw/detect_heds_module_path.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | #--------------------------------------------------------------------#
4 | # #
5 | # Copyright (C) 2020 HOLOEYE Photonics AG. All rights reserved. #
6 | # Contact: https://holoeye.com/contact/ #
7 | # #
8 | # This file is part of HOLOEYE SLM Display SDK. #
9 | # #
10 | # You may use this file under the terms and conditions of the #
11 | # 'HOLOEYE SLM Display SDK Standard License v1.0' license agreement. #
12 | # #
13 | #--------------------------------------------------------------------#
14 |
15 |
16 | # Please import this file in your scripts before actually importing the HOLOEYE SLM Display SDK,
17 | # i. e. copy this file to your project and use this code in your scripts:
18 | #
19 | # import detect_heds_module_path
20 | # import holoeye
21 | #
22 | #
23 | # Another option is to copy the holoeye module directory into your project and import by only using
24 | # import holoeye
25 | # This way, code completion etc. might work better.
26 |
27 |
28 | import os, sys
29 | from platform import system
30 |
31 | # Import the SLM Display SDK:
32 | HEDSModulePath = os.getenv('HEDS_2_PYTHON_MODULES', '')
33 |
34 | if HEDSModulePath == '':
35 | sdklocal = os.path.abspath(os.path.join(os.path.dirname(__file__),
36 | 'holoeye', 'slmdisplaysdk', '__init__.py'))
37 | if os.path.isfile(sdklocal):
38 | HEDSModulePath = os.path.dirname(os.path.dirname(os.path.dirname(sdklocal)))
39 | else:
40 | sdklocal = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',
41 | 'sdk', 'holoeye', 'slmdisplaysdk', '__init__.py'))
42 | if os.path.isfile(sdklocal):
43 | HEDSModulePath = os.path.dirname(os.path.dirname(os.path.dirname(sdklocal)))
44 |
45 | if HEDSModulePath == '':
46 | if system() == 'Windows':
47 | print('\033[91m'
48 | '\nError: Could not find HOLOEYE SLM Display SDK installation path from environment variable. '
49 | '\n\nPlease relogin your Windows user account and try again. '
50 | '\nIf that does not help, please reinstall the SDK and then relogin your user account and try again. '
51 | '\nA simple restart of the computer might fix the problem, too.'
52 | '\033[0m')
53 | else:
54 | print('\033[91m'
55 | '\nError: Could not detect HOLOEYE SLM Display SDK installation path. '
56 | '\n\nPlease make sure it is present within the same folder or in "../../sdk".'
57 | '\033[0m')
58 |
59 | sys.exit(1)
60 |
61 | sys.path.append(HEDSModulePath)
62 |
--------------------------------------------------------------------------------
/hw/lens_control.ino:
--------------------------------------------------------------------------------
1 | // Autofocus control code of Canon EF-S18-55mm f/3.5-5.6 IS II
2 | // helpful instructions:
3 | // https://docplayer.net/20752779-How-to-move-canon-ef-lenses-yosuke-bando.html
4 | // https://github.com/crescentvenus/EF-Lens-CONTROL/blob/master/EF-Lens-control.ino
5 | // https://gist.github.com/marcan/858c242db2fc595da1e0bb70a05192fc
6 | // Contact: Suyeon Choi (suyeon@stanford.edu)
7 |
8 |
9 | #include
10 | #include
11 |
12 | const int HotShoe_Pin = 8;
13 | const int HotShoe_Gnd = 9;
14 | const int LogicVDD_Pin = 10;
15 | const int Cam2Lens_Pin = 11;
16 | const int Clock_Pin = 13;
17 |
18 | // manually set for 11 planes (the last one is dummy)
19 | // *** Calibrate these focus values for your own lens ***
20 | // Inf to 3D (0D, 0.3D, ... 3D) gotta correspond to some range in your physical setup
21 | const int focus[12] = {2047, 1975, 1896, 1830, 1795, 1750, 1655, 1580, 1560, 1510, 1460, 0};
22 |
23 | const int current_focus = 0;
24 | #define INPUT_SIZE 30
25 |
26 | void init_lens() {
27 | // Before sending signal to lens you should do this
28 |
29 | SPI.transfer(0x0A);
30 | delay(30);
31 | SPI.transfer(0x00);
32 | delay(30);
33 | SPI.transfer(0x0A);
34 | delay(30);
35 | SPI.transfer(0x00);
36 | delay(30);
37 | }
38 |
39 | void setup() // initialization
40 | {
41 | Serial.begin(9600);
42 |
43 | pinMode(LogicVDD_Pin, OUTPUT);
44 | digitalWrite(LogicVDD_Pin, HIGH);
45 | pinMode(Cam2Lens_Pin, OUTPUT);
46 | pinMode(Clock_Pin, OUTPUT);
47 | digitalWrite(Clock_Pin, HIGH);
48 | SPI.beginTransaction(SPISettings(9600, MSBFIRST, SPI_MODE3));
49 | move_focus_infinity();
50 | }
51 |
52 | void loop() {
53 | char input[INPUT_SIZE + 1];
54 | byte size = Serial.readBytes(input, INPUT_SIZE);
55 | // Add the final 0 to end the C string
56 | input[size] = 0;
57 |
58 | // Read each command
59 | char* command = strtok(input, ",");
60 | while (command != 0)
61 | {
62 | // input command is assumed to be an integer.
63 | int idx_plane = atoi(command);
64 | move_focus(idx_plane);
65 |
66 | // Find the next command in input string
67 | command = strtok(0, ",");
68 |
69 | }
70 | }
71 |
72 | void move_focus(int idx_plane) {
73 | // Move focus state of lens with index of plane (values for each plane are predefined)
74 |
75 | if (idx_plane > 10) {
76 | Serial.print(" - wrong idx");
77 | return;
78 | }
79 | //else if (idx_plane == 0){ // commenting this out is and using relative is indeed more stable
80 | // Serial.print(" - move to infinity idx");
81 | // move_focus_infinity();
82 | //}
83 | else {
84 | // Below print cmds are for python
85 |
86 | Serial.print(" - from arduino: moving to the ");
87 | Serial.print(idx_plane);
88 | Serial.print("th plane");
89 | int offset = focus[idx_plane] - read_int_EEPROM(current_focus);
90 | Serial.print(offset);
91 | if (offset != 0){
92 | ///////////////////////////////////
93 | // This is what you send to lens //
94 | ///////////////////////////////////
95 | byte HH = highByte(offset);
96 | byte LL = lowByte(offset);
97 | init_lens();
98 | SPI.transfer(0x44); delay(10);
99 | SPI.transfer(HH); delay(10);
100 | SPI.transfer(LL); delay(10);
101 | write_int_EEPROM(current_focus, focus[idx_plane]);
102 | }
103 | }
104 | delay(100);
105 |
106 | }
107 |
108 |
109 | void move_focus_value(int value) {
110 | // Move focus state of lens with exact value
111 |
112 | int offset = value - read_int_EEPROM(current_focus);
113 | byte HH = highByte(offset);
114 | byte LL = lowByte(offset);
115 | init_lens();
116 | SPI.transfer(0x44); delay(10);
117 | SPI.transfer(HH); delay(10);
118 | SPI.transfer(LL); delay(10);
119 |
120 | write_int_EEPROM(current_focus, value);
121 | }
122 |
123 | void move_focus_infinity(){
124 | init_lens();
125 | SPI.transfer(0x05); delay(10);
126 |
127 | write_int_EEPROM(current_focus, focus[0]);
128 | }
129 |
130 | void write_int_EEPROM(int address, int number)
131 | {
132 | EEPROM.write(address, number >> 8);
133 | EEPROM.write(address + 1, number & 0xFF);
134 | }
135 |
136 | int read_int_EEPROM(int address)
137 | {
138 | return (EEPROM.read(address) << 8) + EEPROM.read(address + 1);
139 | }
140 |
--------------------------------------------------------------------------------
/hw/lens_control.py:
--------------------------------------------------------------------------------
1 | """
2 | Test code for Canon EF lens
3 |
4 | Contact: Suyeon Choi (suyeon@stanford.edu)
5 | -----
6 |
7 | $ python lens_control.py --index=$i
8 | """
9 |
10 | import serial
11 | import time
12 | import serial, time
13 | import random
14 | import configargparse
15 |
16 | # Command line argument processing
17 | p = configargparse.ArgumentParser()
18 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
19 | p.add_argument('--index', type=int, default=None, help='index of plane')
20 |
21 | opt = p.parse_args()
22 |
23 |
24 | t0 = time.perf_counter()
25 | arduino = serial.Serial('COM7', 9600, timeout=1.)
26 | time.sleep(0.1) # give the connection a second to settle
27 | print(f' -- connection is established.. {time.perf_counter() - t0}')
28 |
29 |
30 | if opt.index is not None:
31 | time.sleep(5)
32 | print(f' -- writing...{opt.index}')
33 | arduino.write(f'{opt.index}'.encode())
34 |
35 | time.sleep(5)
36 |
37 | data = arduino.readline().decode('UTF-8') # hear back from your arduino
38 | if data:
39 | print(data)
40 | else:
41 | while True:
42 | my_input = input()
43 | arduino.write(f'{str(my_input)}'.encode())
44 | time.sleep(1.0)
45 |
46 | data = arduino.readline().decode('UTF-8')
47 | if data:
48 | print(data)
49 |
--------------------------------------------------------------------------------
/hw/slm_display_module.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the script containing the calibration module, basically calculating homography matrix.
3 |
4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
5 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
6 | # The material is provided as-is, with no warranties whatsoever.
7 | # If you publish any code, data, or scientific work based on this, please cite our work.
8 |
9 | Technical Paper:
10 | S. Choi, M. Gopakumar, Y. Peng, J. Kim, G. Wetzstein. Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays. ACM TOG (SIGGRAPH Asia), 2021.
11 | """
12 |
13 | import hw.detect_heds_module_path
14 | import holoeye
15 | from holoeye import slmdisplaysdk
16 |
17 |
18 | class SLMDisplay:
19 | ErrorCode = slmdisplaysdk.SLMDisplay.ErrorCode
20 | ShowFlags = slmdisplaysdk.SLMDisplay.ShowFlags
21 | State = slmdisplaysdk.SLMDisplay.State
22 | ApplyDataHandleValue = slmdisplaysdk.SLMDisplay.ApplyDataHandleValue
23 |
24 | def __init__(self):
25 | self.ErrorCode = slmdisplaysdk.SLMDisplay.ErrorCode
26 | self.ShowFlags = slmdisplaysdk.SLMDisplay.ShowFlags
27 |
28 | self.displayOptions = self.ShowFlags.PresentAutomatic # PresentAutomatic == 0 (default)
29 | self.displayOptions |= self.ShowFlags.PresentFitWithBars
30 |
31 | def connect(self):
32 | self.slm_device = slmdisplaysdk.SLMDisplay()
33 | self.slm_device.open() # For version 2.0.1
34 |
35 | def disconnect(self):
36 | self.slm_device.release()
37 |
38 | def show_data_from_file(self, filepath):
39 | error = self.slm_device.showDataFromFile(filepath, self.displayOptions)
40 | assert error == self.ErrorCode.NoError, self.slm_device.errorString(error)
41 |
42 | def show_data_from_array(self, numpy_array):
43 | error = self.slm_device.showData(numpy_array)
44 | assert error == self.ErrorCode.NoError, self.slm_device.errorString(error)
45 |
--------------------------------------------------------------------------------
/image_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import skimage.io
4 | from imageio import imread
5 | from skimage.transform import resize
6 | from torchvision.transforms.functional import resize as resize_tensor
7 | import cv2
8 | import random
9 | import json
10 | import numpy as np
11 | import h5py
12 | import torch
13 | import utils
14 |
15 | # Check for endianness, based on Daniel Scharstein's optical flow code.
16 | # Using little-endian architecture, these two should be equal.
17 | TAG_FLOAT = 202021.25
18 | TAG_CHAR = 'PIEH'
19 |
20 | def depth_read(filename):
21 | """ Read depth data from file, return as numpy array. """
22 | f = open(filename, 'rb')
23 | check = np.fromfile(f, dtype=np.float32, count=1)[0]
24 | assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(
25 | TAG_FLOAT, check)
26 | width = np.fromfile(f, dtype=np.int32, count=1)[0]
27 | height = np.fromfile(f, dtype=np.int32, count=1)[0]
28 | size = width * height
29 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(
30 | width, height)
31 | depth = np.fromfile(f, dtype=np.float32, count=-1).reshape((height, width))
32 | return depth
33 |
34 |
35 | def get_matlab_filenames(dir, focuses=None):
36 | """Returns all files in the input directory dir that are images"""
37 | image_types = ('mat')
38 | if isinstance(dir, str):
39 | files = os.listdir(dir)
40 | exts = (os.path.splitext(f)[1] for f in files)
41 | if focuses is not None:
42 | images = [os.path.join(dir, f)
43 | for e, f in zip(exts, files)
44 | if e[1:] in image_types and int(os.path.splitext(f)[0].split('_')[-1]) in focuses]
45 | else:
46 | images = [os.path.join(dir, f)
47 | for e, f in zip(exts, files)
48 | if e[1:] in image_types]
49 | return images
50 | elif isinstance(dir, list):
51 | # Suppport multiple directories (randomly shuffle all)
52 | images = []
53 | for folder in dir:
54 | files = os.listdir(folder)
55 | exts = (os.path.splitext(f)[1] for f in files)
56 | images_in_folder = [os.path.join(folder, f)
57 | for e, f in zip(exts, files)
58 | if e[1:] in image_types]
59 | images = [*images, *images_in_folder]
60 |
61 | return images
62 |
63 |
64 | def get_image_filenames(dir, focuses=None):
65 | """Returns all files in the input directory dir that are images"""
66 | image_types = ('jpg', 'jpeg', 'tiff', 'tif', 'png', 'bmp', 'gif', 'exr', 'dpt', 'hdf5')
67 | if isinstance(dir, str):
68 | files = os.listdir(dir)
69 | exts = (os.path.splitext(f)[1] for f in files)
70 | if focuses is not None:
71 | images = [os.path.join(dir, f)
72 | for e, f in zip(exts, files)
73 | if e[1:] in image_types and int(os.path.splitext(f)[0].split('_')[-1]) in focuses]
74 | else:
75 | images = [os.path.join(dir, f)
76 | for e, f in zip(exts, files)
77 | if e[1:] in image_types]
78 | return images
79 | elif isinstance(dir, list):
80 | # Suppport multiple directories (randomly shuffle all)
81 | images = []
82 | for folder in dir:
83 | files = os.listdir(folder)
84 | exts = (os.path.splitext(f)[1] for f in files)
85 | images_in_folder = [os.path.join(folder, f)
86 | for e, f in zip(exts, files)
87 | if e[1:] in image_types]
88 | images = [*images, *images_in_folder]
89 |
90 | return images
91 |
92 |
93 | def resize_keep_aspect(image, target_res, pad=False, lf=False, pytorch=False):
94 | """Resizes image to the target_res while keeping aspect ratio by cropping
95 |
96 | image: an 3d array with dims [channel, height, width]
97 | target_res: [height, width]
98 | pad: if True, will pad zeros instead of cropping to preserve aspect ratio
99 | """
100 | im_res = image.shape[-2:]
101 |
102 | # finds the resolution needed for either dimension to have the target aspect
103 | # ratio, when the other is kept constant. If the image doesn't have the
104 | # target ratio, then one of these two will be larger, and the other smaller,
105 | # than the current image dimensions
106 | resized_res = (int(np.ceil(im_res[1] * target_res[0] / target_res[1])),
107 | int(np.ceil(im_res[0] * target_res[1] / target_res[0])))
108 |
109 | # only pads smaller or crops larger dims, meaning that the resulting image
110 | # size will be the target aspect ratio after a single pad/crop to the
111 | # resized_res dimensions
112 | if pad:
113 | image = utils.pad_image(image, resized_res, pytorch=False)
114 | else:
115 | image = utils.crop_image(image, resized_res, pytorch=False, lf=lf)
116 |
117 | # switch to numpy channel dim convention, resize, switch back
118 | if lf or pytorch:
119 | image = resize_tensor(image, target_res)
120 | return image
121 | else:
122 | image = np.transpose(image, axes=(1, 2, 0))
123 | image = resize(image, target_res, mode='reflect')
124 | return np.transpose(image, axes=(2, 0, 1))
125 |
126 |
127 | def pad_crop_to_res(image, target_res, pytorch=False):
128 | """Pads with 0 and crops as needed to force image to be target_res
129 |
130 | image: an array with dims [..., channel, height, width]
131 | target_res: [height, width]
132 | """
133 | return utils.crop_image(utils.pad_image(image,
134 | target_res, pytorch=pytorch, stacked_complex=False),
135 | target_res, pytorch=pytorch, stacked_complex=False)
136 |
137 |
138 | def get_folder_names(folder):
139 | """Returns all files in the input directory dir that are images"""
140 | return [d for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d))]
141 |
142 |
143 | class PairsLoader(torch.utils.data.IterableDataset):
144 | """Loads (phase, captured) tuples for forward model training
145 |
146 | Class initialization parameters
147 | -------------------------------
148 |
149 | :param data_path:
150 | :param plane_idxs:
151 | :param batch_size:
152 | :param image_res:
153 | :param shuffle:
154 | :param avg_energy_ratio:
155 | :param slm_type:
156 |
157 |
158 | """
159 |
160 | def __init__(self, data_path, plane_idxs=None, batch_size=1,
161 | image_res=(800, 1280), shuffle=True,
162 | avg_energy_ratio=None, slm_type='leto'):
163 | """
164 |
165 | """
166 | print(data_path)
167 | if isinstance(data_path, str):
168 | if not os.path.isdir(data_path):
169 | raise NotADirectoryError(f'Data folder: {data_path}')
170 | self.phase_path = os.path.join(data_path, 'phase')
171 | self.captured_path = os.path.join(data_path, 'captured')
172 | elif isinstance(data_path, list):
173 | self.phase_path = [os.path.join(path, 'phase') for path in data_path]
174 | self.captured_path = [os.path.join(path, 'captured') for path in data_path]
175 |
176 | self.all_plane_idxs = plane_idxs
177 | self.avg_energy_ratio = avg_energy_ratio
178 | self.batch_size = batch_size
179 | self.shuffle = shuffle
180 | self.image_res = image_res
181 | self.slm_type = slm_type.lower()
182 | self.im_names = get_image_filenames(self.phase_path)
183 | self.im_names.sort()
184 |
185 | # create list of image IDs with augmentation state
186 | self.order = ((i,) for i in range(len(self.im_names)))
187 | self.order = list(self.order)
188 |
189 | def __iter__(self):
190 | self.ind = 0
191 | if self.shuffle:
192 | random.shuffle(self.order)
193 | return self
194 |
195 | def __len__(self):
196 | return len(self.im_names)
197 |
198 | def __next__(self):
199 | if self.ind < len(self.order):
200 | phase_idx = self.order[self.ind]
201 |
202 | self.ind += 1
203 | return self.load_pair(phase_idx[0])
204 | else:
205 | raise StopIteration
206 |
207 | def load_pair(self, filenum):
208 | phase_path = self.im_names[filenum]
209 | captured_path = os.path.splitext(os.path.dirname(phase_path))[0]
210 | captured_path = os.path.splitext(os.path.dirname(captured_path))[0]
211 | captured_path = os.path.join(captured_path, 'captured')
212 |
213 | # load phase
214 | phase_im_enc = imread(phase_path)
215 | im = (1 - phase_im_enc / np.iinfo(np.uint8).max) * 2 * np.pi - np.pi
216 | phase_im = torch.tensor(im, dtype=torch.float32).unsqueeze(0)
217 |
218 | _, captured_filename = os.path.split(os.path.splitext(self.im_names[filenum])[0])
219 | idx = captured_filename.split('/')[-1]
220 |
221 | # load focal stack
222 | captured_amps = []
223 | for plane_idx in self.all_plane_idxs:
224 | captured_filename = os.path.join(captured_path, f'{idx}_{plane_idx}.png')
225 | captured_intensity = utils.im2float(skimage.io.imread(captured_filename))
226 | captured_intensity = torch.tensor(captured_intensity, dtype=torch.float32)
227 | if self.avg_energy_ratio is not None:
228 | captured_intensity /= self.avg_energy_ratio[plane_idx] # energy compensation;
229 | captured_amp = torch.sqrt(captured_intensity)
230 | captured_amps.append(captured_amp)
231 | captured_amps = torch.stack(captured_amps, 0)
232 |
233 | return phase_im, captured_amps
234 |
235 |
236 | class TargetLoader(torch.utils.data.IterableDataset):
237 | """Loads target amp/mask tuples for phase optimization
238 |
239 | Class initialization parameters
240 | -------------------------------
241 | :param data_path:
242 | :param target_type:
243 | :param channel:
244 | :param image_res:
245 | :param roi_res:
246 | :param crop_to_roi:
247 | :param shuffle:
248 | :param vertical_flips:
249 | :param horizontal_flips:
250 | :param virtual_depth_planes:
251 | :param scale_vd_range:
252 |
253 | """
254 |
255 | def __init__(self, data_path, target_type, channel=None,
256 | image_res=(800, 1280), roi_res=(700, 1190),
257 | crop_to_roi=False, shuffle=False,
258 | vertical_flips=False, horizontal_flips=False,
259 | physical_depth_planes=None,
260 | virtual_depth_planes=None, scale_vd_range=True,
261 | mod_i=None, mod=None, options=None):
262 | """ initialization """
263 | if isinstance(data_path, str) and not os.path.isdir(data_path):
264 | raise NotADirectoryError(f'Data folder: {data_path}')
265 |
266 | self.data_path = data_path
267 | self.target_type = target_type.lower()
268 | self.channel = channel
269 | self.roi_res = roi_res
270 | self.crop_to_roi = crop_to_roi
271 | self.image_res = image_res
272 | self.shuffle = shuffle
273 | self.physical_depth_planes = physical_depth_planes
274 | self.virtual_depth_planes = virtual_depth_planes
275 | self.vd_min = 0.01
276 | self.vd_max = max(self.virtual_depth_planes)
277 | self.scale_vd_range = scale_vd_range
278 | self.options = options
279 |
280 | self.augmentations = []
281 | if vertical_flips:
282 | self.augmentations.append(self.augment_vert)
283 | if horizontal_flips:
284 | self.augmentations.append(self.augment_horz)
285 |
286 | # store the possible states for enumerating augmentations
287 | self.augmentation_states = [fn() for fn in self.augmentations]
288 |
289 | if target_type in ('2d', 'rgb'):
290 | self.im_names = get_image_filenames(self.data_path)
291 | self.im_names.sort()
292 | elif target_type in ('2.5d', 'rgbd'):
293 | self.im_names = get_image_filenames(os.path.join(self.data_path, 'rgb'))
294 | self.depth_names = get_image_filenames(os.path.join(self.data_path, 'depth'))
295 |
296 | self.im_names.sort()
297 | self.depth_names.sort()
298 |
299 | # create list of image IDs with augmentation state
300 | self.order = ((i,) for i in range(len(self.im_names)))
301 | for aug_type in self.augmentations:
302 | states = aug_type() # empty call gets possible states
303 | # augment existing list with new entry to states tuple
304 | self.order = ((*prev_states, s)
305 | for prev_states in self.order
306 | for s in states)
307 | self.order = list(self.order)
308 |
309 | if mod_i is not None:
310 | new_order = []
311 | for m, o in enumerate(self.order):
312 | if m % mod == mod_i:
313 | new_order.append(o)
314 | self.order = new_order
315 |
316 | def __iter__(self):
317 | self.ind = 0
318 | if self.shuffle:
319 | random.shuffle(self.order)
320 | return self
321 |
322 | def __len__(self):
323 | return len(self.order)
324 |
325 | def __next__(self):
326 | if self.ind < len(self.order):
327 | img_idx = self.order[self.ind]
328 |
329 | self.ind += 1
330 | if self.target_type in ('2d', 'rgb'):
331 | return self.load_image(*img_idx)
332 | if self.target_type in ('2.5d', 'rgbd'):
333 | return self.load_image_mask(*img_idx)
334 | else:
335 | raise StopIteration
336 |
337 | def load_image(self, filenum, *augmentation_states):
338 | im = imread(self.im_names[filenum])
339 |
340 | if len(im.shape) < 3:
341 | im = np.repeat(im[:, :, np.newaxis], 3, axis=2) # augment channels for gray images
342 |
343 | if self.channel is None:
344 | im = im[..., :3] # remove alpha channel, if any
345 | else:
346 | # select channel while keeping dims
347 | im = im[..., self.channel, np.newaxis]
348 |
349 | im = utils.im2float(im, dtype=np.float64) # convert to double, max 1
350 |
351 | # linearize intensity and convert to amplitude
352 | im = utils.srgb_gamma2lin(im)
353 | im = np.sqrt(im) # to amplitude
354 |
355 | # move channel dim to torch convention
356 | im = np.transpose(im, axes=(2, 0, 1))
357 |
358 | # apply data augmentation
359 | for fn, state in zip(self.augmentations, augmentation_states):
360 | im = fn(im, state)
361 |
362 | # normalize resolution
363 | if self.crop_to_roi:
364 | im = pad_crop_to_res(im, self.roi_res)
365 | else:
366 | im = resize_keep_aspect(im, self.roi_res)
367 | im = pad_crop_to_res(im, self.image_res)
368 |
369 | path = os.path.splitext(self.im_names[filenum])[0]
370 |
371 | return (torch.from_numpy(im).float(),
372 | None,
373 | os.path.split(path)[1].split('_')[-1])
374 |
375 | def load_depth(self, filenum, *augmentation_states):
376 | depth_path = self.depth_names[filenum]
377 | if 'exr' in depth_path:
378 | depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
379 | elif 'dpt' in depth_path:
380 | dist = depth_read(depth_path)
381 | depth = np.nan_to_num(dist, 100) # NaN to inf
382 | elif 'hdf5' in depth_path:
383 | # Depth (in m)
384 | with h5py.File(depth_path, 'r') as f:
385 | dist = np.array(f['dataset'][:], dtype=np.float32)
386 | depth = np.nan_to_num(dist, 100) # NaN to inf
387 | else:
388 | depth = imread(depth_path)
389 |
390 | depth = utils.im2float(depth, dtype=np.float64) # convert to double, max 1
391 |
392 | if 'bbb' in depth_path:
393 | depth *= 6 # this gives us decent depth distribution with 120mm eyepiece setting.
394 | elif 'sintel' in depth_path or 'dpt' in depth_path:
395 | depth /= 2.5 # this gives us decent depth distribution with 120mm eyepiece setting.
396 | if len(depth.shape) > 2 and depth.shape[-1] > 1:
397 | depth = depth[..., 1]
398 |
399 | if not 'eth' in depth_path.lower():
400 | depth = 1 / (depth + 1e-20) # meter to diopter conversion
401 |
402 | # apply data augmentation
403 | for fn, state in zip(self.augmentations, augmentation_states):
404 | depth = fn(depth, *state)
405 |
406 | depth = torch.from_numpy(depth.copy()).float().unsqueeze(0)
407 | # normalize resolution
408 | depth.unsqueeze_(0)
409 | if self.crop_to_roi:
410 | depth = pad_crop_to_res(depth, self.roi_res, pytorch=True)
411 | else:
412 | depth = resize_keep_aspect(depth, self.roi_res, pytorch=True)
413 | depth = pad_crop_to_res(depth, self.image_res, pytorch=True)
414 |
415 | # perform scaling in meters
416 | if self.scale_vd_range:
417 | depth = depth - depth.min()
418 | depth = (depth / depth.max()) * (self.vd_max - self.vd_min)
419 | depth = depth + self.vd_min
420 |
421 | # check nans
422 | if (depth.isnan().any()):
423 | print("Found Nans in target depth!")
424 | min_substitute = self.vd_min * torch.ones_like(depth)
425 | depth = torch.where(depth.isnan(), min_substitute, depth)
426 |
427 | path = os.path.splitext(self.depth_names[filenum])[0]
428 |
429 | return (depth.float(),
430 | None,
431 | os.path.split(path)[1].split('_')[-1])
432 |
433 | def load_image_mask(self, filenum, *augmentation_states):
434 | img_none_idx = self.load_image(filenum, *augmentation_states)
435 | depth_none_idx = self.load_depth(filenum, *augmentation_states)
436 | mask = utils.decompose_depthmap(depth_none_idx[0], self.virtual_depth_planes)
437 | return (img_none_idx[0].unsqueeze(0), mask, img_none_idx[-1])
438 |
439 | def augment_vert(self, image=None, flip=False):
440 | """ augment data with vertical flip """
441 | if image is None:
442 | return (True, False) # return possible augmentation values
443 |
444 | if flip:
445 | return image[..., ::-1, :]
446 | return image
447 |
448 | def augment_horz(self, image=None, flip=False):
449 | """ augment data with horizontal flip """
450 | if image is None:
451 | return (True, False) # return possible augmentation values
452 |
453 | if flip:
454 | return image[..., ::-1]
455 | return image
456 |
457 |
--------------------------------------------------------------------------------
/img/citl-asm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/img/citl-asm.png
--------------------------------------------------------------------------------
/img/sgd-asm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/img/sgd-asm.png
--------------------------------------------------------------------------------
/img/sgd-ours.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/img/sgd-ours.png
--------------------------------------------------------------------------------
/img/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-imaging/neural-3d-holography/db52af35f83823693805a08c33c06864f40b2e02/img/teaser.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Neural 3D holography: Learning accurate wave propagation models for 3D holographic virtual and augmented reality displays
3 |
4 | Suyeon Choi*, Manu Gopakumar*, Yifan Peng, Jonghyun Kim, Gordon Wetzstein
5 |
6 | This is the main executive script used for the phase generation using SGD.
7 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
8 | # The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
9 | # The material is provided as-is, with no warranties whatsoever.
10 | # If you publish any code, data, or scientific work based on this, please cite our work.
11 | -----
12 |
13 | $ python main.py --lr=0.01 --num_iters=10000
14 |
15 | """
16 | import algorithms as algs
17 | import image_loader as loaders
18 | import numpy as np
19 |
20 | import os
21 | import torch
22 | from torch.utils.tensorboard import SummaryWriter
23 | import imageio
24 | import configargparse
25 | import prop_physical
26 | import prop_model
27 | import utils
28 | import params
29 |
30 |
31 | def main():
32 | # Command line argument processing / Parameters
33 | torch.set_default_dtype(torch.float32)
34 | p = configargparse.ArgumentParser()
35 | p.add('-c', '--config_filepath', required=False,
36 | is_config_file=True, help='Path to config file.')
37 | params.add_parameters(p, 'eval')
38 | opt = p.parse_args()
39 | params.set_configs(opt)
40 | dev = torch.device('cuda')
41 |
42 | run_id = params.run_id(opt)
43 | # path to save out optimized phases
44 | out_path = os.path.join(opt.out_path, run_id)
45 | print(f' - out_path: {out_path}')
46 |
47 | # Tensorboard
48 | summaries_dir = os.path.join(out_path, 'summaries')
49 | utils.cond_mkdir(summaries_dir)
50 | writer = SummaryWriter(summaries_dir)
51 |
52 | # Propagations
53 | camera_prop = None
54 | if opt.citl:
55 | camera_prop = prop_physical.PhysicalProp(*(params.hw_params(opt))).to(dev)
56 | sim_prop = prop_model.model(opt)
57 |
58 | # Algorithm
59 | algorithm = algs.load_alg(opt.method)
60 |
61 | # Loader
62 | if ',' in opt.data_path:
63 | opt.data_path = opt.data_path.split(',')
64 | img_loader = loaders.TargetLoader(opt.data_path, opt.target, channel=opt.channel,
65 | image_res=opt.image_res, roi_res=opt.roi_res,
66 | crop_to_roi=False, shuffle=opt.random_gen,
67 | vertical_flips=opt.random_gen, horizontal_flips=opt.random_gen,
68 | physical_depth_planes=opt.physical_depth_planes,
69 | virtual_depth_planes=opt.virtual_depth_planes,
70 | scale_vd_range=False,
71 | mod_i=opt.mod_i, mod=opt.mod, options=opt)
72 |
73 | for i, target in enumerate(img_loader):
74 | target_amp, target_mask, target_idx = target
75 | target_amp = target_amp.to(dev).detach()
76 | if target_mask is not None:
77 | target_mask = target_mask.to(dev).detach()
78 | if len(target_amp.shape) < 4:
79 | target_amp = target_amp.unsqueeze(0)
80 |
81 | print(f' - run phase optimization for {target_idx}th image ...')
82 | if opt.random_gen: # random parameters for dataset generation
83 | opt.num_iters, opt.init_phase_range, \
84 | target_range, opt.lr, opt.eval_plane_idx = utils.random_gen(num_planes=opt.num_planes,
85 | slm_type=opt.slm_type)
86 | sim_prop = prop_model.model(opt)
87 | target_amp *= target_range
88 |
89 | # initial slm phase
90 | init_phase = (opt.init_phase_range * (-0.5 + 1.0 * torch.rand(1, 1, *opt.slm_res))).to(dev)
91 |
92 | # run algorithm
93 | results = algorithm(init_phase, target_amp, target_mask,
94 | forward_prop=sim_prop, num_iters=opt.num_iters, roi_res=opt.roi_res,
95 | loss_fn=opt.loss_fn, lr=opt.lr,
96 | out_path_idx=f'{opt.out_path}_{target_idx}',
97 | citl=opt.citl, camera_prop=camera_prop,
98 | writer=writer,
99 | )
100 |
101 | # optimized slm phase
102 | final_phase = results['final_phase']
103 |
104 | # encoding for SLM & save it out
105 | phase_out = utils.phasemap_8bit(final_phase)
106 | if opt.random_gen:
107 | phase_out_path = os.path.join(out_path, f'{target_idx}_{opt.num_iters}.png')
108 | else:
109 | phase_out_path = os.path.join(out_path, f'{target_idx}_{opt.eval_plane_idx}.png')
110 | imageio.imwrite(phase_out_path, phase_out)
111 |
112 | if __name__ == "__main__":
113 | main()
114 |
--------------------------------------------------------------------------------
/params.py:
--------------------------------------------------------------------------------
1 | """
2 | Default parameter settings for SLMs as well as laser/sensors
3 |
4 | """
5 | import datetime
6 | import math
7 | import sys
8 | import numpy as np
9 | import utils
10 | import torch.nn as nn
11 | if sys.platform == 'win32':
12 | import serial
13 |
14 | cm, mm, um, nm = 1e-2, 1e-3, 1e-6, 1e-9
15 |
16 | def str2bool(v):
17 | """ Simple query parser for configArgParse (which doesn't support native bool from cmd)
18 | Ref: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
19 |
20 | """
21 | if isinstance(v, bool):
22 | return v
23 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
24 | return True
25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
26 | return False
27 | else:
28 | raise ValueError('Boolean value expected.')
29 |
30 |
31 | class PMap(dict):
32 | # use it for parameters
33 | __getattr__ = dict.get
34 | __setattr__ = dict.__setitem__
35 | __delattr__ = dict.__delitem__
36 |
37 |
38 | def add_parameters(p, mode='train'):
39 | p.add_argument('--channel', type=int, default=None, help='Red:0, green:1, blue:2')
40 | p.add_argument('--method', type=str, default='SGD', help='Type of algorithm, GS/SGD/DPAC/HOLONET/UNET')
41 | p.add_argument('--slm_type', type=str, default='leto', help='leto/pluto ...')
42 | p.add_argument('--sensor_type', type=str, default='2k', help='sensor type')
43 | p.add_argument('--laser_type', type=str, default='new', help='old, new_laser, sLED, ...')
44 | p.add_argument('--setup_type', type=str, default='sigasia2021_vr', help='VR or AR')
45 | p.add_argument('--prop_model', type=str, default='ASM', help='Type of propagation model, ASM/NH/NH3D')
46 | p.add_argument('--out_path', type=str, default='./results',
47 | help='Directory for output')
48 | p.add_argument('--citl', type=str2bool, default=False,
49 | help='If True, run camera-in-the-loop')
50 | p.add_argument('--mod_i', type=int, default=None,
51 | help='If not None, say K, pick every K target images from the target loader')
52 | p.add_argument('--mod', type=int, default=None,
53 | help='If not None, say K, pick every K target images from the target loader')
54 | p.add_argument('--data_path', type=str, default='/mount/workspace/data/NH3D/',
55 | help='Directory for input')
56 | p.add_argument('--exp', type=str, default='', help='Name of experiment')
57 | p.add_argument('--lr', type=float, default=0.02, help='Learning rate')
58 | p.add_argument('--num_iters', type=int, default=1000, help='Number of iterations (GS, SGD)')
59 | p.add_argument('--prop_dist', type=float, default=None, help='propagation distance from SLM to midplane')
60 | p.add_argument('--F_aperture', type=float, default=1.0, help='Fourier filter size')
61 | p.add_argument('--eyepiece', type=float, default=0.12, help='eyepiece focal length')
62 | p.add_argument('--full_roi', type=str2bool, default=False,
63 | help='If True, force ROI to SLM resolution')
64 | p.add_argument('--target', type=str, default='rgbd',
65 | help='Type of target:' '{2d, rgb} or {2.5d, rgbd}')
66 | p.add_argument('--show_preview', type=str2bool, default=False,
67 | help='If true, show the preview for homography calibration')
68 | p.add_argument('--random_gen', type=str2bool, default=False,
69 | help='If true, randomize a few parameters for phase dataset generation')
70 | p.add_argument('--eval_plane_idx', type=int, default=None,
71 | help='When evaluating 2d, choose the index of the plane')
72 | p.add_argument('--init_phase_range', type=float, default=1.0,
73 | help='Phase sampling range for initializaation')
74 |
75 | if mode in ('train', 'eval'):
76 | p.add_argument('--num_epochs', type=int, default=350, help='')
77 | p.add_argument('--batch_size', type=int, default=1, help='')
78 | p.add_argument('--prop_model_path', type=str, default=None, help='Path to checkpoints')
79 | p.add_argument('--predefined_model', type=str, default=None, help='string for predefined model'
80 | 'nh, nh3d')
81 | p.add_argument('--num_downs_slm', type=int, default=5, help='')
82 | p.add_argument('--num_feats_slm_min', type=int, default=32, help='')
83 | p.add_argument('--num_feats_slm_max', type=int, default=128, help='')
84 | p.add_argument('--num_downs_target', type=int, default=5, help='')
85 | p.add_argument('--num_feats_target_min', type=int, default=32, help='')
86 | p.add_argument('--num_feats_target_max', type=int, default=128, help='')
87 | p.add_argument('--slm_coord', type=str, default='rect', help='coordinates to represent a complex-valued field.'
88 | 'rect(real+imag) or polar(amp+phase)')
89 | p.add_argument('--target_coord', type=str, default='rect', help='coordinates to represent a complex-valued field.'
90 | 'rect(real+imag) or polar(amp+phase)')
91 | p.add_argument('--norm', type=str, default='instance', help='normalization layer')
92 | p.add_argument('--loss_func', type=str, default='l1', help='l1 or l2')
93 | p.add_argument('--energy_compensation', type=str2bool, default=True, help='adjust intensities '
94 | 'with avg intensity of training set')
95 | p.add_argument('--num_train_planes', type=int, default=7, help='number of planes fed to models')
96 |
97 | return p
98 |
99 |
100 | def set_configs(opt):
101 | """
102 | set or replace parameters with pre-defined parameters with string inputs
103 | """
104 |
105 | # hardware setup
106 | optics_config(opt.setup_type, opt) # prop_dist, etc ...
107 | laser_config(opt.laser_type, opt) # Our Old FISBA Laser, New, SLED, LED, ...
108 | slm_config(opt.slm_type, opt) # Holoeye
109 | sensor_config(opt.sensor_type, opt) # our sensor
110 |
111 | # set predefined model parameters
112 | forward_model_config(opt.prop_model, opt)
113 |
114 | # wavelength, propagation distance (from SLM to midplane)
115 | if opt.channel is None:
116 | opt.chan_str = 'rgb'
117 | opt.prop_dist = opt.prop_dists_rgb
118 | opt.wavelength = opt.wavelengths
119 | else:
120 | opt.chan_str = ('red', 'green', 'blue')[opt.channel]
121 | if opt.prop_dist is None:
122 | opt.prop_dist = opt.prop_dists_rgb[opt.channel][opt.mid_idx] # prop dist from SLM plane to target plane
123 | opt.prop_dist_green = opt.prop_dists_rgb[opt.channel][1]
124 | else:
125 | opt.prop_dist_green = opt.prop_dist
126 | opt.wavelength = opt.wavelengths[opt.channel] # wavelength of each color
127 |
128 | # propagation distances from the wavefront recording plane
129 | opt.prop_dists_from_wrp = [p - opt.prop_dist for p in opt.prop_dists_rgb[opt.channel]]
130 | opt.physical_depth_planes = [p - opt.prop_dist_green for p in opt.prop_dists_physical]
131 | opt.virtual_depth_planes = utils.prop_dist_to_diopter(opt.physical_depth_planes,
132 | opt.eyepiece,
133 | opt.physical_depth_planes[0])
134 | opt.num_planes = len(opt.prop_dists_from_wrp)
135 | opt.all_plane_idxs = range(opt.num_planes)
136 |
137 | # force ROI to that of SLM
138 | if opt.full_roi:
139 | opt.roi_res = opt.slm_res
140 |
141 | ################
142 | # Model Training
143 | # compensate the brightness difference per plane (for model training)
144 | if opt.energy_compensation:
145 | opt.avg_energy_ratio = opt.avg_energy_ratio_rgb[opt.channel]
146 | else:
147 | opt.avg_energy_ratio = None
148 |
149 | # loss functions (for model training)
150 | opt.loss_train = None
151 | opt.loss_fn = None
152 | if opt.loss_func.lower() in ('l2', 'mse'):
153 | opt.loss_train = nn.functional.mse_loss
154 | opt.loss_fn = nn.functional.mse_loss
155 | elif opt.loss_func.lower() == 'l1':
156 | opt.loss_train = nn.functional.l1_loss
157 | opt.loss_fn = nn.functional.l1_loss
158 |
159 | # plane idxs (for model training)
160 | opt.plane_idxs = {}
161 | opt.plane_idxs['all'] = opt.all_plane_idxs
162 | opt.plane_idxs['train'] = opt.training_plane_idxs
163 | opt.plane_idxs['validation'] = opt.training_plane_idxs
164 | opt.plane_idxs['test'] = opt.training_plane_idxs
165 | opt.plane_idxs['heldout'] = opt.heldout_plane_idxs
166 |
167 | admm_opt = None
168 | if 'admm' in opt.method:
169 | admm_opt = {'num_iters_inner': 50,
170 | 'rho': 0.01,
171 | 'alpha': 1.0,
172 | 'gamma': 0.1,
173 | 'varying-penalty': True,
174 | 'mu': 10.0,
175 | 'tau_incr': 2.0,
176 | 'tau_decr': 2.0}
177 |
178 | return opt
179 |
180 |
181 | def run_id(opt):
182 | id_str = f'{opt.exp}_{opt.chan_str}_{opt.prop_model}_{opt.lr}_{opt.num_iters}'
183 | id_str = f'{opt.chan_str}'
184 | return id_str
185 |
186 | def run_id_training(opt):
187 | id_str = f'{opt.exp}_{opt.chan_str}-' \
188 | f'slm{opt.num_downs_slm}-{opt.num_feats_slm_min}-{opt.num_feats_slm_max}_' \
189 | f'tg{opt.num_downs_target}-{opt.num_feats_target_min}-{opt.num_feats_target_max}_' \
190 | f'{opt.slm_coord}{opt.target_coord}_{opt.loss_func}_{opt.num_train_planes}pls_' \
191 | f'bs{opt.batch_size}'
192 | cur_time = datetime.datetime.now().strftime("%d-%H%M")
193 | id_str = f'{cur_time}_{id_str}'
194 |
195 | return id_str
196 |
197 |
198 | def hw_params(opt):
199 | """ Default setting for hardware. Please replace and adjust parameters for your own setup. """
200 | params_slm = PMap()
201 | params_slm.settle_time = 0.3
202 | params_slm.monitor_num = 2
203 | params_slm.slm_type = opt.slm_type
204 |
205 | params_camera = PMap()
206 | params_camera.img_size_native = (3000, 4096) # 4k sensor native
207 | params_camera.ser = None #serial.Serial('COM5', 9600, timeout=0.5)
208 |
209 | params_calib = PMap()
210 | params_calib.show_preview = opt.show_preview
211 | params_calib.range_y = slice(0, params_camera.img_size_native[0])
212 | params_calib.range_x = slice(0, params_camera.img_size_native[1])
213 | params_calib.num_circles = (13, 22)
214 | params_calib.spacing_size = [int(roi / (num_circs - 1))
215 | for roi, num_circs in zip(opt.roi_res, params_calib.num_circles)]
216 | params_calib.pad_pixels = [int(slm - roi) // 2 for slm, roi in zip(opt.slm_res, opt.roi_res)]
217 | params_calib.quadratic = True
218 | params_calib.phase_path = f'./calibration/{opt.chan_str}/1_{opt.eval_plane_idx}.png'
219 | params_calib.img_size_native = params_camera.img_size_native
220 |
221 | return params_slm, params_camera, params_calib
222 |
223 |
224 | def slm_config(slm_type, opt):
225 | """ Setting for specific SLM. """
226 | if slm_type.lower() in ('leto'):
227 | opt.feature_size = (6.4 * um, 6.4 * um) # SLM pitch
228 | opt.slm_res = (1080, 1920) # resolution of SLM
229 | opt.image_res = opt.slm_res
230 | elif slm_type.lower() in ('pluto'):
231 | opt.feature_size = (8.0 * um, 8.0 * um) # SLM pitch
232 | opt.slm_res = (1080, 1920) # resolution of SLM
233 | opt.image_res = opt.slm_res
234 |
235 |
236 | def laser_config(laser_type, opt):
237 | """ Setting for specific laser. """
238 | if 'new' in laser_type.lower():
239 | opt.wavelengths = (636.17 * nm, 518.48 * nm, 442.03 * nm) # wavelength of each color
240 | elif 'ar' in laser_type.lower():
241 | opt.wavelengths = (532 * nm, 532 * nm, 532 * nm)
242 | else:
243 | opt.wavelengths = (636.4 * nm, 517.7 * nm, 440.8 * nm)
244 |
245 |
246 | def sensor_config(sensor_type, opt):
247 | return opt
248 |
249 |
250 | def optics_config(setup_type, opt):
251 | """ Setting for specific setup (prop dists, filter, training plane index ...) """
252 | if setup_type in ('sigasia2021_vr'):
253 | opt.laser_type = 'old'
254 | opt.slm_type = 'leto'
255 | opt.prop_dists_rgb = [[0.0, 1*mm, 2*mm, 3*mm, 4.4*mm, 5.7*mm, 7.0*mm, 8.1*mm],
256 | [0.0, 1.2*mm, 2.0*mm, 3.4*mm, 4.4*mm, 5.7*mm, 7.2*mm, 8.2*mm],
257 | [0.0, 1*mm, 2.1*mm, 3.2*mm, 4.3*mm, 5.5*mm, 7.2*mm, 8.2*mm]]
258 |
259 | opt.avg_energy_ratio_rgb = [[1.0000, 1.0407, 1.0870, 1.1216, 1.1568, 1.2091, 1.2589, 1.2924],
260 | [1.0000, 1.0409, 1.0869, 1.1226, 1.1540, 1.2107, 1.2602, 1.2958],
261 | [1.0000, 1.0409, 1.0869, 1.1226, 1.1540, 1.2107, 1.2602, 1.2958]]
262 | opt.prop_dists_physical = opt.prop_dists_rgb[1]
263 | opt.F_aperture = 0.5
264 | opt.roi_res = (960, 1680) # regions of interest (to penalize for SGD)
265 | opt.training_plane_idxs = [0, 1, 3, 4, 5, 6, 7]
266 | opt.heldout_plane_idxs = [2]
267 | opt.mid_idx = 4 # intermediate plane as 1.5D
268 | elif setup_type in ('sigasia2021_ar'):
269 | opt.laser_type = 'ar_green_only'
270 | opt.slm_type = 'pluto'
271 | opt.prop_dists_rgb = [[9.9*mm, 10.3*mm, 11.8*mm, 13.3*mm],
272 | [9.9*mm, 10.3*mm, 11.8*mm, 13.3*mm],
273 | [9.9*mm, 10.3*mm, 11.8*mm, 13.3*mm]]
274 | opt.prop_dists_physical = opt.prop_dists_rgb[1]
275 | opt.F_aperture = 0.5
276 | opt.roi_res = (768, 1536) # regions of interest (to penalize for SGD)
277 | opt.training_plane_idxs = [0, 1, 2]
278 | opt.heldout_plane_idxs = []
279 |
280 |
281 | def forward_model_config(model_type, opt):
282 | # setting for specific model that is predefined.
283 | if model_type is not None:
284 | print(f' - changing model parameters for {model_type}')
285 | if model_type.lower() in ('cnnpropcnn', 'nh3d'):
286 | opt.num_downs_slm = 8
287 | opt.num_feats_slm_min = 32
288 | opt.num_feats_slm_max = 512
289 | opt.num_downs_target = 5
290 | opt.num_feats_target_min = 8
291 | opt.num_feats_target_max = 128
292 | elif model_type.lower() == 'hil':
293 | opt.num_downs_slm = 0
294 | opt.num_feats_slm_min = 0
295 | opt.num_feats_slm_max = 0
296 | opt.num_downs_target = 8
297 | opt.num_feats_target_min = 32
298 | opt.num_feats_target_max = 512
299 | opt.target_coord = 'amp'
300 | elif model_type.lower() == 'cnnprop':
301 | opt.num_downs_slm = 8
302 | opt.num_feats_slm_min = 32
303 | opt.num_feats_slm_max = 512
304 | opt.num_downs_target = 0
305 | opt.num_feats_target_min = 0
306 | opt.num_feats_target_max = 0
307 | elif model_type.lower() == 'propcnn':
308 | opt.num_downs_slm = 0
309 | opt.num_feats_slm_min = 0
310 | opt.num_feats_slm_max = 0
311 | opt.num_downs_target = 8
312 | opt.num_feats_target_min = 32
313 | opt.num_feats_target_max = 512
--------------------------------------------------------------------------------
/prop_ideal.py:
--------------------------------------------------------------------------------
1 | """
2 | Ideal propagation
3 |
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import utils
9 | import torch.fft as tfft
10 | import math
11 |
12 | class Propagation(nn.Module):
13 | """
14 | The ideal, convolution-based propagation implementation
15 |
16 | Class initialization parameters
17 | -------------------------------
18 | :param prop_dist: propagation distance(s)
19 | :param wavelength: wavelength
20 | :param feature_size: pixel pitch
21 | :param prop_type: type of propagation (ASM or fresnel), by default the angular spectrum method
22 | :param F_aperture: filter size at fourier plane, by default 1.0
23 | :param dim: for propagation to multiple planes, dimension to stack the output, by default 1 (second dimension)
24 | :param linear_conv: If true, pad zeros to ensure the linear convolution, by default True
25 | :param learned_amp: Learned amplitude at Fourier plane, by default None
26 | :param learned_phase: Learned phase at Fourier plane, by default None
27 | """
28 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0,
29 | dim=1, linear_conv=True, learned_amp=None, learned_phase=None):
30 | super(Propagation, self).__init__()
31 |
32 | self.H = None # kernel at Fourier plane
33 | self.prop_type = prop_type
34 | if not isinstance(prop_dist, list):
35 | prop_dist = [prop_dist]
36 | self.prop_dist = prop_dist
37 | self.feature_size = feature_size
38 | if not isinstance(wavelength, list):
39 | wavelength = [wavelength]
40 | self.wvl = wavelength
41 | self.linear_conv = linear_conv # ensure linear convolution by padding
42 | self.bl_asm = min(prop_dist) > 0.3
43 | self.F_aperture = F_aperture
44 | self.dim = dim # The dimension to stack the kernels as well as the resulting fields (if multi-channel)
45 |
46 | self.preload_params = False
47 | self.preloaded_H_amp = False # preload H_mask once trained
48 | self.preloaded_H_phase = False # preload H_phase once trained
49 |
50 | self.fourier_amp = learned_amp
51 | self.fourier_phase = learned_phase
52 |
53 | def forward(self, u_in):
54 | if u_in.dtype == torch.float32:
55 | u_in = torch.exp(1j * u_in)
56 |
57 | if self.H is None:
58 | Hs = []
59 | if len(self.wvl) > 1: # If multi-channel, rearrange kernels
60 | for wv, prop_dist in zip(self.wvl, self.prop_dist):
61 | print(f' -- generating kernel for {wv*1e9:.1f}nm, {prop_dist*100:.2f}cm..')
62 | h = self.compute_H(torch.empty_like(u_in), prop_dist, wv, self.feature_size,
63 | self.prop_type, self.linear_conv,
64 | F_aperture=self.F_aperture, bl_asm=self.bl_asm)
65 | Hs.append(h)
66 | self.H = torch.cat(Hs, dim=self.dim)
67 | else:
68 | for wv in self.wvl:
69 | for prop_dist in self.prop_dist:
70 | print(f' -- generating kernel for {wv*1e9:.1f}nm, {prop_dist*100:.2f}cm..')
71 | h = self.compute_H(torch.empty_like(u_in), prop_dist, wv, self.feature_size,
72 | self.prop_type, self.linear_conv,
73 | F_aperture=self.F_aperture, bl_asm=self.bl_asm)
74 | Hs.append(h)
75 | self.H = torch.cat(Hs, dim=1)
76 |
77 | if self.preload_params:
78 | self.premultiply()
79 |
80 | if self.fourier_amp is not None and not self.preloaded_H_amp:
81 | H = self.fourier_amp.clamp(min=0.) * self.H
82 | else:
83 | H = self.H
84 |
85 | if self.fourier_phase is not None and not self.preloaded_H_phase:
86 | H = H * torch.exp(1j * self.fourier_phase)
87 |
88 | return self.prop(u_in, H, self.linear_conv)
89 |
90 | def compute_H(self, input_field, prop_dist, wvl, feature_size, prop_type, lin_conv=True,
91 | return_exp=False, F_aperture=1.0, bl_asm=False, return_filter=False):
92 | dev = input_field.device
93 | res_mul = 2 if lin_conv else 1
94 | num_y, num_x = res_mul*input_field.shape[-2], res_mul*input_field.shape[-1] # number of pixels
95 | dy, dx = feature_size # sampling inteval size
96 |
97 | # frequency coordinates sampling
98 | fy = torch.linspace(-1 / (2 * dy), 1 / (2 * dy), num_y)
99 | fx = torch.linspace(-1 / (2 * dx), 1 / (2 * dx), num_x)
100 |
101 | # momentum/reciprocal space
102 | # FY, FX = torch.meshgrid(fy, fx)
103 | FX, FY = torch.meshgrid(fx, fy)
104 | FX = torch.transpose(FX, 0, 1)
105 | FY = torch.transpose(FY, 0, 1)
106 |
107 | if prop_type.lower() == 'asm':
108 | G = 2 * math.pi * (1 / wvl**2 - (FX ** 2 + FY ** 2)).sqrt()
109 | elif prop_type.lower() == 'fresnel':
110 | G = math.pi * wvl * (FX ** 2 + FY ** 2)
111 |
112 | H_exp = G.reshape((1, 1, *G.shape)).to(dev)
113 |
114 | if return_exp:
115 | return H_exp
116 |
117 | if bl_asm:
118 | fy_max = 1 / math.sqrt((2 * prop_dist * (1 / (dy * float(num_y))))**2 + 1) / wvl
119 | fx_max = 1 / math.sqrt((2 * prop_dist * (1 / (dx * float(num_x))))**2 + 1) / wvl
120 |
121 | H_filter = ((torch.abs(FX**2 + FY**2) <= (F_aperture**2) * torch.abs(FX**2 + FY**2).max())
122 | & (torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).type(torch.FloatTensor)
123 | else:
124 | H_filter = (torch.abs(FX**2 + FY**2) <= (F_aperture**2) * torch.abs(FX**2 + FY**2).max()).type(torch.FloatTensor)
125 |
126 | if prop_dist == 0.:
127 | H = torch.ones_like(H_exp)
128 | else:
129 | H = H_filter.to(input_field.device) * torch.exp(1j * H_exp * prop_dist)
130 |
131 | if return_filter:
132 | return H_filter
133 | else:
134 | return H
135 |
136 | def prop(self, u_in, H, linear_conv=True, padtype='zero'):
137 | if linear_conv:
138 | # preprocess with padding for linear conv.
139 | input_resolution = u_in.size()[-2:]
140 | conv_size = [i * 2 for i in input_resolution]
141 | if padtype == 'zero':
142 | padval = 0
143 | elif padtype == 'median':
144 | padval = torch.median(torch.pow((u_in ** 2).sum(-1), 0.5))
145 | u_in = utils.pad_image(u_in, conv_size, padval=padval, stacked_complex=False)
146 |
147 | U1 = tfft.fftshift(tfft.fftn(u_in, dim=(-2, -1), norm='ortho'), (-2, -1))
148 | U2 = U1 * H
149 | u_out = tfft.ifftn(tfft.ifftshift(U2, (-2, -1)), dim=(-2, -1), norm='ortho')
150 |
151 | if linear_conv:
152 | u_out = utils.crop_image(u_out, input_resolution, pytorch=True, stacked_complex=False)
153 |
154 | return u_out
155 |
156 | def __len__(self):
157 | return len(self.prop_dist)
158 |
159 | def preload_H(self):
160 | self.preload_params = True
161 |
162 | def premultiply(self):
163 | self.preload_params = False
164 |
165 | if self.fourier_amp is not None and not self.preloaded_H_amp:
166 | self.H = self.fourier_amp.clamp(min=0.) * self.H
167 | if self.fourier_phase is not None and not self.preloaded_H_phase:
168 | self.H = self.H * torch.exp(1j * self.fourier_phase)
169 |
170 | self.H.detach_()
171 | self.preloaded_H_amp = True
172 | self.preloaded_H_phase = True
173 |
174 | @property
175 | def plane_idx(self):
176 | return self._plane_idx
177 |
178 | @plane_idx.setter
179 | def plane_idx(self, idx):
180 | if idx is None:
181 | return
182 |
183 | self._plane_idx = idx
184 | if len(self.prop_dist) > 1:
185 | self.prop_dist = [self.prop_dist[idx]]
186 |
187 | if self.fourier_amp is not None and self.fourier_amp.shape[1] > 1:
188 | self.fourier_amp = nn.Parameter(self.fourier_amp[:, idx:idx+1, ...], requires_grad=False)
189 | if self.fourier_phase is not None and self.fourier_phase.shape[1] > 1:
190 | self.fourier_phase = nn.Parameter(self.fourier_phase[:, idx:idx+1, ...], requires_grad=False)
191 |
192 |
193 |
194 | class SerialProp(nn.Module):
195 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0,
196 | prop_dists_from_wrp=None, linear_conv=True, dim=1):
197 | super(SerialProp, self).__init__()
198 |
199 | first_prop = Propagation(prop_dist, wavelength, feature_size,
200 | prop_type=prop_type, linear_conv=linear_conv, F_aperture=F_aperture, dim=dim)
201 | props = [first_prop]
202 | if prop_dists_from_wrp is not None:
203 | second_prop = Propagation(prop_dists_from_wrp, wavelength, feature_size,
204 | prop_type=prop_type, linear_conv=linear_conv, F_aperture=1.0, dim=dim)
205 | props += [second_prop]
206 | self.props = nn.Sequential(*props)
207 |
208 | def forward(self, u_in):
209 |
210 | u_out = self.props(u_in)
211 |
212 | return u_out
213 |
214 | def preload_H(self):
215 | for prop in self.props:
216 | prop.preload_H()
217 |
218 | @property
219 | def plane_idx(self):
220 | return self._plane_idx
221 |
222 | @plane_idx.setter
223 | def plane_idx(self, idx):
224 | if idx is None:
225 | return
226 |
227 | self._plane_idx = idx
228 | for prop in self.props:
229 | prop.plane_idx = idx
230 |
231 |
--------------------------------------------------------------------------------
/prop_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Parameterized propagations
3 |
4 | """
5 |
6 | import math
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import matplotlib.pyplot as plt
11 | import pytorch_lightning as pl
12 | import utils
13 | from unet import UnetGenerator, init_weights, norm_layer
14 | import prop_ideal
15 | from prop_submodules import Field2Input, Output2Field, Conv2dField,\
16 | ContentDependentField, LatentCodedMLP, SourceAmplitude
17 | from prop_zernike import compute_zernike_basis, combine_zernike_basis
18 |
19 |
20 | def model(opt, dev=torch.device('cuda'), preload_H=False):
21 | """
22 |
23 | :param opt:
24 | :param dev:
25 | :param preload_H:
26 | :return:
27 | """
28 | if opt.prop_model.lower() == 'asm':
29 | sim_prop = prop_ideal.SerialProp(opt.prop_dist, opt.wavelength, opt.feature_size,
30 | 'ASM', opt.F_aperture, opt.prop_dists_from_wrp,
31 | dim=1)
32 | elif opt.prop_model.lower() == 'nh':
33 | sim_prop = PropNH(prop_dist=opt.prop_dist, # Parameterized wave propagation model
34 | feature_size=opt.feature_size,
35 | wavelength=opt.wavelength,
36 | slm_res=opt.slm_res,
37 | roi_res=opt.roi_res,
38 | F_aperture=opt.F_aperture,
39 | prop_dists_from_wrp=opt.prop_dists_from_wrp,
40 | loss_func=opt.loss_train,
41 | lr=opt.lr,
42 | plane_idxs=opt.plane_idxs
43 | ).to(dev)
44 | elif opt.prop_model.lower() in ('cnnprop', 'propcnn', 'hil', 'nh3d', 'cnnpropcnn'):
45 | sim_prop = CNNpropCNN(opt.prop_dist, opt.wavelength, opt.feature_size,
46 | prop_type='ASM',
47 | F_aperture=opt.F_aperture,
48 | prop_dists_from_wrp=opt.prop_dists_from_wrp,
49 | linear_conv=True,
50 | slm_res=opt.slm_res,
51 | roi_res=opt.roi_res,
52 | num_downs_slm=opt.num_downs_slm,
53 | num_feats_slm_min=opt.num_feats_slm_min,
54 | num_feats_slm_max=opt.num_feats_slm_max,
55 | num_downs_target=opt.num_downs_target,
56 | num_feats_target_min=opt.num_feats_target_min,
57 | num_feats_target_max=opt.num_feats_target_max,
58 | norm=norm_layer(opt.norm),
59 | slm_coord=opt.slm_coord,
60 | target_coord=opt.target_coord,
61 | loss_func=opt.loss_train,
62 | lr=opt.lr,
63 | plane_idxs=opt.plane_idxs
64 | ).to(dev)
65 |
66 | if opt.prop_model_path is not None:
67 | checkpoint = torch.load(opt.prop_model_path)
68 | sim_prop.load_state_dict(checkpoint["state_dict"])
69 | sim_prop.eval()
70 | print(f' - Model loaded from {opt.prop_model_path}')
71 |
72 | if preload_H:
73 | sim_prop.preload_H()
74 |
75 | if opt.eval_plane_idx is not None:
76 | sim_prop.plane_idx = opt.eval_plane_idx
77 |
78 | return sim_prop
79 |
80 |
81 | class PropModel(pl.LightningModule):
82 | """
83 | A parameterized model trained on captured images at multiple planes
84 |
85 | Class initialization parameters
86 | -------------------------------
87 | :param roi_res:
88 | :param plane_idxs:
89 | :param loss_func:
90 | :param lr:
91 | """
92 | def __init__(self, roi_res=(1080, 1920), plane_idxs=None, loss_func=F.l1_loss, lr=4e-4):
93 | super(PropModel, self).__init__()
94 | self.roi_res = roi_res
95 | self.plane_idxs = plane_idxs
96 | self.loss_func = loss_func
97 | self.lr = lr
98 | self.recon_amp = {}
99 | self.target_amp = {}
100 |
101 | def perform_step(self, batch, batch_idx, prefix):
102 | slm_phase, target_amps = batch
103 |
104 | recon_fields = self.forward(slm_phase) # output field at target planes.
105 |
106 | # calculate losses
107 | loss, loss_mse, loss_mse_all = self.mp_loss(recon_fields, target_amps,
108 | self.plane_idxs, self.loss_func, prefix)
109 |
110 | with torch.no_grad():
111 | self.log(f'loss_{prefix}', loss, on_step=True, on_epoch=True)
112 | self.log(f'PSNR_{prefix}', 10*math.log10(1/loss_mse), on_step=True, on_epoch=True)
113 | self.log(f'PSNR_except_held_out_{prefix}',
114 | 10*math.log10(1/loss_mse_all), on_step=True, on_epoch=True)
115 |
116 | return loss
117 |
118 | def mp_loss(self, recon_fields, target_amps, plane_idxs, loss_func, prefix):
119 | """ Loss function on multiplane amplitudes"""
120 |
121 | # take the amplitudes.
122 | target_amps = utils.crop_image(target_amps, self.roi_res, stacked_complex=False)
123 | recon_amps = utils.crop_image(recon_fields, self.roi_res, stacked_complex=False).abs()
124 |
125 | with torch.no_grad():
126 | self.recon_amp[prefix] = recon_amps
127 | self.target_amp[prefix] = target_amps
128 |
129 | # calculate loss values.
130 | loss = 0. # the loss you penalize.
131 | mse_loss = 0. # PSNR on planes you penalize.
132 | mse_loss_all = 0. # PSNR on all planes except the held-out plane.
133 |
134 | if plane_idxs is None:
135 | # penalize all planes
136 | loss += loss_func(recon_amps, target_amps)
137 | else:
138 | # penalize only selected planes
139 | for i in range(target_amps.shape[1]):
140 | if i in plane_idxs[prefix]:
141 | # selected planes
142 | loss_i = loss_func(recon_amps[:, i, ...], target_amps[:, i, ...])
143 | loss += loss_i / len(plane_idxs[prefix])
144 | else:
145 | # these are not penalized, just for tensorboard
146 | with torch.no_grad():
147 | loss_i = loss_func(recon_amps[:, i, ...], target_amps[:, i, ...])
148 |
149 | # report PSNR
150 | with torch.no_grad():
151 | # this is for evaluation so just use the min-mse scaling.
152 | s = (recon_amps[:, i, ...] * target_amps[:, i, ...]).mean() \
153 | / (recon_amps[:, i, ...]**2).mean()
154 | mse_loss_i = F.mse_loss(s*recon_amps[:, i, ...], target_amps[:, i, ...])
155 |
156 | if i in self.plane_idxs[prefix]:
157 | mse_loss += mse_loss_i / len(plane_idxs[prefix])
158 |
159 | if not i in self.plane_idxs['heldout']:
160 | # exclude held-out plane
161 | mse_loss_all += mse_loss_i / len(plane_idxs['train'])
162 |
163 | self.log(f'loss_{prefix}/plane_{i}',
164 | loss_i, on_step=True, on_epoch=True)
165 | self.log(f'PSNR_{prefix}/plane_{i}',
166 | 10*math.log10(1./mse_loss_i), on_step=True, on_epoch=True)
167 |
168 | return loss, mse_loss, mse_loss_all
169 |
170 | def training_step(self, batch, batch_idx):
171 | return self.perform_step(batch, batch_idx, 'train')
172 |
173 | def validation_step(self, batch, batch_idx):
174 | return self.perform_step(batch, batch_idx, 'validation')
175 |
176 | def test_step(self, batch, batch_idx):
177 | return self.perform_step(batch, batch_idx, 'test')
178 |
179 | def test_epoch_end(self, outputs) -> None:
180 | self.epoch_end_images('test')
181 |
182 | def training_epoch_end(self, outputs) -> None:
183 | self.epoch_end_images('train')
184 |
185 | def validation_epoch_end(self, outputs) -> None:
186 | self.epoch_end_images('validation')
187 |
188 | def configure_optimizers(self):
189 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
190 | return optimizer
191 |
192 |
193 | class CNNpropCNN(PropModel):
194 | """
195 | A parameterized model with CNNs
196 |
197 | Class initialization parameters
198 | -------------------------------
199 | :param prop_dist: Propagation distance from SLM to Intermediate plane, float.
200 | :param wavelength: wavelength, float.
201 | :param feature_size: Pixel pitch of SLM, float.
202 | :param prop_type: Type of propagation, string, default 'ASM'.
203 | :param F_aperture: Level of filtering at Fourier plane, float, default 1.0 (and circular).
204 | :param prop_dists_from_wrp: An array of propagation distances from Intermediate plane to Target planes.
205 | :param linear_conv: If true, pad before taking Fourier transform to ensure the linear convolution.
206 | :param slm_res: Resolution of SLM.
207 | :param roi_res: Resolution of Region of Interest.
208 | :param num_downs_slm: Number of layers of U-net at SLM network.
209 | :param num_feats_slm_min: Number of features at the top layer of SLM network.
210 | :param num_feats_slm_max: Number of features at the very bottom layer of SLM network.
211 | :param num_downs_target: Number of layers of U-net at target network.
212 | :param num_feats_target_min: Number of features at the top layer of target network.
213 | :param num_feats_target_max: Number of features at the very bottom layer of target network.
214 | :param norm: normalization layers.
215 | :param slm_coord: input/output format of SLM network.
216 | :param target_coord: input/output format of target network.).
217 | :param loss_func: Loss function to train the model.
218 | :param lr: a learning rate.
219 | :param plane_idxs: a dictionary that has plane idxs for 'train', 'val', 'test', 'heldout'.
220 | """
221 |
222 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0,
223 | prop_dists_from_wrp=None, linear_conv=True, slm_res=(1080, 1920), roi_res=(960, 1680),
224 | num_downs_slm=0, num_feats_slm_min=0, num_feats_slm_max=0,
225 | num_downs_target=0, num_feats_target_min=0, num_feats_target_max=0,
226 | norm=nn.InstanceNorm2d, slm_coord='rect', target_coord='rect',
227 | loss_func=F.l1_loss, lr=4e-4,
228 | plane_idxs=None):
229 | super(CNNpropCNN, self).__init__(roi_res=roi_res,
230 | plane_idxs=plane_idxs, loss_func=loss_func, lr=lr)
231 |
232 | ##################
233 | # Model pipeline #
234 | ##################
235 | # SLM Network
236 | if num_downs_slm > 0:
237 | slm_cnns = []
238 | slm_cnn_res = tuple(res if res % (2 ** num_downs_slm) == 0 else
239 | res + (2 ** num_downs_slm - res % (2 ** num_downs_slm))
240 | for res in slm_res)
241 | print(slm_cnn_res,' res')
242 |
243 | slm_input = Field2Input(slm_cnn_res, coord=slm_coord)
244 | slm_cnns += [slm_input]
245 |
246 | if num_downs_slm > 0:
247 | slm_cnn = UnetGenerator(input_nc=4 if 'both' in slm_coord else 2, output_nc=2,
248 | num_downs=num_downs_slm, nf0=num_feats_slm_min,
249 | max_channels=num_feats_slm_max, norm_layer=norm, outer_skip=True)
250 | init_weights(slm_cnn, init_type='normal')
251 | slm_cnns += [slm_cnn]
252 |
253 | slm_output = Output2Field(slm_res, slm_coord)
254 | slm_cnns += [slm_output]
255 | self.slm_cnn = nn.Sequential(*slm_cnns)
256 | else:
257 | self.slm_cnn = None
258 |
259 | # Propagation from the SLM plane to the WRP.
260 | if prop_dist != 0.:
261 | self.prop_slm_wrp = prop_ideal.Propagation(prop_dist, wavelength, feature_size,
262 | prop_type=prop_type, linear_conv=linear_conv,
263 | F_aperture=F_aperture)
264 | else:
265 | self.prop_slm_wrp = None
266 |
267 | # Propagation from the WRP to other planes.
268 | if prop_dists_from_wrp is not None:
269 | self.prop_wrp_target = prop_ideal.Propagation(prop_dists_from_wrp, wavelength, feature_size,
270 | prop_type=prop_type, linear_conv=1.0,
271 | F_aperture=F_aperture)
272 | else:
273 | self.prop_wrp_target = None
274 |
275 | # Target network (This is either included (prop later) or not (prop before, which is then basically NH3D).
276 | if num_downs_target > 0:
277 | target_cnn_res = tuple(res if res % (2 ** num_downs_target) == 0 else
278 | res + (2 ** num_downs_target - res % (2 ** num_downs_target)) for res in slm_res)
279 | target_input = Field2Input(target_cnn_res, coord=target_coord, shared_cnn=True)
280 | input_nc_target = 4 if 'both' in target_coord else 2 if target_coord != 'amp' else 1
281 | output_nc_target = 2 if target_coord != 'amp' and ('1ch_output' not in target_coord) else 1
282 | target_cnn = UnetGenerator(input_nc=input_nc_target, output_nc=output_nc_target,
283 | num_downs=num_downs_target, nf0=num_feats_target_min,
284 | max_channels=num_feats_target_max, norm_layer=norm, outer_skip=True)
285 | init_weights(target_cnn, init_type='normal')
286 |
287 | # shared target cnn requires permutation in channels here.
288 | num_ch_output = 1 if not prop_dists_from_wrp else len(self.prop_wrp_target)
289 | target_output = Output2Field(slm_res, target_coord, num_ch_output=num_ch_output)
290 | target_cnns = [target_input, target_cnn, target_output]
291 | self.target_cnn = nn.Sequential(*target_cnns)
292 | else:
293 | self.target_cnn = None
294 |
295 | def forward(self, field):
296 | if self.slm_cnn is not None:
297 | slm_field = self.slm_cnn(field) # Applying CNN at SLM plane.
298 | else:
299 | slm_field = field
300 | if self.prop_slm_wrp is not None:
301 | wrp_field = self.prop_slm_wrp(slm_field) # Propagation from SLM to Intermediate plane.
302 | if self.prop_wrp_target is not None:
303 | target_field = self.prop_wrp_target(wrp_field) # Propagation from Intermediate plane to Target planes.
304 | if self.target_cnn is not None:
305 | amp = self.target_cnn(target_field).abs() # Applying CNN at Target planes.
306 | phase = target_field.angle()
307 | output_field = amp * torch.exp(1j * phase)
308 | else:
309 | output_field = target_field
310 |
311 | return output_field
312 |
313 | def epoch_end_images(self, prefix):
314 | """
315 | execute at the end of epochs
316 |
317 | :param prefix:
318 | :return:
319 | """
320 | #################
321 | # Reconstructions
322 | logger = self.logger.experiment
323 | recon_amp = self.recon_amp[prefix][0]
324 | target_amp = self.target_amp[prefix][0]
325 | for i in range(recon_amp.shape[0]):
326 | logger.add_image(f'amp_recon/{prefix}_{i}', recon_amp[i:i+1, ...].clip(0, 1), self.global_step)
327 | logger.add_image(f'amp_target/{prefix}_{i}', target_amp[i:i+1, ...].clip(0, 1), self.global_step)
328 |
329 | @property
330 | def plane_idx(self):
331 | return self._plane_idx
332 |
333 | @plane_idx.setter
334 | def plane_idx(self, idx):
335 | """
336 |
337 | """
338 | if idx is None:
339 | return
340 | self._plane_idx = idx
341 | if self.prop_wrp_target is not None and len(self.prop_wrp_target) > 1:
342 | self.prop_wrp_target.plane_idx = idx
343 | if self.target_cnn is not None and self.target_cnn[-1].num_ch_output > 1:
344 | self.target_cnn[-1].num_ch_output = 1
345 |
346 |
347 | class PropNH(PropModel):
348 | """
349 | A parameterized model proposed in the original Neural Holography paper (Peng et al. 2020)
350 |
351 | Class initialization parameters
352 | -------------------------------
353 | """
354 | def __init__(self, prop_dist, wavelength, feature_size, prop_type='ASM', F_aperture=1.0,
355 | prop_dists_from_wrp=None, linear_conv=True, slm_res=(1080, 1920),
356 | roi_res=(1080, 1920), plane_idxs=None, loss_func=F.l1_loss, lr=4e-4,
357 | num_gaussians=3, init_sigma=(1300.0, 1500.0, 1700.0), init_amp=0.9,
358 | num_zernike_slm=0, num_zernike_f=5, init_coeffs=0.0, use_conv1d_mlp=True,
359 | num_layers=3, num_features=16, num_latent_codes=None, norm=nn.GroupNorm,
360 | target_field=True, content_field=True, num_layers_cnn=5, num_feats_cnn=8):
361 | super(PropNH, self).__init__(roi_res=roi_res,
362 | plane_idxs=plane_idxs, loss_func=loss_func, lr=lr)
363 |
364 | self.prop_slm_wrp = None
365 | self.prop_wrp_target = None
366 | self.lut_slm = None # Section 5.1.3. Phase nonlinearity
367 | self.code_slm = None
368 | self.a_src = None # Section 5.1.1. Content-independent Source
369 | self.phi_znk_slm = None # Section 5.1.2 Modeling Optical Propagation with Aberrations
370 | self.znk_basis_slm = None
371 | self.preloaded_znk_slm = False
372 | self.phi_znk_f = None
373 | self.znk_basis_f = None
374 | self.preloaded_znk_f = False
375 | self.coeffs_slm = None
376 | self.coeffs_f = None
377 | self.preload_zernike = False
378 | self.a_t = None # Section 5.1.1. Target Field variation
379 | self.phi_t = None
380 | self.cnn = None # Section 5.1.4. Content-dependent Undiffracted Light
381 |
382 | # Section 3. Propagation from the SLM plane to the WRP.
383 | if prop_dist != 0.:
384 | self.prop_slm_wrp = prop_ideal.Propagation(prop_dist, wavelength, feature_size,
385 | prop_type=prop_type, linear_conv=linear_conv,
386 | F_aperture=F_aperture)
387 | # Propagation from the WRP to other planes.
388 | if prop_dists_from_wrp is not None:
389 | self.prop_wrp_target = prop_ideal.Propagation(prop_dists_from_wrp, wavelength, feature_size,
390 | prop_type=prop_type, linear_conv=linear_conv,
391 | F_aperture=F_aperture)
392 |
393 | # Section 5.1.1. Content-independent Source
394 | if num_gaussians:
395 | self.a_src = SourceAmplitude(num_gaussians, init_sigma, init_amp=init_amp, x_s0=0.0, y_s0=0.0)
396 |
397 | # Section 5.1.1. Target Field variation
398 | if target_field:
399 | self.a_t = nn.Parameter(0.07 * torch.ones(1, 1, *slm_res))
400 | self.phi_t = nn.Parameter(torch.zeros((1, 1, *slm_res)))
401 |
402 | # Section 5.1.2 Modeling Optical Propagation with Aberrations
403 | if num_zernike_slm:
404 | self.coeffs_slm = nn.Parameter(torch.ones(num_zernike_slm) * init_coeffs)
405 | if num_zernike_f:
406 | self.coeffs_f = nn.Parameter(torch.ones(num_zernike_f) * init_coeffs)
407 |
408 | # Section 5.1.3. Phase nonlinearity
409 | if num_latent_codes is None:
410 | num_latent_codes = [2, 0, 0]
411 | if sum(num_latent_codes) > 0:
412 | self.code_slm = nn.Parameter(torch.zeros(1, sum(num_latent_codes), *slm_res))
413 | if use_conv1d_mlp:
414 | self.lut_slm = LatentCodedMLP(num_layers, num_features, norm=norm, num_latent_codes=num_latent_codes)
415 |
416 | # Section 5.1.4. Content-dependent Undiffracted Light
417 | if content_field:
418 | target_cnn = [ContentDependentField(num_layers=num_layers_cnn,
419 | num_features=num_feats_cnn,
420 | norm=nn.GroupNorm)]
421 | target_cnn += [Output2Field(slm_res, 'rect')]
422 | self.cnn = nn.Sequential(*target_cnn)
423 |
424 | self.conv = Conv2dField()
425 |
426 | def forward(self, u_in):
427 | """ Implementation of Equation (6) in the Neural Holography paper """
428 | if torch.is_complex(u_in):
429 | phi_in = u_in.angle()
430 | else:
431 | phi_in = u_in
432 |
433 | # Section 5.1.3. Phase nonlinearity
434 | if self.lut_slm is not None:
435 | slm_phase = self.lut_slm(phi_in, self.code_slm)
436 | else:
437 | slm_phase = phi_in
438 |
439 | # Section 5.1.2. precompute the zernike basis only once
440 | if self.znk_basis_slm is None and self.coeffs_slm is not None:
441 | self.znk_basis_slm = compute_zernike_basis(self.coeffs_slm.size()[0],
442 | slm_phase.size()[-2:], wo_piston=True)
443 | self.znk_basis_slm = self.znk_basis_slm.to(u_in.device).detach().requires_grad_(False)
444 |
445 | if not self.preloaded_znk_slm and self.phi_znk_slm is None and self.coeffs_slm is not None:
446 | self.phi_znk_slm = combine_zernike_basis(self.coeffs_slm, self.znk_basis_slm)
447 | self.phi_znk_slm = self.phi_znk_slm.to(u_in.device)
448 |
449 | if self.preload_zernike:
450 | self.preload_zernike_slm()
451 |
452 | if self.phi_znk_slm is not None:
453 | slm_phase = slm_phase + self.phi_znk_slm
454 |
455 | # Section 5.1.1. Create Source Amplitude (DC + gaussians)
456 | if self.a_src is not None:
457 | slm_field = self.a_src(slm_phase) * torch.exp(1j * slm_phase)
458 | else:
459 | slm_field = torch.exp(1j * slm_phase)
460 |
461 | # Section 5.1.2. precompute the zernike basis only once
462 | if self.znk_basis_f is None and self.coeffs_f is not None:
463 | self.znk_basis_f = compute_zernike_basis(self.coeffs_f.size()[0],
464 | [i * 2 for i in slm_phase.size()[-2:]],
465 | wo_piston=True)
466 | self.znk_basis_f = self.znk_basis_f.to(u_in.device).detach().requires_grad_(False)
467 |
468 | if not self.preloaded_znk_f and self.phi_znk_f is None and self.coeffs_f is not None:
469 | self.phi_znk_f = combine_zernike_basis(self.coeffs_f, self.znk_basis_f).to(u_in.device)
470 |
471 | if self.preload_zernike:
472 | self.preload_zernike_f()
473 |
474 | self.prop_slm_wrp.fourier_phase = self.phi_znk_f
475 |
476 | # Propagation from SLM to Intermediate plane
477 | recon_field = self.prop_slm_wrp(slm_field)
478 |
479 | # Section 5.1.1. Content-independent field at target plane
480 | if self.a_t is not None and self.phi_t is not None:
481 | recon_field = recon_field + self.a_t * torch.exp(1j * self.phi_t)
482 |
483 | # Section 5.1.4. Content-dependent Undiffracted light
484 | if self.cnn is not None:
485 | recon_field = recon_field + self.cnn(phi_in)
486 |
487 | # Propagation from Intermediate plane to Target planes.
488 | if self.prop_wrp_target is not None:
489 | target_field = self.prop_wrp_target(recon_field)
490 |
491 | if self.conv is not None:
492 | return self.conv(target_field) * torch.exp(1j * target_field.angle())
493 | else:
494 | return target_field
495 |
496 | def epoch_end_images(self, prefix):
497 | """ execute at the end of epochs """
498 |
499 | # Reconstructions
500 | logger = self.logger.experiment
501 | recon_amp = self.recon_amp[prefix][0]
502 | target_amp = self.target_amp[prefix][0]
503 | for i in range(recon_amp.shape[0]):
504 | logger.add_image(f'amp_recon/{prefix}_{i}', recon_amp[i:i+1, ...].clip(0, 1), self.global_step)
505 | logger.add_image(f'amp_target/{prefix}_{i}', target_amp[i:i+1, ...].clip(0, 1), self.global_step)
506 |
507 | def preload_zernike_slm(self):
508 | self.preload_zernike_slm = True
509 | if self.phi_znk_slm is not None:
510 | self.phi_znk_slm = self.phi_znk_slm.detach().requires_grad_(False)
511 |
512 | def preload_zernike_f(self):
513 | self.preload_zernike_f = True
514 | if self.phi_znk_f is not None:
515 | self.phi_znk_f = self.phi_znk_f.detach().requires_grad_(False)
516 | self.preload_zernike = False
517 |
518 | def preload_H(self):
519 | self.preload_zernike = True
520 |
521 | @property
522 | def plane_idx(self):
523 | return self._plane_idx
524 |
525 | @plane_idx.setter
526 | def plane_idx(self, idx):
527 | """ """
528 | if idx is None:
529 | return
530 | self._plane_idx = idx
531 | if self.prop_wrp_target is not None and len(self.prop_wrp_target) > 1:
532 | self.prop_wrp_target.plane_idx = idx
533 |
--------------------------------------------------------------------------------
/prop_physical.py:
--------------------------------------------------------------------------------
1 | """
2 | Propagation happening on the setup
3 |
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import utils
9 | import time
10 | import cv2
11 | import numpy as np
12 | import imageio
13 |
14 | import sys
15 | if sys.platform == 'win32':
16 | import slmpy
17 | import hw.camera_capture_module as cam
18 | import hw.calibration_module as calibration
19 |
20 |
21 | class PhysicalProp(nn.Module):
22 | """ A module for physical propagation,
23 | forward pass displays gets SLM pattern as an input and display the pattern on the physical setup,
24 | and capture the diffraction image at the target plane,
25 | and then return warped image using pre-calibrated homography from instantiation.
26 |
27 | Class initialization parameters
28 | -------------------------------
29 | :param params_slm: a set of parameters for the SLM.
30 | :param params_camera: a set of parameters for the camera sensor.
31 | :param params_calib: a set of parameters for homography calibration.
32 |
33 | Usage
34 | -----
35 | Functions as a pytorch module:
36 |
37 | >>> camera_prop = PhysicalProp(...)
38 | >>> captured_amp = camera_prop(slm_phase)
39 |
40 | slm_phase: phase at the SLM plane, with dimensions [batch, 1, height, width]
41 | captured_amp: amplitude at the target plane, with dimensions [batch, 1, height, width]
42 |
43 | """
44 | def __init__(self, params_slm, params_camera, params_calib=None):
45 | super(PhysicalProp, self).__init__()
46 |
47 | # 1. Connect Camera
48 | self.camera = cam.CameraCapture(params_camera)
49 | self.camera.connect(0) # specify the camera to use, 0 for main cam, 1 for the second cam
50 | self.camera.start_capture()
51 |
52 | # 2. Connect SLM
53 | self.slm = slmpy.SLMdisplay(isImageLock=True, monitor=params_slm.monitor_num)
54 | self.params_slm = params_slm
55 |
56 | # 3. Calibrate hardware using homography
57 | if params_calib is not None:
58 | self.warper = calibration.Warper(params_calib)
59 | self.calibrate(params_calib.phase_path, params_calib.show_preview)
60 | else:
61 | self.warper = None
62 |
63 | def calibrate(self, phase_path, show_preview=False):
64 | """
65 |
66 | :param phase_path:
67 | :param show_preview:
68 | :return:
69 | """
70 | print(' -- Calibrating ...')
71 | phase_img = imageio.imread(phase_path)
72 | self.slm.updateArray(phase_img)
73 | time.sleep(self.params_slm.settle_time)
74 | captured_img = self.camera.grab_images_fast(5) # capture 5-10 images for averaging
75 | calib_success = self.warper.calibrate(captured_img, show_preview)
76 | if calib_success:
77 | print(' -- Calibration succeeded!...')
78 | else:
79 | raise ValueError(' -- Calibration failed')
80 |
81 | def forward(self, slm_phase):
82 | """
83 |
84 | :param slm_phase:
85 | :return:
86 | """
87 | raw_intensity = self.capture_linear_intensity(slm_phase) # grayscale raw16 intensity image
88 | warped_intensity = self.warper(raw_intensity) # apply homography
89 | return warped_intensity.sqrt() # return amplitude
90 |
91 | def capture_linear_intensity(self, slm_phase):
92 | """
93 | display a phase pattern on the SLM and capture a generated holographic image with the sensor.
94 |
95 | :param slm_phase:
96 | :return:
97 | """
98 | raw_uint16_data = self.capture_uint16(slm_phase) # display & retrieve buffer
99 | captured_intensity = self.process_raw_data(raw_uint16_data) # demosaick & sum up
100 | return captured_intensity
101 |
102 | def capture_uint16(self, slm_phase):
103 | """
104 | gets phase pattern(s) and display it on the SLM, and then send a signal to board (wait next clock from SLM).
105 | Right after hearing back from the SLM, it sends another signal to PC so that PC retreives the camera buffer.
106 |
107 | :param slm_phase:
108 | :return:
109 | """
110 | if torch.is_tensor(slm_phase):
111 | slm_phase_encoded = utils.phasemap_8bit(slm_phase)
112 | else:
113 | slm_phase_encoded = slm_phase
114 | self.slm.updateArray(slm_phase_encoded)
115 |
116 | if self.camera.params.ser is not None:
117 | self.camera.params.ser.write(f'D'.encode())
118 |
119 | # TODO: make the following in a separate function.
120 | # Wait until receiving signal from arduino
121 | incoming_byte = self.camera.params.ser.inWaiting()
122 | t0 = time.perf_counter()
123 | while True:
124 | received = self.camera.params.ser.read(incoming_byte).decode('UTF-8')
125 | if received != 'C':
126 | incoming_byte = self.camera.params.ser.inWaiting()
127 | if time.perf_counter() - t0 > 2.0:
128 | break
129 | else:
130 | break
131 | else:
132 | time.sleep(self.params_slm.settle_time)
133 |
134 | raw_data_from_buffer = self.camera.retrieve_buffer()
135 | return raw_data_from_buffer
136 |
137 | def process_raw_data(self, raw_data):
138 | """
139 | gets raw data from the camera buffer, and demosaick it
140 |
141 | :param raw_data:
142 | :return:
143 | """
144 | raw_data = raw_data - 64
145 | color_cv_image = cv2.cvtColor(raw_data, self.camera.demosaick_rule) # it gives float64 from uint16 -- double check it
146 | captured_intensity = utils.im2float(color_cv_image) # float64 to float32
147 |
148 | # Numpy to tensor
149 | captured_intensity = torch.tensor(captured_intensity, dtype=torch.float32,
150 | device=self.dev).permute(2, 0, 1).unsqueeze(0)
151 | captured_intensity = torch.sum(captured_intensity, dim=1, keepdim=True)
152 | return captured_intensity
153 |
154 | def disconnect(self):
155 | self.camera.stop_capture()
156 | self.camera.disconnect()
157 | self.slm.close()
158 |
159 | def to(self, *args, **kwargs):
160 | slf = super().to(*args, **kwargs)
161 | if slf.warper is not None:
162 | slf.warper = slf.warper.to(*args, **kwargs)
163 | try:
164 | slf.dev = next(slf.parameters()).device
165 | except StopIteration: # no parameters
166 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0]
167 | if device_arg is not None:
168 | slf.dev = device_arg
169 | return slf
--------------------------------------------------------------------------------
/prop_submodules.py:
--------------------------------------------------------------------------------
1 | """
2 | Modules for propagation
3 |
4 | """
5 |
6 | import math
7 | import torch
8 | import torch.nn as nn
9 | import utils
10 | from unet import Conv2dSame
11 |
12 |
13 | class Field2Input(nn.Module):
14 | """Gets complex-valued field and turns it into multi-channel images"""
15 |
16 | def __init__(self, input_res=(800, 1280), coord='rect', shared_cnn=False):
17 | super(Field2Input, self).__init__()
18 | self.input_res = input_res
19 | self.coord = coord.lower()
20 | self.shared_cnn = shared_cnn
21 |
22 | def forward(self, input_field):
23 | # If input field is slm phase
24 | if input_field.dtype == torch.float32:
25 | input_field = torch.exp(1j * input_field)
26 |
27 | input_field = utils.pad_image(input_field, self.input_res, pytorch=True, stacked_complex=False)
28 | input_field = utils.crop_image(input_field, self.input_res, pytorch=True, stacked_complex=False)
29 |
30 | # To use shared CNN, put everything into batch dimension;
31 | if self.shared_cnn:
32 | num_mb, num_dists = input_field.shape[0], input_field.shape[1]
33 | input_field = input_field.reshape(num_mb*num_dists, 1, *input_field.shape[2:])
34 |
35 | # Input format
36 | if self.coord == 'rect':
37 | stacked_input = torch.cat((input_field.real, input_field.imag), 1)
38 | elif self.coord == 'polar':
39 | stacked_input = torch.cat((input_field.abs(), input_field.angle()), 1)
40 | elif self.coord == 'amp':
41 | stacked_input = input_field.abs()
42 | elif 'both' in self.coord:
43 | stacked_input = torch.cat((input_field.abs(), input_field.angle(), input_field.real, input_field.imag), 1)
44 |
45 | return stacked_input
46 |
47 |
48 | class Output2Field(nn.Module):
49 | """Gets complex-valued field and turns it into multi-channel images"""
50 |
51 | def __init__(self, output_res=(800, 1280), coord='rect', num_ch_output=1):
52 | super(Output2Field, self).__init__()
53 | self.output_res = output_res
54 | self.coord = coord.lower()
55 | self.num_ch_output = num_ch_output # number of channels in output
56 |
57 | def forward(self, stacked_output):
58 |
59 | if self.coord in ('rect', 'both'):
60 | complex_valued_field = torch.view_as_complex(stacked_output.unsqueeze(4).
61 | permute(0, 4, 2, 3, 1).contiguous())
62 | elif self.coord == 'polar':
63 | amp = stacked_output[:, 0:1, ...]
64 | phi = stacked_output[:, 1:2, ...]
65 | complex_valued_field = amp * torch.exp(1j * phi)
66 | elif self.coord == 'amp' or '1ch_output' in self.coord:
67 | complex_valued_field = stacked_output * torch.exp(1j * torch.zeros_like(stacked_output))
68 |
69 | output_field = utils.pad_image(complex_valued_field, self.output_res, pytorch=True, stacked_complex=False)
70 | output_field = utils.crop_image(output_field, self.output_res, pytorch=True, stacked_complex=False)
71 |
72 | if self.num_ch_output > 1:
73 | # reshape to original tensor shape
74 | output_field = output_field.reshape(output_field.shape[0] // self.num_ch_output, self.num_ch_output,
75 | *output_field.shape[2:])
76 |
77 | return output_field
78 |
79 |
80 | class Conv2dField(nn.Module):
81 | """Apply 2d conv on amp or field"""
82 |
83 | def __init__(self, comp=False, conv_size=3):
84 | super(Conv2dField, self).__init__()
85 | self.comp = comp # apply convolution on field
86 | self.conv_size = (conv_size, conv_size)
87 | if self.comp:
88 | self.conv_real = Conv2dSame(1, 1, conv_size)
89 | self.conv_imag = Conv2dSame(1, 1, conv_size)
90 | init_weight = torch.zeros(1, 1, *self.conv_size)
91 | init_weight[..., conv_size//2, conv_size//2] = 1.
92 | self.conv_real.net[1].weight = nn.Parameter(init_weight.detach().requires_grad_(True))
93 | self.conv_imag.net[1].weight = nn.Parameter(init_weight.detach().requires_grad_(True))
94 | else:
95 | self.conv = Conv2dSame(1, 1, conv_size, bias=False)
96 | init_weight = torch.zeros(1, 1, *self.conv_size)
97 | init_weight[..., conv_size//2, conv_size//2] = 1.
98 | self.conv.net[1].weight = nn.Parameter(init_weight.requires_grad_(True))
99 |
100 | def forward(self, input_field):
101 |
102 | # reshape tensor if number of channels > 1
103 | num_ch = input_field.shape[1]
104 | if num_ch > 1:
105 | batch_size = input_field.shape[0]
106 | input_field = input_field.reshape(batch_size * num_ch, 1, *input_field.shape[2:])
107 |
108 | if self.comp:
109 | # apply conv on complex fields
110 | real = self.conv_real(input_field.real) - self.conv_imag(input_field.imag)
111 | imag = self.conv_real(input_field.imag) + self.conv_imag(input_field.real)
112 | output_field = torch.view_as_complex(torch.stack((real, imag), -1))
113 | else:
114 | # apply conv on intensity
115 | output_amp = self.conv(input_field.abs()**2).abs().mean(dim=1, keepdims=True).sqrt()
116 | output_field = output_amp * torch.exp(1j * input_field.angle())
117 |
118 | # reshape to original tensor shape
119 | if num_ch > 1:
120 | output_field = output_field.reshape(batch_size, num_ch, *output_field.shape[2:])
121 |
122 | return output_field
123 |
124 |
125 | class LatentCodedMLP(nn.Module):
126 | """
127 | concatenate latent codes in the middle of forward pass as well.
128 | put latent codes shape of (1, L, H, W) as a parameter for the forward pass.
129 | num_latent_codes: list of numbers of slices for each layer
130 | * so the sum of num_latent_codes should be total number of the latent codes channels
131 | """
132 | def __init__(self, num_layers=5, num_features=32, norm=None, num_latent_codes=None):
133 | super(LatentCodedMLP, self).__init__()
134 |
135 | if num_latent_codes is None:
136 | num_latent_codes = [0] * num_layers
137 |
138 | assert len(num_latent_codes) == num_layers
139 |
140 | self.num_latent_codes = num_latent_codes
141 | self.idxs = [sum(num_latent_codes[:y]) for y in range(num_layers + 1)]
142 | self.nets = nn.ModuleList([])
143 | num_features = [num_features] * num_layers
144 | num_features[0] = 1
145 |
146 | # define each layer
147 | for i in range(num_layers - 1):
148 | net = [nn.Conv2d(num_features[i] + num_latent_codes[i], num_features[i + 1], kernel_size=1)]
149 | if norm is not None:
150 | net += [norm(num_groups=4, num_channels=num_features[i + 1], affine=True)]
151 | net += [nn.LeakyReLU(0.2, True)]
152 | self.nets.append(nn.Sequential(*net))
153 |
154 | self.nets.append(nn.Conv2d(num_features[-1] + num_latent_codes[-1], 1, kernel_size=1))
155 |
156 | for m in self.modules():
157 | if isinstance(m, nn.Conv2d):
158 | nn.init.normal_(m.weight, std=0.05)
159 |
160 | def forward(self, phases, latent_codes=None):
161 |
162 | after_relu = phases
163 | # concatenate latent codes at each layer and send through the convolutional layers
164 | for i in range(len(self.num_latent_codes)):
165 | if latent_codes is not None:
166 | after_relu = torch.cat((after_relu, latent_codes[:, self.idxs[i]:self.idxs[i + 1], ...]), 1)
167 | after_relu = self.nets[i](after_relu)
168 |
169 | # residual connection
170 | return phases - after_relu
171 |
172 |
173 | class ContentDependentField(nn.Module):
174 | def __init__(self, num_layers=5, num_features=32, norm=nn.GroupNorm, latent_coords=False):
175 | """ Simple 5layers CNN modeling content dependent undiffracted light """
176 |
177 | super(ContentDependentField, self).__init__()
178 |
179 | if not latent_coords:
180 | first_ch = 1
181 | else:
182 | first_ch = 3
183 |
184 | net = [Conv2dSame(first_ch, num_features, kernel_size=3)]
185 |
186 | for i in range(num_layers - 2):
187 | if norm is not None:
188 | net += [norm(num_groups=2, num_channels=num_features, affine=True)]
189 | net += [nn.LeakyReLU(0.2, True),
190 | Conv2dSame(num_features, num_features, kernel_size=3)]
191 |
192 | if norm is not None:
193 | net += [norm(num_groups=4, num_channels=num_features, affine=True)]
194 |
195 | net += [nn.LeakyReLU(0.2, True),
196 | Conv2dSame(num_features, 2, kernel_size=3)]
197 |
198 | self.net = nn.Sequential(*net)
199 |
200 | def forward(self, phases, latent_coords=None):
201 | if latent_coords is not None:
202 | input_cnn = torch.cat((phases, latent_coords), dim=1)
203 | else:
204 | input_cnn = phases
205 |
206 | return self.net(input_cnn)
207 |
208 |
209 | class ProcessPhase(nn.Module):
210 | def __init__(self, num_layers=5, num_features=32, num_output_feat=0, norm=nn.BatchNorm2d, num_latent_codes=0):
211 | super(ProcessPhase, self).__init__()
212 |
213 | # avoid zero
214 | self.num_output_feat = max(num_output_feat, 1)
215 | self.num_latent_codes = num_latent_codes
216 |
217 | # a bunch of 1x1 conv layers, set by num_layers
218 | net = [nn.Conv2d(1 + num_latent_codes, num_features, kernel_size=1)]
219 |
220 | for i in range(num_layers - 2):
221 | if norm is not None:
222 | net += [norm(num_groups=2, num_channels=num_features, affine=True)]
223 | net += [nn.LeakyReLU(0.2, True),
224 | nn.Conv2d(num_features, num_features, kernel_size=1)]
225 |
226 | if norm is not None:
227 | net += [norm(num_groups=2, num_channels=num_features, affine=True)]
228 |
229 | net += [nn.ReLU(True),
230 | nn.Conv2d(num_features, self.num_output_feat, kernel_size=1)]
231 |
232 | self.net = nn.Sequential(*net)
233 |
234 | def forward(self, phases):
235 | return phases - self.net(phases)
236 |
237 |
238 | class SourceAmplitude(nn.Module):
239 | def __init__(self, num_gaussians=3, init_sigma=None, init_amp=0.7, x_s0=0.0, y_s0=0.0):
240 | super(SourceAmplitude, self).__init__()
241 |
242 | self.num_gaussians = num_gaussians
243 |
244 | if init_sigma is None:
245 | init_sigma = [100.] * self.num_gaussians # default to 100 for all
246 |
247 | # create parameters for source amplitudes
248 | self.sigmas = nn.Parameter(torch.tensor(init_sigma))
249 | self.x_s = nn.Parameter(torch.ones(num_gaussians) * x_s0)
250 | self.y_s = nn.Parameter(torch.ones(num_gaussians) * y_s0)
251 | self.amplitudes = nn.Parameter(torch.ones(num_gaussians) / (num_gaussians) * init_amp)
252 | self.dc_term = nn.Parameter(torch.zeros(1))
253 |
254 | self.x_dim = None
255 | self.y_dim = None
256 |
257 | def forward(self, phases):
258 | # create DC term, then add the gaussians
259 | source_amp = torch.ones_like(phases) * self.dc_term
260 | for i in range(self.num_gaussians):
261 | source_amp += self.create_gaussian(phases.shape, i)
262 |
263 | return source_amp
264 |
265 | def create_gaussian(self, shape, idx):
266 | # create sampling grid if needed
267 | if self.x_dim is None or self.y_dim is None:
268 | self.x_dim = torch.linspace(-(shape[-1] - 1) / 2,
269 | (shape[-1] - 1) / 2,
270 | shape[-1], device=self.dc_term.device)
271 | self.y_dim = torch.linspace(-(shape[-2] - 1) / 2,
272 | (shape[-2] - 1) / 2,
273 | shape[-2], device=self.dc_term.device)
274 |
275 | if self.x_dim.device != self.sigmas.device:
276 | self.x_dim.to(self.sigmas.device).detach()
277 | self.x_dim.requires_grad = False
278 | if self.y_dim.device != self.sigmas.device:
279 | self.y_dim.to(self.sigmas.device).detach()
280 | self.y_dim.requires_grad = False
281 |
282 | # offset grid by coordinate and compute x and y gaussian components
283 | x_gaussian = torch.exp(-0.5 * torch.pow(torch.div(self.x_dim - self.x_s[idx], self.sigmas[idx]), 2))
284 | y_gaussian = torch.exp(-0.5 * torch.pow(torch.div(self.y_dim - self.y_s[idx], self.sigmas[idx]), 2))
285 |
286 | # outer product with amplitude scaling
287 | gaussian = torch.ger(self.amplitudes[idx] * y_gaussian, x_gaussian)
288 |
289 | return gaussian
290 |
291 |
292 | def make_kernel_gaussian(sigma, kernel_size):
293 |
294 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
295 | x_cord = torch.arange(kernel_size)
296 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
297 | y_grid = x_grid.t()
298 | xy_grid = torch.stack([x_grid, y_grid], dim=-1)
299 |
300 | mean = (kernel_size - 1) / 2
301 | variance = sigma**2
302 |
303 | # Calculate the 2-dimensional gaussian kernel which is
304 | # the product of two gaussian distributions for two different
305 | # variables (in this case called x and y)
306 | gaussian_kernel = ((1 / (2 * math.pi * variance))
307 | * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1)
308 | / (2 * variance)))
309 | # Make sure sum of values in gaussian kernel equals 1.
310 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
311 |
312 | # Reshape to 2d depthwise convolutional weight
313 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
314 |
315 | return gaussian_kernel
316 |
317 |
318 | def create_gaussian(shape, sigma=800, dev=torch.device('cuda')):
319 | # create sampling grid if needed
320 | shape_min = min(shape[-1], shape[-2])
321 | x_dim = torch.linspace(-(shape_min - 1) / 2,
322 | (shape_min - 1) / 2,
323 | shape[-1], device=dev)
324 | y_dim = torch.linspace(-(shape_min - 1) / 2,
325 | (shape_min - 1) / 2,
326 | shape[-2], device=dev)
327 |
328 | # offset grid by coordinate and compute x and y gaussian components
329 | x_gaussian = torch.exp(-0.5 * torch.pow(torch.div(x_dim, sigma), 2))
330 | y_gaussian = torch.exp(-0.5 * torch.pow(torch.div(y_dim, sigma), 2))
331 |
332 | # outer product with amplitude scaling
333 | gaussian = torch.ger(y_gaussian, x_gaussian)
334 |
335 | return gaussian
336 |
337 |
338 |
339 |
340 |
--------------------------------------------------------------------------------
/prop_zernike.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for zernike basis
3 |
4 | """
5 |
6 | import math
7 | import torch
8 | import numpy as np
9 | import utils
10 | import torch.fft
11 | from aotools.functions import zernikeArray
12 |
13 |
14 | def combine_zernike_basis(coeffs, basis, return_phase=False):
15 | """
16 | Multiplies the Zernike coefficients and basis functions while preserving
17 | dimensions
18 |
19 | :param coeffs: torch tensor with coeffs, see propagation_ASM_zernike
20 | :param basis: the output of compute_zernike_basis, must be same length as coeffs
21 | :param return_phase:
22 | :return: A float32 tensor that combines coeffs and basis.
23 | """
24 |
25 | if len(coeffs.shape) < 3:
26 | coeffs = torch.reshape(coeffs, (coeffs.shape[0], 1, 1))
27 |
28 | # combine zernike basis and coefficients
29 | zernike = (coeffs * basis).sum(0, keepdim=True)
30 |
31 | # shape to [1, len(coeffs), H, W]
32 | zernike = zernike.unsqueeze(0)
33 |
34 | return zernike
35 |
36 |
37 | def compute_zernike_basis(num_polynomials, field_res, dtype=torch.float32, wo_piston=False):
38 | """Computes a set of Zernike basis function with resolution field_res
39 |
40 | num_polynomials: number of Zernike polynomials in this basis
41 | field_res: [height, width] in px, any list-like object
42 | dtype: torch dtype for computation at different precision
43 | """
44 |
45 | # size the zernike basis to avoid circular masking
46 | zernike_diam = int(np.ceil(np.sqrt(field_res[0]**2 + field_res[1]**2)))
47 |
48 | # create zernike functions
49 |
50 | if not wo_piston:
51 | zernike = zernikeArray(num_polynomials, zernike_diam)
52 | else: # 200427 - exclude pistorn term
53 | idxs = range(2, 2 + num_polynomials)
54 | zernike = zernikeArray(idxs, zernike_diam)
55 |
56 | zernike = utils.crop_image(zernike, field_res, pytorch=False)
57 |
58 | # convert to tensor and create phase
59 | zernike = torch.tensor(zernike, dtype=dtype, requires_grad=False)
60 |
61 | return zernike
62 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | A script for model training
3 |
4 | """
5 | import os
6 | import configargparse
7 | import utils
8 | import prop_model
9 | import params
10 | import image_loader as loaders
11 |
12 | import pytorch_lightning as pl
13 | from pytorch_lightning import Trainer
14 | from torch.utils.data import DataLoader
15 |
16 | # Command line argument processing
17 | p = configargparse.ArgumentParser()
18 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
19 | params.add_parameters(p, 'train')
20 | opt = p.parse_args()
21 | params.set_configs(opt)
22 | run_id = params.run_id_training(opt)
23 |
24 | def main():
25 | if ',' in opt.data_path:
26 | opt.data_path = opt.data_path.split(',')
27 | else:
28 | opt.data_path = opt.data_path
29 | print(f' - training a model ... Dataset path:{opt.data_path}')
30 | # Setup up dataloaders
31 | num_workers = 8
32 | train_loader = DataLoader(loaders.PairsLoader(os.path.join(opt.data_path, 'train'),
33 | plane_idxs=opt.plane_idxs['all'], image_res=opt.image_res,
34 | avg_energy_ratio=opt.avg_energy_ratio, slm_type=opt.slm_type),
35 | num_workers=num_workers, batch_size=opt.batch_size)
36 | val_loader = DataLoader(loaders.PairsLoader(os.path.join(opt.data_path, 'val'),
37 | plane_idxs=opt.plane_idxs['all'], image_res=opt.image_res,
38 | shuffle=False, avg_energy_ratio=opt.avg_energy_ratio,
39 | slm_type=opt.slm_type),
40 | num_workers=num_workers, batch_size=opt.batch_size, shuffle=False)
41 | test_loader = DataLoader(loaders.PairsLoader(os.path.join(opt.data_path, 'test'),
42 | plane_idxs=opt.plane_idxs['all'], image_res=opt.image_res,
43 | shuffle=False, avg_energy_ratio=opt.avg_energy_ratio, slm_type=opt.slm_type),
44 | num_workers=num_workers, batch_size=opt.batch_size, shuffle=False)
45 |
46 | # Init model
47 | model = prop_model.model(opt)
48 | model.train()
49 |
50 | # Init root path
51 | root_dir = os.path.join(opt.out_path, run_id)
52 | utils.cond_mkdir(root_dir)
53 | p.write_config_file(opt, [os.path.join(root_dir, 'config.txt')])
54 |
55 | checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="PSNR_validation_epoch", dirpath=root_dir,
56 | filename="model-{epoch:02d}-{PSNR_validation_epoch:.2f}",
57 | save_top_k=1, mode="max", )
58 |
59 | # Init trainer
60 | trainer = Trainer(default_root_dir=root_dir, accelerator='gpu',
61 | log_every_n_steps=400, gpus=1, max_epochs=opt.num_epochs, callbacks=[checkpoint_callback])
62 |
63 | # Fit Model
64 | trainer.fit(model, train_loader, val_loader)
65 | # Test Model
66 | trainer.test(model, dataloaders=test_loader)
67 |
68 |
69 | if __name__ == "__main__":
70 | main()
--------------------------------------------------------------------------------
/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 |
6 |
7 | def norm_layer(norm_str):
8 | if norm_str.lower() == 'instance':
9 | return nn.InstanceNorm2d
10 | elif norm_str.lower() == 'group':
11 | return nn.GroupNorm
12 | elif norm_str.lower() == 'batch':
13 | return nn.BatchNorm2d
14 |
15 |
16 | class UnetSkipConnectionBlock(nn.Module):
17 | """Defines the Unet submodule with skip connection.
18 | X -------------------identity----------------------
19 | |-- downsampling -- |submodule| -- upsampling --|
20 | """
21 |
22 | def __init__(self, outer_nc, inner_nc, input_nc=None,
23 | submodule=None, outermost=False, innermost=False,
24 | norm_layer=nn.InstanceNorm2d, use_dropout=False,
25 | outer_skip=False):
26 | """Construct a Unet submodule with skip connections.
27 | Parameters:
28 | outer_nc (int) -- the number of filters in the outer conv layer
29 | inner_nc (int) -- the number of filters in the inner conv layer
30 | input_nc (int) -- the number of channels in input images/features
31 | submodule (UnetSkipConnectionBlock) -- previously defined submodules
32 | outermost (bool) -- if this module is the outermost module
33 | innermost (bool) -- if this module is the innermost module
34 | norm_layer -- normalization layer
35 | use_dropout (bool) -- if use dropout layers.
36 | """
37 | super(UnetSkipConnectionBlock, self).__init__()
38 | self.outermost = outermost
39 | self.outer_skip = outer_skip
40 | if norm_layer == None:
41 | use_bias = True
42 | elif type(norm_layer) == functools.partial:
43 | use_bias = norm_layer.func == nn.InstanceNorm2d
44 | else:
45 | use_bias = norm_layer == nn.InstanceNorm2d
46 | if input_nc is None:
47 | input_nc = outer_nc
48 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=5,
49 | # Change kernel size changed to 5 from 4 and padding size from 1 to 2
50 | stride=2, padding=2, bias=use_bias)
51 | downrelu = nn.LeakyReLU(0.2, True)
52 | if norm_layer is not None:
53 | if norm_layer == nn.GroupNorm:
54 | downnorm = norm_layer(8, inner_nc)
55 | else:
56 | downnorm = norm_layer(inner_nc)
57 | else:
58 | downnorm = None
59 | uprelu = nn.ReLU(True)
60 | if norm_layer is not None:
61 | if norm_layer == nn.GroupNorm:
62 | upnorm = norm_layer(8, outer_nc)
63 | else:
64 | upnorm = norm_layer(outer_nc)
65 | else:
66 | upnorm = None
67 |
68 | if outermost:
69 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
70 | kernel_size=4, stride=2,
71 | padding=1)
72 | down = [downconv, downrelu]
73 | up = [upconv] # Removed tanh and uprelu
74 | model = down + [submodule] + up
75 | elif innermost:
76 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
77 | kernel_size=4, stride=2,
78 | padding=1, bias=use_bias)
79 | if norm_layer is not None:
80 | down = [downconv, downnorm, downrelu]
81 | up = [upconv, upnorm, uprelu]
82 | else:
83 | down = [downconv, downrelu]
84 | up = [upconv, uprelu]
85 |
86 | model = down + up
87 | else:
88 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
89 | kernel_size=4, stride=2,
90 | padding=1, bias=use_bias)
91 | if norm_layer is not None:
92 | down = [downconv, downnorm, downrelu]
93 | up = [upconv, upnorm, uprelu]
94 | else:
95 | down = [downconv, downrelu]
96 | up = [upconv, uprelu]
97 |
98 | if use_dropout:
99 | model = down + [submodule] + up + [nn.Dropout(0.5)]
100 | else:
101 | model = down + [submodule] + up
102 |
103 | self.model = nn.Sequential(*model)
104 |
105 | def forward(self, x):
106 | if self.outermost and not self.outer_skip:
107 | return self.model(x)
108 | else: # add skip connections
109 | return torch.cat([x, self.model(x)], 1)
110 |
111 |
112 | def init_latent(latent_num, wavefront_res, ones=False):
113 | if latent_num > 0:
114 | if ones:
115 | latent = nn.Parameter(torch.ones(1, latent_num, *wavefront_res,
116 | requires_grad=True))
117 | else:
118 | latent = nn.Parameter(torch.zeros(1, latent_num, *wavefront_res,
119 | requires_grad=True))
120 | else:
121 | latent = None
122 | return latent
123 |
124 |
125 | def apply_net(net, input, latent_code, complex=False):
126 | if net is None:
127 | return input
128 | if complex: # Only valid for single batch or single channel complex inputs and outputs
129 | multi_channel = (input.shape[1] > 1)
130 | if multi_channel:
131 | input = torch.view_as_real(input[0,...])
132 | else:
133 | input = torch.view_as_real(input[:,0,...])
134 | input = input.permute(0,3,1,2)
135 | if latent_code is not None:
136 | input = torch.cat((input, latent_code), dim=1)
137 | output = net(input)
138 | if complex:
139 | if multi_channel:
140 | output = output.permute(0,2,3,1).unsqueeze(0)
141 | else:
142 | output = output.permute(0,2,3,1).unsqueeze(1)
143 | output = torch.complex(output[...,0], output[...,1])
144 | return output
145 |
146 |
147 | def init_weights(net, init_type='normal', init_gain=0.02, outer_skip=False):
148 | """Initialize network weights.
149 | Parameters:
150 | net (network) -- network to be initialized
151 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
152 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
153 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
154 | work better for some applications. Feel free to try yourself.
155 | """
156 |
157 | def init_func(m): # define the initialization function
158 | classname = m.__class__.__name__
159 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
160 | if init_type == 'normal':
161 | init.normal_(m.weight.data, 0.0, init_gain)
162 | elif init_type == 'xavier':
163 | init.xavier_normal_(m.weight.data, gain=init_gain)
164 | elif init_type == 'kaiming':
165 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
166 | elif init_type == 'orthogonal':
167 | init.orthogonal_(m.weight.data, gain=init_gain)
168 | else:
169 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
170 | if hasattr(m, 'bias') and m.bias is not None:
171 | init.constant_(m.bias.data, 0.0)
172 | elif classname.find(
173 | 'BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
174 | init.normal_(m.weight.data, 1.0, init_gain)
175 | init.constant_(m.bias.data, 0.0)
176 |
177 | print('initialize network with %s' % init_type)
178 | net.apply(init_func) # apply the initialization function
179 |
180 |
181 | class UnetGenerator(nn.Module):
182 | """Create a Unet-based generator"""
183 |
184 | def __init__(self, input_nc=1, output_nc=1, num_downs=8, nf0=32, max_channels=512,
185 | norm_layer=nn.InstanceNorm2d, use_dropout=False, outer_skip=True,
186 | half_channels=False, eighth_channels=False):
187 | """Construct a Unet generator
188 | Parameters:
189 | input_nc (int) -- the number of channels in input images
190 | output_nc (int) -- the number of channels in output images
191 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
192 | image of size 128x128 will become of size 1x1 # at the bottleneck
193 | ngf (int) -- the number of filters in the last conv layer
194 | norm_layer -- normalization layer
195 | We construct the U-Net from the innermost layer to the outermost layer.
196 | It is a recursive process.
197 | """
198 | super(UnetGenerator, self).__init__()
199 | self.outer_skip = outer_skip
200 | self.input_nc = input_nc
201 |
202 | if eighth_channels:
203 | divisor = 8
204 | elif half_channels:
205 | divisor = 2
206 | else:
207 | divisor = 1
208 | # construct unet structure
209 |
210 | assert num_downs >= 2
211 |
212 | # Add the innermost layer
213 | unet_block = UnetSkipConnectionBlock(min(2 ** (num_downs - 1) * nf0, max_channels) // divisor,
214 | min(2 ** (num_downs - 1) * nf0, max_channels) // divisor,
215 | input_nc=None, submodule=None, norm_layer=norm_layer,
216 | innermost=True)
217 |
218 | for i in list(range(1, num_downs - 1))[::-1]:
219 | if i == 1:
220 | norm = None # Praneeth's modification
221 | else:
222 | norm = norm_layer
223 |
224 | unet_block = UnetSkipConnectionBlock(min(2 ** i * nf0, max_channels) // divisor,
225 | min(2 ** (i + 1) * nf0, max_channels) // divisor,
226 | input_nc=None, submodule=unet_block,
227 | norm_layer=norm,
228 | use_dropout=use_dropout)
229 |
230 | # Add the outermost layer
231 | self.model = UnetSkipConnectionBlock(min(nf0, max_channels) // divisor,
232 | min(2 * nf0, max_channels) // divisor,
233 | input_nc=input_nc, submodule=unet_block, outermost=True,
234 | norm_layer=None, outer_skip=self.outer_skip)
235 | if self.outer_skip:
236 | self.additional_conv = nn.Conv2d(input_nc + min(nf0, max_channels) // divisor, output_nc,
237 | kernel_size=4, stride=1, padding=2, bias=True)
238 | else:
239 | self.additional_conv = nn.Conv2d(min(nf0, max_channels) // divisor, output_nc,
240 | kernel_size=4, stride=1, padding=2, bias=True)
241 |
242 | def forward(self, cnn_input):
243 | """Standard forward"""
244 | output = self.model(cnn_input)
245 | output = self.additional_conv(output)
246 | output = output[:,:,:-1,:-1]
247 | return output
248 |
249 |
250 | class Conv2dSame(torch.nn.Module):
251 | '''2D convolution that pads to keep spatial dimensions equal.
252 | Cannot deal with stride. Only quadratic kernels (=scalar kernel_size).
253 | '''
254 |
255 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, padding_layer=nn.ReflectionPad2d):
256 | '''
257 | :param in_channels: Number of input channels
258 | :param out_channels: Number of output channels
259 | :param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported).
260 | :param bias: Whether or not to use bias.
261 | :param padding_layer: Which padding to use. Default is reflection padding.
262 | '''
263 | super().__init__()
264 | ka = kernel_size // 2
265 | kb = ka - 1 if kernel_size % 2 == 0 else ka
266 | self.net = nn.Sequential(
267 | padding_layer((ka, kb, ka, kb)),
268 | nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias, stride=1)
269 | )
270 |
271 | self.weight = self.net[1].weight
272 | self.bias = self.net[1].bias
273 |
274 | def forward(self, x):
275 | return self.net(x)
276 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utils
3 | """
4 |
5 | import math
6 | import random
7 | import numpy as np
8 |
9 | import os
10 | import torch
11 | import torch.nn as nn
12 |
13 | from skimage.metrics import peak_signal_noise_ratio as psnr
14 | from skimage.metrics import structural_similarity as ssim
15 | from torchvision.transforms import GaussianBlur
16 |
17 | from skimage.restoration import inpaint
18 | import kornia
19 | import torch.nn.functional as F
20 | import torch.fft as tfft
21 |
22 | def roll_torch(tensor, shift: int, axis: int):
23 | if shift == 0:
24 | return tensor
25 |
26 | if axis < 0:
27 | axis += tensor.dim()
28 |
29 | dim_size = tensor.size(axis)
30 | after_start = dim_size - shift
31 | if shift < 0:
32 | after_start = -shift
33 | shift = dim_size - abs(shift)
34 |
35 | before = tensor.narrow(axis, 0, dim_size - shift)
36 | after = tensor.narrow(axis, after_start, shift)
37 | return torch.cat([after, before], axis)
38 |
39 |
40 | def ifftshift(tensor):
41 | """ifftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2]
42 |
43 | shifts the width and heights
44 | """
45 | size = tensor.size()
46 | tensor_shifted = roll_torch(tensor, -math.floor(size[2] / 2.0), 2)
47 | tensor_shifted = roll_torch(tensor_shifted, -math.floor(size[3] / 2.0), 3)
48 | return tensor_shifted
49 |
50 |
51 | def fftshift(tensor):
52 | """fftshift for tensors of dimensions [minibatch_size, num_channels, height, width, 2]
53 |
54 | shifts the width and heights
55 | """
56 | size = tensor.size()
57 | tensor_shifted = roll_torch(tensor, math.floor(size[2] / 2.0), 2)
58 | tensor_shifted = roll_torch(tensor_shifted, math.floor(size[3] / 2.0), 3)
59 | return tensor_shifted
60 |
61 |
62 | def pad_image(field, target_shape, pytorch=True, stacked_complex=True, padval=0, mode='constant'):
63 | """Pads a 2D complex field up to target_shape in size
64 |
65 | Padding is done such that when used with crop_image(), odd and even dimensions are
66 | handled correctly to properly undo the padding.
67 |
68 | field: the field to be padded. May have as many leading dimensions as necessary
69 | (e.g., batch or channel dimensions)
70 | target_shape: the 2D target output dimensions. If any dimensions are smaller
71 | than field, no padding is applied
72 | pytorch: if True, uses torch functions, if False, uses numpy
73 | stacked_complex: for pytorch=True, indicates that field has a final dimension
74 | representing real and imag
75 | padval: the real number value to pad by
76 | mode: padding mode for numpy or torch
77 | """
78 | if pytorch:
79 | if stacked_complex:
80 | size_diff = np.array(target_shape) - np.array(field.shape[-3:-1])
81 | odd_dim = np.array(field.shape[-3:-1]) % 2
82 | else:
83 | size_diff = np.array(target_shape) - np.array(field.shape[-2:])
84 | odd_dim = np.array(field.shape[-2:]) % 2
85 | else:
86 | size_diff = np.array(target_shape) - np.array(field.shape[-2:])
87 | odd_dim = np.array(field.shape[-2:]) % 2
88 |
89 | # pad the dimensions that need to increase in size
90 | if (size_diff > 0).any():
91 | pad_total = np.maximum(size_diff, 0)
92 | pad_front = (pad_total + odd_dim) // 2
93 | pad_end = (pad_total + 1 - odd_dim) // 2
94 |
95 | if pytorch:
96 | pad_axes = [int(p) # convert from np.int64
97 | for tple in zip(pad_front[::-1], pad_end[::-1])
98 | for p in tple]
99 | if stacked_complex:
100 | return pad_stacked_complex(field, pad_axes, mode=mode, padval=padval)
101 | else:
102 | return nn.functional.pad(field, pad_axes, mode=mode, value=padval)
103 | else:
104 | leading_dims = field.ndim - 2 # only pad the last two dims
105 | if leading_dims > 0:
106 | pad_front = np.concatenate(([0] * leading_dims, pad_front))
107 | pad_end = np.concatenate(([0] * leading_dims, pad_end))
108 | return np.pad(field, tuple(zip(pad_front, pad_end)), mode,
109 | constant_values=padval)
110 | else:
111 | return field
112 |
113 |
114 | def crop_image(field, target_shape, pytorch=True, stacked_complex=True, lf=False):
115 | """Crops a 2D field, see pad_image() for details
116 |
117 | No cropping is done if target_shape is already smaller than field
118 | """
119 | if target_shape is None:
120 | return field
121 |
122 | if lf:
123 | size_diff = np.array(field.shape[-4:-2]) - np.array(target_shape)
124 | odd_dim = np.array(field.shape[-4:-2]) % 2
125 | else:
126 | if pytorch:
127 | if stacked_complex:
128 | size_diff = np.array(field.shape[-3:-1]) - np.array(target_shape)
129 | odd_dim = np.array(field.shape[-3:-1]) % 2
130 | else:
131 | size_diff = np.array(field.shape[-2:]) - np.array(target_shape)
132 | odd_dim = np.array(field.shape[-2:]) % 2
133 | else:
134 | size_diff = np.array(field.shape[-2:]) - np.array(target_shape)
135 | odd_dim = np.array(field.shape[-2:]) % 2
136 |
137 | # crop dimensions that need to decrease in size
138 | if (size_diff > 0).any():
139 | crop_total = np.maximum(size_diff, 0)
140 | crop_front = (crop_total + 1 - odd_dim) // 2
141 | crop_end = (crop_total + odd_dim) // 2
142 |
143 | crop_slices = [slice(int(f), int(-e) if e else None)
144 | for f, e in zip(crop_front, crop_end)]
145 | if lf:
146 | return field[(..., *crop_slices, slice(None), slice(None))]
147 | else:
148 | if pytorch and stacked_complex:
149 | return field[(..., *crop_slices, slice(None))]
150 | else:
151 | return field[(..., *crop_slices)]
152 | else:
153 | return field
154 |
155 |
156 | def cond_mkdir(path):
157 | if not os.path.exists(path):
158 | os.makedirs(path)
159 |
160 |
161 | def phasemap_8bit(phasemap, inverted=True):
162 | """convert a phasemap tensor into a numpy 8bit phasemap that can be directly displayed
163 |
164 | Input
165 | -----
166 | :param phasemap: input phasemap tensor, which is supposed to be in the range of [-pi, pi].
167 | :param inverted: a boolean value that indicates whether the phasemap is inverted.
168 |
169 | Output
170 | ------
171 | :return: output phasemap, with uint8 dtype (in [0, 255])
172 | """
173 |
174 | output_phase = ((phasemap + np.pi) % (2 * np.pi)) / (2 * np.pi)
175 | if inverted:
176 | phase_out_8bit = ((1 - output_phase) * 255).round().cpu().detach().squeeze().numpy().astype(np.uint8) # quantized to 8 bits
177 | else:
178 | phase_out_8bit = ((output_phase) * 255).round().cpu().detach().squeeze().numpy().astype(np.uint8) # quantized to 8 bits
179 | return phase_out_8bit
180 |
181 |
182 | def burst_img_processor(img_burst_list):
183 | img_tensor = np.stack(img_burst_list, axis=0)
184 | img_avg = np.mean(img_tensor, axis=0)
185 | return im2float(img_avg) # changed from int8 to float32
186 |
187 |
188 | def im2float(im, dtype=np.float32):
189 | """convert uint16 or uint8 image to float32, with range scaled to 0-1
190 |
191 | :param im: image
192 | :param dtype: default np.float32
193 | :return:
194 | """
195 | if issubclass(im.dtype.type, np.floating):
196 | return im.astype(dtype)
197 | elif issubclass(im.dtype.type, np.integer):
198 | return im / dtype(np.iinfo(im.dtype).max)
199 | else:
200 | raise ValueError(f'Unsupported data type {im.dtype}')
201 |
202 |
203 | def get_psnr_ssim(recon_amp, target_amp, multichannel=False):
204 | """get PSNR and SSIM metrics"""
205 | psnrs, ssims = {}, {}
206 |
207 | # amplitude
208 | psnrs['amp'] = psnr(target_amp, recon_amp)
209 | ssims['amp'] = ssim(target_amp, recon_amp, multichannel=multichannel)
210 |
211 | # linear
212 | target_linear = target_amp**2
213 | recon_linear = recon_amp**2
214 | psnrs['lin'] = psnr(target_linear, recon_linear)
215 | ssims['lin'] = ssim(target_linear, recon_linear, multichannel=multichannel)
216 |
217 | # srgb
218 | target_srgb = srgb_lin2gamma(np.clip(target_linear, 0.0, 1.0))
219 | recon_srgb = srgb_lin2gamma(np.clip(recon_linear, 0.0, 1.0))
220 | psnrs['srgb'] = psnr(target_srgb, recon_srgb)
221 | ssims['srgb'] = ssim(target_srgb, recon_srgb, multichannel=multichannel)
222 |
223 | return psnrs, ssims
224 |
225 |
226 | def make_kernel_gaussian(sigma, kernel_size):
227 |
228 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
229 | x_cord = torch.arange(kernel_size)
230 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size)
231 | y_grid = x_grid.t()
232 | xy_grid = torch.stack([x_grid, y_grid], dim=-1)
233 |
234 | mean = (kernel_size - 1) / 2
235 | variance = sigma**2
236 |
237 | # Calculate the 2-dimensional gaussian kernel which is
238 | # the product of two gaussian distributions for two different
239 | # variables (in this case called x and y)
240 | gaussian_kernel = ((1 / (2 * math.pi * variance))
241 | * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1)
242 | / (2 * variance)))
243 | # Make sure sum of values in gaussian kernel equals 1.
244 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
245 |
246 | # Reshape to 2d depthwise convolutional weight
247 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
248 |
249 | return gaussian_kernel
250 |
251 |
252 | def pad_stacked_complex(field, pad_width, padval=0):
253 | if padval == 0:
254 | pad_width = (0, 0, *pad_width) # add 0 padding for stacked_complex dimension
255 | return nn.functional.pad(field, pad_width)
256 | else:
257 | if isinstance(padval, torch.Tensor):
258 | padval = padval.item()
259 |
260 | real, imag = field[..., 0], field[..., 1]
261 | real = nn.functional.pad(real, pad_width, value=padval)
262 | imag = nn.functional.pad(imag, pad_width, value=0)
263 | return torch.stack((real, imag), -1)
264 |
265 |
266 | def srgb_gamma2lin(im_in):
267 | """ converts from sRGB to linear color space """
268 | thresh = 0.04045
269 | if torch.is_tensor(im_in):
270 | low_val = im_in <= thresh
271 | im_out = torch.zeros_like(im_in)
272 | im_out[low_val] = 25 / 323 * im_in[low_val]
273 | im_out[torch.logical_not(low_val)] = ((200 * im_in[torch.logical_not(low_val)] + 11)
274 | / 211) ** (12 / 5)
275 | else:
276 | im_out = np.where(im_in <= thresh, im_in / 12.92, ((im_in + 0.055) / 1.055) ** (12/5))
277 |
278 | return im_out
279 |
280 |
281 | def srgb_lin2gamma(im_in):
282 | """ converts from linear to sRGB color space """
283 | thresh = 0.0031308
284 | im_out = np.where(im_in <= thresh, 12.92 * im_in, 1.055 * (im_in**(1 / 2.4)) - 0.055)
285 | return im_out
286 |
287 |
288 | def decompose_depthmap(depthmap_virtual_D, depth_planes_D):
289 | """ decompose a depthmap image into a set of masks with depth positions (in Diopter) """
290 |
291 | num_planes = len(depth_planes_D)
292 |
293 | masks = torch.zeros(depthmap_virtual_D.shape[0], len(depth_planes_D), *depthmap_virtual_D.shape[-2:],
294 | dtype=torch.float32).to(depthmap_virtual_D.device)
295 | for k in range(len(depth_planes_D) - 1):
296 | depth_l = depth_planes_D[k]
297 | depth_h = depth_planes_D[k + 1]
298 | idxs = (depthmap_virtual_D >= depth_l) & (depthmap_virtual_D < depth_h)
299 | close_idxs = (depth_h - depthmap_virtual_D) > (depthmap_virtual_D - depth_l)
300 |
301 | # closer one
302 | mask = torch.zeros_like(depthmap_virtual_D)
303 | mask += idxs * close_idxs * 1
304 | masks[:, k, ...] += mask.squeeze(1)
305 |
306 | # farther one
307 | mask = torch.zeros_like(depthmap_virtual_D)
308 | mask += idxs * (~close_idxs) * 1
309 | masks[:, k + 1, ...] += mask.squeeze(1)
310 |
311 | # even closer ones
312 | idxs = depthmap_virtual_D >= max(depth_planes_D)
313 | mask = torch.zeros_like(depthmap_virtual_D)
314 | mask += idxs * 1
315 | masks[:, len(depth_planes_D) - 1, ...] += mask.clone().squeeze(1)
316 |
317 | # even farther ones
318 | idxs = depthmap_virtual_D < min(depth_planes_D)
319 | mask = torch.zeros_like(depthmap_virtual_D)
320 | mask += idxs * 1
321 | masks[:, 0, ...] += mask.clone().squeeze(1)
322 |
323 | # sanity check
324 | assert torch.sum(masks).item() == torch.numel(masks) / num_planes
325 |
326 | return masks
327 |
328 |
329 | def prop_dist_to_diopter(prop_dists, focal_distance, prop_dist_inf, from_lens=True):
330 | """
331 | Calculates distance from the user in diopter unit given the propagation distance from the SLM.
332 | :param prop_dists:
333 | :param focal_distance:
334 | :param prop_dist_inf:
335 | :param from_lens:
336 | :return:
337 | """
338 | x0 = prop_dist_inf # prop distance from SLM that correcponds to optical infinity from the user
339 | f = focal_distance # focal distance of eyepiece
340 |
341 | if from_lens: # distance is from the lens
342 | diopters = [1 / (x0 + f - x) - 1 / f for x in prop_dists] # diopters from the user side
343 | else: # distance is from the user (basically adding focal length)
344 | diopters = [(x - x0) / f**2 for x in prop_dists]
345 |
346 | return diopters
347 |
348 |
349 | class PSNR:
350 | """Peak Signal to Noise Ratio
351 | img1 and img2 have range [0, 255]"""
352 |
353 | def __init__(self):
354 | self.name = "PSNR"
355 |
356 | @staticmethod
357 | def __call__(img1, img2):
358 | mse = torch.mean((img1 - img2) ** 2)
359 | return 20 * torch.log10(255.0 / torch.sqrt(mse))
360 |
361 |
362 | def laplacian(img):
363 |
364 | # signed angular difference
365 | grad_x1, grad_y1 = grad(img, next_pixel=True) # x_{n+1} - x_{n}
366 | grad_x0, grad_y0 = grad(img, next_pixel=False) # x_{n} - x_{n-1}
367 |
368 | laplacian_x = grad_x1 - grad_x0 # (x_{n+1} - x_{n}) - (x_{n} - x_{n-1})
369 | laplacian_y = grad_y1 - grad_y0
370 |
371 | return laplacian_x + laplacian_y
372 |
373 |
374 | def grad(img, next_pixel=False, sovel=False):
375 |
376 | if img.shape[1] > 1:
377 | permuted = True
378 | img = img.permute(1, 0, 2, 3)
379 | else:
380 | permuted = False
381 |
382 | # set diff kernel
383 | if sovel: # use sovel filter for gradient calculation
384 | k_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32) / 8
385 | k_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32) / 8
386 | else:
387 | if next_pixel: # x_{n+1} - x_n
388 | k_x = torch.tensor([[0, -1, 1]], dtype=torch.float32)
389 | k_y = torch.tensor([[1], [-1], [0]], dtype=torch.float32)
390 | else: # x_{n} - x_{n-1}
391 | k_x = torch.tensor([[-1, 1, 0]], dtype=torch.float32)
392 | k_y = torch.tensor([[0], [1], [-1]], dtype=torch.float32)
393 |
394 | # upload to gpu
395 | k_x = k_x.to(img.device).unsqueeze(0).unsqueeze(0)
396 | k_y = k_y.to(img.device).unsqueeze(0).unsqueeze(0)
397 |
398 | # boundary handling (replicate elements at boundary)
399 | img_x = F.pad(img, (1, 1, 0, 0), 'replicate')
400 | img_y = F.pad(img, (0, 0, 1, 1), 'replicate')
401 |
402 | # take sign angular difference
403 | grad_x = signed_ang(F.conv2d(img_x, k_x))
404 | grad_y = signed_ang(F.conv2d(img_y, k_y))
405 |
406 | if permuted:
407 | grad_x = grad_x.permute(1, 0, 2, 3)
408 | grad_y = grad_y.permute(1, 0, 2, 3)
409 |
410 | return grad_x, grad_y
411 |
412 |
413 | def signed_ang(angle):
414 | """
415 | cast all angles into [-pi, pi]
416 | """
417 | return (angle + math.pi) % (2*math.pi) - math.pi
418 |
419 |
420 | # Adapted from https://github.com/svaiter/pyprox/blob/master/pyprox/operators.py
421 | def soft_thresholding(x, gamma):
422 | """
423 | return element-wise shrinkage function with threshold kappa
424 | """
425 | return torch.maximum(torch.zeros_like(x),
426 | 1 - gamma / torch.maximum(torch.abs(x), 1e-10*torch.ones_like(x))) * x
427 |
428 |
429 | def random_gen(num_planes=7, slm_type='ti'):
430 | """
431 | random hyperparameters for the dataset
432 | """
433 | frame_choices = [1, 2, 3, 3, 4, 4, 8, 8, 8, 8] if slm_type.lower() == 'ti' else [1]
434 |
435 | num_iters = random.choice(range(3000))
436 | phase_range = random.uniform(1.0, 6.28)
437 | target_range = random.uniform(0.5, 1.5)
438 | learning_rate = random.uniform(0.01, 0.03)
439 | plane_idx = random.choice(range(num_planes))
440 |
441 | return num_iters, phase_range, target_range, learning_rate, plane_idx
442 |
--------------------------------------------------------------------------------