├── .github ├── chained_cue_navigation_v1.gif ├── chained_cue_navigation_v2.gif ├── pymdp_logo_2-removebg.png ├── pymdp_logo_2.jpeg └── workflows │ └── python-package.yml ├── .gitignore ├── .readthedocs.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _static │ └── pymdp_logo_2-removebg.png ├── agent.rst ├── algos │ ├── fpi.rst │ ├── index.rst │ └── mmp.rst ├── conf.py ├── control.rst ├── env.rst ├── index.rst ├── inference.rst ├── installation.rst ├── learning.rst ├── make.bat ├── notebooks │ ├── active_inference_from_scratch.ipynb │ ├── cue_chaining_demo.ipynb │ ├── free_energy_calculation.ipynb │ ├── pymdp_fundamentals.ipynb │ ├── tmaze_demo.ipynb │ └── using_the_agent_class.ipynb └── requirements.txt ├── examples ├── A_matrix_demo.ipynb ├── A_matrix_demo.py ├── __init__.py ├── agent_demo.ipynb ├── agent_demo.py ├── building_up_agent_loop.ipynb ├── free_energy_calculation.ipynb ├── gridworld_tutorial_1.ipynb ├── gridworld_tutorial_2.ipynb ├── inductive_inference_example.ipynb ├── inductive_inference_gridworld.ipynb ├── inference_and_learning │ └── inference_methods_comparison.ipynb ├── learning │ └── learning_gridworld.ipynb ├── model_inversion.ipynb ├── testing_large_latent_spaces.ipynb ├── tmaze_demo.ipynb ├── tmaze_learning_demo.ipynb └── tmp_dir │ └── my_a_matrix.xlsx ├── paper ├── paper.bib └── paper.md ├── pymdp ├── __init__.py ├── agent.py ├── algos │ ├── __init__.py │ ├── fpi.py │ ├── mmp.py │ └── mmp_old.py ├── control.py ├── default_models.py ├── envs │ ├── __init__.py │ ├── env.py │ ├── grid_worlds.py │ ├── tmaze.py │ └── visual_foraging.py ├── inference.py ├── jax │ ├── __init__.py │ ├── agent.py │ ├── algos.py │ ├── control.py │ ├── inference.py │ ├── learning.py │ ├── likelihoods.py │ ├── maths.py │ ├── task.py │ └── utils.py ├── learning.py ├── maths.py └── utils.py ├── requirements.txt ├── setup.py └── test ├── __init__.py ├── matlab_crossval ├── generation │ ├── bmr_matlab_test_a.m │ ├── bmr_matlab_test_b.m │ ├── mmp_matlab_test_a.m │ ├── mmp_matlab_test_b.m │ ├── mmp_matlab_test_c.m │ ├── mmp_matlab_test_d.m │ ├── run_mmp.m │ ├── vb_x_matlab_test_1a.m │ └── vb_x_matlab_test_1b.m └── output │ ├── bmr_test_a.mat │ ├── bmr_test_b.mat │ ├── cross_a.mat │ ├── cross_b.mat │ ├── cross_c.mat │ ├── cross_d.mat │ ├── cross_e.mat │ ├── dot_a.mat │ ├── dot_b.mat │ ├── dot_c.mat │ ├── dot_d.mat │ ├── dot_e.mat │ ├── mmp_a.mat │ ├── mmp_b.mat │ ├── mmp_c.mat │ ├── mmp_d.mat │ ├── vbx_test_1a.mat │ ├── wnorm_a.mat │ └── wnorm_b.mat ├── test_SPM_validation.py ├── test_agent.py ├── test_agent_jax.py ├── test_control.py ├── test_control_jax.py ├── test_demos.py ├── test_fpi.py ├── test_inference.py ├── test_inference_jax.py ├── test_learning.py ├── test_learning_jax.py ├── test_message_passing_jax.py ├── test_mmp.py ├── test_utils.py └── test_wrappers.py /.github/chained_cue_navigation_v1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/.github/chained_cue_navigation_v1.gif -------------------------------------------------------------------------------- /.github/chained_cue_navigation_v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/.github/chained_cue_navigation_v2.gif -------------------------------------------------------------------------------- /.github/pymdp_logo_2-removebg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/.github/pymdp_logo_2-removebg.png -------------------------------------------------------------------------------- /.github/pymdp_logo_2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/.github/pymdp_logo_2.jpeg -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ["3.10", "3.11", "3.12"] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install flake8 pytest 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Lint with flake8 32 | run: | 33 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 34 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 35 | - name: Test with pytest 36 | run: | 37 | pytest test 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | .DS_Store 4 | .ipynb_checkpoints 5 | .rope* 6 | .vscode/ 7 | .ipynb_checkpoints/ 8 | .pytest_cache 9 | env/ 10 | pymdp.egg-info 11 | inferactively_pymdp.egg-info 12 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: "ubuntu-20.04" 10 | tools: 11 | python: "3.8" 12 | 13 | # Optionally set the version of Python and requirements required to build your docs 14 | python: 15 | install: 16 | - requirements: docs/requirements.txt 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/conf.py 21 | fail_on_warning: false 22 | 23 | # Optionally build your docs in additional formats such as PDF and ePub 24 | formats: 25 | - htmlzip 26 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # pymdp 2 | 3 | *Borrowed below list from [here](https://github.com/netsiphds/netrd)* 4 | 5 | Welcome to `pymdp` and thanks for your interest in contributing! 6 | During development please make sure to keep the following checklists handy. 7 | They contain a summary of all the important steps you need to take to contribute 8 | to the package. As a general statement, the more familiar you already are 9 | with git(hub), the less relevant the detailed instructions below will be for you. 10 | 11 | 12 | ## Types of Contribution 13 | 14 | There are multiple ways to contribute to `pymdp` (borrowed below list from [here](https://github.com/uzhdag/pathpy/blob/master/CONTRIBUTING.rst)): 15 | 16 | #### Report Bugs 17 | 18 | To report a bug in the package, open an issue at https://github.com/infer-actively/pymdp/issues. 19 | 20 | Please include in your bug report: 21 | 22 | * Your operating system name and version. 23 | * Any details about your local setup that might be helpful in troubleshooting. 24 | * Detailed steps to reproduce the bug. 25 | 26 | #### Fix Bugs 27 | 28 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help 29 | wanted" is open to whoever wants to implement it. 30 | 31 | #### Implement Features or New Methods 32 | 33 | Look through the GitHub issues for features. Anything tagged with "enhancement" 34 | and "help wanted" is open to whomever wants to implement it. If you know of a 35 | method that is implemented in another programming language, feel free to 36 | translate it into python here. If you don't want to translate it yourself, feel 37 | free to add an issue at https://github.com/infer-actively/pymdp/issues. If you have 38 | read through this document and still have questions, also open an issue. When 39 | in doubt, open an issue. 40 | 41 | #### Improve Documentation 42 | 43 | Documentation is just as important as the code it documents. Please feel 44 | free to submit PRs that are focused on fixing, improving, correcting, or 45 | refactoring documentation. 46 | 47 | #### Submit Feedback 48 | 49 | The best way to send feedback is to open an issue. 50 | 51 | If you are proposing to implement a function, feature, etc. 52 | see more details below. 53 | 54 | If you are proposing a feature not directly related to implementing a new method: 55 | 56 | * Explain in detail why the feature is desirable and how it would work. 57 | * Keep the scope as narrow as possible, to make it easier to implement. 58 | * Remember that this is a volunteer-driven project, and that your contributions 59 | are welcome! 60 | 61 | ##### A Brief Note On Licensing 62 | Often, python code for an algorithm of interest already exists. In the interest of avoiding repeated reinvention of the wheel, we welcome code from other sources being integrated into `pymdp`. If you are doing this, we ask that you be explicit and transparent about where the code came from and which license it is released under. The safest thing to do is copy the license from the original code into the header documentation of your file. For reference, this software is [licensed under MIT](https://github.com/tlarock/netrd/blob/master/LICENSE). 63 | 64 | ## Setup 65 | Before starting your contribution, you need to complete the following instructions once. 66 | The goal of this process is to fork, download and install the latest version of `pymdp`. 67 | 68 | 1. Log in to GitHub. 69 | 70 | 2. Fork this repository by pressing 'Fork' at the top right of this 71 | page. This will lead you to 'github.com//infer-actively'. We refer 72 | to this as your personal fork (or just 'your fork'), as opposed to this repository 73 | (github.com/infer-actively/pymdp), which we refer to as the 'upstream repository'. 74 | 75 | 3. Clone your fork to your machine by opening a console and doing 76 | 77 | ``` 78 | git clone https://github.com//infer-actively.git 79 | ``` 80 | 81 | Make sure to clone your fork, not the upstream repo. This will create a 82 | directory called 'infer-actively/'. Navigate to it and execute 83 | 84 | ``` 85 | git remote add upstream https://github.com/infer-actively/pymdp.git 86 | ``` 87 | 88 | In this way, your machine will know of both your fork (which git calls 89 | `origin`) and the upstream repository (`upstream`). 90 | 91 | 4. During development, you will probably want to play around with your 92 | code. For this, you need to install the `pymdp` package and have it 93 | reflect your changes as you go along. For this, open the console and 94 | navigate to the `infer-actively/` directory, and execute 95 | 96 | ``` 97 | pip install -e . 98 | ``` 99 | 100 | From now on, you can open a Jupyter notebook, ipython console, or your 101 | favorite IDE from anywhere in your computer and type `import pymdp`. 102 | 103 | 104 | These steps need to be taken only once. Now anything you do in the `infer-actively/` 105 | directory in your machine can be `push`ed into your fork. Once it is in 106 | your fork you can then request one of the organizers to `pull` from your 107 | fork into the upstream repository (by submitting a 'pull request'). More on this later! 108 | 109 | 110 | ## Before you start coding 111 | 112 | Once you have completed the above steps, you are ready to choose an algorithm to implement and begin coding. 113 | 114 | 1. Choose which algorithm you are interested in working on. 115 | 116 | 2. Open an issue at https://github.com/infer-actively/pymdp/issues by clicking the "New Issue" button. 117 | 118 | * Title the issue "Implement XYZ method", where XYZ method is a shorthand name for whatever function / method / environment class you plan to implement. 119 | * Leave a comment that includes a brief motivation for why you want to see this method in `pymdp`, as well as any key citations. 120 | * If such an issue already exists for the method you are going to write, it is not necessary to open another. However, it is a good idea to leave a comment letting others know you are going to work on it. 121 | 122 | 2. In your machine, create the file where your algorithm is going to 123 | live. If you chose a softmax algorithm, copy 124 | an existing file, such as `/pymdp/functions.py`, into 125 | `/pymdp/.py`. Please keep in mind that 126 | will be used inside the code, so try to choose 127 | something that looks "pythonic". In particular, cannot 128 | include spaces, should not include upper case letters, and should use underscores 129 | rather than hyphens. 130 | 131 | 3. Open the newly created file and edit as follows. At the very top you 132 | will find a string describing the algorithm. Edit this to describe the algorithm you 133 | are about to code, and preferably include a citation and link to any relevant papers. 134 | Also add your name and email address (optional). 135 | 136 | ## After you finish coding 137 | 138 | 1. After updating your local code, the first thing to do is tell git which files 139 | you have been working on. (This is called staging.) If you worked on a softmax 140 | function, for example, do 141 | 142 | ``` 143 | git add pymdp/.py 144 | ``` 145 | 146 | 2. Next tell git to commit (or save) your changes: 147 | 148 | ``` 149 | git commit -m 'Write a commit message here. This will be public and 150 | should be descriptive of the work you have done. Please be as explicit 151 | as possible, but at least make sure to include the name of the method 152 | you implemented. For example, the commit message may be: add 153 | implementation of SomeMethod, based on SomeAuthor and/or SomeCode.' 154 | ``` 155 | 156 | 3. Now you have to tell git to do two things. First, `pull` the latest changes from 157 | the upstream repository (in case someone made changes while you were coding), 158 | then `push` your changes and the updated code from your machine to your fork: 159 | 160 | ``` 161 | git pull upstream master 162 | git push origin master 163 | ``` 164 | 165 | NOTE: If you edited already existing files, the `pull` may result in 166 | conflicts that must be merged. If you run in to trouble here, ask 167 | for help! 168 | 169 | 4. Finally, you need to tell this (the upstream) repository to include your 170 | contributions. For this, we use the GitHub web interface. At the top of 171 | this page, there is a 'New Pull Request' button. Click on it, and it 172 | will take you to a page titled 'Compare Changes'. Right below the title, 173 | click on the blue text that reads 'compare across forks'. This will show 174 | four buttons. Make sure that the first button reads 'base fork: 175 | infer-actively/pymdp', the second button reads 'base: master', the third 176 | button reads 'head fork: /infer-actively', and the fourth button 177 | reads 'compare: master'. (If everything has gone according to plan, the 178 | only button you should have to change is the third one - make sure you 179 | find your username, not someone elses.) After you find your username, 180 | GitHub will show a rundown of the differences that you are adding to the 181 | upstream repository, so you will be able to see what changes you are 182 | contributing. If everything looks correct, press 'Create Pull 183 | Request'. 184 | NOTE: Advanced git users may want to develop on branches other 185 | than master on their fork. That is totally fine, we won't know 186 | the difference in the end anyway. 187 | 188 | 189 | That's it! After you've completed these steps, maintainers will be notified 190 | and will review your code and changes to make sure that everything is in place. 191 | Some automated tests will also run in the background to make sure that your 192 | code can be imported correctly and other sanity checks. Once that is all done, 193 | one of us will either accept your Pull Request, or leave a message requesting some 194 | changes (you will receive an email either way). 195 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Conor Heins and Alec Tschantz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | 4 | 5 | 6 |

