├── 20210403
├── README.md
├── calibration
│ ├── calibration.mat
│ ├── cam1.mat
│ ├── cam2.mat
│ ├── cams.mat
│ └── checkerboard.png
├── gamma_calibration
│ └── gammas.mat
└── rotation_calibration
│ └── rotation.mat
├── README.md
├── camera_acquisitions
└── images
│ ├── calibration.png
│ └── sinusoids
│ ├── T=100
│ ├── 0.png
│ ├── 1.png
│ ├── 2.png
│ ├── 3.png
│ ├── 4.png
│ ├── 5.png
│ ├── 6.png
│ └── 7.png
│ ├── T=110
│ ├── 0.png
│ ├── 1.png
│ ├── 2.png
│ ├── 3.png
│ ├── 4.png
│ ├── 5.png
│ ├── 6.png
│ └── 7.png
│ └── T=70
│ ├── 0.png
│ ├── 1.png
│ ├── 2.png
│ ├── 3.png
│ ├── 4.png
│ ├── 5.png
│ ├── 6.png
│ └── 7.png
├── demo_experiments.py
├── diffmetrology
├── __init__.py
├── basics.py
├── optics.py
├── scene.py
├── shapes.py
├── solvers.py
└── utils.py
├── imgs
├── checkerboard.png
├── loss.png
├── results_initial.png
├── results_optimized.png
├── setup.png
├── spot_diagram_initial.png
├── spot_diagram_optimized.png
└── teaser.jpg
├── lenses
└── ThorLabs
│ └── LE1234-A.txt
└── metrology_calibrate.py
/20210403/README.md:
--------------------------------------------------------------------------------
1 | Camera 1 (GS3-U3-50S5C):
2 |
3 | - f# = 16
4 | - mode 0 (2048x2048)
5 | - mono16
6 |
7 | Camera 2 (GS3-U3-50S5M):
8 |
9 | - f# = 16
10 | - mode 0 (2048x2048)
11 | - mono16
12 |
--------------------------------------------------------------------------------
/20210403/calibration/calibration.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/calibration/calibration.mat
--------------------------------------------------------------------------------
/20210403/calibration/cam1.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/calibration/cam1.mat
--------------------------------------------------------------------------------
/20210403/calibration/cam2.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/calibration/cam2.mat
--------------------------------------------------------------------------------
/20210403/calibration/cams.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/calibration/cams.mat
--------------------------------------------------------------------------------
/20210403/calibration/checkerboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/calibration/checkerboard.png
--------------------------------------------------------------------------------
/20210403/gamma_calibration/gammas.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/gamma_calibration/gammas.mat
--------------------------------------------------------------------------------
/20210403/rotation_calibration/rotation.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/rotation_calibration/rotation.mat
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Towards self-calibrated lens metrology by differentiable refractive deflectometry
2 | This is the PyTorch implementation for our paper "Towards self-calibrated lens metrology by differentiable refractive deflectometry".
3 | ### [Project Page](https://vccimaging.org/Publications/Wang2021DiffDeflectometry/) | [Paper](https://vccimaging.org/Publications/Wang2021DiffDeflectometry/Wang2021DiffDeflectometry.pdf)
4 |
5 | [Towards self-calibrated lens metrology by differentiable refractive deflectometry](https://vccimaging.org/Publications/Wang2021DiffDeflectometry/Wang2021DiffDeflectometry.pdf)
6 | [Congli Wang](https://congliwang.github.io),
7 | [Ni Chen](https://ni-chen.github.io), and
8 | [Wolfgang Heidrich](https://vccimaging.org/People/heidriw)
9 | King Abdullah University of Science and Technology (KAUST)
10 | OSA Optics Express 2021
11 |
12 |
13 | Figure: Dual-camera refractive deflectometry for lens metrology. (a) Hardware setup. (b) Captured phase-shifted images, from which on-screen intersections are obtained. (c) A ray tracer models the setup by ray tracing each parameterized refractive surface, obtaining the modeled intersections. (d) Unknown parameters and pose are jointly optimized by minimizing the error between measurement and modeled.
14 |
15 | ## Features
16 |
17 | This repository implements a PyTorch differentiable ray tracer for deflectometry. The solver enables:
18 | - Fast (few seconds on a GPU), simultaneous lens parameters and pose estimation with gradients estimated by a differentiable ray tracer.
19 | - Rendering of photo-realistic images from a CMM-free, computationally calibrated metrology setup.
20 | - A fringe analyzer to solve displacements from phase-shifting patterns.
21 |
22 | ## Quick Start
23 |
24 | ### Real Experiment Example
25 |
26 | [`demo_experiments.py`](./demo_experiments.py) reproduces one of the paper's experimental results. Prerequisites:
27 |
28 | - Download raw image data `*.npz` from the google drive [here](https://drive.google.com/file/d/15a3T0wL7sWDaEXsAeZM0S7YXRS-fXRNO/view?usp=sharing).
29 | - Put the downloaded `*.npz` into directory `./20210403/measurement/LE1234-A`.
30 |
31 | Then run [`demo_experiments.py`](./demo_experiments.py). The script should output the following figures:
32 |
33 | |  |  |
34 | | :---------------------------------: | :-------------------------------------------: |
35 | | The physical setup for experiments. | Optimization loss with respect to iterations. |
36 |
37 | |  |
38 | | :---------------------------------------: |
39 | | Spot diagrams on the display (initial). |
40 | |  |
41 | | Spot diagrams on the display (optimized). |
42 |
43 | |  |
44 | | :-------------------------------------------------------: |
45 | | Measurement images / modeled images / error. (initial) |
46 | |  |
47 | | Measurement images / modeled images / error. (optimized) |
48 |
49 | ## Citation
50 | ```bibtex
51 | @article{wang2021towards,
52 | title={Towards self-calibrated lens metrology by differentiable refractive deflectometry},
53 | author={Wang, Congli and Chen, Ni and Heidrich, Wolfgang},
54 | journal={Optics Express},
55 | volume={29},
56 | number={19},
57 | pages={30284--30295},
58 | year={2021},
59 | publisher={Optical Society of America}
60 | }
61 | ```
62 |
63 | ## Contact
64 | Please either open an issue, or contact Congli Wang for questions.
65 |
66 |
--------------------------------------------------------------------------------
/camera_acquisitions/images/calibration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/calibration.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=100/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/0.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=100/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/1.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=100/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/2.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=100/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/3.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=100/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/4.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=100/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/5.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=100/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/6.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=100/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/7.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=110/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/0.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=110/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/1.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=110/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/2.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=110/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/3.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=110/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/4.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=110/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/5.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=110/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/6.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=110/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/7.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=70/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/0.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=70/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/1.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=70/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/2.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=70/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/3.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=70/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/4.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=70/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/5.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=70/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/6.png
--------------------------------------------------------------------------------
/camera_acquisitions/images/sinusoids/T=70/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/7.png
--------------------------------------------------------------------------------
/demo_experiments.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | import diffmetrology as dm
5 | from matplotlib.image import imread
6 |
7 | # load setup information
8 | data_path = './20210403'
9 | device = dm.init()
10 | # device = torch.device('cpu')
11 |
12 | print("Initialize a DiffMetrology object.")
13 | origin_shift = np.array([0.0, 0.0, 0.0])
14 | DM = dm.DiffMetrology(
15 | calibration_path = data_path + '/calibration/',
16 | rotation_path = data_path + '/rotation_calibration/rotation.mat',
17 | lut_path = data_path + '/gamma_calibration/gammas.mat',
18 | origin_shift = origin_shift,
19 | scale=1.0,
20 | device=device
21 | )
22 |
23 | print("Crop the region of interst in the original images.")
24 | filmsize = np.array([768, 768])
25 | # filmsize = np.array([2048, 2048])
26 | crop_offset = ((2048 - filmsize)/2).astype(int)
27 | for cam in DM.scene.cameras:
28 | cam.filmsize = filmsize
29 | cam.crop_offset = torch.Tensor(crop_offset).to(device)
30 | def crop(x):
31 | return x[..., crop_offset[0]:crop_offset[0]+filmsize[0], crop_offset[1]:crop_offset[1]+filmsize[1]]
32 |
33 | DM.test_setup()
34 |
35 | # ==== Read measurements
36 | lens_name = 'LE1234-A'
37 |
38 | DM.scene.lensgroup.load_file('Thorlabs/' + lens_name + '.txt')
39 |
40 | def show_parameters():
41 | for i in range(len(DM.scene.lensgroup.surfaces)):
42 | print(f"Lens radius of curvature at surface[{i}]: {1.0/DM.scene.lensgroup.surfaces[i].c.item()}")
43 | print(DM.scene.lensgroup.surfaces[1].d)
44 |
45 |
46 | print("Ground Truth Lens Parameters:")
47 | show_parameters()
48 |
49 |
50 | angle = 0.0
51 | Ts = np.array([70, 100, 110]) # period of the sinusoids
52 | t = 0
53 |
54 | # load data
55 | option = 'experiment'
56 | if option == 'experiment':
57 | data = np.load(data_path + '/measurement/' + lens_name + '/data_new.npz')
58 | imgs = data['imgs']
59 | refs = data['refs']
60 | imgs = crop(imgs)
61 | refs = crop(refs)
62 | del data
63 |
64 |
65 | # solve for ps and valid map
66 | ps_cap, valid_cap, C = DM.solve_for_intersections(imgs, refs, Ts[t:])
67 |
68 | # set display pattern
69 | # xs = [0, 4]
70 | xs = [0]
71 | sinusoid_path = './camera_acquisitions/images/sinusoids/T=' + str(Ts[t])
72 | ims = [ np.mean(imread(sinusoid_path + '/' + str(x) + '.png'), axis=-1) for x in xs ] # use grayscale
73 | ims = np.array([ im/im.max() for im in ims ])
74 | ims = np.sum(ims, axis=0)
75 | DM.set_texture(ims)
76 | del ims
77 | if option == 'experiment':
78 | # Obtained from running `metrology_calibrate.py`
79 | # DM.scene.screen.texture_shift = torch.Tensor([1.7445182, 1.1107264]).to(device) # LE1234-A
80 | DM.scene.screen.texture_shift = torch.Tensor([0. , 1.1106231]).to(device) # LE1234-A
81 |
82 |
83 |
84 | print("Shift `origin` by an estimated value")
85 | origin = DM._compute_mount_geometry(C, verbose=True)
86 | DM.scene.lensgroup.origin = torch.Tensor(origin).to(device)
87 | DM.scene.lensgroup.update()
88 | print(origin)
89 |
90 |
91 | print("Load real images")
92 | FR = dm.Fringe()
93 | a_cap, b_cap, psi_cap = FR.solve(imgs)
94 | imgs_sub = np.array([imgs[0,x,...] for x in xs])
95 | imgs_sub = imgs_sub - a_cap[:,0,...]
96 | imgs_sub = np.sum(imgs_sub, axis=0)
97 | imgs_sub = valid_cap * torch.Tensor(imgs_sub).to(device)
98 | I0 = valid_cap * len(xs) * (imgs_sub - imgs_sub.min().item()) / (imgs_sub.max().item() - imgs_sub.min().item())
99 |
100 |
101 | # Utility functions
102 | def forward():
103 | ps = torch.stack(DM.trace(with_element=True, mask=valid_cap, angles=angle)[0])[..., 0:2]
104 | return ps
105 |
106 | def render():
107 | I = valid_cap*torch.stack(DM.render(with_element=True, angles=angle))
108 | I[torch.isnan(I)] = 0.0
109 | return I
110 |
111 | def visualize(ps_current, save_string):
112 | print("Showing spot diagrams at display.")
113 | DM.spot_diagram(ps_cap, ps_current, valid=valid_cap, angle=angle, with_grid=False)
114 | plt.show()
115 |
116 | print("Showing images (measurement & modeled & |measurement - modeled|).")
117 |
118 | # Render images from parameters
119 | I = render()
120 |
121 | fig, axes = plt.subplots(2, 3)
122 | for i in range(2):
123 | im = axes[i,0].imshow(I0[i].cpu(), vmin=0, vmax=1, cmap='gray')
124 | axes[i,0].set_title(f"Camera {i+1}\nMeasurement")
125 | axes[i,0].set_xlabel('[pixel]')
126 | axes[i,0].set_ylabel('[pixel]')
127 | plt.colorbar(im, ax=axes[i,0])
128 |
129 | im = axes[i,1].imshow(I[i].cpu().detach(), vmin=0, vmax=1, cmap='gray')
130 | plt.colorbar(im, ax=axes[i,1])
131 | axes[i,1].set_title(f"Camera {i+1}\nModeled")
132 | axes[i,1].set_xlabel('[pixel]')
133 | axes[i,1].set_ylabel('[pixel]')
134 |
135 | im = axes[i,2].imshow(I0[i].cpu() - I[i].cpu().detach(), vmin=-1, vmax=1, cmap='coolwarm')
136 | plt.colorbar(im, ax=axes[i,2])
137 | axes[i,2].set_title(f"Camera {i+1}\nError")
138 | axes[i,2].set_xlabel('[pixel]')
139 | axes[i,2].set_ylabel('[pixel]')
140 |
141 | fig.suptitle(save_string)
142 | fig.savefig(save_string + str(i) + ".jpg", bbox_inches='tight')
143 | plt.show()
144 |
145 |
146 | print("Initialize lens parameters.")
147 | DM.scene.lensgroup.surfaces[0].c = torch.Tensor([0.00]).to(device) # 1st surface curvature
148 | DM.scene.lensgroup.surfaces[1].c = torch.Tensor([0.00]).to(device) # 2nd surface curvature
149 | DM.scene.lensgroup.surfaces[1].d = torch.Tensor([3.00]).to(device) # lens thickness
150 | DM.scene.lensgroup.theta_x = torch.Tensor([0.00]).to(device) # lens X-tilt angle
151 | DM.scene.lensgroup.theta_y = torch.Tensor([0.00]).to(device) # lens Y-tilt angle
152 | DM.scene.lensgroup.update()
153 |
154 | print("Visualize initial status.")
155 | ps_current = forward()
156 | visualize(ps_current, save_string="initial")
157 |
158 |
159 | print("Set optimization parameters.")
160 | diff_names = [
161 | 'lensgroup.surfaces[0].c',
162 | 'lensgroup.surfaces[1].c',
163 | 'lensgroup.surfaces[1].d',
164 | 'lensgroup.origin',
165 | 'lensgroup.theta_x',
166 | 'lensgroup.theta_y'
167 | ]
168 | def loss(ps):
169 | return torch.sum((ps[valid_cap,...] - ps_cap[valid_cap,...])**2, axis=-1).mean()
170 |
171 | def func_yref_y(ps):
172 | b = valid_cap[...,None] * (ps_cap - ps)
173 | b[torch.isnan(b)] = 0.0 # handle NaN ... otherwise LM won't work!
174 | return b
175 |
176 | # Optimize
177 | ls = DM.solve(diff_names, forward, loss, func_yref_y, option='LM', R='I')
178 | print("Done. Show results (Spot RMS loss):")
179 | show_parameters()
180 |
181 | plt.figure()
182 | plt.semilogy(ls, '-o', color='k')
183 | plt.xlabel('LM iteration')
184 | plt.ylabel('Loss')
185 | plt.title("Opitmization Loss")
186 |
187 | print("Visualize optimized status.")
188 | ps_current = forward()
189 | visualize(ps_current, save_string="optimized")
190 |
191 | # Print mean displacement error
192 | T = ps_current - ps_cap
193 | E = torch.sqrt(torch.sum(T[valid_cap, ...]**2, axis=-1)).mean()
194 | print("error = {} [um]".format(E*1e3))
195 |
196 |
--------------------------------------------------------------------------------
/diffmetrology/__init__.py:
--------------------------------------------------------------------------------
1 | # image formation model
2 | from .basics import *
3 | from .shapes import *
4 | from .optics import *
5 | from .scene import *
6 |
7 | # algorithmic solvers
8 | from .solvers import *
9 |
10 | # utilities
11 | from .utils import *
12 |
13 |
--------------------------------------------------------------------------------
/diffmetrology/basics.py:
--------------------------------------------------------------------------------
1 | import math
2 | from enum import Enum
3 | import torch
4 | import numpy as np
5 |
6 | # ----------------------------------------------------------------------------------------
7 |
8 | class PrettyPrinter():
9 | def __str__(self):
10 | lines = [self.__class__.__name__ + ':']
11 | for key, val in vars(self).items():
12 | if val.__class__.__name__ in ('list', 'tuple'):
13 | for i, v in enumerate(val):
14 | lines += '{}[{}]: {}'.format(key, i, v).split('\n')
15 |
16 | elif val.__class__.__name__ in 'dict':
17 | pass # ignore outputs for Hash tables
18 | elif key == key.upper() and len(key) > 5:
19 | pass # ignore all upper-case variables > 5 chars (constants)
20 | else:
21 | lines += '{}: {}'.format(key, val).split('\n')
22 | return '\n '.join(lines)
23 |
24 | def to(self, device=torch.device('cpu')):
25 | for key, val in vars(self).items():
26 | if torch.is_tensor(val):
27 | exec('self.{x} = self.{x}.to(device)'.format(x=key))
28 | elif issubclass(type(val), PrettyPrinter):
29 | exec(f'self.{key}.to(device)')
30 | elif val.__class__.__name__ in ('list', 'tuple'):
31 | for i, v in enumerate(val):
32 | if torch.is_tensor(v):
33 | exec('self.{x}[{i}] = self.{x}[{i}].to(device)'.format(x=key, i=i))
34 | elif issubclass(type(v), PrettyPrinter):
35 | exec('self.{}[{}].to(device)'.format(key, i))
36 |
37 |
38 | # ----------------------------------------------------------------------------------------
39 |
40 | class Ray(PrettyPrinter):
41 | def __init__(self, o, d, wavelength, device=torch.device('cpu')):
42 | self.o = o # ray origin
43 | self.d = d # ray direction (normalized)
44 |
45 | # scalar-version
46 | self.wavelength = wavelength # [nm]
47 | self.mint = 1e-5 # [mm]
48 | self.maxt = 1e5 # [mm]
49 | self.to(device)
50 |
51 | def __call__(self, t):
52 | return self.o + t[..., None] * self.d
53 |
54 | class Transformation(PrettyPrinter):
55 | def __init__(self, R, t):
56 | if torch.is_tensor(R):
57 | self.R = R
58 | else:
59 | self.R = torch.Tensor(R)
60 | if torch.is_tensor(t):
61 | self.t = t
62 | else:
63 | self.t = torch.Tensor(t)
64 |
65 | def transform_point(self, o):
66 | return torch.squeeze(self.R @ o[..., None]) + self.t
67 |
68 | def transform_vector(self, d):
69 | return torch.squeeze(self.R @ d[..., None])
70 |
71 | def transform_ray(self, ray):
72 | o = self.transform_point(ray.o)
73 | d = self.transform_vector(ray.d)
74 | if o.is_cuda:
75 | return Ray(o, d, ray.wavelength, device=torch.device('cuda'))
76 | else:
77 | return Ray(o, d, ray.wavelength)
78 |
79 | def inverse(self):
80 | RT = self.R.T
81 | t = self.t
82 | return Transformation(RT, -RT @ t)
83 |
84 | class Sampler(PrettyPrinter):
85 | def __init__(self):
86 | self.to()
87 | self.pi_over_2 = np.pi / 2
88 | self.pi_over_4 = np.pi / 4
89 |
90 | def concentric_sample_disk(self, x, y):
91 | # https://pbr-book.org/3ed-2018/Monte_Carlo_Integration/2D_Sampling_with_Multidimensional_Transformations
92 |
93 | # map uniform random numbers to [-1,1]^2
94 | x = 2 * x - 1
95 | y = 2 * y - 1
96 |
97 | # handle degeneracy at the origin when xy == [0,0]
98 |
99 | # apply concentric mapping to point
100 | cond = np.abs(x) > np.abs(y)
101 | r = np.where(cond, x, y)
102 | theta = np.where(cond,
103 | self.pi_over_4 * (y / (x + np.finfo(float).eps)),
104 | self.pi_over_2 - self.pi_over_4 * (x / (y + np.finfo(float).eps))
105 | )
106 |
107 | return r * np.cos(theta), r * np.sin(theta)
108 |
109 | # ----------------------------------------------------------------------------------------
110 | class Filter(PrettyPrinter):
111 | def __init__(self, radius):
112 | self.radius = radius
113 | def eval(self, p):
114 | raise NotImplementedError()
115 |
116 | class Box(Filter):
117 | def __init__(self, radius=None):
118 | if radius is None:
119 | radius = [0.5, 0.5]
120 | Filter.__init__(self, radius)
121 | def eval(self, x):
122 | return torch.ones_like(x)
123 |
124 | class Triangle(Filter):
125 | def __init__(self, radius):
126 | if radius is None:
127 | radius = [2.0, 2.0]
128 | Filter.__init__(self, radius)
129 | def eval(self, p):
130 | x, y = p[...,0], p[...,1]
131 | return (torch.maximum(torch.zeros_like(x), self.radius[0] - x) *
132 | torch.maximum(torch.zeros_like(y), self.radius[1] - y))
133 |
134 | # ----------------------------------------------------------------------------------------
135 |
136 | class Material(PrettyPrinter):
137 | def __init__(self, name=None):
138 | self.name = 'vacuum' if name is None else name.lower()
139 | self.MATERIAL_TABLE = { # [nD, Abbe number]
140 | "vacuum": [1., math.inf],
141 | "air": [1.000293, math.inf],
142 | "occulder": [1., math.inf],
143 | "f2": [1.620, 36.37],
144 | "f15": [1.60570, 37.831],
145 | "uvfs": [1.458, 67.82],
146 |
147 | # https://shop.schott.com/advanced_optics/
148 | "bk10": [1.49780, 66.954],
149 | "n-baf10": [1.67003, 47.11],
150 | "n-bk7": [1.51680, 64.17],
151 | "n-sf1": [1.71736, 29.62],
152 | "n-sf2": [1.64769, 33.82],
153 | "n-sf4": [1.75513, 27.38],
154 | "n-sf5": [1.67271, 32.25],
155 | "n-sf6": [1.80518, 25.36],
156 | "n-sf6ht": [1.80518, 25.36],
157 | "n-sf8": [1.68894, 31.31],
158 | "n-sf10": [1.72828, 28.53],
159 | "n-sf11": [1.78472, 25.68],
160 | "sf1": [1.71736, 29.51],
161 | "sf2": [1.64769, 33.85],
162 | "sf4": [1.75520, 27.58],
163 | "sf5": [1.67270, 32.21],
164 | "sf6": [1.80518, 25.43],
165 | "sf18": [1.72150, 29.245],
166 |
167 | # HIKARI.AGF
168 | "baf10": [1.67, 47.05],
169 |
170 | # SUMITA.AGF
171 | "sk16": [1.62040, 60.306],
172 | "sk1": [1.61030, 56.712],
173 | "ssk4": [1.61770, 55.116],
174 |
175 | # https://www.pgo-online.com/intl/B270.html
176 | "b270": [1.52290, 58.50],
177 |
178 | # https://refractiveindex.info, nd at 589.3 [nm]
179 | "s-nph1": [1.8078, 22.76],
180 | "d-k59": [1.5175, 63.50],
181 |
182 | "flint": [1.6200, 36.37],
183 | "pmma": [1.491756, 58.00],
184 | "polycarb": [1.585470, 30.00]
185 | }
186 | self.A, self.B = self._lookup_material()
187 |
188 | def ior(self, wavelength):
189 | """Computes index of refraction of a given wavelength (in [nm])"""
190 | return self.A + self.B / wavelength**2
191 |
192 | @staticmethod
193 | def nV_to_AB(n, V):
194 | def ivs(a): return 1./a**2
195 | lambdas = [656.3, 589.3, 486.1]
196 | B = (n - 1) / V / ( ivs(lambdas[2]) - ivs(lambdas[0]) )
197 | A = n - B * ivs(lambdas[1])
198 | return A, B
199 |
200 | def _lookup_material(self):
201 | out = self.MATERIAL_TABLE.get(self.name)
202 | if isinstance(out, list):
203 | n, V = out
204 | elif out is None:
205 | # try parsing input as a n/V pair
206 | tmp = self.name.split('/')
207 | n, V = float(tmp[0]), float(tmp[1])
208 | return self.nV_to_AB(n, V)
209 |
210 |
211 | class InterpolationMode(Enum):
212 | nearest = 1
213 | linear = 2
214 |
215 | class BoundaryMode(Enum):
216 | zero = 1
217 | replicate = 2
218 | symmetric = 3
219 | periodic = 4
220 |
221 | class SimulationMode(Enum):
222 | render = 1
223 | trace = 2
224 |
225 | # ----------------------------------------------------------------------------------------
226 |
227 | def init():
228 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
229 | print("DiffMetrology is using: {}".format(device))
230 | torch.set_default_tensor_type('torch.FloatTensor')
231 | return device
232 |
233 | def normalize(d):
234 | return d / torch.sqrt(torch.sum(d**2, axis=-1))[..., None]
235 |
236 | def set_zeros(x, valid=None):
237 | if valid == None:
238 | return torch.where(torch.isnan(x), torch.zeros_like(x), x)
239 | else:
240 | mask = valid[...,None] if len(x.shape) > len(valid.shape) else valid
241 | return torch.where(~mask, torch.zeros_like(x), x)
242 |
243 | def rodrigues_rotation_matrix(k, theta): # theta: [rad]
244 | # cross-product matrix
245 | kx, ky, kz = k[0], k[1], k[2]
246 | K = torch.Tensor([
247 | [ 0, -kz, ky],
248 | [ kz, 0, -kx],
249 | [-ky, kx, 0]
250 | ]).to(k.device)
251 | if not torch.is_tensor(theta):
252 | theta = torch.Tensor(np.asarray(theta)).to(k.device)
253 | return torch.eye(3, device=k.device) + torch.sin(theta) * K + (1 - torch.cos(theta)) * K @ K
254 |
255 | def set_axes_equal(ax, scale=np.ones(3)):
256 | """
257 | Make axes of 3D plot have equal scale (or scaled by `scale`).
258 | """
259 | limits = np.array([
260 | ax.get_xlim3d(),
261 | ax.get_ylim3d(),
262 | ax.get_zlim3d()
263 | ])
264 | tmp = np.abs(limits[:,1]-limits[:,0])
265 | ax.set_box_aspect(scale * tmp/np.min(tmp))
266 |
267 | # ----------------------------------------------------------------------------------------
268 |
269 | def generate_test_rays():
270 | filmsize = np.array([4, 2])
271 |
272 | o = np.array([3,4,-200])
273 | o = np.tile(o[None, None, ...], [*filmsize, 1])
274 | o = torch.Tensor(o)
275 |
276 | # d = np.array([0.02, -0.03, 1])
277 | # d = np.tile(d[None, None, ...], [*filmsize, 1])
278 | dx = 0.1 * torch.rand(*filmsize)
279 | dy = 0.1 * torch.rand(*filmsize)
280 | d = normalize(torch.stack((dx, dy, torch.ones_like(dx)), axis=-1))
281 |
282 | wavelength = 500 # [nm]
283 | return Ray(o, d, wavelength)
284 |
285 | def generate_test_transformation():
286 | k = np.random.rand(3)
287 | k = k / np.sqrt(np.sum(k**2))
288 | theta = 1 # [rad]
289 | R = rodrigues_rotation_matrix(k, theta)
290 | t = np.random.rand(3)
291 | return Transformation(R, t)
292 |
293 | def generate_test_material():
294 | return Material('N-BK7')
295 |
296 | # ----------------------------------------------------------------------------------------
297 |
298 |
299 | if __name__ == "__main__":
300 | init()
301 |
302 | rays = generate_test_rays()
303 | to_world = generate_test_transformation()
304 | # print(to_world)
305 |
306 | rays_new = to_world.transform_ray(rays)
307 | o_old = rays.o[2,1,...].numpy()
308 | o_new = rays_new.o[2,1,...].numpy()
309 | assert np.sum(np.abs(to_world.R.numpy() @ o_old + to_world.t.numpy() - o_new)) < 1e-15
310 |
311 | material = generate_test_material()
312 | # print(material)
313 |
--------------------------------------------------------------------------------
/diffmetrology/optics.py:
--------------------------------------------------------------------------------
1 | from .basics import *
2 | from .shapes import *
3 | from scipy.interpolate import LSQBivariateSpline
4 | from datetime import datetime
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | class Step(torch.autograd.Function):
9 | @staticmethod
10 | def forward(ctx, input, eps):
11 | ctx.constant = eps
12 | ctx.save_for_backward(input)
13 | return (input > 0).float()
14 |
15 | @staticmethod
16 | def backward(ctx, grad_output):
17 | input, = ctx.saved_tensors
18 | return grad_output * torch.exp(-(ctx.constant*input)**2), None
19 |
20 | def ind(x, eps=0.5):
21 | return Step.apply(x, eps)
22 |
23 |
24 | class Lensgroup(Endpoint):
25 | """
26 | The Lensgroup (consisted of multiple optical surfaces) is mounted on a rod, whose
27 | origin is `origin`. The Lensgroup has full degree-of-freedom to rotate around the
28 | x/y axes, with the rotation angles defined as `theta_x`, `theta_y`, and `theta_z` (in degree).
29 |
30 | In Lensgroup's coordinate (i.e. object frame coordinate), surfaces are allocated
31 | starting from `z = 0`. There is an additional, comparatively small 3D origin shift
32 | (`shift`) between the surface center (0,0,0) and the origin of the mount, i.e.
33 | shift + origin = lensgroup_origin.
34 |
35 | There are two configurations of ray tracing: forward and backward. In forward mode,
36 | rays start from `d = 0` surface and propagate along the +z axis; In backward mode,
37 | rays start from `d = d_max` surface and propagate along the -z axis.
38 | """
39 | def __init__(self, origin, shift, theta_x=0., theta_y=0., theta_z=0., device=torch.device('cpu')):
40 | self.origin = torch.Tensor(origin).to(device)
41 | self.shift = torch.Tensor(shift).to(device)
42 | self.theta_x = torch.Tensor(np.asarray(theta_x)).to(device)
43 | self.theta_y = torch.Tensor(np.asarray(theta_y)).to(device)
44 | self.theta_z = torch.Tensor(np.asarray(theta_z)).to(device)
45 | self.device = device
46 |
47 | Endpoint.__init__(self, self._compute_transformation(), device)
48 |
49 | # TODO: in case you would like to render something ...
50 | self.mts_prepared = False
51 |
52 | def load_file(self, filename):
53 | LENSPATH = './lenses/'
54 | filename = filename if filename[0] == '.' else LENSPATH + filename
55 | self.surfaces, self.materials, self.r_last, d_last = self.read_lensfile(filename)
56 | self.d_sensor = d_last + self.surfaces[-1].d
57 | self._sync()
58 |
59 | def load(self, surfaces, materials):
60 | self.surfaces = surfaces
61 | self.materials = materials
62 | self._sync()
63 |
64 | def _sync(self):
65 | for i in range(len(self.surfaces)):
66 | self.surfaces[i].to(self.device)
67 |
68 | def update(self, _x=0.0, _y=0.0):
69 | self.to_world = self._compute_transformation(_x, _y)
70 | self.to_object = self.to_world.inverse()
71 |
72 | def _compute_transformation(self, _x=0.0, _y=0.0, _z=0.0):
73 | # we compute to_world transformation given the input positional parameters (angles)
74 | R = ( rodrigues_rotation_matrix(torch.Tensor([1, 0, 0]).to(self.device), torch.deg2rad(self.theta_x+_x)) @
75 | rodrigues_rotation_matrix(torch.Tensor([0, 1, 0]).to(self.device), torch.deg2rad(self.theta_y+_y)) @
76 | rodrigues_rotation_matrix(torch.Tensor([0, 0, 1]).to(self.device), torch.deg2rad(self.theta_z+_z)) )
77 | t = self.origin + R @ self.shift
78 | return Transformation(R, t)
79 |
80 | @staticmethod
81 | def read_lensfile(filename):
82 | surfaces = []
83 | materials = []
84 | ds = [] # no use for now
85 | with open(filename) as file:
86 | line_no = 0
87 | d_total = 0.
88 | for line in file:
89 | if line_no < 2: # first two lines are comments; ignore them
90 | line_no += 1
91 | else:
92 | ls = line.split()
93 | surface_type, d, r = ls[0], float(ls[1]), float(ls[3])/2
94 | roc = float(ls[2])
95 | if roc != 0: roc = 1/roc
96 | materials.append(Material(ls[4]))
97 |
98 | d_total += d
99 | ds.append(d)
100 |
101 | if surface_type == 'O': # object
102 | d_total = 0.
103 | ds.pop()
104 | elif surface_type == 'X': # XY-polynomial
105 | del roc
106 | ai = []
107 | for ac in range(5, len(ls)):
108 | if ac == 5:
109 | b = float(ls[5])
110 | else:
111 | ai.append(float(ls[ac]))
112 | surfaces.append(XYPolynomial(r, d_total, J=2, ai=ai, b=b))
113 | elif surface_type == 'B': # B-spline
114 | del roc
115 | ai = []
116 | for ac in range(5, len(ls)):
117 | if ac == 5:
118 | nx = int(ls[5])
119 | elif ac == 6:
120 | ny = int(ls[6])
121 | else:
122 | ai.append(float(ls[ac]))
123 | tx = ai[:nx+8]
124 | ai = ai[nx+8:]
125 | ty = ai[:ny+8]
126 | ai = ai[ny+8:]
127 | c = ai
128 | surfaces.append(BSpline(r, d, size=[nx, ny], tx=tx, ty=ty, c=c))
129 | elif surface_type == 'M': # mixed-type of X and B
130 | raise NotImplementedError()
131 | elif surface_type == 'S': # aspheric surface
132 | if len(ls) <= 5:
133 | surfaces.append(Aspheric(r, d_total, roc))
134 | else:
135 | ai = []
136 | for ac in range(5, len(ls)):
137 | if ac == 5:
138 | conic = float(ls[5])
139 | else:
140 | ai.append(float(ls[ac]))
141 | surfaces.append(Aspheric(r, d_total, roc, conic, ai))
142 | elif surface_type == 'A': # aperture
143 | surfaces.append(Aspheric(r, d_total, roc))
144 | elif surface_type == 'I': # sensor
145 | d_total -= d
146 | ds.pop()
147 | materials.pop()
148 | r_last = r
149 | d_last = d
150 | return surfaces, materials, r_last, d_last
151 |
152 | def reverse(self):
153 | # reverse surfaces
154 | d_total = self.surfaces[-1].d
155 | for i in range(len(self.surfaces)):
156 | self.surfaces[i].d = d_total - self.surfaces[i].d
157 | self.surfaces[i].reverse()
158 | self.surfaces.reverse()
159 |
160 | # reverse materials
161 | self.materials.reverse()
162 |
163 | # ------------------------------------------------------------------------------------
164 | # Analysis
165 | # ------------------------------------------------------------------------------------
166 | def rms(self, ps, units=1e3, option='centroid'):
167 | ps = ps[...,:2] * units
168 | if option == 'centroid':
169 | ps_mean = torch.mean(ps, axis=0) # centroid
170 | ps = ps - ps_mean[None,...] # we now use normalized ps
171 | spot_rms = torch.sqrt(torch.mean(torch.sum(ps**2, axis=-1)))
172 | return spot_rms
173 |
174 | def spot_diagram(self, ps, show=True, xlims=None, ylims=None, color='b.'):
175 | """
176 | Plot spot diagram.
177 | """
178 | units = 1e3
179 | units_str = '[um]'
180 | # units = 1
181 | # units_str = '[mm]'
182 | spot_rms = float(self.rms(ps, units))
183 | ps = ps.cpu().detach().numpy()[...,:2] * units
184 | ps_mean = np.mean(ps, axis=0) # centroid
185 | ps = ps - ps_mean[None,...] # we now use normalized ps
186 |
187 | fig = plt.figure()
188 | ax = plt.axes()
189 | ax.plot(ps[...,1], ps[...,0], color) # permute axe 0 and 1
190 | plt.gca().set_aspect('equal', adjustable='box')
191 | plt.xlabel('x ' + units_str)
192 | plt.ylabel('y ' + units_str)
193 | plt.title("Spot diagram, RMS = " + str(round(spot_rms,3)) + ' ' + units_str)
194 | if xlims is not None:
195 | plt.xlim(*xlims)
196 | if ylims is not None:
197 | plt.ylim(*ylims)
198 | ax.set_aspect(1./ax.get_data_ratio())
199 |
200 | fig.savefig("spotdiagram_" + datetime.now().strftime('%Y%m%d-%H%M%S-%f') + ".pdf", bbox_inches='tight')
201 | if show: plt.show()
202 | else: plt.close()
203 | return spot_rms
204 |
205 | # ------------------------------------------------------------------------------------
206 |
207 | # ------------------------------------------------------------------------------------
208 | # IO and visualizations
209 | # ------------------------------------------------------------------------------------
210 | def draw_points(self, ax, options, seq=range(3)):
211 | for surface in self.surfaces:
212 | points_world = self._generate_points(surface)
213 | ax.plot(points_world[seq[0]], points_world[seq[1]], points_world[seq[2]], options)
214 |
215 | # ------------------------------------------------------------------------------------
216 |
217 | # ------------------------------------------------------------------------------------
218 |
219 | def trace(self, ray, stop_ind=None):
220 | # update transformation when doing pose estimation
221 | if (
222 | self.origin.requires_grad or self.shift.requires_grad
223 | or
224 | self.theta_x.requires_grad or self.theta_y.requires_grad or self.theta_z.requires_grad
225 | ):
226 | self.update()
227 |
228 | # in local
229 | ray_in = self.to_object.transform_ray(ray)
230 | valid, mask_g, ray_out = self._trace(ray_in, stop_ind=stop_ind, record=False)
231 |
232 | # in world
233 | ray_final = self.to_world.transform_ray(ray_out)
234 |
235 | return ray_final, valid, mask_g
236 |
237 | # ------------------------------------------------------------------------------------
238 |
239 | def _refract(self, wi, n, eta, approx=False):
240 | """
241 | Snell's law (surface normal n defined along the positive z axis).
242 | """
243 | if np.prod(eta.shape) > 1:
244 | eta_ = eta[..., None]
245 | else:
246 | eta_ = eta
247 |
248 | cosi = torch.sum(wi * n, axis=-1)
249 |
250 | if approx:
251 | tmp = 1. - eta**2 * (1. - cosi)
252 | g = tmp
253 | valid = tmp > 0.
254 | wt = tmp[..., None] * n + eta_ * (wi - cosi[..., None] * n)
255 | else:
256 | cost2 = 1. - (1. - cosi**2) * eta**2
257 |
258 | # 1. get valid map; 2. zero out invalid points; 3. add eps to avoid NaN grad at cost2==0.
259 | g = cost2
260 | valid = cost2 > 0.
261 | cost2 = torch.clamp(cost2, min=1e-8)
262 | tmp = torch.sqrt(cost2)
263 |
264 | wt = tmp[..., None] * n + eta_ * (wi - cosi[..., None] * n)
265 | return valid, wt, g
266 |
267 | def _trace(self, ray, stop_ind=None, record=False):
268 | if stop_ind is None:
269 | stop_ind = len(self.surfaces)-1 # last index to stop
270 | is_forward = (ray.d[..., 2] > 0).all()
271 |
272 | if is_forward:
273 | return self._forward_tracing(ray, stop_ind, record)
274 | else:
275 | return self._backward_tracing(ray, stop_ind, record)
276 |
277 | def _forward_tracing(self, ray, stop_ind, record):
278 | wavelength = ray.wavelength
279 | dim = ray.o[..., 2].shape
280 |
281 | if record:
282 | oss = []
283 | for i in range(dim[0]):
284 | oss.append([ray.o[i,:].cpu().detach().numpy()])
285 |
286 | valid = torch.ones(dim, device=self.device).bool()
287 | mask = torch.ones(dim, device=self.device)
288 | for i in range(stop_ind+1):
289 | eta = self.materials[i].ior(wavelength) / self.materials[i+1].ior(wavelength)
290 |
291 | # ray intersecting surface
292 | valid_o, p, g_o = self.surfaces[i].ray_surface_intersection(ray, valid)
293 |
294 | # get surface normal and refract
295 | n = self.surfaces[i].normal(p[..., 0], p[..., 1])
296 | valid_d, d, g_d = self._refract(ray.d, -n, eta)
297 |
298 | # check validity
299 | mask = mask * ind(g_o) * ind(g_d)
300 | valid = valid & valid_o & valid_d
301 | if not valid.any():
302 | break
303 |
304 | # update ray {o,d}
305 | if record:
306 | for os, v, pp in zip(oss, valid.cpu().detach().numpy(), p.cpu().detach().numpy()):
307 | if v: os.append(pp)
308 | ray.o = p
309 | ray.d = d
310 |
311 | if record:
312 | return valid, mask, ray, oss
313 | else:
314 | return valid, mask, ray
315 |
316 | def _backward_tracing(self, ray, stop_ind, record):
317 | wavelength = ray.wavelength
318 | dim = ray.o[..., 2].shape
319 |
320 | if record:
321 | oss = []
322 | for i in range(dim[0]):
323 | oss.append([ray.o[i,:].cpu().detach().numpy()])
324 |
325 | valid = torch.ones(dim, device=ray.o.device).bool()
326 | mask = torch.ones(dim, device=ray.o.device)
327 | for i in np.flip(range(stop_ind+1)):
328 | surface = self.surfaces[i]
329 | eta = self.materials[i+1].ior(wavelength) / self.materials[i].ior(wavelength)
330 |
331 | # ray intersecting surface
332 | valid_o, p, g_o = surface.ray_surface_intersection(ray, valid)
333 |
334 | # get surface normal and refract
335 | n = surface.normal(p[..., 0], p[..., 1])
336 | valid_d, d, g_d = self._refract(ray.d, n, eta) # backward: no need to revert the normal
337 |
338 | # check validity
339 | mask = mask * ind(g_o) * ind(g_d)
340 | valid = valid & valid_o & valid_d
341 | if not valid.any():
342 | break
343 |
344 | # update ray {o,d}
345 | if record:
346 | for os, v, pp in zip(oss, valid.numpy(), p.cpu().detach().numpy()):
347 | if v: os.append(pp)
348 | ray.o = p
349 | ray.d = d
350 |
351 | if record:
352 | return valid, mask, ray, oss
353 | else:
354 | return valid, mask, ray
355 |
356 | def _generate_points(self, surface, with_boundary=False):
357 | R = surface.r
358 | x = y = torch.linspace(-R, R, surface.APERTURE_SAMPLING, device=self.device)
359 | X, Y = torch.meshgrid(x, y)
360 | Z = surface.surface_with_offset(X, Y)
361 | valid = X**2 + Y**2 <= R**2
362 | if with_boundary:
363 | from scipy import ndimage
364 | tmp = ndimage.convolve(valid.cpu().numpy().astype('float'), np.array([[0,1,0],[1,0,1],[0,1,0]]))
365 | boundary = valid.cpu().numpy() & (tmp != 4)
366 | boundary = boundary[valid.cpu().numpy()].flatten()
367 | points_local = torch.stack(tuple(v[valid].flatten() for v in [X, Y, Z]), axis=-1)
368 | points_world = self.to_world.transform_point(points_local).T.cpu().detach().numpy()
369 | if with_boundary:
370 | return points_world, boundary
371 | else:
372 | return points_world
373 |
374 | class Surface(PrettyPrinter):
375 | def __init__(self, r, d, device=torch.device('cpu')):
376 | # self.r = torch.Tensor(np.array(r))
377 | if torch.is_tensor(d):
378 | self.d = d
379 | else:
380 | self.d = torch.Tensor(np.asarray(float(d))).to(device)
381 | self.r = float(r)
382 | self.device = device
383 | self.NEWTONS_MAXITER = 10
384 | self.NEWTONS_TOLERANCE_TIGHT = 50e-6 # in [mm], i.e. 50 [nm] here (up to <10 [nm])
385 | self.NEWTONS_TOLERANCE_LOOSE = 300e-6 # in [mm], i.e. 300 [nm] here (up to <10 [nm])
386 | self.APERTURE_SAMPLING = 11
387 |
388 | # === Common methods (must not be overridden)
389 | def surface_with_offset(self, x, y):
390 | return self.surface(x, y) + self.d
391 |
392 | def normal(self, x, y):
393 | ds_dxyz = self.surface_derivatives(x, y)
394 | return normalize(torch.stack(ds_dxyz, axis=-1))
395 |
396 | def surface_area(self):
397 | return math.pi * self.r**2
398 |
399 | def ray_surface_intersection(self, ray, active=None):
400 | """
401 | Returns:
402 | - g >= 0: valid or not
403 | - p: intersection point
404 | - g: explicit funciton
405 | """
406 | solution_found, local = self.newtons_method(ray.maxt, ray.o, ray.d)
407 | r2 = local[..., 0]**2 + local[..., 1]**2
408 | g = self.r**2 - r2
409 | if active is None:
410 | valid_o = solution_found & ind(g > 0.).bool()
411 | else:
412 | valid_o = active & solution_found & ind(g > 0.).bool()
413 | return valid_o, local, g
414 |
415 | def newtons_method_impl(self, maxt, t0, dx, dy, dz, ox, oy, oz, A, B, C):
416 | t_delta = torch.zeros_like(oz)
417 |
418 | # Iterate until the intersection error is small
419 | t = maxt * torch.ones_like(oz)
420 | residual = maxt * torch.ones_like(oz)
421 | it = 0
422 | while (torch.abs(residual) > self.NEWTONS_TOLERANCE_TIGHT).any() and (it < self.NEWTONS_MAXITER):
423 | it += 1
424 | t = t0 + t_delta
425 | residual, s_derivatives_dot_D = self.surface_and_derivatives_dot_D(
426 | t, dx, dy, dz, ox, oy, t_delta * dz, A, B, C # here z = t_delta * dz
427 | )
428 | t_delta -= residual / s_derivatives_dot_D
429 | t = t0 + t_delta
430 | valid = (torch.abs(residual) < self.NEWTONS_TOLERANCE_LOOSE) & (t <= maxt)
431 | return t, t_delta, valid
432 |
433 | def newtons_method(self, maxt, o, D, option='implicit'):
434 | # Newton's method to find the root of the ray-surface intersection point.
435 | # Two modes are supported here:
436 | #
437 | # 1. 'explicit": This implements the loop using autodiff, and gradients will be
438 | # accurate for o, D, and self.parameters. Slow and memory-consuming.
439 | #
440 | # 2. 'implicit": This implements the loop as proposed in the paper, it finds the
441 | # solution without autodiff, then hook up the gradient. Less memory consumption.
442 |
443 | # pre-compute constants
444 | ox, oy, oz = (o[..., i].clone() for i in range(3))
445 | dx, dy, dz = (D[..., i].clone() for i in range(3))
446 | A = dx**2 + dy**2
447 | B = 2 * (dx * ox + dy * oy)
448 | C = ox**2 + oy**2
449 |
450 | # initial guess of t
451 | t0 = (self.d - oz) / dz
452 |
453 | if option == 'explicit':
454 | t, t_delta, valid = self.newtons_method_impl(
455 | maxt, t0, dx, dy, dz, ox, oy, oz, A, B, C
456 | )
457 | elif option == 'implicit':
458 | with torch.no_grad():
459 | t, t_delta, valid = self.newtons_method_impl(
460 | maxt, t0, dx, dy, dz, ox, oy, oz, A, B, C
461 | )
462 | s_derivatives_dot_D = self.surface_and_derivatives_dot_D(
463 | t, dx, dy, dz, ox, oy, t_delta * dz, A, B, C
464 | )[1]
465 | t = t0 + t_delta # re-engage autodiff
466 |
467 | t = t - (self.g(ox + t * dx, oy + t * dy) + self.h(oz + t * dz) + self.d)/s_derivatives_dot_D
468 | else:
469 | raise Exception('option={} is not available!'.format(option))
470 |
471 | p = o + t[..., None] * D
472 | return valid, p
473 |
474 | # === Virtual methods (must be overridden)
475 | def g(self, x, y):
476 | raise NotImplementedError()
477 |
478 | def dgd(self, x, y):
479 | """
480 | Derivatives of g: (g'x, g'y).
481 | """
482 | raise NotImplementedError()
483 |
484 | def h(self, z):
485 | raise NotImplementedError()
486 |
487 | def dhd(self, z):
488 | """
489 | Derivative of h.
490 | """
491 | raise NotImplementedError()
492 |
493 | def surface(self, x, y):
494 | """
495 | Solve z from h(z) = -g(x,y).
496 | """
497 | raise NotImplementedError()
498 |
499 | def reverse(self):
500 | raise NotImplementedError()
501 |
502 | # === Default methods (better be overridden)
503 | def surface_derivatives(self, x, y):
504 | """
505 | Returns \nabla f = \nabla (g(x,y) + h(z)) = (dg/dx, dg/dy, dh/dz).
506 | (Note: this default implementation is not efficient)
507 | """
508 | gx, gy = self.dgd(x, y)
509 | z = self.surface(x, y)
510 | return gx, gy, self.dhd(z)
511 |
512 | def surface_and_derivatives_dot_D(self, t, dx, dy, dz, ox, oy, z, A, B, C):
513 | """
514 | Returns g(x,y)+h(z) and dot((g'x,g'y,h'), (dx,dy,dz)).
515 | (Note: this default implementation is not efficient)
516 | """
517 | x = ox + t * dx
518 | y = oy + t * dy
519 | s = self.g(x,y) + self.h(z)
520 | sx, sy = self.dgd(x, y)
521 | sz = self.dhd(z)
522 | return s, sx*dx + sy*dy + sz*dz
523 |
524 |
525 | class Aspheric(Surface):
526 | """
527 | Aspheric surface: https://en.wikipedia.org/wiki/Aspheric_lens.
528 | """
529 | def __init__(self, r, d, c=0., k=0., ai=None, device=torch.device('cpu')):
530 | Surface.__init__(self, r, d, device)
531 | self.c, self.k = (torch.Tensor(np.array(v)) for v in [c, k])
532 | self.ai = None
533 | if ai is not None:
534 | self.ai = torch.Tensor(np.array(ai))
535 |
536 | # === Common methods
537 | def g(self, x, y):
538 | return self._g(x**2 + y**2)
539 |
540 | def dgd(self, x, y):
541 | dsdr2 = 2 * self._dgd(x**2 + y**2)
542 | return dsdr2*x, dsdr2*y
543 |
544 | def h(self, z):
545 | return -z
546 |
547 | def dhd(self, z):
548 | return -torch.ones_like(z)
549 |
550 | def surface(self, x, y):
551 | return self._g(x**2 + y**2)
552 |
553 | def reverse(self):
554 | self.c = -self.c
555 | if self.ai is not None:
556 | self.ai = -self.ai
557 |
558 | def surface_derivatives(self, x, y):
559 | dsdr2 = 2 * self._dgd(x**2 + y**2)
560 | return dsdr2*x, dsdr2*y, -torch.ones_like(x)
561 |
562 | def surface_and_derivatives_dot_D(self, t, dx, dy, dz, ox, oy, z, A, B, C):
563 | r2 = A * t**2 + B * t + C
564 | return self._g(r2) - z, self._dgd(r2) * (2*A*t + B) - dz
565 |
566 | # === Private methods
567 | def _g(self, r2):
568 | tmp = r2*self.c
569 | total_surface = tmp / (1 + torch.sqrt(1 - (1+self.k) * tmp*self.c))
570 | higher_surface = 0
571 | if self.ai is not None:
572 | for i in np.flip(range(len(self.ai))):
573 | higher_surface = r2 * higher_surface + self.ai[i]
574 | higher_surface = higher_surface * r2**2
575 | return total_surface + higher_surface
576 |
577 | def _dgd(self, r2):
578 | alpha_r2 = (1 + self.k) * self.c**2 * r2
579 | tmp = torch.sqrt(1 - alpha_r2) # TODO: potential NaN grad
580 | total_derivative = self.c * (1 + tmp - 0.5*alpha_r2) / (tmp * (1 + tmp)**2)
581 |
582 | higher_derivative = 0
583 | if self.ai is not None:
584 | for i in np.flip(range(len(self.ai))):
585 | higher_derivative = r2 * higher_derivative + (i+2) * self.ai[i]
586 | return total_derivative + higher_derivative * r2
587 |
588 |
589 | # ----------------------------------------------------------------------------------------
590 |
591 | class BSpline(Surface):
592 | """
593 | Implemented according to Wikipedia.
594 | """
595 | def __init__(self, r, d, size, px=3, py=3, tx=None, ty=None, c=None, device=torch.device('cpu')): # input c is 1D
596 | Surface.__init__(self, r, d, device)
597 | self.px = px
598 | self.py = py
599 | self.size = np.asarray(size)
600 |
601 | # knots
602 | if tx is None:
603 | self.tx = None
604 | else:
605 | if len(tx) != size[0] + 2*(self.px + 1):
606 | raise Exception('len(tx) is not correct!')
607 | self.tx = torch.Tensor(np.asarray(tx)).to(self.device)
608 | if ty is None:
609 | self.ty = None
610 | else:
611 | if len(ty) != size[1] + 2*(self.py + 1):
612 | raise Exception('len(ty) is not correct!')
613 | self.ty = torch.Tensor(np.asarray(ty)).to(self.device)
614 |
615 | # c is the only differentiable parameter
616 | c_shape = size + np.array([self.px, self.py]) + 1
617 | if c is None:
618 | self.c = None
619 | else:
620 | c = np.asarray(c)
621 | if c.size != np.prod(c_shape):
622 | raise Exception('len(c) is not correct!')
623 | self.c = torch.Tensor(c.reshape(*c_shape)).to(self.device)
624 |
625 | if (self.tx is None) or (self.ty is None) or (self.c is None):
626 | self.tx = self._generate_knots(self.r, size[0], p=px, device=device)
627 | self.ty = self._generate_knots(self.r, size[1], p=py, device=device)
628 | self.c = torch.zeros(*c_shape, device=device)
629 | else:
630 | self.to(self.device)
631 |
632 | @staticmethod
633 | def _generate_knots(R, n, p=3, device=torch.device('cpu')):
634 | t = np.linspace(-R, R, n)
635 | step = t[1] - t[0]
636 | T = t[0] - 0.9 * step
637 | np.pad(t, p+1, 'constant', constant_values=step)
638 | t = np.concatenate((np.ones(p+1)*T, t, -np.ones(p+1)*T), axis=0)
639 | return torch.Tensor(t).to(device)
640 |
641 | def fit(self, x, y, z, eps=1e-3):
642 | x, y, z = (v.flatten() for v in [x, y, z])
643 |
644 | # knot positions within [-r, r]^2
645 | X = np.linspace(-self.r, self.r, self.size[0])
646 | Y = np.linspace(-self.r, self.r, self.size[1])
647 | bs = LSQBivariateSpline(x, y, z, X, Y, kx=self.px, ky=self.py, eps=eps)
648 | tx, ty = bs.get_knots()
649 | c = bs.get_coeffs().reshape(len(tx)-self.px-1, len(ty)-self.py-1)
650 |
651 | # convert to torch.Tensor
652 | self.tx, self.ty, self.c = (torch.Tensor(v).to(self.device) for v in [tx, ty, c])
653 |
654 | # === Common methods
655 | def g(self, x, y):
656 | return self._deBoor2(x, y)
657 |
658 | def dgd(self, x, y):
659 | return self._deBoor2(x, y, dx=1), self._deBoor2(x, y, dy=1)
660 |
661 | def h(self, z):
662 | return -z
663 |
664 | def dhd(self, z):
665 | return -torch.ones_like(z)
666 |
667 | def surface(self, x, y):
668 | return self._deBoor2(x, y)
669 |
670 | def surface_derivatives(self, x, y):
671 | return self._deBoor2(x, y, dx=1), self._deBoor2(x, y, dy=1), -torch.ones_like(x)
672 |
673 | def surface_and_derivatives_dot_D(self, t, dx, dy, dz, ox, oy, z, A, B, C):
674 | x = ox + t * dx
675 | y = oy + t * dy
676 | s, sx, sy = self._deBoor2(x, y, dx=-1, dy=-1)
677 | return s - z, sx*dx + sy*dy - dz
678 |
679 | def reverse(self):
680 | self.c = -self.c
681 |
682 | # === Private methods
683 | def _deBoor(self, x, t, c, p=3, is2Dfinal=False, dx=0):
684 | """
685 | Arguments
686 | ---------
687 | x: Position.
688 | t: Array of knot positions, needs to be padded as described above.
689 | c: Array of control points.
690 | p: Degree of B-spline.
691 | dx:
692 | - 0: surface only
693 | - 1: surface 1st derivative only
694 | - -1: surface and its 1st derivative
695 | """
696 | k = torch.sum((x[None,...] > t[...,None]).int(), axis=0) - (p+1)
697 |
698 | if is2Dfinal:
699 | inds = np.indices(k.shape)[0]
700 | def _c(jk): return c[jk, inds]
701 | else:
702 | def _c(jk): return c[jk, ...]
703 |
704 | need_newdim = (len(c.shape) > 1) & (not is2Dfinal)
705 |
706 | def f(a, b, alpha):
707 | if need_newdim:
708 | alpha = alpha[...,None]
709 | return (1.0 - alpha) * a + alpha * b
710 |
711 | # surface only
712 | if dx == 0:
713 | d = [_c(j+k) for j in range(0, p+1)]
714 |
715 | for r in range(-p, 0):
716 | for j in range(p, p+r, -1):
717 | left = j+k
718 | t_left = t[left]
719 | t_right = t[left-r]
720 | alpha = (x - t_left) / (t_right - t_left)
721 | d[j] = f(d[j-1], d[j], alpha)
722 | return d[p]
723 |
724 | # surface 1st derivative only
725 | if dx == 1:
726 | q = []
727 | for j in range(1, p+1):
728 | jk = j+k
729 | tmp = t[jk+p] - t[jk]
730 | if need_newdim:
731 | tmp = tmp[..., None]
732 | q.append(p * (_c(jk) - _c(jk-1)) / tmp)
733 |
734 | for r in range(-p, -1):
735 | for j in range(p-1, p+r, -1):
736 | left = j+k
737 | t_right = t[left-r]
738 | t_left_ = t[left+1]
739 | alpha = (x - t_left_) / (t_right - t_left_)
740 | q[j] = f(q[j-1], q[j], alpha)
741 | return q[p-1]
742 |
743 | # surface and its derivative (all)
744 | if dx < 0:
745 | d, q = [], []
746 | for j in range(0, p+1):
747 | jk = j+k
748 | c_jk = _c(jk)
749 | d.append(c_jk)
750 | if j > 0:
751 | tmp = t[jk+p] - t[jk]
752 | if need_newdim:
753 | tmp = tmp[..., None]
754 | q.append(p * (c_jk - _c(jk-1)) / tmp)
755 |
756 | for r in range(-p, 0):
757 | for j in range(p, p+r, -1):
758 | left = j+k
759 | t_left = t[left]
760 | t_right = t[left-r]
761 | alpha = (x - t_left) / (t_right - t_left)
762 | d[j] = f(d[j-1], d[j], alpha)
763 |
764 | if (r < -1) & (j < p):
765 | t_left_ = t[left+1]
766 | alpha = (x - t_left_) / (t_right - t_left_)
767 | q[j] = f(q[j-1], q[j], alpha)
768 | return d[p], q[p-1]
769 |
770 | def _deBoor2(self, x, y, dx=0, dy=0):
771 | """
772 | Arguments
773 | ---------
774 | x, y : Position.
775 | dx, dy:
776 | """
777 | if not torch.is_tensor(x):
778 | x = torch.Tensor(np.asarray(x)).to(self.device)
779 | if not torch.is_tensor(y):
780 | y = torch.Tensor(np.asarray(y)).to(self.device)
781 | dim = x.shape
782 |
783 | x = x.flatten()
784 | y = y.flatten()
785 |
786 | # handle boundary issue
787 | x = torch.clamp(x, min=-self.r, max=self.r)
788 | y = torch.clamp(y, min=-self.r, max=self.r)
789 |
790 | if (dx == 0) & (dy == 0): # spline
791 | s_tmp = self._deBoor(x, self.tx, self.c, self.px)
792 | s = self._deBoor(y, self.ty, s_tmp.T, self.py, True)
793 | return s.reshape(dim)
794 | elif (dx == 1) & (dy == 0): # x-derivative
795 | s_tmp = self._deBoor(y, self.ty, self.c.T, self.py)
796 | s_x = self._deBoor(x, self.tx, s_tmp.T, self.px, True, dx)
797 | return s_x.reshape(dim)
798 | elif (dy == 1) & (dx == 0): # y-derivative
799 | s_tmp = self._deBoor(x, self.tx, self.c, self.px)
800 | s_y = self._deBoor(y, self.ty, s_tmp.T, self.py, True, dy)
801 | return s_y.reshape(dim)
802 | else: # return all
803 | s_tmpx = self._deBoor(x, self.tx, self.c, self.px)
804 | s_tmpy = self._deBoor(y, self.ty, self.c.T, self.py)
805 | s, s_x = self._deBoor(x, self.tx, s_tmpy.T, self.px, True, -abs(dx))
806 | s_y = self._deBoor(y, self.ty, s_tmpx.T, self.py, True, abs(dy))
807 | return s.reshape(dim), s_x.reshape(dim), s_y.reshape(dim)
808 |
809 |
810 | class XYPolynomial(Surface):
811 | """
812 | General XY polynomial surface of equation of parameters:
813 |
814 | explicit: b z^2 - z + \sum{i,j} a_ij x^i y^{j-i} = 0
815 | implicit: (denote c = \sum{i,j} a_ij x^i y^{j-i})
816 | z = (1 - \sqrt{1 - 4 b c}) / (2b)
817 |
818 | explicit derivatives:
819 | (2 b z - 1) dz + \sum{i,j} a_ij x^{i-1} y^{j-i-1} ( i y dx + (j-i) x dy ) = 0
820 |
821 | dx = \sum{i,j} a_ij i x^{i-1} y^{j-i}
822 | dy = \sum{i,j} a_ij (j-i) x^{i} y^{j-i-1}
823 | dz = 2 b z - 1
824 | """
825 | def __init__(self, r, d, J=0, ai=None, b=None, device=torch.device('cpu')):
826 | Surface.__init__(self, r, d, device)
827 | self.J = J
828 | # differentiable parameters (default: all ai's and b are zeros)
829 | if ai is None:
830 | self.ai = torch.zeros(self.J2aisize(J)) if J > 0 else torch.array([0])
831 | else:
832 | if len(ai) != self.J2aisize(J):
833 | raise Exception("len(ai) != (J+1)*(J+2)/2 !")
834 | self.ai = torch.Tensor(ai).to(device)
835 | if b is None:
836 | b = 0.
837 | self.b = torch.Tensor(np.asarray(b)).to(device)
838 | print('ai.size = {}'.format(self.ai.shape[0]))
839 | self.to(self.device)
840 |
841 | @staticmethod
842 | def J2aisize(J):
843 | return int((J+1)*(J+2)/2)
844 |
845 | def center(self):
846 | x0 = -self.ai[2]/self.ai[5]
847 | y0 = -self.ai[1]/self.ai[3]
848 | return x0, y0
849 |
850 | def fit(self, x, y, z):
851 | x, y, z = (torch.Tensor(v.flatten()) for v in [x, y, z])
852 | A, AT = self._construct_A(x, y, z**2)
853 | coeffs = torch.solve(AT @ z[...,None], AT @ A)[0]
854 | self.b = coeffs[0][0]
855 | self.ai = coeffs[1:].flatten()
856 |
857 | # === Common methods
858 | def g(self, x, y):
859 | c = torch.zeros_like(x)
860 | count = 0
861 | for j in range(self.J+1):
862 | for i in range(j+1):
863 | c = c + self.ai[count] * torch.pow(x, i) * torch.pow(y, j-i)
864 | count += 1
865 | return c
866 |
867 | def dgd(self, x, y):
868 | sx = torch.zeros_like(x)
869 | sy = torch.zeros_like(x)
870 | count = 0
871 | for j in range(self.J+1):
872 | for i in range(j+1):
873 | if j > 0:
874 | sx = sx + self.ai[count] * i * torch.pow(x, max(i-1,0)) * torch.pow(y, j-i)
875 | sy = sy + self.ai[count] * (j-i) * torch.pow(x, i) * torch.pow(y, max(j-i-1,0))
876 | count += 1
877 | return sx, sy
878 |
879 | def h(self, z):
880 | return self.b * z**2 - z
881 |
882 | def dhd(self, z):
883 | return 2 * self.b * z - torch.ones_like(z)
884 |
885 | def surface(self, x, y):
886 | x, y = (v if torch.is_tensor(x) else torch.Tensor(v) for v in [x, y])
887 | c = self.g(x, y)
888 | return self._solve_for_z(c)
889 |
890 | def reverse(self):
891 | self.b = -self.b
892 | self.ai = -self.ai
893 |
894 | def surface_derivatives(self, x, y):
895 | x, y = (v if torch.is_tensor(x) else torch.Tensor(v) for v in [x, y])
896 | sx = torch.zeros_like(x)
897 | sy = torch.zeros_like(x)
898 | c = torch.zeros_like(x)
899 | count = 0
900 | for j in range(self.J+1):
901 | for i in range(j+1):
902 | c = c + self.ai[count] * torch.pow(x, i) * torch.pow(y, j-i)
903 | if j > 0:
904 | sx = sx + self.ai[count] * i * torch.pow(x, max(i-1,0)) * torch.pow(y, j-i)
905 | sy = sy + self.ai[count] * (j-i) * torch.pow(x, i) * torch.pow(y, max(j-i-1,0))
906 | count += 1
907 | z = self._solve_for_z(c)
908 | return sx, sy, self.dhd(z)
909 |
910 | def surface_and_derivatives_dot_D(self, t, dx, dy, dz, ox, oy, z, A, B, C):
911 | x = ox + t * dx
912 | y = oy + t * dy
913 | sx = torch.zeros_like(x)
914 | sy = torch.zeros_like(x)
915 | c = torch.zeros_like(x)
916 | count = 0
917 | for j in range(self.J+1):
918 | for i in range(j+1):
919 | c = c + self.ai[count] * torch.pow(x, i) * torch.pow(y, j-i)
920 | if j > 0:
921 | sx = sx + self.ai[count] * i * torch.pow(x, max(i-1,0)) * torch.pow(y, j-i)
922 | sy = sy + self.ai[count] * (j-i) * torch.pow(x, i) * torch.pow(y, max(j-i-1,0))
923 | count += 1
924 | s = c + self.h(z)
925 | return s, sx*dx + sy*dy + self.dhd(z)*dz
926 |
927 | # === Private methods
928 | def _construct_A(self, x, y, A_init=None):
929 | A = torch.zeros_like(x) if A_init == None else A_init
930 | for j in range(self.J+1):
931 | for i in range(j+1):
932 | A = torch.vstack((A, torch.pow(x, i) * torch.pow(y, j-i)))
933 | AT = A[1:,:] if A_init == None else A
934 | return AT.T, AT
935 |
936 | def _solve_for_z(self, c):
937 | if self.b == 0:
938 | return c
939 | else:
940 | return (1. - torch.sqrt(1. - 4*self.b*c)) / (2*self.b)
941 |
942 |
943 | # ----------------------------------------------------------------------------------------
944 |
945 | def generate_test_lensgroup():
946 | origin_mount = np.array([0, 0, -70])
947 | origin_shift = np.array([0.1, 0.2, 0.3])
948 | theta_x = 180
949 | theta_y = 0
950 |
951 | lensname = 'Thorlabs/AL50100-A.txt'
952 | lensgroup = Lensgroup(lensname, origin_mount, origin_shift, theta_x, theta_y)
953 | return lensgroup
954 |
955 | # ----------------------------------------------------------------------------------------
956 |
957 |
958 | if __name__ == "__main__":
959 | init()
960 |
961 | lensgroup = generate_test_lensgroup()
962 | # print(lensgroup)
963 |
964 | ray = generate_test_rays()
965 | ray_out, valid, mask = lensgroup.trace(ray)
966 |
967 | ray_out.d = -ray_out.d
968 | ray_out.update()
969 |
970 | ray_new, valid_, mask_ = lensgroup.trace(ray_out)
971 | # assert np.sum(np.abs(ray.d.numpy() + ray_new.d.numpy())) < 1e-5
972 |
--------------------------------------------------------------------------------
/diffmetrology/scene.py:
--------------------------------------------------------------------------------
1 | from .basics import *
2 | from .shapes import *
3 | from .optics import *
4 |
5 | class Scene(PrettyPrinter):
6 | def __init__(self, cameras, screen, lensgroup=None, device=torch.device('cpu')):
7 | self.cameras = cameras
8 | self.screen = screen
9 | self.lensgroup = lensgroup
10 | self.device = device
11 |
12 | self.camera_count = len(self.cameras)
13 |
14 | # rendering options
15 | self.wavelength = 500 # [nm]
16 |
17 | def render(self, i=None, with_element=True, mask=None, to_numpy=False):
18 | im = self._simulate(i, with_element, mask, SimulationMode.render)
19 | if to_numpy:
20 | im = [x.cpu().detach().numpy() for x in im]
21 | return im
22 |
23 | def trace(self, i=None, with_element=True, mask=None, to_numpy=False):
24 | results = self._simulate(i, with_element, mask, SimulationMode.trace)
25 | p = [x[0].cpu().detach().numpy() if to_numpy else x[0] for x in results]
26 | valid = [x[1].cpu().detach().numpy() if to_numpy else x[1] for x in results]
27 | mask_g = [x[2].cpu().detach().numpy() if to_numpy else x[2] for x in results]
28 | return p, valid, mask_g
29 |
30 | def plot_setup(self):
31 | fig = plt.figure()
32 | ax = fig.add_subplot(111, projection='3d')
33 |
34 | # generate color-ring
35 | if self.camera_count <= 6:
36 | colors = ['b','r','g','c','m','y']
37 | else:
38 | colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
39 | colors = colors * np.ceil(self.camera_count/len(colors)).astype(int)
40 | colors = colors[:self.camera_count]
41 |
42 | # draw screen
43 | seq = [0,2,1] # draw scene in [x,z,y] order
44 | self.screen.draw_points(ax, 'k', seq)
45 |
46 | # draw metrology part
47 | if self.lensgroup is not None:
48 | self.lensgroup.draw_points(ax, 'y.', seq)
49 |
50 | # draw cameras
51 | for i, camera in enumerate(self.cameras):
52 | camera.draw_points(ax, colors[i], seq)
53 |
54 | # pretty plot
55 | labels = 'xyz'
56 | scales = np.array([2,2,1])
57 | plt.locator_params(nbins=5)
58 | ax.set_xlabel(labels[seq[0]] + ' [mm]')
59 | ax.set_ylabel(labels[seq[1]] + ' [mm]')
60 | ax.set_zlabel(labels[seq[2]] + ' [mm]')
61 | ax.legend(['Display', 'Camera 1', 'Camera 2'])
62 | ax.get_legend().legendHandles[0].set_color('k')
63 | for i in range(self.camera_count):
64 | ax.get_legend().legendHandles[i+1].set_color(colors[i])
65 | ax.set_title('setup')
66 | set_axes_equal(ax, np.array([scales[seq[i]] for i in range(3)]))
67 | plt.show()
68 |
69 | def to(self, device=torch.device('cpu')):
70 | super().to(device)
71 |
72 | # set device name (TODO: make it more elegant)
73 | self.device = device
74 | self.lensgroup.device = device
75 | self.screen.device = device
76 | for i in range(self.camera_count):
77 | self.cameras[i].device = device
78 | for i in range(len(self.lensgroup.surfaces)):
79 | self.lensgroup.surfaces[i].device = device
80 |
81 | def _simulate(self, i=None, with_element=True, mask=None, smode=SimulationMode.render):
82 | def simulate(i):
83 | if mask is None: # default: full sensor rendering
84 | ray = self.cameras[i].sample_ray()
85 |
86 | # interaction with metrology part
87 | if with_element and self.lensgroup is not None:
88 | ray, valid_ray, mask_g = self.lensgroup.trace(ray)
89 | else:
90 | mask_g = torch.ones(ray.o.shape[0:2], device=self.device)
91 | valid_ray = mask_g.clone().bool()
92 |
93 | # interaction with screen
94 | # We permute the axes to be compatiable with the image viewer (and the data).
95 | p, uv, valid_screen = self.screen.intersect(ray)
96 | valid = valid_screen & valid_ray
97 |
98 | if smode is SimulationMode.render:
99 | del p
100 | return self.screen.shading(uv, valid).permute(1,0)
101 |
102 | elif smode is SimulationMode.trace:
103 | del uv
104 | return p.permute(1,0,2), valid.permute(1,0), mask_g.permute(1,0)
105 |
106 | else: # get corresponding indices from the mask
107 | mask_ = mask[i].permute(1,0) # we transpose mask to align with our code
108 | ix, iy = torch.where(mask_)
109 | p2 = self.cameras[i].generate_position_sample(mask_)
110 | ray = self.cameras[i].sample_ray(p2)
111 |
112 | # interaction with metrology part
113 | if with_element and self.lensgroup is not None:
114 | ray, valid_ray, mask_g_ = self.lensgroup.trace(ray)
115 | else:
116 | mask_g_ = torch.ones(ray.o.shape[0:2])
117 | valid_ray = mask_g_.clone().bool()
118 |
119 | # interaction with screen
120 | # We permute the axes to be compatiable with the image viewer (and the data).
121 | p_, uv, valid_screen = self.screen.intersect(ray)
122 |
123 | if smode is SimulationMode.render:
124 | del p_
125 | raise NotImplementedError()
126 |
127 | elif smode is SimulationMode.trace:
128 | del uv
129 | p = torch.zeros(*self.cameras[i].filmsize, 3, device=self.device)
130 | p[ix, iy, ...] = p_
131 | valid = torch.zeros(*self.cameras[i].filmsize, device=self.device).bool()
132 | valid[ix, iy] = mask_[ix, iy]
133 | mask_g = torch.zeros(*self.cameras[i].filmsize, device=self.device)
134 | mask_g[ix, iy] = mask_g_
135 | return p.permute(1,0,2), valid.permute(1,0), mask_g.permute(1,0)
136 |
137 | return [simulate(j) for j in range(self.camera_count)] if i is None else simulate(i)
138 |
139 | # ----------------------------------------------------------------------------------------
140 |
141 | def generate_test_scene():
142 | R = np.eye(3)
143 | ts = [np.array([0, 0, -300]), np.array([0, 0, -240])]
144 | cameras = [generate_test_camera(R, t) for t in ts]
145 | screen = generate_test_screen()
146 | lensgroup = generate_test_lensgroup()
147 | return Scene(cameras, screen, lensgroup)
148 |
149 | # ----------------------------------------------------------------------------------------
150 |
151 |
152 | if __name__ == "__main__":
153 | init()
154 |
155 | scene = generate_test_scene()
156 | scene.plot_setup()
157 |
158 | # render images
159 | imgs = scene.render()
160 | fig, ax = plt.subplots(1,2)
161 | for i, img in enumerate(imgs):
162 | ax[i].imshow(img, cmap='gray')
163 | ax[i].set_title('camera ' + str(i+1))
164 | plt.show()
165 |
--------------------------------------------------------------------------------
/diffmetrology/shapes.py:
--------------------------------------------------------------------------------
1 | from .basics import *
2 | from matplotlib.image import imread
3 |
4 |
5 | class Endpoint(PrettyPrinter):
6 | def __init__(self, transformation, device=torch.device('cpu')):
7 | self.to_world = transformation
8 | self.to_object = transformation.inverse()
9 | self.to_world.to(device)
10 | self.to_object.to(device)
11 | self.device = device
12 |
13 | def intersect(self, ray):
14 | raise NotImplementedError()
15 |
16 | def sample_ray(self, position_sample=None):
17 | raise NotImplementedError()
18 |
19 | def draw_points(self, ax, options, seq=range(3)):
20 | raise NotImplementedError()
21 |
22 |
23 | class Screen(Endpoint):
24 | """
25 | Local frame centers at [-w, w]/2 x [-h, h]/2
26 | """
27 | def __init__(self, transformation, size, pixelsize, texture, device=torch.device('cpu')):
28 | self.size = torch.Tensor(np.float32(size)) # screen dimension [mm]
29 | self.halfsize = self.size/2 # screen half-dimension [mm]
30 | self.pixelsize = torch.Tensor([pixelsize]) # screen pixel size [mm]
31 | self.texture = torch.Tensor(texture) # screen image
32 | self.texturesize = torch.Tensor(np.array(texture.shape[0:2])) # screen image dimension [pixel]
33 | self.texturesize_np = self.texturesize.cpu().detach().numpy() # screen image dimension [pixel]
34 | self.texture_shift = torch.zeros(2) # screen image shift [mm]
35 | Endpoint.__init__(self, transformation, device)
36 | self.to(device)
37 |
38 | def intersect(self, ray):
39 | ray_in = self.to_object.transform_ray(ray)
40 | t = - ray_in.o[..., 2] / ray_in.d[..., 2] # well-posed for dz (TODO: potential NaN grad)
41 | local = ray_in(t)
42 |
43 | # Is intersection within ray segment and rectangle?
44 | valid = (
45 | (t >= ray_in.mint) &
46 | (t <= ray_in.maxt) &
47 | (torch.abs(local[..., 0] - self.texture_shift[0]) <= self.halfsize[0]) &
48 | (torch.abs(local[..., 1] - self.texture_shift[1]) <= self.halfsize[1])
49 | )
50 |
51 | # uv map
52 | uv = (local[..., 0:2] + self.halfsize - self.texture_shift) / self.size
53 |
54 | # force uv to be valid in [0,1]^2 (just a sanity check: uv should be in [0,1]^2)
55 | uv = torch.clamp(uv, min=0.0, max=1.0)
56 |
57 | return local, uv, valid
58 |
59 | def shading(self, uv, valid, bmode=BoundaryMode.replicate, lmode=InterpolationMode.linear):
60 | p = uv * (self.texturesize-1)
61 | p_floor = torch.floor(p).long()
62 |
63 | def tex(x, y): # texture indexing function
64 | if bmode is BoundaryMode.zero:
65 | raise NotImplementedError()
66 | elif bmode is BoundaryMode.replicate:
67 | x = torch.clamp(x, min=0, max=self.texturesize_np[0]-1)
68 | y = torch.clamp(y, min=0, max=self.texturesize_np[1]-1)
69 | elif bmode is BoundaryMode.symmetric:
70 | raise NotImplementedError()
71 | elif bmode is BoundaryMode.periodic:
72 | raise NotImplementedError()
73 | img = self.texture[x.flatten(), y.flatten()]
74 | return img.reshape(x.shape)
75 |
76 | if lmode is InterpolationMode.nearest:
77 | val = tex(p_floor[...,0], p_floor[...,1])
78 | elif lmode is InterpolationMode.linear:
79 | x0, y0 = p_floor[...,0], p_floor[...,1]
80 | s00 = tex( x0, y0)
81 | s01 = tex( x0, 1+y0)
82 | s10 = tex(1+x0, y0)
83 | s11 = tex(1+x0, 1+y0)
84 | w1 = p - p_floor
85 | w0 = 1. - w1
86 | val = (
87 | w0[...,0] * (w0[...,1] * s00 + w1[...,1] * s01) +
88 | w1[...,0] * (w0[...,1] * s10 + w1[...,1] * s11)
89 | )
90 |
91 | val[~valid] = 0.0
92 | return val
93 |
94 | def draw_points(self, ax, options, seq=range(3)):
95 | coeffs = np.array([
96 | [ 1, 1, 1],
97 | [-1, 1, 1],
98 | [-1,-1, 1],
99 | [ 1,-1, 1],
100 | [ 1, 1, 1]
101 | ])
102 | points_local = torch.Tensor(coeffs * np.append(self.halfsize.cpu().detach().numpy(), 0)).to(self.device)
103 | points_world = self.to_world.transform_point(points_local).T.cpu().detach().numpy()
104 | ax.plot(points_world[seq[0]], points_world[seq[1]], points_world[seq[2]], options)
105 |
106 |
107 | class Camera(Endpoint):
108 | def __init__(self, transformation,
109 | filmsize, f=np.zeros(2), c=np.zeros(2), k=np.zeros(3), p=np.zeros(2), device=torch.device('cpu')):
110 | self.filmsize = filmsize
111 | self.f = torch.Tensor(np.float32(f)) # focal lengths [pixel]
112 | self.c = torch.Tensor(np.float32(c)) # centers [pixel]
113 | Endpoint.__init__(self, transformation, device)
114 |
115 | # un-initialized for now:
116 | self.crop_offset = torch.zeros(2, device=device) # [pixel]
117 |
118 | # no use for now:
119 | self.k = torch.Tensor(np.float32(k))
120 | if len(self.k) < 3: self.k = np.append(self.k, 0)
121 | self.p = torch.Tensor(np.float32(p))
122 |
123 | # configurations
124 | self.NEWTONS_MAXITER = 5
125 | self.NEWTONS_TOLERANCE = 50e-6 # in [mm], i.e. 50 [nm] here
126 | self.use_approximation = False
127 | self.to(device)
128 |
129 | def generate_position_sample(self, mask=None):
130 | """
131 | Generate position samples (not uniform sampler) from a 2D mask.
132 | """
133 | dim = self.filmsize
134 | X, Y = torch.meshgrid(
135 | 0.5 + dim[0] * torch.linspace(0, 1, 1+dim[0], device=self.device)[:-1],
136 | 0.5 + dim[1] * torch.linspace(0, 1, 1+dim[1], device=self.device)[:-1],
137 | )
138 | if mask is not None:
139 | X, Y = X[mask], Y[mask]
140 | # X.shape could be 1D (masked) or 2D (no masked)
141 | return torch.stack((X, Y), axis=len(X.shape))
142 |
143 | def sample_ray(self, position_sample=None, is_sampler=False):
144 | """
145 | Sample ray(s) from sensor pixels.
146 | """
147 | wavelength = torch.Tensor(np.asarray(562.0)).to(self.device) # 562 [nm]
148 |
149 | if position_sample is None: # default: full-sensor deterministic rendering
150 | dim = self.filmsize
151 | position_sample = self.generate_position_sample()
152 | is_sampler = False
153 | else:
154 | dim = position_sample.shape[:-1]
155 |
156 | if is_sampler:
157 | uv = position_sample * np.float32(dim)
158 | else:
159 | uv = position_sample
160 |
161 | # in local
162 | xy = self._uv2xy(uv)
163 | dz = torch.ones((*dim, 1), device=self.device)
164 | d = torch.cat((xy, dz), axis=-1)
165 | d = normalize(d)
166 | o = torch.zeros((*dim, 3), device=self.device)
167 |
168 | # in world
169 | o = self.to_world.transform_point(o)
170 | d = self.to_world.transform_vector(d)
171 | ray = Ray(o, d, wavelength, self.device)
172 | return ray
173 |
174 | def draw_points(self, ax, options, seq=range(3)):
175 | origin = np.zeros(3)
176 | scales = np.append(self.filmsize/100, 20)
177 | coeffs = np.array([
178 | [ 1, 1, 1],
179 | [-1, 1, 1],
180 | [-1,-1, 1],
181 | [ 1,-1, 1],
182 | [ 1, 1, 1]
183 | ])
184 | sensor_corners = torch.Tensor(coeffs * scales).to(self.device)
185 | ps = self.to_world.transform_point(sensor_corners).T.cpu().detach().numpy()
186 | ax.plot(ps[seq[0]], ps[seq[1]], ps[seq[2]], options)
187 |
188 | for i in range(4):
189 | coeff = coeffs[i] * scales
190 | line = torch.Tensor(np.array([
191 | origin,
192 | coeff
193 | ])).to(self.device)
194 | ps = self.to_world.transform_point(line).T.cpu().detach().numpy()
195 | ax.plot(ps[seq[0]], ps[seq[1]], ps[seq[2]], options)
196 |
197 | def _uv2xy(self, uv):
198 | xy_distorted = (uv + self.crop_offset - self.c) / self.f
199 | xy = xy_distorted
200 | return xy
201 |
202 | # ----------------------------------------------------------------------------------------
203 |
204 | def generate_test_camera(R=np.eye(3), t=np.array([0, 0, -400])):
205 | to_world = Transformation(R, t)
206 |
207 | filmsize = np.array([360, 480])
208 | f = np.array([1000,1010]) # [pixel]
209 | c = np.array([150,160]) # [pixel]
210 | return Camera(to_world, filmsize, f, c)
211 |
212 | def generate_test_screen():
213 | R = np.eye(3)
214 | t = np.zeros(3)
215 | to_world = Transformation(R, t)
216 |
217 | screensize = np.array([81., 80.]) # [mm]
218 | pixelsize = 0.115 # [mm]
219 |
220 | # read texture image
221 | im = imread('./images/checkerboard.png')
222 | im = np.mean(im, axis=-1) # for now we use grayscale
223 | im = im[200:500,200:600]
224 | # plt.figure()
225 | # plt.imshow(im, cmap='gray')
226 | # plt.show()
227 |
228 | return Screen(to_world, screensize, pixelsize, im)
229 |
230 | def test_camera():
231 | camera = generate_test_camera()
232 | # print(camera)
233 |
234 | position_sample = torch.rand((*camera.filmsize, 2))
235 | ray = camera.sample_ray(position_sample)
236 | print(ray)
237 |
238 | def test_screen():
239 | screen = generate_test_screen()
240 | print(screen)
241 |
242 | # ----------------------------------------------------------------------------------------
243 |
244 |
245 | if __name__ == "__main__":
246 | init()
247 |
248 | test_camera()
249 | test_screen()
250 |
--------------------------------------------------------------------------------
/diffmetrology/solvers.py:
--------------------------------------------------------------------------------
1 | from .basics import *
2 | import numpy as np
3 | import torch.autograd.functional as F
4 | from skimage.restoration import unwrap_phase
5 |
6 |
7 | class Fringe(PrettyPrinter):
8 | """
9 | Fringe image analysis to resolve displacements.
10 | """
11 | def __init__(self):
12 | self.PHASE_SHIFT_COUNT = 4
13 | self.XY_COUNT = 2
14 |
15 | def solve(self, fs):
16 | """
17 | ----- old -----
18 | ref.shape = [len(Ts), self.XY_COUNT, camera_count, img_size]
19 | ----- old -----
20 |
21 | Outputs:
22 | a.shape = b.shape = p.shape =
23 | [2, original size]
24 |
25 | where 2 denotes x and y.
26 | """
27 | def single(fs):
28 | ax, bx, psix = self._solve(fs[0:self.PHASE_SHIFT_COUNT])
29 | ay, by, psiy = self._solve(fs[self.PHASE_SHIFT_COUNT:self.XY_COUNT*self.PHASE_SHIFT_COUNT])
30 | return np.array([ax, ay]), np.array([bx, by]), np.array([psix, psiy])
31 |
32 | # TODO: following does not work when `len(Ts) == self.XY_COUNT*self.PHASE_SHIFT_COUNT` ...
33 | fsize = list(fs.shape)
34 | xy_index = fsize.index(self.XY_COUNT*self.PHASE_SHIFT_COUNT)
35 | inds = [i for i in range(len(fsize))]
36 | inds.remove(xy_index)
37 | inds = [xy_index] + inds
38 |
39 | # run the algorithm
40 | a, b, p = single(fs.transpose(inds))
41 |
42 | return a, b, p
43 |
44 | def unwrap(self, fs, Ts, valid=None):
45 | print('unwraping ...')
46 | if valid is None:
47 | valid = 1.0
48 | F = valid * fs
49 | fs_unwrapped = np.zeros(fs.shape)
50 | for xy in range(fs.shape[0]):
51 | print(f'xy = {xy} ...')
52 | for T in range(fs.shape[1]):
53 | print(f't = {T} ...')
54 | t = Ts[T]
55 | for i in range(fs.shape[2]):
56 | fs_unwrapped[xy,T,i] = unwrap_phase(F[xy,T,i,...]) * t / (2*np.pi)
57 | return fs_unwrapped
58 |
59 | @staticmethod
60 | def _solve(fs):
61 | """Solver for four-step phase shifting: f(x) = a + b cos(\phi + \psi).
62 | b cos(\psi) = fs[0] - a
63 | - b sin(\psi) = fs[1] - a
64 | - b cos(\psi) = fs[2] - a
65 | b sin(\psi) = fs[3] - a
66 | """
67 | a = np.mean(fs, axis=0)
68 | b = 0.0
69 | for f in fs:
70 | b += (f - a)**2
71 | b = b/2.0
72 | psi = np.arctan2(fs[3] - fs[1], fs[0] - fs[2])
73 | return a, b, psi
74 |
75 |
76 | class Optimization(PrettyPrinter):
77 | """
78 | General class for design optimization.
79 | """
80 | def __init__(self, diff_variables):
81 | self.diff_variables = diff_variables
82 | for v in self.diff_variables:
83 | v.requires_grad = True
84 | self.optimizer = None # to be initialized from set_parameters
85 |
86 | class Adam(Optimization):
87 | def __init__(self, diff_variables, lr, lrs=None, beta=0.99):
88 | Optimization.__init__(self, diff_variables)
89 | if lrs is None:
90 | lrs = [1] * len(self.diff_variables)
91 | self.optimizer = torch.optim.Adam(
92 | [{"params": v, "lr": lr*l} for v, l in zip(self.diff_variables, lrs)],
93 | betas=(beta,0.999), amsgrad=True
94 | )
95 |
96 | def optimize(self, func, loss, maxit=300, record=True):
97 | print('optimizing ...')
98 | ls = []
99 | with torch.autograd.set_detect_anomaly(False): #True
100 | for it in range(maxit):
101 | L = loss(func())
102 | self.optimizer.zero_grad()
103 | L.backward(retain_graph=True)
104 |
105 | if record:
106 | grads = torch.Tensor([torch.mean(torch.abs(v.grad)) for v in self.diff_variables])
107 | print('iter = {}: loss = {:.4e}, grad_bar = {:.4e}'.format(
108 | it, L.item(), torch.mean(grads)
109 | ))
110 | ls.append(L.cpu().detach().numpy())
111 |
112 | self.optimizer.step()
113 |
114 | return np.array(ls)
115 |
116 |
117 | class LM(Optimization):
118 | """
119 | The Levenberg–Marquardt algorithm.
120 | """
121 | def __init__(self, diff_variables, lamb, mu=None, option='diag'):
122 | Optimization.__init__(self, diff_variables)
123 | self.lamb = lamb # damping factor
124 | self.mu = 2.0 if mu is None else mu # dampling rate (>1)
125 | self.option = option
126 |
127 | def jacobian(self, func, inputs, create_graph=False, strict=False):
128 | """Constructs a M-by-N Jacobian matrix where M >> N.
129 |
130 | Here, computing the Jacobian only makes sense for a tall Jacobian matrix. In this case,
131 | column-wise evaluation (forward-mode, or jvp) is more effective to construct the Jacobian.
132 |
133 | This function is modified from torch.autograd.functional.jvp().
134 | """
135 |
136 | Js = []
137 | outputs = func()
138 | M = outputs.shape
139 |
140 | grad_outputs = (torch.zeros_like(outputs, requires_grad=True),)
141 | for x in inputs:
142 | grad_inputs = F._autograd_grad(
143 | (outputs,), x, grad_outputs, create_graph=True
144 | )
145 |
146 | F._check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
147 |
148 | # Construct Jacobian matrix
149 | N = torch.numel(x)
150 | if N == 1:
151 | J = F._autograd_grad(
152 | grad_inputs, grad_outputs, (torch.ones_like(x),),
153 | create_graph=create_graph,
154 | retain_graph=True
155 | )[0][...,None]
156 | else:
157 | J = torch.zeros((*M, N), device=x.device)
158 | v = torch.zeros(N, device=x.device)
159 | for i in range(N):
160 | v[i] = 1.0
161 | J[...,i] = F._autograd_grad(
162 | grad_inputs, grad_outputs, (v.view(x.shape),),
163 | create_graph=create_graph,
164 | retain_graph=True
165 | )[0]
166 |
167 | v[i] = 0.0
168 | Js.append(J)
169 | return torch.cat(Js, axis=-1)
170 |
171 | def optimize(self, func, change_parameters, func_yref_y, maxit=300, record=True):
172 | """
173 | Inputs:
174 | - func: Evaluate `y = f(x)` where `x` is the implicit parameters by `self.diff_variables` (out of the class)
175 | - change_parameters: Change of `self.diff_variables` (out of the class)
176 | - func_yref_y: Compute `y_ref - y`
177 |
178 | Outputs:
179 | - ls: Loss function.
180 | """
181 | print('optimizing ...')
182 | Ns = [x.numel() for x in self.diff_variables]
183 | NS = [[*x.shape] for x in self.diff_variables]
184 |
185 | ls = []
186 | lamb = self.lamb
187 | with torch.autograd.set_detect_anomaly(False):
188 | for it in range(maxit):
189 | y = func()
190 | with torch.no_grad():
191 | L = torch.mean(func_yref_y(y)**2).item()
192 | if L < 1e-16:
193 | print('L too small; termiante.')
194 | break
195 |
196 | # Obtain Jacobian
197 | J = self.jacobian(func, self.diff_variables, create_graph=False)
198 | J = J.view(-1, J.shape[-1])
199 | JtJ = J.T @ J
200 | N = JtJ.shape[0]
201 |
202 | # Regularization matrix
203 | if self.option == 'I':
204 | R = torch.eye(N, device=JtJ.device)
205 | elif self.option == 'diag':
206 | R = torch.diag(torch.diag(JtJ).abs())
207 | else:
208 | R = torch.diag(self.option)
209 |
210 | # Compute b = J.T @ (y_ref - y)
211 | bb = [
212 | torch.autograd.grad(outputs=y, inputs=x, grad_outputs=func_yref_y(y), retain_graph=True)[0]
213 | for x in self.diff_variables
214 | ]
215 | for i, bx in enumerate(bb):
216 | if len(bx.shape) == 0: # single scalar
217 | bb[i] = torch.Tensor([bx.item()]).to(y.device)
218 | if len(bx.shape) > 1: # multi-dimension
219 | bb[i] = torch.Tensor(bx.cpu().detach().numpy().flatten()).to(y.device)
220 | b = torch.cat(bb, axis=-1)
221 | del J, bb, y
222 |
223 | # Damping loop
224 | L_current = L + 1.0
225 | it_inner = 0
226 | while L_current >= L:
227 | it_inner += 1
228 | if it_inner > 20:
229 | print('inner loop too many; Exiting damping loop.')
230 | break
231 |
232 | A = JtJ + lamb * R
233 | x_delta = torch.linalg.solve(A, b)
234 | if torch.isnan(x_delta).sum():
235 | print('x_delta NaN; Exiting damping loop')
236 | break
237 | x_delta_s = torch.split(x_delta, Ns)
238 |
239 | # Reshape if x is not a 1D array
240 | x_delta_s = [*x_delta_s]
241 | for xi in range(len(x_delta_s)):
242 | x_delta_s[xi] = torch.reshape(x_delta_s[xi], NS[xi])
243 |
244 | # Update `x += x_delta` (this is done in external function `change_parameters`)
245 | self.diff_variables = change_parameters(x_delta_s, sign=True)
246 |
247 | # Calculate new error
248 | with torch.no_grad():
249 | L_current = torch.mean(func_yref_y(func())**2).item()
250 |
251 | del A
252 |
253 | # Terminate
254 | if L_current < L:
255 | lamb /= self.mu
256 | del x_delta_s
257 | break
258 |
259 | # Else, increase damping and undo the update
260 | lamb *= 2.0*self.mu
261 | self.diff_variables = change_parameters(x_delta_s, sign=False)
262 |
263 | if lamb > 1e16:
264 | print('lambda too big; Exiting damping loop.')
265 | del x_delta_s
266 | break
267 |
268 | del JtJ, R, b
269 |
270 | if record:
271 | x_increment = torch.mean(torch.abs(x_delta)).item()
272 | print('iter = {}: loss = {:.4e}, |x_delta| = {:.4e}'.format(
273 | it, L, x_increment
274 | ))
275 | ls.append(L)
276 | if it > 0:
277 | dls = np.abs(ls[-2] - L)
278 | if dls < 1e-8:
279 | print("|\Delta loss| = {:.4e} < 1e-8; Exiting LM loop.".format(dls))
280 | break
281 |
282 | if x_increment < 1e-8:
283 | print("|x_delta| = {:.4e} < 1e-8; Exiting LM loop.".format(x_increment))
284 | break
285 | return ls
286 |
--------------------------------------------------------------------------------
/diffmetrology/utils.py:
--------------------------------------------------------------------------------
1 | from .scene import *
2 | from .solvers import *
3 | import scipy
4 | import scipy.optimize
5 | import scipy.io
6 |
7 | from matplotlib.image import imread
8 | import matplotlib.ticker as plticker
9 | import time
10 |
11 |
12 | def var2string(variable):
13 | for name in globals():
14 | if eval(name) == variable:
15 | return name
16 | return ''
17 |
18 | class DiffMetrology(PrettyPrinter):
19 | """
20 | Major class to handle all situations
21 | """
22 | def __init__(self,
23 | calibration_path, rotation_path,
24 | origin_shift, lut_path=None, thetas=None, angles=0.,
25 | scale=1, device=torch.device('cpu')):
26 |
27 | self.device = device
28 | self.MAX_VAL = 2**16 - 1.0 # 16-bit image
29 |
30 | # geometry setup
31 | mat_g = scipy.io.loadmat(calibration_path + 'cams.mat')
32 | cameras = self._init_camera(mat_g, scale)
33 | screen = self._init_screen(mat_g)
34 | self.scene = Scene(cameras, screen, device=device)
35 |
36 | # cache calibration checkerboard image for testing
37 | self._checkerboard = imread(calibration_path + 'checkerboard.png')
38 | self._checkerboard = np.mean(self._checkerboard, axis=-1) # for now we use grayscale
39 | self._checkerboard = np.flip(self._checkerboard, axis=1).copy()
40 | self._checkerboard_wh = np.array([mat_g['w'], mat_g['h']])
41 |
42 | # lensgroup metrology part setup
43 | mat_r = scipy.io.loadmat(rotation_path)
44 | p_rotation = torch.Tensor(np.stack((mat_r['p1'][0], mat_r['p2'][0]), axis=-1)).T
45 | origin_mount = self._compute_mount_geometry(p_rotation*scale, verbose=True)
46 | if thetas is None:
47 | self.scene.lensgroup = Lensgroup(origin_mount, origin_shift, 0.0, 0.0, 0.0, device)
48 | else:
49 | self.scene.lensgroup = Lensgroup(origin_mount, origin_shift, thetas[0], thetas[1], 0.0, device)
50 |
51 | if type(angles) is not list:
52 | angles = [angles]
53 | self.angles = angles
54 |
55 | # load sensor LUT
56 | if lut_path is not None:
57 | self.lut = scipy.io.loadmat(lut_path)['Js'][:,:self.scene.camera_count]
58 | self.bbd = scipy.io.loadmat(lut_path)['bs'].reshape((2,self.scene.camera_count))
59 |
60 | self.ROTATION_ANGLE = -25.0
61 |
62 | # === Utility methods ===
63 | def solve_for_intersections(self, fs_cap, fs_ref, Ts):
64 | """
65 | Obtain the intersection points from phase-shifting images.
66 | """
67 | FR = Fringe()
68 | a_ref, b_ref, psi_ref = FR.solve(fs_ref)
69 | a_cap, b_cap, psi_cap = FR.solve(fs_cap)
70 |
71 | VERBOSE = False
72 | # VERBOSE = True
73 |
74 | def find_center(valid_cap):
75 | x, y = np.argwhere(valid_cap==1).sum(0)/ valid_cap.sum()
76 | return np.array([x, y])
77 |
78 | # get valid map for two cameras
79 | valid_map = []
80 | A = np.mean(a_cap, axis=(0,1))
81 | B = np.mean(b_cap, axis=(0,1))
82 | I = np.abs(np.mean(fs_ref - fs_cap, axis=(0,1)))
83 | for j in range(self.scene.camera_count):
84 | # thres = 0.005
85 | # thres = 0.1
86 | thres = 0.07
87 | valid_ab = (A[j] > thres) & (B[j] > thres) & (I[j] < 2.0*thres)
88 | if VERBOSE:
89 | plt.imshow(valid_ab); plt.show()
90 | label, num_features = scipy.ndimage.label(valid_ab)
91 | if num_features < 2:
92 | label_target = 1
93 | else: # count labels and get the area as the target measurement area
94 | counts = np.array([np.count_nonzero(label == i) for i in range(num_features)])
95 | label_targets = np.where((200**2 < counts) & (counts < 500**2))[0]
96 | if len(label_targets) > 1: # find which label target is our lens
97 | Dm = np.inf
98 | for l in label_targets:
99 | c = find_center(label == l)
100 | D = np.abs(self.scene.cameras[0].filmsize/2 - c).sum()
101 | if D < Dm:
102 | Dm = D
103 | label_target = l
104 | else:
105 | label_target = label_targets[0]
106 | V = label == label_target
107 | valid_map.append(V)
108 | valid_map = np.array(valid_map)
109 |
110 | psi_unwrap = FR.unwrap(psi_cap - psi_ref, Ts, valid=valid_map[None,None,...])
111 |
112 | # Given the unwrapped phase, we try to remove the DC term so that |psi_unwrap|^2 is minimized
113 | # NOTE: This method is not yet robust
114 | # remove_DC = False
115 | remove_DC = True
116 | if remove_DC:
117 | k_DC = np.arange(-10,11,1)
118 | for t in range(len(Ts)):
119 | DCs = k_DC * Ts[t]
120 | psi_current = psi_unwrap[:,t,:,...]
121 | psi_with_dc = valid_map[:,None,:,:,None] * (psi_current[...,None] + DCs[None,None,None,None,...])
122 | DC_target = DCs[np.argmin(np.sum(psi_with_dc**2, axis=(2,3)), axis=-1)]
123 | print("t = {}, DC_target =\n{}".format(Ts[t], DC_target))
124 | psi_unwrap[:,t,:,...] += valid_map[None,:,...] * np.transpose(DC_target,(0,1))[...,None,None]
125 |
126 | if VERBOSE:
127 | for t in range(len(Ts)):
128 | for ii in range(2):
129 | plt.figure()
130 | plt.imshow(psi_unwrap[0,t,ii,...], cmap='coolwarm')
131 | plt.show()
132 |
133 | # Convert unit from [pixel number] to [mm]
134 | psi_x = psi_unwrap[0, ...] * self.scene.screen.pixelsize.item()
135 | psi_y = psi_unwrap[1, ...] * self.scene.screen.pixelsize.item()
136 |
137 | # Get median of the values across different Ts
138 | psi_x = np.mean(psi_x, axis=0)
139 | psi_y = np.mean(psi_y, axis=0)
140 |
141 | # Compute intersection points when there are no elements
142 | ps_ref = torch.stack(self.trace(with_element=False, angles=0.0)[0])
143 | ps_ref = valid_map[...,None] * ps_ref[...,0:2].cpu().detach().numpy()
144 |
145 | # Compute final shift (here, the valid map is the valid map of x of the first T)
146 | # NOTE: here we flip the sequence of (x,y) to fit the format for lateral processings ...
147 | p = ps_ref - np.stack((psi_y, psi_x), axis=-1)
148 |
149 | # Find valid map centers
150 | xs = []
151 | ys = []
152 | for i in range(len(valid_map)):
153 | x, y = np.argwhere(valid_map[i]==1).sum(0)/ valid_map[i].sum()
154 | xs.append(x)
155 | ys.append(y)
156 | centers = np.stack((np.array(xs), np.array(ys)), axis=-1)
157 | centers = np.fliplr(centers).copy() # flip to fit our style
158 |
159 | return torch.Tensor(p).to(self.device), torch.Tensor(valid_map).bool().to(self.device), torch.Tensor(centers).to(self.device)
160 |
161 |
162 | def simulation(self, sinusoid_path, Ts, i=None, angles=None, to_numpy=False):
163 | """
164 | Render fringe images.
165 |
166 | img.shape = [ len(Ts), 8, self.camera_count, img_size ]
167 | """
168 | print('Simulation ...')
169 |
170 | # cache current screen
171 | screen_org = self.scene.screen
172 | pixelsize = screen_org.pixelsize.item()
173 |
174 | def single_impl(with_element):
175 | """
176 | imgs_all.shape = len(Ts) * [0-8] * camera_count * img_size.
177 | """
178 | imgs_all = []
179 | for T in Ts:
180 | print(f'Now at T = {T}')
181 | img_path = sinusoid_path + 'T=' + str(T) + '/'
182 |
183 | # with elements
184 | imgs = []
185 | for i in range(8):
186 | # read sinusoid images
187 | im = imread(img_path + str(i) + '.png')
188 | im = np.mean(im, axis=-1) # for now we use grayscale
189 | im = np.flip(im).copy() # NOTE: (i) our display is rotated by 90 deg;
190 | # (ii) to be consistent with XY convention here. So the flip.
191 |
192 | # set screen to be sinusoid patterns
193 | sizenew = pixelsize * np.array(im.shape)
194 | t = np.array([sizenew[0]/2-50.0, sizenew[1]/2-80.0, 0])
195 | self.scene.screen = Screen(Transformation(np.eye(3), t), sizenew, pixelsize, im, self.device)
196 |
197 | # render
198 | tmp = self.scene.render(with_element=with_element, to_numpy=True)
199 | imgs.append(np.array(tmp))
200 | imgs_all.append(np.array(imgs))
201 |
202 | return np.array(imgs_all)
203 |
204 | def single(angle):
205 | self.scene.lensgroup.update(_y=angle)
206 | imgs = single_impl(True)
207 | self.scene.lensgroup.update(_y=-angle)
208 | return imgs
209 |
210 | # we only capture reference once
211 | print('Simulating reference ...')
212 | refs = single_impl(False)
213 |
214 | # here, we measure testing part in each angle
215 | print('Simulating measurements ...')
216 | if angles is None:
217 | """
218 | imgs_rendered.shape = len(angles) * len(Ts) * [0-8] * camera_count * img_size.
219 | """
220 | ims = []
221 | for angle in self.angles:
222 | ims.append(single(angle))
223 | ims = np.array(ims)
224 | else:
225 | print(f'Now at angle = {angles}')
226 | ims = single(angles)
227 |
228 | # revert back to the original screen
229 | self.scene.screen = screen_org
230 |
231 | print('Done ...')
232 | return ims, refs
233 |
234 | def render(self, i=None, with_element=True, mask=None, angles=None, to_numpy=False):
235 | """
236 | Rendering.
237 | """
238 | def single(angle):
239 | self.scene.lensgroup.update(_y=angle)
240 | im = self.scene.render(i, with_element, mask, to_numpy)
241 | self.scene.lensgroup.update(_y=-angle)
242 | return im
243 |
244 | if angles is None:
245 | ims = []
246 | for angle in self.angles:
247 | ims += single(angle)
248 | else:
249 | ims = single(angles)
250 | return ims
251 |
252 | def trace(self, i=None, with_element=True, mask=None, angles=None, to_numpy=False):
253 | """
254 | Perform ray tracing.
255 | """
256 | def single(angle):
257 | self.scene.lensgroup.update(_y=angle)
258 | ps, valid, mask_g = self.scene.trace(i, with_element, mask, to_numpy)
259 | self.scene.lensgroup.update(_y=-angle)
260 | return ps, valid, mask_g
261 |
262 | if angles is None:
263 | ps = []
264 | valid = []
265 | mask_g = []
266 | for angle in self.angles:
267 | pss, valids, mask_gs = single(angle)
268 | ps += pss
269 | valid += valids
270 | mask_g += mask_gs
271 | else:
272 | ps, valid, mask_g = single(angles)
273 | return ps, valid, mask_g
274 |
275 | # =====================
276 |
277 | def to(self, device=torch.device('cpu')):
278 | super().to(device)
279 | self.device = device
280 | self.scene.to(device)
281 |
282 | # === Visualizations ===
283 | def imshow(self, imgs):
284 | N = self.scene.camera_count
285 | self._imshow(imgs[0:N], title='front')
286 | if len(imgs) > N:
287 | self._imshow(imgs[N:2*N], title='back')
288 | plt.show()
289 |
290 | def _imshow(self, imgs, title=''):
291 | ax = plt.subplots(1,len(imgs))[1]
292 | for i, img in enumerate(imgs):
293 | ax[i].imshow(img, cmap='gray')
294 | ax[i].set_title(title + ': camera ' + str(i+1))
295 |
296 | def spot_diagram(self, ps_ref, ps_cap, valid=True, angle=None, with_grid=True):
297 | """
298 | Plot spot diagram.
299 | """
300 | N = self.scene.camera_count
301 | for j, a in enumerate(self.angles):
302 | if angle == a:
303 | i = j
304 | try:
305 | i
306 | except NameError:
307 | i = 0
308 | figure = self._spot_diagram(ps_ref[N*i:N*(i+1)], ps_cap[N*i:N*(i+1)], valid[N*i:N*(i+1)], title=f'angle={angle}', with_grid=with_grid)
309 | figure.suptitle('Spot Diagram')
310 | return figure
311 |
312 | def _spot_diagram(self, ps_ref, ps_cap, valid=True, title='', with_grid=False):
313 | """
314 | Plot spot diagram.
315 | """
316 | figure, ax = plt.subplots(1, self.scene.camera_count)
317 |
318 | def sub_sampling(x):
319 | Ns = [8,8]
320 | return x[::Ns[0],::Ns[1],...]
321 | for i in range(len(ax)):
322 | mask = sub_sampling(valid[i])
323 | ref = sub_sampling(ps_ref[i])[mask].cpu().detach().numpy()
324 | cap = sub_sampling(ps_cap[i])[mask].cpu().detach().numpy()
325 | ax[i].plot(ref[...,0], ref[...,1], 'b.', label='Measurement')
326 | ax[i].plot(cap[...,0], cap[...,1], 'r.', label='Modeled (reprojection)')
327 | ax[i].legend()
328 | ax[i].set_xlabel('[mm]')
329 | ax[i].set_ylabel('[mm]')
330 | ax[i].set_aspect(1)
331 | ax[i].set_title(title + ': camera ' + str(i+1))
332 |
333 | # Add the grid
334 | if with_grid:
335 | loc = plticker.MultipleLocator(base=4*self.scene.screen.pixelsize.item())
336 | ax[i].xaxis.set_major_locator(loc)
337 | ax[i].yaxis.set_major_locator(loc)
338 | ax[i].grid(which='major', axis='both', linestyle='-')
339 | ax[i].tick_params(axis='both', which='minor', width=0)
340 | return figure
341 |
342 | def generate_grid(self, R):
343 | N = 513
344 | x = y = torch.linspace(-R, R, N, device=self.device)
345 | X, Y = torch.meshgrid(x, y)
346 | valid = X**2 + Y**2 <= R**2
347 | return X, Y, valid
348 |
349 | def show_surfaces(self, verbose=True):
350 | if verbose:
351 | ax = plt.subplots(1, len(self.scene.lensgroup.surfaces))[1]
352 | Zs = []
353 | valids = []
354 | for i, surface in enumerate(self.scene.lensgroup.surfaces):
355 | X, Y, valid = self.generate_grid(surface.r)
356 | Z = surface.surface(X, Y)
357 | Z_mean = Z[valid].mean().item()
358 | Z = torch.where(valid, Z - Z_mean, torch.zeros_like(Z)).cpu().detach().numpy()
359 | valids.append(valid)
360 | Zs.append(Z)
361 | if verbose:
362 | im = ax[i].imshow(Z, cmap='jet')
363 | ax[i].set_title('surface ' + str(i))
364 | plt.colorbar(im, ax=ax[i])
365 | if verbose:
366 | plt.show()
367 | return Zs, valids
368 |
369 | def print_surfaces(self):
370 | if self.scene.lensgroup == None:
371 | print('No surfaces found; Please initialize lensgroup!')
372 | else:
373 | for i, s in enumerate(self.scene.lensgroup.surfaces):
374 | print("surface[{}] = {}".format(i, s))
375 |
376 | # =====================
377 |
378 | # === Optimizations ===
379 | def change_parameters(self, diff_parameters_names, xs, sign=True):
380 | diff_parameters = []
381 | for i, name in enumerate(diff_parameters_names):
382 | if sign:
383 | exec('self.scene.{name} = self.scene.{name} + xs[{i}]'.format(name=name,i=i))
384 | else:
385 | exec('self.scene.{name} = self.scene.{name} - xs[{i}]'.format(name=name,i=i))
386 | exec('diff_parameters.append(self.scene.{})'.format(name))
387 | return diff_parameters
388 |
389 | def solve(self, diff_parameters_names, forward, loss, func_yref_y=None, option='LM', R=None):
390 | """
391 | Solve for unknown parameters.
392 | """
393 | # def loss(I):
394 | # return (I - I0).mean()
395 |
396 | # def func_yref_y(I):
397 | # return I0 - I
398 |
399 | time_start = time.time()
400 |
401 | diff_parameters = []
402 | for name in diff_parameters_names:
403 | try:
404 | exec('self.scene.{}.requires_grad = True'.format(name))
405 | except:
406 | exec('self.scene.{name} = self.scene.{name}.detach()'.format(name=name))
407 | exec('self.scene.{}.requires_grad = True'.format(name))
408 | exec('diff_parameters.append(self.scene.{})'.format(name))
409 |
410 | if option == 'Adam':
411 | O = Adam(
412 | diff_variables=diff_parameters,
413 | lr=1e-1,
414 | beta=0.99
415 | )
416 | ls = O.optimize(
417 | forward,
418 | loss,
419 | maxit=200
420 | )
421 |
422 | elif option == 'LM':
423 | if func_yref_y is None:
424 | raise Exception("func_yref_y is not given!")
425 |
426 | if R is None:
427 | Ropt = 'I'
428 | else:
429 | Ropt = R
430 | O = LM(
431 | diff_variables=diff_parameters,
432 | lamb=1e-1, # 1e-4
433 | option=Ropt
434 | )
435 | ls = O.optimize(
436 | forward,
437 | lambda xs, sign: self.change_parameters(diff_parameters_names, xs, sign),
438 | func_yref_y=func_yref_y,
439 | maxit=100
440 | )
441 |
442 | else:
443 | raise NotImplementedError()
444 |
445 | for name in diff_parameters_names:
446 | print('self.scene.{} = '.format(name), end='')
447 | exec('print(self.scene.{}.cpu().detach().numpy())'.format(name))
448 |
449 | torch.cuda.synchronize()
450 | time_end = time.time()
451 |
452 | print('Elapsed time = {:e} seconds'.format(time_end - time_start))
453 |
454 | return ls
455 | # =====================
456 |
457 |
458 | # === Optimizations ===
459 | def init_diff_parameters(self, dicts=None, pose_dict=None):
460 | self.diff_parameters = {}
461 | print('Initializing differentiable parameters:')
462 | if dicts is not None:
463 | for i, dictionary in enumerate(dicts):
464 | if dictionary is None:
465 | continue # skip if there is none
466 | if type(dictionary) is dict:
467 | for key, value in dictionary.items():
468 | keystr = 'surfaces[{}].{}'.format(i, key)
469 | full_keystr = 'self.scene.lensgroup.' + keystr
470 | print('--- ' + full_keystr)
471 | self.diff_parameters[keystr] = torch.Tensor(np.asarray(value)).to(self.device)
472 | exec('{} = self.diff_parameters[keystr].clone()'.format(full_keystr))
473 | exec('{}.requires_grad = True'.format(full_keystr))
474 | elif type(dictionary) is set:
475 | for key in dictionary:
476 | keystr = 'surfaces[{}].{}'.format(i, key)
477 | full_keystr = 'self.scene.lensgroup.' + keystr
478 | print('--- ' + full_keystr)
479 | exec('self.diff_parameters[keystr] = {}.clone()'.format(full_keystr))
480 | exec('{}.requires_grad = True'.format(full_keystr))
481 | else:
482 | raise Exception("wrong type dicts!")
483 | if pose_dict is not None:
484 | if type(pose_dict) is dict:
485 | for keystr, value in pose_dict.items():
486 | full_keystr = 'self.scene.lensgroup.' + keystr
487 | print('--- ' + full_keystr)
488 | self.diff_parameters[keystr] = torch.Tensor(np.asarray(value)).to(self.device)
489 | exec('{} = self.diff_parameters[keystr].clone()'.format(full_keystr))
490 | exec('{}.requires_grad = True'.format(full_keystr))
491 | elif type(pose_dict) is set:
492 | for keystr in pose_dict:
493 | full_keystr = 'self.scene.lensgroup.' + keystr
494 | print('--- ' + full_keystr)
495 | exec('self.diff_parameters[keystr] = {}.clone()'.format(full_keystr))
496 | exec('{}.requires_grad = True'.format(full_keystr))
497 | else:
498 | raise Exception("wrong type dicts!")
499 | print('... Done.')
500 |
501 | def print_diff_parameters(self):
502 | for key in self.diff_parameters.keys():
503 | full_keystr = 'self.scene.lensgroup.' + key
504 | exec("print('-- {{}} = {{}}'.format(key, {}.cpu().detach().numpy()))".format(full_keystr))
505 |
506 | @staticmethod
507 | def compRMS(Zs, Zs_gt, valid):
508 | def RMS(x, y, valid):
509 | x = x[valid]
510 | y = y[valid]
511 | x -= x.mean()
512 | y -= y.mean()
513 | tmp = (x - y)**2
514 | return np.sqrt(np.mean(tmp))
515 |
516 | rmss = []
517 | for i in range(len(Zs_gt)):
518 | rms = RMS(Zs[i], Zs_gt[i], valid[i].cpu().detach().numpy())
519 | print('RMS = {} [mm]'.format(rms))
520 | rmss.append(rms)
521 | return rmss
522 |
523 | def ploterror(self, Zs, Zs_gt, verbose=True):
524 | figure, ax = plt.subplots(1, len(self.scene.lensgroup.surfaces), figsize=(9,3.5))
525 | for i, s in enumerate(self.scene.lensgroup.surfaces):
526 | tmp = Zs[i] - Zs_gt[i]
527 | im = ax[i].imshow(tmp, cmap='jet')
528 | ax[i].set_title('surface ' + str(i))
529 | plt.colorbar(im, ax=ax[i])
530 | if verbose:
531 | plt.show()
532 | return figure
533 |
534 | # =====================
535 |
536 | def set_texture(self, textures):
537 | if len(textures.shape) > 2:
538 | texture = textures[0]
539 | else:
540 | texture = textures
541 | pixelsize = self.scene.screen.pixelsize.item()
542 | sizenew = pixelsize * np.array(texture.shape)
543 | t = np.zeros(3)
544 | self.scene.screen = Screen(Transformation(np.eye(3), t), sizenew, pixelsize, np.flip(texture).copy(), self.device)
545 |
546 | # === Tests ===
547 | def test_setup(self, verbose=True):
548 | """
549 | Test if the setup is correct: check render and measurement images for consistency
550 | """
551 | # cache current screen
552 | screen_org = self.scene.screen
553 | lensgroup = self.scene.lensgroup
554 | self.scene.lensgroup = None
555 |
556 | # set screen to be checkerboard
557 | pixelsize = screen_org.pixelsize.item()
558 | sizenew = pixelsize * np.array(self._checkerboard.shape)
559 | t = np.array([sizenew[0]/2, sizenew[1]/2, 0])
560 | t[0] -= sizenew[0]/self._checkerboard_wh[0]
561 | t[1] -= sizenew[1]/self._checkerboard_wh[1]
562 | self.scene.screen = Screen(Transformation(np.eye(3), t), sizenew, pixelsize, self._checkerboard, self.device)
563 |
564 | # render
565 | imgs_rendered = self.scene.render()
566 | if verbose:
567 | self.scene.plot_setup()
568 | plt.show()
569 |
570 | # revert back to the original screen
571 | self.scene.screen = screen_org
572 | self.scene.lensgroup = lensgroup
573 | return imgs_rendered
574 | # =====================
575 |
576 | # === Internal methods ===
577 | # parse cameras and screen parameters
578 | def _init_camera(self, mat, scale=1):
579 | filmsize = (mat['filmsize'] * scale).astype(int)
580 | f = mat['f'] * scale
581 | c = mat['c'] * scale
582 | k = mat['k'] * scale
583 | p = mat['p'] * scale
584 | R = mat['R']
585 | t = mat['t']
586 | def matlab2ours(R, t):
587 | """Convert MATLAB [R | t] convention to ours:
588 | In MATLAB (https://www.mathworks.com/help/vision/ug/camera-calibration.html):
589 | w_scale [x_obj y_obj 1] = [x_world y_world z_world] R_matlab + t_matlab
590 | [ 1x3 ] [ 1x3 ] [3x3] [1x3]
591 |
592 | Ours:
593 | [x_obj y_obj 1]' = R_object [x_world y_world z_world]' + t_object
594 | [ 3x1 ] [ 3x3 ] [ 3x1 ] [3x1]
595 |
596 | We would like to get the world-coordinate [R | t]. This requires:
597 | R_world = R_object.T
598 | t_world = -R_object.T @ t_object
599 |
600 | where we conclude R_world = R_matlab.
601 | """
602 | return R, -R @ t
603 | return [Camera(
604 | Transformation(*matlab2ours(R[...,i], t[...,i])),
605 | filmsize[i], f[i], c[i], k[i], p[i], self.device
606 | ) for i in range(len(f))]
607 |
608 | def _init_screen(self, mat):
609 | pixelsize = 1e-3 * mat['display_pixel_size'][0][0] # in [mm]
610 | size = pixelsize * np.array([1600, 2560]) # in [mm] (MacBook Pro 13.3")
611 | im = imread('./imgs/checkerboard.png')
612 | im = np.mean(im, axis=-1) # for now we use grayscale
613 |
614 | return Screen(Transformation(np.eye(3), np.zeros(3)), size, pixelsize, im, self.device)
615 |
616 | def _compute_mount_geometry(self, p_rotation, verbose=True):
617 | """
618 | We would like to estimate the intersection between two lines:
619 |
620 | L1 = o1 + t1 d1
621 | L2 = o2 + t2 d2
622 |
623 | where we need to solve for t1 and t2, by an over-determined least-squares
624 | (know R^3, solve for R^2):
625 |
626 | min || L1 - L2 ||^2 = || (o1-o2) + [d1,-d2] [t1;t2] ||^2
627 | t1,t2 [3x1] [3x2] [2x1]
628 |
629 | =>
630 |
631 | min || o + d t ||^2
632 | t [3x1] [3x2] [2x1]
633 |
634 | whose solution is t = (d.T d)^{-1} d.T (-o).
635 | """
636 | N = self.scene.camera_count
637 | rays = [self.scene.cameras[i].sample_ray(
638 | p_rotation[i][None,None,...].to(self.device), is_sampler=False) for i in range(N)]
639 |
640 | t, r = np.linalg.lstsq(
641 | torch.stack((rays[0].d, -rays[1].d), axis=-1).cpu().detach().numpy(),
642 | -(rays[0].o - rays[1].o).cpu().detach().numpy(), rcond=None
643 | )[0:2]
644 |
645 | t_pt = torch.Tensor(t).to(self.device)
646 | os = [rays[i](t_pt[i]) for i in range(N)]
647 | if verbose:
648 | for i, o in enumerate(os):
649 | print('intersection point {}: {}'.format(i, o))
650 | print('|intersection points distance error| = {} mm'.format(np.sqrt(r[0])))
651 | return torch.mean(torch.stack(os), axis=0).cpu().detach().numpy()
652 | # =====================
--------------------------------------------------------------------------------
/imgs/checkerboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/checkerboard.png
--------------------------------------------------------------------------------
/imgs/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/loss.png
--------------------------------------------------------------------------------
/imgs/results_initial.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/results_initial.png
--------------------------------------------------------------------------------
/imgs/results_optimized.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/results_optimized.png
--------------------------------------------------------------------------------
/imgs/setup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/setup.png
--------------------------------------------------------------------------------
/imgs/spot_diagram_initial.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/spot_diagram_initial.png
--------------------------------------------------------------------------------
/imgs/spot_diagram_optimized.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/spot_diagram_optimized.png
--------------------------------------------------------------------------------
/imgs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/teaser.jpg
--------------------------------------------------------------------------------
/lenses/ThorLabs/LE1234-A.txt:
--------------------------------------------------------------------------------
1 | Thorlabs-LE1234
2 | type distance roc diameter material
3 | O 0 0 0 AIR
4 | S 0 -82.23 25.4 N-BK7
5 | S 3.59 -32.14 25.4 AIR
6 | I 95.9 0 25.4 AIR
7 |
--------------------------------------------------------------------------------
/metrology_calibrate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | import diffmetrology as dm
5 | from matplotlib.image import imread
6 |
7 | # load setup information
8 | data_path = './20210403'
9 | device = dm.init()
10 | # device = torch.device('cpu')
11 |
12 | print("Initialize a DiffMetrology object.")
13 | origin_shift = np.array([0.0, 0.0, 0.0])
14 | DM = dm.DiffMetrology(
15 | calibration_path = data_path + '/calibration/',
16 | rotation_path = data_path + '/rotation_calibration/rotation.mat',
17 | lut_path = data_path + '/gamma_calibration/gammas.mat',
18 | origin_shift = origin_shift,
19 | scale=1.0,
20 | device=device
21 | )
22 |
23 | print("Crop the region of interst in the original images.")
24 | filmsize = np.array([1024, 1024])
25 | crop_offset = ((2048 - filmsize)/2).astype(int)
26 | for cam in DM.scene.cameras:
27 | cam.filmsize = filmsize
28 | cam.crop_offset = torch.Tensor(crop_offset).to(device)
29 | def crop(x):
30 | return x[..., crop_offset[0]:crop_offset[0]+filmsize[0], crop_offset[1]:crop_offset[1]+filmsize[1]]
31 |
32 |
33 | # ==== Read measurements
34 | lens_name = 'LE1234-A'
35 |
36 | # load data
37 | data = np.load(data_path + '/measurement/' + lens_name + '/data_new.npz')
38 | refs = data['refs']
39 | refs = crop(refs)
40 | del data
41 |
42 | Ts = np.array([70, 100, 110]) # period of the sinusoids
43 | t = 0
44 |
45 | # change display pattern
46 | xs = [0]
47 | sinusoid_path = './camera_acquisitions/images/sinusoids/T=' + str(Ts[t])
48 | ims = [ np.mean(imread(sinusoid_path + '/' + str(x) + '.png'), axis=-1) for x in xs ] # for now we use grayscale
49 | ims = np.array([ im/im.max() for im in ims ])
50 | ims = np.sum(ims, axis=0)
51 | DM.set_texture(ims)
52 | ims = torch.Tensor(ims).to(device)
53 |
54 | # reference image
55 | I0 = torch.Tensor(np.array([refs[t,x,...] for x in xs])).to(device)
56 | I0 = torch.sum(I0, axis=0)
57 |
58 | # define functions
59 | def forward():
60 | I = torch.stack(DM.render(with_element=False, angles=0.0))
61 | return I #/ I.max() * I0.max()
62 | def loss(I):
63 | return (I - I0).mean()
64 | def func_yref_y(I):
65 | return I0 - I
66 |
67 | def show_img(I, string):
68 | fig = plt.figure()
69 | plt.imshow(I[0].cpu().detach(), vmin=0, vmax=1, cmap='gray')
70 | plt.colorbar()
71 | plt.title(string)
72 | plt.axis('off')
73 | fig.savefig("img_" + string + ".jpg", bbox_inches='tight')
74 |
75 | def show_error(I, string):
76 | fig = plt.figure()
77 | plt.imshow(I[0].cpu().detach() - I0[0].cpu(), vmin=-1, vmax=1, cmap='coolwarm')
78 | plt.colorbar()
79 | plt.title(string)
80 | plt.axis('off')
81 | fig.savefig("photo_" + string + ".jpg", bbox_inches='tight')
82 |
83 |
84 | # initialize parameters
85 | DM.scene.screen.texture_shift = torch.Tensor([0.0, 0.0]).to(device)
86 |
87 | # parameters
88 | diff_names = ['screen.texture_shift']
89 |
90 | # initial
91 | I = forward()
92 | show_img(I0, 'Measurement')
93 | show_img(I, 'Modeled')
94 | show_error(I, 'Initial')
95 |
96 | # optimize
97 | ls = DM.solve(diff_names, forward, loss, func_yref_y, option='LM')
98 |
99 | # plot loss
100 | plt.figure()
101 | plt.semilogy(ls, '-o', color='k')
102 | plt.xlabel('LM iteration')
103 | plt.ylabel('Loss')
104 | plt.title("Opitmization Loss")
105 |
106 | I = forward()
107 | show_img(I0, 'Measurement')
108 | show_img(I, 'Modeled')
109 | show_error(I, 'Optimized')
110 |
111 | plt.show()
112 |
--------------------------------------------------------------------------------