├── .github
└── workflows
│ └── pytests.yml
├── LICENSE
├── README.md
├── examples
├── MPI_logo.png
├── __init__.py
├── __pycache__
│ └── __init__.cpython-38.pyc
├── autodiff_evaluation_xlumina.py
├── discovered_solution_4f_system.npy
├── examples.ipynb
├── noisy_4f_system.ipynb
├── noisy_optimization.ipynb
├── numerical_methods_evaluation_diffractio.py
├── scalar_diffractio.py
├── scalar_xlumina.py
├── test_diffractio_vs_xlumina.ipynb
├── vectorial_diffractio.py
└── vectorial_xlumina.py
├── experiments
├── __init__.py
├── four_f_optical_table.py
├── four_f_optimizer.py
├── generate_synthetic_data.py
├── hybrid_optimizer.py
├── hybrid_sharp_optical_table.py
├── hybrid_sted_optical_table.py
├── hybrid_with_fixed_PM.py
└── six_times_six_ansatz_with_fixed_PM.py
├── miscellaneous
├── noise-aware.png
├── performance.png
├── performance_convergence.png
├── propagation_comparison.png
└── workflow.png
├── setup.py
├── tests
├── __init__.py
├── pytest.ini
├── test_optical_elements.py
├── test_toolbox.py
├── test_vectorized_optics.py
└── test_wave_optics.py
└── xlumina
├── __init__.py
├── loss_functions.py
├── optical_elements.py
├── toolbox.py
├── vectorized_optics.py
└── wave_optics.py
/.github/workflows/pytests.yml:
--------------------------------------------------------------------------------
1 | name: pytests
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | matrix:
10 | python-version: ['3.10', '3.11']
11 |
12 | steps:
13 | - uses: actions/checkout@v4
14 | - name: Set up Python ${{ matrix.python-version }}
15 | uses: actions/setup-python@v3
16 | with:
17 | python-version: ${{ matrix.python-version }}
18 | - name: Install dependencies
19 | run: |
20 | python -m pip install --upgrade pip
21 | pip install -e .
22 | pip install pytest
23 | pip install h5py
24 | - name: Test with pytest
25 | run: pytest
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Artificial Scientist Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ✨ XLuminA ✨
2 |
3 | [-b31b1b.svg)](https://www.nature.com/articles/s41467-024-54696-y)
4 | [](https://pepy.tech/project/xlumina)
5 | [](https://opensource.org/licenses/MIT)
6 | 
7 | [](https://pypi.org/project/xlumina/)
8 | [](https://badge.fury.io/py/xlumina)
9 | [](https://github.com/artificial-scientist-lab/XLuminA/stargazers)
10 |
11 | **XLuminA, a highly-efficient, auto-differentiating discovery framework for super-resolution microscopy**
12 |
13 | 📄 Read our paper here: \
14 | [**Automated discovery of experimental designs in super-resolution microscopy with XLuminA**](https://doi.org/10.1038/s41467-024-54696-y)\
15 | *Carla Rodríguez, Sören Arlt, Leonhard Möckl and Mario Krenn*
16 |
17 | 📰 Read the press release: \
18 | [**10,000 times faster than traditional methods: new computational framework automatically discovers experimental designs in microscopy**](https://mpl.mpg.de/news/article/10000-times-faster-than-traditional-methods-new-computational-framework-automatically-discovers-experimental-designs-in-microscopy)
19 |
20 | 📚 Related works featuring XLuminA:
21 |
22 | [](https://ml4physicalsciences.github.io/2024/) **Machine Learning and the Physical Sciences Workshop -** *Poster*
23 |
24 | [](https://openreview.net/forum?id=ik9YuAHq6J&referrer=%5Bthe%20profile%20of%20Carla%20Rodríguez%5D(%2Fprofile%3Fid%3D~Carla_Rodríguez1)) **AI4Science Workshop** - *Oral contribution*
25 |
26 | [](https://openreview.net/forum?id=J8HGMimNYe&referrer=%5Bthe%20profile%20of%20Carla%20Rodríguez%5D(%2Fprofile%3Fid%3D~Carla_Rodríguez1)) **AI4Science Workshop -** *Oral contribution*
27 |
28 | ## 💻 Installation:
29 |
30 | ### Using PyPI:
31 |
32 | Create a new conda environment and install `xlumina`from PyPI. We recommend using `python=3.11`:
33 | ```
34 | conda create -n xlumina_env python=3.11
35 |
36 | conda activate xlumina_env
37 |
38 | pip install xlumina
39 | ```
40 |
41 | It should be installed in about 10 seconds. The package automatically installs:
42 |
43 | 1. [**JAX (CPU only) and jaxlib**](https://jax.readthedocs.io/en/latest/index.html) (the version of JAX used in this project is v0.4.33),
44 |
45 | 2. [**Optax**](https://github.com/google-deepmind/optax/tree/master) (the version of Optax used in this project is v0.2.3),
46 |
47 | 3. [**SciPy**](https://scipy.org) (the version of SciPy used in this project is v1.14.1).
48 |
49 | ### Clone repository:
50 | ```
51 | git clone https://github.com/artificial-scientist-lab/XLuminA.git
52 | ```
53 |
54 | ### GPU compatibility:
55 |
56 | To install [JAX with NVIDIA GPU support](https://jax.readthedocs.io/en/latest/installation.html) (**Note: wheels only available on linux**), use CUDA 12 installation:
57 |
58 | ```
59 | pip install --upgrade "jax[cuda12]"
60 | ```
61 |
62 | XLuminA has been tested on the following operating systems:
63 |
64 | Linux Enterprise Server 15 SP4 15.4,
65 |
66 | and it has been successfully installed in Windows 10 (64-bit) and MacOS Monterey 12.6.2
67 |
68 |
69 | # 👾 Features:
70 |
71 | XLuminA allows for the simulation, in a (*very*) fast and efficient way, of classical light propagation through optics hardware configurations,and enables the optimization and automated discovery of new setup designs.
72 |
73 |
74 |
75 | The simulator contains many features:
76 |
77 | ✦ Light sources (of any wavelength and power) using both scalar or vectorial optical fields.
78 |
79 | ✦ Phase masks (e.g., spatial light modulators (SLMs), polarizers and general variable retarders (LCDs)).
80 |
81 | ✦ Amplitude masks (e.g., spatial light modulators (SLMs) and pre-defined circles, triangles and squares).
82 |
83 | ✦ Beam splitters, fluorescence model for STED, and more!
84 |
85 | ✦ The light propagation methods available in XLuminA are:
86 |
87 | - [Fast-Fourier-transform (FFT) based numerical integration of the Rayleigh-Sommerfeld diffraction integral](https://doi.org/10.1364/AO.45.001102).
88 |
89 | - [Chirped z-transform](https://doi.org/10.1038/s41377-020-00362-z). This algorithm is an accelerated version of the Rayleigh-Sommerfeld method, which allows for arbitrary selection and sampling of the region of interest.
90 |
91 | - Propagation through [high NA objective lenses](https://doi.org/10.1016/j.optcom.2010.07.030) is available to replicate strong focusing conditions using polarized light.
92 |
93 | # 📝 Examples of usage:
94 |
95 | Examples of some experiments that can be reproduced with XLuminA are:
96 |
97 | * Optical telescope (or 4f-correlator),
98 | * Polarization-based beam shaping as used in [STED (stimulated emission depletion) microscopy](https://opg.optica.org/ol/fulltext.cfm?uri=ol-19-11-780&id=12352),
99 | * The [sharp focus of a radially polarized light beam](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.91.233901).
100 |
101 | The code for each of these optical setups is provided in the Jupyter notebook of [examples.ipynb](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/examples.ipynb).
102 |
103 | ➤ A **step-by-step guide on how to add noise to the optical elements** can be found in [noisy_4f_system.ipynb](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/noisy_4f_system.ipynb).
104 |
105 | # 🚀 Testing XLuminA's efficiency:
106 |
107 | We evaluated our framework by conducting several tests - see [Figure 1](https://arxiv.org/abs/2310.08408#). The experiments were run on an Intel CPU Xeon Gold 6130 and Nvidia GPU Quadro RTX 6000.
108 |
109 | (1) Average execution time (in seconds) over 100 runs, within a computational window size of $2048\times 2048$, for scalar and vectorial field propagation using Rayleigh-Sommerfeld (RS, VRS) and Chirped z-transform (CZT, VCZT) in [Diffractio](https://pypi.org/project/diffractio/) and XLuminA. Times for XLuminA reflect the run with pre-compiled jitted functions. The Python files corresponding to light propagation algorithms testing are [scalar_diffractio.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/scalar_diffractio.py) and [vectorial_diffractio.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/vectorial_diffractio.py) for Diffractio, and [scalar_xlumina.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/scalar_xlumina.py) and [vectorial_xlumina.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/vectorial_xlumina.py) for XLuminA.
110 |
111 |
112 |
113 | (2) We compare the gradient evaluation time of numerical methods (using Diffractio's optical simulator and SciPy's [BFGS optimizer](https://docs.scipy.org/doc/scipy/reference/optimize.minimize-bfgs.html#optimize-minimize-bfgs)) *vs* autodiff (analytical) differentiation (using XLuminA's optical simulator with JAX's [ADAM optimizer](https://jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html)) across various resolutions:
114 |
115 |
116 |
117 | (3) We compare the convergence time of numerical methods (using Diffractio's optical simulator and SciPy's [BFGS optimizer](https://docs.scipy.org/doc/scipy/reference/optimize.minimize-bfgs.html#optimize-minimize-bfgs)) *vs* autodiff (analytical) differentiation (using XLuminA's optical simulator with JAX's [ADAM optimizer](https://jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html)) across various resolutions:
118 |
119 |
120 |
121 | ➤ The Jupyter notebook used for running these simulations is provided as [test_diffractio_vs_xlumina.ipynb](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/test_diffractio_vs_xlumina.ipynb).
122 |
123 | ➤ The Python files corresponding to numerical/autodiff evaluations are [numerical_methods_evaluation_diffractio.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/numerical_methods_evaluation_diffractio.py), and [autodiff_evaluation_xlumina.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/autodiff_evaluation_xlumina.py)
124 |
125 | *If you want to run the comparison test of the propagation functions, you need to install [**Diffractio**](https://pypi.org/project/diffractio/) - The version of Diffractio used in this project is v0.1.1.*
126 |
127 | # 🤖🔎 Discovery of new optical setups:
128 |
129 | With XLuminA we were able to rediscover three foundational optics experiments. In particular, we discover new, superior topologies together with their parameter settings using purely continuous optimization.
130 |
131 | ➤ Optical telescope (or 4f-correlator),
132 |
133 | ➤ Polarization-based beam shaping as used in [STED (stimulated emission depletion) microscopy](https://opg.optica.org/ol/fulltext.cfm?uri=ol-19-11-780&id=12352),
134 |
135 | ➤ The [sharp focus of a radially polarized light beam](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.91.233901).
136 |
137 | The Python files used for the discovery of these optical setups, as detailed in [our paper](https://arxiv.org/abs/2310.08408#), are organized in pairs of `optical_table` and `optimizer` as follows:
138 |
139 | | **Experiment name** | 🔬 Optical table | 🤖 Optimizer | 📄 File for data |
140 | |----------------|---------------|-----------|----------|
141 | | ***Optical telescope*** | [four_f_optical_table.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/four_f_optical_table.py) | [four_f_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/four_f_optimizer.py)| [Generate_synthetic_data.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/generate_synthetic_data.py) |
142 | | ***Pure topological discovery: large-scale sharp focus (Dorn, Quabis and Leuchs, 2004)*** | [hybrid_with_fixed_PM.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_with_fixed_PM.py) | [hybrid_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_optimizer.py)| N/A |
143 | | ***Pure topological discovery: STED microscopy*** | [hybrid_with_fixed_PM.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_with_fixed_PM.py) | [hybrid_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_optimizer.py)| N/A |
144 | | ***6x6 grid: pure topological discovery*** | [six_times_six_ansatz_with_fixed_PM.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/six_times_six_ansatz_with_fixed_PM.py) | [hybrid_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_optimizer.py)| N/A |
145 | | ***Large-scale polarization-based STED*** | [hybrid_sted_optical_table.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_sted_optical_table.py) | [hybrid_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_optimizer.py)| N/A |
146 | | ***Large-scale sharp focus (Dorn, Quabis and Leuchs, 2004)*** | [hybrid_sharp_optical_table.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_sharp_optical_table.py) | [hybrid_optimizer.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/experiments/hybrid_optimizer.py)| N/A |
147 |
148 | # 🦾🤖 Robustness and parallelized optimization of multiple optical tables with our noise-aware scheme:
149 |
150 | ✦ Importantly, to ensure simulations which approximate real-world experimental conditions we have included imperfections, misalignment, and noise sources in all optical components (during post-processing and/or during optimization). **All the results presented in the paper are computed considering a wide variety of experimental errors**.
151 |
152 | ➤ A **step-by-step guide on how to setup the optimization using this scheme** can be found in [noisy_optimization.ipynb](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/noisy_optimization.ipynb).
153 |
154 | ➤ A **step-by-step guide on how to add noise to the optical elements** can be found in [noisy_4f_system.ipynb](https://github.com/artificial-scientist-lab/XLuminA/blob/main/examples/noisy_4f_system.ipynb).
155 |
156 |
157 |
158 | ✦ The optimization procedure is as follows: for each optimization step, we execute $N$ parallel optical tables using `vmap`. Then, we sample random noise and apply it to all available physical variables across each of the $N$ optical tables. The random noise is **uniformly distributed** and includes:
159 | * Phase values for spatial light modulators (SLMs) and wave plates (WP) in the range of $\pm$ (0.01 to 0.1) radians, covering all qualities available in current experimental devices.
160 | * Misalignment ranging from $\pm$ (0.01 to 0.1) millimeters, covering both expert-level precision ($\pm$ 0.01 mm) and beginner-level accuracy ($\pm$ 0.1 mm).
161 | * 1\% imperfection on the transmissivity/reflectivity of beam splitters (BS), which is a realistic approach given the high quality of the currently available hardware.
162 |
163 | We then simulate the optical setup for each of the $N$ tables simultaneously, incorporating the sampled noise. The loss function is computed independently for each of the setups. Afterwards, we calculate the mean loss value across all optical tables, which provides an average performance metric that accounts for the introduced experimental variability (noise). The gradients are computed based on this mean loss value and so the update of the system parameters'.
164 |
165 | Importantly, before applying the updated parameters and proceeding to the next iteration, we resample new random noise for each optical table. This ensures that each optimization step encounters different noise values, further enhancing the robustness of the solution. This procedure is repeated iteratively until convergence.
166 |
167 | # 👀 Overview:
168 |
169 | In this section we list the available functions in different files and a brief description:
170 |
171 | 1. In [wave_optics.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/wave_optics.py): module for scalar optical fields.
172 |
173 | |*Class*|*Functions*|*Description*|
174 | |---------------|----|-----------|
175 | | `ScalarLight` | | Class for scalar optical fields defined in the XY plane: complex amplitude $U(r) = A(r)*e^{-ikz}$. |
176 | | | `.draw` | Plots intensity and phase. |
177 | | | `.apply_circular_mask` | Apply a circular mask of variable radius. |
178 | | | `.apply_triangular_mask` | Apply a triangular mask of variable size. |
179 | | | `.apply_rectangular_mask` | Apply a rectangular mask of variable size. |
180 | | | `.apply_annular_aperture` | Apply annular aperture of variable size. |
181 | | | `.RS_propagation` | [Rayleigh-Sommerfeld](https://doi.org/10.1364/AO.45.001102) diffraction integral in z-direction (z>0 and z<0). |
182 | | | `.get_RS_minimum_z` | Given a quality factor, determines the minimum (trustworthy) distance for `RS_propagation`.|
183 | | | `.CZT` | [Chirped z-transform](https://doi.org/10.1038/s41377-020-00362-z) - efficient diffraction using the Bluestein method.|
184 | | `LightSource` | | Class for scalar optical fields defined in the XY plane - defines light source beams. | |
185 | | | `.gaussian_beam` | Gaussian beam. |
186 | | | `.plane_wave` | Plane wave. |
187 |
188 |
189 | 2. In [vectorized_optics.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/xlumina/vectorized_optics.py): module for vectorized optical fields.
190 |
191 | |*Class*| *Functions* |*Description*|
192 | |---------------|----|-----------|
193 | | `VectorizedLight` | | Class for vectorized optical fields defined in the XY plane: $\vec{E} = (E_x, E_y, E_z)$|
194 | | | `.draw` | Plots intensity, phase and amplitude. |
195 | | | `.draw_intensity_profile` | Plots intensity profile. |
196 | | | `.VRS_propagation` | [Vectorial Rayleigh-Sommerfeld](https://iopscience.iop.org/article/10.1088/1612-2011/10/6/065004) diffraction integral in z-direction (z>0 and z<0). |
197 | | | `.get_VRS_minimum_z` | Given a quality factor, determines the minimum (trustworthy) distance for `VRS_propagation`.|
198 | | | `.VCZT` | [Vectorized Chirped z-transform](https://doi.org/10.1038/s41377-020-00362-z) - efficient diffraction using the Bluestein method.|
199 | | `PolarizedLightSource` | | Class for polarized optical fields defined in the XY plane - defines light source beams. | |
200 | | | `.gaussian_beam` | Gaussian beam. |
201 | | | `.plane_wave` | Plane wave. |
202 |
203 |
204 | 3. In [optical_elements.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/xlumina/optical_elements.py): shelf with all the optical elements available.
205 |
206 | | *Function* |*Description*|
207 | |---------------|----|
208 | | ***Scalar light devices*** | - |
209 | | `phase_scalar_SLM` | Phase mask for the spatial light modulator available for scalar fields. |
210 | | `SLM` | Spatial light modulator: applies a phase mask to incident scalar field. |
211 | | ***Jones matrices*** | - |
212 | | `jones_LP` | Jones matrix of a [linear polarizer](https://doi.org/10.1201/b19711)|
213 | | `jones_general_retarder` | Jones matrix of a [general retarder](https://www.researchgate.net/publication/235963739_Obtainment_of_the_polarizing_and_retardation_parameters_of_a_non-depolarizing_optical_system_from_the_polar_decomposition_of_its_Mueller_matrix). |
214 | | `jones_sSLM` | Jones matrix of the *superSLM*. |
215 | | `jones_sSLM_with_amplitude` | Jones matrix of the *superSLM* that modulates phase & amplitude. |
216 | | `jones_LCD` | Jones matrix of liquid crystal display (LCD).|
217 | | ***Polarization-based devices*** | - |
218 | |`sSLM` | *super*-Spatial Light Modulator: adds phase mask (pixel-wise) to $E_x$ and $E_y$ independently. |
219 | | `sSLM_with_amplitude` | *super*-Spatial Light Modulator: adds phase mask and amplitude mask (pixel-wise) to $E_x$ and $E_y$ independently. |
220 | | `LCD` | Liquid crystal device: builds any linear wave-plate. |
221 | | `linear_polarizer` | Linear polarizer.|
222 | | `BS_symmetric` | Symmetric beam splitter.|
223 | | `BS_symmetric_SI` | Symmetric beam splitter with single input.|
224 | | `BS` | Single-side coated dielectric beam splitter.|
225 | | `high_NA_objective_lens` | High NA objective lens (only for `VectorizedLight`).|
226 | | `VCZT_objective_lens` | Propagation through high NA objective lens (only for `VectorizedLight`).|
227 | | ***General elements*** | - |
228 | | `lens` | Transparent lens of variable size and focal length.|
229 | | `cylindrical_lens` | Transparent plano-convex cylindrical lens of variable focal length. |
230 | | `axicon_lens` | Axicon lens function that produces a Bessel beam. |
231 | | `circular_mask` | Circular mask of variable size. |
232 | | `triangular_mask` | Triangular mask of variable size and orientation.|
233 | | `rectangular_mask` | Rectangular mask of variable size and orientation.|
234 | | `annular_aperture` | Annular aperture of variable size.|
235 | | `forked_grating` | Forked grating of variable size, orientation, and topological charge. |
236 | | 👷♀️ ***Pre-built optical setups*** | - |
237 | | `bb_amplitude_and_phase_mod` | Basic building unit. Consists of a `sSLM` (amp & phase modulation), and `LCD` linked via `VRS_propagation`. |
238 | | `building_block` | Basic building unit. Consists of a `sSLM`, and `LCD` linked via `VRS_propagation`. |
239 | | `fluorescence`| Fluorescence model.|
240 | | `vectorized_fluorophores` | Vectorized version of `fluorescence`: Allows to compute effective intensity across an array of detectors.|
241 | | `robust_discovery` | 3x3 setup for hybrid (topology + optical settings) discovery with single wavelength. Longitudinal intensity (Iz) is measured across all detectors. Includes noise for robustness. |
242 | | `hybrid_setup_fixed_slms_fluorophores`| 3x3 optical table with SLMs randomly positioned displaying fixed phase masks; to be used for pure topological discovery; contains the fluorescence model in all detectors. (*Fig. 4a* of [our paper](https://arxiv.org/abs/2310.08408#))|
243 | | `hybrid_setup_fixed_slms`| 3x3 optical table with SLMs randomly positioned displaying fixed phase masks; to be used for pure topological discovery. (*Fig. 4b* of [our paper](https://arxiv.org/abs/2310.08408#))|
244 | | `hybrid_setup_fluorophores`| 3x3 optical table to be used for hybrid (topological + optical parameter) discovery; contains the fluorescence model in all detectors . (*Fig. 5a* and *Fig. 6* of [our paper](https://arxiv.org/abs/2310.08408#))|
245 | | `hybrid_setup_sharp_focus`| 3x3 optical table to be used for hybrid (topological + optical parameter) discovery. (*Fig. 5b* of [our paper](https://arxiv.org/abs/2310.08408#))|
246 | | `six_times_six_ansatz`| 6x6 optical table to be used for pure topological discovery. (*Extended Data Fig. 6* of [our paper](https://arxiv.org/abs/2310.08408#))|
247 | | 🫨 ***Add noise to the optical elements*** | - |
248 | |`shake_setup`| Literally shakes the setup: creates noise and misalignment for the optical elements. Accepts noise settings (dictionary) as argument. Can't be used with `jit` across parallel optical tables. |
249 | |`shake_setup_jit`| Same as `shake_setup`. Doesn't accept noise settings as argument. Intended to be pasted in the optimizer file to enable `jit` compilation across parallel optical tables.|
250 |
251 | 4. In [toolbox.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/xlumina/toolbox.py): file with useful functions.
252 |
253 | | *Function* |*Description*|
254 | |---------------|----|
255 | | ***Basic operations*** | - |
256 | | `space` | Builds the space where light is placed. |
257 | | `wrap_phase` | Wraps any phase mask into $[-\pi, \pi]$ range.|
258 | | `is_conserving_energy` | Computes the total intensity from the light source and compares is with the propagated light - [Ref](https://doi.org/10.1117/12.482883).|
259 | | `softmin` | Differentiable version for min() function.|
260 | | `delta_kronecker` | Kronecker delta.|
261 | | `build_LCD_cell` | Builds the cell for `LCD`.|
262 | | `draw_sSLM` | Plots the two phase masks of `sSLM`.|
263 | | `draw_sSLM_amplitude` | Plots the two amplitude masks of `sSLM`.|
264 | | `moving_avg` | Compute the moving average of a dataset.|
265 | | `image_to_binary_mask`| Converts image (png, jpeg) to binary mask. |
266 | | `rotate_mask` | Rotates the (X, Y) frame w.r.t. given point. |
267 | | `gaussian` | Defines a 1D Gaussian distribution. |
268 | | `gaussian_2d` | Defines a 2D Gaussian distribution. |
269 | | `lorentzian` | Defines a 1D Lorentzian distribution. |
270 | | `lorentzian_2d` | Defines a 2D Lorentzian distribution. |
271 | | `fwhm_1d_fit` | Computes FWHM in 1D using fit for `gaussian` or `lorentzian`.|
272 | | `fwhm_2d_fit` | Computes FWHM in 2D using fit for `gaussian_2d` or `lorentzian_2d`. |
273 | | `profile` | Determines the profile of a given input without using interpolation.|
274 | | `spot_size` | Computes the spot size as $\pi (\text{FWHM}_x \cdot \text{FWHM}_y) /\lambda^2$. |
275 | | `compute_fwhm` | Computes FWHM in 1D or 2D using fit: `gaussian`, `gaussian_2d`, `lorentzian`, `lorentzian_2`. |
276 | | 📑 ***Data loader*** | - |
277 | | `MultiHDF5DataLoader` | Data loader class for 4f system training. |
278 |
279 | 5. In [loss_functions.py](https://github.com/artificial-scientist-lab/XLuminA/blob/main/xlumina/loss_functions.py): file with loss functions.
280 |
281 | | *Function* |*Description*|
282 | |---------------|----|
283 | | `small_area_hybrid` | Small area loss function valid for hybrid (topology + optical parameters) optimization|
284 | | `vMSE_Intensity` | Parallel computation of Mean Squared Error (Intensity) for a given electric field component $E_x$, $E_y$ or $E_z$. |
285 | | `MSE_Intensity` | Mean Squared Error (Intensity) for a given electric field component $E_x$, $E_y$ or $E_z$. |
286 | | `vMSE_Phase` | Parallel computation of Mean Squared Error (Phase) for a given electric field component $E_x$, $E_y$ or $E_z$. |
287 | | `MSE_Phase` | Mean Squared Error (Phase) for a given electric field component $E_x$, $E_y$ or $E_z$. |
288 | | `vMSE_Amplitude` | Parallel computation of Mean Squared Error (Amplitude) for a given electric field component $E_x$, $E_y$ or $E_z$. |
289 | | `MSE_Amplitude` | Mean Squared Error (Amplitude) for a given electric field component $E_x$, $E_y$ or $E_z$. |
290 | | `mean_batch_MSE_Intensity` | Batch-based `MSE_Intensity`.|
291 |
292 | # ⚠️ Considerations when using XLuminA:
293 |
294 | 1. By default, JAX uses `float32` precision. If necessary, enable `jax.config.update("jax_enable_x64", True)` at the beginning of the file.
295 |
296 | 2. Basic units are microns (um) and radians. Other units (centimeters, millimeters, nanometers, and degrees) are available at `__init.py__`.
297 |
298 | 3. **IMPORTANT** - RAYLEIGH-SOMMERFELD PROPAGATION:
299 | [FFT-based diffraction calculation algorithms](https://doi.org/10.1117/12.482883) can be innacurate depending on the computational window size (sampling).\
300 | Before propagating light, one should check which is the minimum distance available for the simulation to be accurate.\
301 | You can use the following functions:
302 |
303 | `get_RS_minimum_z`, for `ScalarLight` class, and `get_VRS_minimum_z`, for `VectorizedLight` class.
304 |
305 |
306 | # 💻 Development:
307 |
308 | *Some functionalities of XLuminA’s optics simulator (e.g., optical propagation algorithms, lens or amplitude masks) are inspired in an open-source NumPy-based Python module for diffraction and interferometry simulation, [Diffractio](https://pypi.org/project/diffractio/). **We have rewritten and modified these approaches to combine them with JAX just-in-time (jit) functionality**. We labeled these functions as such in the docstrings. On top of that, **we developed completely new functions** (e.g., sSLMs, beam splitters, LCDs or propagation through high NA objective lens with CZT methods, to name a few) **which significantly expand the software capabilities.***
309 |
310 | # 📝 How to cite XLuminA:
311 |
312 | If you use this software, please cite as:
313 |
314 | Rodríguez, C., Arlt, S., Möckl, L. and Krenn, M. Automated discovery of experimental designs in super-resolution microscopy with XLuminA. *Nat Commun* **15**, 10658 (2024). https://doi.org/10.1038/s41467-024-54696-y
315 |
316 | BibTeX format:
317 |
318 | @article{NatCommun.15.10658,
319 | title={Automated discovery of experimental designs in super-resolution microscopy with {XLuminA}},
320 | author={Rodríguez, Carla and Arlt, Sören and Möckl, Leonhard and Krenn, Mario},
321 | journal={Nature Communications},
322 | volume={15},
323 | pages={10658},
324 | year={2024},
325 | publisher={Nature Publishing Group},
326 | doi={10.1038/s41467-024-54696-y}
327 | }
328 |
--------------------------------------------------------------------------------
/examples/MPI_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/examples/MPI_logo.png
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/examples/__init__.py
--------------------------------------------------------------------------------
/examples/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/examples/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/examples/autodiff_evaluation_xlumina.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # Setting the path for XLuminA modules:
4 | current_path = os.path.abspath(os.path.join('..'))
5 | module_path = os.path.join(current_path)
6 |
7 | if module_path not in sys.path:
8 | sys.path.append(module_path)
9 |
10 | from xlumina.__init__ import um, nm, mm, degrees, radians
11 | from xlumina.wave_optics import *
12 | from xlumina.vectorized_optics import *
13 | from xlumina.optical_elements import *
14 | from xlumina.loss_functions import *
15 | from xlumina.toolbox import space
16 | import jax.numpy as jnp
17 |
18 | """ Evaluates the convergence and gradient evaluation times for XLuminA's autodiff optimization """
19 |
20 | # System specs:
21 | sensor_lateral_size =500 # Resolution
22 | wavelength = 650*nm
23 | x_total = 1000*um
24 | x, y = space(x_total, sensor_lateral_size)
25 | shape = jnp.shape(x)[0]
26 |
27 | # Light source specs
28 | w0 = (1200*um , 1200*um)
29 | gb = LightSource(x, y, wavelength)
30 | gb.gaussian_beam(w0=w0, E0=1)
31 | gb_gt = ScalarLight(x, y, wavelength)
32 | # Set spiral phase for the ground truth
33 | gb_gt.set_spiral()
34 |
35 | # Define the setup
36 | def setup(gb, parameters):
37 | gb_propagated = gb.RS_propagation(z=25000*mm)
38 | gb_modulated, _ = SLM(gb, parameters, gb.x.shape[0])
39 | return gb_modulated
40 |
41 | def mse_phase(input_light, target_light):
42 | return jnp.sum((jnp.angle(input_light.field) - jnp.angle(target_light.field)) ** 2) / sensor_lateral_size**2
43 | def loss(parameters):
44 | out = setup(gb, parameters)
45 | loss_val = mse_phase(out, gb_gt)
46 | return loss_val
47 |
48 | # Optimizer for phase mask
49 | import time
50 | import jax
51 | from jax import grad, jit
52 | from jax.example_libraries import optimizers
53 |
54 | # Print device info (GPU or CPU)
55 | print(jax.devices(), flush=True)
56 |
57 | # Define the update:
58 | @jit
59 | def update(step_index, optimizer_state):
60 | # define single update step
61 | parameters = get_params(optimizer_state)
62 | # Call the loss function and compute the gradients
63 | computed_loss = loss_value(parameters)
64 | computed_gradients = grad(loss_value, allow_int=True)(parameters)
65 |
66 | return opt_update(step_index, computed_gradients, optimizer_state), computed_loss, computed_gradients
67 |
68 | # Define the loss function and compute its gradients:
69 | loss_value = jit(loss)
70 |
71 | # Optimizer settings
72 | STEP_SIZE = 0.1
73 | num_iterations = 50000
74 | n_best = 50
75 | best_loss = 1e10
76 | best_params = None
77 | best_step = 0
78 | num_samples = 5
79 |
80 | steps = []
81 | times = []
82 | ratio = []
83 |
84 | for i in range(num_samples):
85 | # Parameters for STED
86 | parameters = jnp.array([np.random.uniform(-jnp.pi, jnp.pi, (shape, shape))], dtype=jnp.float64)
87 |
88 | # Define the optimizer and initialize it
89 | opt_init, opt_update, get_params = optimizers.adam(STEP_SIZE)
90 | init_params = parameters
91 | opt_state = opt_init(init_params)
92 |
93 | print('Starting Optimization', flush=True)
94 |
95 | tic = time.perf_counter()
96 |
97 | # Optimize in a loop
98 | for step in range(num_iterations):
99 |
100 | opt_state, loss_value, gradients = update(step, opt_state)
101 |
102 | if loss_value < best_loss:
103 | # Best loss value
104 | best_loss = loss_value
105 | # Best optimized parameters
106 | best_params = get_params(opt_state)
107 | best_step = step
108 | # print('Best loss value is updated')
109 |
110 | if step % 100 == 0:
111 | # Stopping criteria: if best loss has not changed every 500 steps
112 | if step - best_step > n_best:
113 | print(f'Stopping criterion: no improvement in loss value for {n_best} steps')
114 | break
115 |
116 | steps.append(step)
117 | times.append(time.perf_counter() - tic)
118 | ratio.append((time.perf_counter() - tic)/step)
119 |
120 | filename = f"xlumina_cpu_eval_{sensor_lateral_size}.npy"
121 | np.save(filename, {"Time": times, "Step": steps, "t/step": ratio})
--------------------------------------------------------------------------------
/examples/discovered_solution_4f_system.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/examples/discovered_solution_4f_system.npy
--------------------------------------------------------------------------------
/examples/noisy_optimization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 🦾🤖 Noise-aware optimization scheme with ✨ XLuminA ✨: \n",
8 | "\n",
9 | "This notebook is a step-by-step guide for building a robust (noise-aware) optimization scheme in XLuminA.\n",
10 | "\n",
11 | "We will set-up an optimization scheme for *the sharp focus for a radially polarized light beam* - (we use **robust_discovery** from **optical_elements.py**)"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 1,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import os\n",
21 | "import sys\n",
22 | "\n",
23 | "# Setting the path for XLuminA modules:\n",
24 | "current_path = os.path.abspath(os.path.join('..'))\n",
25 | "module_path = os.path.join(current_path)\n",
26 | "\n",
27 | "if module_path not in sys.path:\n",
28 | " sys.path.append(module_path)"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 17,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "from xlumina.__init__ import um, nm, mm\n",
38 | "from xlumina.vectorized_optics import *\n",
39 | "from xlumina.optical_elements import robust_discovery\n",
40 | "from xlumina.toolbox import space, softmin\n",
41 | "from xlumina.loss_functions import vectorized_loss_hybrid\n",
42 | "import jax.numpy as jnp\n",
43 | "import jax\n",
44 | "from jax import random\n",
45 | "import optax"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {},
51 | "source": [
52 | "## System specs, define light sources, output dimensions and static parameters during optimization:"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 5,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "# 1. System specs:\n",
62 | "sensor_lateral_size = 512 # Resolution\n",
63 | "wavelength1 = 650*nm\n",
64 | "x_total = 2500*um\n",
65 | "x, y = space(x_total, sensor_lateral_size)\n",
66 | "shape = jnp.shape(x)[0]\n",
67 | "\n",
68 | "# 2. Define the optical functions: two orthogonally polarized beams:\n",
69 | "w0 = (1200*um, 1200*um) \n",
70 | "ls1 = PolarizedLightSource(x, y, wavelength1)\n",
71 | "ls1.gaussian_beam(w0=w0, jones_vector=(1, 1))\n",
72 | "\n",
73 | "# 3. Define the output (High Resolution) detection:\n",
74 | "x_out, y_out = jnp.array(space(10*um, 400)) # Pixel size detector: 20 um / 400 pix \n",
75 | "\n",
76 | "# 4. High NA objective lens specs:\n",
77 | "NA = 0.9 \n",
78 | "radius_lens = 3.6*mm/2 \n",
79 | "f_lens = radius_lens / NA\n",
80 | "\n",
81 | "# 5. Static parameters - don't change during optimization:\n",
82 | "fixed_params = [radius_lens, f_lens, x_out, y_out]"
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {},
88 | "source": [
89 | "## Define the optical setup:"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "metadata": {},
95 | "source": [
96 | "1. Vectorized version of the optical setup [`optical_elements.py` > `robust_discovery`] over a new axis (defined by the noise).\n",
97 | "\n",
98 | " Here the args of the function are: `light source` (ls1 - ls6, common to all tables), `parameters` (common to all tables), `fixed_params` (common to all tables), `noise` (DIFFERENT for each table). "
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": 6,
104 | "metadata": {},
105 | "outputs": [],
106 | "source": [
107 | "def batch_robust_discovery(ls1, ls2, ls3, ls4, ls5, ls6, parameters, fixed_params, noise_distances, noise_slms, noise_wps, noise_amps, distance_offset):\n",
108 | " \"\"\"\n",
109 | " Vectorized (efficient) version of robust_discovery() for batch optimization. \n",
110 | " \n",
111 | " Parameters: \n",
112 | " ls1, ls2, ls3, ls4, ls5, ls6 (PolarizedLightSource)\n",
113 | " parameters (jnp.array): parameters to pass to the optimizer\n",
114 | " BB 1: [phase1_1, phase1_2, eta1, theta1, z1_1, z1_2]\n",
115 | " BB 2: [phase2_1, phase2_2, eta2, theta2, z2_1, z2_2] \n",
116 | " BB 3: [phase3_1, phase3_2, eta3, theta3, z3_1, z3_2]\n",
117 | " BS ratios: [bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9] <- automatically contains \n",
118 | " Extra distances: [z4, z5]\n",
119 | " fixed_params (jnp.array): parameters to maintain fixed during optimization [r, f, xout and yout]; that is radius and focal length of the objective lens.\n",
120 | " \n",
121 | " noise_distances (jnp.array): Misalignment (in microns): [noise_z1_1, noise_z1_2, ...]\n",
122 | " noise_slms (jnp.array): Noise (in radians): [noise_phase1_1, noise_phase1_2, ...]\n",
123 | " noise_wps (jnp.array): Noise (in radians): [noise_eta1, noise_theta1, ...]\n",
124 | " noise_amps (jnp.array): Noise (in AU): [noise_A1, noise_A2, ...]\n",
125 | " \n",
126 | " Returns vectorized version of detected light (intensity tensor): (# tables, (6, resolution, resolution))\n",
127 | " \"\"\"\n",
128 | " # Noise shapes are: \n",
129 | " # distance: (#tables, 1, 8); \n",
130 | " # slms (amp and phase): (#tables, 6, (resolution,resolution)); \n",
131 | " # wp: (#tables, 1, 6)\n",
132 | " # vmap in axes 0 -> across optical tables\n",
133 | " detected_intensities_z = vmap(robust_discovery, in_axes=(None, None, None, None, None, None, None, None, \n",
134 | " 0, 0, 0, 0, None))(ls1, ls2, ls3, ls4, ls5, ls6, parameters, fixed_params, \n",
135 | " noise_distances, noise_slms, noise_wps, noise_amps, distance_offset)\n",
136 | " return detected_intensities_z"
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "metadata": {},
142 | "source": [
143 | "2. Vectorize the loss function: `vmap` the computation of the loss function across different optical tables.(imported from `loss_functions.py` > `vectorized_loss_hybrid`)."
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 7,
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "def mean_batch_discover(detected_light):\n",
153 | " \"\"\" vmap loss over optical tables: \"\"\"\n",
154 | " # detected_light with shape (#optical tables, (6, N, N))\n",
155 | " # vmap loss in axis #optical tables \n",
156 | " return vmap(vectorized_loss_hybrid, in_axes=(0,))(detected_light)"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "3. Define the loss: first computes light from paralellized optical tables (`batch_robust_discovery`), later compute the loss function for each detector in each optical table. Finally, compute the mean loss value across the optical tables and get the minimum value using `softmin`."
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 8,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "@jit \n",
173 | "def loss_batch_discovery(parameters, noise_d, noise_slm, noise_wp, noise_amp):\n",
174 | " \"\"\"\n",
175 | " Loss function. It computes L= Area/I_{epsilon} across detectors.\n",
176 | " \n",
177 | " Parameters:\n",
178 | " parameters (list): Optimized parameters.\n",
179 | " noise_distances (jnp.array): Misalignment (in microns): [noise_d1, noise_d2, ...]\n",
180 | " noise_slms (jnp.array): Noise (in radians): [noise_slm_1, noise_slm_2, ...]\n",
181 | " noise_wps (jnp.array): Noise (in radians): [noise_eta, noise_theta, ...]\n",
182 | " noise_amp (jnp.array): Noise (in AU): [noise_A1, noise_A2, ...]\n",
183 | "\n",
184 | " Returns the mean value of the loss computed for all the inputs. \n",
185 | " \"\"\"\n",
186 | " # Output from batch_robust_discovery is (#optical tables, (6, N, N)): for 6 detectors each\n",
187 | " detected_z_intensities = batch_robust_discovery(ls1, ls1, ls1, ls1, ls1, ls1, parameters, fixed_params, noise_d, noise_slm, noise_wp, noise_amp, distance_offset = 15) \n",
188 | "\n",
189 | " # Get the minimum value within loss value array\n",
190 | " # output from mean_batch_discover is (#optical tables, (6, 1)). \n",
191 | " # Compute the mean across #optical tables and get the minimum value using softmin.\n",
192 | " loss_val = softmin(jnp.mean(mean_batch_discover(detected_z_intensities), axis=0, keepdims=True))\n",
193 | " return loss_val # shape (#optical tables, 1)"
194 | ]
195 | },
196 | {
197 | "cell_type": "markdown",
198 | "metadata": {},
199 | "source": [
200 | "## Optimizer settings:"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": 9,
206 | "metadata": {},
207 | "outputs": [],
208 | "source": [
209 | "# Global variable\n",
210 | "shape = jnp.array([sensor_lateral_size, sensor_lateral_size])\n",
211 | "# Define the loss function:\n",
212 | "loss_function = loss_batch_discovery"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": 10,
218 | "metadata": {},
219 | "outputs": [],
220 | "source": [
221 | "# Optimization settings\n",
222 | "OS = {'n_best': 500,\n",
223 | " 'best_loss': 3*1e2,\n",
224 | " 'num_iterations': 50000,\n",
225 | " 'num_samples': 1,\n",
226 | " 'WEIGHT_DECAY': 1e-3,\n",
227 | " 'BASE_lr': 0.05,\n",
228 | " 'END_lr': 0.001,\n",
229 | " 'DECAY_STEPS': 4000\n",
230 | " }"
231 | ]
232 | },
233 | {
234 | "cell_type": "markdown",
235 | "metadata": {},
236 | "source": [
237 | "## [!!] define noise settings dictionary:\n",
238 | "\n",
239 | "**NS (dict)**: \n",
240 | "\n",
241 | " NS = {'n_tables': __, \n",
242 | " 'number of distances': __,\n",
243 | " 'number of sSLM': __, \n",
244 | " 'number of wps': __,\n",
245 | " 'noise_level': __, \n",
246 | " 'misalignment': (minval, maxval), \n",
247 | " 'phase_noise': (minval, maxval),\n",
248 | " 'discretize': __}\n",
249 | "\n",
250 | "where, \n",
251 | "\n",
252 | "**n_tables (int)**: number of optical tables to compute in parallel\n",
253 | "\n",
254 | "**number of distances, number of sSLM, number of wps (str)**: number of distances, sSLM and wave plates in the optical setup.\n",
255 | "\n",
256 | "**level (str)**:\n",
257 | "\n",
258 | " 1. low: noise in SLMs and WPs $\\pm$(0.01 to 0.05) rads and misalignment of $\\pm$(0.01 to 0.05) mm \n",
259 | " 2. mild: noise in SLMs and WPs $\\pm$(0.05 to 0.5) rads and misalignment of $\\pm$(0.05 to 0.5) mm \n",
260 | " 3. high: noise in SLMs and WPs $\\pm$(0.5 to 1) rads and misalignment of $\\pm$(0.5 to 1) mm \n",
261 | " 4. all: noise in SLMs and WPs $\\pm$(0.01 to 1) rads and misalignment of $\\pm$(0.01 to 1) mm \n",
262 | " 5. tunable: tunable noise via NS dictionary\n",
263 | "\n",
264 | "**discretize (bool)**: if true, discretize SLM noise to 8-bit.\n"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 11,
270 | "metadata": {},
271 | "outputs": [],
272 | "source": [
273 | "# Noise settings:\n",
274 | "NS = {'n_tables': 3, \n",
275 | " 'number of distances': 8,\n",
276 | " 'number of sSLM': 3, \n",
277 | " 'number of wps': 3,\n",
278 | " 'noise_level': 'tunable',\n",
279 | " 'misalignment': (10, 100), \n",
280 | " 'phase_noise': (0.01, 0.1),\n",
281 | " 'discretize': False}"
282 | ]
283 | },
284 | {
285 | "cell_type": "markdown",
286 | "metadata": {},
287 | "source": [
288 | "1. Define keychain for the number of optical tables specified in NS."
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": 12,
294 | "metadata": {},
295 | "outputs": [],
296 | "source": [
297 | "# Keychain for optical tables\n",
298 | "def keychain_optical_tables(seed, number_of_tables):\n",
299 | " \"\"\" \n",
300 | " Generates keychain for # optical tables especified \n",
301 | " \"\"\"\n",
302 | " keychain = []\n",
303 | " for num in range(number_of_tables):\n",
304 | " key_table = random.PRNGKey(seed + num)\n",
305 | " keychain.append(key_table)\n",
306 | " \n",
307 | " return jnp.array(keychain)"
308 | ]
309 | },
310 | {
311 | "cell_type": "markdown",
312 | "metadata": {},
313 | "source": [
314 | "2. Define `shake_setup()` and `batch_shake_setup()` as functions to include noise in the setup per iteration. \n",
315 | "\n",
316 | " Two types of shaking functions are provided in `optical_elements.py`. \n",
317 | "\n",
318 | " 1. `shake_setup` takes noise settings NS:dict as argument. Thus, `batch_shake_setup` can't be used with @jit or @partial(jit). \n",
319 | "\n",
320 | " 2. However, if you want to @jit `batch_shake_setup`, copy-paste `shake_setup_jit` in your optimizer file, as it doesn't have NS:dict as an argument. \n",
321 | "\n",
322 | "Here we will copy-paste `shake_setup_jit` and use NS as global variable. "
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": 13,
328 | "metadata": {},
329 | "outputs": [],
330 | "source": [
331 | "def shake_setup_jit(key, resolution):\n",
332 | " \"\"\"\n",
333 | " [THIS FUNCTION IS INTENDED TO BE PASTED IN THE OPTIMIZER FILE TO ENABLE @jit COMPILATION FOR `batch_shake_setup`]\n",
334 | " \n",
335 | " Creates noise for all the different optical variables on an optical table.\n",
336 | " \n",
337 | " Parameters:\n",
338 | " key (PRNGKey): JAX random key for reproducibility\n",
339 | " resolution (int): number of pixels for space\n",
340 | " \n",
341 | " global variable NS (dict): noise settings as\n",
342 | "\n",
343 | " NS = {'n_tables': __, 'number of distances': __,\n",
344 | " 'number of sSLM': __, 'number of wps': __,\n",
345 | " 'noise_level': __, \n",
346 | " 'misalignment': (minval, maxval), \n",
347 | " 'phase_noise': (minval, maxval),\n",
348 | " 'discretize': __}\n",
349 | " \n",
350 | " Returns:\n",
351 | " random_noise_distances, random_noise_slms, random_noise_wps, random_noise_amps, key0 (new key to split in the next iteration), key (old key)\n",
352 | " \"\"\"\n",
353 | " num_physical_variables = 4 # of physical variables (e.g., distance, slm phase, ,...) to optimize.\n",
354 | " # split as many times as variables + 1 to renew the key0 each step\n",
355 | " key0, key1, key2, key3, key4 = random.split(key, num_physical_variables+1)\n",
356 | " d_type = 'int8'\n",
357 | "\n",
358 | " # NS is not an input to ensure vmap during optimization.\n",
359 | " level = NS['noise_level']\n",
360 | " discretize = NS['discretize']\n",
361 | " \n",
362 | " # level can be: 'low' == 0, 'mild' == 1, 'high' == 2, 'all' == 3 and 'tunable' == 4\n",
363 | "\n",
364 | " if level == 'low':\n",
365 | " # Misalignment (um)\n",
366 | " minval_d = 10*um # 0.01 mm\n",
367 | " maxval_d = 50*um # 0.05 mm\n",
368 | " # SLM / WP phase and amplitude (rads and AU, respectively)\n",
369 | " minval_phase = 0.01 \n",
370 | " maxval_phase = 0.05\n",
371 | "\n",
372 | " if level == 'mild':\n",
373 | " # Misalignment (um)\n",
374 | " minval_d = 50*um # 0.05 mm\n",
375 | " maxval_d = 500*um # 0.5 mm\n",
376 | " # SLM / WP phase and amplitude (rads and AU, respectively)\n",
377 | " minval_phase = 0.05 \n",
378 | " maxval_phase = 0.5\n",
379 | "\n",
380 | " if level == 'high':\n",
381 | " # Misalignment (um)\n",
382 | " minval_d = 500*um # 0.5 mm\n",
383 | " maxval_d = 1000*um # 1 mm\n",
384 | " # SLM / WP phase and amplitude (rads and AU, respectively)\n",
385 | " minval_phase = 0.5\n",
386 | " maxval_phase = 1\n",
387 | "\n",
388 | " if level == 'all':\n",
389 | " # Misalignment (um)\n",
390 | " minval_d = 10*um # 0.01 mm\n",
391 | " maxval_d = 1000*um # 0.15 mm\n",
392 | " # SLM / WP phase and amplitude (rads and AU, respectively)\n",
393 | " minval_phase = 0.01 \n",
394 | " maxval_phase = 1\n",
395 | "\n",
396 | " if level == 'tunable':\n",
397 | " # Misalignment (um)\n",
398 | " minval_d, maxval_d = NS['misalignment'] # in um\n",
399 | " # SLM / WP phase and amplitude (rads and AU, respectively)\n",
400 | " minval_phase, maxval_phase = NS['phase_noise']\n",
401 | "\n",
402 | " if discretize: \n",
403 | " d_type = 'uint8'\n",
404 | "\n",
405 | " # noise for distances (d1 and d2): shape = (1, NS['number of distances'])\n",
406 | " random_noise_distances = jnp.squeeze(random.uniform(key1, shape=(1, NS['number of distances']), minval=minval_d, maxval=maxval_d), axis=0) \n",
407 | " # noise for SLMs phases and amplitude (slm1 and slm2): shape = (2, (resolution, resolution))\n",
408 | " random_noise_amps = random.choice(key4, jnp.array([-1,1]), shape=(2*NS['number of sSLM'], resolution, resolution)).astype(d_type) * random.uniform(key2, shape=(2*NS['number of sSLM'], resolution, resolution), minval=minval_phase, maxval=maxval_phase) \n",
409 | " random_noise_slms = random.choice(key2, jnp.array([-1,1]), shape=(2*NS['number of sSLM'], resolution, resolution)).astype(d_type) * random.uniform(key2, shape=(2*NS['number of sSLM'], resolution, resolution), minval=minval_phase, maxval=maxval_phase) \n",
410 | " # noise for WP angles (eta and theta): shape = (1, 2)\n",
411 | " random_noise_wps = jnp.squeeze(random.choice(key3, jnp.array([-1,1]), shape=(1, 2*NS['number of wps'])).astype('int8'), axis=0) * jnp.squeeze(random.uniform(key3, shape=(1, 2*NS['number of wps']), minval=minval_phase, maxval=maxval_phase), axis=0) \n",
412 | " \n",
413 | " return random_noise_distances, random_noise_slms, random_noise_wps, random_noise_amps, key0, key\n",
414 | "\n",
415 | "\n",
416 | "@jit\n",
417 | "def batch_shake_setup(key_array, array_for_shape):\n",
418 | " \"\"\"\n",
419 | " Creates noise for all the different optimizable variables on multiple optical tables given by size(key_array).\n",
420 | " \n",
421 | " Parameters: \n",
422 | " key_array (PRNGKey): Array with different keys -- will change for each step in the optimization. \n",
423 | " The dimension of this array is decided by # of optical tables to compute in parallel.\n",
424 | " array_for_shape (jnp.array): array of shape [resolution, resolution] to make it jit. \n",
425 | " \n",
426 | " Returns:\n",
427 | " random_noise_distances [with shape = (size(key_array), 1, NS['number of distances'])], \n",
428 | " random_noise_amps [with shape = (size(key_array), 2*NS['number of sSLM'], resolution, resolution)], \n",
429 | " random_noise_slms [with shape = (size(key_array), 2*NS['number of sSLM'], resolution, resolution)], \n",
430 | " random_noise_wps [with shape = (size(key_array), 1, 2*NS['number of wps'])]\n",
431 | " key0 (PRNGKey): array with key0 to split in the next iteration step\n",
432 | " \"\"\"\n",
433 | " return vmap(shake_setup_jit, in_axes = (0, None))(key_array, jnp.shape(array_for_shape)[0])"
434 | ]
435 | },
436 | {
437 | "cell_type": "markdown",
438 | "metadata": {},
439 | "source": [
440 | "## Define optimizer (adamw with schedule)"
441 | ]
442 | },
443 | {
444 | "cell_type": "code",
445 | "execution_count": 14,
446 | "metadata": {},
447 | "outputs": [],
448 | "source": [
449 | "def adamw_schedule(base_lr, end_lr, decay_steps, weight_decay) -> optax.GradientTransformation:\n",
450 | " \"\"\"\n",
451 | " Custom optimizer - adamw: applies several transformations in sequence\n",
452 | " 1) Apply ADAMW\n",
453 | " 2) Apply lr schedule\n",
454 | " \"\"\"\n",
455 | " lr_schedule = base_lr\n",
456 | " #lr_schedule = optax.linear_schedule(init_value= base_lr, end_value = end_lr, transition_steps = decay_steps, transition_begin = 500) \n",
457 | " return optax.adamw(learning_rate=lr_schedule, weight_decay=weight_decay)"
458 | ]
459 | },
460 | {
461 | "cell_type": "markdown",
462 | "metadata": {},
463 | "source": [
464 | "## Optimization loop: "
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "execution_count": 15,
470 | "metadata": {},
471 | "outputs": [],
472 | "source": [
473 | "def fit(params: optax.Params, optimizer: optax.GradientTransformation, num_iterations, keys, x) -> optax.Params:\n",
474 | " \n",
475 | " # Init the optimizer with initial parameters\n",
476 | " opt_state = optimizer.init(params)\n",
477 | "\n",
478 | " @jit\n",
479 | " def update(parameters, opt_state, noise_d, noise_slm, noise_wp, noise_amp):\n",
480 | " # Define single update step - contains noise_array: \n",
481 | " loss_value, grads = jax.value_and_grad(loss_function)(parameters, noise_d, noise_slm, noise_wp, noise_amp)\n",
482 | " # Update the state of the optimizer\n",
483 | " updates, state = optimizer.update(grads, opt_state, parameters)\n",
484 | " # Update the parameters\n",
485 | " new_params = optax.apply_updates(parameters, updates)\n",
486 | " \n",
487 | " return new_params, parameters, state, loss_value, updates\n",
488 | "\n",
489 | " # Initialize some parameters \n",
490 | " n_best = OS['n_best']\n",
491 | " best_loss = OS['best_loss']\n",
492 | " best_params = None\n",
493 | " best_keys = None\n",
494 | " best_step = 0\n",
495 | " \n",
496 | " print('Starting Optimization', flush=True)\n",
497 | " \n",
498 | " for step in range(num_iterations):\n",
499 | " \n",
500 | " # Add noise: update noise and keys each iteration\n",
501 | " noise_d, noise_slm, noise_wp, noise_amp, keys, old_keys = batch_shake_setup(keys, x) # 'x' is the space variable from optical table\n",
502 | " \n",
503 | " # Apply update step\n",
504 | " params, old_params, opt_state, loss_value, grads = update(params, opt_state, noise_d, noise_slm, noise_wp, noise_amp)\n",
505 | " \n",
506 | " print(f\"Step {step}\")\n",
507 | " print(f\"Loss {loss_value}\")\n",
508 | " \n",
509 | " # Update the `best_loss` value:\n",
510 | " if loss_value < best_loss:\n",
511 | " # Best loss value\n",
512 | " best_loss = loss_value\n",
513 | " # Best optimized parameters\n",
514 | " best_params = old_params\n",
515 | " # Keys for best params\n",
516 | " best_keys = old_keys\n",
517 | " best_step = step\n",
518 | " print('Best loss value is updated')\n",
519 | "\n",
520 | " if step % 100 == 0:\n",
521 | " # Stopping criteria: if best_loss has not changed every 500 steps, stop.\n",
522 | " if step - best_step > n_best:\n",
523 | " print(f'Stopping criterion: no improvement in loss value for {n_best} steps')\n",
524 | " break\n",
525 | " \n",
526 | " print(f'Best loss: {best_loss} at step {best_step}')\n",
527 | " print(f'Best parameters: {best_params}') \n",
528 | " return best_params, best_loss, best_keys"
529 | ]
530 | },
531 | {
532 | "cell_type": "code",
533 | "execution_count": null,
534 | "metadata": {},
535 | "outputs": [],
536 | "source": [
537 | "# Optimizer settings\n",
538 | "num_iterations = OS['num_iterations']\n",
539 | "num_samples = OS['num_samples']\n",
540 | "\n",
541 | "for i in range(num_samples):\n",
542 | " tic = time.perf_counter()\n",
543 | " \n",
544 | " # seed1 to ensure randomness among samples\n",
545 | " seed1 = np.random.randint(9999)\n",
546 | " \n",
547 | " # Init keychain for noise -- as many init keys as optical tables to parallelize\n",
548 | " keys = keychain_optical_tables(seed1, NS['n_tables'])\n",
549 | " \n",
550 | " # Optimizer settings\n",
551 | " WEIGHT_DECAY = OS['WEIGHT_DECAY']\n",
552 | " BASE_lr = OS['BASE_lr']\n",
553 | " END_lr = OS['END_lr']\n",
554 | " DECAY_STEPS = OS['DECAY_STEPS']\n",
555 | " \n",
556 | " # Random init parameters:\n",
557 | " phase1_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
558 | " phase1_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
559 | " a1_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
560 | " a1_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
561 | " \n",
562 | " phase2_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
563 | " phase2_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
564 | " a2_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
565 | " a2_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
566 | " \n",
567 | " phase3_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
568 | " phase3_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
569 | " a3_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
570 | " a3_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]\n",
571 | " \n",
572 | " eta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n",
573 | " theta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n",
574 | " eta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n",
575 | " theta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n",
576 | " eta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n",
577 | " theta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n",
578 | " eta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n",
579 | " theta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]\n",
580 | " \n",
581 | " z1_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
582 | " z1_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
583 | " z2_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
584 | " z2_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
585 | " z3_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
586 | " z3_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
587 | " z4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
588 | " z5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
589 | " \n",
590 | " bs1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
591 | " bs2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
592 | " bs3 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
593 | " bs4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
594 | " bs5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
595 | " bs6 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
596 | " bs7 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
597 | " bs8 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
598 | " bs9 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)\n",
599 | " \n",
600 | " # Init params for 3x3 robust discovery\n",
601 | " init_params = [phase1_1, phase1_2, a1_1, a1_2, eta1, theta1, z1_1, z1_2, \n",
602 | " phase2_1, phase2_2, a2_1, a2_2, eta2, theta2, z2_1, z2_2, \n",
603 | " phase3_1, phase3_2, a3_1, a3_2, eta3, theta3, z3_1, z3_2, \n",
604 | " bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9, \n",
605 | " z4, z5]\n",
606 | " \n",
607 | " # Init optimizer:\n",
608 | " optimizer = adamw_schedule(BASE_lr, END_lr, DECAY_STEPS, WEIGHT_DECAY)\n",
609 | "\n",
610 | " # Apply fit function:\n",
611 | " best_params, best_loss, iteration_steps, loss_list, keys_noise = fit(init_params, optimizer, num_iterations, keys, x)\n",
612 | " \n",
613 | " print(f\"Time taken to optimize one sample - in seconds {(time.perf_counter() - tic):.4f}\")"
614 | ]
615 | },
616 | {
617 | "cell_type": "code",
618 | "execution_count": null,
619 | "metadata": {},
620 | "outputs": [],
621 | "source": []
622 | }
623 | ],
624 | "metadata": {
625 | "kernelspec": {
626 | "display_name": "dummy_env",
627 | "language": "python",
628 | "name": "python3"
629 | },
630 | "language_info": {
631 | "codemirror_mode": {
632 | "name": "ipython",
633 | "version": 3
634 | },
635 | "file_extension": ".py",
636 | "mimetype": "text/x-python",
637 | "name": "python",
638 | "nbconvert_exporter": "python",
639 | "pygments_lexer": "ipython3",
640 | "version": "3.11.10"
641 | }
642 | },
643 | "nbformat": 4,
644 | "nbformat_minor": 2
645 | }
646 |
--------------------------------------------------------------------------------
/examples/numerical_methods_evaluation_diffractio.py:
--------------------------------------------------------------------------------
1 | from diffractio import np, degrees, um, mm, nm
2 | from diffractio.scalar_sources_XY import Scalar_source_XY
3 | from diffractio.scalar_fields_XY import Scalar_field_XY
4 | import time
5 |
6 | """ Evaluates the convergence and gradient evaluation times for Diffractio + numerical optimization """
7 |
8 | # Light source settings
9 | wavelength = .6328 * um
10 | w0 = (1200*um , 1200*um)
11 | sensor_lateral_size = 10 # Resolution
12 | wavelength = 650*nm
13 | x_total = 1000*um
14 | x = np.linspace(-x_total , x_total , sensor_lateral_size)
15 | y = np.linspace(-x_total , x_total , sensor_lateral_size)
16 | gb = Scalar_source_XY(x, y, wavelength, info='Light source')
17 | gb.gauss_beam(r0=(0 * um, 0 * um), w0=w0, z0=(0,0), A=1, theta=0 * degrees, phi=0 * degrees)
18 |
19 | # Spiral phase for ground truth
20 | gb_gt = Scalar_field_XY(x, y, wavelength)
21 | phase_mask = np.arctan2(gb.Y,gb.X)
22 | gb_gt.u = gb.u * np.exp(1j * phase_mask)
23 |
24 | # Optical setup
25 | def setup(gb, parameters):
26 | gb_propagated = gb.RS(25000*mm)
27 | gb_modulated, _ = npSLM(gb_propagated.u, parameters, gb.x.shape[0])
28 | return gb_modulated
29 |
30 | def mse_phase(input_light, target_light):
31 | return np.sum((np.angle(input_light) - np.angle(target_light.u)) ** 2) / sensor_lateral_size**2
32 |
33 | def loss(parameters_flat):
34 | parameters = parameters_flat.reshape(sensor_lateral_size,sensor_lateral_size)
35 | out = setup(gb, parameters)
36 | loss_val = mse_phase(out, gb_gt)
37 | return loss_val
38 |
39 | def phase(phase):
40 | return np.exp(1j * phase)
41 |
42 | def npSLM(input_field, phase_array, shape):
43 | slm = np.fromfunction(lambda i, j: phase(phase_array[i, j]),
44 | (shape, shape), dtype=int)
45 | light_out = input_field * slm # Multiplies element-wise
46 | return light_out, slm
47 |
48 | from scipy.optimize import minimize
49 | import time
50 |
51 | results = []
52 | times = []
53 |
54 | parameters = np.random.uniform(-np.pi, np.pi, (sensor_lateral_size, sensor_lateral_size)).flatten()
55 | tic = time.perf_counter()
56 |
57 | res = minimize(loss, parameters, method='BFGS', options={'disp': True})
58 |
59 | time_to_conv = time.perf_counter() - tic
60 | times.append(time_to_conv)
61 | results.append(res)
62 |
63 | # We save the output (res) to divide the total time by 'njev' from BFGS.
--------------------------------------------------------------------------------
/examples/scalar_diffractio.py:
--------------------------------------------------------------------------------
1 | from diffractio import np, degrees, um, mm
2 | from diffractio.scalar_sources_XY import Scalar_source_XY
3 | from diffractio.scalar_fields_XY import Scalar_field_XY
4 | import time
5 |
6 | """ Computes the running times for scalar version of Rayleigh-Sommerfeld (RS) and Chirped z-transform (CZT) algorithms using Diffractio """
7 |
8 | # Light source settings
9 | wavelength = .6328 * um
10 | w0 = (1200*um , 1200*um)
11 | x = np.linspace(-15 * mm, 15 * um, 2048)
12 | y = np.linspace(-15 * um, 15 * um, 2048)
13 | x_out = np.linspace(-15 * um, 15 * um, 2048)
14 | y_out = np.linspace(-15 * um, 15 * um, 2048)
15 |
16 | ls = Scalar_source_XY(x, y, wavelength, info='Light source')
17 | ls.gauss_beam(r0=(0 * um, 0 * um), w0=w0, z0=(0,0), A=1, theta=0 * degrees, phi=0 * degrees)
18 |
19 |
20 | time_RS_diffractio = []
21 | time_CZT_diffractio = []
22 |
23 | for i in range(101):
24 | tic = time.perf_counter()
25 | ls_propagated = ls.RS(z=5*mm)
26 | time_RS_diffractio.append(time.perf_counter() - tic)
27 |
28 | for i in range(101):
29 | tic = time.perf_counter()
30 | ls_propagated = ls.CZT(z=5*mm, xout=x_out, yout=y_out, verbose=False)
31 | time_CZT_diffractio.append(time.perf_counter() - tic)
32 |
33 | filename = f"Scalar_propagation_Diffractio.npy"
34 | np.save(filename, {"RS_Diffractio": time_RS_diffractio, "CZT_Diffractio": time_CZT_diffractio})
--------------------------------------------------------------------------------
/examples/scalar_xlumina.py:
--------------------------------------------------------------------------------
1 | # Setting the path for XLuminA modules:
2 | current_path = os.path.abspath(os.path.join('..'))
3 | module_path = os.path.join(current_path)
4 |
5 | if module_path not in sys.path:
6 | sys.path.append(module_path)
7 |
8 |
9 | from xlumina.__init__ import mm
10 | from xlumina.wave_optics import *
11 | from xlumina.vectorized_optics import *
12 | from xlumina.optical_elements import *
13 | from xlumina.loss_functions import *
14 | from xlumina.toolbox import space
15 | import jax.numpy as jnp
16 |
17 | import time
18 |
19 | """ Computes the running times for scalar version of Rayleigh-Sommerfeld (RS) and Chirped z-transform (CZT) algorithms using XLuminA """
20 |
21 | # Light source settings
22 | resolution = 2048
23 | wavelength = .6328 * um
24 | w0 = (1200*um , 1200*um)
25 | x, y = space(15*mm, resolution)
26 | x_out, y_out = jnp.array(space(15*mm, resolution))
27 |
28 | gb = LightSource(x, y, wavelength)
29 | gb.gaussian_beam(w0=w0, E0=1)
30 |
31 | # Rayleigh-Sommerfeld:
32 | tic = time.perf_counter()
33 | ls_propagated = gb.RS_propagation(z=5*mm)
34 | print("Time taken for 1st RS propagation - in seconds", time.perf_counter() - tic)
35 |
36 | time_RS_xlumina = []
37 | time_CZT_xlumina = []
38 |
39 | for i in range(101):
40 | tic = time.perf_counter()
41 | ls_propagated = gb.RS_propagation(z=5*mm)
42 | t = time.perf_counter() - tic
43 | time_RS_xlumina.append(t)
44 |
45 |
46 | # Chirped z-transform:
47 | tic = time.perf_counter()
48 | ls_propagated = gb.CZT(z=5*mm, xout=x_out, yout=y_out)
49 | print("Time taken for 1st CZT propagation - in seconds", time.perf_counter() - tic)
50 |
51 | for i in range(101):
52 | tic = time.perf_counter()
53 | ls_propagated = gb.CZT(z=5*mm, xout=x_out, yout=y_out)
54 | t = time.perf_counter() - tic
55 | time_CZT_xlumina.append(t)
56 |
57 | filename = f"Scalar_propagation_xlumina_GPU.npy"
58 | np.save(filename, {"RS_xlumina": time_RS_xlumina, "CZT_xlumina": time_CZT_xlumina})
--------------------------------------------------------------------------------
/examples/vectorial_diffractio.py:
--------------------------------------------------------------------------------
1 | from diffractio import np, degrees, um, mm
2 | from diffractio.scalar_sources_XY import Scalar_source_XY
3 | from diffractio.scalar_fields_XY import Scalar_field_XY
4 | from diffractio.vector_sources_XY import Vector_source_XY
5 | from diffractio.vector_fields_XY import Vector_field_XY
6 | import time
7 |
8 | """ Computes the running times for vectorial version of Rayleigh-Sommerfeld (VRS) and Chirped z-transform (VCZT) algorithms using Diffractio """
9 |
10 | # Light source settings
11 | wavelength = .6328 * um
12 | w0 = (1200*um , 1200*um)
13 | x = np.linspace(-15 * mm, 15 * um, 2048)
14 | y = np.linspace(-15 * um, 15 * um, 2048)
15 | x_out = np.linspace(-15 * um, 15 * um, 2048)
16 | y_out = np.linspace(-15 * um, 15 * um, 2048)
17 |
18 |
19 | ls = Scalar_source_XY(x, y, wavelength, info='Light source')
20 | ls.gauss_beam(r0=(0 * um, 0 * um), w0=w0, z0=(0,0), A=1, theta=0 * degrees, phi=0 * degrees)
21 |
22 | vls = Vector_source_XY(x, y, wavelength=wavelength, info='Light source polarization')
23 | vls.constant_polarization(u=ls, v=(1, 0), has_normalization=False, radius=(15*mm, 15*mm))
24 |
25 | time_VRS_diffractio = []
26 | time_VCZT_diffractio = []
27 |
28 | for i in range(101):
29 | tic = time.perf_counter()
30 | vls.VRS(z=5*mm, n=1, new_field=False, verbose=False, amplification=(1, 1))
31 | time_VRS_diffractio.append(time.perf_counter() - tic)
32 |
33 | for i in range(101):
34 | tic = time.perf_counter()
35 | vls.CZT(z=5*mm, xout=x_out, yout=y_out, verbose=False)
36 | time_VCZT_diffractio.append(time.perf_counter() - tic)
37 |
38 | filename = f"vectorial_propagation_Diffractio.npy"
39 | np.save(filename, {"VRS_Diffractio": time_VRS_diffractio, "VCZT_Diffractio": time_VCZT_diffractio})
--------------------------------------------------------------------------------
/examples/vectorial_xlumina.py:
--------------------------------------------------------------------------------
1 | # Setting the path for XLuminA modules:
2 | current_path = os.path.abspath(os.path.join('..'))
3 | module_path = os.path.join(current_path)
4 |
5 | if module_path not in sys.path:
6 | sys.path.append(module_path)
7 |
8 | from xlumina.__init__ import mm
9 | from xlumina.wave_optics import *
10 | from xlumina.vectorized_optics import *
11 | from xlumina.optical_elements import *
12 | from xlumina.loss_functions import *
13 | from xlumina.toolbox import space
14 | import jax.numpy as jnp
15 |
16 | import time
17 |
18 | """ Computes the running times for vectorial version of Rayleigh-Sommerfeld (VRS) and Chirped z-transform (VCZT) algorithms using XLuminA """
19 |
20 | # Light source settings
21 | resolution = 2048
22 | wavelength = .6328 * um
23 | w0 = (1200*um , 1200*um)
24 | x, y = space(15*mm, resolution)
25 | x_out, y_out = jnp.array(space(15*mm, resolution))
26 |
27 | gb_lp = PolarizedLightSource(x, y, wavelength)
28 | gb_lp.gaussian_beam(w0=w0, jones_vector=(1, 0))
29 |
30 | # Rayleigh-Sommerfeld:
31 | tic = time.perf_counter()
32 | gb_propagated = gb_lp.VRS_propagation(z=5*mm)
33 | print("Time taken for 1st VRS propagation - in seconds", time.perf_counter() - tic)
34 |
35 | time_VRS_xlumina = []
36 | time_VCZT_xlumina = []
37 |
38 | for i in range(101):
39 | tic = time.perf_counter()
40 | gb_propagated = gb_lp.VRS_propagation(z=5*mm)
41 | time_VRS_xlumina.append(time.perf_counter() - tic)
42 |
43 | # Chirped z-transform:
44 | tic = time.perf_counter()
45 | gb_propagated = gb_lp.VCZT(z=5*mm, xout=x_out, yout=y_out)
46 | print("Time taken for 1st VCZT propagation - in seconds", time.perf_counter() - tic)
47 |
48 | for i in range(101):
49 | tic = time.perf_counter()
50 | gb_propagated = gb_lp.VCZT(z=5*mm, xout=x_out, yout=y_out)
51 | time_VCZT_xlumina.append(time.perf_counter() - tic)
52 |
53 | filename = f"vectorial_propagation_xlumina_GPU.npy"
54 | np.save(filename, {"VRS_xlumina": time_VRS_xlumina, "VCZT_xlumina": time_VCZT_xlumina})
--------------------------------------------------------------------------------
/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/experiments/__init__.py
--------------------------------------------------------------------------------
/experiments/four_f_optical_table.py:
--------------------------------------------------------------------------------
1 | # Setting the path for XLuminA modules:
2 | import os
3 | import sys
4 |
5 | # Setting the path for XLuminA modules:
6 | current_path = os.path.abspath(os.path.join('..'))
7 | module_path = os.path.join(current_path)
8 |
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | from xlumina.__init__ import um, nm, cm
13 | from xlumina.wave_optics import *
14 | from xlumina.optical_elements import SLM
15 | from xlumina.toolbox import space
16 | from jax import vmap
17 | import jax.numpy as jnp
18 |
19 | """
20 | OPTICAL TABLE WITH AN OPTICAL TELESCOPE (4F-SYSTEM).
21 | """
22 |
23 | # 1. System specs:
24 | sensor_lateral_size = 1024 # Resolution
25 | wavelength = 632.8*nm
26 | x_total = 1500*um
27 | x, y = space(x_total, sensor_lateral_size)
28 | shape = jnp.shape(x)[0]
29 |
30 | # 2. Define the light source:
31 | w0 = (1200*um , 1200*um)
32 | input_light = LightSource(x, y, wavelength)
33 | input_light.gaussian_beam(w0=w0, E0=1)
34 |
35 | # 3. Define the optical functions:
36 | def batch_dualSLM_4f(input_mask, x, y, wavelength, parameters):
37 | """
38 | [4f system coded exclusively for batch optimization purposes].
39 |
40 | Define an optical table with a 4f system composed by 2 SLMs (to be used with ScalarLight).
41 |
42 | Illustrative scheme:
43 | U(x,y) --> input_mask --> SLM(phase1) --> Propagate: RS(z1) --> SLM(phase2) --> Propagate: RS(z2) --> Detect
44 |
45 | Parameters:
46 | input_mask (jnp.array): Input mask, comes in the form of an array
47 | parameters (list): Parameters to pass to the optimizer [z1, z2, z3, phase1 and phase2] for RS propagation and the two SLMs.
48 |
49 | Returns the intensity (jnp.array) after second propagation, and phase masks slm1 and slm2.
50 |
51 | + Parameters in the optimizer are (0,1). We need to convert them back [Offset is determined by .get_RS_minimum_z() for the corresponding pixel resolution].
52 | Convert (0,1) to distance in cm. Conversion factor (offset, 100) -> (offset/100, 1).
53 | Convert (0,1) to phase (in radians). Conversion factor (0, 2pi) -> (0, 1)
54 | """
55 | global shape
56 |
57 | # From get_RS_minimum_z()
58 | offset = 1.2
59 |
60 | # Apply input mask (comes from vmap)
61 | input_light.field = input_light.field * input_mask
62 |
63 | """ Stage 0: Propagation """
64 | # Propagate light from mask
65 | light_stage0, _ = input_light.RS_propagation(z=(jnp.abs(parameters[0]) * 100 + offset)*cm)
66 |
67 | """ Stage 0: Modulation """
68 | # Feed SLM_1 with parameters[2] and apply the mask to the forward beam
69 | modulated_slm1, slm_1 = SLM(light_stage0, parameters[3] * (2*jnp.pi) - jnp.pi, shape)
70 |
71 | """ Stage 1: Propagation """
72 | # Propagate the SLM_1 output beam to another distance z
73 | light_stage1, _ = modulated_slm1.RS_propagation(z=(jnp.abs(parameters[1]) * 100 + offset)*cm)
74 |
75 | """ Stage 1: Modulation """
76 | # Apply the SLM_2 to the forward beam
77 | modulated_slm2, slm_2 = SLM(light_stage1, parameters[4] * (2*jnp.pi) - jnp.pi, shape)
78 |
79 | """ Stage 2: Propagation """
80 | # Propagate the SLM_2 output beam to another distance z
81 | fw_to_detector, _ = modulated_slm2.RS_propagation(z=(jnp.abs(parameters[2]) * 100 + offset)*cm)
82 |
83 | return jnp.abs(fw_to_detector.field)**2, slm_1, slm_2
84 |
85 | def vector_dualSLM_4f_system(input_masks, x, y, wavelength, parameters):
86 | """
87 | [Coded exclusively for the batch optimization].
88 |
89 | Vectorized (efficient) version of 4f system for batch optimization.
90 |
91 | Parameters:
92 | input_masks (jnp.array): Array with input masks
93 | x, y, wavelength (jnp.arrays and float): Light specs to pass to batch_dualSLM_4f.
94 | parameters (list): Parameters to pass to the optimizer [z1, z2, z3, phase1 and phase2] for RS propagation and the two SLMs.
95 |
96 | Returns vectorized version of detected light (intensity).
97 | """
98 | detected_intensity, _, _ = vmap(batch_dualSLM_4f, in_axes=(0, None, None, None, None))(input_masks, x, y, wavelength, parameters)
99 | return detected_intensity
100 |
101 |
102 | # 3. Define the loss function for batch optimization.
103 | def loss_dualSLM(parameters, input_masks, target_intensities):
104 | """
105 | Loss function for 4f system batch optimization. It computes the MSE between the optimized light and the target field.
106 |
107 | Parameters:
108 | parameters (list): Optimized parameters.
109 | input_masks (jnp.array): Array with input masks.
110 | target_intensities (jnp.array): Array with target intensities.
111 |
112 | Returns the mean value of the loss computed for all the inputs.
113 | """
114 | global x, y, wavelength
115 | # Input fields and target fields are arrays with synthetic data. Global variables defined in the optical table script.
116 | optimized_intensities = vector_dualSLM_4f_system(input_masks, x, y, wavelength, parameters)
117 | mean_loss, loss_array = mean_batch_MSE_Intensity(optimized_intensities, target_intensities)
118 | return mean_loss
119 |
120 | def mean_batch_MSE_Intensity(optimized, target):
121 | """
122 | [Computed for batch optimization in 4f system]. Vectorized version of MSE_Intensity.
123 |
124 | Returns the mean value of all the MSE for each (optimized, target) pairs and a jnp.array with MSE values from each pair.
125 | """
126 | MSE = vmap(MSE_Intensity, in_axes=(0, 0))(optimized, target)
127 | return jnp.mean(MSE), MSE
128 |
129 | @jit
130 | def MSE_Intensity(input_light, target_light):
131 | """
132 | Computes the Mean Squared Error (in Intensity) for a given electric field component Ex, Ey or Ez.
133 |
134 | Parameters:
135 | input_light (array): intensity: input_light = jnp.abs(input_light.field)**2
136 | target_light (array): Ground truth - intensity in the detector: target_light = jnp.abs(target_light.field)**2
137 |
138 | Returns the MSE (jnp.array).
139 | """
140 | num_pix = jnp.shape(input_light)[0] * jnp.shape(input_light)[1]
141 | return jnp.sum((input_light - target_light)** 2) / num_pix
--------------------------------------------------------------------------------
/experiments/four_f_optimizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # Setting the path for XLuminA modules:
5 | current_path = os.path.abspath(os.path.join('..'))
6 | module_path = os.path.join(current_path)
7 |
8 | if module_path not in sys.path:
9 | sys.path.append(module_path)
10 |
11 | from four_f_optical_table import *
12 | from xlumina.toolbox import MultiHDF5DataLoader
13 | import time
14 | import jax
15 | import optax
16 | from jax import jit
17 | import numpy as np
18 | import jax.numpy as jnp
19 |
20 | """
21 | OPTIMIZER FOR THE OPTICAL TELESCOPE (4F-SYSTEM).
22 | """
23 |
24 | # Print device info (GPU or CPU)
25 | print(jax.devices(), flush=True)
26 |
27 | # Call the data loader and set batchsize
28 | dataloader = MultiHDF5DataLoader("training_data_4f", batch_size=10)
29 |
30 | # Define the loss function and compute its gradients:
31 | loss_function = jit(loss_dualSLM)
32 |
33 | # ----------------------------------------------------
34 |
35 | def fit(params: optax.Params, optimizer: optax.GradientTransformation, num_iterations) -> optax.Params:
36 |
37 | opt_state = optimizer.init(params)
38 |
39 | @jit
40 | def update(params, opt_state, input_fields, target_fields):
41 | # Define single update step:
42 | # JIT the loss and compute
43 | loss_value, grads = jax.value_and_grad(loss_function, allow_int=True)(params, input_fields, target_fields)
44 | # Update the state of the optimizer
45 | updates, opt_state = optimizer.update(grads, opt_state, params)
46 | params = optax.apply_updates(params, updates)
47 | return params, opt_state, loss_value
48 |
49 | # Initialize some parameters
50 | iteration_steps=[]
51 | loss_list=[]
52 |
53 | # Optimizer settings
54 | n_best = 500
55 | best_loss = 1e2
56 | best_params = None
57 | best_step = 0
58 |
59 | print('Starting Optimization', flush=True)
60 |
61 | for step in range(num_iterations):
62 | # Load data:
63 | input_fields, target_fields = next(dataloader)
64 | params, opt_state, loss_value = update(params, opt_state, input_fields, target_fields)
65 |
66 | print(f"Step {step}")
67 | print(f"Loss {loss_value}")
68 |
69 | iteration_steps.append(step)
70 | loss_list.append(loss_value)
71 |
72 | # Update the `best_loss` value:
73 | if loss_value < best_loss:
74 | # Best loss value
75 | best_loss = loss_value
76 | # Best optimized parameters
77 | best_params = params
78 | best_step = step
79 | print('Best loss value is updated')
80 |
81 | if step % 100 == 0:
82 | # Stopping criteria: if best_loss has not changed every 500 steps, stop.
83 | if step - best_step > n_best:
84 | print(f'Stopping criterion: no improvement in loss value for {n_best} steps')
85 | break
86 |
87 | print(f'Best loss: {best_loss} at step {best_step}')
88 | print(f'Best parameters: {best_params}')
89 | return best_params, best_loss, iteration_steps, loss_list
90 |
91 | # ----------------------------------------------------
92 |
93 | # Optimizer settings
94 | num_iterations = 50000
95 | num_samples = 1
96 | # Step size engineering:
97 | STEP_SIZE = 0.01
98 | WEIGHT_DECAY = 0.0001
99 |
100 | for i in range(num_samples):
101 | tic = time.perf_counter()
102 |
103 | # Init random parameters
104 | phase_mask_slm1 = jnp.array([np.random.uniform(0, 1, (shape, shape))], dtype=jnp.float64)[0]
105 | phase_mask_slm2 = jnp.array([np.random.uniform(0, 1, (shape, shape))], dtype=jnp.float64)[0]
106 | distance_0 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
107 | distance_1 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
108 | distance_2 = jnp.array([np.random.uniform(0.027, 1)], dtype=jnp.float64)
109 | init_params = [distance_0, distance_1, distance_2, phase_mask_slm1, phase_mask_slm2]
110 |
111 | # Init optimizer:
112 | optimizer = optax.adamw(STEP_SIZE, weight_decay=WEIGHT_DECAY)
113 |
114 | # Apply fit function:
115 | best_params, best_loss, iteration_steps, loss_list = fit(init_params, optimizer, num_iterations)
116 |
117 | print("Time taken to optimize one sample - in seconds", time.perf_counter() - tic)
--------------------------------------------------------------------------------
/experiments/generate_synthetic_data.py:
--------------------------------------------------------------------------------
1 | # Setting the path for XLuminA modules:
2 | import os
3 | import sys
4 |
5 | # Setting the path for XLuminA modules:
6 | current_path = os.path.abspath(os.path.join('..'))
7 | module_path = os.path.join(current_path)
8 |
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | import numpy as np
13 | import jax.numpy as jnp
14 | from __init__ import um, nm
15 | from xlumina.wave_optics import *
16 | from xlumina.optical_elements import *
17 | from xlumina.toolbox import space
18 | import h5py
19 |
20 | """
21 | Synthetic data batches generation: 4f system with magnification 2x.
22 | - input_masks = jnp.array(in1, in2, ...)
23 | - target_intensity = jnp.array(out1, out2, ...)
24 | """
25 |
26 | # System characteristics:
27 | sensor_lateral_size = 1024 # Pixel resolution
28 | wavelength = 632.8*nm
29 | x_total = 1500*um
30 | x, y = space(x_total, sensor_lateral_size)
31 | X, Y = jnp.meshgrid(x,y)
32 |
33 | # Define the light source:
34 | w0 = (1200*um , 1200*um)
35 | gb = LightSource(x, y, wavelength)
36 | gb.plane_wave(A=0.5)
37 |
38 | def generate_synthetic_circles(gb, num_samples):
39 | in_circles = []
40 | out_circles = []
41 | for i in range(num_samples):
42 | r1 = jnp.array(np.random.uniform(100, 1000))
43 | r2 = jnp.array(np.random.uniform(100, 1000))
44 |
45 | # Store only the mask (binary)
46 | in_circle = circular_mask(X, Y, r=(r1, r2))
47 | in_circles.append(in_circle)
48 |
49 | # Magnification is 2, store only the itensity
50 | out_circle = gb.apply_circular_mask(r=(2*r1, 2*r2))
51 | out_circles.append(jnp.abs(out_circle.field)**2)
52 | return jnp.array(in_circles), jnp.array(out_circles)
53 |
54 | def generate_synthetic_squares(gb, num_samples):
55 | in_squares = []
56 | out_squares = []
57 | for i in range(num_samples):
58 | width = jnp.array(np.random.uniform(100, 1000))
59 | height = jnp.array(np.random.uniform(100, 1000))
60 | angle = jnp.array(np.random.uniform(0, 2*jnp.pi))
61 |
62 | # Binary mask only
63 | in_square = rectangular_mask(X, Y, center=(0,0), width=width, height=height, angle=angle)
64 | in_squares.append(in_square)
65 |
66 | # Magnification is 2 - we only need intensity
67 | out_square = gb.apply_rectangular_mask(center=(0,0), width=2*width, height=2*height, angle=-angle)
68 | out_squares.append(jnp.abs(out_square.field)**2)
69 |
70 | return jnp.array(in_squares), jnp.array(out_squares)
71 |
72 | def generate_synthetic_annular(gb, num_samples):
73 | in_annulars = []
74 | out_annulars = []
75 | for i in range(num_samples):
76 | di = jnp.array(np.random.uniform(100, 500))
77 | do = jnp.array(np.random.uniform(550, 1000))
78 |
79 | # Binary mask only:
80 | in_annular = annular_aperture(di, do, X, Y)
81 | in_annulars.append(in_annular)
82 |
83 | # Magnification is 2 - we only need intensity
84 | out_annular = gb.apply_annular_aperture(di=2*di, do=2*do)
85 | out_annulars.append(jnp.abs(out_annular.field)**2)
86 |
87 | return jnp.array(in_annulars), jnp.array(out_annulars)
88 |
89 | # Data generation loop:
90 | num_samples = 30
91 |
92 | for s in range(50):
93 | input_circles, target_circles = generate_synthetic_circles(gb, num_samples)
94 | input_squares, target_squares = generate_synthetic_squares(gb, num_samples)
95 | input_annular, target_annular = generate_synthetic_annular(gb, num_samples)
96 |
97 | input_fields = jnp.vstack([input_circles, input_squares, input_annular])
98 | target_fields = jnp.vstack([target_circles, target_squares, target_annular])
99 |
100 | filename = f"new_training_data_4f/new_synthetic_data_{s+150}.hdf5"
101 |
102 | with h5py.File(filename, 'w') as hdf:
103 | # Create datasets for your data
104 | hdf.create_dataset("Input fields", data=input_fields)
105 | hdf.create_dataset("Target fields", data=target_fields)
--------------------------------------------------------------------------------
/experiments/hybrid_optimizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # Setting the path for XLuminA modules:
5 | current_path = os.path.abspath(os.path.join('..'))
6 | module_path = os.path.join(current_path)
7 |
8 | if module_path not in sys.path:
9 | sys.path.append(module_path)
10 |
11 | import time
12 | import jax
13 | from jax import grad, jit
14 | import optax
15 | import numpy as np
16 | import jax.numpy as jnp
17 | import gc # Garbage collector
18 |
19 | # Use this for pure topological discovery:
20 | from xlumina.six_times_six_ansatz_with_fixed_PM import * #<--- use this for 6x6 ansatz
21 | # from xlumina.hybrid_with_fixed_PM import * # <--- use this for 3x3 with fixed masks
22 |
23 | # Use this for hybrid optimization:
24 | # from xlumina.hybrid_sharp_optical_table import * # <--- use this for sharp focus
25 | # from xlumina.hybrid_sted_optical_table import * # <--- use this for sted
26 |
27 | """
28 | OPTIMIZER - LARGE-SCALE SETUPS
29 | """
30 |
31 | # Print device info (GPU or CPU)
32 | print(jax.devices(), flush=True)
33 |
34 | # Global variable
35 | shape = jnp.array([sensor_lateral_size, sensor_lateral_size])
36 |
37 | # Define the loss function and compute its gradients:
38 | # loss_function = jit(loss_hybrid_sharp_focus) # <--- use this for sharp focus
39 | # loss_function = jit(loss_hybrid_sted) # <--- use this for sted
40 | loss_function = jit(loss_hybrid_fixed_PM) # <--- use this for sharp focus with fixed phase masks
41 |
42 | # ----------------------------------------------------
43 |
44 | def clip_adamw(learning_rate, weight_decay) -> optax.GradientTransformation:
45 | """
46 | Custom optimizer - adamw: applies several transformations in sequence
47 | 1) Apply ADAM
48 | 2) Apply weight decay
49 | """
50 | return optax.adamw(learning_rate=learning_rate, weight_decay=weight_decay)
51 |
52 | def fit(params: optax.Params, optimizer: optax.GradientTransformation, num_iterations) -> optax.Params:
53 |
54 | # Init the optimizer with initial parameters
55 | opt_state = optimizer.init(params)
56 |
57 | @jit
58 | def update(parameters, opt_state):
59 | # Define single update step:
60 | loss_value, grads = jax.value_and_grad(loss_function)(parameters)
61 |
62 | # Update the state of the optimizer
63 | updates, state = optimizer.update(grads, opt_state, parameters)
64 |
65 | # Update the parameters
66 | new_params = optax.apply_updates(parameters, updates)
67 |
68 | return new_params, parameters, state, loss_value, updates
69 |
70 |
71 | # Initialize some parameters
72 | iteration_steps=[]
73 | loss_list=[]
74 |
75 | n_best = 500
76 | best_loss = 3*1e2
77 | best_params = None
78 | best_step = 0
79 |
80 | print('Starting Optimization', flush=True)
81 |
82 | for step in range(num_iterations):
83 |
84 | params, old_params, opt_state, loss_value, grads = update(params, opt_state)
85 |
86 | print(f"Step {step}")
87 | print(f"Loss {loss_value}")
88 | iteration_steps.append(step)
89 | loss_list.append(loss_value)
90 |
91 | # Update the `best_loss` value:
92 | if loss_value < best_loss:
93 | # Best loss value
94 | best_loss = loss_value
95 | # Best optimized parameters
96 | best_params = old_params
97 | best_step = step
98 | print('Best loss value is updated')
99 |
100 | if step % 100 == 0:
101 | # Stopping criteria: if best_loss has not changed every 500 steps, stop.
102 | if step - best_step > n_best:
103 | print(f'Stopping criterion: no improvement in loss value for {n_best} steps')
104 | break
105 |
106 | print(f'Best loss: {best_loss} at step {best_step}')
107 | print(f'Best parameters: {best_params}')
108 | return best_params, best_loss, iteration_steps, loss_list
109 |
110 | # ----------------------------------------------------
111 |
112 | # Optimizer settings
113 | num_iterations = 100000
114 | num_samples = 20
115 |
116 | for i in range(num_samples):
117 |
118 | STEP_SIZE = 0.05
119 | WEIGHT_DECAY = 0.0001
120 |
121 | gc.collect()
122 | tic = time.perf_counter()
123 |
124 | # Parameters -- know which ones to comment based on the setup you want to optimize:
125 | # super-SLM phase masks:
126 | phase1_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
127 | phase1_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
128 | phase2_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
129 | phase2_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
130 | phase3_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
131 | phase3_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
132 |
133 | # Wave plate variables:
134 | eta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
135 | theta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
136 | eta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
137 | theta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
138 | eta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
139 | theta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
140 | eta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
141 | theta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
142 |
143 | # Distances:
144 | z1_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
145 | z1_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
146 | z2_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
147 | z2_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
148 | z3_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
149 | z3_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
150 | z4_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
151 | z4_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
152 | z4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
153 | z5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
154 | z1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
155 | z2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
156 | z3 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
157 | z4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
158 | z5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
159 | z6 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
160 | z7 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
161 | z8 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
162 | z9 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
163 | z10 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
164 | z11 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
165 | z12 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
166 |
167 | # Beam splitter ratios
168 | bs1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
169 | bs2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
170 | bs3 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
171 | bs4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
172 | bs5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
173 | bs6 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
174 | bs7 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
175 | bs8 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
176 | bs9 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
177 | bs10 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
178 | bs11 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
179 | bs12 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
180 | bs13 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
181 | bs14 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
182 | bs15 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
183 | bs16 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
184 | bs17 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
185 | bs18 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
186 | bs19 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
187 | bs20 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
188 | bs21 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
189 | bs22 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
190 | bs23 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
191 | bs24 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
192 | bs25 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
193 | bs26 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
194 | bs27 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
195 | bs28 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
196 | bs29 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
197 | bs30 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
198 | bs31 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
199 | bs32 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
200 | bs33 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
201 | bs34 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
202 | bs35 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
203 | bs36 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
204 |
205 | # Set which set of init parameters to use:
206 | # REMEMBER TO COMMENT (#) THE VARIABLES YOU DON'T USE!
207 |
208 | # 1. For 3x3 hybrid optimization (topology + optical parameters):
209 | # init_params = [phase1_1, phase1_2, eta1, theta1, z1_1, z1_2, phase2_1, phase2_2, eta2, theta2, z2_1, z2_2, phase3_1, phase3_2, eta3, theta3, z3_1, z3_2, bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9, z4, z5]
210 |
211 | # 2. Parameters for pure topological optimization on 3x3 systems with fixed phase masks at random positions:
212 | # init_params = [z1_1, z1_2, z2_1, z2_2, z3_1, z3_2, z4_1, z4_2, bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9, eta1, theta1, eta2, theta2, eta3, theta3, eta4, theta4]
213 |
214 | # 3. Parameters for pure topological optimization on the 6x6 system with fixed phase masks:
215 | init_params = [z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12,
216 | bs1, bs2, bs3, bs4, bs5, bs6,
217 | bs7, bs8, bs9, bs10, bs11, bs12,
218 | bs13, bs14, bs15, bs16, bs17, bs18,
219 | bs19, bs20, bs21, bs22, bs23, bs24,
220 | bs25, bs26, bs27, bs28, bs29, bs30,
221 | bs31, bs32, bs33, bs34, bs35, bs36,
222 | eta1, theta1, eta2, theta2]
223 |
224 | # Init optimizer:
225 | optimizer = clip_adamw(STEP_SIZE, WEIGHT_DECAY)
226 |
227 | # Apply fit function:
228 | best_params, best_loss, iteration_steps, loss_list = fit(init_params, optimizer, num_iterations)
--------------------------------------------------------------------------------
/experiments/hybrid_sharp_optical_table.py:
--------------------------------------------------------------------------------
1 | # Setting the path for XLuminA modules:
2 | import os
3 | import sys
4 |
5 | # Setting the path for XLuminA modules:
6 | current_path = os.path.abspath(os.path.join('..'))
7 | module_path = os.path.join(current_path)
8 |
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | from xlumina.__init__ import um, nm, cm, mm
13 | from xlumina.vectorized_optics import *
14 | from xlumina.optical_elements import hybrid_setup_sharp_focus
15 | from xlumina.loss_functions import vectorized_loss_hybrid
16 | from xlumina.toolbox import space, softmin
17 | import jax.numpy as jnp
18 |
19 | """
20 | Large-scale setup for Dorn, Quabis and Leuchs (2004) benchmark rediscovery:
21 |
22 | 3x3 initial setup - light gets detected across 6 detectors.
23 | """
24 |
25 | # 1. System specs:
26 | sensor_lateral_size = 1024 # Resolution
27 | wavelength = 635*nm
28 | x_total = 2500*um
29 | x, y = space(x_total, sensor_lateral_size)
30 | shape = jnp.shape(x)[0]
31 |
32 | # 2. Define the optical functions: two orthogonally polarized beams:
33 | w0 = (1200*um, 1200*um)
34 | ls1 = PolarizedLightSource(x, y, wavelength)
35 | ls1.gaussian_beam(w0=w0, jones_vector=(1, 1))
36 |
37 | # 3. Define the output (High Resolution) detection:
38 | x_out, y_out = jnp.array(space(10*um, 400))
39 |
40 | # 4. High NA objective lens specs:
41 | NA = 0.9
42 | radius_lens = 3.6*mm/2
43 | f_lens = radius_lens / NA
44 |
45 | # 5. Static parameters - don't change during optimization:
46 | fixed_params = [radius_lens, f_lens, x_out, y_out]
47 |
48 | # 6. Define the loss function:
49 | @jit
50 | def loss_hybrid_sharp_focus(parameters):
51 | # Output from hybrid_setup is jnp.array(6, N, N): for 6 detectors
52 | detected_z_intensities, _ = hybrid_setup_sharp_focus(ls1, ls1, ls1, ls1, ls1, ls1, parameters, fixed_params)
53 |
54 | # Get the minimum value within loss value array of shape (6, 1, 1)
55 | loss_val = softmin(vectorized_loss_hybrid(detected_z_intensities))
56 |
57 | return loss_val
--------------------------------------------------------------------------------
/experiments/hybrid_sted_optical_table.py:
--------------------------------------------------------------------------------
1 | # Setting the path for XLuminA modules:
2 | import os
3 | import sys
4 |
5 | # Setting the path for XLuminA modules:
6 | current_path = os.path.abspath(os.path.join('..'))
7 | module_path = os.path.join(current_path)
8 |
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | from xlumina.__init__ import um, nm, cm, mm
13 | from xlumina.vectorized_optics import *
14 | from xlumina.optical_elements import hybrid_setup_fluorophores
15 | from xlumina.loss_functions import vectorized_loss_hybrid
16 | from xlumina.toolbox import space, softmin
17 | import jax.numpy as jnp
18 |
19 | """
20 | Large-scale setup for STED microscopy baseline rediscovery:
21 |
22 | 3x3 initial setup - light gets detected across 6 detectors.
23 | """
24 |
25 | # 1. System specs:
26 | sensor_lateral_size = 824 # Resolution
27 | wavelength1 = 650*nm
28 | wavelength2 = 532*nm
29 | x_total = 2500*um
30 | x, y = space(x_total, sensor_lateral_size)
31 | shape = jnp.shape(x)[0]
32 |
33 | # 2. Define the optical functions: two orthogonally polarized beams:
34 | w0 = (1200*um, 1200*um)
35 | ls1 = PolarizedLightSource(x, y, wavelength1)
36 | ls1.gaussian_beam(w0=w0, jones_vector=(1, 1))
37 | ls2 = PolarizedLightSource(x, y, wavelength2)
38 | ls2.gaussian_beam(w0=w0, jones_vector=(1, 1))
39 |
40 | # 3. Define the output (High Resolution) detection:
41 | x_out, y_out = jnp.array(space(10*um, 400))
42 |
43 | # 4. High NA objective lens specs:
44 | NA = 0.9
45 | radius_lens = 3.6*mm/2
46 | f_lens = radius_lens / NA
47 |
48 | # 5. Static parameters - don't change during optimization:
49 | fixed_params = [radius_lens, f_lens, x_out, y_out]
50 |
51 | # 6. Define the loss function:
52 | @jit
53 | def loss_hybrid_sted(parameters):
54 | # Output from hybrid_setup is jnp.array(6, N, N): for 6 detectors
55 | i_effective = hybrid_setup_fluorophores(ls1, ls2, ls1, ls2, ls1, ls2, parameters, fixed_params, distance_offset = 10)
56 |
57 | # Get the minimum value within loss value array of shape (6, 1, 1)
58 | loss_val = softmin(vectorized_loss_hybrid(i_effective))
59 |
60 | return loss_val
--------------------------------------------------------------------------------
/experiments/hybrid_with_fixed_PM.py:
--------------------------------------------------------------------------------
1 | # Setting the path for XLuminA modules:
2 | import os
3 | import sys
4 |
5 | # Setting the path for XLuminA modules:
6 | current_path = os.path.abspath(os.path.join('..'))
7 | module_path = os.path.join(current_path)
8 |
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | from xlumina.__init__ import um, nm, mm
13 | from xlumina.vectorized_optics import *
14 | from xlumina.optical_elements import hybrid_setup_fixed_slms_fluorophores, hybrid_setup_fixed_slms
15 | from xlumina.loss_functions import vectorized_loss_hybrid
16 | from xlumina.toolbox import space, softmin
17 | import jax.numpy as jnp
18 |
19 | """
20 | Large-scale setup using fixed phase masks in random positions:
21 |
22 | 3x3 initial setup - light gets detected across 6 detectors.
23 |
24 | This script is valid for rediscovering
25 |
26 | (1) Dorn, Quabis and Leuchs (2004) - use hybrid_setup_fixed_slms() in the loss function,
27 |
28 | (2) STED microscopy - use hybrid_setup_fixed_slms_fluorophores() in the loss function.
29 | """
30 |
31 | # 1. System specs:
32 | sensor_lateral_size = 824 # Resolution
33 | wavelength_1 = 632.8*nm
34 | wavelength_2 = 530*nm
35 | x_total = 2500*um
36 | x, y = space(x_total, sensor_lateral_size)
37 | shape = jnp.shape(x)[0]
38 |
39 | # 2. Define the optical functions: two orthogonally polarized beams:
40 | w0 = (1200*um, 1200*um)
41 | ls1 = PolarizedLightSource(x, y, wavelength_1)
42 | ls1.gaussian_beam(w0=w0, jones_vector=(1, 1))
43 | ls2 = PolarizedLightSource(x, y, wavelength_2)
44 | ls2.gaussian_beam(w0=w0, jones_vector=(1, 1))
45 |
46 | # 3. Define the output (High Resolution) detection:
47 | x_out, y_out = jnp.array(space(10*um, 400))
48 | X, Y = jnp.meshgrid(x,y)
49 |
50 | # 4. High NA objective lens specs:
51 | NA = 0.9
52 | radius_lens = 3.6*mm/2
53 | f_lens = radius_lens / NA
54 |
55 | # 4.1 Fixed phase masks:
56 | # Polarization converter in Dorn, Quabis, Leuchs (2004):
57 | pi_half = (jnp.pi - jnp.pi/2) * jnp.ones(shape=(sensor_lateral_size // 2, sensor_lateral_size // 2))
58 | minus_pi_half = - jnp.pi/2 * jnp.ones(shape=(sensor_lateral_size // 2, sensor_lateral_size // 2))
59 | PM1_1 = jnp.concatenate((jnp.concatenate((minus_pi_half, pi_half), axis=1), jnp.concatenate((minus_pi_half, pi_half), axis=1)), axis=0)
60 | PM1_2 = jnp.concatenate((jnp.concatenate((minus_pi_half, minus_pi_half), axis=1), jnp.concatenate((pi_half, pi_half), axis=1)), axis=0)
61 | # Spiral phase (STED microscopy)
62 | PM2 = jnp.arctan2(Y,X)
63 | # Forked grating
64 | PM3 = jnp.cos(2 * PM2 - 2 * jnp.pi * X/1000) * jnp.pi
65 | # Linear grating
66 | PM4_1 = jnp.sin(2*jnp.pi * Y/1000) * jnp.pi
67 | PM4_2 = jnp.sin(2*jnp.pi * X/1000) * jnp.pi
68 |
69 | # 5. Static parameters - don't change during optimization:
70 | fixed_params = [radius_lens, f_lens, x_out, y_out, PM1_1, PM1_2, PM2, PM3, PM4_1, PM4_2]
71 |
72 | # 6. Define the loss function:
73 | def loss_hybrid_fixed_PM(parameters):
74 | # Output from hybrid_setup is jnp.array(6, N, N): for 6 detectors
75 |
76 | # Use (1) for Dorn, Quabis and Leuchs benchmark / Use (2) for STED microscopy benchmark
77 |
78 | # (1):
79 | # detected_z_intensities, _ = hybrid_setup_fixed_slms(ls1, ls1, ls1, ls1, ls1, ls1, parameters, fixed_params)
80 |
81 | # (2):
82 | i_effective = hybrid_setup_fixed_slms_fluorophores(ls1, ls2, ls1, ls2, ls1, ls2, parameters, fixed_params)
83 |
84 | # Get the minimum value within loss value array of shape (6, 1, 1)
85 | loss_val = softmin(vectorized_loss_hybrid(i_effective))
86 |
87 | return loss_val
--------------------------------------------------------------------------------
/experiments/six_times_six_ansatz_with_fixed_PM.py:
--------------------------------------------------------------------------------
1 | # Setting the path for XLuminA modules:
2 | import os
3 | import sys
4 |
5 | # Setting the path for XLuminA modules:
6 | current_path = os.path.abspath(os.path.join('..'))
7 | module_path = os.path.join(current_path)
8 |
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | from __init__ import um, nm, cm, mm
13 | from xlumina.vectorized_optics import *
14 | from xlumina.optical_elements import six_times_six_ansatz
15 | from xlumina.loss_functions import vectorized_loss_hybrid
16 | from xlumina.toolbox import space, softmin
17 | import jax.numpy as jnp
18 |
19 | """
20 | Pure topological discovery within 6x6 ansatz for Dorn, Quabis and Leuchs (2004)
21 | """
22 |
23 | # 1. System specs:
24 | sensor_lateral_size = 824 # Resolution
25 | wavelength_1 = 635.0*nm
26 | x_total = 2500*um
27 | x, y = space(x_total, sensor_lateral_size)
28 | shape = jnp.shape(x)[0]
29 |
30 | # 2. Define the optical functions: two orthogonally polarized beams:
31 | w0 = (1200*um, 1200*um)
32 | ls1 = PolarizedLightSource(x, y, wavelength_1)
33 | ls1.gaussian_beam(w0=w0, jones_vector=(1, -1))
34 |
35 | # 3. Define the output (High Resolution) detection:
36 | x_out, y_out = jnp.array(space(10*um, 400))
37 | X, Y = jnp.meshgrid(x,y)
38 |
39 | # 4. High NA objective lens specs:
40 | NA = 0.9
41 | radius_lens = 3.6*mm/2
42 | f_lens = radius_lens / NA
43 |
44 | # 4.1 Fixed phase masks:
45 | # Polarization converter in Dorn, Quabis, Leuchs (2004):
46 | pi_half = (jnp.pi - jnp.pi/2) * jnp.ones(shape=(sensor_lateral_size // 2, sensor_lateral_size // 2))
47 | minus_pi_half = - jnp.pi/2 * jnp.ones(shape=(sensor_lateral_size // 2, sensor_lateral_size // 2))
48 | PM1_1 = jnp.concatenate((jnp.concatenate((minus_pi_half, pi_half), axis=1), jnp.concatenate((minus_pi_half, pi_half), axis=1)), axis=0)
49 | PM1_2 = jnp.concatenate((jnp.concatenate((minus_pi_half, minus_pi_half), axis=1), jnp.concatenate((pi_half, pi_half), axis=1)), axis=0)
50 |
51 | # Linear grating
52 | PM2_1 = jnp.sin(2*jnp.pi * Y/1000) * jnp.pi
53 | PM2_2 = jnp.sin(2*jnp.pi * X/1000) * jnp.pi
54 |
55 | # 5. Static parameters - don't change during optimization:
56 | fixed_params = [radius_lens, f_lens, x_out, y_out, PM1_1, PM1_2, PM2_1, PM2_2]
57 |
58 | # 6. Define the loss function:
59 | def loss_hybrid_fixed_PM(parameters):
60 | # Output from hybrid_setup is jnp.array(12, N, N): for 12 detectors
61 | i_effective = six_times_six_ansatz(ls1, ls1, ls1, ls1, ls1, ls1, ls1, ls1, ls1, ls1, ls1, ls1, parameters, fixed_params, distance_offset = 9.5)
62 | # Get the minimum value within loss value array of shape (12, 1, 1)
63 | loss_val = softmin(vectorized_loss_hybrid(i_effective))
64 | return loss_val
--------------------------------------------------------------------------------
/miscellaneous/noise-aware.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/miscellaneous/noise-aware.png
--------------------------------------------------------------------------------
/miscellaneous/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/miscellaneous/performance.png
--------------------------------------------------------------------------------
/miscellaneous/performance_convergence.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/miscellaneous/performance_convergence.png
--------------------------------------------------------------------------------
/miscellaneous/propagation_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/miscellaneous/propagation_comparison.png
--------------------------------------------------------------------------------
/miscellaneous/workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/miscellaneous/workflow.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from os.path import exists
4 | from pathlib import Path
5 | import re
6 |
7 | from setuptools import setup, find_packages
8 |
9 | author = 'artificial-scientist-lab'
10 | email = 'carla.rodriguez@mpl.mpg.de, soeren.arlt@mpl.mpg.de, mario.krenn@mpl.mpg.de,'
11 | description = 'XLuminA: An Auto-differentiating Discovery Framework for Super-Resolution Microscopy'
12 | dist_name = 'xlumina'
13 | package_name = 'xlumina'
14 | year = '2023'
15 | url = 'https://github.com/artificial-scientist-lab/XLuminA'
16 |
17 | setup(
18 | name=dist_name,
19 | author=author,
20 | author_email=email,
21 | url=url,
22 | version="1.0.0",
23 | packages=find_packages(),
24 | package_dir={dist_name: package_name},
25 | include_package_data=True,
26 | license='MIT',
27 | description=description,
28 | long_description=Path('README.md').read_text() if Path('README.md').exists() else '',
29 | long_description_content_type="text/markdown",
30 | install_requires=[
31 | 'jax==0.4.33',
32 | 'numpy',
33 | 'optax==0.2.3',
34 | 'scipy==1.14.1',
35 | 'matplotlib'
36 | ],
37 | python_requires=">=3.10",
38 | classifiers=[
39 | 'Operating System :: OS Independent',
40 | 'Programming Language :: Python :: 3.10',
41 | 'Programming Language :: Python :: 3.11',
42 | ],
43 | platforms=['ALL'],
44 | py_modules=[package_name],
45 | )
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/artificial-scientist-lab/XLuminA/299da752da7c198fa173b69a7f12c3b9ce198af6/tests/__init__.py
--------------------------------------------------------------------------------
/tests/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | filterwarnings =
3 | ignore:.*invalid escape sequence.*:DeprecationWarning
--------------------------------------------------------------------------------
/tests/test_optical_elements.py:
--------------------------------------------------------------------------------
1 | # Test for optical elements
2 | import os
3 | import sys
4 |
5 | # Setting the path for XLuminA modules:
6 | current_path = os.path.abspath(os.path.join('..'))
7 | module_path = os.path.join(current_path)
8 |
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | import unittest
13 | import jax.numpy as jnp
14 | import numpy as np
15 | from xlumina.optical_elements import (
16 | SLM, jones_LP, jones_general_retarder, jones_sSLM, jones_LCD,
17 | sSLM, LCD, linear_polarizer, BS_symmetric, high_NA_objective_lens,
18 | VCZT_objective_lens, lens, cylindrical_lens, axicon_lens, building_block
19 | )
20 |
21 | from xlumina.vectorized_optics import VectorizedLight, PolarizedLightSource
22 | from xlumina.wave_optics import LightSource
23 |
24 | class TestOpticalElements(unittest.TestCase):
25 | def setUp(self):
26 | self.wavelength = 633e-3 #nm
27 | self.resolution = 512
28 | self.x = np.linspace(-1500, 1500, self.resolution)
29 | self.y = np.linspace(-1500, 1500, self.resolution)
30 | self.k = 2 * jnp.pi / self.wavelength
31 |
32 | def test_slm(self):
33 | light = LightSource(self.x, self.y, self.wavelength)
34 | light.gaussian_beam(w0=(1200, 1200), E0=1)
35 | phase = jnp.zeros((self.resolution, self.resolution))
36 | slm_output, _ = SLM(light, phase, self.resolution)
37 | self.assertEqual(slm_output.field.shape, (self.resolution, self.resolution)) # Check output shape == input shape
38 | self.assertTrue(jnp.allclose(slm_output.field, light.field)) # Phase added by SLM is 0, field shouldn't change.
39 |
40 | def test_shape_jones_matrices(self):
41 | lp = jones_LP(jnp.pi/4)
42 | self.assertEqual(lp.shape, (2, 2))
43 |
44 | retarder = jones_general_retarder(jnp.pi/2, jnp.pi/4, 0)
45 | self.assertEqual(retarder.shape, (2, 2))
46 |
47 | sslm = jones_sSLM(jnp.pi/2, jnp.pi/4)
48 | self.assertEqual(sslm.shape, (2, 2))
49 |
50 | lcd = jones_LCD(jnp.pi/2, jnp.pi/4)
51 | self.assertEqual(lcd.shape, (2, 2))
52 |
53 | def test_polarization_devices(self):
54 | light = PolarizedLightSource(self.x, self.y, self.wavelength)
55 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 1))
56 |
57 | # super-SLM with zero phase -- input SoP is diagonal
58 | alpha = jnp.zeros((self.resolution, self.resolution))
59 | phi = jnp.zeros((self.resolution, self.resolution))
60 | sslm_output = sSLM(light, alpha, phi)
61 | self.assertEqual(sslm_output.Ex.shape, (self.resolution, self.resolution))
62 | self.assertEqual(sslm_output.Ey.shape, (self.resolution, self.resolution))
63 | self.assertEqual(sslm_output.Ez.shape, (self.resolution, self.resolution))
64 | self.assertTrue(jnp.allclose(sslm_output.Ex, light.Ex))
65 | self.assertTrue(jnp.allclose(sslm_output.Ey, light.Ey))
66 | self.assertTrue(jnp.allclose(sslm_output.Ez, light.Ez))
67 |
68 | # super-SLM with pi phase in Ex and Ey -- input SoP is diagonal
69 | alpha = jnp.pi * jnp.ones((self.resolution, self.resolution))
70 | phi = jnp.pi * jnp.ones((self.resolution, self.resolution))
71 | sslm_output = sSLM(light, alpha, phi)
72 | self.assertTrue(jnp.allclose(sslm_output.Ex, light.Ex * jnp.exp(1j * jnp.pi)))
73 | self.assertTrue(jnp.allclose(sslm_output.Ey, light.Ey * jnp.exp(1j * jnp.pi)))
74 | self.assertTrue(jnp.allclose(sslm_output.Ez, light.Ez))
75 |
76 | # LCD
77 | light = PolarizedLightSource(self.x, self.y, self.wavelength)
78 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0))
79 | lcd_output = LCD(light, 0, 0)
80 | self.assertEqual(lcd_output.Ex.shape, (self.resolution, self.resolution))
81 | self.assertEqual(lcd_output.Ey.shape, (self.resolution, self.resolution))
82 | self.assertEqual(lcd_output.Ez.shape, (self.resolution, self.resolution))
83 | self.assertTrue(jnp.allclose(lcd_output.Ex, light.Ex))
84 | self.assertTrue(jnp.allclose(lcd_output.Ey, light.Ey))
85 | self.assertTrue(jnp.allclose(lcd_output.Ez, light.Ez))
86 |
87 | # LP aligned with incident SoP
88 | empty = jnp.zeros((self.resolution, self.resolution))
89 | lp_output = linear_polarizer(light, empty)
90 | self.assertEqual(lp_output.Ex.shape, (self.resolution, self.resolution))
91 | self.assertEqual(lp_output.Ey.shape, (self.resolution, self.resolution))
92 | self.assertEqual(lp_output.Ez.shape, (self.resolution, self.resolution))
93 | self.assertTrue(jnp.allclose(lp_output.Ex, light.Ex))
94 | self.assertTrue(jnp.allclose(lp_output.Ey, light.Ey))
95 | self.assertTrue(jnp.allclose(lp_output.Ez, light.Ez))
96 |
97 | # LP crossed to input SoP
98 | pi_half = jnp.pi/2 * jnp.ones_like(empty)
99 | lp_output = linear_polarizer(light, pi_half)
100 | self.assertTrue(jnp.allclose(lp_output.Ex, empty))
101 | self.assertTrue(jnp.allclose(lp_output.Ey, empty))
102 |
103 | def test_beam_splitter(self):
104 | light1 = PolarizedLightSource(self.x, self.y, self.wavelength)
105 | light1.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0))
106 | light2 = PolarizedLightSource(self.x, self.y, self.wavelength)
107 | light2.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0))
108 |
109 | c, d = BS_symmetric(light1, light2, 0) # fully transmissive
110 | # Adds a pi phase: jnp.exp(1j * pi) = 1j
111 | # Noise = T*0.01
112 | T = jnp.abs(jnp.cos(0))
113 | R = jnp.abs(jnp.sin(0))
114 | noise = 0.01
115 | self.assertTrue(jnp.allclose(c.Ex, (T - noise) * 1j * light2.Ex + (R - noise) * light1.Ex))
116 | self.assertTrue(jnp.allclose(c.Ey, (T - noise) * 1j * light2.Ey + (R - noise) * light1.Ey))
117 | self.assertTrue(jnp.allclose(d.Ex, (T - noise) * 1j * light1.Ex + (R - noise) * light2.Ex))
118 | self.assertTrue(jnp.allclose(d.Ey, (T - noise) * 1j * light1.Ey + (R - noise) * light2.Ey))
119 |
120 | def test_high_na_objective_lens(self):
121 | radius_lens = 3.6*1e3/2 # mm
122 | f_lens = radius_lens / 0.9
123 | light = PolarizedLightSource(self.x, self.y, self.wavelength)
124 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0))
125 | output, _ = high_NA_objective_lens(light, radius_lens, f_lens)
126 | self.assertEqual(output.shape, (3, self.resolution, self.resolution))
127 |
128 | def test_vczt_objective_lens(self):
129 | radius_lens = 3.6*1e3/2 # mm
130 | f_lens = radius_lens / 0.9
131 | light = PolarizedLightSource(self.x, self.y, self.wavelength)
132 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0))
133 | output = VCZT_objective_lens(light, radius_lens, f_lens, self.x, self.y)
134 | self.assertEqual(output.Ex.shape, (self.resolution, self.resolution))
135 | self.assertEqual(output.Ey.shape, (self.resolution, self.resolution))
136 | self.assertEqual(output.Ez.shape, (self.resolution, self.resolution))
137 |
138 | def test_lenses_scalar(self):
139 | light = LightSource(self.x, self.y, self.wavelength)
140 | light.gaussian_beam(w0=(1200, 1200), E0=1)
141 | lens_output, _ = lens(light, (50, 50), (1000, 1000))
142 | self.assertEqual(lens_output.field.shape, (self.resolution, self.resolution))
143 | cyl_lens_output, _ = cylindrical_lens(light, 1000)
144 | self.assertEqual(cyl_lens_output.field.shape, (self.resolution, self.resolution))
145 | axicon_output, _ = axicon_lens(light, 0.1)
146 | self.assertEqual(axicon_output.field.shape, (self.resolution, self.resolution))
147 |
148 | def test_lenses_vectorial(self):
149 | ls = PolarizedLightSource(self.x, self.y, self.wavelength)
150 | ls.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0))
151 | light = VectorizedLight(self.x, self.y, self.wavelength)
152 | light.Ex = ls.Ex
153 | light.Ey = ls.Ey
154 | light.Ez = ls.Ez
155 | lens_output, _ = lens(light, (50, 50), (1000, 1000))
156 | self.assertEqual(lens_output.Ex.shape, (self.resolution, self.resolution))
157 | self.assertEqual(lens_output.Ey.shape, (self.resolution, self.resolution))
158 | cyl_lens_output, _ = cylindrical_lens(light, 1000)
159 | self.assertEqual(cyl_lens_output.Ex.shape, (self.resolution, self.resolution))
160 | self.assertEqual(cyl_lens_output.Ey.shape, (self.resolution, self.resolution))
161 | axicon_output, _ = axicon_lens(light, 0.1)
162 | self.assertEqual(axicon_output.Ex.shape, (self.resolution, self.resolution))
163 | self.assertEqual(axicon_output.Ey.shape, (self.resolution, self.resolution))
164 |
165 | def test_building_block(self):
166 | light = PolarizedLightSource(self.x, self.y, self.wavelength)
167 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0))
168 | output = building_block(light, jnp.zeros((self.resolution, self.resolution)), jnp.zeros((self.resolution, self.resolution)), 1000, jnp.pi/2, jnp.pi/4)
169 | self.assertEqual(output.Ex.shape, (self.resolution, self.resolution))
170 | self.assertEqual(output.Ey.shape, (self.resolution, self.resolution))
171 | self.assertEqual(output.Ez.shape, (self.resolution, self.resolution))
172 |
173 | if __name__ == '__main__':
174 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_toolbox.py:
--------------------------------------------------------------------------------
1 | # Test for optical elements
2 | import os
3 | import sys
4 |
5 | # Setting the path for XLuminA modules:
6 | current_path = os.path.abspath(os.path.join('..'))
7 | module_path = os.path.join(current_path)
8 |
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | import unittest
13 | import jax.numpy as jnp
14 | import numpy as np
15 | from jax import random
16 | from xlumina.toolbox import (
17 | space, wrap_phase, is_conserving_energy, softmin, delta_kronecker,
18 | build_LCD_cell, rotate_mask, nearest,
19 | extract_profile, gaussian, lorentzian, fwhm_1d_fit, spot_size,
20 | compute_fwhm, find_max_min, gaussian_2d
21 | )
22 | from xlumina.vectorized_optics import VectorizedLight, PolarizedLightSource
23 |
24 | class TestToolbox(unittest.TestCase):
25 | def setUp(self):
26 | seed = 9999
27 | self.key = random.PRNGKey(seed)
28 | self.resolution = 512
29 | self.x = np.linspace(-1500, 1500, self.resolution)
30 | self.y = np.linspace(-1500, 1500, self.resolution)
31 | self.wavelength = 633e-3
32 |
33 | def test_space(self):
34 | x, y = space(1500, self.resolution)
35 | self.assertTrue(jnp.allclose(x, self.x))
36 | self.assertTrue(jnp.allclose(y, self.y))
37 |
38 | def test_wrap_phase(self):
39 | phase = jnp.array([0, jnp.pi, 2*jnp.pi, 3*jnp.pi, -3*jnp.pi])
40 | wrapped = wrap_phase(phase)
41 | self.assertTrue(jnp.allclose(wrapped, jnp.array([0, jnp.pi, 0, jnp.pi, -jnp.pi])))
42 |
43 | def test_is_conserving_energy(self):
44 | light1 = VectorizedLight(self.x, self.y, self.wavelength)
45 | light2 = VectorizedLight(self.x, self.y, self.wavelength)
46 | light1.Ex = jnp.ones((self.resolution, self.resolution))
47 | light2.Ex = jnp.ones((self.resolution, self.resolution))
48 | conservation = is_conserving_energy(light1, light2)
49 | self.assertAlmostEqual(conservation, 1.0, places=6)
50 | light2.Ex = 0*light2.Ex
51 | conservation = is_conserving_energy(light1, light2)
52 | self.assertEqual(conservation, 0)
53 |
54 | def test_softmin(self):
55 | result = softmin(jnp.array([1.0, 2.0, 3.0]))
56 | self.assertTrue(result == 1.0)
57 |
58 | def test_delta_kronecker(self):
59 | self.assertEqual(delta_kronecker(1, 1), 1)
60 | self.assertEqual(delta_kronecker(1, 2), 0)
61 |
62 | def test_build_LCD_cell(self):
63 | eta, theta = build_LCD_cell(jnp.pi/2, jnp.pi/4, self.resolution)
64 | self.assertTrue(jnp.allclose(eta, jnp.pi/2 * jnp.ones((self.resolution, self.resolution))))
65 | self.assertTrue(jnp.allclose(theta, jnp.pi/4 * jnp.ones((self.resolution, self.resolution))))
66 |
67 | def test_rotate_mask(self):
68 | X, Y = jnp.meshgrid(self.x, self.y)
69 | Xrot, Yrot = rotate_mask(X, Y, jnp.pi/4)
70 | self.assertEqual(Xrot.shape, (self.resolution, self.resolution))
71 | self.assertEqual(Yrot.shape, (self.resolution, self.resolution))
72 |
73 | def test_nearest(self):
74 | array = jnp.array([1, 2, 3, 4, 5])
75 | idx, value, distance = nearest(array, 3.7)
76 | self.assertEqual(idx, 3)
77 | self.assertEqual(value, 4)
78 | self.assertAlmostEqual(distance, 0.3, places=6)
79 |
80 | def test_extract_profile(self):
81 | data_2d = jnp.ones((10, 10))
82 | x_points = jnp.array([0, 1, 2])
83 | y_points = jnp.array([0, 1, 2])
84 | profile = extract_profile(data_2d, x_points, y_points)
85 | self.assertEqual(profile.shape, x_points.shape)
86 |
87 | def test_gaussian(self):
88 | y = gaussian(self.x, 1, 0, 1)
89 | self.assertEqual(y.shape, self.x.shape)
90 |
91 | def test_lorentzian(self):
92 | y = lorentzian(self.x, 0, 1)
93 | self.assertEqual(y.shape, self.x.shape)
94 |
95 | def test_fwhm_1d_fit(self):
96 | sigma = 120
97 | amplitude = 1000
98 | mean = 0
99 | y = gaussian(self.x, amplitude, mean, sigma)
100 | _, fwhm, _ = fwhm_1d_fit(self.x, y, fit='gaussian')
101 | fwhm_theoretical = 2*sigma*jnp.sqrt(2*jnp.log(2)) # 2*sigma*sqrt(2*ln2) is the theoretical FWHM for a gaussian
102 | self.assertAlmostEqual(fwhm, fwhm_theoretical, places=2)
103 |
104 | def test_spot_size(self):
105 | size = spot_size(1, 1, self.wavelength)
106 | self.assertGreater(size, 0)
107 |
108 | def test_compute_fwhm(self):
109 | sigma = 120
110 | light_1d = gaussian(self.x, 1000, 0, sigma)
111 | XY = jnp.meshgrid(self.x, self.y)
112 | light_2d = gaussian_2d(XY, 1000, 0, 0, sigma, sigma)
113 |
114 | popt, fwhm, r_squared = compute_fwhm(light_1d, [self.x, self.y], field='Intensity', fit = 'gaussian', dimension='1D')
115 | fwhm_theoretical = 2*sigma*jnp.sqrt(2*jnp.log(2))
116 | self.assertAlmostEqual(fwhm, fwhm_theoretical, places=4)
117 |
118 | popt, fwhm, r_squared = compute_fwhm(light_2d, [self.x, self.y], field='Intensity', fit = 'gaussian', dimension='2D')
119 | fwhm_x, fwhm_y = fwhm
120 | fwhm_theoretical = 2*sigma*jnp.sqrt(2*jnp.log(2))
121 | self.assertAlmostEqual(fwhm_x, fwhm_theoretical, places=4)
122 | self.assertAlmostEqual(fwhm_y, fwhm_theoretical, places=4)
123 |
124 | def test_find_max_min(self):
125 | value = jnp.array([[1, 2], [3, 4]])
126 | idx, xy, ext_value = find_max_min(value, self.x[:2], self.y[:2], kind='max')
127 | self.assertEqual(idx.shape, (1, 2))
128 | self.assertEqual(xy.shape, (1, 2))
129 | self.assertEqual(ext_value, 4)
130 | idx, xy, ext_value = find_max_min(value, self.x[:2], self.y[:2], kind='min')
131 | self.assertEqual(ext_value, 1)
132 |
133 | if __name__ == '__main__':
134 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_vectorized_optics.py:
--------------------------------------------------------------------------------
1 | # Test for vectorized optics module
2 | import os
3 | import sys
4 |
5 | # Setting the path for XLuminA modules:
6 | current_path = os.path.abspath(os.path.join('..'))
7 | module_path = os.path.join(current_path)
8 |
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | import unittest
13 | import jax.numpy as jnp
14 | import numpy as np
15 | from xlumina.vectorized_optics import PolarizedLightSource, VectorizedLight
16 |
17 | class TestVectorizedOptics(unittest.TestCase):
18 | def setUp(self):
19 | self.wavelength = 633e-3 #nm
20 | self.resolution = 1024
21 | self.x = np.linspace(-1500, 1500, self.resolution)
22 | self.y = np.linspace(-1500, 1500, self.resolution)
23 | self.k = 2 * jnp.pi / self.wavelength
24 |
25 | def test_vectorized_light(self):
26 | light = VectorizedLight(self.x, self.y, self.wavelength)
27 | self.assertEqual(light.wavelength, self.wavelength)
28 | self.assertEqual(light.k, self.k)
29 | self.assertEqual(light.Ex.shape, (self.resolution, self.resolution))
30 | self.assertEqual(light.Ey.shape, (self.resolution, self.resolution))
31 | self.assertEqual(light.Ez.shape, (self.resolution, self.resolution))
32 |
33 | def test_polarized_light_source_horizontal(self):
34 | source = PolarizedLightSource(self.x, self.y, self.wavelength)
35 | source.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 0))
36 | # TEST POLARIZATION
37 | self.assertGreater(jnp.sum(jnp.abs(source.Ex)**2), 0)
38 | self.assertEqual(jnp.sum(jnp.abs(source.Ey)**2), 0)
39 |
40 | def test_polarized_light_source_vertical(self):
41 | source = PolarizedLightSource(self.x, self.y, self.wavelength)
42 | source.gaussian_beam(w0=(1200, 1200), jones_vector=(0, 1))
43 | # TEST POLARIZATION
44 | self.assertGreater(jnp.sum(jnp.abs(source.Ey)**2), 0)
45 | self.assertEqual(jnp.sum(jnp.abs(source.Ex)**2), 0)
46 |
47 | def test_polarized_light_source_diagonal(self):
48 | source = PolarizedLightSource(self.x, self.y, self.wavelength)
49 | source.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 1))
50 | # TEST POLARIZATION
51 | self.assertGreater(jnp.sum(jnp.abs(source.Ex)**2), 0)
52 | self.assertGreater(jnp.sum(jnp.abs(source.Ey)**2), 0)
53 |
54 | def test_vrs_propagation(self):
55 | light = PolarizedLightSource(self.x, self.y, self.wavelength)
56 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 1))
57 | propagated, _ = light.VRS_propagation(z=1000)
58 | self.assertEqual(propagated.Ex.shape, (self.resolution, self.resolution))
59 | self.assertEqual(propagated.Ey.shape, (self.resolution, self.resolution))
60 | self.assertEqual(propagated.Ez.shape, (self.resolution, self.resolution))
61 | # Check same SoP
62 | self.assertGreater(jnp.sum(jnp.abs(light.Ex)**2), 0)
63 | self.assertGreater(jnp.sum(jnp.abs(light.Ey)**2), 0)
64 |
65 | def test_vczt(self):
66 | light = PolarizedLightSource(self.x, self.y, self.wavelength)
67 | light.gaussian_beam(w0=(1200, 1200), jones_vector=(1, 1))
68 | propagated = light.VCZT(1000, self.x, self.y)
69 | self.assertEqual(propagated.Ex.shape, (self.resolution, self.resolution))
70 | self.assertEqual(propagated.Ey.shape, (self.resolution, self.resolution))
71 | self.assertEqual(propagated.Ez.shape, (self.resolution, self.resolution))
72 | # Check same SoP
73 | self.assertGreater(jnp.sum(jnp.abs(light.Ex)**2), 0)
74 | self.assertGreater(jnp.sum(jnp.abs(light.Ey)**2), 0)
75 |
76 | if __name__ == '__main__':
77 | unittest.main()
78 |
--------------------------------------------------------------------------------
/tests/test_wave_optics.py:
--------------------------------------------------------------------------------
1 | # Test for wave optics module
2 | import os
3 | import sys
4 |
5 | # Setting the path for XLuminA modules:
6 | current_path = os.path.abspath(os.path.join('..'))
7 | module_path = os.path.join(current_path)
8 |
9 | if module_path not in sys.path:
10 | sys.path.append(module_path)
11 |
12 | import unittest
13 | import jax.numpy as jnp
14 | import numpy as np
15 | from xlumina.wave_optics import ScalarLight, LightSource
16 |
17 | class TestWaveOptics(unittest.TestCase):
18 | def setUp(self):
19 | self.wavelength = 633e-3 #nm
20 | self.resolution = 1024
21 | self.x = np.linspace(-1500, 1500, self.resolution)
22 | self.y = np.linspace(-1500, 1500, self.resolution)
23 | self.k = 2 * jnp.pi / self.wavelength
24 |
25 | def test_scalar_light(self):
26 | light = ScalarLight(self.x, self.y, self.wavelength)
27 | self.assertEqual(light.wavelength, self.wavelength)
28 | self.assertEqual(light.k, self.k)
29 | self.assertEqual(light.field.shape, (self.resolution, self.resolution))
30 |
31 | def test_light_source_gb(self):
32 | source = LightSource(self.x, self.y, self.wavelength)
33 | source.gaussian_beam(w0=(1200, 1200), E0=1)
34 | self.assertEqual(source.wavelength, self.wavelength)
35 | self.assertEqual(source.field.shape, (self.resolution, self.resolution))
36 | self.assertGreater(jnp.sum(jnp.abs(source.field)**2), 0)
37 |
38 | def test_light_source_pw(self):
39 | source = LightSource(self.x, self.y, self.wavelength)
40 | source.plane_wave(A=1, theta=0, phi=0, z0=0)
41 | self.assertGreater(jnp.sum(jnp.abs(source.field)**2), 0)
42 |
43 | def test_rs_propagation(self):
44 | light = LightSource(self.x, self.y, self.wavelength)
45 | light.gaussian_beam(w0=(1200, 1200), E0=1)
46 | propagated, _ = light.RS_propagation(z=1000)
47 | self.assertEqual(propagated.field.shape, (self.resolution, self.resolution))
48 |
49 | def test_czt(self):
50 | light = LightSource(self.x, self.y, self.wavelength)
51 | light.gaussian_beam(w0=(1200, 1200), E0=1)
52 | propagated = light.CZT(z=1000)
53 | self.assertEqual(propagated.field.shape, (self.resolution, self.resolution))
54 |
55 | if __name__ == '__main__':
56 | unittest.main()
--------------------------------------------------------------------------------
/xlumina/__init__.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 |
3 | """
4 | Define basic units:
5 |
6 | um = microns -> we set um = 1.
7 | mm = nanometers
8 | mm = milimeters
9 | cm = centimeters
10 |
11 | radians -> we set radians = 1.
12 | degrees = 180 / jnp.pi -> When *degrees, the units are transformed to degrees.
13 | """
14 |
15 | um = 1
16 | nm = 1e-3
17 | mm = 1e3
18 | cm = 1e4
19 |
20 | radians = 1
21 | degrees = 180/jnp.pi
--------------------------------------------------------------------------------
/xlumina/loss_functions.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from jax import jit, vmap, config
3 | from .__init__ import um
4 |
5 | # Set this to False if f64 is enough precision for you.
6 | enable_float64 = True
7 | if enable_float64:
8 | config.update("jax_enable_x64", True)
9 |
10 | """ Loss functions:
11 |
12 | - small_area_hybrid
13 | - vectorized_loss_hybrid
14 | - mean_batch_MSE_Intensity
15 | - vMSE_Amplitude
16 | - vMSE_Phase
17 | - vMSE_Intensity
18 | - MSE_Amplitude
19 | - MSE_Phase
20 | - MSE_Intensity
21 |
22 | """
23 |
24 | def small_area_hybrid(detected_intensity):
25 | """
26 | [Small area loss function valid for hybrid (topology + optical parameters) optimization:]
27 |
28 | Computes the fraction of intensity comprised inside the area of a mask.
29 |
30 | Parameters:
31 | detected_intensity (jnp.array): Detected intensity array
32 | + epsilon (float): fraction of minimum intensity comprised inside the area.
33 |
34 | Return type jnp.array.
35 | """
36 | epsilon = 0.7
37 | eps = 1e-08
38 | I = detected_intensity / (jnp.sum(detected_intensity) + eps)
39 | mask = jnp.where(I > epsilon*jnp.max(I), 1, 0)
40 | return jnp.sum(mask) / (jnp.sum(mask * I) + eps)
41 |
42 | @jit
43 | def vectorized_loss_hybrid(detected_intensities):
44 | """[For loss_hybrid]: vectorizes loss function to be used across various detectors"""
45 | # Input field has (M, N, N) shape
46 | vloss = vmap(small_area_hybrid, in_axes = (0))
47 | # Call the vectorized function
48 | loss_val = vloss(detected_intensities)
49 | # Returns (M, 1, 1) shape
50 | return loss_val
51 |
52 | def mean_batch_MSE_Intensity(optimized, target):
53 | """
54 | [Computed for batch optimization in 4f system]. Vectorized version of MSE_Intensity.
55 |
56 | Returns the mean value of all the MSE for each (optimized, target) pairs and a jnp.array with MSE values from each pair.
57 | """
58 | MSE = vmap(MSE_Intensity, in_axes=(0, 0))(optimized, target)
59 | return jnp.mean(MSE), MSE
60 |
61 | def vMSE_Amplitude(input_light, target_light):
62 | """
63 | Computes the Mean Squared Error (in Amplitude) for each electric field component (computed in parallel) Ei (i = x, y, z).
64 |
65 | Parameters:
66 | input_field (object): VectorizedLight in the focal plane of an objective lens.
67 | target_field (object): VectorizedLight in the focal plane of an objective lens.
68 |
69 | Returns the MSE in jnp.array [MSEx, MSEy, MSEz].
70 | """
71 | E_in = jnp.stack([input_light.Ex, input_light.Ey, input_light.Ez], axis=-1)
72 | E_target = jnp.stack([target_light.Ex, target_light.Ey, target_light.Ez], axis=-1)
73 | vectorized_MSE = vmap(MSE_Amplitude, in_axes=(2, 2))
74 | MSE_out = vectorized_MSE(E_in, E_target)
75 | return MSE_out
76 |
77 | def vMSE_Phase(input_light, target_light):
78 | """
79 | Computes the Mean Squared Error (in Phase) for each electric field component (computed in parallel) Ei (i = x, y, z).
80 |
81 | Parameters:
82 | input_field (object): VectorizedLight in the focal plane of an objective lens.
83 | target_field (object): VectorizedLight in the focal plane of an objective lens.
84 |
85 | Returns the MSE in jnp.array [MSEx, MSEy, MSEz].
86 | """
87 | E_in = jnp.stack([input_light.Ex, input_light.Ey, input_light.Ez], axis=-1)
88 | E_target = jnp.stack([target_light.Ex, target_light.Ey, target_light.Ez], axis=-1)
89 | vectorized_MSE = vmap(MSE_Phase, in_axes=(2, 2))
90 | MSE_out = vectorized_MSE(E_in, E_target)
91 | return MSE_out
92 |
93 | def vMSE_Intensity(input_light, target_light):
94 | """
95 | Computes the Mean Squared Error (in Intensity) for each electric field component (computed in parallel) Ei (i = x, y, z).
96 |
97 | Parameters:
98 | input_field (object): VectorizedLight in the focal plane of an objective lens.
99 | target_field (object): VectorizedLight in the focal plane of an objective lens.
100 |
101 | Returns the MSE in jnp.array [MSEx, MSEy, MSEz].
102 | """
103 | E_in = jnp.stack([input_light.Ex, input_light.Ey, input_light.Ez], axis=-1)
104 | E_target = jnp.stack([target_light.Ex, target_light.Ey, target_light.Ez], axis=-1)
105 | vectorized_MSE = vmap(MSE_Intensity, in_axes=(2, 2))
106 | MSE_out = vectorized_MSE(E_in, E_target)
107 | return MSE_out
108 |
109 | @jit
110 | def MSE_Amplitude(input_light, target_light):
111 | """
112 | Computes the Mean Squared Error (in Amplitude) for a given electric field component Ex, Ey or Ez.
113 |
114 | Parameters:
115 | input_light (array): If origin light is VectorizedLight, field Ex, Ey or Ez in the detector. For ScalarLight, it corresponds to .field.
116 | target_light (array): Ground truth - field Ex, Ey or Ez in the detector.
117 |
118 | Returns the MSE (jnp.array).
119 | """
120 | num_pix = input_light.shape[0] * input_light.shape[1]
121 | return jnp.sum((jnp.abs(input_light) - jnp.abs(target_light)) ** 2) / num_pix
122 |
123 | @jit
124 | def MSE_Phase(input_light, target_light):
125 | """
126 | Computes the Mean Squared Error (in Phase) for a given electric field component Ex, Ey or Ez.
127 |
128 | Parameters:
129 | input_light (array): If origin light is VectorizedLight, field Ex, Ey or Ez in the detector. For ScalarLight, it corresponds to .field.
130 | target_light (array): Ground truth - field Ex, Ey or Ez in the detector.
131 |
132 | Returns the MSE (jnp.array).
133 | """
134 | num_pix = input_light.shape[0] * input_light.shape[1]
135 | return jnp.sum((jnp.angle(input_light) - jnp.angle(target_light)) ** 2) / num_pix
136 |
137 | @jit
138 | def MSE_Intensity(input_light, target_light):
139 | """
140 | Computes the Mean Squared Error (in Intensity) for a given electric field component Ex, Ey or Ez.
141 |
142 | Parameters:
143 | input_light (array): If origin light is VectorizedLight, field Ex, Ey or Ez in the detector. For ScalarLight, it corresponds to .field.
144 | target_light (array): Ground truth - field Ex, Ey or Ez in the detector.
145 |
146 | Returns the MSE (jnp.array).
147 | """
148 | num_pix = jnp.shape(input_light)[0] * jnp.shape(input_light)[1]
149 | return jnp.sum((jnp.abs(input_light)**2 - jnp.abs(target_light)**2) ** 2) / num_pix
--------------------------------------------------------------------------------
/xlumina/toolbox.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jax.numpy as jnp
3 | import h5py
4 | import random
5 | from PIL import Image
6 | from jax import config, jit, nn
7 | from scipy.optimize import curve_fit
8 | import matplotlib.pyplot as plt
9 |
10 | # Set this to False if f64 is enough precision for you.
11 | enable_float64 = True
12 | if enable_float64:
13 | config.update("jax_enable_x64", True)
14 | d_type = np.float64
15 |
16 | """
17 | Contains useful functions:
18 |
19 | - space
20 | - wrap_phase
21 | - is_conserving_energy
22 | - softmin
23 | - delta_kronecker
24 | - build_LCD_cell
25 | - draw_sSLM
26 | NEW:- draw_sSLM_amplitude
27 | - moving_avg
28 | NEW:- image_to_binary_mask
29 | - rotate_mask
30 | - nearest
31 | - extract_profile
32 | - profile
33 | NEW:- gaussian
34 | NEW:- gaussian_2d
35 | NEW:- lorentzian
36 | NEW:- lorentzian_2d
37 | NEW:- fwhm_1d_fit
38 | NEW:- fwhm_2d_fit
39 | - spot size
40 | NEW:- compute_fwhm
41 | - find_max_min
42 | - fwhm_1d (no fitting)
43 | - fwhm_2d (no fitting)
44 | NEW:> Functions to process datasets (hdf5 files):
45 | CLASS MultiHDF5DataLoader
46 |
47 | """
48 |
49 | def space(x_total, num_pix):
50 | """
51 | Define the space of the simulation.
52 |
53 | Parameters:
54 | x_total (float): Length of half the array (in microns).
55 | num_pix (float): 1D resolution.
56 |
57 | Returns x and y (jnp.array).
58 | """
59 | x = np.linspace(-x_total, x_total, num_pix, dtype=d_type)
60 | y = np.linspace(-x_total, x_total, num_pix, dtype=d_type)
61 | return x, y
62 |
63 | def wrap_phase(phase):
64 | """
65 | Wraps the input phase into [-pi, pi] range.
66 |
67 | Parameters:
68 | phase (jnp.array): Phase to wrap.
69 |
70 | Returns the wrapped-phase (jnp.array).
71 | """
72 | return jnp.arctan2(jnp.sin(phase), jnp.cos(phase))
73 |
74 | def is_conserving_energy(light_source, propagated_light):
75 | """
76 | Computes the total intensity from the light source and compares the propagated light.
77 | [Ref: J. Li, Z. Fan, Y. Fu, Proc. SPIE 4915, (2002)].
78 |
79 | Parameters:
80 | light_source (object): VectorizedLight light source.
81 | propagated_light (object): Propagated VectorizedLight.
82 |
83 | Ideally, Itotal_propagated / I_source = 1.
84 | Values <1 can happen when light is lost (i.e., light gets outside the computational window).
85 |
86 | Returns Itotal_propagated / I_source (jnp.array).
87 | """
88 | if light_source.info == 'Wave optics light' or light_source.info == 'Wave optics light source':
89 | I_source = jnp.sum(jnp.abs(light_source.field**2))
90 | I_propagated = jnp.sum(jnp.abs(propagated_light.field**2))
91 |
92 | else:
93 | I_source = jnp.sum(jnp.abs(light_source.Ex**2)) + jnp.sum(jnp.abs(light_source.Ey**2)) + jnp.sum(jnp.abs(light_source.Ez**2))
94 | I_propagated = jnp.sum(jnp.abs(propagated_light.Ex**2)) + jnp.sum(jnp.abs(propagated_light.Ey**2)) + jnp.sum(jnp.abs(propagated_light.Ez**2))
95 |
96 | return I_propagated / I_source
97 |
98 | @jit
99 | def softmin(args, beta=90):
100 | """
101 | Differentiable version for min() function.
102 | """
103 | return - nn.logsumexp(-beta * args) / beta
104 |
105 |
106 | def delta_kronecker(a, b):
107 | """
108 | Computes the Kronecker delta.
109 |
110 | Parameters:
111 | a (float): Number
112 | b (float): Number
113 |
114 | Returns (int).
115 | """
116 | if a == b:
117 | return 1
118 | else:
119 | return 0
120 |
121 | def build_LCD_cell(eta, theta, shape):
122 | """
123 | Builds the LCD cell: eta and theta are constant across the cell [not pixel-wise modulation!!].
124 |
125 | Parameters:
126 | eta (float): Phase difference between Ex and Ey (in radians).
127 | theta (float): Tilt of the fast axis w.r.t. horizontal (in radians).
128 | shape (float): 1D resolution.
129 |
130 | Returns the phase and tilt (jnp.array).
131 | """
132 | # Builds constant eta and theta LCD cell
133 | eta_array = eta * jnp.ones(shape=(shape, shape))
134 | theta_array = theta * jnp.ones(shape=(shape, shape))
135 | return eta_array, theta_array
136 |
137 | def draw_sSLM(alpha, phi, extent, extra_title=None, save_file=False, filename=''):
138 | """
139 | Plots the phase masks of the sSLM (for VectorizedLight).
140 |
141 | Parameters:
142 | alpha (jnp.array): Phase mask to be applied to Ex (in radians).
143 | phi (jnp.array): Phase mask to be applied to Ey (in radians).
144 | extent (jnp.array): Limits for x and y for plotting purposes.
145 | extra_title (str): Adds extra info to the plot title.
146 | save_file (bool): If True, saves the figure.
147 | filename (str): Name of the figure.
148 | """
149 | fig, axes = plt.subplots(1, 2, figsize=(14, 3))
150 | cmap = 'twilight'
151 |
152 | ax = axes[0]
153 | im = ax.imshow(alpha, cmap=cmap, extent=extent, origin='lower')
154 | ax.set_title(f"SLM #1. {extra_title}")
155 | ax.set_xlabel('$x (\mu m)$')
156 | ax.set_ylabel('$y (\mu m)$')
157 | fig.colorbar(im, ax=ax)
158 | im.set_clim(vmin=-jnp.pi, vmax=jnp.pi)
159 |
160 | ax = axes[1]
161 | im = ax.imshow(phi, cmap=cmap, extent=extent, origin='lower')
162 | ax.set_title(f"SLM #2. {extra_title}")
163 | ax.set_xlabel('$x (\mu m)$')
164 | ax.set_ylabel('$y (\mu m)$')
165 | fig.colorbar(im, ax=ax)
166 | im.set_clim(vmin=-jnp.pi, vmax=jnp.pi)
167 |
168 | plt.tight_layout()
169 |
170 | if save_file is True:
171 | plt.savefig(filename)
172 | print(f"Plot saved as {filename}")
173 |
174 | plt.show()
175 |
176 | def draw_sSLM_amplitude(A1, A2, extent, extra_title=None, save_file=False, filename=''):
177 | """
178 | Plots the amplitude masks of the sSLM (for VectorizedLight).
179 |
180 | Parameters:
181 | A1 (jnp.array): Amp mask to be applied to Ex (AU).
182 | A2 (jnp.array): Amp mask to be applied to Ey (AU).
183 | extent (jnp.array): Limits for x and y for plotting purposes.
184 | extra_title (str): Adds extra info to the plot title.
185 | save_file (bool): If True, saves the figure.
186 | filename (str): Name of the figure.
187 | """
188 | fig, axes = plt.subplots(1, 2, figsize=(14, 3))
189 | cmap = 'Greys_r'
190 |
191 | ax = axes[0]
192 | im = ax.imshow(A1, cmap=cmap, extent=extent, origin='lower')
193 | ax.set_title(f"SLM #1: amplitude mask {extra_title}")
194 | ax.set_xlabel('$x (\mu m)$')
195 | ax.set_ylabel('$y (\mu m)$')
196 | fig.colorbar(im, ax=ax)
197 | im.set_clim(vmin=0, vmax=1)
198 |
199 | ax = axes[1]
200 | im = ax.imshow(A2, cmap=cmap, extent=extent, origin='lower')
201 | ax.set_title(f"SLM #2: amplitude mask {extra_title}")
202 | ax.set_xlabel('$x (\mu m)$')
203 | ax.set_ylabel('$y (\mu m)$')
204 | fig.colorbar(im, ax=ax)
205 | im.set_clim(vmin=0, vmax=1)
206 |
207 | plt.tight_layout()
208 |
209 | if save_file is True:
210 | plt.savefig(filename)
211 | print(f"Plot saved as {filename}")
212 |
213 | plt.show()
214 |
215 |
216 | def moving_avg(window_size, data):
217 | """
218 | Compute the moving average of a dataset.
219 |
220 | Parameters:
221 | window_size (int): Number of datapoints to compute the avg.
222 | data (jnp.array): Data.
223 |
224 | Returns moving average (jnp.array).
225 | """
226 | return jnp.convolve(jnp.array(data), jnp.ones(window_size)/window_size, mode='valid')
227 |
228 | def image_to_binary_mask(filename, x, y, mirror = 'vertical', normalize=True, invert=False, threshold=0.5):
229 | """
230 | Converts image > binary mask (given a threshold)
231 |
232 | Parameters:
233 | filename (str): os path to the image file
234 | x, y (jnp.array): corresponds to the space dimensions where the mask is placed
235 | mirror (str): direction to mirror the image. Can be 'horizontal', 'vertical' or 'both'. Default is 'vertical'
236 | normalize (bool): if True, normalizes the image
237 | invert (bool): if True, inverts the binarization
238 | threshold (float): pix value threshold for binarization (0 to 1)
239 |
240 | Returns:
241 | Binary mask (jnp.array)
242 | """
243 | with Image.open(filename) as img:
244 | img_gray = img.convert('L')
245 |
246 | if mirror == 'horizontal':
247 | img_gray = img_gray.transpose(Image.FLIP_LEFT_RIGHT)
248 | if mirror == 'vertical':
249 | img_gray = img_gray.transpose(Image.FLIP_TOP_BOTTOM)
250 | if mirror == 'both':
251 | img_gray = img_gray.transpose(Image.FLIP_LEFT_RIGHT)
252 | img_gray = img_gray.transpose(Image.FLIP_TOP_BOTTOM)
253 |
254 | size_x, size_y = jnp.size(x), jnp.size(y)
255 | img_gray = img_gray.resize((size_y, size_x))
256 |
257 | img_array = jnp.array(img_gray) # convert into jax array
258 |
259 | if normalize:
260 | img_array = (img_array - jnp.min(img_array)) / (jnp.max(img_array) - jnp.min(img_array))
261 |
262 | if invert:
263 | img_array = jnp.max(img_array) - img_array
264 |
265 | binary_mask = (img_array > threshold).astype(jnp.uint8)
266 |
267 | return binary_mask
268 |
269 | def rotate_mask(X, Y, angle, origin=None):
270 | """
271 | Rotates the (X, Y) frame of a mask w.r.t. origin.
272 |
273 | Parameters:
274 | origin (float, float): Coordinates w.r.t. which perform the rotation (in microns).
275 | angle (float): Rotation angle (in radians).
276 |
277 | Returns the rotated meshgrid X, Y (jnp.array).
278 |
279 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<
280 | """
281 | if origin is None:
282 | x0 = (X[-1] + X[0]) / 2
283 | y0 = (Y[-1] + Y[0]) / 2
284 | else:
285 | x0, y0 = origin
286 |
287 | Xrot = (X - x0) * jnp.cos(angle) + (Y - y0) * jnp.sin(angle)
288 | Yrot = -(X - x0) * jnp.sin(angle) + (Y - y0) * jnp.cos(angle)
289 | return Xrot, Yrot
290 |
291 | def nearest(array, value):
292 | """
293 | Finds the nearest value and its index in an array.
294 |
295 | Parameters:
296 | array (jnp.array): Array to analyze.
297 | value (float): Number to which determine the position.
298 |
299 | Returns index (idx), the value of the array at idx position and the distance.
300 |
301 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<
302 | """
303 | idx = (jnp.abs(array - value)).argmin()
304 | return idx, array[idx], abs(array[idx] - value)
305 |
306 | def extract_profile(data_2d, x_points, y_points):
307 | """
308 | [From profile] Extract the values along a line defined by x_points and y_points.
309 |
310 | Parameters:
311 | data_2d (jnp.array): Input data from which extract the profile.
312 | x_points, y_points (jnp.array): X and Y arrays.
313 |
314 | Returns the profile (jnp.array).
315 | """
316 | x_indices = jnp.round(x_points).astype(int)
317 | y_indices = jnp.round(y_points).astype(int)
318 |
319 | x_indices = jnp.clip(x_indices, 0, data_2d.shape[1] - 1)
320 | y_indices = jnp.clip(y_indices, 0, data_2d.shape[0] - 1)
321 |
322 | profile = [data_2d[y, x] for x, y in zip(x_indices, y_indices)]
323 | return jnp.array(profile)
324 |
325 | def profile(data_2d, x, y, point1='', point2=''):
326 | """
327 | Determine profile for a given input without using interpolation.
328 |
329 | Parameters:
330 | data_2d (jnp.array): Input 2D array from which extract the profile.
331 | point1 (float, float): Initial point.
332 | point2 (float, float): Final point.
333 |
334 | Returns the profile (h and z) of the input (jnp.array).
335 | """
336 | x1, y1 = point1
337 | x2, y2 = point2
338 |
339 | ix1, value, distance = nearest(x, x1)
340 | ix2, value, distance = nearest(x, x2)
341 | iy1, value, distance = nearest(y, y1)
342 | iy2, value, distance = nearest(y, y2)
343 |
344 | # Create a set of x and y points along the line between point1 and point2
345 | x_points = jnp.linspace(ix1, ix2, int(jnp.hypot(ix2-ix1, iy2-iy1)))
346 | y_points = jnp.linspace(iy1, iy2, int(jnp.hypot(ix2-ix1, iy2-iy1)))
347 |
348 | h = jnp.linspace(0, jnp.sqrt((y2 - y1)**2 + (x2 - x1)**2), len(x_points))
349 | h = h - h[-1] / 2
350 |
351 | z_profile = extract_profile(data_2d, x_points, y_points)
352 |
353 | return h, z_profile
354 |
355 | def gaussian(x, amplitude, mean, std_dev):
356 | """
357 | [In fwhm_1d_fit]
358 | Returns 1D Gaussian (Normal distribution) for FWHM calculation
359 |
360 | Parameters:
361 | x (jnp.array): 1D-position array
362 | mean_x (float): location of the peak of the distribution in X or Y
363 | stdev_x (float): standard deviation in X or Y
364 | """
365 | return amplitude * jnp.exp(-((x - mean) / std_dev)**2 / 2)
366 |
367 | def gaussian_2d(xy, amplitude, mean_x, mean_y, stdev_x, stdev_y):
368 | """
369 | [In fwhm_2d_fit]
370 | Returns 2D Gaussian (Normal distribution) for FWHM 2D calculation
371 |
372 | Parameters:
373 | xy (tuple): contains two 1D arrays, X and Y, which are the meshgrid of x and y coordinates
374 | mean_x (float): location of the peak of the distribution in X
375 | mean_y (float): location of the peak of the distribution in Y
376 | stdev_x (float): standard deviation in X
377 | stdev_y (float): standard deviation in Y.
378 | """
379 | X, Y = xy
380 | return amplitude * jnp.exp(-((X - mean_x)**2 / (2 * stdev_x**2) + (Y - mean_y)**2 / (2 * stdev_y**2)))
381 |
382 | def lorentzian(x, x0, gamma):
383 | """
384 | [In fwhm_1d_fit]
385 | Returns Lorentzian -- pathological distribution (expected value and variance are undefined)
386 | Parameters:
387 | x (jnp.array): 1D-position array
388 | x0 (float): location of the peak of the distribution
389 | gamma (float): scale parameter. Specifies FWHM = 2 * gamma.
390 | """
391 | return (1/jnp.pi) * (gamma / ((x-x0)**2 + gamma**2))
392 |
393 | def lorentzian_2d(xy, amplitude, x0, y0, gamma_x, gamma_y):
394 | """
395 | [In fwhm_2d_fit]
396 | Returns 2D Lorentzian -- pathological distribution (expected value and variance are undefined)
397 |
398 | Parameters:
399 | xy (tuple): contains two 1D arrays, X and Y, which are the meshgrid of x and y coordinates
400 | amplitude (float): amplitude of the peak
401 | x0 (float): location of the peak of the distribution in X
402 | y0 (float): location of the peak of the distribution in Y
403 | gamma_x (float): scale parameter in X.
404 | gamma_y (float): scale parameter in Y.
405 | """
406 | X, Y = xy
407 | return amplitude / (1 + ((X - x0) / gamma_x)**2 + ((Y - y0) / gamma_y)**2)
408 |
409 | def fwhm_1d_fit(x, intensity, fit = 'gaussian'):
410 | """
411 | Compute FWHM of a 1D-array using fit function (gaussian or lorentzian)
412 |
413 | Parameters:
414 | x (jnp.array): 1D-position array.
415 | intensity (jnp.array): 1-dimensional intensity array.
416 | fit (str): can be 'gaussian' or 'lorentzian'
417 |
418 | Returns:
419 | popt (amplitude_fit, mean_fit, stddev_fit),
420 | fwhm (float, in um)
421 | and r_squared (float, r-squared metric of the fit).
422 | """
423 | if fit == 'lorentzian':
424 | # initial guess (p0) for curve_fit
425 | x0_guess = jnp.max(intensity)
426 | gamma_guess = 0.5
427 | # lorentizan fit -- call scipy.curve_fit
428 | popt, _ = curve_fit(lorentzian, np.array(x), np.array(intensity), p0=[x0_guess, gamma_guess], maxfev=100000)
429 | _, gamma_fit = popt
430 |
431 | fwhm = 2 * gamma_fit
432 |
433 | # from https://en.wikipedia.org/wiki/Coefficient_of_determination
434 | # r^2 = 1 - SSres / SStot
435 | # here we compute residual sum of squares (SSres) and total sum of squares (SStot)
436 | ss_res = jnp.sum((intensity - lorentzian(x, *popt))**2)
437 |
438 | if fit == 'gaussian':
439 | # initial guess (p0) for curve_fit
440 | a_guess = jnp.max(intensity)
441 | mean_guess = x[jnp.argmax(intensity)]
442 | std_dev_guess = jnp.sqrt(jnp.sum((x - mean_guess)**2 * intensity) / jnp.sum(intensity))
443 |
444 | # gaussian fit -- call scipy.curve_fit
445 | popt, _ = curve_fit(gaussian, np.array(x), np.array(intensity), p0=[a_guess, mean_guess, std_dev_guess], maxfev=100000)
446 | _, _, stddev_fit = popt
447 |
448 | # FWHM normal distribution = 2*sqrt(2ln2)*sigma
449 | fwhm = 2 * jnp.sqrt(2 * jnp.log(2)) * stddev_fit
450 |
451 | # from https://en.wikipedia.org/wiki/Coefficient_of_determination
452 | # r^2 = 1 - SSres / SStot
453 | # here we compute residual sum of squares (SSres) and total sum of squares (SStot)
454 | ss_res = jnp.sum((intensity - gaussian(x, *popt))**2)
455 |
456 | else:
457 | raise ValueError("fit must be either 'gaussian' or 'lorentzian'")
458 |
459 | ss_tot = jnp.sum((intensity - jnp.mean(intensity))**2)
460 | r_squared = 1 - (ss_res / ss_tot)
461 |
462 | return popt, fwhm, r_squared
463 |
464 | def fwhm_2d_fit(x, y, intensity, fit = 'gaussian'):
465 | """
466 | Computes FWHM of an 2-dimensional using fit function (gaussian or lorentzian)
467 |
468 | Parameters:
469 | x (jnp.array): 1D-position array.
470 | y (jnp.array): 1D-position array.
471 | intensity (jnp.array): 2-dimensional intensity array.
472 | fit (str): can be 'gaussian' or 'lorentzian'
473 |
474 | Returns:
475 | popt (amplitude_fit, mean_fit, stddev_fit),
476 | FWHM_2D = (fwhm_x, fwhm_y) (tuple, in um)
477 | and r_squared (float, r-squared metric of the fit).
478 | """
479 | X, Y = jnp.meshgrid(x, y)
480 | xy = jnp.vstack((X.ravel(), Y.ravel())) # vertical stack of ravel arrays for scipy's curve_fit (accepts indep. var. as a single arg)
481 |
482 | if fit == 'lorentzian':
483 | amplitude_guess = jnp.max(intensity)
484 | # get x[idxx, idxy] and y[idxx', idxy'] where max intensity
485 | mean_x_guess, mean_y_guess = (
486 | x[jnp.unravel_index(jnp.argmax(intensity), intensity.shape)[1]],
487 | y[jnp.unravel_index(jnp.argmax(intensity), intensity.shape)[0]]
488 | )
489 | gamma_x_guess, gamma_y_guess = 0.5, 0.5
490 |
491 | popt, _ = curve_fit(lorentzian_2d, xy, intensity.ravel(),
492 | p0=[amplitude_guess, mean_x_guess, mean_y_guess, gamma_x_guess, gamma_y_guess],
493 | maxfev=100000)
494 |
495 | _, _, _, gamma_x_fit, gamma_y_fit = popt
496 | FWHM2D = 2 * gamma_x_fit, 2 * gamma_y_fit
497 |
498 | # from https://en.wikipedia.org/wiki/Coefficient_of_determination
499 | # r^2 = 1 - SSres / SStot
500 | # here we compute residual sum of squares (SSres) and total sum of squares (SStot)
501 | ss_res = jnp.sum((intensity - lorentzian_2d(xy, *popt).reshape(intensity.shape))**2)
502 |
503 | if fit == 'gaussian':
504 | # initial guess for curve_fit
505 | amplitude_guess = jnp.max(intensity)
506 | # get x[idxx, idxy] and y[idxx', idxy'] where max intensity
507 | x0_guess, y0_guess = (
508 | x[jnp.unravel_index(jnp.argmax(intensity), intensity.shape)[1]],
509 | y[jnp.unravel_index(jnp.argmax(intensity), intensity.shape)[0]]
510 | )
511 | sigma_x_guess = jnp.sqrt(jnp.sum((X - x0_guess)**2 * intensity) / jnp.sum(intensity))
512 | sigma_y_guess = jnp.sqrt(jnp.sum((Y - y0_guess)**2 * intensity) / jnp.sum(intensity))
513 |
514 | popt, _ = curve_fit(gaussian_2d, xy, intensity.ravel(),
515 | p0=[amplitude_guess, x0_guess, y0_guess, sigma_x_guess, sigma_y_guess],
516 | maxfev=100000)
517 |
518 | _, _, _, sigma_x_fit, sigma_y_fit = popt
519 | FWHM2D = 2 * jnp.sqrt(2 * jnp.log(2)) * sigma_x_fit, 2 * jnp.sqrt(2 * jnp.log(2)) * sigma_y_fit
520 |
521 | # compute r-squared
522 | # from https://en.wikipedia.org/wiki/Coefficient_of_determination
523 | # r^2 = 1 - SSres / SStot
524 | # here we compute residual sum of squares (SSres) and total sum of squares (SStot)
525 | ss_res = jnp.sum((intensity - gaussian_2d(xy, *popt).reshape(intensity.shape))**2)
526 |
527 | else:
528 | raise ValueError("fit must be either 'gaussian' or 'lorentzian'")
529 |
530 | ss_tot = jnp.sum((intensity - jnp.mean(intensity))**2)
531 | r_squared = 1 - (ss_res / ss_tot)
532 |
533 | return popt, FWHM2D, r_squared
534 |
535 | def spot_size(fwhm_x, fwhm_y, wavelength):
536 | """
537 | Computes the spot size in wavelength**2 units.
538 |
539 | Parameters:
540 | fwhm_x (float): FWHM in x (in microns).
541 | fwhm_y (float): FWHM in y (in microns).
542 | wavelength (float): Wavelength (in microns).
543 |
544 | Returns the spot size (jnp.array).
545 | """
546 | return jnp.pi * (fwhm_x/2) * (fwhm_y/2) / wavelength**2
547 |
548 | def compute_fwhm(light, light_specs, field='', fit = 'gaussian', dimension = '2D', pix_slice = None):
549 | """
550 | Computes FWHM in 1D or 2D (in um).
551 |
552 | Parameters:
553 | light (object): can be a jnp.array if field='' is set to 'Intensity'.
554 | light_specs (list): list with light specs - measurement plane [x, y].
555 | field (str): component for which compute FWHM. Can be 'Ex', 'Ey', 'Ez', 'r' (radial), 'rz' (total field) and 'Intensity' if the input is not a field, but an intensity array.
556 | dimension (str): can be '1D' or '2D'.
557 | pix_slice (int): pix number in which to perform a slice for 1D calculation. E.g., in the center: pix_slice = resolution // 2
558 |
559 | Returns:
560 | popt, fwhm, r_squared; where fwhm = FWHM_1D or FWHM_2D (tuple)
561 | """
562 | if field == 'Ex':
563 | intensity = (jnp.abs(light.Ex)) ** 2
564 | elif field == 'Ey':
565 | intensity = (jnp.abs(light.Ey)) ** 2
566 | elif field == 'Ez':
567 | intensity = (jnp.abs(light.Ez)) ** 2
568 | elif field == 'r':
569 | intensity = jnp.abs(light.Ex) ** 2 + (jnp.abs(light.Ey)) ** 2
570 | elif field == 'rz':
571 | intensity = (jnp.abs(light.Ex)) ** 2 + (jnp.abs(light.Ey)) ** 2 + (jnp.abs(light.Ez)) ** 2
572 | elif field == 'Intensity':
573 | intensity = jnp.abs(light)
574 |
575 | if '1D' in dimension:
576 | if dimension == '1D_x':
577 | intensity = intensity[:, pix_slice] # Slice intensity array in pix_slice
578 | axis = light_specs[0]
579 | elif dimension == '1D_y':
580 | intensity = intensity[pix_slice, :] # Slice intensity array in pix_slice
581 | axis = light_specs[1]
582 | else:
583 | axis = light_specs[0]
584 | popt, fwhm, r_squared = fwhm_1d_fit(axis, intensity, fit)
585 |
586 | if dimension == '2D':
587 | popt, fwhm, r_squared = fwhm_2d_fit(light_specs[0], light_specs[1], intensity, fit)
588 | # FWHM (tuple) = FWHM_x, FWHM_y
589 |
590 | return popt, fwhm, r_squared
591 |
592 | def find_max_min(value, x, y, kind =''):
593 | """
594 | Find the position of maximum and minimum values within a 2D array.
595 |
596 | Parameters:
597 | value (jnp.array): 2-dimensional array with values.
598 | x (jnp.array): x-position array
599 | y (jnp.array): y-position array
600 | kind (str): choose whether to detect minimum 'min' or maximum 'max'.
601 |
602 | Returns:
603 | idx (int, int): indexes of the position of max/min.
604 | xy (float, float): space position of max/min.
605 | ext_value (float): max/min value.
606 |
607 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<
608 | """
609 | if kind =='min':
610 | val = jnp.where(value==jnp.min(value))
611 | if kind =='max':
612 | val = jnp.where(value==jnp.max(value))
613 |
614 | # Extract coordinates into separate arrays:
615 | coordinates = jnp.array(list(zip(val[1], val[0])))
616 | coords_0 = coordinates[:, 0]
617 | coords_1 = coordinates[:, 1]
618 |
619 | # Index array:
620 | idx = coordinates.astype(int)
621 |
622 | # Array with space positions:
623 | xy = jnp.stack([x[coords_0], y[coords_1]], axis=1)
624 |
625 | # Array with extreme values:
626 | ext_value = value[coords_1, coords_0]
627 |
628 | return idx, xy, ext_value
629 |
630 | def fwhm_2d(x, y, intensity):
631 | """
632 | Computes FWHM of an 2-dimensional intensity array.
633 |
634 | Parameters:
635 | x (jnp.array): x-position array.
636 | y (jnp.array): y-position array.
637 | intensity (jnp.array): 2-dimensional intensity array.
638 |
639 | Returns:
640 | fwhm_x and fwhm_y
641 |
642 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<
643 | """
644 | i_position, _, _ = find_max_min(jnp.transpose(intensity), x, y, kind='max')
645 |
646 | Ix = intensity[:, i_position[0, 1]]
647 | Iy = intensity[i_position[0, 0], :]
648 |
649 | fwhm_x, _, _, _ = fwhm_1d(x, Ix)
650 | fwhm_y, _, _, _ = fwhm_1d(y, Iy)
651 |
652 | return fwhm_x, fwhm_y
653 |
654 | def fwhm_1d(x, intensity, I_max = None):
655 | """
656 | Computes FWHM of 1-dimensional intensity array.
657 |
658 | Parameters:
659 | x (jnp.array): 1D-position array.
660 | intensity (jnp.array): 1-dimensional intensity array.
661 | Returns:
662 | fwhm in 1 dimension
663 |
664 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<
665 | """
666 | # Setting-up:
667 | dx = x[1] - x[0]
668 | if I_max is None:
669 | I_max = jnp.max(intensity)
670 | I_half = I_max * 0.5
671 |
672 | # Pixels with I max:
673 | i_max = jnp.argmax(intensity)
674 | # i_max = jnp.where(intensity == I_max)
675 | # # Compute the pixel location:
676 | # i_max = int(i_max[0][0])
677 |
678 | # Compute the slopes:
679 | i_left, _, distance_left = nearest(intensity[0:i_max], I_half)
680 | slope_left = (intensity[i_left + 1] - intensity[i_left]) / dx
681 |
682 | i_right, _, distance_right = nearest(intensity[i_max::], I_half)
683 | i_right += i_max
684 | slope_right = (intensity[i_right] - intensity[i_right - 1]) / dx
685 |
686 | x_right = x[i_right] - distance_right / slope_right
687 | x_left = x[i_left] - distance_left / slope_left
688 |
689 | # Compute fwhm:
690 | fwhm = x_right - x_left
691 |
692 | return fwhm, x_left, x_right, I_half
693 |
694 | # -----------------------------------------------------------
695 |
696 | """ Functions to process datasets (hdf5 files) """
697 |
698 | class MultiHDF5DataLoader:
699 | """
700 | Class for JAX-DataLoader
701 | """
702 | def __init__(self, directory, batch_size):
703 | self.directory = directory
704 | self.batch_size = batch_size
705 |
706 | # Get a list of all HDF5 files in the directory
707 | self.files = [f for f in os.listdir(self.directory) if f.endswith('.hdf5')]
708 |
709 | if not self.files:
710 | raise ValueError(f"No HDF5 files found in directory: {self.directory}")
711 |
712 | def __iter__(self):
713 | return self
714 |
715 | def __next__(self):
716 | """
717 | Randomly selects one file and randomly picks batch_size number of files from it. Returns the jnp.array version.
718 | """
719 | # Randomly select one of the HDF5 files
720 | selected_file = random.choice(self.files)
721 | filepath = os.path.join(self.directory, selected_file)
722 |
723 | # Open the selected HDF5 file to get the total number of samples;
724 | # This is where we open the selected HDF5 file in read mode. However, we're not reading the entire dataset into memory.
725 | # Instead, we're just accessing its shape to get the total number of samples. The data remains on disk.
726 | with h5py.File(filepath, 'r') as hf:
727 | total_samples = hf["Input fields"].shape[0]
728 |
729 | # Randomly select indices for the batch
730 | batch_indices = sorted(random.sample(range(total_samples), self.batch_size))
731 |
732 | # Load the batch from the selected HDF5 file;
733 | # We open the HDF5 file in read mode again, but this time, we're using the randomly selected indices (batch_indices) to fetch only a specific subset of the data.
734 | # Only the data corresponding to these indices is loaded into memory, and not the entire dataset.
735 | with h5py.File(filepath, 'r') as hf:
736 | input_batch = hf["Input fields"][batch_indices]
737 | target_batch = hf["Target fields"][batch_indices]
738 |
739 | return jnp.array(input_batch), jnp.array(target_batch)
--------------------------------------------------------------------------------
/xlumina/vectorized_optics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jax.numpy as jnp
3 | from jax import jit, vmap, config
4 | from functools import partial
5 | import matplotlib.pyplot as plt
6 | import time
7 |
8 | from .toolbox import profile
9 | from .wave_optics import build_grid, RS_propagation_jit, build_CZT_grid, CZT_jit, CZT_for_high_NA_jit
10 |
11 | # Set this to False if f64 is enough precision for you.
12 | enable_float64 = True
13 | d_type = jnp.complex64
14 | if enable_float64:
15 | config.update("jax_enable_x64", True)
16 | d_type = jnp.complex128
17 |
18 | """
19 | Module for vectorized optical fields:
20 |
21 | - VectorizedLight:
22 | - draw
23 | - draw_intensity_profile
24 | - VRS_propagation
25 | - get_VRS_minimum_z
26 | - VCZT
27 | - VRS_propagation_jit
28 | - VCZT_jit
29 | - vectorized_CZT_for_high_NA
30 |
31 | - PolarizedLightSource:
32 | - gaussian_beam
33 | - plane_wave
34 | """
35 |
36 | class VectorizedLight:
37 | """ Class for Vectorial EM fields - (Ex, Ey, Ez) """
38 | def __init__(self, x=None, y=None, wavelength=None):
39 | self.x = x
40 | self.y = y
41 | self.X, self.Y = jnp.meshgrid(self.x, self.y)
42 | self.wavelength = wavelength
43 | self.k = 2 * jnp.pi / wavelength
44 | self.n = 1
45 | shape = (jnp.shape(x)[0], jnp.shape(y)[0])
46 | self.Ex = jnp.zeros(shape, dtype=d_type)
47 | self.Ey = jnp.zeros(shape, dtype=d_type)
48 | self.Ez = jnp.zeros(shape, dtype=d_type)
49 | self.info = 'Vectorized light'
50 |
51 | def draw(self, xlim='', ylim='', kind='', extra_title='', save_file=False, filename=''):
52 | """
53 | Plots VectorizedLight.
54 |
55 | Parameters:
56 | xlim (float, float): x-axis limit for plot purpose.
57 | ylim (float, float): y-axis limit for plot purpose.
58 | kind (str): Feature to plot: 'Intensity', 'Phase' or 'Field'.
59 | extra_title (str): Adds extra info to the plot title.
60 | save_file (bool): If True, saves the figure.
61 | filename (str): Name of the figure.
62 | """
63 | extent = [xlim[0], xlim[1], ylim[0], ylim[1]]
64 | if kind == 'Intensity':
65 | # Compute intensity
66 | Ix = jnp.abs(self.Ex) ** 2 # Ex
67 | Iy = jnp.abs(self.Ey) ** 2 # Ey
68 | Iz = jnp.abs(self.Ez) ** 2 # Ez
69 | Ir = Ix + Iy # Er
70 |
71 | fig, axes = plt.subplots(2, 3, figsize=(14, 7))
72 | cmap = 'gist_heat'
73 |
74 | ax = axes[0,0]
75 | im = ax.imshow(Ix, cmap=cmap, extent=extent, origin='lower')
76 | ax.set_title(f"Intensity x. {extra_title}")
77 | ax.set_xlabel('$x (\mu m)$')
78 | ax.set_ylabel('$y (\mu m)$')
79 | fig.colorbar(im, ax=ax)
80 | im.set_clim(vmin=jnp.min(Ix), vmax=jnp.max(Ix))
81 |
82 | ax = axes[0,1]
83 | im = ax.imshow(Iy, cmap=cmap, extent=extent, origin='lower')
84 | ax.set_title(f"Intensity y. {extra_title}")
85 | ax.set_xlabel('$x (\mu m)$')
86 | ax.set_ylabel('$y (\mu m)$')
87 | fig.colorbar(im, ax=ax)
88 | im.set_clim(vmin=jnp.min(Iy), vmax=jnp.max(Iy))
89 |
90 | ax = axes[1,0]
91 | im = ax.imshow(Iz, cmap=cmap, extent=extent, origin='lower')
92 | ax.set_title(f"Intensity z. {extra_title}")
93 | ax.set_xlabel('$x (\mu m)$')
94 | ax.set_ylabel('$y (\mu m)$')
95 | fig.colorbar(im, ax=ax)
96 | im.set_clim(vmin=jnp.min(Iz), vmax=jnp.max(Iz))
97 |
98 | ax = axes[0,2]
99 | im = ax.imshow(Ir, cmap=cmap, extent=extent, origin='lower')
100 | ax.set_title(f"Intensity r. {extra_title}")
101 | ax.set_xlabel('$x (\mu m)$')
102 | ax.set_ylabel('$y (\mu m)$')
103 | fig.colorbar(im, ax=ax)
104 | im.set_clim(vmin=jnp.min(Ir), vmax=jnp.max(Ir))
105 |
106 | axes[1,1].axis('off')
107 | axes[1,2].axis('off')
108 | plt.subplots_adjust(wspace=0.6, hspace=0.6)
109 |
110 |
111 | elif kind == 'Phase':
112 | # Compute phase
113 | phi_x = jnp.angle(self.Ex) # Ex
114 | phi_y = jnp.angle(self.Ey) # Ey
115 | phi_z = jnp.angle(self.Ez) # Ez
116 |
117 | fig, axes = plt.subplots(1, 3, figsize=(14, 3))
118 | cmap = 'twilight'
119 |
120 | ax = axes[0]
121 | im = ax.imshow(phi_x, cmap=cmap, extent=extent, origin='lower')
122 | ax.set_title(f"Phase x (in radians). {extra_title}")
123 | ax.set_xlabel('$x (\mu m)$')
124 | ax.set_ylabel('$y (\mu m)$')
125 | fig.colorbar(im, ax=ax)
126 | im.set_clim(vmin=-jnp.pi, vmax=jnp.pi)
127 |
128 | ax = axes[1]
129 | im = ax.imshow(phi_y, cmap=cmap, extent=extent, origin='lower')
130 | ax.set_title(f"Phase y (in radians). {extra_title}")
131 | ax.set_xlabel('$x (\mu m)$')
132 | ax.set_ylabel('$y (\mu m)$')
133 | fig.colorbar(im, ax=ax)
134 | im.set_clim(vmin=-jnp.pi, vmax=jnp.pi)
135 |
136 | ax = axes[2]
137 | im = ax.imshow(phi_z, cmap=cmap, extent=extent, origin='lower')
138 | ax.set_title(f"Phase z (in radians). {extra_title}")
139 | ax.set_xlabel('$x (\mu m)$')
140 | ax.set_ylabel('$y (\mu m)$')
141 | fig.colorbar(im, ax=ax)
142 | im.set_clim(vmin=-jnp.pi, vmax=jnp.pi)
143 |
144 | elif kind == 'Field':
145 | # Compute field amplitudes
146 | Ax = jnp.abs(self.Ex) # Ex
147 | Ay = jnp.abs(self.Ey) # Ey
148 | Az = jnp.abs(self.Ez) # Ez
149 |
150 | fig, axes = plt.subplots(1, 3, figsize=(14, 3))
151 | cmap = 'viridis'
152 |
153 | ax = axes[0]
154 | im = ax.imshow(Ax, cmap=cmap, extent=extent, origin='lower')
155 | ax.set_title(f"Amplitude x. {extra_title}")
156 | ax.set_xlabel('$x (\mu m)$')
157 | ax.set_ylabel('$y (\mu m)$')
158 | fig.colorbar(im, ax=ax)
159 | im.set_clim(vmin=jnp.min(Ax), vmax=jnp.max(Ax))
160 |
161 | ax = axes[1]
162 | im = ax.imshow(Ay, cmap=cmap, extent=extent, origin='lower')
163 | ax.set_title(f"Amplitude y. {extra_title}")
164 | ax.set_xlabel('$x (\mu m)$')
165 | ax.set_ylabel('$y (\mu m)$')
166 | fig.colorbar(im, ax=ax)
167 | im.set_clim(vmin=jnp.min(Ay), vmax=jnp.max(Ay))
168 |
169 | ax = axes[2]
170 | im = ax.imshow(Az, cmap=cmap, extent=extent, origin='lower')
171 | ax.set_title(f"Amplitude z. {extra_title}")
172 | ax.set_xlabel('$x (\mu m)$')
173 | ax.set_ylabel('$y (\mu m)$')
174 | fig.colorbar(im, ax=ax)
175 | im.set_clim(vmin=jnp.min(Az), vmax=jnp.max(Az))
176 |
177 | else:
178 | raise ValueError(f"Invalid kind option: {kind}. Please choose 'Intensity', 'Phase' or 'Field'.")
179 |
180 | plt.tight_layout()
181 |
182 | if save_file is True:
183 | plt.savefig(filename)
184 | print(f"Plot saved as {filename}")
185 |
186 | plt.show()
187 |
188 | def draw_intensity_profile(self, p1='', p2=''):
189 | """
190 | Draws the intensity profile of VectorizedLight.
191 |
192 | Parameters:
193 | p1 (float, float): Initial point.
194 | p2 (float, float): Final point.
195 | """
196 | h, z_profile_x = profile(jnp.abs(self.Ex)**2, self.x, self.y, point1=p1, point2=p2)
197 | _, z_profile_y = profile(jnp.abs(self.Ey)**2, self.x, self.y, point1=p1, point2=p2)
198 | _, z_profile_z = profile(jnp.abs(self.Ez)**2, self.x, self.y, point1=p1, point2=p2)
199 | _, z_profile_r = profile(jnp.abs(self.Ex)**2 + jnp.abs(self.Ey)**2, self.x, self.y, point1=p1, point2=p2)
200 | _, z_profile_total = profile(jnp.abs(self.Ex)**2 + jnp.abs(self.Ey)**2 + jnp.abs(self.Ez)**2, self.x, self.y, point1=p1, point2=p2)
201 |
202 | fig, axes = plt.subplots(3, 2, figsize=(14, 14))
203 |
204 | ax = axes[0, 0]
205 | im = ax.plot(h, z_profile_x, 'k', lw=2)
206 | ax.set_title(f"Ix profile")
207 | ax.set_xlabel('$\mu m$')
208 | ax.set_ylabel('$Ix$')
209 | ax.set(xlim=(h.min(), h.max()), ylim=(z_profile_x.min(), z_profile_x.max()))
210 |
211 | ax = axes[0, 1]
212 | im = ax.plot(h, z_profile_y, 'k', lw=2)
213 | ax.set_title(f"Iy profile")
214 | ax.set_xlabel('$\mu m$')
215 | ax.set_ylabel('$Iy$')
216 | ax.set(xlim=(h.min(), h.max()), ylim=(z_profile_y.min(), z_profile_y.max()))
217 |
218 | ax = axes[1, 0]
219 | im = ax.plot(h, z_profile_z, 'k', lw=2)
220 | ax.set_title(f"Iz profile")
221 | ax.set_xlabel('$\mu m$')
222 | ax.set_ylabel('$Iz$')
223 | ax.set(xlim=(h.min(), h.max()), ylim=(z_profile_z.min(), z_profile_z.max()))
224 |
225 | ax = axes[1, 1]
226 | im = ax.plot(h, z_profile_r, 'k', lw=2)
227 | ax.set_title(f"Ir profile")
228 | ax.set_xlabel('$\mu m$')
229 | ax.set_ylabel('$Ir$')
230 | ax.set(xlim=(h.min(), h.max()), ylim=(z_profile_r.min(), z_profile_r.max()))
231 |
232 | ax = axes[2, 0]
233 | im = ax.plot(h, z_profile_total, 'k', lw=2)
234 | ax.set_title(f"Itotal profile")
235 | ax.set_xlabel('$\mu m$')
236 | ax.set_ylabel('$Itotal$')
237 | ax.set(xlim=(h.min(), h.max()), ylim=(z_profile_total.min(), z_profile_total.max()))
238 |
239 | axes[2, 1].axis('off')
240 | plt.subplots_adjust(wspace=0.3, hspace=0.4)
241 |
242 | plt.show()
243 |
244 | def VRS_propagation(self, z):
245 | """
246 | Rayleigh-Sommerfeld diffraction integral in both, z>0 and z<0, for VectorizedLight.
247 | [Ref 1: Laser Phys. Lett., 10(6), 065004 (2013)].
248 | [Ref 2: Optics and laser tech., 39(4), 10.1016/j.optlastec.2006.03.006].
249 | [Ref 3: J. Li, Z. Fan, Y. Fu, Proc. SPIE 4915, (2002)].
250 |
251 | Parameters:
252 | z (float): Distance to propagate.
253 |
254 | Returns VectorizedLight object after propagation and the quality factor of the algorithm.
255 | """
256 | tic = time.perf_counter()
257 | # Define r [From Ref 1, eq. 1a-1c]:
258 | r = jnp.sqrt(self.X ** 2 + self.Y ** 2 + z ** 2)
259 |
260 | # Set the value of Ez:
261 | Ez = jnp.array(self.Ex * self.X / r + self.Ey * self.Y / r)
262 | nx, ny, dx, dy, Xext, Yext = build_grid(self.x, self.y)
263 |
264 | # Quality factor for accurate simulation [Eq. 22 in Ref1]:
265 | dr_real = jnp.sqrt(dx**2 + dy**2)
266 | # Rho
267 | rmax = jnp.sqrt(jnp.max(self.x**2) + jnp.max(self.y**2))
268 | # Delta rho ideal
269 | dr_ideal = jnp.sqrt((self.wavelength)**2 + rmax**2 + 2 * (self.wavelength) * jnp.sqrt(rmax**2 + z**2)) - rmax
270 | quality_factor = dr_ideal / dr_real
271 |
272 | # Stack the input field in a (3, N, N) shape and pass to jit.
273 | E_in = jnp.stack([self.Ex, self.Ey, Ez], axis=0)
274 | E_out = VRS_propagation_jit(E_in, z, nx, ny, dx, dy, Xext, Yext, self.k)
275 | E_out = jnp.moveaxis(E_out, [0, 1, 2], [2, 0, 1])
276 |
277 | # Define the output light:
278 | light_out = VectorizedLight(self.x, self.y, self.wavelength)
279 | light_out.Ex = E_out[:, :, 0]
280 | light_out.Ey = E_out[:, :, 1]
281 | light_out.Ez = E_out[:, :, 2]
282 |
283 | print(f"Time taken to perform one VRS propagation (in seconds): {(time.perf_counter() - tic):.4f}")
284 | return light_out, quality_factor
285 |
286 | def get_VRS_minimum_z(self, n=1, quality_factor=1):
287 | """
288 | Given a quality factor, determines the minimum available (trustworthy) distance for VRS_propagation().
289 | [Ref 1: Laser Phys. Lett., 10(6), 065004 (2013)].
290 |
291 | Parameters:
292 | n (float): refraction index of the surrounding medium.
293 | quality_factor (int): Defaults to 1.
294 |
295 | Returns the minimum distance z (in microns) necessary to achieve qualities larger than quality_factor.
296 |
297 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<
298 | """
299 | # Check sampling
300 | range_x = self.x[-1] - self.x[0]
301 | range_y = self.y[-1] - self.y[0]
302 | num_x = jnp.size(self.x)
303 | num_y = jnp.size(self.y)
304 |
305 | dx = range_x / num_x
306 | dy = range_y / num_y
307 | # Delta rho
308 | dr_real = jnp.sqrt(dx**2 + dy**2)
309 | # Rho
310 | rmax = jnp.sqrt(range_x**2 + range_y**2)
311 |
312 | factor = (((quality_factor * dr_real + rmax)**2 - (self.wavelength / n)**2 - rmax**2) / (2 * self.wavelength / n))**2 - rmax**2
313 |
314 | if factor > 0:
315 | z_min = jnp.sqrt(factor)
316 | else:
317 | z_min = 0
318 |
319 | return print("Minimum distance to propagate (in um):", z_min)
320 |
321 | def VCZT(self, z, xout, yout):
322 | """
323 | Vectorial version of the Chirped z-transform propagation - efficient RS diffraction using the Bluestein method.
324 | Useful for imaging light in the focal plane: allows high resolution zoom in z-plane.
325 | [Ref] Hu, Y., et al. Light Sci Appl 9, 119 (2020).
326 |
327 | Parameters:
328 | z (float): Propagation distance.
329 | xout (jnp.array): Array with the x-positions for the output plane.
330 |
331 | Returns VectorizedLight object after propagation.
332 | """
333 | tic = time.perf_counter()
334 | if xout is None:
335 | xout = self.x
336 |
337 | if yout is None:
338 | yout = self.y
339 |
340 | # Define r:
341 | r = jnp.sqrt(self.X ** 2 + self.Y ** 2 + z ** 2)
342 |
343 | # Set the value of Ez:
344 | Ez = jnp.array((self.Ex * self.X / r + self.Ey * self.Y / r) * z / r)
345 |
346 | # Define main set of parameters
347 | nx, ny, dx, dy, Xout, Yout, Dm, fy_1, fy_2, fx_1, fx_2 = build_CZT_grid(z, self.wavelength, self.x, self.y, xout, yout)
348 |
349 | # Stack the input field in a (3, N, N) shape and pass to jit.
350 | E_in = jnp.stack([self.Ex, self.Ey, Ez], axis=0)
351 | E_out = VCZT_jit(E_in, z, self.wavelength, self.k, nx, ny, dx, dy, Xout, Yout, self.X, self.Y, Dm, fy_1, fy_2, fx_1, fx_2)
352 | E_out = jnp.moveaxis(E_out, [0, 1, 2], [2, 0, 1])
353 |
354 | # Define the output light:
355 | light_out = VectorizedLight(xout, yout, self.wavelength)
356 | light_out.Ex = E_out[:, :, 0]
357 | light_out.Ey = E_out[:, :, 1]
358 | light_out.Ez = E_out[:, :, 2]
359 |
360 | print(f"Time taken to perform one VCZT propagation (in seconds): {(time.perf_counter() - tic):.4f}")
361 | return light_out
362 |
363 |
364 | @partial(jit, static_argnums=(2, 3, 4, 5, 8))
365 | def VRS_propagation_jit(input_field, z, nx, ny, dx, dy, Xext, Yext, k):
366 | """[From VRS_propagation]: JIT function that vectorizes the propagation and calls RS_propagation_jit from wave_optics.py."""
367 |
368 | # Input field has (3, N, N) shape
369 | vectorized_RS_propagation = vmap(RS_propagation_jit,
370 | in_axes=(0, None, None, None, None, None, None, None, None))
371 | # Call the vectorized function
372 | E_out = vectorized_RS_propagation(input_field, z, nx, ny, dx, dy, Xext, Yext, k)
373 | return E_out # (3, N, N) -> ([Ex, Ey, Ez], N, N)
374 |
375 | def VCZT_jit(field, z, wavelength, k, nx, ny, dx, dy, Xout, Yout, X, Y, Dm, fy_1, fy_2, fx_1, fx_2):
376 | """[From CZT]: JIT function that vectorizes the propagation and calls CZT_jit from wave_optics.py."""
377 |
378 | # Input field has (3, N, N) shape
379 | vectorized_CZT = vmap(CZT_jit,
380 | in_axes=(0, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None))
381 |
382 | # Call the vectorized function
383 | E_out = vectorized_CZT(field, z, wavelength, k, nx, ny, dx, dy, Xout, Yout, X, Y, Dm, fy_1, fy_2, fx_1, fx_2)
384 | return E_out # (3, N, N) -> ([Ex, Ey, Ez], N, N)
385 |
386 | def vectorized_CZT_for_high_NA(field, nx, ny, Dm, fy_1, fy_2, fx_1, fx_2):
387 | """[From VCZT_objective_lens - in optical_elements.py]: JIT function that vectorizes the propagation and calls CZT_for_high_NA_jit from wave_optics.py."""
388 |
389 | # Input field has (3, N, N) shape
390 | vectorized = vmap(CZT_for_high_NA_jit, in_axes=(0, None, None, None, None, None, None, None))
391 |
392 | # Call the vectorized function
393 | E_out = vectorized(field, nx, ny, Dm, fy_1, fy_2, fx_1, fx_2)
394 | return E_out # (3, N, N) -> ([Ex, Ey, Ez], N, N)
395 |
396 | class PolarizedLightSource(VectorizedLight):
397 | """ Class for generating polarized light source beams."""
398 | def __init__(self, x, y, wavelength):
399 | super().__init__(x, y, wavelength)
400 | self.info = 'Vectorized light source'
401 |
402 | def gaussian_beam(self, w0, jones_vector, center=(0, 0), z_w0=(0, 0), alpha=0):
403 | """
404 | Defines a gaussian beam.
405 |
406 | Parameters:
407 | w0 (float, float): Waist radius (in microns).
408 | jones_vector (float, float): (Ex, Ey) at the origin (r=0, z=0). Doesn't need to be normalized.
409 | center (float, float): Position of the center of the beam (in microns).
410 | z_w0 (float, float): Position of the waist for (x, y) (in microns).
411 | alpha (float, float): Amplitude rotation (in radians).
412 |
413 | Returns PolarizedLightSource object.
414 | """
415 | # Waist radius
416 | w0_x, w0_y = w0
417 |
418 | # (x, y) center position
419 | x0, y0 = center
420 |
421 | # z-position of the beam waist
422 | z_w0x, z_w0y = z_w0
423 |
424 | # Rayleigh range
425 | Rayleigh_x = self.k * w0_x ** 2 * self.n / 2
426 | Rayleigh_y = self.k * w0_y ** 2 * self.n / 2
427 |
428 | # Gouy phase
429 | Gouy_phase_x = jnp.arctan2(z_w0x, Rayleigh_x)
430 | Gouy_phase_y = jnp.arctan2(z_w0y, Rayleigh_y)
431 |
432 | # Spot size (radius of the beam at position z)
433 | w_x = w0_x * jnp.sqrt(1 + (z_w0x / Rayleigh_x) ** 2)
434 | w_y = w0_y * jnp.sqrt(1 + (z_w0y / Rayleigh_y) ** 2)
435 |
436 | # Radius of curvature
437 | if z_w0x == 0:
438 | R_x = 1e12
439 | else:
440 | R_x = z_w0x * (1 + (Rayleigh_x / z_w0x) ** 2)
441 | if z_w0x == 0:
442 | R_y = 1e12
443 | else:
444 | R_y = z_w0y * (1 + (Rayleigh_y / z_w0y) ** 2)
445 |
446 | # Gaussian beam coordinates
447 | # Accounting the rotation of the coordinates by alpha:
448 | x_rot = self.X * jnp.cos(alpha) + self.Y * jnp.sin(alpha)
449 | y_rot = -self.X * jnp.sin(alpha) + self.Y * jnp.cos(alpha)
450 |
451 | # Define the phase and amplitude of the field:
452 | phase = jnp.exp(-1j * ((self.k * z_w0x + self.k * self.X ** 2 / (2 * R_x) - Gouy_phase_x) + (
453 | self.k * z_w0y + self.k * self.Y ** 2 / (2 * R_y) - Gouy_phase_y)))
454 |
455 | # Normalize Jones vector:
456 | normalized_jones = np.array(jones_vector) / jnp.linalg.norm(np.array(jones_vector))
457 |
458 | Ex = normalized_jones[0] * (w0_x / w_x) * (w0_y / w_y) * jnp.exp(
459 | -(x_rot - x0) ** 2 / (w_x ** 2) - (y_rot - y0) ** 2 / (w_y ** 2))
460 | Ey = normalized_jones[1] * (w0_x / w_x) * (w0_y / w_y) * jnp.exp(
461 | -(x_rot - x0) ** 2 / (w_x ** 2) - (y_rot - y0) ** 2 / (w_y ** 2))
462 |
463 | self.Ex = Ex*phase
464 | self.Ey = Ey*phase
465 | self.Ez = jnp.zeros((jnp.shape(self.x)[0], jnp.shape(self.x)[0]))
466 |
467 | def plane_wave(self, jones_vector, theta=0, phi=0, z0=0):
468 | """
469 | Defines a plane wave.
470 |
471 | Parameters:
472 | jones_vector (float, float): (Ex, Ey) at the origin (r=0, z=0). Doesn't need to be normalized.
473 | theta (float): Angle (in radians).
474 | phi (float): Angle (in radians).
475 | z0 (float): Constant value for phase shift.
476 |
477 | Equation:
478 | self.field = A * exp(1.j * k * (self.X * sin(theta) * cos(phi) + self.Y * sin(theta) * sin(phi) + z0 * cos(theta)))
479 |
480 | Returns PolarizedLightSource object.
481 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<<
482 | """
483 | # Normalize Jones vector:
484 | normalized_jones = np.array(jones_vector) / jnp.linalg.norm(np.array(jones_vector))
485 |
486 | pw = jnp.exp(1j * self.k * (self.X * jnp.sin(theta) * jnp.cos(phi) + self.Y * jnp.sin(theta) * jnp.sin(phi) + z0 * jnp.cos(theta)))
487 |
488 | self.Ex = normalized_jones[0] * pw
489 | self.Ey = normalized_jones[1] * pw
490 | self.Ez = jnp.zeros((jnp.shape(self.x)[0], jnp.shape(self.x)[0]))
491 |
492 |
493 |
494 |
--------------------------------------------------------------------------------
/xlumina/wave_optics.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import numpy as np
3 | from jax import jit, config
4 | from functools import partial
5 | import matplotlib.pyplot as plt
6 | import time
7 |
8 | from .toolbox import rotate_mask
9 |
10 | # Set this to False if f64 is enough precision for you.
11 | enable_float64 = True
12 | if enable_float64:
13 | config.update("jax_enable_x64", True)
14 |
15 | """
16 | Module for scalar optical fields.
17 |
18 | - ScalarLight:
19 | - draw
20 | - apply_circular_mask
21 | - apply_triangular_mask
22 | - apply_rectangular_mask
23 | - apply_annular_aperture
24 | - RS_propagation
25 | - get_RS_minimum_z
26 | - CZT
27 | - build_grid
28 | - RS_propagation_jit
29 | - transfer_function_RS
30 | - build_CZT_grid
31 | - CZT_jit
32 | - CZT_for_high_NA_jit
33 | - compute_np2
34 | - compute_fft
35 | - Bluestein_method
36 |
37 | - LightSource:
38 | - gaussian_beam
39 | - plane_wave
40 | """
41 |
42 |
43 | class ScalarLight:
44 | """ Class for Scalar fields - complex amplitude U(r) = A(r)*exp(-ikz). """
45 | def __init__(self, x, y, wavelength):
46 | self.x = x
47 | self.y = y
48 | self.X, self.Y = jnp.meshgrid(x, y)
49 | self.wavelength = wavelength
50 | self.k = 2 * jnp.pi / wavelength
51 | self.n = 1
52 | self.field = jnp.zeros((jnp.shape(x)[0], jnp.shape(y)[0]))
53 | self.info = 'Wave optics light'
54 |
55 | def draw(self, xlim='', ylim='', kind='', extra_title='', save_file=False, filename=''):
56 | """
57 | Plots ScalarLight.
58 |
59 | Parameters:
60 | xlim (float, float): x-axis limits.
61 | ylim (float, float): y-axis limits.
62 | kind (str): Feature to plot: 'Intensity' or 'Phase'.
63 | extra_title (str): Adds extra info to the plot title.
64 | save_file (bool): If True, saves the figure.
65 | filename (str): Name of the figure.
66 | """
67 | extent = [xlim[0], xlim[1], ylim[0], ylim[1]]
68 | if kind == 'Intensity':
69 | field_to_plot = jnp.abs(self.field) ** 2 # Compute intensity (magnitude squared)
70 | title = f"Detected intensity. {extra_title}"
71 | cmap = 'gist_heat'
72 | plt.imshow(field_to_plot, cmap=cmap, extent=extent, origin='lower')
73 | plt.colorbar(orientation='vertical')
74 |
75 | elif kind == 'Phase':
76 | field_to_plot = jnp.angle(self.field) # Calculate phase (in radians)
77 | title = f"Phase in radians. {extra_title}"
78 | cmap = 'twilight'
79 | plt.imshow(field_to_plot, cmap=cmap, extent=extent, origin='lower')
80 | plt.colorbar(orientation='vertical')
81 | plt.clim(vmin=-jnp.pi, vmax=jnp.pi)
82 |
83 | else:
84 | raise ValueError(f"Invalid kind option: {kind}. Please choose 'Intensity' or 'Phase'.")
85 |
86 | plt.title(title)
87 | plt.xlabel('$x(\mu m)$')
88 | plt.ylabel('$y(\mu m)$')
89 |
90 | if save_file is True:
91 | plt.savefig(filename)
92 | print(f"Plot saved as {filename}")
93 |
94 | plt.show()
95 |
96 | def apply_circular_mask(self, r):
97 | """
98 | Apply a circular mask of variable radius.
99 |
100 | Parameters:
101 | r (float, float): Radius of the circle (in microns).
102 |
103 | Returns ScalarLight object after applying the pupil mask.
104 | """
105 | rx, ry = r
106 | pupil = jnp.where((self.X**2 / rx**2 + self.Y**2 / ry**2) < 1, 1, 0)
107 | output = ScalarLight(self.x, self.y, self.wavelength)
108 | output.field = self.field * pupil
109 | return output
110 |
111 | def apply_triangular_mask(self, r, angle, m, height):
112 | """
113 | Apply a triangular mask of variable size.
114 |
115 | Equation to generate the triangle: y = -m (x - x0) + y0.
116 |
117 | Parameters:
118 | r (float, float): Coordinates of the top corner of the triangle (in microns).
119 | angle (float): Rotation of the triangle (in radians).
120 | m (float): Slope of the edges.
121 | height (float): Distance between the top corner and the basis (in microns).
122 |
123 | Returns ScalarLight object after applying the triangular mask.
124 |
125 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<
126 | """
127 | x0, y0 = r
128 | Xrot, Yrot = rotate_mask(self.X, self.Y, angle, origin=(x0, y0))
129 | Y = -m * jnp.abs(Xrot - x0) + y0
130 | mask = jnp.where((Yrot < Y) & (Yrot > (y0 - height)), 1, 0)
131 | output = ScalarLight(self.x, self.y, self.wavelength)
132 | output.field = self.field * mask
133 | return output
134 |
135 | def apply_rectangular_mask(self, center, width, height, angle):
136 | """
137 | Apply a square mask of variable size. Can generate rectangles, squares and rotate them to create diamond shapes.
138 |
139 | Parameters:
140 | center (float, float): Coordinates of the center (in microns).
141 | width (float): Width of the rectangle (in microns).
142 | height (float): Height of the rectangle (in microns).
143 | angle (float): Angle of rotation of the rectangle (in degrees).
144 |
145 | Returns ScalarLight object after applying the rectangular mask.
146 | """
147 | x0, y0 = center
148 | angle = angle * (jnp.pi/180)
149 | Xrot, Yrot = rotate_mask(self.X, self.Y, angle, center)
150 | mask = jnp.where((Xrot < (width/2)) & (Xrot > (-width/2)) & (Yrot < (height/2)) & (Yrot > (-height/2)), 1, 0)
151 | output = ScalarLight(self.x, self.y, self.wavelength)
152 | output.field = self.field * mask
153 | return output
154 |
155 | def apply_annular_aperture(self, di, do):
156 | """
157 | Apply an annular aperture of variable size.
158 |
159 | Parameters:
160 | di (float): Radius of the inner circle (in microns).
161 | do (float): Radius of the outer circle (in microns).
162 |
163 | Returns ScalarLight object after applying the annular mask.
164 | """
165 | di = di/2
166 | do = do/2
167 | stop = jnp.where(((self.X**2 + self.Y**2) / di**2) < 1, 0, 1)
168 | ring = jnp.where(((self.X**2 + self.Y**2) / do**2) < 1, 1, 0)
169 | output = ScalarLight(self.x, self.y, self.wavelength)
170 | output.field = self.field * stop*ring
171 | return output
172 |
173 | def RS_propagation(self, z):
174 | """
175 | Rayleigh-Sommerfeld diffraction integral in both, z>0 and z<0, for ScalarLight.
176 | [Ref 1: F. Shen and A. Wang, Appl. Opt. 45, 1102-1110 (2006)].
177 | [Ref 2: J. Li, Z. Fan, Y. Fu, Proc. SPIE 4915, (2002)].
178 |
179 | Parameters:
180 | z (float): Distance to propagate (in microns).
181 |
182 | Returns ScalarLight object after propagation and the quality factor of the algorithm.
183 | """
184 | tic = time.perf_counter()
185 | nx, ny, dx, dy, Xext, Yext = build_grid(self.x, self.y)
186 |
187 | # Quality factor for accurate simulation [Eq. 17 in Ref 2]:
188 | dr_real = jnp.sqrt(dx**2 + dy**2)
189 | rmax = jnp.sqrt((self.x**2).max() + (self.y**2).max())
190 | dr_ideal = jnp.sqrt((self.wavelength / 1)**2 + rmax**2 + 2 * (self.wavelength / 1) * jnp.sqrt(rmax**2 + z**2)) - rmax
191 | quality_factor = dr_ideal / dr_real
192 |
193 | propagated_light = ScalarLight(self.x, self.y, self.wavelength)
194 | propagated_light.field = RS_propagation_jit(self.field, z, nx, ny, dx, dy, Xext, Yext, self.k)
195 | print(f"Time taken to perform one RS propagation (in seconds): {(time.perf_counter() - tic):.4f}")
196 | return propagated_light, quality_factor
197 |
198 | def get_RS_minimum_z(self, n=1, quality_factor=1):
199 | """
200 | Given a quality factor, determines the minimum available (trustworthy) distance for RS_propagation().
201 | [Ref 1: Laser Phys. Lett., 10(6), 065004 (2013)].
202 |
203 | Parameters:
204 | n (float): refraction index of the medium.
205 | quality_factor (int): Defaults to 1.
206 |
207 | Returns the minimum distance z (in microns) necessary to achieve qualities larger than quality_factor.
208 |
209 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<
210 | """
211 | # Check sampling
212 | range_x = self.x[-1] - self.x[0]
213 | range_y = self.y[-1] - self.y[0]
214 | num_x = jnp.size(self.x)
215 | num_y = jnp.size(self.y)
216 |
217 | dx = range_x / num_x
218 | dy = range_y / num_y
219 | # Delta rho
220 | dr_real = jnp.sqrt(dx**2 + dy**2)
221 | # Rho
222 | rmax = jnp.sqrt(range_x**2 + range_y**2)
223 |
224 | factor = (((quality_factor * dr_real + rmax)**2 - (self.wavelength / n)**2 - rmax**2) / (2 * self.wavelength / n))**2 - rmax**2
225 |
226 | if factor > 0:
227 | z_min = jnp.sqrt(factor)
228 | else:
229 | z_min = 0
230 |
231 | return print("Minimum distance to propagate (in microns):", z_min)
232 |
233 | def CZT(self, z, xout=None, yout=None):
234 | """
235 | Chirped z-transform propagation - efficient diffraction using the Bluestein method.
236 | Useful for imaging light in the focal plane: allows high resolution zoom in z-plane.
237 | [Ref] Hu, Y., et al. Light Sci Appl 9, 119 (2020).
238 |
239 | Parameters:
240 | z (float): Propagation distance (in microns).
241 | xout (jnp.array): Array with the x-positions for the output plane.
242 | yout (jnp.array): Array with the y-positions for the output plane.
243 |
244 | Returns ScalarLight object after propagation.
245 | """
246 | tic = time.perf_counter()
247 | if xout is None:
248 | xout = self.x
249 |
250 | if yout is None:
251 | yout = self.y
252 |
253 | # Define main set of parameters
254 | nx, ny, dx, dy, Xout, Yout, Dm, fy_1, fy_2, fx_1, fx_2 = build_CZT_grid(z, self.wavelength, self.x, self.y, xout, yout)
255 |
256 | # Compute the diffraction integral using Bluestein method
257 | field_at_z = CZT_jit(self.field, z, self.wavelength, self.k, nx, ny, dx, dy, Xout, Yout, self.X, self.Y, Dm, fy_1, fy_2, fx_1, fx_2)
258 |
259 | # Build ScalarLight object with output field.
260 | field_out = ScalarLight(xout, yout, self.wavelength)
261 | field_out.field = field_at_z
262 | print(f"Time taken to perform one CZT propagation (in seconds): {(time.perf_counter() - tic):.4f}")
263 | return field_out
264 |
265 | def build_grid(x, y):
266 | """[From RS_propagation]: Returns the grid where the transfer function is defined."""
267 | nx = len(x)
268 | ny = len(y)
269 | dx = x[1] - x[0]
270 | dy = y[1] - y[0]
271 | # Build 2N-1 x 2N-1 (X, Y) space:
272 | x_padded = jnp.pad((x[0] - x[::-1]), (0, jnp.size(x) - 1), 'reflect')
273 | y_padded = jnp.pad((y[0] - y[::-1]), (0, jnp.size(y) - 1), 'reflect')
274 | # Convert the right half into positive values:
275 | I = jnp.ones((1, int(len(x_padded) / 2) + 1))
276 | II = -jnp.ones((1, int(len(x_padded) / 2)))
277 | III = jnp.ravel(jnp.concatenate((I, II), 1))
278 | Xext, Yext = jnp.meshgrid(x_padded * III, y_padded * III)
279 | return nx, ny, dx, dy, Xext, Yext
280 |
281 | @partial(jit, static_argnums=(2, 3, 4, 5, 8))
282 | def RS_propagation_jit(input_field, z, nx, ny, dx, dy, Xext, Yext, k):
283 | """[From RS_propagation]: JIT function for Equation (10) in [Ref 1]."""
284 | # input_field is jnp.array of (N, N)
285 | H = transfer_function_RS(z, Xext, Yext, k)
286 | U = jnp.zeros((2 * ny - 1, 2 * nx - 1), dtype=complex)
287 | U = U.at[0:ny, 0:nx].set(input_field)
288 | output_field = (jnp.fft.ifft2(jnp.fft.fft2(U) * jnp.fft.fft2(H)) * dx * dy)[ny - 1:, nx - 1:]
289 | return output_field
290 |
291 | @partial(jit, static_argnums=(3,))
292 | def transfer_function_RS(z, Xext, Yext, k):
293 | """[From RS_propagation]: JIT function for optical transfer function."""
294 | r = jnp.sqrt(Xext ** 2 + Yext ** 2 + z ** 2)
295 | factor = 1 / (2 * jnp.pi) * z / r ** 2 * (1 / r - 1j * k)
296 | result = jnp.where(z > 0, jnp.exp(1j * k * r) * factor, jnp.exp(-1j * k * r) * factor)
297 | return result
298 |
299 | def build_CZT_grid(z, wavelength, xin, yin, xout, yout):
300 | """
301 | [From CZT]: Defines the resolution / sampling of initial and output planes.
302 |
303 | Parameters:
304 | xin (jnp.array): Array with the x-positions of the input plane.
305 | yin (jnp.array): Array with the y-positions of the input plane.
306 | xout (jnp.array): Array with the x-positions of the output plane.
307 | yout (jnp.array): Array with the y-positions of the output plane.
308 |
309 | Returns the set of parameters: nx, ny, Xout, Yout, dx, dy, delta_out, Dm, fy_1, fy_2, fx_1 and fx_2.
310 | """
311 | # Resolution of the output plane:
312 | nx = len(xout)
313 | ny = len(yout)
314 | Xout, Yout = jnp.meshgrid(xout, yout)
315 |
316 | # Sampling of initial plane:
317 | dx = xin[1] - xin[0]
318 | dy = yin[1] - yin[0]
319 |
320 | # For Bluestein method implementation:
321 | # Dimension of the output field - Eq. (11) in [Ref].
322 | Dm = wavelength * z / dx
323 |
324 | # (1) for FFT in Y-dimension:
325 | fy_1 = yout[0] + Dm / 2
326 | fy_2 = yout[-1] + Dm / 2
327 | # (1) for FFT in X-dimension:
328 | fx_1 = xout[0] + Dm / 2
329 | fx_2 = xout[-1] + Dm / 2
330 |
331 | return nx, ny, dx, dy, Xout, Yout, Dm, fy_1, fy_2, fx_1, fx_2
332 |
333 | def CZT_jit(field, z, wavelength, k, nx, ny, dx, dy, Xout, Yout, X, Y, Dm, fy_1, fy_2, fx_1, fx_2):
334 | """
335 | [From CZT]: Diffraction integral implementation using Bluestein method.
336 | [Ref] Hu, Y., et al. Light Sci Appl 9, 119 (2020).
337 | """
338 | # Compute the scalar diffraction integral using RS transfer function:
339 | # See Eq.(3) in [Ref].
340 | F0 = transfer_function_RS(z, Xout, Yout, k)
341 | F = transfer_function_RS(z, X, Y, k)
342 |
343 | # Compute (E0 x F) in Eq.(6) in [Ref].
344 | field = field * F
345 |
346 | # Bluestein method implementation:
347 |
348 | # (1) FFT in Y-dimension:
349 | U = Bluestein_method(field, fy_1, fy_2, Dm, ny)
350 |
351 | # (2) FFT in X-dimension using output from (1):
352 | U = Bluestein_method(U, fx_1, fx_2, Dm, nx)
353 |
354 | # Compute Eq.(6) in [Ref].
355 | field_at_z = F0 * U * z * dx * dy * wavelength
356 |
357 | return field_at_z
358 |
359 | def CZT_for_high_NA_jit(field, nx, ny, Dm, fy_1, fy_2, fx_1, fx_2):
360 | """
361 | [From VCZT_objective_lens - in optical_elements.py]: Function for Debye integral implementation using Bluestein method.
362 | [Ref] Hu, Y., et al. Light Sci Appl 9, 119 (2020).
363 | """
364 | # Bluestein method implementation:
365 | # (1) FFT in Y-dimension:
366 | U = Bluestein_method(field, fy_1, fy_2, Dm, ny)
367 |
368 | # (2) FFT in X-dimension using output from (1):
369 | U = Bluestein_method(U, fx_1, fx_2, Dm, nx)
370 |
371 | return U
372 |
373 | def compute_np2(x):
374 | """
375 | [For Bluestein method]: Exponent of next higher power of 2.
376 |
377 | Parameters:
378 | x (float): value
379 |
380 | Returns the exponent for the smallest powers of two that satisfy 2**p >= X for each element in X.
381 | """
382 | np2 = 2**(np.ceil(np.log2(x))).astype(int)
383 | return np2
384 |
385 | @partial(jit, static_argnums=(4, 5, 6, 7, 8))
386 | def compute_fft(x, D1, D2, Dm, m, n, mp, M_out, np2):
387 | """
388 | [From Bluestein_method]: JIT-computes the FFT part of the algorithm.
389 | """
390 | # A-Complex exponential term
391 | A = jnp.exp(1j * 2 * jnp.pi * D1 / Dm)
392 | # W-Complex exponential term
393 | W = jnp.exp(-1j * 2 * jnp.pi * (D2 - D1) / (M_out * Dm))
394 |
395 | # Window function
396 | h = jnp.arange(-m + 1, max(M_out -1, m -1) +1)
397 | h = W**(h**2/ 2)
398 | h_sliced = h[:mp + 1]
399 |
400 | # Compute the 1D Fourier Transform of 1/h up to length 2**nextpow2(mp)
401 | ft = jnp.fft.fft(1 / h_sliced, np2)
402 | # Compute intermediate result for Bluestein's algorithm
403 | b = A**(-(jnp.arange(m))) * h[jnp.arange(m - 1, 2 * m -1)]
404 | tmp = jnp.tile(b, (n, 1)).T
405 | # Compute the 1D Fourier Transform of input data * intermediate result
406 | b = jnp.fft.fft(x * tmp, np2, axis=0)
407 | # Compute the Inverse Fourier Transform
408 | b = jnp.fft.ifft(b * jnp.tile(ft, (n, 1)).T, axis=0)
409 |
410 | return b, h
411 |
412 | def Bluestein_method(x, f1, f2, Dm, M_out):
413 | """
414 | [From CZT]: Performs the DFT using Bluestein method.
415 | [Ref1]: Hu, Y., et al. Light Sci Appl 9, 119 (2020).
416 | [Ref2]: L. Bluestein, IEEE Trans. Au. and Electro., 18(4), 451-455 (1970).
417 | [Ref3]: L. Rabiner, et. al., IEEE Trans. Au. and Electro., 17(2), 86-92 (1969).
418 |
419 | Parameters:
420 | x (jnp.array): Input sequence, x[n] in Eq.(12) in [Ref 1].
421 | f1 (float): Starting point in frequency range.
422 | f2 (float): End point in frequency range.
423 | Dm (float): Dimension of the imaging plane.
424 | M_out (float): Length of the transform (resolution of the output plane).
425 |
426 | Returns the output X[m] (jnp.array).
427 |
428 | >> Adapted from MATLAB code provided by https://github.com/yanleihu/Bluestein-Method<<
429 | """
430 |
431 | # Correspond to the length of the input sequence.
432 | m, n = x.shape
433 |
434 | # Intermediate frequency
435 | D1 = f1 + (M_out * Dm + f2 - f1) / (2 * M_out)
436 | # Upper frequency limit
437 | D2 = f2 + (M_out * Dm + f2 - f1) / (2 * M_out)
438 |
439 | # Length of the output sequence
440 | mp = m + M_out - 1
441 | np2 = compute_np2(mp)
442 | b, h = compute_fft(x, D1, D2, Dm, m, n, mp, M_out, np2)
443 |
444 | # Extract the relevant portion and multiply by the window function, h
445 | if M_out > 1:
446 | b = b[m:mp +1, 0:n].T * jnp.tile(h[m - 1:mp], (n, 1))
447 | else:
448 | b = b[0] * h[0]
449 |
450 | # Create a linearly spaced array from 0 to M_out-1
451 | l = jnp.linspace(0, M_out - 1, M_out)
452 | # Scale the array to the frequency range [D1, D2]
453 | l = l / M_out * (D2 - D1) + D1
454 |
455 | # Eq. S14 in Supplementaty Information Section 3 in [Ref1]. Frequency shift to center the spectrum.
456 | M_shift = -m / 2
457 | M_shift = jnp.tile(jnp.exp(-1j * 2 * jnp.pi * l * (M_shift + 1 / 2) / Dm), (n, 1))
458 | # Apply the frequency shift to the final output
459 | b = b * M_shift
460 | return b
461 |
462 | class LightSource(ScalarLight):
463 | """ Class for generating 2D wave optics light source beams. """
464 | def __init__(self, x, y, wavelength):
465 | super().__init__(x, y, wavelength)
466 | self.info = 'Wave optics light source'
467 |
468 | def gaussian_beam(self, w0, E0, center=(0, 0), z_w0=(0, 0), alpha=0):
469 | """
470 | Defines a gaussian beam.
471 |
472 | Parameters:
473 | w0 (float, float): Waist radius (in microns).
474 | E0 (float): Electric field amplitude at the origin (r=0, z=0).
475 | center (float, float): Position of the center of the beam (in microns).
476 | z_w0 (float, float): Position of the waist for (x, y) (in microns).
477 | alpha (float, float): Amplitude rotation (in radians).
478 |
479 | Returns LightSource object.
480 | """
481 | # Waist radius
482 | w0_x, w0_y = w0
483 |
484 | # (x, y) center position
485 | x0, y0 = center
486 |
487 | # z-position of the beam waist
488 | z_w0x, z_w0y = z_w0
489 |
490 | # Rayleigh range
491 | Rayleigh_x = self.k * w0_x ** 2 * self.n / 2
492 | Rayleigh_y = self.k * w0_y ** 2 * self.n / 2
493 |
494 | # Gouy phase
495 | Gouy_phase_x = jnp.arctan2(z_w0x, Rayleigh_x)
496 | Gouy_phase_y = jnp.arctan2(z_w0y, Rayleigh_y)
497 |
498 | # Spot size (radius of the beam at position z)
499 | w_x = w0_x * jnp.sqrt(1 + (z_w0x / Rayleigh_x) ** 2)
500 | w_y = w0_y * jnp.sqrt(1 + (z_w0y / Rayleigh_y) ** 2)
501 |
502 | # Radius of curvature
503 | if z_w0x == 0:
504 | R_x = 1e12
505 | else:
506 | R_x = z_w0x * (1 + (Rayleigh_x / z_w0x) ** 2)
507 | if z_w0x == 0:
508 | R_y = 1e12
509 | else:
510 | R_y = z_w0y * (1 + (Rayleigh_y / z_w0y) ** 2)
511 |
512 | # Gaussian beam coordinates
513 | X, Y = jnp.meshgrid(self.x, self.y)
514 | # Accounting the rotation of the coordinates by alpha:
515 | x_rot = X * jnp.cos(alpha) + Y * jnp.sin(alpha)
516 | y_rot = -X * jnp.sin(alpha) + Y * jnp.cos(alpha)
517 |
518 | # Define the phase and amplitude of the field:
519 | phase = jnp.exp(-1j * ((self.k * z_w0x + self.k * X ** 2 / (2 * R_x) - Gouy_phase_x) + (
520 | self.k * z_w0y + self.k * Y ** 2 / (2 * R_y) - Gouy_phase_y)))
521 |
522 | self.field = (E0 * (w0_x / w_x) * (w0_y / w_y) * jnp.exp(
523 | -(x_rot - x0) ** 2 / (w_x ** 2) - (y_rot - y0) ** 2 / (w_y ** 2))) * phase
524 |
525 | def plane_wave(self, A=1, theta=0, phi=0, z0=0):
526 | """
527 | Defines a plane wave.
528 |
529 | Parameters:
530 | A (float): Maximum amplitude.
531 | theta (float): Angle (in radians).
532 | phi (float): Angle (in radians).
533 | z0 (float): Constant value for phase shift.
534 |
535 | Equation:
536 | self.field = A * exp(1.j * k * (self.X * sin(theta) * cos(phi) + self.Y * sin(theta) * sin(phi) + z0 * cos(theta)))
537 |
538 | Returns a LightSource object.
539 |
540 | >> Diffractio-adapted function (https://pypi.org/project/diffractio/) <<<
541 | """
542 | self.field = A * jnp.exp(1j * self.k * (self.X * jnp.sin(theta) * jnp.cos(phi) + self.Y * jnp.sin(theta) * jnp.sin(phi) + z0 * jnp.cos(theta)))
--------------------------------------------------------------------------------