├── documents ├── results.png ├── principles.png └── uniform-diffuser.png ├── __pycache__ ├── ASASM.cpython-37.pyc ├── LSASM.cpython-37.pyc ├── utils.cpython-37.pyc ├── ASASM.cpython-310.pyc ├── utils.cpython-310.pyc ├── input_field.cpython-310.pyc ├── input_field.cpython-37.pyc ├── phase_plates.cpython-310.pyc └── phase_plates.cpython-37.pyc ├── .vscode └── launch.json ├── utils.py ├── README.md ├── RS.py ├── main.py ├── phase_plates.py ├── input_field.py ├── environment.yml └── LSASM.py /documents/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/documents/results.png -------------------------------------------------------------------------------- /documents/principles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/documents/principles.png -------------------------------------------------------------------------------- /documents/uniform-diffuser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/documents/uniform-diffuser.png -------------------------------------------------------------------------------- /__pycache__/ASASM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/ASASM.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/LSASM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/LSASM.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/ASASM.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/ASASM.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/input_field.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/input_field.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/input_field.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/input_field.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/phase_plates.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/phase_plates.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/phase_plates.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whywww/ASASM/HEAD/__pycache__/phase_plates.cpython-37.pyc -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Remote Attach", 9 | "type": "python", 10 | "request": "attach", 11 | "connect": { 12 | "host": "localhost", 13 | "port": 5678 14 | }, 15 | "pathMappings": [ 16 | { 17 | "localRoot": "${workspaceFolder}", 18 | "remoteRoot": "." 19 | } 20 | ], 21 | "justMyCode": true 22 | } 23 | ] 24 | } -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | These are the util functions. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from authors). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | 10 | Technical Paper: 11 | Haoyu Wei, Xin Liu, Xiang Hao, Edmund Y. Lam, and Yifan Peng, "Modeling off-axis diffraction with the least-sampling angular spectrum method," Optica 10, 959-962 (2023) 12 | """ 13 | 14 | import numpy as np 15 | from PIL import Image 16 | from matplotlib import cm 17 | 18 | 19 | def effective_bandwidth(D, wvls=None, is_plane_wave=False, zf=None, s=1.): 20 | if is_plane_wave: 21 | bandwidth = 41.2 * s / D 22 | else: 23 | assert zf is not None, "Wave origin should be provided!" 24 | bandwidth = s * D / wvls / zf 25 | 26 | return bandwidth 27 | 28 | 29 | def save_image(image, save_path, cmap='gray'): 30 | 31 | imarray = np.array(image / image.max()) # 0~1 32 | if cmap == 'viridis': 33 | imarray = cm.viridis(imarray) 34 | elif cmap == 'twilight': 35 | imarray = cm.twilight(imarray) 36 | elif cmap == 'magma': 37 | imarray = cm.magma(imarray) 38 | elif cmap == 'plasma': 39 | imarray = cm.plasma(imarray) 40 | im = Image.fromarray(np.uint8(imarray * 255)) 41 | im.save(save_path) 42 | 43 | 44 | def remove_linear_phase(phi, thetaX, thetaY, x, y, k): 45 | 46 | linear_phiX = -np.sin(thetaX / 180 * np.pi) * k 47 | linear_phiY = -np.sin(thetaY / 180 * np.pi) * k 48 | 49 | xx, yy = np.meshgrid(x, y, indexing='xy') 50 | phi_new = phi - xx * linear_phiX - yy * linear_phiY 51 | 52 | return np.remainder(phi_new, 2 * np.pi) 53 | 54 | 55 | def snr(u_hat, u_ref): 56 | u_hat /= abs(u_hat).max() 57 | u_ref /= abs(u_ref).max() 58 | signal = np.sum(abs(u_hat)**2) 59 | alpha = np.sum(u_hat * np.conjugate(u_ref)) / np.sum(abs(u_ref)**2) 60 | snr = signal / np.sum(abs(u_hat - alpha * u_ref)**2) 61 | return 10 * np.log10(snr) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Author Contributions 2 | To provide essential information regarding author contributions omitted from the published letter: 3 | 4 | H.W. and X.L. conceived the project and wrote the paper; X.L. constructed the storyline and led the formulation and methodology development with H.W.'s discussion and derivation. 5 | 6 | H.W. and X.L. built the code; 7 | 8 | X.L. and H.W. co-authored the manuscript with valuable feedback from Y.P., E.Y.L., and X.H.; 9 | 10 | H.W. created all figures with input from X.L.; 11 | 12 | Y.P. provided project supervision. 13 | 14 | ***For more insights beyond this work, you may correspond with [Dr. Xin Liu](https://liux2018.github.io) and refer to [On computational optics](https://github.com/LiuX2018/On-computational-optics).*** 15 | 16 | # LS-ASM 17 | This repository provides the official open-source code of the following paper: 18 | 19 | **Modeling off-axis diffraction with least-sampling angular spectrum method**\ 20 | Haoyu Wei*, Xin Liu*, Xiang Hao, Edmund Y. Lam, Yifan Peng\ 21 | [Paper](https://doi.org/10.1364/OPTICA.490223), [Project page](https://whywww.github.io/LSASM_page/) \ 22 | Correspondence: [Dr. Peng](https://www.eee.hku.hk/~evanpeng/) and [Prof. Lam](https://www.eee.hku.hk/~elam/). For implementation and experiment details please contact Haoyu (haoyu.wei97@gmail.com). 23 | 24 | principle 25 | 26 | ## Quick start 27 | This repository contains implementations of LS-ASM and Rayleigh-Sommerfeld algorithms, with spherical wave input and thin lens and diffuser modulations. 28 | 29 | ### Prerequisites 30 | Create a conda environment from yml file: 31 | ``` 32 | conda env create -f environment.yml 33 | ``` 34 | If you are running on a GPU, please install a PyTorch version that matches the Cuda version on your machine. 35 | 36 | ### Config and Run 37 | Configurations are in `main.py`.\ 38 | Run and find results in the `results` folder. 39 | ``` 40 | python main.py 41 | ``` 42 | 43 | ## Performance 44 | We display LS-ASM speedup along 0 - 20 degrees of incident angles.\ 45 | results 46 | 47 | Diffuser results closely resemble RS.\ 48 | diffuser 49 | 50 | ## Citation 51 | 52 | If you use this code and find our work valuable, please cite our paper. 53 | ``` 54 | @article{Wei:23, 55 | title = {Modeling Off-Axis Diffraction with the Least-Sampling Angular Spectrum Method}, 56 | author = {Haoyu Wei and Xin Liu and Xiang Hao and Edmund Y. Lam and Yifan Peng}, 57 | journal = {Optica}, 58 | volume = {10}, number = {7}, pages = {959--962}, 59 | publisher = {Optica Publishing Group}, 60 | year = {2023}, 61 | month = {Jul}, 62 | doi = {10.1364/OPTICA.490223} 63 | } 64 | ``` 65 | 66 | ## License 67 | 68 | Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License. 69 | -------------------------------------------------------------------------------- /RS.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the implementation of the Rayleigh-Sommerfeld algorithm. Refer to Goodman, Joseph W. 3 | Introduction to Fourier optics. Roberts and Company Publishers, 2005, for principle details. 4 | This code is adapted from a Matlab script from Xin Liu and converted into a GPU parallel 5 | -computing Python script by Haoyu Wei (haoyu.wei97@gmail.com). 6 | 7 | 8 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 9 | # The license is only for non-commercial use (commercial licenses can be obtained from authors). 10 | # The material is provided as-is, with no warranties whatsoever. 11 | # If you publish any code, data, or scientific work based on this, please cite our work. 12 | 13 | 14 | Technical Paper: 15 | Haoyu Wei, Xin Liu, Xiang Hao, Edmund Y. Lam, and Yifan Peng, "Modeling off-axis 16 | diffraction with the least-sampling angular spectrum method," Optica 10, 959-962 (2023) 17 | """ 18 | 19 | 20 | import torch 21 | import math 22 | from tqdm import tqdm 23 | 24 | 25 | class RSDiffraction_GPU(): 26 | ''' 27 | Optimized for parallel computing 28 | ''' 29 | def __init__(self, z, xvec, yvec, svec, tvec, wavelengths, device) -> None: 30 | ''' 31 | x,s are horizontal. y,t are vertical. 32 | ''' 33 | 34 | self.device = device 35 | self.k = 2 * torch.pi / wavelengths 36 | self.z = z 37 | 38 | xvec, yvec = torch.tensor(xvec), torch.tensor(yvec) 39 | svec, tvec = torch.tensor(svec), torch.tensor(tvec) 40 | xx, yy = torch.meshgrid(xvec, yvec, indexing='xy') 41 | ss, tt = torch.meshgrid(svec, tvec, indexing='xy') 42 | self.ss, self.tt = ss.to(device), tt.to(device) 43 | self.xx, self.yy = xx.to(device), yy.to(device) 44 | 45 | self.block_sz = 100 # depends on your memory, e.g., 128 needs ~24GB GPU memory 46 | 47 | 48 | def __call__(self, E0): 49 | 50 | E0 = torch.tensor(E0, dtype=torch.complex128, device=self.device) 51 | 52 | LX, LY = E0.shape[-2:] 53 | LS, LT = self.ss.shape 54 | 55 | Eout = [] 56 | for bt in tqdm(range(math.ceil(LT / self.block_sz)), desc='tvec', position=0): 57 | Erow = [] 58 | for bs in tqdm(range(math.ceil(LS / self.block_sz)), desc='svec', position=1, leave=False): 59 | ss_ = self.ss[bt*self.block_sz : (bt+1)*self.block_sz, bs*self.block_sz : (bs+1)*self.block_sz] 60 | tt_ = self.tt[bt*self.block_sz : (bt+1)*self.block_sz, bs*self.block_sz : (bs+1)*self.block_sz] 61 | block_sum = torch.zeros_like(ss_, dtype=E0.dtype) 62 | for by in tqdm(range(math.ceil(LY / self.block_sz)), desc='yvec', position=2, leave=False): 63 | for bx in tqdm(range(math.ceil(LX / self.block_sz)), desc='xvec', position=3, leave=False): 64 | E0_ = E0[by*self.block_sz : (by+1)*self.block_sz, bx*self.block_sz : (bx+1)*self.block_sz] 65 | xx_ = self.xx[by*self.block_sz : (by+1)*self.block_sz, bx*self.block_sz : (bx+1)*self.block_sz] 66 | yy_ = self.yy[by*self.block_sz : (by+1)*self.block_sz, bx*self.block_sz : (bx+1)*self.block_sz] 67 | xx_st = xx_[..., None, None] 68 | yy_st = yy_[..., None, None] 69 | xy_ss = ss_.expand(*xx_.shape, *ss_.shape) 70 | xy_tt = tt_.expand(*xx_.shape, *tt_.shape) 71 | r = torch.sqrt((xy_ss - xx_st)**2 + (xy_tt - yy_st)**2 + self.z**2) 72 | h = -1 / (2 * torch.pi) * (1j * self.k - 1 / r) * torch.exp(1j * self.k * r) * self.z / r**2 73 | block_sum += torch.einsum('xy, xyst', E0_, h) 74 | Erow.append(block_sum) 75 | Eout.append(torch.hstack(Erow)) 76 | Eout = torch.vstack(Eout) 77 | 78 | return Eout.cpu().numpy() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | LS-ASM: 3 | 4 | This is the main executive script used for the diffraction field calculation using LS-ASM. 5 | 6 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 7 | # The license is only for non-commercial use (commercial licenses can be obtained from authors). 8 | # The material is provided as-is, with no warranties whatsoever. 9 | # If you publish any code, data, or scientific work based on this, please cite our work. 10 | 11 | @article{Wei:23, 12 | title = {Modeling Off-Axis Diffraction with the Least-Sampling Angular Spectrum Method}, 13 | author = {Haoyu Wei and Xin Liu and Xiang Hao and Edmund Y. Lam and Yifan Peng}, 14 | journal = {Optica}, 15 | volume = {10}, number = {7}, pages = {959--962}, 16 | publisher = {Optica Publishing Group}, 17 | year = {2023}, 18 | month = {Jul}, 19 | doi = {10.1364/OPTICA.490223} 20 | } 21 | 22 | ----- 23 | 24 | $ python main.py 25 | """ 26 | 27 | 28 | import numpy as np 29 | import time 30 | from utils import save_image, remove_linear_phase, snr 31 | import glob 32 | from input_field import InputField 33 | 34 | 35 | ############################### hyperparameters ############################ 36 | 37 | wvls = 500e-9 # wavelength of light in vacuum 38 | k = 2 * np.pi / wvls # wavenumebr 39 | f = 35e-3 # focal length of lens (if applicable) 40 | z0 = 1.7 # source-aperture distance 41 | zf = 1/(1/f - 1/z0) # image-side focal distance 42 | z = zf # aperture-sensor distance 43 | r = f / 16 / 2 # radius of aperture 44 | thetaX = 0 # incident angle in degree 45 | thetaY = 5 # incident angle in degree 46 | 47 | s_LSASM = 1.5 # oversampling factor for LSASM 48 | s_RS = 4 # oversampling factor for Rayleigh-Sommerfeld 49 | compensate = True # LPC 50 | use_LSASM = True 51 | use_RS = False 52 | result_folder = 'results' 53 | RS_folder = 'RS' 54 | calculate_SNR = False 55 | 56 | # define observation window 57 | Mx, My = 512, 512 58 | l = r * 0.25 59 | # l = 0.0136/1.5 # first term in Eq6 scaled by 1/1.5 to estimate OW size, used for diffuser 60 | # l = r * 8. # 35 degrees 61 | xc = - z * np.sin(thetaX / 180 * np.pi) / np.sqrt(1 - np.sin(thetaX / 180 * np.pi)**2 - np.sin(thetaY / 180 * np.pi)**2) 62 | yc = - z * np.sin(thetaY / 180 * np.pi) / np.sqrt(1 - np.sin(thetaX / 180 * np.pi)**2 - np.sin(thetaY / 180 * np.pi)**2) 63 | 64 | x = np.linspace(-l / 2 + xc, l / 2 + xc, Mx, endpoint=True) 65 | y = np.linspace(-l / 2 + yc, l / 2 + yc, My, endpoint=True) 66 | print(f'observation window diamter = {l}.') 67 | 68 | if use_LSASM: 69 | print('----------------- Propagating with ASASM -----------------') 70 | # use "12" for thin lens + spherical wave 71 | # use "3" for diffuser 72 | Uin = InputField("12", wvls, (thetaX, thetaY), r, z0, f, zf, s_LSASM) 73 | 74 | from LSASM import LeastSamplingASM 75 | device = 'cuda:0' 76 | # device = 'cpu' 77 | prop2 = LeastSamplingASM(Uin, x, y, z, device) 78 | path = f'{result_folder}/LSASM({len(Uin.xi)},{len(prop2.fx)})-{thetaX}-{s_LSASM:.2f}' 79 | 80 | start = time.time() 81 | U2 = prop2(Uin.E0) 82 | end = time.time() 83 | runtime = end - start 84 | print(f'Time elapsed for LSASM: {runtime:.2f}') 85 | 86 | save_image(abs(U2), f'{path}.png', cmap='gray') 87 | phase = remove_linear_phase(np.angle(U2), thetaX, thetaY, x, y, k) # for visualization 88 | save_image(phase, f'{path}-Phi.png', cmap='twilight') 89 | 90 | if calculate_SNR: 91 | if glob.glob(f'{RS_folder}/RS*-{thetaX}-{s_RS:.1f}.npy') != []: 92 | u_GT = np.load(glob.glob(f'{RS_folder}/RS*-{thetaX}-{s_RS:.1f}.npy')[0]) 93 | print(f'SNR is {snr(U2, u_GT):.2f}') 94 | 95 | 96 | if use_RS: 97 | print('-------------- Propagating with RS integral --------------') 98 | Uin = InputField("12", wvls, (thetaX, thetaY), r, z0, f, zf, s_RS) 99 | 100 | from RS import RSDiffraction_GPU 101 | prop = RSDiffraction_GPU(z, Uin.xi, Uin.eta, x, y, wvls, 'cuda:0') 102 | path = f'{RS_folder}/RS({len(Uin.xi)})-{thetaX}-{s_RS:.1f}' 103 | start = time.time() 104 | U0 = prop(Uin.E0) 105 | end = time.time() 106 | print(f'Time elapsed for RS: {end-start:.2f}') 107 | save_image(abs(U0), f'{path}.png', cmap='gray') 108 | phase = remove_linear_phase(np.angle(U0), thetaX, thetaY, x, y, k) # for visualization 109 | save_image(phase, f'{path}-Phi.png', cmap='twilight') 110 | np.save(f'{path}', U0) -------------------------------------------------------------------------------- /phase_plates.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script includes the modulations and components of the input field. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from authors). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | 10 | Technical Paper: 11 | Haoyu Wei, Xin Liu, Xiang Hao, Edmund Y. Lam, and Yifan Peng, "Modeling off-axis diffraction with the least-sampling angular spectrum method," Optica 10, 959-962 (2023) 12 | """ 13 | 14 | 15 | import numpy as np 16 | from utils import effective_bandwidth 17 | import cv2 18 | 19 | 20 | class SphericalWave(): 21 | def __init__(self, k, x0, y0, z0, angles, zf) -> None: 22 | 23 | self.k = k 24 | thetaX, thetaY = angles 25 | self.fcX = - np.sin(thetaX / 180 * np.pi) * k / (2 * np.pi) 26 | self.fcY = - np.sin(thetaY / 180 * np.pi) * k / (2 * np.pi) 27 | self.x0, self.y0, self.z0 = x0, y0, z0 28 | self.zf = zf 29 | 30 | 31 | def forward(self, E0, xi_, eta_): 32 | ''' 33 | Apply a spherical phase shift to E0 at coordinates xi_ and eta_ 34 | ''' 35 | 36 | radius = np.sqrt(self.z0**2 + (xi_ - self.x0)**2 + (eta_ - self.y0)**2) 37 | phase = self.k * radius 38 | amplitude = 1 / radius 39 | 40 | E = amplitude * np.exp(1j * phase) 41 | E *= np.exp(1j * 2 * np.pi * (-self.fcX * xi_ - self.fcY * eta_)) # LPC 42 | 43 | return E0 * E 44 | 45 | 46 | def phase_gradient(self, xi, eta): 47 | ''' 48 | Compute phase gradients at point (xi, eta) 49 | ''' 50 | 51 | denom = np.sqrt((xi - self.x0)**2 + (eta - self.y0)**2 + self.z0**2) 52 | grad_uX = self.k * (xi - self.x0) / denom 53 | grad_uY = self.k * (eta - self.y0) / denom 54 | 55 | grad_linearX = 2 * np.pi * self.fcX 56 | grad_linearY = 2 * np.pi * self.fcY 57 | 58 | gradientX = grad_uX - grad_linearX 59 | gradientY = grad_uY - grad_linearY 60 | 61 | return gradientX, gradientY 62 | 63 | 64 | class PlaneWave(): 65 | def __init__(self, k, r, x0, y0, z0) -> None: 66 | 67 | self.k = k 68 | self.r = r 69 | self.fcX = self.fcY = 0 70 | self.x0, self.y0, self.z0 = x0, y0, z0 71 | 72 | 73 | def forward(self, E0, xi_, eta_): 74 | 75 | vec = np.array([-self.x0, -self.y0, self.z0]) 76 | kx, ky, kz = vec / np.sqrt(np.dot(vec, vec)) 77 | phase = self.k * (kx * xi_ + ky * eta_ + kz) 78 | 79 | return E0 * np.exp(1j * phase) 80 | 81 | 82 | def phase_gradient(self, xi, eta): 83 | 84 | return 0, 0 85 | 86 | 87 | class ThinLens(): 88 | def __init__(self, k, f) -> None: 89 | 90 | self.k = k 91 | self.f = f 92 | self.fcX = self.fcY = 0 93 | 94 | 95 | def forward(self, E0, xi_, eta_): 96 | 97 | phase = self.k / 2 * (-1 / self.f) * (xi_**2 + eta_**2) 98 | 99 | return E0 * np.exp(1j * phase) 100 | 101 | 102 | def phase_gradient(self, xi, eta): 103 | 104 | grad_uX = -self.k / self.f * xi 105 | grad_uY = -self.k / self.f * eta 106 | 107 | return grad_uX, grad_uY 108 | 109 | 110 | class Diffuser(): 111 | def __init__(self, r, interpolation='nearest', rand_phase=True, rand_amp=False) -> None: 112 | ''' 113 | Two types of diffusers: 'nearest' or 'linear' interpolated 114 | ''' 115 | 116 | self.fcX = self.fcY = 0 117 | self.pitch = r / 10 118 | self.N = int(r * 2 / self.pitch) 119 | np.random.seed(0) 120 | self.plate = np.random.rand(self.N, self.N) 121 | self.interp = interpolation 122 | self.rand_phase = rand_phase 123 | self.rand_amp = rand_amp 124 | 125 | 126 | def forward(self, E0, xi_, eta_): 127 | 128 | if self.interp == 'nearest': 129 | plate_sample = cv2.resize(self.plate, xi_.shape, interpolation=cv2.INTER_NEAREST) 130 | elif self.interp == 'linear': 131 | plate_sample = cv2.resize(self.plate, xi_.shape, interpolation=cv2.INTER_LINEAR) 132 | else: 133 | raise NotImplementedError 134 | 135 | amp = np.ones_like(plate_sample) 136 | phase = np.zeros_like(plate_sample) 137 | if self.rand_phase: 138 | phase = plate_sample * 4 * np.pi # random phase from 0 to 4pi 139 | if self.rand_amp: 140 | amp = plate_sample # random amplitude from 0 to 1 141 | 142 | return E0 * amp * np.exp(1j * phase) 143 | 144 | 145 | def phase_gradient(self): 146 | ''' 147 | :return: maximum phase gradient 148 | ''' 149 | 150 | # Second term in Eq. 4 151 | grad_max = effective_bandwidth(self.pitch, is_plane_wave = True) 152 | 153 | # nearest interpolation does not have phase gradient 154 | # but linear interpolation does 155 | if self.interp == 'linear': 156 | if self.rand_phase: 157 | grad_max += 4 / self.pitch 158 | if self.rand_amp: 159 | grad_max += 1 / self.pitch / np.pi 160 | 161 | return grad_max, grad_max -------------------------------------------------------------------------------- /input_field.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is where the input field is integrated. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from authors). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | 10 | Technical Paper: 11 | Haoyu Wei, Xin Liu, Xiang Hao, Edmund Y. Lam, and Yifan Peng, "Modeling off-axis diffraction with the least-sampling angular spectrum method," Optica 10, 959-962 (2023) 12 | """ 13 | 14 | 15 | import numpy as np 16 | from phase_plates import SphericalWave, PlaneWave, ThinLens, Diffuser 17 | from operator import add 18 | from utils import effective_bandwidth 19 | 20 | 21 | class InputField(): 22 | ''' 23 | Prepare compensated input field and spatial sampling 24 | ''' 25 | def __init__(self, type:str, wvls:float, angles, r:float, 26 | z0=None, f=None, zf=None, s=1.5) -> None: 27 | 28 | self.wvls = wvls # wavelength of light in vacuum 29 | self.k = 2 * np.pi / self.wvls # wavenumebr 30 | thetaX, thetaY = angles 31 | 32 | # define incident wave 33 | r0 = z0 / np.sqrt(1 - np.sin(thetaX / 180 * np.pi)**2 - np.sin(thetaY / 180 * np.pi)**2) 34 | x0, y0 = r0 * np.sin(thetaX / 180 * np.pi), r0 * np.sin(thetaY / 180 * np.pi) 35 | print(f'aperture diameter = {2 * r}, offset = {x0:.4f}, theta = {thetaX}.') 36 | 37 | # prepare wave components 38 | typelist = [*type] 39 | wavelist = [] 40 | fcX = 0 # frequency centers 41 | fcY = 0 42 | 43 | print('Input field contains:') 44 | if "0" in typelist: 45 | print('\t Plane wave') 46 | 47 | field = PlaneWave(self.k, r, x0, y0, z0) 48 | fcX += field.fcX 49 | fcY += field.fcY 50 | wavelist.append(field) 51 | 52 | if "1" in typelist: 53 | print('\t Spherical wave') 54 | 55 | field = SphericalWave(self.k, x0, y0, z0, angles, zf) 56 | fcX += field.fcX 57 | fcY += field.fcY 58 | wavelist.append(field) 59 | 60 | if "2" in typelist: 61 | print('\t Convex lens') 62 | lens = ThinLens(self.k, f) 63 | fcX += lens.fcX 64 | fcY += lens.fcY 65 | wavelist.append(lens) 66 | 67 | if "3" in typelist: 68 | print('\t Random diffuser') 69 | 70 | phase_plate = Diffuser(r, interpolation='linear', rand_phase=True, rand_amp=True) 71 | fcX += phase_plate.fcX 72 | fcY += phase_plate.fcY 73 | wavelist.append(phase_plate) 74 | 75 | # Compute spatial sampling 76 | Nx, Ny, fbX, fbY = self.spatial_sampling(r, s, wavelist) 77 | 78 | # Prepare input field 79 | self.set_input_plane(r, Nx, Ny) 80 | E0 = self.pupil 81 | for wave in wavelist: 82 | E0 = wave.forward(E0, self.xi_, self.eta_) 83 | 84 | self.fcX = fcX 85 | self.fcY = fcY 86 | self.fbX = fbX 87 | self.fbY = fbY 88 | self.E0 = E0 89 | self.s = s 90 | self.zf = zf 91 | self.D = 2 * r 92 | self.type = type 93 | 94 | 95 | def spatial_sampling(self, r, s, wavelist): 96 | ''' 97 | :param r: aperture radius 98 | :param s: oversampling factor 99 | :param wavelist: a list of input wave components 100 | :return: number of samples in both dimensions, bandwidths in both dimensions 101 | ''' 102 | 103 | # Second term in Eq4, the size of Airy disk 104 | fplane = effective_bandwidth(r*2, is_plane_wave = True) 105 | 106 | # First term in Eq4, maximum phase gradient 107 | # as the phase terms are all monotonic here, 108 | # we use the two boundaries of aperture (+-r) to find max 109 | grad1 = [0, 0] 110 | grad2 = [0, 0] 111 | diffuser = False 112 | for wave in wavelist: 113 | if isinstance(wave, Diffuser): 114 | diffuser = True 115 | grad = wave.phase_gradient() 116 | fbX_diffuser = grad[0] * s 117 | fbY_diffuser = grad[1] * s 118 | else: 119 | grad1 = list(map(add, grad1, wave.phase_gradient(-r, -r))) 120 | grad2 = list(map(add, grad2, wave.phase_gradient(r, r))) 121 | fbX = (max(abs(grad1[0]), abs(grad2[0])) / np.pi + fplane) * s 122 | fbY = (max(abs(grad1[1]), abs(grad2[1])) / np.pi + fplane) * s 123 | 124 | if diffuser: 125 | fbX = max(fbX, fbX_diffuser) 126 | fbY = max(fbY, fbY_diffuser) 127 | 128 | Nx = int(np.ceil(fbX * 2 * r)) 129 | Ny = int(np.ceil(fbY * 2 * r)) 130 | print(f'spatial sampling number = {Nx, Ny}.') 131 | 132 | return Nx, Ny, (Nx - 1) / (2 * r), (Ny - 1) / (2 * r) 133 | 134 | 135 | def set_input_plane(self, r, Nx, Ny): 136 | 137 | # coordinates of aperture 138 | xi = np.linspace(-r, r, Nx, endpoint=True) 139 | eta = np.linspace(-r, r, Ny, endpoint=True) 140 | xi_, eta_ = np.meshgrid(xi, eta, indexing='xy') 141 | 142 | # circular aperture 143 | pupil = np.where(xi_**2 + eta_**2 <= r**2, 1, 0) 144 | 145 | self.pupil = pupil 146 | self.xi, self.eta = xi, eta 147 | self.xi_, self.eta_ = xi_, eta_ 148 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: lsasm 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - backcall=0.2.0=pyhd3eb1b0_0 11 | - blas=1.0=mkl 12 | - blosc=1.21.0=h4ff587b_1 13 | - bottleneck=1.3.5=py37h7deecbd_0 14 | - brotli=1.0.9=h5eee18b_7 15 | - brotli-bin=1.0.9=h5eee18b_7 16 | - brunsli=0.1=h2531618_0 17 | - bzip2=1.0.8=h7b6447c_0 18 | - c-ares=1.18.1=h7f8727e_0 19 | - ca-certificates=2023.01.10=h06a4308_0 20 | - cairo=1.16.0=hf32fb01_1 21 | - certifi=2022.12.7=py37h06a4308_0 22 | - cfitsio=3.470=h5893167_7 23 | - charls=2.2.0=h2531618_0 24 | - cloudpickle=2.0.0=pyhd3eb1b0_0 25 | - colorama=0.4.5=py37h06a4308_0 26 | - cudatoolkit=11.3.1=h9edb442_10 27 | - cycler=0.11.0=pyhd3eb1b0_0 28 | - cytoolz=0.12.0=py37h5eee18b_0 29 | - dask-core=2021.10.0=pyhd3eb1b0_0 30 | - dbus=1.13.18=hb2f20db_0 31 | - debugpy=1.5.1=py37h295c915_0 32 | - decorator=5.1.1=pyhd3eb1b0_0 33 | - entrypoints=0.4=py37h06a4308_0 34 | - expat=2.4.9=h6a678d5_0 35 | - ffmpeg=4.3=hf484d3e_0 36 | - flit-core=3.6.0=pyhd3eb1b0_0 37 | - fontconfig=2.13.1=h6c09931_0 38 | - freetype=2.12.1=h4a9f257_0 39 | - fsspec=2022.11.0=py37h06a4308_0 40 | - gettext=0.21.0=hf68c758_0 41 | - giflib=5.2.1=h7b6447c_0 42 | - glib=2.68.4=h9c3ff4c_0 43 | - glib-tools=2.68.4=h9c3ff4c_0 44 | - gmp=6.2.1=h295c915_3 45 | - gnutls=3.6.15=he1e5248_0 46 | - graphite2=1.3.14=h295c915_1 47 | - gst-plugins-base=1.14.5=h0935bb2_2 48 | - gstreamer=1.18.5=h76c114f_0 49 | - harfbuzz=2.7.2=ha5b49bf_1 50 | - hdf5=1.10.6=h3ffc7dd_1 51 | - icu=67.1=he1b5a44_0 52 | - imagecodecs=2021.8.26=py37hf0132c2_1 53 | - imageio=2.19.3=py37h06a4308_0 54 | - intel-openmp=2021.4.0=h06a4308_3561 55 | - ipykernel=6.15.2=py37h06a4308_0 56 | - ipython=7.31.1=py37h06a4308_1 57 | - jasper=1.900.1=hd497a04_4 58 | - jedi=0.18.1=py37h06a4308_1 59 | - jpeg=9e=h7f8727e_0 60 | - jupyter_client=7.4.7=py37h06a4308_0 61 | - jupyter_core=4.11.2=py37h06a4308_0 62 | - jxrlib=1.1=h7b6447c_2 63 | - kiwisolver=1.4.2=py37h295c915_0 64 | - krb5=1.19.2=hac12032_0 65 | - lame=3.100=h7b6447c_0 66 | - lcms2=2.12=h3be6417_0 67 | - ld_impl_linux-64=2.38=h1181459_1 68 | - lerc=3.0=h295c915_0 69 | - libaec=1.0.4=he6710b0_1 70 | - libblas=3.9.0=12_linux64_mkl 71 | - libbrotlicommon=1.0.9=h5eee18b_7 72 | - libbrotlidec=1.0.9=h5eee18b_7 73 | - libbrotlienc=1.0.9=h5eee18b_7 74 | - libcblas=3.9.0=12_linux64_mkl 75 | - libclang=11.1.0=default_ha53f305_1 76 | - libcurl=7.87.0=h91b91d3_0 77 | - libdeflate=1.8=h7f8727e_5 78 | - libedit=3.1.20221030=h5eee18b_0 79 | - libev=4.33=h7f8727e_1 80 | - libevent=2.1.10=h9b69904_4 81 | - libffi=3.3=he6710b0_2 82 | - libgcc-ng=11.2.0=h1234567_1 83 | - libgfortran-ng=11.2.0=h00389a5_1 84 | - libgfortran5=11.2.0=h1234567_1 85 | - libglib=2.68.4=h3e27bee_0 86 | - libgomp=11.2.0=h1234567_1 87 | - libiconv=1.16=h7f8727e_2 88 | - libidn2=2.3.2=h7f8727e_0 89 | - liblapack=3.9.0=12_linux64_mkl 90 | - liblapacke=3.9.0=12_linux64_mkl 91 | - libllvm11=11.1.0=h9e868ea_6 92 | - libnghttp2=1.46.0=hce63b2e_0 93 | - libopencv=4.4.0=py37_2 94 | - libpng=1.6.37=hbc83047_0 95 | - libpq=12.9=h16c4e8d_3 96 | - libsodium=1.0.18=h7b6447c_0 97 | - libssh2=1.10.0=h8f2d780_0 98 | - libstdcxx-ng=11.2.0=h1234567_1 99 | - libtasn1=4.16.0=h27cfd23_0 100 | - libtiff=4.4.0=hecacb30_2 101 | - libunistring=0.9.10=h27cfd23_0 102 | - libuuid=1.41.5=h5eee18b_0 103 | - libuv=1.40.0=h7b6447c_0 104 | - libwebp=1.2.4=h11a3e52_0 105 | - libwebp-base=1.2.4=h5eee18b_0 106 | - libxcb=1.15=h7f8727e_0 107 | - libxkbcommon=1.0.3=he3ba5ed_0 108 | - libxml2=2.9.10=h68273f3_2 109 | - libzopfli=1.0.3=he6710b0_0 110 | - locket=1.0.0=py37h06a4308_0 111 | - lz4-c=1.9.4=h6a678d5_0 112 | - matplotlib=3.2.2=1 113 | - matplotlib-base=3.2.2=py37h1d35a4c_1 114 | - matplotlib-inline=0.1.6=py37h06a4308_0 115 | - mkl=2021.4.0=h06a4308_640 116 | - mkl-service=2.4.0=py37h7f8727e_0 117 | - mkl_fft=1.3.1=py37hd3c417c_0 118 | - mkl_random=1.2.2=py37h51133e4_0 119 | - mysql-common=8.0.25=ha770c72_2 120 | - mysql-libs=8.0.25=hfa10184_2 121 | - ncurses=6.3=h5eee18b_3 122 | - nest-asyncio=1.5.5=py37h06a4308_0 123 | - nettle=3.7.3=hbbd107a_1 124 | - networkx=2.6.3=pyhd3eb1b0_0 125 | - nspr=4.33=h295c915_0 126 | - nss=3.74=h0370c37_0 127 | - numexpr=2.8.3=py37h807cd23_0 128 | - numpy=1.21.5=py37h6c91a56_3 129 | - numpy-base=1.21.5=py37ha15fc14_3 130 | - opencv=4.4.0=py37_2 131 | - openh264=2.1.1=h4ff587b_0 132 | - openjpeg=2.4.0=h3ad879b_0 133 | - openssl=1.1.1s=h7f8727e_0 134 | - packaging=21.3=pyhd3eb1b0_0 135 | - pandas=1.3.5=py37h8c16a72_0 136 | - parso=0.8.3=pyhd3eb1b0_0 137 | - partd=1.2.0=pyhd3eb1b0_1 138 | - patsy=0.5.3=pyhd8ed1ab_0 139 | - pcre=8.45=h295c915_0 140 | - pexpect=4.8.0=pyhd3eb1b0_3 141 | - pickleshare=0.7.5=pyhd3eb1b0_1003 142 | - pillow=9.3.0=py37hace64e9_0 143 | - pip=22.3.1=py37h06a4308_0 144 | - pixman=0.40.0=h7f8727e_1 145 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 146 | - psutil=5.9.0=py37h5eee18b_0 147 | - ptyprocess=0.7.0=pyhd3eb1b0_2 148 | - py-opencv=4.4.0=py37h43977f1_2 149 | - pygments=2.11.2=pyhd3eb1b0_0 150 | - pyparsing=3.0.9=py37h06a4308_0 151 | - python=3.7.15=haa1d7c7_0 152 | - python-dateutil=2.8.2=pyhd3eb1b0_0 153 | - python_abi=3.7=2_cp37m 154 | - pytorch=1.10.0=py3.7_cuda11.3_cudnn8.2.0_0 155 | - pytorch-mutex=1.0=cuda 156 | - pytz=2022.1=py37h06a4308_0 157 | - pywavelets=1.3.0=py37h7f8727e_0 158 | - pyyaml=6.0=py37h5eee18b_1 159 | - pyzmq=23.2.0=py37h6a678d5_0 160 | - qt=5.12.9=h763d07f_1 161 | - readline=8.2=h5eee18b_0 162 | - scikit-image=0.19.3=py37h6a678d5_1 163 | - scipy=1.7.3=py37hf2a6cf1_0 164 | - seaborn=0.12.2=hd8ed1ab_0 165 | - seaborn-base=0.12.2=pyhd8ed1ab_0 166 | - setuptools=65.5.0=py37h06a4308_0 167 | - six=1.16.0=pyhd3eb1b0_1 168 | - snappy=1.1.9=h295c915_0 169 | - sqlite=3.40.0=h5082296_0 170 | - statsmodels=0.13.2=py37h7f8727e_0 171 | - tifffile=2021.7.2=pyhd3eb1b0_2 172 | - tk=8.6.12=h1ccaba5_0 173 | - tmux=3.2a=h385fc29_0 174 | - toolz=0.12.0=py37h06a4308_0 175 | - torchaudio=0.10.0=py37_cu113 176 | - torchvision=0.11.0=py37_cu113 177 | - tornado=6.2=py37h5eee18b_0 178 | - tqdm=4.64.1=pyhd8ed1ab_0 179 | - traitlets=5.7.1=py37h06a4308_0 180 | - typing_extensions=4.4.0=py37h06a4308_0 181 | - wcwidth=0.2.5=pyhd3eb1b0_0 182 | - wheel=0.37.1=pyhd3eb1b0_0 183 | - xz=5.2.8=h5eee18b_0 184 | - yaml=0.2.5=h7b6447c_0 185 | - zeromq=4.3.4=h2531618_0 186 | - zfp=0.5.5=h295c915_6 187 | - zlib=1.2.13=h5eee18b_0 188 | - zstd=1.5.2=ha4553b6_0 189 | - pip: 190 | - data==0.4 191 | - finufft==2.1.0 192 | - funcsigs==1.0.2 193 | - future==0.18.2 194 | - latex==0.7.0 195 | - shutilwhich==1.1.0 196 | - tempdir==0.7.1 197 | prefix: /home/hywei/anaconda3/envs/sampling 198 | -------------------------------------------------------------------------------- /LSASM.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the implementation of the algorithm LS-ASM. 3 | 4 | This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell: 5 | # The license is only for non-commercial use (commercial licenses can be obtained from authors). 6 | # The material is provided as-is, with no warranties whatsoever. 7 | # If you publish any code, data, or scientific work based on this, please cite our work. 8 | 9 | 10 | Technical Paper: 11 | Haoyu Wei, Xin Liu, Xiang Hao, Edmund Y. Lam, and Yifan Peng, "Modeling off-axis diffraction with the least-sampling angular spectrum method," Optica 10, 959-962 (2023) 12 | """ 13 | 14 | import torch 15 | import math 16 | 17 | 18 | def mdft(in_matrix, x, y, fx, fy): 19 | x = x.unsqueeze(-1) 20 | y = y.unsqueeze(-2) 21 | fx = fx.unsqueeze(-2) 22 | fy = fy.unsqueeze(-1) 23 | mx = torch.exp(-2 * torch.pi * 1j * torch.matmul(x, fx)) 24 | my = torch.exp(-2 * torch.pi * 1j * torch.matmul(fy, y)) 25 | out_matrix = torch.matmul(torch.matmul(my, in_matrix), mx) 26 | 27 | lx = torch.numel(x) 28 | ly = torch.numel(y) 29 | if lx == 1: 30 | dx = 1 31 | else: 32 | dx = (torch.squeeze(x)[-1] - torch.squeeze(x)[0]) / (lx - 1) 33 | 34 | if ly == 1: 35 | dy = 1 36 | else: 37 | dy = (torch.squeeze(y)[-1] - torch.squeeze(y)[0]) / (ly - 1) 38 | 39 | out_matrix = out_matrix * dx * dy # the result is only valid for uniform sampling 40 | return out_matrix 41 | 42 | 43 | def midft(in_matrix, x, y, fx, fy): 44 | x = x.unsqueeze(-2) 45 | y = y.unsqueeze(-1) 46 | fx = fx.unsqueeze(-1) 47 | fy = fy.unsqueeze(-2) 48 | mx = torch.exp(2 * torch.pi * 1j * torch.matmul(fx, x)) 49 | my = torch.exp(2 * torch.pi * 1j * torch.matmul(y, fy)) 50 | out_matrix = torch.matmul(torch.matmul(my, in_matrix), mx) 51 | 52 | lfx = torch.numel(fx) 53 | lfy = torch.numel(fy) 54 | if lfx == 1: 55 | dfx = 1 56 | else: 57 | dfx = (torch.squeeze(fx)[-1] - torch.squeeze(fx)[0]) / (lfx - 1) 58 | 59 | if lfy == 1: 60 | dfy = 1 61 | else: 62 | dfy = (torch.squeeze(fy)[-1] - torch.squeeze(fy)[0]) / (lfy - 1) 63 | 64 | out_matrix = out_matrix * dfx * dfy # the result is only valid for uniform sampling 65 | return out_matrix 66 | 67 | 68 | class LeastSamplingASM(): 69 | def __init__(self, Uin, xvec, yvec, z, device): 70 | ''' 71 | :param Uin: input field object 72 | :param xvec, yvec: vectors of destination coordinates 73 | :param z: propagation distance 74 | ''' 75 | 76 | super().__init__() 77 | 78 | dtype = torch.double 79 | complex_dtype = torch.complex128 80 | 81 | xivec, etavec = torch.as_tensor(Uin.xi, device=device), torch.as_tensor(Uin.eta, device=device) 82 | xvec, yvec = torch.as_tensor(xvec, device=device), torch.as_tensor(yvec, device=device) 83 | z = torch.as_tensor(z, device=device) 84 | wavelength = torch.as_tensor(Uin.wvls, device=device) 85 | 86 | # maximum wavelength 87 | n = 1 88 | k = 2 * math.pi / wavelength * n 89 | 90 | # bandwidth of aperture 91 | Lfx = Uin.fbX 92 | Lfy = Uin.fbY 93 | 94 | # off-axis offset 95 | xc, yc = xvec[len(xvec) // 2], yvec[len(yvec) // 2] 96 | wx = xvec[-1] - xvec[0] 97 | wy = yvec[-1] - yvec[0] 98 | offx = torch.as_tensor(Uin.fcX, device=device) 99 | offy = torch.as_tensor(Uin.fcY, device=device) 100 | 101 | # shifted frequencies 102 | fxmax = Lfx / 2 + abs(offx) 103 | fymax = Lfy / 2 + abs(offy) 104 | 105 | # drop the evanescent wave 106 | fxmax = torch.clamp(fxmax, -1 / wavelength, 1 / wavelength) 107 | fymax = torch.clamp(fymax, -1 / wavelength, 1 / wavelength) 108 | if 1 - (wavelength * fxmax)**2 - (wavelength * fymax) ** 2 <= 0: 109 | # if frequencies exceed this range, some information is lost because of evanescent wave 110 | # fxmax, fymax < 1 / wavelength 111 | # thetax_max = torch.asin(1 - wavelength * Lfx / 2) / math.pi * 180 112 | # thetay_max = torch.asin(1 - wavelength * Lfy / 2) / math.pi * 180 113 | # print(f'The oblique angle should not exceed ({thetax_max:.1f}, {thetay_max:.1f}) degrees.') 114 | eps = 1e-9 115 | beta = torch.atan2(fymax, fxmax) 116 | fxmax = torch.clamp(fxmax, max = torch.cos(beta) / ((wavelength + eps))) 117 | fymax = torch.clamp(fymax, max = torch.sin(beta) / ((wavelength + eps))) 118 | Lfx = (fxmax - abs(offx)) * 2 119 | Lfy = (fymax - abs(offy)) * 2 120 | 121 | # combined phase gradient analysis 122 | gx1, gy1 = self.grad_H(wavelength, z, Lfx / 2 + offx, Lfy / 2 + offy) 123 | gx2, gy2 = self.grad_H(wavelength, z, -Lfx / 2 + offx, -Lfy / 2 + offy) 124 | FHcx = (gx1 + gx2) / (4 * torch.pi) 125 | FHcy = (gy1 + gy2) / (4 * torch.pi) 126 | 127 | # specify the frequency sampling for each type of input field 128 | if Uin.type == "12": 129 | hx = k * Uin.zf * wavelength**2 * Lfx / 2 130 | hy = k * Uin.zf * wavelength**2 * Lfy / 2 131 | FUHbx = abs((hx + gx1) - (-hx + gx2)) / (2 * torch.pi) 132 | FUHby = abs((hy + gy1) - (-hy + gy2)) / (2 * torch.pi) 133 | 134 | deltax = self.compute_shift_of_H(FHcx, FUHbx + 2 * Uin.D, xc, wx) 135 | deltay = self.compute_shift_of_H(FHcy, FUHby + 2 * Uin.D, yc, wy) 136 | FUHcx_shifted = FHcx + deltax 137 | FUHcy_shifted = FHcy + deltay 138 | 139 | tau_UHx = 2 * abs(FUHcx_shifted) + FUHbx + 2 * Uin.D 140 | tau_UHy = 2 * abs(FUHcy_shifted) + FUHby + 2 * Uin.D 141 | else: 142 | tau_UHx = tau_UHy = torch.inf 143 | 144 | # upper bound 145 | FHbx = abs(gx1 - gx2) / (2 * torch.pi) 146 | FHby = abs(gy1 - gy2) / (2 * torch.pi) 147 | 148 | deltax = self.compute_shift_of_H(FHcx, FHbx + Uin.D, xc, wx) 149 | deltay = self.compute_shift_of_H(FHcy, FHby + Uin.D, yc, wy) 150 | FHcx_shifted = FHcx + deltax 151 | FHcy_shifted = FHcy + deltay 152 | 153 | tau_fx_bound = 2 * abs(FHcx_shifted) + FHbx + Uin.D 154 | tau_fy_bound = 2 * abs(FHcy_shifted) + FHby + Uin.D 155 | 156 | # final phase gradient 157 | tau_UHx = min(tau_UHx, tau_fx_bound) + 41.2 / Uin.fbX 158 | tau_UHy = min(tau_UHy, tau_fy_bound) + 41.2 / Uin.fbY 159 | 160 | dfxMax1 = 1 / tau_UHx 161 | dfyMax1 = 1 / tau_UHy 162 | 163 | # maximum sampling interval limited by OW 164 | dfxMax2 = 1 / (2 * abs(xc - deltax) + wx) 165 | dfyMax2 = 1 / (2 * abs(yc - deltay) + wy) 166 | 167 | # minimum requirements of sampling interval in k space 168 | dfx = min(dfxMax1, dfxMax2) 169 | dfy = min(dfyMax1, dfyMax2) 170 | 171 | LRfx = math.ceil(Lfx / dfx * Uin.s) 172 | LRfy = math.ceil(Lfy / dfy * Uin.s) 173 | 174 | dfx2 = Lfx / LRfx 175 | dfy2 = Lfy / LRfy 176 | 177 | print(f'frequency sampling number = {LRfx, LRfy}, bandwidth = {Lfx:.2f}.') 178 | 179 | # spatial frequency coordinates 180 | fx = torch.linspace(-Lfx / 2, Lfx / 2 - dfx2, LRfx, device=device, dtype=complex_dtype) 181 | fy = torch.linspace(-Lfy / 2, Lfy / 2 - dfy2, LRfy, device=device, dtype=complex_dtype) 182 | fx_shift, fy_shift = fx + offx, fy + offy 183 | 184 | fxx, fyy = torch.meshgrid(fx_shift, fy_shift, indexing='xy') 185 | # self.H = torch.exp(1j * k * z * torch.sqrt(1 - (wavelength * fxx) ** 2 - (wavelength * fyy) ** 2)) 186 | # shifted H 187 | self.H = torch.exp(1j * k * (wavelength * fxx * deltax + wavelength * fyy * deltay 188 | + z * torch.sqrt(1 - (fxx * wavelength)**2 - (fyy * wavelength)**2))) 189 | 190 | self.xi = xivec.to(dtype = complex_dtype) 191 | self.eta = etavec.to(dtype = complex_dtype) 192 | self.x = xvec.to(dtype = complex_dtype) - deltax # shift the observation window back to origin 193 | self.y = yvec.to(dtype = complex_dtype) - deltay 194 | self.offx, self.offy = offx, offy 195 | self.device = device 196 | self.fx = fx_shift 197 | self.fy = fy_shift 198 | self.fbX = Uin.fbX 199 | self.fbY = Uin.fbY 200 | 201 | 202 | def __call__(self, E0): 203 | ''' 204 | :param E0: input field 205 | ''' 206 | 207 | E0 = torch.as_tensor(E0, dtype=torch.complex128, device=self.device) 208 | 209 | fx = self.fx.unsqueeze(0) 210 | fy = self.fy.unsqueeze(0) 211 | 212 | Fu = mdft(E0, self.xi, self.eta, fx - self.offx, fy - self.offy) 213 | 214 | Eout = midft(Fu * self.H, self.x, self.y, fx, fy) 215 | # Eout /= abs(Eout).max() # we dont need to normalize using MTP. 216 | 217 | return Eout[0].cpu().numpy() 218 | 219 | 220 | def grad_H(self, lam, z, fx, fy): 221 | 222 | eps = torch.tensor(1e-9, device = fx.device) 223 | denom = torch.max(1 - (lam * fx)**2 - (lam * fy) ** 2, eps) 224 | gradx = - z * 2 * torch.pi * lam * fx / torch.sqrt(1 - (lam * fx)**2 - (lam * fy)**2) 225 | grady = - z * 2 * torch.pi * lam * fy / torch.sqrt(1 - (lam * fx)**2 - (lam * fy)**2) 226 | return gradx, grady 227 | 228 | 229 | def compute_shift_of_H(self, C1, C2, pc, w): 230 | 231 | if (w > -2 * C1 - 2 * pc + C2) and (w < 2 * C1 + 2 * pc + C2): 232 | delta = pc / 2 + w / 4 - C1 / 2 - C2 / 4 233 | elif (w > 2 * C1 + 2 * pc + C2) and (w < -2 * C1 - 2 * pc + C2): 234 | delta = pc / 2 - w / 4 - C1 / 2 + C2 / 4 235 | elif (w > 2 * C1 + 2 * pc + C2) and (w > -2 * C1 - 2 * pc + C2): 236 | delta = pc 237 | elif (w < 2 * C1 + 2 * pc + C2) and (w < -2 * C1 - 2 * pc + C2): 238 | delta = -C1 239 | 240 | return delta --------------------------------------------------------------------------------