├── .github
└── workflows
│ ├── check_docs.yml
│ ├── codeql-analysis.yml
│ ├── coverage_report.yml
│ ├── release_pip.yml
│ ├── tests_all.yml
│ ├── tests_minversion.yml
│ ├── tests_mpi.yml
│ └── tests_types.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── docs
├── Makefile
├── build_all.sh
├── methods
│ ├── boundary_discretization
│ │ ├── BoundaryDiscretization.nb
│ │ ├── boundary_discretization.pdf
│ │ └── boundary_discretization.tex
│ ├── common.tex
│ ├── stencils
│ │ └── Cartesian2D.nb
│ └── vector_calculus
│ │ ├── Bipolar.nb
│ │ ├── Bispherical.nb
│ │ ├── Cartesian.nb
│ │ ├── Cylindrical.nb
│ │ ├── Polar.nb
│ │ └── Spherical.nb
├── paper
│ ├── paper.bib
│ └── paper.md
├── requirements.txt
├── source
│ ├── .gitignore
│ ├── _images
│ │ ├── discretization.key
│ │ ├── discretization.pdf
│ │ ├── discretization_cropped.pdf
│ │ ├── discretization_cropped.svg
│ │ ├── logo.png
│ │ ├── logo_small.png
│ │ ├── performance_noflux.pdf
│ │ ├── performance_noflux.png
│ │ ├── performance_periodic.pdf
│ │ └── performance_periodic.png
│ ├── _static
│ │ ├── custom.css
│ │ ├── requirements_main.csv
│ │ └── requirements_optional.csv
│ ├── conf.py
│ ├── create_performance_plots.py
│ ├── gallery.rst
│ ├── getting_started.rst
│ ├── index.rst
│ ├── manual
│ │ ├── advanced_usage.rst
│ │ ├── basic_usage.rst
│ │ ├── citing.rst
│ │ ├── code_of_conduct.rst
│ │ ├── contributing.rst
│ │ ├── index.rst
│ │ ├── mathematical_basics.rst
│ │ └── performance.rst
│ ├── parse_examples.py
│ └── run_autodoc.py
└── sphinx_ext
│ ├── package_config.py
│ ├── simplify_typehints.py
│ └── toctree_filter.py
├── examples
├── README.txt
├── advanced_pdes
│ ├── README.txt
│ ├── custom_noise.py
│ ├── heterogeneous_bcs.py
│ ├── mpi_parallel_run.py
│ ├── pde_1d_class.py
│ ├── pde_brusselator_class.py
│ ├── pde_coupled.py
│ ├── pde_custom_class.py
│ ├── pde_custom_numba.py
│ ├── pde_sir.py
│ ├── post_step_hook.py
│ ├── post_step_hook_class.py
│ └── solver_comparison.py
├── fields
│ ├── README.txt
│ ├── analyze_scalar_field.py
│ ├── finite_differences.py
│ ├── plot_cylindrical_field.py
│ ├── plot_polar_grid.py
│ ├── plot_vector_field.py
│ ├── random_fields.py
│ └── show_3d_field_interactively.py
├── jupyter
│ ├── .gitignore
│ ├── Different solvers.ipynb
│ ├── Discretized Fields.ipynb
│ ├── Solve PDEs.ipynb
│ ├── Tutorial 1 - Grids and fields.ipynb
│ └── Tutorial 2 - Solving pre-defined partial differential equations.ipynb
├── output
│ ├── README.txt
│ ├── logarithmic_kymograph.py
│ ├── make_movie_live.py
│ ├── make_movie_storage.py
│ ├── py_modelrunner.py
│ ├── storages.py
│ ├── tracker_interactive.py
│ ├── trackers.py
│ └── trajectory_io.py
└── simple_pdes
│ ├── README.txt
│ ├── boundary_conditions.py
│ ├── cartesian_grid.py
│ ├── laplace_eq_2d.py
│ ├── pde_1d_expression.py
│ ├── pde_brusselator_expression.py
│ ├── pde_custom_expression.py
│ ├── pde_heterogeneous_diffusion.py
│ ├── pde_schroedinger.py
│ ├── poisson_eq_1d.py
│ ├── simple.py
│ ├── spherical_grid.py
│ ├── stochastic_simulation.py
│ └── time_dependent_bcs.py
├── pde
├── __init__.py
├── fields
│ ├── __init__.py
│ ├── base.py
│ ├── collection.py
│ ├── datafield_base.py
│ ├── scalar.py
│ ├── tensorial.py
│ └── vectorial.py
├── grids
│ ├── __init__.py
│ ├── _mesh.py
│ ├── base.py
│ ├── boundaries
│ │ ├── __init__.py
│ │ ├── axes.py
│ │ ├── axis.py
│ │ └── local.py
│ ├── cartesian.py
│ ├── coordinates
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── bipolar.py
│ │ ├── bispherical.py
│ │ ├── cartesian.py
│ │ ├── cylindrical.py
│ │ ├── polar.py
│ │ └── spherical.py
│ ├── cylindrical.py
│ ├── operators
│ │ ├── __init__.py
│ │ ├── cartesian.py
│ │ ├── common.py
│ │ ├── cylindrical_sym.py
│ │ ├── polar_sym.py
│ │ └── spherical_sym.py
│ └── spherical.py
├── pdes
│ ├── __init__.py
│ ├── allen_cahn.py
│ ├── base.py
│ ├── cahn_hilliard.py
│ ├── diffusion.py
│ ├── kpz_interface.py
│ ├── kuramoto_sivashinsky.py
│ ├── laplace.py
│ ├── pde.py
│ ├── swift_hohenberg.py
│ └── wave.py
├── py.typed
├── solvers
│ ├── __init__.py
│ ├── adams_bashforth.py
│ ├── base.py
│ ├── controller.py
│ ├── crank_nicolson.py
│ ├── explicit.py
│ ├── explicit_mpi.py
│ ├── implicit.py
│ └── scipy.py
├── storage
│ ├── __init__.py
│ ├── base.py
│ ├── file.py
│ ├── memory.py
│ ├── modelrunner.py
│ └── movie.py
├── tools
│ ├── __init__.py
│ ├── cache.py
│ ├── config.py
│ ├── cuboid.py
│ ├── docstrings.py
│ ├── expressions.py
│ ├── ffmpeg.py
│ ├── math.py
│ ├── misc.py
│ ├── modelrunner.py
│ ├── mpi.py
│ ├── numba.py
│ ├── output.py
│ ├── parameters.py
│ ├── parse_duration.py
│ ├── plotting.py
│ ├── resources
│ │ ├── requirements_basic.txt
│ │ ├── requirements_full.txt
│ │ └── requirements_mpi.txt
│ ├── spectral.py
│ └── typing.py
├── trackers
│ ├── __init__.py
│ ├── base.py
│ ├── interactive.py
│ ├── interrupts.py
│ └── trackers.py
└── visualization
│ ├── __init__.py
│ ├── movies.py
│ └── plotting.py
├── pyproject.toml
├── requirements.txt
├── runtime.txt
├── scripts
├── _templates
│ ├── _pyproject.toml
│ └── _runtime.txt
├── create_requirements.py
├── create_storage_test_resources.py
├── format_code.sh
├── performance_boundaries.py
├── performance_laplace.py
├── performance_solvers.py
├── profile_import.py
├── run_tests.py
├── show_environment.py
├── tests_all.sh
├── tests_codestyle.sh
├── tests_coverage.sh
├── tests_debug.sh
├── tests_extensive.sh
├── tests_mpi.sh
├── tests_parallel.sh
├── tests_run.sh
└── tests_types.sh
└── tests
├── _notebooks
└── Test PlotTracker for different backend.ipynb
├── conftest.py
├── fields
├── fixtures
│ ├── __init__.py
│ └── fields.py
├── test_field_collections.py
├── test_generic_fields.py
├── test_scalar_fields.py
├── test_tensorial_fields.py
└── test_vectorial_fields.py
├── grids
├── boundaries
│ ├── test_axes_boundaries.py
│ ├── test_axes_boundaries_legacy.py
│ ├── test_axis_boundaries.py
│ └── test_local_boundaries.py
├── operators
│ ├── test_cartesian_operators.py
│ ├── test_common_operators.py
│ ├── test_cylindrical_operators.py
│ ├── test_polar_operators.py
│ └── test_spherical_operators.py
├── test_cartesian_grids.py
├── test_coordinates.py
├── test_cylindrical_grids.py
├── test_generic_grids.py
├── test_grid_mesh.py
└── test_spherical_grids.py
├── pdes
├── test_diffusion_pdes.py
├── test_generic_pdes.py
├── test_laplace_pdes.py
├── test_pde_class.py
├── test_pdes_mpi.py
└── test_wave_pdes.py
├── requirements.txt
├── requirements_full.txt
├── requirements_min.txt
├── requirements_mpi.txt
├── resources
└── run_pde.py
├── solvers
├── test_adams_bashforth_solver.py
├── test_controller.py
├── test_explicit_mpi_solvers.py
├── test_explicit_solvers.py
├── test_generic_solvers.py
├── test_implicit_solvers.py
└── test_scipy_solvers.py
├── storage
├── resources
│ ├── empty.avi
│ ├── no_metadata.avi
│ ├── storage_1.avi
│ ├── storage_1.hdf5
│ ├── storage_2.avi
│ ├── storage_2.avi.times
│ └── storage_2.hdf5
├── test_file_storages.py
├── test_generic_storages.py
├── test_memory_storages.py
├── test_modelrunner_storages.py
└── test_movie_storages.py
├── test_examples.py
├── test_integration.py
├── tools
├── test_cache.py
├── test_config.py
├── test_cuboid.py
├── test_expressions.py
├── test_ffmpeg.py
├── test_math.py
├── test_misc.py
├── test_mpi.py
├── test_numba.py
├── test_output.py
├── test_parameters.py
├── test_parse_duration.py
├── test_plotting_tools.py
└── test_spectral.py
├── trackers
├── test_interrupts.py
└── test_trackers.py
└── visualization
├── test_movies.py
└── test_plotting.py
/.github/workflows/check_docs.yml:
--------------------------------------------------------------------------------
1 | name: "Check documentation"
2 |
3 | on: [push]
4 |
5 | jobs:
6 | docs:
7 | runs-on: ubuntu-latest
8 | timeout-minutes: 15
9 |
10 | steps:
11 | - uses: actions/checkout@v4
12 |
13 | - uses: ammaraskar/sphinx-action@master
14 | with:
15 | docs-folder: "docs/"
16 | pre-build-command: "pip install --upgrade pip"
17 |
--------------------------------------------------------------------------------
/.github/workflows/codeql-analysis.yml:
--------------------------------------------------------------------------------
1 | name: "Code quality analysis"
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | pull_request:
7 | # The branches below must be a subset of the branches above
8 | branches: [master]
9 | schedule:
10 | - cron: '0 10 * * 5'
11 |
12 | jobs:
13 | analyze:
14 | name: Analyze
15 | runs-on: ubuntu-latest
16 |
17 | strategy:
18 | fail-fast: false
19 | matrix:
20 | # Override automatic language detection by changing the below list
21 | # Supported options are ['csharp', 'cpp', 'go', 'java', 'javascript', 'python']
22 | language: ['python']
23 | # Learn more...
24 | # https://docs.github.com/en/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#overriding-automatic-language-detection
25 |
26 | steps:
27 | - name: Checkout repository
28 | uses: actions/checkout@v4
29 | with:
30 | # We must fetch at least the immediate parents so that if this is
31 | # a pull request then we can checkout the head.
32 | fetch-depth: 2
33 |
34 | # Initializes the CodeQL tools for scanning.
35 | - name: Initialize CodeQL
36 | uses: github/codeql-action/init@v2
37 | with:
38 | languages: ${{ matrix.language }}
39 |
40 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
41 | # If this step fails, then you should remove it and run the build manually (see below)
42 | - name: Autobuild
43 | uses: github/codeql-action/autobuild@v2
44 |
45 | # ℹ️ Command-line programs to run using the OS shell.
46 | # 📚 https://git.io/JvXDl
47 |
48 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
49 | # and modify them (or add more) to build your code if your project
50 | # uses a compiled language
51 |
52 | #- run: |
53 | # make bootstrap
54 | # make release
55 |
56 | - name: Perform CodeQL Analysis
57 | uses: github/codeql-action/analyze@v2
58 |
--------------------------------------------------------------------------------
/.github/workflows/coverage_report.yml:
--------------------------------------------------------------------------------
1 | name: "Generate coverage report"
2 |
3 | on: [push]
4 |
5 | jobs:
6 | coverage_report:
7 | runs-on: ubuntu-latest
8 | timeout-minutes: 30
9 |
10 | steps:
11 | - uses: actions/checkout@v4
12 |
13 | - name: Set up Python ${{ matrix.python-version }}
14 | uses: actions/setup-python@v5
15 | with:
16 | python-version: '3.11'
17 |
18 | - name: Setup FFmpeg
19 | uses: AnimMouse/setup-ffmpeg@v1
20 |
21 | - name: Install dependencies
22 | # install all requirements. Note that the full requirements are installed separately
23 | # so the job does not fail if one of the packages cannot be installed. This allows
24 | # testing the package for newer python version even when some of the optional
25 | # packages are not yet available.
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install -r requirements.txt
29 | cat tests/requirements_full.txt | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -n 1 -I % sh -c "pip install % || true"
30 | pip install -r tests/requirements.txt
31 |
32 | - name: Generate serial coverage report
33 | env:
34 | NUMBA_DISABLE_JIT: 1
35 | MPLBACKEND: agg
36 | PYPDE_TESTRUN: 1
37 | run: |
38 | export PYTHONPATH="${PYTHONPATH}:`pwd`"
39 | pytest --cov-config=pyproject.toml --cov=pde -n auto tests
40 |
41 | - name: Setup MPI
42 | uses: mpi4py/setup-mpi@v1
43 |
44 | - name: Generate parallel coverage report
45 | env:
46 | NUMBA_DISABLE_JIT: 1
47 | MPLBACKEND: agg
48 | PYPDE_TESTRUN: 1
49 | run: |
50 | export PYTHONPATH="${PYTHONPATH}:`pwd`"
51 | pip install -r tests/requirements_mpi.txt
52 | mpiexec -n 2 pytest --cov-config=pyproject.toml --cov-append --cov=pde --use_mpi tests
53 |
54 | - name: Create coverage report
55 | run: |
56 | coverage xml -o coverage_report.xml
57 |
58 | - name: Upload coverage to Codecov
59 | uses: codecov/codecov-action@v1
60 | with:
61 | token: ${{ secrets.CODECOV_TOKEN }}
62 | file: ./coverage_report.xml
63 | flags: unittests
64 | name: codecov-pydev
65 | fail_ci_if_error: true
66 |
--------------------------------------------------------------------------------
/.github/workflows/release_pip.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | release:
5 | types: [released]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | timeout-minutes: 15
11 |
12 | steps:
13 | - uses: actions/checkout@v4
14 | with:
15 | fetch-depth: 0 # https://github.com/pypa/setuptools_scm/issues/480
16 |
17 | - name: Set up Python
18 | uses: actions/setup-python@v5
19 | with:
20 | python-version: '3.9'
21 |
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip
25 | pip install build twine
26 |
27 | - name: Prepare build
28 | run: |
29 | python -m build 2>&1 | tee build.log
30 | # exit `fgrep -i warning build.log | wc -l`
31 |
32 | - name: Check the package
33 | run: twine check --strict dist/*
34 |
35 | - name: Build and publish
36 | env:
37 | TWINE_USERNAME: __token__
38 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
39 | run: python -m twine upload dist/*
--------------------------------------------------------------------------------
/.github/workflows/tests_all.yml:
--------------------------------------------------------------------------------
1 | name: "Serial tests"
2 |
3 | on: [push]
4 |
5 | jobs:
6 | serial_tests:
7 | strategy:
8 | matrix:
9 | os: [macos-latest, ubuntu-latest, windows-latest]
10 | runs-on: ${{ matrix.os }}
11 | timeout-minutes: 60
12 |
13 | steps:
14 | - uses: actions/checkout@v4
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v5
18 | with:
19 | python-version: '3.13'
20 |
21 | - name: Setup FFmpeg
22 | uses: AnimMouse/setup-ffmpeg@v1
23 |
24 | - name: Install dependencies
25 | # install all requirements. Note that the full requirements are installed separately
26 | # so the job does not fail if one of the packages cannot be installed. This allows
27 | # testing the package for newer python version even when some of the optional
28 | # packages are not yet available.
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install -r requirements.txt
32 | cat tests/requirements_full.txt | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -I % sh -c "pip install % || true"
33 | pip install -r tests/requirements.txt
34 |
35 | - name: Run serial tests with pytest
36 | env:
37 | NUMBA_WARNINGS: 1
38 | MPLBACKEND: agg
39 | run: |
40 | cd scripts
41 | python run_tests.py --unit --runslow --num_cores auto --showconfig
42 |
--------------------------------------------------------------------------------
/.github/workflows/tests_minversion.yml:
--------------------------------------------------------------------------------
1 | name: "Tests with minimal requirements"
2 |
3 | on: [push]
4 |
5 | jobs:
6 | test_minversion:
7 | strategy:
8 | matrix:
9 | os: [macos-latest, ubuntu-latest] # windows-latest
10 | runs-on: ${{ matrix.os }}
11 | timeout-minutes: 60
12 |
13 | steps:
14 | - uses: actions/checkout@v4
15 | - name: Set up Python
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: '3.9'
19 |
20 | - name: Install dependencies
21 | # install packages in the exact version given in requirements.txt
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install -r tests/requirements_min.txt
25 | pip install -r tests/requirements.txt
26 |
27 | - name: Test with pytest
28 | env:
29 | NUMBA_WARNINGS: 1
30 | MPLBACKEND: agg
31 | run: |
32 | cd scripts
33 | python run_tests.py --unit --runslow --num_cores auto --showconfig
34 |
--------------------------------------------------------------------------------
/.github/workflows/tests_mpi.yml:
--------------------------------------------------------------------------------
1 | name: "Multiprocessing tests"
2 |
3 | on: [push]
4 |
5 | jobs:
6 | test_mpi:
7 | strategy:
8 | matrix:
9 | include:
10 | - os: "ubuntu-latest"
11 | mpi: "openmpi"
12 | - os: "macos-13"
13 | mpi: "mpich"
14 | - os: "windows-latest"
15 | mpi: "intelmpi"
16 | runs-on: ${{ matrix.os }}
17 | timeout-minutes: 30
18 |
19 | steps:
20 | - uses: actions/checkout@v4
21 |
22 | - name: Set up Python
23 | uses: actions/setup-python@v5
24 | with:
25 | python-version: '3.13'
26 |
27 | - name: Setup MPI]
28 | uses: mpi4py/setup-mpi@v1
29 | with:
30 | mpi: ${{ matrix.mpi }}
31 |
32 | - name: Install dependencies
33 | run: |
34 | python -m pip install --upgrade pip
35 | pip install -r tests/requirements_mpi.txt
36 | pip install -r tests/requirements.txt
37 |
38 | - name: Run parallel tests with pytest
39 | env:
40 | NUMBA_WARNINGS: 1
41 | MPLBACKEND: agg
42 | run: |
43 | cd scripts
44 | python run_tests.py --unit --use_mpi --showconfig
45 |
--------------------------------------------------------------------------------
/.github/workflows/tests_types.yml:
--------------------------------------------------------------------------------
1 | name: "Static type checking"
2 |
3 | on: [push]
4 |
5 | jobs:
6 | test_types:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | matrix:
10 | pyversion: ['3.9', '3.13']
11 | timeout-minutes: 30
12 |
13 | steps:
14 | - uses: actions/checkout@v4
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v5
18 | with:
19 | python-version: ${{ matrix.pyversion }}
20 |
21 | - name: Install dependencies
22 | # install all requirements. Note that the full requirements are installed separately
23 | # so the job does not fail if one of the packages cannot be installed. This allows
24 | # testing the package for newer python version even when some of the optional
25 | # packages are not yet available.
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install -r requirements.txt
29 | cat tests/requirements_full.txt | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -I % sh -c "pip install % || true"
30 | pip install -r tests/requirements.txt
31 | pip install types-PyYAML
32 |
33 | - name: Test types with mypy
34 | continue-on-error: true
35 | run: |
36 | python -m mypy --config-file pyproject.toml --pretty --package pde
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # special project folders
2 | docs/source/packages/
3 | docs/source/examples_gallery/
4 | docs/paper/paper.pdf
5 | tests/coverage/
6 | tests/mypy-report/
7 | pde/_version.py
8 | scripts/flamegraph.html
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 | docs/source/examples
15 | .idea
16 |
17 | # C extensions
18 | *.so
19 |
20 | # LaTeX
21 | *.aux
22 | *.bbl
23 | *.blg
24 | *.bgl
25 | *.out
26 | *.synctex.gz
27 | *.toc
28 | *Notes.bib
29 |
30 |
31 | # Distribution / packaging
32 | .Python
33 | build/
34 | develop-eggs/
35 | dist/
36 | downloads/
37 | eggs/
38 | .eggs/
39 | lib/
40 | lib64/
41 | parts/
42 | sdist/
43 | var/
44 | wheels/
45 | *.egg-info/
46 | .installed.cfg
47 | *.egg
48 | MANIFEST
49 |
50 | # PyInstaller
51 | # Usually these files are written by a python script from a template
52 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
53 | *.manifest
54 | *.spec
55 |
56 | # Installer logs
57 | pip-log.txt
58 | pip-delete-this-directory.txt
59 |
60 | # Unit test / coverage reports
61 | htmlcov/
62 | .tox/
63 | .coverage
64 | .coverage.*
65 | .cache
66 | nosetests.xml
67 | coverage.xml
68 | *.cover
69 | .hypothesis/
70 | .pytest_cache/
71 |
72 | # Translations
73 | *.mo
74 | *.pot
75 |
76 | # Django stuff:
77 | *.log
78 | local_settings.py
79 | db.sqlite3
80 |
81 | # Flask stuff:
82 | instance/
83 | .webassets-cache
84 |
85 | # Scrapy stuff:
86 | .scrapy
87 |
88 | # Sphinx documentation
89 | docs/_build/
90 |
91 | # PyBuilder
92 | target/
93 |
94 | # Jupyter Notebook
95 | .ipynb_checkpoints
96 |
97 | # pyenv
98 | .python-version
99 |
100 | # celery beat schedule file
101 | celerybeat-schedule
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | /.pydevproject
128 | /.project
129 | .DS_Store
130 |
131 | # IDE settings
132 | .settings
133 | .vscode
134 | *.code-workspace
135 | examples/output/allen_cahn.avi
136 | examples/output/allen_cahn.hdf
137 | debug
138 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v2.3.0
4 | hooks:
5 | - id: check-yaml
6 |
7 | - repo: https://github.com/astral-sh/ruff-pre-commit
8 | rev: v0.6.1
9 | hooks:
10 | - id: ruff
11 | args: [--fix, --show-fixes]
12 | - id: ruff-format
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the version of Python and other tools you might need
9 | build:
10 | os: ubuntu-22.04
11 | tools:
12 | python: "3.11"
13 | apt_packages:
14 | - ffmpeg
15 | - graphviz
16 |
17 | # Build documentation in the docs/ directory with Sphinx
18 | sphinx:
19 | configuration: docs/source/conf.py
20 |
21 | # If using Sphinx, optionally build your docs in additional formats such as PDF
22 | formats:
23 | - epub
24 | - pdf
25 |
26 | # Optionally declare the Python requirements required to build your docs
27 | python:
28 | install:
29 | - requirements: docs/requirements.txt
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at david.zwicker@ds.mpg.de. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
77 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 David Zwicker
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile for Sphinx documentation
2 |
3 | # You can set these variables from the command line.
4 | SPHINXOPTS =
5 | SPHINXBUILD = sphinx-build
6 | SOURCEDIR = source
7 | BUILDDIR = build
8 |
9 | # Put it first so that "make" without argument is like "make help".
10 | help:
11 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
12 |
13 | # clean:
14 | # rm -rf $(BUILDDIR)/*
15 | # rm -rf examples_gallery/*
16 |
17 | .PHONY: help Makefile
18 |
19 | # Define default target, which can be used in VS code
20 | .PHONY: default
21 | default: html
22 |
23 | # special make target for latex document to exclude the gallery
24 | latexpdf: Makefile
25 | @$(SPHINXBUILD) -M latexpdf "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -E -t exclude_gallery
26 |
27 | # special make target for linkcheck to be nit-picky
28 | linkcheck: Makefile
29 | @$(SPHINXBUILD) -M linkcheck "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -E -n
30 |
31 | # Catch-all target: route all unknown targets to Sphinx using the new
32 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
33 | %: Makefile
34 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -E
35 |
--------------------------------------------------------------------------------
/docs/build_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | make html
4 | make latexpdf
5 | make linkcheck
--------------------------------------------------------------------------------
/docs/methods/boundary_discretization/boundary_discretization.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/methods/boundary_discretization/boundary_discretization.pdf
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | -r ../requirements.txt
2 | ffmpeg-python>=0.2
3 | h5py>=2.10
4 | pandas>=2
5 | Pillow>=7.0
6 | py-modelrunner>=0.19
7 | pydot>=3
8 | Sphinx>=4
9 | sphinx-autodoc-annotation>=1.0
10 | sphinx-gallery>=0.6
11 | sphinx-rtd-theme>=1
12 | utilitiez>=0.3
13 |
--------------------------------------------------------------------------------
/docs/source/.gitignore:
--------------------------------------------------------------------------------
1 | /sg_execution_times.rst
2 |
--------------------------------------------------------------------------------
/docs/source/_images/discretization.key:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/discretization.key
--------------------------------------------------------------------------------
/docs/source/_images/discretization.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/discretization.pdf
--------------------------------------------------------------------------------
/docs/source/_images/discretization_cropped.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/discretization_cropped.pdf
--------------------------------------------------------------------------------
/docs/source/_images/discretization_cropped.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
57 |
--------------------------------------------------------------------------------
/docs/source/_images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/logo.png
--------------------------------------------------------------------------------
/docs/source/_images/logo_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/logo_small.png
--------------------------------------------------------------------------------
/docs/source/_images/performance_noflux.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/performance_noflux.pdf
--------------------------------------------------------------------------------
/docs/source/_images/performance_noflux.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/performance_noflux.png
--------------------------------------------------------------------------------
/docs/source/_images/performance_periodic.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/performance_periodic.pdf
--------------------------------------------------------------------------------
/docs/source/_images/performance_periodic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/docs/source/_images/performance_periodic.png
--------------------------------------------------------------------------------
/docs/source/_static/custom.css:
--------------------------------------------------------------------------------
1 | .sphx-glr-download-link-note {
2 | display: none;
3 | }
4 |
5 | .sphx-glr-script-out {
6 | display: none;
7 | }
8 |
9 | .wy-table-responsive table td {
10 | white-space: inherit;
11 | }
--------------------------------------------------------------------------------
/docs/source/_static/requirements_main.csv:
--------------------------------------------------------------------------------
1 | Package,Minimal version,Usage
2 | matplotlib,3.1,Visualizing results
3 | numba,0.59,Just-in-time compilation to accelerate numerics
4 | numpy,1.22,Handling numerical data
5 | scipy,1.10,Miscellaneous scientific functions
6 | sympy,1.9,Dealing with user-defined mathematical expressions
7 | tqdm,4.66,Display progress bars during calculations
8 |
--------------------------------------------------------------------------------
/docs/source/_static/requirements_optional.csv:
--------------------------------------------------------------------------------
1 | Package,Minimal version,Usage
2 | ffmpeg-python,0.2,Reading and writing videos
3 | h5py,2.10,Storing data in the hierarchical file format
4 | ipywidgets,8,Jupyter notebook support
5 | mpi4py,3,Parallel processing using MPI
6 | napari,0.4.8,Displaying images interactively
7 | numba-mpi,0.22,Parallel processing using MPI+numba
8 | pandas,2,Handling tabular data
9 | py-modelrunner,0.19,Running simulations and handling I/O
10 | pyfftw,0.12,Faster Fourier transforms
11 | rocket-fft,0.2.4,Numba-compiled fast Fourier transforms
12 |
--------------------------------------------------------------------------------
/docs/source/gallery.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | ========
3 |
4 | These are example scripts using the `py-pde` package, which illustrates some of the most
5 | important features of the package.
6 |
7 |
8 | .. include:: /examples_gallery/fields/index.rst
9 | .. include:: /examples_gallery/simple_pdes/index.rst
10 | .. include:: /examples_gallery/output/index.rst
11 | .. include:: /examples_gallery/advanced_pdes/index.rst
12 |
13 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | 'py-pde' python package
2 | =======================
3 |
4 | .. figure:: _images/logo.png
5 | :figwidth: 25%
6 | :align: right
7 | :alt: Logo of the py-pde package
8 |
9 | The `py-pde` python package provides methods and classes useful for solving
10 | partial differential equations (PDEs) of the form
11 |
12 | .. math::
13 | \partial_t u(\boldsymbol x, t) = \mathcal D[u(\boldsymbol x, t)]
14 | + \eta(u, \boldsymbol x, t) \;,
15 |
16 | where :math:`\mathcal D` is a (non-linear) operator containing spatial derivatives that
17 | defines the time evolution of a (set of) physical fields :math:`u` with possibly
18 | tensorial character, which depend on spatial coordinates :math:`\boldsymbol x` and time
19 | :math:`t`.
20 | The framework also supports stochastic differential equations in the Itô
21 | representation, where the noise is represented by :math:`\eta` above.
22 |
23 | The main audience for the package are researchers and students who want to
24 | investigate the behavior of a PDE and get an intuitive understanding of the
25 | role of the different terms and the boundary conditions.
26 | To support this, `py-pde` evaluates PDEs using the methods of lines with a
27 | finite-difference approximation of the differential operators.
28 | Consequently, the mathematical operator :math:`\mathcal D` can be naturally
29 | translated to a function evaluating the evolution rate of the PDE.
30 |
31 |
32 | **Contents**
33 |
34 | .. toctree-filt::
35 | :maxdepth: 2
36 | :numbered:
37 |
38 | getting_started
39 | gallery
40 | manual/index
41 | packages/pde
42 |
43 |
44 | **Indices and tables**
45 |
46 | * :ref:`genindex`
47 | * :ref:`modindex`
48 | * :ref:`search`
49 |
--------------------------------------------------------------------------------
/docs/source/manual/citing.rst:
--------------------------------------------------------------------------------
1 | Citing the package
2 | ^^^^^^^^^^^^^^^^^^
3 |
4 | To cite or reference `py-pde` in other work, please refer to the `publication in
5 | the Journal of Open Source Software `_.
6 | Here are the respective bibliographic records in Bibtex format:
7 |
8 | .. code-block:: bibtex
9 |
10 | @article{py-pde,
11 | Author = {David Zwicker},
12 | Doi = {10.21105/joss.02158},
13 | Journal = {Journal of Open Source Software},
14 | Number = {48},
15 | Pages = {2158},
16 | Publisher = {The Open Journal},
17 | Title = {py-pde: A Python package for solving partial differential equations},
18 | Url = {https://doi.org/10.21105/joss.02158},
19 | Volume = {5},
20 | Year = {2020}
21 | }
22 |
23 | and in RIS format:
24 |
25 | .. code-block:: text
26 |
27 | TY - JOUR
28 | AU - Zwicker, David
29 | JO - Journal of Open Source Software
30 | IS - 48
31 | SP - 2158
32 | PB - The Open Journal
33 | T1 - py-pde: A Python package for solving partial differential equations
34 | UR - https://doi.org/10.21105/joss.02158
35 | VL - 5
36 | PY - 2020
37 |
--------------------------------------------------------------------------------
/docs/source/manual/code_of_conduct.rst:
--------------------------------------------------------------------------------
1 | Code of Conduct
2 | ===============
3 |
4 | Our Pledge
5 | ----------
6 |
7 | In the interest of fostering an open and welcoming environment, we as
8 | contributors and maintainers pledge to making participation in our
9 | project and our community a harassment-free experience for everyone,
10 | regardless of age, body size, disability, ethnicity, sex
11 | characteristics, gender identity and expression, level of experience,
12 | education, socio-economic status, nationality, personal appearance,
13 | race, religion, or sexual identity and orientation.
14 |
15 | Our Standards
16 | -------------
17 |
18 | Examples of behavior that contributes to creating a positive environment
19 | include:
20 |
21 | - Using welcoming and inclusive language
22 | - Being respectful of differing viewpoints and experiences
23 | - Gracefully accepting constructive criticism
24 | - Focusing on what is best for the community
25 | - Showing empathy towards other community members
26 |
27 | Examples of unacceptable behavior by participants include:
28 |
29 | - The use of sexualized language or imagery and unwelcome sexual
30 | attention or advances
31 | - Trolling, insulting/derogatory comments, and personal or political
32 | attacks
33 | - Public or private harassment
34 | - Publishing others’ private information, such as a physical or
35 | electronic address, without explicit permission
36 | - Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | Our Responsibilities
40 | --------------------
41 |
42 | Project maintainers are responsible for clarifying the standards of
43 | acceptable behavior and are expected to take appropriate and fair
44 | corrective action in response to any instances of unacceptable behavior.
45 |
46 | Project maintainers have the right and responsibility to remove, edit,
47 | or reject comments, commits, code, wiki edits, issues, and other
48 | contributions that are not aligned to this Code of Conduct, or to ban
49 | temporarily or permanently any contributor for other behaviors that they
50 | deem inappropriate, threatening, offensive, or harmful.
51 |
52 | Scope
53 | -----
54 |
55 | This Code of Conduct applies both within project spaces and in public
56 | spaces when an individual is representing the project or its community.
57 | Examples of representing a project or community include using an
58 | official project e-mail address, posting via an official social media
59 | account, or acting as an appointed representative at an online or
60 | offline event. Representation of a project may be further defined and
61 | clarified by project maintainers.
62 |
63 | Enforcement
64 | -----------
65 |
66 | Instances of abusive, harassing, or otherwise unacceptable behavior may
67 | be reported by contacting the project team at david.zwicker@ds.mpg.de.
68 | All complaints will be reviewed and investigated and will result in a
69 | response that is deemed necessary and appropriate to the circumstances.
70 | The project team is obligated to maintain confidentiality with regard to
71 | the reporter of an incident. Further details of specific enforcement
72 | policies may be posted separately.
73 |
74 | Project maintainers who do not follow or enforce the Code of Conduct in
75 | good faith may face temporary or permanent repercussions as determined
76 | by other members of the project’s leadership.
77 |
78 | Attribution
79 | -----------
80 |
81 | This Code of Conduct is adapted from the `Contributor
82 | Covenant `__, version 1.4,
83 | available at
84 | https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
85 |
86 | For answers to common questions about this code of conduct, see
87 | https://www.contributor-covenant.org/faq
88 |
--------------------------------------------------------------------------------
/docs/source/manual/index.rst:
--------------------------------------------------------------------------------
1 | User manual
2 | ===========
3 |
4 |
5 | .. toctree::
6 | :maxdepth: 3
7 |
8 |
9 | mathematical_basics
10 | basic_usage
11 | advanced_usage
12 | performance
13 | contributing
14 | citing
15 | code_of_conduct
--------------------------------------------------------------------------------
/docs/source/parse_examples.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import pathlib
4 |
5 | # Root direcotry of the package
6 | ROOT = pathlib.Path(__file__).absolute().parents[2]
7 | # directory where all the examples reside
8 | INPUT = ROOT / "examples"
9 | # directory to which the documents are writen
10 | OUTPUT = ROOT / "docs" / "source" / "examples"
11 |
12 |
13 | def main():
14 | """Parse all examples and write them in a special example module."""
15 | # create the output directory
16 | OUTPUT.mkdir(parents=True, exist_ok=True)
17 |
18 | # iterate over all examples
19 | for path_in in INPUT.glob("*.py"):
20 | path_out = OUTPUT / (path_in.stem + ".rst")
21 | print(f"Found example {path_in}")
22 | with path_in.open("r") as file_in, path_out.open("w") as file_out:
23 | # write the header for the rst file
24 | file_out.write(".. code-block:: python\n\n")
25 |
26 | # add the actual code lines
27 | header = True
28 | for line in file_in:
29 | # skip the shebang, comments and empty lines in the beginning
30 | if header and (line.startswith("#") or len(line.strip()) == 0):
31 | continue
32 | header = False # first real line was reached
33 | file_out.write(" " + line)
34 |
35 |
36 | if __name__ == "__main__":
37 | main()
38 |
--------------------------------------------------------------------------------
/docs/source/run_autodoc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import glob
4 | import logging
5 | import os
6 | import subprocess as sp
7 | from pathlib import Path
8 |
9 | logging.basicConfig(level=logging.INFO)
10 |
11 | OUTPUT_PATH = "packages"
12 | REPLACEMENTS = {
13 | "Submodules\n----------\n\n": "",
14 | "Subpackages\n-----------": "**Subpackages:**",
15 | "pde package\n===========": "Reference manual\n================",
16 | }
17 |
18 |
19 | def replace_in_file(infile, replacements, outfile=None):
20 | """Reads in a file, replaces the given data using python formatting and writes back
21 | the result to a file.
22 |
23 | Args:
24 | infile (str):
25 | File to be read
26 | replacements (dict):
27 | The replacements old => new in a dictionary format {old: new}
28 | outfile (str):
29 | Output file to which the data is written. If it is omitted, the
30 | input file will be overwritten instead
31 | """
32 | if outfile is None:
33 | outfile = infile
34 |
35 | with Path(infile).open() as fp:
36 | content = fp.read()
37 |
38 | for key, value in replacements.items():
39 | content = content.replace(key, value)
40 |
41 | with Path(outfile).open("w") as fp:
42 | fp.write(content)
43 |
44 |
45 | def main():
46 | """Run the autodoc call."""
47 | logger = logging.getLogger("autodoc")
48 |
49 | # remove old files
50 | for path in Path(OUTPUT_PATH).glob("*.rst"):
51 | logger.info("Remove file `%s`", path)
52 | path.unlink()
53 |
54 | # run sphinx-apidoc
55 | sp.check_call(
56 | [
57 | "sphinx-apidoc",
58 | "--separate",
59 | "--maxdepth",
60 | "4",
61 | "--output-dir",
62 | OUTPUT_PATH,
63 | "--module-first",
64 | "../../pde", # path of the package
65 | "../../pde/version.py", # ignored file
66 | "../../**/conftest.py", # ignored file
67 | "../../**/tests", # ignored path
68 | ]
69 | )
70 |
71 | # replace unwanted information
72 | for path in Path(OUTPUT_PATH).glob("*.rst"):
73 | logger.info("Patch file `%s`", path)
74 | replace_in_file(path, REPLACEMENTS)
75 |
76 |
77 | if __name__ == "__main__":
78 | main()
79 |
--------------------------------------------------------------------------------
/docs/sphinx_ext/package_config.py:
--------------------------------------------------------------------------------
1 | from docutils import nodes
2 | from sphinx.util.docutils import SphinxDirective
3 |
4 |
5 | class PackageConfigDirective(SphinxDirective):
6 | """Directive that displays all package configuration items."""
7 |
8 | has_content = True
9 | required_arguments = 0
10 | optional_arguments = 0
11 | final_argument_whitespace = False
12 |
13 | def run(self):
14 | from pde.tools.config import Config
15 |
16 | c = Config()
17 | items = []
18 |
19 | for p in c.data.values():
20 | description = nodes.paragraph(text=p.description + " ")
21 | description += nodes.strong(text=f"(Default value: {c[p.name]!r})")
22 |
23 | items += nodes.definition_list_item(
24 | "",
25 | nodes.term(text=p.name),
26 | nodes.definition("", description),
27 | )
28 |
29 | return [nodes.definition_list("", *items)]
30 |
31 |
32 | def setup(app):
33 | app.add_directive("package_configuration", PackageConfigDirective)
34 | return {"version": "1.0.0"}
35 |
--------------------------------------------------------------------------------
/docs/sphinx_ext/simplify_typehints.py:
--------------------------------------------------------------------------------
1 | """Simple sphinx plug-in that simplifies type information in function signatures."""
2 |
3 | import re
4 |
5 | # simple (literal) replacement rules
6 | REPLACEMENTS = [
7 | # numbers and numerical arrays
8 | ("Union[int, float, complex, numpy.generic, numpy.ndarray]", "NumberOrArray"),
9 | ("Union[int, float, complex, numpy.ndarray]", "NumberOrArray"),
10 | ("Union[int, float, complex]", "Number"),
11 | (
12 | "Optional[Union[_SupportsArray[dtype], _NestedSequence[_SupportsArray[dtype]], "
13 | "bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, "
14 | "float, complex, str, bytes]]]]",
15 | "NumberOrArray",
16 | ),
17 | (
18 | "Union[dtype[Any], None, Type[Any], _SupportsDType[dtype[Any]], str, "
19 | "Tuple[Any, int], Tuple[Any, Union[SupportsIndex, Sequence[SupportsIndex]]], "
20 | "List[Any], _DTypeDict, Tuple[Any, Any]]",
21 | "DType",
22 | ),
23 | # Complex types describing the boundary conditions
24 | (
25 | "Dict[str, Dict | str | BCBase] | Dict | str | BCBase | "
26 | "Tuple[Dict | str | BCBase, Dict | str | BCBase] | BoundaryAxisBase | "
27 | "Sequence[Dict[str, Dict | str | BCBase] | Dict | str | BCBase | "
28 | "Tuple[Dict | str | BCBase, Dict | str | BCBase] | BoundaryAxisBase]",
29 | "BoundariesData",
30 | ),
31 | (
32 | "Dict[str, Dict | str | BCBase] | Dict | str | BCBase | "
33 | "Tuple[Dict | str | BCBase, Dict | str | BCBase] | BoundaryAxisBase",
34 | "BoundariesPairData",
35 | ),
36 | ("Dict | str | BCBase", "BoundaryData"),
37 | # Other compound data types
38 | ("Union[List[Union[TrackerBase, str]], TrackerBase, str, None]", "TrackerData"),
39 | (
40 | "Optional[Union[List[Union[TrackerBase, str]], TrackerBase, str]]",
41 | "TrackerData",
42 | ),
43 | ]
44 |
45 |
46 | # replacement rules based on regular expressions
47 | REPLACEMENTS_REGEX = {
48 | # remove full package path and only leave the module/class identifier
49 | r"pde\.(\w+\.)*": "",
50 | r"typing\.": "",
51 | }
52 |
53 |
54 | def process_signature(
55 | app, what: str, name: str, obj, options, signature, return_annotation
56 | ):
57 | """Process signature by applying replacement rules."""
58 |
59 | def process(sig_obj):
60 | """Process the signature object."""
61 | if sig_obj is not None:
62 | for key, value in REPLACEMENTS_REGEX.items():
63 | sig_obj = re.sub(key, value, sig_obj)
64 | for key, value in REPLACEMENTS:
65 | sig_obj = sig_obj.replace(key, value)
66 | return sig_obj
67 |
68 | signature = process(signature)
69 | return_annotation = process(return_annotation)
70 |
71 | return signature, return_annotation
72 |
73 |
74 | def process_docstring(app, what: str, name: str, obj, options, lines):
75 | """Process docstring by applying replacement rules."""
76 | for i, line in enumerate(lines):
77 | for key, value in REPLACEMENTS:
78 | line = line.replace(key, value)
79 | lines[i] = line
80 |
81 |
82 | def setup(app):
83 | """Set up hooks for this sphinx plugin."""
84 | app.connect("autodoc-process-signature", process_signature)
85 | app.connect("autodoc-process-docstring", process_docstring)
86 |
--------------------------------------------------------------------------------
/docs/sphinx_ext/toctree_filter.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from sphinx.directives.other import TocTree
4 |
5 |
6 | class TocTreeFilter(TocTree):
7 | """Directive to filter table-of-contents entries."""
8 |
9 | hasPat = re.compile(r"^\s*:(.+):(.+)$")
10 |
11 | # Remove any entries in the content that we dont want and strip
12 | # out any filter prefixes that we want but obviously don't want the
13 | # prefix to mess up the file name.
14 | def filter_entries(self, entries):
15 | excl = self.state.document.settings.env.config.toc_filter_exclude
16 | filtered = []
17 | for e in entries:
18 | m = self.hasPat.match(e)
19 | if m != None:
20 | if not m.groups()[0] in excl:
21 | filtered.append(m.groups()[1])
22 | else:
23 | filtered.append(e)
24 | return filtered
25 |
26 | def run(self):
27 | # Remove all TOC entries that should not be on display
28 | self.content = self.filter_entries(self.content)
29 | return super().run()
30 |
31 |
32 | def setup(app):
33 | app.add_config_value("toc_filter_exclude", [], "html")
34 | app.add_directive("toctree-filt", TocTreeFilter)
35 | return {"version": "1.0.0"}
36 |
--------------------------------------------------------------------------------
/examples/README.txt:
--------------------------------------------------------------------------------
1 | Examples
2 | ========
3 |
4 | These are example scripts using the `py-pde` package, which illustrates some of the most
5 | important features of the package.
--------------------------------------------------------------------------------
/examples/advanced_pdes/README.txt:
--------------------------------------------------------------------------------
1 | Advanced PDEs
2 | -------------
3 |
4 | These examples demonstrate more advanced usage of the package.
--------------------------------------------------------------------------------
/examples/advanced_pdes/custom_noise.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom noise
3 | ============
4 |
5 | This example solves a diffusion equation with a custom noise.
6 | """
7 |
8 | import numpy as np
9 |
10 | from pde import DiffusionPDE, ScalarField, UnitGrid
11 | from pde.tools.numba import jit
12 |
13 |
14 | class DiffusionCustomNoisePDE(DiffusionPDE):
15 | """Diffusion PDE with custom noise implementations."""
16 |
17 | def noise_realization(self, state, t):
18 | """Numpy implementation of spatially-dependent noise."""
19 | noise_field = ScalarField.random_uniform(state.grid, -self.noise, self.noise)
20 | return state.grid.cell_coords[..., 0] * noise_field
21 |
22 | def _make_noise_realization_numba(self, state):
23 | """Numba implementation of spatially-dependent noise."""
24 | noise = float(self.noise)
25 | x_values = state.grid.cell_coords[..., 0]
26 |
27 | @jit
28 | def noise_realization(state_data, t):
29 | return x_values * np.random.uniform(-noise, noise, size=state_data.shape)
30 |
31 | return noise_realization
32 |
33 |
34 | eq = DiffusionCustomNoisePDE(diffusivity=0.1, noise=0.1) # define the pde
35 | state = ScalarField.random_uniform(UnitGrid([64, 64])) # generate initial condition
36 | result = eq.solve(state, t_range=10, dt=0.01)
37 | result.plot()
38 |
--------------------------------------------------------------------------------
/examples/advanced_pdes/heterogeneous_bcs.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Heterogeneous boundary conditions
3 | =================================
4 |
5 | This example implements a diffusion equation with a boundary condition specified by a
6 | function, which can in principle depend on time.
7 | """
8 |
9 | import numpy as np
10 |
11 | from pde import CartesianGrid, DiffusionPDE, ScalarField
12 |
13 | # define grid and an initial state
14 | grid = CartesianGrid([[-5, 5], [-5, 5]], 32)
15 | field = ScalarField(grid)
16 |
17 |
18 | # define the boundary conditions, which here are calculated from a function
19 | def bc_value(adjacent_value, dx, x, y, t):
20 | """Return boundary value."""
21 | return np.sign(x)
22 |
23 |
24 | # define and solve a simple diffusion equation
25 | eq = DiffusionPDE(bc={"*": {"derivative": 0}, "y+": {"value_expression": bc_value}})
26 | res = eq.solve(field, t_range=10, dt=0.01, backend="numpy")
27 | res.plot()
28 |
--------------------------------------------------------------------------------
/examples/advanced_pdes/mpi_parallel_run.py:
--------------------------------------------------------------------------------
1 | """
2 | Use multiprocessing via MPI
3 | ===========================
4 |
5 | Use multiple cores to solve a PDE. The implementation here uses the `Message Passing
6 | Interface (MPI) `_, and the
7 | script thus needs to be run using :code:`mpiexec -n 2 python mpi_parallel_run.py`, where
8 | `2` denotes the number of cores used. Note that macOS might require an additional hint
9 | on how to connect the processes. The following line might work:
10 | `mpiexec -n 2 -host localhost:2 python3 mpi_parallel_run.py`
11 |
12 | Such parallel simulations need extra care, since multiple instances of the same program
13 | are started. In particular, in the example below, the initial state is created on all
14 | cores. However, only the state of the first core will actually be used and distributed
15 | automatically by `py-pde`. Note that also only the first (or main) core will run the
16 | trackers and receive the result of the simulation. On all other cores, the simulation
17 | result will be `None`.
18 | """
19 |
20 | from pde import DiffusionPDE, ScalarField, UnitGrid
21 |
22 | grid = UnitGrid([64, 64]) # generate grid
23 | state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition
24 |
25 | eq = DiffusionPDE(diffusivity=0.1) # define the pde
26 | result = eq.solve(state, t_range=10, dt=0.1, solver="explicit_mpi")
27 |
28 | if result is not None: # check whether we are on the main core
29 | result.plot()
30 |
--------------------------------------------------------------------------------
/examples/advanced_pdes/pde_1d_class.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 1D problem - Using custom class
3 | ===============================
4 |
5 | This example implements a PDE that is only defined in one dimension.
6 | Here, we chose the `Korteweg-de Vries equation
7 | `_, given by
8 |
9 | .. math::
10 | \partial_t \phi = 6 \phi \partial_x \phi - \partial_x^3 \phi
11 |
12 | which we implement using a custom PDE class below.
13 | """
14 |
15 | from math import pi
16 |
17 | from pde import CartesianGrid, MemoryStorage, PDEBase, ScalarField, plot_kymograph
18 |
19 |
20 | class KortewegDeVriesPDE(PDEBase):
21 | """Korteweg-de Vries equation."""
22 |
23 | def evolution_rate(self, state, t=0):
24 | """Implement the python version of the evolution equation."""
25 | assert state.grid.dim == 1 # ensure the state is one-dimensional
26 | grad_x = state.gradient("auto_periodic_neumann")[0]
27 | return 6 * state * grad_x - grad_x.laplace("auto_periodic_neumann")
28 |
29 |
30 | # initialize the equation and the space
31 | grid = CartesianGrid([[0, 2 * pi]], [32], periodic=True)
32 | state = ScalarField.from_expression(grid, "sin(x)")
33 |
34 | # solve the equation and store the trajectory
35 | storage = MemoryStorage()
36 | eq = KortewegDeVriesPDE()
37 | eq.solve(state, t_range=3, solver="scipy", tracker=storage.tracker(0.1))
38 |
39 | # plot the trajectory as a space-time plot
40 | plot_kymograph(storage)
41 |
--------------------------------------------------------------------------------
/examples/advanced_pdes/pde_brusselator_class.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Brusselator - Using custom class
3 | ================================
4 |
5 | This example implements the `Brusselator
6 | `_ with spatial coupling,
7 |
8 | .. math::
9 |
10 | \partial_t u &= D_0 \nabla^2 u + a - (1 + b) u + v u^2 \\
11 | \partial_t v &= D_1 \nabla^2 v + b u - v u^2
12 |
13 | Here, :math:`D_0` and :math:`D_1` are the respective diffusivity and the
14 | parameters :math:`a` and :math:`b` are related to reaction rates.
15 |
16 | Note that the PDE can also be implemented using the :class:`~pde.pdes.pde.PDE`
17 | class; see :doc:`the example <../simple_pdes/pde_brusselator_expression>`. However, that
18 | implementation is less flexible and might be more difficult to extend later.
19 | """
20 |
21 | import numba as nb
22 | import numpy as np
23 |
24 | from pde import FieldCollection, PDEBase, PlotTracker, ScalarField, UnitGrid
25 |
26 |
27 | class BrusselatorPDE(PDEBase):
28 | """Brusselator with diffusive mobility."""
29 |
30 | def __init__(self, a=1, b=3, diffusivity=None, bc="auto_periodic_neumann"):
31 | super().__init__()
32 | self.a = a
33 | self.b = b
34 | self.diffusivity = [1, 0.1] if diffusivity is None else diffusivity
35 | self.bc = bc # boundary condition
36 |
37 | def get_initial_state(self, grid):
38 | """Prepare a useful initial state."""
39 | u = ScalarField(grid, self.a, label="Field $u$")
40 | v = self.b / self.a + 0.1 * ScalarField.random_normal(grid, label="Field $v$")
41 | return FieldCollection([u, v])
42 |
43 | def evolution_rate(self, state, t=0):
44 | """Pure python implementation of the PDE."""
45 | u, v = state
46 | rhs = state.copy()
47 | d0, d1 = self.diffusivity
48 | rhs[0] = d0 * u.laplace(self.bc) + self.a - (self.b + 1) * u + u**2 * v
49 | rhs[1] = d1 * v.laplace(self.bc) + self.b * u - u**2 * v
50 | return rhs
51 |
52 | def _make_pde_rhs_numba(self, state):
53 | """Nunmba-compiled implementation of the PDE."""
54 | d0, d1 = self.diffusivity
55 | a, b = self.a, self.b
56 | laplace = state.grid.make_operator("laplace", bc=self.bc)
57 |
58 | @nb.njit
59 | def pde_rhs(state_data, t):
60 | u = state_data[0]
61 | v = state_data[1]
62 |
63 | rate = np.empty_like(state_data)
64 | rate[0] = d0 * laplace(u) + a - (1 + b) * u + v * u**2
65 | rate[1] = d1 * laplace(v) + b * u - v * u**2
66 | return rate
67 |
68 | return pde_rhs
69 |
70 |
71 | # initialize state
72 | grid = UnitGrid([64, 64])
73 | eq = BrusselatorPDE(diffusivity=[1, 0.1])
74 | state = eq.get_initial_state(grid)
75 |
76 | # simulate the pde
77 | tracker = PlotTracker(interrupts=1, plot_args={"kind": "merged", "vmin": 0, "vmax": 5})
78 | sol = eq.solve(state, t_range=20, dt=1e-3, tracker=tracker)
79 |
--------------------------------------------------------------------------------
/examples/advanced_pdes/pde_coupled.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Custom Class for coupled PDEs
3 | =============================
4 |
5 | This example shows how to solve a set of coupled PDEs, the
6 | spatially coupled `FitzHugh–Nagumo model
7 | `_, which is a simple model
8 | for the excitable dynamics of coupled Neurons:
9 |
10 | .. math::
11 |
12 | \partial_t u &= \nabla^2 u + u (u - \alpha) (1 - u) + w \\
13 | \partial_t w &= \epsilon u
14 |
15 | Here, :math:`\alpha` denotes the external stimulus and :math:`\epsilon` defines
16 | the recovery time scale. We implement this as a custom PDE class below.
17 | """
18 |
19 | from pde import FieldCollection, PDEBase, UnitGrid
20 |
21 |
22 | class FitzhughNagumoPDE(PDEBase):
23 | """FitzHugh–Nagumo model with diffusive coupling."""
24 |
25 | def __init__(self, stimulus=0.5, τ=10, a=0, b=0, bc="auto_periodic_neumann"):
26 | super().__init__()
27 | self.bc = bc
28 | self.stimulus = stimulus
29 | self.τ = τ
30 | self.a = a
31 | self.b = b
32 |
33 | def evolution_rate(self, state, t=0):
34 | v, w = state # membrane potential and recovery variable
35 |
36 | v_t = v.laplace(bc=self.bc) + v - v**3 / 3 - w + self.stimulus
37 | w_t = (v + self.a - self.b * w) / self.τ
38 |
39 | return FieldCollection([v_t, w_t])
40 |
41 |
42 | grid = UnitGrid([32, 32])
43 | state = FieldCollection.scalar_random_uniform(2, grid)
44 |
45 | eq = FitzhughNagumoPDE()
46 | result = eq.solve(state, t_range=100, dt=0.01)
47 | result.plot()
48 |
--------------------------------------------------------------------------------
/examples/advanced_pdes/pde_custom_class.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Kuramoto-Sivashinsky - Using custom class
3 | =========================================
4 |
5 | This example implements a scalar PDE using a custom class. We here consider the
6 | `Kuramoto–Sivashinsky equation
7 | `_, which for instance
8 | describes the dynamics of flame fronts:
9 |
10 | .. math::
11 | \partial_t u = -\frac12 |\nabla u|^2 - \nabla^2 u - \nabla^4 u
12 | """
13 |
14 | from pde import PDEBase, ScalarField, UnitGrid
15 |
16 |
17 | class KuramotoSivashinskyPDE(PDEBase):
18 | """Implementation of the normalized Kuramoto–Sivashinsky equation."""
19 |
20 | def evolution_rate(self, state, t=0):
21 | """Implement the python version of the evolution equation."""
22 | state_lap = state.laplace(bc="auto_periodic_neumann")
23 | state_lap2 = state_lap.laplace(bc="auto_periodic_neumann")
24 | state_grad = state.gradient(bc="auto_periodic_neumann")
25 | return -state_grad.to_scalar("squared_sum") / 2 - state_lap - state_lap2
26 |
27 |
28 | grid = UnitGrid([32, 32]) # generate grid
29 | state = ScalarField.random_uniform(grid) # generate initial condition
30 |
31 | eq = KuramotoSivashinskyPDE() # define the pde
32 | result = eq.solve(state, t_range=10, dt=0.01)
33 | result.plot()
34 |
--------------------------------------------------------------------------------
/examples/advanced_pdes/pde_custom_numba.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Kuramoto-Sivashinsky - Compiled methods
3 | =======================================
4 |
5 | This example implements a scalar PDE using a custom class with a numba-compiled method
6 | for accelerated calculations. We here consider the `Kuramoto–Sivashinsky equation
7 | `_, which for instance
8 | describes the dynamics of flame fronts:
9 |
10 | .. math::
11 | \partial_t u = -\frac12 |\nabla u|^2 - \nabla^2 u - \nabla^4 u
12 | """
13 |
14 | import numba as nb
15 |
16 | from pde import PDEBase, ScalarField, UnitGrid
17 |
18 |
19 | class KuramotoSivashinskyPDE(PDEBase):
20 | """Implementation of the normalized Kuramoto–Sivashinsky equation."""
21 |
22 | def __init__(self, bc="auto_periodic_neumann"):
23 | super().__init__()
24 | self.bc = bc
25 |
26 | def evolution_rate(self, state, t=0):
27 | """Implement the python version of the evolution equation."""
28 | state_lap = state.laplace(bc=self.bc)
29 | state_lap2 = state_lap.laplace(bc=self.bc)
30 | state_grad_sq = state.gradient_squared(bc=self.bc)
31 | return -state_grad_sq / 2 - state_lap - state_lap2
32 |
33 | def _make_pde_rhs_numba(self, state):
34 | """Nunmba-compiled implementation of the PDE."""
35 | gradient_squared = state.grid.make_operator("gradient_squared", bc=self.bc)
36 | laplace = state.grid.make_operator("laplace", bc=self.bc)
37 |
38 | @nb.njit
39 | def pde_rhs(data, t):
40 | return -0.5 * gradient_squared(data) - laplace(data + laplace(data))
41 |
42 | return pde_rhs
43 |
44 |
45 | grid = UnitGrid([32, 32]) # generate grid
46 | state = ScalarField.random_uniform(grid) # generate initial condition
47 |
48 | eq = KuramotoSivashinskyPDE() # define the pde
49 | result = eq.solve(state, t_range=10, dt=0.01)
50 | result.plot()
51 |
--------------------------------------------------------------------------------
/examples/advanced_pdes/pde_sir.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Custom PDE class: SIR model
3 | ===========================
4 |
5 | This example implements a `spatially coupled SIR model
6 | `_ with
7 | the following dynamics for the density of susceptible, infected, and recovered
8 | individuals:
9 |
10 | .. math::
11 |
12 | \partial_t s &= D \nabla^2 s - \beta is \\
13 | \partial_t i &= D \nabla^2 i + \beta is - \gamma i \\
14 | \partial_t r &= D \nabla^2 r + \gamma i
15 |
16 | Here, :math:`D` is the diffusivity, :math:`\beta` the infection rate, and
17 | :math:`\gamma` the recovery rate.
18 | """
19 |
20 | from pde import FieldCollection, PDEBase, PlotTracker, ScalarField, UnitGrid
21 |
22 |
23 | class SIRPDE(PDEBase):
24 | """SIR-model with diffusive mobility."""
25 |
26 | def __init__(
27 | self, beta=0.3, gamma=0.9, diffusivity=0.1, bc="auto_periodic_neumann"
28 | ):
29 | super().__init__()
30 | self.beta = beta # transmission rate
31 | self.gamma = gamma # recovery rate
32 | self.diffusivity = diffusivity # spatial mobility
33 | self.bc = bc # boundary condition
34 |
35 | def get_state(self, s, i):
36 | """Generate a suitable initial state."""
37 | norm = (s + i).data.max() # maximal density
38 | if norm > 1:
39 | s /= norm
40 | i /= norm
41 | s.label = "Susceptible"
42 | i.label = "Infected"
43 |
44 | # create recovered field
45 | r = ScalarField(s.grid, data=1 - s - i, label="Recovered")
46 | return FieldCollection([s, i, r])
47 |
48 | def evolution_rate(self, state, t=0):
49 | s, i, r = state
50 | diff = self.diffusivity
51 | ds_dt = diff * s.laplace(self.bc) - self.beta * i * s
52 | di_dt = diff * i.laplace(self.bc) + self.beta * i * s - self.gamma * i
53 | dr_dt = diff * r.laplace(self.bc) + self.gamma * i
54 | return FieldCollection([ds_dt, di_dt, dr_dt])
55 |
56 |
57 | eq = SIRPDE(beta=2, gamma=0.1)
58 |
59 | # initialize state
60 | grid = UnitGrid([32, 32])
61 | s = ScalarField(grid, 1)
62 | i = ScalarField(grid, 0)
63 | i.data[0, 0] = 1
64 | state = eq.get_state(s, i)
65 |
66 | # simulate the pde
67 | tracker = PlotTracker(interrupts=10, plot_args={"vmin": 0, "vmax": 1})
68 | sol = eq.solve(state, t_range=50, dt=1e-2, tracker=["progress", tracker])
69 |
--------------------------------------------------------------------------------
/examples/advanced_pdes/post_step_hook.py:
--------------------------------------------------------------------------------
1 | """
2 | Post-step hook function
3 | =======================
4 |
5 | Demonstrate the simple hook function in :class:`~pde.pdes.PDE`, which is called after
6 | each time step and may modify the state and abort the simulation.
7 | """
8 |
9 | from pde import PDE, ScalarField, UnitGrid
10 |
11 |
12 | def post_step_hook(state_data, t):
13 | """Helper function called after every time step."""
14 | state_data[24:40, 24:40] = 1 # set central region to given value
15 |
16 | if t > 1e3:
17 | raise StopIteration # abort simulation at given time
18 |
19 |
20 | eq = PDE({"c": "laplace(c)"}, post_step_hook=post_step_hook)
21 | state = ScalarField(UnitGrid([64, 64]))
22 | result = eq.solve(state, dt=0.1, t_range=1e4)
23 | result.plot()
24 |
--------------------------------------------------------------------------------
/examples/advanced_pdes/post_step_hook_class.py:
--------------------------------------------------------------------------------
1 | """
2 | Post-step hook function in a custom class
3 | =========================================
4 |
5 | The hook function created by :meth:`~pde.pdes.PDEBase.make_post_step_hook` is called
6 | after each time step. The function can modify the state, keep track of additional
7 | information, and abort the simulation.
8 | """
9 |
10 | from pde import PDEBase, ScalarField, UnitGrid
11 |
12 |
13 | class CustomPDE(PDEBase):
14 | def make_post_step_hook(self, state):
15 | """Create a hook function that is called after every time step."""
16 |
17 | def post_step_hook(state_data, t, post_step_data):
18 | """Limit state 1 and abort when standard deviation exceeds 1."""
19 | i = state_data > 1 # get violating entries
20 | overshoot = (state_data[i] - 1).sum() # get total correction
21 | state_data[i] = 1 # limit data entries
22 | post_step_data += overshoot # accumulate total correction
23 | if post_step_data > 400:
24 | # Abort simulation when correction exceeds 400
25 | # Note that the `post_step_data` of the previous step will be returned.
26 | raise StopIteration
27 |
28 | return post_step_hook, 0.0 # hook function and initial value for data
29 |
30 | def evolution_rate(self, state, t=0):
31 | """Evaluate the right hand side of the evolution equation."""
32 | return state.__class__(state.grid, data=1) # constant growth
33 |
34 |
35 | grid = UnitGrid([64, 64]) # generate grid
36 | state = ScalarField.random_uniform(grid, 0.0, 0.5) # generate initial condition
37 |
38 | eq = CustomPDE()
39 | result = eq.solve(state, dt=0.1, t_range=1e4)
40 | result.plot(title=f"Total correction={eq.diagnostics['solver']['post_step_data']}")
41 |
--------------------------------------------------------------------------------
/examples/advanced_pdes/solver_comparison.py:
--------------------------------------------------------------------------------
1 | """
2 | Solver comparison
3 | =================
4 |
5 | This example shows how to set up solvers explicitly and how to extract
6 | diagnostic information.
7 | """
8 |
9 | import pde
10 |
11 | # initialize the grid, an initial condition, and the PDE
12 | grid = pde.UnitGrid([32, 32])
13 | field = pde.ScalarField.random_uniform(grid, -1, 1)
14 | eq = pde.DiffusionPDE()
15 |
16 | # try the explicit solver
17 | solver1 = pde.ExplicitSolver(eq)
18 | controller1 = pde.Controller(solver1, t_range=1, tracker=None)
19 | sol1 = controller1.run(field, dt=1e-3)
20 | sol1.label = "explicit solver"
21 | print("Diagnostic information from first run:")
22 | print(controller1.diagnostics)
23 | print()
24 |
25 | # try an explicit solver with adaptive time steps
26 | solver2 = pde.ExplicitSolver(eq, scheme="runge-kutta", adaptive=True)
27 | controller2 = pde.Controller(solver2, t_range=1, tracker=None)
28 | sol2 = controller2.run(field, dt=1e-3)
29 | sol2.label = "explicit, adaptive solver"
30 | print("Diagnostic information from second run:")
31 | print(controller2.diagnostics)
32 | print()
33 |
34 | # try the standard scipy solver
35 | solver3 = pde.ScipySolver(eq)
36 | controller3 = pde.Controller(solver3, t_range=1, tracker=None)
37 | sol3 = controller3.run(field)
38 | sol3.label = "scipy solver"
39 | print("Diagnostic information from third run:")
40 | print(controller3.diagnostics)
41 | print()
42 |
43 | # plot both fields and give the deviation as the title
44 | deviation12 = ((sol1 - sol2) ** 2).average
45 | deviation13 = ((sol1 - sol3) ** 2).average
46 | title = f"Deviation: {deviation12:.2g}, {deviation13:.2g}"
47 | pde.FieldCollection([sol1, sol2, sol3]).plot(title=title)
48 |
--------------------------------------------------------------------------------
/examples/fields/README.txt:
--------------------------------------------------------------------------------
1 | Grids and fields
2 | ----------------
3 |
4 | These examples show how to define and manipulate discretized fields.
--------------------------------------------------------------------------------
/examples/fields/analyze_scalar_field.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualizing a scalar field
3 | ==========================
4 |
5 | This example displays methods for visualizing scalar fields.
6 | """
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 |
11 | from pde import CylindricalSymGrid, ScalarField
12 |
13 | # create a scalar field with some noise
14 | grid = CylindricalSymGrid(7, [0, 4 * np.pi], 64)
15 | data = ScalarField.from_expression(grid, "sin(z) * exp(-r / 3)")
16 | data += 0.05 * ScalarField.random_normal(grid)
17 |
18 | # manipulate the field
19 | smoothed = data.smooth() # Gaussian smoothing to get rid of the noise
20 | projected = data.project("r") # integrate along the radial direction
21 | sliced = smoothed.slice({"z": 1}) # slice the smoothed data
22 |
23 | # create four plots of the field and the modifications
24 | fig, axes = plt.subplots(nrows=2, ncols=2)
25 | data.plot(ax=axes[0, 0], title="Original field")
26 | smoothed.plot(ax=axes[1, 0], title="Smoothed field")
27 | projected.plot(ax=axes[0, 1], title="Projection on axial coordinate")
28 | sliced.plot(ax=axes[1, 1], title="Slice of smoothed field at $z=1$")
29 | plt.subplots_adjust(hspace=0.8)
30 | plt.show()
31 |
--------------------------------------------------------------------------------
/examples/fields/finite_differences.py:
--------------------------------------------------------------------------------
1 | """
2 | Finite differences approximation
3 | ================================
4 |
5 | This example displays various finite difference (FD) approximations of derivatives of
6 | simple harmonic function.
7 | """
8 |
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 |
12 | from pde import CartesianGrid, ScalarField
13 | from pde.tools.expressions import evaluate
14 |
15 | # create two grids with different resolution to emphasize finite difference approximation
16 | grid_fine = CartesianGrid([(0, 2 * np.pi)], 256, periodic=True)
17 | grid_coarse = CartesianGrid([(0, 2 * np.pi)], 10, periodic=True)
18 |
19 | # create figure to present plots of the derivative
20 | fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
21 |
22 | # plot first derivatives of sin(x)
23 | f = ScalarField.from_expression(grid_coarse, "sin(x)")
24 | f_grad = f.gradient("periodic") # first derivative (from gradient vector field)
25 | ScalarField.from_expression(grid_fine, "cos(x)").plot(
26 | ax=axes[0, 0], label="Expected f'"
27 | )
28 | f_grad.plot(ax=axes[0, 0], label="FD grad(f)", ls="", marker="o")
29 | plt.legend(frameon=True)
30 | plt.ylabel("")
31 | plt.xlabel("")
32 | plt.title(r"First derivative of $f(x) = \sin(x)$")
33 |
34 | # plot second derivatives of sin(x)
35 | f_laplace = f.laplace("periodic") # second derivative
36 | f_grad2 = f_grad.divergence("periodic") # second derivative using composition
37 | ScalarField.from_expression(grid_fine, "-sin(x)").plot(
38 | ax=axes[0, 1], label="Expected f''"
39 | )
40 | f_laplace.plot(ax=axes[0, 1], label="FD laplace(f)", ls="", marker="o")
41 | f_grad2.plot(ax=axes[0, 1], label="FD div(grad(f))", ls="", marker="o")
42 | plt.legend(frameon=True)
43 | plt.xlabel("")
44 | plt.title(r"Second derivative of $f(x) = \sin(x)$")
45 |
46 | # plot first derivatives of sin(x)**2
47 | g_fine = ScalarField.from_expression(grid_fine, "sin(x)**2")
48 | g = g_fine.interpolate_to_grid(grid_coarse)
49 | expected = evaluate("2 * cos(x) * sin(x)", {"g": g_fine})
50 | fd_1 = evaluate("d_dx(g)", {"g": g}) # first derivative (from directional derivative)
51 | expected.plot(ax=axes[1, 0], label="Expected g'")
52 | fd_1.plot(ax=axes[1, 0], label="FD grad(g)", ls="", marker="o")
53 | plt.legend(frameon=True)
54 | plt.title(r"First derivative of $g(x) = \sin(x)^2$")
55 |
56 | # plot second derivatives of sin(x)**2
57 | expected = evaluate("2 * cos(2 * x)", {"g": g_fine})
58 | fd_2 = evaluate("d2_dx2(g)", {"g": g}) # second derivative
59 | fd_11 = evaluate("d_dx(d_dx(g))", {"g": g}) # composition of first derivatives
60 | expected.plot(ax=axes[1, 1], label="Expected g''")
61 | fd_2.plot(ax=axes[1, 1], label="FD laplace(g)", ls="", marker="o")
62 | fd_11.plot(ax=axes[1, 1], label="FD div(grad(g))", ls="", marker="o")
63 | plt.legend(frameon=True)
64 | plt.title(r"Second derivative of $g(x) = \sin(x)^2$")
65 |
66 | # finalize plot
67 | plt.tight_layout()
68 | plt.show()
69 |
--------------------------------------------------------------------------------
/examples/fields/plot_cylindrical_field.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Plotting a scalar field in cylindrical coordinates
3 | ==================================================
4 |
5 | This example shows how to initialize and visualize the scalar field
6 | :math:`u = \sqrt{z} \, \exp(-r^2)` in cylindrical coordinates.
7 | """
8 |
9 | from pde import CylindricalSymGrid, ScalarField
10 |
11 | grid = CylindricalSymGrid(radius=3, bounds_z=[0, 4], shape=16)
12 | field = ScalarField.from_expression(grid, "sqrt(z) * exp(-r**2)")
13 | field.plot(title="Scalar field in cylindrical coordinates")
14 |
--------------------------------------------------------------------------------
/examples/fields/plot_polar_grid.py:
--------------------------------------------------------------------------------
1 | """
2 | Plot a polar grid
3 | =================
4 |
5 | This example shows how to initialize a polar grid with a hole inside and angular
6 | symmetry, so that fields only depend on the radial coordinate.
7 | """
8 |
9 | from pde import PolarSymGrid
10 |
11 | grid = PolarSymGrid((2, 5), 8)
12 | grid.plot(title=f"Area={grid.volume:.5g}")
13 |
--------------------------------------------------------------------------------
/examples/fields/plot_vector_field.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Plotting a vector field
3 | =======================
4 |
5 | This example shows how to initialize and visualize the vector field
6 | :math:`\boldsymbol u = \bigl(\sin(x), \cos(x)\bigr)`.
7 | """
8 |
9 | from pde import CartesianGrid, VectorField
10 |
11 | grid = CartesianGrid([[-2, 2], [-2, 2]], 32)
12 | field = VectorField.from_expression(grid, ["sin(x)", "cos(x)"])
13 | field.plot(method="streamplot", title="Stream plot")
14 |
--------------------------------------------------------------------------------
/examples/fields/random_fields.py:
--------------------------------------------------------------------------------
1 | """
2 | Random scalar fields
3 | ====================
4 |
5 | This example showcases several random fields
6 | """
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 |
11 | from pde import ScalarField, UnitGrid
12 |
13 | # initialize grid and plot figure
14 | grid = UnitGrid([256, 256], periodic=True)
15 | fig, axes = plt.subplots(nrows=2, ncols=2)
16 |
17 | f1 = ScalarField.random_uniform(grid, -2.5, 2.5)
18 | f1.plot(ax=axes[0, 0], title="Uniform, uncorrelated")
19 |
20 | f2 = ScalarField.random_normal(grid, correlation="power law", exponent=-6)
21 | f2.plot(ax=axes[0, 1], title="Gaussian, power-law correlated")
22 |
23 | f3 = ScalarField.random_normal(grid, correlation="cosine", length_scale=30)
24 | f3.plot(ax=axes[1, 0], title="Gaussian, cosine correlated")
25 |
26 | f4 = ScalarField.random_harmonic(grid, modes=4)
27 | f4.plot(ax=axes[1, 1], title="Combined harmonic functions")
28 |
29 | plt.subplots_adjust(hspace=0.8)
30 | plt.show()
31 |
--------------------------------------------------------------------------------
/examples/fields/show_3d_field_interactively.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualizing a 3d field interactively
3 | ====================================
4 |
5 | This example demonstrates how to display 3d data interactively using the
6 | `napari `_ viewer.
7 | """
8 |
9 | import numpy as np
10 |
11 | from pde import CartesianGrid, ScalarField
12 |
13 | # create a scalar field with some noise
14 | grid = CartesianGrid([[0, 2 * np.pi]] * 3, 64)
15 | data = ScalarField.from_expression(grid, "(cos(2 * x) * sin(3 * y) + cos(2 * z))**2")
16 | data += ScalarField.random_normal(grid, std=0.1)
17 | data.label = "3D Field"
18 |
19 | data.plot_interactive()
20 |
--------------------------------------------------------------------------------
/examples/jupyter/.gitignore:
--------------------------------------------------------------------------------
1 | *.hdf
--------------------------------------------------------------------------------
/examples/jupyter/Discretized Fields.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import sys\n",
10 | "\n",
11 | "sys.path.append(\"../..\") # add the pde package to the python path\n",
12 | "import pde"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "# define a simple grid\n",
22 | "grid = pde.UnitGrid([32, 32])\n",
23 | "grid.plot()"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# define scalar field, initially filled with zeros\n",
33 | "field = pde.ScalarField(grid)\n",
34 | "field.average"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "# do computations on the field\n",
44 | "field += 1\n",
45 | "field.average"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "# define a scalar field initialized with random colored noise and plot it\n",
55 | "scalar = pde.ScalarField.random_normal(grid, correlation=\"power law\", exponent=-2)\n",
56 | "scalar.plot(colorbar=True);"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "# apply operators to the field\n",
66 | "smoothed = scalar.smooth(1)\n",
67 | "laplace = smoothed.laplace(bc=\"auto_periodic_neumann\")\n",
68 | "laplace.plot(colorbar=True);"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "# initialize a vector field and plot it\n",
78 | "vector = pde.VectorField.random_normal(grid, correlation=\"power law\", exponent=-4)\n",
79 | "vector.plot(method=\"streamplot\");"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "# plot the first component of the vector field\n",
89 | "vector[0].plot()"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": []
98 | }
99 | ],
100 | "metadata": {
101 | "kernelspec": {
102 | "display_name": "Python 3 (ipykernel)",
103 | "language": "python",
104 | "name": "python3"
105 | },
106 | "language_info": {
107 | "codemirror_mode": {
108 | "name": "ipython",
109 | "version": 3
110 | },
111 | "file_extension": ".py",
112 | "mimetype": "text/x-python",
113 | "name": "python",
114 | "nbconvert_exporter": "python",
115 | "pygments_lexer": "ipython3",
116 | "version": "3.12.7"
117 | },
118 | "toc": {
119 | "base_numbering": 1,
120 | "nav_menu": {},
121 | "number_sections": true,
122 | "sideBar": true,
123 | "skip_h1_title": false,
124 | "title_cell": "Table of Contents",
125 | "title_sidebar": "Contents",
126 | "toc_cell": false,
127 | "toc_position": {},
128 | "toc_section_display": true,
129 | "toc_window_display": false
130 | }
131 | },
132 | "nbformat": 4,
133 | "nbformat_minor": 4
134 | }
135 |
--------------------------------------------------------------------------------
/examples/output/README.txt:
--------------------------------------------------------------------------------
1 | Output and analysis
2 | -------------------
3 |
4 | These examples demonstrate how to store and analyze output.
--------------------------------------------------------------------------------
/examples/output/logarithmic_kymograph.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Logarithmic kymograph
3 | =====================
4 |
5 | This example demonstrates a space-time plot with a logarithmic time axis, which is useful
6 | to analyze coarsening processes. Here, we use :func:`utilitiez.densityplot` for plotting.
7 | """
8 |
9 | import matplotlib.pyplot as plt
10 | from utilitiez import densityplot
11 |
12 | import pde
13 |
14 | # define grid, initial field, and the PDE
15 | grid = pde.UnitGrid([128])
16 | field = pde.ScalarField.random_uniform(grid, -0.1, 0.1)
17 | eq = pde.CahnHilliardPDE(interface_width=2)
18 |
19 | # run the simulation and store data in logarithmically spaced time intervals
20 | storage = pde.MemoryStorage()
21 | res = eq.solve(
22 | field, t_range=1e5, adaptive=True, tracker=[storage.tracker("geometric(10, 1.1)")]
23 | )
24 |
25 | # create the density plot, which detects the logarithmically scaled time
26 | densityplot(storage.data, storage.times, grid.axes_coords[0])
27 | plt.xlabel("Time")
28 | plt.ylabel("Space")
29 |
--------------------------------------------------------------------------------
/examples/output/make_movie_live.py:
--------------------------------------------------------------------------------
1 | """
2 | Create a movie live
3 | ===================
4 |
5 | This example shows how to create a movie while running the simulation. Making movies
6 | requires that `ffmpeg` is installed in a standard location.
7 | """
8 |
9 | from pde import DiffusionPDE, PlotTracker, ScalarField, UnitGrid
10 |
11 | grid = UnitGrid([16, 16]) # generate grid
12 | state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition
13 |
14 | tracker = PlotTracker(movie="diffusion.mov") # create movie tracker
15 |
16 | eq = DiffusionPDE() # define the physics
17 | eq.solve(state, t_range=2, dt=0.005, tracker=tracker)
18 |
--------------------------------------------------------------------------------
/examples/output/make_movie_storage.py:
--------------------------------------------------------------------------------
1 | """
2 | Create a movie from a storage
3 | =============================
4 |
5 | This example shows how to create a movie from data stored during a simulation. Making
6 | movies requires that `ffmpeg` is installed in a standard location.
7 | """
8 |
9 | from pde import DiffusionPDE, MemoryStorage, ScalarField, UnitGrid, movie_scalar
10 |
11 | grid = UnitGrid([16, 16]) # generate grid
12 | state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition
13 |
14 | storage = MemoryStorage() # create storage
15 | tracker = storage.tracker(interrupts=1) # create associated tracker
16 |
17 | eq = DiffusionPDE() # define the physics
18 | eq.solve(state, t_range=2, dt=0.005, tracker=tracker)
19 |
20 | # create movie from stored data
21 | movie_scalar(storage, "diffusion.mov")
22 |
--------------------------------------------------------------------------------
/examples/output/py_modelrunner.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -m modelrunner.run --output data.hdf5 --method foreground
2 | """
3 | Using :mod:`modelrunner`
4 | ========================
5 |
6 | This example shows how `py-pde` can be combined with :mod:`modelrunner`. The magic first
7 | line allows running this example as a script using :code:`./py_modelrunner.py`, which
8 | runs the function defined below and stores all results in the file `data.hdf5`.
9 |
10 | The results can be read by the following code
11 |
12 | .. code-block:: python
13 |
14 | from modelrunner import Result
15 |
16 | r = Result.from_file("data.hdf5")
17 | r.result.plot() # plots the final state
18 | r.storage["trajectory"] # allows accessing the stored trajectory
19 | """
20 |
21 | from pde import DiffusionPDE, ModelrunnerStorage, ScalarField, UnitGrid
22 |
23 |
24 | def run(storage, diffusivity=0.1):
25 | """Function that runs the model.
26 |
27 | Args:
28 | storage (:mod:`~modelrunner.storage.group.StorageGroup`):
29 | Automatically supplied storage, to which extra data can be written
30 | diffusivity (float):
31 | Example for a parameter used in the model
32 | """
33 | # initialize the model
34 | state = ScalarField.random_uniform(UnitGrid([64, 64]), 0.2, 0.3)
35 | storage["initia_state"] = state # store initial state with simulation
36 | eq = DiffusionPDE(diffusivity=diffusivity)
37 |
38 | # store trajectory in storage
39 | tracker = ModelrunnerStorage(storage, loc="trajectory").tracker(1)
40 | final_state = eq.solve(state, t_range=5, tracker=tracker)
41 |
42 | # returns the final state as the result, which will be stored by modelrunner
43 | return final_state
44 |
--------------------------------------------------------------------------------
/examples/output/storages.py:
--------------------------------------------------------------------------------
1 | """
2 | Storage examples
3 | ================
4 |
5 | This example shows how to use :mod:`~pde.storage` to store data persistently.
6 | """
7 |
8 | from pde import AllenCahnPDE, FileStorage, MovieStorage, ScalarField, UnitGrid
9 |
10 | # initialize the model
11 | state = ScalarField.random_uniform(UnitGrid([128, 128]), -0.01, 0.01)
12 | eq = AllenCahnPDE()
13 |
14 | # initialize empty storages
15 | file_write = FileStorage("allen_cahn.hdf")
16 | movie_write = MovieStorage("allen_cahn.avi", vmin=-1, vmax=1)
17 |
18 | # store trajectory in storage
19 | final_state = eq.solve(
20 | state,
21 | t_range=100,
22 | adaptive=True,
23 | tracker=[file_write.tracker(2), movie_write.tracker(1)],
24 | )
25 |
26 | # read storage and plot last frame
27 | movie_read = MovieStorage("allen_cahn.avi")
28 | movie_read[-1].plot()
29 |
--------------------------------------------------------------------------------
/examples/output/tracker_interactive.py:
--------------------------------------------------------------------------------
1 | """
2 | Using an interactive tracker
3 | ============================
4 |
5 | This example illustrates how a simulation can be analyzed live using the
6 | `napari `_ viewer.
7 | """
8 |
9 | import pde
10 |
11 |
12 | def main():
13 | grid = pde.UnitGrid([64, 64])
14 | field = pde.ScalarField.random_uniform(grid, label="Density")
15 |
16 | eq = pde.CahnHilliardPDE()
17 | eq.solve(field, t_range=1e3, dt=1e-3, tracker=["progress", "interactive"])
18 |
19 |
20 | if __name__ == "__main__":
21 | # this safeguard is required since the interactive tracker uses multiprocessing
22 | main()
23 |
--------------------------------------------------------------------------------
/examples/output/trackers.py:
--------------------------------------------------------------------------------
1 | """
2 | Using simulation trackers
3 | =========================
4 |
5 | This example illustrates how trackers can be used to analyze simulations.
6 | """
7 |
8 | import pde
9 |
10 | grid = pde.UnitGrid([32, 32]) # generate grid
11 | state = pde.ScalarField.random_uniform(grid) # generate initial condition
12 |
13 | storage = pde.MemoryStorage()
14 |
15 | trackers = [
16 | "progress", # show progress bar during simulation
17 | "steady_state", # abort when steady state is reached
18 | storage.tracker(interrupts=1), # store data every simulation time unit
19 | pde.PlotTracker(show=True), # show images during simulation
20 | # print some output every 5 real seconds:
21 | pde.PrintTracker(interrupts=pde.RealtimeInterrupts(duration=5)),
22 | ]
23 |
24 | eq = pde.DiffusionPDE(0.1) # define the PDE
25 | eq.solve(state, 3, dt=0.1, tracker=trackers)
26 |
27 | for field in storage:
28 | print(field.integral)
29 |
--------------------------------------------------------------------------------
/examples/output/trajectory_io.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Writing and reading trajectory data
3 | ===================================
4 |
5 | This example illustrates how to store intermediate data to a file for later
6 | post-processing. The storage frequency is an argument to the tracker.
7 | """
8 |
9 | from tempfile import NamedTemporaryFile
10 |
11 | import pde
12 |
13 | # define grid, state and pde
14 | grid = pde.UnitGrid([32])
15 | state = pde.FieldCollection(
16 | [pde.ScalarField.random_uniform(grid), pde.VectorField.random_uniform(grid)]
17 | )
18 | eq = pde.PDE({"s": "-0.1 * s", "v": "-v"})
19 |
20 | # get a temporary file to write data to
21 | with NamedTemporaryFile(suffix=".hdf5") as path:
22 | # run a simulation and write the results
23 | writer = pde.FileStorage(path.name, write_mode="truncate")
24 | eq.solve(state, t_range=32, dt=0.01, tracker=writer.tracker(1))
25 |
26 | # read the simulation back in again
27 | reader = pde.FileStorage(path.name, write_mode="read_only")
28 | pde.plot_kymographs(reader)
29 |
--------------------------------------------------------------------------------
/examples/simple_pdes/README.txt:
--------------------------------------------------------------------------------
1 | Simple PDEs
2 | -----------
3 |
4 | These examples demonstrate basic usage of the package to solve PDEs.
--------------------------------------------------------------------------------
/examples/simple_pdes/boundary_conditions.py:
--------------------------------------------------------------------------------
1 | """
2 | Setting boundary conditions
3 | ===========================
4 |
5 | This example shows how different boundary conditions can be specified.
6 | """
7 |
8 | from pde import DiffusionPDE, ScalarField, UnitGrid
9 |
10 | grid = UnitGrid([32, 32], periodic=[False, True]) # generate grid
11 | state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition
12 |
13 | # set boundary conditions `bc` for all axes
14 | eq = DiffusionPDE(
15 | bc={"x-": {"derivative": 0.1}, "x+": {"value": "sin(y / 2)"}, "y": "periodic"}
16 | )
17 |
18 | result = eq.solve(state, t_range=10, dt=0.005)
19 | result.plot()
20 |
--------------------------------------------------------------------------------
/examples/simple_pdes/cartesian_grid.py:
--------------------------------------------------------------------------------
1 | """
2 | Diffusion on a Cartesian grid
3 | =============================
4 |
5 | This example shows how to solve the diffusion equation on a Cartesian grid.
6 | """
7 |
8 | from pde import CartesianGrid, DiffusionPDE, ScalarField
9 |
10 | grid = CartesianGrid([[-1, 1], [0, 2]], [30, 16]) # generate grid
11 | state = ScalarField(grid) # generate initial condition
12 | state.insert([0, 1], 1)
13 |
14 | eq = DiffusionPDE(0.1) # define the pde
15 | result = eq.solve(state, t_range=1, dt=0.01)
16 | result.plot(cmap="magma")
17 |
--------------------------------------------------------------------------------
/examples/simple_pdes/laplace_eq_2d.py:
--------------------------------------------------------------------------------
1 | """
2 | Solving Laplace's equation in 2d
3 | ================================
4 |
5 | This example shows how to solve a 2d Laplace equation with spatially varying
6 | boundary conditions.
7 | """
8 |
9 | import numpy as np
10 |
11 | from pde import CartesianGrid, solve_laplace_equation
12 |
13 | grid = CartesianGrid([[0, 2 * np.pi], [0, 2 * np.pi]], 64)
14 | bcs = {"x": {"value": "sin(y)"}, "y": {"value": "sin(x)"}}
15 |
16 | res = solve_laplace_equation(grid, bcs)
17 | res.plot()
18 |
--------------------------------------------------------------------------------
/examples/simple_pdes/pde_1d_expression.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 1D problem - Using `PDE` class
3 | ==============================
4 |
5 | This example implements a PDE that is only defined in one dimension.
6 | Here, we chose the `Korteweg-de Vries equation
7 | `_, given by
8 |
9 | .. math::
10 | \partial_t \phi = 6 \phi \partial_x \phi - \partial_x^3 \phi
11 |
12 | which we implement using the :class:`~pde.pdes.pde.PDE`.
13 | """
14 |
15 | from math import pi
16 |
17 | from pde import PDE, CartesianGrid, MemoryStorage, ScalarField, plot_kymograph
18 |
19 | # initialize the equation and the space
20 | eq = PDE({"φ": "6 * φ * d_dx(φ) - laplace(d_dx(φ))"})
21 | grid = CartesianGrid([[0, 2 * pi]], [32], periodic=True)
22 | state = ScalarField.from_expression(grid, "sin(x)")
23 |
24 | # solve the equation and store the trajectory
25 | storage = MemoryStorage()
26 | eq.solve(state, t_range=3, solver="scipy", tracker=storage.tracker(0.1))
27 |
28 | # plot the trajectory as a space-time plot
29 | plot_kymograph(storage)
30 |
--------------------------------------------------------------------------------
/examples/simple_pdes/pde_brusselator_expression.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Brusselator - Using the `PDE` class
3 | ===================================
4 |
5 | This example uses the :class:`~pde.pdes.pde.PDE` class to implement the
6 | `Brusselator `_ with spatial
7 | coupling,
8 |
9 | .. math::
10 |
11 | \partial_t u &= D_0 \nabla^2 u + a - (1 + b) u + v u^2 \\
12 | \partial_t v &= D_1 \nabla^2 v + b u - v u^2
13 |
14 | Here, :math:`D_0` and :math:`D_1` are the respective diffusivity and the
15 | parameters :math:`a` and :math:`b` are related to reaction rates.
16 |
17 | Note that the same result can also be achieved with a
18 | :doc:`full implementation of a custom class <../advanced_pdes/pde_brusselator_class>`,
19 | which allows for more flexibility at the cost of code complexity.
20 | """
21 |
22 | from pde import PDE, FieldCollection, PlotTracker, ScalarField, UnitGrid
23 |
24 | # define the PDE
25 | a, b = 1, 3
26 | d0, d1 = 1, 0.1
27 | eq = PDE(
28 | {
29 | "u": f"{d0} * laplace(u) + {a} - ({b} + 1) * u + u**2 * v",
30 | "v": f"{d1} * laplace(v) + {b} * u - u**2 * v",
31 | }
32 | )
33 |
34 | # initialize state
35 | grid = UnitGrid([64, 64])
36 | u = ScalarField(grid, a, label="Field $u$")
37 | v = b / a + 0.1 * ScalarField.random_normal(grid, label="Field $v$")
38 | state = FieldCollection([u, v])
39 |
40 | # simulate the pde
41 | tracker = PlotTracker(interrupts=1, plot_args={"vmin": 0, "vmax": 5})
42 | sol = eq.solve(state, t_range=20, dt=1e-3, tracker=tracker)
43 |
--------------------------------------------------------------------------------
/examples/simple_pdes/pde_custom_expression.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Kuramoto-Sivashinsky - Using `PDE` class
3 | ========================================
4 |
5 | This example implements a scalar PDE using the :class:`~pde.pdes.pde.PDE`. We here
6 | consider the `Kuramoto–Sivashinsky equation
7 | `_, which for instance
8 | describes the dynamics of flame fronts:
9 |
10 | .. math::
11 | \partial_t u = -\frac12 |\nabla u|^2 - \nabla^2 u - \nabla^4 u
12 | """
13 |
14 | from pde import PDE, ScalarField, UnitGrid
15 |
16 | grid = UnitGrid([32, 32]) # generate grid
17 | state = ScalarField.random_uniform(grid) # generate initial condition
18 |
19 | eq = PDE({"u": "-gradient_squared(u) / 2 - laplace(u + laplace(u))"}) # define the pde
20 | result = eq.solve(state, t_range=10, dt=0.01)
21 | result.plot()
22 |
--------------------------------------------------------------------------------
/examples/simple_pdes/pde_heterogeneous_diffusion.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Diffusion equation with spatial dependence
3 | ==========================================
4 |
5 | This example solve the
6 | `Diffusion equation `_ with a
7 | heterogeneous diffusivity:
8 |
9 | .. math::
10 | \partial_t c = \nabla\bigr( D(\boldsymbol r) \nabla c \bigr)
11 |
12 | using the :class:`~pde.pdes.pde.PDE` class. In particular, we consider
13 | :math:`D(x) = 1.01 + \tanh(x)`, which gives a low diffusivity on the left side of the
14 | domain.
15 |
16 | Note that the naive implementation,
17 | :code:`PDE({"c": "divergence((1.01 + tanh(x)) * gradient(c))"})`, has numerical
18 | instabilities. This is because two finite difference approximations are nested. To
19 | arrive at a more stable numerical scheme, it is advisable to expand the divergence,
20 |
21 | .. math::
22 | \partial_t c = D \nabla^2 c + \nabla D . \nabla c
23 | """
24 |
25 | from pde import PDE, CartesianGrid, MemoryStorage, ScalarField, plot_kymograph
26 |
27 | # Expanded definition of the PDE
28 | diffusivity = "1.01 + tanh(x)"
29 | term_1 = f"({diffusivity}) * laplace(c)"
30 | term_2 = f"dot(gradient({diffusivity}), gradient(c))"
31 | eq = PDE({"c": f"{term_1} + {term_2}"}, bc={"value": 0})
32 |
33 |
34 | grid = CartesianGrid([[-5, 5]], 64) # generate grid
35 | field = ScalarField(grid, 1) # generate initial condition
36 |
37 | storage = MemoryStorage() # store intermediate information of the simulation
38 | res = eq.solve(field, 100, dt=1e-3, tracker=storage.tracker(1)) # solve the PDE
39 |
40 | plot_kymograph(storage) # visualize the result in a space-time plot
41 |
--------------------------------------------------------------------------------
/examples/simple_pdes/pde_schroedinger.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Schrödinger's Equation
3 | ======================
4 |
5 | This example implements a complex PDE using the :class:`~pde.pdes.pde.PDE`. We here
6 | chose the `Schrödinger equation `_
7 | without a spatial potential in non-dimensional form:
8 |
9 | .. math::
10 | i \partial_t \psi = -\nabla^2 \psi
11 |
12 | Note that the example imposes Neumann conditions at the wall, so the wave packet is
13 | expected to reflect off the wall.
14 | """
15 |
16 | from math import sqrt
17 |
18 | from pde import PDE, CartesianGrid, MemoryStorage, ScalarField, plot_kymograph
19 |
20 | grid = CartesianGrid([[0, 20]], 128, periodic=False) # generate grid
21 |
22 | # create a (normalized) wave packet with a certain form as an initial condition
23 | initial_state = ScalarField.from_expression(grid, "exp(I * 5 * x) * exp(-(x - 10)**2)")
24 | initial_state /= sqrt(initial_state.to_scalar("norm_squared").integral.real)
25 |
26 | eq = PDE({"ψ": "I * laplace(ψ)"}) # define the pde
27 |
28 | # solve the pde and store intermediate data
29 | storage = MemoryStorage()
30 | eq.solve(initial_state, t_range=2.5, dt=1e-5, tracker=[storage.tracker(0.02)])
31 |
32 | # visualize the results as a space-time plot
33 | plot_kymograph(storage, scalar="norm_squared")
34 |
--------------------------------------------------------------------------------
/examples/simple_pdes/poisson_eq_1d.py:
--------------------------------------------------------------------------------
1 | """
2 | Solving Poisson's equation in 1d
3 | ================================
4 |
5 | This example shows how to solve a 1d Poisson equation with boundary conditions.
6 | """
7 |
8 | from pde import CartesianGrid, ScalarField, solve_poisson_equation
9 |
10 | grid = CartesianGrid([[0, 1]], 32, periodic=False)
11 | field = ScalarField(grid, 1)
12 | result = solve_poisson_equation(field, bc={"x-": {"value": 0}, "x+": {"derivative": 1}})
13 |
14 | result.plot()
15 |
--------------------------------------------------------------------------------
/examples/simple_pdes/simple.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple diffusion equation
3 | =========================
4 |
5 | This example solves a simple diffusion equation in two dimensions.
6 | """
7 |
8 | from pde import DiffusionPDE, ScalarField, UnitGrid
9 |
10 | grid = UnitGrid([64, 64]) # generate grid
11 | state = ScalarField.random_uniform(grid, 0.2, 0.3) # generate initial condition
12 |
13 | eq = DiffusionPDE(diffusivity=0.1) # define the pde
14 | result = eq.solve(state, t_range=10)
15 | result.plot()
16 |
--------------------------------------------------------------------------------
/examples/simple_pdes/spherical_grid.py:
--------------------------------------------------------------------------------
1 | """
2 | Spherically symmetric PDE
3 | =========================
4 |
5 | This example illustrates how to solve a PDE in a spherically symmetric geometry.
6 | """
7 |
8 | from pde import DiffusionPDE, ScalarField, SphericalSymGrid
9 |
10 | grid = SphericalSymGrid(radius=[1, 5], shape=128) # generate grid
11 | state = ScalarField.random_uniform(grid) # generate initial condition
12 |
13 | eq = DiffusionPDE(0.1) # define the PDE
14 | result = eq.solve(state, t_range=0.1, dt=0.001)
15 |
16 | result.plot(kind="image")
17 |
--------------------------------------------------------------------------------
/examples/simple_pdes/stochastic_simulation.py:
--------------------------------------------------------------------------------
1 | """
2 | Stochastic simulation
3 | =====================
4 |
5 | This example illustrates how a stochastic simulation can be done.
6 | """
7 |
8 | from pde import KPZInterfacePDE, MemoryStorage, ScalarField, UnitGrid, plot_kymograph
9 |
10 | grid = UnitGrid([64]) # generate grid
11 | state = ScalarField.random_harmonic(grid) # generate initial condition
12 |
13 | eq = KPZInterfacePDE(noise=1) # define the SDE
14 | storage = MemoryStorage()
15 | eq.solve(state, t_range=10, dt=0.01, tracker=storage.tracker(0.5))
16 | plot_kymograph(storage)
17 |
--------------------------------------------------------------------------------
/examples/simple_pdes/time_dependent_bcs.py:
--------------------------------------------------------------------------------
1 | """
2 | Time-dependent boundary conditions
3 | ==================================
4 |
5 | This example solves a simple diffusion equation in one dimensions with time-dependent
6 | boundary conditions.
7 | """
8 |
9 | from pde import PDE, CartesianGrid, MemoryStorage, ScalarField, plot_kymograph
10 |
11 | grid = CartesianGrid([[0, 10]], [64]) # generate grid
12 | state = ScalarField(grid) # generate initial condition
13 |
14 | eq = PDE({"c": "laplace(c)"}, bc={"value_expression": "sin(t)"})
15 |
16 | storage = MemoryStorage()
17 | eq.solve(state, t_range=20, dt=1e-4, tracker=storage.tracker(0.1))
18 |
19 | # plot the trajectory as a space-time plot
20 | plot_kymograph(storage)
21 |
--------------------------------------------------------------------------------
/pde/__init__.py:
--------------------------------------------------------------------------------
1 | """The py-pde package provides classes and methods for solving partial differential
2 | equations."""
3 |
4 | # determine the package version
5 | try:
6 | # try reading version of the automatically generated module
7 | from ._version import __version__
8 | except ImportError:
9 | # determine version automatically from CVS information
10 | from importlib.metadata import PackageNotFoundError, version
11 |
12 | try:
13 | __version__ = version("pde")
14 | except PackageNotFoundError:
15 | # package is not installed, so we cannot determine any version
16 | __version__ = "unknown"
17 | del PackageNotFoundError, version # clean name space
18 |
19 | # initialize the configuration
20 | from .tools.config import Config, environment
21 |
22 | config = Config() # initialize the default configuration
23 |
24 | import contextlib
25 |
26 | # import all other modules that should occupy the main name space
27 | from .fields import *
28 | from .grids import *
29 | from .pdes import *
30 | from .solvers import *
31 | from .storage import *
32 | from .tools.parameters import Parameter
33 | from .trackers import *
34 | from .visualization import *
35 |
36 | with contextlib.suppress(ImportError):
37 | from .tools.modelrunner import *
38 |
39 | del contextlib, Config # clean name space
40 |
--------------------------------------------------------------------------------
/pde/fields/__init__.py:
--------------------------------------------------------------------------------
1 | """Defines fields, which contain the actual data stored on a discrete grid.
2 |
3 | .. autosummary::
4 | :nosignatures:
5 |
6 | ~scalar.ScalarField
7 | ~vectorial.VectorField
8 | ~tensorial.Tensor2Field
9 | ~collection.FieldCollection
10 |
11 |
12 | Inheritance structure of the classes:
13 |
14 |
15 | .. inheritance-diagram:: scalar.ScalarField vectorial.VectorField tensorial.Tensor2Field
16 | collection.FieldCollection
17 | :parts: 1
18 |
19 | The details of the classes are explained below:
20 |
21 | .. codeauthor:: David Zwicker
22 | """
23 |
24 | from .base import FieldBase
25 | from .collection import FieldCollection
26 | from .datafield_base import DataFieldBase
27 | from .scalar import ScalarField
28 | from .tensorial import Tensor2Field
29 | from .vectorial import VectorField
30 |
31 | # DataFieldBase has been moved to its own module on 2024-04-18.
32 | # Add it back to `base` for the time being, so dependent code doesn't break
33 | from . import base # isort:skip
34 |
35 | base.DataFieldBase = DataFieldBase # type: ignore
36 | del base # clean namespaces
37 |
--------------------------------------------------------------------------------
/pde/grids/__init__.py:
--------------------------------------------------------------------------------
1 | """Grids define the domains on which PDEs will be solved. In particular, symmetries,
2 | periodicities, and the discretizations are defined by the underlying grid.
3 |
4 | We only consider regular, orthogonal grids, which are constructed from orthogonal
5 | coordinate systems with equidistant discretizations along each axis. The dimension of
6 | the space that the grid describes is given by the attribute :attr:`dim`. Cartesian
7 | coordinates can be mapped to grid coordinates and the corresponding discretization cells
8 | using the method :meth:`transform`.
9 |
10 | .. autosummary::
11 | :nosignatures:
12 |
13 | ~cartesian.UnitGrid
14 | ~cartesian.CartesianGrid
15 | ~spherical.PolarSymGrid
16 | ~spherical.SphericalSymGrid
17 | ~cylindrical.CylindricalSymGrid
18 |
19 | Inheritance structure of the classes:
20 |
21 | .. inheritance-diagram:: cartesian.UnitGrid cartesian.CartesianGrid
22 | spherical.PolarSymGrid spherical.SphericalSymGrid cylindrical.CylindricalSymGrid
23 | :parts: 1
24 |
25 | .. codeauthor:: David Zwicker
26 | """
27 |
28 | from . import operators # import all operator modules to register them
29 | from .base import registered_operators
30 | from .boundaries import *
31 | from .cartesian import CartesianGrid, UnitGrid
32 | from .cylindrical import CylindricalSymGrid
33 | from .spherical import PolarSymGrid, SphericalSymGrid
34 |
35 | del operators # remove the name from the namespace
36 |
--------------------------------------------------------------------------------
/pde/grids/coordinates/__init__.py:
--------------------------------------------------------------------------------
1 | """Package collecting classes representing orthonormal coordinate systems.
2 |
3 | .. autosummary::
4 | :nosignatures:
5 |
6 | ~bipolar.BipolarCoordinates
7 | ~bispherical.BisphericalCoordinates
8 | ~cartesian.CartesianCoordinates
9 | ~cylindrical.CylindricalCoordinates
10 | ~polar.PolarCoordinates
11 | ~spherical.SphericalCoordinates
12 | """
13 |
14 | from .base import CoordinatesBase, DimensionError
15 | from .bipolar import BipolarCoordinates
16 | from .bispherical import BisphericalCoordinates
17 | from .cartesian import CartesianCoordinates
18 | from .cylindrical import CylindricalCoordinates
19 | from .polar import PolarCoordinates
20 | from .spherical import SphericalCoordinates
21 |
--------------------------------------------------------------------------------
/pde/grids/coordinates/bipolar.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import numpy as np
8 | from numpy.typing import ArrayLike
9 |
10 | from .base import CoordinatesBase
11 |
12 |
13 | class BipolarCoordinates(CoordinatesBase):
14 | """2-dimensional bipolar coordinates."""
15 |
16 | dim = 2
17 | axes = ["σ", "τ"]
18 | _axes_alt = {"σ": ["sigma"], "τ": ["tau"]}
19 | coordinate_limits = [(0, 2 * np.pi), (-np.inf, np.inf)]
20 |
21 | def __init__(self, scale_parameter: float = 1):
22 | super().__init__()
23 | if scale_parameter <= 0:
24 | raise ValueError("Scale parameter must be positive")
25 | self.scale_parameter = scale_parameter
26 |
27 | def __repr__(self) -> str:
28 | """Return instance as string."""
29 | return f"{self.__class__.__name__}(scale_parameter={self.scale_parameter})"
30 |
31 | def __eq__(self, other):
32 | return (
33 | self.__class__ is other.__class__
34 | and self.scale_parameter == other.scale_parameter
35 | )
36 |
37 | def _pos_to_cart(self, points: np.ndarray) -> np.ndarray:
38 | σ, τ = points[..., 0], points[..., 1]
39 | denom = np.cosh(τ) - np.cos(σ)
40 | x = self.scale_parameter * np.sinh(τ) / denom
41 | y = self.scale_parameter * np.sin(σ) / denom
42 | return np.stack((x, y), axis=-1) # type: ignore
43 |
44 | def _pos_from_cart(self, points: np.ndarray) -> np.ndarray:
45 | x, y = points[..., 0], points[..., 1]
46 | a = self.scale_parameter
47 | h2 = x**2 + y**2
48 | denom = a**2 - h2 + np.sqrt((a**2 - h2) ** 2 + 4 * a**2 * y**2)
49 | σ = np.mod(np.pi - 2 * np.arctan2(2 * a * y, denom), 2 * np.pi)
50 | τ = 0.5 * np.log(((x + a) ** 2 + y**2) / ((x - a) ** 2 + y**2))
51 | return np.stack((σ, τ), axis=-1) # type: ignore
52 |
53 | def _mapping_jacobian(self, points: np.ndarray) -> np.ndarray:
54 | σ, τ = points[..., 0], points[..., 1]
55 |
56 | sinσ = np.sin(σ)
57 | cosσ = np.cos(σ)
58 | sinhτ = np.sinh(τ)
59 | coshτ = np.cosh(τ)
60 | factor = self.scale_parameter * (cosσ - coshτ) ** -2
61 |
62 | return factor * np.array( # type: ignore
63 | [
64 | [-sinσ * sinhτ, 1 - cosσ * coshτ],
65 | [cosσ * coshτ - 1, -sinσ * sinhτ],
66 | ]
67 | )
68 |
69 | def _volume_factor(self, points: np.ndarray) -> ArrayLike:
70 | σ, τ = points[..., 0], points[..., 1]
71 | return self.scale_parameter**2 * (np.cosh(τ) - np.cos(σ)) ** -2
72 |
73 | def _scale_factors(self, points: np.ndarray) -> np.ndarray:
74 | σ, τ = points[..., 0], points[..., 1]
75 | sf = self.scale_parameter / (np.cosh(τ) - np.cos(σ))
76 | return np.array([sf, sf]) # type: ignore
77 |
78 | def _basis_rotation(self, points: np.ndarray) -> np.ndarray:
79 | σ, τ = points[..., 0], points[..., 1]
80 |
81 | sinσ = np.sin(σ)
82 | cosσ = np.cos(σ)
83 | sinhτ = np.sinh(τ)
84 | coshτ = np.cosh(τ)
85 | factor = 1 / (cosσ - coshτ)
86 |
87 | return factor * np.array( # type: ignore
88 | [
89 | [sinσ * sinhτ, 1 - cosσ * coshτ],
90 | [cosσ * coshτ - 1, sinσ * sinhτ],
91 | ]
92 | )
93 |
--------------------------------------------------------------------------------
/pde/grids/coordinates/cartesian.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import numpy as np
8 | from numpy.typing import ArrayLike
9 |
10 | from .base import CoordinatesBase
11 |
12 |
13 | class CartesianCoordinates(CoordinatesBase):
14 | """N-dimensional Cartesian coordinates."""
15 |
16 | _objs: dict[int, CartesianCoordinates] = {}
17 |
18 | def __new__(cls, dim: int):
19 | # cache the instances for each dimension
20 | if dim not in cls._objs:
21 | cls._objs[dim] = super().__new__(cls)
22 | return cls._objs[dim]
23 |
24 | def __getnewargs__(self):
25 | return (self.dim,)
26 |
27 | def __init__(self, dim: int):
28 | """
29 | Args:
30 | dim (int):
31 | Dimension of the Cartesian coordinate system
32 | """
33 | if dim <= 0:
34 | raise ValueError("`dim` must be positive integer")
35 | self.dim = dim
36 | if self.dim <= 3:
37 | self.axes = list("xyz"[: self.dim])
38 | else:
39 | self.axes = [chr(97 + i) for i in range(self.dim)]
40 | self.coordinate_limits = [(-np.inf, np.inf)] * self.dim
41 |
42 | def __repr__(self) -> str:
43 | """Return instance as string."""
44 | return f"{self.__class__.__name__}(dim={self.dim})"
45 |
46 | def __eq__(self, other):
47 | return self.__class__ is other.__class__ and self.dim == other.dim
48 |
49 | def _pos_to_cart(self, points: np.ndarray) -> np.ndarray:
50 | return points
51 |
52 | def _pos_from_cart(self, points: np.ndarray) -> np.ndarray:
53 | return points
54 |
55 | def _mapping_jacobian(self, points: np.ndarray) -> np.ndarray:
56 | jac = np.zeros((self.dim, self.dim) + points.shape[:-1])
57 | jac[range(self.dim), range(self.dim)] = 1
58 | return jac # type: ignore
59 |
60 | def _volume_factor(self, points: np.ndarray) -> ArrayLike:
61 | return np.ones(points.shape[:-1])
62 |
63 | def _cell_volume(self, c_low: np.ndarray, c_high: np.ndarray):
64 | return np.prod(c_high - c_low, axis=-1)
65 |
66 | def _scale_factors(self, points: np.ndarray) -> np.ndarray:
67 | return np.ones_like(points) # type: ignore
68 |
69 | def _basis_rotation(self, points: np.ndarray) -> np.ndarray:
70 | return np.eye(self.dim) # type: ignore
71 |
--------------------------------------------------------------------------------
/pde/grids/coordinates/cylindrical.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import numpy as np
8 | from numpy.typing import ArrayLike
9 |
10 | from .base import CoordinatesBase
11 |
12 |
13 | class CylindricalCoordinates(CoordinatesBase):
14 | """3-dimensional cylindrical coordinates."""
15 |
16 | _singleton: CylindricalCoordinates | None = None
17 | dim = 3
18 | coordinate_limits = [(0, np.inf), (0, 2 * np.pi), (-np.inf, np.inf)]
19 | axes = ["r", "φ", "z"]
20 | _axes_alt = {"φ": ["phi"]}
21 |
22 | def __new__(cls):
23 | # cache the instances for each dimension
24 | if cls._singleton is None:
25 | cls._singleton = super().__new__(cls)
26 | return cls._singleton
27 |
28 | def __eq__(self, other):
29 | return self.__class__ is other.__class__
30 |
31 | def _pos_to_cart(self, points: np.ndarray) -> np.ndarray:
32 | r, φ, z = points[..., 0], points[..., 1], points[..., 2]
33 | x = r * np.cos(φ)
34 | y = r * np.sin(φ)
35 | return np.stack((x, y, z), axis=-1) # type: ignore
36 |
37 | def _pos_from_cart(self, points: np.ndarray) -> np.ndarray:
38 | x, y, z = points[..., 0], points[..., 1], points[..., 2]
39 | r = np.hypot(x, y)
40 | φ = np.arctan2(y, x)
41 | return np.stack((r, φ, z), axis=-1) # type: ignore
42 |
43 | def _mapping_jacobian(self, points: np.ndarray) -> np.ndarray:
44 | r, φ = points[..., 0], points[..., 1]
45 | sinφ, cosφ = np.sin(φ), np.cos(φ)
46 | zero = np.zeros_like(r)
47 | return np.array( # type: ignore
48 | [
49 | [cosφ, -r * sinφ, zero],
50 | [sinφ, r * cosφ, zero],
51 | [zero, zero, zero + 1],
52 | ]
53 | )
54 |
55 | def _volume_factor(self, points: np.ndarray) -> ArrayLike:
56 | return points[..., 0]
57 |
58 | def _cell_volume(self, c_low: np.ndarray, c_high: np.ndarray):
59 | r1, φ1, z1 = c_low[..., 0], c_low[..., 1], c_low[..., 2]
60 | r2, φ2, z2 = c_high[..., 0], c_high[..., 1], c_high[..., 2]
61 | return (φ2 - φ1) * (z2 - z1) * (r2**2 - r1**2) / 2
62 |
63 | def _scale_factors(self, points: np.ndarray) -> np.ndarray:
64 | r = points[..., 0]
65 | ones = np.ones_like(r)
66 | return np.array([ones, r, ones]) # type: ignore
67 |
68 | def _basis_rotation(self, points: np.ndarray) -> np.ndarray:
69 | φ = points[..., 1]
70 | sinφ, cosφ = np.sin(φ), np.cos(φ)
71 | zero = np.zeros_like(φ)
72 | return np.array( # type: ignore
73 | [
74 | [cosφ, sinφ, zero],
75 | [-sinφ, cosφ, zero],
76 | [zero, zero, zero + 1],
77 | ]
78 | )
79 |
--------------------------------------------------------------------------------
/pde/grids/coordinates/polar.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import numpy as np
8 | from numpy.typing import ArrayLike
9 |
10 | from .base import CoordinatesBase
11 |
12 |
13 | class PolarCoordinates(CoordinatesBase):
14 | """2-dimensional polar coordinates."""
15 |
16 | dim = 2
17 | axes = ["r", "φ"]
18 | _axes_alt = {"r": ["radius"], "φ": ["phi"]}
19 | coordinate_limits = [(0, np.inf), (0, 2 * np.pi)]
20 |
21 | _singleton: PolarCoordinates | None = None
22 |
23 | def __new__(cls):
24 | # cache the instances for each dimension
25 | if cls._singleton is None:
26 | cls._singleton = super().__new__(cls)
27 | return cls._singleton
28 |
29 | def __repr__(self) -> str:
30 | """Return instance as string."""
31 | return f"{self.__class__.__name__}()"
32 |
33 | def __eq__(self, other):
34 | return self.__class__ is other.__class__
35 |
36 | def _pos_to_cart(self, points: np.ndarray) -> np.ndarray:
37 | r, φ = points[..., 0], points[..., 1]
38 | x = r * np.cos(φ)
39 | y = r * np.sin(φ)
40 | return np.stack((x, y), axis=-1) # type: ignore
41 |
42 | def _pos_from_cart(self, points: np.ndarray) -> np.ndarray:
43 | x, y = points[..., 0], points[..., 1]
44 | r = np.hypot(x, y)
45 | φ = np.arctan2(y, x)
46 | return np.stack((r, φ), axis=-1) # type: ignore
47 |
48 | def _mapping_jacobian(self, points: np.ndarray) -> np.ndarray:
49 | r, φ = points[..., 0], points[..., 1]
50 | sinφ, cosφ = np.sin(φ), np.cos(φ)
51 | return np.array([[cosφ, -r * sinφ], [sinφ, r * cosφ]]) # type: ignore
52 |
53 | def _volume_factor(self, points: np.ndarray) -> ArrayLike:
54 | return points[..., 0]
55 |
56 | def _cell_volume(self, c_low: np.ndarray, c_high: np.ndarray) -> np.ndarray:
57 | r1, φ1 = c_low[..., 0], c_low[..., 1]
58 | r2, φ2 = c_high[..., 0], c_high[..., 1]
59 | return (φ2 - φ1) * (r2**2 - r1**2) / 2 # type: ignore
60 |
61 | def _scale_factors(self, points: np.ndarray) -> np.ndarray:
62 | r = points[..., 0]
63 | return np.array([np.ones_like(r), r]) # type: ignore
64 |
65 | def _basis_rotation(self, points: np.ndarray) -> np.ndarray:
66 | φ = points[..., 1]
67 | sinφ, cosφ = np.sin(φ), np.cos(φ)
68 | return np.array([[cosφ, sinφ], [-sinφ, cosφ]]) # type: ignore
69 |
--------------------------------------------------------------------------------
/pde/grids/coordinates/spherical.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | import numpy as np
8 | from numpy.typing import ArrayLike
9 |
10 | from .base import CoordinatesBase
11 |
12 |
13 | class SphericalCoordinates(CoordinatesBase):
14 | """3-dimensional spherical coordinates."""
15 |
16 | dim = 3
17 | axes = ["r", "θ", "φ"]
18 | _axes_alt = {"r": ["radius"], "θ": ["theta"], "φ": ["phi"]}
19 | coordinate_limits = [(0, np.inf), (0, np.pi), (0, 2 * np.pi)]
20 | major_axis = 0
21 |
22 | _singleton: SphericalCoordinates | None = None
23 |
24 | def __new__(cls):
25 | # cache the instances for each dimension
26 | if cls._singleton is None:
27 | cls._singleton = super().__new__(cls)
28 | return cls._singleton
29 |
30 | def __repr__(self) -> str:
31 | """Return instance as string."""
32 | return f"{self.__class__.__name__}()"
33 |
34 | def __eq__(self, other):
35 | return self.__class__ is other.__class__
36 |
37 | def _pos_to_cart(self, points: np.ndarray) -> np.ndarray:
38 | r, θ, φ = points[..., 0], points[..., 1], points[..., 2]
39 | rsinθ = r * np.sin(θ)
40 | x = rsinθ * np.cos(φ)
41 | y = rsinθ * np.sin(φ)
42 | z = r * np.cos(θ)
43 | return np.stack((x, y, z), axis=-1) # type:ignore
44 |
45 | def _pos_from_cart(self, points: np.ndarray) -> np.ndarray:
46 | x, y, z = points[..., 0], points[..., 1], points[..., 2]
47 | r = np.linalg.norm(points, axis=-1)
48 | θ = np.arctan2(np.hypot(x, y), z)
49 | φ = np.arctan2(y, x)
50 | return np.stack((r, θ, φ), axis=-1) # type:ignore
51 |
52 | def _mapping_jacobian(self, points: np.ndarray) -> np.ndarray:
53 | r, θ, φ = points[..., 0], points[..., 1], points[..., 2]
54 | sinθ, cosθ = np.sin(θ), np.cos(θ)
55 | sinφ, cosφ = np.sin(φ), np.cos(φ)
56 | return np.array( # type:ignore
57 | [
58 | [cosφ * sinθ, r * cosφ * cosθ, -r * sinφ * sinθ],
59 | [sinφ * sinθ, r * sinφ * cosθ, r * cosφ * sinθ],
60 | [cosθ, -r * sinθ, np.zeros_like(θ)],
61 | ]
62 | )
63 |
64 | def _volume_factor(self, points: np.ndarray) -> ArrayLike:
65 | r, θ = points[..., 0], points[..., 1]
66 | return r**2 * np.sin(θ)
67 |
68 | def _cell_volume(self, c_low: np.ndarray, c_high: np.ndarray):
69 | r1, θ1, φ1 = c_low[..., 0], c_low[..., 1], c_low[..., 2]
70 | r2, θ2, φ2 = c_high[..., 0], c_high[..., 1], c_high[..., 2]
71 | return (φ2 - φ1) * (np.cos(θ1) - np.cos(θ2)) * (r2**3 - r1**3) / 3
72 |
73 | def _scale_factors(self, points: np.ndarray) -> np.ndarray:
74 | r, θ = points[..., 0], points[..., 1]
75 | return np.array([np.ones_like(r), r, r * np.sin(θ)]) # type: ignore
76 |
77 | def _basis_rotation(self, points: np.ndarray) -> np.ndarray:
78 | θ, φ = points[..., 1], points[..., 2]
79 | sinθ, cosθ = np.sin(θ), np.cos(θ)
80 | sinφ, cosφ = np.sin(φ), np.cos(φ)
81 | return np.array( # type: ignore
82 | [
83 | [cosφ * sinθ, sinφ * sinθ, cosθ],
84 | [cosφ * cosθ, sinφ * cosθ, -sinθ],
85 | [-sinφ, cosφ, np.zeros_like(θ)],
86 | ]
87 | )
88 |
--------------------------------------------------------------------------------
/pde/grids/operators/__init__.py:
--------------------------------------------------------------------------------
1 | """Package collecting modules defining discretized operators for different grids.
2 |
3 | These operators can either be used directly or they are imported by the respective
4 | methods defined on fields and grids.
5 |
6 | .. autosummary::
7 | :nosignatures:
8 |
9 | cartesian
10 | cylindrical_sym
11 | polar_sym
12 | spherical_sym
13 |
14 | common.make_derivative
15 | common.make_derivative2
16 | """
17 |
18 | from . import cartesian, cylindrical_sym, polar_sym, spherical_sym
19 | from .common import make_derivative, make_derivative2
20 |
--------------------------------------------------------------------------------
/pde/pdes/__init__.py:
--------------------------------------------------------------------------------
1 | """Package that defines PDEs describing physical systems.
2 |
3 | The examples in this package are often simple version of classical PDEs to
4 | demonstrate various aspects of the `py-pde` package. Clearly, not all extensions
5 | to these PDEs can be covered here, but this should serve as a starting point for
6 | custom investigations.
7 |
8 | Publicly available methods should take fields with grid information and also
9 | only return such methods. There might be corresponding private methods that
10 | deal with raw data for faster simulations.
11 |
12 |
13 | .. autosummary::
14 | :nosignatures:
15 |
16 | ~pde.PDE
17 | ~allen_cahn.AllenCahnPDE
18 | ~cahn_hilliard.CahnHilliardPDE
19 | ~diffusion.DiffusionPDE
20 | ~kpz_interface.KPZInterfacePDE
21 | ~kuramoto_sivashinsky.KuramotoSivashinskyPDE
22 | ~swift_hohenberg.SwiftHohenbergPDE
23 | ~wave.WavePDE
24 |
25 | Additionally, we offer two solvers for typical elliptical PDEs:
26 |
27 |
28 | .. autosummary::
29 | :nosignatures:
30 |
31 | ~laplace.solve_laplace_equation
32 | ~laplace.solve_poisson_equation
33 |
34 |
35 | .. codeauthor:: David Zwicker
36 | """
37 |
38 | from .allen_cahn import AllenCahnPDE
39 | from .base import PDEBase
40 | from .cahn_hilliard import CahnHilliardPDE
41 | from .diffusion import DiffusionPDE
42 | from .kpz_interface import KPZInterfacePDE
43 | from .kuramoto_sivashinsky import KuramotoSivashinskyPDE
44 | from .laplace import solve_laplace_equation, solve_poisson_equation
45 | from .pde import PDE
46 | from .swift_hohenberg import SwiftHohenbergPDE
47 | from .wave import WavePDE
48 |
--------------------------------------------------------------------------------
/pde/pdes/laplace.py:
--------------------------------------------------------------------------------
1 | """Solvers for Poisson's and Laplace's equation.
2 |
3 | .. codeauthor:: David Zwicker
4 | """
5 |
6 | from __future__ import annotations
7 |
8 | from ..fields import ScalarField
9 | from ..grids.base import GridBase
10 | from ..grids.boundaries.axes import BoundariesData
11 | from ..tools.docstrings import fill_in_docstring
12 |
13 |
14 | @fill_in_docstring
15 | def solve_poisson_equation(
16 | rhs: ScalarField,
17 | bc: BoundariesData,
18 | *,
19 | label: str = "Solution to Poisson's equation",
20 | **kwargs,
21 | ) -> ScalarField:
22 | r"""Solve Laplace's equation on a given grid.
23 |
24 | Denoting the current field by :math:`u`, we thus solve for :math:`f`, defined by the
25 | equation
26 |
27 | .. math::
28 | \nabla^2 u(\boldsymbol r) = -f(\boldsymbol r)
29 |
30 | with boundary conditions specified by `bc`.
31 |
32 | Note:
33 | In case of periodic or Neumann boundary conditions, the right hand side
34 | :math:`f(\boldsymbol r)` needs to satisfy the following condition
35 |
36 | .. math::
37 | \int f \, \mathrm{d}V = \oint g \, \mathrm{d}S \;,
38 |
39 | where :math:`g` denotes the function specifying the outwards
40 | derivative for Neumann conditions. Note that for periodic boundaries
41 | :math:`g` vanishes, so that this condition implies that the integral
42 | over
43 | :math:`f` must vanish for neutral Neumann or periodic conditions.
44 |
45 | Args:
46 | rhs (:class:`~pde.fields.scalar.ScalarField`):
47 | The scalar field :math:`f` describing the right hand side
48 | bc:
49 | The boundary conditions applied to the field.
50 | {ARG_BOUNDARIES}
51 | label (str):
52 | The label of the returned field.
53 |
54 | Returns:
55 | :class:`~pde.fields.scalar.ScalarField`: The field :math:`u` that solves
56 | the equation. This field will be defined on the same grid as `rhs`.
57 | """
58 | # get the operator information
59 | operator = rhs.grid._get_operator_info("poisson_solver")
60 | # get the boundary conditions
61 | bcs = rhs.grid.get_boundary_conditions(bc)
62 | # get the actual solver
63 | solver = operator.factory(bcs=bcs, **kwargs)
64 |
65 | # solve the poisson problem
66 | result = ScalarField(rhs.grid, label=label)
67 | try:
68 | solver(rhs.data, result.data)
69 | except RuntimeError as err:
70 | magnitude = rhs.magnitude
71 | if magnitude > 1e-10:
72 | raise RuntimeError(
73 | "Could not solve the Poisson problem. One possible reason for this is "
74 | "that only periodic or Neumann conditions are applied although the "
75 | f"magnitude of the field is {magnitude} and thus non-zero."
76 | ) from err
77 | else:
78 | raise # another error occurred
79 |
80 | return result
81 |
82 |
83 | @fill_in_docstring
84 | def solve_laplace_equation(
85 | grid: GridBase, bc: BoundariesData, *, label: str = "Solution to Laplace's equation"
86 | ) -> ScalarField:
87 | """Solve Laplace's equation on a given grid.
88 |
89 | This is implemented by calling :func:`solve_poisson_equation` with a
90 | vanishing right hand side.
91 |
92 | Args:
93 | grid (:class:`~pde.grids.base.GridBase`):
94 | The grid on which the equation is solved
95 | bc:
96 | The boundary conditions applied to the field.
97 | {ARG_BOUNDARIES}
98 | label (str):
99 | The label of the returned field.
100 |
101 | Returns:
102 | :class:`~pde.fields.scalar.ScalarField`: The field that solves the
103 | equation. This field will be defined on the given `grid`.
104 | """
105 | rhs = ScalarField(grid, data=0)
106 | return solve_poisson_equation(rhs, bc=bc, label=label)
107 |
--------------------------------------------------------------------------------
/pde/py.typed:
--------------------------------------------------------------------------------
1 | # This file indicates that the py-pde package supports typing compliant with PEP 561
--------------------------------------------------------------------------------
/pde/solvers/__init__.py:
--------------------------------------------------------------------------------
1 | """Solvers define how a PDE is solved, i.e., how the initial state is advanced in time.
2 |
3 | .. autosummary::
4 | :nosignatures:
5 |
6 | ~controller.Controller
7 | ~explicit.ExplicitSolver
8 | ~explicit_mpi.ExplicitMPISolver
9 | ~implicit.ImplicitSolver
10 | ~crank_nicolson.CrankNicolsonSolver
11 | ~adams_bashforth.AdamsBashforthSolver
12 | ~scipy.ScipySolver
13 | ~registered_solvers
14 |
15 |
16 | Inheritance structure of the classes:
17 |
18 |
19 | .. inheritance-diagram:: adams_bashforth.AdamsBashforthSolver
20 | crank_nicolson.CrankNicolsonSolver
21 | explicit.ExplicitSolver
22 | implicit.ImplicitSolver
23 | scipy.ScipySolver
24 | explicit_mpi.ExplicitMPISolver
25 | :parts: 1
26 |
27 | .. codeauthor:: David Zwicker
28 | """
29 |
30 | from .adams_bashforth import AdamsBashforthSolver
31 | from .controller import Controller
32 | from .crank_nicolson import CrankNicolsonSolver
33 | from .explicit import ExplicitSolver
34 | from .implicit import ImplicitSolver
35 | from .scipy import ScipySolver
36 |
37 | try:
38 | from .explicit_mpi import ExplicitMPISolver
39 | except ImportError:
40 | # MPI modules do not seem to be properly available
41 | ExplicitMPISolver = None # type: ignore
42 |
43 |
44 | def registered_solvers() -> list[str]:
45 | """Returns all solvers that are currently registered.
46 |
47 | Returns:
48 | list of str: List with the names of the solvers
49 | """
50 | from .base import SolverBase
51 |
52 | return SolverBase.registered_solvers # type: ignore
53 |
54 |
55 | __all__ = [
56 | "Controller",
57 | "ExplicitSolver",
58 | "ImplicitSolver",
59 | "CrankNicolsonSolver",
60 | "AdamsBashforthSolver",
61 | "ScipySolver",
62 | "registered_solvers",
63 | ]
64 |
--------------------------------------------------------------------------------
/pde/solvers/adams_bashforth.py:
--------------------------------------------------------------------------------
1 | """Defines an explicit Adams-Bashforth solver.
2 |
3 | .. codeauthor:: David Zwicker
4 | """
5 |
6 | from __future__ import annotations
7 |
8 | from typing import Any, Callable
9 |
10 | import numba as nb
11 | import numpy as np
12 |
13 | from ..fields.base import FieldBase
14 | from ..tools.numba import jit
15 | from .base import SolverBase
16 |
17 |
18 | class AdamsBashforthSolver(SolverBase):
19 | """Explicit Adams-Bashforth multi-step solver."""
20 |
21 | name = "adams–bashforth"
22 |
23 | def _make_fixed_stepper(
24 | self, state: FieldBase, dt: float
25 | ) -> Callable[[np.ndarray, float, int, Any], float]:
26 | """Return a stepper function using an explicit scheme with fixed time steps.
27 |
28 | Args:
29 | state (:class:`~pde.fields.base.FieldBase`):
30 | An example for the state from which the grid and other information can
31 | be extracted
32 | dt (float):
33 | Time step of the explicit stepping
34 | """
35 | if self.pde.is_sde:
36 | raise NotImplementedError
37 |
38 | rhs_pde = self._make_pde_rhs(state, backend=self.backend)
39 | post_step_hook = self._make_post_step_hook(state)
40 |
41 | def single_step(
42 | state_data: np.ndarray, t: float, state_prev: np.ndarray
43 | ) -> None:
44 | """Perform a single Adams-Bashforth step."""
45 | rhs_prev = rhs_pde(state_prev, t - dt).copy()
46 | rhs_cur = rhs_pde(state_data, t)
47 | state_prev[:] = state_data # save the previous state
48 | state_data += dt * (1.5 * rhs_cur - 0.5 * rhs_prev)
49 |
50 | # allocate memory to store the state of the previous time step
51 | state_prev = np.empty_like(state.data)
52 | init_state_prev = True
53 |
54 | if self._compiled:
55 | sig_single_step = (nb.typeof(state.data), nb.double, nb.typeof(state_prev))
56 | single_step = jit(sig_single_step)(single_step)
57 |
58 | def fixed_stepper(
59 | state_data: np.ndarray, t_start: float, steps: int, post_step_data
60 | ) -> float:
61 | """Perform `steps` steps with fixed time steps."""
62 | nonlocal state_prev, init_state_prev
63 |
64 | if init_state_prev:
65 | # initialize the state_prev with an estimate of the previous step
66 | state_prev[:] = state_data - dt * rhs_pde(state_data, t_start)
67 | init_state_prev = False
68 |
69 | for i in range(steps):
70 | # calculate the right hand side
71 | t = t_start + i * dt
72 | single_step(state_data, t, state_prev)
73 | post_step_hook(state_data, t, post_step_data=post_step_data)
74 |
75 | return t + dt
76 |
77 | self._logger.info("Init explicit Adams-Bashforth stepper with dt=%g", dt)
78 |
79 | return fixed_stepper
80 |
--------------------------------------------------------------------------------
/pde/storage/__init__.py:
--------------------------------------------------------------------------------
1 | """Module defining classes for storing simulation data.
2 |
3 | .. autosummary::
4 | :nosignatures:
5 |
6 | ~memory.get_memory_storage
7 | ~memory.MemoryStorage
8 | ~modelrunner.ModelrunnerStorage
9 | ~file.FileStorage
10 | ~movie.MovieStorage
11 |
12 | .. codeauthor:: David Zwicker
13 | """
14 |
15 | import contextlib
16 |
17 | from .file import FileStorage
18 | from .memory import MemoryStorage, get_memory_storage
19 | from .movie import MovieStorage
20 |
21 | with contextlib.suppress(ImportError):
22 | from .modelrunner import ModelrunnerStorage
23 |
--------------------------------------------------------------------------------
/pde/tools/__init__.py:
--------------------------------------------------------------------------------
1 | """Package containing several tools required in py-pde.
2 |
3 | .. autosummary::
4 | :nosignatures:
5 |
6 | cache
7 | config
8 | cuboid
9 | docstrings
10 | expressions
11 | ffmpeg
12 | math
13 | misc
14 | mpi
15 | numba
16 | output
17 | parameters
18 | parse_duration
19 | plotting
20 | spectral
21 | typing
22 |
23 | .. codeauthor:: David Zwicker
24 | """
25 |
--------------------------------------------------------------------------------
/pde/tools/modelrunner.py:
--------------------------------------------------------------------------------
1 | """Establishes hooks for the interplay between :mod:`pde` and :mod:`modelrunner`
2 |
3 | This package is usually loaded automatically during import if :mod:`modelrunner` is
4 | available. In this case, grids and fields of :mod:`pde` can be directly written to
5 | storages from :mod:`modelrunner.storage`.
6 |
7 | .. codeauthor:: David Zwicker
8 | """
9 |
10 | from collections.abc import Sequence
11 |
12 | from modelrunner.storage import StorageBase, storage_actions
13 | from modelrunner.storage.utils import decode_class
14 |
15 | from ..fields.base import FieldBase
16 | from ..grids.base import GridBase
17 |
18 |
19 | # these actions are inherited by all subclasses by default
20 | def load_grid(storage: StorageBase, loc: Sequence[str]) -> GridBase:
21 | """Function loading a grid from a modelrunner storage.
22 |
23 | Args:
24 | storage (:class:`~modelrunner.storage.group.StorageGroup`):
25 | Storage to load data from
26 | loc (Location):
27 | Location in the storage
28 |
29 | Returns:
30 | :class:`~pde.grids.base.GridBase`: the loaded grid
31 | """
32 | # get grid class that was stored
33 | stored_cls = decode_class(storage._read_attrs(loc).get("__class__"))
34 | state = storage.read_attrs(loc)
35 | return stored_cls.from_state(state) # type: ignore
36 |
37 |
38 | storage_actions.register("read_item", GridBase, load_grid)
39 |
40 |
41 | def save_grid(storage: StorageBase, loc: Sequence[str], grid: GridBase) -> None:
42 | """Function saving a grid to a modelrunner storage.
43 |
44 | Args:
45 | storage (:class:`~modelrunner.storage.group.StorageGroup`):
46 | Storage to save data to
47 | loc (Location):
48 | Location in the storage
49 | grid (:class:`~pde.grids.base.GridBase`):
50 | the grid to store
51 | """
52 | storage.write_object(loc, None, attrs=grid.state, cls=grid.__class__)
53 |
54 |
55 | storage_actions.register("write_item", GridBase, save_grid)
56 |
57 |
58 | # these actions are inherited by all subclasses by default
59 | def load_field(storage: StorageBase, loc: Sequence[str]) -> FieldBase:
60 | """Function loading a field from a modelrunner storage.
61 |
62 | Args:
63 | storage (:class:`~modelrunner.storage.group.StorageGroup`):
64 | Storage to load data from
65 | loc (Location):
66 | Location in the storage
67 |
68 | Returns:
69 | :class:`~pde.fields.base.FieldBase`: the loaded field
70 | """
71 | # get field class that was stored
72 | stored_cls = decode_class(storage._read_attrs(loc).get("__class__"))
73 | attributes = stored_cls.unserialize_attributes(storage.read_attrs(loc)) # type: ignore
74 | return stored_cls.from_state(attributes, data=storage.read_array(loc)) # type: ignore
75 |
76 |
77 | storage_actions.register("read_item", FieldBase, load_field)
78 |
79 |
80 | def save_field(storage: StorageBase, loc: Sequence[str], field: FieldBase) -> None:
81 | """Function saving a field to a modelrunner storage.
82 |
83 | Args:
84 | storage (:class:`~modelrunner.storage.group.StorageGroup`):
85 | Storage to save data to
86 | loc (Location):
87 | Location in the storage
88 | field (:class:`~pde.fields.base.FieldBase`):
89 | the field to store
90 | """
91 | storage.write_array(
92 | loc, field.data, attrs=field.attributes_serialized, cls=field.__class__
93 | )
94 |
95 |
96 | storage_actions.register("write_item", FieldBase, save_field)
97 |
98 |
99 | __all__: list[str] = [] # module only registers hooks and does not export any functions
100 |
--------------------------------------------------------------------------------
/pde/tools/resources/requirements_basic.txt:
--------------------------------------------------------------------------------
1 | # These are the basic requirements for the package
2 | matplotlib>=3.1
3 | numba>=0.59
4 | numpy>=1.22
5 | scipy>=1.10
6 | sympy>=1.9
7 | tqdm>=4.66
8 |
--------------------------------------------------------------------------------
/pde/tools/resources/requirements_full.txt:
--------------------------------------------------------------------------------
1 | # These are the full requirements used to test all functions
2 | ffmpeg-python>=0.2
3 | h5py>=2.10
4 | ipywidgets>=8
5 | matplotlib>=3.1
6 | numba>=0.59
7 | numpy>=1.22
8 | pandas>=2
9 | py-modelrunner>=0.19
10 | rocket-fft>=0.2.4
11 | scipy>=1.10
12 | sympy>=1.9
13 | tqdm>=4.66
14 |
--------------------------------------------------------------------------------
/pde/tools/resources/requirements_mpi.txt:
--------------------------------------------------------------------------------
1 | # These are requirements for supporting multiprocessing
2 | h5py>=2.10
3 | matplotlib>=3.1
4 | mpi4py>=3
5 | numba>=0.59
6 | numba-mpi>=0.22
7 | numpy>=1.22
8 | pandas>=2
9 | scipy>=1.10
10 | sympy>=1.9
11 | tqdm>=4.66
12 |
--------------------------------------------------------------------------------
/pde/tools/typing.py:
--------------------------------------------------------------------------------
1 | """Provides support for mypy type checking of the package.
2 |
3 | .. codeauthor:: David Zwicker
4 | """
5 |
6 | from __future__ import annotations
7 |
8 | from typing import TYPE_CHECKING, Literal, Protocol, Union
9 |
10 | import numpy as np
11 | from numpy.typing import ArrayLike
12 |
13 | if TYPE_CHECKING:
14 | from ..grids.base import GridBase
15 |
16 | Real = Union[int, float]
17 | Number = Union[Real, complex]
18 | NumberOrArray = Union[Number, np.ndarray]
19 | FloatNumerical = Union[float, np.ndarray]
20 | BackendType = Literal["auto", "numpy", "numba"]
21 |
22 |
23 | class OperatorType(Protocol):
24 | """An operator that acts on an array."""
25 |
26 | def __call__(self, arr: np.ndarray, out: np.ndarray) -> None:
27 | """Evaluate the operator."""
28 |
29 |
30 | class OperatorFactory(Protocol):
31 | """A factory function that creates an operator for a particular grid."""
32 |
33 | def __call__(self, grid: GridBase, **kwargs) -> OperatorType:
34 | """Create the operator."""
35 |
36 |
37 | class CellVolume(Protocol):
38 | def __call__(self, *args: int) -> float:
39 | """Calculate the volume of the cell at the given position."""
40 |
41 |
42 | class VirtualPointEvaluator(Protocol):
43 | def __call__(self, arr: np.ndarray, idx: tuple[int, ...], args=None) -> float:
44 | """Evaluate the virtual point at the given position."""
45 |
46 |
47 | class GhostCellSetter(Protocol):
48 | def __call__(self, data_full: np.ndarray, args=None) -> None:
49 | """Set the ghost cells."""
50 |
51 |
52 | class StepperHook(Protocol):
53 | def __call__(
54 | self, state_data: np.ndarray, t: float, post_step_data: np.ndarray
55 | ) -> None:
56 | """Function analyzing and potentially modifying the current state."""
57 |
--------------------------------------------------------------------------------
/pde/trackers/__init__.py:
--------------------------------------------------------------------------------
1 | """Classes for tracking simulation results in controlled interrupts.
2 |
3 | Trackers are classes that periodically receive the state of the simulation to analyze,
4 | store, or output it. The trackers defined in this module are:
5 |
6 | .. autosummary::
7 | :nosignatures:
8 |
9 | ~trackers.CallbackTracker
10 | ~trackers.ProgressTracker
11 | ~trackers.PrintTracker
12 | ~trackers.PlotTracker
13 | ~trackers.LivePlotTracker
14 | ~trackers.DataTracker
15 | ~trackers.SteadyStateTracker
16 | ~trackers.RuntimeTracker
17 | ~trackers.ConsistencyTracker
18 | ~interactive.InteractivePlotTracker
19 |
20 | Some trackers can also be referenced by name for convenience when using them in
21 | simulations. The lit of supported names is returned by
22 | :func:`~pde.trackers.base.get_named_trackers`.
23 |
24 | Multiple trackers can be collected in a :class:`~base.TrackerCollection`, which provides
25 | methods for handling them efficiently. Moreover, custom trackers can be implemented by
26 | deriving from :class:`~.trackers.base.TrackerBase`. Note that trackers generally receive
27 | a view into the current state, implying that they can adjust the state by modifying it
28 | in-place. Moreover, trackers can abort the simulation by raising the special exception
29 | :class:`StopIteration`.
30 |
31 |
32 | For each tracker, the time at which the simulation is interrupted can be decided using
33 | one of the following classes:
34 |
35 | .. autosummary::
36 | :nosignatures:
37 |
38 | ~interrupts.FixedInterrupts
39 | ~interrupts.ConstantInterrupts
40 | ~interrupts.LogarithmicInterrupts
41 | ~interrupts.GeometricInterrupts
42 | ~interrupts.RealtimeInterrupts
43 |
44 | In particular, interrupts can be specified conveniently using
45 | :func:`~interrupts.parse_interrupt`.
46 |
47 | .. codeauthor:: David Zwicker
48 | """
49 |
50 | from .base import get_named_trackers
51 | from .interactive import InteractivePlotTracker
52 | from .interrupts import (
53 | ConstantInterrupts,
54 | FixedInterrupts,
55 | LogarithmicInterrupts,
56 | RealtimeInterrupts,
57 | parse_interrupt,
58 | )
59 | from .trackers import *
60 |
--------------------------------------------------------------------------------
/pde/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | """Functions and classes for visualizing simulations.
2 |
3 | .. autosummary::
4 | :nosignatures:
5 |
6 | movies
7 | plotting
8 |
9 | .. codeauthor:: David Zwicker
10 | """
11 |
12 | from .movies import movie, movie_multiple, movie_scalar
13 | from .plotting import plot_interactive, plot_kymograph, plot_kymographs, plot_magnitudes
14 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib>=3.1
2 | numba>=0.59
3 | numpy>=1.22
4 | scipy>=1.10
5 | sympy>=1.9
6 | tqdm>=4.66
7 |
--------------------------------------------------------------------------------
/runtime.txt:
--------------------------------------------------------------------------------
1 | python-3.9
--------------------------------------------------------------------------------
/scripts/_templates/_runtime.txt:
--------------------------------------------------------------------------------
1 | python-$MIN_PYTHON_VERSION
--------------------------------------------------------------------------------
/scripts/create_storage_test_resources.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """This script creates storage files for backwards compatibility tests."""
3 |
4 | from __future__ import annotations
5 |
6 | import sys
7 | from pathlib import Path
8 |
9 | PACKAGE_PATH = Path(__file__).resolve().parents[1]
10 | sys.path.insert(0, str(PACKAGE_PATH))
11 |
12 | import pde
13 |
14 |
15 | def create_storage_test_resources(path, num):
16 | """Test storing scalar field as movie."""
17 | grid = pde.CylindricalSymGrid(3, [1, 2], [2, 2])
18 | field = pde.ScalarField(grid, [[1, 3], [2, 4]])
19 | eq = pde.DiffusionPDE()
20 | info = {"payload": "storage-test"}
21 | movie_writer = pde.MovieStorage(
22 | path / f"storage_{num}.avi",
23 | info=info,
24 | vmax=4,
25 | bits_per_channel=16,
26 | write_times=True,
27 | )
28 | file_writer = pde.FileStorage(path / f"storage_{num}.hdf5", info=info)
29 | interrupts = pde.FixedInterrupts([0.1, 0.7, 2.9])
30 | eq.solve(
31 | field,
32 | t_range=3.5,
33 | dt=0.1,
34 | backend="numpy",
35 | tracker=[movie_writer.tracker(interrupts), file_writer.tracker(interrupts)],
36 | )
37 |
38 |
39 | def main():
40 | """Main function creating all the requirements."""
41 | root = Path(PACKAGE_PATH)
42 | create_storage_test_resources(root / "tests" / "storage" / "resources", 2)
43 |
44 |
45 | if __name__ == "__main__":
46 | main()
47 |
--------------------------------------------------------------------------------
/scripts/format_code.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # This script formats the code of this package
3 |
4 | echo "Formatting import statements..."
5 | ruff check --fix --config=../pyproject.toml ..
6 |
7 | echo "Formatting docstrings..."
8 | docformatter --in-place --black --recursive ..
9 |
10 | echo "Formatting source code..."
11 | ruff format --config=../pyproject.toml ..
--------------------------------------------------------------------------------
/scripts/performance_boundaries.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """This script tests the performance of the implementation of different boundary
3 | conditions."""
4 |
5 | import sys
6 | from pathlib import Path
7 |
8 | PACKAGE_PATH = Path(__file__).resolve().parents[1]
9 | sys.path.insert(0, str(PACKAGE_PATH))
10 |
11 | import numpy as np
12 |
13 | from pde import ScalarField, UnitGrid
14 | from pde.tools.misc import estimate_computation_speed
15 | from pde.tools.numba import numba_dict
16 |
17 |
18 | def main():
19 | """Main routine testing the performance."""
20 | print("Reports calls-per-second (larger is better)\n")
21 |
22 | # Cartesian grid with different shapes and boundary conditions
23 | for size in [32, 512]:
24 | grid = UnitGrid([size, size], periodic=False)
25 | print(grid)
26 |
27 | field = ScalarField.random_normal(grid)
28 | bc_value = np.ones(size)
29 | result = field.laplace(bc={"value": 1}).data
30 |
31 | for bc in ["scalar", "array", "function", "time-dependent", "linked"]:
32 | if bc == "scalar":
33 | bcs = {"value": 1}
34 | elif bc == "array":
35 | bcs = {"value": bc_value}
36 | elif bc == "function":
37 | bcs = grid.get_boundary_conditions({"virtual_point": "2 - value"})
38 | elif bc == "time-dependent":
39 | bcs = grid.get_boundary_conditions({"value_expression": "t"})
40 | elif bc == "linked":
41 | bcs = grid.get_boundary_conditions({"value": bc_value})
42 | for ax, upper in grid._iter_boundaries():
43 | bcs[ax][upper].link_value(bc_value)
44 | else:
45 | raise RuntimeError
46 |
47 | # create the operator with these conditions
48 | laplace = grid.make_operator("laplace", bc=bcs)
49 | if bc == "time-dependent":
50 | args = numba_dict(t=1)
51 | # call once to pre-compile and test result
52 | np.testing.assert_allclose(laplace(field.data, args=args), result)
53 | # estimate the speed
54 | speed = estimate_computation_speed(laplace, field.data, args=args)
55 |
56 | else:
57 | # call once to pre-compile and test result
58 | np.testing.assert_allclose(laplace(field.data), result)
59 | # estimate the speed
60 | speed = estimate_computation_speed(laplace, field.data)
61 |
62 | print(f"{bc:>14s}:{int(speed):>9d}")
63 |
64 | print()
65 |
66 |
67 | if __name__ == "__main__":
68 | main()
69 |
--------------------------------------------------------------------------------
/scripts/performance_solvers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """This script tests the performance of different solvers."""
3 |
4 | import sys
5 | from pathlib import Path
6 | from typing import Literal
7 |
8 | PACKAGE_PATH = Path(__file__).resolve().parents[1]
9 | sys.path.insert(0, str(PACKAGE_PATH))
10 |
11 | import numpy as np
12 |
13 | from pde import CahnHilliardPDE, Controller, DiffusionPDE, ScalarField, UnitGrid
14 | from pde.solvers import (
15 | AdamsBashforthSolver,
16 | CrankNicolsonSolver,
17 | ExplicitSolver,
18 | ImplicitSolver,
19 | ScipySolver,
20 | )
21 |
22 |
23 | def main(
24 | equation: Literal["diffusion", "cahn-hilliard"] = "cahn-hilliard",
25 | t_range: float = 100,
26 | size: int = 32,
27 | ):
28 | """Main routine testing the performance.
29 |
30 | Args:
31 | equation (str):
32 | Chooses the equation to consider
33 | t_range (float):
34 | Sets the total duration that should be solved for
35 | size (int):
36 | The number of grid points along each axis
37 | """
38 | print("Reports duration in seconds (smaller is better)\n")
39 |
40 | # determine grid and initial state
41 | grid = UnitGrid([size, size], periodic=False)
42 | field = ScalarField.random_uniform(grid)
43 | print(f"GRID: {grid}")
44 |
45 | # determine the equation to solve
46 | if equation == "diffusion":
47 | eq = DiffusionPDE()
48 | elif equation == "cahn-hilliard":
49 | eq = CahnHilliardPDE()
50 | else:
51 | raise ValueError(f"Undefined equation `{equation}`")
52 | print(f"EQUATION: ∂c/∂t = {eq.expression}\n")
53 |
54 | print("Determine ground truth...")
55 | expected = eq.solve(field, t_range=t_range, dt=1e-4, tracker=["progress"])
56 |
57 | print("\nSOLVER PERFORMANCE:")
58 | solvers = {
59 | "Euler, fixed": (1e-3, ExplicitSolver(eq, scheme="euler", adaptive=False)),
60 | "Euler, adaptive": (1e-3, ExplicitSolver(eq, scheme="euler", adaptive=True)),
61 | "Runge-Kutta, fixed": (1e-2, ExplicitSolver(eq, scheme="rk", adaptive=False)),
62 | "Runge-Kutta, adaptive": (1e-2, ExplicitSolver(eq, scheme="rk", adaptive=True)),
63 | "Implicit": (1e-2, ImplicitSolver(eq)),
64 | "Adams-Bashforth": (1e-2, AdamsBashforthSolver(eq)),
65 | "Crank-Nicolson": (1e-2, CrankNicolsonSolver(eq)),
66 | "Scipy": (None, ScipySolver(eq)),
67 | }
68 |
69 | for name, (dt, solver) in solvers.items():
70 | # run the simulation with the given solver
71 | solver.backend = "numba"
72 | controller = Controller(solver, t_range=t_range, tracker=None)
73 | result = controller.run(field, dt=dt)
74 |
75 | # determine the deviation from the ground truth
76 | error = np.linalg.norm(result.data - expected.data)
77 | error_str = f"{error:.4g}"
78 |
79 | # report the runtime
80 | runtime_str = f"{controller.info['profiler']['solver']:.3g}"
81 |
82 | # report information about the time step
83 | if solver.info.get("dt_adaptive", False):
84 | stats = solver.info["dt_statistics"]
85 | dt_str = f"{stats['min']:.3g} .. {stats['max']:.3g}"
86 | elif solver.info["dt"] is None:
87 | dt_str = "automatic"
88 | else:
89 | dt_str = f"{solver.info['dt']:.3g}"
90 |
91 | print(f"{name:>21s}: runtime={runtime_str:8} error={error_str:11} dt={dt_str}")
92 |
93 |
94 | if __name__ == "__main__":
95 | main()
96 |
--------------------------------------------------------------------------------
/scripts/profile_import.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """This scripts measures the total time it takes to import the module.
3 |
4 | The total time should ideally be below 1 second.
5 | """
6 |
7 | import sys
8 | from pathlib import Path
9 |
10 | PACKAGE_PATH = Path(__file__).resolve().parents[1]
11 | sys.path.insert(0, str(PACKAGE_PATH))
12 |
13 |
14 | from pyinstrument import Profiler
15 |
16 | with Profiler() as profiler:
17 | import pde
18 |
19 | print(profiler.open_in_browser())
20 |
--------------------------------------------------------------------------------
/scripts/show_environment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """This script shows important information about the current python environment and the
3 | associated installed packages.
4 |
5 | This information can be helpful in understanding issues that occur with the package
6 | """
7 |
8 | import sys
9 | from pathlib import Path
10 |
11 | PACKAGE_PATH = Path(__file__).resolve().parents[1]
12 | sys.path.insert(0, str(PACKAGE_PATH))
13 |
14 | from pde import environment
15 |
16 | for category, data in environment().items():
17 | if hasattr(data, "items"):
18 | print(f"\n{category}:")
19 | for key, value in data.items():
20 | print(f" {key}: {value}")
21 | else:
22 | data_formatted = str(data).replace("\n", "\n ")
23 | print(f"{category}: {data_formatted}")
24 |
--------------------------------------------------------------------------------
/scripts/tests_all.sh:
--------------------------------------------------------------------------------
1 | echo "Test codestyle"
2 | echo "--------------"
3 | ./tests_codestyle.sh
4 |
5 | echo ""
6 | echo "Test types"
7 | echo "----------"
8 | ./tests_types.sh
9 |
10 | echo ""
11 | echo "Run serial unittests"
12 | echo "-------------"
13 | ./tests_parallel.sh
14 |
15 | echo ""
16 | echo "Run parallel unittests"
17 | echo "-------------"
18 | ./tests_mpi.sh
--------------------------------------------------------------------------------
/scripts/tests_codestyle.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # This script checks the code format of this package without changing files
4 | #
5 |
6 | ./run_tests.py --style
--------------------------------------------------------------------------------
/scripts/tests_coverage.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo 'Run serial tests to determine coverage...'
4 | ./run_tests.py --unit --coverage --nojit --num_cores auto --runslow
5 |
6 | echo 'Run parallel tests to determine coverage...'
7 | ./run_tests.py --unit --coverage --nojit --use_mpi --runslow -- --cov-append
8 |
--------------------------------------------------------------------------------
/scripts/tests_debug.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export PYTHONPATH=../py-pde # likely path of pde package, relative to current base path
4 |
5 | if [ ! -z $1 ]
6 | then
7 | # test pattern was specified
8 | echo 'Run unittests with pattern '$1':'
9 | ./run_tests.py --unit --runslow --nojit --pattern "$1" -- \
10 | -o log_cli=true --log-cli-level=debug -vv
11 | else
12 | # test pattern was not specified
13 | echo 'Run all unittests:'
14 | ./run_tests.py --unit --nojit -- \
15 | -o log_cli=true --log-cli-level=debug -vv
16 | fi
17 |
--------------------------------------------------------------------------------
/scripts/tests_extensive.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # test pattern was not specified
4 | echo 'Run all unittests:'
5 | ./run_tests.py --unit --runslow --num_cores auto -- -rsx
6 |
--------------------------------------------------------------------------------
/scripts/tests_mpi.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ ! -z $1 ]
4 | then
5 | # test pattern was specified
6 | echo 'Run unittests with pattern '$1':'
7 | ./run_tests.py --unit --use_mpi --runslow --pattern "$1"
8 | else
9 | # test pattern was not specified
10 | echo 'Run all unittests:'
11 | ./run_tests.py --unit --use_mpi
12 | fi
13 |
--------------------------------------------------------------------------------
/scripts/tests_parallel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ ! -z $1 ]
4 | then
5 | # test pattern was specified
6 | echo 'Run unittests with pattern '$1':'
7 | ./run_tests.py --unit --runslow --num_cores auto --pattern "$1"
8 | else
9 | # test pattern was not specified
10 | echo 'Run all unittests:'
11 | ./run_tests.py --unit --num_cores auto
12 | fi
13 |
--------------------------------------------------------------------------------
/scripts/tests_run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ ! -z $1 ]
4 | then
5 | # test pattern was specified
6 | echo 'Run unittests with pattern '$1':'
7 | ./run_tests.py --unit --runslow --pattern "$1"
8 | else
9 | # test pattern was not specified
10 | echo 'Run all unittests:'
11 | ./run_tests.py --unit
12 | fi
13 |
--------------------------------------------------------------------------------
/scripts/tests_types.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ./run_tests.py --types
4 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """This file is used to configure the test environment when running py.test.
2 |
3 | .. codeauthor:: David Zwicker
4 | """
5 |
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import pytest
9 |
10 | from pde import config
11 | from pde.tools.misc import module_available
12 | from pde.tools.numba import random_seed
13 |
14 | # ensure we use the Agg backend, so figures are not displayed
15 | plt.switch_backend("agg")
16 |
17 |
18 | @pytest.fixture(autouse=True)
19 | def _setup_and_teardown():
20 | """Helper function adjusting environment before and after tests."""
21 | # raise all underflow errors
22 | np.seterr(all="raise", under="ignore")
23 |
24 | # run the actual test
25 | with config({"boundaries.accept_lists": False}):
26 | yield
27 |
28 | # clean up open matplotlib figures after the test
29 | plt.close("all")
30 |
31 |
32 | @pytest.fixture(autouse=False, name="rng")
33 | def init_random_number_generators():
34 | """Get a random number generator and set the seed of the random number generator.
35 |
36 | The function returns an instance of :func:`~numpy.random.default_rng()` and
37 | initializes the default generators of both :mod:`numpy` and :mod:`numba`.
38 | """
39 | random_seed()
40 | return np.random.default_rng(0)
41 |
42 |
43 | def pytest_configure(config):
44 | """Add markers to the configuration."""
45 | config.addinivalue_line("markers", "interactive: test is interactive")
46 | config.addinivalue_line("markers", "multiprocessing: test requires multiprocessing")
47 | config.addinivalue_line("markers", "slow: test runs slowly")
48 |
49 |
50 | def pytest_addoption(parser):
51 | """Pytest hook to add command line options parsed by pytest."""
52 | parser.addoption(
53 | "--runslow",
54 | action="store_true",
55 | default=False,
56 | help="also run tests marked by `slow`",
57 | )
58 | parser.addoption(
59 | "--runinteractive",
60 | action="store_true",
61 | default=False,
62 | help="also run tests marked by `interactive`",
63 | )
64 | parser.addoption(
65 | "--use_mpi",
66 | action="store_true",
67 | default=False,
68 | help="only run tests marked by `multiprocessing`",
69 | )
70 |
71 |
72 | def pytest_collection_modifyitems(config, items):
73 | """Pytest hook to filter a collection of tests."""
74 | # parse options provided to py.test
75 | running_cov = config.getvalue("--cov")
76 | runslow = config.getoption("--runslow", default=False)
77 | runinteractive = config.getoption("--runinteractive", default=False)
78 | use_mpi = config.getoption("--use_mpi", default=False)
79 | has_numba_mpi = module_available("numba_mpi") and module_available("mpi4py")
80 |
81 | # prepare markers
82 | skip_cov = pytest.mark.skip(reason="skipped during coverage run")
83 | skip_slow = pytest.mark.skip(reason="need --runslow option to run")
84 | skip_interactive = pytest.mark.skip(reason="need --runinteractive option to run")
85 | skip_serial = pytest.mark.skip(reason="serial test, but --use_mpi option was set")
86 | skip_mpi = pytest.mark.skip(reason="mpi test, but `numba_mpi` not available")
87 |
88 | # check each test item
89 | for item in items:
90 | if "no_cover" in item.keywords and running_cov:
91 | item.add_marker(skip_cov)
92 | if "slow" in item.keywords and not runslow:
93 | item.add_marker(skip_slow)
94 | if "interactive" in item.keywords and not runinteractive:
95 | item.add_marker(skip_interactive)
96 |
97 | if "multiprocessing" in item.keywords and not has_numba_mpi:
98 | item.add_marker(skip_mpi)
99 | if use_mpi and "multiprocessing" not in item.keywords:
100 | item.add_marker(skip_serial)
101 |
--------------------------------------------------------------------------------
/tests/fields/fixtures/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/fields/fixtures/__init__.py
--------------------------------------------------------------------------------
/tests/fields/fixtures/fields.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 |
7 | from pde import (
8 | CartesianGrid,
9 | CylindricalSymGrid,
10 | FieldCollection,
11 | PolarSymGrid,
12 | ScalarField,
13 | SphericalSymGrid,
14 | Tensor2Field,
15 | UnitGrid,
16 | VectorField,
17 | )
18 |
19 |
20 | def iter_grids():
21 | """Generator providing some test grids."""
22 | for periodic in [True, False]:
23 | yield UnitGrid([3], periodic=periodic)
24 | yield UnitGrid([3, 3, 3], periodic=periodic)
25 | yield CartesianGrid([[-1, 2], [0, 3]], [5, 7], periodic=periodic)
26 | yield CylindricalSymGrid(3, [-1, 2], [7, 8], periodic_z=periodic)
27 | yield PolarSymGrid(3, 4)
28 | yield SphericalSymGrid(3, 4)
29 |
30 |
31 | def iter_fields():
32 | """Generator providing some test fields."""
33 | yield ScalarField(UnitGrid([1, 2, 3]), 1)
34 | yield VectorField.from_expression(PolarSymGrid(2, 3), ["r**2", "r"])
35 | yield Tensor2Field.random_normal(
36 | CylindricalSymGrid(3, [-1, 2], [7, 8], periodic_z=True)
37 | )
38 |
39 | grid = CartesianGrid([[0, 2], [-1, 1]], [3, 4], [True, False])
40 | yield FieldCollection([ScalarField(grid, 1), VectorField(grid, 2)])
41 |
42 |
43 | def get_cartesian_grid(dim=2, periodic=True):
44 | """Return a random Cartesian grid of given dimension."""
45 | rng = np.random.default_rng(0)
46 | bounds = [[0, 1 + rng.random()] for _ in range(dim)]
47 | shape = rng.integers(32, 64, size=dim)
48 | return CartesianGrid(bounds, shape, periodic=periodic)
49 |
--------------------------------------------------------------------------------
/tests/grids/boundaries/test_axis_boundaries.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import itertools
6 |
7 | import pytest
8 |
9 | from pde import UnitGrid
10 | from pde.grids.boundaries.axis import BoundaryPair, get_boundary_axis
11 | from pde.grids.boundaries.local import BCBase
12 |
13 |
14 | def test_boundary_pair():
15 | """Test setting boundary conditions for whole axis."""
16 | g = UnitGrid([2, 3])
17 | b = ["value", {"type": "derivative", "value": 1}]
18 | for bl, bh in itertools.product(b, b):
19 | bc = BoundaryPair.from_data(g, 0, [bl, bh])
20 | blo = BCBase.from_data(g, 0, upper=False, data=bl)
21 | bho = BCBase.from_data(g, 0, upper=True, data=bh)
22 |
23 | assert bc.low == blo
24 | assert bc.high == bho
25 | assert bc == BoundaryPair(blo, bho)
26 | if bl == bh:
27 | assert bc == BoundaryPair.from_data(g, 0, bl)
28 | assert list(bc) == [blo, bho]
29 | assert isinstance(str(bc), str)
30 | assert isinstance(repr(bc), str)
31 |
32 | bc.check_value_rank(0)
33 | with pytest.raises(RuntimeError):
34 | bc.check_value_rank(1)
35 |
36 | data = {"low": {"value": 1}, "high": {"derivative": 2}}
37 | bc1 = BoundaryPair.from_data(g, 0, data)
38 | bc2 = BoundaryPair.from_data(g, 0, data)
39 | assert bc1 == bc2
40 | assert bc1 is not bc2
41 | bc2 = BoundaryPair.from_data(g, 1, data)
42 | assert bc1 != bc2
43 | assert bc1 is not bc2
44 |
45 | # miscellaneous methods
46 | data = {"low": {"value": 0}, "high": {"derivative": 0}}
47 | bc1 = BoundaryPair.from_data(g, 0, data)
48 | b_lo, b_hi = bc1
49 | assert b_lo == BCBase.from_data(g, 0, False, {"value": 0})
50 | assert b_hi == BCBase.from_data(g, 0, True, {"derivative": 0})
51 | assert b_lo is bc1[0]
52 | assert b_lo is bc1[False]
53 | assert b_hi is bc1[1]
54 | assert b_hi is bc1[True]
55 |
56 |
57 | def test_get_axis_boundaries():
58 | """Test setting boundary conditions including periodic ones."""
59 | for data in ["value", "derivative", "periodic", "anti-periodic"]:
60 | g = UnitGrid([2], periodic=("periodic" in data))
61 | b = get_boundary_axis(g, 0, data)
62 | assert str(b) == '"' + data + '"'
63 | b1, b2 = b.get_mathematical_representation("field")
64 | assert "field" in b1
65 | assert "field" in b2
66 |
67 | if "periodic" in data:
68 | assert b.periodic
69 | assert len(list(b)) == 2
70 | assert b.flip_sign == (data == "anti-periodic")
71 | else:
72 | assert not b.periodic
73 | assert len(list(b)) == 2
74 |
75 | # check double setting
76 | c = get_boundary_axis(g, 0, (data, data))
77 | assert b == c
78 |
--------------------------------------------------------------------------------
/tests/grids/operators/test_common_operators.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import random
6 |
7 | import numpy as np
8 | import pytest
9 |
10 | from pde import CartesianGrid, ScalarField
11 | from pde.grids.operators import common as ops
12 |
13 |
14 | @pytest.mark.parametrize("ndim,axis", [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
15 | def test_make_derivative(ndim, axis, rng):
16 | """Test the _make_derivative function."""
17 | periodic = random.choice([True, False])
18 | grid = CartesianGrid([[0, 6 * np.pi]] * ndim, 16, periodic=periodic)
19 | field = ScalarField.random_harmonic(grid, modes=1, axis_combination=np.add, rng=rng)
20 |
21 | bcs = grid.get_boundary_conditions("auto_periodic_neumann")
22 | grad = field.gradient(bcs)
23 | for method in ["central", "forward", "backward"]:
24 | msg = f"method={method}, periodic={periodic}"
25 | diff = ops.make_derivative(grid, axis=axis, method=method)
26 | res = field.copy()
27 | res.data[:] = 0
28 | field.set_ghost_cells(bcs)
29 | diff(field._data_full, out=res.data)
30 | np.testing.assert_allclose(
31 | grad.data[axis], res.data, atol=0.1, rtol=0.1, err_msg=msg
32 | )
33 |
34 |
35 | @pytest.mark.parametrize("ndim,axis", [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
36 | def test_make_derivative2(ndim, axis, rng):
37 | """Test the _make_derivative2 function."""
38 | periodic = random.choice([True, False])
39 | grid = CartesianGrid([[0, 6 * np.pi]] * ndim, 16, periodic=periodic)
40 | field = ScalarField.random_harmonic(grid, modes=1, axis_combination=np.add, rng=rng)
41 |
42 | bcs = grid.get_boundary_conditions("auto_periodic_neumann")
43 | grad = field.gradient(bcs)[axis]
44 | grad2 = grad.gradient(bcs)[axis]
45 |
46 | diff = ops.make_derivative2(grid, axis=axis)
47 | res = field.copy()
48 | res.data[:] = 0
49 | field.set_ghost_cells(bcs)
50 | diff(field._data_full, out=res.data)
51 | np.testing.assert_allclose(grad2.data, res.data, atol=0.1, rtol=0.1)
52 |
--------------------------------------------------------------------------------
/tests/grids/test_cylindrical_grids.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from pde import CartesianGrid, CylindricalSymGrid, ScalarField
9 | from pde.grids.boundaries.local import NeumannBC
10 |
11 |
12 | @pytest.mark.parametrize("periodic", [True, False])
13 | @pytest.mark.parametrize("r_inner", [0, 2])
14 | def test_cylindrical_grid(periodic, r_inner, rng):
15 | """Test simple cylindrical grid."""
16 | grid = CylindricalSymGrid((r_inner, 4), (-1, 2), (8, 9), periodic_z=periodic)
17 | if r_inner == 0:
18 | assert grid == CylindricalSymGrid(4, (-1, 2), (8, 9), periodic_z=periodic)
19 | rs, zs = grid.axes_coords
20 |
21 | assert grid.dim == 3
22 | assert grid.numba_type == "f8[:, :]"
23 | assert grid.shape == (8, 9)
24 | assert grid.length == pytest.approx(3)
25 | assert grid.discretization[1] == pytest.approx(1 / 3)
26 | assert grid.volume == pytest.approx(3 * np.pi * (4**2 - r_inner**2))
27 | assert not grid.uniform_cell_volumes
28 | assert grid.volume == pytest.approx(grid.integrate(1))
29 | np.testing.assert_allclose(zs, np.linspace(-1 + 1 / 6, 2 - 1 / 6, 9))
30 |
31 | if r_inner == 0:
32 | assert grid.discretization[0] == pytest.approx(0.5)
33 | np.testing.assert_array_equal(grid.discretization, np.array([0.5, 1 / 3]))
34 | np.testing.assert_allclose(rs, np.linspace(0.25, 3.75, 8))
35 | else:
36 | assert grid.discretization[0] == pytest.approx(0.25)
37 | np.testing.assert_array_equal(grid.discretization, np.array([0.25, 1 / 3]))
38 | np.testing.assert_allclose(rs, np.linspace(2.125, 3.875, 8))
39 |
40 | assert grid.contains_point(grid.get_random_point(coords="cartesian", rng=rng))
41 | ps = [grid.get_random_point(coords="cartesian", rng=rng) for _ in range(2)]
42 | assert all(grid.contains_point(ps))
43 | ps = grid.get_random_point(coords="cartesian", boundary_distance=1.49, rng=rng)
44 | assert grid.contains_point(ps)
45 | assert "laplace" in grid.operators
46 |
47 |
48 | def test_cylindrical_to_cartesian():
49 | """Test conversion of cylindrical grid to Cartesian."""
50 | expr_cyl = "cos(z / 2) / (1 + r**2)"
51 | expr_cart = expr_cyl.replace("r**2", "(x**2 + y**2)")
52 |
53 | z_range = (-np.pi, 2 * np.pi)
54 | grid_cyl = CylindricalSymGrid(10, z_range, (16, 33))
55 | pf_cyl = ScalarField.from_expression(grid_cyl, expression=expr_cyl)
56 |
57 | grid_cart = CartesianGrid([[-7, 7], [-6, 7], z_range], [16, 16, 16])
58 | pf_cart1 = pf_cyl.interpolate_to_grid(grid_cart)
59 | pf_cart2 = ScalarField.from_expression(grid_cart, expression=expr_cart)
60 | np.testing.assert_allclose(pf_cart1.data, pf_cart2.data, atol=0.1)
61 |
62 |
63 | def test_setting_boundary_conditions():
64 | """Test various versions of settings bcs for cylindrical grids."""
65 | grid = CylindricalSymGrid(1, [0, 1], [2, 2], periodic_z=False)
66 | grid.get_boundary_conditions("auto_periodic_neumann")
67 | grid.get_boundary_conditions({"r": "derivative", "z": "derivative"})
68 | with pytest.raises(RuntimeError):
69 | grid.get_boundary_conditions({"r": "derivative", "z": "periodic"})
70 |
71 | b_inner = NeumannBC(grid, 0, upper=False)
72 | assert grid.get_boundary_conditions("auto_periodic_neumann")[0].low == b_inner
73 | assert grid.get_boundary_conditions({"value": 2})[0].low != b_inner
74 |
75 | grid = CylindricalSymGrid(1, [0, 1], [2, 2], periodic_z=True)
76 | grid.get_boundary_conditions("auto_periodic_neumann")
77 | grid.get_boundary_conditions({"r": "derivative", "z": "periodic"})
78 | with pytest.raises(RuntimeError):
79 | grid.get_boundary_conditions({"r": "derivative", "z": "derivative"})
80 |
81 |
82 | def test_mixed_derivatives():
83 | """Test mixed derivatives of scalar fields."""
84 | grid = CylindricalSymGrid(1, [-1, 0.5], [7, 9])
85 | field = ScalarField.random_normal(grid, label="c")
86 |
87 | res1 = field.apply("d_dz(d_dr(c))")
88 | res2 = field.apply("d_dr(d_dz(c))")
89 | np.testing.assert_allclose(res1.data, res2.data)
90 |
--------------------------------------------------------------------------------
/tests/pdes/test_generic_pdes.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from pde import ScalarField, UnitGrid, pdes
9 |
10 |
11 | @pytest.mark.parametrize("dim", [1, 2])
12 | @pytest.mark.parametrize(
13 | "pde_class",
14 | [
15 | pdes.AllenCahnPDE,
16 | pdes.CahnHilliardPDE,
17 | pdes.DiffusionPDE,
18 | pdes.KPZInterfacePDE,
19 | pdes.KuramotoSivashinskyPDE,
20 | pdes.SwiftHohenbergPDE,
21 | ],
22 | )
23 | def test_pde_consistency(pde_class, dim, rng):
24 | """Test some methods of generic PDE models."""
25 | eq = pde_class()
26 | assert isinstance(str(eq), str)
27 | assert isinstance(repr(eq), str)
28 |
29 | # compare numba to numpy implementation
30 | grid = UnitGrid([4] * dim)
31 | state = ScalarField.random_uniform(grid, rng=rng)
32 | field = eq.evolution_rate(state)
33 | assert field.grid == grid
34 | rhs = eq._make_pde_rhs_numba(state)
35 | res = rhs(state.data, 0)
36 | np.testing.assert_allclose(field.data, res)
37 |
38 | # compare to generic implementation
39 | assert isinstance(eq.expression, str)
40 | eq2 = pdes.PDE({"c": eq.expression})
41 | np.testing.assert_allclose(field.data, eq2.evolution_rate(state).data)
42 |
43 |
44 | def test_pde_consistency_test(rng):
45 | """Test whether the consistency of a pde implementation is checked."""
46 |
47 | class TestPDE(pdes.PDEBase):
48 | def evolution_rate(self, field, t=0):
49 | return 2 * field
50 |
51 | def _make_pde_rhs_numba(self, state):
52 | def impl(state_data, t):
53 | return 3 * state_data
54 |
55 | return impl
56 |
57 | eq = TestPDE()
58 | state = ScalarField.random_uniform(UnitGrid([4]), rng=rng)
59 | with pytest.raises(AssertionError):
60 | eq.solve(state, t_range=5, tracker=None)
61 |
--------------------------------------------------------------------------------
/tests/pdes/test_laplace_pdes.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from pde import CartesianGrid, ScalarField, UnitGrid
9 | from pde.pdes import solve_laplace_equation, solve_poisson_equation
10 |
11 |
12 | def test_pde_poisson_solver_1d():
13 | """Test the poisson solver on 1d grids."""
14 | # solve Laplace's equation
15 | grid = UnitGrid([4])
16 | res = solve_laplace_equation(grid, bc={"x-": {"value": -1}, "x+": {"value": 3}})
17 | np.testing.assert_allclose(res.data, grid.axes_coords[0] - 1)
18 |
19 | res = solve_laplace_equation(
20 | grid, bc={"x-": {"value": -1}, "x+": {"derivative": 1}}
21 | )
22 | np.testing.assert_allclose(res.data, grid.axes_coords[0] - 1)
23 |
24 | # test Poisson equation with 2nd Order BC
25 | res = solve_laplace_equation(grid, bc={"x-": {"value": -1}, "x+": "extrapolate"})
26 |
27 | # solve Poisson's equation
28 | grid = CartesianGrid([[0, 1]], 4)
29 | field = ScalarField(grid, data=1)
30 |
31 | res = solve_poisson_equation(
32 | field, bc={"x-": {"value": 1}, "x+": {"derivative": 1}}
33 | )
34 | xs = grid.axes_coords[0]
35 | np.testing.assert_allclose(res.data, 1 + 0.5 * xs**2, rtol=1e-2)
36 |
37 | # test inconsistent problem
38 | field.data = 1
39 | with pytest.raises(RuntimeError, match="Neumann"):
40 | solve_poisson_equation(field, {"derivative": 0})
41 |
42 |
43 | def test_pde_poisson_solver_2d():
44 | """Test the poisson solver on 2d grids."""
45 | grid = CartesianGrid([[0, 2 * np.pi]] * 2, 16)
46 | bcs = {"x": {"value": "sin(y)"}, "y": {"value": "sin(x)"}}
47 |
48 | # solve Laplace's equation
49 | res = solve_laplace_equation(grid, bcs)
50 | xs = grid.cell_coords[..., 0]
51 | ys = grid.cell_coords[..., 1]
52 |
53 | # analytical solution was obtained with Mathematica
54 | expect = (
55 | np.cosh(np.pi - ys) * np.sin(xs) + np.cosh(np.pi - xs) * np.sin(ys)
56 | ) / np.cosh(np.pi)
57 | np.testing.assert_allclose(res.data, expect, atol=1e-2, rtol=1e-2)
58 |
59 | # test more complex case for exceptions
60 | bcs = {"x": {"value": "sin(y)"}, "y": {"curvature": "sin(x)"}}
61 | res = solve_laplace_equation(grid, bc=bcs)
62 |
--------------------------------------------------------------------------------
/tests/pdes/test_wave_pdes.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from pde import PDE, ScalarField, UnitGrid, WavePDE
9 |
10 |
11 | @pytest.mark.parametrize("dim", [1, 2])
12 | def test_wave_consistency(dim, rng):
13 | """Test some methods of the wave model."""
14 | eq = WavePDE()
15 | assert isinstance(str(eq), str)
16 | assert isinstance(repr(eq), str)
17 |
18 | # compare numba to numpy implementation
19 | grid = UnitGrid([4] * dim)
20 | state = eq.get_initial_condition(ScalarField.random_uniform(grid, rng=rng))
21 | field = eq.evolution_rate(state)
22 | assert field.grid == grid
23 | rhs = eq._make_pde_rhs_numba(state)
24 | np.testing.assert_allclose(field.data, rhs(state.data, 0))
25 |
26 | # compare to generic implementation
27 | assert isinstance(eq.expressions, dict)
28 | eq2 = PDE(eq.expressions)
29 | np.testing.assert_allclose(field.data, eq2.evolution_rate(state).data)
30 |
--------------------------------------------------------------------------------
/tests/requirements.txt:
--------------------------------------------------------------------------------
1 | -r ../requirements.txt
2 | docformatter>=1.7
3 | importlib-metadata>=5
4 | jupyter_contrib_nbextensions>=0.5
5 | mypy>=1.8
6 | notebook>=7
7 | pre-commit>=3
8 | pytest>=5.4
9 | pytest-cov>=2.8
10 | pytest-xdist>=1.30
11 | ruff>=0.6
12 | utilitiez>=0.3
13 |
--------------------------------------------------------------------------------
/tests/requirements_full.txt:
--------------------------------------------------------------------------------
1 | # These are the full requirements used to test all functions
2 | ffmpeg-python>=0.2
3 | h5py>=2.10
4 | ipywidgets>=8
5 | matplotlib>=3.1
6 | numba>=0.59
7 | numpy>=1.22
8 | pandas>=2
9 | py-modelrunner>=0.19
10 | rocket-fft>=0.2.4
11 | scipy>=1.10
12 | sympy>=1.9
13 | tqdm>=4.66
14 |
--------------------------------------------------------------------------------
/tests/requirements_min.txt:
--------------------------------------------------------------------------------
1 | # These are the minimal requirements used to test compatibility
2 | matplotlib~=3.1
3 | numba~=0.59
4 | numpy~=1.22
5 | scipy~=1.10
6 | sympy~=1.9
7 | tqdm~=4.66
8 |
--------------------------------------------------------------------------------
/tests/requirements_mpi.txt:
--------------------------------------------------------------------------------
1 | # These are requirements used to test multiprocessing
2 | h5py>=2.10
3 | matplotlib>=3.1
4 | mpi4py>=3
5 | numba>=0.59
6 | numba-mpi>=0.22
7 | numpy>=1.22
8 | pandas>=2
9 | scipy>=1.10
10 | sympy>=1.9
11 | tqdm>=4.66
12 |
--------------------------------------------------------------------------------
/tests/resources/run_pde.py:
--------------------------------------------------------------------------------
1 | import pde
2 |
3 |
4 | def run_pde(t_range, storage):
5 | """Run a pde and store trajectory."""
6 | field = pde.ScalarField.random_uniform(pde.UnitGrid([8, 8]))
7 | storage["initial_state"] = field
8 |
9 | eq = pde.DiffusionPDE()
10 | result = eq.solve(
11 | field,
12 | t_range=t_range,
13 | dt=0.1,
14 | backend="numpy",
15 | tracker=pde.ModelrunnerStorage(storage).tracker(1),
16 | )
17 | return {"field": result}
18 |
--------------------------------------------------------------------------------
/tests/solvers/test_adams_bashforth_solver.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 |
7 | import pde
8 |
9 |
10 | def test_adams_bashforth():
11 | """Test the adams_bashforth method."""
12 | eq = pde.PDE({"y": "y"})
13 | state = pde.ScalarField(pde.UnitGrid([1]), 1)
14 | storage = pde.MemoryStorage()
15 | eq.solve(
16 | state,
17 | t_range=2.1,
18 | dt=0.5,
19 | solver="adams–bashforth",
20 | tracker=storage.tracker(0.5),
21 | )
22 | np.testing.assert_allclose(
23 | np.ravel([f.data for f in storage]),
24 | [1, 13 / 8, 83 / 32, 529 / 128, 3371 / 512, 21481 / 2048],
25 | )
26 |
--------------------------------------------------------------------------------
/tests/solvers/test_controller.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from pde import PDEBase, ScalarField, UnitGrid
9 | from pde.solvers import Controller
10 |
11 |
12 | def test_controller_abort():
13 | """Test how controller deals with errors."""
14 |
15 | class ErrorPDEException(RuntimeError): ...
16 |
17 | class ErrorPDE(PDEBase):
18 | def evolution_rate(self, state, t):
19 | if t < 1:
20 | return 0 * state
21 | else:
22 | raise ErrorPDEException
23 |
24 | field = ScalarField(UnitGrid([16]), 1)
25 | eq = ErrorPDE()
26 |
27 | with pytest.raises(ErrorPDEException):
28 | eq.solve(field, t_range=2, dt=0.2, backend="numpy")
29 |
30 | assert eq.diagnostics["last_tracker_time"] >= 0
31 | assert eq.diagnostics["last_state"] == field
32 |
33 |
34 | def test_controller_foreign_solver():
35 | """Test whether the Controller can deal with a minimal foreign solver."""
36 |
37 | class MySolver:
38 | def make_stepper(self, state, dt):
39 | def stepper(state, t, t_break):
40 | return t_break
41 |
42 | return stepper
43 |
44 | c = Controller(MySolver(), t_range=1)
45 | res = c.run(np.arange(3))
46 | np.testing.assert_allclose(res, np.arange(3))
47 |
--------------------------------------------------------------------------------
/tests/solvers/test_explicit_mpi_solvers.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from pde import PDE, DiffusionPDE, FieldCollection, ScalarField, UnitGrid
9 | from pde.solvers import Controller, ExplicitMPISolver
10 | from pde.tools import mpi
11 |
12 |
13 | @pytest.mark.multiprocessing
14 | @pytest.mark.parametrize("backend", ["numpy", "numba"])
15 | @pytest.mark.parametrize(
16 | "scheme, adaptive, decomposition",
17 | [
18 | ("euler", False, "auto"),
19 | ("euler", True, [1, -1]),
20 | ("runge-kutta", True, [-1, 1]),
21 | ],
22 | )
23 | def test_simple_pde_mpi(backend, scheme, adaptive, decomposition, rng):
24 | """Test setting boundary conditions using numba."""
25 | grid = UnitGrid([8, 8], periodic=[True, False])
26 |
27 | field = ScalarField.random_uniform(grid, rng=rng)
28 | eq = DiffusionPDE()
29 |
30 | args = {
31 | "state": field,
32 | "t_range": 1.01,
33 | "dt": 0.1,
34 | "adaptive": adaptive,
35 | "scheme": scheme,
36 | "tracker": None,
37 | "ret_info": True,
38 | }
39 | res_mpi, info_mpi = eq.solve(
40 | backend=backend, solver="explicit_mpi", decomposition=decomposition, **args
41 | )
42 |
43 | if mpi.is_main:
44 | # check results in the main process
45 | expect, info2 = eq.solve(backend="numpy", solver="explicit", **args)
46 | np.testing.assert_allclose(res_mpi.data, expect.data)
47 |
48 | assert info_mpi["solver"]["steps"] == info2["solver"]["steps"]
49 | assert info_mpi["solver"]["use_mpi"]
50 | if decomposition != "auto":
51 | for i in range(2):
52 | if decomposition[i] == 1:
53 | assert info_mpi["solver"]["grid_decomposition"][i] == 1
54 | else:
55 | assert info_mpi["solver"]["grid_decomposition"][i] == mpi.size
56 |
57 |
58 | @pytest.mark.multiprocessing
59 | @pytest.mark.parametrize("backend", ["numba", "numpy"])
60 | def test_stochastic_mpi_solvers(backend, rng):
61 | """Test simple version of the stochastic solver."""
62 | field = ScalarField.random_uniform(UnitGrid([16]), -1, 1, rng=rng)
63 | eq = DiffusionPDE()
64 | seq = DiffusionPDE(noise=1e-10)
65 |
66 | solver1 = ExplicitMPISolver(eq, backend=backend)
67 | c1 = Controller(solver1, t_range=1, tracker=None)
68 | s1 = c1.run(field, dt=1e-3)
69 |
70 | solver2 = ExplicitMPISolver(seq, backend=backend)
71 | c2 = Controller(solver2, t_range=1, tracker=None)
72 | s2 = c2.run(field, dt=1e-3)
73 |
74 | if mpi.is_main:
75 | np.testing.assert_allclose(s1.data, s2.data, rtol=1e-4, atol=1e-4)
76 | assert not solver1.info["stochastic"]
77 | assert solver2.info["stochastic"]
78 |
79 | assert not solver1.info["dt_adaptive"]
80 | assert not solver2.info["dt_adaptive"]
81 |
82 |
83 | @pytest.mark.multiprocessing
84 | @pytest.mark.parametrize("backend", ["numpy", "numba"])
85 | def test_multiple_pdes_mpi(backend, rng):
86 | """Test setting boundary conditions using numba."""
87 | grid = UnitGrid([8, 8], periodic=[True, False])
88 |
89 | fields = FieldCollection.scalar_random_uniform(2, grid, rng=rng)
90 | eq = PDE({"a": "laplace(a) - b", "b": "laplace(b) + a"})
91 |
92 | args = {
93 | "state": fields,
94 | "t_range": 1.01,
95 | "dt": 0.1,
96 | "adaptive": True,
97 | "scheme": "euler",
98 | "tracker": None,
99 | "ret_info": True,
100 | }
101 | res_mpi, info_mpi = eq.solve(backend=backend, solver="explicit_mpi", **args)
102 |
103 | if mpi.is_main:
104 | # check results in the main process
105 | expect, info2 = eq.solve(backend="numpy", solver="explicit", **args)
106 | np.testing.assert_allclose(res_mpi.data, expect.data)
107 |
108 | assert info_mpi["solver"]["steps"] == info2["solver"]["steps"]
109 | assert info_mpi["solver"]["use_mpi"]
110 |
--------------------------------------------------------------------------------
/tests/solvers/test_generic_solvers.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from pde import PDE, DiffusionPDE, FieldCollection, MemoryStorage, ScalarField, UnitGrid
9 | from pde.solvers import (
10 | AdamsBashforthSolver,
11 | Controller,
12 | CrankNicolsonSolver,
13 | ExplicitSolver,
14 | ImplicitSolver,
15 | ScipySolver,
16 | registered_solvers,
17 | )
18 | from pde.solvers.base import AdaptiveSolverBase
19 |
20 | SOLVER_CLASSES = [
21 | ExplicitSolver,
22 | ImplicitSolver,
23 | CrankNicolsonSolver,
24 | AdamsBashforthSolver,
25 | ScipySolver,
26 | ]
27 |
28 |
29 | def test_solver_registration():
30 | """Test solver registration."""
31 | solvers = registered_solvers()
32 | assert "explicit" in solvers
33 | assert "implicit" in solvers
34 | assert "crank-nicolson" in solvers
35 | assert "scipy" in solvers
36 |
37 |
38 | def test_solver_in_pde_class(rng):
39 | """Test whether solver instances can be used in pde instances."""
40 | field = ScalarField.random_uniform(UnitGrid([16, 16]), -1, 1, rng=rng)
41 | eq = DiffusionPDE()
42 | eq.solve(field, t_range=1, solver=ScipySolver, tracker=None)
43 |
44 |
45 | @pytest.mark.parametrize("solver_class", SOLVER_CLASSES)
46 | def test_compare_solvers(solver_class, rng):
47 | """Compare several solvers."""
48 | field = ScalarField.random_uniform(UnitGrid([8, 8]), -1, 1, rng=rng)
49 | eq = DiffusionPDE()
50 |
51 | # ground truth
52 | c1 = Controller(ExplicitSolver(eq, scheme="runge-kutta"), t_range=0.1, tracker=None)
53 | s1 = c1.run(field, dt=5e-3)
54 |
55 | c2 = Controller(solver_class(eq), t_range=0.1, tracker=None)
56 | with np.errstate(under="ignore"):
57 | s2 = c2.run(field, dt=5e-3)
58 |
59 | np.testing.assert_allclose(s1.data, s2.data, rtol=1e-2, atol=1e-2)
60 |
61 |
62 | @pytest.mark.parametrize("solver_class", SOLVER_CLASSES)
63 | @pytest.mark.parametrize("backend", ["numpy", "numba"])
64 | def test_solvers_complex(solver_class, backend):
65 | """Test solvers with a complex PDE."""
66 | r = FieldCollection.scalar_random_uniform(2, UnitGrid([3]), labels=["a", "b"])
67 | c = r["a"] + 1j * r["b"]
68 | assert c.is_complex
69 |
70 | # assume c = a + i * b
71 | eq_c = PDE({"c": "-I * laplace(c)"})
72 | eq_r = PDE({"a": "laplace(b)", "b": "-laplace(a)"})
73 | res_r = eq_r.solve(r, t_range=1e-2, dt=1e-3, backend="numpy", tracker=None)
74 | exp_c = res_r[0].data + 1j * res_r[1].data
75 |
76 | solver = solver_class(eq_c, backend=backend)
77 | controller = Controller(solver, t_range=1e-2, tracker=None)
78 | res_c = controller.run(c, dt=1e-3)
79 | np.testing.assert_allclose(res_c.data, exp_c, rtol=1e-3, atol=1e-3)
80 |
81 |
82 | def test_basic_adaptive_solver():
83 | """Test basic adaptive solvers."""
84 | grid = UnitGrid([4])
85 | y0 = np.array([1e-3, 1e-3, 1e3, 1e3])
86 | field = ScalarField(grid, y0)
87 | eq = PDE({"c": "c"})
88 |
89 | dt = 0.1
90 |
91 | solver = AdaptiveSolverBase(eq, tolerance=1e-1)
92 | storage = MemoryStorage()
93 | controller = Controller(solver, t_range=10.1, tracker=storage.tracker(1.0))
94 | res = controller.run(field, dt=dt)
95 |
96 | np.testing.assert_allclose(res.data, y0 * np.exp(10.1), rtol=0.02)
97 | assert solver.info["steps"] != pytest.approx(10.1 / dt, abs=1)
98 | assert solver.info["dt_adaptive"]
99 | assert solver.info["dt_statistics"]["min"] < 0.0005
100 | assert np.allclose(storage.times, np.arange(11))
101 |
--------------------------------------------------------------------------------
/tests/solvers/test_implicit_solvers.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from pde import PDE, DiffusionPDE, ScalarField, UnitGrid
9 | from pde.solvers import Controller, ImplicitSolver
10 | from pde.tools import mpi
11 |
12 |
13 | @pytest.mark.parametrize("backend", ["numpy", "numba"])
14 | def test_implicit_solvers_simple_fixed(backend):
15 | """Test implicit solvers."""
16 | grid = UnitGrid([4])
17 | xs = grid.axes_coords[0]
18 | field = ScalarField.from_expression(grid, "x")
19 | eq = PDE({"c": "c"})
20 |
21 | dt = 0.01
22 | solver = ImplicitSolver(eq, backend=backend)
23 | controller = Controller(solver, t_range=10.0, tracker=None)
24 | res = controller.run(field, dt=dt)
25 |
26 | if mpi.is_main:
27 | np.testing.assert_allclose(res.data, xs * np.exp(10), rtol=0.1)
28 | assert solver.info["steps"] == pytest.approx(10 / dt, abs=1)
29 | assert not solver.info.get("dt_adaptive", False)
30 |
31 |
32 | @pytest.mark.parametrize("backend", ["numpy", "numba"])
33 | def test_implicit_stochastic_solvers(backend, rng):
34 | """Test simple version of the stochastic implicit solver."""
35 | field = ScalarField.random_uniform(UnitGrid([16]), -1, 1, rng=rng)
36 | eq = DiffusionPDE()
37 | seq = DiffusionPDE(noise=1e-10)
38 |
39 | solver1 = ImplicitSolver(eq, backend=backend)
40 | c1 = Controller(solver1, t_range=1, tracker=None)
41 | s1 = c1.run(field, dt=1e-3)
42 |
43 | solver2 = ImplicitSolver(seq, backend=backend)
44 | c2 = Controller(solver2, t_range=1, tracker=None)
45 | s2 = c2.run(field, dt=1e-3)
46 |
47 | np.testing.assert_allclose(s1.data, s2.data, rtol=1e-4, atol=1e-4)
48 | assert not solver1.info["stochastic"]
49 | assert solver2.info["stochastic"]
50 |
51 | assert not solver1.info.get("dt_adaptive", False)
52 | assert not solver2.info.get("dt_adaptive", False)
53 |
--------------------------------------------------------------------------------
/tests/solvers/test_scipy_solvers.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 |
7 | from pde import PDE, DiffusionPDE, FieldCollection, ScalarField, UnitGrid
8 | from pde.solvers import Controller, ScipySolver
9 |
10 |
11 | def test_scipy_no_dt(rng):
12 | """Test scipy solver without timestep."""
13 | grid = UnitGrid([16])
14 | field = ScalarField.random_uniform(grid, -1, 1, rng=rng)
15 | eq = DiffusionPDE()
16 |
17 | c1 = Controller(ScipySolver(eq), t_range=1, tracker=None)
18 | s1 = c1.run(field, dt=1e-3)
19 |
20 | c2 = Controller(ScipySolver(eq), t_range=1, tracker=None)
21 | s2 = c2.run(field)
22 |
23 | np.testing.assert_allclose(s1.data, s2.data, rtol=1e-3, atol=1e-3)
24 |
25 |
26 | def test_scipy_field_collection():
27 | """Test scipy solver with field collection."""
28 | grid = UnitGrid([2])
29 | field = FieldCollection.from_scalar_expressions(grid, ["x", "0"])
30 | eq = PDE({"a": "1", "b": "a"})
31 |
32 | res = eq.solve(field, t_range=1, dt=1e-2, solver="scipy")
33 | np.testing.assert_allclose(res.data, np.array([[1.5, 2.5], [1.0, 2.0]]))
34 |
--------------------------------------------------------------------------------
/tests/storage/resources/empty.avi:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/empty.avi
--------------------------------------------------------------------------------
/tests/storage/resources/no_metadata.avi:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/no_metadata.avi
--------------------------------------------------------------------------------
/tests/storage/resources/storage_1.avi:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/storage_1.avi
--------------------------------------------------------------------------------
/tests/storage/resources/storage_1.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/storage_1.hdf5
--------------------------------------------------------------------------------
/tests/storage/resources/storage_2.avi:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/storage_2.avi
--------------------------------------------------------------------------------
/tests/storage/resources/storage_2.avi.times:
--------------------------------------------------------------------------------
1 | 0.1
2 | 0.7
3 | 2.9
4 |
--------------------------------------------------------------------------------
/tests/storage/resources/storage_2.hdf5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zwicker-group/py-pde/748e801f30d6e6b383bbcc5d918c5bd3b6d563ea/tests/storage/resources/storage_2.hdf5
--------------------------------------------------------------------------------
/tests/storage/test_memory_storages.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from pde import MemoryStorage, UnitGrid
9 | from pde.fields import FieldCollection, ScalarField, Tensor2Field, VectorField
10 |
11 |
12 | def test_memory_storage():
13 | """Test methods specific to memory storage."""
14 | sf = ScalarField(UnitGrid([1]))
15 | s1 = MemoryStorage()
16 | s1.start_writing(sf)
17 | sf.data = 0
18 | s1.append(sf, 0)
19 | sf.data = 2
20 | s1.append(sf, 1)
21 |
22 | s2 = MemoryStorage()
23 | s2.start_writing(sf)
24 | sf.data = 1
25 | s2.append(sf, 0)
26 | sf.data = 3
27 | s2.append(sf, 1)
28 |
29 | # test from_fields
30 | s3 = MemoryStorage.from_fields(s1.times, [s1[0], s1[1]])
31 | assert s3.times == s1.times
32 | np.testing.assert_allclose(s3.data, s1.data)
33 |
34 | # test from_collection
35 | s3 = MemoryStorage.from_collection([s1, s2])
36 | assert s3.times == s1.times
37 | np.testing.assert_allclose(np.ravel(s3.data), np.arange(4))
38 |
39 |
40 | @pytest.mark.parametrize("cls", [ScalarField, VectorField, Tensor2Field])
41 | def test_field_type_guessing_fields(cls, rng):
42 | """Test the ability to guess the field type."""
43 | grid = UnitGrid([3])
44 | field = cls.random_normal(grid, rng=rng)
45 | s = MemoryStorage()
46 | s.start_writing(field)
47 | s.append(field, 0)
48 | s.append(field, 1)
49 |
50 | # delete information
51 | s._field = None
52 | s.info = {}
53 |
54 | assert not s.has_collection
55 | assert len(s) == 2
56 | assert s[0] == field
57 |
58 |
59 | def test_field_type_guessing_collection(rng):
60 | """Test the ability to guess the field type of a collection."""
61 | grid = UnitGrid([3])
62 | field = FieldCollection([ScalarField(grid), VectorField(grid)])
63 | s = MemoryStorage()
64 | s.start_writing(field)
65 | s.append(field, 0)
66 |
67 | assert s.has_collection
68 |
69 | # delete information
70 | s._field = None
71 | s.info = {}
72 |
73 | with pytest.raises(RuntimeError):
74 | s[0]
75 |
--------------------------------------------------------------------------------
/tests/storage/test_modelrunner_storages.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | import pde
9 | from pde.tools.misc import module_available
10 |
11 |
12 | @pytest.mark.skipif(
13 | not module_available("modelrunner"), reason="requires `py-modelrunner` package"
14 | )
15 | def test_storage_write_trajectory(tmp_path):
16 | """Test simple storage writing."""
17 | import modelrunner as mr
18 |
19 | path = tmp_path / "storage.json"
20 | storage = mr.open_storage(path, mode="truncate")
21 |
22 | field = pde.ScalarField.random_uniform(pde.UnitGrid([8, 8]))
23 | eq = pde.DiffusionPDE()
24 | eq.solve(
25 | field,
26 | t_range=2.5,
27 | dt=0.1,
28 | backend="numpy",
29 | tracker=pde.ModelrunnerStorage(storage).tracker(1),
30 | )
31 |
32 | assert path.is_file()
33 | assert len(pde.ModelrunnerStorage(path)) == 3
34 | np.testing.assert_allclose(pde.ModelrunnerStorage(path).times, [0, 1, 2])
35 |
--------------------------------------------------------------------------------
/tests/test_examples.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import os
6 | import subprocess as sp
7 | import sys
8 | from pathlib import Path
9 |
10 | import pytest
11 |
12 | from pde.tools.misc import module_available
13 | from pde.visualization.movies import Movie
14 |
15 | PACKAGE_PATH = Path(__file__).resolve().parents[1]
16 | EXAMPLES = (PACKAGE_PATH / "examples").glob("*/*.py")
17 | NOTEBOOKS = (PACKAGE_PATH / "examples").glob("*/*.ipynb")
18 |
19 | SKIP_EXAMPLES: list[str] = []
20 | if not Movie.is_available():
21 | SKIP_EXAMPLES.extend(["make_movie_live.py", "make_movie_storage.py", "storages.py"])
22 | if not module_available("mpi4py"):
23 | SKIP_EXAMPLES.extend(["mpi_parallel_run"])
24 | if not module_available("napari"):
25 | SKIP_EXAMPLES.extend(["tracker_interactive", "show_3d_field_interactively"])
26 | if not module_available("h5py"):
27 | SKIP_EXAMPLES.extend(["trajectory_io"])
28 | if not all(module_available(m) for m in ["modelrunner", "h5py"]):
29 | SKIP_EXAMPLES.extend(["py_modelrunner"])
30 | if not module_available("utilitiez"):
31 | SKIP_EXAMPLES.extend(["logarithmic_kymograph"])
32 |
33 |
34 | @pytest.mark.slow
35 | @pytest.mark.no_cover
36 | @pytest.mark.skipif(sys.platform == "win32", reason="Assumes unix setup")
37 | @pytest.mark.parametrize("path", EXAMPLES)
38 | def test_example_scripts(path):
39 | """Runs an example script given by path."""
40 | # check whether this test needs to be run
41 | if path.name.startswith("_"):
42 | pytest.skip("skip examples starting with an underscore")
43 | if any(name in str(path) for name in SKIP_EXAMPLES):
44 | pytest.skip(f"Skip test {path}")
45 |
46 | # run the actual test in a separate python process
47 | env = os.environ.copy()
48 | env["PYTHONPATH"] = str(PACKAGE_PATH) + ":" + env.get("PYTHONPATH", "")
49 | proc = sp.Popen([sys.executable, path], env=env, stdout=sp.PIPE, stderr=sp.PIPE)
50 | try:
51 | outs, errs = proc.communicate(timeout=30)
52 | except sp.TimeoutExpired:
53 | proc.kill()
54 | outs, errs = proc.communicate()
55 |
56 | # delete files that might be created by the test
57 | try:
58 | (PACKAGE_PATH / "diffusion.mov").unlink()
59 | (PACKAGE_PATH / "allen_cahn.avi").unlink()
60 | (PACKAGE_PATH / "allen_cahn.hdf").unlink()
61 | except OSError:
62 | pass
63 |
64 | # prepare output
65 | msg = f"Script `{path}` failed with following output:"
66 | if outs:
67 | msg = f"{msg}\nSTDOUT:\n{outs}"
68 | if errs:
69 | msg = f"{msg}\nSTDERR:\n{errs}"
70 | assert proc.returncode <= 0, msg
71 |
72 |
73 | @pytest.mark.slow
74 | @pytest.mark.no_cover
75 | @pytest.mark.skipif(not module_available("h5py"), reason="requires `h5py`")
76 | @pytest.mark.skipif(not module_available("jupyter"), reason="requires `jupyter`")
77 | @pytest.mark.skipif(not module_available("notebook"), reason="requires `notebook`")
78 | @pytest.mark.parametrize("path", NOTEBOOKS)
79 | def test_jupyter_notebooks(path, tmp_path):
80 | """Run the jupyter notebooks."""
81 | import notebook as jupyter_notebook
82 |
83 | if int(jupyter_notebook.__version__.split(".")[0]) < 7:
84 | raise RuntimeError("Jupyter notebooks must be at least version 7")
85 |
86 | if path.name.startswith("_"):
87 | pytest.skip("Skip examples starting with an underscore")
88 |
89 | # adjust python environment
90 | my_env = os.environ.copy()
91 | my_env["PYTHONPATH"] = str(PACKAGE_PATH) + ":" + my_env.get("PYTHONPATH", "")
92 |
93 | # run the notebook
94 | sp.check_call([sys.executable, "-m", "jupyter", "execute", path], env=my_env)
95 |
--------------------------------------------------------------------------------
/tests/tools/test_config.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import pytest
6 |
7 | from pde.tools.config import Config, environment, packages_from_requirements
8 |
9 |
10 | def test_environment():
11 | """Test the environment function."""
12 | assert isinstance(environment(), dict)
13 |
14 |
15 | def test_config():
16 | """Test configuration system."""
17 | c = Config()
18 |
19 | assert c["numba.multithreading_threshold"] > 0
20 |
21 | assert "numba.multithreading_threshold" in c
22 | assert any(k == "numba.multithreading_threshold" for k in c)
23 | assert any(k == "numba.multithreading_threshold" and v > 0 for k, v in c.items())
24 | assert "numba.multithreading_threshold" in c.to_dict()
25 | assert isinstance(repr(c), str)
26 |
27 |
28 | def test_config_modes():
29 | """Test configuration system running in different modes."""
30 | c = Config(mode="insert")
31 | assert c["numba.multithreading_threshold"] > 0
32 | c["numba.multithreading_threshold"] = 0
33 | assert c["numba.multithreading_threshold"] == 0
34 | c["new_value"] = "value"
35 | assert c["new_value"] == "value"
36 | c.update({"new_value2": "value2"})
37 | assert c["new_value2"] == "value2"
38 | del c["new_value"]
39 | with pytest.raises(KeyError):
40 | c["new_value"]
41 | with pytest.raises(KeyError):
42 | c["undefined"]
43 |
44 | c = Config(mode="update")
45 | assert c["numba.multithreading_threshold"] > 0
46 | c["numba.multithreading_threshold"] = 0
47 |
48 | with pytest.raises(KeyError):
49 | c["new_value"] = "value"
50 | with pytest.raises(KeyError):
51 | c.update({"new_value": "value"})
52 | with pytest.raises(RuntimeError):
53 | del c["numba.multithreading_threshold"]
54 | with pytest.raises(KeyError):
55 | c["undefined"]
56 |
57 | c = Config(mode="locked")
58 | assert c["numba.multithreading_threshold"] > 0
59 | with pytest.raises(RuntimeError):
60 | c["numba.multithreading_threshold"] = 0
61 | with pytest.raises(RuntimeError):
62 | c.update({"numba.multithreading_threshold": 0})
63 | with pytest.raises(RuntimeError):
64 | c["new_value"] = "value"
65 | with pytest.raises(RuntimeError):
66 | del c["numba.multithreading_threshold"]
67 | with pytest.raises(KeyError):
68 | c["undefined"]
69 |
70 | c = Config(mode="undefined")
71 | assert c["numba.multithreading_threshold"] > 0
72 | with pytest.raises(ValueError):
73 | c["numba.multithreading_threshold"] = 0
74 | with pytest.raises(ValueError):
75 | c.update({"numba.multithreading_threshold": 0})
76 | with pytest.raises(RuntimeError):
77 | del c["numba.multithreading_threshold"]
78 |
79 | c = Config({"new_value": "value"}, mode="locked")
80 | assert c["new_value"] == "value"
81 |
82 |
83 | def test_config_contexts():
84 | """Test context manager temporarily changing configuration."""
85 | c = Config()
86 |
87 | assert c["numba.multithreading_threshold"] > 0
88 | with c({"numba.multithreading_threshold": 0}):
89 | assert c["numba.multithreading_threshold"] == 0
90 | with c({"numba.multithreading_threshold": 1}):
91 | assert c["numba.multithreading_threshold"] == 1
92 | assert c["numba.multithreading_threshold"] == 0
93 |
94 | assert c["numba.multithreading_threshold"] > 0
95 |
96 |
97 | def test_config_special_values():
98 | """Test configuration system running in different modes."""
99 | c = Config()
100 | c["numba.multithreading"] = True
101 | assert c["numba.multithreading"] == "always"
102 | assert c.use_multithreading()
103 | c["numba.multithreading"] = False
104 | assert c["numba.multithreading"] == "never"
105 | assert not c.use_multithreading()
106 |
107 |
108 | def test_packages_from_requirements():
109 | """Test the packages_from_requirements function."""
110 | results = packages_from_requirements("file_not_existing")
111 | assert len(results) == 1
112 | assert "Could not open" in results[0]
113 | assert "file_not_existing" in results[0]
114 |
--------------------------------------------------------------------------------
/tests/tools/test_ffmpeg.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import pytest
6 |
7 | from pde.tools.ffmpeg import find_format
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "channels,bits_per_channel,result",
12 | [(1, 8, "gray"), (2, 7, "rgb24"), (3, 9, "gbrp16le"), (5, 8, None), (1, 17, None)],
13 | )
14 | def test_find_format(channels, bits_per_channel, result):
15 | """test_find_format function."""
16 | assert find_format(channels, bits_per_channel) == result
17 |
--------------------------------------------------------------------------------
/tests/tools/test_math.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from pde.tools.math import OnlineStatistics, SmoothData1D
9 |
10 |
11 | def test_SmoothData1D(rng):
12 | """Test smoothing."""
13 | x = rng.uniform(0, 1, 500)
14 | xs = np.linspace(0, 1, 16)[1:-1]
15 |
16 | s = SmoothData1D(x, np.ones_like(x), 0.1)
17 | np.testing.assert_allclose(s(xs), 1)
18 | np.testing.assert_allclose(s.derivative(xs), 0, atol=1e-14)
19 |
20 | s = SmoothData1D(x, x, 0.02)
21 | np.testing.assert_allclose(s(xs), xs, atol=0.1)
22 | np.testing.assert_allclose(s.derivative(xs), 1, atol=0.3)
23 |
24 | s = SmoothData1D(x, np.sin(x), 0.05)
25 | np.testing.assert_allclose(s(xs), np.sin(xs), atol=0.1)
26 | np.testing.assert_allclose(s.derivative(xs)[1:-1], np.cos(xs)[1:-1], atol=0.3)
27 |
28 | assert -0.1 not in s
29 | assert x.min() in s
30 | assert 0.5 in s
31 | assert x.max() in s
32 | assert 1.1 not in s
33 |
34 | x = np.arange(3)
35 | y = [0, 1, np.nan]
36 | s = SmoothData1D(x, y)
37 | assert s(0.5) == pytest.approx(0.5)
38 |
39 |
40 | def test_online_statistics():
41 | """Test OnlineStatistics class."""
42 | stat = OnlineStatistics()
43 |
44 | stat.add(1)
45 | stat.add(2)
46 |
47 | assert stat.mean == pytest.approx(1.5)
48 | assert stat.std == pytest.approx(0.5)
49 | assert stat.var == pytest.approx(0.25)
50 | assert stat.count == 2
51 | assert stat.to_dict() == pytest.approx(
52 | {"min": 1, "max": 2, "mean": 1.5, "std": 0.5, "count": 2}
53 | )
54 |
55 | stat = OnlineStatistics()
56 | assert stat.to_dict()["count"] == 0
57 |
--------------------------------------------------------------------------------
/tests/tools/test_misc.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import json
6 | import os
7 | from pathlib import Path
8 |
9 | import numpy as np
10 | import pytest
11 |
12 | from pde.tools import misc
13 |
14 |
15 | def test_ensure_directory_exists(tmp_path):
16 | """Tests the ensure_directory_exists function."""
17 | # create temporary name
18 | path = tmp_path / "test_ensure_directory_exists"
19 | assert not path.exists()
20 | # create the folder
21 | misc.ensure_directory_exists(path)
22 | assert path.is_dir()
23 | # check that a second call has the same result
24 | misc.ensure_directory_exists(path)
25 | assert path.is_dir()
26 | # remove the folder again
27 | Path.rmdir(path)
28 | assert not path.exists()
29 |
30 |
31 | def test_preserve_scalars():
32 | """Test the preserve_scalars decorator."""
33 |
34 | class Test:
35 | @misc.preserve_scalars
36 | def meth(self, arr):
37 | return arr + 1
38 |
39 | t = Test()
40 |
41 | assert t.meth(1) == 2
42 | np.testing.assert_equal(t.meth(np.ones(2)), np.full(2, 2))
43 |
44 |
45 | def test_hybridmethod():
46 | """Test the hybridmethod decorator."""
47 |
48 | class Test:
49 | @misc.hybridmethod
50 | def method(cls):
51 | return "class"
52 |
53 | @method.instancemethod
54 | def method(self):
55 | return "instance"
56 |
57 | assert Test.method() == "class"
58 | assert Test().method() == "instance"
59 |
60 |
61 | def test_estimate_computation_speed():
62 | """Test estimate_computation_speed method."""
63 |
64 | def f(x):
65 | return 2 * x
66 |
67 | def g(x):
68 | return np.sin(x) * np.cos(x) ** 2
69 |
70 | assert misc.estimate_computation_speed(f, 1) > misc.estimate_computation_speed(g, 1)
71 |
72 |
73 | def test_classproperty():
74 | """Test classproperty decorator."""
75 |
76 | class Test:
77 | _value = 2
78 |
79 | @misc.classproperty
80 | def value(cls):
81 | return cls._value
82 |
83 | assert Test.value == 2
84 |
85 |
86 | @pytest.mark.skipif(not misc.module_available("h5py"), reason="requires `h5py` module")
87 | def test_hdf_write_attributes(tmp_path):
88 | """Test hdf_write_attributes function."""
89 | import h5py
90 |
91 | path = tmp_path / "test_hdf_write_attributes.hdf5"
92 |
93 | # test normal case
94 | data = {"a": 3, "b": "asd"}
95 | with h5py.File(path, "w") as hdf_file:
96 | misc.hdf_write_attributes(hdf_file, data)
97 | data2 = {k: json.loads(v) for k, v in hdf_file.attrs.items()}
98 |
99 | assert data == data2
100 | assert data is not data2
101 |
102 | # test silencing of problematic items
103 | with h5py.File(path, "w") as hdf_file:
104 | misc.hdf_write_attributes(hdf_file, {"a": 1, "b": object()})
105 | data2 = {k: json.loads(v) for k, v in hdf_file.attrs.items()}
106 | assert data2 == {"a": 1}
107 |
108 | # test raising problematic items
109 | with h5py.File(path, "w") as hdf_file, pytest.raises(TypeError):
110 | misc.hdf_write_attributes(
111 | hdf_file, {"a": object()}, raise_serialization_error=True
112 | )
113 |
--------------------------------------------------------------------------------
/tests/tools/test_mpi.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from pde.tools.mpi import mpi_allreduce, mpi_recv, mpi_send, rank, size
9 |
10 |
11 | @pytest.mark.multiprocessing
12 | def test_send_recv():
13 | """Test basic send and receive."""
14 | if size == 1:
15 | pytest.skip("Run without multiprocessing")
16 |
17 | data = np.arange(5)
18 | if rank == 0:
19 | out = np.empty_like(data)
20 | mpi_recv(out, 1, 1)
21 | np.testing.assert_allclose(out, data)
22 | elif rank == 1:
23 | mpi_send(data, 0, 1)
24 |
25 |
26 | @pytest.mark.multiprocessing
27 | @pytest.mark.parametrize("operator", ["MAX", "MIN", "SUM"])
28 | def test_allreduce(operator, rng):
29 | """Test MPI allreduce function."""
30 | data = rng.uniform(size=size)
31 | result = mpi_allreduce(data[rank], operator=operator)
32 |
33 | if operator == "MAX":
34 | assert result == data.max()
35 | elif operator == "MIN":
36 | assert result == data.min()
37 | elif operator == "SUM":
38 | assert result == data.sum()
39 | else:
40 | raise NotImplementedError
41 |
42 |
43 | @pytest.mark.multiprocessing
44 | @pytest.mark.parametrize("operator", ["MAX", "MIN", "SUM"])
45 | def test_allreduce_array(operator, rng):
46 | """Test MPI allreduce function."""
47 | data = rng.uniform(size=(size, 3))
48 | result = mpi_allreduce(data[rank], operator=operator)
49 |
50 | if operator == "MAX":
51 | np.testing.assert_allclose(result, data.max(axis=0))
52 | elif operator == "MIN":
53 | np.testing.assert_allclose(result, data.min(axis=0))
54 | elif operator == "SUM":
55 | np.testing.assert_allclose(result, data.sum(axis=0))
56 | else:
57 | raise NotImplementedError
58 |
--------------------------------------------------------------------------------
/tests/tools/test_numba.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import numba
6 | import numpy as np
7 | import pytest
8 |
9 | from pde.tools.numba import (
10 | Counter,
11 | flat_idx,
12 | jit,
13 | make_array_constructor,
14 | numba_dict,
15 | numba_environment,
16 | )
17 |
18 |
19 | def test_environment():
20 | """Test function signature checks."""
21 | assert isinstance(numba_environment(), dict)
22 |
23 |
24 | def test_flat_idx():
25 | """Test flat_idx function."""
26 | # testing the numpy version
27 | assert flat_idx(2, 1) == 2
28 | assert flat_idx(np.arange(2), 1) == 1
29 | assert flat_idx(np.arange(4).reshape(2, 2), 1) == 1
30 |
31 | # testing the numba compiled version
32 | @jit
33 | def get_sparse_matrix_data(data):
34 | return flat_idx(data, 1)
35 |
36 | assert get_sparse_matrix_data(2) == 2
37 | assert get_sparse_matrix_data(np.arange(2)) == 1
38 | assert get_sparse_matrix_data(np.arange(4).reshape(2, 2)) == 1
39 |
40 |
41 | def test_counter():
42 | """Test Counter implementation."""
43 | c1 = Counter()
44 | assert int(c1) == 0
45 | assert c1 == 0
46 | assert str(c1) == "0"
47 |
48 | c1.increment()
49 | assert int(c1) == 1
50 |
51 | c1 += 2
52 | assert int(c1) == 3
53 |
54 | c2 = Counter(3)
55 | assert c1 is not c2
56 | assert c1 == c2
57 |
58 |
59 | @pytest.mark.parametrize(
60 | "arr", [np.arange(5), np.linspace(0, 1, 3), np.arange(12).reshape(3, 4)[1:, 2:]]
61 | )
62 | def test_make_array_constructor(arr):
63 | """Test implementation to create array."""
64 | constructor = jit(make_array_constructor(arr))
65 | arr2 = constructor()
66 | np.testing.assert_equal(arr, arr2)
67 | assert np.shares_memory(arr, arr2)
68 |
69 |
70 | def test_numba_dict():
71 | """Test numba_dict function."""
72 | cls = dict if numba.config.DISABLE_JIT else numba.typed.Dict
73 |
74 | # test empty dictionaries
75 | for d in [numba_dict(), numba_dict({})]:
76 | assert len(d) == 0
77 | assert isinstance(d, cls)
78 |
79 | # test initializing dictionaries in different ways
80 | for d in [
81 | numba_dict({"a": 1, "b": 2}),
82 | numba_dict(a=1, b=2),
83 | numba_dict({"a": 1}, b=2),
84 | numba_dict({"a": 1, "b": 3}, b=2),
85 | ]:
86 | assert isinstance(d, cls)
87 | assert len(d) == 2
88 | assert d["a"] == 1
89 | assert d["b"] == 2
90 |
91 | # test edge case
92 | d = numba_dict(data=1)
93 | assert d["data"] == 1
94 | assert isinstance(d, cls)
95 |
--------------------------------------------------------------------------------
/tests/tools/test_output.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | from pde.tools import output
6 |
7 |
8 | def test_progress_bars():
9 | """Test progress bars."""
10 | pb_cls = output.get_progress_bar_class()
11 | tot = 0
12 | for i in pb_cls(range(4)):
13 | tot += i
14 | assert tot == 6
15 |
16 |
17 | def test_in_jupyter_notebook():
18 | """Test the function in_jupyter_notebook."""
19 | assert isinstance(output.in_jupyter_notebook(), bool)
20 |
21 |
22 | def test_display_progress(capsys):
23 | """Test whether this works."""
24 | for _ in output.display_progress(range(2)):
25 | pass
26 | out, err = capsys.readouterr()
27 | assert out == ""
28 | assert len(err) > 0
29 |
--------------------------------------------------------------------------------
/tests/tools/test_parse_duration.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | from pde.tools.parse_duration import parse_duration
6 |
7 |
8 | def test_parse_duration():
9 | """Test function signature checks."""
10 |
11 | def p(value):
12 | return parse_duration(value).total_seconds()
13 |
14 | assert p("0") == 0
15 | assert p("1") == 1
16 | assert p("1:2") == 62
17 | assert p("1:2:3") == 3723
18 |
--------------------------------------------------------------------------------
/tests/tools/test_plotting_tools.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import matplotlib.pyplot as plt
6 | import matplotlib.testing.compare
7 | import pytest
8 |
9 | from pde.tools.plotting import add_scaled_colorbar, plot_on_axes, plot_on_figure
10 |
11 |
12 | def test_plot_on_axes(tmp_path):
13 | """Test the plot_on_axes decorator."""
14 |
15 | @plot_on_axes
16 | def plot(ax):
17 | ax.plot([0, 1], [0, 1])
18 |
19 | path = tmp_path / "test.png"
20 | plot(title="Test", filename=path)
21 | assert path.stat().st_size > 0
22 |
23 |
24 | def test_plot_on_figure(tmp_path):
25 | """Test the plot_on_figure decorator."""
26 |
27 | @plot_on_figure
28 | def plot(fig):
29 | ax1, ax2 = fig.subplots(1, 2)
30 | ax1.plot([0, 1], [0, 1])
31 | ax2.plot([0, 1], [0, 1])
32 |
33 | path = tmp_path / "test.png"
34 | plot(title="Test", filename=path)
35 | assert path.stat().st_size > 0
36 |
37 |
38 | @pytest.mark.interactive
39 | def test_plot_colorbar(tmp_path, rng):
40 | """Test the plot_on_axes decorator."""
41 | data = rng.normal(size=(3, 3))
42 |
43 | # do not specify axis
44 | img = plt.imshow(data)
45 | add_scaled_colorbar(img, label="Label")
46 | plt.savefig(tmp_path / "img1.png")
47 | plt.clf()
48 |
49 | # specify axis explicitly
50 | ax = plt.gca()
51 | img = ax.imshow(data)
52 | add_scaled_colorbar(img, ax=ax, label="Label")
53 | plt.savefig(tmp_path / "img2.png")
54 |
55 | # compare the two results
56 | cmp = matplotlib.testing.compare.compare_images(
57 | str(tmp_path / "img1.png"), str(tmp_path / "img2.png"), tol=0.1
58 | )
59 | assert cmp is None
60 |
--------------------------------------------------------------------------------
/tests/visualization/test_movies.py:
--------------------------------------------------------------------------------
1 | """
2 | .. codeauthor:: David Zwicker
3 | """
4 |
5 | import pytest
6 |
7 | from pde.fields import ScalarField
8 | from pde.grids import UnitGrid
9 | from pde.pdes import DiffusionPDE
10 | from pde.storage import MemoryStorage
11 | from pde.visualization import movies
12 |
13 |
14 | @pytest.mark.skipif(not movies.Movie.is_available(), reason="no ffmpeg")
15 | def test_movie_class(tmp_path):
16 | """Test Movie class."""
17 | import matplotlib.pyplot as plt
18 |
19 | path = tmp_path / "test_movie.mov"
20 |
21 | try:
22 | with movies.Movie(path) as movie:
23 | # iterate over all time steps
24 | plt.plot([0, 1], [0, 1])
25 | movie.add_figure()
26 | movie.add_figure()
27 |
28 | # save movie
29 | movie.save()
30 | except RuntimeError:
31 | pass # can happen when ffmpeg is not installed
32 | else:
33 | assert path.stat().st_size > 0
34 |
35 |
36 | @pytest.mark.skipif(not movies.Movie.is_available(), reason="no ffmpeg")
37 | @pytest.mark.parametrize("movie_func", [movies.movie_scalar, movies.movie])
38 | def test_movie_scalar(movie_func, tmp_path, rng):
39 | """Test Movie class."""
40 | # create some data
41 | state = ScalarField.random_uniform(UnitGrid([4, 4]), rng=rng)
42 | eq = DiffusionPDE()
43 | storage = MemoryStorage()
44 | tracker = storage.tracker(interrupts=1)
45 | eq.solve(state, t_range=2, dt=1e-2, backend="numpy", tracker=tracker)
46 |
47 | # check creating the movie
48 | path = tmp_path / "test_movie.mov"
49 |
50 | try:
51 | movie_func(storage, filename=path, progress=False)
52 | except RuntimeError:
53 | pass # can happen when ffmpeg is not installed
54 | else:
55 | assert path.stat().st_size > 0
56 |
57 |
58 | @pytest.mark.skipif(not movies.Movie.is_available(), reason="no ffmpeg")
59 | def test_movie_wrong_path(tmp_path):
60 | """Test whether there is a useful error message when path doesn't exist."""
61 | with pytest.raises(OSError):
62 | movies.Movie(tmp_path / "unavailable" / "test.mov")
63 |
--------------------------------------------------------------------------------