├── .github └── workflows │ ├── check_docs.yml │ ├── codeql-analysis.yml │ ├── coverage_report.yml │ ├── release_pip.yml │ ├── tests_all.yml │ ├── tests_minversion.yml │ ├── tests_mpi.yml │ └── tests_types.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── build_all.sh ├── methods │ ├── boundary_discretization │ │ ├── BoundaryDiscretization.nb │ │ ├── boundary_discretization.pdf │ │ └── boundary_discretization.tex │ ├── common.tex │ ├── stencils │ │ └── Cartesian2D.nb │ └── vector_calculus │ │ ├── Bipolar.nb │ │ ├── Bispherical.nb │ │ ├── Cartesian.nb │ │ ├── Cylindrical.nb │ │ ├── Polar.nb │ │ └── Spherical.nb ├── paper │ ├── paper.bib │ └── paper.md ├── requirements.txt ├── source │ ├── .gitignore │ ├── _images │ │ ├── discretization.key │ │ ├── discretization.pdf │ │ ├── discretization_cropped.pdf │ │ ├── discretization_cropped.svg │ │ ├── logo.png │ │ ├── logo_small.png │ │ ├── performance_noflux.pdf │ │ ├── performance_noflux.png │ │ ├── performance_periodic.pdf │ │ └── performance_periodic.png │ ├── _static │ │ ├── custom.css │ │ ├── requirements_main.csv │ │ └── requirements_optional.csv │ ├── conf.py │ ├── create_performance_plots.py │ ├── gallery.rst │ ├── getting_started.rst │ ├── index.rst │ ├── manual │ │ ├── advanced_usage.rst │ │ ├── basic_usage.rst │ │ ├── citing.rst │ │ ├── code_of_conduct.rst │ │ ├── contributing.rst │ │ ├── index.rst │ │ ├── mathematical_basics.rst │ │ └── performance.rst │ ├── parse_examples.py │ └── run_autodoc.py └── sphinx_ext │ ├── package_config.py │ ├── simplify_typehints.py │ └── toctree_filter.py ├── examples ├── README.txt ├── advanced_pdes │ ├── README.txt │ ├── custom_noise.py │ ├── heterogeneous_bcs.py │ ├── mpi_parallel_run.py │ ├── pde_1d_class.py │ ├── pde_brusselator_class.py │ ├── pde_coupled.py │ ├── pde_custom_class.py │ ├── pde_custom_numba.py │ ├── pde_sir.py │ ├── post_step_hook.py │ ├── post_step_hook_class.py │ └── solver_comparison.py ├── fields │ ├── README.txt │ ├── analyze_scalar_field.py │ ├── finite_differences.py │ ├── plot_cylindrical_field.py │ ├── plot_polar_grid.py │ ├── plot_vector_field.py │ ├── random_fields.py │ └── show_3d_field_interactively.py ├── jupyter │ ├── .gitignore │ ├── Different solvers.ipynb │ ├── Discretized Fields.ipynb │ ├── Solve PDEs.ipynb │ ├── Tutorial 1 - Grids and fields.ipynb │ └── Tutorial 2 - Solving pre-defined partial differential equations.ipynb ├── output │ ├── README.txt │ ├── logarithmic_kymograph.py │ ├── make_movie_live.py │ ├── make_movie_storage.py │ ├── py_modelrunner.py │ ├── storages.py │ ├── tracker_interactive.py │ ├── trackers.py │ └── trajectory_io.py └── simple_pdes │ ├── README.txt │ ├── boundary_conditions.py │ ├── cartesian_grid.py │ ├── laplace_eq_2d.py │ ├── pde_1d_expression.py │ ├── pde_brusselator_expression.py │ ├── pde_custom_expression.py │ ├── pde_heterogeneous_diffusion.py │ ├── pde_schroedinger.py │ ├── poisson_eq_1d.py │ ├── simple.py │ ├── spherical_grid.py │ ├── stochastic_simulation.py │ └── time_dependent_bcs.py ├── pde ├── __init__.py ├── fields │ ├── __init__.py │ ├── base.py │ ├── collection.py │ ├── datafield_base.py │ ├── scalar.py │ ├── tensorial.py │ └── vectorial.py ├── grids │ ├── __init__.py │ ├── _mesh.py │ ├── base.py │ ├── boundaries │ │ ├── __init__.py │ │ ├── axes.py │ │ ├── axis.py │ │ └── local.py │ ├── cartesian.py │ ├── coordinates │ │ ├── __init__.py │ │ ├── base.py │ │ ├── bipolar.py │ │ ├── bispherical.py │ │ ├── cartesian.py │ │ ├── cylindrical.py │ │ ├── polar.py │ │ └── spherical.py │ ├── cylindrical.py │ ├── operators │ │ ├── __init__.py │ │ ├── cartesian.py │ │ ├── common.py │ │ ├── cylindrical_sym.py │ │ ├── polar_sym.py │ │ └── spherical_sym.py │ └── spherical.py ├── pdes │ ├── __init__.py │ ├── allen_cahn.py │ ├── base.py │ ├── cahn_hilliard.py │ ├── diffusion.py │ ├── kpz_interface.py │ ├── kuramoto_sivashinsky.py │ ├── laplace.py │ ├── pde.py │ ├── swift_hohenberg.py │ └── wave.py ├── py.typed ├── solvers │ ├── __init__.py │ ├── adams_bashforth.py │ ├── base.py │ ├── controller.py │ ├── crank_nicolson.py │ ├── explicit.py │ ├── explicit_mpi.py │ ├── implicit.py │ └── scipy.py ├── storage │ ├── __init__.py │ ├── base.py │ ├── file.py │ ├── memory.py │ ├── modelrunner.py │ └── movie.py ├── tools │ ├── __init__.py │ ├── cache.py │ ├── config.py │ ├── cuboid.py │ ├── docstrings.py │ ├── expressions.py │ ├── ffmpeg.py │ ├── math.py │ ├── misc.py │ ├── modelrunner.py │ ├── mpi.py │ ├── numba.py │ ├── output.py │ ├── parameters.py │ ├── parse_duration.py │ ├── plotting.py │ ├── resources │ │ ├── requirements_basic.txt │ │ ├── requirements_full.txt │ │ └── requirements_mpi.txt │ ├── spectral.py │ └── typing.py ├── trackers │ ├── __init__.py │ ├── base.py │ ├── interactive.py │ ├── interrupts.py │ └── trackers.py └── visualization │ ├── __init__.py │ ├── movies.py │ └── plotting.py ├── pyproject.toml ├── requirements.txt ├── runtime.txt ├── scripts ├── _templates │ ├── _pyproject.toml │ └── _runtime.txt ├── create_requirements.py ├── create_storage_test_resources.py ├── format_code.sh ├── performance_boundaries.py ├── performance_laplace.py ├── performance_solvers.py ├── profile_import.py ├── run_tests.py ├── show_environment.py ├── tests_all.sh ├── tests_codestyle.sh ├── tests_coverage.sh ├── tests_debug.sh ├── tests_extensive.sh ├── tests_mpi.sh ├── tests_parallel.sh ├── tests_run.sh └── tests_types.sh └── tests ├── _notebooks └── Test PlotTracker for different backend.ipynb ├── conftest.py ├── fields ├── fixtures │ ├── __init__.py │ └── fields.py ├── test_field_collections.py ├── test_generic_fields.py ├── test_scalar_fields.py ├── test_tensorial_fields.py └── test_vectorial_fields.py ├── grids ├── boundaries │ ├── test_axes_boundaries.py │ ├── test_axes_boundaries_legacy.py │ ├── test_axis_boundaries.py │ └── test_local_boundaries.py ├── operators │ ├── test_cartesian_operators.py │ ├── test_common_operators.py │ ├── test_cylindrical_operators.py │ ├── test_polar_operators.py │ └── test_spherical_operators.py ├── test_cartesian_grids.py ├── test_coordinates.py ├── test_cylindrical_grids.py ├── test_generic_grids.py ├── test_grid_mesh.py └── test_spherical_grids.py ├── pdes ├── test_diffusion_pdes.py ├── test_generic_pdes.py ├── test_laplace_pdes.py ├── test_pde_class.py ├── test_pdes_mpi.py └── test_wave_pdes.py ├── requirements.txt ├── requirements_full.txt ├── requirements_min.txt ├── requirements_mpi.txt ├── resources └── run_pde.py ├── solvers ├── test_adams_bashforth_solver.py ├── test_controller.py ├── test_explicit_mpi_solvers.py ├── test_explicit_solvers.py ├── test_generic_solvers.py ├── test_implicit_solvers.py └── test_scipy_solvers.py ├── storage ├── resources │ ├── empty.avi │ ├── no_metadata.avi │ ├── storage_1.avi │ ├── storage_1.hdf5 │ ├── storage_2.avi │ ├── storage_2.avi.times │ └── storage_2.hdf5 ├── test_file_storages.py ├── test_generic_storages.py ├── test_memory_storages.py ├── test_modelrunner_storages.py └── test_movie_storages.py ├── test_examples.py ├── test_integration.py ├── tools ├── test_cache.py ├── test_config.py ├── test_cuboid.py ├── test_expressions.py ├── test_ffmpeg.py ├── test_math.py ├── test_misc.py ├── test_mpi.py ├── test_numba.py ├── test_output.py ├── test_parameters.py ├── test_parse_duration.py ├── test_plotting_tools.py └── test_spectral.py ├── trackers ├── test_interrupts.py └── test_trackers.py └── visualization ├── test_movies.py └── test_plotting.py /.github/workflows/check_docs.yml: -------------------------------------------------------------------------------- 1 | name: "Check documentation" 2 | 3 | on: [push] 4 | 5 | jobs: 6 | docs: 7 | runs-on: ubuntu-latest 8 | timeout-minutes: 15 9 | 10 | steps: 11 | - uses: actions/checkout@v4 12 | 13 | - uses: ammaraskar/sphinx-action@master 14 | with: 15 | docs-folder: "docs/" 16 | pre-build-command: "pip install --upgrade pip" 17 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | name: "Code quality analysis" 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | # The branches below must be a subset of the branches above 8 | branches: [master] 9 | schedule: 10 | - cron: '0 10 * * 5' 11 | 12 | jobs: 13 | analyze: 14 | name: Analyze 15 | runs-on: ubuntu-latest 16 | 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | # Override automatic language detection by changing the below list 21 | # Supported options are ['csharp', 'cpp', 'go', 'java', 'javascript', 'python'] 22 | language: ['python'] 23 | # Learn more... 24 | # https://docs.github.com/en/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#overriding-automatic-language-detection 25 | 26 | steps: 27 | - name: Checkout repository 28 | uses: actions/checkout@v4 29 | with: 30 | # We must fetch at least the immediate parents so that if this is 31 | # a pull request then we can checkout the head. 32 | fetch-depth: 2 33 | 34 | # Initializes the CodeQL tools for scanning. 35 | - name: Initialize CodeQL 36 | uses: github/codeql-action/init@v2 37 | with: 38 | languages: ${{ matrix.language }} 39 | 40 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 41 | # If this step fails, then you should remove it and run the build manually (see below) 42 | - name: Autobuild 43 | uses: github/codeql-action/autobuild@v2 44 | 45 | # ℹ️ Command-line programs to run using the OS shell. 46 | # 📚 https://git.io/JvXDl 47 | 48 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 49 | # and modify them (or add more) to build your code if your project 50 | # uses a compiled language 51 | 52 | #- run: | 53 | # make bootstrap 54 | # make release 55 | 56 | - name: Perform CodeQL Analysis 57 | uses: github/codeql-action/analyze@v2 58 | -------------------------------------------------------------------------------- /.github/workflows/coverage_report.yml: -------------------------------------------------------------------------------- 1 | name: "Generate coverage report" 2 | 3 | on: [push] 4 | 5 | jobs: 6 | coverage_report: 7 | runs-on: ubuntu-latest 8 | timeout-minutes: 30 9 | 10 | steps: 11 | - uses: actions/checkout@v4 12 | 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: '3.11' 17 | 18 | - name: Setup FFmpeg 19 | uses: AnimMouse/setup-ffmpeg@v1 20 | 21 | - name: Install dependencies 22 | # install all requirements. Note that the full requirements are installed separately 23 | # so the job does not fail if one of the packages cannot be installed. This allows 24 | # testing the package for newer python version even when some of the optional 25 | # packages are not yet available. 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install -r requirements.txt 29 | cat tests/requirements_full.txt | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -n 1 -I % sh -c "pip install % || true" 30 | pip install -r tests/requirements.txt 31 | 32 | - name: Generate serial coverage report 33 | env: 34 | NUMBA_DISABLE_JIT: 1 35 | MPLBACKEND: agg 36 | PYPDE_TESTRUN: 1 37 | run: | 38 | export PYTHONPATH="${PYTHONPATH}:`pwd`" 39 | pytest --cov-config=pyproject.toml --cov=pde -n auto tests 40 | 41 | - name: Setup MPI 42 | uses: mpi4py/setup-mpi@v1 43 | 44 | - name: Generate parallel coverage report 45 | env: 46 | NUMBA_DISABLE_JIT: 1 47 | MPLBACKEND: agg 48 | PYPDE_TESTRUN: 1 49 | run: | 50 | export PYTHONPATH="${PYTHONPATH}:`pwd`" 51 | pip install -r tests/requirements_mpi.txt 52 | mpiexec -n 2 pytest --cov-config=pyproject.toml --cov-append --cov=pde --use_mpi tests 53 | 54 | - name: Create coverage report 55 | run: | 56 | coverage xml -o coverage_report.xml 57 | 58 | - name: Upload coverage to Codecov 59 | uses: codecov/codecov-action@v1 60 | with: 61 | token: ${{ secrets.CODECOV_TOKEN }} 62 | file: ./coverage_report.xml 63 | flags: unittests 64 | name: codecov-pydev 65 | fail_ci_if_error: true 66 | -------------------------------------------------------------------------------- /.github/workflows/release_pip.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [released] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | timeout-minutes: 15 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | with: 15 | fetch-depth: 0 # https://github.com/pypa/setuptools_scm/issues/480 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: '3.9' 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install build twine 26 | 27 | - name: Prepare build 28 | run: | 29 | python -m build 2>&1 | tee build.log 30 | # exit `fgrep -i warning build.log | wc -l` 31 | 32 | - name: Check the package 33 | run: twine check --strict dist/* 34 | 35 | - name: Build and publish 36 | env: 37 | TWINE_USERNAME: __token__ 38 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 39 | run: python -m twine upload dist/* -------------------------------------------------------------------------------- /.github/workflows/tests_all.yml: -------------------------------------------------------------------------------- 1 | name: "Serial tests" 2 | 3 | on: [push] 4 | 5 | jobs: 6 | serial_tests: 7 | strategy: 8 | matrix: 9 | os: [macos-latest, ubuntu-latest, windows-latest] 10 | runs-on: ${{ matrix.os }} 11 | timeout-minutes: 60 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.13' 20 | 21 | - name: Setup FFmpeg 22 | uses: AnimMouse/setup-ffmpeg@v1 23 | 24 | - name: Install dependencies 25 | # install all requirements. Note that the full requirements are installed separately 26 | # so the job does not fail if one of the packages cannot be installed. This allows 27 | # testing the package for newer python version even when some of the optional 28 | # packages are not yet available. 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install -r requirements.txt 32 | cat tests/requirements_full.txt | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -I % sh -c "pip install % || true" 33 | pip install -r tests/requirements.txt 34 | 35 | - name: Run serial tests with pytest 36 | env: 37 | NUMBA_WARNINGS: 1 38 | MPLBACKEND: agg 39 | run: | 40 | cd scripts 41 | python run_tests.py --unit --runslow --num_cores auto --showconfig 42 | -------------------------------------------------------------------------------- /.github/workflows/tests_minversion.yml: -------------------------------------------------------------------------------- 1 | name: "Tests with minimal requirements" 2 | 3 | on: [push] 4 | 5 | jobs: 6 | test_minversion: 7 | strategy: 8 | matrix: 9 | os: [macos-latest, ubuntu-latest] # windows-latest 10 | runs-on: ${{ matrix.os }} 11 | timeout-minutes: 60 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.9' 19 | 20 | - name: Install dependencies 21 | # install packages in the exact version given in requirements.txt 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install -r tests/requirements_min.txt 25 | pip install -r tests/requirements.txt 26 | 27 | - name: Test with pytest 28 | env: 29 | NUMBA_WARNINGS: 1 30 | MPLBACKEND: agg 31 | run: | 32 | cd scripts 33 | python run_tests.py --unit --runslow --num_cores auto --showconfig 34 | -------------------------------------------------------------------------------- /.github/workflows/tests_mpi.yml: -------------------------------------------------------------------------------- 1 | name: "Multiprocessing tests" 2 | 3 | on: [push] 4 | 5 | jobs: 6 | test_mpi: 7 | strategy: 8 | matrix: 9 | include: 10 | - os: "ubuntu-latest" 11 | mpi: "openmpi" 12 | - os: "macos-13" 13 | mpi: "mpich" 14 | - os: "windows-latest" 15 | mpi: "intelmpi" 16 | runs-on: ${{ matrix.os }} 17 | timeout-minutes: 30 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | 22 | - name: Set up Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: '3.13' 26 | 27 | - name: Setup MPI] 28 | uses: mpi4py/setup-mpi@v1 29 | with: 30 | mpi: ${{ matrix.mpi }} 31 | 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install -r tests/requirements_mpi.txt 36 | pip install -r tests/requirements.txt 37 | 38 | - name: Run parallel tests with pytest 39 | env: 40 | NUMBA_WARNINGS: 1 41 | MPLBACKEND: agg 42 | run: | 43 | cd scripts 44 | python run_tests.py --unit --use_mpi --showconfig 45 | -------------------------------------------------------------------------------- /.github/workflows/tests_types.yml: -------------------------------------------------------------------------------- 1 | name: "Static type checking" 2 | 3 | on: [push] 4 | 5 | jobs: 6 | test_types: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | pyversion: ['3.9', '3.13'] 11 | timeout-minutes: 30 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.pyversion }} 20 | 21 | - name: Install dependencies 22 | # install all requirements. Note that the full requirements are installed separately 23 | # so the job does not fail if one of the packages cannot be installed. This allows 24 | # testing the package for newer python version even when some of the optional 25 | # packages are not yet available. 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install -r requirements.txt 29 | cat tests/requirements_full.txt | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -I % sh -c "pip install % || true" 30 | pip install -r tests/requirements.txt 31 | pip install types-PyYAML 32 | 33 | - name: Test types with mypy 34 | continue-on-error: true 35 | run: | 36 | python -m mypy --config-file pyproject.toml --pretty --package pde 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # special project folders 2 | docs/source/packages/ 3 | docs/source/examples_gallery/ 4 | docs/paper/paper.pdf 5 | tests/coverage/ 6 | tests/mypy-report/ 7 | pde/_version.py 8 | scripts/flamegraph.html 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | docs/source/examples 15 | .idea 16 | 17 | # C extensions 18 | *.so 19 | 20 | # LaTeX 21 | *.aux 22 | *.bbl 23 | *.blg 24 | *.bgl 25 | *.out 26 | *.synctex.gz 27 | *.toc 28 | *Notes.bib 29 | 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # celery beat schedule file 101 | celerybeat-schedule 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | /.pydevproject 128 | /.project 129 | .DS_Store 130 | 131 | # IDE settings 132 | .settings 133 | .vscode 134 | *.code-workspace 135 | examples/output/allen_cahn.avi 136 | examples/output/allen_cahn.hdf 137 | debug 138 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | 7 | - repo: https://github.com/astral-sh/ruff-pre-commit 8 | rev: v0.6.1 9 | hooks: 10 | - id: ruff 11 | args: [--fix, --show-fixes] 12 | - id: ruff-format -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | apt_packages: 14 | - ffmpeg 15 | - graphviz 16 | 17 | # Build documentation in the docs/ directory with Sphinx 18 | sphinx: 19 | configuration: docs/source/conf.py 20 | 21 | # If using Sphinx, optionally build your docs in additional formats such as PDF 22 | formats: 23 | - epub 24 | - pdf 25 | 26 | # Optionally declare the Python requirements required to build your docs 27 | python: 28 | install: 29 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at david.zwicker@ds.mpg.de. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 David Zwicker 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | 3 | # You can set these variables from the command line. 4 | SPHINXOPTS = 5 | SPHINXBUILD = sphinx-build 6 | SOURCEDIR = source 7 | BUILDDIR = build 8 | 9 | # Put it first so that "make" without argument is like "make help". 10 | help: 11 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 12 | 13 | # clean: 14 | # rm -rf $(BUILDDIR)/* 15 | # rm -rf examples_gallery/* 16 | 17 | .PHONY: help Makefile 18 | 19 | # Define default target, which can be used in VS code 20 | .PHONY: default 21 | default: html 22 | 23 | # special make target for latex document to exclude the gallery 24 | latexpdf: Makefile 25 | @$(SPHINXBUILD) -M latexpdf "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -E -t exclude_gallery 26 | 27 | # special make target for linkcheck to be nit-picky 28 | linkcheck: Makefile 29 | @$(SPHINXBUILD) -M linkcheck "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -E -n 30 | 31 | # Catch-all target: route all unknown targets to Sphinx using the new 32 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 33 | %: Makefile 34 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -E 35 | -------------------------------------------------------------------------------- /docs/build_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | make html 4 | make latexpdf 5 | make linkcheck -------------------------------------------------------------------------------- /docs/methods/boundary_discretization/boundary_discretization.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/methods/boundary_discretization/boundary_discretization.pdf -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | -r ../requirements.txt 2 | ffmpeg-python>=0.2 3 | h5py>=2.10 4 | pandas>=2 5 | Pillow>=7.0 6 | py-modelrunner>=0.19 7 | pydot>=3 8 | Sphinx>=4 9 | sphinx-autodoc-annotation>=1.0 10 | sphinx-gallery>=0.6 11 | sphinx-rtd-theme>=1 12 | utilitiez>=0.3 13 | -------------------------------------------------------------------------------- /docs/source/.gitignore: -------------------------------------------------------------------------------- 1 | /sg_execution_times.rst 2 | -------------------------------------------------------------------------------- /docs/source/_images/discretization.key: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/discretization.key -------------------------------------------------------------------------------- /docs/source/_images/discretization.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/discretization.pdf -------------------------------------------------------------------------------- /docs/source/_images/discretization_cropped.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/discretization_cropped.pdf -------------------------------------------------------------------------------- /docs/source/_images/discretization_cropped.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 20 | 21 | 22 | 23 | 24 | 25 | x 26 | min 27 | 28 | 29 | 30 | x 31 | 0 32 | x 33 | 1 34 | x 35 | 2 36 | 37 | 38 | 39 | 40 | 41 | x 42 | N 43 | 44 | 1 45 | x 46 | N 47 | 48 | 2 49 | 50 | 51 | x 52 | max 53 | { 54 | Δ 55 | x 56 | 57 | -------------------------------------------------------------------------------- /docs/source/_images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/logo.png -------------------------------------------------------------------------------- /docs/source/_images/logo_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/logo_small.png -------------------------------------------------------------------------------- /docs/source/_images/performance_noflux.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/performance_noflux.pdf -------------------------------------------------------------------------------- /docs/source/_images/performance_noflux.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/performance_noflux.png -------------------------------------------------------------------------------- /docs/source/_images/performance_periodic.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/performance_periodic.pdf -------------------------------------------------------------------------------- /docs/source/_images/performance_periodic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/performance_periodic.png -------------------------------------------------------------------------------- /docs/source/_static/custom.css: -------------------------------------------------------------------------------- 1 | .sphx-glr-download-link-note { 2 | display: none; 3 | } 4 | 5 | .sphx-glr-script-out { 6 | display: none; 7 | } 8 | 9 | .wy-table-responsive table td { 10 | white-space: inherit; 11 | } -------------------------------------------------------------------------------- /docs/source/_static/requirements_main.csv: -------------------------------------------------------------------------------- 1 | Package,Minimal version,Usage 2 | matplotlib,3.1,Visualizing results 3 | numba,0.59,Just-in-time compilation to accelerate numerics 4 | numpy,1.22,Handling numerical data 5 | scipy,1.10,Miscellaneous scientific functions 6 | sympy,1.9,Dealing with user-defined mathematical expressions 7 | tqdm,4.66,Display progress bars during calculations 8 | -------------------------------------------------------------------------------- /docs/source/_static/requirements_optional.csv: -------------------------------------------------------------------------------- 1 | Package,Minimal version,Usage 2 | ffmpeg-python,0.2,Reading and writing videos 3 | h5py,2.10,Storing data in the hierarchical file format 4 | ipywidgets,8,Jupyter notebook support 5 | mpi4py,3,Parallel processing using MPI 6 | napari,0.4.8,Displaying images interactively 7 | numba-mpi,0.22,Parallel processing using MPI+numba 8 | pandas,2,Handling tabular data 9 | py-modelrunner,0.19,Running simulations and handling I/O 10 | pyfftw,0.12,Faster Fourier transforms 11 | rocket-fft,0.2.4,Numba-compiled fast Fourier transforms 12 | -------------------------------------------------------------------------------- /docs/source/gallery.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | These are example scripts using the `py-pde` package, which illustrates some of the most 5 | important features of the package. 6 | 7 | 8 | .. include:: /examples_gallery/fields/index.rst 9 | .. include:: /examples_gallery/simple_pdes/index.rst 10 | .. include:: /examples_gallery/output/index.rst 11 | .. include:: /examples_gallery/advanced_pdes/index.rst 12 | 13 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | 'py-pde' python package 2 | ======================= 3 | 4 | .. figure:: _images/logo.png 5 | :figwidth: 25% 6 | :align: right 7 | :alt: Logo of the py-pde package 8 | 9 | The `py-pde` python package provides methods and classes useful for solving 10 | partial differential equations (PDEs) of the form 11 | 12 | .. math:: 13 | \partial_t u(\boldsymbol x, t) = \mathcal D[u(\boldsymbol x, t)] 14 | + \eta(u, \boldsymbol x, t) \;, 15 | 16 | where :math:`\mathcal D` is a (non-linear) operator containing spatial derivatives that 17 | defines the time evolution of a (set of) physical fields :math:`u` with possibly 18 | tensorial character, which depend on spatial coordinates :math:`\boldsymbol x` and time 19 | :math:`t`. 20 | The framework also supports stochastic differential equations in the Itô 21 | representation, where the noise is represented by :math:`\eta` above. 22 | 23 | The main audience for the package are researchers and students who want to 24 | investigate the behavior of a PDE and get an intuitive understanding of the 25 | role of the different terms and the boundary conditions. 26 | To support this, `py-pde` evaluates PDEs using the methods of lines with a 27 | finite-difference approximation of the differential operators. 28 | Consequently, the mathematical operator :math:`\mathcal D` can be naturally 29 | translated to a function evaluating the evolution rate of the PDE. 30 | 31 | 32 | **Contents** 33 | 34 | .. toctree-filt:: 35 | :maxdepth: 2 36 | :numbered: 37 | 38 | getting_started 39 | gallery 40 | manual/index 41 | packages/pde 42 | 43 | 44 | **Indices and tables** 45 | 46 | * :ref:`genindex` 47 | * :ref:`modindex` 48 | * :ref:`search` 49 | -------------------------------------------------------------------------------- /docs/source/manual/citing.rst: -------------------------------------------------------------------------------- 1 | Citing the package 2 | ^^^^^^^^^^^^^^^^^^ 3 | 4 | To cite or reference `py-pde` in other work, please refer to the `publication in 5 | the Journal of Open Source Software `_. 6 | Here are the respective bibliographic records in Bibtex format: 7 | 8 | .. code-block:: bibtex 9 | 10 | @article{py-pde, 11 | Author = {David Zwicker}, 12 | Doi = {10.21105/joss.02158}, 13 | Journal = {Journal of Open Source Software}, 14 | Number = {48}, 15 | Pages = {2158}, 16 | Publisher = {The Open Journal}, 17 | Title = {py-pde: A Python package for solving partial differential equations}, 18 | Url = {https://doi.org/10.21105/joss.02158}, 19 | Volume = {5}, 20 | Year = {2020} 21 | } 22 | 23 | and in RIS format: 24 | 25 | .. code-block:: text 26 | 27 | TY - JOUR 28 | AU - Zwicker, David 29 | JO - Journal of Open Source Software 30 | IS - 48 31 | SP - 2158 32 | PB - The Open Journal 33 | T1 - py-pde: A Python package for solving partial differential equations 34 | UR - https://doi.org/10.21105/joss.02158 35 | VL - 5 36 | PY - 2020 37 | -------------------------------------------------------------------------------- /docs/source/manual/code_of_conduct.rst: -------------------------------------------------------------------------------- 1 | Code of Conduct 2 | =============== 3 | 4 | Our Pledge 5 | ---------- 6 | 7 | In the interest of fostering an open and welcoming environment, we as 8 | contributors and maintainers pledge to making participation in our 9 | project and our community a harassment-free experience for everyone, 10 | regardless of age, body size, disability, ethnicity, sex 11 | characteristics, gender identity and expression, level of experience, 12 | education, socio-economic status, nationality, personal appearance, 13 | race, religion, or sexual identity and orientation. 14 | 15 | Our Standards 16 | ------------- 17 | 18 | Examples of behavior that contributes to creating a positive environment 19 | include: 20 | 21 | - Using welcoming and inclusive language 22 | - Being respectful of differing viewpoints and experiences 23 | - Gracefully accepting constructive criticism 24 | - Focusing on what is best for the community 25 | - Showing empathy towards other community members 26 | 27 | Examples of unacceptable behavior by participants include: 28 | 29 | - The use of sexualized language or imagery and unwelcome sexual 30 | attention or advances 31 | - Trolling, insulting/derogatory comments, and personal or political 32 | attacks 33 | - Public or private harassment 34 | - Publishing others’ private information, such as a physical or 35 | electronic address, without explicit permission 36 | - Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | Our Responsibilities 40 | -------------------- 41 | 42 | Project maintainers are responsible for clarifying the standards of 43 | acceptable behavior and are expected to take appropriate and fair 44 | corrective action in response to any instances of unacceptable behavior. 45 | 46 | Project maintainers have the right and responsibility to remove, edit, 47 | or reject comments, commits, code, wiki edits, issues, and other 48 | contributions that are not aligned to this Code of Conduct, or to ban 49 | temporarily or permanently any contributor for other behaviors that they 50 | deem inappropriate, threatening, offensive, or harmful. 51 | 52 | Scope 53 | ----- 54 | 55 | This Code of Conduct applies both within project spaces and in public 56 | spaces when an individual is representing the project or its community. 57 | Examples of representing a project or community include using an 58 | official project e-mail address, posting via an official social media 59 | account, or acting as an appointed representative at an online or 60 | offline event. Representation of a project may be further defined and 61 | clarified by project maintainers. 62 | 63 | Enforcement 64 | ----------- 65 | 66 | Instances of abusive, harassing, or otherwise unacceptable behavior may 67 | be reported by contacting the project team at david.zwicker@ds.mpg.de. 68 | All complaints will be reviewed and investigated and will result in a 69 | response that is deemed necessary and appropriate to the circumstances. 70 | The project team is obligated to maintain confidentiality with regard to 71 | the reporter of an incident. Further details of specific enforcement 72 | policies may be posted separately. 73 | 74 | Project maintainers who do not follow or enforce the Code of Conduct in 75 | good faith may face temporary or permanent repercussions as determined 76 | by other members of the project’s leadership. 77 | 78 | Attribution 79 | ----------- 80 | 81 | This Code of Conduct is adapted from the `Contributor 82 | Covenant `__, version 1.4, 83 | available at 84 | https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 85 | 86 | For answers to common questions about this code of conduct, see 87 | https://www.contributor-covenant.org/faq 88 | -------------------------------------------------------------------------------- /docs/source/manual/index.rst: -------------------------------------------------------------------------------- 1 | User manual 2 | =========== 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 3 7 | 8 | 9 | mathematical_basics 10 | basic_usage 11 | advanced_usage 12 | performance 13 | contributing 14 | citing 15 | code_of_conduct -------------------------------------------------------------------------------- /docs/source/parse_examples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import pathlib 4 | 5 | # Root direcotry of the package 6 | ROOT = pathlib.Path(__file__).absolute().parents[2] 7 | # directory where all the examples reside 8 | INPUT = ROOT / "examples" 9 | # directory to which the documents are writen 10 | OUTPUT = ROOT / "docs" / "source" / "examples" 11 | 12 | 13 | def main(): 14 | """Parse all examples and write them in a special example module.""" 15 | # create the output directory 16 | OUTPUT.mkdir(parents=True, exist_ok=True) 17 | 18 | # iterate over all examples 19 | for path_in in INPUT.glob("*.py"): 20 | path_out = OUTPUT / (path_in.stem + ".rst") 21 | print(f"Found example {path_in}") 22 | with path_in.open("r") as file_in, path_out.open("w") as file_out: 23 | # write the header for the rst file 24 | file_out.write(".. code-block:: python\n\n") 25 | 26 | # add the actual code lines 27 | header = True 28 | for line in file_in: 29 | # skip the shebang, comments and empty lines in the beginning 30 | if header and (line.startswith("#") or len(line.strip()) == 0): 31 | continue 32 | header = False # first real line was reached 33 | file_out.write(" " + line) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /docs/source/run_autodoc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import glob 4 | import logging 5 | import os 6 | import subprocess as sp 7 | from pathlib import Path 8 | 9 | logging.basicConfig(level=logging.INFO) 10 | 11 | OUTPUT_PATH = "packages" 12 | REPLACEMENTS = { 13 | "Submodules\n----------\n\n": "", 14 | "Subpackages\n-----------": "**Subpackages:**", 15 | "pde package\n===========": "Reference manual\n================", 16 | } 17 | 18 | 19 | def replace_in_file(infile, replacements, outfile=None): 20 | """Reads in a file, replaces the given data using python formatting and writes back 21 | the result to a file. 22 | 23 | Args: 24 | infile (str): 25 | File to be read 26 | replacements (dict): 27 | The replacements old => new in a dictionary format {old: new} 28 | outfile (str): 29 | Output file to which the data is written. If it is omitted, the 30 | input file will be overwritten instead 31 | """ 32 | if outfile is None: 33 | outfile = infile 34 | 35 | with Path(infile).open() as fp: 36 | content = fp.read() 37 | 38 | for key, value in replacements.items(): 39 | content = content.replace(key, value) 40 | 41 | with Path(outfile).open("w") as fp: 42 | fp.write(content) 43 | 44 | 45 | def main(): 46 | """Run the autodoc call.""" 47 | logger = logging.getLogger("autodoc") 48 | 49 | # remove old files 50 | for path in Path(OUTPUT_PATH).glob("*.rst"): 51 | logger.info("Remove file `%s`", path) 52 | path.unlink() 53 | 54 | # run sphinx-apidoc 55 | sp.check_call( 56 | [ 57 | "sphinx-apidoc", 58 | "--separate", 59 | "--maxdepth", 60 | "4", 61 | "--output-dir", 62 | OUTPUT_PATH, 63 | "--module-first", 64 | "../../pde", # path of the package 65 | "../../pde/version.py", # ignored file 66 | "../../**/conftest.py", # ignored file 67 | "../../**/tests", # ignored path 68 | ] 69 | ) 70 | 71 | # replace unwanted information 72 | for path in Path(OUTPUT_PATH).glob("*.rst"): 73 | logger.info("Patch file `%s`", path) 74 | replace_in_file(path, REPLACEMENTS) 75 | 76 | 77 | if __name__ == "__main__": 78 | main() 79 | -------------------------------------------------------------------------------- /docs/sphinx_ext/package_config.py: -------------------------------------------------------------------------------- 1 | from docutils import nodes 2 | from sphinx.util.docutils import SphinxDirective 3 | 4 | 5 | class PackageConfigDirective(SphinxDirective): 6 | """Directive that displays all package configuration items.""" 7 | 8 | has_content = True 9 | required_arguments = 0 10 | optional_arguments = 0 11 | final_argument_whitespace = False 12 | 13 | def run(self): 14 | from pde.tools.config import Config 15 | 16 | c = Config() 17 | items = [] 18 | 19 | for p in c.data.values(): 20 | description = nodes.paragraph(text=p.description + " ") 21 | description += nodes.strong(text=f"(Default value: {c[p.name]!r})") 22 | 23 | items += nodes.definition_list_item( 24 | "", 25 | nodes.term(text=p.name), 26 | nodes.definition("", description), 27 | ) 28 | 29 | return [nodes.definition_list("", *items)] 30 | 31 | 32 | def setup(app): 33 | app.add_directive("package_configuration", PackageConfigDirective) 34 | return {"version": "1.0.0"} 35 | -------------------------------------------------------------------------------- /docs/sphinx_ext/simplify_typehints.py: -------------------------------------------------------------------------------- 1 | """Simple sphinx plug-in that simplifies type information in function signatures.""" 2 | 3 | import re 4 | 5 | # simple (literal) replacement rules 6 | REPLACEMENTS = [ 7 | # numbers and numerical arrays 8 | ("Union[int, float, complex, numpy.generic, numpy.ndarray]", "NumberOrArray"), 9 | ("Union[int, float, complex, numpy.ndarray]", "NumberOrArray"), 10 | ("Union[int, float, complex]", "Number"), 11 | ( 12 | "Optional[Union[_SupportsArray[dtype], _NestedSequence[_SupportsArray[dtype]], " 13 | "bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, " 14 | "float, complex, str, bytes]]]]", 15 | "NumberOrArray", 16 | ), 17 | ( 18 | "Union[dtype[Any], None, Type[Any], _SupportsDType[dtype[Any]], str, " 19 | "Tuple[Any, int], Tuple[Any, Union[SupportsIndex, Sequence[SupportsIndex]]], " 20 | "List[Any], _DTypeDict, Tuple[Any, Any]]", 21 | "DType", 22 | ), 23 | # Complex types describing the boundary conditions 24 | ( 25 | "Dict[str, Dict | str | BCBase] | Dict | str | BCBase | " 26 | "Tuple[Dict | str | BCBase, Dict | str | BCBase] | BoundaryAxisBase | " 27 | "Sequence[Dict[str, Dict | str | BCBase] | Dict | str | BCBase | " 28 | "Tuple[Dict | str | BCBase, Dict | str | BCBase] | BoundaryAxisBase]", 29 | "BoundariesData", 30 | ), 31 | ( 32 | "Dict[str, Dict | str | BCBase] | Dict | str | BCBase | " 33 | "Tuple[Dict | str | BCBase, Dict | str | BCBase] | BoundaryAxisBase", 34 | "BoundariesPairData", 35 | ), 36 | ("Dict | str | BCBase", "BoundaryData"), 37 | # Other compound data types 38 | ("Union[List[Union[TrackerBase, str]], TrackerBase, str, None]", "TrackerData"), 39 | ( 40 | "Optional[Union[List[Union[TrackerBase, str]], TrackerBase, str]]", 41 | "TrackerData", 42 | ), 43 | ] 44 | 45 | 46 | # replacement rules based on regular expressions 47 | REPLACEMENTS_REGEX = { 48 | # remove full package path and only leave the module/class identifier 49 | r"pde\.(\w+\.)*": "", 50 | r"typing\.": "", 51 | } 52 | 53 | 54 | def process_signature( 55 | app, what: str, name: str, obj, options, signature, return_annotation 56 | ): 57 | """Process signature by applying replacement rules.""" 58 | 59 | def process(sig_obj): 60 | """Process the signature object.""" 61 | if sig_obj is not None: 62 | for key, value in REPLACEMENTS_REGEX.items(): 63 | sig_obj = re.sub(key, value, sig_obj) 64 | for key, value in REPLACEMENTS: 65 | sig_obj = sig_obj.replace(key, value) 66 | return sig_obj 67 | 68 | signature = process(signature) 69 | return_annotation = process(return_annotation) 70 | 71 | return signature, return_annotation 72 | 73 | 74 | def process_docstring(app, what: str, name: str, obj, options, lines): 75 | """Process docstring by applying replacement rules.""" 76 | for i, line in enumerate(lines): 77 | for key, value in REPLACEMENTS: 78 | line = line.replace(key, value) 79 | lines[i] = line 80 | 81 | 82 | def setup(app): 83 | """Set up hooks for this sphinx plugin.""" 84 | app.connect("autodoc-process-signature", process_signature) 85 | app.connect("autodoc-process-docstring", process_docstring) 86 | -------------------------------------------------------------------------------- /docs/sphinx_ext/toctree_filter.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from sphinx.directives.other import TocTree 4 | 5 | 6 | class TocTreeFilter(TocTree): 7 | """Directive to filter table-of-contents entries.""" 8 | 9 | hasPat = re.compile(r"^\s*:(.+):(.+)$") 10 | 11 | # Remove any entries in the content that we dont want and strip 12 | # out any filter prefixes that we want but obviously don't want the 13 | # prefix to mess up the file name. 14 | def filter_entries(self, entries): 15 | excl = self.state.document.settings.env.config.toc_filter_exclude 16 | filtered = [] 17 | for e in entries: 18 | m = self.hasPat.match(e) 19 | if m != None: 20 | if not m.groups()[0] in excl: 21 | filtered.append(m.groups()[1]) 22 | else: 23 | filtered.append(e) 24 | return filtered 25 | 26 | def run(self): 27 | # Remove all TOC entries that should not be on display 28 | self.content = self.filter_entries(self.content) 29 | return super().run() 30 | 31 | 32 | def setup(app): 33 | app.add_config_value("toc_filter_exclude", [], "html") 34 | app.add_directive("toctree-filt", TocTreeFilter) 35 | return {"version": "1.0.0"} 36 | -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | These are example scripts using the `py-pde` package, which illustrates some of the most 5 | important features of the package. -------------------------------------------------------------------------------- /examples/advanced_pdes/README.txt: -------------------------------------------------------------------------------- 1 | Advanced PDEs 2 | ------------- 3 | 4 | These examples demonstrate more advanced usage of the package. -------------------------------------------------------------------------------- /examples/advanced_pdes/custom_noise.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom noise 3 | ============ 4 | 5 | This example solves a diffusion equation with a custom noise. 6 | """ 7 | 8 | import numpy as np 9 | 10 | from pde import DiffusionPDE, ScalarField, UnitGrid 11 | from pde.tools.numba import jit 12 | 13 | 14 | class DiffusionCustomNoisePDE(DiffusionPDE): 15 | """Diffusion PDE with custom noise implementations.""" 16 | 17 | def noise_realization(self, state, t): 18 | """Numpy implementation of spatially-dependent noise.""" 19 | noise_field = ScalarField.random_uniform(state.grid, -self.noise, self.noise) 20 | return state.grid.cell_coords[..., 0] * noise_field 21 | 22 | def _make_noise_realization_numba(self, state): 23 | """Numba implementation of spatially-dependent noise.""" 24 | noise = float(self.noise) 25 | x_values = state.grid.cell_coords[..., 0] 26 | 27 | @jit 28 | def noise_realization(state_data, t): 29 | return x_values * np.random.uniform(-noise, noise, size=state_data.shape) 30 | 31 | return noise_realization 32 | 33 | 34 | eq = DiffusionCustomNoisePDE(diffusivity=0.1, noise=0.1) # define the pde 35 | state = ScalarField.random_uniform(UnitGrid([64, 64])) # generate initial condition 36 | result = eq.solve(state, t_range=10, dt=0.01) 37 | result.plot() 38 | -------------------------------------------------------------------------------- /examples/advanced_pdes/heterogeneous_bcs.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Heterogeneous boundary conditions 3 | ================================= 4 | 5 | This example implements a diffusion equation with a boundary condition specified by a 6 | function, which can in principle depend on time. 7 | """ 8 | 9 | import numpy as np 10 | 11 | from pde import CartesianGrid, DiffusionPDE, ScalarField 12 | 13 | # define grid and an initial state 14 | grid = CartesianGrid([[-5, 5], [-5, 5]], 32) 15 | field = ScalarField(grid) 16 | 17 | 18 | # define the boundary conditions, which here are calculated from a function 19 | def bc_value(adjacent_value, dx, x, y, t): 20 | """Return boundary value.""" 21 | return np.sign(x) 22 | 23 | 24 | # define and solve a simple diffusion equation 25 | eq = DiffusionPDE(bc={"*": {"derivative": 0}, "y+": {"value_expression": bc_value}}) 26 | res = eq.solve(field, t_range=10, dt=0.01, backend="numpy") 27 | res.plot() 28 | -------------------------------------------------------------------------------- /examples/advanced_pdes/mpi_parallel_run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use multiprocessing via MPI 3 | =========================== 4 | 5 | Use multiple cores to solve a PDE. The implementation here uses the `Message Passing 6 | Interface (MPI) `_, and the 7 | script thus needs to be run using :code:`mpiexec -n 2 python mpi_parallel_run.py`, where 8 | `2` denotes the number of cores used. Note that macOS might require an additional hint 9 | on how to connect the processes. The following line might work: 10 | `mpiexec -n 2 -host localhost:2 python3 mpi_parallel_run.py` 11 | 12 | Such parallel simulations need extra care, since multiple instances of the same program 13 | are started. In particular, in the example below, the initial state is created on all 14 | cores. However, only the state of the first core will actually be used and distributed 15 | automatically by `py-pde`. Note that also only the first (or main) core will run the 16 | trackers and receive the result of the simulation. On all other cores, the simulation 17 | result will be `None`. 18 | """ 19 | 20 | from pde import DiffusionPDE, ScalarField, UnitGrid 21 | 22 | grid = UnitGrid([64, 64]) # generate grid 23 | state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition 24 | 25 | eq = DiffusionPDE(diffusivity=0.1) # define the pde 26 | result = eq.solve(state, t_range=10, dt=0.1, solver="explicit_mpi") 27 | 28 | if result is not None: # check whether we are on the main core 29 | result.plot() 30 | -------------------------------------------------------------------------------- /examples/advanced_pdes/pde_1d_class.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 1D problem - Using custom class 3 | =============================== 4 | 5 | This example implements a PDE that is only defined in one dimension. 6 | Here, we chose the `Korteweg-de Vries equation 7 | `_, given by 8 | 9 | .. math:: 10 | \partial_t \phi = 6 \phi \partial_x \phi - \partial_x^3 \phi 11 | 12 | which we implement using a custom PDE class below. 13 | """ 14 | 15 | from math import pi 16 | 17 | from pde import CartesianGrid, MemoryStorage, PDEBase, ScalarField, plot_kymograph 18 | 19 | 20 | class KortewegDeVriesPDE(PDEBase): 21 | """Korteweg-de Vries equation.""" 22 | 23 | def evolution_rate(self, state, t=0): 24 | """Implement the python version of the evolution equation.""" 25 | assert state.grid.dim == 1 # ensure the state is one-dimensional 26 | grad_x = state.gradient("auto_periodic_neumann")[0] 27 | return 6 * state * grad_x - grad_x.laplace("auto_periodic_neumann") 28 | 29 | 30 | # initialize the equation and the space 31 | grid = CartesianGrid([[0, 2 * pi]], [32], periodic=True) 32 | state = ScalarField.from_expression(grid, "sin(x)") 33 | 34 | # solve the equation and store the trajectory 35 | storage = MemoryStorage() 36 | eq = KortewegDeVriesPDE() 37 | eq.solve(state, t_range=3, solver="scipy", tracker=storage.tracker(0.1)) 38 | 39 | # plot the trajectory as a space-time plot 40 | plot_kymograph(storage) 41 | -------------------------------------------------------------------------------- /examples/advanced_pdes/pde_brusselator_class.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Brusselator - Using custom class 3 | ================================ 4 | 5 | This example implements the `Brusselator 6 | `_ with spatial coupling, 7 | 8 | .. math:: 9 | 10 | \partial_t u &= D_0 \nabla^2 u + a - (1 + b) u + v u^2 \\ 11 | \partial_t v &= D_1 \nabla^2 v + b u - v u^2 12 | 13 | Here, :math:`D_0` and :math:`D_1` are the respective diffusivity and the 14 | parameters :math:`a` and :math:`b` are related to reaction rates. 15 | 16 | Note that the PDE can also be implemented using the :class:`~pde.pdes.pde.PDE` 17 | class; see :doc:`the example <../simple_pdes/pde_brusselator_expression>`. However, that 18 | implementation is less flexible and might be more difficult to extend later. 19 | """ 20 | 21 | import numba as nb 22 | import numpy as np 23 | 24 | from pde import FieldCollection, PDEBase, PlotTracker, ScalarField, UnitGrid 25 | 26 | 27 | class BrusselatorPDE(PDEBase): 28 | """Brusselator with diffusive mobility.""" 29 | 30 | def __init__(self, a=1, b=3, diffusivity=None, bc="auto_periodic_neumann"): 31 | super().__init__() 32 | self.a = a 33 | self.b = b 34 | self.diffusivity = [1, 0.1] if diffusivity is None else diffusivity 35 | self.bc = bc # boundary condition 36 | 37 | def get_initial_state(self, grid): 38 | """Prepare a useful initial state.""" 39 | u = ScalarField(grid, self.a, label="Field $u$") 40 | v = self.b / self.a + 0.1 * ScalarField.random_normal(grid, label="Field $v$") 41 | return FieldCollection([u, v]) 42 | 43 | def evolution_rate(self, state, t=0): 44 | """Pure python implementation of the PDE.""" 45 | u, v = state 46 | rhs = state.copy() 47 | d0, d1 = self.diffusivity 48 | rhs[0] = d0 * u.laplace(self.bc) + self.a - (self.b + 1) * u + u**2 * v 49 | rhs[1] = d1 * v.laplace(self.bc) + self.b * u - u**2 * v 50 | return rhs 51 | 52 | def _make_pde_rhs_numba(self, state): 53 | """Nunmba-compiled implementation of the PDE.""" 54 | d0, d1 = self.diffusivity 55 | a, b = self.a, self.b 56 | laplace = state.grid.make_operator("laplace", bc=self.bc) 57 | 58 | @nb.njit 59 | def pde_rhs(state_data, t): 60 | u = state_data[0] 61 | v = state_data[1] 62 | 63 | rate = np.empty_like(state_data) 64 | rate[0] = d0 * laplace(u) + a - (1 + b) * u + v * u**2 65 | rate[1] = d1 * laplace(v) + b * u - v * u**2 66 | return rate 67 | 68 | return pde_rhs 69 | 70 | 71 | # initialize state 72 | grid = UnitGrid([64, 64]) 73 | eq = BrusselatorPDE(diffusivity=[1, 0.1]) 74 | state = eq.get_initial_state(grid) 75 | 76 | # simulate the pde 77 | tracker = PlotTracker(interrupts=1, plot_args={"kind": "merged", "vmin": 0, "vmax": 5}) 78 | sol = eq.solve(state, t_range=20, dt=1e-3, tracker=tracker) 79 | -------------------------------------------------------------------------------- /examples/advanced_pdes/pde_coupled.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Custom Class for coupled PDEs 3 | ============================= 4 | 5 | This example shows how to solve a set of coupled PDEs, the 6 | spatially coupled `FitzHugh–Nagumo model 7 | `_, which is a simple model 8 | for the excitable dynamics of coupled Neurons: 9 | 10 | .. math:: 11 | 12 | \partial_t u &= \nabla^2 u + u (u - \alpha) (1 - u) + w \\ 13 | \partial_t w &= \epsilon u 14 | 15 | Here, :math:`\alpha` denotes the external stimulus and :math:`\epsilon` defines 16 | the recovery time scale. We implement this as a custom PDE class below. 17 | """ 18 | 19 | from pde import FieldCollection, PDEBase, UnitGrid 20 | 21 | 22 | class FitzhughNagumoPDE(PDEBase): 23 | """FitzHugh–Nagumo model with diffusive coupling.""" 24 | 25 | def __init__(self, stimulus=0.5, τ=10, a=0, b=0, bc="auto_periodic_neumann"): 26 | super().__init__() 27 | self.bc = bc 28 | self.stimulus = stimulus 29 | self.τ = τ 30 | self.a = a 31 | self.b = b 32 | 33 | def evolution_rate(self, state, t=0): 34 | v, w = state # membrane potential and recovery variable 35 | 36 | v_t = v.laplace(bc=self.bc) + v - v**3 / 3 - w + self.stimulus 37 | w_t = (v + self.a - self.b * w) / self.τ 38 | 39 | return FieldCollection([v_t, w_t]) 40 | 41 | 42 | grid = UnitGrid([32, 32]) 43 | state = FieldCollection.scalar_random_uniform(2, grid) 44 | 45 | eq = FitzhughNagumoPDE() 46 | result = eq.solve(state, t_range=100, dt=0.01) 47 | result.plot() 48 | -------------------------------------------------------------------------------- /examples/advanced_pdes/pde_custom_class.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Kuramoto-Sivashinsky - Using custom class 3 | ========================================= 4 | 5 | This example implements a scalar PDE using a custom class. We here consider the 6 | `Kuramoto–Sivashinsky equation 7 | `_, which for instance 8 | describes the dynamics of flame fronts: 9 | 10 | .. math:: 11 | \partial_t u = -\frac12 |\nabla u|^2 - \nabla^2 u - \nabla^4 u 12 | """ 13 | 14 | from pde import PDEBase, ScalarField, UnitGrid 15 | 16 | 17 | class KuramotoSivashinskyPDE(PDEBase): 18 | """Implementation of the normalized Kuramoto–Sivashinsky equation.""" 19 | 20 | def evolution_rate(self, state, t=0): 21 | """Implement the python version of the evolution equation.""" 22 | state_lap = state.laplace(bc="auto_periodic_neumann") 23 | state_lap2 = state_lap.laplace(bc="auto_periodic_neumann") 24 | state_grad = state.gradient(bc="auto_periodic_neumann") 25 | return -state_grad.to_scalar("squared_sum") / 2 - state_lap - state_lap2 26 | 27 | 28 | grid = UnitGrid([32, 32]) # generate grid 29 | state = ScalarField.random_uniform(grid) # generate initial condition 30 | 31 | eq = KuramotoSivashinskyPDE() # define the pde 32 | result = eq.solve(state, t_range=10, dt=0.01) 33 | result.plot() 34 | -------------------------------------------------------------------------------- /examples/advanced_pdes/pde_custom_numba.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Kuramoto-Sivashinsky - Compiled methods 3 | ======================================= 4 | 5 | This example implements a scalar PDE using a custom class with a numba-compiled method 6 | for accelerated calculations. We here consider the `Kuramoto–Sivashinsky equation 7 | `_, which for instance 8 | describes the dynamics of flame fronts: 9 | 10 | .. math:: 11 | \partial_t u = -\frac12 |\nabla u|^2 - \nabla^2 u - \nabla^4 u 12 | """ 13 | 14 | import numba as nb 15 | 16 | from pde import PDEBase, ScalarField, UnitGrid 17 | 18 | 19 | class KuramotoSivashinskyPDE(PDEBase): 20 | """Implementation of the normalized Kuramoto–Sivashinsky equation.""" 21 | 22 | def __init__(self, bc="auto_periodic_neumann"): 23 | super().__init__() 24 | self.bc = bc 25 | 26 | def evolution_rate(self, state, t=0): 27 | """Implement the python version of the evolution equation.""" 28 | state_lap = state.laplace(bc=self.bc) 29 | state_lap2 = state_lap.laplace(bc=self.bc) 30 | state_grad_sq = state.gradient_squared(bc=self.bc) 31 | return -state_grad_sq / 2 - state_lap - state_lap2 32 | 33 | def _make_pde_rhs_numba(self, state): 34 | """Nunmba-compiled implementation of the PDE.""" 35 | gradient_squared = state.grid.make_operator("gradient_squared", bc=self.bc) 36 | laplace = state.grid.make_operator("laplace", bc=self.bc) 37 | 38 | @nb.njit 39 | def pde_rhs(data, t): 40 | return -0.5 * gradient_squared(data) - laplace(data + laplace(data)) 41 | 42 | return pde_rhs 43 | 44 | 45 | grid = UnitGrid([32, 32]) # generate grid 46 | state = ScalarField.random_uniform(grid) # generate initial condition 47 | 48 | eq = KuramotoSivashinskyPDE() # define the pde 49 | result = eq.solve(state, t_range=10, dt=0.01) 50 | result.plot() 51 | -------------------------------------------------------------------------------- /examples/advanced_pdes/pde_sir.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Custom PDE class: SIR model 3 | =========================== 4 | 5 | This example implements a `spatially coupled SIR model 6 | `_ with 7 | the following dynamics for the density of susceptible, infected, and recovered 8 | individuals: 9 | 10 | .. math:: 11 | 12 | \partial_t s &= D \nabla^2 s - \beta is \\ 13 | \partial_t i &= D \nabla^2 i + \beta is - \gamma i \\ 14 | \partial_t r &= D \nabla^2 r + \gamma i 15 | 16 | Here, :math:`D` is the diffusivity, :math:`\beta` the infection rate, and 17 | :math:`\gamma` the recovery rate. 18 | """ 19 | 20 | from pde import FieldCollection, PDEBase, PlotTracker, ScalarField, UnitGrid 21 | 22 | 23 | class SIRPDE(PDEBase): 24 | """SIR-model with diffusive mobility.""" 25 | 26 | def __init__( 27 | self, beta=0.3, gamma=0.9, diffusivity=0.1, bc="auto_periodic_neumann" 28 | ): 29 | super().__init__() 30 | self.beta = beta # transmission rate 31 | self.gamma = gamma # recovery rate 32 | self.diffusivity = diffusivity # spatial mobility 33 | self.bc = bc # boundary condition 34 | 35 | def get_state(self, s, i): 36 | """Generate a suitable initial state.""" 37 | norm = (s + i).data.max() # maximal density 38 | if norm > 1: 39 | s /= norm 40 | i /= norm 41 | s.label = "Susceptible" 42 | i.label = "Infected" 43 | 44 | # create recovered field 45 | r = ScalarField(s.grid, data=1 - s - i, label="Recovered") 46 | return FieldCollection([s, i, r]) 47 | 48 | def evolution_rate(self, state, t=0): 49 | s, i, r = state 50 | diff = self.diffusivity 51 | ds_dt = diff * s.laplace(self.bc) - self.beta * i * s 52 | di_dt = diff * i.laplace(self.bc) + self.beta * i * s - self.gamma * i 53 | dr_dt = diff * r.laplace(self.bc) + self.gamma * i 54 | return FieldCollection([ds_dt, di_dt, dr_dt]) 55 | 56 | 57 | eq = SIRPDE(beta=2, gamma=0.1) 58 | 59 | # initialize state 60 | grid = UnitGrid([32, 32]) 61 | s = ScalarField(grid, 1) 62 | i = ScalarField(grid, 0) 63 | i.data[0, 0] = 1 64 | state = eq.get_state(s, i) 65 | 66 | # simulate the pde 67 | tracker = PlotTracker(interrupts=10, plot_args={"vmin": 0, "vmax": 1}) 68 | sol = eq.solve(state, t_range=50, dt=1e-2, tracker=["progress", tracker]) 69 | -------------------------------------------------------------------------------- /examples/advanced_pdes/post_step_hook.py: -------------------------------------------------------------------------------- 1 | """ 2 | Post-step hook function 3 | ======================= 4 | 5 | Demonstrate the simple hook function in :class:`~pde.pdes.PDE`, which is called after 6 | each time step and may modify the state and abort the simulation. 7 | """ 8 | 9 | from pde import PDE, ScalarField, UnitGrid 10 | 11 | 12 | def post_step_hook(state_data, t): 13 | """Helper function called after every time step.""" 14 | state_data[24:40, 24:40] = 1 # set central region to given value 15 | 16 | if t > 1e3: 17 | raise StopIteration # abort simulation at given time 18 | 19 | 20 | eq = PDE({"c": "laplace(c)"}, post_step_hook=post_step_hook) 21 | state = ScalarField(UnitGrid([64, 64])) 22 | result = eq.solve(state, dt=0.1, t_range=1e4) 23 | result.plot() 24 | -------------------------------------------------------------------------------- /examples/advanced_pdes/post_step_hook_class.py: -------------------------------------------------------------------------------- 1 | """ 2 | Post-step hook function in a custom class 3 | ========================================= 4 | 5 | The hook function created by :meth:`~pde.pdes.PDEBase.make_post_step_hook` is called 6 | after each time step. The function can modify the state, keep track of additional 7 | information, and abort the simulation. 8 | """ 9 | 10 | from pde import PDEBase, ScalarField, UnitGrid 11 | 12 | 13 | class CustomPDE(PDEBase): 14 | def make_post_step_hook(self, state): 15 | """Create a hook function that is called after every time step.""" 16 | 17 | def post_step_hook(state_data, t, post_step_data): 18 | """Limit state 1 and abort when standard deviation exceeds 1.""" 19 | i = state_data > 1 # get violating entries 20 | overshoot = (state_data[i] - 1).sum() # get total correction 21 | state_data[i] = 1 # limit data entries 22 | post_step_data += overshoot # accumulate total correction 23 | if post_step_data > 400: 24 | # Abort simulation when correction exceeds 400 25 | # Note that the `post_step_data` of the previous step will be returned. 26 | raise StopIteration 27 | 28 | return post_step_hook, 0.0 # hook function and initial value for data 29 | 30 | def evolution_rate(self, state, t=0): 31 | """Evaluate the right hand side of the evolution equation.""" 32 | return state.__class__(state.grid, data=1) # constant growth 33 | 34 | 35 | grid = UnitGrid([64, 64]) # generate grid 36 | state = ScalarField.random_uniform(grid, 0.0, 0.5) # generate initial condition 37 | 38 | eq = CustomPDE() 39 | result = eq.solve(state, dt=0.1, t_range=1e4) 40 | result.plot(title=f"Total correction={eq.diagnostics['solver']['post_step_data']}") 41 | -------------------------------------------------------------------------------- /examples/advanced_pdes/solver_comparison.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solver comparison 3 | ================= 4 | 5 | This example shows how to set up solvers explicitly and how to extract 6 | diagnostic information. 7 | """ 8 | 9 | import pde 10 | 11 | # initialize the grid, an initial condition, and the PDE 12 | grid = pde.UnitGrid([32, 32]) 13 | field = pde.ScalarField.random_uniform(grid, -1, 1) 14 | eq = pde.DiffusionPDE() 15 | 16 | # try the explicit solver 17 | solver1 = pde.ExplicitSolver(eq) 18 | controller1 = pde.Controller(solver1, t_range=1, tracker=None) 19 | sol1 = controller1.run(field, dt=1e-3) 20 | sol1.label = "explicit solver" 21 | print("Diagnostic information from first run:") 22 | print(controller1.diagnostics) 23 | print() 24 | 25 | # try an explicit solver with adaptive time steps 26 | solver2 = pde.ExplicitSolver(eq, scheme="runge-kutta", adaptive=True) 27 | controller2 = pde.Controller(solver2, t_range=1, tracker=None) 28 | sol2 = controller2.run(field, dt=1e-3) 29 | sol2.label = "explicit, adaptive solver" 30 | print("Diagnostic information from second run:") 31 | print(controller2.diagnostics) 32 | print() 33 | 34 | # try the standard scipy solver 35 | solver3 = pde.ScipySolver(eq) 36 | controller3 = pde.Controller(solver3, t_range=1, tracker=None) 37 | sol3 = controller3.run(field) 38 | sol3.label = "scipy solver" 39 | print("Diagnostic information from third run:") 40 | print(controller3.diagnostics) 41 | print() 42 | 43 | # plot both fields and give the deviation as the title 44 | deviation12 = ((sol1 - sol2) ** 2).average 45 | deviation13 = ((sol1 - sol3) ** 2).average 46 | title = f"Deviation: {deviation12:.2g}, {deviation13:.2g}" 47 | pde.FieldCollection([sol1, sol2, sol3]).plot(title=title) 48 | -------------------------------------------------------------------------------- /examples/fields/README.txt: -------------------------------------------------------------------------------- 1 | Grids and fields 2 | ---------------- 3 | 4 | These examples show how to define and manipulate discretized fields. -------------------------------------------------------------------------------- /examples/fields/analyze_scalar_field.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualizing a scalar field 3 | ========================== 4 | 5 | This example displays methods for visualizing scalar fields. 6 | """ 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | from pde import CylindricalSymGrid, ScalarField 12 | 13 | # create a scalar field with some noise 14 | grid = CylindricalSymGrid(7, [0, 4 * np.pi], 64) 15 | data = ScalarField.from_expression(grid, "sin(z) * exp(-r / 3)") 16 | data += 0.05 * ScalarField.random_normal(grid) 17 | 18 | # manipulate the field 19 | smoothed = data.smooth() # Gaussian smoothing to get rid of the noise 20 | projected = data.project("r") # integrate along the radial direction 21 | sliced = smoothed.slice({"z": 1}) # slice the smoothed data 22 | 23 | # create four plots of the field and the modifications 24 | fig, axes = plt.subplots(nrows=2, ncols=2) 25 | data.plot(ax=axes[0, 0], title="Original field") 26 | smoothed.plot(ax=axes[1, 0], title="Smoothed field") 27 | projected.plot(ax=axes[0, 1], title="Projection on axial coordinate") 28 | sliced.plot(ax=axes[1, 1], title="Slice of smoothed field at $z=1$") 29 | plt.subplots_adjust(hspace=0.8) 30 | plt.show() 31 | -------------------------------------------------------------------------------- /examples/fields/finite_differences.py: -------------------------------------------------------------------------------- 1 | """ 2 | Finite differences approximation 3 | ================================ 4 | 5 | This example displays various finite difference (FD) approximations of derivatives of 6 | simple harmonic function. 7 | """ 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | from pde import CartesianGrid, ScalarField 13 | from pde.tools.expressions import evaluate 14 | 15 | # create two grids with different resolution to emphasize finite difference approximation 16 | grid_fine = CartesianGrid([(0, 2 * np.pi)], 256, periodic=True) 17 | grid_coarse = CartesianGrid([(0, 2 * np.pi)], 10, periodic=True) 18 | 19 | # create figure to present plots of the derivative 20 | fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True) 21 | 22 | # plot first derivatives of sin(x) 23 | f = ScalarField.from_expression(grid_coarse, "sin(x)") 24 | f_grad = f.gradient("periodic") # first derivative (from gradient vector field) 25 | ScalarField.from_expression(grid_fine, "cos(x)").plot( 26 | ax=axes[0, 0], label="Expected f'" 27 | ) 28 | f_grad.plot(ax=axes[0, 0], label="FD grad(f)", ls="", marker="o") 29 | plt.legend(frameon=True) 30 | plt.ylabel("") 31 | plt.xlabel("") 32 | plt.title(r"First derivative of $f(x) = \sin(x)$") 33 | 34 | # plot second derivatives of sin(x) 35 | f_laplace = f.laplace("periodic") # second derivative 36 | f_grad2 = f_grad.divergence("periodic") # second derivative using composition 37 | ScalarField.from_expression(grid_fine, "-sin(x)").plot( 38 | ax=axes[0, 1], label="Expected f''" 39 | ) 40 | f_laplace.plot(ax=axes[0, 1], label="FD laplace(f)", ls="", marker="o") 41 | f_grad2.plot(ax=axes[0, 1], label="FD div(grad(f))", ls="", marker="o") 42 | plt.legend(frameon=True) 43 | plt.xlabel("") 44 | plt.title(r"Second derivative of $f(x) = \sin(x)$") 45 | 46 | # plot first derivatives of sin(x)**2 47 | g_fine = ScalarField.from_expression(grid_fine, "sin(x)**2") 48 | g = g_fine.interpolate_to_grid(grid_coarse) 49 | expected = evaluate("2 * cos(x) * sin(x)", {"g": g_fine}) 50 | fd_1 = evaluate("d_dx(g)", {"g": g}) # first derivative (from directional derivative) 51 | expected.plot(ax=axes[1, 0], label="Expected g'") 52 | fd_1.plot(ax=axes[1, 0], label="FD grad(g)", ls="", marker="o") 53 | plt.legend(frameon=True) 54 | plt.title(r"First derivative of $g(x) = \sin(x)^2$") 55 | 56 | # plot second derivatives of sin(x)**2 57 | expected = evaluate("2 * cos(2 * x)", {"g": g_fine}) 58 | fd_2 = evaluate("d2_dx2(g)", {"g": g}) # second derivative 59 | fd_11 = evaluate("d_dx(d_dx(g))", {"g": g}) # composition of first derivatives 60 | expected.plot(ax=axes[1, 1], label="Expected g''") 61 | fd_2.plot(ax=axes[1, 1], label="FD laplace(g)", ls="", marker="o") 62 | fd_11.plot(ax=axes[1, 1], label="FD div(grad(g))", ls="", marker="o") 63 | plt.legend(frameon=True) 64 | plt.title(r"Second derivative of $g(x) = \sin(x)^2$") 65 | 66 | # finalize plot 67 | plt.tight_layout() 68 | plt.show() 69 | -------------------------------------------------------------------------------- /examples/fields/plot_cylindrical_field.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Plotting a scalar field in cylindrical coordinates 3 | ================================================== 4 | 5 | This example shows how to initialize and visualize the scalar field 6 | :math:`u = \sqrt{z} \, \exp(-r^2)` in cylindrical coordinates. 7 | """ 8 | 9 | from pde import CylindricalSymGrid, ScalarField 10 | 11 | grid = CylindricalSymGrid(radius=3, bounds_z=[0, 4], shape=16) 12 | field = ScalarField.from_expression(grid, "sqrt(z) * exp(-r**2)") 13 | field.plot(title="Scalar field in cylindrical coordinates") 14 | -------------------------------------------------------------------------------- /examples/fields/plot_polar_grid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot a polar grid 3 | ================= 4 | 5 | This example shows how to initialize a polar grid with a hole inside and angular 6 | symmetry, so that fields only depend on the radial coordinate. 7 | """ 8 | 9 | from pde import PolarSymGrid 10 | 11 | grid = PolarSymGrid((2, 5), 8) 12 | grid.plot(title=f"Area={grid.volume:.5g}") 13 | -------------------------------------------------------------------------------- /examples/fields/plot_vector_field.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Plotting a vector field 3 | ======================= 4 | 5 | This example shows how to initialize and visualize the vector field 6 | :math:`\boldsymbol u = \bigl(\sin(x), \cos(x)\bigr)`. 7 | """ 8 | 9 | from pde import CartesianGrid, VectorField 10 | 11 | grid = CartesianGrid([[-2, 2], [-2, 2]], 32) 12 | field = VectorField.from_expression(grid, ["sin(x)", "cos(x)"]) 13 | field.plot(method="streamplot", title="Stream plot") 14 | -------------------------------------------------------------------------------- /examples/fields/random_fields.py: -------------------------------------------------------------------------------- 1 | """ 2 | Random scalar fields 3 | ==================== 4 | 5 | This example showcases several random fields 6 | """ 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | from pde import ScalarField, UnitGrid 12 | 13 | # initialize grid and plot figure 14 | grid = UnitGrid([256, 256], periodic=True) 15 | fig, axes = plt.subplots(nrows=2, ncols=2) 16 | 17 | f1 = ScalarField.random_uniform(grid, -2.5, 2.5) 18 | f1.plot(ax=axes[0, 0], title="Uniform, uncorrelated") 19 | 20 | f2 = ScalarField.random_normal(grid, correlation="power law", exponent=-6) 21 | f2.plot(ax=axes[0, 1], title="Gaussian, power-law correlated") 22 | 23 | f3 = ScalarField.random_normal(grid, correlation="cosine", length_scale=30) 24 | f3.plot(ax=axes[1, 0], title="Gaussian, cosine correlated") 25 | 26 | f4 = ScalarField.random_harmonic(grid, modes=4) 27 | f4.plot(ax=axes[1, 1], title="Combined harmonic functions") 28 | 29 | plt.subplots_adjust(hspace=0.8) 30 | plt.show() 31 | -------------------------------------------------------------------------------- /examples/fields/show_3d_field_interactively.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualizing a 3d field interactively 3 | ==================================== 4 | 5 | This example demonstrates how to display 3d data interactively using the 6 | `napari `_ viewer. 7 | """ 8 | 9 | import numpy as np 10 | 11 | from pde import CartesianGrid, ScalarField 12 | 13 | # create a scalar field with some noise 14 | grid = CartesianGrid([[0, 2 * np.pi]] * 3, 64) 15 | data = ScalarField.from_expression(grid, "(cos(2 * x) * sin(3 * y) + cos(2 * z))**2") 16 | data += ScalarField.random_normal(grid, std=0.1) 17 | data.label = "3D Field" 18 | 19 | data.plot_interactive() 20 | -------------------------------------------------------------------------------- /examples/jupyter/.gitignore: -------------------------------------------------------------------------------- 1 | *.hdf -------------------------------------------------------------------------------- /examples/jupyter/Discretized Fields.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "\n", 11 | "sys.path.append(\"../..\") # add the pde package to the python path\n", 12 | "import pde" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# define a simple grid\n", 22 | "grid = pde.UnitGrid([32, 32])\n", 23 | "grid.plot()" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# define scalar field, initially filled with zeros\n", 33 | "field = pde.ScalarField(grid)\n", 34 | "field.average" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "# do computations on the field\n", 44 | "field += 1\n", 45 | "field.average" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# define a scalar field initialized with random colored noise and plot it\n", 55 | "scalar = pde.ScalarField.random_normal(grid, correlation=\"power law\", exponent=-2)\n", 56 | "scalar.plot(colorbar=True);" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "# apply operators to the field\n", 66 | "smoothed = scalar.smooth(1)\n", 67 | "laplace = smoothed.laplace(bc=\"auto_periodic_neumann\")\n", 68 | "laplace.plot(colorbar=True);" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# initialize a vector field and plot it\n", 78 | "vector = pde.VectorField.random_normal(grid, correlation=\"power law\", exponent=-4)\n", 79 | "vector.plot(method=\"streamplot\");" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# plot the first component of the vector field\n", 89 | "vector[0].plot()" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [] 98 | } 99 | ], 100 | "metadata": { 101 | "kernelspec": { 102 | "display_name": "Python 3 (ipykernel)", 103 | "language": "python", 104 | "name": "python3" 105 | }, 106 | "language_info": { 107 | "codemirror_mode": { 108 | "name": "ipython", 109 | "version": 3 110 | }, 111 | "file_extension": ".py", 112 | "mimetype": "text/x-python", 113 | "name": "python", 114 | "nbconvert_exporter": "python", 115 | "pygments_lexer": "ipython3", 116 | "version": "3.12.7" 117 | }, 118 | "toc": { 119 | "base_numbering": 1, 120 | "nav_menu": {}, 121 | "number_sections": true, 122 | "sideBar": true, 123 | "skip_h1_title": false, 124 | "title_cell": "Table of Contents", 125 | "title_sidebar": "Contents", 126 | "toc_cell": false, 127 | "toc_position": {}, 128 | "toc_section_display": true, 129 | "toc_window_display": false 130 | } 131 | }, 132 | "nbformat": 4, 133 | "nbformat_minor": 4 134 | } 135 | -------------------------------------------------------------------------------- /examples/output/README.txt: -------------------------------------------------------------------------------- 1 | Output and analysis 2 | ------------------- 3 | 4 | These examples demonstrate how to store and analyze output. -------------------------------------------------------------------------------- /examples/output/logarithmic_kymograph.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Logarithmic kymograph 3 | ===================== 4 | 5 | This example demonstrates a space-time plot with a logarithmic time axis, which is useful 6 | to analyze coarsening processes. Here, we use :func:`utilitiez.densityplot` for plotting. 7 | """ 8 | 9 | import matplotlib.pyplot as plt 10 | from utilitiez import densityplot 11 | 12 | import pde 13 | 14 | # define grid, initial field, and the PDE 15 | grid = pde.UnitGrid([128]) 16 | field = pde.ScalarField.random_uniform(grid, -0.1, 0.1) 17 | eq = pde.CahnHilliardPDE(interface_width=2) 18 | 19 | # run the simulation and store data in logarithmically spaced time intervals 20 | storage = pde.MemoryStorage() 21 | res = eq.solve( 22 | field, t_range=1e5, adaptive=True, tracker=[storage.tracker("geometric(10, 1.1)")] 23 | ) 24 | 25 | # create the density plot, which detects the logarithmically scaled time 26 | densityplot(storage.data, storage.times, grid.axes_coords[0]) 27 | plt.xlabel("Time") 28 | plt.ylabel("Space") 29 | -------------------------------------------------------------------------------- /examples/output/make_movie_live.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create a movie live 3 | =================== 4 | 5 | This example shows how to create a movie while running the simulation. Making movies 6 | requires that `ffmpeg` is installed in a standard location. 7 | """ 8 | 9 | from pde import DiffusionPDE, PlotTracker, ScalarField, UnitGrid 10 | 11 | grid = UnitGrid([16, 16]) # generate grid 12 | state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition 13 | 14 | tracker = PlotTracker(movie="diffusion.mov") # create movie tracker 15 | 16 | eq = DiffusionPDE() # define the physics 17 | eq.solve(state, t_range=2, dt=0.005, tracker=tracker) 18 | -------------------------------------------------------------------------------- /examples/output/make_movie_storage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create a movie from a storage 3 | ============================= 4 | 5 | This example shows how to create a movie from data stored during a simulation. Making 6 | movies requires that `ffmpeg` is installed in a standard location. 7 | """ 8 | 9 | from pde import DiffusionPDE, MemoryStorage, ScalarField, UnitGrid, movie_scalar 10 | 11 | grid = UnitGrid([16, 16]) # generate grid 12 | state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition 13 | 14 | storage = MemoryStorage() # create storage 15 | tracker = storage.tracker(interrupts=1) # create associated tracker 16 | 17 | eq = DiffusionPDE() # define the physics 18 | eq.solve(state, t_range=2, dt=0.005, tracker=tracker) 19 | 20 | # create movie from stored data 21 | movie_scalar(storage, "diffusion.mov") 22 | -------------------------------------------------------------------------------- /examples/output/py_modelrunner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -m modelrunner.run --output data.hdf5 --method foreground 2 | """ 3 | Using :mod:`modelrunner` 4 | ======================== 5 | 6 | This example shows how `py-pde` can be combined with :mod:`modelrunner`. The magic first 7 | line allows running this example as a script using :code:`./py_modelrunner.py`, which 8 | runs the function defined below and stores all results in the file `data.hdf5`. 9 | 10 | The results can be read by the following code 11 | 12 | .. code-block:: python 13 | 14 | from modelrunner import Result 15 | 16 | r = Result.from_file("data.hdf5") 17 | r.result.plot() # plots the final state 18 | r.storage["trajectory"] # allows accessing the stored trajectory 19 | """ 20 | 21 | from pde import DiffusionPDE, ModelrunnerStorage, ScalarField, UnitGrid 22 | 23 | 24 | def run(storage, diffusivity=0.1): 25 | """Function that runs the model. 26 | 27 | Args: 28 | storage (:mod:`~modelrunner.storage.group.StorageGroup`): 29 | Automatically supplied storage, to which extra data can be written 30 | diffusivity (float): 31 | Example for a parameter used in the model 32 | """ 33 | # initialize the model 34 | state = ScalarField.random_uniform(UnitGrid([64, 64]), 0.2, 0.3) 35 | storage["initia_state"] = state # store initial state with simulation 36 | eq = DiffusionPDE(diffusivity=diffusivity) 37 | 38 | # store trajectory in storage 39 | tracker = ModelrunnerStorage(storage, loc="trajectory").tracker(1) 40 | final_state = eq.solve(state, t_range=5, tracker=tracker) 41 | 42 | # returns the final state as the result, which will be stored by modelrunner 43 | return final_state 44 | -------------------------------------------------------------------------------- /examples/output/storages.py: -------------------------------------------------------------------------------- 1 | """ 2 | Storage examples 3 | ================ 4 | 5 | This example shows how to use :mod:`~pde.storage` to store data persistently. 6 | """ 7 | 8 | from pde import AllenCahnPDE, FileStorage, MovieStorage, ScalarField, UnitGrid 9 | 10 | # initialize the model 11 | state = ScalarField.random_uniform(UnitGrid([128, 128]), -0.01, 0.01) 12 | eq = AllenCahnPDE() 13 | 14 | # initialize empty storages 15 | file_write = FileStorage("allen_cahn.hdf") 16 | movie_write = MovieStorage("allen_cahn.avi", vmin=-1, vmax=1) 17 | 18 | # store trajectory in storage 19 | final_state = eq.solve( 20 | state, 21 | t_range=100, 22 | adaptive=True, 23 | tracker=[file_write.tracker(2), movie_write.tracker(1)], 24 | ) 25 | 26 | # read storage and plot last frame 27 | movie_read = MovieStorage("allen_cahn.avi") 28 | movie_read[-1].plot() 29 | -------------------------------------------------------------------------------- /examples/output/tracker_interactive.py: -------------------------------------------------------------------------------- 1 | """ 2 | Using an interactive tracker 3 | ============================ 4 | 5 | This example illustrates how a simulation can be analyzed live using the 6 | `napari `_ viewer. 7 | """ 8 | 9 | import pde 10 | 11 | 12 | def main(): 13 | grid = pde.UnitGrid([64, 64]) 14 | field = pde.ScalarField.random_uniform(grid, label="Density") 15 | 16 | eq = pde.CahnHilliardPDE() 17 | eq.solve(field, t_range=1e3, dt=1e-3, tracker=["progress", "interactive"]) 18 | 19 | 20 | if __name__ == "__main__": 21 | # this safeguard is required since the interactive tracker uses multiprocessing 22 | main() 23 | -------------------------------------------------------------------------------- /examples/output/trackers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Using simulation trackers 3 | ========================= 4 | 5 | This example illustrates how trackers can be used to analyze simulations. 6 | """ 7 | 8 | import pde 9 | 10 | grid = pde.UnitGrid([32, 32]) # generate grid 11 | state = pde.ScalarField.random_uniform(grid) # generate initial condition 12 | 13 | storage = pde.MemoryStorage() 14 | 15 | trackers = [ 16 | "progress", # show progress bar during simulation 17 | "steady_state", # abort when steady state is reached 18 | storage.tracker(interrupts=1), # store data every simulation time unit 19 | pde.PlotTracker(show=True), # show images during simulation 20 | # print some output every 5 real seconds: 21 | pde.PrintTracker(interrupts=pde.RealtimeInterrupts(duration=5)), 22 | ] 23 | 24 | eq = pde.DiffusionPDE(0.1) # define the PDE 25 | eq.solve(state, 3, dt=0.1, tracker=trackers) 26 | 27 | for field in storage: 28 | print(field.integral) 29 | -------------------------------------------------------------------------------- /examples/output/trajectory_io.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Writing and reading trajectory data 3 | =================================== 4 | 5 | This example illustrates how to store intermediate data to a file for later 6 | post-processing. The storage frequency is an argument to the tracker. 7 | """ 8 | 9 | from tempfile import NamedTemporaryFile 10 | 11 | import pde 12 | 13 | # define grid, state and pde 14 | grid = pde.UnitGrid([32]) 15 | state = pde.FieldCollection( 16 | [pde.ScalarField.random_uniform(grid), pde.VectorField.random_uniform(grid)] 17 | ) 18 | eq = pde.PDE({"s": "-0.1 * s", "v": "-v"}) 19 | 20 | # get a temporary file to write data to 21 | with NamedTemporaryFile(suffix=".hdf5") as path: 22 | # run a simulation and write the results 23 | writer = pde.FileStorage(path.name, write_mode="truncate") 24 | eq.solve(state, t_range=32, dt=0.01, tracker=writer.tracker(1)) 25 | 26 | # read the simulation back in again 27 | reader = pde.FileStorage(path.name, write_mode="read_only") 28 | pde.plot_kymographs(reader) 29 | -------------------------------------------------------------------------------- /examples/simple_pdes/README.txt: -------------------------------------------------------------------------------- 1 | Simple PDEs 2 | ----------- 3 | 4 | These examples demonstrate basic usage of the package to solve PDEs. -------------------------------------------------------------------------------- /examples/simple_pdes/boundary_conditions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setting boundary conditions 3 | =========================== 4 | 5 | This example shows how different boundary conditions can be specified. 6 | """ 7 | 8 | from pde import DiffusionPDE, ScalarField, UnitGrid 9 | 10 | grid = UnitGrid([32, 32], periodic=[False, True]) # generate grid 11 | state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition 12 | 13 | # set boundary conditions `bc` for all axes 14 | eq = DiffusionPDE( 15 | bc={"x-": {"derivative": 0.1}, "x+": {"value": "sin(y / 2)"}, "y": "periodic"} 16 | ) 17 | 18 | result = eq.solve(state, t_range=10, dt=0.005) 19 | result.plot() 20 | -------------------------------------------------------------------------------- /examples/simple_pdes/cartesian_grid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Diffusion on a Cartesian grid 3 | ============================= 4 | 5 | This example shows how to solve the diffusion equation on a Cartesian grid. 6 | """ 7 | 8 | from pde import CartesianGrid, DiffusionPDE, ScalarField 9 | 10 | grid = CartesianGrid([[-1, 1], [0, 2]], [30, 16]) # generate grid 11 | state = ScalarField(grid) # generate initial condition 12 | state.insert([0, 1], 1) 13 | 14 | eq = DiffusionPDE(0.1) # define the pde 15 | result = eq.solve(state, t_range=1, dt=0.01) 16 | result.plot(cmap="magma") 17 | -------------------------------------------------------------------------------- /examples/simple_pdes/laplace_eq_2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solving Laplace's equation in 2d 3 | ================================ 4 | 5 | This example shows how to solve a 2d Laplace equation with spatially varying 6 | boundary conditions. 7 | """ 8 | 9 | import numpy as np 10 | 11 | from pde import CartesianGrid, solve_laplace_equation 12 | 13 | grid = CartesianGrid([[0, 2 * np.pi], [0, 2 * np.pi]], 64) 14 | bcs = {"x": {"value": "sin(y)"}, "y": {"value": "sin(x)"}} 15 | 16 | res = solve_laplace_equation(grid, bcs) 17 | res.plot() 18 | -------------------------------------------------------------------------------- /examples/simple_pdes/pde_1d_expression.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 1D problem - Using `PDE` class 3 | ============================== 4 | 5 | This example implements a PDE that is only defined in one dimension. 6 | Here, we chose the `Korteweg-de Vries equation 7 | `_, given by 8 | 9 | .. math:: 10 | \partial_t \phi = 6 \phi \partial_x \phi - \partial_x^3 \phi 11 | 12 | which we implement using the :class:`~pde.pdes.pde.PDE`. 13 | """ 14 | 15 | from math import pi 16 | 17 | from pde import PDE, CartesianGrid, MemoryStorage, ScalarField, plot_kymograph 18 | 19 | # initialize the equation and the space 20 | eq = PDE({"φ": "6 * φ * d_dx(φ) - laplace(d_dx(φ))"}) 21 | grid = CartesianGrid([[0, 2 * pi]], [32], periodic=True) 22 | state = ScalarField.from_expression(grid, "sin(x)") 23 | 24 | # solve the equation and store the trajectory 25 | storage = MemoryStorage() 26 | eq.solve(state, t_range=3, solver="scipy", tracker=storage.tracker(0.1)) 27 | 28 | # plot the trajectory as a space-time plot 29 | plot_kymograph(storage) 30 | -------------------------------------------------------------------------------- /examples/simple_pdes/pde_brusselator_expression.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Brusselator - Using the `PDE` class 3 | =================================== 4 | 5 | This example uses the :class:`~pde.pdes.pde.PDE` class to implement the 6 | `Brusselator `_ with spatial 7 | coupling, 8 | 9 | .. math:: 10 | 11 | \partial_t u &= D_0 \nabla^2 u + a - (1 + b) u + v u^2 \\ 12 | \partial_t v &= D_1 \nabla^2 v + b u - v u^2 13 | 14 | Here, :math:`D_0` and :math:`D_1` are the respective diffusivity and the 15 | parameters :math:`a` and :math:`b` are related to reaction rates. 16 | 17 | Note that the same result can also be achieved with a 18 | :doc:`full implementation of a custom class <../advanced_pdes/pde_brusselator_class>`, 19 | which allows for more flexibility at the cost of code complexity. 20 | """ 21 | 22 | from pde import PDE, FieldCollection, PlotTracker, ScalarField, UnitGrid 23 | 24 | # define the PDE 25 | a, b = 1, 3 26 | d0, d1 = 1, 0.1 27 | eq = PDE( 28 | { 29 | "u": f"{d0} * laplace(u) + {a} - ({b} + 1) * u + u**2 * v", 30 | "v": f"{d1} * laplace(v) + {b} * u - u**2 * v", 31 | } 32 | ) 33 | 34 | # initialize state 35 | grid = UnitGrid([64, 64]) 36 | u = ScalarField(grid, a, label="Field $u$") 37 | v = b / a + 0.1 * ScalarField.random_normal(grid, label="Field $v$") 38 | state = FieldCollection([u, v]) 39 | 40 | # simulate the pde 41 | tracker = PlotTracker(interrupts=1, plot_args={"vmin": 0, "vmax": 5}) 42 | sol = eq.solve(state, t_range=20, dt=1e-3, tracker=tracker) 43 | -------------------------------------------------------------------------------- /examples/simple_pdes/pde_custom_expression.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Kuramoto-Sivashinsky - Using `PDE` class 3 | ======================================== 4 | 5 | This example implements a scalar PDE using the :class:`~pde.pdes.pde.PDE`. We here 6 | consider the `Kuramoto–Sivashinsky equation 7 | `_, which for instance 8 | describes the dynamics of flame fronts: 9 | 10 | .. math:: 11 | \partial_t u = -\frac12 |\nabla u|^2 - \nabla^2 u - \nabla^4 u 12 | """ 13 | 14 | from pde import PDE, ScalarField, UnitGrid 15 | 16 | grid = UnitGrid([32, 32]) # generate grid 17 | state = ScalarField.random_uniform(grid) # generate initial condition 18 | 19 | eq = PDE({"u": "-gradient_squared(u) / 2 - laplace(u + laplace(u))"}) # define the pde 20 | result = eq.solve(state, t_range=10, dt=0.01) 21 | result.plot() 22 | -------------------------------------------------------------------------------- /examples/simple_pdes/pde_heterogeneous_diffusion.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Diffusion equation with spatial dependence 3 | ========================================== 4 | 5 | This example solve the 6 | `Diffusion equation `_ with a 7 | heterogeneous diffusivity: 8 | 9 | .. math:: 10 | \partial_t c = \nabla\bigr( D(\boldsymbol r) \nabla c \bigr) 11 | 12 | using the :class:`~pde.pdes.pde.PDE` class. In particular, we consider 13 | :math:`D(x) = 1.01 + \tanh(x)`, which gives a low diffusivity on the left side of the 14 | domain. 15 | 16 | Note that the naive implementation, 17 | :code:`PDE({"c": "divergence((1.01 + tanh(x)) * gradient(c))"})`, has numerical 18 | instabilities. This is because two finite difference approximations are nested. To 19 | arrive at a more stable numerical scheme, it is advisable to expand the divergence, 20 | 21 | .. math:: 22 | \partial_t c = D \nabla^2 c + \nabla D . \nabla c 23 | """ 24 | 25 | from pde import PDE, CartesianGrid, MemoryStorage, ScalarField, plot_kymograph 26 | 27 | # Expanded definition of the PDE 28 | diffusivity = "1.01 + tanh(x)" 29 | term_1 = f"({diffusivity}) * laplace(c)" 30 | term_2 = f"dot(gradient({diffusivity}), gradient(c))" 31 | eq = PDE({"c": f"{term_1} + {term_2}"}, bc={"value": 0}) 32 | 33 | 34 | grid = CartesianGrid([[-5, 5]], 64) # generate grid 35 | field = ScalarField(grid, 1) # generate initial condition 36 | 37 | storage = MemoryStorage() # store intermediate information of the simulation 38 | res = eq.solve(field, 100, dt=1e-3, tracker=storage.tracker(1)) # solve the PDE 39 | 40 | plot_kymograph(storage) # visualize the result in a space-time plot 41 | -------------------------------------------------------------------------------- /examples/simple_pdes/pde_schroedinger.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Schrödinger's Equation 3 | ====================== 4 | 5 | This example implements a complex PDE using the :class:`~pde.pdes.pde.PDE`. We here 6 | chose the `Schrödinger equation `_ 7 | without a spatial potential in non-dimensional form: 8 | 9 | .. math:: 10 | i \partial_t \psi = -\nabla^2 \psi 11 | 12 | Note that the example imposes Neumann conditions at the wall, so the wave packet is 13 | expected to reflect off the wall. 14 | """ 15 | 16 | from math import sqrt 17 | 18 | from pde import PDE, CartesianGrid, MemoryStorage, ScalarField, plot_kymograph 19 | 20 | grid = CartesianGrid([[0, 20]], 128, periodic=False) # generate grid 21 | 22 | # create a (normalized) wave packet with a certain form as an initial condition 23 | initial_state = ScalarField.from_expression(grid, "exp(I * 5 * x) * exp(-(x - 10)**2)") 24 | initial_state /= sqrt(initial_state.to_scalar("norm_squared").integral.real) 25 | 26 | eq = PDE({"ψ": "I * laplace(ψ)"}) # define the pde 27 | 28 | # solve the pde and store intermediate data 29 | storage = MemoryStorage() 30 | eq.solve(initial_state, t_range=2.5, dt=1e-5, tracker=[storage.tracker(0.02)]) 31 | 32 | # visualize the results as a space-time plot 33 | plot_kymograph(storage, scalar="norm_squared") 34 | -------------------------------------------------------------------------------- /examples/simple_pdes/poisson_eq_1d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solving Poisson's equation in 1d 3 | ================================ 4 | 5 | This example shows how to solve a 1d Poisson equation with boundary conditions. 6 | """ 7 | 8 | from pde import CartesianGrid, ScalarField, solve_poisson_equation 9 | 10 | grid = CartesianGrid([[0, 1]], 32, periodic=False) 11 | field = ScalarField(grid, 1) 12 | result = solve_poisson_equation(field, bc={"x-": {"value": 0}, "x+": {"derivative": 1}}) 13 | 14 | result.plot() 15 | -------------------------------------------------------------------------------- /examples/simple_pdes/simple.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple diffusion equation 3 | ========================= 4 | 5 | This example solves a simple diffusion equation in two dimensions. 6 | """ 7 | 8 | from pde import DiffusionPDE, ScalarField, UnitGrid 9 | 10 | grid = UnitGrid([64, 64]) # generate grid 11 | state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition 12 | 13 | eq = DiffusionPDE(diffusivity=0.1) # define the pde 14 | result = eq.solve(state, t_range=10) 15 | result.plot() 16 | -------------------------------------------------------------------------------- /examples/simple_pdes/spherical_grid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Spherically symmetric PDE 3 | ========================= 4 | 5 | This example illustrates how to solve a PDE in a spherically symmetric geometry. 6 | """ 7 | 8 | from pde import DiffusionPDE, ScalarField, SphericalSymGrid 9 | 10 | grid = SphericalSymGrid(radius=[1, 5], shape=128) # generate grid 11 | state = ScalarField.random_uniform(grid) # generate initial condition 12 | 13 | eq = DiffusionPDE(0.1) # define the PDE 14 | result = eq.solve(state, t_range=0.1, dt=0.001) 15 | 16 | result.plot(kind="image") 17 | -------------------------------------------------------------------------------- /examples/simple_pdes/stochastic_simulation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Stochastic simulation 3 | ===================== 4 | 5 | This example illustrates how a stochastic simulation can be done. 6 | """ 7 | 8 | from pde import KPZInterfacePDE, MemoryStorage, ScalarField, UnitGrid, plot_kymograph 9 | 10 | grid = UnitGrid([64]) # generate grid 11 | state = ScalarField.random_harmonic(grid) # generate initial condition 12 | 13 | eq = KPZInterfacePDE(noise=1) # define the SDE 14 | storage = MemoryStorage() 15 | eq.solve(state, t_range=10, dt=0.01, tracker=storage.tracker(0.5)) 16 | plot_kymograph(storage) 17 | -------------------------------------------------------------------------------- /examples/simple_pdes/time_dependent_bcs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Time-dependent boundary conditions 3 | ================================== 4 | 5 | This example solves a simple diffusion equation in one dimensions with time-dependent 6 | boundary conditions. 7 | """ 8 | 9 | from pde import PDE, CartesianGrid, MemoryStorage, ScalarField, plot_kymograph 10 | 11 | grid = CartesianGrid([[0, 10]], [64]) # generate grid 12 | state = ScalarField(grid) # generate initial condition 13 | 14 | eq = PDE({"c": "laplace(c)"}, bc={"value_expression": "sin(t)"}) 15 | 16 | storage = MemoryStorage() 17 | eq.solve(state, t_range=20, dt=1e-4, tracker=storage.tracker(0.1)) 18 | 19 | # plot the trajectory as a space-time plot 20 | plot_kymograph(storage) 21 | -------------------------------------------------------------------------------- /pde/__init__.py: -------------------------------------------------------------------------------- 1 | """The py-pde package provides classes and methods for solving partial differential 2 | equations.""" 3 | 4 | # determine the package version 5 | try: 6 | # try reading version of the automatically generated module 7 | from ._version import __version__ 8 | except ImportError: 9 | # determine version automatically from CVS information 10 | from importlib.metadata import PackageNotFoundError, version 11 | 12 | try: 13 | __version__ = version("pde") 14 | except PackageNotFoundError: 15 | # package is not installed, so we cannot determine any version 16 | __version__ = "unknown" 17 | del PackageNotFoundError, version # clean name space 18 | 19 | # initialize the configuration 20 | from .tools.config import Config, environment 21 | 22 | config = Config() # initialize the default configuration 23 | 24 | import contextlib 25 | 26 | # import all other modules that should occupy the main name space 27 | from .fields import * 28 | from .grids import * 29 | from .pdes import * 30 | from .solvers import * 31 | from .storage import * 32 | from .tools.parameters import Parameter 33 | from .trackers import * 34 | from .visualization import * 35 | 36 | with contextlib.suppress(ImportError): 37 | from .tools.modelrunner import * 38 | 39 | del contextlib, Config # clean name space 40 | -------------------------------------------------------------------------------- /pde/fields/__init__.py: -------------------------------------------------------------------------------- 1 | """Defines fields, which contain the actual data stored on a discrete grid. 2 | 3 | .. autosummary:: 4 | :nosignatures: 5 | 6 | ~scalar.ScalarField 7 | ~vectorial.VectorField 8 | ~tensorial.Tensor2Field 9 | ~collection.FieldCollection 10 | 11 | 12 | Inheritance structure of the classes: 13 | 14 | 15 | .. inheritance-diagram:: scalar.ScalarField vectorial.VectorField tensorial.Tensor2Field 16 | collection.FieldCollection 17 | :parts: 1 18 | 19 | The details of the classes are explained below: 20 | 21 | .. codeauthor:: David Zwicker 22 | """ 23 | 24 | from .base import FieldBase 25 | from .collection import FieldCollection 26 | from .datafield_base import DataFieldBase 27 | from .scalar import ScalarField 28 | from .tensorial import Tensor2Field 29 | from .vectorial import VectorField 30 | 31 | # DataFieldBase has been moved to its own module on 2024-04-18. 32 | # Add it back to `base` for the time being, so dependent code doesn't break 33 | from . import base # isort:skip 34 | 35 | base.DataFieldBase = DataFieldBase # type: ignore 36 | del base # clean namespaces 37 | -------------------------------------------------------------------------------- /pde/grids/__init__.py: -------------------------------------------------------------------------------- 1 | """Grids define the domains on which PDEs will be solved. In particular, symmetries, 2 | periodicities, and the discretizations are defined by the underlying grid. 3 | 4 | We only consider regular, orthogonal grids, which are constructed from orthogonal 5 | coordinate systems with equidistant discretizations along each axis. The dimension of 6 | the space that the grid describes is given by the attribute :attr:`dim`. Cartesian 7 | coordinates can be mapped to grid coordinates and the corresponding discretization cells 8 | using the method :meth:`transform`. 9 | 10 | .. autosummary:: 11 | :nosignatures: 12 | 13 | ~cartesian.UnitGrid 14 | ~cartesian.CartesianGrid 15 | ~spherical.PolarSymGrid 16 | ~spherical.SphericalSymGrid 17 | ~cylindrical.CylindricalSymGrid 18 | 19 | Inheritance structure of the classes: 20 | 21 | .. inheritance-diagram:: cartesian.UnitGrid cartesian.CartesianGrid 22 | spherical.PolarSymGrid spherical.SphericalSymGrid cylindrical.CylindricalSymGrid 23 | :parts: 1 24 | 25 | .. codeauthor:: David Zwicker 26 | """ 27 | 28 | from . import operators # import all operator modules to register them 29 | from .base import registered_operators 30 | from .boundaries import * 31 | from .cartesian import CartesianGrid, UnitGrid 32 | from .cylindrical import CylindricalSymGrid 33 | from .spherical import PolarSymGrid, SphericalSymGrid 34 | 35 | del operators # remove the name from the namespace 36 | -------------------------------------------------------------------------------- /pde/grids/coordinates/__init__.py: -------------------------------------------------------------------------------- 1 | """Package collecting classes representing orthonormal coordinate systems. 2 | 3 | .. autosummary:: 4 | :nosignatures: 5 | 6 | ~bipolar.BipolarCoordinates 7 | ~bispherical.BisphericalCoordinates 8 | ~cartesian.CartesianCoordinates 9 | ~cylindrical.CylindricalCoordinates 10 | ~polar.PolarCoordinates 11 | ~spherical.SphericalCoordinates 12 | """ 13 | 14 | from .base import CoordinatesBase, DimensionError 15 | from .bipolar import BipolarCoordinates 16 | from .bispherical import BisphericalCoordinates 17 | from .cartesian import CartesianCoordinates 18 | from .cylindrical import CylindricalCoordinates 19 | from .polar import PolarCoordinates 20 | from .spherical import SphericalCoordinates 21 | -------------------------------------------------------------------------------- /pde/grids/coordinates/bipolar.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import numpy as np 8 | from numpy.typing import ArrayLike 9 | 10 | from .base import CoordinatesBase 11 | 12 | 13 | class BipolarCoordinates(CoordinatesBase): 14 | """2-dimensional bipolar coordinates.""" 15 | 16 | dim = 2 17 | axes = ["σ", "τ"] 18 | _axes_alt = {"σ": ["sigma"], "τ": ["tau"]} 19 | coordinate_limits = [(0, 2 * np.pi), (-np.inf, np.inf)] 20 | 21 | def __init__(self, scale_parameter: float = 1): 22 | super().__init__() 23 | if scale_parameter <= 0: 24 | raise ValueError("Scale parameter must be positive") 25 | self.scale_parameter = scale_parameter 26 | 27 | def __repr__(self) -> str: 28 | """Return instance as string.""" 29 | return f"{self.__class__.__name__}(scale_parameter={self.scale_parameter})" 30 | 31 | def __eq__(self, other): 32 | return ( 33 | self.__class__ is other.__class__ 34 | and self.scale_parameter == other.scale_parameter 35 | ) 36 | 37 | def _pos_to_cart(self, points: np.ndarray) -> np.ndarray: 38 | σ, τ = points[..., 0], points[..., 1] 39 | denom = np.cosh(τ) - np.cos(σ) 40 | x = self.scale_parameter * np.sinh(τ) / denom 41 | y = self.scale_parameter * np.sin(σ) / denom 42 | return np.stack((x, y), axis=-1) # type: ignore 43 | 44 | def _pos_from_cart(self, points: np.ndarray) -> np.ndarray: 45 | x, y = points[..., 0], points[..., 1] 46 | a = self.scale_parameter 47 | h2 = x**2 + y**2 48 | denom = a**2 - h2 + np.sqrt((a**2 - h2) ** 2 + 4 * a**2 * y**2) 49 | σ = np.mod(np.pi - 2 * np.arctan2(2 * a * y, denom), 2 * np.pi) 50 | τ = 0.5 * np.log(((x + a) ** 2 + y**2) / ((x - a) ** 2 + y**2)) 51 | return np.stack((σ, τ), axis=-1) # type: ignore 52 | 53 | def _mapping_jacobian(self, points: np.ndarray) -> np.ndarray: 54 | σ, τ = points[..., 0], points[..., 1] 55 | 56 | sinσ = np.sin(σ) 57 | cosσ = np.cos(σ) 58 | sinhτ = np.sinh(τ) 59 | coshτ = np.cosh(τ) 60 | factor = self.scale_parameter * (cosσ - coshτ) ** -2 61 | 62 | return factor * np.array( # type: ignore 63 | [ 64 | [-sinσ * sinhτ, 1 - cosσ * coshτ], 65 | [cosσ * coshτ - 1, -sinσ * sinhτ], 66 | ] 67 | ) 68 | 69 | def _volume_factor(self, points: np.ndarray) -> ArrayLike: 70 | σ, τ = points[..., 0], points[..., 1] 71 | return self.scale_parameter**2 * (np.cosh(τ) - np.cos(σ)) ** -2 72 | 73 | def _scale_factors(self, points: np.ndarray) -> np.ndarray: 74 | σ, τ = points[..., 0], points[..., 1] 75 | sf = self.scale_parameter / (np.cosh(τ) - np.cos(σ)) 76 | return np.array([sf, sf]) # type: ignore 77 | 78 | def _basis_rotation(self, points: np.ndarray) -> np.ndarray: 79 | σ, τ = points[..., 0], points[..., 1] 80 | 81 | sinσ = np.sin(σ) 82 | cosσ = np.cos(σ) 83 | sinhτ = np.sinh(τ) 84 | coshτ = np.cosh(τ) 85 | factor = 1 / (cosσ - coshτ) 86 | 87 | return factor * np.array( # type: ignore 88 | [ 89 | [sinσ * sinhτ, 1 - cosσ * coshτ], 90 | [cosσ * coshτ - 1, sinσ * sinhτ], 91 | ] 92 | ) 93 | -------------------------------------------------------------------------------- /pde/grids/coordinates/cartesian.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import numpy as np 8 | from numpy.typing import ArrayLike 9 | 10 | from .base import CoordinatesBase 11 | 12 | 13 | class CartesianCoordinates(CoordinatesBase): 14 | """N-dimensional Cartesian coordinates.""" 15 | 16 | _objs: dict[int, CartesianCoordinates] = {} 17 | 18 | def __new__(cls, dim: int): 19 | # cache the instances for each dimension 20 | if dim not in cls._objs: 21 | cls._objs[dim] = super().__new__(cls) 22 | return cls._objs[dim] 23 | 24 | def __getnewargs__(self): 25 | return (self.dim,) 26 | 27 | def __init__(self, dim: int): 28 | """ 29 | Args: 30 | dim (int): 31 | Dimension of the Cartesian coordinate system 32 | """ 33 | if dim <= 0: 34 | raise ValueError("`dim` must be positive integer") 35 | self.dim = dim 36 | if self.dim <= 3: 37 | self.axes = list("xyz"[: self.dim]) 38 | else: 39 | self.axes = [chr(97 + i) for i in range(self.dim)] 40 | self.coordinate_limits = [(-np.inf, np.inf)] * self.dim 41 | 42 | def __repr__(self) -> str: 43 | """Return instance as string.""" 44 | return f"{self.__class__.__name__}(dim={self.dim})" 45 | 46 | def __eq__(self, other): 47 | return self.__class__ is other.__class__ and self.dim == other.dim 48 | 49 | def _pos_to_cart(self, points: np.ndarray) -> np.ndarray: 50 | return points 51 | 52 | def _pos_from_cart(self, points: np.ndarray) -> np.ndarray: 53 | return points 54 | 55 | def _mapping_jacobian(self, points: np.ndarray) -> np.ndarray: 56 | jac = np.zeros((self.dim, self.dim) + points.shape[:-1]) 57 | jac[range(self.dim), range(self.dim)] = 1 58 | return jac # type: ignore 59 | 60 | def _volume_factor(self, points: np.ndarray) -> ArrayLike: 61 | return np.ones(points.shape[:-1]) 62 | 63 | def _cell_volume(self, c_low: np.ndarray, c_high: np.ndarray): 64 | return np.prod(c_high - c_low, axis=-1) 65 | 66 | def _scale_factors(self, points: np.ndarray) -> np.ndarray: 67 | return np.ones_like(points) # type: ignore 68 | 69 | def _basis_rotation(self, points: np.ndarray) -> np.ndarray: 70 | return np.eye(self.dim) # type: ignore 71 | -------------------------------------------------------------------------------- /pde/grids/coordinates/cylindrical.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import numpy as np 8 | from numpy.typing import ArrayLike 9 | 10 | from .base import CoordinatesBase 11 | 12 | 13 | class CylindricalCoordinates(CoordinatesBase): 14 | """3-dimensional cylindrical coordinates.""" 15 | 16 | _singleton: CylindricalCoordinates | None = None 17 | dim = 3 18 | coordinate_limits = [(0, np.inf), (0, 2 * np.pi), (-np.inf, np.inf)] 19 | axes = ["r", "φ", "z"] 20 | _axes_alt = {"φ": ["phi"]} 21 | 22 | def __new__(cls): 23 | # cache the instances for each dimension 24 | if cls._singleton is None: 25 | cls._singleton = super().__new__(cls) 26 | return cls._singleton 27 | 28 | def __eq__(self, other): 29 | return self.__class__ is other.__class__ 30 | 31 | def _pos_to_cart(self, points: np.ndarray) -> np.ndarray: 32 | r, φ, z = points[..., 0], points[..., 1], points[..., 2] 33 | x = r * np.cos(φ) 34 | y = r * np.sin(φ) 35 | return np.stack((x, y, z), axis=-1) # type: ignore 36 | 37 | def _pos_from_cart(self, points: np.ndarray) -> np.ndarray: 38 | x, y, z = points[..., 0], points[..., 1], points[..., 2] 39 | r = np.hypot(x, y) 40 | φ = np.arctan2(y, x) 41 | return np.stack((r, φ, z), axis=-1) # type: ignore 42 | 43 | def _mapping_jacobian(self, points: np.ndarray) -> np.ndarray: 44 | r, φ = points[..., 0], points[..., 1] 45 | sinφ, cosφ = np.sin(φ), np.cos(φ) 46 | zero = np.zeros_like(r) 47 | return np.array( # type: ignore 48 | [ 49 | [cosφ, -r * sinφ, zero], 50 | [sinφ, r * cosφ, zero], 51 | [zero, zero, zero + 1], 52 | ] 53 | ) 54 | 55 | def _volume_factor(self, points: np.ndarray) -> ArrayLike: 56 | return points[..., 0] 57 | 58 | def _cell_volume(self, c_low: np.ndarray, c_high: np.ndarray): 59 | r1, φ1, z1 = c_low[..., 0], c_low[..., 1], c_low[..., 2] 60 | r2, φ2, z2 = c_high[..., 0], c_high[..., 1], c_high[..., 2] 61 | return (φ2 - φ1) * (z2 - z1) * (r2**2 - r1**2) / 2 62 | 63 | def _scale_factors(self, points: np.ndarray) -> np.ndarray: 64 | r = points[..., 0] 65 | ones = np.ones_like(r) 66 | return np.array([ones, r, ones]) # type: ignore 67 | 68 | def _basis_rotation(self, points: np.ndarray) -> np.ndarray: 69 | φ = points[..., 1] 70 | sinφ, cosφ = np.sin(φ), np.cos(φ) 71 | zero = np.zeros_like(φ) 72 | return np.array( # type: ignore 73 | [ 74 | [cosφ, sinφ, zero], 75 | [-sinφ, cosφ, zero], 76 | [zero, zero, zero + 1], 77 | ] 78 | ) 79 | -------------------------------------------------------------------------------- /pde/grids/coordinates/polar.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import numpy as np 8 | from numpy.typing import ArrayLike 9 | 10 | from .base import CoordinatesBase 11 | 12 | 13 | class PolarCoordinates(CoordinatesBase): 14 | """2-dimensional polar coordinates.""" 15 | 16 | dim = 2 17 | axes = ["r", "φ"] 18 | _axes_alt = {"r": ["radius"], "φ": ["phi"]} 19 | coordinate_limits = [(0, np.inf), (0, 2 * np.pi)] 20 | 21 | _singleton: PolarCoordinates | None = None 22 | 23 | def __new__(cls): 24 | # cache the instances for each dimension 25 | if cls._singleton is None: 26 | cls._singleton = super().__new__(cls) 27 | return cls._singleton 28 | 29 | def __repr__(self) -> str: 30 | """Return instance as string.""" 31 | return f"{self.__class__.__name__}()" 32 | 33 | def __eq__(self, other): 34 | return self.__class__ is other.__class__ 35 | 36 | def _pos_to_cart(self, points: np.ndarray) -> np.ndarray: 37 | r, φ = points[..., 0], points[..., 1] 38 | x = r * np.cos(φ) 39 | y = r * np.sin(φ) 40 | return np.stack((x, y), axis=-1) # type: ignore 41 | 42 | def _pos_from_cart(self, points: np.ndarray) -> np.ndarray: 43 | x, y = points[..., 0], points[..., 1] 44 | r = np.hypot(x, y) 45 | φ = np.arctan2(y, x) 46 | return np.stack((r, φ), axis=-1) # type: ignore 47 | 48 | def _mapping_jacobian(self, points: np.ndarray) -> np.ndarray: 49 | r, φ = points[..., 0], points[..., 1] 50 | sinφ, cosφ = np.sin(φ), np.cos(φ) 51 | return np.array([[cosφ, -r * sinφ], [sinφ, r * cosφ]]) # type: ignore 52 | 53 | def _volume_factor(self, points: np.ndarray) -> ArrayLike: 54 | return points[..., 0] 55 | 56 | def _cell_volume(self, c_low: np.ndarray, c_high: np.ndarray) -> np.ndarray: 57 | r1, φ1 = c_low[..., 0], c_low[..., 1] 58 | r2, φ2 = c_high[..., 0], c_high[..., 1] 59 | return (φ2 - φ1) * (r2**2 - r1**2) / 2 # type: ignore 60 | 61 | def _scale_factors(self, points: np.ndarray) -> np.ndarray: 62 | r = points[..., 0] 63 | return np.array([np.ones_like(r), r]) # type: ignore 64 | 65 | def _basis_rotation(self, points: np.ndarray) -> np.ndarray: 66 | φ = points[..., 1] 67 | sinφ, cosφ = np.sin(φ), np.cos(φ) 68 | return np.array([[cosφ, sinφ], [-sinφ, cosφ]]) # type: ignore 69 | -------------------------------------------------------------------------------- /pde/grids/coordinates/spherical.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | import numpy as np 8 | from numpy.typing import ArrayLike 9 | 10 | from .base import CoordinatesBase 11 | 12 | 13 | class SphericalCoordinates(CoordinatesBase): 14 | """3-dimensional spherical coordinates.""" 15 | 16 | dim = 3 17 | axes = ["r", "θ", "φ"] 18 | _axes_alt = {"r": ["radius"], "θ": ["theta"], "φ": ["phi"]} 19 | coordinate_limits = [(0, np.inf), (0, np.pi), (0, 2 * np.pi)] 20 | major_axis = 0 21 | 22 | _singleton: SphericalCoordinates | None = None 23 | 24 | def __new__(cls): 25 | # cache the instances for each dimension 26 | if cls._singleton is None: 27 | cls._singleton = super().__new__(cls) 28 | return cls._singleton 29 | 30 | def __repr__(self) -> str: 31 | """Return instance as string.""" 32 | return f"{self.__class__.__name__}()" 33 | 34 | def __eq__(self, other): 35 | return self.__class__ is other.__class__ 36 | 37 | def _pos_to_cart(self, points: np.ndarray) -> np.ndarray: 38 | r, θ, φ = points[..., 0], points[..., 1], points[..., 2] 39 | rsinθ = r * np.sin(θ) 40 | x = rsinθ * np.cos(φ) 41 | y = rsinθ * np.sin(φ) 42 | z = r * np.cos(θ) 43 | return np.stack((x, y, z), axis=-1) # type:ignore 44 | 45 | def _pos_from_cart(self, points: np.ndarray) -> np.ndarray: 46 | x, y, z = points[..., 0], points[..., 1], points[..., 2] 47 | r = np.linalg.norm(points, axis=-1) 48 | θ = np.arctan2(np.hypot(x, y), z) 49 | φ = np.arctan2(y, x) 50 | return np.stack((r, θ, φ), axis=-1) # type:ignore 51 | 52 | def _mapping_jacobian(self, points: np.ndarray) -> np.ndarray: 53 | r, θ, φ = points[..., 0], points[..., 1], points[..., 2] 54 | sinθ, cosθ = np.sin(θ), np.cos(θ) 55 | sinφ, cosφ = np.sin(φ), np.cos(φ) 56 | return np.array( # type:ignore 57 | [ 58 | [cosφ * sinθ, r * cosφ * cosθ, -r * sinφ * sinθ], 59 | [sinφ * sinθ, r * sinφ * cosθ, r * cosφ * sinθ], 60 | [cosθ, -r * sinθ, np.zeros_like(θ)], 61 | ] 62 | ) 63 | 64 | def _volume_factor(self, points: np.ndarray) -> ArrayLike: 65 | r, θ = points[..., 0], points[..., 1] 66 | return r**2 * np.sin(θ) 67 | 68 | def _cell_volume(self, c_low: np.ndarray, c_high: np.ndarray): 69 | r1, θ1, φ1 = c_low[..., 0], c_low[..., 1], c_low[..., 2] 70 | r2, θ2, φ2 = c_high[..., 0], c_high[..., 1], c_high[..., 2] 71 | return (φ2 - φ1) * (np.cos(θ1) - np.cos(θ2)) * (r2**3 - r1**3) / 3 72 | 73 | def _scale_factors(self, points: np.ndarray) -> np.ndarray: 74 | r, θ = points[..., 0], points[..., 1] 75 | return np.array([np.ones_like(r), r, r * np.sin(θ)]) # type: ignore 76 | 77 | def _basis_rotation(self, points: np.ndarray) -> np.ndarray: 78 | θ, φ = points[..., 1], points[..., 2] 79 | sinθ, cosθ = np.sin(θ), np.cos(θ) 80 | sinφ, cosφ = np.sin(φ), np.cos(φ) 81 | return np.array( # type: ignore 82 | [ 83 | [cosφ * sinθ, sinφ * sinθ, cosθ], 84 | [cosφ * cosθ, sinφ * cosθ, -sinθ], 85 | [-sinφ, cosφ, np.zeros_like(θ)], 86 | ] 87 | ) 88 | -------------------------------------------------------------------------------- /pde/grids/operators/__init__.py: -------------------------------------------------------------------------------- 1 | """Package collecting modules defining discretized operators for different grids. 2 | 3 | These operators can either be used directly or they are imported by the respective 4 | methods defined on fields and grids. 5 | 6 | .. autosummary:: 7 | :nosignatures: 8 | 9 | cartesian 10 | cylindrical_sym 11 | polar_sym 12 | spherical_sym 13 | 14 | common.make_derivative 15 | common.make_derivative2 16 | """ 17 | 18 | from . import cartesian, cylindrical_sym, polar_sym, spherical_sym 19 | from .common import make_derivative, make_derivative2 20 | -------------------------------------------------------------------------------- /pde/pdes/__init__.py: -------------------------------------------------------------------------------- 1 | """Package that defines PDEs describing physical systems. 2 | 3 | The examples in this package are often simple version of classical PDEs to 4 | demonstrate various aspects of the `py-pde` package. Clearly, not all extensions 5 | to these PDEs can be covered here, but this should serve as a starting point for 6 | custom investigations. 7 | 8 | Publicly available methods should take fields with grid information and also 9 | only return such methods. There might be corresponding private methods that 10 | deal with raw data for faster simulations. 11 | 12 | 13 | .. autosummary:: 14 | :nosignatures: 15 | 16 | ~pde.PDE 17 | ~allen_cahn.AllenCahnPDE 18 | ~cahn_hilliard.CahnHilliardPDE 19 | ~diffusion.DiffusionPDE 20 | ~kpz_interface.KPZInterfacePDE 21 | ~kuramoto_sivashinsky.KuramotoSivashinskyPDE 22 | ~swift_hohenberg.SwiftHohenbergPDE 23 | ~wave.WavePDE 24 | 25 | Additionally, we offer two solvers for typical elliptical PDEs: 26 | 27 | 28 | .. autosummary:: 29 | :nosignatures: 30 | 31 | ~laplace.solve_laplace_equation 32 | ~laplace.solve_poisson_equation 33 | 34 | 35 | .. codeauthor:: David Zwicker 36 | """ 37 | 38 | from .allen_cahn import AllenCahnPDE 39 | from .base import PDEBase 40 | from .cahn_hilliard import CahnHilliardPDE 41 | from .diffusion import DiffusionPDE 42 | from .kpz_interface import KPZInterfacePDE 43 | from .kuramoto_sivashinsky import KuramotoSivashinskyPDE 44 | from .laplace import solve_laplace_equation, solve_poisson_equation 45 | from .pde import PDE 46 | from .swift_hohenberg import SwiftHohenbergPDE 47 | from .wave import WavePDE 48 | -------------------------------------------------------------------------------- /pde/pdes/laplace.py: -------------------------------------------------------------------------------- 1 | """Solvers for Poisson's and Laplace's equation. 2 | 3 | .. codeauthor:: David Zwicker 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | from ..fields import ScalarField 9 | from ..grids.base import GridBase 10 | from ..grids.boundaries.axes import BoundariesData 11 | from ..tools.docstrings import fill_in_docstring 12 | 13 | 14 | @fill_in_docstring 15 | def solve_poisson_equation( 16 | rhs: ScalarField, 17 | bc: BoundariesData, 18 | *, 19 | label: str = "Solution to Poisson's equation", 20 | **kwargs, 21 | ) -> ScalarField: 22 | r"""Solve Laplace's equation on a given grid. 23 | 24 | Denoting the current field by :math:`u`, we thus solve for :math:`f`, defined by the 25 | equation 26 | 27 | .. math:: 28 | \nabla^2 u(\boldsymbol r) = -f(\boldsymbol r) 29 | 30 | with boundary conditions specified by `bc`. 31 | 32 | Note: 33 | In case of periodic or Neumann boundary conditions, the right hand side 34 | :math:`f(\boldsymbol r)` needs to satisfy the following condition 35 | 36 | .. math:: 37 | \int f \, \mathrm{d}V = \oint g \, \mathrm{d}S \;, 38 | 39 | where :math:`g` denotes the function specifying the outwards 40 | derivative for Neumann conditions. Note that for periodic boundaries 41 | :math:`g` vanishes, so that this condition implies that the integral 42 | over 43 | :math:`f` must vanish for neutral Neumann or periodic conditions. 44 | 45 | Args: 46 | rhs (:class:`~pde.fields.scalar.ScalarField`): 47 | The scalar field :math:`f` describing the right hand side 48 | bc: 49 | The boundary conditions applied to the field. 50 | {ARG_BOUNDARIES} 51 | label (str): 52 | The label of the returned field. 53 | 54 | Returns: 55 | :class:`~pde.fields.scalar.ScalarField`: The field :math:`u` that solves 56 | the equation. This field will be defined on the same grid as `rhs`. 57 | """ 58 | # get the operator information 59 | operator = rhs.grid._get_operator_info("poisson_solver") 60 | # get the boundary conditions 61 | bcs = rhs.grid.get_boundary_conditions(bc) 62 | # get the actual solver 63 | solver = operator.factory(bcs=bcs, **kwargs) 64 | 65 | # solve the poisson problem 66 | result = ScalarField(rhs.grid, label=label) 67 | try: 68 | solver(rhs.data, result.data) 69 | except RuntimeError as err: 70 | magnitude = rhs.magnitude 71 | if magnitude > 1e-10: 72 | raise RuntimeError( 73 | "Could not solve the Poisson problem. One possible reason for this is " 74 | "that only periodic or Neumann conditions are applied although the " 75 | f"magnitude of the field is {magnitude} and thus non-zero." 76 | ) from err 77 | else: 78 | raise # another error occurred 79 | 80 | return result 81 | 82 | 83 | @fill_in_docstring 84 | def solve_laplace_equation( 85 | grid: GridBase, bc: BoundariesData, *, label: str = "Solution to Laplace's equation" 86 | ) -> ScalarField: 87 | """Solve Laplace's equation on a given grid. 88 | 89 | This is implemented by calling :func:`solve_poisson_equation` with a 90 | vanishing right hand side. 91 | 92 | Args: 93 | grid (:class:`~pde.grids.base.GridBase`): 94 | The grid on which the equation is solved 95 | bc: 96 | The boundary conditions applied to the field. 97 | {ARG_BOUNDARIES} 98 | label (str): 99 | The label of the returned field. 100 | 101 | Returns: 102 | :class:`~pde.fields.scalar.ScalarField`: The field that solves the 103 | equation. This field will be defined on the given `grid`. 104 | """ 105 | rhs = ScalarField(grid, data=0) 106 | return solve_poisson_equation(rhs, bc=bc, label=label) 107 | -------------------------------------------------------------------------------- /pde/py.typed: -------------------------------------------------------------------------------- 1 | # This file indicates that the py-pde package supports typing compliant with PEP 561 -------------------------------------------------------------------------------- /pde/solvers/__init__.py: -------------------------------------------------------------------------------- 1 | """Solvers define how a PDE is solved, i.e., how the initial state is advanced in time. 2 | 3 | .. autosummary:: 4 | :nosignatures: 5 | 6 | ~controller.Controller 7 | ~explicit.ExplicitSolver 8 | ~explicit_mpi.ExplicitMPISolver 9 | ~implicit.ImplicitSolver 10 | ~crank_nicolson.CrankNicolsonSolver 11 | ~adams_bashforth.AdamsBashforthSolver 12 | ~scipy.ScipySolver 13 | ~registered_solvers 14 | 15 | 16 | Inheritance structure of the classes: 17 | 18 | 19 | .. inheritance-diagram:: adams_bashforth.AdamsBashforthSolver 20 | crank_nicolson.CrankNicolsonSolver 21 | explicit.ExplicitSolver 22 | implicit.ImplicitSolver 23 | scipy.ScipySolver 24 | explicit_mpi.ExplicitMPISolver 25 | :parts: 1 26 | 27 | .. codeauthor:: David Zwicker 28 | """ 29 | 30 | from .adams_bashforth import AdamsBashforthSolver 31 | from .controller import Controller 32 | from .crank_nicolson import CrankNicolsonSolver 33 | from .explicit import ExplicitSolver 34 | from .implicit import ImplicitSolver 35 | from .scipy import ScipySolver 36 | 37 | try: 38 | from .explicit_mpi import ExplicitMPISolver 39 | except ImportError: 40 | # MPI modules do not seem to be properly available 41 | ExplicitMPISolver = None # type: ignore 42 | 43 | 44 | def registered_solvers() -> list[str]: 45 | """Returns all solvers that are currently registered. 46 | 47 | Returns: 48 | list of str: List with the names of the solvers 49 | """ 50 | from .base import SolverBase 51 | 52 | return SolverBase.registered_solvers # type: ignore 53 | 54 | 55 | __all__ = [ 56 | "Controller", 57 | "ExplicitSolver", 58 | "ImplicitSolver", 59 | "CrankNicolsonSolver", 60 | "AdamsBashforthSolver", 61 | "ScipySolver", 62 | "registered_solvers", 63 | ] 64 | -------------------------------------------------------------------------------- /pde/solvers/adams_bashforth.py: -------------------------------------------------------------------------------- 1 | """Defines an explicit Adams-Bashforth solver. 2 | 3 | .. codeauthor:: David Zwicker 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | from typing import Any, Callable 9 | 10 | import numba as nb 11 | import numpy as np 12 | 13 | from ..fields.base import FieldBase 14 | from ..tools.numba import jit 15 | from .base import SolverBase 16 | 17 | 18 | class AdamsBashforthSolver(SolverBase): 19 | """Explicit Adams-Bashforth multi-step solver.""" 20 | 21 | name = "adams–bashforth" 22 | 23 | def _make_fixed_stepper( 24 | self, state: FieldBase, dt: float 25 | ) -> Callable[[np.ndarray, float, int, Any], float]: 26 | """Return a stepper function using an explicit scheme with fixed time steps. 27 | 28 | Args: 29 | state (:class:`~pde.fields.base.FieldBase`): 30 | An example for the state from which the grid and other information can 31 | be extracted 32 | dt (float): 33 | Time step of the explicit stepping 34 | """ 35 | if self.pde.is_sde: 36 | raise NotImplementedError 37 | 38 | rhs_pde = self._make_pde_rhs(state, backend=self.backend) 39 | post_step_hook = self._make_post_step_hook(state) 40 | 41 | def single_step( 42 | state_data: np.ndarray, t: float, state_prev: np.ndarray 43 | ) -> None: 44 | """Perform a single Adams-Bashforth step.""" 45 | rhs_prev = rhs_pde(state_prev, t - dt).copy() 46 | rhs_cur = rhs_pde(state_data, t) 47 | state_prev[:] = state_data # save the previous state 48 | state_data += dt * (1.5 * rhs_cur - 0.5 * rhs_prev) 49 | 50 | # allocate memory to store the state of the previous time step 51 | state_prev = np.empty_like(state.data) 52 | init_state_prev = True 53 | 54 | if self._compiled: 55 | sig_single_step = (nb.typeof(state.data), nb.double, nb.typeof(state_prev)) 56 | single_step = jit(sig_single_step)(single_step) 57 | 58 | def fixed_stepper( 59 | state_data: np.ndarray, t_start: float, steps: int, post_step_data 60 | ) -> float: 61 | """Perform `steps` steps with fixed time steps.""" 62 | nonlocal state_prev, init_state_prev 63 | 64 | if init_state_prev: 65 | # initialize the state_prev with an estimate of the previous step 66 | state_prev[:] = state_data - dt * rhs_pde(state_data, t_start) 67 | init_state_prev = False 68 | 69 | for i in range(steps): 70 | # calculate the right hand side 71 | t = t_start + i * dt 72 | single_step(state_data, t, state_prev) 73 | post_step_hook(state_data, t, post_step_data=post_step_data) 74 | 75 | return t + dt 76 | 77 | self._logger.info("Init explicit Adams-Bashforth stepper with dt=%g", dt) 78 | 79 | return fixed_stepper 80 | -------------------------------------------------------------------------------- /pde/storage/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining classes for storing simulation data. 2 | 3 | .. autosummary:: 4 | :nosignatures: 5 | 6 | ~memory.get_memory_storage 7 | ~memory.MemoryStorage 8 | ~modelrunner.ModelrunnerStorage 9 | ~file.FileStorage 10 | ~movie.MovieStorage 11 | 12 | .. codeauthor:: David Zwicker 13 | """ 14 | 15 | import contextlib 16 | 17 | from .file import FileStorage 18 | from .memory import MemoryStorage, get_memory_storage 19 | from .movie import MovieStorage 20 | 21 | with contextlib.suppress(ImportError): 22 | from .modelrunner import ModelrunnerStorage 23 | -------------------------------------------------------------------------------- /pde/tools/__init__.py: -------------------------------------------------------------------------------- 1 | """Package containing several tools required in py-pde. 2 | 3 | .. autosummary:: 4 | :nosignatures: 5 | 6 | cache 7 | config 8 | cuboid 9 | docstrings 10 | expressions 11 | ffmpeg 12 | math 13 | misc 14 | mpi 15 | numba 16 | output 17 | parameters 18 | parse_duration 19 | plotting 20 | spectral 21 | typing 22 | 23 | .. codeauthor:: David Zwicker 24 | """ 25 | -------------------------------------------------------------------------------- /pde/tools/modelrunner.py: -------------------------------------------------------------------------------- 1 | """Establishes hooks for the interplay between :mod:`pde` and :mod:`modelrunner` 2 | 3 | This package is usually loaded automatically during import if :mod:`modelrunner` is 4 | available. In this case, grids and fields of :mod:`pde` can be directly written to 5 | storages from :mod:`modelrunner.storage`. 6 | 7 | .. codeauthor:: David Zwicker 8 | """ 9 | 10 | from collections.abc import Sequence 11 | 12 | from modelrunner.storage import StorageBase, storage_actions 13 | from modelrunner.storage.utils import decode_class 14 | 15 | from ..fields.base import FieldBase 16 | from ..grids.base import GridBase 17 | 18 | 19 | # these actions are inherited by all subclasses by default 20 | def load_grid(storage: StorageBase, loc: Sequence[str]) -> GridBase: 21 | """Function loading a grid from a modelrunner storage. 22 | 23 | Args: 24 | storage (:class:`~modelrunner.storage.group.StorageGroup`): 25 | Storage to load data from 26 | loc (Location): 27 | Location in the storage 28 | 29 | Returns: 30 | :class:`~pde.grids.base.GridBase`: the loaded grid 31 | """ 32 | # get grid class that was stored 33 | stored_cls = decode_class(storage._read_attrs(loc).get("__class__")) 34 | state = storage.read_attrs(loc) 35 | return stored_cls.from_state(state) # type: ignore 36 | 37 | 38 | storage_actions.register("read_item", GridBase, load_grid) 39 | 40 | 41 | def save_grid(storage: StorageBase, loc: Sequence[str], grid: GridBase) -> None: 42 | """Function saving a grid to a modelrunner storage. 43 | 44 | Args: 45 | storage (:class:`~modelrunner.storage.group.StorageGroup`): 46 | Storage to save data to 47 | loc (Location): 48 | Location in the storage 49 | grid (:class:`~pde.grids.base.GridBase`): 50 | the grid to store 51 | """ 52 | storage.write_object(loc, None, attrs=grid.state, cls=grid.__class__) 53 | 54 | 55 | storage_actions.register("write_item", GridBase, save_grid) 56 | 57 | 58 | # these actions are inherited by all subclasses by default 59 | def load_field(storage: StorageBase, loc: Sequence[str]) -> FieldBase: 60 | """Function loading a field from a modelrunner storage. 61 | 62 | Args: 63 | storage (:class:`~modelrunner.storage.group.StorageGroup`): 64 | Storage to load data from 65 | loc (Location): 66 | Location in the storage 67 | 68 | Returns: 69 | :class:`~pde.fields.base.FieldBase`: the loaded field 70 | """ 71 | # get field class that was stored 72 | stored_cls = decode_class(storage._read_attrs(loc).get("__class__")) 73 | attributes = stored_cls.unserialize_attributes(storage.read_attrs(loc)) # type: ignore 74 | return stored_cls.from_state(attributes, data=storage.read_array(loc)) # type: ignore 75 | 76 | 77 | storage_actions.register("read_item", FieldBase, load_field) 78 | 79 | 80 | def save_field(storage: StorageBase, loc: Sequence[str], field: FieldBase) -> None: 81 | """Function saving a field to a modelrunner storage. 82 | 83 | Args: 84 | storage (:class:`~modelrunner.storage.group.StorageGroup`): 85 | Storage to save data to 86 | loc (Location): 87 | Location in the storage 88 | field (:class:`~pde.fields.base.FieldBase`): 89 | the field to store 90 | """ 91 | storage.write_array( 92 | loc, field.data, attrs=field.attributes_serialized, cls=field.__class__ 93 | ) 94 | 95 | 96 | storage_actions.register("write_item", FieldBase, save_field) 97 | 98 | 99 | __all__: list[str] = [] # module only registers hooks and does not export any functions 100 | -------------------------------------------------------------------------------- /pde/tools/resources/requirements_basic.txt: -------------------------------------------------------------------------------- 1 | # These are the basic requirements for the package 2 | matplotlib>=3.1 3 | numba>=0.59 4 | numpy>=1.22 5 | scipy>=1.10 6 | sympy>=1.9 7 | tqdm>=4.66 8 | -------------------------------------------------------------------------------- /pde/tools/resources/requirements_full.txt: -------------------------------------------------------------------------------- 1 | # These are the full requirements used to test all functions 2 | ffmpeg-python>=0.2 3 | h5py>=2.10 4 | ipywidgets>=8 5 | matplotlib>=3.1 6 | numba>=0.59 7 | numpy>=1.22 8 | pandas>=2 9 | py-modelrunner>=0.19 10 | rocket-fft>=0.2.4 11 | scipy>=1.10 12 | sympy>=1.9 13 | tqdm>=4.66 14 | -------------------------------------------------------------------------------- /pde/tools/resources/requirements_mpi.txt: -------------------------------------------------------------------------------- 1 | # These are requirements for supporting multiprocessing 2 | h5py>=2.10 3 | matplotlib>=3.1 4 | mpi4py>=3 5 | numba>=0.59 6 | numba-mpi>=0.22 7 | numpy>=1.22 8 | pandas>=2 9 | scipy>=1.10 10 | sympy>=1.9 11 | tqdm>=4.66 12 | -------------------------------------------------------------------------------- /pde/tools/typing.py: -------------------------------------------------------------------------------- 1 | """Provides support for mypy type checking of the package. 2 | 3 | .. codeauthor:: David Zwicker 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | from typing import TYPE_CHECKING, Literal, Protocol, Union 9 | 10 | import numpy as np 11 | from numpy.typing import ArrayLike 12 | 13 | if TYPE_CHECKING: 14 | from ..grids.base import GridBase 15 | 16 | Real = Union[int, float] 17 | Number = Union[Real, complex] 18 | NumberOrArray = Union[Number, np.ndarray] 19 | FloatNumerical = Union[float, np.ndarray] 20 | BackendType = Literal["auto", "numpy", "numba"] 21 | 22 | 23 | class OperatorType(Protocol): 24 | """An operator that acts on an array.""" 25 | 26 | def __call__(self, arr: np.ndarray, out: np.ndarray) -> None: 27 | """Evaluate the operator.""" 28 | 29 | 30 | class OperatorFactory(Protocol): 31 | """A factory function that creates an operator for a particular grid.""" 32 | 33 | def __call__(self, grid: GridBase, **kwargs) -> OperatorType: 34 | """Create the operator.""" 35 | 36 | 37 | class CellVolume(Protocol): 38 | def __call__(self, *args: int) -> float: 39 | """Calculate the volume of the cell at the given position.""" 40 | 41 | 42 | class VirtualPointEvaluator(Protocol): 43 | def __call__(self, arr: np.ndarray, idx: tuple[int, ...], args=None) -> float: 44 | """Evaluate the virtual point at the given position.""" 45 | 46 | 47 | class GhostCellSetter(Protocol): 48 | def __call__(self, data_full: np.ndarray, args=None) -> None: 49 | """Set the ghost cells.""" 50 | 51 | 52 | class StepperHook(Protocol): 53 | def __call__( 54 | self, state_data: np.ndarray, t: float, post_step_data: np.ndarray 55 | ) -> None: 56 | """Function analyzing and potentially modifying the current state.""" 57 | -------------------------------------------------------------------------------- /pde/trackers/__init__.py: -------------------------------------------------------------------------------- 1 | """Classes for tracking simulation results in controlled interrupts. 2 | 3 | Trackers are classes that periodically receive the state of the simulation to analyze, 4 | store, or output it. The trackers defined in this module are: 5 | 6 | .. autosummary:: 7 | :nosignatures: 8 | 9 | ~trackers.CallbackTracker 10 | ~trackers.ProgressTracker 11 | ~trackers.PrintTracker 12 | ~trackers.PlotTracker 13 | ~trackers.LivePlotTracker 14 | ~trackers.DataTracker 15 | ~trackers.SteadyStateTracker 16 | ~trackers.RuntimeTracker 17 | ~trackers.ConsistencyTracker 18 | ~interactive.InteractivePlotTracker 19 | 20 | Some trackers can also be referenced by name for convenience when using them in 21 | simulations. The lit of supported names is returned by 22 | :func:`~pde.trackers.base.get_named_trackers`. 23 | 24 | Multiple trackers can be collected in a :class:`~base.TrackerCollection`, which provides 25 | methods for handling them efficiently. Moreover, custom trackers can be implemented by 26 | deriving from :class:`~.trackers.base.TrackerBase`. Note that trackers generally receive 27 | a view into the current state, implying that they can adjust the state by modifying it 28 | in-place. Moreover, trackers can abort the simulation by raising the special exception 29 | :class:`StopIteration`. 30 | 31 | 32 | For each tracker, the time at which the simulation is interrupted can be decided using 33 | one of the following classes: 34 | 35 | .. autosummary:: 36 | :nosignatures: 37 | 38 | ~interrupts.FixedInterrupts 39 | ~interrupts.ConstantInterrupts 40 | ~interrupts.LogarithmicInterrupts 41 | ~interrupts.GeometricInterrupts 42 | ~interrupts.RealtimeInterrupts 43 | 44 | In particular, interrupts can be specified conveniently using 45 | :func:`~interrupts.parse_interrupt`. 46 | 47 | .. codeauthor:: David Zwicker 48 | """ 49 | 50 | from .base import get_named_trackers 51 | from .interactive import InteractivePlotTracker 52 | from .interrupts import ( 53 | ConstantInterrupts, 54 | FixedInterrupts, 55 | LogarithmicInterrupts, 56 | RealtimeInterrupts, 57 | parse_interrupt, 58 | ) 59 | from .trackers import * 60 | -------------------------------------------------------------------------------- /pde/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | """Functions and classes for visualizing simulations. 2 | 3 | .. autosummary:: 4 | :nosignatures: 5 | 6 | movies 7 | plotting 8 | 9 | .. codeauthor:: David Zwicker 10 | """ 11 | 12 | from .movies import movie, movie_multiple, movie_scalar 13 | from .plotting import plot_interactive, plot_kymograph, plot_kymographs, plot_magnitudes 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.1 2 | numba>=0.59 3 | numpy>=1.22 4 | scipy>=1.10 5 | sympy>=1.9 6 | tqdm>=4.66 7 | -------------------------------------------------------------------------------- /runtime.txt: -------------------------------------------------------------------------------- 1 | python-3.9 -------------------------------------------------------------------------------- /scripts/_templates/_runtime.txt: -------------------------------------------------------------------------------- 1 | python-$MIN_PYTHON_VERSION -------------------------------------------------------------------------------- /scripts/create_storage_test_resources.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This script creates storage files for backwards compatibility tests.""" 3 | 4 | from __future__ import annotations 5 | 6 | import sys 7 | from pathlib import Path 8 | 9 | PACKAGE_PATH = Path(__file__).resolve().parents[1] 10 | sys.path.insert(0, str(PACKAGE_PATH)) 11 | 12 | import pde 13 | 14 | 15 | def create_storage_test_resources(path, num): 16 | """Test storing scalar field as movie.""" 17 | grid = pde.CylindricalSymGrid(3, [1, 2], [2, 2]) 18 | field = pde.ScalarField(grid, [[1, 3], [2, 4]]) 19 | eq = pde.DiffusionPDE() 20 | info = {"payload": "storage-test"} 21 | movie_writer = pde.MovieStorage( 22 | path / f"storage_{num}.avi", 23 | info=info, 24 | vmax=4, 25 | bits_per_channel=16, 26 | write_times=True, 27 | ) 28 | file_writer = pde.FileStorage(path / f"storage_{num}.hdf5", info=info) 29 | interrupts = pde.FixedInterrupts([0.1, 0.7, 2.9]) 30 | eq.solve( 31 | field, 32 | t_range=3.5, 33 | dt=0.1, 34 | backend="numpy", 35 | tracker=[movie_writer.tracker(interrupts), file_writer.tracker(interrupts)], 36 | ) 37 | 38 | 39 | def main(): 40 | """Main function creating all the requirements.""" 41 | root = Path(PACKAGE_PATH) 42 | create_storage_test_resources(root / "tests" / "storage" / "resources", 2) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /scripts/format_code.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # This script formats the code of this package 3 | 4 | echo "Formatting import statements..." 5 | ruff check --fix --config=../pyproject.toml .. 6 | 7 | echo "Formatting docstrings..." 8 | docformatter --in-place --black --recursive .. 9 | 10 | echo "Formatting source code..." 11 | ruff format --config=../pyproject.toml .. -------------------------------------------------------------------------------- /scripts/performance_boundaries.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This script tests the performance of the implementation of different boundary 3 | conditions.""" 4 | 5 | import sys 6 | from pathlib import Path 7 | 8 | PACKAGE_PATH = Path(__file__).resolve().parents[1] 9 | sys.path.insert(0, str(PACKAGE_PATH)) 10 | 11 | import numpy as np 12 | 13 | from pde import ScalarField, UnitGrid 14 | from pde.tools.misc import estimate_computation_speed 15 | from pde.tools.numba import numba_dict 16 | 17 | 18 | def main(): 19 | """Main routine testing the performance.""" 20 | print("Reports calls-per-second (larger is better)\n") 21 | 22 | # Cartesian grid with different shapes and boundary conditions 23 | for size in [32, 512]: 24 | grid = UnitGrid([size, size], periodic=False) 25 | print(grid) 26 | 27 | field = ScalarField.random_normal(grid) 28 | bc_value = np.ones(size) 29 | result = field.laplace(bc={"value": 1}).data 30 | 31 | for bc in ["scalar", "array", "function", "time-dependent", "linked"]: 32 | if bc == "scalar": 33 | bcs = {"value": 1} 34 | elif bc == "array": 35 | bcs = {"value": bc_value} 36 | elif bc == "function": 37 | bcs = grid.get_boundary_conditions({"virtual_point": "2 - value"}) 38 | elif bc == "time-dependent": 39 | bcs = grid.get_boundary_conditions({"value_expression": "t"}) 40 | elif bc == "linked": 41 | bcs = grid.get_boundary_conditions({"value": bc_value}) 42 | for ax, upper in grid._iter_boundaries(): 43 | bcs[ax][upper].link_value(bc_value) 44 | else: 45 | raise RuntimeError 46 | 47 | # create the operator with these conditions 48 | laplace = grid.make_operator("laplace", bc=bcs) 49 | if bc == "time-dependent": 50 | args = numba_dict(t=1) 51 | # call once to pre-compile and test result 52 | np.testing.assert_allclose(laplace(field.data, args=args), result) 53 | # estimate the speed 54 | speed = estimate_computation_speed(laplace, field.data, args=args) 55 | 56 | else: 57 | # call once to pre-compile and test result 58 | np.testing.assert_allclose(laplace(field.data), result) 59 | # estimate the speed 60 | speed = estimate_computation_speed(laplace, field.data) 61 | 62 | print(f"{bc:>14s}:{int(speed):>9d}") 63 | 64 | print() 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /scripts/performance_solvers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This script tests the performance of different solvers.""" 3 | 4 | import sys 5 | from pathlib import Path 6 | from typing import Literal 7 | 8 | PACKAGE_PATH = Path(__file__).resolve().parents[1] 9 | sys.path.insert(0, str(PACKAGE_PATH)) 10 | 11 | import numpy as np 12 | 13 | from pde import CahnHilliardPDE, Controller, DiffusionPDE, ScalarField, UnitGrid 14 | from pde.solvers import ( 15 | AdamsBashforthSolver, 16 | CrankNicolsonSolver, 17 | ExplicitSolver, 18 | ImplicitSolver, 19 | ScipySolver, 20 | ) 21 | 22 | 23 | def main( 24 | equation: Literal["diffusion", "cahn-hilliard"] = "cahn-hilliard", 25 | t_range: float = 100, 26 | size: int = 32, 27 | ): 28 | """Main routine testing the performance. 29 | 30 | Args: 31 | equation (str): 32 | Chooses the equation to consider 33 | t_range (float): 34 | Sets the total duration that should be solved for 35 | size (int): 36 | The number of grid points along each axis 37 | """ 38 | print("Reports duration in seconds (smaller is better)\n") 39 | 40 | # determine grid and initial state 41 | grid = UnitGrid([size, size], periodic=False) 42 | field = ScalarField.random_uniform(grid) 43 | print(f"GRID: {grid}") 44 | 45 | # determine the equation to solve 46 | if equation == "diffusion": 47 | eq = DiffusionPDE() 48 | elif equation == "cahn-hilliard": 49 | eq = CahnHilliardPDE() 50 | else: 51 | raise ValueError(f"Undefined equation `{equation}`") 52 | print(f"EQUATION: ∂c/∂t = {eq.expression}\n") 53 | 54 | print("Determine ground truth...") 55 | expected = eq.solve(field, t_range=t_range, dt=1e-4, tracker=["progress"]) 56 | 57 | print("\nSOLVER PERFORMANCE:") 58 | solvers = { 59 | "Euler, fixed": (1e-3, ExplicitSolver(eq, scheme="euler", adaptive=False)), 60 | "Euler, adaptive": (1e-3, ExplicitSolver(eq, scheme="euler", adaptive=True)), 61 | "Runge-Kutta, fixed": (1e-2, ExplicitSolver(eq, scheme="rk", adaptive=False)), 62 | "Runge-Kutta, adaptive": (1e-2, ExplicitSolver(eq, scheme="rk", adaptive=True)), 63 | "Implicit": (1e-2, ImplicitSolver(eq)), 64 | "Adams-Bashforth": (1e-2, AdamsBashforthSolver(eq)), 65 | "Crank-Nicolson": (1e-2, CrankNicolsonSolver(eq)), 66 | "Scipy": (None, ScipySolver(eq)), 67 | } 68 | 69 | for name, (dt, solver) in solvers.items(): 70 | # run the simulation with the given solver 71 | solver.backend = "numba" 72 | controller = Controller(solver, t_range=t_range, tracker=None) 73 | result = controller.run(field, dt=dt) 74 | 75 | # determine the deviation from the ground truth 76 | error = np.linalg.norm(result.data - expected.data) 77 | error_str = f"{error:.4g}" 78 | 79 | # report the runtime 80 | runtime_str = f"{controller.info['profiler']['solver']:.3g}" 81 | 82 | # report information about the time step 83 | if solver.info.get("dt_adaptive", False): 84 | stats = solver.info["dt_statistics"] 85 | dt_str = f"{stats['min']:.3g} .. {stats['max']:.3g}" 86 | elif solver.info["dt"] is None: 87 | dt_str = "automatic" 88 | else: 89 | dt_str = f"{solver.info['dt']:.3g}" 90 | 91 | print(f"{name:>21s}: runtime={runtime_str:8} error={error_str:11} dt={dt_str}") 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /scripts/profile_import.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This scripts measures the total time it takes to import the module. 3 | 4 | The total time should ideally be below 1 second. 5 | """ 6 | 7 | import sys 8 | from pathlib import Path 9 | 10 | PACKAGE_PATH = Path(__file__).resolve().parents[1] 11 | sys.path.insert(0, str(PACKAGE_PATH)) 12 | 13 | 14 | from pyinstrument import Profiler 15 | 16 | with Profiler() as profiler: 17 | import pde 18 | 19 | print(profiler.open_in_browser()) 20 | -------------------------------------------------------------------------------- /scripts/show_environment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This script shows important information about the current python environment and the 3 | associated installed packages. 4 | 5 | This information can be helpful in understanding issues that occur with the package 6 | """ 7 | 8 | import sys 9 | from pathlib import Path 10 | 11 | PACKAGE_PATH = Path(__file__).resolve().parents[1] 12 | sys.path.insert(0, str(PACKAGE_PATH)) 13 | 14 | from pde import environment 15 | 16 | for category, data in environment().items(): 17 | if hasattr(data, "items"): 18 | print(f"\n{category}:") 19 | for key, value in data.items(): 20 | print(f" {key}: {value}") 21 | else: 22 | data_formatted = str(data).replace("\n", "\n ") 23 | print(f"{category}: {data_formatted}") 24 | -------------------------------------------------------------------------------- /scripts/tests_all.sh: -------------------------------------------------------------------------------- 1 | echo "Test codestyle" 2 | echo "--------------" 3 | ./tests_codestyle.sh 4 | 5 | echo "" 6 | echo "Test types" 7 | echo "----------" 8 | ./tests_types.sh 9 | 10 | echo "" 11 | echo "Run serial unittests" 12 | echo "-------------" 13 | ./tests_parallel.sh 14 | 15 | echo "" 16 | echo "Run parallel unittests" 17 | echo "-------------" 18 | ./tests_mpi.sh -------------------------------------------------------------------------------- /scripts/tests_codestyle.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # This script checks the code format of this package without changing files 4 | # 5 | 6 | ./run_tests.py --style -------------------------------------------------------------------------------- /scripts/tests_coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 'Run serial tests to determine coverage...' 4 | ./run_tests.py --unit --coverage --nojit --num_cores auto --runslow 5 | 6 | echo 'Run parallel tests to determine coverage...' 7 | ./run_tests.py --unit --coverage --nojit --use_mpi --runslow -- --cov-append 8 | -------------------------------------------------------------------------------- /scripts/tests_debug.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PYTHONPATH=../py-pde # likely path of pde package, relative to current base path 4 | 5 | if [ ! -z $1 ] 6 | then 7 | # test pattern was specified 8 | echo 'Run unittests with pattern '$1':' 9 | ./run_tests.py --unit --runslow --nojit --pattern "$1" -- \ 10 | -o log_cli=true --log-cli-level=debug -vv 11 | else 12 | # test pattern was not specified 13 | echo 'Run all unittests:' 14 | ./run_tests.py --unit --nojit -- \ 15 | -o log_cli=true --log-cli-level=debug -vv 16 | fi 17 | -------------------------------------------------------------------------------- /scripts/tests_extensive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # test pattern was not specified 4 | echo 'Run all unittests:' 5 | ./run_tests.py --unit --runslow --num_cores auto -- -rsx 6 | -------------------------------------------------------------------------------- /scripts/tests_mpi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -z $1 ] 4 | then 5 | # test pattern was specified 6 | echo 'Run unittests with pattern '$1':' 7 | ./run_tests.py --unit --use_mpi --runslow --pattern "$1" 8 | else 9 | # test pattern was not specified 10 | echo 'Run all unittests:' 11 | ./run_tests.py --unit --use_mpi 12 | fi 13 | -------------------------------------------------------------------------------- /scripts/tests_parallel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -z $1 ] 4 | then 5 | # test pattern was specified 6 | echo 'Run unittests with pattern '$1':' 7 | ./run_tests.py --unit --runslow --num_cores auto --pattern "$1" 8 | else 9 | # test pattern was not specified 10 | echo 'Run all unittests:' 11 | ./run_tests.py --unit --num_cores auto 12 | fi 13 | -------------------------------------------------------------------------------- /scripts/tests_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -z $1 ] 4 | then 5 | # test pattern was specified 6 | echo 'Run unittests with pattern '$1':' 7 | ./run_tests.py --unit --runslow --pattern "$1" 8 | else 9 | # test pattern was not specified 10 | echo 'Run all unittests:' 11 | ./run_tests.py --unit 12 | fi 13 | -------------------------------------------------------------------------------- /scripts/tests_types.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ./run_tests.py --types 4 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """This file is used to configure the test environment when running py.test. 2 | 3 | .. codeauthor:: David Zwicker 4 | """ 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pytest 9 | 10 | from pde import config 11 | from pde.tools.misc import module_available 12 | from pde.tools.numba import random_seed 13 | 14 | # ensure we use the Agg backend, so figures are not displayed 15 | plt.switch_backend("agg") 16 | 17 | 18 | @pytest.fixture(autouse=True) 19 | def _setup_and_teardown(): 20 | """Helper function adjusting environment before and after tests.""" 21 | # raise all underflow errors 22 | np.seterr(all="raise", under="ignore") 23 | 24 | # run the actual test 25 | with config({"boundaries.accept_lists": False}): 26 | yield 27 | 28 | # clean up open matplotlib figures after the test 29 | plt.close("all") 30 | 31 | 32 | @pytest.fixture(autouse=False, name="rng") 33 | def init_random_number_generators(): 34 | """Get a random number generator and set the seed of the random number generator. 35 | 36 | The function returns an instance of :func:`~numpy.random.default_rng()` and 37 | initializes the default generators of both :mod:`numpy` and :mod:`numba`. 38 | """ 39 | random_seed() 40 | return np.random.default_rng(0) 41 | 42 | 43 | def pytest_configure(config): 44 | """Add markers to the configuration.""" 45 | config.addinivalue_line("markers", "interactive: test is interactive") 46 | config.addinivalue_line("markers", "multiprocessing: test requires multiprocessing") 47 | config.addinivalue_line("markers", "slow: test runs slowly") 48 | 49 | 50 | def pytest_addoption(parser): 51 | """Pytest hook to add command line options parsed by pytest.""" 52 | parser.addoption( 53 | "--runslow", 54 | action="store_true", 55 | default=False, 56 | help="also run tests marked by `slow`", 57 | ) 58 | parser.addoption( 59 | "--runinteractive", 60 | action="store_true", 61 | default=False, 62 | help="also run tests marked by `interactive`", 63 | ) 64 | parser.addoption( 65 | "--use_mpi", 66 | action="store_true", 67 | default=False, 68 | help="only run tests marked by `multiprocessing`", 69 | ) 70 | 71 | 72 | def pytest_collection_modifyitems(config, items): 73 | """Pytest hook to filter a collection of tests.""" 74 | # parse options provided to py.test 75 | running_cov = config.getvalue("--cov") 76 | runslow = config.getoption("--runslow", default=False) 77 | runinteractive = config.getoption("--runinteractive", default=False) 78 | use_mpi = config.getoption("--use_mpi", default=False) 79 | has_numba_mpi = module_available("numba_mpi") and module_available("mpi4py") 80 | 81 | # prepare markers 82 | skip_cov = pytest.mark.skip(reason="skipped during coverage run") 83 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 84 | skip_interactive = pytest.mark.skip(reason="need --runinteractive option to run") 85 | skip_serial = pytest.mark.skip(reason="serial test, but --use_mpi option was set") 86 | skip_mpi = pytest.mark.skip(reason="mpi test, but `numba_mpi` not available") 87 | 88 | # check each test item 89 | for item in items: 90 | if "no_cover" in item.keywords and running_cov: 91 | item.add_marker(skip_cov) 92 | if "slow" in item.keywords and not runslow: 93 | item.add_marker(skip_slow) 94 | if "interactive" in item.keywords and not runinteractive: 95 | item.add_marker(skip_interactive) 96 | 97 | if "multiprocessing" in item.keywords and not has_numba_mpi: 98 | item.add_marker(skip_mpi) 99 | if use_mpi and "multiprocessing" not in item.keywords: 100 | item.add_marker(skip_serial) 101 | -------------------------------------------------------------------------------- /tests/fields/fixtures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/fields/fixtures/__init__.py -------------------------------------------------------------------------------- /tests/fields/fixtures/fields.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | 7 | from pde import ( 8 | CartesianGrid, 9 | CylindricalSymGrid, 10 | FieldCollection, 11 | PolarSymGrid, 12 | ScalarField, 13 | SphericalSymGrid, 14 | Tensor2Field, 15 | UnitGrid, 16 | VectorField, 17 | ) 18 | 19 | 20 | def iter_grids(): 21 | """Generator providing some test grids.""" 22 | for periodic in [True, False]: 23 | yield UnitGrid([3], periodic=periodic) 24 | yield UnitGrid([3, 3, 3], periodic=periodic) 25 | yield CartesianGrid([[-1, 2], [0, 3]], [5, 7], periodic=periodic) 26 | yield CylindricalSymGrid(3, [-1, 2], [7, 8], periodic_z=periodic) 27 | yield PolarSymGrid(3, 4) 28 | yield SphericalSymGrid(3, 4) 29 | 30 | 31 | def iter_fields(): 32 | """Generator providing some test fields.""" 33 | yield ScalarField(UnitGrid([1, 2, 3]), 1) 34 | yield VectorField.from_expression(PolarSymGrid(2, 3), ["r**2", "r"]) 35 | yield Tensor2Field.random_normal( 36 | CylindricalSymGrid(3, [-1, 2], [7, 8], periodic_z=True) 37 | ) 38 | 39 | grid = CartesianGrid([[0, 2], [-1, 1]], [3, 4], [True, False]) 40 | yield FieldCollection([ScalarField(grid, 1), VectorField(grid, 2)]) 41 | 42 | 43 | def get_cartesian_grid(dim=2, periodic=True): 44 | """Return a random Cartesian grid of given dimension.""" 45 | rng = np.random.default_rng(0) 46 | bounds = [[0, 1 + rng.random()] for _ in range(dim)] 47 | shape = rng.integers(32, 64, size=dim) 48 | return CartesianGrid(bounds, shape, periodic=periodic) 49 | -------------------------------------------------------------------------------- /tests/grids/boundaries/test_axis_boundaries.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import itertools 6 | 7 | import pytest 8 | 9 | from pde import UnitGrid 10 | from pde.grids.boundaries.axis import BoundaryPair, get_boundary_axis 11 | from pde.grids.boundaries.local import BCBase 12 | 13 | 14 | def test_boundary_pair(): 15 | """Test setting boundary conditions for whole axis.""" 16 | g = UnitGrid([2, 3]) 17 | b = ["value", {"type": "derivative", "value": 1}] 18 | for bl, bh in itertools.product(b, b): 19 | bc = BoundaryPair.from_data(g, 0, [bl, bh]) 20 | blo = BCBase.from_data(g, 0, upper=False, data=bl) 21 | bho = BCBase.from_data(g, 0, upper=True, data=bh) 22 | 23 | assert bc.low == blo 24 | assert bc.high == bho 25 | assert bc == BoundaryPair(blo, bho) 26 | if bl == bh: 27 | assert bc == BoundaryPair.from_data(g, 0, bl) 28 | assert list(bc) == [blo, bho] 29 | assert isinstance(str(bc), str) 30 | assert isinstance(repr(bc), str) 31 | 32 | bc.check_value_rank(0) 33 | with pytest.raises(RuntimeError): 34 | bc.check_value_rank(1) 35 | 36 | data = {"low": {"value": 1}, "high": {"derivative": 2}} 37 | bc1 = BoundaryPair.from_data(g, 0, data) 38 | bc2 = BoundaryPair.from_data(g, 0, data) 39 | assert bc1 == bc2 40 | assert bc1 is not bc2 41 | bc2 = BoundaryPair.from_data(g, 1, data) 42 | assert bc1 != bc2 43 | assert bc1 is not bc2 44 | 45 | # miscellaneous methods 46 | data = {"low": {"value": 0}, "high": {"derivative": 0}} 47 | bc1 = BoundaryPair.from_data(g, 0, data) 48 | b_lo, b_hi = bc1 49 | assert b_lo == BCBase.from_data(g, 0, False, {"value": 0}) 50 | assert b_hi == BCBase.from_data(g, 0, True, {"derivative": 0}) 51 | assert b_lo is bc1[0] 52 | assert b_lo is bc1[False] 53 | assert b_hi is bc1[1] 54 | assert b_hi is bc1[True] 55 | 56 | 57 | def test_get_axis_boundaries(): 58 | """Test setting boundary conditions including periodic ones.""" 59 | for data in ["value", "derivative", "periodic", "anti-periodic"]: 60 | g = UnitGrid([2], periodic=("periodic" in data)) 61 | b = get_boundary_axis(g, 0, data) 62 | assert str(b) == '"' + data + '"' 63 | b1, b2 = b.get_mathematical_representation("field") 64 | assert "field" in b1 65 | assert "field" in b2 66 | 67 | if "periodic" in data: 68 | assert b.periodic 69 | assert len(list(b)) == 2 70 | assert b.flip_sign == (data == "anti-periodic") 71 | else: 72 | assert not b.periodic 73 | assert len(list(b)) == 2 74 | 75 | # check double setting 76 | c = get_boundary_axis(g, 0, (data, data)) 77 | assert b == c 78 | -------------------------------------------------------------------------------- /tests/grids/operators/test_common_operators.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import random 6 | 7 | import numpy as np 8 | import pytest 9 | 10 | from pde import CartesianGrid, ScalarField 11 | from pde.grids.operators import common as ops 12 | 13 | 14 | @pytest.mark.parametrize("ndim,axis", [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]) 15 | def test_make_derivative(ndim, axis, rng): 16 | """Test the _make_derivative function.""" 17 | periodic = random.choice([True, False]) 18 | grid = CartesianGrid([[0, 6 * np.pi]] * ndim, 16, periodic=periodic) 19 | field = ScalarField.random_harmonic(grid, modes=1, axis_combination=np.add, rng=rng) 20 | 21 | bcs = grid.get_boundary_conditions("auto_periodic_neumann") 22 | grad = field.gradient(bcs) 23 | for method in ["central", "forward", "backward"]: 24 | msg = f"method={method}, periodic={periodic}" 25 | diff = ops.make_derivative(grid, axis=axis, method=method) 26 | res = field.copy() 27 | res.data[:] = 0 28 | field.set_ghost_cells(bcs) 29 | diff(field._data_full, out=res.data) 30 | np.testing.assert_allclose( 31 | grad.data[axis], res.data, atol=0.1, rtol=0.1, err_msg=msg 32 | ) 33 | 34 | 35 | @pytest.mark.parametrize("ndim,axis", [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]) 36 | def test_make_derivative2(ndim, axis, rng): 37 | """Test the _make_derivative2 function.""" 38 | periodic = random.choice([True, False]) 39 | grid = CartesianGrid([[0, 6 * np.pi]] * ndim, 16, periodic=periodic) 40 | field = ScalarField.random_harmonic(grid, modes=1, axis_combination=np.add, rng=rng) 41 | 42 | bcs = grid.get_boundary_conditions("auto_periodic_neumann") 43 | grad = field.gradient(bcs)[axis] 44 | grad2 = grad.gradient(bcs)[axis] 45 | 46 | diff = ops.make_derivative2(grid, axis=axis) 47 | res = field.copy() 48 | res.data[:] = 0 49 | field.set_ghost_cells(bcs) 50 | diff(field._data_full, out=res.data) 51 | np.testing.assert_allclose(grad2.data, res.data, atol=0.1, rtol=0.1) 52 | -------------------------------------------------------------------------------- /tests/grids/test_cylindrical_grids.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from pde import CartesianGrid, CylindricalSymGrid, ScalarField 9 | from pde.grids.boundaries.local import NeumannBC 10 | 11 | 12 | @pytest.mark.parametrize("periodic", [True, False]) 13 | @pytest.mark.parametrize("r_inner", [0, 2]) 14 | def test_cylindrical_grid(periodic, r_inner, rng): 15 | """Test simple cylindrical grid.""" 16 | grid = CylindricalSymGrid((r_inner, 4), (-1, 2), (8, 9), periodic_z=periodic) 17 | if r_inner == 0: 18 | assert grid == CylindricalSymGrid(4, (-1, 2), (8, 9), periodic_z=periodic) 19 | rs, zs = grid.axes_coords 20 | 21 | assert grid.dim == 3 22 | assert grid.numba_type == "f8[:, :]" 23 | assert grid.shape == (8, 9) 24 | assert grid.length == pytest.approx(3) 25 | assert grid.discretization[1] == pytest.approx(1 / 3) 26 | assert grid.volume == pytest.approx(3 * np.pi * (4**2 - r_inner**2)) 27 | assert not grid.uniform_cell_volumes 28 | assert grid.volume == pytest.approx(grid.integrate(1)) 29 | np.testing.assert_allclose(zs, np.linspace(-1 + 1 / 6, 2 - 1 / 6, 9)) 30 | 31 | if r_inner == 0: 32 | assert grid.discretization[0] == pytest.approx(0.5) 33 | np.testing.assert_array_equal(grid.discretization, np.array([0.5, 1 / 3])) 34 | np.testing.assert_allclose(rs, np.linspace(0.25, 3.75, 8)) 35 | else: 36 | assert grid.discretization[0] == pytest.approx(0.25) 37 | np.testing.assert_array_equal(grid.discretization, np.array([0.25, 1 / 3])) 38 | np.testing.assert_allclose(rs, np.linspace(2.125, 3.875, 8)) 39 | 40 | assert grid.contains_point(grid.get_random_point(coords="cartesian", rng=rng)) 41 | ps = [grid.get_random_point(coords="cartesian", rng=rng) for _ in range(2)] 42 | assert all(grid.contains_point(ps)) 43 | ps = grid.get_random_point(coords="cartesian", boundary_distance=1.49, rng=rng) 44 | assert grid.contains_point(ps) 45 | assert "laplace" in grid.operators 46 | 47 | 48 | def test_cylindrical_to_cartesian(): 49 | """Test conversion of cylindrical grid to Cartesian.""" 50 | expr_cyl = "cos(z / 2) / (1 + r**2)" 51 | expr_cart = expr_cyl.replace("r**2", "(x**2 + y**2)") 52 | 53 | z_range = (-np.pi, 2 * np.pi) 54 | grid_cyl = CylindricalSymGrid(10, z_range, (16, 33)) 55 | pf_cyl = ScalarField.from_expression(grid_cyl, expression=expr_cyl) 56 | 57 | grid_cart = CartesianGrid([[-7, 7], [-6, 7], z_range], [16, 16, 16]) 58 | pf_cart1 = pf_cyl.interpolate_to_grid(grid_cart) 59 | pf_cart2 = ScalarField.from_expression(grid_cart, expression=expr_cart) 60 | np.testing.assert_allclose(pf_cart1.data, pf_cart2.data, atol=0.1) 61 | 62 | 63 | def test_setting_boundary_conditions(): 64 | """Test various versions of settings bcs for cylindrical grids.""" 65 | grid = CylindricalSymGrid(1, [0, 1], [2, 2], periodic_z=False) 66 | grid.get_boundary_conditions("auto_periodic_neumann") 67 | grid.get_boundary_conditions({"r": "derivative", "z": "derivative"}) 68 | with pytest.raises(RuntimeError): 69 | grid.get_boundary_conditions({"r": "derivative", "z": "periodic"}) 70 | 71 | b_inner = NeumannBC(grid, 0, upper=False) 72 | assert grid.get_boundary_conditions("auto_periodic_neumann")[0].low == b_inner 73 | assert grid.get_boundary_conditions({"value": 2})[0].low != b_inner 74 | 75 | grid = CylindricalSymGrid(1, [0, 1], [2, 2], periodic_z=True) 76 | grid.get_boundary_conditions("auto_periodic_neumann") 77 | grid.get_boundary_conditions({"r": "derivative", "z": "periodic"}) 78 | with pytest.raises(RuntimeError): 79 | grid.get_boundary_conditions({"r": "derivative", "z": "derivative"}) 80 | 81 | 82 | def test_mixed_derivatives(): 83 | """Test mixed derivatives of scalar fields.""" 84 | grid = CylindricalSymGrid(1, [-1, 0.5], [7, 9]) 85 | field = ScalarField.random_normal(grid, label="c") 86 | 87 | res1 = field.apply("d_dz(d_dr(c))") 88 | res2 = field.apply("d_dr(d_dz(c))") 89 | np.testing.assert_allclose(res1.data, res2.data) 90 | -------------------------------------------------------------------------------- /tests/pdes/test_generic_pdes.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from pde import ScalarField, UnitGrid, pdes 9 | 10 | 11 | @pytest.mark.parametrize("dim", [1, 2]) 12 | @pytest.mark.parametrize( 13 | "pde_class", 14 | [ 15 | pdes.AllenCahnPDE, 16 | pdes.CahnHilliardPDE, 17 | pdes.DiffusionPDE, 18 | pdes.KPZInterfacePDE, 19 | pdes.KuramotoSivashinskyPDE, 20 | pdes.SwiftHohenbergPDE, 21 | ], 22 | ) 23 | def test_pde_consistency(pde_class, dim, rng): 24 | """Test some methods of generic PDE models.""" 25 | eq = pde_class() 26 | assert isinstance(str(eq), str) 27 | assert isinstance(repr(eq), str) 28 | 29 | # compare numba to numpy implementation 30 | grid = UnitGrid([4] * dim) 31 | state = ScalarField.random_uniform(grid, rng=rng) 32 | field = eq.evolution_rate(state) 33 | assert field.grid == grid 34 | rhs = eq._make_pde_rhs_numba(state) 35 | res = rhs(state.data, 0) 36 | np.testing.assert_allclose(field.data, res) 37 | 38 | # compare to generic implementation 39 | assert isinstance(eq.expression, str) 40 | eq2 = pdes.PDE({"c": eq.expression}) 41 | np.testing.assert_allclose(field.data, eq2.evolution_rate(state).data) 42 | 43 | 44 | def test_pde_consistency_test(rng): 45 | """Test whether the consistency of a pde implementation is checked.""" 46 | 47 | class TestPDE(pdes.PDEBase): 48 | def evolution_rate(self, field, t=0): 49 | return 2 * field 50 | 51 | def _make_pde_rhs_numba(self, state): 52 | def impl(state_data, t): 53 | return 3 * state_data 54 | 55 | return impl 56 | 57 | eq = TestPDE() 58 | state = ScalarField.random_uniform(UnitGrid([4]), rng=rng) 59 | with pytest.raises(AssertionError): 60 | eq.solve(state, t_range=5, tracker=None) 61 | -------------------------------------------------------------------------------- /tests/pdes/test_laplace_pdes.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from pde import CartesianGrid, ScalarField, UnitGrid 9 | from pde.pdes import solve_laplace_equation, solve_poisson_equation 10 | 11 | 12 | def test_pde_poisson_solver_1d(): 13 | """Test the poisson solver on 1d grids.""" 14 | # solve Laplace's equation 15 | grid = UnitGrid([4]) 16 | res = solve_laplace_equation(grid, bc={"x-": {"value": -1}, "x+": {"value": 3}}) 17 | np.testing.assert_allclose(res.data, grid.axes_coords[0] - 1) 18 | 19 | res = solve_laplace_equation( 20 | grid, bc={"x-": {"value": -1}, "x+": {"derivative": 1}} 21 | ) 22 | np.testing.assert_allclose(res.data, grid.axes_coords[0] - 1) 23 | 24 | # test Poisson equation with 2nd Order BC 25 | res = solve_laplace_equation(grid, bc={"x-": {"value": -1}, "x+": "extrapolate"}) 26 | 27 | # solve Poisson's equation 28 | grid = CartesianGrid([[0, 1]], 4) 29 | field = ScalarField(grid, data=1) 30 | 31 | res = solve_poisson_equation( 32 | field, bc={"x-": {"value": 1}, "x+": {"derivative": 1}} 33 | ) 34 | xs = grid.axes_coords[0] 35 | np.testing.assert_allclose(res.data, 1 + 0.5 * xs**2, rtol=1e-2) 36 | 37 | # test inconsistent problem 38 | field.data = 1 39 | with pytest.raises(RuntimeError, match="Neumann"): 40 | solve_poisson_equation(field, {"derivative": 0}) 41 | 42 | 43 | def test_pde_poisson_solver_2d(): 44 | """Test the poisson solver on 2d grids.""" 45 | grid = CartesianGrid([[0, 2 * np.pi]] * 2, 16) 46 | bcs = {"x": {"value": "sin(y)"}, "y": {"value": "sin(x)"}} 47 | 48 | # solve Laplace's equation 49 | res = solve_laplace_equation(grid, bcs) 50 | xs = grid.cell_coords[..., 0] 51 | ys = grid.cell_coords[..., 1] 52 | 53 | # analytical solution was obtained with Mathematica 54 | expect = ( 55 | np.cosh(np.pi - ys) * np.sin(xs) + np.cosh(np.pi - xs) * np.sin(ys) 56 | ) / np.cosh(np.pi) 57 | np.testing.assert_allclose(res.data, expect, atol=1e-2, rtol=1e-2) 58 | 59 | # test more complex case for exceptions 60 | bcs = {"x": {"value": "sin(y)"}, "y": {"curvature": "sin(x)"}} 61 | res = solve_laplace_equation(grid, bc=bcs) 62 | -------------------------------------------------------------------------------- /tests/pdes/test_wave_pdes.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from pde import PDE, ScalarField, UnitGrid, WavePDE 9 | 10 | 11 | @pytest.mark.parametrize("dim", [1, 2]) 12 | def test_wave_consistency(dim, rng): 13 | """Test some methods of the wave model.""" 14 | eq = WavePDE() 15 | assert isinstance(str(eq), str) 16 | assert isinstance(repr(eq), str) 17 | 18 | # compare numba to numpy implementation 19 | grid = UnitGrid([4] * dim) 20 | state = eq.get_initial_condition(ScalarField.random_uniform(grid, rng=rng)) 21 | field = eq.evolution_rate(state) 22 | assert field.grid == grid 23 | rhs = eq._make_pde_rhs_numba(state) 24 | np.testing.assert_allclose(field.data, rhs(state.data, 0)) 25 | 26 | # compare to generic implementation 27 | assert isinstance(eq.expressions, dict) 28 | eq2 = PDE(eq.expressions) 29 | np.testing.assert_allclose(field.data, eq2.evolution_rate(state).data) 30 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | -r ../requirements.txt 2 | docformatter>=1.7 3 | importlib-metadata>=5 4 | jupyter_contrib_nbextensions>=0.5 5 | mypy>=1.8 6 | notebook>=7 7 | pre-commit>=3 8 | pytest>=5.4 9 | pytest-cov>=2.8 10 | pytest-xdist>=1.30 11 | ruff>=0.6 12 | utilitiez>=0.3 13 | -------------------------------------------------------------------------------- /tests/requirements_full.txt: -------------------------------------------------------------------------------- 1 | # These are the full requirements used to test all functions 2 | ffmpeg-python>=0.2 3 | h5py>=2.10 4 | ipywidgets>=8 5 | matplotlib>=3.1 6 | numba>=0.59 7 | numpy>=1.22 8 | pandas>=2 9 | py-modelrunner>=0.19 10 | rocket-fft>=0.2.4 11 | scipy>=1.10 12 | sympy>=1.9 13 | tqdm>=4.66 14 | -------------------------------------------------------------------------------- /tests/requirements_min.txt: -------------------------------------------------------------------------------- 1 | # These are the minimal requirements used to test compatibility 2 | matplotlib~=3.1 3 | numba~=0.59 4 | numpy~=1.22 5 | scipy~=1.10 6 | sympy~=1.9 7 | tqdm~=4.66 8 | -------------------------------------------------------------------------------- /tests/requirements_mpi.txt: -------------------------------------------------------------------------------- 1 | # These are requirements used to test multiprocessing 2 | h5py>=2.10 3 | matplotlib>=3.1 4 | mpi4py>=3 5 | numba>=0.59 6 | numba-mpi>=0.22 7 | numpy>=1.22 8 | pandas>=2 9 | scipy>=1.10 10 | sympy>=1.9 11 | tqdm>=4.66 12 | -------------------------------------------------------------------------------- /tests/resources/run_pde.py: -------------------------------------------------------------------------------- 1 | import pde 2 | 3 | 4 | def run_pde(t_range, storage): 5 | """Run a pde and store trajectory.""" 6 | field = pde.ScalarField.random_uniform(pde.UnitGrid([8, 8])) 7 | storage["initial_state"] = field 8 | 9 | eq = pde.DiffusionPDE() 10 | result = eq.solve( 11 | field, 12 | t_range=t_range, 13 | dt=0.1, 14 | backend="numpy", 15 | tracker=pde.ModelrunnerStorage(storage).tracker(1), 16 | ) 17 | return {"field": result} 18 | -------------------------------------------------------------------------------- /tests/solvers/test_adams_bashforth_solver.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | 7 | import pde 8 | 9 | 10 | def test_adams_bashforth(): 11 | """Test the adams_bashforth method.""" 12 | eq = pde.PDE({"y": "y"}) 13 | state = pde.ScalarField(pde.UnitGrid([1]), 1) 14 | storage = pde.MemoryStorage() 15 | eq.solve( 16 | state, 17 | t_range=2.1, 18 | dt=0.5, 19 | solver="adams–bashforth", 20 | tracker=storage.tracker(0.5), 21 | ) 22 | np.testing.assert_allclose( 23 | np.ravel([f.data for f in storage]), 24 | [1, 13 / 8, 83 / 32, 529 / 128, 3371 / 512, 21481 / 2048], 25 | ) 26 | -------------------------------------------------------------------------------- /tests/solvers/test_controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from pde import PDEBase, ScalarField, UnitGrid 9 | from pde.solvers import Controller 10 | 11 | 12 | def test_controller_abort(): 13 | """Test how controller deals with errors.""" 14 | 15 | class ErrorPDEException(RuntimeError): ... 16 | 17 | class ErrorPDE(PDEBase): 18 | def evolution_rate(self, state, t): 19 | if t < 1: 20 | return 0 * state 21 | else: 22 | raise ErrorPDEException 23 | 24 | field = ScalarField(UnitGrid([16]), 1) 25 | eq = ErrorPDE() 26 | 27 | with pytest.raises(ErrorPDEException): 28 | eq.solve(field, t_range=2, dt=0.2, backend="numpy") 29 | 30 | assert eq.diagnostics["last_tracker_time"] >= 0 31 | assert eq.diagnostics["last_state"] == field 32 | 33 | 34 | def test_controller_foreign_solver(): 35 | """Test whether the Controller can deal with a minimal foreign solver.""" 36 | 37 | class MySolver: 38 | def make_stepper(self, state, dt): 39 | def stepper(state, t, t_break): 40 | return t_break 41 | 42 | return stepper 43 | 44 | c = Controller(MySolver(), t_range=1) 45 | res = c.run(np.arange(3)) 46 | np.testing.assert_allclose(res, np.arange(3)) 47 | -------------------------------------------------------------------------------- /tests/solvers/test_explicit_mpi_solvers.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from pde import PDE, DiffusionPDE, FieldCollection, ScalarField, UnitGrid 9 | from pde.solvers import Controller, ExplicitMPISolver 10 | from pde.tools import mpi 11 | 12 | 13 | @pytest.mark.multiprocessing 14 | @pytest.mark.parametrize("backend", ["numpy", "numba"]) 15 | @pytest.mark.parametrize( 16 | "scheme, adaptive, decomposition", 17 | [ 18 | ("euler", False, "auto"), 19 | ("euler", True, [1, -1]), 20 | ("runge-kutta", True, [-1, 1]), 21 | ], 22 | ) 23 | def test_simple_pde_mpi(backend, scheme, adaptive, decomposition, rng): 24 | """Test setting boundary conditions using numba.""" 25 | grid = UnitGrid([8, 8], periodic=[True, False]) 26 | 27 | field = ScalarField.random_uniform(grid, rng=rng) 28 | eq = DiffusionPDE() 29 | 30 | args = { 31 | "state": field, 32 | "t_range": 1.01, 33 | "dt": 0.1, 34 | "adaptive": adaptive, 35 | "scheme": scheme, 36 | "tracker": None, 37 | "ret_info": True, 38 | } 39 | res_mpi, info_mpi = eq.solve( 40 | backend=backend, solver="explicit_mpi", decomposition=decomposition, **args 41 | ) 42 | 43 | if mpi.is_main: 44 | # check results in the main process 45 | expect, info2 = eq.solve(backend="numpy", solver="explicit", **args) 46 | np.testing.assert_allclose(res_mpi.data, expect.data) 47 | 48 | assert info_mpi["solver"]["steps"] == info2["solver"]["steps"] 49 | assert info_mpi["solver"]["use_mpi"] 50 | if decomposition != "auto": 51 | for i in range(2): 52 | if decomposition[i] == 1: 53 | assert info_mpi["solver"]["grid_decomposition"][i] == 1 54 | else: 55 | assert info_mpi["solver"]["grid_decomposition"][i] == mpi.size 56 | 57 | 58 | @pytest.mark.multiprocessing 59 | @pytest.mark.parametrize("backend", ["numba", "numpy"]) 60 | def test_stochastic_mpi_solvers(backend, rng): 61 | """Test simple version of the stochastic solver.""" 62 | field = ScalarField.random_uniform(UnitGrid([16]), -1, 1, rng=rng) 63 | eq = DiffusionPDE() 64 | seq = DiffusionPDE(noise=1e-10) 65 | 66 | solver1 = ExplicitMPISolver(eq, backend=backend) 67 | c1 = Controller(solver1, t_range=1, tracker=None) 68 | s1 = c1.run(field, dt=1e-3) 69 | 70 | solver2 = ExplicitMPISolver(seq, backend=backend) 71 | c2 = Controller(solver2, t_range=1, tracker=None) 72 | s2 = c2.run(field, dt=1e-3) 73 | 74 | if mpi.is_main: 75 | np.testing.assert_allclose(s1.data, s2.data, rtol=1e-4, atol=1e-4) 76 | assert not solver1.info["stochastic"] 77 | assert solver2.info["stochastic"] 78 | 79 | assert not solver1.info["dt_adaptive"] 80 | assert not solver2.info["dt_adaptive"] 81 | 82 | 83 | @pytest.mark.multiprocessing 84 | @pytest.mark.parametrize("backend", ["numpy", "numba"]) 85 | def test_multiple_pdes_mpi(backend, rng): 86 | """Test setting boundary conditions using numba.""" 87 | grid = UnitGrid([8, 8], periodic=[True, False]) 88 | 89 | fields = FieldCollection.scalar_random_uniform(2, grid, rng=rng) 90 | eq = PDE({"a": "laplace(a) - b", "b": "laplace(b) + a"}) 91 | 92 | args = { 93 | "state": fields, 94 | "t_range": 1.01, 95 | "dt": 0.1, 96 | "adaptive": True, 97 | "scheme": "euler", 98 | "tracker": None, 99 | "ret_info": True, 100 | } 101 | res_mpi, info_mpi = eq.solve(backend=backend, solver="explicit_mpi", **args) 102 | 103 | if mpi.is_main: 104 | # check results in the main process 105 | expect, info2 = eq.solve(backend="numpy", solver="explicit", **args) 106 | np.testing.assert_allclose(res_mpi.data, expect.data) 107 | 108 | assert info_mpi["solver"]["steps"] == info2["solver"]["steps"] 109 | assert info_mpi["solver"]["use_mpi"] 110 | -------------------------------------------------------------------------------- /tests/solvers/test_generic_solvers.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from pde import PDE, DiffusionPDE, FieldCollection, MemoryStorage, ScalarField, UnitGrid 9 | from pde.solvers import ( 10 | AdamsBashforthSolver, 11 | Controller, 12 | CrankNicolsonSolver, 13 | ExplicitSolver, 14 | ImplicitSolver, 15 | ScipySolver, 16 | registered_solvers, 17 | ) 18 | from pde.solvers.base import AdaptiveSolverBase 19 | 20 | SOLVER_CLASSES = [ 21 | ExplicitSolver, 22 | ImplicitSolver, 23 | CrankNicolsonSolver, 24 | AdamsBashforthSolver, 25 | ScipySolver, 26 | ] 27 | 28 | 29 | def test_solver_registration(): 30 | """Test solver registration.""" 31 | solvers = registered_solvers() 32 | assert "explicit" in solvers 33 | assert "implicit" in solvers 34 | assert "crank-nicolson" in solvers 35 | assert "scipy" in solvers 36 | 37 | 38 | def test_solver_in_pde_class(rng): 39 | """Test whether solver instances can be used in pde instances.""" 40 | field = ScalarField.random_uniform(UnitGrid([16, 16]), -1, 1, rng=rng) 41 | eq = DiffusionPDE() 42 | eq.solve(field, t_range=1, solver=ScipySolver, tracker=None) 43 | 44 | 45 | @pytest.mark.parametrize("solver_class", SOLVER_CLASSES) 46 | def test_compare_solvers(solver_class, rng): 47 | """Compare several solvers.""" 48 | field = ScalarField.random_uniform(UnitGrid([8, 8]), -1, 1, rng=rng) 49 | eq = DiffusionPDE() 50 | 51 | # ground truth 52 | c1 = Controller(ExplicitSolver(eq, scheme="runge-kutta"), t_range=0.1, tracker=None) 53 | s1 = c1.run(field, dt=5e-3) 54 | 55 | c2 = Controller(solver_class(eq), t_range=0.1, tracker=None) 56 | with np.errstate(under="ignore"): 57 | s2 = c2.run(field, dt=5e-3) 58 | 59 | np.testing.assert_allclose(s1.data, s2.data, rtol=1e-2, atol=1e-2) 60 | 61 | 62 | @pytest.mark.parametrize("solver_class", SOLVER_CLASSES) 63 | @pytest.mark.parametrize("backend", ["numpy", "numba"]) 64 | def test_solvers_complex(solver_class, backend): 65 | """Test solvers with a complex PDE.""" 66 | r = FieldCollection.scalar_random_uniform(2, UnitGrid([3]), labels=["a", "b"]) 67 | c = r["a"] + 1j * r["b"] 68 | assert c.is_complex 69 | 70 | # assume c = a + i * b 71 | eq_c = PDE({"c": "-I * laplace(c)"}) 72 | eq_r = PDE({"a": "laplace(b)", "b": "-laplace(a)"}) 73 | res_r = eq_r.solve(r, t_range=1e-2, dt=1e-3, backend="numpy", tracker=None) 74 | exp_c = res_r[0].data + 1j * res_r[1].data 75 | 76 | solver = solver_class(eq_c, backend=backend) 77 | controller = Controller(solver, t_range=1e-2, tracker=None) 78 | res_c = controller.run(c, dt=1e-3) 79 | np.testing.assert_allclose(res_c.data, exp_c, rtol=1e-3, atol=1e-3) 80 | 81 | 82 | def test_basic_adaptive_solver(): 83 | """Test basic adaptive solvers.""" 84 | grid = UnitGrid([4]) 85 | y0 = np.array([1e-3, 1e-3, 1e3, 1e3]) 86 | field = ScalarField(grid, y0) 87 | eq = PDE({"c": "c"}) 88 | 89 | dt = 0.1 90 | 91 | solver = AdaptiveSolverBase(eq, tolerance=1e-1) 92 | storage = MemoryStorage() 93 | controller = Controller(solver, t_range=10.1, tracker=storage.tracker(1.0)) 94 | res = controller.run(field, dt=dt) 95 | 96 | np.testing.assert_allclose(res.data, y0 * np.exp(10.1), rtol=0.02) 97 | assert solver.info["steps"] != pytest.approx(10.1 / dt, abs=1) 98 | assert solver.info["dt_adaptive"] 99 | assert solver.info["dt_statistics"]["min"] < 0.0005 100 | assert np.allclose(storage.times, np.arange(11)) 101 | -------------------------------------------------------------------------------- /tests/solvers/test_implicit_solvers.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from pde import PDE, DiffusionPDE, ScalarField, UnitGrid 9 | from pde.solvers import Controller, ImplicitSolver 10 | from pde.tools import mpi 11 | 12 | 13 | @pytest.mark.parametrize("backend", ["numpy", "numba"]) 14 | def test_implicit_solvers_simple_fixed(backend): 15 | """Test implicit solvers.""" 16 | grid = UnitGrid([4]) 17 | xs = grid.axes_coords[0] 18 | field = ScalarField.from_expression(grid, "x") 19 | eq = PDE({"c": "c"}) 20 | 21 | dt = 0.01 22 | solver = ImplicitSolver(eq, backend=backend) 23 | controller = Controller(solver, t_range=10.0, tracker=None) 24 | res = controller.run(field, dt=dt) 25 | 26 | if mpi.is_main: 27 | np.testing.assert_allclose(res.data, xs * np.exp(10), rtol=0.1) 28 | assert solver.info["steps"] == pytest.approx(10 / dt, abs=1) 29 | assert not solver.info.get("dt_adaptive", False) 30 | 31 | 32 | @pytest.mark.parametrize("backend", ["numpy", "numba"]) 33 | def test_implicit_stochastic_solvers(backend, rng): 34 | """Test simple version of the stochastic implicit solver.""" 35 | field = ScalarField.random_uniform(UnitGrid([16]), -1, 1, rng=rng) 36 | eq = DiffusionPDE() 37 | seq = DiffusionPDE(noise=1e-10) 38 | 39 | solver1 = ImplicitSolver(eq, backend=backend) 40 | c1 = Controller(solver1, t_range=1, tracker=None) 41 | s1 = c1.run(field, dt=1e-3) 42 | 43 | solver2 = ImplicitSolver(seq, backend=backend) 44 | c2 = Controller(solver2, t_range=1, tracker=None) 45 | s2 = c2.run(field, dt=1e-3) 46 | 47 | np.testing.assert_allclose(s1.data, s2.data, rtol=1e-4, atol=1e-4) 48 | assert not solver1.info["stochastic"] 49 | assert solver2.info["stochastic"] 50 | 51 | assert not solver1.info.get("dt_adaptive", False) 52 | assert not solver2.info.get("dt_adaptive", False) 53 | -------------------------------------------------------------------------------- /tests/solvers/test_scipy_solvers.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | 7 | from pde import PDE, DiffusionPDE, FieldCollection, ScalarField, UnitGrid 8 | from pde.solvers import Controller, ScipySolver 9 | 10 | 11 | def test_scipy_no_dt(rng): 12 | """Test scipy solver without timestep.""" 13 | grid = UnitGrid([16]) 14 | field = ScalarField.random_uniform(grid, -1, 1, rng=rng) 15 | eq = DiffusionPDE() 16 | 17 | c1 = Controller(ScipySolver(eq), t_range=1, tracker=None) 18 | s1 = c1.run(field, dt=1e-3) 19 | 20 | c2 = Controller(ScipySolver(eq), t_range=1, tracker=None) 21 | s2 = c2.run(field) 22 | 23 | np.testing.assert_allclose(s1.data, s2.data, rtol=1e-3, atol=1e-3) 24 | 25 | 26 | def test_scipy_field_collection(): 27 | """Test scipy solver with field collection.""" 28 | grid = UnitGrid([2]) 29 | field = FieldCollection.from_scalar_expressions(grid, ["x", "0"]) 30 | eq = PDE({"a": "1", "b": "a"}) 31 | 32 | res = eq.solve(field, t_range=1, dt=1e-2, solver="scipy") 33 | np.testing.assert_allclose(res.data, np.array([[1.5, 2.5], [1.0, 2.0]])) 34 | -------------------------------------------------------------------------------- /tests/storage/resources/empty.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/empty.avi -------------------------------------------------------------------------------- /tests/storage/resources/no_metadata.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/no_metadata.avi -------------------------------------------------------------------------------- /tests/storage/resources/storage_1.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/storage_1.avi -------------------------------------------------------------------------------- /tests/storage/resources/storage_1.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/storage_1.hdf5 -------------------------------------------------------------------------------- /tests/storage/resources/storage_2.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/storage_2.avi -------------------------------------------------------------------------------- /tests/storage/resources/storage_2.avi.times: -------------------------------------------------------------------------------- 1 | 0.1 2 | 0.7 3 | 2.9 4 | -------------------------------------------------------------------------------- /tests/storage/resources/storage_2.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/storage_2.hdf5 -------------------------------------------------------------------------------- /tests/storage/test_memory_storages.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from pde import MemoryStorage, UnitGrid 9 | from pde.fields import FieldCollection, ScalarField, Tensor2Field, VectorField 10 | 11 | 12 | def test_memory_storage(): 13 | """Test methods specific to memory storage.""" 14 | sf = ScalarField(UnitGrid([1])) 15 | s1 = MemoryStorage() 16 | s1.start_writing(sf) 17 | sf.data = 0 18 | s1.append(sf, 0) 19 | sf.data = 2 20 | s1.append(sf, 1) 21 | 22 | s2 = MemoryStorage() 23 | s2.start_writing(sf) 24 | sf.data = 1 25 | s2.append(sf, 0) 26 | sf.data = 3 27 | s2.append(sf, 1) 28 | 29 | # test from_fields 30 | s3 = MemoryStorage.from_fields(s1.times, [s1[0], s1[1]]) 31 | assert s3.times == s1.times 32 | np.testing.assert_allclose(s3.data, s1.data) 33 | 34 | # test from_collection 35 | s3 = MemoryStorage.from_collection([s1, s2]) 36 | assert s3.times == s1.times 37 | np.testing.assert_allclose(np.ravel(s3.data), np.arange(4)) 38 | 39 | 40 | @pytest.mark.parametrize("cls", [ScalarField, VectorField, Tensor2Field]) 41 | def test_field_type_guessing_fields(cls, rng): 42 | """Test the ability to guess the field type.""" 43 | grid = UnitGrid([3]) 44 | field = cls.random_normal(grid, rng=rng) 45 | s = MemoryStorage() 46 | s.start_writing(field) 47 | s.append(field, 0) 48 | s.append(field, 1) 49 | 50 | # delete information 51 | s._field = None 52 | s.info = {} 53 | 54 | assert not s.has_collection 55 | assert len(s) == 2 56 | assert s[0] == field 57 | 58 | 59 | def test_field_type_guessing_collection(rng): 60 | """Test the ability to guess the field type of a collection.""" 61 | grid = UnitGrid([3]) 62 | field = FieldCollection([ScalarField(grid), VectorField(grid)]) 63 | s = MemoryStorage() 64 | s.start_writing(field) 65 | s.append(field, 0) 66 | 67 | assert s.has_collection 68 | 69 | # delete information 70 | s._field = None 71 | s.info = {} 72 | 73 | with pytest.raises(RuntimeError): 74 | s[0] 75 | -------------------------------------------------------------------------------- /tests/storage/test_modelrunner_storages.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | import pde 9 | from pde.tools.misc import module_available 10 | 11 | 12 | @pytest.mark.skipif( 13 | not module_available("modelrunner"), reason="requires `py-modelrunner` package" 14 | ) 15 | def test_storage_write_trajectory(tmp_path): 16 | """Test simple storage writing.""" 17 | import modelrunner as mr 18 | 19 | path = tmp_path / "storage.json" 20 | storage = mr.open_storage(path, mode="truncate") 21 | 22 | field = pde.ScalarField.random_uniform(pde.UnitGrid([8, 8])) 23 | eq = pde.DiffusionPDE() 24 | eq.solve( 25 | field, 26 | t_range=2.5, 27 | dt=0.1, 28 | backend="numpy", 29 | tracker=pde.ModelrunnerStorage(storage).tracker(1), 30 | ) 31 | 32 | assert path.is_file() 33 | assert len(pde.ModelrunnerStorage(path)) == 3 34 | np.testing.assert_allclose(pde.ModelrunnerStorage(path).times, [0, 1, 2]) 35 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import os 6 | import subprocess as sp 7 | import sys 8 | from pathlib import Path 9 | 10 | import pytest 11 | 12 | from pde.tools.misc import module_available 13 | from pde.visualization.movies import Movie 14 | 15 | PACKAGE_PATH = Path(__file__).resolve().parents[1] 16 | EXAMPLES = (PACKAGE_PATH / "examples").glob("*/*.py") 17 | NOTEBOOKS = (PACKAGE_PATH / "examples").glob("*/*.ipynb") 18 | 19 | SKIP_EXAMPLES: list[str] = [] 20 | if not Movie.is_available(): 21 | SKIP_EXAMPLES.extend(["make_movie_live.py", "make_movie_storage.py", "storages.py"]) 22 | if not module_available("mpi4py"): 23 | SKIP_EXAMPLES.extend(["mpi_parallel_run"]) 24 | if not module_available("napari"): 25 | SKIP_EXAMPLES.extend(["tracker_interactive", "show_3d_field_interactively"]) 26 | if not module_available("h5py"): 27 | SKIP_EXAMPLES.extend(["trajectory_io"]) 28 | if not all(module_available(m) for m in ["modelrunner", "h5py"]): 29 | SKIP_EXAMPLES.extend(["py_modelrunner"]) 30 | if not module_available("utilitiez"): 31 | SKIP_EXAMPLES.extend(["logarithmic_kymograph"]) 32 | 33 | 34 | @pytest.mark.slow 35 | @pytest.mark.no_cover 36 | @pytest.mark.skipif(sys.platform == "win32", reason="Assumes unix setup") 37 | @pytest.mark.parametrize("path", EXAMPLES) 38 | def test_example_scripts(path): 39 | """Runs an example script given by path.""" 40 | # check whether this test needs to be run 41 | if path.name.startswith("_"): 42 | pytest.skip("skip examples starting with an underscore") 43 | if any(name in str(path) for name in SKIP_EXAMPLES): 44 | pytest.skip(f"Skip test {path}") 45 | 46 | # run the actual test in a separate python process 47 | env = os.environ.copy() 48 | env["PYTHONPATH"] = str(PACKAGE_PATH) + ":" + env.get("PYTHONPATH", "") 49 | proc = sp.Popen([sys.executable, path], env=env, stdout=sp.PIPE, stderr=sp.PIPE) 50 | try: 51 | outs, errs = proc.communicate(timeout=30) 52 | except sp.TimeoutExpired: 53 | proc.kill() 54 | outs, errs = proc.communicate() 55 | 56 | # delete files that might be created by the test 57 | try: 58 | (PACKAGE_PATH / "diffusion.mov").unlink() 59 | (PACKAGE_PATH / "allen_cahn.avi").unlink() 60 | (PACKAGE_PATH / "allen_cahn.hdf").unlink() 61 | except OSError: 62 | pass 63 | 64 | # prepare output 65 | msg = f"Script `{path}` failed with following output:" 66 | if outs: 67 | msg = f"{msg}\nSTDOUT:\n{outs}" 68 | if errs: 69 | msg = f"{msg}\nSTDERR:\n{errs}" 70 | assert proc.returncode <= 0, msg 71 | 72 | 73 | @pytest.mark.slow 74 | @pytest.mark.no_cover 75 | @pytest.mark.skipif(not module_available("h5py"), reason="requires `h5py`") 76 | @pytest.mark.skipif(not module_available("jupyter"), reason="requires `jupyter`") 77 | @pytest.mark.skipif(not module_available("notebook"), reason="requires `notebook`") 78 | @pytest.mark.parametrize("path", NOTEBOOKS) 79 | def test_jupyter_notebooks(path, tmp_path): 80 | """Run the jupyter notebooks.""" 81 | import notebook as jupyter_notebook 82 | 83 | if int(jupyter_notebook.__version__.split(".")[0]) < 7: 84 | raise RuntimeError("Jupyter notebooks must be at least version 7") 85 | 86 | if path.name.startswith("_"): 87 | pytest.skip("Skip examples starting with an underscore") 88 | 89 | # adjust python environment 90 | my_env = os.environ.copy() 91 | my_env["PYTHONPATH"] = str(PACKAGE_PATH) + ":" + my_env.get("PYTHONPATH", "") 92 | 93 | # run the notebook 94 | sp.check_call([sys.executable, "-m", "jupyter", "execute", path], env=my_env) 95 | -------------------------------------------------------------------------------- /tests/tools/test_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import pytest 6 | 7 | from pde.tools.config import Config, environment, packages_from_requirements 8 | 9 | 10 | def test_environment(): 11 | """Test the environment function.""" 12 | assert isinstance(environment(), dict) 13 | 14 | 15 | def test_config(): 16 | """Test configuration system.""" 17 | c = Config() 18 | 19 | assert c["numba.multithreading_threshold"] > 0 20 | 21 | assert "numba.multithreading_threshold" in c 22 | assert any(k == "numba.multithreading_threshold" for k in c) 23 | assert any(k == "numba.multithreading_threshold" and v > 0 for k, v in c.items()) 24 | assert "numba.multithreading_threshold" in c.to_dict() 25 | assert isinstance(repr(c), str) 26 | 27 | 28 | def test_config_modes(): 29 | """Test configuration system running in different modes.""" 30 | c = Config(mode="insert") 31 | assert c["numba.multithreading_threshold"] > 0 32 | c["numba.multithreading_threshold"] = 0 33 | assert c["numba.multithreading_threshold"] == 0 34 | c["new_value"] = "value" 35 | assert c["new_value"] == "value" 36 | c.update({"new_value2": "value2"}) 37 | assert c["new_value2"] == "value2" 38 | del c["new_value"] 39 | with pytest.raises(KeyError): 40 | c["new_value"] 41 | with pytest.raises(KeyError): 42 | c["undefined"] 43 | 44 | c = Config(mode="update") 45 | assert c["numba.multithreading_threshold"] > 0 46 | c["numba.multithreading_threshold"] = 0 47 | 48 | with pytest.raises(KeyError): 49 | c["new_value"] = "value" 50 | with pytest.raises(KeyError): 51 | c.update({"new_value": "value"}) 52 | with pytest.raises(RuntimeError): 53 | del c["numba.multithreading_threshold"] 54 | with pytest.raises(KeyError): 55 | c["undefined"] 56 | 57 | c = Config(mode="locked") 58 | assert c["numba.multithreading_threshold"] > 0 59 | with pytest.raises(RuntimeError): 60 | c["numba.multithreading_threshold"] = 0 61 | with pytest.raises(RuntimeError): 62 | c.update({"numba.multithreading_threshold": 0}) 63 | with pytest.raises(RuntimeError): 64 | c["new_value"] = "value" 65 | with pytest.raises(RuntimeError): 66 | del c["numba.multithreading_threshold"] 67 | with pytest.raises(KeyError): 68 | c["undefined"] 69 | 70 | c = Config(mode="undefined") 71 | assert c["numba.multithreading_threshold"] > 0 72 | with pytest.raises(ValueError): 73 | c["numba.multithreading_threshold"] = 0 74 | with pytest.raises(ValueError): 75 | c.update({"numba.multithreading_threshold": 0}) 76 | with pytest.raises(RuntimeError): 77 | del c["numba.multithreading_threshold"] 78 | 79 | c = Config({"new_value": "value"}, mode="locked") 80 | assert c["new_value"] == "value" 81 | 82 | 83 | def test_config_contexts(): 84 | """Test context manager temporarily changing configuration.""" 85 | c = Config() 86 | 87 | assert c["numba.multithreading_threshold"] > 0 88 | with c({"numba.multithreading_threshold": 0}): 89 | assert c["numba.multithreading_threshold"] == 0 90 | with c({"numba.multithreading_threshold": 1}): 91 | assert c["numba.multithreading_threshold"] == 1 92 | assert c["numba.multithreading_threshold"] == 0 93 | 94 | assert c["numba.multithreading_threshold"] > 0 95 | 96 | 97 | def test_config_special_values(): 98 | """Test configuration system running in different modes.""" 99 | c = Config() 100 | c["numba.multithreading"] = True 101 | assert c["numba.multithreading"] == "always" 102 | assert c.use_multithreading() 103 | c["numba.multithreading"] = False 104 | assert c["numba.multithreading"] == "never" 105 | assert not c.use_multithreading() 106 | 107 | 108 | def test_packages_from_requirements(): 109 | """Test the packages_from_requirements function.""" 110 | results = packages_from_requirements("file_not_existing") 111 | assert len(results) == 1 112 | assert "Could not open" in results[0] 113 | assert "file_not_existing" in results[0] 114 | -------------------------------------------------------------------------------- /tests/tools/test_ffmpeg.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import pytest 6 | 7 | from pde.tools.ffmpeg import find_format 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "channels,bits_per_channel,result", 12 | [(1, 8, "gray"), (2, 7, "rgb24"), (3, 9, "gbrp16le"), (5, 8, None), (1, 17, None)], 13 | ) 14 | def test_find_format(channels, bits_per_channel, result): 15 | """test_find_format function.""" 16 | assert find_format(channels, bits_per_channel) == result 17 | -------------------------------------------------------------------------------- /tests/tools/test_math.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from pde.tools.math import OnlineStatistics, SmoothData1D 9 | 10 | 11 | def test_SmoothData1D(rng): 12 | """Test smoothing.""" 13 | x = rng.uniform(0, 1, 500) 14 | xs = np.linspace(0, 1, 16)[1:-1] 15 | 16 | s = SmoothData1D(x, np.ones_like(x), 0.1) 17 | np.testing.assert_allclose(s(xs), 1) 18 | np.testing.assert_allclose(s.derivative(xs), 0, atol=1e-14) 19 | 20 | s = SmoothData1D(x, x, 0.02) 21 | np.testing.assert_allclose(s(xs), xs, atol=0.1) 22 | np.testing.assert_allclose(s.derivative(xs), 1, atol=0.3) 23 | 24 | s = SmoothData1D(x, np.sin(x), 0.05) 25 | np.testing.assert_allclose(s(xs), np.sin(xs), atol=0.1) 26 | np.testing.assert_allclose(s.derivative(xs)[1:-1], np.cos(xs)[1:-1], atol=0.3) 27 | 28 | assert -0.1 not in s 29 | assert x.min() in s 30 | assert 0.5 in s 31 | assert x.max() in s 32 | assert 1.1 not in s 33 | 34 | x = np.arange(3) 35 | y = [0, 1, np.nan] 36 | s = SmoothData1D(x, y) 37 | assert s(0.5) == pytest.approx(0.5) 38 | 39 | 40 | def test_online_statistics(): 41 | """Test OnlineStatistics class.""" 42 | stat = OnlineStatistics() 43 | 44 | stat.add(1) 45 | stat.add(2) 46 | 47 | assert stat.mean == pytest.approx(1.5) 48 | assert stat.std == pytest.approx(0.5) 49 | assert stat.var == pytest.approx(0.25) 50 | assert stat.count == 2 51 | assert stat.to_dict() == pytest.approx( 52 | {"min": 1, "max": 2, "mean": 1.5, "std": 0.5, "count": 2} 53 | ) 54 | 55 | stat = OnlineStatistics() 56 | assert stat.to_dict()["count"] == 0 57 | -------------------------------------------------------------------------------- /tests/tools/test_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import json 6 | import os 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import pytest 11 | 12 | from pde.tools import misc 13 | 14 | 15 | def test_ensure_directory_exists(tmp_path): 16 | """Tests the ensure_directory_exists function.""" 17 | # create temporary name 18 | path = tmp_path / "test_ensure_directory_exists" 19 | assert not path.exists() 20 | # create the folder 21 | misc.ensure_directory_exists(path) 22 | assert path.is_dir() 23 | # check that a second call has the same result 24 | misc.ensure_directory_exists(path) 25 | assert path.is_dir() 26 | # remove the folder again 27 | Path.rmdir(path) 28 | assert not path.exists() 29 | 30 | 31 | def test_preserve_scalars(): 32 | """Test the preserve_scalars decorator.""" 33 | 34 | class Test: 35 | @misc.preserve_scalars 36 | def meth(self, arr): 37 | return arr + 1 38 | 39 | t = Test() 40 | 41 | assert t.meth(1) == 2 42 | np.testing.assert_equal(t.meth(np.ones(2)), np.full(2, 2)) 43 | 44 | 45 | def test_hybridmethod(): 46 | """Test the hybridmethod decorator.""" 47 | 48 | class Test: 49 | @misc.hybridmethod 50 | def method(cls): 51 | return "class" 52 | 53 | @method.instancemethod 54 | def method(self): 55 | return "instance" 56 | 57 | assert Test.method() == "class" 58 | assert Test().method() == "instance" 59 | 60 | 61 | def test_estimate_computation_speed(): 62 | """Test estimate_computation_speed method.""" 63 | 64 | def f(x): 65 | return 2 * x 66 | 67 | def g(x): 68 | return np.sin(x) * np.cos(x) ** 2 69 | 70 | assert misc.estimate_computation_speed(f, 1) > misc.estimate_computation_speed(g, 1) 71 | 72 | 73 | def test_classproperty(): 74 | """Test classproperty decorator.""" 75 | 76 | class Test: 77 | _value = 2 78 | 79 | @misc.classproperty 80 | def value(cls): 81 | return cls._value 82 | 83 | assert Test.value == 2 84 | 85 | 86 | @pytest.mark.skipif(not misc.module_available("h5py"), reason="requires `h5py` module") 87 | def test_hdf_write_attributes(tmp_path): 88 | """Test hdf_write_attributes function.""" 89 | import h5py 90 | 91 | path = tmp_path / "test_hdf_write_attributes.hdf5" 92 | 93 | # test normal case 94 | data = {"a": 3, "b": "asd"} 95 | with h5py.File(path, "w") as hdf_file: 96 | misc.hdf_write_attributes(hdf_file, data) 97 | data2 = {k: json.loads(v) for k, v in hdf_file.attrs.items()} 98 | 99 | assert data == data2 100 | assert data is not data2 101 | 102 | # test silencing of problematic items 103 | with h5py.File(path, "w") as hdf_file: 104 | misc.hdf_write_attributes(hdf_file, {"a": 1, "b": object()}) 105 | data2 = {k: json.loads(v) for k, v in hdf_file.attrs.items()} 106 | assert data2 == {"a": 1} 107 | 108 | # test raising problematic items 109 | with h5py.File(path, "w") as hdf_file, pytest.raises(TypeError): 110 | misc.hdf_write_attributes( 111 | hdf_file, {"a": object()}, raise_serialization_error=True 112 | ) 113 | -------------------------------------------------------------------------------- /tests/tools/test_mpi.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from pde.tools.mpi import mpi_allreduce, mpi_recv, mpi_send, rank, size 9 | 10 | 11 | @pytest.mark.multiprocessing 12 | def test_send_recv(): 13 | """Test basic send and receive.""" 14 | if size == 1: 15 | pytest.skip("Run without multiprocessing") 16 | 17 | data = np.arange(5) 18 | if rank == 0: 19 | out = np.empty_like(data) 20 | mpi_recv(out, 1, 1) 21 | np.testing.assert_allclose(out, data) 22 | elif rank == 1: 23 | mpi_send(data, 0, 1) 24 | 25 | 26 | @pytest.mark.multiprocessing 27 | @pytest.mark.parametrize("operator", ["MAX", "MIN", "SUM"]) 28 | def test_allreduce(operator, rng): 29 | """Test MPI allreduce function.""" 30 | data = rng.uniform(size=size) 31 | result = mpi_allreduce(data[rank], operator=operator) 32 | 33 | if operator == "MAX": 34 | assert result == data.max() 35 | elif operator == "MIN": 36 | assert result == data.min() 37 | elif operator == "SUM": 38 | assert result == data.sum() 39 | else: 40 | raise NotImplementedError 41 | 42 | 43 | @pytest.mark.multiprocessing 44 | @pytest.mark.parametrize("operator", ["MAX", "MIN", "SUM"]) 45 | def test_allreduce_array(operator, rng): 46 | """Test MPI allreduce function.""" 47 | data = rng.uniform(size=(size, 3)) 48 | result = mpi_allreduce(data[rank], operator=operator) 49 | 50 | if operator == "MAX": 51 | np.testing.assert_allclose(result, data.max(axis=0)) 52 | elif operator == "MIN": 53 | np.testing.assert_allclose(result, data.min(axis=0)) 54 | elif operator == "SUM": 55 | np.testing.assert_allclose(result, data.sum(axis=0)) 56 | else: 57 | raise NotImplementedError 58 | -------------------------------------------------------------------------------- /tests/tools/test_numba.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import numba 6 | import numpy as np 7 | import pytest 8 | 9 | from pde.tools.numba import ( 10 | Counter, 11 | flat_idx, 12 | jit, 13 | make_array_constructor, 14 | numba_dict, 15 | numba_environment, 16 | ) 17 | 18 | 19 | def test_environment(): 20 | """Test function signature checks.""" 21 | assert isinstance(numba_environment(), dict) 22 | 23 | 24 | def test_flat_idx(): 25 | """Test flat_idx function.""" 26 | # testing the numpy version 27 | assert flat_idx(2, 1) == 2 28 | assert flat_idx(np.arange(2), 1) == 1 29 | assert flat_idx(np.arange(4).reshape(2, 2), 1) == 1 30 | 31 | # testing the numba compiled version 32 | @jit 33 | def get_sparse_matrix_data(data): 34 | return flat_idx(data, 1) 35 | 36 | assert get_sparse_matrix_data(2) == 2 37 | assert get_sparse_matrix_data(np.arange(2)) == 1 38 | assert get_sparse_matrix_data(np.arange(4).reshape(2, 2)) == 1 39 | 40 | 41 | def test_counter(): 42 | """Test Counter implementation.""" 43 | c1 = Counter() 44 | assert int(c1) == 0 45 | assert c1 == 0 46 | assert str(c1) == "0" 47 | 48 | c1.increment() 49 | assert int(c1) == 1 50 | 51 | c1 += 2 52 | assert int(c1) == 3 53 | 54 | c2 = Counter(3) 55 | assert c1 is not c2 56 | assert c1 == c2 57 | 58 | 59 | @pytest.mark.parametrize( 60 | "arr", [np.arange(5), np.linspace(0, 1, 3), np.arange(12).reshape(3, 4)[1:, 2:]] 61 | ) 62 | def test_make_array_constructor(arr): 63 | """Test implementation to create array.""" 64 | constructor = jit(make_array_constructor(arr)) 65 | arr2 = constructor() 66 | np.testing.assert_equal(arr, arr2) 67 | assert np.shares_memory(arr, arr2) 68 | 69 | 70 | def test_numba_dict(): 71 | """Test numba_dict function.""" 72 | cls = dict if numba.config.DISABLE_JIT else numba.typed.Dict 73 | 74 | # test empty dictionaries 75 | for d in [numba_dict(), numba_dict({})]: 76 | assert len(d) == 0 77 | assert isinstance(d, cls) 78 | 79 | # test initializing dictionaries in different ways 80 | for d in [ 81 | numba_dict({"a": 1, "b": 2}), 82 | numba_dict(a=1, b=2), 83 | numba_dict({"a": 1}, b=2), 84 | numba_dict({"a": 1, "b": 3}, b=2), 85 | ]: 86 | assert isinstance(d, cls) 87 | assert len(d) == 2 88 | assert d["a"] == 1 89 | assert d["b"] == 2 90 | 91 | # test edge case 92 | d = numba_dict(data=1) 93 | assert d["data"] == 1 94 | assert isinstance(d, cls) 95 | -------------------------------------------------------------------------------- /tests/tools/test_output.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | from pde.tools import output 6 | 7 | 8 | def test_progress_bars(): 9 | """Test progress bars.""" 10 | pb_cls = output.get_progress_bar_class() 11 | tot = 0 12 | for i in pb_cls(range(4)): 13 | tot += i 14 | assert tot == 6 15 | 16 | 17 | def test_in_jupyter_notebook(): 18 | """Test the function in_jupyter_notebook.""" 19 | assert isinstance(output.in_jupyter_notebook(), bool) 20 | 21 | 22 | def test_display_progress(capsys): 23 | """Test whether this works.""" 24 | for _ in output.display_progress(range(2)): 25 | pass 26 | out, err = capsys.readouterr() 27 | assert out == "" 28 | assert len(err) > 0 29 | -------------------------------------------------------------------------------- /tests/tools/test_parse_duration.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | from pde.tools.parse_duration import parse_duration 6 | 7 | 8 | def test_parse_duration(): 9 | """Test function signature checks.""" 10 | 11 | def p(value): 12 | return parse_duration(value).total_seconds() 13 | 14 | assert p("0") == 0 15 | assert p("1") == 1 16 | assert p("1:2") == 62 17 | assert p("1:2:3") == 3723 18 | -------------------------------------------------------------------------------- /tests/tools/test_plotting_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import matplotlib.pyplot as plt 6 | import matplotlib.testing.compare 7 | import pytest 8 | 9 | from pde.tools.plotting import add_scaled_colorbar, plot_on_axes, plot_on_figure 10 | 11 | 12 | def test_plot_on_axes(tmp_path): 13 | """Test the plot_on_axes decorator.""" 14 | 15 | @plot_on_axes 16 | def plot(ax): 17 | ax.plot([0, 1], [0, 1]) 18 | 19 | path = tmp_path / "test.png" 20 | plot(title="Test", filename=path) 21 | assert path.stat().st_size > 0 22 | 23 | 24 | def test_plot_on_figure(tmp_path): 25 | """Test the plot_on_figure decorator.""" 26 | 27 | @plot_on_figure 28 | def plot(fig): 29 | ax1, ax2 = fig.subplots(1, 2) 30 | ax1.plot([0, 1], [0, 1]) 31 | ax2.plot([0, 1], [0, 1]) 32 | 33 | path = tmp_path / "test.png" 34 | plot(title="Test", filename=path) 35 | assert path.stat().st_size > 0 36 | 37 | 38 | @pytest.mark.interactive 39 | def test_plot_colorbar(tmp_path, rng): 40 | """Test the plot_on_axes decorator.""" 41 | data = rng.normal(size=(3, 3)) 42 | 43 | # do not specify axis 44 | img = plt.imshow(data) 45 | add_scaled_colorbar(img, label="Label") 46 | plt.savefig(tmp_path / "img1.png") 47 | plt.clf() 48 | 49 | # specify axis explicitly 50 | ax = plt.gca() 51 | img = ax.imshow(data) 52 | add_scaled_colorbar(img, ax=ax, label="Label") 53 | plt.savefig(tmp_path / "img2.png") 54 | 55 | # compare the two results 56 | cmp = matplotlib.testing.compare.compare_images( 57 | str(tmp_path / "img1.png"), str(tmp_path / "img2.png"), tol=0.1 58 | ) 59 | assert cmp is None 60 | -------------------------------------------------------------------------------- /tests/visualization/test_movies.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. codeauthor:: David Zwicker 3 | """ 4 | 5 | import pytest 6 | 7 | from pde.fields import ScalarField 8 | from pde.grids import UnitGrid 9 | from pde.pdes import DiffusionPDE 10 | from pde.storage import MemoryStorage 11 | from pde.visualization import movies 12 | 13 | 14 | @pytest.mark.skipif(not movies.Movie.is_available(), reason="no ffmpeg") 15 | def test_movie_class(tmp_path): 16 | """Test Movie class.""" 17 | import matplotlib.pyplot as plt 18 | 19 | path = tmp_path / "test_movie.mov" 20 | 21 | try: 22 | with movies.Movie(path) as movie: 23 | # iterate over all time steps 24 | plt.plot([0, 1], [0, 1]) 25 | movie.add_figure() 26 | movie.add_figure() 27 | 28 | # save movie 29 | movie.save() 30 | except RuntimeError: 31 | pass # can happen when ffmpeg is not installed 32 | else: 33 | assert path.stat().st_size > 0 34 | 35 | 36 | @pytest.mark.skipif(not movies.Movie.is_available(), reason="no ffmpeg") 37 | @pytest.mark.parametrize("movie_func", [movies.movie_scalar, movies.movie]) 38 | def test_movie_scalar(movie_func, tmp_path, rng): 39 | """Test Movie class.""" 40 | # create some data 41 | state = ScalarField.random_uniform(UnitGrid([4, 4]), rng=rng) 42 | eq = DiffusionPDE() 43 | storage = MemoryStorage() 44 | tracker = storage.tracker(interrupts=1) 45 | eq.solve(state, t_range=2, dt=1e-2, backend="numpy", tracker=tracker) 46 | 47 | # check creating the movie 48 | path = tmp_path / "test_movie.mov" 49 | 50 | try: 51 | movie_func(storage, filename=path, progress=False) 52 | except RuntimeError: 53 | pass # can happen when ffmpeg is not installed 54 | else: 55 | assert path.stat().st_size > 0 56 | 57 | 58 | @pytest.mark.skipif(not movies.Movie.is_available(), reason="no ffmpeg") 59 | def test_movie_wrong_path(tmp_path): 60 | """Test whether there is a useful error message when path doesn't exist.""" 61 | with pytest.raises(OSError): 62 | movies.Movie(tmp_path / "unavailable" / "test.mov") 63 | --------------------------------------------------------------------------------