├── .github └── workflows │ └── pytests.yml ├── LICENSE ├── README.md ├── examples ├── MPI_logo.png ├── __init__.py ├── __pycache__ │ └── __init__.cpython-38.pyc ├── autodiff_evaluation_xlumina.py ├── discovered_solution_4f_system.npy ├── examples.ipynb ├── noisy_4f_system.ipynb ├── noisy_optimization.ipynb ├── numerical_methods_evaluation_diffractio.py ├── scalar_diffractio.py ├── scalar_xlumina.py ├── test_diffractio_vs_xlumina.ipynb ├── vectorial_diffractio.py └── vectorial_xlumina.py ├── experiments ├── __init__.py ├── four_f_optical_table.py ├── four_f_optimizer.py ├── generate_synthetic_data.py ├── hybrid_optimizer.py ├── hybrid_sharp_optical_table.py ├── hybrid_sted_optical_table.py ├── hybrid_with_fixed_PM.py └── six_times_six_ansatz_with_fixed_PM.py ├── miscellaneous ├── noise-aware.png ├── performance.png ├── performance_convergence.png ├── propagation_comparison.png └── workflow.png ├── setup.py ├── tests ├── __init__.py ├── pytest.ini ├── test_optical_elements.py ├── test_toolbox.py ├── test_vectorized_optics.py └── test_wave_optics.py └── xlumina ├── __init__.py ├── loss_functions.py ├── optical_elements.py ├── toolbox.py ├── vectorized_optics.py └── wave_optics.py /.github/workflows/pytests.yml: -------------------------------------------------------------------------------- 1 | name: pytests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ['3.10', '3.11'] 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v3 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install -e . 22 | pip install pytest 23 | pip install h5py 24 | - name: Test with pytest 25 | run: pytest -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Artificial Scientist Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ✨ XLuminA ✨ 2 | 3 | [![Nature Communications](https://img.shields.io/badge/Nat_Commun-15,_10658_(2024)-b31b1b.svg)](https://www.nature.com/articles/s41467-024-54696-y) 4 | [![Downloads](https://pepy.tech/badge/xlumina)](https://pepy.tech/project/xlumina) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | ![Tests](https://github.com/artificial-scientist-lab/XLuminA/actions/workflows/pytests.yml/badge.svg) 7 | [![Python Versions](https://img.shields.io/pypi/pyversions/xlumina.svg)](https://pypi.org/project/xlumina/) 8 | [![PyPI version](https://badge.fury.io/py/xlumina.svg)](https://badge.fury.io/py/xlumina) 9 | [![GitHub stars](https://img.shields.io/github/stars/artificial-scientist-lab/XLuminA.svg)](https://github.com/artificial-scientist-lab/XLuminA/stargazers) 10 | 11 | **XLuminA, a highly-efficient, auto-differentiating discovery framework for super-resolution microscopy** 12 | 13 | 📄 Read our paper here: \ 14 | [**Automated discovery of experimental designs in super-resolution microscopy with XLuminA**](https://doi.org/10.1038/s41467-024-54696-y)\ 15 | *Carla Rodríguez, Sören Arlt, Leonhard Möckl and Mario Krenn* 16 | 17 | 📰 Read the press release: \ 18 | [**10,000 times faster than traditional methods: new computational framework automatically discovers experimental designs in microscopy**](https://mpl.mpg.de/news/article/10000-times-faster-than-traditional-methods-new-computational-framework-automatically-discovers-experimental-designs-in-microscopy) 19 | 20 | 📚 Related works featuring XLuminA: 21 | 22 | [![NeurIPS](https://img.shields.io/badge/NeurIPS-2024-red.svg)](https://ml4physicalsciences.github.io/2024/) **Machine Learning and the Physical Sciences Workshop -** *Poster* 23 | 24 | [![ICML](https://img.shields.io/badge/ICML-2024-blue.svg)](https://openreview.net/forum?id=ik9YuAHq6J&referrer=%5Bthe%20profile%20of%20Carla%20Rodríguez%5D(%2Fprofile%3Fid%3D~Carla_Rodríguez1)) **AI4Science Workshop** - *Oral contribution* 25 | 26 | [![NeurIPS](https://img.shields.io/badge/NeurIPS-2023-purple.svg)](https://openreview.net/forum?id=J8HGMimNYe&referrer=%5Bthe%20profile%20of%20Carla%20Rodríguez%5D(%2Fprofile%3Fid%3D~Carla_Rodríguez1)) **AI4Science Workshop -** *Oral contribution* 27 | 28 | ## 💻 Installation: 29 | 30 | ### Using PyPI: 31 | 32 | Create a new conda environment and install `xlumina`from PyPI. We recommend using `python=3.11`: 33 | ``` 34 | conda create -n xlumina_env python=3.11 35 | 36 | conda activate xlumina_env 37 | 38 | pip install xlumina 39 | ``` 40 | 41 | It should be installed in about 10 seconds. The package automatically installs: 42 | 43 | 1. [**JAX (CPU only) and jaxlib**](https://jax.readthedocs.io/en/latest/index.html) (the version of JAX used in this project is v0.4.33), 44 | 45 | 2. [**Optax**](https://github.com/google-deepmind/optax/tree/master) (the version of Optax used in this project is v0.2.3), 46 | 47 | 3. [**SciPy**](https://scipy.org) (the version of SciPy used in this project is v1.14.1). 48 | 49 | ### Clone repository: 50 | ``` 51 | git clone https://github.com/artificial-scientist-lab/XLuminA.git 52 | ``` 53 | 54 | ### GPU compatibility: 55 | 56 | To install [JAX with NVIDIA GPU support](https://jax.readthedocs.io/en/latest/installation.html) (**Note: wheels only available on linux**), use CUDA 12 installation: 57 | 58 | ``` 59 | pip install --upgrade "jax[cuda12]" 60 | ``` 61 | 62 | XLuminA has been tested on the following operating systems: 63 | 64 | Linux Enterprise Server 15 SP4 15.4, 65 | 66 | and it has been successfully installed in Windows 10 (64-bit) and MacOS Monterey 12.6.2 67 | 68 | 69 | # 👾 Features: 70 | 71 | XLuminA allows for the simulation, in a (*very*) fast and efficient way, of classical light propagation through optics hardware configurations,and enables the optimization and automated discovery of new setup designs. 72 | 73 | workflow 74 | 75 | The simulator contains many features: 76 | 77 | ✦ Light sources (of any wavelength and power) using both scalar or vectorial optical fields. 78 | 79 | ✦ Phase masks (e.g., spatial light modulators (SLMs), polarizers and general variable retarders (LCDs)). 80 | 81 | ✦ Amplitude masks (e.g., spatial light modulators (SLMs) and pre-defined circles, triangles and squares). 82 | 83 | ✦ Beam splitters, fluorescence model for STED, and more! 84 | 85 | ✦ The light propagation methods available in XLuminA are: 86 | 87 | - [Fast-Fourier-transform (FFT) based numerical integration of the Rayleigh-Sommerfeld diffraction integral](https://doi.org/10.1364/AO.45.001102). 88 | 89 | - [Chirped z-transform](https://doi.org/10.1038/s41377-020-00362-z). This algorithm is an accelerated version of the Rayleigh-Sommerfeld method, which allows for arbitrary selection and sampling of the region of interest. 90 | 91 | - Propagation through [high NA objective lenses](https://doi.org/10.1016/j.optcom.2010.07.030) is available to replicate strong focusing conditions using polarized light. 92 | 93 | # 📝 Examples of usage: 94 | 95 | Examples of some experiments that can be reproduced with XLuminA are: 96 | 97 | * Optical telescope (or 4f-correlator), 98 | * Polarization-based beam shaping as used in [STED (stimulated emission depletion) microscopy](https://opg.optica.org/ol/fulltext.cfm?uri=ol-19-11-780&id=12352), 99 | * The [sharp focus of a radially polarized light beam](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.91.233901). 100 | 101 | The code for each of these optical setups is provided in the Jupyter notebook of [examples.ipynb](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/examples.ipynb). 102 | 103 | ➤ A **step-by-step guide on how to add noise to the optical elements** can be found in [noisy_4f_system.ipynb](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/noisy_4f_system.ipynb). 104 | 105 | # 🚀 Testing XLuminA's efficiency: 106 | 107 | We evaluated our framework by conducting several tests - see [Figure 1](https://arxiv.org/abs/2310.08408#). The experiments were run on an Intel CPU Xeon Gold 6130 and Nvidia GPU Quadro RTX 6000. 108 | 109 | (1) Average execution time (in seconds) over 100 runs, within a computational window size of $2048\times 2048$, for scalar and vectorial field propagation using Rayleigh-Sommerfeld (RS, VRS) and Chirped z-transform (CZT, VCZT) in [Diffractio](https://pypi.org/project/diffractio/) and XLuminA. Times for XLuminA reflect the run with pre-compiled jitted functions. The Python files corresponding to light propagation algorithms testing are [scalar_diffractio.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/scalar_diffractio.py) and [vectorial_diffractio.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/vectorial_diffractio.py) for Diffractio, and [scalar_xlumina.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/scalar_xlumina.py) and [vectorial_xlumina.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/vectorial_xlumina.py) for XLuminA. 110 | 111 | propagation 112 | 113 | (2) We compare the gradient evaluation time of numerical methods (using Diffractio's optical simulator and SciPy's [BFGS optimizer](https://docs.scipy.org/doc/scipy/reference/optimize.minimize-bfgs.html#optimize-minimize-bfgs)) *vs* autodiff (analytical) differentiation (using XLuminA's optical simulator with JAX's [ADAM optimizer](https://jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html)) across various resolutions: 114 | 115 | performance 116 | 117 | (3) We compare the convergence time of numerical methods (using Diffractio's optical simulator and SciPy's [BFGS optimizer](https://docs.scipy.org/doc/scipy/reference/optimize.minimize-bfgs.html#optimize-minimize-bfgs)) *vs* autodiff (analytical) differentiation (using XLuminA's optical simulator with JAX's [ADAM optimizer](https://jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html)) across various resolutions: 118 | 119 | performance 120 | 121 | ➤ The Jupyter notebook used for running these simulations is provided as [test_diffractio_vs_xlumina.ipynb](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/test_diffractio_vs_xlumina.ipynb). 122 | 123 | ➤ The Python files corresponding to numerical/autodiff evaluations are [numerical_methods_evaluation_diffractio.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/numerical_methods_evaluation_diffractio.py), and [autodiff_evaluation_xlumina.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/autodiff_evaluation_xlumina.py) 124 | 125 | *If you want to run the comparison test of the propagation functions, you need to install [**Diffractio**](https://pypi.org/project/diffractio/) - The version of Diffractio used in this project is v0.1.1.* 126 | 127 | # 🤖🔎 Discovery of new optical setups: 128 | 129 | With XLuminA we were able to rediscover three foundational optics experiments. In particular, we discover new, superior topologies together with their parameter settings using purely continuous optimization. 130 | 131 | ➤ Optical telescope (or 4f-correlator), 132 | 133 | ➤ Polarization-based beam shaping as used in [STED (stimulated emission depletion) microscopy](https://opg.optica.org/ol/fulltext.cfm?uri=ol-19-11-780&id=12352), 134 | 135 | ➤ The [sharp focus of a radially polarized light beam](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.91.233901). 136 | 137 | The Python files used for the discovery of these optical setups, as detailed in [our paper](https://arxiv.org/abs/2310.08408#), are organized in pairs of `optical_table` and `optimizer` as follows: 138 | 139 | | **Experiment name** | 🔬 Optical table | 🤖 Optimizer | 📄 File for data | 140 | |----------------|---------------|-----------|----------| 141 | | ***Optical telescope*** | [four_f_optical_table.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/four_f_optical_table.py) | [four_f_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/four_f_optimizer.py)| [Generate_synthetic_data.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/generate_synthetic_data.py) | 142 | | ***Pure topological discovery: large-scale sharp focus (Dorn, Quabis and Leuchs, 2004)*** | [hybrid_with_fixed_PM.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_with_fixed_PM.py) | [hybrid_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_optimizer.py)| N/A | 143 | | ***Pure topological discovery: STED microscopy*** | [hybrid_with_fixed_PM.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_with_fixed_PM.py) | [hybrid_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_optimizer.py)| N/A | 144 | | ***6x6 grid: pure topological discovery*** | [six_times_six_ansatz_with_fixed_PM.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/six_times_six_ansatz_with_fixed_PM.py) | [hybrid_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_optimizer.py)| N/A | 145 | | ***Large-scale polarization-based STED*** | [hybrid_sted_optical_table.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_sted_optical_table.py) | [hybrid_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_optimizer.py)| N/A | 146 | | ***Large-scale sharp focus (Dorn, Quabis and Leuchs, 2004)*** | [hybrid_sharp_optical_table.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_sharp_optical_table.py) | [hybrid_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_optimizer.py)| N/A | 147 | 148 | # 🦾🤖 Robustness and parallelized optimization of multiple optical tables with our noise-aware scheme: 149 | 150 | ✦ Importantly, to ensure simulations which approximate real-world experimental conditions we have included imperfections, misalignment, and noise sources in all optical components (during post-processing and/or during optimization). **All the results presented in the paper are computed considering a wide variety of experimental errors**. 151 | 152 | ➤ A **step-by-step guide on how to setup the optimization using this scheme** can be found in [noisy_optimization.ipynb](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/noisy_optimization.ipynb). 153 | 154 | ➤ A **step-by-step guide on how to add noise to the optical elements** can be found in [noisy_4f_system.ipynb](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/noisy_4f_system.ipynb). 155 | 156 | noise_aware 157 | 158 | ✦ The optimization procedure is as follows: for each optimization step, we execute $N$ parallel optical tables using `vmap`. Then, we sample random noise and apply it to all available physical variables across each of the $N$ optical tables. The random noise is **uniformly distributed** and includes: 159 | * Phase values for spatial light modulators (SLMs) and wave plates (WP) in the range of $\pm$ (0.01 to 0.1) radians, covering all qualities available in current experimental devices. 160 | * Misalignment ranging from $\pm$ (0.01 to 0.1) millimeters, covering both expert-level precision ($\pm$ 0.01 mm) and beginner-level accuracy ($\pm$ 0.1 mm). 161 | * 1\% imperfection on the transmissivity/reflectivity of beam splitters (BS), which is a realistic approach given the high quality of the currently available hardware. 162 | 163 | We then simulate the optical setup for each of the $N$ tables simultaneously, incorporating the sampled noise. The loss function is computed independently for each of the setups. Afterwards, we calculate the mean loss value across all optical tables, which provides an average performance metric that accounts for the introduced experimental variability (noise). The gradients are computed based on this mean loss value and so the update of the system parameters'. 164 | 165 | Importantly, before applying the updated parameters and proceeding to the next iteration, we resample new random noise for each optical table. This ensures that each optimization step encounters different noise values, further enhancing the robustness of the solution. This procedure is repeated iteratively until convergence. 166 | 167 | # 👀 Overview: 168 | 169 | In this section we list the available functions in different files and a brief description: 170 | 171 | 1. In [wave_optics.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/wave_optics.py): module for scalar optical fields. 172 | 173 | |*Class*|*Functions*|*Description*| 174 | |---------------|----|-----------| 175 | | `ScalarLight` | | Class for scalar optical fields defined in the XY plane: complex amplitude $U(r) = A(r)*e^{-ikz}$. | 176 | | | `.draw` | Plots intensity and phase. | 177 | | | `.apply_circular_mask` | Apply a circular mask of variable radius. | 178 | | | `.apply_triangular_mask` | Apply a triangular mask of variable size. | 179 | | | `.apply_rectangular_mask` | Apply a rectangular mask of variable size. | 180 | | | `.apply_annular_aperture` | Apply annular aperture of variable size. | 181 | | | `.RS_propagation` | [Rayleigh-Sommerfeld](https://doi.org/10.1364/AO.45.001102) diffraction integral in z-direction (z>0 and z<0). | 182 | | | `.get_RS_minimum_z` | Given a quality factor, determines the minimum (trustworthy) distance for `RS_propagation`.| 183 | | | `.CZT` | [Chirped z-transform](https://doi.org/10.1038/s41377-020-00362-z) - efficient diffraction using the Bluestein method.| 184 | | `LightSource` | | Class for scalar optical fields defined in the XY plane - defines light source beams. | | 185 | | | `.gaussian_beam` | Gaussian beam. | 186 | | | `.plane_wave` | Plane wave. | 187 | 188 | 189 | 2. In [vectorized_optics.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/xlumina/vectorized_optics.py): module for vectorized optical fields. 190 | 191 | |*Class*| *Functions* |*Description*| 192 | |---------------|----|-----------| 193 | | `VectorizedLight` | | Class for vectorized optical fields defined in the XY plane: $\vec{E} = (E_x, E_y, E_z)$| 194 | | | `.draw` | Plots intensity, phase and amplitude. | 195 | | | `.draw_intensity_profile` | Plots intensity profile. | 196 | | | `.VRS_propagation` | [Vectorial Rayleigh-Sommerfeld](https://iopscience.iop.org/article/10.1088/1612-2011/10/6/065004) diffraction integral in z-direction (z>0 and z<0). | 197 | | | `.get_VRS_minimum_z` | Given a quality factor, determines the minimum (trustworthy) distance for `VRS_propagation`.| 198 | | | `.VCZT` | [Vectorized Chirped z-transform](https://doi.org/10.1038/s41377-020-00362-z) - efficient diffraction using the Bluestein method.| 199 | | `PolarizedLightSource` | | Class for polarized optical fields defined in the XY plane - defines light source beams. | | 200 | | | `.gaussian_beam` | Gaussian beam. | 201 | | | `.plane_wave` | Plane wave. | 202 | 203 | 204 | 3. In [optical_elements.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/xlumina/optical_elements.py): shelf with all the optical elements available. 205 | 206 | | *Function* |*Description*| 207 | |---------------|----| 208 | | ***Scalar light devices*** | - | 209 | | `phase_scalar_SLM` | Phase mask for the spatial light modulator available for scalar fields. | 210 | | `SLM` | Spatial light modulator: applies a phase mask to incident scalar field. | 211 | | ***Jones matrices*** | - | 212 | | `jones_LP` | Jones matrix of a [linear polarizer](https://doi.org/10.1201/b19711)| 213 | | `jones_general_retarder` | Jones matrix of a [general retarder](https://www.researchgate.net/publication/235963739_Obtainment_of_the_polarizing_and_retardation_parameters_of_a_non-depolarizing_optical_system_from_the_polar_decomposition_of_its_Mueller_matrix). | 214 | | `jones_sSLM` | Jones matrix of the *superSLM*. | 215 | | `jones_sSLM_with_amplitude` | Jones matrix of the *superSLM* that modulates phase & amplitude. | 216 | | `jones_LCD` | Jones matrix of liquid crystal display (LCD).| 217 | | ***Polarization-based devices*** | - | 218 | |`sSLM` | *super*-Spatial Light Modulator: adds phase mask (pixel-wise) to $E_x$ and $E_y$ independently. | 219 | | `sSLM_with_amplitude` | *super*-Spatial Light Modulator: adds phase mask and amplitude mask (pixel-wise) to $E_x$ and $E_y$ independently. | 220 | | `LCD` | Liquid crystal device: builds any linear wave-plate. | 221 | | `linear_polarizer` | Linear polarizer.| 222 | | `BS_symmetric` | Symmetric beam splitter.| 223 | | `BS_symmetric_SI` | Symmetric beam splitter with single input.| 224 | | `BS` | Single-side coated dielectric beam splitter.| 225 | | `high_NA_objective_lens` | High NA objective lens (only for `VectorizedLight`).| 226 | | `VCZT_objective_lens` | Propagation through high NA objective lens (only for `VectorizedLight`).| 227 | | ***General elements*** | - | 228 | | `lens` | Transparent lens of variable size and focal length.| 229 | | `cylindrical_lens` | Transparent plano-convex cylindrical lens of variable focal length. | 230 | | `axicon_lens` | Axicon lens function that produces a Bessel beam. | 231 | | `circular_mask` | Circular mask of variable size. | 232 | | `triangular_mask` | Triangular mask of variable size and orientation.| 233 | | `rectangular_mask` | Rectangular mask of variable size and orientation.| 234 | | `annular_aperture` | Annular aperture of variable size.| 235 | | `forked_grating` | Forked grating of variable size, orientation, and topological charge. | 236 | | 👷‍♀️ ***Pre-built optical setups*** | - | 237 | | `bb_amplitude_and_phase_mod` | Basic building unit. Consists of a `sSLM` (amp & phase modulation), and `LCD` linked via `VRS_propagation`. | 238 | | `building_block` | Basic building unit. Consists of a `sSLM`, and `LCD` linked via `VRS_propagation`. | 239 | | `fluorescence`| Fluorescence model.| 240 | | `vectorized_fluorophores` | Vectorized version of `fluorescence`: Allows to compute effective intensity across an array of detectors.| 241 | | `robust_discovery` | 3x3 setup for hybrid (topology + optical settings) discovery with single wavelength. Longitudinal intensity (Iz) is measured across all detectors. Includes noise for robustness. | 242 | | `hybrid_setup_fixed_slms_fluorophores`| 3x3 optical table with SLMs randomly positioned displaying fixed phase masks; to be used for pure topological discovery; contains the fluorescence model in all detectors. (*Fig. 4a* of [our paper](https://arxiv.org/abs/2310.08408#))| 243 | | `hybrid_setup_fixed_slms`| 3x3 optical table with SLMs randomly positioned displaying fixed phase masks; to be used for pure topological discovery. (*Fig. 4b* of [our paper](https://arxiv.org/abs/2310.08408#))| 244 | | `hybrid_setup_fluorophores`| 3x3 optical table to be used for hybrid (topological + optical parameter) discovery; contains the fluorescence model in all detectors . (*Fig. 5a* and *Fig. 6* of [our paper](https://arxiv.org/abs/2310.08408#))| 245 | | `hybrid_setup_sharp_focus`| 3x3 optical table to be used for hybrid (topological + optical parameter) discovery. (*Fig. 5b* of [our paper](https://arxiv.org/abs/2310.08408#))| 246 | | `six_times_six_ansatz`| 6x6 optical table to be used for pure topological discovery. (*Extended Data Fig. 6* of [our paper](https://arxiv.org/abs/2310.08408#))| 247 | | 🫨 ***Add noise to the optical elements*** | - | 248 | |`shake_setup`| Literally shakes the setup: creates noise and misalignment for the optical elements. Accepts noise settings (dictionary) as argument. Can't be used with `jit` across parallel optical tables. | 249 | |`shake_setup_jit`| Same as `shake_setup`. Doesn't accept noise settings as argument. Intended to be pasted in the optimizer file to enable `jit` compilation across parallel optical tables.| 250 | 251 | 4. In [toolbox.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/xlumina/toolbox.py): file with useful functions. 252 | 253 | | *Function* |*Description*| 254 | |---------------|----| 255 | | ***Basic operations*** | - | 256 | | `space` | Builds the space where light is placed. | 257 | | `wrap_phase` | Wraps any phase mask into $[-\pi, \pi]$ range.| 258 | | `is_conserving_energy` | Computes the total intensity from the light source and compares is with the propagated light - [Ref](https://doi.org/10.1117/12.482883).| 259 | | `softmin` | Differentiable version for min() function.| 260 | | `delta_kronecker` | Kronecker delta.| 261 | | `build_LCD_cell` | Builds the cell for `LCD`.| 262 | | `draw_sSLM` | Plots the two phase masks of `sSLM`.| 263 | | `draw_sSLM_amplitude` | Plots the two amplitude masks of `sSLM`.| 264 | | `moving_avg` | Compute the moving average of a dataset.| 265 | | `image_to_binary_mask`| Converts image (png, jpeg) to binary mask. | 266 | | `rotate_mask` | Rotates the (X, Y) frame w.r.t. given point. | 267 | | `gaussian` | Defines a 1D Gaussian distribution. | 268 | | `gaussian_2d` | Defines a 2D Gaussian distribution. | 269 | | `lorentzian` | Defines a 1D Lorentzian distribution. | 270 | | `lorentzian_2d` | Defines a 2D Lorentzian distribution. | 271 | | `fwhm_1d_fit` | Computes FWHM in 1D using fit for `gaussian` or `lorentzian`.| 272 | | `fwhm_2d_fit` | Computes FWHM in 2D using fit for `gaussian_2d` or `lorentzian_2d`. | 273 | | `profile` | Determines the profile of a given input without using interpolation.| 274 | | `spot_size` | Computes the spot size as $\pi (\text{FWHM}_x \cdot \text{FWHM}_y) /\lambda^2$. | 275 | | `compute_fwhm` | Computes FWHM in 1D or 2D using fit: `gaussian`, `gaussian_2d`, `lorentzian`, `lorentzian_2`. | 276 | | 📑 ***Data loader*** | - | 277 | | `MultiHDF5DataLoader` | Data loader class for 4f system training. | 278 | 279 | 5. In [loss_functions.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/xlumina/loss_functions.py): file with loss functions. 280 | 281 | | *Function* |*Description*| 282 | |---------------|----| 283 | | `small_area_hybrid` | Small area loss function valid for hybrid (topology + optical parameters) optimization| 284 | | `vMSE_Intensity` | Parallel computation of Mean Squared Error (Intensity) for a given electric field component $E_x$, $E_y$ or $E_z$. | 285 | | `MSE_Intensity` | Mean Squared Error (Intensity) for a given electric field component $E_x$, $E_y$ or $E_z$. | 286 | | `vMSE_Phase` | Parallel computation of Mean Squared Error (Phase) for a given electric field component $E_x$, $E_y$ or $E_z$. | 287 | | `MSE_Phase` | Mean Squared Error (Phase) for a given electric field component $E_x$, $E_y$ or $E_z$. | 288 | | `vMSE_Amplitude` | Parallel computation of Mean Squared Error (Amplitude) for a given electric field component $E_x$, $E_y$ or $E_z$. | 289 | | `MSE_Amplitude` | Mean Squared Error (Amplitude) for a given electric field component $E_x$, $E_y$ or $E_z$. | 290 | | `mean_batch_MSE_Intensity` | Batch-based `MSE_Intensity`.| 291 | 292 | # ⚠️ Considerations when using XLuminA: 293 | 294 | 1. By default, JAX uses `float32` precision. If necessary, enable `jax.config.update("jax_enable_x64", True)` at the beginning of the file. 295 | 296 | 2. Basic units are microns (um) and radians. Other units (centimeters, millimeters, nanometers, and degrees) are available at `__init.py__`. 297 | 298 | 3. **IMPORTANT** - RAYLEIGH-SOMMERFELD PROPAGATION: 299 | [FFT-based diffraction calculation algorithms](https://doi.org/10.1117/12.482883) can be innacurate depending on the computational window size (sampling).\ 300 | Before propagating light, one should check which is the minimum distance available for the simulation to be accurate.\ 301 | You can use the following functions: 302 | 303 | `get_RS_minimum_z`, for `ScalarLight` class, and `get_VRS_minimum_z`, for `VectorizedLight` class. 304 | 305 | 306 | # 💻 Development: 307 | 308 | *Some functionalities of XLuminA’s optics simulator (e.g., optical propagation algorithms, lens or amplitude masks) are inspired in an open-source NumPy-based Python module for diffraction and interferometry simulation, [Diffractio](https://pypi.org/project/diffractio/). **We have rewritten and modified these approaches to combine them with JAX just-in-time (jit) functionality**. We labeled these functions as such in the docstrings. On top of that, **we developed completely new functions** (e.g., sSLMs, beam splitters, LCDs or propagation through high NA objective lens with CZT methods, to name a few) **which significantly expand the software capabilities.*** 309 | 310 | # 📝 How to cite XLuminA: 311 | 312 | If you use this software, please cite as: 313 | 314 | Rodríguez, C., Arlt, S., Möckl, L. and Krenn, M. Automated discovery of experimental designs in super-resolution microscopy with XLuminA. *Nat Commun* **15**, 10658 (2024). https://doi.org/10.1038/s41467-024-54696-y 315 | 316 | BibTeX format: 317 | 318 | @article{NatCommun.15.10658, 319 | title={Automated discovery of experimental designs in super-resolution microscopy with {XLuminA}}, 320 | author={Rodríguez, Carla and Arlt, Sören and Möckl, Leonhard and Krenn, Mario}, 321 | journal={Nature Communications}, 322 | volume={15}, 323 | pages={10658}, 324 | year={2024}, 325 | publisher={Nature Publishing Group}, 326 | doi={10.1038/s41467-024-54696-y} 327 | } 328 | -------------------------------------------------------------------------------- /examples/MPI_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/examples/MPI_logo.png -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/examples/__init__.py -------------------------------------------------------------------------------- /examples/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/examples/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /examples/autodiff_evaluation_xlumina.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Setting the path for XLuminA modules: 4 | current_path = os.path.abspath(os.path.join('..')) 5 | module_path = os.path.join(current_path) 6 | 7 | if module_path not in sys.path: 8 | sys.path.append(module_path) 9 | 10 | from xlumina.__init__ import um, nm, mm, degrees, radians 11 | from xlumina.wave_optics import * 12 | from xlumina.vectorized_optics import * 13 | from xlumina.optical_elements import * 14 | from xlumina.loss_functions import * 15 | from xlumina.toolbox import space 16 | import jax.numpy as jnp 17 | 18 | """ Evaluates the convergence and gradient evaluation times for XLuminA's autodiff optimization """ 19 | 20 | # System specs: 21 | sensor_lateral_size =500 # Resolution 22 | wavelength = 650*nm 23 | x_total = 1000*um 24 | x, y = space(x_total, sensor_lateral_size) 25 | shape = jnp.shape(x)[0] 26 | 27 | # Light source specs 28 | w0 = (1200*um , 1200*um) 29 | gb = LightSource(x, y, wavelength) 30 | gb.gaussian_beam(w0=w0, E0=1) 31 | gb_gt = ScalarLight(x, y, wavelength) 32 | # Set spiral phase for the ground truth 33 | gb_gt.set_spiral() 34 | 35 | # Define the setup 36 | def setup(gb, parameters): 37 | gb_propagated = gb.RS_propagation(z=25000*mm) 38 | gb_modulated, _ = SLM(gb, parameters, gb.x.shape[0]) 39 | return gb_modulated 40 | 41 | def mse_phase(input_light, target_light): 42 | return jnp.sum((jnp.angle(input_light.field) - jnp.angle(target_light.field)) ** 2) / sensor_lateral_size**2 43 | def loss(parameters): 44 | out = setup(gb, parameters) 45 | loss_val = mse_phase(out, gb_gt) 46 | return loss_val 47 | 48 | # Optimizer for phase mask 49 | import time 50 | import jax 51 | from jax import grad, jit 52 | from jax.example_libraries import optimizers 53 | 54 | # Print device info (GPU or CPU) 55 | print(jax.devices(), flush=True) 56 | 57 | # Define the update: 58 | @jit 59 | def update(step_index, optimizer_state): 60 | # define single update step 61 | parameters = get_params(optimizer_state) 62 | # Call the loss function and compute the gradients 63 | computed_loss = loss_value(parameters) 64 | computed_gradients = grad(loss_value, allow_int=True)(parameters) 65 | 66 | return opt_update(step_index, computed_gradients, optimizer_state), computed_loss, computed_gradients 67 | 68 | # Define the loss function and compute its gradients: 69 | loss_value = jit(loss) 70 | 71 | # Optimizer settings 72 | STEP_SIZE = 0.1 73 | num_iterations = 50000 74 | n_best = 50 75 | best_loss = 1e10 76 | best_params = None 77 | best_step = 0 78 | num_samples = 5 79 | 80 | steps = [] 81 | times = [] 82 | ratio = [] 83 | 84 | for i in range(num_samples): 85 | # Parameters for STED 86 | parameters = jnp.array([np.random.uniform(-jnp.pi, jnp.pi, (shape, shape))], dtype=jnp.float64) 87 | 88 | # Define the optimizer and initialize it 89 | opt_init, opt_update, get_params = optimizers.adam(STEP_SIZE) 90 | init_params = parameters 91 | opt_state = opt_init(init_params) 92 | 93 | print('Starting Optimization', flush=True) 94 | 95 | tic = time.perf_counter() 96 | 97 | # Optimize in a loop 98 | for step in range(num_iterations): 99 | 100 | opt_state, loss_value, gradients = update(step, opt_state) 101 | 102 | if loss_value < best_loss: 103 | # Best loss value 104 | best_loss = loss_value 105 | # Best optimized parameters 106 | best_params = get_params(opt_state) 107 | best_step = step 108 | # print('Best loss value is updated') 109 | 110 | if step % 100 == 0: 111 | # Stopping criteria: if best loss has not changed every 500 steps 112 | if step - best_step > n_best: 113 | print(f'Stopping criterion: no improvement in loss value for {n_best} steps') 114 | break 115 | 116 | steps.append(step) 117 | times.append(time.perf_counter() - tic) 118 | ratio.append((time.perf_counter() - tic)/step) 119 | 120 | filename = f"xlumina_cpu_eval_{sensor_lateral_size}.npy" 121 | np.save(filename, {"Time": times, "Step": steps, "t/step": ratio}) -------------------------------------------------------------------------------- /examples/discovered_solution_4f_system.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/examples/discovered_solution_4f_system.npy -------------------------------------------------------------------------------- /examples/noisy_optimization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 🦾🤖 Noise-aware optimization scheme with ✨ XLuminA ✨: \n", 8 | "\n", 9 | "This notebook is a step-by-step guide for building a robust (noise-aware) optimization scheme in XLuminA.\n", 10 | "\n", 11 | "We will set-up an optimization scheme for *the sharp focus for a radially polarized light beam* - (we use **robust_discovery** from **optical_elements.py**)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import os\n", 21 | "import sys\n", 22 | "\n", 23 | "# Setting the path for XLuminA modules:\n", 24 | "current_path = os.path.abspath(os.path.join('..'))\n", 25 | "module_path = os.path.join(current_path)\n", 26 | "\n", 27 | "if module_path not in sys.path:\n", 28 | " sys.path.append(module_path)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 17, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "from xlumina.__init__ import um, nm, mm\n", 38 | "from xlumina.vectorized_optics import *\n", 39 | "from xlumina.optical_elements import robust_discovery\n", 40 | "from xlumina.toolbox import space, softmin\n", 41 | "from xlumina.loss_functions import vectorized_loss_hybrid\n", 42 | "import jax.numpy as jnp\n", 43 | "import jax\n", 44 | "from jax import random\n", 45 | "import optax" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "## System specs, define light sources, output dimensions and static parameters during optimization:" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 5, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# 1. System specs:\n", 62 | "sensor_lateral_size = 512 # Resolution\n", 63 | "wavelength1 = 650*nm\n", 64 | "x_total = 2500*um\n", 65 | "x, y = space(x_total, sensor_lateral_size)\n", 66 | "shape = jnp.shape(x)[0]\n", 67 | "\n", 68 | "# 2. Define the optical functions: two orthogonally polarized beams:\n", 69 | "w0 = (1200*um, 1200*um) \n", 70 | "ls1 = PolarizedLightSource(x, y, wavelength1)\n", 71 | "ls1.gaussian_beam(w0=w0, jones_vector=(1, 1))\n", 72 | "\n", 73 | "# 3. Define the output (High Resolution) detection:\n", 74 | "x_out, y_out = jnp.array(space(10*um, 400)) # Pixel size detector: 20 um / 400 pix \n", 75 | "\n", 76 | "# 4. High NA objective lens specs:\n", 77 | "NA = 0.9 \n", 78 | "radius_lens = 3.6*mm/2 \n", 79 | "f_lens = radius_lens / NA\n", 80 | "\n", 81 | "# 5. Static parameters - don't change during optimization:\n", 82 | "fixed_params = [radius_lens, f_lens, x_out, y_out]" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "## Define the optical setup:" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "1. Vectorized version of the optical setup [`optical_elements.py` > `robust_discovery`] over a new axis (defined by the noise).\n", 97 | "\n", 98 | " Here the args of the function are: `light source` (ls1 - ls6, common to all tables), `parameters` (common to all tables), `fixed_params` (common to all tables), `noise` (DIFFERENT for each table). " 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 6, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "def batch_robust_discovery(ls1, ls2, ls3, ls4, ls5, ls6, parameters, fixed_params, noise_distances, noise_slms, noise_wps, noise_amps, distance_offset):\n", 108 | " \"\"\"\n", 109 | " Vectorized (efficient) version of robust_discovery() for batch optimization. \n", 110 | " \n", 111 | " Parameters: \n", 112 | " ls1, ls2, ls3, ls4, ls5, ls6 (PolarizedLightSource)\n", 113 | " parameters (jnp.array): parameters to pass to the optimizer\n", 114 | " BB 1: [phase1_1, phase1_2, eta1, theta1, z1_1, z1_2]\n", 115 | " BB 2: [phase2_1, phase2_2, eta2, theta2, z2_1, z2_2] \n", 116 | " BB 3: [phase3_1, phase3_2, eta3, theta3, z3_1, z3_2]\n", 117 | " BS ratios: [bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9] <- automatically contains \n", 118 | " Extra distances: [z4, z5]\n", 119 | " fixed_params (jnp.array): parameters to maintain fixed during optimization [r, f, xout and yout]; that is radius and focal length of the objective lens.\n", 120 | " \n", 121 | " noise_distances (jnp.array): Misalignment (in microns): [noise_z1_1, noise_z1_2, ...]\n", 122 | " noise_slms (jnp.array): Noise (in radians): [noise_phase1_1, noise_phase1_2, ...]\n", 123 | " noise_wps (jnp.array): Noise (in radians): [noise_eta1, noise_theta1, ...]\n", 124 | " noise_amps (jnp.array): Noise (in AU): [noise_A1, noise_A2, ...]\n", 125 | " \n", 126 | " Returns vectorized version of detected light (intensity tensor): (# tables, (6, resolution, resolution))\n", 127 | " \"\"\"\n", 128 | " # Noise shapes are: \n", 129 | " # distance: (#tables, 1, 8); \n", 130 | " # slms (amp and phase): (#tables, 6, (resolution,resolution)); \n", 131 | " # wp: (#tables, 1, 6)\n", 132 | " # vmap in axes 0 -> across optical tables\n", 133 | " detected_intensities_z = vmap(robust_discovery, in_axes=(None, None, None, None, None, None, None, None, \n", 134 | " 0, 0, 0, 0, None))(ls1, ls2, ls3, ls4, ls5, ls6, parameters, fixed_params, \n", 135 | " noise_distances, noise_slms, noise_wps, noise_amps, distance_offset)\n", 136 | " return detected_intensities_z" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "2. Vectorize the loss function: `vmap` the computation of the loss function across different optical tables.(imported from `loss_functions.py` > `vectorized_loss_hybrid`)." 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 7, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "def mean_batch_discover(detected_light):\n", 153 | " \"\"\" vmap loss over optical tables: \"\"\"\n", 154 | " # detected_light with shape (#optical tables, (6, N, N))\n", 155 | " # vmap loss in axis #optical tables \n", 156 | " return vmap(vectorized_loss_hybrid, in_axes=(0,))(detected_light)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "3. Define the loss: first computes light from paralellized optical tables (`batch_robust_discovery`), later compute the loss function for each detector in each optical table. Finally, compute the mean loss value across the optical tables and get the minimum value using `softmin`." 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 8, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "@jit \n", 173 | "def loss_batch_discovery(parameters, noise_d, noise_slm, noise_wp, noise_amp):\n", 174 | " \"\"\"\n", 175 | " Loss function. It computes L= Area/I_{epsilon} across detectors.\n", 176 | " \n", 177 | " Parameters:\n", 178 | " parameters (list): Optimized parameters.\n", 179 | " noise_distances (jnp.array): Misalignment (in microns): [noise_d1, noise_d2, ...]\n", 180 | " noise_slms (jnp.array): Noise (in radians): [noise_slm_1, noise_slm_2, ...]\n", 181 | " noise_wps (jnp.array): Noise (in radians): [noise_eta, noise_theta, ...]\n", 182 | " noise_amp (jnp.array): Noise (in AU): [noise_A1, noise_A2, ...]\n", 183 | "\n", 184 | " Returns the mean value of the loss computed for all the inputs. \n", 185 | " \"\"\"\n", 186 | " # Output from batch_robust_discovery is (#optical tables, (6, N, N)): for 6 detectors each\n", 187 | " detected_z_intensities = batch_robust_discovery(ls1, ls1, ls1, ls1, ls1, ls1, parameters, fixed_params, noise_d, noise_slm, noise_wp, noise_amp, distance_offset = 15) \n", 188 | "\n", 189 | " # Get the minimum value within loss value array\n", 190 | " # output from mean_batch_discover is (#optical tables, (6, 1)). \n", 191 | " # Compute the mean across #optical tables and get the minimum value using softmin.\n", 192 | " loss_val = softmin(jnp.mean(mean_batch_discover(detected_z_intensities), axis=0, keepdims=True))\n", 193 | " return loss_val # shape (#optical tables, 1)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "## Optimizer settings:" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 9, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "# Global variable\n", 210 | "shape = jnp.array([sensor_lateral_size, sensor_lateral_size])\n", 211 | "# Define the loss function:\n", 212 | "loss_function = loss_batch_discovery" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 10, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "# Optimization settings\n", 222 | "OS = {'n_best': 500,\n", 223 | " 'best_loss': 3*1e2,\n", 224 | " 'num_iterations': 50000,\n", 225 | " 'num_samples': 1,\n", 226 | " 'WEIGHT_DECAY': 1e-3,\n", 227 | " 'BASE_lr': 0.05,\n", 228 | " 'END_lr': 0.001,\n", 229 | " 'DECAY_STEPS': 4000\n", 230 | " }" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "## [!!] define noise settings dictionary:\n", 238 | "\n", 239 | "**NS (dict)**: \n", 240 | "\n", 241 | " NS = {'n_tables': __, \n", 242 | " 'number of distances': __,\n", 243 | " 'number of sSLM': __, \n", 244 | " 'number of wps': __,\n", 245 | " 'noise_level': __, \n", 246 | " 'misalignment': (minval, maxval), \n", 247 | " 'phase_noise': (minval, maxval),\n", 248 | " 'discretize': __}\n", 249 | "\n", 250 | "where, \n", 251 | "\n", 252 | "**n_tables (int)**: number of optical tables to compute in parallel\n", 253 | "\n", 254 | "**number of distances, number of sSLM, number of wps (str)**: number of distances, sSLM and wave plates in the optical setup.\n", 255 | "\n", 256 | "**level (str)**:\n", 257 | "\n", 258 | " 1. low: noise in SLMs and WPs $\\pm$(0.01 to 0.05) rads and misalignment of $\\pm$(0.01 to 0.05) mm \n", 259 | " 2. mild: noise in SLMs and WPs $\\pm$(0.05 to 0.5) rads and misalignment of $\\pm$(0.05 to 0.5) mm \n", 260 | " 3. high: noise in SLMs and WPs $\\pm$(0.5 to 1) rads and misalignment of $\\pm$(0.5 to 1) mm \n", 261 | " 4. all: noise in SLMs and WPs $\\pm$(0.01 to 1) rads and misalignment of $\\pm$(0.01 to 1) mm \n", 262 | " 5. tunable: tunable noise via NS dictionary\n", 263 | "\n", 264 | "**discretize (bool)**: if true, discretize SLM noise to 8-bit.\n" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 11, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "# Noise settings:\n", 274 | "NS = {'n_tables': 3, \n", 275 | " 'number of distances': 8,\n", 276 | " 'number of sSLM': 3, \n", 277 | " 'number of wps': 3,\n", 278 | " 'noise_level': 'tunable',\n", 279 | " 'misalignment': (10, 100), \n", 280 | " 'phase_noise': (0.01, 0.1),\n", 281 | " 'discretize': False}" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "1. Define keychain for the number of optical tables specified in NS." 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 12, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "# Keychain for optical tables\n", 298 | "def keychain_optical_tables(seed, number_of_tables):\n", 299 | " \"\"\" \n", 300 | " Generates keychain for # optical tables especified \n", 301 | " \"\"\"\n", 302 | " keychain = []\n", 303 | " for num in range(number_of_tables):\n", 304 | " key_table = random.PRNGKey(seed + num)\n", 305 | " keychain.append(key_table)\n", 306 | " \n", 307 | " return jnp.array(keychain)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "2. Define `shake_setup()` and `batch_shake_setup()` as functions to include noise in the setup per iteration. \n", 315 | "\n", 316 | " Two types of shaking functions are provided in `optical_elements.py`. \n", 317 | "\n", 318 | " 1. `shake_setup` takes noise settings NS:dict as argument. Thus, `batch_shake_setup` can't be used with @jit or @partial(jit). \n", 319 | "\n", 320 | " 2. However, if you want to @jit `batch_shake_setup`, copy-paste `shake_setup_jit` in your optimizer file, as it doesn't have NS:dict as an argument. \n", 321 | "\n", 322 | "Here we will copy-paste `shake_setup_jit` and use NS as global variable. " 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 13, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "def shake_setup_jit(key, resolution):\n", 332 | " \"\"\"\n", 333 | " [THIS FUNCTION IS INTENDED TO BE PASTED IN THE OPTIMIZER FILE TO ENABLE @jit COMPILATION FOR `batch_shake_setup`]\n", 334 | " \n", 335 | " Creates noise for all the different optical variables on an optical table.\n", 336 | " \n", 337 | " Parameters:\n", 338 | " key (PRNGKey): JAX random key for reproducibility\n", 339 | " resolution (int): number of pixels for space\n", 340 | " \n", 341 | " global variable NS (dict): noise settings as\n", 342 | "\n", 343 | " NS = {'n_tables': __, 'number of distances': __,\n", 344 | " 'number of sSLM': __, 'number of wps': __,\n", 345 | " 'noise_level': __, \n", 346 | " 'misalignment': (minval, maxval), \n", 347 | " 'phase_noise': (minval, maxval),\n", 348 | " 'discretize': __}\n", 349 | " \n", 350 | " Returns:\n", 351 | " random_noise_distances, random_noise_slms, random_noise_wps, random_noise_amps, key0 (new key to split in the next iteration), key (old key)\n", 352 | " \"\"\"\n", 353 | " num_physical_variables = 4 # of physical variables (e.g., distance, slm phase, ,...) to optimize.\n", 354 | " # split as many times as variables + 1 to renew the key0 each step\n", 355 | " key0, key1, key2, key3, key4 = random.split(key, num_physical_variables+1)\n", 356 | " d_type = 'int8'\n", 357 | "\n", 358 | " # NS is not an input to ensure vmap during optimization.\n", 359 | " level = NS['noise_level']\n", 360 | " discretize = NS['discretize']\n", 361 | " \n", 362 | " # level can be: 'low' == 0, 'mild' == 1, 'high' == 2, 'all' == 3 and 'tunable' == 4\n", 363 | "\n", 364 | " if level == 'low':\n", 365 | " # Misalignment (um)\n", 366 | " minval_d = 10*um # 0.01 mm\n", 367 | " maxval_d = 50*um # 0.05 mm\n", 368 | " # SLM / WP phase and amplitude (rads and AU, respectively)\n", 369 | " minval_phase = 0.01 \n", 370 | " maxval_phase = 0.05\n", 371 | "\n", 372 | " if level == 'mild':\n", 373 | " # Misalignment (um)\n", 374 | " minval_d = 50*um # 0.05 mm\n", 375 | " maxval_d = 500*um # 0.5 mm\n", 376 | " # SLM / WP phase and amplitude (rads and AU, respectively)\n", 377 | " minval_phase = 0.05 \n", 378 | " maxval_phase = 0.5\n", 379 | "\n", 380 | " if level == 'high':\n", 381 | " # Misalignment (um)\n", 382 | " minval_d = 500*um # 0.5 mm\n", 383 | " maxval_d = 1000*um # 1 mm\n", 384 | " # SLM / WP phase and amplitude (rads and AU, respectively)\n", 385 | " minval_phase = 0.5\n", 386 | " maxval_phase = 1\n", 387 | "\n", 388 | " if level == 'all':\n", 389 | " # Misalignment (um)\n", 390 | " minval_d = 10*um # 0.01 mm\n", 391 | " maxval_d = 1000*um # 0.15 mm\n", 392 | " # SLM / WP phase and amplitude (rads and AU, respectively)\n", 393 | " minval_phase = 0.01 \n", 394 | " maxval_phase = 1\n", 395 | "\n", 396 | " if level == 'tunable':\n", 397 | " # Misalignment (um)\n", 398 | " minval_d, maxval_d = NS['misalignment'] # in um\n", 399 | " # SLM / WP phase and amplitude (rads and AU, respectively)\n", 400 | " minval_phase, maxval_phase = NS['phase_noise']\n", 401 | "\n", 402 | " if discretize: \n", 403 | " d_type = 'uint8'\n", 404 | "\n", 405 | " # noise for distances (d1 and d2): shape = (1, NS['number of distances'])\n", 406 | " random_noise_distances = jnp.squeeze(random.uniform(key1, shape=(1, NS['number of distances']), minval=minval_d, maxval=maxval_d), axis=0) \n", 407 | " # noise for SLMs phases and amplitude (slm1 and slm2): shape = (2, (resolution, resolution))\n", 408 | " random_noise_amps = random.choice(key4, jnp.array([-1,1]), shape=(2*NS['number of sSLM'], resolution, resolution)).astype(d_type) * random.uniform(key2, shape=(2*NS['number of sSLM'], resolution, resolution), minval=minval_phase, maxval=maxval_phase) \n", 409 | " random_noise_slms = random.choice(key2, jnp.array([-1,1]), shape=(2*NS['number of sSLM'], resolution, resolution)).astype(d_type) * random.uniform(key2, shape=(2*NS['number of sSLM'], resolution, resolution), minval=minval_phase, maxval=maxval_phase) \n", 410 | " # noise for WP angles (eta and theta): shape = (1, 2)\n", 411 | " random_noise_wps = jnp.squeeze(random.choice(key3, jnp.array([-1,1]), shape=(1, 2*NS['number of wps'])).astype('int8'), axis=0) * jnp.squeeze(random.uniform(key3, shape=(1, 2*NS['number of wps']), minval=minval_phase, maxval=maxval_phase), axis=0) \n", 412 | " \n", 413 | " return random_noise_distances, random_noise_slms, random_noise_wps, random_noise_amps, key0, key\n", 414 | "\n", 415 | "\n", 416 | "@jit\n", 417 | "def batch_shake_setup(key_array, array_for_shape):\n", 418 | " \"\"\"\n", 419 | " Creates noise for all the different optimizable variables on multiple optical tables given by size(key_array).\n", 420 | " \n", 421 | " Parameters: \n", 422 | " key_array (PRNGKey): Array with different keys -- will change for each step in the optimization. \n", 423 | " The dimension of this array is decided by # of optical tables to compute in parallel.\n", 424 | " array_for_shape (jnp.array): array of shape [resolution, resolution] to make it jit. \n", 425 | " \n", 426 | " Returns:\n", 427 | " random_noise_distances [with shape = (size(key_array), 1, NS['number of distances'])], \n", 428 | " random_noise_amps [with shape = (size(key_array), 2*NS['number of sSLM'], resolution, resolution)], \n", 429 | " random_noise_slms [with shape = (size(key_array), 2*NS['number of sSLM'], resolution, resolution)], \n", 430 | " random_noise_wps [with shape = (size(key_array), 1, 2*NS['number of wps'])]\n", 431 | " key0 (PRNGKey): array with key0 to split in the next iteration step\n", 432 | " \"\"\"\n", 433 | " return vmap(shake_setup_jit, in_axes = (0, None))(key_array, jnp.shape(array_for_shape)[0])" 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "metadata": {}, 439 | "source": [ 440 | "## Define optimizer (adamw with schedule)" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 14, 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [ 449 | "def adamw_schedule(base_lr, end_lr, decay_steps, weight_decay) -> optax.GradientTransformation:\n", 450 | " \"\"\"\n", 451 | " Custom optimizer - adamw: applies several transformations in sequence\n", 452 | " 1) Apply ADAMW\n", 453 | " 2) Apply lr schedule\n", 454 | " \"\"\"\n", 455 | " lr_schedule = base_lr\n", 456 | " #lr_schedule = optax.linear_schedule(init_value= base_lr, end_value = end_lr, transition_steps = decay_steps, transition_begin = 500) \n", 457 | " return optax.adamw(learning_rate=lr_schedule, weight_decay=weight_decay)" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "metadata": {}, 463 | "source": [ 464 | "## Optimization loop: " 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 15, 470 | "metadata": {}, 471 | "outputs": [], 472 | "source": [ 473 | "def fit(params: optax.Params, optimizer: optax.GradientTransformation, num_iterations, keys, x) -> optax.Params:\n", 474 | " \n", 475 | " # Init the optimizer with initial parameters\n", 476 | " opt_state = optimizer.init(params)\n", 477 | "\n", 478 | " @jit\n", 479 | " def update(parameters, opt_state, noise_d, noise_slm, noise_wp, noise_amp):\n", 480 | " # Define single update step - contains noise_array: \n", 481 | " loss_value, grads = jax.value_and_grad(loss_function)(parameters, noise_d, noise_slm, noise_wp, noise_amp)\n", 482 | " # Update the state of the optimizer\n", 483 | " updates, state = optimizer.update(grads, opt_state, parameters)\n", 484 | " # Update the parameters\n", 485 | " new_params = optax.apply_updates(parameters, updates)\n", 486 | " \n", 487 | " return new_params, parameters, state, loss_value, updates\n", 488 | "\n", 489 | " # Initialize some parameters \n", 490 | " n_best = OS['n_best']\n", 491 | " best_loss = OS['best_loss']\n", 492 | " best_params = None\n", 493 | " best_keys = None\n", 494 | " best_step = 0\n", 495 | " \n", 496 | " print('Starting Optimization', flush=True)\n", 497 | " \n", 498 | " for step in range(num_iterations):\n", 499 | " \n", 500 | " # Add noise: update noise and keys each iteration\n", 501 | " noise_d, noise_slm, noise_wp, noise_amp, keys, old_keys = batch_shake_setup(keys, x) # 'x' is the space variable from optical table\n", 502 | " \n", 503 | " # Apply update step\n", 504 | " params, old_params, opt_state, loss_value, grads = update(params, opt_state, noise_d, noise_slm, noise_wp, noise_amp)\n", 505 | " \n", 506 | " print(f\"Step {step}\")\n", 507 | " print(f\"Loss {loss_value}\")\n", 508 | " \n", 509 | " # Update the `best_loss` value:\n", 510 | " if loss_value < best_loss:\n", 511 | " # Best loss value\n", 512 | " best_loss = loss_value\n", 513 | " # Best optimized parameters\n", 514 | " best_params = old_params\n", 515 | " # Keys for best params\n", 516 | " best_keys = old_keys\n", 517 | " best_step = step\n", 518 | " print('Best loss value is updated')\n", 519 | "\n", 520 | " if step % 100 == 0:\n", 521 | " # Stopping criteria: if best_loss has not changed every 500 steps, stop.\n", 522 | " if step - best_step > n_best:\n", 523 | " print(f'Stopping criterion: no improvement in loss value for {n_best} steps')\n", 524 | " break\n", 525 | " \n", 526 | " print(f'Best loss: {best_loss} at step {best_step}')\n", 527 | " print(f'Best parameters: {best_params}') \n", 528 | " return best_params, best_loss, best_keys" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "# Optimizer settings\n", 538 | "num_iterations = OS['num_iterations']\n", 539 | "num_samples = OS['num_samples']\n", 540 | "\n", 541 | "for i in range(num_samples):\n", 542 | " tic = time.perf_counter()\n", 543 | " \n", 544 | " # seed1 to ensure randomness among samples\n", 545 | " seed1 = np.random.randint(9999)\n", 546 | " \n", 547 | " # Init keychain for noise -- as many init keys as optical tables to parallelize\n", 548 | " keys = keychain_optical_tables(seed1, NS['n_tables'])\n", 549 | " \n", 550 | " # Optimizer settings\n", 551 | " WEIGHT_DECAY = OS['WEIGHT_DECAY']\n", 552 | " BASE_lr = OS['BASE_lr']\n", 553 | " END_lr = OS['END_lr']\n", 554 | " DECAY_STEPS = OS['DECAY_STEPS']\n", 555 | " \n", 556 | " # Random init parameters:\n", 557 | " phase1_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 558 | " phase1_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 559 | " a1_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 560 | " a1_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 561 | " \n", 562 | " phase2_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 563 | " phase2_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 564 | " a2_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 565 | " a2_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 566 | " \n", 567 | " phase3_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 568 | " phase3_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 569 | " a3_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 570 | " a3_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n", 571 | " \n", 572 | " eta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n", 573 | " theta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n", 574 | " eta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n", 575 | " theta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n", 576 | " eta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n", 577 | " theta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n", 578 | " eta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n", 579 | " theta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n", 580 | " \n", 581 | " z1_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 582 | " z1_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 583 | " z2_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 584 | " z2_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 585 | " z3_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 586 | " z3_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 587 | " z4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 588 | " z5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 589 | " \n", 590 | " bs1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 591 | " bs2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 592 | " bs3 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 593 | " bs4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 594 | " bs5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 595 | " bs6 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 596 | " bs7 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 597 | " bs8 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 598 | " bs9 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n", 599 | " \n", 600 | " # Init params for 3x3 robust discovery\n", 601 | " init_params = [phase1_1, phase1_2, a1_1, a1_2, eta1, theta1, z1_1, z1_2, \n", 602 | " phase2_1, phase2_2, a2_1, a2_2, eta2, theta2, z2_1, z2_2, \n", 603 | " phase3_1, phase3_2, a3_1, a3_2, eta3, theta3, z3_1, z3_2, \n", 604 | " bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9, \n", 605 | " z4, z5]\n", 606 | " \n", 607 | " # Init optimizer:\n", 608 | " optimizer = adamw_schedule(BASE_lr, END_lr, DECAY_STEPS, WEIGHT_DECAY)\n", 609 | "\n", 610 | " # Apply fit function:\n", 611 | " best_params, best_loss, iteration_steps, loss_list, keys_noise = fit(init_params, optimizer, num_iterations, keys, x)\n", 612 | " \n", 613 | " print(f\"Time taken to optimize one sample - in seconds {(time.perf_counter() - tic):.4f}\")" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": null, 619 | "metadata": {}, 620 | "outputs": [], 621 | "source": [] 622 | } 623 | ], 624 | "metadata": { 625 | "kernelspec": { 626 | "display_name": "dummy_env", 627 | "language": "python", 628 | "name": "python3" 629 | }, 630 | "language_info": { 631 | "codemirror_mode": { 632 | "name": "ipython", 633 | "version": 3 634 | }, 635 | "file_extension": ".py", 636 | "mimetype": "text/x-python", 637 | "name": "python", 638 | "nbconvert_exporter": "python", 639 | "pygments_lexer": "ipython3", 640 | "version": "3.11.10" 641 | } 642 | }, 643 | "nbformat": 4, 644 | "nbformat_minor": 2 645 | } 646 | -------------------------------------------------------------------------------- /examples/numerical_methods_evaluation_diffractio.py: -------------------------------------------------------------------------------- 1 | from diffractio import np, degrees, um, mm, nm 2 | from diffractio.scalar_sources_XY import Scalar_source_XY 3 | from diffractio.scalar_fields_XY import Scalar_field_XY 4 | import time 5 | 6 | """ Evaluates the convergence and gradient evaluation times for Diffractio + numerical optimization """ 7 | 8 | # Light source settings 9 | wavelength = .6328 * um 10 | w0 = (1200*um , 1200*um) 11 | sensor_lateral_size = 10 # Resolution 12 | wavelength = 650*nm 13 | x_total = 1000*um 14 | x = np.linspace(-x_total , x_total , sensor_lateral_size) 15 | y = np.linspace(-x_total , x_total , sensor_lateral_size) 16 | gb = Scalar_source_XY(x, y, wavelength, info='Light source') 17 | gb.gauss_beam(r0=(0 * um, 0 * um), w0=w0, z0=(0,0), A=1, theta=0 * degrees, phi=0 * degrees) 18 | 19 | # Spiral phase for ground truth 20 | gb_gt = Scalar_field_XY(x, y, wavelength) 21 | phase_mask = np.arctan2(gb.Y,gb.X) 22 | gb_gt.u = gb.u * np.exp(1j * phase_mask) 23 | 24 | # Optical setup 25 | def setup(gb, parameters): 26 | gb_propagated = gb.RS(25000*mm) 27 | gb_modulated, _ = npSLM(gb_propagated.u, parameters, gb.x.shape[0]) 28 | return gb_modulated 29 | 30 | def mse_phase(input_light, target_light): 31 | return np.sum((np.angle(input_light) - np.angle(target_light.u)) ** 2) / sensor_lateral_size**2 32 | 33 | def loss(parameters_flat): 34 | parameters = parameters_flat.reshape(sensor_lateral_size,sensor_lateral_size) 35 | out = setup(gb, parameters) 36 | loss_val = mse_phase(out, gb_gt) 37 | return loss_val 38 | 39 | def phase(phase): 40 | return np.exp(1j * phase) 41 | 42 | def npSLM(input_field, phase_array, shape): 43 | slm = np.fromfunction(lambda i, j: phase(phase_array[i, j]), 44 | (shape, shape), dtype=int) 45 | light_out = input_field * slm # Multiplies element-wise 46 | return light_out, slm 47 | 48 | from scipy.optimize import minimize 49 | import time 50 | 51 | results = [] 52 | times = [] 53 | 54 | parameters = np.random.uniform(-np.pi, np.pi, (sensor_lateral_size, sensor_lateral_size)).flatten() 55 | tic = time.perf_counter() 56 | 57 | res = minimize(loss, parameters, method='BFGS', options={'disp': True}) 58 | 59 | time_to_conv = time.perf_counter() - tic 60 | times.append(time_to_conv) 61 | results.append(res) 62 | 63 | # We save the output (res) to divide the total time by 'njev' from BFGS. -------------------------------------------------------------------------------- /examples/scalar_diffractio.py: -------------------------------------------------------------------------------- 1 | from diffractio import np, degrees, um, mm 2 | from diffractio.scalar_sources_XY import Scalar_source_XY 3 | from diffractio.scalar_fields_XY import Scalar_field_XY 4 | import time 5 | 6 | """ Computes the running times for scalar version of Rayleigh-Sommerfeld (RS) and Chirped z-transform (CZT) algorithms using Diffractio """ 7 | 8 | # Light source settings 9 | wavelength = .6328 * um 10 | w0 = (1200*um , 1200*um) 11 | x = np.linspace(-15 * mm, 15 * um, 2048) 12 | y = np.linspace(-15 * um, 15 * um, 2048) 13 | x_out = np.linspace(-15 * um, 15 * um, 2048) 14 | y_out = np.linspace(-15 * um, 15 * um, 2048) 15 | 16 | ls = Scalar_source_XY(x, y, wavelength, info='Light source') 17 | ls.gauss_beam(r0=(0 * um, 0 * um), w0=w0, z0=(0,0), A=1, theta=0 * degrees, phi=0 * degrees) 18 | 19 | 20 | time_RS_diffractio = [] 21 | time_CZT_diffractio = [] 22 | 23 | for i in range(101): 24 | tic = time.perf_counter() 25 | ls_propagated = ls.RS(z=5*mm) 26 | time_RS_diffractio.append(time.perf_counter() - tic) 27 | 28 | for i in range(101): 29 | tic = time.perf_counter() 30 | ls_propagated = ls.CZT(z=5*mm, xout=x_out, yout=y_out, verbose=False) 31 | time_CZT_diffractio.append(time.perf_counter() - tic) 32 | 33 | filename = f"Scalar_propagation_Diffractio.npy" 34 | np.save(filename, {"RS_Diffractio": time_RS_diffractio, "CZT_Diffractio": time_CZT_diffractio}) -------------------------------------------------------------------------------- /examples/scalar_xlumina.py: -------------------------------------------------------------------------------- 1 | # Setting the path for XLuminA modules: 2 | current_path = os.path.abspath(os.path.join('..')) 3 | module_path = os.path.join(current_path) 4 | 5 | if module_path not in sys.path: 6 | sys.path.append(module_path) 7 | 8 | 9 | from xlumina.__init__ import mm 10 | from xlumina.wave_optics import * 11 | from xlumina.vectorized_optics import * 12 | from xlumina.optical_elements import * 13 | from xlumina.loss_functions import * 14 | from xlumina.toolbox import space 15 | import jax.numpy as jnp 16 | 17 | import time 18 | 19 | """ Computes the running times for scalar version of Rayleigh-Sommerfeld (RS) and Chirped z-transform (CZT) algorithms using XLuminA """ 20 | 21 | # Light source settings 22 | resolution = 2048 23 | wavelength = .6328 * um 24 | w0 = (1200*um , 1200*um) 25 | x, y = space(15*mm, resolution) 26 | x_out, y_out = jnp.array(space(15*mm, resolution)) 27 | 28 | gb = LightSource(x, y, wavelength) 29 | gb.gaussian_beam(w0=w0, E0=1) 30 | 31 | # Rayleigh-Sommerfeld: 32 | tic = time.perf_counter() 33 | ls_propagated = gb.RS_propagation(z=5*mm) 34 | print("Time taken for 1st RS propagation - in seconds", time.perf_counter() - tic) 35 | 36 | time_RS_xlumina = [] 37 | time_CZT_xlumina = [] 38 | 39 | for i in range(101): 40 | tic = time.perf_counter() 41 | ls_propagated = gb.RS_propagation(z=5*mm) 42 | t = time.perf_counter() - tic 43 | time_RS_xlumina.append(t) 44 | 45 | 46 | # Chirped z-transform: 47 | tic = time.perf_counter() 48 | ls_propagated = gb.CZT(z=5*mm, xout=x_out, yout=y_out) 49 | print("Time taken for 1st CZT propagation - in seconds", time.perf_counter() - tic) 50 | 51 | for i in range(101): 52 | tic = time.perf_counter() 53 | ls_propagated = gb.CZT(z=5*mm, xout=x_out, yout=y_out) 54 | t = time.perf_counter() - tic 55 | time_CZT_xlumina.append(t) 56 | 57 | filename = f"Scalar_propagation_xlumina_GPU.npy" 58 | np.save(filename, {"RS_xlumina": time_RS_xlumina, "CZT_xlumina": time_CZT_xlumina}) -------------------------------------------------------------------------------- /examples/vectorial_diffractio.py: -------------------------------------------------------------------------------- 1 | from diffractio import np, degrees, um, mm 2 | from diffractio.scalar_sources_XY import Scalar_source_XY 3 | from diffractio.scalar_fields_XY import Scalar_field_XY 4 | from diffractio.vector_sources_XY import Vector_source_XY 5 | from diffractio.vector_fields_XY import Vector_field_XY 6 | import time 7 | 8 | """ Computes the running times for vectorial version of Rayleigh-Sommerfeld (VRS) and Chirped z-transform (VCZT) algorithms using Diffractio """ 9 | 10 | # Light source settings 11 | wavelength = .6328 * um 12 | w0 = (1200*um , 1200*um) 13 | x = np.linspace(-15 * mm, 15 * um, 2048) 14 | y = np.linspace(-15 * um, 15 * um, 2048) 15 | x_out = np.linspace(-15 * um, 15 * um, 2048) 16 | y_out = np.linspace(-15 * um, 15 * um, 2048) 17 | 18 | 19 | ls = Scalar_source_XY(x, y, wavelength, info='Light source') 20 | ls.gauss_beam(r0=(0 * um, 0 * um), w0=w0, z0=(0,0), A=1, theta=0 * degrees, phi=0 * degrees) 21 | 22 | vls = Vector_source_XY(x, y, wavelength=wavelength, info='Light source polarization') 23 | vls.constant_polarization(u=ls, v=(1, 0), has_normalization=False, radius=(15*mm, 15*mm)) 24 | 25 | time_VRS_diffractio = [] 26 | time_VCZT_diffractio = [] 27 | 28 | for i in range(101): 29 | tic = time.perf_counter() 30 | vls.VRS(z=5*mm, n=1, new_field=False, verbose=False, amplification=(1, 1)) 31 | time_VRS_diffractio.append(time.perf_counter() - tic) 32 | 33 | for i in range(101): 34 | tic = time.perf_counter() 35 | vls.CZT(z=5*mm, xout=x_out, yout=y_out, verbose=False) 36 | time_VCZT_diffractio.append(time.perf_counter() - tic) 37 | 38 | filename = f"vectorial_propagation_Diffractio.npy" 39 | np.save(filename, {"VRS_Diffractio": time_VRS_diffractio, "VCZT_Diffractio": time_VCZT_diffractio}) -------------------------------------------------------------------------------- /examples/vectorial_xlumina.py: -------------------------------------------------------------------------------- 1 | # Setting the path for XLuminA modules: 2 | current_path = os.path.abspath(os.path.join('..')) 3 | module_path = os.path.join(current_path) 4 | 5 | if module_path not in sys.path: 6 | sys.path.append(module_path) 7 | 8 | from xlumina.__init__ import mm 9 | from xlumina.wave_optics import * 10 | from xlumina.vectorized_optics import * 11 | from xlumina.optical_elements import * 12 | from xlumina.loss_functions import * 13 | from xlumina.toolbox import space 14 | import jax.numpy as jnp 15 | 16 | import time 17 | 18 | """ Computes the running times for vectorial version of Rayleigh-Sommerfeld (VRS) and Chirped z-transform (VCZT) algorithms using XLuminA """ 19 | 20 | # Light source settings 21 | resolution = 2048 22 | wavelength = .6328 * um 23 | w0 = (1200*um , 1200*um) 24 | x, y = space(15*mm, resolution) 25 | x_out, y_out = jnp.array(space(15*mm, resolution)) 26 | 27 | gb_lp = PolarizedLightSource(x, y, wavelength) 28 | gb_lp.gaussian_beam(w0=w0, jones_vector=(1, 0)) 29 | 30 | # Rayleigh-Sommerfeld: 31 | tic = time.perf_counter() 32 | gb_propagated = gb_lp.VRS_propagation(z=5*mm) 33 | print("Time taken for 1st VRS propagation - in seconds", time.perf_counter() - tic) 34 | 35 | time_VRS_xlumina = [] 36 | time_VCZT_xlumina = [] 37 | 38 | for i in range(101): 39 | tic = time.perf_counter() 40 | gb_propagated = gb_lp.VRS_propagation(z=5*mm) 41 | time_VRS_xlumina.append(time.perf_counter() - tic) 42 | 43 | # Chirped z-transform: 44 | tic = time.perf_counter() 45 | gb_propagated = gb_lp.VCZT(z=5*mm, xout=x_out, yout=y_out) 46 | print("Time taken for 1st VCZT propagation - in seconds", time.perf_counter() - tic) 47 | 48 | for i in range(101): 49 | tic = time.perf_counter() 50 | gb_propagated = gb_lp.VCZT(z=5*mm, xout=x_out, yout=y_out) 51 | time_VCZT_xlumina.append(time.perf_counter() - tic) 52 | 53 | filename = f"vectorial_propagation_xlumina_GPU.npy" 54 | np.save(filename, {"VRS_xlumina": time_VRS_xlumina, "VCZT_xlumina": time_VCZT_xlumina}) -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/four_f_optical_table.py: -------------------------------------------------------------------------------- 1 | # Setting the path for XLuminA modules: 2 | import os 3 | import sys 4 | 5 | # Setting the path for XLuminA modules: 6 | current_path = os.path.abspath(os.path.join('..')) 7 | module_path = os.path.join(current_path) 8 | 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | from xlumina.__init__ import um, nm, cm 13 | from xlumina.wave_optics import * 14 | from xlumina.optical_elements import SLM 15 | from xlumina.toolbox import space 16 | from jax import vmap 17 | import jax.numpy as jnp 18 | 19 | """ 20 | OPTICAL TABLE WITH AN OPTICAL TELESCOPE (4F-SYSTEM). 21 | """ 22 | 23 | # 1. System specs: 24 | sensor_lateral_size = 1024 # Resolution 25 | wavelength = 632.8*nm 26 | x_total = 1500*um 27 | x, y = space(x_total, sensor_lateral_size) 28 | shape = jnp.shape(x)[0] 29 | 30 | # 2. Define the light source: 31 | w0 = (1200*um , 1200*um) 32 | input_light = LightSource(x, y, wavelength) 33 | input_light.gaussian_beam(w0=w0, E0=1) 34 | 35 | # 3. Define the optical functions: 36 | def batch_dualSLM_4f(input_mask, x, y, wavelength, parameters): 37 | """ 38 | [4f system coded exclusively for batch optimization purposes]. 39 | 40 | Define an optical table with a 4f system composed by 2 SLMs (to be used with ScalarLight). 41 | 42 | Illustrative scheme: 43 | U(x,y) --> input_mask --> SLM(phase1) --> Propagate: RS(z1) --> SLM(phase2) --> Propagate: RS(z2) --> Detect 44 | 45 | Parameters: 46 | input_mask (jnp.array): Input mask, comes in the form of an array 47 | parameters (list): Parameters to pass to the optimizer [z1, z2, z3, phase1 and phase2] for RS propagation and the two SLMs. 48 | 49 | Returns the intensity (jnp.array) after second propagation, and phase masks slm1 and slm2. 50 | 51 | + Parameters in the optimizer are (0,1). We need to convert them back [Offset is determined by .get_RS_minimum_z() for the corresponding pixel resolution]. 52 | Convert (0,1) to distance in cm. Conversion factor (offset, 100) -> (offset/100, 1). 53 | Convert (0,1) to phase (in radians). Conversion factor (0, 2pi) -> (0, 1) 54 | """ 55 | global shape 56 | 57 | # From get_RS_minimum_z() 58 | offset = 1.2 59 | 60 | # Apply input mask (comes from vmap) 61 | input_light.field = input_light.field * input_mask 62 | 63 | """ Stage 0: Propagation """ 64 | # Propagate light from mask 65 | light_stage0, _ = input_light.RS_propagation(z=(jnp.abs(parameters[0]) * 100 + offset)*cm) 66 | 67 | """ Stage 0: Modulation """ 68 | # Feed SLM_1 with parameters[2] and apply the mask to the forward beam 69 | modulated_slm1, slm_1 = SLM(light_stage0, parameters[3] * (2*jnp.pi) - jnp.pi, shape) 70 | 71 | """ Stage 1: Propagation """ 72 | # Propagate the SLM_1 output beam to another distance z 73 | light_stage1, _ = modulated_slm1.RS_propagation(z=(jnp.abs(parameters[1]) * 100 + offset)*cm) 74 | 75 | """ Stage 1: Modulation """ 76 | # Apply the SLM_2 to the forward beam 77 | modulated_slm2, slm_2 = SLM(light_stage1, parameters[4] * (2*jnp.pi) - jnp.pi, shape) 78 | 79 | """ Stage 2: Propagation """ 80 | # Propagate the SLM_2 output beam to another distance z 81 | fw_to_detector, _ = modulated_slm2.RS_propagation(z=(jnp.abs(parameters[2]) * 100 + offset)*cm) 82 | 83 | return jnp.abs(fw_to_detector.field)**2, slm_1, slm_2 84 | 85 | def vector_dualSLM_4f_system(input_masks, x, y, wavelength, parameters): 86 | """ 87 | [Coded exclusively for the batch optimization]. 88 | 89 | Vectorized (efficient) version of 4f system for batch optimization. 90 | 91 | Parameters: 92 | input_masks (jnp.array): Array with input masks 93 | x, y, wavelength (jnp.arrays and float): Light specs to pass to batch_dualSLM_4f. 94 | parameters (list): Parameters to pass to the optimizer [z1, z2, z3, phase1 and phase2] for RS propagation and the two SLMs. 95 | 96 | Returns vectorized version of detected light (intensity). 97 | """ 98 | detected_intensity, _, _ = vmap(batch_dualSLM_4f, in_axes=(0, None, None, None, None))(input_masks, x, y, wavelength, parameters) 99 | return detected_intensity 100 | 101 | 102 | # 3. Define the loss function for batch optimization. 103 | def loss_dualSLM(parameters, input_masks, target_intensities): 104 | """ 105 | Loss function for 4f system batch optimization. It computes the MSE between the optimized light and the target field. 106 | 107 | Parameters: 108 | parameters (list): Optimized parameters. 109 | input_masks (jnp.array): Array with input masks. 110 | target_intensities (jnp.array): Array with target intensities. 111 | 112 | Returns the mean value of the loss computed for all the inputs. 113 | """ 114 | global x, y, wavelength 115 | # Input fields and target fields are arrays with synthetic data. Global variables defined in the optical table script. 116 | optimized_intensities = vector_dualSLM_4f_system(input_masks, x, y, wavelength, parameters) 117 | mean_loss, loss_array = mean_batch_MSE_Intensity(optimized_intensities, target_intensities) 118 | return mean_loss 119 | 120 | def mean_batch_MSE_Intensity(optimized, target): 121 | """ 122 | [Computed for batch optimization in 4f system]. Vectorized version of MSE_Intensity. 123 | 124 | Returns the mean value of all the MSE for each (optimized, target) pairs and a jnp.array with MSE values from each pair. 125 | """ 126 | MSE = vmap(MSE_Intensity, in_axes=(0, 0))(optimized, target) 127 | return jnp.mean(MSE), MSE 128 | 129 | @jit 130 | def MSE_Intensity(input_light, target_light): 131 | """ 132 | Computes the Mean Squared Error (in Intensity) for a given electric field component Ex, Ey or Ez. 133 | 134 | Parameters: 135 | input_light (array): intensity: input_light = jnp.abs(input_light.field)**2 136 | target_light (array): Ground truth - intensity in the detector: target_light = jnp.abs(target_light.field)**2 137 | 138 | Returns the MSE (jnp.array). 139 | """ 140 | num_pix = jnp.shape(input_light)[0] * jnp.shape(input_light)[1] 141 | return jnp.sum((input_light - target_light)** 2) / num_pix -------------------------------------------------------------------------------- /experiments/four_f_optimizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # Setting the path for XLuminA modules: 5 | current_path = os.path.abspath(os.path.join('..')) 6 | module_path = os.path.join(current_path) 7 | 8 | if module_path not in sys.path: 9 | sys.path.append(module_path) 10 | 11 | from four_f_optical_table import * 12 | from xlumina.toolbox import MultiHDF5DataLoader 13 | import time 14 | import jax 15 | import optax 16 | from jax import jit 17 | import numpy as np 18 | import jax.numpy as jnp 19 | 20 | """ 21 | OPTIMIZER FOR THE OPTICAL TELESCOPE (4F-SYSTEM). 22 | """ 23 | 24 | # Print device info (GPU or CPU) 25 | print(jax.devices(), flush=True) 26 | 27 | # Call the data loader and set batchsize 28 | dataloader = MultiHDF5DataLoader("training_data_4f", batch_size=10) 29 | 30 | # Define the loss function and compute its gradients: 31 | loss_function = jit(loss_dualSLM) 32 | 33 | # ---------------------------------------------------- 34 | 35 | def fit(params: optax.Params, optimizer: optax.GradientTransformation, num_iterations) -> optax.Params: 36 | 37 | opt_state = optimizer.init(params) 38 | 39 | @jit 40 | def update(params, opt_state, input_fields, target_fields): 41 | # Define single update step: 42 | # JIT the loss and compute 43 | loss_value, grads = jax.value_and_grad(loss_function, allow_int=True)(params, input_fields, target_fields) 44 | # Update the state of the optimizer 45 | updates, opt_state = optimizer.update(grads, opt_state, params) 46 | params = optax.apply_updates(params, updates) 47 | return params, opt_state, loss_value 48 | 49 | # Initialize some parameters 50 | iteration_steps=[] 51 | loss_list=[] 52 | 53 | # Optimizer settings 54 | n_best = 500 55 | best_loss = 1e2 56 | best_params = None 57 | best_step = 0 58 | 59 | print('Starting Optimization', flush=True) 60 | 61 | for step in range(num_iterations): 62 | # Load data: 63 | input_fields, target_fields = next(dataloader) 64 | params, opt_state, loss_value = update(params, opt_state, input_fields, target_fields) 65 | 66 | print(f"Step {step}") 67 | print(f"Loss {loss_value}") 68 | 69 | iteration_steps.append(step) 70 | loss_list.append(loss_value) 71 | 72 | # Update the `best_loss` value: 73 | if loss_value < best_loss: 74 | # Best loss value 75 | best_loss = loss_value 76 | # Best optimized parameters 77 | best_params = params 78 | best_step = step 79 | print('Best loss value is updated') 80 | 81 | if step % 100 == 0: 82 | # Stopping criteria: if best_loss has not changed every 500 steps, stop. 83 | if step - best_step > n_best: 84 | print(f'Stopping criterion: no improvement in loss value for {n_best} steps') 85 | break 86 | 87 | print(f'Best loss: {best_loss} at step {best_step}') 88 | print(f'Best parameters: {best_params}') 89 | return best_params, best_loss, iteration_steps, loss_list 90 | 91 | # ---------------------------------------------------- 92 | 93 | # Optimizer settings 94 | num_iterations = 50000 95 | num_samples = 1 96 | # Step size engineering: 97 | STEP_SIZE = 0.01 98 | WEIGHT_DECAY = 0.0001 99 | 100 | for i in range(num_samples): 101 | tic = time.perf_counter() 102 | 103 | # Init random parameters 104 | phase_mask_slm1 = jnp.array([np.random.uniform(0, 1, (shape, shape))], dtype=jnp.float64)[0] 105 | phase_mask_slm2 = jnp.array([np.random.uniform(0, 1, (shape, shape))], dtype=jnp.float64)[0] 106 | distance_0 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64) 107 | distance_1 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64) 108 | distance_2 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64) 109 | init_params = [distance_0, distance_1, distance_2, phase_mask_slm1, phase_mask_slm2] 110 | 111 | # Init optimizer: 112 | optimizer = optax.adamw(STEP_SIZE, weight_decay=WEIGHT_DECAY) 113 | 114 | # Apply fit function: 115 | best_params, best_loss, iteration_steps, loss_list = fit(init_params, optimizer, num_iterations) 116 | 117 | print("Time taken to optimize one sample - in seconds", time.perf_counter() - tic) -------------------------------------------------------------------------------- /experiments/generate_synthetic_data.py: -------------------------------------------------------------------------------- 1 | # Setting the path for XLuminA modules: 2 | import os 3 | import sys 4 | 5 | # Setting the path for XLuminA modules: 6 | current_path = os.path.abspath(os.path.join('..')) 7 | module_path = os.path.join(current_path) 8 | 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | import numpy as np 13 | import jax.numpy as jnp 14 | from __init__ import um, nm 15 | from xlumina.wave_optics import * 16 | from xlumina.optical_elements import * 17 | from xlumina.toolbox import space 18 | import h5py 19 | 20 | """ 21 | Synthetic data batches generation: 4f system with magnification 2x. 22 | - input_masks = jnp.array(in1, in2, ...) 23 | - target_intensity = jnp.array(out1, out2, ...) 24 | """ 25 | 26 | # System characteristics: 27 | sensor_lateral_size = 1024 # Pixel resolution 28 | wavelength = 632.8*nm 29 | x_total = 1500*um 30 | x, y = space(x_total, sensor_lateral_size) 31 | X, Y = jnp.meshgrid(x,y) 32 | 33 | # Define the light source: 34 | w0 = (1200*um , 1200*um) 35 | gb = LightSource(x, y, wavelength) 36 | gb.plane_wave(A=0.5) 37 | 38 | def generate_synthetic_circles(gb, num_samples): 39 | in_circles = [] 40 | out_circles = [] 41 | for i in range(num_samples): 42 | r1 = jnp.array(np.random.uniform(100, 1000)) 43 | r2 = jnp.array(np.random.uniform(100, 1000)) 44 | 45 | # Store only the mask (binary) 46 | in_circle = circular_mask(X, Y, r=(r1, r2)) 47 | in_circles.append(in_circle) 48 | 49 | # Magnification is 2, store only the itensity 50 | out_circle = gb.apply_circular_mask(r=(2*r1, 2*r2)) 51 | out_circles.append(jnp.abs(out_circle.field)**2) 52 | return jnp.array(in_circles), jnp.array(out_circles) 53 | 54 | def generate_synthetic_squares(gb, num_samples): 55 | in_squares = [] 56 | out_squares = [] 57 | for i in range(num_samples): 58 | width = jnp.array(np.random.uniform(100, 1000)) 59 | height = jnp.array(np.random.uniform(100, 1000)) 60 | angle = jnp.array(np.random.uniform(0, 2*jnp.pi)) 61 | 62 | # Binary mask only 63 | in_square = rectangular_mask(X, Y, center=(0,0), width=width, height=height, angle=angle) 64 | in_squares.append(in_square) 65 | 66 | # Magnification is 2 - we only need intensity 67 | out_square = gb.apply_rectangular_mask(center=(0,0), width=2*width, height=2*height, angle=-angle) 68 | out_squares.append(jnp.abs(out_square.field)**2) 69 | 70 | return jnp.array(in_squares), jnp.array(out_squares) 71 | 72 | def generate_synthetic_annular(gb, num_samples): 73 | in_annulars = [] 74 | out_annulars = [] 75 | for i in range(num_samples): 76 | di = jnp.array(np.random.uniform(100, 500)) 77 | do = jnp.array(np.random.uniform(550, 1000)) 78 | 79 | # Binary mask only: 80 | in_annular = annular_aperture(di, do, X, Y) 81 | in_annulars.append(in_annular) 82 | 83 | # Magnification is 2 - we only need intensity 84 | out_annular = gb.apply_annular_aperture(di=2*di, do=2*do) 85 | out_annulars.append(jnp.abs(out_annular.field)**2) 86 | 87 | return jnp.array(in_annulars), jnp.array(out_annulars) 88 | 89 | # Data generation loop: 90 | num_samples = 30 91 | 92 | for s in range(50): 93 | input_circles, target_circles = generate_synthetic_circles(gb, num_samples) 94 | input_squares, target_squares = generate_synthetic_squares(gb, num_samples) 95 | input_annular, target_annular = generate_synthetic_annular(gb, num_samples) 96 | 97 | input_fields = jnp.vstack([input_circles, input_squares, input_annular]) 98 | target_fields = jnp.vstack([target_circles, target_squares, target_annular]) 99 | 100 | filename = f"new_training_data_4f/new_synthetic_data_{s+150}.hdf5" 101 | 102 | with h5py.File(filename, 'w') as hdf: 103 | # Create datasets for your data 104 | hdf.create_dataset("Input fields", data=input_fields) 105 | hdf.create_dataset("Target fields", data=target_fields) -------------------------------------------------------------------------------- /experiments/hybrid_optimizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # Setting the path for XLuminA modules: 5 | current_path = os.path.abspath(os.path.join('..')) 6 | module_path = os.path.join(current_path) 7 | 8 | if module_path not in sys.path: 9 | sys.path.append(module_path) 10 | 11 | import time 12 | import jax 13 | from jax import grad, jit 14 | import optax 15 | import numpy as np 16 | import jax.numpy as jnp 17 | import gc # Garbage collector 18 | 19 | # Use this for pure topological discovery: 20 | from xlumina.six_times_six_ansatz_with_fixed_PM import * #<--- use this for 6x6 ansatz 21 | # from xlumina.hybrid_with_fixed_PM import * # <--- use this for 3x3 with fixed masks 22 | 23 | # Use this for hybrid optimization: 24 | # from xlumina.hybrid_sharp_optical_table import * # <--- use this for sharp focus 25 | # from xlumina.hybrid_sted_optical_table import * # <--- use this for sted 26 | 27 | """ 28 | OPTIMIZER - LARGE-SCALE SETUPS 29 | """ 30 | 31 | # Print device info (GPU or CPU) 32 | print(jax.devices(), flush=True) 33 | 34 | # Global variable 35 | shape = jnp.array([sensor_lateral_size, sensor_lateral_size]) 36 | 37 | # Define the loss function and compute its gradients: 38 | # loss_function = jit(loss_hybrid_sharp_focus) # <--- use this for sharp focus 39 | # loss_function = jit(loss_hybrid_sted) # <--- use this for sted 40 | loss_function = jit(loss_hybrid_fixed_PM) # <--- use this for sharp focus with fixed phase masks 41 | 42 | # ---------------------------------------------------- 43 | 44 | def clip_adamw(learning_rate, weight_decay) -> optax.GradientTransformation: 45 | """ 46 | Custom optimizer - adamw: applies several transformations in sequence 47 | 1) Apply ADAM 48 | 2) Apply weight decay 49 | """ 50 | return optax.adamw(learning_rate=learning_rate, weight_decay=weight_decay) 51 | 52 | def fit(params: optax.Params, optimizer: optax.GradientTransformation, num_iterations) -> optax.Params: 53 | 54 | # Init the optimizer with initial parameters 55 | opt_state = optimizer.init(params) 56 | 57 | @jit 58 | def update(parameters, opt_state): 59 | # Define single update step: 60 | loss_value, grads = jax.value_and_grad(loss_function)(parameters) 61 | 62 | # Update the state of the optimizer 63 | updates, state = optimizer.update(grads, opt_state, parameters) 64 | 65 | # Update the parameters 66 | new_params = optax.apply_updates(parameters, updates) 67 | 68 | return new_params, parameters, state, loss_value, updates 69 | 70 | 71 | # Initialize some parameters 72 | iteration_steps=[] 73 | loss_list=[] 74 | 75 | n_best = 500 76 | best_loss = 3*1e2 77 | best_params = None 78 | best_step = 0 79 | 80 | print('Starting Optimization', flush=True) 81 | 82 | for step in range(num_iterations): 83 | 84 | params, old_params, opt_state, loss_value, grads = update(params, opt_state) 85 | 86 | print(f"Step {step}") 87 | print(f"Loss {loss_value}") 88 | iteration_steps.append(step) 89 | loss_list.append(loss_value) 90 | 91 | # Update the `best_loss` value: 92 | if loss_value < best_loss: 93 | # Best loss value 94 | best_loss = loss_value 95 | # Best optimized parameters 96 | best_params = old_params 97 | best_step = step 98 | print('Best loss value is updated') 99 | 100 | if step % 100 == 0: 101 | # Stopping criteria: if best_loss has not changed every 500 steps, stop. 102 | if step - best_step > n_best: 103 | print(f'Stopping criterion: no improvement in loss value for {n_best} steps') 104 | break 105 | 106 | print(f'Best loss: {best_loss} at step {best_step}') 107 | print(f'Best parameters: {best_params}') 108 | return best_params, best_loss, iteration_steps, loss_list 109 | 110 | # ---------------------------------------------------- 111 | 112 | # Optimizer settings 113 | num_iterations = 100000 114 | num_samples = 20 115 | 116 | for i in range(num_samples): 117 | 118 | STEP_SIZE = 0.05 119 | WEIGHT_DECAY = 0.0001 120 | 121 | gc.collect() 122 | tic = time.perf_counter() 123 | 124 | # Parameters -- know which ones to comment based on the setup you want to optimize: 125 | # super-SLM phase masks: 126 | phase1_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0] 127 | phase1_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0] 128 | phase2_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0] 129 | phase2_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0] 130 | phase3_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0] 131 | phase3_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0] 132 | 133 | # Wave plate variables: 134 | eta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0] 135 | theta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0] 136 | eta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0] 137 | theta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0] 138 | eta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0] 139 | theta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0] 140 | eta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0] 141 | theta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0] 142 | 143 | # Distances: 144 | z1_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 145 | z1_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 146 | z2_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 147 | z2_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 148 | z3_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 149 | z3_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 150 | z4_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 151 | z4_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 152 | z4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 153 | z5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 154 | z1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 155 | z2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 156 | z3 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 157 | z4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 158 | z5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 159 | z6 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 160 | z7 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 161 | z8 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 162 | z9 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 163 | z10 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 164 | z11 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 165 | z12 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 166 | 167 | # Beam splitter ratios 168 | bs1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 169 | bs2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 170 | bs3 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 171 | bs4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 172 | bs5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 173 | bs6 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 174 | bs7 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 175 | bs8 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 176 | bs9 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 177 | bs10 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 178 | bs11 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 179 | bs12 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 180 | bs13 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 181 | bs14 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 182 | bs15 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 183 | bs16 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 184 | bs17 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 185 | bs18 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 186 | bs19 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 187 | bs20 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 188 | bs21 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 189 | bs22 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 190 | bs23 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 191 | bs24 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 192 | bs25 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 193 | bs26 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 194 | bs27 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 195 | bs28 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 196 | bs29 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 197 | bs30 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 198 | bs31 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 199 | bs32 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 200 | bs33 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 201 | bs34 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 202 | bs35 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 203 | bs36 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64) 204 | 205 | # Set which set of init parameters to use: 206 | # REMEMBER TO COMMENT (#) THE VARIABLES YOU DON'T USE! 207 | 208 | # 1. For 3x3 hybrid optimization (topology + optical parameters): 209 | # init_params = [phase1_1, phase1_2, eta1, theta1, z1_1, z1_2, phase2_1, phase2_2, eta2, theta2, z2_1, z2_2, phase3_1, phase3_2, eta3, theta3, z3_1, z3_2, bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9, z4, z5] 210 | 211 | # 2. Parameters for pure topological optimization on 3x3 systems with fixed phase masks at random positions: 212 | # init_params = [z1_1, z1_2, z2_1, z2_2, z3_1, z3_2, z4_1, z4_2, bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9, eta1, theta1, eta2, theta2, eta3, theta3, eta4, theta4] 213 | 214 | # 3. Parameters for pure topological optimization on the 6x6 system with fixed phase masks: 215 | init_params = [z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12, 216 | bs1, bs2, bs3, bs4, bs5, bs6, 217 | bs7, bs8, bs9, bs10, bs11, bs12, 218 | bs13, bs14, bs15, bs16, bs17, bs18, 219 | bs19, bs20, bs21, bs22, bs23, bs24, 220 | bs25, bs26, bs27, bs28, bs29, bs30, 221 | bs31, bs32, bs33, bs34, bs35, bs36, 222 | eta1, theta1, eta2, theta2] 223 | 224 | # Init optimizer: 225 | optimizer = clip_adamw(STEP_SIZE, WEIGHT_DECAY) 226 | 227 | # Apply fit function: 228 | best_params, best_loss, iteration_steps, loss_list = fit(init_params, optimizer, num_iterations) -------------------------------------------------------------------------------- /experiments/hybrid_sharp_optical_table.py: -------------------------------------------------------------------------------- 1 | # Setting the path for XLuminA modules: 2 | import os 3 | import sys 4 | 5 | # Setting the path for XLuminA modules: 6 | current_path = os.path.abspath(os.path.join('..')) 7 | module_path = os.path.join(current_path) 8 | 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | from xlumina.__init__ import um, nm, cm, mm 13 | from xlumina.vectorized_optics import * 14 | from xlumina.optical_elements import hybrid_setup_sharp_focus 15 | from xlumina.loss_functions import vectorized_loss_hybrid 16 | from xlumina.toolbox import space, softmin 17 | import jax.numpy as jnp 18 | 19 | """ 20 | Large-scale setup for Dorn, Quabis and Leuchs (2004) benchmark rediscovery: 21 | 22 | 3x3 initial setup - light gets detected across 6 detectors. 23 | """ 24 | 25 | # 1. System specs: 26 | sensor_lateral_size = 1024 # Resolution 27 | wavelength = 635*nm 28 | x_total = 2500*um 29 | x, y = space(x_total, sensor_lateral_size) 30 | shape = jnp.shape(x)[0] 31 | 32 | # 2. Define the optical functions: two orthogonally polarized beams: 33 | w0 = (1200*um, 1200*um) 34 | ls1 = PolarizedLightSource(x, y, wavelength) 35 | ls1.gaussian_beam(w0=w0, jones_vector=(1, 1)) 36 | 37 | # 3. Define the output (High Resolution) detection: 38 | x_out, y_out = jnp.array(space(10*um, 400)) 39 | 40 | # 4. High NA objective lens specs: 41 | NA = 0.9 42 | radius_lens = 3.6*mm/2 43 | f_lens = radius_lens / NA 44 | 45 | # 5. Static parameters - don't change during optimization: 46 | fixed_params = [radius_lens, f_lens, x_out, y_out] 47 | 48 | # 6. Define the loss function: 49 | @jit 50 | def loss_hybrid_sharp_focus(parameters): 51 | # Output from hybrid_setup is jnp.array(6, N, N): for 6 detectors 52 | detected_z_intensities, _ = hybrid_setup_sharp_focus(ls1, ls1, ls1, ls1, ls1, ls1, parameters, fixed_params) 53 | 54 | # Get the minimum value within loss value array of shape (6, 1, 1) 55 | loss_val = softmin(vectorized_loss_hybrid(detected_z_intensities)) 56 | 57 | return loss_val -------------------------------------------------------------------------------- /experiments/hybrid_sted_optical_table.py: -------------------------------------------------------------------------------- 1 | # Setting the path for XLuminA modules: 2 | import os 3 | import sys 4 | 5 | # Setting the path for XLuminA modules: 6 | current_path = os.path.abspath(os.path.join('..')) 7 | module_path = os.path.join(current_path) 8 | 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | from xlumina.__init__ import um, nm, cm, mm 13 | from xlumina.vectorized_optics import * 14 | from xlumina.optical_elements import hybrid_setup_fluorophores 15 | from xlumina.loss_functions import vectorized_loss_hybrid 16 | from xlumina.toolbox import space, softmin 17 | import jax.numpy as jnp 18 | 19 | """ 20 | Large-scale setup for STED microscopy baseline rediscovery: 21 | 22 | 3x3 initial setup - light gets detected across 6 detectors. 23 | """ 24 | 25 | # 1. System specs: 26 | sensor_lateral_size = 824 # Resolution 27 | wavelength1 = 650*nm 28 | wavelength2 = 532*nm 29 | x_total = 2500*um 30 | x, y = space(x_total, sensor_lateral_size) 31 | shape = jnp.shape(x)[0] 32 | 33 | # 2. Define the optical functions: two orthogonally polarized beams: 34 | w0 = (1200*um, 1200*um) 35 | ls1 = PolarizedLightSource(x, y, wavelength1) 36 | ls1.gaussian_beam(w0=w0, jones_vector=(1, 1)) 37 | ls2 = PolarizedLightSource(x, y, wavelength2) 38 | ls2.gaussian_beam(w0=w0, jones_vector=(1, 1)) 39 | 40 | # 3. Define the output (High Resolution) detection: 41 | x_out, y_out = jnp.array(space(10*um, 400)) 42 | 43 | # 4. High NA objective lens specs: 44 | NA = 0.9 45 | radius_lens = 3.6*mm/2 46 | f_lens = radius_lens / NA 47 | 48 | # 5. Static parameters - don't change during optimization: 49 | fixed_params = [radius_lens, f_lens, x_out, y_out] 50 | 51 | # 6. Define the loss function: 52 | @jit 53 | def loss_hybrid_sted(parameters): 54 | # Output from hybrid_setup is jnp.array(6, N, N): for 6 detectors 55 | i_effective = hybrid_setup_fluorophores(ls1, ls2, ls1, ls2, ls1, ls2, parameters, fixed_params, distance_offset = 10) 56 | 57 | # Get the minimum value within loss value array of shape (6, 1, 1) 58 | loss_val = softmin(vectorized_loss_hybrid(i_effective)) 59 | 60 | return loss_val -------------------------------------------------------------------------------- /experiments/hybrid_with_fixed_PM.py: -------------------------------------------------------------------------------- 1 | # Setting the path for XLuminA modules: 2 | import os 3 | import sys 4 | 5 | # Setting the path for XLuminA modules: 6 | current_path = os.path.abspath(os.path.join('..')) 7 | module_path = os.path.join(current_path) 8 | 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | from xlumina.__init__ import um, nm, mm 13 | from xlumina.vectorized_optics import * 14 | from xlumina.optical_elements import hybrid_setup_fixed_slms_fluorophores, hybrid_setup_fixed_slms 15 | from xlumina.loss_functions import vectorized_loss_hybrid 16 | from xlumina.toolbox import space, softmin 17 | import jax.numpy as jnp 18 | 19 | """ 20 | Large-scale setup using fixed phase masks in random positions: 21 | 22 | 3x3 initial setup - light gets detected across 6 detectors. 23 | 24 | This script is valid for rediscovering 25 | 26 | (1) Dorn, Quabis and Leuchs (2004) - use hybrid_setup_fixed_slms() in the loss function, 27 | 28 | (2) STED microscopy - use hybrid_setup_fixed_slms_fluorophores() in the loss function. 29 | """ 30 | 31 | # 1. System specs: 32 | sensor_lateral_size = 824 # Resolution 33 | wavelength_1 = 632.8*nm 34 | wavelength_2 = 530*nm 35 | x_total = 2500*um 36 | x, y = space(x_total, sensor_lateral_size) 37 | shape = jnp.shape(x)[0] 38 | 39 | # 2. Define the optical functions: two orthogonally polarized beams: 40 | w0 = (1200*um, 1200*um) 41 | ls1 = PolarizedLightSource(x, y, wavelength_1) 42 | ls1.gaussian_beam(w0=w0, jones_vector=(1, 1)) 43 | ls2 = PolarizedLightSource(x, y, wavelength_2) 44 | ls2.gaussian_beam(w0=w0, jones_vector=(1, 1)) 45 | 46 | # 3. Define the output (High Resolution) detection: 47 | x_out, y_out = jnp.array(space(10*um, 400)) 48 | X, Y = jnp.meshgrid(x,y) 49 | 50 | # 4. High NA objective lens specs: 51 | NA = 0.9 52 | radius_lens = 3.6*mm/2 53 | f_lens = radius_lens / NA 54 | 55 | # 4.1 Fixed phase masks: 56 | # Polarization converter in Dorn, Quabis, Leuchs (2004): 57 | pi_half = (jnp.pi - jnp.pi/2) * jnp.ones(shape=(sensor_lateral_size // 2, sensor_lateral_size // 2)) 58 | minus_pi_half = - jnp.pi/2 * jnp.ones(shape=(sensor_lateral_size // 2, sensor_lateral_size // 2)) 59 | PM1_1 = jnp.concatenate((jnp.concatenate((minus_pi_half, pi_half), axis=1), jnp.concatenate((minus_pi_half, pi_half), axis=1)), axis=0) 60 | PM1_2 = jnp.concatenate((jnp.concatenate((minus_pi_half, minus_pi_half), axis=1), jnp.concatenate((pi_half, pi_half), axis=1)), axis=0) 61 | # Spiral phase (STED microscopy) 62 | PM2 = jnp.arctan2(Y,X) 63 | # Forked grating 64 | PM3 = jnp.cos(2 * PM2 - 2 * jnp.pi * X/1000) * jnp.pi 65 | # Linear grating 66 | PM4_1 = jnp.sin(2*jnp.pi * Y/1000) * jnp.pi 67 | PM4_2 = jnp.sin(2*jnp.pi * X/1000) * jnp.pi 68 | 69 | # 5. Static parameters - don't change during optimization: 70 | fixed_params = [radius_lens, f_lens, x_out, y_out, PM1_1, PM1_2, PM2, PM3, PM4_1, PM4_2] 71 | 72 | # 6. Define the loss function: 73 | def loss_hybrid_fixed_PM(parameters): 74 | # Output from hybrid_setup is jnp.array(6, N, N): for 6 detectors 75 | 76 | # Use (1) for Dorn, Quabis and Leuchs benchmark / Use (2) for STED microscopy benchmark 77 | 78 | # (1): 79 | # detected_z_intensities, _ = hybrid_setup_fixed_slms(ls1, ls1, ls1, ls1, ls1, ls1, parameters, fixed_params) 80 | 81 | # (2): 82 | i_effective = hybrid_setup_fixed_slms_fluorophores(ls1, ls2, ls1, ls2, ls1, ls2, parameters, fixed_params) 83 | 84 | # Get the minimum value within loss value array of shape (6, 1, 1) 85 | loss_val = softmin(vectorized_loss_hybrid(i_effective)) 86 | 87 | return loss_val -------------------------------------------------------------------------------- /experiments/six_times_six_ansatz_with_fixed_PM.py: -------------------------------------------------------------------------------- 1 | # Setting the path for XLuminA modules: 2 | import os 3 | import sys 4 | 5 | # Setting the path for XLuminA modules: 6 | current_path = os.path.abspath(os.path.join('..')) 7 | module_path = os.path.join(current_path) 8 | 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | from __init__ import um, nm, cm, mm 13 | from xlumina.vectorized_optics import * 14 | from xlumina.optical_elements import six_times_six_ansatz 15 | from xlumina.loss_functions import vectorized_loss_hybrid 16 | from xlumina.toolbox import space, softmin 17 | import jax.numpy as jnp 18 | 19 | """ 20 | Pure topological discovery within 6x6 ansatz for Dorn, Quabis and Leuchs (2004) 21 | """ 22 | 23 | # 1. System specs: 24 | sensor_lateral_size = 824 # Resolution 25 | wavelength_1 = 635.0*nm 26 | x_total = 2500*um 27 | x, y = space(x_total, sensor_lateral_size) 28 | shape = jnp.shape(x)[0] 29 | 30 | # 2. Define the optical functions: two orthogonally polarized beams: 31 | w0 = (1200*um, 1200*um) 32 | ls1 = PolarizedLightSource(x, y, wavelength_1) 33 | ls1.gaussian_beam(w0=w0, jones_vector=(1, -1)) 34 | 35 | # 3. Define the output (High Resolution) detection: 36 | x_out, y_out = jnp.array(space(10*um, 400)) 37 | X, Y = jnp.meshgrid(x,y) 38 | 39 | # 4. High NA objective lens specs: 40 | NA = 0.9 41 | radius_lens = 3.6*mm/2 42 | f_lens = radius_lens / NA 43 | 44 | # 4.1 Fixed phase masks: 45 | # Polarization converter in Dorn, Quabis, Leuchs (2004): 46 | pi_half = (jnp.pi - jnp.pi/2) * jnp.ones(shape=(sensor_lateral_size // 2, sensor_lateral_size // 2)) 47 | minus_pi_half = - jnp.pi/2 * jnp.ones(shape=(sensor_lateral_size // 2, sensor_lateral_size // 2)) 48 | PM1_1 = jnp.concatenate((jnp.concatenate((minus_pi_half, pi_half), axis=1), jnp.concatenate((minus_pi_half, pi_half), axis=1)), axis=0) 49 | PM1_2 = jnp.concatenate((jnp.concatenate((minus_pi_half, minus_pi_half), axis=1), jnp.concatenate((pi_half, pi_half), axis=1)), axis=0) 50 | 51 | # Linear grating 52 | PM2_1 = jnp.sin(2*jnp.pi * Y/1000) * jnp.pi 53 | PM2_2 = jnp.sin(2*jnp.pi * X/1000) * jnp.pi 54 | 55 | # 5. Static parameters - don't change during optimization: 56 | fixed_params = [radius_lens, f_lens, x_out, y_out, PM1_1, PM1_2, PM2_1, PM2_2] 57 | 58 | # 6. Define the loss function: 59 | def loss_hybrid_fixed_PM(parameters): 60 | # Output from hybrid_setup is jnp.array(12, N, N): for 12 detectors 61 | i_effective = six_times_six_ansatz(ls1, ls1, ls1, ls1, ls1, ls1, ls1, ls1, ls1, ls1, ls1, ls1, parameters, fixed_params, distance_offset = 9.5) 62 | # Get the minimum value within loss value array of shape (12, 1, 1) 63 | loss_val = softmin(vectorized_loss_hybrid(i_effective)) 64 | return loss_val -------------------------------------------------------------------------------- /miscellaneous/noise-aware.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/miscellaneous/noise-aware.png -------------------------------------------------------------------------------- /miscellaneous/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/miscellaneous/performance.png -------------------------------------------------------------------------------- /miscellaneous/performance_convergence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/miscellaneous/performance_convergence.png -------------------------------------------------------------------------------- /miscellaneous/propagation_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/miscellaneous/propagation_comparison.png -------------------------------------------------------------------------------- /miscellaneous/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/miscellaneous/workflow.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from os.path import exists 4 | from pathlib import Path 5 | import re 6 | 7 | from setuptools import setup, find_packages 8 | 9 | author = 'artificial-scientist-lab' 10 | email = 'carla.rodriguez@mpl.mpg.de, soeren.arlt@mpl.mpg.de, mario.krenn@mpl.mpg.de,' 11 | description = 'XLuminA: An Auto-differentiating Discovery Framework for Super-Resolution Microscopy' 12 | dist_name = 'xlumina' 13 | package_name = 'xlumina' 14 | year = '2023' 15 | url = 'https://github.com/artificial-scientist-lab/XLuminA' 16 | 17 | setup( 18 | name=dist_name, 19 | author=author, 20 | author_email=email, 21 | url=url, 22 | version="1.0.0", 23 | packages=find_packages(), 24 | package_dir={dist_name: package_name}, 25 | include_package_data=True, 26 | license='MIT', 27 | description=description, 28 | long_description=Path('README.md').read_text() if Path('README.md').exists() else '', 29 | long_description_content_type="text/markdown", 30 | install_requires=[ 31 | 'jax==0.4.33', 32 | 'numpy', 33 | 'optax==0.2.3', 34 | 'scipy==1.14.1', 35 | 'matplotlib' 36 | ], 37 | python_requires=">=3.10", 38 | classifiers=[ 39 | 'Operating System :: OS Independent', 40 | 'Programming Language :: Python :: 3.10', 41 | 'Programming Language :: Python :: 3.11', 42 | ], 43 | platforms=['ALL'], 44 | py_modules=[package_name], 45 | ) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/tests/__init__.py -------------------------------------------------------------------------------- /tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore:.*invalid escape sequence.*:DeprecationWarning -------------------------------------------------------------------------------- /tests/test_optical_elements.py: -------------------------------------------------------------------------------- 1 | # Test for optical elements 2 | import os 3 | import sys 4 | 5 | # Setting the path for XLuminA modules: 6 | current_path = os.path.abspath(os.path.join('..')) 7 | module_path = os.path.join(current_path) 8 | 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | import unittest 13 | import jax.numpy as jnp 14 | import numpy as np 15 | from xlumina.optical_elements import ( 16 | SLM, jones_LP, jones_general_retarder, jones_sSLM, jones_LCD, 17 | sSLM, LCD, linear_polarizer, BS_symmetric, high_NA_objective_lens, 18 | VCZT_objective_lens, lens, cylindrical_lens, axicon_lens, building_block 19 | ) 20 | 21 | from xlumina.vectorized_optics import VectorizedLight, PolarizedLightSource 22 | from xlumina.wave_optics import LightSource 23 | 24 | class TestOpticalElements(unittest.TestCase): 25 | def setUp(self): 26 | self.wavelength = 633e-3 #nm 27 | self.resolution = 512 28 | self.x = np.linspace(-1500, 1500, self.resolution) 29 | self.y = np.linspace(-1500, 1500, self.resolution) 30 | self.k = 2 * jnp.pi / self.wavelength 31 | 32 | def test_slm(self): 33 | light = LightSource(self.x, self.y, self.wavelength) 34 | light.gaussian_beam(w0=(1200, 1200), E0=1) 35 | phase = jnp.zeros((self.resolution, self.resolution)) 36 | slm_output, _ = SLM(light, phase, self.resolution) 37 | self.assertEqual(slm_output.field.shape, (self.resolution, self.resolution)) # Check output shape == input shape 38 | self.assertTrue(jnp.allclose(slm_output.field, light.field)) # Phase added by SLM is 0, field shouldn't change. 39 | 40 | def test_shape_jones_matrices(self): 41 | lp = jones_LP(jnp.pi/4) 42 | self.assertEqual(lp.shape, (2, 2)) 43 | 44 | retarder = jones_general_retarder(jnp.pi/2, jnp.pi/4, 0) 45 | self.assertEqual(retarder.shape, (2, 2)) 46 | 47 | sslm = jones_sSLM(jnp.pi/2, jnp.pi/4) 48 | self.assertEqual(sslm.shape, (2, 2)) 49 | 50 | lcd = jones_LCD(jnp.pi/2, jnp.pi/4) 51 | self.assertEqual(lcd.shape, (2, 2)) 52 | 53 | def test_polarization_devices(self): 54 | light = PolarizedLightSource(self.x, self.y, self.wavelength) 55 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 1)) 56 | 57 | # super-SLM with zero phase -- input SoP is diagonal 58 | alpha = jnp.zeros((self.resolution, self.resolution)) 59 | phi = jnp.zeros((self.resolution, self.resolution)) 60 | sslm_output = sSLM(light, alpha, phi) 61 | self.assertEqual(sslm_output.Ex.shape, (self.resolution, self.resolution)) 62 | self.assertEqual(sslm_output.Ey.shape, (self.resolution, self.resolution)) 63 | self.assertEqual(sslm_output.Ez.shape, (self.resolution, self.resolution)) 64 | self.assertTrue(jnp.allclose(sslm_output.Ex, light.Ex)) 65 | self.assertTrue(jnp.allclose(sslm_output.Ey, light.Ey)) 66 | self.assertTrue(jnp.allclose(sslm_output.Ez, light.Ez)) 67 | 68 | # super-SLM with pi phase in Ex and Ey -- input SoP is diagonal 69 | alpha = jnp.pi * jnp.ones((self.resolution, self.resolution)) 70 | phi = jnp.pi * jnp.ones((self.resolution, self.resolution)) 71 | sslm_output = sSLM(light, alpha, phi) 72 | self.assertTrue(jnp.allclose(sslm_output.Ex, light.Ex * jnp.exp(1j * jnp.pi))) 73 | self.assertTrue(jnp.allclose(sslm_output.Ey, light.Ey * jnp.exp(1j * jnp.pi))) 74 | self.assertTrue(jnp.allclose(sslm_output.Ez, light.Ez)) 75 | 76 | # LCD 77 | light = PolarizedLightSource(self.x, self.y, self.wavelength) 78 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0)) 79 | lcd_output = LCD(light, 0, 0) 80 | self.assertEqual(lcd_output.Ex.shape, (self.resolution, self.resolution)) 81 | self.assertEqual(lcd_output.Ey.shape, (self.resolution, self.resolution)) 82 | self.assertEqual(lcd_output.Ez.shape, (self.resolution, self.resolution)) 83 | self.assertTrue(jnp.allclose(lcd_output.Ex, light.Ex)) 84 | self.assertTrue(jnp.allclose(lcd_output.Ey, light.Ey)) 85 | self.assertTrue(jnp.allclose(lcd_output.Ez, light.Ez)) 86 | 87 | # LP aligned with incident SoP 88 | empty = jnp.zeros((self.resolution, self.resolution)) 89 | lp_output = linear_polarizer(light, empty) 90 | self.assertEqual(lp_output.Ex.shape, (self.resolution, self.resolution)) 91 | self.assertEqual(lp_output.Ey.shape, (self.resolution, self.resolution)) 92 | self.assertEqual(lp_output.Ez.shape, (self.resolution, self.resolution)) 93 | self.assertTrue(jnp.allclose(lp_output.Ex, light.Ex)) 94 | self.assertTrue(jnp.allclose(lp_output.Ey, light.Ey)) 95 | self.assertTrue(jnp.allclose(lp_output.Ez, light.Ez)) 96 | 97 | # LP crossed to input SoP 98 | pi_half = jnp.pi/2 * jnp.ones_like(empty) 99 | lp_output = linear_polarizer(light, pi_half) 100 | self.assertTrue(jnp.allclose(lp_output.Ex, empty)) 101 | self.assertTrue(jnp.allclose(lp_output.Ey, empty)) 102 | 103 | def test_beam_splitter(self): 104 | light1 = PolarizedLightSource(self.x, self.y, self.wavelength) 105 | light1.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0)) 106 | light2 = PolarizedLightSource(self.x, self.y, self.wavelength) 107 | light2.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0)) 108 | 109 | c, d = BS_symmetric(light1, light2, 0) # fully transmissive 110 | # Adds a pi phase: jnp.exp(1j * pi) = 1j 111 | # Noise = T*0.01 112 | T = jnp.abs(jnp.cos(0)) 113 | R = jnp.abs(jnp.sin(0)) 114 | noise = 0.01 115 | self.assertTrue(jnp.allclose(c.Ex, (T - noise) * 1j * light2.Ex + (R - noise) * light1.Ex)) 116 | self.assertTrue(jnp.allclose(c.Ey, (T - noise) * 1j * light2.Ey + (R - noise) * light1.Ey)) 117 | self.assertTrue(jnp.allclose(d.Ex, (T - noise) * 1j * light1.Ex + (R - noise) * light2.Ex)) 118 | self.assertTrue(jnp.allclose(d.Ey, (T - noise) * 1j * light1.Ey + (R - noise) * light2.Ey)) 119 | 120 | def test_high_na_objective_lens(self): 121 | radius_lens = 3.6*1e3/2 # mm 122 | f_lens = radius_lens / 0.9 123 | light = PolarizedLightSource(self.x, self.y, self.wavelength) 124 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0)) 125 | output, _ = high_NA_objective_lens(light, radius_lens, f_lens) 126 | self.assertEqual(output.shape, (3, self.resolution, self.resolution)) 127 | 128 | def test_vczt_objective_lens(self): 129 | radius_lens = 3.6*1e3/2 # mm 130 | f_lens = radius_lens / 0.9 131 | light = PolarizedLightSource(self.x, self.y, self.wavelength) 132 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0)) 133 | output = VCZT_objective_lens(light, radius_lens, f_lens, self.x, self.y) 134 | self.assertEqual(output.Ex.shape, (self.resolution, self.resolution)) 135 | self.assertEqual(output.Ey.shape, (self.resolution, self.resolution)) 136 | self.assertEqual(output.Ez.shape, (self.resolution, self.resolution)) 137 | 138 | def test_lenses_scalar(self): 139 | light = LightSource(self.x, self.y, self.wavelength) 140 | light.gaussian_beam(w0=(1200, 1200), E0=1) 141 | lens_output, _ = lens(light, (50, 50), (1000, 1000)) 142 | self.assertEqual(lens_output.field.shape, (self.resolution, self.resolution)) 143 | cyl_lens_output, _ = cylindrical_lens(light, 1000) 144 | self.assertEqual(cyl_lens_output.field.shape, (self.resolution, self.resolution)) 145 | axicon_output, _ = axicon_lens(light, 0.1) 146 | self.assertEqual(axicon_output.field.shape, (self.resolution, self.resolution)) 147 | 148 | def test_lenses_vectorial(self): 149 | ls = PolarizedLightSource(self.x, self.y, self.wavelength) 150 | ls.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0)) 151 | light = VectorizedLight(self.x, self.y, self.wavelength) 152 | light.Ex = ls.Ex 153 | light.Ey = ls.Ey 154 | light.Ez = ls.Ez 155 | lens_output, _ = lens(light, (50, 50), (1000, 1000)) 156 | self.assertEqual(lens_output.Ex.shape, (self.resolution, self.resolution)) 157 | self.assertEqual(lens_output.Ey.shape, (self.resolution, self.resolution)) 158 | cyl_lens_output, _ = cylindrical_lens(light, 1000) 159 | self.assertEqual(cyl_lens_output.Ex.shape, (self.resolution, self.resolution)) 160 | self.assertEqual(cyl_lens_output.Ey.shape, (self.resolution, self.resolution)) 161 | axicon_output, _ = axicon_lens(light, 0.1) 162 | self.assertEqual(axicon_output.Ex.shape, (self.resolution, self.resolution)) 163 | self.assertEqual(axicon_output.Ey.shape, (self.resolution, self.resolution)) 164 | 165 | def test_building_block(self): 166 | light = PolarizedLightSource(self.x, self.y, self.wavelength) 167 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0)) 168 | output = building_block(light, jnp.zeros((self.resolution, self.resolution)), jnp.zeros((self.resolution, self.resolution)), 1000, jnp.pi/2, jnp.pi/4) 169 | self.assertEqual(output.Ex.shape, (self.resolution, self.resolution)) 170 | self.assertEqual(output.Ey.shape, (self.resolution, self.resolution)) 171 | self.assertEqual(output.Ez.shape, (self.resolution, self.resolution)) 172 | 173 | if __name__ == '__main__': 174 | unittest.main() -------------------------------------------------------------------------------- /tests/test_toolbox.py: -------------------------------------------------------------------------------- 1 | # Test for optical elements 2 | import os 3 | import sys 4 | 5 | # Setting the path for XLuminA modules: 6 | current_path = os.path.abspath(os.path.join('..')) 7 | module_path = os.path.join(current_path) 8 | 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | import unittest 13 | import jax.numpy as jnp 14 | import numpy as np 15 | from jax import random 16 | from xlumina.toolbox import ( 17 | space, wrap_phase, is_conserving_energy, softmin, delta_kronecker, 18 | build_LCD_cell, rotate_mask, nearest, 19 | extract_profile, gaussian, lorentzian, fwhm_1d_fit, spot_size, 20 | compute_fwhm, find_max_min, gaussian_2d 21 | ) 22 | from xlumina.vectorized_optics import VectorizedLight, PolarizedLightSource 23 | 24 | class TestToolbox(unittest.TestCase): 25 | def setUp(self): 26 | seed = 9999 27 | self.key = random.PRNGKey(seed) 28 | self.resolution = 512 29 | self.x = np.linspace(-1500, 1500, self.resolution) 30 | self.y = np.linspace(-1500, 1500, self.resolution) 31 | self.wavelength = 633e-3 32 | 33 | def test_space(self): 34 | x, y = space(1500, self.resolution) 35 | self.assertTrue(jnp.allclose(x, self.x)) 36 | self.assertTrue(jnp.allclose(y, self.y)) 37 | 38 | def test_wrap_phase(self): 39 | phase = jnp.array([0, jnp.pi, 2*jnp.pi, 3*jnp.pi, -3*jnp.pi]) 40 | wrapped = wrap_phase(phase) 41 | self.assertTrue(jnp.allclose(wrapped, jnp.array([0, jnp.pi, 0, jnp.pi, -jnp.pi]))) 42 | 43 | def test_is_conserving_energy(self): 44 | light1 = VectorizedLight(self.x, self.y, self.wavelength) 45 | light2 = VectorizedLight(self.x, self.y, self.wavelength) 46 | light1.Ex = jnp.ones((self.resolution, self.resolution)) 47 | light2.Ex = jnp.ones((self.resolution, self.resolution)) 48 | conservation = is_conserving_energy(light1, light2) 49 | self.assertAlmostEqual(conservation, 1.0, places=6) 50 | light2.Ex = 0*light2.Ex 51 | conservation = is_conserving_energy(light1, light2) 52 | self.assertEqual(conservation, 0) 53 | 54 | def test_softmin(self): 55 | result = softmin(jnp.array([1.0, 2.0, 3.0])) 56 | self.assertTrue(result == 1.0) 57 | 58 | def test_delta_kronecker(self): 59 | self.assertEqual(delta_kronecker(1, 1), 1) 60 | self.assertEqual(delta_kronecker(1, 2), 0) 61 | 62 | def test_build_LCD_cell(self): 63 | eta, theta = build_LCD_cell(jnp.pi/2, jnp.pi/4, self.resolution) 64 | self.assertTrue(jnp.allclose(eta, jnp.pi/2 * jnp.ones((self.resolution, self.resolution)))) 65 | self.assertTrue(jnp.allclose(theta, jnp.pi/4 * jnp.ones((self.resolution, self.resolution)))) 66 | 67 | def test_rotate_mask(self): 68 | X, Y = jnp.meshgrid(self.x, self.y) 69 | Xrot, Yrot = rotate_mask(X, Y, jnp.pi/4) 70 | self.assertEqual(Xrot.shape, (self.resolution, self.resolution)) 71 | self.assertEqual(Yrot.shape, (self.resolution, self.resolution)) 72 | 73 | def test_nearest(self): 74 | array = jnp.array([1, 2, 3, 4, 5]) 75 | idx, value, distance = nearest(array, 3.7) 76 | self.assertEqual(idx, 3) 77 | self.assertEqual(value, 4) 78 | self.assertAlmostEqual(distance, 0.3, places=6) 79 | 80 | def test_extract_profile(self): 81 | data_2d = jnp.ones((10, 10)) 82 | x_points = jnp.array([0, 1, 2]) 83 | y_points = jnp.array([0, 1, 2]) 84 | profile = extract_profile(data_2d, x_points, y_points) 85 | self.assertEqual(profile.shape, x_points.shape) 86 | 87 | def test_gaussian(self): 88 | y = gaussian(self.x, 1, 0, 1) 89 | self.assertEqual(y.shape, self.x.shape) 90 | 91 | def test_lorentzian(self): 92 | y = lorentzian(self.x, 0, 1) 93 | self.assertEqual(y.shape, self.x.shape) 94 | 95 | def test_fwhm_1d_fit(self): 96 | sigma = 120 97 | amplitude = 1000 98 | mean = 0 99 | y = gaussian(self.x, amplitude, mean, sigma) 100 | _, fwhm, _ = fwhm_1d_fit(self.x, y, fit='gaussian') 101 | fwhm_theoretical = 2*sigma*jnp.sqrt(2*jnp.log(2)) # 2*sigma*sqrt(2*ln2) is the theoretical FWHM for a gaussian 102 | self.assertAlmostEqual(fwhm, fwhm_theoretical, places=2) 103 | 104 | def test_spot_size(self): 105 | size = spot_size(1, 1, self.wavelength) 106 | self.assertGreater(size, 0) 107 | 108 | def test_compute_fwhm(self): 109 | sigma = 120 110 | light_1d = gaussian(self.x, 1000, 0, sigma) 111 | XY = jnp.meshgrid(self.x, self.y) 112 | light_2d = gaussian_2d(XY, 1000, 0, 0, sigma, sigma) 113 | 114 | popt, fwhm, r_squared = compute_fwhm(light_1d, [self.x, self.y], field='Intensity', fit = 'gaussian', dimension='1D') 115 | fwhm_theoretical = 2*sigma*jnp.sqrt(2*jnp.log(2)) 116 | self.assertAlmostEqual(fwhm, fwhm_theoretical, places=4) 117 | 118 | popt, fwhm, r_squared = compute_fwhm(light_2d, [self.x, self.y], field='Intensity', fit = 'gaussian', dimension='2D') 119 | fwhm_x, fwhm_y = fwhm 120 | fwhm_theoretical = 2*sigma*jnp.sqrt(2*jnp.log(2)) 121 | self.assertAlmostEqual(fwhm_x, fwhm_theoretical, places=4) 122 | self.assertAlmostEqual(fwhm_y, fwhm_theoretical, places=4) 123 | 124 | def test_find_max_min(self): 125 | value = jnp.array([[1, 2], [3, 4]]) 126 | idx, xy, ext_value = find_max_min(value, self.x[:2], self.y[:2], kind='max') 127 | self.assertEqual(idx.shape, (1, 2)) 128 | self.assertEqual(xy.shape, (1, 2)) 129 | self.assertEqual(ext_value, 4) 130 | idx, xy, ext_value = find_max_min(value, self.x[:2], self.y[:2], kind='min') 131 | self.assertEqual(ext_value, 1) 132 | 133 | if __name__ == '__main__': 134 | unittest.main() -------------------------------------------------------------------------------- /tests/test_vectorized_optics.py: -------------------------------------------------------------------------------- 1 | # Test for vectorized optics module 2 | import os 3 | import sys 4 | 5 | # Setting the path for XLuminA modules: 6 | current_path = os.path.abspath(os.path.join('..')) 7 | module_path = os.path.join(current_path) 8 | 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | import unittest 13 | import jax.numpy as jnp 14 | import numpy as np 15 | from xlumina.vectorized_optics import PolarizedLightSource, VectorizedLight 16 | 17 | class TestVectorizedOptics(unittest.TestCase): 18 | def setUp(self): 19 | self.wavelength = 633e-3 #nm 20 | self.resolution = 1024 21 | self.x = np.linspace(-1500, 1500, self.resolution) 22 | self.y = np.linspace(-1500, 1500, self.resolution) 23 | self.k = 2 * jnp.pi / self.wavelength 24 | 25 | def test_vectorized_light(self): 26 | light = VectorizedLight(self.x, self.y, self.wavelength) 27 | self.assertEqual(light.wavelength, self.wavelength) 28 | self.assertEqual(light.k, self.k) 29 | self.assertEqual(light.Ex.shape, (self.resolution, self.resolution)) 30 | self.assertEqual(light.Ey.shape, (self.resolution, self.resolution)) 31 | self.assertEqual(light.Ez.shape, (self.resolution, self.resolution)) 32 | 33 | def test_polarized_light_source_horizontal(self): 34 | source = PolarizedLightSource(self.x, self.y, self.wavelength) 35 | source.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0)) 36 | # TEST POLARIZATION 37 | self.assertGreater(jnp.sum(jnp.abs(source.Ex)**2), 0) 38 | self.assertEqual(jnp.sum(jnp.abs(source.Ey)**2), 0) 39 | 40 | def test_polarized_light_source_vertical(self): 41 | source = PolarizedLightSource(self.x, self.y, self.wavelength) 42 | source.gaussian_beam(w0=(1200, 1200), jones_vector=(0, 1)) 43 | # TEST POLARIZATION 44 | self.assertGreater(jnp.sum(jnp.abs(source.Ey)**2), 0) 45 | self.assertEqual(jnp.sum(jnp.abs(source.Ex)**2), 0) 46 | 47 | def test_polarized_light_source_diagonal(self): 48 | source = PolarizedLightSource(self.x, self.y, self.wavelength) 49 | source.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 1)) 50 | # TEST POLARIZATION 51 | self.assertGreater(jnp.sum(jnp.abs(source.Ex)**2), 0) 52 | self.assertGreater(jnp.sum(jnp.abs(source.Ey)**2), 0) 53 | 54 | def test_vrs_propagation(self): 55 | light = PolarizedLightSource(self.x, self.y, self.wavelength) 56 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 1)) 57 | propagated, _ = light.VRS_propagation(z=1000) 58 | self.assertEqual(propagated.Ex.shape, (self.resolution, self.resolution)) 59 | self.assertEqual(propagated.Ey.shape, (self.resolution, self.resolution)) 60 | self.assertEqual(propagated.Ez.shape, (self.resolution, self.resolution)) 61 | # Check same SoP 62 | self.assertGreater(jnp.sum(jnp.abs(light.Ex)**2), 0) 63 | self.assertGreater(jnp.sum(jnp.abs(light.Ey)**2), 0) 64 | 65 | def test_vczt(self): 66 | light = PolarizedLightSource(self.x, self.y, self.wavelength) 67 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 1)) 68 | propagated = light.VCZT(1000, self.x, self.y) 69 | self.assertEqual(propagated.Ex.shape, (self.resolution, self.resolution)) 70 | self.assertEqual(propagated.Ey.shape, (self.resolution, self.resolution)) 71 | self.assertEqual(propagated.Ez.shape, (self.resolution, self.resolution)) 72 | # Check same SoP 73 | self.assertGreater(jnp.sum(jnp.abs(light.Ex)**2), 0) 74 | self.assertGreater(jnp.sum(jnp.abs(light.Ey)**2), 0) 75 | 76 | if __name__ == '__main__': 77 | unittest.main() 78 | -------------------------------------------------------------------------------- /tests/test_wave_optics.py: -------------------------------------------------------------------------------- 1 | # Test for wave optics module 2 | import os 3 | import sys 4 | 5 | # Setting the path for XLuminA modules: 6 | current_path = os.path.abspath(os.path.join('..')) 7 | module_path = os.path.join(current_path) 8 | 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | import unittest 13 | import jax.numpy as jnp 14 | import numpy as np 15 | from xlumina.wave_optics import ScalarLight, LightSource 16 | 17 | class TestWaveOptics(unittest.TestCase): 18 | def setUp(self): 19 | self.wavelength = 633e-3 #nm 20 | self.resolution = 1024 21 | self.x = np.linspace(-1500, 1500, self.resolution) 22 | self.y = np.linspace(-1500, 1500, self.resolution) 23 | self.k = 2 * jnp.pi / self.wavelength 24 | 25 | def test_scalar_light(self): 26 | light = ScalarLight(self.x, self.y, self.wavelength) 27 | self.assertEqual(light.wavelength, self.wavelength) 28 | self.assertEqual(light.k, self.k) 29 | self.assertEqual(light.field.shape, (self.resolution, self.resolution)) 30 | 31 | def test_light_source_gb(self): 32 | source = LightSource(self.x, self.y, self.wavelength) 33 | source.gaussian_beam(w0=(1200, 1200), E0=1) 34 | self.assertEqual(source.wavelength, self.wavelength) 35 | self.assertEqual(source.field.shape, (self.resolution, self.resolution)) 36 | self.assertGreater(jnp.sum(jnp.abs(source.field)**2), 0) 37 | 38 | def test_light_source_pw(self): 39 | source = LightSource(self.x, self.y, self.wavelength) 40 | source.plane_wave(A=1, theta=0, phi=0, z0=0) 41 | self.assertGreater(jnp.sum(jnp.abs(source.field)**2), 0) 42 | 43 | def test_rs_propagation(self): 44 | light = LightSource(self.x, self.y, self.wavelength) 45 | light.gaussian_beam(w0=(1200, 1200), E0=1) 46 | propagated, _ = light.RS_propagation(z=1000) 47 | self.assertEqual(propagated.field.shape, (self.resolution, self.resolution)) 48 | 49 | def test_czt(self): 50 | light = LightSource(self.x, self.y, self.wavelength) 51 | light.gaussian_beam(w0=(1200, 1200), E0=1) 52 | propagated = light.CZT(z=1000) 53 | self.assertEqual(propagated.field.shape, (self.resolution, self.resolution)) 54 | 55 | if __name__ == '__main__': 56 | unittest.main() -------------------------------------------------------------------------------- /xlumina/__init__.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | """ 4 | Define basic units: 5 | 6 | um = microns -> we set um = 1. 7 | mm = nanometers 8 | mm = milimeters 9 | cm = centimeters 10 | 11 | radians -> we set radians = 1. 12 | degrees = 180 / jnp.pi -> When *degrees, the units are transformed to degrees. 13 | """ 14 | 15 | um = 1 16 | nm = 1e-3 17 | mm = 1e3 18 | cm = 1e4 19 | 20 | radians = 1 21 | degrees = 180/jnp.pi -------------------------------------------------------------------------------- /xlumina/loss_functions.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import jit, vmap, config 3 | from .__init__ import um 4 | 5 | # Set this to False if f64 is enough precision for you. 6 | enable_float64 = True 7 | if enable_float64: 8 | config.update("jax_enable_x64", True) 9 | 10 | """ Loss functions: 11 | 12 | - small_area_hybrid 13 | - vectorized_loss_hybrid 14 | - mean_batch_MSE_Intensity 15 | - vMSE_Amplitude 16 | - vMSE_Phase 17 | - vMSE_Intensity 18 | - MSE_Amplitude 19 | - MSE_Phase 20 | - MSE_Intensity 21 | 22 | """ 23 | 24 | def small_area_hybrid(detected_intensity): 25 | """ 26 | [Small area loss function valid for hybrid (topology + optical parameters) optimization:] 27 | 28 | Computes the fraction of intensity comprised inside the area of a mask. 29 | 30 | Parameters: 31 | detected_intensity (jnp.array): Detected intensity array 32 | + epsilon (float): fraction of minimum intensity comprised inside the area. 33 | 34 | Return type jnp.array. 35 | """ 36 | epsilon = 0.7 37 | eps = 1e-08 38 | I = detected_intensity / (jnp.sum(detected_intensity) + eps) 39 | mask = jnp.where(I > epsilon*jnp.max(I), 1, 0) 40 | return jnp.sum(mask) / (jnp.sum(mask * I) + eps) 41 | 42 | @jit 43 | def vectorized_loss_hybrid(detected_intensities): 44 | """[For loss_hybrid]: vectorizes loss function to be used across various detectors""" 45 | # Input field has (M, N, N) shape 46 | vloss = vmap(small_area_hybrid, in_axes = (0)) 47 | # Call the vectorized function 48 | loss_val = vloss(detected_intensities) 49 | # Returns (M, 1, 1) shape 50 | return loss_val 51 | 52 | def mean_batch_MSE_Intensity(optimized, target): 53 | """ 54 | [Computed for batch optimization in 4f system]. Vectorized version of MSE_Intensity. 55 | 56 | Returns the mean value of all the MSE for each (optimized, target) pairs and a jnp.array with MSE values from each pair. 57 | """ 58 | MSE = vmap(MSE_Intensity, in_axes=(0, 0))(optimized, target) 59 | return jnp.mean(MSE), MSE 60 | 61 | def vMSE_Amplitude(input_light, target_light): 62 | """ 63 | Computes the Mean Squared Error (in Amplitude) for each electric field component (computed in parallel) Ei (i = x, y, z). 64 | 65 | Parameters: 66 | input_field (object): VectorizedLight in the focal plane of an objective lens. 67 | target_field (object): VectorizedLight in the focal plane of an objective lens. 68 | 69 | Returns the MSE in jnp.array [MSEx, MSEy, MSEz]. 70 | """ 71 | E_in = jnp.stack([input_light.Ex, input_light.Ey, input_light.Ez], axis=-1) 72 | E_target = jnp.stack([target_light.Ex, target_light.Ey, target_light.Ez], axis=-1) 73 | vectorized_MSE = vmap(MSE_Amplitude, in_axes=(2, 2)) 74 | MSE_out = vectorized_MSE(E_in, E_target) 75 | return MSE_out 76 | 77 | def vMSE_Phase(input_light, target_light): 78 | """ 79 | Computes the Mean Squared Error (in Phase) for each electric field component (computed in parallel) Ei (i = x, y, z). 80 | 81 | Parameters: 82 | input_field (object): VectorizedLight in the focal plane of an objective lens. 83 | target_field (object): VectorizedLight in the focal plane of an objective lens. 84 | 85 | Returns the MSE in jnp.array [MSEx, MSEy, MSEz]. 86 | """ 87 | E_in = jnp.stack([input_light.Ex, input_light.Ey, input_light.Ez], axis=-1) 88 | E_target = jnp.stack([target_light.Ex, target_light.Ey, target_light.Ez], axis=-1) 89 | vectorized_MSE = vmap(MSE_Phase, in_axes=(2, 2)) 90 | MSE_out = vectorized_MSE(E_in, E_target) 91 | return MSE_out 92 | 93 | def vMSE_Intensity(input_light, target_light): 94 | """ 95 | Computes the Mean Squared Error (in Intensity) for each electric field component (computed in parallel) Ei (i = x, y, z). 96 | 97 | Parameters: 98 | input_field (object): VectorizedLight in the focal plane of an objective lens. 99 | target_field (object): VectorizedLight in the focal plane of an objective lens. 100 | 101 | Returns the MSE in jnp.array [MSEx, MSEy, MSEz]. 102 | """ 103 | E_in = jnp.stack([input_light.Ex, input_light.Ey, input_light.Ez], axis=-1) 104 | E_target = jnp.stack([target_light.Ex, target_light.Ey, target_light.Ez], axis=-1) 105 | vectorized_MSE = vmap(MSE_Intensity, in_axes=(2, 2)) 106 | MSE_out = vectorized_MSE(E_in, E_target) 107 | return MSE_out 108 | 109 | @jit 110 | def MSE_Amplitude(input_light, target_light): 111 | """ 112 | Computes the Mean Squared Error (in Amplitude) for a given electric field component Ex, Ey or Ez. 113 | 114 | Parameters: 115 | input_light (array): If origin light is VectorizedLight, field Ex, Ey or Ez in the detector. For ScalarLight, it corresponds to .field. 116 | target_light (array): Ground truth - field Ex, Ey or Ez in the detector. 117 | 118 | Returns the MSE (jnp.array). 119 | """ 120 | num_pix = input_light.shape[0] * input_light.shape[1] 121 | return jnp.sum((jnp.abs(input_light) - jnp.abs(target_light)) ** 2) / num_pix 122 | 123 | @jit 124 | def MSE_Phase(input_light, target_light): 125 | """ 126 | Computes the Mean Squared Error (in Phase) for a given electric field component Ex, Ey or Ez. 127 | 128 | Parameters: 129 | input_light (array): If origin light is VectorizedLight, field Ex, Ey or Ez in the detector. For ScalarLight, it corresponds to .field. 130 | target_light (array): Ground truth - field Ex, Ey or Ez in the detector. 131 | 132 | Returns the MSE (jnp.array). 133 | """ 134 | num_pix = input_light.shape[0] * input_light.shape[1] 135 | return jnp.sum((jnp.angle(input_light) - jnp.angle(target_light)) ** 2) / num_pix 136 | 137 | @jit 138 | def MSE_Intensity(input_light, target_light): 139 | """ 140 | Computes the Mean Squared Error (in Intensity) for a given electric field component Ex, Ey or Ez. 141 | 142 | Parameters: 143 | input_light (array): If origin light is VectorizedLight, field Ex, Ey or Ez in the detector. For ScalarLight, it corresponds to .field. 144 | target_light (array): Ground truth - field Ex, Ey or Ez in the detector. 145 | 146 | Returns the MSE (jnp.array). 147 | """ 148 | num_pix = jnp.shape(input_light)[0] * jnp.shape(input_light)[1] 149 | return jnp.sum((jnp.abs(input_light)**2 - jnp.abs(target_light)**2) ** 2) / num_pix -------------------------------------------------------------------------------- /xlumina/toolbox.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | import h5py 4 | import random 5 | from PIL import Image 6 | from jax import config, jit, nn 7 | from scipy.optimize import curve_fit 8 | import matplotlib.pyplot as plt 9 | 10 | # Set this to False if f64 is enough precision for you. 11 | enable_float64 = True 12 | if enable_float64: 13 | config.update("jax_enable_x64", True) 14 | d_type = np.float64 15 | 16 | """ 17 | Contains useful functions: 18 | 19 | - space 20 | - wrap_phase 21 | - is_conserving_energy 22 | - softmin 23 | - delta_kronecker 24 | - build_LCD_cell 25 | - draw_sSLM 26 | NEW:- draw_sSLM_amplitude 27 | - moving_avg 28 | NEW:- image_to_binary_mask 29 | - rotate_mask 30 | - nearest 31 | - extract_profile 32 | - profile 33 | NEW:- gaussian 34 | NEW:- gaussian_2d 35 | NEW:- lorentzian 36 | NEW:- lorentzian_2d 37 | NEW:- fwhm_1d_fit 38 | NEW:- fwhm_2d_fit 39 | - spot size 40 | NEW:- compute_fwhm 41 | - find_max_min 42 | - fwhm_1d (no fitting) 43 | - fwhm_2d (no fitting) 44 | NEW:> Functions to process datasets (hdf5 files): 45 | CLASS MultiHDF5DataLoader 46 | 47 | """ 48 | 49 | def space(x_total, num_pix): 50 | """ 51 | Define the space of the simulation. 52 | 53 | Parameters: 54 | x_total (float): Length of half the array (in microns). 55 | num_pix (float): 1D resolution. 56 | 57 | Returns x and y (jnp.array). 58 | """ 59 | x = np.linspace(-x_total, x_total, num_pix, dtype=d_type) 60 | y = np.linspace(-x_total, x_total, num_pix, dtype=d_type) 61 | return x, y 62 | 63 | def wrap_phase(phase): 64 | """ 65 | Wraps the input phase into [-pi, pi] range. 66 | 67 | Parameters: 68 | phase (jnp.array): Phase to wrap. 69 | 70 | Returns the wrapped-phase (jnp.array). 71 | """ 72 | return jnp.arctan2(jnp.sin(phase), jnp.cos(phase)) 73 | 74 | def is_conserving_energy(light_source, propagated_light): 75 | """ 76 | Computes the total intensity from the light source and compares the propagated light. 77 | [Ref: J. Li, Z. Fan, Y. Fu, Proc. SPIE 4915, (2002)]. 78 | 79 | Parameters: 80 | light_source (object): VectorizedLight light source. 81 | propagated_light (object): Propagated VectorizedLight. 82 | 83 | Ideally, Itotal_propagated / I_source = 1. 84 | Values <1 can happen when light is lost (i.e., light gets outside the computational window). 85 | 86 | Returns Itotal_propagated / I_source (jnp.array). 87 | """ 88 | if light_source.info == 'Wave optics light' or light_source.info == 'Wave optics light source': 89 | I_source = jnp.sum(jnp.abs(light_source.field**2)) 90 | I_propagated = jnp.sum(jnp.abs(propagated_light.field**2)) 91 | 92 | else: 93 | I_source = jnp.sum(jnp.abs(light_source.Ex**2)) + jnp.sum(jnp.abs(light_source.Ey**2)) + jnp.sum(jnp.abs(light_source.Ez**2)) 94 | I_propagated = jnp.sum(jnp.abs(propagated_light.Ex**2)) + jnp.sum(jnp.abs(propagated_light.Ey**2)) + jnp.sum(jnp.abs(propagated_light.Ez**2)) 95 | 96 | return I_propagated / I_source 97 | 98 | @jit 99 | def softmin(args, beta=90): 100 | """ 101 | Differentiable version for min() function. 102 | """ 103 | return - nn.logsumexp(-beta * args) / beta 104 | 105 | 106 | def delta_kronecker(a, b): 107 | """ 108 | Computes the Kronecker delta. 109 | 110 | Parameters: 111 | a (float): Number 112 | b (float): Number 113 | 114 | Returns (int). 115 | """ 116 | if a == b: 117 | return 1 118 | else: 119 | return 0 120 | 121 | def build_LCD_cell(eta, theta, shape): 122 | """ 123 | Builds the LCD cell: eta and theta are constant across the cell [not pixel-wise modulation!!]. 124 | 125 | Parameters: 126 | eta (float): Phase difference between Ex and Ey (in radians). 127 | theta (float): Tilt of the fast axis w.r.t. horizontal (in radians). 128 | shape (float): 1D resolution. 129 | 130 | Returns the phase and tilt (jnp.array). 131 | """ 132 | # Builds constant eta and theta LCD cell 133 | eta_array = eta * jnp.ones(shape=(shape, shape)) 134 | theta_array = theta * jnp.ones(shape=(shape, shape)) 135 | return eta_array, theta_array 136 | 137 | def draw_sSLM(alpha, phi, extent, extra_title=None, save_file=False, filename=''): 138 | """ 139 | Plots the phase masks of the sSLM (for VectorizedLight). 140 | 141 | Parameters: 142 | alpha (jnp.array): Phase mask to be applied to Ex (in radians). 143 | phi (jnp.array): Phase mask to be applied to Ey (in radians). 144 | extent (jnp.array): Limits for x and y for plotting purposes. 145 | extra_title (str): Adds extra info to the plot title. 146 | save_file (bool): If True, saves the figure. 147 | filename (str): Name of the figure. 148 | """ 149 | fig, axes = plt.subplots(1, 2, figsize=(14, 3)) 150 | cmap = 'twilight' 151 | 152 | ax = axes[0] 153 | im = ax.imshow(alpha, cmap=cmap, extent=extent, origin='lower') 154 | ax.set_title(f"SLM #1. {extra_title}") 155 | ax.set_xlabel('$x (\mu m)$') 156 | ax.set_ylabel('$y (\mu m)$') 157 | fig.colorbar(im, ax=ax) 158 | im.set_clim(vmin=-jnp.pi, vmax=jnp.pi) 159 | 160 | ax = axes[1] 161 | im = ax.imshow(phi, cmap=cmap, extent=extent, origin='lower') 162 | ax.set_title(f"SLM #2. {extra_title}") 163 | ax.set_xlabel('$x (\mu m)$') 164 | ax.set_ylabel('$y (\mu m)$') 165 | fig.colorbar(im, ax=ax) 166 | im.set_clim(vmin=-jnp.pi, vmax=jnp.pi) 167 | 168 | plt.tight_layout() 169 | 170 | if save_file is True: 171 | plt.savefig(filename) 172 | print(f"Plot saved as {filename}") 173 | 174 | plt.show() 175 | 176 | def draw_sSLM_amplitude(A1, A2, extent, extra_title=None, save_file=False, filename=''): 177 | """ 178 | Plots the amplitude masks of the sSLM (for VectorizedLight). 179 | 180 | Parameters: 181 | A1 (jnp.array): Amp mask to be applied to Ex (AU). 182 | A2 (jnp.array): Amp mask to be applied to Ey (AU). 183 | extent (jnp.array): Limits for x and y for plotting purposes. 184 | extra_title (str): Adds extra info to the plot title. 185 | save_file (bool): If True, saves the figure. 186 | filename (str): Name of the figure. 187 | """ 188 | fig, axes = plt.subplots(1, 2, figsize=(14, 3)) 189 | cmap = 'Greys_r' 190 | 191 | ax = axes[0] 192 | im = ax.imshow(A1, cmap=cmap, extent=extent, origin='lower') 193 | ax.set_title(f"SLM #1: amplitude mask {extra_title}") 194 | ax.set_xlabel('$x (\mu m)$') 195 | ax.set_ylabel('$y (\mu m)$') 196 | fig.colorbar(im, ax=ax) 197 | im.set_clim(vmin=0, vmax=1) 198 | 199 | ax = axes[1] 200 | im = ax.imshow(A2, cmap=cmap, extent=extent, origin='lower') 201 | ax.set_title(f"SLM #2: amplitude mask {extra_title}") 202 | ax.set_xlabel('$x (\mu m)$') 203 | ax.set_ylabel('$y (\mu m)$') 204 | fig.colorbar(im, ax=ax) 205 | im.set_clim(vmin=0, vmax=1) 206 | 207 | plt.tight_layout() 208 | 209 | if save_file is True: 210 | plt.savefig(filename) 211 | print(f"Plot saved as {filename}") 212 | 213 | plt.show() 214 | 215 | 216 | def moving_avg(window_size, data): 217 | """ 218 | Compute the moving average of a dataset. 219 | 220 | Parameters: 221 | window_size (int): Number of datapoints to compute the avg. 222 | data (jnp.array): Data. 223 | 224 | Returns moving average (jnp.array). 225 | """ 226 | return jnp.convolve(jnp.array(data), jnp.ones(window_size)/window_size, mode='valid') 227 | 228 | def image_to_binary_mask(filename, x, y, mirror = 'vertical', normalize=True, invert=False, threshold=0.5): 229 | """ 230 | Converts image > binary mask (given a threshold) 231 | 232 | Parameters: 233 | filename (str): os path to the image file 234 | x, y (jnp.array): corresponds to the space dimensions where the mask is placed 235 | mirror (str): direction to mirror the image. Can be 'horizontal', 'vertical' or 'both'. Default is 'vertical' 236 | normalize (bool): if True, normalizes the image 237 | invert (bool): if True, inverts the binarization 238 | threshold (float): pix value threshold for binarization (0 to 1) 239 | 240 | Returns: 241 | Binary mask (jnp.array) 242 | """ 243 | with Image.open(filename) as img: 244 | img_gray = img.convert('L') 245 | 246 | if mirror == 'horizontal': 247 | img_gray = img_gray.transpose(Image.FLIP_LEFT_RIGHT) 248 | if mirror == 'vertical': 249 | img_gray = img_gray.transpose(Image.FLIP_TOP_BOTTOM) 250 | if mirror == 'both': 251 | img_gray = img_gray.transpose(Image.FLIP_LEFT_RIGHT) 252 | img_gray = img_gray.transpose(Image.FLIP_TOP_BOTTOM) 253 | 254 | size_x, size_y = jnp.size(x), jnp.size(y) 255 | img_gray = img_gray.resize((size_y, size_x)) 256 | 257 | img_array = jnp.array(img_gray) # convert into jax array 258 | 259 | if normalize: 260 | img_array = (img_array - jnp.min(img_array)) / (jnp.max(img_array) - jnp.min(img_array)) 261 | 262 | if invert: 263 | img_array = jnp.max(img_array) - img_array 264 | 265 | binary_mask = (img_array > threshold).astype(jnp.uint8) 266 | 267 | return binary_mask 268 | 269 | def rotate_mask(X, Y, angle, origin=None): 270 | """ 271 | Rotates the (X, Y) frame of a mask w.r.t. origin. 272 | 273 | Parameters: 274 | origin (float, float): Coordinates w.r.t. which perform the rotation (in microns). 275 | angle (float): Rotation angle (in radians). 276 | 277 | Returns the rotated meshgrid X, Y (jnp.array). 278 | 279 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) << 280 | """ 281 | if origin is None: 282 | x0 = (X[-1] + X[0]) / 2 283 | y0 = (Y[-1] + Y[0]) / 2 284 | else: 285 | x0, y0 = origin 286 | 287 | Xrot = (X - x0) * jnp.cos(angle) + (Y - y0) * jnp.sin(angle) 288 | Yrot = -(X - x0) * jnp.sin(angle) + (Y - y0) * jnp.cos(angle) 289 | return Xrot, Yrot 290 | 291 | def nearest(array, value): 292 | """ 293 | Finds the nearest value and its index in an array. 294 | 295 | Parameters: 296 | array (jnp.array): Array to analyze. 297 | value (float): Number to which determine the position. 298 | 299 | Returns index (idx), the value of the array at idx position and the distance. 300 | 301 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) << 302 | """ 303 | idx = (jnp.abs(array - value)).argmin() 304 | return idx, array[idx], abs(array[idx] - value) 305 | 306 | def extract_profile(data_2d, x_points, y_points): 307 | """ 308 | [From profile] Extract the values along a line defined by x_points and y_points. 309 | 310 | Parameters: 311 | data_2d (jnp.array): Input data from which extract the profile. 312 | x_points, y_points (jnp.array): X and Y arrays. 313 | 314 | Returns the profile (jnp.array). 315 | """ 316 | x_indices = jnp.round(x_points).astype(int) 317 | y_indices = jnp.round(y_points).astype(int) 318 | 319 | x_indices = jnp.clip(x_indices, 0, data_2d.shape[1] - 1) 320 | y_indices = jnp.clip(y_indices, 0, data_2d.shape[0] - 1) 321 | 322 | profile = [data_2d[y, x] for x, y in zip(x_indices, y_indices)] 323 | return jnp.array(profile) 324 | 325 | def profile(data_2d, x, y, point1='', point2=''): 326 | """ 327 | Determine profile for a given input without using interpolation. 328 | 329 | Parameters: 330 | data_2d (jnp.array): Input 2D array from which extract the profile. 331 | point1 (float, float): Initial point. 332 | point2 (float, float): Final point. 333 | 334 | Returns the profile (h and z) of the input (jnp.array). 335 | """ 336 | x1, y1 = point1 337 | x2, y2 = point2 338 | 339 | ix1, value, distance = nearest(x, x1) 340 | ix2, value, distance = nearest(x, x2) 341 | iy1, value, distance = nearest(y, y1) 342 | iy2, value, distance = nearest(y, y2) 343 | 344 | # Create a set of x and y points along the line between point1 and point2 345 | x_points = jnp.linspace(ix1, ix2, int(jnp.hypot(ix2-ix1, iy2-iy1))) 346 | y_points = jnp.linspace(iy1, iy2, int(jnp.hypot(ix2-ix1, iy2-iy1))) 347 | 348 | h = jnp.linspace(0, jnp.sqrt((y2 - y1)**2 + (x2 - x1)**2), len(x_points)) 349 | h = h - h[-1] / 2 350 | 351 | z_profile = extract_profile(data_2d, x_points, y_points) 352 | 353 | return h, z_profile 354 | 355 | def gaussian(x, amplitude, mean, std_dev): 356 | """ 357 | [In fwhm_1d_fit] 358 | Returns 1D Gaussian (Normal distribution) for FWHM calculation 359 | 360 | Parameters: 361 | x (jnp.array): 1D-position array 362 | mean_x (float): location of the peak of the distribution in X or Y 363 | stdev_x (float): standard deviation in X or Y 364 | """ 365 | return amplitude * jnp.exp(-((x - mean) / std_dev)**2 / 2) 366 | 367 | def gaussian_2d(xy, amplitude, mean_x, mean_y, stdev_x, stdev_y): 368 | """ 369 | [In fwhm_2d_fit] 370 | Returns 2D Gaussian (Normal distribution) for FWHM 2D calculation 371 | 372 | Parameters: 373 | xy (tuple): contains two 1D arrays, X and Y, which are the meshgrid of x and y coordinates 374 | mean_x (float): location of the peak of the distribution in X 375 | mean_y (float): location of the peak of the distribution in Y 376 | stdev_x (float): standard deviation in X 377 | stdev_y (float): standard deviation in Y. 378 | """ 379 | X, Y = xy 380 | return amplitude * jnp.exp(-((X - mean_x)**2 / (2 * stdev_x**2) + (Y - mean_y)**2 / (2 * stdev_y**2))) 381 | 382 | def lorentzian(x, x0, gamma): 383 | """ 384 | [In fwhm_1d_fit] 385 | Returns Lorentzian -- pathological distribution (expected value and variance are undefined) 386 | Parameters: 387 | x (jnp.array): 1D-position array 388 | x0 (float): location of the peak of the distribution 389 | gamma (float): scale parameter. Specifies FWHM = 2 * gamma. 390 | """ 391 | return (1/jnp.pi) * (gamma / ((x-x0)**2 + gamma**2)) 392 | 393 | def lorentzian_2d(xy, amplitude, x0, y0, gamma_x, gamma_y): 394 | """ 395 | [In fwhm_2d_fit] 396 | Returns 2D Lorentzian -- pathological distribution (expected value and variance are undefined) 397 | 398 | Parameters: 399 | xy (tuple): contains two 1D arrays, X and Y, which are the meshgrid of x and y coordinates 400 | amplitude (float): amplitude of the peak 401 | x0 (float): location of the peak of the distribution in X 402 | y0 (float): location of the peak of the distribution in Y 403 | gamma_x (float): scale parameter in X. 404 | gamma_y (float): scale parameter in Y. 405 | """ 406 | X, Y = xy 407 | return amplitude / (1 + ((X - x0) / gamma_x)**2 + ((Y - y0) / gamma_y)**2) 408 | 409 | def fwhm_1d_fit(x, intensity, fit = 'gaussian'): 410 | """ 411 | Compute FWHM of a 1D-array using fit function (gaussian or lorentzian) 412 | 413 | Parameters: 414 | x (jnp.array): 1D-position array. 415 | intensity (jnp.array): 1-dimensional intensity array. 416 | fit (str): can be 'gaussian' or 'lorentzian' 417 | 418 | Returns: 419 | popt (amplitude_fit, mean_fit, stddev_fit), 420 | fwhm (float, in um) 421 | and r_squared (float, r-squared metric of the fit). 422 | """ 423 | if fit == 'lorentzian': 424 | # initial guess (p0) for curve_fit 425 | x0_guess = jnp.max(intensity) 426 | gamma_guess = 0.5 427 | # lorentizan fit -- call scipy.curve_fit 428 | popt, _ = curve_fit(lorentzian, np.array(x), np.array(intensity), p0=[x0_guess, gamma_guess], maxfev=100000) 429 | _, gamma_fit = popt 430 | 431 | fwhm = 2 * gamma_fit 432 | 433 | # from https://en.wikipedia.org/wiki/Coefficient_of_determination 434 | # r^2 = 1 - SSres / SStot 435 | # here we compute residual sum of squares (SSres) and total sum of squares (SStot) 436 | ss_res = jnp.sum((intensity - lorentzian(x, *popt))**2) 437 | 438 | if fit == 'gaussian': 439 | # initial guess (p0) for curve_fit 440 | a_guess = jnp.max(intensity) 441 | mean_guess = x[jnp.argmax(intensity)] 442 | std_dev_guess = jnp.sqrt(jnp.sum((x - mean_guess)**2 * intensity) / jnp.sum(intensity)) 443 | 444 | # gaussian fit -- call scipy.curve_fit 445 | popt, _ = curve_fit(gaussian, np.array(x), np.array(intensity), p0=[a_guess, mean_guess, std_dev_guess], maxfev=100000) 446 | _, _, stddev_fit = popt 447 | 448 | # FWHM normal distribution = 2*sqrt(2ln2)*sigma 449 | fwhm = 2 * jnp.sqrt(2 * jnp.log(2)) * stddev_fit 450 | 451 | # from https://en.wikipedia.org/wiki/Coefficient_of_determination 452 | # r^2 = 1 - SSres / SStot 453 | # here we compute residual sum of squares (SSres) and total sum of squares (SStot) 454 | ss_res = jnp.sum((intensity - gaussian(x, *popt))**2) 455 | 456 | else: 457 | raise ValueError("fit must be either 'gaussian' or 'lorentzian'") 458 | 459 | ss_tot = jnp.sum((intensity - jnp.mean(intensity))**2) 460 | r_squared = 1 - (ss_res / ss_tot) 461 | 462 | return popt, fwhm, r_squared 463 | 464 | def fwhm_2d_fit(x, y, intensity, fit = 'gaussian'): 465 | """ 466 | Computes FWHM of an 2-dimensional using fit function (gaussian or lorentzian) 467 | 468 | Parameters: 469 | x (jnp.array): 1D-position array. 470 | y (jnp.array): 1D-position array. 471 | intensity (jnp.array): 2-dimensional intensity array. 472 | fit (str): can be 'gaussian' or 'lorentzian' 473 | 474 | Returns: 475 | popt (amplitude_fit, mean_fit, stddev_fit), 476 | FWHM_2D = (fwhm_x, fwhm_y) (tuple, in um) 477 | and r_squared (float, r-squared metric of the fit). 478 | """ 479 | X, Y = jnp.meshgrid(x, y) 480 | xy = jnp.vstack((X.ravel(), Y.ravel())) # vertical stack of ravel arrays for scipy's curve_fit (accepts indep. var. as a single arg) 481 | 482 | if fit == 'lorentzian': 483 | amplitude_guess = jnp.max(intensity) 484 | # get x[idxx, idxy] and y[idxx', idxy'] where max intensity 485 | mean_x_guess, mean_y_guess = ( 486 | x[jnp.unravel_index(jnp.argmax(intensity), intensity.shape)[1]], 487 | y[jnp.unravel_index(jnp.argmax(intensity), intensity.shape)[0]] 488 | ) 489 | gamma_x_guess, gamma_y_guess = 0.5, 0.5 490 | 491 | popt, _ = curve_fit(lorentzian_2d, xy, intensity.ravel(), 492 | p0=[amplitude_guess, mean_x_guess, mean_y_guess, gamma_x_guess, gamma_y_guess], 493 | maxfev=100000) 494 | 495 | _, _, _, gamma_x_fit, gamma_y_fit = popt 496 | FWHM2D = 2 * gamma_x_fit, 2 * gamma_y_fit 497 | 498 | # from https://en.wikipedia.org/wiki/Coefficient_of_determination 499 | # r^2 = 1 - SSres / SStot 500 | # here we compute residual sum of squares (SSres) and total sum of squares (SStot) 501 | ss_res = jnp.sum((intensity - lorentzian_2d(xy, *popt).reshape(intensity.shape))**2) 502 | 503 | if fit == 'gaussian': 504 | # initial guess for curve_fit 505 | amplitude_guess = jnp.max(intensity) 506 | # get x[idxx, idxy] and y[idxx', idxy'] where max intensity 507 | x0_guess, y0_guess = ( 508 | x[jnp.unravel_index(jnp.argmax(intensity), intensity.shape)[1]], 509 | y[jnp.unravel_index(jnp.argmax(intensity), intensity.shape)[0]] 510 | ) 511 | sigma_x_guess = jnp.sqrt(jnp.sum((X - x0_guess)**2 * intensity) / jnp.sum(intensity)) 512 | sigma_y_guess = jnp.sqrt(jnp.sum((Y - y0_guess)**2 * intensity) / jnp.sum(intensity)) 513 | 514 | popt, _ = curve_fit(gaussian_2d, xy, intensity.ravel(), 515 | p0=[amplitude_guess, x0_guess, y0_guess, sigma_x_guess, sigma_y_guess], 516 | maxfev=100000) 517 | 518 | _, _, _, sigma_x_fit, sigma_y_fit = popt 519 | FWHM2D = 2 * jnp.sqrt(2 * jnp.log(2)) * sigma_x_fit, 2 * jnp.sqrt(2 * jnp.log(2)) * sigma_y_fit 520 | 521 | # compute r-squared 522 | # from https://en.wikipedia.org/wiki/Coefficient_of_determination 523 | # r^2 = 1 - SSres / SStot 524 | # here we compute residual sum of squares (SSres) and total sum of squares (SStot) 525 | ss_res = jnp.sum((intensity - gaussian_2d(xy, *popt).reshape(intensity.shape))**2) 526 | 527 | else: 528 | raise ValueError("fit must be either 'gaussian' or 'lorentzian'") 529 | 530 | ss_tot = jnp.sum((intensity - jnp.mean(intensity))**2) 531 | r_squared = 1 - (ss_res / ss_tot) 532 | 533 | return popt, FWHM2D, r_squared 534 | 535 | def spot_size(fwhm_x, fwhm_y, wavelength): 536 | """ 537 | Computes the spot size in wavelength**2 units. 538 | 539 | Parameters: 540 | fwhm_x (float): FWHM in x (in microns). 541 | fwhm_y (float): FWHM in y (in microns). 542 | wavelength (float): Wavelength (in microns). 543 | 544 | Returns the spot size (jnp.array). 545 | """ 546 | return jnp.pi * (fwhm_x/2) * (fwhm_y/2) / wavelength**2 547 | 548 | def compute_fwhm(light, light_specs, field='', fit = 'gaussian', dimension = '2D', pix_slice = None): 549 | """ 550 | Computes FWHM in 1D or 2D (in um). 551 | 552 | Parameters: 553 | light (object): can be a jnp.array if field='' is set to 'Intensity'. 554 | light_specs (list): list with light specs - measurement plane [x, y]. 555 | field (str): component for which compute FWHM. Can be 'Ex', 'Ey', 'Ez', 'r' (radial), 'rz' (total field) and 'Intensity' if the input is not a field, but an intensity array. 556 | dimension (str): can be '1D' or '2D'. 557 | pix_slice (int): pix number in which to perform a slice for 1D calculation. E.g., in the center: pix_slice = resolution // 2 558 | 559 | Returns: 560 | popt, fwhm, r_squared; where fwhm = FWHM_1D or FWHM_2D (tuple) 561 | """ 562 | if field == 'Ex': 563 | intensity = (jnp.abs(light.Ex)) ** 2 564 | elif field == 'Ey': 565 | intensity = (jnp.abs(light.Ey)) ** 2 566 | elif field == 'Ez': 567 | intensity = (jnp.abs(light.Ez)) ** 2 568 | elif field == 'r': 569 | intensity = jnp.abs(light.Ex) ** 2 + (jnp.abs(light.Ey)) ** 2 570 | elif field == 'rz': 571 | intensity = (jnp.abs(light.Ex)) ** 2 + (jnp.abs(light.Ey)) ** 2 + (jnp.abs(light.Ez)) ** 2 572 | elif field == 'Intensity': 573 | intensity = jnp.abs(light) 574 | 575 | if '1D' in dimension: 576 | if dimension == '1D_x': 577 | intensity = intensity[:, pix_slice] # Slice intensity array in pix_slice 578 | axis = light_specs[0] 579 | elif dimension == '1D_y': 580 | intensity = intensity[pix_slice, :] # Slice intensity array in pix_slice 581 | axis = light_specs[1] 582 | else: 583 | axis = light_specs[0] 584 | popt, fwhm, r_squared = fwhm_1d_fit(axis, intensity, fit) 585 | 586 | if dimension == '2D': 587 | popt, fwhm, r_squared = fwhm_2d_fit(light_specs[0], light_specs[1], intensity, fit) 588 | # FWHM (tuple) = FWHM_x, FWHM_y 589 | 590 | return popt, fwhm, r_squared 591 | 592 | def find_max_min(value, x, y, kind =''): 593 | """ 594 | Find the position of maximum and minimum values within a 2D array. 595 | 596 | Parameters: 597 | value (jnp.array): 2-dimensional array with values. 598 | x (jnp.array): x-position array 599 | y (jnp.array): y-position array 600 | kind (str): choose whether to detect minimum 'min' or maximum 'max'. 601 | 602 | Returns: 603 | idx (int, int): indexes of the position of max/min. 604 | xy (float, float): space position of max/min. 605 | ext_value (float): max/min value. 606 | 607 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) << 608 | """ 609 | if kind =='min': 610 | val = jnp.where(value==jnp.min(value)) 611 | if kind =='max': 612 | val = jnp.where(value==jnp.max(value)) 613 | 614 | # Extract coordinates into separate arrays: 615 | coordinates = jnp.array(list(zip(val[1], val[0]))) 616 | coords_0 = coordinates[:, 0] 617 | coords_1 = coordinates[:, 1] 618 | 619 | # Index array: 620 | idx = coordinates.astype(int) 621 | 622 | # Array with space positions: 623 | xy = jnp.stack([x[coords_0], y[coords_1]], axis=1) 624 | 625 | # Array with extreme values: 626 | ext_value = value[coords_1, coords_0] 627 | 628 | return idx, xy, ext_value 629 | 630 | def fwhm_2d(x, y, intensity): 631 | """ 632 | Computes FWHM of an 2-dimensional intensity array. 633 | 634 | Parameters: 635 | x (jnp.array): x-position array. 636 | y (jnp.array): y-position array. 637 | intensity (jnp.array): 2-dimensional intensity array. 638 | 639 | Returns: 640 | fwhm_x and fwhm_y 641 | 642 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) << 643 | """ 644 | i_position, _, _ = find_max_min(jnp.transpose(intensity), x, y, kind='max') 645 | 646 | Ix = intensity[:, i_position[0, 1]] 647 | Iy = intensity[i_position[0, 0], :] 648 | 649 | fwhm_x, _, _, _ = fwhm_1d(x, Ix) 650 | fwhm_y, _, _, _ = fwhm_1d(y, Iy) 651 | 652 | return fwhm_x, fwhm_y 653 | 654 | def fwhm_1d(x, intensity, I_max = None): 655 | """ 656 | Computes FWHM of 1-dimensional intensity array. 657 | 658 | Parameters: 659 | x (jnp.array): 1D-position array. 660 | intensity (jnp.array): 1-dimensional intensity array. 661 | Returns: 662 | fwhm in 1 dimension 663 | 664 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) << 665 | """ 666 | # Setting-up: 667 | dx = x[1] - x[0] 668 | if I_max is None: 669 | I_max = jnp.max(intensity) 670 | I_half = I_max * 0.5 671 | 672 | # Pixels with I max: 673 | i_max = jnp.argmax(intensity) 674 | # i_max = jnp.where(intensity == I_max) 675 | # # Compute the pixel location: 676 | # i_max = int(i_max[0][0]) 677 | 678 | # Compute the slopes: 679 | i_left, _, distance_left = nearest(intensity[0:i_max], I_half) 680 | slope_left = (intensity[i_left + 1] - intensity[i_left]) / dx 681 | 682 | i_right, _, distance_right = nearest(intensity[i_max::], I_half) 683 | i_right += i_max 684 | slope_right = (intensity[i_right] - intensity[i_right - 1]) / dx 685 | 686 | x_right = x[i_right] - distance_right / slope_right 687 | x_left = x[i_left] - distance_left / slope_left 688 | 689 | # Compute fwhm: 690 | fwhm = x_right - x_left 691 | 692 | return fwhm, x_left, x_right, I_half 693 | 694 | # ----------------------------------------------------------- 695 | 696 | """ Functions to process datasets (hdf5 files) """ 697 | 698 | class MultiHDF5DataLoader: 699 | """ 700 | Class for JAX-DataLoader 701 | """ 702 | def __init__(self, directory, batch_size): 703 | self.directory = directory 704 | self.batch_size = batch_size 705 | 706 | # Get a list of all HDF5 files in the directory 707 | self.files = [f for f in os.listdir(self.directory) if f.endswith('.hdf5')] 708 | 709 | if not self.files: 710 | raise ValueError(f"No HDF5 files found in directory: {self.directory}") 711 | 712 | def __iter__(self): 713 | return self 714 | 715 | def __next__(self): 716 | """ 717 | Randomly selects one file and randomly picks batch_size number of files from it. Returns the jnp.array version. 718 | """ 719 | # Randomly select one of the HDF5 files 720 | selected_file = random.choice(self.files) 721 | filepath = os.path.join(self.directory, selected_file) 722 | 723 | # Open the selected HDF5 file to get the total number of samples; 724 | # This is where we open the selected HDF5 file in read mode. However, we're not reading the entire dataset into memory. 725 | # Instead, we're just accessing its shape to get the total number of samples. The data remains on disk. 726 | with h5py.File(filepath, 'r') as hf: 727 | total_samples = hf["Input fields"].shape[0] 728 | 729 | # Randomly select indices for the batch 730 | batch_indices = sorted(random.sample(range(total_samples), self.batch_size)) 731 | 732 | # Load the batch from the selected HDF5 file; 733 | # We open the HDF5 file in read mode again, but this time, we're using the randomly selected indices (batch_indices) to fetch only a specific subset of the data. 734 | # Only the data corresponding to these indices is loaded into memory, and not the entire dataset. 735 | with h5py.File(filepath, 'r') as hf: 736 | input_batch = hf["Input fields"][batch_indices] 737 | target_batch = hf["Target fields"][batch_indices] 738 | 739 | return jnp.array(input_batch), jnp.array(target_batch) -------------------------------------------------------------------------------- /xlumina/vectorized_optics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | from jax import jit, vmap, config 4 | from functools import partial 5 | import matplotlib.pyplot as plt 6 | import time 7 | 8 | from .toolbox import profile 9 | from .wave_optics import build_grid, RS_propagation_jit, build_CZT_grid, CZT_jit, CZT_for_high_NA_jit 10 | 11 | # Set this to False if f64 is enough precision for you. 12 | enable_float64 = True 13 | d_type = jnp.complex64 14 | if enable_float64: 15 | config.update("jax_enable_x64", True) 16 | d_type = jnp.complex128 17 | 18 | """ 19 | Module for vectorized optical fields: 20 | 21 | - VectorizedLight: 22 | - draw 23 | - draw_intensity_profile 24 | - VRS_propagation 25 | - get_VRS_minimum_z 26 | - VCZT 27 | - VRS_propagation_jit 28 | - VCZT_jit 29 | - vectorized_CZT_for_high_NA 30 | 31 | - PolarizedLightSource: 32 | - gaussian_beam 33 | - plane_wave 34 | """ 35 | 36 | class VectorizedLight: 37 | """ Class for Vectorial EM fields - (Ex, Ey, Ez) """ 38 | def __init__(self, x=None, y=None, wavelength=None): 39 | self.x = x 40 | self.y = y 41 | self.X, self.Y = jnp.meshgrid(self.x, self.y) 42 | self.wavelength = wavelength 43 | self.k = 2 * jnp.pi / wavelength 44 | self.n = 1 45 | shape = (jnp.shape(x)[0], jnp.shape(y)[0]) 46 | self.Ex = jnp.zeros(shape, dtype=d_type) 47 | self.Ey = jnp.zeros(shape, dtype=d_type) 48 | self.Ez = jnp.zeros(shape, dtype=d_type) 49 | self.info = 'Vectorized light' 50 | 51 | def draw(self, xlim='', ylim='', kind='', extra_title='', save_file=False, filename=''): 52 | """ 53 | Plots VectorizedLight. 54 | 55 | Parameters: 56 | xlim (float, float): x-axis limit for plot purpose. 57 | ylim (float, float): y-axis limit for plot purpose. 58 | kind (str): Feature to plot: 'Intensity', 'Phase' or 'Field'. 59 | extra_title (str): Adds extra info to the plot title. 60 | save_file (bool): If True, saves the figure. 61 | filename (str): Name of the figure. 62 | """ 63 | extent = [xlim[0], xlim[1], ylim[0], ylim[1]] 64 | if kind == 'Intensity': 65 | # Compute intensity 66 | Ix = jnp.abs(self.Ex) ** 2 # Ex 67 | Iy = jnp.abs(self.Ey) ** 2 # Ey 68 | Iz = jnp.abs(self.Ez) ** 2 # Ez 69 | Ir = Ix + Iy # Er 70 | 71 | fig, axes = plt.subplots(2, 3, figsize=(14, 7)) 72 | cmap = 'gist_heat' 73 | 74 | ax = axes[0,0] 75 | im = ax.imshow(Ix, cmap=cmap, extent=extent, origin='lower') 76 | ax.set_title(f"Intensity x. {extra_title}") 77 | ax.set_xlabel('$x (\mu m)$') 78 | ax.set_ylabel('$y (\mu m)$') 79 | fig.colorbar(im, ax=ax) 80 | im.set_clim(vmin=jnp.min(Ix), vmax=jnp.max(Ix)) 81 | 82 | ax = axes[0,1] 83 | im = ax.imshow(Iy, cmap=cmap, extent=extent, origin='lower') 84 | ax.set_title(f"Intensity y. {extra_title}") 85 | ax.set_xlabel('$x (\mu m)$') 86 | ax.set_ylabel('$y (\mu m)$') 87 | fig.colorbar(im, ax=ax) 88 | im.set_clim(vmin=jnp.min(Iy), vmax=jnp.max(Iy)) 89 | 90 | ax = axes[1,0] 91 | im = ax.imshow(Iz, cmap=cmap, extent=extent, origin='lower') 92 | ax.set_title(f"Intensity z. {extra_title}") 93 | ax.set_xlabel('$x (\mu m)$') 94 | ax.set_ylabel('$y (\mu m)$') 95 | fig.colorbar(im, ax=ax) 96 | im.set_clim(vmin=jnp.min(Iz), vmax=jnp.max(Iz)) 97 | 98 | ax = axes[0,2] 99 | im = ax.imshow(Ir, cmap=cmap, extent=extent, origin='lower') 100 | ax.set_title(f"Intensity r. {extra_title}") 101 | ax.set_xlabel('$x (\mu m)$') 102 | ax.set_ylabel('$y (\mu m)$') 103 | fig.colorbar(im, ax=ax) 104 | im.set_clim(vmin=jnp.min(Ir), vmax=jnp.max(Ir)) 105 | 106 | axes[1,1].axis('off') 107 | axes[1,2].axis('off') 108 | plt.subplots_adjust(wspace=0.6, hspace=0.6) 109 | 110 | 111 | elif kind == 'Phase': 112 | # Compute phase 113 | phi_x = jnp.angle(self.Ex) # Ex 114 | phi_y = jnp.angle(self.Ey) # Ey 115 | phi_z = jnp.angle(self.Ez) # Ez 116 | 117 | fig, axes = plt.subplots(1, 3, figsize=(14, 3)) 118 | cmap = 'twilight' 119 | 120 | ax = axes[0] 121 | im = ax.imshow(phi_x, cmap=cmap, extent=extent, origin='lower') 122 | ax.set_title(f"Phase x (in radians). {extra_title}") 123 | ax.set_xlabel('$x (\mu m)$') 124 | ax.set_ylabel('$y (\mu m)$') 125 | fig.colorbar(im, ax=ax) 126 | im.set_clim(vmin=-jnp.pi, vmax=jnp.pi) 127 | 128 | ax = axes[1] 129 | im = ax.imshow(phi_y, cmap=cmap, extent=extent, origin='lower') 130 | ax.set_title(f"Phase y (in radians). {extra_title}") 131 | ax.set_xlabel('$x (\mu m)$') 132 | ax.set_ylabel('$y (\mu m)$') 133 | fig.colorbar(im, ax=ax) 134 | im.set_clim(vmin=-jnp.pi, vmax=jnp.pi) 135 | 136 | ax = axes[2] 137 | im = ax.imshow(phi_z, cmap=cmap, extent=extent, origin='lower') 138 | ax.set_title(f"Phase z (in radians). {extra_title}") 139 | ax.set_xlabel('$x (\mu m)$') 140 | ax.set_ylabel('$y (\mu m)$') 141 | fig.colorbar(im, ax=ax) 142 | im.set_clim(vmin=-jnp.pi, vmax=jnp.pi) 143 | 144 | elif kind == 'Field': 145 | # Compute field amplitudes 146 | Ax = jnp.abs(self.Ex) # Ex 147 | Ay = jnp.abs(self.Ey) # Ey 148 | Az = jnp.abs(self.Ez) # Ez 149 | 150 | fig, axes = plt.subplots(1, 3, figsize=(14, 3)) 151 | cmap = 'viridis' 152 | 153 | ax = axes[0] 154 | im = ax.imshow(Ax, cmap=cmap, extent=extent, origin='lower') 155 | ax.set_title(f"Amplitude x. {extra_title}") 156 | ax.set_xlabel('$x (\mu m)$') 157 | ax.set_ylabel('$y (\mu m)$') 158 | fig.colorbar(im, ax=ax) 159 | im.set_clim(vmin=jnp.min(Ax), vmax=jnp.max(Ax)) 160 | 161 | ax = axes[1] 162 | im = ax.imshow(Ay, cmap=cmap, extent=extent, origin='lower') 163 | ax.set_title(f"Amplitude y. {extra_title}") 164 | ax.set_xlabel('$x (\mu m)$') 165 | ax.set_ylabel('$y (\mu m)$') 166 | fig.colorbar(im, ax=ax) 167 | im.set_clim(vmin=jnp.min(Ay), vmax=jnp.max(Ay)) 168 | 169 | ax = axes[2] 170 | im = ax.imshow(Az, cmap=cmap, extent=extent, origin='lower') 171 | ax.set_title(f"Amplitude z. {extra_title}") 172 | ax.set_xlabel('$x (\mu m)$') 173 | ax.set_ylabel('$y (\mu m)$') 174 | fig.colorbar(im, ax=ax) 175 | im.set_clim(vmin=jnp.min(Az), vmax=jnp.max(Az)) 176 | 177 | else: 178 | raise ValueError(f"Invalid kind option: {kind}. Please choose 'Intensity', 'Phase' or 'Field'.") 179 | 180 | plt.tight_layout() 181 | 182 | if save_file is True: 183 | plt.savefig(filename) 184 | print(f"Plot saved as {filename}") 185 | 186 | plt.show() 187 | 188 | def draw_intensity_profile(self, p1='', p2=''): 189 | """ 190 | Draws the intensity profile of VectorizedLight. 191 | 192 | Parameters: 193 | p1 (float, float): Initial point. 194 | p2 (float, float): Final point. 195 | """ 196 | h, z_profile_x = profile(jnp.abs(self.Ex)**2, self.x, self.y, point1=p1, point2=p2) 197 | _, z_profile_y = profile(jnp.abs(self.Ey)**2, self.x, self.y, point1=p1, point2=p2) 198 | _, z_profile_z = profile(jnp.abs(self.Ez)**2, self.x, self.y, point1=p1, point2=p2) 199 | _, z_profile_r = profile(jnp.abs(self.Ex)**2 + jnp.abs(self.Ey)**2, self.x, self.y, point1=p1, point2=p2) 200 | _, z_profile_total = profile(jnp.abs(self.Ex)**2 + jnp.abs(self.Ey)**2 + jnp.abs(self.Ez)**2, self.x, self.y, point1=p1, point2=p2) 201 | 202 | fig, axes = plt.subplots(3, 2, figsize=(14, 14)) 203 | 204 | ax = axes[0, 0] 205 | im = ax.plot(h, z_profile_x, 'k', lw=2) 206 | ax.set_title(f"Ix profile") 207 | ax.set_xlabel('$\mu m$') 208 | ax.set_ylabel('$Ix$') 209 | ax.set(xlim=(h.min(), h.max()), ylim=(z_profile_x.min(), z_profile_x.max())) 210 | 211 | ax = axes[0, 1] 212 | im = ax.plot(h, z_profile_y, 'k', lw=2) 213 | ax.set_title(f"Iy profile") 214 | ax.set_xlabel('$\mu m$') 215 | ax.set_ylabel('$Iy$') 216 | ax.set(xlim=(h.min(), h.max()), ylim=(z_profile_y.min(), z_profile_y.max())) 217 | 218 | ax = axes[1, 0] 219 | im = ax.plot(h, z_profile_z, 'k', lw=2) 220 | ax.set_title(f"Iz profile") 221 | ax.set_xlabel('$\mu m$') 222 | ax.set_ylabel('$Iz$') 223 | ax.set(xlim=(h.min(), h.max()), ylim=(z_profile_z.min(), z_profile_z.max())) 224 | 225 | ax = axes[1, 1] 226 | im = ax.plot(h, z_profile_r, 'k', lw=2) 227 | ax.set_title(f"Ir profile") 228 | ax.set_xlabel('$\mu m$') 229 | ax.set_ylabel('$Ir$') 230 | ax.set(xlim=(h.min(), h.max()), ylim=(z_profile_r.min(), z_profile_r.max())) 231 | 232 | ax = axes[2, 0] 233 | im = ax.plot(h, z_profile_total, 'k', lw=2) 234 | ax.set_title(f"Itotal profile") 235 | ax.set_xlabel('$\mu m$') 236 | ax.set_ylabel('$Itotal$') 237 | ax.set(xlim=(h.min(), h.max()), ylim=(z_profile_total.min(), z_profile_total.max())) 238 | 239 | axes[2, 1].axis('off') 240 | plt.subplots_adjust(wspace=0.3, hspace=0.4) 241 | 242 | plt.show() 243 | 244 | def VRS_propagation(self, z): 245 | """ 246 | Rayleigh-Sommerfeld diffraction integral in both, z>0 and z<0, for VectorizedLight. 247 | [Ref 1: Laser Phys. Lett., 10(6), 065004 (2013)]. 248 | [Ref 2: Optics and laser tech., 39(4), 10.1016/j.optlastec.2006.03.006]. 249 | [Ref 3: J. Li, Z. Fan, Y. Fu, Proc. SPIE 4915, (2002)]. 250 | 251 | Parameters: 252 | z (float): Distance to propagate. 253 | 254 | Returns VectorizedLight object after propagation and the quality factor of the algorithm. 255 | """ 256 | tic = time.perf_counter() 257 | # Define r [From Ref 1, eq. 1a-1c]: 258 | r = jnp.sqrt(self.X ** 2 + self.Y ** 2 + z ** 2) 259 | 260 | # Set the value of Ez: 261 | Ez = jnp.array(self.Ex * self.X / r + self.Ey * self.Y / r) 262 | nx, ny, dx, dy, Xext, Yext = build_grid(self.x, self.y) 263 | 264 | # Quality factor for accurate simulation [Eq. 22 in Ref1]: 265 | dr_real = jnp.sqrt(dx**2 + dy**2) 266 | # Rho 267 | rmax = jnp.sqrt(jnp.max(self.x**2) + jnp.max(self.y**2)) 268 | # Delta rho ideal 269 | dr_ideal = jnp.sqrt((self.wavelength)**2 + rmax**2 + 2 * (self.wavelength) * jnp.sqrt(rmax**2 + z**2)) - rmax 270 | quality_factor = dr_ideal / dr_real 271 | 272 | # Stack the input field in a (3, N, N) shape and pass to jit. 273 | E_in = jnp.stack([self.Ex, self.Ey, Ez], axis=0) 274 | E_out = VRS_propagation_jit(E_in, z, nx, ny, dx, dy, Xext, Yext, self.k) 275 | E_out = jnp.moveaxis(E_out, [0, 1, 2], [2, 0, 1]) 276 | 277 | # Define the output light: 278 | light_out = VectorizedLight(self.x, self.y, self.wavelength) 279 | light_out.Ex = E_out[:, :, 0] 280 | light_out.Ey = E_out[:, :, 1] 281 | light_out.Ez = E_out[:, :, 2] 282 | 283 | print(f"Time taken to perform one VRS propagation (in seconds): {(time.perf_counter() - tic):.4f}") 284 | return light_out, quality_factor 285 | 286 | def get_VRS_minimum_z(self, n=1, quality_factor=1): 287 | """ 288 | Given a quality factor, determines the minimum available (trustworthy) distance for VRS_propagation(). 289 | [Ref 1: Laser Phys. Lett., 10(6), 065004 (2013)]. 290 | 291 | Parameters: 292 | n (float): refraction index of the surrounding medium. 293 | quality_factor (int): Defaults to 1. 294 | 295 | Returns the minimum distance z (in microns) necessary to achieve qualities larger than quality_factor. 296 | 297 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) << 298 | """ 299 | # Check sampling 300 | range_x = self.x[-1] - self.x[0] 301 | range_y = self.y[-1] - self.y[0] 302 | num_x = jnp.size(self.x) 303 | num_y = jnp.size(self.y) 304 | 305 | dx = range_x / num_x 306 | dy = range_y / num_y 307 | # Delta rho 308 | dr_real = jnp.sqrt(dx**2 + dy**2) 309 | # Rho 310 | rmax = jnp.sqrt(range_x**2 + range_y**2) 311 | 312 | factor = (((quality_factor * dr_real + rmax)**2 - (self.wavelength / n)**2 - rmax**2) / (2 * self.wavelength / n))**2 - rmax**2 313 | 314 | if factor > 0: 315 | z_min = jnp.sqrt(factor) 316 | else: 317 | z_min = 0 318 | 319 | return print("Minimum distance to propagate (in um):", z_min) 320 | 321 | def VCZT(self, z, xout, yout): 322 | """ 323 | Vectorial version of the Chirped z-transform propagation - efficient RS diffraction using the Bluestein method. 324 | Useful for imaging light in the focal plane: allows high resolution zoom in z-plane. 325 | [Ref] Hu, Y., et al. Light Sci Appl 9, 119 (2020). 326 | 327 | Parameters: 328 | z (float): Propagation distance. 329 | xout (jnp.array): Array with the x-positions for the output plane. 330 | 331 | Returns VectorizedLight object after propagation. 332 | """ 333 | tic = time.perf_counter() 334 | if xout is None: 335 | xout = self.x 336 | 337 | if yout is None: 338 | yout = self.y 339 | 340 | # Define r: 341 | r = jnp.sqrt(self.X ** 2 + self.Y ** 2 + z ** 2) 342 | 343 | # Set the value of Ez: 344 | Ez = jnp.array((self.Ex * self.X / r + self.Ey * self.Y / r) * z / r) 345 | 346 | # Define main set of parameters 347 | nx, ny, dx, dy, Xout, Yout, Dm, fy_1, fy_2, fx_1, fx_2 = build_CZT_grid(z, self.wavelength, self.x, self.y, xout, yout) 348 | 349 | # Stack the input field in a (3, N, N) shape and pass to jit. 350 | E_in = jnp.stack([self.Ex, self.Ey, Ez], axis=0) 351 | E_out = VCZT_jit(E_in, z, self.wavelength, self.k, nx, ny, dx, dy, Xout, Yout, self.X, self.Y, Dm, fy_1, fy_2, fx_1, fx_2) 352 | E_out = jnp.moveaxis(E_out, [0, 1, 2], [2, 0, 1]) 353 | 354 | # Define the output light: 355 | light_out = VectorizedLight(xout, yout, self.wavelength) 356 | light_out.Ex = E_out[:, :, 0] 357 | light_out.Ey = E_out[:, :, 1] 358 | light_out.Ez = E_out[:, :, 2] 359 | 360 | print(f"Time taken to perform one VCZT propagation (in seconds): {(time.perf_counter() - tic):.4f}") 361 | return light_out 362 | 363 | 364 | @partial(jit, static_argnums=(2, 3, 4, 5, 8)) 365 | def VRS_propagation_jit(input_field, z, nx, ny, dx, dy, Xext, Yext, k): 366 | """[From VRS_propagation]: JIT function that vectorizes the propagation and calls RS_propagation_jit from wave_optics.py.""" 367 | 368 | # Input field has (3, N, N) shape 369 | vectorized_RS_propagation = vmap(RS_propagation_jit, 370 | in_axes=(0, None, None, None, None, None, None, None, None)) 371 | # Call the vectorized function 372 | E_out = vectorized_RS_propagation(input_field, z, nx, ny, dx, dy, Xext, Yext, k) 373 | return E_out # (3, N, N) -> ([Ex, Ey, Ez], N, N) 374 | 375 | def VCZT_jit(field, z, wavelength, k, nx, ny, dx, dy, Xout, Yout, X, Y, Dm, fy_1, fy_2, fx_1, fx_2): 376 | """[From CZT]: JIT function that vectorizes the propagation and calls CZT_jit from wave_optics.py.""" 377 | 378 | # Input field has (3, N, N) shape 379 | vectorized_CZT = vmap(CZT_jit, 380 | in_axes=(0, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None)) 381 | 382 | # Call the vectorized function 383 | E_out = vectorized_CZT(field, z, wavelength, k, nx, ny, dx, dy, Xout, Yout, X, Y, Dm, fy_1, fy_2, fx_1, fx_2) 384 | return E_out # (3, N, N) -> ([Ex, Ey, Ez], N, N) 385 | 386 | def vectorized_CZT_for_high_NA(field, nx, ny, Dm, fy_1, fy_2, fx_1, fx_2): 387 | """[From VCZT_objective_lens - in optical_elements.py]: JIT function that vectorizes the propagation and calls CZT_for_high_NA_jit from wave_optics.py.""" 388 | 389 | # Input field has (3, N, N) shape 390 | vectorized = vmap(CZT_for_high_NA_jit, in_axes=(0, None, None, None, None, None, None, None)) 391 | 392 | # Call the vectorized function 393 | E_out = vectorized(field, nx, ny, Dm, fy_1, fy_2, fx_1, fx_2) 394 | return E_out # (3, N, N) -> ([Ex, Ey, Ez], N, N) 395 | 396 | class PolarizedLightSource(VectorizedLight): 397 | """ Class for generating polarized light source beams.""" 398 | def __init__(self, x, y, wavelength): 399 | super().__init__(x, y, wavelength) 400 | self.info = 'Vectorized light source' 401 | 402 | def gaussian_beam(self, w0, jones_vector, center=(0, 0), z_w0=(0, 0), alpha=0): 403 | """ 404 | Defines a gaussian beam. 405 | 406 | Parameters: 407 | w0 (float, float): Waist radius (in microns). 408 | jones_vector (float, float): (Ex, Ey) at the origin (r=0, z=0). Doesn't need to be normalized. 409 | center (float, float): Position of the center of the beam (in microns). 410 | z_w0 (float, float): Position of the waist for (x, y) (in microns). 411 | alpha (float, float): Amplitude rotation (in radians). 412 | 413 | Returns PolarizedLightSource object. 414 | """ 415 | # Waist radius 416 | w0_x, w0_y = w0 417 | 418 | # (x, y) center position 419 | x0, y0 = center 420 | 421 | # z-position of the beam waist 422 | z_w0x, z_w0y = z_w0 423 | 424 | # Rayleigh range 425 | Rayleigh_x = self.k * w0_x ** 2 * self.n / 2 426 | Rayleigh_y = self.k * w0_y ** 2 * self.n / 2 427 | 428 | # Gouy phase 429 | Gouy_phase_x = jnp.arctan2(z_w0x, Rayleigh_x) 430 | Gouy_phase_y = jnp.arctan2(z_w0y, Rayleigh_y) 431 | 432 | # Spot size (radius of the beam at position z) 433 | w_x = w0_x * jnp.sqrt(1 + (z_w0x / Rayleigh_x) ** 2) 434 | w_y = w0_y * jnp.sqrt(1 + (z_w0y / Rayleigh_y) ** 2) 435 | 436 | # Radius of curvature 437 | if z_w0x == 0: 438 | R_x = 1e12 439 | else: 440 | R_x = z_w0x * (1 + (Rayleigh_x / z_w0x) ** 2) 441 | if z_w0x == 0: 442 | R_y = 1e12 443 | else: 444 | R_y = z_w0y * (1 + (Rayleigh_y / z_w0y) ** 2) 445 | 446 | # Gaussian beam coordinates 447 | # Accounting the rotation of the coordinates by alpha: 448 | x_rot = self.X * jnp.cos(alpha) + self.Y * jnp.sin(alpha) 449 | y_rot = -self.X * jnp.sin(alpha) + self.Y * jnp.cos(alpha) 450 | 451 | # Define the phase and amplitude of the field: 452 | phase = jnp.exp(-1j * ((self.k * z_w0x + self.k * self.X ** 2 / (2 * R_x) - Gouy_phase_x) + ( 453 | self.k * z_w0y + self.k * self.Y ** 2 / (2 * R_y) - Gouy_phase_y))) 454 | 455 | # Normalize Jones vector: 456 | normalized_jones = np.array(jones_vector) / jnp.linalg.norm(np.array(jones_vector)) 457 | 458 | Ex = normalized_jones[0] * (w0_x / w_x) * (w0_y / w_y) * jnp.exp( 459 | -(x_rot - x0) ** 2 / (w_x ** 2) - (y_rot - y0) ** 2 / (w_y ** 2)) 460 | Ey = normalized_jones[1] * (w0_x / w_x) * (w0_y / w_y) * jnp.exp( 461 | -(x_rot - x0) ** 2 / (w_x ** 2) - (y_rot - y0) ** 2 / (w_y ** 2)) 462 | 463 | self.Ex = Ex*phase 464 | self.Ey = Ey*phase 465 | self.Ez = jnp.zeros((jnp.shape(self.x)[0], jnp.shape(self.x)[0])) 466 | 467 | def plane_wave(self, jones_vector, theta=0, phi=0, z0=0): 468 | """ 469 | Defines a plane wave. 470 | 471 | Parameters: 472 | jones_vector (float, float): (Ex, Ey) at the origin (r=0, z=0). Doesn't need to be normalized. 473 | theta (float): Angle (in radians). 474 | phi (float): Angle (in radians). 475 | z0 (float): Constant value for phase shift. 476 | 477 | Equation: 478 | self.field = A * exp(1.j * k * (self.X * sin(theta) * cos(phi) + self.Y * sin(theta) * sin(phi) + z0 * cos(theta))) 479 | 480 | Returns PolarizedLightSource object. 481 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<< 482 | """ 483 | # Normalize Jones vector: 484 | normalized_jones = np.array(jones_vector) / jnp.linalg.norm(np.array(jones_vector)) 485 | 486 | pw = jnp.exp(1j * self.k * (self.X * jnp.sin(theta) * jnp.cos(phi) + self.Y * jnp.sin(theta) * jnp.sin(phi) + z0 * jnp.cos(theta))) 487 | 488 | self.Ex = normalized_jones[0] * pw 489 | self.Ey = normalized_jones[1] * pw 490 | self.Ez = jnp.zeros((jnp.shape(self.x)[0], jnp.shape(self.x)[0])) 491 | 492 | 493 | 494 | -------------------------------------------------------------------------------- /xlumina/wave_optics.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | from jax import jit, config 4 | from functools import partial 5 | import matplotlib.pyplot as plt 6 | import time 7 | 8 | from .toolbox import rotate_mask 9 | 10 | # Set this to False if f64 is enough precision for you. 11 | enable_float64 = True 12 | if enable_float64: 13 | config.update("jax_enable_x64", True) 14 | 15 | """ 16 | Module for scalar optical fields. 17 | 18 | - ScalarLight: 19 | - draw 20 | - apply_circular_mask 21 | - apply_triangular_mask 22 | - apply_rectangular_mask 23 | - apply_annular_aperture 24 | - RS_propagation 25 | - get_RS_minimum_z 26 | - CZT 27 | - build_grid 28 | - RS_propagation_jit 29 | - transfer_function_RS 30 | - build_CZT_grid 31 | - CZT_jit 32 | - CZT_for_high_NA_jit 33 | - compute_np2 34 | - compute_fft 35 | - Bluestein_method 36 | 37 | - LightSource: 38 | - gaussian_beam 39 | - plane_wave 40 | """ 41 | 42 | 43 | class ScalarLight: 44 | """ Class for Scalar fields - complex amplitude U(r) = A(r)*exp(-ikz). """ 45 | def __init__(self, x, y, wavelength): 46 | self.x = x 47 | self.y = y 48 | self.X, self.Y = jnp.meshgrid(x, y) 49 | self.wavelength = wavelength 50 | self.k = 2 * jnp.pi / wavelength 51 | self.n = 1 52 | self.field = jnp.zeros((jnp.shape(x)[0], jnp.shape(y)[0])) 53 | self.info = 'Wave optics light' 54 | 55 | def draw(self, xlim='', ylim='', kind='', extra_title='', save_file=False, filename=''): 56 | """ 57 | Plots ScalarLight. 58 | 59 | Parameters: 60 | xlim (float, float): x-axis limits. 61 | ylim (float, float): y-axis limits. 62 | kind (str): Feature to plot: 'Intensity' or 'Phase'. 63 | extra_title (str): Adds extra info to the plot title. 64 | save_file (bool): If True, saves the figure. 65 | filename (str): Name of the figure. 66 | """ 67 | extent = [xlim[0], xlim[1], ylim[0], ylim[1]] 68 | if kind == 'Intensity': 69 | field_to_plot = jnp.abs(self.field) ** 2 # Compute intensity (magnitude squared) 70 | title = f"Detected intensity. {extra_title}" 71 | cmap = 'gist_heat' 72 | plt.imshow(field_to_plot, cmap=cmap, extent=extent, origin='lower') 73 | plt.colorbar(orientation='vertical') 74 | 75 | elif kind == 'Phase': 76 | field_to_plot = jnp.angle(self.field) # Calculate phase (in radians) 77 | title = f"Phase in radians. {extra_title}" 78 | cmap = 'twilight' 79 | plt.imshow(field_to_plot, cmap=cmap, extent=extent, origin='lower') 80 | plt.colorbar(orientation='vertical') 81 | plt.clim(vmin=-jnp.pi, vmax=jnp.pi) 82 | 83 | else: 84 | raise ValueError(f"Invalid kind option: {kind}. Please choose 'Intensity' or 'Phase'.") 85 | 86 | plt.title(title) 87 | plt.xlabel('$x(\mu m)$') 88 | plt.ylabel('$y(\mu m)$') 89 | 90 | if save_file is True: 91 | plt.savefig(filename) 92 | print(f"Plot saved as {filename}") 93 | 94 | plt.show() 95 | 96 | def apply_circular_mask(self, r): 97 | """ 98 | Apply a circular mask of variable radius. 99 | 100 | Parameters: 101 | r (float, float): Radius of the circle (in microns). 102 | 103 | Returns ScalarLight object after applying the pupil mask. 104 | """ 105 | rx, ry = r 106 | pupil = jnp.where((self.X**2 / rx**2 + self.Y**2 / ry**2) < 1, 1, 0) 107 | output = ScalarLight(self.x, self.y, self.wavelength) 108 | output.field = self.field * pupil 109 | return output 110 | 111 | def apply_triangular_mask(self, r, angle, m, height): 112 | """ 113 | Apply a triangular mask of variable size. 114 | 115 | Equation to generate the triangle: y = -m (x - x0) + y0. 116 | 117 | Parameters: 118 | r (float, float): Coordinates of the top corner of the triangle (in microns). 119 | angle (float): Rotation of the triangle (in radians). 120 | m (float): Slope of the edges. 121 | height (float): Distance between the top corner and the basis (in microns). 122 | 123 | Returns ScalarLight object after applying the triangular mask. 124 | 125 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) << 126 | """ 127 | x0, y0 = r 128 | Xrot, Yrot = rotate_mask(self.X, self.Y, angle, origin=(x0, y0)) 129 | Y = -m * jnp.abs(Xrot - x0) + y0 130 | mask = jnp.where((Yrot < Y) & (Yrot > (y0 - height)), 1, 0) 131 | output = ScalarLight(self.x, self.y, self.wavelength) 132 | output.field = self.field * mask 133 | return output 134 | 135 | def apply_rectangular_mask(self, center, width, height, angle): 136 | """ 137 | Apply a square mask of variable size. Can generate rectangles, squares and rotate them to create diamond shapes. 138 | 139 | Parameters: 140 | center (float, float): Coordinates of the center (in microns). 141 | width (float): Width of the rectangle (in microns). 142 | height (float): Height of the rectangle (in microns). 143 | angle (float): Angle of rotation of the rectangle (in degrees). 144 | 145 | Returns ScalarLight object after applying the rectangular mask. 146 | """ 147 | x0, y0 = center 148 | angle = angle * (jnp.pi/180) 149 | Xrot, Yrot = rotate_mask(self.X, self.Y, angle, center) 150 | mask = jnp.where((Xrot < (width/2)) & (Xrot > (-width/2)) & (Yrot < (height/2)) & (Yrot > (-height/2)), 1, 0) 151 | output = ScalarLight(self.x, self.y, self.wavelength) 152 | output.field = self.field * mask 153 | return output 154 | 155 | def apply_annular_aperture(self, di, do): 156 | """ 157 | Apply an annular aperture of variable size. 158 | 159 | Parameters: 160 | di (float): Radius of the inner circle (in microns). 161 | do (float): Radius of the outer circle (in microns). 162 | 163 | Returns ScalarLight object after applying the annular mask. 164 | """ 165 | di = di/2 166 | do = do/2 167 | stop = jnp.where(((self.X**2 + self.Y**2) / di**2) < 1, 0, 1) 168 | ring = jnp.where(((self.X**2 + self.Y**2) / do**2) < 1, 1, 0) 169 | output = ScalarLight(self.x, self.y, self.wavelength) 170 | output.field = self.field * stop*ring 171 | return output 172 | 173 | def RS_propagation(self, z): 174 | """ 175 | Rayleigh-Sommerfeld diffraction integral in both, z>0 and z<0, for ScalarLight. 176 | [Ref 1: F. Shen and A. Wang, Appl. Opt. 45, 1102-1110 (2006)]. 177 | [Ref 2: J. Li, Z. Fan, Y. Fu, Proc. SPIE 4915, (2002)]. 178 | 179 | Parameters: 180 | z (float): Distance to propagate (in microns). 181 | 182 | Returns ScalarLight object after propagation and the quality factor of the algorithm. 183 | """ 184 | tic = time.perf_counter() 185 | nx, ny, dx, dy, Xext, Yext = build_grid(self.x, self.y) 186 | 187 | # Quality factor for accurate simulation [Eq. 17 in Ref 2]: 188 | dr_real = jnp.sqrt(dx**2 + dy**2) 189 | rmax = jnp.sqrt((self.x**2).max() + (self.y**2).max()) 190 | dr_ideal = jnp.sqrt((self.wavelength / 1)**2 + rmax**2 + 2 * (self.wavelength / 1) * jnp.sqrt(rmax**2 + z**2)) - rmax 191 | quality_factor = dr_ideal / dr_real 192 | 193 | propagated_light = ScalarLight(self.x, self.y, self.wavelength) 194 | propagated_light.field = RS_propagation_jit(self.field, z, nx, ny, dx, dy, Xext, Yext, self.k) 195 | print(f"Time taken to perform one RS propagation (in seconds): {(time.perf_counter() - tic):.4f}") 196 | return propagated_light, quality_factor 197 | 198 | def get_RS_minimum_z(self, n=1, quality_factor=1): 199 | """ 200 | Given a quality factor, determines the minimum available (trustworthy) distance for RS_propagation(). 201 | [Ref 1: Laser Phys. Lett., 10(6), 065004 (2013)]. 202 | 203 | Parameters: 204 | n (float): refraction index of the medium. 205 | quality_factor (int): Defaults to 1. 206 | 207 | Returns the minimum distance z (in microns) necessary to achieve qualities larger than quality_factor. 208 | 209 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) << 210 | """ 211 | # Check sampling 212 | range_x = self.x[-1] - self.x[0] 213 | range_y = self.y[-1] - self.y[0] 214 | num_x = jnp.size(self.x) 215 | num_y = jnp.size(self.y) 216 | 217 | dx = range_x / num_x 218 | dy = range_y / num_y 219 | # Delta rho 220 | dr_real = jnp.sqrt(dx**2 + dy**2) 221 | # Rho 222 | rmax = jnp.sqrt(range_x**2 + range_y**2) 223 | 224 | factor = (((quality_factor * dr_real + rmax)**2 - (self.wavelength / n)**2 - rmax**2) / (2 * self.wavelength / n))**2 - rmax**2 225 | 226 | if factor > 0: 227 | z_min = jnp.sqrt(factor) 228 | else: 229 | z_min = 0 230 | 231 | return print("Minimum distance to propagate (in microns):", z_min) 232 | 233 | def CZT(self, z, xout=None, yout=None): 234 | """ 235 | Chirped z-transform propagation - efficient diffraction using the Bluestein method. 236 | Useful for imaging light in the focal plane: allows high resolution zoom in z-plane. 237 | [Ref] Hu, Y., et al. Light Sci Appl 9, 119 (2020). 238 | 239 | Parameters: 240 | z (float): Propagation distance (in microns). 241 | xout (jnp.array): Array with the x-positions for the output plane. 242 | yout (jnp.array): Array with the y-positions for the output plane. 243 | 244 | Returns ScalarLight object after propagation. 245 | """ 246 | tic = time.perf_counter() 247 | if xout is None: 248 | xout = self.x 249 | 250 | if yout is None: 251 | yout = self.y 252 | 253 | # Define main set of parameters 254 | nx, ny, dx, dy, Xout, Yout, Dm, fy_1, fy_2, fx_1, fx_2 = build_CZT_grid(z, self.wavelength, self.x, self.y, xout, yout) 255 | 256 | # Compute the diffraction integral using Bluestein method 257 | field_at_z = CZT_jit(self.field, z, self.wavelength, self.k, nx, ny, dx, dy, Xout, Yout, self.X, self.Y, Dm, fy_1, fy_2, fx_1, fx_2) 258 | 259 | # Build ScalarLight object with output field. 260 | field_out = ScalarLight(xout, yout, self.wavelength) 261 | field_out.field = field_at_z 262 | print(f"Time taken to perform one CZT propagation (in seconds): {(time.perf_counter() - tic):.4f}") 263 | return field_out 264 | 265 | def build_grid(x, y): 266 | """[From RS_propagation]: Returns the grid where the transfer function is defined.""" 267 | nx = len(x) 268 | ny = len(y) 269 | dx = x[1] - x[0] 270 | dy = y[1] - y[0] 271 | # Build 2N-1 x 2N-1 (X, Y) space: 272 | x_padded = jnp.pad((x[0] - x[::-1]), (0, jnp.size(x) - 1), 'reflect') 273 | y_padded = jnp.pad((y[0] - y[::-1]), (0, jnp.size(y) - 1), 'reflect') 274 | # Convert the right half into positive values: 275 | I = jnp.ones((1, int(len(x_padded) / 2) + 1)) 276 | II = -jnp.ones((1, int(len(x_padded) / 2))) 277 | III = jnp.ravel(jnp.concatenate((I, II), 1)) 278 | Xext, Yext = jnp.meshgrid(x_padded * III, y_padded * III) 279 | return nx, ny, dx, dy, Xext, Yext 280 | 281 | @partial(jit, static_argnums=(2, 3, 4, 5, 8)) 282 | def RS_propagation_jit(input_field, z, nx, ny, dx, dy, Xext, Yext, k): 283 | """[From RS_propagation]: JIT function for Equation (10) in [Ref 1].""" 284 | # input_field is jnp.array of (N, N) 285 | H = transfer_function_RS(z, Xext, Yext, k) 286 | U = jnp.zeros((2 * ny - 1, 2 * nx - 1), dtype=complex) 287 | U = U.at[0:ny, 0:nx].set(input_field) 288 | output_field = (jnp.fft.ifft2(jnp.fft.fft2(U) * jnp.fft.fft2(H)) * dx * dy)[ny - 1:, nx - 1:] 289 | return output_field 290 | 291 | @partial(jit, static_argnums=(3,)) 292 | def transfer_function_RS(z, Xext, Yext, k): 293 | """[From RS_propagation]: JIT function for optical transfer function.""" 294 | r = jnp.sqrt(Xext ** 2 + Yext ** 2 + z ** 2) 295 | factor = 1 / (2 * jnp.pi) * z / r ** 2 * (1 / r - 1j * k) 296 | result = jnp.where(z > 0, jnp.exp(1j * k * r) * factor, jnp.exp(-1j * k * r) * factor) 297 | return result 298 | 299 | def build_CZT_grid(z, wavelength, xin, yin, xout, yout): 300 | """ 301 | [From CZT]: Defines the resolution / sampling of initial and output planes. 302 | 303 | Parameters: 304 | xin (jnp.array): Array with the x-positions of the input plane. 305 | yin (jnp.array): Array with the y-positions of the input plane. 306 | xout (jnp.array): Array with the x-positions of the output plane. 307 | yout (jnp.array): Array with the y-positions of the output plane. 308 | 309 | Returns the set of parameters: nx, ny, Xout, Yout, dx, dy, delta_out, Dm, fy_1, fy_2, fx_1 and fx_2. 310 | """ 311 | # Resolution of the output plane: 312 | nx = len(xout) 313 | ny = len(yout) 314 | Xout, Yout = jnp.meshgrid(xout, yout) 315 | 316 | # Sampling of initial plane: 317 | dx = xin[1] - xin[0] 318 | dy = yin[1] - yin[0] 319 | 320 | # For Bluestein method implementation: 321 | # Dimension of the output field - Eq. (11) in [Ref]. 322 | Dm = wavelength * z / dx 323 | 324 | # (1) for FFT in Y-dimension: 325 | fy_1 = yout[0] + Dm / 2 326 | fy_2 = yout[-1] + Dm / 2 327 | # (1) for FFT in X-dimension: 328 | fx_1 = xout[0] + Dm / 2 329 | fx_2 = xout[-1] + Dm / 2 330 | 331 | return nx, ny, dx, dy, Xout, Yout, Dm, fy_1, fy_2, fx_1, fx_2 332 | 333 | def CZT_jit(field, z, wavelength, k, nx, ny, dx, dy, Xout, Yout, X, Y, Dm, fy_1, fy_2, fx_1, fx_2): 334 | """ 335 | [From CZT]: Diffraction integral implementation using Bluestein method. 336 | [Ref] Hu, Y., et al. Light Sci Appl 9, 119 (2020). 337 | """ 338 | # Compute the scalar diffraction integral using RS transfer function: 339 | # See Eq.(3) in [Ref]. 340 | F0 = transfer_function_RS(z, Xout, Yout, k) 341 | F = transfer_function_RS(z, X, Y, k) 342 | 343 | # Compute (E0 x F) in Eq.(6) in [Ref]. 344 | field = field * F 345 | 346 | # Bluestein method implementation: 347 | 348 | # (1) FFT in Y-dimension: 349 | U = Bluestein_method(field, fy_1, fy_2, Dm, ny) 350 | 351 | # (2) FFT in X-dimension using output from (1): 352 | U = Bluestein_method(U, fx_1, fx_2, Dm, nx) 353 | 354 | # Compute Eq.(6) in [Ref]. 355 | field_at_z = F0 * U * z * dx * dy * wavelength 356 | 357 | return field_at_z 358 | 359 | def CZT_for_high_NA_jit(field, nx, ny, Dm, fy_1, fy_2, fx_1, fx_2): 360 | """ 361 | [From VCZT_objective_lens - in optical_elements.py]: Function for Debye integral implementation using Bluestein method. 362 | [Ref] Hu, Y., et al. Light Sci Appl 9, 119 (2020). 363 | """ 364 | # Bluestein method implementation: 365 | # (1) FFT in Y-dimension: 366 | U = Bluestein_method(field, fy_1, fy_2, Dm, ny) 367 | 368 | # (2) FFT in X-dimension using output from (1): 369 | U = Bluestein_method(U, fx_1, fx_2, Dm, nx) 370 | 371 | return U 372 | 373 | def compute_np2(x): 374 | """ 375 | [For Bluestein method]: Exponent of next higher power of 2. 376 | 377 | Parameters: 378 | x (float): value 379 | 380 | Returns the exponent for the smallest powers of two that satisfy 2**p >= X for each element in X. 381 | """ 382 | np2 = 2**(np.ceil(np.log2(x))).astype(int) 383 | return np2 384 | 385 | @partial(jit, static_argnums=(4, 5, 6, 7, 8)) 386 | def compute_fft(x, D1, D2, Dm, m, n, mp, M_out, np2): 387 | """ 388 | [From Bluestein_method]: JIT-computes the FFT part of the algorithm. 389 | """ 390 | # A-Complex exponential term 391 | A = jnp.exp(1j * 2 * jnp.pi * D1 / Dm) 392 | # W-Complex exponential term 393 | W = jnp.exp(-1j * 2 * jnp.pi * (D2 - D1) / (M_out * Dm)) 394 | 395 | # Window function 396 | h = jnp.arange(-m + 1, max(M_out -1, m -1) +1) 397 | h = W**(h**2/ 2) 398 | h_sliced = h[:mp + 1] 399 | 400 | # Compute the 1D Fourier Transform of 1/h up to length 2**nextpow2(mp) 401 | ft = jnp.fft.fft(1 / h_sliced, np2) 402 | # Compute intermediate result for Bluestein's algorithm 403 | b = A**(-(jnp.arange(m))) * h[jnp.arange(m - 1, 2 * m -1)] 404 | tmp = jnp.tile(b, (n, 1)).T 405 | # Compute the 1D Fourier Transform of input data * intermediate result 406 | b = jnp.fft.fft(x * tmp, np2, axis=0) 407 | # Compute the Inverse Fourier Transform 408 | b = jnp.fft.ifft(b * jnp.tile(ft, (n, 1)).T, axis=0) 409 | 410 | return b, h 411 | 412 | def Bluestein_method(x, f1, f2, Dm, M_out): 413 | """ 414 | [From CZT]: Performs the DFT using Bluestein method. 415 | [Ref1]: Hu, Y., et al. Light Sci Appl 9, 119 (2020). 416 | [Ref2]: L. Bluestein, IEEE Trans. Au. and Electro., 18(4), 451-455 (1970). 417 | [Ref3]: L. Rabiner, et. al., IEEE Trans. Au. and Electro., 17(2), 86-92 (1969). 418 | 419 | Parameters: 420 | x (jnp.array): Input sequence, x[n] in Eq.(12) in [Ref 1]. 421 | f1 (float): Starting point in frequency range. 422 | f2 (float): End point in frequency range. 423 | Dm (float): Dimension of the imaging plane. 424 | M_out (float): Length of the transform (resolution of the output plane). 425 | 426 | Returns the output X[m] (jnp.array). 427 | 428 | >> Adapted from MATLAB code provided by https://github.com/yanleihu/Bluestein-Method<< 429 | """ 430 | 431 | # Correspond to the length of the input sequence. 432 | m, n = x.shape 433 | 434 | # Intermediate frequency 435 | D1 = f1 + (M_out * Dm + f2 - f1) / (2 * M_out) 436 | # Upper frequency limit 437 | D2 = f2 + (M_out * Dm + f2 - f1) / (2 * M_out) 438 | 439 | # Length of the output sequence 440 | mp = m + M_out - 1 441 | np2 = compute_np2(mp) 442 | b, h = compute_fft(x, D1, D2, Dm, m, n, mp, M_out, np2) 443 | 444 | # Extract the relevant portion and multiply by the window function, h 445 | if M_out > 1: 446 | b = b[m:mp +1, 0:n].T * jnp.tile(h[m - 1:mp], (n, 1)) 447 | else: 448 | b = b[0] * h[0] 449 | 450 | # Create a linearly spaced array from 0 to M_out-1 451 | l = jnp.linspace(0, M_out - 1, M_out) 452 | # Scale the array to the frequency range [D1, D2] 453 | l = l / M_out * (D2 - D1) + D1 454 | 455 | # Eq. S14 in Supplementaty Information Section 3 in [Ref1]. Frequency shift to center the spectrum. 456 | M_shift = -m / 2 457 | M_shift = jnp.tile(jnp.exp(-1j * 2 * jnp.pi * l * (M_shift + 1 / 2) / Dm), (n, 1)) 458 | # Apply the frequency shift to the final output 459 | b = b * M_shift 460 | return b 461 | 462 | class LightSource(ScalarLight): 463 | """ Class for generating 2D wave optics light source beams. """ 464 | def __init__(self, x, y, wavelength): 465 | super().__init__(x, y, wavelength) 466 | self.info = 'Wave optics light source' 467 | 468 | def gaussian_beam(self, w0, E0, center=(0, 0), z_w0=(0, 0), alpha=0): 469 | """ 470 | Defines a gaussian beam. 471 | 472 | Parameters: 473 | w0 (float, float): Waist radius (in microns). 474 | E0 (float): Electric field amplitude at the origin (r=0, z=0). 475 | center (float, float): Position of the center of the beam (in microns). 476 | z_w0 (float, float): Position of the waist for (x, y) (in microns). 477 | alpha (float, float): Amplitude rotation (in radians). 478 | 479 | Returns LightSource object. 480 | """ 481 | # Waist radius 482 | w0_x, w0_y = w0 483 | 484 | # (x, y) center position 485 | x0, y0 = center 486 | 487 | # z-position of the beam waist 488 | z_w0x, z_w0y = z_w0 489 | 490 | # Rayleigh range 491 | Rayleigh_x = self.k * w0_x ** 2 * self.n / 2 492 | Rayleigh_y = self.k * w0_y ** 2 * self.n / 2 493 | 494 | # Gouy phase 495 | Gouy_phase_x = jnp.arctan2(z_w0x, Rayleigh_x) 496 | Gouy_phase_y = jnp.arctan2(z_w0y, Rayleigh_y) 497 | 498 | # Spot size (radius of the beam at position z) 499 | w_x = w0_x * jnp.sqrt(1 + (z_w0x / Rayleigh_x) ** 2) 500 | w_y = w0_y * jnp.sqrt(1 + (z_w0y / Rayleigh_y) ** 2) 501 | 502 | # Radius of curvature 503 | if z_w0x == 0: 504 | R_x = 1e12 505 | else: 506 | R_x = z_w0x * (1 + (Rayleigh_x / z_w0x) ** 2) 507 | if z_w0x == 0: 508 | R_y = 1e12 509 | else: 510 | R_y = z_w0y * (1 + (Rayleigh_y / z_w0y) ** 2) 511 | 512 | # Gaussian beam coordinates 513 | X, Y = jnp.meshgrid(self.x, self.y) 514 | # Accounting the rotation of the coordinates by alpha: 515 | x_rot = X * jnp.cos(alpha) + Y * jnp.sin(alpha) 516 | y_rot = -X * jnp.sin(alpha) + Y * jnp.cos(alpha) 517 | 518 | # Define the phase and amplitude of the field: 519 | phase = jnp.exp(-1j * ((self.k * z_w0x + self.k * X ** 2 / (2 * R_x) - Gouy_phase_x) + ( 520 | self.k * z_w0y + self.k * Y ** 2 / (2 * R_y) - Gouy_phase_y))) 521 | 522 | self.field = (E0 * (w0_x / w_x) * (w0_y / w_y) * jnp.exp( 523 | -(x_rot - x0) ** 2 / (w_x ** 2) - (y_rot - y0) ** 2 / (w_y ** 2))) * phase 524 | 525 | def plane_wave(self, A=1, theta=0, phi=0, z0=0): 526 | """ 527 | Defines a plane wave. 528 | 529 | Parameters: 530 | A (float): Maximum amplitude. 531 | theta (float): Angle (in radians). 532 | phi (float): Angle (in radians). 533 | z0 (float): Constant value for phase shift. 534 | 535 | Equation: 536 | self.field = A * exp(1.j * k * (self.X * sin(theta) * cos(phi) + self.Y * sin(theta) * sin(phi) + z0 * cos(theta))) 537 | 538 | Returns a LightSource object. 539 | 540 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<< 541 | """ 542 | self.field = A * jnp.exp(1j * self.k * (self.X * jnp.sin(theta) * jnp.cos(phi) + self.Y * jnp.sin(theta) * jnp.sin(phi) + z0 * jnp.cos(theta))) --------------------------------------------------------------------------------