├── .flake8 ├── .github └── workflows │ ├── pythonpublish.yml │ ├── test_code.yml │ └── test_docs.yml ├── .gitignore ├── LICENSE ├── README.md ├── doc └── source │ ├── Makefile │ ├── _static │ └── css │ │ └── custom.css │ ├── conf.py │ ├── index.rst │ ├── introduction.md │ ├── make.bat │ ├── modules.rst │ ├── requirements.txt │ └── simphox.rst ├── requirements.txt ├── setup.py ├── simphox ├── __init__.py ├── bpm.py ├── circuit │ ├── __init__.py │ ├── component.py │ ├── coupling.py │ ├── envelope.py │ ├── forward.py │ ├── matrix.py │ ├── rectangular.py │ └── vector.py ├── fdfd.py ├── fdtd.py ├── grid.py ├── mkl.py ├── mode.py ├── opt.py ├── parse.py ├── primitives.py ├── sim.py ├── transform.py ├── typing.py ├── utils.py └── viz.py └── tests ├── circuit_test.py ├── fdfd_test.py ├── grid_test.py ├── mode_test.py ├── primitives_test.py ├── sim_test.py └── transform_test.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | per-file-ignores = __init__.py:F401 3 | ignore = E226,E302,E41 4 | max-line-length = 160 -------------------------------------------------------------------------------- /.github/workflows/pythonpublish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [created, published] 6 | push: 7 | branches: [master] 8 | tags: [v*] 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: "3.9" 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install setuptools wheel twine 23 | - name: Build and publish 24 | env: 25 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 26 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 27 | run: | 28 | python setup.py sdist bdist_wheel 29 | twine upload dist/* --verbose 30 | -------------------------------------------------------------------------------- /.github/workflows/test_code.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Lint and Test 3 | # https://docs.github.com/en/free-pro-team@latest/actions/guides/building-and-testing-python 4 | 5 | on: 6 | pull_request: 7 | push: 8 | schedule: 9 | - cron: 0 2 * * * # run at 2 AM UTC 10 | 11 | jobs: 12 | build: 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | max-parallel: 12 16 | matrix: 17 | python-version: [3.9] 18 | os: [ubuntu-latest] 19 | 20 | steps: 21 | - name: Cancel Workflow Action 22 | uses: styfle/cancel-workflow-action@0.9.1 23 | - uses: actions/checkout@v2 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install flake8 pytest 32 | pip install . 33 | - name: Lint with flake8 34 | run: | 35 | flake8 . 36 | - name: Test with pytest 37 | run: pytest 38 | -------------------------------------------------------------------------------- /.github/workflows/test_docs.yml: -------------------------------------------------------------------------------- 1 | name: Test documentation 2 | 3 | on: 4 | pull_request: 5 | push: 6 | schedule: 7 | - cron: "0 2 * * *" # run at 2 AM UTC 8 | - 9 | 10 | jobs: 11 | build-linux: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Cancel Workflow Action 16 | uses: styfle/cancel-workflow-action@0.9.1 17 | - uses: actions/checkout@v2 18 | - name: Set up Python 3.9 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: 3.9 22 | - name: Install dependencies 23 | run: | 24 | pip install -r requirements.txt 25 | pip install . 26 | pip install -r doc/source/requirements.txt 27 | sudo apt install pandoc 28 | sudo apt-get install dvipng texlive-latex-extra texlive-fonts-recommended cm-super 29 | - name: Test documentation 30 | run: | 31 | cd doc/source 32 | make html 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py? 2 | *~ 3 | *.swp 4 | .cache 5 | __pycache__ 6 | *.egg-info 7 | .env 8 | .env* 9 | .idea 10 | .vscode 11 | .DS_Store 12 | data/* 13 | aim_lib/* 14 | doc/source/_build 15 | *.gds 16 | *.jpg 17 | *.png 18 | *.tif 19 | build/* 20 | dist/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Fan Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![simphox](https://user-images.githubusercontent.com/7623867/131265616-4e438679-f3b6-4a9f-b401-130a41cb8ab7.png) 2 | # 3 | ![Build Status](https://img.shields.io/travis/fancompute/simphox/master.svg?style=for-the-badge) 4 | ![Docs](https://readthedocs.org/projects/simphox/badge/?style=for-the-badge) 5 | ![PiPy](https://img.shields.io/pypi/v/simphox.svg?style=for-the-badge) 6 | ![CodeCov](https://img.shields.io/codecov/c/github/fancompute/simphox/master.svg?style=for-the-badge) 7 | 8 | 9 | 10 | 11 | Another inverse design library (wip) 12 | -------------------------------------------------------------------------------- /doc/source/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /doc/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../../')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'simphox' 21 | copyright = '2021, Sunil Pai' 22 | author = 'Sunil Pai' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = '0.0.1alpha' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | master_doc = 'index' 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = [ 35 | 'sphinx.ext.napoleon', 36 | 'sphinx.ext.mathjax', 37 | 'sphinx.ext.autodoc', 38 | 'sphinx_autodoc_typehints', 39 | 'sphinx.ext.viewcode', 40 | 'sphinx.ext.inheritance_diagram', 41 | 'myst_parser' 42 | ] 43 | 44 | myst_enable_extensions = [ 45 | "colon_fence", 46 | ] 47 | 48 | mathjax_path = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-MML-AM_CHTML" 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ['_templates'] 52 | 53 | # List of patterns, relative to source directory, that match files and 54 | # directories to ignore when looking for source files. 55 | # This pattern also affects html_static_path and html_extra_path. 56 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 57 | 58 | 59 | # -- Options for HTML output ------------------------------------------------- 60 | 61 | # The theme to use for HTML and HTML Help pages. See the documentation for 62 | # a list of builtin themes. 63 | # 64 | html_theme_path = ["_themes"] 65 | html_theme = 'sphinx_rtd_theme' 66 | html_theme_options = { 67 | "logo_only": True 68 | } 69 | 70 | # Add any paths that contain custom static files (such as style sheets) here, 71 | # relative to this directory. They are copied after the builtin static files, 72 | # so a file named "default.css" will overwrite the builtin "default.css". 73 | html_static_path = ['_static'] 74 | html_logo = "https://user-images.githubusercontent.com/7623867/131265616-4e438679-f3b6-4a9f-b401-130a41cb8ab7.png" 75 | 76 | 77 | def setup(app): 78 | app.add_css_file('css/custom.css') 79 | -------------------------------------------------------------------------------- /doc/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to simphox's documentation! 2 | =================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Contents: 7 | 8 | 9 | 10 | Indices and tables 11 | ================== 12 | 13 | * :ref:`genindex` 14 | * :ref:`modindex` 15 | * :ref:`search` 16 | -------------------------------------------------------------------------------- /doc/source/introduction.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fancompute/simphox/8aa2560beb7f9bf3d57db02be768695fb2a7acca/doc/source/introduction.md -------------------------------------------------------------------------------- /doc/source/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /doc/source/modules.rst: -------------------------------------------------------------------------------- 1 | simphox 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | simphox 8 | -------------------------------------------------------------------------------- /doc/source/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | matplotlib 4 | xarray 5 | pydantic 6 | pandas 7 | bokeh 8 | holoviews 9 | jaxlib 10 | jax 11 | scikit-image 12 | bokeh 13 | sphinx-autodoc-typehints 14 | sphinx-autodoc-annotation 15 | autodoc_pydantic 16 | myst_parser 17 | absl-py -------------------------------------------------------------------------------- /doc/source/simphox.rst: -------------------------------------------------------------------------------- 1 | simphox package 2 | =============== 3 | 4 | Submodules 5 | ---------- 6 | 7 | 8 | simphox.circuit module 9 | ---------------------- 10 | 11 | .. automodule:: simphox.circuit 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: 15 | 16 | simphox.fdfd module 17 | ------------------- 18 | 19 | .. automodule:: simphox.fdfd 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | 24 | 25 | simphox.grid module 26 | ------------------- 27 | 28 | .. automodule:: simphox.grid 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | simphox.mkl module 34 | ------------------ 35 | 36 | .. automodule:: simphox.mkl 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | 41 | simphox.mode module 42 | ------------------- 43 | 44 | .. automodule:: simphox.mode 45 | :members: 46 | :undoc-members: 47 | :show-inheritance: 48 | 49 | simphox.opt module 50 | ------------------ 51 | 52 | .. automodule:: simphox.opt 53 | :members: 54 | :undoc-members: 55 | :show-inheritance: 56 | 57 | simphox.primitives module 58 | ------------------------- 59 | 60 | .. automodule:: simphox.primitives 61 | :members: 62 | :undoc-members: 63 | :show-inheritance: 64 | 65 | simphox.sim module 66 | ------------------ 67 | 68 | .. automodule:: simphox.sim 69 | :members: 70 | :undoc-members: 71 | :show-inheritance: 72 | 73 | simphox.transform module 74 | ------------------------ 75 | 76 | .. automodule:: simphox.transform 77 | :members: 78 | :undoc-members: 79 | :show-inheritance: 80 | 81 | simphox.utils module 82 | -------------------- 83 | 84 | .. automodule:: simphox.utils 85 | :members: 86 | :undoc-members: 87 | :show-inheritance: 88 | 89 | simphox.viz module 90 | ------------------ 91 | 92 | .. automodule:: simphox.viz 93 | :members: 94 | :undoc-members: 95 | :show-inheritance: 96 | 97 | Module contents 98 | --------------- 99 | 100 | .. automodule:: simphox 101 | :members: 102 | :undoc-members: 103 | :show-inheritance: 104 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | pandas 4 | pydantic 5 | matplotlib 6 | jax 7 | jaxlib 8 | scikit-image 9 | absl-py 10 | dm-haiku 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, find_packages 3 | 4 | project_name = "simphox" 5 | 6 | setup( 7 | name=project_name, 8 | version="0.0.1a8", 9 | packages=find_packages(), 10 | install_requires=[ 11 | 'numpy', 12 | 'scipy', 13 | 'pandas', 14 | 'jaxlib', 15 | 'jax', 16 | 'scikit-image', 17 | 'pydantic', 18 | 'xarray', 19 | 'dm-haiku', 20 | 'absl-py' 21 | ], 22 | extras_require={ 23 | 'interactive': ['matplotlib', 24 | 'jupyterlab', 25 | 'holoviews', 26 | 'bokeh'] 27 | } 28 | ) 29 | -------------------------------------------------------------------------------- /simphox/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import Grid, YeeGrid 2 | from .fdfd import FDFD 3 | from .fdtd import FDTD 4 | from .mode import ModeLibrary, ModeSolver 5 | from .utils import Box, Material, SILICON, POLYSILICON, NITRIDE, OXIDE, TEST_INF, AIR 6 | 7 | from jax.config import config 8 | config.update("jax_enable_x64", True) 9 | -------------------------------------------------------------------------------- /simphox/bpm.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | from scipy.linalg import solve_banded 5 | 6 | from .sim import SimGrid 7 | from .typing import Shape, Size, Spacing, Optional, Union, Size3 8 | 9 | 10 | class BPM(SimGrid): 11 | def __init__(self, size: Size, spacing: Spacing, eps: Union[float, np.ndarray] = 1, 12 | wavelength: float = 1.55, bloch_phase: Union[Size, float] = 0.0, 13 | pml: Optional[Union[Shape, Size]] = None, pml_params: Size3 = (4, -16, 1, 5), 14 | yee_avg: bool = True, no_grad: bool = True, not_implemented: bool = True): 15 | 16 | if not_implemented: # this is just to avoid annoying pycharm linting (TODO: remove this when fixed) 17 | raise NotImplementedError("This class is still WIP") 18 | 19 | self.wavelength = wavelength 20 | self.k0 = 2 * np.pi / self.wavelength # defines the units for the simulation! 21 | self.no_grad = no_grad 22 | 23 | super(BPM, self).__init__( 24 | size=size, 25 | spacing=spacing, 26 | eps=eps, 27 | bloch_phase=bloch_phase, 28 | pml=pml, 29 | pml_params=pml_params 30 | ) 31 | 32 | if self.ndim == 1: 33 | raise ValueError(f"Simulation dimension ndim must be 2 or 3 but got {self.ndim}.") 34 | self.init() 35 | 36 | def init(self, center: Tuple[float, ...] = None, shape: Tuple[float, ...] = None, axis: int = 0): 37 | # initial scalar fields for fdtd 38 | center = (0, self.shape[1] // 2, self.shape[2] // 2) if center is None else center 39 | shape = self.eps[0].shape if shape is None else shape 40 | self.x = center[0] 41 | # self.beta, _, self.e, self.h = mode_profile(self, center=center, size=shape, axis=axis) 42 | 43 | def adi_polarized(self, te: bool = True): 44 | """The ADI step for beam propagation method based on https://publik.tuwien.ac.at/files/PubDat_195610.pdf 45 | 46 | Returns: 47 | 48 | """ 49 | d, _ = self._dxes 50 | if self.ndim == 3: 51 | s, e = d[1], d[0] 52 | n, w = np.roll(s, 1, axis=1), np.roll(e, 1, axis=0) 53 | n[0], w[0], s[-1], e[-1] = 0, 0, 0, 0 # set to zero to make life easy later 54 | 55 | a_x = np.tile(2 / (w * (e + w)).flatten(), 2) 56 | c_x = np.tile(2 / (e * (e + w)).flatten(), 2) 57 | a_y = np.tile(2 / (n * (n + s)).flatten(), 2) 58 | c_y = np.tile(2 / (s * (s + n)).flatten(), 2) 59 | 60 | eps = self.eps[self.x, :, :] 61 | e = self.e[1, self.x, :, :] if te else self.e[0, self.x, :, :] 62 | h = self.h[0, self.x, :, :] if te else self.h[1, self.x, :, :] 63 | phi = np.stack(e.flatten(), h.flatten()) 64 | 65 | if te: 66 | eps_e = np.roll(eps, 1, axis=0) 67 | eps_w = np.roll(eps, -1, axis=0) 68 | a_x *= np.hstack(((2 * eps_w / (eps + eps_w)).flatten(), (2 * eps / (eps + eps_w)).flatten())) 69 | c_x *= np.hstack(((2 * eps_e / (eps + eps_e)).flatten(), (2 * eps / (eps + eps_e)).flatten())) 70 | else: 71 | eps_n = np.roll(eps, -1, axis=1) 72 | eps_s = np.roll(eps, 1, axis=1) 73 | a_y *= np.hstack(((2 * eps_n / (eps + eps_n)).flatten(), (2 * eps / (eps + eps_n)).flatten())) 74 | c_y *= np.hstack(((2 * eps_s / (eps + eps_s)).flatten(), (2 * eps / (eps + eps_s)).flatten())) 75 | 76 | b_x = -(c_x + a_x) 77 | b_y = -(a_y + c_y) 78 | 79 | if te: 80 | adjustment = -4 / (e * w).flatten() 81 | b_x = np.hstack(adjustment, np.zeros_like(adjustment)) - b_x 82 | else: 83 | adjustment = -4 / (n * s).flatten() 84 | b_y = np.hstack(adjustment, np.zeros_like(adjustment)) - b_y 85 | 86 | # ADI algorithm 87 | 88 | b_x += (self.k0 ** 2 * eps.flatten() - self.beta ** 2) / 2 89 | b_y += (self.k0 ** 2 * eps.flatten() - self.beta ** 2) / 2 90 | t_x = np.vstack([-a_x, -b_x - 4 * 1j * self.beta / self.spacing[-1], -c_x]) 91 | t_y = np.vstack([-a_y, -b_y - 4 * 1j * self.beta / self.spacing[-1], -c_y]) 92 | d_x = np.roll(phi, -1) * a_y + phi * b_y + np.roll(phi, 1) * c_y 93 | phi_x = solve_banded((1, 1), t_x, d_x) 94 | d_y = np.roll(phi, -1) * a_x + phi_x * b_x + np.roll(phi_x, 1) * c_x 95 | new_phi = solve_banded((1, 1), t_y, d_y) 96 | if te: 97 | self.e[1, self.x, :, :].flat, self.h[0, self.x, :, :].flat = np.hsplit(new_phi, 2) 98 | else: 99 | self.e[0, self.x, :, :].flat, self.h[1, self.x, :, :].flat = np.hsplit(new_phi, 2) 100 | self.x += 1 101 | -------------------------------------------------------------------------------- /simphox/circuit/__init__.py: -------------------------------------------------------------------------------- 1 | from .matrix import cascade, tree_cascade, triangular, psvd 2 | from .vector import balanced_tree, unbalanced_tree, vector_unit 3 | from .forward import ForwardMesh 4 | from .coupling import CouplingNode, PhaseStyle 5 | from .rectangular import rectangular 6 | from .component import Component 7 | -------------------------------------------------------------------------------- /simphox/circuit/component.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import xarray as xr 3 | 4 | try: 5 | DPHOX_IMPORTED = True 6 | from dphox.device import Device 7 | from dphox.pattern import Pattern 8 | except ImportError: 9 | DPHOX_IMPORTED = False 10 | 11 | from ..fdfd import FDFD 12 | from ..typing import Callable, Iterable, List, Optional, Size, Union 13 | 14 | 15 | class Component: 16 | def __init__(self, structure: Union["Pattern", "Device"], 17 | model: Union[xr.DataArray, Callable[[jnp.ndarray], xr.DataArray]], name: str): 18 | """A component in a circuit will have some structure that can be simulated 19 | (a pattern or device defined in DPhox), a model, and a name string. 20 | 21 | Args: 22 | structure: Structure of the device. 23 | model: Model of the device (in terms of wavelength). 24 | name: Name of the component (string representing the model name). 25 | """ 26 | self.structure = structure 27 | self.model = model 28 | self.name = name 29 | 30 | @classmethod 31 | def from_fdfd(cls, pattern: "Pattern", core_eps: float, clad_eps: float, spacing: float, 32 | wavelengths: Iterable[float], 33 | boundary: Size, pml: float, name: str, in_ports: Optional[List[str]] = None, 34 | out_ports: Optional[List[str]] = None, component_t: float = 0, component_zmin: Optional[float] = None, 35 | rib_t: float = 0, sub_z: float = 0, height: float = 0, bg_eps: float = 1, 36 | profile_size_factor: int = 3, 37 | pbar: Optional[Callable] = None): 38 | """From FDFD, this classmethod produces a component model based on a provided pattern 39 | and simulation attributes (currently configured for scalar photonics problems). 40 | 41 | Args: 42 | pattern: component provided by DPhox 43 | core_eps: core epsilon 44 | clad_eps: clad epsilon 45 | spacing: spacing required 46 | wavelengths: wavelengths 47 | boundary: boundary size around component 48 | pml: PML size (see :code:`FDFD` class for details) 49 | name: component name 50 | in_ports: input ports 51 | out_ports: output ports 52 | height: height for 3d simulation 53 | sub_z: substrate minimum height 54 | component_zmin: component height (defaults to substrate_z) 55 | component_t: component thickness 56 | rib_t: rib thickness for component (partial etch) 57 | bg_eps: background epsilon (usually 1 or air) 58 | profile_size_factor: profile size factor (multiply port size dimensions to get mode dimensions at each port) 59 | pbar: Progress bar (e.g. TQDM in a notebook which can be a valuable progress indicator). 60 | 61 | Returns: 62 | Initialize a component which contains a structure (for port specificication and visualization purposes) 63 | and model describing the component behavior. 64 | 65 | """ 66 | sparams = [] 67 | 68 | iterator = wavelengths if pbar is None else pbar(wavelengths) 69 | for wl in iterator: 70 | fdfd = FDFD.from_pattern( 71 | component=pattern, 72 | core_eps=core_eps, 73 | clad_eps=clad_eps, 74 | spacing=spacing, 75 | height=height, 76 | boundary=boundary, 77 | pml=pml, 78 | component_t=component_t, 79 | component_zmin=component_zmin, 80 | wavelength=wl, 81 | rib_t=rib_t, 82 | sub_z=sub_z, 83 | bg_eps=bg_eps, 84 | name=f'{name}_{wl}um' 85 | ) 86 | sparams_wl = [] 87 | for port in fdfd.port: 88 | s, _ = fdfd.get_sim_sparams_fn(port, profile_size_factor=profile_size_factor)(fdfd.eps) 89 | sparams_wl.append(s) 90 | sparams.append(sparams_wl) 91 | 92 | model = xr.DataArray( 93 | data=sparams, 94 | dims=["wavelengths", "in_ports", "out_ports"], 95 | coords={ 96 | "wavelengths": wavelengths, 97 | "in_ports": in_ports, 98 | "out_ports": out_ports 99 | } 100 | ) 101 | 102 | return cls(pattern, model=model, name=name) 103 | -------------------------------------------------------------------------------- /simphox/circuit/coupling.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import numpy as np 4 | from enum import Enum 5 | from pydantic.dataclasses import dataclass 6 | from ..utils import fix_dataclass_init_docs 7 | 8 | 9 | def phase_matrix(top: float, bottom: float = 0): 10 | return np.array([ 11 | [np.exp(1j * top), 0], 12 | [0, np.exp(1j * bottom)] 13 | ]) 14 | 15 | 16 | def coupling_matrix_s(s: float): 17 | return np.array([ 18 | [np.sqrt(1 - s), np.sqrt(s)], 19 | [np.sqrt(s), -np.sqrt(1 - s)] 20 | ]) 21 | 22 | 23 | def loss2insertion(loss_db: float): 24 | return np.sqrt(np.exp(np.log(10) * loss_db / 10)) 25 | 26 | 27 | def coupling_matrix_phase(theta: float, split_error: float = 0, loss_error: float = 0): 28 | insertion = loss2insertion(loss_error) 29 | return np.array([ 30 | [insertion * np.cos(theta / 2 + split_error / 2), insertion * 1j * np.sin(theta / 2 + split_error / 2)], 31 | [1j * np.sin(theta / 2 + split_error / 2), np.sin(theta / 2 + split_error / 2)] 32 | ]) 33 | 34 | 35 | def _embed_2x2(mat: np.ndarray, n: int, i: int, j: int): 36 | if mat.shape != (2, 2): 37 | raise AttributeError(f"Expected shape (2, 2), but got {mat.shape}.") 38 | out = np.eye(n, dtype=np.complex128) 39 | out[i, i] = mat[0, 0] 40 | out[i, j] = mat[0, 1] 41 | out[j, i] = mat[1, 0] 42 | out[j, j] = mat[1, 1] 43 | return out 44 | 45 | 46 | class PhaseStyle(str, Enum): 47 | """Enumeration for the different phase styles (differential, common, top, bottom). 48 | 49 | A phase style is defined as 50 | 51 | Attributes: 52 | TOP: Top phase shift 53 | BOTTOM: Bottom phase shift 54 | DIFFERENTIAL: Differential phase shift 55 | SYMMETRIC: Symmetric phase shift 56 | 57 | """ 58 | TOP = 'top' 59 | BOTTOM = 'bottom' 60 | DIFFERENTIAL = 'differential' 61 | SYMMETRIC = 'symmetric' 62 | 63 | 64 | @fix_dataclass_init_docs 65 | @dataclass 66 | class CouplingNode: 67 | """A simple programmable 2x2 coupling node model. 68 | 69 | Attributes: 70 | node_id: The index of the coupling node (useful in networks). 71 | loss: The differential loss error of the overall coupling node in dB (one at each of the theta and phi phase shifters). 72 | bs_error: The splitting error of the coupling node (MZI coupling errors). 73 | n: Total number of inputs/outputs. 74 | top: Top input/output index. 75 | bottom: Bottom input/output index. 76 | num_top: Total number of inputs connected to top port (for tree architectures initialization). 77 | num_bottom: Total number of inputs connected to bottom port (for tree architecture initialization). 78 | column: The column label assigned to the node 79 | 80 | """ 81 | node_id: int = 0 82 | loss: Tuple[float, float] = (0., 0.) 83 | bs_error: Union[float, Tuple[float, float]] = 0. 84 | n: int = 2 85 | top: int = 0 86 | bottom: int = 1 87 | alpha: int = 1 88 | beta: int = 1 89 | column: int = 0 90 | 91 | def __post_init_post_parse__(self): 92 | self.stride = self.bottom - self.top 93 | self.top_descendants = np.array([]) 94 | self.bot_descendants = np.array([]) 95 | self.bs_error = (self.bs_error, self.bs_error) if np.isscalar(self.bs_error) else self.bs_error 96 | 97 | @property 98 | def mzi_terms(self): 99 | return [ 100 | np.cos(np.pi / 4 + self.bs_error[1]) * np.cos(np.pi / 4 + self.bs_error[0]), 101 | np.cos(np.pi / 4 + self.bs_error[1]) * np.sin(np.pi / 4 + self.bs_error[0]), 102 | np.sin(np.pi / 4 + self.bs_error[1]) * np.cos(np.pi / 4 + self.bs_error[0]), 103 | np.sin(np.pi / 4 + self.bs_error[1]) * np.sin(np.pi / 4 + self.bs_error[0]), 104 | ] 105 | 106 | def ideal_node(self, s: float = 0, phi: float = 0): 107 | """Ideal node with parameters s and phi that can be embedded in a circuit. 108 | 109 | Args: 110 | s: Cross split ratio :math:`s \\in [0, 1]` (:math:`s=1` means cross state). 111 | phi: Differential phase :math:`\\phi \\in [0, 2\\pi)` (set phase between inputs to the node). 112 | 113 | Returns: 114 | The embedded ideal node. 115 | 116 | """ 117 | mat = phase_matrix(phi) @ coupling_matrix_s(s) 118 | return _embed_2x2(mat, self.n, self.top, self.bottom) 119 | 120 | def mzi_node_matrix(self, theta: float = 0, phi: float = 0, embed: bool = True): 121 | """Tunable Mach-Zehnder interferometer node matrix. 122 | 123 | Args: 124 | theta: MMI phase between odd/even modes :math:`\\theta \\in [0, \\pi]` (:math:`\\theta=0` means cross state). 125 | phi: Differential phase :math:`\\phi \\in [0, 2\\pi)` (set phase between inputs to the node). 126 | embed: Whether to return the embedded matrix in the n-waveguide system (specified in node). 127 | 128 | Returns: 129 | Tunable MMI node matrix embedded in an :math:`N`-waveguide system. 130 | 131 | """ 132 | mat = self.dc(right=True) @ phase_matrix(theta) @ self.dc(right=False) @ phase_matrix(phi) 133 | return _embed_2x2(mat, self.n, self.top, self.bottom) if embed else mat 134 | 135 | def phase_matrix(self, top: float = 0, bottom: float = 0): 136 | """Embedded phase matrix. 137 | 138 | Args: 139 | top: Top phase of the phase matrix 140 | bottom: Bottom phase of the phase matrix 141 | 142 | Returns: 143 | Embedded phase matrix. 144 | 145 | """ 146 | return _embed_2x2(phase_matrix(top, bottom), self.n, self.top, self.bottom) 147 | 148 | def dc(self, right: bool = False) -> np.ndarray: 149 | """Directional coupler matrix with error. 150 | 151 | Args: 152 | right: Whether to use the left or right error (:code:`error` and :code:`error_right` respectively). 153 | 154 | Returns: 155 | A directional coupler matrix with error. 156 | 157 | """ 158 | error = self.bs_error[right] 159 | insertion = loss2insertion(self.loss[right]) 160 | return np.array([ 161 | [np.cos(np.pi / 4 + error) * insertion, 1j * np.sin(np.pi / 4 + error) * insertion], 162 | [1j * np.sin(np.pi / 4 + error), np.cos(np.pi / 4 + error)] 163 | ]) 164 | 165 | def mmi_node_matrix(self, theta: float = 0, phi: float = 0, embed: bool = True): 166 | """Tunable multimode interferometer node matrix. 167 | 168 | Args: 169 | theta: MZI arm phase :math:`\\theta \\in [0, \\pi]` (:math:`\\theta=0` means cross state). 170 | phi: Differential phase :math:`\\phi \\in [0, 2\\pi)` (set phase between inputs to the node). 171 | embed: Whether to return the embedded matrix in the n-waveguide system (specified in node). 172 | 173 | Returns: 174 | Tunable MMI node matrix embedded in an :math:`N`-waveguide system. 175 | 176 | """ 177 | mat = coupling_matrix_phase(theta, self.bs_error, self.loss) @ phase_matrix(phi) 178 | return _embed_2x2(mat, self.n, self.top, self.bottom) if embed else mat 179 | 180 | def nullify(self, vector: np.ndarray, idx: int, lower_theta: bool = False, lower_phi: bool = False): 181 | theta = np.arctan2(np.abs(vector[idx]), np.abs(vector[idx + 1])) * 2 182 | theta = -theta if lower_theta else theta 183 | phi = np.angle(vector[idx + 1]) - np.angle(vector[idx]) 184 | phi = -phi if lower_phi else phi 185 | mat = self.mzi_node_matrix(theta, phi) 186 | nullified_vector = mat @ vector 187 | return nullified_vector, mat, np.mod(theta, 2 * np.pi), np.mod(phi, 2 * np.pi) 188 | 189 | def set_descendants(self, top_descendants: np.ndarray, bot_descendants: np.ndarray): 190 | self.top_descendants = top_descendants 191 | self.bot_descendants = bot_descendants 192 | return self 193 | 194 | 195 | def direct_transmissivity(top: np.ndarray, bottom: np.ndarray): 196 | """Get the direct transmissivity between top and bottom 197 | 198 | Args: 199 | top: Top vector elements 200 | bottom: Bottom vector elements 201 | 202 | Returns: 203 | The transmissivities 204 | 205 | """ 206 | return np.abs(top) ** 2 / (np.abs(top) ** 2 + np.abs(bottom) ** 2 + np.spacing(1)) 207 | 208 | 209 | def transmissivity_to_phase(s: Union[float, np.ndarray], mzi_terms: np.ndarray = None): 210 | """Convert transmissivity :math:`\\boldsymbol{s}` to phase :math:`\\boldsymbol{\\theta}`. 211 | 212 | Args: 213 | s: The transmissivity float or array. 214 | mzi_terms: The splitting terms :code:`(cs, sc)` for an MZI node. If :code:`None`, assumes 0.5 power for each. 215 | 216 | Returns: 217 | The phase :math:`\\boldsymbol{\\theta}` corresponding to the transmissivity :math:`\\boldsymbol{s}`. 218 | 219 | """ 220 | 221 | if mzi_terms is not None: 222 | _, cs, sc, _ = mzi_terms 223 | else: 224 | cs = sc = 0.5 225 | return np.arccos(np.minimum(np.maximum((s - cs ** 2 - sc ** 2) / (2 * cs * sc), -1), 1)) 226 | -------------------------------------------------------------------------------- /simphox/circuit/envelope.py: -------------------------------------------------------------------------------- 1 | """This file is for back-of-the-envelope calculations / figures based on them.""" 2 | 3 | import numpy as np 4 | 5 | 6 | def binary_svd_depth_size(n, k, interport_distance=25, device_length=200, loss_db: float = 0.3): 7 | """Binary SVD architecture depth and size. 8 | 9 | Args: 10 | n: Number of inputs. 11 | k: Number of outputs (generally want :math:`k << n`). 12 | interport_distance: Distance between the ports of the device. 13 | device_length: Overall device length. 14 | loss_db: Loss in dB for the circuit. 15 | 16 | Returns: 17 | Information about number of layers, length, height, footprint, loss, etc. 18 | 19 | """ 20 | num_input_layers = np.ceil(np.log2(n)) 21 | num_unitary_layers = k * np.ceil(np.log2(n)) 22 | attenuate_and_unitary_layer = k + 1 23 | num_waveguides_height = np.ceil(n / k) * n 24 | layers = num_input_layers + num_unitary_layers + attenuate_and_unitary_layer 25 | height = num_waveguides_height * interport_distance 26 | length = device_length * layers 27 | return { 28 | 'layers': layers, 29 | 'length (cm)': length / 1e4, 30 | 'height (cm)': height / 1e4, 31 | 'footprint (cm^2)': length * height / 1e8, 32 | 'loss (dB)': -layers * loss_db + 10 * np.log10(1 / np.ceil(n / k)) + 10 * np.log10(1 / n) - 3 33 | } 34 | 35 | 36 | def rectangular_depth_size(n, interport_distance=25, device_length=200, loss_db: float = 0.3, svd: bool = True): 37 | """Rectangular architecture depth and size. 38 | 39 | 40 | Args: 41 | n: The number of outputs of the binary tree 42 | interport_distance: Distance between each port in the network (include the phase shifters) 43 | device_length: Length of the device 44 | loss_db: Loss of the device 45 | svd: Whether to use an SVD architecture (doubles the number of layers in the architecture) 46 | 47 | Returns: 48 | Information about number of layers, length, height, footprint, loss, etc. 49 | 50 | """ 51 | num_input_layers = np.ceil(np.log2(n)) 52 | num_unitary_layers = n * (1 + svd) + svd 53 | num_waveguides_height = n 54 | layers = num_input_layers + num_unitary_layers 55 | height = num_waveguides_height * interport_distance 56 | length = device_length * layers 57 | return { 58 | 'layers': layers, 59 | 'length (cm)': length / 1e4, 60 | 'height (cm)': height / 1e4, 61 | 'footprint (cm^2)': length * height / 1e8, 62 | 'loss (dB)': -layers * loss_db - 10 * np.log10(n) - 3 * svd 63 | } 64 | 65 | 66 | def binary_equiv_cascade_size(n, n_equiv, interport_distance=25, device_length=200, loss_db: float = 0.3): 67 | """Binary architecture cascade size. 68 | 69 | Args: 70 | n: The number of outputs of the binary tree 71 | n_equiv: Find the k (number of inputs) required to match the flops of n_equiv x n_equiv matrix 72 | interport_distance: Distance between each port in the network (include the phase shifters) 73 | device_length: Length of the device 74 | loss_db: Loss of the device 75 | 76 | Returns: 77 | Information about number of layers, length, height, footprint, loss, etc. 78 | 79 | """ 80 | k = np.ceil(n_equiv ** 2 / n) 81 | num_input_layers = np.ceil(np.log2(n)) 82 | num_unitary_layers = k * np.ceil(np.log2(n)) 83 | attenuate_and_unitary_layer = k + 1 84 | num_waveguides_height = n 85 | layers = num_input_layers + num_unitary_layers + attenuate_and_unitary_layer 86 | height = num_waveguides_height * interport_distance 87 | length = device_length * layers 88 | return { 89 | 'layers': layers, 90 | 'length (cm)': length / 1e4, 91 | 'height (cm)': height / 1e4, 92 | 'footprint (cm^2)': length * height / 1e8, 93 | 'loss (dB)': -layers * loss_db + 10 * np.log10(1 / n) 94 | } 95 | -------------------------------------------------------------------------------- /simphox/circuit/matrix.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import scipy as sp 4 | 5 | 6 | from .coupling import PhaseStyle 7 | from .rectangular import rectangular 8 | from .vector import vector_unit 9 | from scipy.stats import unitary_group 10 | import numpy as np 11 | 12 | from scipy.linalg import svd, qr, block_diag 13 | from .coupling import CouplingNode 14 | from .forward import ForwardMesh 15 | 16 | 17 | def cascade(u: np.ndarray, balanced: bool = True, phase_style: str = PhaseStyle.TOP, 18 | error_mean_std: Tuple[float, float] = (0., 0.), loss_mean_std: Tuple[float, float] = (0., 0.)): 19 | """Generate an architecture based on our recursive definitions programmed to implement unitary :math:`U`, 20 | or a set of :math:`K` mutually orthogonal basis vectors. 21 | 22 | Args: 23 | u: The (:math:`k \\times n`) mutually orthogonal basis vectors (unitary if :math:`k=n`) to be configured. 24 | balanced: If balanced, does balanced tree (:code:`m = n // 2`) otherwise linear chain (:code:`m = n - 1`). 25 | phase_style: Phase style for the nodes (see the :code:`PhaseStyle` enum). 26 | error_mean_std: Mean and standard deviation for errors (in radians). 27 | loss_mean_std: Mean and standard deviation for losses (in dB). 28 | 29 | Returns: 30 | Node list, thetas and phis. 31 | 32 | """ 33 | subunits = [] 34 | thetas = np.array([]) 35 | phis = np.array([]) 36 | gammas = np.array([]) 37 | n_rails = u.shape[0] 38 | num_columns = 0 39 | num_nodes = 0 40 | 41 | w = u.conj().T.copy() 42 | for i in reversed(range(n_rails + 1 - u.shape[1], n_rails)): 43 | # Generate the architecture as well as the theta and phi for each row of u. 44 | network, w = vector_unit(w[:i + 1, :i + 1], n_rails, balanced, phase_style, 45 | error_mean_std, loss_mean_std) 46 | 47 | # Update the phases. 48 | thetas = np.hstack((thetas, network.thetas)) 49 | phis = np.hstack((phis, network.phis)) 50 | gammas = np.hstack((network.gammas[-1], gammas)) 51 | 52 | # We need to index the thetas and phis correctly based on the number of programmed nodes in previous subunits 53 | # For unbalanced architectures (linear chains), we can actually pack them more efficiently into a triangular 54 | # architecture. 55 | network.offset(num_nodes).offset_column(num_columns if balanced else 2 * (n_rails - 1 - i)) 56 | 57 | # Add the nodes list to the subunits 58 | subunits.append(network) 59 | 60 | # The number of columns and nodes in the architecture are incremented by the subunit size (log_2(i)) 61 | num_columns += subunits[-1].num_columns 62 | num_nodes += subunits[-1].num_nodes 63 | gammas = np.hstack((-np.angle(w[0, 0]), gammas)) 64 | unit = ForwardMesh.aggregate(subunits) 65 | unit.params = thetas, phis, gammas 66 | return unit 67 | 68 | 69 | def triangular(u: np.ndarray, phase_style: str = PhaseStyle.TOP, error_mean_std: Tuple[float, float] = (0., 0.), 70 | loss_mean_std: Tuple[float, float] = (0., 0.)): 71 | """Triangular mesh that analyzes a unitary matrix :code:`u`. 72 | 73 | Args: 74 | u: Unitary matrix or integer representing the number of inputs and outputs 75 | phase_style: Phase style for the nodes of the mesh. 76 | error_mean_std: Split error mean and standard deviation 77 | loss_mean_std: Loss error mean and standard deviation (dB) 78 | 79 | Returns: 80 | A triangular mesh object. 81 | 82 | """ 83 | u = unitary_group.rvs(u) if np.isscalar(u) else u 84 | return cascade(u, balanced=False, phase_style=phase_style, 85 | error_mean_std=error_mean_std, loss_mean_std=loss_mean_std) 86 | 87 | 88 | def tree_cascade(u: np.ndarray, phase_style: str = PhaseStyle.TOP, error_mean_std: Tuple[float, float] = (0., 0.), 89 | loss_mean_std: Tuple[float, float] = (0., 0.)): 90 | """Balanced cascade mesh that analyzes a unitary matrix :code:`u`. 91 | 92 | Args: 93 | u: Unitary matrix 94 | phase_style: Phase style for the nodes of the mesh. 95 | error_mean_std: Split error mean and standard deviation 96 | loss_mean_std: Loss error mean and standard deviation (dB) 97 | 98 | Returns: 99 | A tree cascade mesh object. 100 | 101 | """ 102 | u = unitary_group.rvs(u) if np.isscalar(u) else u 103 | return cascade(u, balanced=True, phase_style=phase_style, 104 | error_mean_std=error_mean_std, loss_mean_std=loss_mean_std) 105 | 106 | 107 | def dirichlet_matrix(v, embed_dim=None): 108 | phases = np.exp(-1j * np.angle(v)) 109 | y = np.abs(v) ** 2 110 | yop = np.sqrt(np.outer(y, y)) 111 | ysum = np.cumsum(y) 112 | yden = 1 / np.sqrt(ysum[:-1] * ysum[1:]) 113 | u = np.zeros_like(yop, dtype=np.complex128) 114 | u[1:, :] = yden[:, np.newaxis] * yop[1:, :] 115 | u[np.triu_indices(v.size)] = 0 116 | u[1:, 1:][np.diag_indices(v.size - 1)] = -ysum[:-1] * yden 117 | u[0] = np.sqrt(y / ysum[-1]) 118 | u *= phases 119 | u = np.roll(u, -1, axis=0) if embed_dim is None else sp.linalg.block_diag(np.roll(u, -1, axis=0), 120 | np.eye(embed_dim - v.size)) 121 | return u.T 122 | 123 | def cs(mat: np.ndarray): 124 | """Cosine-sine decomposition of arbitrary matrix :math:`U`(:code:`u`) 125 | 126 | Args: 127 | mat: The unitary matrix 128 | 129 | Even-partition cosine decomposition: 130 | [ q00 | q01 ] [ l0 | 0 ] [ s | c ] [ r0 | 0 ] 131 | u = [-----------] = [---------] [---------] [---------] . 132 | [ q10 | q11 ] [ 0 | l1 ] [ c | -s ] [ 0 | r1 ] 133 | 134 | c = diag(cos(theta)) 135 | s = diag(sin(theta)) 136 | where theta is in the range [0, pi / 2] 137 | 138 | Returns: 139 | The tuple of the four matrices :code:`l0`, :code:`l1`, :code:`r0`, :code:`r1`, and 140 | cosine-sine phases :code:`theta` in order from top to bottom. 141 | 142 | """ 143 | n = mat.shape[0] 144 | m = n // 2 145 | q00 = mat[:m, :m] 146 | q10 = mat[m:, :m] 147 | q01 = mat[:m, m:] 148 | q11 = mat[m:, m:] 149 | l0, d00, r0 = svd(q00) 150 | r1hp, d01 = qr(q01.conj().T @ l0) 151 | theta = np.arcsin(d00) 152 | d01 = np.append(np.diag(d01), 1) if n % 2 else np.diag(d01) 153 | r1 = (r1hp * np.sign(d01)).conj().T 154 | l1p, d10 = qr(q10 @ r0.conj().T) 155 | d10 = np.append(np.diag(d10), 1) if n % 2 else np.diag(d10) 156 | l1 = l1p * np.sign(d10) 157 | phasor = (l1.conj().T @ q11 @ r1.conj().T)[-1, -1] if n % 2 else None 158 | if n % 2: 159 | r1[-1] *= phasor 160 | return l0, l1, r0, r1, theta 161 | 162 | 163 | def csinv(l0: np.ndarray, l1: np.ndarray, r0: np.ndarray, r1: np.ndarray, theta: np.ndarray): 164 | """Runs the inverse of the :code:`cs` function 165 | 166 | Args: 167 | l0: top left 168 | l1: bottom left 169 | r0: top right 170 | r1: bottom right 171 | theta: cosine-sine phases 172 | 173 | Returns: 174 | The final unitary matrix :code:`u`. 175 | 176 | """ 177 | left = block_diag(l0, l1) 178 | right = block_diag(r0, r1) 179 | c = np.cos(theta) 180 | s = np.sin(theta) 181 | d = np.block([[np.diag(s), np.diag(c)], 182 | [np.diag(c), -np.diag(s)]]) 183 | if r0.shape[0] != r1.shape[1]: 184 | d = block_diag(d, 1).astype(np.complex128) 185 | return left @ d @ right 186 | 187 | 188 | def _bowtie(u: np.ndarray, n_rails: int, thetas: np.ndarray, phis: np.ndarray, start: int, layer: int = None): 189 | """Recursive step for the cosine-sine bowtie architecture 190 | 191 | Args: 192 | u: Unitary matrix u 193 | n_rails: Number of total rails in the architecture 194 | thetas: The internal phase shifts or coupling phase terms :math:`\\theta`. 195 | phis: The external phase shifts or differential input phase terms :math:`\\phi`. 196 | start: Start index for interfering modes. 197 | layer: Layer of the bowtie recursion 198 | 199 | Returns: 200 | The list of :code:`CouplingNode`. 201 | 202 | """ 203 | nodes = [] 204 | n = u.shape[0] 205 | m = n // 2 206 | if n == 1: 207 | phis[layer][start] += np.angle(u[0][0]) 208 | return nodes 209 | l0, l1, r0, r1, theta = cs(u) 210 | thetas[layer][start:start + m * 2][::2] = theta 211 | nodes.extend([CouplingNode(n=n_rails, top=start + shift, bottom=start + shift + m, column=layer) 212 | for shift in range(m)]) 213 | nodes.extend(_bowtie(l0, n_rails, thetas, phis, start, layer - m)) 214 | nodes.extend(_bowtie(r0, n_rails, thetas, phis, start, layer + m)) 215 | nodes.extend(_bowtie(l1, n_rails, thetas, phis, start + m, layer - m)) 216 | nodes.extend(_bowtie(r1, n_rails, thetas, phis, start + m, layer + m)) 217 | return nodes 218 | 219 | 220 | def bowtie(u: np.ndarray): 221 | """Cosine-sine bowtie architecture. 222 | 223 | Args: 224 | u: The unitary matrix :code:`u` to parametrize the system. 225 | 226 | Returns: 227 | The bowtie fractal architecture. 228 | 229 | """ 230 | n = u.shape[0] 231 | thetas = np.zeros((2 * n - 3, n)) 232 | phis = np.zeros((2 * n - 1, n)) 233 | circuit = ForwardMesh(_bowtie(u, n, thetas, phis, 0, n - 2)) 234 | phis = phis[1:] 235 | theta = np.zeros(int(n * (n - 1) / 2)) 236 | phi = np.zeros(int(n * (n - 1) / 2)) 237 | columns = circuit.columns 238 | for col_idx, col in enumerate(columns): 239 | theta[(col.node_idxs,)] = thetas[col_idx][np.nonzero(thetas[col_idx])] 240 | phi[(col.node_idxs,)] = phis[col_idx][(col.top,)] - phis[col_idx][(col.bottom,)] 241 | phis[col_idx][(col.top,)] = phis[col_idx][(col.bottom,)] 242 | if col_idx < len(columns): 243 | phis[col_idx + 1] += phis[col_idx] 244 | phis[col_idx + 1] = np.mod(phis[col_idx + 1], 2 * np.pi) 245 | circuit.params = theta * 2, phi, phis[-1] 246 | return circuit 247 | 248 | 249 | def psvd(a: np.ndarray): 250 | """Photonic SVD architecture 251 | 252 | Args: 253 | a: The matrix for which to perform the svd 254 | 255 | Returns: 256 | A tuple of singular values and the two corresponding SVD architectures :math:`U` and :math:`V^\\dagger`. 257 | 258 | """ 259 | l, d, r = svd(a) 260 | return rectangular(l), d, rectangular(r) 261 | -------------------------------------------------------------------------------- /simphox/circuit/rectangular.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | 4 | from ..typing import Callable, Union 5 | from .coupling import CouplingNode 6 | from .forward import ForwardMesh 7 | 8 | 9 | def checkerboard_to_param(checkerboard: np.ndarray, units: int): 10 | param = np.zeros((units, units // 2)) 11 | if units % 2: 12 | param[::2, :] = checkerboard.T[::2, :-1:2] 13 | else: 14 | param[::2, :] = checkerboard.T[::2, ::2] 15 | param[1::2, :] = checkerboard.T[1::2, 1::2] 16 | return param 17 | 18 | 19 | def get_alpha_checkerboard(n: int): 20 | """Get the sensitivity index for each of the nodes in a rectangular architecture. 21 | 22 | The sensitivity values are arranged in a checkerboard form for easy spatial mapping to 23 | coupling nodes in a rectangular architecture. 24 | 25 | Args: 26 | n: The number of inputs in the rectangular architecture 27 | 28 | Returns: 29 | 30 | """ 31 | def rectangular_alpha(length: int, parity_odd: bool = False): 32 | odd_nums = list(length + 1 - np.flip(np.arange(1, length + 1, 2), axis=0)) 33 | even_nums = list(length + 1 - np.arange(2, 2 * (length - len(odd_nums)) + 1, 2)) 34 | nums = np.asarray(odd_nums + even_nums) 35 | if parity_odd: 36 | nums = nums[::-1] 37 | return nums 38 | alpha_checkerboard = np.zeros((n, n)) 39 | diagonal_length_to_sequence = [rectangular_alpha(i, bool(n % 2)) for i in range(1, n + 1)] 40 | for i in range(n - 1): 41 | for j in range(n): 42 | if (i + j) % 2 == 0: 43 | if j > i: 44 | diagonal_length = n - np.abs(i - j) 45 | elif i > 0 and j < i: 46 | diagonal_length = n - np.abs(i - j) - 1 47 | else: 48 | diagonal_length = n - 1 49 | alpha_checkerboard[i, j] = 1 if diagonal_length == 1 else \ 50 | diagonal_length_to_sequence[int(diagonal_length) - 1][min(i, j)] 51 | return alpha_checkerboard 52 | 53 | 54 | def grid_common_mode_flow(external_phases: np.ndarray, gamma: np.ndarray = None): 55 | """In a grid mesh (e.g., triangular, rectangular meshes), phases may need to be re-arranged. 56 | This is achieved using a procedure called "common mode flow" where common modes are shifted 57 | throughout the mesh until phases are correctly set. 58 | 59 | Args: 60 | external_phases: external phases in the grid mesh 61 | gamma: input phase shifts 62 | 63 | Returns: 64 | new external phases shifts and new gamma resulting 65 | 66 | """ 67 | gamma = np.zeros(external_phases.shape[1]) if gamma is None else gamma 68 | units, num_layers = external_phases.shape 69 | phase_shifts = np.hstack((external_phases, gamma[:, np.newaxis])).T 70 | new_phase_shifts = np.zeros_like(external_phases.T) 71 | 72 | for i in range(num_layers): 73 | start_idx = i % 2 74 | end_idx = units - (i + units) % 2 75 | 76 | # calculate upper and lower phases 77 | upper_phase = phase_shifts[i][start_idx:end_idx][::2] 78 | lower_phase = phase_shifts[i][start_idx:end_idx][1::2] 79 | upper_phase = np.mod(upper_phase, 2 * np.pi) 80 | lower_phase = np.mod(lower_phase, 2 * np.pi) 81 | 82 | # upper - lower 83 | new_phase_shifts[i][start_idx:end_idx][::2] = upper_phase - lower_phase 84 | 85 | # lower_phase is now the common mode for all phase shifts in this layer 86 | phase_shifts[i] -= new_phase_shifts[i] 87 | 88 | # shift the phases to the next layer in parallel 89 | phase_shifts[i + 1] += np.mod(phase_shifts[i], 2 * np.pi) 90 | new_gamma = np.mod(phase_shifts[-1], 2 * np.pi) 91 | return np.mod(new_phase_shifts.T, 2 * np.pi), new_gamma 92 | 93 | 94 | def rectangular(u: Union[int, np.ndarray], pbar: Callable = None): 95 | """Get a rectangular architecture for the unitary matrix :code:`u` using the Clements decomposition. 96 | 97 | Args: 98 | u: The unitary matrix 99 | pbar: The progress bar for the clements decomposition (useful for larger unitaries) 100 | 101 | Returns: 102 | 103 | """ 104 | u_hat = u.copy() 105 | n = u.shape[0] 106 | # odd and even layer dimensions 107 | theta_checkerboard = np.zeros_like(u, dtype=np.float64) 108 | phi_checkerboard = np.zeros_like(u, dtype=np.float64) 109 | phi_checkerboard = np.hstack((np.zeros((n, 1)), phi_checkerboard)) 110 | iterator = pbar(range(n - 1)) if pbar else range(n - 1) 111 | for i in iterator: 112 | if i % 2: 113 | for j in range(i + 1): 114 | pairwise_index = n + j - i - 2 115 | target_row, target_col = n + j - i - 1, j 116 | theta = np.arctan2(np.abs(u_hat[target_row - 1, target_col]), np.abs(u_hat[target_row, target_col])) * 2 117 | phi = np.angle(u_hat[target_row, target_col]) - np.angle(u_hat[target_row - 1, target_col]) 118 | left = CouplingNode(n=n, top=pairwise_index, bottom=pairwise_index + 1) 119 | # Need an equivalent of differential internal phase shift to invert the matrix properly. 120 | u_hat = left.phase_matrix(-theta / 2, -theta / 2) @ left.mzi_node_matrix(theta, phi) @ u_hat 121 | theta_checkerboard[pairwise_index, -j - 1] = theta 122 | phi_checkerboard[pairwise_index, -j - 1] = -phi - theta / 2 + np.pi 123 | phi_checkerboard[pairwise_index + 1, -j - 1] = -theta / 2 + np.pi 124 | else: 125 | for j in range(i + 1): 126 | pairwise_index = i - j 127 | target_row, target_col = n - j - 1, i - j 128 | theta = np.arctan2(np.abs(u_hat[target_row, target_col + 1]), np.abs(u_hat[target_row, target_col])) * 2 129 | phi = np.angle(-u_hat[target_row, target_col]) - np.angle(u_hat[target_row, target_col + 1]) 130 | right = CouplingNode(n=n, top=pairwise_index, bottom=pairwise_index + 1) 131 | u_hat = u_hat @ right.mzi_node_matrix(theta, phi).conj().T 132 | theta_checkerboard[pairwise_index, j] = theta 133 | phi_checkerboard[pairwise_index, j] = phi 134 | 135 | diag_phases = np.angle(np.diag(u_hat)) 136 | theta = checkerboard_to_param(theta_checkerboard, n) 137 | alpha_checkerboard = get_alpha_checkerboard(n) 138 | if n % 2: 139 | phi_checkerboard[:, :-1] += np.fliplr(np.diag(diag_phases)) 140 | else: 141 | phi_checkerboard[:, 1:] += np.fliplr(np.diag(diag_phases)) 142 | 143 | # Run the common mode flow algorithm to move phase front to the last layer of the mesh 144 | phi, gamma = grid_common_mode_flow(external_phases=phi_checkerboard[:, :-1], gamma=phi_checkerboard[:, -1]) 145 | phi = checkerboard_to_param(phi, n) 146 | alpha = checkerboard_to_param(alpha_checkerboard, n) 147 | 148 | # Set up the rectangular mesh nodes 149 | nodes = [] 150 | thetas = np.array([]) 151 | phis = np.array([]) 152 | node_id = 0 153 | for i in range(n): 154 | num_to_interfere = theta.shape[1] - (i % 2) * (1 - n % 2) 155 | nodes += [CouplingNode(node_id=node_id + j, n=n, column=i, 156 | top=2 * j + i % 2, bottom=2 * j + 1 + i % 2, 157 | alpha=1, beta=alpha[i, j]) 158 | for j in range(num_to_interfere)] 159 | thetas = np.hstack([thetas, theta[i, :num_to_interfere]]) 160 | phis = np.hstack([phis, phi[i, :num_to_interfere]]) 161 | node_id += num_to_interfere 162 | 163 | unit = ForwardMesh(nodes) 164 | unit.params = thetas, phis, gamma 165 | return unit 166 | 167 | 168 | def rectangular_rows(rectangular_mesh: ForwardMesh): 169 | n = rectangular_mesh.n 170 | rows = [[] for _ in range(n - 1)] 171 | for i in range(n - 1): 172 | for j in range(i + 1): 173 | pairwise_index = n + j - i - 2 if i % 2 else i - j 174 | rows[i].append(CouplingNode(n=n, top=n - pairwise_index - 2, bottom=n - pairwise_index - 1, 175 | column=n - j - 1 if i % 2 else j)) 176 | return [ForwardMesh(row).column_ordered for row in rows] 177 | 178 | 179 | def rectangular_phase_shift_powers(prop: np.ndarray, use_jax: bool = False): 180 | xp = jnp if use_jax else np 181 | n = prop.shape[1] 182 | y = xp.abs(prop) ** 2 183 | y = y / xp.sum(y, axis=1)[:, xp.newaxis] 184 | phi_even = y[::8, ::2] 185 | phi_odd = y[4::8, 1::2] 186 | theta_even = y[2::8, ::2] 187 | theta_odd = y[6::8, 1::2] 188 | theta_p = xp.hstack([xp.hstack((theta_even[i], theta_odd[i]))[:-1] for i in range(n // 2)]) 189 | phi_p = xp.hstack([xp.hstack((phi_even[i], phi_odd[i]))[:-1] for i in range(n // 2)]) 190 | gamma_p = y[-1] 191 | return theta_p, phi_p, gamma_p 192 | 193 | def random_theta(n: int): 194 | param = checkerboard_to_param(get_alpha_checkerboard(n), n) 195 | return param.flatten() if n % 2 else np.hstack([np.hstack((pe, po))[:-1] for pe, po in zip(param[::2], param[1::2])]) 196 | -------------------------------------------------------------------------------- /simphox/circuit/vector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .coupling import CouplingNode, PhaseStyle 4 | from .forward import ForwardMesh 5 | from ..typing import Callable, List, Optional, Tuple 6 | from ..utils import random_vector 7 | 8 | 9 | def _tree(indices: np.ndarray, n_rails: int, start: int = 0, column: int = 0, 10 | balanced: bool = True) -> List[CouplingNode]: 11 | """Recursive helper function to generate balanced binary tree architecture. 12 | 13 | Our network data structure is similar to how arrays are generally treated as tree data structures 14 | in computer science, where the waveguide rail index refers to the index of the vector propagating through 15 | the coupling network. This is the basis for our recursive definitions of binary trees, 16 | which ultimately form universal photonic networks when cascaded. 17 | 18 | Args: 19 | indices: Ordered indices into the tree nodes (splitting ratio, losses, errors, etc. all use this). 20 | n_rails: Number of rails in the system. 21 | start: Starting index for the tree. 22 | column: column index for the tree (leaves have largest column index :math:`\\log_2 n`). 23 | balanced: If balanced, does balanced tree (:code:`m = n // 2`) otherwise linear chain (:code:`m = n - 1`). 24 | 25 | Returns: 26 | A list of :code:`CouplingNode`'s in order that modes visit them. 27 | 28 | """ 29 | nodes = [] 30 | n = indices.size + 1 31 | if not balanced: 32 | return [CouplingNode(i, top=i, bottom=i + 1, n=n_rails, 33 | alpha=1, beta=n - i - 1, column=n - 2 - i).set_descendants(indices[:i], 34 | np.array([], dtype=np.int32)) 35 | for i in reversed(range(n - 1))] 36 | m = n // 2 if balanced else n - 1 37 | if n == 1: 38 | return nodes 39 | top = indices[:m - 1] 40 | bottom = indices[m:] 41 | nodes.append(CouplingNode(indices[m - 1], top=start + m - 1, bottom=start + n - 1, n=n_rails, 42 | alpha=bottom.size + 1, beta=top.size + 1, column=column).set_descendants(top, bottom)) 43 | nodes.extend(_tree(top, n_rails, start, column + 1, balanced)) 44 | nodes.extend(_tree(bottom, n_rails, start + m, column + 1, balanced)) 45 | return nodes 46 | 47 | 48 | def _butterfly(n: int, n_rails: int, start: int = 0, column: int = 0) -> List[CouplingNode]: 49 | """Recursive helper function to generate a balanced butterfly architecture (works best with powers of 2). 50 | 51 | Our network data structure is similar to how arrays are generally treated as tree data structures 52 | in computer science, where the waveguide rail index refers to the index of the vector propagating through 53 | the coupling network. 54 | 55 | Args: 56 | n: The number of modes in the binary tree (uses the top :math:`n` modes of the system). 57 | n_rails: Number of rails in the system. 58 | start: Starting index for the tree. 59 | column: column index for the tree (leaves have largest column index :math:`\\log_2 n`). 60 | 61 | Returns: 62 | A list of :code:`CouplingNode`'s in order that modes visit them. 63 | 64 | """ 65 | nodes = [] 66 | m = n // 2 67 | if n == 1: 68 | return nodes 69 | nodes.extend([CouplingNode(top=start + i, bottom=start + m + i, n=n_rails, 70 | alpha=m + 1, beta=m + 1, column=column) for i in reversed(range(m))]) 71 | nodes.extend(_butterfly(m, n_rails, start, column + 1)) 72 | nodes.extend(_butterfly(n - m, n_rails, start + m, column + 1)) 73 | return nodes 74 | 75 | 76 | def tree(n: int, n_rails: Optional[int] = None, balanced: bool = True, 77 | phase_style: str = PhaseStyle.TOP) -> ForwardMesh: 78 | """Return a balanced or linear chain tree of MZIs. 79 | 80 | Args: 81 | n: Number of inputs into the tree. 82 | n_rails: Embed the first :code:`n` rails in an :code:`n_rails`-rail system (default :code:`n_rails == n`). 83 | balanced: If balanced, does balanced tree (:code:`m = n // 2`) otherwise linear chain (:code:`m = n - 1`). 84 | 85 | Returns: 86 | A :code:`CouplingCircuit` consisting of :code:`CouplingNode`'s arranged in a tree network. 87 | 88 | """ 89 | n_rails = n if n_rails is None else n_rails 90 | return ForwardMesh(_tree(np.arange(n - 1), n_rails, balanced=balanced), phase_style=phase_style).invert_columns() 91 | 92 | 93 | def butterfly(n: int, n_rails: Optional[int] = None) -> ForwardMesh: 94 | """Return a butterfly architecture 95 | 96 | Args: 97 | n: Number of inputs into the tree. 98 | n_rails: Embed the first :code:`n` rails in an :code:`n_rails`-rail system (default :code:`n_rails == n`). 99 | 100 | Returns: 101 | A :code:`CouplingCircuit` consisting of :code:`CouplingNode`'s arranged in a tree network. 102 | 103 | """ 104 | n_rails = n if n_rails is None else n_rails 105 | return ForwardMesh(_butterfly(n, n_rails)).invert_columns() 106 | 107 | 108 | def _program_vector_unit(v: np.ndarray, network: ForwardMesh): 109 | """Code for programming a vector unit that already exists. 110 | 111 | Note: 112 | This is a private method since we cannot assume that the input network is the appropriate 113 | vector unit (in size/structure) without defining a separate dataclass for it. 114 | 115 | Args: 116 | v: The vector / matrix (use final row) to program into the network 117 | network: The network to program 118 | 119 | Returns: 120 | The programmed vector unit 121 | 122 | """ 123 | v = v + 0j # cast to complex 124 | thetas = np.zeros(v.shape[0] - 1) 125 | phis = np.zeros(v.shape[0] - 1) 126 | w = v.copy() 127 | w = w[:, np.newaxis] if w.ndim == 1 else w 128 | for nc in network.columns: 129 | # grab the elements for the top and bottom arms of the mzi. 130 | top = w[nc.top] 131 | bottom = w[nc.bottom] 132 | 133 | theta, phi = nc.parallel_nullify(w, network.mzi_terms) 134 | 135 | # Vectorized (efficient!) parallel mzi elements 136 | t11, t12, t21, t22 = nc.parallel_mzi(theta, phi) 137 | t11, t12, t21, t22 = t11[:, np.newaxis], t12[:, np.newaxis], t21[:, np.newaxis], t22[:, np.newaxis] 138 | 139 | # these are the top port powers before nulling 140 | network.pnsn[nc.node_idxs] = np.abs(top[..., -1] 141 | if network.phase_style == PhaseStyle.TOP else bottom[..., -1]) ** 2 142 | 143 | # The final vector after the vectorized multiply 144 | w[nc.top + nc.bottom] = np.vstack([t11 * top + t21 * bottom, 145 | t12 * top + t22 * bottom]) 146 | 147 | # these are the relative powers after nulling 148 | network.pn[nc.node_idxs] = np.abs(w[nc.bottom][..., -1]) ** 2 149 | 150 | # The resulting thetas and phis, indexed according to the coupling network specifications 151 | thetas[nc.node_idxs] = theta 152 | phis[nc.node_idxs] = np.mod(phi, 2 * np.pi) 153 | 154 | final_basis_vec = np.zeros(v.shape[0]) 155 | final_basis_vec[-1] = 1 156 | gammas = -np.angle(final_basis_vec * w[-1, -1]) 157 | 158 | network.params = thetas, phis, gammas 159 | # print(network.params) 160 | 161 | return network, w.squeeze() 162 | 163 | 164 | def vector_unit(v: np.ndarray, n_rails: int = None, balanced: bool = True, phase_style: str = PhaseStyle.TOP, 165 | bs_error_mean_std: Tuple[float, float] = (0., 0.), loss_mean_std: Tuple[float, float] = (0., 0.)): 166 | """Generate an architecture based on our recursive definitions programmed to implement normalized vector :code:`v`. 167 | 168 | Args: 169 | v: The number of inputs or vector to be configured. If a matrix is provided, use the final row vector of matrix. 170 | n_rails: Embed the first :code:`n` rails in an :code:`n_rails`-rail system (default :code:`n_rails == n`). 171 | balanced: If balanced, does balanced tree (:code:`m = n // 2`) otherwise linear chain (:code:`m = n - 1`). 172 | phase_style: Phase style for the nodes (see the :code:`PhaseStyle` enum). 173 | bs_error_mean_std: Mean and standard deviation for beamsplitter errors (in radians). 174 | loss_mean_std: Mean and standard deviation for losses (in dB). 175 | 176 | Returns: 177 | A tuple of the programmed coupling network, the matrix after being fed through the network. 178 | """ 179 | network = tree(v.shape[0] if not np.isscalar(v) else v, n_rails=n_rails, balanced=balanced, phase_style=phase_style) 180 | error_mean, error_std = bs_error_mean_std 181 | loss_mean, loss_std = loss_mean_std 182 | network = network.add_error_mean(error_mean, loss_mean).add_error_variance(error_std, loss_std) 183 | if np.isscalar(v): 184 | return network, None 185 | return _program_vector_unit(v, network) 186 | 187 | 188 | def balanced_tree(v: np.ndarray, phase_style: str = PhaseStyle.TOP, 189 | bs_error_mean_std: Tuple[float, float] = (0., 0.), 190 | loss_mean_std: Tuple[float, float] = (0., 0.)): 191 | """Balanced tree mesh that analyzes a vector :code:`v`. 192 | 193 | Args: 194 | v: Vector unit. 195 | phase_style: Phase style for the nodes of the mesh. 196 | bs_error_mean_std: Mean and standard deviation for beamsplitter errors (in radians). 197 | loss_mean_std: Mean and standard deviation for losses (in dB). 198 | 199 | Returns: 200 | A tree mesh object analyzing a vector. 201 | 202 | """ 203 | return vector_unit(v.conj().T if not np.isscalar(v) else v, 204 | phase_style=phase_style, bs_error_mean_std=bs_error_mean_std, loss_mean_std=loss_mean_std)[0] 205 | 206 | 207 | def unbalanced_tree(v: np.ndarray, phase_style: str = PhaseStyle.TOP, bs_error_mean_std: Tuple[float, float] = (0., 0.), 208 | loss_mean_std: Tuple[float, float] = (0., 0.)): 209 | """Linear chain that analyzes a vector :code:`v`. 210 | 211 | Args: 212 | v: Vector unit 213 | phase_style: Phase style for the nodes of the mesh. 214 | bs_error_mean_std: Split error mean and standard deviation 215 | loss_mean_std: Loss error mean and standard deviation (dB) 216 | 217 | Returns: 218 | A linear chain mesh object analyzing a vector. 219 | 220 | """ 221 | return vector_unit(v.conj().T if not np.isscalar(v) else v, 222 | phase_style=phase_style, bs_error_mean_std=bs_error_mean_std, loss_mean_std=loss_mean_std, 223 | balanced=False)[0] 224 | 225 | 226 | def hessian_vector_unit(v: np.ndarray, balanced: bool = True): 227 | """Compute the Hessian for a vector unit if size code:`n` using finite differences assuming TOP phase style. 228 | 229 | We use the self-configuring dynamic programming approach to generate the necessary power quantities 230 | at each node required to generate the Hessian directly from the vector unit. 231 | 232 | Args: 233 | v: Vector to be programmed on the vector unit. 234 | balanced: Whether to use the balanced or the unbalanced tree vector unit. 235 | 236 | Returns: 237 | The Hessian matrix with block matrices 238 | :math:`\\mathcal{H}_{\\theta \\to \\theta}`, :math:`\\mathcal{H}_{\\phi \\to \\phi}`, 239 | :math:`\\mathcal{H}_{\\theta \\to \\phi}`, which give the Hessian magnitudes for the matrix. 240 | 241 | """ 242 | v = v / np.linalg.norm(v) 243 | mesh = balanced_tree(v) if balanced else unbalanced_tree(v) 244 | pn, pnsn = mesh.pn, mesh.pnsn 245 | 246 | theta_theta = np.diag(pn) 247 | phi_phi = 2 * np.diag(pnsn) 248 | theta_phi = np.zeros_like(phi_phi) 249 | phi_theta = np.diag(pnsn) 250 | for i in range(v.size - 1): 251 | nid = mesh.nodes[i].node_id 252 | theta_theta[nid][(mesh.nodes[i].top_descendants,)] = mesh.pn[(mesh.nodes[i].top_descendants,)] / 2 253 | theta_theta[nid][(mesh.nodes[i].bot_descendants,)] = mesh.pn[(mesh.nodes[i].bot_descendants,)] / 2 254 | theta_theta[..., nid] = theta_theta[nid] 255 | phi_phi[nid][(mesh.nodes[i].top_descendants,)] = 2 * mesh.pnsn[(mesh.nodes[i].top_descendants,)] 256 | phi_phi[..., nid] = phi_phi[nid] 257 | theta_phi[..., nid][(mesh.nodes[i].top_descendants,)] = mesh.pn[(mesh.nodes[i].top_descendants,)] 258 | phi_theta[nid][(mesh.nodes[i].top_descendants,)] = mesh.pnsn[(mesh.nodes[i].top_descendants,)] 259 | phi_theta[nid][(mesh.nodes[i].bot_descendants,)] = mesh.pnsn[(mesh.nodes[i].bot_descendants,)] 260 | h = np.block([[theta_theta, (theta_phi + phi_theta).T], [(theta_phi + phi_theta), phi_phi]]) 261 | return h 262 | 263 | 264 | def hessian_fd(v: np.ndarray, error=0.0001, balanced=False): 265 | """Compute the Hessian for a vector unit if size code:`n` using finite differences. 266 | This is mostly useful for testing, but it takes way too long in practice. 267 | 268 | The finite difference evaluation is given by the central differencing scheme: 269 | .. math:: 270 | \\mathcal{H}_{ij} = \\frac{\\partial^2 \\epsilon^2}{\\partial \\delta_{i} \\partial \\delta_{j}} \\approx 271 | \\frac{\\epsilon^2(\\delta_i \\boldsymbol{e}_i + \\delta_j \\boldsymbol{e}_j) - \\epsilon^2( 272 | \\delta_i \\boldsymbol{e}_i - \\delta_j \\boldsymbol{e}_j)}{2\\delta_{i} \\delta_{j}} 273 | where we allow :math:`\\delta_{i} = \\delta_{j}` be the phase error applied to phases :math:`i, j` in the network. 274 | 275 | Args: 276 | v: Vector to be programmed on the vector unit. 277 | error: The tiny error to use to compute the second-order Hessian matrix. 278 | balanced: Whether to use the balanced or the unbalanced tree vector unit. 279 | 280 | Returns: 281 | The Hessian matrix with block matrices 282 | :math:`\\mathcal{H}_{\\theta \\to \\theta}`, :math:`\\mathcal{H}_{\\phi \\to \\phi}`, 283 | :math:`\\mathcal{H}_{\\theta \\to \\phi}`, which give the Hessian magnitudes for the matrix. 284 | 285 | """ 286 | n = v.size 287 | v = v / np.linalg.norm(v) 288 | h = np.zeros((2 * n - 2, 2 * n - 2), dtype=np.complex128) 289 | e = np.eye(n - 1) * error 290 | mesh = balanced_tree(v) if balanced else unbalanced_tree(v) 291 | 292 | def err(params): 293 | return 2 - 2 * mesh.matrix_fn()(params, v.conj())[-1] 294 | 295 | for i in range(n - 1): 296 | for j in range(n - 1): 297 | theta_pp = (mesh.thetas + e[i] + e[j], mesh.phis, mesh.gammas) 298 | theta_nn = (mesh.thetas - e[i] - e[j], mesh.phis, mesh.gammas) 299 | theta_pn = (mesh.thetas + e[i] - e[j], mesh.phis, mesh.gammas) 300 | theta_np = (mesh.thetas - e[i] + e[j], mesh.phis, mesh.gammas) 301 | phi_pp = (mesh.thetas, mesh.phis + e[i] + e[j], mesh.gammas) 302 | phi_nn = (mesh.thetas, mesh.phis - e[i] - e[j], mesh.gammas) 303 | phi_pn = (mesh.thetas, mesh.phis + e[i] - e[j], mesh.gammas) 304 | phi_np = (mesh.thetas, mesh.phis - e[i] + e[j], mesh.gammas) 305 | theta_phi_pp = (mesh.thetas + e[i], mesh.phis + e[j], mesh.gammas) 306 | theta_phi_pn = (mesh.thetas + e[i], mesh.phis - e[j], mesh.gammas) 307 | theta_phi_np = (mesh.thetas - e[i], mesh.phis + e[j], mesh.gammas) 308 | theta_phi_nn = (mesh.thetas - e[i], mesh.phis - e[j], mesh.gammas) 309 | h[i, j] = np.real(err(theta_pp) + err(theta_nn) - err(theta_pn) - err(theta_np)) / (4 * error ** 2) 310 | h[i + n - 1, j + n - 1] = np.real(err(phi_pp) + err(phi_nn) - err(phi_pn) - err(phi_np)) / (4 * error ** 2) 311 | h[i + n - 1, j] = np.real(err(theta_phi_pp) + err(theta_phi_nn) - err(theta_phi_pn) - err(theta_phi_np)) / ( 312 | 4 * error ** 2) 313 | h[j, i + n - 1] = h[i + n - 1, j] 314 | return h 315 | 316 | 317 | def hessian_distribution(n: int, balanced: bool = False, n_samples=1, pbar: Callable = None): 318 | """Compute the Hessian distribution for a vector unit of size code:`n`. 319 | 320 | Args: 321 | n: Number of inputs 322 | balanced: Whether to use the balanced or the unbalanced tree vector unit. 323 | n_samples: Number of samples (if 1, return the mesh along with the errors). 324 | pbar: Progress bar. 325 | 326 | Returns: 327 | A tuple of :math:`\\mathcal{H}_{\\theta \\to \\theta}`, :math:`\\mathcal{H}_{\\phi \\to \\phi}`, 328 | :math:`\\mathcal{H}_{\\theta \\to \\phi}`, which give the Hessian magnitudes for the matrix. 329 | 330 | """ 331 | hessians = [] 332 | iterator = range(n_samples) if pbar is None else pbar(range(n_samples)) 333 | for _ in iterator: 334 | hessians.append(hessian_vector_unit(random_vector(n, normed=True), balanced=balanced)[0]) 335 | return np.array(hessians) 336 | -------------------------------------------------------------------------------- /simphox/fdtd.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from typing import Tuple, List, Callable 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | from .parse import parse_source_port 8 | from .sim import SimGrid 9 | from .typing import Array, Shape, Spacing, Optional, Union, State, Size3, Size, Source 10 | from .utils import pml_sigma, curl_pml_fn, yee_avg, yee_avg_jax 11 | 12 | try: 13 | from dphox.pattern import Pattern 14 | DPHOX_INSTALLED = True 15 | except ImportError: 16 | DPHOX_INSTALLED = False 17 | 18 | 19 | class FDTD(SimGrid): 20 | """Stateless Finite Difference Time Domain (FDTD) implementation. 21 | 22 | The FDTD update consists of updating the fields and auxiliary vectors that comprise the system "state." This class 23 | ideally makes use of the jit capability of JAX. 24 | 25 | Attributes: 26 | size: size of the simulation 27 | spacing: spacing among the different dimensions 28 | eps: epsilon permittivity 29 | pml: perfectly matched layers (PML) 30 | pml_params: The PML parameters of the form :code:`(exp_scale, log_reflectivity, pml_eps)`. 31 | use_jax: Whether to use jax 32 | name: Name of the simulator 33 | """ 34 | 35 | def __init__(self, size: Size, spacing: Spacing, eps: Union[float, np.ndarray] = 1, 36 | pml: Optional[Union[Shape, Size, float]] = None, pml_params: Size3 = (3, -25, 1), 37 | pml_sep: int = 5, use_jax: bool = True, name: str = 'fdtd'): 38 | super(FDTD, self).__init__(size, spacing, eps, pml=pml, pml_params=pml_params, pml_sep=pml_sep, name=name) 39 | self.dt = 1 / np.sqrt(np.sum(1 / self.spacing ** 2)) # includes courant condition! 40 | self.use_jax = use_jax 41 | self.xp = jnp if use_jax else np 42 | self.pml_regions = [] 43 | self.sigma = None 44 | self.cpml_b, self.cpml_c = [], [] 45 | self._curl_e_pml, self._curl_h_pml = [], [] 46 | # pml (internal to the grid / does not affect params, so specified here!) 47 | if self.pml_shape is not None: 48 | self._set_pml(pml_params) 49 | self._curl_e = self.curl_fn(use_jax=self.use_jax) 50 | self._curl_h = self.curl_fn(of_h=True, use_jax=self.use_jax) 51 | 52 | # raise NotImplementedError("This class is still WIP") 53 | 54 | @classmethod 55 | def from_pattern(cls, component: "Pattern", core_eps: float, clad_eps: float, spacing: float, boundary: Size, 56 | pml: float, component_t: float = 0, component_zmin: Optional[float] = None, 57 | rib_t: float = 0, sub_z: float = 0, height: float = 0, bg_eps: float = 1, name: str = 'fdfd'): 58 | """Initialize an FDFD from a Pattern defined in DPhox. 59 | 60 | Args: 61 | component: pattern provided by DPhox 62 | core_eps: core epsilon (in the pattern mask region) 63 | clad_eps: clad epsilon 64 | spacing: spacing required 65 | boundary: boundary size around component 66 | pml: PML boundary size 67 | height: height for 3d simulation 68 | sub_z: substrate minimum height 69 | component_zmin: component height (defaults to substrate_z) 70 | component_t: component thickness 71 | rib_t: rib thickness for component (partial etch) 72 | bg_eps: background epsilon (usually 1 or air/vacuum) 73 | name: Name of the component 74 | 75 | Returns: 76 | A Grid object for the component 77 | 78 | """ 79 | if not DPHOX_INSTALLED: 80 | raise ImportError('DPhox not installed, but it is required to run this function.') 81 | b = component.size 82 | x = b[0] + 2 * boundary[0] 83 | y = b[1] + 2 * boundary[1] 84 | component_zmin = sub_z if component_zmin is None else component_zmin 85 | spacing = spacing * np.ones(2 + (component_t > 0)) if isinstance(spacing, float) else np.asarray(spacing) 86 | size = (x, y, height) if height > 0 else (x, y) 87 | grid = cls(size, spacing, eps=bg_eps, pml=pml, name=name) 88 | grid.fill(sub_z + rib_t, core_eps) 89 | grid.fill(sub_z, clad_eps) 90 | grid.add(component, core_eps, component_zmin, component_t) 91 | return grid 92 | 93 | @property 94 | def zero_state(self) -> State: 95 | """Zero state, the default initial state for the FDTD 96 | 97 | Returns: 98 | Hidden state of the form: 99 | e: current :math:`\\mathbf{E}` 100 | 101 | h: current :math:`\\mathbf{H}` 102 | 103 | psi_e: current :math:`\\boldsymbol{\\Psi}_E` for CPML updates (otherwise :code:`None`) 104 | 105 | psi_h: current :math:`\\boldsymbol{\\Psi}_H` for CPML updates (otherwise :code:`None`) 106 | 107 | """ 108 | # stored fields for fdtd 109 | e = self.xp.zeros(self.field_shape, dtype=np.complex128) 110 | h = self.xp.zeros_like(e) 111 | # for pml updates 112 | psi_e = None if self.pml_shape is None else [self.xp.zeros_like(self.xp.vstack([e[s], e[s]])) 113 | for s in self.pml_regions] 114 | psi_h = None if self.pml_shape is None else [self.xp.zeros_like(p) for p in psi_e] 115 | return e, h, psi_e, psi_h 116 | 117 | def _step_e(self, state: State, sources: List[Tuple[np.ndarray, np.ndarray, Tuple[slice, ...]]]): 118 | e, h, psi_e, psi_h = state 119 | # update pml in pml regions if specified 120 | phi_e = [] 121 | for pml_idx, pml_region in enumerate(self.pml_regions): 122 | psi_e[pml_idx], p = self._curl_h_pml[pml_idx](h, psi_e[pml_idx], self.cpml_b[pml_idx][1]) 123 | phi_e.append(p) 124 | 125 | # update e field 126 | e += self._curl_h(h) / self.eps_t * self.dt 127 | 128 | for pml_idx, pml_region in enumerate(self.pml_regions): 129 | if self.use_jax: 130 | e = e.at[pml_region].add(phi_e[pml_idx] / self.eps_t[pml_region] * self.dt) 131 | else: 132 | e[pml_region] += phi_e[pml_idx] / self.eps_t[pml_region] * self.dt 133 | 134 | # add source 135 | for source, _, source_region in sources: 136 | if source is not None: 137 | if self.use_jax: 138 | e = e.at[source_region].set(-source.squeeze() / self.eps_t[source_region] * self.dt) 139 | else: 140 | e[source_region] -= source.squeeze() / self.eps_t[source_region] * self.dt 141 | 142 | return e, h, psi_e, psi_h 143 | 144 | def _step_h(self, state: State, sources: List[Source]): 145 | e, h, psi_e, psi_h = state 146 | # update h field in pml regions if specified 147 | phi_h = [] 148 | for pml_idx, pml_region in enumerate(self.pml_regions): 149 | psi_h[pml_idx], p = self._curl_e_pml[pml_idx](e, psi_h[pml_idx], self.cpml_b[pml_idx][0]) 150 | phi_h.append(p) 151 | 152 | # update h field 153 | h -= self._curl_e(e) * self.dt 154 | 155 | for pml_idx, pml_region in enumerate(self.pml_regions): 156 | if self.use_jax: 157 | h = h.at[pml_region].add(-phi_h[pml_idx] * self.dt) 158 | else: 159 | h[pml_region] -= phi_h[pml_idx] * self.dt 160 | 161 | # add source 162 | for _, source, source_region in sources: 163 | if source is not None: 164 | if self.use_jax: 165 | h = h.at[source_region].set(-source.squeeze() * self.dt) 166 | else: 167 | h[source_region] -= source.squeeze() * self.dt 168 | 169 | return e, h, psi_e, psi_h 170 | 171 | def step(self, state: State, sources: List[Source]) -> State: 172 | """FDTD step (in the form of an RNNCell) 173 | 174 | Notes: 175 | The updates are of the form: 176 | 177 | .. math:: 178 | \\mathbf{E}(t + \\mathrm{d}t) &= \\mathbf{E}(t) + \\mathrm{d}t 179 | \\frac{\\mathrm{d}\\mathbf{E}}{\\mathrm{d}t} 180 | 181 | \\mathbf{H}(t + \\mathrm{d}t) &= \\mathbf{H}(t) + 182 | \\mathrm{d}t \\frac{\\mathrm{d}\\mathbf{H}}{\\mathrm{d}t} 183 | 184 | From Maxwell's equations, we have (for current source :math:`\\mathbf{J}(t)`): 185 | 186 | .. math:: 187 | \\frac{\\mathrm{d}\\mathbf{E}}{\\mathrm{d}t} &= \\frac{1}{\\epsilon} \\nabla 188 | \\times \\mathbf{H}(t) + \\mathbf{J}(t) 189 | 190 | \\frac{\\mathrm{d}\\mathbf{H}}{\\mathrm{d}t} &= -\\frac{1}{\\mu} \\nabla \\times 191 | \\mathbf{E}(t) + \\mathbf{M}(t) 192 | 193 | The recurrent update assumes that :math:`\\mu = c = 1, \\mathbf{M}(t) = \\mathbf{0}` and factors in 194 | perfectly-matched layers (PML), which requires storing two additional PML arrays in the system's state 195 | vector, namely :math:`\\boldsymbol{\\Psi}_E(t)` and :math:`\\boldsymbol{\\Psi}_H(t)`. 196 | 197 | .. math:: 198 | \\mathbf{\\Psi_E}^{(t+1/2)} &= \\mathbf{b} \\mathbf{\\Psi_E}^{(t-1/2)} + 199 | \\nabla_{\\mathbf{c}} \\times \\mathbf{H}^{(t)} 200 | 201 | \\mathbf{\\Psi_H}^{(t + 1)} &= \\mathbf{b} \\mathbf{\\Psi_H}^{(t)} + 202 | \\nabla_{\\mathbf{c}} \\times \\mathbf{E}^{(t)} 203 | 204 | \\mathbf{E}^{(t+1/2)} &= \\mathbf{E}^{(t-1/2)} + \\frac{\\Delta t}{\\epsilon} \\left(\\nabla \\times 205 | \\mathbf{H}^{(t)} + \\mathbf{J}^{(t)} + \\mathbf{\\Psi_E}^{(t+1/2)}\\right) 206 | 207 | \\mathbf{H}^{(t + 1)} &= \\mathbf{H}^{(t)} - \\Delta t \\left(\\nabla \\times \\mathbf{E}^{(t+1/2)} + 208 | \\mathbf{\\Psi_H}^{(t + 1)}\\right) 209 | 210 | 211 | Note, in Einstein notation, the weighted curl operator is given by: 212 | :math:`\\nabla_{\\mathbf{c}} \\times \\mathbf{v} := \\epsilon_{ijk} c_i \\partial_j v_k`. 213 | 214 | Args: 215 | state: current state of the form :code:`(e, h, psi_e, psi_h)` = :math:`(\\mathbf{E}(t), 216 | \\mathbf{H}(t), \\boldsymbol{\\Psi}_E(t), \\boldsymbol{\\Psi}_H(t))`. 217 | sources: The sources :math:`\\mathbf{J}_i(t)`, i.e. the input excitations to the system, 218 | and the corresponding slice or mask of the added source to be added to E in the update, 219 | which must be the same shape as :math:`\\mathbf{J}_i(t)`. 220 | 221 | Returns: 222 | a new :code:`State` of the form :code:`(e, h, psi_e, psi_h)` = :math:`(\\mathbf{E}(t), 223 | \\mathbf{H}(t), \\boldsymbol{\\Psi}_E(t), \\boldsymbol{\\Psi}_H(t))`. 224 | 225 | """ 226 | 227 | state = self._step_e(state, sources) 228 | state = self._step_h(state, sources) 229 | return state 230 | 231 | def run_cw_port(self, num_time_steps: int, wavelength: float = 1.55, 232 | source_port: Union[str, List[Tuple[str, int]]] = 'a0', 233 | measure_port: Union[str, List[Tuple[str, int]]] = None, 234 | tm_2d: bool = True, profile_size_factor: float = 3, pbar: Callable = None, 235 | initial_state: Optional[State] = None, viz_pipes: Optional[dict] = None, 236 | viz_interval: int = 1, viz_h: bool = True, viz_axis: int = 2): 237 | """Run the FDTD using an harmonic source with an eigenmode at a port at a specified wavelength. 238 | 239 | Args: 240 | num_time_steps: total time to run the simulation. 241 | wavelength: Wavelength of the CW harmonic source. 242 | source_port: The source port(s), default a0, generally considered to be specified default input port. 243 | measure_port: The measure port(s), measure at all ports if None. 244 | tm_2d: If 2D, use the TM mode, else use the TE mode. Ignore if 3D. 245 | profile_size_factor: profile size factor (multiply the port size to get profile sim region size) 246 | pbar: Progress bar handle (e.g. tqdm) 247 | initial_state: Initial state fot the FDTD (default is the zero state called by :code:`fdtd.initial_state()`) 248 | viz_pipes: Visualization streaming handle structure mapping port names to visualization pipes. 249 | This is useful for streaming live simulation results in a notebook! We assume multilayer device for now, 250 | which means the z-dimension is where we perform a cross section. However, you can look along another 251 | dimension (e.g. y) if you supply a tuple of the port and the dimension 0 or 1. 252 | viz_interval: Interval to send streaming data 253 | 254 | Returns: 255 | flux: The flux measurements (not averaged) as an array, which can be averaged during post-processing. 256 | state: final state of the form :code:`(e, h, psi_e, psi_h)` 257 | -:code:`e` refers to electric field :math:`\\mathbf{E}(t)` 258 | -:code:`h` refers to magnetic field :math:`\\mathbf{H}(t)` 259 | -:code:`psi_e` refers to :math:`\\boldsymbol{\\Psi}_E(t)` (for debugging PML) 260 | -:code:`psi_h` refers to :math:`\\boldsymbol{\\Psi}_H(t)` (for debugging PML) 261 | 262 | """ 263 | source_excitation = list(parse_source_port(source_port).keys()) 264 | source_modes = self.port_modes(excitation=source_excitation, 265 | profile_size_factor=profile_size_factor, wavelength=wavelength) 266 | measure_fn = self.get_measure_fn(measure_port, use_jax=self.use_jax, 267 | profile_size_factor=profile_size_factor, tm_2d=tm_2d) 268 | state = self.zero_state if initial_state is None else initial_state 269 | iterator = range(num_time_steps) if pbar is None else pbar(np.arange(num_time_steps)) 270 | flux = [measure_fn(state[:2])] 271 | k0 = 2 * np.pi / wavelength 272 | for step in iterator: 273 | sources = [] 274 | for p, midx in source_excitation: 275 | mode = source_modes[p] 276 | time_shift = np.exp(1j * step * self.dt * k0) 277 | src_e, src_h, src_slice = mode.profile(midx), mode.profile(midx, use_h=True), mode.slice(self) 278 | sources.append((src_e * time_shift, src_h * time_shift, src_slice)) 279 | _, h_before, _, _ = state 280 | state = self.step(state, sources) 281 | e, h, _, _ = state 282 | synchronized_fields = np.stack((e, (h_before + h) / 2)) 283 | flux.append(measure_fn(synchronized_fields)) 284 | if viz_pipes and step % viz_interval == 0: 285 | self._viz(viz_pipes, e, h, viz_h, viz_axis) 286 | return flux, state 287 | 288 | def _viz(self, viz_pipes: dict, e: Array, h: Array, viz_h: bool, viz_axis: int): 289 | """Visualize the fields.""" 290 | for port_name in viz_pipes: 291 | eps_pipe, field_pipe, power_pipe = viz_pipes[port_name] 292 | if self.ndim == 3: 293 | port_name = (port_name, 2) if not isinstance(port_name, tuple) else port_name 294 | idx = int(np.around(self.port[port_name[0]].xyz[port_name[1]] / self.spacing[port_name[1]])) 295 | eps_slice = tuple([idx if ax == port_name[1] else slice(None) for ax in range(3)]) 296 | eps = self.eps[eps_slice].T 297 | f = np.array(h[(slice(None), *eps_slice)] if viz_h else e[(slice(None), *eps_slice)]) 298 | else: 299 | eps = self.eps.T 300 | f = np.array(h if viz_h else e) 301 | eps_pipe.send((eps - np.min(eps)) / (np.max(eps) - np.min(eps))) 302 | field_pipe.send(f[viz_axis].T.real / np.max(f[viz_axis].T.real + np.spacing(1))) 303 | power = np.abs(f[viz_axis].T) ** 2 304 | power_pipe.send(power / np.max(power + np.spacing(1))) 305 | 306 | def _set_pml(self, pml_params: Size3): 307 | exp_scale, log_reflection, absorption_corr = pml_params 308 | kappa, alpha = 1, 1e-8 # TODO: make these params 309 | self.sigma = [-pml_sigma(self.pos[ax], thickness=self.pml_shape[ax], exp_scale=exp_scale, 310 | log_reflection=log_reflection, absorption_corr=absorption_corr) for ax in range(3)] 311 | # for memory and time purposes, we only update the pml slices, NOT the full field 312 | # therefore, we need to specify the pml regions for the fields. 313 | self.pml_regions = ((slice(None), slice(None, self.pml_shape[0]), slice(None), slice(None)), 314 | (slice(None), slice(-self.pml_shape[0], None), slice(None), slice(None)), 315 | (slice(None), slice(None), slice(None, self.pml_shape[1]), slice(None)), 316 | (slice(None), slice(None), slice(-self.pml_shape[1], None), slice(None)), 317 | (slice(None), slice(None), slice(None), slice(None, self.pml_shape[2])), 318 | (slice(None), slice(None), slice(None), slice(-self.pml_shape[2], None))) 319 | 320 | for i, region in enumerate(self.pml_regions): 321 | ax = i // 2 322 | if self.pml_shape[ax]: 323 | pml_slice = tuple([None if idx == slice(None) else idx for idx in region[1:]]) 324 | pml_shape = np.array((2,) + self.field_shape) 325 | pml_shape[ax + 2] = self.pml_shape[ax] 326 | sigma_ax = np.zeros(pml_shape, dtype=np.complex128) 327 | sigma_ax[0, ax] = self.sigma[ax][0][pml_slice] 328 | sigma_ax[1, ax] = self.sigma[ax][1][pml_slice] 329 | self.cpml_b.append(np.exp(-(alpha + sigma_ax / kappa) * self.dt)) 330 | self.cpml_c.append((self.cpml_b[-1] - 1) * sigma_ax / (sigma_ax * kappa + alpha * kappa ** 2)) 331 | self._curl_h_pml = [self.curl_h_pml(pml_idx) for pml_idx in range(len(self.pml_regions))] 332 | self._curl_e_pml = [self.curl_e_pml(pml_idx) for pml_idx in range(len(self.pml_regions))] 333 | 334 | def curl_e_pml(self, pml_idx: int) -> Callable[[Array, Array, Array], Array]: 335 | dx, _ = self._dxes 336 | c, s = self.cpml_c[pml_idx][0], self.pml_regions[pml_idx][1:] 337 | 338 | def de(e, ax): 339 | return c[ax] * (self.xp.roll(e, -1, axis=ax)[s] - e[s]) / dx[ax][s] 340 | return curl_pml_fn(de, use_jax=self.use_jax) 341 | 342 | def curl_h_pml(self, pml_idx: int) -> Callable[[Array, Array, Array], Array]: 343 | _, dx = self._dxes 344 | c, s = self.cpml_c[pml_idx][1], self.pml_regions[pml_idx][1:] 345 | 346 | def dh(h, ax): 347 | return c[ax] * (h[s] - self.xp.roll(h, 1, axis=ax)[s]) / dx[ax][s] 348 | return curl_pml_fn(dh, use_jax=self.use_jax) 349 | 350 | @property 351 | @lru_cache() 352 | def eps_t(self): 353 | """The epsilon tensor (assumed to be diagonal, i.e. no off-diagonal components for now). 354 | 355 | Returns: 356 | 357 | """ 358 | yee = yee_avg_jax if self.use_jax else yee_avg 359 | eps_t = yee(self.xp.array(self.eps.reshape(self.shape3))) 360 | return eps_t 361 | -------------------------------------------------------------------------------- /simphox/mkl.py: -------------------------------------------------------------------------------- 1 | from .typing import Optional 2 | import sys 3 | from ctypes import CDLL, byref, c_char, c_int, c_int64, POINTER, c_float, c_double 4 | 5 | import numpy as np 6 | import scipy.sparse as sp 7 | from typing import Tuple 8 | 9 | 10 | libname = {'linux': 'libmkl_rt.so', # python3 11 | 'linux2': 'libmkl_rt.so', # python2 12 | 'darwin': 'libmkl_rt.dylib', 13 | 'win32': 'mkl_rt.dll'} 14 | mkl = CDLL(libname[sys.platform]) 15 | 16 | pardisoinit = mkl.pardisoinit 17 | pardisoinit.argtypes = [POINTER(c_int64), 18 | POINTER(c_int), 19 | POINTER(c_int)] 20 | pardisoinit.restype = None 21 | 22 | pardiso = mkl.pardiso 23 | pardiso.argtypes = [POINTER(c_int64), # pt 24 | POINTER(c_int), # maxfct 25 | POINTER(c_int), # mnum 26 | POINTER(c_int), # mtype 27 | POINTER(c_int), # phase 28 | POINTER(c_int), # n 29 | POINTER(None), # a 30 | POINTER(c_int), # ia 31 | POINTER(c_int), # ja 32 | POINTER(c_int), # perm 33 | POINTER(c_int), # nrhs 34 | POINTER(c_int), # iparm 35 | POINTER(c_int), # msglvl 36 | POINTER(None), # rhs 37 | POINTER(None), # x 38 | POINTER(c_int)] # error 39 | pardiso.restype = None 40 | 41 | feastinit = mkl.feastinit 42 | feastinit.argtypes = [POINTER(c_int)] 43 | feastinit.restype = None 44 | 45 | feast_argtypes = [POINTER(c_char), # uplo 46 | POINTER(c_int), # n 47 | POINTER(None), # a 48 | POINTER(c_int), # ia 49 | POINTER(c_int), # ja 50 | POINTER(c_int), # fpm 51 | POINTER(None), # epsout 52 | POINTER(c_int), # loop 53 | POINTER(None), # emin 54 | POINTER(None), # emax 55 | POINTER(c_int), # m0 56 | POINTER(None), # e 57 | POINTER(None), # x 58 | POINTER(c_int), # m 59 | POINTER(None), # res 60 | POINTER(c_int)] # info 61 | 62 | sfeast_scsrev, dfeast_scsrev, cfeast_hcsrev, zfeast_hcsrev = mkl.sfeast_scsrev, mkl.dfeast_scsrev,\ 63 | mkl.cfeast_hcsrev, mkl.zfeast_hcsrev 64 | sfeast_scsrev.argtypes = dfeast_scsrev.argtypes = cfeast_hcsrev.argtypes = zfeast_hcsrev.argtypes = feast_argtypes 65 | sfeast_scsrev.restype = dfeast_scsrev.restype = cfeast_hcsrev.restype = zfeast_hcsrev.restype = None 66 | 67 | 68 | PARDISO_FREEFACTOR = -1 69 | PARDISO_FREEALL = 0 70 | PARDISO_SOLVE = 33 71 | PARDISO_FACTORIZE = 12 72 | PARDISO_FULLSOLVE = 13 73 | 74 | 75 | # a cleaner version of pyMKL with some additional factorization optimizations 76 | class Pardiso: 77 | def __init__(self, mtype: int = 13): 78 | self.mtype = mtype 79 | if mtype in (1, 3): 80 | raise NotImplementedError(f"mtype = {mtype} - structurally symmetric not supported") 81 | if self.is_complex: 82 | self.dtype = np.complex128 83 | elif self.is_real: 84 | self.dtype = np.float64 85 | else: 86 | raise ValueError(f"mtype = {mtype} - invalid mtype, need (2, -2, 4, -4, 6, 11, 13)") 87 | self.ctypes_dtype = np.ctypeslib.ndpointer(self.dtype) 88 | 89 | self.pt = np.zeros(64, np.int64) 90 | self.pt_ = self.pt.ctypes.data_as(POINTER(c_int64)) 91 | 92 | self.iparm = np.zeros(64, dtype=np.int32) 93 | self.iparm_ = self.iparm.ctypes.data_as(POINTER(c_int)) 94 | 95 | pardisoinit(self.pt_, byref(c_int(self.mtype)), self.iparm_) 96 | 97 | # from pyMKL 98 | self.iparm[1] = 3 # Parallel nested dissection for reordering 99 | self.iparm[23] = 1 # Parallel factorization 100 | self.iparm[34] = 1 # Zero-indexing 101 | 102 | self.phase = PARDISO_FULLSOLVE 103 | self._mat_hash = 0 # no matrix has been factorized yet, and this is an unlikely hash for it to be assigned 104 | 105 | @property 106 | def is_complex(self) -> bool: 107 | return self.mtype in (4, -4, 6, 13) 108 | 109 | @property 110 | def is_real(self) -> bool: 111 | return self.mtype in (2, -2, 11) 112 | 113 | def _set_mat(self, mat: sp.csr_matrix): 114 | # If mat is symmetric, store only the upper triangular portion 115 | if self.mtype in [2, -2, 4, -4, 6]: 116 | mat = sp.triu(mat, format='csr') 117 | 118 | if mat.dtype != self.dtype: 119 | raise ValueError(f"Expected mat.dtype to match chosen mtype but got {mat.dtype} != {self.dtype}") 120 | if mat.shape[0] != mat.shape[1] or mat.ndim > 2: 121 | raise ValueError(f'Expected mat square (i.e., shape (n, n)), but has shape: {mat.shape}') 122 | 123 | if not mat.has_sorted_indices: 124 | mat.sort_indices() 125 | 126 | self.mat: sp.csr_matrix = mat 127 | 128 | self.a = self.mat.data 129 | self.a_ = self.a.ctypes.data_as(self.ctypes_dtype) 130 | 131 | self.ia = self.mat.indptr 132 | self.ia_ = self.ia.ctypes.data_as(POINTER(c_int)) 133 | 134 | self.ja = self.mat.indices 135 | self.ja_ = self.ja.ctypes.data_as(POINTER(c_int)) 136 | 137 | self.n = mat.shape[0] 138 | 139 | def free(self, complete: bool = True): 140 | self.phase = PARDISO_FREEALL if complete else PARDISO_FREEFACTOR 141 | self.pardiso() 142 | 143 | def factor(self, mat: sp.csr_matrix): 144 | self.phase = PARDISO_FACTORIZE 145 | mat_hash = hash(str(mat)) 146 | self._set_mat(mat) 147 | self._mat_hash = mat_hash 148 | self.pardiso() 149 | 150 | def solve(self, mat: sp.csr_matrix, rhs: np.ndarray) -> np.ndarray: 151 | mat_hash = hash(str(mat)) 152 | if mat_hash != self._mat_hash: 153 | self._set_mat(mat) 154 | self._mat_hash = mat_hash 155 | self.phase = PARDISO_FULLSOLVE 156 | else: 157 | self.phase = PARDISO_SOLVE 158 | return self.pardiso(rhs) 159 | 160 | def pardiso(self, rhs: Optional[np.ndarray] = None) -> np.ndarray: 161 | if self._mat_hash == 0: 162 | raise RuntimeError('Mat information not stored in Pardiso.') 163 | if rhs is not None and rhs.shape[0] != self.n: 164 | raise RuntimeError(f'Expected rhs.shape[0] == {self.n}, but got {rhs.shape[0]}') 165 | 166 | nrhs = 0 if rhs is None else (np.prod(rhs.shape[1:]) if rhs.ndim > 1 else 1) 167 | rhs = np.zeros(1) if rhs is None else rhs.astype(self.dtype).flatten(order='f') 168 | x = np.zeros(1) if rhs is None else np.zeros(nrhs * self.n, dtype=self.dtype) 169 | rhs_ = rhs.ctypes.data_as(self.ctypes_dtype) 170 | x_ = x.ctypes.data_as(self.ctypes_dtype) 171 | 172 | err_c = c_int(0) 173 | 174 | mkl.pardiso( 175 | self.pt_, # pt 176 | byref(c_int(1)), # maxfct 177 | byref(c_int(1)), # mnum 178 | byref(c_int(self.mtype)), # mtype 179 | byref(c_int(self.phase)), # phase 180 | byref(c_int(self.n)), # n 181 | self.a_, # a 182 | self.ia_, # ia 183 | self.ja_, # ja 184 | byref(c_int(0)), # perm 185 | byref(c_int(nrhs)), # nrhs 186 | self.iparm_, # iparm 187 | byref(c_int(0)), # msglvl 188 | rhs_, # rhs 189 | x_, # x 190 | byref(err_c) # error 191 | ) 192 | 193 | if self.iparm[13] > 0: 194 | raise RuntimeError(f"Pardiso - Number of perturbed pivot elements = {repr(self.iparm[13])}. " 195 | f"This could mean that the matrix is singular.") 196 | 197 | if err_c.value != 0: 198 | raise RuntimeError(f"Pardiso returned an error with code {err_c.value}. " 199 | f"Check error codes in manual: https://pardiso-project.org/manual/manual.pdf") 200 | 201 | return x.reshape(rhs.shape, order='f') if nrhs > 1 else x 202 | 203 | 204 | class Feast: 205 | def __init__(self, num_contours: int = 8, stopping: int = 12, max_refinement_loops_dp: int = 20, 206 | max_refinement_loops_sp: int = 5, one_contour: bool = False, 207 | check_input: bool = False, check_pos_definite: bool = False): 208 | self.fpm = np.zeros(128, np.int32) 209 | self.fpm_ = self.fpm.ctypes.data_as(POINTER(c_int)) 210 | self.iparm = self.fpm[-64:] 211 | 212 | feastinit(self.fpm_) 213 | 214 | self.fpm[1] = num_contours 215 | self.fpm[2] = stopping 216 | self.fpm[3] = max_refinement_loops_dp 217 | self.fpm[6] = max_refinement_loops_sp 218 | self.fpm[13] = int(one_contour) 219 | self.fpm[26] = int(check_input) 220 | self.fpm[27] = int(check_pos_definite) 221 | self.fpm[63] = 1 # Enable iparm settings 222 | # from pyMKL 223 | self.iparm[1] = 3 # Parallel nested dissection for reordering 224 | self.iparm[23] = 1 # Parallel factorization 225 | 226 | def feast(self, mat: sp.csr_matrix, m0: int, erange: Tuple[float, float], symmetric: bool = True): 227 | dtype = mat.dtype 228 | ctypes_dtype = np.ctypeslib.ndpointer(dtype) 229 | single_precision = dtype == np.complex64 or dtype == np.float32 230 | rdtype = np.float32 if single_precision else np.float64 231 | ctypes_rdtype = np.ctypeslib.ndpointer(rdtype) 232 | if dtype == np.complex128 or dtype == np.complex64: 233 | feast_fn = cfeast_hcsrev if single_precision else zfeast_hcsrev 234 | elif dtype == np.float64 or dtype == np.float32: 235 | feast_fn = sfeast_scsrev if single_precision else dfeast_scsrev 236 | else: 237 | raise TypeError(f'Expected mat.dtype to be one of (np.complex128, np.complex64, np.float64, np.float32),' 238 | f'but got {mat.dtype}.') 239 | 240 | # If mat is symmetric, store only the upper triangular portion 241 | # if symmetric: 242 | # mat = sp.triu(mat, format='csr') 243 | 244 | if mat.shape[0] != mat.shape[1] or mat.ndim > 2: 245 | raise ValueError(f'Expected mat square (i.e., shape (n, n)), but has shape: {mat.shape}') 246 | 247 | if not mat.has_sorted_indices: 248 | mat.sort_indices() 249 | 250 | n = mat.shape[0] 251 | 252 | emin, emax = erange 253 | # uplo_ = byref(c_char(b'U')) if symmetric else byref(c_char(b'F')) 254 | uplo_ = byref(c_char(b'F')) 255 | emin_ = byref(c_float(emin)) if single_precision else byref(c_double(emin)) 256 | emax_ = byref(c_float(emax)) if single_precision else byref(c_double(emax)) 257 | epsout_ = byref(c_float(0.0)) if single_precision else byref(c_double(0.0)) 258 | loop = c_int(0) 259 | m = c_int(0) 260 | x = np.zeros(m0 * n, dtype=dtype) 261 | e = np.zeros(m0, dtype=rdtype) 262 | res = np.zeros(m0, dtype=rdtype) 263 | res_ = res.ctypes.data_as(ctypes_rdtype) 264 | info = c_int(0) 265 | 266 | feast_fn( 267 | uplo_, # uplo 268 | byref(c_int(n)), # n 269 | mat.data.ctypes.data_as(ctypes_dtype), # a 270 | (mat.indptr + 1).ctypes.data_as(POINTER(c_int)), # ia (one-based indexing) 271 | (mat.indices + 1).ctypes.data_as(POINTER(c_int)), # ja (one-based indexing) 272 | self.fpm_, # fpm 273 | epsout_, # epsout 274 | byref(loop), # loop 275 | emin_, # emin 276 | emax_, # emax 277 | c_int(m0), # m0 278 | e.ctypes.data_as(ctypes_rdtype), # e 279 | x.ctypes.data_as(ctypes_dtype), # x 280 | byref(m), # m 281 | res_, # res 282 | byref(info) # info 283 | ) 284 | 285 | if self.iparm[13] > 0: 286 | raise RuntimeError(f"Pardiso - Number of perturbed pivot elements = {repr(self.iparm[13])}. " 287 | f"This could mean that the matrix is singular.") 288 | 289 | if info.value != 0: 290 | raise RuntimeError(f"Feast returned an error/warning with code {info.value}. " 291 | f"Check error codes in manual: " 292 | f"https://software.intel.com/sites/default/files/mkl-2020-developer-reference-fortran.pdf") 293 | 294 | return e, x.reshape((n, m0), order='f'), m.value, loop.value, res, info.value 295 | 296 | 297 | pardiso = Pardiso() 298 | feast = Feast() 299 | 300 | 301 | def spsolve_pardiso(mat: sp.spmatrix, rhs: np.ndarray): 302 | if not isinstance(mat, sp.spmatrix): 303 | raise TypeError(f'mat must be an instance of spmatrix but got {type(mat)}') 304 | if not isinstance(rhs, np.ndarray): 305 | raise TypeError(f'mat must be an instance of ndarray but got {type(rhs)}') 306 | return pardiso.solve(mat.tocsr(), rhs) 307 | 308 | 309 | def feast_eigs(mat: sp.spmatrix, erange: Tuple[float, float], k: int = 6, symmetric: bool = True): 310 | return feast.feast(mat.tocsr(), k, erange, symmetric) 311 | -------------------------------------------------------------------------------- /simphox/opt.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from .utils import fix_dataclass_init_docs 4 | from .sim import SimGrid 5 | from .typing import Optional, Callable, Union, List, Tuple, Dict 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | from jax.config import config 10 | from jax.example_libraries.optimizers import adam 11 | import numpy as np 12 | import dataclasses 13 | import xarray 14 | 15 | try: 16 | HOLOVIEWS_IMPORTED = True 17 | import holoviews as hv 18 | from holoviews.streams import Pipe 19 | import panel as pn 20 | except ImportError: 21 | HOLOVIEWS_IMPORTED = False 22 | 23 | from .viz import scalar_metrics_viz 24 | 25 | config.parse_flags_with_absl() 26 | 27 | 28 | @fix_dataclass_init_docs 29 | @dataclasses.dataclass 30 | class OptProblem: 31 | """An optimization problem 32 | 33 | An optimization problem consists of a neural network defined at least by input parameters :code:`rho`, 34 | the transform function :code:`T(rho)` (:math:`T(\\rho(x, y))`) (default identity), 35 | and objective function :code:`C(T(rho))` (:math:`C(T(\\rho(x, y)))`), which maps to a scalar. 36 | For use with an inverse design problem (the primary use case in this module), the user can include an 37 | FDFD simulation and a source (to be fed into the FDFD solver). The FDFD simulation and source are then 38 | used to define a function :code:`S(eps) == S(T(rho))` that solves the FDFD problem 39 | where `eps == T(rho)` (:math:`\\epsilon(x, y)) := T(\\rho(x, y))`), 40 | in which case the objective function evaluates :code:`C(S(T(rho)))` 41 | (:math:`\\epsilon(x, y)) := C(S(T((\\rho(x, y))))`). 42 | 43 | Args: 44 | transform_fn: The JAX-transformable transform function to yield epsilon (identity if None, 45 | must be a single :code:`transform_fn` (to be broadcast to all) 46 | or a list to match the FDFD objects respectively). Examples of transform_fn 47 | could be smoothing functions, symmetry functions, and more (which can be compounded appropriately). 48 | cost_fn: The JAX-transformable cost function (or tuple of such functions) 49 | corresponding to src that takes in output of solve_fn from :code:`opt_solver`. 50 | sim: SimGrid(s) used to generate the solver (FDFD is not run is :code:`fdfd` is :code:`None`) 51 | source: A numpy array source (FDFD is not run is :code:`source` is :code:`None`) 52 | metrics_fn: A metric_fn that returns useful dictionary data based on fields and FDFD object 53 | at certain time intervals (specified in opt). Each problem is supplied this metric_fn 54 | (Optional, ignored if :code:`None`). 55 | 56 | """ 57 | transform_fn: Callable 58 | cost_fn: Callable 59 | sim: SimGrid 60 | source: str 61 | metrics_fn: Optional[Callable[[np.ndarray, SimGrid], Dict]] = None 62 | 63 | def __post_init__(self): 64 | self.fn = self.sim.get_sim_sparams_fn(self.source, self.transform_fn)\ 65 | if self.source is not None else self.transform_fn 66 | 67 | 68 | @fix_dataclass_init_docs 69 | @dataclasses.dataclass 70 | class OptViz: 71 | """An optimization visualization object 72 | 73 | An optimization visualization object consists of a plot for monitoring the 74 | history and current state of an optimization in real time. 75 | 76 | Args: 77 | cost_dmap: Cost dynamic map for streaming cost fn over time 78 | simulations_panel: Simulations panel for visualizing simulation results from last iteration 79 | costs_pipe: Costs pipe for streaming cost fn over time 80 | simulations_pipes: Simulations pipes of form :code:`eps, field, power` 81 | for visualizing simulation results from last iteration 82 | metrics_panels: Metrics panels for streaming metrics over time for each simulation (e.g. powers/power ratios) 83 | metrics_pipes: Metrics pipes for streaming metrics over time for each simulation 84 | metric_config: Metric config (a dictionary that describes how to plot/group the real-time metrics) 85 | 86 | """ 87 | cost_dmap: "hv.DynamicMap" 88 | simulations_panels: Dict[str, "pn.layout.Panel"] 89 | costs_pipe: "Pipe" 90 | simulations_pipes: Dict[str, Tuple["Pipe", "Pipe", "Pipe"]] 91 | metric_config: Optional[Dict[str, List[str]]] = None 92 | metrics_panels: Optional[Dict[str, "hv.DynamicMap"]] = None 93 | metrics_pipes: Optional[Dict[str, Dict[str, "Pipe"]]] = None 94 | 95 | 96 | @fix_dataclass_init_docs 97 | @dataclasses.dataclass 98 | class OptRecord: 99 | """An optimization record 100 | 101 | We need an object to hold the history, which includes a list of costs (we avoid the term loss 102 | as it may be related to denoted 103 | 104 | Attributes: 105 | costs: List of costs 106 | params: Params (:math:`\rho`) transformed into the design 107 | metrics: An xarray for metrics with dimensions :code:`name`, :code:`metric`, :code:`iteration` 108 | eps: An xarray for relative permittivity with dimensions :code:`name`, :code:`x`, :code:`y` 109 | fields: An xarray for a selected field component with dimensions :code:`name`, :code:`x`, :code:`y` 110 | 111 | """ 112 | costs: np.ndarray 113 | params: jnp.ndarray 114 | metrics: xarray.DataArray 115 | eps: xarray.DataArray 116 | fields: xarray.DataArray 117 | 118 | 119 | def opt_run(opt_problem: Union[OptProblem, List[OptProblem]], init_params: np.ndarray, num_iters: int, 120 | pbar: Optional[Callable] = None, step_size: float = 1, viz_interval: int = 0, metric_interval: int = 0, 121 | viz: Optional[OptViz] = None, backend: str = 'cpu', 122 | eps_interval: int = 0, field_interval: int = 0) -> OptRecord: 123 | """Run the optimization. 124 | 125 | The optimization can be done over multiple simulations as long as those simulations 126 | share the same set of params provided by :code:`init_params`. 127 | 128 | Args: 129 | opt_problem: An :code:`OptProblem` or list of :code:`OptProblem`'s. If a list is provided, 130 | the optimization optimizes the sum of all objective functions. 131 | If the user wants to weight the objective functions, weights must be inlcuded in the objective function 132 | definition itself, but we may provide support for this feature at a later time if needed. 133 | init_params: Initial parameters for the optimizer (:code:`eps` if :code:`None`) 134 | num_iters: Number of iterations to run 135 | pbar: Progress bar to keep track of optimization progress with ideally a simple tqdm interface 136 | step_size: For the Adam update, specify the step size needed. 137 | viz_interval: The optimization intermediate results are recorded every :code:`record_interval` steps 138 | (default of 0 means do not visualize anything) 139 | metric_interval: The interval over which a recorded object (e.g. metric, param) 140 | are recorded in a given :code:`OptProblem` (default of 0 means do not record anything). 141 | viz: The :code:`OptViz` object required for visualizing the optimization in real time. 142 | backend: Recommended backend for :code:`ndim == 2` is :code:`'cpu'` and :code:`ndim == 3` is :code:`'gpu'` 143 | eps_interval: Whether to record the eps at the specified :code:`eps_interval`. 144 | Beware, this can use up a lot of memory during the opt so use judiciously. 145 | field_interval: Whether to record the field at the specified :code:`field_interval`. 146 | Beware, this can use up a lot of memory during the opt so use judiciously. 147 | 148 | Returns: 149 | A tuple of the final eps distribution (:code:`transform_fn(p)`) and parameters :code:`p` 150 | 151 | """ 152 | 153 | opt_init, opt_update, get_params = adam(step_size=step_size) 154 | opt_state = opt_init(init_params) 155 | 156 | # define opt_problems 157 | opt_problems = [opt_problem] if isinstance(opt_problem, OptProblem) else opt_problem 158 | n_problems = len(opt_problems) 159 | 160 | # opt problems that include both an FDFD sim and a source sim 161 | sim_opt_problems = [op for op in opt_problems if op.sim is not None and op.source is not None] 162 | 163 | if viz is not None: 164 | if not len(viz.simulations_pipes) == len(sim_opt_problems): 165 | raise ValueError("Number of viz_pipes must match number of opt problems") 166 | 167 | # Define the simulation and objective function acting on parameters rho 168 | solve_fn = [None if (op.source is None or op.sim is None) else op.fn for op in opt_problems] 169 | 170 | def overall_cost_fn(rho: jnp.ndarray): 171 | evals = [op.cost_fn(s(rho)) if s is not None else op.cost_fn(rho) for op, s in zip(opt_problems, solve_fn)] 172 | return jnp.array([obj for obj, _ in evals]).sum() / n_problems, [aux for _, aux in evals] 173 | 174 | # Define a compiled update step 175 | def step_(current_step, state): 176 | vaux, g = jax.value_and_grad(overall_cost_fn, has_aux=True)(get_params(state)) 177 | v, aux = vaux 178 | return v, opt_update(current_step, g, state), aux 179 | 180 | def _update_eps(state): 181 | rho = get_params(state) 182 | for op in opt_problems: 183 | op.sim.eps = np.asarray(jax.lax.stop_gradient(op.transform_fn(rho))) 184 | 185 | step = jax.jit(step_, backend=backend) 186 | 187 | iterator = pbar(range(num_iters)) if pbar is not None else range(num_iters) 188 | 189 | costs = [] 190 | history = defaultdict(list) 191 | 192 | for i in iterator: 193 | v, opt_state, data = step(i, opt_state) 194 | _update_eps(opt_state) 195 | for sop, sparams_fields in zip(sim_opt_problems, data): 196 | sim = sop.sim 197 | sparams, e, h = sim.decorate(*sparams_fields) 198 | hz = np.asarray(h[2]).squeeze().T 199 | if viz_interval > 0 and i % viz_interval == 0 and viz is not None: 200 | eps_pipe, field_pipe, power_pipe = viz.simulations_pipes[sim.name] 201 | eps_pipe.send((sim.eps.T - np.min(sim.eps)) / (np.max(sim.eps) - np.min(sim.eps))) 202 | field_pipe.send(hz.real / np.max(hz.real)) 203 | power = np.abs(hz) ** 2 204 | power_pipe.send(power / np.max(power)) 205 | if metric_interval > 0 and i % metric_interval == 0 and viz is not None: 206 | metrics = sop.metrics_fn(sparams) 207 | for metric_name, metric_value in metrics.items(): 208 | history[f'{metric_name}/{sop.sim.name}'].append(metric_value) 209 | for title in viz.metrics_pipes[sop.sim.name]: 210 | viz.metrics_pipes[sop.sim.name][title].send( 211 | xarray.DataArray( 212 | data=np.asarray([history[f'{metric_name}/{sop.sim.name}'] 213 | for metric_name in viz.metric_config[title]]), 214 | coords={ 215 | 'metric': viz.metric_config[title], 216 | 'iteration': np.arange(i + 1) 217 | }, 218 | dims=['metric', 'iteration'], 219 | name=title 220 | ) 221 | ) 222 | if eps_interval > 0 and i % eps_interval == 0: 223 | history[f'eps/{sop.sim.name}'].append((i, sop.sim.eps)) 224 | if field_interval > 0 and i % field_interval == 0: 225 | history[f'field/{sop.sim.name}'].append((i, hz.T)) 226 | iterator.set_description(f"𝓛: {v:.5f}") 227 | costs.append(jax.lax.stop_gradient(v)) 228 | if viz is not None: 229 | viz.costs_pipe.send(np.asarray(costs)) 230 | _update_eps(opt_state) 231 | 232 | all_metric_names = sum([metric_names for _, metric_names in viz.metric_config.items()], []) 233 | metrics = xarray.DataArray( 234 | data=np.array([[history[f'{metric_name}/{sop.sim.name}'] 235 | for metric_name in all_metric_names] for sop in sim_opt_problems]), 236 | coords={ 237 | 'name': [sop.sim.name for sop in sim_opt_problems], 238 | 'metric': all_metric_names, 239 | 'iteration': np.arange(num_iters) 240 | }, 241 | dims=['name', 'metric', 'iteration'], 242 | name='metrics' 243 | ) if sim_opt_problems and metric_interval != 0 else [] 244 | eps = xarray.DataArray( 245 | data=np.array([[eps for _, eps in history[f'eps/{sop.sim.name}']] if eps_interval > 0 else [] 246 | for sop in sim_opt_problems]), 247 | coords={ 248 | 'name': [sop.sim.name for sop in sim_opt_problems], 249 | 'iteration': [it for it, _ in history[f'eps/{sim_opt_problems[0].sim.name}']], 250 | 'x': np.arange(sim_opt_problems[0].sim.shape[0]), 251 | 'y': np.arange(sim_opt_problems[0].sim.shape[1]), 252 | }, 253 | dims=['name', 'iteration', 'x', 'y'], 254 | name='eps' 255 | ) if sim_opt_problems and eps_interval != 0 else [] 256 | fields = xarray.DataArray( 257 | data=np.asarray([[field for _, field in history[f'field/{sop.sim.name}']] if field_interval > 0 else [] 258 | for sop in sim_opt_problems]), 259 | coords={ 260 | 'name': [sop.sim.name for sop in sim_opt_problems], 261 | 'iteration': [it for it, _ in history[f'field/{sim_opt_problems[0].sim.name}']], 262 | 'x': np.arange(sim_opt_problems[0].sim.shape[0]), 263 | 'y': np.arange(sim_opt_problems[0].sim.shape[1]), 264 | }, 265 | dims=['name', 'iteration', 'x', 'y'], 266 | name='fields' 267 | ) if sim_opt_problems and field_interval != 0 else [] 268 | return OptRecord(costs=np.asarray(costs), params=get_params(opt_state), metrics=metrics, eps=eps, fields=fields) 269 | 270 | 271 | def opt_viz(opt_problem: Union[OptProblem, List[OptProblem]], metric_config: Dict[str, List[str]]) -> OptViz: 272 | """Optimization visualization panel 273 | 274 | Args: 275 | opt_problem: An :code:`OptProblem` or list of :code:`OptProblem`'s. 276 | metric_config: A dictionary of titles mapped to lists of metrics to plot in the graph (for overlay) 277 | 278 | Returns: 279 | A tuple of visualization panel, loss curve pipe, and visualization pipes 280 | 281 | """ 282 | opt_problems = [opt_problem] if isinstance(opt_problem, OptProblem) else opt_problem 283 | viz_panel_pipes = {op.sim.name: op.sim.viz_panel() 284 | for op in opt_problems if op.sim is not None and op.source is not None} 285 | costs_pipe = Pipe(data=[]) 286 | 287 | metrics_panel_pipes = {op.sim.name: scalar_metrics_viz(metric_config=metric_config) 288 | for op in opt_problems if op.sim is not None and op.source is not None} 289 | 290 | return OptViz( 291 | cost_dmap=hv.DynamicMap(hv.Curve, streams=[costs_pipe]).opts(title='Cost Fn (𝓛)'), 292 | simulations_panels={name: v[0] for name, v in viz_panel_pipes.items()}, 293 | costs_pipe=costs_pipe, 294 | simulations_pipes={name: v[1] for name, v in viz_panel_pipes.items()}, 295 | metrics_panels={name: m[0] for name, m in metrics_panel_pipes.items()}, 296 | metrics_pipes={name: m[1] for name, m in metrics_panel_pipes.items()}, 297 | metric_config=metric_config 298 | ) 299 | -------------------------------------------------------------------------------- /simphox/parse.py: -------------------------------------------------------------------------------- 1 | from .typing import Excitation, SourceLabel 2 | 3 | 4 | def parse_excitation(excitation: Excitation): 5 | """Parse any excitation format into a list of tuples consisting of port name and mode index. 6 | 7 | Args: 8 | excitation: Excitation of various types 9 | 10 | Returns: 11 | List of tuples consisting of port name and mode index. 12 | 13 | """ 14 | if isinstance(excitation, str): 15 | return [(excitation, 0)] 16 | elif isinstance(excitation, tuple) or isinstance(excitation, list): 17 | if (isinstance(excitation[0], str) or isinstance(excitation[0], int)) and isinstance(excitation[1], int) and len(excitation) == 2: 18 | return [excitation] 19 | return sum([parse_excitation(mi) for mi in excitation], []) 20 | elif isinstance(excitation, dict): 21 | return sum([parse_excitation([(mi, idx) for idx in excitation[mi]]) for mi in excitation], []) 22 | 23 | 24 | def parse_source_port(source: SourceLabel): 25 | """Parse any acceptable source format into a dict between tuples of port name and mode index and weight. 26 | 27 | Args: 28 | source: Source of various types/formats 29 | 30 | Returns: 31 | Dictionary of tuples consisting of port name and mode index mapped to weights. 32 | 33 | """ 34 | if isinstance(source, str): 35 | return {(source, 0): 1} 36 | elif isinstance(source, tuple) or isinstance(source, list): 37 | if isinstance(source[0], str) and isinstance(source[1], int) and len(source) == 2: 38 | return {tuple(source): 1} 39 | return {k: v for mi in source for k, v in parse_source_port(mi).items()} 40 | elif isinstance(source, dict): 41 | return {(s, 0) if isinstance(s, str) else s: w for s, w in source.items()} 42 | -------------------------------------------------------------------------------- /simphox/primitives.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | try: # pardiso (using Intel MKL) is much faster than scipy's solver 4 | from .mkl import spsolve_pardiso as _spsolve 5 | except OSError: # if mkl isn't installed 6 | from scipy.sparse.linalg import spsolve as _spsolve 7 | 8 | import numpy as np 9 | import scipy.sparse as sp 10 | import jax.numpy as jnp 11 | import jax 12 | from jax.config import config 13 | import jax.experimental.host_callback as hcb 14 | 15 | config.parse_flags_with_absl() 16 | 17 | 18 | def _spsolve_hcb(ab: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]) -> jnp.ndarray: 19 | a_entries, b, a_indices = ab 20 | a = sp.coo_matrix((a_entries, (a_indices[0], a_indices[1])), shape=(b.size, b.size)) 21 | # caching is not necessary since the pardiso spsolve we are using caches the matrix factorization by default 22 | # replace with a better solver if available 23 | return _spsolve(a, b.flatten()) 24 | 25 | 26 | @jax.custom_vjp 27 | def spsolve(a_entries: jnp.ndarray, b: jnp.ndarray, a_indices: jnp.ndarray) -> jnp.ndarray: 28 | return hcb.call(_spsolve_hcb, (a_entries, b, a_indices), result_shape=b) 29 | 30 | 31 | def spsolve_fwd(a_entries: jnp.ndarray, b: jnp.ndarray, a_indices: jnp.ndarray) -> Tuple[jnp.ndarray, 32 | Tuple[jnp.ndarray, ...]]: 33 | x = spsolve(a_entries, b, a_indices) 34 | return x, (a_entries, x, a_indices) 35 | 36 | 37 | def spsolve_bwd(res, g): 38 | a_entries, x, a_indices = res 39 | lambda_ = spsolve(a_entries, g, a_indices[::-1]) 40 | i, j = a_indices 41 | return -lambda_[i] * x[j], lambda_, None 42 | 43 | 44 | spsolve.defvjp(spsolve_fwd, spsolve_bwd) 45 | 46 | 47 | def _coo_to_jnp(mat): 48 | mat.sort_indices() 49 | mat = mat.tocoo() 50 | return jnp.array(mat.data, dtype=np.complex128), jnp.vstack((jnp.array(mat.row), jnp.array(mat.col))) 51 | 52 | 53 | class TMOperator: 54 | """This class generates some helpful TE primitives based on the input discrete derivatives provided by the FDFD 55 | class for a 2D problem. 56 | 57 | Attributes: 58 | df: A list of forward discrete derivative in order (:code:`df_x`, :code:`df_y`, :code:`df_z`). 59 | db: A list of backward discrete derivative in order (:code:`db_x`, :code:`db_y`, :code:`db_z`). 60 | """ 61 | 62 | def __init__(self, df: List[sp.spmatrix], db: List[sp.spmatrix]): 63 | self.df, self.db = df, db 64 | data_x, self.x_indices = _coo_to_jnp(self.df[0] @ self.db[0]) 65 | data_y, self.y_indices = _coo_to_jnp(self.df[1] @ self.db[1]) 66 | self.size = (data_x.size, data_y.size) 67 | self.n = df[0].diagonal().size 68 | 69 | def compile_operator_along_axis(self, axis: int): 70 | """Compiles the TE mode operator along a certain axis (0 or 1) 71 | 72 | Args: 73 | axis: Axis along which to compute the operator. 74 | 75 | Returns: 76 | The contribution to the TE operator along axis 0 or 1 specified by the input. 77 | 78 | """ 79 | if axis != 0 and axis != 1: 80 | raise ValueError("axis must be either 0 or 1.") 81 | 82 | n = self.n 83 | size = self.size[axis] 84 | a = self.db[axis] 85 | b = self.df[axis] 86 | c_indices = (self.x_indices, self.y_indices)[axis] 87 | 88 | def _te_hcb(t: jnp.ndarray): 89 | tm = sp.diags(t) 90 | c = a.dot(tm).dot(b) 91 | c.sort_indices() 92 | c = c.tocoo() 93 | return c.data 94 | 95 | def _te_backward_hcb(g: jnp.ndarray) -> jnp.ndarray: 96 | g = sp.coo_matrix((g, (c_indices[1], c_indices[0])), shape=(n, n)) 97 | complex_res = b.dot(g.dot(a)).diagonal() 98 | return complex_res.real 99 | 100 | @jax.custom_vjp 101 | def te(t: jnp.ndarray): 102 | return hcb.call(_te_hcb, t, result_shape=jax.ShapeDtypeStruct((size,), np.complex128)) 103 | 104 | def te_fwd(t: jnp.ndarray): 105 | return te(t), None 106 | 107 | def te_bwd(_, g): 108 | v = hcb.call(_te_backward_hcb, g, result_shape=jax.ShapeDtypeStruct((n,), np.float)) 109 | return v, 110 | 111 | te.defvjp(te_fwd, te_bwd) 112 | 113 | return te 114 | -------------------------------------------------------------------------------- /simphox/transform.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | from jax.scipy.signal import convolve as conv 4 | from skimage.draw import disk 5 | 6 | from .typing import Union, List 7 | from .utils import Box 8 | 9 | 10 | def get_smooth_fn(beta: float, radius: float, eta: float = 0.5): 11 | """Using the sigmoid function and convolutional kernel provided in jax, we return a function that 12 | effectively binarizes the design respectively and smooths the density parameters. 13 | 14 | Args: 15 | beta: A multiplicative factor in the tanh function to effectively define how binarized the design should be 16 | radius: The radius of the convolutional kernel for smoothing 17 | eta: The average value of the design 18 | 19 | Returns: 20 | The smoothing function 21 | 22 | """ 23 | rr, cc = disk((radius, radius), radius + 1) 24 | kernel = np.zeros((2 * radius + 1, 2 * radius + 1), dtype=np.float64) 25 | kernel[rr, cc] = 1 26 | kernel = kernel / kernel.sum() 27 | 28 | def smooth(rho: jnp.ndarray): 29 | rho = conv(rho, kernel, mode='same') 30 | return jnp.divide(jnp.tanh(beta * eta) + jnp.tanh(beta * (rho - eta)), 31 | jnp.tanh(beta * eta) + jnp.tanh(beta * (1 - eta))) 32 | 33 | return smooth 34 | 35 | 36 | def get_symmetry_fn(ortho_x: bool = False, ortho_y: bool = False, diag_p: bool = False, diag_n: bool = False, 37 | avg: bool = False): 38 | """Get the array-based reflection symmetry function based on orthogonal or diagonal axes. 39 | 40 | Args: 41 | ortho_x: symmetry along x-axis (axis 0) 42 | ortho_y: symmetry along y-axis (axis 1) 43 | diag_p: symmetry along positive ([1, 1] plane) diagonal (shape of params must be square) 44 | diag_n: symmetry along negative ([1, -1] plane) diagonal (shape of params must be square) 45 | avg: Whether the symmetry should take the average (applies to ortho symmetries ONLY) 46 | 47 | Returns: 48 | The overall symmetry function 49 | """ 50 | identity = (lambda x: x) 51 | diag_n_fn = (lambda x: (x + x.T) / 2) if diag_p else identity 52 | diag_p_fn = (lambda x: (x + x[::-1, ::-1].T) / 2) if diag_n else identity 53 | if avg: 54 | ortho_x_fn = (lambda x: (x + x[::-1]) / 2) if ortho_x else identity 55 | ortho_y_fn = (lambda x: (x + x[:, ::-1]) / 2) if ortho_y else identity 56 | else: 57 | ortho_x_fn = (lambda x: x.at[-(x.shape[0] // 2 + 1):, :].set(x[:x.shape[0] // 2 + 1:, :][::-1, :])) if ortho_x else identity 58 | ortho_y_fn = (lambda x: x.at[:, -(x.shape[1] // 2 + 1):].set(x[:, :x.shape[1] // 2 + 1][:, ::-1])) if ortho_y else identity 59 | return lambda x: diag_p_fn(diag_n_fn(ortho_x_fn(ortho_y_fn(x)))) 60 | 61 | 62 | def get_mask_fn(rho_init: jnp.ndarray, box: Union[Box, List[Box]]): 63 | """Given an initial param set, this function defines the box region(s) where the params are allowed to change. 64 | 65 | Args: 66 | rho_init: initial rho definition 67 | box: Box (or list of boxes) defines position and orientation of the design region(s) 68 | 69 | Returns: 70 | The mask function 71 | 72 | """ 73 | mask = box.mask(rho_init) if isinstance(box, Box) else (sum([b.mask(rho_init) for b in box]) > 0).astype(np.float) 74 | return lambda rho: jnp.array(rho_init) * (1 - mask) + rho * mask 75 | -------------------------------------------------------------------------------- /simphox/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, List, Optional, Dict, Callable, Iterable 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import scipy.sparse as sp 6 | 7 | Shape2 = Tuple[int, int] 8 | Shape3 = Tuple[int, int, int] 9 | Size2 = Tuple[float, float] 10 | Size3 = Tuple[float, float, float] 11 | Size4 = Tuple[float, float, float, float] 12 | Shape = Union[Shape2, Shape3] 13 | Size = Union[Size2, Size3] 14 | Spacing = Union[float, Tuple[float, float, float]] 15 | Op = Callable[[np.ndarray], np.ndarray] 16 | SpSolve = Callable[[sp.spmatrix, np.ndarray], np.ndarray] 17 | Array = Union[jnp.ndarray, np.ndarray] 18 | State = Tuple[Array, Array, Optional[List[Array]], Optional[List[Array]]] 19 | MeasureInfo = Dict[str, List[int]] 20 | Excitation = Union[str, Tuple[str, int], Dict[str, List[int]], Iterable[Union[str, Tuple[str, int]]]] 21 | SourceLabel = Union[str, Dict[Tuple[str, int], float], Dict[str, float]] 22 | Source = Tuple[Optional[np.ndarray], Optional[np.ndarray], Tuple[slice, ...]] 23 | PortLabel = Union[str, int] 24 | PhaseParams = Tuple[np.ndarray, np.ndarray, np.ndarray] 25 | IndexSelect = Tuple[Union[slice, np.ndarray], ...] 26 | -------------------------------------------------------------------------------- /simphox/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import unitary_group 3 | import scipy.sparse as sp 4 | from pydantic.dataclasses import dataclass 5 | import jax.numpy as jnp 6 | 7 | from typing import Tuple, Union, Optional 8 | from copy import deepcopy 9 | import xarray as xr 10 | from scipy.special import beta as beta_func 11 | 12 | from .typing import List, Callable, Size2, Size3 13 | 14 | SMALL_NUMBER = 1e-20 15 | 16 | 17 | def fix_dataclass_init_docs(cls): 18 | """Fix the ``__init__`` documentation for a :class:`dataclasses.dataclass`. 19 | 20 | See Also: 21 | https://github.com/agronholm/sphinx-autodoc-typehints/issues/123 22 | 23 | Attributes: 24 | cls: The class whose docstring needs fixing 25 | 26 | Returns: 27 | The class that was passed so this function can be used as a decorator 28 | """ 29 | cls.__init__.__qualname__ = f'{cls.__name__}.__init__' 30 | return cls 31 | 32 | 33 | @fix_dataclass_init_docs 34 | @dataclass 35 | class Material: 36 | """Helper class for materials. 37 | 38 | Attributes: 39 | name: Name of the material. 40 | eps: Constant epsilon (relative permittivity) assigned for the material. 41 | facecolor: Facecolor in red-green-blue (RGB) for drawings (default is black or :code:`(0, 0, 0)`). 42 | """ 43 | name: str 44 | eps: float = 1. 45 | facecolor: Size3 = (0, 0, 0) 46 | 47 | def __str__(self): 48 | return self.name 49 | 50 | 51 | SILICON = Material('Silicon', 3.4784 ** 2, (0.3, 0.3, 0.3)) 52 | POLYSILICON = Material('Poly-Si', 3.4784 ** 2, (0.5, 0.5, 0.5)) 53 | AIR = Material('Air') 54 | OXIDE = Material('Oxide', 1.4442 ** 2, (0.6, 0, 0)) 55 | NITRIDE = Material('Nitride', 1.996 ** 2, (0, 0, 0.7)) 56 | LS_NITRIDE = Material('Low-Stress Nitride', facecolor=(0, 0.4, 1)) 57 | LT_OXIDE = Material('Low-Temp Oxide', 1.4442 ** 2, (0.8, 0.2, 0.2)) 58 | ALUMINUM = Material('Aluminum', facecolor=(0, 0.5, 0)) 59 | ALUMINA = Material('Alumina', 1.75, (0.2, 0, 0.2)) 60 | ETCH = Material('Etch') 61 | 62 | TEST_ZERO = Material('Zero', 0, (0, 0, 0)) 63 | TEST_ONE = Material('One', 1, (0, 0, 0)) 64 | TEST_INF = Material('Inf', 1e10, (0, 0, 0)) 65 | 66 | 67 | @fix_dataclass_init_docs 68 | @dataclass 69 | class Box: 70 | """Helper class for quickly generating functions for design region placements. 71 | 72 | Attributes: 73 | size: size of box 74 | spacing: spacing for pixelation 75 | material: :code:`Material` for this Box 76 | min: min x and min y of box 77 | """ 78 | size: Union[float, Size2] 79 | material: Optional[Material] = None 80 | spacing: float = 1 81 | min: Size2 = (0., 0.) 82 | 83 | def __post_init_post_parse__(self): 84 | self.size = (self.size, 0) if isinstance(self.size, float) else self.size 85 | self.eps = self.material.eps if self.material is not None else None 86 | 87 | @property 88 | def max(self): 89 | return self.min[0] + self.size[0], self.min[1] + self.size[1] 90 | 91 | @property 92 | def min_i(self): 93 | return int(self.min[0] / self.spacing), int(self.min[1] / self.spacing) 94 | 95 | @property 96 | def max_i(self): 97 | return int(self.max[0] / self.spacing), int(self.max[1] / self.spacing) 98 | 99 | @property 100 | def shape(self): 101 | return self.max_i[0] - self.min_i[0], self.max_i[1] - self.min_i[1] 102 | 103 | @property 104 | def center(self) -> Size2: 105 | return self.min[0] + self.size[0] / 2, self.min[1] + self.size[1] / 2 106 | 107 | @property 108 | def slice(self) -> Tuple[slice, slice]: 109 | return slice(self.min_i[0], self.max_i[0]), slice(self.min_i[1], self.max_i[1]) 110 | 111 | @property 112 | def copy(self) -> "Box": 113 | return deepcopy(self) 114 | 115 | def mask(self, array: Union[np.ndarray, jnp.ndarray]): 116 | mask = np.zeros_like(array) 117 | mask[self.slice[0], self.slice[1]] = 1.0 118 | return mask 119 | 120 | def translate(self, dx: float = 0, dy: float = 0) -> "Box": 121 | self.min = (self.min[0] + dx, self.min[1] + dy) 122 | return self 123 | 124 | def align(self, c: Union["Box", Tuple[float, float]]) -> "Box": 125 | center = c.center if isinstance(c, Box) else c 126 | self.translate(center[0] - self.center[0], center[1] - self.center[1]) 127 | return self 128 | 129 | def halign(self, c: Union["Box", float], left: bool = True, opposite: bool = True): 130 | x = self.min[0] if left else self.max[0] 131 | p = c if isinstance(c, float) or isinstance(c, int) \ 132 | else (c.min[0] if left and not opposite or opposite and not left else c.max[0]) 133 | self.translate(dx=p - x) 134 | return self 135 | 136 | def valign(self, c: Union["Box", float], bottom: bool = True, opposite: bool = True): 137 | y = self.min[1] if bottom else self.max[1] 138 | p = c if isinstance(c, float) or isinstance(c, int) \ 139 | else (c.min[1] if bottom and not opposite or opposite and not bottom else c.max[1]) 140 | self.translate(dy=p - y) 141 | return self 142 | 143 | 144 | def poynting_fn(axis: int = 2, use_jax: bool = False): 145 | ax = np.roll((1, 2, 0), -axis) 146 | xp = jnp if use_jax else np 147 | 148 | def poynting(e: np.ndarray, h: np.ndarray): 149 | e_cross = xp.stack([(e[ax[0]] + xp.roll(e[ax[0]], shift=1, axis=1)) / 2, 150 | (e[ax[1]] + xp.roll(e[ax[1]], shift=1, axis=0)) / 2]) 151 | h_cross = xp.stack([(h[ax[0]] + xp.roll(h[ax[0]], shift=1, axis=0)) / 2, 152 | (h[ax[1]] + xp.roll(h[ax[1]], shift=1, axis=1)) / 2]) 153 | return e_cross[ax[0]] * h_cross.conj()[ax[1]] - e_cross[ax[1]] * h_cross.conj()[ax[0]] 154 | 155 | return poynting 156 | 157 | 158 | def d2curl_op(d: List[sp.spmatrix]) -> sp.spmatrix: 159 | o = sp.csr_matrix((d[0].shape[0], d[0].shape[0])) 160 | return sp.bmat([[o, -d[2], d[1]], 161 | [d[2], o, -d[0]], 162 | [-d[1], d[0], o]]) 163 | 164 | 165 | def curl_fn(df: Callable[[np.ndarray, int], np.ndarray], use_jax: bool = False, beta: float = None): 166 | xp = jnp if use_jax else np 167 | if beta is not None: 168 | def _curl(f: np.ndarray): 169 | return xp.stack([df(f[2], 1) + 1j * beta * f[1], 170 | -1j * beta * f[0] - df(f[2], 0), 171 | df(f[1], 0) - df(f[0], 1)]) 172 | else: 173 | def _curl(f: np.ndarray): 174 | return xp.stack([df(f[2], 1) - df(f[1], 2), 175 | df(f[0], 2) - df(f[2], 0), 176 | df(f[1], 0) - df(f[0], 1)]) 177 | return _curl 178 | 179 | 180 | def curl_pml_fn(df: Callable[[np.ndarray, int], np.ndarray], use_jax: bool = False): 181 | xp = jnp if use_jax else np 182 | 183 | def _curl(f: np.ndarray, prev_df: np.ndarray, b_pml: np.ndarray): 184 | next_df = xp.stack( 185 | [df(f[2], 1), df(f[1], 2), 186 | df(f[0], 2), df(f[2], 0), 187 | df(f[1], 0), df(f[0], 1)] 188 | ) 189 | return next_df, xp.stack([next_df[0] + prev_df[0] * b_pml[1] - next_df[1] - prev_df[1] * b_pml[2], 190 | next_df[2] + prev_df[2] * b_pml[2] - next_df[3] - prev_df[3] * b_pml[0], 191 | next_df[4] + prev_df[4] * b_pml[0] - next_df[5] - prev_df[5] * b_pml[1]]) 192 | return _curl 193 | 194 | 195 | def yee_avg(params: np.ndarray, shift: int = 1) -> np.ndarray: 196 | p = params 197 | p_x = (p + np.roll(p, shift=shift, axis=1)) / 2 198 | p_y = (p + np.roll(p, shift=shift, axis=0)) / 2 199 | p_z = (p_y + np.roll(p_y, shift=shift, axis=1)) / 2 200 | return np.stack([p_x, p_y, p_z]) 201 | 202 | 203 | def yee_avg_2d_z(params: jnp.ndarray) -> jnp.ndarray: 204 | p = params 205 | p_y = (p + jnp.roll(p, shift=1, axis=0)) / 2 206 | p_z = (p_y + jnp.roll(p_y, shift=1, axis=1)) / 2 207 | return p_z 208 | 209 | 210 | def yee_avg_jax(params: jnp.ndarray) -> jnp.ndarray: 211 | p = params 212 | p_x = (p + jnp.roll(p, shift=1, axis=1)) / 2 213 | p_y = (p + jnp.roll(p, shift=1, axis=0)) / 2 214 | p_z = (p_y + jnp.roll(p_y, shift=1, axis=1)) / 2 215 | return jnp.stack((p_x, p_y, p_z)) 216 | 217 | 218 | def pml_sigma(pos: np.ndarray, thickness: int, exp_scale: float, log_reflection: float, absorption_corr: float): 219 | d = np.vstack(((pos[:-1] + pos[1:]) / 2, pos[:-1])).T 220 | d_pml = np.vstack(( 221 | (d[thickness] - d[:thickness]) / (d[thickness] - pos[0]), 222 | np.zeros_like(d[thickness:-thickness]), 223 | (d[-thickness:] - d[-thickness]) / (pos[-1] - d[-thickness]) 224 | )).T 225 | return (exp_scale + 1) * (d_pml ** exp_scale) * log_reflection / (2 * absorption_corr) 226 | 227 | 228 | # Real-time splitter metrics 229 | def splitter_metrics(sparams: xr.DataArray): 230 | powers = np.abs(sparams) ** 2 231 | return { 232 | 'reflectivity': powers.loc["b0"] / (powers.loc["b0"] + powers.loc["b1"]), 233 | 'transmissivity': powers.loc["b1"] / (powers.loc["b0"] + powers.loc["b1"]), 234 | 'reflection': powers.loc["a0"], 235 | 'insertion': powers.sum(), 236 | 'upper': powers.loc["b0"], 237 | 'lower': powers.loc["b1"], 238 | } 239 | 240 | 241 | def random_vector(n: int, normed: bool = False, is_complex: bool = True): 242 | """Generate a random complex normal tensor. 243 | 244 | Args: 245 | n: Number of inputs. 246 | normed: Whether to norm the random complex vector so that the norm of the vector is 1. 247 | is_complex: Return a complex vector 248 | 249 | Returns: 250 | The random complex normal vector. 251 | 252 | """ 253 | z = random_tensor(n, is_complex) 254 | return z / np.linalg.norm(z) if normed else z 255 | 256 | 257 | def random_tensor(size: Union[int, Tuple], is_complex: bool = True) -> np.ndarray: 258 | """Generate a random complex normal tensor. 259 | 260 | Args: 261 | size: Number of inputs or shape. 262 | is_complex: Return a complex vector 263 | 264 | Returns: 265 | The random complex normal tensor. 266 | 267 | """ 268 | size = (int(size),) if np.isscalar(size) else size 269 | return np.array(0.5 * np.random.randn(*size) + 0.5 * np.random.randn(*size) * 1j) if is_complex else np.random.randn(*size) 270 | 271 | 272 | def random_unitary(n: int) -> np.ndarray: 273 | """Generate a random unitary matrix. 274 | 275 | Args: 276 | n: Number of inputs and outputs 277 | 278 | Returns: 279 | The random complex normal vector. 280 | 281 | """ 282 | return unitary_group.rvs(n) 283 | 284 | 285 | def normalized_error(u: np.ndarray, use_jax: bool = False): 286 | """Normalized fidelity cost function. 287 | 288 | Args: 289 | u: the true (target) unitary, :math:`U \\in \\mathrm{U}(N)`. 290 | use_jax: Use JAX for the normalized fidelity function (for optimizations) 291 | 292 | Returns: 293 | A function that accepts :code:`uhat` the estimated unitary (not necessarily unitary), :math:`\\widehat{U}` 294 | and returns the fidelity measurement. 295 | 296 | """ 297 | 298 | xp = jnp if use_jax else np 299 | u = jnp.array(u) if use_jax else u 300 | return lambda uhat: xp.sqrt( 301 | 1 - xp.abs(xp.trace(u.conj().T @ uhat)) ** 2 / xp.abs(xp.trace(uhat.conj().T @ uhat)) ** 2) 302 | 303 | 304 | def beta_pdf(x, a, b): 305 | return (x ** (a - 1) * (1 - x) ** (b - 1)) / beta_func(a, b) 306 | 307 | 308 | def beta_phase(theta, a, b): 309 | x = np.cos(theta / 2) ** 2 310 | return beta_pdf(x, a, b) * np.sin(theta / 2) * np.cos(theta / 2) / np.pi 311 | 312 | 313 | def gaussian_fft(profiles: np.ndarray, pulse_width: float, center_wavelength: float, dt: float, 314 | t0: float = None, linear_chirp: float = 0): 315 | """Gaussian FFT for measurement. 316 | 317 | Args: 318 | profiles: profiles measured over time 319 | pulse_width: Gaussian pulse width 320 | center_wavelength: center wavelength 321 | dt: time step size 322 | t0: peak time (default to be central time step) 323 | linear_chirp: linear chirp coefficient (default to be 0) 324 | 325 | Returns: 326 | the Gaussian source discretized in time 327 | 328 | """ 329 | k0 = 2 * np.pi / center_wavelength 330 | t = np.arange(profiles.shape[0]) * dt 331 | t0 = t[t.size // 2] if t0 is None else t0 332 | g = np.fft.fft(np.exp(1j * k0 * (t - t0)) * np.exp((-pulse_width + 1j * linear_chirp) * (t - t0) ** 2)) 333 | return np.fft.ifft(g * profiles, axis=0) 334 | 335 | 336 | def gaussian_fn(wavelength: float, pulse_width: float = 0, fwidth: float = np.inf, 337 | start_time: float = 0, center_time_factor: float = 5.0, linear_chirp: float = 0): 338 | """A Gaussian function for sources. 339 | 340 | Args: 341 | wavelength: The carrier wavelength for the electromagnetic radiation. 342 | pulse_width: The Gaussian envelope pulse width :math:`w` in wavelength units. 343 | fwidth: The Gaussian envelope pulse width in :math:`w_f = 2 \\pi / w` frequency units. 344 | start_time: The start time for the Gaussian. 345 | center_time_factor: Decide the time :math:`t_0`: to center the Gaussian, 346 | such that :code:`t0 = center_factor * k0`. 347 | linear_chirp: linear chirp coefficient (default to be 0) 348 | 349 | Returns: 350 | 351 | """ 352 | if pulse_width <= 0 and pulse_width == np.inf: 353 | raise ValueError("Bandwidth must be positive or fwidth must be noninfinite.") 354 | 355 | fwidth = 2 * np.pi / pulse_width if pulse_width > 0 else fwidth 356 | pulse_width = 2 * np.pi / fwidth 357 | k0 = 2 * np.pi / wavelength 358 | t0 = start_time + pulse_width * center_time_factor 359 | 360 | def _gaussian(t): 361 | return np.exp(1j * k0 * (t - t0)) * np.exp((-fwidth + 1j * linear_chirp) * (t - t0) ** 2)\ 362 | if t > start_time else 0 363 | 364 | return _gaussian 365 | 366 | 367 | def shift_slice(slice_to_shift: Tuple[Union[slice, int], ...], shift: int = 1, axis=0): 368 | """Shift slice tuple by some amount. 369 | 370 | Args: 371 | slice_to_shift: 372 | shift: Shift 373 | axis: Axis to shift (ignore if the slice start OR stop is None) 374 | 375 | Returns: 376 | 377 | """ 378 | slices = list(slice_to_shift) 379 | if isinstance(slices[axis], int): 380 | slices[axis] += shift 381 | elif isinstance(slices[axis], slice): 382 | if isinstance(slices[axis].start, int) and isinstance(slices[axis].stop, int): 383 | slices[axis] += slice(slices[axis].start + shift, slices[axis].stop + shift) 384 | return tuple(slices) 385 | -------------------------------------------------------------------------------- /simphox/viz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | 4 | import xarray 5 | from .typing import Tuple, Optional, List 6 | 7 | try: 8 | HOLOVIEWS_IMPORTED = True 9 | import holoviews as hv 10 | from holoviews.streams import Pipe 11 | from holoviews import opts 12 | import panel as pn 13 | from bokeh.models import Range1d, LinearAxis 14 | from bokeh.models.renderers import GlyphRenderer 15 | from bokeh.plotting.figure import Figure 16 | except ImportError: 17 | HOLOVIEWS_IMPORTED = False 18 | 19 | try: 20 | K3D_IMPORTED = True 21 | import k3d 22 | from k3d import Plot 23 | except ImportError: 24 | K3D_IMPORTED = False 25 | 26 | from matplotlib import colors as mcolors 27 | 28 | 29 | def _plot_twinx_bokeh(plot, _): 30 | """Hook to plot data on a secondary (twin) axis on a Holoviews Plot with Bokeh backend. 31 | 32 | Args: 33 | plot: Holoviews plot object to hook for twinx 34 | 35 | See Also: 36 | The code was copied from a comment in https://github.com/holoviz/holoviews/issues/396. 37 | - http://holoviews.org/user_guide/Customizing_Plots.html#plot-hooks 38 | - https://docs.bokeh.org/en/latest/docs/user_guide/plotting.html#twin-axes 39 | 40 | """ 41 | fig: Figure = plot.state 42 | glyph_first: GlyphRenderer = fig.renderers[0] # will be the original plot 43 | glyph_last: GlyphRenderer = fig.renderers[-1] # will be the new plot 44 | right_axis_name = "twiny" 45 | # Create both axes if right axis does not exist 46 | if right_axis_name not in fig.extra_y_ranges.keys(): 47 | # Recreate primary axis (left) 48 | y_first_name = glyph_first.glyph.y 49 | y_first_min = glyph_first.data_source.data[y_first_name].min() 50 | y_first_max = glyph_first.data_source.data[y_first_name].max() 51 | y_first_offset = (y_first_max - y_first_min) * 0.1 52 | fig.y_range = Range1d( 53 | start=y_first_min - y_first_offset, 54 | end=y_first_max + y_first_offset 55 | ) 56 | fig.y_range.name = glyph_first.y_range_name 57 | # Create secondary axis (right) 58 | y_last_name = glyph_last.glyph.y 59 | y_last_min = glyph_last.data_source.data[y_last_name].min() 60 | y_last_max = glyph_last.data_source.data[y_last_name].max() 61 | y_last_offset = (y_last_max - y_last_min) * 0.1 62 | fig.extra_y_ranges = {right_axis_name: Range1d( 63 | start=y_last_min - y_last_offset, 64 | end=y_last_max + y_last_offset 65 | )} 66 | fig.add_layout(LinearAxis(y_range_name=right_axis_name, axis_label=glyph_last.glyph.y), "right") 67 | # Set right axis for the last glyph added to the figure 68 | glyph_last.y_range_name = right_axis_name 69 | 70 | 71 | def get_extent_2d(shape, spacing: Optional[float] = None): 72 | """2D extent 73 | 74 | Args: 75 | shape: shape of the elements to plot 76 | spacing: spacing between grid points (assumed to be isotropic) 77 | 78 | Returns: 79 | The extent in 2D. 80 | 81 | """ 82 | return (0, shape[0] * spacing, 0, shape[1] * spacing) if spacing else (0, shape[0], 0, shape[1]) 83 | 84 | 85 | def plot_eps_2d(ax, eps: np.ndarray, spacing: Optional[float] = None, cmap: str = 'gray'): 86 | """Plot eps in 2D 87 | 88 | Args: 89 | ax: Matplotlib axis handle 90 | eps: epsilon permittivity 91 | spacing: spacing between grid points (assumed to be isotropic) 92 | cmap: colormap for field array (we highly recommend RdBu) 93 | 94 | """ 95 | extent = get_extent_2d(eps.shape, spacing) 96 | if spacing: # in microns! 97 | ax.set_ylabel(r'$y$ ($\mu$m)') 98 | ax.set_xlabel(r'$x$ ($\mu$m)') 99 | ax.imshow(eps.T, cmap=cmap, origin='lower', alpha=1, extent=extent) 100 | 101 | 102 | def plot_field_2d(ax, field: np.ndarray, eps: Optional[np.ndarray] = None, spacing: Optional[float] = None, 103 | cmap: str = 'RdBu', mat_cmap: str = 'gray', alpha: float = 0.8, vmax=None): 104 | """Plot field in 2D 105 | 106 | Args: 107 | ax: Matplotlib axis handle 108 | field: field to plot 109 | eps: epsilon permittivity for overlaying field onto materials 110 | spacing: spacing between grid points (assumed to be isotropic) 111 | cmap: colormap for field array (we highly recommend RdBu) 112 | mat_cmap: colormap for eps array (we recommend gray) 113 | alpha: transparency of the plots to visualize overlay 114 | 115 | """ 116 | extent = get_extent_2d(field.shape, spacing) 117 | if spacing: # in microns! 118 | ax.set_ylabel(r'$y$ ($\mu$m)') 119 | ax.set_xlabel(r'$x$ ($\mu$m)') 120 | if eps is not None: 121 | plot_eps_2d(ax, eps, spacing, mat_cmap) 122 | im_val = field 123 | vmax = np.max(im_val * np.sign(field.flat[np.abs(field).argmax()])) if vmax is None else vmax 124 | norm = mcolors.TwoSlopeNorm(vcenter=0, vmin=-vmax, vmax=vmax) 125 | ax.imshow(im_val.T, cmap=cmap, origin='lower', alpha=alpha, extent=extent, norm=norm) 126 | 127 | 128 | def plot_eps_1d(ax, eps: Optional[np.ndarray], spacing: Optional[float] = None, 129 | color: str = 'blue', units: str = r"$\mu$m", axis_label_rotation: float = 90): 130 | """Plot eps in 1D. 131 | 132 | Args: 133 | ax: Matplotlib axis handle 134 | eps: epsilon permittivity for overlaying field onto materials 135 | spacing: spacing between grid points (assumed to be isotropic) 136 | color: Color to plot the epsilon 137 | units: Units for plotting (default microns) 138 | axis_label_rotation: Rotate the axis label in case a plot is made with shared axes. 139 | 140 | """ 141 | x = np.arange(eps.shape[0]) * spacing 142 | if spacing: 143 | ax.set_xlabel(rf'$x$ ({units})') 144 | ax.set_ylabel(r'Relative permittivity ($\epsilon$)', color=color, 145 | rotation=axis_label_rotation, labelpad=15) 146 | ax.plot(x, eps, color=color) 147 | ax.tick_params(axis='y', labelcolor=color) 148 | 149 | 150 | def plot_field_1d(ax, field: np.ndarray, field_name: str, eps: Optional[np.ndarray] = None, 151 | spacing: Optional[float] = None, color: str = 'red', eps_color: str = 'blue', 152 | units: str = r"$\mu$m"): 153 | """Plot field in 1D 154 | 155 | Args: 156 | ax: Matplotlib axis handle. 157 | field: Field to plot. 158 | field_name: Name of the field being plotted 159 | spacing: spacing between grid points (assumed to be isotropic). 160 | color: Color to plot the epsilon 161 | units: Units for plotting (default microns) 162 | 163 | """ 164 | x = np.arange(field.shape[0]) * spacing 165 | if spacing: # in microns! 166 | ax.set_xlabel(rf'$x$ ({units})') 167 | ax.set_ylabel(rf'{field_name}', color=color) 168 | ax.plot(x, field.real, color=color) 169 | ax.tick_params(axis='y', labelcolor=color) 170 | if eps is not None: 171 | ax_eps = ax.twinx() 172 | plot_eps_1d(ax_eps, eps, spacing, eps_color, units, axis_label_rotation=270) 173 | 174 | 175 | def hv_field_1d(field: np.ndarray, eps: Optional[np.ndarray] = None, spacing: Optional[float] = None, 176 | width: float = 600): 177 | x = np.arange(field.shape[0]) * spacing 178 | field = field.squeeze().real / np.max(np.abs(field)) 179 | c1 = hv.Curve((x, (field + 1) / 2), kdims='x', vdims='field').opts( 180 | width=width, show_grid=True, framewise=True, yaxis='left', ylim=(-1, 1)) 181 | c2 = hv.Curve((x, eps), kdims='x', vdims='eps').opts(width=width, show_grid=True, framewise=True, color='red', 182 | hooks=[_plot_twinx_bokeh]) 183 | return c1 * c2 184 | 185 | 186 | def hv_field_2d(field: np.ndarray, eps: Optional[np.ndarray] = None, spacing: Optional[float] = None, 187 | cmap: str = 'RdBu', mat_cmap: str = 'gray', alpha: float = 0.2, width: float = 600): 188 | extent = get_extent_2d(field.squeeze().T.shape, spacing) 189 | bounds = (extent[0], extent[2], extent[1], extent[3]) 190 | aspect = (extent[3] - extent[2]) / (extent[1] - extent[0]) 191 | field_img = hv.Image(field.squeeze().T.real / np.max(np.abs(field)), 192 | bounds=bounds, vdims='field').opts(cmap=cmap, aspect=aspect, frame_width=width) 193 | eps_img = hv.Image(eps.T / np.max(eps), bounds=bounds).opts(cmap=mat_cmap, alpha=alpha, aspect=aspect, frame_width=width) 194 | return field_img.redim.range(field=(-1, 1)) * eps_img 195 | 196 | 197 | def hv_power_1d(power: np.ndarray, eps: Optional[np.ndarray] = None, spacing: Optional[float] = None, 198 | width: float = 600): 199 | x = np.arange(power.shape[0]) * spacing 200 | power = power.squeeze().real / np.max(np.abs(power)) 201 | c1 = hv.Curve((x, power), kdims='x', vdims='field').opts(width=width, show_grid=True, framewise=True, 202 | yaxis='left', ylim=(-1, 1)) 203 | c2 = hv.Curve((x, eps), kdims='x', vdims='eps').opts(width=width, show_grid=True, framewise=True, color='red', 204 | hooks=[_plot_twinx_bokeh]) 205 | return c1 * c2 206 | 207 | 208 | def hv_power_2d(power: np.ndarray, eps: Optional[np.ndarray] = None, spacing: Optional[float] = None, 209 | cmap: str = 'hot', mat_cmap: str = 'gray', alpha: float = 0.2, width: float = 600): 210 | extent = get_extent_2d(power.squeeze().T.shape, spacing) 211 | bounds = (extent[0], extent[2], extent[1], extent[3]) 212 | aspect = (extent[3] - extent[2]) / (extent[1] - extent[0]) 213 | power_img = hv.Image(power.squeeze().T.real / np.max(np.abs(power)), 214 | bounds=bounds, vdims='power').opts(cmap=cmap, aspect=aspect, frame_width=width) 215 | eps_img = hv.Image(eps.T / np.max(eps), bounds=bounds).opts(cmap=mat_cmap, alpha=alpha, 216 | aspect=aspect, frame_width=width) 217 | return power_img.redim.range(power=(0, 1)) * eps_img 218 | 219 | 220 | def plot_power_2d(ax, power: np.ndarray, eps: Optional[np.ndarray] = None, spacing: Optional[float] = None, 221 | cmap: str = 'hot', mat_cmap: str = 'gray', alpha: float = 0.8, vmax=None): 222 | """Plot the power (computed using Poynting) in 2D 223 | 224 | Args: 225 | ax: Matplotlib axis handle 226 | power: power array of size (X, Y) 227 | eps: epsilon for overlay with materials 228 | spacing: spacing between grid points (assumed to be isotropic) 229 | cmap: colormap for power array 230 | mat_cmap: colormap for eps array (we recommend gray) 231 | alpha: transparency of the plots to visualize overlay 232 | 233 | """ 234 | extent = get_extent_2d(power.shape, spacing) 235 | if spacing: # in microns! 236 | ax.set_ylabel(r'$y$ ($\mu$m)') 237 | ax.set_xlabel(r'$x$ ($\mu$m)') 238 | if eps is not None: 239 | plot_eps_2d(ax, eps, spacing, mat_cmap) 240 | vmax = np.max(power) if vmax is None else vmax 241 | ax.imshow(power.T, cmap=cmap, origin='lower', alpha=alpha, extent=extent, vmax=vmax) 242 | 243 | 244 | def plot_power_3d(plot: "Plot", power: np.ndarray, eps: Optional[np.ndarray] = None, axis: int = 0, 245 | spacing: float = 1, color_range: Tuple[float, float] = None, alpha: float = 100, 246 | samples: float = 1200): 247 | """Plot the 3d power in a notebook given the fields :math:`E` and :math:`H`. 248 | 249 | Args: 250 | plot: K3D plot handle (NOTE: this is for plotting in a Jupyter notebook) 251 | power: power (either Poynting field of size (3, X, Y, Z) or power of size (X, Y, Z)) 252 | eps: permittivity (if specified, plot with default options) 253 | axis: pick the correct axis if inputting power in Poynting field 254 | spacing: spacing between grid points (assumed to be isotropic) 255 | color_range: color range for visualization (if none, use half maximum value of field) 256 | alpha: alpha for k3d plot 257 | samples: samples for k3d plot rendering 258 | 259 | Returns: 260 | 261 | """ 262 | 263 | if not K3D_IMPORTED: 264 | raise ImportError("Need to install k3d for this function to work.") 265 | 266 | power = power[axis] if power.ndim == 4 else power 267 | color_range = (0, np.max(power) / 2) if color_range is None else color_range 268 | 269 | if eps is not None: 270 | plot_eps_3d(plot, eps, spacing=spacing) # use defaults for now 271 | 272 | power_volume = k3d.volume( 273 | power.transpose((2, 1, 0)), 274 | alpha_coef=alpha, 275 | samples=samples, 276 | color_range=color_range, 277 | color_map=(np.array(k3d.colormaps.matplotlib_color_maps.hot).reshape(-1, 4)).astype(np.float32), 278 | compression_level=8, 279 | name='power' 280 | ) 281 | 282 | bounds = [0, power.shape[0] * spacing, 0, power.shape[1] * spacing, 0, power.shape[2] * spacing] 283 | power_volume.transform.bounds = bounds 284 | plot += power_volume 285 | 286 | 287 | def plot_field_3d(plot: "Plot", field: np.ndarray, eps: Optional[np.ndarray] = None, axis: int = 1, 288 | imag: bool = False, spacing: float = 1, 289 | alpha: float = 100, samples: float = 1200, color_range: Tuple[float, float] = None): 290 | """ 291 | 292 | Args: 293 | plot: K3D plot handle (NOTE: this is for plotting in a Jupyter notebook) 294 | field: field to plot 295 | eps: permittivity (if specified, plot with default options) 296 | axis: pick the correct axis for power in Poynting vector form 297 | imag: whether to use the imaginary (instead of real) component of the field 298 | spacing: spacing between grid points (assumed to be isotropic) 299 | color_range: color range for visualization (if none, use half maximum value of field) 300 | alpha: alpha for k3d plot 301 | samples: samples for k3d plot rendering 302 | 303 | Returns: 304 | 305 | """ 306 | 307 | if not K3D_IMPORTED: 308 | raise ImportError("Need to install k3d for this function to work.") 309 | 310 | field = field[axis] if field.ndim == 4 else field 311 | field = field.imag if imag else field.real 312 | color_range = np.asarray((0, np.max(field)) if color_range is None else color_range) 313 | 314 | if eps is not None: 315 | plot_eps_3d(plot, eps, spacing=spacing) # use defaults for now 316 | 317 | bounds = [0, field.shape[0] * spacing, 0, field.shape[1] * spacing, 0, field.shape[2] * spacing] 318 | 319 | pos_e_volume = k3d.volume( 320 | volume=field.transpose((2, 1, 0)), 321 | alpha_coef=alpha, 322 | samples=samples, 323 | color_range=color_range, 324 | color_map=(np.array(k3d.colormaps.matplotlib_color_maps.RdBu).reshape(-1, 4)).astype(np.float32), 325 | compression_level=8, 326 | name='pos' 327 | ) 328 | 329 | neg_e_volume = k3d.volume( 330 | volume=-field.transpose((2, 1, 0)), 331 | alpha_coef=alpha, 332 | samples=1200, 333 | color_range=color_range, 334 | color_map=(np.array(k3d.colormaps.matplotlib_color_maps.RdBu_r).reshape(-1, 4)).astype(np.float32), 335 | compression_level=8, 336 | name='neg' 337 | ) 338 | 339 | neg_e_volume.transform.bounds = bounds 340 | pos_e_volume.transform.bounds = bounds 341 | 342 | plot += neg_e_volume 343 | plot += pos_e_volume 344 | 345 | 346 | def plot_eps_3d(plot: "Plot", eps: Optional[np.ndarray] = None, spacing: float = 1, 347 | color_range: Tuple[float, float] = None, alpha: float = 100, samples: float = 1200): 348 | """ 349 | 350 | Args: 351 | plot: K3D plot handle (NOTE: this is for plotting in a Jupyter notebook) 352 | eps: relative permittivity 353 | spacing: spacing between grid points (assumed to be isotropic) 354 | color_range: color range for visualization (if none, use half maximum value of field) 355 | alpha: alpha for k3d plot 356 | samples: samples for k3d plot rendering 357 | 358 | Returns: 359 | 360 | """ 361 | 362 | if not K3D_IMPORTED: 363 | raise ImportError("Need to install k3d for this function to work.") 364 | 365 | color_range = (1, np.max(eps)) if color_range is None else color_range 366 | 367 | eps_volume = k3d.volume( 368 | eps.transpose((2, 1, 0)), 369 | alpha_coef=alpha, 370 | samples=samples, 371 | color_map=(np.array(k3d.colormaps.matplotlib_color_maps.Greens).reshape(-1, 4)).astype(np.float32), 372 | compression_level=8, 373 | color_range=color_range, 374 | name='epsilon' 375 | ) 376 | 377 | bounds = [0, eps.shape[0] * spacing, 0, eps.shape[1] * spacing, 0, eps.shape[2] * spacing] 378 | eps_volume.transform.bounds = bounds 379 | plot += eps_volume 380 | 381 | 382 | def scalar_metrics_viz(metric_config: Dict[str, List[str]]): 383 | if not HOLOVIEWS_IMPORTED: 384 | raise ImportError("Holoviews not imported, cannot visualize") 385 | metrics_pipe = {title: Pipe(data=xarray.DataArray( 386 | data=np.asarray([[] for _ in metric_config[title]]), 387 | coords={ 388 | 'metric': metric_config[title], 389 | 'iteration': np.arange(0) 390 | }, 391 | dims=['metric', 'iteration'], 392 | name=title 393 | )) for title in metric_config} 394 | metrics_dmaps = [ 395 | hv.DynamicMap(lambda data: hv.Dataset(data).to(hv.Curve, kdims=['iteration']).overlay('metric'), 396 | streams=[metrics_pipe[title]]).opts(opts.Curve(framewise=True, shared_axes=False, title=title)) 397 | for title in metric_config 398 | ] 399 | return pn.Row(*metrics_dmaps), metrics_pipe 400 | -------------------------------------------------------------------------------- /tests/circuit_test.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from simphox.utils import random_vector 3 | from simphox.circuit.vector import tree, hessian_fd, hessian_vector_unit, PhaseStyle 4 | from simphox.circuit import cascade, vector_unit, rectangular, balanced_tree 5 | from scipy.stats import unitary_group 6 | import pytest 7 | from itertools import product, zip_longest 8 | 9 | import numpy as np 10 | import jax 11 | import jax.numpy as jnp 12 | from jax import grad 13 | jax.config.update('jax_platform_name', 'cpu') 14 | jax.config.update("jax_enable_x64", True) 15 | 16 | 17 | N = [2, 4, 7, 10, 15, 16] 18 | 19 | np.random.seed(0) 20 | RAND_VECS = [random_vector(n, normed=True) for n in N] 21 | RAND_UNITARIES = [unitary_group.rvs(n) for n in N] 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "n, balanced, expected_node_idxs, expected_num_columns, expected_num_top, expected_num_bottom", 26 | [ 27 | (6, True, [2, 0, 1, 3, 4], 3, [3, 1, 1, 1, 1], [3, 2, 1, 2, 1]), 28 | (8, True, [3, 1, 0, 2, 5, 4, 6], 3, [4, 2, 1, 1, 2, 1, 1], [4, 2, 1, 1, 2, 1, 1]), 29 | (11, True, [4, 1, 0, 2, 3, 7, 5, 6, 8, 9], 4, [5, 2, 1, 1, 1, 3, 1, 1, 1, 1], [6, 3, 1, 2, 1, 3, 2, 1, 2, 1]), 30 | (6, False, [4, 3, 2, 1, 0], 5, [1, 2, 3, 4, 5], [1, 1, 1, 1, 1]), 31 | (8, False, [6, 5, 4, 3, 2, 1, 0], 7, [1, 2, 3, 4, 5, 6, 7], [1, 1, 1, 1, 1, 1, 1]), 32 | (11, False, [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 10, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 33 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 34 | ] 35 | ) 36 | def test_tree_network(n: int, balanced: bool, expected_node_idxs: np.ndarray, expected_num_columns: np.ndarray, 37 | expected_num_top: np.ndarray, expected_num_bottom: np.ndarray): 38 | circuit = tree(n, balanced=balanced) 39 | np.testing.assert_allclose(circuit.node_idxs, expected_node_idxs) 40 | np.testing.assert_allclose(circuit.num_columns, expected_num_columns) 41 | np.testing.assert_allclose(circuit.beta, expected_num_top) 42 | np.testing.assert_allclose(circuit.alpha, expected_num_bottom) 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "v, balanced, phase_style", 47 | product(RAND_VECS, [True, False], [PhaseStyle.TOP, PhaseStyle.BOTTOM]) 48 | ) 49 | def test_vector_configure(v: np.ndarray, balanced: bool, phase_style: PhaseStyle): 50 | np.random.seed(0) 51 | circuit, _ = vector_unit(v, balanced=balanced, phase_style=phase_style) 52 | res = circuit.matrix_fn(use_jax=False)(circuit.params) @ v 53 | np.testing.assert_allclose(res, np.eye(v.size)[v.size - 1], atol=1e-6) 54 | 55 | 56 | @pytest.mark.parametrize( 57 | "u, balanced", 58 | product(RAND_UNITARIES, [True, False]) 59 | ) 60 | def test_unitary_configure(u: np.ndarray, balanced: bool): 61 | circuit = cascade(u, balanced=balanced) 62 | np.testing.assert_allclose(circuit.matrix(), u, atol=1e-6) 63 | 64 | 65 | @pytest.mark.parametrize( 66 | "u, num_columns", 67 | zip_longest(RAND_UNITARIES, [2 * n - 3 for n in N]) 68 | ) 69 | def test_triangular_columns(u: np.ndarray, num_columns: int): 70 | circuit = cascade(u, balanced=False) 71 | np.testing.assert_allclose(circuit.num_columns, num_columns, atol=1e-6) 72 | 73 | 74 | @pytest.mark.parametrize( 75 | "u, num_columns", 76 | zip_longest(RAND_UNITARIES, [1, 5, 14, 25, 45, 49]) 77 | ) 78 | def test_cascade_columns(u: np.ndarray, num_columns: int): 79 | circuit = cascade(u, balanced=True) 80 | np.testing.assert_allclose(circuit.num_columns, num_columns, atol=1e-6) 81 | 82 | 83 | @pytest.mark.parametrize( 84 | "u", RAND_UNITARIES 85 | ) 86 | def test_rectangular(u: np.ndarray): 87 | circuit = rectangular(u) 88 | np.testing.assert_allclose(circuit.matrix(), u, atol=1e-6) 89 | 90 | 91 | @pytest.mark.parametrize( 92 | "u", RAND_UNITARIES 93 | ) 94 | def test_inverse(u: np.ndarray): 95 | circuit = rectangular(u) 96 | np.testing.assert_allclose(circuit.matrix(), circuit.matrix(back=True).T, atol=1e-6) 97 | 98 | 99 | @pytest.mark.parametrize( 100 | "u", RAND_UNITARIES 101 | ) 102 | def test_program_null_basis(u: np.ndarray): 103 | circuit = rectangular(u) 104 | basis = circuit.nullification_basis 105 | params = copy.deepcopy(circuit.params) 106 | circuit.program_by_null_basis(basis) 107 | for param, param_expected in zip(params, circuit.params): 108 | np.testing.assert_allclose(param, param_expected, rtol=1e-4) 109 | 110 | 111 | @pytest.mark.parametrize( 112 | "u", RAND_UNITARIES 113 | ) 114 | def test_error_correction(u: np.ndarray): 115 | tree = balanced_tree(u) 116 | tree_error = balanced_tree(u, bs_error_mean_std=(0, 0.02)) 117 | x = tree.matrix_fn()()[-1] 118 | np.testing.assert_allclose(np.abs(tree_error.matrix_fn()(inputs=x.conj())[-1]), 1, atol=1e-3) 119 | 120 | 121 | @pytest.mark.parametrize( 122 | "u, balanced", product(RAND_UNITARIES[:3], [True, False]) 123 | ) 124 | def test_hessian(u: np.ndarray, balanced: bool): 125 | h = hessian_vector_unit(u[0], balanced=balanced) 126 | h_fd = hessian_fd(u[0], balanced=balanced) 127 | np.testing.assert_allclose(h, h_fd, atol=1e-4) 128 | 129 | 130 | @pytest.mark.parametrize( 131 | "u, balanced", product(RAND_UNITARIES[:3], [True, False]) 132 | ) 133 | def test_hessian_correlated_error(u: np.ndarray, balanced: bool): 134 | mesh = balanced_tree(u[0]) 135 | vhat = mesh.matrix(params=(mesh.thetas + 0.0001, mesh.phis + 0.0001, mesh.gammas))[-1] 136 | np.testing.assert_allclose( 137 | np.sum(hessian_vector_unit(u[0], balanced=True)), 138 | 2 * np.linalg.norm(u[0] - vhat) ** 2 / 0.0001 ** 2, rtol=1e-3 139 | ) 140 | 141 | @pytest.mark.parametrize( 142 | "u, input_type, all_analog", product(RAND_UNITARIES[:3], ['ones', 'id'], [True, False]) 143 | ) 144 | def test_in_situ_matrix_fn(u: np.ndarray, input_type: str, all_analog: bool): 145 | mesh = rectangular(u) 146 | inputs = jnp.ones(u.shape[0], dtype=jnp.complex64) if input_type == 'ones' else jnp.eye(u.shape[0], dtype=jnp.complex64) 147 | in_situ_matrix_fn = mesh.in_situ_matrix_fn(all_analog=all_analog) 148 | matrix_fn = mesh.matrix_fn(use_jax=True) 149 | 150 | def tr(u): 151 | return jnp.abs(u[0, 0]) ** 2 152 | 153 | def fn(params): 154 | return tr(matrix_fn(params, inputs)) 155 | 156 | def in_situ_fn(params): 157 | return tr(in_situ_matrix_fn(params, inputs)) 158 | grad_fn = grad(fn) 159 | grad_in_situ_fn = grad(in_situ_fn) 160 | for expected, actual in zip(grad_in_situ_fn(mesh.params), grad_fn(mesh.params)): 161 | np.testing.assert_allclose(expected, actual, rtol=1e-5, atol=1e-6) 162 | -------------------------------------------------------------------------------- /tests/fdfd_test.py: -------------------------------------------------------------------------------- 1 | from jax.config import config 2 | import numpy as np 3 | import pytest 4 | 5 | from simphox.fdfd import FDFD 6 | from simphox.typing import Size, Size3, Optional, List, Union 7 | 8 | np.random.seed(0) 9 | 10 | 11 | config.update("jax_enable_x64", True) 12 | # config.update('jax_platform_name', 'cpu') 13 | 14 | EPS_3_3_2 = 1 + np.random.rand(3, 3, 2).astype(np.complex128) 15 | SOURCE_3_3_2 = np.random.rand(54).astype(np.complex128) 16 | EPS_3_3 = 1 + np.random.rand(3, 3).astype(np.complex128) 17 | SOURCE_3_3 = np.random.rand(9).astype(np.complex128) 18 | EPS_10 = 1 + np.random.rand(10) 19 | EPS_6_5_10 = 1 + np.random.rand(6, 5, 10) 20 | EPS_5_5 = 1 + np.random.rand(5, 5) 21 | SOURCE_5_5 = np.random.rand(25).astype(np.complex128) 22 | EPS_10_10 = 1 + np.random.rand(10, 10) 23 | EPS_10_10_10 = 1 + np.random.rand(10, 10, 10) 24 | SOURCE_10 = np.random.rand(10).astype(np.complex128) 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "size, spacing, pml, pml_params, selected_indices, expected_df_data, expected_df_indices", 29 | [ 30 | ((5, 5, 5), 0.5, 2, (4, -16, 1), [0, 2, 10, 40, 100, 1000, 400, 203, 314], 31 | [0.05134884 + 0.31632416j, 0.05134884 + 0.31632416j, 32 | 0.05134884 + 0.31632416j, 0.05134884 + 0.31632416j, 0.05134884 + 0.31632416j, 2., 33 | 1.74179713 + 0.67062435j, -0.41673505 - 0.81228197j, 0.41673505 + 0.81228197j], 34 | [100, 101, 105, 120, 150, 600, 300, 101, 257]), 35 | ((10, 10), 1, 4, (4, -16, 1), [0, 2, 10, 40, 100, 130, 190, 80], 36 | [0.02567442 + 0.15816208j, 0.02567442 + 0.15816208j, 37 | 0.02567442 + 0.15816208j, 0.87089856 + 0.33531218j, 1., 1., 38 | -0.03404911 - 0.18135536j, 1.], 39 | [10, 11, 15, 30, 60, 75, 95, 50]), 40 | ((20,), 2, 8, (4, -16, 1), [0, 2, 4, 6, 8], 41 | [0.01283721 + 0.07908104j, 0.10418376 + 0.20307049j, 0.43544928 + 0.16765609j, 42 | 0.49971064 + 0.01202487j, 0.5], [1, 2, 3, 4, 5]), 43 | ], 44 | ) 45 | def test_df_pml_selected_indices(size: Size, spacing: Size, pml: Optional[Size], pml_params: Size3, 46 | selected_indices: List[int], 47 | expected_df_data: np.ndarray, expected_df_indices: np.ndarray): 48 | grid = FDFD(size, spacing, pml=pml, pml_params=pml_params, pml_sep=1) 49 | actual_df = grid.deriv_forward 50 | np.testing.assert_allclose(actual_df[0].data[selected_indices], expected_df_data) 51 | np.testing.assert_allclose(actual_df[0].indices[selected_indices], expected_df_indices) 52 | 53 | 54 | @pytest.mark.parametrize( 55 | "size, spacing, pml, pml_params, selected_indices, expected_db_data, expected_db_indices", 56 | [ 57 | ((5, 5, 5), 0.5, 2, (4, -16, 1), [0, 2, 10, 40, 100, 1000, 400, 203, 314], 58 | [-0.02033147 - 0.20062297j, -0.02033147 - 0.20062297j, -0.02033147 - 0.20062297j, -0.02033147 - 0.20062297j, 59 | -0.02033147 - 0.20062297j, 2., 1.44890765 + 0.89357816j, -0.18608182 - 0.58097951j, 60 | 0.18608182 + 0.58097951j], 61 | [900, 901, 905, 920, 950, 500, 200, 1, 157]), 62 | ((10, 10), 1, 4, (4, -16, 1), [0, 2, 10, 40, 100, 130, 190, 80], 63 | [-0.01016574 - 0.10031149j, -0.01016574 - 0.10031149j, 64 | -0.01016574 - 0.10031149j, 0.72445382 + 0.44678908j, 65 | 1., 1., 0.09304091 + 0.29048976j, 1.], 66 | [90, 91, 95, 20, 50, 65, 95, 40]), 67 | ((20,), 2, 8, (4, -16, 1), [0, 2, 4, 6, 8], 68 | [-0.00508287 - 0.05015574j, -0.04652045 - 0.14524488j, -0.36222691 - 0.22339454j, -0.49925823 - 0.01924408j, 69 | -0.5], [9, 0, 1, 2, 3]), 70 | ], 71 | ) 72 | def test_db_pml_selected_indices(size: Size, spacing: Size, pml: Optional[Size], pml_params: Size3, 73 | selected_indices: List[int], 74 | expected_db_data: np.ndarray, expected_db_indices: np.ndarray): 75 | grid = FDFD(size, spacing, pml=pml, pml_params=pml_params, pml_sep=1) 76 | actual_db = grid.deriv_backward 77 | np.testing.assert_allclose(actual_db[0].data[selected_indices], expected_db_data) 78 | np.testing.assert_allclose(actual_db[0].indices[selected_indices], expected_db_indices) 79 | 80 | 81 | @pytest.mark.parametrize( 82 | "size, spacing, eps, pml, pml_params, selected_indices, expected_db_data, expected_db_indices", 83 | [ 84 | ((5, 5, 5), 0.5, EPS_10_10_10, 2, (4, -16, 1), [0, 2, 10, 40, 100, 1000, 400, 203, 314], 85 | [-30.165956 + 0.075539j, 0.071384 - 0.021037j, -0.062418 + 0.016733j, 86 | -16.707972 + 1.761025j, 0.078196 + 0.635662j, 0.102698 + 0.632648j, 87 | -0.062418 + 0.016733j, -0.174223 + 0.088695j, -18.585439 + 4.17361j], 88 | [0, 9, 2009, 3, 2006, 2176, 2039, 1115, 24]), 89 | ((3, 2.5, 5), 0.5, EPS_6_5_10, None, (4, -16, 1), [0, 2, 10, 40, 100, 1000, 400, 203, 314], 90 | [-13.367367, -4., 4., -15.665815, 4., 4., 4., 4., -10.600918], 91 | [0, 9, 609, 3, 606, 726, 639, 365, 24]), 92 | ((10, 10,), 1, EPS_10_10, None, (4, -16, 1), [0, 2, 10, 40, 100, 130, 190, 80], 93 | [-24.113075, -1., 1., -1., -1., -1., -28.132924, 1.], 94 | [0, 9, 100, 114, 15, 118, 27, 110]), 95 | ((20,), 2, EPS_10, None, (4, -16, 1), [0, 2, 4, 6, 8], 96 | [-21.66702, -18.597956, -21.187809, -26.069936, -30.053552], 97 | [0, 2, 4, 6, 8]) 98 | ], 99 | ) 100 | def test_mat_selected_indices(size: Size, spacing: Size, eps: Union[float, np.ndarray], 101 | pml: Optional[Size], pml_params: Size3, selected_indices: List[int], 102 | expected_db_data: np.ndarray, expected_db_indices: np.ndarray): 103 | grid = FDFD(size, spacing, eps=eps, pml=pml, pml_params=pml_params, pml_sep=1) 104 | actual_mat = grid.mat 105 | np.testing.assert_allclose(actual_mat.data[selected_indices], expected_db_data, rtol=1e-5) 106 | np.testing.assert_allclose(actual_mat.indices[selected_indices], expected_db_indices) 107 | 108 | 109 | @pytest.mark.parametrize( 110 | "size, spacing, eps, pml, pml_params, selected_indices, expected_mat_ez_data, expected_mat_ez_indices", 111 | [ 112 | ((10, 10), 1, EPS_10_10, 4, (4, -16, 1), [0, 2, 10, 40, 100, 130, 190, 80], 113 | [-2.272260e+01 + 0.018885j, 1.784589e-02 - 0.005259j, 114 | 3.050671e-02 - 0.387327j, -6.920820e-01 - 0.492298j, 115 | 3.050671e-02 - 0.387327j, 3.050671e-02 - 0.387327j, 116 | -8.567010e-01 - 0.368334j, 4.355569e-02 - 0.022174j], 117 | [0, 9, 1, 7, 10, 16, 28, 6]), 118 | ((20,), 2, EPS_10, 8, (4, -16, 1), [0, 2, 4, 6, 8], 119 | [-1.909656e+01 + 0.002361j, 4.461473e-03 - 0.001315j, 120 | -2.456868e+01 + 0.030123j, 7.626676e-03 - 0.096832j, 121 | -1.202780e-01 - 0.158007j], [0, 9, 1, 1, 3]), 122 | ], 123 | ) 124 | def test_mat_ez_selected_indices(size: Size, spacing: Size, eps: Union[float, np.ndarray], 125 | pml: Optional[Size], pml_params: Size3, selected_indices: List[int], 126 | expected_mat_ez_data: np.ndarray, expected_mat_ez_indices: np.ndarray): 127 | grid = FDFD(size, spacing, eps=eps, pml=pml, pml_params=pml_params, pml_sep=1) 128 | actual_mat_ez = grid.mat_ez 129 | np.testing.assert_allclose(actual_mat_ez.data[selected_indices], expected_mat_ez_data, rtol=1e-4) 130 | np.testing.assert_allclose(actual_mat_ez.indices[selected_indices], expected_mat_ez_indices, rtol=1e-4) 131 | 132 | 133 | @pytest.mark.parametrize( 134 | "size, spacing, eps, pml, pml_params, selected_indices, expected_mat_hz_data, expected_mat_hz_indices", 135 | [ 136 | ((10, 10), 1, EPS_10_10, 4, (4, -16, 1), [0, 2, 10, 40, 100, 130, 190, 80], 137 | [0.017608 - 0.005189j, 0.013573 - 0.003639j, 0.011075 - 0.003264j, 138 | 0.011218 - 0.003306j, -0.346769 - 0.455543j, -0.305578 - 0.401431j, 139 | -0.503263 - 0.031538j, 0.062817 - 0.062641j], 140 | [90, 10, 92, 98, 30, 36, 48, 26]), 141 | ((20,), 2, EPS_10, 8, (4, -16, 1), [0, 2, 4, 6, 8], 142 | [3.148939e-03 - 0.000928j, 3.358308e-03 - 0.0009j, 143 | -1.645811e+01 + 0.021235j, 5.108325e-03 - 0.064858j, 144 | -8.593715e-02 - 0.112894j], [9, 1, 1, 1, 3]), 145 | ], 146 | ) 147 | def test_mat_hz_selected_indices(size: Size, spacing: Size, eps: Union[float, np.ndarray], 148 | pml: Optional[Size], pml_params: Size3, selected_indices: List[int], 149 | expected_mat_hz_data: np.ndarray, expected_mat_hz_indices: np.ndarray): 150 | grid = FDFD(size, spacing, eps=eps, pml=pml, pml_params=pml_params, pml_sep=1) 151 | actual_mat_hz = grid.mat_hz 152 | np.testing.assert_allclose(actual_mat_hz.data[selected_indices], expected_mat_hz_data, rtol=3e-4) 153 | np.testing.assert_allclose(actual_mat_hz.indices[selected_indices], expected_mat_hz_indices, rtol=3e-4) 154 | 155 | 156 | @pytest.mark.parametrize( 157 | "size, spacing, eps, pml, pml_params, src, expected, tm_2d", 158 | [ 159 | ((5, 5), 1, EPS_5_5, None, (4, -16, 1), SOURCE_5_5, 160 | [-0.07167871 - 0.j, -0.14298599 - 0.j, 0.01107421 + 0.j, -0.1477444 - 0.j, 161 | -0.02976408 - 0.j, -0.03151538 - 0.j, -0.1121591 - 0.j, -0.03932945 - 0.j, 162 | -0.11261177 - 0.j, -0.08187067 - 0.j, -0.09351526 - 0.j, -0.10450915 - 0.j, 163 | -0.0381832 - 0.j, -0.15488618 - 0.j, -0.07711778 - 0.j, -0.06523334 - 0.j, 164 | -0.17933909 - 0.j, -0.22449252 - 0.j, 0.03195563 + 0.j, -0.03765066 - 0.j, 165 | -0.04298976 - 0.j, -0.12478716 - 0.j, -0.1828108 - 0.j, -0.22549907 - 0.j, 166 | -0.17781745 - 0.j], False), 167 | ((5,), 0.5, EPS_10, 2, (4, -16, 1), SOURCE_10, 168 | [-0.17089647 - 1.02674552e-05j, -0.16491481 - 2.76797540e-04j, 169 | -0.15285368 - 6.93158247e-03j, -0.09097409 + 4.74820947e-03j, 170 | 0.07009409 - 3.09620835e-04j, -0.33076997 + 4.20679477e-04j, 171 | 0.09635553 - 9.85560547e-04j, -0.08525008 + 3.21510276e-03j, 172 | -0.16716785 - 8.39563779e-03j, -0.16756679 - 1.39350922e-04j], False), 173 | ((5, 5), 1, EPS_5_5, None, (4, -16, 1), SOURCE_5_5, 174 | [-0.12267507, -0.26522232, 0.02130549, -0.20746095, -0.05265192, -0.04341361, -0.17348322, -0.05401931, 175 | -0.16864395, -0.13477245, -0.10720724, -0.13857351, -0.04744684, -0.19776741, -0.09765291, -0.08476391, 176 | -0.2110588, -0.25837641, 0.03507694, -0.04967016, -0.07139072, -0.20058036, -0.25911829, -0.25718025, 177 | -0.24592517], True), 178 | ((5,), 0.5, EPS_10, 2, (4, -16, 1), SOURCE_10, 179 | [-0.19882722 + 7.14214969e-05j, -0.24548843 - 8.43838043e-04j, -0.21431744 - 9.28589671e-03j, 180 | -0.12438099 + 7.56693867e-03j, 0.07728352 - 9.30301413e-04j, -0.3811222 + 9.69347403e-04j, 181 | 0.12419099 - 2.39105654e-03j, -0.10017246 + 7.44759445e-03j, -0.24136862 - 1.68686342e-02j, 182 | -0.23727428 - 3.29575035e-04j], True), 183 | ], 184 | ) 185 | def test_solve_2d(size: Size, spacing: Size, eps: Union[float, np.ndarray], 186 | pml: Optional[Size], pml_params: Size3, src: np.ndarray, expected: np.ndarray, tm_2d: bool): 187 | grid = FDFD(size, spacing, eps=eps, pml=pml, pml_params=pml_params, pml_sep=1) 188 | actual = grid.solve(src, tm_2d=tm_2d)[int(tm_2d), 2].ravel() # hz if tm_2d, ez if not tm_2d 189 | np.testing.assert_allclose(actual, expected, rtol=1e-4) 190 | 191 | 192 | @pytest.mark.parametrize( 193 | "size, spacing, eps, pml, pml_params, src, expected", 194 | [ 195 | ((1.5, 1.5, 1), 0.5, EPS_3_3_2, None, (4, -16, 1), SOURCE_3_3_2, 196 | [-1.22372285, -0.2033493, -2.93547346, -1.57145287, 197 | 3.60707297, 1.71612853, -0.5059581, 0.47474773, 198 | -1.52122203, -0.17910732, 1.99674375, -1.06473848, 199 | -0.27784698, 0.99227531, 2.72554383, 4.7693412, 200 | -3.30200237, -6.50685786, 0.86746106, -3.05483556, 201 | 5.96681911, 4.1779766, -0.18377168, -4.25069235, 202 | 0.12800289, 0.14021134, -1.86431569, -0.3491506, 203 | 0.66900647, 0.88677132, -0.40220728, 1.94473361, 204 | -5.38257397, -1.9850864, 0.36696246, 2.42108452, 205 | 0.16731768, 0.30936384, 0.93888465, 0.04561664, 206 | -1.42294774, -0.71444213, -0.25400074, -0.1811002, 207 | 0.02535975, -0.93109824, -0.13218408, 1.52497524, 208 | 0.13498791, -0.23291909, 0.19398719, -0.48559814, 209 | -0.62364214, 0.46228806]) 210 | ], 211 | ) 212 | def test_solve_full(size: Size, spacing: Size, eps: Union[float, np.ndarray], 213 | pml: Optional[Size], pml_params: Size3, src: np.ndarray, expected: np.ndarray): 214 | grid = FDFD(size, spacing, eps=eps, pml=pml, pml_params=pml_params, pml_sep=1) 215 | actual = grid.solve(src)[0].ravel() # just check the e field 216 | np.testing.assert_allclose(actual, expected, atol=1e-6) 217 | 218 | 219 | @pytest.mark.parametrize( 220 | "size, spacing, eps, pml, pml_params, src", 221 | [ 222 | ((3, 3), 1, EPS_3_3, None, (4, -16, 1), np.ones(27)), 223 | ((5,), 0.5, EPS_10, 2, (4, -16, 1), np.ones(30)), 224 | ], 225 | ) 226 | def test_solve_src_size_error(size: Size, spacing: Size, eps: Union[float, np.ndarray], 227 | pml: Optional[Size], pml_params: Size3, src: np.ndarray): 228 | with pytest.raises(ValueError, match='Expected src.size == '): 229 | FDFD(size, spacing, eps=eps, pml=pml, pml_params=pml_params, pml_sep=1).solve(src) 230 | 231 | 232 | @pytest.mark.parametrize( 233 | "size, spacing, eps, pml, pml_params, src, tm_2d", 234 | [ 235 | # ((3, 3, 2), 1, EPS_3_3_2, None, (4, -16, 1), SOURCE_3_3_2, False), # takes a while, bicgstab test 236 | ((3, 3), 1, EPS_3_3, None, (4, -16, 1), SOURCE_3_3, True), 237 | ((5,), 0.5, EPS_10, 2, (4, -16, 1), SOURCE_10, True), 238 | ((3, 3), 1, EPS_3_3, None, (4, -16, 1), SOURCE_3_3, False), 239 | ((5,), 0.5, EPS_10, 2, (4, -16, 1), SOURCE_10, False) 240 | ] 241 | ) 242 | def test_solve_fn_jax(size: Size, spacing: Size, eps: Union[float, np.ndarray], 243 | pml: Optional[Size], pml_params: Size3, src: np.ndarray, tm_2d: bool): 244 | grid = FDFD(size, spacing, eps=eps, pml=pml, pml_params=pml_params, pml_sep=1) 245 | solve_jax = grid.get_fields_fn(src, tm_2d=tm_2d) 246 | jax_result = solve_jax(grid.eps).ravel() 247 | numpy_result = grid.solve(src, tm_2d=tm_2d).ravel() 248 | np.testing.assert_allclose(jax_result, numpy_result, rtol=2e-3) 249 | -------------------------------------------------------------------------------- /tests/grid_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from typing import List, Tuple, Union, Optional 3 | 4 | from simphox.typing import Shape, Size, Size2, Size3 5 | from simphox.utils import TEST_ZERO, TEST_ONE, Box 6 | from simphox.grid import Grid, YeeGrid 7 | 8 | import numpy as np 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "size, spacing, eps, expected_cell_sizes", 13 | [ 14 | ((2.5, 2.5, 1), 0.5, 1, [np.array([0.5, 0.5, 0.5, 0.5, 0.5]), np.array([0.5, 0.5, 0.5, 0.5, 0.5]), 15 | np.array([0.5, 0.5])]), 16 | ((1, 1), 0.2, 1, [np.array([0.2, 0.2, 0.2, 0.2, 0.2]), np.array([0.2, 0.2, 0.2, 0.2, 0.2]), np.array([1])]), 17 | ((1, 0.8), 0.2, 1, [np.ones(5) * 0.2, np.ones(4) * 0.2, np.array([1])]), 18 | ((15,), 3, 1, [np.ones(5) * 3, np.array([1]), np.array([1])]), 19 | ((5, 6, 6), (1, 2, 3), 1, [np.ones(5) * 1, np.ones(3) * 2, np.ones(2) * 3]) 20 | ], 21 | ) 22 | def test_cell_size(size: Size, spacing: Size, eps: Union[float, np.ndarray], 23 | expected_cell_sizes: List[np.ndarray]): 24 | grid = Grid(size, spacing, eps) 25 | for actual, expected in zip(grid.cells, expected_cell_sizes): 26 | np.testing.assert_allclose(actual, expected) 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "size, spacing, eps, expected_pos", 31 | [ 32 | ((2.5, 2.5, 1), 0.5, 1, 33 | [np.array([0, 0.5, 1, 1.5, 2, 2.5]), np.array([0, 0.5, 1, 1.5, 2, 2.5]), np.array([0, 0.5, 1])]), 34 | ((1, 1), 0.2, 1, [np.array([0, 0.2, 0.4, 0.6, 0.8, 1]), np.array([0, 0.2, 0.4, 0.6, 0.8, 1]), np.array([0])]), 35 | ((1, 0.8), 0.2, 1, [np.array([0, 0.2, 0.4, 0.6, 0.8, 1]), np.array([0, 0.2, 0.4, 0.6, 0.8]), np.array([0])]), 36 | ((15,), 3, 1, [np.arange(6) * 3, np.array([0]), np.array([0])]), 37 | ((5, 6, 6), (1, 2, 3), 1, [np.arange(6) * 1, np.arange(4) * 2, np.arange(3) * 3]) 38 | ], 39 | ) 40 | def test_pos(size: Size, spacing: Size, eps: Union[float, np.ndarray], 41 | expected_pos: List[np.ndarray]): 42 | grid = Grid(size, spacing, eps) 43 | for actual, expected in zip(grid.pos, expected_pos): 44 | np.testing.assert_allclose(actual, expected) 45 | 46 | 47 | @pytest.mark.parametrize( 48 | "size, spacing, eps, expected_spacing", 49 | [ 50 | ((5, 5, 2), 0.5, 1, np.asarray((0.5, 0.5, 0.5))), 51 | ((5, 5), 0.2, 1, np.ones(2) * 0.2), 52 | ((5, 4), 0.2, 1, np.ones(2) * 0.2), 53 | ((5, 3, 2), (1, 2, 3), 1, np.array((1, 2, 3))) 54 | ], 55 | ) 56 | def test_spacing(size: Shape, spacing: Size, 57 | eps: Union[float, np.ndarray], expected_spacing: np.ndarray): 58 | grid = Grid(size, spacing, eps) 59 | np.testing.assert_allclose(grid.spacing, expected_spacing) 60 | 61 | 62 | @pytest.mark.parametrize( 63 | "shape, eps", 64 | [ 65 | ((2, 3), np.asarray(((1, 1), (1, 1)))), 66 | ((2,), np.asarray(((1, 1), (1, 1)))) 67 | ], 68 | ) 69 | def test_error_raised_for_shape_eps_mismatch(shape: Shape, eps: Union[float, np.ndarray]): 70 | with pytest.raises(AttributeError, match='Require grid.shape == eps.shape but got '): 71 | Grid(shape, 1, eps) 72 | 73 | 74 | @pytest.mark.parametrize( 75 | "shape, spacing", 76 | [ 77 | ((2, 3), (1, 1, 1)), 78 | ((2, 3, 2), (1, 1)) 79 | ], 80 | ) 81 | def test_error_raised_for_shape_spacing_mismatch(shape: Shape, spacing: Size): 82 | with pytest.raises(AttributeError, match='Require size.size == ndim == spacing.size but got '): 83 | Grid(shape, spacing) 84 | 85 | 86 | @pytest.mark.parametrize( 87 | "shape, spacing, size", 88 | [ 89 | ((5, 5, 2), 0.5, (2.5, 2.5, 1)), 90 | ((5, 5), 0.2, (1, 1)), 91 | ], 92 | ) 93 | def test_shape(shape: Shape, spacing: Size, size: Size): 94 | grid = Grid(size, spacing) 95 | np.testing.assert_allclose(grid.shape, shape) 96 | 97 | 98 | @pytest.mark.parametrize( 99 | "shape, spacing, size", 100 | [ 101 | ((5, 5, 2), 0.5, (2.5, 2.5, 1)), 102 | ((5, 5, 1), 0.2, (1, 1)), 103 | ], 104 | ) 105 | def test_shape3(shape: Shape, spacing: Size, size: Size): 106 | grid = Grid(size, spacing) 107 | np.testing.assert_allclose(grid.shape3, shape) 108 | 109 | 110 | @pytest.mark.parametrize( 111 | "sim_spacing3, spacing, size", 112 | [ 113 | ((0.5, 0.5, 0.5), 0.5, (2.5, 2.5, 1)), 114 | ((0.2, 0.2, np.inf), 0.2, (1, 1)), 115 | ], 116 | ) 117 | def test_spacing3(sim_spacing3: Size, spacing: Size, size: Size): 118 | grid = Grid(size, spacing) 119 | np.testing.assert_allclose(grid.spacing3, sim_spacing3) 120 | 121 | 122 | @pytest.mark.parametrize( 123 | "sim_size, spacing, center, size, squeezed, expected_slice", 124 | [ 125 | ((2.5, 2.5, 1), 0.5, (1, 1, 1), (0.5, 1, 1), True, [slice(2, 3, None), slice(1, 3, None), slice(1, 3, None)]), 126 | ((2.5, 2.5, 1), 0.5, (1, 1, 1), (0.5, 0.1, 1), True, [slice(2, 3, None), 2, slice(1, 3, None)]), 127 | ( 128 | (2.5, 2.5, 1), 0.5, (1, 1, 1), (0.5, 0.1, 1), False, [slice(2, 3, None), slice(2, 3, None), slice(1, 3, None)]), 129 | ((1, 1), 0.2, (1, 1, 0), (0.5, 1, 1), True, [slice(4, 6, None), slice(3, 8, None), 0]), 130 | ], 131 | ) 132 | def test_slice(sim_size: Shape, spacing: Size, center: Size3, size: Size3, squeezed: bool, 133 | expected_slice: Tuple[Union[slice, int]]): 134 | grid = Grid(sim_size, spacing) 135 | actual = grid.slice(center, size, squeezed=squeezed) 136 | assert tuple(actual) == tuple(expected_slice) 137 | 138 | 139 | @pytest.mark.parametrize( 140 | "size, spacing, pml, expected_df_data, expected_df_indices", 141 | [ 142 | ((1.5, 1.5, 1), 0.5, None, 143 | [2., -2., 2., -2., 2., -2., 2., -2., 2., -2., 2., -2., 2., -2., 144 | 2., -2., 2., -2., 2., -2., 2., -2., 2., -2., -2., 2., -2., 2., 145 | -2., 2., -2., 2., -2., 2., -2., 2.], 146 | [6, 0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13, 7, 14, 147 | 8, 15, 9, 16, 10, 17, 11, 12, 0, 13, 1, 14, 2, 15, 3, 16, 4, 148 | 17, 5] 149 | ), 150 | ((3, 3), 1, None, 151 | [1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1., -1., -1., 1., 152 | -1., 1., -1., 1.], 153 | [3, 0, 4, 1, 5, 2, 6, 3, 7, 4, 8, 5, 6, 0, 7, 1, 8, 2] 154 | ), 155 | ((6,), 2, None, 156 | [0.5, -0.5, 0.5, -0.5, 0.5, -0.5], 157 | [1, 0, 2, 1, 0, 2] 158 | ), 159 | ], 160 | ) 161 | def test_df(size: Size, spacing: Size, pml: Optional[Size3], expected_df_data: np.ndarray, 162 | expected_df_indices: np.ndarray): 163 | grid = YeeGrid(size, spacing, pml=pml) 164 | actual_df = grid.deriv_forward 165 | np.testing.assert_allclose(actual_df[0].data, expected_df_data) 166 | np.testing.assert_allclose(actual_df[0].indices, expected_df_indices) 167 | 168 | 169 | @pytest.mark.parametrize( 170 | "size, spacing, pml, expected_db_data, expected_db_indices", 171 | [ 172 | ((1.5, 1.5, 1), 0.5, None, 173 | [-2., 2., -2., 2., -2., 2., -2., 2., -2., 2., -2., 2., 2., 174 | -2., 2., -2., 2., -2., 2., -2., 2., -2., 2., -2., 2., -2., 175 | 2., -2., 2., -2., 2., -2., 2., -2., 2., -2.], 176 | [12, 0, 13, 1, 14, 2, 15, 3, 16, 4, 17, 5, 6, 0, 7, 1, 8, 177 | 2, 9, 3, 10, 4, 11, 5, 12, 6, 13, 7, 14, 8, 15, 9, 16, 10, 178 | 17, 11] 179 | ), 180 | ((3, 3), 1, None, 181 | [-1., 1., -1., 1., -1., 1., 1., -1., 1., -1., 1., -1., 1., 182 | -1., 1., -1., 1., -1.], 183 | [6, 0, 7, 1, 8, 2, 3, 0, 4, 1, 5, 2, 6, 3, 7, 4, 8, 5] 184 | ), 185 | ((6,), 2, None, 186 | [-0.5, 0.5, -0.5, 0.5, -0.5, 0.5], 187 | [2, 0, 0, 1, 1, 2] 188 | ), 189 | ], 190 | ) 191 | def test_db(size: Size, spacing: Size, pml: Optional[Size3], expected_db_data: np.ndarray, 192 | expected_db_indices: np.ndarray): 193 | grid = YeeGrid(size, spacing, pml=pml) 194 | actual_db = grid.deriv_backward 195 | np.testing.assert_allclose(actual_db[0].data, expected_db_data) 196 | np.testing.assert_allclose(actual_db[0].indices, expected_db_indices) 197 | 198 | 199 | @pytest.mark.parametrize( 200 | "waveguide, sub, size, wg_height, spacing, rib_y, vertical, block, gap, seps, expected", 201 | [ 202 | (Box((0.2, 0.4), material=TEST_ZERO), (1.4, 0.2), (1.4, 1), 0.2, 0.2, 0, False, 203 | None, 0.2, (0.2, 0.4), np.array( 204 | [[1., 1., 1., 1., 1.], 205 | [1., 1., 1., 1., 1.], 206 | [1., 0., 0., 1., 1.], 207 | [1., 1., 1., 1., 1.], 208 | [1., 0., 0., 1., 1.], 209 | [1., 1., 1., 1., 1.], 210 | [1., 1., 1., 1., 1.]] 211 | )), 212 | (Box((0.2, 0.4), material=TEST_ZERO), (1.4, 0.2), (1.4, 1), 0.2, 0.2, 0, False, 213 | Box((0.2, 0.2), material=TEST_ZERO), 0.2, (0.2, 0.2), np.array([ 214 | [1, 1, 0, 1, 1], 215 | [1, 1, 1, 1, 1], 216 | [1, 0, 0, 1, 1], 217 | [1, 1, 1, 1, 1], 218 | [1, 0, 0, 1, 1], 219 | [1, 1, 1, 1, 1], 220 | [1, 1, 0, 1, 1] 221 | ])), 222 | (Box((0.2, 0.4), material=TEST_ZERO), (1.4, 0.2), (1.4, 1), 0.2, 0.2, 0, True, 223 | Box((0.2, 0.2), material=TEST_ZERO), 0.2, (0.2, 0.2), np.array([ 224 | [1., 1., 1., 1., 1.], 225 | [1., 1., 1., 1., 1.], 226 | [1., 0., 0., 1., 0.], 227 | [1., 1., 1., 1., 1.], 228 | [1., 0., 0., 1., 0.], 229 | [1., 1., 1., 1., 1.], 230 | [1., 1., 1., 1., 1.] 231 | ])), 232 | (Box((0.6, 0.6), material=TEST_ZERO), (1, 1), (1, 1), 0.2, 0.2, 0, False, None, 0, 0, np.array([ 233 | [1, 1, 1, 1, 1], 234 | [1, 0, 0, 0, 1], 235 | [1, 0, 0, 0, 1], 236 | [1, 0, 0, 0, 1], 237 | [1, 1, 1, 1, 1] 238 | ])), 239 | (Box((0.4, 0.4), material=TEST_ZERO), (1, 1), (1, 1), 0.2, 0.2, 0, False, None, 0, 0, np.array([ 240 | [1, 1, 1, 1, 1], 241 | [1, 0, 0, 1, 1], 242 | [1, 0, 0, 1, 1], 243 | [1, 1, 1, 1, 1], 244 | [1, 1, 1, 1, 1] 245 | ])), 246 | (Box((0.4, 0.4), material=TEST_ZERO), (1, 0.2), (1, 1), 0.2, 0.2, 0.2, False, None, 0, 0, np.array([ 247 | [1, 0, 1, 1, 1], 248 | [1, 0, 0, 1, 1], 249 | [1, 0, 0, 1, 1], 250 | [1, 0, 1, 1, 1], 251 | [1, 0, 1, 1, 1] 252 | ])), 253 | (Box((0.2, 0.4), material=TEST_ZERO), (1, 0.2), (1, 1), 0.2, 0.2, 0, False, 254 | Box((0.2, 0.4), material=TEST_ZERO), 0, 0.2, np.array([ 255 | [1, 0, 0, 1, 1], 256 | [1, 1, 1, 1, 1], 257 | [1, 0, 0, 1, 1], 258 | [1, 1, 1, 1, 1], 259 | [1, 0, 0, 1, 1] 260 | ])), 261 | (Box((0.4, 0.2), material=TEST_ZERO), (1, 0.2), (1, 1), 0.2, 0.2, 0, True, 262 | Box((0.4, 0.2), material=TEST_ZERO), 0, 0.2, np.array([ 263 | [1., 1., 1., 1., 1.], 264 | [1., 1., 0., 1., 0.], 265 | [1., 1., 0., 1., 0.], 266 | [1., 1., 1., 1., 1.], 267 | [1., 1., 1., 1., 1.] 268 | ])), 269 | ], 270 | ) 271 | def test_block_design_eps_matches_expected(waveguide: Box, sub: Size2, size: Size2, wg_height: float, spacing: float, 272 | rib_y: float, vertical: bool, block: Box, gap: float, 273 | seps: Size2, expected: np.ndarray): 274 | actual = Grid(size, spacing).block_design(waveguide=waveguide, 275 | wg_height=wg_height, 276 | sub_height=wg_height, 277 | sub_eps=TEST_ONE.eps, 278 | gap=gap, 279 | rib_y=rib_y, 280 | block=block, 281 | vertical=vertical, 282 | sep=seps 283 | ).eps 284 | np.testing.assert_allclose(actual, expected) 285 | -------------------------------------------------------------------------------- /tests/mode_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from simphox.mode import ModeSolver, ModeLibrary 5 | from simphox.grid import Grid 6 | from simphox.typing import Size2 7 | from simphox.utils import TEST_ONE, TEST_INF, SILICON, AIR, Box 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "waveguide, sub, size, wg_height, spacing, rib_y, vertical, block, gap, seps, expected_beta", 12 | [ 13 | (Box((0.2, 0.4), material=SILICON), Box((1.4, 0.2), material=AIR), 14 | (1.4, 1), 0.2, 0.2, 0, False, Box((0.2, 0.2), material=SILICON), 0.2, (0.2, 0.4), [8.511193, 8.208999, 15 | 7.785885, 6.190572, 16 | 5.509391, 4.78587]), 17 | (Box((0.2, 0.4), material=SILICON), Box((1.4, 0.2), material=AIR), 18 | (1.4, 1), 0.2, 0.2, 0, False, None, 0.2, (0, 0), [8.497423, 8.193003, 19 | 7.732571, 5.94686, 20 | 5.247048, 3.93605]), 21 | (Box((0.2, 0.4), material=SILICON), Box((1.4, 0.2), material=AIR), 22 | (1.4, 1), 0.2, 0.2, 0, True, Box((0.2, 0.2), material=SILICON), 0.2, (0.2, 0.4), [8.539984, 8.22446, 23 | 7.776173, 6.022444, 24 | 5.283903, 4.618694]), 25 | ], 26 | ) 27 | def test_mode_matches_expected_beta(waveguide: Box, sub: Box, size: Size2, wg_height: float, spacing: float, 28 | rib_y: float, vertical: bool, block: Box, gap: float, seps: Size2, 29 | expected_beta: float): 30 | actual_beta, _ = ModeSolver(size, spacing).block_design(waveguide=waveguide, 31 | wg_height=wg_height, 32 | sub_height=wg_height, 33 | sub_eps=TEST_ONE.eps, 34 | gap=gap, 35 | rib_y=rib_y, 36 | block=block, 37 | vertical=vertical, 38 | sep=seps 39 | ).solve() 40 | np.testing.assert_allclose(actual_beta, expected_beta, atol=1e-6) 41 | 42 | 43 | @pytest.mark.skip(reason="This currently fails on travis. unsure why...") 44 | @pytest.mark.parametrize( 45 | "waveguide, sub, size, wg_height, spacing, rib_y, vertical, block, gap, seps, expected_mean, expected_std", 46 | [ 47 | (Box((0.2, 0.4), material=SILICON), Box((1.4, 0.2), material=AIR), 48 | (1.4, 1), 0.2, 0.2, 0, False, Box((0.2, 0.2), material=SILICON), 0.2, (0.2, 0.4), 49 | [0.025206, -0.000639, 0.03445, 0.002703, 0.019679, 0.02683], 50 | [0.105274, 0.107789, 0.094922, 0.121647, 0.147402, 0.099317]), 51 | (Box((0.2, 0.4), material=SILICON), Box((1.4, 0.2), material=AIR), 52 | (1.4, 1), 0.2, 0.2, 0, False, None, 0.2, (0, 0), 53 | [-9.200602e-03, 0, 3.316460e-02, 0, 2.053727e-02, 0], 54 | [0.105418, 0.107831, 0.096243, 0.124457, 0.167402, 0.110291]), 55 | (Box((0.2, 0.4), material=SILICON), Box((1.4, 0.2), material=AIR), 56 | (1.4, 1), 0.2, 0.2, 0, True, Box((0.2, 0.2), material=SILICON), 0.2, (0.2, 0.4), 57 | [-0.02573, -0.001929, 0.03454, 0.00172, 0.021061, -0.025712], 58 | [0.104182, 0.107348, 0.095774, 0.124315, 0.166756, 0.146658]), 59 | ], 60 | ) 61 | def test_mode_matches_expected_mean_std(waveguide: Box, sub: Box, size: Size2, wg_height: float, spacing: float, 62 | rib_y: float, vertical: bool, block: Box, gap: float, seps: Size2, 63 | expected_mean: float, expected_std: float): 64 | _, actual_modes = ModeSolver(size, spacing).block_design(waveguide=waveguide, 65 | wg_height=wg_height, 66 | sub_height=wg_height, 67 | sub_eps=TEST_ONE.eps, 68 | gap=gap, 69 | rib_y=rib_y, 70 | block=block, 71 | vertical=vertical, 72 | sep=seps 73 | ).solve() 74 | actual_mean = np.mean(actual_modes, axis=1).real 75 | actual_std = np.std(actual_modes, axis=1).real 76 | np.testing.assert_allclose(actual_mean, expected_mean, atol=1e-6) 77 | np.testing.assert_allclose(actual_std, expected_std, atol=1e-6) 78 | 79 | 80 | @pytest.mark.parametrize( 81 | "waveguide, size, wg_height, spacing", 82 | [ 83 | (Box((0.36, 0.16), material=TEST_INF), (0.48, 0.24), 0.04, 0.02), 84 | (Box((0.16, 0.36), material=TEST_INF), (0.36, 0.64), 0.06, 0.01), 85 | (Box((0.36, 0), material=TEST_INF), (0.48,), 0.06, 0.001), # need very high res to match result in 1d mode 86 | ], 87 | ) 88 | def test_mode_matches_expected_analytical_2d(waveguide: Box, size: Size2, wg_height: float, spacing: float): 89 | modes = ModeLibrary.from_block_design( 90 | size=size, 91 | spacing=spacing, 92 | waveguide=waveguide, 93 | wg_height=wg_height 94 | ) 95 | wg_grid = Grid(waveguide.size, spacing) 96 | y, x, z = np.meshgrid(wg_grid.pos[1], wg_grid.pos[0] + spacing / 2, wg_grid.pos[2]) 97 | if len(size) == 2: 98 | analytical = (np.sin(y / waveguide.size[1] * np.pi) * np.sin(x / waveguide.size[0] * np.pi))[:-1, :-1].squeeze() 99 | else: 100 | analytical = np.sin((x - spacing / 2) / waveguide.size[0] * np.pi)[1:].squeeze() 101 | numerical = np.abs(modes.h(0)[1][modes.eps == 1e10].reshape(analytical.shape)) 102 | numerical = numerical / np.max(numerical) 103 | np.testing.assert_allclose(numerical, analytical, rtol=1e-2, atol=1e-2) 104 | -------------------------------------------------------------------------------- /tests/primitives_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from jax import vjp 3 | import jax.numpy as jnp 4 | from jax.config import config 5 | import jax.test_util as jtu 6 | import numpy as np 7 | import scipy.sparse as sp 8 | from scipy.sparse.linalg import spsolve as spsolve_scipy 9 | 10 | from simphox.primitives import spsolve, TMOperator 11 | 12 | np.random.seed(0) 13 | config.update("jax_enable_x64", True) 14 | config.update('jax_platform_name', 'cpu') 15 | 16 | 17 | @pytest.mark.parametrize( 18 | "mat, v", 19 | [ 20 | (sp.spdiags(np.array([[1, 2, 3, 4, 5], [6, 5, 8, 9, 10]]), [0, 1], 5, 5), np.ones(5, dtype=np.complex128)), 21 | (sp.spdiags(np.array([[1, 2, 3, 8], [6, 5, 8, 300]]), [0, 1], 4, 4).transpose(), np.arange(4, dtype=np.complex128)) 22 | ], 23 | ) 24 | def test_spsolve_matches_scipy(mat: sp.spmatrix, v: np.ndarray): 25 | mat = mat.tocsr() 26 | expected = spsolve_scipy(mat, v) 27 | mat = mat.tocoo() 28 | mat_entries = jnp.array(mat.data, dtype=np.complex128) 29 | mat_indices = jnp.vstack((jnp.array(mat.row), jnp.array(mat.col))) 30 | x = spsolve(mat_entries, jnp.array(v), mat_indices) 31 | np.testing.assert_allclose(x, expected) 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "mat, v, g, expected", 36 | [ 37 | (sp.spdiags(np.array([[1, 2, 3, 4, 5], [6, 5, 8, 9, 10]]), [0, 1], 5, 5), np.ones(5, dtype=np.complex128), 38 | np.ones(5, dtype=np.complex128), np.array([(1, -2, 17 / 3, -12.5, 25.2)], dtype=np.complex128)), 39 | ], 40 | ) 41 | def test_spsolve_vjp_b(mat: sp.spmatrix, v: np.ndarray, g: np.ndarray, expected: np.ndarray): 42 | mat = mat.tocoo() 43 | mat_entries = jnp.array(mat.data, dtype=np.complex128) 44 | mat_indices = jnp.vstack((jnp.array(mat.row), jnp.array(mat.col))) 45 | _, vjp_fun = vjp(lambda x: spsolve(mat_entries, jnp.asarray(x), mat_indices), v) 46 | np.testing.assert_allclose(vjp_fun(g), expected) 47 | 48 | 49 | @pytest.mark.parametrize( 50 | "mat, v, g, expected", 51 | [ 52 | (sp.spdiags(np.array([[1, 2, 3, 4, 5], [6, 5, 8, 9, 10]]), [0, 1], 5, 5), 53 | np.ones(5, dtype=np.complex128), np.ones(5, dtype=np.complex128), 54 | np.array([[-726, -276, -221, -112.5, -181.44, 138, 78, 51, 90]], dtype=np.complex128) / 36), 55 | ], 56 | ) 57 | def test_spsolve_vjp_mat(mat: sp.spmatrix, v: np.ndarray, g: np.ndarray, expected: np.ndarray): 58 | mat = mat.tocoo() 59 | mat_entries = jnp.array(mat.data, dtype=np.complex128) 60 | mat_indices = jnp.vstack((jnp.array(mat.row), jnp.array(mat.col))) 61 | _, vjp_fun = vjp(lambda x: spsolve(x, jnp.asarray(v), mat_indices), mat_entries) 62 | np.testing.assert_allclose(vjp_fun(g), expected) 63 | 64 | 65 | # These only work when run individually at the moment... 66 | 67 | 68 | @pytest.mark.skip(reason="This currently fails at the test tree column...") 69 | @pytest.mark.parametrize( 70 | "mat1, mat2, v", 71 | [ 72 | (sp.spdiags(np.array([[1, 2, 3, 4, 5], [6, 5, 8, 9, 10]]), [0, 1], 5, 5), 73 | sp.spdiags(np.array([[6, 5, 8, 9, 10], [1, 2, 3, 4, 5]]), [0, 1], 5, 5), 74 | np.ones(5, dtype=np.complex128)), 75 | ], 76 | ) 77 | def test_tmoperator_numerical_grads(mat1: sp.spmatrix, mat2: sp.spmatrix, v: np.ndarray): 78 | operator = TMOperator([mat1, mat2], [mat2, mat1]) 79 | op = operator.compile_operator_along_axis(axis=0) 80 | 81 | def f(x): 82 | return jnp.sum(op(x)).real 83 | jtu.check_grads(f, (v,), order=1, modes=['rev']) 84 | 85 | 86 | @pytest.mark.skip(reason="This currently fails at the test tree column...") 87 | @pytest.mark.parametrize( 88 | "mat, v", 89 | [ 90 | (sp.spdiags(np.array([[1, 2, 3, 4, 5], [6, 5, 8, 9, 10]]), [0, 1], 5, 5), 91 | np.ones(5, dtype=np.complex128)), 92 | ], 93 | ) 94 | def test_spsolve_numerical_grads(mat, v): 95 | mat = mat.tocoo() 96 | mat_entries = jnp.array(mat.data, dtype=np.complex128) 97 | mat_indices = jnp.vstack((jnp.array(mat.row), jnp.array(mat.col))) 98 | 99 | def f(x): 100 | return jnp.sum(spsolve(x, jnp.asarray(v), mat_indices).real) 101 | jtu.check_grads(f, (mat_entries,), order=1, modes=['rev']) 102 | -------------------------------------------------------------------------------- /tests/sim_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from simphox.grid import Port 5 | from simphox.sim import SimGrid 6 | from simphox.typing import Size, Size2, Size3, Size4, Union, List, MeasureInfo 7 | 8 | np.random.seed(0) 9 | 10 | EPS_20_20 = 1 + np.random.rand(20, 20) + np.random.rand(20, 20) * 1j 11 | SOURCE_20_20 = np.random.rand(2, 3, 20, 20) + np.random.rand(2, 3, 20, 20) * 1j 12 | EPS_20_20_15 = 1 + np.random.rand(20, 20, 15) + np.random.rand(20, 20, 15) * 1j 13 | SOURCE_20_20_15 = np.random.rand(2, 3, 20, 20, 15) + np.random.rand(2, 3, 20, 20, 15) * 1j 14 | EPS_40_40_30 = 1 + np.random.rand(40, 40, 30) + np.random.rand(40, 40, 30) * 1j 15 | SOURCE_40_40_30 = np.random.rand(2, 3, 40, 40, 30) + 1j * np.random.rand(2, 3, 40, 40, 30) 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "size, spacing, eps, port_list, fields, measure_info, profile_size_factor, expected_params", 20 | [ 21 | ((20, 20), 1, EPS_20_20, [(7, 7, 0, 2), (13, 13, 0, 2)], SOURCE_20_20, None, 2, 22 | [[-0.287084 + 1.077556j, 0.532823 + 0.852614j], [-0.444322 + 0.312377j, 0.018376 + 0.435319j]]), 23 | ((20, 20, 15), 1, EPS_20_20_15, [(7, 7, 0, 2), (13, 13, 0, 2)], SOURCE_20_20_15, 24 | None, 2, [[0.587553 + 6.300795j, 4.783141 + 4.648525j], [-0.607148 - 0.019959j, 0.454386 + 0.905611j]]), 25 | ((20, 20, 15), 0.5, EPS_40_40_30, [(7, 7, 0, 2, 7, 3), (13, 13, 0, 2, 8, 3)], # x,y,a,w,z,h 26 | SOURCE_40_40_30, None, 2, 27 | [[-1.613978 - 1.048766j, 4.775639 + 1.626212j], [-4.093684 - 5.091917j, 1.649647 + 1.33246j]]) 28 | ] 29 | ) 30 | def test_measure_fn(size: Size, spacing: Size, eps: Union[float, np.ndarray], 31 | port_list: List[Size4], fields: np.ndarray, measure_info: MeasureInfo, profile_size_factor: float, 32 | expected_params: np.ndarray): 33 | grid = SimGrid(size, spacing, eps=eps, pml_sep=1) 34 | grid.port = {i: Port(*port_tuple) for i, port_tuple in enumerate(port_list)} 35 | measure_fn = grid.get_measure_fn(measure_port=measure_info) 36 | actual = measure_fn(fields) 37 | np.testing.assert_allclose(actual, expected_params, rtol=1e-5) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "size, spacing, eps, pml, pml_params, port_list," 42 | "fields, measure_info, profile_size_factor, expected_center, expected_size," 43 | "expected_mode_mean_std", 44 | [ 45 | ((20, 20), 1, EPS_20_20, None, (4, -16, 1, 1), [(7, 7, 0, 2), (13, 13, 0, 2)], 46 | SOURCE_20_20, None, 2, [(7, 7, 0), (13, 13, 0)], 47 | [(0, 4, 0), (0, 4, 0)], [[-0.085174, 0.408666], [-0.109273, 0.420226]]), 48 | ((20, 20, 15), 1, EPS_20_20_15, None, (4, -16, 1, 1), [(7, 7, 0, 2, 7, 3), (13, 13, 0, 2, 8, 3)], 49 | SOURCE_20_20_15, None, 2, [(7, 7, 7), (13, 13, 8)], [(0, 4, 6), (0, 4, 6)], 50 | [[0.027211, 0.117614], [0.010568, 0.118461]]), 51 | ((20, 20), 1, EPS_20_20, 8, (4, -16, 1, 1), [(7, 7, 0, 2), (13, 13, 0, 2)], 52 | SOURCE_20_20, None, 2, [(9, 9, 0), (10, 10, 0)], 53 | [(0, 4, 0), (0, 4, 0)], [[0.330118, 0.374979], [0.34629, 0.355712]]), 54 | ((20, 20), 1, EPS_20_20, 8, (4, -16, 1, 1), [(7, 7, 0, 2), (13, 13, 90, 2)], 55 | SOURCE_20_20, None, 2, [(9, 9, 0), (10, 10, 0)], 56 | [(0, 4, 0), (4, 0, 0)], [[0.330118, 0.374979], [0.200134, 0.40919]]), 57 | ((20, 20, 15), 0.5, EPS_40_40_30, None, (4, -16, 1, 1), 58 | [(7, 7, 0, 2, 7, 3), (13, 13, 90, 2, 8, 4)], SOURCE_40_40_30, None, 2, [(7, 7, 7), (13, 13, 8)], 59 | [(0, 4, 6), (4, 0, 8)], [[0.000196, 0.061155], [-0.012119, 0.049334]]) 60 | ] 61 | ) 62 | def test_port_modes(size: Size, spacing: Size, eps: Union[float, np.ndarray], pml: Size, pml_params: Size3, 63 | port_list: List[Size4], fields: np.ndarray, measure_info: MeasureInfo, profile_size_factor: float, 64 | expected_center: List[Size3], expected_size: List[Size3], expected_mode_mean_std: List[Size2]): 65 | grid = SimGrid(size, spacing, eps=eps, pml=pml, pml_params=pml_params, pml_sep=1) 66 | grid.port = {i: Port(*port_tuple) for i, port_tuple in enumerate(port_list)} 67 | port_modes = grid.port_modes(profile_size_factor=profile_size_factor) 68 | actual_center = [port_modes[i].center for i in port_modes] 69 | actual_size = [port_modes[i].size for i in port_modes] 70 | actual_mode_mean_std = [(np.mean(port_modes[i].io.modes[0]).real, np.std(port_modes[i].io.modes[0]).real) 71 | for i in port_modes] 72 | np.testing.assert_allclose(actual_center, expected_center) 73 | np.testing.assert_allclose(actual_size, expected_size) 74 | np.testing.assert_allclose(actual_mode_mean_std, expected_mode_mean_std, atol=1e-6) 75 | -------------------------------------------------------------------------------- /tests/transform_test.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import pytest 4 | 5 | from simphox.transform import get_mask_fn, get_smooth_fn, get_symmetry_fn 6 | from simphox.typing import List, Union 7 | from simphox.utils import Box 8 | 9 | np.random.seed(0) 10 | 11 | TEST_ARRAY = jnp.array([[4., 0., 4., 4., 0.], 12 | [0., 4., 0., 4., 4.], 13 | [4., 0., 0., 4., 0.], 14 | [0., 4., 4., 0., 0.], 15 | [4., 0., 0., 4., 4.], 16 | [4., 4., 4., 0., 4.], 17 | [0., 0., 0., 4., 0.]]) 18 | 19 | TEST_ARRAY_SQUARE = jnp.array([[0., 0., 0., 8., 0.], 20 | [8., 8., 0., 0., 8.], 21 | [8., 0., 0., 8., 0.], 22 | [0., 8., 8., 0., 0.], 23 | [8., 0., 0., 8., 8.]]) 24 | 25 | TEST_ARRAY_SQUARE_ONES = jnp.ones_like(TEST_ARRAY_SQUARE) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "ortho_x, ortho_y, expected_output", 30 | [ 31 | (False, False, TEST_ARRAY), 32 | (False, True, [[2., 2., 4., 2., 2.], 33 | [2., 4., 0., 4., 2.], 34 | [2., 2., 0., 2., 2.], 35 | [0., 2., 4., 2., 0.], 36 | [4., 2., 0., 2., 4.], 37 | [4., 2., 4., 2., 4.], 38 | [0., 2., 0., 2., 0.]]), 39 | (True, False, [[2., 0., 2., 4., 0.], 40 | [2., 4., 2., 2., 4.], 41 | [4., 0., 0., 4., 2.], 42 | [0., 4., 4., 0., 0.], 43 | [4., 0., 0., 4., 2.], 44 | [2., 4., 2., 2., 4.], 45 | [2., 0., 2., 4., 0.]]), 46 | (True, True, [[1., 2., 2., 2., 1.], 47 | [3., 3., 2., 3., 3.], 48 | [3., 2., 0., 2., 3.], 49 | [0., 2., 4., 2., 0.], 50 | [3., 2., 0., 2., 3.], 51 | [3., 3., 2., 3., 3.], 52 | [1., 2., 2., 2., 1.]]) 53 | ] 54 | ) 55 | def test_get_symmetry_fn(ortho_x: bool, ortho_y: bool, expected_output: np.ndarray): 56 | actual = get_symmetry_fn(ortho_x, ortho_y, avg=True)(TEST_ARRAY) 57 | np.testing.assert_allclose(actual, expected_output) 58 | 59 | 60 | @pytest.mark.parametrize( 61 | "ortho_x, ortho_y, diag_p, diag_n, expected_output", 62 | [ 63 | (False, False, False, False, TEST_ARRAY_SQUARE), 64 | (False, False, False, True, [[4., 0., 0., 8., 0.], 65 | [8., 4., 4., 0., 8.], 66 | [4., 4., 0., 4., 0.], 67 | [0., 8., 4., 4., 0.], 68 | [8., 0., 4., 8., 4.]]), 69 | (False, False, True, False, [[0., 4., 4., 4., 4.], 70 | [4., 8., 0., 4., 4.], 71 | [4., 0., 0., 8., 0.], 72 | [4., 4., 8., 0., 4.], 73 | [4., 4., 0., 4., 8.]]), 74 | (False, False, True, True, [[4., 4., 2., 4., 4.], 75 | [4., 4., 4., 4., 4.], 76 | [2., 4., 0., 4., 2.], 77 | [4., 4., 4., 4., 4.], 78 | [4., 4., 2., 4., 4.]]), 79 | (True, False, True, False, [[4., 2., 4., 6., 4.], 80 | [2., 8., 2., 4., 2.], 81 | [4., 2., 0., 6., 0.], 82 | [6., 4., 6., 0., 6.], 83 | [4., 2., 0., 6., 4.]]), 84 | (False, True, True, False, [[0., 6., 2., 2., 4.], 85 | [6., 4., 2., 4., 6.], 86 | [2., 2., 0., 6., 2.], 87 | [2., 4., 6., 4., 2.], 88 | [4., 6., 2., 2., 8.]]), 89 | (True, True, True, False, [[4., 4., 2., 4., 4.], 90 | [4., 4., 4., 4., 4.], 91 | [2., 4., 0., 4., 2.], 92 | [4., 4., 4., 4., 4.], 93 | [4., 4., 2., 4., 4.]]), 94 | (True, True, True, True, [[4., 4., 2., 4., 4.], 95 | [4., 4., 4., 4., 4.], 96 | [2., 4., 0., 4., 2.], 97 | [4., 4., 4., 4., 4.], 98 | [4., 4., 2., 4., 4.]]), 99 | ] 100 | ) 101 | def test_get_symmetry_fn_square(ortho_x: bool, ortho_y: bool, diag_p: bool, diag_n: bool, expected_output: np.ndarray): 102 | actual = get_symmetry_fn(ortho_x, ortho_y, diag_p, diag_n, avg=True)(TEST_ARRAY_SQUARE) 103 | np.testing.assert_allclose(actual, expected_output) 104 | 105 | 106 | @pytest.mark.parametrize( 107 | "rho_init, box, rho, expected_output", 108 | [ 109 | (TEST_ARRAY, Box((2, 2), min=(2, 2)), 1, [[4., 0., 4., 4., 0.], 110 | [0., 4., 0., 4., 4.], 111 | [4., 0., 1., 1., 0.], 112 | [0., 4., 1., 1., 0.], 113 | [4., 0., 0., 4., 4.], 114 | [4., 4., 4., 0., 4.], 115 | [0., 0., 0., 4., 0.]]), 116 | (TEST_ARRAY, Box((4, 2), min=(1, 2)), 1, [[4., 0., 4., 4., 0.], 117 | [0., 4., 1., 1., 4.], 118 | [4., 0., 1., 1., 0.], 119 | [0., 4., 1., 1., 0.], 120 | [4., 0., 1., 1., 4.], 121 | [4., 4., 4., 0., 4.], 122 | [0., 0., 0., 4., 0.]]), 123 | (TEST_ARRAY, [Box((4, 2), min=(1, 2)), Box((2, 4), min=(4, 1))], 1, [[4., 0., 4., 4., 0.], 124 | [0., 4., 1., 1., 4.], 125 | [4., 0., 1., 1., 0.], 126 | [0., 4., 1., 1., 0.], 127 | [4., 1., 1., 1., 1.], 128 | [4., 1., 1., 1., 1.], 129 | [0., 0., 0., 4., 0.]]), 130 | (TEST_ARRAY_SQUARE, 131 | [Box((4, 2), min=(1, 2)), Box((2, 4), min=(4, 1))], TEST_ARRAY_SQUARE_ONES, [[0., 0., 0., 8., 0.], 132 | [8., 8., 1., 1., 8.], 133 | [8., 0., 1., 1., 0.], 134 | [0., 8., 1., 1., 0.], 135 | [8., 1., 1., 1., 1.]]), 136 | ] 137 | ) 138 | def test_mask_fn(rho_init: np.ndarray, box: Union[Box, List[Box]], rho: np.ndarray, expected_output: np.ndarray): 139 | actual = get_mask_fn(rho_init, box)(rho) 140 | np.testing.assert_allclose(actual, expected_output) 141 | 142 | 143 | @pytest.mark.parametrize( 144 | "rho, eta, beta, radius, expected_output", 145 | [ 146 | (TEST_ARRAY, 0.5, 1, 2, [[0.65049475, 1.0963078, 1.2062135, 0.96534103, 0.81519336], 147 | [0.96534103, 1.2954934, 1.3661214, 1.2062135, 0.96534103], 148 | [1.0963078, 1.4208316, 1.4939058, 1.3661214, 1.2062135], 149 | [1.2062135, 1.4208316, 1.5173032, 1.4208316, 1.2062135], 150 | [1.0963078, 1.3661213, 1.4625252, 1.2954934, 1.0963076], 151 | [0.96534103, 1.2062135, 1.3661213, 1.2062135, 0.96534103], 152 | [0.65049475, 0.96534103, 1.2062135, 0.96534103, 0.81519336]]), 153 | (TEST_ARRAY, 0.5, 2, 2, [[0.6791669, 1.0550566, 1.1009896, 0.97656447, 0.8525824], 154 | [0.97656447, 1.1266409, 1.1405926, 1.1009896, 0.97656447], 155 | [1.0550566, 1.148072, 1.1541586, 1.1405926, 1.1009896], 156 | [1.1009896, 1.148072, 1.1552726, 1.148072, 1.1009896], 157 | [1.0550566, 1.1405926, 1.1520507, 1.1266409, 1.0550565], 158 | [0.97656447, 1.1009896, 1.1405926, 1.1009896, 0.97656447], 159 | [0.6791669, 0.97656447, 1.1009896, 0.97656447, 0.8525824]]), 160 | (TEST_ARRAY, 0.8, 1, 3, [[0.7708701, 0.87378895, 0.87378895, 0.87378895, 0.5666377], 161 | [0.87378883, 1.0733042, 1.167072, 1.0733042, 0.87378883], 162 | [1.2553324, 1.4124601, 1.4806038, 1.4124601, 1.0733042], 163 | [1.167072, 1.5417402, 1.5417402, 1.4806038, 1.2553324], 164 | [0.9751025, 1.2553325, 1.3372947, 1.3372947, 0.9751025], 165 | [0.77087015, 1.0733042, 1.0733042, 0.9751025, 0.77087015], 166 | [0.66795135, 0.87378895, 0.87378895, 0.87378895, 0.5666377]]), 167 | ] 168 | ) 169 | def test_smooth_fn(rho: np.ndarray, eta: float, beta: float, radius: float, expected_output: float): 170 | actual = get_smooth_fn(beta, radius, eta)(rho) 171 | np.testing.assert_allclose(actual, expected_output, rtol=1e-6) 172 | --------------------------------------------------------------------------------