├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── make.bat ├── source │ ├── acknowledgements.rst │ ├── api.rst │ ├── api_domain.rst │ ├── api_helmholtzdomain.rst │ ├── api_iteration.rst │ ├── api_multidomain.rst │ ├── api_utilities.rst │ ├── conclusion.rst │ ├── conf.py │ ├── development.rst │ ├── index.rst │ ├── index_latex.rst │ ├── index_markdown.rst │ ├── readme.rst │ └── references.bib └── tex.bat ├── environment.yml ├── examples ├── README.rst ├── __init__.py ├── check_mem.py ├── helmholtz_1d.py ├── helmholtz_1d_analytical.py ├── helmholtz_2d.py ├── helmholtz_2d_homogeneous.py ├── helmholtz_2d_low_contrast.py ├── helmholtz_3d_disordered.py ├── logo_structure_vector.png ├── matlab_results.mat ├── mem_snapshot.py ├── paper_code │ ├── __init__.py │ ├── fig1_splitting.py │ ├── fig2_decompose.py │ ├── fig3_correction_matrix.py │ ├── fig4_truncate.py │ ├── fig5_dd_validation.py │ ├── fig6_dd_truncation.py │ ├── fig7_dd_convergence.py │ └── fig8_dd_large_simulation.py ├── run_example.py └── timing_test.py ├── guides ├── ipynb2slides_guide.md ├── misc_tips.md └── pytorch_gpu_setup.md ├── pyproject.toml ├── pyproject_cpu.toml ├── requirements.txt ├── tests ├── __init__.py ├── test_analytical.py ├── test_basics.py ├── test_examples.py ├── test_operators.py └── test_utilities.py └── wavesim ├── __init__.py ├── domain.py ├── helmholtzdomain.py ├── iteration.py ├── multidomain.py └── utilities.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Simulation output 7 | *.pdf 8 | *.npz 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 115 | .pdm.toml 116 | .pdm-python 117 | .pdm-build/ 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | .idea/ 168 | 169 | # Custom 170 | *.html 171 | *.ipynb 172 | *.json 173 | *.mp4 174 | *.pdf 175 | *.png 176 | *.sh 177 | .vscode 178 | __pycache__ 179 | logs 180 | *.nsys-rep 181 | *.sqlite 182 | 183 | # Ignore certain files in the docs directory 184 | auto_examples 185 | sg_execution_times.rst -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.11" 7 | 8 | jobs: 9 | post_create_environment: 10 | - pip install poetry 11 | - pip install poetry-plugin-export 12 | - poetry export -f requirements.txt -o requirements.txt --with docs,dev 13 | - cat requirements.txt 14 | 15 | python: 16 | install: 17 | - requirements: requirements.txt 18 | 19 | sphinx: 20 | configuration: docs/source/conf.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ivo Vellekoop, Swapnil Mache - University of Twente 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Wavesim 4 | 5 | 6 | 7 | ## What is Wavesim? 8 | 9 | Wavesim is a tool to simulate the propagation of waves in complex, inhomogeneous structures. Whereas most available solvers use the popular finite difference time domain (FDTD) method [[1](#id27), [2](#id21), [3](#id16), [4](#id15)], Wavesim is based on the modified Born series (MBS) approach, which has lower memory requirements, no numerical dispersion, and is faster as compared to FDTD [[5](#id17), [6](#id23)]. 10 | 11 | This package [[7](#id25)] is a Python implementation of the MBS approach for solving the Helmholtz equation in arbitrarily large media through domain decomposition [[8](#id13)]. With this new framework, we simulated a complex 3D structure of a remarkable $315\times 315\times 315$ wavelengths $\left( 3.1\cdot 10^7 \right)$ in size in just $1.4$ hours by solving over two GPUs. This represents a factor of $1.93$ increase over the largest possible simulation on a single GPU without domain decomposition. 12 | 13 | When using Wavesim in your work, please cite: 14 | 15 | > [[5](#id17)] [Osnabrugge, G., Leedumrongwatthanakun, S., & Vellekoop, I. M. (2016). A convergent Born series for solving the inhomogeneous Helmholtz equation in arbitrarily large media. *Journal of computational physics, 322*, 113-124.](https://doi.org/10.1016/j.jcp.2016.06.034) 16 | 17 | > [[8](#id13)] [Mache, S., & Vellekoop, I. M. (2024). Domain decomposition of the modified Born series approach for large-scale wave propagation simulations. *arXiv preprint arXiv:2410.02395*.](https://arxiv.org/abs/2410.02395) 18 | 19 | If you use the code in your research, please cite this repository as well [[7](#id25)]. 20 | 21 | Examples and documentation for this project are available at [Read the Docs](https://wavesim.readthedocs.io/en/latest/) [[9](#id24)]. For more information (and to participate in the forum for discussions, queries, and requests), please visit our website [www.wavesim.org](https://www.wavesim.org/). 22 | 23 | ## Installation 24 | 25 | Wavesim requires [Python >=3.11.0 and <3.13.0](https://www.python.org/downloads/) and uses [PyTorch](https://pytorch.org/) for GPU acceleration. 26 | 27 | First, clone the repository and navigate to the directory: 28 | 29 | ```default 30 | git clone https://github.com/IvoVellekoop/wavesim_py.git 31 | cd wavesim_py 32 | ``` 33 | 34 | Then, you can install the dependencies in a couple of ways: 35 | 36 | [1. Using pip](#pip-installation) 37 | 38 | [2. Using conda](#conda-installation) 39 | 40 | [3. Using Poetry](#poetry-installation) 41 | 42 | We recommend working with a virtual environment to avoid conflicts with other packages. 43 | 44 | 45 | 46 | ### 1. **Using pip** 47 | 48 | If you prefer to use pip, you can install the required packages using [requirements.txt](https://github.com/IvoVellekoop/wavesim_py/blob/main/requirements.txt): 49 | 50 | 1. **Create a virtual environment and activate it** (optional but recommended) 51 | * First, [create a virtual environment](https://docs.python.org/3/library/venv.html#creating-virtual-environments) using the following command: 52 | ```default 53 | python -m venv path/to/venv 54 | ``` 55 | * Then, activate the virtual environment. The command depends on your operating system and shell ([How venvs work](https://docs.python.org/3/library/venv.html#how-venvs-work)): 56 | ```default 57 | source path/to/venv/bin/activate # for Linux/macOS 58 | path/to/venv/Scripts/activate.bat # for Windows (cmd) 59 | path/to/venv/Scripts/Activate.ps1 # for Windows (PowerShell) 60 | ``` 61 | 2. **Install packages**: 62 | ```default 63 | pip install -r requirements.txt 64 | ``` 65 | 66 | 67 | 68 | ### 2. **Using conda** 69 | 70 | We recommend using [Miniconda](https://docs.anaconda.com/miniconda/) (a much lighter counterpart of Anaconda) to install Python and the required packages (contained in [environment.yml](https://github.com/IvoVellekoop/wavesim_py/blob/main/environment.yml)) within a conda environment. 71 | 72 | 1. **Download Miniconda**, choosing the appropriate [Python installer](https://docs.anaconda.com/miniconda/) for your operating system (Windows/macOS/Linux). 73 | 2. **Install Miniconda**, following the [installation instructions](https://docs.anaconda.com/miniconda/miniconda-install/) for your OS. Follow the prompts on the installer screens. If you are unsure about any setting, accept the defaults. You can change them later. (If you cannot immediately activate conda, close and re-open your terminal window to make the changes take effect). 74 | 3. **Test your installation**. Open Anaconda Prompt and run the below command. Alternatively, open an editor like [Visual Studio Code](https://code.visualstudio.com/) or [PyCharm](https://www.jetbrains.com/pycharm/), select the Python interpreter in the `miniconda3/` directory with the label `('base')`, and run the command: 75 | ```default 76 | conda list 77 | ``` 78 | 79 | A list of installed packages appears if it has been installed correctly. 80 | 4. **Set up a conda environment**. Avoid using the base environment altogether. It is a good backup environment to fall back on if and when the other environments are corrupted/don’t work. Create a new environment using [environment.yml](https://github.com/IvoVellekoop/wavesim_py/blob/main/environment.yml) and activate: 81 | ```default 82 | conda env create -f environment.yml 83 | conda activate wavesim 84 | ``` 85 | 86 | The [Miniconda environment management guide](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) has more details if you need them. 87 | 88 | Alternatively, you can create a conda environment with a specific Python version, and then use the [requirements.txt](https://github.com/IvoVellekoop/wavesim_py/blob/main/requirements.txt) file to install the dependencies: 89 | ```default 90 | conda create -n wavesim python'>=3.11.0,<3.13' 91 | conda activate wavesim 92 | pip install -r requirements.txt 93 | ``` 94 | 95 | 96 | 97 | ### 3. **Using Poetry** 98 | 99 | 1. Install [Poetry](https://python-poetry.org/). 100 | 2. Install dependencies by running the following command: 101 | ```default 102 | poetry install 103 | ``` 104 | 105 | To run tests using pytest, you can install the development dependencies as well: 106 | ```default 107 | poetry install --with dev 108 | ``` 109 | 3. [Activate](https://python-poetry.org/docs/managing-environments/#activating-the-environment) the virtual environment created by Poetry. 110 | 111 | ## Running the code 112 | 113 | Once the virtual environment is set up with all the required packages, you are ready to run the code. You can go through any of the scripts in the `examples` [directory](https://github.com/IvoVellekoop/wavesim_py/tree/main/examples) for the basic steps needed to run a simulation. The directory contains examples of 1D, 2D, and 3D problems. 114 | 115 | You can run the code with just three inputs: 116 | 117 | * `permittivity`, i.e. refractive index distribution squared (a 4-dimensional array on a regular grid), 118 | * `source`, the same size as permittivity. 119 | * `periodic`, a tuple of three booleans to indicate whether the domain is periodic in each dimension [`True`] or not [`False`], and 120 | 121 | [Listing 1.1](#helmholtz-1d-analytical) shows a simple example of a 1D problem with a homogeneous medium ([helmholtz_1d_analytical.py](https://github.com/IvoVellekoop/wavesim_py/blob/main/examples/helmholtz_1d_analytical.py)) to explain these and other inputs. 122 | 123 | 124 | ```python 125 | """ 126 | Helmholtz 1D analytical test 127 | ============================ 128 | Test to compare the result of Wavesim to analytical results. 129 | Compare 1D free-space propagation with analytic solution. 130 | """ 131 | 132 | import torch 133 | import numpy as np 134 | from time import time 135 | import sys 136 | sys.path.append(".") 137 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 138 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 139 | from wavesim.iteration import run_algorithm # to run the wavesim iteration 140 | from wavesim.utilities import analytical_solution, preprocess, relative_error 141 | from __init__ import plot 142 | 143 | 144 | # Parameters 145 | wavelength = 1. # wavelength in micrometer (um) 146 | n_size = (256, 1, 1) # size of simulation domain (in pixels in x, y, and z direction) 147 | n = np.ones(n_size, dtype=np.complex64) # permittivity (refractive index²) map 148 | boundary_widths = 16 # width of the boundary in pixels 149 | 150 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 151 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n 152 | 153 | # Source term. This way is more efficient than dense tensor 154 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain 155 | values = torch.tensor([1.0]) # Amplitude: 1 156 | n_ext = tuple(np.array(n_size) + 2*boundary_array) 157 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64) 158 | 159 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains) 160 | # 1-domain, periodic boundaries (without wrapping correction) 161 | periodic = (True, True, True) # periodic boundaries, wrapped field. 162 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength) 163 | # # OR. Uncomment to test domain decomposition 164 | # periodic = (False, True, True) # wrapping correction 165 | # domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, n_domains=(3, 1, 1)) 166 | 167 | # Run the wavesim iteration and get the computed field 168 | start = time() 169 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=2000) 170 | end = time() - start 171 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}') 172 | u_computed = u_computed.squeeze().cpu().numpy()[boundary_widths:-boundary_widths] 173 | u_ref = analytical_solution(n_size[0], domain.pixel_size, wavelength) 174 | 175 | # Compute relative error with respect to the analytical solution 176 | re = relative_error(u_computed, u_ref) 177 | print(f'Relative error: {re:.2e}') 178 | threshold = 1.e-3 179 | assert re < threshold, f"Relative error higher than {threshold}" 180 | 181 | # Plot the results 182 | plot(u_computed, u_ref, re) 183 | ``` 184 | 185 | Apart from the inputs `permittivity`, `source`, and `periodic`, all other parameters have defaults. Details about these are given below (with the default values, if defined). 186 | 187 | Parameters in the `Domain` class: `HelmholtzDomain` (for a single domain without domain decomposition) or `MultiDomain` (to solve a problem with domain decomposition) 188 | 189 | * `permittivity`: 3-dimensional array with refractive index-squared distribution in x, y, and z direction. To set up a 1 or 2-dimensional problem, leave the other dimension(s) as 1. 190 | * `periodic`: indicates for each dimension whether the simulation is periodic (`True`) or not (`False`). For periodic dimensions, i.e., `periodic` `= [True, True, True]`, the field is wrapped around the domain. 191 | * `pixel_size` `:float = 0.25`: points per wavelength. 192 | * `wavelength` `:float = None`: wavelength: wavelength in micrometer (um). If not given, i.e. `None`, it is calculated as `1/pixel_size = 4 um`. 193 | * `n_domains` `: tuple[int, int, int] = (1, 1, 1)`: number of domains to split the simulation into. If the domain size is not divisible by n_domains, the last domain will be slightly smaller than the other ones. If `(1, 1, 1)`, indicates no domain decomposition. 194 | * `n_boundary` `: int = 8`: number of points used in the wrapping and domain transfer correction. Applicable when `periodic` is False in a dimension, or `n_domains > 1` in a dimension. 195 | * `device` `: str = None`: 196 | > * `'cpu'` to use the cpu, 197 | > * `'cuda:x'` to use a specific cuda device 198 | > * `'cuda'` or a list of strings, e.g., `['cuda:0', 'cuda:1']`, to distribute the simulation over the available/given cuda devices in a round-robin fashion 199 | > * `None`, which is equivalent to `'cuda'` if cuda devices are available, and `'cpu'` if they are not. 200 | * `debug` `: bool = False`: set to `True` for testing to return `inverse_propagator_kernel` as output. 201 | 202 | Parameters in the `run_algorithm()` function 203 | 204 | * `domain`: the domain object created by HelmholtzDomain() or MultiDomain() 205 | * `source`: source term, a 3-dimensional array, with the same size as permittivity. Set up amplitude(s) at the desired location(s), following the same principle as permittivity for 1, 2, or 3-dimensional problems. 206 | * `alpha` `: float = 0.75`: relaxation parameter for the Richardson iteration 207 | * `max_iterations` `: int = 1000`: maximum number of iterations 208 | * `threshold` `: float = 1.e-6`: threshold for the residual norm for stopping the iteration 209 | 210 | ## Acknowledgements 211 | 212 | This work was supported by the European Research Council’s Proof of Concept Grant n° [101069402]. 213 | 214 | ## Conflict of interest statement 215 | 216 | The authors declare no conflict of interest. 217 | 218 | ## References 219 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd -------------------------------------------------------------------------------- /docs/source/acknowledgements.rst: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | ---------------- 3 | 4 | This work was supported by the European Research Council's Proof of Concept Grant n° [101069402]. 5 | 6 | Conflict of interest statement 7 | ------------------------------ 8 | The authors declare no conflict of interest. 9 | 10 | References 11 | ---------- 12 | .. bibliography:: 13 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============== 3 | .. toctree:: 4 | :maxdepth: 2 5 | 6 | api_domain 7 | api_helmholtzdomain 8 | api_multidomain 9 | api_iteration 10 | api_utilities -------------------------------------------------------------------------------- /docs/source/api_domain.rst: -------------------------------------------------------------------------------- 1 | wavesim.domain 2 | ---------------------- 3 | .. automodule:: wavesim.domain 4 | :members: 5 | :imported-members: 6 | :undoc-members: 7 | :noindex: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/api_helmholtzdomain.rst: -------------------------------------------------------------------------------- 1 | wavesim.helmholtzdomain 2 | ----------------------- 3 | .. automodule:: wavesim.helmholtzdomain 4 | :members: 5 | :imported-members: 6 | :undoc-members: 7 | :noindex: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/api_iteration.rst: -------------------------------------------------------------------------------- 1 | wavesim.iteration 2 | ---------------------- 3 | .. automodule:: wavesim.iteration 4 | :members: 5 | :imported-members: 6 | :undoc-members: 7 | :noindex: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/api_multidomain.rst: -------------------------------------------------------------------------------- 1 | wavesim.multidomain 2 | ---------------------- 3 | .. automodule:: wavesim.multidomain 4 | :members: 5 | :imported-members: 6 | :undoc-members: 7 | :noindex: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/api_utilities.rst: -------------------------------------------------------------------------------- 1 | wavesim.utilities 2 | ---------------------- 3 | .. automodule:: wavesim.utilities 4 | :members: 5 | :imported-members: 6 | :undoc-members: 7 | :noindex: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/conclusion.rst: -------------------------------------------------------------------------------- 1 | Conclusion 2 | ========== 3 | 4 | In this work :cite:`mache2024domain`, we have introduced a domain decomposition of the modified Born series (MBS) approach :cite:`osnabrugge2016convergent, vettenburg2023universal` applied to the Helmholtz equation. With the new framework, we simulated a complex 3D structure of a remarkable :math:`3.1\cdot 10^7` wavelengths in size in just :math:`1.4` hours by solving over two GPUs. This represents a factor of :math:`1.93` increase over the largest possible simulation on a single GPU without domain decomposition. 5 | 6 | Our decomposition framework hinges on the ability to split the linear system as :math:`A=L+V`. Instead of the traditional splitting, where :math:`V` is a scattering potential that acts locally on each voxel, we introduced a :math:`V` that includes the communication between subdomains and corrections for wraparound artefacts. As a result, the operator :math:`(L+I)^{-1}` in the MBS iteration can be evaluated locally on each subdomain using a fast convolution. Therefore, this operator, which is the most computationally intensive step of the iteration, can be evaluated in parallel on multiple GPUs. 7 | 8 | Despite the significant overhead of our domain splitting method due to an increased number of iterations, and communication and synchronisation overhead, the ability to split a simulation over multiple GPUs results in a very significant speedup. Already, with the current dual-GPU system, we were able to solve a problem of :math:`315\times 315\times 315` wavelengths :math:`13.2\times` faster than without domain decomposition since the non-decomposed problem is too large to fit on a single GPU. Moreover, there is only a little overhead associated with adding more subdomains along an axis after the first splitting. This favourable scaling paves the way for distributing simulations over more GPUs or compute nodes in a cluster. 9 | 10 | In this work, we have already introduced strategies to reduce the overhead of the domain decomposition through truncating corrections to only a few points close to the edge of the subdomain and only activating certain subdomains in the iteration. We anticipate that further developments and optimisation of the code may help reduce the overhead of the lock-step execution. 11 | 12 | Finally, due to the generality of our approach, we expect it to be readily extended to include Maxwell's equations :cite:`kruger2017solution` and birefringent media :cite:`vettenburg2019calculating`. Given the rapid developments of GPU hardware and compute clusters, we anticipate that optical simulations at a cubic-millimetre scale can soon be performed in a matter of minutes. 13 | 14 | Code availability 15 | ----------------- 16 | The code for Wavesim is available on GitHub :cite:`wavesim_py`, it is licensed under the MIT license. When using Wavesim in your work, please cite :cite:`mache2024domain, osnabrugge2016convergent`. Examples and documentation for this project are available at `Read the Docs `_ :cite:`wavesim_documentation`. 17 | 18 | %endmatter% -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Path setup -------------------------------------------------------------- 7 | import os 8 | import sys 9 | import shutil 10 | from pathlib import Path 11 | 12 | from sphinx_markdown_builder import MarkdownBuilder 13 | 14 | # path setup (relevant for both local and read-the-docs builds) 15 | docs_source_dir = os.path.dirname(__file__) 16 | root_dir = os.path.dirname(os.path.dirname(docs_source_dir)) 17 | sys.path.append(docs_source_dir) 18 | sys.path.append(root_dir) 19 | 20 | # -- Project information ----------------------------------------------------- 21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 22 | project = 'wavesim' 23 | copyright = '2024, Ivo Vellekoop, and Swapnil Mache, University of Twente' 24 | # author = 'Swapnil Mache, Ivo M. Vellekoop' 25 | release = '0.1.0-alpha.2' 26 | html_title = "Wavesim - A Python package for wave propagation simulation" 27 | 28 | # -- latex configuration ----------------------------------------------------- 29 | latex_elements = { 30 | 'preamble': r""" 31 | \usepackage{authblk} 32 | """, 33 | 'maketitle': r""" 34 | \author[1,2]{Swapnil~Mache} 35 | \author[1*]{Ivo~M.~Vellekoop} 36 | \affil[1]{University of Twente, Biomedical Photonic Imaging, TechMed Institute, P. O. Box 217, 7500 AE Enschede, The Netherlands} 37 | \affil[2]{Currently at: Rayfos Ltd., Winton House, Winton Square, Basingstoke, United Kingdom} 38 | \affil[*]{Corresponding author: i.m.vellekoop@utwente.nl} 39 | \publishers{% 40 | \normalfont\normalsize% 41 | \parbox{0.8\linewidth}{% 42 | \vspace{0.5cm} 43 | The modified Born series (MBS) method is a fast and accurate method for simulating wave 44 | propagation in complex structures. The major limitation of MBS is that the size of the structure 45 | is limited by the working memory of a single computer or graphics processing unit (GPU). 46 | Through this package, we present a domain decomposition method that removes this limitation. We 47 | decompose large problems over subdomains while maintaining the accuracy, memory efficiency, and guaranteed monotonic convergence of the method. With this work, we have been able to obtain a 48 | factor of $1.93$ increase in size over the single-domain MBS simulations without domain decomposition through a 3D simulation using 2 GPUs. For the Helmholtz problem, we solved a complex structure of size $315 \times 315 \times 315$ wavelengths in just 1.4 hours on a dual-GPU system. 49 | } 50 | } 51 | \maketitle 52 | """, 53 | 'tableofcontents': "", 54 | 'makeindex': "", 55 | 'printindex': "", 56 | 'figure_align': "", 57 | 'extraclassoptions': 'notitlepage', 58 | } 59 | latex_docclass = { 60 | 'manual': 'scrartcl', 61 | 'howto': 'scrartcl', 62 | } 63 | latex_documents = [("index_latex", 64 | "wavesim.tex", 65 | "Wavesim - A Python package for wave propagation simulation", 66 | "", 67 | "howto")] 68 | latex_toplevel_sectioning = 'section' 69 | bibtex_default_style = 'unsrt' 70 | bibtex_bibfiles = ['references.bib'] 71 | numfig = True 72 | 73 | # -- General configuration --------------------------------------------------- 74 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 75 | extensions = ['sphinx.ext.napoleon', 76 | 'sphinx.ext.autodoc', 77 | 'sphinx.ext.mathjax', 78 | 'sphinx.ext.viewcode', 79 | 'sphinx_autodoc_typehints', 80 | 'sphinxcontrib.bibtex', 81 | 'sphinx.ext.autosectionlabel', 82 | 'sphinx_markdown_builder', 83 | 'sphinx_gallery.gen_gallery'] 84 | 85 | templates_path = ['_templates'] 86 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'acknowledgements.rst', 'sg_execution_times.rst'] 87 | master_doc = '' 88 | include_patterns = ['**'] 89 | napoleon_use_rtype = False 90 | napoleon_use_param = True 91 | typehints_document_rtype = False 92 | latex_engine = 'xelatex' 93 | add_module_names = False 94 | autodoc_preserve_defaults = True 95 | 96 | # -- Options for HTML output ------------------------------------------------- 97 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 98 | html_theme = 'sphinx_rtd_theme' 99 | 100 | # -- Options for sphinx-gallery ---------------------------------------------- 101 | # https://sphinx-gallery.github.io/stable/configuration.html 102 | sphinx_gallery_conf = { 103 | 'examples_dirs': '../../examples', # path to your example scripts 104 | 'ignore_pattern': '__init__.py|check_mem.py|timing_test.py', 105 | 'gallery_dirs': 'auto_examples', # path to where to save gallery generated output 106 | } 107 | 108 | 109 | # -- Monkey-patch the MarkdownTranslator class to support citations ------------ 110 | def visit_citation(self, node): 111 | """Patch-in function for markdown builder to support citations.""" 112 | id = node['ids'][0] 113 | self.add(f'') 114 | 115 | 116 | def visit_label(self, node): 117 | """Patch-in function for markdown builder to support citations.""" 118 | pass 119 | 120 | 121 | def setup(app): 122 | """Setup function for the Sphinx extension.""" 123 | # register event handlers 124 | # app.connect("autodoc-skip-member", skip) 125 | app.connect("build-finished", build_finished) 126 | app.connect("builder-inited", builder_inited) 127 | app.connect("source-read", source_read) 128 | # # app.connect("autodoc-skip-member", skip) 129 | # app.connect("build-finished", copy_readme) 130 | # app.connect("builder-inited", builder_inited) 131 | # app.connect("source-read", source_read) 132 | 133 | # monkey-patch the MarkdownTranslator class to support citations 134 | # TODO: this should be done in the markdown builder itself 135 | cls = MarkdownBuilder.default_translator_class 136 | # cls.visit_citation = visit_citation 137 | # cls.visit_label = visit_label 138 | 139 | 140 | def source_read(app, docname, source): 141 | """Modify the source of the readme and conclusion files based on the builder.""" 142 | if docname == 'readme' or docname == 'conclusion': 143 | if (app.builder.name == 'latex') == (docname == 'conclusion'): 144 | source[0] = source[0].replace('%endmatter%', '.. include:: acknowledgements.rst') 145 | else: 146 | source[0] = source[0].replace('%endmatter%', '') 147 | 148 | 149 | def builder_inited(app): 150 | """Set the master document and exclude patterns based on the builder.""" 151 | if app.builder.name == 'html': 152 | exclude_patterns.extend(['conclusion.rst', 'index_latex.rst', 'index_markdown.rst']) 153 | app.config.master_doc = 'index' 154 | elif app.builder.name == 'latex': 155 | exclude_patterns.extend(['auto_examples/*', 'index_markdown.rst', 'index.rst', 'api*']) 156 | app.config.master_doc = 'index_latex' 157 | elif app.builder.name == 'markdown': 158 | include_patterns.clear() 159 | include_patterns.extend(['readme.rst', 'index_markdown.rst']) 160 | app.config.master_doc = 'index_markdown' 161 | 162 | 163 | def build_finished(app, exception): 164 | if exception: 165 | return 166 | 167 | if app.builder.name == "markdown": 168 | # Copy the readme file to the root of the documentation directory. 169 | source_file = Path(app.outdir) / "readme.md" 170 | destination_dir = Path(app.confdir).parents[1] / "README.md" 171 | shutil.copy(source_file, destination_dir) 172 | 173 | elif app.builder.name == "latex": 174 | # The latex builder adds an empty author field to the title page. 175 | # This code removes it. 176 | # Define the path to the .tex file 177 | tex_file = Path(app.outdir) / "wavesim.tex" 178 | 179 | # Read the file 180 | with open(tex_file, "r") as file: 181 | content = file.read() 182 | 183 | # Remove \author{} from the file 184 | content = content.replace(r"\author{}", "") 185 | 186 | # Write the modified content back to the file 187 | with open(tex_file, "w") as file: 188 | file.write(content) 189 | -------------------------------------------------------------------------------- /docs/source/development.rst: -------------------------------------------------------------------------------- 1 | .. _section-development: 2 | 3 | Wavesim Development 4 | ============================================== 5 | 6 | Running the tests and examples 7 | -------------------------------------------------- 8 | To download the source code, including tests and examples, clone the repository from GitHub :cite:`wavesim_py`. Wavesim uses `poetry` :cite:`Poetry` for package management, so you have to download and install Poetry first. Then, navigate to the location where you want to store the source code, and execute the following commands to clone the repository, set up the poetry environment, and run the tests. 9 | 10 | .. code-block:: shell 11 | 12 | git clone https://github.com/IvoVellekoop/wavesim_py 13 | cd wavesim_py 14 | poetry install --with dev --with docs 15 | poetry run pytest 16 | 17 | The examples are located in the ``examples`` directory. Note that a lot of functionality is also demonstrated in the automatic tests located in the ``tests`` directory. As an alternative to downloading the source code, the samples can also be copied directly from the example gallery on the documentation website :cite:`readthedocs_Wavesim`. 18 | 19 | Building the documentation 20 | -------------------------------------------------- 21 | 22 | .. only:: html or markdown 23 | 24 | The html, and pdf versions of the documentation, as well as the `README.md` file in the root directory of the repository, are automatically generated from the docstrings in the source code and reStructuredText source files in the repository. 25 | 26 | .. only:: latex 27 | 28 | The html version of the documentation, as well as the `README.md` file in the root directory of the repository, and the pdf document you are currently reading are automatically generated from the docstrings in the source code and reStructuredText source files in the repository. 29 | 30 | Note that for building the pdf version of the documentation, you need to have `xelatex` installed, which comes with the MiKTeX distribution of LaTeX :cite:`MiKTeX`. Then, run the following commands to build the html and pdf versions of the documentation, and to auto-generate `README.md`. 31 | 32 | .. code-block:: shell 33 | 34 | poetry shell 35 | cd docs 36 | make clean 37 | make html 38 | make markdown 39 | make latex 40 | cd _build/latex 41 | xelatex wavesim 42 | xelatex wavesim 43 | 44 | 45 | Reporting bugs and contributing 46 | -------------------------------------------------- 47 | Bugs can be reported through the GitHub issue tracking system. Better than reporting bugs, we encourage users to *contribute bug fixes, optimizations, and other improvements*. These contributions can be made in the form of a pull request :cite:`zandonellaMassiddaOpenScience2022`, which will be reviewed by the development team and integrated into the package when appropriate. Please contact the current development team through GitHub :cite:`wavesim_py` or the `www.wavesim.org `_ forum to coordinate such contributions. 48 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Wavesim - A Python package for wave propagation simulation 2 | ========================================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :numbered: 7 | 8 | Introduction 9 | development 10 | api 11 | auto_examples/index -------------------------------------------------------------------------------- /docs/source/index_latex.rst: -------------------------------------------------------------------------------- 1 | Wavesim - A Python package for wave propagation simulation 2 | ========================================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :numbered: 7 | 8 | Introduction 9 | development 10 | conclusion 11 | auto_examples/index -------------------------------------------------------------------------------- /docs/source/index_markdown.rst: -------------------------------------------------------------------------------- 1 | Wavesim - A Python package for wave propagation simulation 2 | ========================================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :numbered: 7 | 8 | Introduction -------------------------------------------------------------------------------- /docs/source/readme.rst: -------------------------------------------------------------------------------- 1 | .. _root_label: 2 | 3 | Wavesim 4 | ===================================== 5 | 6 | .. 7 | NOTE: README.MD IS AUTO-GENERATED FROM DOCS/SOURCE/README.RST. DO NOT EDIT README.MD DIRECTLY. 8 | 9 | .. only:: html 10 | 11 | .. image:: https://readthedocs.org/projects/wavesim/badge/?version=latest 12 | :target: https://wavesim.readthedocs.io/en/latest/?badge=latest 13 | :alt: Documentation Status 14 | 15 | 16 | What is Wavesim? 17 | ---------------------------------------------- 18 | 19 | Wavesim is a tool to simulate the propagation of waves in complex, inhomogeneous structures. Whereas most available solvers use the popular finite difference time domain (FDTD) method :cite:`yee1966numerical, taflove1995computational, oskooi2010meep, nabavi2007new`, Wavesim is based on the modified Born series (MBS) approach, which has lower memory requirements, no numerical dispersion, and is faster as compared to FDTD :cite:`osnabrugge2016convergent, vettenburg2023universal`. 20 | 21 | This package :cite:`wavesim_py` is a Python implementation of the MBS approach for solving the Helmholtz equation in arbitrarily large media through domain decomposition :cite:`mache2024domain`. With this new framework, we simulated a complex 3D structure of a remarkable :math:`315\times 315\times 315` wavelengths :math:`\left( 3.1\cdot 10^7 \right)` in size in just :math:`1.4` hours by solving over two GPUs. This represents a factor of :math:`1.93` increase over the largest possible simulation on a single GPU without domain decomposition. 22 | 23 | When using Wavesim in your work, please cite: 24 | 25 | :cite:`osnabrugge2016convergent` |osnabrugge2016|_ 26 | 27 | :cite:`mache2024domain` |mache2024|_ 28 | 29 | .. _osnabrugge2016: https://doi.org/10.1016/j.jcp.2016.06.034 30 | .. |osnabrugge2016| replace:: Osnabrugge, G., Leedumrongwatthanakun, S., & Vellekoop, I. M. (2016). A convergent Born series for solving the inhomogeneous Helmholtz equation in arbitrarily large media. *Journal of computational physics, 322*\ , 113-124. 31 | 32 | .. _mache2024: https://arxiv.org/abs/2410.02395 33 | .. |mache2024| replace:: Mache, S., & Vellekoop, I. M. (2024). Domain decomposition of the modified Born series approach for large-scale wave propagation simulations. *arXiv preprint arXiv:2410.02395*. 34 | 35 | If you use the code in your research, please cite this repository as well :cite:`wavesim_py`. 36 | 37 | Examples and documentation for this project are available at `Read the Docs `_ :cite:`wavesim_documentation`. For more information (and to participate in the forum for discussions, queries, and requests), please visit our website `www.wavesim.org `_. 38 | 39 | Installation 40 | ---------------------------------------------- 41 | 42 | Wavesim requires `Python >=3.11.0 and <3.13.0 `_ and uses `PyTorch `_ for GPU acceleration. 43 | 44 | First, clone the repository and navigate to the directory:: 45 | 46 | git clone https://github.com/IvoVellekoop/wavesim_py.git 47 | cd wavesim_py 48 | 49 | Then, you can install the dependencies in a couple of ways: 50 | 51 | :ref:`pip_installation` 52 | 53 | :ref:`conda_installation` 54 | 55 | :ref:`poetry_installation` 56 | 57 | We recommend working with a virtual environment to avoid conflicts with other packages. 58 | 59 | .. _pip_installation: 60 | 61 | 1. **Using pip** 62 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 63 | 64 | If you prefer to use pip, you can install the required packages using `requirements.txt `_: 65 | 66 | 1. **Create a virtual environment and activate it** (optional but recommended) 67 | 68 | * First, `create a virtual environment `_ using the following command:: 69 | 70 | python -m venv path/to/venv 71 | 72 | * Then, activate the virtual environment. The command depends on your operating system and shell (`How venvs work `_):: 73 | 74 | source path/to/venv/bin/activate # for Linux/macOS 75 | path/to/venv/Scripts/activate.bat # for Windows (cmd) 76 | path/to/venv/Scripts/Activate.ps1 # for Windows (PowerShell) 77 | 78 | 2. **Install packages**:: 79 | 80 | pip install -r requirements.txt 81 | 82 | .. _conda_installation: 83 | 84 | 2. **Using conda** 85 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 86 | 87 | We recommend using `Miniconda `_ (a much lighter counterpart of Anaconda) to install Python and the required packages (contained in `environment.yml `_) within a conda environment. 88 | 89 | 1. **Download Miniconda**, choosing the appropriate `Python installer `_ for your operating system (Windows/macOS/Linux). 90 | 91 | 2. **Install Miniconda**, following the `installation instructions `_ for your OS. Follow the prompts on the installer screens. If you are unsure about any setting, accept the defaults. You can change them later. (If you cannot immediately activate conda, close and re-open your terminal window to make the changes take effect). 92 | 93 | 3. **Test your installation**. Open Anaconda Prompt and run the below command. Alternatively, open an editor like `Visual Studio Code `_ or `PyCharm `_, select the Python interpreter in the ``miniconda3/`` directory with the label ``('base')``, and run the command:: 94 | 95 | conda list 96 | 97 | A list of installed packages appears if it has been installed correctly. 98 | 99 | 4. **Set up a conda environment**. Avoid using the base environment altogether. It is a good backup environment to fall back on if and when the other environments are corrupted/don't work. Create a new environment using `environment.yml `_ and activate:: 100 | 101 | conda env create -f environment.yml 102 | conda activate wavesim 103 | 104 | The `Miniconda environment management guide `_ has more details if you need them. 105 | 106 | Alternatively, you can create a conda environment with a specific Python version, and then use the `requirements.txt `_ file to install the dependencies:: 107 | 108 | conda create -n wavesim python'>=3.11.0,<3.13' 109 | conda activate wavesim 110 | pip install -r requirements.txt 111 | 112 | .. _poetry_installation: 113 | 114 | 3. **Using Poetry** 115 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 116 | 117 | 1. Install `Poetry `_. 118 | 2. Install dependencies by running the following command:: 119 | 120 | poetry install 121 | 122 | To run tests using pytest, you can install the development dependencies as well:: 123 | 124 | poetry install --with dev 125 | 126 | 3. `Activate `_ the virtual environment created by Poetry. 127 | 128 | Running the code 129 | ---------------- 130 | 131 | Once the virtual environment is set up with all the required packages, you are ready to run the code. You can go through any of the scripts in the ``examples`` `directory `_ for the basic steps needed to run a simulation. The directory contains examples of 1D, 2D, and 3D problems. 132 | 133 | You can run the code with just three inputs: 134 | 135 | * :attr:`~.Domain.permittivity`, i.e. refractive index distribution squared (a 3-dimensional array on a regular grid), 136 | 137 | * :attr:`~.Domain.periodic`, a tuple of three booleans to indicate whether the domain is periodic in each dimension [``True``] or not [``False``], and 138 | 139 | * :attr:`~.Domain.source`, the same size as permittivity. 140 | 141 | :numref:`helmholtz_1d_analytical` shows a simple example of a 1D problem with a homogeneous medium (`helmholtz_1d_analytical.py `_) to explain these and other inputs. 142 | 143 | .. _helmholtz_1d_analytical: 144 | .. literalinclude:: ../../examples/helmholtz_1d_analytical.py 145 | :language: python 146 | :caption: ``helmholtz_1d_analytical.py``. A simple example of a 1D problem with a homogeneous medium. 147 | 148 | Apart from the inputs :attr:`~.Domain.permittivity`, :attr:`~.Domain.source`, and :attr:`~.Domain.periodic`, all other parameters have defaults. Details about all parameters are given below (with the default values, if defined). 149 | 150 | Parameters in the :class:`~.domain.Domain` class: :class:`~.helmholtzdomain.HelmholtzDomain` (for a single domain without domain decomposition) or :class:`~.multidomain.MultiDomain` (to solve a problem with domain decomposition) 151 | 152 | * :attr:`~.Domain.permittivity`: 3-dimensional array with refractive index-squared distribution in x, y, and z direction. To set up a 1 or 2-dimensional problem, leave the other dimension(s) as 1. 153 | 154 | * :attr:`~.Domain.periodic`: indicates for each dimension whether the simulation is periodic (``True``) or not (``False``). For periodic dimensions, i.e., :attr:`~.Domain.periodic` ``= [True, True, True]``, the field is wrapped around the domain. 155 | 156 | * :attr:`~.pixel_size` ``:float = 0.25``: points per wavelength. 157 | 158 | * :attr:`~.wavelength` ``:float = None``: wavelength in micrometer (um). If not given, i.e. ``None``, it is calculated as ``1/pixel_size = 4 um``. 159 | 160 | * :attr:`~.n_domains` ``: tuple[int, int, int] = (1, 1, 1)``: number of domains to split the simulation into. If the domain size is not divisible by n_domains, the last domain will be slightly smaller than the other ones. If ``(1, 1, 1)``, indicates no domain decomposition. 161 | 162 | * :attr:`~.n_boundary` ``: int = 8``: number of points used in the wrapping and domain transfer correction. Applicable when :attr:`~.Domain.periodic` is False in a dimension, or ``n_domains > 1`` in a dimension. 163 | 164 | * :attr:`~.device` ``: str = None``: 165 | 166 | * ``'cpu'`` to use the cpu, 167 | 168 | * ``'cuda:x'`` to use a specific cuda device 169 | 170 | * ``'cuda'`` or a list of strings, e.g., ``['cuda:0', 'cuda:1']``, to distribute the simulation over the available/given cuda devices in a round-robin fashion 171 | 172 | * ``None``, which is equivalent to ``'cuda'`` if cuda devices are available, and ``'cpu'`` if they are not. 173 | 174 | * :attr:`~.debug` ``: bool = False``: set to ``True`` for testing to return :attr:`~.inverse_propagator_kernel` as output. 175 | 176 | Parameters in the :func:`run_algorithm` function 177 | 178 | * :attr:`~.domain`: the domain object created by HelmholtzDomain() or MultiDomain() 179 | 180 | * :attr:`~.Domain.source`: source term, a 3-dimensional array, with the same size as permittivity. Set up amplitude(s) at the desired location(s), following the same principle as permittivity for 1, 2, or 3-dimensional problems. 181 | 182 | * :attr:`~.alpha` ``: float = 0.75``: relaxation parameter for the Richardson iteration 183 | 184 | * :attr:`~.max_iterations` ``: int = 1000``: maximum number of iterations 185 | 186 | * :attr:`~.threshold` ``: float = 1.e-6``: threshold for the residual norm for stopping the iteration 187 | 188 | %endmatter% 189 | -------------------------------------------------------------------------------- /docs/source/references.bib: -------------------------------------------------------------------------------- 1 | @article{kruger2017solution, 2 | author = {Benjamin Kr\"{u}ger and Thomas Brenner and Alwin Kienle}, 3 | journal = "{Optics Express}", 4 | year = {2017}, 5 | number = {21}, 6 | pages = {25165--25182}, 7 | publisher = {Optica Publishing Group}, 8 | title = {Solution of the inhomogeneous {M}axwell's equations using a {B}orn series}, 9 | volume = {25}, 10 | month = {Oct}, 11 | urlintro = {https://opg.optica.org/oe/abstract.cfm?URI=oe-25-21-25165}, 12 | doi = {10.1364/OE.25.025165}, 13 | } 14 | 15 | @article{mache2024domain, 16 | author={Swapnil Mache and Ivo M. Vellekoop}, 17 | title={Domain decomposition of the modified Born series approach for large-scale wave propagation simulations}, 18 | journal={arXiv preprint}, 19 | year={2024}, 20 | eprint={2410.02395}, 21 | archivePrefix={arXiv}, 22 | primaryClass={physics.comp-ph}, 23 | urlintro = {https://arxiv.org/abs/2410.02395}, 24 | doiprefix = {10.48550/arXiv.2410.02395} 25 | } 26 | 27 | @misc{MiKTeX, 28 | title = {{MikTeX: An Up-to-Date Implementation of TeX/LaTeX and Related Programs}}, 29 | url = {https://miktex.org/}, 30 | year = {2024}, 31 | note = {Version 24.1} 32 | } 33 | 34 | @article{nabavi2007new, 35 | author = {Majid Nabavi and M.H. Kamran Siddiqui and Javad Dargahi}, 36 | title = {A new 9-point sixth-order accurate compact finite-difference method for the {H}elmholtz equation}, 37 | journal = {Journal of Sound and Vibration}, 38 | year = {2007}, 39 | volume = {307}, 40 | number = {3}, 41 | pages = {972-982}, 42 | issn = {0022-460X}, 43 | doi = {10.1016/j.jsv.2007.06.070}, 44 | urlintro = {https://www.sciencedirect.com/science/article/pii/S0022460X07004877}, 45 | } 46 | 47 | @article{oskooi2010meep, 48 | author = {Ardavan F. Oskooi and David Roundy and Mihai Ibanescu and Peter Bermel and J.D. Joannopoulos and Steven G. Johnson}, 49 | title = {{MEEP}: A flexible free-software package for electromagnetic simulations by the {FDTD} method}, 50 | journal = {Computer Physics Communications}, 51 | year = {2010}, 52 | volume = {181}, 53 | number = {3}, 54 | pages = {687-702}, 55 | issn = {0010-4655}, 56 | doi = {10.1016/j.cpc.2009.11.008}, 57 | urlintro = {https://www.sciencedirect.com/science/article/pii/S001046550900383X}, 58 | keywords = {Computational electromagnetism, FDTD, Maxwell solver}, 59 | } 60 | 61 | @article{osnabrugge2016convergent, 62 | author={Osnabrugge, Gerwin and Leedumrongwatthanakun, Saroch and Vellekoop, Ivo M}, 63 | title={A convergent {B}orn series for solving the inhomogeneous {H}elmholtz equation in arbitrarily large media}, 64 | journal={Journal of Computational Physics}, 65 | year={2016}, 66 | volume={322}, 67 | pages={113--124}, 68 | publisher={Elsevier}, 69 | doi={10.1016/j.jcp.2016.06.034}, 70 | } 71 | 72 | @article{osnabrugge2021ultra, 73 | author={Osnabrugge, Gerwin and Benedictus, Maaike and Vellekoop, Ivo M}, 74 | title={Ultra-thin boundary layer for high-accuracy simulations of light propagation}, 75 | journal={Optics express}, 76 | year={2021}, 77 | volume={29}, 78 | number={2}, 79 | pages={1649--1658}, 80 | publisher={Optica Publishing Group}, 81 | doi={10.1364/OE.412833}, 82 | } 83 | 84 | @misc{Poetry, 85 | title = {{Poetry: Python Dependency Management and Packaging Made Easy}}, 86 | author = {Sébastien Eustace and The Poetry Contributors}, 87 | year = {2023}, 88 | note = {Version 1.7.1}, 89 | url = {https://python-poetry.org/}, 90 | } 91 | 92 | @misc{readthedocs_Wavesim, 93 | author = {Ivo M. Vellekoop}, 94 | title = {{W}avesim | {R}ead the {D}ocs}, 95 | url = {https://readthedocs.org/projects/wavesim/}, 96 | note = {[Published DD-MM-2024]}, 97 | } 98 | 99 | @book{taflove1995computational, 100 | author={Taflove, Allen and Hagnes, Susan C}, 101 | title={Computational electrodynamics: The Finite-Difference Time-Domain Method}, 102 | publisher={Artech House}, 103 | year={1995} 104 | } 105 | 106 | @article{vettenburg2019calculating, 107 | author = {T. Vettenburg and S. A. R. Horsley and J. Bertolotti}, 108 | title = {Calculating coherent light-wave propagation in large heterogeneous media}, 109 | journal = {Opt. Express}, 110 | year = {2019}, 111 | month = {Apr}, 112 | volume = {27}, 113 | number = {9}, 114 | pages = {11946--11967}, 115 | publisher = {Optica Publishing Group}, 116 | keywords = {Chiral media; Finite element method; Material properties; Negative index materials; Optical microscopy; Scattering media}, 117 | urlintro = {https://opg.optica.org/oe/abstract.cfm?URI=oe-27-9-11946}, 118 | doi = {10.1364/OE.27.011946}, 119 | } 120 | 121 | @article{vettenburg2023universal, 122 | author={Tom Vettenburg and Ivo M. Vellekoop}, 123 | title={A universal matrix-free split preconditioner for the fixed-point iterative solution of non-symmetric linear systems}, 124 | journal={arXiv preprint}, 125 | year={2023}, 126 | eprint={2207.14222}, 127 | archivePrefix={arXiv}, 128 | primaryClass={math.NA}, 129 | urlintro = {https://arxiv.org/abs/2207.14222}, 130 | doiprefix = {10.48550/arXiv.2207.14222} 131 | } 132 | 133 | @misc{wavesim_documentation, 134 | title = {Wavesim documentation}, 135 | url = {https://wavesim.readthedocs.io/en/latest/}, 136 | } 137 | 138 | @misc{wavesim_py, 139 | author = {Swapnil Mache and Ivo M. Vellekoop}, 140 | title = {Wavesim - A Python package for wave propagation simulation}, 141 | url = {https://github.com/IvoVellekoop/wavesim_py} 142 | } 143 | 144 | @misc{wavesim_matlab, 145 | author = {G. Osnabrugge and I. M. Vellekoop}, 146 | title = "{WaveSim}", 147 | note = {\url{https://github.com/ivovellekoop/wavesim}}, 148 | year = {2020} 149 | } 150 | 151 | @article{yee1966numerical, 152 | author={Kane Yee}, 153 | title={Numerical solution of initial boundary value problems involving {M}axwell's equations in isotropic media}, 154 | journal={IEEE Transactions on Antennas and Propagation}, 155 | year={1966}, 156 | volume={14}, 157 | number={3}, 158 | pages={302-307}, 159 | doi={10.1109/TAP.1966.1138693} 160 | } 161 | 162 | @book{zandonellaMassiddaOpenScience2022, 163 | title = {The Open Science Manual: Make Your Scientific Research Accessible and Reproducible}, 164 | author = {Zandonella Callegher, Claudio and Massidda, Davide}, 165 | year = {2022}, 166 | publisher = {}, 167 | urlintro = {https://arca-dpss.github.io/manual-open-science/}, 168 | doi = {10.5281/zenodo.6521850} 169 | } 170 | -------------------------------------------------------------------------------- /docs/tex.bat: -------------------------------------------------------------------------------- 1 | .\make.bat clean & .\make.bat latex & cd _build\latex & xelatex wavesim.tex & xelatex wavesim.tex & cd .. & cd .. -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: wavesim 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | dependencies: 7 | - python>=3.11.0,<3.13.0 8 | - numpy<2.0.0 9 | - matplotlib 10 | - scipy 11 | - pytest 12 | - pytorch 13 | - pytorch-cuda 14 | - porespy 15 | - scikit-image<0.23 16 | -------------------------------------------------------------------------------- /examples/README.rst: -------------------------------------------------------------------------------- 1 | Example gallery 2 | ===================== 3 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from scipy.signal.windows import gaussian 5 | from scipy.fft import fftn, ifftn, fftshift 6 | from wavesim.utilities import normalize, relative_error 7 | 8 | 9 | def random_permittivity(shape): 10 | np.random.seed(0) # Set the random seed for reproducibility 11 | n = (1.0 + np.random.rand(*shape).astype(np.float32) + 12 | 0.03j * np.random.rand(*shape).astype(np.float32)) # Random refractive index 13 | n = smooth_n(n, shape) # Low pass filter to remove sharp edges 14 | n.real = normalize(n.real, a=1.0, b=2.0) # Normalize to [1, 2] 15 | n.imag = normalize(n.imag, a=0.0, b=0.03) # Normalize to [0, 0.03] 16 | # make sure that the imaginary part of n² is positive 17 | mask = (n ** 2).imag < 0 18 | n.imag[mask] *= -1.0 19 | 20 | n[0:5, :, :] = 1 21 | n[-5:, :, :] = 1 22 | return n ** 2 23 | 24 | 25 | def smooth_n(n, shape): 26 | """Low pass filter to remove sharp edges""" 27 | n_fft = fftn(n) 28 | w = (window(shape[1]).T @ window(shape[0])).T[:, :, None] * window(shape[2]).reshape(1, 1, shape[2]) 29 | n = ifftn(n_fft * fftshift(w)) 30 | n = np.clip(n.real, a_min=1.0, a_max=None) + 1.0j * np.clip(n.imag, a_min=0.0, a_max=None) 31 | 32 | assert (n ** 2).imag.min() >= 0, 'Imaginary part of n² is negative' 33 | assert n.shape == shape, 'n and shape do not match' 34 | assert n.dtype == np.complex64, 'n is not complex64' 35 | return n 36 | 37 | 38 | def window(x): 39 | """Create a window function for low pass filtering""" 40 | c0 = round(x / 4) 41 | cl = (x - c0) // 2 42 | cr = cl 43 | if c0 + cl + cr != x: 44 | c0 = x - cl - cr 45 | return np.concatenate((np.zeros((1, cl), dtype=np.complex64), 46 | np.ones((1, c0), dtype=np.complex64), 47 | np.zeros((1, cr), dtype=np.complex64)), axis=1) 48 | 49 | 50 | def construct_source(source_type, at, shape): 51 | if source_type == 'point': 52 | return torch.sparse_coo_tensor(at[:, None], torch.tensor([1.0]), shape, dtype=torch.complex64) 53 | elif source_type == 'plane_wave': 54 | return source_plane_wave(at, shape) 55 | elif source_type == 'gaussian_beam': 56 | return source_gaussian_beam(at, shape) 57 | else: 58 | raise ValueError(f"Unknown source type: {source_type}") 59 | 60 | 61 | def source_plane_wave(at, shape): 62 | """ Set up source, with size same as permittivity, 63 | and a plane wave source on one edge of the domain """ 64 | # TODO: use CSR format instead? 65 | data = np.ones((1, shape[1], shape[2]), dtype=np.float32) # the source itself 66 | return torch.sparse_coo_tensor(at, data, shape, dtype=torch.complex64) 67 | 68 | 69 | def source_gaussian_beam(at, shape): 70 | """ Set up source, with size same as permittivity, 71 | and a Gaussian beam source on one edge of the domain """ 72 | # TODO: use CSR format instead? 73 | std = (shape[1] - 1) / (2 * 3) 74 | source_amplitude = gaussian(shape[1], std).astype(np.float32) 75 | source_amplitude = np.outer(gaussian(shape[2], std).astype(np.float32), 76 | source_amplitude.astype(np.float32)) 77 | source_amplitude = torch.tensor(source_amplitude[None, ...]) 78 | data = torch.zeros((1, shape[1], shape[2]), dtype=torch.complex64) 79 | data[0, 80 | 0:shape[1], 81 | 0:shape[2]] = source_amplitude 82 | return torch.sparse_coo_tensor(at, data, shape, dtype=torch.complex64) 83 | 84 | 85 | def plot(x, x_ref, re=None, normalize_x=True): 86 | """Plot the computed field x and the reference field x_ref. 87 | If x and x_ref are 1D arrays, the real and imaginary parts are plotted separately. 88 | If x and x_ref are 2D arrays, the absolute values are plotted. 89 | If x and x_ref are 3D arrays, the central slice is plotted. 90 | If normalize_x is True, the values are normalized to the same range. 91 | The relative error is (computed, if needed, and) displayed. 92 | """ 93 | 94 | re = relative_error(x, x_ref) if re is None else re 95 | 96 | if x.ndim == 1 and x_ref.ndim == 1: 97 | plt.subplot(211) 98 | plt.plot(x_ref.real, label='Analytic') 99 | plt.plot(x.real, label='Computed') 100 | plt.legend() 101 | plt.title(f'Real part (RE = {relative_error(x.real, x_ref.real):.2e})') 102 | plt.grid() 103 | 104 | plt.subplot(212) 105 | plt.plot(x_ref.imag, label='Analytic') 106 | plt.plot(x.imag, label='Computed') 107 | plt.legend() 108 | plt.title(f'Imaginary part (RE = {relative_error(x.imag, x_ref.imag):.2e})') 109 | plt.grid() 110 | 111 | plt.suptitle(f'Relative error (RE) = {re:.2e}') 112 | plt.tight_layout() 113 | plt.show() 114 | else: 115 | if x.ndim == 3 and x_ref.ndim == 3: 116 | x = x[x.shape[0]//2, ...] 117 | x_ref = x_ref[x_ref.shape[0]//2, ...] 118 | 119 | x = np.abs(x) 120 | x_ref = np.abs(x_ref) 121 | if normalize_x: 122 | min_val = min(np.min(x), np.min(x_ref)) 123 | max_val = max(np.max(x), np.max(x_ref)) 124 | a = 0 125 | b = 1 126 | x = normalize(x, min_val, max_val, a, b) 127 | x_ref = normalize(x_ref, min_val, max_val, a, b) 128 | else: 129 | a = None 130 | b = None 131 | 132 | cmap = 'inferno' 133 | 134 | plt.figure(figsize=(10, 5)) 135 | plt.subplot(121) 136 | plt.imshow(x_ref, cmap=cmap, vmin=a, vmax=b) 137 | plt.colorbar(fraction=0.046, pad=0.04) 138 | plt.title('Reference') 139 | 140 | plt.subplot(122) 141 | plt.imshow(x, cmap=cmap, vmin=a, vmax=b) 142 | plt.colorbar(fraction=0.046, pad=0.04) 143 | plt.title('Computed') 144 | 145 | plt.suptitle(f'Relative error (RE) = {re:.2e}') 146 | plt.tight_layout() 147 | plt.show() 148 | -------------------------------------------------------------------------------- /examples/check_mem.py: -------------------------------------------------------------------------------- 1 | from torch.cuda.memory import _record_memory_history, _dump_snapshot 2 | from torch import zeros, complex64, cat 3 | from torch.cuda import empty_cache 4 | 5 | _record_memory_history(True, trace_alloc_max_entries=100000, 6 | trace_alloc_record_context=True) 7 | 8 | t = [None] * 6 9 | t[0] = zeros(8, 500, 500, dtype=complex64, device='cuda') 10 | t[1] = zeros(8, 500, 500, dtype=complex64, device='cuda') 11 | 12 | s = cat([i for i in t if i is not None], dim=0) 13 | print(s.shape) 14 | del s 15 | # s = t[0].clone() 16 | 17 | # for tt in t: 18 | # if tt is not None: 19 | # del tt 20 | 21 | # del s 22 | 23 | _dump_snapshot(f"logs/mem_snapshot.pickle") 24 | _record_memory_history(enabled=None) 25 | -------------------------------------------------------------------------------- /examples/helmholtz_1d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helmholtz 1D example with glass plate 3 | ===================================== 4 | Test for 1D propagation through glass plate. 5 | Compare with reference solution (matlab repo result). 6 | """ 7 | 8 | import os 9 | import torch 10 | import numpy as np 11 | from time import time 12 | from scipy.io import loadmat 13 | import sys 14 | sys.path.append(".") 15 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 16 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 17 | from wavesim.iteration import run_algorithm # to run the wavesim iteration 18 | from wavesim.utilities import preprocess, relative_error 19 | from __init__ import plot 20 | 21 | if os.path.basename(os.getcwd()) == 'examples': 22 | os.chdir('..') 23 | 24 | 25 | # Parameters 26 | wavelength = 1. # wavelength in micrometer (um) 27 | n_size = (256, 1, 1) # size of simulation domain (in pixels in x, y, and z direction) 28 | n = np.ones(n_size, dtype=np.complex64) # refractive index map 29 | n[99:130] = 1.5 # glass plate 30 | boundary_widths = 24 # width of the boundary in pixels 31 | 32 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 33 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n 34 | 35 | # Source term. This way is more efficient than dense tensor 36 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain 37 | values = torch.tensor([1.0]) # Amplitude: 1 38 | n_ext = tuple(np.array(n_size) + 2*boundary_array) 39 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64) 40 | 41 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains) 42 | # 1-domain, periodic boundaries (without wrapping correction) 43 | periodic = (True, True, True) # periodic boundaries, wrapped field. 44 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength) 45 | # # OR. Uncomment to test domain decomposition 46 | # periodic = (False, True, True) # wrapping correction 47 | # domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=1., n_domains=(2, 1, 1)) 48 | 49 | # Run the wavesim iteration and get the computed field 50 | start = time() 51 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=2000) 52 | end = time() - start 53 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}') 54 | u_computed = u_computed.squeeze().cpu().numpy()[boundary_widths:-boundary_widths] 55 | 56 | # load dictionary of results from matlab wavesim/anysim for comparison and validation 57 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u']) 58 | 59 | # Compute relative error with respect to the reference solution 60 | re = relative_error(u_computed, u_ref) 61 | print(f'Relative error: {re:.2e}') 62 | threshold = 1.e-3 63 | assert re < threshold, f"Relative error higher than {threshold}" 64 | 65 | # Plot the results 66 | plot(u_computed, u_ref, re) 67 | -------------------------------------------------------------------------------- /examples/helmholtz_1d_analytical.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helmholtz 1D analytical test 3 | ============================ 4 | Test to compare the result of Wavesim to analytical results. 5 | Compare 1D free-space propagation with analytic solution. 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | from time import time 11 | import sys 12 | sys.path.append(".") 13 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 14 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 15 | from wavesim.iteration import run_algorithm # to run the wavesim iteration 16 | from wavesim.utilities import analytical_solution, preprocess, relative_error 17 | from __init__ import plot 18 | 19 | 20 | # Parameters 21 | wavelength = 1. # wavelength in micrometer (um) 22 | n_size = (256, 1, 1) # size of simulation domain (in pixels in x, y, and z direction) 23 | n = np.ones(n_size, dtype=np.complex64) # permittivity (refractive index²) map 24 | boundary_widths = 16 # width of the boundary in pixels 25 | 26 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 27 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n 28 | 29 | # Source term. This way is more efficient than dense tensor 30 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain 31 | values = torch.tensor([1.0]) # Amplitude: 1 32 | n_ext = tuple(np.array(n_size) + 2*boundary_array) 33 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64) 34 | 35 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains) 36 | # 1-domain, periodic boundaries (without wrapping correction) 37 | periodic = (True, True, True) # periodic boundaries, wrapped field. 38 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength) 39 | # # OR. Uncomment to test domain decomposition 40 | # periodic = (False, True, True) # wrapping correction 41 | # domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, n_domains=(3, 1, 1)) 42 | 43 | # Run the wavesim iteration and get the computed field 44 | start = time() 45 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=2000) 46 | end = time() - start 47 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}') 48 | u_computed = u_computed.squeeze().cpu().numpy()[boundary_widths:-boundary_widths] 49 | u_ref = analytical_solution(n_size[0], domain.pixel_size, wavelength) 50 | 51 | # Compute relative error with respect to the analytical solution 52 | re = relative_error(u_computed, u_ref) 53 | print(f'Relative error: {re:.2e}') 54 | threshold = 1.e-3 55 | assert re < threshold, f"Relative error higher than {threshold}" 56 | 57 | # Plot the results 58 | plot(u_computed, u_ref, re) 59 | -------------------------------------------------------------------------------- /examples/helmholtz_2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helmholtz 2D high contrast test 3 | =============================== 4 | Test for propagation in 2D structure made of iron, 5 | with high refractive index contrast. 6 | Compare with reference solution (matlab repo result). 7 | """ 8 | 9 | import os 10 | import torch 11 | import numpy as np 12 | from time import time 13 | from scipy.io import loadmat 14 | from PIL.Image import BILINEAR, fromarray, open 15 | import sys 16 | sys.path.append(".") 17 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 18 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 19 | from wavesim.iteration import run_algorithm # to run the wavesim iteration 20 | from wavesim.utilities import pad_boundaries, preprocess, relative_error 21 | from __init__ import plot 22 | 23 | if os.path.basename(os.getcwd()) == 'examples': 24 | os.chdir('..') 25 | 26 | 27 | # Parameters 28 | n_iron = 2.8954 + 2.9179j 29 | n_contrast = n_iron - 1 30 | wavelength = 0.532 # Wavelength in micrometers 31 | pixel_size = wavelength / (3 * np.max(abs(n_contrast + 1))) # Pixel size in wavelength units 32 | 33 | # Load image and create refractive index map 34 | oversampling = 0.25 35 | im = np.asarray(open('examples/logo_structure_vector.png')) / 255 36 | n_im = ((np.where(im[:, :, 2] > 0.25, 1, 0) * n_contrast) + 1) # Refractive index map 37 | n_roi = int(oversampling * n_im.shape[0]) # Size of ROI in pixels 38 | n = np.asarray(fromarray(n_im.real).resize((n_roi, n_roi), BILINEAR)) + 1j * np.asarray( 39 | fromarray(n_im.imag).resize((n_roi, n_roi), BILINEAR)) # Refractive index map 40 | boundary_widths = 8 # Width of the boundary in pixels 41 | 42 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 43 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n 44 | 45 | # Source term 46 | source = np.asarray(fromarray(im[:, :, 1]).resize((n_roi, n_roi), BILINEAR)) 47 | source = pad_boundaries(source, boundary_array) 48 | source = torch.tensor(source, dtype=torch.complex64) 49 | 50 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains) 51 | # 1-domain, periodic boundaries (without wrapping correction) 52 | periodic = (True, True, True) # periodic boundaries, wrapped field. 53 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength) 54 | # # OR. Uncomment to test domain decomposition 55 | # periodic = (True, False, True) # wrapping correction 56 | # domain = MultiDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength, 57 | # n_domains=(1, 2, 1)) 58 | 59 | # Run the wavesim iteration and get the computed field 60 | start = time() 61 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=int(1.e+5)) 62 | end = time() - start 63 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}') 64 | u_computed = u_computed.squeeze().cpu().numpy()[*([slice(boundary_widths, -boundary_widths)]*2)] 65 | 66 | # load dictionary of results from matlab wavesim/anysim for comparison and validation 67 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u2d_hc']) 68 | 69 | # Compute relative error with respect to the reference solution 70 | re = relative_error(u_computed, u_ref) 71 | print(f'Relative error: {re:.2e}') 72 | threshold = 1.e-3 73 | assert re < threshold, f"Relative error {re} higher than {threshold}" 74 | 75 | # Plot the results 76 | plot(u_computed, u_ref, re) 77 | -------------------------------------------------------------------------------- /examples/helmholtz_2d_homogeneous.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helmholtz 2D homogeneous medium test 3 | ==================================== 4 | Test for propagation in 2D homogeneous medium. 5 | """ 6 | 7 | import torch 8 | import numpy as np 9 | from time import time 10 | import sys 11 | sys.path.append(".") 12 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 13 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 14 | from wavesim.iteration import run_algorithm # to run the wavesim iteration 15 | from wavesim.utilities import preprocess 16 | 17 | 18 | # Parameters 19 | wavelength = 1. # Wavelength in micrometers 20 | pixel_size = wavelength/4 # Pixel size in wavelength units 21 | # Size of simulation domain (in pixels in x, y, and z direction) 22 | sim_size = np.array([50, 50, 1]) # Simulation size in micrometers 23 | n_size = (sim_size * wavelength/pixel_size).astype(int) 24 | n = np.ones(n_size, dtype=np.complex64) # Refractive index map 25 | boundary_widths = 8 # Width of the boundary in pixels 26 | 27 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 28 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n 29 | 30 | # Source term. This way is more efficient than dense tensor 31 | indices = torch.tensor([[int(v/2 - 1) + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain 32 | values = torch.tensor([1.0]) # Amplitude: 1 33 | n_ext = tuple(np.array(n_size) + 2*boundary_array) 34 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64) 35 | 36 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains) 37 | # 1-domain, periodic boundaries (without wrapping correction) 38 | periodic = (True, True, True) # periodic boundaries, wrapped field. 39 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength) 40 | # # OR. Uncomment to test domain decomposition 41 | # periodic = (False, True, True) # wrapping correction 42 | # domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, 43 | # n_domains=(2, 1, 1)) 44 | 45 | # Run the wavesim iteration and get the computed field 46 | start = time() 47 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=2000) 48 | end = time() - start 49 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}') 50 | u_computed = u_computed.squeeze()[*([slice(boundary_widths, -boundary_widths)]*2)] 51 | -------------------------------------------------------------------------------- /examples/helmholtz_2d_low_contrast.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helmholtz 2D low contrast test 3 | =============================== 4 | Test for propagation in 2D structure with low refractive index contrast 5 | (made of fat and water to mimic biological tissue). 6 | Compare with reference solution (matlab repo result). 7 | """ 8 | 9 | import os 10 | import torch 11 | import numpy as np 12 | from time import time 13 | from scipy.io import loadmat 14 | from PIL.Image import BILINEAR, fromarray, open 15 | import sys 16 | sys.path.append(".") 17 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 18 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 19 | from wavesim.iteration import run_algorithm # to run the wavesim iteration 20 | from wavesim.utilities import pad_boundaries, preprocess, relative_error 21 | from __init__ import plot 22 | 23 | if os.path.basename(os.getcwd()) == 'examples': 24 | os.chdir('..') 25 | 26 | 27 | # Parameters 28 | n_water = 1.33 29 | n_fat = 1.46 30 | wavelength = 0.532 # Wavelength in micrometers 31 | pixel_size = wavelength / (3 * abs(n_fat)) # Pixel size in wavelength units 32 | 33 | # Load image and create refractive index map 34 | oversampling = 0.25 35 | im = np.asarray(open('examples/logo_structure_vector.png')) / 255 36 | n_im = (np.where(im[:, :, 2] > 0.25, 1, 0) * (n_fat - n_water)) + n_water 37 | n_roi = int(oversampling * n_im.shape[0]) # Size of ROI in pixels 38 | n = np.asarray(fromarray(n_im).resize((n_roi, n_roi), BILINEAR)) # Refractive index map 39 | boundary_widths = 40 # Width of the boundary in pixels 40 | 41 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 42 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n 43 | 44 | # Source term 45 | source = np.asarray(fromarray(im[:, :, 1]).resize((n_roi, n_roi), BILINEAR)) 46 | source = pad_boundaries(source, boundary_array) 47 | source = torch.tensor(source, dtype=torch.complex64) 48 | 49 | # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains) 50 | # 1-domain, periodic boundaries (without wrapping correction) 51 | periodic = (True, True, True) # periodic boundaries, wrapped field. 52 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength) 53 | # # OR. Uncomment to test domain decomposition 54 | # periodic = (False, True, True) # wrapping correction 55 | # domain = MultiDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength, 56 | # n_domains=(3, 1, 1)) 57 | 58 | # Run the wavesim iteration and get the computed field 59 | start = time() 60 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=10000) 61 | end = time() - start 62 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}') 63 | u_computed = u_computed.squeeze().cpu().numpy()[*([slice(boundary_widths, -boundary_widths)]*2)] 64 | 65 | # load dictionary of results from matlab wavesim/anysim for comparison and validation 66 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u2d_lc']) 67 | 68 | # Compute relative error with respect to the reference solution 69 | re = relative_error(u_computed, u_ref) 70 | print(f'Relative error: {re:.2e}') 71 | threshold = 1.e-3 72 | assert re < threshold, f"Relative error higher than {threshold}" 73 | 74 | # Plot the results 75 | plot(u_computed, u_ref, re) 76 | -------------------------------------------------------------------------------- /examples/helmholtz_3d_disordered.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helmholtz 3D disordered medium test 3 | =================================== 4 | Test for propagation in a 3D disordered medium. 5 | Compare with reference solution (matlab repo result). 6 | """ 7 | 8 | import os 9 | import torch 10 | import numpy as np 11 | from time import time 12 | from scipy.io import loadmat 13 | import sys 14 | sys.path.append(".") 15 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 16 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 17 | from wavesim.iteration import run_algorithm # to run the wavesim iteration 18 | from wavesim.utilities import preprocess, relative_error 19 | from __init__ import plot 20 | 21 | if os.path.basename(os.getcwd()) == 'examples': 22 | os.chdir('..') 23 | 24 | 25 | # Parameters 26 | wavelength = 1. # Wavelength in micrometers 27 | n_size = (128, 48, 96) # Size of the domain in pixels (x, y, z) 28 | n = np.ascontiguousarray(loadmat('examples/matlab_results.mat')['n3d_disordered']) # Refractive index map 29 | boundary_widths = 8 # Width of the boundary in pixels 30 | 31 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 32 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n 33 | 34 | # Source term. This way is more efficient than dense tensor 35 | indices = torch.tensor([[int(v/2 - 1) + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain 36 | values = torch.tensor([1.0]) # Amplitude: 1 37 | n_ext = tuple(np.array(n_size) + 2*boundary_array) 38 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64) 39 | 40 | # # Set up the domain operators (HelmholtzDomain() or MultiDomain() depending on number of domains) 41 | # # 1-domain, periodic boundaries (without wrapping correction) 42 | # periodic = (True, True, True) # periodic boundaries, wrapped field. 43 | # domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength) 44 | # OR. Uncomment to test domain decomposition 45 | periodic = (False, True, True) # wrapping correction 46 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, 47 | n_domains=(2, 1, 1)) 48 | 49 | # Run the wavesim iteration and get the computed field 50 | start = time() 51 | u_computed, iterations, residual_norm = run_algorithm(domain, source, max_iterations=1000) 52 | end = time() - start 53 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}') 54 | u_computed = u_computed.squeeze().cpu().numpy()[*([slice(boundary_widths, -boundary_widths)]*3)] 55 | 56 | # load dictionary of results from matlab wavesim/anysim for comparison and validation 57 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u3d_disordered']) 58 | 59 | re = relative_error(u_computed, u_ref) 60 | print(f'Relative error: {re:.2e}') 61 | threshold = 1.e-3 62 | assert re < threshold, f"Relative error {re} higher than {threshold}" 63 | 64 | # Plot the results 65 | plot(u_computed, u_ref, re) 66 | -------------------------------------------------------------------------------- /examples/logo_structure_vector.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IvoVellekoop/wavesim_py/3fc81f6ef9f3ee523575bd81f196c4f34ee0b788/examples/logo_structure_vector.png -------------------------------------------------------------------------------- /examples/matlab_results.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IvoVellekoop/wavesim_py/3fc81f6ef9f3ee523575bd81f196c4f34ee0b788/examples/matlab_results.mat -------------------------------------------------------------------------------- /examples/mem_snapshot.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch memory snapshot example 3 | =============================== 4 | This script captures a memory snapshot of the GPU memory usage during the simulation. 5 | """ 6 | 7 | import os 8 | import sys 9 | import torch 10 | import platform 11 | import numpy as np 12 | from time import time 13 | from paper_code.__init__ import random_refractive_index, construct_source 14 | sys.path.append(".") 15 | from wavesim.helmholtzdomain import HelmholtzDomain 16 | from wavesim.multidomain import MultiDomain 17 | from wavesim.iteration import run_algorithm 18 | from wavesim.utilities import preprocess 19 | 20 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 21 | os.environ["TORCH_USE_CUDA_DSA"] = "1" 22 | if os.path.basename(os.getcwd()) == 'examples': 23 | os.chdir('..') 24 | os.makedirs("logs", exist_ok=True) 25 | 26 | 27 | def is_supported_platform(): 28 | return platform.system().lower() == "linux" and sys.maxsize > 2**32 29 | 30 | 31 | if is_supported_platform(): 32 | torch.cuda.memory._record_memory_history(True, trace_alloc_max_entries=100000, 33 | trace_alloc_record_context=True) 34 | else: 35 | print(f"Pytorch emory snapshot functionality is not supported on {platform.system()} (non-linux non-x86_64 platforms).") 36 | # On Windows, gives "RuntimeError: record_context_cpp is not supported on non-linux non-x86_64 platforms" 37 | 38 | # generate a refractive index map 39 | sim_size = 100 * np.array([2, 1, 1]) # Simulation size in micrometers 40 | wavelength = 1. 41 | pixel_size = 0.25 42 | boundary_widths = 20 43 | n_dims = len(sim_size.squeeze()) 44 | 45 | # Size of the simulation domain in pixels 46 | n_size = sim_size * wavelength / pixel_size 47 | n_size = n_size - 2 * boundary_widths # Subtract the boundary widths 48 | n_size = tuple(n_size.astype(int)) # Convert to integer for indexing 49 | 50 | n = random_refractive_index(n_size) 51 | print(f"Size of n: {n_size}") 52 | print(f"Size of n in GB: {n.nbytes / (1024**3):.2f}") 53 | assert n.imag.min() >= 0, 'Imaginary part of n is negative' 54 | 55 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 56 | n, boundary_array = preprocess(n**2, boundary_widths) # permittivity is n², but uses the same variable n 57 | assert n.imag.min() >= 0, 'Imaginary part of n² is negative' 58 | 59 | source = construct_source(n_size, boundary_array) 60 | 61 | n_domains = (2, 1, 1) # number of domains in each direction 62 | periodic = (False, True, True) # True for 1 domain in that direction, False otherwise 63 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, pixel_size=pixel_size, 64 | n_domains=n_domains) 65 | 66 | # Run the wavesim iteration and get the computed field 67 | start = time() 68 | u, iterations, residual_norm = run_algorithm(domain, source, max_iterations=5) 69 | end = time() - start 70 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}') 71 | 72 | if is_supported_platform(): 73 | try: 74 | torch.cuda.memory._dump_snapshot(f"logs/mem_snapshot_cluster.pickle") 75 | # To view memory snapshot, got this link in a browser window: https://pytorch.org/memory_viz 76 | # Then drag and drop the file "mem_snapshot.pickle" into the browser window. 77 | # From the dropdown menus in the top left corner, open the second one and select "Allocator State History". 78 | except Exception as e: 79 | # logger.error(f"Failed to capture memory snapshot {e}") 80 | print(f"Failed to capture memory snapshot {e}") 81 | 82 | # Stop recording memory snapshot history 83 | torch.cuda.memory._record_memory_history(enabled=None) 84 | 85 | 86 | # %% Postprocessing 87 | 88 | file_name = 'logs/size' 89 | for i in range(n_dims): 90 | file_name += f'{n_size[i]}_' 91 | file_name += f'bw{boundary_widths}_domains' 92 | for i in range(n_dims): 93 | file_name += f'{n_domains[i]}' 94 | 95 | output = (f'Size {n_size}; Boundaries {boundary_widths}; Domains {n_domains}; ' 96 | + f'Time {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e} \n') 97 | with open('logs/output.txt', 'a') as file: 98 | file.write(output) 99 | -------------------------------------------------------------------------------- /examples/paper_code/fig1_splitting.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script generates a figure showing the splitting of matrices A, L, and V. 3 | 4 | It sets up the problem parameters, computes the matrices A, L, and V, 5 | and visualizes them in a figure with four subplots: 6 | - Subplot (a): Matrix A 7 | - Subplot (b): Matrix L without wraparound artifacts 8 | - Subplot (c): Matrix V 9 | - Subplot (d): Matrix L with wraparound artifacts 10 | 11 | The figure is saved as a PDF file. 12 | """ 13 | 14 | # import packages 15 | import os 16 | import sys 17 | import torch 18 | import numpy as np 19 | import matplotlib.pyplot as plt 20 | from matplotlib import rc, rcParams, colors 21 | 22 | sys.path.append(".") 23 | sys.path.append("..") 24 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 25 | from wavesim.utilities import full_matrix, normalize, preprocess 26 | from wavesim.iteration import domain_operator 27 | 28 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 13} 29 | rc('font', **font) 30 | rcParams['mathtext.fontset'] = 'cm' 31 | 32 | if os.path.basename(os.getcwd()) == 'paper_code': 33 | os.chdir('..') 34 | os.makedirs('paper_figures', exist_ok=True) 35 | filename = 'paper_figures/fig1_splitting.pdf' 36 | else: 37 | try: 38 | os.makedirs('examples/paper_figures', exist_ok=True) 39 | filename = 'examples/paper_figures/fig1_splitting.pdf' 40 | except FileNotFoundError: 41 | filename = 'fig1_splitting.pdf' 42 | 43 | # Define problem parameters 44 | boundary_widths = 0 45 | n_size = (40, 1, 1) 46 | # Random refractive index distribution 47 | torch.manual_seed(0) # Set the random seed for reproducibility 48 | n = (torch.normal(mean=1.3, std=0.1, size=n_size, dtype=torch.float32) 49 | + 1j * abs(torch.normal(mean=0.05, std=0.02, size=n_size, dtype=torch.float32))) 50 | assert n.imag.min() >= 0, 'Imaginary part of n is negative' 51 | 52 | wavelength = 1. 53 | pixel_size = wavelength / 4 54 | periodic = (True, True, True) 55 | 56 | # Get matrices of A, L (with wrapping artifacts), and V operators 57 | domain = HelmholtzDomain(permittivity=n**2, periodic=periodic, 58 | wavelength=wavelength, pixel_size=pixel_size, 59 | debug=True) 60 | domain.set_source(0) 61 | b = full_matrix(domain_operator(domain, 'medium')).cpu().numpy() # B = I - V 62 | l_plus1 = full_matrix(domain_operator(domain, 'inverse_propagator')).cpu().numpy() # L + I 63 | 64 | I = np.eye(np.prod(n_size), dtype=np.complex64) 65 | v = I - b 66 | l = l_plus1 - I 67 | 68 | # Get matrices of A, L, and V for large domain without wraparound artifacts 69 | boundary_widths = 100 70 | n, boundary_array = preprocess((n**2).numpy(), boundary_widths) 71 | assert n.imag.min() >= 0, 'Imaginary part of n is negative' 72 | domain_o = HelmholtzDomain(permittivity=n, periodic=periodic, 73 | wavelength=wavelength, pixel_size=pixel_size, 74 | debug=True) 75 | 76 | crop2roi = (slice(boundary_array[0], -boundary_array[0]), 77 | slice(boundary_array[0], -boundary_array[0])) 78 | b_o = full_matrix(domain_operator(domain_o, 'medium'))[crop2roi].cpu().numpy() # B = I - V 79 | l_plus1_o = full_matrix(domain_operator(domain_o, 'inverse_propagator'))[crop2roi].cpu().numpy() # L + I 80 | 81 | v_o = (I - b_o) / domain_o.scale + domain_o.shift.item()*I # V = (I - B) / scaling 82 | l_o = (l_plus1_o - I) / domain_o.scale - domain_o.shift.item()*I # L = (L + I - I) / scaling 83 | 84 | v_o = (v_o - domain.shift.item()*I) * domain.scale 85 | l_o = (l_o + domain.shift.item()*I) * domain.scale 86 | 87 | a_o = l_o + v_o # A = L + V 88 | 89 | # Normalize matrices for visualization 90 | l = l.imag 91 | v = v.imag 92 | a_o = a_o.imag 93 | l_o = l_o.imag 94 | 95 | max_val = max(np.max(l), np.max(a_o), np.max(l_o), np.max(v)) 96 | min_val = min(np.min(l), np.min(a_o), np.min(l_o), np.min(v)) 97 | extremum = max(abs(min_val), abs(max_val)) 98 | vmin = -1 99 | vmax = 1 100 | 101 | l = normalize(l, -extremum, extremum, vmin, vmax) 102 | v = normalize(v, -extremum, extremum, vmin, vmax) 103 | a_o = normalize(a_o, -extremum, extremum, vmin, vmax) 104 | l_o = normalize(l_o, -extremum, extremum, vmin, vmax) 105 | 106 | # Create a figure with four subplots in one row 107 | fig, axs = plt.subplots(1, 4, figsize=(12, 3), sharex=True, sharey=True, 108 | gridspec_kw={'wspace': 0.15, 'width_ratios': [1, 1, 1, 1.094]}) 109 | fraction = 0.046 110 | pad = 0.04 111 | extent = np.array([0, n_size[0], n_size[0], 0]) # * base.pixel_size 112 | cmap = 'seismic' 113 | 114 | ax0 = axs[0] 115 | im0 = ax0.imshow(a_o, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax) 116 | ax0.set_title('$A$') 117 | 118 | ax1 = axs[1] 119 | im1 = ax1.imshow(l_o, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax) 120 | ax1.set_title('$L$') 121 | 122 | ax2 = axs[2] 123 | im2 = ax2.imshow(v, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax) 124 | ax2.set_title('$V$') 125 | 126 | ax3 = axs[3] 127 | im3 = ax3.imshow(l, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax) 128 | ax3.set_title('$L$ with wraparound artifacts') 129 | fig.colorbar(im3, ax=ax3, fraction=fraction, pad=pad) 130 | 131 | # Add text boxes with labels (a), (b), (c), ... 132 | labels = ['(a)', '(b)', '(c)', '(d)'] 133 | for i, ax in enumerate(axs.flat): 134 | ax.text(0.5, -0.23, labels[i], transform=ax.transAxes, ha='center') 135 | ax.set_xticks(np.arange(0, 41, 10)) 136 | ax.set_yticks(np.arange(0, 41, 10)) 137 | 138 | plt.savefig(filename, bbox_inches='tight', pad_inches=0.03, dpi=300) 139 | plt.close('all') 140 | print(f'Saved: {filename}') 141 | -------------------------------------------------------------------------------- /examples/paper_code/fig2_decompose.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script demonstrates the domain decomposition of operator A, L, and V. 3 | It visualizes the matrices in a figure with three subplots: 4 | - Subplot (a): Matrix A 5 | - Subplot (b): Matrix L 6 | - Subplot (c): Matrix V 7 | 8 | The figure is saved as a PDF file. 9 | """ 10 | 11 | # import packages 12 | import os 13 | import sys 14 | import torch 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | from matplotlib import rc, rcParams, colors 18 | 19 | sys.path.append(".") 20 | sys.path.append("..") 21 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 22 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 23 | from wavesim.utilities import full_matrix, normalize, preprocess 24 | from wavesim.iteration import domain_operator 25 | 26 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 13} 27 | rc('font', **font) 28 | rcParams['mathtext.fontset'] = 'cm' 29 | 30 | if os.path.basename(os.getcwd()) == 'paper_code': 31 | os.chdir('..') 32 | os.makedirs('paper_figures', exist_ok=True) 33 | filename = 'paper_figures/fig2_decompose.pdf' 34 | else: 35 | try: 36 | os.makedirs('examples/paper_figures', exist_ok=True) 37 | filename = 'examples/paper_figures/fig2_decompose.pdf' 38 | except FileNotFoundError: 39 | filename = 'fig2_decompose.pdf' 40 | 41 | # Define problem parameters 42 | boundary_widths = 0 43 | n_size = (40, 1, 1) 44 | # Random refractive index distribution 45 | torch.manual_seed(0) # Set the random seed for reproducibility 46 | n = (torch.normal(mean=1.3, std=0.1, size=n_size, dtype=torch.float32) 47 | + 1j * abs(torch.normal(mean=0.05, std=0.02, size=n_size, dtype=torch.float32)))**2 48 | assert n.imag.min() >= 0, 'Imaginary part of n is negative' 49 | 50 | wavelength = 1. 51 | pixel_size = wavelength / 4 52 | periodic = (True, True, True) 53 | 54 | # Get matrices of A, L, and V for large domain without wraparound artifacts 55 | boundary_widths = 100 56 | n_o, boundary_array = preprocess(n.numpy(), boundary_widths) 57 | assert n_o.imag.min() >= 0, 'Imaginary part of n_o is negative' 58 | 59 | domain_o = HelmholtzDomain(permittivity=n_o, periodic=periodic, wavelength=wavelength, 60 | pixel_size=pixel_size, debug=True) 61 | 62 | crop2roi = (slice(boundary_array[0], -boundary_array[0]), 63 | slice(boundary_array[0], -boundary_array[0])) 64 | 65 | I = np.eye(np.prod(n_size), dtype=np.complex64) 66 | 67 | l_plus1_o = full_matrix(domain_operator(domain_o, 'inverse_propagator'))[crop2roi].cpu().numpy() # L + I 68 | l_o = (l_plus1_o - I) / domain_o.scale - domain_o.shift.item()*I # L = (L + I - I) / scaling 69 | 70 | # Get matrices of A, L, and V operators decomposed into two domains 71 | domain_2 = MultiDomain(permittivity=n, periodic=(True, True, True), wavelength=wavelength, 72 | pixel_size=pixel_size, n_domains=(2,1,1), debug=True, n_boundary=0) 73 | 74 | b_2 = full_matrix(domain_operator(domain_2, 'medium')).cpu().numpy() # B = I - V 75 | l_plus1_2 = full_matrix(domain_operator(domain_2, 'inverse_propagator')).cpu().numpy() # L + I 76 | 77 | v_2 = I - b_2 78 | l_2 = l_plus1_2 - I 79 | 80 | l_o = (l_o + domain_2.shift.item()*I) * domain_2.scale 81 | 82 | l_diff = l_o - l_2 83 | v_l_diff = v_2 + l_diff 84 | a_2 = l_2 + v_l_diff # A = L + V 85 | 86 | # Normalize matrices for visualization 87 | l_2 = l_2.imag 88 | v_l_diff = v_l_diff.imag 89 | a_2 = a_2.imag 90 | 91 | max_val = max(np.max(l_2), np.max(v_l_diff), np.max(a_2)) 92 | min_val = min(np.min(l_2), np.min(v_l_diff), np.min(a_2)) 93 | extremum = max(abs(min_val), abs(max_val)) 94 | vmin = -1 95 | vmax = 1 96 | 97 | l_2 = normalize(l_2, -extremum, extremum, vmin, vmax) 98 | v_l_diff = normalize(v_l_diff, -extremum, extremum, vmin, vmax) 99 | a_2 = normalize(a_2, -extremum, extremum, vmin, vmax) 100 | 101 | # Create a figure with three subplots in one row 102 | fig, axs = plt.subplots(1, 3, figsize=(9, 3), sharex=True, sharey=True, 103 | gridspec_kw={'wspace': 0.15, 'width_ratios': [1, 1, 1.094]}) 104 | fraction = 0.046 105 | pad = 0.04 106 | extent = np.array([0, n_size[0], n_size[0], 0]) 107 | cmap = 'seismic' 108 | 109 | ax0 = axs[0] 110 | im0 = ax0.imshow(a_2, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax) 111 | ax0.set_title('$A$') 112 | kwargs0 = dict(transform=ax0.transAxes, ha='center', va='center', 113 | bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.1', 114 | alpha=0.5, linewidth=0.0)) 115 | ax0.text(0.09, 0.57, '$A_{11}$', **kwargs0) 116 | ax0.text(0.9, 0.92, '$A_{12}$', **kwargs0) 117 | ax0.text(0.09, 0.07, '$A_{21}$', **kwargs0) 118 | ax0.text(0.9, 0.42, '$A_{22}$', **kwargs0) 119 | 120 | ax1 = axs[1] 121 | im1 = ax1.imshow(l_2, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax) 122 | ax1.set_title('$L$') 123 | 124 | ax2 = axs[2] 125 | im2 = ax2.imshow(v_l_diff, cmap=cmap, extent=extent, vmin=vmin, vmax=vmax) 126 | ax2.set_title('$V$') 127 | kwargs2 = dict(transform=ax2.transAxes, ha='center', va='center', 128 | bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.1', 129 | alpha=0.5, linewidth=0.0)) 130 | ax2.text(0.18, 0.64, '$C_{11}$', **kwargs2) 131 | ax2.text(0.84, 0.84, '$A_{12}$', **kwargs2) 132 | ax2.text(0.18, 0.15, '$A_{21}$', **kwargs2) 133 | ax2.text(0.84, 0.34, '$C_{22}$', **kwargs2) 134 | fig.colorbar(im2, ax=ax2, fraction=fraction, pad=pad) 135 | 136 | # Add text boxes with labels (a), (b), (c), ... 137 | labels = ['(a)', '(b)', '(c)'] 138 | for i, ax in enumerate(axs.flat): 139 | ax.text(0.5, -0.23, labels[i], transform=ax.transAxes, ha='center') 140 | ax.axhline(20, color='gray', linestyle='--', alpha=0.5) # subdomain demarcation 141 | ax.axvline(20, color='gray', linestyle='--', alpha=0.5) # subdomain demarcation 142 | ax.set_yticks(np.arange(0, 41, 10)) 143 | ax.set_xticks(np.arange(0, 41, 10)) 144 | 145 | plt.savefig(filename, bbox_inches='tight', pad_inches=0.03, dpi=300) 146 | plt.close('all') 147 | print(f'Saved: {filename}') 148 | -------------------------------------------------------------------------------- /examples/paper_code/fig3_correction_matrix.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script generates a figure demonstrating the fast convolution of the Laplacian with a point source 3 | for a 1-domain and 2-domain case, and the construction of the correction matrix A_{12}. 4 | 5 | The figure consists of three subplots: 6 | - Subplot (a): Fast convolution with a point source for 1-domain case 7 | - Subplot (b): Fast convolution with a point source for 2-domain case, where the fast convolution is performed over 1 8 | subdomain. The subplot also shows the difference between the unwrapped and wrapped fields. 9 | - Subplot (c): Correction matrix A_{12}, a non-cyclic convolution matrix that computes the wrapping artifacts. 10 | 11 | The figure is saved as a PDF file. 12 | """ 13 | 14 | # import packages 15 | import os 16 | import sys 17 | import torch 18 | import numpy as np 19 | import matplotlib.pyplot as plt 20 | from matplotlib.gridspec import GridSpec 21 | from matplotlib import rc, rcParams, colors 22 | 23 | sys.path.append(".") 24 | sys.path.append("..") 25 | from wavesim.domain import Domain 26 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 27 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 28 | from wavesim.utilities import normalize, pad_boundaries, preprocess 29 | 30 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 14} 31 | rc('font', **font) 32 | rcParams['mathtext.fontset'] = 'cm' 33 | 34 | if os.path.basename(os.getcwd()) == 'paper_code': 35 | os.chdir('..') 36 | os.makedirs('paper_figures', exist_ok=True) 37 | filename = 'paper_figures/fig3_correction_matrix.pdf' 38 | else: 39 | try: 40 | os.makedirs('examples/paper_figures', exist_ok=True) 41 | filename = 'examples/paper_figures/fig3_correction_matrix.pdf' 42 | except FileNotFoundError: 43 | filename = 'fig3_correction_matrix.pdf' 44 | 45 | 46 | def coordinates_f_sq(n_, pixel_size=0.25): 47 | """ Calculate the coordinates in the frequency domain and returns squared values 48 | :param n_: Number of points 49 | :param pixel_size: Pixel size. Defaults to 0.25. 50 | :return: Tensor containing the coordinates in the frequency domain """ 51 | return (2 * torch.pi * torch.fft.fftfreq(n_, pixel_size))**2 52 | 53 | # Define problem parameters 54 | boundary_widths = 0. 55 | n_size = 40 56 | 57 | # Fast convolution of Laplacian with point source 58 | 59 | # Case 1: 1 domain 60 | d = int(n_size/2 - 1) # 1 point before the center 61 | # Point source 62 | side = torch.zeros(n_size) 63 | side[d] = 1.0 64 | # Fast convolution result 65 | fc1 = (torch.fft.ifftn(coordinates_f_sq(n_size) 66 | * torch.fft.fftn(side))).real.cpu().numpy() # discard tiny imaginary part due to numerical errors 67 | 68 | # Case 2: 2 domains. Fast convolution of Laplacian with point source over 1 subdomain 69 | # Point source 70 | side2 = torch.zeros(n_size // 2) 71 | side2[d] = 1.0 72 | # discard tiny imaginary part due to numerical errors in the fast conv result 73 | fc2 = (torch.fft.ifftn(coordinates_f_sq(n_size // 2) 74 | * torch.fft.fftn(side2))).real.cpu().numpy() 75 | fc2 = np.concatenate((fc2, np.zeros_like(fc2))) # zero padding for the 2nd domain 76 | diff_ = fc1 - fc2 # difference between unwrapped and wrapped fields 77 | 78 | # construct a non-cyclic convolution matrix that computes the wrapping artifacts only 79 | t = 8 # number of correction points for constructing correction matrix 80 | a_12 = np.zeros((t, t), dtype=np.complex64) 81 | for r in range(t): 82 | a_12[r, :] = diff_[r:r + t] 83 | a_12 = np.flip(a_12.real, axis=0) # flip the matrix 84 | 85 | # Normalize matrices for visualization 86 | min_val = min(fc1.min(), fc2.min(), diff_.min(), a_12.min()) 87 | max_val = max(fc1.max(), fc2.max(), diff_.max(), a_12.max()) 88 | extremum = max(abs(min_val), abs(max_val)) 89 | vmin = -1 90 | vmax = 1 91 | 92 | fc1 = normalize(fc1, -extremum, extremum, vmin, vmax) 93 | fc2 = normalize(fc2, -extremum, extremum, vmin, vmax) 94 | diff_ = normalize(diff_, -extremum, extremum, vmin, vmax) 95 | a_12 = normalize(a_12, -extremum, extremum, vmin, vmax) 96 | 97 | # Plot limits 98 | c = n_size//2 # center of the domain 99 | 100 | # Plot 101 | fig = plt.figure(figsize=(12, 6), layout='constrained') 102 | gs = GridSpec(2, 2, figure=fig) 103 | 104 | # 1-domain case 105 | ax1 = fig.add_subplot(gs[0, 0]) 106 | ax1.plot(fc1, 'k') # 1-domain, no wrapping field 107 | ax1.set_title(r'$\nabla^2$ kernel in infinite domain') 108 | ax1.set_xlim([-1, n_size]) 109 | ax1.set_xticks(np.arange(0, n_size + 1, 5)) 110 | ax1.set_xticklabels([]) 111 | ax1.grid(True, which='major', linestyle='--', linewidth=0.5) 112 | ax1.text(0.5, -0.15, '(a)', transform=ax1.transAxes, ha='center') 113 | 114 | # 2-domain case 115 | ax2 = fig.add_subplot(gs[1, 0]) 116 | ax2.plot(fc2, 'k', label='Field') # 2-domain, with wrapping field 117 | ax2.axvspan(c, n_size - 1, color='gray', alpha=0.3) # Patch to demarcate 2nd subdomain 118 | 119 | # difference between unwrapped and wrapped fields 120 | ax2.plot(diff_[:t], 'r--', label='Corrections') # wrapping correction 121 | ax2.text(0.14, 0.5, '$C_{11}$', color='r', transform=ax2.transAxes, ha='center') 122 | 123 | ax2.plot(np.arange(c, c + t), diff_[c:c + t], 'r--') # transfer correction 124 | ax2.text(0.62, 0.23, '$A_{12}$', color='r', transform=ax2.transAxes, ha='center') 125 | 126 | ax2.set_xlim([-1, n_size]) 127 | ax2.set_xticks(np.arange(0, n_size + 1, 5)) 128 | ax2.set_xlabel('x') 129 | ax2.text(0.26, 0.09, 'Subdomain 1', transform=ax2.transAxes, ha='center') 130 | ax2.text(0.76, 0.09, 'Subdomain 2', transform=ax2.transAxes, ha='center') 131 | ax2.set_title(r'$\nabla^2$ kernel in periodic subdomains') 132 | ax2.grid(True, which='major', linestyle='--', linewidth=0.5) 133 | ax2.legend() 134 | ax2.text(0.5, -0.36, '(b)', transform=ax2.transAxes, ha='center') 135 | 136 | # Correction matrix 137 | ax3 = fig.add_subplot(gs[:, 1]) 138 | im3 = ax3.imshow(a_12, cmap='seismic', extent=[0, t, 0, t], norm=colors.CenteredNorm()) 139 | cb3 = plt.colorbar(mappable=im3, fraction=0.05, pad=0.01, ax=ax3) 140 | ax3.set_title('$A_{12}$') 141 | ax3.text(0.5, -0.15, '(c)', transform=ax3.transAxes, ha='center') 142 | 143 | plt.savefig(filename, bbox_inches='tight', pad_inches=0.03, dpi=300) 144 | plt.close('all') 145 | 146 | # Print min and max a_12 values. min/max should be <<1% (for truncation to make sense) 147 | min_a_12 = a_12[0, -1] 148 | max_a_12 = a_12[-1, 0] 149 | percent = (min_a_12/max_a_12) * 100 150 | print(f'Wrapping artifact amplitudes: Min {min_a_12:.3f}, Max {max_a_12:.3f}') 151 | print(f'Min/Max of A_{12} = {percent:.2f} %') 152 | assert percent < 1, f"Min/Max of A_{12} ratio exceeds 1%: {percent:.2f} %" 153 | print('Done.') 154 | -------------------------------------------------------------------------------- /examples/paper_code/fig4_truncate.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script generates a figure to demonstrate the truncationg of wrapping and transfer 3 | corrections. The figure consists of three subpltos: 4 | 5 | - Subplot (a): Matrix A with truncated corrections 6 | - Subplot (b): Matrix L 7 | - Subplot (c): Matrix V with truncated corrections 8 | 9 | The figure is saved as a PDF file. 10 | """ 11 | 12 | # import packages 13 | import os 14 | import sys 15 | import torch 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | from matplotlib import rc, rcParams, colors 19 | 20 | sys.path.append(".") 21 | sys.path.append("..") 22 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 23 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 24 | from wavesim.utilities import full_matrix, normalize, preprocess, relative_error 25 | from wavesim.iteration import domain_operator 26 | 27 | current_dir = os.path.dirname(__file__) 28 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 13} 29 | rc('font', **font) 30 | rcParams['mathtext.fontset'] = 'cm' 31 | 32 | if os.path.basename(os.getcwd()) == 'paper_code': 33 | os.chdir('..') 34 | os.makedirs('paper_figures', exist_ok=True) 35 | filename = 'paper_figures/fig4_truncate.pdf' 36 | else: 37 | try: 38 | os.makedirs('examples/paper_figures', exist_ok=True) 39 | filename = 'examples/paper_figures/fig4_truncate.pdf' 40 | except FileNotFoundError: 41 | filename = 'fig4_truncate.pdf' 42 | 43 | # Define problem parameters 44 | n_size = (40, 1, 1) 45 | # Random refractive index distribution 46 | torch.manual_seed(0) # Set the random seed for reproducibility 47 | n = (torch.normal(mean=1.3, std=0.1, size=n_size, dtype=torch.float32) 48 | + 1j * abs(torch.normal(mean=0.05, std=0.02, size=n_size, dtype=torch.float32)))**2 49 | assert n.imag.min() >= 0, 'Imaginary part of n is negative' 50 | 51 | wavelength = 1. 52 | pixel_size = wavelength / 4 53 | periodic = (True, True, True) 54 | 55 | I = np.eye(np.prod(n_size), dtype=np.complex64) 56 | 57 | # Get matrices of A, L, and V operators (with truncated corrections) decomposed into two domains 58 | domain_c = MultiDomain(permittivity=n, periodic=(False, True, True), wavelength=wavelength, 59 | pixel_size=pixel_size, n_domains=(2,1,1), n_boundary=6, debug=True) 60 | 61 | b_c = full_matrix(domain_operator(domain_c, 'medium')).cpu().numpy() # B = I - V 62 | l_plus1_c = full_matrix(domain_operator(domain_c, 'inverse_propagator')).cpu().numpy() # L + I 63 | 64 | v_c = I - b_c 65 | l_c = l_plus1_c - I 66 | a_c = l_c + v_c # A = L + V 67 | 68 | # Normalize matrices for visualization 69 | a_c = a_c.imag 70 | l_c = l_c.imag 71 | v_c = v_c.imag 72 | 73 | max_val = max(np.max(a_c), np.max(l_c), np.max(v_c)) 74 | min_val = min(np.min(a_c), np.min(l_c), np.min(v_c)) 75 | extremum = max(abs(min_val), abs(max_val)) 76 | vmin = -1 77 | vmax = 1 78 | 79 | a_c = normalize(a_c, -extremum, extremum, vmin, vmax) 80 | v_c = normalize(v_c, -extremum, extremum, vmin, vmax) 81 | l_c = normalize(l_c, -extremum, extremum, vmin, vmax) 82 | 83 | # Plot (a) A, (b) L, and (c) V with truncated corrections 84 | 85 | # Create a figure with three subplots in one row 86 | fig, axs = plt.subplots(1, 3, figsize=(9, 3), sharex=True, sharey=True, 87 | gridspec_kw={'wspace': 0.15, 'width_ratios': [1, 1, 1.094]}) 88 | fraction = 0.046 89 | pad = 0.04 90 | extent = np.array([0, n_size[0], n_size[0], 0]) 91 | cmap = 'seismic' 92 | 93 | ax0 = axs[0] 94 | im0 = ax0.imshow(a_c, cmap=cmap, extent=extent, norm=colors.SymLogNorm(linthresh=0.7, vmin=vmin, vmax=vmax)) 95 | ax0.set_title('$A$ (truncated corrections)') 96 | 97 | ax1 = axs[1] 98 | im1 = ax1.imshow(l_c, cmap=cmap, extent=extent, norm=colors.SymLogNorm(linthresh=0.7, vmin=vmin, vmax=vmax)) 99 | ax1.set_title('$L$') 100 | 101 | ax2 = axs[2] 102 | im2 = ax2.imshow(v_c, cmap=cmap, extent=extent, norm=colors.SymLogNorm(linthresh=0.7, vmin=vmin, vmax=vmax)) 103 | ax2.set_title('$V$ (truncated corrections)') 104 | fig.colorbar(im2, ax=ax2, fraction=fraction, pad=pad, ticks=[-1, -0.5, 0, 0.5, 1], format='%.1f') 105 | 106 | # Add text boxes with labels (a), (b), (c), ... 107 | labels = ['(a)', '(b)', '(c)', '(d)'] 108 | for i, ax in enumerate(axs.flat): 109 | ax.text(0.5, -0.23, labels[i], transform=ax.transAxes, ha='center') 110 | ax.axhline(20, color='gray', linestyle='--', alpha=0.5) # subdomain demarcation 111 | ax.axvline(20, color='gray', linestyle='--', alpha=0.5) # subdomain demarcation 112 | ax.set_yticks(np.arange(0, 41, 10)) 113 | ax.set_xticks(np.arange(0, 41, 10)) 114 | 115 | plt.savefig(filename, bbox_inches='tight', pad_inches=0.03, dpi=300) 116 | plt.close('all') 117 | print(f'Saved: {filename}') 118 | -------------------------------------------------------------------------------- /examples/paper_code/fig5_dd_validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from __init__ import sim_3d_random, plot_validation 4 | 5 | if os.path.basename(os.getcwd()) == 'paper_code': 6 | os.chdir('..') 7 | os.makedirs('paper_data', exist_ok=True) 8 | os.makedirs('paper_figures', exist_ok=True) 9 | filename = 'paper_data/fig5_dd_validation_' 10 | figname = 'paper_figures/fig5_dd_validation.pdf' 11 | else: 12 | try: 13 | os.makedirs('examples/paper_data', exist_ok=True) 14 | os.makedirs('examples/paper_figures', exist_ok=True) 15 | filename = 'examples/paper_data/fig5_dd_validation_' 16 | figname = 'examples/paper_figures/fig5_dd_validation.pdf' 17 | except FileNotFoundError: 18 | print("Directory not found. Please run the script from the 'paper_code' directory.") 19 | 20 | sim_size = 50 * np.array([1, 1, 1]) # Simulation size in micrometers (excluding boundaries) 21 | full_residuals = True 22 | 23 | # Run the simulations 24 | sim_ref = sim_3d_random(filename, sim_size, n_domains=None, n_boundary=0, r=12, clearance=0, full_residuals=full_residuals) 25 | sim = sim_3d_random(filename, sim_size, n_domains=(3, 1, 1), r=12, clearance=0, full_residuals=full_residuals) 26 | 27 | # plot the field 28 | plot_validation(figname, sim_ref, sim, plt_norm='log') 29 | print('Done.') 30 | -------------------------------------------------------------------------------- /examples/paper_code/fig6_dd_truncation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from time import time 5 | import matplotlib.pyplot as plt 6 | from matplotlib import rc, rcParams 7 | from matplotlib.ticker import LogLocator, MultipleLocator 8 | 9 | sys.path.append(".") 10 | sys.path.append("..") 11 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 12 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 13 | from wavesim.iteration import run_algorithm # to run the wavesim iteration 14 | from wavesim.utilities import preprocess, relative_error 15 | from __init__ import random_spheres_refractive_index, construct_source 16 | 17 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 13} 18 | rc('font', **font) 19 | rcParams['mathtext.fontset'] = 'cm' 20 | 21 | if os.path.basename(os.getcwd()) == 'paper_code': 22 | os.chdir('..') 23 | os.makedirs('paper_data', exist_ok=True) 24 | os.makedirs('paper_figures', exist_ok=True) 25 | current_dir = 'paper_data/' 26 | figname = (f'paper_figures/fig6_dd_truncation.pdf') 27 | else: 28 | try: 29 | os.makedirs('examples/paper_data', exist_ok=True) 30 | os.makedirs('examples/paper_figures', exist_ok=True) 31 | current_dir = 'examples/paper_data/' 32 | figname = (f'examples/paper_figures/fig6_dd_truncation.pdf') 33 | except FileNotFoundError: 34 | print("Directory not found. Please run the script from the 'paper_code' directory.") 35 | 36 | sim_size = 50 * np.array([1, 1, 1]) # Simulation size in micrometers (excluding boundaries) 37 | 38 | wavelength = 1. # Wavelength in micrometers 39 | pixel_size = wavelength/4 # Pixel size in wavelength units 40 | boundary_wavelengths = 5 # Boundary width in wavelengths 41 | boundary_widths = [round(boundary_wavelengths * wavelength / pixel_size), 0, 0] # Boundary width in pixels 42 | # Periodic boundaries True (no wrapping correction) if boundary width is 0, else False (wrapping correction) 43 | periodic = tuple(np.where(np.array(boundary_widths) == 0, True, False)) 44 | n_dims = np.count_nonzero(sim_size != 1) # Number of dimensions 45 | 46 | # Size of the simulation domain 47 | n_size = np.ones_like(sim_size, dtype=int) 48 | n_size[:n_dims] = sim_size[:n_dims] * wavelength / pixel_size # Size of the simulation domain in pixels 49 | n_size = tuple(n_size.astype(int)) # Convert to integer for indexing 50 | 51 | filename = os.path.join(current_dir, f'fig6_dd_truncation.npz') 52 | if os.path.exists(filename): 53 | print(f"File {filename} already exists. Loading data...") 54 | data = np.load(filename) 55 | corrs = data['corrs'] 56 | sim_time = data['sim_time'] 57 | iterations = data['iterations'] 58 | ure_list = data['ure_list'] 59 | print(f"Loaded data from {filename}. Now plotting...") 60 | else: 61 | n = random_spheres_refractive_index(n_size, r=12, clearance=0) # Random refractive index 62 | 63 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 64 | n, boundary_array = preprocess((n**2), boundary_widths) # permittivity is n², but uses the same variable n 65 | 66 | print(f"Size of n: {n_size}") 67 | print(f"Size of n in GB: {n.nbytes / (1024**3):.2f}") 68 | assert n.imag.min() >= 0, 'Imaginary part of n² is negative' 69 | assert (n.shape == np.asarray(n_size) + 2*boundary_array).all(), 'n and n_size do not match' 70 | assert n.dtype == np.complex64, f'n is not complex64, but {n.dtype}' 71 | 72 | source = construct_source(n_size, boundary_array) 73 | 74 | domain_ref = HelmholtzDomain(permittivity=n, periodic=periodic, 75 | wavelength=wavelength, pixel_size=pixel_size) 76 | 77 | start_ref = time() 78 | # Field u and state object with information about the run 79 | u_ref, iterations_ref, residual_norm_ref = run_algorithm(domain_ref, source, 80 | max_iterations=10000) 81 | sim_time_ref = time() - start_ref 82 | print(f'\nTime {sim_time_ref:2.2f} s; Iterations {iterations_ref}; Residual norm {residual_norm_ref:.3e}') 83 | # crop the field to the region of interest 84 | u_ref = u_ref[*(slice(boundary_widths[i], 85 | u_ref.shape[i] - boundary_widths[i]) for i in range(3))].cpu().numpy() 86 | 87 | n_ext = np.array(n_size) + 2*boundary_array 88 | corrs = np.arange(n_ext[0] // 4 + 1) 89 | print(f"Number of correction points: {corrs[-1]}") 90 | 91 | ure_list = [] 92 | iterations = [] 93 | sim_time = [] 94 | 95 | n_domains = (2, 1, 1) # number of domains in each direction 96 | 97 | for n_boundary in corrs: 98 | print(f'n_boundary {n_boundary}/{corrs[-1]}', end='\r') 99 | domain_n = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, 100 | pixel_size=pixel_size, n_domains=n_domains, n_boundary=n_boundary) 101 | 102 | start_n = time() 103 | u_n, iterations_n, residual_norm_n = run_algorithm(domain_n, source, 104 | max_iterations=10000) 105 | sim_time_n = time() - start_n 106 | print(f'\nTime {sim_time_n:2.2f} s; Iterations {iterations_n}; Residual norm {residual_norm_n:.3e}') 107 | # crop the field to the region of interest 108 | u_n = u_n[*(slice(boundary_widths[i], 109 | u_n.shape[i] - boundary_widths[i]) for i in range(3))].cpu().numpy() 110 | 111 | ure_list.append(relative_error(u_n, u_ref)) 112 | iterations.append(iterations_n) 113 | sim_time.append(sim_time_n) 114 | 115 | sim_time = np.array(sim_time) 116 | iterations = np.array(iterations) 117 | ure_list = np.array(ure_list) 118 | np.savez_compressed(filename, corrs=corrs, 119 | sim_time=sim_time, 120 | iterations=iterations, 121 | ure_list=ure_list) 122 | print(f'Saved: {filename}. Now plotting...') 123 | 124 | 125 | # Plot 126 | length = int(len(ure_list) * 2/3) 127 | x = np.arange(length) 128 | ncols = 3 129 | figsize = (12, 3) 130 | 131 | fig, axs = plt.subplots(1, ncols, figsize=figsize, gridspec_kw={'hspace': 0., 'wspace': 0.29}) 132 | 133 | ax0 = axs[0] 134 | ax0.semilogy(x, ure_list[:length], 'r', lw=1., marker='x', markersize=3) 135 | ax0.set_xlabel('Number of correction points') 136 | ax0.set_ylabel('Relative Error') 137 | ax0.set_xticks(np.arange(0, round(length,-1)+1, 10)) 138 | ax0.set_xlim([-2 if n_dims == 3 else -10, length + 1 if n_dims == 3 else length + 9]) 139 | ax0.grid(True, which='major', linestyle='--', linewidth=0.5) 140 | ax0.grid(True, which='minor', linestyle=':', linewidth=0.3, axis='y') 141 | ax0.yaxis.set_minor_locator(LogLocator(numticks=12,subs=np.arange(2,10))) 142 | ax0.xaxis.set_minor_locator(MultipleLocator(1)) 143 | 144 | ax1 = axs[1] 145 | start = 4 146 | ax1.plot(x[start:], iterations[start:length], 'g', lw=1., marker='+', markersize=3) 147 | ax1.set_xlabel('Number of correction points') 148 | ax1.set_ylabel('Iterations') 149 | ax1.set_xticks(np.arange(0, round(length,-1)+1, 10)) 150 | ax1.set_xlim([start-1, length+1]) 151 | ax1.grid(True, which='major', linestyle='--', linewidth=0.5) 152 | ax1.xaxis.set_minor_locator(MultipleLocator(1)) 153 | 154 | ax2 = axs[2] 155 | ax2.plot(x[start:], sim_time[start:length], 'b', lw=1., marker='*', markersize=3) 156 | ax2.set_xlabel('Number of correction points') 157 | ax2.set_ylabel('Time (s)') 158 | ax2.set_xticks(np.arange(0, round(length,-1)+1, 10)) 159 | ax2.set_xlim([start-1, length+1]) 160 | ax2.grid(True, which='major', linestyle='--', linewidth=0.5) 161 | ax2.xaxis.set_minor_locator(MultipleLocator(1)) 162 | 163 | # Add text boxes with labels (a), (b), (c), ... 164 | labels = ['(a)', '(b)', '(c)'] 165 | for i, ax in enumerate(axs.flat): 166 | ax.text(0.5, -0.3, labels[i], transform=ax.transAxes, ha='center') 167 | 168 | plt.savefig(figname, bbox_inches='tight', pad_inches=0.03, dpi=300) 169 | plt.close('all') 170 | print(f'Saved: {figname}') 171 | -------------------------------------------------------------------------------- /examples/paper_code/fig7_dd_convergence.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from time import time 5 | from itertools import product 6 | import matplotlib.pyplot as plt 7 | from matplotlib import rc, rcParams 8 | from matplotlib.ticker import MultipleLocator 9 | 10 | sys.path.append(".") 11 | sys.path.append("..") 12 | from wavesim.helmholtzdomain import HelmholtzDomain # when number of domains is 1 13 | from wavesim.multidomain import MultiDomain # for domain decomposition, when number of domains is >= 1 14 | from wavesim.iteration import run_algorithm # to run the anysim iteration 15 | from wavesim.utilities import preprocess 16 | from __init__ import random_spheres_refractive_index, construct_source 17 | 18 | font = {'family': 'serif', 'serif': ['Times New Roman'], 'size': 13} 19 | rc('font', **font) 20 | rcParams['mathtext.fontset'] = 'cm' 21 | 22 | if os.path.basename(os.getcwd()) == 'paper_code': 23 | os.chdir('..') 24 | os.makedirs('paper_data', exist_ok=True) 25 | os.makedirs('paper_figures', exist_ok=True) 26 | filename = f'paper_data/fig7_dd_convergence.txt' 27 | figname = f'paper_figures/fig7_dd_convergence.pdf' 28 | else: 29 | try: 30 | os.makedirs('examples/paper_data', exist_ok=True) 31 | os.makedirs('examples/paper_figures', exist_ok=True) 32 | filename = f'examples/paper_data/fig7_dd_convergence.txt' 33 | figname = (f'examples/paper_figures/fig7_dd_convergence.pdf') 34 | except FileNotFoundError: 35 | print("Directory not found. Please run the script from the 'paper_code' directory.") 36 | 37 | sim_size = 50 * np.array([1, 1, 1]) # Simulation size in micrometers (excluding boundaries) 38 | wavelength = 1. # Wavelength in micrometers 39 | pixel_size = wavelength/4 # Pixel size in wavelength units 40 | boundary_wavelengths = 5 # Boundary width in wavelengths 41 | n_dims = np.count_nonzero(sim_size != 1) # Number of dimensions 42 | 43 | # Size of the simulation domain 44 | n_size = np.ones_like(sim_size, dtype=int) 45 | n_size[:n_dims] = sim_size[:n_dims] * wavelength / pixel_size # Size of the simulation domain in pixels 46 | n_size = tuple(n_size.astype(int)) # Convert to integer for indexing 47 | 48 | if os.path.exists(filename): 49 | print(f"File {filename} already exists. Loading data and plotting...") 50 | else: 51 | domains = range(1, 11) 52 | for nx, ny in product(domains, domains): 53 | print(f'Domains {nx}/{domains[-1]}, {ny}/{domains[-1]}', end='\r') 54 | 55 | if nx == 1 and ny == 1: 56 | boundary_widths = [round(boundary_wavelengths * wavelength / pixel_size), 0, 0] 57 | elif nx > 1 and ny == 1: 58 | boundary_widths = [round(boundary_wavelengths * wavelength / pixel_size), 0, 0] 59 | elif nx == 1 and ny > 1: 60 | boundary_widths = [0, round(boundary_wavelengths * wavelength / pixel_size), 0] 61 | else: 62 | boundary_widths = [round(boundary_wavelengths * wavelength / pixel_size)]*2 + [0] 63 | 64 | periodic = tuple(np.where(np.array(boundary_widths) == 0, True, False)) 65 | n = random_spheres_refractive_index(n_size, r=12, clearance=0) # Random refractive index 66 | 67 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 68 | n, boundary_array = preprocess((n**2), boundary_widths) # permittivity is n², but uses the same variable n 69 | 70 | print(f"Size of n: {n_size}") 71 | print(f"Size of n in GB: {n.nbytes / (1024**3):.2f}") 72 | assert n.imag.min() >= 0, 'Imaginary part of n² is negative' 73 | assert (n.shape == np.asarray(n_size) + 2*boundary_array).all(), 'n and n_size do not match' 74 | assert n.dtype == np.complex64, f'n is not complex64, but {n.dtype}' 75 | 76 | source = construct_source(n_size, boundary_array) 77 | 78 | n_domains = (nx, ny, 1) 79 | domain = MultiDomain(permittivity=n, periodic=periodic, 80 | wavelength=wavelength, pixel_size=pixel_size, 81 | n_boundary=8, n_domains=n_domains) 82 | 83 | start = time() 84 | u, iterations, residual_norm = run_algorithm(domain, source, max_iterations=10000) # Field u and state object with information about the run 85 | end = time() - start 86 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}') 87 | 88 | #%% Save data to file 89 | data = (f'Size {n_size}; Boundaries {boundary_widths}; Domains {n_domains}; ' 90 | + f'Time {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e} \n') 91 | with open(filename, 'a') as file: 92 | file.write(data) 93 | 94 | #%% Domains in x AND y direction vs iterations and time 95 | data = np.loadtxt(filename, dtype=str, delimiter=';') 96 | 97 | num_domains = [data[i, 2] for i in range(len(data))] 98 | num_domains = [num_domains[i].split('(', maxsplit=1)[-1] for i in range(len(num_domains))] 99 | num_domains = [num_domains[i].split(', 1)', maxsplit=1)[0] for i in range(len(num_domains))] 100 | num_domains = [(int(num_domains[i].split(',', maxsplit=1)[0]), int(num_domains[i].split(',', maxsplit=1)[-1])) for i in range(len(num_domains))] 101 | 102 | x, y = max(num_domains, key=lambda x: x[0])[0], max(num_domains, key=lambda x: x[1])[1] 103 | 104 | #%% Both subplots in one figure 105 | 106 | iterations = [data[i, 4] for i in range(len(data))] 107 | iterations = np.array([int(iterations[i].split(' ')[2]) for i in range(len(iterations))]) 108 | iterations = np.reshape(iterations, (x, y), order='F') 109 | 110 | times = [data[i, 3] for i in range(len(data))] 111 | times = [float(times[i].split(' ')[2]) for i in range(len(times))] 112 | times = np.reshape(times, (x, y), order='F') 113 | 114 | fig, ax = plt.subplots(figsize=(9, 3), nrows=1, ncols=2, sharey=True, gridspec_kw={'wspace': 0.05}) 115 | cmap = 'inferno' 116 | 117 | im0 = ax[0].imshow(np.flipud(iterations), cmap=cmap, extent=[0.5, x+0.5, 0.5, y+0.5]) 118 | ax[0].set_xlabel('Domains in x direction') 119 | ax[0].set_ylabel('Domains in y direction') 120 | plt.colorbar(im0, label='Iterations', fraction=0.046, pad=0.04) 121 | ax[0].set_title('Iterations vs Number of domains') 122 | ax[0].text(0.5, -0.27, '(a)', color='k', ha='center', va='center', transform=ax[0].transAxes) 123 | ax[0].xaxis.set_minor_locator(MultipleLocator(1)) 124 | ax[0].yaxis.set_minor_locator(MultipleLocator(1)) 125 | ax[0].set_xticks(np.arange(2, x+1, 2)) 126 | ax[0].set_yticks(np.arange(2, y+1, 2)) 127 | 128 | im1 = ax[1].imshow(np.flipud(times), cmap=cmap, extent=[0.5, x+0.5, 0.5, y+0.5]) 129 | ax[1].set_xlabel('Domains in x direction') 130 | plt.colorbar(im1, label='Time (s)', fraction=0.046, pad=0.04) 131 | ax[1].set_title('Time vs Number of domains') 132 | ax[1].text(0.5, -0.27, '(b)', color='k', ha='center', va='center', transform=ax[1].transAxes) 133 | ax[1].xaxis.set_minor_locator(MultipleLocator(1)) 134 | ax[1].yaxis.set_minor_locator(MultipleLocator(1)) 135 | ax[1].set_xticks(np.arange(2, x+1, 2)) 136 | ax[1].set_yticks(np.arange(2, y+1, 2)) 137 | 138 | plt.savefig(figname, bbox_inches='tight', pad_inches=0.03, dpi=300) 139 | plt.close('all') 140 | print(f'Saved: {figname}') 141 | -------------------------------------------------------------------------------- /examples/paper_code/fig8_dd_large_simulation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from __init__ import sim_3d_random, plot_validation 4 | 5 | if os.path.basename(os.getcwd()) == 'paper_code': 6 | os.chdir('..') 7 | os.makedirs('paper_data', exist_ok=True) 8 | os.makedirs('paper_figures', exist_ok=True) 9 | filename = 'paper_data/fig8_dd_large_simulation_' 10 | figname = 'paper_figures/fig8_dd_large_simulation.pdf' 11 | else: 12 | try: 13 | os.makedirs('examples/paper_data', exist_ok=True) 14 | os.makedirs('examples/paper_figures', exist_ok=True) 15 | filename = 'examples/paper_data/fig8_dd_large_simulation_' 16 | figname = 'examples/paper_figures/fig8_dd_large_simulation.pdf' 17 | except FileNotFoundError: 18 | print("Directory not found. Please run the script from the 'paper_code' directory.") 19 | 20 | sim_size = 315 * np.array([1, 1, 1]) # Simulation size in micrometers (excluding boundaries) 21 | full_residuals = True 22 | 23 | # Run the simulations 24 | sim_gpu = sim_3d_random(filename, sim_size, n_domains=(2, 1, 1), r=24, clearance=24, full_residuals=full_residuals) 25 | sim_cpu = sim_3d_random(filename, sim_size, n_domains=None, n_boundary=0, r=24, clearance=24, full_residuals=full_residuals, 26 | device='cpu') 27 | 28 | # plot the field 29 | plot_validation(figname, sim_cpu, sim_gpu, plt_norm='log', inset=True) 30 | print('Done.') 31 | -------------------------------------------------------------------------------- /examples/run_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run Helmholtz example 3 | ===================== 4 | Example script to run a simulation of a point source in a random refractive index map using the Helmholtz equation. 5 | """ 6 | 7 | import os 8 | import sys 9 | 10 | import torch 11 | import numpy as np 12 | from time import time 13 | from scipy.signal.windows import gaussian 14 | from torch.fft import fftn, ifftn, fftshift 15 | import matplotlib.pyplot as plt 16 | from matplotlib import colors 17 | 18 | sys.path.append(".") 19 | from __init__ import random_permittivity, construct_source 20 | from wavesim.helmholtzdomain import HelmholtzDomain 21 | from wavesim.multidomain import MultiDomain 22 | from wavesim.iteration import run_algorithm 23 | from wavesim.utilities import preprocess, normalize 24 | 25 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 26 | os.environ["TORCH_USE_CUDA_DSA"] = "1" 27 | if os.path.basename(os.getcwd()) == 'examples': 28 | os.chdir('..') 29 | 30 | # generate a refractive index map 31 | sim_size = 50 * np.array([1, 1, 1]) # Simulation size in micrometers 32 | periodic = (False, True, True) 33 | wavelength = 1. # Wavelength in micrometers 34 | pixel_size = wavelength/4 # Pixel size in wavelength units 35 | boundary_wavelengths = 10 # Boundary width in wavelengths 36 | boundary_widths = [int(boundary_wavelengths * wavelength / pixel_size), 0, 0] # Boundary width in pixels 37 | n_dims = len(sim_size.squeeze()) # Number of dimensions 38 | 39 | # Size of the simulation domain 40 | n_size = sim_size * wavelength / pixel_size # Size of the simulation domain in pixels 41 | n_size = tuple(n_size.astype(int)) # Convert to integer for indexing 42 | 43 | # return permittivity (n²) with absorbing boundaries 44 | permittivity = random_permittivity(n_size) 45 | permittivity = preprocess(permittivity, boundary_widths)[0] 46 | assert permittivity.imag.min() >= 0, 'Imaginary part of n² is negative' 47 | assert (permittivity.shape == np.asarray(n_size) + 2*np.asarray(boundary_widths)).all(), 'permittivity and n_size do not match' 48 | assert permittivity.dtype == np.complex64, f'permittivity is not complex64, but {permittivity.dtype}' 49 | 50 | # construct a source at the center of the domain 51 | source = construct_source(source_type='point', at=np.asarray(permittivity.shape) // 2, shape=permittivity.shape) 52 | # source = construct_source(source_type='plane_wave', at=[[boundary_widths[0]]], shape=permittivity.shape) 53 | # source = construct_source(source_type='gaussian_beam', at=[[boundary_widths[0]]], shape=permittivity.shape) 54 | 55 | # # 1-domain 56 | # n_domains = (1, 1, 1) # number of domains in each direction 57 | # domain = HelmholtzDomain(permittivity=permittivity, periodic=periodic, wavelength=wavelength, pixel_size=pixel_size) 58 | 59 | # 1-domain or more with domain decomposition 60 | n_domains = (1, 1, 1) # number of domains in each direction 61 | domain = MultiDomain(permittivity=permittivity, periodic=periodic, wavelength=wavelength, pixel_size=pixel_size, 62 | n_domains=n_domains) 63 | 64 | start = time() 65 | # Field u and state object with information about the run 66 | u, iterations, residual_norm = run_algorithm(domain, source, max_iterations=1000) 67 | end = time() - start 68 | print(f'\nTime {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e}') 69 | 70 | # %% Postprocessing 71 | 72 | file_name = './logs/size' 73 | for i in range(n_dims): 74 | file_name += f'{n_size[i]}_' 75 | file_name += f'bw{boundary_widths}_domains' 76 | for i in range(n_dims): 77 | file_name += f'{n_domains[i]}' 78 | 79 | output = (f'Size {n_size}; Boundaries {boundary_widths}; Domains {n_domains}; ' 80 | + f'Time {end:2.2f} s; Iterations {iterations}; Residual norm {residual_norm:.3e} \n') 81 | if not os.path.exists('logs'): 82 | os.makedirs('logs') 83 | with open('logs/output.txt', 'a') as file: 84 | file.write(output) 85 | 86 | # %% crop and save the field 87 | # crop the field to the region of interest 88 | u = u.squeeze()[*([slice(boundary_widths[i], 89 | u.shape[i] - boundary_widths[i]) for i in range(3)])].cpu().numpy() 90 | np.savez_compressed(f'{file_name}.npz', u=u) # save the field 91 | 92 | # %% plot the field 93 | extent = np.array([0, n_size[0], n_size[1], 0])*pixel_size 94 | u = normalize(np.abs(u[:, :, u.shape[2]//2].T)) 95 | plt.imshow(u, cmap='inferno', extent=extent, norm=colors.LogNorm()) 96 | plt.xlabel(r'$x~(\mu m)$') 97 | plt.ylabel(r'$y~(\mu m)$') 98 | cbar = plt.colorbar(fraction=0.046, pad=0.04) 99 | cbar.ax.set_title(r'$|E|$') 100 | plt.tight_layout() 101 | plt.savefig(f'{file_name}.pdf', bbox_inches='tight', pad_inches=0.03, dpi=300) 102 | plt.show() 103 | # plt.close('all') 104 | -------------------------------------------------------------------------------- /examples/timing_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | ## Tests for the fastest way to calculate the Euclidean norm of a complex vector 5 | # It seems that on the CPU, 'norm' is consistently slower than 'vdot', but gives a slightly different, perhaps more accurate value. 6 | # On the GPU, 'norm' is only very slightly faster (1%) than 'vdot' and gives the same result as vdot, indicating that the same 7 | # computations are used (perhaps with one cached memory load less in the 'norm' case) 8 | 9 | size = 100000000 # Οne hundred million elements 10 | 11 | 12 | def norm_using_vdot(vector): 13 | squared_magnitude = torch.vdot(vector, vector) 14 | return squared_magnitude.sqrt().real 15 | 16 | 17 | def norm_using_matmul(vector): 18 | x = torch.unsqueeze(vector, 0) 19 | y = torch.unsqueeze(vector, 0) 20 | squared_magnitude = torch.matmul(x, y.H).item() 21 | return squared_magnitude.sqrt().real 22 | 23 | 24 | def norm_using_linalg_norm(vector): 25 | return torch.linalg.norm(vector) 26 | 27 | 28 | def timeit(func, name, complex_vector): 29 | start_time = time.time() 30 | for x in range(100): 31 | value = func(complex_vector) 32 | print(f"Euclidean norm using {name}: {value}, Time taken: {time.time() - start_time} seconds") 33 | 34 | 35 | for device in ['cpu', 'cuda']: 36 | complex_vector = torch.randn(size, dtype=torch.float32, device=device) + 1j * torch.randn(size, 37 | dtype=torch.float32, 38 | device=device) 39 | print(device) 40 | for repeat in range(3): 41 | print(f"Repeat {repeat + 1}") 42 | # timeit(norm_using_matmul, 'torch.matmul', complex_vector) # always much slower 43 | timeit(norm_using_linalg_norm, 'torch.linalg.norm', complex_vector) 44 | timeit(norm_using_vdot, 'torch.vdot', complex_vector) 45 | -------------------------------------------------------------------------------- /guides/ipynb2slides_guide.md: -------------------------------------------------------------------------------- 1 | ## Working with conda and jupyter notebooks (.ipynb) 2 | 3 | We need to add the conda environment to jupyter notebooks so that it can be selected in the Select Kernel option in a jupyter notebook 4 | 5 | 1. Activate the desired conda environment 6 | 7 | conda activate 8 | 9 | 2. Install the ipykernel package 10 | 11 | conda install ipykernel 12 | 13 | 3. Add/install the environment to the ipykernel 14 | 15 | python -m ipykernel install --user --name= 16 | 17 | ### Additional step for Visual Studio code 18 | 19 | 4. Install the Jupyter extension through the gui 20 | 21 | ### To open in browser (google chrome) 22 | 23 | 5. Although these packages should already be installed through the .yml file while setting up the environment, if the current working directory does not open jupyter in a browser window after entering the command: 24 | 25 | jupyter notebook 26 | 27 | 1. Set up jupyterlab and jupyter notebook with the following commands 28 | 29 | conda install -c conda-forge jupyterlab 30 | conda install -c anaconda notebook 31 | 32 | 2. Run the below command again, and now jupyter should open in a browser window: 33 | 34 | jupyter notebook 35 | 36 | 6. The environment should now be visible in the Select Kernel dropdown. 37 | 38 | 39 | ### To convert the current jupyter notebook into a presentation (plotly plots stay interactive) 40 | 41 | 1. Open the jupyter notebook in a browser window and check that running all cells gives the expected output 42 | 43 | 2. In the toolbar at the top, Click **View** --> **Cell Toolbar** ---> **Slideshow** 44 | 45 | 3. Each cell in the notebook will now have a toolbar at the top with a dropdown named **Slide Type**. In the dropdown, select **Slide** for all the cells you want to include in the presentation. 46 | 47 | 4. Convert the .ipynb notebook to a .html presentation with the command(s and the options as below) 48 | 49 | * Default options 50 | 51 | jupyter nbconvert --to slides .ipynb 52 | 53 | * If you don't want to show the code, and just the outputs 54 | 55 | jupyter nbconvert --to slides --no-input .ipynb 56 | 57 | * In addition to above, if you don't want any transitions 58 | 59 | jupyter nbconvert --to slides --no-input .ipynb --SlidesExporter.reveal_transition=none 60 | 61 | * In addition to above, if you want a specific theme (here, serif) 62 | 63 | jupyter nbconvert --to slides --no-input .ipynb --SlidesExporter.reveal_transition=none --SlidesExporter.reveal_theme=serif 64 | 65 | 5. Double-click the .html presentation \.html that should now be in the current working directory. 66 | 67 | 6. To convert the .html slides to pdf, add _?print-pdf_ in the URL in the web browser between _html_ and _#_. -------------------------------------------------------------------------------- /guides/misc_tips.md: -------------------------------------------------------------------------------- 1 | ### Some useful conda environment management commands 2 | 3 | The [Miniconda environment management guide](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) has more details if you need them. 4 | 5 | * To update the current conda environment from a .yml file: 6 | 7 | ``` 8 | conda env update --name wavesim --file environment.yml --prune 9 | ``` 10 | 11 | * To export the current environment to a .yml file: 12 | 13 | ``` 14 | conda env export > .yml 15 | ``` 16 | 17 | * To install any packages within an environment, first go into the environment and then install the package: 18 | 19 | ``` 20 | conda activate 21 | conda install 22 | ``` 23 | 24 | * If conda does not have the package, and googling it suggests installing it via pip, use this command to install it specifically within the current environment and not globally (always prefer conda over pip. Only go to pip if the package is not available through conda): 25 | 26 | ``` 27 | python -m pip install 28 | ``` 29 | 30 | * After updating conda, setting up a new environment, installing packages, it is a nice idea to clean up any installation packages or tarballs as they are not needed anymore: 31 | 32 | ``` 33 | conda clean --all 34 | ``` 35 | 36 | ### If problem with specifying fonts in matplotlib.rc 37 | 38 | Example of an error: "findfont: Generic family 'sans-serif' not found because none of the following families were found: Time New Roman" 39 | 40 | 1. Check if 'mscorefonts' package installed in conda (using conda list). If not, 41 | 42 | ``` 43 | conda install -c conda-forge mscorefonts 44 | ``` 45 | 46 | 2. Clear matplotlib cache. An equally important step. 47 | 48 | ``` 49 | rm ~/.cache/matplotlib -rf 50 | ``` 51 | 52 | ### If problem with tex in matplotlib 53 | 54 | ``` 55 | sudo apt install texlive texlive-latex-extra texlive-fonts-recommended dvipng cm-super 56 | 57 | python -m pip install latex 58 | ``` 59 | 60 | ### For animations, ffmpeg package needed (below command for Linux) 61 | 62 | ``` 63 | sudo apt-get install ffmpeg 64 | ``` 65 | -------------------------------------------------------------------------------- /guides/pytorch_gpu_setup.md: -------------------------------------------------------------------------------- 1 | ## [Setting up NVIDIA CUDA on WSL (Windows Subsystem for Linux)](https://docs.nvidia.com/cuda/wsl-user-guide/index.html) 2 | 3 | ### Main steps 4 | 1. Install NVIDIA Driver for GPU support. Select and download appropriate [driver](https://www.nvidia.com/Download/index.aspx) for GPU and Operating system 5 | 6 | 2. Install WSL 2 (follow instruction in [2.2. Step 2](https://docs.nvidia.com/cuda/wsl-user-guide/index.html)) 7 | 8 | 3. Install CUDA Toolkit using WSL-Ubuntu Package, following the [command line instructions](https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&Distribution=WSL-Ubuntu&target_version=2.0&target_type=deb_local). 9 | 10 | ## Install pytorch-gpu 11 | 12 | 1. If doing this for the first time (or perhaps just in general), preferably do this in a separate conda environment, so that if anything breaks during installation, this new conda environment can just be trashed without losing anything. 13 | 14 | * Create a new conda environment 15 | ``` 16 | conda create --name 17 | ``` 18 | 19 | * OR, if a .yml file is available, 20 | ``` 21 | conda env create -f .yml 22 | ``` 23 | 24 | 2. Obtain the appropriate (depending on the build (prefer stable), os, package (here conda), language (here, Python), and compute platform (CPU or CUDA versions for GPU)) command to run from https://pytorch.org/. Example: 25 | ``` 26 | python -m pip install torch --index-url https://download.pytorch.org/whl/cu126 27 | ``` 28 | 29 | 3. Check that pytorch works (and with GPU) with the following series of commands in a WSL (Ubuntu) window/session, inside the conda pytorch environment 30 | ``` 31 | python 32 | import torch 33 | ``` 34 | 35 | * To check if pytorch works 36 | ``` 37 | a = torch.randn(5,3) 38 | print(a) 39 | ``` 40 | 41 | * To check if it works with GPU(s) 42 | ``` 43 | torch.cuda.is_available() 44 | ``` 45 | Output should be: "True" 46 | 47 | ### Monitoring GPU usage 48 | 49 | * To be able to see the GPU(s) and usage, the [nvtop](https://github.com/Syllo/nvtop) package is quite useful. 50 | ``` 51 | sudo apt install nvtop 52 | ``` 53 | 54 | * For WSL, nvtop installed using above approach might not work. Instead, [build from source](https://github.com/Syllo/nvtop#nvtop-build) 55 | 56 | * For Windows, nvtop is not available, but an alternate tool [nvitop](https://pypi.org/project/nvitop/) can be used instead. 57 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "wavesim" 3 | version = "0.1.0a2" 4 | description = "A library for simulating wave propagation using the modified Born series" 5 | authors = [ {name = "Ivo Vellekoop", email = "i.m.vellekoop@utwente.nl"}, 6 | {name = "Swapnil Mache", email = "s.mache@utwente.nl"} ] 7 | readme = "README.md" 8 | repository = "https://github.com/IvoVellekoop/wavesim_py" 9 | documentation = "https://wavesim.readthedocs.io/en/latest/" 10 | classifiers = [ 11 | 'Programming Language :: Python :: 3', 12 | 'Operating System :: OS Independent', 13 | ] 14 | requires-python = ">=3.11,<3.13" 15 | dependencies = [ 16 | "numpy<2.0.0", 17 | "matplotlib>=3.9.1", 18 | "scipy>=1.14.0", 19 | "porespy", 20 | "scikit-image<0.23", 21 | "torch" 22 | ] 23 | 24 | [build-system] 25 | requires = ["poetry-core"] 26 | build-backend = "poetry.core.masonry.api" 27 | 28 | [[tool.poetry.source]] 29 | name = "pytorch-gpu" 30 | url = "https://download.pytorch.org/whl/cu126" 31 | priority = "explicit" 32 | 33 | [tool.poetry.dependencies] 34 | torch = { source = "pytorch-gpu" } 35 | 36 | [tool.poetry.group.dev] 37 | optional = true 38 | 39 | [tool.poetry.group.dev.dependencies] 40 | pytest = "^8.2.2" 41 | 42 | [tool.poetry.group.docs] 43 | optional = true 44 | 45 | [tool.poetry.group.docs.dependencies] 46 | sphinx = ">=4.1.2" 47 | sphinx_mdinclude = ">= 0.5.0" 48 | sphinx-rtd-theme = ">= 2.0.0" 49 | sphinx-autodoc-typehints = ">= 1.11.0" 50 | sphinxcontrib-bibtex = ">= 2.6.0" 51 | sphinx-markdown-builder = ">= 0.6.6" 52 | sphinx-gallery = ">= 0.15.0" 53 | -------------------------------------------------------------------------------- /pyproject_cpu.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "wavesim" 3 | version = "0.1.0a2" 4 | description = "A library for simulating wave propagation using the modified Born series" 5 | authors = [ {name = "Ivo Vellekoop", email = "i.m.vellekoop@utwente.nl"}, 6 | {name = "Swapnil Mache", email = "s.mache@utwente.nl"} ] 7 | readme = "README.md" 8 | repository = "https://github.com/IvoVellekoop/wavesim_py" 9 | documentation = "https://wavesim.readthedocs.io/en/latest/" 10 | classifiers = [ 11 | 'Programming Language :: Python :: 3', 12 | 'Operating System :: OS Independent', 13 | ] 14 | requires-python = ">=3.11,<3.13" 15 | dependencies = [ 16 | "numpy<2.0.0", 17 | "matplotlib>=3.9.1", 18 | "scipy>=1.14.0", 19 | "porespy", 20 | "scikit-image<0.23", 21 | "torch" 22 | ] 23 | 24 | [build-system] 25 | requires = ["poetry-core"] 26 | build-backend = "poetry.core.masonry.api" 27 | 28 | [tool.poetry.group.dev] 29 | optional = true 30 | 31 | [tool.poetry.group.dev.dependencies] 32 | pytest = "^8.2.2" 33 | 34 | [tool.poetry.group.docs] 35 | optional = true 36 | 37 | [tool.poetry.group.docs.dependencies] 38 | sphinx = ">=4.1.2" 39 | sphinx_mdinclude = ">= 0.5.0" 40 | sphinx-rtd-theme = ">= 2.0.0" 41 | sphinx-autodoc-typehints = ">= 1.11.0" 42 | sphinxcontrib-bibtex = ">= 2.6.0" 43 | sphinx-markdown-builder = ">= 0.6.6" 44 | sphinx-gallery = ">= 0.15.0" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IvoVellekoop/wavesim_py/3fc81f6ef9f3ee523575bd81f196c4f34ee0b788/requirements.txt -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | dtype = torch.complex64 4 | 5 | 6 | def allclose(a, b, rtol=0.0, atol=0.0, ulptol=100): 7 | """Check if two tensors are close to each other. 8 | 9 | Condition: |a-b| <= atol + rtol * maximum(|b|,|a|) + ulptol * ulp 10 | Where ulp is the size of the smallest representable difference between two numbers of magnitude ~max(|b[...]|) 11 | """ 12 | 13 | # make sure that a and b are tensors of the same dtype and device 14 | if not torch.is_tensor(a): 15 | a = torch.tensor(a, dtype=b.dtype) 16 | if not torch.is_tensor(b): 17 | b = torch.tensor(b, dtype=a.dtype) 18 | if a.dtype != b.dtype: 19 | a = a.type(b.dtype) 20 | if a.device != b.device: 21 | a = a.to('cpu') 22 | b = b.to('cpu') 23 | a = a.to_dense() 24 | b = b.to_dense() 25 | 26 | # compute the size of a single ULP 27 | ab_max = torch.maximum(a.abs(), b.abs()) 28 | exponent = ab_max.max().log2().ceil().item() 29 | ulp = torch.finfo(b.dtype).eps * 2 ** exponent 30 | tolerance = atol + rtol * ab_max + ulptol * ulp 31 | diff = (a - b).abs() 32 | 33 | if (diff - tolerance).max() <= 0.0: 34 | return True 35 | else: 36 | abs_err = diff.max().item() 37 | rel_err = (diff / ab_max).max() 38 | print(f"\nabsolute error {abs_err} = {abs_err / ulp} ulp\nrelative error {rel_err}") 39 | return False 40 | 41 | 42 | def random_vector(n_size, device='cuda' if torch.cuda.is_available() else 'cpu', dtype=dtype): 43 | """Construct a random vector for testing operators""" 44 | return torch.randn(n_size, device=device, dtype=dtype) + 1.0j * torch.randn(n_size, device=device, dtype=dtype) 45 | 46 | 47 | def random_refractive_index(n_size, device='cuda' if torch.cuda.is_available() else 'cpu', dtype=dtype): 48 | """Construct a random refractive index between 1 and 2 with a small positive imaginary part 49 | 50 | The sign of the imaginary part is such that the imaginary part of n² is positive 51 | """ 52 | n = (1.0 + torch.rand(n_size, device=device, dtype=dtype) + 53 | 0.1j * torch.rand(n_size, device=device, dtype=dtype)) 54 | 55 | # make sure that the imaginary part of n² is positive 56 | mask = (n ** 2).imag < 0 57 | n.imag[mask] *= -1.0 58 | return n 59 | -------------------------------------------------------------------------------- /tests/test_analytical.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | from wavesim.helmholtzdomain import HelmholtzDomain 5 | from wavesim.multidomain import MultiDomain 6 | from wavesim.iteration import domain_operator, preconditioned_iteration, preconditioner, run_algorithm 7 | from wavesim.utilities import analytical_solution, preprocess, relative_error 8 | from . import allclose, random_vector, random_refractive_index 9 | 10 | """Tests to compare the result of Wavesim to analytical results""" 11 | 12 | 13 | def test_no_propagation(): 14 | """Basic test where the L-component is zero 15 | By manually removing the laplacian, we are solving the equation (2 π n / λ)² x = y 16 | """ 17 | n = random_refractive_index((2, 3, 4)) 18 | domain = HelmholtzDomain(permittivity=(n ** 2), periodic=(True, True, True)) 19 | x = random_vector(domain.shape) 20 | 21 | # manually disable the propagator, and test if, indeed, we are solving the system (2 π n / λ)² x = y 22 | L1 = 1.0 + domain.shift * domain.scale 23 | domain.propagator_kernel = 1.0 / L1 24 | domain.inverse_propagator_kernel = L1 25 | k2 = -(2 * torch.pi * n * domain.pixel_size) ** 2 # -(2 π n / λ)² 26 | B = (1.0 - (k2 - domain.shift) * domain.scale) 27 | assert allclose(domain_operator(domain, 'inverse_propagator')(x), x * L1) 28 | assert allclose(domain_operator(domain, 'propagator')(x), x / L1) 29 | assert allclose(domain_operator(domain, 'medium')(x), B * x) 30 | 31 | y = domain_operator(domain, 'forward')(x) 32 | assert allclose(y, k2 * x) 33 | 34 | domain.set_source(y) 35 | alpha = 0.75 36 | M = domain_operator(domain, 'richardson', alpha=alpha) 37 | x_wavesim = M(0) 38 | assert allclose(x_wavesim, (domain.scale * alpha / L1) * B * y) 39 | 40 | for _ in range(500): 41 | x_wavesim = M(x_wavesim) 42 | 43 | assert allclose(x_wavesim, x) 44 | 45 | x_wavesim = run_algorithm(domain, y, threshold=1.e-16)[0] 46 | assert allclose(x_wavesim, x) 47 | 48 | 49 | @pytest.mark.parametrize("size", [[32, 1, 1], [7, 15, 1], [13, 25, 46]]) 50 | @pytest.mark.parametrize("boundary_widths", [0, 10]) 51 | @pytest.mark.parametrize("periodic", [[True, True, True], # periodic boundaries, wrapped field 52 | [False, True, True]]) # wrapping correction 53 | def test_residual(size, boundary_widths, periodic): 54 | """ Check that the residual_norm at first iteration == 1 55 | residual_norm is normalized with the preconditioned source 56 | residual_norm = norm ( B(x - (L+1)⁻¹ (B·x + c·y)) ) 57 | norm of preconditioned source = norm( B(L+1)⁻¹y ) 58 | """ 59 | torch.manual_seed(0) # Set the random seed for reproducibility 60 | n = (torch.normal(mean=1.3, std=0.1, size=size, dtype=torch.float32) 61 | + 1j * abs(torch.normal(mean=0.05, std=0.02, size=size, dtype=torch.float32))).numpy() 62 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 63 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n 64 | 65 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(size)]]).T # Location: center of the domain 66 | values = torch.tensor([1.0]) # Amplitude: 1 67 | n_ext = tuple(np.array(size) + 2*boundary_array) 68 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64) 69 | 70 | wavelength = 1. 71 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength) 72 | # domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=1., n_domains=n_domains) 73 | 74 | # Reset the field u to zero 75 | slot_x = 0 76 | slot_tmp = 1 77 | domain.clear(slot_x) 78 | domain.set_source(source) 79 | 80 | # compute initial residual 81 | domain.add_source(slot_x, weight=1.) # [x] = y 82 | preconditioner(domain, slot_x, slot_x) # [x] = B(L+1)⁻¹y 83 | init_norm = domain.inner_product(slot_x, slot_x) # inverse of initial norm, 1 / norm([x]) 84 | domain.clear(slot_x) # Clear [x] 85 | 86 | residual_norm = preconditioned_iteration(domain, slot_x, slot_x, slot_tmp, alpha=0.75, compute_norm2=True) 87 | 88 | assert np.allclose(residual_norm, init_norm) 89 | 90 | 91 | @pytest.mark.parametrize("n_domains, periodic", [ 92 | ((1, 1, 1), (True, True, True)), # periodic boundaries, wrapped field. 93 | ((1, 1, 1), (False, True, True)), # wrapping correction (here and beyond) 94 | ((2, 1, 1), (False, True, True)), 95 | ((3, 1, 1), (False, True, True)), 96 | ]) 97 | def test_1d_analytical(n_domains, periodic): 98 | """ Test for 1D free-space propagation. Compare with analytic solution """ 99 | wavelength = 1. # wavelength in micrometer (um) 100 | pixel_size = wavelength / 4 # pixel size in wavelength units 101 | n_size = (512, 1, 1) # size of simulation domain (in pixels in x, y, and z direction) 102 | n = np.ones(n_size, dtype=np.complex64) 103 | boundary_widths = 16 # width of the boundary in pixels 104 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 105 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n 106 | 107 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain 108 | values = torch.tensor([1.0]) # Amplitude: 1 109 | n_ext = tuple(np.array(n_size) + 2*boundary_array) 110 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64) 111 | 112 | # domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength) 113 | domain = MultiDomain(permittivity=n, periodic=periodic, 114 | wavelength=wavelength, pixel_size=pixel_size, n_domains=n_domains) 115 | u_computed = run_algorithm(domain, source, max_iterations=10000)[0] 116 | u_computed = u_computed.squeeze()[boundary_widths:-boundary_widths] 117 | u_ref = analytical_solution(n_size[0], domain.pixel_size, wavelength) 118 | 119 | re = relative_error(u_computed.cpu().numpy(), u_ref) 120 | print(f'Relative error: {re:.2e}') 121 | 122 | assert re <= 1.e-3, f'Relative error: {re:.2e}' 123 | assert allclose(u_computed, u_ref, atol=1.e-3, rtol=1.e-3) 124 | -------------------------------------------------------------------------------- /tests/test_basics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from wavesim.helmholtzdomain import HelmholtzDomain 4 | from wavesim.multidomain import MultiDomain 5 | from . import allclose, random_vector, random_refractive_index, dtype 6 | 7 | """ Performs a set of basic consistency checks for the Domain class and the HelmholtzBase multi-domain class. """ 8 | 9 | 10 | def construct_domain(n_size, n_domains, n_boundary, periodic=(False, False, True)): 11 | """ Construct a domain or multi-domain""" 12 | torch.manual_seed(12345) 13 | n = random_refractive_index(n_size) 14 | if n_domains is None: # single domain 15 | return HelmholtzDomain(permittivity=n, periodic=periodic, n_boundary=n_boundary, debug=True) 16 | else: 17 | return MultiDomain(permittivity=n, periodic=periodic, n_boundary=n_boundary, 18 | n_domains=n_domains, debug=True) 19 | 20 | 21 | def construct_source(n_size): 22 | """ Construct a sparse-matrix source with some points at the corners and in the center""" 23 | locations = torch.tensor([ 24 | [n_size[0] // 2, 0, 0], 25 | [n_size[1] // 2, 0, 0], 26 | [n_size[2] // 2, 0, n_size[2] - 1]]) 27 | 28 | return torch.sparse_coo_tensor(locations, torch.tensor([1, 1, 1]), n_size, dtype=dtype) 29 | 30 | 31 | @pytest.mark.parametrize("n_size", [(128, 100, 93), (50, 49, 1)]) 32 | @pytest.mark.parametrize("n_domains", [None, (1, 1, 1), (3, 2, 1)]) 33 | def test_basics(n_size: tuple[int, int, int], n_domains: tuple[int, int, int] | None): 34 | """Tests the basic functionality of the Domain and MultiDomain classes 35 | 36 | Tests constructing a domain, subdividing data over subdomains, 37 | concatenating data from the subdomains, adding sparse sources, 38 | and computing the inner product. 39 | """ 40 | 41 | # construct the (multi-) domain operator 42 | domain = construct_domain(n_size, n_domains, n_boundary=8) 43 | 44 | # test coordinates 45 | assert domain.shape == n_size 46 | for dim in range(3): 47 | coordinates = domain.coordinates(dim) 48 | assert coordinates.shape[dim] == n_size[dim] 49 | assert coordinates.numel() == n_size[dim] 50 | 51 | coordinates_f = domain.coordinates_f(dim) 52 | assert coordinates_f.shape == coordinates.shape 53 | assert coordinates_f[0, 0, 0] == 0 54 | 55 | if n_size[dim] > 1: 56 | assert allclose(coordinates.flatten()[1] - coordinates.flatten()[0], domain.pixel_size) 57 | assert allclose(coordinates_f.flatten()[1] - coordinates_f.flatten()[0], 58 | 2.0 * torch.pi / (n_size[dim] * domain.pixel_size)) 59 | 60 | # construct a random vector for testing operators 61 | x = random_vector(n_size, device=domain.device) 62 | y = random_vector(n_size, device=domain.device) 63 | 64 | # perform some very basic checks 65 | # mainly, this tests if the partitioning and composition works correctly 66 | domain.set(0, x) 67 | domain.set(1, y) 68 | assert x.device == domain.device 69 | assert allclose(domain.get(0), x) 70 | assert allclose(domain.get(1), y) 71 | 72 | inp = domain.inner_product(0, 1) 73 | assert allclose(inp, torch.vdot(x.flatten(), y.flatten())) 74 | 75 | # construct a source and test adding it 76 | domain.clear(0) 77 | assert allclose(domain.get(0), 0.0) 78 | source = construct_source(n_size) 79 | domain.set_source(source) 80 | domain.add_source(0, 0.9) 81 | domain.add_source(0, 1.1) 82 | assert allclose(domain.get(0), 2.0 * source) 83 | x[0, 0, 0] = 1 84 | y[0, 0, 0] = 2 85 | # test mixing: α x + β y 86 | # make sure to include the special cases α=0, β=0, α=1, β=1 and α+β=1 87 | # since they may be optimized and thus have different code paths 88 | for alpha in [0.0, 1.0, 0.25, -0.1]: 89 | for beta in [0.0, 1.0, 0.75]: 90 | for out_slot in [0, 1]: 91 | domain.set(0, x) 92 | domain.set(1, y) 93 | domain.mix(alpha, 0, beta, 1, out_slot) 94 | assert allclose(domain.get(out_slot), alpha * x + beta * y) 95 | 96 | 97 | @pytest.mark.parametrize("n_size", [(128, 100, 93), (50, 49, 1)]) 98 | @pytest.mark.parametrize("n_domains", [None, (1, 1, 1), (3, 2, 1)]) 99 | def test_propagator(n_size: tuple[int, int, int], n_domains: tuple[int, int, int] | None): 100 | """Tests the forward and inverse propagator 101 | 102 | The wavesim algorithm only needs the propagator (L+1)^(-1) to be implemented. 103 | For testing, and for evaluating the final residue, the Domain and MultiDomain classes 104 | also implement the 'inverse propagator L+1', which is basically the homogeneous part of the forward operator A. 105 | 106 | This test checks that the forward and inverse propagator are consistent, namely (L+1)^(-1) (L+1) x = x. 107 | todo: check if the operators are actually correct (not just consistent) 108 | Note that the propagator is domain-local, so the wrapping correction and domain 109 | transfer functions are not tested here. 110 | """ 111 | 112 | # construct the (multi-) domain operator 113 | domain = construct_domain(n_size, n_domains, n_boundary=8) 114 | 115 | # assert that (L+1) (L+1)^-1 x = x 116 | x = random_vector(n_size) 117 | domain.set(0, x) 118 | domain.propagator(0, 0) 119 | domain.inverse_propagator(0, 0) 120 | x_reconstructed = domain.get(0) 121 | assert allclose(x, x_reconstructed) 122 | 123 | # also assert that (L+1)^-1 (L+1) x = x, use different slots for input and output 124 | domain.set(0, x) 125 | domain.inverse_propagator(0, 1) 126 | domain.propagator(1, 1) 127 | x_reconstructed = domain.get(1) 128 | assert allclose(x, x_reconstructed) 129 | 130 | # for the non-decomposed case, test if the propagator gives the correct value 131 | if n_domains is None: 132 | n_size = torch.tensor(n_size, dtype=torch.float64) 133 | # choose |k| < Nyquist, make sure k is at exact grid point in Fourier space 134 | k_relative = torch.tensor((0.2, -0.15, 0.4), dtype=torch.float64) 135 | k = 2 * torch.pi * torch.round(k_relative * n_size) / n_size # in 1/pixels 136 | k[n_size == 1] = 0.0 137 | plane_wave = torch.exp(1j * ( 138 | k[0] * torch.arange(n_size[0], device=domain.device).reshape(-1, 1, 1) + 139 | k[1] * torch.arange(n_size[1], device=domain.device).reshape(1, -1, 1) + 140 | k[2] * torch.arange(n_size[2], device=domain.device).reshape(1, 1, -1))) 141 | domain.set(0, plane_wave) 142 | domain.inverse_propagator(0, 0) 143 | result = domain.get(0) 144 | laplace_kernel = (k[0]**2 + k[1]**2 + k[2]**2) / domain.pixel_size ** 2 # -∇² [negative of laplace kernel] 145 | correct_result = (1.0 + domain.scale * (laplace_kernel + domain.shift)) * plane_wave # L+1 = scale·(-∇²) + 1. 146 | # note: the result is not exactly the same because wavesim is using the real-space kernel, and we compare to 147 | # the Fourier-space kernel 148 | assert allclose(result, correct_result, rtol=0.01) 149 | 150 | 151 | def test_basic_wrapping(): 152 | """Simple test if the wrapping correction is applied at the correct position. 153 | 154 | Constructs a 1-D domain and splits it in two. A source is placed at the right edge of the left domain. 155 | """ 156 | n_size = (10, 1, 1) 157 | n_boundary = 2 158 | source = torch.sparse_coo_tensor(torch.tensor([[(n_size[0] - 1) // 2, 0, 0]]).T, torch.tensor([1.0]), n_size, 159 | dtype=dtype) 160 | domain = MultiDomain(permittivity=torch.ones(n_size, dtype=dtype), n_domains=(2, 1, 1), 161 | n_boundary=n_boundary, periodic=(False, True, True)) 162 | domain.clear(0) 163 | domain.set_source(source) 164 | domain.add_source(0, 1.0) 165 | left = torch.squeeze(domain.domains[0, 0, 0].get(0)) 166 | right = torch.squeeze(domain.domains[1, 0, 0].get(0)) 167 | total = torch.squeeze(domain.get(0)) 168 | assert allclose(torch.concat([left.to(domain.device), right.to(domain.device)]), total) 169 | assert torch.all(right == 0.0) 170 | assert torch.all(left[:-2] == 0.0) 171 | assert left[-1] != 0.0 172 | 173 | domain.medium(0, 1) 174 | 175 | # periodic in 2nd and 3rd dimension: no edges 176 | left_edges = domain.domains[0, 0, 0].edges 177 | right_edges = domain.domains[1, 0, 0].edges 178 | for edge in range(2, 6): 179 | assert left_edges[edge] is None 180 | assert right_edges[edge] is None 181 | 182 | # right domain should have zero edge corrections (since domain is empty) 183 | assert torch.all(right_edges[0] == 0.0) 184 | assert torch.all(right_edges[1] == 0.0) 185 | 186 | # left domain should have wrapping correction at the right edge 187 | # and nothing at the left edge 188 | assert torch.all(left_edges[0] == 0.0) 189 | assert left_edges[1].abs().max() > 1e-3 190 | 191 | # after applying the correction, the first n_boundary elements 192 | # of the left domain should be non-zero (wrapping correction) 193 | # and the first n_boundary elements of the right domain should be non-zero 194 | total2 = torch.squeeze(domain.get(1)) 195 | assert allclose(total2.real[0:n_boundary], -total2.real[n_size[0] // 2:n_size[0] // 2 + n_boundary]) 196 | 197 | 198 | def test_wrapped_propagator(): 199 | """Tests the inverse propagator L+1 with wrapping corrections 200 | 201 | This test compares the situation of a single large domain to that of a multi-domain. 202 | If the wrapping and transfer corrections are implemented correctly, the results should be the same 203 | up to the difference in scaling factor. 204 | """ 205 | # n_size = (128, 100, 93, 1) 206 | n_size = (3 * 32 * 1024, 1, 1) 207 | n_boundary = 16 208 | domain_single = construct_domain(n_size, n_domains=None, n_boundary=n_boundary, periodic=(True, True, True)) 209 | domain_multi = construct_domain(n_size, n_domains=(3, 1, 1), n_boundary=n_boundary, periodic=(True, True, True)) 210 | source = torch.sparse_coo_tensor(torch.tensor([[0, 0, 0]]).T, torch.tensor([1.0]), n_size, dtype=dtype) 211 | 212 | x = [None, None] 213 | for i, domain in enumerate([domain_single, domain_multi]): 214 | # evaluate L+1-B = L + Vscat + wrapping correction for the multi-domain, 215 | # and L+1-B = L + Vscat for the full domain 216 | # Note that we need to compensate for scaling squared, 217 | # because scaling affects both the source and operators L and B 218 | B = 0 219 | L1 = 1 220 | domain.clear(0) 221 | domain.set_source(source) 222 | domain.add_source(0, 1.0) 223 | domain.inverse_propagator(0, L1) # (L+1) y 224 | domain.medium(0, B) # (1-V) y 225 | domain.mix(1.0, L1, -1.0, B, 0) # (L+V) y 226 | x[i] = domain.get(0) / domain.scale 227 | 228 | # first non-compensated point 229 | pos = domain_multi.domains[0].shape[0] - n_boundary - 1 230 | atol = x[0][pos, 0, 0].abs().item() 231 | assert allclose(x[0], x[1], atol=atol) 232 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import torch 4 | import numpy as np 5 | from scipy.io import loadmat 6 | from PIL.Image import BILINEAR, fromarray, open 7 | from wavesim.helmholtzdomain import HelmholtzDomain 8 | from wavesim.multidomain import MultiDomain 9 | from wavesim.iteration import run_algorithm 10 | from wavesim.utilities import pad_boundaries, preprocess, relative_error 11 | 12 | if os.path.basename(os.getcwd()) == 'tests': 13 | os.chdir('..') 14 | 15 | 16 | @pytest.mark.parametrize("n_domains, periodic", [ 17 | (None, (True, True, True)), # periodic boundaries, wrapped field. 18 | ((1, 1, 1), (False, True, True)), # wrapping correction (here and beyond) 19 | ((2, 1, 1), (False, True, True)), 20 | ((3, 1, 1), (False, True, True)), 21 | ]) 22 | def test_1d_glass_plate(n_domains, periodic): 23 | """ Test for 1D propagation through glass plate. Compare with reference solution (matlab repo result). """ 24 | wavelength = 1. 25 | n_size = (256, 1, 1) 26 | n = np.ones(n_size, dtype=np.complex64) 27 | n[99:130] = 1.5 28 | boundary_widths = 24 29 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 30 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n 31 | 32 | indices = torch.tensor([[0 + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location: center of the domain 33 | values = torch.tensor([1.0]) # Amplitude: 1 34 | n_ext = tuple(np.array(n_size) + 2*boundary_array) 35 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64) 36 | 37 | if n_domains is None: # 1-domain, periodic boundaries (without wrapping correction) 38 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength) 39 | else: # OR. Domain decomposition 40 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, n_domains=n_domains) 41 | 42 | u_computed = run_algorithm(domain, source, max_iterations=2000)[0] 43 | u_computed = u_computed.squeeze()[boundary_widths:-boundary_widths] 44 | # load dictionary of results from matlab wavesim/anysim for comparison and validation 45 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u']) 46 | 47 | re = relative_error(u_computed.cpu().numpy(), u_ref) 48 | print(f'Relative error: {re:.2e}') 49 | threshold = 1.e-3 50 | assert re < threshold, f"Relative error higher than {threshold}" 51 | 52 | 53 | @pytest.mark.parametrize("n_domains", [ 54 | None, # periodic boundaries, wrapped field. 55 | (1, 1, 1), # wrapping correction (here and beyond) 56 | (2, 1, 1), 57 | (3, 1, 1), 58 | (1, 2, 1), 59 | (1, 3, 1), 60 | (2, 2, 1), 61 | ]) 62 | def test_2d_low_contrast(n_domains): 63 | """ Test for propagation in 2D structure with low refractive index contrast (made of fat and water to mimic 64 | biological tissue). Compare with reference solution (matlab repo result). """ 65 | oversampling = 0.25 66 | im = np.asarray(open('examples/logo_structure_vector.png')) / 255 67 | n_water = 1.33 68 | n_fat = 1.46 69 | n_im = (np.where(im[:, :, 2] > 0.25, 1, 0) * (n_fat - n_water)) + n_water 70 | n_roi = int(oversampling * n_im.shape[0]) 71 | n = np.asarray(fromarray(n_im).resize((n_roi, n_roi), BILINEAR)) 72 | boundary_widths = 40 73 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 74 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n 75 | 76 | source = np.asarray(fromarray(im[:, :, 1]).resize((n_roi, n_roi), BILINEAR)) 77 | source = pad_boundaries(source, boundary_array) 78 | source = torch.tensor(source, dtype=torch.complex64) 79 | 80 | wavelength = 0.532 81 | pixel_size = wavelength / (3 * abs(n_fat)) 82 | 83 | if n_domains is None: # 1-domain, periodic boundaries (without wrapping correction) 84 | periodic = (True, True, True) # periodic boundaries, wrapped field. 85 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength) 86 | else: # OR. Domain decomposition 87 | periodic = np.where(np.array(n_domains) == 1, True, False) # True for 1 domain in direction, False otherwise 88 | periodic = tuple(periodic) 89 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, pixel_size=pixel_size, 90 | n_domains=n_domains) 91 | 92 | u_computed = run_algorithm(domain, source, max_iterations=10000)[0] 93 | u_computed = u_computed.squeeze()[*([slice(boundary_widths, -boundary_widths)]*2)] 94 | # load dictionary of results from matlab wavesim/anysim for comparison and validation 95 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u2d_lc']) 96 | 97 | re = relative_error(u_computed.cpu().numpy(), u_ref) 98 | print(f'Relative error: {re:.2e}') 99 | threshold = 1.e-3 100 | assert re < threshold, f"Relative error higher than {threshold}" 101 | 102 | 103 | @pytest.mark.parametrize("n_domains", [ 104 | None, # periodic boundaries, wrapped field. 105 | (1, 1, 1), # wrapping correction (here and beyond) 106 | (1, 2, 1), 107 | ]) 108 | def test_2d_high_contrast(n_domains): 109 | """ Test for propagation in 2D structure made of iron, with high refractive index contrast. 110 | Compare with reference solution (matlab repo result). """ 111 | 112 | oversampling = 0.25 113 | im = np.asarray(open('examples/logo_structure_vector.png')) / 255 114 | n_iron = 2.8954 + 2.9179j 115 | n_contrast = n_iron - 1 116 | n_im = ((np.where(im[:, :, 2] > 0.25, 1, 0) * n_contrast) + 1) 117 | n_roi = int(oversampling * n_im.shape[0]) 118 | n = np.asarray(fromarray(n_im.real).resize((n_roi, n_roi), BILINEAR)) + 1j * np.asarray( 119 | fromarray(n_im.imag).resize((n_roi, n_roi), BILINEAR)) 120 | boundary_widths = 8 121 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 122 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n 123 | 124 | source = np.asarray(fromarray(im[:, :, 1]).resize((n_roi, n_roi), BILINEAR)) 125 | source = pad_boundaries(source, boundary_array) 126 | source = torch.tensor(source, dtype=torch.complex64) 127 | 128 | wavelength = 0.532 129 | pixel_size = wavelength / (3 * np.max(abs(n_contrast + 1))) 130 | 131 | if n_domains is None: # 1-domain, periodic boundaries (without wrapping correction) 132 | periodic = (True, True, True) # periodic boundaries, wrapped field. 133 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, pixel_size=pixel_size, wavelength=wavelength) 134 | else: # OR. Domain decomposition 135 | periodic = np.where(np.array(n_domains) == 1, True, False) # True for 1 domain in direction, False otherwise 136 | periodic = tuple(periodic) 137 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, pixel_size=pixel_size, 138 | n_domains=n_domains) 139 | 140 | u_computed = run_algorithm(domain, source, max_iterations=int(1.e+5))[0] 141 | u_computed = u_computed.squeeze()[*([slice(boundary_widths, -boundary_widths)]*2)] 142 | 143 | # load dictionary of results from matlab wavesim/anysim for comparison and validation 144 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u2d_hc']) 145 | 146 | re = relative_error(u_computed.cpu().numpy(), u_ref) 147 | print(f'Relative error: {re:.2e}') 148 | threshold = 1.e-3 149 | assert re < threshold, f"Relative error {re} higher than {threshold}" 150 | 151 | 152 | @pytest.mark.parametrize("n_domains", [ 153 | None, # periodic boundaries, wrapped field. 154 | (1, 1, 1), # wrapping correction (here and beyond) 155 | (2, 1, 1), 156 | (3, 1, 1), 157 | (1, 2, 1), 158 | (1, 3, 1), 159 | (1, 1, 2), 160 | (2, 2, 1), 161 | (2, 1, 2), 162 | (1, 2, 2), 163 | (2, 2, 2), 164 | (3, 2, 1), 165 | (3, 1, 2), 166 | (1, 3, 2), 167 | (1, 2, 3), 168 | ]) 169 | def test_3d_disordered(n_domains): 170 | """ Test for propagation in a 3D disordered medium. Compare with reference solution (matlab repo result). """ 171 | wavelength = 1. 172 | n_size = (128, 48, 96) 173 | n = np.ascontiguousarray(loadmat('examples/matlab_results.mat')['n3d_disordered']) 174 | boundary_widths = 12 175 | # return permittivity (n²) with boundaries, and boundary_widths in format (ax0, ax1, ax2) 176 | n, boundary_array = preprocess(n ** 2, boundary_widths) # permittivity is n², but uses the same variable n 177 | 178 | # Source: single point source in the center of the domain 179 | indices = torch.tensor([[int(v/2 - 1) + boundary_array[i] for i, v in enumerate(n_size)]]).T # Location 180 | values = torch.tensor([1.0]) # Amplitude: 1 181 | n_ext = tuple(np.array(n_size) + 2*boundary_array) 182 | source = torch.sparse_coo_tensor(indices, values, n_ext, dtype=torch.complex64) 183 | 184 | if n_domains is None: # 1-domain, periodic boundaries (without wrapping correction) 185 | periodic = (True, True, True) # periodic boundaries, wrapped field. 186 | domain = HelmholtzDomain(permittivity=n, periodic=periodic, wavelength=wavelength) 187 | else: # OR. Domain decomposition 188 | periodic = np.where(np.array(n_domains) == 1, True, False) # True for 1 domain in direction, False otherwise 189 | periodic = tuple(periodic) 190 | domain = MultiDomain(permittivity=n, periodic=periodic, wavelength=wavelength, n_domains=n_domains) 191 | 192 | u_computed = run_algorithm(domain, source, max_iterations=1000)[0] 193 | u_computed = u_computed.squeeze()[*([slice(boundary_widths, -boundary_widths)]*3)] 194 | 195 | # load dictionary of results from matlab wavesim/anysim for comparison and validation 196 | u_ref = np.squeeze(loadmat('examples/matlab_results.mat')['u3d_disordered']) 197 | 198 | re = relative_error(u_computed.cpu().numpy(), u_ref) 199 | print(f'Relative error: {re:.2e}') 200 | threshold = 1.e-3 201 | assert re < threshold, f"Relative error {re} higher than {threshold}" 202 | -------------------------------------------------------------------------------- /tests/test_operators.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from wavesim.helmholtzdomain import HelmholtzDomain 4 | from wavesim.multidomain import MultiDomain 5 | from wavesim.iteration import domain_operator 6 | from wavesim.utilities import full_matrix 7 | from . import random_vector, allclose, dtype 8 | 9 | """ Performs checks on the operators represented as matrices (accretivity, norm).""" 10 | 11 | parameters = [ 12 | {'n_size': (1, 1, 12), 'n_domains': None, 'n_boundary': 0, 'periodic': (False, False, True)}, 13 | {'n_size': (1, 1, 12), 'n_domains': (1, 1, 1), 'n_boundary': 0, 'periodic': (False, False, True)}, 14 | {'n_size': (1, 1, 12), 'n_domains': (1, 1, 1), 'n_boundary': 5, 'periodic': (True, True, False)}, 15 | {'n_size': (1, 1, 32), 'n_domains': (1, 1, 2), 'n_boundary': 5, 'periodic': (True, True, True)}, 16 | {'n_size': (1, 1, 32), 'n_domains': (1, 1, 2), 'n_boundary': 5, 'periodic': (True, True, False)}, 17 | {'n_size': (1, 32, 1), 'n_domains': (1, 2, 1), 'n_boundary': 5, 'periodic': (True, True, True)}, 18 | {'n_size': (1, 32, 1), 'n_domains': (1, 2, 1), 'n_boundary': 5, 'periodic': (True, False, True)}, 19 | {'n_size': (32, 1, 1), 'n_domains': (2, 1, 1), 'n_boundary': 5, 'periodic': (False, True, True)}, 20 | {'n_size': (32, 1, 1), 'n_domains': (2, 1, 1), 'n_boundary': 5, 'periodic': (True, True, True)}, 21 | # test different-sized domains 22 | {'n_size': (23, 1, 1), 'n_domains': (3, 1, 1), 'n_boundary': 2, 'periodic': (True, True, True)}, 23 | {'n_size': (23, 1, 1), 'n_domains': (3, 1, 1), 'n_boundary': 3, 'periodic': (False, True, True)}, 24 | {'n_size': (1, 5, 19), 'n_domains': (1, 1, 2), 'n_boundary': 3, 'periodic': (True, True, True)}, 25 | {'n_size': (1, 14, 19), 'n_domains': (1, 2, 2), 'n_boundary': 3, 'periodic': (True, False, True)}, 26 | {'n_size': (17, 30, 1), 'n_domains': (2, 3, 1), 'n_boundary': 3, 'periodic': (True, True, True)}, 27 | {'n_size': (8, 8, 8), 'n_domains': (2, 2, 2), 'n_boundary': 2, 'periodic': (False, False, True)}, 28 | {'n_size': (8, 12, 8), 'n_domains': (2, 3, 2), 'n_boundary': 2, 'periodic': (True, True, True)}, 29 | # these parameters are very slow for test_accretivity and should be run only when needed 30 | # {'n_size': (12, 12, 12), 'n_domains': (2, 2, 2), 'n_boundary': 3, 'periodic': (False, False, True)}, 31 | # {'n_size': (18, 24, 18), 'n_domains': (2, 3, 2), 'n_boundary': 3, 'periodic': (True, True, True)}, 32 | # {'n_size': (17, 23, 19), 'n_domains': (2, 3, 2), 'n_boundary': 3, 'periodic': (True, True, True)}, 33 | ] 34 | 35 | 36 | def construct_domain(n_size, n_domains, n_boundary, periodic=(False, False, True)): 37 | """ Construct a domain or multi-domain""" 38 | torch.manual_seed(12345) 39 | n = torch.rand(n_size, device='cuda' if torch.cuda.is_available() else 'cpu', 40 | dtype=dtype) + 1.0 # random refractive index between 1 and 2 41 | n.imag = 0.1 * torch.maximum(n.imag, torch.tensor(0.0)) # a positive imaginary part of n corresponds to absorption 42 | if n_domains is None: # single domain 43 | return HelmholtzDomain(permittivity=n, periodic=periodic, n_boundary=n_boundary, debug=True) 44 | else: 45 | return MultiDomain(permittivity=n, periodic=periodic, n_boundary=n_boundary, 46 | n_domains=n_domains, debug=True) 47 | 48 | 49 | @pytest.mark.parametrize("params", parameters) 50 | def test_operators(params): 51 | """ Check that operator definitions are consistent: 52 | - forward = inverse_propagator - medium: A= L + 1 - B 53 | - preconditioned_operator = preconditioned(operator) 54 | - richardson = x + α (Γ⁻¹b - Γ⁻¹A x) 55 | """ 56 | domain = construct_domain(**params) 57 | x = random_vector(domain.shape, device=domain.device) 58 | B = domain_operator(domain, 'medium') 59 | L1 = domain_operator(domain, 'inverse_propagator') 60 | A = domain_operator(domain, 'forward') 61 | Ax = A(x) 62 | assert allclose(domain.scale * Ax, L1(x) - B(x)) 63 | 64 | Γ = domain_operator(domain, 'preconditioner') 65 | ΓA = domain_operator(domain, 'preconditioned_operator') 66 | ΓAx = ΓA(x) 67 | assert allclose(ΓAx, Γ(Ax)) 68 | 69 | α = 0.1 70 | b = random_vector(domain.shape) 71 | Γb = Γ(b) 72 | domain.set_source(b) 73 | M = domain_operator(domain, 'richardson', alpha=α) 74 | assert allclose(M(0), α * Γb) 75 | 76 | residual = Γb - ΓAx 77 | assert allclose(M(x), x.to(residual.device) + α * residual) 78 | 79 | 80 | @pytest.mark.parametrize("params", parameters) 81 | def test_accretivity(params): 82 | """ Checks norm and lower bound of real part for various operators 83 | 84 | B (medium) should have real part between -0.05 and 1.0 (if we don't put the absorption in V0. If we do, the upper 85 | limit may be 1.95) 86 | The operator B-1 should have a norm of less than 0.95 87 | 88 | L + 1 (inverse propagator) should be accretive with a real part of at least 1.0 89 | (L+1)^-1 (propagator) should be accretive with a real part of at least 0.0 90 | A (forward) should be accretive with a real part of at least 0.0 91 | ΓA (preconditioned_operator) should be such that 1-ΓA is a contraction (a norm of less than 1.0) 92 | """ 93 | domain = construct_domain(**params) 94 | domain.set_source(0) 95 | assert_accretive(domain_operator(domain, 'medium'), 'B', real_min=0.05, real_max=1.0, norm_max=0.95, 96 | norm_offset=1.0) 97 | assert_accretive(domain_operator(domain, 'inverse_propagator'), 'L + 1', real_min=1.0) 98 | assert_accretive(domain_operator(domain, 'propagator'), '(L + 1)^-1', real_min=0.0) 99 | assert_accretive(domain_operator(domain, 'preconditioned_operator'), 'ΓA', norm_max=1.0, norm_offset=1.0) 100 | assert_accretive(domain_operator(domain, 'richardson', alpha=0.75), '1- α ΓA', norm_max=1.0) 101 | assert_accretive(domain_operator(domain, 'forward'), 'A', real_min=0.0, pre_factor=domain.scale) 102 | 103 | 104 | def assert_accretive(operator, name, *, real_min=None, real_max=None, norm_max=None, norm_offset=None, pre_factor=None): 105 | """ Helper function to check if an operator is accretive, and to compute the norm around a given offset. 106 | This function constructs a full matrix from the operator, so it only works if the domain is not too large. 107 | """ 108 | M = full_matrix(operator) 109 | if pre_factor is not None: 110 | M *= pre_factor 111 | 112 | if norm_max is not None: 113 | if norm_offset is not None: 114 | M.diagonal().add_(-norm_offset) 115 | norm = torch.linalg.norm(M, ord=2) 116 | print(f'norm {norm:.2e}') 117 | assert norm <= norm_max, f'operator {name} has norm {norm} > {norm_max}' 118 | 119 | if real_min is not None or real_max is not None: 120 | M.add_(M.mH) 121 | eigs = torch.linalg.eigvalsh(M) 122 | if norm_offset is not None: 123 | eigs.add_(norm_offset) 124 | if real_min is not None: 125 | acc = eigs.min() 126 | print(f'acc {acc:.2e}') 127 | assert acc >= real_min, f'operator {name} is not accretive, min λ_(A+A*) = {acc} < {real_min}' 128 | if real_max is not None: 129 | acc = eigs.max() 130 | print(f'acc {acc:.2e}') 131 | assert acc <= real_max, (f'operator {name} has eigenvalues that are too large, ' 132 | f'max λ_(A+A*) = {acc} > {real_max}') 133 | -------------------------------------------------------------------------------- /tests/test_utilities.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from wavesim.utilities import partition, combine 4 | from . import allclose 5 | 6 | """Test of the utility functions.""" 7 | 8 | 9 | @pytest.mark.parametrize("size", [(5, 4, 6), (7, 15, 32), (3, 5, 6)]) 10 | @pytest.mark.parametrize("n_domains", [(1, 2, 3), (3, 3, 3), (2, 4, 1)]) 11 | @pytest.mark.parametrize("type", ['full', 'sparse', 'hybrid1', 'hybrid2']) 12 | @pytest.mark.parametrize("expanded", [False, True]) 13 | def test_partition_combine(size: tuple[int, int, int], n_domains: tuple[int, int, int], type: str, expanded: bool): 14 | if expanded: 15 | x = torch.tensor(1.0, dtype=torch.complex64).expand(size) 16 | else: 17 | x = torch.randn(size, dtype=torch.complex64) + 1j * torch.randn(size, dtype=torch.complex64) 18 | if type == 'sparse': 19 | x[x.real < 0.5] = 0 20 | x = x.to_sparse() 21 | elif type == 'hybrid1': 22 | # select half of the slices, make rest zero 23 | indices = torch.range(0, size[0] - 1, 2, dtype=torch.int64) # construct indices for the other half 24 | values = x[0::2, :, :] 25 | x = torch.sparse_coo_tensor(indices.reshape(1, -1), values, size) 26 | elif type == 'hybrid2': 27 | indices0 = torch.range(0, size[0] - 1, 2, dtype=torch.int64) 28 | indices1 = torch.range(0, size[1] - 1, 2, dtype=torch.int64) 29 | i0, i1 = torch.meshgrid(indices0, indices1) 30 | indices = torch.stack((i0.reshape(-1), i1.reshape(-1)), dim=0) 31 | values = x[0::2, 0::2, :].reshape(-1, x.shape[2]) 32 | x = torch.sparse_coo_tensor(indices, values, size) 33 | 34 | partitions = partition(x, n_domains) 35 | assert partitions.shape == n_domains 36 | 37 | combined = combine(partitions) 38 | assert allclose(combined, x) 39 | 40 | 41 | @pytest.mark.parametrize("n_domains", [(0, 0, 0), (4, 3, 3)]) 42 | def test_partition_with_invalid_input(n_domains): 43 | array = torch.randn((3, 3, 3), dtype=torch.complex64) 44 | with pytest.raises(ValueError): 45 | partition(array, n_domains) 46 | -------------------------------------------------------------------------------- /wavesim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IvoVellekoop/wavesim_py/3fc81f6ef9f3ee523575bd81f196c4f34ee0b788/wavesim/__init__.py -------------------------------------------------------------------------------- /wavesim/domain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class Domain(metaclass=ABCMeta): 6 | """Base class for all simulation domains 7 | 8 | This base class defines the interface for operations that are common for all simulation types, 9 | and for MultiDomain. 10 | todo: the design using slots minimizes memory use, but it is a suboptimal design because it 11 | mixes mutable and immutable state. This design should be revisited so that the Domain is 12 | immutable, and the code that runs the algorithms performs the memory management. 13 | """ 14 | 15 | def __init__(self, pixel_size: float, shape, device): 16 | self.pixel_size = pixel_size 17 | self.scale = None 18 | self.shift = None 19 | self.shape = shape 20 | self.device = device 21 | 22 | @abstractmethod 23 | def add_source(self, slot, weight: float): 24 | pass 25 | 26 | @abstractmethod 27 | def clear(self, slot): 28 | """Clears the data in the specified slot""" 29 | pass 30 | 31 | @abstractmethod 32 | def get(self, slot: int, copy=False): 33 | """Returns the data in the specified slot. 34 | 35 | :param slot: slot from which to return the data 36 | :param copy: if True, returns a copy of the data. Otherwise, may return the original data possible. 37 | 38 | Note that this data may be overwritten by the next call to domain. 39 | """ 40 | pass 41 | 42 | @abstractmethod 43 | def set(self, slot, data): 44 | """Copy the date into the specified slot""" 45 | pass 46 | 47 | @abstractmethod 48 | def inner_product(self, slot_a, slot_b): 49 | """Computes the inner product of two data vectors 50 | 51 | Note: 52 | The vectors may be represented as multidimensional arrays, 53 | but these arrays must be contiguous for this operation to work. 54 | Although it would be possible to use flatten(), this would create a 55 | copy when the array is not contiguous, causing a hidden performance hit. 56 | """ 57 | pass 58 | 59 | @abstractmethod 60 | def medium(self, slot_in, slot_out, mnum): 61 | """Applies the operator 1-Vscat.""" 62 | pass 63 | 64 | @abstractmethod 65 | def mix(self, weight_a, slot_a, weight_b, slot_b, slot_out): 66 | """Mixes two data arrays and stores the result in the specified slot""" 67 | pass 68 | 69 | @abstractmethod 70 | def propagator(self, slot_in, slot_out): 71 | """Applies the operator (L+1)^-1 x. 72 | """ 73 | pass 74 | 75 | @abstractmethod 76 | def inverse_propagator(self, slot_in, slot_out): 77 | """Applies the operator (L+1) x . 78 | 79 | This operation is not needed for the Wavesim algorithm, but is provided for testing purposes, 80 | and can be used to evaluate the residue of the solution. 81 | """ 82 | pass 83 | 84 | @abstractmethod 85 | def set_source(self, source): 86 | """Sets the source term for this domain.""" 87 | pass 88 | 89 | @abstractmethod 90 | def create_empty_vdot(self): 91 | """Create an empty tensor for the Vdot tensor""" 92 | pass 93 | 94 | def coordinates_f(self, dim): 95 | """Returns the Fourier-space coordinates along the specified dimension""" 96 | shapes = [[-1, 1, 1], [1, -1, 1], [1, 1, -1]] 97 | return (2 * torch.pi * torch.fft.fftfreq(self.shape[dim], self.pixel_size, device=self.device, 98 | dtype=torch.float64)).reshape(shapes[dim]).to(torch.complex64) 99 | 100 | def coordinates(self, dim, type: str = 'linear'): 101 | """Returns the real-space coordinates along the specified dimension, starting at 0""" 102 | shapes = [[-1, 1, 1], [1, -1, 1], [1, 1, -1]] 103 | x = torch.arange(self.shape[dim], device=self.device, dtype=torch.float64) * self.pixel_size 104 | if type == 'periodic': 105 | x -= self.pixel_size * (self.shape[dim] // 2) 106 | x = torch.fft.ifftshift(x) # todo: or fftshift? 107 | elif type == 'centered': 108 | x -= self.pixel_size * (self.shape[dim] // 2) 109 | elif type == 'linear': 110 | pass 111 | else: 112 | raise ValueError(f"Unknown type {type}") 113 | 114 | return x.reshape(shapes[dim]) 115 | -------------------------------------------------------------------------------- /wavesim/iteration.py: -------------------------------------------------------------------------------- 1 | from .domain import Domain 2 | from .utilities import is_zero 3 | 4 | 5 | def run_algorithm(domain: Domain, source, alpha=0.75, max_iterations=1000, threshold=1.e-6, full_residuals=False): 6 | """ WaveSim update 7 | 8 | :param domain: Helmholtz base parameters 9 | :param source: source field 10 | :param alpha: relaxation parameter for the Richardson iteration 11 | :param max_iterations: maximum number of iterations 12 | :param threshold: threshold for the residual norm 13 | :param full_residuals: when True, returns list of residuals for all iterations. Otherwise, returns final residual 14 | :return: u, iteration count, residuals """ 15 | 16 | # Reset the field u to zero 17 | slot_x = 0 18 | slot_tmp = 1 19 | domain.clear(slot_x) 20 | domain.set_source(source) 21 | 22 | # compute initial residual norm (with preconditioned source) for normalization 23 | domain.create_empty_vdot() # create empty slot for Vdot tensor. Always 8.1 MiB 24 | domain.add_source(slot_x, weight=1.) # [x] = y 25 | preconditioner(domain, slot_x, slot_x) # [x] = B(L+1)⁻¹y 26 | init_norm_inv = 1 / domain.inner_product(slot_x, slot_x) # inverse of initial norm: 1 / norm([x]) 27 | domain.clear(slot_x) # Clear [x] 28 | 29 | # save list of residuals if requested 30 | residuals = [] if full_residuals else None 31 | 32 | for i in range(max_iterations): 33 | residual_norm = preconditioned_iteration(domain, slot_x, slot_x, slot_tmp, alpha, compute_norm2=True) 34 | # normalize residual norm with preconditioned source (i.e., with norm of B(L+1)⁻¹y) 35 | residual_norm = residual_norm * init_norm_inv # norm(B(x - (L+1)⁻¹ (B·x + c·y))) / norm(B(L+1)⁻¹y) 36 | print('.', end='', flush=True) if (i + 1) % 100 == 0 else None 37 | residuals.append(residual_norm) if full_residuals else None 38 | if residual_norm < threshold: 39 | break 40 | 41 | # return u and u_iter cropped to roi, residual arrays, and state object with information on run 42 | return domain.get(slot_x), (i + 1), residuals if full_residuals else residual_norm 43 | 44 | 45 | def preconditioned_iteration(domain, slot_in: int = 0, slot_out: int = 0, slot_tmp: int = 1, alpha=0.75, 46 | compute_norm2=False): 47 | """ Run one preconditioned iteration. 48 | 49 | Args: 50 | domain: Domain object 51 | slot_in: slot holding input x. This slot will be overwritten! 52 | slot_out: output slot that will receive the result 53 | slot_tmp: slot for temporary storage. Cannot be equal to slot_in, may be equal to slot_out 54 | alpha: relaxation parameter for the Richardson iteration 55 | compute_norm2: when True, returns the squared norm of the residual. Otherwise, returns 0.0 56 | 57 | Richardson iteration: 58 | x -> x + α (y - A x) 59 | 60 | Preconditioned Richardson iteration: 61 | x -> x + α Γ⁻¹ (y - A x) 62 | = x + α c B (L+1)⁻¹ (y - A x) 63 | = x + α c B (L+1)⁻¹ (y - c⁻¹ [L+V] x) 64 | = x + α c B (L+1)⁻¹ (y + c⁻¹ [1-V] x - c⁻¹ [L+1] x) 65 | = x + α B [(L+1)⁻¹ (c y + B x) - x] 66 | = x - α B x + α B (L+1)⁻¹ (c y + B x) 67 | """ 68 | if slot_tmp == slot_in: 69 | raise ValueError("slot_in and slot_tmp should be different") 70 | 71 | domain.medium(slot_in, slot_tmp, mnum=0) # [tmp] = B·x 72 | domain.add_source(slot_tmp, domain.scale) # [tmp] = B·x + c·y 73 | domain.propagator(slot_tmp, slot_tmp) # [tmp] = (L+1)⁻¹ (B·x + c·y) 74 | domain.mix(1.0, slot_in, -1.0, slot_tmp, slot_tmp) # [tmp] = x - (L+1)⁻¹ (B·x + c·y) 75 | domain.medium(slot_tmp, slot_tmp, mnum=1) # [tmp] = B(x - (L+1)⁻¹ (B·x + c·y)) 76 | # optionally compute norm of residual of preconditioned system 77 | retval = domain.inner_product(slot_tmp, slot_tmp) if compute_norm2 else 0.0 78 | domain.mix(1.0, slot_in, -alpha, slot_tmp, slot_out) # [out] = x - α B x + α B (L+1)⁻¹ (B·x + c·y) 79 | return retval 80 | 81 | 82 | def forward(domain: Domain, slot_in: int, slot_out: int): 83 | """ Evaluates the forward operator A= c⁻¹ (L + V) 84 | 85 | Args: 86 | domain: Domain object 87 | slot_in: slot holding input x. This slot will be overwritten! 88 | slot_out: output slot that will receive A x 89 | """ 90 | if slot_in == slot_out: 91 | raise ValueError("slot_in and slot_out must be different") 92 | 93 | domain.medium(slot_in, slot_out) # (1-V) x 94 | domain.inverse_propagator(slot_in, slot_in) # (L+1) x 95 | domain.mix(1.0 / domain.scale, slot_in, -1.0 / domain.scale, slot_out, slot_out) # c⁻¹ (L+V) x 96 | 97 | 98 | def preconditioned_operator(domain: Domain, slot_in: int, slot_out: int): 99 | """ Evaluates the preconditioned operator Γ⁻¹ A 100 | 101 | Where Γ⁻¹ = c B (L+1)⁻¹ 102 | 103 | Note: the scale factor c that makes A accretive and V a contraction is 104 | included in the preconditioner. The Richardson step size is _not_. 105 | 106 | Operator A is the original non-scaled operator, and we have (L+V) = c A 107 | Then: 108 | 109 | Γ⁻¹ A = c B (L+1)⁻¹ A 110 | = c B (L+1)⁻¹ c⁻¹ (L+V) 111 | = B (L+1)⁻¹ (L+V) 112 | = B (L+1)⁻¹ ([L+1] - [1-V]) 113 | = B - B (L+1)⁻¹ B 114 | 115 | Args: 116 | domain: Domain object 117 | slot_in: slot holding input x. This slot will be overwritten! 118 | slot_out: output slot that will receive A x 119 | """ 120 | if slot_in == slot_out: 121 | raise ValueError("slot_in and slot_out must be different") 122 | 123 | domain.medium(slot_in, slot_in) # B x 124 | domain.propagator(slot_in, slot_out) # (L+1)⁻¹ B x 125 | domain.medium(slot_out, slot_out) # B (L+1)⁻¹ B x 126 | domain.mix(1.0, slot_in, -1.0, slot_out, slot_out) # B - B (L+1)⁻¹ B x 127 | 128 | 129 | def preconditioner(domain: Domain, slot_in: int, slot_out: int): 130 | """ Evaluates Γ⁻¹ = c B(L+1)⁻¹ 131 | 132 | Args: 133 | domain: Domain object 134 | slot_in: slot holding input x. This slot will be overwritten! 135 | slot_out: output slot that will receive A x 136 | """ 137 | domain.propagator(slot_in, slot_in) # (L+1)⁻¹ x 138 | domain.medium(slot_in, slot_out) # B (L+1)⁻¹ x 139 | domain.mix(0.0, slot_out, domain.scale, slot_out, slot_out) # c B (L+1)⁻¹ x 140 | 141 | 142 | def domain_operator(domain: Domain, function: str, **kwargs): 143 | """Constructs various operators by combining calls to 'medium', 'propagator', etc. 144 | 145 | todo: this is currently very inefficient because of the overhead of copying data to and from the device 146 | """ 147 | 148 | def potential_(domain, slot_in, slot_out): 149 | domain.medium(slot_in, slot_out) 150 | domain.mix(1.0, slot_in, -1.0, slot_out, slot_out) 151 | 152 | fn = { 153 | 'medium': domain.medium, 154 | 'propagator': domain.propagator, 155 | 'inverse_propagator': domain.inverse_propagator, 156 | 'potential': potential_, 157 | 'forward': lambda slot_in, slot_out: forward(domain, slot_in, slot_out), 158 | 'preconditioned_operator': lambda slot_in, slot_out: preconditioned_operator(domain, slot_in, slot_out), 159 | 'preconditioner': lambda slot_in, slot_out: preconditioner(domain, slot_in, slot_out), 160 | 'richardson': lambda slot_in, slot_out: preconditioned_iteration(domain, slot_in, slot_out, slot_out, **kwargs) 161 | }[function] 162 | 163 | def operator_(x): 164 | if is_zero(x): 165 | domain.clear(0) 166 | else: 167 | domain.set(0, x) 168 | fn(0, 1) 169 | return domain.get(1, copy=True) 170 | 171 | return operator_ 172 | -------------------------------------------------------------------------------- /wavesim/multidomain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .domain import Domain 4 | from .helmholtzdomain import HelmholtzDomain 5 | from .utilities import partition, combine, list_to_array, is_zero 6 | from torch.cuda import empty_cache 7 | 8 | 9 | class MultiDomain(Domain): 10 | """ Class for generating medium (B) and propagator (L+1)^(-1) operators, scaling, 11 | and setting up wrapping and transfer corrections """ 12 | 13 | def __init__(self, 14 | permittivity, 15 | periodic: tuple[bool, bool, bool], 16 | pixel_size: float = 0.25, 17 | wavelength: float = None, 18 | n_domains: tuple[int, int, int] = (1, 1, 1), 19 | n_boundary: int = 8, 20 | device: str = None, 21 | debug: bool = False): 22 | """ Takes input parameters for the HelmholtzBase class (and sets up the operators) 23 | 24 | Args: 25 | permittivity: Permittivity distribution, must be 3-d. 26 | periodic: Indicates for each dimension whether the simulation is periodic or not. 27 | periodic dimensions, the field is wrapped around the domain. 28 | pixel_size: Grid spacing in wavelengths. 29 | wavelength: wavelength in micrometer (um). 30 | n_domains: number of domains to split the simulation into. 31 | the domain size is not divisible by n_domains, the last domain will be slightly smaller than the other 32 | ones. In the future, the domain size may be adjusted to have an efficient fourier transform. 33 | is (1, 1, 1), no domain decomposition. 34 | n_boundary: Number of points used in the wrapping and domain transfer correction. Default is 8. 35 | device: 'cpu' to use the cpu, 'cuda' to distribute the simulation over all available cuda devices, 'cuda:x' 36 | to use a specific cuda device, a list of strings, e.g., ['cuda:0', 'cuda:1'] to distribute the 37 | simulation over these devices in a round-robin fashion, or None, which is equivalent to 'cuda' if 38 | cuda devices are available, and 'cpu' if they are not. 39 | todo: implement 40 | debug: set to True to return inverse_propagator_kernel as output 41 | """ 42 | 43 | # Takes the input parameters and returns these in the appropriate format, with more parameters for setting up 44 | # the Medium (+corrections) and Propagator operators, and scaling 45 | # (self.n_roi, self.s, self.n_dims, self.boundary_widths, self.boundary_pre, self.boundary_post, 46 | # self.n_domains, self.domain_size, self.omega, self.v_min, self.v_raw) = ( 47 | # preprocess(n, pixel_size, n_domains)) 48 | 49 | # validata input parameters 50 | if not permittivity.ndim == 3: 51 | raise ValueError("The permittivity must be a 3D array") 52 | if not len(n_domains) == 3: 53 | raise ValueError("The number of domains must be a 3-tuple") 54 | 55 | # enumerate the cuda devices. We will assign the domains to the devices in a round-robin fashion. 56 | # we use the first GPU as primary device 57 | if device is None or device == 'cuda': 58 | devices = [f'cuda:{device_id}' for device_id in 59 | range(torch.cuda.device_count())] if torch.cuda.is_available() else ['cpu'] 60 | else: 61 | devices = [device] 62 | 63 | if not torch.is_tensor(permittivity): 64 | permittivity = torch.tensor(permittivity) 65 | super().__init__(pixel_size, permittivity.shape, torch.device(devices[0])) 66 | self.periodic = np.array(periodic) 67 | 68 | # compute domain boundaries in each dimension 69 | self.domains = np.empty(n_domains, dtype=HelmholtzDomain) 70 | self.n_domains = n_domains 71 | 72 | # distribute the permittivity map over the subdomains. 73 | p_domains = partition(permittivity, self.n_domains) 74 | subdomain_periodic = [periodic[i] and n_domains[i] == 1 for i in range(3)] 75 | Vwrap = None 76 | for domain_index, p_domain in enumerate(p_domains.flat): 77 | # p_domain = torch.tensor(p_domain, device=devices[domain_index % len(devices)]) 78 | self.domains.flat[domain_index] = HelmholtzDomain(permittivity=p_domain.to(devices[domain_index % 79 | len(devices)]), 80 | pixel_size=pixel_size, wavelength=wavelength, 81 | n_boundary=n_boundary, periodic=subdomain_periodic, 82 | stand_alone=False, debug=debug, Vwrap=Vwrap) 83 | Vwrap = self.domains.flat[domain_index].Vwrap # re-use wrapping matrix 84 | 85 | # determine the optimal shift 86 | limits = np.array([domain.V_bounds for domain in self.domains.flat]) 87 | r_min = np.min(limits[:, 0]) 88 | r_max = np.max(limits[:, 1]) 89 | # i_min = np.min(limits[:, 2]) 90 | # i_max = np.max(limits[:, 3]) 91 | center = 0.5 * (r_min + r_max) # + 0.5j * (i_min + i_max) 92 | 93 | # shift L and V to minimize norm of V 94 | Vscat_norm = 0.0 95 | Vwrap_norm = 0.0 96 | for domain in self.domains.flat: 97 | Vscat_norm = np.maximum(Vscat_norm, domain.initialize_shift(center)) 98 | Vwrap_norm = np.maximum(Vwrap_norm, domain.Vwrap_norm) 99 | 100 | # the factor 2 is because the same matrix is used twice (for domain transfer and wrapping correction) 101 | Vwrap_norm = 2 * Vwrap_norm if max(n_domains) > 1 else Vwrap_norm 102 | 103 | # compute the scaling factor 104 | # apply the scaling to compute the final form of all operators in the iteration 105 | self.shift = center 106 | self.scale = 0.95j / (Vscat_norm + Vwrap_norm) 107 | for domain in self.domains.flat: 108 | domain.initialize_scale(self.scale) 109 | empty_cache() # free up memory before going to run_algorithm 110 | 111 | # Functions implementing the domain interface 112 | # add_source() 113 | # clear() 114 | # get() 115 | # inner_product() 116 | # medium() 117 | # mix() 118 | # propagator() 119 | # set_source() 120 | def add_source(self, slot: int, weight: float): 121 | """ Add the source to the field in slot """ 122 | for domain in self.domains.flat: 123 | domain.add_source(slot, weight) 124 | 125 | def clear(self, slot: int): 126 | """ Clear the field in the specified slot """ 127 | for domain in self.domains.flat: 128 | domain.clear(slot) 129 | 130 | def get(self, slot: int, copy=False, device=None): 131 | """ Get the field in the specified slot, this gathers the fields from all subdomains and puts them in 132 | one big array 133 | 134 | :param slot: slot to get the data from 135 | :param copy: if True, returns a copy of the data. Otherwise, may return the original data possible. 136 | Note that this data may be overwritten by the next call to domain. 137 | :param device: device on which to store the data. Defaults to the primary device 138 | """ 139 | domain_data = list_to_array([domain.get(slot) for domain in self.domains.flat], 1).reshape(self.domains.shape) 140 | return combine(domain_data, device) 141 | 142 | def set(self, slot: int, data): 143 | """Copy the date into the specified slot""" 144 | parts = partition(data, self.n_domains) 145 | for domain, part in zip(self.domains.flat, parts.flat): 146 | domain.set(slot, part) 147 | 148 | def inner_product(self, slot_a: int, slot_b: int): 149 | """ Compute the inner product of the fields in slots a and b 150 | 151 | Note: 152 | Use sqrt(inner_product(slot_a, slot_a)) to compute the norm of the field in slot_a. 153 | There is a large but inconsistent difference in performance between vdot and linalg.norm. 154 | Execution time can vary a factor of 3 or more between the two, depending on the input size 155 | and whether the function is executed on the CPU or the GPU. 156 | """ 157 | inner_product = 0.0 158 | for domain in self.domains.flat: 159 | inner_product += domain.inner_product(slot_a, slot_b) 160 | return inner_product 161 | 162 | def medium(self, slot_in: int, slot_out: int, mnum=None): 163 | """ Apply the medium operator B, including wrapping corrections. 164 | 165 | Args: 166 | slot_in: slot holding the input field 167 | slot_out: slot that will receive the result 168 | mnum: # of the medium() call in preconditioned iteration. 169 | 0 for first, 1 for second medium call. 170 | """ 171 | 172 | # compute the corrections for each domain, before applying the medium operator 173 | domain_edges = [domain.compute_corrections(slot_in) for domain in self.domains.flat] 174 | domain_edges = list_to_array(domain_edges, 2).reshape(*self.domains.shape, 6) 175 | 176 | # Only applies the operator B=1-Vscat. The corrections are applied in the next step 177 | for domain in self.domains.flat: 178 | domain.medium(slot_in, slot_out) 179 | 180 | # apply wrapping corrections. We subtract each correction from 181 | # the opposite side of the domain to compensate for the wrapping. 182 | # also, we add each correction to the opposite side of the neighbouring domain 183 | for idx, domain in enumerate(self.domains.flat): 184 | x = np.unravel_index(idx, self.domains.shape) 185 | # for the wrap corrections, take the corrections for this domain and swap them 186 | wrap_corrections = domain_edges[*x, (1, 0, 3, 2, 5, 4)] 187 | 188 | # for the transfer corrections, take the corrections from the neighbors 189 | def get_neighbor(edge): 190 | dim = edge // 2 191 | offset = -1 if edge % 2 == 0 else 1 192 | x_neighbor = np.array(x) 193 | x_neighbor[dim] += offset 194 | if self.periodic[dim]: 195 | x_neighbor = np.mod(x_neighbor, self.domains.shape) 196 | else: 197 | if x_neighbor[dim] < 0 or x_neighbor[dim] >= self.domains.shape[dim]: 198 | return None 199 | return domain_edges[*tuple(x_neighbor), edge - offset] 200 | 201 | transfer_corrections = [get_neighbor(edge) for edge in range(6)] 202 | 203 | # check if domain should be active in the iteration or not 204 | if mnum is None or domain._debug: # always active outside iteration (mnum==None) or in debug mod 205 | domain.active = True 206 | else: # check based on the norm of the transfer corrections 207 | tc_norm = [a for a in transfer_corrections if a is not None] 208 | if tc_norm: 209 | if domain.counter < 25: # counter for the number of iterations with increasing norm 210 | tc_norm = max([torch.vdot(a.view(-1), a.view(-1)).item().real for a in tc_norm]) 211 | 212 | if mnum == 0: # first medium call in preconditioned iteration 213 | domain.mnum0[1] = tc_norm 214 | elif mnum == 1: # second medium call in preconditioned iteration 215 | domain.mnum1[1] = tc_norm 216 | 217 | # if norm is high, domain is set to active 218 | if domain.mnum0[1] >= 1.e-7 or domain.mnum1[1] >= 1.e-7: 219 | domain.active = True 220 | domain.counter = 25 221 | # if norm is monotonically increasing, increase the counter 222 | elif domain.mnum0[-1] > domain.mnum0[-2] and domain.mnum1[-1] > domain.mnum1[-2]: 223 | domain.counter += 1 224 | else: 225 | domain.counter = 0 226 | # if the norm is not increasing and source is zero, domain is set to inactive 227 | if domain._source is not None and torch.sum(domain._source).item().real == 0.0: 228 | domain.active = False 229 | else: 230 | domain.active = True 231 | domain.counter = 25 232 | 233 | # current norm becomes previous norm 234 | domain.mnum0[0] = domain.mnum0[1] 235 | domain.mnum1[0] = domain.mnum1[1] 236 | else: 237 | domain.active = True 238 | 239 | if domain.active: 240 | domain.apply_corrections(wrap_corrections, transfer_corrections, slot_out) 241 | 242 | def mix(self, weight_a: float, slot_a: int, weight_b: float, slot_b: int, slot_out: int): 243 | """ Mix the fields in slots a and b and store the result in slot_out """ 244 | for domain in self.domains.flat: 245 | domain.mix(weight_a, slot_a, weight_b, slot_b, slot_out) 246 | 247 | def propagator(self, slot_in: int, slot_out: int): 248 | """ Apply propagator operators (L+1)^-1 to subdomains/patches of x.""" 249 | for domain in self.domains.flat: 250 | domain.propagator(slot_in, slot_out) 251 | 252 | def inverse_propagator(self, slot_in: int, slot_out: int): 253 | """ Apply inverse propagator operators L+1 to subdomains/patches of x.""" 254 | for domain in self.domains.flat: 255 | domain.inverse_propagator(slot_in, slot_out) 256 | 257 | def set_source(self, source): 258 | """ Split the source into subdomains and store in the subdomain states.""" 259 | if source is None or is_zero(source): 260 | for domain in self.domains.flat: 261 | domain.set_source(None) 262 | else: 263 | for domain, source in zip(self.domains.flat, partition(source, self.n_domains).flat): 264 | domain.set_source(source) 265 | 266 | def create_empty_vdot(self): 267 | """ Create an empty tensor for the Vdot tensor """ 268 | for domain in self.domains.flat: 269 | domain.create_empty_vdot() 270 | -------------------------------------------------------------------------------- /wavesim/utilities.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Sequence 4 | from itertools import chain 5 | from scipy.special import exp1 6 | 7 | 8 | # Preprocessing functions. These functions are used to preprocess the input parameters , i.e., 9 | # to add absorption and boundaries to the permittivity (refractive index²). 10 | def preprocess(permittivity, boundary_widths=10): 11 | """ Preprocess the input parameters for the simulation. 12 | Add absorption and boundaries to the permittivity (refractive index²), 13 | and return the preprocessed permittivity and boundaries in the format (ax0, ax1, ax2). 14 | 15 | :param permittivity: Refractive index² 16 | :param boundary_widths: Boundary widths (in pixels) 17 | :return: Preprocessed permittivity (refractive index²) with boundaries and absorption 18 | """ 19 | permittivity = check_input_dims(permittivity) # Ensure permittivity is a 3-d array 20 | if permittivity.dtype != np.complex64: 21 | permittivity = permittivity.astype(np.complex64) 22 | n_dims = get_dims(permittivity) # Number of dimensions in simulation 23 | n_roi = np.array(permittivity.shape) # Num of points in ROI (Region of Interest) 24 | 25 | # Ensure boundary_widths is a 3-element array of ints with 0s after n_dims 26 | boundary_widths = check_input_len(boundary_widths, 0, n_dims).astype(int) 27 | 28 | permittivity = add_absorption(permittivity, boundary_widths, n_roi, n_dims) 29 | 30 | return permittivity, boundary_widths 31 | 32 | 33 | def add_absorption(m, boundary_widths, n_roi, n_dims): 34 | """ Add (weighted) absorption to the permittivity (refractive index squared) 35 | 36 | :param m: array (permittivity) 37 | :param boundary_widths: Boundary widths 38 | :param n_roi: Number of points in the region of interest 39 | :param n_dims: Number of dimensions 40 | :return: m with absorption 41 | """ 42 | w = np.ones_like(m) # Weighting function (1 everywhere) 43 | w = pad_boundaries(w, boundary_widths, mode='linear_ramp') # pad w using linear_ramp 44 | a = 1 - w # for absorption, inverse weighting 1 - w 45 | for i in range(n_dims): 46 | left_boundary = boundary_(boundary_widths[i]) # boundary_ is a linear window function 47 | right_boundary = np.flip(left_boundary) # flip is a vertical flip 48 | full_filter = np.concatenate((left_boundary, np.ones(n_roi[i], dtype=np.float32), right_boundary)) 49 | a = np.moveaxis(a, i, -1) * full_filter # transpose to last dimension, apply filter 50 | a = np.moveaxis(a, -1, i) # transpose back to original position 51 | a = 1j * a # absorption is imaginary 52 | 53 | m = pad_boundaries(m, boundary_widths, mode='edge') # pad m using edge values 54 | m = w * m + a # add absorption to m 55 | return m 56 | 57 | 58 | def boundary_(x): 59 | """ Anti-reflection boundary layer (ARL). Linear window function 60 | 61 | :param x: Size of the ARL 62 | """ 63 | return ((np.arange(1, x + 1) - 0.21).T / (x + 0.66)).astype(np.float32) 64 | 65 | 66 | def check_input_dims(x): 67 | """ Expand arrays to 3 dimensions (e.g. permittivity (refractive index²) or source) 68 | 69 | :param x: Input array 70 | :return: x with 3 dimensions 71 | """ 72 | for _ in range(3 - x.ndim): 73 | x = np.expand_dims(x, axis=-1) # Expand dimensions to 3 74 | return x 75 | 76 | 77 | def check_input_len(x, e, n_dims): 78 | """ Check the length of input arrays and expand them to 3 elements if necessary. Either repeat or add 'e' 79 | 80 | :param x: Input array 81 | :param e: Element to add 82 | :param n_dims: Number of dimensions 83 | :return: Array with 3 elements 84 | """ 85 | if isinstance(x, int) or isinstance(x, float): # If x is a single number 86 | x = n_dims * tuple((x,)) + (3 - n_dims) * (e,) # Repeat the number n_dims times, and add (3-n_dims) e's 87 | elif len(x) == 1: # If x is a single element list or tuple 88 | x = n_dims * tuple(x) + (3 - n_dims) * (e,) # Repeat the element n_dims times, and add (3-n_dims) e's 89 | elif isinstance(x, list) or isinstance(x, tuple): # If x is a list or tuple 90 | x += (3 - len(x)) * (e,) # Add (3-len(x)) e's 91 | if isinstance(x, np.ndarray): # If x is a numpy array 92 | x = np.concatenate((x, np.zeros(3 - len(x)))) # Concatenate with (3-len(x)) zeros 93 | return np.array(x) 94 | 95 | 96 | def get_dims(x): 97 | """ Get the number of dimensions of 'x' 98 | 99 | :param x: Input array 100 | :return: Number of dimensions 101 | """ 102 | x = squeeze_(x) # Squeeze the last dimension if it is 1 103 | return x.ndim # Number of dimensions 104 | 105 | 106 | def pad_boundaries(x, boundary_widths, boundary_post=None, mode='constant'): 107 | """ Pad 'x' with boundaries in all dimensions using numpy pad (if x is np.ndarray) or PyTorch nn.functional.pad 108 | (if x is torch.Tensor). 109 | If boundary_post is specified separately, pad with boundary_widths (before) and boundary_post (after). 110 | 111 | :param x: Input array 112 | :param boundary_widths: Boundary widths for padding before and after (or just before if boundary_post not None) 113 | :param boundary_post: Boundary widths for padding after 114 | :param mode: Padding mode 115 | :return: Padded array 116 | """ 117 | x = check_input_dims(x) # Ensure x is a 3-d array 118 | 119 | if boundary_post is None: 120 | boundary_post = boundary_widths 121 | 122 | if isinstance(x, np.ndarray): 123 | pad_width = tuple(zip(boundary_widths, boundary_post)) # pairs ((a0, b0), (a1, b1), (a2, b2)) 124 | return np.pad(x, pad_width, mode) 125 | elif torch.is_tensor(x): 126 | t = zip(boundary_widths[::-1], boundary_post[::-1]) # reversed pairs (a2, b2) (a1, b1) (a0, b0) 127 | pad_width = tuple(chain.from_iterable(t)) # flatten to (a2, b2, a1, b1, a0, b0) 128 | return torch.nn.functional.pad(x, pad_width, mode) 129 | else: 130 | raise ValueError("Input must be a numpy array or a torch tensor") 131 | 132 | 133 | def squeeze_(x): 134 | """ Squeeze the last dimension of 'x' if it is 1 135 | 136 | :param x: Input array 137 | :return: Squeezed array 138 | """ 139 | while x.shape[-1] == 1: 140 | x = np.squeeze(x, axis=-1) 141 | return x 142 | 143 | 144 | # Domain decomposition functions. 145 | def combine(domains: np.ndarray, device='cpu') -> torch.Tensor: 146 | """ Concatenates a 3-d array of 3-d tensors""" 147 | 148 | # Calculate total size for each dimension 149 | total_size = [ 150 | sum(tensor.shape[0] for tensor in domains[:, 0, 0]), 151 | sum(tensor.shape[1] for tensor in domains[0, :, 0]), 152 | sum(tensor.shape[2] for tensor in domains[0, 0, :]), 153 | ] 154 | 155 | # allocate memory 156 | template = domains[0, 0, 0] 157 | result_tensor = torch.empty(size=total_size, dtype=template.dtype, device=device) 158 | 159 | # Fill the pre-allocated tensor 160 | index0 = 0 161 | for i, tensor_slice0 in enumerate(domains[:, 0, 0]): 162 | index1 = 0 163 | for j, tensor_slice1 in enumerate(domains[0, :, 0]): 164 | index2 = 0 165 | for k, tensor in enumerate(domains[0, 0, :]): 166 | tensor = domains[i, j, k] 167 | if tensor.is_sparse: 168 | tensor = tensor.to_dense() 169 | end0 = index0 + tensor.shape[0] 170 | end1 = index1 + tensor.shape[1] 171 | end2 = index2 + tensor.shape[2] 172 | result_tensor[index0:end0, index1:end1, index2:end2] = tensor 173 | index2 += tensor.shape[2] 174 | index1 += domains[i, j, 0].shape[1] 175 | index0 += tensor_slice0.shape[0] 176 | 177 | return result_tensor 178 | 179 | 180 | def list_to_array(input: list, depth: int) -> np.ndarray: 181 | """ Convert a nested list of depth `depth` to a numpy object array """ 182 | # first determine the size of the final array 183 | size = np.zeros(depth, dtype=int) 184 | outer = input 185 | for i in range(depth): 186 | size[i] = len(outer) 187 | outer = outer[0] 188 | 189 | # allocate memory 190 | array = np.empty(size, dtype=object) 191 | 192 | # flatten the input array 193 | for i in range(depth - 1): 194 | input = sum(input, input[0][0:0]) # works both for tuples and lists 195 | 196 | # copy to the output array 197 | ra = array.reshape(-1) 198 | assert ra.base is not None # must be a view 199 | for i in range(ra.size): 200 | if input[i] is None or input[i].is_sparse or input[i].is_contiguous(): 201 | ra[i] = input[i] 202 | else: 203 | ra[i] = input[i].contiguous() 204 | return array 205 | 206 | 207 | def partition(array: torch.Tensor, n_domains: tuple[int, int, int]) -> np.ndarray: 208 | """ Split a 3-D array into a 3-D set of sub-arrays of approximately equal sizes.""" 209 | n_domains = np.array(n_domains) # Add 1 to the end to make it a 3-element array 210 | size = np.array(array.shape) 211 | if any(size < n_domains) or any(n_domains <= 0) or len(n_domains) != 3: 212 | raise ValueError(f"Number of domains {n_domains} must be larger than 1 and " 213 | f"less than or equal to the size of the array {array.shape}") 214 | 215 | # Calculate the size of each domain 216 | large_domain_size = np.ceil(size / n_domains).astype(int) 217 | small_domain_count = large_domain_size * n_domains - size 218 | large_domain_count = n_domains - small_domain_count 219 | subdomain_sizes = [(large_domain_size[dim],) * large_domain_count[dim] + (large_domain_size[dim] - 1,) 220 | * small_domain_count[dim] for dim in range(3)] 221 | 222 | split = _sparse_split if array.is_sparse else torch.split 223 | 224 | array = split(array, subdomain_sizes[0], dim=0) 225 | array = [split(part, subdomain_sizes[1], dim=1) for part in array] 226 | array = [[split(part, subdomain_sizes[2], dim=2) for part in subpart] for subpart in array] 227 | return list_to_array(array, depth=3) 228 | 229 | 230 | def _sparse_split(tensor: torch.Tensor, sizes: Sequence[int], dim: int) -> np.ndarray: 231 | """ Split a COO-sparse tensor into a 3-D set of sub-arrays of approximately equal sizes.""" 232 | if len(sizes) == 1: 233 | return [tensor] # no need to split 234 | 235 | tensor = tensor.coalesce() 236 | indices = tensor.indices().cpu().numpy() 237 | values = tensor.values() 238 | 239 | if dim >= tensor.sparse_dim(): 240 | values = torch.tensor(values.detach().clone().cpu().numpy()) # for troubleshooting access violation 241 | value_list = list(torch.split(values, sizes, dim - tensor.sparse_dim() + 1)) # split dense tensor component 242 | sz = list(tensor.shape) 243 | for i in range(len(value_list)): 244 | sz[dim] = sizes[i] 245 | v = np.array( 246 | value_list[i].cpu().numpy()) # should not be necessary, workaround for access violation bug in torch 247 | value_list[i] = torch.sparse_coo_tensor(indices, v, tuple(sz)) 248 | # print(indices, indices.dtype, value_list[i], value_list[i].shape, sz) 249 | value_list[i].to_dense() # for troubleshooting access violation 250 | return value_list 251 | 252 | coordinate_to_domain = np.array(sum([(idx,) * size for idx, size in enumerate(sizes)], ())) 253 | domain_starts = np.cumsum((0,) + sizes) 254 | domains = coordinate_to_domain[indices[dim, :]] 255 | 256 | def extract_subarray(domain: int) -> torch.Tensor: 257 | mask = domains == domain 258 | domain_indices = indices[:, mask] 259 | if len(domain_indices) == 0: 260 | return None 261 | domain_values = values[mask] 262 | domain_indices[dim, :] -= domain_starts[domain] 263 | size = list(tensor.shape) 264 | size[dim] = sizes[domain] 265 | return torch.sparse_coo_tensor(domain_indices, domain_values, tuple(size)) 266 | 267 | return [extract_subarray(d) for d in range(len(sizes))] 268 | 269 | 270 | # Used in tests 271 | def full_matrix(operator): 272 | """ Converts operator to a 2D square matrix of size np.prod(d) x np.prod(d) 273 | 274 | :param operator: Operator to convert to a matrix. This function must be able to accept a 0 scalar, and 275 | return a vector of the size and data type of the domain. 276 | """ 277 | y = operator(0.0) 278 | n_size = y.shape 279 | nf = np.prod(n_size) 280 | M = torch.zeros((nf, nf), dtype=y.dtype, device=y.device) 281 | b = torch.zeros(n_size, dtype=y.dtype, device=y.device) 282 | for i in range(nf): 283 | b.view(-1)[i] = 1 284 | M[:, i] = torch.ravel(operator(b)) 285 | b.view(-1)[i] = 0 286 | 287 | return M 288 | 289 | 290 | # Metrics 291 | def max_abs_error(e, e_true): 292 | """ (Normalized) Maximum Absolute Error (MAE) ||e-e_true||_{inf} / ||e_true|| 293 | 294 | :param e: Computed field 295 | :param e_true: True field 296 | :return: (Normalized) MAE 297 | """ 298 | return np.max(np.abs(e - e_true)) / np.linalg.norm(e_true) 299 | 300 | 301 | def max_relative_error(e, e_true): 302 | """Computes the maximum error, normalized by the rms of the true field 303 | 304 | :param e: Computed field 305 | :param e_true: True field 306 | :return: (Normalized) Maximum Relative Error 307 | """ 308 | return np.max(np.abs(e - e_true)) / np.sqrt(np.mean(np.abs(e_true) ** 2)) 309 | 310 | 311 | def relative_error(e, e_true): 312 | """ Relative error ``⟨|e-e_true|^2⟩ / ⟨|e_true|^2⟩`` 313 | 314 | :param e: Computed field 315 | :param e_true: True field 316 | :return: Relative Error 317 | """ 318 | return np.nanmean(np.abs(e - e_true) ** 2) / np.nanmean(np.abs(e_true) ** 2) 319 | 320 | 321 | # Miscellaneous functions 322 | ## 1D analytical solution for Helmholtz equation 323 | def analytical_solution(n_size0, pixel_size, wavelength=None): 324 | """ Compute analytic solution for 1D case """ 325 | x = np.arange(0, n_size0 * pixel_size, pixel_size, dtype=np.float32) 326 | x = np.pad(x, (n_size0, n_size0), mode='constant', constant_values=np.nan) 327 | h = pixel_size 328 | # wavenumber (k) 329 | if wavelength is None: 330 | k = 1. * 2. * np.pi * pixel_size 331 | else: 332 | k = 1. * 2. * np.pi / wavelength 333 | phi = k * x 334 | u_theory = (1.0j * h / (2 * k) * np.exp(1.0j * phi) # propagating plane wave 335 | - h / (4 * np.pi * k) * ( 336 | np.exp(1.0j * phi) * (exp1(1.0j * (k - np.pi / h) * x) - exp1(1.0j * (k + np.pi / h) * x)) - 337 | np.exp(-1.0j * phi) * (-exp1(-1.0j * (k - np.pi / h) * x) + exp1(-1.0j * (k + np.pi / h) * x))) 338 | ) 339 | small = np.abs(k * x) < 1.e-10 # special case for values close to 0 340 | u_theory[small] = 1.0j * h / (2 * k) * (1 + 2j * np.arctanh(h * k / np.pi) / np.pi) # exact value at 0. 341 | return u_theory[n_size0:-n_size0] 342 | 343 | 344 | def is_zero(x): 345 | """ Check if x is zero 346 | 347 | Some functions allow specifying 0 or 0.0 instead of a torch tensor, to indicate that the array should be cleared. 348 | This function returns True if x is a scalar 0 or 0.0. It raises an error if x is a scalar that is not equal to 0 or 349 | 0.0, and returns False otherwise. 350 | """ 351 | if isinstance(x, float) or isinstance(x, int): 352 | if x != 0: 353 | raise ValueError("Cannot set a field to a scalar to a field, only scalar 0.0 is supported") 354 | return True 355 | else: 356 | return False 357 | 358 | 359 | def normalize(x, min_val=None, max_val=None, a=0, b=1): 360 | """ Normalize x to the range [a, b] 361 | 362 | :param x: Input array 363 | :param min_val: Minimum value (of x) 364 | :param max_val: Maximum value (of x) 365 | :param a: Lower bound for normalization 366 | :param b: Upper bound for normalization 367 | :return: Normalized x 368 | """ 369 | if min_val is None: 370 | min_val = x.min() 371 | if max_val is None: 372 | max_val = x.max() 373 | normalized_x = (x - min_val) / (max_val - min_val) * (b - a) + a 374 | return normalized_x 375 | --------------------------------------------------------------------------------