├── 20210403 ├── README.md ├── calibration │ ├── calibration.mat │ ├── cam1.mat │ ├── cam2.mat │ ├── cams.mat │ └── checkerboard.png ├── gamma_calibration │ └── gammas.mat └── rotation_calibration │ └── rotation.mat ├── README.md ├── camera_acquisitions └── images │ ├── calibration.png │ └── sinusoids │ ├── T=100 │ ├── 0.png │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ └── 7.png │ ├── T=110 │ ├── 0.png │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ └── 7.png │ └── T=70 │ ├── 0.png │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ └── 7.png ├── demo_experiments.py ├── diffmetrology ├── __init__.py ├── basics.py ├── optics.py ├── scene.py ├── shapes.py ├── solvers.py └── utils.py ├── imgs ├── checkerboard.png ├── loss.png ├── results_initial.png ├── results_optimized.png ├── setup.png ├── spot_diagram_initial.png ├── spot_diagram_optimized.png └── teaser.jpg ├── lenses └── ThorLabs │ └── LE1234-A.txt └── metrology_calibrate.py /20210403/README.md: -------------------------------------------------------------------------------- 1 | Camera 1 (GS3-U3-50S5C): 2 | 3 | - f# = 16 4 | - mode 0 (2048x2048) 5 | - mono16 6 | 7 | Camera 2 (GS3-U3-50S5M): 8 | 9 | - f# = 16 10 | - mode 0 (2048x2048) 11 | - mono16 12 | -------------------------------------------------------------------------------- /20210403/calibration/calibration.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/calibration/calibration.mat -------------------------------------------------------------------------------- /20210403/calibration/cam1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/calibration/cam1.mat -------------------------------------------------------------------------------- /20210403/calibration/cam2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/calibration/cam2.mat -------------------------------------------------------------------------------- /20210403/calibration/cams.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/calibration/cams.mat -------------------------------------------------------------------------------- /20210403/calibration/checkerboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/calibration/checkerboard.png -------------------------------------------------------------------------------- /20210403/gamma_calibration/gammas.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/gamma_calibration/gammas.mat -------------------------------------------------------------------------------- /20210403/rotation_calibration/rotation.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/20210403/rotation_calibration/rotation.mat -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards self-calibrated lens metrology by differentiable refractive deflectometry 2 | This is the PyTorch implementation for our paper "Towards self-calibrated lens metrology by differentiable refractive deflectometry". 3 | ### [Project Page](https://vccimaging.org/Publications/Wang2021DiffDeflectometry/) | [Paper](https://vccimaging.org/Publications/Wang2021DiffDeflectometry/Wang2021DiffDeflectometry.pdf) 4 | 5 | [Towards self-calibrated lens metrology by differentiable refractive deflectometry](https://vccimaging.org/Publications/Wang2021DiffDeflectometry/Wang2021DiffDeflectometry.pdf) 6 | [Congli Wang](https://congliwang.github.io), 7 | [Ni Chen](https://ni-chen.github.io), and 8 | [Wolfgang Heidrich](https://vccimaging.org/People/heidriw)
9 | King Abdullah University of Science and Technology (KAUST)
10 | OSA Optics Express 2021 11 | 12 | 13 | Figure: Dual-camera refractive deflectometry for lens metrology. (a) Hardware setup. (b) Captured phase-shifted images, from which on-screen intersections are obtained. (c) A ray tracer models the setup by ray tracing each parameterized refractive surface, obtaining the modeled intersections. (d) Unknown parameters and pose are jointly optimized by minimizing the error between measurement and modeled. 14 | 15 | ## Features 16 | 17 | This repository implements a PyTorch differentiable ray tracer for deflectometry. The solver enables: 18 | - Fast (few seconds on a GPU), simultaneous lens parameters and pose estimation with gradients estimated by a differentiable ray tracer. 19 | - Rendering of photo-realistic images from a CMM-free, computationally calibrated metrology setup. 20 | - A fringe analyzer to solve displacements from phase-shifting patterns. 21 | 22 | ## Quick Start 23 | 24 | ### Real Experiment Example 25 | 26 | [`demo_experiments.py`](./demo_experiments.py) reproduces one of the paper's experimental results. Prerequisites: 27 | 28 | - Download raw image data `*.npz` from the google drive [here](https://drive.google.com/file/d/15a3T0wL7sWDaEXsAeZM0S7YXRS-fXRNO/view?usp=sharing). 29 | - Put the downloaded `*.npz` into directory `./20210403/measurement/LE1234-A`. 30 | 31 | Then run [`demo_experiments.py`](./demo_experiments.py). The script should output the following figures: 32 | 33 | | ![](./imgs/setup.png) | ![](./imgs/loss.png) | 34 | | :---------------------------------: | :-------------------------------------------: | 35 | | The physical setup for experiments. | Optimization loss with respect to iterations. | 36 | 37 | | ![](./imgs/spot_diagram_initial.png) | 38 | | :---------------------------------------: | 39 | | Spot diagrams on the display (initial). | 40 | | ![](./imgs/spot_diagram_optimized.png) | 41 | | Spot diagrams on the display (optimized). | 42 | 43 | | ![](./imgs/results_initial.png) | 44 | | :-------------------------------------------------------: | 45 | | Measurement images / modeled images / error. (initial) | 46 | | ![](./imgs/results_optimized.png) | 47 | | Measurement images / modeled images / error. (optimized) | 48 | 49 | ## Citation 50 | ```bibtex 51 | @article{wang2021towards, 52 | title={Towards self-calibrated lens metrology by differentiable refractive deflectometry}, 53 | author={Wang, Congli and Chen, Ni and Heidrich, Wolfgang}, 54 | journal={Optics Express}, 55 | volume={29}, 56 | number={19}, 57 | pages={30284--30295}, 58 | year={2021}, 59 | publisher={Optical Society of America} 60 | } 61 | ``` 62 | 63 | ## Contact 64 | Please either open an issue, or contact Congli Wang for questions. 65 | 66 | -------------------------------------------------------------------------------- /camera_acquisitions/images/calibration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/calibration.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=100/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/0.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=100/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/1.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=100/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/2.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=100/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/3.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=100/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/4.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=100/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/5.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=100/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/6.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=100/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=100/7.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=110/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/0.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=110/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/1.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=110/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/2.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=110/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/3.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=110/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/4.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=110/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/5.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=110/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/6.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=110/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=110/7.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=70/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/0.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=70/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/1.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=70/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/2.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=70/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/3.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=70/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/4.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=70/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/5.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=70/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/6.png -------------------------------------------------------------------------------- /camera_acquisitions/images/sinusoids/T=70/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/camera_acquisitions/images/sinusoids/T=70/7.png -------------------------------------------------------------------------------- /demo_experiments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import diffmetrology as dm 5 | from matplotlib.image import imread 6 | 7 | # load setup information 8 | data_path = './20210403' 9 | device = dm.init() 10 | # device = torch.device('cpu') 11 | 12 | print("Initialize a DiffMetrology object.") 13 | origin_shift = np.array([0.0, 0.0, 0.0]) 14 | DM = dm.DiffMetrology( 15 | calibration_path = data_path + '/calibration/', 16 | rotation_path = data_path + '/rotation_calibration/rotation.mat', 17 | lut_path = data_path + '/gamma_calibration/gammas.mat', 18 | origin_shift = origin_shift, 19 | scale=1.0, 20 | device=device 21 | ) 22 | 23 | print("Crop the region of interst in the original images.") 24 | filmsize = np.array([768, 768]) 25 | # filmsize = np.array([2048, 2048]) 26 | crop_offset = ((2048 - filmsize)/2).astype(int) 27 | for cam in DM.scene.cameras: 28 | cam.filmsize = filmsize 29 | cam.crop_offset = torch.Tensor(crop_offset).to(device) 30 | def crop(x): 31 | return x[..., crop_offset[0]:crop_offset[0]+filmsize[0], crop_offset[1]:crop_offset[1]+filmsize[1]] 32 | 33 | DM.test_setup() 34 | 35 | # ==== Read measurements 36 | lens_name = 'LE1234-A' 37 | 38 | DM.scene.lensgroup.load_file('Thorlabs/' + lens_name + '.txt') 39 | 40 | def show_parameters(): 41 | for i in range(len(DM.scene.lensgroup.surfaces)): 42 | print(f"Lens radius of curvature at surface[{i}]: {1.0/DM.scene.lensgroup.surfaces[i].c.item()}") 43 | print(DM.scene.lensgroup.surfaces[1].d) 44 | 45 | 46 | print("Ground Truth Lens Parameters:") 47 | show_parameters() 48 | 49 | 50 | angle = 0.0 51 | Ts = np.array([70, 100, 110]) # period of the sinusoids 52 | t = 0 53 | 54 | # load data 55 | option = 'experiment' 56 | if option == 'experiment': 57 | data = np.load(data_path + '/measurement/' + lens_name + '/data_new.npz') 58 | imgs = data['imgs'] 59 | refs = data['refs'] 60 | imgs = crop(imgs) 61 | refs = crop(refs) 62 | del data 63 | 64 | 65 | # solve for ps and valid map 66 | ps_cap, valid_cap, C = DM.solve_for_intersections(imgs, refs, Ts[t:]) 67 | 68 | # set display pattern 69 | # xs = [0, 4] 70 | xs = [0] 71 | sinusoid_path = './camera_acquisitions/images/sinusoids/T=' + str(Ts[t]) 72 | ims = [ np.mean(imread(sinusoid_path + '/' + str(x) + '.png'), axis=-1) for x in xs ] # use grayscale 73 | ims = np.array([ im/im.max() for im in ims ]) 74 | ims = np.sum(ims, axis=0) 75 | DM.set_texture(ims) 76 | del ims 77 | if option == 'experiment': 78 | # Obtained from running `metrology_calibrate.py` 79 | # DM.scene.screen.texture_shift = torch.Tensor([1.7445182, 1.1107264]).to(device) # LE1234-A 80 | DM.scene.screen.texture_shift = torch.Tensor([0. , 1.1106231]).to(device) # LE1234-A 81 | 82 | 83 | 84 | print("Shift `origin` by an estimated value") 85 | origin = DM._compute_mount_geometry(C, verbose=True) 86 | DM.scene.lensgroup.origin = torch.Tensor(origin).to(device) 87 | DM.scene.lensgroup.update() 88 | print(origin) 89 | 90 | 91 | print("Load real images") 92 | FR = dm.Fringe() 93 | a_cap, b_cap, psi_cap = FR.solve(imgs) 94 | imgs_sub = np.array([imgs[0,x,...] for x in xs]) 95 | imgs_sub = imgs_sub - a_cap[:,0,...] 96 | imgs_sub = np.sum(imgs_sub, axis=0) 97 | imgs_sub = valid_cap * torch.Tensor(imgs_sub).to(device) 98 | I0 = valid_cap * len(xs) * (imgs_sub - imgs_sub.min().item()) / (imgs_sub.max().item() - imgs_sub.min().item()) 99 | 100 | 101 | # Utility functions 102 | def forward(): 103 | ps = torch.stack(DM.trace(with_element=True, mask=valid_cap, angles=angle)[0])[..., 0:2] 104 | return ps 105 | 106 | def render(): 107 | I = valid_cap*torch.stack(DM.render(with_element=True, angles=angle)) 108 | I[torch.isnan(I)] = 0.0 109 | return I 110 | 111 | def visualize(ps_current, save_string): 112 | print("Showing spot diagrams at display.") 113 | DM.spot_diagram(ps_cap, ps_current, valid=valid_cap, angle=angle, with_grid=False) 114 | plt.show() 115 | 116 | print("Showing images (measurement & modeled & |measurement - modeled|).") 117 | 118 | # Render images from parameters 119 | I = render() 120 | 121 | fig, axes = plt.subplots(2, 3) 122 | for i in range(2): 123 | im = axes[i,0].imshow(I0[i].cpu(), vmin=0, vmax=1, cmap='gray') 124 | axes[i,0].set_title(f"Camera {i+1}\nMeasurement") 125 | axes[i,0].set_xlabel('[pixel]') 126 | axes[i,0].set_ylabel('[pixel]') 127 | plt.colorbar(im, ax=axes[i,0]) 128 | 129 | im = axes[i,1].imshow(I[i].cpu().detach(), vmin=0, vmax=1, cmap='gray') 130 | plt.colorbar(im, ax=axes[i,1]) 131 | axes[i,1].set_title(f"Camera {i+1}\nModeled") 132 | axes[i,1].set_xlabel('[pixel]') 133 | axes[i,1].set_ylabel('[pixel]') 134 | 135 | im = axes[i,2].imshow(I0[i].cpu() - I[i].cpu().detach(), vmin=-1, vmax=1, cmap='coolwarm') 136 | plt.colorbar(im, ax=axes[i,2]) 137 | axes[i,2].set_title(f"Camera {i+1}\nError") 138 | axes[i,2].set_xlabel('[pixel]') 139 | axes[i,2].set_ylabel('[pixel]') 140 | 141 | fig.suptitle(save_string) 142 | fig.savefig(save_string + str(i) + ".jpg", bbox_inches='tight') 143 | plt.show() 144 | 145 | 146 | print("Initialize lens parameters.") 147 | DM.scene.lensgroup.surfaces[0].c = torch.Tensor([0.00]).to(device) # 1st surface curvature 148 | DM.scene.lensgroup.surfaces[1].c = torch.Tensor([0.00]).to(device) # 2nd surface curvature 149 | DM.scene.lensgroup.surfaces[1].d = torch.Tensor([3.00]).to(device) # lens thickness 150 | DM.scene.lensgroup.theta_x = torch.Tensor([0.00]).to(device) # lens X-tilt angle 151 | DM.scene.lensgroup.theta_y = torch.Tensor([0.00]).to(device) # lens Y-tilt angle 152 | DM.scene.lensgroup.update() 153 | 154 | print("Visualize initial status.") 155 | ps_current = forward() 156 | visualize(ps_current, save_string="initial") 157 | 158 | 159 | print("Set optimization parameters.") 160 | diff_names = [ 161 | 'lensgroup.surfaces[0].c', 162 | 'lensgroup.surfaces[1].c', 163 | 'lensgroup.surfaces[1].d', 164 | 'lensgroup.origin', 165 | 'lensgroup.theta_x', 166 | 'lensgroup.theta_y' 167 | ] 168 | def loss(ps): 169 | return torch.sum((ps[valid_cap,...] - ps_cap[valid_cap,...])**2, axis=-1).mean() 170 | 171 | def func_yref_y(ps): 172 | b = valid_cap[...,None] * (ps_cap - ps) 173 | b[torch.isnan(b)] = 0.0 # handle NaN ... otherwise LM won't work! 174 | return b 175 | 176 | # Optimize 177 | ls = DM.solve(diff_names, forward, loss, func_yref_y, option='LM', R='I') 178 | print("Done. Show results (Spot RMS loss):") 179 | show_parameters() 180 | 181 | plt.figure() 182 | plt.semilogy(ls, '-o', color='k') 183 | plt.xlabel('LM iteration') 184 | plt.ylabel('Loss') 185 | plt.title("Opitmization Loss") 186 | 187 | print("Visualize optimized status.") 188 | ps_current = forward() 189 | visualize(ps_current, save_string="optimized") 190 | 191 | # Print mean displacement error 192 | T = ps_current - ps_cap 193 | E = torch.sqrt(torch.sum(T[valid_cap, ...]**2, axis=-1)).mean() 194 | print("error = {} [um]".format(E*1e3)) 195 | 196 | -------------------------------------------------------------------------------- /diffmetrology/__init__.py: -------------------------------------------------------------------------------- 1 | # image formation model 2 | from .basics import * 3 | from .shapes import * 4 | from .optics import * 5 | from .scene import * 6 | 7 | # algorithmic solvers 8 | from .solvers import * 9 | 10 | # utilities 11 | from .utils import * 12 | 13 | -------------------------------------------------------------------------------- /diffmetrology/basics.py: -------------------------------------------------------------------------------- 1 | import math 2 | from enum import Enum 3 | import torch 4 | import numpy as np 5 | 6 | # ---------------------------------------------------------------------------------------- 7 | 8 | class PrettyPrinter(): 9 | def __str__(self): 10 | lines = [self.__class__.__name__ + ':'] 11 | for key, val in vars(self).items(): 12 | if val.__class__.__name__ in ('list', 'tuple'): 13 | for i, v in enumerate(val): 14 | lines += '{}[{}]: {}'.format(key, i, v).split('\n') 15 | 16 | elif val.__class__.__name__ in 'dict': 17 | pass # ignore outputs for Hash tables 18 | elif key == key.upper() and len(key) > 5: 19 | pass # ignore all upper-case variables > 5 chars (constants) 20 | else: 21 | lines += '{}: {}'.format(key, val).split('\n') 22 | return '\n '.join(lines) 23 | 24 | def to(self, device=torch.device('cpu')): 25 | for key, val in vars(self).items(): 26 | if torch.is_tensor(val): 27 | exec('self.{x} = self.{x}.to(device)'.format(x=key)) 28 | elif issubclass(type(val), PrettyPrinter): 29 | exec(f'self.{key}.to(device)') 30 | elif val.__class__.__name__ in ('list', 'tuple'): 31 | for i, v in enumerate(val): 32 | if torch.is_tensor(v): 33 | exec('self.{x}[{i}] = self.{x}[{i}].to(device)'.format(x=key, i=i)) 34 | elif issubclass(type(v), PrettyPrinter): 35 | exec('self.{}[{}].to(device)'.format(key, i)) 36 | 37 | 38 | # ---------------------------------------------------------------------------------------- 39 | 40 | class Ray(PrettyPrinter): 41 | def __init__(self, o, d, wavelength, device=torch.device('cpu')): 42 | self.o = o # ray origin 43 | self.d = d # ray direction (normalized) 44 | 45 | # scalar-version 46 | self.wavelength = wavelength # [nm] 47 | self.mint = 1e-5 # [mm] 48 | self.maxt = 1e5 # [mm] 49 | self.to(device) 50 | 51 | def __call__(self, t): 52 | return self.o + t[..., None] * self.d 53 | 54 | class Transformation(PrettyPrinter): 55 | def __init__(self, R, t): 56 | if torch.is_tensor(R): 57 | self.R = R 58 | else: 59 | self.R = torch.Tensor(R) 60 | if torch.is_tensor(t): 61 | self.t = t 62 | else: 63 | self.t = torch.Tensor(t) 64 | 65 | def transform_point(self, o): 66 | return torch.squeeze(self.R @ o[..., None]) + self.t 67 | 68 | def transform_vector(self, d): 69 | return torch.squeeze(self.R @ d[..., None]) 70 | 71 | def transform_ray(self, ray): 72 | o = self.transform_point(ray.o) 73 | d = self.transform_vector(ray.d) 74 | if o.is_cuda: 75 | return Ray(o, d, ray.wavelength, device=torch.device('cuda')) 76 | else: 77 | return Ray(o, d, ray.wavelength) 78 | 79 | def inverse(self): 80 | RT = self.R.T 81 | t = self.t 82 | return Transformation(RT, -RT @ t) 83 | 84 | class Sampler(PrettyPrinter): 85 | def __init__(self): 86 | self.to() 87 | self.pi_over_2 = np.pi / 2 88 | self.pi_over_4 = np.pi / 4 89 | 90 | def concentric_sample_disk(self, x, y): 91 | # https://pbr-book.org/3ed-2018/Monte_Carlo_Integration/2D_Sampling_with_Multidimensional_Transformations 92 | 93 | # map uniform random numbers to [-1,1]^2 94 | x = 2 * x - 1 95 | y = 2 * y - 1 96 | 97 | # handle degeneracy at the origin when xy == [0,0] 98 | 99 | # apply concentric mapping to point 100 | cond = np.abs(x) > np.abs(y) 101 | r = np.where(cond, x, y) 102 | theta = np.where(cond, 103 | self.pi_over_4 * (y / (x + np.finfo(float).eps)), 104 | self.pi_over_2 - self.pi_over_4 * (x / (y + np.finfo(float).eps)) 105 | ) 106 | 107 | return r * np.cos(theta), r * np.sin(theta) 108 | 109 | # ---------------------------------------------------------------------------------------- 110 | class Filter(PrettyPrinter): 111 | def __init__(self, radius): 112 | self.radius = radius 113 | def eval(self, p): 114 | raise NotImplementedError() 115 | 116 | class Box(Filter): 117 | def __init__(self, radius=None): 118 | if radius is None: 119 | radius = [0.5, 0.5] 120 | Filter.__init__(self, radius) 121 | def eval(self, x): 122 | return torch.ones_like(x) 123 | 124 | class Triangle(Filter): 125 | def __init__(self, radius): 126 | if radius is None: 127 | radius = [2.0, 2.0] 128 | Filter.__init__(self, radius) 129 | def eval(self, p): 130 | x, y = p[...,0], p[...,1] 131 | return (torch.maximum(torch.zeros_like(x), self.radius[0] - x) * 132 | torch.maximum(torch.zeros_like(y), self.radius[1] - y)) 133 | 134 | # ---------------------------------------------------------------------------------------- 135 | 136 | class Material(PrettyPrinter): 137 | def __init__(self, name=None): 138 | self.name = 'vacuum' if name is None else name.lower() 139 | self.MATERIAL_TABLE = { # [nD, Abbe number] 140 | "vacuum": [1., math.inf], 141 | "air": [1.000293, math.inf], 142 | "occulder": [1., math.inf], 143 | "f2": [1.620, 36.37], 144 | "f15": [1.60570, 37.831], 145 | "uvfs": [1.458, 67.82], 146 | 147 | # https://shop.schott.com/advanced_optics/ 148 | "bk10": [1.49780, 66.954], 149 | "n-baf10": [1.67003, 47.11], 150 | "n-bk7": [1.51680, 64.17], 151 | "n-sf1": [1.71736, 29.62], 152 | "n-sf2": [1.64769, 33.82], 153 | "n-sf4": [1.75513, 27.38], 154 | "n-sf5": [1.67271, 32.25], 155 | "n-sf6": [1.80518, 25.36], 156 | "n-sf6ht": [1.80518, 25.36], 157 | "n-sf8": [1.68894, 31.31], 158 | "n-sf10": [1.72828, 28.53], 159 | "n-sf11": [1.78472, 25.68], 160 | "sf1": [1.71736, 29.51], 161 | "sf2": [1.64769, 33.85], 162 | "sf4": [1.75520, 27.58], 163 | "sf5": [1.67270, 32.21], 164 | "sf6": [1.80518, 25.43], 165 | "sf18": [1.72150, 29.245], 166 | 167 | # HIKARI.AGF 168 | "baf10": [1.67, 47.05], 169 | 170 | # SUMITA.AGF 171 | "sk16": [1.62040, 60.306], 172 | "sk1": [1.61030, 56.712], 173 | "ssk4": [1.61770, 55.116], 174 | 175 | # https://www.pgo-online.com/intl/B270.html 176 | "b270": [1.52290, 58.50], 177 | 178 | # https://refractiveindex.info, nd at 589.3 [nm] 179 | "s-nph1": [1.8078, 22.76], 180 | "d-k59": [1.5175, 63.50], 181 | 182 | "flint": [1.6200, 36.37], 183 | "pmma": [1.491756, 58.00], 184 | "polycarb": [1.585470, 30.00] 185 | } 186 | self.A, self.B = self._lookup_material() 187 | 188 | def ior(self, wavelength): 189 | """Computes index of refraction of a given wavelength (in [nm])""" 190 | return self.A + self.B / wavelength**2 191 | 192 | @staticmethod 193 | def nV_to_AB(n, V): 194 | def ivs(a): return 1./a**2 195 | lambdas = [656.3, 589.3, 486.1] 196 | B = (n - 1) / V / ( ivs(lambdas[2]) - ivs(lambdas[0]) ) 197 | A = n - B * ivs(lambdas[1]) 198 | return A, B 199 | 200 | def _lookup_material(self): 201 | out = self.MATERIAL_TABLE.get(self.name) 202 | if isinstance(out, list): 203 | n, V = out 204 | elif out is None: 205 | # try parsing input as a n/V pair 206 | tmp = self.name.split('/') 207 | n, V = float(tmp[0]), float(tmp[1]) 208 | return self.nV_to_AB(n, V) 209 | 210 | 211 | class InterpolationMode(Enum): 212 | nearest = 1 213 | linear = 2 214 | 215 | class BoundaryMode(Enum): 216 | zero = 1 217 | replicate = 2 218 | symmetric = 3 219 | periodic = 4 220 | 221 | class SimulationMode(Enum): 222 | render = 1 223 | trace = 2 224 | 225 | # ---------------------------------------------------------------------------------------- 226 | 227 | def init(): 228 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 229 | print("DiffMetrology is using: {}".format(device)) 230 | torch.set_default_tensor_type('torch.FloatTensor') 231 | return device 232 | 233 | def normalize(d): 234 | return d / torch.sqrt(torch.sum(d**2, axis=-1))[..., None] 235 | 236 | def set_zeros(x, valid=None): 237 | if valid == None: 238 | return torch.where(torch.isnan(x), torch.zeros_like(x), x) 239 | else: 240 | mask = valid[...,None] if len(x.shape) > len(valid.shape) else valid 241 | return torch.where(~mask, torch.zeros_like(x), x) 242 | 243 | def rodrigues_rotation_matrix(k, theta): # theta: [rad] 244 | # cross-product matrix 245 | kx, ky, kz = k[0], k[1], k[2] 246 | K = torch.Tensor([ 247 | [ 0, -kz, ky], 248 | [ kz, 0, -kx], 249 | [-ky, kx, 0] 250 | ]).to(k.device) 251 | if not torch.is_tensor(theta): 252 | theta = torch.Tensor(np.asarray(theta)).to(k.device) 253 | return torch.eye(3, device=k.device) + torch.sin(theta) * K + (1 - torch.cos(theta)) * K @ K 254 | 255 | def set_axes_equal(ax, scale=np.ones(3)): 256 | """ 257 | Make axes of 3D plot have equal scale (or scaled by `scale`). 258 | """ 259 | limits = np.array([ 260 | ax.get_xlim3d(), 261 | ax.get_ylim3d(), 262 | ax.get_zlim3d() 263 | ]) 264 | tmp = np.abs(limits[:,1]-limits[:,0]) 265 | ax.set_box_aspect(scale * tmp/np.min(tmp)) 266 | 267 | # ---------------------------------------------------------------------------------------- 268 | 269 | def generate_test_rays(): 270 | filmsize = np.array([4, 2]) 271 | 272 | o = np.array([3,4,-200]) 273 | o = np.tile(o[None, None, ...], [*filmsize, 1]) 274 | o = torch.Tensor(o) 275 | 276 | # d = np.array([0.02, -0.03, 1]) 277 | # d = np.tile(d[None, None, ...], [*filmsize, 1]) 278 | dx = 0.1 * torch.rand(*filmsize) 279 | dy = 0.1 * torch.rand(*filmsize) 280 | d = normalize(torch.stack((dx, dy, torch.ones_like(dx)), axis=-1)) 281 | 282 | wavelength = 500 # [nm] 283 | return Ray(o, d, wavelength) 284 | 285 | def generate_test_transformation(): 286 | k = np.random.rand(3) 287 | k = k / np.sqrt(np.sum(k**2)) 288 | theta = 1 # [rad] 289 | R = rodrigues_rotation_matrix(k, theta) 290 | t = np.random.rand(3) 291 | return Transformation(R, t) 292 | 293 | def generate_test_material(): 294 | return Material('N-BK7') 295 | 296 | # ---------------------------------------------------------------------------------------- 297 | 298 | 299 | if __name__ == "__main__": 300 | init() 301 | 302 | rays = generate_test_rays() 303 | to_world = generate_test_transformation() 304 | # print(to_world) 305 | 306 | rays_new = to_world.transform_ray(rays) 307 | o_old = rays.o[2,1,...].numpy() 308 | o_new = rays_new.o[2,1,...].numpy() 309 | assert np.sum(np.abs(to_world.R.numpy() @ o_old + to_world.t.numpy() - o_new)) < 1e-15 310 | 311 | material = generate_test_material() 312 | # print(material) 313 | -------------------------------------------------------------------------------- /diffmetrology/optics.py: -------------------------------------------------------------------------------- 1 | from .basics import * 2 | from .shapes import * 3 | from scipy.interpolate import LSQBivariateSpline 4 | from datetime import datetime 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class Step(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, input, eps): 11 | ctx.constant = eps 12 | ctx.save_for_backward(input) 13 | return (input > 0).float() 14 | 15 | @staticmethod 16 | def backward(ctx, grad_output): 17 | input, = ctx.saved_tensors 18 | return grad_output * torch.exp(-(ctx.constant*input)**2), None 19 | 20 | def ind(x, eps=0.5): 21 | return Step.apply(x, eps) 22 | 23 | 24 | class Lensgroup(Endpoint): 25 | """ 26 | The Lensgroup (consisted of multiple optical surfaces) is mounted on a rod, whose 27 | origin is `origin`. The Lensgroup has full degree-of-freedom to rotate around the 28 | x/y axes, with the rotation angles defined as `theta_x`, `theta_y`, and `theta_z` (in degree). 29 | 30 | In Lensgroup's coordinate (i.e. object frame coordinate), surfaces are allocated 31 | starting from `z = 0`. There is an additional, comparatively small 3D origin shift 32 | (`shift`) between the surface center (0,0,0) and the origin of the mount, i.e. 33 | shift + origin = lensgroup_origin. 34 | 35 | There are two configurations of ray tracing: forward and backward. In forward mode, 36 | rays start from `d = 0` surface and propagate along the +z axis; In backward mode, 37 | rays start from `d = d_max` surface and propagate along the -z axis. 38 | """ 39 | def __init__(self, origin, shift, theta_x=0., theta_y=0., theta_z=0., device=torch.device('cpu')): 40 | self.origin = torch.Tensor(origin).to(device) 41 | self.shift = torch.Tensor(shift).to(device) 42 | self.theta_x = torch.Tensor(np.asarray(theta_x)).to(device) 43 | self.theta_y = torch.Tensor(np.asarray(theta_y)).to(device) 44 | self.theta_z = torch.Tensor(np.asarray(theta_z)).to(device) 45 | self.device = device 46 | 47 | Endpoint.__init__(self, self._compute_transformation(), device) 48 | 49 | # TODO: in case you would like to render something ... 50 | self.mts_prepared = False 51 | 52 | def load_file(self, filename): 53 | LENSPATH = './lenses/' 54 | filename = filename if filename[0] == '.' else LENSPATH + filename 55 | self.surfaces, self.materials, self.r_last, d_last = self.read_lensfile(filename) 56 | self.d_sensor = d_last + self.surfaces[-1].d 57 | self._sync() 58 | 59 | def load(self, surfaces, materials): 60 | self.surfaces = surfaces 61 | self.materials = materials 62 | self._sync() 63 | 64 | def _sync(self): 65 | for i in range(len(self.surfaces)): 66 | self.surfaces[i].to(self.device) 67 | 68 | def update(self, _x=0.0, _y=0.0): 69 | self.to_world = self._compute_transformation(_x, _y) 70 | self.to_object = self.to_world.inverse() 71 | 72 | def _compute_transformation(self, _x=0.0, _y=0.0, _z=0.0): 73 | # we compute to_world transformation given the input positional parameters (angles) 74 | R = ( rodrigues_rotation_matrix(torch.Tensor([1, 0, 0]).to(self.device), torch.deg2rad(self.theta_x+_x)) @ 75 | rodrigues_rotation_matrix(torch.Tensor([0, 1, 0]).to(self.device), torch.deg2rad(self.theta_y+_y)) @ 76 | rodrigues_rotation_matrix(torch.Tensor([0, 0, 1]).to(self.device), torch.deg2rad(self.theta_z+_z)) ) 77 | t = self.origin + R @ self.shift 78 | return Transformation(R, t) 79 | 80 | @staticmethod 81 | def read_lensfile(filename): 82 | surfaces = [] 83 | materials = [] 84 | ds = [] # no use for now 85 | with open(filename) as file: 86 | line_no = 0 87 | d_total = 0. 88 | for line in file: 89 | if line_no < 2: # first two lines are comments; ignore them 90 | line_no += 1 91 | else: 92 | ls = line.split() 93 | surface_type, d, r = ls[0], float(ls[1]), float(ls[3])/2 94 | roc = float(ls[2]) 95 | if roc != 0: roc = 1/roc 96 | materials.append(Material(ls[4])) 97 | 98 | d_total += d 99 | ds.append(d) 100 | 101 | if surface_type == 'O': # object 102 | d_total = 0. 103 | ds.pop() 104 | elif surface_type == 'X': # XY-polynomial 105 | del roc 106 | ai = [] 107 | for ac in range(5, len(ls)): 108 | if ac == 5: 109 | b = float(ls[5]) 110 | else: 111 | ai.append(float(ls[ac])) 112 | surfaces.append(XYPolynomial(r, d_total, J=2, ai=ai, b=b)) 113 | elif surface_type == 'B': # B-spline 114 | del roc 115 | ai = [] 116 | for ac in range(5, len(ls)): 117 | if ac == 5: 118 | nx = int(ls[5]) 119 | elif ac == 6: 120 | ny = int(ls[6]) 121 | else: 122 | ai.append(float(ls[ac])) 123 | tx = ai[:nx+8] 124 | ai = ai[nx+8:] 125 | ty = ai[:ny+8] 126 | ai = ai[ny+8:] 127 | c = ai 128 | surfaces.append(BSpline(r, d, size=[nx, ny], tx=tx, ty=ty, c=c)) 129 | elif surface_type == 'M': # mixed-type of X and B 130 | raise NotImplementedError() 131 | elif surface_type == 'S': # aspheric surface 132 | if len(ls) <= 5: 133 | surfaces.append(Aspheric(r, d_total, roc)) 134 | else: 135 | ai = [] 136 | for ac in range(5, len(ls)): 137 | if ac == 5: 138 | conic = float(ls[5]) 139 | else: 140 | ai.append(float(ls[ac])) 141 | surfaces.append(Aspheric(r, d_total, roc, conic, ai)) 142 | elif surface_type == 'A': # aperture 143 | surfaces.append(Aspheric(r, d_total, roc)) 144 | elif surface_type == 'I': # sensor 145 | d_total -= d 146 | ds.pop() 147 | materials.pop() 148 | r_last = r 149 | d_last = d 150 | return surfaces, materials, r_last, d_last 151 | 152 | def reverse(self): 153 | # reverse surfaces 154 | d_total = self.surfaces[-1].d 155 | for i in range(len(self.surfaces)): 156 | self.surfaces[i].d = d_total - self.surfaces[i].d 157 | self.surfaces[i].reverse() 158 | self.surfaces.reverse() 159 | 160 | # reverse materials 161 | self.materials.reverse() 162 | 163 | # ------------------------------------------------------------------------------------ 164 | # Analysis 165 | # ------------------------------------------------------------------------------------ 166 | def rms(self, ps, units=1e3, option='centroid'): 167 | ps = ps[...,:2] * units 168 | if option == 'centroid': 169 | ps_mean = torch.mean(ps, axis=0) # centroid 170 | ps = ps - ps_mean[None,...] # we now use normalized ps 171 | spot_rms = torch.sqrt(torch.mean(torch.sum(ps**2, axis=-1))) 172 | return spot_rms 173 | 174 | def spot_diagram(self, ps, show=True, xlims=None, ylims=None, color='b.'): 175 | """ 176 | Plot spot diagram. 177 | """ 178 | units = 1e3 179 | units_str = '[um]' 180 | # units = 1 181 | # units_str = '[mm]' 182 | spot_rms = float(self.rms(ps, units)) 183 | ps = ps.cpu().detach().numpy()[...,:2] * units 184 | ps_mean = np.mean(ps, axis=0) # centroid 185 | ps = ps - ps_mean[None,...] # we now use normalized ps 186 | 187 | fig = plt.figure() 188 | ax = plt.axes() 189 | ax.plot(ps[...,1], ps[...,0], color) # permute axe 0 and 1 190 | plt.gca().set_aspect('equal', adjustable='box') 191 | plt.xlabel('x ' + units_str) 192 | plt.ylabel('y ' + units_str) 193 | plt.title("Spot diagram, RMS = " + str(round(spot_rms,3)) + ' ' + units_str) 194 | if xlims is not None: 195 | plt.xlim(*xlims) 196 | if ylims is not None: 197 | plt.ylim(*ylims) 198 | ax.set_aspect(1./ax.get_data_ratio()) 199 | 200 | fig.savefig("spotdiagram_" + datetime.now().strftime('%Y%m%d-%H%M%S-%f') + ".pdf", bbox_inches='tight') 201 | if show: plt.show() 202 | else: plt.close() 203 | return spot_rms 204 | 205 | # ------------------------------------------------------------------------------------ 206 | 207 | # ------------------------------------------------------------------------------------ 208 | # IO and visualizations 209 | # ------------------------------------------------------------------------------------ 210 | def draw_points(self, ax, options, seq=range(3)): 211 | for surface in self.surfaces: 212 | points_world = self._generate_points(surface) 213 | ax.plot(points_world[seq[0]], points_world[seq[1]], points_world[seq[2]], options) 214 | 215 | # ------------------------------------------------------------------------------------ 216 | 217 | # ------------------------------------------------------------------------------------ 218 | 219 | def trace(self, ray, stop_ind=None): 220 | # update transformation when doing pose estimation 221 | if ( 222 | self.origin.requires_grad or self.shift.requires_grad 223 | or 224 | self.theta_x.requires_grad or self.theta_y.requires_grad or self.theta_z.requires_grad 225 | ): 226 | self.update() 227 | 228 | # in local 229 | ray_in = self.to_object.transform_ray(ray) 230 | valid, mask_g, ray_out = self._trace(ray_in, stop_ind=stop_ind, record=False) 231 | 232 | # in world 233 | ray_final = self.to_world.transform_ray(ray_out) 234 | 235 | return ray_final, valid, mask_g 236 | 237 | # ------------------------------------------------------------------------------------ 238 | 239 | def _refract(self, wi, n, eta, approx=False): 240 | """ 241 | Snell's law (surface normal n defined along the positive z axis). 242 | """ 243 | if np.prod(eta.shape) > 1: 244 | eta_ = eta[..., None] 245 | else: 246 | eta_ = eta 247 | 248 | cosi = torch.sum(wi * n, axis=-1) 249 | 250 | if approx: 251 | tmp = 1. - eta**2 * (1. - cosi) 252 | g = tmp 253 | valid = tmp > 0. 254 | wt = tmp[..., None] * n + eta_ * (wi - cosi[..., None] * n) 255 | else: 256 | cost2 = 1. - (1. - cosi**2) * eta**2 257 | 258 | # 1. get valid map; 2. zero out invalid points; 3. add eps to avoid NaN grad at cost2==0. 259 | g = cost2 260 | valid = cost2 > 0. 261 | cost2 = torch.clamp(cost2, min=1e-8) 262 | tmp = torch.sqrt(cost2) 263 | 264 | wt = tmp[..., None] * n + eta_ * (wi - cosi[..., None] * n) 265 | return valid, wt, g 266 | 267 | def _trace(self, ray, stop_ind=None, record=False): 268 | if stop_ind is None: 269 | stop_ind = len(self.surfaces)-1 # last index to stop 270 | is_forward = (ray.d[..., 2] > 0).all() 271 | 272 | if is_forward: 273 | return self._forward_tracing(ray, stop_ind, record) 274 | else: 275 | return self._backward_tracing(ray, stop_ind, record) 276 | 277 | def _forward_tracing(self, ray, stop_ind, record): 278 | wavelength = ray.wavelength 279 | dim = ray.o[..., 2].shape 280 | 281 | if record: 282 | oss = [] 283 | for i in range(dim[0]): 284 | oss.append([ray.o[i,:].cpu().detach().numpy()]) 285 | 286 | valid = torch.ones(dim, device=self.device).bool() 287 | mask = torch.ones(dim, device=self.device) 288 | for i in range(stop_ind+1): 289 | eta = self.materials[i].ior(wavelength) / self.materials[i+1].ior(wavelength) 290 | 291 | # ray intersecting surface 292 | valid_o, p, g_o = self.surfaces[i].ray_surface_intersection(ray, valid) 293 | 294 | # get surface normal and refract 295 | n = self.surfaces[i].normal(p[..., 0], p[..., 1]) 296 | valid_d, d, g_d = self._refract(ray.d, -n, eta) 297 | 298 | # check validity 299 | mask = mask * ind(g_o) * ind(g_d) 300 | valid = valid & valid_o & valid_d 301 | if not valid.any(): 302 | break 303 | 304 | # update ray {o,d} 305 | if record: 306 | for os, v, pp in zip(oss, valid.cpu().detach().numpy(), p.cpu().detach().numpy()): 307 | if v: os.append(pp) 308 | ray.o = p 309 | ray.d = d 310 | 311 | if record: 312 | return valid, mask, ray, oss 313 | else: 314 | return valid, mask, ray 315 | 316 | def _backward_tracing(self, ray, stop_ind, record): 317 | wavelength = ray.wavelength 318 | dim = ray.o[..., 2].shape 319 | 320 | if record: 321 | oss = [] 322 | for i in range(dim[0]): 323 | oss.append([ray.o[i,:].cpu().detach().numpy()]) 324 | 325 | valid = torch.ones(dim, device=ray.o.device).bool() 326 | mask = torch.ones(dim, device=ray.o.device) 327 | for i in np.flip(range(stop_ind+1)): 328 | surface = self.surfaces[i] 329 | eta = self.materials[i+1].ior(wavelength) / self.materials[i].ior(wavelength) 330 | 331 | # ray intersecting surface 332 | valid_o, p, g_o = surface.ray_surface_intersection(ray, valid) 333 | 334 | # get surface normal and refract 335 | n = surface.normal(p[..., 0], p[..., 1]) 336 | valid_d, d, g_d = self._refract(ray.d, n, eta) # backward: no need to revert the normal 337 | 338 | # check validity 339 | mask = mask * ind(g_o) * ind(g_d) 340 | valid = valid & valid_o & valid_d 341 | if not valid.any(): 342 | break 343 | 344 | # update ray {o,d} 345 | if record: 346 | for os, v, pp in zip(oss, valid.numpy(), p.cpu().detach().numpy()): 347 | if v: os.append(pp) 348 | ray.o = p 349 | ray.d = d 350 | 351 | if record: 352 | return valid, mask, ray, oss 353 | else: 354 | return valid, mask, ray 355 | 356 | def _generate_points(self, surface, with_boundary=False): 357 | R = surface.r 358 | x = y = torch.linspace(-R, R, surface.APERTURE_SAMPLING, device=self.device) 359 | X, Y = torch.meshgrid(x, y) 360 | Z = surface.surface_with_offset(X, Y) 361 | valid = X**2 + Y**2 <= R**2 362 | if with_boundary: 363 | from scipy import ndimage 364 | tmp = ndimage.convolve(valid.cpu().numpy().astype('float'), np.array([[0,1,0],[1,0,1],[0,1,0]])) 365 | boundary = valid.cpu().numpy() & (tmp != 4) 366 | boundary = boundary[valid.cpu().numpy()].flatten() 367 | points_local = torch.stack(tuple(v[valid].flatten() for v in [X, Y, Z]), axis=-1) 368 | points_world = self.to_world.transform_point(points_local).T.cpu().detach().numpy() 369 | if with_boundary: 370 | return points_world, boundary 371 | else: 372 | return points_world 373 | 374 | class Surface(PrettyPrinter): 375 | def __init__(self, r, d, device=torch.device('cpu')): 376 | # self.r = torch.Tensor(np.array(r)) 377 | if torch.is_tensor(d): 378 | self.d = d 379 | else: 380 | self.d = torch.Tensor(np.asarray(float(d))).to(device) 381 | self.r = float(r) 382 | self.device = device 383 | self.NEWTONS_MAXITER = 10 384 | self.NEWTONS_TOLERANCE_TIGHT = 50e-6 # in [mm], i.e. 50 [nm] here (up to <10 [nm]) 385 | self.NEWTONS_TOLERANCE_LOOSE = 300e-6 # in [mm], i.e. 300 [nm] here (up to <10 [nm]) 386 | self.APERTURE_SAMPLING = 11 387 | 388 | # === Common methods (must not be overridden) 389 | def surface_with_offset(self, x, y): 390 | return self.surface(x, y) + self.d 391 | 392 | def normal(self, x, y): 393 | ds_dxyz = self.surface_derivatives(x, y) 394 | return normalize(torch.stack(ds_dxyz, axis=-1)) 395 | 396 | def surface_area(self): 397 | return math.pi * self.r**2 398 | 399 | def ray_surface_intersection(self, ray, active=None): 400 | """ 401 | Returns: 402 | - g >= 0: valid or not 403 | - p: intersection point 404 | - g: explicit funciton 405 | """ 406 | solution_found, local = self.newtons_method(ray.maxt, ray.o, ray.d) 407 | r2 = local[..., 0]**2 + local[..., 1]**2 408 | g = self.r**2 - r2 409 | if active is None: 410 | valid_o = solution_found & ind(g > 0.).bool() 411 | else: 412 | valid_o = active & solution_found & ind(g > 0.).bool() 413 | return valid_o, local, g 414 | 415 | def newtons_method_impl(self, maxt, t0, dx, dy, dz, ox, oy, oz, A, B, C): 416 | t_delta = torch.zeros_like(oz) 417 | 418 | # Iterate until the intersection error is small 419 | t = maxt * torch.ones_like(oz) 420 | residual = maxt * torch.ones_like(oz) 421 | it = 0 422 | while (torch.abs(residual) > self.NEWTONS_TOLERANCE_TIGHT).any() and (it < self.NEWTONS_MAXITER): 423 | it += 1 424 | t = t0 + t_delta 425 | residual, s_derivatives_dot_D = self.surface_and_derivatives_dot_D( 426 | t, dx, dy, dz, ox, oy, t_delta * dz, A, B, C # here z = t_delta * dz 427 | ) 428 | t_delta -= residual / s_derivatives_dot_D 429 | t = t0 + t_delta 430 | valid = (torch.abs(residual) < self.NEWTONS_TOLERANCE_LOOSE) & (t <= maxt) 431 | return t, t_delta, valid 432 | 433 | def newtons_method(self, maxt, o, D, option='implicit'): 434 | # Newton's method to find the root of the ray-surface intersection point. 435 | # Two modes are supported here: 436 | # 437 | # 1. 'explicit": This implements the loop using autodiff, and gradients will be 438 | # accurate for o, D, and self.parameters. Slow and memory-consuming. 439 | # 440 | # 2. 'implicit": This implements the loop as proposed in the paper, it finds the 441 | # solution without autodiff, then hook up the gradient. Less memory consumption. 442 | 443 | # pre-compute constants 444 | ox, oy, oz = (o[..., i].clone() for i in range(3)) 445 | dx, dy, dz = (D[..., i].clone() for i in range(3)) 446 | A = dx**2 + dy**2 447 | B = 2 * (dx * ox + dy * oy) 448 | C = ox**2 + oy**2 449 | 450 | # initial guess of t 451 | t0 = (self.d - oz) / dz 452 | 453 | if option == 'explicit': 454 | t, t_delta, valid = self.newtons_method_impl( 455 | maxt, t0, dx, dy, dz, ox, oy, oz, A, B, C 456 | ) 457 | elif option == 'implicit': 458 | with torch.no_grad(): 459 | t, t_delta, valid = self.newtons_method_impl( 460 | maxt, t0, dx, dy, dz, ox, oy, oz, A, B, C 461 | ) 462 | s_derivatives_dot_D = self.surface_and_derivatives_dot_D( 463 | t, dx, dy, dz, ox, oy, t_delta * dz, A, B, C 464 | )[1] 465 | t = t0 + t_delta # re-engage autodiff 466 | 467 | t = t - (self.g(ox + t * dx, oy + t * dy) + self.h(oz + t * dz) + self.d)/s_derivatives_dot_D 468 | else: 469 | raise Exception('option={} is not available!'.format(option)) 470 | 471 | p = o + t[..., None] * D 472 | return valid, p 473 | 474 | # === Virtual methods (must be overridden) 475 | def g(self, x, y): 476 | raise NotImplementedError() 477 | 478 | def dgd(self, x, y): 479 | """ 480 | Derivatives of g: (g'x, g'y). 481 | """ 482 | raise NotImplementedError() 483 | 484 | def h(self, z): 485 | raise NotImplementedError() 486 | 487 | def dhd(self, z): 488 | """ 489 | Derivative of h. 490 | """ 491 | raise NotImplementedError() 492 | 493 | def surface(self, x, y): 494 | """ 495 | Solve z from h(z) = -g(x,y). 496 | """ 497 | raise NotImplementedError() 498 | 499 | def reverse(self): 500 | raise NotImplementedError() 501 | 502 | # === Default methods (better be overridden) 503 | def surface_derivatives(self, x, y): 504 | """ 505 | Returns \nabla f = \nabla (g(x,y) + h(z)) = (dg/dx, dg/dy, dh/dz). 506 | (Note: this default implementation is not efficient) 507 | """ 508 | gx, gy = self.dgd(x, y) 509 | z = self.surface(x, y) 510 | return gx, gy, self.dhd(z) 511 | 512 | def surface_and_derivatives_dot_D(self, t, dx, dy, dz, ox, oy, z, A, B, C): 513 | """ 514 | Returns g(x,y)+h(z) and dot((g'x,g'y,h'), (dx,dy,dz)). 515 | (Note: this default implementation is not efficient) 516 | """ 517 | x = ox + t * dx 518 | y = oy + t * dy 519 | s = self.g(x,y) + self.h(z) 520 | sx, sy = self.dgd(x, y) 521 | sz = self.dhd(z) 522 | return s, sx*dx + sy*dy + sz*dz 523 | 524 | 525 | class Aspheric(Surface): 526 | """ 527 | Aspheric surface: https://en.wikipedia.org/wiki/Aspheric_lens. 528 | """ 529 | def __init__(self, r, d, c=0., k=0., ai=None, device=torch.device('cpu')): 530 | Surface.__init__(self, r, d, device) 531 | self.c, self.k = (torch.Tensor(np.array(v)) for v in [c, k]) 532 | self.ai = None 533 | if ai is not None: 534 | self.ai = torch.Tensor(np.array(ai)) 535 | 536 | # === Common methods 537 | def g(self, x, y): 538 | return self._g(x**2 + y**2) 539 | 540 | def dgd(self, x, y): 541 | dsdr2 = 2 * self._dgd(x**2 + y**2) 542 | return dsdr2*x, dsdr2*y 543 | 544 | def h(self, z): 545 | return -z 546 | 547 | def dhd(self, z): 548 | return -torch.ones_like(z) 549 | 550 | def surface(self, x, y): 551 | return self._g(x**2 + y**2) 552 | 553 | def reverse(self): 554 | self.c = -self.c 555 | if self.ai is not None: 556 | self.ai = -self.ai 557 | 558 | def surface_derivatives(self, x, y): 559 | dsdr2 = 2 * self._dgd(x**2 + y**2) 560 | return dsdr2*x, dsdr2*y, -torch.ones_like(x) 561 | 562 | def surface_and_derivatives_dot_D(self, t, dx, dy, dz, ox, oy, z, A, B, C): 563 | r2 = A * t**2 + B * t + C 564 | return self._g(r2) - z, self._dgd(r2) * (2*A*t + B) - dz 565 | 566 | # === Private methods 567 | def _g(self, r2): 568 | tmp = r2*self.c 569 | total_surface = tmp / (1 + torch.sqrt(1 - (1+self.k) * tmp*self.c)) 570 | higher_surface = 0 571 | if self.ai is not None: 572 | for i in np.flip(range(len(self.ai))): 573 | higher_surface = r2 * higher_surface + self.ai[i] 574 | higher_surface = higher_surface * r2**2 575 | return total_surface + higher_surface 576 | 577 | def _dgd(self, r2): 578 | alpha_r2 = (1 + self.k) * self.c**2 * r2 579 | tmp = torch.sqrt(1 - alpha_r2) # TODO: potential NaN grad 580 | total_derivative = self.c * (1 + tmp - 0.5*alpha_r2) / (tmp * (1 + tmp)**2) 581 | 582 | higher_derivative = 0 583 | if self.ai is not None: 584 | for i in np.flip(range(len(self.ai))): 585 | higher_derivative = r2 * higher_derivative + (i+2) * self.ai[i] 586 | return total_derivative + higher_derivative * r2 587 | 588 | 589 | # ---------------------------------------------------------------------------------------- 590 | 591 | class BSpline(Surface): 592 | """ 593 | Implemented according to Wikipedia. 594 | """ 595 | def __init__(self, r, d, size, px=3, py=3, tx=None, ty=None, c=None, device=torch.device('cpu')): # input c is 1D 596 | Surface.__init__(self, r, d, device) 597 | self.px = px 598 | self.py = py 599 | self.size = np.asarray(size) 600 | 601 | # knots 602 | if tx is None: 603 | self.tx = None 604 | else: 605 | if len(tx) != size[0] + 2*(self.px + 1): 606 | raise Exception('len(tx) is not correct!') 607 | self.tx = torch.Tensor(np.asarray(tx)).to(self.device) 608 | if ty is None: 609 | self.ty = None 610 | else: 611 | if len(ty) != size[1] + 2*(self.py + 1): 612 | raise Exception('len(ty) is not correct!') 613 | self.ty = torch.Tensor(np.asarray(ty)).to(self.device) 614 | 615 | # c is the only differentiable parameter 616 | c_shape = size + np.array([self.px, self.py]) + 1 617 | if c is None: 618 | self.c = None 619 | else: 620 | c = np.asarray(c) 621 | if c.size != np.prod(c_shape): 622 | raise Exception('len(c) is not correct!') 623 | self.c = torch.Tensor(c.reshape(*c_shape)).to(self.device) 624 | 625 | if (self.tx is None) or (self.ty is None) or (self.c is None): 626 | self.tx = self._generate_knots(self.r, size[0], p=px, device=device) 627 | self.ty = self._generate_knots(self.r, size[1], p=py, device=device) 628 | self.c = torch.zeros(*c_shape, device=device) 629 | else: 630 | self.to(self.device) 631 | 632 | @staticmethod 633 | def _generate_knots(R, n, p=3, device=torch.device('cpu')): 634 | t = np.linspace(-R, R, n) 635 | step = t[1] - t[0] 636 | T = t[0] - 0.9 * step 637 | np.pad(t, p+1, 'constant', constant_values=step) 638 | t = np.concatenate((np.ones(p+1)*T, t, -np.ones(p+1)*T), axis=0) 639 | return torch.Tensor(t).to(device) 640 | 641 | def fit(self, x, y, z, eps=1e-3): 642 | x, y, z = (v.flatten() for v in [x, y, z]) 643 | 644 | # knot positions within [-r, r]^2 645 | X = np.linspace(-self.r, self.r, self.size[0]) 646 | Y = np.linspace(-self.r, self.r, self.size[1]) 647 | bs = LSQBivariateSpline(x, y, z, X, Y, kx=self.px, ky=self.py, eps=eps) 648 | tx, ty = bs.get_knots() 649 | c = bs.get_coeffs().reshape(len(tx)-self.px-1, len(ty)-self.py-1) 650 | 651 | # convert to torch.Tensor 652 | self.tx, self.ty, self.c = (torch.Tensor(v).to(self.device) for v in [tx, ty, c]) 653 | 654 | # === Common methods 655 | def g(self, x, y): 656 | return self._deBoor2(x, y) 657 | 658 | def dgd(self, x, y): 659 | return self._deBoor2(x, y, dx=1), self._deBoor2(x, y, dy=1) 660 | 661 | def h(self, z): 662 | return -z 663 | 664 | def dhd(self, z): 665 | return -torch.ones_like(z) 666 | 667 | def surface(self, x, y): 668 | return self._deBoor2(x, y) 669 | 670 | def surface_derivatives(self, x, y): 671 | return self._deBoor2(x, y, dx=1), self._deBoor2(x, y, dy=1), -torch.ones_like(x) 672 | 673 | def surface_and_derivatives_dot_D(self, t, dx, dy, dz, ox, oy, z, A, B, C): 674 | x = ox + t * dx 675 | y = oy + t * dy 676 | s, sx, sy = self._deBoor2(x, y, dx=-1, dy=-1) 677 | return s - z, sx*dx + sy*dy - dz 678 | 679 | def reverse(self): 680 | self.c = -self.c 681 | 682 | # === Private methods 683 | def _deBoor(self, x, t, c, p=3, is2Dfinal=False, dx=0): 684 | """ 685 | Arguments 686 | --------- 687 | x: Position. 688 | t: Array of knot positions, needs to be padded as described above. 689 | c: Array of control points. 690 | p: Degree of B-spline. 691 | dx: 692 | - 0: surface only 693 | - 1: surface 1st derivative only 694 | - -1: surface and its 1st derivative 695 | """ 696 | k = torch.sum((x[None,...] > t[...,None]).int(), axis=0) - (p+1) 697 | 698 | if is2Dfinal: 699 | inds = np.indices(k.shape)[0] 700 | def _c(jk): return c[jk, inds] 701 | else: 702 | def _c(jk): return c[jk, ...] 703 | 704 | need_newdim = (len(c.shape) > 1) & (not is2Dfinal) 705 | 706 | def f(a, b, alpha): 707 | if need_newdim: 708 | alpha = alpha[...,None] 709 | return (1.0 - alpha) * a + alpha * b 710 | 711 | # surface only 712 | if dx == 0: 713 | d = [_c(j+k) for j in range(0, p+1)] 714 | 715 | for r in range(-p, 0): 716 | for j in range(p, p+r, -1): 717 | left = j+k 718 | t_left = t[left] 719 | t_right = t[left-r] 720 | alpha = (x - t_left) / (t_right - t_left) 721 | d[j] = f(d[j-1], d[j], alpha) 722 | return d[p] 723 | 724 | # surface 1st derivative only 725 | if dx == 1: 726 | q = [] 727 | for j in range(1, p+1): 728 | jk = j+k 729 | tmp = t[jk+p] - t[jk] 730 | if need_newdim: 731 | tmp = tmp[..., None] 732 | q.append(p * (_c(jk) - _c(jk-1)) / tmp) 733 | 734 | for r in range(-p, -1): 735 | for j in range(p-1, p+r, -1): 736 | left = j+k 737 | t_right = t[left-r] 738 | t_left_ = t[left+1] 739 | alpha = (x - t_left_) / (t_right - t_left_) 740 | q[j] = f(q[j-1], q[j], alpha) 741 | return q[p-1] 742 | 743 | # surface and its derivative (all) 744 | if dx < 0: 745 | d, q = [], [] 746 | for j in range(0, p+1): 747 | jk = j+k 748 | c_jk = _c(jk) 749 | d.append(c_jk) 750 | if j > 0: 751 | tmp = t[jk+p] - t[jk] 752 | if need_newdim: 753 | tmp = tmp[..., None] 754 | q.append(p * (c_jk - _c(jk-1)) / tmp) 755 | 756 | for r in range(-p, 0): 757 | for j in range(p, p+r, -1): 758 | left = j+k 759 | t_left = t[left] 760 | t_right = t[left-r] 761 | alpha = (x - t_left) / (t_right - t_left) 762 | d[j] = f(d[j-1], d[j], alpha) 763 | 764 | if (r < -1) & (j < p): 765 | t_left_ = t[left+1] 766 | alpha = (x - t_left_) / (t_right - t_left_) 767 | q[j] = f(q[j-1], q[j], alpha) 768 | return d[p], q[p-1] 769 | 770 | def _deBoor2(self, x, y, dx=0, dy=0): 771 | """ 772 | Arguments 773 | --------- 774 | x, y : Position. 775 | dx, dy: 776 | """ 777 | if not torch.is_tensor(x): 778 | x = torch.Tensor(np.asarray(x)).to(self.device) 779 | if not torch.is_tensor(y): 780 | y = torch.Tensor(np.asarray(y)).to(self.device) 781 | dim = x.shape 782 | 783 | x = x.flatten() 784 | y = y.flatten() 785 | 786 | # handle boundary issue 787 | x = torch.clamp(x, min=-self.r, max=self.r) 788 | y = torch.clamp(y, min=-self.r, max=self.r) 789 | 790 | if (dx == 0) & (dy == 0): # spline 791 | s_tmp = self._deBoor(x, self.tx, self.c, self.px) 792 | s = self._deBoor(y, self.ty, s_tmp.T, self.py, True) 793 | return s.reshape(dim) 794 | elif (dx == 1) & (dy == 0): # x-derivative 795 | s_tmp = self._deBoor(y, self.ty, self.c.T, self.py) 796 | s_x = self._deBoor(x, self.tx, s_tmp.T, self.px, True, dx) 797 | return s_x.reshape(dim) 798 | elif (dy == 1) & (dx == 0): # y-derivative 799 | s_tmp = self._deBoor(x, self.tx, self.c, self.px) 800 | s_y = self._deBoor(y, self.ty, s_tmp.T, self.py, True, dy) 801 | return s_y.reshape(dim) 802 | else: # return all 803 | s_tmpx = self._deBoor(x, self.tx, self.c, self.px) 804 | s_tmpy = self._deBoor(y, self.ty, self.c.T, self.py) 805 | s, s_x = self._deBoor(x, self.tx, s_tmpy.T, self.px, True, -abs(dx)) 806 | s_y = self._deBoor(y, self.ty, s_tmpx.T, self.py, True, abs(dy)) 807 | return s.reshape(dim), s_x.reshape(dim), s_y.reshape(dim) 808 | 809 | 810 | class XYPolynomial(Surface): 811 | """ 812 | General XY polynomial surface of equation of parameters: 813 | 814 | explicit: b z^2 - z + \sum{i,j} a_ij x^i y^{j-i} = 0 815 | implicit: (denote c = \sum{i,j} a_ij x^i y^{j-i}) 816 | z = (1 - \sqrt{1 - 4 b c}) / (2b) 817 | 818 | explicit derivatives: 819 | (2 b z - 1) dz + \sum{i,j} a_ij x^{i-1} y^{j-i-1} ( i y dx + (j-i) x dy ) = 0 820 | 821 | dx = \sum{i,j} a_ij i x^{i-1} y^{j-i} 822 | dy = \sum{i,j} a_ij (j-i) x^{i} y^{j-i-1} 823 | dz = 2 b z - 1 824 | """ 825 | def __init__(self, r, d, J=0, ai=None, b=None, device=torch.device('cpu')): 826 | Surface.__init__(self, r, d, device) 827 | self.J = J 828 | # differentiable parameters (default: all ai's and b are zeros) 829 | if ai is None: 830 | self.ai = torch.zeros(self.J2aisize(J)) if J > 0 else torch.array([0]) 831 | else: 832 | if len(ai) != self.J2aisize(J): 833 | raise Exception("len(ai) != (J+1)*(J+2)/2 !") 834 | self.ai = torch.Tensor(ai).to(device) 835 | if b is None: 836 | b = 0. 837 | self.b = torch.Tensor(np.asarray(b)).to(device) 838 | print('ai.size = {}'.format(self.ai.shape[0])) 839 | self.to(self.device) 840 | 841 | @staticmethod 842 | def J2aisize(J): 843 | return int((J+1)*(J+2)/2) 844 | 845 | def center(self): 846 | x0 = -self.ai[2]/self.ai[5] 847 | y0 = -self.ai[1]/self.ai[3] 848 | return x0, y0 849 | 850 | def fit(self, x, y, z): 851 | x, y, z = (torch.Tensor(v.flatten()) for v in [x, y, z]) 852 | A, AT = self._construct_A(x, y, z**2) 853 | coeffs = torch.solve(AT @ z[...,None], AT @ A)[0] 854 | self.b = coeffs[0][0] 855 | self.ai = coeffs[1:].flatten() 856 | 857 | # === Common methods 858 | def g(self, x, y): 859 | c = torch.zeros_like(x) 860 | count = 0 861 | for j in range(self.J+1): 862 | for i in range(j+1): 863 | c = c + self.ai[count] * torch.pow(x, i) * torch.pow(y, j-i) 864 | count += 1 865 | return c 866 | 867 | def dgd(self, x, y): 868 | sx = torch.zeros_like(x) 869 | sy = torch.zeros_like(x) 870 | count = 0 871 | for j in range(self.J+1): 872 | for i in range(j+1): 873 | if j > 0: 874 | sx = sx + self.ai[count] * i * torch.pow(x, max(i-1,0)) * torch.pow(y, j-i) 875 | sy = sy + self.ai[count] * (j-i) * torch.pow(x, i) * torch.pow(y, max(j-i-1,0)) 876 | count += 1 877 | return sx, sy 878 | 879 | def h(self, z): 880 | return self.b * z**2 - z 881 | 882 | def dhd(self, z): 883 | return 2 * self.b * z - torch.ones_like(z) 884 | 885 | def surface(self, x, y): 886 | x, y = (v if torch.is_tensor(x) else torch.Tensor(v) for v in [x, y]) 887 | c = self.g(x, y) 888 | return self._solve_for_z(c) 889 | 890 | def reverse(self): 891 | self.b = -self.b 892 | self.ai = -self.ai 893 | 894 | def surface_derivatives(self, x, y): 895 | x, y = (v if torch.is_tensor(x) else torch.Tensor(v) for v in [x, y]) 896 | sx = torch.zeros_like(x) 897 | sy = torch.zeros_like(x) 898 | c = torch.zeros_like(x) 899 | count = 0 900 | for j in range(self.J+1): 901 | for i in range(j+1): 902 | c = c + self.ai[count] * torch.pow(x, i) * torch.pow(y, j-i) 903 | if j > 0: 904 | sx = sx + self.ai[count] * i * torch.pow(x, max(i-1,0)) * torch.pow(y, j-i) 905 | sy = sy + self.ai[count] * (j-i) * torch.pow(x, i) * torch.pow(y, max(j-i-1,0)) 906 | count += 1 907 | z = self._solve_for_z(c) 908 | return sx, sy, self.dhd(z) 909 | 910 | def surface_and_derivatives_dot_D(self, t, dx, dy, dz, ox, oy, z, A, B, C): 911 | x = ox + t * dx 912 | y = oy + t * dy 913 | sx = torch.zeros_like(x) 914 | sy = torch.zeros_like(x) 915 | c = torch.zeros_like(x) 916 | count = 0 917 | for j in range(self.J+1): 918 | for i in range(j+1): 919 | c = c + self.ai[count] * torch.pow(x, i) * torch.pow(y, j-i) 920 | if j > 0: 921 | sx = sx + self.ai[count] * i * torch.pow(x, max(i-1,0)) * torch.pow(y, j-i) 922 | sy = sy + self.ai[count] * (j-i) * torch.pow(x, i) * torch.pow(y, max(j-i-1,0)) 923 | count += 1 924 | s = c + self.h(z) 925 | return s, sx*dx + sy*dy + self.dhd(z)*dz 926 | 927 | # === Private methods 928 | def _construct_A(self, x, y, A_init=None): 929 | A = torch.zeros_like(x) if A_init == None else A_init 930 | for j in range(self.J+1): 931 | for i in range(j+1): 932 | A = torch.vstack((A, torch.pow(x, i) * torch.pow(y, j-i))) 933 | AT = A[1:,:] if A_init == None else A 934 | return AT.T, AT 935 | 936 | def _solve_for_z(self, c): 937 | if self.b == 0: 938 | return c 939 | else: 940 | return (1. - torch.sqrt(1. - 4*self.b*c)) / (2*self.b) 941 | 942 | 943 | # ---------------------------------------------------------------------------------------- 944 | 945 | def generate_test_lensgroup(): 946 | origin_mount = np.array([0, 0, -70]) 947 | origin_shift = np.array([0.1, 0.2, 0.3]) 948 | theta_x = 180 949 | theta_y = 0 950 | 951 | lensname = 'Thorlabs/AL50100-A.txt' 952 | lensgroup = Lensgroup(lensname, origin_mount, origin_shift, theta_x, theta_y) 953 | return lensgroup 954 | 955 | # ---------------------------------------------------------------------------------------- 956 | 957 | 958 | if __name__ == "__main__": 959 | init() 960 | 961 | lensgroup = generate_test_lensgroup() 962 | # print(lensgroup) 963 | 964 | ray = generate_test_rays() 965 | ray_out, valid, mask = lensgroup.trace(ray) 966 | 967 | ray_out.d = -ray_out.d 968 | ray_out.update() 969 | 970 | ray_new, valid_, mask_ = lensgroup.trace(ray_out) 971 | # assert np.sum(np.abs(ray.d.numpy() + ray_new.d.numpy())) < 1e-5 972 | -------------------------------------------------------------------------------- /diffmetrology/scene.py: -------------------------------------------------------------------------------- 1 | from .basics import * 2 | from .shapes import * 3 | from .optics import * 4 | 5 | class Scene(PrettyPrinter): 6 | def __init__(self, cameras, screen, lensgroup=None, device=torch.device('cpu')): 7 | self.cameras = cameras 8 | self.screen = screen 9 | self.lensgroup = lensgroup 10 | self.device = device 11 | 12 | self.camera_count = len(self.cameras) 13 | 14 | # rendering options 15 | self.wavelength = 500 # [nm] 16 | 17 | def render(self, i=None, with_element=True, mask=None, to_numpy=False): 18 | im = self._simulate(i, with_element, mask, SimulationMode.render) 19 | if to_numpy: 20 | im = [x.cpu().detach().numpy() for x in im] 21 | return im 22 | 23 | def trace(self, i=None, with_element=True, mask=None, to_numpy=False): 24 | results = self._simulate(i, with_element, mask, SimulationMode.trace) 25 | p = [x[0].cpu().detach().numpy() if to_numpy else x[0] for x in results] 26 | valid = [x[1].cpu().detach().numpy() if to_numpy else x[1] for x in results] 27 | mask_g = [x[2].cpu().detach().numpy() if to_numpy else x[2] for x in results] 28 | return p, valid, mask_g 29 | 30 | def plot_setup(self): 31 | fig = plt.figure() 32 | ax = fig.add_subplot(111, projection='3d') 33 | 34 | # generate color-ring 35 | if self.camera_count <= 6: 36 | colors = ['b','r','g','c','m','y'] 37 | else: 38 | colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] 39 | colors = colors * np.ceil(self.camera_count/len(colors)).astype(int) 40 | colors = colors[:self.camera_count] 41 | 42 | # draw screen 43 | seq = [0,2,1] # draw scene in [x,z,y] order 44 | self.screen.draw_points(ax, 'k', seq) 45 | 46 | # draw metrology part 47 | if self.lensgroup is not None: 48 | self.lensgroup.draw_points(ax, 'y.', seq) 49 | 50 | # draw cameras 51 | for i, camera in enumerate(self.cameras): 52 | camera.draw_points(ax, colors[i], seq) 53 | 54 | # pretty plot 55 | labels = 'xyz' 56 | scales = np.array([2,2,1]) 57 | plt.locator_params(nbins=5) 58 | ax.set_xlabel(labels[seq[0]] + ' [mm]') 59 | ax.set_ylabel(labels[seq[1]] + ' [mm]') 60 | ax.set_zlabel(labels[seq[2]] + ' [mm]') 61 | ax.legend(['Display', 'Camera 1', 'Camera 2']) 62 | ax.get_legend().legendHandles[0].set_color('k') 63 | for i in range(self.camera_count): 64 | ax.get_legend().legendHandles[i+1].set_color(colors[i]) 65 | ax.set_title('setup') 66 | set_axes_equal(ax, np.array([scales[seq[i]] for i in range(3)])) 67 | plt.show() 68 | 69 | def to(self, device=torch.device('cpu')): 70 | super().to(device) 71 | 72 | # set device name (TODO: make it more elegant) 73 | self.device = device 74 | self.lensgroup.device = device 75 | self.screen.device = device 76 | for i in range(self.camera_count): 77 | self.cameras[i].device = device 78 | for i in range(len(self.lensgroup.surfaces)): 79 | self.lensgroup.surfaces[i].device = device 80 | 81 | def _simulate(self, i=None, with_element=True, mask=None, smode=SimulationMode.render): 82 | def simulate(i): 83 | if mask is None: # default: full sensor rendering 84 | ray = self.cameras[i].sample_ray() 85 | 86 | # interaction with metrology part 87 | if with_element and self.lensgroup is not None: 88 | ray, valid_ray, mask_g = self.lensgroup.trace(ray) 89 | else: 90 | mask_g = torch.ones(ray.o.shape[0:2], device=self.device) 91 | valid_ray = mask_g.clone().bool() 92 | 93 | # interaction with screen 94 | # We permute the axes to be compatiable with the image viewer (and the data). 95 | p, uv, valid_screen = self.screen.intersect(ray) 96 | valid = valid_screen & valid_ray 97 | 98 | if smode is SimulationMode.render: 99 | del p 100 | return self.screen.shading(uv, valid).permute(1,0) 101 | 102 | elif smode is SimulationMode.trace: 103 | del uv 104 | return p.permute(1,0,2), valid.permute(1,0), mask_g.permute(1,0) 105 | 106 | else: # get corresponding indices from the mask 107 | mask_ = mask[i].permute(1,0) # we transpose mask to align with our code 108 | ix, iy = torch.where(mask_) 109 | p2 = self.cameras[i].generate_position_sample(mask_) 110 | ray = self.cameras[i].sample_ray(p2) 111 | 112 | # interaction with metrology part 113 | if with_element and self.lensgroup is not None: 114 | ray, valid_ray, mask_g_ = self.lensgroup.trace(ray) 115 | else: 116 | mask_g_ = torch.ones(ray.o.shape[0:2]) 117 | valid_ray = mask_g_.clone().bool() 118 | 119 | # interaction with screen 120 | # We permute the axes to be compatiable with the image viewer (and the data). 121 | p_, uv, valid_screen = self.screen.intersect(ray) 122 | 123 | if smode is SimulationMode.render: 124 | del p_ 125 | raise NotImplementedError() 126 | 127 | elif smode is SimulationMode.trace: 128 | del uv 129 | p = torch.zeros(*self.cameras[i].filmsize, 3, device=self.device) 130 | p[ix, iy, ...] = p_ 131 | valid = torch.zeros(*self.cameras[i].filmsize, device=self.device).bool() 132 | valid[ix, iy] = mask_[ix, iy] 133 | mask_g = torch.zeros(*self.cameras[i].filmsize, device=self.device) 134 | mask_g[ix, iy] = mask_g_ 135 | return p.permute(1,0,2), valid.permute(1,0), mask_g.permute(1,0) 136 | 137 | return [simulate(j) for j in range(self.camera_count)] if i is None else simulate(i) 138 | 139 | # ---------------------------------------------------------------------------------------- 140 | 141 | def generate_test_scene(): 142 | R = np.eye(3) 143 | ts = [np.array([0, 0, -300]), np.array([0, 0, -240])] 144 | cameras = [generate_test_camera(R, t) for t in ts] 145 | screen = generate_test_screen() 146 | lensgroup = generate_test_lensgroup() 147 | return Scene(cameras, screen, lensgroup) 148 | 149 | # ---------------------------------------------------------------------------------------- 150 | 151 | 152 | if __name__ == "__main__": 153 | init() 154 | 155 | scene = generate_test_scene() 156 | scene.plot_setup() 157 | 158 | # render images 159 | imgs = scene.render() 160 | fig, ax = plt.subplots(1,2) 161 | for i, img in enumerate(imgs): 162 | ax[i].imshow(img, cmap='gray') 163 | ax[i].set_title('camera ' + str(i+1)) 164 | plt.show() 165 | -------------------------------------------------------------------------------- /diffmetrology/shapes.py: -------------------------------------------------------------------------------- 1 | from .basics import * 2 | from matplotlib.image import imread 3 | 4 | 5 | class Endpoint(PrettyPrinter): 6 | def __init__(self, transformation, device=torch.device('cpu')): 7 | self.to_world = transformation 8 | self.to_object = transformation.inverse() 9 | self.to_world.to(device) 10 | self.to_object.to(device) 11 | self.device = device 12 | 13 | def intersect(self, ray): 14 | raise NotImplementedError() 15 | 16 | def sample_ray(self, position_sample=None): 17 | raise NotImplementedError() 18 | 19 | def draw_points(self, ax, options, seq=range(3)): 20 | raise NotImplementedError() 21 | 22 | 23 | class Screen(Endpoint): 24 | """ 25 | Local frame centers at [-w, w]/2 x [-h, h]/2 26 | """ 27 | def __init__(self, transformation, size, pixelsize, texture, device=torch.device('cpu')): 28 | self.size = torch.Tensor(np.float32(size)) # screen dimension [mm] 29 | self.halfsize = self.size/2 # screen half-dimension [mm] 30 | self.pixelsize = torch.Tensor([pixelsize]) # screen pixel size [mm] 31 | self.texture = torch.Tensor(texture) # screen image 32 | self.texturesize = torch.Tensor(np.array(texture.shape[0:2])) # screen image dimension [pixel] 33 | self.texturesize_np = self.texturesize.cpu().detach().numpy() # screen image dimension [pixel] 34 | self.texture_shift = torch.zeros(2) # screen image shift [mm] 35 | Endpoint.__init__(self, transformation, device) 36 | self.to(device) 37 | 38 | def intersect(self, ray): 39 | ray_in = self.to_object.transform_ray(ray) 40 | t = - ray_in.o[..., 2] / ray_in.d[..., 2] # well-posed for dz (TODO: potential NaN grad) 41 | local = ray_in(t) 42 | 43 | # Is intersection within ray segment and rectangle? 44 | valid = ( 45 | (t >= ray_in.mint) & 46 | (t <= ray_in.maxt) & 47 | (torch.abs(local[..., 0] - self.texture_shift[0]) <= self.halfsize[0]) & 48 | (torch.abs(local[..., 1] - self.texture_shift[1]) <= self.halfsize[1]) 49 | ) 50 | 51 | # uv map 52 | uv = (local[..., 0:2] + self.halfsize - self.texture_shift) / self.size 53 | 54 | # force uv to be valid in [0,1]^2 (just a sanity check: uv should be in [0,1]^2) 55 | uv = torch.clamp(uv, min=0.0, max=1.0) 56 | 57 | return local, uv, valid 58 | 59 | def shading(self, uv, valid, bmode=BoundaryMode.replicate, lmode=InterpolationMode.linear): 60 | p = uv * (self.texturesize-1) 61 | p_floor = torch.floor(p).long() 62 | 63 | def tex(x, y): # texture indexing function 64 | if bmode is BoundaryMode.zero: 65 | raise NotImplementedError() 66 | elif bmode is BoundaryMode.replicate: 67 | x = torch.clamp(x, min=0, max=self.texturesize_np[0]-1) 68 | y = torch.clamp(y, min=0, max=self.texturesize_np[1]-1) 69 | elif bmode is BoundaryMode.symmetric: 70 | raise NotImplementedError() 71 | elif bmode is BoundaryMode.periodic: 72 | raise NotImplementedError() 73 | img = self.texture[x.flatten(), y.flatten()] 74 | return img.reshape(x.shape) 75 | 76 | if lmode is InterpolationMode.nearest: 77 | val = tex(p_floor[...,0], p_floor[...,1]) 78 | elif lmode is InterpolationMode.linear: 79 | x0, y0 = p_floor[...,0], p_floor[...,1] 80 | s00 = tex( x0, y0) 81 | s01 = tex( x0, 1+y0) 82 | s10 = tex(1+x0, y0) 83 | s11 = tex(1+x0, 1+y0) 84 | w1 = p - p_floor 85 | w0 = 1. - w1 86 | val = ( 87 | w0[...,0] * (w0[...,1] * s00 + w1[...,1] * s01) + 88 | w1[...,0] * (w0[...,1] * s10 + w1[...,1] * s11) 89 | ) 90 | 91 | val[~valid] = 0.0 92 | return val 93 | 94 | def draw_points(self, ax, options, seq=range(3)): 95 | coeffs = np.array([ 96 | [ 1, 1, 1], 97 | [-1, 1, 1], 98 | [-1,-1, 1], 99 | [ 1,-1, 1], 100 | [ 1, 1, 1] 101 | ]) 102 | points_local = torch.Tensor(coeffs * np.append(self.halfsize.cpu().detach().numpy(), 0)).to(self.device) 103 | points_world = self.to_world.transform_point(points_local).T.cpu().detach().numpy() 104 | ax.plot(points_world[seq[0]], points_world[seq[1]], points_world[seq[2]], options) 105 | 106 | 107 | class Camera(Endpoint): 108 | def __init__(self, transformation, 109 | filmsize, f=np.zeros(2), c=np.zeros(2), k=np.zeros(3), p=np.zeros(2), device=torch.device('cpu')): 110 | self.filmsize = filmsize 111 | self.f = torch.Tensor(np.float32(f)) # focal lengths [pixel] 112 | self.c = torch.Tensor(np.float32(c)) # centers [pixel] 113 | Endpoint.__init__(self, transformation, device) 114 | 115 | # un-initialized for now: 116 | self.crop_offset = torch.zeros(2, device=device) # [pixel] 117 | 118 | # no use for now: 119 | self.k = torch.Tensor(np.float32(k)) 120 | if len(self.k) < 3: self.k = np.append(self.k, 0) 121 | self.p = torch.Tensor(np.float32(p)) 122 | 123 | # configurations 124 | self.NEWTONS_MAXITER = 5 125 | self.NEWTONS_TOLERANCE = 50e-6 # in [mm], i.e. 50 [nm] here 126 | self.use_approximation = False 127 | self.to(device) 128 | 129 | def generate_position_sample(self, mask=None): 130 | """ 131 | Generate position samples (not uniform sampler) from a 2D mask. 132 | """ 133 | dim = self.filmsize 134 | X, Y = torch.meshgrid( 135 | 0.5 + dim[0] * torch.linspace(0, 1, 1+dim[0], device=self.device)[:-1], 136 | 0.5 + dim[1] * torch.linspace(0, 1, 1+dim[1], device=self.device)[:-1], 137 | ) 138 | if mask is not None: 139 | X, Y = X[mask], Y[mask] 140 | # X.shape could be 1D (masked) or 2D (no masked) 141 | return torch.stack((X, Y), axis=len(X.shape)) 142 | 143 | def sample_ray(self, position_sample=None, is_sampler=False): 144 | """ 145 | Sample ray(s) from sensor pixels. 146 | """ 147 | wavelength = torch.Tensor(np.asarray(562.0)).to(self.device) # 562 [nm] 148 | 149 | if position_sample is None: # default: full-sensor deterministic rendering 150 | dim = self.filmsize 151 | position_sample = self.generate_position_sample() 152 | is_sampler = False 153 | else: 154 | dim = position_sample.shape[:-1] 155 | 156 | if is_sampler: 157 | uv = position_sample * np.float32(dim) 158 | else: 159 | uv = position_sample 160 | 161 | # in local 162 | xy = self._uv2xy(uv) 163 | dz = torch.ones((*dim, 1), device=self.device) 164 | d = torch.cat((xy, dz), axis=-1) 165 | d = normalize(d) 166 | o = torch.zeros((*dim, 3), device=self.device) 167 | 168 | # in world 169 | o = self.to_world.transform_point(o) 170 | d = self.to_world.transform_vector(d) 171 | ray = Ray(o, d, wavelength, self.device) 172 | return ray 173 | 174 | def draw_points(self, ax, options, seq=range(3)): 175 | origin = np.zeros(3) 176 | scales = np.append(self.filmsize/100, 20) 177 | coeffs = np.array([ 178 | [ 1, 1, 1], 179 | [-1, 1, 1], 180 | [-1,-1, 1], 181 | [ 1,-1, 1], 182 | [ 1, 1, 1] 183 | ]) 184 | sensor_corners = torch.Tensor(coeffs * scales).to(self.device) 185 | ps = self.to_world.transform_point(sensor_corners).T.cpu().detach().numpy() 186 | ax.plot(ps[seq[0]], ps[seq[1]], ps[seq[2]], options) 187 | 188 | for i in range(4): 189 | coeff = coeffs[i] * scales 190 | line = torch.Tensor(np.array([ 191 | origin, 192 | coeff 193 | ])).to(self.device) 194 | ps = self.to_world.transform_point(line).T.cpu().detach().numpy() 195 | ax.plot(ps[seq[0]], ps[seq[1]], ps[seq[2]], options) 196 | 197 | def _uv2xy(self, uv): 198 | xy_distorted = (uv + self.crop_offset - self.c) / self.f 199 | xy = xy_distorted 200 | return xy 201 | 202 | # ---------------------------------------------------------------------------------------- 203 | 204 | def generate_test_camera(R=np.eye(3), t=np.array([0, 0, -400])): 205 | to_world = Transformation(R, t) 206 | 207 | filmsize = np.array([360, 480]) 208 | f = np.array([1000,1010]) # [pixel] 209 | c = np.array([150,160]) # [pixel] 210 | return Camera(to_world, filmsize, f, c) 211 | 212 | def generate_test_screen(): 213 | R = np.eye(3) 214 | t = np.zeros(3) 215 | to_world = Transformation(R, t) 216 | 217 | screensize = np.array([81., 80.]) # [mm] 218 | pixelsize = 0.115 # [mm] 219 | 220 | # read texture image 221 | im = imread('./images/checkerboard.png') 222 | im = np.mean(im, axis=-1) # for now we use grayscale 223 | im = im[200:500,200:600] 224 | # plt.figure() 225 | # plt.imshow(im, cmap='gray') 226 | # plt.show() 227 | 228 | return Screen(to_world, screensize, pixelsize, im) 229 | 230 | def test_camera(): 231 | camera = generate_test_camera() 232 | # print(camera) 233 | 234 | position_sample = torch.rand((*camera.filmsize, 2)) 235 | ray = camera.sample_ray(position_sample) 236 | print(ray) 237 | 238 | def test_screen(): 239 | screen = generate_test_screen() 240 | print(screen) 241 | 242 | # ---------------------------------------------------------------------------------------- 243 | 244 | 245 | if __name__ == "__main__": 246 | init() 247 | 248 | test_camera() 249 | test_screen() 250 | -------------------------------------------------------------------------------- /diffmetrology/solvers.py: -------------------------------------------------------------------------------- 1 | from .basics import * 2 | import numpy as np 3 | import torch.autograd.functional as F 4 | from skimage.restoration import unwrap_phase 5 | 6 | 7 | class Fringe(PrettyPrinter): 8 | """ 9 | Fringe image analysis to resolve displacements. 10 | """ 11 | def __init__(self): 12 | self.PHASE_SHIFT_COUNT = 4 13 | self.XY_COUNT = 2 14 | 15 | def solve(self, fs): 16 | """ 17 | ----- old ----- 18 | ref.shape = [len(Ts), self.XY_COUNT, camera_count, img_size] 19 | ----- old ----- 20 | 21 | Outputs: 22 | a.shape = b.shape = p.shape = 23 | [2, original size] 24 | 25 | where 2 denotes x and y. 26 | """ 27 | def single(fs): 28 | ax, bx, psix = self._solve(fs[0:self.PHASE_SHIFT_COUNT]) 29 | ay, by, psiy = self._solve(fs[self.PHASE_SHIFT_COUNT:self.XY_COUNT*self.PHASE_SHIFT_COUNT]) 30 | return np.array([ax, ay]), np.array([bx, by]), np.array([psix, psiy]) 31 | 32 | # TODO: following does not work when `len(Ts) == self.XY_COUNT*self.PHASE_SHIFT_COUNT` ... 33 | fsize = list(fs.shape) 34 | xy_index = fsize.index(self.XY_COUNT*self.PHASE_SHIFT_COUNT) 35 | inds = [i for i in range(len(fsize))] 36 | inds.remove(xy_index) 37 | inds = [xy_index] + inds 38 | 39 | # run the algorithm 40 | a, b, p = single(fs.transpose(inds)) 41 | 42 | return a, b, p 43 | 44 | def unwrap(self, fs, Ts, valid=None): 45 | print('unwraping ...') 46 | if valid is None: 47 | valid = 1.0 48 | F = valid * fs 49 | fs_unwrapped = np.zeros(fs.shape) 50 | for xy in range(fs.shape[0]): 51 | print(f'xy = {xy} ...') 52 | for T in range(fs.shape[1]): 53 | print(f't = {T} ...') 54 | t = Ts[T] 55 | for i in range(fs.shape[2]): 56 | fs_unwrapped[xy,T,i] = unwrap_phase(F[xy,T,i,...]) * t / (2*np.pi) 57 | return fs_unwrapped 58 | 59 | @staticmethod 60 | def _solve(fs): 61 | """Solver for four-step phase shifting: f(x) = a + b cos(\phi + \psi). 62 | b cos(\psi) = fs[0] - a 63 | - b sin(\psi) = fs[1] - a 64 | - b cos(\psi) = fs[2] - a 65 | b sin(\psi) = fs[3] - a 66 | """ 67 | a = np.mean(fs, axis=0) 68 | b = 0.0 69 | for f in fs: 70 | b += (f - a)**2 71 | b = b/2.0 72 | psi = np.arctan2(fs[3] - fs[1], fs[0] - fs[2]) 73 | return a, b, psi 74 | 75 | 76 | class Optimization(PrettyPrinter): 77 | """ 78 | General class for design optimization. 79 | """ 80 | def __init__(self, diff_variables): 81 | self.diff_variables = diff_variables 82 | for v in self.diff_variables: 83 | v.requires_grad = True 84 | self.optimizer = None # to be initialized from set_parameters 85 | 86 | class Adam(Optimization): 87 | def __init__(self, diff_variables, lr, lrs=None, beta=0.99): 88 | Optimization.__init__(self, diff_variables) 89 | if lrs is None: 90 | lrs = [1] * len(self.diff_variables) 91 | self.optimizer = torch.optim.Adam( 92 | [{"params": v, "lr": lr*l} for v, l in zip(self.diff_variables, lrs)], 93 | betas=(beta,0.999), amsgrad=True 94 | ) 95 | 96 | def optimize(self, func, loss, maxit=300, record=True): 97 | print('optimizing ...') 98 | ls = [] 99 | with torch.autograd.set_detect_anomaly(False): #True 100 | for it in range(maxit): 101 | L = loss(func()) 102 | self.optimizer.zero_grad() 103 | L.backward(retain_graph=True) 104 | 105 | if record: 106 | grads = torch.Tensor([torch.mean(torch.abs(v.grad)) for v in self.diff_variables]) 107 | print('iter = {}: loss = {:.4e}, grad_bar = {:.4e}'.format( 108 | it, L.item(), torch.mean(grads) 109 | )) 110 | ls.append(L.cpu().detach().numpy()) 111 | 112 | self.optimizer.step() 113 | 114 | return np.array(ls) 115 | 116 | 117 | class LM(Optimization): 118 | """ 119 | The Levenberg–Marquardt algorithm. 120 | """ 121 | def __init__(self, diff_variables, lamb, mu=None, option='diag'): 122 | Optimization.__init__(self, diff_variables) 123 | self.lamb = lamb # damping factor 124 | self.mu = 2.0 if mu is None else mu # dampling rate (>1) 125 | self.option = option 126 | 127 | def jacobian(self, func, inputs, create_graph=False, strict=False): 128 | """Constructs a M-by-N Jacobian matrix where M >> N. 129 | 130 | Here, computing the Jacobian only makes sense for a tall Jacobian matrix. In this case, 131 | column-wise evaluation (forward-mode, or jvp) is more effective to construct the Jacobian. 132 | 133 | This function is modified from torch.autograd.functional.jvp(). 134 | """ 135 | 136 | Js = [] 137 | outputs = func() 138 | M = outputs.shape 139 | 140 | grad_outputs = (torch.zeros_like(outputs, requires_grad=True),) 141 | for x in inputs: 142 | grad_inputs = F._autograd_grad( 143 | (outputs,), x, grad_outputs, create_graph=True 144 | ) 145 | 146 | F._check_requires_grad(grad_inputs, "grad_inputs", strict=strict) 147 | 148 | # Construct Jacobian matrix 149 | N = torch.numel(x) 150 | if N == 1: 151 | J = F._autograd_grad( 152 | grad_inputs, grad_outputs, (torch.ones_like(x),), 153 | create_graph=create_graph, 154 | retain_graph=True 155 | )[0][...,None] 156 | else: 157 | J = torch.zeros((*M, N), device=x.device) 158 | v = torch.zeros(N, device=x.device) 159 | for i in range(N): 160 | v[i] = 1.0 161 | J[...,i] = F._autograd_grad( 162 | grad_inputs, grad_outputs, (v.view(x.shape),), 163 | create_graph=create_graph, 164 | retain_graph=True 165 | )[0] 166 | 167 | v[i] = 0.0 168 | Js.append(J) 169 | return torch.cat(Js, axis=-1) 170 | 171 | def optimize(self, func, change_parameters, func_yref_y, maxit=300, record=True): 172 | """ 173 | Inputs: 174 | - func: Evaluate `y = f(x)` where `x` is the implicit parameters by `self.diff_variables` (out of the class) 175 | - change_parameters: Change of `self.diff_variables` (out of the class) 176 | - func_yref_y: Compute `y_ref - y` 177 | 178 | Outputs: 179 | - ls: Loss function. 180 | """ 181 | print('optimizing ...') 182 | Ns = [x.numel() for x in self.diff_variables] 183 | NS = [[*x.shape] for x in self.diff_variables] 184 | 185 | ls = [] 186 | lamb = self.lamb 187 | with torch.autograd.set_detect_anomaly(False): 188 | for it in range(maxit): 189 | y = func() 190 | with torch.no_grad(): 191 | L = torch.mean(func_yref_y(y)**2).item() 192 | if L < 1e-16: 193 | print('L too small; termiante.') 194 | break 195 | 196 | # Obtain Jacobian 197 | J = self.jacobian(func, self.diff_variables, create_graph=False) 198 | J = J.view(-1, J.shape[-1]) 199 | JtJ = J.T @ J 200 | N = JtJ.shape[0] 201 | 202 | # Regularization matrix 203 | if self.option == 'I': 204 | R = torch.eye(N, device=JtJ.device) 205 | elif self.option == 'diag': 206 | R = torch.diag(torch.diag(JtJ).abs()) 207 | else: 208 | R = torch.diag(self.option) 209 | 210 | # Compute b = J.T @ (y_ref - y) 211 | bb = [ 212 | torch.autograd.grad(outputs=y, inputs=x, grad_outputs=func_yref_y(y), retain_graph=True)[0] 213 | for x in self.diff_variables 214 | ] 215 | for i, bx in enumerate(bb): 216 | if len(bx.shape) == 0: # single scalar 217 | bb[i] = torch.Tensor([bx.item()]).to(y.device) 218 | if len(bx.shape) > 1: # multi-dimension 219 | bb[i] = torch.Tensor(bx.cpu().detach().numpy().flatten()).to(y.device) 220 | b = torch.cat(bb, axis=-1) 221 | del J, bb, y 222 | 223 | # Damping loop 224 | L_current = L + 1.0 225 | it_inner = 0 226 | while L_current >= L: 227 | it_inner += 1 228 | if it_inner > 20: 229 | print('inner loop too many; Exiting damping loop.') 230 | break 231 | 232 | A = JtJ + lamb * R 233 | x_delta = torch.linalg.solve(A, b) 234 | if torch.isnan(x_delta).sum(): 235 | print('x_delta NaN; Exiting damping loop') 236 | break 237 | x_delta_s = torch.split(x_delta, Ns) 238 | 239 | # Reshape if x is not a 1D array 240 | x_delta_s = [*x_delta_s] 241 | for xi in range(len(x_delta_s)): 242 | x_delta_s[xi] = torch.reshape(x_delta_s[xi], NS[xi]) 243 | 244 | # Update `x += x_delta` (this is done in external function `change_parameters`) 245 | self.diff_variables = change_parameters(x_delta_s, sign=True) 246 | 247 | # Calculate new error 248 | with torch.no_grad(): 249 | L_current = torch.mean(func_yref_y(func())**2).item() 250 | 251 | del A 252 | 253 | # Terminate 254 | if L_current < L: 255 | lamb /= self.mu 256 | del x_delta_s 257 | break 258 | 259 | # Else, increase damping and undo the update 260 | lamb *= 2.0*self.mu 261 | self.diff_variables = change_parameters(x_delta_s, sign=False) 262 | 263 | if lamb > 1e16: 264 | print('lambda too big; Exiting damping loop.') 265 | del x_delta_s 266 | break 267 | 268 | del JtJ, R, b 269 | 270 | if record: 271 | x_increment = torch.mean(torch.abs(x_delta)).item() 272 | print('iter = {}: loss = {:.4e}, |x_delta| = {:.4e}'.format( 273 | it, L, x_increment 274 | )) 275 | ls.append(L) 276 | if it > 0: 277 | dls = np.abs(ls[-2] - L) 278 | if dls < 1e-8: 279 | print("|\Delta loss| = {:.4e} < 1e-8; Exiting LM loop.".format(dls)) 280 | break 281 | 282 | if x_increment < 1e-8: 283 | print("|x_delta| = {:.4e} < 1e-8; Exiting LM loop.".format(x_increment)) 284 | break 285 | return ls 286 | -------------------------------------------------------------------------------- /diffmetrology/utils.py: -------------------------------------------------------------------------------- 1 | from .scene import * 2 | from .solvers import * 3 | import scipy 4 | import scipy.optimize 5 | import scipy.io 6 | 7 | from matplotlib.image import imread 8 | import matplotlib.ticker as plticker 9 | import time 10 | 11 | 12 | def var2string(variable): 13 | for name in globals(): 14 | if eval(name) == variable: 15 | return name 16 | return '' 17 | 18 | class DiffMetrology(PrettyPrinter): 19 | """ 20 | Major class to handle all situations 21 | """ 22 | def __init__(self, 23 | calibration_path, rotation_path, 24 | origin_shift, lut_path=None, thetas=None, angles=0., 25 | scale=1, device=torch.device('cpu')): 26 | 27 | self.device = device 28 | self.MAX_VAL = 2**16 - 1.0 # 16-bit image 29 | 30 | # geometry setup 31 | mat_g = scipy.io.loadmat(calibration_path + 'cams.mat') 32 | cameras = self._init_camera(mat_g, scale) 33 | screen = self._init_screen(mat_g) 34 | self.scene = Scene(cameras, screen, device=device) 35 | 36 | # cache calibration checkerboard image for testing 37 | self._checkerboard = imread(calibration_path + 'checkerboard.png') 38 | self._checkerboard = np.mean(self._checkerboard, axis=-1) # for now we use grayscale 39 | self._checkerboard = np.flip(self._checkerboard, axis=1).copy() 40 | self._checkerboard_wh = np.array([mat_g['w'], mat_g['h']]) 41 | 42 | # lensgroup metrology part setup 43 | mat_r = scipy.io.loadmat(rotation_path) 44 | p_rotation = torch.Tensor(np.stack((mat_r['p1'][0], mat_r['p2'][0]), axis=-1)).T 45 | origin_mount = self._compute_mount_geometry(p_rotation*scale, verbose=True) 46 | if thetas is None: 47 | self.scene.lensgroup = Lensgroup(origin_mount, origin_shift, 0.0, 0.0, 0.0, device) 48 | else: 49 | self.scene.lensgroup = Lensgroup(origin_mount, origin_shift, thetas[0], thetas[1], 0.0, device) 50 | 51 | if type(angles) is not list: 52 | angles = [angles] 53 | self.angles = angles 54 | 55 | # load sensor LUT 56 | if lut_path is not None: 57 | self.lut = scipy.io.loadmat(lut_path)['Js'][:,:self.scene.camera_count] 58 | self.bbd = scipy.io.loadmat(lut_path)['bs'].reshape((2,self.scene.camera_count)) 59 | 60 | self.ROTATION_ANGLE = -25.0 61 | 62 | # === Utility methods === 63 | def solve_for_intersections(self, fs_cap, fs_ref, Ts): 64 | """ 65 | Obtain the intersection points from phase-shifting images. 66 | """ 67 | FR = Fringe() 68 | a_ref, b_ref, psi_ref = FR.solve(fs_ref) 69 | a_cap, b_cap, psi_cap = FR.solve(fs_cap) 70 | 71 | VERBOSE = False 72 | # VERBOSE = True 73 | 74 | def find_center(valid_cap): 75 | x, y = np.argwhere(valid_cap==1).sum(0)/ valid_cap.sum() 76 | return np.array([x, y]) 77 | 78 | # get valid map for two cameras 79 | valid_map = [] 80 | A = np.mean(a_cap, axis=(0,1)) 81 | B = np.mean(b_cap, axis=(0,1)) 82 | I = np.abs(np.mean(fs_ref - fs_cap, axis=(0,1))) 83 | for j in range(self.scene.camera_count): 84 | # thres = 0.005 85 | # thres = 0.1 86 | thres = 0.07 87 | valid_ab = (A[j] > thres) & (B[j] > thres) & (I[j] < 2.0*thres) 88 | if VERBOSE: 89 | plt.imshow(valid_ab); plt.show() 90 | label, num_features = scipy.ndimage.label(valid_ab) 91 | if num_features < 2: 92 | label_target = 1 93 | else: # count labels and get the area as the target measurement area 94 | counts = np.array([np.count_nonzero(label == i) for i in range(num_features)]) 95 | label_targets = np.where((200**2 < counts) & (counts < 500**2))[0] 96 | if len(label_targets) > 1: # find which label target is our lens 97 | Dm = np.inf 98 | for l in label_targets: 99 | c = find_center(label == l) 100 | D = np.abs(self.scene.cameras[0].filmsize/2 - c).sum() 101 | if D < Dm: 102 | Dm = D 103 | label_target = l 104 | else: 105 | label_target = label_targets[0] 106 | V = label == label_target 107 | valid_map.append(V) 108 | valid_map = np.array(valid_map) 109 | 110 | psi_unwrap = FR.unwrap(psi_cap - psi_ref, Ts, valid=valid_map[None,None,...]) 111 | 112 | # Given the unwrapped phase, we try to remove the DC term so that |psi_unwrap|^2 is minimized 113 | # NOTE: This method is not yet robust 114 | # remove_DC = False 115 | remove_DC = True 116 | if remove_DC: 117 | k_DC = np.arange(-10,11,1) 118 | for t in range(len(Ts)): 119 | DCs = k_DC * Ts[t] 120 | psi_current = psi_unwrap[:,t,:,...] 121 | psi_with_dc = valid_map[:,None,:,:,None] * (psi_current[...,None] + DCs[None,None,None,None,...]) 122 | DC_target = DCs[np.argmin(np.sum(psi_with_dc**2, axis=(2,3)), axis=-1)] 123 | print("t = {}, DC_target =\n{}".format(Ts[t], DC_target)) 124 | psi_unwrap[:,t,:,...] += valid_map[None,:,...] * np.transpose(DC_target,(0,1))[...,None,None] 125 | 126 | if VERBOSE: 127 | for t in range(len(Ts)): 128 | for ii in range(2): 129 | plt.figure() 130 | plt.imshow(psi_unwrap[0,t,ii,...], cmap='coolwarm') 131 | plt.show() 132 | 133 | # Convert unit from [pixel number] to [mm] 134 | psi_x = psi_unwrap[0, ...] * self.scene.screen.pixelsize.item() 135 | psi_y = psi_unwrap[1, ...] * self.scene.screen.pixelsize.item() 136 | 137 | # Get median of the values across different Ts 138 | psi_x = np.mean(psi_x, axis=0) 139 | psi_y = np.mean(psi_y, axis=0) 140 | 141 | # Compute intersection points when there are no elements 142 | ps_ref = torch.stack(self.trace(with_element=False, angles=0.0)[0]) 143 | ps_ref = valid_map[...,None] * ps_ref[...,0:2].cpu().detach().numpy() 144 | 145 | # Compute final shift (here, the valid map is the valid map of x of the first T) 146 | # NOTE: here we flip the sequence of (x,y) to fit the format for lateral processings ... 147 | p = ps_ref - np.stack((psi_y, psi_x), axis=-1) 148 | 149 | # Find valid map centers 150 | xs = [] 151 | ys = [] 152 | for i in range(len(valid_map)): 153 | x, y = np.argwhere(valid_map[i]==1).sum(0)/ valid_map[i].sum() 154 | xs.append(x) 155 | ys.append(y) 156 | centers = np.stack((np.array(xs), np.array(ys)), axis=-1) 157 | centers = np.fliplr(centers).copy() # flip to fit our style 158 | 159 | return torch.Tensor(p).to(self.device), torch.Tensor(valid_map).bool().to(self.device), torch.Tensor(centers).to(self.device) 160 | 161 | 162 | def simulation(self, sinusoid_path, Ts, i=None, angles=None, to_numpy=False): 163 | """ 164 | Render fringe images. 165 | 166 | img.shape = [ len(Ts), 8, self.camera_count, img_size ] 167 | """ 168 | print('Simulation ...') 169 | 170 | # cache current screen 171 | screen_org = self.scene.screen 172 | pixelsize = screen_org.pixelsize.item() 173 | 174 | def single_impl(with_element): 175 | """ 176 | imgs_all.shape = len(Ts) * [0-8] * camera_count * img_size. 177 | """ 178 | imgs_all = [] 179 | for T in Ts: 180 | print(f'Now at T = {T}') 181 | img_path = sinusoid_path + 'T=' + str(T) + '/' 182 | 183 | # with elements 184 | imgs = [] 185 | for i in range(8): 186 | # read sinusoid images 187 | im = imread(img_path + str(i) + '.png') 188 | im = np.mean(im, axis=-1) # for now we use grayscale 189 | im = np.flip(im).copy() # NOTE: (i) our display is rotated by 90 deg; 190 | # (ii) to be consistent with XY convention here. So the flip. 191 | 192 | # set screen to be sinusoid patterns 193 | sizenew = pixelsize * np.array(im.shape) 194 | t = np.array([sizenew[0]/2-50.0, sizenew[1]/2-80.0, 0]) 195 | self.scene.screen = Screen(Transformation(np.eye(3), t), sizenew, pixelsize, im, self.device) 196 | 197 | # render 198 | tmp = self.scene.render(with_element=with_element, to_numpy=True) 199 | imgs.append(np.array(tmp)) 200 | imgs_all.append(np.array(imgs)) 201 | 202 | return np.array(imgs_all) 203 | 204 | def single(angle): 205 | self.scene.lensgroup.update(_y=angle) 206 | imgs = single_impl(True) 207 | self.scene.lensgroup.update(_y=-angle) 208 | return imgs 209 | 210 | # we only capture reference once 211 | print('Simulating reference ...') 212 | refs = single_impl(False) 213 | 214 | # here, we measure testing part in each angle 215 | print('Simulating measurements ...') 216 | if angles is None: 217 | """ 218 | imgs_rendered.shape = len(angles) * len(Ts) * [0-8] * camera_count * img_size. 219 | """ 220 | ims = [] 221 | for angle in self.angles: 222 | ims.append(single(angle)) 223 | ims = np.array(ims) 224 | else: 225 | print(f'Now at angle = {angles}') 226 | ims = single(angles) 227 | 228 | # revert back to the original screen 229 | self.scene.screen = screen_org 230 | 231 | print('Done ...') 232 | return ims, refs 233 | 234 | def render(self, i=None, with_element=True, mask=None, angles=None, to_numpy=False): 235 | """ 236 | Rendering. 237 | """ 238 | def single(angle): 239 | self.scene.lensgroup.update(_y=angle) 240 | im = self.scene.render(i, with_element, mask, to_numpy) 241 | self.scene.lensgroup.update(_y=-angle) 242 | return im 243 | 244 | if angles is None: 245 | ims = [] 246 | for angle in self.angles: 247 | ims += single(angle) 248 | else: 249 | ims = single(angles) 250 | return ims 251 | 252 | def trace(self, i=None, with_element=True, mask=None, angles=None, to_numpy=False): 253 | """ 254 | Perform ray tracing. 255 | """ 256 | def single(angle): 257 | self.scene.lensgroup.update(_y=angle) 258 | ps, valid, mask_g = self.scene.trace(i, with_element, mask, to_numpy) 259 | self.scene.lensgroup.update(_y=-angle) 260 | return ps, valid, mask_g 261 | 262 | if angles is None: 263 | ps = [] 264 | valid = [] 265 | mask_g = [] 266 | for angle in self.angles: 267 | pss, valids, mask_gs = single(angle) 268 | ps += pss 269 | valid += valids 270 | mask_g += mask_gs 271 | else: 272 | ps, valid, mask_g = single(angles) 273 | return ps, valid, mask_g 274 | 275 | # ===================== 276 | 277 | def to(self, device=torch.device('cpu')): 278 | super().to(device) 279 | self.device = device 280 | self.scene.to(device) 281 | 282 | # === Visualizations === 283 | def imshow(self, imgs): 284 | N = self.scene.camera_count 285 | self._imshow(imgs[0:N], title='front') 286 | if len(imgs) > N: 287 | self._imshow(imgs[N:2*N], title='back') 288 | plt.show() 289 | 290 | def _imshow(self, imgs, title=''): 291 | ax = plt.subplots(1,len(imgs))[1] 292 | for i, img in enumerate(imgs): 293 | ax[i].imshow(img, cmap='gray') 294 | ax[i].set_title(title + ': camera ' + str(i+1)) 295 | 296 | def spot_diagram(self, ps_ref, ps_cap, valid=True, angle=None, with_grid=True): 297 | """ 298 | Plot spot diagram. 299 | """ 300 | N = self.scene.camera_count 301 | for j, a in enumerate(self.angles): 302 | if angle == a: 303 | i = j 304 | try: 305 | i 306 | except NameError: 307 | i = 0 308 | figure = self._spot_diagram(ps_ref[N*i:N*(i+1)], ps_cap[N*i:N*(i+1)], valid[N*i:N*(i+1)], title=f'angle={angle}', with_grid=with_grid) 309 | figure.suptitle('Spot Diagram') 310 | return figure 311 | 312 | def _spot_diagram(self, ps_ref, ps_cap, valid=True, title='', with_grid=False): 313 | """ 314 | Plot spot diagram. 315 | """ 316 | figure, ax = plt.subplots(1, self.scene.camera_count) 317 | 318 | def sub_sampling(x): 319 | Ns = [8,8] 320 | return x[::Ns[0],::Ns[1],...] 321 | for i in range(len(ax)): 322 | mask = sub_sampling(valid[i]) 323 | ref = sub_sampling(ps_ref[i])[mask].cpu().detach().numpy() 324 | cap = sub_sampling(ps_cap[i])[mask].cpu().detach().numpy() 325 | ax[i].plot(ref[...,0], ref[...,1], 'b.', label='Measurement') 326 | ax[i].plot(cap[...,0], cap[...,1], 'r.', label='Modeled (reprojection)') 327 | ax[i].legend() 328 | ax[i].set_xlabel('[mm]') 329 | ax[i].set_ylabel('[mm]') 330 | ax[i].set_aspect(1) 331 | ax[i].set_title(title + ': camera ' + str(i+1)) 332 | 333 | # Add the grid 334 | if with_grid: 335 | loc = plticker.MultipleLocator(base=4*self.scene.screen.pixelsize.item()) 336 | ax[i].xaxis.set_major_locator(loc) 337 | ax[i].yaxis.set_major_locator(loc) 338 | ax[i].grid(which='major', axis='both', linestyle='-') 339 | ax[i].tick_params(axis='both', which='minor', width=0) 340 | return figure 341 | 342 | def generate_grid(self, R): 343 | N = 513 344 | x = y = torch.linspace(-R, R, N, device=self.device) 345 | X, Y = torch.meshgrid(x, y) 346 | valid = X**2 + Y**2 <= R**2 347 | return X, Y, valid 348 | 349 | def show_surfaces(self, verbose=True): 350 | if verbose: 351 | ax = plt.subplots(1, len(self.scene.lensgroup.surfaces))[1] 352 | Zs = [] 353 | valids = [] 354 | for i, surface in enumerate(self.scene.lensgroup.surfaces): 355 | X, Y, valid = self.generate_grid(surface.r) 356 | Z = surface.surface(X, Y) 357 | Z_mean = Z[valid].mean().item() 358 | Z = torch.where(valid, Z - Z_mean, torch.zeros_like(Z)).cpu().detach().numpy() 359 | valids.append(valid) 360 | Zs.append(Z) 361 | if verbose: 362 | im = ax[i].imshow(Z, cmap='jet') 363 | ax[i].set_title('surface ' + str(i)) 364 | plt.colorbar(im, ax=ax[i]) 365 | if verbose: 366 | plt.show() 367 | return Zs, valids 368 | 369 | def print_surfaces(self): 370 | if self.scene.lensgroup == None: 371 | print('No surfaces found; Please initialize lensgroup!') 372 | else: 373 | for i, s in enumerate(self.scene.lensgroup.surfaces): 374 | print("surface[{}] = {}".format(i, s)) 375 | 376 | # ===================== 377 | 378 | # === Optimizations === 379 | def change_parameters(self, diff_parameters_names, xs, sign=True): 380 | diff_parameters = [] 381 | for i, name in enumerate(diff_parameters_names): 382 | if sign: 383 | exec('self.scene.{name} = self.scene.{name} + xs[{i}]'.format(name=name,i=i)) 384 | else: 385 | exec('self.scene.{name} = self.scene.{name} - xs[{i}]'.format(name=name,i=i)) 386 | exec('diff_parameters.append(self.scene.{})'.format(name)) 387 | return diff_parameters 388 | 389 | def solve(self, diff_parameters_names, forward, loss, func_yref_y=None, option='LM', R=None): 390 | """ 391 | Solve for unknown parameters. 392 | """ 393 | # def loss(I): 394 | # return (I - I0).mean() 395 | 396 | # def func_yref_y(I): 397 | # return I0 - I 398 | 399 | time_start = time.time() 400 | 401 | diff_parameters = [] 402 | for name in diff_parameters_names: 403 | try: 404 | exec('self.scene.{}.requires_grad = True'.format(name)) 405 | except: 406 | exec('self.scene.{name} = self.scene.{name}.detach()'.format(name=name)) 407 | exec('self.scene.{}.requires_grad = True'.format(name)) 408 | exec('diff_parameters.append(self.scene.{})'.format(name)) 409 | 410 | if option == 'Adam': 411 | O = Adam( 412 | diff_variables=diff_parameters, 413 | lr=1e-1, 414 | beta=0.99 415 | ) 416 | ls = O.optimize( 417 | forward, 418 | loss, 419 | maxit=200 420 | ) 421 | 422 | elif option == 'LM': 423 | if func_yref_y is None: 424 | raise Exception("func_yref_y is not given!") 425 | 426 | if R is None: 427 | Ropt = 'I' 428 | else: 429 | Ropt = R 430 | O = LM( 431 | diff_variables=diff_parameters, 432 | lamb=1e-1, # 1e-4 433 | option=Ropt 434 | ) 435 | ls = O.optimize( 436 | forward, 437 | lambda xs, sign: self.change_parameters(diff_parameters_names, xs, sign), 438 | func_yref_y=func_yref_y, 439 | maxit=100 440 | ) 441 | 442 | else: 443 | raise NotImplementedError() 444 | 445 | for name in diff_parameters_names: 446 | print('self.scene.{} = '.format(name), end='') 447 | exec('print(self.scene.{}.cpu().detach().numpy())'.format(name)) 448 | 449 | torch.cuda.synchronize() 450 | time_end = time.time() 451 | 452 | print('Elapsed time = {:e} seconds'.format(time_end - time_start)) 453 | 454 | return ls 455 | # ===================== 456 | 457 | 458 | # === Optimizations === 459 | def init_diff_parameters(self, dicts=None, pose_dict=None): 460 | self.diff_parameters = {} 461 | print('Initializing differentiable parameters:') 462 | if dicts is not None: 463 | for i, dictionary in enumerate(dicts): 464 | if dictionary is None: 465 | continue # skip if there is none 466 | if type(dictionary) is dict: 467 | for key, value in dictionary.items(): 468 | keystr = 'surfaces[{}].{}'.format(i, key) 469 | full_keystr = 'self.scene.lensgroup.' + keystr 470 | print('--- ' + full_keystr) 471 | self.diff_parameters[keystr] = torch.Tensor(np.asarray(value)).to(self.device) 472 | exec('{} = self.diff_parameters[keystr].clone()'.format(full_keystr)) 473 | exec('{}.requires_grad = True'.format(full_keystr)) 474 | elif type(dictionary) is set: 475 | for key in dictionary: 476 | keystr = 'surfaces[{}].{}'.format(i, key) 477 | full_keystr = 'self.scene.lensgroup.' + keystr 478 | print('--- ' + full_keystr) 479 | exec('self.diff_parameters[keystr] = {}.clone()'.format(full_keystr)) 480 | exec('{}.requires_grad = True'.format(full_keystr)) 481 | else: 482 | raise Exception("wrong type dicts!") 483 | if pose_dict is not None: 484 | if type(pose_dict) is dict: 485 | for keystr, value in pose_dict.items(): 486 | full_keystr = 'self.scene.lensgroup.' + keystr 487 | print('--- ' + full_keystr) 488 | self.diff_parameters[keystr] = torch.Tensor(np.asarray(value)).to(self.device) 489 | exec('{} = self.diff_parameters[keystr].clone()'.format(full_keystr)) 490 | exec('{}.requires_grad = True'.format(full_keystr)) 491 | elif type(pose_dict) is set: 492 | for keystr in pose_dict: 493 | full_keystr = 'self.scene.lensgroup.' + keystr 494 | print('--- ' + full_keystr) 495 | exec('self.diff_parameters[keystr] = {}.clone()'.format(full_keystr)) 496 | exec('{}.requires_grad = True'.format(full_keystr)) 497 | else: 498 | raise Exception("wrong type dicts!") 499 | print('... Done.') 500 | 501 | def print_diff_parameters(self): 502 | for key in self.diff_parameters.keys(): 503 | full_keystr = 'self.scene.lensgroup.' + key 504 | exec("print('-- {{}} = {{}}'.format(key, {}.cpu().detach().numpy()))".format(full_keystr)) 505 | 506 | @staticmethod 507 | def compRMS(Zs, Zs_gt, valid): 508 | def RMS(x, y, valid): 509 | x = x[valid] 510 | y = y[valid] 511 | x -= x.mean() 512 | y -= y.mean() 513 | tmp = (x - y)**2 514 | return np.sqrt(np.mean(tmp)) 515 | 516 | rmss = [] 517 | for i in range(len(Zs_gt)): 518 | rms = RMS(Zs[i], Zs_gt[i], valid[i].cpu().detach().numpy()) 519 | print('RMS = {} [mm]'.format(rms)) 520 | rmss.append(rms) 521 | return rmss 522 | 523 | def ploterror(self, Zs, Zs_gt, verbose=True): 524 | figure, ax = plt.subplots(1, len(self.scene.lensgroup.surfaces), figsize=(9,3.5)) 525 | for i, s in enumerate(self.scene.lensgroup.surfaces): 526 | tmp = Zs[i] - Zs_gt[i] 527 | im = ax[i].imshow(tmp, cmap='jet') 528 | ax[i].set_title('surface ' + str(i)) 529 | plt.colorbar(im, ax=ax[i]) 530 | if verbose: 531 | plt.show() 532 | return figure 533 | 534 | # ===================== 535 | 536 | def set_texture(self, textures): 537 | if len(textures.shape) > 2: 538 | texture = textures[0] 539 | else: 540 | texture = textures 541 | pixelsize = self.scene.screen.pixelsize.item() 542 | sizenew = pixelsize * np.array(texture.shape) 543 | t = np.zeros(3) 544 | self.scene.screen = Screen(Transformation(np.eye(3), t), sizenew, pixelsize, np.flip(texture).copy(), self.device) 545 | 546 | # === Tests === 547 | def test_setup(self, verbose=True): 548 | """ 549 | Test if the setup is correct: check render and measurement images for consistency 550 | """ 551 | # cache current screen 552 | screen_org = self.scene.screen 553 | lensgroup = self.scene.lensgroup 554 | self.scene.lensgroup = None 555 | 556 | # set screen to be checkerboard 557 | pixelsize = screen_org.pixelsize.item() 558 | sizenew = pixelsize * np.array(self._checkerboard.shape) 559 | t = np.array([sizenew[0]/2, sizenew[1]/2, 0]) 560 | t[0] -= sizenew[0]/self._checkerboard_wh[0] 561 | t[1] -= sizenew[1]/self._checkerboard_wh[1] 562 | self.scene.screen = Screen(Transformation(np.eye(3), t), sizenew, pixelsize, self._checkerboard, self.device) 563 | 564 | # render 565 | imgs_rendered = self.scene.render() 566 | if verbose: 567 | self.scene.plot_setup() 568 | plt.show() 569 | 570 | # revert back to the original screen 571 | self.scene.screen = screen_org 572 | self.scene.lensgroup = lensgroup 573 | return imgs_rendered 574 | # ===================== 575 | 576 | # === Internal methods === 577 | # parse cameras and screen parameters 578 | def _init_camera(self, mat, scale=1): 579 | filmsize = (mat['filmsize'] * scale).astype(int) 580 | f = mat['f'] * scale 581 | c = mat['c'] * scale 582 | k = mat['k'] * scale 583 | p = mat['p'] * scale 584 | R = mat['R'] 585 | t = mat['t'] 586 | def matlab2ours(R, t): 587 | """Convert MATLAB [R | t] convention to ours: 588 | In MATLAB (https://www.mathworks.com/help/vision/ug/camera-calibration.html): 589 | w_scale [x_obj y_obj 1] = [x_world y_world z_world] R_matlab + t_matlab 590 | [ 1x3 ] [ 1x3 ] [3x3] [1x3] 591 | 592 | Ours: 593 | [x_obj y_obj 1]' = R_object [x_world y_world z_world]' + t_object 594 | [ 3x1 ] [ 3x3 ] [ 3x1 ] [3x1] 595 | 596 | We would like to get the world-coordinate [R | t]. This requires: 597 | R_world = R_object.T 598 | t_world = -R_object.T @ t_object 599 | 600 | where we conclude R_world = R_matlab. 601 | """ 602 | return R, -R @ t 603 | return [Camera( 604 | Transformation(*matlab2ours(R[...,i], t[...,i])), 605 | filmsize[i], f[i], c[i], k[i], p[i], self.device 606 | ) for i in range(len(f))] 607 | 608 | def _init_screen(self, mat): 609 | pixelsize = 1e-3 * mat['display_pixel_size'][0][0] # in [mm] 610 | size = pixelsize * np.array([1600, 2560]) # in [mm] (MacBook Pro 13.3") 611 | im = imread('./imgs/checkerboard.png') 612 | im = np.mean(im, axis=-1) # for now we use grayscale 613 | 614 | return Screen(Transformation(np.eye(3), np.zeros(3)), size, pixelsize, im, self.device) 615 | 616 | def _compute_mount_geometry(self, p_rotation, verbose=True): 617 | """ 618 | We would like to estimate the intersection between two lines: 619 | 620 | L1 = o1 + t1 d1 621 | L2 = o2 + t2 d2 622 | 623 | where we need to solve for t1 and t2, by an over-determined least-squares 624 | (know R^3, solve for R^2): 625 | 626 | min || L1 - L2 ||^2 = || (o1-o2) + [d1,-d2] [t1;t2] ||^2 627 | t1,t2 [3x1] [3x2] [2x1] 628 | 629 | => 630 | 631 | min || o + d t ||^2 632 | t [3x1] [3x2] [2x1] 633 | 634 | whose solution is t = (d.T d)^{-1} d.T (-o). 635 | """ 636 | N = self.scene.camera_count 637 | rays = [self.scene.cameras[i].sample_ray( 638 | p_rotation[i][None,None,...].to(self.device), is_sampler=False) for i in range(N)] 639 | 640 | t, r = np.linalg.lstsq( 641 | torch.stack((rays[0].d, -rays[1].d), axis=-1).cpu().detach().numpy(), 642 | -(rays[0].o - rays[1].o).cpu().detach().numpy(), rcond=None 643 | )[0:2] 644 | 645 | t_pt = torch.Tensor(t).to(self.device) 646 | os = [rays[i](t_pt[i]) for i in range(N)] 647 | if verbose: 648 | for i, o in enumerate(os): 649 | print('intersection point {}: {}'.format(i, o)) 650 | print('|intersection points distance error| = {} mm'.format(np.sqrt(r[0]))) 651 | return torch.mean(torch.stack(os), axis=0).cpu().detach().numpy() 652 | # ===================== -------------------------------------------------------------------------------- /imgs/checkerboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/checkerboard.png -------------------------------------------------------------------------------- /imgs/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/loss.png -------------------------------------------------------------------------------- /imgs/results_initial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/results_initial.png -------------------------------------------------------------------------------- /imgs/results_optimized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/results_optimized.png -------------------------------------------------------------------------------- /imgs/setup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/setup.png -------------------------------------------------------------------------------- /imgs/spot_diagram_initial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/spot_diagram_initial.png -------------------------------------------------------------------------------- /imgs/spot_diagram_optimized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/spot_diagram_optimized.png -------------------------------------------------------------------------------- /imgs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffDeflectometry/c30fb27b405ed3203a218b91f9a34f3d9e999f8b/imgs/teaser.jpg -------------------------------------------------------------------------------- /lenses/ThorLabs/LE1234-A.txt: -------------------------------------------------------------------------------- 1 | Thorlabs-LE1234 2 | type distance roc diameter material 3 | O 0 0 0 AIR 4 | S 0 -82.23 25.4 N-BK7 5 | S 3.59 -32.14 25.4 AIR 6 | I 95.9 0 25.4 AIR 7 | -------------------------------------------------------------------------------- /metrology_calibrate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import diffmetrology as dm 5 | from matplotlib.image import imread 6 | 7 | # load setup information 8 | data_path = './20210403' 9 | device = dm.init() 10 | # device = torch.device('cpu') 11 | 12 | print("Initialize a DiffMetrology object.") 13 | origin_shift = np.array([0.0, 0.0, 0.0]) 14 | DM = dm.DiffMetrology( 15 | calibration_path = data_path + '/calibration/', 16 | rotation_path = data_path + '/rotation_calibration/rotation.mat', 17 | lut_path = data_path + '/gamma_calibration/gammas.mat', 18 | origin_shift = origin_shift, 19 | scale=1.0, 20 | device=device 21 | ) 22 | 23 | print("Crop the region of interst in the original images.") 24 | filmsize = np.array([1024, 1024]) 25 | crop_offset = ((2048 - filmsize)/2).astype(int) 26 | for cam in DM.scene.cameras: 27 | cam.filmsize = filmsize 28 | cam.crop_offset = torch.Tensor(crop_offset).to(device) 29 | def crop(x): 30 | return x[..., crop_offset[0]:crop_offset[0]+filmsize[0], crop_offset[1]:crop_offset[1]+filmsize[1]] 31 | 32 | 33 | # ==== Read measurements 34 | lens_name = 'LE1234-A' 35 | 36 | # load data 37 | data = np.load(data_path + '/measurement/' + lens_name + '/data_new.npz') 38 | refs = data['refs'] 39 | refs = crop(refs) 40 | del data 41 | 42 | Ts = np.array([70, 100, 110]) # period of the sinusoids 43 | t = 0 44 | 45 | # change display pattern 46 | xs = [0] 47 | sinusoid_path = './camera_acquisitions/images/sinusoids/T=' + str(Ts[t]) 48 | ims = [ np.mean(imread(sinusoid_path + '/' + str(x) + '.png'), axis=-1) for x in xs ] # for now we use grayscale 49 | ims = np.array([ im/im.max() for im in ims ]) 50 | ims = np.sum(ims, axis=0) 51 | DM.set_texture(ims) 52 | ims = torch.Tensor(ims).to(device) 53 | 54 | # reference image 55 | I0 = torch.Tensor(np.array([refs[t,x,...] for x in xs])).to(device) 56 | I0 = torch.sum(I0, axis=0) 57 | 58 | # define functions 59 | def forward(): 60 | I = torch.stack(DM.render(with_element=False, angles=0.0)) 61 | return I #/ I.max() * I0.max() 62 | def loss(I): 63 | return (I - I0).mean() 64 | def func_yref_y(I): 65 | return I0 - I 66 | 67 | def show_img(I, string): 68 | fig = plt.figure() 69 | plt.imshow(I[0].cpu().detach(), vmin=0, vmax=1, cmap='gray') 70 | plt.colorbar() 71 | plt.title(string) 72 | plt.axis('off') 73 | fig.savefig("img_" + string + ".jpg", bbox_inches='tight') 74 | 75 | def show_error(I, string): 76 | fig = plt.figure() 77 | plt.imshow(I[0].cpu().detach() - I0[0].cpu(), vmin=-1, vmax=1, cmap='coolwarm') 78 | plt.colorbar() 79 | plt.title(string) 80 | plt.axis('off') 81 | fig.savefig("photo_" + string + ".jpg", bbox_inches='tight') 82 | 83 | 84 | # initialize parameters 85 | DM.scene.screen.texture_shift = torch.Tensor([0.0, 0.0]).to(device) 86 | 87 | # parameters 88 | diff_names = ['screen.texture_shift'] 89 | 90 | # initial 91 | I = forward() 92 | show_img(I0, 'Measurement') 93 | show_img(I, 'Modeled') 94 | show_error(I, 'Initial') 95 | 96 | # optimize 97 | ls = DM.solve(diff_names, forward, loss, func_yref_y, option='LM') 98 | 99 | # plot loss 100 | plt.figure() 101 | plt.semilogy(ls, '-o', color='k') 102 | plt.xlabel('LM iteration') 103 | plt.ylabel('Loss') 104 | plt.title("Opitmization Loss") 105 | 106 | I = forward() 107 | show_img(I0, 'Measurement') 108 | show_img(I, 'Modeled') 109 | show_error(I, 'Optimized') 110 | 111 | plt.show() 112 | --------------------------------------------------------------------------------