├── documents
├── results.png
├── principles.png
└── uniform-diffuser.png
├── __pycache__
├── ASASM.cpython-37.pyc
├── LSASM.cpython-37.pyc
├── utils.cpython-37.pyc
├── ASASM.cpython-310.pyc
├── utils.cpython-310.pyc
├── input_field.cpython-310.pyc
├── input_field.cpython-37.pyc
├── phase_plates.cpython-310.pyc
└── phase_plates.cpython-37.pyc
├── .vscode
└── launch.json
├── utils.py
├── README.md
├── RS.py
├── main.py
├── phase_plates.py
├── input_field.py
├── environment.yml
└── LSASM.py
/documents/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/documents/results.png
--------------------------------------------------------------------------------
/documents/principles.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/documents/principles.png
--------------------------------------------------------------------------------
/documents/uniform-diffuser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/documents/uniform-diffuser.png
--------------------------------------------------------------------------------
/__pycache__/ASASM.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/ASASM.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/LSASM.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/LSASM.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/ASASM.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/ASASM.cpython-310.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/__pycache__/input_field.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/input_field.cpython-310.pyc
--------------------------------------------------------------------------------
/__pycache__/input_field.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/input_field.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/phase_plates.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/phase_plates.cpython-310.pyc
--------------------------------------------------------------------------------
/__pycache__/phase_plates.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/phase_plates.cpython-37.pyc
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: Remote Attach",
9 | "type": "python",
10 | "request": "attach",
11 | "connect": {
12 | "host": "localhost",
13 | "port": 5678
14 | },
15 | "pathMappings": [
16 | {
17 | "localRoot": "${workspaceFolder}",
18 | "remoteRoot": "."
19 | }
20 | ],
21 | "justMyCode": true
22 | }
23 | ]
24 | }
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | These are the util functions.
3 |
4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
5 | # The license is only for non-commercial use (commercial licenses can be obtained from authors).
6 | # The material is provided as-is, with no warranties whatsoever.
7 | # If you publish any code, data, or scientific work based on this, please cite our work.
8 |
9 |
10 | Technical Paper:
11 | Haoyu Wei, Xin Liu, Xiang Hao, Edmund Y. Lam, and Yifan Peng, "Modeling off-axis diffraction with the least-sampling angular spectrum method," Optica 10, 959-962 (2023)
12 | """
13 |
14 | import numpy as np
15 | from PIL import Image
16 | from matplotlib import cm
17 |
18 |
19 | def effective_bandwidth(D, wvls=None, is_plane_wave=False, zf=None, s=1.):
20 | if is_plane_wave:
21 | bandwidth = 41.2 * s / D
22 | else:
23 | assert zf is not None, "Wave origin should be provided!"
24 | bandwidth = s * D / wvls / zf
25 |
26 | return bandwidth
27 |
28 |
29 | def save_image(image, save_path, cmap='gray'):
30 |
31 | imarray = np.array(image / image.max()) # 0~1
32 | if cmap == 'viridis':
33 | imarray = cm.viridis(imarray)
34 | elif cmap == 'twilight':
35 | imarray = cm.twilight(imarray)
36 | elif cmap == 'magma':
37 | imarray = cm.magma(imarray)
38 | elif cmap == 'plasma':
39 | imarray = cm.plasma(imarray)
40 | im = Image.fromarray(np.uint8(imarray * 255))
41 | im.save(save_path)
42 |
43 |
44 | def remove_linear_phase(phi, thetaX, thetaY, x, y, k):
45 |
46 | linear_phiX = -np.sin(thetaX / 180 * np.pi) * k
47 | linear_phiY = -np.sin(thetaY / 180 * np.pi) * k
48 |
49 | xx, yy = np.meshgrid(x, y, indexing='xy')
50 | phi_new = phi - xx * linear_phiX - yy * linear_phiY
51 |
52 | return np.remainder(phi_new, 2 * np.pi)
53 |
54 |
55 | def snr(u_hat, u_ref):
56 | u_hat /= abs(u_hat).max()
57 | u_ref /= abs(u_ref).max()
58 | signal = np.sum(abs(u_hat)**2)
59 | alpha = np.sum(u_hat * np.conjugate(u_ref)) / np.sum(abs(u_ref)**2)
60 | snr = signal / np.sum(abs(u_hat - alpha * u_ref)**2)
61 | return 10 * np.log10(snr)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Author Contributions
2 | To provide essential information regarding author contributions omitted from the published letter:
3 |
4 | H.W. and X.L. conceived the project and wrote the paper; X.L. constructed the storyline and led the formulation and methodology development with H.W.'s discussion and derivation.
5 |
6 | H.W. and X.L. built the code;
7 |
8 | X.L. and H.W. co-authored the manuscript with valuable feedback from Y.P., E.Y.L., and X.H.;
9 |
10 | H.W. created all figures with input from X.L.;
11 |
12 | Y.P. provided project supervision.
13 |
14 | ***For more insights beyond this work, you may correspond with [Dr. Xin Liu](https://liux2018.github.io) and refer to [On computational optics](https://github.com/LiuX2018/On-computational-optics).***
15 |
16 | # LS-ASM
17 | This repository provides the official open-source code of the following paper:
18 |
19 | **Modeling off-axis diffraction with least-sampling angular spectrum method**\
20 | Haoyu Wei*, Xin Liu*, Xiang Hao, Edmund Y. Lam, Yifan Peng\
21 | [Paper](https://doi.org/10.1364/OPTICA.490223), [Project page](https://whywww.github.io/LSASM_page/) \
22 | Correspondence: [Dr. Peng](https://www.eee.hku.hk/~evanpeng/) and [Prof. Lam](https://www.eee.hku.hk/~elam/). For implementation and experiment details please contact Haoyu (haoyu.wei97@gmail.com).
23 |
24 |
25 |
26 | ## Quick start
27 | This repository contains implementations of LS-ASM and Rayleigh-Sommerfeld algorithms, with spherical wave input and thin lens and diffuser modulations.
28 |
29 | ### Prerequisites
30 | Create a conda environment from yml file:
31 | ```
32 | conda env create -f environment.yml
33 | ```
34 | If you are running on a GPU, please install a PyTorch version that matches the Cuda version on your machine.
35 |
36 | ### Config and Run
37 | Configurations are in `main.py`.\
38 | Run and find results in the `results` folder.
39 | ```
40 | python main.py
41 | ```
42 |
43 | ## Performance
44 | We display LS-ASM speedup along 0 - 20 degrees of incident angles.\
45 |
46 |
47 | Diffuser results closely resemble RS.\
48 |
49 |
50 | ## Citation
51 |
52 | If you use this code and find our work valuable, please cite our paper.
53 | ```
54 | @article{Wei:23,
55 | title = {Modeling Off-Axis Diffraction with the Least-Sampling Angular Spectrum Method},
56 | author = {Haoyu Wei and Xin Liu and Xiang Hao and Edmund Y. Lam and Yifan Peng},
57 | journal = {Optica},
58 | volume = {10}, number = {7}, pages = {959--962},
59 | publisher = {Optica Publishing Group},
60 | year = {2023},
61 | month = {Jul},
62 | doi = {10.1364/OPTICA.490223}
63 | }
64 | ```
65 |
66 | ## License
67 |
68 | 
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
69 |
--------------------------------------------------------------------------------
/RS.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the implementation of the Rayleigh-Sommerfeld algorithm. Refer to Goodman, Joseph W.
3 | Introduction to Fourier optics. Roberts and Company Publishers, 2005, for principle details.
4 | This code is adapted from a Matlab script from Xin Liu and converted into a GPU parallel
5 | -computing Python script by Haoyu Wei (haoyu.wei97@gmail.com).
6 |
7 |
8 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
9 | # The license is only for non-commercial use (commercial licenses can be obtained from authors).
10 | # The material is provided as-is, with no warranties whatsoever.
11 | # If you publish any code, data, or scientific work based on this, please cite our work.
12 |
13 |
14 | Technical Paper:
15 | Haoyu Wei, Xin Liu, Xiang Hao, Edmund Y. Lam, and Yifan Peng, "Modeling off-axis
16 | diffraction with the least-sampling angular spectrum method," Optica 10, 959-962 (2023)
17 | """
18 |
19 |
20 | import torch
21 | import math
22 | from tqdm import tqdm
23 |
24 |
25 | class RSDiffraction_GPU():
26 | '''
27 | Optimized for parallel computing
28 | '''
29 | def __init__(self, z, xvec, yvec, svec, tvec, wavelengths, device) -> None:
30 | '''
31 | x,s are horizontal. y,t are vertical.
32 | '''
33 |
34 | self.device = device
35 | self.k = 2 * torch.pi / wavelengths
36 | self.z = z
37 |
38 | xvec, yvec = torch.tensor(xvec), torch.tensor(yvec)
39 | svec, tvec = torch.tensor(svec), torch.tensor(tvec)
40 | xx, yy = torch.meshgrid(xvec, yvec, indexing='xy')
41 | ss, tt = torch.meshgrid(svec, tvec, indexing='xy')
42 | self.ss, self.tt = ss.to(device), tt.to(device)
43 | self.xx, self.yy = xx.to(device), yy.to(device)
44 |
45 | self.block_sz = 100 # depends on your memory, e.g., 128 needs ~24GB GPU memory
46 |
47 |
48 | def __call__(self, E0):
49 |
50 | E0 = torch.tensor(E0, dtype=torch.complex128, device=self.device)
51 |
52 | LX, LY = E0.shape[-2:]
53 | LS, LT = self.ss.shape
54 |
55 | Eout = []
56 | for bt in tqdm(range(math.ceil(LT / self.block_sz)), desc='tvec', position=0):
57 | Erow = []
58 | for bs in tqdm(range(math.ceil(LS / self.block_sz)), desc='svec', position=1, leave=False):
59 | ss_ = self.ss[bt*self.block_sz : (bt+1)*self.block_sz, bs*self.block_sz : (bs+1)*self.block_sz]
60 | tt_ = self.tt[bt*self.block_sz : (bt+1)*self.block_sz, bs*self.block_sz : (bs+1)*self.block_sz]
61 | block_sum = torch.zeros_like(ss_, dtype=E0.dtype)
62 | for by in tqdm(range(math.ceil(LY / self.block_sz)), desc='yvec', position=2, leave=False):
63 | for bx in tqdm(range(math.ceil(LX / self.block_sz)), desc='xvec', position=3, leave=False):
64 | E0_ = E0[by*self.block_sz : (by+1)*self.block_sz, bx*self.block_sz : (bx+1)*self.block_sz]
65 | xx_ = self.xx[by*self.block_sz : (by+1)*self.block_sz, bx*self.block_sz : (bx+1)*self.block_sz]
66 | yy_ = self.yy[by*self.block_sz : (by+1)*self.block_sz, bx*self.block_sz : (bx+1)*self.block_sz]
67 | xx_st = xx_[..., None, None]
68 | yy_st = yy_[..., None, None]
69 | xy_ss = ss_.expand(*xx_.shape, *ss_.shape)
70 | xy_tt = tt_.expand(*xx_.shape, *tt_.shape)
71 | r = torch.sqrt((xy_ss - xx_st)**2 + (xy_tt - yy_st)**2 + self.z**2)
72 | h = -1 / (2 * torch.pi) * (1j * self.k - 1 / r) * torch.exp(1j * self.k * r) * self.z / r**2
73 | block_sum += torch.einsum('xy, xyst', E0_, h)
74 | Erow.append(block_sum)
75 | Eout.append(torch.hstack(Erow))
76 | Eout = torch.vstack(Eout)
77 |
78 | return Eout.cpu().numpy()
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | LS-ASM:
3 |
4 | This is the main executive script used for the diffraction field calculation using LS-ASM.
5 |
6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
7 | # The license is only for non-commercial use (commercial licenses can be obtained from authors).
8 | # The material is provided as-is, with no warranties whatsoever.
9 | # If you publish any code, data, or scientific work based on this, please cite our work.
10 |
11 | @article{Wei:23,
12 | title = {Modeling Off-Axis Diffraction with the Least-Sampling Angular Spectrum Method},
13 | author = {Haoyu Wei and Xin Liu and Xiang Hao and Edmund Y. Lam and Yifan Peng},
14 | journal = {Optica},
15 | volume = {10}, number = {7}, pages = {959--962},
16 | publisher = {Optica Publishing Group},
17 | year = {2023},
18 | month = {Jul},
19 | doi = {10.1364/OPTICA.490223}
20 | }
21 |
22 | -----
23 |
24 | $ python main.py
25 | """
26 |
27 |
28 | import numpy as np
29 | import time
30 | from utils import save_image, remove_linear_phase, snr
31 | import glob
32 | from input_field import InputField
33 |
34 |
35 | ############################### hyperparameters ############################
36 |
37 | wvls = 500e-9 # wavelength of light in vacuum
38 | k = 2 * np.pi / wvls # wavenumebr
39 | f = 35e-3 # focal length of lens (if applicable)
40 | z0 = 1.7 # source-aperture distance
41 | zf = 1/(1/f - 1/z0) # image-side focal distance
42 | z = zf # aperture-sensor distance
43 | r = f / 16 / 2 # radius of aperture
44 | thetaX = 0 # incident angle in degree
45 | thetaY = 5 # incident angle in degree
46 |
47 | s_LSASM = 1.5 # oversampling factor for LSASM
48 | s_RS = 4 # oversampling factor for Rayleigh-Sommerfeld
49 | compensate = True # LPC
50 | use_LSASM = True
51 | use_RS = False
52 | result_folder = 'results'
53 | RS_folder = 'RS'
54 | calculate_SNR = False
55 |
56 | # define observation window
57 | Mx, My = 512, 512
58 | l = r * 0.25
59 | # l = 0.0136/1.5 # first term in Eq6 scaled by 1/1.5 to estimate OW size, used for diffuser
60 | # l = r * 8. # 35 degrees
61 | xc = - z * np.sin(thetaX / 180 * np.pi) / np.sqrt(1 - np.sin(thetaX / 180 * np.pi)**2 - np.sin(thetaY / 180 * np.pi)**2)
62 | yc = - z * np.sin(thetaY / 180 * np.pi) / np.sqrt(1 - np.sin(thetaX / 180 * np.pi)**2 - np.sin(thetaY / 180 * np.pi)**2)
63 |
64 | x = np.linspace(-l / 2 + xc, l / 2 + xc, Mx, endpoint=True)
65 | y = np.linspace(-l / 2 + yc, l / 2 + yc, My, endpoint=True)
66 | print(f'observation window diamter = {l}.')
67 |
68 | if use_LSASM:
69 | print('----------------- Propagating with ASASM -----------------')
70 | # use "12" for thin lens + spherical wave
71 | # use "3" for diffuser
72 | Uin = InputField("12", wvls, (thetaX, thetaY), r, z0, f, zf, s_LSASM)
73 |
74 | from LSASM import LeastSamplingASM
75 | device = 'cuda:0'
76 | # device = 'cpu'
77 | prop2 = LeastSamplingASM(Uin, x, y, z, device)
78 | path = f'{result_folder}/LSASM({len(Uin.xi)},{len(prop2.fx)})-{thetaX}-{s_LSASM:.2f}'
79 |
80 | start = time.time()
81 | U2 = prop2(Uin.E0)
82 | end = time.time()
83 | runtime = end - start
84 | print(f'Time elapsed for LSASM: {runtime:.2f}')
85 |
86 | save_image(abs(U2), f'{path}.png', cmap='gray')
87 | phase = remove_linear_phase(np.angle(U2), thetaX, thetaY, x, y, k) # for visualization
88 | save_image(phase, f'{path}-Phi.png', cmap='twilight')
89 |
90 | if calculate_SNR:
91 | if glob.glob(f'{RS_folder}/RS*-{thetaX}-{s_RS:.1f}.npy') != []:
92 | u_GT = np.load(glob.glob(f'{RS_folder}/RS*-{thetaX}-{s_RS:.1f}.npy')[0])
93 | print(f'SNR is {snr(U2, u_GT):.2f}')
94 |
95 |
96 | if use_RS:
97 | print('-------------- Propagating with RS integral --------------')
98 | Uin = InputField("12", wvls, (thetaX, thetaY), r, z0, f, zf, s_RS)
99 |
100 | from RS import RSDiffraction_GPU
101 | prop = RSDiffraction_GPU(z, Uin.xi, Uin.eta, x, y, wvls, 'cuda:0')
102 | path = f'{RS_folder}/RS({len(Uin.xi)})-{thetaX}-{s_RS:.1f}'
103 | start = time.time()
104 | U0 = prop(Uin.E0)
105 | end = time.time()
106 | print(f'Time elapsed for RS: {end-start:.2f}')
107 | save_image(abs(U0), f'{path}.png', cmap='gray')
108 | phase = remove_linear_phase(np.angle(U0), thetaX, thetaY, x, y, k) # for visualization
109 | save_image(phase, f'{path}-Phi.png', cmap='twilight')
110 | np.save(f'{path}', U0)
--------------------------------------------------------------------------------
/phase_plates.py:
--------------------------------------------------------------------------------
1 | """
2 | This script includes the modulations and components of the input field.
3 |
4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
5 | # The license is only for non-commercial use (commercial licenses can be obtained from authors).
6 | # The material is provided as-is, with no warranties whatsoever.
7 | # If you publish any code, data, or scientific work based on this, please cite our work.
8 |
9 |
10 | Technical Paper:
11 | Haoyu Wei, Xin Liu, Xiang Hao, Edmund Y. Lam, and Yifan Peng, "Modeling off-axis diffraction with the least-sampling angular spectrum method," Optica 10, 959-962 (2023)
12 | """
13 |
14 |
15 | import numpy as np
16 | from utils import effective_bandwidth
17 | import cv2
18 |
19 |
20 | class SphericalWave():
21 | def __init__(self, k, x0, y0, z0, angles, zf) -> None:
22 |
23 | self.k = k
24 | thetaX, thetaY = angles
25 | self.fcX = - np.sin(thetaX / 180 * np.pi) * k / (2 * np.pi)
26 | self.fcY = - np.sin(thetaY / 180 * np.pi) * k / (2 * np.pi)
27 | self.x0, self.y0, self.z0 = x0, y0, z0
28 | self.zf = zf
29 |
30 |
31 | def forward(self, E0, xi_, eta_):
32 | '''
33 | Apply a spherical phase shift to E0 at coordinates xi_ and eta_
34 | '''
35 |
36 | radius = np.sqrt(self.z0**2 + (xi_ - self.x0)**2 + (eta_ - self.y0)**2)
37 | phase = self.k * radius
38 | amplitude = 1 / radius
39 |
40 | E = amplitude * np.exp(1j * phase)
41 | E *= np.exp(1j * 2 * np.pi * (-self.fcX * xi_ - self.fcY * eta_)) # LPC
42 |
43 | return E0 * E
44 |
45 |
46 | def phase_gradient(self, xi, eta):
47 | '''
48 | Compute phase gradients at point (xi, eta)
49 | '''
50 |
51 | denom = np.sqrt((xi - self.x0)**2 + (eta - self.y0)**2 + self.z0**2)
52 | grad_uX = self.k * (xi - self.x0) / denom
53 | grad_uY = self.k * (eta - self.y0) / denom
54 |
55 | grad_linearX = 2 * np.pi * self.fcX
56 | grad_linearY = 2 * np.pi * self.fcY
57 |
58 | gradientX = grad_uX - grad_linearX
59 | gradientY = grad_uY - grad_linearY
60 |
61 | return gradientX, gradientY
62 |
63 |
64 | class PlaneWave():
65 | def __init__(self, k, r, x0, y0, z0) -> None:
66 |
67 | self.k = k
68 | self.r = r
69 | self.fcX = self.fcY = 0
70 | self.x0, self.y0, self.z0 = x0, y0, z0
71 |
72 |
73 | def forward(self, E0, xi_, eta_):
74 |
75 | vec = np.array([-self.x0, -self.y0, self.z0])
76 | kx, ky, kz = vec / np.sqrt(np.dot(vec, vec))
77 | phase = self.k * (kx * xi_ + ky * eta_ + kz)
78 |
79 | return E0 * np.exp(1j * phase)
80 |
81 |
82 | def phase_gradient(self, xi, eta):
83 |
84 | return 0, 0
85 |
86 |
87 | class ThinLens():
88 | def __init__(self, k, f) -> None:
89 |
90 | self.k = k
91 | self.f = f
92 | self.fcX = self.fcY = 0
93 |
94 |
95 | def forward(self, E0, xi_, eta_):
96 |
97 | phase = self.k / 2 * (-1 / self.f) * (xi_**2 + eta_**2)
98 |
99 | return E0 * np.exp(1j * phase)
100 |
101 |
102 | def phase_gradient(self, xi, eta):
103 |
104 | grad_uX = -self.k / self.f * xi
105 | grad_uY = -self.k / self.f * eta
106 |
107 | return grad_uX, grad_uY
108 |
109 |
110 | class Diffuser():
111 | def __init__(self, r, interpolation='nearest', rand_phase=True, rand_amp=False) -> None:
112 | '''
113 | Two types of diffusers: 'nearest' or 'linear' interpolated
114 | '''
115 |
116 | self.fcX = self.fcY = 0
117 | self.pitch = r / 10
118 | self.N = int(r * 2 / self.pitch)
119 | np.random.seed(0)
120 | self.plate = np.random.rand(self.N, self.N)
121 | self.interp = interpolation
122 | self.rand_phase = rand_phase
123 | self.rand_amp = rand_amp
124 |
125 |
126 | def forward(self, E0, xi_, eta_):
127 |
128 | if self.interp == 'nearest':
129 | plate_sample = cv2.resize(self.plate, xi_.shape, interpolation=cv2.INTER_NEAREST)
130 | elif self.interp == 'linear':
131 | plate_sample = cv2.resize(self.plate, xi_.shape, interpolation=cv2.INTER_LINEAR)
132 | else:
133 | raise NotImplementedError
134 |
135 | amp = np.ones_like(plate_sample)
136 | phase = np.zeros_like(plate_sample)
137 | if self.rand_phase:
138 | phase = plate_sample * 4 * np.pi # random phase from 0 to 4pi
139 | if self.rand_amp:
140 | amp = plate_sample # random amplitude from 0 to 1
141 |
142 | return E0 * amp * np.exp(1j * phase)
143 |
144 |
145 | def phase_gradient(self):
146 | '''
147 | :return: maximum phase gradient
148 | '''
149 |
150 | # Second term in Eq. 4
151 | grad_max = effective_bandwidth(self.pitch, is_plane_wave = True)
152 |
153 | # nearest interpolation does not have phase gradient
154 | # but linear interpolation does
155 | if self.interp == 'linear':
156 | if self.rand_phase:
157 | grad_max += 4 / self.pitch
158 | if self.rand_amp:
159 | grad_max += 1 / self.pitch / np.pi
160 |
161 | return grad_max, grad_max
--------------------------------------------------------------------------------
/input_field.py:
--------------------------------------------------------------------------------
1 | """
2 | This is where the input field is integrated.
3 |
4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
5 | # The license is only for non-commercial use (commercial licenses can be obtained from authors).
6 | # The material is provided as-is, with no warranties whatsoever.
7 | # If you publish any code, data, or scientific work based on this, please cite our work.
8 |
9 |
10 | Technical Paper:
11 | Haoyu Wei, Xin Liu, Xiang Hao, Edmund Y. Lam, and Yifan Peng, "Modeling off-axis diffraction with the least-sampling angular spectrum method," Optica 10, 959-962 (2023)
12 | """
13 |
14 |
15 | import numpy as np
16 | from phase_plates import SphericalWave, PlaneWave, ThinLens, Diffuser
17 | from operator import add
18 | from utils import effective_bandwidth
19 |
20 |
21 | class InputField():
22 | '''
23 | Prepare compensated input field and spatial sampling
24 | '''
25 | def __init__(self, type:str, wvls:float, angles, r:float,
26 | z0=None, f=None, zf=None, s=1.5) -> None:
27 |
28 | self.wvls = wvls # wavelength of light in vacuum
29 | self.k = 2 * np.pi / self.wvls # wavenumebr
30 | thetaX, thetaY = angles
31 |
32 | # define incident wave
33 | r0 = z0 / np.sqrt(1 - np.sin(thetaX / 180 * np.pi)**2 - np.sin(thetaY / 180 * np.pi)**2)
34 | x0, y0 = r0 * np.sin(thetaX / 180 * np.pi), r0 * np.sin(thetaY / 180 * np.pi)
35 | print(f'aperture diameter = {2 * r}, offset = {x0:.4f}, theta = {thetaX}.')
36 |
37 | # prepare wave components
38 | typelist = [*type]
39 | wavelist = []
40 | fcX = 0 # frequency centers
41 | fcY = 0
42 |
43 | print('Input field contains:')
44 | if "0" in typelist:
45 | print('\t Plane wave')
46 |
47 | field = PlaneWave(self.k, r, x0, y0, z0)
48 | fcX += field.fcX
49 | fcY += field.fcY
50 | wavelist.append(field)
51 |
52 | if "1" in typelist:
53 | print('\t Spherical wave')
54 |
55 | field = SphericalWave(self.k, x0, y0, z0, angles, zf)
56 | fcX += field.fcX
57 | fcY += field.fcY
58 | wavelist.append(field)
59 |
60 | if "2" in typelist:
61 | print('\t Convex lens')
62 | lens = ThinLens(self.k, f)
63 | fcX += lens.fcX
64 | fcY += lens.fcY
65 | wavelist.append(lens)
66 |
67 | if "3" in typelist:
68 | print('\t Random diffuser')
69 |
70 | phase_plate = Diffuser(r, interpolation='linear', rand_phase=True, rand_amp=True)
71 | fcX += phase_plate.fcX
72 | fcY += phase_plate.fcY
73 | wavelist.append(phase_plate)
74 |
75 | # Compute spatial sampling
76 | Nx, Ny, fbX, fbY = self.spatial_sampling(r, s, wavelist)
77 |
78 | # Prepare input field
79 | self.set_input_plane(r, Nx, Ny)
80 | E0 = self.pupil
81 | for wave in wavelist:
82 | E0 = wave.forward(E0, self.xi_, self.eta_)
83 |
84 | self.fcX = fcX
85 | self.fcY = fcY
86 | self.fbX = fbX
87 | self.fbY = fbY
88 | self.E0 = E0
89 | self.s = s
90 | self.zf = zf
91 | self.D = 2 * r
92 | self.type = type
93 |
94 |
95 | def spatial_sampling(self, r, s, wavelist):
96 | '''
97 | :param r: aperture radius
98 | :param s: oversampling factor
99 | :param wavelist: a list of input wave components
100 | :return: number of samples in both dimensions, bandwidths in both dimensions
101 | '''
102 |
103 | # Second term in Eq4, the size of Airy disk
104 | fplane = effective_bandwidth(r*2, is_plane_wave = True)
105 |
106 | # First term in Eq4, maximum phase gradient
107 | # as the phase terms are all monotonic here,
108 | # we use the two boundaries of aperture (+-r) to find max
109 | grad1 = [0, 0]
110 | grad2 = [0, 0]
111 | diffuser = False
112 | for wave in wavelist:
113 | if isinstance(wave, Diffuser):
114 | diffuser = True
115 | grad = wave.phase_gradient()
116 | fbX_diffuser = grad[0] * s
117 | fbY_diffuser = grad[1] * s
118 | else:
119 | grad1 = list(map(add, grad1, wave.phase_gradient(-r, -r)))
120 | grad2 = list(map(add, grad2, wave.phase_gradient(r, r)))
121 | fbX = (max(abs(grad1[0]), abs(grad2[0])) / np.pi + fplane) * s
122 | fbY = (max(abs(grad1[1]), abs(grad2[1])) / np.pi + fplane) * s
123 |
124 | if diffuser:
125 | fbX = max(fbX, fbX_diffuser)
126 | fbY = max(fbY, fbY_diffuser)
127 |
128 | Nx = int(np.ceil(fbX * 2 * r))
129 | Ny = int(np.ceil(fbY * 2 * r))
130 | print(f'spatial sampling number = {Nx, Ny}.')
131 |
132 | return Nx, Ny, (Nx - 1) / (2 * r), (Ny - 1) / (2 * r)
133 |
134 |
135 | def set_input_plane(self, r, Nx, Ny):
136 |
137 | # coordinates of aperture
138 | xi = np.linspace(-r, r, Nx, endpoint=True)
139 | eta = np.linspace(-r, r, Ny, endpoint=True)
140 | xi_, eta_ = np.meshgrid(xi, eta, indexing='xy')
141 |
142 | # circular aperture
143 | pupil = np.where(xi_**2 + eta_**2 <= r**2, 1, 0)
144 |
145 | self.pupil = pupil
146 | self.xi, self.eta = xi, eta
147 | self.xi_, self.eta_ = xi_, eta_
148 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: lsasm
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - anaconda
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=5.1=1_gnu
10 | - backcall=0.2.0=pyhd3eb1b0_0
11 | - blas=1.0=mkl
12 | - blosc=1.21.0=h4ff587b_1
13 | - bottleneck=1.3.5=py37h7deecbd_0
14 | - brotli=1.0.9=h5eee18b_7
15 | - brotli-bin=1.0.9=h5eee18b_7
16 | - brunsli=0.1=h2531618_0
17 | - bzip2=1.0.8=h7b6447c_0
18 | - c-ares=1.18.1=h7f8727e_0
19 | - ca-certificates=2023.01.10=h06a4308_0
20 | - cairo=1.16.0=hf32fb01_1
21 | - certifi=2022.12.7=py37h06a4308_0
22 | - cfitsio=3.470=h5893167_7
23 | - charls=2.2.0=h2531618_0
24 | - cloudpickle=2.0.0=pyhd3eb1b0_0
25 | - colorama=0.4.5=py37h06a4308_0
26 | - cudatoolkit=11.3.1=h9edb442_10
27 | - cycler=0.11.0=pyhd3eb1b0_0
28 | - cytoolz=0.12.0=py37h5eee18b_0
29 | - dask-core=2021.10.0=pyhd3eb1b0_0
30 | - dbus=1.13.18=hb2f20db_0
31 | - debugpy=1.5.1=py37h295c915_0
32 | - decorator=5.1.1=pyhd3eb1b0_0
33 | - entrypoints=0.4=py37h06a4308_0
34 | - expat=2.4.9=h6a678d5_0
35 | - ffmpeg=4.3=hf484d3e_0
36 | - flit-core=3.6.0=pyhd3eb1b0_0
37 | - fontconfig=2.13.1=h6c09931_0
38 | - freetype=2.12.1=h4a9f257_0
39 | - fsspec=2022.11.0=py37h06a4308_0
40 | - gettext=0.21.0=hf68c758_0
41 | - giflib=5.2.1=h7b6447c_0
42 | - glib=2.68.4=h9c3ff4c_0
43 | - glib-tools=2.68.4=h9c3ff4c_0
44 | - gmp=6.2.1=h295c915_3
45 | - gnutls=3.6.15=he1e5248_0
46 | - graphite2=1.3.14=h295c915_1
47 | - gst-plugins-base=1.14.5=h0935bb2_2
48 | - gstreamer=1.18.5=h76c114f_0
49 | - harfbuzz=2.7.2=ha5b49bf_1
50 | - hdf5=1.10.6=h3ffc7dd_1
51 | - icu=67.1=he1b5a44_0
52 | - imagecodecs=2021.8.26=py37hf0132c2_1
53 | - imageio=2.19.3=py37h06a4308_0
54 | - intel-openmp=2021.4.0=h06a4308_3561
55 | - ipykernel=6.15.2=py37h06a4308_0
56 | - ipython=7.31.1=py37h06a4308_1
57 | - jasper=1.900.1=hd497a04_4
58 | - jedi=0.18.1=py37h06a4308_1
59 | - jpeg=9e=h7f8727e_0
60 | - jupyter_client=7.4.7=py37h06a4308_0
61 | - jupyter_core=4.11.2=py37h06a4308_0
62 | - jxrlib=1.1=h7b6447c_2
63 | - kiwisolver=1.4.2=py37h295c915_0
64 | - krb5=1.19.2=hac12032_0
65 | - lame=3.100=h7b6447c_0
66 | - lcms2=2.12=h3be6417_0
67 | - ld_impl_linux-64=2.38=h1181459_1
68 | - lerc=3.0=h295c915_0
69 | - libaec=1.0.4=he6710b0_1
70 | - libblas=3.9.0=12_linux64_mkl
71 | - libbrotlicommon=1.0.9=h5eee18b_7
72 | - libbrotlidec=1.0.9=h5eee18b_7
73 | - libbrotlienc=1.0.9=h5eee18b_7
74 | - libcblas=3.9.0=12_linux64_mkl
75 | - libclang=11.1.0=default_ha53f305_1
76 | - libcurl=7.87.0=h91b91d3_0
77 | - libdeflate=1.8=h7f8727e_5
78 | - libedit=3.1.20221030=h5eee18b_0
79 | - libev=4.33=h7f8727e_1
80 | - libevent=2.1.10=h9b69904_4
81 | - libffi=3.3=he6710b0_2
82 | - libgcc-ng=11.2.0=h1234567_1
83 | - libgfortran-ng=11.2.0=h00389a5_1
84 | - libgfortran5=11.2.0=h1234567_1
85 | - libglib=2.68.4=h3e27bee_0
86 | - libgomp=11.2.0=h1234567_1
87 | - libiconv=1.16=h7f8727e_2
88 | - libidn2=2.3.2=h7f8727e_0
89 | - liblapack=3.9.0=12_linux64_mkl
90 | - liblapacke=3.9.0=12_linux64_mkl
91 | - libllvm11=11.1.0=h9e868ea_6
92 | - libnghttp2=1.46.0=hce63b2e_0
93 | - libopencv=4.4.0=py37_2
94 | - libpng=1.6.37=hbc83047_0
95 | - libpq=12.9=h16c4e8d_3
96 | - libsodium=1.0.18=h7b6447c_0
97 | - libssh2=1.10.0=h8f2d780_0
98 | - libstdcxx-ng=11.2.0=h1234567_1
99 | - libtasn1=4.16.0=h27cfd23_0
100 | - libtiff=4.4.0=hecacb30_2
101 | - libunistring=0.9.10=h27cfd23_0
102 | - libuuid=1.41.5=h5eee18b_0
103 | - libuv=1.40.0=h7b6447c_0
104 | - libwebp=1.2.4=h11a3e52_0
105 | - libwebp-base=1.2.4=h5eee18b_0
106 | - libxcb=1.15=h7f8727e_0
107 | - libxkbcommon=1.0.3=he3ba5ed_0
108 | - libxml2=2.9.10=h68273f3_2
109 | - libzopfli=1.0.3=he6710b0_0
110 | - locket=1.0.0=py37h06a4308_0
111 | - lz4-c=1.9.4=h6a678d5_0
112 | - matplotlib=3.2.2=1
113 | - matplotlib-base=3.2.2=py37h1d35a4c_1
114 | - matplotlib-inline=0.1.6=py37h06a4308_0
115 | - mkl=2021.4.0=h06a4308_640
116 | - mkl-service=2.4.0=py37h7f8727e_0
117 | - mkl_fft=1.3.1=py37hd3c417c_0
118 | - mkl_random=1.2.2=py37h51133e4_0
119 | - mysql-common=8.0.25=ha770c72_2
120 | - mysql-libs=8.0.25=hfa10184_2
121 | - ncurses=6.3=h5eee18b_3
122 | - nest-asyncio=1.5.5=py37h06a4308_0
123 | - nettle=3.7.3=hbbd107a_1
124 | - networkx=2.6.3=pyhd3eb1b0_0
125 | - nspr=4.33=h295c915_0
126 | - nss=3.74=h0370c37_0
127 | - numexpr=2.8.3=py37h807cd23_0
128 | - numpy=1.21.5=py37h6c91a56_3
129 | - numpy-base=1.21.5=py37ha15fc14_3
130 | - opencv=4.4.0=py37_2
131 | - openh264=2.1.1=h4ff587b_0
132 | - openjpeg=2.4.0=h3ad879b_0
133 | - openssl=1.1.1s=h7f8727e_0
134 | - packaging=21.3=pyhd3eb1b0_0
135 | - pandas=1.3.5=py37h8c16a72_0
136 | - parso=0.8.3=pyhd3eb1b0_0
137 | - partd=1.2.0=pyhd3eb1b0_1
138 | - patsy=0.5.3=pyhd8ed1ab_0
139 | - pcre=8.45=h295c915_0
140 | - pexpect=4.8.0=pyhd3eb1b0_3
141 | - pickleshare=0.7.5=pyhd3eb1b0_1003
142 | - pillow=9.3.0=py37hace64e9_0
143 | - pip=22.3.1=py37h06a4308_0
144 | - pixman=0.40.0=h7f8727e_1
145 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0
146 | - psutil=5.9.0=py37h5eee18b_0
147 | - ptyprocess=0.7.0=pyhd3eb1b0_2
148 | - py-opencv=4.4.0=py37h43977f1_2
149 | - pygments=2.11.2=pyhd3eb1b0_0
150 | - pyparsing=3.0.9=py37h06a4308_0
151 | - python=3.7.15=haa1d7c7_0
152 | - python-dateutil=2.8.2=pyhd3eb1b0_0
153 | - python_abi=3.7=2_cp37m
154 | - pytorch=1.10.0=py3.7_cuda11.3_cudnn8.2.0_0
155 | - pytorch-mutex=1.0=cuda
156 | - pytz=2022.1=py37h06a4308_0
157 | - pywavelets=1.3.0=py37h7f8727e_0
158 | - pyyaml=6.0=py37h5eee18b_1
159 | - pyzmq=23.2.0=py37h6a678d5_0
160 | - qt=5.12.9=h763d07f_1
161 | - readline=8.2=h5eee18b_0
162 | - scikit-image=0.19.3=py37h6a678d5_1
163 | - scipy=1.7.3=py37hf2a6cf1_0
164 | - seaborn=0.12.2=hd8ed1ab_0
165 | - seaborn-base=0.12.2=pyhd8ed1ab_0
166 | - setuptools=65.5.0=py37h06a4308_0
167 | - six=1.16.0=pyhd3eb1b0_1
168 | - snappy=1.1.9=h295c915_0
169 | - sqlite=3.40.0=h5082296_0
170 | - statsmodels=0.13.2=py37h7f8727e_0
171 | - tifffile=2021.7.2=pyhd3eb1b0_2
172 | - tk=8.6.12=h1ccaba5_0
173 | - tmux=3.2a=h385fc29_0
174 | - toolz=0.12.0=py37h06a4308_0
175 | - torchaudio=0.10.0=py37_cu113
176 | - torchvision=0.11.0=py37_cu113
177 | - tornado=6.2=py37h5eee18b_0
178 | - tqdm=4.64.1=pyhd8ed1ab_0
179 | - traitlets=5.7.1=py37h06a4308_0
180 | - typing_extensions=4.4.0=py37h06a4308_0
181 | - wcwidth=0.2.5=pyhd3eb1b0_0
182 | - wheel=0.37.1=pyhd3eb1b0_0
183 | - xz=5.2.8=h5eee18b_0
184 | - yaml=0.2.5=h7b6447c_0
185 | - zeromq=4.3.4=h2531618_0
186 | - zfp=0.5.5=h295c915_6
187 | - zlib=1.2.13=h5eee18b_0
188 | - zstd=1.5.2=ha4553b6_0
189 | - pip:
190 | - data==0.4
191 | - finufft==2.1.0
192 | - funcsigs==1.0.2
193 | - future==0.18.2
194 | - latex==0.7.0
195 | - shutilwhich==1.1.0
196 | - tempdir==0.7.1
197 | prefix: /home/hywei/anaconda3/envs/sampling
198 |
--------------------------------------------------------------------------------
/LSASM.py:
--------------------------------------------------------------------------------
1 | """
2 | This is the implementation of the algorithm LS-ASM.
3 |
4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
5 | # The license is only for non-commercial use (commercial licenses can be obtained from authors).
6 | # The material is provided as-is, with no warranties whatsoever.
7 | # If you publish any code, data, or scientific work based on this, please cite our work.
8 |
9 |
10 | Technical Paper:
11 | Haoyu Wei, Xin Liu, Xiang Hao, Edmund Y. Lam, and Yifan Peng, "Modeling off-axis diffraction with the least-sampling angular spectrum method," Optica 10, 959-962 (2023)
12 | """
13 |
14 | import torch
15 | import math
16 |
17 |
18 | def mdft(in_matrix, x, y, fx, fy):
19 | x = x.unsqueeze(-1)
20 | y = y.unsqueeze(-2)
21 | fx = fx.unsqueeze(-2)
22 | fy = fy.unsqueeze(-1)
23 | mx = torch.exp(-2 * torch.pi * 1j * torch.matmul(x, fx))
24 | my = torch.exp(-2 * torch.pi * 1j * torch.matmul(fy, y))
25 | out_matrix = torch.matmul(torch.matmul(my, in_matrix), mx)
26 |
27 | lx = torch.numel(x)
28 | ly = torch.numel(y)
29 | if lx == 1:
30 | dx = 1
31 | else:
32 | dx = (torch.squeeze(x)[-1] - torch.squeeze(x)[0]) / (lx - 1)
33 |
34 | if ly == 1:
35 | dy = 1
36 | else:
37 | dy = (torch.squeeze(y)[-1] - torch.squeeze(y)[0]) / (ly - 1)
38 |
39 | out_matrix = out_matrix * dx * dy # the result is only valid for uniform sampling
40 | return out_matrix
41 |
42 |
43 | def midft(in_matrix, x, y, fx, fy):
44 | x = x.unsqueeze(-2)
45 | y = y.unsqueeze(-1)
46 | fx = fx.unsqueeze(-1)
47 | fy = fy.unsqueeze(-2)
48 | mx = torch.exp(2 * torch.pi * 1j * torch.matmul(fx, x))
49 | my = torch.exp(2 * torch.pi * 1j * torch.matmul(y, fy))
50 | out_matrix = torch.matmul(torch.matmul(my, in_matrix), mx)
51 |
52 | lfx = torch.numel(fx)
53 | lfy = torch.numel(fy)
54 | if lfx == 1:
55 | dfx = 1
56 | else:
57 | dfx = (torch.squeeze(fx)[-1] - torch.squeeze(fx)[0]) / (lfx - 1)
58 |
59 | if lfy == 1:
60 | dfy = 1
61 | else:
62 | dfy = (torch.squeeze(fy)[-1] - torch.squeeze(fy)[0]) / (lfy - 1)
63 |
64 | out_matrix = out_matrix * dfx * dfy # the result is only valid for uniform sampling
65 | return out_matrix
66 |
67 |
68 | class LeastSamplingASM():
69 | def __init__(self, Uin, xvec, yvec, z, device):
70 | '''
71 | :param Uin: input field object
72 | :param xvec, yvec: vectors of destination coordinates
73 | :param z: propagation distance
74 | '''
75 |
76 | super().__init__()
77 |
78 | dtype = torch.double
79 | complex_dtype = torch.complex128
80 |
81 | xivec, etavec = torch.as_tensor(Uin.xi, device=device), torch.as_tensor(Uin.eta, device=device)
82 | xvec, yvec = torch.as_tensor(xvec, device=device), torch.as_tensor(yvec, device=device)
83 | z = torch.as_tensor(z, device=device)
84 | wavelength = torch.as_tensor(Uin.wvls, device=device)
85 |
86 | # maximum wavelength
87 | n = 1
88 | k = 2 * math.pi / wavelength * n
89 |
90 | # bandwidth of aperture
91 | Lfx = Uin.fbX
92 | Lfy = Uin.fbY
93 |
94 | # off-axis offset
95 | xc, yc = xvec[len(xvec) // 2], yvec[len(yvec) // 2]
96 | wx = xvec[-1] - xvec[0]
97 | wy = yvec[-1] - yvec[0]
98 | offx = torch.as_tensor(Uin.fcX, device=device)
99 | offy = torch.as_tensor(Uin.fcY, device=device)
100 |
101 | # shifted frequencies
102 | fxmax = Lfx / 2 + abs(offx)
103 | fymax = Lfy / 2 + abs(offy)
104 |
105 | # drop the evanescent wave
106 | fxmax = torch.clamp(fxmax, -1 / wavelength, 1 / wavelength)
107 | fymax = torch.clamp(fymax, -1 / wavelength, 1 / wavelength)
108 | if 1 - (wavelength * fxmax)**2 - (wavelength * fymax) ** 2 <= 0:
109 | # if frequencies exceed this range, some information is lost because of evanescent wave
110 | # fxmax, fymax < 1 / wavelength
111 | # thetax_max = torch.asin(1 - wavelength * Lfx / 2) / math.pi * 180
112 | # thetay_max = torch.asin(1 - wavelength * Lfy / 2) / math.pi * 180
113 | # print(f'The oblique angle should not exceed ({thetax_max:.1f}, {thetay_max:.1f}) degrees.')
114 | eps = 1e-9
115 | beta = torch.atan2(fymax, fxmax)
116 | fxmax = torch.clamp(fxmax, max = torch.cos(beta) / ((wavelength + eps)))
117 | fymax = torch.clamp(fymax, max = torch.sin(beta) / ((wavelength + eps)))
118 | Lfx = (fxmax - abs(offx)) * 2
119 | Lfy = (fymax - abs(offy)) * 2
120 |
121 | # combined phase gradient analysis
122 | gx1, gy1 = self.grad_H(wavelength, z, Lfx / 2 + offx, Lfy / 2 + offy)
123 | gx2, gy2 = self.grad_H(wavelength, z, -Lfx / 2 + offx, -Lfy / 2 + offy)
124 | FHcx = (gx1 + gx2) / (4 * torch.pi)
125 | FHcy = (gy1 + gy2) / (4 * torch.pi)
126 |
127 | # specify the frequency sampling for each type of input field
128 | if Uin.type == "12":
129 | hx = k * Uin.zf * wavelength**2 * Lfx / 2
130 | hy = k * Uin.zf * wavelength**2 * Lfy / 2
131 | FUHbx = abs((hx + gx1) - (-hx + gx2)) / (2 * torch.pi)
132 | FUHby = abs((hy + gy1) - (-hy + gy2)) / (2 * torch.pi)
133 |
134 | deltax = self.compute_shift_of_H(FHcx, FUHbx + 2 * Uin.D, xc, wx)
135 | deltay = self.compute_shift_of_H(FHcy, FUHby + 2 * Uin.D, yc, wy)
136 | FUHcx_shifted = FHcx + deltax
137 | FUHcy_shifted = FHcy + deltay
138 |
139 | tau_UHx = 2 * abs(FUHcx_shifted) + FUHbx + 2 * Uin.D
140 | tau_UHy = 2 * abs(FUHcy_shifted) + FUHby + 2 * Uin.D
141 | else:
142 | tau_UHx = tau_UHy = torch.inf
143 |
144 | # upper bound
145 | FHbx = abs(gx1 - gx2) / (2 * torch.pi)
146 | FHby = abs(gy1 - gy2) / (2 * torch.pi)
147 |
148 | deltax = self.compute_shift_of_H(FHcx, FHbx + Uin.D, xc, wx)
149 | deltay = self.compute_shift_of_H(FHcy, FHby + Uin.D, yc, wy)
150 | FHcx_shifted = FHcx + deltax
151 | FHcy_shifted = FHcy + deltay
152 |
153 | tau_fx_bound = 2 * abs(FHcx_shifted) + FHbx + Uin.D
154 | tau_fy_bound = 2 * abs(FHcy_shifted) + FHby + Uin.D
155 |
156 | # final phase gradient
157 | tau_UHx = min(tau_UHx, tau_fx_bound) + 41.2 / Uin.fbX
158 | tau_UHy = min(tau_UHy, tau_fy_bound) + 41.2 / Uin.fbY
159 |
160 | dfxMax1 = 1 / tau_UHx
161 | dfyMax1 = 1 / tau_UHy
162 |
163 | # maximum sampling interval limited by OW
164 | dfxMax2 = 1 / (2 * abs(xc - deltax) + wx)
165 | dfyMax2 = 1 / (2 * abs(yc - deltay) + wy)
166 |
167 | # minimum requirements of sampling interval in k space
168 | dfx = min(dfxMax1, dfxMax2)
169 | dfy = min(dfyMax1, dfyMax2)
170 |
171 | LRfx = math.ceil(Lfx / dfx * Uin.s)
172 | LRfy = math.ceil(Lfy / dfy * Uin.s)
173 |
174 | dfx2 = Lfx / LRfx
175 | dfy2 = Lfy / LRfy
176 |
177 | print(f'frequency sampling number = {LRfx, LRfy}, bandwidth = {Lfx:.2f}.')
178 |
179 | # spatial frequency coordinates
180 | fx = torch.linspace(-Lfx / 2, Lfx / 2 - dfx2, LRfx, device=device, dtype=complex_dtype)
181 | fy = torch.linspace(-Lfy / 2, Lfy / 2 - dfy2, LRfy, device=device, dtype=complex_dtype)
182 | fx_shift, fy_shift = fx + offx, fy + offy
183 |
184 | fxx, fyy = torch.meshgrid(fx_shift, fy_shift, indexing='xy')
185 | # self.H = torch.exp(1j * k * z * torch.sqrt(1 - (wavelength * fxx) ** 2 - (wavelength * fyy) ** 2))
186 | # shifted H
187 | self.H = torch.exp(1j * k * (wavelength * fxx * deltax + wavelength * fyy * deltay
188 | + z * torch.sqrt(1 - (fxx * wavelength)**2 - (fyy * wavelength)**2)))
189 |
190 | self.xi = xivec.to(dtype = complex_dtype)
191 | self.eta = etavec.to(dtype = complex_dtype)
192 | self.x = xvec.to(dtype = complex_dtype) - deltax # shift the observation window back to origin
193 | self.y = yvec.to(dtype = complex_dtype) - deltay
194 | self.offx, self.offy = offx, offy
195 | self.device = device
196 | self.fx = fx_shift
197 | self.fy = fy_shift
198 | self.fbX = Uin.fbX
199 | self.fbY = Uin.fbY
200 |
201 |
202 | def __call__(self, E0):
203 | '''
204 | :param E0: input field
205 | '''
206 |
207 | E0 = torch.as_tensor(E0, dtype=torch.complex128, device=self.device)
208 |
209 | fx = self.fx.unsqueeze(0)
210 | fy = self.fy.unsqueeze(0)
211 |
212 | Fu = mdft(E0, self.xi, self.eta, fx - self.offx, fy - self.offy)
213 |
214 | Eout = midft(Fu * self.H, self.x, self.y, fx, fy)
215 | # Eout /= abs(Eout).max() # we dont need to normalize using MTP.
216 |
217 | return Eout[0].cpu().numpy()
218 |
219 |
220 | def grad_H(self, lam, z, fx, fy):
221 |
222 | eps = torch.tensor(1e-9, device = fx.device)
223 | denom = torch.max(1 - (lam * fx)**2 - (lam * fy) ** 2, eps)
224 | gradx = - z * 2 * torch.pi * lam * fx / torch.sqrt(1 - (lam * fx)**2 - (lam * fy)**2)
225 | grady = - z * 2 * torch.pi * lam * fy / torch.sqrt(1 - (lam * fx)**2 - (lam * fy)**2)
226 | return gradx, grady
227 |
228 |
229 | def compute_shift_of_H(self, C1, C2, pc, w):
230 |
231 | if (w > -2 * C1 - 2 * pc + C2) and (w < 2 * C1 + 2 * pc + C2):
232 | delta = pc / 2 + w / 4 - C1 / 2 - C2 / 4
233 | elif (w > 2 * C1 + 2 * pc + C2) and (w < -2 * C1 - 2 * pc + C2):
234 | delta = pc / 2 - w / 4 - C1 / 2 + C2 / 4
235 | elif (w > 2 * C1 + 2 * pc + C2) and (w > -2 * C1 - 2 * pc + C2):
236 | delta = pc
237 | elif (w < 2 * C1 + 2 * pc + C2) and (w < -2 * C1 - 2 * pc + C2):
238 | delta = -C1
239 |
240 | return delta
--------------------------------------------------------------------------------