├── requirements.txt ├── maxent ├── version.py ├── __init__.py ├── core.py └── hyper.py ├── paper ├── mbp_files │ └── mbp_cs.npz ├── epidemiology_files │ ├── LS_traj_folds.npy │ ├── abc_biased_traj.npy │ ├── abc_traj_folds.npy │ ├── ls_biased_traj.npy │ ├── MaxEnt_traj_folds.npy │ └── maxent_biased_me_w.npy ├── requirements.in ├── sbi_gravitation.py ├── MBP.ipynb ├── gaussian.ipynb ├── requirements.txt ├── gravitation.ipynb └── epidemiology.ipynb ├── docs ├── source │ ├── changelog.rst │ ├── index.md │ ├── toc.rst │ ├── api.rst │ └── conf.py ├── requirements.in ├── Makefile ├── make.bat ├── _templates │ └── breadcrumbs.html └── requirements.txt ├── .pre-commit-config.yaml ├── .github └── workflows │ ├── test.yml │ ├── docs.yml │ ├── build.yml │ └── paper.yml ├── setup.py ├── .gitignore ├── tests └── test_maxent.py ├── README.md └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | pre-commit 2 | pytest 3 | -------------------------------------------------------------------------------- /maxent/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | -------------------------------------------------------------------------------- /paper/mbp_files/mbp_cs.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ur-whitelab/maxent/HEAD/paper/mbp_files/mbp_cs.npz -------------------------------------------------------------------------------- /docs/source/changelog.rst: -------------------------------------------------------------------------------- 1 | ## Change Log 2 | 3 | v1.0.0 (2022-4-5) 4 | ------------------- 5 | 6 | * Initial public release 7 | -------------------------------------------------------------------------------- /paper/epidemiology_files/LS_traj_folds.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ur-whitelab/maxent/HEAD/paper/epidemiology_files/LS_traj_folds.npy -------------------------------------------------------------------------------- /paper/epidemiology_files/abc_biased_traj.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ur-whitelab/maxent/HEAD/paper/epidemiology_files/abc_biased_traj.npy -------------------------------------------------------------------------------- /paper/epidemiology_files/abc_traj_folds.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ur-whitelab/maxent/HEAD/paper/epidemiology_files/abc_traj_folds.npy -------------------------------------------------------------------------------- /paper/epidemiology_files/ls_biased_traj.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ur-whitelab/maxent/HEAD/paper/epidemiology_files/ls_biased_traj.npy -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | Getting Started 2 | ================ 3 | 4 | ```{include} ../../README.md 5 | :relative-images: 6 | :start-line: 1 7 | ``` 8 | -------------------------------------------------------------------------------- /paper/epidemiology_files/MaxEnt_traj_folds.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ur-whitelab/maxent/HEAD/paper/epidemiology_files/MaxEnt_traj_folds.npy -------------------------------------------------------------------------------- /paper/epidemiology_files/maxent_biased_me_w.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ur-whitelab/maxent/HEAD/paper/epidemiology_files/maxent_biased_me_w.npy -------------------------------------------------------------------------------- /docs/requirements.in: -------------------------------------------------------------------------------- 1 | sphinx 2 | myst-parser~=0.15.2 3 | sphinx-rtd-theme 4 | sphinx_autodoc_typehints 5 | myst-nb 6 | sbi 7 | pynmrstar 8 | git+https://github.com/ur-whitelab/py0.git@nature_compsci 9 | pyabc 10 | seaborn 11 | scipy <= 1.8.0 12 | -------------------------------------------------------------------------------- /docs/source/toc.rst: -------------------------------------------------------------------------------- 1 | maxent 2 | ============== 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | 7 | index.md 8 | paper/gaussian.ipynb 9 | paper/gravitation.ipynb 10 | paper/epidemiology.ipynb 11 | paper/MBP.ipynb 12 | api.rst 13 | -------------------------------------------------------------------------------- /paper/requirements.in: -------------------------------------------------------------------------------- 1 | tensorflow-probability==0.12 2 | tensorflow==2.4 3 | sbi==0.15.1 4 | pyabc 5 | pandas==1.2.4 6 | numpy==1.19.2 7 | matplotlib 8 | seaborn 9 | tqdm 10 | jupyter 11 | git+https://github.com/ur-whitelab/py0.git@nature_compsci 12 | pynmrstar 13 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API 2 | ===================== 3 | 4 | .. automodule:: maxent.core 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | :exclude-members: maxent.core.AvgLayerLaplace, maxent.core.LaplaceLayer, maxent.core.AvgLayer, maxent.core.ReweightLayerLaplace 9 | 10 | .. automodule:: maxent.hyper 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | -------------------------------------------------------------------------------- /maxent/__init__.py: -------------------------------------------------------------------------------- 1 | """maxent - Maximum Entropy""" 2 | 3 | from .version import __version__ 4 | 5 | from .hyper import ParameterJoint, TrainableInputLayer, HyperMaxentModel 6 | from .core import ( 7 | Prior, 8 | EmptyPrior, 9 | Laplace, 10 | MaxentModel, 11 | Restraint, 12 | _AvgLayerLaplace, 13 | _AvgLayer, 14 | _ReweightLayer, 15 | _ReweightLayerLaplace, 16 | ) 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.2.3 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: check-yaml 7 | - id: end-of-file-fixer 8 | - id: mixed-line-ending 9 | - repo: https://github.com/psf/black 10 | rev: "22.3.0" 11 | hooks: 12 | - id: black 13 | - repo: https://github.com/tomcatling/black-nb 14 | rev: "0.7" 15 | hooks: 16 | - id: black-nb 17 | description: strip output and black source 18 | additional_dependencies: ['black[jupyter]'] 19 | args: ["--clear-output"] 20 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | 2 | name: tests 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: [3.7, 3.8, 3.9] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install flake8 pytest pytest-cov 28 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 29 | - name: Check pre-commit 30 | run: pre-commit run --all-files || ( git status --short ; git diff ; exit 1 ) 31 | - name: Install 32 | run: | 33 | pip install . 34 | - name: Run Test 35 | run: | 36 | pytest tests --doctest-modules --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html 37 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | 2 | name: docs 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | 10 | env: 11 | TF_CPP_MIN_LOG_LEVEL: 3 12 | 13 | jobs: 14 | docs: 15 | 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python 3.8 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: '3.8' 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install flake8 pytest pytest-cov 28 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 29 | - name: Install 30 | run: | 31 | pip install . && pip install -r docs/requirements.txt 32 | - name: Sphinx build 33 | run: | 34 | cp -R paper docs/source 35 | cd docs && make html 36 | - name: Deploy 37 | uses: peaceiris/actions-gh-pages@v3 38 | if: github.ref == 'refs/heads/main' 39 | with: 40 | github_token: ${{ secrets.GITHUB_TOKEN }} 41 | publish_dir: docs/build/html 42 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | 2 | name: publish 3 | 4 | on: 5 | release: 6 | types: 7 | - created 8 | workflow_dispatch: 9 | 10 | 11 | jobs: 12 | publish: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python "3.8" 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: "3.8" 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 26 | - name: Check pre-commit 27 | run: pre-commit run --all-files || ( git status --short ; git diff ; exit 1 ) 28 | - name: Install 29 | run: | 30 | pip install . 31 | - name: Run Test 32 | run: | 33 | pytest tests 34 | - name: Build a binary wheel and a source tarball 35 | run: | 36 | pip install build 37 | python -m build --sdist --wheel --outdir dist/ . 38 | - name: Publish distribution 📦 to PyPI 39 | uses: pypa/gh-action-pypi-publish@master 40 | with: 41 | password: ${{ secrets.PYPI_API_TOKEN }} 42 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | 5 | from setuptools import find_packages 6 | from setuptools import setup 7 | 8 | exec(open("maxent/version.py").read()) 9 | 10 | 11 | with open("README.md", "r", encoding="utf-8") as fh: 12 | long_description = fh.read() 13 | 14 | 15 | setup( 16 | name="maxent-infer", 17 | version=__version__, 18 | url="https://github.com/ur-whitelab/maxent", 19 | license="GPL v2", 20 | author="Mehrad Ansari , Rainier Barrett , Andrew White ", 21 | author_email="andrew.white@rochester.edu", 22 | description="Maximum entropy inference Keras implementation", 23 | long_description=long_description, 24 | long_description_content_type="text/markdown", 25 | packages=find_packages(exclude=("tests",)), 26 | install_requires=["numpy", "tensorflow", "tensorflow_probability"], 27 | classifiers=[ 28 | "Development Status :: 2 - Pre-Alpha", 29 | "License :: OSI Approved :: GNU General Public License v2 (GPLv2)", 30 | "Programming Language :: Python", 31 | "Programming Language :: Python :: 3", 32 | "Programming Language :: Python :: 3.8", 33 | "Programming Language :: Python :: 3.5", 34 | "Programming Language :: Python :: 3.6", 35 | "Programming Language :: Python :: 3.7", 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /.github/workflows/paper.yml: -------------------------------------------------------------------------------- 1 | 2 | name: paper 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 3.8 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: "3.8" 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install flake8 pytest pytest-cov 25 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 26 | - name: Install 27 | run: | 28 | pip install . 29 | - name: Install paper depends 30 | run: | 31 | pip install --upgrade --use-deprecated=legacy-resolver -r paper/requirements.txt 32 | - name: Run Gaussian System 33 | run: jupyter nbconvert --ExecutePreprocessor.timeout=-1 --execute "paper/gaussian.ipynb" --to notebook --output-dir='temp' --clear-output 34 | - name: Run Epidemiology System 35 | run: jupyter nbconvert --ExecutePreprocessor.timeout=-1 --execute "paper/epidemiology.ipynb" --to notebook --output-dir='temp' --clear-output 36 | - name: Run Gravitation System 37 | run: jupyter nbconvert --ExecutePreprocessor.timeout=-1 --execute "paper/gravitation.ipynb" --to notebook --output-dir='temp' --clear-output 38 | - name: Run MBP System 39 | run: jupyter nbconvert --ExecutePreprocessor.timeout=-1 --execute "paper/MBP.ipynb" --to notebook --output-dir='temp' --clear-output 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("../..")) 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "maxent" 21 | copyright = "2021, Rainier Barret, Mehrad Ansari, Andrew D White" 22 | author = "Rainier Barret, Mehrad Ansari, Andrew D White" 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | "myst_nb", 32 | "sphinx.ext.autodoc", 33 | "sphinx_autodoc_typehints", 34 | "sphinx.ext.autosectionlabel", 35 | "sphinx.ext.intersphinx", 36 | "sphinx.ext.githubpages", 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ["_templates"] 41 | 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | # This pattern also affects html_static_path and html_extra_path. 45 | exclude_patterns = [] 46 | 47 | 48 | # -- Options for HTML output ------------------------------------------------- 49 | 50 | # The theme to use for HTML and HTML Help pages. See the documentation for 51 | # a list of builtin themes. 52 | # 53 | html_theme = "sphinx_rtd_theme" 54 | html_context = { 55 | "display_github": True, # Add 'Edit on Github' link instead of 'View page source' 56 | "last_updated": True, 57 | "commit": False, 58 | } 59 | 60 | 61 | autosectionlabel_prefix_document = True 62 | add_module_names = False 63 | 64 | intersphinx_mapping = { 65 | "tf": ( 66 | "https://www.tensorflow.org/api_docs/python", 67 | "https://github.com/mr-ubik/tensorflow-intersphinx/raw/master/tf2_py_objects.inv", 68 | ), 69 | "tensorflow_probability": ( 70 | "https://www.tensorflow.org/probability/api_docs/python", 71 | "https://github.com/GPflow/tensorflow-intersphinx/raw/master/tfp_py_objects.inv", 72 | ), 73 | "matplotlib": ("https://matplotlib.org/stable/", None), 74 | "numpy": ("https://numpy.org/doc/stable/", None), 75 | } 76 | 77 | master_doc = "toc" 78 | 79 | myst_enable_extensions = [ 80 | "amsmath", 81 | "colon_fence", 82 | "deflist", 83 | "dollarmath", 84 | "html_image", 85 | ] 86 | myst_url_schemes = ["http", "https", "mailto"] 87 | execution_timeout = -1 88 | -------------------------------------------------------------------------------- /docs/_templates/breadcrumbs.html: -------------------------------------------------------------------------------- 1 | {%- if meta is defined and meta is not none %} 2 | {%- set check_meta = True %} 3 | {%- else %} 4 | {%- set check_meta = False %} 5 | {%- endif %} 6 | 7 | {%- if check_meta and 'github_url' in meta %} 8 | {%- set display_github = True %} 9 | {%- endif %} 10 | 11 | {%- if check_meta and 'bitbucket_url' in meta %} 12 | {%- set display_bitbucket = True %} 13 | {%- endif %} 14 | 15 | {%- if check_meta and 'gitlab_url' in meta %} 16 | {%- set display_gitlab = True %} 17 | {%- endif %} 18 | 19 | {%- set display_vcs_links = display_vcs_links if display_vcs_links is defined else True %} 20 | 21 | {#- Translators: This is an ARIA section label for page links, including previous/next page link and links to GitHub/GitLab/etc. -#} 22 |
23 | 64 | 65 | {%- if (theme_prev_next_buttons_location == 'top' or theme_prev_next_buttons_location == 'both') and (next or prev) %} 66 | {#- Translators: This is an ARIA section label for sequential page links, such as previous and next page links. -#} 67 | 75 | {%- endif %} 76 |
77 |
78 | -------------------------------------------------------------------------------- /tests/test_maxent.py: -------------------------------------------------------------------------------- 1 | import maxent 2 | import unittest 3 | import numpy as np 4 | import numpy.testing as npt 5 | import tensorflow as tf 6 | import tensorflow_probability as tfp 7 | 8 | tfd = tfp.distributions 9 | 10 | np.random.seed(0) 11 | tf.random.set_seed(0) 12 | 13 | 14 | class TestPriors(unittest.TestCase): 15 | def test_empty(self): 16 | p = maxent.EmptyPrior() 17 | assert p.expected(1) == 0 18 | 19 | def test_laplace(self): 20 | p = maxent.Laplace(0.1) 21 | 22 | def test_restraint(self): 23 | r = maxent.Restraint(lambda x: x**2, 4, maxent.EmptyPrior()) 24 | assert r(2) == 0 25 | 26 | 27 | class TestLayers(unittest.TestCase): 28 | def test_rw_layer(self): 29 | l = maxent._ReweightLayer(10) 30 | w = l(np.arange(10, dtype=np.float32)) 31 | assert len(w) == 1 32 | 33 | def test_avg_layer(self): 34 | l = maxent._ReweightLayer(10) 35 | la = maxent._AvgLayer(l) 36 | gk = np.arange(10, dtype=np.float32) 37 | w = l(gk) 38 | la(gk, w) 39 | 40 | def test_lrw_layer(self): 41 | l = maxent._ReweightLayerLaplace(np.random.normal(size=10).astype(np.float32)) 42 | w = l(np.arange(10, dtype=np.float32)) 43 | assert len(w) == 1 44 | 45 | def test_lavg_layer(self): 46 | l = maxent._ReweightLayerLaplace(np.random.normal(size=10).astype(np.float32)) 47 | la = maxent._AvgLayerLaplace(l) 48 | gk = np.arange(10, dtype=np.float32) 49 | w = l(gk) 50 | la(gk, w) 51 | 52 | 53 | class TestModel(unittest.TestCase): 54 | def test_me(self): 55 | data = np.random.normal(size=256).astype(np.float32) 56 | r = maxent.Restraint(lambda x: x**2, 2, maxent.EmptyPrior()) 57 | model = maxent.MaxentModel([r]) 58 | model.compile(tf.keras.optimizers.Adam(0.1), "mean_squared_error") 59 | model.fit(data, epochs=128, verbose=0) 60 | # check we fit somewhat close 61 | e = np.sum(data**2 * model.traj_weights) 62 | npt.assert_array_almost_equal(e, 2.0, decimal=2) 63 | 64 | def test_lme(self): 65 | data = np.random.normal(size=256).astype(np.float32) 66 | r = maxent.Restraint(lambda x: x**2, 2, maxent.Laplace(0.01)) 67 | model = maxent.MaxentModel([r]) 68 | model.compile(tf.keras.optimizers.Adam(0.1), "mean_squared_error") 69 | model.fit(data, epochs=128, verbose=0) 70 | # check we fit somewhat close 71 | e = np.sum(data**2 * model.traj_weights) 72 | npt.assert_array_almost_equal(e, 2.0, decimal=1) 73 | 74 | 75 | class TestHyperModel(unittest.TestCase): 76 | def test_reshaper(self): 77 | # make a model for sampling parameters 78 | x = np.array([1.0, 1.0]) 79 | i = tf.keras.Input((1,)) 80 | l = maxent.TrainableInputLayer(x)(i) 81 | d = tfp.layers.DistributionLambda( 82 | lambda x: tfd.Normal(loc=x[..., 0], scale=tf.math.exp(x[..., 1])) 83 | )(l) 84 | model = maxent.ParameterJoint(inputs=i, outputs=[d]) 85 | model.compile(tf.keras.optimizers.Adam(0.1)) 86 | model.sample(1) 87 | 88 | def test_hme(self): 89 | # make a model for sampling parameters 90 | x = np.array([1.0, 1.0]) 91 | tf.random.set_seed(0) 92 | 93 | i = tf.keras.Input((1,)) 94 | l = maxent.TrainableInputLayer(x)(i) 95 | d = tfp.layers.DistributionLambda( 96 | lambda x: tfd.Normal(loc=x[..., 0], scale=tf.math.exp(x[..., 1])) 97 | )(l) 98 | model = maxent.ParameterJoint(inputs=i, outputs=[d]) 99 | model.compile(tf.keras.optimizers.Adam(0.1)) 100 | 101 | # make simulator 102 | def simulate(x): 103 | y = np.random.normal(loc=x, scale=0.1) 104 | return y 105 | 106 | # make ME model 107 | r = maxent.Restraint(lambda x: x, 8, maxent.EmptyPrior()) 108 | hme_model = maxent.HyperMaxentModel([r], model, simulate) 109 | hme_model.compile(tf.keras.optimizers.Adam(0.5), "mean_squared_error") 110 | hme_model.fit(epochs=64, outter_epochs=2) 111 | e = np.sum(hme_model.trajs[:, 0] * hme_model.traj_weights) 112 | assert abs(e - 8.0) < 0.25 113 | 114 | def test_error(self): 115 | # make a model for sampling parameters 116 | x = np.array([1.0, 1.0]) 117 | i = tf.keras.Input((1,)) 118 | l = maxent.TrainableInputLayer(x)(i) 119 | d = tfp.layers.DistributionLambda( 120 | lambda x: tfd.Normal(loc=x[..., 0], scale=tf.math.exp(x[..., 1])) 121 | )(l) 122 | model = maxent.ParameterJoint([lambda x: x], inputs=i, outputs=[d]) 123 | model.compile(tf.keras.optimizers.Adam(0.1)) 124 | 125 | # make bad simulator 126 | def simulate(x): 127 | y = np.random.normal(loc=2, scale=0.1) 128 | return y 129 | 130 | # make ME model 131 | r = maxent.Restraint(lambda x: x, 8, maxent.EmptyPrior()) 132 | hme_model = maxent.HyperMaxentModel([r], model, simulate) 133 | with self.assertRaises(ValueError) as e: 134 | hme_model.fit(epochs=1, outter_epochs=1) 135 | 136 | 137 | if __name__ == "__main__": 138 | unittest.main() 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Maximum Entropy Inference 2 | 3 | [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/ur-whitelab/maxent) 4 | ![tests](https://github.com/ur-whitelab/maxent/actions/workflows/test.yml/badge.svg) 5 | ![paper](https://github.com/ur-whitelab/maxent/actions/workflows/paper.yml/badge.svg) 6 | [![docs](https://github.com/ur-whitelab/maxent/actions/workflows/docs.yml/badge.svg)](https://ur-whitelab.github.io/maxent/) 7 | [![PyPI version](https://badge.fury.io/py/maxent-infer.svg)](https://badge.fury.io/py/maxent-infer) 8 | 9 | This provides a Keras implementation of maximum entropy simulation based inference. The point of this package is to reweight outcomes from a simulator to agree with observations, rather than trying to optimize your simulators input parameters. The simulator must necessarily give multiple outcomes - either because you're trying multiple sets of input parameters or it has intrinsic noise. The assumption of this model is that your simulator is approximately correct. The observations being fit could have come the distribution of outcomes of your simulator. 10 | 11 | ## About maximum entropy 12 | 13 | Maximum entropy reweighting is a straightforward black box method that can be applied to arbitrary simulators with few observations. Its runtime is independent of the number of parameters used by the simulator, and it has been shown analytically to minimally change the prior to agree with observations. This method fills a niche in the small-data, high-complexity regime of SBI parameter inference, because it accurately and minimally biases a prior to match observations and does not scale in runtime with the number of model parameters. 14 | 15 | ## Installation 16 | 17 | ```sh 18 | pip install maxent-infer 19 | ``` 20 | 21 | ## Quick Start 22 | 23 | 24 | ### A Pandas Data Frame 25 | 26 | Consider a data frame representing outcomes from our prior model/simulator. We would like to 27 | regress these outcomes to data. 28 | 29 | ```python 30 | import pandas as pd 31 | import numpy as np 32 | import maxent 33 | 34 | 35 | data = pd.read_csv('data.csv') 36 | ``` 37 | 38 | Perhaps we have a single observation we would like to match. We will reweight our rows (outcomes) to agree with the single observation. This is under-determined, but there is one solution because of the maximum entropy condition. To fit this one observation, we specify two things: a function to compute the observation from outcomes of our prior/simulator and the value of the observation (called `target`). Let's say our observation is just the 3rd column in a row that came from one outcome: 39 | 40 | ```python 41 | 42 | def observe(single_row): 43 | return single_row[3] 44 | 45 | r = maxent.Restraint(observe, target=1.5) 46 | ``` 47 | 48 | Do you have uncertainty with your observation? No problem. Here we specify our uncertainty as a Laplace distributed with a variance of 2 (Laplace scale parameter 1): 49 | 50 | ```python 51 | r = maxent.Restraint(observe, target=1.5, prior=maxent.Laplace(1)) 52 | ``` 53 | 54 | Now we'll fit our outcomes to the single observation. 55 | ```python 56 | model = maxent.MaxentModel(r) 57 | model.compile() 58 | model.fit(data.values) 59 | ``` 60 | 61 | We now have a set of weights -- one per row -- that we can use to compute other expressions. 62 | For example, here is the most likely outcome (mode) 63 | 64 | ```python 65 | i = np.argmax(model.traj_weights) 66 | mode = data.iloc[i, :] 67 | ``` 68 | 69 | Here are the new column averages 70 | ```python 71 | col_avg = np.sum(data.values * model.traj_weights[:, np.newaxis], axis=0) 72 | ``` 73 | 74 | ### A simulator 75 | 76 | Here we show how to take a random walk simulator and use `maxent` to have reweight the random walk so that the average end is at x = 2, y= 1. 77 | 78 | ```python 79 | # simulate 80 | def random_walk_simulator(T=10): 81 | x = [0,0] 82 | traj = np.empty((T,2)) 83 | for i in range(T): 84 | traj[i] = x 85 | x += np.random.normal(size=2) 86 | return traj 87 | 88 | N = 500 89 | trajs = [random_walk_simulator() for _ in range(N)] 90 | 91 | # now have N x T x 2 tensor 92 | trajs = np.array(trajs) 93 | 94 | # here is a plot of these trajectories 95 | ``` 96 | 97 | ![image](https://user-images.githubusercontent.com/908389/130389256-2710cb73-617f-4e71-b3ba-e32bd0f85d6a.png) 98 | 99 | 100 | ```python 101 | # we want the random walk to have average end of 2,1 102 | rx = maxent.Restraint(lambda traj: traj[-1,0], target=2) 103 | ry = maxent.Restraint(lambda traj: traj[-1,1], target=1) 104 | 105 | # create model by passing in restraints 106 | model = maxent.MaxentModel([rx, ry]) 107 | 108 | # convert model to be differentiable/GPU (if available) 109 | model.compile() 110 | # fit to data 111 | h = model.fit(trajs) 112 | 113 | # can now compute other averages properties 114 | # with new weights 115 | model.traj_weights 116 | 117 | # plot showing weights of trajectories: 118 | ``` 119 | 120 | ![image](https://user-images.githubusercontent.com/908389/130389259-3a081e19-110a-4c80-9f91-3b3902444e21.png) 121 | 122 | ## Further Examples 123 | 124 | You can find the examples used in the manuscript, including comparisons with competing methods: [here](https://ur-whitelab.github.io/maxent/toc.html). These examples use the latest package versions, so the figures will not exactly match those in the manuscript. If you would like to reproduce the manuscript exactly, install the packages in `paper/requirements.txt` and execute the notebooks in `paper` (this is the output from the `paper` workflow above). 125 | 126 | ## API 127 | 128 | [API](https://ur-whitelab.github.io/maxent/api.html) 129 | 130 | ## Citation 131 | 132 | [See paper](https://iopscience.iop.org/article/10.1088/2632-2153/ac6286/meta) and the citation: 133 | 134 | ```bibtex 135 | @article{barrett2022simulation, 136 | title={Simulation-Based Inference with Approximately Correct Parameters via Maximum Entropy}, 137 | author={Barrett, Rainier and Ansari, Mehrad and Ghoshal, Gourab and White, Andrew D}, 138 | journal={Machine Learning: Science and Technology}, 139 | year={2022} 140 | } 141 | ``` 142 | 143 | ## License 144 | 145 | [![License: GPL v2](https://img.shields.io/badge/License-GPL%20v2-blue.svg)](https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html) 146 | -------------------------------------------------------------------------------- /paper/sbi_gravitation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | 5 | from matplotlib.collections import LineCollection 6 | 7 | 8 | def get_observation_points(traj): 9 | return traj[19:101:20] 10 | 11 | 12 | # colorline code from matplotlib examples https://nbviewer.jupyter.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb 13 | # Data manipulation: 14 | 15 | 16 | def make_segments(x, y): 17 | """ 18 | Create list of line segments from x and y coordinates, in the correct format for LineCollection: 19 | an array of the form numlines x (points per line) x 2 (x and y) array 20 | """ 21 | 22 | points = np.array([x, y]).T.reshape(-1, 1, 2) 23 | segments = np.concatenate([points[:-1], points[1:]], axis=1) 24 | 25 | return segments 26 | 27 | 28 | # Interface to LineCollection: 29 | 30 | 31 | def colorline( 32 | x, 33 | y, 34 | z=None, 35 | cmap=plt.get_cmap("copper"), 36 | norm=plt.Normalize(0.0, 1.0), 37 | linewidth=3, 38 | alpha=1.0, 39 | linestyle=None, 40 | label=None, 41 | ): 42 | """ 43 | Plot a colored line with coordinates x and y 44 | Optionally specify colors in the array z 45 | Optionally specify a colormap, a norm function and a line width 46 | """ 47 | 48 | # Default colors equally spaced on [0,1]: 49 | if z is None: 50 | z = np.linspace(0.0, 1.0, len(x)) 51 | 52 | # Special case if a single number: 53 | if not hasattr(z, "__iter__"): # to check for numerical input -- this is a hack 54 | z = np.array([z]) 55 | 56 | z = np.asarray(z) 57 | 58 | segments = make_segments(x, y) 59 | lc = LineCollection( 60 | segments, 61 | array=z, 62 | cmap=cmap, 63 | norm=norm, 64 | linewidth=linewidth, 65 | alpha=alpha, 66 | linestyle=linestyle, 67 | label=label, 68 | ) 69 | 70 | ax = plt.gca() 71 | ax.add_collection(lc) 72 | 73 | return lc 74 | 75 | 76 | class GravitySimulator: 77 | def __init__( 78 | self, 79 | m1=45, 80 | m2=33, 81 | m3=60, 82 | v0=[50.0, 0.0], 83 | G=1.90809e5, 84 | dt=1e-3, 85 | nsteps=100, 86 | random_noise=False, 87 | noise_size=3.0, 88 | ): 89 | # always start at origin 90 | self.m0 = 1.0 91 | self.m1, self.m2, self.m3, self.v0, self.G, self.dt, self.nsteps = ( 92 | np.array(m1), 93 | np.array(m2), 94 | np.array(m3), 95 | np.array(v0), 96 | G, 97 | dt, 98 | nsteps, 99 | ) 100 | self.masses = [self.m1, self.m2, self.m3] 101 | self.positions = np.zeros([self.nsteps, 2]) 102 | self.attractor_positions = np.array([[20.0, 20.0], [50.0, -15.0], [80.0, 25.0]]) 103 | # first step special case 104 | self.positions[1] = ( 105 | self.positions[0] 106 | + self.v0 * self.dt 107 | + 0.5 * self.dt**2 * self.A(self.positions[0]) 108 | ) 109 | self.iter_idx = 2 110 | self.random_noise = random_noise 111 | self.noise_size = noise_size 112 | 113 | def rsquare(self, x1, x2): 114 | # square of distance between two points --> only square dist matters for gravity 115 | return np.linalg.norm(x1 - x2) 116 | 117 | def A(self, x): 118 | """Take the position of the small particle, x, and return 119 | the sum of forces on it from the three attractors.""" 120 | # acceleration = Force/mass 121 | # F = G * m1 * m2 / r^2 + R(t) --> add random noise to force 122 | forces = np.zeros([3, 2]) 123 | for i, mass in enumerate(self.masses): 124 | # since the small particle has unit mass, just G * m 125 | dist = self.rsquare(x, self.attractor_positions[i]) 126 | force = self.G * mass / dist 127 | unit_vec = (self.attractor_positions[i] - x) / dist 128 | # point the force in the correct direction (attractive) 129 | force *= unit_vec 130 | forces[i] = force 131 | # sum up the three force vectors 132 | return np.sum(forces, axis=0) 133 | 134 | def run(self): 135 | np.random.seed(12656) 136 | while self.iter_idx < self.nsteps: 137 | self.step() 138 | if self.random_noise: 139 | self.positions = np.random.normal(self.positions, self.noise_size) 140 | return self.positions 141 | 142 | def step(self): 143 | # single step of integration with velocity verlet 144 | last_last_x = self.positions[self.iter_idx - 2] 145 | last_x = self.positions[self.iter_idx - 1] 146 | self.positions[self.iter_idx] = ( 147 | 2 * last_x - last_last_x + self.A(last_x) * self.dt**2 148 | ) 149 | self.iter_idx += 1 150 | 151 | def plot_traj( 152 | self, 153 | name="trajectory.png", 154 | fig=None, 155 | axes=None, 156 | save=True, 157 | make_colorbar=False, 158 | alpha=0.5, 159 | cmap=plt.get_cmap("Blues").reversed(), 160 | color="blue", 161 | fade_lines=True, 162 | linestyle="-", 163 | linewidth=2, 164 | label=None, 165 | label_attractors=False, 166 | ): 167 | if fig is None and axes is None: 168 | fig, axes = plt.subplots() 169 | x, y = self.positions[:, 0], self.positions[:, 1] 170 | if fade_lines: 171 | lc = colorline( 172 | x, 173 | y, 174 | alpha=alpha, 175 | cmap=cmap, 176 | linestyle=linestyle, 177 | linewidth=linewidth, 178 | label=label, 179 | ) 180 | else: 181 | axes.plot( 182 | x, 183 | y, 184 | alpha=alpha, 185 | color=color, 186 | linestyle=linestyle, 187 | linewidth=linewidth, 188 | label=label, 189 | ) 190 | if make_colorbar: 191 | fig.colorbar(lc) 192 | xmin = min( 193 | x.min(), 194 | np.min(self.attractor_positions[0, :]) 195 | - 0.1 * abs(np.min(self.attractor_positions[0, :])), 196 | ) 197 | xmax = max( 198 | x.max(), 199 | np.max(self.attractor_positions[0, :]) 200 | + 0.1 * np.max(self.attractor_positions[0, :]), 201 | ) 202 | plt.xlim(xmin, xmax) 203 | ymin = min( 204 | y.min(), 205 | np.min(self.attractor_positions[1, :]) 206 | - 0.1 * abs(np.min(self.attractor_positions[1, :])), 207 | ) 208 | ymax = max( 209 | y.max(), 210 | np.max(self.attractor_positions[1, :]) 211 | + 0.1 * np.max(self.attractor_positions[1, :]), 212 | ) 213 | plt.ylim(ymin, ymax) 214 | axes.scatter( 215 | self.attractor_positions[:, 0], 216 | self.attractor_positions[:, 1], 217 | color="black", 218 | label=("Attractors" if label_attractors else None), 219 | ) 220 | if save: 221 | plt.savefig(name) 222 | 223 | def set_traj(self, trajectory): 224 | self.positions = trajectory 225 | 226 | 227 | def sim_wrapper(params_list): 228 | """params_list should be: m1, m2, m3, v0[0], v0[1] in that order""" 229 | m1, m2, m3 = float(params_list[0]), float(params_list[1]), float(params_list[2]) 230 | v0 = np.array([params_list[3], params_list[4]], dtype=np.float64) 231 | this_sim = GravitySimulator(m1, m2, m3, v0, random_noise=True) 232 | this_traj = this_sim.run() 233 | summary_stats = torch.as_tensor(get_observation_points(this_traj).flatten()) 234 | return summary_stats 235 | -------------------------------------------------------------------------------- /paper/MBP.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2e71345f", 6 | "metadata": {}, 7 | "source": [ 8 | "## MBP Protein NMR Example" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "cf7e046f", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pandas as pd\n", 19 | "import numpy as np\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "import seaborn as sns\n", 22 | "import tensorflow as tf\n", 23 | "import maxent\n", 24 | "\n", 25 | "sns.set_context(\"paper\")\n", 26 | "sns.set_style(\n", 27 | " \"whitegrid\",\n", 28 | " {\n", 29 | " \"xtick.bottom\": True,\n", 30 | " \"ytick.left\": True,\n", 31 | " \"xtick.color\": \"#333333\",\n", 32 | " \"ytick.color\": \"#333333\",\n", 33 | " },\n", 34 | ")\n", 35 | "# plt.rcParams[\"font.family\"] = \"serif\"\n", 36 | "plt.rcParams[\"mathtext.fontset\"] = \"dejavuserif\"\n", 37 | "colors = [\"#1b9e77\", \"#d95f02\", \"#7570b3\", \"#e7298a\", \"#66a61e\"]\n", 38 | "import pynmrstar\n", 39 | "from functools import partialmethod\n", 40 | "from tqdm import tqdm\n", 41 | "\n", 42 | "tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "id": "154e7ea9", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "# load data from brmrb\n", 53 | "bmrb = pynmrstar.Entry.from_database(20062, convert_data_types=True)\n", 54 | "cs_result_sets = []\n", 55 | "for chemical_shift_loop in bmrb.get_loops_by_category(\"Atom_chem_shift\"):\n", 56 | " cs_result_sets.append(\n", 57 | " chemical_shift_loop.get_tag(\n", 58 | " [\"Comp_index_ID\", \"Comp_ID\", \"Atom_ID\", \"Atom_type\", \"Val\", \"Val_err\"]\n", 59 | " )\n", 60 | " )\n", 61 | "ref_data = pd.DataFrame(\n", 62 | " cs_result_sets[0], columns=[\"id\", \"res\", \"atom\", \"type\", \"shift\", \"error\"]\n", 63 | ")\n", 64 | "\n", 65 | "ref_resids = ref_data[ref_data.atom == \"H\"].id.values\n", 66 | "ref_data[ref_data.atom == \"H\"].head(25)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "3c30d90b", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "HAVE_MD_FILE = False\n", 77 | "\n", 78 | "ref_hdata = ref_data[ref_data.atom == \"H\"]\n", 79 | "# cut GLU because proton type mismatch\n", 80 | "ref_hdata = ref_hdata[\"shift\"].values[1:].astype(float)\n", 81 | "resnames = ref_data[ref_data.atom == \"H\"].res[1:]\n", 82 | "if HAVE_MD_FILE:\n", 83 | " data = pd.read_csv(\"./cs.csv\")\n", 84 | " data.head(10)\n", 85 | " # only need weights, so we extract only shifts that will be biased\n", 86 | " hdata_df = data[data.names == \"HN\"]\n", 87 | " hdata_df = hdata_df[hdata_df[\"resids\"].isin(ref_resids)]\n", 88 | " hdata_c = hdata_df.confident.values.reshape(len(data.frame.unique()), -1)\n", 89 | " hdata = hdata_df.peaks.values.reshape(len(data.frame.unique()), -1)\n", 90 | " assert hdata.shape[-1] == ref_hdata.shape[0]\n", 91 | " np.savez(\"mbp_files/mbp_cs.npz\", hdata=hdata, hdata_c=hdata_c)\n", 92 | "data = np.load(\"mbp_files/mbp_cs.npz\")\n", 93 | "hdata, hdata_c = data[\"hdata\"], data[\"hdata_c\"]" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "2dda8568", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "plt.plot(np.mean(hdata, axis=0), \"o-\")\n", 104 | "plt.plot(ref_hdata, \"o-\")\n", 105 | "plt.show()" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "5b840e7b", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# fill in unconfident peaks with mean\n", 116 | "hdata_m = np.sum(hdata * hdata_c, axis=0) / np.sum(hdata_c, axis=0)\n", 117 | "total_fill = 0\n", 118 | "for i in range(hdata.shape[1]):\n", 119 | " hdata[:, i][~hdata_c[:, i]] = hdata_m[i]\n", 120 | " total_fill += np.sum(~hdata_c[:, i])\n", 121 | "print(\"Filled\", total_fill)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "f2accfa3", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "plt.plot(np.mean(hdata, axis=0), \"o-\")\n", 132 | "plt.plot(ref_hdata, \"o-\")\n", 133 | "plt.show()" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "id": "b4f98fe8", 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "# make restraints\n", 144 | "restraints = []\n", 145 | "do_restrain = range(len(ref_hdata) // 2)\n", 146 | "for i in do_restrain:\n", 147 | " restraints.append(\n", 148 | " maxent.Restraint(lambda h, i=i: h[i], ref_hdata[i], prior=maxent.Laplace(0.05))\n", 149 | " )" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "id": "273827bb", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "model = maxent.MaxentModel(restraints)\n", 160 | "model.compile(tf.keras.optimizers.Adam(0.1), \"mean_squared_error\")\n", 161 | "history = model.fit(hdata, epochs=500, verbose=0)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "1938a4f8", 168 | "metadata": { 169 | "scrolled": true 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "plt.plot(history.history[\"loss\"])\n", 174 | "print(history.history[\"loss\"][-1])" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "54f99247", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "np.mean(np.abs(np.sum(hdata * model.traj_weights[..., np.newaxis], axis=0) - ref_hdata))" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "981888b8", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "model.lambdas" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "id": "8c14e63f", 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "plt.plot(model.traj_weights)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "id": "931e9e71", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "plt.figure(figsize=(3, 2), dpi=300)\n", 215 | "seq_dict = {\n", 216 | " \"CYS\": \"C\",\n", 217 | " \"ASP\": \"D\",\n", 218 | " \"SER\": \"S\",\n", 219 | " \"GLN\": \"Q\",\n", 220 | " \"LYS\": \"K\",\n", 221 | " \"ILE\": \"I\",\n", 222 | " \"PRO\": \"P\",\n", 223 | " \"THR\": \"T\",\n", 224 | " \"PHE\": \"F\",\n", 225 | " \"ASN\": \"N\",\n", 226 | " \"GLY\": \"G\",\n", 227 | " \"HIS\": \"H\",\n", 228 | " \"LEU\": \"L\",\n", 229 | " \"ARG\": \"R\",\n", 230 | " \"TRP\": \"W\",\n", 231 | " \"ALA\": \"A\",\n", 232 | " \"VAL\": \"V\",\n", 233 | " \"GLU\": \"E\",\n", 234 | " \"TYR\": \"Y\",\n", 235 | " \"MET\": \"M\",\n", 236 | "}\n", 237 | "plt.plot(\n", 238 | " np.sum(hdata * model.traj_weights[..., np.newaxis], axis=0), \"o-\", label=\"Posterior\"\n", 239 | ")\n", 240 | "plt.plot(np.mean(hdata, axis=0), \"o-\", label=\"Prior\")\n", 241 | "plt.plot(ref_hdata, \"*\", label=\"Experiment\")\n", 242 | "plt.axvline(x=len(ref_hdata) // 2 - 0.5, color=\"gray\", linestyle=\"--\")\n", 243 | "plt.xticks(range(len(ref_hdata)), [seq_dict[r] for r in resnames])\n", 244 | "plt.legend(loc=\"center left\", bbox_to_anchor=(1.0, 0.8))\n", 245 | "plt.text(len(ref_hdata) // 5, 8.55, \"Biased\")\n", 246 | "plt.text(len(ref_hdata) // 2, 8.55, \"Unbiased\")\n", 247 | "plt.xlabel(\"Sequence\")\n", 248 | "plt.ylabel(\"Chemical Shift [ppm]\")\n", 249 | "plt.savefig(\"protein.pdf\")" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "id": "d0d608ca", 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "print(\"most favored clusters\", np.argsort(model.traj_weights)[-3:])" 260 | ] 261 | } 262 | ], 263 | "metadata": { 264 | "kernelspec": { 265 | "display_name": "Python 3", 266 | "language": "python", 267 | "name": "python3" 268 | }, 269 | "language_info": { 270 | "codemirror_mode": { 271 | "name": "ipython", 272 | "version": 3 273 | }, 274 | "file_extension": ".py", 275 | "mimetype": "text/x-python", 276 | "name": "python", 277 | "nbconvert_exporter": "python", 278 | "pygments_lexer": "ipython3", 279 | "version": "3.8.3" 280 | } 281 | }, 282 | "nbformat": 4, 283 | "nbformat_minor": 5 284 | } 285 | -------------------------------------------------------------------------------- /maxent/core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from math import sqrt 4 | from typing import * 5 | 6 | EPS = np.finfo(np.float32).tiny 7 | #: Array-like type 8 | Array = Union[np.ndarray, tf.TensorArray, float, List[float]] 9 | 10 | 11 | class Prior: 12 | """Prior distribution for expected deviation from target for restraint""" 13 | 14 | def expected(self, l: float) -> float: 15 | """Expected disagreement 16 | 17 | :param l: The lagrange multiplier 18 | :return: expected disagreement 19 | """ 20 | raise NotImplementedError() 21 | 22 | 23 | class EmptyPrior(Prior): 24 | """No prior deviation from target for restraint (exact agreement)""" 25 | 26 | def expected(self, l: float): 27 | return 0.0 28 | 29 | 30 | class Laplace(Prior): 31 | """Laplace distribution prior expected deviation from target for restraint 32 | 33 | :param sigma: Parameter for Laplace prior - higher means more allowable disagreement 34 | """ 35 | 36 | def __init__(self, sigma: float) -> float: 37 | self.sigma = sigma 38 | 39 | def expected(self, l: float) -> float: 40 | return -1.0 * l * self.sigma**2 / (1.0 - l**2 * self.sigma**2 / 2) 41 | 42 | 43 | class Restraint: 44 | """Restraint - includes function, target, and prior belief in deviation from target 45 | 46 | :param fxn: callable that returns scalar 47 | :param target: desired scalar value 48 | :param prior: prior is a :class:`Prior` for expected deviation from that target 49 | """ 50 | 51 | def __init__( 52 | self, fxn: Callable[[Array], float], target: float, prior: Prior = EmptyPrior() 53 | ): 54 | self.target = target 55 | self.fxn = fxn 56 | self.prior = prior 57 | 58 | def __call__(self, traj: Array) -> float: 59 | return self.fxn(traj) - self.target 60 | 61 | 62 | class _ReweightLayerLaplace(tf.keras.layers.Layer): 63 | """Trainable layer containing weights for maxent method""" 64 | 65 | def __init__(self, sigmas: Array): 66 | super(_ReweightLayerLaplace, self).__init__() 67 | l_init = tf.zeros_initializer() 68 | restraint_dim = len(sigmas) 69 | self.l = tf.Variable( 70 | initial_value=l_init(shape=(restraint_dim,), dtype="float32"), 71 | trainable=True, 72 | name="maxent-lambda", 73 | constraint=lambda x: tf.clip_by_value( 74 | x, -sqrt(2) / (1e-10 + sigmas), sqrt(2) / (1e-10 + sigmas) 75 | ), 76 | ) 77 | self.sigmas = sigmas 78 | 79 | def call(self, gk: Array, input_weights: Array = None) -> tf.TensorArray: 80 | # add priors 81 | mask = tf.cast(tf.equal(self.sigmas, 0), tf.float32) 82 | two_sig = tf.math.divide_no_nan(sqrt(2), self.sigmas) 83 | prior_term = mask * tf.math.log( 84 | tf.clip_by_value( 85 | 1.0 / (self.l + two_sig) + 1.0 / (two_sig - self.l), 1e-20, 1e8 86 | ) 87 | ) 88 | # sum-up constraint terms 89 | logits = tf.reduce_sum( 90 | -self.l[tf.newaxis, :] * gk + prior_term[tf.newaxis, :], axis=1 91 | ) 92 | # compute per-trajectory weights 93 | weights = tf.math.softmax(logits) 94 | if input_weights is not None: 95 | weights = weights * tf.reshape(input_weights, (-1,)) 96 | weights /= tf.reduce_sum(weights) 97 | self.add_metric( 98 | tf.reduce_sum(-weights * tf.math.log(weights + EPS)), 99 | aggregation="mean", 100 | name="weight-entropy", 101 | ) 102 | return weights 103 | 104 | 105 | class _AvgLayerLaplace(tf.keras.layers.Layer): 106 | """Layer that returns reweighted expected value for observations""" 107 | 108 | def __init__(self, reweight_layer: _ReweightLayerLaplace): 109 | super(_AvgLayerLaplace, self).__init__() 110 | if type(reweight_layer) != _ReweightLayerLaplace: 111 | raise TypeError() 112 | self.rl = reweight_layer 113 | 114 | def call(self, gk: Array, weights: Array) -> tf.TensorArray: 115 | # sum over trajectories 116 | e_gk = tf.reduce_sum(gk * weights[:, tf.newaxis], axis=0) 117 | # add laplace term 118 | # cannot rely on mask due to no clip 119 | err_e_gk = e_gk + -1.0 * self.rl.l * self.rl.sigmas**2 / ( 120 | 1.0 - self.rl.l**2 * self.rl.sigmas**2 / 2 121 | ) 122 | return err_e_gk 123 | 124 | 125 | class _ReweightLayer(tf.keras.layers.Layer): 126 | """Trainable layer containing weights for maxent method""" 127 | 128 | def __init__(self, restraint_dim: int): 129 | super(_ReweightLayer, self).__init__() 130 | l_init = tf.zeros_initializer() 131 | self.l = tf.Variable( 132 | initial_value=l_init(shape=(restraint_dim,), dtype="float32"), 133 | trainable=True, 134 | name="maxent-lambda", 135 | ) 136 | 137 | def call(self, gk: Array, input_weights: Array = None) -> tf.TensorArray: 138 | # sum-up constraint terms 139 | logits = tf.reduce_sum(-self.l[tf.newaxis, :] * gk, axis=1) 140 | # compute per-trajectory weights 141 | weights = tf.math.softmax(logits) 142 | if input_weights is not None: 143 | weights = weights * tf.reshape(input_weights, (-1,)) 144 | weights /= tf.reduce_sum(weights) 145 | self.add_metric( 146 | tf.reduce_sum(-weights * tf.math.log(weights + 1e-30)), 147 | aggregation="mean", 148 | name="weight-entropy", 149 | ) 150 | return weights 151 | 152 | 153 | class _AvgLayer(tf.keras.layers.Layer): 154 | """Layer that returns reweighted expected value for observations""" 155 | 156 | def __init__(self, reweight_layer: _ReweightLayer): 157 | super(_AvgLayer, self).__init__() 158 | if type(reweight_layer) != _ReweightLayer: 159 | raise TypeError() 160 | self.rl = reweight_layer 161 | 162 | def call(self, gk: Array, weights: Array) -> tf.TensorArray: 163 | # sum over trajectories 164 | e_gk = tf.reduce_sum(gk * weights[:, tf.newaxis], axis=0) 165 | return e_gk 166 | 167 | 168 | def _compute_restraints(trajs, restraints): 169 | N = trajs.shape[0] 170 | K = len(restraints) 171 | gk = np.empty((N, K)) 172 | for i in range(N): 173 | gk[i, :] = np.array([r(trajs[i]) for r in restraints]) 174 | return gk 175 | 176 | 177 | class MaxentModel(tf.keras.Model): 178 | """Keras Maximum entropy model 179 | 180 | :param restraints: List of :class:`Restraint` 181 | :param name: Name of model 182 | """ 183 | 184 | def __init__( 185 | self, restraints: List[Restraint], name: str = "maxent-model", **kwargs 186 | ): 187 | super(MaxentModel, self).__init__(name=name, **kwargs) 188 | if type(restraints) == Restraint: 189 | restraints = [restraints] 190 | self.restraints = restraints 191 | restraint_dim = len(restraints) 192 | # identify prior 193 | prior = type(restraints[0].prior) 194 | # double-check 195 | for r in restraints: 196 | if type(r.prior) != prior: 197 | raise ValueError("Can only do restraints of one type") 198 | if prior == Laplace: 199 | sigmas = np.array([r.prior.sigma for r in restraints], dtype=np.float32) 200 | self.weight_layer = _ReweightLayerLaplace(sigmas) 201 | self.avg_layer = _AvgLayerLaplace(self.weight_layer) 202 | else: 203 | self.weight_layer = _ReweightLayer(restraint_dim) 204 | self.avg_layer = _AvgLayer(self.weight_layer) 205 | self.lambdas = self.weight_layer.l 206 | self.prior = prior 207 | 208 | def reset_weights(self): 209 | """Zero out the weights of the model""" 210 | w = self.weight_layer.get_weights() 211 | self.weight_layer.set_weights(tf.zeros_like(w)) 212 | 213 | def call(self, inputs: Union[Array, List[Array], Tuple[Array]]) -> tf.TensorArray: 214 | """Compute reweighted restraint values 215 | 216 | :param inputs: Restraint values 217 | :return: Weighted restraint values 218 | """ 219 | input_weights = None 220 | if (type(inputs) == tuple or type(inputs) == list) and len(inputs) == 2: 221 | input_weights = inputs[1] 222 | inputs = inputs[0] 223 | weights = self.weight_layer(inputs, input_weights=input_weights) 224 | wgk = self.avg_layer(inputs, weights) 225 | return wgk 226 | 227 | # docstring from parent uses special sphinx stuff we cannot replicate 228 | def compile( 229 | self, 230 | optimizer=tf.keras.optimizers.Adam(0.1), 231 | loss="mean_squared_error", 232 | metrics=None, 233 | loss_weights=None, 234 | weighted_metrics=None, 235 | run_eagerly=None, 236 | steps_per_execution=None, 237 | **kwargs 238 | ): 239 | """See ``compile`` method of :class:`tf.keras.Model`""" 240 | super(MaxentModel, self).compile( 241 | optimizer, 242 | loss, 243 | metrics, 244 | loss_weights, 245 | weighted_metrics, 246 | run_eagerly, 247 | steps_per_execution, 248 | **kwargs 249 | ) 250 | 251 | def fit( 252 | self, 253 | trajs: Array, 254 | input_weights: Array = None, 255 | batch_size: int = None, 256 | epochs: int = 128, 257 | **kwargs 258 | ) -> tf.keras.callbacks.History: 259 | """Fit to given observations with restraints 260 | 261 | :param trajs: Observations, which can be input to :class:`Restraint` 262 | :param input_weights: Array of weights which will be start 263 | :param batch_size: Almost always should be equal to number of trajs, unless you want to mix your Lagrange multipliers across trajectories 264 | :param kwargs: See :class:tf.keras.Model ``fit`` method for further optional arguments, like ``verbose=0`` to hide output 265 | :return: The history of fit 266 | """ 267 | 268 | # process kwargs 269 | if "verbose" not in kwargs: 270 | kwargs["verbose"] = 0 271 | 272 | gk = _compute_restraints(trajs, self.restraints) 273 | inputs = gk.astype(np.float32) 274 | if batch_size is None: 275 | batch_size = len(gk) 276 | if input_weights is None: 277 | input_weights = tf.ones((tf.shape(gk)[0], 1)) 278 | result = super(MaxentModel, self).fit( 279 | [inputs, input_weights], 280 | tf.zeros_like(gk), 281 | batch_size=batch_size, 282 | epochs=epochs, 283 | **kwargs 284 | ) 285 | self.traj_weights = self.weight_layer(inputs, input_weights) 286 | self.restraint_values = gk 287 | return result 288 | -------------------------------------------------------------------------------- /paper/gaussian.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Gaussian Distribution Example" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import tensorflow_probability as tfp\n", 17 | "\n", 18 | "tfd = tfp.distributions\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "import tensorflow as tf\n", 21 | "import seaborn as sns\n", 22 | "import numpy as np\n", 23 | "import scipy\n", 24 | "import tqdm\n", 25 | "import maxent\n", 26 | "import os\n", 27 | "\n", 28 | "tf.random.set_seed(0)\n", 29 | "np.random.seed(0)\n", 30 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n", 31 | "sns.set_context(\"paper\")\n", 32 | "sns.set_style(\n", 33 | " \"whitegrid\",\n", 34 | " {\n", 35 | " \"xtick.bottom\": True,\n", 36 | " \"ytick.left\": True,\n", 37 | " \"xtick.color\": \"#333333\",\n", 38 | " \"ytick.color\": \"#333333\",\n", 39 | " },\n", 40 | ")\n", 41 | "# plt.rcParams[\"font.family\"] = \"serif\"\n", 42 | "plt.rcParams[\"mathtext.fontset\"] = \"dejavuserif\"\n", 43 | "colors = [\"#1b9e77\", \"#d95f02\", \"#7570b3\", \"#e7298a\", \"#66a61e\"]\n", 44 | "%matplotlib inline" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "### Set-up Prior Distribution" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "x = np.array([1.0, 1.0])\n", 61 | "i = tf.keras.Input((1,))\n", 62 | "l = maxent.TrainableInputLayer(x)(i)\n", 63 | "d = tfp.layers.DistributionLambda(\n", 64 | " lambda x: tfd.Normal(loc=x[..., 0], scale=tf.math.exp(x[..., 1]))\n", 65 | ")(l)\n", 66 | "model = maxent.ParameterJoint([lambda x: x], inputs=i, outputs=[d])\n", 67 | "model.compile(tf.keras.optimizers.Adam(0.1))\n", 68 | "model(tf.constant([1.0]))" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "### Simulator" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "def simulate(x):\n", 85 | " y = np.random.normal(loc=x, scale=0.1)\n", 86 | " return y\n", 87 | "\n", 88 | "\n", 89 | "plt.figure()\n", 90 | "unbiased_params = model.sample(100000)\n", 91 | "y = simulate(*unbiased_params)\n", 92 | "y = np.squeeze(y)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "pdf = scipy.stats.gaussian_kde(y)\n", 102 | "x = np.linspace(-10, 10, 100)\n", 103 | "plt.plot(x, pdf.pdf(x), color=colors[0], linewidth=2)\n", 104 | "plt.axvline(np.mean(y), color=colors[0])" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "### Maximum Entropy Method" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "r = maxent.Restraint(lambda x: x, 4, maxent.EmptyPrior())\n", 121 | "\n", 122 | "me_model = maxent.MaxentModel([r])\n", 123 | "me_model.compile(tf.keras.optimizers.Adam(0.01), \"mean_squared_error\")\n", 124 | "result = me_model.fit(y, epochs=4, batch_size=128)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "plt.axvline(x=4, color=colors[0])\n", 134 | "wpdf = scipy.stats.gaussian_kde(\n", 135 | " np.squeeze(y), weights=np.squeeze(me_model.traj_weights)\n", 136 | ")\n", 137 | "x = np.linspace(-10, 10, 100)\n", 138 | "plt.plot(x, wpdf.pdf(x), color=colors[0], linewidth=2)\n", 139 | "\n", 140 | "\n", 141 | "plt.plot(x, pdf.pdf(x), color=colors[1], linewidth=2)\n", 142 | "plt.axvline(np.mean(np.squeeze(y)), color=colors[1])" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "### Variational MaxEnt\n", 150 | "\n", 151 | "Try to fit to more extreme value - 10" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "r = maxent.Restraint(lambda x: x, 10, maxent.EmptyPrior())\n", 161 | "hme_model = maxent.HyperMaxentModel([r], model, simulate)\n", 162 | "hme_model.compile(tf.keras.optimizers.SGD(0.005), \"mean_squared_error\")\n", 163 | "result = hme_model.fit(epochs=4, sample_batch_size=len(y) // 4, verbose=0)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "w2pdf = scipy.stats.gaussian_kde(\n", 173 | " np.squeeze(hme_model.trajs), weights=np.squeeze(hme_model.traj_weights)\n", 174 | ")" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "plt.figure(figsize=(3, 2), dpi=300)\n", 184 | "x = np.linspace(-10, 25, 100)\n", 185 | "plt.plot(\n", 186 | " x, w2pdf.pdf(x), color=colors[2], linewidth=2, label=\"Variational MaxEnt Posterior\"\n", 187 | ")\n", 188 | "plt.axvline(x=10, color=colors[2])\n", 189 | "\n", 190 | "plt.plot(x, pdf.pdf(x), color=colors[1], linewidth=2, label=\"Prior\")\n", 191 | "plt.axvline(np.mean(np.squeeze(y)), color=colors[1])\n", 192 | "\n", 193 | "plt.plot(x, wpdf.pdf(x), color=colors[0], linewidth=2, label=\"MaxEnt Posterior\")\n", 194 | "plt.axvline(x=4, color=colors[0])\n", 195 | "plt.ylim(0, 0.30)\n", 196 | "plt.xlabel(r\"$r$\")\n", 197 | "plt.ylabel(r\"$P(r)$\")\n", 198 | "plt.title(\"a) MaxEnt\")\n", 199 | "plt.legend()\n", 200 | "plt.savefig(\"maxent.svg\")" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "### Bayesian Inference Setting" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": { 214 | "scrolled": true 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "# https://pubmed.ncbi.nlm.nih.gov/26723635/\n", 219 | "plt.figure(figsize=(3, 2), dpi=300)\n", 220 | "x = np.linspace(-10, 25, 1000)\n", 221 | "cmap = plt.get_cmap(\"magma\")\n", 222 | "prior_theta = 10 ** np.linspace(-1, 4, 10)\n", 223 | "bpdf = np.exp(-((y - 10) ** 2) / (2 * prior_theta[:, np.newaxis]))\n", 224 | "bpdf /= np.sum(bpdf, axis=1)[:, np.newaxis]\n", 225 | "for i, p in enumerate(prior_theta):\n", 226 | " ppdf = scipy.stats.gaussian_kde(np.squeeze(y), weights=bpdf[i])\n", 227 | " plt.plot(\n", 228 | " x,\n", 229 | " ppdf.pdf(x),\n", 230 | " color=cmap(i / len(prior_theta)),\n", 231 | " label=f\"$\\\\theta/\\\\sigma$ = {p:.2f}\",\n", 232 | " )\n", 233 | "plt.legend(fontsize=6)\n", 234 | "plt.xlim(-10, 15)\n", 235 | "plt.xlabel(r\"$r$\")\n", 236 | "plt.ylabel(r\"$P(r)$\")\n", 237 | "plt.title(\"b) Bayesian\")\n", 238 | "plt.savefig(\"bayes.svg\")\n", 239 | "plt.show()" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "### Effects of Observable" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": { 253 | "scrolled": true 254 | }, 255 | "outputs": [], 256 | "source": [ 257 | "bayesian_results = []\n", 258 | "# scipy.stats.wasserstein_distance(y, y, u_weights=np.ones_like(y) / len(y), v_weights=bpdf[i])])\n", 259 | "x2 = np.linspace(-20, 20, 10000)\n", 260 | "for i in range(len(prior_theta)):\n", 261 | " ppdf = scipy.stats.gaussian_kde(np.squeeze(y), weights=bpdf[i])\n", 262 | " bayesian_results.append(\n", 263 | " [\n", 264 | " np.sum(ppdf.pdf(x) * x * (x[1] - x[0])),\n", 265 | " -np.nansum((x[1] - x[0]) * ppdf.pdf(x) * np.log(ppdf.pdf(x))),\n", 266 | " ]\n", 267 | " )\n", 268 | " print(i, bayesian_results[-1])" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "me_results = []\n", 278 | "for i in range(-5, 10):\n", 279 | " r = maxent.Restraint(lambda x: x, i, maxent.EmptyPrior())\n", 280 | " m = maxent.MaxentModel([r])\n", 281 | " m.compile(tf.keras.optimizers.Adam(0.001), \"mean_squared_error\")\n", 282 | " m.fit(y, epochs=4, batch_size=256, verbose=0)\n", 283 | " # d = scipy.stats.wasserstein_distance(y, y, u_weights=m.traj_weights)\n", 284 | " ppdf = scipy.stats.gaussian_kde(y, weights=m.traj_weights)\n", 285 | " d = -np.nansum((x[1] - x[0]) * ppdf.pdf(x) * np.log(ppdf.pdf(x)))\n", 286 | " me_results.append([i, d])\n", 287 | " print(np.sum(y * m.traj_weights), d)" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": { 294 | "scrolled": true 295 | }, 296 | "outputs": [], 297 | "source": [ 298 | "plt.figure(figsize=(3, 2), dpi=300)\n", 299 | "me_result = np.array(me_results)\n", 300 | "bayesian_results = np.array(bayesian_results)\n", 301 | "plt.plot(me_result[:, 0], me_result[:, 1], label=\"MaxEnt\", color=colors[0])\n", 302 | "plt.plot(\n", 303 | " bayesian_results[:, 0],\n", 304 | " bayesian_results[:, 1],\n", 305 | " linestyle=\"--\",\n", 306 | " label=\"Bayesian Inference\",\n", 307 | " color=colors[1],\n", 308 | ")\n", 309 | "plt.ylabel(\"Posterior Entropy\")\n", 310 | "plt.xlabel(\"$E[r]$\")\n", 311 | "plt.legend()\n", 312 | "plt.title(\"c) Posterior Entropy\")\n", 313 | "plt.savefig(\"post.svg\")\n", 314 | "plt.show()" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "bayesian_results[:]" 324 | ] 325 | } 326 | ], 327 | "metadata": { 328 | "kernelspec": { 329 | "display_name": "Python 3", 330 | "language": "python", 331 | "name": "python3" 332 | }, 333 | "language_info": { 334 | "codemirror_mode": { 335 | "name": "ipython", 336 | "version": 3 337 | }, 338 | "file_extension": ".py", 339 | "mimetype": "text/x-python", 340 | "name": "python", 341 | "nbconvert_exporter": "python", 342 | "pygments_lexer": "ipython3", 343 | "version": "3.8.3" 344 | } 345 | }, 346 | "nbformat": 4, 347 | "nbformat_minor": 4 348 | } 349 | -------------------------------------------------------------------------------- /paper/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with python 3.8 3 | # To update, run: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | absl-py==0.15.0 8 | # via 9 | # tensorboard 10 | # tensorflow 11 | argon2-cffi==21.3.0 12 | # via notebook 13 | argon2-cffi-bindings==21.2.0 14 | # via argon2-cffi 15 | asttokens==2.0.5 16 | # via stack-data 17 | astunparse==1.6.3 18 | # via tensorflow 19 | async-timeout==4.0.2 20 | # via redis 21 | attrs==21.4.0 22 | # via jsonschema 23 | backcall==0.2.0 24 | # via ipython 25 | beautifulsoup4==4.10.0 26 | # via nbconvert 27 | bleach==4.1.0 28 | # via nbconvert 29 | cachetools==5.0.0 30 | # via google-auth 31 | certifi==2021.10.8 32 | # via requests 33 | cffi==1.15.0 34 | # via argon2-cffi-bindings 35 | charset-normalizer==2.0.12 36 | # via requests 37 | click==8.1.2 38 | # via 39 | # distributed 40 | # pyabc 41 | cloudpickle==2.0.0 42 | # via 43 | # dask 44 | # distributed 45 | # pyabc 46 | # tensorflow-probability 47 | cycler==0.11.0 48 | # via matplotlib 49 | dask==2022.04.0 50 | # via distributed 51 | debugpy==1.6.0 52 | # via ipykernel 53 | decorator==5.1.1 54 | # via 55 | # ipython 56 | # tensorflow-probability 57 | defusedxml==0.7.1 58 | # via nbconvert 59 | deprecated==1.2.13 60 | # via redis 61 | distributed==2022.4.0 62 | # via pyabc 63 | dm-tree==0.1.6 64 | # via tensorflow-probability 65 | entrypoints==0.4 66 | # via 67 | # jupyter-client 68 | # nbconvert 69 | executing==0.8.3 70 | # via stack-data 71 | fastjsonschema==2.15.3 72 | # via nbformat 73 | flatbuffers==1.12 74 | # via tensorflow 75 | fonttools==4.31.2 76 | # via matplotlib 77 | fsspec==2022.3.0 78 | # via dask 79 | gast==0.3.3 80 | # via 81 | # tensorflow 82 | # tensorflow-probability 83 | gitdb==4.0.9 84 | # via gitpython 85 | gitpython==3.1.27 86 | # via pyabc 87 | google-auth==2.6.2 88 | # via 89 | # google-auth-oauthlib 90 | # tensorboard 91 | google-auth-oauthlib==0.4.6 92 | # via tensorboard 93 | google-pasta==0.2.0 94 | # via tensorflow 95 | greenlet==1.1.2 96 | # via sqlalchemy 97 | grpcio==1.32.0 98 | # via 99 | # tensorboard 100 | # tensorflow 101 | h5py==2.10.0 102 | # via tensorflow 103 | heapdict==1.0.1 104 | # via zict 105 | idna==3.3 106 | # via requests 107 | importlib-metadata==4.11.3 108 | # via markdown 109 | importlib-resources==5.6.0 110 | # via jsonschema 111 | ipykernel==6.12.1 112 | # via 113 | # ipywidgets 114 | # jupyter 115 | # jupyter-console 116 | # notebook 117 | # qtconsole 118 | ipython==8.2.0 119 | # via 120 | # ipykernel 121 | # ipywidgets 122 | # jupyter-console 123 | ipython-genutils==0.2.0 124 | # via 125 | # ipywidgets 126 | # notebook 127 | # qtconsole 128 | ipywidgets==7.7.0 129 | # via jupyter 130 | jabbar==0.0.15 131 | # via pyabc 132 | jedi==0.18.1 133 | # via ipython 134 | jinja2==3.1.1 135 | # via 136 | # distributed 137 | # nbconvert 138 | # notebook 139 | joblib==1.1.0 140 | # via 141 | # sbi 142 | # scikit-learn 143 | jsonschema==4.4.0 144 | # via nbformat 145 | jupyter==1.0.0 146 | # via -r requirements.in 147 | jupyter-client==7.2.1 148 | # via 149 | # ipykernel 150 | # jupyter-console 151 | # nbclient 152 | # notebook 153 | # qtconsole 154 | jupyter-console==6.4.3 155 | # via jupyter 156 | jupyter-core==4.9.2 157 | # via 158 | # jupyter-client 159 | # nbconvert 160 | # nbformat 161 | # notebook 162 | # qtconsole 163 | jupyterlab-pygments==0.1.2 164 | # via nbconvert 165 | jupyterlab-widgets==1.1.0 166 | # via ipywidgets 167 | keras-preprocessing==1.1.2 168 | # via tensorflow 169 | kiwisolver==1.4.2 170 | # via matplotlib 171 | locket==0.2.1 172 | # via partd 173 | markdown==3.3.6 174 | # via tensorboard 175 | markupsafe==2.1.1 176 | # via 177 | # jinja2 178 | # nbconvert 179 | matplotlib==3.5.1 180 | # via 181 | # -r requirements.in 182 | # maxentep 183 | # nflows 184 | # pyabc 185 | # pyknos 186 | # sbi 187 | # seaborn 188 | matplotlib-inline==0.1.3 189 | # via 190 | # ipykernel 191 | # ipython 192 | maxentep @ git+https://github.com/ur-whitelab/py0.git@nature_compsci 193 | # via -r requirements.in 194 | mistune==0.8.4 195 | # via nbconvert 196 | msgpack==1.0.3 197 | # via distributed 198 | nbclient==0.5.13 199 | # via nbconvert 200 | nbconvert==6.4.5 201 | # via 202 | # jupyter 203 | # notebook 204 | nbformat==5.3.0 205 | # via 206 | # ipywidgets 207 | # nbclient 208 | # nbconvert 209 | # notebook 210 | nest-asyncio==1.5.5 211 | # via 212 | # ipykernel 213 | # jupyter-client 214 | # nbclient 215 | # notebook 216 | nflows==0.14 217 | # via pyknos 218 | notebook==6.4.10 219 | # via 220 | # jupyter 221 | # widgetsnbextension 222 | numpy==1.19.2 223 | # via 224 | # -r requirements.in 225 | # h5py 226 | # keras-preprocessing 227 | # matplotlib 228 | # maxentep 229 | # nflows 230 | # opt-einsum 231 | # pandas 232 | # pyabc 233 | # pyknos 234 | # pyro-ppl 235 | # sbi 236 | # scikit-learn 237 | # scipy 238 | # seaborn 239 | # tensorboard 240 | # tensorflow 241 | # tensorflow-probability 242 | oauthlib==3.2.0 243 | # via requests-oauthlib 244 | opt-einsum==3.3.0 245 | # via 246 | # pyro-ppl 247 | # tensorflow 248 | packaging==21.3 249 | # via 250 | # bleach 251 | # dask 252 | # distributed 253 | # ipykernel 254 | # matplotlib 255 | # qtpy 256 | # redis 257 | pandas==1.2.4 258 | # via 259 | # -r requirements.in 260 | # pyabc 261 | # seaborn 262 | pandocfilters==1.5.0 263 | # via nbconvert 264 | parso==0.8.3 265 | # via jedi 266 | partd==1.2.0 267 | # via dask 268 | pexpect==4.8.0 269 | # via ipython 270 | pickleshare==0.7.5 271 | # via ipython 272 | pillow==9.1.0 273 | # via 274 | # matplotlib 275 | # sbi 276 | prometheus-client==0.13.1 277 | # via notebook 278 | prompt-toolkit==3.0.29 279 | # via 280 | # ipython 281 | # jupyter-console 282 | protobuf==3.20.0 283 | # via 284 | # tensorboard 285 | # tensorflow 286 | psutil==5.9.0 287 | # via 288 | # distributed 289 | # ipykernel 290 | ptyprocess==0.7.0 291 | # via 292 | # pexpect 293 | # terminado 294 | pure-eval==0.2.2 295 | # via stack-data 296 | pyabc==0.12.2 297 | # via -r requirements.in 298 | pyasn1==0.4.8 299 | # via 300 | # pyasn1-modules 301 | # rsa 302 | pyasn1-modules==0.2.8 303 | # via google-auth 304 | pycparser==2.21 305 | # via cffi 306 | pygments==2.11.2 307 | # via 308 | # ipython 309 | # jupyter-console 310 | # jupyterlab-pygments 311 | # nbconvert 312 | # qtconsole 313 | pyknos==0.14.2 314 | # via sbi 315 | pynmrstar==3.3.0 316 | # via -r requirements.in 317 | pyparsing==3.0.7 318 | # via 319 | # matplotlib 320 | # packaging 321 | pyro-api==0.1.2 322 | # via pyro-ppl 323 | pyro-ppl==1.8.1 324 | # via sbi 325 | pyrsistent==0.18.1 326 | # via jsonschema 327 | python-dateutil==2.8.2 328 | # via 329 | # jupyter-client 330 | # matplotlib 331 | # pandas 332 | pytz==2022.1 333 | # via pandas 334 | pyyaml==6.0 335 | # via 336 | # dask 337 | # distributed 338 | pyzmq==22.3.0 339 | # via 340 | # jupyter-client 341 | # notebook 342 | # qtconsole 343 | qtconsole==5.3.0 344 | # via jupyter 345 | qtpy==2.0.1 346 | # via qtconsole 347 | redis==4.2.2 348 | # via pyabc 349 | requests==2.27.1 350 | # via 351 | # pynmrstar 352 | # requests-oauthlib 353 | # tensorboard 354 | requests-oauthlib==1.3.1 355 | # via google-auth-oauthlib 356 | rsa==4.8 357 | # via google-auth 358 | sbi==0.15.1 359 | # via -r requirements.in 360 | scikit-learn==1.0.2 361 | # via 362 | # pyabc 363 | # sbi 364 | scipy==1.8.0 365 | # via 366 | # maxentep 367 | # pyabc 368 | # sbi 369 | # scikit-learn 370 | # seaborn 371 | seaborn==0.11.2 372 | # via -r requirements.in 373 | send2trash==1.8.0 374 | # via notebook 375 | six==1.15.0 376 | # via 377 | # absl-py 378 | # asttokens 379 | # astunparse 380 | # bleach 381 | # dm-tree 382 | # google-auth 383 | # google-pasta 384 | # grpcio 385 | # h5py 386 | # keras-preprocessing 387 | # python-dateutil 388 | # tensorflow 389 | # tensorflow-probability 390 | smmap==5.0.0 391 | # via gitdb 392 | sortedcontainers==2.4.0 393 | # via distributed 394 | soupsieve==2.3.1 395 | # via beautifulsoup4 396 | sqlalchemy==1.4.34 397 | # via pyabc 398 | stack-data==0.2.0 399 | # via ipython 400 | tblib==1.7.0 401 | # via distributed 402 | tensorboard==2.8.0 403 | # via 404 | # nflows 405 | # pyknos 406 | # sbi 407 | # tensorflow 408 | tensorboard-data-server==0.6.1 409 | # via tensorboard 410 | tensorboard-plugin-wit==1.8.1 411 | # via tensorboard 412 | tensorflow==2.4 413 | # via -r requirements.in 414 | tensorflow-estimator==2.4.0 415 | # via tensorflow 416 | tensorflow-probability==0.12 417 | # via -r requirements.in 418 | termcolor==1.1.0 419 | # via tensorflow 420 | terminado==0.13.3 421 | # via notebook 422 | testpath==0.6.0 423 | # via nbconvert 424 | threadpoolctl==3.1.0 425 | # via scikit-learn 426 | toolz==0.11.2 427 | # via 428 | # dask 429 | # distributed 430 | # partd 431 | torch==1.11.0 432 | # via 433 | # nflows 434 | # pyknos 435 | # pyro-ppl 436 | # sbi 437 | tornado==6.1 438 | # via 439 | # distributed 440 | # ipykernel 441 | # jupyter-client 442 | # notebook 443 | # terminado 444 | tqdm==4.64.0 445 | # via 446 | # -r requirements.in 447 | # maxentep 448 | # nflows 449 | # pyknos 450 | # pyro-ppl 451 | # sbi 452 | traitlets==5.1.1 453 | # via 454 | # ipykernel 455 | # ipython 456 | # ipywidgets 457 | # jupyter-client 458 | # jupyter-core 459 | # matplotlib-inline 460 | # nbclient 461 | # nbconvert 462 | # nbformat 463 | # notebook 464 | # qtconsole 465 | typing-extensions==3.7.4.3 466 | # via 467 | # tensorflow 468 | # torch 469 | urllib3==1.26.9 470 | # via 471 | # distributed 472 | # requests 473 | wcwidth==0.2.5 474 | # via prompt-toolkit 475 | webencodings==0.5.1 476 | # via bleach 477 | werkzeug==2.1.1 478 | # via tensorboard 479 | wheel==0.37.1 480 | # via 481 | # astunparse 482 | # tensorboard 483 | # tensorflow 484 | widgetsnbextension==3.6.0 485 | # via ipywidgets 486 | wrapt==1.12.1 487 | # via 488 | # deprecated 489 | # tensorflow 490 | zict==2.1.0 491 | # via distributed 492 | zipp==3.8.0 493 | # via 494 | # importlib-metadata 495 | # importlib-resources 496 | 497 | # The following packages are considered to be unsafe in a requirements file: 498 | # setuptools 499 | -------------------------------------------------------------------------------- /maxent/hyper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_probability as tfp 4 | from .core import * 5 | 6 | tfd = tfp.distributions 7 | tfb = tfp.bijectors 8 | 9 | EPS = np.finfo(np.float32).tiny 10 | 11 | 12 | def _merge_history(base, other, prefix=""): 13 | if base is None: 14 | return other 15 | if other is None: 16 | return base 17 | for k, v in other.history.items(): 18 | if prefix + k in other.history: 19 | base.history[prefix + k].extend(v) 20 | else: 21 | base.history[prefix + k] = v 22 | return base 23 | 24 | 25 | def negloglik(y: Array, rv_y: tfd.Distribution) -> Array: 26 | """ 27 | negative log likelihood 28 | 29 | :param y: observations 30 | :param rv_y: distribution 31 | :return: negative log likelihood of y 32 | """ 33 | logp = rv_y.log_prob(y + EPS) 34 | logp = tf.reduce_sum(tf.reshape(logp, (tf.shape(y)[0], -1)), axis=1) 35 | return -logp 36 | 37 | 38 | class ParameterJoint(tf.keras.Model): 39 | """Prior parameter model joint distribution 40 | 41 | This packages up how you want to sample prior paramters into one joint distribution. 42 | This has an important ability of reshaping output from distributions in case your simulation requires 43 | matrices, applying constraints, or projections. 44 | 45 | :param inputs: :class:`tf.keras.Input` or tuple of them. 46 | :param outputs: list of :py:class:`tfp.distributions.Distribution` 47 | :param reshapers: optional list of callables that will be called on outputs from your distribution 48 | """ 49 | 50 | def __init__( 51 | self, 52 | reshapers: List[Callable[[Array], Array]] = None, 53 | inputs: Union[tf.keras.Input, Tuple[tf.keras.Input]] = None, 54 | outputs: List[tfd.Distribution] = None, 55 | **kwargs 56 | ): 57 | if inputs is None or outputs is None: 58 | raise ValueError("Must pass inputs and outputs to construct model") 59 | if reshapers: 60 | self.reshapers = reshapers 61 | self.output_count = len(reshapers) 62 | else: 63 | self.output_count = len(outputs) 64 | self.reshapers = [lambda x: x for _ in range(self.output_count)] 65 | super(ParameterJoint, self).__init__(inputs=inputs, outputs=outputs, **kwargs) 66 | 67 | def compile(self, optimizer: object, **kwargs): 68 | """See ``compile`` method of :class:`tf.keras.Model`""" 69 | if "loss" in kwargs: 70 | raise ValueError("Do not set loss") 71 | super(ParameterJoint, self).compile( 72 | optimizer, loss=self.output_count * [negloglik] 73 | ) 74 | 75 | def sample( 76 | self, N: int, return_joint: bool = False 77 | ) -> Union[Tuple[Array, Array, Any], Array]: 78 | """Generate sample 79 | 80 | :param N: Number of samples (events) 81 | :param return_joint: return a joint :py:class:`tfp.distributions.Distribution` that can be called on ``y`` 82 | :return: the reshaped output samples and (if ``return_joint``) a value ``y`` which can be used to compute probabilities and :py:class:`tfp.distributions.Distribution` joint 83 | """ 84 | joint = self(tf.constant([1.0])) 85 | if type(joint) != list: 86 | joint = [joint] 87 | y = [j.sample(N) for j in joint] 88 | v = [self.reshapers[i](s) for i, s in enumerate(y)] 89 | if return_joint: 90 | return v, y, joint 91 | else: 92 | return v 93 | 94 | 95 | def _reweight( 96 | samples: Array, unbiased_joint: ParameterJoint, joint: ParameterJoint 97 | ) -> Array: 98 | batch_dim = samples[0].shape[0] 99 | logit = tf.zeros((batch_dim,)) 100 | for i, (uj, j) in enumerate(zip(unbiased_joint, joint)): 101 | # reduce across other axis (summing independent variable log ps) 102 | logitdiff = uj.log_prob(samples[i] + EPS) - j.log_prob(samples[i] + EPS) 103 | logit += tf.reduce_sum(tf.reshape(logitdiff, (batch_dim, -1)), axis=1) 104 | return tf.math.softmax(logit) 105 | 106 | 107 | class TrainableInputLayer(tf.keras.layers.Layer): 108 | """Create trainable input layer for :py:class:`tfp.distributions.Distribution` 109 | 110 | This will, given a fake input, return a trainable weight set by ``initial_value``. Use 111 | to feed into distributions that can be trained. 112 | 113 | :param initial_value: starting value, determines shape/dtype of output 114 | :param constraint: Callable that returns scalar given output. See :py:class:`tf.keras.layers.Layer` 115 | :param kwargs: See :py:class:`tf.Keras.layers.Layer` for additional arguments 116 | """ 117 | 118 | def __init__( 119 | self, 120 | initial_value: Array, 121 | constraint: Callable[[Array], float] = None, 122 | **kwargs 123 | ): 124 | super(TrainableInputLayer, self).__init__(**kwargs) 125 | flat = initial_value.flatten() 126 | self.initial_value = initial_value 127 | self.w = self.add_weight( 128 | "value", 129 | shape=initial_value.shape, 130 | initializer=tf.constant_initializer(flat), 131 | constraint=constraint, 132 | dtype=self.dtype, 133 | trainable=True, 134 | ) 135 | 136 | def call(self, inputs: Array) -> Array: 137 | """See call of :class:`tf.keras.layers.Layer`""" 138 | batch_dim = tf.shape(inputs)[:1] 139 | return tf.tile( 140 | self.w[tf.newaxis, ...], 141 | tf.concat((batch_dim, tf.ones(tf.rank(self.w), dtype=tf.int32)), axis=0), 142 | ) 143 | 144 | 145 | class HyperMaxentModel(MaxentModel): 146 | """Keras Maximum entropy model 147 | 148 | :param restraints: List of :class:`Restraint` 149 | :param prior_model: :class:`ParameterJoint` that specifies prior 150 | :param simulation: Callable that will generate observations given the output from ``prior_model`` 151 | :param reweight: True means use to remove effect of prior training updates via reweighting, which keeps as close as possible to given untrained ``prior_model`` 152 | :param name: Name of model 153 | """ 154 | 155 | def __init__( 156 | self, 157 | restraints: List[Restraint], 158 | prior_model: ParameterJoint, 159 | simulation: Callable[[Array], Array], 160 | reweight: bool = True, 161 | name: str = "hyper-maxent-model", 162 | **kwargs 163 | ): 164 | super(HyperMaxentModel, self).__init__( 165 | restraints=restraints, name=name, **kwargs 166 | ) 167 | self.prior_model = prior_model 168 | self.reweight = reweight 169 | self.unbiased_joint = prior_model(tf.constant([1.0])) 170 | # self.trajs = trajs 171 | if hasattr(self.unbiased_joint, "sample"): 172 | self.unbiased_joint = [self.unbiased_joint] 173 | self.simulation = simulation 174 | 175 | def fit( 176 | self, 177 | sample_batch_size: int = 256, 178 | final_batch_multiplier: int = 4, 179 | param_epochs: int = None, 180 | outer_epochs: int = 10, 181 | **kwargs 182 | ) -> tf.keras.callbacks.History: 183 | """Fit to given outcomes from ``simulation`` 184 | 185 | :param sample_batch_size: Number of observations to sample per ``outer_epochs`` 186 | :param final_batch_multiplier: Sets number of final MaxEnt fitting step after training ``prior_model``. Final number of MaxEnt steps will be ``final_batch_multiplier * sample_batch_size`` 187 | :param param_epochs: Number of times ``prior_model`` will be fit to sampled observations 188 | :param outer_epochs: Number of loops of sampling/``prior_model`` fitting 189 | :param kwargs: See :class:tf.keras.Model ``fit`` method for further optional arguments, like ``verbose=0`` to hide output 190 | :return: The :class:`tf.keras.callbacks.History` of fit 191 | """ 192 | me_history, prior_history = None, None 193 | 194 | # backwards compatible for my bad spelling 195 | if "outter_epochs" in kwargs: 196 | outer_epochs = kwargs["outter_epochs"] 197 | del kwargs["outter_epochs"] 198 | 199 | # we want to reset optimizer state each time we have 200 | # new trajectories 201 | # but compile, new object assignment both 202 | # don't work. 203 | # So I guess use SGD? 204 | 205 | def new_optimizer(): 206 | return self.optimizer.__class__(**self.optimizer.get_config()) 207 | 208 | uni_flags = ["verbose"] 209 | prior_kwargs = {} 210 | for u in uni_flags: 211 | if u in kwargs: 212 | prior_kwargs[u] = kwargs[u] 213 | 214 | if param_epochs is None: 215 | param_epochs = 10 216 | if "epochs" in kwargs: 217 | param_epochs = kwargs["epochs"] 218 | for i in range(outer_epochs - 1): 219 | # sample parameters 220 | psample, y, joint = self.prior_model.sample(sample_batch_size, True) 221 | trajs = self.simulation(*psample) 222 | try: 223 | if trajs.shape[0] != sample_batch_size: 224 | raise ValueError( 225 | "Simulation must take in batched samples and return batched outputs" 226 | ) 227 | except TypeError as e: 228 | raise ValueError( 229 | "Simulation must take in batched samples and return batched outputs" 230 | ) 231 | # get reweight, so we keep original parameter 232 | # probs 233 | rw = _reweight(y, self.unbiased_joint, joint) 234 | # TODO reset optimizer state 235 | if self.reweight: 236 | hm = super(HyperMaxentModel, self).fit(trajs, rw, **kwargs) 237 | else: 238 | hm = super(HyperMaxentModel, self).fit(trajs, **kwargs) 239 | fake_x = tf.constant(sample_batch_size * [1.0]) 240 | hp = self.prior_model.fit( 241 | fake_x, 242 | y, 243 | sample_weight=self.traj_weights, 244 | epochs=param_epochs, 245 | **prior_kwargs 246 | ) 247 | if me_history is None: 248 | me_history = hm 249 | prior_history = hp 250 | else: 251 | me_history = _merge_history(me_history, hm) 252 | prior_history = _merge_history(prior_history, hp) 253 | 254 | # For final fit use more samples 255 | outs = [] 256 | rws = [] 257 | for i in range(final_batch_multiplier): 258 | psample, y, joint = self.prior_model.sample(sample_batch_size, True) 259 | trajs = self.simulation(*psample) 260 | outs.append(trajs) 261 | rw = _reweight(y, self.unbiased_joint, joint) 262 | rws.append(rw) 263 | trajs = np.concatenate(outs, axis=0) 264 | rw = np.concatenate(rws, axis=0) 265 | self.weights_hyper = rw 266 | self.trajs = trajs 267 | # TODO reset optimizer state 268 | self.reset_weights() 269 | if self.reweight: 270 | hm = super(HyperMaxentModel, self).fit(trajs, rw, **kwargs) 271 | else: 272 | hm = super(HyperMaxentModel, self).fit(trajs, **kwargs) 273 | me_history = _merge_history(me_history, hm) 274 | return _merge_history(me_history, prior_history, "prior-") 275 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with python 3.8 3 | # To update, run: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | absl-py==1.0.0 8 | # via tensorboard 9 | alabaster==0.7.12 10 | # via sphinx 11 | anyio==3.5.0 12 | # via jupyter-server 13 | argon2-cffi==21.3.0 14 | # via 15 | # jupyter-server 16 | # notebook 17 | argon2-cffi-bindings==21.2.0 18 | # via argon2-cffi 19 | asttokens==2.0.5 20 | # via stack-data 21 | async-timeout==4.0.2 22 | # via redis 23 | attrs==21.4.0 24 | # via 25 | # jsonschema 26 | # jupyter-cache 27 | # markdown-it-py 28 | babel==2.9.1 29 | # via sphinx 30 | backcall==0.2.0 31 | # via ipython 32 | beautifulsoup4==4.10.0 33 | # via nbconvert 34 | bleach==4.1.0 35 | # via nbconvert 36 | cachetools==5.0.0 37 | # via google-auth 38 | certifi==2021.10.8 39 | # via requests 40 | cffi==1.15.0 41 | # via argon2-cffi-bindings 42 | charset-normalizer==2.0.12 43 | # via requests 44 | click==8.1.2 45 | # via 46 | # distributed 47 | # pyabc 48 | cloudpickle==2.0.0 49 | # via 50 | # dask 51 | # distributed 52 | # pyabc 53 | colorama==0.4.4 54 | # via nbdime 55 | cycler==0.11.0 56 | # via matplotlib 57 | dask==2022.04.0 58 | # via distributed 59 | debugpy==1.6.0 60 | # via ipykernel 61 | decorator==5.1.1 62 | # via ipython 63 | defusedxml==0.7.1 64 | # via nbconvert 65 | deprecated==1.2.13 66 | # via redis 67 | distributed==2022.4.0 68 | # via pyabc 69 | docutils==0.17.1 70 | # via 71 | # myst-nb 72 | # myst-parser 73 | # sphinx 74 | # sphinx-rtd-theme 75 | # sphinx-togglebutton 76 | entrypoints==0.4 77 | # via 78 | # jupyter-client 79 | # nbconvert 80 | executing==0.8.3 81 | # via stack-data 82 | fonttools==4.31.2 83 | # via matplotlib 84 | fsspec==2022.3.0 85 | # via dask 86 | gitdb==4.0.9 87 | # via gitpython 88 | gitpython==3.1.27 89 | # via 90 | # nbdime 91 | # pyabc 92 | google-auth==2.6.2 93 | # via 94 | # google-auth-oauthlib 95 | # tensorboard 96 | google-auth-oauthlib==0.4.6 97 | # via tensorboard 98 | greenlet==1.1.2 99 | # via sqlalchemy 100 | grpcio==1.44.0 101 | # via tensorboard 102 | heapdict==1.0.1 103 | # via zict 104 | idna==3.3 105 | # via 106 | # anyio 107 | # requests 108 | imagesize==1.3.0 109 | # via sphinx 110 | importlib-metadata==4.11.3 111 | # via 112 | # markdown 113 | # myst-nb 114 | # sphinx 115 | importlib-resources==5.6.0 116 | # via jsonschema 117 | ipykernel==6.11.0 118 | # via 119 | # ipywidgets 120 | # notebook 121 | ipython==8.2.0 122 | # via 123 | # ipykernel 124 | # ipywidgets 125 | # jupyter-sphinx 126 | # myst-nb 127 | ipython-genutils==0.2.0 128 | # via 129 | # ipywidgets 130 | # notebook 131 | ipywidgets==7.7.0 132 | # via 133 | # jupyter-sphinx 134 | # myst-nb 135 | jabbar==0.0.15 136 | # via pyabc 137 | jedi==0.18.1 138 | # via ipython 139 | jinja2==3.1.1 140 | # via 141 | # distributed 142 | # jupyter-server 143 | # myst-parser 144 | # nbconvert 145 | # nbdime 146 | # notebook 147 | # sphinx 148 | joblib==1.1.0 149 | # via 150 | # sbi 151 | # scikit-learn 152 | jsonschema==4.4.0 153 | # via nbformat 154 | jupyter-cache==0.4.3 155 | # via myst-nb 156 | jupyter-client==7.2.1 157 | # via 158 | # ipykernel 159 | # jupyter-server 160 | # nbclient 161 | # notebook 162 | jupyter-core==4.9.2 163 | # via 164 | # jupyter-client 165 | # jupyter-server 166 | # nbconvert 167 | # nbformat 168 | # notebook 169 | jupyter-server==1.16.0 170 | # via 171 | # jupyter-server-mathjax 172 | # nbdime 173 | jupyter-server-mathjax==0.2.5 174 | # via nbdime 175 | jupyter-sphinx==0.3.2 176 | # via myst-nb 177 | jupyterlab-pygments==0.1.2 178 | # via nbconvert 179 | jupyterlab-widgets==1.1.0 180 | # via ipywidgets 181 | kiwisolver==1.4.2 182 | # via matplotlib 183 | locket==0.2.1 184 | # via partd 185 | markdown==3.3.6 186 | # via tensorboard 187 | markdown-it-py==1.1.0 188 | # via 189 | # mdit-py-plugins 190 | # myst-parser 191 | markupsafe==2.1.1 192 | # via 193 | # jinja2 194 | # nbconvert 195 | matplotlib==3.5.1 196 | # via 197 | # maxentep 198 | # nflows 199 | # pyabc 200 | # pyknos 201 | # sbi 202 | # seaborn 203 | matplotlib-inline==0.1.3 204 | # via 205 | # ipykernel 206 | # ipython 207 | maxentep @ git+https://github.com/ur-whitelab/py0.git@nature_compsci 208 | # via -r requirements.in 209 | mdit-py-plugins==0.2.8 210 | # via myst-parser 211 | mistune==0.8.4 212 | # via nbconvert 213 | msgpack==1.0.3 214 | # via distributed 215 | myst-nb==0.13.2 216 | # via -r requirements.in 217 | myst-parser==0.15.2 218 | # via 219 | # -r requirements.in 220 | # myst-nb 221 | nbclient==0.5.13 222 | # via 223 | # jupyter-cache 224 | # nbconvert 225 | nbconvert==6.4.5 226 | # via 227 | # jupyter-server 228 | # jupyter-sphinx 229 | # myst-nb 230 | # notebook 231 | nbdime==3.1.1 232 | # via jupyter-cache 233 | nbformat==5.2.0 234 | # via 235 | # ipywidgets 236 | # jupyter-cache 237 | # jupyter-server 238 | # jupyter-sphinx 239 | # myst-nb 240 | # nbclient 241 | # nbconvert 242 | # nbdime 243 | # notebook 244 | nest-asyncio==1.5.5 245 | # via 246 | # ipykernel 247 | # jupyter-client 248 | # nbclient 249 | # notebook 250 | nflows==0.14 251 | # via pyknos 252 | notebook==6.4.10 253 | # via widgetsnbextension 254 | numpy==1.22.3 255 | # via 256 | # matplotlib 257 | # maxentep 258 | # nflows 259 | # opt-einsum 260 | # pandas 261 | # pyabc 262 | # pyknos 263 | # pyro-ppl 264 | # sbi 265 | # scikit-learn 266 | # scipy 267 | # seaborn 268 | # tensorboard 269 | oauthlib==3.2.0 270 | # via requests-oauthlib 271 | opt-einsum==3.3.0 272 | # via pyro-ppl 273 | packaging==21.3 274 | # via 275 | # bleach 276 | # dask 277 | # distributed 278 | # jupyter-server 279 | # matplotlib 280 | # redis 281 | # sphinx 282 | pandas==1.4.2 283 | # via 284 | # pyabc 285 | # seaborn 286 | pandocfilters==1.5.0 287 | # via nbconvert 288 | parso==0.8.3 289 | # via jedi 290 | partd==1.2.0 291 | # via dask 292 | pexpect==4.8.0 293 | # via ipython 294 | pickleshare==0.7.5 295 | # via ipython 296 | pillow==9.1.0 297 | # via 298 | # matplotlib 299 | # sbi 300 | prometheus-client==0.13.1 301 | # via 302 | # jupyter-server 303 | # notebook 304 | prompt-toolkit==3.0.28 305 | # via ipython 306 | protobuf==3.20.0 307 | # via tensorboard 308 | psutil==5.9.0 309 | # via 310 | # distributed 311 | # ipykernel 312 | ptyprocess==0.7.0 313 | # via 314 | # pexpect 315 | # terminado 316 | pure-eval==0.2.2 317 | # via stack-data 318 | pyabc==0.12.2 319 | # via -r requirements.in 320 | pyasn1==0.4.8 321 | # via 322 | # pyasn1-modules 323 | # rsa 324 | pyasn1-modules==0.2.8 325 | # via google-auth 326 | pycparser==2.21 327 | # via cffi 328 | pygments==2.11.2 329 | # via 330 | # ipython 331 | # jupyterlab-pygments 332 | # nbconvert 333 | # nbdime 334 | # sphinx 335 | pyknos==0.14.2 336 | # via sbi 337 | pynmrstar==3.3.0 338 | # via -r requirements.in 339 | pyparsing==3.0.7 340 | # via 341 | # matplotlib 342 | # packaging 343 | pyro-api==0.1.2 344 | # via pyro-ppl 345 | pyro-ppl==1.8.1 346 | # via sbi 347 | pyrsistent==0.18.1 348 | # via jsonschema 349 | python-dateutil==2.8.2 350 | # via 351 | # jupyter-client 352 | # matplotlib 353 | # pandas 354 | pytz==2022.1 355 | # via 356 | # babel 357 | # pandas 358 | pyyaml==6.0 359 | # via 360 | # dask 361 | # distributed 362 | # myst-nb 363 | # myst-parser 364 | pyzmq==22.3.0 365 | # via 366 | # jupyter-client 367 | # jupyter-server 368 | # notebook 369 | redis==4.2.1 370 | # via pyabc 371 | requests==2.27.1 372 | # via 373 | # nbdime 374 | # pynmrstar 375 | # requests-oauthlib 376 | # sphinx 377 | # tensorboard 378 | requests-oauthlib==1.3.1 379 | # via google-auth-oauthlib 380 | rsa==4.8 381 | # via google-auth 382 | sbi==0.18.0 383 | # via -r requirements.in 384 | scikit-learn==1.0.2 385 | # via 386 | # pyabc 387 | # sbi 388 | scipy==1.7.3 389 | # via 390 | # -r requirements.in 391 | # maxentep 392 | # pyabc 393 | # sbi 394 | # scikit-learn 395 | # seaborn 396 | seaborn==0.11.2 397 | # via -r requirements.in 398 | send2trash==1.8.0 399 | # via 400 | # jupyter-server 401 | # notebook 402 | six==1.16.0 403 | # via 404 | # absl-py 405 | # asttokens 406 | # bleach 407 | # google-auth 408 | # grpcio 409 | # python-dateutil 410 | smmap==5.0.0 411 | # via gitdb 412 | sniffio==1.2.0 413 | # via anyio 414 | snowballstemmer==2.2.0 415 | # via sphinx 416 | sortedcontainers==2.4.0 417 | # via distributed 418 | soupsieve==2.3.1 419 | # via beautifulsoup4 420 | sphinx==4.5.0 421 | # via 422 | # -r requirements.in 423 | # jupyter-sphinx 424 | # myst-nb 425 | # myst-parser 426 | # sphinx-autodoc-typehints 427 | # sphinx-rtd-theme 428 | # sphinx-togglebutton 429 | sphinx-autodoc-typehints==1.17.0 430 | # via -r requirements.in 431 | sphinx-rtd-theme==1.0.0 432 | # via -r requirements.in 433 | sphinx-togglebutton==0.3.1 434 | # via myst-nb 435 | sphinxcontrib-applehelp==1.0.2 436 | # via sphinx 437 | sphinxcontrib-devhelp==1.0.2 438 | # via sphinx 439 | sphinxcontrib-htmlhelp==2.0.0 440 | # via sphinx 441 | sphinxcontrib-jsmath==1.0.1 442 | # via sphinx 443 | sphinxcontrib-qthelp==1.0.3 444 | # via sphinx 445 | sphinxcontrib-serializinghtml==1.1.5 446 | # via sphinx 447 | sqlalchemy==1.4.34 448 | # via 449 | # jupyter-cache 450 | # pyabc 451 | stack-data==0.2.0 452 | # via ipython 453 | tblib==1.7.0 454 | # via distributed 455 | tensorboard==2.8.0 456 | # via 457 | # nflows 458 | # pyknos 459 | # sbi 460 | tensorboard-data-server==0.6.1 461 | # via tensorboard 462 | tensorboard-plugin-wit==1.8.1 463 | # via tensorboard 464 | terminado==0.13.3 465 | # via 466 | # jupyter-server 467 | # notebook 468 | testpath==0.6.0 469 | # via nbconvert 470 | threadpoolctl==3.1.0 471 | # via scikit-learn 472 | toolz==0.11.2 473 | # via 474 | # dask 475 | # distributed 476 | # partd 477 | torch==1.11.0 478 | # via 479 | # nflows 480 | # pyknos 481 | # pyro-ppl 482 | # sbi 483 | tornado==6.1 484 | # via 485 | # distributed 486 | # ipykernel 487 | # jupyter-client 488 | # jupyter-server 489 | # nbdime 490 | # notebook 491 | # terminado 492 | tqdm==4.63.1 493 | # via 494 | # maxentep 495 | # nflows 496 | # pyknos 497 | # pyro-ppl 498 | # sbi 499 | traitlets==5.1.1 500 | # via 501 | # ipykernel 502 | # ipython 503 | # ipywidgets 504 | # jupyter-client 505 | # jupyter-core 506 | # jupyter-server 507 | # matplotlib-inline 508 | # nbclient 509 | # nbconvert 510 | # nbformat 511 | # notebook 512 | typing-extensions==4.1.1 513 | # via torch 514 | urllib3==1.26.9 515 | # via 516 | # distributed 517 | # requests 518 | wcwidth==0.2.5 519 | # via prompt-toolkit 520 | webencodings==0.5.1 521 | # via bleach 522 | websocket-client==1.3.2 523 | # via jupyter-server 524 | werkzeug==2.1.1 525 | # via tensorboard 526 | wheel==0.37.1 527 | # via 528 | # sphinx-togglebutton 529 | # tensorboard 530 | widgetsnbextension==3.6.0 531 | # via ipywidgets 532 | wrapt==1.14.0 533 | # via deprecated 534 | zict==2.1.0 535 | # via distributed 536 | zipp==3.8.0 537 | # via 538 | # importlib-metadata 539 | # importlib-resources 540 | 541 | # The following packages are considered to be unsafe in a requirements file: 542 | # setuptools 543 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /paper/gravitation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Gravitation Example" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import os\n", 18 | "import tensorflow as tf\n", 19 | "import torch\n", 20 | "import maxent\n", 21 | "from sbi_gravitation import GravitySimulator, sim_wrapper, get_observation_points\n", 22 | "from torch.distributions.multivariate_normal import MultivariateNormal\n", 23 | "from sbi.inference import infer\n", 24 | "import scipy\n", 25 | "\n", 26 | "import pandas as pd\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "from matplotlib.lines import Line2D\n", 29 | "import seaborn as sns\n", 30 | "from functools import partialmethod\n", 31 | "from tqdm import tqdm\n", 32 | "\n", 33 | "tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)\n", 34 | "\n", 35 | "np.random.seed(12656)\n", 36 | "sns.set_context(\"paper\")\n", 37 | "sns.set_style(\n", 38 | " \"white\",\n", 39 | " {\n", 40 | " \"xtick.bottom\": True,\n", 41 | " \"ytick.left\": True,\n", 42 | " \"xtick.color\": \"#333333\",\n", 43 | " \"ytick.color\": \"#333333\",\n", 44 | " },\n", 45 | ")\n", 46 | "plt.rcParams[\"mathtext.fontset\"] = \"dejavuserif\"\n", 47 | "colors = [\"#1b9e77\", \"#d95f02\", \"#7570b3\", \"#e7298a\", \"#66a61e\"]" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# set up true parameters\n", 57 | "m1 = 100.0 # solar masses\n", 58 | "m2 = 50.0 # solar masses\n", 59 | "m3 = 75 # solar masses\n", 60 | "G = 1.90809e5 # solar radius / solar mass * (km/s)^2\n", 61 | "v0 = np.array([15.0, -40.0]) # km/s\n", 62 | "\n", 63 | "true_params = [m1, m2, m3, v0[0], v0[1]]\n", 64 | "\n", 65 | "# set prior means\n", 66 | "prior_means = [85.0, 40.0, 70.0, 12.0, -30.0]\n", 67 | "prior_cov = np.eye(5) * 50" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "# generate true trajectory and apply some noise to it\n", 77 | "if os.path.exists(\"true_trajectory.txt\"):\n", 78 | " true_traj = np.genfromtxt(\"true_trajectory.txt\")\n", 79 | "else:\n", 80 | " sim = GravitySimulator(m1, m2, m3, v0, G, random_noise=False)\n", 81 | " true_traj = sim.run()\n", 82 | " np.savetxt(\"true_trajectory.txt\", true_traj)\n", 83 | "\n", 84 | "if os.path.exists(\"noisy_trajectory.txt\"):\n", 85 | " noisy_traj = np.genfromtxt(\"noisy_trajectory.txt\")\n", 86 | "else:\n", 87 | " sim = GravitySimulator(m1, m2, m3, v0, G, random_noise=True)\n", 88 | " noisy_traj = sim.run()\n", 89 | " np.savetxt(\"noisy_trajectory.txt\", noisy_traj)\n", 90 | "\n", 91 | "observed_points = get_observation_points(noisy_traj)\n", 92 | "observation_summary_stats = observed_points.flatten()\n", 93 | "sim = GravitySimulator(m1, m2, m3, v0, G, random_noise=False)\n", 94 | "sim.run()\n", 95 | "sim.plot_traj()" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": { 102 | "tags": [ 103 | "hide-output" 104 | ] 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "# perform SNL inference\n", 109 | "prior = MultivariateNormal(\n", 110 | " loc=torch.as_tensor(prior_means),\n", 111 | " covariance_matrix=torch.as_tensor(torch.eye(5) * 50),\n", 112 | ")\n", 113 | "\n", 114 | "posterior = infer(\n", 115 | " sim_wrapper, prior, method=\"SNLE\", num_simulations=2048, num_workers=16\n", 116 | ")" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": { 123 | "tags": [ 124 | "hide-output" 125 | ] 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "# sample from SNL posterior\n", 130 | "samples = posterior.sample((2000,), x=observation_summary_stats)\n", 131 | "snl_data = np.array(samples)\n", 132 | "np.savetxt(\"wide_prior_samples.txt\", snl_data)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "# set up restraints for maxent\n", 142 | "# restraint structure: [value, uncertainty, indices... ]\n", 143 | "restraints = []\n", 144 | "for i, point in enumerate(observed_points):\n", 145 | " value1 = point[0]\n", 146 | " value2 = point[1]\n", 147 | " uncertainty = 25\n", 148 | " index = 20 * i + 19 # based on how we slice in get_observation_points()\n", 149 | " restraints.append([value1, uncertainty, index, 0])\n", 150 | " restraints.append([value2, uncertainty, index, 1])" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "# set up maxent restraints\n", 160 | "maxent_restraints = []\n", 161 | "\n", 162 | "for i in range(len(restraints)):\n", 163 | " traj_index = tuple(restraints[i][2:])\n", 164 | " value = restraints[i][0]\n", 165 | " uncertainty = restraints[i][1]\n", 166 | " p = maxent.EmptyPrior()\n", 167 | " r = maxent.Restraint(lambda traj, i=traj_index: traj[i], value, p)\n", 168 | " maxent_restraints.append(r)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "# sample from prior for maxent\n", 178 | "if os.path.exists(\"maxent_prior_samples.npy\"):\n", 179 | " prior_dist = np.load(\"maxent_prior_samples.npy\")\n", 180 | "else:\n", 181 | " prior_dist = np.random.multivariate_normal(prior_means, prior_cov, size=2048)\n", 182 | " np.save(\"maxent_prior_samples.npy\", prior_dist)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "# generate trajectories for maxent from prior samples\n", 192 | "trajs = np.zeros([prior_dist.shape[0], 100, 2])\n", 193 | "\n", 194 | "for i, sample in enumerate(prior_dist):\n", 195 | " m1, m2, m3, v0 = sample[0], sample[1], sample[2], sample[3:]\n", 196 | " sim = GravitySimulator(m1, m2, m3, v0, random_noise=False)\n", 197 | " traj = sim.run()\n", 198 | " trajs[i] = traj\n", 199 | "\n", 200 | "maxent_trajs = trajs\n", 201 | "np.save(\"maxent_raw_trajectories.npy\", trajs)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": { 208 | "scrolled": true 209 | }, 210 | "outputs": [], 211 | "source": [ 212 | "# run maxent on trajectories\n", 213 | "batch_size = prior_dist.shape[0]\n", 214 | "\n", 215 | "model = maxent.MaxentModel(maxent_restraints)\n", 216 | "model.compile(tf.keras.optimizers.Adam(1e-4), \"mean_squared_error\")\n", 217 | "# short burn-in\n", 218 | "h = model.fit(trajs, batch_size=batch_size, epochs=5000, verbose=0)\n", 219 | "# restart to reset learning rate\n", 220 | "h = model.fit(trajs, batch_size=batch_size, epochs=25000, verbose=0)\n", 221 | "\n", 222 | "np.savetxt(\"maxent_loss.txt\", h.history[\"loss\"])\n", 223 | "\n", 224 | "maxent_weights = model.traj_weights\n", 225 | "np.savetxt(\"maxent_traj_weights.txt\", maxent_weights)\n", 226 | "\n", 227 | "maxent_avg_traj = np.sum(trajs * maxent_weights[:, np.newaxis, np.newaxis], axis=0)\n", 228 | "np.savetxt(\"maxent_avg_traj.txt\", maxent_avg_traj)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "### Plotting Results" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "# simulate traj generated by prior means\n", 245 | "sim = GravitySimulator(prior_means[0], prior_means[1], prior_means[2], prior_means[3:])\n", 246 | "prior_means_traj = sim.run()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "# simulate trajectories from SNL samples\n", 256 | "snl_trajs = np.zeros([snl_data.shape[0], noisy_traj.shape[0], noisy_traj.shape[1]])\n", 257 | "for i, sample in enumerate(snl_data):\n", 258 | " m1, m2, m3, v0 = sample[0], sample[1], sample[2], [sample[3], sample[4]]\n", 259 | " sim = GravitySimulator(m1, m2, m3, v0)\n", 260 | " traj = sim.run()\n", 261 | " snl_trajs[i] = traj\n", 262 | "\n", 263 | "mean_snl_traj = np.mean(snl_trajs, axis=0)\n", 264 | "np.savetxt(\"mean_snl_traj.txt\", mean_snl_traj)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "alpha_val = 0.7\n", 274 | "fig, axes = plt.subplots(figsize=(5, 3), dpi=300)\n", 275 | "\n", 276 | "# plot the observation points\n", 277 | "axes.scatter(\n", 278 | " observed_points[:, 0],\n", 279 | " observed_points[:, 1],\n", 280 | " color=\"black\",\n", 281 | " zorder=10,\n", 282 | " marker=\"*\",\n", 283 | " label=\"Observed Points\",\n", 284 | ")\n", 285 | "\n", 286 | "# plot the trajectory generated by prior means\n", 287 | "sim.set_traj(prior_means_traj)\n", 288 | "sim.plot_traj(\n", 289 | " fig=fig,\n", 290 | " axes=axes,\n", 291 | " make_colorbar=False,\n", 292 | " save=False,\n", 293 | " cmap=plt.get_cmap(\"Greys\").reversed(),\n", 294 | " color=\"grey\",\n", 295 | " fade_lines=False,\n", 296 | " alpha=alpha_val,\n", 297 | " linestyle=\"-.\",\n", 298 | " linewidth=1,\n", 299 | " label=\"Prior Mean\",\n", 300 | ")\n", 301 | "\n", 302 | "# plot the SNL mean trajectory\n", 303 | "sim.set_traj(mean_snl_traj)\n", 304 | "sim.plot_traj(\n", 305 | " fig=fig,\n", 306 | " axes=axes,\n", 307 | " make_colorbar=False,\n", 308 | " save=False,\n", 309 | " cmap=plt.get_cmap(\"Greens\").reversed(),\n", 310 | " color=colors[0],\n", 311 | " fade_lines=False,\n", 312 | " linewidth=1,\n", 313 | " alpha=alpha_val,\n", 314 | " label=\"SNL\",\n", 315 | ")\n", 316 | "\n", 317 | "# plot the true trajectory\n", 318 | "sim.set_traj(true_traj)\n", 319 | "sim.plot_traj(\n", 320 | " fig=fig,\n", 321 | " axes=axes,\n", 322 | " make_colorbar=False,\n", 323 | " save=False,\n", 324 | " cmap=plt.get_cmap(\"Reds\").reversed(),\n", 325 | " color=\"black\",\n", 326 | " fade_lines=False,\n", 327 | " alpha=alpha_val,\n", 328 | " linestyle=\":\",\n", 329 | " linewidth=1,\n", 330 | " label=\"True Path\",\n", 331 | " label_attractors=False,\n", 332 | ")\n", 333 | "\n", 334 | "# plot the maxent average trajectory\n", 335 | "sim.set_traj(maxent_avg_traj)\n", 336 | "sim.plot_traj(\n", 337 | " fig=fig,\n", 338 | " axes=axes,\n", 339 | " make_colorbar=False,\n", 340 | " save=False,\n", 341 | " cmap=plt.get_cmap(\"Oranges\").reversed(),\n", 342 | " color=colors[2],\n", 343 | " fade_lines=False,\n", 344 | " alpha=alpha_val,\n", 345 | " linestyle=\"-\",\n", 346 | " linewidth=1,\n", 347 | " label=\"MaxEnt\",\n", 348 | " label_attractors=True,\n", 349 | ")\n", 350 | "\n", 351 | "# set limits manually\n", 352 | "axes.set_xlim(-5, 130)\n", 353 | "axes.set_ylim(-30, 75)\n", 354 | "\n", 355 | "plt.legend(loc=\"upper left\", bbox_to_anchor=(1.05, 1.0))\n", 356 | "plt.tight_layout()\n", 357 | "\n", 358 | "# plt.savefig('paths_compare.png')\n", 359 | "# plt.savefig('paths_compare.svg')\n", 360 | "plt.show()" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "# set up KDE plotting of posteriors\n", 370 | "column_names = [\"m1\", \"m2\", \"m3\", \"v0x\", \"v0y\"]\n", 371 | "\n", 372 | "snl_dist = np.array(snl_data)\n", 373 | "snl_frame = pd.DataFrame(snl_dist, columns=column_names)\n", 374 | "\n", 375 | "maxent_dist = np.load(\"maxent_prior_samples.npy\")\n", 376 | "maxent_frame = pd.DataFrame(maxent_dist, columns=column_names)\n", 377 | "\n", 378 | "fig, axes = plt.subplots(nrows=5, ncols=1, figsize=(5, 5), dpi=300, sharex=False)\n", 379 | "\n", 380 | "# iterate over the five parameters\n", 381 | "n_bins = 30\n", 382 | "for i, key in enumerate(column_names):\n", 383 | " sns.histplot(\n", 384 | " data=snl_frame,\n", 385 | " x=key,\n", 386 | " ax=axes[i],\n", 387 | " color=colors[0],\n", 388 | " stat=\"probability\",\n", 389 | " element=\"step\",\n", 390 | " kde=True,\n", 391 | " fill=False,\n", 392 | " bins=n_bins,\n", 393 | " lw=1.0,\n", 394 | " )\n", 395 | " sns.histplot(\n", 396 | " data=maxent_frame,\n", 397 | " x=key,\n", 398 | " ax=axes[i],\n", 399 | " color=colors[2],\n", 400 | " stat=\"probability\",\n", 401 | " element=\"step\",\n", 402 | " kde=True,\n", 403 | " fill=False,\n", 404 | " bins=n_bins,\n", 405 | " weights=maxent_weights,\n", 406 | " lw=1.0,\n", 407 | " )\n", 408 | " sns.histplot(\n", 409 | " data=maxent_frame,\n", 410 | " x=key,\n", 411 | " ax=axes[i],\n", 412 | " color=colors[3],\n", 413 | " stat=\"probability\",\n", 414 | " element=\"step\",\n", 415 | " kde=True,\n", 416 | " fill=False,\n", 417 | " bins=n_bins,\n", 418 | " lw=1.0,\n", 419 | " )\n", 420 | " axes[i].axvline(prior_means[i], ls=\"-.\", color=\"grey\", lw=1.2)\n", 421 | " axes[i].axvline(true_params[i], ls=\":\", color=\"black\", lw=1.2)\n", 422 | " axes[i].set_xlabel(key)\n", 423 | "\n", 424 | "# custom lines object for making legend\n", 425 | "custom_lines = [\n", 426 | " Line2D([0], [0], color=colors[3], lw=2),\n", 427 | " Line2D([0], [0], color=colors[0], lw=2),\n", 428 | " Line2D([0], [0], color=colors[2], lw=2),\n", 429 | " Line2D([0], [0], color=\"black\", ls=\":\", lw=2),\n", 430 | " Line2D([0], [0], color=\"grey\", ls=\"-.\", lw=2),\n", 431 | "]\n", 432 | "axes[0].legend(\n", 433 | " custom_lines,\n", 434 | " [\"Prior\", \"SNL\", \"MaxEnt\", \"True Parameters\", \"Prior Mean\"],\n", 435 | " loc=\"upper left\",\n", 436 | " bbox_to_anchor=(1.05, 1.0),\n", 437 | ")\n", 438 | "plt.tight_layout()\n", 439 | "\n", 440 | "# plt.savefig('posterior_compare.png')\n", 441 | "# plt.savefig('posterior_compare.svg')\n", 442 | "plt.show()" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": null, 448 | "metadata": {}, 449 | "outputs": [], 450 | "source": [ 451 | "# calculating cross-entropy values\n", 452 | "def get_crossent(\n", 453 | " prior_samples,\n", 454 | " posterior_samples,\n", 455 | " epsilon=1e-7,\n", 456 | " x_range=[-100, 100],\n", 457 | " nbins=40,\n", 458 | " post_weights=None,\n", 459 | "):\n", 460 | " prior_dists = []\n", 461 | " posterior_dists = []\n", 462 | " crossents = []\n", 463 | " for i in range(5):\n", 464 | " prior_dist, _ = np.histogram(\n", 465 | " prior_samples[:, i], bins=nbins, range=x_range, density=True\n", 466 | " )\n", 467 | " prior_dists.append(prior_dist)\n", 468 | " posterior_dist, _ = np.histogram(\n", 469 | " posterior_samples[:, i],\n", 470 | " bins=nbins,\n", 471 | " range=x_range,\n", 472 | " density=True,\n", 473 | " weights=post_weights,\n", 474 | " )\n", 475 | " posterior_dists.append(posterior_dist)\n", 476 | " crossents.append(np.log(posterior_dist + epsilon) * (prior_dist + epsilon))\n", 477 | " return -np.sum(crossents)\n", 478 | "\n", 479 | "\n", 480 | "snl_prior = np.random.multivariate_normal(\n", 481 | " mean=prior_means, cov=np.eye(5) * 50, size=snl_dist.shape[0]\n", 482 | ")\n", 483 | "snl_crossent = get_crossent(snl_prior, snl_dist)\n", 484 | "\n", 485 | "maxent_prior = np.random.multivariate_normal(prior_means, np.eye(5) * 50, size=2048)\n", 486 | "maxent_crossent = get_crossent(maxent_prior, maxent_prior, post_weights=maxent_weights)\n", 487 | "\n", 488 | "print(f\"CROSS-ENTROPY:\\nSNL: {snl_crossent}\\nMaxEnt: {maxent_crossent}\")\n", 489 | "\n", 490 | "crossent_values = [snl_crossent, maxent_crossent]\n", 491 | "np.savetxt(\"crossent_values.txt\", np.array(crossent_values), header=\"SNL, MaxEnt\")" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": {}, 497 | "source": [ 498 | "### MaxEnt With Variational" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "import tensorflow_probability as tfp\n", 508 | "\n", 509 | "tfd = tfp.distributions" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "metadata": { 516 | "scrolled": false 517 | }, 518 | "outputs": [], 519 | "source": [ 520 | "x = np.array(prior_means, dtype=np.float32)\n", 521 | "y = np.array(prior_cov, dtype=np.float32)\n", 522 | "i = tf.keras.Input((100, 2))\n", 523 | "l = maxent.TrainableInputLayer(x)(i)\n", 524 | "d = tfp.layers.DistributionLambda(\n", 525 | " lambda x: tfd.MultivariateNormalFullCovariance(loc=x, covariance_matrix=y)\n", 526 | ")(l)\n", 527 | "model = maxent.ParameterJoint([lambda x: x], inputs=i, outputs=[d])\n", 528 | "model.compile(tf.keras.optimizers.SGD(1e-3))\n", 529 | "model.summary()\n", 530 | "model(tf.constant([1.0, 2.0, 3.0, 4.0, 5.0]))" 531 | ] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "execution_count": null, 536 | "metadata": {}, 537 | "outputs": [], 538 | "source": [ 539 | "def simulate(x, nsteps=100):\n", 540 | " \"\"\"params_list should be: m1, m2, m3, v0[0], v0[1] in that order\"\"\"\n", 541 | " # double nsteps b/c we flatten the (x,y) coordinates\n", 542 | " output = np.zeros((x.shape[0], nsteps, 2))\n", 543 | " for i in range(x.shape[0]):\n", 544 | " params_list = x[i, 0, :]\n", 545 | " m1, m2, m3 = float(params_list[0]), float(params_list[1]), float(params_list[2])\n", 546 | " v0 = np.array([params_list[3], params_list[4]], dtype=np.float64)\n", 547 | " this_sim = GravitySimulator(m1, m2, m3, v0, random_noise=False, nsteps=nsteps)\n", 548 | " # set to 1D to make hypermaxent setup easier\n", 549 | " this_traj = this_sim.run() # .flatten()\n", 550 | " output[i] = this_traj\n", 551 | " return output" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": null, 557 | "metadata": { 558 | "scrolled": true 559 | }, 560 | "outputs": [], 561 | "source": [ 562 | "def get_observation_points_from_flat(flat_traj):\n", 563 | " recovered_traj = flat_traj.reshape([-1, 2])\n", 564 | " return get_observation_points(recovered_traj) # .flatten()" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": null, 570 | "metadata": {}, 571 | "outputs": [], 572 | "source": [ 573 | "r = []\n", 574 | "true_points = get_observation_points(noisy_traj)\n", 575 | "true_points_flat = true_points.flatten()\n", 576 | "for i, point in enumerate(true_points_flat):\n", 577 | " r.append(\n", 578 | " maxent.Restraint(\n", 579 | " lambda x: get_observation_points_from_flat(x)[i], point, maxent.EmptyPrior()\n", 580 | " )\n", 581 | " )\n", 582 | "hme_model = maxent.HyperMaxentModel(maxent_restraints, model, simulate)\n", 583 | "hme_model.compile(tf.keras.optimizers.Adam(1e-4), \"mean_squared_error\")" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": null, 589 | "metadata": { 590 | "scrolled": true 591 | }, 592 | "outputs": [], 593 | "source": [ 594 | "hme_results = hme_model.fit(\n", 595 | " epochs=30000, sample_batch_size=2048 // 4, outter_epochs=4, verbose=0\n", 596 | ") # one-quarter of plain maxent batch size" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": null, 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [ 605 | "hme_predicted_params = hme_model.weights[1]\n", 606 | "hme_trajectory_weights = hme_model.traj_weights\n", 607 | "variational_trajs = hme_model.trajs.reshape([hme_model.trajs.shape[0], -1, 2])\n", 608 | "maxent_variational_avg_traj = np.sum(\n", 609 | " variational_trajs * hme_trajectory_weights[:, np.newaxis, np.newaxis], axis=0\n", 610 | ")\n", 611 | "np.savetxt(\"maxent_variational_avg_traj.txt\", maxent_variational_avg_traj)" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "metadata": {}, 618 | "outputs": [], 619 | "source": [ 620 | "# simulate traj generated by prior means\n", 621 | "sim = GravitySimulator(prior_means[0], prior_means[1], prior_means[2], prior_means[3:])\n", 622 | "prior_means_traj = sim.run()" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": null, 628 | "metadata": {}, 629 | "outputs": [], 630 | "source": [ 631 | "mean_snl_traj = np.genfromtxt(\"mean_snl_traj.txt\")\n", 632 | "maxent_avg_traj = np.genfromtxt(\"maxent_avg_traj.txt\")\n", 633 | "maxent_variational_avg_traj = np.genfromtxt(\"maxent_variational_avg_traj.txt\")" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": null, 639 | "metadata": {}, 640 | "outputs": [], 641 | "source": [ 642 | "alpha_val = 0.7\n", 643 | "fig, axes = plt.subplots(figsize=(5, 3), dpi=300)\n", 644 | "\n", 645 | "# plot the observation points\n", 646 | "axes.scatter(\n", 647 | " observed_points[:, 0],\n", 648 | " observed_points[:, 1],\n", 649 | " color=\"black\",\n", 650 | " zorder=10,\n", 651 | " marker=\"*\",\n", 652 | " label=\"Observed Points\",\n", 653 | ")\n", 654 | "\n", 655 | "# plot the trajectory generated by prior means\n", 656 | "sim.set_traj(prior_means_traj)\n", 657 | "sim.plot_traj(\n", 658 | " fig=fig,\n", 659 | " axes=axes,\n", 660 | " make_colorbar=False,\n", 661 | " save=False,\n", 662 | " cmap=plt.get_cmap(\"Greys\").reversed(),\n", 663 | " color=\"grey\",\n", 664 | " fade_lines=False,\n", 665 | " alpha=alpha_val,\n", 666 | " linestyle=\"-.\",\n", 667 | " linewidth=1,\n", 668 | " label=\"Prior Mean\",\n", 669 | ")\n", 670 | "\n", 671 | "# plot the SNL mean trajectory\n", 672 | "sim.set_traj(mean_snl_traj)\n", 673 | "sim.plot_traj(\n", 674 | " fig=fig,\n", 675 | " axes=axes,\n", 676 | " make_colorbar=False,\n", 677 | " save=False,\n", 678 | " cmap=plt.get_cmap(\"Greens\").reversed(),\n", 679 | " color=colors[0],\n", 680 | " fade_lines=False,\n", 681 | " linewidth=1,\n", 682 | " alpha=alpha_val,\n", 683 | " label=\"SNL\",\n", 684 | ")\n", 685 | "\n", 686 | "# plot the true trajectory\n", 687 | "sim.set_traj(true_traj)\n", 688 | "sim.plot_traj(\n", 689 | " fig=fig,\n", 690 | " axes=axes,\n", 691 | " make_colorbar=False,\n", 692 | " save=False,\n", 693 | " cmap=plt.get_cmap(\"Reds\").reversed(),\n", 694 | " color=\"black\",\n", 695 | " fade_lines=False,\n", 696 | " alpha=alpha_val,\n", 697 | " linestyle=\":\",\n", 698 | " linewidth=1,\n", 699 | " label=\"True Path\",\n", 700 | " label_attractors=False,\n", 701 | ")\n", 702 | "\n", 703 | "# plot the maxent average trajectory\n", 704 | "sim.set_traj(maxent_avg_traj)\n", 705 | "sim.plot_traj(\n", 706 | " fig=fig,\n", 707 | " axes=axes,\n", 708 | " make_colorbar=False,\n", 709 | " save=False,\n", 710 | " cmap=plt.get_cmap(\"Oranges\").reversed(),\n", 711 | " color=colors[2],\n", 712 | " fade_lines=False,\n", 713 | " alpha=alpha_val,\n", 714 | " linestyle=\"-\",\n", 715 | " linewidth=1,\n", 716 | " label=\"MaxEnt\",\n", 717 | " label_attractors=False,\n", 718 | ")\n", 719 | "\n", 720 | "# plot the maxent average trajectory\n", 721 | "sim.set_traj(maxent_variational_avg_traj)\n", 722 | "sim.plot_traj(\n", 723 | " fig=fig,\n", 724 | " axes=axes,\n", 725 | " make_colorbar=False,\n", 726 | " save=False,\n", 727 | " cmap=plt.get_cmap(\"Oranges\").reversed(),\n", 728 | " color=colors[3],\n", 729 | " fade_lines=False,\n", 730 | " alpha=alpha_val,\n", 731 | " linestyle=\"-\",\n", 732 | " linewidth=1,\n", 733 | " label=\"Variational MaxEnt\",\n", 734 | " label_attractors=True,\n", 735 | ")\n", 736 | "\n", 737 | "# set limits manually\n", 738 | "axes.set_xlim(-5, 130)\n", 739 | "axes.set_ylim(-30, 75)\n", 740 | "\n", 741 | "plt.legend(loc=\"upper left\", bbox_to_anchor=(1.05, 1.0))\n", 742 | "plt.tight_layout()\n", 743 | "\n", 744 | "plt.savefig(\"paths_compare.png\")\n", 745 | "plt.savefig(\"paths_compare.svg\")\n", 746 | "plt.show()" 747 | ] 748 | } 749 | ], 750 | "metadata": { 751 | "celltoolbar": "Tags", 752 | "kernelspec": { 753 | "display_name": "Python 3", 754 | "language": "python", 755 | "name": "python3" 756 | }, 757 | "language_info": { 758 | "codemirror_mode": { 759 | "name": "ipython", 760 | "version": 3 761 | }, 762 | "file_extension": ".py", 763 | "mimetype": "text/x-python", 764 | "name": "python", 765 | "nbconvert_exporter": "python", 766 | "pygments_lexer": "ipython3", 767 | "version": "3.8.3" 768 | } 769 | }, 770 | "nbformat": 4, 771 | "nbformat_minor": 4 772 | } 773 | -------------------------------------------------------------------------------- /paper/epidemiology.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Epidemiology Example" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Packages" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# to speed-up execution, mark this as True\n", 24 | "USE_CACHED_RESULTS = False\n", 25 | "# cross-fold crashes Github CI\n", 26 | "USE_CACHED_CV5_RESULTS = True" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "import os\n", 36 | "\n", 37 | "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n", 38 | "import maxent\n", 39 | "import maxentep as py0\n", 40 | "import tensorflow as tf\n", 41 | "\n", 42 | "tf.get_logger().setLevel(\"INFO\")\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "import numpy as np\n", 45 | "import seaborn as sns\n", 46 | "import pyabc\n", 47 | "\n", 48 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n", 49 | "sns.set_context(\"paper\")\n", 50 | "from tqdm import tqdm\n", 51 | "from functools import partialmethod\n", 52 | "\n", 53 | "tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "### Set-up SEAIR Model" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# Make up some populations (take as known)\n", 70 | "M = 3\n", 71 | "np.random.seed(0)\n", 72 | "population = np.maximum(\n", 73 | " 1000, np.round(np.random.normal(loc=100000, scale=25000, size=(M,)), 0)\n", 74 | ")\n", 75 | "area = np.maximum(250, np.round(np.random.normal(loc=2000, scale=1000, size=(M,)), 0))\n", 76 | "print(area, population)\n", 77 | "population_fraction = population / np.sum(population)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# compartment parameters\n", 87 | "compartments = [\"E\", \"A\", \"I\", \"R\"]\n", 88 | "infections_compartments = [1, 2]\n", 89 | "infect_fxn = py0.contact_infection_func(infections_compartments)\n", 90 | "full_compartments = [\"S\"] + compartments\n", 91 | "R = np.array([[1000, 400, 10], [0, 300, 300], [300, 300, 1000]])\n", 92 | "R_norm = R / np.sum(R, axis=1)\n", 93 | "\n", 94 | "\n", 95 | "def metapop_wraper(start_infected, start_asymptomatic, E_time, A_time, I_time):\n", 96 | " beta = 0.025\n", 97 | " start = np.zeros((3, 4))\n", 98 | " start[0, 2] = start_infected\n", 99 | " start[0, 1] = start_asymptomatic\n", 100 | " tmat = py0.TransitionMatrix(compartments, infections_compartments)\n", 101 | " tmat.add_transition(\"E\", \"A\", E_time, 0)\n", 102 | " tmat.add_transition(\"A\", \"I\", A_time, 0)\n", 103 | " tmat.add_transition(\"I\", \"R\", I_time, 0)\n", 104 | " timesteps = 250\n", 105 | " meta_model = py0.MetaModel(infect_fxn, timesteps, populations=population_fraction)\n", 106 | " trajectory = meta_model(R_norm, tmat.value, start, beta)[0]\n", 107 | " return trajectory\n", 108 | "\n", 109 | "\n", 110 | "ref_inpputs = [0.02, 0.05, 7, 5, 14]\n", 111 | "ref_traj = metapop_wraper(*ref_inpputs)\n", 112 | "# ref_traj = ref_traj[np.newaxis,...]\n", 113 | "fig, axs = plt.subplots(nrows=1, ncols=M, figsize=(8, 2), dpi=300)\n", 114 | "fig.suptitle(\"Reference Model\", y=1.2, fontsize=18)\n", 115 | "for i in range(M):\n", 116 | " plt.setp(axs[i], xlabel=\"Time\", title=\"Patch {}\".format(i + 1), ylim=[0, 1])\n", 117 | " axs[i].plot(ref_traj[:, i], linestyle=\"--\", label=full_compartments)\n", 118 | "axs[0].set_ylabel(\"Population Fraction\")\n", 119 | "plt.legend(bbox_to_anchor=(1, 1))\n", 120 | "plt.show()" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "npoints = 5\n", 130 | "np.random.seed(0)\n", 131 | "prior = maxent.Laplace(0.01)\n", 132 | "restrained_compartments = [3] # (infected patch)\n", 133 | "restrained_compartments_names = [full_compartments[m] for m in restrained_compartments]\n", 134 | "number_of_restrained_compartments = len(restrained_compartments)\n", 135 | "restrained_patches = np.array([0])\n", 136 | "print(\"Patches restrained:\", restrained_patches)\n", 137 | "print(\n", 138 | " \"Total number of restraints: \",\n", 139 | " npoints * number_of_restrained_compartments * len(restrained_patches),\n", 140 | ")\n", 141 | "print(\"Compartments restrained: \", restrained_compartments_names)\n", 142 | "restraints, plot_fxns_list = py0.compartment_restrainer(\n", 143 | " restrained_patches,\n", 144 | " restrained_compartments,\n", 145 | " npoints,\n", 146 | " ref_traj,\n", 147 | " prior,\n", 148 | " noise=0.05,\n", 149 | " start_time=0,\n", 150 | " end_time=100,\n", 151 | ")" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "ref_traj.shape" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "### Sample SEAIR Trajectories" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "tf.random.set_seed(0)\n", 177 | "if not USE_CACHED_RESULTS or not os.path.exists(\n", 178 | " \"epidemiology_files/maxent_sample_trajs\"\n", 179 | "):\n", 180 | " tmat = py0.TransitionMatrix(compartments, infections_compartments)\n", 181 | " tmat.add_transition(\"E\", \"A\", 2, 1)\n", 182 | " tmat.add_transition(\"A\", \"I\", 2, 4)\n", 183 | " tmat.add_transition(\"I\", \"R\", 10, 5)\n", 184 | " start_logits = np.zeros((M))\n", 185 | " hyper_pram = py0.ParameterHypers()\n", 186 | " hyper_pram.beta_var = 0.000001\n", 187 | " hyper_pram.beta_start = 0.025\n", 188 | " hyper_pram.beta_high = 0.025002\n", 189 | " hyper_pram.beta_low = 0.025001\n", 190 | " hyper_pram.start_mean = 0.001\n", 191 | " hyper_pram.start_high = 0.4\n", 192 | " hyper_pram.start_var = 0.8\n", 193 | " hyper_pram.R_var = 0.00001\n", 194 | " param_model = py0.MetaParameterJoint(\n", 195 | " start_logits, R, tmat, name=\"unbiased model\", hypers=hyper_pram\n", 196 | " )\n", 197 | " # Fxing beta and mobility matrix\n", 198 | " R_norm_sample = tf.convert_to_tensor(R_norm)\n", 199 | " beta_sample = tf.convert_to_tensor([0.025])\n", 200 | " N = 2048\n", 201 | " batches = 4\n", 202 | " batch_size = N * batches\n", 203 | " outs = []\n", 204 | " timesteps = 250\n", 205 | " model = py0.MetaModel(infect_fxn, timesteps, populations=population_fraction)\n", 206 | " for b in tqdm(range(batches)):\n", 207 | " psample = param_model.sample(N)\n", 208 | " outs.append(model(*psample))\n", 209 | " trajs = np.concatenate(outs, axis=0)\n", 210 | " np.save(\"epidemiology_files/maxent_sample_trajs\", trajs)\n", 211 | "else:\n", 212 | " trajs = np.load(\"epidemiology_files/maxent_sample_trajs.npy\")" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "fig, axs = plt.subplots(nrows=1, ncols=M, figsize=(8, 2), dpi=300)\n", 222 | "fig.suptitle(\"Unbiased Model\", y=1.2, fontsize=18)\n", 223 | "py0.traj_quantile(\n", 224 | " trajs[:, :, 0, :],\n", 225 | " names=full_compartments,\n", 226 | " plot_means=True,\n", 227 | " ax=axs[0],\n", 228 | " add_legend=False,\n", 229 | " alpha=0.2,\n", 230 | ")\n", 231 | "py0.traj_quantile(\n", 232 | " trajs[:, :, 1, :],\n", 233 | " names=full_compartments,\n", 234 | " plot_means=True,\n", 235 | " ax=axs[1],\n", 236 | " add_legend=False,\n", 237 | " alpha=0.2,\n", 238 | ")\n", 239 | "py0.traj_quantile(\n", 240 | " trajs[:, :, 2, :], names=full_compartments, plot_means=True, ax=axs[2], alpha=0.2\n", 241 | ")\n", 242 | "for i in range(M):\n", 243 | " plt.setp(axs[i], xlabel=\"Time\", title=\"Patch {}\".format(i + 1), ylim=[0, 1.0])\n", 244 | " axs[i].plot(ref_traj[:, i], linestyle=\"--\")\n", 245 | "axs[0].set_ylabel(\"Population Fraction\")\n", 246 | "plt.show()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "if not USE_CACHED_RESULTS:\n", 256 | " me_model = maxent.MaxentModel(restraints)\n", 257 | " me_model.compile(tf.keras.optimizers.Adam(learning_rate=1e-1), \"mean_squared_error\")\n", 258 | " me_history = me_model.fit(trajs, batch_size=batch_size, epochs=1000, verbose=0)\n", 259 | " me_w = me_model.traj_weights\n", 260 | " np.save(\"epidemiology_files/maxent_biased_me_w\", me_w)\n", 261 | "else:\n", 262 | " me_w = np.load(\"epidemiology_files/maxent_biased_me_w.npy\")" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "fig, axs = plt.subplots(nrows=1, ncols=M, figsize=(8, 2), dpi=300)\n", 272 | "fig.suptitle(\"Maxent Biased Model\", y=1.2, fontsize=18)\n", 273 | "py0.traj_quantile(\n", 274 | " trajs[:, :, 0, :],\n", 275 | " weights=me_w,\n", 276 | " names=full_compartments,\n", 277 | " plot_means=True,\n", 278 | " ax=axs[0],\n", 279 | " add_legend=False,\n", 280 | " alpha=0.2,\n", 281 | ")\n", 282 | "py0.traj_quantile(\n", 283 | " trajs[:, :, 1, :],\n", 284 | " weights=me_w,\n", 285 | " names=full_compartments,\n", 286 | " plot_means=True,\n", 287 | " ax=axs[1],\n", 288 | " add_legend=False,\n", 289 | " alpha=0.2,\n", 290 | ")\n", 291 | "py0.traj_quantile(\n", 292 | " trajs[:, :, 2, :],\n", 293 | " weights=me_w,\n", 294 | " names=full_compartments,\n", 295 | " plot_means=True,\n", 296 | " ax=axs[2],\n", 297 | " alpha=0.2,\n", 298 | ")\n", 299 | "for i in range(M):\n", 300 | " plt.setp(axs[i], xlabel=\"Time\", title=\"Patch {}\".format(i + 1), ylim=[0, 1.0])\n", 301 | " axs[i].plot(ref_traj[:, i], linestyle=\"--\")\n", 302 | " if i in restrained_patches:\n", 303 | " for _, pf in enumerate(plot_fxns_list[restrained_patches.tolist().index(i)]):\n", 304 | " pf(axs[i], 0, color=\"C3\")\n", 305 | " axs[i].spines[\"bottom\"].set_color(\"y\")\n", 306 | " axs[i].spines[\"top\"].set_color(\"y\")\n", 307 | " axs[i].spines[\"right\"].set_color(\"y\")\n", 308 | " axs[i].spines[\"left\"].set_color(\"y\")\n", 309 | " axs[i].spines[\"left\"].set_linewidth(2)\n", 310 | " axs[i].spines[\"top\"].set_linewidth(2)\n", 311 | " axs[i].spines[\"right\"].set_linewidth(2)\n", 312 | " axs[i].spines[\"bottom\"].set_linewidth(2)\n", 313 | "axs[0].set_ylabel(\"Population Fraction\")\n", 314 | "plt.show()" 315 | ] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "metadata": {}, 320 | "source": [ 321 | "### Least squares regression" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "if not USE_CACHED_RESULTS:\n", 331 | " C = len(compartments)\n", 332 | " start = np.zeros((M, C))\n", 333 | " start[0, 1] = 0.001\n", 334 | " start[0, 2] = 0.001\n", 335 | " beta_start = 0.1\n", 336 | " infect = py0.ContactInfectionLayer(beta_start, infections_compartments)\n", 337 | " # make \"agreement\" function\n", 338 | " def agreement(traj, rs=restraints):\n", 339 | " s = 0\n", 340 | " for r in rs:\n", 341 | " s += (r(traj[0]) ** 2)[tf.newaxis, ...]\n", 342 | " return s\n", 343 | "\n", 344 | " rmodel = py0.TrainableMetaModel(\n", 345 | " start, R_norm, tmat.value, infect, timesteps, agreement\n", 346 | " )\n", 347 | " rmodel.compile(optimizer=tf.keras.optimizers.Nadam(0.01))\n", 348 | " utraj = rmodel.get_traj()\n", 349 | " rmodel.fit(steps=timesteps, verbose=0)\n", 350 | " regressed_traj = rmodel.get_traj()\n", 351 | " np.save(\"epidemiology_files/ls_biased_traj\", regressed_traj)\n", 352 | "else:\n", 353 | " regressed_traj = np.load(\"epidemiology_files/ls_biased_traj.npy\")" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "fig, axs = plt.subplots(nrows=1, ncols=M, figsize=(8, 2), dpi=300)\n", 363 | "fig.suptitle(\"Least Squares Regression Biased Model\", y=1.2, fontsize=18)\n", 364 | "py0.traj_quantile(\n", 365 | " regressed_traj[:, :, 0, :],\n", 366 | " names=full_compartments,\n", 367 | " plot_means=True,\n", 368 | " ax=axs[0],\n", 369 | " add_legend=False,\n", 370 | " alpha=0.2,\n", 371 | ")\n", 372 | "py0.traj_quantile(\n", 373 | " regressed_traj[:, :, 1, :],\n", 374 | " names=full_compartments,\n", 375 | " plot_means=True,\n", 376 | " ax=axs[1],\n", 377 | " add_legend=False,\n", 378 | " alpha=0.2,\n", 379 | ")\n", 380 | "py0.traj_quantile(\n", 381 | " regressed_traj[:, :, 2, :],\n", 382 | " names=full_compartments,\n", 383 | " plot_means=True,\n", 384 | " ax=axs[2],\n", 385 | " alpha=0.2,\n", 386 | ")\n", 387 | "for i in range(M):\n", 388 | " plt.setp(axs[i], xlabel=\"Time\", title=\"Patch {}\".format(i + 1), ylim=[0, 1.0])\n", 389 | " axs[i].plot(ref_traj[:, i], linestyle=\"--\")\n", 390 | " if i in restrained_patches:\n", 391 | " for _, pf in enumerate(plot_fxns_list[restrained_patches.tolist().index(i)]):\n", 392 | " pf(axs[i], 0, color=\"C3\")\n", 393 | " axs[i].spines[\"bottom\"].set_color(\"y\")\n", 394 | " axs[i].spines[\"top\"].set_color(\"y\")\n", 395 | " axs[i].spines[\"right\"].set_color(\"y\")\n", 396 | " axs[i].spines[\"left\"].set_color(\"y\")\n", 397 | " axs[i].spines[\"left\"].set_linewidth(2)\n", 398 | " axs[i].spines[\"top\"].set_linewidth(2)\n", 399 | " axs[i].spines[\"right\"].set_linewidth(2)\n", 400 | " axs[i].spines[\"bottom\"].set_linewidth(2)\n", 401 | "axs[0].set_ylabel(\"Population Fraction\")\n", 402 | "plt.show()" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "metadata": {}, 408 | "source": [ 409 | "### abc" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "np.random.seed(0)\n", 419 | "start_infected = 0.001\n", 420 | "start_asymptomatic = 0.001\n", 421 | "E_time = 2\n", 422 | "A_time = 2\n", 423 | "I_time = 10\n", 424 | "abc_param_keys = [\"start_infected\", \"start_asymptomatic\", \"E_time\", \"A_time\", \"I_time\"]\n", 425 | "abc_param_values = [start_infected, start_asymptomatic, E_time, A_time, I_time]\n", 426 | "abc_parameters = dict(zip(abc_param_keys, abc_param_values))\n", 427 | "\n", 428 | "\n", 429 | "def abc_model(parameter):\n", 430 | " trajectory = metapop_wraper(\n", 431 | " float(parameter[\"start_infected\"]),\n", 432 | " float(parameter[\"start_asymptomatic\"]),\n", 433 | " +float(parameter[\"E_time\"]),\n", 434 | " float(parameter[\"A_time\"]),\n", 435 | " float(parameter[\"I_time\"]),\n", 436 | " )\n", 437 | " restrainted_time_values = [59, 45, 31, 80, 17]\n", 438 | " values = np.array([trajectory[m, 0, 3].numpy() for m in restrainted_time_values])\n", 439 | " return {\"data\": values}\n", 440 | "\n", 441 | "\n", 442 | "def distance(x, y):\n", 443 | " d = np.sum(np.abs(x[\"data\"] - y[\"data\"]))\n", 444 | " return d\n", 445 | "\n", 446 | "\n", 447 | "if not USE_CACHED_RESULTS:\n", 448 | " parameter_priors = pyabc.Distribution(\n", 449 | " start_infected=pyabc.RV(\"truncnorm\", 0, 0.5, abc_param_values[0], 0.8),\n", 450 | " start_asymptomatic=pyabc.RV(\"truncnorm\", 0, 0.5, abc_param_values[1], 0.8),\n", 451 | " E_time=pyabc.RV(\"norm\", abc_param_values[2], 1),\n", 452 | " A_time=pyabc.RV(\"norm\", abc_param_values[3], 4),\n", 453 | " I_time=pyabc.RV(\"norm\", abc_param_values[4], 5),\n", 454 | " )\n", 455 | "\n", 456 | " abc = pyabc.ABCSMC(abc_model, parameter_priors, distance)\n", 457 | " db_path = \"sqlite:///\" + os.path.join(os.getcwd(), \"abc_SEAIR.db\")\n", 458 | " observation = np.array([r.target for r in restraints])\n", 459 | " abc.new(db_path, {\"data\": observation})\n", 460 | " history = abc.run(minimum_epsilon=0.1, max_nr_populations=5)\n", 461 | " df, w_abc = history.get_distribution(m=0, t=history.max_t)\n", 462 | " abc_trajs = np.empty(\n", 463 | " (len(df), ref_traj.shape[0], ref_traj.shape[1], ref_traj.shape[2])\n", 464 | " )\n", 465 | " for i, row in enumerate(tqdm(np.array(df))):\n", 466 | " (\n", 467 | " A_time_abc,\n", 468 | " E_time_abc,\n", 469 | " I_time_abc,\n", 470 | " start_asymptomatic_abc,\n", 471 | " start_infected_abc,\n", 472 | " ) = (row[0], row[1], row[2], row[3], row[4])\n", 473 | " traj = metapop_wraper(\n", 474 | " start_infected_abc,\n", 475 | " start_asymptomatic_abc,\n", 476 | " E_time_abc,\n", 477 | " A_time_abc,\n", 478 | " I_time_abc,\n", 479 | " )\n", 480 | " abc_trajs[i] = traj\n", 481 | " abc_biased_traj = np.sum(\n", 482 | " abc_trajs * w_abc[:, np.newaxis, np.newaxis, np.newaxis], axis=0\n", 483 | " )\n", 484 | " np.save(\"epidemiology_files/abc_biased_traj.npy\", abc_biased_traj)\n", 485 | "else:\n", 486 | " abc_biased_traj = np.load(\"epidemiology_files/abc_biased_traj.npy\")" 487 | ] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": {}, 492 | "source": [ 493 | "### 5-fold cross validation" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": null, 499 | "metadata": {}, 500 | "outputs": [], 501 | "source": [ 502 | "if not USE_CACHED_CV5_RESULTS:\n", 503 | " LS_traj_folds = np.empty((npoints, trajs.shape[1], trajs.shape[2], trajs.shape[-1]))\n", 504 | " MaxEnt_traj_folds = np.empty_like(LS_traj_folds)\n", 505 | " abc_traj_folds = []\n", 506 | " for i in range(npoints):\n", 507 | " sampled_restraints = [n for n in restraints if n != restraints[i]]\n", 508 | " ## MaxEnt\n", 509 | " print(f\"Initializing LS for fold {i}\")\n", 510 | "\n", 511 | " def new_agreement(traj, rs=sampled_restraints):\n", 512 | " s = 0\n", 513 | " for r in rs:\n", 514 | " s += (r(traj[0]) ** 2)[tf.newaxis, ...]\n", 515 | " return s\n", 516 | "\n", 517 | " rmodel_sampled = py0.TrainableMetaModel(\n", 518 | " start, R_norm, tmat.value, infect, timesteps, new_agreement\n", 519 | " )\n", 520 | " rmodel_sampled.compile(optimizer=tf.keras.optimizers.Nadam(0.01))\n", 521 | " utraj_sampled = rmodel_sampled.get_traj()\n", 522 | " rmodel_sampled.fit(timesteps)\n", 523 | " regressed_traj_sampled = rmodel_sampled.get_traj()\n", 524 | " LS_traj_folds[i, :, :, :] = regressed_traj_sampled\n", 525 | " ## MaxEnt\n", 526 | " print(f\"Initializing MaxEnt for fold {i}\")\n", 527 | " me_model_sample = py0.MaxentModel(\n", 528 | " sampled_restraints, trajs=trajs, population_fraction=population_fraction\n", 529 | " )\n", 530 | " me_model_sample.compile(\n", 531 | " tf.keras.optimizers.Adam(learning_rate=1e-1), \"mean_squared_error\"\n", 532 | " )\n", 533 | " me_history_sample = me_model_sample.fit(\n", 534 | " trajs, batch_size=batch_size, epochs=1000\n", 535 | " )\n", 536 | " MaxEnt_traj_folds[i, :, :, :] = np.sum(\n", 537 | " me_model_sample.trajs\n", 538 | " * me_model_sample.traj_weights[:, np.newaxis, np.newaxis, np.newaxis],\n", 539 | " axis=0,\n", 540 | " )\n", 541 | " ## ABC\n", 542 | " print(f\"Initializing ABC for fold {i}\")\n", 543 | "\n", 544 | " def abc_model(parameter):\n", 545 | " trajectory = metapop_wraper(\n", 546 | " float(parameter[\"start_infected\"]),\n", 547 | " float(parameter[\"start_asymptomatic\"]),\n", 548 | " +float(parameter[\"E_time\"]),\n", 549 | " float(parameter[\"A_time\"]),\n", 550 | " float(parameter[\"I_time\"]),\n", 551 | " )\n", 552 | " restrainted_time_values = [59, 45, 31, 80, 17]\n", 553 | " restrainted_time_values.pop(i)\n", 554 | " values = np.array(\n", 555 | " [trajectory[m, 0, 3].numpy() for m in restrainted_time_values]\n", 556 | " )\n", 557 | " return {\"data\": values}\n", 558 | "\n", 559 | " parameter_priors = pyabc.Distribution(\n", 560 | " start_infected=pyabc.RV(\"truncnorm\", 0, 0.5, abc_param_values[0], 0.8),\n", 561 | " start_asymptomatic=pyabc.RV(\"truncnorm\", 0, 0.5, abc_param_values[1], 0.8),\n", 562 | " E_time=pyabc.RV(\"norm\", abc_param_values[2], 1),\n", 563 | " A_time=pyabc.RV(\"norm\", abc_param_values[3], 4),\n", 564 | " I_time=pyabc.RV(\"norm\", abc_param_values[4], 5),\n", 565 | " )\n", 566 | "\n", 567 | " abc = pyabc.ABCSMC(abc_model, parameter_priors, distance)\n", 568 | " db_path = \"sqlite:///\" + os.path.join(os.getcwd(), \"abc_SEAIR.db\")\n", 569 | " observation_abc = np.array([r.target for r in sampled_restraints])\n", 570 | " abc.new(db_path, {\"data\": observation_abc})\n", 571 | " history = abc.run(minimum_epsilon=0.1, max_nr_populations=5)\n", 572 | " df, w_abc = history.get_distribution(m=0, t=history.max_t)\n", 573 | " abc_trajs = np.empty(\n", 574 | " (len(df), ref_traj.shape[0], ref_traj.shape[1], ref_traj.shape[2])\n", 575 | " )\n", 576 | " for j, row in enumerate(tqdm(np.array(df))):\n", 577 | " (\n", 578 | " A_time_abc,\n", 579 | " E_time_abc,\n", 580 | " I_time_abc,\n", 581 | " start_asymptomatic_abc,\n", 582 | " start_infected_abc,\n", 583 | " ) = (row[0], row[1], row[2], row[3], row[4])\n", 584 | " traj = metapop_wraper(\n", 585 | " abs(start_infected_abc),\n", 586 | " abs(start_asymptomatic_abc),\n", 587 | " E_time_abc,\n", 588 | " A_time_abc,\n", 589 | " I_time_abc,\n", 590 | " )\n", 591 | " abc_trajs[j] = traj\n", 592 | " mean_abc_traj = np.sum(\n", 593 | " abc_trajs * w_abc[:, np.newaxis, np.newaxis, np.newaxis], axis=0\n", 594 | " )\n", 595 | " abc_traj_folds.append(mean_abc_traj)\n", 596 | " np.save(\"epidemiology_files/abc_traj_folds.npy\", abc_traj_folds)\n", 597 | " np.save(\"epidemiology_files/MaxEnt_traj_folds.npy\", MaxEnt_traj_folds)\n", 598 | " np.save(\"epidemiology_files/LS_traj_folds.npy\", LS_traj_folds)\n", 599 | "abc_traj_folds = np.load(\"epidemiology_files/abc_traj_folds.npy\")\n", 600 | "MaxEnt_traj_folds = np.load(\"epidemiology_files/MaxEnt_traj_folds.npy\")\n", 601 | "LS_traj_folds = np.load(\"epidemiology_files/LS_traj_folds.npy\")" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": null, 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [ 610 | "def find_std(sample_trajs):\n", 611 | " std_0 = np.std(sample_trajs[:, 0, 2, 4])\n", 612 | " std_mid = np.std(sample_trajs[:, 125, 2, 4])\n", 613 | " std_end = np.std(sample_trajs[:, -1, 2, 4])\n", 614 | " return [std_0, std_mid, std_end]\n", 615 | "\n", 616 | "\n", 617 | "std_abc = find_std(abc_traj_folds)\n", 618 | "std_MaxEnt = find_std(MaxEnt_traj_folds)\n", 619 | "std_LS = find_std(LS_traj_folds)\n", 620 | "print(\"MaxEnt std : \", std_MaxEnt)\n", 621 | "print(\"Least-squares std : \", std_LS)\n", 622 | "print(\"ABC std : \", std_abc)" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": null, 628 | "metadata": {}, 629 | "outputs": [], 630 | "source": [ 631 | "def weighted_quantile(\n", 632 | " values, quantiles, sample_weight=None, values_sorted=False, old_style=False\n", 633 | "):\n", 634 | " \"\"\"Very close to numpy.percentile, but supports weights.\n", 635 | " NOTE: quantiles should be in [0, 1]!\n", 636 | " :param values: numpy.array with data\n", 637 | " :param quantiles: array-like with many quantiles needed\n", 638 | " :param sample_weight: array-like of the same length as `array`\n", 639 | " :param values_sorted: bool, if True, then will avoid sorting of\n", 640 | " initial array\n", 641 | " :param old_style: if True, will correct output to be consistent\n", 642 | " with numpy.percentile.\n", 643 | " :return: numpy.array with computed quantiles.\n", 644 | " \"\"\"\n", 645 | " values = np.array(values)\n", 646 | " quantiles = np.array(quantiles)\n", 647 | " if sample_weight is None:\n", 648 | " sample_weight = np.ones(len(values))\n", 649 | " sample_weight = np.array(sample_weight)\n", 650 | " assert np.all(quantiles >= 0) and np.all(\n", 651 | " quantiles <= 1\n", 652 | " ), \"quantiles should be in [0, 1]\"\n", 653 | "\n", 654 | " if not values_sorted:\n", 655 | " sorter = np.argsort(values)\n", 656 | " values = values[sorter]\n", 657 | " sample_weight = sample_weight[sorter]\n", 658 | "\n", 659 | " weighted_quantiles = np.cumsum(sample_weight) - 0.5 * sample_weight\n", 660 | " if old_style:\n", 661 | " # To be convenient with numpy.percentile\n", 662 | " weighted_quantiles -= weighted_quantiles[0]\n", 663 | " weighted_quantiles /= weighted_quantiles[-1]\n", 664 | " else:\n", 665 | " weighted_quantiles /= np.sum(sample_weight)\n", 666 | " return np.interp(quantiles, weighted_quantiles, values)" 667 | ] 668 | }, 669 | { 670 | "cell_type": "code", 671 | "execution_count": null, 672 | "metadata": {}, 673 | "outputs": [], 674 | "source": [ 675 | "new_stds = np.array(np.round([std_MaxEnt, std_LS, std_abc], 3))\n", 676 | "\n", 677 | "\n", 678 | "def set_align_for_column(table, col, align=\"left\"):\n", 679 | " cells = [key for key in table._cells if key[1] == col]\n", 680 | " for cell in cells:\n", 681 | " table._cells[cell]._loc = align\n", 682 | "\n", 683 | "\n", 684 | "plt.rc(\"axes\", titlesize=8)\n", 685 | "from matplotlib.font_manager import FontProperties\n", 686 | "\n", 687 | "sns.set_context(\"paper\")\n", 688 | "sns.set_style(\n", 689 | " \"darkgrid\",\n", 690 | " {\n", 691 | " \"xtick.bottom\": True,\n", 692 | " \"ytick.left\": True,\n", 693 | " \"xtick.color\": \"#333333\",\n", 694 | " \"ytick.color\": \"#333333\",\n", 695 | " },\n", 696 | ")\n", 697 | "plt.rcParams[\"mathtext.fontset\"] = \"dejavuserif\"\n", 698 | "colors = [\"#1b9e77\", \"#d95f02\", \"#7570b3\", \"#e7298a\", \"#66a61e\"]\n", 699 | "\n", 700 | "\n", 701 | "import matplotlib.markers as mmark\n", 702 | "\n", 703 | "plt.rcParams[\"axes.grid\"] = True\n", 704 | "plt.rcParams[\"grid.alpha\"] = 0.9\n", 705 | "plt.rcParams[\"grid.color\"] = \"#cccccc\"\n", 706 | "fig, axs = plt.subplots(\n", 707 | " nrows=1,\n", 708 | " ncols=3,\n", 709 | " figsize=(11, 3.5),\n", 710 | " dpi=300,\n", 711 | " gridspec_kw={\"width_ratios\": [5, 5, 2.5]},\n", 712 | ")\n", 713 | "\n", 714 | "py0.traj_quantile(\n", 715 | " trajs[:, :, 0, :],\n", 716 | " names=full_compartments,\n", 717 | " plot_means=True,\n", 718 | " ax=axs[0],\n", 719 | " add_legend=True,\n", 720 | " alpha=0.2,\n", 721 | ")\n", 722 | "axs[0].plot(ref_traj[:, 0], linestyle=\"--\")\n", 723 | "axs[0].set_ylabel(\"Population Fraction\")\n", 724 | "plt.setp(\n", 725 | " axs[0],\n", 726 | " xlabel=\"Time\",\n", 727 | " title=\"a) Patch 1\",\n", 728 | " xlim=[0, 250],\n", 729 | " ylim=[0, 1.0],\n", 730 | " facecolor=\"white\",\n", 731 | ")\n", 732 | "for _, pf in enumerate(plot_fxns_list[restrained_patches.tolist().index(0)]):\n", 733 | " pf(axs[0], 0, color=\"black\")\n", 734 | "Reference_line = plt.Line2D((0, 1), (0, 0), color=\"k\", linestyle=\"--\")\n", 735 | "rs_marker = plt.Line2D((0, 1), (0, 0), color=\"k\", marker=\"o\", linestyle=\"\", ms=3)\n", 736 | "hand, labl = axs[0].get_legend_handles_labels()\n", 737 | "Reference_label = \"Reference model\"\n", 738 | "Restraints_label = \"Restraints\"\n", 739 | "axs[0].legend(\n", 740 | " [handle for i, handle in enumerate(hand)] + [Reference_line, rs_marker],\n", 741 | " [label for i, label in enumerate(labl)] + [Reference_label, Restraints_label],\n", 742 | " bbox_to_anchor=(1.02, 0.7),\n", 743 | " fontsize=8,\n", 744 | " frameon=True,\n", 745 | " fancybox=True,\n", 746 | " facecolor=\"white\",\n", 747 | ")\n", 748 | "\n", 749 | "\n", 750 | "plt.setp(\n", 751 | " axs[1],\n", 752 | " xlabel=\"Time\",\n", 753 | " title=\"b) Compartment R in Patch 3\",\n", 754 | " xlim=[0, 250],\n", 755 | " ylim=[0, 1.0],\n", 756 | " facecolor=\"white\",\n", 757 | ")\n", 758 | "w = np.ones(trajs.shape[0])\n", 759 | "w /= np.sum(w)\n", 760 | "mtrajs_unbiased = np.sum(trajs * w[:, np.newaxis, np.newaxis, np.newaxis], axis=0)\n", 761 | "mtrajs_maxentbiased = np.sum(\n", 762 | " trajs * me_w[:, np.newaxis, np.newaxis, np.newaxis], axis=0\n", 763 | ")\n", 764 | "mtrajs_regressedbiased = np.sum(regressed_traj, axis=0)\n", 765 | "x = range(trajs.shape[1])\n", 766 | "qtrajs_maxentbiased = np.apply_along_axis(\n", 767 | " lambda x: weighted_quantile(x, [1 / 3, 1 / 2, 2 / 3], sample_weight=me_w), 0, trajs\n", 768 | ")\n", 769 | "qtrajs_maxentbiased[0, :, :] = np.clip(\n", 770 | " qtrajs_maxentbiased[0, :, :] - qtrajs_maxentbiased[1, :, :] + mtrajs_maxentbiased,\n", 771 | " 0,\n", 772 | " 1,\n", 773 | ")\n", 774 | "qtrajs_maxentbiased[2, :, :] = np.clip(\n", 775 | " qtrajs_maxentbiased[2, :, :] - qtrajs_maxentbiased[1, :, :] + mtrajs_maxentbiased,\n", 776 | " 0,\n", 777 | " 1,\n", 778 | ")\n", 779 | "qtrajs_maxentbiased[1, :, :] = mtrajs_maxentbiased\n", 780 | "print(mtrajs_unbiased.shape)\n", 781 | "\n", 782 | "axs[1].plot(x, mtrajs_unbiased[:, 2, 4], color=\"#4a8c76\", label=\"Unbiased\")\n", 783 | "axs[1].plot(x, mtrajs_maxentbiased[:, 2, 4], color=\"#D03D9A\", label=\"MaxEnt\")\n", 784 | "axs[1].plot(x, abc_biased_traj[:, 2, 4], color=\"#fcec03\")\n", 785 | "axs[1].plot(ref_traj[:, 2, 4], linestyle=\"--\", color=\"k\")\n", 786 | "axs[1].fill_between(\n", 787 | " x,\n", 788 | " qtrajs_maxentbiased[0, :, 2, 4],\n", 789 | " qtrajs_maxentbiased[-1, :, 2, 4],\n", 790 | " color=\"#D03D9A\",\n", 791 | " alpha=0.2,\n", 792 | ")\n", 793 | "axs[1].plot(x, mtrajs_regressedbiased[:, 2, 4], color=\"#35a9d4\")\n", 794 | "axs[1].legend(bbox_to_anchor=(0.45, 0.38), fontsize=6)\n", 795 | "hand, labl = axs[1].get_legend_handles_labels()\n", 796 | "predicted_label_LS = \"Least-squares\"\n", 797 | "predicted_line_LS = plt.Line2D((0, 1), (0, 0), color=\"#35a9d4\")\n", 798 | "predicted_label_abc = \"ABC\"\n", 799 | "predicted_line_abc = plt.Line2D((0, 1), (0, 0), color=\"#fcec03\")\n", 800 | "ref_label = \"Reference model\"\n", 801 | "ref_line = plt.Line2D((0, 1), (0, 0), color=\"k\", linestyle=\"--\")\n", 802 | "axs[1].legend(\n", 803 | " [handle for i, handle in enumerate(hand)]\n", 804 | " + [predicted_line_LS, predicted_line_abc, ref_line],\n", 805 | " [label for i, label in enumerate(labl)]\n", 806 | " + [predicted_label_LS, predicted_label_abc, ref_label],\n", 807 | " bbox_to_anchor=(1.6, 0.67),\n", 808 | " fontsize=8,\n", 809 | " frameon=True,\n", 810 | " fancybox=True,\n", 811 | " facecolor=\"white\",\n", 812 | ")\n", 813 | "\n", 814 | "collabel = (\"$\\sigma_{t=0}$\", \"$\\sigma_{t=125}$\", \"$\\sigma_{t=250}$\")\n", 815 | "axs[2].axis(\"tight\")\n", 816 | "axs[2].axis(\"off\")\n", 817 | "tb = axs[2].table(\n", 818 | " cellText=np.round(new_stds, 3),\n", 819 | " colLabels=collabel,\n", 820 | " rowLabels=[\"MaxEnt\", \"Least-squares\", \"ABC\"],\n", 821 | " loc=\"center\",\n", 822 | " cellLoc=\"center\",\n", 823 | " rowLoc=\"center\",\n", 824 | " colWidths=[0.2 for x in collabel],\n", 825 | " fontsize=12,\n", 826 | " edges=\"BRTL\",\n", 827 | " bbox=[-0.38, 0.19, 0.65, 0.45],\n", 828 | " alpha=0.35,\n", 829 | ")\n", 830 | "\n", 831 | "for (row, col), cell in tb.get_celld().items():\n", 832 | " if (row == 0) or (col == -1):\n", 833 | " cell.set_text_props(fontproperties=FontProperties(weight=\"bold\"))\n", 834 | "\n", 835 | "for key, cell in tb.get_celld().items():\n", 836 | " cell.set_linewidth(0.8)\n", 837 | " cell.set_edgecolor(\"#545350\")\n", 838 | " cell.set_facecolor(\"white\")\n", 839 | " cell.set_alpha(0.9)\n", 840 | "\n", 841 | "set_align_for_column(tb, col=0, align=\"center\")\n", 842 | "set_align_for_column(tb, col=1, align=\"center\")\n", 843 | "plt.tight_layout()\n", 844 | "plt.subplots_adjust(wspace=0.9)\n", 845 | "axs[2].set_position([0.652, 0.041, 0.15, 0.7])" 846 | ] 847 | }, 848 | { 849 | "cell_type": "markdown", 850 | "metadata": {}, 851 | "source": [ 852 | "### Variational Inference" 853 | ] 854 | }, 855 | { 856 | "cell_type": "code", 857 | "execution_count": null, 858 | "metadata": {}, 859 | "outputs": [], 860 | "source": [ 861 | "fit_param_model = py0.MetaParameterJoint(\n", 862 | " start_logits, R, tmat, name=\"unbiased_model\", hypers=hyper_pram\n", 863 | ")\n", 864 | "fit_param_model.compile(tf.optimizers.Adam(0.1))\n", 865 | "hme_model = maxent.HyperMaxentModel(restraints, fit_param_model, model, reweight=False)\n", 866 | "hme_model.compile(tf.keras.optimizers.Adam(learning_rate=0.1), \"mean_squared_error\")" 867 | ] 868 | }, 869 | { 870 | "cell_type": "code", 871 | "execution_count": null, 872 | "metadata": {}, 873 | "outputs": [], 874 | "source": [ 875 | "hme_history = hme_model.fit(\n", 876 | " N,\n", 877 | " final_batch_multiplier=batches,\n", 878 | " outter_epochs=3,\n", 879 | " param_epochs=250,\n", 880 | " batch_size=batch_size // 4,\n", 881 | " epochs=1000,\n", 882 | " verbose=0,\n", 883 | ")" 884 | ] 885 | }, 886 | { 887 | "cell_type": "code", 888 | "execution_count": null, 889 | "metadata": {}, 890 | "outputs": [], 891 | "source": [ 892 | "# hme_final_history = hme_model.fit(N, final_batch_multiplier=batches, outter_epochs=1, batch_size=batch_size, epochs=1000, verbose=0)" 893 | ] 894 | }, 895 | { 896 | "cell_type": "code", 897 | "execution_count": null, 898 | "metadata": {}, 899 | "outputs": [], 900 | "source": [ 901 | "fig, ax = plt.subplots(1, 3, figsize=(10, 3), dpi=200)\n", 902 | "ax[0].plot(me_history.history[\"loss\"], label=\"MaxEnt\")\n", 903 | "ax[0].plot(hme_history.history[\"loss\"], label=\"Hyper-MaxEnt\")\n", 904 | "ax[0].set_title(\n", 905 | " \"loss (number of patches restrained: {})\".format(len(restrained_patches))\n", 906 | ")\n", 907 | "ax[0].set_yscale(\"log\")\n", 908 | "ax[0].legend()\n", 909 | "\n", 910 | "ax[1].plot(hme_history.history[\"weight-entropy\"])\n", 911 | "ax[1].set_title(\"weight entropy\")\n", 912 | "\n", 913 | "ax[2].plot(hme_history.history[\"prior-loss\"])\n", 914 | "ax[2].set_title(\"neg log-likelihood\")" 915 | ] 916 | }, 917 | { 918 | "cell_type": "code", 919 | "execution_count": null, 920 | "metadata": {}, 921 | "outputs": [], 922 | "source": [ 923 | "fig, axs = plt.subplots(nrows=1, ncols=M, figsize=(8, 2), dpi=300)\n", 924 | "fig.suptitle(\"Variational Maxent Biased Model\", y=1.2, fontsize=18)\n", 925 | "py0.traj_quantile(\n", 926 | " hme_model.trajs[:, :, 0, :],\n", 927 | " weights=hme_model.traj_weights,\n", 928 | " names=full_compartments,\n", 929 | " plot_means=True,\n", 930 | " ax=axs[0],\n", 931 | " add_legend=False,\n", 932 | " alpha=0.2,\n", 933 | ")\n", 934 | "py0.traj_quantile(\n", 935 | " hme_model.trajs[:, :, 1, :],\n", 936 | " weights=hme_model.traj_weights,\n", 937 | " names=full_compartments,\n", 938 | " plot_means=True,\n", 939 | " ax=axs[1],\n", 940 | " add_legend=False,\n", 941 | " alpha=0.2,\n", 942 | ")\n", 943 | "py0.traj_quantile(\n", 944 | " hme_model.trajs[:, :, 2, :],\n", 945 | " weights=hme_model.traj_weights,\n", 946 | " names=full_compartments,\n", 947 | " plot_means=True,\n", 948 | " ax=axs[2],\n", 949 | " alpha=0.2,\n", 950 | ")\n", 951 | "for i in range(M):\n", 952 | " plt.setp(axs[i], xlabel=\"Time\", title=\"Patch {}\".format(i + 1), ylim=[0, 1.0])\n", 953 | " axs[i].plot(ref_traj[:, i], linestyle=\"--\")\n", 954 | " if i in restrained_patches:\n", 955 | " for _, pf in enumerate(plot_fxns_list[restrained_patches.tolist().index(i)]):\n", 956 | " pf(axs[i], 0, color=\"C3\")\n", 957 | " axs[i].spines[\"bottom\"].set_color(\"y\")\n", 958 | " axs[i].spines[\"top\"].set_color(\"y\")\n", 959 | " axs[i].spines[\"right\"].set_color(\"y\")\n", 960 | " axs[i].spines[\"left\"].set_color(\"y\")\n", 961 | " axs[i].spines[\"left\"].set_linewidth(2)\n", 962 | " axs[i].spines[\"top\"].set_linewidth(2)\n", 963 | " axs[i].spines[\"right\"].set_linewidth(2)\n", 964 | " axs[i].spines[\"bottom\"].set_linewidth(2)\n", 965 | "axs[0].set_ylabel(\"Population Fraction\")\n", 966 | "plt.savefig(\"varitional_maxent.svg\", dpi=600)" 967 | ] 968 | } 969 | ], 970 | "metadata": { 971 | "kernelspec": { 972 | "display_name": "Python 3", 973 | "language": "python", 974 | "name": "python3" 975 | }, 976 | "language_info": { 977 | "codemirror_mode": { 978 | "name": "ipython", 979 | "version": 3 980 | }, 981 | "file_extension": ".py", 982 | "mimetype": "text/x-python", 983 | "name": "python", 984 | "nbconvert_exporter": "python", 985 | "pygments_lexer": "ipython3", 986 | "version": "3.8.13" 987 | } 988 | }, 989 | "nbformat": 4, 990 | "nbformat_minor": 4 991 | } 992 | --------------------------------------------------------------------------------