7 | 8 | A Python package for simulating Active Inference agents in Markov Decision Process environments. 9 | Please see our companion paper, published in the Journal of Open Source Software: ["pymdp: A Python library for active inference in discrete state spaces"](https://joss.theoj.org/papers/10.21105/joss.04098) for an overview of the package and its motivation. For a more in-depth, tutorial-style introduction to the package and a mathematical overview of active inference in Markov Decision Processes, see the [longer arxiv version](https://arxiv.org/abs/2201.03904) of the paper. 10 | 11 | This package is hosted on the [`infer-actively`](https://github.com/infer-actively) GitHub organization, which was built with the intention of hosting open-source active inference and free-energy-principle related software. 12 | 13 | Most of the low-level mathematical operations are [NumPy](https://github.com/numpy/numpy) ports of their equivalent functions from the `SPM` [implementation](https://www.fil.ion.ucl.ac.uk/spm/doc/) in MATLAB. We have benchmarked and validated most of these functions against their SPM counterparts. 14 | 15 | ## Status 16 | 17 | ![status](https://img.shields.io/badge/status-active-green) 18 | ![PyPI version](https://img.shields.io/pypi/v/inferactively-pymdp) 19 | [![Documentation Status](https://readthedocs.org/projects/pymdp-rtd/badge/?version=latest)](https://pymdp-rtd.readthedocs.io/en/latest/?badge=latest) 20 | [![DOI](https://joss.theoj.org/papers/10.21105/joss.04098/status.svg)](https://doi.org/10.21105/joss.04098) 21 | 22 | 23 | # ``pymdp`` in action 24 | 25 | Here's a visualization of ``pymdp`` agents in action. One of the defining features of active inference agents is the drive to maximize "epistemic value" (i.e. curiosity). Equipped with such a drive in environments with uncertain yet disclosable hidden structure, active inference can ultimately allow agents to simultaneously learn about the environment as well as maximize reward. 26 | 27 | The simulation below (see associated notebook [here](https://pymdp-rtd.readthedocs.io/en/latest/notebooks/cue_chaining_demo.html)) demonstrates what might be called "epistemic chaining," where an agent (here, analogized to a mouse seeking food) forages for a chain of cues, each of which discloses the location of the subsequent cue in the chain. The final cue (here, "Cue 2") reveals the location a hidden reward. This is similar in spirit to "behavior chaining" used in operant conditioning, except that here, each successive action in the behavioral sequence doesn't need to be learned through instrumental conditioning. Rather, active inference agents will naturally forage the sequence of cues based on an intrinsic desire to disclose information. This ultimately leads the agent to the hidden reward source in the fewest number of moves as possible. 28 | 29 | You can run the code behind simulating tasks like this one and others in the **Examples** section of the [official documentation](https://pymdp-rtd.readthedocs.io/en/stable/). 30 | 31 | 36 | 37 | 38 | 39 | 40 | 47 | 54 |
41 |

42 | 43 |
44 | Cue 2 in Location 1, Reward on Top 45 |

46 |
48 |

49 | 50 |
51 | Cue 2 in Location 3, Reward on Bottom 52 |

53 |
55 | 56 | ## Quick-start: Installation and Usage 57 | 58 | In order to use `pymdp` to build and develop active inference agents, we recommend installing it with the the package installer [`pip`](https://pip.pypa.io/en/stable/), which will install `pymdp` locally as well as its dependencies. This can also be done in a virtual environment (e.g. with `venv`). 59 | 60 | When pip installing `pymdp`, use the package name `inferactively-pymdp`: 61 | 62 | ```bash 63 | pip install inferactively-pymdp 64 | ``` 65 | 66 | Once in Python, you can then directly import `pymdp`, its sub-packages, and functions. 67 | 68 | ```bash 69 | 70 | import pymdp 71 | from pymdp import utils 72 | from pymdp.agent import Agent 73 | 74 | num_obs = [3, 5] # observation modality dimensions 75 | num_states = [3, 2, 2] # hidden state factor dimensions 76 | num_controls = [3, 1, 1] # control state factor dimensions 77 | A_matrix = utils.random_A_matrix(num_obs, num_states) # create sensory likelihood (A matrix) 78 | B_matrix = utils.random_B_matrix(num_states, num_controls) # create transition likelihood (B matrix) 79 | 80 | C_vector = utils.obj_array_uniform(num_obs) # uniform preferences 81 | 82 | # instantiate a quick agent using your A, B and C arrays 83 | my_agent = Agent( A = A_matrix, B = B_matrix, C = C_vector) 84 | 85 | # give the agent a random observation and get the optimized posterior beliefs 86 | 87 | observation = [1, 4] # a list specifying the indices of the observation, for each observation modality 88 | 89 | qs = my_agent.infer_states(observation) # get posterior over hidden states (a multi-factor belief) 90 | 91 | # Do active inference 92 | 93 | q_pi, neg_efe = my_agent.infer_policies() # return the policy posterior and return (negative) expected free energies of each policy as well 94 | 95 | action = my_agent.sample_action() # sample an action 96 | 97 | # ... and so on ... 98 | ``` 99 | 100 | ## Getting started / introductory material 101 | 102 | We recommend starting with the Installation/Usage section of the [official documentation](https://pymdp-rtd.readthedocs.io/en/stable/) for the repository, which provides a series of useful pedagogical notebooks for introducing you to active inference and how to build agents in `pymdp`. 103 | 104 | For new users to `pymdp`, we specifically recommend stepping through following three Jupyter notebooks (can also be used on Google Colab): 105 | 106 | - [`Pymdp` fundamentals](https://pymdp-rtd.readthedocs.io/en/latest/notebooks/pymdp_fundamentals.html) 107 | - [Active Inference from Scratch](https://pymdp-rtd.readthedocs.io/en/latest/notebooks/active_inference_from_scratch.html) 108 | - [The `Agent` API](https://pymdp-rtd.readthedocs.io/en/latest/notebooks/using_the_agent_class.html) 109 | 110 | Special thanks to [Beren Millidge](https://github.com/BerenMillidge) and [Daphne Demekas](https://github.com/daphnedemekas) for their help in prototyping earlier versions of the [Active Inference from Scratch](https://pymdp-rtd.readthedocs.io/en/latest/notebooks/active_inference_from_scratch.html) tutorial, which were originally based on a grid world POMDP environment create by [Alec Tschantz](https://github.com/alec-tschantz). 111 | 112 | We also have (and are continuing to build) a series of notebooks that walk through active inference agents performing different types of tasks, such as the classic [T-Maze environment](https://pymdp-rtd.readthedocs.io/en/latest/notebooks/tmaze_demo.html) and the newer [Epistemic Chaining](https://pymdp-rtd.readthedocs.io/en/latest/notebooks/cue_chaining_demo.html) demo. 113 | 114 | ## Contributing 115 | 116 | This package is under active development. If you would like to contribute, please refer to [this file](CONTRIBUTING.md) 117 | 118 | If you would like to contribute to this repo, we recommend using venv and pip 119 | ```bash 120 | cd 121 | python3 -m venv env 122 | source env/bin/activate 123 | pip install -r requirements.txt 124 | pip install -e ./ # This will install pymdp as a local dev package 125 | ``` 126 | 127 | You should then be able to run tests locally with `pytest` 128 | ```bash 129 | pytest test 130 | ``` 131 | 132 | ## Citing `pymdp` 133 | If you use `pymdp` in your work or research, please consider citing our [paper](https://joss.theoj.org/papers/10.21105/joss.04098) (open-access) published in the Journal of Open-Source Software: 134 | 135 | ``` 136 | @article{Heins2022, 137 | doi = {10.21105/joss.04098}, 138 | url = {https://doi.org/10.21105/joss.04098}, 139 | year = {2022}, 140 | publisher = {The Open Journal}, 141 | volume = {7}, 142 | number = {73}, 143 | pages = {4098}, 144 | author = {Conor Heins and Beren Millidge and Daphne Demekas and Brennan Klein and Karl Friston and Iain D. Couzin and Alexander Tschantz}, 145 | title = {pymdp: A Python library for active inference in discrete state spaces}, 146 | journal = {Journal of Open Source Software} 147 | } 148 | ``` 149 | 150 | For a more in-depth, tutorial-style introduction to the package and a mathematical overview of active inference in Markov Decision Processes, you can also consult the [longer arxiv version](https://arxiv.org/abs/2201.03904) of the paper. 151 | 152 | ## Authors 153 | 154 | - Conor Heins [@conorheins](https://github.com/conorheins) 155 | - Alec Tschantz [@alec-tschantz](https://github.com/alec-tschantz) 156 | - Beren Millidge [@BerenMillidge](https://github.com/BerenMillidge) 157 | - Brennan Klein [@jkbren](https://github.com/jkbren) 158 | - Arun Niranjan [@Arun-Niranjan](https://github.com/Arun-Niranjan) 159 | - Daphne Demekas [@daphnedemekas](https://github.com/daphnedemekas) 160 | - Aswin Paul [@aswinpaul](https://github.com/aswinpaul) 161 | - Tim Verbelen [@tverbele](https://github.com/tverbele) 162 | - Dimitrije Markovic [@dimarkov](https://github.com/dimarkov) 163 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/pymdp_logo_2-removebg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/docs/_static/pymdp_logo_2-removebg.png -------------------------------------------------------------------------------- /docs/agent.rst: -------------------------------------------------------------------------------- 1 | Agent class 2 | ================================= 3 | 4 | .. autoclass:: pymdp.agent.Agent 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/algos/fpi.rst: -------------------------------------------------------------------------------- 1 | FPI (Fixed Point Iteration) 2 | ================================= 3 | 4 | .. automodule:: pymdp.algos.fpi 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /docs/algos/index.rst: -------------------------------------------------------------------------------- 1 | Algos 2 | ================================= 3 | 4 | The ``algos.py`` library contains the functions for implementing message passing algorithms for variational inference on POMDP generative models 5 | 6 | Sub-libraries 7 | --------------- 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | 12 | fpi 13 | mmp -------------------------------------------------------------------------------- /docs/algos/mmp.rst: -------------------------------------------------------------------------------- 1 | MMP (Marginal Message Passing) 2 | ================================= 3 | 4 | .. automodule:: pymdp.algos.mmp 5 | :members: -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('..')) 16 | 17 | # -- Project information ----------------------------------------------------- 18 | 19 | project = 'pymdp' 20 | copyright = '2021, infer-actively' 21 | author = 'infer-actively' 22 | 23 | # The full version, including alpha/beta/rc tags 24 | release = '0.0.7.1' 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = ['sphinx.ext.autodoc', 32 | 'sphinx.ext.doctest', 33 | 'sphinx.ext.coverage', 34 | 'sphinx.ext.napoleon', 35 | 'sphinx.ext.autosummary', 36 | 'myst_nb' 37 | ] 38 | 39 | source_suffix = { 40 | '.rst': 'restructuredtext', 41 | '.ipynb': 'myst-nb' 42 | } 43 | 44 | # Add any paths that contain templates here, relative to this directory. 45 | templates_path = ['_templates'] 46 | 47 | # List of patterns, relative to source directory, that match files and 48 | # directories to ignore when looking for source files. 49 | # This pattern also affects html_static_path and html_extra_path. 50 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 51 | 52 | 53 | # -- Options for HTML output ------------------------------------------------- 54 | 55 | # The theme to use for HTML and HTML Help pages. See the documentation for 56 | # a list of builtin themes. 57 | # 58 | html_theme = 'sphinx_rtd_theme' 59 | 60 | # Theme options are theme-specific and customize the look and feel of a theme 61 | # further. For a list of options available for each theme, see the 62 | # documentation. 63 | html_theme_options = { 64 | 'logo_only': True, 65 | } 66 | 67 | # The name of an image file (relative to this directory) to place at the top 68 | # of the sidebar. 69 | html_logo = '_static/pymdp_logo_2-removebg.png' 70 | 71 | html_favicon = '_static/pymdp_logo_2-removebg.png' 72 | 73 | # Add any paths that contain custom static files (such as style sheets) here, 74 | # relative to this directory. They are copied after the builtin static files, 75 | # so a file named "default.css" will overwrite the builtin "default.css". 76 | html_static_path = ['_static'] 77 | 78 | 79 | # -- Options for myst ---------------------------------------------- 80 | jupyter_execute_notebooks = "cache" 81 | jupyter_cache = "notebooks" 82 | -------------------------------------------------------------------------------- /docs/control.rst: -------------------------------------------------------------------------------- 1 | Control 2 | ================================= 3 | 4 | The ``control.py`` module contains the functions for performing inference of policies (sequences of control states) in POMDP generative models, 5 | according to active inference. 6 | 7 | .. automodule:: pymdp.control 8 | :members: -------------------------------------------------------------------------------- /docs/env.rst: -------------------------------------------------------------------------------- 1 | Env 2 | ======== 3 | 4 | The OpenAIGym-inspired ``Env`` base class is the main API that represents the environmental dynamics or "generative process" with 5 | which agents exchange observations and actions 6 | 7 | Base class 8 | ---------- 9 | .. autoclass:: pymdp.envs.Env 10 | 11 | Specific environment implementations 12 | ---------- 13 | 14 | All of the following dynamics inherit from ``Env`` and have the 15 | same general usage as above. 16 | 17 | .. autosummary:: 18 | :nosignatures: 19 | 20 | pymdp.envs.GridWorldEnv 21 | pymdp.envs.DGridWorldEnv 22 | pymdp.envs.VisualForagingEnv 23 | pymdp.envs.TMazeEnv 24 | pymdp.envs.TMazeEnvNullOutcome 25 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. pymdp documentation master file, created by 2 | sphinx-quickstart on Fri Oct 29 13:27:58 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to pymdp's documentation! 7 | ================================= 8 | 9 | ``pymdp`` is a Python package for simulating active inference agents in 10 | discrete space and time, using partially-observed Markov Decision Processes 11 | (POMDPs) as a generative model class. The package is designed to be modular and flexible, to 12 | enable users to design and simulate bespoke active inference models with varying levels of 13 | specificity to a given task. 14 | 15 | For a theoretical overview of active inference and the motivations for developing this package, 16 | please see our companion paper_: "pymdp: A Python library for active inference in discrete state spaces". 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | :caption: Installation & Usage 21 | 22 | installation 23 | notebooks/pymdp_fundamentals 24 | notebooks/active_inference_from_scratch 25 | notebooks/using_the_agent_class 26 | 27 | .. toctree:: 28 | :maxdepth: 1 29 | :caption: Examples 30 | 31 | notebooks/tmaze_demo 32 | notebooks/cue_chaining_demo 33 | 34 | .. toctree:: 35 | :maxdepth: 2 36 | :caption: Modules 37 | 38 | inference 39 | control 40 | learning 41 | algos/index 42 | 43 | .. toctree:: 44 | :maxdepth: 2 45 | :caption: Agent and environment API 46 | 47 | agent 48 | env 49 | 50 | .. toctree:: 51 | :maxdepth: 1 52 | :caption: Additional learning materials 53 | 54 | notebooks/free_energy_calculation 55 | 56 | Indices and tables 57 | ================== 58 | 59 | * :ref:`genindex` 60 | * :ref:`modindex` 61 | * :ref:`search` 62 | 63 | .. _paper: https://joss.theoj.org/papers/10.21105/joss.04098 64 | -------------------------------------------------------------------------------- /docs/inference.rst: -------------------------------------------------------------------------------- 1 | Inference 2 | ================================= 3 | 4 | The ``inference.py`` module contains the functions for performing inference of discrete hidden states (categorical distributions) in POMDP generative models. 5 | 6 | .. automodule:: pymdp.inference 7 | :members: -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ================================= 3 | 4 | We recommend installing ``pymdp`` using the package installer pip_, which will install the package locally as well as its dependencies. 5 | This can also be done in a virtual environment (e.g. one created using ``venv`` or ``conda``). 6 | 7 | When pip installing ``pymdp``, use the full package name: ``inferactively-pymdp``: 8 | 9 | .. code-block:: console 10 | 11 | (.venv) $ pip install inferactively-pymdp 12 | 13 | .. _pip: https://pip.pypa.io/en/stable/ -------------------------------------------------------------------------------- /docs/learning.rst: -------------------------------------------------------------------------------- 1 | Learning 2 | ================================= 3 | 4 | The ``learning.py`` module contains the functions for updating parameters of Dirichlet posteriors (that paramaterise categorical priors and likelihoods) in POMDP generative models. 5 | 6 | .. automodule:: pymdp.learning 7 | :members: -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # sphinx <4 required by myst-nb v0.12.0 (Feb 2021) 2 | # sphinx >=3 required by sphinx-autodoc-typehints v1.11.1 (Oct 2020) 3 | # sphinx >=3, <4 # old version of sphinx dependency, based on two comments above 4 | 5 | # New commit to myst-nb (September 2021) suggests it can work with sphinx >4 (see https://github.com/executablebooks/MyST-NB/commit/9f6fae2da3f1ce3e726320eade079b8e25a878fa) 6 | sphinx==4.2.0 7 | sphinx_rtd_theme 8 | sphinx-autodoc-typehints==1.11.1 9 | jupyter-sphinx>=0.3.2 10 | myst-nb 11 | jinja2==3.0.0 12 | 13 | # Packages used for notebook execution 14 | matplotlib 15 | numpy 16 | seaborn 17 | . -------------------------------------------------------------------------------- /examples/A_matrix_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Generative Model Demo: Constructing a simple likelihood model \n", 8 | "This demo notebook provides a walk-through of how to build a simple A matrix (or likelihood mapping) that encodes an aegnt's beliefs about how hidden states 'cause' or probabilistically relate to observations" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "### Imports\n", 16 | "\n", 17 | "First, import `pymdp` and the modules we'll need." 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import os\n", 27 | "import sys\n", 28 | "import pathlib\n", 29 | "\n", 30 | "import numpy as np\n", 31 | "import itertools\n", 32 | "import pandas as pd\n", 33 | "\n", 34 | "path = pathlib.Path(os.getcwd())\n", 35 | "module_path = str(path.parent) + '/'\n", 36 | "sys.path.append(module_path)\n", 37 | "\n", 38 | "import pymdp.utils as utils\n", 39 | "from pymdp.utils import create_A_matrix_stub, read_A_matrix\n", 40 | "from pymdp.algos import run_vanilla_fpi" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## The world (as represented by the agent's generative model)\n", 48 | "\n", 49 | "### Hidden states\n", 50 | "\n", 51 | "We assume the agent's \"represents\" (this should make you think: generative _model_ , not _process_ ) its environment using two latent variables that are statistically independent of one another - we can thus represent them using two _hidden state factors._\n", 52 | "\n", 53 | "We refer to these two hidden state factors are `DID_IT_RAIN` and `WAS_SPRINKLER_ON`. \n", 54 | "\n", 55 | "#### 1. `DID_IT_RAIN`\n", 56 | "The first factor is a binary variable representing whether or not it rained earlier today.\n", 57 | "\n", 58 | "#### 2. `WAS_SPRINKLER_ON`\n", 59 | "\n", 60 | "The second factor is a binary variable representing whether or not the sprinkler was on or off earlier today.\n", 61 | "\n", 62 | "### Observations\n", 63 | "\n", 64 | "The agent believes that these two hidden states probabilistically relate to two observation modalities, i.e. two independent 'sensory channels', which we can call `GRASS_OBSERVATION` and `WEATHER_OBSERVATION`. \n", 65 | "\n", 66 | "#### 1. `GRASS_OBSERVATION`\n", 67 | "The first modality is a binary variable representing the agent's observation (e.g. via vision, for instance) of the grass being wet or being dry.\n", 68 | "\n", 69 | "#### 2. `WEATHER_OBSERVATION`\n", 70 | "\n", 71 | "The second modality is a ternary (3-valued) variable representing the agent's observation of the state of the weather, e.g. by looking at the sky. In this example, it can either look `clear`, `rainy`, or `cloudy`\n" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "model_labels = {\n", 81 | " \"observations\": {\n", 82 | " \"grass_observation\": [\n", 83 | " \"wet\",\n", 84 | " \"dry\" \n", 85 | " ],\n", 86 | " \"weather_observation\": [\n", 87 | " \"clear\",\n", 88 | " \"rainy\",\n", 89 | " \"cloudy\"\n", 90 | " ]\n", 91 | " },\n", 92 | " \"states\": {\n", 93 | " \"did_it_rain\": [\"rained\", \"did_not_rain\"],\n", 94 | " \"was_sprinkler_on\": [\"on\", \"off\"],\n", 95 | " },\n", 96 | " }\n", 97 | "\n", 98 | "num_obs, _, n_states, n_factors = utils.get_model_dimensions_from_labels(model_labels)\n", 99 | "\n", 100 | "read_from_excel = True\n", 101 | "pre_specified_excel = True\n", 102 | "\n", 103 | "A_stub = create_A_matrix_stub(model_labels)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "### Option 1. Write the empty A matrix stub to an excel file, fill it out separately (e.g. manually in excel, and then read it back into memory). Remember, these represent the agent's generative model, not the true probabilities that relate states to observations. So you can think of these as the agent's personal/subjective 'assumptions' about how hidden states relate to observations." 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "if read_from_excel:\n", 120 | " ## Option 1: fill out A matrix 'offline' (e.g. in an excel spreadsheet)\n", 121 | "\n", 122 | " excel_dir = 'tmp_dir'\n", 123 | " if not os.path.exists(excel_dir):\n", 124 | " os.mkdir(excel_dir)\n", 125 | "\n", 126 | " excel_path = os.path.join(excel_dir, 'my_a_matrix.xlsx')\n", 127 | "\n", 128 | " if not pre_specified_excel:\n", 129 | " A_stub.to_excel(excel_path)\n", 130 | " print(f'Go fill out the A matrix in {excel_path} and then continue running this code\\n')" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "After you've filled out the Excel sheet separately (e.g. opening up Microsoft Excel and filling out the cells, you can read it back into memory)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "if read_from_excel:\n", 147 | " A_stub = read_A_matrix(excel_path, n_factors)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "### Option 2. Fill out the A matrix using the desired probabilities. Remember, these represent the agent's generative model, not the true probabilities that relate states to observations. So you can think of these as the agent's personal/subjective 'assumptions' about how hidden states relate to observations." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "if not read_from_excel:\n", 164 | " A_stub.loc[('grass_observation','wet'),('rained', 'on')] = 1.0\n", 165 | "\n", 166 | " A_stub.loc[('grass_observation','wet'),('rained', 'off')] = 0.7\n", 167 | " A_stub.loc[('grass_observation','dry'),('rained', 'off')] = 0.3\n", 168 | "\n", 169 | " A_stub.loc[('grass_observation','wet'),('did_not_rain', 'on')] = 0.5\n", 170 | " A_stub.loc[('grass_observation','dry'),('did_not_rain', 'on')] = 0.5\n", 171 | "\n", 172 | " A_stub.loc[('grass_observation','dry'),('did_not_rain', 'off')] = 1.0\n", 173 | "\n", 174 | " A_stub.loc['weather_observation','rained'] = np.tile(np.array([0.1, 0.65, 0.25]).reshape(-1,1), (1,2)) \n", 175 | "\n", 176 | " A_stub.loc[('weather_observation'),('did_not_rain')] = np.tile(np.array([0.9, 0.05, 0.05]).reshape(-1,1), (1,2)) \n" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "### Now we can use a utility function `convert_stub_to_ndarray` to convert the human-readable A matrix into the multi-dimensional tensor form needed by `pymdp` to achieve things like inference and action selection" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "A = utils.convert_A_stub_to_ndarray(A_stub, model_labels)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "## Sample a random observation" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "obs_idx = [np.random.randint(o_dim) for o_dim in num_obs]\n", 209 | "# obs_idx = [0, 1] # wet and rainy\n", 210 | "\n", 211 | "observation = utils.obj_array_zeros(num_obs)\n", 212 | "\n", 213 | "for g, modality_name in enumerate(model_labels['observations'].keys()):\n", 214 | " observation[g][obs_idx[g]] = 1.0\n", 215 | " print('%s: %s'%(modality_name, model_labels['observations'][modality_name][obs_idx[g]]))" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "## Given the observation and your A matrix, perform inference to optimize a simple posterior belief about the state of the world " 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "qs = run_vanilla_fpi(A, observation, num_obs, n_states, prior=None, num_iter=10, dF=1.0, dF_tol=0.001)\n", 232 | "\n", 233 | "print('Belief that it rained: %.2f'%(qs[0][0]))\n", 234 | "print('Belief that the sprinkler was on: %.2f'%(qs[1][0]))" 235 | ] 236 | } 237 | ], 238 | "metadata": { 239 | "interpreter": { 240 | "hash": "43ee964e2ad3601b7244370fb08e7f23a81bd2f0e3c87ee41227da88c57ff102" 241 | }, 242 | "kernelspec": { 243 | "display_name": "Python 3.7.10 64-bit ('pymdp_env': conda)", 244 | "name": "python3" 245 | }, 246 | "language_info": { 247 | "codemirror_mode": { 248 | "name": "ipython", 249 | "version": 3 250 | }, 251 | "file_extension": ".py", 252 | "mimetype": "text/x-python", 253 | "name": "python", 254 | "nbconvert_exporter": "python", 255 | "pygments_lexer": "ipython3", 256 | "version": "3.7.10" 257 | } 258 | }, 259 | "nbformat": 4, 260 | "nbformat_minor": 2 261 | } 262 | -------------------------------------------------------------------------------- /examples/A_matrix_demo.py: -------------------------------------------------------------------------------- 1 | # %% This notebook is supposed to be stepped through cell-by-cell, like a jupyter notebook 2 | 3 | import os 4 | import sys 5 | import pathlib 6 | 7 | import numpy as np 8 | import itertools 9 | import pandas as pd 10 | 11 | path = pathlib.Path(os.getcwd()) 12 | module_path = str(path.parent) + '/' 13 | sys.path.append(module_path) 14 | 15 | from pymdp import utils 16 | from pymdp.utils import create_A_matrix_stub, read_A_matrix 17 | from pymdp.algos import run_vanilla_fpi 18 | 19 | # %% Create an empty A matrix 20 | model_labels = { 21 | "observations": { 22 | "grass_observation": [ 23 | "wet", 24 | "dry" 25 | ], 26 | "weather_observation": [ 27 | "clear", 28 | "rainy", 29 | "cloudy" 30 | ] 31 | }, 32 | "states": { 33 | "did_it_rain": ["rained", "did_not_rain"], 34 | "was_sprinkler_on": ["on", "off"], 35 | }, 36 | } 37 | 38 | num_obs, _, n_states, n_factors = utils.get_model_dimensions_from_labels(model_labels) 39 | 40 | read_from_excel = True 41 | pre_specified_excel = True 42 | 43 | A_stub = create_A_matrix_stub(model_labels) 44 | 45 | if read_from_excel: 46 | ## Option 1: fill out A matrix 'offline' (e.g. in an excel spreadsheet) 47 | 48 | excel_dir = 'examples/tmp_dir' 49 | if not os.path.exists(excel_dir): 50 | os.mkdir(excel_dir) 51 | 52 | excel_path = os.path.join(excel_dir, 'my_a_matrix.xlsx') 53 | 54 | if not pre_specified_excel: 55 | A_stub.to_excel(excel_path) 56 | print(f'Go fill out the A matrix in {excel_path} and then continue running this code\n') 57 | 58 | if read_from_excel: 59 | 60 | A_stub = read_A_matrix(excel_path, n_factors) 61 | 62 | if not read_from_excel: 63 | ## Option 2: fill out the A matrix here in Python, using our knowledge of the dependencies in the system and pandas multindexing assignments 64 | 65 | A_stub.loc[('grass_observation','wet'),('rained', 'on')] = 1.0 66 | 67 | A_stub.loc[('grass_observation','wet'),('rained', 'off')] = 0.7 68 | A_stub.loc[('grass_observation','dry'),('rained', 'off')] = 0.3 69 | 70 | A_stub.loc[('grass_observation','wet'),('did_not_rain', 'on')] = 0.5 71 | A_stub.loc[('grass_observation','dry'),('did_not_rain', 'on')] = 0.5 72 | 73 | A_stub.loc[('grass_observation','dry'),('did_not_rain', 'off')] = 1.0 74 | 75 | A_stub.loc['weather_observation','rained'] = np.tile(np.array([0.1, 0.65, 0.25]).reshape(-1,1), (1,2)) 76 | 77 | A_stub.loc[('weather_observation'),('did_not_rain')] = np.tile(np.array([0.9, 0.05, 0.05]).reshape(-1,1), (1,2)) 78 | 79 | # %% now convert the A matrix into a sequence of appopriately shaped numpy arrays 80 | 81 | A = utils.convert_A_stub_to_ndarray(A_stub, model_labels) 82 | 83 | obs_idx = [np.random.randint(o_dim) for o_dim in num_obs] 84 | # obs_idx = [0, 1] # wet and rainy 85 | 86 | observation = utils.obj_array_zeros(num_obs) 87 | 88 | for g, modality_name in enumerate(model_labels['observations'].keys()): 89 | observation[g][obs_idx[g]] = 1.0 90 | print('%s: %s'%(modality_name, model_labels['observations'][modality_name][obs_idx[g]])) 91 | 92 | qs = run_vanilla_fpi(A, observation, num_obs, n_states, prior=None, num_iter=10, dF=1.0, dF_tol=0.001) 93 | 94 | print('Belief that it rained: %.2f'%(qs[0][0])) 95 | print('Belief that the sprinkler was on: %.2f'%(qs[1][0])) 96 | 97 | # %% 98 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/examples/__init__.py -------------------------------------------------------------------------------- /examples/agent_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pymdp.agent import Agent 3 | from pymdp import utils 4 | from pymdp.maths import softmax 5 | import copy 6 | 7 | obs_names = ["state_observation", "reward", "decision_proprioceptive"] 8 | state_names = ["reward_level", "decision_state"] 9 | action_names = ["uncontrolled", "decision_state"] 10 | 11 | num_obs = [3, 3, 3] 12 | num_states = [2, 3] 13 | num_modalities = len(num_obs) 14 | num_factors = len(num_states) 15 | 16 | A = utils.obj_array_zeros([[o] + num_states for _, o in enumerate(num_obs)]) 17 | 18 | A[0][:, :, 0] = np.ones( (num_obs[0], num_states[0]) ) / num_obs[0] 19 | A[0][:, :, 1] = np.ones( (num_obs[0], num_states[0]) ) / num_obs[0] 20 | A[0][:, :, 2] = np.array([[0.8, 0.2], [0.0, 0.0], [0.2, 0.8]]) 21 | 22 | A[1][2, :, 0] = np.ones(num_states[0]) 23 | A[1][0:2, :, 1] = softmax(np.eye(num_obs[1] - 1)) # bandit statistics (mapping between reward-state (first hidden state factor) and rewards (Good vs Bad)) 24 | A[1][2, :, 2] = np.ones(num_states[0]) 25 | 26 | # establish a proprioceptive mapping that determines how the agent perceives its own `decision_state` 27 | A[2][0,:,0] = 1.0 28 | A[2][1,:,1] = 1.0 29 | A[2][2,:,2] = 1.0 30 | 31 | control_fac_idx = [1] 32 | B = utils.obj_array(num_factors) 33 | for f, ns in enumerate(num_states): 34 | B[f] = np.eye(ns) 35 | if f in control_fac_idx: 36 | B[f] = B[f].reshape(ns, ns, 1) 37 | B[f] = np.tile(B[f], (1, 1, ns)) 38 | B[f] = B[f].transpose(1, 2, 0) 39 | else: 40 | B[f] = B[f].reshape(ns, ns, 1) 41 | 42 | C = utils.obj_array_zeros(num_obs) 43 | C[1][0] = 1.0 # put a 'reward' over first observation 44 | C[1][1] = -2.0 # put a 'punishment' over first observation 45 | # this implies that C[1][2] is 'neutral' 46 | 47 | agent = Agent(A=A, B=B, C=C, control_fac_idx=[1]) 48 | 49 | # initial state 50 | T = 5 51 | o = [2, 2, 0] 52 | s = [0, 0] 53 | 54 | # transition/observation matrices characterising the generative process 55 | A_gp = copy.deepcopy(A) 56 | B_gp = copy.deepcopy(B) 57 | 58 | for t in range(T): 59 | 60 | for g in range(num_modalities): 61 | print(f"{t}: Observation {obs_names[g]}: {o[g]}") 62 | 63 | qx = agent.infer_states(o) 64 | 65 | for f in range(num_factors): 66 | print(f"{t}: Beliefs about {state_names[f]}: {qx[f]}") 67 | 68 | agent.infer_policies() 69 | action = agent.sample_action() 70 | 71 | for f, s_i in enumerate(s): 72 | s[f] = utils.sample(B_gp[f][:, s_i, int(action[f])]) 73 | 74 | for g, _ in enumerate(o): 75 | o[g] = utils.sample(A_gp[g][:, s[0], s[1]]) 76 | 77 | print(np.argmax(s)) 78 | print(f"{t}: Action: {action} / State: {s}") 79 | -------------------------------------------------------------------------------- /examples/building_up_agent_loop.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import jax.numpy as jnp\n", 10 | "import jax.tree_util as jtu\n", 11 | "from jax import random as jr\n", 12 | "from pymdp.jax.agent import Agent as AIFAgent\n", 13 | "from pymdp.utils import random_A_matrix, random_B_matrix" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "(2, 10, 5, 4)\n", 26 | "[1 1]\n", 27 | "(10, 3, 3, 3)\n", 28 | "(10, 3, 3, 2)\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "def scan(f, init, xs, length=None, axis=0):\n", 34 | " if xs is None:\n", 35 | " xs = [None] * length\n", 36 | " carry = init\n", 37 | " ys = []\n", 38 | " for x in xs:\n", 39 | " carry, y = f(carry, x)\n", 40 | " if y is not None:\n", 41 | " ys.append(y)\n", 42 | " \n", 43 | " ys = None if len(ys) < 1 else jtu.tree_map(lambda *x: jnp.stack(x,axis=axis), *ys)\n", 44 | "\n", 45 | " return carry, ys\n", 46 | "\n", 47 | "def evolve_trials(agent, env, block_idx, num_timesteps, prng_key=jr.PRNGKey(0)):\n", 48 | "\n", 49 | " batch_keys = jr.split(prng_key, batch_size)\n", 50 | " def step_fn(carry, xs):\n", 51 | " actions = carry['actions']\n", 52 | " outcomes = carry['outcomes']\n", 53 | " beliefs = agent.infer_states(outcomes, actions, *carry['args'])\n", 54 | " q_pi, _ = agent.infer_policies(beliefs)\n", 55 | " actions_t = agent.sample_action(q_pi, rng_key=batch_keys)\n", 56 | "\n", 57 | " outcome_t = env.step(actions_t)\n", 58 | " outcomes = jtu.tree_map(\n", 59 | " lambda prev_o, new_o: jnp.concatenate([prev_o, jnp.expand_dims(new_o, -1)], -1), outcomes, outcome_t\n", 60 | " )\n", 61 | "\n", 62 | " if actions is not None:\n", 63 | " actions = jnp.concatenate([actions, jnp.expand_dims(actions_t, -2)], -2)\n", 64 | " else:\n", 65 | " actions = jnp.expand_dims(actions_t, -2)\n", 66 | "\n", 67 | " args = agent.update_empirical_prior(actions_t, beliefs)\n", 68 | "\n", 69 | " ### @ NOTE !!!!: Shape of policy_probs = (num_blocks, num_trials, batch_size, num_policies) if scan axis = 0, but size of `actions` will \n", 70 | " ### be (num_blocks, batch_size, num_trials, num_controls) -- so we need to 1) swap axes to both to have the same first three dimensiosn aligned,\n", 71 | " # 2) use the action indices (the integers stored in the last dimension of `actions`) to index into the policy_probs array\n", 72 | " \n", 73 | " # args = (pred_{t+1}, [post_1, post_{2}, ..., post_{t}])\n", 74 | " # beliefs = [post_1, post_{2}, ..., post_{t}]\n", 75 | " return {'args': args, 'outcomes': outcomes, 'beliefs': beliefs, 'actions': actions}, {'policy_probs': q_pi}\n", 76 | "\n", 77 | " \n", 78 | " outcome_0 = jtu.tree_map(lambda x: jnp.expand_dims(x, -1), env.step())\n", 79 | " # qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, -2), agent.D) # add a time dimension to the initial state prior\n", 80 | " init = {\n", 81 | " 'args': (agent.D, None,),\n", 82 | " 'outcomes': outcome_0, \n", 83 | " 'beliefs': [],\n", 84 | " 'actions': None\n", 85 | " }\n", 86 | " last, q_pis_ = scan(step_fn, init, range(num_timesteps), axis=1)\n", 87 | "\n", 88 | " return last, q_pis_, env\n", 89 | "\n", 90 | "def step_fn(carry, block_idx):\n", 91 | " agent, env = carry\n", 92 | " output, q_pis_, env = evolve_trials(agent, env, block_idx, num_timesteps)\n", 93 | " args = output.pop('args')\n", 94 | " output['beliefs'] = agent.infer_states(output['outcomes'], output['actions'], *args)\n", 95 | " output.update(q_pis_)\n", 96 | "\n", 97 | " # How to deal with contiguous blocks of trials? Two options we can imagine: \n", 98 | " # A) you use final posterior (over current and past timesteps) to compute the smoothing distribution over qs_{t=0} and update pD, and then pass pD as the initial state prior ($D = \\mathbb{E}_{pD}[qs_{t=0}]$);\n", 99 | " # B) we don't assume that blocks 'reset time', and are really just adjacent chunks of one long sequence, so you set the initial state prior to be the final output (`output['beliefs']`) passed through\n", 100 | " # the transition model entailed by the action taken at the last timestep of the previous block.\n", 101 | " # print(output['beliefs'].shape)\n", 102 | " agent = agent.learning(**output)\n", 103 | " \n", 104 | " return (agent, env), output\n", 105 | "\n", 106 | "# define an agent and environment here\n", 107 | "batch_size = 10\n", 108 | "num_obs = [3, 3]\n", 109 | "num_states = [3, 3]\n", 110 | "num_controls = [2, 2]\n", 111 | "num_blocks = 2\n", 112 | "num_timesteps = 5\n", 113 | "\n", 114 | "A_np = random_A_matrix(num_obs=num_obs, num_states=num_states)\n", 115 | "B_np = random_B_matrix(num_states=num_states, num_controls=num_controls)\n", 116 | "A = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(A_np))\n", 117 | "B = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(B_np))\n", 118 | "C = [jnp.zeros((batch_size, no)) for no in num_obs]\n", 119 | "D = [jnp.ones((batch_size, ns)) / ns for ns in num_states]\n", 120 | "E = jnp.ones((batch_size, 4 )) / 4 \n", 121 | "\n", 122 | "pA = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(A_np))\n", 123 | "pB = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), list(B_np))\n", 124 | "\n", 125 | "class TestEnv:\n", 126 | " def __init__(self, num_obs, prng_key=jr.PRNGKey(0)):\n", 127 | " self.num_obs=num_obs\n", 128 | " self.key = prng_key\n", 129 | " def step(self, actions=None):\n", 130 | " # return a list of random observations for each agent or parallel realization (each entry in batch_dim)\n", 131 | " obs = [jr.randint(self.key, (batch_size,), 0, no) for no in self.num_obs]\n", 132 | " self.key, _ = jr.split(self.key)\n", 133 | " return obs\n", 134 | "\n", 135 | "agents = AIFAgent(A, B, C, D, E, pA, pB, use_param_info_gain=True, use_inductive=False, inference_algo='fpi', sampling_mode='marginal', action_selection='stochastic')\n", 136 | "env = TestEnv(num_obs)\n", 137 | "init = (agents, env)\n", 138 | "(agents, env), sequences = scan(step_fn, init, range(num_blocks) )\n", 139 | "print(sequences['policy_probs'].shape)\n", 140 | "print(sequences['actions'][0][0][0])\n", 141 | "print(agents.A[0].shape)\n", 142 | "print(agents.B[0].shape)\n", 143 | "# def loss_fn(agents):\n", 144 | "# env = TestEnv(num_obs)\n", 145 | "# init = (agents, env)\n", 146 | "# (agents, env), sequences = scan(step_fn, init, range(num_blocks)) \n", 147 | "\n", 148 | "# return jnp.sum(jnp.log(sequences['policy_probs']))\n", 149 | "\n", 150 | "# dLoss_dAgents = jax.grad(loss_fn)(agents)\n", 151 | "# print(dLoss_dAgents.A[0].shape)\n", 152 | "\n", 153 | "\n", 154 | "# sequences = jtu.tree_map(lambda x: x.swapaxes(1, 2), sequences)\n", 155 | "\n", 156 | "# NOTE: all elements of sequences will have dimensionality blocks, trials, batch_size, ...\n" 157 | ] 158 | } 159 | ], 160 | "metadata": { 161 | "kernelspec": { 162 | "display_name": "jax_pymdp_test", 163 | "language": "python", 164 | "name": "python3" 165 | }, 166 | "language_info": { 167 | "codemirror_mode": { 168 | "name": "ipython", 169 | "version": 3 170 | }, 171 | "file_extension": ".py", 172 | "mimetype": "text/x-python", 173 | "name": "python", 174 | "nbconvert_exporter": "python", 175 | "pygments_lexer": "ipython3", 176 | "version": "3.11.6" 177 | }, 178 | "orig_nbformat": 4 179 | }, 180 | "nbformat": 4, 181 | "nbformat_minor": 2 182 | } 183 | -------------------------------------------------------------------------------- /examples/inductive_inference_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from pymdp.jax import control\n", 17 | "import jax.numpy as jnp\n", 18 | "import jax.tree_util as jtu\n", 19 | "from jax import nn, vmap, random, lax\n", 20 | "\n", 21 | "from typing import List, Optional\n", 22 | "from jaxtyping import Array\n", 23 | "from jax import random as jr" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "### Set up generative model (random one with trivial observation model)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "# Set up a generative model\n", 40 | "num_states = [5, 3]\n", 41 | "num_controls = [2, 2]\n", 42 | "\n", 43 | "# make some arbitrary policies (policy depth 3, 2 control factors)\n", 44 | "policy_1 = jnp.array([[0, 1],\n", 45 | " [1, 1],\n", 46 | " [0, 0]])\n", 47 | "policy_2 = jnp.array([[1, 0],\n", 48 | " [0, 0],\n", 49 | " [1, 1]])\n", 50 | "policy_matrix = jnp.stack([policy_1, policy_2]) \n", 51 | "\n", 52 | "# observation modalities (isomorphic/identical to hidden states, just need to include for the need to include likleihood model)\n", 53 | "num_obs = [5, 3]\n", 54 | "num_factors = len(num_states)\n", 55 | "num_modalities = len(num_obs)\n", 56 | "\n", 57 | "# sample parameters of the model (A, B, C)\n", 58 | "key = jr.PRNGKey(1)\n", 59 | "factor_keys = jr.split(key, num_factors)\n", 60 | "\n", 61 | "d = [0.1* jr.uniform(factor_key, (ns,)) for factor_key, ns in zip(factor_keys, num_states)]\n", 62 | "qs_init = [jr.dirichlet(factor_key, d_f) for factor_key, d_f in zip(factor_keys, d)]\n", 63 | "A = [jnp.eye(no) for no in num_obs]\n", 64 | "\n", 65 | "factor_keys = jr.split(factor_keys[-1], num_factors)\n", 66 | "b = [jr.uniform(factor_keys[f], shape=(num_controls[f], num_states[f], num_states[f])) for f in range(num_factors)]\n", 67 | "b_sparse = [jnp.where(b_f < 0.75, 1e-5, b_f) for b_f in b]\n", 68 | "B = [jnp.swapaxes(jr.dirichlet(factor_keys[f], b_sparse[f]), 2, 0) for f in range(num_factors)]\n", 69 | "\n", 70 | "modality_keys = jr.split(factor_keys[-1], num_modalities)\n", 71 | "C = [nn.one_hot(jr.randint(modality_keys[m], shape=(1,), minval=0, maxval=num_obs[m]), num_obs[m]) for m in range(num_modalities)]\n", 72 | "\n", 73 | "# trivial dependencies -- factor 1 drives modality 1, etc.\n", 74 | "A_dependencies = [[0], [1]]\n", 75 | "B_dependencies = [[0], [1]]" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "### Generate sparse constraints vectors `H` and inductive matrix `I`, using inductive parameters like depth and threshold " 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "# generate random constraints (H vector)\n", 92 | "factor_keys = jr.split(key, num_factors)\n", 93 | "H = [jr.uniform(factor_key, (ns,)) for factor_key, ns in zip(factor_keys, num_states)]\n", 94 | "H = [jnp.where(h < 0.75, 0., 1.) for h in H]\n", 95 | "\n", 96 | "# depth and threshold for inductive planning algorithm. I made policy-depth equal to inductive planning depth, out of ignorance -- need to ask Tim or Tommaso about this\n", 97 | "inductive_depth, inductive_threshold = 3, 0.5\n", 98 | "I = control.generate_I_matrix(H, B, inductive_threshold, inductive_depth)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "### Evaluate posterior probability of policies and negative EFE using new version of `update_posterior_policies`\n", 106 | "#### This function no longer computes info gain (for both states and parameters) since deterministic model is assumed, and includes new inductive matrix `I` and `inductive_epsilon` parameter" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 7, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# evaluate Q(pi) and negative EFE using the inductive planning algorithm\n", 116 | "\n", 117 | "E = jnp.ones(policy_matrix.shape[0])\n", 118 | "pA = jtu.tree_map(lambda a: jnp.ones_like(a), A)\n", 119 | "pB = jtu.tree_map(lambda b: jnp.ones_like(b), B)\n", 120 | "\n", 121 | "q_pi, neg_efe = control.update_posterior_policies_inductive(policy_matrix, qs_init, A, B, C, E, pA, pB, A_dependencies, B_dependencies, I, gamma=16.0, use_utility=True, use_inductive=True, inductive_epsilon=1e-3)" 122 | ] 123 | } 124 | ], 125 | "metadata": { 126 | "kernelspec": { 127 | "display_name": "atari_env", 128 | "language": "python", 129 | "name": "python3" 130 | }, 131 | "language_info": { 132 | "codemirror_mode": { 133 | "name": "ipython", 134 | "version": 3 135 | }, 136 | "file_extension": ".py", 137 | "mimetype": "text/x-python", 138 | "name": "python", 139 | "nbconvert_exporter": "python", 140 | "pygments_lexer": "ipython3", 141 | "version": "3.11.7" 142 | } 143 | }, 144 | "nbformat": 4, 145 | "nbformat_minor": 2 146 | } 147 | -------------------------------------------------------------------------------- /examples/inductive_inference_gridworld.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import jax.numpy as jnp\n", 17 | "import jax.tree_util as jtu\n", 18 | "from jax import nn, vmap, random, lax\n", 19 | "from typing import List, Optional\n", 20 | "from jaxtyping import Array\n", 21 | "from jax import random as jr\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "import numpy as np\n", 24 | "\n", 25 | "from pymdp.envs import GridWorldEnv\n", 26 | "from pymdp.jax import control as j_control\n", 27 | "from pymdp.jax.agent import Agent as AIFAgent\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "### Grid world generative model" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "num_rows, num_columns = 7, 7\n", 44 | "num_states = [num_rows*num_columns] # number of states equals the number of grid locations\n", 45 | "num_obs = [num_rows*num_columns] # number of observations equals the number of grid locations (fully observable)\n", 46 | "\n", 47 | "# number of agents\n", 48 | "n_batches = 5\n", 49 | "\n", 50 | "# construct A arrays\n", 51 | "A = [jnp.broadcast_to(jnp.eye(num_states[0]), (n_batches,) + (num_obs[0], num_states[0]))] # fully observable (identity observation matrix\n", 52 | "\n", 53 | "# construct B arrays\n", 54 | "grid_world = GridWorldEnv(shape=[num_rows, num_columns])\n", 55 | "B = [jnp.broadcast_to(jnp.array(grid_world.get_transition_dist()), (n_batches,) + (num_states[0], num_states[0], grid_world.n_control))] # easy way to get the generative model parameters is to extract them from one of pre-made GridWorldEnv classes\n", 56 | "num_controls = [grid_world.n_control] # number of control states equals the number of actions\n", 57 | " \n", 58 | "# create mapping from gridworld coordinates to linearly-index states\n", 59 | "grid = np.arange(grid_world.n_states).reshape(grid_world.shape)\n", 60 | "it = np.nditer(grid, flags=[\"multi_index\"])\n", 61 | "coord_to_idx_map = {}\n", 62 | "while not it.finished:\n", 63 | " coord_to_idx_map[it.multi_index] = it.iterindex\n", 64 | " it.iternext()\n", 65 | "\n", 66 | "# construct C arrays\n", 67 | "desired_position = (6,6) # lower corner\n", 68 | "desired_state_id = coord_to_idx_map[desired_position]\n", 69 | "desired_obs_id = jnp.argmax(A[0][:, desired_state_id]) # throw this in there, in case there is some indeterminism between states and observations\n", 70 | "C = [jnp.broadcast_to(nn.one_hot(desired_obs_id, num_obs[0]), (n_batches, num_obs[0]))]\n", 71 | "\n", 72 | "# construct D arrays\n", 73 | "starting_position = (3, 3) # middle\n", 74 | "# starting_position = (0, 0) # upper left corner\n", 75 | "starting_state_id = coord_to_idx_map[starting_position]\n", 76 | "starting_obs_id = jnp.argmax(A[0][:, starting_state_id]) # throw this in there, in case there is some indeterminism between states and observations\n", 77 | "D = [jnp.broadcast_to(nn.one_hot(starting_state_id, num_states[0]), (n_batches, num_states[0]))]" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "### Planning parameters" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "planning_horizon, inductive_threshold = 1, 0.1\n", 94 | "inductive_depth = 7\n", 95 | "policy_matrix = j_control.construct_policies(num_states, num_controls, policy_len=planning_horizon)\n", 96 | "\n", 97 | "# inductive planning goal states\n", 98 | "H = [jnp.broadcast_to(nn.one_hot(desired_state_id, num_states[0]), (n_batches, num_states[0]))] # list of factor-specific goal vectors (shape of each is (n_batches, num_states[f]))" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "### Initialize an `Agent()`" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 4, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# create agent\n", 115 | "agent = AIFAgent(A, B, C, D, E=None, pA=None, pB=None, policies=policy_matrix, policy_len=planning_horizon, \n", 116 | " inductive_depth=inductive_depth, inductive_threshold=inductive_threshold,\n", 117 | " H=H, use_utility=True, use_states_info_gain=False, use_param_info_gain=False, use_inductive=True)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "### Run active inference" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 6, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "Grid position for agent 2 at time 0: (3, 3)\n", 137 | "Grid position for agent 2 at time 1: (3, 4)\n", 138 | "Grid position for agent 2 at time 2: (3, 5)\n", 139 | "Grid position for agent 2 at time 3: (3, 6)\n", 140 | "Grid position for agent 2 at time 4: (4, 6)\n", 141 | "Grid position for agent 2 at time 5: (5, 6)\n", 142 | "Grid position for agent 2 at time 6: (6, 6)\n" 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "# T = 14 # needed if you start further away from the goal (e.g. in upper left corner)\n", 148 | "T = 7 # can get away with fewer timesteps if you start closer to the goal (e.g. in the middle)\n", 149 | "\n", 150 | "qs_init = [jnp.broadcast_to(nn.one_hot(starting_state_id, num_states[0]), (n_batches, num_states[0]))] # same as D\n", 151 | "obs_idx = [jnp.broadcast_to(starting_obs_id, (n_batches,))] # list of len (num_modalities), each list element of shape (n_batches,)\n", 152 | "obs_idx = jtu.tree_map(lambda x: jnp.expand_dims(x, -1), obs_idx) # list of len (num_modalities), elements each of shape (n_batches,1), this adds a trivial \"time dimension\"\n", 153 | "\n", 154 | "state = jnp.broadcast_to(starting_state_id, (n_batches,))\n", 155 | "infer_args = (agent.D, None,)\n", 156 | "batch_keys = jr.split(jr.PRNGKey(0), n_batches)\n", 157 | "batch_to_track = 1\n", 158 | "\n", 159 | "for t in range(T):\n", 160 | "\n", 161 | " print('Grid position for agent {} at time {}: {}'.format(batch_to_track+1, t, np.unravel_index(state[batch_to_track], grid_world.shape)))\n", 162 | "\n", 163 | " if t == 0:\n", 164 | " actions = None\n", 165 | " else:\n", 166 | " actions = actions_t\n", 167 | " beliefs = agent.infer_states(obs_idx, empirical_prior=infer_args[0], past_actions=actions, qs_hist=infer_args[1])\n", 168 | " q_pi, _ = agent.infer_policies(beliefs)\n", 169 | " actions_t = agent.sample_action(q_pi, rng_key=batch_keys)\n", 170 | " infer_args = agent.update_empirical_prior(actions_t, beliefs)\n", 171 | "\n", 172 | " # get next state and observation from the grid world (need to vmap everything over batches)\n", 173 | " state = vmap(lambda b, s, a: jnp.argmax(b[:, s, a]), in_axes=(0,0,0))(B[0], state, actions_t)\n", 174 | " next_obs = vmap(lambda a, s: jnp.argmax(a[:, s]), in_axes=(0,0))(A[0], state)\n", 175 | " obs_idx = [next_obs]\n", 176 | " obs_idx = jtu.tree_map(lambda x: jnp.expand_dims(x, -1), obs_idx) # add a trivial time dimension to the observation to enable indexing during agent.infer_states\n", 177 | "\n" 178 | ] 179 | } 180 | ], 181 | "metadata": { 182 | "kernelspec": { 183 | "display_name": "atari_env", 184 | "language": "python", 185 | "name": "python3" 186 | }, 187 | "language_info": { 188 | "codemirror_mode": { 189 | "name": "ipython", 190 | "version": 3 191 | }, 192 | "file_extension": ".py", 193 | "mimetype": "text/x-python", 194 | "name": "python", 195 | "nbconvert_exporter": "python", 196 | "pygments_lexer": "ipython3", 197 | "version": "3.11.7" 198 | } 199 | }, 200 | "nbformat": 4, 201 | "nbformat_minor": 2 202 | } 203 | -------------------------------------------------------------------------------- /examples/tmp_dir/my_a_matrix.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/examples/tmp_dir/my_a_matrix.xlsx -------------------------------------------------------------------------------- /pymdp/__init__.py: -------------------------------------------------------------------------------- 1 | from . import agent 2 | from . import envs 3 | from . import utils 4 | from . import maths 5 | from . import control 6 | from . import inference 7 | from . import learning 8 | from . import algos 9 | from . import default_models 10 | from . import jax 11 | -------------------------------------------------------------------------------- /pymdp/algos/__init__.py: -------------------------------------------------------------------------------- 1 | from .fpi import run_vanilla_fpi, run_vanilla_fpi_factorized 2 | from .mmp import run_mmp, run_mmp_factorized, _run_mmp_testing 3 | -------------------------------------------------------------------------------- /pymdp/default_models.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from pymdp import utils, maths 4 | 5 | def generate_epistemic_MAB_model(): 6 | ''' 7 | Create the generative model matrices (A, B, C, D) for the 'epistemic multi-armed bandit', 8 | used in the `agent_demo.py` Python file and the `agent_demo.ipynb` notebook. 9 | ''' 10 | 11 | num_states = [2, 3] 12 | num_obs = [3, 3, 3] 13 | num_controls = [1, 3] 14 | A = utils.obj_array_zeros([[o] + num_states for _, o in enumerate(num_obs)]) 15 | 16 | """ 17 | MODALITY 0 -- INFORMATION-ABOUT-REWARD-STATE MODALITY 18 | """ 19 | 20 | A[0][:, :, 0] = np.ones( (num_obs[0], num_states[0]) ) / num_obs[0] 21 | A[0][:, :, 1] = np.ones( (num_obs[0], num_states[0]) ) / num_obs[0] 22 | A[0][:, :, 2] = np.array([[0.8, 0.2], [0.0, 0.0], [0.2, 0.8]]) 23 | 24 | """ 25 | MODALITY 1 -- REWARD MODALITY 26 | """ 27 | 28 | A[1][2, :, 0] = np.ones(num_states[0]) 29 | A[1][0:2, :, 1] = maths.softmax(np.eye(num_obs[1] - 1)) # bandit statistics (mapping between reward-state (first hidden state factor) and rewards (Good vs Bad)) 30 | A[1][2, :, 2] = np.ones(num_states[0]) 31 | 32 | """ 33 | MODALITY 2 -- LOCATION-OBSERVATION MODALITY 34 | """ 35 | A[2][0,:,0] = 1.0 36 | A[2][1,:,1] = 1.0 37 | A[2][2,:,2] = 1.0 38 | 39 | control_fac_idx = [1] # this is the controllable control state factor, where there will be a >1-dimensional control state along this factor 40 | B = utils.obj_array_zeros([[n_s, n_s, num_controls[f]] for f, n_s in enumerate(num_states)]) 41 | 42 | """ 43 | FACTOR 0 -- REWARD STATE DYNAMICS 44 | """ 45 | 46 | p_stoch = 0.0 47 | 48 | # we cannot influence factor zero, set up the 'default' stationary dynamics - 49 | # one state just maps to itself at the next timestep with very high probability, by default. So this means the reward state can 50 | # change from one to another with some low probability (p_stoch) 51 | 52 | B[0][0, 0, 0] = 1.0 - p_stoch 53 | B[0][1, 0, 0] = p_stoch 54 | 55 | B[0][1, 1, 0] = 1.0 - p_stoch 56 | B[0][0, 1, 0] = p_stoch 57 | 58 | """ 59 | FACTOR 1 -- CONTROLLABLE LOCATION DYNAMICS 60 | """ 61 | # setup our controllable factor. 62 | B[1] = utils.construct_controllable_B(num_states, num_controls)[1] 63 | 64 | C = utils.obj_array_zeros(num_obs) 65 | C[1][0] = 1.0 # make the observation we've a priori named `REWARD` actually desirable, by building a high prior expectation of encountering it 66 | C[1][1] = -1.0 # make the observation we've a prior named `PUN` actually aversive,by building a low prior expectation of encountering it 67 | 68 | control_fac_idx = [1] 69 | 70 | return A, B, C, control_fac_idx 71 | 72 | def generate_grid_world_transitions(action_labels, num_rows = 3, num_cols = 3): 73 | """ 74 | Wrapper code for creating the controllable transition matrix 75 | that an agent can use to navigate in a 2-dimensional grid world 76 | """ 77 | 78 | num_grid_locs = num_rows * num_cols 79 | 80 | transition_matrix = np.zeros( (num_grid_locs, num_grid_locs, len(action_labels)) ) 81 | 82 | grid = np.arange(num_grid_locs).reshape(num_rows, num_cols) 83 | it = np.nditer(grid, flags=["multi_index"]) 84 | 85 | loc_list = [] 86 | while not it.finished: 87 | loc_list.append(it.multi_index) 88 | it.iternext() 89 | 90 | for action_id, action_label in enumerate(action_labels): 91 | 92 | for curr_state, grid_location in enumerate(loc_list): 93 | 94 | curr_row, curr_col = grid_location 95 | 96 | if action_label == "LEFT": 97 | next_col = curr_col - 1 if curr_col > 0 else curr_col 98 | next_row = curr_row 99 | elif action_label == "DOWN": 100 | next_row = curr_row + 1 if curr_row < (num_rows-1) else curr_row 101 | next_col = curr_col 102 | elif action_label == "RIGHT": 103 | next_col = curr_col + 1 if curr_col < (num_cols-1) else curr_col 104 | next_row = curr_row 105 | elif action_label == "UP": 106 | next_row = curr_row - 1 if curr_row > 0 else curr_row 107 | next_col = curr_col 108 | elif action_label == "STAY": 109 | next_row, next_col = curr_row, curr_col 110 | 111 | new_location = (next_row, next_col) 112 | next_state = loc_list.index(new_location) 113 | transition_matrix[next_state, curr_state, action_id] = 1.0 114 | 115 | return transition_matrix 116 | -------------------------------------------------------------------------------- /pymdp/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .env import Env 2 | from .grid_worlds import GridWorldEnv, DGridWorldEnv 3 | from .visual_foraging import VisualForagingEnv, SceneConstruction, RandomDotMotion, initialize_scene_construction_GM, initialize_RDM_GM 4 | from .tmaze import TMazeEnv, TMazeEnvNullOutcome 5 | -------------------------------------------------------------------------------- /pymdp/envs/env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ Environment Base Class 5 | 6 | __author__: Conor Heins, Alexander Tschantz, Brennan Klein 7 | 8 | """ 9 | 10 | 11 | class Env(object): 12 | """ 13 | The Env base class, loosely-inspired by the analogous ``env`` class of the OpenAIGym framework. 14 | 15 | A typical workflow is as follows: 16 | 17 | >>> my_env = MyCustomEnv() 18 | >>> initial_observation = my_env.reset(initial_state) 19 | >>> my_agent.infer_states(initial_observation) 20 | >>> my_agent.infer_policies() 21 | >>> next_action = my_agent.sample_action() 22 | >>> next_observation = my_env.step(next_action) 23 | 24 | This would be the first step of an active inference process, where a sub-class of ``Env``, ``MyCustomEnv`` is initialized, 25 | an initial observation is produced, and these observations are fed into an instance of ``Agent`` in order to produce an action, 26 | that can then be fed back into the the ``Env`` instance. 27 | 28 | """ 29 | 30 | def reset(self, state=None): 31 | """ 32 | Resets the initial state of the environment. Depending on case, it may be common to return an initial observation as well. 33 | """ 34 | raise NotImplementedError 35 | 36 | def step(self, action): 37 | """ 38 | Steps the environment forward using an action. 39 | 40 | Parameters 41 | ---------- 42 | action 43 | The action, the type/format of which depends on the implementation. 44 | 45 | Returns 46 | --------- 47 | observation 48 | Sensory observations for an agent, the type/format of which depends on the implementation of ``step`` and the observation space of the agent. 49 | """ 50 | raise NotImplementedError 51 | 52 | def render(self): 53 | """ 54 | Rendering function, that typically creates a visual representation of the state of the environment at the current timestep. 55 | """ 56 | pass 57 | 58 | def sample_action(self): 59 | pass 60 | 61 | def get_likelihood_dist(self): 62 | raise ValueError( 63 | "<{}> does not provide a model specification".format(type(self).__name__) 64 | ) 65 | 66 | def get_transition_dist(self): 67 | raise ValueError( 68 | "<{}> does not provide a model specification".format(type(self).__name__) 69 | ) 70 | 71 | def get_uniform_posterior(self): 72 | raise ValueError( 73 | "<{}> does not provide a model specification".format(type(self).__name__) 74 | ) 75 | 76 | def get_rand_likelihood_dist(self): 77 | raise ValueError( 78 | "<{}> does not provide a model specification".format(type(self).__name__) 79 | ) 80 | 81 | def get_rand_transition_dist(self): 82 | raise ValueError( 83 | "<{}> does not provide a model specification".format(type(self).__name__) 84 | ) 85 | 86 | def __str__(self): 87 | return "<{} instance>".format(type(self).__name__) 88 | -------------------------------------------------------------------------------- /pymdp/jax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/pymdp/jax/__init__.py -------------------------------------------------------------------------------- /pymdp/jax/inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # pylint: disable=no-member 4 | 5 | import jax.numpy as jnp 6 | from .algos import run_factorized_fpi, run_mmp, run_vmp 7 | from jax import tree_util as jtu, lax 8 | from jax.experimental.sparse._base import JAXSparse 9 | from jax.experimental import sparse 10 | from jaxtyping import Array, ArrayLike 11 | 12 | eps = jnp.finfo('float').eps 13 | 14 | def update_posterior_states( 15 | A, 16 | B, 17 | obs, 18 | past_actions, 19 | prior=None, 20 | qs_hist=None, 21 | A_dependencies=None, 22 | B_dependencies=None, 23 | num_iter=16, 24 | method='fpi' 25 | ): 26 | 27 | if method == 'fpi' or method == "ovf": 28 | # format obs to select only last observation 29 | curr_obs = jtu.tree_map(lambda x: x[-1], obs) 30 | qs = run_factorized_fpi(A, curr_obs, prior, A_dependencies, num_iter=num_iter) 31 | else: 32 | # format B matrices using action sequences here 33 | # TODO: past_actions can be None 34 | if past_actions is not None: 35 | nf = len(B) 36 | actions_tree = [past_actions[:, i] for i in range(nf)] 37 | 38 | # move time steps to the leading axis (leftmost) 39 | # this assumes that a policy is always specified as the rightmost axis of Bs 40 | B = jtu.tree_map(lambda b, a_idx: jnp.moveaxis(b[..., a_idx], -1, 0), B, actions_tree) 41 | else: 42 | B = None 43 | 44 | # outputs of both VMP and MMP should be a list of hidden state factors, where each qs[f].shape = (T, batch_dim, num_states_f) 45 | if method == 'vmp': 46 | qs = run_vmp(A, B, obs, prior, A_dependencies, B_dependencies, num_iter=num_iter) 47 | if method == 'mmp': 48 | qs = run_mmp(A, B, obs, prior, A_dependencies, B_dependencies, num_iter=num_iter) 49 | 50 | if qs_hist is not None: 51 | if method == 'fpi' or method == "ovf": 52 | qs_hist = jtu.tree_map(lambda x, y: jnp.concatenate([x, jnp.expand_dims(y, 0)], 0), qs_hist, qs) 53 | else: 54 | #TODO: return entire history of beliefs 55 | qs_hist = qs 56 | else: 57 | if method == 'fpi' or method == "ovf": 58 | qs_hist = jtu.tree_map(lambda x: jnp.expand_dims(x, 0), qs) 59 | else: 60 | qs_hist = qs 61 | 62 | return qs_hist 63 | 64 | def joint_dist_factor(b: ArrayLike, filtered_qs: list[Array], actions: Array): 65 | qs_last = filtered_qs[-1] 66 | qs_filter = filtered_qs[:-1] 67 | 68 | def step_fn(qs_smooth, xs): 69 | qs_f, action = xs 70 | time_b = b[..., action] 71 | qs_j = time_b * qs_f 72 | norm = qs_j.sum(-1, keepdims=True) 73 | if isinstance(norm, JAXSparse): 74 | norm = sparse.todense(norm) 75 | norm = jnp.where(norm == 0, eps, norm) 76 | qs_backward_cond = qs_j / norm 77 | qs_joint = qs_backward_cond * jnp.expand_dims(qs_smooth, -1) 78 | qs_smooth = qs_joint.sum(-2) 79 | if isinstance(qs_smooth, JAXSparse): 80 | qs_smooth = sparse.todense(qs_smooth) 81 | 82 | # returns q(s_t), (q(s_t), q(s_t, s_t+1)) 83 | return qs_smooth, (qs_smooth, qs_joint) 84 | 85 | # seq_qs will contain a sequence of smoothed marginals and joints 86 | _, seq_qs = lax.scan( 87 | step_fn, 88 | qs_last, 89 | (qs_filter, actions), 90 | reverse=True, 91 | unroll=2 92 | ) 93 | 94 | # we add the last filtered belief to smoothed beliefs 95 | 96 | qs_smooth_all = jnp.concatenate([seq_qs[0], jnp.expand_dims(qs_last, 0)], 0) 97 | qs_joint_all = seq_qs[1] 98 | if isinstance(qs_joint_all, JAXSparse): 99 | qs_joint_all.shape = (len(actions),) + qs_joint_all.shape 100 | return qs_smooth_all, qs_joint_all 101 | 102 | 103 | def smoothing_ovf(filtered_post, B, past_actions): 104 | assert len(filtered_post) == len(B) 105 | nf = len(B) # number of factors 106 | 107 | joint = lambda b, qs, f: joint_dist_factor(b, qs, past_actions[..., f]) 108 | 109 | marginals_and_joints = ([], []) 110 | for b, qs, f in zip(B, filtered_post, list(range(nf))): 111 | marginals, joints = joint(b, qs, f) 112 | marginals_and_joints[0].append(marginals) 113 | marginals_and_joints[1].append(joints) 114 | 115 | return marginals_and_joints 116 | 117 | 118 | -------------------------------------------------------------------------------- /pymdp/jax/likelihoods.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpyro.distributions as dist 3 | from jax import lax 4 | from numpyro import plate, sample, deterministic 5 | from numpyro.contrib.control_flow import scan 6 | 7 | def evolve_trials(agent, data): 8 | 9 | def step_fn(carry, xs): 10 | empirical_prior = carry 11 | outcomes = xs['outcomes'] 12 | qs = agent.infer_states(outcomes, empirical_prior) 13 | q_pi, _ = agent.infer_policies(qs) 14 | 15 | probs = agent.action_probabilities(q_pi) 16 | 17 | actions = xs['actions'] 18 | empirical_prior = agent.update_empirical_prior(actions, qs) 19 | #TODO: if outcomes and actions are None, generate samples 20 | return empirical_prior, (probs, outcomes, actions) 21 | 22 | prior = agent.D 23 | _, res = lax.scan(step_fn, prior, data) 24 | 25 | return res 26 | 27 | def aif_likelihood(Nb, Nt, Na, data, agent): 28 | # Na -> batch dimension - number of different subjects/agents 29 | # Nb -> number of experimental blocks 30 | # Nt -> number of trials within each block 31 | 32 | def step_fn(carry, xs): 33 | probs, outcomes, actions = evolve_trials(agent, xs) 34 | 35 | deterministic('outcomes', outcomes) 36 | 37 | with plate('num_agents', Na): 38 | with plate('num_trials', Nt): 39 | sample('actions', dist.Categorical(logits=probs).to_event(1), obs=actions) 40 | 41 | return None, None 42 | 43 | # TODO: See if some information has to be passed from one block to the next and change init and carry accordingly 44 | init = None 45 | scan(step_fn, init, data, length=Nb) -------------------------------------------------------------------------------- /pymdp/jax/maths.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from functools import partial 4 | from typing import Optional, Tuple, List 5 | from jax import tree_util, nn, jit, vmap, lax 6 | from jax.scipy.special import xlogy 7 | from opt_einsum import contract 8 | 9 | MINVAL = jnp.finfo(float).eps 10 | 11 | def stable_xlogx(x): 12 | return xlogy(x, jnp.clip(x, MINVAL)) 13 | 14 | def stable_entropy(x): 15 | return - stable_xlogx(x).sum() 16 | 17 | def stable_cross_entropy(x, y): 18 | return - xlogy(x, y).sum() 19 | 20 | def log_stable(x): 21 | return jnp.log(jnp.clip(x, min=MINVAL)) 22 | 23 | @partial(jit, static_argnames=['keep_dims']) 24 | def factor_dot(M, xs, keep_dims: Optional[Tuple[int]] = None): 25 | """ Dot product of a multidimensional array with `x`. 26 | 27 | Parameters 28 | ---------- 29 | - `qs` [list of 1D numpy.ndarray] - list of jnp.ndarrays 30 | 31 | Returns 32 | ------- 33 | - `Y` [1D numpy.ndarray] - the result of the dot product 34 | """ 35 | d = len(keep_dims) if keep_dims is not None else 0 36 | assert M.ndim == len(xs) + d 37 | keep_dims = () if keep_dims is None else keep_dims 38 | dims = tuple((i,) for i in range(M.ndim) if i not in keep_dims) 39 | return factor_dot_flex(M, xs, dims, keep_dims=keep_dims) 40 | 41 | @partial(jit, static_argnames=['dims', 'keep_dims']) 42 | def factor_dot_flex(M, xs, dims: List[Tuple[int]], keep_dims: Optional[Tuple[int]] = None): 43 | """ Dot product of a multidimensional array with `x`. 44 | 45 | Parameters 46 | ---------- 47 | - `M` [numpy.ndarray] - tensor 48 | - 'xs' [list of numpyr.ndarray] - list of tensors 49 | - 'dims' [list of tuples] - list of dimensions of xs tensors in tensor M 50 | - 'keep_dims' [tuple] - tuple of integers denoting dimesions to keep 51 | Returns 52 | ------- 53 | - `Y` [1D numpy.ndarray] - the result of the dot product 54 | """ 55 | all_dims = tuple(range(M.ndim)) 56 | matrix = [[xs[f], dims[f]] for f in range(len(xs))] 57 | args = [M, all_dims] 58 | for row in matrix: 59 | args.extend(row) 60 | 61 | args += [keep_dims] 62 | return contract(*args, backend='jax') 63 | 64 | def get_likelihood_single_modality(o_m, A_m, distr_obs=True): 65 | """Return observation likelihood for a single observation modality m""" 66 | if distr_obs: 67 | expanded_obs = jnp.expand_dims(o_m, tuple(range(1, A_m.ndim))) 68 | likelihood = (expanded_obs * A_m).sum(axis=0) 69 | else: 70 | likelihood = A_m[o_m] 71 | 72 | return likelihood 73 | 74 | def compute_log_likelihood_single_modality(o_m, A_m, distr_obs=True): 75 | """Compute observation log-likelihood for a single modality""" 76 | return log_stable(get_likelihood_single_modality(o_m, A_m, distr_obs=distr_obs)) 77 | 78 | def compute_log_likelihood(obs, A, distr_obs=True): 79 | """ Compute likelihood over hidden states across observations from different modalities """ 80 | result = tree_util.tree_map(lambda o, a: compute_log_likelihood_single_modality(o, a, distr_obs=distr_obs), obs, A) 81 | ll = jnp.sum(jnp.stack(result), 0) 82 | 83 | return ll 84 | 85 | def compute_log_likelihood_per_modality(obs, A, distr_obs=True): 86 | """ Compute likelihood over hidden states across observations from different modalities, and return them per modality """ 87 | ll_all = tree_util.tree_map(lambda o, a: compute_log_likelihood_single_modality(o, a, distr_obs=distr_obs), obs, A) 88 | 89 | return ll_all 90 | 91 | def compute_accuracy(qs, obs, A): 92 | """ Compute the accuracy portion of the variational free energy (expected log likelihood under the variational posterior) """ 93 | 94 | log_likelihood = compute_log_likelihood(obs, A) 95 | 96 | x = qs[0] 97 | for q in qs[1:]: 98 | x = jnp.expand_dims(x, -1) * q 99 | 100 | joint = log_likelihood * x 101 | return joint.sum() 102 | 103 | def compute_free_energy(qs, prior, obs, A): 104 | """ 105 | Calculate variational free energy by breaking its computation down into three steps: 106 | 1. computation of the negative entropy of the posterior -H[Q(s)] 107 | 2. computation of the cross entropy of the posterior with the prior H_{Q(s)}[P(s)] 108 | 3. computation of the accuracy E_{Q(s)}[lnP(o|s)] 109 | 110 | Then add them all together -- except subtract the accuracy 111 | """ 112 | 113 | vfe = 0.0 # initialize variational free energy 114 | for q, p in zip(qs, prior): 115 | negH_qs = - stable_entropy(q) 116 | xH_qp = stable_cross_entropy(q, p) 117 | vfe += (negH_qs + xH_qp) 118 | 119 | vfe -= compute_accuracy(qs, obs, A) 120 | 121 | return vfe 122 | 123 | def multidimensional_outer(arrs): 124 | """ Compute the outer product of a list of arrays by iteratively expanding the first array and multiplying it with the next array """ 125 | 126 | x = arrs[0] 127 | for q in arrs[1:]: 128 | x = jnp.expand_dims(x, -1) * q 129 | 130 | return x 131 | 132 | def spm_wnorm(A): 133 | """ 134 | Returns Expectation of logarithm of Dirichlet parameters over a set of 135 | Categorical distributions, stored in the columns of A. 136 | """ 137 | norm = 1. / A.sum(axis=0) 138 | avg = 1. / (A + MINVAL) 139 | wA = norm - avg 140 | return wA 141 | 142 | def dirichlet_expected_value(dir_arr): 143 | """ 144 | Returns Expectation of Dirichlet parameters over a set of 145 | Categorical distributions, stored in the columns of A. 146 | """ 147 | dir_arr = jnp.clip(dir_arr, min=MINVAL) 148 | expected_val = jnp.divide(dir_arr, dir_arr.sum(axis=0, keepdims=True)) 149 | return expected_val 150 | 151 | if __name__ == '__main__': 152 | obs = [0, 1, 2] 153 | obs_vec = [ nn.one_hot(o, 3) for o in obs] 154 | A = [jnp.ones((3, 2)) / 3] * 3 155 | res = jit(compute_log_likelihood)(obs_vec, A) 156 | 157 | print(res) -------------------------------------------------------------------------------- /pymdp/jax/task.py: -------------------------------------------------------------------------------- 1 | # Task environmnet 2 | from typing import Optional, List, Dict 3 | from jaxtyping import Array, PRNGKeyArray 4 | from functools import partial 5 | 6 | from equinox import Module, field, tree_at 7 | from jax import vmap, random as jr, tree_util as jtu 8 | import jax.numpy as jnp 9 | 10 | def select_probs(positions, matrix, dependency_list, actions=None): 11 | args = tuple(p for i, p in enumerate(positions) if i in dependency_list) 12 | args += () if actions is None else (actions,) 13 | 14 | return matrix[..., *args] 15 | 16 | def cat_sample(key, p): 17 | a = jnp.arange(p.shape[-1]) 18 | if p.ndim > 1: 19 | choice = lambda key, p: jr.choice(key, a, p=p) 20 | keys = jr.split(key, len(p)) 21 | return vmap(choice)(keys, p) 22 | 23 | return jr.choice(key, a, p=p) 24 | 25 | class PyMDPEnv(Module): 26 | params: Dict 27 | states: List[List[Array]] 28 | dependencies: Dict = field(static=True) 29 | 30 | def __init__( 31 | self, params: Dict, dependencies: Dict, init_state: List[Array] = None 32 | ): 33 | self.params = params 34 | self.dependencies = dependencies 35 | 36 | if init_state is None: 37 | init_state = jtu.tree_map(lambda x: jnp.argmax(x, -1), self.params["D"]) 38 | 39 | self.states = [init_state] 40 | 41 | def reset(self, key: Optional[PRNGKeyArray] = None): 42 | if key is None: 43 | states = [self.states[0]] 44 | else: 45 | probs = self.params["D"] 46 | keys = list(jr.split(key, len(probs))) 47 | states = [jtu.tree_map(cat_sample, keys, probs)] 48 | 49 | return tree_at(lambda x: x.states, self, states) 50 | 51 | @vmap 52 | def step(self, key: PRNGKeyArray, actions: Optional[Array] = None): 53 | # return a list of random observations and states 54 | key_state, key_obs = jr.split(key) 55 | states = self.states 56 | if actions is not None: 57 | actions = list(actions) 58 | _select_probs = partial(select_probs, states[-1]) 59 | state_probs = jtu.tree_map( 60 | _select_probs, self.params["B"], self.dependencies["B"], actions 61 | ) 62 | 63 | keys = list(jr.split(key_state, len(state_probs))) 64 | new_states = jtu.tree_map(cat_sample, keys, state_probs) 65 | else: 66 | new_states = states[-1] 67 | 68 | _select_probs = partial(select_probs, new_states) 69 | obs_probs = jtu.tree_map( 70 | _select_probs, self.params["A"], self.dependencies["A"] 71 | ) 72 | 73 | keys = list(jr.split(key_obs, len(obs_probs))) 74 | new_obs = jtu.tree_map(cat_sample, keys, obs_probs) 75 | 76 | return new_obs, tree_at(lambda x: (x.states), self, [new_states]) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs>=20.3.0 2 | cycler>=0.10.0 3 | iniconfig>=1.1.1 4 | kiwisolver>=1.3.1 5 | matplotlib>=3.1.3 6 | nose>=1.3.7 7 | numpy>=1.19.5 8 | openpyxl>=3.0.7 9 | packaging>=20.8 10 | Pillow>=8.2.0 11 | pluggy>=0.13.1 12 | py>=1.10.0 13 | pyparsing>=2.4.7 14 | pytest>=6.2.1 15 | python-dateutil>=2.8.1 16 | pytz>=2020.5 17 | scipy>=1.6.0 18 | seaborn>=0.11.1 19 | six>=1.15.0 20 | toml>=0.10.2 21 | typing-extensions>=3.7.4.3 22 | xlsxwriter>=1.4.3 23 | sphinx-rtd-theme>=0.4 24 | myst-nb>=0.13.1 25 | autograd>=1.3 26 | jax>=0.3.4 27 | jaxlib>=0.3.4 28 | equinox>=0.9 29 | numpyro>=0.1 30 | arviz>=0.13 31 | optax>=0.1 32 | multimethod>=1.11 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="inferactively-pymdp", 8 | version="0.0.7.1", 9 | author="infer-actively", 10 | author_email="conor.heins@gmail.com", 11 | description= ("A Python package for solving Markov Decision Processes with Active Inference"), 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | license='MIT', 15 | url="https://github.com/infer-actively/pymdp", 16 | python_requires='>3.7', 17 | install_requires =[ 18 | 'attrs>=20.3.0', 19 | 'cycler>=0.10.0', 20 | 'iniconfig>=1.1.1', 21 | 'kiwisolver>=1.3.1', 22 | 'matplotlib>=3.1.3', 23 | 'nose>=1.3.7', 24 | 'numpy>=1.19.5', 25 | 'openpyxl>=3.0.7', 26 | 'packaging>=20.8', 27 | 'pandas>=1.2.4', 28 | 'Pillow>=8.2.0', 29 | 'pluggy>=0.13.1', 30 | 'py>=1.10.0', 31 | 'pyparsing>=2.4.7', 32 | 'pytest>=6.2.1', 33 | 'python-dateutil>=2.8.1', 34 | 'pytz>=2020.5', 35 | 'scipy>=1.6.0', 36 | 'seaborn>=0.11.1', 37 | 'six>=1.15.0', 38 | 'toml>=0.10.2', 39 | 'typing-extensions>=3.7.4.3', 40 | 'xlsxwriter>=1.4.3', 41 | 'sphinx-rtd-theme>=0.4', 42 | 'myst-nb>=0.13.1', 43 | 'autograd>=1.3', 44 | 'jax>=0.3.4', 45 | 'jaxlib>=0.3.4', 46 | 'equinox>=0.9', 47 | 'numpyro>=0.1', 48 | 'arviz>=0.13', 49 | 'optax>=0.1' 50 | ], 51 | packages=[ 52 | "pymdp", 53 | "pymdp.envs", 54 | "pymdp.algos", 55 | "pymdp.jax" 56 | ], 57 | include_package_data=True, 58 | keywords=[ 59 | "artificial intelligence", 60 | "active inference", 61 | "free energy principle" 62 | "information theory", 63 | "decision-making", 64 | "MDP", 65 | "Markov Decision Process", 66 | "Bayesian inference", 67 | "variational inference", 68 | "reinforcement learning" 69 | ], 70 | classifiers=[ 71 | 'Development Status :: 4 - Beta', 72 | 'Intended Audience :: Developers', 73 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 74 | 'License :: OSI Approved :: MIT License', 75 | 'Programming Language :: Python :: 3.7', 76 | ], 77 | ) 78 | 79 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/__init__.py -------------------------------------------------------------------------------- /test/matlab_crossval/generation/bmr_matlab_test_a.m: -------------------------------------------------------------------------------- 1 | %% 2 | 3 | clear all; close all; clc; 4 | 5 | cd .. % this brings you into the 'pymdp/tests/matlab_crossval/' super directory, since this file should be stored in 'pymdp/tests/matlab_crossval/generation' 6 | 7 | x = linspace(1,32,128); 8 | pA = [1; 1]; 9 | rA = pA; 10 | rA(2) = 8; 11 | F = zeros(numel(x), numel(x)); 12 | for i = 1:numel(x) 13 | for j = 1:numel(x) 14 | qA = [x(i);x(j)]; 15 | F(i,j) = spm_MDP_log_evidence(qA,pA,rA); 16 | end 17 | end 18 | 19 | save_dir = 'output/bmr_test_a.mat'; 20 | save(save_dir, 'F'); 21 | 22 | %% 23 | function [F,sA] = spm_MDP_log_evidence(qA,pA,rA) 24 | % Bayesian model reduction for Dirichlet hyperparameters 25 | % FORMAT [F,sA] = spm_MDP_log_evidence(qA,pA,rA) 26 | % 27 | % qA - sufficient statistics of posterior of full model 28 | % pA - sufficient statistics of prior of full model 29 | % rA - sufficient statistics of prior of reduced model 30 | % 31 | % F - free energy or (negative) log evidence of reduced model 32 | % sA - sufficient statistics of reduced posterior 33 | % 34 | % This routine computes the negative log evidence of a reduced model of a 35 | % categorical distribution parameterised in terms of Dirichlet 36 | % hyperparameters (i.e., concentration parameters encoding probabilities). 37 | % It uses Bayesian model reduction to evaluate the evidence for models with 38 | % and without a particular parameter. 39 | % 40 | % It is assumed that all the inputs are column vectors. 41 | % 42 | % A demonstration of the implicit pruning can be found at the end of this 43 | % routine 44 | %__________________________________________________________________________ 45 | % Copyright (C) 2005 Wellcome Trust Centre for Neuroimaging 46 | 47 | % Karl Friston 48 | % $Id: spm_MDP_log_evidence.m 7326 2018-06-06 12:16:40Z karl $ 49 | 50 | 51 | % change in free energy or log model evidence 52 | %-------------------------------------------------------------------------- 53 | sA = qA + rA - pA; 54 | F = spm_betaln(qA) + spm_betaln(rA) - spm_betaln(pA) - spm_betaln(sA); 55 | 56 | end 57 | 58 | function y = spm_betaln(z) 59 | % returns the log the multivariate beta function of a vector. 60 | % FORMAT y = spm_betaln(z) 61 | % y = spm_betaln(z) computes the natural logarithm of the beta function 62 | % for corresponding elements of the vector z. if concerned is an array, 63 | % the beta functions are taken over the elements of the first to mention 64 | % (and size(y,1) equals one). 65 | % 66 | % See also BETAINC, BETA. 67 | %-------------------------------------------------------------------------- 68 | % Ref: Abramowitz & Stegun, Handbook of Mathematical Functions, sec. 6.2. 69 | % Copyright 1984-2004 The MathWorks, Inc. 70 | %__________________________________________________________________________ 71 | % Copyright (C) 2005 Wellcome Trust Centre for Neuroimaging 72 | 73 | % Karl Friston 74 | % $Id: spm_betaln.m 7508 2018-12-21 09:49:44Z thomas $ 75 | 76 | % log the multivariate beta function of a vector 77 | %-------------------------------------------------------------------------- 78 | if isvector(z) 79 | z = z(find(z)); %#ok 80 | y = sum(gammaln(z)) - gammaln(sum(z)); 81 | else 82 | for i = 1:size(z,2) 83 | for j = 1:size(z,3) 84 | for k = 1:size(z,4) 85 | for l = 1:size(z,5) 86 | for m = 1:size(z,6) 87 | y(1,i,j,k,l,m) = spm_betaln(z(:,i,j,k,l,m)); 88 | end 89 | end 90 | end 91 | end 92 | end 93 | end 94 | 95 | end 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /test/matlab_crossval/generation/bmr_matlab_test_b.m: -------------------------------------------------------------------------------- 1 | %% 2 | 3 | clear all; close all; clc; 4 | 5 | cd .. % this brings you into the 'pymdp/tests/matlab_crossval/' super directory, since this file should be stored in 'pymdp/tests/matlab_crossval/generation' 6 | 7 | q_dir = rand(8, 1); 8 | p_dir = ones(8, 1); 9 | r_dir = ones(8, 3); % 3 different reduced models 10 | 11 | r_dir(5,1) = 2; 12 | r_dir(6,2) = 2; 13 | r_dir(7,3) = 2; 14 | 15 | [F, s_dir] = spm_MDP_log_evidence(q_dir,p_dir,r_dir); 16 | 17 | save_dir = 'output/bmr_test_b.mat'; 18 | save(save_dir, 'q_dir', 'p_dir', 'r_dir', 's_dir', 'F'); 19 | 20 | %% 21 | function [F,sA] = spm_MDP_log_evidence(qA,pA,rA) 22 | % Bayesian model reduction for Dirichlet hyperparameters 23 | % FORMAT [F,sA] = spm_MDP_log_evidence(qA,pA,rA) 24 | % 25 | % qA - sufficient statistics of posterior of full model 26 | % pA - sufficient statistics of prior of full model 27 | % rA - sufficient statistics of prior of reduced model 28 | % 29 | % F - free energy or (negative) log evidence of reduced model 30 | % sA - sufficient statistics of reduced posterior 31 | % 32 | % This routine computes the negative log evidence of a reduced model of a 33 | % categorical distribution parameterised in terms of Dirichlet 34 | % hyperparameters (i.e., concentration parameters encoding probabilities). 35 | % It uses Bayesian model reduction to evaluate the evidence for models with 36 | % and without a particular parameter. 37 | % 38 | % It is assumed that all the inputs are column vectors. 39 | % 40 | % A demonstration of the implicit pruning can be found at the end of this 41 | % routine 42 | %__________________________________________________________________________ 43 | % Copyright (C) 2005 Wellcome Trust Centre for Neuroimaging 44 | 45 | % Karl Friston 46 | % $Id: spm_MDP_log_evidence.m 7326 2018-06-06 12:16:40Z karl $ 47 | 48 | 49 | % change in free energy or log model evidence 50 | %-------------------------------------------------------------------------- 51 | sA = qA + rA - pA; 52 | F = spm_betaln(qA) + spm_betaln(rA) - spm_betaln(pA) - spm_betaln(sA); 53 | 54 | end 55 | 56 | function y = spm_betaln(z) 57 | % returns the log the multivariate beta function of a vector. 58 | % FORMAT y = spm_betaln(z) 59 | % y = spm_betaln(z) computes the natural logarithm of the beta function 60 | % for corresponding elements of the vector z. if concerned is an array, 61 | % the beta functions are taken over the elements of the first to mention 62 | % (and size(y,1) equals one). 63 | % 64 | % See also BETAINC, BETA. 65 | %-------------------------------------------------------------------------- 66 | % Ref: Abramowitz & Stegun, Handbook of Mathematical Functions, sec. 6.2. 67 | % Copyright 1984-2004 The MathWorks, Inc. 68 | %__________________________________________________________________________ 69 | % Copyright (C) 2005 Wellcome Trust Centre for Neuroimaging 70 | 71 | % Karl Friston 72 | % $Id: spm_betaln.m 7508 2018-12-21 09:49:44Z thomas $ 73 | 74 | % log the multivariate beta function of a vector 75 | %-------------------------------------------------------------------------- 76 | if isvector(z) 77 | z = z(find(z)); %#ok 78 | y = sum(gammaln(z)) - gammaln(sum(z)); 79 | else 80 | for i = 1:size(z,2) 81 | for j = 1:size(z,3) 82 | for k = 1:size(z,4) 83 | for l = 1:size(z,5) 84 | for m = 1:size(z,6) 85 | y(1,i,j,k,l,m) = spm_betaln(z(:,i,j,k,l,m)); 86 | end 87 | end 88 | end 89 | end 90 | end 91 | end 92 | 93 | end 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /test/matlab_crossval/generation/mmp_matlab_test_a.m: -------------------------------------------------------------------------------- 1 | %%% PSEUDO-CODE OVERVIEW OF MESSAGE PASSING SCHEME USED IN SPM_MDP_VB_X.m 2 | clear all; close all; clc; 3 | 4 | cd .. % this brings you into the 'pymdp/tests/matlab_crossval/' super directory, since this file should be stored in 'pymdp/tests/matlab_crossval/generation' 5 | 6 | rng(1); % ensure the saved output file for inferactively is always the same 7 | %% VARIABLE NAMES 8 | 9 | T = 10; % total length of time (generative process horizon) 10 | window_len = 3; % length of inference window (in the past) 11 | policy_horizon = 1; % temporal horizon of policies 12 | num_iter = 5; % number of variational iterations 13 | num_states = [3]; % hidden state dimensionalities 14 | num_factors = length(num_states); % number of hidden state factors 15 | num_obs = [2]; % observation modality dimensionalities 16 | num_modalities = length(num_obs); % number of hidden state factors 17 | num_actions = [3]; % control factor (action) dimensionalities 18 | num_control = length(num_actions); 19 | 20 | qs_ppd = cell(1, num_factors); % variable to store posterior predictive density for current timestep. cell array of length num_factors, where each qs_ppd{f} is the PPD for a given factor (length [num_states(f), 1]) 21 | qs_bma = cell(1, num_factors); % variable to store bayesian model average for current timestep. cell array of length num_factors, where each xq{f} is the BMA for a given factor (length [num_states(f), 1]) 22 | 23 | states = zeros(num_factors,T); % matrix of true hidden states (separated by factor and timepoint) -- size(states) == [num_factors, T] 24 | for f = 1:num_factors 25 | states(f,1) = randi(num_states(f)); 26 | end 27 | 28 | actions = zeros(num_control, T); % history of actions along each control state factor and timestep -- size(actions) == [num_factors, T] 29 | obs = zeros(num_modalities,T); % history of observations (separated by modality and timepoint) -- size (obs) == [num_modalities, T] 30 | vector_obs = cell(num_modalities,T); % history of observations expressed as one-hot vectors 31 | 32 | policy_matrix = zeros(policy_horizon, 1, num_control); % matrix of policies expressed in terms of time points, actions, and hidden state factors. size(policies) == [policy_horizon, num_policies, num_factors]. 33 | % This gets updated over time with the actual actions/policies taken in the past 34 | 35 | U = zeros(1,1,num_factors); % matrix of allowable actions per policy at each move. size(U) == [1, num_policies, num_factors] 36 | 37 | U(1,1,:) = 1; 38 | 39 | policy_matrix(1,:,:) = U; 40 | 41 | % likelihoods and priors 42 | 43 | A = cell(1,num_modalities); % generative process observation likelihood (cell array of length num_modalities -- each A{g} is a matrix of size [num_modalities(g), num_states(:)] 44 | B = cell(1,num_factors); % generative process transition likelihood (cell array of length num_factors -- each B{f} is a matrix of size [num_states(f), num_states(f), num_actions(f)] 45 | C = cell(1,num_modalities); 46 | for g= 1:num_modalities 47 | C{g} = rand(num_obs(g),T); 48 | end 49 | 50 | D = cell(1,num_factors); % prior over hidden states -- a cell array of size [1, num_factors] where each D{f} is a vector of length [num_states(f), 1] 51 | for f = 1:num_factors 52 | D{f} = ones(num_states(f),1)/num_states(f); 53 | end 54 | 55 | 56 | for g = 1:num_modalities 57 | A{g} = spm_norm(rand([num_obs(g),num_states])); 58 | end 59 | 60 | a = A; % generative model == generative process 61 | 62 | 63 | for f = 1:num_factors 64 | B{f} = spm_norm(rand(num_states(f), num_states(f), num_actions(f))); 65 | end 66 | 67 | 68 | b = B; % generative model transition likelihood (cell array of length num_factors -- each b{f} is a matrix of size [num_states(f), num_states(f), num_actions(f)] 69 | b_t = cell(1,num_factors); 70 | 71 | for f = 1:num_factors 72 | for u = 1:num_actions(f) 73 | b_t{f}(:,:,u) = spm_norm(b{f}(:,:,u)');% transpose of generative model transition likelihood 74 | end 75 | end 76 | 77 | %% INITIALIZATION of beliefs 78 | 79 | % initialise different posterior beliefs used in message passing 80 | for f = 1:num_factors 81 | 82 | xn{f} = zeros(num_iter,num_states(f),window_len,T,1); 83 | 84 | vn{f} = zeros(num_iter,num_states(f),window_len,T,1); 85 | 86 | x{f} = zeros(num_states(f),T,1) + 1/num_states(f); 87 | qs_ppd{f} = zeros(num_states(f), T, 1) + 1/num_states(f); 88 | 89 | qs_bma{f} = repmat(D{f},1,T); 90 | 91 | x{f}(:,1,1) = D{f}; 92 | qs_ppd{f}(:,1,1) = D{f}; 93 | 94 | end 95 | 96 | %% 97 | for t = 1:T 98 | 99 | 100 | % posterior predictive density over hidden (external) states 101 | %-------------------------------------------------------------- 102 | for f = 1:num_factors 103 | % Bayesian model average (xq) 104 | %---------------------------------------------------------- 105 | xq{f} = qs_bma{f}(:,t); 106 | end 107 | 108 | % sample state, if not specified 109 | %-------------------------------------------------------------- 110 | for f = 1:num_factors 111 | 112 | % the next state is generated by action on external states 113 | %---------------------------------------------------------- 114 | 115 | if t > 1 116 | ps = B{f}(:,states(f,t - 1),actions(f,t - 1)); 117 | else 118 | ps = D{f}; 119 | end 120 | states(f,t) = find(rand < cumsum(ps),1); 121 | 122 | end 123 | 124 | % sample observations, if not specified 125 | %-------------------------------------------------------------- 126 | for g = 1:num_modalities 127 | 128 | % if observation is not given 129 | %---------------------------------------------------------- 130 | if ~obs(g,t) 131 | 132 | % sample from likelihood given hidden state 133 | %-------------------------------------------------- 134 | ind = num2cell(states(:,t)); 135 | p_obs = A{g}(:,ind{:}); % gets the probability over observations, under the current hidden state configuration 136 | obs(g,t) = find(rand < cumsum(p_obs),1); 137 | vector_obs{g,t} = sparse(obs(g,t),1,1,num_obs(g),1); 138 | end 139 | end 140 | 141 | % Likelihood of observation under the various configurations of hidden states 142 | %================================================================== 143 | L{t} = 1; 144 | for g = 1:num_modalities 145 | L{t} = L{t}.*spm_dot(a{g},vector_obs{g,t}); 146 | end 147 | 148 | % reset 149 | %-------------------------------------------------------------- 150 | for f = 1:num_factors 151 | x{f} = spm_softmax(spm_log(x{f})/4); 152 | end 153 | 154 | if t == 6 155 | debug_flag = true; 156 | end 157 | 158 | [F, G, x, xq, vn, xn] = run_mmp(num_iter, window_len, policy_matrix, t, xq, x, L, D, b, b_t, xn, vn); 159 | 160 | if t == 6 161 | save_dir = 'output/mmp_a.mat'; 162 | policy = squeeze(policy_matrix(end,1,:))'; 163 | previous_actions = squeeze(policy_matrix(1:(end-1),1,:)); 164 | t_horizon = window_len; 165 | qs = xq; 166 | obs_idx = obs(:,1:t); 167 | likelihoods = L(1:t); 168 | save(save_dir,'A','B','obs_idx','policy','t','t_horizon','previous_actions','qs','likelihoods') 169 | end 170 | 171 | 172 | % pretend you took a random action and supplement policy matrix with it 173 | if t < T 174 | for u = 1:num_control 175 | actions(u,t) = randi(num_actions(u)); 176 | end 177 | 178 | if (t+1) < T 179 | policy_matrix(t+1,1,:) = actions(:,t); 180 | end 181 | % and re-initialise expectations about hidden states 182 | %------------------------------------------------------ 183 | for f = 1:num_factors 184 | x{f}(:,:,1) = 1/num_states(f); 185 | end 186 | end 187 | 188 | if t == T 189 | obs = obs(:,1:T); % outcomes at 1,...,T 190 | states = states(:,1:T); % states at 1,...,T 191 | actions = actions(:,1:T - 1); % actions at 1,...,T - 1 192 | break; 193 | end 194 | 195 | end 196 | %% 197 | % auxillary functions 198 | %========================================================================== 199 | 200 | function A = spm_log(A) 201 | % log of numeric array plus a small constant 202 | %-------------------------------------------------------------------------- 203 | A = log(A + 1e-16); 204 | end 205 | 206 | function A = spm_norm(A) 207 | % normalisation of a probability transition matrix (columns) 208 | %-------------------------------------------------------------------------- 209 | A = bsxfun(@rdivide,A,sum(A,1)); 210 | A(isnan(A)) = 1/size(A,1); 211 | end 212 | 213 | function A = spm_wnorm(A) 214 | % summation of a probability transition matrix (columns) 215 | %-------------------------------------------------------------------------- 216 | A = A + 1e-16; 217 | A = bsxfun(@minus,1./sum(A,1),1./A)/2; 218 | end 219 | 220 | function sub = spm_ind2sub(siz,ndx) 221 | % subscripts from linear index 222 | %-------------------------------------------------------------------------- 223 | n = numel(siz); 224 | k = [1 cumprod(siz(1:end-1))]; 225 | for i = n:-1:1 226 | vi = rem(ndx - 1,k(i)) + 1; 227 | vj = (ndx - vi)/k(i) + 1; 228 | sub(i,1) = vj; 229 | ndx = vi; 230 | end 231 | end 232 | 233 | function [X] = spm_dot(X,x,i) 234 | % Multidimensional dot (inner) product 235 | % FORMAT [Y] = spm_dot(X,x,[DIM]) 236 | % 237 | % X - numeric array 238 | % x - cell array of numeric vectors 239 | % DIM - dimensions to omit (asumes ndims(X) = numel(x)) 240 | % 241 | % Y - inner product obtained by summing the products of X and x along DIM 242 | % 243 | % If DIM is not specified the leading dimensions of X are omitted. 244 | % If x is a vector the inner product is over the leading dimension of X 245 | 246 | % initialise dimensions 247 | %-------------------------------------------------------------------------- 248 | if iscell(x) 249 | DIM = (1:numel(x)) + ndims(X) - numel(x); 250 | else 251 | DIM = 1; 252 | x = {x}; 253 | end 254 | 255 | % omit dimensions specified 256 | %-------------------------------------------------------------------------- 257 | if nargin > 2 258 | DIM(i) = []; 259 | x(i) = []; 260 | end 261 | 262 | % inner product using recursive summation (and bsxfun) 263 | %-------------------------------------------------------------------------- 264 | for d = 1:numel(x) 265 | s = ones(1,ndims(X)); 266 | s(DIM(d)) = numel(x{d}); 267 | X = bsxfun(@times,X,reshape(full(x{d}),s)); 268 | X = sum(X,DIM(d)); 269 | end 270 | 271 | % eliminate singleton dimensions 272 | %-------------------------------------------------------------------------- 273 | X = squeeze(X); 274 | end 275 | 276 | function [y] = spm_softmax(x,k) 277 | % softmax (e.g., neural transfer) function over columns 278 | % FORMAT [y] = spm_softmax(x,k) 279 | % 280 | % x - numeric array array 281 | % k - precision, sensitivity or inverse temperature (default k = 1) 282 | % 283 | % y = exp(k*x)/sum(exp(k*x)) 284 | % 285 | % NB: If supplied with a matrix this routine will return the softmax 286 | % function over colums - so that spm_softmax([x1,x2,..]) = [1,1,...] 287 | 288 | % apply 289 | %-------------------------------------------------------------------------- 290 | if nargin > 1, x = k*x; end 291 | if size(x,1) < 2; y = ones(size(x)); return, end 292 | 293 | % exponentiate and normalise 294 | %-------------------------------------------------------------------------- 295 | x = exp(bsxfun(@minus,x,max(x))); 296 | y = bsxfun(@rdivide,x,sum(x)); 297 | end 298 | -------------------------------------------------------------------------------- /test/matlab_crossval/generation/mmp_matlab_test_b.m: -------------------------------------------------------------------------------- 1 | %%% PSEUDO-CODE OVERVIEW OF MESSAGE PASSING SCHEME USED IN SPM_MDP_VB_X.m 2 | clear all; close all; clc; 3 | 4 | cd .. % this brings you into the 'pymdp/tests/matlab_crossval/' super directory, since this file should be stored in 'pymdp/tests/matlab_crossval/generation' 5 | 6 | rng(5); % ensure the saved output file for inferactively is always the same 7 | %% VARIABLE NAMES 8 | 9 | T = 10; % total length of time (generative process horizon) 10 | window_len = 5; % length of inference window (in the past) 11 | policy_horizon = 1; % temporal horizon of policies 12 | num_iter = 5; % number of variational iterations 13 | num_states = [2, 2]; % hidden state dimensionalities 14 | num_factors = length(num_states); % number of hidden state factors 15 | num_obs = [3, 2]; % observation modality dimensionalities 16 | num_modalities = length(num_obs); % number of hidden state factors 17 | num_actions = [2, 2]; % control factor (action) dimensionalities 18 | num_control = length(num_actions); 19 | 20 | qs_ppd = cell(1, num_factors); % variable to store posterior predictive density for current timestep. cell array of length num_factors, where each qs_ppd{f} is the PPD for a given factor (length [num_states(f), 1]) 21 | qs_bma = cell(1, num_factors); % variable to store bayesian model average for current timestep. cell array of length num_factors, where each xq{f} is the BMA for a given factor (length [num_states(f), 1]) 22 | 23 | states = zeros(num_factors,T); % matrix of true hidden states (separated by factor and timepoint) -- size(states) == [num_factors, T] 24 | for f = 1:num_factors 25 | states(f,1) = randi(num_states(f)); 26 | end 27 | 28 | actions = zeros(num_control, T); % history of actions along each control state factor and timestep -- size(actions) == [num_factors, T] 29 | obs = zeros(num_modalities,T); % history of observations (separated by modality and timepoint) -- size (obs) == [num_modalities, T] 30 | vector_obs = cell(num_modalities,T); % history of observations expressed as one-hot vectors 31 | 32 | policy_matrix = zeros(policy_horizon, 1, num_control); % matrix of policies expressed in terms of time points, actions, and hidden state factors. size(policies) == [policy_horizon, num_policies, num_factors]. 33 | % This gets updated over time with the actual actions/policies taken in the past 34 | 35 | U = zeros(1,1,num_factors); % matrix of allowable actions per policy at each move. size(U) == [1, num_policies, num_factors] 36 | 37 | U(1,1,:) = [1, 1]; 38 | 39 | policy_matrix(1,:,:) = U; 40 | 41 | % likelihoods and priors 42 | 43 | A = cell(1,num_modalities); % generative process observation likelihood (cell array of length num_modalities -- each A{g} is a matrix of size [num_modalities(g), num_states(:)] 44 | B = cell(1,num_factors); % generative process transition likelihood (cell array of length num_factors -- each B{f} is a matrix of size [num_states(f), num_states(f), num_actions(f)] 45 | C = cell(1,num_modalities); 46 | for g= 1:num_modalities 47 | C{g} = rand(num_obs(g),T); 48 | end 49 | 50 | D = cell(1,num_factors); % prior over hidden states -- a cell array of size [1, num_factors] where each D{f} is a vector of length [num_states(f), 1] 51 | for f = 1:num_factors 52 | D{f} = ones(num_states(f),1)/num_states(f); 53 | end 54 | 55 | 56 | for g = 1:num_modalities 57 | A{g} = spm_norm(rand([num_obs(g),num_states])); 58 | end 59 | 60 | a = A; % generative model == generative process 61 | 62 | 63 | for f = 1:num_factors 64 | B{f} = spm_norm(rand(num_states(f), num_states(f), num_actions(f))); 65 | end 66 | 67 | 68 | b = B; % generative model transition likelihood (cell array of length num_factors -- each b{f} is a matrix of size [num_states(f), num_states(f), num_actions(f)] 69 | b_t = cell(1,num_factors); 70 | 71 | for f = 1:num_factors 72 | for u = 1:num_actions(f) 73 | b_t{f}(:,:,u) = spm_norm(b{f}(:,:,u)');% transpose of generative model transition likelihood 74 | end 75 | end 76 | 77 | %% INITIALIZATION of beliefs 78 | 79 | % initialise different posterior beliefs used in message passing 80 | for f = 1:num_factors 81 | 82 | xn{f} = zeros(num_iter,num_states(f),window_len,T,1); 83 | 84 | vn{f} = zeros(num_iter,num_states(f),window_len,T,1); 85 | 86 | x{f} = zeros(num_states(f),T,1) + 1/num_states(f); 87 | qs_ppd{f} = zeros(num_states(f), T, 1) + 1/num_states(f); 88 | 89 | qs_bma{f} = repmat(D{f},1,T); 90 | 91 | x{f}(:,1,1) = D{f}; 92 | qs_ppd{f}(:,1,1) = D{f}; 93 | 94 | end 95 | 96 | %% 97 | for t = 1:T 98 | 99 | 100 | % posterior predictive density over hidden (external) states 101 | %-------------------------------------------------------------- 102 | for f = 1:num_factors 103 | % Bayesian model average (xq) 104 | %---------------------------------------------------------- 105 | xq{f} = qs_bma{f}(:,t); 106 | end 107 | 108 | % sample state, if not specified 109 | %-------------------------------------------------------------- 110 | for f = 1:num_factors 111 | 112 | % the next state is generated by action on external states 113 | %---------------------------------------------------------- 114 | 115 | if t > 1 116 | ps = B{f}(:,states(f,t - 1),actions(f,t - 1)); 117 | else 118 | ps = D{f}; 119 | end 120 | states(f,t) = find(rand < cumsum(ps),1); 121 | 122 | end 123 | 124 | % sample observations, if not specified 125 | %-------------------------------------------------------------- 126 | for g = 1:num_modalities 127 | 128 | % if observation is not given 129 | %---------------------------------------------------------- 130 | if ~obs(g,t) 131 | 132 | % sample from likelihood given hidden state 133 | %-------------------------------------------------- 134 | ind = num2cell(states(:,t)); 135 | p_obs = A{g}(:,ind{:}); % gets the probability over observations, under the current hidden state configuration 136 | obs(g,t) = find(rand < cumsum(p_obs),1); 137 | vector_obs{g,t} = sparse(obs(g,t),1,1,num_obs(g),1); 138 | end 139 | end 140 | 141 | % Likelihood of observation under the various configurations of hidden states 142 | %================================================================== 143 | L{t} = 1; 144 | for g = 1:num_modalities 145 | L{t} = L{t}.*spm_dot(a{g},vector_obs{g,t}); 146 | end 147 | 148 | % reset 149 | %-------------------------------------------------------------- 150 | for f = 1:num_factors 151 | x{f} = spm_softmax(spm_log(x{f})/4); 152 | end 153 | 154 | if t == 3 155 | debug_flag = true; 156 | end 157 | 158 | [F, G, x, xq, vn, xn] = run_mmp(num_iter, window_len, policy_matrix, t, xq, x, L, D, b, b_t, xn, vn); 159 | 160 | if t == 3 161 | save_dir = 'output/mmp_b.mat'; 162 | policy = squeeze(policy_matrix(end,1,:))'; 163 | previous_actions = squeeze(policy_matrix(1:(end-1),1,:)); 164 | t_horizon = window_len; 165 | qs = xq; 166 | obs_idx = obs(:,1:t); 167 | likelihoods = L(1:t); 168 | save(save_dir,'A','B','obs_idx','policy','t','t_horizon','previous_actions','qs','likelihoods') 169 | end 170 | 171 | 172 | % pretend you took a random action and supplement policy matrix with it 173 | if t < T 174 | for u = 1:num_control 175 | actions(u,t) = randi(num_actions(u)); 176 | end 177 | 178 | if (t+1) < T 179 | policy_matrix(t+1,1,:) = actions(:,t); 180 | end 181 | % and re-initialise expectations about hidden states 182 | %------------------------------------------------------ 183 | for f = 1:num_factors 184 | x{f}(:,:,1) = 1/num_states(f); 185 | end 186 | end 187 | 188 | if t == T 189 | obs = obs(:,1:T); % outcomes at 1,...,T 190 | states = states(:,1:T); % states at 1,...,T 191 | actions = actions(:,1:T - 1); % actions at 1,...,T - 1 192 | break; 193 | end 194 | 195 | end 196 | %% 197 | % auxillary functions 198 | %========================================================================== 199 | 200 | function A = spm_log(A) 201 | % log of numeric array plus a small constant 202 | %-------------------------------------------------------------------------- 203 | A = log(A + 1e-16); 204 | end 205 | 206 | function A = spm_norm(A) 207 | % normalisation of a probability transition matrix (columns) 208 | %-------------------------------------------------------------------------- 209 | A = bsxfun(@rdivide,A,sum(A,1)); 210 | A(isnan(A)) = 1/size(A,1); 211 | end 212 | 213 | function A = spm_wnorm(A) 214 | % summation of a probability transition matrix (columns) 215 | %-------------------------------------------------------------------------- 216 | A = A + 1e-16; 217 | A = bsxfun(@minus,1./sum(A,1),1./A)/2; 218 | end 219 | 220 | function sub = spm_ind2sub(siz,ndx) 221 | % subscripts from linear index 222 | %-------------------------------------------------------------------------- 223 | n = numel(siz); 224 | k = [1 cumprod(siz(1:end-1))]; 225 | for i = n:-1:1 226 | vi = rem(ndx - 1,k(i)) + 1; 227 | vj = (ndx - vi)/k(i) + 1; 228 | sub(i,1) = vj; 229 | ndx = vi; 230 | end 231 | end 232 | 233 | function [X] = spm_dot(X,x,i) 234 | % Multidimensional dot (inner) product 235 | % FORMAT [Y] = spm_dot(X,x,[DIM]) 236 | % 237 | % X - numeric array 238 | % x - cell array of numeric vectors 239 | % DIM - dimensions to omit (asumes ndims(X) = numel(x)) 240 | % 241 | % Y - inner product obtained by summing the products of X and x along DIM 242 | % 243 | % If DIM is not specified the leading dimensions of X are omitted. 244 | % If x is a vector the inner product is over the leading dimension of X 245 | 246 | % initialise dimensions 247 | %-------------------------------------------------------------------------- 248 | if iscell(x) 249 | DIM = (1:numel(x)) + ndims(X) - numel(x); 250 | else 251 | DIM = 1; 252 | x = {x}; 253 | end 254 | 255 | % omit dimensions specified 256 | %-------------------------------------------------------------------------- 257 | if nargin > 2 258 | DIM(i) = []; 259 | x(i) = []; 260 | end 261 | 262 | % inner product using recursive summation (and bsxfun) 263 | %-------------------------------------------------------------------------- 264 | for d = 1:numel(x) 265 | s = ones(1,ndims(X)); 266 | s(DIM(d)) = numel(x{d}); 267 | X = bsxfun(@times,X,reshape(full(x{d}),s)); 268 | X = sum(X,DIM(d)); 269 | end 270 | 271 | % eliminate singleton dimensions 272 | %-------------------------------------------------------------------------- 273 | X = squeeze(X); 274 | end 275 | 276 | function [y] = spm_softmax(x,k) 277 | % softmax (e.g., neural transfer) function over columns 278 | % FORMAT [y] = spm_softmax(x,k) 279 | % 280 | % x - numeric array array 281 | % k - precision, sensitivity or inverse temperature (default k = 1) 282 | % 283 | % y = exp(k*x)/sum(exp(k*x)) 284 | % 285 | % NB: If supplied with a matrix this routine will return the softmax 286 | % function over colums - so that spm_softmax([x1,x2,..]) = [1,1,...] 287 | 288 | % apply 289 | %-------------------------------------------------------------------------- 290 | if nargin > 1, x = k*x; end 291 | if size(x,1) < 2; y = ones(size(x)); return, end 292 | 293 | % exponentiate and normalise 294 | %-------------------------------------------------------------------------- 295 | x = exp(bsxfun(@minus,x,max(x))); 296 | y = bsxfun(@rdivide,x,sum(x)); 297 | end 298 | -------------------------------------------------------------------------------- /test/matlab_crossval/generation/mmp_matlab_test_c.m: -------------------------------------------------------------------------------- 1 | %%% PSEUDO-CODE OVERVIEW OF MESSAGE PASSING SCHEME USED IN SPM_MDP_VB_X.m 2 | clear all; close all; clc; 3 | 4 | cd .. % this brings you into the 'pymdp/tests/matlab_crossval/' super directory, since this file should be stored in 'pymdp/tests/matlab_crossval/generation' 5 | 6 | rng(10); % ensure the saved output file for inferactively is always the same 7 | %% VARIABLE NAMES 8 | 9 | T = 10; % total length of time (generative process horizon) 10 | window_len = 5; % length of inference window (in the past) 11 | policy_horizon = 1; % temporal horizon of policies 12 | num_iter = 5; % number of variational iterations 13 | num_states = [2, 2]; % hidden state dimensionalities 14 | num_factors = length(num_states); % number of hidden state factors 15 | num_obs = [3, 2]; % observation modality dimensionalities 16 | num_modalities = length(num_obs); % number of hidden state factors 17 | num_actions = [2, 2]; % control factor (action) dimensionalities 18 | num_control = length(num_actions); 19 | 20 | qs_ppd = cell(1, num_factors); % variable to store posterior predictive density for current timestep. cell array of length num_factors, where each qs_ppd{f} is the PPD for a given factor (length [num_states(f), 1]) 21 | qs_bma = cell(1, num_factors); % variable to store bayesian model average for current timestep. cell array of length num_factors, where each xq{f} is the BMA for a given factor (length [num_states(f), 1]) 22 | 23 | states = zeros(num_factors,T); % matrix of true hidden states (separated by factor and timepoint) -- size(states) == [num_factors, T] 24 | for f = 1:num_factors 25 | states(f,1) = randi(num_states(f)); 26 | end 27 | 28 | actions = zeros(num_control, T); % history of actions along each control state factor and timestep -- size(actions) == [num_factors, T] 29 | obs = zeros(num_modalities,T); % history of observations (separated by modality and timepoint) -- size (obs) == [num_modalities, T] 30 | vector_obs = cell(num_modalities,T); % history of observations expressed as one-hot vectors 31 | 32 | policy_matrix = zeros(policy_horizon, 1, num_control); % matrix of policies expressed in terms of time points, actions, and hidden state factors. size(policies) == [policy_horizon, num_policies, num_factors]. 33 | % This gets updated over time with the actual actions/policies taken in the past 34 | 35 | U = zeros(1,1,num_factors); % matrix of allowable actions per policy at each move. size(U) == [1, num_policies, num_factors] 36 | 37 | U(1,1,:) = [1, 1]; 38 | 39 | policy_matrix(1,:,:) = U; 40 | 41 | % likelihoods and priors 42 | 43 | A = cell(1,num_modalities); % generative process observation likelihood (cell array of length num_modalities -- each A{g} is a matrix of size [num_modalities(g), num_states(:)] 44 | B = cell(1,num_factors); % generative process transition likelihood (cell array of length num_factors -- each B{f} is a matrix of size [num_states(f), num_states(f), num_actions(f)] 45 | C = cell(1,num_modalities); 46 | for g= 1:num_modalities 47 | C{g} = rand(num_obs(g),T); 48 | end 49 | 50 | D = cell(1,num_factors); % prior over hidden states -- a cell array of size [1, num_factors] where each D{f} is a vector of length [num_states(f), 1] 51 | for f = 1:num_factors 52 | D{f} = ones(num_states(f),1)/num_states(f); 53 | end 54 | 55 | 56 | for g = 1:num_modalities 57 | A{g} = spm_norm(rand([num_obs(g),num_states])); 58 | end 59 | 60 | a = A; % generative model == generative process 61 | 62 | 63 | for f = 1:num_factors 64 | B{f} = spm_norm(rand(num_states(f), num_states(f), num_actions(f))); 65 | end 66 | 67 | 68 | b = B; % generative model transition likelihood (cell array of length num_factors -- each b{f} is a matrix of size [num_states(f), num_states(f), num_actions(f)] 69 | b_t = cell(1,num_factors); 70 | 71 | for f = 1:num_factors 72 | for u = 1:num_actions(f) 73 | b_t{f}(:,:,u) = spm_norm(b{f}(:,:,u)');% transpose of generative model transition likelihood 74 | end 75 | end 76 | 77 | %% INITIALIZATION of beliefs 78 | 79 | % initialise different posterior beliefs used in message passing 80 | for f = 1:num_factors 81 | 82 | xn{f} = zeros(num_iter,num_states(f),window_len,T,1); 83 | 84 | vn{f} = zeros(num_iter,num_states(f),window_len,T,1); 85 | 86 | x{f} = zeros(num_states(f),T,1) + 1/num_states(f); 87 | qs_ppd{f} = zeros(num_states(f), T, 1) + 1/num_states(f); 88 | 89 | qs_bma{f} = repmat(D{f},1,T); 90 | 91 | x{f}(:,1,1) = D{f}; 92 | qs_ppd{f}(:,1,1) = D{f}; 93 | 94 | end 95 | 96 | %% 97 | for t = 1:T 98 | 99 | 100 | % posterior predictive density over hidden (external) states 101 | %-------------------------------------------------------------- 102 | for f = 1:num_factors 103 | % Bayesian model average (xq) 104 | %---------------------------------------------------------- 105 | xq{f} = qs_bma{f}(:,t); 106 | end 107 | 108 | % sample state, if not specified 109 | %-------------------------------------------------------------- 110 | for f = 1:num_factors 111 | 112 | % the next state is generated by action on external states 113 | %---------------------------------------------------------- 114 | 115 | if t > 1 116 | ps = B{f}(:,states(f,t - 1),actions(f,t - 1)); 117 | else 118 | ps = D{f}; 119 | end 120 | states(f,t) = find(rand < cumsum(ps),1); 121 | 122 | end 123 | 124 | % sample observations, if not specified 125 | %-------------------------------------------------------------- 126 | for g = 1:num_modalities 127 | 128 | % if observation is not given 129 | %---------------------------------------------------------- 130 | if ~obs(g,t) 131 | 132 | % sample from likelihood given hidden state 133 | %-------------------------------------------------- 134 | ind = num2cell(states(:,t)); 135 | p_obs = A{g}(:,ind{:}); % gets the probability over observations, under the current hidden state configuration 136 | obs(g,t) = find(rand < cumsum(p_obs),1); 137 | vector_obs{g,t} = sparse(obs(g,t),1,1,num_obs(g),1); 138 | end 139 | end 140 | 141 | % Likelihood of observation under the various configurations of hidden states 142 | %================================================================== 143 | L{t} = 1; 144 | for g = 1:num_modalities 145 | L{t} = L{t}.*spm_dot(a{g},vector_obs{g,t}); 146 | end 147 | 148 | % reset 149 | %-------------------------------------------------------------- 150 | for f = 1:num_factors 151 | x{f} = spm_softmax(spm_log(x{f})/4); 152 | end 153 | 154 | if t == 3 155 | debug_flag = true; 156 | end 157 | 158 | [F, G, x, xq, vn, xn] = run_mmp(num_iter, window_len, policy_matrix, t, xq, x, L, D, b, b_t, xn, vn); 159 | 160 | if t == 1 161 | save_dir = 'output/mmp_c.mat'; 162 | policy = squeeze(policy_matrix(end,1,:))'; 163 | previous_actions = squeeze(policy_matrix(1:(end-1),1,:)); 164 | t_horizon = window_len; 165 | qs = xq; 166 | obs_idx = obs(:,1:t); 167 | likelihoods = L(1:t); 168 | save(save_dir,'A','B','obs_idx','policy','t','t_horizon','previous_actions','qs','likelihoods') 169 | end 170 | 171 | 172 | % pretend you took a random action and supplement policy matrix with it 173 | if t < T 174 | for u = 1:num_control 175 | actions(u,t) = randi(num_actions(u)); 176 | end 177 | 178 | if (t+1) < T 179 | policy_matrix(t+1,1,:) = actions(:,t); 180 | end 181 | % and re-initialise expectations about hidden states 182 | %------------------------------------------------------ 183 | for f = 1:num_factors 184 | x{f}(:,:,1) = 1/num_states(f); 185 | end 186 | end 187 | 188 | if t == T 189 | obs = obs(:,1:T); % outcomes at 1,...,T 190 | states = states(:,1:T); % states at 1,...,T 191 | actions = actions(:,1:T - 1); % actions at 1,...,T - 1 192 | break; 193 | end 194 | 195 | end 196 | %% 197 | % auxillary functions 198 | %========================================================================== 199 | 200 | function A = spm_log(A) 201 | % log of numeric array plus a small constant 202 | %-------------------------------------------------------------------------- 203 | A = log(A + 1e-16); 204 | end 205 | 206 | function A = spm_norm(A) 207 | % normalisation of a probability transition matrix (columns) 208 | %-------------------------------------------------------------------------- 209 | A = bsxfun(@rdivide,A,sum(A,1)); 210 | A(isnan(A)) = 1/size(A,1); 211 | end 212 | 213 | function A = spm_wnorm(A) 214 | % summation of a probability transition matrix (columns) 215 | %-------------------------------------------------------------------------- 216 | A = A + 1e-16; 217 | A = bsxfun(@minus,1./sum(A,1),1./A)/2; 218 | end 219 | 220 | function sub = spm_ind2sub(siz,ndx) 221 | % subscripts from linear index 222 | %-------------------------------------------------------------------------- 223 | n = numel(siz); 224 | k = [1 cumprod(siz(1:end-1))]; 225 | for i = n:-1:1 226 | vi = rem(ndx - 1,k(i)) + 1; 227 | vj = (ndx - vi)/k(i) + 1; 228 | sub(i,1) = vj; 229 | ndx = vi; 230 | end 231 | end 232 | 233 | function [X] = spm_dot(X,x,i) 234 | % Multidimensional dot (inner) product 235 | % FORMAT [Y] = spm_dot(X,x,[DIM]) 236 | % 237 | % X - numeric array 238 | % x - cell array of numeric vectors 239 | % DIM - dimensions to omit (asumes ndims(X) = numel(x)) 240 | % 241 | % Y - inner product obtained by summing the products of X and x along DIM 242 | % 243 | % If DIM is not specified the leading dimensions of X are omitted. 244 | % If x is a vector the inner product is over the leading dimension of X 245 | 246 | % initialise dimensions 247 | %-------------------------------------------------------------------------- 248 | if iscell(x) 249 | DIM = (1:numel(x)) + ndims(X) - numel(x); 250 | else 251 | DIM = 1; 252 | x = {x}; 253 | end 254 | 255 | % omit dimensions specified 256 | %-------------------------------------------------------------------------- 257 | if nargin > 2 258 | DIM(i) = []; 259 | x(i) = []; 260 | end 261 | 262 | % inner product using recursive summation (and bsxfun) 263 | %-------------------------------------------------------------------------- 264 | for d = 1:numel(x) 265 | s = ones(1,ndims(X)); 266 | s(DIM(d)) = numel(x{d}); 267 | X = bsxfun(@times,X,reshape(full(x{d}),s)); 268 | X = sum(X,DIM(d)); 269 | end 270 | 271 | % eliminate singleton dimensions 272 | %-------------------------------------------------------------------------- 273 | X = squeeze(X); 274 | end 275 | 276 | function [y] = spm_softmax(x,k) 277 | % softmax (e.g., neural transfer) function over columns 278 | % FORMAT [y] = spm_softmax(x,k) 279 | % 280 | % x - numeric array array 281 | % k - precision, sensitivity or inverse temperature (default k = 1) 282 | % 283 | % y = exp(k*x)/sum(exp(k*x)) 284 | % 285 | % NB: If supplied with a matrix this routine will return the softmax 286 | % function over colums - so that spm_softmax([x1,x2,..]) = [1,1,...] 287 | 288 | % apply 289 | %-------------------------------------------------------------------------- 290 | if nargin > 1, x = k*x; end 291 | if size(x,1) < 2; y = ones(size(x)); return, end 292 | 293 | % exponentiate and normalise 294 | %-------------------------------------------------------------------------- 295 | x = exp(bsxfun(@minus,x,max(x))); 296 | y = bsxfun(@rdivide,x,sum(x)); 297 | end 298 | -------------------------------------------------------------------------------- /test/matlab_crossval/generation/run_mmp.m: -------------------------------------------------------------------------------- 1 | function [F, G, x, xq, vn, xn] = run_mmp(num_iter, window_len, policy_matrix, t, xq, x, L, D, b, b_t, xn, vn) 2 | %run_mmp Functioned out version of the marginal message passing routine 3 | %that happens in Karl's SPM_MDP_VB_X.m 4 | 5 | % marginal message passing (MMP) 6 | %-------------------------------------------------------------- 7 | 8 | num_factors = length(xq); 9 | num_states = zeros(1,num_factors); 10 | for f = 1:num_factors 11 | num_states(f) = size(xq{f},1); 12 | end 13 | 14 | S = size(policy_matrix,1) + 1; % horizon 15 | R = t; 16 | 17 | dF = 1; % reset criterion for this policy 18 | for iter = 1:num_iter % iterate belief updates 19 | F = 0; % reset free energy for this policy 20 | for j = max(1,t-window_len):S % loop over future time points 21 | 22 | % curent posterior over outcome factors 23 | %-------------------------------------------------- 24 | if j <= t 25 | for f = 1:num_factors 26 | xq{f} = x{f}(:,j,1); 27 | end 28 | end 29 | 30 | for f = 1:num_factors 31 | 32 | % hidden states for this time and policy 33 | %---------------------------------------------- 34 | sx = x{f}(:,j,1); 35 | qL = zeros(num_states(f),1); 36 | v = zeros(num_states(f),1); 37 | 38 | % evaluate free energy and gradients (v = dFdx) 39 | %---------------------------------------------- 40 | if dF > exp(-8) || iter > 4 41 | 42 | % marginal likelihood over outcome factors 43 | %------------------------------------------ 44 | if j <= t 45 | qL = spm_dot(L{j},xq,f); 46 | qL = spm_log(qL(:)); 47 | end 48 | 49 | % entropy 50 | %------------------------------------------ 51 | qx = spm_log(sx); 52 | 53 | % emprical priors (forward messages) 54 | %------------------------------------------ 55 | if j < 2 56 | px = spm_log(D{f}); 57 | v = v + px + qL - qx; 58 | else 59 | px = spm_log(b{f}(:,:,policy_matrix(j - 1,1,f))*x{f}(:,j - 1,1)); 60 | v = v + px + qL - qx; 61 | end 62 | 63 | 64 | % emprical priors (backward messages) 65 | %------------------------------------------ 66 | if j < R 67 | px = log( b_t{f}(:,:,policy_matrix(j,1,f)) * x{f}(:,j+1,1) ); 68 | % if iter == num_iter 69 | % fprintf('inference timestep: %d, factor: %d \n',j, f) 70 | % disp(px) 71 | % end 72 | v = v + px + qL - qx; 73 | end 74 | 75 | % (negative) free energy 76 | %------------------------------------------ 77 | if j == 1 || j == S 78 | F = F + sx'*0.5*v; 79 | else 80 | F = F + sx'*(0.5*v - (num_factors-1)*qL/num_factors); 81 | end 82 | 83 | % update 84 | %----------------------------------------- 85 | 86 | v = v - mean(v); 87 | % if iter == num_iter 88 | % fprintf('inference timestep: %d, factor: %d \n',j, f) 89 | % disp(v) 90 | % end 91 | 92 | sx = softmax(qx + v/4); 93 | 94 | else 95 | F = G; 96 | end 97 | 98 | % store update neuronal activity 99 | %---------------------------------------------- 100 | x{f}(:,j,1) = sx; 101 | xq{f} = sx; 102 | xn{f}(iter,:,j,t,1) = sx; 103 | vn{f}(iter,:,j,t,1) = v; 104 | 105 | end 106 | end 107 | 108 | % convergence 109 | %------------------------------------------------------ 110 | if iter > 1 111 | dF = F - G; 112 | end 113 | G = F; 114 | 115 | end 116 | 117 | end 118 | 119 | function A = spm_log(A) 120 | % log of numeric array plus a small constant 121 | %-------------------------------------------------------------------------- 122 | A = log(A + 1e-16); 123 | end 124 | 125 | function A = spm_norm(A) 126 | % normalisation of a probability transition matrix (columns) 127 | %-------------------------------------------------------------------------- 128 | A = bsxfun(@rdivide,A,sum(A,1)); 129 | A(isnan(A)) = 1/size(A,1); 130 | end 131 | 132 | function [X] = spm_dot(X,x,i) 133 | % Multidimensional dot (inner) product 134 | 135 | % initialise dimensions 136 | %-------------------------------------------------------------------------- 137 | if iscell(x) 138 | DIM = (1:numel(x)) + ndims(X) - numel(x); 139 | else 140 | DIM = 1; 141 | x = {x}; 142 | end 143 | 144 | % omit dimensions specified 145 | %-------------------------------------------------------------------------- 146 | if nargin > 2 147 | DIM(i) = []; 148 | x(i) = []; 149 | end 150 | 151 | % inner product using recursive summation (and bsxfun) 152 | %-------------------------------------------------------------------------- 153 | for d = 1:numel(x) 154 | s = ones(1,ndims(X)); 155 | s(DIM(d)) = numel(x{d}); 156 | X = bsxfun(@times,X,reshape(full(x{d}),s)); 157 | X = sum(X,DIM(d)); 158 | end 159 | 160 | % eliminate singleton dimensions 161 | %-------------------------------------------------------------------------- 162 | X = squeeze(X); 163 | end 164 | 165 | function [y] = spm_softmax(x,k) 166 | % softmax (e.g., neural transfer) function over columns 167 | 168 | % apply 169 | %-------------------------------------------------------------------------- 170 | if nargin > 1, x = k*x; end 171 | if size(x,1) < 2; y = ones(size(x)); return, end 172 | 173 | % exponentiate and normalise 174 | %-------------------------------------------------------------------------- 175 | x = exp(bsxfun(@minus,x,max(x))); 176 | y = bsxfun(@rdivide,x,sum(x)); 177 | end 178 | 179 | -------------------------------------------------------------------------------- /test/matlab_crossval/output/bmr_test_a.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/bmr_test_a.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/bmr_test_b.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/bmr_test_b.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/cross_a.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/cross_a.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/cross_b.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/cross_b.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/cross_c.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/cross_c.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/cross_d.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/cross_d.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/cross_e.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/cross_e.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/dot_a.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/dot_a.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/dot_b.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/dot_b.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/dot_c.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/dot_c.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/dot_d.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/dot_d.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/dot_e.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/dot_e.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/mmp_a.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/mmp_a.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/mmp_b.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/mmp_b.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/mmp_c.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/mmp_c.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/mmp_d.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/mmp_d.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/vbx_test_1a.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/vbx_test_1a.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/wnorm_a.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/wnorm_a.mat -------------------------------------------------------------------------------- /test/matlab_crossval/output/wnorm_b.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/infer-actively/pymdp/b29e2693e162a247ee864e306995d6b4474a29bf/test/matlab_crossval/output/wnorm_b.mat -------------------------------------------------------------------------------- /test/test_SPM_validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import numpy as np 5 | from scipy.io import loadmat 6 | 7 | from pymdp.agent import Agent 8 | from pymdp.utils import to_obj_array, build_xn_vn_array, get_model_dimensions, convert_observation_array 9 | from pymdp.maths import dirichlet_log_evidence 10 | 11 | DATA_PATH = "test/matlab_crossval/output/" 12 | 13 | class TestSPM(unittest.TestCase): 14 | 15 | def test_active_inference_SPM_1a(self): 16 | """ 17 | Test against output of SPM_MDP_VB_X.m 18 | 1A - one hidden state factor, one observation modality, backwards horizon = 3, policy_len = 1, policy-conditional prior 19 | """ 20 | array_path = os.path.join(os.getcwd(), DATA_PATH + "vbx_test_1a.mat") 21 | mat_contents = loadmat(file_name=array_path) 22 | 23 | A = mat_contents["A"][0] 24 | B = mat_contents["B"][0] 25 | C = to_obj_array(mat_contents["C"][0][0][:,0]) 26 | obs_matlab = mat_contents["obs"].astype("int64") 27 | policy = mat_contents["policies"].astype("int64") - 1 28 | t_horizon = mat_contents["t_horizon"][0, 0].astype("int64") 29 | actions_matlab = mat_contents["actions"].astype("int64") - 1 30 | qs_matlab = mat_contents["qs"][0] 31 | xn_matlab = mat_contents["xn"][0] 32 | vn_matlab = mat_contents["vn"][0] 33 | 34 | likelihoods_matlab = mat_contents["likelihoods"][0] 35 | 36 | num_obs, num_states, _, num_factors = get_model_dimensions(A, B) 37 | obs = convert_observation_array(obs_matlab, num_obs) 38 | T = len(obs) 39 | 40 | agent = Agent(A=A, B=B, C=C, inference_algo="MMP", policy_len=1, 41 | inference_horizon=t_horizon, use_BMA = False, 42 | policy_sep_prior = True) 43 | 44 | actions_python = np.zeros(T) 45 | 46 | for t in range(T): 47 | o_t = (np.where(obs[t])[0][0],) 48 | qx, xn_t, vn_t = agent._infer_states_test(o_t) 49 | q_pi, G= agent.infer_policies() 50 | action = agent.sample_action() 51 | 52 | actions_python[t] = action.item() 53 | 54 | xn_python = build_xn_vn_array(xn_t) 55 | vn_python = build_xn_vn_array(vn_t) 56 | 57 | if t == T-1: 58 | xn_python = xn_python[:,:,:-1,:] 59 | vn_python = vn_python[:,:,:-1,:] 60 | 61 | start_tstep = max(0, agent.curr_timestep - agent.inference_horizon) 62 | end_tstep = min(agent.curr_timestep + agent.policy_len, T) 63 | 64 | xn_validation = xn_matlab[0][:,:,start_tstep:end_tstep,t,:] 65 | vn_validation = vn_matlab[0][:,:,start_tstep:end_tstep,t,:] 66 | 67 | self.assertTrue(np.isclose(xn_python, xn_validation).all()) 68 | self.assertTrue(np.isclose(vn_python, vn_validation).all()) 69 | 70 | self.assertTrue(np.isclose(actions_matlab[0,:],actions_python[:-1]).all()) 71 | 72 | def test_BMR_SPM_a(self): 73 | """ 74 | Validate output of pymdp's `dirichlet_log_evidence` function 75 | against output of `spm_MDP_log_evidence` from DEM in SPM (MATLAB) 76 | Test `a` tests the log evidence calculations across for a single 77 | reduced model, stored in a vector `r_dir` 78 | """ 79 | array_path = os.path.join(os.getcwd(), DATA_PATH + "bmr_test_a.mat") 80 | mat_contents = loadmat(file_name=array_path) 81 | F_valid = mat_contents["F"] 82 | 83 | # create BMR example from MATLAB 84 | x = np.linspace(1, 32, 128) 85 | 86 | p_dir = np.ones(2) 87 | r_dir = p_dir.copy() 88 | r_dir[1] = 8. 89 | 90 | F_out = np.zeros( (len(x), len(x)) ) 91 | for i in range(len(x)): 92 | for j in range(len(x)): 93 | q_dir = np.array([x[i], x[j]]) 94 | F_out[i,j] = dirichlet_log_evidence(q_dir, p_dir, r_dir)[0] 95 | 96 | self.assertTrue(np.allclose(F_valid, F_out)) 97 | 98 | def test_BMR_SPM_b(self): 99 | """ 100 | Validate output of pymdp's `dirichlet_log_evidence` function 101 | against output of `spm_MDP_log_evidence` from DEM in SPM (MATLAB). 102 | Test `b` vectorizes the log evidence calculations across a _matrix_ of 103 | reduced models, with one reduced model prior per column of the argument `r_dir` 104 | """ 105 | array_path = os.path.join(os.getcwd(), DATA_PATH + "bmr_test_b.mat") 106 | mat_contents = loadmat(file_name=array_path) 107 | F_valid = mat_contents["F"] 108 | s_dir_valid = mat_contents['s_dir'] 109 | q_dir = mat_contents["q_dir"] 110 | p_dir = mat_contents["p_dir"] 111 | r_dir = mat_contents["r_dir"] 112 | 113 | F_out, s_dir_out = dirichlet_log_evidence(q_dir, p_dir, r_dir) 114 | 115 | self.assertTrue(np.allclose(F_valid, F_out)) 116 | 117 | self.assertTrue(np.allclose(s_dir_valid, s_dir_out)) 118 | 119 | 120 | 121 | if __name__ == "__main__": 122 | unittest.main() -------------------------------------------------------------------------------- /test/test_agent_jax.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ Unit Tests 5 | __author__: Dimitrije Markovic, Conor Heins 6 | """ 7 | 8 | import os 9 | import unittest 10 | 11 | import numpy as np 12 | import jax.numpy as jnp 13 | from jax import vmap, nn, random 14 | import jax.tree_util as jtu 15 | 16 | from pymdp.jax.maths import compute_log_likelihood_single_modality 17 | from pymdp.jax.utils import norm_dist 18 | from equinox import Module 19 | from typing import Any, List 20 | 21 | class TestAgentJax(unittest.TestCase): 22 | 23 | def test_vmappable_agent_methods(self): 24 | 25 | dim, N = 5, 10 26 | sampling_key = random.PRNGKey(1) 27 | 28 | class BasicAgent(Module): 29 | A: jnp.ndarray 30 | B: jnp.ndarray 31 | qs: jnp.ndarray 32 | 33 | def __init__(self, A, B, qs=None): 34 | self.A = A 35 | self.B = B 36 | self.qs = jnp.ones((N, dim))/dim if qs is None else qs 37 | 38 | @vmap 39 | def infer_states(self, obs): 40 | qs = nn.softmax(compute_log_likelihood_single_modality(obs, self.A)) 41 | return qs, BasicAgent(self.A, self.B, qs=qs) 42 | 43 | A_key, B_key, obs_key, test_key = random.split(sampling_key, 4) 44 | 45 | all_A = vmap(norm_dist)(random.uniform(A_key, shape = (N, dim, dim))) 46 | all_B = vmap(norm_dist)(random.uniform(B_key, shape = (N, dim, dim))) 47 | all_obs = vmap(nn.one_hot, (0, None))(random.choice(obs_key, dim, shape = (N,)), dim) 48 | 49 | my_agent = BasicAgent(all_A, all_B) 50 | 51 | all_qs, my_agent = my_agent.infer_states(all_obs) 52 | 53 | assert all_qs.shape == my_agent.qs.shape 54 | self.assertTrue(jnp.allclose(all_qs, my_agent.qs)) 55 | 56 | # validate that the method broadcasted properly 57 | for id_to_check in range(N): 58 | validation_qs = nn.softmax(compute_log_likelihood_single_modality(all_obs[id_to_check], all_A[id_to_check])) 59 | self.assertTrue(jnp.allclose(validation_qs, all_qs[id_to_check])) 60 | 61 | if __name__ == "__main__": 62 | unittest.main() 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /test/test_control_jax.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ Unit Tests 5 | __author__: Dimitrije Markovic, Conor Heins 6 | """ 7 | 8 | import os 9 | import unittest 10 | import pytest 11 | 12 | import numpy as np 13 | import jax.numpy as jnp 14 | import jax.random as jr 15 | import jax.tree_util as jtu 16 | 17 | import pymdp.jax.control as ctl_jax 18 | import pymdp.control as ctl_np 19 | 20 | from pymdp.jax.maths import factor_dot 21 | from pymdp import utils 22 | 23 | cfg = {"source_key": 0, "num_models": 4} 24 | 25 | def generate_model_params(): 26 | """ 27 | Generate random model dimensions 28 | """ 29 | rng_keys = jr.split(jr.PRNGKey(cfg["source_key"]), cfg["num_models"]) 30 | num_factors_list = [ jr.randint(key, (1,), 1, 10)[0].item() for key in rng_keys ] 31 | num_states_list = [ jr.randint(key, (nf,), 1, 5).tolist() for nf, key in zip(num_factors_list, rng_keys) ] 32 | 33 | rng_keys = jr.split(rng_keys[-1], cfg["num_models"]) 34 | num_modalities_list = [ jr.randint(key, (1,), 1, 10)[0].item() for key in rng_keys ] 35 | num_obs_list = [ jr.randint(key, (nm,), 1, 5).tolist() for nm, key in zip(num_modalities_list, rng_keys) ] 36 | 37 | rng_keys = jr.split(rng_keys[-1], cfg["num_models"]) 38 | A_deps_list = [] 39 | for nf, nm, model_key in zip(num_factors_list, num_modalities_list, rng_keys): 40 | keys_model_i = jr.split(model_key, nm) 41 | A_deps_model_i = [jr.randint(key, (nm,), 0, nf).tolist() for key in keys_model_i] 42 | A_deps_list.append(A_deps_model_i) 43 | 44 | return {'nf_list': num_factors_list, 45 | 'ns_list': num_states_list, 46 | 'nm_list': num_modalities_list, 47 | 'no_list': num_obs_list, 48 | 'A_deps_list': A_deps_list} 49 | 50 | class TestControlJax(unittest.TestCase): 51 | 52 | def test_get_expected_obs_factorized(self): 53 | """ 54 | Tests the jax-ified version of computations of expected observations under some hidden states and policy 55 | """ 56 | gm_params = generate_model_params() 57 | num_factors_list, num_states_list, num_modalities_list, num_obs_list, A_deps_list = gm_params['nf_list'], gm_params['ns_list'], gm_params['nm_list'], gm_params['no_list'], gm_params['A_deps_list'] 58 | for (num_states, num_obs, A_deps) in zip(num_states_list, num_obs_list, A_deps_list): 59 | 60 | qs_numpy = utils.random_single_categorical(num_states) 61 | qs_jax = list(qs_numpy) 62 | 63 | A_np = utils.random_A_matrix(num_obs, num_states, A_factor_list=A_deps) 64 | A_jax = jtu.tree_map(lambda x: jnp.array(x), list(A_np)) 65 | 66 | qo_test = ctl_jax.compute_expected_obs(qs_jax, A_jax, A_deps) 67 | qo_validation = ctl_np.get_expected_obs_factorized([qs_numpy], A_np, A_deps) # need to wrap `qs` in list because `get_expected_obs_factorized` expects a list of `qs` (representing multiple timesteps) 68 | 69 | for qo_m, qo_val_m in zip(qo_test, qo_validation[0]): # need to extract first index of `qo_validation` because `get_expected_obs_factorized` returns a list of `qo` (representing multiple timesteps) 70 | self.assertTrue(np.allclose(qo_m, qo_val_m)) 71 | 72 | def test_info_gain_factorized(self): 73 | """ 74 | Unit test the `calc_states_info_gain_factorized` function by qualitatively checking that in the T-Maze (contextual bandit) 75 | example, the state info gain is higher for the policy that leads to visiting the cue, which is higher than state info gain 76 | for visiting the bandit arm, which in turn is higher than the state info gain for the policy that leads to staying in the start state. 77 | """ 78 | 79 | num_states = [2, 3] 80 | num_obs = [3, 3, 3] 81 | 82 | A_dependencies = [[0, 1], [0, 1], [1]] 83 | A = [] 84 | for m, obs in enumerate(num_obs): 85 | lagging_dimensions = [ns for i, ns in enumerate(num_states) if i in A_dependencies[m]] 86 | modality_shape = [obs] + lagging_dimensions 87 | A.append(np.zeros(modality_shape)) 88 | if m == 0: 89 | A[m][:, :, 0] = np.ones( (num_obs[m], num_states[0]) ) / num_obs[m] 90 | A[m][:, :, 1] = np.ones( (num_obs[m], num_states[0]) ) / num_obs[m] 91 | A[m][:, :, 2] = np.array([[0.9, 0.1], [0.0, 0.0], [0.1, 0.9]]) # cue statistics 92 | if m == 1: 93 | A[m][2, :, 0] = np.ones(num_states[0]) 94 | A[m][0:2, :, 1] = np.array([[0.6, 0.4], [0.6, 0.4]]) # bandit statistics (mapping between reward-state (first hidden state factor) and rewards (Good vs Bad)) 95 | A[m][2, :, 2] = np.ones(num_states[0]) 96 | if m == 2: 97 | A[m] = np.eye(obs) 98 | 99 | qs_start = list(utils.obj_array_uniform(num_states)) 100 | qs_start[1] = np.array([1., 0., 0.]) # agent believes it's in the start state 101 | 102 | A = [jnp.array(A_m) for A_m in A] 103 | qs_start = [jnp.array(qs) for qs in qs_start] 104 | qo_start = ctl_jax.compute_expected_obs(qs_start, A, A_dependencies) 105 | 106 | start_info_gain = ctl_jax.compute_info_gain(qs_start, qo_start, A, A_dependencies) 107 | 108 | qs_arm = list(utils.obj_array_uniform(num_states)) 109 | qs_arm[1] = np.array([0., 1., 0.]) # agent believes it's in the arm-visiting state 110 | qs_arm = [jnp.array(qs) for qs in qs_arm] 111 | qo_arm = ctl_jax.compute_expected_obs(qs_arm, A, A_dependencies) 112 | 113 | arm_info_gain = ctl_jax.compute_info_gain(qs_arm, qo_arm, A, A_dependencies) 114 | 115 | qs_cue = utils.obj_array_uniform(num_states) 116 | qs_cue[1] = np.array([0., 0., 1.]) # agent believes it's in the cue-visiting state 117 | qs_cue = [jnp.array(qs) for qs in qs_cue] 118 | 119 | qo_cue = ctl_jax.compute_expected_obs(qs_cue, A, A_dependencies) 120 | cue_info_gain = ctl_jax.compute_info_gain(qs_cue, qo_cue, A, A_dependencies) 121 | 122 | self.assertGreater(arm_info_gain, start_info_gain) 123 | self.assertGreater(cue_info_gain, arm_info_gain) 124 | 125 | gm_params = generate_model_params() 126 | num_factors_list, num_states_list, num_modalities_list, num_obs_list, A_deps_list = gm_params['nf_list'], gm_params['ns_list'], gm_params['nm_list'], gm_params['no_list'], gm_params['A_deps_list'] 127 | for (num_states, num_obs, A_deps) in zip(num_states_list, num_obs_list, A_deps_list): 128 | 129 | qs_numpy = utils.random_single_categorical(num_states) 130 | qs_jax = list(qs_numpy) 131 | 132 | A_np = utils.random_A_matrix(num_obs, num_states, A_factor_list=A_deps) 133 | A_jax = jtu.tree_map(lambda x: jnp.array(x), list(A_np)) 134 | 135 | qo = ctl_jax.compute_expected_obs(qs_jax, A_jax, A_deps) 136 | 137 | info_gain = ctl_jax.compute_info_gain(qs_jax, qo, A_jax, A_deps) 138 | info_gain_validation = ctl_np.calc_states_info_gain_factorized(A_np, [qs_numpy], A_deps) 139 | 140 | self.assertTrue(np.allclose(info_gain, info_gain_validation, atol=1e-5)) 141 | 142 | 143 | if __name__ == "__main__": 144 | unittest.main() -------------------------------------------------------------------------------- /test/test_fpi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ Unit Tests for factorized version of variational fixed point iteration (FPI or "Vanilla FPI") 5 | __author__: Conor Heins 6 | """ 7 | 8 | import os 9 | import unittest 10 | 11 | import numpy as np 12 | 13 | from pymdp import utils, maths 14 | from pymdp.algos import run_vanilla_fpi, run_vanilla_fpi_factorized 15 | 16 | class TestFPI(unittest.TestCase): 17 | 18 | def test_factorized_fpi_one_factor_one_modality(self): 19 | """ 20 | Test the sparsified version of `run_vanilla_fpi`, named `run_vanilla_fpi_factorized` 21 | with single hidden state factor and single observation modality. 22 | """ 23 | 24 | num_states = [3] 25 | num_obs = [3] 26 | 27 | prior = utils.random_single_categorical(num_states) 28 | 29 | A = utils.to_obj_array(maths.softmax(np.eye(num_states[0]) * 0.1)) 30 | 31 | obs_idx = np.random.choice(num_obs[0]) 32 | obs = utils.onehot(obs_idx, num_obs[0]) 33 | 34 | mb_dict = {'A_factor_list': [[0]], 35 | 'A_modality_list': [[0]]} 36 | 37 | qs_out = run_vanilla_fpi_factorized(A, obs, num_obs, num_states, mb_dict, prior=prior)[0] 38 | qs_validation_1 = run_vanilla_fpi(A, obs, num_obs, num_states, prior=prior)[0] 39 | qs_validation_2 = maths.softmax(maths.spm_log_single(A[0][obs_idx,:]) + maths.spm_log_single(prior[0])) 40 | 41 | self.assertTrue(np.isclose(qs_validation_1, qs_out).all()) 42 | self.assertTrue(np.isclose(qs_validation_2, qs_out).all()) 43 | 44 | def test_factorized_fpi_one_factor_multi_modality(self): 45 | """ 46 | Test the sparsified version of `run_vanilla_fpi`, named `run_vanilla_fpi_factorized` 47 | with single hidden state factor and multiple observation modalities. 48 | """ 49 | 50 | num_states = [3] 51 | num_obs = [3, 2] 52 | 53 | prior = utils.random_single_categorical(num_states) 54 | 55 | A = utils.random_A_matrix(num_obs, num_states) 56 | 57 | obs = utils.obj_array(len(num_obs)) 58 | for m, obs_dim in enumerate(num_obs): 59 | obs[m] = utils.onehot(np.random.randint(obs_dim), obs_dim) 60 | 61 | mb_dict = {'A_factor_list': [[0], [0]], 62 | 'A_modality_list': [[0, 1]]} 63 | 64 | qs_out = run_vanilla_fpi_factorized(A, obs, num_obs, num_states, mb_dict, prior=prior)[0] 65 | qs_validation = run_vanilla_fpi(A, obs, num_obs, num_states, prior=prior)[0] 66 | 67 | self.assertTrue(np.isclose(qs_validation, qs_out).all()) 68 | 69 | def test_factorized_fpi_multi_factor_one_modality(self): 70 | """ 71 | Test the sparsified version of `run_vanilla_fpi`, named `run_vanilla_fpi_factorized` 72 | with multiple hidden state factors and one observation modality. 73 | """ 74 | 75 | num_states = [4, 5] 76 | num_obs = [3] 77 | 78 | prior = utils.random_single_categorical(num_states) 79 | 80 | A = utils.random_A_matrix(num_obs, num_states) 81 | 82 | obs_idx = np.random.choice(num_obs[0]) 83 | obs = utils.onehot(obs_idx, num_obs[0]) 84 | 85 | mb_dict = {'A_factor_list': [[0, 1]], 86 | 'A_modality_list': [[0], [0]]} 87 | 88 | qs_out = run_vanilla_fpi_factorized(A, obs, num_obs, num_states, mb_dict, prior=prior) 89 | qs_validation = run_vanilla_fpi(A, obs, num_obs, num_states, prior=prior) 90 | 91 | for qs_f_val, qs_f_out in zip(qs_validation, qs_out): 92 | self.assertTrue(np.isclose(qs_f_val, qs_f_out).all()) 93 | 94 | def test_factorized_fpi_multi_factor_multi_modality(self): 95 | """ 96 | Test the sparsified version of `run_vanilla_fpi`, named `run_vanilla_fpi_factorized` 97 | with multiple hidden state factors and multiple observation modalities. 98 | """ 99 | 100 | num_states = [3, 4] 101 | num_obs = [3, 3, 5] 102 | 103 | prior = utils.random_single_categorical(num_states) 104 | 105 | A = utils.random_A_matrix(num_obs, num_states) 106 | 107 | obs = utils.obj_array(len(num_obs)) 108 | for m, obs_dim in enumerate(num_obs): 109 | obs[m] = utils.onehot(np.random.randint(obs_dim), obs_dim) 110 | 111 | mb_dict = {'A_factor_list': [[0, 1], [0, 1], [0, 1]], 112 | 'A_modality_list': [[0, 1, 2], [0, 1, 2]]} 113 | 114 | qs_out = run_vanilla_fpi_factorized(A, obs, num_obs, num_states, mb_dict, prior=prior) 115 | qs_validation = run_vanilla_fpi(A, obs, num_obs, num_states, prior=prior) 116 | 117 | for qs_f_val, qs_f_out in zip(qs_validation, qs_out): 118 | self.assertTrue(np.isclose(qs_f_val, qs_f_out).all()) 119 | 120 | # test it also without computing VFE (i.e. with `compute_vfe=False`) 121 | qs_out = run_vanilla_fpi_factorized(A, obs, num_obs, num_states, mb_dict, prior=prior, compute_vfe=False) 122 | qs_validation = run_vanilla_fpi(A, obs, num_obs, num_states, prior=prior, compute_vfe=False) 123 | 124 | for qs_f_val, qs_f_out in zip(qs_validation, qs_out): 125 | self.assertTrue(np.isclose(qs_f_val, qs_f_out).all()) 126 | 127 | def test_factorized_fpi_multi_factor_multi_modality_with_condind(self): 128 | """ 129 | Test the sparsified version of `run_vanilla_fpi`, named `run_vanilla_fpi_factorized` 130 | with multiple hidden state factors and multiple observation modalities, where some modalities only depend on some factors. 131 | """ 132 | 133 | num_states = [3, 4] 134 | num_obs = [3, 3, 5] 135 | 136 | prior = utils.random_single_categorical(num_states) 137 | 138 | obs = utils.obj_array(len(num_obs)) 139 | for m, obs_dim in enumerate(num_obs): 140 | obs[m] = utils.onehot(np.random.randint(obs_dim), obs_dim) 141 | 142 | mb_dict = {'A_factor_list': [[0], [1], [0, 1]], 143 | 'A_modality_list': [[0, 2], [1, 2]]} 144 | 145 | A_reduced = utils.random_A_matrix(num_obs, num_states, A_factor_list=mb_dict['A_factor_list']) 146 | 147 | qs_out = run_vanilla_fpi_factorized(A_reduced, obs, num_obs, num_states, mb_dict, prior=prior) 148 | 149 | A_full = utils.initialize_empty_A(num_obs, num_states) 150 | for m, A_m in enumerate(A_full): 151 | other_factors = list(set(range(len(num_states))) - set(mb_dict['A_factor_list'][m])) # list of the factors that modality `m` does not depend on 152 | 153 | # broadcast or tile the reduced A matrix (`A_reduced`) along the dimensions of corresponding to `other_factors` 154 | expanded_dims = [num_obs[m]] + [1 if f in other_factors else ns for (f, ns) in enumerate(num_states)] 155 | tile_dims = [1] + [ns if f in other_factors else 1 for (f, ns) in enumerate(num_states)] 156 | A_full[m] = np.tile(A_reduced[m].reshape(expanded_dims), tile_dims) 157 | 158 | qs_validation = run_vanilla_fpi(A_full, obs, num_obs, num_states, prior=prior) 159 | 160 | for qs_f_val, qs_f_out in zip(qs_validation, qs_out): 161 | self.assertTrue(np.isclose(qs_f_val, qs_f_out).all()) 162 | 163 | def test_factorized_fpi_multi_factor_single_modality_with_condind(self): 164 | """ 165 | Test the sparsified version of `run_vanilla_fpi`, named `run_vanilla_fpi_factorized` 166 | with multiple hidden state factors and one observation modality, where the modality only depend on some factors. 167 | """ 168 | 169 | num_states = [3, 4] 170 | num_obs = [3] 171 | 172 | prior = utils.random_single_categorical(num_states) 173 | 174 | obs = utils.obj_array(len(num_obs)) 175 | for m, obs_dim in enumerate(num_obs): 176 | obs[m] = utils.onehot(np.random.randint(obs_dim), obs_dim) 177 | 178 | mb_dict = {'A_factor_list': [[0]], 179 | 'A_modality_list': [[0], []]} 180 | 181 | A_reduced = utils.random_A_matrix(num_obs, num_states, A_factor_list=mb_dict['A_factor_list']) 182 | 183 | qs_out = run_vanilla_fpi_factorized(A_reduced, obs, num_obs, num_states, mb_dict, prior=prior) 184 | 185 | A_full = utils.initialize_empty_A(num_obs, num_states) 186 | for m, A_m in enumerate(A_full): 187 | other_factors = list(set(range(len(num_states))) - set(mb_dict['A_factor_list'][m])) # list of the factors that modality `m` does not depend on 188 | 189 | # broadcast or tile the reduced A matrix (`A_reduced`) along the dimensions of corresponding to `other_factors` 190 | expanded_dims = [num_obs[m]] + [1 if f in other_factors else ns for (f, ns) in enumerate(num_states)] 191 | tile_dims = [1] + [ns if f in other_factors else 1 for (f, ns) in enumerate(num_states)] 192 | A_full[m] = np.tile(A_reduced[m].reshape(expanded_dims), tile_dims) 193 | 194 | qs_validation = run_vanilla_fpi(A_full, obs, num_obs, num_states, prior=prior) 195 | 196 | for qs_f_val, qs_f_out in zip(qs_validation, qs_out): 197 | self.assertTrue(np.isclose(qs_f_val, qs_f_out).all()) 198 | 199 | self.assertTrue(np.isclose(qs_out[1], prior[1]).all()) 200 | 201 | 202 | if __name__ == "__main__": 203 | unittest.main() -------------------------------------------------------------------------------- /test/test_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ Unit Tests 5 | __author__: Conor Heins, Alexander Tschantz, Daphne Demekas, Brennan Klein 6 | """ 7 | 8 | import os 9 | import unittest 10 | 11 | import numpy as np 12 | 13 | from pymdp import utils, maths 14 | from pymdp import inference 15 | 16 | class TestInference(unittest.TestCase): 17 | 18 | def test_update_posterior_states(self): 19 | """ 20 | Tests the refactored version of `update_posterior_states` 21 | """ 22 | 23 | '''Test with single hidden state factor and single observation modality''' 24 | 25 | num_states = [3] 26 | num_obs = [3] 27 | 28 | prior = utils.random_single_categorical(num_states) 29 | 30 | A = utils.to_obj_array(maths.softmax(np.eye(num_states[0]) * 0.1)) 31 | 32 | obs_idx = 1 33 | obs = utils.onehot(obs_idx, num_obs[0]) 34 | 35 | qs_out = inference.update_posterior_states(A, obs, prior=prior) 36 | qs_validation = maths.softmax(maths.spm_log_single(A[0][obs_idx,:]) + maths.spm_log_single(prior[0])) 37 | 38 | self.assertTrue(np.isclose(qs_validation, qs_out[0]).all()) 39 | 40 | '''Try single modality inference where the observation is passed in as an int''' 41 | qs_out_2 = inference.update_posterior_states(A, obs_idx, prior=prior) 42 | self.assertTrue(np.isclose(qs_out_2[0], qs_out[0]).all()) 43 | 44 | '''Try single modality inference where the observation is a one-hot stored in an object array''' 45 | qs_out_3 = inference.update_posterior_states(A, utils.to_obj_array(obs), prior=prior) 46 | self.assertTrue(np.isclose(qs_out_3[0], qs_out[0]).all()) 47 | 48 | '''Test with multiple hidden state factors and single observation modality''' 49 | 50 | num_states = [3, 4] 51 | num_obs = [3] 52 | 53 | prior = utils.random_single_categorical(num_states) 54 | 55 | A = utils.random_A_matrix(num_obs, num_states) 56 | 57 | obs_idx = 1 58 | obs = utils.onehot(obs_idx, num_obs[0]) 59 | 60 | qs_out = inference.update_posterior_states(A, obs, prior=prior, num_iter = 1) 61 | 62 | # validate with a quick n' dirty implementation of FPI 63 | 64 | # initialize posterior and log prior 65 | qs_valid_init = utils.obj_array_uniform(num_states) 66 | log_prior = maths.spm_log_obj_array(prior) 67 | 68 | qs_valid_final = utils.obj_array(len(num_states)) 69 | 70 | log_likelihood = maths.spm_log_single(maths.get_joint_likelihood(A, obs, num_states)) 71 | 72 | num_factors = len(num_states) 73 | 74 | qs_valid_init_all = qs_valid_init[0] 75 | for factor in range(num_factors-1): 76 | qs_valid_init_all = qs_valid_init_all[...,None]*qs_valid_init[factor+1] 77 | LL_tensor = log_likelihood * qs_valid_init_all 78 | 79 | factor_ids = range(num_factors) 80 | 81 | for factor, qs_f in enumerate(qs_valid_init): 82 | ax2sum = tuple(set(factor_ids) - set([factor])) # which axes to sum out 83 | qL = LL_tensor.sum(axis = ax2sum) / qs_f 84 | qs_valid_final[factor] = maths.softmax(qL + log_prior[factor]) 85 | 86 | for factor, qs_f_valid in enumerate(qs_valid_final): 87 | self.assertTrue(np.isclose(qs_f_valid, qs_out[factor]).all()) 88 | 89 | '''Test with multiple hidden state factors and multiple observation modalities, for two different kinds of observation input formats''' 90 | 91 | num_states = [3, 4] 92 | num_obs = [3, 3, 5] 93 | 94 | prior = utils.random_single_categorical(num_states) 95 | 96 | A = utils.random_A_matrix(num_obs, num_states) 97 | 98 | obs_index_tuple = tuple([np.random.randint(obs_dim) for obs_dim in num_obs]) 99 | 100 | qs_out1 = inference.update_posterior_states(A, obs_index_tuple, prior=prior) 101 | 102 | obs_onehots = utils.obj_array(len(num_obs)) 103 | for g in range(len(num_obs)): 104 | obs_onehots[g] = utils.onehot(obs_index_tuple[g], num_obs[g]) 105 | 106 | qs_out2 = inference.update_posterior_states(A, obs_onehots, prior=prior) 107 | 108 | for factor in range(len(num_states)): 109 | self.assertTrue(np.isclose(qs_out1[factor], qs_out2[factor]).all()) 110 | 111 | def test_update_posterior_states_factorized_single_factor(self): 112 | """ 113 | Tests the version of `update_posterior_states` where an `mb_dict` is provided as an argument to factorize 114 | the fixed-point iteration (FPI) algorithm. Single factor version. 115 | """ 116 | num_states = [3] 117 | num_obs = [3] 118 | 119 | prior = utils.random_single_categorical(num_states) 120 | 121 | A = utils.to_obj_array(maths.softmax(np.eye(num_states[0]) * 0.1)) 122 | 123 | obs_idx = 1 124 | obs = utils.onehot(obs_idx, num_obs[0]) 125 | 126 | mb_dict = {'A_factor_list': [[0]], 127 | 'A_modality_list': [[0]]} 128 | 129 | qs_out = inference.update_posterior_states_factorized(A, obs, num_obs, num_states, mb_dict, prior=prior) 130 | qs_validation = maths.softmax(maths.spm_log_single(A[0][obs_idx,:]) + maths.spm_log_single(prior[0])) 131 | 132 | self.assertTrue(np.isclose(qs_validation, qs_out[0]).all()) 133 | 134 | '''Try single modality inference where the observation is passed in as an int''' 135 | qs_out_2 = inference.update_posterior_states_factorized(A, obs_idx, num_obs, num_states, mb_dict, prior=prior) 136 | self.assertTrue(np.isclose(qs_out_2[0], qs_out[0]).all()) 137 | 138 | '''Try single modality inference where the observation is a one-hot stored in an object array''' 139 | qs_out_3 = inference.update_posterior_states_factorized(A, utils.to_obj_array(obs),num_obs, num_states, mb_dict, prior=prior) 140 | self.assertTrue(np.isclose(qs_out_3[0], qs_out[0]).all()) 141 | 142 | def test_update_posterior_states_factorized(self): 143 | """ 144 | Tests the version of `update_posterior_states` where an `mb_dict` is provided as an argument to factorize 145 | the fixed-point iteration (FPI) algorithm. 146 | """ 147 | 148 | num_states = [3, 4] 149 | num_obs = [3, 3, 5] 150 | 151 | prior = utils.random_single_categorical(num_states) 152 | 153 | obs_index_tuple = tuple([np.random.randint(obs_dim) for obs_dim in num_obs]) 154 | 155 | mb_dict = {'A_factor_list': [[0], [1], [0, 1]], 156 | 'A_modality_list': [[0, 2], [1, 2]]} 157 | 158 | A_reduced = utils.random_A_matrix(num_obs, num_states, A_factor_list=mb_dict['A_factor_list']) 159 | 160 | qs_out = inference.update_posterior_states_factorized(A_reduced, obs_index_tuple, num_obs, num_states, mb_dict, prior=prior) 161 | 162 | A_full = utils.initialize_empty_A(num_obs, num_states) 163 | for m, A_m in enumerate(A_full): 164 | other_factors = list(set(range(len(num_states))) - set(mb_dict['A_factor_list'][m])) # list of the factors that modality `m` does not depend on 165 | 166 | # broadcast or tile the reduced A matrix (`A_reduced`) along the dimensions of corresponding to `other_factors` 167 | expanded_dims = [num_obs[m]] + [1 if f in other_factors else ns for (f, ns) in enumerate(num_states)] 168 | tile_dims = [1] + [ns if f in other_factors else 1 for (f, ns) in enumerate(num_states)] 169 | A_full[m] = np.tile(A_reduced[m].reshape(expanded_dims), tile_dims) 170 | 171 | qs_validation = inference.update_posterior_states(A_full, obs_index_tuple, prior=prior) 172 | 173 | for qs_f_val, qs_f_out in zip(qs_validation, qs_out): 174 | self.assertTrue(np.isclose(qs_f_val, qs_f_out).all()) 175 | 176 | def test_update_posterior_states_factorized_noVFE_compute(self): 177 | """ 178 | Tests the version of `update_posterior_states` where an `mb_dict` is provided as an argument to factorize 179 | the fixed-point iteration (FPI) algorithm. 180 | 181 | In this version, we always run the total number of iterations because we don't compute the variational free energy over the course of convergence/optimization. 182 | """ 183 | 184 | num_states = [3, 4] 185 | num_obs = [3, 3, 5] 186 | 187 | prior = utils.random_single_categorical(num_states) 188 | 189 | obs_index_tuple = tuple([np.random.randint(obs_dim) for obs_dim in num_obs]) 190 | 191 | mb_dict = {'A_factor_list': [[0], [1], [0, 1]], 192 | 'A_modality_list': [[0, 2], [1, 2]]} 193 | 194 | A_reduced = utils.random_A_matrix(num_obs, num_states, A_factor_list=mb_dict['A_factor_list']) 195 | 196 | qs_out = inference.update_posterior_states_factorized(A_reduced, obs_index_tuple, num_obs, num_states, mb_dict, prior=prior, compute_vfe=False) 197 | 198 | A_full = utils.initialize_empty_A(num_obs, num_states) 199 | for m, A_m in enumerate(A_full): 200 | other_factors = list(set(range(len(num_states))) - set(mb_dict['A_factor_list'][m])) # list of the factors that modality `m` does not depend on 201 | 202 | # broadcast or tile the reduced A matrix (`A_reduced`) along the dimensions of corresponding to `other_factors` 203 | expanded_dims = [num_obs[m]] + [1 if f in other_factors else ns for (f, ns) in enumerate(num_states)] 204 | tile_dims = [1] + [ns if f in other_factors else 1 for (f, ns) in enumerate(num_states)] 205 | A_full[m] = np.tile(A_reduced[m].reshape(expanded_dims), tile_dims) 206 | 207 | qs_validation = inference.update_posterior_states(A_full, obs_index_tuple, prior=prior, compute_vfe=False) 208 | 209 | for qs_f_val, qs_f_out in zip(qs_validation, qs_out): 210 | self.assertTrue(np.isclose(qs_f_val, qs_f_out).all()) 211 | 212 | 213 | if __name__ == "__main__": 214 | unittest.main() -------------------------------------------------------------------------------- /test/test_inference_jax.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ Unit Tests 5 | __author__: Dimitrije Markovic, Conor Heins 6 | """ 7 | 8 | import os 9 | import unittest 10 | 11 | import numpy as np 12 | import jax.numpy as jnp 13 | 14 | from pymdp.jax.algos import run_vanilla_fpi as fpi_jax 15 | from pymdp.algos import run_vanilla_fpi as fpi_numpy 16 | from pymdp import utils, maths 17 | 18 | class TestInferenceJax(unittest.TestCase): 19 | 20 | def test_fixed_point_iteration_singlestate_singleobs(self): 21 | """ 22 | Tests the jax-ified version of mean-field fixed-point iteration against the original numpy version. 23 | In this version there is one hidden state factor and one observation modality 24 | """ 25 | 26 | num_states_list = [ 27 | [1], 28 | [5], 29 | [10] 30 | ] 31 | 32 | num_obs_list = [ 33 | [5], 34 | [1], 35 | [2] 36 | ] 37 | 38 | for (num_states, num_obs) in zip(num_states_list, num_obs_list): 39 | 40 | # numpy version 41 | prior = utils.random_single_categorical(num_states) 42 | A = utils.random_A_matrix(num_obs, num_states) 43 | 44 | obs = utils.obj_array(len(num_obs)) 45 | for m, obs_dim in enumerate(num_obs): 46 | obs[m] = utils.onehot(np.random.randint(obs_dim), obs_dim) 47 | 48 | qs_numpy = fpi_numpy(A, obs, num_obs, num_states, prior=prior, num_iter=16, dF=1.0, dF_tol=-1.0) # set dF_tol to negative number so numpy version of FPI never stops early due to convergence 49 | 50 | # jax version 51 | prior = [jnp.array(prior_f) for prior_f in prior] 52 | A = [jnp.array(a_m) for a_m in A] 53 | obs = [jnp.array(o_m) for o_m in obs] 54 | 55 | qs_jax = fpi_jax(A, obs, prior, num_iter=16) 56 | 57 | for f, _ in enumerate(qs_jax): 58 | self.assertTrue(np.allclose(qs_numpy[f], qs_jax[f])) 59 | 60 | def test_fixed_point_iteration_singlestate_multiobs(self): 61 | """ 62 | Tests the jax-ified version of mean-field fixed-point iteration against the original numpy version. 63 | In this version there is one hidden state factor and multiple observation modalities 64 | """ 65 | 66 | num_states_list = [ 67 | [1], 68 | [5], 69 | [10] 70 | ] 71 | 72 | num_obs_list = [ 73 | [5, 2], 74 | [1, 8, 9], 75 | [2, 2, 2] 76 | ] 77 | 78 | for (num_states, num_obs) in zip(num_states_list, num_obs_list): 79 | 80 | # numpy version 81 | prior = utils.random_single_categorical(num_states) 82 | A = utils.random_A_matrix(num_obs, num_states) 83 | 84 | obs = utils.obj_array(len(num_obs)) 85 | for m, obs_dim in enumerate(num_obs): 86 | obs[m] = utils.onehot(np.random.randint(obs_dim), obs_dim) 87 | 88 | qs_numpy = fpi_numpy(A, obs, num_obs, num_states, prior=prior, num_iter=16, dF=1.0, dF_tol=-1.0) # set dF_tol to negative number so numpy version of FPI never stops early due to convergence 89 | 90 | # jax version 91 | prior = [jnp.array(prior_f) for prior_f in prior] 92 | A = [jnp.array(a_m) for a_m in A] 93 | obs = [jnp.array(o_m) for o_m in obs] 94 | 95 | qs_jax = fpi_jax(A, obs, prior, num_iter=16) 96 | 97 | for f, _ in enumerate(qs_jax): 98 | self.assertTrue(np.allclose(qs_numpy[f], qs_jax[f])) 99 | 100 | def test_fixed_point_iteration_multistate_singleobs(self): 101 | """ 102 | Tests the jax-ified version of mean-field fixed-point iteration against the original numpy version. 103 | In this version there are multiple hidden state factors and a single observation modality 104 | """ 105 | 106 | num_states_list = [ 107 | [1, 10, 2], 108 | [5, 5, 10, 2], 109 | [10, 2] 110 | ] 111 | 112 | num_obs_list = [ 113 | [5], 114 | [1], 115 | [10] 116 | ] 117 | 118 | for (num_states, num_obs) in zip(num_states_list, num_obs_list): 119 | 120 | # numpy version 121 | prior = utils.random_single_categorical(num_states) 122 | A = utils.random_A_matrix(num_obs, num_states) 123 | 124 | obs = utils.obj_array(len(num_obs)) 125 | for m, obs_dim in enumerate(num_obs): 126 | obs[m] = utils.onehot(np.random.randint(obs_dim), obs_dim) 127 | 128 | qs_numpy = fpi_numpy(A, obs, num_obs, num_states, prior=prior, num_iter=16, dF=1.0, dF_tol=-1.0) # set dF_tol to negative number so numpy version of FPI never stops early due to convergence 129 | 130 | # jax version 131 | prior = [jnp.array(prior_f) for prior_f in prior] 132 | A = [jnp.array(a_m) for a_m in A] 133 | obs = [jnp.array(o_m) for o_m in obs] 134 | 135 | qs_jax = fpi_jax(A, obs, prior, num_iter=16) 136 | 137 | for f, _ in enumerate(qs_jax): 138 | self.assertTrue(np.allclose(qs_numpy[f], qs_jax[f])) 139 | 140 | 141 | def test_fixed_point_iteration_multistate_multiobs(self): 142 | """ 143 | Tests the jax-ified version of mean-field fixed-point iteration against the original numpy version. 144 | In this version there are multiple hidden state factors and multiple observation modalities 145 | """ 146 | 147 | ''' Start by creating a collection of random generative models with different 148 | cardinalities and dimensionalities of hidden state factors and observation modalities''' 149 | 150 | num_states_list = [ 151 | [2, 2, 5], 152 | [2, 2, 2], 153 | [4, 4] 154 | ] 155 | 156 | num_obs_list = [ 157 | [5, 10], 158 | [4, 3, 2], 159 | [5, 10, 6] 160 | ] 161 | 162 | for (num_states, num_obs) in zip(num_states_list, num_obs_list): 163 | 164 | # numpy version 165 | prior = utils.random_single_categorical(num_states) 166 | A = utils.random_A_matrix(num_obs, num_states) 167 | 168 | obs = utils.obj_array(len(num_obs)) 169 | for m, obs_dim in enumerate(num_obs): 170 | obs[m] = utils.onehot(np.random.randint(obs_dim), obs_dim) 171 | 172 | qs_numpy = fpi_numpy(A, obs, num_obs, num_states, prior=prior, num_iter=16, dF=1.0, dF_tol=-1.0) # set dF_tol to negative number so numpy version of FPI never stops early due to convergence 173 | 174 | # jax version 175 | prior = [jnp.array(prior_f) for prior_f in prior] 176 | A = [jnp.array(a_m) for a_m in A] 177 | obs = [jnp.array(o_m) for o_m in obs] 178 | 179 | qs_jax = fpi_jax(A, obs, prior, num_iter=16) 180 | 181 | for f, _ in enumerate(qs_jax): 182 | self.assertTrue(np.allclose(qs_numpy[f], qs_jax[f])) 183 | 184 | def test_fixed_point_iteration_index_observations(self): 185 | """ 186 | Tests the jax-ified version of mean-field fixed-point iteration against the original NumPy version. 187 | In this version there are multiple hidden state factors and multiple observation modalities. 188 | 189 | Test the jax version with index-based observations (not one-hots) 190 | """ 191 | 192 | ''' Start by creating a collection of random generative models with different 193 | cardinalities and dimensionalities of hidden state factors and observation modalities''' 194 | 195 | num_states_list = [ 196 | [2, 2, 5], 197 | [2, 2, 2], 198 | [4, 4] 199 | ] 200 | 201 | num_obs_list = [ 202 | [5, 10], 203 | [4, 3, 2], 204 | [5, 10, 6] 205 | ] 206 | 207 | for (num_states, num_obs) in zip(num_states_list, num_obs_list): 208 | 209 | # numpy version 210 | prior = utils.random_single_categorical(num_states) 211 | A = utils.random_A_matrix(num_obs, num_states) 212 | 213 | obs = utils.obj_array(len(num_obs)) 214 | for m, obs_dim in enumerate(num_obs): 215 | obs[m] = utils.onehot(np.random.randint(obs_dim), obs_dim) 216 | 217 | qs_numpy = fpi_numpy(A, obs, num_obs, num_states, prior=prior, num_iter=16, dF=1.0, dF_tol=-1.0) # set dF_tol to negative number so numpy version of FPI never stops early due to convergence 218 | 219 | obs_idx = [] 220 | for ob in obs: 221 | obs_idx.append(np.where(ob)[0][0]) 222 | 223 | # jax version 224 | prior = [jnp.array(prior_f) for prior_f in prior] 225 | A = [jnp.array(a_m) for a_m in A] 226 | # obs = [jnp.array(o_m) for o_m in obs] 227 | 228 | qs_jax = fpi_jax(A, obs_idx, prior, num_iter=16, distr_obs=False) 229 | 230 | for f, _ in enumerate(qs_jax): 231 | self.assertTrue(np.allclose(qs_numpy[f], qs_jax[f])) 232 | 233 | if __name__ == "__main__": 234 | unittest.main() -------------------------------------------------------------------------------- /test/test_learning_jax.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ Unit Tests 5 | __author__: Dimitrije Markovic, Conor Heins 6 | """ 7 | 8 | import unittest 9 | 10 | import numpy as np 11 | import jax.numpy as jnp 12 | import jax.tree_util as jtu 13 | 14 | from pymdp.learning import update_obs_likelihood_dirichlet as update_pA_numpy 15 | from pymdp.learning import update_obs_likelihood_dirichlet_factorized as update_pA_numpy_factorized 16 | from pymdp.jax.learning import update_obs_likelihood_dirichlet as update_pA_jax 17 | from pymdp import utils 18 | 19 | class TestLearningJax(unittest.TestCase): 20 | 21 | def test_update_observation_likelihood_fullyconnected(self): 22 | """ 23 | Testing JAX-ified version of updating Dirichlet posterior over observation likelihood parameters (qA is posterior, pA is prior, and A is expectation 24 | of likelihood wrt to current posterior over A, i.e. $A = E_{Q(A)}[P(o|s,A)]$. 25 | 26 | This is the so-called 'fully-connected' version where all hidden state factors drive each modality (i.e. A_dependencies is a list of lists of hidden state factors) 27 | """ 28 | 29 | num_obs_list = [ [5], 30 | [10, 3, 2], 31 | [2, 4, 4, 2], 32 | [10] 33 | ] 34 | num_states_list = [ [2,3,4], 35 | [2], 36 | [4,5], 37 | [3] 38 | ] 39 | 40 | A_dependencies_list = [ [ [0,1,2] ], 41 | [ [0], [0], [0] ], 42 | [ [0,1], [0,1], [0,1], [0,1] ], 43 | [ [0] ] 44 | ] 45 | 46 | for (num_obs, num_states, A_dependencies) in zip(num_obs_list, num_states_list, A_dependencies_list): 47 | # create numpy arrays to test numpy version of learning 48 | 49 | # create A matrix initialization (expected initial value of P(o|s, A)) and prior over A (pA) 50 | A_np = utils.random_A_matrix(num_obs, num_states) 51 | pA_np = utils.dirichlet_like(A_np, scale = 3.0) 52 | 53 | # create random observations 54 | obs_np = utils.obj_array(len(num_obs)) 55 | for m, obs_dim in enumerate(num_obs): 56 | obs_np[m] = utils.onehot(np.random.randint(obs_dim), obs_dim) 57 | 58 | # create random state posterior 59 | qs_np = utils.random_single_categorical(num_states) 60 | 61 | l_rate = 1.0 62 | 63 | # run numpy version of learning 64 | qA_np_test = update_pA_numpy(pA_np, A_np, obs_np, qs_np, lr=l_rate) 65 | 66 | pA_jax = jtu.tree_map(lambda x: jnp.array(x), list(pA_np)) 67 | A_jax = jtu.tree_map(lambda x: jnp.array(x), list(A_np)) 68 | obs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(obs_np)) 69 | qs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(qs_np)) 70 | 71 | qA_jax_test, E_qA_jax_test = update_pA_jax( 72 | pA_jax, 73 | A_jax, 74 | obs_jax, 75 | qs_jax, 76 | A_dependencies=A_dependencies, 77 | onehot_obs=True, 78 | num_obs=num_obs, 79 | lr=l_rate 80 | ) 81 | 82 | for modality, obs_dim in enumerate(num_obs): 83 | self.assertTrue(np.allclose(qA_jax_test[modality], qA_np_test[modality])) 84 | 85 | def test_update_observation_likelihood_factorized(self): 86 | """ 87 | Testing JAX-ified version of updating Dirichlet posterior over observation likelihood parameters (qA is posterior, pA is prior, and A is expectation 88 | of likelihood wrt to current posterior over A, i.e. $A = E_{Q(A)}[P(o|s,A)]$. 89 | 90 | This is the factorized version where only some hidden state factors drive each modality (i.e. A_dependencies is a list of lists of hidden state factors) 91 | """ 92 | 93 | num_obs_list = [ [5], 94 | [10, 3, 2], 95 | [2, 4, 4, 2], 96 | [10] 97 | ] 98 | num_states_list = [ [2,3,4], 99 | [2, 5, 2], 100 | [4,5], 101 | [3] 102 | ] 103 | 104 | A_dependencies_list = [ [ [0,1] ], 105 | [ [0, 1], [1], [1, 2] ], 106 | [ [0,1], [0], [0,1], [1] ], 107 | [ [0] ] 108 | ] 109 | 110 | for (num_obs, num_states, A_dependencies) in zip(num_obs_list, num_states_list, A_dependencies_list): 111 | # create numpy arrays to test numpy version of learning 112 | 113 | # create A matrix initialization (expected initial value of P(o|s, A)) and prior over A (pA) 114 | A_np = utils.random_A_matrix(num_obs, num_states, A_factor_list=A_dependencies) 115 | pA_np = utils.dirichlet_like(A_np, scale = 3.0) 116 | 117 | # create random observations 118 | obs_np = utils.obj_array(len(num_obs)) 119 | for m, obs_dim in enumerate(num_obs): 120 | obs_np[m] = utils.onehot(np.random.randint(obs_dim), obs_dim) 121 | 122 | # create random state posterior 123 | qs_np = utils.random_single_categorical(num_states) 124 | 125 | l_rate = 1.0 126 | 127 | # run numpy version of learning 128 | qA_np_test = update_pA_numpy_factorized(pA_np, A_np, obs_np, qs_np, A_dependencies, lr=l_rate) 129 | 130 | pA_jax = jtu.tree_map(lambda x: jnp.array(x), list(pA_np)) 131 | A_jax = jtu.tree_map(lambda x: jnp.array(x), list(A_np)) 132 | obs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(obs_np)) 133 | qs_jax = jtu.tree_map(lambda x: jnp.array(x)[None], list(qs_np)) 134 | 135 | qA_jax_test, E_qA_jax_test = update_pA_jax( 136 | pA_jax, 137 | A_jax, 138 | obs_jax, 139 | qs_jax, 140 | A_dependencies=A_dependencies, 141 | onehot_obs=True, 142 | num_obs=num_obs, 143 | lr=l_rate 144 | ) 145 | 146 | for modality, obs_dim in enumerate(num_obs): 147 | self.assertTrue(np.allclose(qA_jax_test[modality],qA_np_test[modality])) 148 | 149 | if __name__ == "__main__": 150 | unittest.main() -------------------------------------------------------------------------------- /test/test_mmp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ Unit Test of `run_mmp` function under various parameterisations 5 | 6 | __date__: 25/11/2020 7 | __author__: Conor Heins, Alexander Tschantz 8 | """ 9 | 10 | import os 11 | import unittest 12 | 13 | import numpy as np 14 | from scipy.io import loadmat 15 | 16 | from pymdp.utils import get_model_dimensions, convert_observation_array 17 | from pymdp.algos import run_mmp 18 | from pymdp.maths import get_joint_likelihood_seq 19 | 20 | DATA_PATH = "test/matlab_crossval/output/" 21 | 22 | class MMP(unittest.TestCase): 23 | 24 | def test_mmp_a(self): 25 | """ 26 | Testing our SPM-ified version of `run_MMP` with 27 | 1 hidden state factor & 1 outcome modality, at a random fixed 28 | timestep during the generative process 29 | """ 30 | 31 | array_path = os.path.join(os.getcwd(), DATA_PATH + "mmp_a.mat") 32 | mat_contents = loadmat(file_name=array_path) 33 | 34 | A = mat_contents["A"][0] 35 | B = mat_contents["B"][0] 36 | prev_obs = mat_contents["obs_idx"].astype("int64") 37 | policy = mat_contents["policy"].astype("int64") - 1 38 | curr_t = mat_contents["t"][0, 0].astype("int64") - 1 39 | t_horizon = mat_contents["t_horizon"][0, 0].astype("int64") 40 | prev_actions = mat_contents["previous_actions"].astype("int64") - 1 41 | result_spm = mat_contents["qs"][0] 42 | likelihoods = mat_contents["likelihoods"][0] 43 | 44 | num_obs, num_states, _, num_factors = get_model_dimensions(A, B) 45 | prev_obs = convert_observation_array( 46 | prev_obs[:, max(0, curr_t - t_horizon) : (curr_t + 1)], num_obs 47 | ) 48 | 49 | prev_actions = prev_actions[(max(0, curr_t - t_horizon) -1) :, :] 50 | prior = np.empty(num_factors, dtype=object) 51 | for f in range(num_factors): 52 | uniform = np.ones(num_states[f]) / num_states[f] 53 | prior[f] = B[f][:, :, prev_actions[0, f]].dot(uniform) 54 | 55 | lh_seq = get_joint_likelihood_seq(A, prev_obs, num_states) 56 | qs_seq, _ = run_mmp( 57 | lh_seq, B, policy, prev_actions[1:], prior=prior, num_iter=5, grad_descent=True 58 | ) 59 | 60 | result_pymdp = qs_seq[-1] 61 | for f in range(num_factors): 62 | self.assertTrue(np.isclose(result_spm[f].squeeze(), result_pymdp[f]).all()) 63 | 64 | def test_mmp_b(self): 65 | """ Testing our SPM-ified version of `run_MMP` with 66 | 2 hidden state factors & 2 outcome modalities, at a random fixed 67 | timestep during the generative process""" 68 | 69 | array_path = os.path.join(os.getcwd(), DATA_PATH + "mmp_b.mat") 70 | mat_contents = loadmat(file_name=array_path) 71 | 72 | A = mat_contents["A"][0] 73 | B = mat_contents["B"][0] 74 | prev_obs = mat_contents["obs_idx"].astype("int64") 75 | policy = mat_contents["policy"].astype("int64") - 1 76 | curr_t = mat_contents["t"][0, 0].astype("int64") - 1 77 | t_horizon = mat_contents["t_horizon"][0, 0].astype("int64") 78 | prev_actions = mat_contents["previous_actions"].astype("int64") - 1 79 | result_spm = mat_contents["qs"][0] 80 | likelihoods = mat_contents["likelihoods"][0] 81 | 82 | num_obs, num_states, _, num_factors = get_model_dimensions(A, B) 83 | prev_obs = convert_observation_array( 84 | prev_obs[:, max(0, curr_t - t_horizon) : (curr_t + 1)], num_obs 85 | ) 86 | 87 | prev_actions = prev_actions[(max(0, curr_t - t_horizon)) :, :] 88 | lh_seq = get_joint_likelihood_seq(A, prev_obs, num_states) 89 | qs_seq, _ = run_mmp(lh_seq, 90 | B, policy, prev_actions=prev_actions, prior=None, num_iter=5, grad_descent=True 91 | ) 92 | 93 | result_pymdp = qs_seq[-1] 94 | for f in range(num_factors): 95 | self.assertTrue(np.isclose(result_spm[f].squeeze(), result_pymdp[f]).all()) 96 | 97 | def test_mmp_c(self): 98 | """ Testing our SPM-ified version of `run_MMP` with 99 | 2 hidden state factors & 2 outcome modalities, at the very first 100 | timestep of the generative process (boundary condition test). So there 101 | are no previous actions""" 102 | 103 | array_path = os.path.join(os.getcwd(), DATA_PATH + "mmp_c.mat") 104 | mat_contents = loadmat(file_name=array_path) 105 | 106 | A = mat_contents["A"][0] 107 | B = mat_contents["B"][0] 108 | prev_obs = mat_contents["obs_idx"].astype("int64") 109 | policy = mat_contents["policy"].astype("int64") - 1 110 | curr_t = mat_contents["t"][0, 0].astype("int64") - 1 111 | t_horizon = mat_contents["t_horizon"][0, 0].astype("int64") 112 | # prev_actions = mat_contents["previous_actions"].astype("int64") - 1 113 | result_spm = mat_contents["qs"][0] 114 | likelihoods = mat_contents["likelihoods"][0] 115 | 116 | num_obs, num_states, _, num_factors = get_model_dimensions(A, B) 117 | prev_obs = convert_observation_array( 118 | prev_obs[:, max(0, curr_t - t_horizon) : (curr_t + 1)], num_obs 119 | ) 120 | 121 | # prev_actions = prev_actions[(max(0, curr_t - t_horizon)) :, :] 122 | lh_seq = get_joint_likelihood_seq(A, prev_obs, num_states) 123 | qs_seq, _ = run_mmp( 124 | lh_seq, B, policy, prev_actions=None, prior=None, num_iter=5, grad_descent=True 125 | ) 126 | 127 | result_pymdp = qs_seq[-1] 128 | for f in range(num_factors): 129 | self.assertTrue(np.isclose(result_spm[f].squeeze(), result_pymdp[f]).all()) 130 | 131 | def test_mmp_d(self): 132 | """ Testing our SPM-ified version of `run_MMP` with 133 | 2 hidden state factors & 2 outcome modalities, at the final 134 | timestep of the generative process (boundary condition test) 135 | @NOTE: mmp_d.mat test has issues with the prediction errors. But the future messages are 136 | totally fine (even at the last timestep of variational iteration.""" 137 | 138 | array_path = os.path.join(os.getcwd(), DATA_PATH + "mmp_d.mat") 139 | mat_contents = loadmat(file_name=array_path) 140 | 141 | A = mat_contents["A"][0] 142 | B = mat_contents["B"][0] 143 | prev_obs = mat_contents["obs_idx"].astype("int64") 144 | policy = mat_contents["policy"].astype("int64") - 1 145 | curr_t = mat_contents["t"][0, 0].astype("int64") - 1 146 | t_horizon = mat_contents["t_horizon"][0, 0].astype("int64") 147 | prev_actions = mat_contents["previous_actions"].astype("int64") - 1 148 | result_spm = mat_contents["qs"][0] 149 | likelihoods = mat_contents["likelihoods"][0] 150 | 151 | num_obs, num_states, _, num_factors = get_model_dimensions(A, B) 152 | prev_obs = convert_observation_array( 153 | prev_obs[:, max(0, curr_t - t_horizon) : (curr_t + 1)], num_obs 154 | ) 155 | 156 | prev_actions = prev_actions[(max(0, curr_t - t_horizon) -1) :, :] 157 | prior = np.empty(num_factors, dtype=object) 158 | for f in range(num_factors): 159 | uniform = np.ones(num_states[f]) / num_states[f] 160 | prior[f] = B[f][:, :, prev_actions[0, f]].dot(uniform) 161 | 162 | lh_seq = get_joint_likelihood_seq(A, prev_obs, num_states) 163 | 164 | qs_seq, _ = run_mmp( 165 | lh_seq, B, policy, prev_actions[1:], prior=prior, num_iter=5, grad_descent=True, last_timestep=True 166 | ) 167 | 168 | result_pymdp = qs_seq[-1] 169 | 170 | for f in range(num_factors): 171 | self.assertTrue(np.isclose(result_spm[f].squeeze(), result_pymdp[f]).all()) 172 | 173 | """" 174 | @ NOTE (from Conor Heins 07.04.2021) 175 | Please keep this uncommented code below here. We need to figure out how to re-include optional arguments e.g. `save_vfe_seq` 176 | into `run_mmp` so that important tests like these can run again some day. My only dumb solution for now would be to just have a 'UnitTest variant' of the MMP function 177 | that has extra optional outputs that slow down run-time (e.g. `save_vfe_seq`), and are thus excluded from the deployable version of `pymdp`, 178 | but are useful for benchmarking the performance/ accuracy of the algorithm 179 | """ 180 | # def test_mmp_fixedpoints(self): 181 | 182 | # array_path = os.path.join(os.getcwd(), DATA_PATH + "mmp_a.mat") 183 | # mat_contents = loadmat(file_name=array_path) 184 | 185 | # A = mat_contents["A"][0] 186 | # B = mat_contents["B"][0] 187 | # prev_obs = mat_contents["obs_idx"].astype("int64") 188 | # policy = mat_contents["policy"].astype("int64") - 1 189 | # curr_t = mat_contents["t"][0, 0].astype("int64") - 1 190 | # t_horizon = mat_contents["t_horizon"][0, 0].astype("int64") 191 | # prev_actions = mat_contents["previous_actions"].astype("int64") - 1 192 | # result_spm = mat_contents["qs"][0] 193 | # likelihoods = mat_contents["likelihoods"][0] 194 | 195 | # num_obs, num_states, _, num_factors = get_model_dimensions(A, B) 196 | # prev_obs = convert_observation_array( 197 | # prev_obs[:, max(0, curr_t - t_horizon) : (curr_t + 1)], num_obs 198 | # ) 199 | 200 | # prev_actions = prev_actions[(max(0, curr_t - t_horizon) -1) :, :] 201 | # prior = np.empty(num_factors, dtype=object) 202 | # for f in range(num_factors): 203 | # uniform = np.ones(num_states[f]) / num_states[f] 204 | # prior[f] = B[f][:, :, prev_actions[0, f]].dot(uniform) 205 | 206 | # lh_seq = get_joint_likelihood_seq(A, prev_obs, num_states) 207 | # qs_seq, F = run_mmp( 208 | # lh_seq, B, policy, prev_actions[1:], prior=prior, num_iter=5, grad_descent=False, save_vfe_seq=True 209 | # ) 210 | 211 | # self.assertTrue((np.diff(np.array(F)) < 0).all()) 212 | 213 | 214 | if __name__ == "__main__": 215 | unittest.main() 216 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ Agent Class 5 | 6 | __author__: Conor Heins, Alexander Tschantz, Daphne Demekas, Brennan Klein 7 | 8 | """ 9 | 10 | import unittest 11 | 12 | import numpy as np 13 | 14 | from pymdp import utils 15 | 16 | class TestUtils(unittest.TestCase): 17 | def test_obj_array_from_list(self): 18 | """ 19 | Tests `obj_array_from_list` 20 | """ 21 | # make arrays with same leading dimensions. naive method trigger numpy broadcasting error. 22 | arrs = [np.zeros((3, 6)), np.zeros((3, 4, 5))] 23 | obs_arrs = utils.obj_array_from_list(arrs) 24 | 25 | self.assertTrue(all([np.all(a == b) for a, b in zip(arrs, obs_arrs)])) 26 | 27 | if __name__ == "__main__": 28 | unittest.main() -------------------------------------------------------------------------------- /test/test_wrappers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from pathlib import Path 4 | from pymdp.utils import Dimensions, get_model_dimensions_from_labels 5 | 6 | class TestWrappers(unittest.TestCase): 7 | 8 | def test_get_model_dimensions_from_labels(self): 9 | """ 10 | Tests model dimension extraction from labels including observations, states and actions. 11 | """ 12 | model_labels = { 13 | "observations": { 14 | "species_observation": [ 15 | "absent", 16 | "present", 17 | ], 18 | "budget_observation": [ 19 | "high", 20 | "medium", 21 | "low", 22 | ], 23 | }, 24 | "states": { 25 | "species_state": [ 26 | "extant", 27 | "extinct", 28 | ], 29 | }, 30 | "actions": { 31 | "conservation_action": [ 32 | "manage", 33 | "survey", 34 | "stop", 35 | ], 36 | }, 37 | } 38 | 39 | want = Dimensions( 40 | num_observations=[2, 3], 41 | num_observation_modalities=2, 42 | num_states=[2], 43 | num_state_factors=1, 44 | num_controls=[3], 45 | num_control_factors=1, 46 | ) 47 | 48 | got = get_model_dimensions_from_labels(model_labels) 49 | 50 | self.assertEqual(want.num_observations, got.num_observations) 51 | self.assertEqual(want.num_observation_modalities, got.num_observation_modalities) 52 | self.assertEqual(want.num_states, got.num_states) 53 | self.assertEqual(want.num_state_factors, got.num_state_factors) 54 | self.assertEqual(want.num_controls, got.num_controls) 55 | self.assertEqual(want.num_control_factors, got.num_control_factors) 56 | 57 | if __name__ == "__main__": 58 | unittest.main() 59 | --------------------------------------------------------------------------------