├── tests ├── __init__.py ├── test_cadet.py ├── test_duplicate_keys.py ├── test_parallelization.py ├── test_meta_class_install_path.py ├── common.py ├── test_install_path_settings.py ├── test_save_as_python.py ├── test_h5.py └── test_dll.py ├── cadet ├── __init__.py ├── cadet_dll_parameterprovider.py ├── runner.py ├── cadet_dll_utils.py ├── cadet.py └── h5.py ├── .gitignore ├── .github ├── dependabot.yml └── workflows │ ├── python-publish.yml │ └── pipeline.yml ├── LICENSE ├── .zenodo.json ├── pyproject.toml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | -------------------------------------------------------------------------------- /cadet/__init__.py: -------------------------------------------------------------------------------- 1 | name = "CADET-Python" 2 | 3 | __version__ = "1.1.0" 4 | 5 | from .h5 import H5 6 | from .cadet import Cadet 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | /.vs 3 | *.h5 4 | *.csv 5 | /__pycache__ 6 | /CADET.egg-info 7 | /dist 8 | /build/lib/cadet 9 | /cadet/__pycache__ 10 | /CADET_Python.egg-info 11 | .idea 12 | *tmp* -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" # See documentation for possible values 4 | directory: "/" # Location of package manifests 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /tests/test_cadet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test general properties of Cadet. 3 | """ 4 | 5 | 6 | import re 7 | import pytest 8 | from cadet import Cadet 9 | 10 | 11 | @pytest.mark.parametrize("use_dll", [True, False]) 12 | def test_version(use_dll): 13 | # Assuming Cadet has a method to set or configure the use of DLL 14 | cadet = Cadet(use_dll=use_dll) 15 | 16 | assert re.match(r"\d\.\d\.\d", cadet.version), "Version format should be X.X.X" 17 | 18 | 19 | if __name__ == '__main__': 20 | pytest.main([__file__]) 21 | -------------------------------------------------------------------------------- /tests/test_duplicate_keys.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import pytest 4 | 5 | from cadet import Cadet 6 | 7 | 8 | @pytest.fixture 9 | def temp_cadet_file(): 10 | """ 11 | Create a new Cadet object for use in tests. 12 | """ 13 | model = Cadet() 14 | 15 | with tempfile.NamedTemporaryFile() as temp: 16 | model.filename = temp 17 | yield model 18 | 19 | 20 | def test_duplicate_keys(temp_cadet_file): 21 | """ 22 | Test that the Cadet class raises a KeyError exception when duplicate keys are set on it. 23 | """ 24 | with pytest.raises(KeyError): 25 | temp_cadet_file.root.input.foo = 1 26 | temp_cadet_file.root.input.Foo = 1 27 | 28 | temp_cadet_file.save() 29 | 30 | 31 | if __name__ == "__main__": 32 | pytest.main() 33 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /tests/test_parallelization.py: -------------------------------------------------------------------------------- 1 | from cadet import Cadet 2 | from joblib import Parallel, delayed 3 | import pytest 4 | from tests.test_dll import setup_model 5 | 6 | n_jobs = 2 7 | 8 | 9 | def run_simulation(model): 10 | model.save() 11 | data = model.run_simulation() 12 | model.delete_file() 13 | 14 | return data 15 | 16 | 17 | def test_parallelization_io(): 18 | model1 = Cadet() 19 | model1.root.input = {'model': 1} 20 | model1.filename = "sim_1.h5" 21 | model2 = Cadet() 22 | model2.root.input = {'model': 2} 23 | model2.filename = "sim_2.h5" 24 | 25 | models = [model1, model2] 26 | 27 | results_sequential = [run_simulation(model) for model in models] 28 | 29 | results_parallel = Parallel(n_jobs=n_jobs, verbose=0)( 30 | delayed(run_simulation)(model, ) for model in models 31 | ) 32 | assert results_sequential == results_parallel 33 | 34 | 35 | def test_parallelization_simulation(): 36 | models = [setup_model(Cadet.autodetect_cadet(), file_name=f"LWE_{i}.h5") for i in range(2)] 37 | 38 | results_sequential = [run_simulation(model) for model in models] 39 | 40 | results_parallel = Parallel(n_jobs=n_jobs, verbose=0)( 41 | delayed(run_simulation)(model, ) for model in models 42 | ) 43 | assert results_sequential == results_parallel 44 | 45 | 46 | if __name__ == "__main__": 47 | pytest.main([__file__]) 48 | -------------------------------------------------------------------------------- /tests/test_meta_class_install_path.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pytest 3 | 4 | from cadet import Cadet 5 | 6 | 7 | """ These tests require two distinct CADET installations to compare between and should not run in the CI""" 8 | 9 | # Full path to cadet.dll or cadet.so, that is different from the system/conda cadet 10 | full_path_dll = Path("path/to/cadet") 11 | 12 | install_path_conda = Cadet.autodetect_cadet() 13 | 14 | 15 | @pytest.mark.local 16 | def test_meta_class(): 17 | if full_path_dll == Path("path/to/cadet"): 18 | raise ValueError("This test requires a secondary CADET installation. Please set the full_path_dll variable.") 19 | Cadet.cadet_path = full_path_dll 20 | assert Cadet.use_dll 21 | 22 | # With a path set in the meta class, the sim instance should not autodetect and use the meta class cadet path 23 | sim = Cadet() 24 | assert sim.use_dll 25 | assert sim.install_path == full_path_dll.parent.parent 26 | assert sim.cadet_dll_path == full_path_dll 27 | assert sim.cadet_cli_path.parent.parent == full_path_dll.parent.parent 28 | 29 | # With an install path given, the sim instance should use the given install path 30 | sim = Cadet(install_path=install_path_conda) 31 | assert sim.install_path == install_path_conda 32 | assert sim.cadet_dll_path.parent.parent == install_path_conda 33 | assert sim.cadet_cli_path.parent.parent == install_path_conda 34 | 35 | # Reset Path 36 | Cadet.cadet_path = None 37 | 38 | 39 | if __name__ == "__main__": 40 | pytest.main([__file__]) 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019 - 2024, The CADET-Python Authors 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the name of the CADET project nor the 12 | names of its contributors may be used to endorse or promote products 13 | derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY 19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /.zenodo.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "CADET-Python: Version 1.0.3", 3 | "upload_type": "software", 4 | "creators": [ 5 | { 6 | "name": "Heymann, William", 7 | "orcid": "0000-0002-5093-0797", 8 | "affiliation": "Forschungszentrum Jülich" 9 | }, 10 | { 11 | "name": "Schmölder, Johannes", 12 | "orcid": "0000-0003-0446-7209", 13 | "affiliation": "Forschungszentrum Jülich" 14 | }, 15 | { 16 | "name": "Jäpel, Ronald", 17 | "orcid": "0000-0002-4912-5176", 18 | "affiliation": "Forschungszentrum Jülich" 19 | }, 20 | { 21 | "name": "Lanzrath, Hannah", 22 | "orcid": "0000-0002-2675-9002", 23 | "affiliation": "Forschungszentrum Jülich" 24 | }, 25 | { 26 | "name": "Leweke, Samuel", 27 | "orcid": "0000-0001-9471-4511", 28 | "affiliation": "Forschungszentrum Jülich" 29 | }, 30 | { 31 | "name": "von Lieres, Eric", 32 | "orcid": "0000-0002-0309-8408", 33 | "affiliation": "Forschungszentrum Jülich" 34 | } 35 | ], 36 | "license": "BSD-3-Clause", 37 | "keywords": [ 38 | "modeling", 39 | "simulation", 40 | "biotechnology", 41 | "process", 42 | "chromatography", 43 | "CADET", 44 | "general rate model", 45 | "Python" 46 | ], 47 | "version": "1.0.3", 48 | "access_right": "open", 49 | "communities": [{"identifier": "open-source"}] 50 | } 51 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=69", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "CADET-Python" 10 | dynamic = ["version"] 11 | authors = [ 12 | { name = "William Heymann", email = "w.heymann@fz-juelich.de" }, 13 | { name = "Samuel Leweke", email = "s.leweke@fz-juelich.de" }, 14 | { name = "Johannes Schmölder", email = "j.schmoelder@fz-juelich.de" }, 15 | { name = "Ronald Jäpel", email = "r.jaepel@fz-juelich.de" }, 16 | ] 17 | description = "CADET-Python is a Python interface to the CADET-Core simulator" 18 | readme = "README.md" 19 | requires-python = ">=3.10" 20 | keywords = ["process modeling", "process optimization", "chromatography"] 21 | license = { text = "BSD-3-Clause" } 22 | classifiers = [ 23 | "Programming Language :: Python :: 3", 24 | "Operating System :: OS Independent", 25 | "License :: OSI Approved :: BSD License", 26 | "Intended Audience :: Science/Research", 27 | ] 28 | dependencies = [ 29 | "addict", 30 | "numpy", 31 | "h5py", 32 | "filelock", 33 | ] 34 | 35 | [project.optional-dependencies] 36 | testing = [ 37 | "pytest", 38 | "joblib" 39 | ] 40 | 41 | [project.urls] 42 | "homepage" = "https://github.com/cadet/CADET-Python" 43 | "Bug Tracker" = "https://github.com/cadet/CADET-Python/issues" 44 | 45 | [tool.setuptools.dynamic] 46 | version = { attr = "cadet.__version__" } 47 | 48 | [tool.setuptools.packages.find] 49 | include = ["cadet*"] 50 | 51 | [tool.ruff] 52 | src = ["cadet"] 53 | line-length = 88 54 | 55 | [tool.pytest.ini_options] 56 | testpaths = ["tests"] 57 | markers = [ 58 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 59 | "local: marks tests as only useful on local installs (deselect with '-m \"not local\"')", 60 | ] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CADET-Python 2 | 3 | **CADET-Python** provides a file-based Python interface for **CADET-Core**, which must be installed separately. For this, please refer to the [installation instructions](https://cadet.github.io/master/getting_started/installation.html) and the [CADET-Core repository](https://github.com/cadet/CADET-Core). 4 | 5 | The CADET-Python package simplifies access by mapping to the [CADET interface](https://cadet.github.io/master/interface/index.html#), **with all dataset names in lowercase**. 6 | 7 | ## Installation 8 | 9 | To install CADET-Python, use: 10 | 11 | ``` 12 | pip install cadet-python 13 | ``` 14 | 15 | ## Usage Example 16 | 17 | The package includes two primary classes: 18 | 19 | - **`CADET`**: The main class to configure and run simulations. 20 | - **`H5`**: A general-purpose HDF5 interface. 21 | 22 | ### Setting Up a Simulation 23 | 24 | To set a simulation parameter, such as the column porosity for column 1. 25 | 26 | Referring to this path in the CADET interface: 27 | ``` 28 | /input/model/unit_001/COL_POROSITY 29 | ``` 30 | In CADET-Python, this is now set as: 31 | ``` 32 | from cadet import Cadet 33 | 34 | # Initialize simulation 35 | sim = Cadet() 36 | 37 | # Set column porosity for unit 001 38 | sim.root.input.model.unit_001.col_porosity = 0.33 39 | ``` 40 | ### Saving the Simulation File 41 | 42 | Before running, save the simulation configuration to a file: 43 | ``` 44 | sim.filename = "/path/to/your/file.hdf5" 45 | sim.save() 46 | ``` 47 | ### Setting the Path to CADET 48 | 49 | To execute the simulation, specify the path to **CADET-Core**. On Windows, set the path to `cadet-cli.exe`: 50 | ``` 51 | sim.cadet_path = '/path/to/cadet-cli' 52 | ``` 53 | ### Running the Simulation and Loading Data 54 | 55 | Run the simulation and load the output data with: 56 | ``` 57 | print(sim.run()) 58 | sim.load() 59 | ``` 60 | ### Reading Data from a Pre-Simulated File 61 | 62 | If you have a pre-simulated file, you can read it directly: 63 | ``` 64 | # Initialize a new simulation object 65 | sim = Cadet() 66 | 67 | # Set the filename for the existing simulation data 68 | sim.filename = "/path/to/your/file.hdf5" 69 | sim.load() 70 | ``` 71 | At this point, any data in the file can be accessed. 72 | -------------------------------------------------------------------------------- /.github/workflows/pipeline.yml: -------------------------------------------------------------------------------- 1 | name: pipeline 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - dev 8 | pull_request: 9 | 10 | jobs: 11 | test-job: 12 | runs-on: ${{ matrix.os }} 13 | 14 | defaults: 15 | run: 16 | shell: bash -l {0} 17 | 18 | strategy: 19 | matrix: 20 | os: [ubuntu-latest] 21 | python-version: ["3.10", "3.11", "3.12"] 22 | include: 23 | - os: windows-latest 24 | python-version: "3.12" 25 | - os: macos-13 26 | python-version: "3.12" 27 | 28 | steps: 29 | - uses: actions/checkout@v4 30 | 31 | - name: Get Date 32 | id: get-date 33 | run: echo "today=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_OUTPUT 34 | shell: bash 35 | 36 | - name: Setup Conda Environment 37 | uses: conda-incubator/setup-miniconda@v3 38 | with: 39 | miniforge-version: latest 40 | use-mamba: true 41 | activate-environment: cadet-python 42 | channels: conda-forge 43 | 44 | - name: Cache conda 45 | uses: actions/cache@v3 46 | env: 47 | # Increase this value to reset cache if environment.yml has not changed 48 | CACHE_NUMBER: 0 49 | with: 50 | path: ${{ env.CONDA }}/envs 51 | key: ${{ matrix.os }}-python_${{ matrix.python-version }}-${{ steps.get-date.outputs.today }}-${{ hashFiles(env.CONDA_FILE) }}-${{ env.CACHE_NUMBER }} 52 | id: cache 53 | 54 | - name: Set up python env 55 | run: | 56 | conda install python=${{ matrix.python-version }} 57 | conda run pip install . 58 | 59 | - name: Install pypa/build 60 | run: | 61 | conda run python -m pip install build --user 62 | 63 | - name: Build binary wheel and source tarball 64 | run: | 65 | conda run python -m build --sdist --wheel --outdir dist/ . 66 | 67 | - name: Test Wheel install and import 68 | run: | 69 | conda run python -c "import cadet; print(cadet.__version__)" 70 | cd .. 71 | 72 | - name: Test with pytest 73 | run: | 74 | conda run pip install .[testing] 75 | conda install -c conda-forge cadet>=5.0.3 76 | pytest tests --rootdir=tests -m "not slow and not local" 77 | -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | from cadet import Cadet 2 | 3 | common = Cadet() 4 | 5 | root = common.root 6 | 7 | root.input.model.solver.gs_type = 1 8 | root.input.model.solver.max_krylov = 0 9 | root.input.model.solver.max_restarts = 10 10 | root.input.model.solver.schur_safety = 1e-8 11 | 12 | # CADET 3.1 and CADET-dev flags are in here so that it works with both 13 | root.input['return'].write_solution_times = 1 14 | root.input['return'].split_components_data = 1 15 | root.input['return'].unit_000.write_sens_bulk = 0 16 | root.input['return'].unit_000.write_sens_flux = 0 17 | root.input['return'].unit_000.write_sens_inlet = 1 18 | root.input['return'].unit_000.write_sens_outlet = 1 19 | root.input['return'].unit_000.write_sens_particle = 0 20 | root.input['return'].unit_000.write_solution_bulk = 0 21 | root.input['return'].unit_000.write_solution_flux = 0 22 | root.input['return'].unit_000.write_solution_inlet = 1 23 | root.input['return'].unit_000.write_solution_outlet = 1 24 | root.input['return'].unit_000.write_solution_particle = 0 25 | root.input['return'].unit_000.write_sens_column = 0 26 | root.input['return'].unit_000.write_sens_column_inlet = 1 27 | root.input['return'].unit_000.write_sens_column_outlet = 1 28 | root.input['return'].unit_000.write_solution_column = 0 29 | root.input['return'].unit_000.write_solution_column_inlet = 1 30 | root.input['return'].unit_000.write_solution_column_outlet = 1 31 | 32 | root.input['return'].unit_001 = root.input['return'].unit_000 33 | root.input['return'].unit_002 = root.input['return'].unit_000 34 | 35 | root.input.model.unit_001.discretization.par_disc_type = 'EQUIDISTANT_PAR' 36 | root.input.model.unit_001.discretization.schur_safety = 1.0e-8 37 | root.input.model.unit_001.discretization.use_analytic_jacobian = 1 38 | root.input.model.unit_001.discretization.weno.boundary_model = 0 39 | root.input.model.unit_001.discretization.weno.weno_eps = 1e-10 40 | root.input.model.unit_001.discretization.weno.weno_order = 3 41 | root.input.model.unit_001.discretization.gs_type = 1 42 | root.input.model.unit_001.discretization.max_krylov = 0 43 | root.input.model.unit_001.discretization.max_restarts = 10 44 | 45 | root.input.solver.time_integrator.abstol = 1e-10 46 | root.input.solver.time_integrator.algtol = 1e-12 47 | root.input.solver.time_integrator.init_step_size = 1e-6 48 | root.input.solver.time_integrator.max_steps = 1000000 49 | root.input.solver.time_integrator.reltol = 1e-10 50 | 51 | root.input.solver.nthreads = 0 52 | -------------------------------------------------------------------------------- /tests/test_install_path_settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import pytest 5 | import re 6 | 7 | from cadet import Cadet 8 | 9 | """ These tests require two distinct CADET installations to compare between and should not run in the CI""" 10 | 11 | 12 | # Full path to cadet.dll or cadet.so, that is different from the system/conda cadet 13 | full_path_dll = Path("path/to/cadet") 14 | home_path_dll = Path("path/to/cadet") 15 | 16 | install_path_conda = Cadet.autodetect_cadet() 17 | 18 | 19 | def test_autodetection(): 20 | sim = Cadet() 21 | assert sim.install_path == install_path_conda 22 | assert sim.cadet_dll_path.parent.parent == install_path_conda 23 | assert sim.cadet_cli_path.parent.parent == install_path_conda 24 | assert sim.cadet_runner.cadet_path.suffix not in [".dll", ".so"] 25 | 26 | 27 | @pytest.mark.local 28 | def test_install_path(): 29 | if full_path_dll == Path("path/to/cadet"): 30 | raise ValueError("This test requires a secondary CADET installation. Please set the full_path_dll variable.") 31 | sim = Cadet(install_path=full_path_dll, use_dll=True) 32 | assert sim.cadet_dll_path == full_path_dll 33 | assert sim.cadet_runner.cadet_path.suffix in [".dll", ".so"] 34 | 35 | # Set root directory of CADET installation 36 | sim = Cadet() 37 | sim.install_path = full_path_dll.parent.parent 38 | sim.use_dll = True 39 | assert sim.cadet_dll_path == full_path_dll 40 | assert sim.cadet_runner.cadet_path.suffix in [".dll", ".so"] 41 | 42 | # Set root directory of CADET installation (with user home (`~`)) 43 | sim = Cadet() 44 | sim.install_path = home_path_dll.parent.parent 45 | sim.use_dll = True 46 | assert sim.cadet_dll_path == full_path_dll 47 | assert sim.cadet_runner.cadet_path.suffix in [".dll", ".so"] 48 | 49 | # Set cli/dll path (deprecated) 50 | sim = Cadet() 51 | with pytest.deprecated_call(): 52 | sim.cadet_path = full_path_dll.parent.parent 53 | 54 | sim.use_dll = True 55 | assert sim.cadet_dll_path == full_path_dll 56 | assert sim.cadet_runner.cadet_path.suffix in [".dll", ".so"] 57 | 58 | 59 | @pytest.mark.local 60 | def test_dll_runner_attrs(): 61 | if full_path_dll == Path("path/to/cadet"): 62 | raise ValueError("This test requires a secondary CADET installation. Please set the full_path_dll variable.") 63 | cadet = Cadet(full_path_dll.parent.parent) 64 | cadet_runner = cadet._cadet_dll_runner 65 | assert re.match(r"\d\.\d\.\d", cadet_runner.cadet_version) 66 | assert isinstance(cadet_runner.cadet_branch, str) 67 | assert isinstance(cadet_runner.cadet_build_type, str | None) 68 | assert isinstance(cadet_runner.cadet_commit_hash, str) 69 | assert isinstance(cadet_runner.cadet_path, str | os.PathLike) 70 | 71 | 72 | @pytest.mark.local 73 | def test_cli_runner_attrs(): 74 | if full_path_dll == Path("path/to/cadet"): 75 | raise ValueError("This test requires a secondary CADET installation. Please set the full_path_dll variable.") 76 | cadet = Cadet(full_path_dll.parent.parent) 77 | cadet_runner = cadet._cadet_cli_runner 78 | assert re.match(r"\d\.\d\.\d", cadet_runner.cadet_version) 79 | assert isinstance(cadet_runner.cadet_branch, str) 80 | assert isinstance(cadet_runner.cadet_build_type, str | None) 81 | assert isinstance(cadet_runner.cadet_commit_hash, str) 82 | assert isinstance(cadet_runner.cadet_path, str | os.PathLike) 83 | 84 | 85 | if __name__ == '__main__': 86 | pytest.main(["test_install_path_settings.py"]) 87 | -------------------------------------------------------------------------------- /tests/test_save_as_python.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import numpy as np 4 | import pytest 5 | from addict import Dict 6 | 7 | from cadet import Cadet 8 | 9 | 10 | @pytest.fixture 11 | def original_model(): 12 | """ 13 | Create a new Cadet object for use in tests. 14 | """ 15 | with tempfile.NamedTemporaryFile() as temp: 16 | model = Cadet().create_lwe(file_path=temp.name+".h5") 17 | model.run_simulation() 18 | yield model 19 | 20 | 21 | def test_save_as_python(original_model): 22 | """ 23 | Test saving and regenerating a Cadet model as Python code. 24 | 25 | Verifies that a Cadet model can be serialized to a Python script and 26 | accurately reconstructed by executing the generated script. This ensures 27 | that model parameters, including arrays and edge-case values, are preserved. 28 | 29 | Parameters 30 | ---------- 31 | original_model : Cadet 32 | A Cadet model instance to populate and serialize for testing. 33 | 34 | Raises 35 | ------ 36 | AssertionError 37 | If the regenerated model does not match the original model within 38 | a specified relative tolerance. 39 | """ 40 | # initialize "model" variable to be overwritten by the exec lines later 41 | # it needs to be called "model", as that is the variable that the generated code overwrites 42 | model = Cadet() 43 | 44 | # Populate original_model with all tricky cases currently known 45 | original_model.root.input.foo = 1 46 | original_model.root.input.food = 1.9 47 | original_model.root.input.bar.baryon = np.arange(10) 48 | original_model.root.input.bar.barometer = np.linspace(0, 10, 9) 49 | original_model.root.input.bar.init_q = np.array([], dtype=np.float64) 50 | original_model.root.input.bar.init_qt = np.array([0., 0.0011666666666666668, 0.0023333333333333335]) 51 | original_model.root.input.bar.par_disc_type = np.array([b'EQUIDISTANT_PAR'], dtype='|S15') 52 | original_model.root.input["return"].split_foobar = 1 53 | 54 | code_lines = original_model.save_as_python_script( 55 | filename="temp.py", only_return_pythonic_representation=True 56 | ) 57 | 58 | # remove code lines that save the file 59 | code_lines = code_lines[:-2] 60 | 61 | # populate "sim" variable using the generated code lines 62 | for line in code_lines: 63 | exec(line) 64 | 65 | # test that "sim" is equal to "temp_cadet_file" 66 | recursive_equality_check(original_model.root, model.root, rtol=1e-5) 67 | 68 | 69 | def recursive_equality_check(dict_a: dict, dict_b: dict, rtol=1e-5): 70 | """ 71 | Recursively compare two nested dictionaries for equality. 72 | 73 | Compares the keys and values of two dictionaries. If a value is a nested 74 | dictionary, the function recurses. NumPy arrays are compared using 75 | `np.testing.assert_allclose`, except for byte strings which are compared 76 | directly. 77 | 78 | Parameters 79 | ---------- 80 | dict_a : dict 81 | First dictionary to compare. 82 | dict_b : dict 83 | Second dictionary to compare. 84 | rtol : float, optional 85 | Relative tolerance for comparing NumPy arrays, by default 1e-5. 86 | 87 | Returns 88 | ------- 89 | bool 90 | True if the dictionaries are equal; otherwise, an assertion is raised. 91 | 92 | Raises 93 | ------ 94 | AssertionError 95 | If keys do not match, or values are not equal within the given tolerance. 96 | """ 97 | assert dict_a.keys() == dict_b.keys() 98 | for key in dict_a.keys(): 99 | value_a = dict_a[key] 100 | value_b = dict_b[key] 101 | if type(value_a) in (dict, Dict): 102 | recursive_equality_check(value_a, value_b) 103 | elif isinstance(value_a, np.ndarray): 104 | # This catches cases where strings are stored in arrays, and the dtype S15 causes numpy problems 105 | # which can happen if reading a simulation file back from an H5 file from disk 106 | if value_a.dtype == np.dtype("S15") and len(value_a) == 1 and len(value_b) == 1: 107 | assert value_a[0] == value_b[0] 108 | else: 109 | np.testing.assert_allclose(value_a, value_b, rtol=rtol) 110 | else: 111 | assert value_a == value_b 112 | return True 113 | 114 | 115 | if __name__ == "__main__": 116 | pytest.main([__file__]) 117 | -------------------------------------------------------------------------------- /tests/test_h5.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tempfile 3 | import numpy as np 4 | import json 5 | import os 6 | from pathlib import Path 7 | from addict import Dict 8 | import h5py 9 | from cadet import H5 10 | from cadet.h5 import recursively_save, recursively_load, convert_from_numpy, recursively_load_dict 11 | 12 | 13 | @pytest.fixture 14 | def h5_instance(): 15 | return H5({ 16 | "keyString": "value1", 17 | "keyInt": 42, 18 | "keyArray": np.array([1, 2, 3]), 19 | "keyNone": None, 20 | "keyDict": { 21 | "nestedKeyFloat": 12.345, 22 | "nestedKeyList": [1, 2, 3, 4], 23 | "nestedKeyNone": None, 24 | } 25 | }) 26 | 27 | 28 | @pytest.fixture 29 | def temp_h5_file(): 30 | with tempfile.NamedTemporaryFile(delete=False, suffix=".h5") as tmp: 31 | yield tmp.name 32 | os.remove(tmp.name) 33 | 34 | 35 | @pytest.fixture 36 | def temp_json_file(): 37 | with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp: 38 | yield tmp.name 39 | os.remove(tmp.name) 40 | 41 | 42 | def test_init(h5_instance): 43 | assert isinstance(h5_instance.root, Dict) 44 | assert h5_instance.root.keyString == "value1" 45 | assert h5_instance.root.keyInt == 42 46 | 47 | 48 | def test_save_and_load_h5(h5_instance, temp_h5_file): 49 | h5_instance.filename = temp_h5_file 50 | h5_instance.save() 51 | 52 | new_instance = H5() 53 | new_instance.filename = temp_h5_file 54 | new_instance.load_from_file() 55 | 56 | assert new_instance.root.keyString == b"value1" 57 | assert new_instance.root.keyInt == 42 58 | assert "keyNone" not in new_instance.root 59 | assert all(new_instance.root.keyDict["nestedKeyList"] == [1, 2, 3, 4]) 60 | assert "nestedKeyNone" not in new_instance.root.keyDict 61 | assert np.array_equal(new_instance.root.keyArray, h5_instance.root.keyArray) 62 | 63 | 64 | def test_save_and_load_json(h5_instance, temp_json_file): 65 | h5_instance.save_json(temp_json_file) 66 | 67 | new_instance = H5() 68 | new_instance.load_json(temp_json_file) 69 | 70 | assert new_instance.root.keyString == "value1" 71 | assert new_instance.root.keyInt == 42 72 | assert new_instance.root.keyArray == [1, 2, 3] 73 | 74 | 75 | def test_append_data(h5_instance, temp_h5_file): 76 | h5_instance.filename = temp_h5_file 77 | h5_instance.save() 78 | 79 | h5_instance["key4"] = "new_value" 80 | 81 | with pytest.raises(KeyError): 82 | # This correctly raises a KeyError because h5_instance still contains 83 | # e.g. keyString and .append would try to over-write keyString 84 | h5_instance.append() 85 | 86 | addition_h5_instance = H5() 87 | addition_h5_instance.filename = temp_h5_file 88 | 89 | addition_h5_instance["key4"] = "new_value" 90 | addition_h5_instance.append() 91 | 92 | new_instance = H5() 93 | new_instance.filename = temp_h5_file 94 | new_instance.load_from_file() 95 | 96 | assert new_instance.root.key4 == b"new_value" 97 | 98 | 99 | def test_update(h5_instance): 100 | other_instance = H5({"keyInt": 100, "key4": "added"}) 101 | h5_instance.update(other_instance) 102 | 103 | assert h5_instance.root.keyInt == 100 104 | assert h5_instance.root.key4 == "added" 105 | 106 | 107 | def test_recursively_save_and_load(h5_instance, temp_h5_file): 108 | data = Dict({"group": {"dataset": np.array([10, 20, 30])}}) 109 | 110 | with h5py.File(temp_h5_file, "w") as h5file: 111 | recursively_save(h5file, "/", data, lambda x: x) 112 | 113 | with h5py.File(temp_h5_file, "r") as h5file: 114 | loaded_data = recursively_load(h5file, "/", lambda x: x, None) 115 | 116 | assert np.array_equal(loaded_data["group"]["dataset"], np.array([10, 20, 30])) 117 | 118 | 119 | def test_transform_methods(): 120 | instance = H5() 121 | data = np.array([1, 2, 3]) 122 | 123 | transformed = instance.transform(data) 124 | inverse_transformed = instance.inverse_transform(transformed) 125 | 126 | assert np.array_equal(inverse_transformed, data) 127 | 128 | 129 | def test_convert_from_numpy(): 130 | data = Dict({"array": np.array([1, 2, 3]), "scalar": np.int32(10)}) 131 | converted = convert_from_numpy(data) 132 | 133 | assert converted["array"] == [1, 2, 3] 134 | assert converted["scalar"] == 10 135 | 136 | 137 | def test_recursively_load_dict(): 138 | data = {"nested": {"value": np.int32(42), "bytes": b"text"}} 139 | loaded = recursively_load_dict(data) 140 | 141 | assert loaded.nested.value == 42 142 | assert loaded.nested.bytes == "text" 143 | 144 | 145 | def test_get_set_item(h5_instance): 146 | h5_instance["key4"] = "test_value" 147 | assert h5_instance["key4"] == "test_value" 148 | 149 | h5_instance["nested/key5"] = 123 150 | assert h5_instance["nested/key5"] == 123 151 | 152 | 153 | def test_string_representation(h5_instance): 154 | representation = str(h5_instance) 155 | assert "Filename = None" in representation 156 | assert "keyString" in representation 157 | assert "keyInt" in representation 158 | 159 | 160 | def test_load_nonexistent_file(): 161 | instance = H5() 162 | instance.filename = "nonexistent_file.h5" 163 | with pytest.raises(OSError): 164 | instance.load_from_file() 165 | 166 | 167 | def test_save_without_filename(h5_instance): 168 | with pytest.raises(ValueError): 169 | h5_instance.save() 170 | 171 | 172 | def test_load_json_with_invalid_data(temp_json_file): 173 | invalid_data = "{invalid_json: true}" 174 | Path(temp_json_file).write_text(invalid_data) 175 | 176 | instance = H5() 177 | with pytest.raises(json.JSONDecodeError): 178 | instance.load_json(temp_json_file) 179 | -------------------------------------------------------------------------------- /cadet/cadet_dll_parameterprovider.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import cadet.cadet_dll_utils as utils 3 | import addict 4 | from typing import Any, Dict, Optional, Union 5 | 6 | 7 | c_cadet_result = ctypes.c_int 8 | array_double = ctypes.POINTER(ctypes.POINTER(ctypes.c_double)) 9 | point_int = ctypes.POINTER(ctypes.c_int) 10 | 11 | 12 | def null(*args: Any) -> None: 13 | pass 14 | 15 | 16 | if 0: 17 | log_print = print 18 | else: 19 | log_print = null 20 | 21 | 22 | class NestedDictReader: 23 | """ 24 | Utility class to read and navigate through nested dictionaries. 25 | """ 26 | 27 | def __init__(self, data: Dict[str, Any]) -> None: 28 | self._root = data 29 | self._cursor = [data] 30 | self.buffer: Optional[Any] = None 31 | 32 | def push_scope(self, scope: str) -> bool: 33 | """ 34 | Enter a nested scope within the dictionary. 35 | 36 | Parameters 37 | ---------- 38 | scope : str 39 | The key representing the nested scope. 40 | 41 | Returns 42 | ------- 43 | bool 44 | True if the scope exists and was entered, otherwise False. 45 | """ 46 | if scope in self._cursor[-1]: 47 | log_print(f'Entering scope {scope}') 48 | self._cursor.append(self._cursor[-1][scope]) 49 | return True 50 | return False 51 | 52 | def pop_scope(self) -> None: 53 | """ 54 | Exit the current scope. 55 | """ 56 | if len(self._cursor) > 1: 57 | self._cursor.pop() 58 | log_print('Exiting scope') 59 | 60 | def current(self) -> Any: 61 | """ 62 | Get the current scope data. 63 | 64 | Returns 65 | ------- 66 | Any 67 | The current data under the scope. 68 | """ 69 | return self._cursor[-1] 70 | 71 | 72 | def recursively_convert_dict(data: Dict[str, Any]) -> addict.Dict: 73 | """ 74 | Recursively convert dictionary keys to uppercase while preserving nested structure. 75 | 76 | Parameters 77 | ---------- 78 | data : dict 79 | The dictionary to convert. 80 | 81 | Returns 82 | ------- 83 | addict.Dict 84 | A new dictionary with all keys converted to uppercase. 85 | """ 86 | ans = addict.Dict() 87 | for key_original, item in data.items(): 88 | if isinstance(item, dict): 89 | ans[key_original] = recursively_convert_dict(item) 90 | else: 91 | key = str.upper(key_original) 92 | ans[key] = item 93 | return ans 94 | 95 | 96 | class PARAMETERPROVIDER(ctypes.Structure): 97 | """ 98 | Implement the CADET Parameter Provider interface, allowing querying Python for parameters. 99 | 100 | This class exposes various function pointers as fields in a ctypes structure 101 | to be used with CADET's C-API. 102 | 103 | Parameters 104 | ---------- 105 | simulation : Cadet 106 | The simulation object containing the input data. 107 | """ 108 | 109 | def __init__(self, simulation: "Cadet") -> None: 110 | sim_input = recursively_convert_dict(simulation.root.input) 111 | self.userData = NestedDictReader(sim_input) 112 | 113 | # Assign function pointers at instance level 114 | self.getDouble = self._fields_[1][1](utils.param_provider_get_double) 115 | self.getInt = self._fields_[2][1](utils.param_provider_get_int) 116 | self.getBool = self._fields_[3][1](utils.param_provider_get_bool) 117 | self.getString = self._fields_[4][1](utils.param_provider_get_string) 118 | 119 | self.getDoubleArray = self._fields_[5][1](utils.param_provider_get_double_array) 120 | self.getIntArray = self._fields_[6][1](utils.param_provider_get_int_array) 121 | self.getBoolArray = ctypes.cast(None, self._fields_[7][1]) 122 | self.getStringArray = ctypes.cast(None, self._fields_[8][1]) 123 | 124 | self.getDoubleArrayItem = self._fields_[9][1](utils.param_provider_get_double_array_item) 125 | self.getIntArrayItem = self._fields_[10][1](utils.param_provider_get_int_array_item) 126 | self.getBoolArrayItem = self._fields_[11][1](utils.param_provider_get_bool_array_item) 127 | self.getStringArrayItem = self._fields_[12][1](utils.param_provider_get_string_array_item) 128 | 129 | self.exists = self._fields_[13][1](utils.param_provider_exists) 130 | self.isArray = self._fields_[14][1](utils.param_provider_is_array) 131 | self.numElements = self._fields_[15][1](utils.param_provider_num_elements) 132 | self.pushScope = self._fields_[16][1](utils.param_provider_push_scope) 133 | self.popScope = self._fields_[17][1](utils.param_provider_pop_scope) 134 | 135 | _fields_ = [ 136 | # 0 (Position must match indices in __init__ method.) 137 | ('userData', ctypes.py_object), 138 | 139 | # 1 140 | ('getDouble', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, ctypes.POINTER(ctypes.c_double))), 141 | ('getInt', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, point_int)), 142 | ('getBool', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, ctypes.POINTER(ctypes.c_uint8))), 143 | ('getString', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, ctypes.POINTER(ctypes.c_char_p))), 144 | 145 | # 5 146 | ('getDoubleArray', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, point_int, array_double)), 147 | ('getIntArray', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, point_int, ctypes.POINTER(point_int))), 148 | ('getBoolArray', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, point_int, ctypes.POINTER(ctypes.POINTER(ctypes.c_uint8)))), 149 | ('getStringArray', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, point_int, ctypes.POINTER(ctypes.POINTER(ctypes.c_char_p)))), 150 | 151 | # 9 152 | ('getDoubleArrayItem', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_double))), 153 | ('getIntArrayItem', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, ctypes.c_int, point_int)), 154 | ('getBoolArrayItem', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_uint8))), 155 | ('getStringArrayItem', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p))), 156 | 157 | # 13 158 | ('exists', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.py_object, ctypes.c_char_p)), 159 | ('isArray', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p, ctypes.POINTER(ctypes.c_uint8))), 160 | ('numElements', ctypes.CFUNCTYPE(ctypes.c_int, ctypes.py_object, ctypes.c_char_p)), 161 | ('pushScope', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object, ctypes.c_char_p)), 162 | ('popScope', ctypes.CFUNCTYPE(c_cadet_result, ctypes.py_object)), 163 | ] 164 | -------------------------------------------------------------------------------- /cadet/runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import re 4 | import subprocess 5 | from abc import ABC, abstractmethod 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | from typing import Optional 9 | 10 | 11 | @dataclass 12 | class ReturnInformation: 13 | """ 14 | Class to store information about a CADET run return status. 15 | 16 | Parameters 17 | ---------- 18 | return_code : int 19 | An integer representing the return code. 0 indicates success, non-zero values indicate errors. 20 | error_message : str 21 | A string containing the error message if an error occurred. Empty if no error. 22 | log : str 23 | A string containing log information. 24 | """ 25 | return_code: int 26 | error_message: str 27 | log: str 28 | 29 | 30 | class CadetRunnerBase(ABC): 31 | """ 32 | Abstract base class for CADET runners. 33 | 34 | Subclasses must implement the `run`, `clear`, and `load_results` methods. 35 | """ 36 | 37 | @abstractmethod 38 | def run( 39 | self, 40 | simulation: "Cadet", 41 | timeout: Optional[int] = None, 42 | ) -> ReturnInformation: 43 | """ 44 | Run a CADET simulation. 45 | 46 | Parameters 47 | ---------- 48 | simulation : Cadet 49 | The simulation object. 50 | timeout : Optional[int] 51 | Maximum time allowed for the simulation to run, in seconds. 52 | 53 | Returns 54 | ------- 55 | ReturnInformation 56 | Information about the simulation run. 57 | """ 58 | pass 59 | 60 | @abstractmethod 61 | def clear(self) -> None: 62 | """ 63 | Clear the simulation data. 64 | """ 65 | pass 66 | 67 | @abstractmethod 68 | def load_results(self, sim: "Cadet") -> None: 69 | """ 70 | Load the results of the simulation into the provided object. 71 | 72 | Parameters 73 | ---------- 74 | sim : Cadet 75 | The simulation object where results will be loaded. 76 | """ 77 | pass 78 | 79 | @property 80 | @abstractmethod 81 | def cadet_version(self) -> str: 82 | pass 83 | 84 | @property 85 | @abstractmethod 86 | def cadet_branch(self) -> str: 87 | pass 88 | 89 | @property 90 | @abstractmethod 91 | def cadet_build_type(self) -> str: 92 | pass 93 | 94 | @property 95 | @abstractmethod 96 | def cadet_commit_hash(self) -> str: 97 | pass 98 | 99 | @property 100 | @abstractmethod 101 | def cadet_path(self) -> Optional[os.PathLike]: 102 | pass 103 | 104 | 105 | class CadetCLIRunner(CadetRunnerBase): 106 | """ 107 | File-based CADET runner. 108 | 109 | This class runs CADET simulations using a command-line interface (CLI) executable. 110 | """ 111 | 112 | def __init__(self, cadet_path: str | os.PathLike) -> None: 113 | """ 114 | Initialize the CadetFileRunner. 115 | 116 | Parameters 117 | ---------- 118 | cadet_path : os.PathLike 119 | Path to the CADET CLI executable. 120 | """ 121 | cadet_path = Path(cadet_path) 122 | 123 | self._cadet_path = cadet_path 124 | self._get_cadet_version() 125 | 126 | def run( 127 | self, 128 | simulation: "Cadet", 129 | timeout: Optional[int] = None, 130 | ) -> ReturnInformation: 131 | """ 132 | Run a CADET simulation using the CLI executable. 133 | 134 | Parameters 135 | ---------- 136 | simulation : Cadet 137 | Not used in this runner. 138 | timeout : Optional[int] 139 | Maximum time allowed for the simulation to run, in seconds. 140 | 141 | Raises 142 | ------ 143 | RuntimeError 144 | If the simulation process returns a non-zero exit code. 145 | 146 | Returns 147 | ------- 148 | ReturnInformation 149 | Information about the simulation run. 150 | """ 151 | if simulation.filename is None: 152 | raise ValueError("Filename must be set before run can be used") 153 | 154 | data = subprocess.run( 155 | [self.cadet_path, str(simulation.filename)], 156 | timeout=timeout, 157 | capture_output=True 158 | ) 159 | 160 | return_info = ReturnInformation( 161 | return_code=data.returncode, 162 | error_message=data.stderr.decode('utf-8'), 163 | log=data.stdout.decode('utf-8') 164 | ) 165 | 166 | return return_info 167 | 168 | def clear(self) -> None: 169 | """ 170 | Clear the simulation data. 171 | 172 | This method can be extended if any cleanup is required. 173 | """ 174 | pass 175 | 176 | def load_results(self, sim: "Cadet") -> None: 177 | """ 178 | Load the results of the simulation into the provided object. 179 | 180 | Parameters 181 | ---------- 182 | sim : Cadet 183 | The simulation object where results will be loaded. 184 | """ 185 | sim.load_from_file(paths=["/meta", "/output"], update=True) 186 | 187 | def _get_cadet_version(self) -> dict: 188 | """ 189 | Get version and branch name of the currently instanced CADET build. 190 | Returns 191 | ------- 192 | dict 193 | Dictionary containing: cadet_version as x.x.x, cadet_branch, cadet_build_type, cadet_commit_hash 194 | Raises 195 | ------ 196 | ValueError 197 | If version and branch name cannot be found in the output string. 198 | RuntimeError 199 | If any unhandled event during running the subprocess occurs. 200 | """ 201 | try: 202 | result = subprocess.run( 203 | [self.cadet_path, '--version'], 204 | check=True, 205 | stdout=subprocess.PIPE, 206 | stderr=subprocess.PIPE, 207 | text=True 208 | ) 209 | version_output = result.stdout.strip() 210 | 211 | version_match = re.search( 212 | r'cadet-cli version ([\d.]+) \((.*) branch\)\n', 213 | version_output 214 | ) 215 | 216 | commit_hash_match = re.search( 217 | "Built from commit (.*)\n", 218 | version_output 219 | ) 220 | 221 | build_variant_match = re.search( 222 | "Build variant (.*)\n", 223 | version_output 224 | ) 225 | 226 | if version_match: 227 | self._cadet_version = version_match.group(1) 228 | self._cadet_branch = version_match.group(2) 229 | self._cadet_commit_hash = commit_hash_match.group(1) 230 | if build_variant_match: 231 | self._cadet_build_type = build_variant_match.group(1) 232 | else: 233 | self._cadet_build_type = None 234 | else: 235 | raise ValueError("CADET version or branch name missing from output.") 236 | except subprocess.CalledProcessError as e: 237 | raise RuntimeError(f"Command execution failed: {e}") 238 | 239 | @property 240 | def cadet_version(self) -> str: 241 | return self._cadet_version 242 | 243 | @property 244 | def cadet_branch(self) -> str: 245 | return self._cadet_branch 246 | 247 | @property 248 | def cadet_build_type(self) -> str: 249 | return self._cadet_build_type 250 | 251 | @property 252 | def cadet_commit_hash(self) -> str: 253 | return self._cadet_commit_hash 254 | 255 | @property 256 | def cadet_path(self) -> os.PathLike: 257 | return self._cadet_path 258 | -------------------------------------------------------------------------------- /cadet/cadet_dll_utils.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | from typing import Any 3 | 4 | import numpy as np 5 | 6 | 7 | def null(*args: Any) -> None: 8 | """Do nothing (used as a placeholder function).""" 9 | pass 10 | 11 | 12 | log_print = print if 0 else null 13 | 14 | 15 | # %% Single entries 16 | 17 | def param_provider_get_double( 18 | reader: Any, 19 | name: ctypes.c_char_p, 20 | val: ctypes.POINTER(ctypes.c_double) 21 | ) -> int: 22 | """ 23 | Retrieve a double value from the reader based on the provided name. 24 | 25 | Parameters 26 | ---------- 27 | reader : Any 28 | The reader object containing the current data scope. 29 | name : ctypes.c_char_p 30 | The name of the parameter to retrieve. 31 | val : ctypes.POINTER(ctypes.c_double) 32 | A pointer to store the retrieved double value. 33 | 34 | Returns 35 | ------- 36 | int 37 | 0 if the value was found and retrieved successfully, -1 otherwise. 38 | """ 39 | n = name.decode('utf-8') 40 | c = reader.current() 41 | 42 | if n not in c: 43 | log_print(f"Parameter {n} not found.") 44 | return -1 45 | 46 | o = c[n] 47 | try: 48 | float_val = float(o) 49 | except TypeError: 50 | float_val = float(o[0]) 51 | 52 | val[0] = ctypes.c_double(float_val) 53 | log_print(f"GET scalar [double] {n}: {float(val[0])}") 54 | return 0 55 | 56 | 57 | def param_provider_get_int( 58 | reader: Any, 59 | name: ctypes.c_char_p, 60 | val: ctypes.POINTER(ctypes.c_int) 61 | ) -> int: 62 | """ 63 | Retrieve an integer value from the reader based on the provided name. 64 | 65 | Parameters 66 | ---------- 67 | reader : Any 68 | The reader object containing the current data scope. 69 | name : ctypes.c_char_p 70 | The name of the parameter to retrieve. 71 | val : ctypes.POINTER(ctypes.c_int) 72 | A pointer to store the retrieved integer value. 73 | 74 | Returns 75 | ------- 76 | int 77 | 0 if the value was found and retrieved successfully, -1 otherwise. 78 | """ 79 | n = name.decode('utf-8') 80 | c = reader.current() 81 | 82 | if n not in c: 83 | log_print(f"Parameter {n} not found.") 84 | return -1 85 | 86 | o = c[n] 87 | try: 88 | int_val = int(o) 89 | except TypeError: 90 | int_val = int(o[0]) 91 | 92 | val[0] = ctypes.c_int(int_val) 93 | 94 | log_print(f"GET scalar [int] {n}: {int(val[0])}") 95 | return 0 96 | 97 | 98 | def param_provider_get_bool( 99 | reader: Any, 100 | name: ctypes.c_char_p, 101 | val: ctypes.POINTER(ctypes.c_uint8) 102 | ) -> int: 103 | """ 104 | Retrieve a boolean value from the reader based on the provided name. 105 | 106 | Parameters 107 | ---------- 108 | reader : Any 109 | The reader object containing the current data scope. 110 | name : ctypes.c_char_p 111 | The name of the parameter to retrieve. 112 | val : ctypes.POINTER(ctypes.c_uint8) 113 | A pointer to store the retrieved boolean value. 114 | 115 | Returns 116 | ------- 117 | int 118 | 0 if the value was found and retrieved successfully, -1 otherwise. 119 | """ 120 | n = name.decode('utf-8') 121 | c = reader.current() 122 | 123 | if n not in c: 124 | log_print(f"Parameter {n} not found.") 125 | return -1 126 | 127 | o = c[n] 128 | try: 129 | int_val = int(o) 130 | except TypeError: 131 | int_val = int(o[0]) 132 | 133 | val[0] = ctypes.c_uint8(int_val) 134 | 135 | log_print(f"GET scalar [bool] {n}: {bool(val[0])}") 136 | return 0 137 | 138 | 139 | def param_provider_get_string( 140 | reader: Any, 141 | name: ctypes.c_char_p, 142 | val: ctypes.POINTER(ctypes.c_char_p) 143 | ) -> int: 144 | """ 145 | Retrieve a string value from the reader based on the provided name. 146 | 147 | Parameters 148 | ---------- 149 | reader : Any 150 | The reader object containing the current data scope. 151 | name : ctypes.c_char_p 152 | The name of the parameter to retrieve. 153 | val : ctypes.POINTER(ctypes.c_char_p) 154 | A pointer to store the retrieved string value. 155 | 156 | Returns 157 | ------- 158 | int 159 | 0 if the value was found and retrieved successfully, -1 otherwise. 160 | """ 161 | n = name.decode('utf-8') 162 | c = reader.current() 163 | 164 | if n not in c: 165 | log_print(f"Parameter {n} not found.") 166 | return -1 167 | 168 | o = c[n] 169 | 170 | if hasattr(o, 'encode'): 171 | bytes_val = o.encode('utf-8') 172 | elif hasattr(o, 'decode'): 173 | bytes_val = o 174 | elif hasattr(o[0], 'encode'): 175 | bytes_val = o[0].encode('utf-8') 176 | elif hasattr(o[0], 'decode'): 177 | bytes_val = o[0] 178 | 179 | reader.buffer = bytes_val 180 | val[0] = ctypes.cast(reader.buffer, ctypes.c_char_p) 181 | return 0 182 | 183 | 184 | # %% Arrays 185 | 186 | def param_provider_get_double_array( 187 | reader: Any, 188 | name: ctypes.c_char_p, 189 | n_elem: ctypes.POINTER(ctypes.c_int), 190 | val: ctypes.POINTER(ctypes.POINTER(ctypes.c_double)) 191 | ) -> int: 192 | """ 193 | Retrieve a double array from the reader based on the provided name. 194 | 195 | Parameters 196 | ---------- 197 | reader : Any 198 | The reader object containing the current data scope. 199 | name : ctypes.c_char_p 200 | The name of the parameter to retrieve. 201 | n_elem : ctypes.POINTER(ctypes.c_int) 202 | A pointer to store the number of elements in the array. 203 | val : ctypes.POINTER(ctypes.POINTER(ctypes.c_double)) 204 | A pointer to store the retrieved array. 205 | 206 | Returns 207 | ------- 208 | int 209 | 0 if the array was found and retrieved successfully, -1 otherwise. 210 | """ 211 | n = name.decode('utf-8') 212 | c = reader.current() 213 | 214 | if n not in c: 215 | log_print(f"Parameter {n} not found.") 216 | return -1 217 | 218 | o = c[n] 219 | 220 | # Ensure object is a properly aligned numpy array 221 | if isinstance(o, list): # Convert lists to numpy arrays 222 | o = np.array(o, dtype=np.double) 223 | c[n] = o # Update the reader's storage 224 | 225 | # Validate the array 226 | if not isinstance(o, np.ndarray) or o.dtype != np.double or not o.flags.c_contiguous: 227 | log_print(f"Error: Parameter {n} is not a contiguous double array.") 228 | return -1 229 | 230 | # Provide array data to the caller 231 | n_elem[0] = ctypes.c_int(o.size) 232 | val[0] = np.ctypeslib.as_ctypes(o.ravel()) 233 | 234 | log_print(f"GET array [double] {n}: {o}") 235 | return 0 236 | 237 | 238 | def param_provider_get_int_array( 239 | reader: Any, 240 | name: ctypes.c_char_p, 241 | n_elem: ctypes.POINTER(ctypes.c_int), 242 | val: ctypes.POINTER(ctypes.POINTER(ctypes.c_int)) 243 | ) -> int: 244 | """ 245 | Retrieve an integer array from the reader based on the provided name. 246 | 247 | Parameters 248 | ---------- 249 | reader : Any 250 | The reader object containing the current data scope. 251 | name : ctypes.c_char_p 252 | The name of the parameter to retrieve. 253 | n_elem : ctypes.POINTER(ctypes.c_int) 254 | A pointer to store the number of elements in the array. 255 | val : ctypes.POINTER(ctypes.POINTER(ctypes.c_int)) 256 | A pointer to store the retrieved array. 257 | 258 | Returns 259 | ------- 260 | int 261 | 0 if the array was found and retrieved successfully, -1 otherwise. 262 | """ 263 | n = name.decode('utf-8') 264 | c = reader.current() 265 | 266 | if n not in c: 267 | log_print(f"Parameter {n} not found.") 268 | return -1 269 | 270 | o = c[n] 271 | 272 | # Ensure object is a properly aligned numpy array 273 | if isinstance(o, list): # Convert lists to numpy arrays 274 | o = np.array(o, dtype=np.double) 275 | c[n] = o # Update the reader's storage 276 | 277 | # Validate the array 278 | if not isinstance(o, np.ndarray) or o.dtype != np.int32 or not o.flags.c_contiguous: 279 | log_print(f"Error: Parameter {n} is not a contiguous int array.") 280 | return -1 281 | 282 | # Provide array data to the caller 283 | n_elem[0] = ctypes.c_int(o.size) 284 | val[0] = np.ctypeslib.as_ctypes(o) 285 | 286 | log_print(f"GET array [int] {n}: {o}") 287 | return 0 288 | 289 | 290 | # %% Array items 291 | 292 | def param_provider_get_double_array_item( 293 | reader: Any, 294 | name: ctypes.c_char_p, 295 | index: int, 296 | val: ctypes.POINTER(ctypes.c_double) 297 | ) -> int: 298 | """ 299 | Retrieve an item from a double array in the reader based on the provided name and index. 300 | 301 | Parameters 302 | ---------- 303 | reader : Any 304 | The reader object containing the current data scope. 305 | name : ctypes.c_char_p 306 | The name of the parameter to retrieve. 307 | index : int 308 | The index of the array item to retrieve. 309 | val : ctypes.POINTER(ctypes.c_double) 310 | A pointer to store the retrieved double value. 311 | 312 | Returns 313 | ------- 314 | int 315 | 0 if the value was found and retrieved successfully, -1 otherwise. 316 | """ 317 | n = name.decode('utf-8') 318 | c = reader.current() 319 | 320 | if n not in c: 321 | log_print(f"Parameter {n} not found.") 322 | return -1 323 | 324 | o = c[n] 325 | 326 | try: 327 | float_val = float(o) 328 | except TypeError: 329 | float_val = float(o[index]) 330 | 331 | val[0] = ctypes.c_double(float_val) 332 | 333 | log_print(f"GET array [double] ({index}) {n}: {val[0]}") 334 | return 0 335 | 336 | 337 | def param_provider_get_int_array_item( 338 | reader: Any, 339 | name: ctypes.c_char_p, 340 | index: int, 341 | val: ctypes.POINTER(ctypes.c_int) 342 | ) -> int: 343 | """ 344 | Retrieve an item from an integer array in the reader based on the provided name and index. 345 | 346 | Parameters 347 | ---------- 348 | reader : Any 349 | The reader object containing the current data scope. 350 | name : ctypes.c_char_p 351 | The name of the parameter to retrieve. 352 | index : int 353 | The index of the array item to retrieve. 354 | val : ctypes.POINTER(ctypes.c_int) 355 | A pointer to store the retrieved integer value. 356 | 357 | Returns 358 | ------- 359 | int 360 | 0 if the value was found and retrieved successfully, -1 otherwise. 361 | """ 362 | n = name.decode('utf-8') 363 | c = reader.current() 364 | 365 | if n not in c: 366 | log_print(f"Parameter {n} not found.") 367 | return -1 368 | 369 | o = c[n] 370 | 371 | try: 372 | int_val = int(o) 373 | except TypeError: 374 | int_val = int(o[index]) 375 | 376 | val[0] = ctypes.c_int(int_val) 377 | 378 | log_print(f"GET array [int] ({index}) {n}: {val[0]}") 379 | return 0 380 | 381 | 382 | def param_provider_get_bool_array_item( 383 | reader: Any, 384 | name: ctypes.c_char_p, 385 | index: int, 386 | val: ctypes.POINTER(ctypes.c_uint8) 387 | ) -> int: 388 | """ 389 | Retrieve an item from a boolean array in the reader based on the provided name and index. 390 | 391 | Parameters 392 | ---------- 393 | reader : Any 394 | The reader object containing the current data scope. 395 | name : ctypes.c_char_p 396 | The name of the parameter to retrieve. 397 | index : int 398 | The index of the array item to retrieve. 399 | val : ctypes.POINTER(ctypes.c_uint8) 400 | A pointer to store the retrieved boolean value. 401 | 402 | Returns 403 | ------- 404 | int 405 | 0 if the value was found and retrieved successfully, -1 otherwise. 406 | """ 407 | n = name.decode('utf-8') 408 | c = reader.current() 409 | 410 | if n not in c: 411 | log_print(f"Parameter {n} not found.") 412 | return -1 413 | 414 | o = c[n] 415 | 416 | try: 417 | int_val = int(o) 418 | except TypeError: 419 | int_val = int(o[index]) 420 | 421 | val[0] = ctypes.c_uint8(int_val) 422 | 423 | log_print(f"GET array [bool] ({index}) {n}: {bool(val[0])}") 424 | return 0 425 | 426 | 427 | def param_provider_get_string_array_item( 428 | reader: Any, 429 | name: ctypes.c_char_p, 430 | index: int, 431 | val: ctypes.POINTER(ctypes.c_char_p) 432 | ) -> int: 433 | """ 434 | Retrieve an item from a string array in the reader based on the provided name and index. 435 | 436 | Parameters 437 | ---------- 438 | reader : Any 439 | The reader object containing the current data scope. 440 | name : ctypes.c_char_p 441 | The name of the parameter to retrieve. 442 | index : int 443 | The index of the array item to retrieve. 444 | val : ctypes.POINTER(ctypes.c_char_p) 445 | A pointer to store the retrieved string value. 446 | 447 | Returns 448 | ------- 449 | int 450 | 0 if the value was found and retrieved successfully, -1 otherwise. 451 | """ 452 | n = name.decode('utf-8') 453 | c = reader.current() 454 | 455 | if n not in c: 456 | log_print(f"Parameter {n} not found.") 457 | return -1 458 | 459 | o = c[n] 460 | if isinstance(o, bytes): 461 | bytes_val = o 462 | elif isinstance(o, str): 463 | bytes_val = o.encode('utf-8') 464 | elif isinstance(o, (np.ndarray, list)): 465 | bytes_val = o[index] 466 | else: 467 | raise TypeError( 468 | "Unexpected type for name {n}: {type(o)}. " 469 | "Must be of type bytes, str, or np.ndarray." 470 | ) 471 | 472 | reader.buffer = bytes_val 473 | val[0] = ctypes.cast(reader.buffer, ctypes.c_char_p) 474 | 475 | log_print(f"GET array [string] ({index}) {n}: {bytes_val}") 476 | return 0 477 | 478 | 479 | # %% Misc 480 | 481 | def param_provider_exists( 482 | reader: Any, 483 | name: ctypes.c_char_p 484 | ) -> int: 485 | """ 486 | Check if a given parameter name exists in the reader. 487 | 488 | Parameters 489 | ---------- 490 | reader : Any 491 | The reader object containing the current data scope. 492 | name : ctypes.c_char_p 493 | The name of the parameter to check. 494 | 495 | Returns 496 | ------- 497 | int 498 | 1 if the name exists, 0 otherwise. 499 | """ 500 | n = name.decode('utf-8') 501 | c = reader.current() 502 | 503 | log_print(f"EXISTS {n}: {n in c}") 504 | 505 | return 1 if n in c else 0 506 | 507 | 508 | def param_provider_is_array( 509 | reader: Any, 510 | name: ctypes.c_char_p, 511 | res: ctypes.POINTER(ctypes.c_uint8) 512 | ) -> int: 513 | """ 514 | Check if a given parameter is an array. 515 | 516 | Parameters 517 | ---------- 518 | reader : Any 519 | The reader object containing the current data scope. 520 | name : ctypes.c_char_p 521 | The name of the parameter to check. 522 | res : ctypes.POINTER(ctypes.c_uint8) 523 | A pointer to store the result (1 if the parameter is an array, 0 otherwise). 524 | 525 | Returns 526 | ------- 527 | int 528 | 0 if the check was successful, -1 if the parameter does not exist. 529 | """ 530 | n = name.decode('utf-8') 531 | c = reader.current() 532 | 533 | if n not in c: 534 | log_print(f"Parameter {n} not found.") 535 | return -1 536 | 537 | o = c[n] 538 | res[0] = ctypes.c_uint8(1 if isinstance(o, (list, np.ndarray)) else 0) 539 | log_print(f"ISARRAY {n}: {bool(res[0])}") 540 | 541 | return 0 542 | 543 | 544 | def param_provider_num_elements( 545 | reader: Any, 546 | name: ctypes.c_char_p 547 | ) -> int: 548 | """ 549 | Get the number of elements in a given parameter if it is an array. 550 | 551 | Parameters 552 | ---------- 553 | reader : Any 554 | The reader object containing the current data scope. 555 | name : ctypes.c_char_p 556 | The name of the parameter to check. 557 | 558 | Returns 559 | ------- 560 | int 561 | The number of elements if the parameter is an array, 1 otherwise. 562 | """ 563 | n = name.decode('utf-8') 564 | c = reader.current() 565 | 566 | if n not in c: 567 | log_print(f"Parameter {n} not found.") 568 | return -1 569 | 570 | o = c[n] 571 | if isinstance(o, list): 572 | log_print(f"NUMELEMENTS {n}: {len(o)}") 573 | return len(o) 574 | elif isinstance(o, np.ndarray): 575 | log_print(f"NUMELEMENTS {n}: {o.size}") 576 | return o.size 577 | 578 | log_print(f"NUMELEMENTS {n}: 1") 579 | return 1 580 | 581 | 582 | def param_provider_push_scope( 583 | reader: Any, 584 | name: ctypes.c_char_p 585 | ) -> int: 586 | """ 587 | Push a new scope in the reader based on the provided name. 588 | 589 | Parameters 590 | ---------- 591 | reader : Any 592 | The reader object containing the current data scope. 593 | name : ctypes.c_char_p 594 | The name of the scope to push. 595 | 596 | Returns 597 | ------- 598 | int 599 | 0 if the scope was successfully pushed, -1 otherwise. 600 | """ 601 | n = name.decode('utf-8') 602 | 603 | if reader.push_scope(n): 604 | return 0 605 | else: 606 | return -1 607 | 608 | 609 | def param_provider_pop_scope(reader: Any) -> int: 610 | """ 611 | Pop the current scope from the reader. 612 | 613 | Parameters 614 | ---------- 615 | reader : Any 616 | The reader object containing the current data scope. 617 | 618 | Returns 619 | ------- 620 | int 621 | 0 if the scope was successfully popped. 622 | """ 623 | reader.pop_scope() 624 | return 0 625 | -------------------------------------------------------------------------------- /cadet/cadet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import platform 4 | import shutil 5 | import subprocess 6 | from typing import Optional 7 | import warnings 8 | 9 | from addict import Dict 10 | 11 | from cadet.h5 import H5 12 | from cadet.runner import CadetRunnerBase, CadetCLIRunner, ReturnInformation 13 | from cadet.cadet_dll import CadetDLLRunner 14 | 15 | 16 | def is_dll(path: os.PathLike) -> bool: 17 | """ 18 | Determine if the given path points to a shared library. 19 | 20 | Parameters 21 | ---------- 22 | path : os.PathLike 23 | Path to the file. 24 | 25 | Returns 26 | ------- 27 | bool 28 | True if the file has a shared library extension (.so, .dll), False otherwise. 29 | """ 30 | suffix = Path(path).suffix 31 | return suffix in {'.so', '.dll'} 32 | 33 | 34 | def resolve_cadet_paths( 35 | install_path: Optional[os.PathLike] 36 | ) -> tuple[Optional[Path], Optional[Path], Optional[Path], Optional[Path]]: 37 | """ 38 | Resolve paths from the installation path of CADET. 39 | 40 | Parameters 41 | ---------- 42 | install_path : Optional[os.PathLike] 43 | Path to the root of the CADET installation or the executable file 'cadet-cli'. 44 | If a file path is provided, the root directory will be inferred. 45 | 46 | Returns 47 | ------- 48 | tuple[Optional[Path], Optional[Path], Optional[Path], Optional[Path]] 49 | tuple with CADET installation paths 50 | (root_path, cadet_cli_path, cadet_dll_path, cadet_create_lwe_path) 51 | """ 52 | if install_path is None: 53 | return None, None, None, None 54 | 55 | install_path = Path(install_path).expanduser() 56 | 57 | if install_path.is_file(): 58 | cadet_root = install_path.parent.parent 59 | warnings.warn( 60 | "The specified install_path is not the root of the CADET installation. " 61 | "It has been inferred from the file path." 62 | ) 63 | else: 64 | cadet_root = install_path 65 | 66 | root_path = cadet_root 67 | 68 | cli_executable = 'cadet-cli' 69 | lwe_executable = 'createLWE' 70 | 71 | if platform.system() == 'Windows': 72 | cli_executable += '.exe' 73 | lwe_executable += '.exe' 74 | 75 | cadet_cli_path = cadet_root / 'bin' / cli_executable 76 | if not cadet_cli_path.is_file(): 77 | raise FileNotFoundError( 78 | "CADET CLI could not be found. Please check the path." 79 | ) 80 | 81 | cadet_create_lwe_path = cadet_root / 'bin' / lwe_executable 82 | if not cadet_create_lwe_path.is_file(): 83 | raise FileNotFoundError( 84 | "CADET createLWE could not be found. Please check the path." 85 | ) 86 | 87 | if platform.system() == 'Windows': 88 | dll_path = cadet_root / 'bin' / 'cadet.dll' 89 | dll_debug_path = cadet_root / 'bin' / 'cadet_d.dll' 90 | elif platform.system() == 'Darwin': 91 | dll_path = cadet_root / 'lib' / 'libcadet.dylib' 92 | dll_debug_path = cadet_root / 'lib' / 'libcadet_d.dylib' 93 | else: 94 | dll_path = cadet_root / 'lib' / 'libcadet.so' 95 | dll_debug_path = cadet_root / 'lib' / 'libcadet_d.so' 96 | 97 | # Look for debug dll if dll is not found. 98 | if not dll_path.is_file() and dll_debug_path.is_file(): 99 | dll_path = dll_debug_path 100 | 101 | cadet_dll_path = dll_path if dll_path.is_file() else None 102 | 103 | return root_path, cadet_cli_path, cadet_dll_path, cadet_create_lwe_path 104 | 105 | 106 | class CadetMeta(type): 107 | """ 108 | Meta class for the CADET interface. 109 | 110 | This meta class allows setting the `cadet_path` attribute for all instances of the 111 | `Cadet` class. 112 | """ 113 | 114 | use_dll = False 115 | cadet_cli_path = None 116 | cadet_dll_path = None 117 | cadet_create_lwe_path = None 118 | 119 | @property 120 | def cadet_path(cls) -> Optional[Path]: 121 | """ 122 | Get the current CADET path. 123 | 124 | Returns 125 | ------- 126 | Optional[Path] 127 | The current CADET path if set, otherwise None. 128 | """ 129 | if cls.use_dll and cls.cadet_dll_path is not None: 130 | return cls.cadet_dll_path 131 | elif cls.cadet_cli_path is not None: 132 | return cls.cadet_cli_path 133 | else: 134 | return None 135 | 136 | @cadet_path.setter 137 | def cadet_path(cls, cadet_path: Optional[os.PathLike]) -> None: 138 | """ 139 | Set the CADET path and initialize the appropriate runner. 140 | 141 | Parameters 142 | ---------- 143 | cadet_path : os.PathLike 144 | Path to the CADET executable or library. 145 | 146 | Notes 147 | ----- 148 | If the path is a DLL, a `CadetDLLRunner` runner is used. 149 | Otherwise, a `CadetFileRunner` runner is used. 150 | """ 151 | warnings.warn( 152 | "Support for setting Cadet.cadet_path will be removed in a future version. " 153 | "Please set the `install_path` on instance level.", 154 | DeprecationWarning 155 | ) 156 | if cadet_path is None: 157 | cls.use_dll = False 158 | cls._install_path = None 159 | cls.cadet_cli_path = None 160 | cls.cadet_dll_path = None 161 | cls.cadet_create_lwe_path = None 162 | return 163 | 164 | cadet_path = Path(cadet_path) 165 | 166 | cls.use_dll = cadet_path.suffix in [".dll", ".so"] 167 | 168 | install_path, cadet_cli_path, cadet_dll_path, cadet_create_lwe_path = \ 169 | resolve_cadet_paths(cadet_path) 170 | 171 | cls._install_path = install_path 172 | cls.cadet_create_lwe_path = cadet_create_lwe_path 173 | cls.cadet_cli_path = cadet_cli_path 174 | cls.cadet_dll_path = cadet_dll_path 175 | 176 | 177 | class Cadet(H5, metaclass=CadetMeta): 178 | """ 179 | CADET interface class. 180 | 181 | This class manages the CADET runner, whether it's based on the CLI executable or 182 | the in-memory interface and provides methods for running simulations and loading 183 | results. 184 | 185 | Attributes 186 | ---------- 187 | install_path : Optional[Path] 188 | The root directory of the CADET installation. 189 | cadet_cli_path : Optional[Path] 190 | Path to the 'cadet-cli' executable. 191 | cadet_dll_path : Optional[Path] 192 | Path to the 'cadet.dll' or equivalent shared library. 193 | cadet_create_lwe_path : Optional[Path] 194 | Path to the 'createLWE' executable. 195 | return_information : Optional[dict] 196 | Stores the information returned after a simulation run. 197 | """ 198 | 199 | def __init__( 200 | self, 201 | install_path: Optional[Path] = None, 202 | use_dll: bool = False, 203 | *data 204 | ) -> None: 205 | """ 206 | Initialize a new instance of the Cadet class. 207 | 208 | Priority order of install_paths is: 209 | 1. install_path set in __init__ args 210 | 2. install_path set in CadetMeta 211 | 3. auto-detected install_path 212 | 213 | Parameters 214 | ---------- 215 | *data : tuple 216 | Additional data to be passed to the H5 base class initialization. 217 | """ 218 | super().__init__(*data) 219 | 220 | self.cadet_create_lwe_path: Optional[Path] = None 221 | self.return_information: Optional[dict] = None 222 | 223 | self._cadet_cli_runner: Optional[CadetCLIRunner] = None 224 | self._cadet_dll_runner: Optional[CadetDLLRunner] = None 225 | 226 | # Regardless of settings in the Meta Class, if we get an install_path, we 227 | # respect the install_path 228 | if install_path is not None: 229 | self.use_dll = use_dll 230 | self.install_path = install_path # This will automatically set the runners. 231 | return 232 | 233 | # Use CLIRunner of the Meta class, if provided. 234 | if hasattr(self, "cadet_cli_path") and self.cadet_cli_path is not None: 235 | self._cadet_cli_runner: Optional[CadetCLIRunner] = CadetCLIRunner( 236 | self.cadet_cli_path 237 | ) 238 | else: 239 | self._cadet_cli_runner: Optional[CadetCLIRunner] = None 240 | self.use_dll = use_dll 241 | 242 | # Use DLLRunner of the Meta class, if provided. 243 | if hasattr(self, "cadet_dll_path") and self.cadet_dll_path is not None: 244 | try: 245 | self._cadet_dll_runner: Optional[CadetDLLRunner] = CadetDLLRunner( 246 | self.cadet_dll_path 247 | ) 248 | except ValueError: 249 | self.cadet_dll_path = None 250 | self._cadet_dll_runner: Optional[CadetCLIRunner] = None 251 | self.use_dll = False 252 | else: 253 | self._cadet_dll_runner: Optional[CadetCLIRunner] = None 254 | self.use_dll = use_dll 255 | 256 | if self._cadet_cli_runner is not None or self._cadet_dll_runner is not None: 257 | return 258 | 259 | # Auto-detect Cadet if neither Meta Class nor install_path are given. 260 | self.install_path = self.autodetect_cadet() 261 | 262 | @property 263 | def install_path(self) -> Optional[Path]: 264 | """ 265 | Path to the installation of CADET. 266 | 267 | Returns 268 | ------- 269 | Optional[Path] 270 | The root directory of the CADET installation or the path to 'cadet-cli'. 271 | """ 272 | return self._install_path 273 | 274 | @install_path.setter 275 | def install_path(self, install_path: Optional[os.PathLike]) -> None: 276 | """ 277 | Set the installation path of CADET. 278 | 279 | Parameters 280 | ---------- 281 | install_path : Optional[os.PathLike] 282 | Path to the root of the CADET installation or the 'cadet-cli' executable. 283 | If a file path is provided, the root directory will be inferred. 284 | """ 285 | if install_path is None: 286 | self._install_path = None 287 | self.cadet_cli_path = None 288 | self.cadet_dll_path = None 289 | self.cadet_create_lwe_path = None 290 | return 291 | 292 | root_path, cadet_cli_path, cadet_dll_path, create_lwe_path = \ 293 | resolve_cadet_paths(install_path) 294 | 295 | self._install_path = root_path 296 | self.cadet_create_lwe_path = create_lwe_path 297 | 298 | if cadet_cli_path is not None: 299 | self._cadet_cli_runner = CadetCLIRunner(cadet_cli_path) 300 | self.cadet_cli_path = cadet_cli_path 301 | 302 | self.cadet_dll_path = cadet_dll_path 303 | if cadet_dll_path is not None: 304 | try: 305 | self._cadet_dll_runner = CadetDLLRunner(cadet_dll_path) 306 | except ValueError: 307 | pass 308 | 309 | @property 310 | def cadet_path(self) -> Optional[Path]: 311 | """ 312 | Get the path to the current CADET executable or library. 313 | 314 | Returns 315 | ------- 316 | Path 317 | The path to the current CADET executable or library if set, otherwise None. 318 | """ 319 | runner = self.cadet_runner 320 | if runner is not None: 321 | return runner.cadet_path 322 | return None 323 | 324 | @cadet_path.setter 325 | def cadet_path(self, cadet_path: os.PathLike) -> None: 326 | """ 327 | Set the CADET path and initialize the appropriate runner. 328 | 329 | Parameters 330 | ---------- 331 | cadet_path : os.PathLike 332 | Path to the CADET executable or library. 333 | 334 | Notes 335 | ----- 336 | If the path is a DLL, a `CadetDLLRunner` runner is used. 337 | Otherwise, a `CadetFileRunner` runner is used. 338 | """ 339 | cadet_path = Path(cadet_path) 340 | warnings.warn( 341 | "Deprecation warning: Support for setting cadet.cadet_path will be removed " 342 | " in a future version. Use `install_path` instead.", 343 | FutureWarning 344 | ) 345 | self.install_path = cadet_path 346 | 347 | @staticmethod 348 | def autodetect_cadet() -> Optional[Path]: 349 | """ 350 | Autodetect the CADET installation path. 351 | 352 | Returns 353 | ------- 354 | Optional[Path] 355 | The root directory of the CADET installation. 356 | 357 | Raises 358 | ------ 359 | FileNotFoundError 360 | If CADET cannot be found in the system path. 361 | """ 362 | executable = 'cadet-cli' 363 | if platform.system() == 'Windows': 364 | executable += '.exe' 365 | 366 | path = shutil.which(executable) 367 | 368 | if path is None: 369 | raise FileNotFoundError( 370 | "Could not autodetect CADET installation. Please provide path." 371 | ) 372 | 373 | cli_path = Path(path) 374 | cadet_root = cli_path.parent.parent if cli_path else None 375 | 376 | return cadet_root 377 | 378 | @property 379 | def version(self) -> str: 380 | """str: The version of the CADET-Core installation.""" 381 | return self.cadet_runner.cadet_version 382 | 383 | @property 384 | def cadet_runner(self) -> CadetRunnerBase: 385 | """ 386 | Get the current CADET runner instance. 387 | 388 | Returns 389 | ------- 390 | Optional[CadetRunnerBase] 391 | The current runner instance, either a DLL or file-based runner. 392 | """ 393 | if self.use_dll and self.found_dll: 394 | return self._cadet_dll_runner 395 | 396 | if self.use_dll and not self.found_dll: 397 | raise ValueError("Set Cadet to use_dll but no dll interface found.") 398 | 399 | return self._cadet_cli_runner 400 | 401 | def create_lwe(self, file_path=None): 402 | """Create basic LWE example and loads the configuration into self. 403 | 404 | Parameters 405 | ---------- 406 | file_path : Path, optional 407 | Path to store HDF5 file. If None, temporary file will be created and 408 | deleted after simulation. 409 | 410 | """ 411 | file_path_input = file_path 412 | if file_path is None: 413 | file_name = "LWE.h5" 414 | cwd = os.getcwd() 415 | file_path = Path(cwd) / file_name 416 | else: 417 | file_path = Path(file_path).absolute() 418 | file_name = file_path.name 419 | cwd = file_path.parent.as_posix() 420 | 421 | ret = subprocess.run( 422 | [self.cadet_create_lwe_path, '-o', file_name], 423 | stdout=subprocess.PIPE, 424 | stderr=subprocess.PIPE, 425 | cwd=cwd 426 | ) 427 | if ret.returncode != 0: 428 | if ret.stdout: 429 | print('Output', ret.stdout.decode('utf-8')) 430 | if ret.stderr: 431 | print('Errors', ret.stderr.decode('utf-8')) 432 | raise RuntimeError( 433 | "Failure: Creation of test simulation ran into problems" 434 | ) 435 | 436 | self.filename = file_path 437 | 438 | self.load_from_file() 439 | 440 | if file_path_input is None: 441 | self.delete_file() 442 | 443 | return self 444 | 445 | @property 446 | def found_dll(self): 447 | """ 448 | Check if a DLL interface was found. 449 | 450 | Returns 451 | ------- 452 | bool 453 | True if a cadet DLL interface was found. 454 | False otherwise. 455 | """ 456 | return self.cadet_dll_path is not None 457 | 458 | def transform(self, x: str) -> str: 459 | """ 460 | Transform the input string to uppercase. 461 | 462 | Parameters 463 | ---------- 464 | x : str 465 | Input string. 466 | 467 | Returns 468 | ------- 469 | str 470 | Transformed string in uppercase. 471 | """ 472 | return str.upper(x) 473 | 474 | def inverse_transform(self, x: str) -> str: 475 | """ 476 | Transform the input string to lowercase. 477 | 478 | Parameters 479 | ---------- 480 | x : str 481 | Input string. 482 | 483 | Returns 484 | ------- 485 | str 486 | Transformed string in lowercase. 487 | """ 488 | return str.lower(x) 489 | 490 | def run_load( 491 | self, 492 | timeout: Optional[int] = None, 493 | clear: bool = True 494 | ) -> ReturnInformation: 495 | """ 496 | Run the CADET simulation and load the results. 497 | 498 | Parameters 499 | ---------- 500 | timeout : Optional[int] 501 | Maximum time allowed for the simulation to run, in seconds. 502 | clear : bool 503 | If True, clear the simulation results from the current runner instance. 504 | 505 | Returns 506 | ------- 507 | ReturnInformation 508 | Information about the simulation run. 509 | """ 510 | warnings.warn( 511 | "Cadet.run_load() will be removed in a future release. " 512 | "Please use Cadet.run_simulation()", 513 | category=FutureWarning 514 | ) 515 | return_information = self.run_simulation(timeout=timeout, clear=clear) 516 | 517 | return return_information 518 | 519 | def run_simulation( 520 | self, 521 | timeout: Optional[int] = None, 522 | clear: bool = True 523 | ) -> ReturnInformation: 524 | """ 525 | Run the CADET simulation and load the results. 526 | 527 | Parameters 528 | ---------- 529 | timeout : Optional[int] 530 | Maximum time allowed for the simulation to run, in seconds. 531 | clear : bool 532 | If True, clear the simulation results from the current runner instance. 533 | 534 | Returns 535 | ------- 536 | ReturnInformation 537 | Information about the simulation run. 538 | """ 539 | return_information = self.cadet_runner.run( 540 | simulation=self, 541 | timeout=timeout 542 | ) 543 | 544 | if return_information.return_code == 0: 545 | self.cadet_runner.load_results(self) 546 | 547 | if clear: 548 | self.clear() 549 | 550 | return return_information 551 | 552 | def run( 553 | self, 554 | timeout: Optional[int] = None, 555 | ) -> ReturnInformation: 556 | """ 557 | Run the CADET simulation. 558 | 559 | Parameters 560 | ---------- 561 | timeout : Optional[int] 562 | Maximum time allowed for the simulation to run, in seconds. 563 | 564 | Returns 565 | ------- 566 | ReturnInformation 567 | Information about the simulation run. 568 | """ 569 | warnings.warn( 570 | "Cadet.run() will be removed in a future release. \n" 571 | "Please use Cadet.run_simulation()", 572 | category=FutureWarning 573 | ) 574 | return_information = self.cadet_runner.run( 575 | self, 576 | timeout=timeout, 577 | ) 578 | 579 | return return_information 580 | 581 | def load_results(self) -> None: 582 | """Load the results of the last simulation run into the current instance.""" 583 | warnings.warn( 584 | "Cadet.load_results() will be removed in a future release. \n" 585 | "Please use Cadet.load_from_file() to load results from a file " 586 | "or use Cadet.run_simulation() to run a simulation and directly load the " 587 | "simulation results.", 588 | category=FutureWarning 589 | ) 590 | self.load_from_file() 591 | 592 | def load(self) -> None: 593 | """Load the results of the last simulation run into the current instance.""" 594 | warnings.warn( 595 | "Cadet.load() will be removed in a future release. \n" 596 | "Please use Cadet.load_from_file() to load results from a file or use " 597 | "Cadet.run_simulation() to run a simulation and directly load the " 598 | "simulation results.", 599 | category=FutureWarning 600 | ) 601 | self.load_from_file() 602 | 603 | def clear(self) -> None: 604 | """Clear the simulation results from the current runner instance.""" 605 | runner = self.cadet_runner 606 | if runner is not None: 607 | runner.clear() 608 | 609 | def __del__(self): 610 | self.clear() 611 | del self._cadet_dll_runner 612 | del self._cadet_cli_runner 613 | 614 | def __getstate__(self): 615 | state = self.__dict__.copy() 616 | return state 617 | 618 | def __setstate__(self, state): 619 | # Restore the state and cast to addict.Dict() to add __frozen attributes 620 | state = Dict(state) 621 | self.__dict__.update(state) 622 | -------------------------------------------------------------------------------- /cadet/h5.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | from pathlib import Path 5 | import pprint 6 | from typing import Optional, Any 7 | import warnings 8 | 9 | import numpy as np 10 | from addict import Dict 11 | with warnings.catch_warnings(): 12 | warnings.filterwarnings("ignore", category=FutureWarning) 13 | import h5py 14 | import numpy 15 | 16 | import filelock 17 | import contextlib 18 | 19 | 20 | class H5: 21 | """ 22 | A class for handling hierarchical HDF5 data structures and JSON representations. 23 | 24 | Attributes 25 | ---------- 26 | root : Dict 27 | The root data structure holding the HDF5/JSON data. 28 | filename : Optional[str] 29 | Path to the HDF5 file. 30 | 31 | Methods 32 | ------- 33 | transform(x: Any) -> Any 34 | Applies a transformation to the data before saving. 35 | inverse_transform(x: Any) -> Any 36 | Applies an inverse transformation to the data after loading. 37 | load(paths: Optional[List[str]] = None, update: bool = False, lock: bool = False) -> None 38 | Loads data from the specified HDF5 file. 39 | save(lock: bool = False) -> None 40 | Saves the current data to the specified HDF5 file. 41 | save_json(filename: Union[str, Path]) -> None 42 | Saves the current data to a JSON file. 43 | load_json(filename: Union[str, Path], update: bool = False) -> None 44 | Loads data from a JSON file. 45 | append(lock: bool = False) -> None 46 | Appends new keys to the HDF5 file without reading existing data. 47 | update(other: "H5") -> None 48 | Merges another H5 object's data with the current one. 49 | """ 50 | 51 | pp = pprint.PrettyPrinter(indent=4) 52 | 53 | def transform(self, x: Any) -> Any: 54 | """ 55 | Transform the data before saving. 56 | 57 | Parameters 58 | ---------- 59 | x : Any 60 | Data to be transformed. 61 | 62 | Returns 63 | ------- 64 | Any 65 | Transformed data. 66 | """ 67 | return x 68 | 69 | def inverse_transform(self, x: Any) -> Any: 70 | """ 71 | Apply an inverse transformation to the data after loading. 72 | 73 | Parameters 74 | ---------- 75 | x : Any 76 | Data to be transformed back. 77 | 78 | Returns 79 | ------- 80 | Any 81 | Inversely transformed data. 82 | """ 83 | return x 84 | 85 | def __init__(self, *data: Any): 86 | """ 87 | Initialize an H5 object with optional data. 88 | 89 | Parameters 90 | ---------- 91 | data : Any 92 | Optional initial data to populate the object. 93 | """ 94 | self.root = Dict() 95 | self.filename: Optional[str] = None 96 | for i in data: 97 | self.root.update(copy.deepcopy(i)) 98 | 99 | def load( 100 | self, 101 | paths: Optional[list[str]] = None, 102 | update: bool = False, 103 | lock: bool = False 104 | ) -> None: 105 | """ 106 | Load data from the specified HDF5 file. 107 | 108 | Parameters 109 | ---------- 110 | paths : Optional[List[str]], optional 111 | Specific paths to load within the HDF5 file. 112 | update : bool, optional 113 | If True, update the existing data with the loaded data, 114 | i.e. keep existing data and ADD loaded data. 115 | If False, discard existing data and only keep loaded data. 116 | lock : bool, optional 117 | If True, uses a file lock while loading. 118 | """ 119 | warnings.warn( 120 | "Deprecation warning: Support for `load` will be removed in a future " 121 | "version. Use `load_from_file` instead.", 122 | FutureWarning 123 | ) 124 | self.load_from_file(paths=paths, update=update, lock=lock) 125 | 126 | def load_from_file( 127 | self, 128 | paths: Optional[list[str]] = None, 129 | update: bool = False, 130 | lock: bool = False 131 | ) -> None: 132 | """ 133 | Load data from the specified HDF5 file. 134 | 135 | Parameters 136 | ---------- 137 | paths : Optional[List[str]], optional 138 | Specific paths to load within the HDF5 file. 139 | update : bool, optional 140 | If True, update the existing data with the loaded data, 141 | i.e. keep existing data and ADD loaded data. 142 | If False, discard existing data and only keep loaded data. 143 | lock : bool, optional 144 | If True, uses a file lock while loading. 145 | """ 146 | if self.filename is not None: 147 | lock_file = filelock.FileLock( 148 | self.filename + '.lock' 149 | ) if lock else contextlib.nullcontext() 150 | 151 | with lock_file: 152 | with h5py.File(self.filename, 'r') as h5file: 153 | data = Dict( 154 | recursively_load(h5file, '/', self.inverse_transform, paths) 155 | ) 156 | if update: 157 | self.root.update(data) 158 | else: 159 | self.root = data 160 | else: 161 | print('Filename must be set before load can be used') 162 | 163 | def save(self, lock: bool = False) -> None: 164 | """ 165 | Save the current data to the specified HDF5 file. 166 | 167 | Parameters 168 | ---------- 169 | lock : bool, optional 170 | If True, uses a file lock while saving. 171 | 172 | Raises 173 | ------ 174 | ValueError 175 | If the filename is not set before attempting to save. 176 | """ 177 | if self.filename is not None: 178 | lock_file = filelock.FileLock( 179 | self.filename + '.lock' 180 | ) if lock else contextlib.nullcontext() 181 | 182 | with lock_file: 183 | with h5py.File(self.filename, 'w') as h5file: 184 | recursively_save(h5file, '/', self.root, self.transform) 185 | else: 186 | raise ValueError("Filename must be set before save can be used") 187 | 188 | def save_as_python_script( 189 | self, 190 | filename: str, 191 | only_return_pythonic_representation: bool = False 192 | ) -> None | list[str]: 193 | """ 194 | Save the current state as a Python script. 195 | 196 | Parameters 197 | ---------- 198 | filename : str 199 | The name of the file to save the script to. Must end with ".py". 200 | only_return_pythonic_representation : bool, optional 201 | If True, returns the Python code as a list of strings instead of writing 202 | to a file. Defaults to False. 203 | 204 | Returns 205 | ------- 206 | None | list[str] 207 | If `only_return_pythonic_representation` is True, returns a list of strings 208 | representing the Python code. Otherwise, returns None. 209 | 210 | Raises 211 | ------ 212 | Warning 213 | If the filename does not end with ".py". 214 | 215 | """ 216 | if not filename.endswith(".py"): 217 | raise Warning( 218 | "Unexpected filename extension. Consider setting a '.py' file." 219 | ) 220 | 221 | code_lines_list = [ 222 | "import numpy as np", 223 | f"from cadet import {self.__class__.__name__}", 224 | "", 225 | f"model = {self.__class__.__name__}()", 226 | ] 227 | 228 | code_lines_list = recursively_turn_dict_to_python_list( 229 | dictionary=self.root, 230 | current_lines_list=code_lines_list, 231 | prefix="model.root" 232 | ) 233 | 234 | filename_for_reproduced_h5_file = filename.replace(".py", ".h5") 235 | code_lines_list.append(f"model.filename = '{filename_for_reproduced_h5_file}'") 236 | code_lines_list.append("model.save()") 237 | 238 | if not only_return_pythonic_representation: 239 | with open(filename, "w") as handle: 240 | handle.writelines([line + "\n" for line in code_lines_list]) 241 | return 242 | else: 243 | return code_lines_list 244 | 245 | def delete_file(self) -> None: 246 | """Delete the file associated with the current instance.""" 247 | if self.filename is not None: 248 | try: 249 | os.remove(self.filename) 250 | except FileNotFoundError: 251 | pass 252 | 253 | def save_json(self, filename: str | Path) -> None: 254 | """ 255 | Save the current data to a JSON file. 256 | 257 | Parameters 258 | ---------- 259 | filename : str | Path 260 | Path to the JSON file. 261 | """ 262 | with Path(filename).open("w") as fp: 263 | data = convert_from_numpy(self.root, self.transform) 264 | json.dump(data, fp, indent=4, sort_keys=True) 265 | 266 | def load_json(self, filename: str | Path, update: bool = False) -> None: 267 | """ 268 | Load data from a JSON file. 269 | 270 | Parameters 271 | ---------- 272 | filename : str | Path 273 | Path to the JSON file. 274 | update : bool, optional 275 | If True, updates the existing data with the loaded data. 276 | """ 277 | with Path(filename).open("r") as fp: 278 | data = json.load(fp) 279 | data = recursively_load_dict(data, self.inverse_transform) 280 | if update: 281 | self.root.update(data) 282 | else: 283 | self.root = data 284 | 285 | def append(self, lock: bool = False) -> None: 286 | """ 287 | Append new keys to the HDF5 file without reading existing data. 288 | 289 | Parameters 290 | ---------- 291 | lock : bool, optional 292 | If True, uses a file lock while appending. 293 | """ 294 | if self.filename is not None: 295 | lock_file = filelock.FileLock( 296 | self.filename + '.lock' 297 | ) if lock else contextlib.nullcontext() 298 | 299 | with lock_file: 300 | with h5py.File(self.filename, 'a') as h5file: 301 | recursively_save(h5file, '/', self.root, self.transform) 302 | else: 303 | print("Filename must be set before save can be used") 304 | 305 | def __str__(self) -> str: 306 | """ 307 | Return a string representation of the object. 308 | 309 | Returns 310 | ------- 311 | str 312 | String representation of the filename and root data. 313 | """ 314 | temp = [] 315 | temp.append(f'Filename = {self.filename}') 316 | temp.append(self.pp.pformat(self.root)) 317 | return '\n'.join(temp) 318 | 319 | def update(self, other: "H5") -> None: 320 | """ 321 | Merge another H5 object's data with the current one. 322 | 323 | Parameters 324 | ---------- 325 | other : H5 326 | Another H5 object whose data will be merged. 327 | """ 328 | self.root.update(other.root) 329 | 330 | def __getitem__(self, key: str) -> Any: 331 | """ 332 | Access data by key. 333 | 334 | Parameters 335 | ---------- 336 | key : str 337 | Key for accessing nested data. 338 | 339 | Returns 340 | ------- 341 | Any 342 | Retrieved data. 343 | """ 344 | key = key.lower() 345 | obj = self.root 346 | for i in key.split('/'): 347 | if i: 348 | obj = obj[i] 349 | return obj 350 | 351 | def __setitem__(self, key: str, value: Any) -> None: 352 | """ 353 | Set data by key. 354 | 355 | Parameters 356 | ---------- 357 | key : str 358 | Key for accessing nested data. 359 | value : Any 360 | Value to set for the given key. 361 | """ 362 | key = key.lower() 363 | obj = self.root 364 | parts = key.split('/') 365 | for i in parts[:-1]: 366 | if i: 367 | obj = obj[i] 368 | obj[parts[-1]] = value 369 | 370 | 371 | def convert_from_numpy(data: Dict, func: Optional[callable] = None) -> Dict: 372 | """ 373 | Convert a dictionary with NumPy objects into native Python types. 374 | 375 | Parameters 376 | ---------- 377 | data : dict 378 | The input dictionary with potential NumPy types. 379 | func : callable 380 | A function to transform the keys. 381 | 382 | Returns 383 | ------- 384 | dict 385 | A dictionary with transformed keys and native Python types. 386 | """ 387 | ans = {} 388 | for key, item in data.items(): 389 | if func is not None: 390 | key = func(key) 391 | 392 | # Handle NumPy-specific types 393 | if isinstance(item, numpy.ndarray): 394 | item = item.tolist() 395 | elif isinstance(item, numpy.generic): 396 | item = item.item() 397 | 398 | # Handle bytes 399 | elif isinstance(item, bytes): 400 | item = item.decode('utf-8') 401 | 402 | # Recursive handling of nested dictionaries 403 | if isinstance(item, dict): # Assuming Dict is replaced with dict 404 | ans[key] = convert_from_numpy(item, func) 405 | else: 406 | ans[key] = item 407 | 408 | return ans 409 | 410 | 411 | def recursively_load_dict(data: dict, func: Optional[callable] = None) -> Dict: 412 | """ 413 | Recursively load data from a dictionary. 414 | 415 | Parameters 416 | ---------- 417 | data : dict 418 | Input dictionary to load. 419 | func : callable 420 | Transformation function for dictionary keys. 421 | 422 | Returns 423 | ------- 424 | Dict 425 | Dictionary with loaded data. 426 | """ 427 | ans = Dict() 428 | for key, item in data.items(): 429 | if func is not None: 430 | key = func(key) 431 | 432 | if isinstance(item, dict): 433 | ans[key] = recursively_load_dict(item, func) 434 | else: 435 | # Handle bytes 436 | if isinstance(item, numpy.int32): 437 | item = int(item) 438 | elif isinstance(item, bytes): 439 | item = item.decode('utf-8') 440 | 441 | ans[key] = item 442 | return ans 443 | 444 | 445 | def set_path(obj: Dict[str, Any], path: str, value: Any) -> None: 446 | """ 447 | Set a value within a nested dictionary given a slash-separated path. 448 | 449 | Parameters 450 | ---------- 451 | obj : Dict[str, Any] 452 | Dictionary to set the value in. 453 | path : str 454 | Slash-separated path indicating where to set the value. 455 | value : Any 456 | Value to be set at the specified path. 457 | """ 458 | path_parts = [i for i in path.split('/') if i] 459 | 460 | temp = obj 461 | for part in path_parts[:-1]: 462 | if part not in temp or not isinstance(temp[part], dict): 463 | temp[part] = {} # Create intermediate dictionaries as needed 464 | temp = temp[part] 465 | 466 | value = recursively_load_dict(value) 467 | 468 | temp[path_parts[-1]] = value 469 | 470 | 471 | def recursively_load( 472 | h5file: h5py.File, 473 | path: str, 474 | func: callable, 475 | paths: Optional[list[str]] 476 | ) -> Dict: 477 | """ 478 | Recursively load data from an HDF5 file. 479 | 480 | Parameters 481 | ---------- 482 | h5file : h5py.File 483 | The HDF5 file to load data from. 484 | path : str 485 | Path within the HDF5 file. 486 | func : callable 487 | Transformation function for dictionary keys. 488 | paths : Optional[List[str]] 489 | Specific paths to load, or None to load everything. 490 | 491 | Returns 492 | ------- 493 | Dict 494 | Loaded data. 495 | """ 496 | ans = Dict() 497 | if paths is not None: 498 | for path in paths: 499 | item = h5file.get(path, None) 500 | if item is not None: 501 | if isinstance(item, h5py._hl.dataset.Dataset): 502 | set_path(ans, path, item[()]) 503 | elif isinstance(item, h5py._hl.group.Group): 504 | set_path( 505 | ans, path, recursively_load(h5file, path + '/', func, None) 506 | ) 507 | else: 508 | for key_original in h5file[path].keys(): 509 | key = func(key_original) 510 | local_path = path + key 511 | item = h5file[path][key_original] 512 | if isinstance(item, h5py._hl.dataset.Dataset): 513 | ans[key] = item[()] 514 | elif isinstance(item, h5py._hl.group.Group): 515 | ans[key] = recursively_load(h5file, local_path + '/', func, None) 516 | return ans 517 | 518 | 519 | def recursively_save(h5file: h5py.File, path: str, dic: Dict, func: callable) -> None: 520 | """ 521 | Recursively save data to an HDF5 file. 522 | 523 | Parameters 524 | ---------- 525 | h5file : h5py.File 526 | The HDF5 file to save data to. 527 | path : str 528 | Path within the HDF5 file. 529 | dic : Dict 530 | Dictionary of data to save. 531 | func : callable 532 | Transformation function for dictionary keys. 533 | 534 | Raises 535 | ------ 536 | ValueError 537 | If path or h5file types are invalid, or if the dictionary contains unsupported 538 | data types. 539 | """ 540 | if not isinstance(path, str): 541 | raise ValueError("path must be a string") 542 | if not isinstance(h5file, h5py._hl.files.File): 543 | raise ValueError("must be an open h5py file") 544 | if not isinstance(dic, dict): 545 | raise ValueError("must provide a dictionary") 546 | 547 | for key, item in dic.items(): 548 | key = str(key) 549 | 550 | if item is None: 551 | continue 552 | 553 | if not isinstance(key, str): 554 | raise ValueError("dict keys must be strings to save to hdf5") 555 | 556 | if isinstance(item, dict): 557 | recursively_save(h5file, path + key + '/', item, func) 558 | continue 559 | elif isinstance(item, str): 560 | value = numpy.array(item.encode('utf-8')) 561 | elif isinstance(item, list) and all(isinstance(i, str) for i in item): 562 | value = numpy.array([i.encode('utf-8') for i in item]) 563 | else: 564 | try: 565 | value = numpy.array(item) 566 | except TypeError: 567 | raise ValueError( 568 | f'Cannot save {path}/{func(key)} key with {type(item)} type.' 569 | ) 570 | 571 | try: 572 | h5file[path + func(key)] = value 573 | except OSError as e: 574 | if str(e) == 'Unable to create link (name already exists)': 575 | raise KeyError( 576 | 'Name conflict with upper and lower case entries for key ' 577 | f'"{path}{key}".' 578 | ) 579 | else: 580 | raise 581 | 582 | 583 | def recursively_turn_dict_to_python_list(dictionary: dict, current_lines_list: list = None, prefix: str = None): 584 | """ 585 | Recursively convert a nested dictionary (including addict.Dict) into a list of Python code lines 586 | that can regenerate the original nested structure. 587 | 588 | Parameters 589 | ---------- 590 | dictionary : dict 591 | The nested dictionary or addict.Dict to convert. 592 | current_lines_list : list, optional 593 | A list that accumulates the Python code lines as the recursion progresses. 594 | If None, a new list is created. 595 | prefix : str, optional 596 | A prefix used to build fully-qualified variable names representing nested keys. 597 | 598 | Returns 599 | ------- 600 | list of str 601 | List of Python code lines that, when executed, recreate the nested dictionary. 602 | """ 603 | 604 | def merge_to_absolute_key(prefix, key): 605 | """ 606 | Combine prefix and key into a dot-separated path unless the prefix is None. 607 | 608 | Parameters 609 | ---------- 610 | prefix : str or None 611 | The existing path prefix. 612 | key : str 613 | The current key to append. 614 | 615 | Returns 616 | ------- 617 | str 618 | Dot-separated key path if prefix is not None; otherwise, the key itself. 619 | """ 620 | if prefix is None: 621 | return key 622 | else: 623 | return f"{prefix}.{key}" 624 | 625 | def clean_up_key(absolute_key: str): 626 | """ 627 | Sanitize a key path by replacing problematic substrings like '.return'. 628 | 629 | Parameters 630 | ---------- 631 | absolute_key : str 632 | A dot-separated key path. 633 | 634 | Returns 635 | ------- 636 | str 637 | A cleaned key path with special keywords properly escaped. 638 | """ 639 | absolute_key = absolute_key.replace(".return", "['return']") 640 | return absolute_key 641 | 642 | def get_pythonic_representation_of_value(value): 643 | """ 644 | Convert a value to a Python code representation, with NumPy-style modifications. 645 | 646 | Parameters 647 | ---------- 648 | value : any 649 | The value to be represented. 650 | 651 | Returns 652 | ------- 653 | str 654 | A string representation using `repr()`, with `array` replaced by `np.array`. 655 | """ 656 | if isinstance(value, np.ndarray): 657 | if len(value) > 1e7: 658 | raise ValueError("Array is too long to be serialized") 659 | value_representation = np.array2string(value, separator=',', threshold=int(1e7)) 660 | value_representation = f"np.array({value_representation})" 661 | else: 662 | value_representation = repr(value) 663 | return value_representation 664 | 665 | if current_lines_list is None: 666 | current_lines_list = [] 667 | 668 | for key in sorted(dictionary.keys()): 669 | value = dictionary[key] 670 | 671 | absolute_key = merge_to_absolute_key(prefix, key) 672 | 673 | if type(value) in (dict, Dict): 674 | current_lines_list = recursively_turn_dict_to_python_list( 675 | value, 676 | current_lines_list, 677 | prefix=absolute_key 678 | ) 679 | else: 680 | value_representation = get_pythonic_representation_of_value(value) 681 | 682 | absolute_key = clean_up_key(absolute_key) 683 | 684 | current_lines_list.append(f"{absolute_key} = {value_representation}") 685 | 686 | return current_lines_list 687 | -------------------------------------------------------------------------------- /tests/test_dll.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import subprocess 3 | import pytest 4 | 5 | from cadet import Cadet 6 | 7 | 8 | # %% Utility methods 9 | 10 | # Use this to specify custom cadet_roots if you require it. 11 | cadet_root = None 12 | 13 | 14 | def setup_model( 15 | cadet_root, 16 | use_dll=True, 17 | model='GENERAL_RATE_MODEL', 18 | n_partypes=1, 19 | ncol=10, 20 | npar=4, 21 | include_sensitivity=False, 22 | file_name='LWE.h5', 23 | n_components=4 24 | ): 25 | """ 26 | Set up and initialize a CADET model template. 27 | 28 | This function prepares a CADET model template by invoking the `createLWE` executable 29 | with specified parameters. It supports the configuration of the model type, number 30 | of particle types, inclusion of sensitivity analysis, and the name of the output 31 | file. Depending on the operating system, it adjusts the executable name accordingly. 32 | After creating the model, it initializes a Cadet instance with the specified or 33 | default CADET binary and the created model file. 34 | 35 | Parameters 36 | ---------- 37 | cadet_root : str or Path 38 | The root directory where the CADET software is located. 39 | use_dll : bool, optional 40 | If True, use the in-memory interface for CADET. Otherwise, use the CLI. 41 | The default is True. 42 | model : str, optional 43 | The model type to set up. The default is 'GENERAL_RATE_MODEL'. 44 | n_partypes : int, optional 45 | The number of particle types. The default is 1. 46 | ncol : int, optional 47 | The number of axial cells in the unit operation. The default is 10. 48 | npar : int, optional 49 | The number of particle cells in the unit operation. The default is 4. 50 | include_sensitivity : bool, optional 51 | If True, included parameter sensitivities in template. The default is False. 52 | file_name : str, optional 53 | The name of the file to which the CADET model is written. 54 | The default is 'LWE.h5'. 55 | n_components : int, optional 56 | Number of components for the simulation. The default is 4. 57 | 58 | Returns 59 | ------- 60 | Cadet 61 | An initialized Cadet instance with the model loaded. 62 | 63 | Raises 64 | ------ 65 | Exception 66 | If the creation of the test simulation encounters problems, 67 | detailed in the subprocess's stdout and stderr. 68 | FileNotFoundError 69 | If the CADET executable or DLL file cannot be found at the specified paths. 70 | 71 | Notes 72 | ----- 73 | The function assumes the presence of `createLWE` executable within the `bin` 74 | directory of the `cadet_root` path. The sensitivity analysis, if included, is 75 | configured for column porosity. 76 | 77 | See Also 78 | -------- 79 | Cadet : The class representing a CADET simulation model. 80 | 81 | Examples 82 | -------- 83 | >>> cadet_model = setup_model( 84 | '/path/to/cadet', 85 | use_dll=False, 86 | model='GENERAL_RATE_MODEL', 87 | n_partypes=2, 88 | include_sensitivity=True, 89 | file_name='my_model.h5' 90 | ) 91 | This example sets up a GENERAL_RATE_MODEL with 2 particle types, includes 92 | sensitivity analysis, and writes the model to 'my_model.h5', using the command-line 93 | interface. 94 | """ 95 | 96 | cadet_model = Cadet(install_path=cadet_root, use_dll=use_dll) 97 | 98 | args = [ 99 | cadet_model.cadet_create_lwe_path, 100 | f'--out {file_name}', 101 | f'--unit {model}', 102 | f'--parTypes {n_partypes}', 103 | f'--col {ncol}', 104 | f'--par {npar}', 105 | ] 106 | 107 | if include_sensitivity: 108 | args.append('--sens COL_POROSITY/-1/-1/-1/-1/-1/-1/0') 109 | 110 | ret = subprocess.run( 111 | args, 112 | stdout=subprocess.PIPE, 113 | stderr=subprocess.PIPE, 114 | cwd='./' 115 | ) 116 | 117 | if ret.returncode != 0: 118 | if ret.stdout: 119 | print('Output', ret.stdout.decode('utf-8')) 120 | if ret.stderr: 121 | print('Errors', ret.stderr.decode('utf-8')) 122 | raise Exception( 123 | "Failure: Creation of test simulation ran into problems" 124 | ) 125 | 126 | cadet_model.filename = file_name 127 | cadet_model.load_from_file() 128 | if n_components < 4: 129 | unit_000 = cadet_model.root.input.model.unit_000 130 | unit_000.update({ 131 | 'adsorption_model': 'LINEAR', 132 | 'col_dispersion': 5.75e-08, 133 | 'col_length': 0.014, 134 | 'col_porosity': 0.37, 135 | 'cross_section_area': 0.0003141592653589793, 136 | 'film_diffusion': [6.9e-06, ] * n_components, 137 | 'film_diffusion_multiplex': 0, 138 | 'init_c': [0., ] * n_components, 139 | 'init_q': [0., ] * n_components, 140 | 'nbound': [1, ] * n_components, 141 | 'ncomp': 1, 142 | 'npartype': 1, 143 | 'par_coreradius': 0.0, 144 | 'par_diffusion': [7.00e-10, ] * n_components, 145 | 'par_geom': 'SPHERE', 146 | 'par_porosity': 0.75, 147 | 'par_radius': 4.5e-05, 148 | 'par_surfdiffusion': [0., ] * n_components, 149 | 'unit_type': 'GENERAL_RATE_MODEL', 150 | 'velocity': 1.0, 151 | 'adsorption': { 152 | 'is_kinetic': 0, 153 | 'lin_ka': [0.] * n_components, 154 | 'lin_kd': [1.] * n_components 155 | }, 156 | }) 157 | cadet_model.root.input.model.unit_001.update({ 158 | 'inlet_type': b'PIECEWISE_CUBIC_POLY', 159 | 'ncomp': 1, 'unit_type': b'INLET', 160 | 'sec_000': { 161 | 'const_coeff': [50., ], 162 | 'cube_coeff': [0., ], 163 | 'lin_coeff': [0., ], 164 | 'quad_coeff': [0., ] 165 | }, 166 | 'sec_001': { 167 | 'const_coeff': [50., ], 168 | 'cube_coeff': [0., ], 169 | 'lin_coeff': [0., ], 170 | 'quad_coeff': [0., ] 171 | }, 172 | 'sec_002': { 173 | 'const_coeff': [100., ], 174 | 'cube_coeff': [0.2, ], 175 | 'lin_coeff': [0., ], 176 | 'quad_coeff': [0., ] 177 | } 178 | } 179 | ) 180 | # if we don't save and re-load the model we get windows access violations. 181 | # Interesting case for future tests, not what I want to test now. 182 | cadet_model.save() 183 | cadet_model = Cadet(install_path=cadet_root, use_dll=use_dll) 184 | cadet_model.filename = file_name 185 | cadet_model.load_from_file() 186 | 187 | return cadet_model 188 | 189 | 190 | def setup_solution_recorder( 191 | model, 192 | split_components=0, 193 | split_ports=0, 194 | single_as_multi_port=0, 195 | ): 196 | """ 197 | Configure the solution recorder for the simulation. 198 | 199 | This function adjusts the model's settings to specify what simulation data should 200 | be recorded, including solutions at various points (inlet, outlet, bulk, etc.), 201 | sensitivities, and their derivatives. It allows for the configuration of how 202 | components and ports are treated in the output data, potentially splitting them for 203 | individual analysis or aggregating them for a more holistic view. 204 | 205 | Parameters 206 | ---------- 207 | model : Cadet 208 | The model instance to be configured for solution recording. 209 | split_components : int, optional 210 | If 1, split component data in the output. The default is 0. 211 | split_ports : int, optional 212 | If 1, split port data in the output. The default is 0. 213 | single_as_multi_port : int, optional 214 | If 1, treat single ports as multiple ports in the output. The default is 0. 215 | 216 | Examples 217 | -------- 218 | >>> model = Cadet() 219 | >>> setup_solution_recorder(model, split_components=1, split_ports=1, single_as_multi_port=1) 220 | This example demonstrates configuring a Cadet model instance for detailed solution 221 | recording, with component and port data split, and single ports treated as multiple 222 | ports. 223 | 224 | """ 225 | 226 | model.root.input['return'].write_solution_times = 1 227 | model.root.input['return'].write_solution_last = 1 228 | model.root.input['return'].write_sens_last = 1 229 | 230 | model.root.input['return'].split_components_data = split_components 231 | model.root.input['return'].split_ports_data = split_ports 232 | model.root.input['return'].single_as_multi_port = single_as_multi_port 233 | 234 | model.root.input['return'].unit_000.write_coordinates = 1 235 | 236 | model.root.input['return'].unit_000.write_solution_inlet = 1 237 | model.root.input['return'].unit_000.write_solution_outlet = 1 238 | model.root.input['return'].unit_000.write_solution_bulk = 1 239 | model.root.input['return'].unit_000.write_solution_particle = 1 240 | model.root.input['return'].unit_000.write_solution_solid = 1 241 | model.root.input['return'].unit_000.write_solution_flux = 1 242 | model.root.input['return'].unit_000.write_solution_volume = 1 243 | 244 | model.root.input['return'].unit_000.write_soldot_inlet = 1 245 | model.root.input['return'].unit_000.write_soldot_outlet = 1 246 | model.root.input['return'].unit_000.write_soldot_bulk = 1 247 | model.root.input['return'].unit_000.write_soldot_particle = 1 248 | model.root.input['return'].unit_000.write_soldot_solid = 1 249 | model.root.input['return'].unit_000.write_soldot_flux = 1 250 | model.root.input['return'].unit_000.write_soldot_volume = 1 251 | 252 | model.root.input['return'].unit_000.write_sens_inlet = 1 253 | model.root.input['return'].unit_000.write_sens_outlet = 1 254 | model.root.input['return'].unit_000.write_sens_bulk = 1 255 | model.root.input['return'].unit_000.write_sens_particle = 1 256 | model.root.input['return'].unit_000.write_sens_solid = 1 257 | model.root.input['return'].unit_000.write_sens_flux = 1 258 | model.root.input['return'].unit_000.write_sens_volume = 1 259 | 260 | model.root.input['return'].unit_000.write_sensdot_inlet = 1 261 | model.root.input['return'].unit_000.write_sensdot_outlet = 1 262 | model.root.input['return'].unit_000.write_sensdot_bulk = 1 263 | model.root.input['return'].unit_000.write_sensdot_particle = 1 264 | model.root.input['return'].unit_000.write_sensdot_solid = 1 265 | model.root.input['return'].unit_000.write_sensdot_flux = 1 266 | model.root.input['return'].unit_000.write_sensdot_volume = 1 267 | 268 | model.root.input['return'].unit_000.write_solution_last_unit = 1 269 | model.root.input['return'].unit_000.write_soldot_last_unit = 1 270 | 271 | for unit in range(model.root.input.model['nunits']): 272 | model.root.input['return']['unit_{0:03d}'.format(unit)] = model.root.input['return'].unit_000 273 | 274 | if model.filename is not None: 275 | model.save() 276 | 277 | 278 | def run_simulation_with_options(use_dll, model_options, solution_recorder_options): 279 | """Run a simulation with specified options for the model and solution recorder. 280 | 281 | Initializes and configures a simulation model with given options, sets up the 282 | solution recording parameters, and executes the simulation. This function leverages 283 | `setup_model` to create and initialize the model and `setup_solution_recorder` to 284 | configure how the simulation results should be recorded based on the specified 285 | options. 286 | 287 | Parameters 288 | ---------- 289 | use_dll : bool, optional 290 | If True, use the in-memory interface for CADET. Otherwise, use the CLI. 291 | The default is True. 292 | model_options : dict 293 | A dictionary of options to pass to `setup_model` for initializing the model. 294 | Keys should match the parameter names of `setup_model`, excluding `use_dll`. 295 | solution_recorder_options : dict 296 | A dictionary of options to pass to `setup_solution_recorder` for configuring the 297 | solution recorder. Keys should match the parameter names of 298 | `setup_solution_recorder`. 299 | 300 | Returns 301 | ------- 302 | Cadet 303 | An instance of the Cadet class with the model simulated and loaded. 304 | 305 | Examples 306 | -------- 307 | >>> use_dll = True 308 | >>> model_options = { 309 | ... 'model': 'GENERAL_RATE_MODEL', 310 | ... 'n_partypes': 2, 311 | ... 'include_sensitivity': True, 312 | ... 'file_name': 'model_output.h5' 313 | ... } 314 | >>> solution_recorder_options = { 315 | ... 'split_components': 1, 316 | ... 'split_ports': 1, 317 | ... 'single_as_multi_port': True 318 | ... } 319 | >>> model = run_simulation_with_options(use_dll, model_options, solution_recorder_options) 320 | This example configures and runs a GENERAL_RATE_MODEL with sensitivity analysis 321 | and two particle types, records the solution with specific options, and loads the 322 | simulation results for further analysis. 323 | """ 324 | model = setup_model(cadet_root, use_dll, **model_options) 325 | setup_solution_recorder(model, **solution_recorder_options) 326 | 327 | return_info = model.run_simulation() 328 | 329 | if return_info.return_code != 0: 330 | raise RuntimeError(return_info) 331 | 332 | return model 333 | 334 | 335 | # %% Model templates 336 | 337 | cstr_template = { 338 | 'model': 'CSTR', 339 | 'n_partypes': 1, 340 | 'include_sensitivity': False, 341 | } 342 | 343 | lrm_template = { 344 | 'model': 'LUMPED_RATE_MODEL_WITHOUT_PORES', 345 | 'ncol': 10, 346 | 'n_partypes': 1, 347 | 'include_sensitivity': False, 348 | } 349 | 350 | lrmp_template = { 351 | 'model': 'LUMPED_RATE_MODEL_WITH_PORES', 352 | 'ncol': 10, 353 | 'n_partypes': 1, 354 | 'include_sensitivity': False, 355 | } 356 | 357 | grm_template = { 358 | 'model': 'GENERAL_RATE_MODEL', 359 | 'ncol': 10, 360 | 'npar': 5, 361 | 'n_partypes': 1, 362 | 'include_sensitivity': False, 363 | } 364 | 365 | grm_template_1_comp = { 366 | 'model': 'GENERAL_RATE_MODEL', 367 | 'n_partypes': 1, 368 | 'ncol': 10, 369 | 'npar': 5, 370 | 'include_sensitivity': False, 371 | 'n_components': 1, 372 | } 373 | 374 | grm_template_sens = { 375 | 'model': 'GENERAL_RATE_MODEL', 376 | 'n_partypes': 1, 377 | 'ncol': 10, 378 | 'npar': 5, 379 | 'include_sensitivity': True, 380 | } 381 | 382 | grm_template_partypes = { 383 | 'model': 'GENERAL_RATE_MODEL', 384 | 'n_partypes': 2, 385 | 'ncol': 10, 386 | 'npar': 5, 387 | 'include_sensitivity': False, 388 | } 389 | 390 | _2dgrm_template = { 391 | 'model': 'GENERAL_RATE_MODEL_2D', 392 | 'n_partypes': 1, 393 | 'ncol': 10, 394 | 'npar': 5, 395 | 'include_sensitivity': False, 396 | } 397 | 398 | 399 | # %% Solution recorder templates 400 | 401 | no_split = { 402 | 'split_components': 0, 403 | 'split_ports': 0, 404 | 'single_as_multi_port': 0, 405 | } 406 | 407 | split_components = { 408 | 'split_components': 1, 409 | 'split_ports': 0, 410 | 'single_as_multi_port': 0, 411 | } 412 | 413 | split_ports = { 414 | 'split_components': 0, 415 | 'split_ports': 1, 416 | 'single_as_multi_port': 0, 417 | } 418 | 419 | split_ports_single_as_multi = { 420 | 'split_components': 0, 421 | 'split_ports': 1, 422 | 'single_as_multi_port': 1, 423 | } 424 | 425 | split_all = { 426 | 'split_components': 1, 427 | 'split_ports': 1, 428 | 'single_as_multi_port': 1, 429 | } 430 | 431 | 432 | # %% Test cases 433 | 434 | class Case(): 435 | def __init__(self, name, model_options, solution_recorder_options, expected_results): 436 | self.name = name 437 | self.model_options = model_options 438 | self.solution_recorder_options = solution_recorder_options 439 | self.expected_results = expected_results 440 | 441 | def __str__(self): 442 | return self.name 443 | 444 | def __repr__(self): 445 | return \ 446 | f"Case('{self.name}', {self.model_options}, " \ 447 | f"{self.solution_recorder_options}, {self.expected_results})" 448 | 449 | 450 | # %% CSTR 451 | 452 | cstr = Case( 453 | name='cstr', 454 | model_options=cstr_template, 455 | solution_recorder_options=no_split, 456 | expected_results={ 457 | 'solution_times': (1501,), 458 | 'last_state_y': (21,), 459 | 'last_state_ydot': (21,), 460 | 'coordinates_unit_000': {}, 461 | 'coordinates_unit_001': {}, 462 | 'solution_unit_000': { 463 | 'last_state_y': (13,), 464 | 'last_state_ydot': (13,), 465 | 466 | 'solution_inlet': (1501, 4), 467 | 'solution_outlet': (1501, 4), 468 | 'solution_bulk': (1501, 4), 469 | 'solution_solid': (1501, 4), 470 | 'solution_volume': (1501,), 471 | 472 | 'soldot_inlet': (1501, 4), 473 | 'soldot_outlet': (1501, 4), 474 | 'soldot_bulk': (1501, 4), 475 | 'soldot_solid': (1501, 4), 476 | 'soldot_volume': (1501,), 477 | }, 478 | 'solution_unit_001': { 479 | 'last_state_y': (4,), 480 | 'last_state_ydot': (4,), 481 | 482 | 'solution_inlet': (1501, 4), 483 | 'solution_outlet': (1501, 4), 484 | 485 | 'soldot_inlet': (1501, 4), 486 | 'soldot_outlet': (1501, 4), 487 | }, 488 | }, 489 | ) 490 | 491 | 492 | # %% LRM 493 | 494 | lrm = Case( 495 | name='lrm', 496 | model_options=lrm_template, 497 | solution_recorder_options=no_split, 498 | expected_results={ 499 | 'solution_times': (1501,), 500 | 'last_state_y': (92,), 501 | 'last_state_ydot': (92,), 502 | 'coordinates_unit_000': { 503 | 'axial_coordinates': (10,), 504 | }, 505 | 'coordinates_unit_001': {}, 506 | 'solution_unit_000': { 507 | 'last_state_y': (84,), 508 | 'last_state_ydot': (84,), 509 | 510 | 'solution_inlet': (1501, 4), 511 | 'solution_outlet': (1501, 4), 512 | 'solution_bulk': (1501, 10, 4), 513 | 'solution_solid': (1501, 10, 4), 514 | 515 | 'soldot_inlet': (1501, 4), 516 | 'soldot_outlet': (1501, 4), 517 | 'soldot_bulk': (1501, 10, 4), 518 | 'soldot_solid': (1501, 10, 4), 519 | }, 520 | 'solution_unit_001': { 521 | 'last_state_y': (4,), 522 | 'last_state_ydot': (4,), 523 | 524 | 'solution_inlet': (1501, 4), 525 | 'solution_outlet': (1501, 4), 526 | 527 | 'soldot_inlet': (1501, 4), 528 | 'soldot_outlet': (1501, 4), 529 | 530 | }, 531 | }, 532 | ) 533 | 534 | 535 | # %% LRMP 536 | 537 | lrmp = Case( 538 | name='lrmp', 539 | model_options=lrmp_template, 540 | solution_recorder_options=no_split, 541 | expected_results={ 542 | 'solution_times': (1501,), 543 | 'last_state_y': (172,), 544 | 'last_state_ydot': (172,), 545 | 'coordinates_unit_000': { 546 | 'axial_coordinates': (10,), 547 | }, 548 | 'coordinates_unit_001': {}, 549 | 'solution_unit_000': { 550 | 'last_state_y': (164,), 551 | 'last_state_ydot': (164,), 552 | 553 | 'solution_inlet': (1501, 4), 554 | 'solution_outlet': (1501, 4), 555 | 'solution_bulk': (1501, 10, 4), 556 | 'solution_particle': (1501, 10, 4), 557 | 'solution_solid': (1501, 10, 4), 558 | 'solution_flux': (1501, 1, 10, 4), 559 | 560 | 'soldot_inlet': (1501, 4), 561 | 'soldot_outlet': (1501, 4), 562 | 'soldot_bulk': (1501, 10, 4), 563 | 'soldot_particle': (1501, 10, 4), 564 | 'soldot_solid': (1501, 10, 4), 565 | 'soldot_flux': (1501, 1, 10, 4), 566 | }, 567 | 'solution_unit_001': { 568 | 'last_state_y': (4,), 569 | 'last_state_ydot': (4,), 570 | 571 | 'solution_inlet': (1501, 4), 572 | 'solution_outlet': (1501, 4), 573 | 574 | 'soldot_inlet': (1501, 4), 575 | 'soldot_outlet': (1501, 4), 576 | }, 577 | }, 578 | ) 579 | 580 | 581 | # %% GRM (no_split) 582 | 583 | grm = Case( 584 | name='grm', 585 | model_options=grm_template, 586 | solution_recorder_options=no_split, 587 | expected_results={ 588 | 'solution_times': (1501,), 589 | 'last_state_y': (492,), 590 | 'last_state_ydot': (492,), 591 | 'coordinates_unit_000': { 592 | 'axial_coordinates': (10,), 593 | 'particle_coordinates_000': (5,), 594 | }, 595 | 'coordinates_unit_001': {}, 596 | 'solution_unit_000': { 597 | 'last_state_y': (484,), 598 | 'last_state_ydot': (484,), 599 | 600 | 'solution_inlet': (1501, 4), 601 | 'solution_outlet': (1501, 4), 602 | 'solution_bulk': (1501, 10, 4), 603 | 'solution_particle': (1501, 10, 5, 4), 604 | 'solution_solid': (1501, 10, 5, 4), 605 | 'solution_flux': (1501, 1, 10, 4), 606 | 607 | 'soldot_inlet': (1501, 4), 608 | 'soldot_outlet': (1501, 4), 609 | 'soldot_bulk': (1501, 10, 4), 610 | 'soldot_particle': (1501, 10, 5, 4), 611 | 'soldot_solid': (1501, 10, 5, 4), 612 | 'soldot_flux': (1501, 1, 10, 4), 613 | }, 614 | 'solution_unit_001': { 615 | 'last_state_y': (4,), 616 | 'last_state_ydot': (4,), 617 | 618 | 'solution_inlet': (1501, 4), 619 | 'solution_outlet': (1501, 4), 620 | 621 | 'soldot_inlet': (1501, 4), 622 | 'soldot_outlet': (1501, 4), 623 | }, 624 | }, 625 | ) 626 | 627 | 628 | # %% GRM (split_components) 629 | 630 | grm_split_components = Case( 631 | name='grm_split_components', 632 | model_options=grm_template, 633 | solution_recorder_options=split_components, 634 | expected_results={ 635 | 'solution_times': (1501,), 636 | 'last_state_y': (492,), 637 | 'last_state_ydot': (492,), 638 | 'coordinates_unit_000': { 639 | 'axial_coordinates': (10,), 640 | 'particle_coordinates_000': (5,), 641 | }, 642 | 'coordinates_unit_001': {}, 643 | 'solution_unit_000': { 644 | 'last_state_y': (484,), 645 | 'last_state_ydot': (484,), 646 | 647 | 'solution_inlet_comp_000': (1501,), 648 | 'solution_inlet_comp_001': (1501,), 649 | 'solution_inlet_comp_002': (1501,), 650 | 'solution_inlet_comp_003': (1501,), 651 | 'solution_outlet_comp_000': (1501,), 652 | 'solution_outlet_comp_001': (1501,), 653 | 'solution_outlet_comp_002': (1501,), 654 | 'solution_outlet_comp_003': (1501,), 655 | 'solution_bulk': (1501, 10, 4), 656 | 'solution_particle': (1501, 10, 5, 4), 657 | 'solution_solid': (1501, 10, 5, 4), 658 | 'solution_flux': (1501, 1, 10, 4), 659 | 660 | 'soldot_inlet_comp_000': (1501,), 661 | 'soldot_inlet_comp_001': (1501,), 662 | 'soldot_inlet_comp_002': (1501,), 663 | 'soldot_inlet_comp_003': (1501,), 664 | 'soldot_outlet_comp_000': (1501,), 665 | 'soldot_outlet_comp_001': (1501,), 666 | 'soldot_outlet_comp_002': (1501,), 667 | 'soldot_outlet_comp_003': (1501,), 668 | 'soldot_bulk': (1501, 10, 4), 669 | 'soldot_particle': (1501, 10, 5, 4), 670 | 'soldot_solid': (1501, 10, 5, 4), 671 | 'soldot_flux': (1501, 1, 10, 4), 672 | }, 673 | 'solution_unit_001': { 674 | 'last_state_y': (4,), 675 | 'last_state_ydot': (4,), 676 | 677 | 'solution_inlet_comp_000': (1501,), 678 | 'solution_inlet_comp_001': (1501,), 679 | 'solution_inlet_comp_002': (1501,), 680 | 'solution_inlet_comp_003': (1501,), 681 | 'solution_outlet_comp_000': (1501,), 682 | 'solution_outlet_comp_001': (1501,), 683 | 'solution_outlet_comp_002': (1501,), 684 | 'solution_outlet_comp_003': (1501,), 685 | 686 | 'soldot_inlet_comp_000': (1501,), 687 | 'soldot_inlet_comp_001': (1501,), 688 | 'soldot_inlet_comp_002': (1501,), 689 | 'soldot_inlet_comp_003': (1501,), 690 | 'soldot_outlet_comp_000': (1501,), 691 | 'soldot_outlet_comp_001': (1501,), 692 | 'soldot_outlet_comp_002': (1501,), 693 | 'soldot_outlet_comp_003': (1501,), 694 | }, 695 | }, 696 | ) 697 | 698 | 699 | # %% GRM (split_ports) 700 | 701 | grm_split_ports = Case( 702 | name='grm_split_ports', 703 | model_options=grm_template, 704 | solution_recorder_options=split_ports, 705 | expected_results={ 706 | 'solution_times': (1501,), 707 | 'last_state_y': (492,), 708 | 'last_state_ydot': (492,), 709 | 'coordinates_unit_000': { 710 | 'axial_coordinates': (10,), 711 | 'particle_coordinates_000': (5,), 712 | }, 713 | 'coordinates_unit_001': {}, 714 | 'solution_unit_000': { 715 | 'last_state_y': (484,), 716 | 'last_state_ydot': (484,), 717 | 718 | 'solution_inlet': (1501, 4), 719 | 'solution_outlet': (1501, 4), 720 | 'solution_bulk': (1501, 10, 4), 721 | 'solution_particle': (1501, 10, 5, 4), 722 | 'solution_solid': (1501, 10, 5, 4), 723 | 'solution_flux': (1501, 1, 10, 4), 724 | 725 | 'soldot_inlet': (1501, 4), 726 | 'soldot_outlet': (1501, 4), 727 | 'soldot_bulk': (1501, 10, 4), 728 | 'soldot_particle': (1501, 10, 5, 4), 729 | 'soldot_solid': (1501, 10, 5, 4), 730 | 'soldot_flux': (1501, 1, 10, 4), 731 | }, 732 | 'solution_unit_001': { 733 | 'last_state_y': (4,), 734 | 'last_state_ydot': (4,), 735 | 736 | 'solution_inlet': (1501, 4), 737 | 'solution_outlet': (1501, 4), 738 | 739 | 'soldot_inlet': (1501, 4), 740 | 'soldot_outlet': (1501, 4), 741 | }, 742 | }, 743 | ) 744 | 745 | # %% GRM (split_ports_single_as_multi) 746 | 747 | grm_split_ports_single_as_multi = Case( 748 | name='grm_split_ports_single_as_multi', 749 | model_options=grm_template, 750 | solution_recorder_options=split_ports_single_as_multi, 751 | expected_results={ 752 | 'solution_times': (1501,), 753 | 'last_state_y': (492,), 754 | 'last_state_ydot': (492,), 755 | 'coordinates_unit_000': { 756 | 'axial_coordinates': (10,), 757 | 'particle_coordinates_000': (5,), 758 | }, 759 | 'coordinates_unit_001': {}, 760 | 'solution_unit_000': { 761 | 'last_state_y': (484,), 762 | 'last_state_ydot': (484,), 763 | 764 | 'solution_inlet_port_000': (1501, 4), 765 | 'solution_outlet_port_000': (1501, 4), 766 | 'solution_bulk': (1501, 10, 4), 767 | 'solution_particle': (1501, 10, 5, 4), 768 | 'solution_solid': (1501, 10, 5, 4), 769 | 'solution_flux': (1501, 1, 10, 4), 770 | 771 | 'soldot_inlet_port_000': (1501, 4), 772 | 'soldot_outlet_port_000': (1501, 4), 773 | 'soldot_bulk': (1501, 10, 4), 774 | 'soldot_particle': (1501, 10, 5, 4), 775 | 'soldot_solid': (1501, 10, 5, 4), 776 | 'soldot_flux': (1501, 1, 10, 4), 777 | }, 778 | 'solution_unit_001': { 779 | 'last_state_y': (4,), 780 | 'last_state_ydot': (4,), 781 | 782 | 'solution_inlet_port_000': (1501, 4), 783 | 'solution_outlet_port_000': (1501, 4), 784 | 785 | 'soldot_inlet_port_000': (1501, 4), 786 | 'soldot_outlet_port_000': (1501, 4), 787 | }, 788 | }, 789 | ) 790 | 791 | 792 | # %% GRM (split_all) 793 | 794 | grm_split_all = Case( 795 | name='grm_split_all', 796 | model_options=grm_template, 797 | solution_recorder_options=split_all, 798 | expected_results={ 799 | 'solution_times': (1501,), 800 | 'last_state_y': (492,), 801 | 'last_state_ydot': (492,), 802 | 'coordinates_unit_000': { 803 | 'axial_coordinates': (10,), 804 | 'particle_coordinates_000': (5,), 805 | }, 806 | 'coordinates_unit_001': {}, 807 | 'solution_unit_000': { 808 | 'last_state_y': (484,), 809 | 'last_state_ydot': (484,), 810 | 811 | 'solution_inlet_port_000_comp_000': (1501,), 812 | 'solution_inlet_port_000_comp_001': (1501,), 813 | 'solution_inlet_port_000_comp_002': (1501,), 814 | 'solution_inlet_port_000_comp_003': (1501,), 815 | 'solution_outlet_port_000_comp_000': (1501,), 816 | 'solution_outlet_port_000_comp_001': (1501,), 817 | 'solution_outlet_port_000_comp_002': (1501,), 818 | 'solution_outlet_port_000_comp_003': (1501,), 819 | 'solution_bulk': (1501, 10, 4), 820 | 'solution_particle': (1501, 10, 5, 4), 821 | 'solution_solid': (1501, 10, 5, 4), 822 | 'solution_flux': (1501, 1, 10, 4), 823 | 824 | 'soldot_inlet_port_000_comp_000': (1501,), 825 | 'soldot_inlet_port_000_comp_001': (1501,), 826 | 'soldot_inlet_port_000_comp_002': (1501,), 827 | 'soldot_inlet_port_000_comp_003': (1501,), 828 | 'soldot_outlet_port_000_comp_000': (1501,), 829 | 'soldot_outlet_port_000_comp_001': (1501,), 830 | 'soldot_outlet_port_000_comp_002': (1501,), 831 | 'soldot_outlet_port_000_comp_003': (1501,), 832 | 'soldot_bulk': (1501, 10, 4), 833 | 'soldot_particle': (1501, 10, 5, 4), 834 | 'soldot_solid': (1501, 10, 5, 4), 835 | 'soldot_flux': (1501, 1, 10, 4), 836 | }, 837 | 'solution_unit_001': { 838 | 'last_state_y': (4,), 839 | 'last_state_ydot': (4,), 840 | 841 | 'solution_inlet_port_000_comp_000': (1501,), 842 | 'solution_inlet_port_000_comp_001': (1501,), 843 | 'solution_inlet_port_000_comp_002': (1501,), 844 | 'solution_inlet_port_000_comp_003': (1501,), 845 | 'solution_outlet_port_000_comp_000': (1501,), 846 | 'solution_outlet_port_000_comp_001': (1501,), 847 | 'solution_outlet_port_000_comp_002': (1501,), 848 | 'solution_outlet_port_000_comp_003': (1501,), 849 | 850 | 'soldot_inlet_port_000_comp_000': (1501,), 851 | 'soldot_inlet_port_000_comp_001': (1501,), 852 | 'soldot_inlet_port_000_comp_002': (1501,), 853 | 'soldot_inlet_port_000_comp_003': (1501,), 854 | 'soldot_outlet_port_000_comp_000': (1501,), 855 | 'soldot_outlet_port_000_comp_001': (1501,), 856 | 'soldot_outlet_port_000_comp_002': (1501,), 857 | 'soldot_outlet_port_000_comp_003': (1501,), 858 | }, 859 | }, 860 | ) 861 | 862 | 863 | # %% GRM 1 Comp 864 | 865 | grm_1_comp = Case( 866 | name='grm_1_comp', 867 | model_options=grm_template_1_comp, 868 | solution_recorder_options=no_split, 869 | expected_results={ 870 | 'solution_times': (1501,), 871 | 'last_state_y': (123,), 872 | 'last_state_ydot': (123,), 873 | 'coordinates_unit_000': { 874 | 'axial_coordinates': (10,), 875 | 'particle_coordinates_000': (5,), 876 | }, 877 | 'coordinates_unit_001': {}, 878 | 'solution_unit_000': { 879 | 'last_state_y': (121,), 880 | 'last_state_ydot': (121,), 881 | 882 | 'solution_inlet': (1501, 1), 883 | 'solution_outlet': (1501, 1), 884 | 'solution_bulk': (1501, 10, 1), 885 | 'solution_particle': (1501, 10, 5, 1), 886 | 'solution_solid': (1501, 10, 5, 1), 887 | 'solution_flux': (1501, 1, 10, 1), 888 | 889 | 'soldot_inlet': (1501, 1), 890 | 'soldot_outlet': (1501, 1), 891 | 'soldot_bulk': (1501, 10, 1), 892 | 'soldot_particle': (1501, 10, 5, 1), 893 | 'soldot_solid': (1501, 10, 5, 1), 894 | 'soldot_flux': (1501, 1, 10, 1), 895 | }, 896 | 'solution_unit_001': { 897 | 'last_state_y': (1,), 898 | 'last_state_ydot': (1,), 899 | 900 | 'solution_inlet': (1501, 1), 901 | 'solution_outlet': (1501, 1), 902 | 'soldot_inlet': (1501, 1), 903 | 'soldot_outlet': (1501, 1), 904 | }, 905 | }, 906 | ) 907 | 908 | 909 | # %% GRM Sens 910 | 911 | grm_sens = Case( 912 | name='grm_sens', 913 | model_options=grm_template_sens, 914 | solution_recorder_options=no_split, 915 | expected_results={ 916 | 'solution_times': (1501,), 917 | 'last_state_y': (492,), 918 | 'last_state_ydot': (492,), 919 | 'coordinates_unit_000': { 920 | 'axial_coordinates': (10,), 921 | 'particle_coordinates_000': (5,), 922 | }, 923 | 'coordinates_unit_001': {}, 924 | 'solution_unit_000': { 925 | 'last_state_y': (484,), 926 | 'last_state_ydot': (484,), 927 | 928 | 'solution_inlet': (1501, 4), 929 | 'solution_outlet': (1501, 4), 930 | 'solution_bulk': (1501, 10, 4), 931 | 'solution_particle': (1501, 10, 5, 4), 932 | 'solution_solid': (1501, 10, 5, 4), 933 | 'solution_flux': (1501, 1, 10, 4), 934 | 935 | 'soldot_inlet': (1501, 4), 936 | 'soldot_outlet': (1501, 4), 937 | 'soldot_bulk': (1501, 10, 4), 938 | 'soldot_particle': (1501, 10, 5, 4), 939 | 'soldot_solid': (1501, 10, 5, 4), 940 | 'soldot_flux': (1501, 1, 10, 4), 941 | }, 942 | 'solution_unit_001': { 943 | 'last_state_y': (4,), 944 | 'last_state_ydot': (4,), 945 | 946 | 'solution_inlet': (1501, 4), 947 | 'solution_outlet': (1501, 4), 948 | 949 | 'soldot_inlet': (1501, 4), 950 | 'soldot_outlet': (1501, 4), 951 | }, 952 | 'sens_param_000_unit_000': { 953 | 'sens_inlet': (1501, 4), 954 | 'sens_outlet': (1501, 4), 955 | 'sens_bulk': (1501, 10, 4), 956 | 'sens_particle': (1501, 10, 5, 4), 957 | 'sens_solid': (1501, 10, 5, 4), 958 | 'sens_flux': (1501, 1, 10, 4), 959 | 960 | 'sensdot_inlet': (1501, 4), 961 | 'sensdot_outlet': (1501, 4), 962 | 'sensdot_bulk': (1501, 10, 4), 963 | 'sensdot_particle': (1501, 10, 5, 4), 964 | 'sensdot_solid': (1501, 10, 5, 4), 965 | 'sensdot_flux': (1501, 1, 10, 4), 966 | }, 967 | 'sens_param_000_unit_001': { 968 | 'sens_inlet': (1501, 4), 969 | 'sens_outlet': (1501, 4), 970 | 971 | 'sensdot_inlet': (1501, 4), 972 | 'sensdot_outlet': (1501, 4), 973 | }, 974 | }, 975 | ) 976 | 977 | # %% GRM ParTypes 978 | 979 | grm_par_types = Case( 980 | name='grm_par_types', 981 | model_options=grm_template_partypes, 982 | solution_recorder_options=no_split, 983 | expected_results={ 984 | 'solution_times': (1501,), 985 | 'last_state_y': (932,), 986 | 'last_state_ydot': (932,), 987 | 'coordinates_unit_000': { 988 | 'axial_coordinates': (10,), 989 | 'particle_coordinates_000': (5,), 990 | 'particle_coordinates_001': (5,), 991 | }, 992 | 'coordinates_unit_001': {}, 993 | 'solution_unit_000': { 994 | 'last_state_y': (924,), 995 | 'last_state_ydot': (924,), 996 | 997 | 'solution_inlet': (1501, 4), 998 | 'solution_outlet': (1501, 4), 999 | 'solution_bulk': (1501, 10, 4), 1000 | 'solution_particle_partype_000': (1501, 10, 5, 4), 1001 | 'solution_particle_partype_001': (1501, 10, 5, 4), 1002 | 'solution_solid_partype_000': (1501, 10, 5, 4), 1003 | 'solution_solid_partype_001': (1501, 10, 5, 4), 1004 | 'solution_flux': (1501, 2, 10, 4), 1005 | 1006 | 'soldot_inlet': (1501, 4), 1007 | 'soldot_outlet': (1501, 4), 1008 | 'soldot_bulk': (1501, 10, 4), 1009 | 'soldot_particle_partype_000': (1501, 10, 5, 4), 1010 | 'soldot_particle_partype_001': (1501, 10, 5, 4), 1011 | 'soldot_solid_partype_000': (1501, 10, 5, 4), 1012 | 'soldot_solid_partype_001': (1501, 10, 5, 4), 1013 | 'soldot_flux': (1501, 2, 10, 4), 1014 | }, 1015 | 'solution_unit_001': { 1016 | 'last_state_y': (4,), 1017 | 'last_state_ydot': (4,), 1018 | 1019 | 'solution_inlet': (1501, 4), 1020 | 'solution_outlet': (1501, 4), 1021 | 1022 | 'soldot_inlet': (1501, 4), 1023 | 'soldot_outlet': (1501, 4), 1024 | }, 1025 | }, 1026 | ) 1027 | 1028 | # %% 2D GRM 1029 | _2dgrm = Case( 1030 | name='_2dgrm', 1031 | model_options=_2dgrm_template, 1032 | solution_recorder_options=no_split, 1033 | expected_results={ 1034 | 'solution_times': (1501,), 1035 | 'last_state_y': (1468,), 1036 | 'last_state_ydot': (1468,), 1037 | 'coordinates_unit_000': { 1038 | 'axial_coordinates': (10,), 1039 | 'particle_coordinates_000': (5,), 1040 | 'radial_coordinates': (3,), 1041 | }, 1042 | 'coordinates_unit_001': {}, 1043 | 'solution_unit_000': { 1044 | 'last_state_y': (1452,), 1045 | 'last_state_ydot': (1452,), 1046 | 1047 | 'solution_inlet': (1501, 3, 4), 1048 | 'solution_outlet': (1501, 3, 4), 1049 | 'solution_bulk': (1501, 10, 3, 4), 1050 | 'solution_particle': (1501, 10, 3, 5, 4), 1051 | 'solution_solid': (1501, 10, 3, 5, 4), 1052 | 'solution_flux': (1501, 1, 10, 3, 4), 1053 | 1054 | 'soldot_inlet': (1501, 3, 4), 1055 | 'soldot_outlet': (1501, 3, 4), 1056 | 'soldot_bulk': (1501, 10, 3, 4), 1057 | 'soldot_particle': (1501, 10, 3, 5, 4), 1058 | 'soldot_solid': (1501, 10, 3, 5, 4), 1059 | 'soldot_flux': (1501, 1, 10, 3, 4), 1060 | }, 1061 | 'solution_unit_001': { 1062 | 'last_state_y': (4,), 1063 | 'last_state_ydot': (4,), 1064 | 1065 | 'solution_inlet': (1501, 4), 1066 | 'solution_outlet': (1501, 4), 1067 | 1068 | 'soldot_inlet': (1501, 4), 1069 | 'soldot_outlet': (1501, 4), 1070 | }, 1071 | }, 1072 | ) 1073 | 1074 | # %% 2D GRM Split Ports (single_as_multi_port=False) 1075 | 1076 | _2dgrm_split_ports = Case( 1077 | name='_2dgrm_split_ports', 1078 | model_options=_2dgrm_template, 1079 | solution_recorder_options=split_ports, 1080 | expected_results={ 1081 | 'solution_times': (1501,), 1082 | 'last_state_y': (1468,), 1083 | 'last_state_ydot': (1468,), 1084 | 'coordinates_unit_000': { 1085 | 'axial_coordinates': (10,), 1086 | 'particle_coordinates_000': (5,), 1087 | 'radial_coordinates': (3,), 1088 | }, 1089 | 'coordinates_unit_001': {}, 1090 | 'solution_unit_000': { 1091 | 'last_state_y': (1452,), 1092 | 'last_state_ydot': (1452,), 1093 | 1094 | 'solution_inlet_port_000': (1501, 4), 1095 | 'solution_inlet_port_001': (1501, 4), 1096 | 'solution_inlet_port_002': (1501, 4), 1097 | 'solution_outlet_port_000': (1501, 4), 1098 | 'solution_outlet_port_001': (1501, 4), 1099 | 'solution_outlet_port_002': (1501, 4), 1100 | 'solution_bulk': (1501, 10, 3, 4), 1101 | 'solution_particle': (1501, 10, 3, 5, 4), 1102 | 'solution_solid': (1501, 10, 3, 5, 4), 1103 | 'solution_flux': (1501, 1, 10, 3, 4), 1104 | 1105 | 'soldot_inlet_port_000': (1501, 4), 1106 | 'soldot_inlet_port_001': (1501, 4), 1107 | 'soldot_inlet_port_002': (1501, 4), 1108 | 'soldot_outlet_port_000': (1501, 4), 1109 | 'soldot_outlet_port_001': (1501, 4), 1110 | 'soldot_outlet_port_002': (1501, 4), 1111 | 'soldot_bulk': (1501, 10, 3, 4), 1112 | 'soldot_particle': (1501, 10, 3, 5, 4), 1113 | 'soldot_solid': (1501, 10, 3, 5, 4), 1114 | 'soldot_flux': (1501, 1, 10, 3, 4), 1115 | 1116 | }, 1117 | 'solution_unit_001': { 1118 | 'last_state_y': (4,), 1119 | 'last_state_ydot': (4,), 1120 | 1121 | 'solution_inlet': (1501, 4), 1122 | 'solution_outlet': (1501, 4), 1123 | 1124 | 'soldot_inlet': (1501, 4), 1125 | 'soldot_outlet': (1501, 4), 1126 | }, 1127 | }, 1128 | ) 1129 | 1130 | # %% 2D GRM Split All 1131 | 1132 | _2dgrm_split_all = Case( 1133 | name='_2dgrm_split_all', 1134 | model_options=_2dgrm_template, 1135 | solution_recorder_options=split_all, 1136 | expected_results={ 1137 | 'solution_times': (1501,), 1138 | 'last_state_y': (1468,), 1139 | 'last_state_ydot': (1468,), 1140 | 'coordinates_unit_000': { 1141 | 'axial_coordinates': (10,), 1142 | 'particle_coordinates_000': (5,), 1143 | 'radial_coordinates': (3,), 1144 | }, 1145 | 'coordinates_unit_001': {}, 1146 | 'solution_unit_000': { 1147 | 'last_state_y': (1452,), 1148 | 'last_state_ydot': (1452,), 1149 | 1150 | 'solution_inlet_port_000_comp_000': (1501,), 1151 | 'solution_inlet_port_000_comp_001': (1501,), 1152 | 'solution_inlet_port_000_comp_002': (1501,), 1153 | 'solution_inlet_port_000_comp_003': (1501,), 1154 | 'solution_inlet_port_001_comp_000': (1501,), 1155 | 'solution_inlet_port_001_comp_001': (1501,), 1156 | 'solution_inlet_port_001_comp_002': (1501,), 1157 | 'solution_inlet_port_001_comp_003': (1501,), 1158 | 'solution_inlet_port_002_comp_000': (1501,), 1159 | 'solution_inlet_port_002_comp_001': (1501,), 1160 | 'solution_inlet_port_002_comp_002': (1501,), 1161 | 'solution_inlet_port_002_comp_003': (1501,), 1162 | 'solution_outlet_port_000_comp_000': (1501,), 1163 | 'solution_outlet_port_000_comp_001': (1501,), 1164 | 'solution_outlet_port_000_comp_002': (1501,), 1165 | 'solution_outlet_port_000_comp_003': (1501,), 1166 | 'solution_outlet_port_001_comp_000': (1501,), 1167 | 'solution_outlet_port_001_comp_001': (1501,), 1168 | 'solution_outlet_port_001_comp_002': (1501,), 1169 | 'solution_outlet_port_001_comp_003': (1501,), 1170 | 'solution_outlet_port_002_comp_000': (1501,), 1171 | 'solution_outlet_port_002_comp_001': (1501,), 1172 | 'solution_outlet_port_002_comp_002': (1501,), 1173 | 'solution_outlet_port_002_comp_003': (1501,), 1174 | 'solution_bulk': (1501, 10, 3, 4), 1175 | 'solution_particle': (1501, 10, 3, 5, 4), 1176 | 'solution_solid': (1501, 10, 3, 5, 4), 1177 | 'solution_flux': (1501, 1, 10, 3, 4), 1178 | 1179 | 'soldot_inlet_port_000_comp_000': (1501,), 1180 | 'soldot_inlet_port_000_comp_001': (1501,), 1181 | 'soldot_inlet_port_000_comp_002': (1501,), 1182 | 'soldot_inlet_port_000_comp_003': (1501,), 1183 | 'soldot_inlet_port_001_comp_000': (1501,), 1184 | 'soldot_inlet_port_001_comp_001': (1501,), 1185 | 'soldot_inlet_port_001_comp_002': (1501,), 1186 | 'soldot_inlet_port_001_comp_003': (1501,), 1187 | 'soldot_inlet_port_002_comp_000': (1501,), 1188 | 'soldot_inlet_port_002_comp_001': (1501,), 1189 | 'soldot_inlet_port_002_comp_002': (1501,), 1190 | 'soldot_inlet_port_002_comp_003': (1501,), 1191 | 'soldot_outlet_port_000_comp_000': (1501,), 1192 | 'soldot_outlet_port_000_comp_001': (1501,), 1193 | 'soldot_outlet_port_000_comp_002': (1501,), 1194 | 'soldot_outlet_port_000_comp_003': (1501,), 1195 | 'soldot_outlet_port_001_comp_000': (1501,), 1196 | 'soldot_outlet_port_001_comp_001': (1501,), 1197 | 'soldot_outlet_port_001_comp_002': (1501,), 1198 | 'soldot_outlet_port_001_comp_003': (1501,), 1199 | 'soldot_outlet_port_002_comp_000': (1501,), 1200 | 'soldot_outlet_port_002_comp_001': (1501,), 1201 | 'soldot_outlet_port_002_comp_002': (1501,), 1202 | 'soldot_outlet_port_002_comp_003': (1501,), 1203 | 'soldot_bulk': (1501, 10, 3, 4), 1204 | 'soldot_particle': (1501, 10, 3, 5, 4), 1205 | 'soldot_solid': (1501, 10, 3, 5, 4), 1206 | 'soldot_flux': (1501, 1, 10, 3, 4), 1207 | }, 1208 | 'solution_unit_001': { 1209 | 'last_state_y': (4,), 1210 | 'last_state_ydot': (4,), 1211 | 'solution_inlet_port_000_comp_000': (1501,), 1212 | 'solution_inlet_port_000_comp_001': (1501,), 1213 | 'solution_inlet_port_000_comp_002': (1501,), 1214 | 'solution_inlet_port_000_comp_003': (1501,), 1215 | 'solution_outlet_port_000_comp_000': (1501,), 1216 | 'solution_outlet_port_000_comp_001': (1501,), 1217 | 'solution_outlet_port_000_comp_002': (1501,), 1218 | 'solution_outlet_port_000_comp_003': (1501,), 1219 | 1220 | 'soldot_inlet_port_000_comp_000': (1501,), 1221 | 'soldot_inlet_port_000_comp_001': (1501,), 1222 | 'soldot_inlet_port_000_comp_002': (1501,), 1223 | 'soldot_inlet_port_000_comp_003': (1501,), 1224 | 'soldot_outlet_port_000_comp_000': (1501,), 1225 | 'soldot_outlet_port_000_comp_001': (1501,), 1226 | 'soldot_outlet_port_000_comp_002': (1501,), 1227 | 'soldot_outlet_port_000_comp_003': (1501,), 1228 | }, 1229 | }, 1230 | ) 1231 | 1232 | 1233 | 1234 | # %% Testing utils 1235 | 1236 | def assert_keys(model_dict: dict, expected_dict: dict): 1237 | """ 1238 | Assert that the keys of two dictionaries are identical. 1239 | 1240 | Parameters 1241 | ---------- 1242 | model_dict : dict 1243 | The dictionary whose keys are to be compared. 1244 | expected_dict : dict 1245 | The dictionary containing the expected set of keys. 1246 | 1247 | Raises 1248 | ------ 1249 | AssertionError 1250 | If the keys of `model_dict` and `expected_dict` do not match. 1251 | 1252 | Examples 1253 | -------- 1254 | >>> assert_keys({"a": 1, "b": 2}, {"b": 3, "a": 4}) 1255 | True 1256 | >>> assert_keys({"a": 1, "b": 2}, {"b": 3, "c": 4}) 1257 | Traceback (most recent call last): 1258 | ... 1259 | AssertionError: Key mismatch. Expected {'b', 'c'}, but got {'b', 'a'}. 1260 | """ 1261 | model_keys = set(model_dict.keys()) 1262 | expected_keys = set(expected_dict.keys()) 1263 | assert model_keys == expected_keys, ( 1264 | f"Key mismatch. Expected {expected_keys}, but got {model_keys}." 1265 | ) 1266 | 1267 | 1268 | def assert_shape(array_shape, expected_shape, context, key, unit_id=None): 1269 | """ 1270 | Assert that the shape of an array matches the expected shape. 1271 | 1272 | Parameters 1273 | ---------- 1274 | array_shape : tuple 1275 | The shape of the actual array to validate. 1276 | expected_shape : tuple 1277 | The expected shape to compare against. 1278 | context : str 1279 | High-level context for the assertion, 1280 | e.g., 'last_state', 'coordinates', 'solution'. 1281 | key : str 1282 | Specific key or identifier within the context. 1283 | unit_id : str, optional 1284 | Unit identifier, e.g., 'unit_000'. If not provided, it is assumed the context 1285 | does not require unit-specific validation. 1286 | 1287 | Raises 1288 | ------ 1289 | AssertionError 1290 | If the actual shape does not match the expected shape, including detailed context. 1291 | 1292 | """ 1293 | unit_info = f"in unit '{unit_id}'" if unit_id else "" 1294 | assert array_shape == expected_shape, ( 1295 | f"Shape mismatch {unit_info} for {context}[{key}]. " 1296 | f"Expected {expected_shape}, but got {array_shape}." 1297 | ) 1298 | 1299 | 1300 | # %% Actual tests 1301 | 1302 | use_dll = [False, True] 1303 | 1304 | test_cases = [ 1305 | cstr, 1306 | lrm, 1307 | lrmp, 1308 | grm, 1309 | grm_split_components, 1310 | grm_split_ports, 1311 | grm_split_ports_single_as_multi, 1312 | grm_split_all, 1313 | grm_1_comp, 1314 | grm_sens, 1315 | grm_par_types, 1316 | _2dgrm, 1317 | _2dgrm_split_ports, 1318 | _2dgrm_split_all 1319 | ] 1320 | 1321 | 1322 | @pytest.mark.parametrize("use_dll", use_dll, ids=[f"{case}" for case in use_dll]) 1323 | @pytest.mark.parametrize("test_case", test_cases, ids=[case.name for case in test_cases]) 1324 | def test_simulator_options(use_dll, test_case): 1325 | model_options = test_case.model_options 1326 | solution_recorder_options = test_case.solution_recorder_options 1327 | expected_results = test_case.expected_results 1328 | 1329 | model = run_simulation_with_options( 1330 | use_dll, model_options, solution_recorder_options 1331 | ) 1332 | 1333 | # Assert solution_times shape 1334 | assert_shape( 1335 | model.root.output.solution.solution_times.shape, 1336 | expected_results['solution_times'], 1337 | context="solution", 1338 | key="solution_times" 1339 | ) 1340 | 1341 | # Assert last_state shapes 1342 | assert_shape( 1343 | model.root.output.last_state_y.shape, 1344 | expected_results['last_state_y'], 1345 | context="last_state", 1346 | key="y" 1347 | ) 1348 | assert_shape( 1349 | model.root.output.last_state_ydot.shape, 1350 | expected_results['last_state_ydot'], 1351 | context="last_state", 1352 | key="ydot" 1353 | ) 1354 | 1355 | # Check coordinates 1356 | unit = "unit_000" 1357 | excpected_coordinates = expected_results[f'coordinates_{unit}'] 1358 | coordinates_unit = model.root.output.coordinates[unit] 1359 | assert_keys(coordinates_unit, excpected_coordinates) 1360 | 1361 | for key, value in excpected_coordinates.items(): 1362 | coordinates_shape = coordinates_unit[key].shape 1363 | assert_shape( 1364 | coordinates_shape, 1365 | value, 1366 | context="coordinates", 1367 | key=key, 1368 | unit_id=unit, 1369 | ) 1370 | 1371 | unit = "unit_001" 1372 | excpected_coordinates = expected_results[f'coordinates_{unit}'] 1373 | coordinates_unit = model.root.output.coordinates[unit] 1374 | assert_keys(coordinates_unit, excpected_coordinates) 1375 | 1376 | for key, value in excpected_coordinates.items(): 1377 | coordinates_shape = coordinates_unit[key].shape 1378 | assert_shape( 1379 | coordinates_shape, 1380 | value, 1381 | context="coordinates", 1382 | key=key, 1383 | unit_id=unit, 1384 | ) 1385 | 1386 | # Check solution 1387 | unit = "unit_000" 1388 | excpected_solution = expected_results[f'solution_{unit}'] 1389 | solution_unit = model.root.output.solution[unit] 1390 | assert_keys(excpected_solution, solution_unit) 1391 | 1392 | for key, value in excpected_solution.items(): 1393 | shape = solution_unit[key].shape 1394 | assert_shape( 1395 | shape, 1396 | value, 1397 | context="solution", 1398 | key=key, 1399 | unit_id=unit, 1400 | ) 1401 | 1402 | unit = "unit_001" 1403 | excpected_solution = expected_results[f'solution_{unit}'] 1404 | solution_unit = model.root.output.solution[unit] 1405 | assert_keys(excpected_solution, solution_unit) 1406 | 1407 | for key, value in excpected_solution.items(): 1408 | shape = solution_unit[key].shape 1409 | assert_shape( 1410 | shape, 1411 | value, 1412 | context="solution", 1413 | key=key, 1414 | unit_id=unit, 1415 | ) 1416 | 1417 | # Check sensitivity 1418 | if model_options['include_sensitivity']: 1419 | unit = "unit_000" 1420 | excpected_sensitivity = expected_results[f'sens_param_000_{unit}'] 1421 | sensitivity_unit = model.root.output.sensitivity.param_000[unit] 1422 | assert_keys(excpected_sensitivity, sensitivity_unit) 1423 | 1424 | for key, value in excpected_sensitivity.items(): 1425 | shape = sensitivity_unit[key].shape 1426 | assert_shape( 1427 | shape, 1428 | value, 1429 | context="sensitivity", 1430 | key=key, 1431 | unit_id=unit, 1432 | ) 1433 | 1434 | unit = "unit_001" 1435 | excpected_sensitivity = expected_results[f'sens_param_000_{unit}'] 1436 | sensitivity_unit = model.root.output.sensitivity.param_000[unit] 1437 | assert_keys(excpected_sensitivity, sensitivity_unit) 1438 | 1439 | for key, value in excpected_sensitivity.items(): 1440 | shape = sensitivity_unit[key].shape 1441 | assert_shape( 1442 | shape, 1443 | value, 1444 | context="sensitivity", 1445 | key=key, 1446 | unit_id=unit, 1447 | ) 1448 | 1449 | 1450 | @pytest.mark.parametrize("use_dll", use_dll) 1451 | @pytest.mark.parametrize("test_case", [grm]) 1452 | def test_meta(use_dll, test_case): 1453 | model_options = test_case.model_options 1454 | solution_recorder_options = test_case.solution_recorder_options 1455 | expected_results = test_case.expected_results 1456 | 1457 | model = run_simulation_with_options( 1458 | use_dll, model_options, solution_recorder_options 1459 | ) 1460 | 1461 | meta_information = { 1462 | 'cadet_branch': str, 1463 | 'cadet_commit': str, 1464 | 'cadet_version': str, 1465 | 'file_format': int, 1466 | 'time_sim': float, 1467 | } 1468 | 1469 | assert model.root.meta.keys() == meta_information.keys() 1470 | 1471 | for meta_key, meta_type in meta_information.items(): 1472 | assert isinstance(model.root.meta[meta_key], meta_type) 1473 | 1474 | 1475 | if __name__ == "__main__": 1476 | pytest.main(["test_dll.py"]) 1477 | --------------------------------------------------------------------------------