├── examples ├── example08_data01.xlsx ├── example08_data02.xlsx ├── example08_data03.xlsx ├── example08_data04.xlsx └── example17_data.xlsx ├── MANIFEST.in ├── pyproject.toml ├── environment.yml ├── pyfoomb ├── __init__.py ├── constants.py ├── oed.py ├── utils.py ├── parameter.py ├── model_checking.py ├── modelling.py └── generalized_islands.py ├── LICENSE.txt ├── tests ├── test_oed.py ├── test_visualization.py ├── test_parameter.py ├── test_utils.py ├── test_model_checking.py ├── test_generalized_island.py ├── test_modelling.py ├── test_simulation.py ├── test_datatypes.py └── modelling_library.py ├── setup.py ├── .github └── workflows │ └── Tests.yml ├── README.md └── .gitignore /examples/example08_data01.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicroPhen/pyFOOMB/HEAD/examples/example08_data01.xlsx -------------------------------------------------------------------------------- /examples/example08_data02.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicroPhen/pyFOOMB/HEAD/examples/example08_data02.xlsx -------------------------------------------------------------------------------- /examples/example08_data03.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicroPhen/pyFOOMB/HEAD/examples/example08_data03.xlsx -------------------------------------------------------------------------------- /examples/example08_data04.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicroPhen/pyFOOMB/HEAD/examples/example08_data04.xlsx -------------------------------------------------------------------------------- /examples/example17_data.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicroPhen/pyFOOMB/HEAD/examples/example17_data.xlsx -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # inspired by https://hynek.me/articles/sharing-your-labor-of-love-pypi-quick-and-dirty/ 2 | 3 | # Tests 4 | recursive-include tests *.py 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # inspired by https://hynek.me/articles/sharing-your-labor-of-love-pypi-quick-and-dirty/ 2 | 3 | [build-system] 4 | requires = ["setuptools", "wheel"] 5 | build-backend = "setuptools.build_meta" 6 | 7 | [tool.black] 8 | line-length = 110 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pyfoomb 2 | 3 | channels: 4 | - defaults 5 | - conda-forge 6 | 7 | dependencies: 8 | - python>=3.7 9 | - numpy 10 | - scipy 11 | - joblib 12 | - pandas>=0.24 13 | - openpyxl 14 | - matplotlib-base 15 | - seaborn-base 16 | - psutil 17 | - pip 18 | - assimulo 19 | - pygmo>=2.14 -------------------------------------------------------------------------------- /pyfoomb/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.17.7' 2 | 3 | 4 | # import major classes the user will interact with 5 | from .modelling import BioprocessModel 6 | from .modelling import ObservationFunction 7 | from .caretaker import Caretaker 8 | from .parameter import ParameterMapper 9 | from .generalized_islands import LossCalculator 10 | 11 | # import datatypes 12 | from .datatypes import ModelState 13 | from .datatypes import Observation 14 | from .datatypes import Measurement 15 | from .datatypes import TimeSeries 16 | 17 | # import utils 18 | from .utils import Helpers 19 | from .visualization import Visualization 20 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 J. Hemmerich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_oed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import pytest 4 | 5 | from pyfoomb.oed import CovOptimality 6 | 7 | 8 | @pytest.fixture() 9 | def cov_evaluator(): 10 | cov_evaluator = CovOptimality() 11 | return cov_evaluator 12 | 13 | 14 | class TestCovOptimality(): 15 | 16 | @pytest.mark.parametrize( 17 | "criterion", 18 | [ 19 | 'A', 20 | 'D', 21 | 'E', 22 | 'E_mod', 23 | 'unknown_criterion' 24 | ] 25 | ) 26 | def test_calculate_criteria(self, criterion, cov_evaluator): 27 | Cov = np.random.rand(3, 3) + 0.001 28 | if criterion == 'unknown_criterion': 29 | with pytest.raises(KeyError): 30 | cov_evaluator.get_value(criterion, Cov) 31 | else: 32 | cov_evaluator.get_value(criterion, Cov) 33 | 34 | def test_bad_Cov(self, cov_evaluator): 35 | # Can only use a square Cov 36 | with pytest.raises(ValueError): 37 | cov_evaluator.get_value( 38 | criterion='A', 39 | Cov=np.random.rand(3, 4) + 0.001, 40 | ) 41 | # Return nan for Covs with inf entries 42 | bad_Cov = np.full(shape=(3, 3), fill_value=1.0) 43 | bad_Cov[0, 0] = np.inf 44 | assert np.isnan(cov_evaluator.get_value(criterion='A', Cov=bad_Cov,)) 45 | -------------------------------------------------------------------------------- /pyfoomb/constants.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | class Constants(): 4 | eps_float64 = numpy.finfo(numpy.float64).eps 5 | single_id = None 6 | observed_state_key = 'observed_state' 7 | pretty_metrics = { 8 | 'negLL' : 'negative-log-likelihood', 9 | 'SS' : 'sum-of-squares', 10 | 'WSS' : 'weighted-sum-of-squares', 11 | } 12 | handled_CVodeErrors = [-1, -4] 13 | 14 | 15 | class Messages(): 16 | bad_unknowns = 'Bad type of unknowns, must be either of type list or dict' 17 | cvode_boundzero = 'Detected CVodeError, probably due to some bounds being 0' 18 | invalid_measurements = 'Detected invalid measurement keys' 19 | invalid_unknowns = 'Detected invalid unknowns to be estimated' 20 | invalid_initial_values_type = 'Initial values must be provided as dictionary' 21 | missing_bounds = 'Must provide bounds for global parameter optimization' 22 | missing_values = 'Missing values' 23 | missing_sw_arg = 'Detected event handling in model. Provide "sw" argument in rhs signature: rhs(self, t, y, sw).' 24 | non_unique_ids = 'Detected non-unique (case-insensitive) keys/items/ids' 25 | unpacking_state_vector = 'Unpacking the state vector y is required in alphabetical case-insensitve order' 26 | wrong_return_order_state_derivatives = 'State derivatives are returned in the wrong order or do not match the pattern "d(state)dt"' 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | __packagename__ = 'pyfoomb' 4 | 5 | def get_version(): 6 | import os 7 | import re 8 | VERSIONFILE = os.path.join(__packagename__, '__init__.py') 9 | initfile_lines = open(VERSIONFILE, 'rt').readlines() 10 | VSRE = r"^__version__ = ['\"]([^'\"]*)['\"]" 11 | for line in initfile_lines: 12 | mo = re.search(VSRE, line, re.M) 13 | if mo: 14 | return mo.group(1) 15 | raise RuntimeError('Unable to find version string in %s.' % (VERSIONFILE,)) 16 | 17 | __version__ = get_version() 18 | 19 | 20 | setuptools.setup(name = __packagename__, 21 | packages = setuptools.find_packages(exclude=['examples', '*test*']), 22 | version=__version__, 23 | zip_safe=False, 24 | description='Package for handling systems of ordinary differential equations (ODEs) with discontinuities. Relies on assimulo package for ODE integration and pygmo package for optimization.', 25 | author='Johannes Hemmerich', 26 | author_email='hemmerich@outlook.com', 27 | url='https://github.com/MicroPhen/pyFOOMB', 28 | license='MIT', 29 | classifiers= [ 30 | 'Programming Language :: Python :: 3 :: Only', 31 | 'Operating System :: OS Independent', 32 | 'Intended Audience :: Developers' 33 | ], 34 | python_requires='>=3.7', 35 | install_requires=[ 36 | 'numpy', 37 | 'scipy', 38 | 'pandas>=0.24', 39 | 'openpyxl', 40 | 'joblib', 41 | 'matplotlib', 42 | 'seaborn', 43 | 'assimulo', 44 | 'psutil', 45 | ] 46 | ) 47 | -------------------------------------------------------------------------------- /.github/workflows/Tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | test-unittest: 7 | name: Test (${{ matrix.python-version }}, ${{ matrix.os }}) 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | max-parallel: 6 11 | matrix: 12 | os: [windows-latest, ubuntu-latest] 13 | python-version: [3.7, 3.8, 3.9] 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - uses: conda-incubator/setup-miniconda@v2 18 | with: 19 | miniconda-version: "latest" 20 | python-version: ${{ matrix.python-version }} 21 | auto-activate-base: true 22 | add-pip-as-python-dependency: true 23 | 24 | - name: Run tests on Windows 25 | if: matrix.os == 'windows-latest' 26 | shell: powershell 27 | run: | 28 | conda init 29 | conda env update --name test --file environment.yml 30 | conda activate test 31 | conda install pytest pytest-cov 32 | pip install -e . 33 | pytest --cov-report xml --cov-report term-missing --cov=pyfoomb tests/ -v 34 | 35 | - name: Run tests on Ubuntu 36 | if: matrix.os == 'ubuntu-latest' 37 | shell: bash -l {0} 38 | run: | 39 | conda init 40 | conda env update --name test --file environment.yml 41 | conda activate test 42 | conda install pytest pytest-cov 43 | pip install -e . 44 | pytest --cov-report xml --cov-report term-missing --cov=pyfoomb tests/ -v 45 | 46 | - name: Upload coverage 47 | if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9 48 | uses: "codecov/codecov-action@v1" 49 | with: 50 | token: ${{ secrets.CODECOV_TOKEN }} 51 | file: ./coverage.xml 52 | 53 | - name: Build wheel 54 | if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9 55 | shell: bash -l {0} 56 | run: | 57 | conda init 58 | conda env create --name build-test --file environment.yml 59 | conda activate build-test 60 | pip install -U pip pep517 twine 61 | python -m pep517.build . 62 | cd dist 63 | pip install pyfoomb*.whl 64 | python -c "import pyfoomb; print(f'Can successfully import {pyfoomb.__name__} {pyfoomb.__version__}')" 65 | openssl sha256 pyfoomb*.tar.gz 66 | -------------------------------------------------------------------------------- /pyfoomb/oed.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | class CovOptimality(): 4 | """ 5 | Manages evaluation of a parameter variance-covariance matrix w.r.t. different optimality measures. 6 | Typically used for optimal experimental design (OED) methods. 7 | """ 8 | 9 | def get_value(self, criterion:str, Cov:numpy.ndarray) -> float: 10 | """ 11 | Arguments 12 | --------- 13 | criterion : str 14 | Must be one of the implemented criteria, describing the mapping of the Cov matrix to a scalar. 15 | Cov : numpy.ndarray 16 | Variance-covariance matrix, must be sqaure and positive semi-definite. 17 | 18 | Returns 19 | ------- 20 | float 21 | The result of the requested mapping function Cov -> scalar. 22 | 23 | Raises 24 | ------ 25 | ValueError 26 | Cov is not square. 27 | """ 28 | 29 | if Cov.shape[0] != Cov.shape[1]: 30 | raise ValueError('Parameter covariance matrix must be square') 31 | if numpy.isinf(Cov).any(): 32 | return numpy.nan 33 | opt_fun = self._get_optimality_function(criterion) 34 | return opt_fun(Cov) 35 | 36 | 37 | def _get_optimality_function(self, criterion:str): 38 | """ 39 | Selects the criterion function. 40 | """ 41 | 42 | opt_functions = { 43 | 'A' : self._A_optimality, 44 | 'D' : self._D_optimality, 45 | 'E' : self._E_optimality, 46 | 'E_mod' : self._E_mod_optimality, 47 | } 48 | 49 | return opt_functions[criterion] 50 | 51 | 52 | def _A_optimality(self, Cov:numpy.ndarray) -> float: 53 | """ 54 | The A criterion simply adds up the parameter variances, 55 | neglecting parameter covariances. 56 | """ 57 | 58 | return numpy.trace(Cov) 59 | 60 | 61 | def _D_optimality(self, Cov:numpy.ndarray) -> float: 62 | """ 63 | Evaluates the hypervolume of the parameter joint confidence ellipsoid. 64 | """ 65 | 66 | return numpy.linalg.det(Cov) 67 | 68 | 69 | def _E_optimality(self, Cov) -> float: 70 | """ 71 | Evaluate the major axis of the parameter joint confidence ellipsoid. 72 | """ 73 | 74 | return numpy.max(numpy.linalg.eigvals(Cov)) 75 | 76 | 77 | def _E_mod_optimality(self, Cov:numpy.ndarray) -> float: 78 | """ 79 | Evaluates the 'sphericalitcity' (i.e., shape) of the parameter joint confidence ellipsoid. 80 | Sometimes also referred to as K-criterion. 81 | """ 82 | 83 | eig_vals = numpy.linalg.eigvals(Cov) 84 | return numpy.min(eig_vals) / numpy.max(eig_vals) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![codecov](https://codecov.io/gh/MicroPhen/pyFOOMB/branch/main/graph/badge.svg?token=7WALTIPP6O)](https://codecov.io/gh/MicroPhen/pyFOOMB) 2 | [![Tests](https://github.com/MicroPhen/pyFOOMB/workflows/Tests/badge.svg)](https://github.com/MicroPhen/pyFOOMB/actions) 3 | [![DOI](https://zenodo.org/badge/309308898.svg)](https://zenodo.org/badge/latestdoi/309308898) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | ![GitHub release (latest by date)](https://img.shields.io/github/v/release/MicroPhen/pyFOOMB) 6 | 7 | # pyFOOMB 8 | 9 | __*Py*thon *F*ramework for *O*bject *O*riented *M*odelling of *B*ioprocesses__ 10 | 11 | Intented application is the acessible modelling of simple to medium complex bioprocess models, by programmatic means. In contrast to 'full-blown' software suites, `pyFOOMB` can be used by scientists with little programming skills in the easy-access language Python. 12 | `pyFOOMB` comes with a MIT license, and anyone interested in using, understanding, or contributing to pyFOOMB is happily invited to do so. 13 | 14 | `pyFOOMB` relies on the `assimulo` package (), providing an interface to the SUNDIALS CVode integrator for systems of differential equations, as well as event handling routines. For optimization, i.e. model calibration from data, the `pygmo` package is used, which provides Python bindings for the `pagmo2` package implementing the Asynchronous Generalized Islands Model. 15 | 16 | To faciliate rapid starting for new users, a continously growing collection of Jupyter notebooks is provided. These serve to demonstrate basic and advanced concepts and functionalities (also beyond the pure functions of the `pyFOOMB` package). Also, the examples can be used as building blocks for developing own bioprocess models and corresponding workflows. 17 | 18 | Check also our open access [publication](https://onlinelibrary.wiley.com/doi/full/10.1002/elsc.202000088) at Engineering in Life Sciences introducing `pyFOOMB` with two more elaborated application examples that reproduce real-world data from literature. 19 | 20 | Literature: 21 | 22 | * Andersson C, Führer C, and Akesson J (2015). Assimulo: A unified framework for ODE solvers. _Math Comp Simul_ __116__:26-43 23 | * Biscani F, Izzo D (2020). A parallel global multiobjective framework for optimization: pagmo. _J Open Source Softw_ __5__:2338 24 | * Hindmarsh AC, _et al_ (2005). SUNDIALS: Suite of nonlinear and differential/algebraic equation solvers. _ACM Trans Math Softw_ __31__:363-396 25 | 26 | ## Requirements (provided as environment.yml) 27 | 28 | * python 3.7, 3.8 or 3.9 29 | * numpy 30 | * scipy 31 | * joblib 32 | * pandas 33 | * openpyxl 34 | * matplotlib(-base) 35 | * seaborn(-base) 36 | * psutil 37 | * assimulo (via conda-forge) 38 | * pygmo (via conda-forge) 39 | 40 | ## Easy installation 41 | 42 | 1) Open a terminal / shell 43 | 2) Optional: Create a new environment with `conda env create -n my-pyfoomb-env python=3.9` and activate it with `conda activate my-pyfoomb-env` 44 | 3) Install `pyFOOMB` by executing `conda install -c conda-forge pyfoomb` 45 | 46 | ## Development installation 47 | 48 | 1) Download the code repository to your computer, this is done the best way using `git clone`: In a shell, navigate to the folder where you want the repository to be located. 49 | 2) Open a terminal / shell and clone the repository 50 | via `git clone https://github.com/MicroPhen/pyFOOMB.git` 51 | 3) cd (*change directory*) into the newly cloned repository : `cd pyfoomb` 52 | 4) Verify that you are in the repo folder, where the file `environment.yml` is found (`dir` for Windows, `ls` for Linux/Mac). 53 | 5) Exceute `conda env create -f environment.yml`. 54 | This will create a conda environment named `pyfoomb`, with the current version of the just cloned git repository. 55 | 6) Don't forget to activate the newly created environment to install the `pyFOOMB` package in the next step 56 | 7) To make sure, your environment will refer always the state of your git repo (i.e., after own code modifications or after pulling from remote), run `pip install -e ../pyfoomb`. 57 | -------------------------------------------------------------------------------- /tests/test_visualization.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('agg') 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | from pyfoomb.caretaker import Caretaker 8 | from pyfoomb.datatypes import Measurement 9 | from pyfoomb.datatypes import TimeSeries 10 | from pyfoomb.visualization import Visualization 11 | from pyfoomb.visualization import VisualizationHelpers 12 | 13 | from modelling_library import ModelLibrary 14 | 15 | class StaticHelpers(): 16 | 17 | data_single = [Measurement(name='y0', timepoints=[1, 2, 3], values=[10, 10 ,10], errors=[0.1, 0.2, 0.3])] 18 | data_multi = [ 19 | Measurement(name='y0', timepoints=[1, 2, 3], values=[10, 10 ,10], errors=[0.1, 0.2, 0.3], replicate_id='1st'), 20 | Measurement(name='y0', timepoints=[1, 5, 10], values=[10, 10 ,10], errors=[0.1, 0.2, 0.3], replicate_id='2nd'), 21 | ] 22 | many_time_series_1 = [ 23 | [TimeSeries(name='T1', timepoints=[1, 2, 3, 4], values=[10, 10, 10, 10], replicate_id='1st')], 24 | [TimeSeries(name='T1', timepoints=[1, 2, 3], values=[10, 10, 10], replicate_id='1st')], 25 | ] 26 | unknowns = ['y00', 'y10'] 27 | bounds = [(-100, 100), (-100, 100)] 28 | 29 | @pytest.fixture 30 | def caretaker_single(): 31 | name = 'model01' 32 | return Caretaker( 33 | bioprocess_model_class=ModelLibrary.modelclasses[name], 34 | model_parameters=ModelLibrary.model_parameters[name], 35 | initial_values=ModelLibrary.initial_values[name], 36 | ) 37 | 38 | @pytest.fixture 39 | def caretaker_multi(): 40 | name = 'model01' 41 | return Caretaker( 42 | bioprocess_model_class=ModelLibrary.modelclasses[name], 43 | model_parameters=ModelLibrary.model_parameters[name], 44 | initial_values=ModelLibrary.initial_values[name], 45 | replicate_ids=['1st', '2nd'] 46 | ) 47 | 48 | class TestVisualizationHelpers(): 49 | 50 | @pytest.mark.parametrize('n', [10, 20, 21, 10.5, 10.6, 20.5, 20.6]) 51 | def test_colors(self, n): 52 | VisualizationHelpers.get_n_colors(n) 53 | 54 | 55 | class TestVisualization(): 56 | 57 | @pytest.mark.parametrize('measurements', [StaticHelpers.data_single, StaticHelpers.data_multi]) 58 | def test_show_kinetic_data(self, measurements): 59 | Visualization.show_kinetic_data(time_series=measurements) 60 | 61 | @pytest.mark.filterwarnings('ignore:All-NaN slice encountered') 62 | @pytest.mark.parametrize( 63 | 'time_series' , 64 | [ 65 | [StaticHelpers.data_multi]*2, 66 | StaticHelpers.many_time_series_1, 67 | ] 68 | ) 69 | def test_show_kinetic_data_many(self, time_series): 70 | figsax = Visualization.show_kinetic_data_many(time_series=time_series) 71 | for _key in figsax: 72 | for _line in figsax[_key][1][0].lines: 73 | for _value in _line._y: 74 | assert _value == 10 or np.isnan(_value) 75 | 76 | @pytest.mark.parametrize('measurements', [StaticHelpers.data_single, StaticHelpers.data_multi]) 77 | def test_compare_estimates(self, caretaker_single, measurements): 78 | Visualization.compare_estimates( 79 | parameters={_p : 10 for _p in StaticHelpers.unknowns}, 80 | measurements=measurements, 81 | caretaker=caretaker_single, 82 | ) 83 | 84 | @pytest.mark.parametrize('show_measurements_only', [False, True]) 85 | @pytest.mark.parametrize( 86 | 'caretaker, data', 87 | [ 88 | ('caretaker_multi', StaticHelpers.data_multi), 89 | ('caretaker_single', StaticHelpers.data_single), 90 | ] 91 | ) 92 | def test_compare_estimates_many(self, caretaker, data, show_measurements_only, request): 93 | caretaker = request.getfixturevalue(caretaker) 94 | Visualization.compare_estimates_many( 95 | parameter_collections={_p : [10]*3 for _p in StaticHelpers.unknowns}, 96 | measurements=data, 97 | caretaker=caretaker, 98 | show_measurements_only=show_measurements_only, 99 | ) 100 | 101 | @pytest.mark.parametrize('show_corr_coeffs', [True, False]) 102 | @pytest.mark.parametrize('estimates', [None, {'p1': 2.5, 'p2' : 5.5}]) 103 | def test_show_parameter_distributions(self, estimates, show_corr_coeffs): 104 | Visualization.show_parameter_distributions( 105 | parameter_collections={ 106 | 'p1' : [1, 2, 3], 107 | 'p2' : [4, 5, 6] 108 | }, 109 | estimates=estimates, 110 | show_corr_coeffs=show_corr_coeffs, 111 | ) 112 | -------------------------------------------------------------------------------- /tests/test_parameter.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | 4 | from pyfoomb.parameter import Parameter 5 | from pyfoomb.parameter import ParameterManager 6 | from pyfoomb.parameter import ParameterMapper 7 | 8 | 9 | class TestParameter(): 10 | 11 | @pytest.mark.parametrize( 12 | "local_name, value", 13 | [ 14 | (None, None), 15 | ('p_local', None), 16 | ('p_local', 1000) 17 | ], 18 | ) 19 | def test_init(self, local_name, value): 20 | Parameter(global_name='p_global', replicate_id='1st', local_name=local_name, value=value) 21 | 22 | 23 | class TestParameterMapper(): 24 | 25 | def test_init(self): 26 | replicate_id = '1st' 27 | global_name = 'p_global' 28 | local_name = 'p_local' 29 | value = 1 30 | 31 | ParameterMapper(replicate_id, global_name, local_name, value) 32 | 33 | # Without specifying a local name, this is build from the global name and replicate_id 34 | pm = ParameterMapper(replicate_id, global_name) 35 | assert pm.local_name == f'{global_name}_{replicate_id}' 36 | # The value defaults to None 37 | assert pm.value is None 38 | 39 | # Must use a local name when mapping shall be applied to all or several replicate_ids 40 | with pytest.raises(ValueError): 41 | ParameterMapper(replicate_id='all', global_name=global_name) 42 | with pytest.raises(ValueError): 43 | ParameterMapper(replicate_id=['1st', '2nd'], global_name=global_name) 44 | 45 | 46 | class TestParameterManager(): 47 | 48 | replicate_ids = ['1st', '2nd'] 49 | parameters = {'p1' : 1, 'p2' : 2} 50 | 51 | @pytest.fixture() 52 | def parameter_manager(self): 53 | return ParameterManager(replicate_ids=self.replicate_ids, parameters=self.parameters) 54 | 55 | @pytest.fixture() 56 | def mappings(self): 57 | return [ 58 | ParameterMapper(_replicate_id, _global_parameter) 59 | for _replicate_id in self.replicate_ids 60 | for _global_parameter in self.parameters 61 | ] 62 | 63 | def test_init(self, parameter_manager): 64 | # Must provide case-senstitive unique replicate_ids 65 | with pytest.raises(ValueError): 66 | ParameterManager(replicate_ids=['1st', '1ST'], parameters=self.parameters) 67 | # Can set replicate_ids only during instantiation 68 | with pytest.raises(AttributeError): 69 | parameter_manager.replicate_ids = ['1st', '2nd'] 70 | # Can set global parameters only during instantiation 71 | with pytest.raises(AttributeError): 72 | parameter_manager.global_parameters = self.parameters 73 | 74 | def test_apply_parameter_mappings(self, parameter_manager, mappings): 75 | # Apply a single mapping 76 | parameter_manager.apply_mappings(mappings[0]) 77 | # Apply a list of mappings 78 | parameter_manager.apply_mappings(mappings[1:]) 79 | # Now set some parameter values, there will be a Warning issued for the unknown parameter 80 | with pytest.warns(UserWarning): 81 | parameter_manager.set_parameter_values( 82 | { 83 | 'p1' : 1000, # a global parameter 84 | 'p1_1st' : 100, # a local parameter 85 | 'p_unknown' : 10, # unknown parameter 86 | } 87 | ) 88 | # One can define a local parameter for multiple replicate ids 89 | parameter_manager.apply_mappings( 90 | ParameterMapper(replicate_id=['1st', '2nd'], global_name='p1', local_name='p1_local', value=10000), 91 | ) 92 | # Can also be applied to all replicates 93 | parameter_manager.apply_mappings( 94 | ParameterMapper(replicate_id='all', global_name='p1', local_name='p1_local'), 95 | ) 96 | # Must use only known replicate ids 97 | with pytest.raises(ValueError): 98 | parameter_manager.apply_mappings( 99 | ParameterMapper(replicate_id=['1st', '2nd', 'invalid'], global_name='p1', local_name='p1_local'), 100 | ) 101 | with pytest.raises(ValueError): 102 | parameter_manager.apply_mappings( 103 | ParameterMapper(replicate_id='invalid', global_name='p1', local_name='p1_local'), 104 | ) 105 | # Must use only known global parameters 106 | with pytest.raises(ValueError): 107 | parameter_manager.apply_mappings( 108 | ParameterMapper(replicate_id='1st', global_name='p_unknown'), 109 | ) 110 | 111 | def test_parameter_other_mapping_related_methods(self, parameter_manager): 112 | # The managed mappings can be shown as DataFrame 113 | parameter_manager.parameter_mapping 114 | # Can get the current parameter mappings as list of ParameterMappers 115 | parameter_manager.get_parameter_mappers() 116 | # Get the current parameter values for a specific replicate_id 117 | parameter_manager.get_parameters_for_replicate(replicate_id='1st') 118 | # There is a private method to check the parameter mappings before applying them 119 | # Check for being ParameterMapper objects 120 | with pytest.raises(TypeError): 121 | parameter_manager._check_mappings(mappings=['I am a string']) 122 | # Each unique local parameter name must have the same value for the mapping 123 | with pytest.raises(ValueError): 124 | parameter_manager._check_mappings( 125 | [ 126 | ParameterMapper(replicate_id='1st', global_name='p1', local_name='p1_local', value=100), 127 | ParameterMapper(replicate_id='2nd', global_name='p1', local_name='p1_local', value=1000), 128 | ] 129 | ) 130 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import pytest 5 | 6 | from pyfoomb.datatypes import TimeSeries 7 | from pyfoomb.datatypes import Measurement 8 | 9 | from pyfoomb.utils import Calculations 10 | from pyfoomb.utils import Helpers 11 | 12 | 13 | class StaticData(): 14 | measurements_wo_errors = Measurement(name='M1', timepoints=[1, 2, 3], values=[10, 20, 30], replicate_id='1st') 15 | measurements_w_errors = Measurement(name='M2', timepoints=[2, 3, 4], values=[20, 30, 40], errors=[1/20, 1/30, 1/40], replicate_id='1st') 16 | time_series_1 = TimeSeries(name='TS1', timepoints=[1, 2], values=[10, 20], replicate_id='1st') 17 | time_series_2 = TimeSeries(name='TS2', timepoints=[2, 3], values=[20, 30], replicate_id='2nd') 18 | 19 | 20 | class TestCalculations(): 21 | 22 | def test_corr(self): 23 | 24 | matrix = np.array( 25 | [ 26 | [1, 2], [3, 4] 27 | ] 28 | ) 29 | Calculations.cov_into_corr(matrix) 30 | 31 | # can use only square matrices 32 | non_square_matrix = np.array( 33 | [ 34 | [1, 2, 3], [4, 5, 6] 35 | ] 36 | ) 37 | with pytest.raises(ValueError): 38 | Calculations.cov_into_corr(non_square_matrix) 39 | 40 | 41 | class TestHelpers(): 42 | 43 | def test_bounds_to_floats(self): 44 | int_bounds = [(0, 1), (2, 3)] 45 | float_bounds = Helpers.bounds_to_floats(int_bounds) 46 | for _bounds in float_bounds: 47 | assert isinstance(_bounds[0], float) 48 | assert isinstance(_bounds[1], float) 49 | 50 | @pytest.mark.parametrize( 51 | 'ok_ids', 52 | [ 53 | ({'a01' : 1, 'b01' : 2}), 54 | ({'b01' : 1, 'a01' : 2}), 55 | (['a01', 'b01']), 56 | (['a01']) 57 | 58 | ] 59 | ) 60 | def test_unique_ids_ok(self, ok_ids): 61 | """ 62 | To ensure that ids (replicate_ids, states, etc) are case-insensitive unique 63 | """ 64 | assert Helpers.has_unique_ids(ok_ids) 65 | 66 | @pytest.mark.parametrize( 67 | 'not_ok_ids', 68 | [ 69 | ({'a01' : 1, 'A01' : 2}), 70 | (['a01', 'b01', 'b01']), 71 | (['a01', 'B01', 'b01']), 72 | ] 73 | ) 74 | def test_unique_ids_not_ok(self, not_ok_ids): 75 | """ 76 | To ensure that ids (replicate_ids, states, etc) are case-insensitive unique 77 | """ 78 | assert not Helpers.has_unique_ids(not_ok_ids) 79 | 80 | def test_unique_ids_must_be_list_or_dict(self): 81 | with pytest.raises(TypeError): 82 | Helpers.has_unique_ids(('a01', 'b01')) 83 | 84 | def test_utils_for_datatypes(self): 85 | 86 | # To check whether all measurements in a list of those hve errors or not 87 | assert not Helpers.all_measurements_have_errors([StaticData.measurements_wo_errors, StaticData.measurements_w_errors]) 88 | assert Helpers.all_measurements_have_errors([StaticData.measurements_w_errors, StaticData.measurements_w_errors]) 89 | assert not Helpers.all_measurements_have_errors([StaticData.measurements_wo_errors, StaticData.measurements_wo_errors]) 90 | 91 | # Get the joint time vector of several TimeSeries objects 92 | actual = Helpers.get_unique_timepoints([StaticData.measurements_wo_errors, StaticData.measurements_w_errors]) 93 | for _actual, _expected in zip(actual, np.array([1., 2., 3., 4.])): 94 | assert _actual == _expected 95 | 96 | # Extract a specific TimeSeries from a list 97 | timeseries_list = [StaticData.measurements_wo_errors, StaticData.measurements_w_errors] 98 | assert isinstance(Helpers.extract_time_series(timeseries_list, replicate_id='1st', name='M1'), TimeSeries) 99 | # In case not match is found, the method returns None 100 | with pytest.warns(UserWarning): 101 | assert Helpers.extract_time_series(timeseries_list, replicate_id='2nd', name='M1', no_extraction_warning=True) is None 102 | # More than one match is found 103 | with pytest.raises(ValueError): 104 | Helpers.extract_time_series(timeseries_list*2, replicate_id='1st', name='M1') 105 | 106 | def test_parameter_collections(self): 107 | """ 108 | Methods related to parameter distributions from MC sampling or parameter scanning studies 109 | """ 110 | 111 | parameter_collection_not_ok = { 112 | 'p1' : [1, 2, 3], 113 | 'p2' : [10, 20, 30, 40] 114 | } 115 | parameter_collection_ok = { 116 | 'p1' : [1, 2, 3], 117 | 'p2' : [10, 20, 30] 118 | } 119 | # The parameters shall all have the same length 120 | assert Helpers.get_parameters_length(parameter_collection_ok) == 3 121 | # The parameters are not allowed to have different lengths 122 | with pytest.raises(ValueError): 123 | Helpers.get_parameters_length(parameter_collection_not_ok) 124 | 125 | # Parameter collections can be sliced for, e.g. get predictions for a particular slice 126 | parameter_slices = Helpers.split_parameters_distributions(parameter_collection_ok) 127 | for parameter_slice in parameter_slices: 128 | assert list(parameter_slice.keys()) == list(parameter_collection_ok.keys()) 129 | 130 | def test_unique_timepoints(self): 131 | t_all = Helpers.get_unique_timepoints( 132 | [ 133 | StaticData.time_series_1, 134 | StaticData.time_series_2, 135 | StaticData.measurements_w_errors, 136 | StaticData.measurements_wo_errors, 137 | ] 138 | ) 139 | assert len(t_all) == 4 140 | assert all(np.equal(t_all, np.array([1, 2, 3, 4]))) 141 | -------------------------------------------------------------------------------- /tests/test_model_checking.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | 4 | from pyfoomb.simulation import ExtendedSimulator 5 | from pyfoomb.model_checking import ModelChecker 6 | 7 | import modelling_library 8 | from modelling_library import ModelLibrary 9 | from modelling_library import ObservationFunctionLibrary 10 | 11 | 12 | @pytest.fixture 13 | def model_checker(): 14 | return ModelChecker() 15 | 16 | 17 | class TestCheckBioprocessModel(): 18 | 19 | @pytest.mark.parametrize('model', ModelLibrary.variants_model03) 20 | def test_bioprocess_model_checking(self, model_checker, model): 21 | name = 'model03' 22 | model_parameters = ModelLibrary.model_parameters[name] 23 | initial_values = ModelLibrary.initial_values[name] 24 | extended_simulator = ExtendedSimulator(bioprocess_model_class=model, model_parameters=model_parameters, initial_values=initial_values) 25 | # These models should not raise any warnings 26 | model_checker.check_model_consistency(extended_simulator) 27 | 28 | @pytest.mark.parametrize('bad_model', ModelLibrary.bad_variants_model03) 29 | def test_bioprocess_bad_model_checking(self, model_checker, bad_model): 30 | name = 'model03' 31 | model_parameters = ModelLibrary.model_parameters[name] 32 | initial_values = ModelLibrary.initial_values[name] 33 | extended_simulator = ExtendedSimulator(bioprocess_model_class=bad_model, model_parameters=model_parameters, initial_values=initial_values) 34 | # These models should raise warnings for different reasons (cf. the specific model in the modelling library for details) 35 | with pytest.warns(UserWarning): 36 | model_checker.check_model_consistency(extended_simulator) 37 | 38 | @pytest.mark.parametrize( 39 | 'model_variant, initial_switches, expected_behavior', 40 | [ 41 | (modelling_library.Model06, None, 'pass'), 42 | (modelling_library.Model06_V02, [False]*4, 'UserWarning'), 43 | (modelling_library.Model06_V02, None, 'pass'), 44 | (modelling_library.Model06_V03, None, 'UserWarning'), 45 | (modelling_library.Model06_Bad01, None, 'UserWarning'), 46 | (modelling_library.Model06_Bad02, None, 'UserWarning'), 47 | (modelling_library.Model06_Bad03, None, 'NameError'), 48 | (modelling_library.Model06_Bad04, None, 'NameError'), 49 | (modelling_library.Model06_Bad05, [False]*3, 'UserWarning'), 50 | (modelling_library.Model06_Bad06, None, 'UserWarning'), 51 | (modelling_library.Model06_Bad07, None, 'UserWarning'), 52 | (modelling_library.Model06_Bad08, None, 'UserWarning'), 53 | ] 54 | ) 55 | def test_model06_variants(self, model_checker, model_variant, initial_switches, expected_behavior): 56 | name = 'model06' 57 | model_parameters = ModelLibrary.model_parameters[name] 58 | initial_values = ModelLibrary.initial_values[name] 59 | extended_simulator = ExtendedSimulator( 60 | bioprocess_model_class=model_variant, 61 | model_parameters=model_parameters, 62 | initial_values=initial_values, 63 | initial_switches=initial_switches, 64 | ) 65 | if expected_behavior == 'UserWarning': 66 | with pytest.warns(UserWarning): 67 | model_checker.check_model_consistency(extended_simulator) 68 | elif expected_behavior == 'NameError': 69 | with pytest.raises(NameError): 70 | model_checker.check_model_consistency(extended_simulator) 71 | else: 72 | model_checker.check_model_consistency(extended_simulator) 73 | 74 | 75 | class TestObservationFunction(): 76 | 77 | @pytest.mark.parametrize('model', ModelLibrary.variants_model03) 78 | @pytest.mark.parametrize('obsfun', ObservationFunctionLibrary.variants_obsfun01) 79 | def test_observation_function_checking(self, model_checker, model, obsfun): 80 | # Get all building blocks for the bioprocess model 81 | name = 'model03' 82 | model_parameters = ModelLibrary.model_parameters[name] 83 | initial_values = ModelLibrary.initial_values[name] 84 | # Get all buidling blocks for the observation function 85 | name = 'obsfun01' 86 | observed_state = ObservationFunctionLibrary.observed_states[name] 87 | observation_parameters = ObservationFunctionLibrary.observation_function_parameters[name] 88 | # Create an extended simulator for checking 89 | extended_simulator = ExtendedSimulator( 90 | bioprocess_model_class=model, 91 | model_parameters=model_parameters, 92 | initial_values=initial_values, 93 | observation_functions_parameters=[ 94 | ( 95 | obsfun, 96 | { 97 | **observation_parameters, 98 | 'observed_state' : observed_state, 99 | } 100 | ), 101 | ] 102 | ) 103 | # These checks should not raise any warnings 104 | model_checker.check_model_consistency(extended_simulator) 105 | 106 | @pytest.mark.parametrize('model', ModelLibrary.variants_model03) 107 | @pytest.mark.parametrize('bad_obsfun', ObservationFunctionLibrary.bad_variants_obsfun01) 108 | def test_bad_observation_function_checking(self, model_checker, model, bad_obsfun): 109 | # Get all building blocks for the bioprocess model 110 | name = 'model03' 111 | model_parameters = ModelLibrary.model_parameters[name] 112 | initial_values = ModelLibrary.initial_values[name] 113 | # Get all buidling blocks for the observation function 114 | name = 'obsfun01' 115 | observed_state = ObservationFunctionLibrary.observed_states[name] 116 | observation_parameters = ObservationFunctionLibrary.observation_function_parameters[name] 117 | # Create an extended simulator for checking 118 | extended_simulator = ExtendedSimulator( 119 | bioprocess_model_class=model, 120 | model_parameters=model_parameters, 121 | initial_values=initial_values, 122 | observation_functions_parameters=[ 123 | ( 124 | bad_obsfun, 125 | { 126 | **observation_parameters, 127 | 'observed_state' : observed_state, 128 | } 129 | ), 130 | ] 131 | ) 132 | # These checks sould raise any warnings 133 | with pytest.warns(UserWarning): 134 | model_checker.check_model_consistency(extended_simulator) 135 | -------------------------------------------------------------------------------- /tests/test_generalized_island.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | import pytest 4 | 5 | import matplotlib 6 | matplotlib.rcParams.update({'figure.max_open_warning': 0}) 7 | matplotlib.use('agg') 8 | 9 | from pyfoomb.caretaker import Caretaker 10 | from pyfoomb.datatypes import Measurement 11 | from pyfoomb.generalized_islands import ArchipelagoHelpers 12 | from pyfoomb.generalized_islands import LossCalculator 13 | from pyfoomb.generalized_islands import PygmoOptimizers 14 | from pyfoomb.generalized_islands import ParallelEstimationInfo 15 | 16 | from modelling_library import ModelLibrary 17 | 18 | class StaticHelpers(): 19 | 20 | data_single = [Measurement(name='y0', timepoints=[1, 2, 3], values=[10, 20 ,30], errors=[0.1, 0.2, 0.3])] 21 | data_multi = [Measurement(name='y0', timepoints=[1, 2, 3], values=[10, 20 ,30], errors=[0.1, 0.2, 0.3], replicate_id='1st')] 22 | unknowns = ['y00', 'y10'] 23 | bounds = [(-100, 100), (-100, 100)] 24 | 25 | @pytest.fixture 26 | def caretaker_single(): 27 | name = 'model01' 28 | return Caretaker( 29 | bioprocess_model_class=ModelLibrary.modelclasses[name], 30 | model_parameters=ModelLibrary.model_parameters[name], 31 | initial_values=ModelLibrary.initial_values[name], 32 | ) 33 | 34 | @pytest.fixture 35 | def caretaker_multi(): 36 | name = 'model01' 37 | return Caretaker( 38 | bioprocess_model_class=ModelLibrary.modelclasses[name], 39 | model_parameters=ModelLibrary.model_parameters[name], 40 | initial_values=ModelLibrary.initial_values[name], 41 | replicate_ids=['1st', '2nd'] 42 | ) 43 | 44 | @pytest.fixture 45 | def constrainted_loss_calculator(): 46 | class OwnLossCalculator(LossCalculator): 47 | def constraint_1(self): 48 | p1 = self.current_parameters[StaticHelpers.unknowns[0]] 49 | return p1 < 0 50 | def constraint_2(self): 51 | p2 = self.current_parameters[StaticHelpers.unknowns[1]] 52 | return p2 > 0 53 | def constraint_3(self): 54 | p1 = self.current_parameters[StaticHelpers.unknowns[0]] 55 | p2 = self.current_parameters[StaticHelpers.unknowns[1]] 56 | return (p1 + p2) <= 1 57 | def check_constraints(self) -> List[bool]: 58 | return [self.constraint_1(), self.constraint_2(), self.constraint_3()] 59 | return OwnLossCalculator 60 | 61 | @pytest.fixture 62 | def evolutions_trail(): 63 | return { 64 | 'cum_runtime_min' : [1, 2, 3], 65 | 'evo_time_min' : [0.5, 0.5, 0.4], 66 | 'best_losses' : [300, 200, 100], 67 | 'estimates_info' : [ 68 | {'losses': np.array([302, 301, 300])}, 69 | {'losses': np.array([202, 201, 200])}, 70 | {'losses': np.array([102, 101, 100])} 71 | ], 72 | } 73 | 74 | 75 | class TestLossCalculator(): 76 | 77 | @pytest.mark.parametrize('metric', ['negLL', 'WSS', 'SS']) 78 | @pytest.mark.parametrize( 79 | 'fitness_vector', 80 | [ 81 | [0, 0], 82 | [100, 100], 83 | ] 84 | ) 85 | def test_fitness(self, caretaker_single, fitness_vector, metric): 86 | pg_problem = LossCalculator( 87 | StaticHelpers.unknowns, 88 | StaticHelpers.bounds, 89 | metric, 90 | StaticHelpers.data_single, 91 | caretaker_single.loss_function, 92 | ) 93 | pg_problem.fitness(fitness_vector) 94 | pg_problem.gradient(fitness_vector) 95 | 96 | @pytest.mark.parametrize( 97 | 'fitness_vector, inf_loss', 98 | [ 99 | ([-1, 1], False), 100 | ([0, 1], True), 101 | ([-1, 0], True), 102 | ([0, 0], True), 103 | ([1, 1], True), 104 | ([-1, 100], True), 105 | ] 106 | ) 107 | def test_contrained_loss_calculator(self, caretaker_single, constrainted_loss_calculator, fitness_vector, inf_loss): 108 | pg_constraint_problem = constrainted_loss_calculator( 109 | StaticHelpers.unknowns, 110 | StaticHelpers.bounds, 111 | 'negLL', 112 | StaticHelpers.data_single, 113 | caretaker_single.loss_function, 114 | ) 115 | loss = pg_constraint_problem.fitness(fitness_vector) 116 | assert np.isinf(loss) == inf_loss 117 | 118 | 119 | class TestPygmoOptimizers(): 120 | 121 | @pytest.mark.parametrize('algo_name', list(PygmoOptimizers.optimizers.keys())) 122 | def test_get_algo_instance_defaults(self, algo_name): 123 | PygmoOptimizers.get_optimizer_algo_instance(name=algo_name) 124 | 125 | @pytest.mark.parametrize( 126 | 'algo_name, kwargs', 127 | [ 128 | ('mbh', {'perturb' : 0.05, 'inner_stop_range' : 1e-3}), 129 | ('de1220', {'gen' : 10}) 130 | ] 131 | ) 132 | def test_get_algo_instance(self, algo_name, kwargs): 133 | PygmoOptimizers.get_optimizer_algo_instance(name=algo_name, kwargs=kwargs) 134 | 135 | 136 | class TestParallelEstimationInfo(): 137 | 138 | def test_properties_methods(self, evolutions_trail): 139 | est_info = ParallelEstimationInfo(archipelago='archipelago_mock', evolutions_trail=evolutions_trail) 140 | est_info.average_loss_trail 141 | est_info.best_loss_trail 142 | est_info.losses_trail 143 | est_info.std_loss_trail 144 | est_info.runtime_trail 145 | est_info.plot_loss_trail() 146 | est_info.plot_loss_trail(x_log=True) 147 | 148 | 149 | class TestArchipelagoHelpers(): 150 | 151 | @pytest.mark.parametrize('atol', [None, 1e-1]) 152 | @pytest.mark.parametrize('rtol', [None, 1e-1]) 153 | @pytest.mark.parametrize('curr_runtime', [None, 10]) 154 | @pytest.mark.parametrize('max_runtime', [None, 5]) 155 | @pytest.mark.parametrize('curr_evotime', [None, 10]) 156 | @pytest.mark.parametrize('max_evotime', [None, 5]) 157 | @pytest.mark.parametrize('max_memory_share', [0, 0.95]) 158 | def test_check_evolution_stop(self, atol, rtol, curr_runtime, max_runtime, curr_evotime, max_evotime, max_memory_share): 159 | ArchipelagoHelpers.check_evolution_stop( 160 | current_losses=np.array([10.1, 10.2, 9.9]), 161 | atol_islands=atol, 162 | rtol_islands=rtol, 163 | current_runtime_min=curr_runtime, 164 | max_runtime_min=max_runtime, 165 | current_evotime_min=curr_evotime, 166 | max_evotime_min=max_evotime, 167 | max_memory_share=max_memory_share, 168 | ) 169 | 170 | @pytest.mark.parametrize('report_level', [0, 1, 2, 3, 4]) 171 | def test_report_evolution_results(self, report_level): 172 | reps = 2 173 | mock_evolution_results = { 174 | 'evo_time_min' : [1]*reps, 175 | 'best_losses' : [[1111, 1111]]*reps, 176 | 'best_estimates' : [ 177 | {'p1' : 1, 'p2' : 10}, 178 | ], 179 | 'estimates_info' : [ 180 | { 181 | 'losses' : [1000, 1100, 1110, 1111], 182 | } 183 | ]*reps, 184 | } 185 | ArchipelagoHelpers.report_evolution_result(mock_evolution_results, report_level) 186 | -------------------------------------------------------------------------------- /tests/test_modelling.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | 5 | import pytest 6 | 7 | from pyfoomb import BioprocessModel 8 | from pyfoomb import Observation 9 | from pyfoomb import ModelState 10 | 11 | from modelling_library import ModelLibrary 12 | from modelling_library import ObservationFunctionLibrary 13 | 14 | 15 | class TestBioprocessModel(): 16 | 17 | @pytest.mark.parametrize('name', ModelLibrary.modelnames) 18 | def test_init_model(self, name): 19 | # Collect required parts to create the model instance 20 | modelclass = ModelLibrary.modelclasses[name] 21 | states = ModelLibrary.states[name] 22 | model_parameters_list = list(ModelLibrary.model_parameters[name].keys()) 23 | # Instantiate the model class, expected to work 24 | modelclass(states=states, model_parameters=model_parameters_list, model_name='my_model') 25 | model = modelclass(states=states, model_parameters=model_parameters_list) 26 | # The states argument for instatiation must must be a list as they have no values like parameters do 27 | with pytest.raises(TypeError): 28 | modelclass(states={_state : 0 for _state in states}, model_parameters=model_parameters_list) 29 | # The states can only be set during instatiation 30 | with pytest.raises(AttributeError): 31 | model.states = states 32 | # There is also a str method 33 | print(model) 34 | # States must be unique 35 | with pytest.raises(KeyError): 36 | modelclass(states=states*2, model_parameters=model_parameters_list) 37 | 38 | 39 | def test_init_model_with_events(self): 40 | # Model03 has events 41 | name = 'model03' 42 | modelclass = ModelLibrary.modelclasses[name] 43 | states = ModelLibrary.states[name] 44 | model_parameters = ModelLibrary.model_parameters[name] 45 | # The number of initial_switches can be autodetected 46 | model_v01 = modelclass(states=states, model_parameters=model_parameters) 47 | # Can also explicitly set the intial_switches 48 | model_v02 = modelclass(states=states, model_parameters=model_parameters, initial_switches=[False]) 49 | assert model_v01.initial_switches == model_v02.initial_switches 50 | 51 | def test_set_parameters(self): 52 | # Get a model instance to work with 53 | name = 'model03' 54 | modelclass = ModelLibrary.modelclasses[name] 55 | states = ModelLibrary.states[name] 56 | model_parameters = ModelLibrary.model_parameters[name] 57 | initial_values = ModelLibrary.initial_values[name] 58 | model = modelclass(states=states, model_parameters=list(model_parameters.keys())) 59 | # The BioprocessModel object has a dedicated method to set parameters (model parameters & initial values) 60 | model.set_parameters(model_parameters) 61 | model.set_parameters(initial_values) 62 | # Parameter names must be case-insensitive unique 63 | with pytest.raises(KeyError): 64 | non_unique_params = {str.upper(_iv) : 1 for _iv in initial_values} 65 | non_unique_params.update(initial_values) 66 | model.set_parameters(non_unique_params) 67 | # Keys for initial values must match the state names extended by "0" 68 | model.initial_values = {f'{_state}0' : 1 for _state in states} 69 | with pytest.raises(KeyError): 70 | model.initial_values = {f'{_state}X' : 1 for _state in states} 71 | # Initial values and model parameters must be a dict 72 | with pytest.raises(TypeError): 73 | model.initial_values = [_iv for _iv in initial_values] 74 | with pytest.raises(TypeError): 75 | model.model_parameters = [_p for _p in model_parameters] 76 | # After init, no new model_parameters can be introduced 77 | with pytest.raises(KeyError): 78 | model.model_parameters = {**model_parameters, 'new_p' : 1} 79 | # Number of initial_switches cannot be changed after init 80 | initial_switches = ModelLibrary.initial_switches[name] 81 | with pytest.raises(ValueError): 82 | model.initial_switches = initial_switches*2 83 | # Initial switches must be a list of booleans 84 | with pytest.raises(ValueError): 85 | model.initial_switches = ['False' for _ in initial_switches] 86 | 87 | 88 | class TestObservationFunction(): 89 | 90 | @pytest.mark.parametrize('name', ObservationFunctionLibrary.names) 91 | def test_init(self, name): 92 | obsfun = ObservationFunctionLibrary.observation_functions[name] 93 | observed_state = ObservationFunctionLibrary.observed_states[name] 94 | observation_parameters = ObservationFunctionLibrary.observation_function_parameters[name] 95 | obsfun(observed_state=observed_state, observation_parameters=list(observation_parameters.keys())) 96 | 97 | def test_get_observations(self): 98 | # Create an ObservationFunction 99 | name = 'obsfun01' 100 | obsfun = ObservationFunctionLibrary.observation_functions[name] 101 | observed_state = ObservationFunctionLibrary.observed_states[name] 102 | observation_parameters = ObservationFunctionLibrary.observation_function_parameters[name] 103 | observation_function = obsfun(observed_state=observed_state, observation_parameters=list(observation_parameters.keys())) 104 | 105 | # After creating the ObservationFunction, all parameter values are None, regardless if a list of dictionary is used as argument for parameters 106 | for _p in observation_parameters: 107 | assert observation_function.observation_parameters[_p] is None 108 | 109 | # One must explicitly set the parameter values 110 | observation_function.set_parameters(observation_parameters) 111 | for _p in observation_parameters: 112 | assert observation_function.observation_parameters[_p] is not None 113 | 114 | # Create a ModelState that now can be observed observe 115 | modelstate = ModelState( 116 | name=ObservationFunctionLibrary.observed_states[name], 117 | timepoints=[1, 2, 3], 118 | values=[10, 20, 30], 119 | ) 120 | observation_function.get_observation(modelstate) 121 | 122 | # The ModelState to be observed must match the ObservationsFunction's replicate_id 123 | modelstate.replicate_id = '1st' 124 | with pytest.raises(ValueError): 125 | observation_function.get_observation(modelstate) 126 | 127 | # Same for the state name 128 | modelstate.name = 'other_state' 129 | with pytest.raises(KeyError): 130 | observation_function.get_observation(modelstate) 131 | 132 | def test_properties(self): 133 | # Create an ObservationFunction 134 | name = 'obsfun01' 135 | obsfun = ObservationFunctionLibrary.observation_functions[name] 136 | observed_state = ObservationFunctionLibrary.observed_states[name] 137 | observation_parameters = ObservationFunctionLibrary.observation_function_parameters[name] 138 | observation_function = obsfun(observed_state=observed_state, observation_parameters=list(observation_parameters.keys())) 139 | observation_function.set_parameters(observation_parameters) 140 | 141 | # Can't change the observed state after instantiation 142 | with pytest.raises(AttributeError): 143 | observation_function.observed_state = 'new_state' 144 | 145 | # Cant't set unknown parameters 146 | with pytest.raises(KeyError): 147 | observation_function.observation_parameters = {'unknown_parameter' : 1000} 148 | 149 | # Observed state parameter must match the corresponding property of ObservationFunction 150 | with pytest.raises(ValueError): 151 | observation_function.observation_parameters = {**observation_parameters, 'observed_state' : 'unknown_state'} 152 | 153 | # Must use a dictionary to set the property 154 | with pytest.raises(ValueError): 155 | observation_function.observation_parameters = list(observation_parameters.keys()) 156 | 157 | # There is a str method 158 | print(observation_function) 159 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | *.coverage* 13 | 14 | # User-specific files (MonoDevelop/Xamarin Studio) 15 | *.userprefs 16 | 17 | # Build results 18 | [Dd]ebug/ 19 | [Dd]ebugPublic/ 20 | [Rr]elease/ 21 | [Rr]eleases/ 22 | x64/ 23 | x86/ 24 | [Aa][Rr][Mm]/ 25 | [Aa][Rr][Mm]64/ 26 | bld/ 27 | [Bb]in/ 28 | [Oo]bj/ 29 | [Ll]og/ 30 | 31 | # Visual Studio 2015/2017 cache/options directory 32 | .vs/ 33 | # Uncomment if you have tasks that create the project's static files in wwwroot 34 | #wwwroot/ 35 | 36 | # Visual Studio 2017 auto generated files 37 | Generated\ Files/ 38 | 39 | # MSTest test Results 40 | [Tt]est[Rr]esult*/ 41 | [Bb]uild[Ll]og.* 42 | 43 | # NUNIT 44 | *.VisualState.xml 45 | TestResult.xml 46 | 47 | # Build Results of an ATL Project 48 | [Dd]ebugPS/ 49 | [Rr]eleasePS/ 50 | dlldata.c 51 | 52 | # Benchmark Results 53 | BenchmarkDotNet.Artifacts/ 54 | 55 | # .NET Core 56 | project.lock.json 57 | project.fragment.lock.json 58 | artifacts/ 59 | 60 | # StyleCop 61 | StyleCopReport.xml 62 | 63 | # Files built by Visual Studio 64 | *_i.c 65 | *_p.c 66 | *_h.h 67 | *.ilk 68 | *.meta 69 | *.obj 70 | *.iobj 71 | *.pch 72 | *.pdb 73 | *.ipdb 74 | *.pgc 75 | *.pgd 76 | *.rsp 77 | *.sbr 78 | *.tlb 79 | *.tli 80 | *.tlh 81 | *.tmp 82 | *.tmp_proj 83 | *_wpftmp.csproj 84 | *.log 85 | *.vspscc 86 | *.vssscc 87 | .builds 88 | *.pidb 89 | *.svclog 90 | *.scc 91 | 92 | # Chutzpah Test files 93 | _Chutzpah* 94 | 95 | # Visual C++ cache files 96 | ipch/ 97 | *.aps 98 | *.ncb 99 | *.opendb 100 | *.opensdf 101 | *.sdf 102 | *.cachefile 103 | *.VC.db 104 | *.VC.VC.opendb 105 | 106 | # Visual Studio profiler 107 | *.psess 108 | *.vsp 109 | *.vspx 110 | *.sap 111 | 112 | # Visual Studio Trace Files 113 | *.e2e 114 | 115 | # TFS 2012 Local Workspace 116 | $tf/ 117 | 118 | # Guidance Automation Toolkit 119 | *.gpState 120 | 121 | # ReSharper is a .NET coding add-in 122 | _ReSharper*/ 123 | *.[Rr]e[Ss]harper 124 | *.DotSettings.user 125 | 126 | # JustCode is a .NET coding add-in 127 | .JustCode 128 | 129 | # TeamCity is a build add-in 130 | _TeamCity* 131 | 132 | # DotCover is a Code Coverage Tool 133 | *.dotCover 134 | 135 | # AxoCover is a Code Coverage Tool 136 | .axoCover/* 137 | !.axoCover/settings.json 138 | 139 | # Visual Studio code coverage results 140 | *.coverage 141 | *.coveragexml 142 | 143 | # NCrunch 144 | _NCrunch_* 145 | .*crunch*.local.xml 146 | nCrunchTemp_* 147 | 148 | # MightyMoose 149 | *.mm.* 150 | AutoTest.Net/ 151 | 152 | # Web workbench (sass) 153 | .sass-cache/ 154 | 155 | # Installshield output folder 156 | [Ee]xpress/ 157 | 158 | # DocProject is a documentation generator add-in 159 | DocProject/buildhelp/ 160 | DocProject/Help/*.HxT 161 | DocProject/Help/*.HxC 162 | DocProject/Help/*.hhc 163 | DocProject/Help/*.hhk 164 | DocProject/Help/*.hhp 165 | DocProject/Help/Html2 166 | DocProject/Help/html 167 | 168 | # Click-Once directory 169 | publish/ 170 | 171 | # Publish Web Output 172 | *.[Pp]ublish.xml 173 | *.azurePubxml 174 | # Note: Comment the next line if you want to checkin your web deploy settings, 175 | # but database connection strings (with potential passwords) will be unencrypted 176 | *.pubxml 177 | *.publishproj 178 | 179 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 180 | # checkin your Azure Web App publish settings, but sensitive information contained 181 | # in these scripts will be unencrypted 182 | PublishScripts/ 183 | 184 | # NuGet Packages 185 | *.nupkg 186 | # The packages folder can be ignored because of Package Restore 187 | **/[Pp]ackages/* 188 | # except build/, which is used as an MSBuild target. 189 | !**/[Pp]ackages/build/ 190 | # Uncomment if necessary however generally it will be regenerated when needed 191 | #!**/[Pp]ackages/repositories.config 192 | # NuGet v3's project.json files produces more ignorable files 193 | *.nuget.props 194 | *.nuget.targets 195 | 196 | # Microsoft Azure Build Output 197 | csx/ 198 | *.build.csdef 199 | 200 | # Microsoft Azure Emulator 201 | ecf/ 202 | rcf/ 203 | 204 | # Windows Store app package directories and files 205 | AppPackages/ 206 | BundleArtifacts/ 207 | Package.StoreAssociation.xml 208 | _pkginfo.txt 209 | *.appx 210 | 211 | # Visual Studio cache files 212 | # files ending in .cache can be ignored 213 | *.[Cc]ache 214 | # but keep track of directories ending in .cache 215 | !?*.[Cc]ache/ 216 | 217 | # Others 218 | ClientBin/ 219 | ~$* 220 | *~ 221 | *.dbmdl 222 | *.dbproj.schemaview 223 | *.jfm 224 | *.pfx 225 | *.publishsettings 226 | orleans.codegen.cs 227 | 228 | # Including strong name files can present a security risk 229 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 230 | #*.snk 231 | 232 | # Since there are multiple workflows, uncomment next line to ignore bower_components 233 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 234 | #bower_components/ 235 | 236 | # RIA/Silverlight projects 237 | Generated_Code/ 238 | 239 | # Backup & report files from converting an old project file 240 | # to a newer Visual Studio version. Backup files are not needed, 241 | # because we have git ;-) 242 | _UpgradeReport_Files/ 243 | Backup*/ 244 | UpgradeLog*.XML 245 | UpgradeLog*.htm 246 | ServiceFabricBackup/ 247 | *.rptproj.bak 248 | 249 | # SQL Server files 250 | *.mdf 251 | *.ldf 252 | *.ndf 253 | 254 | # Business Intelligence projects 255 | *.rdl.data 256 | *.bim.layout 257 | *.bim_*.settings 258 | *.rptproj.rsuser 259 | *- Backup*.rdl 260 | 261 | # Microsoft Fakes 262 | FakesAssemblies/ 263 | 264 | # GhostDoc plugin setting file 265 | *.GhostDoc.xml 266 | 267 | # Node.js Tools for Visual Studio 268 | .ntvs_analysis.dat 269 | node_modules/ 270 | 271 | # Visual Studio 6 build log 272 | *.plg 273 | 274 | # Visual Studio 6 workspace options file 275 | *.opt 276 | 277 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 278 | *.vbw 279 | 280 | # Visual Studio LightSwitch build output 281 | **/*.HTMLClient/GeneratedArtifacts 282 | **/*.DesktopClient/GeneratedArtifacts 283 | **/*.DesktopClient/ModelManifest.xml 284 | **/*.Server/GeneratedArtifacts 285 | **/*.Server/ModelManifest.xml 286 | _Pvt_Extensions 287 | 288 | # Paket dependency manager 289 | .paket/paket.exe 290 | paket-files/ 291 | 292 | # FAKE - F# Make 293 | .fake/ 294 | 295 | # JetBrains Rider 296 | .idea/ 297 | *.sln.iml 298 | 299 | # CodeRush personal settings 300 | .cr/personal 301 | 302 | # Python Tools for Visual Studio (PTVS) 303 | __pycache__/ 304 | *.pyc 305 | 306 | # Cake - Uncomment if you are using it 307 | # tools/** 308 | # !tools/packages.config 309 | 310 | # Tabs Studio 311 | *.tss 312 | 313 | # Telerik's JustMock configuration file 314 | *.jmconfig 315 | 316 | # BizTalk build output 317 | *.btp.cs 318 | *.btm.cs 319 | *.odx.cs 320 | *.xsd.cs 321 | 322 | # OpenCover UI analysis results 323 | OpenCover/ 324 | 325 | # Azure Stream Analytics local run output 326 | ASALocalRun/ 327 | 328 | # MSBuild Binary and Structured Log 329 | *.binlog 330 | 331 | # NVidia Nsight GPU debugger configuration file 332 | *.nvuser 333 | 334 | # MFractors (Xamarin productivity tool) working folder 335 | .mfractor/ 336 | 337 | # Local History for Visual Studio 338 | .localhistory/ 339 | 340 | # BeatPulse healthcheck temp database 341 | healthchecksdb 342 | 343 | # Created by https://www.gitignore.io/api/jupyternotebook,jupyternotebooks 344 | # Edit at https://www.gitignore.io/?templates=jupyternotebook,jupyternotebooks 345 | 346 | ### JupyterNotebook ### 347 | .ipynb_checkpoints 348 | */.ipynb_checkpoints/* 349 | 350 | # Remove previous ipynb_checkpoints 351 | # git rm -r .ipynb_checkpoints/ 352 | # 353 | 354 | ### JupyterNotebooks ### 355 | # gitignore template for Jupyter Notebooks 356 | # website: http://jupyter.org/ 357 | 358 | 359 | # Remove previous ipynb_checkpoints 360 | # git rm -r .ipynb_checkpoints/ 361 | 362 | # End of https://www.gitignore.io/api/jupyternotebook,jupyternotebooks 363 | /bioprocess_modelling.egg-info 364 | .vscode/settings.json 365 | /pyfoomb.egg-info 366 | /.vscode/.ropeproject 367 | .vscode/launch.json 368 | 369 | # Distribution / packaging 370 | .Python 371 | build/ 372 | develop-eggs/ 373 | dist/ 374 | downloads/ 375 | eggs/ 376 | .eggs/ 377 | lib/ 378 | lib64/ 379 | parts/ 380 | sdist/ 381 | var/ 382 | wheels/ 383 | pip-wheel-metadata/ 384 | share/python-wheels/ 385 | *.egg-info/ 386 | .installed.cfg 387 | *.egg 388 | -------------------------------------------------------------------------------- /pyfoomb/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import numpy 4 | from typing import Dict, List 5 | import warnings 6 | 7 | from .constants import Messages 8 | from .datatypes import Measurement 9 | from .datatypes import TimeSeries 10 | from .datatypes import Sensitivity 11 | from .constants import Constants 12 | 13 | SINGLE_ID = Constants.single_id 14 | 15 | class OwnDict(collections.OrderedDict): 16 | """ 17 | Extendeds OrderedDict with `to_numpy()` method 18 | """ 19 | 20 | def to_numpy(self) -> numpy.ndarray: 21 | return numpy.array(list((self.values()))) 22 | 23 | 24 | class Calculations(): 25 | 26 | @staticmethod 27 | def cov_into_corr(Cov:numpy.ndarray) -> numpy.ndarray: 28 | """ 29 | Calculates correlation matrix from variance-covariance matrix. 30 | 31 | Arguments 32 | --------- 33 | Cov : numpy.ndarray 34 | Variance-covariance matrix, must be sqaure and positive semi-definite. 35 | 36 | Returns 37 | ------- 38 | Corr : numpy.ndarray 39 | Correlation matrix for Cov. 40 | 41 | Raises 42 | ------ 43 | ValueError 44 | Cov is not square. 45 | """ 46 | 47 | if Cov.shape[0] != Cov.shape[1]: 48 | raise ValueError('Cov must be square') 49 | 50 | Corr = numpy.zeros_like(Cov) * numpy.nan 51 | for i in range(Cov.shape[0]): 52 | for j in range(Cov.shape[0]): 53 | Corr[i, j] = Cov[i, j] / (numpy.sqrt(Cov[i, i]) * numpy.sqrt(Cov[j, j])) 54 | return Corr 55 | 56 | 57 | class Helpers(): 58 | 59 | @staticmethod 60 | def bounds_to_floats(bounds:List[tuple]) -> List[tuple]: 61 | """ 62 | Casts bounds from int to float. 63 | """ 64 | 65 | new_bounds = [] 66 | for _bounds in bounds: 67 | lower, upper = _bounds 68 | new_bounds.append((float(lower), float(upper))) 69 | return new_bounds 70 | 71 | @staticmethod 72 | def has_unique_ids(values, report:bool=True) -> bool: 73 | """ 74 | Verifies that a list or dict has only (case-insensitive) unique items or keys, respectively. 75 | 76 | Keyword arguments 77 | ----------------- 78 | report : bool 79 | To show the non-unique ids. 80 | """ 81 | 82 | success = True 83 | 84 | if isinstance(values, set): 85 | return success 86 | if len(values) == 1: 87 | return success 88 | 89 | _values = copy.deepcopy(values) 90 | 91 | if isinstance(_values, list): 92 | _values.sort(key=str.lower) 93 | values_str_lower = [_value.lower() for _value in _values] 94 | if len(_values) > len(set(values_str_lower)): 95 | success = False 96 | elif isinstance(_values, (dict, OwnDict)): 97 | _values = list(_values.keys()) 98 | values_str_lower = [_value.lower() for _value in _values] 99 | if len(_values) > len(set(values_str_lower)): 100 | success = False 101 | else: 102 | raise TypeError(f'Type {type(values)} cannot be handled.') 103 | if not success and report: 104 | print(f'Bad, non-unique (case-insensitive) ids: {_values}') 105 | 106 | return success 107 | 108 | @staticmethod 109 | def all_measurements_have_errors(measurements:List[Measurement]) -> bool: 110 | """ 111 | Checks whether if Measurement objects have errors. 112 | """ 113 | 114 | with_errors = [] 115 | for measurement in measurements: 116 | if measurement.errors is None: 117 | with_errors.append(False) 118 | else: 119 | with_errors.append(True) 120 | return all(with_errors) 121 | 122 | 123 | @staticmethod 124 | def get_unique_timepoints(time_series:List[TimeSeries]) -> numpy.ndarray: 125 | """ 126 | Creates a joint unique time vector from all timepoints of a list of TimeSeries objects. 127 | 128 | Arguments 129 | --------- 130 | time_series : List[TimeSeries] 131 | The list of TimeSeries (and subclasses thereof) for which a joint time vector is wanted. 132 | 133 | Returns 134 | ------- 135 | t_all : numpy.ndarray 136 | The joint vector of time points. 137 | """ 138 | _t = [ 139 | _timepoint 140 | for _time_series in time_series 141 | for _timepoint in _time_series.timepoints.flatten() 142 | ] 143 | return numpy.unique(_t) 144 | 145 | @staticmethod 146 | def extract_time_series(time_series:List[TimeSeries], name:str, replicate_id:str, no_extraction_warning:bool=False) -> TimeSeries: 147 | """ 148 | Extract a specific TimeSeries object, identified by its properties `name` and `replicate_id`. 149 | In case no match is found, None is returned. 150 | 151 | Arguments 152 | --------- 153 | time_series : List[TimeSeries] 154 | The list from which the specific TimeSeries object shall be extracted. 155 | name : str 156 | The identifying `name` property. 157 | replicate_id : str 158 | The identifying `replicate_id` property. 159 | 160 | Keyword arguments 161 | ----------------- 162 | no_extraction_warning : bool 163 | Whether to raise a warning when no TimeSeries object can be extracted. 164 | Default is False 165 | 166 | Returns 167 | ------- 168 | extracted_time_series : TimeSeries or None 169 | 170 | Raises 171 | ------ 172 | ValueError 173 | Multiple TimeSeries objects have the same `name` and `replicate_id` property. 174 | 175 | Warns 176 | ----- 177 | UserWarning 178 | No TimeSeries object match the criteria. 179 | Only raised for `no_extraction_warning` set to True. 180 | """ 181 | 182 | _extracted_time_series = [ 183 | _time_series for _time_series in time_series 184 | if _time_series.name == name and _time_series.replicate_id == replicate_id 185 | ] 186 | 187 | if len(_extracted_time_series) > 1: 188 | raise ValueError('List of (subclassed) TimeSeries objects is ambigous. Found multiple occurences ') 189 | if len(_extracted_time_series) == 0: 190 | extracted_time_series = None 191 | if no_extraction_warning: 192 | warnings.warn(f'Could not extract a TimeSeries object with replicate_id {replicate_id} and name {name}') 193 | else: 194 | extracted_time_series = _extracted_time_series[0] 195 | 196 | return extracted_time_series 197 | 198 | @staticmethod 199 | def get_parameters_length(parameter_collections:Dict[str, numpy.ndarray]) -> int: 200 | """ 201 | Arguments 202 | --------- 203 | parameter_collections : Dict[str, numpy.ndarray] 204 | A set of parameters (model parameters, initial values, observation parameters). 205 | 206 | Returns 207 | ------- 208 | length : int 209 | The number of values for each parameter 210 | 211 | Raises 212 | ------ 213 | ValueError 214 | Parameters have different number of estimated values. 215 | """ 216 | 217 | lengths = set([len(parameter_collections[_p]) for _p in parameter_collections]) 218 | if len(lengths) > 1: 219 | raise ValueError('Parameters have different number of estimated values.') 220 | length = list(lengths)[0] 221 | return length 222 | 223 | @staticmethod 224 | def split_parameters_distributions(parameter_collections:Dict[str, numpy.ndarray]) -> List[Dict]: 225 | """ 226 | Arguments 227 | --------- 228 | parameter_collections : Dict[str, numpy.ndarray] 229 | A set of parameters (model parameters, initial values, observation parameters). 230 | 231 | Returns 232 | ------- 233 | splits : List[Dict] 234 | A list of separate parameter dictonaries for each slice of the parameter collections. 235 | """ 236 | 237 | _length = Helpers.get_parameters_length(parameter_collections) 238 | splits = [ 239 | { 240 | _p : parameter_collections[_p][i] 241 | for _p in parameter_collections 242 | } 243 | for i in range(_length) 244 | ] 245 | return splits -------------------------------------------------------------------------------- /tests/test_simulation.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import warnings 4 | 5 | from assimulo.solvers.sundials import CVodeError 6 | 7 | import pytest 8 | 9 | from pyfoomb.datatypes import ModelState 10 | from pyfoomb.datatypes import Measurement 11 | 12 | from pyfoomb.simulation import Simulator 13 | from pyfoomb.simulation import ExtendedSimulator 14 | from pyfoomb.simulation import ModelObserver 15 | 16 | import modelling_library 17 | from modelling_library import ModelLibrary 18 | from modelling_library import ObservationFunctionLibrary 19 | 20 | 21 | @pytest.fixture(params=ModelLibrary.modelnames) 22 | def simulator(request): 23 | modelclass = ModelLibrary.modelclasses[request.param] 24 | model_parameters = ModelLibrary.model_parameters[request.param] 25 | initial_values = ModelLibrary.initial_values[request.param] 26 | return Simulator(bioprocess_model_class=modelclass, model_parameters=model_parameters, initial_values=initial_values) 27 | 28 | 29 | @pytest.fixture(params=ModelLibrary.modelnames) 30 | def extended_simulator(request): 31 | modelclass = ModelLibrary.modelclasses[request.param] 32 | model_parameters = ModelLibrary.model_parameters[request.param] 33 | initial_values = ModelLibrary.initial_values[request.param] 34 | return ExtendedSimulator(bioprocess_model_class=modelclass, model_parameters=model_parameters, initial_values=initial_values) 35 | 36 | 37 | class TestSimulator(): 38 | 39 | @pytest.mark.parametrize('name', ModelLibrary.modelnames) 40 | def test_init(self, name): 41 | modelclass = ModelLibrary.modelclasses[name] 42 | states = ModelLibrary.states[name] 43 | initial_values = ModelLibrary.initial_values[name] 44 | model_parameters = ModelLibrary.model_parameters[name] 45 | initial_switches = ModelLibrary.initial_switches[name] 46 | # These inits will work 47 | Simulator(bioprocess_model_class=modelclass, model_parameters=model_parameters, states=states) 48 | simulator = Simulator(bioprocess_model_class=modelclass, model_parameters=model_parameters, initial_values=initial_values) 49 | # Check correct autodetection of initial switches 50 | if initial_switches is not None: 51 | for _actual, _expected in zip(simulator.bioprocess_model.initial_switches, initial_switches): 52 | assert _actual == _expected 53 | 54 | # Can also provide the model parameters as list 55 | Simulator(bioprocess_model_class=modelclass, model_parameters=list(model_parameters.keys()), initial_values=initial_values) 56 | # The model parameter cannot be of other types that list or dict 57 | with pytest.raises(ValueError): 58 | Simulator(bioprocess_model_class=modelclass, model_parameters=True, states=states) 59 | # Must provide at least either states list of initial values dict 60 | with pytest.raises(ValueError): 61 | Simulator(bioprocess_model_class=modelclass, model_parameters=model_parameters) 62 | 63 | @pytest.mark.parametrize('t', [24, [0, 1, 2, 3]]) 64 | def test_simulate(self, simulator, t): 65 | simulator.simulate(t=t) 66 | # Using unknown parameters has no effect and passes silently 67 | simulator.simulate(t=t, parameters={'unknown' : np.nan}) 68 | # Integrator warnings (non-critical) which are sent to stdout are suppred by default 69 | simulator.simulate(t=t, suppress_stdout=False) 70 | 71 | def test_integrator_kwargs(self, simulator): 72 | # Must be a dict 73 | with pytest.raises(ValueError): 74 | simulator.integrator_kwargs = ('atol', 1e-8, 'rtol' , 1e-8) 75 | # Tighter tolerance can lead to increased number or integrations for models with high dynamic states 76 | simulator.integrator_kwargs = {'atol' : 1e-14, 'rtol' : 1e-14} 77 | sim_lower_tols = simulator.simulate(t=1000) 78 | simulator.integrator_kwargs = {'atol' : 1e-2, 'rtol' : 1e-2} 79 | sim_higher_tols = simulator.simulate(t=1000) 80 | assert sim_lower_tols[0].length >= sim_higher_tols[0].length 81 | 82 | def test_simulator_with_observations(self): 83 | # Get building blocks for BioprocessModel 84 | name = 'model01' 85 | modelclass = ModelLibrary.modelclasses[name] 86 | initial_values = ModelLibrary.initial_values[name] 87 | model_parameters = ModelLibrary.model_parameters[name] 88 | # Get building blocks for ObservationFunctions 89 | obsfun_name = 'obsfun01' 90 | obsfun = ObservationFunctionLibrary.observation_functions[obsfun_name] 91 | obsfun_parameters = ObservationFunctionLibrary.observation_function_parameters[obsfun_name] 92 | observed_state = ObservationFunctionLibrary.observed_states[obsfun_name] 93 | obsfuns_params = [ 94 | ( 95 | obsfun, 96 | {**obsfun_parameters, 'observed_state' : observed_state} 97 | ) 98 | ] 99 | simulator = Simulator( 100 | bioprocess_model_class=modelclass, 101 | model_parameters=model_parameters, 102 | initial_values=initial_values, 103 | observation_functions_parameters=obsfuns_params, 104 | ) 105 | simulator.simulate(t=24) 106 | simulator.simulate(t=24, reset_afterwards=True) 107 | 108 | # The observation cannot target a not define modelstate 109 | with pytest.raises(ValueError): 110 | simulator = Simulator( 111 | bioprocess_model_class=modelclass, 112 | model_parameters=model_parameters, 113 | initial_values=initial_values, 114 | observation_functions_parameters=[(obsfun, {**obsfun_parameters, 'observed_state' : 'unknown_state'})], 115 | ) 116 | 117 | 118 | class TestExtendedSimulator(): 119 | 120 | @pytest.mark.parametrize('name', ModelLibrary.modelnames) 121 | def test_init(self, name): 122 | modelclass = ModelLibrary.modelclasses[name] 123 | initial_values = ModelLibrary.initial_values[name] 124 | model_parameters = ModelLibrary.model_parameters[name] 125 | ExtendedSimulator(bioprocess_model_class=modelclass, model_parameters=model_parameters, initial_values=initial_values) 126 | 127 | @pytest.mark.parametrize('t', [24, [1, 2, 3]]) 128 | @pytest.mark.parametrize('metric', ['SS', 'WSS', 'negLL']) 129 | @pytest.mark.parametrize('handle_CVodeError', [True, False]) 130 | def test_get_loss(self, extended_simulator, t, metric, handle_CVodeError): 131 | # Create some measurement objects from predictions, i.e. create some artifical data 132 | predicitions = extended_simulator.simulate(t=t) 133 | measurements = [ 134 | Measurement(name=_prediction.name, timepoints=_prediction.timepoints, values=_prediction.values, errors=np.ones_like(_prediction.values)) 135 | for _prediction in predicitions 136 | ] 137 | extended_simulator._get_loss(metric=metric, measurements=measurements) 138 | # Loss will be nan in case no relevant measurements are provided, i.e. measurements for states that are not defined 139 | assert np.isnan( 140 | extended_simulator._get_loss( 141 | metric=metric, 142 | measurements=[ 143 | Measurement(name='y1000', timepoints=[100, 200], values=[1, 2], errors=[10, 20]), 144 | ] 145 | ) 146 | ) 147 | # Get loss for other parameters, as a minimizer would do several times 148 | _params = extended_simulator.get_all_parameters() 149 | different_params = {_p : _params[_p]*0.95 for _p in _params} 150 | extended_simulator._get_loss_for_minimzer( 151 | metric=metric, 152 | guess_dict=different_params, 153 | measurements=measurements, 154 | handle_CVodeError=handle_CVodeError, 155 | verbosity_CVodeError=False, 156 | ) 157 | 158 | def test_with_model_enforcing_CVodeError(self): 159 | name = 'model06' 160 | modelclass = ModelLibrary.modelclasses[name] 161 | model_parameters = ModelLibrary.model_parameters[name] 162 | initial_values = ModelLibrary.initial_values[name] 163 | extended_simulator = ExtendedSimulator(bioprocess_model_class=modelclass, model_parameters=model_parameters, initial_values=initial_values) 164 | # The chosen model will create an integration error for rate = 0. A RuntimeWarning is tehrefore raised before the CVodeError is raised 165 | with pytest.warns(RuntimeWarning): 166 | with pytest.raises(CVodeError): 167 | extended_simulator.simulate(t=24, parameters={'rate0' : 0}) 168 | with pytest.warns(RuntimeWarning): 169 | with pytest.raises(CVodeError): 170 | extended_simulator._get_loss_for_minimzer( 171 | metric='negLL', 172 | guess_dict={'rate0' : 0}, 173 | measurements=[ 174 | Measurement(name='y0', timepoints=[1, 2, 3], values=[10, 20, 30]), 175 | ], 176 | handle_CVodeError=False, 177 | verbosity_CVodeError=False, 178 | ) 179 | # For toxic parameters causing integration errors, CVodeError handling results in inf loss 180 | with pytest.warns(RuntimeWarning): 181 | assert np.isinf( 182 | extended_simulator._get_loss_for_minimzer( 183 | metric='negLL', 184 | guess_dict={'rate0' : 0}, 185 | measurements=[ 186 | Measurement(name='y0', timepoints=[1, 2, 3], values=[10, 20, 30]), 187 | ], 188 | handle_CVodeError=True, 189 | verbosity_CVodeError=True, 190 | ) 191 | ) 192 | 193 | def test_extended_simulator_with_observations(self): 194 | # Get building blocks for BioprocessModel 195 | name = 'model01' 196 | modelclass = ModelLibrary.modelclasses[name] 197 | initial_values = ModelLibrary.initial_values[name] 198 | model_parameters = ModelLibrary.model_parameters[name] 199 | # Get building blocks for ObservationFunctions 200 | obsfun_name = 'obsfun01' 201 | obsfun = ObservationFunctionLibrary.observation_functions[obsfun_name] 202 | obsfun_parameters = ObservationFunctionLibrary.observation_function_parameters[obsfun_name] 203 | observed_state = ObservationFunctionLibrary.observed_states[obsfun_name] 204 | obsfuns_params = [ 205 | ( 206 | obsfun, 207 | {**obsfun_parameters, 'observed_state' : observed_state} 208 | ) 209 | ] 210 | 211 | # Set new values for parameters, using an extended simulator 212 | extended_simulator = ExtendedSimulator( 213 | bioprocess_model_class=modelclass, 214 | model_parameters=model_parameters, 215 | initial_values=initial_values, 216 | observation_functions_parameters=obsfuns_params, 217 | ) 218 | params = extended_simulator.get_all_parameters() 219 | extended_simulator.set_parameters({_p : params[_p]*1.05 for _p in params}) 220 | 221 | # Get some prediction to be used as artifical data 222 | predicitions = extended_simulator.simulate(t=24) 223 | measurements = [ 224 | Measurement(name=_prediction.name, timepoints=_prediction.timepoints, values=_prediction.values, errors=np.ones_like(_prediction.values)) 225 | for _prediction in predicitions 226 | ] 227 | extended_simulator._get_loss(metric='negLL', measurements=measurements) 228 | 229 | 230 | class TestModelObserver(): 231 | 232 | @pytest.mark.parametrize('obsfun_name', ObservationFunctionLibrary.names) 233 | def test_init_observe(self, simulator, obsfun_name): 234 | 235 | obsfun = ObservationFunctionLibrary.observation_functions[obsfun_name] 236 | obsfun_parameters = ObservationFunctionLibrary.observation_function_parameters[obsfun_name] 237 | observed_state = ObservationFunctionLibrary.observed_states[obsfun_name] 238 | obsfuns_params = [ 239 | ( 240 | obsfun, 241 | {**obsfun_parameters, 'observed_state' : observed_state} 242 | ) 243 | ] 244 | observer = ModelObserver(observation_functions_parameters=obsfuns_params) 245 | # The observed_state must be indicated in the dictionary with observation function parameters 246 | with pytest.raises(KeyError): 247 | ModelObserver(observation_functions_parameters=[(obsfun, obsfun_parameters)]) 248 | # Create and observe a Modelstate 249 | modelstate = ModelState(name=observed_state, timepoints=[1, 2, 3], values=[10, 20, 30]) 250 | observer.get_observations(model_states=[modelstate]) 251 | # There is also a str method 252 | print(observer) 253 | -------------------------------------------------------------------------------- /pyfoomb/parameter.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | from dataclasses import dataclass 4 | import numpy 5 | import pandas 6 | from typing import List 7 | import warnings 8 | 9 | from .constants import Messages 10 | from .utils import Helpers 11 | from .utils import OwnDict 12 | 13 | 14 | @dataclass 15 | class ParameterMapper(): 16 | """ 17 | Maps global parameter names to local ones, specific for a certain replicate 18 | """ 19 | 20 | replicate_id : str 21 | global_name : str 22 | local_name : str = None 23 | value : float = None 24 | 25 | def __post_init__(self): 26 | if self.local_name is None: 27 | if self.replicate_id == 'all' or isinstance(self.replicate_id, list): 28 | raise ValueError('Argument `local_name` cannot be None when `replicate_id` is "all" or a list of replicate_ids') 29 | self.local_name = f'{self.global_name}_{self.replicate_id}' 30 | 31 | 32 | class Parameter(): 33 | 34 | def __init__(self, global_name:str, replicate_id:str, local_name:str=None, value:float=None): 35 | """ 36 | Arguments 37 | --------- 38 | global_name : str 39 | The global name of the parameter 40 | replicate_id : str 41 | The replicate_id for which the local_name may apply 42 | 43 | Keyword arguments 44 | ----------------- 45 | local_name : str 46 | the local name for the parameter, specifically for the corresponding replicate_id. 47 | Default is None, which sets the global name as local name 48 | value : float 49 | Default is None 50 | """ 51 | 52 | self.global_name = global_name 53 | self.replicate_id = replicate_id 54 | if local_name is not None: 55 | self.local_name = local_name 56 | else: 57 | self.local_name = global_name 58 | self.value = value 59 | 60 | 61 | class ParameterManager(): 62 | """ 63 | Manages a list of Parameter objects and their mappings between 64 | a bioprocess model (including observation functions) and replicates of the model instances. 65 | """ 66 | 67 | def __init__(self, replicate_ids:list, parameters:dict): 68 | """ 69 | Arguments 70 | --------- 71 | replicate_ids : list 72 | A list of case-sensitive unique replicate ids. 73 | 74 | parameters : dict 75 | Parameters that are set. Keys are the global parameter names, which are used as local names. 76 | Values are set correspondingly. 77 | """ 78 | 79 | self._is_init = True 80 | self.global_parameters = parameters 81 | self.replicate_ids = replicate_ids 82 | self._parameters = [ 83 | Parameter( 84 | global_name=p, 85 | replicate_id=_id, 86 | value=self.global_parameters[p], 87 | ) 88 | for p in self.global_parameters.keys() 89 | for _id in self.replicate_ids 90 | ] 91 | self._is_init = False 92 | 93 | 94 | #%% Properties 95 | 96 | @property 97 | def replicate_ids(self) -> list: 98 | return self._replicate_ids 99 | 100 | 101 | @replicate_ids.setter 102 | def replicate_ids(self, value:list): 103 | if self._is_init: 104 | if not Helpers.has_unique_ids(value): 105 | raise ValueError(Messages.non_unique_ids) 106 | self._replicate_ids = value 107 | else: 108 | raise AttributeError('Replicate ids can only be set during initialization of the ParameterManager') 109 | 110 | 111 | @property 112 | def global_parameters(self) -> list: 113 | return self._global_parameters 114 | 115 | 116 | @global_parameters.setter 117 | def global_parameters(self, value): 118 | if self._is_init: 119 | if not Helpers.has_unique_ids(value): 120 | raise ValueError(Messages.non_unique_ids) 121 | if isinstance(value, list): 122 | _value = {p : numpy.nan for p in sorted(value, key=str.lower)} 123 | elif isinstance(value, dict): 124 | _value = {p : value[p] for p in sorted(value.keys(), key=str.lower)} 125 | self._global_parameters = _value 126 | else: 127 | raise AttributeError('Global parameters can only be set during initialization of the ParameterManager') 128 | 129 | 130 | @property 131 | def parameter_mapping(self) -> pandas.DataFrame: 132 | _values = [p.value for p in self._parameters] 133 | _global_names = [p.global_name for p in self._parameters] 134 | _local_names = [p.local_name for p in self._parameters] 135 | _replicate_ids = [p.replicate_id for p in self._parameters] 136 | df = pandas.DataFrame([_global_names, _local_names, _replicate_ids, _values]).T 137 | df.columns = ['global_name', 'local_name', 'replicate_id', 'value'] 138 | return df.set_index(['global_name', 'replicate_id']) 139 | 140 | 141 | #%% Public methods 142 | 143 | def set_parameter_values(self, parameters:dict): 144 | """ 145 | Assigns values to some parameters. 146 | Valid keys are the global names or local names or model parameters, initial values, 147 | or observation parameters, according to the current parameter mapping. 148 | 149 | Arguments 150 | --------- 151 | parameters : dict 152 | The parameter names and corresponding values to be set. 153 | 154 | Warns 155 | ----- 156 | UserWarning 157 | Values for unknown parameters are set. 158 | """ 159 | 160 | known_parameters = set([_parameter.local_name for _parameter in self._parameters]) 161 | new_parameters = set(parameters.keys()) 162 | unknown_parameters = new_parameters.difference(known_parameters) 163 | if len(unknown_parameters) > 0: 164 | warnings.warn(f'Detected unknown parameters, which are ignored: {unknown_parameters}', UserWarning) 165 | 166 | for p in parameters.keys(): 167 | for _parameter in self._parameters: 168 | if _parameter.local_name == p: 169 | _parameter.value = parameters[p] 170 | 171 | 172 | def apply_mappings(self, mappings:List[ParameterMapper]): 173 | """ 174 | An item of the mappings list must be a ParameterMapper instance according to ParameterMapper(replicate_id=..., global_name=..., local_name=..., value=...). 175 | 176 | NOTE: 177 | replicate_id can also be a list, which applies the mapping to all replicate in this list. 178 | replicate_id can also be 'all', which applies the mapping to all replicates. 179 | 180 | Arguments 181 | --------- 182 | mappings : list 183 | A list of mappings, which can be a tupe or ParameterMapper objects, or a mix of them. 184 | 185 | Raises 186 | ------ 187 | TypeError 188 | Any mapping is not a ParameterMapper object. 189 | ValueError 190 | A mapping has an invalid replicate id. 191 | ValueError 192 | A mapping has a invalid global parameter name. 193 | """ 194 | 195 | if isinstance(mappings, ParameterMapper): 196 | mappings = [mappings] 197 | 198 | self._check_mappings(mappings) 199 | # save parameters in case the mapping is not valid 200 | _backup_parameters = copy.deepcopy(self._parameters) 201 | for mapping in mappings: 202 | 203 | _replicate_id = mapping.replicate_id 204 | _global_name = mapping.global_name 205 | _local_name = mapping.local_name 206 | _value = mapping.value 207 | 208 | if isinstance(_replicate_id, list): 209 | for _id in list(_replicate_id): 210 | if _id not in self.replicate_ids: 211 | raise ValueError(f'Invalid replicate id: {_id}') 212 | elif _replicate_id not in self.replicate_ids and _replicate_id != 'all': 213 | raise ValueError(f'Invalid replicate id: {_replicate_id}') 214 | 215 | if _global_name not in self.global_parameters.keys(): 216 | raise ValueError(f'Invalid global parameter name: {_global_name}') 217 | 218 | self._apply_single_mapping(_replicate_id, _global_name, _local_name, _value) 219 | 220 | self._check_joint_uniqueness_local_names_and_values(_backup_parameters) 221 | 222 | 223 | def get_parameter_mappers(self) -> List[ParameterMapper]: 224 | """ 225 | Returns a list of ParameterMapper objects representing the current parameter mapping.0 226 | """ 227 | 228 | mappings = [] 229 | for p in self._parameters: 230 | mapping = ParameterMapper(replicate_id=p.replicate_id, global_name=p.global_name, local_name=p.local_name, value=p.value) 231 | mappings.append(mapping) 232 | return mappings 233 | 234 | 235 | def get_parameters_for_replicate(self, replicate_id:str) -> OwnDict: 236 | """ 237 | Extracts the parameters for a specific replicate. 238 | 239 | Arguments 240 | --------- 241 | replicate_id : str 242 | The specific (unique) id of a replicate. 243 | 244 | Returns 245 | ------- 246 | OwnDict with keys as global parameter names and corresponding values. 247 | """ 248 | 249 | parameters_dict = {} 250 | for _parameter in self._parameters: 251 | if _parameter.replicate_id == replicate_id: 252 | parameters_dict[_parameter.global_name] = _parameter.value 253 | return OwnDict(parameters_dict) 254 | 255 | 256 | #%% Private methods 257 | 258 | def _apply_single_mapping(self, replicate_id:list, global_name:str, local_name:str, value:float=None): 259 | """ 260 | Helper method that applies a single mapping. 261 | 262 | Arguments 263 | --------- 264 | replicate_id : list, or str, or 'all' 265 | Identifies the replicates for which the mapping is applied. 266 | Can be a single id, a list of those, or 'all'. 267 | global_name : str 268 | Identifies the global parameter that is mapped. 269 | local_name : str 270 | The local, replicate-specific name of the global parameter. 271 | 272 | Keyword arguments 273 | ----------------- 274 | value : float 275 | The parameters value for the mapping. 276 | Default is None, which uses the value of the corresponding global parameter. 277 | """ 278 | 279 | if replicate_id == 'all': 280 | replicate_id = self.replicate_ids 281 | 282 | # make a single item list in case only one replicate_id is addressed 283 | if isinstance(replicate_id, str): 284 | replicate_id = [replicate_id] 285 | 286 | for _parameter in self._parameters: 287 | for _replicate_id in replicate_id: 288 | if _parameter.replicate_id == _replicate_id and _parameter.global_name == global_name: 289 | _parameter.local_name = local_name 290 | if value is not None: 291 | _parameter.value = value 292 | 293 | 294 | def _check_mappings(self, mappings:List[ParameterMapper]): 295 | """ 296 | Checks that the mappings have unique pairs for local_names and value, 297 | as well as each mappings item is a Parameter Mapping objects. 298 | 299 | Arguments 300 | --------- 301 | mappings : List[ParameterMapper] 302 | The list of parameter mappings to be applied. 303 | 304 | Raises 305 | ------ 306 | TypeError 307 | An item of mappings is not a Parameter object. 308 | ValueError 309 | Same parameters in mappings have not unique values. 310 | """ 311 | 312 | local_names = [] 313 | values = [] 314 | for mapping in mappings: 315 | 316 | if not isinstance(mapping, ParameterMapper): 317 | raise TypeError(f'Items of mappings must be of type ParameterMapper. Invalid mapping item: {mapping}') 318 | local_names.append(mapping.local_name) 319 | values.append(mapping.value) 320 | 321 | name_value_mapping = {p : None for p in sorted(set(local_names), key=str.lower)} 322 | for _name, _value in zip(local_names, values): 323 | if name_value_mapping[_name] is None: 324 | name_value_mapping[_name] = _value 325 | elif name_value_mapping[_name] != _value: 326 | raise ValueError( 327 | f'Parameters of mappings to be applied must have unique values. Parameter with at least two different values detected. "{_name}" with {name_value_mapping[_name]} and {_value}.' 328 | ) 329 | 330 | 331 | def _check_joint_uniqueness_local_names_and_values(self, backup_parameters:List[Parameter]): 332 | """ 333 | Checks that the application of a valid mapping does not result in non-unique pairs of local_name and value 334 | 335 | Arguments 336 | --------- 337 | backup_parameters : List[Parameter] 338 | The backup parameters will be applied in case a ValueError is raised 339 | 340 | Raises 341 | ------ 342 | ValueError 343 | A parameter among different replicates has different values. 344 | """ 345 | 346 | local_names = [_parameter.local_name for _parameter in self._parameters] 347 | values = [ _parameter.value for _parameter in self._parameters] 348 | name_value_mapping = {p : None for p in sorted(set(local_names), key=str.lower)} 349 | for _name, _value in zip(local_names, values): 350 | if name_value_mapping[_name] is None: 351 | name_value_mapping[_name] = _value 352 | elif name_value_mapping[_name] != _value: 353 | self._parameters = backup_parameters 354 | raise ValueError( 355 | f'Parameters must have unique values. Parameter with at least two different values detected. "{_name}" with {name_value_mapping[_name]} and {_value}.' 356 | ) 357 | -------------------------------------------------------------------------------- /tests/test_datatypes.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import pytest 4 | import scipy 5 | 6 | from pyfoomb.datatypes import Measurement 7 | from pyfoomb.datatypes import ModelState 8 | from pyfoomb.datatypes import Observation 9 | from pyfoomb.datatypes import Sensitivity 10 | from pyfoomb.datatypes import TimeSeries 11 | 12 | 13 | class StaticHelpers(): 14 | 15 | name = 'TestName' 16 | timepoints = [1, 2, 3, 4, 5] 17 | values = [100, 200, 300, 400, 500] 18 | errors = [1/1, 1/2, 1/3, 1/4, 1/5] 19 | replicate_id = '1st' 20 | state = 'y1' 21 | parameter = 'p1' 22 | 23 | 24 | class StaticDatatypes(): 25 | 26 | timeseries = TimeSeries( 27 | name=StaticHelpers.name, 28 | timepoints=StaticHelpers.timepoints, 29 | values=StaticHelpers.values, 30 | replicate_id=StaticHelpers.replicate_id 31 | ) 32 | 33 | modelstate = ModelState( 34 | name=StaticHelpers.name, 35 | timepoints=StaticHelpers.timepoints, 36 | values=StaticHelpers.values, 37 | replicate_id=StaticHelpers.replicate_id 38 | ) 39 | 40 | measurement_wo_errs = Measurement( 41 | name=StaticHelpers.name, 42 | timepoints=StaticHelpers.timepoints, 43 | values=StaticHelpers.values, 44 | replicate_id=StaticHelpers.replicate_id 45 | ) 46 | 47 | measurement_w_errs = Measurement( 48 | name=StaticHelpers.name, 49 | timepoints=StaticHelpers.timepoints, 50 | values=StaticHelpers.values, 51 | errors=StaticHelpers.errors, 52 | replicate_id=StaticHelpers.replicate_id 53 | ) 54 | 55 | observation = Observation( 56 | name=StaticHelpers.name, 57 | timepoints=StaticHelpers.timepoints, 58 | values=StaticHelpers.values, 59 | observed_state=StaticHelpers.state, 60 | replicate_id=StaticHelpers.replicate_id 61 | ) 62 | 63 | sensitivity = Sensitivity( 64 | timepoints=StaticHelpers.timepoints, 65 | values=StaticHelpers.values, 66 | response=StaticHelpers.state, 67 | parameter=StaticHelpers.parameter, 68 | replicate_id=StaticHelpers.replicate_id 69 | ) 70 | 71 | 72 | class StaticErrorModelHelpers(): 73 | 74 | constant_error_model_parameters = { 75 | 'offset' : 0, 76 | } 77 | 78 | linear_error_model_parameters = { 79 | 'offset' : 0, 80 | 'slope' : 1, 81 | } 82 | 83 | squared_error_model_parameters = { 84 | 'w0' : 1, 85 | 'w1' : 0.1, 86 | 'w2' : 0.02, 87 | } 88 | 89 | @staticmethod 90 | def constant_error_model(values, parameters): 91 | offset = parameters['offset'] 92 | return np.ones_like(values)*offset 93 | 94 | @staticmethod 95 | def linear_error_model(values, parameters): 96 | offset = parameters['offset'] 97 | slope = parameters['slope'] 98 | return values * slope + offset 99 | 100 | @staticmethod 101 | def squared_error_model(values, parameters): 102 | w0 = parameters['w0'] 103 | w1 = parameters['w1'] 104 | w2 = parameters['w2'] 105 | return w0 + values*w1 + np.square(values)*w2 106 | 107 | 108 | #%% Actual tests 109 | 110 | class TestInstantiationVariousDatatypes(): 111 | 112 | @pytest.mark.parametrize( 113 | 'values, errors, info, replicate_id', 114 | [ 115 | ([[10], [20], [30], [40], [50]], None, None, None), 116 | (StaticHelpers.values, None, None, None), 117 | (StaticHelpers.values, StaticHelpers.errors, None, None), 118 | (StaticHelpers.values, StaticHelpers.errors, 'TestInfo', None), 119 | (StaticHelpers.values, StaticHelpers.errors, 'TestInfo', '1st'), 120 | ] 121 | ) 122 | def test_init_datatypes(self, values, errors, info, replicate_id): 123 | """ 124 | To test typical instantiations of datatypes. 125 | """ 126 | 127 | # Standard instatiations 128 | TimeSeries(name=StaticHelpers.name, timepoints=StaticHelpers.timepoints, values=values, info=info, replicate_id=replicate_id) 129 | ModelState(name=StaticHelpers.name, timepoints=StaticHelpers.timepoints, values=values, info=info, replicate_id=replicate_id) 130 | Measurement(name=StaticHelpers.name, timepoints=StaticHelpers.timepoints, values=values, errors=errors, info=info, replicate_id=replicate_id) 131 | Observation(name=StaticHelpers.name, timepoints=StaticHelpers.timepoints, values=values, observed_state='y1', replicate_id=replicate_id) 132 | Sensitivity(timepoints=StaticHelpers.timepoints, values=values, response='y1', parameter='p1', replicate_id=replicate_id) 133 | Sensitivity(timepoints=StaticHelpers.timepoints, values=values, response='y1', parameter='p1', h=1e-8, replicate_id=replicate_id) 134 | 135 | # Must provide timepoints 136 | with pytest.raises(ValueError): 137 | TimeSeries(name=StaticHelpers.name, timepoints=None, values=values) 138 | 139 | # Must provide values 140 | with pytest.raises(ValueError): 141 | TimeSeries(name=StaticHelpers.name, timepoints=StaticHelpers.timepoints, values=None) 142 | 143 | # Measurements can be created with error_models 144 | Measurement( 145 | name=StaticHelpers.name, 146 | timepoints=StaticHelpers.timepoints, 147 | values=values, 148 | error_model=StaticErrorModelHelpers.constant_error_model, 149 | error_model_parameters=StaticErrorModelHelpers.constant_error_model_parameters, 150 | ) 151 | 152 | # Must provide a subclass of rvs.continous as p.d.f. 153 | with pytest.raises(ValueError): 154 | Measurement(name=StaticHelpers.name, timepoints=StaticHelpers.timepoints, values=values, error_distribution=scipy.stats.bernoulli) 155 | 156 | # Error values must be >0 157 | with pytest.raises(ValueError): 158 | Measurement(name=StaticHelpers.name, timepoints=StaticHelpers.timepoints, values=values, errors=[0]*len(StaticHelpers.values)) 159 | with pytest.raises(ValueError): 160 | Measurement(name=StaticHelpers.name, timepoints=StaticHelpers.timepoints, values=values, errors=[-1]*len(StaticHelpers.values)) 161 | 162 | 163 | @pytest.mark.parametrize( 164 | 'input_vector, masked_vector', 165 | [ 166 | ([10, None, 30, 40, 50], [10, 30, 40, 50]), 167 | ([10, 20, np.nan, 40, 50], [10, 20, 40, 50]), 168 | ([10, 20, 30, np.inf, 50], [10, 20, 30, 50]), 169 | ([10, 20, 30, 40, -np.inf], [10, 20, 30, 40]), 170 | ([10, None, np.nan, np.inf, -np.inf], [10]), 171 | ] 172 | ) 173 | def test_init_datatypes_masking_non_numeric(self, input_vector, masked_vector): 174 | """ 175 | Non-numeric values implicitly define a mask, which is in turn applied to all vectors of the corresponding datatypes. 176 | """ 177 | 178 | _timeseries = TimeSeries(name=StaticHelpers.name, timepoints=input_vector, values=StaticHelpers.values) 179 | assert all(_timeseries.timepoints) == all(masked_vector) 180 | assert _timeseries.length == len(masked_vector) 181 | 182 | _timeseries = TimeSeries(name=StaticHelpers.name, timepoints=StaticHelpers.timepoints, values=input_vector) 183 | assert all(_timeseries.timepoints) == all(masked_vector) 184 | assert _timeseries.length == len(masked_vector) 185 | 186 | _measurement = Measurement(name=StaticHelpers.name, timepoints=StaticHelpers.timepoints, values=StaticHelpers.values, errors=input_vector) 187 | assert all(_measurement.timepoints) == all(masked_vector) 188 | assert _measurement.length == len(masked_vector) 189 | 190 | 191 | class TestSetters(): 192 | 193 | @pytest.mark.parametrize( 194 | 'bad_vector', 195 | [ 196 | ([10, 20, 30, 40]), # length does not match the length of the other vectors 197 | ([[10, 20, 30, 40, 50], [10, 20, 30, 40, 50]]), # Cannot be cast into 1D vector 198 | ([[10, 20, 30, 40, 50]]), 199 | ], 200 | ) 201 | def test_setters_reject_bad_vectors(self, bad_vector): 202 | """ 203 | Testing that the setters do not accept bad_vectors. 204 | """ 205 | 206 | with pytest.raises(ValueError): 207 | StaticDatatypes.timeseries.values = bad_vector 208 | with pytest.raises(ValueError): 209 | StaticDatatypes.measurement_w_errs.errors = bad_vector 210 | 211 | 212 | class TestPlot(): 213 | 214 | @pytest.mark.parametrize( 215 | 'datatype', [ 216 | (StaticDatatypes.timeseries), 217 | (StaticDatatypes.modelstate), 218 | (StaticDatatypes.measurement_wo_errs), 219 | (StaticDatatypes.observation), 220 | (StaticDatatypes.sensitivity) 221 | ] 222 | ) 223 | def test_plotting(self, datatype): 224 | """ 225 | The pyfoomb datatypes come with an own plot method for rapid development. 226 | Based on some properties of the datatype, the plot will auto-generated legend and title in different ways. 227 | Some arguments are also tested. 228 | """ 229 | 230 | datatype.plot() 231 | datatype.plot(title='Some title') 232 | datatype.replicate_id = StaticHelpers.replicate_id 233 | datatype.plot() 234 | datatype.info = 'Some info' 235 | datatype.plot() 236 | 237 | 238 | class TestMeasurementErrorModels(): 239 | 240 | @pytest.mark.parametrize( 241 | 'error_model, error_model_parameters', 242 | [ 243 | (StaticErrorModelHelpers.constant_error_model, StaticErrorModelHelpers.constant_error_model_parameters), 244 | (StaticErrorModelHelpers.linear_error_model, StaticErrorModelHelpers.linear_error_model_parameters), 245 | (StaticErrorModelHelpers.squared_error_model, StaticErrorModelHelpers.squared_error_model_parameters), 246 | ] 247 | ) 248 | def test_update_error_models_parameters(self, error_model, error_model_parameters): 249 | """ 250 | Updates error_models for existing Measurement objects 251 | """ 252 | 253 | # create measurement first 254 | measurement = Measurement(name=StaticHelpers.name, timepoints=StaticHelpers.timepoints, values=StaticHelpers.values) 255 | 256 | # To use a different (new) error_model, it must be passed with its corresponding error_model_parameters 257 | measurement.update_error_model(error_model=error_model, error_model_parameters=error_model_parameters) 258 | 259 | # Parameter values can be updated, as long as all parameters are present in the new dictionary 260 | measurement.error_model_parameters = {_p : error_model_parameters[_p]*1.5 for _p in error_model_parameters} 261 | 262 | # Incase the error model is applied, a warning can be given for overwriting the error vector 263 | with pytest.warns(UserWarning): 264 | measurement.apply_error_model(report_level=1) 265 | 266 | # Setting new parameter values won't work 267 | with pytest.raises(KeyError): 268 | measurement.error_model_parameters = {'bad_parameter' : 1000} 269 | 270 | @pytest.mark.parametrize( 271 | 'metric', 272 | [ 273 | ('negLL'), 274 | ('WSS'), 275 | ('SS'), 276 | ('bad_metric') 277 | ], 278 | ) 279 | def test_metrics_and_loss_caluclation(self, metric): 280 | 281 | if metric == 'bad_metric': 282 | with pytest.raises(NotImplementedError): 283 | StaticDatatypes.measurement_wo_errs.get_loss(metric=metric, predictions=[StaticDatatypes.modelstate]) 284 | # Only for metric SS (sum-of-squares) the loss can be calculated from Measurement objects without having errors 285 | elif metric == 'SS': 286 | assert not np.isnan(StaticDatatypes.measurement_wo_errs.get_loss(metric=metric, predictions=[StaticDatatypes.modelstate])) 287 | assert not np.isnan(StaticDatatypes.measurement_w_errs.get_loss(metric=metric, predictions=[StaticDatatypes.modelstate])) 288 | else: 289 | with pytest.raises(AttributeError): 290 | StaticDatatypes.measurement_wo_errs.get_loss(metric=metric, predictions=[StaticDatatypes.modelstate]) 291 | assert not np.isnan(StaticDatatypes.measurement_w_errs.get_loss(metric=metric, predictions=[StaticDatatypes.modelstate])) 292 | 293 | # Fail fast for ambiguous, non-unique predictions: list of prediction with the same 'name' and 'replicate_id' 294 | with pytest.raises(ValueError): 295 | StaticDatatypes.measurement_w_errs.get_loss(metric=metric, predictions=[StaticDatatypes.modelstate]*2) 296 | 297 | modelstate = ModelState( 298 | name=StaticHelpers.name, 299 | timepoints=StaticHelpers.timepoints[::2], # Use less timepoints 300 | values=StaticHelpers.values[::2], # Use less values, corresponding to using less timepoints 301 | ) 302 | # Using predictions that have not matching replicate_ids returns nan loss 303 | assert np.isnan(StaticDatatypes.measurement_wo_errs.get_loss(metric='SS', predictions=[modelstate])) 304 | 305 | # When adding at least one matching prediction, a loss can be calculated 306 | assert not np.isnan(StaticDatatypes.measurement_w_errs.get_loss(metric=metric, predictions=[modelstate, StaticDatatypes.modelstate])) 307 | 308 | 309 | class TestMiscellaneous(): 310 | 311 | def test_str(self): 312 | print(StaticDatatypes.timeseries) 313 | 314 | def test_other_distributions(self): 315 | 316 | # Create a Measurement object having an error t-distribution 317 | measurement_t = Measurement( 318 | name=StaticHelpers.name, 319 | timepoints=StaticHelpers.timepoints, 320 | values=StaticHelpers.values, 321 | errors=StaticHelpers.values, 322 | error_distribution=scipy.stats.t, 323 | distribution_kwargs={'df' : 1}, 324 | replicate_id='1st' 325 | ) 326 | 327 | # Get rvs values for, e.g. MC sampling 328 | measurement_t._get_random_samples_values() 329 | 330 | loss_1 = measurement_t.get_loss(metric='negLL', predictions=[StaticDatatypes.modelstate]) 331 | assert not np.isnan(loss_1) 332 | loss_2 = measurement_t.get_loss(metric='negLL', predictions=[StaticDatatypes.modelstate], distribution_kwargs={'df' : 100}) 333 | assert not np.isnan(loss_2) 334 | 335 | # Does not work in case there are no errors 336 | with pytest.raises(AttributeError): 337 | StaticDatatypes.measurement_wo_errs._get_random_samples_values() 338 | -------------------------------------------------------------------------------- /tests/modelling_library.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | from pyfoomb import BioprocessModel 5 | from pyfoomb import ObservationFunction 6 | 7 | 8 | #%% Models that work 9 | 10 | class Model01(BioprocessModel): 11 | def rhs(self, t, y, sw=None): 12 | y0, y1 = y 13 | rate0, rate1 = self.model_parameters.to_numpy() 14 | dy0dt = rate0 15 | dy1dt = rate1 16 | return np.array([dy0dt, dy1dt]) 17 | 18 | 19 | class Model02(BioprocessModel): 20 | def rhs(self, t, y, sw=None): 21 | k = self.model_parameters['k'] 22 | dydt = -k * y 23 | return dydt 24 | 25 | 26 | class Model03(BioprocessModel): 27 | def rhs(self, t, y, sw): 28 | y0, y1 = y 29 | rate0, rate1 = self.model_parameters.to_numpy() 30 | dy0dt = rate0 31 | dy1dt = rate1 32 | return np.array([dy0dt, dy1dt]) 33 | 34 | def state_events(self, t, y, sw): 35 | y0, y1 = y 36 | event_y1 = y1 37 | return np.array([event_y1]) 38 | 39 | 40 | class Model04(BioprocessModel): 41 | def rhs(self, t, y, sw=None): 42 | y0, y1 = y 43 | rate0 = self.model_parameters['rate0'] 44 | rate1 = self.model_parameters['rate1'] 45 | dy0dt = rate0 46 | dy1dt = rate1 47 | return np.array([dy0dt, dy1dt]) 48 | 49 | 50 | class Model05(BioprocessModel): 51 | def rhs(self, t, y): 52 | yA, yB, yC, yD = y 53 | rateA, rateB, rateC, rateD, rateE = self.model_parameters.to_numpy() 54 | dyAdt = rateA 55 | dyBdt = rateB 56 | dyCdt = rateC 57 | dyDdt = rateD - rateE 58 | return np.array([dyAdt, dyBdt, dyCdt, dyDdt]) 59 | 60 | 61 | class Model06(BioprocessModel): 62 | # This model cannot have rate0 and rate1 to be zero 63 | def rhs(self, t, y, sw): 64 | y0, y1 = y 65 | rate0, rate1 = self.model_parameters.to_numpy() 66 | if sw[0]: 67 | dy0dt = 1/rate1 68 | else: 69 | dy0dt = rate0 70 | 71 | if sw[1]: 72 | dy1dt = 1/rate0 73 | else: 74 | dy1dt = rate1 75 | return np.array([dy0dt, dy1dt]) 76 | 77 | def state_events(self, t, y, sw): 78 | y0, y1 = y 79 | event_t1 = t - 5 80 | event_t2 = t - 2 81 | event_t3 = t - 1 82 | return np.array([event_t1, event_t2, event_t3]) 83 | 84 | def change_states(self, t, y, sw): 85 | y0, y1 = y 86 | if sw[2]: 87 | y0 = self.initial_values['y00'] 88 | y1 = self.initial_values['y10'] 89 | return [y0, y1] 90 | 91 | 92 | class Model06_V02(Model06): 93 | # The auto-detection of events will work on the Simulator level, so the user need to provided the initial switches 94 | def state_events(self, t, y, sw): 95 | y0, y1 = y 96 | event_t1 = t - 5 97 | event_t2 = t - 2 98 | event_t3 = t - 1 99 | events = [event_t1, event_t2, event_t3] 100 | return events 101 | 102 | class Model06_V03(Model06): 103 | # The auto-detection of events will fail as the sw arg can't be None, so the user need to provided the initial switches 104 | def state_events(self, t, y, sw): 105 | y0, y1 = y 106 | event_t1 = t - 5 107 | event_t2 = t - 2 108 | event_t3 = t - 1 109 | sw_1 = sw[0] 110 | sw_2 = sw[1] 111 | sw_3 = sw[2] 112 | events = [event_t1, event_t2, event_t3] 113 | return events 114 | 115 | 116 | class Model07(Model02): 117 | 118 | def state_events(self, t, y, sw): 119 | event_t = t - 10 120 | return [event_t] 121 | 122 | def change_states(self, t, y, sw): 123 | if sw[0]: 124 | y = y + 10 125 | return y 126 | 127 | 128 | # Variants of Model03 129 | class Model03_V02(Model03): 130 | def state_events(self, t, y, sw): 131 | y0, y1 = y 132 | event_y1 = y1 133 | return np.array([event_y1,]) 134 | 135 | 136 | class Model03_V03(Model03): 137 | def state_events(self, t, y, sw): 138 | y0, y1 = y 139 | event_y0 = y0 140 | event_y1 = y1 141 | return np.array([event_y0, event_y1]) 142 | 143 | 144 | class Model03_V04(Model03): 145 | def state_events(self, t, y, sw): 146 | y0, y1 = y 147 | event_y0 = y0 148 | event_y1 = y1 149 | return np.array([event_y0, event_y1,]) 150 | 151 | 152 | class Model03_V05(Model03): 153 | def state_events(self, t, y, sw): 154 | y0, y1 = y 155 | event_y0 = y0 156 | event_y1 = y1 157 | event_t = t - 5 158 | return np.array([event_y0, event_y1, event_t]) 159 | 160 | 161 | class Model03_V06(Model03): 162 | def state_events(self, t, y, sw): 163 | y0, y1 = y 164 | event_y0 = y0 165 | event_y1 = y1 166 | event_t = t - 5 167 | return np.array([event_y0, event_y1, event_t,]) 168 | 169 | 170 | # Autodetection of number of events from return of method `state_events` 171 | class Model03_V07(BioprocessModel): 172 | def rhs(self, t, y, sw): 173 | y0, y1 = y 174 | rate0, rate1 = self.model_parameters.to_numpy() 175 | if sw[0]: 176 | dy0dt = rate1 177 | else: 178 | dy0dt = rate0 179 | 180 | if sw[1]: 181 | dy1dt = rate0 182 | else: 183 | dy1dt = rate1 184 | return np.array([dy0dt, dy1dt]) 185 | 186 | def state_events(self, t, y, sw): 187 | y0, y1 = y 188 | event_y0 = y0 189 | event_y1 = y1 190 | event_t = t - 5 191 | return np.array([event_y0, event_y1, event_t,]) 192 | 193 | def change_states(self, t, y, sw): 194 | y0, y1 = y 195 | if sw[2]: 196 | y0 = self.initial_values['y00'] 197 | y1 = self.initial_values['y10'] 198 | return [y0, y1] 199 | 200 | 201 | # Autodetection of number of events from return of method `state_events` 202 | class Model03_V08(Model03_V07): 203 | def state_events(self, t, y, sw): 204 | y0, y1 = y 205 | event_y0 = y0 206 | event_y1 = y1 207 | event_t = t - 5 208 | return np.array([ 209 | event_y0, 210 | event_y1, 211 | event_t, 212 | ]) 213 | 214 | 215 | # Autodetection of number of events from return of method `state_events` 216 | class Model03_V09(Model03_V07): 217 | def state_events(self, t, y, sw): 218 | y0, y1 = y 219 | event_y0 = y0 220 | event_y1 = y1 221 | event_t = t - 5 222 | return np.array([ 223 | event_y0, 224 | event_y1, 225 | event_t 226 | ]) 227 | 228 | 229 | # Autodetection of number of events from return of method `state_events` 230 | class Model03_V10(Model03_V07): 231 | def state_events(self, t, y, sw): 232 | y0, y1 = y 233 | event_y0 = y0 234 | event_y1 = y1 235 | event_t = t - 5 236 | return np.array([ 237 | event_y0, event_y1, event_t 238 | ]) 239 | 240 | 241 | # Autodetection of number of events from return of method `state_events` 242 | class Model03_V11(Model03_V07): 243 | def state_events(self, t, y, sw): 244 | y0, y1 = y 245 | event_y0 = y0 246 | event_y1 = y1 247 | event_t = t - 5 248 | return np.array([ 249 | event_y0, event_y1, event_t, 250 | ]) 251 | 252 | 253 | # Autodetection of number of events from return of method `state_events` 254 | class Model03_V12(Model03): 255 | def state_events(self, t, y, sw): 256 | event_t = t - 5 257 | return np.array([event_t]) 258 | 259 | 260 | # Autodetection of number of events from return of method `state_events` 261 | class Model03_V13(Model03): 262 | def state_events(self, t, y, sw): 263 | event_t = t - 5 264 | return np.array([event_t,]) 265 | 266 | 267 | # Bad variants of Model03 268 | 269 | class Model03_BadV01(BioprocessModel): 270 | # state vector is unpacked in the wrong order 271 | def rhs(self, t, y): 272 | y1, y0 = y 273 | rate0, rate1 = self.model_parameters.to_numpy() 274 | dy0dt = rate0 275 | dy1dt = rate1 276 | return np.array([dy0dt, dy1dt]) 277 | 278 | 279 | class Model03_BadV02(BioprocessModel): 280 | # derivatives of state vector are return in the wrong order 281 | def rhs(self, t, y): 282 | y0, y1 = y 283 | rate0, rate1 = self.model_parameters.to_numpy() 284 | dy0dt = rate0 285 | dy1dt = rate1 286 | return np.array([dy1dt, dy0dt]) 287 | 288 | 289 | class Model03_BadV03(BioprocessModel): 290 | # state vector is unpacked in the wrong order 291 | # derivatives of state vector are return in the wrong order 292 | def rhs(self, t, y): 293 | y1, y0 = y 294 | rate0, rate1 = self.model_parameters.to_numpy() 295 | dy0dt = rate0 296 | dy1dt = rate1 297 | return np.array([dy1dt, dy0dt]) 298 | 299 | 300 | class Model03_BadV04(BioprocessModel): 301 | # name of parameter variable does not match the corresponding key 302 | def rhs(self, t, y): 303 | y0, y1 = y 304 | rate0 = self.model_parameters['rate0'] 305 | any_parameter = self.model_parameters['rate1'] 306 | dy0dt = rate0 307 | dy1dt = any_parameter 308 | return np.array([dy0dt, dy1dt]) 309 | 310 | 311 | class Model03_BadV05(BioprocessModel): 312 | # parameters are unpacked in wrong order 313 | def rhs(self, t, y): 314 | y0, y1 = y 315 | rate1, rate0 = self.model_parameters.to_numpy() 316 | dy0dt = rate0 317 | dy1dt = rate1 318 | return np.array([dy0dt, dy1dt]) 319 | 320 | 321 | class Model03_BadV06(Model03_V06): 322 | # state vector is unpacked in the wrong order 323 | def state_events(self, t, y, sw): 324 | y1, y0 = y 325 | rate0, rate1 = self.model_parameters.to_numpy() 326 | event_y0 = y0 327 | event_y1 = y1 328 | event_t = t - 5 329 | return np.array([event_y0, event_y1, event_t,]) 330 | 331 | 332 | class Model03_BadV07(Model03_V06): 333 | # parameters are unpacked in wrong order 334 | def state_events(self, t, y, sw): 335 | y0, y1 = y 336 | rate1, rate0 = self.model_parameters.to_numpy() 337 | event_y0 = y0 338 | event_y1 = y1 339 | event_t = t - 5 340 | return np.array([event_y0, event_y1, event_t,]) 341 | 342 | 343 | class Model03_BadV08(Model03_V06): 344 | # name of parameter variable does not match the corresponding key 345 | def state_events(self, t, y, sw): 346 | y0, y1 = y 347 | rate0 = self.model_parameters['rate0'] 348 | any_parameter = self.model_parameters['rate1'] 349 | event_y0 = y0 350 | event_y1 = y1 351 | event_t = t - 5 352 | return np.array([event_y0, event_y1, event_t,]) 353 | 354 | 355 | class Model03_BadV09(Model03_V06): 356 | # name of parameter variable does not match the corresponding key 357 | def state_events(self, t, y, sw): 358 | y0, y1 = y 359 | any_parameter = self.model_parameters['rate1'] 360 | event_y0 = y0 361 | event_y1 = y1 362 | event_t = t - 5 363 | return np.array([event_y0, event_y1, event_t,]) 364 | 365 | 366 | class Model06_Bad01(Model06): 367 | # Has an undefined variable 368 | def rhs(self, t, y, sw): 369 | y0, y1 = y 370 | rate0, rate1 = self.model_parameters.to_numpy() 371 | if sw[0]: 372 | dy0dt = rate99 373 | else: 374 | dy0dt = rate0 375 | 376 | if sw[1]: 377 | dy1dt = rate0 378 | else: 379 | dy1dt = rate1 380 | return np.array([dy0dt, dy1dt]) 381 | 382 | 383 | class Model06_Bad02(Model06): 384 | # Has an undefined variable 385 | def rhs(self, t, y, sw): 386 | y0, y1 = y 387 | rate0, rate1 = self.model_parameters.to_numpy() 388 | if sw[0]: 389 | dy0dt = rate1 390 | else: 391 | dy0dt = rate99 392 | 393 | if sw[1]: 394 | dy1dt = rate0 395 | else: 396 | dy1dt = rate1 397 | return np.array([dy0dt, dy1dt]) 398 | 399 | 400 | class Model06_Bad03(Model06): 401 | 402 | # Has an undefined variable 403 | def state_events(self, t, y, sw): 404 | y0, y1 = y 405 | event_t1 = t - 5 406 | event_t3 = t - 1 407 | return np.array([event_t1, event_t2, event_t3]) 408 | 409 | 410 | class Model06_Bad04(Model06): 411 | # Has an undefined variable 412 | def change_states(self, t, y, sw): 413 | y0, y1 = y 414 | y00 = self.initial_values['y00'] 415 | y10 = self.initial_values['y10'] 416 | 417 | if sw[1]: 418 | y0 = y000000000000000 419 | y1 = y10 420 | 421 | return [y0, y1] 422 | 423 | 424 | class Model06_Bad05(Model06): 425 | # the number of events depends on the switches 426 | def state_events(self, t, y, sw): 427 | y0, y1 = y 428 | event_t1 = t - 5 429 | event_t2 = t - 2 430 | event_t3 = t - 1 431 | if sw[1]: 432 | events = [event_t1, event_t3] 433 | else: 434 | events = [event_t1, event_t2, event_t3] 435 | return events 436 | 437 | 438 | class Model06_Bad06(Model06): 439 | # Inconsitent parameter unpacking 440 | def change_states(self, t, y, sw): 441 | y0, y1 = y 442 | rate00000 = self.model_parameters['rate0'] 443 | return [y0, y1] 444 | 445 | 446 | class Model06_Bad07(Model06): 447 | # Inconsitent parameter unpacking 448 | def change_states(self, t, y, sw): 449 | y0, y1 = y 450 | rate1, rate0 = self.model_parameters.to_numpy() 451 | return [y0, y1] 452 | 453 | 454 | class Model06_Bad08(Model06): 455 | # Inconsitent state unpacking 456 | def change_states(self, t, y, sw): 457 | y1, y0 = y 458 | rate0 = self.model_parameters['rate0'] 459 | return [y0, y1] 460 | 461 | 462 | 463 | class ModelLibrary(): 464 | 465 | modelnames = [ 466 | 'model01', 467 | 'model02', 468 | 'model03', 469 | 'model04', 470 | 'model05', 471 | 'model06', 472 | 'model07', 473 | ] 474 | 475 | modelclasses = { 476 | 'model01' : Model01, 477 | 'model02' : Model02, 478 | 'model03' : Model03, 479 | 'model04' : Model04, 480 | 'model05' : Model05, 481 | 'model06' : Model06, 482 | 'model07' : Model07 483 | } 484 | 485 | states = { 486 | 'model01' : ['y0' , 'y1'], 487 | 'model02' : ['y'], 488 | 'model03' : ['y0' , 'y1'], 489 | 'model04' : ['y0' , 'y1'], 490 | 'model05' : ['yA', 'yB', 'yC', 'yD'], 491 | 'model06' : ['y0' , 'y1'], 492 | 'model07' : ['y'], 493 | } 494 | 495 | model_parameters = { 496 | 'model01' : {'rate0' : 0.0, 'rate1' : 1.0}, 497 | 'model02' : {'k' : 0.02}, 498 | 'model03' : {'rate0' : 2.0, 'rate1' : 3.0}, 499 | 'model04' : {'rate0' : 4.0, 'rate1' : 5.0}, 500 | 'model05' : {'rateA' : 10.0, 'rateB' : 11.0, 'rateC' : 12.0, 'rateD' : 13.0, 'rateE' : 14.0}, 501 | 'model06' : {'rate0' : -2.0, 'rate1' : -3.0}, 502 | 'model07' : {'k' : 0.02}, 503 | } 504 | 505 | initial_values = { 506 | 'model01' : {'y00' : 0.0, 'y10' : 1.0}, 507 | 'model02' : {'y0' : 100.0}, 508 | 'model03' : {'y00' : 2.0, 'y10' : 3.0}, 509 | 'model04' : {'y00' : 4.0, 'y10' : 5.0}, 510 | 'model05' : {'yA0' : 100.0, 'yB0' : 200.0, 'yC0' : 300.0, 'yD0': 400.0}, 511 | 'model06' : {'y00' : 20.0, 'y10' : 30.0}, 512 | 'model07' : {'y0' : 100.0}, 513 | } 514 | 515 | initial_switches = { 516 | 'model01' : None, 517 | 'model02' : None, 518 | 'model03' : [False], 519 | 'model04' : None, 520 | 'model05' : None, 521 | 'model06' : [False, False, False], 522 | 'model07' : None, 523 | } 524 | 525 | bad_variants_model03 = [ 526 | Model03_BadV01, 527 | Model03_BadV01, 528 | Model03_BadV02, 529 | Model03_BadV03, 530 | Model03_BadV04, 531 | Model03_BadV05, 532 | Model03_BadV06, 533 | Model03_BadV07, 534 | Model03_BadV08, 535 | Model03_BadV09, 536 | ] 537 | 538 | bad_variants_model06 = [ 539 | Model06_Bad01, 540 | Model06_Bad02, 541 | Model06_Bad03, 542 | Model06_Bad04, 543 | Model06_Bad05, 544 | Model06_Bad06, 545 | Model06_Bad07, 546 | Model06_Bad08, 547 | ] 548 | 549 | variants_model03 = [ 550 | Model03_V02, 551 | Model03_V03, 552 | Model03_V04, 553 | Model03_V05, 554 | Model03_V06, 555 | Model03_V07, 556 | Model03_V08, 557 | Model03_V09, 558 | Model03_V10, 559 | Model03_V11, 560 | Model03_V12, 561 | Model03_V13 562 | ] 563 | 564 | 565 | class ObservationFunction01(ObservationFunction): 566 | 567 | def observe(self, state_values): 568 | slope_01 = self.observation_parameters['slope_01'] 569 | offset_01 = self.observation_parameters['offset_01'] 570 | return state_values * slope_01 + offset_01 571 | 572 | 573 | class ObservationFunction02(ObservationFunction): 574 | def observe(self, state_values): 575 | p1, p2, p3, p4, p5 = self.observation_parameters.to_numpy() 576 | return state_values + p1 + p2 + p3 + p4 + p5 577 | 578 | 579 | class ObservationFunction02_V02(ObservationFunction): 580 | def observe(self, state_values): 581 | p1, \ 582 | p2, p3, \ 583 | p4, p5 \ 584 | = self.observation_parameters.to_numpy() 585 | return state_values + p1 + p2 + p3 + p4 + p5 586 | 587 | 588 | class ObservationFunction02_V03(ObservationFunction): 589 | def observe(self, state_values): 590 | p1, \ 591 | p2, p3, \ 592 | p4, p5 = self.observation_parameters.to_numpy() 593 | return state_values + p1 + p2 + p3 + p4 + p5 594 | 595 | 596 | class ObservationFunction02_V04(ObservationFunction): 597 | def observe(self, state_values): 598 | p1 = self.observation_parameters['p1'] 599 | p2 = self.observation_parameters['p2'] 600 | p3 = self.observation_parameters['p3'] 601 | p4 = self.observation_parameters['p4'] 602 | p5 = self.observation_parameters['p5'] 603 | return state_values + p1 + p2 + p3 + p4 + p5 604 | 605 | 606 | class ObservationFunction01_Bad01(ObservationFunction): 607 | # observation parameters are unpacked in the wrong order 608 | def observe(self, model_values): 609 | slope_01, offset_01 = self.observation_parameters.to_numpy() 610 | return model_values * slope_01 + offset_01 611 | 612 | 613 | class ObservationFunction01_Bad02(ObservationFunction): 614 | # observation parameter variable name does not match corresponding keys 615 | def observe(self, model_values): 616 | slope_01 = self.observation_parameters['slope_01'] 617 | some_offset = self.observation_parameters['offset_01'] 618 | return model_values * slope_01 + some_offset 619 | 620 | 621 | class ObservationFunctionLibrary(): 622 | 623 | names = [ 624 | 'obsfun01', 625 | 'obsfun02', 626 | ] 627 | 628 | observation_functions = { 629 | 'obsfun01' : ObservationFunction01, 630 | 'obsfun02' : ObservationFunction02, 631 | } 632 | 633 | observation_function_parameters = { 634 | 'obsfun01' : {'slope_01' : 2, 'offset_01' : 10}, 635 | 'obsfun02' : {'p1' : 1.0, 'p2' : 2.0, 'p3' : 3.0, 'p4' : 4.0, 'p5' : 5.0}, 636 | } 637 | 638 | observed_states = { 639 | 'obsfun01' : 'y0', 640 | 'obsfun02' : 'y1', 641 | } 642 | 643 | variants_obsfun01 = [ 644 | ObservationFunction01, 645 | ] 646 | 647 | variants_obsfun02 = [ 648 | ObservationFunction02, 649 | ObservationFunction02_V02, 650 | ObservationFunction02_V03, 651 | ObservationFunction02_V04 652 | ] 653 | 654 | 655 | bad_variants_obsfun01 = [ 656 | ObservationFunction01_Bad01, 657 | ObservationFunction01_Bad02, 658 | ] 659 | -------------------------------------------------------------------------------- /pyfoomb/model_checking.py: -------------------------------------------------------------------------------- 1 | 2 | import inspect 3 | import numpy 4 | 5 | from typing import Callable, List 6 | import warnings 7 | 8 | from .constants import Messages 9 | from .modelling import BioprocessModel 10 | from .modelling import ObservationFunction 11 | from .simulation import Simulator 12 | 13 | 14 | class ModelChecker(): 15 | """ 16 | A helper class providing methods to assist users in consitent model implementation. 17 | """ 18 | 19 | def check_model_consistency(self, simulator:Simulator, report:bool=True) -> bool: 20 | """ 21 | Runs several consistency checks for the implemented bioprocess model and observations functions for a Simulator instance. 22 | 23 | Arguments 24 | --------- 25 | simulator : Simulator 26 | 27 | Keyword arguments 28 | ----------------- 29 | report : bool 30 | Reports if a model and/or observer is not fully specified. 31 | Default is True 32 | 33 | Returns 34 | ------- 35 | bool 36 | The status of the currently implemented model consistency checking routines. 37 | 38 | Warns 39 | ----- 40 | UserWarning 41 | Some consistency checks failed. 42 | """ 43 | 44 | checks_ok = [] 45 | 46 | # Check for correct unpacking of states in rhs method 47 | _check_ok = self._check_state_unpacking(simulator.bioprocess_model.rhs, simulator.bioprocess_model.states) 48 | if not _check_ok: 49 | warnings.warn('A possible inconsistency for state vector unpacking in method `rhs` was detected.') 50 | checks_ok.append(_check_ok) 51 | 52 | # Check for correct order of state derivatives 53 | _check_ok = self._check_rhs_derivatives_order(simulator.bioprocess_model.rhs, simulator.bioprocess_model.states) 54 | if not _check_ok: 55 | warnings.warn('A possible inconsistency for returning order of state derivatives in method `rhs` was detected.') 56 | checks_ok.append(_check_ok) 57 | 58 | # Check for correct parameter unpacking in rhs method 59 | _check_ok = self._check_parameter_unpacking(simulator.bioprocess_model.rhs, simulator.bioprocess_model.model_parameters) 60 | if not _check_ok: 61 | warnings.warn('A possible inconsistency for parameter unpacking in method `rhs` was detected.') 62 | checks_ok.append(_check_ok) 63 | 64 | if simulator.bioprocess_model.initial_switches is not None: 65 | 66 | # Check for sw arg in rhs signature 67 | _check_ok = self._check_sw_arg(simulator.bioprocess_model.rhs) 68 | checks_ok.append(_check_ok) 69 | 70 | # Check for correct unpacking of states in state_events method 71 | _check_ok = self._check_state_unpacking(simulator.bioprocess_model.state_events, simulator.bioprocess_model.states) 72 | if not _check_ok: 73 | warnings.warn('A possible inconsistency for state vector unpacking in method `state_events` was detected.') 74 | checks_ok.append(_check_ok) 75 | 76 | # Check for correct parameter unpacking in state_events method 77 | _check_ok = self._check_parameter_unpacking(simulator.bioprocess_model.state_events, simulator.bioprocess_model.model_parameters) 78 | if not _check_ok: 79 | warnings.warn('A possible inconsistency for parameter unpacking in method `state_events` was detected.') 80 | checks_ok.append(_check_ok) 81 | 82 | # Check for correct unpacking of states in change_states method 83 | _check_ok = self._check_state_unpacking(simulator.bioprocess_model.change_states, simulator.bioprocess_model.states) 84 | if not _check_ok: 85 | warnings.warn('A possible inconsistency for state vector unpacking in method `change_states` was detected.') 86 | checks_ok.append(_check_ok) 87 | 88 | # Check for correct parameter unpacking in change_states method 89 | _check_ok = self._check_parameter_unpacking(simulator.bioprocess_model.change_states, simulator.bioprocess_model.model_parameters) 90 | if not _check_ok: 91 | warnings.warn('A possible inconsistency for parameter unpacking in method `state_events` was detected.') 92 | checks_ok.append(_check_ok) 93 | 94 | # Call methods to see if there are any issues 95 | checks_ok.append( 96 | self._call_checks_bioprocess_model_methods(simulator, True) 97 | ) 98 | 99 | if simulator.observer is not None: 100 | 101 | for _obs_fun in simulator.observer.observation_functions: 102 | _observation_function = simulator.observer.observation_functions[_obs_fun] 103 | 104 | checks_ok.append( 105 | self._check_observe_method(_observation_function.observe, _observation_function.observation_parameters) 106 | ) 107 | 108 | self._call_check_observe_method(_observation_function) 109 | 110 | return all(checks_ok) 111 | 112 | 113 | def _check_observe_method(self, method:Callable, observation_parameters:dict) -> bool: 114 | """ 115 | Checks the observe method of an Observationfunction subclass for inconsistencies. 116 | 117 | Arguments 118 | --------- 119 | method : Callable 120 | The `observe` method of an ObservationFunction subclass. 121 | 122 | observation_parameters : dict 123 | The corresponding observation parameter values. 124 | 125 | Returns 126 | ------- 127 | check_ok : bool 128 | The status of the currently implemented consistency check 129 | 130 | Warns 131 | ----- 132 | UserWarning 133 | Parameters are unpacked in the wrong order. 134 | Variable names do not match the keys of the `observation_parameters`. 135 | """ 136 | 137 | check_ok = True 138 | 139 | _lines = inspect.getsourcelines(method) 140 | all_in_one = ''.join(_lines[0]).replace('\n', '').replace(' ', '').replace('\\', '').replace(',]', ']').split('#')[0] 141 | 142 | # check for correct parameter unpacking when model_parameters are unpacked at once 143 | if 'self.observation_parameters.to_numpy()' in all_in_one: 144 | search_str = str(list(observation_parameters.keys())).replace(' ', '').replace("'", "").replace('[','').replace(']','')+'=self.observation_parameters.to_numpy()' 145 | if not search_str in all_in_one: 146 | correct_str = search_str.replace(',', ', ').replace('=', ' = ') 147 | warnings.warn( 148 | f'Detected wrong order of parameter unpacking at once. Correct order is {correct_str}', 149 | UserWarning 150 | ) 151 | check_ok = False 152 | 153 | for _line in _lines[0]: 154 | curr_line = _line.replace(' ', '').replace('\n','').split('#')[0] 155 | # check correct variable naming for explicit parameter unpacking 156 | if 'self.observation_parameters[' in curr_line: 157 | ok_unpack = False 158 | for p in list(observation_parameters.keys()): 159 | valid_par_var1 = f"{p}=self.observation_parameters['{p}']" 160 | valid_par_var2 = f'{p}=self.observation_parameters["{p}"]' 161 | if valid_par_var1 in curr_line or valid_par_var2 in curr_line: 162 | ok_unpack = True 163 | break 164 | if not ok_unpack: 165 | _line_msg = _line.replace('\n','') 166 | warnings.warn( 167 | f'Variable names from explicit parameter unpacking must match those of the corresponding keys.\nThis line is bad: {_line_msg}', 168 | UserWarning 169 | ) 170 | check_ok = False 171 | 172 | return check_ok 173 | 174 | 175 | def _call_check_observe_method(self, observation_function:ObservationFunction): 176 | """ 177 | Calls the `observe` method of an Observationfunction object to check for errors. 178 | 179 | Arguments 180 | --------- 181 | observation_function : ObservationFunction 182 | An ObservationFunction object of the current Simulator instance under investigation. 183 | """ 184 | 185 | state_values = [ 186 | -1.0, 187 | 0.0, 188 | 1.0, 189 | numpy.array([0, 1,]), 190 | numpy.array([-1, 0, 1,]) 191 | ] 192 | 193 | for _state_values in state_values: 194 | observation_function.observe(_state_values) 195 | 196 | 197 | def _check_state_unpacking(self, method:Callable, states:list) -> bool: 198 | """ 199 | Checks the order of state unpacking in a method. 200 | 201 | Arguments 202 | --------- 203 | method : Callable 204 | The bioprocess model method to be checked. 205 | 206 | states : list 207 | the states of the bioprocess model intance of the current Simulator object under investigation. 208 | 209 | Returns 210 | ------- 211 | check_ok : bool 212 | The status of the currently implemented consistency check 213 | 214 | Warns 215 | ----- 216 | UserWarning 217 | States are unpacked in the wrong order. 218 | """ 219 | 220 | check_ok = True 221 | 222 | _lines = inspect.getsourcelines(method) 223 | _code_text = ''.join(_lines[0]).replace('\n', '').replace(' ', '').replace('\\', '').replace(',]', ']').split('#')[0] 224 | _doc = inspect.getdoc(method) 225 | _doc_text = _doc.replace('\n', '').replace(' ', '').replace('\\', '').replace(',]', ']') 226 | code_text = _code_text.replace(_doc_text, '') 227 | 228 | # Check for correct unpacking of states 229 | if '=y' in code_text: 230 | states_str = str(states).replace("'", '').replace('[','').replace(']','').replace(' ','')+'=y' 231 | if states_str not in code_text: 232 | correct_states_str = states_str.replace(',', ', ').replace('=', ' = ') 233 | warnings.warn( 234 | f'{Messages.unpacking_state_vector}. Correct order would be {correct_states_str}', 235 | UserWarning, 236 | ) 237 | check_ok = False 238 | 239 | return check_ok 240 | 241 | 242 | def _check_parameter_unpacking(self, method:Callable, model_parameters:dict) -> bool: 243 | """ 244 | Checks a methhod for consitent parameter unpacking. 245 | 246 | Arguments 247 | --------- 248 | method : Callable 249 | the method to be checked. 250 | 251 | model_parameters : dict 252 | The corresponding model parameter values. 253 | 254 | Returns 255 | ------- 256 | check_ok : bool 257 | The status of the currently implemented consistency check. 258 | 259 | Warns 260 | ----- 261 | UserWarning 262 | Parameters are unpacked in the wrong order. 263 | Variable names do not match the keys of the `model_parameters`. 264 | """ 265 | 266 | check_ok = True 267 | 268 | _lines = inspect.getsourcelines(method) 269 | _code_text = ''.join(_lines[0]).replace('\n', '').replace(' ', '').replace('\\', '').replace(',]', ']').split('#')[0] 270 | _doc = inspect.getdoc(method) 271 | _doc_text = _doc.replace('\n', '').replace(' ', '').replace('\\', '').replace(',]', ']') 272 | code_text = _code_text.replace(_doc_text, '') 273 | 274 | # Check for correct parameter unpacking when model_parameters are unpacked at once 275 | if 'self.model_parameters.to_numpy()' in code_text: 276 | search_str = str(list(model_parameters.keys())).replace(' ', '').replace("'", "").replace('[','').replace(']','')+'=self.model_parameters.to_numpy()' 277 | if not search_str in code_text: 278 | correct_str = search_str.replace(',', ', ').replace('=', ' = ') 279 | warnings.warn( 280 | f'Detected wrong order of parameter unpacking at once. Correct order would be {correct_str}', 281 | UserWarning, 282 | ) 283 | check_ok = False 284 | 285 | # Check correct parameter unpacking using the model_parameters dict keys 286 | # First get the lines of the doc string and remove any whitespaces 287 | _doc_lines = _doc.replace(' ', '').split('\n') 288 | for _line in _lines[0]: 289 | # Make sure that not the lines of the docstring 290 | if _line.replace(' ', '').replace('\n','') not in _doc_lines: 291 | curr_line = _line.replace(' ', '').replace('\n','').split('#')[0] 292 | # Check correct variable naming for explicit parameter unpacking 293 | if 'self.model_parameters[' in curr_line: 294 | ok_unpack = False 295 | for p in list(model_parameters.keys()): 296 | valid_par_var1 = f"{p}=self.model_parameters['{p}']" 297 | valid_par_var2 = f'{p}=self.model_parameters["{p}"]' 298 | if valid_par_var1 in curr_line or valid_par_var2 in curr_line: 299 | ok_unpack = True 300 | break 301 | if not ok_unpack: 302 | _line_msg = _line.replace('\n','') 303 | warnings.warn( 304 | f'Variable names from explicit parameter unpacking should match those of the corresponding keys.\nThis line seems bad: {_line_msg}.\nValid model parameters are {list(model_parameters.keys())}', 305 | UserWarning, 306 | ) 307 | check_ok = False 308 | 309 | return check_ok 310 | 311 | 312 | def _check_sw_arg(self, method:Callable) -> bool: 313 | """ 314 | Checks for use of `sw` argument in corresponding bioprocess model methods. 315 | 316 | Arguments 317 | --------- 318 | method : Callable 319 | The method in whose signature the `sw` argument shall be used. 320 | 321 | Returns 322 | ------- 323 | check_ok : bool 324 | The status of the currently implemented consistency check. 325 | 326 | Warns 327 | ----- 328 | UserWarning 329 | The `sw` argument is missing. 330 | """ 331 | 332 | check_ok = True 333 | 334 | _lines = inspect.getsourcelines(method) 335 | code_text = ''.join(_lines[0]).replace('\n', '').replace(' ', '').replace('\\', '').replace(',]', ']') 336 | 337 | first_line = _lines[0][0].replace(' ', '').replace('\n','') 338 | if not 'sw' in first_line: 339 | warnings.warn(Messages.missing_sw_arg, UserWarning) 340 | check_ok = False 341 | 342 | return check_ok 343 | 344 | 345 | def _check_rhs_derivatives_order(self, rhs_method:Callable, states:list) -> bool: 346 | """ 347 | Check that the order of the returned state derivatives is correct. 348 | 349 | Arguments 350 | --------- 351 | rhs_method : Callable 352 | The right-hand-side method implemented by the user. 353 | states : list 354 | The states of the bioprocess model. 355 | Returns 356 | ------- 357 | check_ok : bool 358 | The status of the currently implemented consistency check. 359 | 360 | Warns 361 | ----- 362 | UserWarning 363 | The returned derivatives are not in the correct order. 364 | """ 365 | 366 | check_ok = True 367 | 368 | _lines = inspect.getsourcelines(rhs_method) 369 | code_text = ''.join(_lines[0]).replace('\n', '').replace(' ', '').replace('\\', '').replace(',]', ']') 370 | 371 | # Check rhs return for correct order of derivatives 372 | ders = [f'd{_state}dt' for _state in states] 373 | ders_str = str(ders).replace("'", '').replace('[','').replace(']','').replace(' ','') 374 | if not ders_str in code_text: 375 | correct_ders_str = ders_str.replace(',',', ') 376 | warnings.warn( 377 | f'{Messages.wrong_return_order_state_derivatives}. Correct return is numpy.array([{correct_ders_str}])', 378 | UserWarning, 379 | ) 380 | check_ok = False 381 | 382 | return check_ok 383 | 384 | 385 | def _call_checks_bioprocess_model_methods(self, simulator:Simulator, report:bool=False) -> bool: 386 | """ 387 | Runs several call checks for the implemented bioprocess model and observations functions for a Simulator instance. 388 | 389 | Arguments 390 | --------- 391 | simulator : Simulator 392 | 393 | Keyword arguments 394 | ----------------- 395 | report : bool 396 | Reports if a model and/or observer is not fully specified. 397 | Default is True 398 | 399 | Returns 400 | ------- 401 | bool 402 | The status of the currently implemented model consistency checking routines. 403 | 404 | Warns 405 | ----- 406 | UserWarning 407 | The `state_events` methods returns a different number of events in certain situations. 408 | The number of initial switches does not match the number of events returned by the state_events method. 409 | """ 410 | 411 | check_ok = True 412 | 413 | # Call method `rhs` 414 | if simulator.bioprocess_model.initial_switches is not None: 415 | try: 416 | simulator.bioprocess_model.rhs( 417 | t=0, 418 | y=simulator.bioprocess_model.initial_values.to_numpy(), 419 | sw=simulator.bioprocess_model.initial_switches, 420 | ) 421 | except Exception as e: 422 | check_ok = False 423 | warnings.warn(f'Set `initial_switches` argument. Autodetection for number of events failed: {e}', UserWarning) 424 | return check_ok 425 | try: 426 | # Invert the switches 427 | simulator.bioprocess_model.rhs( 428 | t=0, 429 | y=simulator.bioprocess_model.initial_values.to_numpy(), 430 | sw=numpy.invert(simulator.bioprocess_model.initial_switches), 431 | ) 432 | except Exception as e: 433 | check_ok = False 434 | warnings.warn(f'Set `initial_switches` argument. Autodetection for number of events failed: {e}', UserWarning) 435 | return check_ok 436 | else: 437 | simulator.bioprocess_model.rhs( 438 | t=0, 439 | y=simulator.bioprocess_model.initial_values.to_numpy(), 440 | ) 441 | 442 | # Call method 'state_events' 443 | state_events_list_01 = simulator.bioprocess_model.state_events( 444 | t=0, 445 | y=simulator.bioprocess_model.initial_values.to_numpy(), 446 | sw=simulator.bioprocess_model.initial_switches, 447 | ) 448 | if simulator.bioprocess_model.initial_switches is not None: 449 | state_events_list_02 = simulator.bioprocess_model.state_events( 450 | t=0, 451 | y=simulator.bioprocess_model.initial_values.to_numpy(), 452 | sw=numpy.invert(simulator.bioprocess_model.initial_switches), 453 | ) 454 | 455 | # Call method `change_states` 456 | simulator.bioprocess_model.change_states( 457 | t=0, 458 | y=simulator.bioprocess_model.initial_values.to_numpy(), 459 | sw=simulator.bioprocess_model.initial_switches, 460 | ) 461 | if simulator.bioprocess_model.initial_switches is not None: 462 | simulator.bioprocess_model.change_states( 463 | t=0, 464 | y=simulator.bioprocess_model.initial_values.to_numpy(), 465 | sw=numpy.invert(simulator.bioprocess_model.initial_switches), 466 | ) 467 | 468 | # Check length of returned event list with length of initial switches 469 | if simulator.bioprocess_model.initial_switches is not None: 470 | if len(state_events_list_01) != len(state_events_list_02): 471 | warnings.warn( 472 | 'The number of returned events seems to vary with the states of the switches', 473 | UserWarning, 474 | ) 475 | check_ok = False 476 | elif len(state_events_list_01) != len(simulator.bioprocess_model.initial_switches): 477 | warnings.warn( 478 | f'Number of initial switches does not match with number of events: {len(simulator.bioprocess_model.initial_switches)} vs. {len(state_events_list_01)}', 479 | UserWarning, 480 | ) 481 | check_ok = False 482 | 483 | return check_ok -------------------------------------------------------------------------------- /pyfoomb/modelling.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import copy 3 | import inspect 4 | import numpy 5 | from typing import List 6 | import warnings 7 | 8 | from assimulo.problem import Explicit_Problem 9 | from assimulo.solvers.sundials import CVode 10 | 11 | from .constants import Constants 12 | from .constants import Messages 13 | 14 | from .datatypes import ModelState 15 | from .datatypes import Observation 16 | 17 | from .utils import Helpers 18 | from .utils import OwnDict 19 | 20 | OBSERVED_STATE_KEY = Constants.observed_state_key 21 | 22 | 23 | class BioprocessModel(Explicit_Problem): 24 | """ 25 | The abstract base class for bioprocess models, implemented as system of ODEs. 26 | Supports event handling and corresponding modification of states, parameters or whole equations. 27 | For integration of the ODE system, the CVode solver by Sundials is used, which is made available via the assimulo package. 28 | The BioprocessModel class subclasses assimulos `Explicit_Problem`, as recommended, see https://jmodelica.org/assimulo/ODE_CVode.html. 29 | """ 30 | 31 | def __init__(self, model_parameters:list, states:list, initial_switches:list=None, model_name:str=None, replicate_id:str=None): 32 | """ 33 | Arguments 34 | --------- 35 | model_parameters : list 36 | The (time-invariant) model parameters. 37 | states : list 38 | The names of the model states. 39 | 40 | Keyword arguments 41 | ----------------- 42 | initial_switches : 43 | A list of booleans, indicating the initial state of switches. 44 | Number of switches must correpond to the number of return events in method `state_events`, 45 | if this method is implemented by the inheriting class. 46 | Default is None, which enables auto-detection of initial switches, which all will be False. 47 | model_name : str 48 | A descriptive model name. 49 | Default is None. 50 | replicate_id : str 51 | Makes this `BioprocessModel` instance know about the `replicate_id` is is assigned to. 52 | Default is None, which implies a single replicate model. 53 | """ 54 | 55 | self._is_init = True 56 | self.replicate_id = replicate_id 57 | self.states = states 58 | self.initial_values = {f'{state}0' : numpy.nan for state in self.states} 59 | self.model_parameters = {model_parameter : numpy.nan for model_parameter in model_parameters} 60 | 61 | if initial_switches is None: 62 | try: 63 | _no_of_events = len(self.state_events(t=0, y=self.initial_values.to_numpy(), sw=None)) 64 | except Exception as e: 65 | print(f'Falling back to detect number of events: {e}') 66 | _no_of_events = self._auto_detect_no_of_events() 67 | print(f'Detected {_no_of_events} events') 68 | self.initial_switches = [False] * _no_of_events 69 | else: 70 | self.initial_switches = initial_switches 71 | 72 | if model_name is not None: 73 | self._name = model_name 74 | else: 75 | self._name = self.__class__.__name__ 76 | 77 | super(BioprocessModel, self).__init__( 78 | y0=self.initial_values.to_numpy(), 79 | sw0=self.initial_switches, 80 | name=self._name, 81 | ) 82 | 83 | self._is_init = False 84 | 85 | 86 | #%% Methods that implement the actual model 87 | 88 | @abc.abstractmethod 89 | def rhs(self, t:float, y:numpy.ndarray, sw:List[bool]) -> List[float]: 90 | """ 91 | Defines the right-hand-side of the explicit ODE formulation. This method will be integrated by the solver. 92 | 93 | Arguments 94 | --------- 95 | t : float 96 | The current time. 97 | y : numpy.ndarray 98 | The vector holding the current values of the model states. 99 | sw : List[bool] 100 | The current switch states. A switch is turned after its corresponding event was hit. 101 | Use this argument only if the model implements events. 102 | 103 | NOTE: An event equals not zero at one instant timepoint, while its corresponding switch is turned afterwards, 104 | and maintains its state until the event occurs again. 105 | 106 | Returns 107 | ------- 108 | List[float] or numpy.array 109 | The corresponding vector of derivatives for argument y. Must be the same order as `y`. 110 | """ 111 | 112 | 113 | def state_events(self, t:float, y:numpy.ndarray, sw:List[bool]) -> List[float]: 114 | """ 115 | Defines the roots (events) of the model states. 116 | An event is defined as y_i = 0, detected by an change in sign of y_i. 117 | 118 | Arguments 119 | --------- 120 | t : float 121 | The current time. 122 | y : numpy.ndarray 123 | The vector holding the current values of the model states. 124 | sw : list of bool 125 | The current switch states. A switch is turned after its corresponding event was hit. 126 | Use this argument only if the model implements events 127 | 128 | NOTE: An event equals not zero at one instant timepoint, while its corresponding switch is turned afterwards, 129 | and maintains its state until the event occurs again. 130 | 131 | Returns 132 | ------- 133 | List[float] or numpy.ndarray 134 | 135 | Example 136 | ------- 137 | # unpack the state vector for more convenient reading 138 | P, S, X = y 139 | 140 | X_ind = self.model_parameters['X_ind'] 141 | 142 | # event is when y[2] - X_ind = 0 143 | event_X = X - X_ind 144 | # event is hit when integration time is 20 145 | event_t = 20 - t 146 | 147 | return [event_X, event_t] 148 | """ 149 | 150 | return numpy.array([]) 151 | 152 | 153 | def change_states(self, t:float, y:numpy.ndarray, sw:List[bool]) -> List[float]: 154 | """ 155 | Initialize the ODE system with the new conditions, i.e. change the values of state variables 156 | depending on the value of an state_event_info list (can be 1, -1, or 0). 157 | 158 | NOTE: This method is only called in case ANY event is hit. 159 | One can filter which event was hit by evaluating `solver.sw` and `state_event_info`. 160 | 161 | Arguments 162 | --------- 163 | t : float 164 | The current time. 165 | y : array or list 166 | The vector holding the current values of the model states. 167 | sw : list of bool 168 | The current switch states. A switch is turned after its corresponding event was hit. 169 | 170 | Returns 171 | ------- 172 | List[float] or numpy.ndarray 173 | The updated state vector for restart of integration. 174 | 175 | Example 176 | ------- 177 | # Unpacks the state vector. The states are alphabetically ordered. 178 | A, B, C = y 179 | 180 | # Change state A when the second event is hit. 181 | if sw[1]: 182 | A_add = self.model_parameters['A_add'] 183 | A = A + A_add 184 | 185 | return [A, B, C] 186 | """ 187 | 188 | return y 189 | 190 | 191 | #%% Helper methods for handling the model implementation, need normally not to be implemented by the subclass 192 | 193 | def handle_event(self, solver:CVode, event_info:list): 194 | """ 195 | Handling events that are discovered during the integration process. 196 | Normally, this method does not need to be overridden by the subclass. 197 | """ 198 | 199 | state_event_info = event_info[0] # Not the 'time events', has their own method (event_info is a list) 200 | 201 | while True: 202 | # turn event switches of the solver instance 203 | self.event_switch(solver, state_event_info) 204 | # Collect event values before changing states 205 | before_mode = self.state_events(solver.t, solver.y, solver.sw) 206 | # Can now change the states 207 | solver.y = numpy.array(self.change_states(solver.t, solver.y, solver.sw)) 208 | # Collect event values after changing states 209 | after_mode = self.state_events(solver.t, solver.y, solver.sw) 210 | event_iter = self.check_event_iter(before_mode, after_mode) 211 | # Check if event values have been changes because the states were changed by the user 212 | if not True in event_iter: # Breaks the iteration loop 213 | break 214 | 215 | 216 | def event_switch(self, solver:CVode, state_event_info:List[int]): 217 | """ 218 | Turns the switches if a correponding event was hit. 219 | Helper method for method `handle_event`. 220 | 221 | Arguments 222 | --------- 223 | solver : CVode 224 | The solver instance. 225 | state_event_info : List[int] 226 | Indicates for which state an event was hit (0: no event, -1 and 1 indicate a zero crossing) 227 | """ 228 | 229 | for i in range(len(state_event_info)): #Loop across all event functions 230 | if state_event_info[i] != 0: 231 | solver.sw[i] = not solver.sw[i] #Turn the switch 232 | 233 | 234 | def check_event_iter(self, before:List[float], after:List[float]) -> List[bool]: 235 | """ 236 | Helper method for method `handle_event` to change the states at an timpoint of event. 237 | 238 | Arguments 239 | --------- 240 | before : List[float] 241 | The list of event monitoring values BEFORE the solver states (may) have been changed. 242 | after : List[float] 243 | The list of event monitoring values AFTER the solver states (may) have been changed. 244 | 245 | Returns 246 | ------- 247 | event_iter : List[bool] 248 | Indicates changes in state values at corresponding positions. 249 | """ 250 | 251 | event_iter = [False]*len(before) 252 | 253 | for i in range(len(before)): 254 | if (before[i] <= 0.0 and after[i] > 0.0) or \ 255 | (before[i] >= 0.0 and after[i] < 0.0) or \ 256 | (before[i] < 0.0 and after[i] >= 0.0) or \ 257 | (before[i] > 0.0 and after[i] <= 0.0): 258 | event_iter[i] = True 259 | 260 | return event_iter 261 | 262 | 263 | #%% Other public methods 264 | 265 | def set_parameters(self, values:dict): 266 | """ 267 | Assigns specfic values to the models initial values and / or model parameters. 268 | 269 | Arguments 270 | --------- 271 | values : dict 272 | Key-value pairs for parameters that are to be set. 273 | Keys must match the names of initial values or model parameters. 274 | 275 | Raises 276 | ------ 277 | KeyError 278 | The parameters values to be set contain a key 279 | that is neither an initial value, nor a model parameter. 280 | """ 281 | 282 | if not Helpers.has_unique_ids(values): 283 | raise KeyError(Messages.non_unique_ids) 284 | 285 | existing_keys = [] 286 | existing_keys.extend(self.initial_values.keys()) 287 | existing_keys.extend(self.model_parameters.keys()) 288 | 289 | _initial_values = copy.deepcopy(self.initial_values) 290 | _model_parameters = copy.deepcopy(self.model_parameters) 291 | 292 | for key in values.keys(): 293 | if key in _initial_values.keys(): 294 | _initial_values[key] = values[key] 295 | if key in _model_parameters.keys(): 296 | _model_parameters[key] = values[key] 297 | 298 | self.initial_values = _initial_values 299 | self.model_parameters = _model_parameters 300 | 301 | 302 | #%% Private methods 303 | 304 | def __str__(self): 305 | return self.__class__.__name__ 306 | 307 | 308 | def _auto_detect_no_of_events(self) -> int: 309 | """ 310 | Convenient auto-detection of event to define initial switches. 311 | 312 | Returns 313 | ------- 314 | no_of_events : int 315 | The automatically detected number of events. 316 | Works only for explicitly states events in the return of methods `state_events`. 317 | 318 | NOTE: Does not work with joblib parallel loky backend and IPython. 319 | """ 320 | 321 | _lines = inspect.getsourcelines(self.state_events) 322 | all_in_one = ''.join(_lines[0]).replace('\n', '').replace(' ', '').replace('\\', '') 323 | after_return = all_in_one.split('return')[-1] 324 | 325 | if '[]' in after_return: # there are no detectable events 326 | no_of_events = 0 327 | else: 328 | # detect automatically the number of returned events by the number of commas 329 | final_comma = after_return.count(',]') 330 | no_of_events = after_return.count(',') - final_comma + 1 331 | return no_of_events 332 | 333 | 334 | #%% Properties 335 | 336 | @property 337 | def states(self): 338 | return self._states 339 | 340 | 341 | @states.setter 342 | def states(self, value:list): 343 | if self._is_init: 344 | if not Helpers.has_unique_ids(value): 345 | raise KeyError(Messages.non_unique_ids) 346 | if not isinstance(value, list): 347 | raise TypeError('Model states must be a list.') 348 | self._states = sorted(value, key=str.lower) 349 | else: 350 | raise AttributeError(f'Cannot set states after instantiation of {self.__class__.__name__}') 351 | 352 | 353 | @property 354 | def initial_values(self): 355 | return self._initial_values 356 | 357 | 358 | @initial_values.setter 359 | def initial_values(self, value): 360 | if not isinstance(value, dict): 361 | raise TypeError(Messages.invalid_initial_values_type) 362 | 363 | if not Helpers.has_unique_ids(value): 364 | raise KeyError(Messages.non_unique_ids) 365 | 366 | _dict = OwnDict() 367 | for key, state in zip(sorted(value, key=str.lower), [f'{_state}0' for _state in self._states]): 368 | if key != state: 369 | raise KeyError(f'Initial value keys must match the state names {self.states}, extended by "0"') 370 | _dict[key] = value[key] 371 | self._initial_values = _dict 372 | 373 | # updates y0 in case the property is set after model initialization 374 | if not self._is_init: 375 | self.y0 = self.initial_values.to_numpy() 376 | 377 | 378 | @property 379 | def model_parameters(self): 380 | return self._model_parameters 381 | 382 | 383 | @model_parameters.setter 384 | def model_parameters(self, value): 385 | if not isinstance(value, dict): 386 | raise TypeError('Model parameters must be provided as dictionary') 387 | 388 | if not Helpers.has_unique_ids(value): 389 | raise KeyError(Messages.non_unique_ids) 390 | 391 | if not self._is_init: 392 | old_keys = sorted(self.model_parameters, key=str.lower) 393 | new_keys = sorted(value, key=str.lower) 394 | if old_keys != new_keys: 395 | raise KeyError(f'Cannot set values for unknown parameters: {new_keys} vs. {old_keys}') 396 | 397 | _dict = OwnDict() 398 | for key in sorted(value, key=str.lower): 399 | _dict[key] = value[key] 400 | self._model_parameters = _dict 401 | 402 | 403 | @property 404 | def initial_switches(self): 405 | return self._initial_switches 406 | 407 | 408 | @initial_switches.setter 409 | def initial_switches(self, value): 410 | 411 | if value == [] or value is None: 412 | self._initial_switches = None 413 | else: 414 | if not self._is_init and len(value) != len(self.initial_switches): 415 | raise ValueError(f'Invalid number of initial switches provided') 416 | for _value in value: 417 | if type(_value) != bool: 418 | raise ValueError('Initial switch states must be of type boolean') 419 | self._initial_switches = value 420 | 421 | 422 | class ObservationFunction(abc.ABC): 423 | """ 424 | Base class for observation functions that observe model states. Each model state can be observed, 425 | while the mapping is described by the specific observation function with its own parameters. 426 | A model state can be observed by multiple observation functions. 427 | """ 428 | 429 | def __init__(self, observed_state:str, observation_parameters:list, replicate_id:str=None): 430 | """ 431 | Arguments 432 | --------- 433 | observed_state : str 434 | The name of the model state that is observed by this object. 435 | observation_parameters : list 436 | The names of observation parameters for this ObservationFunction. 437 | 438 | Keyword arguments 439 | ----------------- 440 | replicate_id : str 441 | Makes this `ObservationFunction` instance know about the `replicate_id` is is assigned to. 442 | Default is None, which implies a single replicate model. 443 | """ 444 | 445 | self._is_init = True 446 | self.replicate_id = replicate_id 447 | self.observed_state = observed_state 448 | self.observation_parameters = {p : None for p in observation_parameters if p != OBSERVED_STATE_KEY} 449 | self._is_init = False 450 | 451 | 452 | #%% Public methods 453 | 454 | @abc.abstractmethod 455 | def observe(self, state_values:numpy.ndarray) -> numpy.ndarray: 456 | """ 457 | Describes the mapping of model state into observation. 458 | 459 | Arguments 460 | --------- 461 | state_values : numpy.ndarray 462 | 463 | Returns 464 | ------- 465 | numpy.ndarray 466 | """ 467 | 468 | raise NotImplementedError('Method must be implemented by the inheriting class.') 469 | 470 | 471 | def get_observation(self, model_state:ModelState, replicate_id:str=None): 472 | """ 473 | Applies the observation function on a ModelState object. 474 | 475 | Arguments 476 | --------- 477 | model_state : ModelState 478 | An instance of ModelState, as returned by the simulate method of a Simulator object, 479 | together with several other ModelState objects. 480 | 481 | Returns 482 | ------- 483 | observation : Observation 484 | An instance of the Observation class. 485 | 486 | Raises 487 | ------ 488 | KeyError 489 | The name of the model state to be observed and `observed_state` property of this observation function do not match. 490 | ValueError 491 | The replicate ids of this ObservationFunction object and the ModelState object to be observed do not match. 492 | """ 493 | 494 | if model_state.name != self.observed_state: 495 | raise KeyError(f'Model state and observed state do not match. {model_state.name} vs. {self.observed_state}') 496 | if model_state.replicate_id != self.replicate_id: 497 | raise ValueError(f'Replicate ids of model state and observation functions do not match: {model_state.replicate_id} vs. {self.replicate_id}') 498 | observation = Observation( 499 | name=self.name, 500 | observed_state=self.observed_state, 501 | timepoints=model_state.timepoints, 502 | values=self.observe(model_state.values), 503 | replicate_id=replicate_id, 504 | ) 505 | return observation 506 | 507 | 508 | def set_parameters(self, parameters:dict): 509 | """ 510 | Assigns specfic values to observation parameters. 511 | 512 | Arguments 513 | --------- 514 | values : dict 515 | Key-value pairs for parameters that are to be set. 516 | Keys must match the names of observation parameters. 517 | """ 518 | 519 | _observation_parameters = copy.deepcopy(self.observation_parameters) 520 | for key in parameters.keys(): 521 | if key in self.observation_parameters.keys(): 522 | _observation_parameters[key] = parameters[key] 523 | self.observation_parameters = _observation_parameters 524 | 525 | 526 | #%% Private methods 527 | 528 | def __str__(self): 529 | return self.__class__.__name__ 530 | 531 | 532 | #%% Properties 533 | 534 | @property 535 | def name(self): 536 | return self.__class__.__name__ 537 | 538 | 539 | @property 540 | def observed_state(self): 541 | return self._observed_state 542 | 543 | 544 | @observed_state.setter 545 | def observed_state(self, value): 546 | if self._is_init: 547 | if not isinstance(value, str): 548 | raise ValueError(f'Bad value: {value}. Must provide a str.') 549 | self._observed_state = value 550 | else: 551 | raise AttributeError(f'Cannot set observed_state after instantiation of {self.__class__.__name__}') 552 | 553 | 554 | @property 555 | def observation_parameters(self): 556 | return self._observation_parameters 557 | 558 | 559 | @observation_parameters.setter 560 | def observation_parameters(self, value): 561 | if not isinstance(value, dict): 562 | raise ValueError('Observation parameters must be provided as dictionary.') 563 | 564 | if not Helpers.has_unique_ids(value): 565 | raise KeyError(Messages.non_unique_ids) 566 | 567 | if not self._is_init: 568 | old_keys = sorted(self.observation_parameters, key=str.lower) 569 | new_keys = sorted(value, key=str.lower) 570 | if OBSERVED_STATE_KEY in new_keys: 571 | new_keys.remove(OBSERVED_STATE_KEY) 572 | 573 | if old_keys != new_keys: 574 | raise KeyError(f'Cannot set values for unknown parameters: {new_keys} vs. {old_keys}') 575 | 576 | _dict = OwnDict() 577 | for key in sorted(value, key=str.lower): 578 | if key == OBSERVED_STATE_KEY: 579 | if value[key] != self.observed_state: 580 | raise ValueError(f'Parameter observed_state {value[key]} does not match {self.observed_state}') 581 | else: 582 | _dict[key] = value[key] 583 | self._observation_parameters = _dict -------------------------------------------------------------------------------- /pyfoomb/generalized_islands.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import io 3 | import joblib 4 | import matplotlib 5 | from matplotlib import pyplot 6 | import numpy 7 | import psutil 8 | from typing import Callable, Dict, List, Tuple 9 | 10 | from assimulo.solvers.sundials import CVodeError 11 | import pygmo 12 | 13 | from .datatypes import Measurement 14 | 15 | pyplot.style.use('ggplot') 16 | 17 | 18 | class LossCalculator(): 19 | """ 20 | Defines the objective that is used to create an pygmo problem instance. 21 | See pygmo docu for further information (e.g., https://esa.github.io/pagmo2/docs/python/tutorials/coding_udp_simple.html). 22 | """ 23 | 24 | def __init__(self, 25 | unknowns:list, bounds:list, metric:str, measurements:List[Measurement], caretaker_loss_fun:Callable, 26 | handle_CVodeError:bool=True, verbosity_CVodeError:bool=False, 27 | ): 28 | """ 29 | Arguments 30 | --------- 31 | unknowns : list 32 | The unknowns to be estimated. 33 | bounds : list 34 | Corresponding list of (upper, lower) bounds. 35 | metric : str 36 | The loss metric, which is minimized for model calibration. 37 | measurements : List[Measurement] 38 | The measurements for which the model will be calibrated. 39 | caretaker_loss_fun : Callable 40 | The Caretaker's loss function. 41 | 42 | Keyword arguments 43 | ----------------- 44 | handle_CVodeError : bool 45 | to handle arising CVodeErrors. 46 | Default is True, which returns an infinite loss. 47 | verbosity_CVodeError : bool 48 | To report about handled CVodeErrros. 49 | Default is False. 50 | """ 51 | 52 | self.unknowns = unknowns 53 | self.metric = metric 54 | self.lower_bounds = [_bounds[0] for _bounds in bounds] 55 | self.upper_bounds = [_bounds[1] for _bounds in bounds] 56 | self.measurements = measurements 57 | self.caretaker_loss_fun = caretaker_loss_fun 58 | self.handle_CVodeError = handle_CVodeError 59 | self.verbosity_CVodeError = verbosity_CVodeError 60 | 61 | 62 | @property 63 | def current_parameters(self) -> dict: 64 | return self._current_parameters 65 | 66 | 67 | @current_parameters.setter 68 | def current_parameters(self, value): 69 | self._current_parameters = {unknown : _x for unknown, _x in zip(self.unknowns, value)} 70 | 71 | 72 | def check_constraints(self) -> List[bool]: 73 | return [True] 74 | 75 | 76 | def get_model_loss(self) -> float: 77 | """ 78 | Calculates the loss for the current parameter values. 79 | 80 | Returns 81 | ------- 82 | loss : float 83 | """ 84 | try: 85 | loss = self.caretaker_loss_fun( 86 | self.current_parameters, 87 | self.metric, 88 | self.measurements, 89 | self.handle_CVodeError, 90 | self.verbosity_CVodeError, 91 | ) 92 | if numpy.isnan(loss): 93 | loss = numpy.inf 94 | except CVodeError: 95 | loss = numpy.inf 96 | return loss 97 | 98 | 99 | def fitness(self, x) -> List[float]: 100 | """ 101 | Method for fitness calculation, as demanded by the pygmo package. 102 | """ 103 | 104 | # (1): Create the current parameter dictionary from the current guess vector 105 | self.current_parameters = x 106 | 107 | # (2) Check if any constraint is violated 108 | constraints_ok = self.check_constraints() 109 | if not all(constraints_ok): 110 | loss = numpy.inf 111 | 112 | # (3) Evaluate the Caretakers objective function only is no constraints have been violated 113 | else: 114 | loss = self.get_model_loss() 115 | 116 | # TODO: Regularization can be added to the loss here 117 | 118 | return [loss] 119 | 120 | 121 | def get_bounds(self) -> tuple: 122 | """ 123 | Method for checking the parameter bounds, as demanded by the pygmo package. 124 | """ 125 | 126 | return (self.lower_bounds, self.upper_bounds) 127 | 128 | 129 | def gradient(self, x): 130 | """ 131 | Method for gradient calculation, as demanded by the pygmo package for some pygmo optimizers. 132 | """ 133 | 134 | return pygmo.estimate_gradient_h(lambda x: self.fitness(x), x) 135 | 136 | 137 | class PyfoombArchipelago(pygmo.archipelago): 138 | """ 139 | An archipelago subclass, extended with specific properties needed for the pyFOOMB package. 140 | """ 141 | 142 | def __init__(self, *args, **kwargs): 143 | super().__init__(*args, **kwargs) 144 | self.mc_info = None 145 | self.finished = None 146 | self.problem = LossCalculator 147 | 148 | @property 149 | def mc_info(self): 150 | return self._mc_info 151 | 152 | @mc_info.setter 153 | def mc_info(self, value): 154 | self._mc_info = value 155 | 156 | @property 157 | def finished(self): 158 | return self._finished 159 | 160 | @finished.setter 161 | def finished(self, value): 162 | self._finished = value 163 | 164 | @property 165 | def problem(self): 166 | return self._problem 167 | 168 | @problem.setter 169 | def problem(self, value) -> LossCalculator: 170 | self._problem = value 171 | 172 | 173 | class PygmoOptimizers(): 174 | """ 175 | Static class that conveniently handles pygmo optimizers, as well as default kwargs that are considered useful. 176 | """ 177 | 178 | optimizers = { 179 | 'bee_colony' : pygmo.bee_colony, 180 | #'cmaes' : pygmo.cmaes, 181 | 'compass_search' : pygmo.compass_search, 182 | 'de' : pygmo.de, 183 | 'de1220' : pygmo.de1220, 184 | 'gaco' : pygmo.gaco, 185 | 'ihs' : pygmo.ihs, 186 | 'maco' : pygmo.maco, 187 | 'mbh' : pygmo.mbh, 188 | 'moead' : pygmo.moead, 189 | 'nlopt' : pygmo.nlopt, 190 | 'nsga2' : pygmo.nsga2, 191 | 'nspso' : pygmo.nspso, 192 | 'pso' : pygmo.pso, 193 | 'pso_gen' : pygmo.pso_gen, 194 | 'sade' : pygmo.sade, 195 | 'sea' : pygmo.sea, 196 | 'sga' : pygmo.sga, 197 | 'simulated_annealing' : pygmo.simulated_annealing, 198 | #'xnes' : pygmo.xnes, 199 | } 200 | 201 | 202 | default_kwargs = { 203 | 'bee_colony' : {'limit' : 2, 'gen' : 10}, 204 | #'cmaes' : {'gen' : 10, 'force_bounds' : False, 'ftol' : 1e-8, 'xtol' : 1e-8}, 205 | 'compass_search' : {'max_fevals' : 100, 'start_range' : 1, 'stop_range' : 1e-6}, 206 | 'de' : {'gen' : 10, 'ftol' : 1e-8, 'xtol' : 1e-8}, 207 | 'de1220' : {'gen' : 10, 'variant_adptv' : 2, 'ftol' : 1e-8, 'xtol' : 1e-8}, 208 | 'gaco' : {'gen' : 10}, 209 | 'ihs' : {'gen' : 10*4}, 210 | 'maco' : {'gen' : 10}, 211 | 'mbh' : {'algo' : 'compass_search', 'perturb' : 0.1, 'stop' : 2}, 212 | 'moead' : {'gen' : 10}, 213 | 'nlopt' : {'solver' : 'lbfgs'}, 214 | 'nsga2' : {'gen' : 10}, 215 | 'nspso' : {'gen' : 10}, 216 | 'pso' : {'gen' : 10}, 217 | 'pso_gen' : {'gen' : 10}, 218 | 'sade' : { 'gen' : 10, 'variant_adptv' : 2, 'ftol' : 1e-8, 'xtol' : 1e-8}, 219 | 'sea' : {'gen' : 10*4}, 220 | 'sga' : {'gen' : 10}, 221 | 'simulated_annealing' : {}, 222 | #'xnes' : {'gen' : 10, 'ftol' : 1e-8, 'xtol' : 1e-8, 'eta_mu' : 0.05}, 223 | } 224 | 225 | 226 | @staticmethod 227 | def get_optimizer_algo_instance(name:str, kwargs:dict=None) -> pygmo.algorithm: 228 | """ 229 | Get an pygmo optimizer instance. 230 | In case 'mbh' is chosen, key for the inner algorithm corresponds the respective names prefixed with `inner_`. 231 | 232 | Arguments 233 | --------- 234 | name : str 235 | The name of the desired optimizer. 236 | 237 | Keyword arguments 238 | ----------------- 239 | kwargs : dict 240 | Additional kwargs for creation of the optimizer instance, 241 | beside the default kwargs (see corresponding class attribute). 242 | 243 | Returns 244 | ------- 245 | pygmo.algorithm 246 | 247 | Raises 248 | ------ 249 | ValueError 250 | Argument `name` specifies an unsupported optimizer. 251 | """ 252 | 253 | if name not in PygmoOptimizers.optimizers.keys(): 254 | raise ValueError(f'Unsupported optimizer {name} chosen. Valid choices are {PygmoOptimizers.optimizers.keys()}') 255 | 256 | _kwargs = {} 257 | if name in PygmoOptimizers.default_kwargs.keys(): 258 | _kwargs.update(PygmoOptimizers.default_kwargs[name]) 259 | if kwargs is not None: 260 | for _kwarg in kwargs: 261 | _kwargs[_kwarg] = kwargs[_kwarg] 262 | 263 | if name == 'mbh': 264 | # (1) Get inner algorithm 265 | _inner_algo = _kwargs['algo'] 266 | _inner_kwargs = {} 267 | _outer_kwargs = {} 268 | for key in _kwargs: 269 | if key.startswith('inner_'): 270 | _inner_kwargs[key[6:]] = _kwargs[key] 271 | else: 272 | _outer_kwargs[key] = _kwargs[key] 273 | 274 | _algo_instance = PygmoOptimizers.get_optimizer_algo_instance(_inner_algo, _inner_kwargs) 275 | # (2) get new kwargs, cleaned from inner kwargs 276 | _kwargs = _outer_kwargs 277 | _kwargs['algo'] = _algo_instance 278 | 279 | return pygmo.algorithm(PygmoOptimizers.optimizers[name](**_kwargs)) 280 | 281 | 282 | class ParallelEstimationInfo(): 283 | 284 | def __init__(self, archipelago:PyfoombArchipelago, evolutions_trail:dict=None): 285 | """ 286 | Arguments 287 | --------- 288 | archipelago : PyfoombArchipelago 289 | The archipelago for which the evolutions have been run. 290 | 291 | Keyword arguments 292 | ----------------- 293 | evolutions_trail : dict 294 | Information about previous evolutions run with the archipelago. 295 | Default is None, which causes creation of a new dictionary. 296 | """ 297 | self.archipelago = archipelago 298 | if evolutions_trail is None: 299 | evolutions_trail = {} 300 | evolutions_trail['cum_runtime_min'] = [] 301 | evolutions_trail['evo_time_min'] = [] 302 | evolutions_trail['best_losses'] = [] 303 | evolutions_trail['best_estimates'] = [] 304 | evolutions_trail['estimates_info'] = [] 305 | self.evolutions_trail = evolutions_trail 306 | 307 | @property 308 | def runtime_trail(self): 309 | return numpy.array(self.evolutions_trail['cum_runtime_min']) 310 | 311 | 312 | @property 313 | def evotime_trail(self): 314 | return numpy.cumsum(self.evolutions_trail['evo_time_min']) 315 | 316 | 317 | @property 318 | def losses_trail(self): 319 | return numpy.array([_info['losses'].flatten() for _info in self.evolutions_trail['estimates_info']]) 320 | 321 | 322 | @property 323 | def best_loss_trail(self): 324 | return numpy.min(self.losses_trail, axis=1) 325 | 326 | 327 | @property 328 | def average_loss_trail(self) -> numpy.ndarray: 329 | return numpy.mean(self.losses_trail, axis=1) 330 | 331 | 332 | @property 333 | def std_loss_trail(self) -> numpy.ndarray: 334 | return numpy.std(self.losses_trail, axis=1, ddof=1) 335 | 336 | 337 | @property 338 | def estimates(self) -> dict: 339 | return ArchipelagoHelpers.estimates_from_archipelago(self.archipelago) 340 | 341 | 342 | def plot_loss_trail(self, x_log:bool=True): 343 | """ 344 | Shows the progression of the loss during the estimation process, more specifically the development 345 | of the best loss, the average loss and the correopnding CV among the parallel optimizations. 346 | 347 | Keyword arguments 348 | ----------------- 349 | x_log : bool 350 | To show the x-axis (the runtime) in log scale or not. 351 | Default is True. 352 | 353 | Returns 354 | ------- 355 | fig : The figure object 356 | ax : The axis object 357 | """ 358 | 359 | fig, ax = pyplot.subplots(nrows=2, ncols=1, dpi=100, figsize=(10, 5), sharex=True) 360 | ax[0].plot(self.evotime_trail, self.best_loss_trail, marker='.', linestyle='--', label='Best', zorder=2) 361 | ax[0].plot(self.evotime_trail, self.average_loss_trail, marker='.', linestyle='--', label='Average', zorder=1) 362 | ax[0].set_ylabel('Loss', size=14) 363 | ax[1].plot( 364 | self.evotime_trail, numpy.abs(self.std_loss_trail/self.average_loss_trail*100), 365 | marker='.', linestyle='--', label='CV of losses', 366 | ) 367 | ax[1].set_ylabel('CV in %', size=14) 368 | ax[1].set_xlabel('Cumulated evolution time in min', size=14) 369 | for _ax in ax.flat: 370 | _ax.legend(frameon=False) 371 | _ax.xaxis.set_tick_params(labelsize=12) 372 | _ax.yaxis.set_tick_params(labelsize=12) 373 | if x_log: 374 | _ax.set_xscale('log') 375 | fig.tight_layout() 376 | return fig, ax 377 | 378 | 379 | class ArchipelagoHelpers(): 380 | 381 | @staticmethod 382 | def estimates_from_archipelago(archipelago:PyfoombArchipelago) -> dict: 383 | """ 384 | Extracts the current estimated values for the optimization probelm of an archipelago. 385 | 386 | Arguments 387 | --------- 388 | archipelago : PyfoombArchipelago 389 | The evolved archipelago. 390 | 391 | Returns 392 | ------- 393 | dict : The current estimates. 394 | 395 | """ 396 | 397 | unknowns = ArchipelagoHelpers.problem_from_archipelago(archipelago).unknowns 398 | best_idx = numpy.argmin(numpy.array(archipelago.get_champions_f()).flatten()) 399 | estimates = { 400 | _unknown : _x 401 | for _unknown, _x in zip(unknowns, archipelago[int(best_idx)].get_population().champion_x) 402 | } 403 | return estimates.copy() # maybe a deep copy needed? 404 | 405 | 406 | @staticmethod 407 | def problem_from_archipelago(archipelago:PyfoombArchipelago) -> LossCalculator: 408 | """ 409 | Extracts the optimization problem from an archipelago, implemented as (subclass of) LossCalculator. 410 | 411 | Arguments 412 | --------- 413 | archipelago : PyfoombArchipelago 414 | The evolved archipelago. 415 | 416 | Returns 417 | ------- 418 | LossCalculator 419 | """ 420 | 421 | return archipelago[0].get_population().problem.extract(archipelago.problem) 422 | 423 | 424 | @staticmethod 425 | def create_population(pg_problem, pop_size, seed): 426 | return pygmo.population(pg_problem, pop_size, seed=seed) 427 | 428 | @staticmethod 429 | def parallel_create_population(arg): 430 | pg_problem, pop_size, seed = arg 431 | return ArchipelagoHelpers.create_population(pg_problem, pop_size, seed) 432 | 433 | @staticmethod 434 | def create_archipelago(unknowns:list, 435 | optimizers:list, 436 | optimizers_kwargs:list, 437 | pg_problem:pygmo.problem, 438 | rel_pop_size:float, 439 | archipelago_kwargs:dict, 440 | log_each_nth_gen:int, 441 | report_level:int, 442 | ) -> PyfoombArchipelago: 443 | """ 444 | Helper method for parallelized estimation using the generalized island model. 445 | Creates the archipelago object for running several rounds of evolutions. 446 | 447 | Arguments 448 | --------- 449 | unknowns : list 450 | The unknowns, sorted alphabetically and case-insensitive. 451 | optimizers : list 452 | A list of optimizers to be used on individual islands. 453 | optimizers_kwargs : list 454 | A list of corresponding kwargs. 455 | pg_problem : pygmo.problem 456 | An pygmo problem instance. 457 | archipelago_kwargs : dict 458 | Additional kwargs for archipelago creation. 459 | log_each_nth_gen : int 460 | Specifies at which each n-th generation the algorithm stores logs. 461 | report_level : int 462 | Prints information on the archipelago creation for values >= 1. 463 | 464 | Returns 465 | ------- 466 | archipelago : PyfoombArchipelago 467 | """ 468 | 469 | _cpus = joblib.cpu_count() 470 | 471 | # There is one optimizer with a set of kwargs 472 | if len(optimizers) == 1 and len(optimizers_kwargs) == 1: 473 | optimizers = optimizers * _cpus 474 | optimizers_kwargs = optimizers_kwargs * _cpus 475 | # Several optimizers with the same kwargs 476 | elif len(optimizers) > 1 and len(optimizers_kwargs) == 1: 477 | optimizers_kwargs = optimizers_kwargs * len(optimizers) 478 | # Several kwargs for the same optimizer 479 | elif len(optimizers) == 1 and len(optimizers_kwargs) > 1: 480 | optimizers = optimizers * len(optimizers_kwargs) 481 | elif len(optimizers) != len(optimizers_kwargs): 482 | raise ValueError('Number of optimizers does not match number of corresponding kwarg dicts') 483 | 484 | # Get the optimizer intances 485 | algos = [ 486 | PygmoOptimizers.get_optimizer_algo_instance( 487 | name=_optimizers, kwargs=_optimizers_kwargs 488 | ) 489 | for _optimizers, _optimizers_kwargs in zip(optimizers, optimizers_kwargs) 490 | ] 491 | 492 | # Update number of islands 493 | n_islands = len(algos) 494 | 495 | if report_level >= 1: 496 | print(f'Creating archipelago with {n_islands} islands. May take some time...') 497 | 498 | pop_size = int(numpy.ceil(rel_pop_size*len(unknowns))) 499 | prop_create_args = ( 500 | (pg_problem, pop_size, seed*numpy.random.randint(0, 1e4)) 501 | for seed, pop_size in enumerate([pop_size] * n_islands) 502 | ) 503 | try: 504 | parallel_verbose = 0 if report_level == 0 else 1 505 | with joblib.parallel_backend('loky', n_jobs=n_islands): 506 | pops = joblib.Parallel(verbose=parallel_verbose)(map(joblib.delayed(ArchipelagoHelpers.parallel_create_population), prop_create_args)) 507 | except Exception as ex: 508 | print(f'Parallelized archipelago creation failed, falling back to sequential\n{ex}') 509 | pops = (ArchipelagoHelpers.parallel_create_population(prop_create_arg) for prop_create_arg in prop_create_args) 510 | 511 | # Now create the empyty archipelago 512 | if not 't' in archipelago_kwargs.keys(): 513 | archipelago_kwargs['t'] = pygmo.fully_connected() 514 | archi = PyfoombArchipelago(**archipelago_kwargs) 515 | archi.set_migrant_handling(pygmo.migrant_handling.preserve) 516 | 517 | # Add the populations to the archipelago and wait for its construction 518 | with contextlib.redirect_stdout(io.StringIO()): 519 | for _pop, _algo in zip(pops, algos): 520 | if log_each_nth_gen is not None: 521 | _algo.set_verbosity(int(log_each_nth_gen)) 522 | _island = pygmo.island(algo=_algo, pop=_pop, udi=pygmo.mp_island()) 523 | archi.push_back(_island) 524 | archi.wait_check() 525 | 526 | return archi 527 | 528 | 529 | @staticmethod 530 | def extract_archipelago_results(archipelago:PyfoombArchipelago) -> Tuple[dict, float, dict]: 531 | """ 532 | Get the essential and further informative results from an archipelago object. 533 | 534 | Arguments 535 | --------- 536 | archipelago : PyfoombArchipelago 537 | The archipelago object after finished evolution. 538 | 539 | Returns 540 | ------- 541 | Tuple[dict, float, dict] 542 | The best estimates as dict, according to the best (smallest) loss. 543 | The best loss. 544 | A dictionary with several informative results. 545 | """ 546 | 547 | estimates_info = {} 548 | best_estimates = ArchipelagoHelpers.estimates_from_archipelago(archipelago) 549 | unknowns = list(best_estimates.keys()) 550 | 551 | best_idx = numpy.argmin(numpy.array(archipelago.get_champions_f()).flatten()) 552 | best_loss = float(archipelago[int(best_idx)].get_population().champion_f) 553 | 554 | estimates = { 555 | _unknown : _x 556 | for _unknown, _x in zip(unknowns, numpy.array([island.get_population().champion_x for island in archipelago], dtype=float).T) 557 | } 558 | losses = numpy.array([float(island.get_population().champion_f) for island in archipelago], dtype=float) 559 | 560 | estimates_info['best_estimates'] = best_estimates 561 | estimates_info['best_loss'] = best_loss 562 | estimates_info['estimates'] = estimates 563 | estimates_info['losses'] = losses 564 | 565 | return best_estimates, best_loss, estimates_info 566 | 567 | 568 | @staticmethod 569 | def check_evolution_stop(current_losses:numpy.ndarray, 570 | atol_islands:float, rtol_islands:float, 571 | current_runtime_min:float, max_runtime_min:float, 572 | current_evotime_min:float, max_evotime_min:float, 573 | max_memory_share:float, 574 | ) -> dict: 575 | """ 576 | Checks if losses between islands have been sufficiently converged. 577 | 578 | Arguments 579 | --------- 580 | current_losses : numpy.ndarray 581 | The best losses of all islands after an evolution. 582 | atol_islands : float 583 | stop_criterion = atol_islands + rtol_islands * numpy.abs(numpy.mean(current_losses)) 584 | rtol_islands : float 585 | stop_criterion = atol_islands + rtol_islands * numpy.abs(numpy.mean(current_losses)) 586 | current_runtime : float 587 | The current runtime in min of the estimation process after a completed evolution. 588 | max_runtime : float 589 | The maximal runtime in min the estimation process can take. 590 | max_memory_share : float 591 | The maximum relative memory occupation for which evolutions are run 592 | 593 | Returns 594 | ------- 595 | stopping_criteria : dict 596 | """ 597 | 598 | stopping_criteria = { 599 | 'convergence' : False, 600 | 'max_runtime' : False, 601 | 'max_evotime' : False, 602 | 'max_memory_share' : False, 603 | } 604 | 605 | # Check convergence 606 | if atol_islands is None: 607 | atol_islands = 0.0 608 | if rtol_islands is None: 609 | rtol_islands = 0.0 610 | 611 | _stop_criterion = atol_islands + rtol_islands * numpy.abs(numpy.mean(current_losses)) 612 | _abs_std = numpy.std(current_losses, ddof=1) 613 | if _abs_std < _stop_criterion: 614 | stopping_criteria['convergence'] = True 615 | 616 | # Check runtime 617 | if (current_runtime_min is not None) and (max_runtime_min is not None) and (current_runtime_min > max_runtime_min): 618 | stopping_criteria['max_runtime'] = True 619 | 620 | # Check evolution time 621 | if (current_evotime_min is not None) and (max_evotime_min is not None) and (current_evotime_min > max_evotime_min): 622 | stopping_criteria['max_evotime'] = True 623 | 624 | # Check memory occupation 625 | curr_memory_share = psutil.virtual_memory().percent/100 626 | if curr_memory_share > max_memory_share: 627 | stopping_criteria['max_memory_share'] = True 628 | 629 | return stopping_criteria 630 | 631 | 632 | @staticmethod 633 | def report_evolution_result(evolutions_results:dict, report_level:int): 634 | """ 635 | Helper method for parallel estimation method to report progress. 636 | 637 | Arguments 638 | --------- 639 | evolutions_result : dict 640 | Contains information on the result of an evolution. 641 | report_level : int 642 | Controls the output that is printed. 643 | 2 = prints the best loss, as well as information about archipelago creation and evolution. 644 | 3 = prints additionally average loss among all islands, and the runtime of the evolution. 645 | 4 = prints additionally the parameter values for the best loss, and the average parameter values 646 | among the champions of all islands in the `archipelago` after the evolutions. 647 | """ 648 | 649 | if report_level < 2: 650 | return 651 | 652 | if report_level >= 2: 653 | _evolution = len(evolutions_results['evo_time_min']) 654 | print(f'-------------Finished evolution {_evolution}-------------') 655 | _best_loss = evolutions_results['best_losses'][-1] 656 | print(f'Current best loss: {_best_loss}') 657 | 658 | if report_level >= 3: 659 | _estimates_info = evolutions_results['estimates_info'][-1] 660 | _mean = numpy.mean(_estimates_info['losses']) 661 | _std = numpy.std(_estimates_info['losses'], ddof=1) 662 | _cv = numpy.abs(_std/_mean*100) 663 | print(f'Average loss among the islands: {_mean:.6f} +/- {_std:.6f} ({_cv:.6f} %)') 664 | 665 | if report_level >= 4: 666 | _evo_time_min = evolutions_results['evo_time_min'][-1] 667 | print(f'Run time for this evolution was {_evo_time_min:.2f} min') 668 | --------------------------------------------------------------------------------