├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── docs
├── Makefile
├── make.bat
├── source
│ ├── acknowledgements.rst
│ ├── api.rst
│ ├── api_domain.rst
│ ├── api_helmholtzdomain.rst
│ ├── api_iteration.rst
│ ├── api_multidomain.rst
│ ├── api_utilities.rst
│ ├── conclusion.rst
│ ├── conf.py
│ ├── development.rst
│ ├── index.rst
│ ├── index_latex.rst
│ ├── index_markdown.rst
│ ├── readme.rst
│ └── references.bib
└── tex.bat
├── environment.yml
├── examples
├── README.rst
├── __init__.py
├── check_mem.py
├── helmholtz_1d.py
├── helmholtz_1d_analytical.py
├── helmholtz_2d.py
├── helmholtz_2d_homogeneous.py
├── helmholtz_2d_low_contrast.py
├── helmholtz_3d_disordered.py
├── logo_structure_vector.png
├── matlab_results.mat
├── mem_snapshot.py
├── paper_code
│ ├── __init__.py
│ ├── fig1_splitting.py
│ ├── fig2_decompose.py
│ ├── fig3_correction_matrix.py
│ ├── fig4_truncate.py
│ ├── fig5_dd_validation.py
│ ├── fig6_dd_truncation.py
│ ├── fig7_dd_convergence.py
│ └── fig8_dd_large_simulation.py
├── run_example.py
└── timing_test.py
├── guides
├── ipynb2slides_guide.md
├── misc_tips.md
└── pytorch_gpu_setup.md
├── pyproject.toml
├── pyproject_cpu.toml
├── requirements.txt
├── tests
├── __init__.py
├── test_analytical.py
├── test_basics.py
├── test_examples.py
├── test_operators.py
└── test_utilities.py
└── wavesim
├── __init__.py
├── domain.py
├── helmholtzdomain.py
├── iteration.py
├── multidomain.py
└── utilities.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Simulation output
7 | *.pdf
8 | *.npz
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | #pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
115 | .pdm.toml
116 | .pdm-python
117 | .pdm-build/
118 |
119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120 | __pypackages__/
121 |
122 | # Celery stuff
123 | celerybeat-schedule
124 | celerybeat.pid
125 |
126 | # SageMath parsed files
127 | *.sage.py
128 |
129 | # Environments
130 | .env
131 | .venv
132 | env/
133 | venv/
134 | ENV/
135 | env.bak/
136 | venv.bak/
137 |
138 | # Spyder project settings
139 | .spyderproject
140 | .spyproject
141 |
142 | # Rope project settings
143 | .ropeproject
144 |
145 | # mkdocs documentation
146 | /site
147 |
148 | # mypy
149 | .mypy_cache/
150 | .dmypy.json
151 | dmypy.json
152 |
153 | # Pyre type checker
154 | .pyre/
155 |
156 | # pytype static type analyzer
157 | .pytype/
158 |
159 | # Cython debug symbols
160 | cython_debug/
161 |
162 | # PyCharm
163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165 | # and can be added to the global gitignore or merged into this file. For a more nuclear
166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167 | .idea/
168 |
169 | # Custom
170 | *.html
171 | *.ipynb
172 | *.json
173 | *.mp4
174 | *.pdf
175 | *.png
176 | *.sh
177 | .vscode
178 | __pycache__
179 | logs
180 | *.nsys-rep
181 | *.sqlite
182 |
183 | # Ignore certain files in the docs directory
184 | auto_examples
185 | sg_execution_times.rst
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: "ubuntu-22.04"
5 | tools:
6 | python: "3.11"
7 |
8 | jobs:
9 | post_create_environment:
10 | - pip install poetry
11 | - pip install poetry-plugin-export
12 | - poetry export -f requirements.txt -o requirements.txt --with docs,dev
13 | - cat requirements.txt
14 |
15 | python:
16 | install:
17 | - requirements: requirements.txt
18 |
19 | sphinx:
20 | configuration: docs/source/conf.py
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Ivo Vellekoop, Swapnil Mache - University of Twente
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Wavesim
4 |
5 |
6 |
7 | ## What is Wavesim?
8 |
9 | Wavesim is a tool to simulate the propagation of waves in complex, inhomogeneous structures. Whereas most available solvers use the popular finite difference time domain (FDTD) method [[1](#id27), [2](#id21), [3](#id16), [4](#id15)], Wavesim is based on the modified Born series (MBS) approach, which has lower memory requirements, no numerical dispersion, and is faster as compared to FDTD [[5](#id17), [6](#id23)].
10 |
11 | This package [[7](#id25)] is a Python implementation of the MBS approach for solving the Helmholtz equation in arbitrarily large media through domain decomposition [[8](#id13)]. With this new framework, we simulated a complex 3D structure of a remarkable $315\times 315\times 315$ wavelengths $\left( 3.1\cdot 10^7 \right)$ in size in just $1.4$ hours by solving over two GPUs. This represents a factor of $1.93$ increase over the largest possible simulation on a single GPU without domain decomposition.
12 |
13 | When using Wavesim in your work, please cite:
14 |
15 | > [[5](#id17)] [Osnabrugge, G., Leedumrongwatthanakun, S., & Vellekoop, I. M. (2016). A convergent Born series for solving the inhomogeneous Helmholtz equation in arbitrarily large media. *Journal of computational physics, 322*, 113-124.](https://doi.org/10.1016/j.jcp.2016.06.034)
16 |
17 | > [[8](#id13)] [Mache, S., & Vellekoop, I. M. (2024). Domain decomposition of the modified Born series approach for large-scale wave propagation simulations. *arXiv preprint arXiv:2410.02395*.](https://arxiv.org/abs/2410.02395)
18 |
19 | If you use the code in your research, please cite this repository as well [[7](#id25)].
20 |
21 | Examples and documentation for this project are available at [Read the Docs](https://wavesim.readthedocs.io/en/latest/) [[9](#id24)]. For more information (and to participate in the forum for discussions, queries, and requests), please visit our website [www.wavesim.org](https://www.wavesim.org/).
22 |
23 | ## Installation
24 |
25 | Wavesim requires [Python >=3.11.0 and <3.13.0](https://www.python.org/downloads/) and uses [PyTorch](https://pytorch.org/) for GPU acceleration.
26 |
27 | First, clone the repository and navigate to the directory:
28 |
29 | ```default
30 | git clone https://github.com/IvoVellekoop/wavesim_py.git
31 | cd wavesim_py
32 | ```
33 |
34 | Then, you can install the dependencies in a couple of ways:
35 |
36 | [1. Using pip](#pip-installation)
37 |
38 | [2. Using conda](#conda-installation)
39 |
40 | [3. Using Poetry](#poetry-installation)
41 |
42 | We recommend working with a virtual environment to avoid conflicts with other packages.
43 |
44 |
45 |
46 | ### 1. **Using pip**
47 |
48 | If you prefer to use pip, you can install the required packages using [requirements.txt](https://github.com/IvoVellekoop/wavesim_py/blob/main/requirements.txt):
49 |
50 | 1. **Create a virtual environment and activate it** (optional but recommended)
51 | * First, [create a virtual environment](https://docs.python.org/3/library/venv.html#creating-virtual-environments) using the following command:
52 | ```default
53 | python -m venv path/to/venv
54 | ```
55 | * Then, activate the virtual environment. The command depends on your operating system and shell ([How venvs work](https://docs.python.org/3/library/venv.html#how-venvs-work)):
56 | ```default
57 | source path/to/venv/bin/activate # for Linux/macOS
58 | path/to/venv/Scripts/activate.bat # for Windows (cmd)
59 | path/to/venv/Scripts/Activate.ps1 # for Windows (PowerShell)
60 | ```
61 | 2. **Install packages**:
62 | ```default
63 | pip install -r requirements.txt
64 | ```
65 |
66 |
67 |
68 | ### 2. **Using conda**
69 |
70 | We recommend using [Miniconda](https://docs.anaconda.com/miniconda/) (a much lighter counterpart of Anaconda) to install Python and the required packages (contained in [environment.yml](https://github.com/IvoVellekoop/wavesim_py/blob/main/environment.yml)) within a conda environment.
71 |
72 | 1. **Download Miniconda**, choosing the appropriate [Python installer](https://docs.anaconda.com/miniconda/) for your operating system (Windows/macOS/Linux).
73 | 2. **Install Miniconda**, following the [installation instructions](https://docs.anaconda.com/miniconda/miniconda-install/) for your OS. Follow the prompts on the installer screens. If you are unsure about any setting, accept the defaults. You can change them later. (If you cannot immediately activate conda, close and re-open your terminal window to make the changes take effect).
74 | 3. **Test your installation**. Open Anaconda Prompt and run the below command. Alternatively, open an editor like [Visual Studio Code](https://code.visualstudio.com/) or [PyCharm](https://www.jetbrains.com/pycharm/), select the Python interpreter in the `miniconda3/` directory with the label `('base')`, and run the command:
75 | ```default
76 | conda list
77 | ```
78 |
79 | A list of installed packages appears if it has been installed correctly.
80 | 4. **Set up a conda environment**. Avoid using the base environment altogether. It is a good backup environment to fall back on if and when the other environments are corrupted/don’t work. Create a new environment using [environment.yml](https://github.com/IvoVellekoop/wavesim_py/blob/main/environment.yml) and activate:
81 | ```default
82 | conda env create -f environment.yml
83 | conda activate wavesim
84 | ```
85 |
86 | The [Miniconda environment management guide](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) has more details if you need them.
87 |
88 | Alternatively, you can create a conda environment with a specific Python version, and then use the [requirements.txt](https://github.com/IvoVellekoop/wavesim_py/blob/main/requirements.txt) file to install the dependencies:
89 | ```default
90 | conda create -n wavesim python'>=3.11.0,<3.13'
91 | conda activate wavesim
92 | pip install -r requirements.txt
93 | ```
94 |
95 |
96 |
97 | ### 3. **Using Poetry**
98 |
99 | 1. Install [Poetry](https://python-poetry.org/).
100 | 2. Install dependencies by running the following command:
101 | ```default
102 | poetry install
103 | ```
104 |
105 | To run tests using pytest, you can install the development dependencies as well:
106 | ```default
107 | poetry install --with dev
108 | ```
109 | 3. [Activate](https://python-poetry.org/docs/managing-environments/#activating-the-environment) the virtual environment created by Poetry.
110 |
111 | ## Running the code
112 |
113 | Once the virtual environment is set up with all the required packages, you are ready to run the code. You can go through any of the scripts in the `examples` [directory](https://github.com/IvoVellekoop/wavesim_py/tree/main/examples) for the basic steps needed to run a simulation. The directory contains examples of 1D, 2D, and 3D problems.
114 |
115 | You can run the code with just three inputs:
116 |
117 | * `permittivity`, i.e. refractive index distribution squared (a 4-dimensional array on a regular grid),
118 | * `source`, the same size as permittivity.
119 | * `periodic`, a tuple of three booleans to indicate whether the domain is periodic in each dimension [`True`] or not [`False`], and
120 |
121 | [Listing 1.1](#helmholtz-1d-analytical) shows a simple example of a 1D problem with a homogeneous medium ([helmholtz_1d_analytical.py](https://github.com/IvoVellekoop/wavesim_py/blob/main/examples/helmholtz_1d_analytical.py)) to explain these and other inputs.
122 |
123 |
124 | ```python
125 | """
126 | Helmholtz 1D analytical test
127 | ============================
128 | Test to compare the result of Wavesim to analytical results.
129 | Compare 1D free-space propagation with analytic solution.
130 | """
131 |
132 | import torch
133 | import numpy as np
134 | from time import time
135 | import sys
136 | sys.path.append(".")
137 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
138 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
139 | from wavesim.iteration import run_algorithm # to run the wavesim iteration
140 | from wavesim.utilities import analytical_solution, preprocess, relative_error
141 | from __init__ import plot
142 |
143 |
144 | # Parameters
145 | wavelength = 1. # wavelength in micrometer (um)
146 | n_size = (256, 1, 1) # size of simulation domain (in pixels in x, y, and z direction)
147 | n = np.ones(n_size, dtype=np.complex64) # permittivity (refractive index²) map
148 | boundary_widths = 16 # width of the boundary in pixels
149 |
150 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
151 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n
152 |
153 | # Source term. This way is more efficient than dense tensor
154 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain
155 | values = torch.tensor([1.0]) # Amplitude: 1
156 | n_ext = tuple(np.array(n_size) + 2*boundary_array)
157 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64)
158 |
159 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains)
160 | # 1-domain, periodic boundaries (without wrapping correction)
161 | periodic = (True, True, True) # periodic boundaries, wrapped field.
162 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength)
163 | # # OR. Uncomment to test domain decomposition
164 | # periodic = (False, True, True) # wrapping correction
165 | # domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, n_domains=(3, 1, 1))
166 |
167 | # Run the wavesim iteration and get the computed field
168 | start = time()
169 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=2000)
170 | end = time() - start
171 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}')
172 | u_computed = u_computed.squeeze().cpu().numpy()[boundary_widths:-boundary_widths]
173 | u_ref = analytical_solution(n_size[0], domain.pixel_size, wavelength)
174 |
175 | # Compute relative error with respect to the analytical solution
176 | re = relative_error(u_computed, u_ref)
177 | print(f'Relative error: {re:.2e}')
178 | threshold = 1.e-3
179 | assert re < threshold, f"Relative error higher than {threshold}"
180 |
181 | # Plot the results
182 | plot(u_computed, u_ref, re)
183 | ```
184 |
185 | Apart from the inputs `permittivity`, `source`, and `periodic`, all other parameters have defaults. Details about these are given below (with the default values, if defined).
186 |
187 | Parameters in the `Domain` class: `HelmholtzDomain` (for a single domain without domain decomposition) or `MultiDomain` (to solve a problem with domain decomposition)
188 |
189 | * `permittivity`: 3-dimensional array with refractive index-squared distribution in x, y, and z direction. To set up a 1 or 2-dimensional problem, leave the other dimension(s) as 1.
190 | * `periodic`: indicates for each dimension whether the simulation is periodic (`True`) or not (`False`). For periodic dimensions, i.e., `periodic` `= [True, True, True]`, the field is wrapped around the domain.
191 | * `pixel_size` `:float = 0.25`: points per wavelength.
192 | * `wavelength` `:float = None`: wavelength: wavelength in micrometer (um). If not given, i.e. `None`, it is calculated as `1/pixel_size = 4 um`.
193 | * `n_domains` `: tuple[int, int, int] = (1, 1, 1)`: number of domains to split the simulation into. If the domain size is not divisible by n_domains, the last domain will be slightly smaller than the other ones. If `(1, 1, 1)`, indicates no domain decomposition.
194 | * `n_boundary` `: int = 8`: number of points used in the wrapping and domain transfer correction. Applicable when `periodic` is False in a dimension, or `n_domains > 1` in a dimension.
195 | * `device` `: str = None`:
196 | > * `'cpu'` to use the cpu,
197 | > * `'cuda:x'` to use a specific cuda device
198 | > * `'cuda'` or a list of strings, e.g., `['cuda:0', 'cuda:1']`, to distribute the simulation over the available/given cuda devices in a round-robin fashion
199 | > * `None`, which is equivalent to `'cuda'` if cuda devices are available, and `'cpu'` if they are not.
200 | * `debug` `: bool = False`: set to `True` for testing to return `inverse_propagator_kernel` as output.
201 |
202 | Parameters in the `run_algorithm()` function
203 |
204 | * `domain`: the domain object created by HelmholtzDomain() or MultiDomain()
205 | * `source`: source term, a 3-dimensional array, with the same size as permittivity. Set up amplitude(s) at the desired location(s), following the same principle as permittivity for 1, 2, or 3-dimensional problems.
206 | * `alpha` `: float = 0.75`: relaxation parameter for the Richardson iteration
207 | * `max_iterations` `: int = 1000`: maximum number of iterations
208 | * `threshold` `: float = 1.e-6`: threshold for the residual norm for stopping the iteration
209 |
210 | ## Acknowledgements
211 |
212 | This work was supported by the European Research Council’s Proof of Concept Grant n° [101069402].
213 |
214 | ## Conflict of interest statement
215 |
216 | The authors declare no conflict of interest.
217 |
218 | ## References
219 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
--------------------------------------------------------------------------------
/docs/source/acknowledgements.rst:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | ----------------
3 |
4 | This work was supported by the European Research Council's Proof of Concept Grant n° [101069402].
5 |
6 | Conflict of interest statement
7 | ------------------------------
8 | The authors declare no conflict of interest.
9 |
10 | References
11 | ----------
12 | .. bibliography::
13 |
--------------------------------------------------------------------------------
/docs/source/api.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | ==============
3 | .. toctree::
4 | :maxdepth: 2
5 |
6 | api_domain
7 | api_helmholtzdomain
8 | api_multidomain
9 | api_iteration
10 | api_utilities
--------------------------------------------------------------------------------
/docs/source/api_domain.rst:
--------------------------------------------------------------------------------
1 | wavesim.domain
2 | ----------------------
3 | .. automodule:: wavesim.domain
4 | :members:
5 | :imported-members:
6 | :undoc-members:
7 | :noindex:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/api_helmholtzdomain.rst:
--------------------------------------------------------------------------------
1 | wavesim.helmholtzdomain
2 | -----------------------
3 | .. automodule:: wavesim.helmholtzdomain
4 | :members:
5 | :imported-members:
6 | :undoc-members:
7 | :noindex:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/api_iteration.rst:
--------------------------------------------------------------------------------
1 | wavesim.iteration
2 | ----------------------
3 | .. automodule:: wavesim.iteration
4 | :members:
5 | :imported-members:
6 | :undoc-members:
7 | :noindex:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/api_multidomain.rst:
--------------------------------------------------------------------------------
1 | wavesim.multidomain
2 | ----------------------
3 | .. automodule:: wavesim.multidomain
4 | :members:
5 | :imported-members:
6 | :undoc-members:
7 | :noindex:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/api_utilities.rst:
--------------------------------------------------------------------------------
1 | wavesim.utilities
2 | ----------------------
3 | .. automodule:: wavesim.utilities
4 | :members:
5 | :imported-members:
6 | :undoc-members:
7 | :noindex:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/conclusion.rst:
--------------------------------------------------------------------------------
1 | Conclusion
2 | ==========
3 |
4 | In this work :cite:`mache2024domain`, we have introduced a domain decomposition of the modified Born series (MBS) approach :cite:`osnabrugge2016convergent, vettenburg2023universal` applied to the Helmholtz equation. With the new framework, we simulated a complex 3D structure of a remarkable :math:`3.1\cdot 10^7` wavelengths in size in just :math:`1.4` hours by solving over two GPUs. This represents a factor of :math:`1.93` increase over the largest possible simulation on a single GPU without domain decomposition.
5 |
6 | Our decomposition framework hinges on the ability to split the linear system as :math:`A=L+V`. Instead of the traditional splitting, where :math:`V` is a scattering potential that acts locally on each voxel, we introduced a :math:`V` that includes the communication between subdomains and corrections for wraparound artefacts. As a result, the operator :math:`(L+I)^{-1}` in the MBS iteration can be evaluated locally on each subdomain using a fast convolution. Therefore, this operator, which is the most computationally intensive step of the iteration, can be evaluated in parallel on multiple GPUs.
7 |
8 | Despite the significant overhead of our domain splitting method due to an increased number of iterations, and communication and synchronisation overhead, the ability to split a simulation over multiple GPUs results in a very significant speedup. Already, with the current dual-GPU system, we were able to solve a problem of :math:`315\times 315\times 315` wavelengths :math:`13.2\times` faster than without domain decomposition since the non-decomposed problem is too large to fit on a single GPU. Moreover, there is only a little overhead associated with adding more subdomains along an axis after the first splitting. This favourable scaling paves the way for distributing simulations over more GPUs or compute nodes in a cluster.
9 |
10 | In this work, we have already introduced strategies to reduce the overhead of the domain decomposition through truncating corrections to only a few points close to the edge of the subdomain and only activating certain subdomains in the iteration. We anticipate that further developments and optimisation of the code may help reduce the overhead of the lock-step execution.
11 |
12 | Finally, due to the generality of our approach, we expect it to be readily extended to include Maxwell's equations :cite:`kruger2017solution` and birefringent media :cite:`vettenburg2019calculating`. Given the rapid developments of GPU hardware and compute clusters, we anticipate that optical simulations at a cubic-millimetre scale can soon be performed in a matter of minutes.
13 |
14 | Code availability
15 | -----------------
16 | The code for Wavesim is available on GitHub :cite:`wavesim_py`, it is licensed under the MIT license. When using Wavesim in your work, please cite :cite:`mache2024domain, osnabrugge2016convergent`. Examples and documentation for this project are available at `Read the Docs `_ :cite:`wavesim_documentation`.
17 |
18 | %endmatter%
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Path setup --------------------------------------------------------------
7 | import os
8 | import sys
9 | import shutil
10 | from pathlib import Path
11 |
12 | from sphinx_markdown_builder import MarkdownBuilder
13 |
14 | # path setup (relevant for both local and read-the-docs builds)
15 | docs_source_dir = os.path.dirname(__file__)
16 | root_dir = os.path.dirname(os.path.dirname(docs_source_dir))
17 | sys.path.append(docs_source_dir)
18 | sys.path.append(root_dir)
19 |
20 | # -- Project information -----------------------------------------------------
21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
22 | project = 'wavesim'
23 | copyright = '2024, Ivo Vellekoop, and Swapnil Mache, University of Twente'
24 | # author = 'Swapnil Mache, Ivo M. Vellekoop'
25 | release = '0.1.0-alpha.2'
26 | html_title = "Wavesim - A Python package for wave propagation simulation"
27 |
28 | # -- latex configuration -----------------------------------------------------
29 | latex_elements = {
30 | 'preamble': r"""
31 | \usepackage{authblk}
32 | """,
33 | 'maketitle': r"""
34 | \author[1,2]{Swapnil~Mache}
35 | \author[1*]{Ivo~M.~Vellekoop}
36 | \affil[1]{University of Twente, Biomedical Photonic Imaging, TechMed Institute, P. O. Box 217, 7500 AE Enschede, The Netherlands}
37 | \affil[2]{Currently at: Rayfos Ltd., Winton House, Winton Square, Basingstoke, United Kingdom}
38 | \affil[*]{Corresponding author: i.m.vellekoop@utwente.nl}
39 | \publishers{%
40 | \normalfont\normalsize%
41 | \parbox{0.8\linewidth}{%
42 | \vspace{0.5cm}
43 | The modified Born series (MBS) method is a fast and accurate method for simulating wave
44 | propagation in complex structures. The major limitation of MBS is that the size of the structure
45 | is limited by the working memory of a single computer or graphics processing unit (GPU).
46 | Through this package, we present a domain decomposition method that removes this limitation. We
47 | decompose large problems over subdomains while maintaining the accuracy, memory efficiency, and guaranteed monotonic convergence of the method. With this work, we have been able to obtain a
48 | factor of $1.93$ increase in size over the single-domain MBS simulations without domain decomposition through a 3D simulation using 2 GPUs. For the Helmholtz problem, we solved a complex structure of size $315 \times 315 \times 315$ wavelengths in just 1.4 hours on a dual-GPU system.
49 | }
50 | }
51 | \maketitle
52 | """,
53 | 'tableofcontents': "",
54 | 'makeindex': "",
55 | 'printindex': "",
56 | 'figure_align': "",
57 | 'extraclassoptions': 'notitlepage',
58 | }
59 | latex_docclass = {
60 | 'manual': 'scrartcl',
61 | 'howto': 'scrartcl',
62 | }
63 | latex_documents = [("index_latex",
64 | "wavesim.tex",
65 | "Wavesim - A Python package for wave propagation simulation",
66 | "",
67 | "howto")]
68 | latex_toplevel_sectioning = 'section'
69 | bibtex_default_style = 'unsrt'
70 | bibtex_bibfiles = ['references.bib']
71 | numfig = True
72 |
73 | # -- General configuration ---------------------------------------------------
74 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
75 | extensions = ['sphinx.ext.napoleon',
76 | 'sphinx.ext.autodoc',
77 | 'sphinx.ext.mathjax',
78 | 'sphinx.ext.viewcode',
79 | 'sphinx_autodoc_typehints',
80 | 'sphinxcontrib.bibtex',
81 | 'sphinx.ext.autosectionlabel',
82 | 'sphinx_markdown_builder',
83 | 'sphinx_gallery.gen_gallery']
84 |
85 | templates_path = ['_templates']
86 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'acknowledgements.rst', 'sg_execution_times.rst']
87 | master_doc = ''
88 | include_patterns = ['**']
89 | napoleon_use_rtype = False
90 | napoleon_use_param = True
91 | typehints_document_rtype = False
92 | latex_engine = 'xelatex'
93 | add_module_names = False
94 | autodoc_preserve_defaults = True
95 |
96 | # -- Options for HTML output -------------------------------------------------
97 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
98 | html_theme = 'sphinx_rtd_theme'
99 |
100 | # -- Options for sphinx-gallery ----------------------------------------------
101 | # https://sphinx-gallery.github.io/stable/configuration.html
102 | sphinx_gallery_conf = {
103 | 'examples_dirs': '../../examples', # path to your example scripts
104 | 'ignore_pattern': '__init__.py|check_mem.py|timing_test.py',
105 | 'gallery_dirs': 'auto_examples', # path to where to save gallery generated output
106 | }
107 |
108 |
109 | # -- Monkey-patch the MarkdownTranslator class to support citations ------------
110 | def visit_citation(self, node):
111 | """Patch-in function for markdown builder to support citations."""
112 | id = node['ids'][0]
113 | self.add(f'')
114 |
115 |
116 | def visit_label(self, node):
117 | """Patch-in function for markdown builder to support citations."""
118 | pass
119 |
120 |
121 | def setup(app):
122 | """Setup function for the Sphinx extension."""
123 | # register event handlers
124 | # app.connect("autodoc-skip-member", skip)
125 | app.connect("build-finished", build_finished)
126 | app.connect("builder-inited", builder_inited)
127 | app.connect("source-read", source_read)
128 | # # app.connect("autodoc-skip-member", skip)
129 | # app.connect("build-finished", copy_readme)
130 | # app.connect("builder-inited", builder_inited)
131 | # app.connect("source-read", source_read)
132 |
133 | # monkey-patch the MarkdownTranslator class to support citations
134 | # TODO: this should be done in the markdown builder itself
135 | cls = MarkdownBuilder.default_translator_class
136 | # cls.visit_citation = visit_citation
137 | # cls.visit_label = visit_label
138 |
139 |
140 | def source_read(app, docname, source):
141 | """Modify the source of the readme and conclusion files based on the builder."""
142 | if docname == 'readme' or docname == 'conclusion':
143 | if (app.builder.name == 'latex') == (docname == 'conclusion'):
144 | source[0] = source[0].replace('%endmatter%', '.. include:: acknowledgements.rst')
145 | else:
146 | source[0] = source[0].replace('%endmatter%', '')
147 |
148 |
149 | def builder_inited(app):
150 | """Set the master document and exclude patterns based on the builder."""
151 | if app.builder.name == 'html':
152 | exclude_patterns.extend(['conclusion.rst', 'index_latex.rst', 'index_markdown.rst'])
153 | app.config.master_doc = 'index'
154 | elif app.builder.name == 'latex':
155 | exclude_patterns.extend(['auto_examples/*', 'index_markdown.rst', 'index.rst', 'api*'])
156 | app.config.master_doc = 'index_latex'
157 | elif app.builder.name == 'markdown':
158 | include_patterns.clear()
159 | include_patterns.extend(['readme.rst', 'index_markdown.rst'])
160 | app.config.master_doc = 'index_markdown'
161 |
162 |
163 | def build_finished(app, exception):
164 | if exception:
165 | return
166 |
167 | if app.builder.name == "markdown":
168 | # Copy the readme file to the root of the documentation directory.
169 | source_file = Path(app.outdir) / "readme.md"
170 | destination_dir = Path(app.confdir).parents[1] / "README.md"
171 | shutil.copy(source_file, destination_dir)
172 |
173 | elif app.builder.name == "latex":
174 | # The latex builder adds an empty author field to the title page.
175 | # This code removes it.
176 | # Define the path to the .tex file
177 | tex_file = Path(app.outdir) / "wavesim.tex"
178 |
179 | # Read the file
180 | with open(tex_file, "r") as file:
181 | content = file.read()
182 |
183 | # Remove \author{} from the file
184 | content = content.replace(r"\author{}", "")
185 |
186 | # Write the modified content back to the file
187 | with open(tex_file, "w") as file:
188 | file.write(content)
189 |
--------------------------------------------------------------------------------
/docs/source/development.rst:
--------------------------------------------------------------------------------
1 | .. _section-development:
2 |
3 | Wavesim Development
4 | ==============================================
5 |
6 | Running the tests and examples
7 | --------------------------------------------------
8 | To download the source code, including tests and examples, clone the repository from GitHub :cite:`wavesim_py`. Wavesim uses `poetry` :cite:`Poetry` for package management, so you have to download and install Poetry first. Then, navigate to the location where you want to store the source code, and execute the following commands to clone the repository, set up the poetry environment, and run the tests.
9 |
10 | .. code-block:: shell
11 |
12 | git clone https://github.com/IvoVellekoop/wavesim_py
13 | cd wavesim_py
14 | poetry install --with dev --with docs
15 | poetry run pytest
16 |
17 | The examples are located in the ``examples`` directory. Note that a lot of functionality is also demonstrated in the automatic tests located in the ``tests`` directory. As an alternative to downloading the source code, the samples can also be copied directly from the example gallery on the documentation website :cite:`readthedocs_Wavesim`.
18 |
19 | Building the documentation
20 | --------------------------------------------------
21 |
22 | .. only:: html or markdown
23 |
24 | The html, and pdf versions of the documentation, as well as the `README.md` file in the root directory of the repository, are automatically generated from the docstrings in the source code and reStructuredText source files in the repository.
25 |
26 | .. only:: latex
27 |
28 | The html version of the documentation, as well as the `README.md` file in the root directory of the repository, and the pdf document you are currently reading are automatically generated from the docstrings in the source code and reStructuredText source files in the repository.
29 |
30 | Note that for building the pdf version of the documentation, you need to have `xelatex` installed, which comes with the MiKTeX distribution of LaTeX :cite:`MiKTeX`. Then, run the following commands to build the html and pdf versions of the documentation, and to auto-generate `README.md`.
31 |
32 | .. code-block:: shell
33 |
34 | poetry shell
35 | cd docs
36 | make clean
37 | make html
38 | make markdown
39 | make latex
40 | cd _build/latex
41 | xelatex wavesim
42 | xelatex wavesim
43 |
44 |
45 | Reporting bugs and contributing
46 | --------------------------------------------------
47 | Bugs can be reported through the GitHub issue tracking system. Better than reporting bugs, we encourage users to *contribute bug fixes, optimizations, and other improvements*. These contributions can be made in the form of a pull request :cite:`zandonellaMassiddaOpenScience2022`, which will be reviewed by the development team and integrated into the package when appropriate. Please contact the current development team through GitHub :cite:`wavesim_py` or the `www.wavesim.org `_ forum to coordinate such contributions.
48 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | Wavesim - A Python package for wave propagation simulation
2 | ==========================================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :numbered:
7 |
8 | Introduction
9 | development
10 | api
11 | auto_examples/index
--------------------------------------------------------------------------------
/docs/source/index_latex.rst:
--------------------------------------------------------------------------------
1 | Wavesim - A Python package for wave propagation simulation
2 | ==========================================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :numbered:
7 |
8 | Introduction
9 | development
10 | conclusion
11 | auto_examples/index
--------------------------------------------------------------------------------
/docs/source/index_markdown.rst:
--------------------------------------------------------------------------------
1 | Wavesim - A Python package for wave propagation simulation
2 | ==========================================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :numbered:
7 |
8 | Introduction
--------------------------------------------------------------------------------
/docs/source/readme.rst:
--------------------------------------------------------------------------------
1 | .. _root_label:
2 |
3 | Wavesim
4 | =====================================
5 |
6 | ..
7 | NOTE: README.MD IS AUTO-GENERATED FROM DOCS/SOURCE/README.RST. DO NOT EDIT README.MD DIRECTLY.
8 |
9 | .. only:: html
10 |
11 | .. image:: https://readthedocs.org/projects/wavesim/badge/?version=latest
12 | :target: https://wavesim.readthedocs.io/en/latest/?badge=latest
13 | :alt: Documentation Status
14 |
15 |
16 | What is Wavesim?
17 | ----------------------------------------------
18 |
19 | Wavesim is a tool to simulate the propagation of waves in complex, inhomogeneous structures. Whereas most available solvers use the popular finite difference time domain (FDTD) method :cite:`yee1966numerical, taflove1995computational, oskooi2010meep, nabavi2007new`, Wavesim is based on the modified Born series (MBS) approach, which has lower memory requirements, no numerical dispersion, and is faster as compared to FDTD :cite:`osnabrugge2016convergent, vettenburg2023universal`.
20 |
21 | This package :cite:`wavesim_py` is a Python implementation of the MBS approach for solving the Helmholtz equation in arbitrarily large media through domain decomposition :cite:`mache2024domain`. With this new framework, we simulated a complex 3D structure of a remarkable :math:`315\times 315\times 315` wavelengths :math:`\left( 3.1\cdot 10^7 \right)` in size in just :math:`1.4` hours by solving over two GPUs. This represents a factor of :math:`1.93` increase over the largest possible simulation on a single GPU without domain decomposition.
22 |
23 | When using Wavesim in your work, please cite:
24 |
25 | :cite:`osnabrugge2016convergent` |osnabrugge2016|_
26 |
27 | :cite:`mache2024domain` |mache2024|_
28 |
29 | .. _osnabrugge2016: https://doi.org/10.1016/j.jcp.2016.06.034
30 | .. |osnabrugge2016| replace:: Osnabrugge, G., Leedumrongwatthanakun, S., & Vellekoop, I. M. (2016). A convergent Born series for solving the inhomogeneous Helmholtz equation in arbitrarily large media. *Journal of computational physics, 322*\ , 113-124.
31 |
32 | .. _mache2024: https://arxiv.org/abs/2410.02395
33 | .. |mache2024| replace:: Mache, S., & Vellekoop, I. M. (2024). Domain decomposition of the modified Born series approach for large-scale wave propagation simulations. *arXiv preprint arXiv:2410.02395*.
34 |
35 | If you use the code in your research, please cite this repository as well :cite:`wavesim_py`.
36 |
37 | Examples and documentation for this project are available at `Read the Docs `_ :cite:`wavesim_documentation`. For more information (and to participate in the forum for discussions, queries, and requests), please visit our website `www.wavesim.org `_.
38 |
39 | Installation
40 | ----------------------------------------------
41 |
42 | Wavesim requires `Python >=3.11.0 and <3.13.0 `_ and uses `PyTorch `_ for GPU acceleration.
43 |
44 | First, clone the repository and navigate to the directory::
45 |
46 | git clone https://github.com/IvoVellekoop/wavesim_py.git
47 | cd wavesim_py
48 |
49 | Then, you can install the dependencies in a couple of ways:
50 |
51 | :ref:`pip_installation`
52 |
53 | :ref:`conda_installation`
54 |
55 | :ref:`poetry_installation`
56 |
57 | We recommend working with a virtual environment to avoid conflicts with other packages.
58 |
59 | .. _pip_installation:
60 |
61 | 1. **Using pip**
62 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
63 |
64 | If you prefer to use pip, you can install the required packages using `requirements.txt `_:
65 |
66 | 1. **Create a virtual environment and activate it** (optional but recommended)
67 |
68 | * First, `create a virtual environment `_ using the following command::
69 |
70 | python -m venv path/to/venv
71 |
72 | * Then, activate the virtual environment. The command depends on your operating system and shell (`How venvs work `_)::
73 |
74 | source path/to/venv/bin/activate # for Linux/macOS
75 | path/to/venv/Scripts/activate.bat # for Windows (cmd)
76 | path/to/venv/Scripts/Activate.ps1 # for Windows (PowerShell)
77 |
78 | 2. **Install packages**::
79 |
80 | pip install -r requirements.txt
81 |
82 | .. _conda_installation:
83 |
84 | 2. **Using conda**
85 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
86 |
87 | We recommend using `Miniconda `_ (a much lighter counterpart of Anaconda) to install Python and the required packages (contained in `environment.yml `_) within a conda environment.
88 |
89 | 1. **Download Miniconda**, choosing the appropriate `Python installer `_ for your operating system (Windows/macOS/Linux).
90 |
91 | 2. **Install Miniconda**, following the `installation instructions `_ for your OS. Follow the prompts on the installer screens. If you are unsure about any setting, accept the defaults. You can change them later. (If you cannot immediately activate conda, close and re-open your terminal window to make the changes take effect).
92 |
93 | 3. **Test your installation**. Open Anaconda Prompt and run the below command. Alternatively, open an editor like `Visual Studio Code `_ or `PyCharm `_, select the Python interpreter in the ``miniconda3/`` directory with the label ``('base')``, and run the command::
94 |
95 | conda list
96 |
97 | A list of installed packages appears if it has been installed correctly.
98 |
99 | 4. **Set up a conda environment**. Avoid using the base environment altogether. It is a good backup environment to fall back on if and when the other environments are corrupted/don't work. Create a new environment using `environment.yml `_ and activate::
100 |
101 | conda env create -f environment.yml
102 | conda activate wavesim
103 |
104 | The `Miniconda environment management guide `_ has more details if you need them.
105 |
106 | Alternatively, you can create a conda environment with a specific Python version, and then use the `requirements.txt `_ file to install the dependencies::
107 |
108 | conda create -n wavesim python'>=3.11.0,<3.13'
109 | conda activate wavesim
110 | pip install -r requirements.txt
111 |
112 | .. _poetry_installation:
113 |
114 | 3. **Using Poetry**
115 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
116 |
117 | 1. Install `Poetry `_.
118 | 2. Install dependencies by running the following command::
119 |
120 | poetry install
121 |
122 | To run tests using pytest, you can install the development dependencies as well::
123 |
124 | poetry install --with dev
125 |
126 | 3. `Activate `_ the virtual environment created by Poetry.
127 |
128 | Running the code
129 | ----------------
130 |
131 | Once the virtual environment is set up with all the required packages, you are ready to run the code. You can go through any of the scripts in the ``examples`` `directory `_ for the basic steps needed to run a simulation. The directory contains examples of 1D, 2D, and 3D problems.
132 |
133 | You can run the code with just three inputs:
134 |
135 | * :attr:`~.Domain.permittivity`, i.e. refractive index distribution squared (a 3-dimensional array on a regular grid),
136 |
137 | * :attr:`~.Domain.periodic`, a tuple of three booleans to indicate whether the domain is periodic in each dimension [``True``] or not [``False``], and
138 |
139 | * :attr:`~.Domain.source`, the same size as permittivity.
140 |
141 | :numref:`helmholtz_1d_analytical` shows a simple example of a 1D problem with a homogeneous medium (`helmholtz_1d_analytical.py `_) to explain these and other inputs.
142 |
143 | .. _helmholtz_1d_analytical:
144 | .. literalinclude:: ../../examples/helmholtz_1d_analytical.py
145 | :language: python
146 | :caption: ``helmholtz_1d_analytical.py``. A simple example of a 1D problem with a homogeneous medium.
147 |
148 | Apart from the inputs :attr:`~.Domain.permittivity`, :attr:`~.Domain.source`, and :attr:`~.Domain.periodic`, all other parameters have defaults. Details about all parameters are given below (with the default values, if defined).
149 |
150 | Parameters in the :class:`~.domain.Domain` class: :class:`~.helmholtzdomain.HelmholtzDomain` (for a single domain without domain decomposition) or :class:`~.multidomain.MultiDomain` (to solve a problem with domain decomposition)
151 |
152 | * :attr:`~.Domain.permittivity`: 3-dimensional array with refractive index-squared distribution in x, y, and z direction. To set up a 1 or 2-dimensional problem, leave the other dimension(s) as 1.
153 |
154 | * :attr:`~.Domain.periodic`: indicates for each dimension whether the simulation is periodic (``True``) or not (``False``). For periodic dimensions, i.e., :attr:`~.Domain.periodic` ``= [True, True, True]``, the field is wrapped around the domain.
155 |
156 | * :attr:`~.pixel_size` ``:float = 0.25``: points per wavelength.
157 |
158 | * :attr:`~.wavelength` ``:float = None``: wavelength in micrometer (um). If not given, i.e. ``None``, it is calculated as ``1/pixel_size = 4 um``.
159 |
160 | * :attr:`~.n_domains` ``: tuple[int, int, int] = (1, 1, 1)``: number of domains to split the simulation into. If the domain size is not divisible by n_domains, the last domain will be slightly smaller than the other ones. If ``(1, 1, 1)``, indicates no domain decomposition.
161 |
162 | * :attr:`~.n_boundary` ``: int = 8``: number of points used in the wrapping and domain transfer correction. Applicable when :attr:`~.Domain.periodic` is False in a dimension, or ``n_domains > 1`` in a dimension.
163 |
164 | * :attr:`~.device` ``: str = None``:
165 |
166 | * ``'cpu'`` to use the cpu,
167 |
168 | * ``'cuda:x'`` to use a specific cuda device
169 |
170 | * ``'cuda'`` or a list of strings, e.g., ``['cuda:0', 'cuda:1']``, to distribute the simulation over the available/given cuda devices in a round-robin fashion
171 |
172 | * ``None``, which is equivalent to ``'cuda'`` if cuda devices are available, and ``'cpu'`` if they are not.
173 |
174 | * :attr:`~.debug` ``: bool = False``: set to ``True`` for testing to return :attr:`~.inverse_propagator_kernel` as output.
175 |
176 | Parameters in the :func:`run_algorithm` function
177 |
178 | * :attr:`~.domain`: the domain object created by HelmholtzDomain() or MultiDomain()
179 |
180 | * :attr:`~.Domain.source`: source term, a 3-dimensional array, with the same size as permittivity. Set up amplitude(s) at the desired location(s), following the same principle as permittivity for 1, 2, or 3-dimensional problems.
181 |
182 | * :attr:`~.alpha` ``: float = 0.75``: relaxation parameter for the Richardson iteration
183 |
184 | * :attr:`~.max_iterations` ``: int = 1000``: maximum number of iterations
185 |
186 | * :attr:`~.threshold` ``: float = 1.e-6``: threshold for the residual norm for stopping the iteration
187 |
188 | %endmatter%
189 |
--------------------------------------------------------------------------------
/docs/source/references.bib:
--------------------------------------------------------------------------------
1 | @article{kruger2017solution,
2 | author = {Benjamin Kr\"{u}ger and Thomas Brenner and Alwin Kienle},
3 | journal = "{Optics Express}",
4 | year = {2017},
5 | number = {21},
6 | pages = {25165--25182},
7 | publisher = {Optica Publishing Group},
8 | title = {Solution of the inhomogeneous {M}axwell's equations using a {B}orn series},
9 | volume = {25},
10 | month = {Oct},
11 | urlintro = {https://opg.optica.org/oe/abstract.cfm?URI=oe-25-21-25165},
12 | doi = {10.1364/OE.25.025165},
13 | }
14 |
15 | @article{mache2024domain,
16 | author={Swapnil Mache and Ivo M. Vellekoop},
17 | title={Domain decomposition of the modified Born series approach for large-scale wave propagation simulations},
18 | journal={arXiv preprint},
19 | year={2024},
20 | eprint={2410.02395},
21 | archivePrefix={arXiv},
22 | primaryClass={physics.comp-ph},
23 | urlintro = {https://arxiv.org/abs/2410.02395},
24 | doiprefix = {10.48550/arXiv.2410.02395}
25 | }
26 |
27 | @misc{MiKTeX,
28 | title = {{MikTeX: An Up-to-Date Implementation of TeX/LaTeX and Related Programs}},
29 | url = {https://miktex.org/},
30 | year = {2024},
31 | note = {Version 24.1}
32 | }
33 |
34 | @article{nabavi2007new,
35 | author = {Majid Nabavi and M.H. Kamran Siddiqui and Javad Dargahi},
36 | title = {A new 9-point sixth-order accurate compact finite-difference method for the {H}elmholtz equation},
37 | journal = {Journal of Sound and Vibration},
38 | year = {2007},
39 | volume = {307},
40 | number = {3},
41 | pages = {972-982},
42 | issn = {0022-460X},
43 | doi = {10.1016/j.jsv.2007.06.070},
44 | urlintro = {https://www.sciencedirect.com/science/article/pii/S0022460X07004877},
45 | }
46 |
47 | @article{oskooi2010meep,
48 | author = {Ardavan F. Oskooi and David Roundy and Mihai Ibanescu and Peter Bermel and J.D. Joannopoulos and Steven G. Johnson},
49 | title = {{MEEP}: A flexible free-software package for electromagnetic simulations by the {FDTD} method},
50 | journal = {Computer Physics Communications},
51 | year = {2010},
52 | volume = {181},
53 | number = {3},
54 | pages = {687-702},
55 | issn = {0010-4655},
56 | doi = {10.1016/j.cpc.2009.11.008},
57 | urlintro = {https://www.sciencedirect.com/science/article/pii/S001046550900383X},
58 | keywords = {Computational electromagnetism, FDTD, Maxwell solver},
59 | }
60 |
61 | @article{osnabrugge2016convergent,
62 | author={Osnabrugge, Gerwin and Leedumrongwatthanakun, Saroch and Vellekoop, Ivo M},
63 | title={A convergent {B}orn series for solving the inhomogeneous {H}elmholtz equation in arbitrarily large media},
64 | journal={Journal of Computational Physics},
65 | year={2016},
66 | volume={322},
67 | pages={113--124},
68 | publisher={Elsevier},
69 | doi={10.1016/j.jcp.2016.06.034},
70 | }
71 |
72 | @article{osnabrugge2021ultra,
73 | author={Osnabrugge, Gerwin and Benedictus, Maaike and Vellekoop, Ivo M},
74 | title={Ultra-thin boundary layer for high-accuracy simulations of light propagation},
75 | journal={Optics express},
76 | year={2021},
77 | volume={29},
78 | number={2},
79 | pages={1649--1658},
80 | publisher={Optica Publishing Group},
81 | doi={10.1364/OE.412833},
82 | }
83 |
84 | @misc{Poetry,
85 | title = {{Poetry: Python Dependency Management and Packaging Made Easy}},
86 | author = {Sébastien Eustace and The Poetry Contributors},
87 | year = {2023},
88 | note = {Version 1.7.1},
89 | url = {https://python-poetry.org/},
90 | }
91 |
92 | @misc{readthedocs_Wavesim,
93 | author = {Ivo M. Vellekoop},
94 | title = {{W}avesim | {R}ead the {D}ocs},
95 | url = {https://readthedocs.org/projects/wavesim/},
96 | note = {[Published DD-MM-2024]},
97 | }
98 |
99 | @book{taflove1995computational,
100 | author={Taflove, Allen and Hagnes, Susan C},
101 | title={Computational electrodynamics: The Finite-Difference Time-Domain Method},
102 | publisher={Artech House},
103 | year={1995}
104 | }
105 |
106 | @article{vettenburg2019calculating,
107 | author = {T. Vettenburg and S. A. R. Horsley and J. Bertolotti},
108 | title = {Calculating coherent light-wave propagation in large heterogeneous media},
109 | journal = {Opt. Express},
110 | year = {2019},
111 | month = {Apr},
112 | volume = {27},
113 | number = {9},
114 | pages = {11946--11967},
115 | publisher = {Optica Publishing Group},
116 | keywords = {Chiral media; Finite element method; Material properties; Negative index materials; Optical microscopy; Scattering media},
117 | urlintro = {https://opg.optica.org/oe/abstract.cfm?URI=oe-27-9-11946},
118 | doi = {10.1364/OE.27.011946},
119 | }
120 |
121 | @article{vettenburg2023universal,
122 | author={Tom Vettenburg and Ivo M. Vellekoop},
123 | title={A universal matrix-free split preconditioner for the fixed-point iterative solution of non-symmetric linear systems},
124 | journal={arXiv preprint},
125 | year={2023},
126 | eprint={2207.14222},
127 | archivePrefix={arXiv},
128 | primaryClass={math.NA},
129 | urlintro = {https://arxiv.org/abs/2207.14222},
130 | doiprefix = {10.48550/arXiv.2207.14222}
131 | }
132 |
133 | @misc{wavesim_documentation,
134 | title = {Wavesim documentation},
135 | url = {https://wavesim.readthedocs.io/en/latest/},
136 | }
137 |
138 | @misc{wavesim_py,
139 | author = {Swapnil Mache and Ivo M. Vellekoop},
140 | title = {Wavesim - A Python package for wave propagation simulation},
141 | url = {https://github.com/IvoVellekoop/wavesim_py}
142 | }
143 |
144 | @misc{wavesim_matlab,
145 | author = {G. Osnabrugge and I. M. Vellekoop},
146 | title = "{WaveSim}",
147 | note = {\url{https://github.com/ivovellekoop/wavesim}},
148 | year = {2020}
149 | }
150 |
151 | @article{yee1966numerical,
152 | author={Kane Yee},
153 | title={Numerical solution of initial boundary value problems involving {M}axwell's equations in isotropic media},
154 | journal={IEEE Transactions on Antennas and Propagation},
155 | year={1966},
156 | volume={14},
157 | number={3},
158 | pages={302-307},
159 | doi={10.1109/TAP.1966.1138693}
160 | }
161 |
162 | @book{zandonellaMassiddaOpenScience2022,
163 | title = {The Open Science Manual: Make Your Scientific Research Accessible and Reproducible},
164 | author = {Zandonella Callegher, Claudio and Massidda, Davide},
165 | year = {2022},
166 | publisher = {},
167 | urlintro = {https://arca-dpss.github.io/manual-open-science/},
168 | doi = {10.5281/zenodo.6521850}
169 | }
170 |
--------------------------------------------------------------------------------
/docs/tex.bat:
--------------------------------------------------------------------------------
1 | .\make.bat clean & .\make.bat latex & cd _build\latex & xelatex wavesim.tex & xelatex wavesim.tex & cd .. & cd ..
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: wavesim
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | dependencies:
7 | - python>=3.11.0,<3.13.0
8 | - numpy<2.0.0
9 | - matplotlib
10 | - scipy
11 | - pytest
12 | - pytorch
13 | - pytorch-cuda
14 | - porespy
15 | - scikit-image<0.23
16 |
--------------------------------------------------------------------------------
/examples/README.rst:
--------------------------------------------------------------------------------
1 | Example gallery
2 | =====================
3 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from scipy.signal.windows import gaussian
5 | from scipy.fft import fftn, ifftn, fftshift
6 | from wavesim.utilities import normalize, relative_error
7 |
8 |
9 | def random_permittivity(shape):
10 | np.random.seed(0) # Set the random seed for reproducibility
11 | n = (1.0 + np.random.rand(*shape).astype(np.float32) +
12 | 0.03j * np.random.rand(*shape).astype(np.float32)) # Random refractive index
13 | n = smooth_n(n, shape) # Low pass filter to remove sharp edges
14 | n.real = normalize(n.real, a=1.0, b=2.0) # Normalize to [1, 2]
15 | n.imag = normalize(n.imag, a=0.0, b=0.03) # Normalize to [0, 0.03]
16 | # make sure that the imaginary part of n² is positive
17 | mask = (n ** 2).imag < 0
18 | n.imag[mask] *= -1.0
19 |
20 | n[0:5, :, :] = 1
21 | n[-5:, :, :] = 1
22 | return n ** 2
23 |
24 |
25 | def smooth_n(n, shape):
26 | """Low pass filter to remove sharp edges"""
27 | n_fft = fftn(n)
28 | w = (window(shape[1]).T @ window(shape[0])).T[:, :, None] * window(shape[2]).reshape(1, 1, shape[2])
29 | n = ifftn(n_fft * fftshift(w))
30 | n = np.clip(n.real, a_min=1.0, a_max=None) + 1.0j * np.clip(n.imag, a_min=0.0, a_max=None)
31 |
32 | assert (n ** 2).imag.min() >= 0, 'Imaginary part of n² is negative'
33 | assert n.shape == shape, 'n and shape do not match'
34 | assert n.dtype == np.complex64, 'n is not complex64'
35 | return n
36 |
37 |
38 | def window(x):
39 | """Create a window function for low pass filtering"""
40 | c0 = round(x / 4)
41 | cl = (x - c0) // 2
42 | cr = cl
43 | if c0 + cl + cr != x:
44 | c0 = x - cl - cr
45 | return np.concatenate((np.zeros((1, cl), dtype=np.complex64),
46 | np.ones((1, c0), dtype=np.complex64),
47 | np.zeros((1, cr), dtype=np.complex64)), axis=1)
48 |
49 |
50 | def construct_source(source_type, at, shape):
51 | if source_type == 'point':
52 | return torch.sparse_coo_tensor(at[:, None], torch.tensor([1.0]), shape, dtype=torch.complex64)
53 | elif source_type == 'plane_wave':
54 | return source_plane_wave(at, shape)
55 | elif source_type == 'gaussian_beam':
56 | return source_gaussian_beam(at, shape)
57 | else:
58 | raise ValueError(f"Unknown source type: {source_type}")
59 |
60 |
61 | def source_plane_wave(at, shape):
62 | """ Set up source, with size same as permittivity,
63 | and a plane wave source on one edge of the domain """
64 | # TODO: use CSR format instead?
65 | data = np.ones((1, shape[1], shape[2]), dtype=np.float32) # the source itself
66 | return torch.sparse_coo_tensor(at, data, shape, dtype=torch.complex64)
67 |
68 |
69 | def source_gaussian_beam(at, shape):
70 | """ Set up source, with size same as permittivity,
71 | and a Gaussian beam source on one edge of the domain """
72 | # TODO: use CSR format instead?
73 | std = (shape[1] - 1) / (2 * 3)
74 | source_amplitude = gaussian(shape[1], std).astype(np.float32)
75 | source_amplitude = np.outer(gaussian(shape[2], std).astype(np.float32),
76 | source_amplitude.astype(np.float32))
77 | source_amplitude = torch.tensor(source_amplitude[None, ...])
78 | data = torch.zeros((1, shape[1], shape[2]), dtype=torch.complex64)
79 | data[0,
80 | 0:shape[1],
81 | 0:shape[2]] = source_amplitude
82 | return torch.sparse_coo_tensor(at, data, shape, dtype=torch.complex64)
83 |
84 |
85 | def plot(x, x_ref, re=None, normalize_x=True):
86 | """Plot the computed field x and the reference field x_ref.
87 | If x and x_ref are 1D arrays, the real and imaginary parts are plotted separately.
88 | If x and x_ref are 2D arrays, the absolute values are plotted.
89 | If x and x_ref are 3D arrays, the central slice is plotted.
90 | If normalize_x is True, the values are normalized to the same range.
91 | The relative error is (computed, if needed, and) displayed.
92 | """
93 |
94 | re = relative_error(x, x_ref) if re is None else re
95 |
96 | if x.ndim == 1 and x_ref.ndim == 1:
97 | plt.subplot(211)
98 | plt.plot(x_ref.real, label='Analytic')
99 | plt.plot(x.real, label='Computed')
100 | plt.legend()
101 | plt.title(f'Real part (RE = {relative_error(x.real, x_ref.real):.2e})')
102 | plt.grid()
103 |
104 | plt.subplot(212)
105 | plt.plot(x_ref.imag, label='Analytic')
106 | plt.plot(x.imag, label='Computed')
107 | plt.legend()
108 | plt.title(f'Imaginary part (RE = {relative_error(x.imag, x_ref.imag):.2e})')
109 | plt.grid()
110 |
111 | plt.suptitle(f'Relative error (RE) = {re:.2e}')
112 | plt.tight_layout()
113 | plt.show()
114 | else:
115 | if x.ndim == 3 and x_ref.ndim == 3:
116 | x = x[x.shape[0]//2, ...]
117 | x_ref = x_ref[x_ref.shape[0]//2, ...]
118 |
119 | x = np.abs(x)
120 | x_ref = np.abs(x_ref)
121 | if normalize_x:
122 | min_val = min(np.min(x), np.min(x_ref))
123 | max_val = max(np.max(x), np.max(x_ref))
124 | a = 0
125 | b = 1
126 | x = normalize(x, min_val, max_val, a, b)
127 | x_ref = normalize(x_ref, min_val, max_val, a, b)
128 | else:
129 | a = None
130 | b = None
131 |
132 | cmap = 'inferno'
133 |
134 | plt.figure(figsize=(10, 5))
135 | plt.subplot(121)
136 | plt.imshow(x_ref, cmap=cmap, vmin=a, vmax=b)
137 | plt.colorbar(fraction=0.046, pad=0.04)
138 | plt.title('Reference')
139 |
140 | plt.subplot(122)
141 | plt.imshow(x, cmap=cmap, vmin=a, vmax=b)
142 | plt.colorbar(fraction=0.046, pad=0.04)
143 | plt.title('Computed')
144 |
145 | plt.suptitle(f'Relative error (RE) = {re:.2e}')
146 | plt.tight_layout()
147 | plt.show()
148 |
--------------------------------------------------------------------------------
/examples/check_mem.py:
--------------------------------------------------------------------------------
1 | from torch.cuda.memory import _record_memory_history, _dump_snapshot
2 | from torch import zeros, complex64, cat
3 | from torch.cuda import empty_cache
4 |
5 | _record_memory_history(True, trace_alloc_max_entries=100000,
6 | trace_alloc_record_context=True)
7 |
8 | t = [None] * 6
9 | t[0] = zeros(8, 500, 500, dtype=complex64, device='cuda')
10 | t[1] = zeros(8, 500, 500, dtype=complex64, device='cuda')
11 |
12 | s = cat([i for i in t if i is not None], dim=0)
13 | print(s.shape)
14 | del s
15 | # s = t[0].clone()
16 |
17 | # for tt in t:
18 | # if tt is not None:
19 | # del tt
20 |
21 | # del s
22 |
23 | _dump_snapshot(f"logs/mem_snapshot.pickle")
24 | _record_memory_history(enabled=None)
25 |
--------------------------------------------------------------------------------
/examples/helmholtz_1d.py:
--------------------------------------------------------------------------------
1 | """
2 | Helmholtz 1D example with glass plate
3 | =====================================
4 | Test for 1D propagation through glass plate.
5 | Compare with reference solution (matlab repo result).
6 | """
7 |
8 | import os
9 | import torch
10 | import numpy as np
11 | from time import time
12 | from scipy.io import loadmat
13 | import sys
14 | sys.path.append(".")
15 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
16 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
17 | from wavesim.iteration import run_algorithm # to run the wavesim iteration
18 | from wavesim.utilities import preprocess, relative_error
19 | from __init__ import plot
20 |
21 | if os.path.basename(os.getcwd()) == 'examples':
22 | os.chdir('..')
23 |
24 |
25 | # Parameters
26 | wavelength = 1. # wavelength in micrometer (um)
27 | n_size = (256, 1, 1) # size of simulation domain (in pixels in x, y, and z direction)
28 | n = np.ones(n_size, dtype=np.complex64) # refractive index map
29 | n[99:130] = 1.5 # glass plate
30 | boundary_widths = 24 # width of the boundary in pixels
31 |
32 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
33 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n
34 |
35 | # Source term. This way is more efficient than dense tensor
36 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain
37 | values = torch.tensor([1.0]) # Amplitude: 1
38 | n_ext = tuple(np.array(n_size) + 2*boundary_array)
39 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64)
40 |
41 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains)
42 | # 1-domain, periodic boundaries (without wrapping correction)
43 | periodic = (True, True, True) # periodic boundaries, wrapped field.
44 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength)
45 | # # OR. Uncomment to test domain decomposition
46 | # periodic = (False, True, True) # wrapping correction
47 | # domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=1., n_domains=(2, 1, 1))
48 |
49 | # Run the wavesim iteration and get the computed field
50 | start = time()
51 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=2000)
52 | end = time() - start
53 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}')
54 | u_computed = u_computed.squeeze().cpu().numpy()[boundary_widths:-boundary_widths]
55 |
56 | # load dictionary of results from matlab wavesim/anysim for comparison and validation
57 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u'])
58 |
59 | # Compute relative error with respect to the reference solution
60 | re = relative_error(u_computed, u_ref)
61 | print(f'Relative error: {re:.2e}')
62 | threshold = 1.e-3
63 | assert re < threshold, f"Relative error higher than {threshold}"
64 |
65 | # Plot the results
66 | plot(u_computed, u_ref, re)
67 |
--------------------------------------------------------------------------------
/examples/helmholtz_1d_analytical.py:
--------------------------------------------------------------------------------
1 | """
2 | Helmholtz 1D analytical test
3 | ============================
4 | Test to compare the result of Wavesim to analytical results.
5 | Compare 1D free-space propagation with analytic solution.
6 | """
7 |
8 | import torch
9 | import numpy as np
10 | from time import time
11 | import sys
12 | sys.path.append(".")
13 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
14 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
15 | from wavesim.iteration import run_algorithm # to run the wavesim iteration
16 | from wavesim.utilities import analytical_solution, preprocess, relative_error
17 | from __init__ import plot
18 |
19 |
20 | # Parameters
21 | wavelength = 1. # wavelength in micrometer (um)
22 | n_size = (256, 1, 1) # size of simulation domain (in pixels in x, y, and z direction)
23 | n = np.ones(n_size, dtype=np.complex64) # permittivity (refractive index²) map
24 | boundary_widths = 16 # width of the boundary in pixels
25 |
26 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
27 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n
28 |
29 | # Source term. This way is more efficient than dense tensor
30 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain
31 | values = torch.tensor([1.0]) # Amplitude: 1
32 | n_ext = tuple(np.array(n_size) + 2*boundary_array)
33 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64)
34 |
35 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains)
36 | # 1-domain, periodic boundaries (without wrapping correction)
37 | periodic = (True, True, True) # periodic boundaries, wrapped field.
38 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength)
39 | # # OR. Uncomment to test domain decomposition
40 | # periodic = (False, True, True) # wrapping correction
41 | # domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, n_domains=(3, 1, 1))
42 |
43 | # Run the wavesim iteration and get the computed field
44 | start = time()
45 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=2000)
46 | end = time() - start
47 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}')
48 | u_computed = u_computed.squeeze().cpu().numpy()[boundary_widths:-boundary_widths]
49 | u_ref = analytical_solution(n_size[0], domain.pixel_size, wavelength)
50 |
51 | # Compute relative error with respect to the analytical solution
52 | re = relative_error(u_computed, u_ref)
53 | print(f'Relative error: {re:.2e}')
54 | threshold = 1.e-3
55 | assert re < threshold, f"Relative error higher than {threshold}"
56 |
57 | # Plot the results
58 | plot(u_computed, u_ref, re)
59 |
--------------------------------------------------------------------------------
/examples/helmholtz_2d.py:
--------------------------------------------------------------------------------
1 | """
2 | Helmholtz 2D high contrast test
3 | ===============================
4 | Test for propagation in 2D structure made of iron,
5 | with high refractive index contrast.
6 | Compare with reference solution (matlab repo result).
7 | """
8 |
9 | import os
10 | import torch
11 | import numpy as np
12 | from time import time
13 | from scipy.io import loadmat
14 | from PIL.Image import BILINEAR, fromarray, open
15 | import sys
16 | sys.path.append(".")
17 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
18 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
19 | from wavesim.iteration import run_algorithm # to run the wavesim iteration
20 | from wavesim.utilities import pad_boundaries, preprocess, relative_error
21 | from __init__ import plot
22 |
23 | if os.path.basename(os.getcwd()) == 'examples':
24 | os.chdir('..')
25 |
26 |
27 | # Parameters
28 | n_iron = 2.8954 + 2.9179j
29 | n_contrast = n_iron - 1
30 | wavelength = 0.532 # Wavelength in micrometers
31 | pixel_size = wavelength / (3 * np.max(abs(n_contrast + 1))) # Pixel size in wavelength units
32 |
33 | # Load image and create refractive index map
34 | oversampling = 0.25
35 | im = np.asarray(open('examples/logo_structure_vector.png')) / 255
36 | n_im = ((np.where(im[:, :, 2] > 0.25, 1, 0) * n_contrast) + 1) # Refractive index map
37 | n_roi = int(oversampling * n_im.shape[0]) # Size of ROI in pixels
38 | n = np.asarray(fromarray(n_im.real).resize((n_roi, n_roi), BILINEAR)) + 1j * np.asarray(
39 | fromarray(n_im.imag).resize((n_roi, n_roi), BILINEAR)) # Refractive index map
40 | boundary_widths = 8 # Width of the boundary in pixels
41 |
42 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
43 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n
44 |
45 | # Source term
46 | source = np.asarray(fromarray(im[:, :, 1]).resize((n_roi, n_roi), BILINEAR))
47 | source = pad_boundaries(source, boundary_array)
48 | source = torch.tensor(source, dtype=torch.complex64)
49 |
50 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains)
51 | # 1-domain, periodic boundaries (without wrapping correction)
52 | periodic = (True, True, True) # periodic boundaries, wrapped field.
53 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength)
54 | # # OR. Uncomment to test domain decomposition
55 | # periodic = (True, False, True) # wrapping correction
56 | # domain = MultiDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength,
57 | # n_domains=(1, 2, 1))
58 |
59 | # Run the wavesim iteration and get the computed field
60 | start = time()
61 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=int(1.e+5))
62 | end = time() - start
63 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}')
64 | u_computed = u_computed.squeeze().cpu().numpy()[*([slice(boundary_widths, -boundary_widths)]*2)]
65 |
66 | # load dictionary of results from matlab wavesim/anysim for comparison and validation
67 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u2d_hc'])
68 |
69 | # Compute relative error with respect to the reference solution
70 | re = relative_error(u_computed, u_ref)
71 | print(f'Relative error: {re:.2e}')
72 | threshold = 1.e-3
73 | assert re < threshold, f"Relative error {re} higher than {threshold}"
74 |
75 | # Plot the results
76 | plot(u_computed, u_ref, re)
77 |
--------------------------------------------------------------------------------
/examples/helmholtz_2d_homogeneous.py:
--------------------------------------------------------------------------------
1 | """
2 | Helmholtz 2D homogeneous medium test
3 | ====================================
4 | Test for propagation in 2D homogeneous medium.
5 | """
6 |
7 | import torch
8 | import numpy as np
9 | from time import time
10 | import sys
11 | sys.path.append(".")
12 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
13 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
14 | from wavesim.iteration import run_algorithm # to run the wavesim iteration
15 | from wavesim.utilities import preprocess
16 |
17 |
18 | # Parameters
19 | wavelength = 1. # Wavelength in micrometers
20 | pixel_size = wavelength/4 # Pixel size in wavelength units
21 | # Size of simulation domain (in pixels in x, y, and z direction)
22 | sim_size = np.array([50, 50, 1]) # Simulation size in micrometers
23 | n_size = (sim_size * wavelength/pixel_size).astype(int)
24 | n = np.ones(n_size, dtype=np.complex64) # Refractive index map
25 | boundary_widths = 8 # Width of the boundary in pixels
26 |
27 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
28 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n
29 |
30 | # Source term. This way is more efficient than dense tensor
31 | indices = torch.tensor([[int(v/2 - 1) + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain
32 | values = torch.tensor([1.0]) # Amplitude: 1
33 | n_ext = tuple(np.array(n_size) + 2*boundary_array)
34 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64)
35 |
36 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains)
37 | # 1-domain, periodic boundaries (without wrapping correction)
38 | periodic = (True, True, True) # periodic boundaries, wrapped field.
39 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength)
40 | # # OR. Uncomment to test domain decomposition
41 | # periodic = (False, True, True) # wrapping correction
42 | # domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength,
43 | # n_domains=(2, 1, 1))
44 |
45 | # Run the wavesim iteration and get the computed field
46 | start = time()
47 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=2000)
48 | end = time() - start
49 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}')
50 | u_computed = u_computed.squeeze()[*([slice(boundary_widths, -boundary_widths)]*2)]
51 |
--------------------------------------------------------------------------------
/examples/helmholtz_2d_low_contrast.py:
--------------------------------------------------------------------------------
1 | """
2 | Helmholtz 2D low contrast test
3 | ===============================
4 | Test for propagation in 2D structure with low refractive index contrast
5 | (made of fat and water to mimic biological tissue).
6 | Compare with reference solution (matlab repo result).
7 | """
8 |
9 | import os
10 | import torch
11 | import numpy as np
12 | from time import time
13 | from scipy.io import loadmat
14 | from PIL.Image import BILINEAR, fromarray, open
15 | import sys
16 | sys.path.append(".")
17 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
18 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
19 | from wavesim.iteration import run_algorithm # to run the wavesim iteration
20 | from wavesim.utilities import pad_boundaries, preprocess, relative_error
21 | from __init__ import plot
22 |
23 | if os.path.basename(os.getcwd()) == 'examples':
24 | os.chdir('..')
25 |
26 |
27 | # Parameters
28 | n_water = 1.33
29 | n_fat = 1.46
30 | wavelength = 0.532 # Wavelength in micrometers
31 | pixel_size = wavelength / (3 * abs(n_fat)) # Pixel size in wavelength units
32 |
33 | # Load image and create refractive index map
34 | oversampling = 0.25
35 | im = np.asarray(open('examples/logo_structure_vector.png')) / 255
36 | n_im = (np.where(im[:, :, 2] > 0.25, 1, 0) * (n_fat - n_water)) + n_water
37 | n_roi = int(oversampling * n_im.shape[0]) # Size of ROI in pixels
38 | n = np.asarray(fromarray(n_im).resize((n_roi, n_roi), BILINEAR)) # Refractive index map
39 | boundary_widths = 40 # Width of the boundary in pixels
40 |
41 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
42 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n
43 |
44 | # Source term
45 | source = np.asarray(fromarray(im[:, :, 1]).resize((n_roi, n_roi), BILINEAR))
46 | source = pad_boundaries(source, boundary_array)
47 | source = torch.tensor(source, dtype=torch.complex64)
48 |
49 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains)
50 | # 1-domain, periodic boundaries (without wrapping correction)
51 | periodic = (True, True, True) # periodic boundaries, wrapped field.
52 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength)
53 | # # OR. Uncomment to test domain decomposition
54 | # periodic = (False, True, True) # wrapping correction
55 | # domain = MultiDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength,
56 | # n_domains=(3, 1, 1))
57 |
58 | # Run the wavesim iteration and get the computed field
59 | start = time()
60 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=10000)
61 | end = time() - start
62 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}')
63 | u_computed = u_computed.squeeze().cpu().numpy()[*([slice(boundary_widths, -boundary_widths)]*2)]
64 |
65 | # load dictionary of results from matlab wavesim/anysim for comparison and validation
66 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u2d_lc'])
67 |
68 | # Compute relative error with respect to the reference solution
69 | re = relative_error(u_computed, u_ref)
70 | print(f'Relative error: {re:.2e}')
71 | threshold = 1.e-3
72 | assert re < threshold, f"Relative error higher than {threshold}"
73 |
74 | # Plot the results
75 | plot(u_computed, u_ref, re)
76 |
--------------------------------------------------------------------------------
/examples/helmholtz_3d_disordered.py:
--------------------------------------------------------------------------------
1 | """
2 | Helmholtz 3D disordered medium test
3 | ===================================
4 | Test for propagation in a 3D disordered medium.
5 | Compare with reference solution (matlab repo result).
6 | """
7 |
8 | import os
9 | import torch
10 | import numpy as np
11 | from time import time
12 | from scipy.io import loadmat
13 | import sys
14 | sys.path.append(".")
15 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
16 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
17 | from wavesim.iteration import run_algorithm # to run the wavesim iteration
18 | from wavesim.utilities import preprocess, relative_error
19 | from __init__ import plot
20 |
21 | if os.path.basename(os.getcwd()) == 'examples':
22 | os.chdir('..')
23 |
24 |
25 | # Parameters
26 | wavelength = 1. # Wavelength in micrometers
27 | n_size = (128, 48, 96) # Size of the domain in pixels (x, y, z)
28 | n = np.ascontiguousarray(loadmat('examples/matlab_results.mat')['n3d_disordered']) # Refractive index map
29 | boundary_widths = 8 # Width of the boundary in pixels
30 |
31 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
32 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n
33 |
34 | # Source term. This way is more efficient than dense tensor
35 | indices = torch.tensor([[int(v/2 - 1) + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain
36 | values = torch.tensor([1.0]) # Amplitude: 1
37 | n_ext = tuple(np.array(n_size) + 2*boundary_array)
38 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64)
39 |
40 | # # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains)
41 | # # 1-domain, periodic boundaries (without wrapping correction)
42 | # periodic = (True, True, True) # periodic boundaries, wrapped field.
43 | # domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength)
44 | # OR. Uncomment to test domain decomposition
45 | periodic = (False, True, True) # wrapping correction
46 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength,
47 | n_domains=(2, 1, 1))
48 |
49 | # Run the wavesim iteration and get the computed field
50 | start = time()
51 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=1000)
52 | end = time() - start
53 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}')
54 | u_computed = u_computed.squeeze().cpu().numpy()[*([slice(boundary_widths, -boundary_widths)]*3)]
55 |
56 | # load dictionary of results from matlab wavesim/anysim for comparison and validation
57 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u3d_disordered'])
58 |
59 | re = relative_error(u_computed, u_ref)
60 | print(f'Relative error: {re:.2e}')
61 | threshold = 1.e-3
62 | assert re < threshold, f"Relative error {re} higher than {threshold}"
63 |
64 | # Plot the results
65 | plot(u_computed, u_ref, re)
66 |
--------------------------------------------------------------------------------
/examples/logo_structure_vector.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IvoVellekoop/wavesim_py/3fc81f6ef9f3ee523575bd81f196c4f34ee0b788/examples/logo_structure_vector.png
--------------------------------------------------------------------------------
/examples/matlab_results.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IvoVellekoop/wavesim_py/3fc81f6ef9f3ee523575bd81f196c4f34ee0b788/examples/matlab_results.mat
--------------------------------------------------------------------------------
/examples/mem_snapshot.py:
--------------------------------------------------------------------------------
1 | """
2 | PyTorch memory snapshot example
3 | ===============================
4 | This script captures a memory snapshot of the GPU memory usage during the simulation.
5 | """
6 |
7 | import os
8 | import sys
9 | import torch
10 | import platform
11 | import numpy as np
12 | from time import time
13 | from paper_code.__init__ import random_refractive_index, construct_source
14 | sys.path.append(".")
15 | from wavesim.helmholtzdomain import HelmholtzDomain
16 | from wavesim.multidomain import MultiDomain
17 | from wavesim.iteration import run_algorithm
18 | from wavesim.utilities import preprocess
19 |
20 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
21 | os.environ["TORCH_USE_CUDA_DSA"] = "1"
22 | if os.path.basename(os.getcwd()) == 'examples':
23 | os.chdir('..')
24 | os.makedirs("logs", exist_ok=True)
25 |
26 |
27 | def is_supported_platform():
28 | return platform.system().lower() == "linux" and sys.maxsize > 2**32
29 |
30 |
31 | if is_supported_platform():
32 | torch.cuda.memory._record_memory_history(True, trace_alloc_max_entries=100000,
33 | trace_alloc_record_context=True)
34 | else:
35 | print(f"Pytorch emory snapshot functionality is not supported on {platform.system()} (non-linux non-x86_64 platforms).")
36 | # On Windows, gives "RuntimeError: record_context_cpp is not supported on non-linux non-x86_64 platforms"
37 |
38 | # generate a refractive index map
39 | sim_size = 100 * np.array([2, 1, 1]) # Simulation size in micrometers
40 | wavelength = 1.
41 | pixel_size = 0.25
42 | boundary_widths = 20
43 | n_dims = len(sim_size.squeeze())
44 |
45 | # Size of the simulation domain in pixels
46 | n_size = sim_size * wavelength / pixel_size
47 | n_size = n_size - 2 * boundary_widths # Subtract the boundary widths
48 | n_size = tuple(n_size.astype(int)) # Convert to integer for indexing
49 |
50 | n = random_refractive_index(n_size)
51 | print(f"Size of n: {n_size}")
52 | print(f"Size of n in GB: {n.nbytes / (1024**3):.2f}")
53 | assert n.imag.min() >= 0, 'Imaginary part of n is negative'
54 |
55 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
56 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n
57 | assert n.imag.min() >= 0, 'Imaginary part of n² is negative'
58 |
59 | source = construct_source(n_size, boundary_array)
60 |
61 | n_domains = (2, 1, 1) # number of domains in each direction
62 | periodic = (False, True, True) # True for 1 domain in that direction, False otherwise
63 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, pixel_size=pixel_size,
64 | n_domains=n_domains)
65 |
66 | # Run the wavesim iteration and get the computed field
67 | start = time()
68 | u, iterations, residual_norm = run_algorithm(domain, source, max_iterations=5)
69 | end = time() - start
70 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}')
71 |
72 | if is_supported_platform():
73 | try:
74 | torch.cuda.memory._dump_snapshot(f"logs/mem_snapshot_cluster.pickle")
75 | # To view memory snapshot, got this link in a browser window: https://pytorch.org/memory_viz
76 | # Then drag and drop the file "mem_snapshot.pickle" into the browser window.
77 | # From the dropdown menus in the top left corner, open the second one and select "Allocator State History".
78 | except Exception as e:
79 | # logger.error(f"Failed to capture memory snapshot {e}")
80 | print(f"Failed to capture memory snapshot {e}")
81 |
82 | # Stop recording memory snapshot history
83 | torch.cuda.memory._record_memory_history(enabled=None)
84 |
85 |
86 | # %% Postprocessing
87 |
88 | file_name = 'logs/size'
89 | for i in range(n_dims):
90 | file_name += f'{n_size[i]}_'
91 | file_name += f'bw{boundary_widths}_domains'
92 | for i in range(n_dims):
93 | file_name += f'{n_domains[i]}'
94 |
95 | output = (f'Size {n_size}; Boundaries {boundary_widths}; Domains {n_domains}; '
96 | + f'Time {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e} \n')
97 | with open('logs/output.txt', 'a') as file:
98 | file.write(output)
99 |
--------------------------------------------------------------------------------
/examples/paper_code/fig1_splitting.py:
--------------------------------------------------------------------------------
1 | """
2 | This script generates a figure showing the splitting of matrices A, L, and V.
3 |
4 | It sets up the problem parameters, computes the matrices A, L, and V,
5 | and visualizes them in a figure with four subplots:
6 | - Subplot (a): Matrix A
7 | - Subplot (b): Matrix L without wraparound artifacts
8 | - Subplot (c): Matrix V
9 | - Subplot (d): Matrix L with wraparound artifacts
10 |
11 | The figure is saved as a PDF file.
12 | """
13 |
14 | # import packages
15 | import os
16 | import sys
17 | import torch
18 | import numpy as np
19 | import matplotlib.pyplot as plt
20 | from matplotlib import rc, rcParams, colors
21 |
22 | sys.path.append(".")
23 | sys.path.append("..")
24 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
25 | from wavesim.utilities import full_matrix, normalize, preprocess
26 | from wavesim.iteration import domain_operator
27 |
28 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 13}
29 | rc('font', **font)
30 | rcParams['mathtext.fontset'] = 'cm'
31 |
32 | if os.path.basename(os.getcwd()) == 'paper_code':
33 | os.chdir('..')
34 | os.makedirs('paper_figures', exist_ok=True)
35 | filename = 'paper_figures/fig1_splitting.pdf'
36 | else:
37 | try:
38 | os.makedirs('examples/paper_figures', exist_ok=True)
39 | filename = 'examples/paper_figures/fig1_splitting.pdf'
40 | except FileNotFoundError:
41 | filename = 'fig1_splitting.pdf'
42 |
43 | # Define problem parameters
44 | boundary_widths = 0
45 | n_size = (40, 1, 1)
46 | # Random refractive index distribution
47 | torch.manual_seed(0) # Set the random seed for reproducibility
48 | n = (torch.normal(mean=1.3, std=0.1, size=n_size, dtype=torch.float32)
49 | + 1j * abs(torch.normal(mean=0.05, std=0.02, size=n_size, dtype=torch.float32)))
50 | assert n.imag.min() >= 0, 'Imaginary part of n is negative'
51 |
52 | wavelength = 1.
53 | pixel_size = wavelength / 4
54 | periodic = (True, True, True)
55 |
56 | # Get matrices of A, L (with wrapping artifacts), and V operators
57 | domain = HelmholtzDomain(permittivity=n**2, periodic=periodic,
58 | wavelength=wavelength, pixel_size=pixel_size,
59 | debug=True)
60 | domain.set_source(0)
61 | b = full_matrix(domain_operator(domain, 'medium')).cpu().numpy() # B = I - V
62 | l_plus1 = full_matrix(domain_operator(domain, 'inverse_propagator')).cpu().numpy() # L + I
63 |
64 | I = np.eye(np.prod(n_size), dtype=np.complex64)
65 | v = I - b
66 | l = l_plus1 - I
67 |
68 | # Get matrices of A, L, and V for large domain without wraparound artifacts
69 | boundary_widths = 100
70 | n, boundary_array = preprocess((n**2).numpy(), boundary_widths)
71 | assert n.imag.min() >= 0, 'Imaginary part of n is negative'
72 | domain_o = HelmholtzDomain(permittivity=n, periodic=periodic,
73 | wavelength=wavelength, pixel_size=pixel_size,
74 | debug=True)
75 |
76 | crop2roi = (slice(boundary_array[0], -boundary_array[0]),
77 | slice(boundary_array[0], -boundary_array[0]))
78 | b_o = full_matrix(domain_operator(domain_o, 'medium'))[crop2roi].cpu().numpy() # B = I - V
79 | l_plus1_o = full_matrix(domain_operator(domain_o, 'inverse_propagator'))[crop2roi].cpu().numpy() # L + I
80 |
81 | v_o = (I - b_o) / domain_o.scale + domain_o.shift.item()*I # V = (I - B) / scaling
82 | l_o = (l_plus1_o - I) / domain_o.scale - domain_o.shift.item()*I # L = (L + I - I) / scaling
83 |
84 | v_o = (v_o - domain.shift.item()*I) * domain.scale
85 | l_o = (l_o + domain.shift.item()*I) * domain.scale
86 |
87 | a_o = l_o + v_o # A = L + V
88 |
89 | # Normalize matrices for visualization
90 | l = l.imag
91 | v = v.imag
92 | a_o = a_o.imag
93 | l_o = l_o.imag
94 |
95 | max_val = max(np.max(l), np.max(a_o), np.max(l_o), np.max(v))
96 | min_val = min(np.min(l), np.min(a_o), np.min(l_o), np.min(v))
97 | extremum = max(abs(min_val), abs(max_val))
98 | vmin = -1
99 | vmax = 1
100 |
101 | l = normalize(l, -extremum, extremum, vmin, vmax)
102 | v = normalize(v, -extremum, extremum, vmin, vmax)
103 | a_o = normalize(a_o, -extremum, extremum, vmin, vmax)
104 | l_o = normalize(l_o, -extremum, extremum, vmin, vmax)
105 |
106 | # Create a figure with four subplots in one row
107 | fig, axs = plt.subplots(1, 4, figsize=(12, 3), sharex=True, sharey=True,
108 | gridspec_kw={'wspace': 0.15, 'width_ratios': [1, 1, 1, 1.094]})
109 | fraction = 0.046
110 | pad = 0.04
111 | extent = np.array([0, n_size[0], n_size[0], 0]) # * base.pixel_size
112 | cmap = 'seismic'
113 |
114 | ax0 = axs[0]
115 | im0 = ax0.imshow(a_o, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax)
116 | ax0.set_title('$A$')
117 |
118 | ax1 = axs[1]
119 | im1 = ax1.imshow(l_o, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax)
120 | ax1.set_title('$L$')
121 |
122 | ax2 = axs[2]
123 | im2 = ax2.imshow(v, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax)
124 | ax2.set_title('$V$')
125 |
126 | ax3 = axs[3]
127 | im3 = ax3.imshow(l, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax)
128 | ax3.set_title('$L$ with wraparound artifacts')
129 | fig.colorbar(im3, ax=ax3, fraction=fraction, pad=pad)
130 |
131 | # Add text boxes with labels (a), (b), (c), ...
132 | labels = ['(a)', '(b)', '(c)', '(d)']
133 | for i, ax in enumerate(axs.flat):
134 | ax.text(0.5, -0.23, labels[i], transform=ax.transAxes, ha='center')
135 | ax.set_xticks(np.arange(0, 41, 10))
136 | ax.set_yticks(np.arange(0, 41, 10))
137 |
138 | plt.savefig(filename, bbox_inches='tight', pad_inches=0.03, dpi=300)
139 | plt.close('all')
140 | print(f'Saved: {filename}')
141 |
--------------------------------------------------------------------------------
/examples/paper_code/fig2_decompose.py:
--------------------------------------------------------------------------------
1 | """
2 | This script demonstrates the domain decomposition of operator A, L, and V.
3 | It visualizes the matrices in a figure with three subplots:
4 | - Subplot (a): Matrix A
5 | - Subplot (b): Matrix L
6 | - Subplot (c): Matrix V
7 |
8 | The figure is saved as a PDF file.
9 | """
10 |
11 | # import packages
12 | import os
13 | import sys
14 | import torch
15 | import numpy as np
16 | import matplotlib.pyplot as plt
17 | from matplotlib import rc, rcParams, colors
18 |
19 | sys.path.append(".")
20 | sys.path.append("..")
21 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
22 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
23 | from wavesim.utilities import full_matrix, normalize, preprocess
24 | from wavesim.iteration import domain_operator
25 |
26 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 13}
27 | rc('font', **font)
28 | rcParams['mathtext.fontset'] = 'cm'
29 |
30 | if os.path.basename(os.getcwd()) == 'paper_code':
31 | os.chdir('..')
32 | os.makedirs('paper_figures', exist_ok=True)
33 | filename = 'paper_figures/fig2_decompose.pdf'
34 | else:
35 | try:
36 | os.makedirs('examples/paper_figures', exist_ok=True)
37 | filename = 'examples/paper_figures/fig2_decompose.pdf'
38 | except FileNotFoundError:
39 | filename = 'fig2_decompose.pdf'
40 |
41 | # Define problem parameters
42 | boundary_widths = 0
43 | n_size = (40, 1, 1)
44 | # Random refractive index distribution
45 | torch.manual_seed(0) # Set the random seed for reproducibility
46 | n = (torch.normal(mean=1.3, std=0.1, size=n_size, dtype=torch.float32)
47 | + 1j * abs(torch.normal(mean=0.05, std=0.02, size=n_size, dtype=torch.float32)))**2
48 | assert n.imag.min() >= 0, 'Imaginary part of n is negative'
49 |
50 | wavelength = 1.
51 | pixel_size = wavelength / 4
52 | periodic = (True, True, True)
53 |
54 | # Get matrices of A, L, and V for large domain without wraparound artifacts
55 | boundary_widths = 100
56 | n_o, boundary_array = preprocess(n.numpy(), boundary_widths)
57 | assert n_o.imag.min() >= 0, 'Imaginary part of n_o is negative'
58 |
59 | domain_o = HelmholtzDomain(permittivity=n_o, periodic=periodic, wavelength=wavelength,
60 | pixel_size=pixel_size, debug=True)
61 |
62 | crop2roi = (slice(boundary_array[0], -boundary_array[0]),
63 | slice(boundary_array[0], -boundary_array[0]))
64 |
65 | I = np.eye(np.prod(n_size), dtype=np.complex64)
66 |
67 | l_plus1_o = full_matrix(domain_operator(domain_o, 'inverse_propagator'))[crop2roi].cpu().numpy() # L + I
68 | l_o = (l_plus1_o - I) / domain_o.scale - domain_o.shift.item()*I # L = (L + I - I) / scaling
69 |
70 | # Get matrices of A, L, and V operators decomposed into two domains
71 | domain_2 = MultiDomain(permittivity=n, periodic=(True, True, True), wavelength=wavelength,
72 | pixel_size=pixel_size, n_domains=(2,1,1), debug=True, n_boundary=0)
73 |
74 | b_2 = full_matrix(domain_operator(domain_2, 'medium')).cpu().numpy() # B = I - V
75 | l_plus1_2 = full_matrix(domain_operator(domain_2, 'inverse_propagator')).cpu().numpy() # L + I
76 |
77 | v_2 = I - b_2
78 | l_2 = l_plus1_2 - I
79 |
80 | l_o = (l_o + domain_2.shift.item()*I) * domain_2.scale
81 |
82 | l_diff = l_o - l_2
83 | v_l_diff = v_2 + l_diff
84 | a_2 = l_2 + v_l_diff # A = L + V
85 |
86 | # Normalize matrices for visualization
87 | l_2 = l_2.imag
88 | v_l_diff = v_l_diff.imag
89 | a_2 = a_2.imag
90 |
91 | max_val = max(np.max(l_2), np.max(v_l_diff), np.max(a_2))
92 | min_val = min(np.min(l_2), np.min(v_l_diff), np.min(a_2))
93 | extremum = max(abs(min_val), abs(max_val))
94 | vmin = -1
95 | vmax = 1
96 |
97 | l_2 = normalize(l_2, -extremum, extremum, vmin, vmax)
98 | v_l_diff = normalize(v_l_diff, -extremum, extremum, vmin, vmax)
99 | a_2 = normalize(a_2, -extremum, extremum, vmin, vmax)
100 |
101 | # Create a figure with three subplots in one row
102 | fig, axs = plt.subplots(1, 3, figsize=(9, 3), sharex=True, sharey=True,
103 | gridspec_kw={'wspace': 0.15, 'width_ratios': [1, 1, 1.094]})
104 | fraction = 0.046
105 | pad = 0.04
106 | extent = np.array([0, n_size[0], n_size[0], 0])
107 | cmap = 'seismic'
108 |
109 | ax0 = axs[0]
110 | im0 = ax0.imshow(a_2, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax)
111 | ax0.set_title('$A$')
112 | kwargs0 = dict(transform=ax0.transAxes, ha='center', va='center',
113 | bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.1',
114 | alpha=0.5, linewidth=0.0))
115 | ax0.text(0.09, 0.57, '$A_{11}$', **kwargs0)
116 | ax0.text(0.9, 0.92, '$A_{12}$', **kwargs0)
117 | ax0.text(0.09, 0.07, '$A_{21}$', **kwargs0)
118 | ax0.text(0.9, 0.42, '$A_{22}$', **kwargs0)
119 |
120 | ax1 = axs[1]
121 | im1 = ax1.imshow(l_2, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax)
122 | ax1.set_title('$L$')
123 |
124 | ax2 = axs[2]
125 | im2 = ax2.imshow(v_l_diff, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax)
126 | ax2.set_title('$V$')
127 | kwargs2 = dict(transform=ax2.transAxes, ha='center', va='center',
128 | bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.1',
129 | alpha=0.5, linewidth=0.0))
130 | ax2.text(0.18, 0.64, '$C_{11}$', **kwargs2)
131 | ax2.text(0.84, 0.84, '$A_{12}$', **kwargs2)
132 | ax2.text(0.18, 0.15, '$A_{21}$', **kwargs2)
133 | ax2.text(0.84, 0.34, '$C_{22}$', **kwargs2)
134 | fig.colorbar(im2, ax=ax2, fraction=fraction, pad=pad)
135 |
136 | # Add text boxes with labels (a), (b), (c), ...
137 | labels = ['(a)', '(b)', '(c)']
138 | for i, ax in enumerate(axs.flat):
139 | ax.text(0.5, -0.23, labels[i], transform=ax.transAxes, ha='center')
140 | ax.axhline(20, color='gray', linestyle='--', alpha=0.5) # subdomain demarcation
141 | ax.axvline(20, color='gray', linestyle='--', alpha=0.5) # subdomain demarcation
142 | ax.set_yticks(np.arange(0, 41, 10))
143 | ax.set_xticks(np.arange(0, 41, 10))
144 |
145 | plt.savefig(filename, bbox_inches='tight', pad_inches=0.03, dpi=300)
146 | plt.close('all')
147 | print(f'Saved: {filename}')
148 |
--------------------------------------------------------------------------------
/examples/paper_code/fig3_correction_matrix.py:
--------------------------------------------------------------------------------
1 | """
2 | This script generates a figure demonstrating the fast convolution of the Laplacian with a point source
3 | for a 1-domain and 2-domain case, and the construction of the correction matrix A_{12}.
4 |
5 | The figure consists of three subplots:
6 | - Subplot (a): Fast convolution with a point source for 1-domain case
7 | - Subplot (b): Fast convolution with a point source for 2-domain case, where the fast convolution is performed over 1
8 | subdomain. The subplot also shows the difference between the unwrapped and wrapped fields.
9 | - Subplot (c): Correction matrix A_{12}, a non-cyclic convolution matrix that computes the wrapping artifacts.
10 |
11 | The figure is saved as a PDF file.
12 | """
13 |
14 | # import packages
15 | import os
16 | import sys
17 | import torch
18 | import numpy as np
19 | import matplotlib.pyplot as plt
20 | from matplotlib.gridspec import GridSpec
21 | from matplotlib import rc, rcParams, colors
22 |
23 | sys.path.append(".")
24 | sys.path.append("..")
25 | from wavesim.domain import Domain
26 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
27 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
28 | from wavesim.utilities import normalize, pad_boundaries, preprocess
29 |
30 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 14}
31 | rc('font', **font)
32 | rcParams['mathtext.fontset'] = 'cm'
33 |
34 | if os.path.basename(os.getcwd()) == 'paper_code':
35 | os.chdir('..')
36 | os.makedirs('paper_figures', exist_ok=True)
37 | filename = 'paper_figures/fig3_correction_matrix.pdf'
38 | else:
39 | try:
40 | os.makedirs('examples/paper_figures', exist_ok=True)
41 | filename = 'examples/paper_figures/fig3_correction_matrix.pdf'
42 | except FileNotFoundError:
43 | filename = 'fig3_correction_matrix.pdf'
44 |
45 |
46 | def coordinates_f_sq(n_, pixel_size=0.25):
47 | """ Calculate the coordinates in the frequency domain and returns squared values
48 | :param n_: Number of points
49 | :param pixel_size: Pixel size. Defaults to 0.25.
50 | :return: Tensor containing the coordinates in the frequency domain """
51 | return (2 * torch.pi * torch.fft.fftfreq(n_, pixel_size))**2
52 |
53 | # Define problem parameters
54 | boundary_widths = 0.
55 | n_size = 40
56 |
57 | # Fast convolution of Laplacian with point source
58 |
59 | # Case 1: 1 domain
60 | d = int(n_size/2 - 1) # 1 point before the center
61 | # Point source
62 | side = torch.zeros(n_size)
63 | side[d] = 1.0
64 | # Fast convolution result
65 | fc1 = (torch.fft.ifftn(coordinates_f_sq(n_size)
66 | * torch.fft.fftn(side))).real.cpu().numpy() # discard tiny imaginary part due to numerical errors
67 |
68 | # Case 2: 2 domains. Fast convolution of Laplacian with point source over 1 subdomain
69 | # Point source
70 | side2 = torch.zeros(n_size // 2)
71 | side2[d] = 1.0
72 | # discard tiny imaginary part due to numerical errors in the fast conv result
73 | fc2 = (torch.fft.ifftn(coordinates_f_sq(n_size // 2)
74 | * torch.fft.fftn(side2))).real.cpu().numpy()
75 | fc2 = np.concatenate((fc2, np.zeros_like(fc2))) # zero padding for the 2nd domain
76 | diff_ = fc1 - fc2 # difference between unwrapped and wrapped fields
77 |
78 | # construct a non-cyclic convolution matrix that computes the wrapping artifacts only
79 | t = 8 # number of correction points for constructing correction matrix
80 | a_12 = np.zeros((t, t), dtype=np.complex64)
81 | for r in range(t):
82 | a_12[r, :] = diff_[r:r + t]
83 | a_12 = np.flip(a_12.real, axis=0) # flip the matrix
84 |
85 | # Normalize matrices for visualization
86 | min_val = min(fc1.min(), fc2.min(), diff_.min(), a_12.min())
87 | max_val = max(fc1.max(), fc2.max(), diff_.max(), a_12.max())
88 | extremum = max(abs(min_val), abs(max_val))
89 | vmin = -1
90 | vmax = 1
91 |
92 | fc1 = normalize(fc1, -extremum, extremum, vmin, vmax)
93 | fc2 = normalize(fc2, -extremum, extremum, vmin, vmax)
94 | diff_ = normalize(diff_, -extremum, extremum, vmin, vmax)
95 | a_12 = normalize(a_12, -extremum, extremum, vmin, vmax)
96 |
97 | # Plot limits
98 | c = n_size//2 # center of the domain
99 |
100 | # Plot
101 | fig = plt.figure(figsize=(12, 6), layout='constrained')
102 | gs = GridSpec(2, 2, figure=fig)
103 |
104 | # 1-domain case
105 | ax1 = fig.add_subplot(gs[0, 0])
106 | ax1.plot(fc1, 'k') # 1-domain, no wrapping field
107 | ax1.set_title(r'$\nabla^2$ kernel in infinite domain')
108 | ax1.set_xlim([-1, n_size])
109 | ax1.set_xticks(np.arange(0, n_size + 1, 5))
110 | ax1.set_xticklabels([])
111 | ax1.grid(True, which='major', linestyle='--', linewidth=0.5)
112 | ax1.text(0.5, -0.15, '(a)', transform=ax1.transAxes, ha='center')
113 |
114 | # 2-domain case
115 | ax2 = fig.add_subplot(gs[1, 0])
116 | ax2.plot(fc2, 'k', label='Field') # 2-domain, with wrapping field
117 | ax2.axvspan(c, n_size - 1, color='gray', alpha=0.3) # Patch to demarcate 2nd subdomain
118 |
119 | # difference between unwrapped and wrapped fields
120 | ax2.plot(diff_[:t], 'r--', label='Corrections') # wrapping correction
121 | ax2.text(0.14, 0.5, '$C_{11}$', color='r', transform=ax2.transAxes, ha='center')
122 |
123 | ax2.plot(np.arange(c, c + t), diff_[c:c + t], 'r--') # transfer correction
124 | ax2.text(0.62, 0.23, '$A_{12}$', color='r', transform=ax2.transAxes, ha='center')
125 |
126 | ax2.set_xlim([-1, n_size])
127 | ax2.set_xticks(np.arange(0, n_size + 1, 5))
128 | ax2.set_xlabel('x')
129 | ax2.text(0.26, 0.09, 'Subdomain 1', transform=ax2.transAxes, ha='center')
130 | ax2.text(0.76, 0.09, 'Subdomain 2', transform=ax2.transAxes, ha='center')
131 | ax2.set_title(r'$\nabla^2$ kernel in periodic subdomains')
132 | ax2.grid(True, which='major', linestyle='--', linewidth=0.5)
133 | ax2.legend()
134 | ax2.text(0.5, -0.36, '(b)', transform=ax2.transAxes, ha='center')
135 |
136 | # Correction matrix
137 | ax3 = fig.add_subplot(gs[:, 1])
138 | im3 = ax3.imshow(a_12, cmap='seismic', extent=[0, t, 0, t], norm=colors.CenteredNorm())
139 | cb3 = plt.colorbar(mappable=im3, fraction=0.05, pad=0.01, ax=ax3)
140 | ax3.set_title('$A_{12}$')
141 | ax3.text(0.5, -0.15, '(c)', transform=ax3.transAxes, ha='center')
142 |
143 | plt.savefig(filename, bbox_inches='tight', pad_inches=0.03, dpi=300)
144 | plt.close('all')
145 |
146 | # Print min and max a_12 values. min/max should be <<1% (for truncation to make sense)
147 | min_a_12 = a_12[0, -1]
148 | max_a_12 = a_12[-1, 0]
149 | percent = (min_a_12/max_a_12) * 100
150 | print(f'Wrapping artifact amplitudes: Min {min_a_12:.3f}, Max {max_a_12:.3f}')
151 | print(f'Min/Max of A_{12} = {percent:.2f} %')
152 | assert percent < 1, f"Min/Max of A_{12} ratio exceeds 1%: {percent:.2f} %"
153 | print('Done.')
154 |
--------------------------------------------------------------------------------
/examples/paper_code/fig4_truncate.py:
--------------------------------------------------------------------------------
1 | """
2 | This script generates a figure to demonstrate the truncationg of wrapping and transfer
3 | corrections. The figure consists of three subpltos:
4 |
5 | - Subplot (a): Matrix A with truncated corrections
6 | - Subplot (b): Matrix L
7 | - Subplot (c): Matrix V with truncated corrections
8 |
9 | The figure is saved as a PDF file.
10 | """
11 |
12 | # import packages
13 | import os
14 | import sys
15 | import torch
16 | import numpy as np
17 | import matplotlib.pyplot as plt
18 | from matplotlib import rc, rcParams, colors
19 |
20 | sys.path.append(".")
21 | sys.path.append("..")
22 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
23 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
24 | from wavesim.utilities import full_matrix, normalize, preprocess, relative_error
25 | from wavesim.iteration import domain_operator
26 |
27 | current_dir = os.path.dirname(__file__)
28 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 13}
29 | rc('font', **font)
30 | rcParams['mathtext.fontset'] = 'cm'
31 |
32 | if os.path.basename(os.getcwd()) == 'paper_code':
33 | os.chdir('..')
34 | os.makedirs('paper_figures', exist_ok=True)
35 | filename = 'paper_figures/fig4_truncate.pdf'
36 | else:
37 | try:
38 | os.makedirs('examples/paper_figures', exist_ok=True)
39 | filename = 'examples/paper_figures/fig4_truncate.pdf'
40 | except FileNotFoundError:
41 | filename = 'fig4_truncate.pdf'
42 |
43 | # Define problem parameters
44 | n_size = (40, 1, 1)
45 | # Random refractive index distribution
46 | torch.manual_seed(0) # Set the random seed for reproducibility
47 | n = (torch.normal(mean=1.3, std=0.1, size=n_size, dtype=torch.float32)
48 | + 1j * abs(torch.normal(mean=0.05, std=0.02, size=n_size, dtype=torch.float32)))**2
49 | assert n.imag.min() >= 0, 'Imaginary part of n is negative'
50 |
51 | wavelength = 1.
52 | pixel_size = wavelength / 4
53 | periodic = (True, True, True)
54 |
55 | I = np.eye(np.prod(n_size), dtype=np.complex64)
56 |
57 | # Get matrices of A, L, and V operators (with truncated corrections) decomposed into two domains
58 | domain_c = MultiDomain(permittivity=n, periodic=(False, True, True), wavelength=wavelength,
59 | pixel_size=pixel_size, n_domains=(2,1,1), n_boundary=6, debug=True)
60 |
61 | b_c = full_matrix(domain_operator(domain_c, 'medium')).cpu().numpy() # B = I - V
62 | l_plus1_c = full_matrix(domain_operator(domain_c, 'inverse_propagator')).cpu().numpy() # L + I
63 |
64 | v_c = I - b_c
65 | l_c = l_plus1_c - I
66 | a_c = l_c + v_c # A = L + V
67 |
68 | # Normalize matrices for visualization
69 | a_c = a_c.imag
70 | l_c = l_c.imag
71 | v_c = v_c.imag
72 |
73 | max_val = max(np.max(a_c), np.max(l_c), np.max(v_c))
74 | min_val = min(np.min(a_c), np.min(l_c), np.min(v_c))
75 | extremum = max(abs(min_val), abs(max_val))
76 | vmin = -1
77 | vmax = 1
78 |
79 | a_c = normalize(a_c, -extremum, extremum, vmin, vmax)
80 | v_c = normalize(v_c, -extremum, extremum, vmin, vmax)
81 | l_c = normalize(l_c, -extremum, extremum, vmin, vmax)
82 |
83 | # Plot (a) A, (b) L, and (c) V with truncated corrections
84 |
85 | # Create a figure with three subplots in one row
86 | fig, axs = plt.subplots(1, 3, figsize=(9, 3), sharex=True, sharey=True,
87 | gridspec_kw={'wspace': 0.15, 'width_ratios': [1, 1, 1.094]})
88 | fraction = 0.046
89 | pad = 0.04
90 | extent = np.array([0, n_size[0], n_size[0], 0])
91 | cmap = 'seismic'
92 |
93 | ax0 = axs[0]
94 | im0 = ax0.imshow(a_c, cmap=cmap, extent=extent, norm=colors.SymLogNorm(linthresh=0.7, vmin=vmin, vmax=vmax))
95 | ax0.set_title('$A$ (truncated corrections)')
96 |
97 | ax1 = axs[1]
98 | im1 = ax1.imshow(l_c, cmap=cmap, extent=extent, norm=colors.SymLogNorm(linthresh=0.7, vmin=vmin, vmax=vmax))
99 | ax1.set_title('$L$')
100 |
101 | ax2 = axs[2]
102 | im2 = ax2.imshow(v_c, cmap=cmap, extent=extent, norm=colors.SymLogNorm(linthresh=0.7, vmin=vmin, vmax=vmax))
103 | ax2.set_title('$V$ (truncated corrections)')
104 | fig.colorbar(im2, ax=ax2, fraction=fraction, pad=pad, ticks=[-1, -0.5, 0, 0.5, 1], format='%.1f')
105 |
106 | # Add text boxes with labels (a), (b), (c), ...
107 | labels = ['(a)', '(b)', '(c)', '(d)']
108 | for i, ax in enumerate(axs.flat):
109 | ax.text(0.5, -0.23, labels[i], transform=ax.transAxes, ha='center')
110 | ax.axhline(20, color='gray', linestyle='--', alpha=0.5) # subdomain demarcation
111 | ax.axvline(20, color='gray', linestyle='--', alpha=0.5) # subdomain demarcation
112 | ax.set_yticks(np.arange(0, 41, 10))
113 | ax.set_xticks(np.arange(0, 41, 10))
114 |
115 | plt.savefig(filename, bbox_inches='tight', pad_inches=0.03, dpi=300)
116 | plt.close('all')
117 | print(f'Saved: {filename}')
118 |
--------------------------------------------------------------------------------
/examples/paper_code/fig5_dd_validation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from __init__ import sim_3d_random, plot_validation
4 |
5 | if os.path.basename(os.getcwd()) == 'paper_code':
6 | os.chdir('..')
7 | os.makedirs('paper_data', exist_ok=True)
8 | os.makedirs('paper_figures', exist_ok=True)
9 | filename = 'paper_data/fig5_dd_validation_'
10 | figname = 'paper_figures/fig5_dd_validation.pdf'
11 | else:
12 | try:
13 | os.makedirs('examples/paper_data', exist_ok=True)
14 | os.makedirs('examples/paper_figures', exist_ok=True)
15 | filename = 'examples/paper_data/fig5_dd_validation_'
16 | figname = 'examples/paper_figures/fig5_dd_validation.pdf'
17 | except FileNotFoundError:
18 | print("Directory not found. Please run the script from the 'paper_code' directory.")
19 |
20 | sim_size = 50 * np.array([1, 1, 1]) # Simulation size in micrometers (excluding boundaries)
21 | full_residuals = True
22 |
23 | # Run the simulations
24 | sim_ref = sim_3d_random(filename, sim_size, n_domains=None, n_boundary=0, r=12, clearance=0, full_residuals=full_residuals)
25 | sim = sim_3d_random(filename, sim_size, n_domains=(3, 1, 1), r=12, clearance=0, full_residuals=full_residuals)
26 |
27 | # plot the field
28 | plot_validation(figname, sim_ref, sim, plt_norm='log')
29 | print('Done.')
30 |
--------------------------------------------------------------------------------
/examples/paper_code/fig6_dd_truncation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | from time import time
5 | import matplotlib.pyplot as plt
6 | from matplotlib import rc, rcParams
7 | from matplotlib.ticker import LogLocator, MultipleLocator
8 |
9 | sys.path.append(".")
10 | sys.path.append("..")
11 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
12 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
13 | from wavesim.iteration import run_algorithm # to run the wavesim iteration
14 | from wavesim.utilities import preprocess, relative_error
15 | from __init__ import random_spheres_refractive_index, construct_source
16 |
17 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 13}
18 | rc('font', **font)
19 | rcParams['mathtext.fontset'] = 'cm'
20 |
21 | if os.path.basename(os.getcwd()) == 'paper_code':
22 | os.chdir('..')
23 | os.makedirs('paper_data', exist_ok=True)
24 | os.makedirs('paper_figures', exist_ok=True)
25 | current_dir = 'paper_data/'
26 | figname = (f'paper_figures/fig6_dd_truncation.pdf')
27 | else:
28 | try:
29 | os.makedirs('examples/paper_data', exist_ok=True)
30 | os.makedirs('examples/paper_figures', exist_ok=True)
31 | current_dir = 'examples/paper_data/'
32 | figname = (f'examples/paper_figures/fig6_dd_truncation.pdf')
33 | except FileNotFoundError:
34 | print("Directory not found. Please run the script from the 'paper_code' directory.")
35 |
36 | sim_size = 50 * np.array([1, 1, 1]) # Simulation size in micrometers (excluding boundaries)
37 |
38 | wavelength = 1. # Wavelength in micrometers
39 | pixel_size = wavelength/4 # Pixel size in wavelength units
40 | boundary_wavelengths = 5 # Boundary width in wavelengths
41 | boundary_widths = [round(boundary_wavelengths * wavelength / pixel_size), 0, 0] # Boundary width in pixels
42 | # Periodic boundaries True (no wrapping correction) if boundary width is 0, else False (wrapping correction)
43 | periodic = tuple(np.where(np.array(boundary_widths) == 0, True, False))
44 | n_dims = np.count_nonzero(sim_size != 1) # Number of dimensions
45 |
46 | # Size of the simulation domain
47 | n_size = np.ones_like(sim_size, dtype=int)
48 | n_size[:n_dims] = sim_size[:n_dims] * wavelength / pixel_size # Size of the simulation domain in pixels
49 | n_size = tuple(n_size.astype(int)) # Convert to integer for indexing
50 |
51 | filename = os.path.join(current_dir, f'fig6_dd_truncation.npz')
52 | if os.path.exists(filename):
53 | print(f"File {filename} already exists. Loading data...")
54 | data = np.load(filename)
55 | corrs = data['corrs']
56 | sim_time = data['sim_time']
57 | iterations = data['iterations']
58 | ure_list = data['ure_list']
59 | print(f"Loaded data from {filename}. Now plotting...")
60 | else:
61 | n = random_spheres_refractive_index(n_size, r=12, clearance=0) # Random refractive index
62 |
63 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
64 | n, boundary_array = preprocess((n**2), boundary_widths) # permittivity is n², but uses the same variable n
65 |
66 | print(f"Size of n: {n_size}")
67 | print(f"Size of n in GB: {n.nbytes / (1024**3):.2f}")
68 | assert n.imag.min() >= 0, 'Imaginary part of n² is negative'
69 | assert (n.shape == np.asarray(n_size) + 2*boundary_array).all(), 'n and n_size do not match'
70 | assert n.dtype == np.complex64, f'n is not complex64, but {n.dtype}'
71 |
72 | source = construct_source(n_size, boundary_array)
73 |
74 | domain_ref = HelmholtzDomain(permittivity=n, periodic=periodic,
75 | wavelength=wavelength, pixel_size=pixel_size)
76 |
77 | start_ref = time()
78 | # Field u and state object with information about the run
79 | u_ref, iterations_ref, residual_norm_ref = run_algorithm(domain_ref, source,
80 | max_iterations=10000)
81 | sim_time_ref = time() - start_ref
82 | print(f'\nTime {sim_time_ref:2.2f} s; Iterations {iterations_ref}; Residual norm {residual_norm_ref:.3e}')
83 | # crop the field to the region of interest
84 | u_ref = u_ref[*(slice(boundary_widths[i],
85 | u_ref.shape[i] - boundary_widths[i]) for i in range(3))].cpu().numpy()
86 |
87 | n_ext = np.array(n_size) + 2*boundary_array
88 | corrs = np.arange(n_ext[0] // 4 + 1)
89 | print(f"Number of correction points: {corrs[-1]}")
90 |
91 | ure_list = []
92 | iterations = []
93 | sim_time = []
94 |
95 | n_domains = (2, 1, 1) # number of domains in each direction
96 |
97 | for n_boundary in corrs:
98 | print(f'n_boundary {n_boundary}/{corrs[-1]}', end='\r')
99 | domain_n = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength,
100 | pixel_size=pixel_size, n_domains=n_domains, n_boundary=n_boundary)
101 |
102 | start_n = time()
103 | u_n, iterations_n, residual_norm_n = run_algorithm(domain_n, source,
104 | max_iterations=10000)
105 | sim_time_n = time() - start_n
106 | print(f'\nTime {sim_time_n:2.2f} s; Iterations {iterations_n}; Residual norm {residual_norm_n:.3e}')
107 | # crop the field to the region of interest
108 | u_n = u_n[*(slice(boundary_widths[i],
109 | u_n.shape[i] - boundary_widths[i]) for i in range(3))].cpu().numpy()
110 |
111 | ure_list.append(relative_error(u_n, u_ref))
112 | iterations.append(iterations_n)
113 | sim_time.append(sim_time_n)
114 |
115 | sim_time = np.array(sim_time)
116 | iterations = np.array(iterations)
117 | ure_list = np.array(ure_list)
118 | np.savez_compressed(filename, corrs=corrs,
119 | sim_time=sim_time,
120 | iterations=iterations,
121 | ure_list=ure_list)
122 | print(f'Saved: {filename}. Now plotting...')
123 |
124 |
125 | # Plot
126 | length = int(len(ure_list) * 2/3)
127 | x = np.arange(length)
128 | ncols = 3
129 | figsize = (12, 3)
130 |
131 | fig, axs = plt.subplots(1, ncols, figsize=figsize, gridspec_kw={'hspace': 0., 'wspace': 0.29})
132 |
133 | ax0 = axs[0]
134 | ax0.semilogy(x, ure_list[:length], 'r', lw=1., marker='x', markersize=3)
135 | ax0.set_xlabel('Number of correction points')
136 | ax0.set_ylabel('Relative Error')
137 | ax0.set_xticks(np.arange(0, round(length,-1)+1, 10))
138 | ax0.set_xlim([-2 if n_dims == 3 else -10, length + 1 if n_dims == 3 else length + 9])
139 | ax0.grid(True, which='major', linestyle='--', linewidth=0.5)
140 | ax0.grid(True, which='minor', linestyle=':', linewidth=0.3, axis='y')
141 | ax0.yaxis.set_minor_locator(LogLocator(numticks=12,subs=np.arange(2,10)))
142 | ax0.xaxis.set_minor_locator(MultipleLocator(1))
143 |
144 | ax1 = axs[1]
145 | start = 4
146 | ax1.plot(x[start:], iterations[start:length], 'g', lw=1., marker='+', markersize=3)
147 | ax1.set_xlabel('Number of correction points')
148 | ax1.set_ylabel('Iterations')
149 | ax1.set_xticks(np.arange(0, round(length,-1)+1, 10))
150 | ax1.set_xlim([start-1, length+1])
151 | ax1.grid(True, which='major', linestyle='--', linewidth=0.5)
152 | ax1.xaxis.set_minor_locator(MultipleLocator(1))
153 |
154 | ax2 = axs[2]
155 | ax2.plot(x[start:], sim_time[start:length], 'b', lw=1., marker='*', markersize=3)
156 | ax2.set_xlabel('Number of correction points')
157 | ax2.set_ylabel('Time (s)')
158 | ax2.set_xticks(np.arange(0, round(length,-1)+1, 10))
159 | ax2.set_xlim([start-1, length+1])
160 | ax2.grid(True, which='major', linestyle='--', linewidth=0.5)
161 | ax2.xaxis.set_minor_locator(MultipleLocator(1))
162 |
163 | # Add text boxes with labels (a), (b), (c), ...
164 | labels = ['(a)', '(b)', '(c)']
165 | for i, ax in enumerate(axs.flat):
166 | ax.text(0.5, -0.3, labels[i], transform=ax.transAxes, ha='center')
167 |
168 | plt.savefig(figname, bbox_inches='tight', pad_inches=0.03, dpi=300)
169 | plt.close('all')
170 | print(f'Saved: {figname}')
171 |
--------------------------------------------------------------------------------
/examples/paper_code/fig7_dd_convergence.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | from time import time
5 | from itertools import product
6 | import matplotlib.pyplot as plt
7 | from matplotlib import rc, rcParams
8 | from matplotlib.ticker import MultipleLocator
9 |
10 | sys.path.append(".")
11 | sys.path.append("..")
12 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1
13 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1
14 | from wavesim.iteration import run_algorithm # to run the anysim iteration
15 | from wavesim.utilities import preprocess
16 | from __init__ import random_spheres_refractive_index, construct_source
17 |
18 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 13}
19 | rc('font', **font)
20 | rcParams['mathtext.fontset'] = 'cm'
21 |
22 | if os.path.basename(os.getcwd()) == 'paper_code':
23 | os.chdir('..')
24 | os.makedirs('paper_data', exist_ok=True)
25 | os.makedirs('paper_figures', exist_ok=True)
26 | filename = f'paper_data/fig7_dd_convergence.txt'
27 | figname = f'paper_figures/fig7_dd_convergence.pdf'
28 | else:
29 | try:
30 | os.makedirs('examples/paper_data', exist_ok=True)
31 | os.makedirs('examples/paper_figures', exist_ok=True)
32 | filename = f'examples/paper_data/fig7_dd_convergence.txt'
33 | figname = (f'examples/paper_figures/fig7_dd_convergence.pdf')
34 | except FileNotFoundError:
35 | print("Directory not found. Please run the script from the 'paper_code' directory.")
36 |
37 | sim_size = 50 * np.array([1, 1, 1]) # Simulation size in micrometers (excluding boundaries)
38 | wavelength = 1. # Wavelength in micrometers
39 | pixel_size = wavelength/4 # Pixel size in wavelength units
40 | boundary_wavelengths = 5 # Boundary width in wavelengths
41 | n_dims = np.count_nonzero(sim_size != 1) # Number of dimensions
42 |
43 | # Size of the simulation domain
44 | n_size = np.ones_like(sim_size, dtype=int)
45 | n_size[:n_dims] = sim_size[:n_dims] * wavelength / pixel_size # Size of the simulation domain in pixels
46 | n_size = tuple(n_size.astype(int)) # Convert to integer for indexing
47 |
48 | if os.path.exists(filename):
49 | print(f"File {filename} already exists. Loading data and plotting...")
50 | else:
51 | domains = range(1, 11)
52 | for nx, ny in product(domains, domains):
53 | print(f'Domains {nx}/{domains[-1]}, {ny}/{domains[-1]}', end='\r')
54 |
55 | if nx == 1 and ny == 1:
56 | boundary_widths = [round(boundary_wavelengths * wavelength / pixel_size), 0, 0]
57 | elif nx > 1 and ny == 1:
58 | boundary_widths = [round(boundary_wavelengths * wavelength / pixel_size), 0, 0]
59 | elif nx == 1 and ny > 1:
60 | boundary_widths = [0, round(boundary_wavelengths * wavelength / pixel_size), 0]
61 | else:
62 | boundary_widths = [round(boundary_wavelengths * wavelength / pixel_size)]*2 + [0]
63 |
64 | periodic = tuple(np.where(np.array(boundary_widths) == 0, True, False))
65 | n = random_spheres_refractive_index(n_size, r=12, clearance=0) # Random refractive index
66 |
67 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
68 | n, boundary_array = preprocess((n**2), boundary_widths) # permittivity is n², but uses the same variable n
69 |
70 | print(f"Size of n: {n_size}")
71 | print(f"Size of n in GB: {n.nbytes / (1024**3):.2f}")
72 | assert n.imag.min() >= 0, 'Imaginary part of n² is negative'
73 | assert (n.shape == np.asarray(n_size) + 2*boundary_array).all(), 'n and n_size do not match'
74 | assert n.dtype == np.complex64, f'n is not complex64, but {n.dtype}'
75 |
76 | source = construct_source(n_size, boundary_array)
77 |
78 | n_domains = (nx, ny, 1)
79 | domain = MultiDomain(permittivity=n, periodic=periodic,
80 | wavelength=wavelength, pixel_size=pixel_size,
81 | n_boundary=8, n_domains=n_domains)
82 |
83 | start = time()
84 | u, iterations, residual_norm = run_algorithm(domain, source, max_iterations=10000) # Field u and state object with information about the run
85 | end = time() - start
86 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}')
87 |
88 | #%% Save data to file
89 | data = (f'Size {n_size}; Boundaries {boundary_widths}; Domains {n_domains}; '
90 | + f'Time {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e} \n')
91 | with open(filename, 'a') as file:
92 | file.write(data)
93 |
94 | #%% Domains in x AND y direction vs iterations and time
95 | data = np.loadtxt(filename, dtype=str, delimiter=';')
96 |
97 | num_domains = [data[i, 2] for i in range(len(data))]
98 | num_domains = [num_domains[i].split('(', maxsplit=1)[-1] for i in range(len(num_domains))]
99 | num_domains = [num_domains[i].split(', 1)', maxsplit=1)[0] for i in range(len(num_domains))]
100 | num_domains = [(int(num_domains[i].split(',', maxsplit=1)[0]), int(num_domains[i].split(',', maxsplit=1)[-1])) for i in range(len(num_domains))]
101 |
102 | x, y = max(num_domains, key=lambda x: x[0])[0], max(num_domains, key=lambda x: x[1])[1]
103 |
104 | #%% Both subplots in one figure
105 |
106 | iterations = [data[i, 4] for i in range(len(data))]
107 | iterations = np.array([int(iterations[i].split(' ')[2]) for i in range(len(iterations))])
108 | iterations = np.reshape(iterations, (x, y), order='F')
109 |
110 | times = [data[i, 3] for i in range(len(data))]
111 | times = [float(times[i].split(' ')[2]) for i in range(len(times))]
112 | times = np.reshape(times, (x, y), order='F')
113 |
114 | fig, ax = plt.subplots(figsize=(9, 3), nrows=1, ncols=2, sharey=True, gridspec_kw={'wspace': 0.05})
115 | cmap = 'inferno'
116 |
117 | im0 = ax[0].imshow(np.flipud(iterations), cmap=cmap, extent=[0.5, x+0.5, 0.5, y+0.5])
118 | ax[0].set_xlabel('Domains in x direction')
119 | ax[0].set_ylabel('Domains in y direction')
120 | plt.colorbar(im0, label='Iterations', fraction=0.046, pad=0.04)
121 | ax[0].set_title('Iterations vs Number of domains')
122 | ax[0].text(0.5, -0.27, '(a)', color='k', ha='center', va='center', transform=ax[0].transAxes)
123 | ax[0].xaxis.set_minor_locator(MultipleLocator(1))
124 | ax[0].yaxis.set_minor_locator(MultipleLocator(1))
125 | ax[0].set_xticks(np.arange(2, x+1, 2))
126 | ax[0].set_yticks(np.arange(2, y+1, 2))
127 |
128 | im1 = ax[1].imshow(np.flipud(times), cmap=cmap, extent=[0.5, x+0.5, 0.5, y+0.5])
129 | ax[1].set_xlabel('Domains in x direction')
130 | plt.colorbar(im1, label='Time (s)', fraction=0.046, pad=0.04)
131 | ax[1].set_title('Time vs Number of domains')
132 | ax[1].text(0.5, -0.27, '(b)', color='k', ha='center', va='center', transform=ax[1].transAxes)
133 | ax[1].xaxis.set_minor_locator(MultipleLocator(1))
134 | ax[1].yaxis.set_minor_locator(MultipleLocator(1))
135 | ax[1].set_xticks(np.arange(2, x+1, 2))
136 | ax[1].set_yticks(np.arange(2, y+1, 2))
137 |
138 | plt.savefig(figname, bbox_inches='tight', pad_inches=0.03, dpi=300)
139 | plt.close('all')
140 | print(f'Saved: {figname}')
141 |
--------------------------------------------------------------------------------
/examples/paper_code/fig8_dd_large_simulation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from __init__ import sim_3d_random, plot_validation
4 |
5 | if os.path.basename(os.getcwd()) == 'paper_code':
6 | os.chdir('..')
7 | os.makedirs('paper_data', exist_ok=True)
8 | os.makedirs('paper_figures', exist_ok=True)
9 | filename = 'paper_data/fig8_dd_large_simulation_'
10 | figname = 'paper_figures/fig8_dd_large_simulation.pdf'
11 | else:
12 | try:
13 | os.makedirs('examples/paper_data', exist_ok=True)
14 | os.makedirs('examples/paper_figures', exist_ok=True)
15 | filename = 'examples/paper_data/fig8_dd_large_simulation_'
16 | figname = 'examples/paper_figures/fig8_dd_large_simulation.pdf'
17 | except FileNotFoundError:
18 | print("Directory not found. Please run the script from the 'paper_code' directory.")
19 |
20 | sim_size = 315 * np.array([1, 1, 1]) # Simulation size in micrometers (excluding boundaries)
21 | full_residuals = True
22 |
23 | # Run the simulations
24 | sim_gpu = sim_3d_random(filename, sim_size, n_domains=(2, 1, 1), r=24, clearance=24, full_residuals=full_residuals)
25 | sim_cpu = sim_3d_random(filename, sim_size, n_domains=None, n_boundary=0, r=24, clearance=24, full_residuals=full_residuals,
26 | device='cpu')
27 |
28 | # plot the field
29 | plot_validation(figname, sim_cpu, sim_gpu, plt_norm='log', inset=True)
30 | print('Done.')
31 |
--------------------------------------------------------------------------------
/examples/run_example.py:
--------------------------------------------------------------------------------
1 | """
2 | Run Helmholtz example
3 | =====================
4 | Example script to run a simulation of a point source in a random refractive index map using the Helmholtz equation.
5 | """
6 |
7 | import os
8 | import sys
9 |
10 | import torch
11 | import numpy as np
12 | from time import time
13 | from scipy.signal.windows import gaussian
14 | from torch.fft import fftn, ifftn, fftshift
15 | import matplotlib.pyplot as plt
16 | from matplotlib import colors
17 |
18 | sys.path.append(".")
19 | from __init__ import random_permittivity, construct_source
20 | from wavesim.helmholtzdomain import HelmholtzDomain
21 | from wavesim.multidomain import MultiDomain
22 | from wavesim.iteration import run_algorithm
23 | from wavesim.utilities import preprocess, normalize
24 |
25 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
26 | os.environ["TORCH_USE_CUDA_DSA"] = "1"
27 | if os.path.basename(os.getcwd()) == 'examples':
28 | os.chdir('..')
29 |
30 | # generate a refractive index map
31 | sim_size = 50 * np.array([1, 1, 1]) # Simulation size in micrometers
32 | periodic = (False, True, True)
33 | wavelength = 1. # Wavelength in micrometers
34 | pixel_size = wavelength/4 # Pixel size in wavelength units
35 | boundary_wavelengths = 10 # Boundary width in wavelengths
36 | boundary_widths = [int(boundary_wavelengths * wavelength / pixel_size), 0, 0] # Boundary width in pixels
37 | n_dims = len(sim_size.squeeze()) # Number of dimensions
38 |
39 | # Size of the simulation domain
40 | n_size = sim_size * wavelength / pixel_size # Size of the simulation domain in pixels
41 | n_size = tuple(n_size.astype(int)) # Convert to integer for indexing
42 |
43 | # return permittivity (n²) with absorbing boundaries
44 | permittivity = random_permittivity(n_size)
45 | permittivity = preprocess(permittivity, boundary_widths)[0]
46 | assert permittivity.imag.min() >= 0, 'Imaginary part of n² is negative'
47 | assert (permittivity.shape == np.asarray(n_size) + 2*np.asarray(boundary_widths)).all(), 'permittivity and n_size do not match'
48 | assert permittivity.dtype == np.complex64, f'permittivity is not complex64, but {permittivity.dtype}'
49 |
50 | # construct a source at the center of the domain
51 | source = construct_source(source_type='point', at=np.asarray(permittivity.shape) // 2, shape=permittivity.shape)
52 | # source = construct_source(source_type='plane_wave', at=[[boundary_widths[0]]], shape=permittivity.shape)
53 | # source = construct_source(source_type='gaussian_beam', at=[[boundary_widths[0]]], shape=permittivity.shape)
54 |
55 | # # 1-domain
56 | # n_domains = (1, 1, 1) # number of domains in each direction
57 | # domain = HelmholtzDomain(permittivity=permittivity, periodic=periodic, wavelength=wavelength, pixel_size=pixel_size)
58 |
59 | # 1-domain or more with domain decomposition
60 | n_domains = (1, 1, 1) # number of domains in each direction
61 | domain = MultiDomain(permittivity=permittivity, periodic=periodic, wavelength=wavelength, pixel_size=pixel_size,
62 | n_domains=n_domains)
63 |
64 | start = time()
65 | # Field u and state object with information about the run
66 | u, iterations, residual_norm = run_algorithm(domain, source, max_iterations=1000)
67 | end = time() - start
68 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}')
69 |
70 | # %% Postprocessing
71 |
72 | file_name = './logs/size'
73 | for i in range(n_dims):
74 | file_name += f'{n_size[i]}_'
75 | file_name += f'bw{boundary_widths}_domains'
76 | for i in range(n_dims):
77 | file_name += f'{n_domains[i]}'
78 |
79 | output = (f'Size {n_size}; Boundaries {boundary_widths}; Domains {n_domains}; '
80 | + f'Time {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e} \n')
81 | if not os.path.exists('logs'):
82 | os.makedirs('logs')
83 | with open('logs/output.txt', 'a') as file:
84 | file.write(output)
85 |
86 | # %% crop and save the field
87 | # crop the field to the region of interest
88 | u = u.squeeze()[*([slice(boundary_widths[i],
89 | u.shape[i] - boundary_widths[i]) for i in range(3)])].cpu().numpy()
90 | np.savez_compressed(f'{file_name}.npz', u=u) # save the field
91 |
92 | # %% plot the field
93 | extent = np.array([0, n_size[0], n_size[1], 0])*pixel_size
94 | u = normalize(np.abs(u[:, :, u.shape[2]//2].T))
95 | plt.imshow(u, cmap='inferno', extent=extent, norm=colors.LogNorm())
96 | plt.xlabel(r'$x~(\mu m)$')
97 | plt.ylabel(r'$y~(\mu m)$')
98 | cbar = plt.colorbar(fraction=0.046, pad=0.04)
99 | cbar.ax.set_title(r'$|E|$')
100 | plt.tight_layout()
101 | plt.savefig(f'{file_name}.pdf', bbox_inches='tight', pad_inches=0.03, dpi=300)
102 | plt.show()
103 | # plt.close('all')
104 |
--------------------------------------------------------------------------------
/examples/timing_test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 |
4 | ## Tests for the fastest way to calculate the Euclidean norm of a complex vector
5 | # It seems that on the CPU, 'norm' is consistently slower than 'vdot', but gives a slightly different, perhaps more accurate value.
6 | # On the GPU, 'norm' is only very slightly faster (1%) than 'vdot' and gives the same result as vdot, indicating that the same
7 | # computations are used (perhaps with one cached memory load less in the 'norm' case)
8 |
9 | size = 100000000 # Οne hundred million elements
10 |
11 |
12 | def norm_using_vdot(vector):
13 | squared_magnitude = torch.vdot(vector, vector)
14 | return squared_magnitude.sqrt().real
15 |
16 |
17 | def norm_using_matmul(vector):
18 | x = torch.unsqueeze(vector, 0)
19 | y = torch.unsqueeze(vector, 0)
20 | squared_magnitude = torch.matmul(x, y.H).item()
21 | return squared_magnitude.sqrt().real
22 |
23 |
24 | def norm_using_linalg_norm(vector):
25 | return torch.linalg.norm(vector)
26 |
27 |
28 | def timeit(func, name, complex_vector):
29 | start_time = time.time()
30 | for x in range(100):
31 | value = func(complex_vector)
32 | print(f"Euclidean norm using {name}: {value}, Time taken: {time.time() - start_time} seconds")
33 |
34 |
35 | for device in ['cpu', 'cuda']:
36 | complex_vector = torch.randn(size, dtype=torch.float32, device=device) + 1j * torch.randn(size,
37 | dtype=torch.float32,
38 | device=device)
39 | print(device)
40 | for repeat in range(3):
41 | print(f"Repeat {repeat + 1}")
42 | # timeit(norm_using_matmul, 'torch.matmul', complex_vector) # always much slower
43 | timeit(norm_using_linalg_norm, 'torch.linalg.norm', complex_vector)
44 | timeit(norm_using_vdot, 'torch.vdot', complex_vector)
45 |
--------------------------------------------------------------------------------
/guides/ipynb2slides_guide.md:
--------------------------------------------------------------------------------
1 | ## Working with conda and jupyter notebooks (.ipynb)
2 |
3 | We need to add the conda environment to jupyter notebooks so that it can be selected in the Select Kernel option in a jupyter notebook
4 |
5 | 1. Activate the desired conda environment
6 |
7 | conda activate
8 |
9 | 2. Install the ipykernel package
10 |
11 | conda install ipykernel
12 |
13 | 3. Add/install the environment to the ipykernel
14 |
15 | python -m ipykernel install --user --name=
16 |
17 | ### Additional step for Visual Studio code
18 |
19 | 4. Install the Jupyter extension through the gui
20 |
21 | ### To open in browser (google chrome)
22 |
23 | 5. Although these packages should already be installed through the .yml file while setting up the environment, if the current working directory does not open jupyter in a browser window after entering the command:
24 |
25 | jupyter notebook
26 |
27 | 1. Set up jupyterlab and jupyter notebook with the following commands
28 |
29 | conda install -c conda-forge jupyterlab
30 | conda install -c anaconda notebook
31 |
32 | 2. Run the below command again, and now jupyter should open in a browser window:
33 |
34 | jupyter notebook
35 |
36 | 6. The environment should now be visible in the Select Kernel dropdown.
37 |
38 |
39 | ### To convert the current jupyter notebook into a presentation (plotly plots stay interactive)
40 |
41 | 1. Open the jupyter notebook in a browser window and check that running all cells gives the expected output
42 |
43 | 2. In the toolbar at the top, Click **View** --> **Cell Toolbar** ---> **Slideshow**
44 |
45 | 3. Each cell in the notebook will now have a toolbar at the top with a dropdown named **Slide Type**. In the dropdown, select **Slide** for all the cells you want to include in the presentation.
46 |
47 | 4. Convert the .ipynb notebook to a .html presentation with the command(s and the options as below)
48 |
49 | * Default options
50 |
51 | jupyter nbconvert --to slides .ipynb
52 |
53 | * If you don't want to show the code, and just the outputs
54 |
55 | jupyter nbconvert --to slides --no-input .ipynb
56 |
57 | * In addition to above, if you don't want any transitions
58 |
59 | jupyter nbconvert --to slides --no-input .ipynb --SlidesExporter.reveal_transition=none
60 |
61 | * In addition to above, if you want a specific theme (here, serif)
62 |
63 | jupyter nbconvert --to slides --no-input .ipynb --SlidesExporter.reveal_transition=none --SlidesExporter.reveal_theme=serif
64 |
65 | 5. Double-click the .html presentation \.html that should now be in the current working directory.
66 |
67 | 6. To convert the .html slides to pdf, add _?print-pdf_ in the URL in the web browser between _html_ and _#_.
--------------------------------------------------------------------------------
/guides/misc_tips.md:
--------------------------------------------------------------------------------
1 | ### Some useful conda environment management commands
2 |
3 | The [Miniconda environment management guide](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) has more details if you need them.
4 |
5 | * To update the current conda environment from a .yml file:
6 |
7 | ```
8 | conda env update --name wavesim --file environment.yml --prune
9 | ```
10 |
11 | * To export the current environment to a .yml file:
12 |
13 | ```
14 | conda env export > .yml
15 | ```
16 |
17 | * To install any packages within an environment, first go into the environment and then install the package:
18 |
19 | ```
20 | conda activate
21 | conda install
22 | ```
23 |
24 | * If conda does not have the package, and googling it suggests installing it via pip, use this command to install it specifically within the current environment and not globally (always prefer conda over pip. Only go to pip if the package is not available through conda):
25 |
26 | ```
27 | python -m pip install
28 | ```
29 |
30 | * After updating conda, setting up a new environment, installing packages, it is a nice idea to clean up any installation packages or tarballs as they are not needed anymore:
31 |
32 | ```
33 | conda clean --all
34 | ```
35 |
36 | ### If problem with specifying fonts in matplotlib.rc
37 |
38 | Example of an error: "findfont: Generic family 'sans-serif' not found because none of the following families were found: Time New Roman"
39 |
40 | 1. Check if 'mscorefonts' package installed in conda (using conda list). If not,
41 |
42 | ```
43 | conda install -c conda-forge mscorefonts
44 | ```
45 |
46 | 2. Clear matplotlib cache. An equally important step.
47 |
48 | ```
49 | rm ~/.cache/matplotlib -rf
50 | ```
51 |
52 | ### If problem with tex in matplotlib
53 |
54 | ```
55 | sudo apt install texlive texlive-latex-extra texlive-fonts-recommended dvipng cm-super
56 |
57 | python -m pip install latex
58 | ```
59 |
60 | ### For animations, ffmpeg package needed (below command for Linux)
61 |
62 | ```
63 | sudo apt-get install ffmpeg
64 | ```
65 |
--------------------------------------------------------------------------------
/guides/pytorch_gpu_setup.md:
--------------------------------------------------------------------------------
1 | ## [Setting up NVIDIA CUDA on WSL (Windows Subsystem for Linux)](https://docs.nvidia.com/cuda/wsl-user-guide/index.html)
2 |
3 | ### Main steps
4 | 1. Install NVIDIA Driver for GPU support. Select and download appropriate [driver](https://www.nvidia.com/Download/index.aspx) for GPU and Operating system
5 |
6 | 2. Install WSL 2 (follow instruction in [2.2. Step 2](https://docs.nvidia.com/cuda/wsl-user-guide/index.html))
7 |
8 | 3. Install CUDA Toolkit using WSL-Ubuntu Package, following the [command line instructions](https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&Distribution=WSL-Ubuntu&target_version=2.0&target_type=deb_local).
9 |
10 | ## Install pytorch-gpu
11 |
12 | 1. If doing this for the first time (or perhaps just in general), preferably do this in a separate conda environment, so that if anything breaks during installation, this new conda environment can just be trashed without losing anything.
13 |
14 | * Create a new conda environment
15 | ```
16 | conda create --name
17 | ```
18 |
19 | * OR, if a .yml file is available,
20 | ```
21 | conda env create -f .yml
22 | ```
23 |
24 | 2. Obtain the appropriate (depending on the build (prefer stable), os, package (here conda), language (here, Python), and compute platform (CPU or CUDA versions for GPU)) command to run from https://pytorch.org/. Example:
25 | ```
26 | python -m pip install torch --index-url https://download.pytorch.org/whl/cu126
27 | ```
28 |
29 | 3. Check that pytorch works (and with GPU) with the following series of commands in a WSL (Ubuntu) window/session, inside the conda pytorch environment
30 | ```
31 | python
32 | import torch
33 | ```
34 |
35 | * To check if pytorch works
36 | ```
37 | a = torch.randn(5,3)
38 | print(a)
39 | ```
40 |
41 | * To check if it works with GPU(s)
42 | ```
43 | torch.cuda.is_available()
44 | ```
45 | Output should be: "True"
46 |
47 | ### Monitoring GPU usage
48 |
49 | * To be able to see the GPU(s) and usage, the [nvtop](https://github.com/Syllo/nvtop) package is quite useful.
50 | ```
51 | sudo apt install nvtop
52 | ```
53 |
54 | * For WSL, nvtop installed using above approach might not work. Instead, [build from source](https://github.com/Syllo/nvtop#nvtop-build)
55 |
56 | * For Windows, nvtop is not available, but an alternate tool [nvitop](https://pypi.org/project/nvitop/) can be used instead.
57 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "wavesim"
3 | version = "0.1.0a2"
4 | description = "A library for simulating wave propagation using the modified Born series"
5 | authors = [ {name = "Ivo Vellekoop", email = "i.m.vellekoop@utwente.nl"},
6 | {name = "Swapnil Mache", email = "s.mache@utwente.nl"} ]
7 | readme = "README.md"
8 | repository = "https://github.com/IvoVellekoop/wavesim_py"
9 | documentation = "https://wavesim.readthedocs.io/en/latest/"
10 | classifiers = [
11 | 'Programming Language :: Python :: 3',
12 | 'Operating System :: OS Independent',
13 | ]
14 | requires-python = ">=3.11,<3.13"
15 | dependencies = [
16 | "numpy<2.0.0",
17 | "matplotlib>=3.9.1",
18 | "scipy>=1.14.0",
19 | "porespy",
20 | "scikit-image<0.23",
21 | "torch"
22 | ]
23 |
24 | [build-system]
25 | requires = ["poetry-core"]
26 | build-backend = "poetry.core.masonry.api"
27 |
28 | [[tool.poetry.source]]
29 | name = "pytorch-gpu"
30 | url = "https://download.pytorch.org/whl/cu126"
31 | priority = "explicit"
32 |
33 | [tool.poetry.dependencies]
34 | torch = { source = "pytorch-gpu" }
35 |
36 | [tool.poetry.group.dev]
37 | optional = true
38 |
39 | [tool.poetry.group.dev.dependencies]
40 | pytest = "^8.2.2"
41 |
42 | [tool.poetry.group.docs]
43 | optional = true
44 |
45 | [tool.poetry.group.docs.dependencies]
46 | sphinx = ">=4.1.2"
47 | sphinx_mdinclude = ">= 0.5.0"
48 | sphinx-rtd-theme = ">= 2.0.0"
49 | sphinx-autodoc-typehints = ">= 1.11.0"
50 | sphinxcontrib-bibtex = ">= 2.6.0"
51 | sphinx-markdown-builder = ">= 0.6.6"
52 | sphinx-gallery = ">= 0.15.0"
53 |
--------------------------------------------------------------------------------
/pyproject_cpu.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "wavesim"
3 | version = "0.1.0a2"
4 | description = "A library for simulating wave propagation using the modified Born series"
5 | authors = [ {name = "Ivo Vellekoop", email = "i.m.vellekoop@utwente.nl"},
6 | {name = "Swapnil Mache", email = "s.mache@utwente.nl"} ]
7 | readme = "README.md"
8 | repository = "https://github.com/IvoVellekoop/wavesim_py"
9 | documentation = "https://wavesim.readthedocs.io/en/latest/"
10 | classifiers = [
11 | 'Programming Language :: Python :: 3',
12 | 'Operating System :: OS Independent',
13 | ]
14 | requires-python = ">=3.11,<3.13"
15 | dependencies = [
16 | "numpy<2.0.0",
17 | "matplotlib>=3.9.1",
18 | "scipy>=1.14.0",
19 | "porespy",
20 | "scikit-image<0.23",
21 | "torch"
22 | ]
23 |
24 | [build-system]
25 | requires = ["poetry-core"]
26 | build-backend = "poetry.core.masonry.api"
27 |
28 | [tool.poetry.group.dev]
29 | optional = true
30 |
31 | [tool.poetry.group.dev.dependencies]
32 | pytest = "^8.2.2"
33 |
34 | [tool.poetry.group.docs]
35 | optional = true
36 |
37 | [tool.poetry.group.docs.dependencies]
38 | sphinx = ">=4.1.2"
39 | sphinx_mdinclude = ">= 0.5.0"
40 | sphinx-rtd-theme = ">= 2.0.0"
41 | sphinx-autodoc-typehints = ">= 1.11.0"
42 | sphinxcontrib-bibtex = ">= 2.6.0"
43 | sphinx-markdown-builder = ">= 0.6.6"
44 | sphinx-gallery = ">= 0.15.0"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IvoVellekoop/wavesim_py/3fc81f6ef9f3ee523575bd81f196c4f34ee0b788/requirements.txt
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | dtype = torch.complex64
4 |
5 |
6 | def allclose(a, b, rtol=0.0, atol=0.0, ulptol=100):
7 | """Check if two tensors are close to each other.
8 |
9 | Condition: |a-b| <= atol + rtol * maximum(|b|,|a|) + ulptol * ulp
10 | Where ulp is the size of the smallest representable difference between two numbers of magnitude ~max(|b[...]|)
11 | """
12 |
13 | # make sure that a and b are tensors of the same dtype and device
14 | if not torch.is_tensor(a):
15 | a = torch.tensor(a, dtype=b.dtype)
16 | if not torch.is_tensor(b):
17 | b = torch.tensor(b, dtype=a.dtype)
18 | if a.dtype != b.dtype:
19 | a = a.type(b.dtype)
20 | if a.device != b.device:
21 | a = a.to('cpu')
22 | b = b.to('cpu')
23 | a = a.to_dense()
24 | b = b.to_dense()
25 |
26 | # compute the size of a single ULP
27 | ab_max = torch.maximum(a.abs(), b.abs())
28 | exponent = ab_max.max().log2().ceil().item()
29 | ulp = torch.finfo(b.dtype).eps * 2 ** exponent
30 | tolerance = atol + rtol * ab_max + ulptol * ulp
31 | diff = (a - b).abs()
32 |
33 | if (diff - tolerance).max() <= 0.0:
34 | return True
35 | else:
36 | abs_err = diff.max().item()
37 | rel_err = (diff / ab_max).max()
38 | print(f"\nabsolute error {abs_err} = {abs_err / ulp} ulp\nrelative error {rel_err}")
39 | return False
40 |
41 |
42 | def random_vector(n_size, device='cuda' if torch.cuda.is_available() else 'cpu', dtype=dtype):
43 | """Construct a random vector for testing operators"""
44 | return torch.randn(n_size, device=device, dtype=dtype) + 1.0j * torch.randn(n_size, device=device, dtype=dtype)
45 |
46 |
47 | def random_refractive_index(n_size, device='cuda' if torch.cuda.is_available() else 'cpu', dtype=dtype):
48 | """Construct a random refractive index between 1 and 2 with a small positive imaginary part
49 |
50 | The sign of the imaginary part is such that the imaginary part of n² is positive
51 | """
52 | n = (1.0 + torch.rand(n_size, device=device, dtype=dtype) +
53 | 0.1j * torch.rand(n_size, device=device, dtype=dtype))
54 |
55 | # make sure that the imaginary part of n² is positive
56 | mask = (n ** 2).imag < 0
57 | n.imag[mask] *= -1.0
58 | return n
59 |
--------------------------------------------------------------------------------
/tests/test_analytical.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 | from wavesim.helmholtzdomain import HelmholtzDomain
5 | from wavesim.multidomain import MultiDomain
6 | from wavesim.iteration import domain_operator, preconditioned_iteration, preconditioner, run_algorithm
7 | from wavesim.utilities import analytical_solution, preprocess, relative_error
8 | from . import allclose, random_vector, random_refractive_index
9 |
10 | """Tests to compare the result of Wavesim to analytical results"""
11 |
12 |
13 | def test_no_propagation():
14 | """Basic test where the L-component is zero
15 | By manually removing the laplacian, we are solving the equation (2 π n / λ)² x = y
16 | """
17 | n = random_refractive_index((2, 3, 4))
18 | domain = HelmholtzDomain(permittivity=(n ** 2), periodic=(True, True, True))
19 | x = random_vector(domain.shape)
20 |
21 | # manually disable the propagator, and test if, indeed, we are solving the system (2 π n / λ)² x = y
22 | L1 = 1.0 + domain.shift * domain.scale
23 | domain.propagator_kernel = 1.0 / L1
24 | domain.inverse_propagator_kernel = L1
25 | k2 = -(2 * torch.pi * n * domain.pixel_size) ** 2 # -(2 π n / λ)²
26 | B = (1.0 - (k2 - domain.shift) * domain.scale)
27 | assert allclose(domain_operator(domain, 'inverse_propagator')(x), x * L1)
28 | assert allclose(domain_operator(domain, 'propagator')(x), x / L1)
29 | assert allclose(domain_operator(domain, 'medium')(x), B * x)
30 |
31 | y = domain_operator(domain, 'forward')(x)
32 | assert allclose(y, k2 * x)
33 |
34 | domain.set_source(y)
35 | alpha = 0.75
36 | M = domain_operator(domain, 'richardson', alpha=alpha)
37 | x_wavesim = M(0)
38 | assert allclose(x_wavesim, (domain.scale * alpha / L1) * B * y)
39 |
40 | for _ in range(500):
41 | x_wavesim = M(x_wavesim)
42 |
43 | assert allclose(x_wavesim, x)
44 |
45 | x_wavesim = run_algorithm(domain, y, threshold=1.e-16)[0]
46 | assert allclose(x_wavesim, x)
47 |
48 |
49 | @pytest.mark.parametrize("size", [[32, 1, 1], [7, 15, 1], [13, 25, 46]])
50 | @pytest.mark.parametrize("boundary_widths", [0, 10])
51 | @pytest.mark.parametrize("periodic", [[True, True, True], # periodic boundaries, wrapped field
52 | [False, True, True]]) # wrapping correction
53 | def test_residual(size, boundary_widths, periodic):
54 | """ Check that the residual_norm at first iteration == 1
55 | residual_norm is normalized with the preconditioned source
56 | residual_norm = norm ( B(x - (L+1)⁻¹ (B·x + c·y)) )
57 | norm of preconditioned source = norm( B(L+1)⁻¹y )
58 | """
59 | torch.manual_seed(0) # Set the random seed for reproducibility
60 | n = (torch.normal(mean=1.3, std=0.1, size=size, dtype=torch.float32)
61 | + 1j * abs(torch.normal(mean=0.05, std=0.02, size=size, dtype=torch.float32))).numpy()
62 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
63 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n
64 |
65 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(size)]]).T # Location: center of the domain
66 | values = torch.tensor([1.0]) # Amplitude: 1
67 | n_ext = tuple(np.array(size) + 2*boundary_array)
68 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64)
69 |
70 | wavelength = 1.
71 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength)
72 | # domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=1., n_domains=n_domains)
73 |
74 | # Reset the field u to zero
75 | slot_x = 0
76 | slot_tmp = 1
77 | domain.clear(slot_x)
78 | domain.set_source(source)
79 |
80 | # compute initial residual
81 | domain.add_source(slot_x, weight=1.) # [x] = y
82 | preconditioner(domain, slot_x, slot_x) # [x] = B(L+1)⁻¹y
83 | init_norm = domain.inner_product(slot_x, slot_x) # inverse of initial norm, 1 / norm([x])
84 | domain.clear(slot_x) # Clear [x]
85 |
86 | residual_norm = preconditioned_iteration(domain, slot_x, slot_x, slot_tmp, alpha=0.75, compute_norm2=True)
87 |
88 | assert np.allclose(residual_norm, init_norm)
89 |
90 |
91 | @pytest.mark.parametrize("n_domains, periodic", [
92 | ((1, 1, 1), (True, True, True)), # periodic boundaries, wrapped field.
93 | ((1, 1, 1), (False, True, True)), # wrapping correction (here and beyond)
94 | ((2, 1, 1), (False, True, True)),
95 | ((3, 1, 1), (False, True, True)),
96 | ])
97 | def test_1d_analytical(n_domains, periodic):
98 | """ Test for 1D free-space propagation. Compare with analytic solution """
99 | wavelength = 1. # wavelength in micrometer (um)
100 | pixel_size = wavelength / 4 # pixel size in wavelength units
101 | n_size = (512, 1, 1) # size of simulation domain (in pixels in x, y, and z direction)
102 | n = np.ones(n_size, dtype=np.complex64)
103 | boundary_widths = 16 # width of the boundary in pixels
104 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
105 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n
106 |
107 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain
108 | values = torch.tensor([1.0]) # Amplitude: 1
109 | n_ext = tuple(np.array(n_size) + 2*boundary_array)
110 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64)
111 |
112 | # domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength)
113 | domain = MultiDomain(permittivity=n, periodic=periodic,
114 | wavelength=wavelength, pixel_size=pixel_size, n_domains=n_domains)
115 | u_computed = run_algorithm(domain, source, max_iterations=10000)[0]
116 | u_computed = u_computed.squeeze()[boundary_widths:-boundary_widths]
117 | u_ref = analytical_solution(n_size[0], domain.pixel_size, wavelength)
118 |
119 | re = relative_error(u_computed.cpu().numpy(), u_ref)
120 | print(f'Relative error: {re:.2e}')
121 |
122 | assert re <= 1.e-3, f'Relative error: {re:.2e}'
123 | assert allclose(u_computed, u_ref, atol=1.e-3, rtol=1.e-3)
124 |
--------------------------------------------------------------------------------
/tests/test_basics.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from wavesim.helmholtzdomain import HelmholtzDomain
4 | from wavesim.multidomain import MultiDomain
5 | from . import allclose, random_vector, random_refractive_index, dtype
6 |
7 | """ Performs a set of basic consistency checks for the Domain class and the HelmholtzBase multi-domain class. """
8 |
9 |
10 | def construct_domain(n_size, n_domains, n_boundary, periodic=(False, False, True)):
11 | """ Construct a domain or multi-domain"""
12 | torch.manual_seed(12345)
13 | n = random_refractive_index(n_size)
14 | if n_domains is None: # single domain
15 | return HelmholtzDomain(permittivity=n, periodic=periodic, n_boundary=n_boundary, debug=True)
16 | else:
17 | return MultiDomain(permittivity=n, periodic=periodic, n_boundary=n_boundary,
18 | n_domains=n_domains, debug=True)
19 |
20 |
21 | def construct_source(n_size):
22 | """ Construct a sparse-matrix source with some points at the corners and in the center"""
23 | locations = torch.tensor([
24 | [n_size[0] // 2, 0, 0],
25 | [n_size[1] // 2, 0, 0],
26 | [n_size[2] // 2, 0, n_size[2] - 1]])
27 |
28 | return torch.sparse_coo_tensor(locations, torch.tensor([1, 1, 1]), n_size, dtype=dtype)
29 |
30 |
31 | @pytest.mark.parametrize("n_size", [(128, 100, 93), (50, 49, 1)])
32 | @pytest.mark.parametrize("n_domains", [None, (1, 1, 1), (3, 2, 1)])
33 | def test_basics(n_size: tuple[int, int, int], n_domains: tuple[int, int, int] | None):
34 | """Tests the basic functionality of the Domain and MultiDomain classes
35 |
36 | Tests constructing a domain, subdividing data over subdomains,
37 | concatenating data from the subdomains, adding sparse sources,
38 | and computing the inner product.
39 | """
40 |
41 | # construct the (multi-) domain operator
42 | domain = construct_domain(n_size, n_domains, n_boundary=8)
43 |
44 | # test coordinates
45 | assert domain.shape == n_size
46 | for dim in range(3):
47 | coordinates = domain.coordinates(dim)
48 | assert coordinates.shape[dim] == n_size[dim]
49 | assert coordinates.numel() == n_size[dim]
50 |
51 | coordinates_f = domain.coordinates_f(dim)
52 | assert coordinates_f.shape == coordinates.shape
53 | assert coordinates_f[0, 0, 0] == 0
54 |
55 | if n_size[dim] > 1:
56 | assert allclose(coordinates.flatten()[1] - coordinates.flatten()[0], domain.pixel_size)
57 | assert allclose(coordinates_f.flatten()[1] - coordinates_f.flatten()[0],
58 | 2.0 * torch.pi / (n_size[dim] * domain.pixel_size))
59 |
60 | # construct a random vector for testing operators
61 | x = random_vector(n_size, device=domain.device)
62 | y = random_vector(n_size, device=domain.device)
63 |
64 | # perform some very basic checks
65 | # mainly, this tests if the partitioning and composition works correctly
66 | domain.set(0, x)
67 | domain.set(1, y)
68 | assert x.device == domain.device
69 | assert allclose(domain.get(0), x)
70 | assert allclose(domain.get(1), y)
71 |
72 | inp = domain.inner_product(0, 1)
73 | assert allclose(inp, torch.vdot(x.flatten(), y.flatten()))
74 |
75 | # construct a source and test adding it
76 | domain.clear(0)
77 | assert allclose(domain.get(0), 0.0)
78 | source = construct_source(n_size)
79 | domain.set_source(source)
80 | domain.add_source(0, 0.9)
81 | domain.add_source(0, 1.1)
82 | assert allclose(domain.get(0), 2.0 * source)
83 | x[0, 0, 0] = 1
84 | y[0, 0, 0] = 2
85 | # test mixing: α x + β y
86 | # make sure to include the special cases α=0, β=0, α=1, β=1 and α+β=1
87 | # since they may be optimized and thus have different code paths
88 | for alpha in [0.0, 1.0, 0.25, -0.1]:
89 | for beta in [0.0, 1.0, 0.75]:
90 | for out_slot in [0, 1]:
91 | domain.set(0, x)
92 | domain.set(1, y)
93 | domain.mix(alpha, 0, beta, 1, out_slot)
94 | assert allclose(domain.get(out_slot), alpha * x + beta * y)
95 |
96 |
97 | @pytest.mark.parametrize("n_size", [(128, 100, 93), (50, 49, 1)])
98 | @pytest.mark.parametrize("n_domains", [None, (1, 1, 1), (3, 2, 1)])
99 | def test_propagator(n_size: tuple[int, int, int], n_domains: tuple[int, int, int] | None):
100 | """Tests the forward and inverse propagator
101 |
102 | The wavesim algorithm only needs the propagator (L+1)^(-1) to be implemented.
103 | For testing, and for evaluating the final residue, the Domain and MultiDomain classes
104 | also implement the 'inverse propagator L+1', which is basically the homogeneous part of the forward operator A.
105 |
106 | This test checks that the forward and inverse propagator are consistent, namely (L+1)^(-1) (L+1) x = x.
107 | todo: check if the operators are actually correct (not just consistent)
108 | Note that the propagator is domain-local, so the wrapping correction and domain
109 | transfer functions are not tested here.
110 | """
111 |
112 | # construct the (multi-) domain operator
113 | domain = construct_domain(n_size, n_domains, n_boundary=8)
114 |
115 | # assert that (L+1) (L+1)^-1 x = x
116 | x = random_vector(n_size)
117 | domain.set(0, x)
118 | domain.propagator(0, 0)
119 | domain.inverse_propagator(0, 0)
120 | x_reconstructed = domain.get(0)
121 | assert allclose(x, x_reconstructed)
122 |
123 | # also assert that (L+1)^-1 (L+1) x = x, use different slots for input and output
124 | domain.set(0, x)
125 | domain.inverse_propagator(0, 1)
126 | domain.propagator(1, 1)
127 | x_reconstructed = domain.get(1)
128 | assert allclose(x, x_reconstructed)
129 |
130 | # for the non-decomposed case, test if the propagator gives the correct value
131 | if n_domains is None:
132 | n_size = torch.tensor(n_size, dtype=torch.float64)
133 | # choose |k| < Nyquist, make sure k is at exact grid point in Fourier space
134 | k_relative = torch.tensor((0.2, -0.15, 0.4), dtype=torch.float64)
135 | k = 2 * torch.pi * torch.round(k_relative * n_size) / n_size # in 1/pixels
136 | k[n_size == 1] = 0.0
137 | plane_wave = torch.exp(1j * (
138 | k[0] * torch.arange(n_size[0], device=domain.device).reshape(-1, 1, 1) +
139 | k[1] * torch.arange(n_size[1], device=domain.device).reshape(1, -1, 1) +
140 | k[2] * torch.arange(n_size[2], device=domain.device).reshape(1, 1, -1)))
141 | domain.set(0, plane_wave)
142 | domain.inverse_propagator(0, 0)
143 | result = domain.get(0)
144 | laplace_kernel = (k[0]**2 + k[1]**2 + k[2]**2) / domain.pixel_size ** 2 # -∇² [negative of laplace kernel]
145 | correct_result = (1.0 + domain.scale * (laplace_kernel + domain.shift)) * plane_wave # L+1 = scale·(-∇²) + 1.
146 | # note: the result is not exactly the same because wavesim is using the real-space kernel, and we compare to
147 | # the Fourier-space kernel
148 | assert allclose(result, correct_result, rtol=0.01)
149 |
150 |
151 | def test_basic_wrapping():
152 | """Simple test if the wrapping correction is applied at the correct position.
153 |
154 | Constructs a 1-D domain and splits it in two. A source is placed at the right edge of the left domain.
155 | """
156 | n_size = (10, 1, 1)
157 | n_boundary = 2
158 | source = torch.sparse_coo_tensor(torch.tensor([[(n_size[0] - 1) // 2, 0, 0]]).T, torch.tensor([1.0]), n_size,
159 | dtype=dtype)
160 | domain = MultiDomain(permittivity=torch.ones(n_size, dtype=dtype), n_domains=(2, 1, 1),
161 | n_boundary=n_boundary, periodic=(False, True, True))
162 | domain.clear(0)
163 | domain.set_source(source)
164 | domain.add_source(0, 1.0)
165 | left = torch.squeeze(domain.domains[0, 0, 0].get(0))
166 | right = torch.squeeze(domain.domains[1, 0, 0].get(0))
167 | total = torch.squeeze(domain.get(0))
168 | assert allclose(torch.concat([left.to(domain.device), right.to(domain.device)]), total)
169 | assert torch.all(right == 0.0)
170 | assert torch.all(left[:-2] == 0.0)
171 | assert left[-1] != 0.0
172 |
173 | domain.medium(0, 1)
174 |
175 | # periodic in 2nd and 3rd dimension: no edges
176 | left_edges = domain.domains[0, 0, 0].edges
177 | right_edges = domain.domains[1, 0, 0].edges
178 | for edge in range(2, 6):
179 | assert left_edges[edge] is None
180 | assert right_edges[edge] is None
181 |
182 | # right domain should have zero edge corrections (since domain is empty)
183 | assert torch.all(right_edges[0] == 0.0)
184 | assert torch.all(right_edges[1] == 0.0)
185 |
186 | # left domain should have wrapping correction at the right edge
187 | # and nothing at the left edge
188 | assert torch.all(left_edges[0] == 0.0)
189 | assert left_edges[1].abs().max() > 1e-3
190 |
191 | # after applying the correction, the first n_boundary elements
192 | # of the left domain should be non-zero (wrapping correction)
193 | # and the first n_boundary elements of the right domain should be non-zero
194 | total2 = torch.squeeze(domain.get(1))
195 | assert allclose(total2.real[0:n_boundary], -total2.real[n_size[0] // 2:n_size[0] // 2 + n_boundary])
196 |
197 |
198 | def test_wrapped_propagator():
199 | """Tests the inverse propagator L+1 with wrapping corrections
200 |
201 | This test compares the situation of a single large domain to that of a multi-domain.
202 | If the wrapping and transfer corrections are implemented correctly, the results should be the same
203 | up to the difference in scaling factor.
204 | """
205 | # n_size = (128, 100, 93, 1)
206 | n_size = (3 * 32 * 1024, 1, 1)
207 | n_boundary = 16
208 | domain_single = construct_domain(n_size, n_domains=None, n_boundary=n_boundary, periodic=(True, True, True))
209 | domain_multi = construct_domain(n_size, n_domains=(3, 1, 1), n_boundary=n_boundary, periodic=(True, True, True))
210 | source = torch.sparse_coo_tensor(torch.tensor([[0, 0, 0]]).T, torch.tensor([1.0]), n_size, dtype=dtype)
211 |
212 | x = [None, None]
213 | for i, domain in enumerate([domain_single, domain_multi]):
214 | # evaluate L+1-B = L + Vscat + wrapping correction for the multi-domain,
215 | # and L+1-B = L + Vscat for the full domain
216 | # Note that we need to compensate for scaling squared,
217 | # because scaling affects both the source and operators L and B
218 | B = 0
219 | L1 = 1
220 | domain.clear(0)
221 | domain.set_source(source)
222 | domain.add_source(0, 1.0)
223 | domain.inverse_propagator(0, L1) # (L+1) y
224 | domain.medium(0, B) # (1-V) y
225 | domain.mix(1.0, L1, -1.0, B, 0) # (L+V) y
226 | x[i] = domain.get(0) / domain.scale
227 |
228 | # first non-compensated point
229 | pos = domain_multi.domains[0].shape[0] - n_boundary - 1
230 | atol = x[0][pos, 0, 0].abs().item()
231 | assert allclose(x[0], x[1], atol=atol)
232 |
--------------------------------------------------------------------------------
/tests/test_examples.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import os
3 | import torch
4 | import numpy as np
5 | from scipy.io import loadmat
6 | from PIL.Image import BILINEAR, fromarray, open
7 | from wavesim.helmholtzdomain import HelmholtzDomain
8 | from wavesim.multidomain import MultiDomain
9 | from wavesim.iteration import run_algorithm
10 | from wavesim.utilities import pad_boundaries, preprocess, relative_error
11 |
12 | if os.path.basename(os.getcwd()) == 'tests':
13 | os.chdir('..')
14 |
15 |
16 | @pytest.mark.parametrize("n_domains, periodic", [
17 | (None, (True, True, True)), # periodic boundaries, wrapped field.
18 | ((1, 1, 1), (False, True, True)), # wrapping correction (here and beyond)
19 | ((2, 1, 1), (False, True, True)),
20 | ((3, 1, 1), (False, True, True)),
21 | ])
22 | def test_1d_glass_plate(n_domains, periodic):
23 | """ Test for 1D propagation through glass plate. Compare with reference solution (matlab repo result). """
24 | wavelength = 1.
25 | n_size = (256, 1, 1)
26 | n = np.ones(n_size, dtype=np.complex64)
27 | n[99:130] = 1.5
28 | boundary_widths = 24
29 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
30 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n
31 |
32 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain
33 | values = torch.tensor([1.0]) # Amplitude: 1
34 | n_ext = tuple(np.array(n_size) + 2*boundary_array)
35 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64)
36 |
37 | if n_domains is None: # 1-domain, periodic boundaries (without wrapping correction)
38 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength)
39 | else: # OR. Domain decomposition
40 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, n_domains=n_domains)
41 |
42 | u_computed = run_algorithm(domain, source, max_iterations=2000)[0]
43 | u_computed = u_computed.squeeze()[boundary_widths:-boundary_widths]
44 | # load dictionary of results from matlab wavesim/anysim for comparison and validation
45 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u'])
46 |
47 | re = relative_error(u_computed.cpu().numpy(), u_ref)
48 | print(f'Relative error: {re:.2e}')
49 | threshold = 1.e-3
50 | assert re < threshold, f"Relative error higher than {threshold}"
51 |
52 |
53 | @pytest.mark.parametrize("n_domains", [
54 | None, # periodic boundaries, wrapped field.
55 | (1, 1, 1), # wrapping correction (here and beyond)
56 | (2, 1, 1),
57 | (3, 1, 1),
58 | (1, 2, 1),
59 | (1, 3, 1),
60 | (2, 2, 1),
61 | ])
62 | def test_2d_low_contrast(n_domains):
63 | """ Test for propagation in 2D structure with low refractive index contrast (made of fat and water to mimic
64 | biological tissue). Compare with reference solution (matlab repo result). """
65 | oversampling = 0.25
66 | im = np.asarray(open('examples/logo_structure_vector.png')) / 255
67 | n_water = 1.33
68 | n_fat = 1.46
69 | n_im = (np.where(im[:, :, 2] > 0.25, 1, 0) * (n_fat - n_water)) + n_water
70 | n_roi = int(oversampling * n_im.shape[0])
71 | n = np.asarray(fromarray(n_im).resize((n_roi, n_roi), BILINEAR))
72 | boundary_widths = 40
73 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
74 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n
75 |
76 | source = np.asarray(fromarray(im[:, :, 1]).resize((n_roi, n_roi), BILINEAR))
77 | source = pad_boundaries(source, boundary_array)
78 | source = torch.tensor(source, dtype=torch.complex64)
79 |
80 | wavelength = 0.532
81 | pixel_size = wavelength / (3 * abs(n_fat))
82 |
83 | if n_domains is None: # 1-domain, periodic boundaries (without wrapping correction)
84 | periodic = (True, True, True) # periodic boundaries, wrapped field.
85 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength)
86 | else: # OR. Domain decomposition
87 | periodic = np.where(np.array(n_domains) == 1, True, False) # True for 1 domain in direction, False otherwise
88 | periodic = tuple(periodic)
89 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, pixel_size=pixel_size,
90 | n_domains=n_domains)
91 |
92 | u_computed = run_algorithm(domain, source, max_iterations=10000)[0]
93 | u_computed = u_computed.squeeze()[*([slice(boundary_widths, -boundary_widths)]*2)]
94 | # load dictionary of results from matlab wavesim/anysim for comparison and validation
95 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u2d_lc'])
96 |
97 | re = relative_error(u_computed.cpu().numpy(), u_ref)
98 | print(f'Relative error: {re:.2e}')
99 | threshold = 1.e-3
100 | assert re < threshold, f"Relative error higher than {threshold}"
101 |
102 |
103 | @pytest.mark.parametrize("n_domains", [
104 | None, # periodic boundaries, wrapped field.
105 | (1, 1, 1), # wrapping correction (here and beyond)
106 | (1, 2, 1),
107 | ])
108 | def test_2d_high_contrast(n_domains):
109 | """ Test for propagation in 2D structure made of iron, with high refractive index contrast.
110 | Compare with reference solution (matlab repo result). """
111 |
112 | oversampling = 0.25
113 | im = np.asarray(open('examples/logo_structure_vector.png')) / 255
114 | n_iron = 2.8954 + 2.9179j
115 | n_contrast = n_iron - 1
116 | n_im = ((np.where(im[:, :, 2] > 0.25, 1, 0) * n_contrast) + 1)
117 | n_roi = int(oversampling * n_im.shape[0])
118 | n = np.asarray(fromarray(n_im.real).resize((n_roi, n_roi), BILINEAR)) + 1j * np.asarray(
119 | fromarray(n_im.imag).resize((n_roi, n_roi), BILINEAR))
120 | boundary_widths = 8
121 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
122 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n
123 |
124 | source = np.asarray(fromarray(im[:, :, 1]).resize((n_roi, n_roi), BILINEAR))
125 | source = pad_boundaries(source, boundary_array)
126 | source = torch.tensor(source, dtype=torch.complex64)
127 |
128 | wavelength = 0.532
129 | pixel_size = wavelength / (3 * np.max(abs(n_contrast + 1)))
130 |
131 | if n_domains is None: # 1-domain, periodic boundaries (without wrapping correction)
132 | periodic = (True, True, True) # periodic boundaries, wrapped field.
133 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength)
134 | else: # OR. Domain decomposition
135 | periodic = np.where(np.array(n_domains) == 1, True, False) # True for 1 domain in direction, False otherwise
136 | periodic = tuple(periodic)
137 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, pixel_size=pixel_size,
138 | n_domains=n_domains)
139 |
140 | u_computed = run_algorithm(domain, source, max_iterations=int(1.e+5))[0]
141 | u_computed = u_computed.squeeze()[*([slice(boundary_widths, -boundary_widths)]*2)]
142 |
143 | # load dictionary of results from matlab wavesim/anysim for comparison and validation
144 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u2d_hc'])
145 |
146 | re = relative_error(u_computed.cpu().numpy(), u_ref)
147 | print(f'Relative error: {re:.2e}')
148 | threshold = 1.e-3
149 | assert re < threshold, f"Relative error {re} higher than {threshold}"
150 |
151 |
152 | @pytest.mark.parametrize("n_domains", [
153 | None, # periodic boundaries, wrapped field.
154 | (1, 1, 1), # wrapping correction (here and beyond)
155 | (2, 1, 1),
156 | (3, 1, 1),
157 | (1, 2, 1),
158 | (1, 3, 1),
159 | (1, 1, 2),
160 | (2, 2, 1),
161 | (2, 1, 2),
162 | (1, 2, 2),
163 | (2, 2, 2),
164 | (3, 2, 1),
165 | (3, 1, 2),
166 | (1, 3, 2),
167 | (1, 2, 3),
168 | ])
169 | def test_3d_disordered(n_domains):
170 | """ Test for propagation in a 3D disordered medium. Compare with reference solution (matlab repo result). """
171 | wavelength = 1.
172 | n_size = (128, 48, 96)
173 | n = np.ascontiguousarray(loadmat('examples/matlab_results.mat')['n3d_disordered'])
174 | boundary_widths = 12
175 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2)
176 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n
177 |
178 | # Source: single point source in the center of the domain
179 | indices = torch.tensor([[int(v/2 - 1) + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location
180 | values = torch.tensor([1.0]) # Amplitude: 1
181 | n_ext = tuple(np.array(n_size) + 2*boundary_array)
182 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64)
183 |
184 | if n_domains is None: # 1-domain, periodic boundaries (without wrapping correction)
185 | periodic = (True, True, True) # periodic boundaries, wrapped field.
186 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength)
187 | else: # OR. Domain decomposition
188 | periodic = np.where(np.array(n_domains) == 1, True, False) # True for 1 domain in direction, False otherwise
189 | periodic = tuple(periodic)
190 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, n_domains=n_domains)
191 |
192 | u_computed = run_algorithm(domain, source, max_iterations=1000)[0]
193 | u_computed = u_computed.squeeze()[*([slice(boundary_widths, -boundary_widths)]*3)]
194 |
195 | # load dictionary of results from matlab wavesim/anysim for comparison and validation
196 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u3d_disordered'])
197 |
198 | re = relative_error(u_computed.cpu().numpy(), u_ref)
199 | print(f'Relative error: {re:.2e}')
200 | threshold = 1.e-3
201 | assert re < threshold, f"Relative error {re} higher than {threshold}"
202 |
--------------------------------------------------------------------------------
/tests/test_operators.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from wavesim.helmholtzdomain import HelmholtzDomain
4 | from wavesim.multidomain import MultiDomain
5 | from wavesim.iteration import domain_operator
6 | from wavesim.utilities import full_matrix
7 | from . import random_vector, allclose, dtype
8 |
9 | """ Performs checks on the operators represented as matrices (accretivity, norm)."""
10 |
11 | parameters = [
12 | {'n_size': (1, 1, 12), 'n_domains': None, 'n_boundary': 0, 'periodic': (False, False, True)},
13 | {'n_size': (1, 1, 12), 'n_domains': (1, 1, 1), 'n_boundary': 0, 'periodic': (False, False, True)},
14 | {'n_size': (1, 1, 12), 'n_domains': (1, 1, 1), 'n_boundary': 5, 'periodic': (True, True, False)},
15 | {'n_size': (1, 1, 32), 'n_domains': (1, 1, 2), 'n_boundary': 5, 'periodic': (True, True, True)},
16 | {'n_size': (1, 1, 32), 'n_domains': (1, 1, 2), 'n_boundary': 5, 'periodic': (True, True, False)},
17 | {'n_size': (1, 32, 1), 'n_domains': (1, 2, 1), 'n_boundary': 5, 'periodic': (True, True, True)},
18 | {'n_size': (1, 32, 1), 'n_domains': (1, 2, 1), 'n_boundary': 5, 'periodic': (True, False, True)},
19 | {'n_size': (32, 1, 1), 'n_domains': (2, 1, 1), 'n_boundary': 5, 'periodic': (False, True, True)},
20 | {'n_size': (32, 1, 1), 'n_domains': (2, 1, 1), 'n_boundary': 5, 'periodic': (True, True, True)},
21 | # test different-sized domains
22 | {'n_size': (23, 1, 1), 'n_domains': (3, 1, 1), 'n_boundary': 2, 'periodic': (True, True, True)},
23 | {'n_size': (23, 1, 1), 'n_domains': (3, 1, 1), 'n_boundary': 3, 'periodic': (False, True, True)},
24 | {'n_size': (1, 5, 19), 'n_domains': (1, 1, 2), 'n_boundary': 3, 'periodic': (True, True, True)},
25 | {'n_size': (1, 14, 19), 'n_domains': (1, 2, 2), 'n_boundary': 3, 'periodic': (True, False, True)},
26 | {'n_size': (17, 30, 1), 'n_domains': (2, 3, 1), 'n_boundary': 3, 'periodic': (True, True, True)},
27 | {'n_size': (8, 8, 8), 'n_domains': (2, 2, 2), 'n_boundary': 2, 'periodic': (False, False, True)},
28 | {'n_size': (8, 12, 8), 'n_domains': (2, 3, 2), 'n_boundary': 2, 'periodic': (True, True, True)},
29 | # these parameters are very slow for test_accretivity and should be run only when needed
30 | # {'n_size': (12, 12, 12), 'n_domains': (2, 2, 2), 'n_boundary': 3, 'periodic': (False, False, True)},
31 | # {'n_size': (18, 24, 18), 'n_domains': (2, 3, 2), 'n_boundary': 3, 'periodic': (True, True, True)},
32 | # {'n_size': (17, 23, 19), 'n_domains': (2, 3, 2), 'n_boundary': 3, 'periodic': (True, True, True)},
33 | ]
34 |
35 |
36 | def construct_domain(n_size, n_domains, n_boundary, periodic=(False, False, True)):
37 | """ Construct a domain or multi-domain"""
38 | torch.manual_seed(12345)
39 | n = torch.rand(n_size, device='cuda' if torch.cuda.is_available() else 'cpu',
40 | dtype=dtype) + 1.0 # random refractive index between 1 and 2
41 | n.imag = 0.1 * torch.maximum(n.imag, torch.tensor(0.0)) # a positive imaginary part of n corresponds to absorption
42 | if n_domains is None: # single domain
43 | return HelmholtzDomain(permittivity=n, periodic=periodic, n_boundary=n_boundary, debug=True)
44 | else:
45 | return MultiDomain(permittivity=n, periodic=periodic, n_boundary=n_boundary,
46 | n_domains=n_domains, debug=True)
47 |
48 |
49 | @pytest.mark.parametrize("params", parameters)
50 | def test_operators(params):
51 | """ Check that operator definitions are consistent:
52 | - forward = inverse_propagator - medium: A= L + 1 - B
53 | - preconditioned_operator = preconditioned(operator)
54 | - richardson = x + α (Γ⁻¹b - Γ⁻¹A x)
55 | """
56 | domain = construct_domain(**params)
57 | x = random_vector(domain.shape, device=domain.device)
58 | B = domain_operator(domain, 'medium')
59 | L1 = domain_operator(domain, 'inverse_propagator')
60 | A = domain_operator(domain, 'forward')
61 | Ax = A(x)
62 | assert allclose(domain.scale * Ax, L1(x) - B(x))
63 |
64 | Γ = domain_operator(domain, 'preconditioner')
65 | ΓA = domain_operator(domain, 'preconditioned_operator')
66 | ΓAx = ΓA(x)
67 | assert allclose(ΓAx, Γ(Ax))
68 |
69 | α = 0.1
70 | b = random_vector(domain.shape)
71 | Γb = Γ(b)
72 | domain.set_source(b)
73 | M = domain_operator(domain, 'richardson', alpha=α)
74 | assert allclose(M(0), α * Γb)
75 |
76 | residual = Γb - ΓAx
77 | assert allclose(M(x), x.to(residual.device) + α * residual)
78 |
79 |
80 | @pytest.mark.parametrize("params", parameters)
81 | def test_accretivity(params):
82 | """ Checks norm and lower bound of real part for various operators
83 |
84 | B (medium) should have real part between -0.05 and 1.0 (if we don't put the absorption in V0. If we do, the upper
85 | limit may be 1.95)
86 | The operator B-1 should have a norm of less than 0.95
87 |
88 | L + 1 (inverse propagator) should be accretive with a real part of at least 1.0
89 | (L+1)^-1 (propagator) should be accretive with a real part of at least 0.0
90 | A (forward) should be accretive with a real part of at least 0.0
91 | ΓA (preconditioned_operator) should be such that 1-ΓA is a contraction (a norm of less than 1.0)
92 | """
93 | domain = construct_domain(**params)
94 | domain.set_source(0)
95 | assert_accretive(domain_operator(domain, 'medium'), 'B', real_min=0.05, real_max=1.0, norm_max=0.95,
96 | norm_offset=1.0)
97 | assert_accretive(domain_operator(domain, 'inverse_propagator'), 'L + 1', real_min=1.0)
98 | assert_accretive(domain_operator(domain, 'propagator'), '(L + 1)^-1', real_min=0.0)
99 | assert_accretive(domain_operator(domain, 'preconditioned_operator'), 'ΓA', norm_max=1.0, norm_offset=1.0)
100 | assert_accretive(domain_operator(domain, 'richardson', alpha=0.75), '1- α ΓA', norm_max=1.0)
101 | assert_accretive(domain_operator(domain, 'forward'), 'A', real_min=0.0, pre_factor=domain.scale)
102 |
103 |
104 | def assert_accretive(operator, name, *, real_min=None, real_max=None, norm_max=None, norm_offset=None, pre_factor=None):
105 | """ Helper function to check if an operator is accretive, and to compute the norm around a given offset.
106 | This function constructs a full matrix from the operator, so it only works if the domain is not too large.
107 | """
108 | M = full_matrix(operator)
109 | if pre_factor is not None:
110 | M *= pre_factor
111 |
112 | if norm_max is not None:
113 | if norm_offset is not None:
114 | M.diagonal().add_(-norm_offset)
115 | norm = torch.linalg.norm(M, ord=2)
116 | print(f'norm {norm:.2e}')
117 | assert norm <= norm_max, f'operator {name} has norm {norm} > {norm_max}'
118 |
119 | if real_min is not None or real_max is not None:
120 | M.add_(M.mH)
121 | eigs = torch.linalg.eigvalsh(M)
122 | if norm_offset is not None:
123 | eigs.add_(norm_offset)
124 | if real_min is not None:
125 | acc = eigs.min()
126 | print(f'acc {acc:.2e}')
127 | assert acc >= real_min, f'operator {name} is not accretive, min λ_(A+A*) = {acc} < {real_min}'
128 | if real_max is not None:
129 | acc = eigs.max()
130 | print(f'acc {acc:.2e}')
131 | assert acc <= real_max, (f'operator {name} has eigenvalues that are too large, '
132 | f'max λ_(A+A*) = {acc} > {real_max}')
133 |
--------------------------------------------------------------------------------
/tests/test_utilities.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from wavesim.utilities import partition, combine
4 | from . import allclose
5 |
6 | """Test of the utility functions."""
7 |
8 |
9 | @pytest.mark.parametrize("size", [(5, 4, 6), (7, 15, 32), (3, 5, 6)])
10 | @pytest.mark.parametrize("n_domains", [(1, 2, 3), (3, 3, 3), (2, 4, 1)])
11 | @pytest.mark.parametrize("type", ['full', 'sparse', 'hybrid1', 'hybrid2'])
12 | @pytest.mark.parametrize("expanded", [False, True])
13 | def test_partition_combine(size: tuple[int, int, int], n_domains: tuple[int, int, int], type: str, expanded: bool):
14 | if expanded:
15 | x = torch.tensor(1.0, dtype=torch.complex64).expand(size)
16 | else:
17 | x = torch.randn(size, dtype=torch.complex64) + 1j * torch.randn(size, dtype=torch.complex64)
18 | if type == 'sparse':
19 | x[x.real < 0.5] = 0
20 | x = x.to_sparse()
21 | elif type == 'hybrid1':
22 | # select half of the slices, make rest zero
23 | indices = torch.range(0, size[0] - 1, 2, dtype=torch.int64) # construct indices for the other half
24 | values = x[0::2, :, :]
25 | x = torch.sparse_coo_tensor(indices.reshape(1, -1), values, size)
26 | elif type == 'hybrid2':
27 | indices0 = torch.range(0, size[0] - 1, 2, dtype=torch.int64)
28 | indices1 = torch.range(0, size[1] - 1, 2, dtype=torch.int64)
29 | i0, i1 = torch.meshgrid(indices0, indices1)
30 | indices = torch.stack((i0.reshape(-1), i1.reshape(-1)), dim=0)
31 | values = x[0::2, 0::2, :].reshape(-1, x.shape[2])
32 | x = torch.sparse_coo_tensor(indices, values, size)
33 |
34 | partitions = partition(x, n_domains)
35 | assert partitions.shape == n_domains
36 |
37 | combined = combine(partitions)
38 | assert allclose(combined, x)
39 |
40 |
41 | @pytest.mark.parametrize("n_domains", [(0, 0, 0), (4, 3, 3)])
42 | def test_partition_with_invalid_input(n_domains):
43 | array = torch.randn((3, 3, 3), dtype=torch.complex64)
44 | with pytest.raises(ValueError):
45 | partition(array, n_domains)
46 |
--------------------------------------------------------------------------------
/wavesim/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IvoVellekoop/wavesim_py/3fc81f6ef9f3ee523575bd81f196c4f34ee0b788/wavesim/__init__.py
--------------------------------------------------------------------------------
/wavesim/domain.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from abc import ABCMeta, abstractmethod
3 |
4 |
5 | class Domain(metaclass=ABCMeta):
6 | """Base class for all simulation domains
7 |
8 | This base class defines the interface for operations that are common for all simulation types,
9 | and for MultiDomain.
10 | todo: the design using slots minimizes memory use, but it is a suboptimal design because it
11 | mixes mutable and immutable state. This design should be revisited so that the Domain is
12 | immutable, and the code that runs the algorithms performs the memory management.
13 | """
14 |
15 | def __init__(self, pixel_size: float, shape, device):
16 | self.pixel_size = pixel_size
17 | self.scale = None
18 | self.shift = None
19 | self.shape = shape
20 | self.device = device
21 |
22 | @abstractmethod
23 | def add_source(self, slot, weight: float):
24 | pass
25 |
26 | @abstractmethod
27 | def clear(self, slot):
28 | """Clears the data in the specified slot"""
29 | pass
30 |
31 | @abstractmethod
32 | def get(self, slot: int, copy=False):
33 | """Returns the data in the specified slot.
34 |
35 | :param slot: slot from which to return the data
36 | :param copy: if True, returns a copy of the data. Otherwise, may return the original data possible.
37 |
38 | Note that this data may be overwritten by the next call to domain.
39 | """
40 | pass
41 |
42 | @abstractmethod
43 | def set(self, slot, data):
44 | """Copy the date into the specified slot"""
45 | pass
46 |
47 | @abstractmethod
48 | def inner_product(self, slot_a, slot_b):
49 | """Computes the inner product of two data vectors
50 |
51 | Note:
52 | The vectors may be represented as multidimensional arrays,
53 | but these arrays must be contiguous for this operation to work.
54 | Although it would be possible to use flatten(), this would create a
55 | copy when the array is not contiguous, causing a hidden performance hit.
56 | """
57 | pass
58 |
59 | @abstractmethod
60 | def medium(self, slot_in, slot_out, mnum):
61 | """Applies the operator 1-Vscat."""
62 | pass
63 |
64 | @abstractmethod
65 | def mix(self, weight_a, slot_a, weight_b, slot_b, slot_out):
66 | """Mixes two data arrays and stores the result in the specified slot"""
67 | pass
68 |
69 | @abstractmethod
70 | def propagator(self, slot_in, slot_out):
71 | """Applies the operator (L+1)^-1 x.
72 | """
73 | pass
74 |
75 | @abstractmethod
76 | def inverse_propagator(self, slot_in, slot_out):
77 | """Applies the operator (L+1) x .
78 |
79 | This operation is not needed for the Wavesim algorithm, but is provided for testing purposes,
80 | and can be used to evaluate the residue of the solution.
81 | """
82 | pass
83 |
84 | @abstractmethod
85 | def set_source(self, source):
86 | """Sets the source term for this domain."""
87 | pass
88 |
89 | @abstractmethod
90 | def create_empty_vdot(self):
91 | """Create an empty tensor for the Vdot tensor"""
92 | pass
93 |
94 | def coordinates_f(self, dim):
95 | """Returns the Fourier-space coordinates along the specified dimension"""
96 | shapes = [[-1, 1, 1], [1, -1, 1], [1, 1, -1]]
97 | return (2 * torch.pi * torch.fft.fftfreq(self.shape[dim], self.pixel_size, device=self.device,
98 | dtype=torch.float64)).reshape(shapes[dim]).to(torch.complex64)
99 |
100 | def coordinates(self, dim, type: str = 'linear'):
101 | """Returns the real-space coordinates along the specified dimension, starting at 0"""
102 | shapes = [[-1, 1, 1], [1, -1, 1], [1, 1, -1]]
103 | x = torch.arange(self.shape[dim], device=self.device, dtype=torch.float64) * self.pixel_size
104 | if type == 'periodic':
105 | x -= self.pixel_size * (self.shape[dim] // 2)
106 | x = torch.fft.ifftshift(x) # todo: or fftshift?
107 | elif type == 'centered':
108 | x -= self.pixel_size * (self.shape[dim] // 2)
109 | elif type == 'linear':
110 | pass
111 | else:
112 | raise ValueError(f"Unknown type {type}")
113 |
114 | return x.reshape(shapes[dim])
115 |
--------------------------------------------------------------------------------
/wavesim/iteration.py:
--------------------------------------------------------------------------------
1 | from .domain import Domain
2 | from .utilities import is_zero
3 |
4 |
5 | def run_algorithm(domain: Domain, source, alpha=0.75, max_iterations=1000, threshold=1.e-6, full_residuals=False):
6 | """ WaveSim update
7 |
8 | :param domain: Helmholtz base parameters
9 | :param source: source field
10 | :param alpha: relaxation parameter for the Richardson iteration
11 | :param max_iterations: maximum number of iterations
12 | :param threshold: threshold for the residual norm
13 | :param full_residuals: when True, returns list of residuals for all iterations. Otherwise, returns final residual
14 | :return: u, iteration count, residuals """
15 |
16 | # Reset the field u to zero
17 | slot_x = 0
18 | slot_tmp = 1
19 | domain.clear(slot_x)
20 | domain.set_source(source)
21 |
22 | # compute initial residual norm (with preconditioned source) for normalization
23 | domain.create_empty_vdot() # create empty slot for Vdot tensor. Always 8.1 MiB
24 | domain.add_source(slot_x, weight=1.) # [x] = y
25 | preconditioner(domain, slot_x, slot_x) # [x] = B(L+1)⁻¹y
26 | init_norm_inv = 1 / domain.inner_product(slot_x, slot_x) # inverse of initial norm: 1 / norm([x])
27 | domain.clear(slot_x) # Clear [x]
28 |
29 | # save list of residuals if requested
30 | residuals = [] if full_residuals else None
31 |
32 | for i in range(max_iterations):
33 | residual_norm = preconditioned_iteration(domain, slot_x, slot_x, slot_tmp, alpha, compute_norm2=True)
34 | # normalize residual norm with preconditioned source (i.e., with norm of B(L+1)⁻¹y)
35 | residual_norm = residual_norm * init_norm_inv # norm(B(x - (L+1)⁻¹ (B·x + c·y))) / norm(B(L+1)⁻¹y)
36 | print('.', end='', flush=True) if (i + 1) % 100 == 0 else None
37 | residuals.append(residual_norm) if full_residuals else None
38 | if residual_norm < threshold:
39 | break
40 |
41 | # return u and u_iter cropped to roi, residual arrays, and state object with information on run
42 | return domain.get(slot_x), (i + 1), residuals if full_residuals else residual_norm
43 |
44 |
45 | def preconditioned_iteration(domain, slot_in: int = 0, slot_out: int = 0, slot_tmp: int = 1, alpha=0.75,
46 | compute_norm2=False):
47 | """ Run one preconditioned iteration.
48 |
49 | Args:
50 | domain: Domain object
51 | slot_in: slot holding input x. This slot will be overwritten!
52 | slot_out: output slot that will receive the result
53 | slot_tmp: slot for temporary storage. Cannot be equal to slot_in, may be equal to slot_out
54 | alpha: relaxation parameter for the Richardson iteration
55 | compute_norm2: when True, returns the squared norm of the residual. Otherwise, returns 0.0
56 |
57 | Richardson iteration:
58 | x -> x + α (y - A x)
59 |
60 | Preconditioned Richardson iteration:
61 | x -> x + α Γ⁻¹ (y - A x)
62 | = x + α c B (L+1)⁻¹ (y - A x)
63 | = x + α c B (L+1)⁻¹ (y - c⁻¹ [L+V] x)
64 | = x + α c B (L+1)⁻¹ (y + c⁻¹ [1-V] x - c⁻¹ [L+1] x)
65 | = x + α B [(L+1)⁻¹ (c y + B x) - x]
66 | = x - α B x + α B (L+1)⁻¹ (c y + B x)
67 | """
68 | if slot_tmp == slot_in:
69 | raise ValueError("slot_in and slot_tmp should be different")
70 |
71 | domain.medium(slot_in, slot_tmp, mnum=0) # [tmp] = B·x
72 | domain.add_source(slot_tmp, domain.scale) # [tmp] = B·x + c·y
73 | domain.propagator(slot_tmp, slot_tmp) # [tmp] = (L+1)⁻¹ (B·x + c·y)
74 | domain.mix(1.0, slot_in, -1.0, slot_tmp, slot_tmp) # [tmp] = x - (L+1)⁻¹ (B·x + c·y)
75 | domain.medium(slot_tmp, slot_tmp, mnum=1) # [tmp] = B(x - (L+1)⁻¹ (B·x + c·y))
76 | # optionally compute norm of residual of preconditioned system
77 | retval = domain.inner_product(slot_tmp, slot_tmp) if compute_norm2 else 0.0
78 | domain.mix(1.0, slot_in, -alpha, slot_tmp, slot_out) # [out] = x - α B x + α B (L+1)⁻¹ (B·x + c·y)
79 | return retval
80 |
81 |
82 | def forward(domain: Domain, slot_in: int, slot_out: int):
83 | """ Evaluates the forward operator A= c⁻¹ (L + V)
84 |
85 | Args:
86 | domain: Domain object
87 | slot_in: slot holding input x. This slot will be overwritten!
88 | slot_out: output slot that will receive A x
89 | """
90 | if slot_in == slot_out:
91 | raise ValueError("slot_in and slot_out must be different")
92 |
93 | domain.medium(slot_in, slot_out) # (1-V) x
94 | domain.inverse_propagator(slot_in, slot_in) # (L+1) x
95 | domain.mix(1.0 / domain.scale, slot_in, -1.0 / domain.scale, slot_out, slot_out) # c⁻¹ (L+V) x
96 |
97 |
98 | def preconditioned_operator(domain: Domain, slot_in: int, slot_out: int):
99 | """ Evaluates the preconditioned operator Γ⁻¹ A
100 |
101 | Where Γ⁻¹ = c B (L+1)⁻¹
102 |
103 | Note: the scale factor c that makes A accretive and V a contraction is
104 | included in the preconditioner. The Richardson step size is _not_.
105 |
106 | Operator A is the original non-scaled operator, and we have (L+V) = c A
107 | Then:
108 |
109 | Γ⁻¹ A = c B (L+1)⁻¹ A
110 | = c B (L+1)⁻¹ c⁻¹ (L+V)
111 | = B (L+1)⁻¹ (L+V)
112 | = B (L+1)⁻¹ ([L+1] - [1-V])
113 | = B - B (L+1)⁻¹ B
114 |
115 | Args:
116 | domain: Domain object
117 | slot_in: slot holding input x. This slot will be overwritten!
118 | slot_out: output slot that will receive A x
119 | """
120 | if slot_in == slot_out:
121 | raise ValueError("slot_in and slot_out must be different")
122 |
123 | domain.medium(slot_in, slot_in) # B x
124 | domain.propagator(slot_in, slot_out) # (L+1)⁻¹ B x
125 | domain.medium(slot_out, slot_out) # B (L+1)⁻¹ B x
126 | domain.mix(1.0, slot_in, -1.0, slot_out, slot_out) # B - B (L+1)⁻¹ B x
127 |
128 |
129 | def preconditioner(domain: Domain, slot_in: int, slot_out: int):
130 | """ Evaluates Γ⁻¹ = c B(L+1)⁻¹
131 |
132 | Args:
133 | domain: Domain object
134 | slot_in: slot holding input x. This slot will be overwritten!
135 | slot_out: output slot that will receive A x
136 | """
137 | domain.propagator(slot_in, slot_in) # (L+1)⁻¹ x
138 | domain.medium(slot_in, slot_out) # B (L+1)⁻¹ x
139 | domain.mix(0.0, slot_out, domain.scale, slot_out, slot_out) # c B (L+1)⁻¹ x
140 |
141 |
142 | def domain_operator(domain: Domain, function: str, **kwargs):
143 | """Constructs various operators by combining calls to 'medium', 'propagator', etc.
144 |
145 | todo: this is currently very inefficient because of the overhead of copying data to and from the device
146 | """
147 |
148 | def potential_(domain, slot_in, slot_out):
149 | domain.medium(slot_in, slot_out)
150 | domain.mix(1.0, slot_in, -1.0, slot_out, slot_out)
151 |
152 | fn = {
153 | 'medium': domain.medium,
154 | 'propagator': domain.propagator,
155 | 'inverse_propagator': domain.inverse_propagator,
156 | 'potential': potential_,
157 | 'forward': lambda slot_in, slot_out: forward(domain, slot_in, slot_out),
158 | 'preconditioned_operator': lambda slot_in, slot_out: preconditioned_operator(domain, slot_in, slot_out),
159 | 'preconditioner': lambda slot_in, slot_out: preconditioner(domain, slot_in, slot_out),
160 | 'richardson': lambda slot_in, slot_out: preconditioned_iteration(domain, slot_in, slot_out, slot_out, **kwargs)
161 | }[function]
162 |
163 | def operator_(x):
164 | if is_zero(x):
165 | domain.clear(0)
166 | else:
167 | domain.set(0, x)
168 | fn(0, 1)
169 | return domain.get(1, copy=True)
170 |
171 | return operator_
172 |
--------------------------------------------------------------------------------
/wavesim/multidomain.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from .domain import Domain
4 | from .helmholtzdomain import HelmholtzDomain
5 | from .utilities import partition, combine, list_to_array, is_zero
6 | from torch.cuda import empty_cache
7 |
8 |
9 | class MultiDomain(Domain):
10 | """ Class for generating medium (B) and propagator (L+1)^(-1) operators, scaling,
11 | and setting up wrapping and transfer corrections """
12 |
13 | def __init__(self,
14 | permittivity,
15 | periodic: tuple[bool, bool, bool],
16 | pixel_size: float = 0.25,
17 | wavelength: float = None,
18 | n_domains: tuple[int, int, int] = (1, 1, 1),
19 | n_boundary: int = 8,
20 | device: str = None,
21 | debug: bool = False):
22 | """ Takes input parameters for the HelmholtzBase class (and sets up the operators)
23 |
24 | Args:
25 | permittivity: Permittivity distribution, must be 3-d.
26 | periodic: Indicates for each dimension whether the simulation is periodic or not.
27 | periodic dimensions, the field is wrapped around the domain.
28 | pixel_size: Grid spacing in wavelengths.
29 | wavelength: wavelength in micrometer (um).
30 | n_domains: number of domains to split the simulation into.
31 | the domain size is not divisible by n_domains, the last domain will be slightly smaller than the other
32 | ones. In the future, the domain size may be adjusted to have an efficient fourier transform.
33 | is (1, 1, 1), no domain decomposition.
34 | n_boundary: Number of points used in the wrapping and domain transfer correction. Default is 8.
35 | device: 'cpu' to use the cpu, 'cuda' to distribute the simulation over all available cuda devices, 'cuda:x'
36 | to use a specific cuda device, a list of strings, e.g., ['cuda:0', 'cuda:1'] to distribute the
37 | simulation over these devices in a round-robin fashion, or None, which is equivalent to 'cuda' if
38 | cuda devices are available, and 'cpu' if they are not.
39 | todo: implement
40 | debug: set to True to return inverse_propagator_kernel as output
41 | """
42 |
43 | # Takes the input parameters and returns these in the appropriate format, with more parameters for setting up
44 | # the Medium (+corrections) and Propagator operators, and scaling
45 | # (self.n_roi, self.s, self.n_dims, self.boundary_widths, self.boundary_pre, self.boundary_post,
46 | # self.n_domains, self.domain_size, self.omega, self.v_min, self.v_raw) = (
47 | # preprocess(n, pixel_size, n_domains))
48 |
49 | # validata input parameters
50 | if not permittivity.ndim == 3:
51 | raise ValueError("The permittivity must be a 3D array")
52 | if not len(n_domains) == 3:
53 | raise ValueError("The number of domains must be a 3-tuple")
54 |
55 | # enumerate the cuda devices. We will assign the domains to the devices in a round-robin fashion.
56 | # we use the first GPU as primary device
57 | if device is None or device == 'cuda':
58 | devices = [f'cuda:{device_id}' for device_id in
59 | range(torch.cuda.device_count())] if torch.cuda.is_available() else ['cpu']
60 | else:
61 | devices = [device]
62 |
63 | if not torch.is_tensor(permittivity):
64 | permittivity = torch.tensor(permittivity)
65 | super().__init__(pixel_size, permittivity.shape, torch.device(devices[0]))
66 | self.periodic = np.array(periodic)
67 |
68 | # compute domain boundaries in each dimension
69 | self.domains = np.empty(n_domains, dtype=HelmholtzDomain)
70 | self.n_domains = n_domains
71 |
72 | # distribute the permittivity map over the subdomains.
73 | p_domains = partition(permittivity, self.n_domains)
74 | subdomain_periodic = [periodic[i] and n_domains[i] == 1 for i in range(3)]
75 | Vwrap = None
76 | for domain_index, p_domain in enumerate(p_domains.flat):
77 | # p_domain = torch.tensor(p_domain, device=devices[domain_index % len(devices)])
78 | self.domains.flat[domain_index] = HelmholtzDomain(permittivity=p_domain.to(devices[domain_index %
79 | len(devices)]),
80 | pixel_size=pixel_size, wavelength=wavelength,
81 | n_boundary=n_boundary, periodic=subdomain_periodic,
82 | stand_alone=False, debug=debug, Vwrap=Vwrap)
83 | Vwrap = self.domains.flat[domain_index].Vwrap # re-use wrapping matrix
84 |
85 | # determine the optimal shift
86 | limits = np.array([domain.V_bounds for domain in self.domains.flat])
87 | r_min = np.min(limits[:, 0])
88 | r_max = np.max(limits[:, 1])
89 | # i_min = np.min(limits[:, 2])
90 | # i_max = np.max(limits[:, 3])
91 | center = 0.5 * (r_min + r_max) # + 0.5j * (i_min + i_max)
92 |
93 | # shift L and V to minimize norm of V
94 | Vscat_norm = 0.0
95 | Vwrap_norm = 0.0
96 | for domain in self.domains.flat:
97 | Vscat_norm = np.maximum(Vscat_norm, domain.initialize_shift(center))
98 | Vwrap_norm = np.maximum(Vwrap_norm, domain.Vwrap_norm)
99 |
100 | # the factor 2 is because the same matrix is used twice (for domain transfer and wrapping correction)
101 | Vwrap_norm = 2 * Vwrap_norm if max(n_domains) > 1 else Vwrap_norm
102 |
103 | # compute the scaling factor
104 | # apply the scaling to compute the final form of all operators in the iteration
105 | self.shift = center
106 | self.scale = 0.95j / (Vscat_norm + Vwrap_norm)
107 | for domain in self.domains.flat:
108 | domain.initialize_scale(self.scale)
109 | empty_cache() # free up memory before going to run_algorithm
110 |
111 | # Functions implementing the domain interface
112 | # add_source()
113 | # clear()
114 | # get()
115 | # inner_product()
116 | # medium()
117 | # mix()
118 | # propagator()
119 | # set_source()
120 | def add_source(self, slot: int, weight: float):
121 | """ Add the source to the field in slot """
122 | for domain in self.domains.flat:
123 | domain.add_source(slot, weight)
124 |
125 | def clear(self, slot: int):
126 | """ Clear the field in the specified slot """
127 | for domain in self.domains.flat:
128 | domain.clear(slot)
129 |
130 | def get(self, slot: int, copy=False, device=None):
131 | """ Get the field in the specified slot, this gathers the fields from all subdomains and puts them in
132 | one big array
133 |
134 | :param slot: slot to get the data from
135 | :param copy: if True, returns a copy of the data. Otherwise, may return the original data possible.
136 | Note that this data may be overwritten by the next call to domain.
137 | :param device: device on which to store the data. Defaults to the primary device
138 | """
139 | domain_data = list_to_array([domain.get(slot) for domain in self.domains.flat], 1).reshape(self.domains.shape)
140 | return combine(domain_data, device)
141 |
142 | def set(self, slot: int, data):
143 | """Copy the date into the specified slot"""
144 | parts = partition(data, self.n_domains)
145 | for domain, part in zip(self.domains.flat, parts.flat):
146 | domain.set(slot, part)
147 |
148 | def inner_product(self, slot_a: int, slot_b: int):
149 | """ Compute the inner product of the fields in slots a and b
150 |
151 | Note:
152 | Use sqrt(inner_product(slot_a, slot_a)) to compute the norm of the field in slot_a.
153 | There is a large but inconsistent difference in performance between vdot and linalg.norm.
154 | Execution time can vary a factor of 3 or more between the two, depending on the input size
155 | and whether the function is executed on the CPU or the GPU.
156 | """
157 | inner_product = 0.0
158 | for domain in self.domains.flat:
159 | inner_product += domain.inner_product(slot_a, slot_b)
160 | return inner_product
161 |
162 | def medium(self, slot_in: int, slot_out: int, mnum=None):
163 | """ Apply the medium operator B, including wrapping corrections.
164 |
165 | Args:
166 | slot_in: slot holding the input field
167 | slot_out: slot that will receive the result
168 | mnum: # of the medium() call in preconditioned iteration.
169 | 0 for first, 1 for second medium call.
170 | """
171 |
172 | # compute the corrections for each domain, before applying the medium operator
173 | domain_edges = [domain.compute_corrections(slot_in) for domain in self.domains.flat]
174 | domain_edges = list_to_array(domain_edges, 2).reshape(*self.domains.shape, 6)
175 |
176 | # Only applies the operator B=1-Vscat. The corrections are applied in the next step
177 | for domain in self.domains.flat:
178 | domain.medium(slot_in, slot_out)
179 |
180 | # apply wrapping corrections. We subtract each correction from
181 | # the opposite side of the domain to compensate for the wrapping.
182 | # also, we add each correction to the opposite side of the neighbouring domain
183 | for idx, domain in enumerate(self.domains.flat):
184 | x = np.unravel_index(idx, self.domains.shape)
185 | # for the wrap corrections, take the corrections for this domain and swap them
186 | wrap_corrections = domain_edges[*x, (1, 0, 3, 2, 5, 4)]
187 |
188 | # for the transfer corrections, take the corrections from the neighbors
189 | def get_neighbor(edge):
190 | dim = edge // 2
191 | offset = -1 if edge % 2 == 0 else 1
192 | x_neighbor = np.array(x)
193 | x_neighbor[dim] += offset
194 | if self.periodic[dim]:
195 | x_neighbor = np.mod(x_neighbor, self.domains.shape)
196 | else:
197 | if x_neighbor[dim] < 0 or x_neighbor[dim] >= self.domains.shape[dim]:
198 | return None
199 | return domain_edges[*tuple(x_neighbor), edge - offset]
200 |
201 | transfer_corrections = [get_neighbor(edge) for edge in range(6)]
202 |
203 | # check if domain should be active in the iteration or not
204 | if mnum is None or domain._debug: # always active outside iteration (mnum==None) or in debug mod
205 | domain.active = True
206 | else: # check based on the norm of the transfer corrections
207 | tc_norm = [a for a in transfer_corrections if a is not None]
208 | if tc_norm:
209 | if domain.counter < 25: # counter for the number of iterations with increasing norm
210 | tc_norm = max([torch.vdot(a.view(-1), a.view(-1)).item().real for a in tc_norm])
211 |
212 | if mnum == 0: # first medium call in preconditioned iteration
213 | domain.mnum0[1] = tc_norm
214 | elif mnum == 1: # second medium call in preconditioned iteration
215 | domain.mnum1[1] = tc_norm
216 |
217 | # if norm is high, domain is set to active
218 | if domain.mnum0[1] >= 1.e-7 or domain.mnum1[1] >= 1.e-7:
219 | domain.active = True
220 | domain.counter = 25
221 | # if norm is monotonically increasing, increase the counter
222 | elif domain.mnum0[-1] > domain.mnum0[-2] and domain.mnum1[-1] > domain.mnum1[-2]:
223 | domain.counter += 1
224 | else:
225 | domain.counter = 0
226 | # if the norm is not increasing and source is zero, domain is set to inactive
227 | if domain._source is not None and torch.sum(domain._source).item().real == 0.0:
228 | domain.active = False
229 | else:
230 | domain.active = True
231 | domain.counter = 25
232 |
233 | # current norm becomes previous norm
234 | domain.mnum0[0] = domain.mnum0[1]
235 | domain.mnum1[0] = domain.mnum1[1]
236 | else:
237 | domain.active = True
238 |
239 | if domain.active:
240 | domain.apply_corrections(wrap_corrections, transfer_corrections, slot_out)
241 |
242 | def mix(self, weight_a: float, slot_a: int, weight_b: float, slot_b: int, slot_out: int):
243 | """ Mix the fields in slots a and b and store the result in slot_out """
244 | for domain in self.domains.flat:
245 | domain.mix(weight_a, slot_a, weight_b, slot_b, slot_out)
246 |
247 | def propagator(self, slot_in: int, slot_out: int):
248 | """ Apply propagator operators (L+1)^-1 to subdomains/patches of x."""
249 | for domain in self.domains.flat:
250 | domain.propagator(slot_in, slot_out)
251 |
252 | def inverse_propagator(self, slot_in: int, slot_out: int):
253 | """ Apply inverse propagator operators L+1 to subdomains/patches of x."""
254 | for domain in self.domains.flat:
255 | domain.inverse_propagator(slot_in, slot_out)
256 |
257 | def set_source(self, source):
258 | """ Split the source into subdomains and store in the subdomain states."""
259 | if source is None or is_zero(source):
260 | for domain in self.domains.flat:
261 | domain.set_source(None)
262 | else:
263 | for domain, source in zip(self.domains.flat, partition(source, self.n_domains).flat):
264 | domain.set_source(source)
265 |
266 | def create_empty_vdot(self):
267 | """ Create an empty tensor for the Vdot tensor """
268 | for domain in self.domains.flat:
269 | domain.create_empty_vdot()
270 |
--------------------------------------------------------------------------------
/wavesim/utilities.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import Sequence
4 | from itertools import chain
5 | from scipy.special import exp1
6 |
7 |
8 | # Preprocessing functions. These functions are used to preprocess the input parameters , i.e.,
9 | # to add absorption and boundaries to the permittivity (refractive index²).
10 | def preprocess(permittivity, boundary_widths=10):
11 | """ Preprocess the input parameters for the simulation.
12 | Add absorption and boundaries to the permittivity (refractive index²),
13 | and return the preprocessed permittivity and boundaries in the format (ax0, ax1, ax2).
14 |
15 | :param permittivity: Refractive index²
16 | :param boundary_widths: Boundary widths (in pixels)
17 | :return: Preprocessed permittivity (refractive index²) with boundaries and absorption
18 | """
19 | permittivity = check_input_dims(permittivity) # Ensure permittivity is a 3-d array
20 | if permittivity.dtype != np.complex64:
21 | permittivity = permittivity.astype(np.complex64)
22 | n_dims = get_dims(permittivity) # Number of dimensions in simulation
23 | n_roi = np.array(permittivity.shape) # Num of points in ROI (Region of Interest)
24 |
25 | # Ensure boundary_widths is a 3-element array of ints with 0s after n_dims
26 | boundary_widths = check_input_len(boundary_widths, 0, n_dims).astype(int)
27 |
28 | permittivity = add_absorption(permittivity, boundary_widths, n_roi, n_dims)
29 |
30 | return permittivity, boundary_widths
31 |
32 |
33 | def add_absorption(m, boundary_widths, n_roi, n_dims):
34 | """ Add (weighted) absorption to the permittivity (refractive index squared)
35 |
36 | :param m: array (permittivity)
37 | :param boundary_widths: Boundary widths
38 | :param n_roi: Number of points in the region of interest
39 | :param n_dims: Number of dimensions
40 | :return: m with absorption
41 | """
42 | w = np.ones_like(m) # Weighting function (1 everywhere)
43 | w = pad_boundaries(w, boundary_widths, mode='linear_ramp') # pad w using linear_ramp
44 | a = 1 - w # for absorption, inverse weighting 1 - w
45 | for i in range(n_dims):
46 | left_boundary = boundary_(boundary_widths[i]) # boundary_ is a linear window function
47 | right_boundary = np.flip(left_boundary) # flip is a vertical flip
48 | full_filter = np.concatenate((left_boundary, np.ones(n_roi[i], dtype=np.float32), right_boundary))
49 | a = np.moveaxis(a, i, -1) * full_filter # transpose to last dimension, apply filter
50 | a = np.moveaxis(a, -1, i) # transpose back to original position
51 | a = 1j * a # absorption is imaginary
52 |
53 | m = pad_boundaries(m, boundary_widths, mode='edge') # pad m using edge values
54 | m = w * m + a # add absorption to m
55 | return m
56 |
57 |
58 | def boundary_(x):
59 | """ Anti-reflection boundary layer (ARL). Linear window function
60 |
61 | :param x: Size of the ARL
62 | """
63 | return ((np.arange(1, x + 1) - 0.21).T / (x + 0.66)).astype(np.float32)
64 |
65 |
66 | def check_input_dims(x):
67 | """ Expand arrays to 3 dimensions (e.g. permittivity (refractive index²) or source)
68 |
69 | :param x: Input array
70 | :return: x with 3 dimensions
71 | """
72 | for _ in range(3 - x.ndim):
73 | x = np.expand_dims(x, axis=-1) # Expand dimensions to 3
74 | return x
75 |
76 |
77 | def check_input_len(x, e, n_dims):
78 | """ Check the length of input arrays and expand them to 3 elements if necessary. Either repeat or add 'e'
79 |
80 | :param x: Input array
81 | :param e: Element to add
82 | :param n_dims: Number of dimensions
83 | :return: Array with 3 elements
84 | """
85 | if isinstance(x, int) or isinstance(x, float): # If x is a single number
86 | x = n_dims * tuple((x,)) + (3 - n_dims) * (e,) # Repeat the number n_dims times, and add (3-n_dims) e's
87 | elif len(x) == 1: # If x is a single element list or tuple
88 | x = n_dims * tuple(x) + (3 - n_dims) * (e,) # Repeat the element n_dims times, and add (3-n_dims) e's
89 | elif isinstance(x, list) or isinstance(x, tuple): # If x is a list or tuple
90 | x += (3 - len(x)) * (e,) # Add (3-len(x)) e's
91 | if isinstance(x, np.ndarray): # If x is a numpy array
92 | x = np.concatenate((x, np.zeros(3 - len(x)))) # Concatenate with (3-len(x)) zeros
93 | return np.array(x)
94 |
95 |
96 | def get_dims(x):
97 | """ Get the number of dimensions of 'x'
98 |
99 | :param x: Input array
100 | :return: Number of dimensions
101 | """
102 | x = squeeze_(x) # Squeeze the last dimension if it is 1
103 | return x.ndim # Number of dimensions
104 |
105 |
106 | def pad_boundaries(x, boundary_widths, boundary_post=None, mode='constant'):
107 | """ Pad 'x' with boundaries in all dimensions using numpy pad (if x is np.ndarray) or PyTorch nn.functional.pad
108 | (if x is torch.Tensor).
109 | If boundary_post is specified separately, pad with boundary_widths (before) and boundary_post (after).
110 |
111 | :param x: Input array
112 | :param boundary_widths: Boundary widths for padding before and after (or just before if boundary_post not None)
113 | :param boundary_post: Boundary widths for padding after
114 | :param mode: Padding mode
115 | :return: Padded array
116 | """
117 | x = check_input_dims(x) # Ensure x is a 3-d array
118 |
119 | if boundary_post is None:
120 | boundary_post = boundary_widths
121 |
122 | if isinstance(x, np.ndarray):
123 | pad_width = tuple(zip(boundary_widths, boundary_post)) # pairs ((a0, b0), (a1, b1), (a2, b2))
124 | return np.pad(x, pad_width, mode)
125 | elif torch.is_tensor(x):
126 | t = zip(boundary_widths[::-1], boundary_post[::-1]) # reversed pairs (a2, b2) (a1, b1) (a0, b0)
127 | pad_width = tuple(chain.from_iterable(t)) # flatten to (a2, b2, a1, b1, a0, b0)
128 | return torch.nn.functional.pad(x, pad_width, mode)
129 | else:
130 | raise ValueError("Input must be a numpy array or a torch tensor")
131 |
132 |
133 | def squeeze_(x):
134 | """ Squeeze the last dimension of 'x' if it is 1
135 |
136 | :param x: Input array
137 | :return: Squeezed array
138 | """
139 | while x.shape[-1] == 1:
140 | x = np.squeeze(x, axis=-1)
141 | return x
142 |
143 |
144 | # Domain decomposition functions.
145 | def combine(domains: np.ndarray, device='cpu') -> torch.Tensor:
146 | """ Concatenates a 3-d array of 3-d tensors"""
147 |
148 | # Calculate total size for each dimension
149 | total_size = [
150 | sum(tensor.shape[0] for tensor in domains[:, 0, 0]),
151 | sum(tensor.shape[1] for tensor in domains[0, :, 0]),
152 | sum(tensor.shape[2] for tensor in domains[0, 0, :]),
153 | ]
154 |
155 | # allocate memory
156 | template = domains[0, 0, 0]
157 | result_tensor = torch.empty(size=total_size, dtype=template.dtype, device=device)
158 |
159 | # Fill the pre-allocated tensor
160 | index0 = 0
161 | for i, tensor_slice0 in enumerate(domains[:, 0, 0]):
162 | index1 = 0
163 | for j, tensor_slice1 in enumerate(domains[0, :, 0]):
164 | index2 = 0
165 | for k, tensor in enumerate(domains[0, 0, :]):
166 | tensor = domains[i, j, k]
167 | if tensor.is_sparse:
168 | tensor = tensor.to_dense()
169 | end0 = index0 + tensor.shape[0]
170 | end1 = index1 + tensor.shape[1]
171 | end2 = index2 + tensor.shape[2]
172 | result_tensor[index0:end0, index1:end1, index2:end2] = tensor
173 | index2 += tensor.shape[2]
174 | index1 += domains[i, j, 0].shape[1]
175 | index0 += tensor_slice0.shape[0]
176 |
177 | return result_tensor
178 |
179 |
180 | def list_to_array(input: list, depth: int) -> np.ndarray:
181 | """ Convert a nested list of depth `depth` to a numpy object array """
182 | # first determine the size of the final array
183 | size = np.zeros(depth, dtype=int)
184 | outer = input
185 | for i in range(depth):
186 | size[i] = len(outer)
187 | outer = outer[0]
188 |
189 | # allocate memory
190 | array = np.empty(size, dtype=object)
191 |
192 | # flatten the input array
193 | for i in range(depth - 1):
194 | input = sum(input, input[0][0:0]) # works both for tuples and lists
195 |
196 | # copy to the output array
197 | ra = array.reshape(-1)
198 | assert ra.base is not None # must be a view
199 | for i in range(ra.size):
200 | if input[i] is None or input[i].is_sparse or input[i].is_contiguous():
201 | ra[i] = input[i]
202 | else:
203 | ra[i] = input[i].contiguous()
204 | return array
205 |
206 |
207 | def partition(array: torch.Tensor, n_domains: tuple[int, int, int]) -> np.ndarray:
208 | """ Split a 3-D array into a 3-D set of sub-arrays of approximately equal sizes."""
209 | n_domains = np.array(n_domains) # Add 1 to the end to make it a 3-element array
210 | size = np.array(array.shape)
211 | if any(size < n_domains) or any(n_domains <= 0) or len(n_domains) != 3:
212 | raise ValueError(f"Number of domains {n_domains} must be larger than 1 and "
213 | f"less than or equal to the size of the array {array.shape}")
214 |
215 | # Calculate the size of each domain
216 | large_domain_size = np.ceil(size / n_domains).astype(int)
217 | small_domain_count = large_domain_size * n_domains - size
218 | large_domain_count = n_domains - small_domain_count
219 | subdomain_sizes = [(large_domain_size[dim],) * large_domain_count[dim] + (large_domain_size[dim] - 1,)
220 | * small_domain_count[dim] for dim in range(3)]
221 |
222 | split = _sparse_split if array.is_sparse else torch.split
223 |
224 | array = split(array, subdomain_sizes[0], dim=0)
225 | array = [split(part, subdomain_sizes[1], dim=1) for part in array]
226 | array = [[split(part, subdomain_sizes[2], dim=2) for part in subpart] for subpart in array]
227 | return list_to_array(array, depth=3)
228 |
229 |
230 | def _sparse_split(tensor: torch.Tensor, sizes: Sequence[int], dim: int) -> np.ndarray:
231 | """ Split a COO-sparse tensor into a 3-D set of sub-arrays of approximately equal sizes."""
232 | if len(sizes) == 1:
233 | return [tensor] # no need to split
234 |
235 | tensor = tensor.coalesce()
236 | indices = tensor.indices().cpu().numpy()
237 | values = tensor.values()
238 |
239 | if dim >= tensor.sparse_dim():
240 | values = torch.tensor(values.detach().clone().cpu().numpy()) # for troubleshooting access violation
241 | value_list = list(torch.split(values, sizes, dim - tensor.sparse_dim() + 1)) # split dense tensor component
242 | sz = list(tensor.shape)
243 | for i in range(len(value_list)):
244 | sz[dim] = sizes[i]
245 | v = np.array(
246 | value_list[i].cpu().numpy()) # should not be necessary, workaround for access violation bug in torch
247 | value_list[i] = torch.sparse_coo_tensor(indices, v, tuple(sz))
248 | # print(indices, indices.dtype, value_list[i], value_list[i].shape, sz)
249 | value_list[i].to_dense() # for troubleshooting access violation
250 | return value_list
251 |
252 | coordinate_to_domain = np.array(sum([(idx,) * size for idx, size in enumerate(sizes)], ()))
253 | domain_starts = np.cumsum((0,) + sizes)
254 | domains = coordinate_to_domain[indices[dim, :]]
255 |
256 | def extract_subarray(domain: int) -> torch.Tensor:
257 | mask = domains == domain
258 | domain_indices = indices[:, mask]
259 | if len(domain_indices) == 0:
260 | return None
261 | domain_values = values[mask]
262 | domain_indices[dim, :] -= domain_starts[domain]
263 | size = list(tensor.shape)
264 | size[dim] = sizes[domain]
265 | return torch.sparse_coo_tensor(domain_indices, domain_values, tuple(size))
266 |
267 | return [extract_subarray(d) for d in range(len(sizes))]
268 |
269 |
270 | # Used in tests
271 | def full_matrix(operator):
272 | """ Converts operator to a 2D square matrix of size np.prod(d) x np.prod(d)
273 |
274 | :param operator: Operator to convert to a matrix. This function must be able to accept a 0 scalar, and
275 | return a vector of the size and data type of the domain.
276 | """
277 | y = operator(0.0)
278 | n_size = y.shape
279 | nf = np.prod(n_size)
280 | M = torch.zeros((nf, nf), dtype=y.dtype, device=y.device)
281 | b = torch.zeros(n_size, dtype=y.dtype, device=y.device)
282 | for i in range(nf):
283 | b.view(-1)[i] = 1
284 | M[:, i] = torch.ravel(operator(b))
285 | b.view(-1)[i] = 0
286 |
287 | return M
288 |
289 |
290 | # Metrics
291 | def max_abs_error(e, e_true):
292 | """ (Normalized) Maximum Absolute Error (MAE) ||e-e_true||_{inf} / ||e_true||
293 |
294 | :param e: Computed field
295 | :param e_true: True field
296 | :return: (Normalized) MAE
297 | """
298 | return np.max(np.abs(e - e_true)) / np.linalg.norm(e_true)
299 |
300 |
301 | def max_relative_error(e, e_true):
302 | """Computes the maximum error, normalized by the rms of the true field
303 |
304 | :param e: Computed field
305 | :param e_true: True field
306 | :return: (Normalized) Maximum Relative Error
307 | """
308 | return np.max(np.abs(e - e_true)) / np.sqrt(np.mean(np.abs(e_true) ** 2))
309 |
310 |
311 | def relative_error(e, e_true):
312 | """ Relative error ``⟨|e-e_true|^2⟩ / ⟨|e_true|^2⟩``
313 |
314 | :param e: Computed field
315 | :param e_true: True field
316 | :return: Relative Error
317 | """
318 | return np.nanmean(np.abs(e - e_true) ** 2) / np.nanmean(np.abs(e_true) ** 2)
319 |
320 |
321 | # Miscellaneous functions
322 | ## 1D analytical solution for Helmholtz equation
323 | def analytical_solution(n_size0, pixel_size, wavelength=None):
324 | """ Compute analytic solution for 1D case """
325 | x = np.arange(0, n_size0 * pixel_size, pixel_size, dtype=np.float32)
326 | x = np.pad(x, (n_size0, n_size0), mode='constant', constant_values=np.nan)
327 | h = pixel_size
328 | # wavenumber (k)
329 | if wavelength is None:
330 | k = 1. * 2. * np.pi * pixel_size
331 | else:
332 | k = 1. * 2. * np.pi / wavelength
333 | phi = k * x
334 | u_theory = (1.0j * h / (2 * k) * np.exp(1.0j * phi) # propagating plane wave
335 | - h / (4 * np.pi * k) * (
336 | np.exp(1.0j * phi) * (exp1(1.0j * (k - np.pi / h) * x) - exp1(1.0j * (k + np.pi / h) * x)) -
337 | np.exp(-1.0j * phi) * (-exp1(-1.0j * (k - np.pi / h) * x) + exp1(-1.0j * (k + np.pi / h) * x)))
338 | )
339 | small = np.abs(k * x) < 1.e-10 # special case for values close to 0
340 | u_theory[small] = 1.0j * h / (2 * k) * (1 + 2j * np.arctanh(h * k / np.pi) / np.pi) # exact value at 0.
341 | return u_theory[n_size0:-n_size0]
342 |
343 |
344 | def is_zero(x):
345 | """ Check if x is zero
346 |
347 | Some functions allow specifying 0 or 0.0 instead of a torch tensor, to indicate that the array should be cleared.
348 | This function returns True if x is a scalar 0 or 0.0. It raises an error if x is a scalar that is not equal to 0 or
349 | 0.0, and returns False otherwise.
350 | """
351 | if isinstance(x, float) or isinstance(x, int):
352 | if x != 0:
353 | raise ValueError("Cannot set a field to a scalar to a field, only scalar 0.0 is supported")
354 | return True
355 | else:
356 | return False
357 |
358 |
359 | def normalize(x, min_val=None, max_val=None, a=0, b=1):
360 | """ Normalize x to the range [a, b]
361 |
362 | :param x: Input array
363 | :param min_val: Minimum value (of x)
364 | :param max_val: Maximum value (of x)
365 | :param a: Lower bound for normalization
366 | :param b: Upper bound for normalization
367 | :return: Normalized x
368 | """
369 | if min_val is None:
370 | min_val = x.min()
371 | if max_val is None:
372 | max_val = x.max()
373 | normalized_x = (x - min_val) / (max_val - min_val) * (b - a) + a
374 | return normalized_x
375 |
--------------------------------------------------------------------------------