├── test ├── __init__.py ├── test_aggregateevents.py ├── test_errorcheck.py ├── test_datacorrection.py └── test_datacheck.py ├── MANIFEST.in ├── data └── example.pkl ├── requirements.txt ├── .travis.yml ├── ctmc ├── __init__.py ├── generatormatrix.py ├── aggregateevents.py ├── errorcheck.py ├── datacorrection.py ├── simulate.py ├── datacheck.py ├── ctmc_class.py └── ctmc_func.py ├── CHANGES.md ├── setup.py ├── LICENSE ├── README.md ├── profile ├── speed (timeit).ipynb ├── funcalls (prun).ipynb ├── memory (mprun).ipynb └── linebyline (lprun).ipynb ├── .gitignore └── examples ├── demo datacorrection.ipynb ├── demo datacheck.ipynb ├── automatic data correction.ipynb ├── demo ctmc.ipynb ├── internal data format of ctmc_fit.ipynb └── demo Ctmc sklearn API.ipynb /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | recursive-include test *.py 3 | -------------------------------------------------------------------------------- /data/example.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kmedian/ctmc/HEAD/data/example.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | line_profiler>=2.1.2 2 | memory_profiler>=0.54.0 3 | matplotlib>=2.2.2 4 | nose>=1.3.7 5 | numpy>=1.14.5 6 | scipy>=1.1.0 7 | scikit-learn>=0.19.2 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: python 3 | python: 4 | - "3.6" 5 | install: 6 | - pip install flake8>=3.5.0 7 | - pip install -r requirements.txt 8 | script: 9 | - flake8 --ignore=F401 10 | - python -W ignore -m unittest discover 11 | cache: pip 12 | -------------------------------------------------------------------------------- /ctmc/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .datacheck import datacheck 3 | from .errorcheck import errorcheck 4 | from .datacorrection import datacorrection 5 | from .ctmc_func import ctmc 6 | from .ctmc_class import Ctmc 7 | from .simulate import simulate 8 | 9 | # for profiling 10 | from .aggregateevents import aggregateevents 11 | -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | # 0.1.3 / 2019-09-26 2 | 3 | * add type hints 4 | * add unit tests 5 | * automatic data correction in ctmc.ctmc and ctmc.Ctmc 6 | 7 | # 0.1.2 / 2018-10-04 8 | 9 | * Comply to sklearn's instantiation rule, http://scikit-learn.org/stable/developers/contributing.html#instantiation 10 | 11 | # 0.1.1 / 2018-09-25 12 | 13 | * missing dependency added 14 | 15 | # 0.1.0 / 2018-09-21 16 | 17 | * Initial Release 18 | -------------------------------------------------------------------------------- /ctmc/generatormatrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def generatormatrix(transcount: np.ndarray, 5 | statetime: np.ndarray) -> np.ndarray: 6 | 7 | tmp = transcount.copy() 8 | n = tmp.shape[0] 9 | 10 | rowsum = np.sum(tmp, axis=1) 11 | for i in range(n): 12 | tmp[i, i] = -rowsum[i] 13 | 14 | genmat = np.zeros(shape=(n, n), dtype=float) 15 | for i in range(n): 16 | genmat[i, :] = tmp[i, :] / statetime[i] 17 | 18 | return genmat 19 | -------------------------------------------------------------------------------- /ctmc/aggregateevents.py: -------------------------------------------------------------------------------- 1 | import scipy.sparse 2 | import numpy as np 3 | 4 | 5 | def aggregateevents(data: list, numstates: int) -> (np.ndarray, np.ndarray): 6 | 7 | transcount = scipy.sparse.lil_matrix((numstates, numstates), dtype=int) 8 | statetime = np.zeros(numstates, dtype=float) 9 | 10 | for _, example in enumerate(data): 11 | states = example[0] 12 | times = example[1] 13 | 14 | for i, s in enumerate(states): 15 | statetime[s] += times[i] 16 | if i: 17 | transcount[states[i - 1], s] += 1 18 | 19 | return transcount.toarray(), statetime 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | def read(fname): 5 | import os 6 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 7 | 8 | 9 | setup(name='ctmc', 10 | version='0.1.3', 11 | description='Continous Time Markov Chain', 12 | long_description=read('README.md'), 13 | long_description_content_type='text/markdown', 14 | url='http://github.com/kmedian/ctmc', 15 | author='Ulf Hamster', 16 | author_email='554c46@gmail.com', 17 | license='MIT', 18 | packages=['ctmc'], 19 | install_requires=[ 20 | 'setuptools>=40.0.0', 21 | 'nose>=1.3.7', 22 | 'numpy>=1.14.5', 23 | 'scipy>=1.1.0', 24 | 'scikit-learn>=0.19.2'], 25 | python_requires='>=3.6', 26 | zip_safe=False) 27 | -------------------------------------------------------------------------------- /ctmc/errorcheck.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def errorcheck(transcount: np.ndarray, statetime: np.ndarray, 5 | toltime: float) -> bool: 6 | # check transitions counting went wrong 7 | if np.any(np.diag(transcount) != 0): 8 | raise Exception( 9 | ("Transition Count Matrix have diagonal " 10 | "elements 'm[i,i] != 0'. There are no transition counts " 11 | "for the i-th state to itself by definition.")) 12 | 13 | # check if statetime[i] is big enough to work as divisor 14 | if np.any(statetime < toltime): 15 | ids = ",".join([str(i) for i in np.where(statetime < toltime)[0]]) 16 | raise Exception( 17 | ("The states i={:s} have each a cumulated time period" 18 | " that is smaller than toltime.").format(ids)) 19 | 20 | return False 21 | -------------------------------------------------------------------------------- /test/test_aggregateevents.py: -------------------------------------------------------------------------------- 1 | import ctmc 2 | import unittest 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | 7 | def flatten(x): 8 | import itertools 9 | return list(itertools.chain.from_iterable( 10 | itertools.chain.from_iterable(x))) 11 | 12 | 13 | class Test_Aggregateevents(unittest.TestCase): 14 | 15 | def test1(self): 16 | # test case 17 | data = [ 18 | ([0, 1, 0], [0.5, 0.5, 0.5]) 19 | ] 20 | n_states = 2 21 | 22 | # solution 23 | sol_mat = [[0, 1], [1, 0]] 24 | sol_stm = [1, .5] 25 | 26 | # result 27 | res_mat, res_stm = ctmc.aggregateevents(data, n_states) 28 | 29 | # compare 30 | npt.assert_allclose(res_mat, sol_mat) 31 | npt.assert_allclose(res_stm, sol_stm) 32 | 33 | 34 | # run 35 | if __name__ == '__main__': 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ulf Hamster 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /test/test_errorcheck.py: -------------------------------------------------------------------------------- 1 | import ctmc 2 | import unittest 3 | import numpy as np 4 | 5 | 6 | class Test_Errorcheck(unittest.TestCase): 7 | 8 | def test1(self): 9 | transcount = np.array([[0, 99], [99, 1]]) # error: diag <> 0 10 | statetime = np.array([12, 34]) 11 | toltime = 1e-8 12 | 13 | with self.assertRaises(Exception) as context: 14 | ctmc.errorcheck(transcount, statetime, toltime) 15 | 16 | # print("\ntest1: " + str(context.exception)) 17 | self.assertTrue( 18 | 'Transition Count Matrix have diagonal' 19 | in str(context.exception)) 20 | 21 | def test2(self): 22 | transcount = np.array([[0, 99], [99, 0]]) 23 | statetime = np.array([12, .0]) # error: statetime for id=1 too small 24 | toltime = 1e-8 25 | 26 | with self.assertRaises(Exception) as context: 27 | ctmc.errorcheck(transcount, statetime, toltime) 28 | 29 | # print("\ntest2: " + str(context.exception)) 30 | self.assertTrue( 31 | 'is smaller than toltime' 32 | in str(context.exception)) 33 | 34 | 35 | # run 36 | if __name__ == '__main__': 37 | unittest.main() 38 | -------------------------------------------------------------------------------- /ctmc/datacorrection.py: -------------------------------------------------------------------------------- 1 | 2 | def datacorrection(datalist: list, toltime: float = 1e-8) -> list: 3 | 4 | newlist = list() 5 | 6 | for example in datalist: 7 | states = example[0] 8 | times = example[1] 9 | 10 | # skip this example 11 | if len(states) < 2: 12 | continue 13 | 14 | # delete durations < toltime 15 | tmp = [row for row in zip(states, times) if row[1] >= toltime] 16 | 17 | # skip this example 18 | if len(tmp) < 2: 19 | continue 20 | 21 | # merge consecutive states that are the same 22 | tmp2 = list() 23 | s = tmp[0][0] 24 | t = tmp[0][1] 25 | for i in range(1, len(tmp)): 26 | if s is tmp[i][0]: 27 | t += tmp[i][1] 28 | else: 29 | tmp2.append([s, t]) 30 | s = tmp[i][0] 31 | t = tmp[i][1] 32 | 33 | # add the last 34 | tmp2.append([s, t]) 35 | 36 | # skip this example 37 | if len(tmp2) < 2: 38 | continue 39 | 40 | # add to new list 41 | newlist.append([l for l in zip(*tmp2)]) 42 | 43 | # done 44 | return newlist 45 | -------------------------------------------------------------------------------- /ctmc/simulate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def simulate(s0: np.ndarray, transmat: np.ndarray, 5 | steps: int = 1) -> np.ndarray: 6 | """Simulate the next state 7 | 8 | Parameters 9 | ---------- 10 | s0 : ndarray 11 | Vector with state variables at t=0 12 | 13 | transmat : ndarray 14 | The estimated transition/stochastic matrix. 15 | 16 | steps : int 17 | (Default: 1) The number of steps to simulate model outputs ahead. 18 | If steps>1 the a Mult-Step Simulation is triggered. 19 | 20 | Returns 21 | ------- 22 | out : ndarray 23 | (steps=1) Vector with simulated state variables (). 24 | 25 | (steps>1) Matrix with out[:,step] columns (Fortran order) from a 26 | Multi-Step Simulation. The first column is the initial state 27 | vector out[:,0]=s0 for algorithmic reasons. 28 | """ 29 | # Single-Step simulation 30 | if steps == 1: 31 | return np.dot(s0, transmat) 32 | 33 | # Multi-Step simulation 34 | out = np.zeros(shape=(steps + 1, len(s0)), order='C') 35 | out[0, :] = s0 36 | 37 | for i in range(1, steps + 1): 38 | out[i, :] = np.dot(out[i - 1, :], transmat) 39 | 40 | return out 41 | -------------------------------------------------------------------------------- /ctmc/datacheck.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def datacheck(data: list, numstates: int, toltime: float) -> bool: 5 | 6 | eligiblestates = range(numstates) 7 | 8 | for exid, example in enumerate(data): 9 | states = example[0] 10 | times = example[1] 11 | 12 | if not np.all(np.isin(states, eligiblestates)): 13 | raise Exception( 14 | ("The example id={:d} has faulty state " 15 | "labels/encodings").format(exid)) 16 | 17 | if len(np.unique(states)) < 2: 18 | raise Exception( 19 | ("The example id={:d} has only 1 distinct " 20 | "state").format(exid)) 21 | 22 | for i in range(1, len(states)): 23 | if states[i - 1] == states[i]: 24 | raise Exception( 25 | ("The example id={:d} has two consequtive entries " 26 | "state[{:d}]==state[{:d}]").format(exid, i - 1, i)) 27 | 28 | for i, t in enumerate(times): 29 | if t < toltime: 30 | raise Exception( 31 | ("The example id={:d} has a state[{:d}] that have not " 32 | "been active for longer than toltime").format(exid, i)) 33 | 34 | return False 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/kmedian/ctmc.svg?branch=master)](https://travis-ci.org/kmedian/ctmc) 2 | [![Binder](https://mybinder.org/badge.svg)](https://mybinder.org/v2/gh/kmedian/ctmc/master?urlpath=lab) 3 | 4 | # ctmc 5 | 6 | 7 | ## Table of Contents 8 | * [Installation](#installation) 9 | * [Usage](#usage) 10 | * [Commands](#commands) 11 | * [Support](#support) 12 | * [Contributing](#contributing) 13 | 14 | 15 | ## Installation 16 | The `ctmc` [git repo](http://github.com/kmedian/ctmc) is available as [PyPi package](https://pypi.org/project/ctmc) 17 | 18 | ``` 19 | pip install ctmc 20 | ``` 21 | 22 | 23 | ## Usage 24 | Check the [examples](https://github.com/kmedian/ctmc/tree/master/examples) folder for notebooks. 25 | 26 | 27 | ## Commands 28 | * Check syntax: `flake8 --ignore=F401` 29 | * Run Unit Tests: `python -W ignore -m unittest discover` 30 | * Remove `.pyc` files: `find . -type f -name "*.pyc" | xargs rm` 31 | * Remove `__pycache__` folders: `find . -type d -name "__pycache__" | xargs rm -rf` 32 | * Upload to PyPi with twine: `python setup.py sdist && twine upload -r pypi dist/*` 33 | 34 | 35 | ## Debugging 36 | * Notebooks to profile python code are in the [profile](https://github.com/kmedian/ctmc/tree/master/profile) folder 37 | 38 | 39 | ## Support 40 | Please [open an issue](https://github.com/kmedian/ctmc/issues/new) for support. 41 | 42 | 43 | ## Contributing 44 | Please contribute using [Github Flow](https://guides.github.com/introduction/flow/). Create a branch, add commits, and [open a pull request](https://github.com/kmedian/ctmc/compare/). 45 | -------------------------------------------------------------------------------- /ctmc/ctmc_class.py: -------------------------------------------------------------------------------- 1 | from sklearn.base import BaseEstimator 2 | from .ctmc_func import ctmc 3 | from .simulate import simulate 4 | import numpy as np 5 | 6 | 7 | class Ctmc(BaseEstimator): 8 | """Continous Time Markov Chain, sklearn API class""" 9 | 10 | def __init__(self, numstates: int = None, transintv: float = 1.0, 11 | toltime: float = 1e-8, autocorrect: bool = False, 12 | debug: bool = False): 13 | self.numstates = numstates 14 | self.transintv = transintv 15 | self.toltime = toltime 16 | self.autocorrect = autocorrect 17 | self.debug = debug 18 | 19 | def fit(self, X: list, y=None): 20 | """Calls the ctmc.ctmc function 21 | 22 | Parameters 23 | ---------- 24 | X : list of lists 25 | (see ctmc function 'data') 26 | 27 | y 28 | not used, present for API consistence purpose. 29 | """ 30 | self.transmat, self.genmat, self.transcount, self.statetime = ctmc( 31 | X, numstates=self.numstates, 32 | transintv=self.transintv, 33 | toltime=self.toltime, 34 | autocorrect=self.autocorrect, 35 | debug=self.debug) 36 | return self 37 | 38 | def predict(self, X: np.ndarray, steps: int = 1) -> np.ndarray: 39 | """ 40 | Parameters 41 | ---------- 42 | X : ndarray 43 | vector with state variables at t 44 | 45 | Returns 46 | ------- 47 | C : ndarray 48 | vector with state variables at t+1 49 | """ 50 | return simulate(X, self.transmat, steps) 51 | -------------------------------------------------------------------------------- /profile/speed (timeit).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Total execution time" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pickle\n", 17 | "import sys\n", 18 | "sys.path.append('..')\n", 19 | "import ctmc" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "with open(\"../data/example.pkl\", \"rb\") as f:\n", 29 | " datalist = pickle.load(f)\n", 30 | "numstates = 9" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "3.19 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "%timeit ctmc.ctmc(datalist, numstates, 1.0)" 48 | ] 49 | } 50 | ], 51 | "metadata": { 52 | "kernelspec": { 53 | "display_name": "Python 3", 54 | "language": "python", 55 | "name": "python3" 56 | }, 57 | "language_info": { 58 | "codemirror_mode": { 59 | "name": "ipython", 60 | "version": 3 61 | }, 62 | "file_extension": ".py", 63 | "mimetype": "text/x-python", 64 | "name": "python", 65 | "nbconvert_exporter": "python", 66 | "pygments_lexer": "ipython3", 67 | "version": "3.6.2" 68 | } 69 | }, 70 | "nbformat": 4, 71 | "nbformat_minor": 2 72 | } 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | 107 | # other 108 | .vscode 109 | profile/data* 110 | -------------------------------------------------------------------------------- /profile/funcalls (prun).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Execution Time by Function Calls" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 4, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pickle\n", 17 | "import sys\n", 18 | "sys.path.append('..')\n", 19 | "import ctmc" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 5, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "with open(\"../data/example.pkl\", \"rb\") as f:\n", 29 | " datalist = pickle.load(f)\n", 30 | "numstates = 9" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 10, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | " " 43 | ] 44 | }, 45 | { 46 | "data": { 47 | "text/plain": [ 48 | " 3130 function calls in 0.008 seconds\n", 49 | "\n", 50 | " Ordered by: internal time\n", 51 | " List reduced from 105 to 1 due to restriction <'ctmc_func.py'>\n", 52 | "\n", 53 | " ncalls tottime percall cumtime percall filename:lineno(function)\n", 54 | " 1 0.000 0.000 0.008 0.008 ctmc_func.py:9(ctmc)" 55 | ] 56 | }, 57 | "metadata": {}, 58 | "output_type": "display_data" 59 | } 60 | ], 61 | "source": [ 62 | "%prun -l ctmc_func.py ctmc.ctmc(datalist, numstates, 1.0)" 63 | ] 64 | } 65 | ], 66 | "metadata": { 67 | "kernelspec": { 68 | "display_name": "Python 3", 69 | "language": "python", 70 | "name": "python3" 71 | }, 72 | "language_info": { 73 | "codemirror_mode": { 74 | "name": "ipython", 75 | "version": 3 76 | }, 77 | "file_extension": ".py", 78 | "mimetype": "text/x-python", 79 | "name": "python", 80 | "nbconvert_exporter": "python", 81 | "pygments_lexer": "ipython3", 82 | "version": "3.6.2" 83 | } 84 | }, 85 | "nbformat": 4, 86 | "nbformat_minor": 2 87 | } 88 | -------------------------------------------------------------------------------- /test/test_datacorrection.py: -------------------------------------------------------------------------------- 1 | 2 | # import sys 3 | # sys.path.append('..') 4 | import ctmc 5 | 6 | import unittest 7 | import numpy as np 8 | import numpy.testing as npt 9 | 10 | 11 | def flatten(x): 12 | import itertools 13 | return list(itertools.chain.from_iterable( 14 | itertools.chain.from_iterable(x))) 15 | 16 | 17 | class Test_Datacorrection(unittest.TestCase): 18 | 19 | def test1(self): 20 | sol = [[(2, 3), (1.6, 2.7)]] 21 | 22 | test = [([1], [0.5]), 23 | ([2, 3], [1.6, 2.7])] 24 | res = ctmc.datacorrection(test) 25 | 26 | # npt.assert_allclose(res, sol) 27 | npt.assert_allclose(flatten(res), flatten(sol)) 28 | 29 | def test2(self): 30 | sol = [[(2, 3), (1.6, 2.7)], [(4, 5, 6), (0.1, 0.2, 0.1)]] 31 | 32 | test = [([2, 3], [1.6, 2.7]), 33 | ([4, 5, 5, 6], [.1, .1, .1, .1])] 34 | res = ctmc.datacorrection(test) 35 | 36 | npt.assert_allclose(flatten(res), flatten(sol)) 37 | 38 | def test3(self): 39 | sol = [[(4, 5, 6), (0.1, 0.2, 0.1)]] 40 | 41 | test = [([4, 5, 5, 6], [.1, .1, .1, .1]), 42 | ([7, 7, 7, 7], [.1, .1, .1, .1])] 43 | res = ctmc.datacorrection(test) 44 | 45 | npt.assert_allclose(flatten(res), flatten(sol)) 46 | 47 | def test4(self): 48 | sol = [[(1, 2, 3), (0.1, 0.1, 0.1)], [(4, 6), (0.1, 0.1)]] 49 | 50 | test = [([1, 2, 3], [.1, .1, .1]), 51 | ([4, 5, 6], [.1, .0, .1])] 52 | res = ctmc.datacorrection(test) 53 | 54 | npt.assert_allclose(flatten(res), flatten(sol)) 55 | 56 | def test5(self): 57 | sol = [[(1, 2, 3), (0.1, 0.1, 0.1)], [(4, 6), (0.1, 0.1)]] 58 | 59 | test = [([1, 2, 3], [.1, .1, .1]), 60 | ([4, 5, 6], [.1, .0, .1])] 61 | res = ctmc.datacorrection(test) 62 | 63 | npt.assert_allclose(flatten(res), flatten(sol)) 64 | 65 | def test6(self): 66 | sol = [[(4, 6), (0.1, 0.1)]] 67 | 68 | test = [([4, 5, 6], [.1, .0, .1]), 69 | ([7, 8, 9], [.1, .0, .0])] 70 | res = ctmc.datacorrection(test) 71 | 72 | npt.assert_allclose(flatten(res), flatten(sol)) 73 | 74 | 75 | # run 76 | if __name__ == '__main__': 77 | unittest.main() 78 | -------------------------------------------------------------------------------- /test/test_datacheck.py: -------------------------------------------------------------------------------- 1 | import ctmc 2 | import unittest 3 | 4 | 5 | class Test_Datacheck(unittest.TestCase): 6 | 7 | def test1(self): 8 | data = [ 9 | ([0, 1], [.7, .7]), # ok 10 | ([0, 1, 2], [0.5, 0.5, 0.5]) # example id=1 doesn't work 11 | ] 12 | n_states = 2 # but there are 3 states 13 | toltime = 1e-8 14 | 15 | with self.assertRaises(Exception) as context: 16 | ctmc.datacheck(data, n_states, toltime) 17 | 18 | # print("\ntest1: " + str(context.exception)) 19 | self.assertTrue( 20 | 'has faulty state' 21 | in str(context.exception)) 22 | 23 | def test2(self): 24 | data = [ 25 | ([0, 1], [.7, .7]), # ok 26 | ([0], [.3]) # example id=1 doesn't work 27 | ] 28 | n_states = 2 29 | toltime = 1e-8 30 | 31 | with self.assertRaises(Exception) as context: 32 | ctmc.datacheck(data, n_states, toltime) 33 | 34 | # print("\ntest2: " + str(context.exception)) 35 | self.assertTrue( 36 | 'has only 1 distinct' 37 | in str(context.exception)) 38 | 39 | def test3(self): 40 | data = [ 41 | ([0, 1], [.7, .7]), # ok 42 | ([0, 1, 1], [.3, .3, .3]) # example id=1 doesn't work 43 | ] 44 | n_states = 2 45 | toltime = 1e-8 46 | 47 | with self.assertRaises(Exception) as context: 48 | ctmc.datacheck(data, n_states, toltime) 49 | 50 | # print("\ntest3: " + str(context.exception)) 51 | self.assertTrue( 52 | 'has two consequtive entries' 53 | in str(context.exception)) 54 | 55 | def test4(self): 56 | data = [ 57 | ([0, 1], [.7, .7]), # ok 58 | ([0, 1], [.3, .0]) # example id=1 doesn't work 59 | ] 60 | n_states = 2 61 | toltime = 1e-8 62 | 63 | with self.assertRaises(Exception) as context: 64 | ctmc.datacheck(data, n_states, toltime) 65 | 66 | # print("\ntest4: " + str(context.exception)) 67 | self.assertTrue( 68 | 'that have not been active for longer than toltime' 69 | in str(context.exception)) 70 | 71 | def test5(self): 72 | data = [ 73 | ([0, 1], [.7, .7]), # ok example 74 | ] 75 | n_states = 2 76 | toltime = 1e-8 77 | 78 | flag = ctmc.datacheck(data, n_states, toltime) 79 | self.assertFalse(flag) 80 | 81 | 82 | # run 83 | if __name__ == '__main__': 84 | unittest.main() 85 | -------------------------------------------------------------------------------- /ctmc/ctmc_func.py: -------------------------------------------------------------------------------- 1 | from .datacheck import datacheck 2 | from .aggregateevents import aggregateevents 3 | from .errorcheck import errorcheck 4 | from .generatormatrix import generatormatrix 5 | from .datacorrection import datacorrection 6 | import scipy.linalg 7 | import numpy as np 8 | 9 | 10 | def ctmc(data: list, numstates: int, transintv: float = 1.0, 11 | toltime: float = 1e-8, autocorrect: bool = False, debug: bool = False 12 | ) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray): 13 | """ Continous Time Markov Chain 14 | 15 | Parameters 16 | ---------- 17 | data : list of lists 18 | A python list of N examples (e.g. rating histories of N companies, 19 | the event data of N basketball games, etc.). The i-th example 20 | consist of one list with M_i encoded state labels and M_i the 21 | durations or time periods the state lasted since the recording 22 | started. 23 | 24 | numstates : int 25 | number of unique states 26 | 27 | transintv : float 28 | The time interval 29 | 30 | toltime : float 31 | (If debug=True) Will throw an exception if the aggregated state 32 | duration or aggregated time periods of any state is smaller 33 | than toltime. 34 | 35 | autocorrect : bool 36 | (Default: False) If True run ctmc.datacorretion function. 37 | 38 | debug : bool 39 | (Default: False) If True run the ctmc.datacheck function. 40 | Enable this flag if you to check if your 'data' variable 41 | has been processed correctly. 42 | 43 | Returns 44 | ------- 45 | transmat : ndarray 46 | The estimated transition/stochastic matrix. 47 | 48 | genmat : ndarray 49 | The estimated generator matrix 50 | 51 | transcount : ndarray 52 | 53 | statetime : ndarray 54 | 55 | 56 | Errors: 57 | ------- 58 | - ctmc assumes a clean data object and does not 59 | autocorrect any errors as result of it 60 | 61 | The main error sources are 62 | 63 | - transitions counting (e.g. two consequtive states 64 | has not been aggregated, only one distinct state 65 | reported) and 66 | - a state is modeled ore required that does not occur 67 | in the dataset (e.g. you a certain scale in mind 68 | and just assume it's in the data) or resp. involved 69 | in any transition (e.g. an example with just one 70 | state) 71 | 72 | You can enable error checking and exceptions by setting 73 | debug=True. You should do this for the first run on a 74 | smaller dataset. 75 | 76 | Example: 77 | -------- 78 | Use `datacheck` to check during preprocessing the 79 | dataset 80 | 81 | data = ... 82 | ctmc.datacheck(data, numstates, toltime) 83 | 84 | Disable checks in `ctmc` 85 | 86 | transmat, genmat, transcount, statetime = ctmc.ctmc( 87 | data, numstates, toltime, checks=False) 88 | 89 | Check aftwards if there has been an error 90 | 91 | ctmc.errorcheck(transcount, statetime, toltime) 92 | 93 | """ 94 | # auto-correct data list 95 | if autocorrect: 96 | data = datacorrection(data, toltime) 97 | 98 | # raise an exception if the data format is wrong 99 | if debug: 100 | datacheck(data, numstates, toltime) 101 | 102 | # aggregate event data 103 | transcount, statetime = aggregateevents(data, numstates) 104 | 105 | # raise an exception if the event data aggregation failed 106 | if debug: 107 | errorcheck(transcount, statetime, toltime) 108 | 109 | # create generator matrix 110 | genmat = generatormatrix(transcount, statetime) 111 | 112 | # compute matrix exponential of the generator matrix 113 | transmat = scipy.linalg.expm(genmat * transintv) 114 | 115 | # done 116 | return transmat, genmat, transcount, statetime 117 | -------------------------------------------------------------------------------- /examples/demo datacorrection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('..')\n", 11 | "import ctmc" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## What errors will cause exceptions?\n", 19 | "* Duration or time period in a certain state is smaller than `toltime` => Delete these observations\n", 20 | "* Two consecutive states refer to the same state, i.e. there is no transition => Merge these observations\n", 21 | "* An example has just one state as observation => Delete the example" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Examples how ctmc.datacorrection behaves" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/plain": [ 39 | "[[(2, 3), (1.6, 2.7)]]" 40 | ] 41 | }, 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "output_type": "execute_result" 45 | } 46 | ], 47 | "source": [ 48 | "test = [([1], [0.5]), \n", 49 | " ([2, 3], [1.6, 2.7])]\n", 50 | "res = ctmc.datacorrection(test)\n", 51 | "res" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "data": { 61 | "text/plain": [ 62 | "[[(2, 3), (1.6, 2.7)], [(4, 5, 6), (0.1, 0.2, 0.1)]]" 63 | ] 64 | }, 65 | "execution_count": 3, 66 | "metadata": {}, 67 | "output_type": "execute_result" 68 | } 69 | ], 70 | "source": [ 71 | "test = [([2, 3], [1.6, 2.7]),\n", 72 | " ([4, 5, 5, 6], [.1, .1, .1, .1])]\n", 73 | "res = ctmc.datacorrection(test)\n", 74 | "res" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "[[(4, 5, 6), (0.1, 0.2, 0.1)]]" 86 | ] 87 | }, 88 | "execution_count": 4, 89 | "metadata": {}, 90 | "output_type": "execute_result" 91 | } 92 | ], 93 | "source": [ 94 | "test = [([4, 5, 5, 6], [.1, .1, .1, .1]),\n", 95 | " ([7, 7, 7, 7], [.1, .1, .1, .1])]\n", 96 | "res = ctmc.datacorrection(test)\n", 97 | "res" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 5, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/plain": [ 108 | "[[(1, 2, 3), (0.1, 0.1, 0.1)], [(4, 6), (0.1, 0.1)]]" 109 | ] 110 | }, 111 | "execution_count": 5, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "test = [([1, 2, 3], [.1, .1, .1]),\n", 118 | " ([4, 5, 6], [.1, .0, .1])]\n", 119 | "res = ctmc.datacorrection(test)\n", 120 | "res" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 6, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "[[(4, 6), (0.1, 0.1)]]" 132 | ] 133 | }, 134 | "execution_count": 6, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | } 138 | ], 139 | "source": [ 140 | "test = [([4, 5, 6], [.1, .0, .1]),\n", 141 | " ([7, 8, 9], [.1, .0, .0])]\n", 142 | "res = ctmc.datacorrection(test)\n", 143 | "res" 144 | ] 145 | } 146 | ], 147 | "metadata": { 148 | "kernelspec": { 149 | "display_name": "Python 3", 150 | "language": "python", 151 | "name": "python3" 152 | }, 153 | "language_info": { 154 | "codemirror_mode": { 155 | "name": "ipython", 156 | "version": 3 157 | }, 158 | "file_extension": ".py", 159 | "mimetype": "text/x-python", 160 | "name": "python", 161 | "nbconvert_exporter": "python", 162 | "pygments_lexer": "ipython3", 163 | "version": "3.6.2" 164 | } 165 | }, 166 | "nbformat": 4, 167 | "nbformat_minor": 2 168 | } 169 | -------------------------------------------------------------------------------- /examples/demo datacheck.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pickle\n", 10 | "\n", 11 | "import pprint\n", 12 | "pp = pprint.PrettyPrinter(indent=4)\n", 13 | "\n", 14 | "import sys\n", 15 | "sys.path.append('..')\n", 16 | "import ctmc" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "## Load Demo Dataset\n", 24 | "A preprocessed data list is used." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "with open(\"../data/example.pkl\", \"rb\") as f:\n", 34 | " datalist = pickle.load(f)\n", 35 | "numstates = 9" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## ctmc.datacheck\n", 43 | "However the file can contain inconsistencies" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "The example id=40 has a state[2] that have not been active for longer than toltime\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "try:\n", 61 | " ctmc.datacheck(datalist, numstates, toltime=1e-8)\n", 62 | "except Exception as e:\n", 63 | " print(e)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "## Correct the error\n", 71 | "The 40-th example has a states $8$ with a duration of $0.0$." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 4, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "data": { 81 | "text/plain": [ 82 | "([3, 4, 8, 5],\n", 83 | " [1.8164383561643835, 0.1178082191780822, 0.0, 1.3415300546448088])" 84 | ] 85 | }, 86 | "execution_count": 4, 87 | "metadata": {}, 88 | "output_type": "execute_result" 89 | } 90 | ], 91 | "source": [ 92 | "datalist[40]" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "The quick fix is to remove it" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 5, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "data": { 109 | "text/plain": [ 110 | "([3, 4, 5], [1.8164383561643835, 0.1178082191780822, 1.3415300546448088])" 111 | ] 112 | }, 113 | "execution_count": 5, 114 | "metadata": {}, 115 | "output_type": "execute_result" 116 | } 117 | ], 118 | "source": [ 119 | "del datalist[40][0][2]\n", 120 | "del datalist[40][1][2]\n", 121 | "datalist[40]" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "## Try it again" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 6, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "try:\n", 138 | " ctmc.datacheck(datalist, numstates, toltime=1e-8)\n", 139 | "except Exception as e:\n", 140 | " print(e)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "We were able to find the needle in the haystack.\n", 148 | "However, correcting each error manually is not really user-friendly.\n", 149 | "\n", 150 | "Better try \n", 151 | "\n", 152 | "```\n", 153 | "newlist = ctmc.datacorrection(oldlist)\n", 154 | "```" 155 | ] 156 | } 157 | ], 158 | "metadata": { 159 | "kernelspec": { 160 | "display_name": "Python 3", 161 | "language": "python", 162 | "name": "python3" 163 | }, 164 | "language_info": { 165 | "codemirror_mode": { 166 | "name": "ipython", 167 | "version": 3 168 | }, 169 | "file_extension": ".py", 170 | "mimetype": "text/x-python", 171 | "name": "python", 172 | "nbconvert_exporter": "python", 173 | "pygments_lexer": "ipython3", 174 | "version": "3.6.2" 175 | } 176 | }, 177 | "nbformat": 4, 178 | "nbformat_minor": 2 179 | } 180 | -------------------------------------------------------------------------------- /examples/automatic data correction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pickle\n", 10 | "import numpy as np\n", 11 | "\n", 12 | "import sys\n", 13 | "sys.path.append('..')\n", 14 | "import ctmc" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Load Demo Dataset\n", 22 | "A preprocessed data list is used." 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "with open(\"../data/example.pkl\", \"rb\") as f:\n", 32 | " datalist = pickle.load(f)\n", 33 | "numstates = 9" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "# Fit with uncorrected data" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "data": { 50 | "text/plain": [ 51 | "array([[0.96, 0.04, 0. , 0. , 0. , 0. , 0. , 0. , 0. ],\n", 52 | " [0.04, 0.9 , 0.06, 0. , 0. , 0. , 0. , 0. , 0. ],\n", 53 | " [0. , 0.03, 0.89, 0.08, 0. , 0. , 0. , 0. , 0. ],\n", 54 | " [0. , 0. , 0.06, 0.86, 0.06, 0.01, 0. , 0. , 0. ],\n", 55 | " [0. , 0. , 0. , 0.11, 0.8 , 0.08, 0. , 0. , 0.01],\n", 56 | " [0. , 0. , 0. , 0.01, 0.1 , 0.79, 0.05, 0.01, 0.04],\n", 57 | " [0. , 0. , 0. , 0. , 0.01, 0.21, 0.53, 0.17, 0.07],\n", 58 | " [0. , 0. , 0. , 0. , 0.03, 0.38, 0.06, 0.52, 0.01],\n", 59 | " [0. , 0. , 0. , 0. , 0. , 0.03, 0. , 0. , 0.97]])" 60 | ] 61 | }, 62 | "execution_count": 3, 63 | "metadata": {}, 64 | "output_type": "execute_result" 65 | } 66 | ], 67 | "source": [ 68 | "model1 = ctmc.Ctmc(numstates, transintv=1.0, toltime=1e-8, debug=False)\n", 69 | "model1 = model1.fit(datalist)\n", 70 | "mat1 = model1.transmat\n", 71 | "mat1.round(2)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "# Fit with auto-corrected data" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "data": { 88 | "text/plain": [ 89 | "array([[0.96, 0.04, 0. , 0. , 0. , 0. , 0. , 0. , 0. ],\n", 90 | " [0.04, 0.9 , 0.06, 0. , 0. , 0. , 0. , 0. , 0. ],\n", 91 | " [0. , 0.03, 0.89, 0.08, 0. , 0. , 0. , 0. , 0. ],\n", 92 | " [0. , 0. , 0.06, 0.86, 0.06, 0.01, 0. , 0. , 0. ],\n", 93 | " [0. , 0. , 0. , 0.11, 0.8 , 0.08, 0. , 0. , 0. ],\n", 94 | " [0. , 0. , 0. , 0.01, 0.1 , 0.79, 0.05, 0.01, 0.04],\n", 95 | " [0. , 0. , 0. , 0. , 0.01, 0.21, 0.53, 0.17, 0.07],\n", 96 | " [0. , 0. , 0. , 0. , 0.03, 0.38, 0.06, 0.52, 0.01],\n", 97 | " [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 1. ]])" 98 | ] 99 | }, 100 | "execution_count": 4, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "model2 = ctmc.Ctmc(numstates, transintv=1.0, toltime=1e-8, autocorrect=True, debug=False)\n", 107 | "model2 = model2.fit(datalist)\n", 108 | "mat2 = model2.transmat\n", 109 | "mat2.round(2)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "# Differences" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 5, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "text/plain": [ 127 | "array([[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , -0. ],\n", 128 | " [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , -0. ],\n", 129 | " [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , -0. ],\n", 130 | " [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , -0. ],\n", 131 | " [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , -0. ],\n", 132 | " [-0. , -0. , -0. , -0. , -0. , -0. , -0. , -0. , 0. ],\n", 133 | " [-0. , -0. , -0. , -0. , -0. , -0. , -0. , -0. , 0. ],\n", 134 | " [-0. , -0. , -0. , -0. , -0. , -0. , -0. , -0. , 0. ],\n", 135 | " [-0. , -0. , -0. , -0. , -0. , -0.03, -0. , -0. , 0.03]])" 136 | ] 137 | }, 138 | "execution_count": 5, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "(mat2 - mat1).round(2)" 145 | ] 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "codemirror_mode": { 156 | "name": "ipython", 157 | "version": 3 158 | }, 159 | "file_extension": ".py", 160 | "mimetype": "text/x-python", 161 | "name": "python", 162 | "nbconvert_exporter": "python", 163 | "pygments_lexer": "ipython3", 164 | "version": "3.7.1" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 2 169 | } 170 | -------------------------------------------------------------------------------- /examples/demo ctmc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pickle\n", 10 | "\n", 11 | "import pprint\n", 12 | "pp = pprint.PrettyPrinter(indent=4)\n", 13 | "\n", 14 | "import sys\n", 15 | "sys.path.append('..')\n", 16 | "import ctmc" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "## Load Demo Dataset\n", 24 | "A preprocessed data list is used." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "with open(\"../data/example.pkl\", \"rb\") as f:\n", 34 | " datalist = pickle.load(f)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "The number of states is 9." 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "numstates = 9" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## Visual Inspection\n", 58 | "`datalist` is a list of examples.\n", 59 | "Each example consist of two lists.\n", 60 | "The first list contains encoded state labels.\n", 61 | "The second list contains the durations or time periods the corresponding states has been active." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "[ ([4, 3], [8.467213114754099, 4.371584699453552]),\n", 74 | " ([4, 3, 2], [0.6147540983606558, 10.616438356164384, 5.576502732240437])]\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "pp.pprint(datalist[49:51])" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "## Estimate Markov Model\n", 87 | "`ctmc` with `debug=True` will throw an exception if something is wrong." 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 5, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "The example id=40 has a state[2] that have not been active for longer than toltime\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "try:\n", 105 | " transmat, genmat, transcount, statetime = ctmc.ctmc(\n", 106 | " datalist, numstates, 1.0, toltime=1e-8, debug=True)\n", 107 | "except Exception as e:\n", 108 | " print(e)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "With `debug=False` (Default) `ctmc` is very fast but might crash at a later or generate bogus results.\n", 116 | "Therefore, the faulty data should be corrected." 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "## Repeat Estimation" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 6, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "newlist = ctmc.datacorrection(datalist, toltime=1e-8)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 7, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "try:\n", 142 | " transmat, genmat, transcount, statetime = ctmc.ctmc(\n", 143 | " newlist, numstates, 1.0, toltime=1e-8, debug=True)\n", 144 | "except Exception as e:\n", 145 | " print(e)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 8, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "data": { 155 | "text/plain": [ 156 | "array([[0.96, 0.04, 0. , 0. , 0. , 0. , 0. , 0. , 0. ],\n", 157 | " [0.04, 0.9 , 0.06, 0. , 0. , 0. , 0. , 0. , 0. ],\n", 158 | " [0. , 0.03, 0.89, 0.08, 0. , 0. , 0. , 0. , 0. ],\n", 159 | " [0. , 0. , 0.06, 0.86, 0.06, 0.01, 0. , 0. , 0. ],\n", 160 | " [0. , 0. , 0. , 0.11, 0.8 , 0.08, 0. , 0. , 0. ],\n", 161 | " [0. , 0. , 0. , 0.01, 0.1 , 0.79, 0.05, 0.01, 0.04],\n", 162 | " [0. , 0. , 0. , 0. , 0.01, 0.21, 0.53, 0.17, 0.07],\n", 163 | " [0. , 0. , 0. , 0. , 0.03, 0.38, 0.06, 0.52, 0.01],\n", 164 | " [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 1. ]])" 165 | ] 166 | }, 167 | "execution_count": 8, 168 | "metadata": {}, 169 | "output_type": "execute_result" 170 | } 171 | ], 172 | "source": [ 173 | "transmat.round(2)" 174 | ] 175 | } 176 | ], 177 | "metadata": { 178 | "kernelspec": { 179 | "display_name": "Python 3", 180 | "language": "python", 181 | "name": "python3" 182 | }, 183 | "language_info": { 184 | "codemirror_mode": { 185 | "name": "ipython", 186 | "version": 3 187 | }, 188 | "file_extension": ".py", 189 | "mimetype": "text/x-python", 190 | "name": "python", 191 | "nbconvert_exporter": "python", 192 | "pygments_lexer": "ipython3", 193 | "version": "3.6.2" 194 | } 195 | }, 196 | "nbformat": 4, 197 | "nbformat_minor": 2 198 | } 199 | -------------------------------------------------------------------------------- /examples/internal data format of ctmc_fit.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### The Internal Data Format for ctmc\n", 8 | "The function `ctmc` or `Ctmc` class expect the data to be structured as follows" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": { 15 | "ExecuteTime": { 16 | "end_time": "2018-09-01T20:16:48.521661Z", 17 | "start_time": "2018-09-01T20:16:48.511886Z" 18 | } 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "data = [([0, 1, 2, 1], [2.2, 3.35, 9.4, 1.3]), \n", 23 | " ([1, 0, 1], [4.0, 1.25, 1.7])]" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "Each example or event chain is one element in a array `data`.\n", 31 | "\n", 32 | "* The first entry of entry of an example row is a list of **states**, \n", 33 | "* the second entry a list **time periods** a state lasted.\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "### How does it work in ctmc?" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "Initialize variables" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "metadata": { 54 | "ExecuteTime": { 55 | "end_time": "2018-09-01T20:17:32.967965Z", 56 | "start_time": "2018-09-01T20:17:32.961972Z" 57 | } 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "import numpy as np\n", 62 | "numstates = 3\n", 63 | "statetime = np.zeros(numstates, dtype=float)\n", 64 | "transcount = np.zeros(shape=(numstates, numstates), dtype=int)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "Loop over all examples, \n", 72 | "and cumulate time periods and count transitions across all examples." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": { 79 | "ExecuteTime": { 80 | "end_time": "2018-09-01T20:20:46.174968Z", 81 | "start_time": "2018-09-01T20:20:46.169468Z" 82 | } 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "for _, example in enumerate(data):\n", 87 | " states = example[0]\n", 88 | " times = example[1]\n", 89 | " \n", 90 | " for i,s in enumerate(states):\n", 91 | " statetime[s] += times[i]\n", 92 | " if i: transcount[states[i-1], s] += 1" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "The intermediate results are" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 4, 105 | "metadata": { 106 | "ExecuteTime": { 107 | "end_time": "2018-09-01T20:21:29.070819Z", 108 | "start_time": "2018-09-01T20:21:29.059698Z" 109 | } 110 | }, 111 | "outputs": [ 112 | { 113 | "data": { 114 | "text/plain": [ 115 | "array([ 3.45, 10.35, 9.4 ])" 116 | ] 117 | }, 118 | "execution_count": 4, 119 | "metadata": {}, 120 | "output_type": "execute_result" 121 | } 122 | ], 123 | "source": [ 124 | "statetime" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 5, 130 | "metadata": { 131 | "ExecuteTime": { 132 | "end_time": "2018-09-01T20:21:29.759006Z", 133 | "start_time": "2018-09-01T20:21:29.750955Z" 134 | } 135 | }, 136 | "outputs": [ 137 | { 138 | "data": { 139 | "text/plain": [ 140 | "array([[0, 2, 0],\n", 141 | " [1, 0, 1],\n", 142 | " [0, 1, 0]])" 143 | ] 144 | }, 145 | "execution_count": 5, 146 | "metadata": {}, 147 | "output_type": "execute_result" 148 | } 149 | ], 150 | "source": [ 151 | "transcount" 152 | ] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "Python 3", 158 | "language": "python", 159 | "name": "python3" 160 | }, 161 | "language_info": { 162 | "codemirror_mode": { 163 | "name": "ipython", 164 | "version": 3 165 | }, 166 | "file_extension": ".py", 167 | "mimetype": "text/x-python", 168 | "name": "python", 169 | "nbconvert_exporter": "python", 170 | "pygments_lexer": "ipython3", 171 | "version": "3.6.2" 172 | }, 173 | "toc": { 174 | "base_numbering": 1, 175 | "nav_menu": {}, 176 | "number_sections": true, 177 | "sideBar": true, 178 | "skip_h1_title": false, 179 | "title_cell": "Table of Contents", 180 | "title_sidebar": "Contents", 181 | "toc_cell": false, 182 | "toc_position": {}, 183 | "toc_section_display": true, 184 | "toc_window_display": false 185 | }, 186 | "varInspector": { 187 | "cols": { 188 | "lenName": 16, 189 | "lenType": 16, 190 | "lenVar": 40 191 | }, 192 | "kernels_config": { 193 | "python": { 194 | "delete_cmd_postfix": "", 195 | "delete_cmd_prefix": "del ", 196 | "library": "var_list.py", 197 | "varRefreshCmd": "print(var_dic_list())" 198 | }, 199 | "r": { 200 | "delete_cmd_postfix": ") ", 201 | "delete_cmd_prefix": "rm(", 202 | "library": "var_list.r", 203 | "varRefreshCmd": "cat(var_dic_list()) " 204 | } 205 | }, 206 | "types_to_exclude": [ 207 | "module", 208 | "function", 209 | "builtin_function_or_method", 210 | "instance", 211 | "_Feature" 212 | ], 213 | "window_display": false 214 | } 215 | }, 216 | "nbformat": 4, 217 | "nbformat_minor": 2 218 | } 219 | -------------------------------------------------------------------------------- /profile/memory (mprun).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Memory Profile" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pickle\n", 17 | "import sys\n", 18 | "sys.path.append('..')\n", 19 | "import ctmc" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "with open(\"../data/example.pkl\", \"rb\") as f:\n", 29 | " datalist = pickle.load(f)\n", 30 | "numstates = 9" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "%load_ext memory_profiler" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 4, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "peak memory: 53.62 MiB, increment: 0.24 MiB\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "%memit ctmc.ctmc(datalist, numstates, 1.0)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## ctmc.ctmc" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 5, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "\n" 76 | ] 77 | }, 78 | { 79 | "data": { 80 | "text/plain": [ 81 | "Filename: ../ctmc/ctmc_func.py\n", 82 | "\n", 83 | "Line # Mem usage Increment Line Contents\n", 84 | "================================================\n", 85 | " 9 53.9 MiB 53.9 MiB def ctmc(data, numstates, transintv=1.0, toltime=1e-8, debug=False):\n", 86 | " 10 \"\"\" Continous Time Markov Chain\n", 87 | " 11 \n", 88 | " 12 Parameters\n", 89 | " 13 ----------\n", 90 | " 14 data : list of lists\n", 91 | " 15 A python list of N examples (e.g. rating histories of N companies,\n", 92 | " 16 the event data of N basketball games, etc.). The i-th example\n", 93 | " 17 consist of one list with M_i encoded state labels and M_i the\n", 94 | " 18 durations or time periods the state lasted since the recording\n", 95 | " 19 started.\n", 96 | " 20 \n", 97 | " 21 numstates : int\n", 98 | " 22 number of unique states\n", 99 | " 23 \n", 100 | " 24 transintv : float\n", 101 | " 25 The time interval\n", 102 | " 26 \n", 103 | " 27 toltime : float\n", 104 | " 28 (If debug=True) Will throw an exception if the aggregated state\n", 105 | " 29 duration or aggregated time periods of any state is smaller\n", 106 | " 30 than toltime.\n", 107 | " 31 \n", 108 | " 32 debug : bool\n", 109 | " 33 (Default: False) If True run the ctmc.datacheck function.\n", 110 | " 34 Enable this flag if you to check if your 'data' variable\n", 111 | " 35 has been processed correctly.\n", 112 | " 36 \n", 113 | " 37 Returns\n", 114 | " 38 -------\n", 115 | " 39 transmat : ndarray\n", 116 | " 40 The estimated transition/stochastic matrix.\n", 117 | " 41 \n", 118 | " 42 genmat : ndarray\n", 119 | " 43 The estimated generator matrix\n", 120 | " 44 \n", 121 | " 45 transcount : ndarray\n", 122 | " 46 \n", 123 | " 47 statetime : ndarray\n", 124 | " 48 \n", 125 | " 49 \n", 126 | " 50 Errors:\n", 127 | " 51 -------\n", 128 | " 52 - ctmc assumes a clean data object and does not\n", 129 | " 53 autocorrect any errors as result of it\n", 130 | " 54 \n", 131 | " 55 The main error sources are\n", 132 | " 56 \n", 133 | " 57 - transitions counting (e.g. two consequtive states\n", 134 | " 58 has not been aggregated, only one distinct state\n", 135 | " 59 reported) and\n", 136 | " 60 - a state is modeled ore required that does not occur\n", 137 | " 61 in the dataset (e.g. you a certain scale in mind\n", 138 | " 62 and just assume it's in the data) or resp. involved\n", 139 | " 63 in any transition (e.g. an example with just one\n", 140 | " 64 state)\n", 141 | " 65 \n", 142 | " 66 You can enable error checking and exceptions by setting\n", 143 | " 67 debug=True. You should do this for the first run on a\n", 144 | " 68 smaller dataset.\n", 145 | " 69 \n", 146 | " 70 Example:\n", 147 | " 71 --------\n", 148 | " 72 Use `datacheck` to check during preprocessing the\n", 149 | " 73 dataset\n", 150 | " 74 \n", 151 | " 75 data = ...\n", 152 | " 76 ctmc.datacheck(data, numstates, toltime)\n", 153 | " 77 \n", 154 | " 78 Disable checks in `ctmc`\n", 155 | " 79 \n", 156 | " 80 transmat, genmat, transcount, statetime = ctmc.ctmc(\n", 157 | " 81 data, numstates, toltime, checks=False)\n", 158 | " 82 \n", 159 | " 83 Check aftwards if there has been an error\n", 160 | " 84 \n", 161 | " 85 ctmc.errorcheck(transcount, statetime, toltime)\n", 162 | " 86 \n", 163 | " 87 \"\"\"\n", 164 | " 88 # raise an exception if the data format is wrong\n", 165 | " 89 53.9 MiB 0.0 MiB if debug:\n", 166 | " 90 datacheck(data, numstates, toltime)\n", 167 | " 91 \n", 168 | " 92 # aggregate event data\n", 169 | " 93 53.9 MiB 0.0 MiB transcount, statetime = aggregateevents(data, numstates)\n", 170 | " 94 \n", 171 | " 95 # raise an exception if the event data aggregation failed\n", 172 | " 96 53.9 MiB 0.0 MiB if debug:\n", 173 | " 97 errorcheck(transcount, statetime, toltime)\n", 174 | " 98 \n", 175 | " 99 # create generator matrix\n", 176 | " 100 53.9 MiB 0.0 MiB genmat = generatormatrix(transcount, statetime)\n", 177 | " 101 \n", 178 | " 102 # compute matrix exponential of the generator matrix\n", 179 | " 103 53.9 MiB 0.0 MiB transmat = scipy.linalg.expm(genmat * transintv)\n", 180 | " 104 \n", 181 | " 105 # done\n", 182 | " 106 53.9 MiB 0.0 MiB return transmat, genmat, transcount, statetime" 183 | ] 184 | }, 185 | "metadata": {}, 186 | "output_type": "display_data" 187 | } 188 | ], 189 | "source": [ 190 | "%mprun -f ctmc.ctmc ctmc.ctmc(datalist, numstates, 1.0)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": {}, 196 | "source": [ 197 | "## ctmc.aggregateevents" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 6, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "\n" 210 | ] 211 | }, 212 | { 213 | "data": { 214 | "text/plain": [ 215 | "Filename: ../ctmc/aggregateevents.py\n", 216 | "\n", 217 | "Line # Mem usage Increment Line Contents\n", 218 | "================================================\n", 219 | " 6 53.9 MiB 53.9 MiB def aggregateevents(data, numstates):\n", 220 | " 7 \n", 221 | " 8 53.9 MiB 0.0 MiB transcount = scipy.sparse.lil_matrix((numstates, numstates), dtype=int)\n", 222 | " 9 53.9 MiB 0.0 MiB statetime = np.zeros(numstates, dtype=float)\n", 223 | " 10 \n", 224 | " 11 53.9 MiB 0.0 MiB for _, example in enumerate(data):\n", 225 | " 12 53.9 MiB 0.0 MiB states = example[0]\n", 226 | " 13 53.9 MiB 0.0 MiB times = example[1]\n", 227 | " 14 \n", 228 | " 15 53.9 MiB 0.0 MiB for i, s in enumerate(states):\n", 229 | " 16 53.9 MiB 0.0 MiB statetime[s] += times[i]\n", 230 | " 17 53.9 MiB 0.0 MiB if i:\n", 231 | " 18 53.9 MiB 0.0 MiB transcount[states[i - 1], s] += 1\n", 232 | " 19 \n", 233 | " 20 53.9 MiB 0.0 MiB return transcount.toarray(), statetime" 234 | ] 235 | }, 236 | "metadata": {}, 237 | "output_type": "display_data" 238 | } 239 | ], 240 | "source": [ 241 | "%mprun -f ctmc.aggregateevents ctmc.ctmc(datalist, numstates, 1.0)" 242 | ] 243 | } 244 | ], 245 | "metadata": { 246 | "kernelspec": { 247 | "display_name": "Python 3", 248 | "language": "python", 249 | "name": "python3" 250 | }, 251 | "language_info": { 252 | "codemirror_mode": { 253 | "name": "ipython", 254 | "version": 3 255 | }, 256 | "file_extension": ".py", 257 | "mimetype": "text/x-python", 258 | "name": "python", 259 | "nbconvert_exporter": "python", 260 | "pygments_lexer": "ipython3", 261 | "version": "3.6.2" 262 | } 263 | }, 264 | "nbformat": 4, 265 | "nbformat_minor": 2 266 | } 267 | -------------------------------------------------------------------------------- /profile/linebyline (lprun).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Line by Line Execution Time" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pickle\n", 17 | "import sys\n", 18 | "sys.path.append('..')\n", 19 | "import ctmc" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "with open(\"../data/example.pkl\", \"rb\") as f:\n", 29 | " datalist = pickle.load(f)\n", 30 | "numstates = 9" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "%load_ext line_profiler" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## ctmc.ctmc" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "data": { 56 | "text/plain": [ 57 | "Timer unit: 1e-06 s\n", 58 | "\n", 59 | "Total time: 0.01663 s\n", 60 | "File: ../ctmc/ctmc_func.py\n", 61 | "Function: ctmc at line 9\n", 62 | "\n", 63 | "Line # Hits Time Per Hit % Time Line Contents\n", 64 | "==============================================================\n", 65 | " 9 def ctmc(data, numstates, transintv=1.0, toltime=1e-8, debug=False):\n", 66 | " 10 \"\"\" Continous Time Markov Chain\n", 67 | " 11 \n", 68 | " 12 Parameters\n", 69 | " 13 ----------\n", 70 | " 14 data : list of lists\n", 71 | " 15 A python list of N examples (e.g. rating histories of N companies,\n", 72 | " 16 the event data of N basketball games, etc.). The i-th example\n", 73 | " 17 consist of one list with M_i encoded state labels and M_i the\n", 74 | " 18 durations or time periods the state lasted since the recording\n", 75 | " 19 started.\n", 76 | " 20 \n", 77 | " 21 numstates : int\n", 78 | " 22 number of unique states\n", 79 | " 23 \n", 80 | " 24 transintv : float\n", 81 | " 25 The time interval\n", 82 | " 26 \n", 83 | " 27 toltime : float\n", 84 | " 28 (If debug=True) Will throw an exception if the aggregated state\n", 85 | " 29 duration or aggregated time periods of any state is smaller\n", 86 | " 30 than toltime.\n", 87 | " 31 \n", 88 | " 32 debug : bool\n", 89 | " 33 (Default: False) If True run the ctmc.datacheck function.\n", 90 | " 34 Enable this flag if you to check if your 'data' variable\n", 91 | " 35 has been processed correctly.\n", 92 | " 36 \n", 93 | " 37 Returns\n", 94 | " 38 -------\n", 95 | " 39 transmat : ndarray\n", 96 | " 40 The estimated transition/stochastic matrix.\n", 97 | " 41 \n", 98 | " 42 genmat : ndarray\n", 99 | " 43 The estimated generator matrix\n", 100 | " 44 \n", 101 | " 45 transcount : ndarray\n", 102 | " 46 \n", 103 | " 47 statetime : ndarray\n", 104 | " 48 \n", 105 | " 49 \n", 106 | " 50 Errors:\n", 107 | " 51 -------\n", 108 | " 52 - ctmc assumes a clean data object and does not\n", 109 | " 53 autocorrect any errors as result of it\n", 110 | " 54 \n", 111 | " 55 The main error sources are\n", 112 | " 56 \n", 113 | " 57 - transitions counting (e.g. two consequtive states\n", 114 | " 58 has not been aggregated, only one distinct state\n", 115 | " 59 reported) and\n", 116 | " 60 - a state is modeled ore required that does not occur\n", 117 | " 61 in the dataset (e.g. you a certain scale in mind\n", 118 | " 62 and just assume it's in the data) or resp. involved\n", 119 | " 63 in any transition (e.g. an example with just one\n", 120 | " 64 state)\n", 121 | " 65 \n", 122 | " 66 You can enable error checking and exceptions by setting\n", 123 | " 67 debug=True. You should do this for the first run on a\n", 124 | " 68 smaller dataset.\n", 125 | " 69 \n", 126 | " 70 Example:\n", 127 | " 71 --------\n", 128 | " 72 Use `datacheck` to check during preprocessing the\n", 129 | " 73 dataset\n", 130 | " 74 \n", 131 | " 75 data = ...\n", 132 | " 76 ctmc.datacheck(data, numstates, toltime)\n", 133 | " 77 \n", 134 | " 78 Disable checks in `ctmc`\n", 135 | " 79 \n", 136 | " 80 transmat, genmat, transcount, statetime = ctmc.ctmc(\n", 137 | " 81 data, numstates, toltime, checks=False)\n", 138 | " 82 \n", 139 | " 83 Check aftwards if there has been an error\n", 140 | " 84 \n", 141 | " 85 ctmc.errorcheck(transcount, statetime, toltime)\n", 142 | " 86 \n", 143 | " 87 \"\"\"\n", 144 | " 88 # raise an exception if the data format is wrong\n", 145 | " 89 1 11.0 11.0 0.1 if debug:\n", 146 | " 90 datacheck(data, numstates, toltime)\n", 147 | " 91 \n", 148 | " 92 # aggregate event data\n", 149 | " 93 1 14128.0 14128.0 85.0 transcount, statetime = aggregateevents(data, numstates)\n", 150 | " 94 \n", 151 | " 95 # raise an exception if the event data aggregation failed\n", 152 | " 96 1 3.0 3.0 0.0 if debug:\n", 153 | " 97 errorcheck(transcount, statetime, toltime)\n", 154 | " 98 \n", 155 | " 99 # create generator matrix\n", 156 | " 100 1 264.0 264.0 1.6 genmat = generatormatrix(transcount, statetime)\n", 157 | " 101 \n", 158 | " 102 # compute matrix exponential of the generator matrix\n", 159 | " 103 1 2222.0 2222.0 13.4 transmat = scipy.linalg.expm(genmat * transintv)\n", 160 | " 104 \n", 161 | " 105 # done\n", 162 | " 106 1 2.0 2.0 0.0 return transmat, genmat, transcount, statetime" 163 | ] 164 | }, 165 | "metadata": {}, 166 | "output_type": "display_data" 167 | } 168 | ], 169 | "source": [ 170 | "%lprun -f ctmc.ctmc ctmc.ctmc(datalist, numstates, 1.0)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "## ctmc.aggregateevents" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 5, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "data": { 187 | "text/plain": [ 188 | "Timer unit: 1e-06 s\n", 189 | "\n", 190 | "Total time: 0.009175 s\n", 191 | "File: ../ctmc/aggregateevents.py\n", 192 | "Function: aggregateevents at line 6\n", 193 | "\n", 194 | "Line # Hits Time Per Hit % Time Line Contents\n", 195 | "==============================================================\n", 196 | " 6 def aggregateevents(data, numstates):\n", 197 | " 7 \n", 198 | " 8 1 151.0 151.0 1.6 transcount = scipy.sparse.lil_matrix((numstates, numstates), dtype=int)\n", 199 | " 9 1 6.0 6.0 0.1 statetime = np.zeros(numstates, dtype=float)\n", 200 | " 10 \n", 201 | " 11 72 88.0 1.2 1.0 for _, example in enumerate(data):\n", 202 | " 12 71 86.0 1.2 0.9 states = example[0]\n", 203 | " 13 71 75.0 1.1 0.8 times = example[1]\n", 204 | " 14 \n", 205 | " 15 316 506.0 1.6 5.5 for i, s in enumerate(states):\n", 206 | " 16 245 509.0 2.1 5.5 statetime[s] += times[i]\n", 207 | " 17 245 310.0 1.3 3.4 if i:\n", 208 | " 18 174 7388.0 42.5 80.5 transcount[states[i - 1], s] += 1\n", 209 | " 19 \n", 210 | " 20 1 56.0 56.0 0.6 return transcount.toarray(), statetime" 211 | ] 212 | }, 213 | "metadata": {}, 214 | "output_type": "display_data" 215 | } 216 | ], 217 | "source": [ 218 | "%lprun -f ctmc.aggregateevents ctmc.ctmc(datalist, numstates, 1.0)" 219 | ] 220 | } 221 | ], 222 | "metadata": { 223 | "kernelspec": { 224 | "display_name": "Python 3", 225 | "language": "python", 226 | "name": "python3" 227 | }, 228 | "language_info": { 229 | "codemirror_mode": { 230 | "name": "ipython", 231 | "version": 3 232 | }, 233 | "file_extension": ".py", 234 | "mimetype": "text/x-python", 235 | "name": "python", 236 | "nbconvert_exporter": "python", 237 | "pygments_lexer": "ipython3", 238 | "version": "3.6.2" 239 | } 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 2 243 | } 244 | -------------------------------------------------------------------------------- /examples/demo Ctmc sklearn API.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pickle\n", 10 | "import numpy as np\n", 11 | "\n", 12 | "import sys\n", 13 | "sys.path.append('..')\n", 14 | "import ctmc\n", 15 | "\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "%matplotlib inline" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "## Load Demo Dataset\n", 25 | "A preprocessed data list is used." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "with open(\"../data/example.pkl\", \"rb\") as f:\n", 35 | " datalist = pickle.load(f)\n", 36 | "numstates = 9" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "## Correct Data Errors" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "newlist = ctmc.datacorrection(datalist, toltime=1e-8)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "## Fit" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "model = ctmc.Ctmc(numstates, transintv=1.0, toltime=1e-8, debug=False)\n", 69 | "model = model.fit(newlist)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "data": { 79 | "text/plain": [ 80 | "array([[0.96, 0.04, 0. , 0. , 0. , 0. , 0. , 0. , 0. ],\n", 81 | " [0.04, 0.9 , 0.06, 0. , 0. , 0. , 0. , 0. , 0. ],\n", 82 | " [0. , 0.03, 0.89, 0.08, 0. , 0. , 0. , 0. , 0. ],\n", 83 | " [0. , 0. , 0.06, 0.86, 0.06, 0.01, 0. , 0. , 0. ],\n", 84 | " [0. , 0. , 0. , 0.11, 0.8 , 0.08, 0. , 0. , 0. ],\n", 85 | " [0. , 0. , 0. , 0.01, 0.1 , 0.79, 0.05, 0.01, 0.04],\n", 86 | " [0. , 0. , 0. , 0. , 0.01, 0.21, 0.53, 0.17, 0.07],\n", 87 | " [0. , 0. , 0. , 0. , 0.03, 0.38, 0.06, 0.52, 0.01],\n", 88 | " [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 1. ]])" 89 | ] 90 | }, 91 | "execution_count": 5, 92 | "metadata": {}, 93 | "output_type": "execute_result" 94 | } 95 | ], 96 | "source": [ 97 | "model.transmat.round(2)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "## Predict" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 6, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "s0 = np.zeros(shape=(numstates,))\n", 114 | "s0[0] = 1" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 7, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "data": { 124 | "text/plain": [ 125 | "array([9.59042201e-01, 3.96720153e-02, 1.24751277e-03, 3.75024034e-05,\n", 126 | " 7.05714400e-07, 6.12327701e-08, 8.50006366e-10, 4.32672048e-11,\n", 127 | " 4.85809264e-10])" 128 | ] 129 | }, 130 | "execution_count": 7, 131 | "metadata": {}, 132 | "output_type": "execute_result" 133 | } 134 | ], 135 | "source": [ 136 | "out = model.predict(s0)\n", 137 | "out" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 8, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "#sum(out)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "## Multi-Step Simulation" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 9, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "data": { 163 | "text/plain": [ 164 | "array([[1. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],\n", 165 | " [0.96, 0.04, 0. , 0. , 0. , 0. , 0. , 0. , 0. ],\n", 166 | " [0.92, 0.07, 0. , 0. , 0. , 0. , 0. , 0. , 0. ],\n", 167 | " [0.89, 0.1 , 0.01, 0. , 0. , 0. , 0. , 0. , 0. ],\n", 168 | " [0.86, 0.13, 0.02, 0. , 0. , 0. , 0. , 0. , 0. ],\n", 169 | " [0.83, 0.15, 0.02, 0. , 0. , 0. , 0. , 0. , 0. ],\n", 170 | " [0.8 , 0.17, 0.03, 0. , 0. , 0. , 0. , 0. , 0. ],\n", 171 | " [0.77, 0.18, 0.04, 0.01, 0. , 0. , 0. , 0. , 0. ],\n", 172 | " [0.75, 0.2 , 0.04, 0.01, 0. , 0. , 0. , 0. , 0. ],\n", 173 | " [0.73, 0.21, 0.05, 0.01, 0. , 0. , 0. , 0. , 0. ]])" 174 | ] 175 | }, 176 | "execution_count": 9, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | } 180 | ], 181 | "source": [ 182 | "out = model.predict(s0, steps=500)\n", 183 | "out[:10].round(2)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 10, 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "data": { 193 | "image/png": "\n", 194 | "text/plain": [ 195 | "
" 196 | ] 197 | }, 198 | "metadata": {}, 199 | "output_type": "display_data" 200 | } 201 | ], 202 | "source": [ 203 | "fig, ax = plt.subplots(dpi=150, facecolor='white')\n", 204 | "ax.stackplot(list(range(len(out))), out.T * 100, labels=list(range(numstates)));\n", 205 | "ax.legend(loc='upper right');\n", 206 | "ax.set_xlabel('step');\n", 207 | "ax.set_ylabel('percentage');\n", 208 | "ax.set_xlim(0, len(out));\n", 209 | "ax.set_ylim(0, 100);" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 3", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.6.2" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 2 241 | } 242 | --------------------------------------------------------------------------------