├── docs ├── FAQs.md ├── examples ├── .htaccess ├── _static │ ├── favicon.png │ ├── mathjax.js │ ├── custom_css.css │ ├── logo-light.svg │ └── logo-dark.svg ├── api │ ├── linear_net_theoretical_energy.png │ ├── Continuous-time Inference.md │ ├── Training.md │ ├── Utils.md │ ├── Testing.md │ ├── Discrete updates.md │ ├── Initialisation.md │ ├── Theoretical tools.md │ ├── Energy functions.md │ └── Gradients.md ├── _overrides │ └── partials │ │ ├── logo.html │ │ └── source.html ├── requirements.txt ├── advanced_usage.md ├── basic_usage.md └── index.md ├── experiments ├── library_paper │ ├── __init__.py │ ├── utils.py │ ├── test_theory_energies.py │ └── train_mlp.py ├── mupc_paper │ ├── __init__.py │ ├── spotlight_fig.png │ ├── requirements.txt │ ├── README.md │ ├── test_energy_theory.py │ ├── analyse_activity_hessian.py │ ├── test_mlp_fwd_pass.py │ └── train_bpn.py └── datasets.py ├── tests ├── __init__.py ├── test_errors.py ├── README.md ├── conftest.py ├── test_init.py ├── test_infer.py ├── test_test_functions.py ├── test_analytical.py ├── test_utils.py └── test_train.py ├── .gitattributes ├── .gitignore ├── jpc ├── _core │ ├── _errors.py │ ├── __init__.py │ ├── _infer.py │ └── _init.py ├── __init__.py ├── _utils.py └── _test.py ├── .github ├── workflows │ ├── tests.yml │ └── build_docs.yml └── logo-with-background.svg ├── LICENSE ├── pyproject.toml ├── mkdocs.yml └── README.md /docs/FAQs.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/examples: -------------------------------------------------------------------------------- 1 | ../examples -------------------------------------------------------------------------------- /experiments/library_paper/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/mupc_paper/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.htaccess: -------------------------------------------------------------------------------- 1 | ErrorDocument 404 /jpc/404.html -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Test suite for the jpc library.""" 2 | 3 | -------------------------------------------------------------------------------- /docs/_static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thebuckleylab/jpc/HEAD/docs/_static/favicon.png -------------------------------------------------------------------------------- /experiments/mupc_paper/spotlight_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thebuckleylab/jpc/HEAD/experiments/mupc_paper/spotlight_fig.png -------------------------------------------------------------------------------- /docs/api/linear_net_theoretical_energy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thebuckleylab/jpc/HEAD/docs/api/linear_net_theoretical_energy.png -------------------------------------------------------------------------------- /docs/_overrides/partials/logo.html: -------------------------------------------------------------------------------- 1 | logo 2 | logo -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages 2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code 3 | *.ipynb linguist-generated -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints 2 | **/data/ 3 | **/datasets/ 4 | examples/data/ 5 | *.egg-info/ 6 | *.pdf 7 | *.npy 8 | *.DS_Store 9 | *.pkl 10 | *.iml 11 | *.xml 12 | __pycache__/ 13 | .vscode/ 14 | .all_objects.cache 15 | site/ 16 | venv*/ 17 | *.log 18 | .pytest_cache/ 19 | .coverage 20 | coverage.json 21 | htmlcov/ 22 | -------------------------------------------------------------------------------- /docs/api/Continuous-time Inference.md: -------------------------------------------------------------------------------- 1 | # Continuous-time inference 2 | 3 | The inference or activity dynamics of PC networks can be solved in either 4 | discrete or continuous time. [jpc.solve_inference()](https://thebuckleylab.github.io/jpc/api/Continuous-time%20Inference/#jpc.solve_inference) 5 | leverages ODE solvers to integrate the continuous-time dynamics. 6 | 7 | ::: jpc.solve_inference -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | jax==0.5.2 # to avoid conflict for `jax.core.ClosedJaxpr` 2 | hippogriffe==0.2.2, 3 | mkdocs==1.6.1, 4 | mkdocs-include-exclude-files==0.1.0, 5 | mkdocs-ipynb==0.1.1, 6 | mkdocs-material==9.6.7, 7 | mkdocstrings==0.28.3, 8 | mkdocstrings-python==1.16.8, 9 | pymdown-extensions==10.14.3, 10 | 11 | # Dependencies of JPC itself. 12 | # Always use most up-to-date versions. 13 | jax[cpu] -------------------------------------------------------------------------------- /docs/_static/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.startup.output.clearCache() 16 | MathJax.typesetClear() 17 | MathJax.texReset() 18 | MathJax.typesetPromise() 19 | }) -------------------------------------------------------------------------------- /jpc/_core/_errors.py: -------------------------------------------------------------------------------- 1 | _PARAM_TYPES = ["sp", "mupc", "ntp"] 2 | 3 | 4 | def _check_param_type(param_type): 5 | if param_type not in _PARAM_TYPES: 6 | raise ValueError( 7 | 'Invalid parameterisation. Options are `"sp"` (standard ' 8 | 'parameterisation), `"mupc"` (μPC), or `"ntp"` (neural tangent ' 9 | 'parameterisation). See `_get_param_scalings()` (https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings) ' 10 | 'for the specific scalings of these different parameterisations.' 11 | ) 12 | -------------------------------------------------------------------------------- /docs/api/Training.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | JPC provides 2 single convenience functions to update the parameters of any 4 | PC-compatible model with PC: 5 | 6 | * [jpc.make_pc_step()](https://thebuckleylab.github.io/jpc/api/Training/#jpc.make_pc_step) to 7 | perform an update using standard PC, and 8 | * [jpc.make_hpc_step()](https://thebuckleylab.github.io/jpc/api/Training/#jpc.make_hpc_step) 9 | to use hybrid PC ([Tscshantz et al., 2023](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1011280)). 10 | 11 | ::: jpc.make_pc_step 12 | 13 | --- 14 | 15 | ::: jpc.make_hpc_step -------------------------------------------------------------------------------- /docs/api/Utils.md: -------------------------------------------------------------------------------- 1 | # Utils 2 | 3 | JPC provides several standard utilities for neural network training, including 4 | creation of simple models, losses, and metrics. 5 | 6 | ::: jpc.make_mlp 7 | 8 | --- 9 | 10 | ::: jpc.make_skip_model 11 | 12 | --- 13 | 14 | ::: jpc.get_act_fn 15 | 16 | --- 17 | 18 | ::: jpc.mse_loss 19 | 20 | --- 21 | 22 | ::: jpc.cross_entropy_loss 23 | 24 | --- 25 | 26 | ::: jpc.compute_accuracy 27 | 28 | --- 29 | 30 | ::: jpc.get_t_max 31 | 32 | --- 33 | 34 | ::: jpc.compute_infer_energies 35 | 36 | --- 37 | 38 | ::: jpc.compute_activity_norms 39 | 40 | --- 41 | 42 | ::: jpc.compute_param_norms -------------------------------------------------------------------------------- /docs/api/Testing.md: -------------------------------------------------------------------------------- 1 | # Testing 2 | 3 | JPC provides a few convenience functions to test different types of PC network (PCN): 4 | 5 | * [jpc.test_discriminative_pc()](https://thebuckleylab.github.io/jpc/api/Testing/#jpc.test_discriminative_pc) 6 | for test loss and accuracy of discriminative PCNs; 7 | * [jpc.test_generative_pc()](https://thebuckleylab.github.io/jpc/api/Testing/#jpc.test_generative_pc) 8 | for accuracy and output predictions of generative PCNs; and 9 | * [jpc.test_hpc()](https://thebuckleylab.github.io/jpc/api/Testing/#jpc.test_hpc) for accuracy 10 | of all models (amortiser, generator, & hybrid) as well as output predictions. 11 | 12 | ::: jpc.test_discriminative_pc 13 | 14 | --- 15 | 16 | ::: jpc.test_generative_pc 17 | 18 | --- 19 | 20 | ::: jpc.test_hpc -------------------------------------------------------------------------------- /tests/test_errors.py: -------------------------------------------------------------------------------- 1 | """Tests for error checking functions.""" 2 | 3 | import pytest 4 | from jpc._core._errors import _check_param_type 5 | 6 | 7 | def test_check_param_type_valid(): 8 | """Test that valid parameter types pass.""" 9 | _check_param_type("sp") 10 | _check_param_type("mupc") 11 | _check_param_type("ntp") 12 | 13 | 14 | def test_check_param_type_invalid(): 15 | """Test that invalid parameter types raise ValueError.""" 16 | with pytest.raises(ValueError, match="Invalid parameterisation"): 17 | _check_param_type("invalid") 18 | 19 | with pytest.raises(ValueError, match="Invalid parameterisation"): 20 | _check_param_type("") 21 | 22 | with pytest.raises(ValueError, match="Invalid parameterisation"): 23 | _check_param_type("ntk") 24 | -------------------------------------------------------------------------------- /docs/_overrides/partials/source.html: -------------------------------------------------------------------------------- 1 | {% import "partials/language.html" as lang with context %} 2 | 3 |
4 | {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %} 5 | {% include ".icons/" ~ icon ~ ".svg" %} 6 |
7 |
8 | {{ config.repo_name }} 9 |
10 |
11 | {% if config.theme.twitter_url %} 12 | 13 |
14 | {% include ".icons/fontawesome/brands/twitter.svg" %} 15 |
16 |
17 | {{ config.theme.twitter_name }} 18 |
19 |
20 | {% endif %} -------------------------------------------------------------------------------- /docs/api/Discrete updates.md: -------------------------------------------------------------------------------- 1 | # Discrete updates 2 | 3 | JPC provides access to standard discrete optimisers to update the parameters of 4 | PC networks ([jpc.update_pc_params](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_pc_params)), 5 | and to both discrete ([jpc.update_pc_activities](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_pc_activities)) 6 | and continuous optimisers ([jpc.solve_inference](https://thebuckleylab.github.io/jpc/api/Continuous-time%20Inference/#jpc.solve_inference)) 7 | to solve the PC inference or activity dynamics. 8 | 9 | ::: jpc.update_pc_activities 10 | 11 | --- 12 | 13 | ::: jpc.update_pc_params 14 | 15 | --- 16 | 17 | ::: jpc.update_bpc_activities 18 | 19 | --- 20 | 21 | ::: jpc.update_bpc_params 22 | 23 | --- 24 | 25 | ::: jpc.update_epc_errors 26 | 27 | --- 28 | 29 | ::: jpc.update_epc_params -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | 12 | jobs: 13 | test: 14 | strategy: 15 | matrix: 16 | python-version: ["3.10", "3.11"] 17 | os: [ubuntu-latest] 18 | runs-on: ${{ matrix.os }} 19 | steps: 20 | - name: Checkout code 21 | uses: actions/checkout@v4 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | python -m pip install -e . 32 | 33 | - name: Run tests with coverage 34 | run: | 35 | pytest tests/ --cov=jpc --cov-report=term -v -------------------------------------------------------------------------------- /.github/workflows/build_docs.yml: -------------------------------------------------------------------------------- 1 | name: Build docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | build: 13 | strategy: 14 | matrix: 15 | python-version: [ 3.11 ] 16 | os: [ ubuntu-latest ] 17 | runs-on: ${{ matrix.os }} 18 | steps: 19 | - name: Checkout code 20 | uses: actions/checkout@v2 21 | 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install -e . 31 | python -m pip install -r docs/requirements.txt 32 | 33 | - name: Build docs 34 | run: | 35 | mkdocs build 36 | 37 | - name: Deploy docs 38 | run: mkdocs gh-deploy --force -------------------------------------------------------------------------------- /docs/api/Initialisation.md: -------------------------------------------------------------------------------- 1 | # Initialisation 2 | 3 | JPC provides 4 ways of initialising the activities of a PC network: 4 | 5 | * [jpc.init_activities_with_ffwd()](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_activities_with_ffwd) for a feedforward pass (standard), 6 | * [jpc.init_activities_from_normal()](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_activities_from_normal) for random initialisation, 7 | * [jpc.init_activities_with_amort()](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_activities_with_amort) for use of an amortised network, and 8 | * [jpc.init_epc_errors()](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_epc_errors) for zero-initialisation of prediction errors in [ePC](https://arxiv.org/abs/2505.20137). 9 | 10 | ::: jpc.init_activities_with_ffwd 11 | 12 | --- 13 | 14 | ::: jpc.init_activities_from_normal 15 | 16 | --- 17 | 18 | ::: jpc.init_activities_with_amort 19 | 20 | --- 21 | 22 | ::: jpc.init_epc_errors -------------------------------------------------------------------------------- /docs/api/Theoretical tools.md: -------------------------------------------------------------------------------- 1 | # Theoretical tools 2 | 3 | JPC provides the following theoretical tools that can be used to study 4 | **deep linear networks** (DLNs) trained with PC: 5 | 6 | * [jpc.compute_linear_equilib_energy()](https://thebuckleylab.github.io/jpc/api/Theoretical%20tools/#jpc.compute_linear_equilib_energy) 7 | to compute the theoretical PC energy at the solution of the activities for DLNs; 8 | * [jpc.compute_linear_activity_hessian()](https://thebuckleylab.github.io/jpc/api/Theoretical%20tools/#jpc.compute_linear_activity_hessian) 9 | to compute the theoretical Hessian of the energy with respect to the activities of DLNs; 10 | * [jpc.compute_linear_activity_solution()](https://thebuckleylab.github.io/jpc/api/Theoretical%20tools/#jpc.compute_linear_activity_solution) 11 | to compute the analytical PC inference solution for DLNs. 12 | 13 | ::: jpc.compute_linear_equilib_energy 14 | 15 | --- 16 | 17 | ::: jpc.compute_linear_activity_hessian 18 | 19 | --- 20 | 21 | ::: jpc.compute_linear_activity_solution -------------------------------------------------------------------------------- /docs/api/Energy functions.md: -------------------------------------------------------------------------------- 1 | # Energy functions 2 | 3 | JPC provides three main PC energy functions: 4 | 5 | * [jpc.pc_energy_fn()](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc.pc_energy_fn) 6 | for standard PC networks, 7 | * [jpc.hpc_energy_fn()](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc.hpc_energy_fn) 8 | for hybrid PC models ([Tscshantz et al., 2023](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1011280)), 9 | * [jpc.bpc_energy_fn()](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc.hpc_energy_fn) 10 | for bidirectional PC models ([Oliviers et al., 2025](https://arxiv.org/abs/2505.23415)), and 11 | * [jpc.epc_energy_fn()](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc.epc_energy_fn) 12 | for error-reparameterised PC ([Goemaere et al., 2025](https://arxiv.org/abs/2505.20137)). 13 | 14 | ::: jpc.pc_energy_fn 15 | 16 | --- 17 | 18 | ::: jpc.hpc_energy_fn 19 | 20 | --- 21 | 22 | ::: jpc.bpc_energy_fn 23 | 24 | --- 25 | 26 | ::: jpc.epc_energy_fn 27 | 28 | --- 29 | 30 | ::: jpc._get_param_scalings -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 thebuckleylab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/api/Gradients.md: -------------------------------------------------------------------------------- 1 | # Gradients 2 | 3 | !!! info 4 | There are two similar functions to compute the gradient of the energy with 5 | respect to the activities of a standard PC energy: [`jpc.neg_pc_activity_grad()`](https://thebuckleylab.github.io/jpc/api/Gradients/#jpc.neg_pc_activity_grad) 6 | and [`jpc.compute_pc_activity_grad()`](https://thebuckleylab.github.io/jpc/api/Gradients/#jpc.compute_pc_activity_grad). 7 | The first is used by [`jpc.solve_inference()`](https://thebuckleylab.github.io/jpc/api/Continuous-time%20Inference/#jpc.solve_inference) 8 | as gradient flow, while the second is for compatibility with discrete 9 | [optax](https://github.com/google-deepmind/optax) optimisers such as 10 | gradient descent. 11 | 12 | ::: jpc.neg_pc_activity_grad 13 | 14 | --- 15 | 16 | ::: jpc.compute_pc_activity_grad 17 | 18 | --- 19 | 20 | ::: jpc.compute_bpc_activity_grad 21 | 22 | --- 23 | 24 | ::: jpc.compute_epc_error_grad 25 | 26 | --- 27 | 28 | ::: jpc.compute_pc_param_grads 29 | 30 | --- 31 | 32 | ::: jpc.compute_hpc_param_grads 33 | 34 | --- 35 | 36 | ::: jpc.compute_bpc_param_grads 37 | 38 | --- 39 | 40 | ::: jpc.compute_epc_param_grads -------------------------------------------------------------------------------- /jpc/_core/__init__.py: -------------------------------------------------------------------------------- 1 | from ._init import ( 2 | init_activities_with_ffwd as init_activities_with_ffwd, 3 | init_activities_from_normal as init_activities_from_normal, 4 | init_activities_with_amort as init_activities_with_amort, 5 | init_epc_errors as init_epc_errors 6 | ) 7 | from ._energies import ( 8 | pc_energy_fn as pc_energy_fn, 9 | hpc_energy_fn as hpc_energy_fn, 10 | bpc_energy_fn as bpc_energy_fn, 11 | epc_energy_fn as epc_energy_fn, 12 | pdm_energy_fn as pdm_energy_fn, 13 | _get_param_scalings as _get_param_scalings 14 | ) 15 | from ._grads import ( 16 | neg_pc_activity_grad as neg_pc_activity_grad, 17 | compute_pc_activity_grad as compute_pc_activity_grad, 18 | compute_pc_param_grads as compute_pc_param_grads, 19 | compute_hpc_param_grads as compute_hpc_param_grads, 20 | compute_bpc_activity_grad as compute_bpc_activity_grad, 21 | compute_bpc_param_grads as compute_bpc_param_grads, 22 | compute_epc_error_grad as compute_epc_error_grad, 23 | compute_epc_param_grads as compute_epc_param_grads, 24 | compute_pdm_activity_grad as compute_pdm_activity_grad, 25 | compute_pdm_param_grads as compute_pdm_param_grads 26 | ) 27 | from ._infer import solve_inference as solve_inference 28 | from ._updates import ( 29 | update_pc_activities as update_pc_activities, 30 | update_pc_params as update_pc_params, 31 | update_bpc_activities as update_bpc_activities, 32 | update_bpc_params as update_bpc_params, 33 | update_epc_errors as update_epc_errors, 34 | update_epc_params as update_epc_params, 35 | update_pdm_activities as update_pdm_activities, 36 | update_pdm_params as update_pdm_params 37 | ) 38 | from ._analytical import ( 39 | compute_linear_equilib_energy as compute_linear_equilib_energy, 40 | compute_linear_activity_hessian as compute_linear_activity_hessian, 41 | compute_linear_activity_solution as compute_linear_activity_solution 42 | ) 43 | from ._errors import _check_param_type as _check_param_type 44 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # JPC Test Suite 2 | 3 | This directory contains comprehensive tests for the jpc library. 4 | 5 | ## Running Tests 6 | 7 | To run all tests: 8 | ```bash 9 | pytest tests/ 10 | ``` 11 | 12 | To run a specific test file: 13 | ```bash 14 | pytest tests/test_energies.py 15 | ``` 16 | 17 | To run with verbose output: 18 | ```bash 19 | pytest tests/ -v 20 | ``` 21 | 22 | To run with coverage: 23 | ```bash 24 | pytest tests/ --cov=jpc --cov-report=term 25 | ``` 26 | 27 | To generate an HTML coverage report: 28 | ```bash 29 | pytest tests/ --cov=jpc --cov-report=html 30 | ``` 31 | Then open `htmlcov/index.html` in your browser. 32 | 33 | To see coverage percentage only: 34 | ```bash 35 | pytest tests/ --cov=jpc --cov-report=term-missing 36 | ``` 37 | 38 | ## Test Structure 39 | 40 | - `conftest.py`: Shared fixtures and pytest configuration 41 | - `test_errors.py`: Tests for error checking functions 42 | - `test_init.py`: Tests for initialization functions 43 | - `test_energies.py`: Tests for energy functions (PC, HPC, BPC, PDM) 44 | - `test_grads.py`: Tests for gradient computation functions 45 | - `test_infer.py`: Tests for inference solving functions 46 | - `test_updates.py`: Tests for update functions (activities and parameters) 47 | - `test_analytical.py`: Tests for analytical tools 48 | - `test_utils.py`: Tests for utility functions 49 | - `test_train.py`: Tests for training functions 50 | - `test_test_functions.py`: Tests for test utility functions 51 | 52 | ## Automated Testing 53 | 54 | Tests run automatically on every push and pull request via GitHub Actions (`.github/workflows/tests.yml`). The workflow: 55 | - Runs tests on Python 3.10 and 3.11 56 | - Generates coverage reports 57 | - Automatically updates the coverage badge in the main README on pushes to `main` 58 | 59 | ## Requirements 60 | 61 | The tests require: 62 | - pytest 63 | - pytest-cov (for coverage reporting) 64 | - jax 65 | - jax.numpy 66 | - equinox 67 | - diffrax 68 | - optax 69 | 70 | All dependencies should be installed when installing jpc. Coverage configuration is defined in `pyproject.toml` under `[tool.coverage.*]`. 71 | -------------------------------------------------------------------------------- /docs/advanced_usage.md: -------------------------------------------------------------------------------- 1 | # Advanced usage 2 | 3 | Advanced users can access all the underlying functions of `jpc.make_pc_step()` 4 | as well as additional features. A custom PC training step looks like the 5 | following: 6 | ```py 7 | import jpc 8 | 9 | # 1. initialise activities with a feedforward pass 10 | activities = jpc.init_activities_with_ffwd(model=model, input=x) 11 | 12 | # 2. run inference to equilibrium 13 | equilibrated_activities = jpc.solve_inference( 14 | params=(model, None), 15 | activities=activities, 16 | output=y, 17 | input=x 18 | ) 19 | 20 | # 3. update parameters at the activities' solution with PC 21 | param_update_result = jpc.update_params( 22 | params=(model, None), 23 | activities=equilibrated_activities, 24 | optim=param_optim, 25 | opt_state=param_opt_state, 26 | output=y, 27 | input=x 28 | ) 29 | 30 | # updated model and optimiser 31 | model = param_update_result["model"] 32 | param_opt_state = param_update_result["opt_state"] 33 | ``` 34 | which can be embedded in a jitted function with any other additional 35 | computations. One can also use any [optax 36 | ](https://optax.readthedocs.io/en/latest/api/optimizers.html) optimiser to 37 | equilibrate the inference dynamics by replacing the function in step 2, as 38 | shown below. 39 | ```py 40 | activity_optim = optax.adam(1e-3) 41 | 42 | # 1. initialise activities 43 | ... 44 | 45 | # 2. infer with adam 46 | activity_opt_state = activity_optim.init(activities) 47 | 48 | for t in range(T): 49 | activity_update_result = jpc.update_activities( 50 | params=(model, None), 51 | activities=activities, 52 | optim=activity_optim, 53 | opt_state=activity_opt_state, 54 | output=y, 55 | input=x 56 | ) 57 | # updated activities and optimiser 58 | activities = activity_update_result["activities"] 59 | activity_opt_state = activity_update_result["opt_state"] 60 | 61 | # 3. update parameters at the activities' solution with PC 62 | ... 63 | ``` 64 | See the [updates docs 65 | ](https://thebuckleylab.github.io/jpc/api/Updates/) for more details. JPC also 66 | comes with some analytical tools that can be used to study and potentially 67 | diagnose issues with PCNs 68 | (see [docs 69 | ](https://thebuckleylab.github.io/jpc/api/Analytical%20tools/) 70 | and [example notebook 71 | ](https://thebuckleylab.github.io/jpc/examples/theoretical_energy_with_linear_net/)). 72 | -------------------------------------------------------------------------------- /experiments/library_paper/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from torch import manual_seed 5 | from diffrax import Euler, Heun, Midpoint, Ralston, Bosh3, Tsit5, Dopri5, Dopri8 6 | 7 | 8 | def set_seed(seed): 9 | np.random.seed(seed) 10 | random.seed(seed) 11 | manual_seed(seed) 12 | 13 | 14 | def get_ode_solver(name): 15 | if name == "Euler": 16 | return Euler() 17 | elif name == "Heun": 18 | return Heun() 19 | elif name == "Midpoint": 20 | return Midpoint() 21 | elif name == "Ralston": 22 | return Ralston() 23 | elif name == "Bosh3": 24 | return Bosh3() 25 | elif name == "Tsit5": 26 | return Tsit5() 27 | elif name == "Dopri5": 28 | return Dopri5() 29 | elif name == "Dopri8": 30 | return Dopri8() 31 | 32 | 33 | def setup_mlp_experiment( 34 | results_dir, 35 | dataset, 36 | width, 37 | n_hidden, 38 | act_fn, 39 | max_t1, 40 | activity_lr, 41 | param_lr, 42 | activity_optim_id, 43 | seed 44 | ): 45 | print( 46 | f""" 47 | Starting experiment with configuration: 48 | 49 | Dataset: {dataset} 50 | Width: {width} 51 | N hidden: {n_hidden} 52 | Act fn: {act_fn} 53 | Max t1: {max_t1} 54 | Activity step size: {activity_lr} 55 | Param learning rate: {param_lr} 56 | Activity optim: {activity_optim_id} 57 | Seed: {seed} 58 | """ 59 | ) 60 | return os.path.join( 61 | results_dir, 62 | dataset, 63 | f"width_{width}", 64 | f"{n_hidden}_n_hidden", 65 | act_fn, 66 | f"max_t1_{max_t1}", 67 | f"activity_lr_{activity_lr}", 68 | f"param_lr_{param_lr}", 69 | activity_optim_id, 70 | str(seed) 71 | ) 72 | 73 | 74 | def get_min_iter(lists): 75 | min_iter = 100000 76 | for i in lists: 77 | if len(i) < min_iter: 78 | min_iter = len(i) 79 | return min_iter 80 | 81 | 82 | def get_min_iter_metrics(metrics): 83 | n_seeds = len(metrics) 84 | min_iter = get_min_iter(lists=metrics) 85 | 86 | min_iter_metrics = np.zeros((n_seeds, min_iter)) 87 | for seed in range(n_seeds): 88 | min_iter_metrics[seed, :] = metrics[seed][:min_iter] 89 | 90 | return min_iter_metrics 91 | 92 | 93 | def compute_metric_stats(metric): 94 | min_iter_metrics = get_min_iter_metrics(metrics=metric) 95 | metric_means = min_iter_metrics.mean(axis=0) 96 | metric_stds = min_iter_metrics.std(axis=0) 97 | return metric_means, metric_stds 98 | -------------------------------------------------------------------------------- /jpc/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | 3 | from ._core import ( 4 | init_activities_with_ffwd as init_activities_with_ffwd, 5 | init_activities_from_normal as init_activities_from_normal, 6 | init_activities_with_amort as init_activities_with_amort, 7 | init_epc_errors as init_epc_errors, 8 | pc_energy_fn as pc_energy_fn, 9 | hpc_energy_fn as hpc_energy_fn, 10 | bpc_energy_fn as bpc_energy_fn, 11 | epc_energy_fn as epc_energy_fn, 12 | pdm_energy_fn as pdm_energy_fn, 13 | _get_param_scalings as _get_param_scalings, 14 | neg_pc_activity_grad as neg_pc_activity_grad, 15 | solve_inference as solve_inference, 16 | compute_pc_activity_grad as compute_pc_activity_grad, 17 | compute_pc_param_grads as compute_pc_param_grads, 18 | compute_hpc_param_grads as compute_hpc_param_grads, 19 | compute_bpc_activity_grad as compute_bpc_activity_grad, 20 | compute_bpc_param_grads as compute_bpc_param_grads, 21 | compute_epc_error_grad as compute_epc_error_grad, 22 | compute_epc_param_grads as compute_epc_param_grads, 23 | compute_pdm_activity_grad as compute_pdm_activity_grad, 24 | compute_pdm_param_grads as compute_pdm_param_grads, 25 | update_pc_activities as update_pc_activities, 26 | update_pc_params as update_pc_params, 27 | update_bpc_activities as update_bpc_activities, 28 | update_bpc_params as update_bpc_params, 29 | update_epc_errors as update_epc_errors, 30 | update_epc_params as update_epc_params, 31 | update_pdm_activities as update_pdm_activities, 32 | update_pdm_params as update_pdm_params, 33 | compute_linear_equilib_energy as compute_linear_equilib_energy, 34 | compute_linear_activity_hessian as compute_linear_activity_hessian, 35 | compute_linear_activity_solution as compute_linear_activity_solution, 36 | _check_param_type as _check_param_type 37 | ) 38 | from ._utils import ( 39 | make_mlp as make_mlp, 40 | make_skip_model as make_skip_model, 41 | get_act_fn as get_act_fn, 42 | mse_loss as mse_loss, 43 | cross_entropy_loss as cross_entropy_loss, 44 | compute_accuracy as compute_accuracy, 45 | get_t_max as get_t_max, 46 | compute_activity_norms as compute_activity_norms, 47 | compute_infer_energies as compute_infer_energies, 48 | compute_param_norms as compute_param_norms 49 | ) 50 | from ._train import ( 51 | make_pc_step as make_pc_step, 52 | make_hpc_step as make_hpc_step 53 | ) 54 | from ._test import ( 55 | test_discriminative_pc as test_discriminative_pc, 56 | test_generative_pc as test_generative_pc, 57 | test_hpc as test_hpc 58 | ) 59 | 60 | 61 | __version__ = importlib.metadata.version("jpc") 62 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "jpc" 3 | version = "1.0.0" 4 | description = "Flexible Inference for Predictive Coding Networks in JAX." 5 | readme = "README.md" 6 | requires-python =">=3.10" 7 | license = {file = "LICENSE"} 8 | authors = [ 9 | {name = "Francesco Innocenti", email = "F.Innocenti@sussex.ac.uk"}, 10 | ] 11 | keywords = [ 12 | "jax", 13 | "predictive-coding", 14 | "neural-networks", 15 | "hybrid-predictive-coding", 16 | "deep-learning", 17 | "local-learning", 18 | "inference-learning", 19 | "mupc" 20 | ] 21 | classifiers = [ 22 | "Development Status :: 3 - Alpha", 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Financial and Insurance Industry", 25 | "Intended Audience :: Information Technology", 26 | "Intended Audience :: Science/Research", 27 | "License :: OSI Approved :: MIT License", 28 | "Natural Language :: English", 29 | "Programming Language :: Python :: 3", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | "Topic :: Scientific/Engineering :: Information Analysis", 32 | "Topic :: Scientific/Engineering :: Mathematics", 33 | ] 34 | urls = {repository = "https://github.com/thebuckleylab/jpc"} 35 | dependencies = [ 36 | "jax>=0.4.38,<=0.5.2", # to prevent jaxlib import error 37 | "equinox>=0.11.2", 38 | "diffrax>=0.6.0", 39 | "optax>=0.2.4", 40 | "jaxtyping>=0.2.24", 41 | "pytest-cov>=4.0.0" 42 | ] 43 | 44 | [build-system] 45 | build-backend = "hatchling.build" 46 | requires = ["hatchling"] 47 | 48 | [tool.hatch.build] 49 | include = ["jpc/*"] 50 | 51 | [tool.ruff] 52 | extend-include = ["*.ipynb"] 53 | src = [] 54 | 55 | [tool.ruff.lint] 56 | fixable = ["I001", "F401"] 57 | ignore = ["E402", "E721", "E731", "E741", "F722"] 58 | ignore-init-module-imports = true 59 | select = ["E", "F", "I001"] 60 | 61 | [tool.ruff.lint.isort] 62 | combine-as-imports = true 63 | extra-standard-library = ["typing_extensions"] 64 | lines-after-imports = 2 65 | order-by-type = false 66 | 67 | [tool.pytest.ini_options] 68 | testpaths = ["tests"] 69 | python_files = ["test_*.py"] 70 | python_classes = ["Test*"] 71 | python_functions = ["test_*"] 72 | norecursedirs = [".git", "venv", "site", "jpc.egg-info", "examples", "experiments", "jpc"] 73 | 74 | [tool.coverage.run] 75 | source = ["jpc"] 76 | omit = [ 77 | "*/tests/*", 78 | "*/test_*.py", 79 | "*/__pycache__/*", 80 | "*/site/*", 81 | "*/examples/*", 82 | "*/experiments/*", 83 | ] 84 | 85 | [tool.coverage.report] 86 | exclude_lines = [ 87 | "pragma: no cover", 88 | "def __repr__", 89 | "raise AssertionError", 90 | "raise NotImplementedError", 91 | "if __name__ == .__main__.:", 92 | "if TYPE_CHECKING:", 93 | "@abstractmethod", 94 | ] -------------------------------------------------------------------------------- /experiments/mupc_paper/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | anyio==4.8.0 3 | appnope==0.1.4 4 | argon2-cffi==23.1.0 5 | argon2-cffi-bindings==21.2.0 6 | arrow==1.3.0 7 | asttokens==3.0.0 8 | async-lru==2.0.4 9 | attrs==25.1.0 10 | babel==2.17.0 11 | beautifulsoup4==4.13.3 12 | bleach==6.2.0 13 | certifi==2025.1.31 14 | cffi==1.17.1 15 | charset-normalizer==3.4.1 16 | chex==0.1.89 17 | comm==0.2.2 18 | contourpy==1.3.1 19 | cycler==0.12.1 20 | debugpy==1.8.13 21 | decorator==5.2.1 22 | defusedxml==0.7.1 23 | diffrax==0.6.2 24 | equinox==0.11.12 25 | etils==1.12.0 26 | exceptiongroup==1.2.2 27 | executing==2.2.0 28 | fastjsonschema==2.21.1 29 | filelock==3.17.0 30 | fonttools==4.56.0 31 | fqdn==1.5.1 32 | fsspec==2025.2.0 33 | h11==0.14.0 34 | httpcore==1.0.7 35 | httpx==0.28.1 36 | idna==3.10 37 | ipykernel==6.29.5 38 | ipython==8.33.0 39 | ipywidgets==8.1.5 40 | isoduration==20.11.0 41 | jax==0.5.2 42 | jaxlib==0.5.1 43 | jaxtyping==0.2.38 44 | jedi==0.19.2 45 | Jinja2==3.1.5 46 | json5==0.10.0 47 | jsonpointer==3.0.0 48 | jsonschema==4.23.0 49 | jsonschema-specifications==2024.10.1 50 | jupyter==1.1.1 51 | jupyter-console==6.6.3 52 | jupyter-events==0.12.0 53 | jupyter-lsp==2.2.5 54 | jupyter_client==8.6.3 55 | jupyter_core==5.7.2 56 | jupyter_server==2.15.0 57 | jupyter_server_terminals==0.5.3 58 | jupyterlab==4.3.5 59 | jupyterlab_pygments==0.3.0 60 | jupyterlab_server==2.27.3 61 | jupyterlab_widgets==3.0.13 62 | kaleido==0.2.1 63 | kiwisolver==1.4.8 64 | lineax==0.0.7 65 | MarkupSafe==3.0.2 66 | matplotlib==3.10.1 67 | matplotlib-inline==0.1.7 68 | mistune==3.1.2 69 | ml_dtypes==0.5.1 70 | mpmath==1.3.0 71 | narwhals==1.29.0 72 | nbclient==0.10.2 73 | nbconvert==7.16.6 74 | nbformat==5.10.4 75 | nest-asyncio==1.6.0 76 | networkx==3.4.2 77 | notebook==7.3.2 78 | notebook_shim==0.2.4 79 | numpy==2.2.3 80 | opt_einsum==3.4.0 81 | optax==0.2.4 82 | optimistix==0.0.10 83 | overrides==7.7.0 84 | packaging==24.2 85 | pandocfilters==1.5.1 86 | parso==0.8.4 87 | pexpect==4.9.0 88 | pillow==11.1.0 89 | platformdirs==4.3.6 90 | plotly==5.24.1 91 | prometheus_client==0.21.1 92 | prompt_toolkit==3.0.50 93 | psutil==7.0.0 94 | ptyprocess==0.7.0 95 | pure_eval==0.2.3 96 | pycparser==2.22 97 | Pygments==2.19.1 98 | pyparsing==3.2.1 99 | python-dateutil==2.9.0.post0 100 | python-json-logger==3.2.1 101 | PyYAML==6.0.2 102 | pyzmq==26.2.1 103 | referencing==0.36.2 104 | requests==2.32.3 105 | rfc3339-validator==0.1.4 106 | rfc3986-validator==0.1.1 107 | rpds-py==0.23.1 108 | scipy==1.15.2 109 | Send2Trash==1.8.3 110 | six==1.17.0 111 | sniffio==1.3.1 112 | soupsieve==2.6 113 | stack-data==0.6.3 114 | sympy==1.13.1 115 | tenacity==9.0.0 116 | terminado==0.18.1 117 | tinycss2==1.4.0 118 | tomli==2.2.1 119 | toolz==1.0.0 120 | torch==2.5.1+cu121 121 | torchvision==0.20.1+cu121 122 | tornado==6.4.2 123 | traitlets==5.14.3 124 | typeguard==2.13.3 125 | types-python-dateutil==2.9.0.20241206 126 | typing_extensions==4.12.2 127 | uri-template==1.3.0 128 | urllib3==2.3.0 129 | wadler_lindig==0.1.3 130 | wcwidth==0.2.13 131 | webcolors==24.11.1 132 | webencodings==0.5.1 133 | websocket-client==1.8.0 134 | widgetsnbextension==4.0.13 135 | -------------------------------------------------------------------------------- /experiments/mupc_paper/README.md: -------------------------------------------------------------------------------- 1 | # μPC paper 2 | 3 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/mupc.ipynb) [![Paper](https://img.shields.io/badge/Paper-arXiv:2508.01191-%23f2806bff.svg)](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1)) 4 | 5 |

6 | 7 | 8 | 9 |

10 | 11 | This folder contains code to reproduce all the experiments for the NeurIPS 2025 12 | paper ["μPC": Scaling Predictive Coding to 100+ Layer Networks](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1)). For a high-level summary, see 13 | [this blog post](https://francesco-innocenti.github.io/posts/2025/05/20/Scaling-Predictive-Coding-to-100+-Layer-Networks/). 14 | And for a tutorial, see [this example notebook](https://thebuckleylab.github.io/jpc/examples/mupc/). 15 | 16 | 17 | ## Setup 18 | Clone the `jpc` repo. We recommend using a virtual environment, e.g. 19 | ``` 20 | python3 -m venv venv 21 | ``` 22 | Install `jpc` 23 | ``` 24 | pip install -e . 25 | ``` 26 | For GPU usage, upgrade jax to the appropriate cuda version (12 as an example 27 | here). 28 | 29 | ``` 30 | pip install --upgrade "jax[cuda12]==0.5.2" 31 | ``` 32 | Now navigate to `experiments/mupc_paper` and install all the requirements 33 | 34 | ``` 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | 39 | ## Compute resources 40 | We recommend using a GPU for the experiments with 64- and 128-layer networks. 41 | 42 | 43 | ## Scripts 44 | * `train_pcn_no_metrics.py`: This is the main script that was used to produce 45 | results for Figs. 1, 5, & A.16-A.18. 46 | * `analyse_activity_hessian`: This script can be used to reproduce results 47 | related to spectral properties of the activity Hessian at initialisation 48 | (Figs. 2 & 4, & Figs. A.1-A.7, A.12 & A.21). 49 | * `train_pcn.py`: This was mainly used to monitor the condition number of the 50 | activity Hessian during training (Figs. 3, A.8-A.9, A.13-A.14,, A.22-A.28). 51 | * `test_energy_theory.py`: This can be used to reproduce results related to 52 | the convergence behaviour of μPC to BP shown in Section 6 53 | (Figs. 6 & A.32-A.33). 54 | * `train_bpn.py`: Used to obtain all the results with backprop. 55 | * `test_mlp_fwd_pass.py`: Used for results of Fig. A.29. 56 | * `toy_experiments.ipynb`: Used for many secondary results in the Appendix. 57 | 58 | The majority of results are plotted in `plot_results.ipynb` under informative 59 | headings. For details of all the experiments, see Section A.4 of the paper. 60 | 61 | 62 | ## Citation 63 | If you use μPC in your work, please cite the paper: 64 | 65 | ```bibtex 66 | @article{innocenti2025mu, 67 | title={$$\backslash$mu $ PC: Scaling Predictive Coding to 100+ Layer Networks}, 68 | author={Innocenti, Francesco and Achour, El Mehdi and Buckley, Christopher L}, 69 | journal={arXiv preprint arXiv:2505.13124}, 70 | year={2025} 71 | } 72 | ``` 73 | Also consider starring the repo! ⭐️ 74 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Pytest configuration and shared fixtures for jpc tests.""" 2 | 3 | import pytest 4 | import jax 5 | import jax.numpy as jnp 6 | import equinox.nn as nn 7 | from jpc import make_mlp 8 | 9 | 10 | @pytest.fixture 11 | def key(): 12 | """Random key for testing.""" 13 | return jax.random.PRNGKey(42) 14 | 15 | 16 | @pytest.fixture 17 | def batch_size(): 18 | """Batch size for testing.""" 19 | return 4 20 | 21 | 22 | @pytest.fixture 23 | def input_dim(): 24 | """Input dimension for testing.""" 25 | return 10 26 | 27 | 28 | @pytest.fixture 29 | def hidden_dim(): 30 | """Hidden dimension for testing.""" 31 | return 20 32 | 33 | 34 | @pytest.fixture 35 | def output_dim(): 36 | """Output dimension for testing.""" 37 | return 5 38 | 39 | 40 | @pytest.fixture 41 | def depth(): 42 | """Network depth for testing.""" 43 | return 3 44 | 45 | 46 | @pytest.fixture 47 | def simple_model(key, input_dim, hidden_dim, output_dim, depth): 48 | """Create a simple MLP model for testing.""" 49 | return make_mlp( 50 | key=key, 51 | input_dim=input_dim, 52 | width=hidden_dim, 53 | depth=depth, 54 | output_dim=output_dim, 55 | act_fn="relu", 56 | use_bias=False, 57 | param_type="sp" 58 | ) 59 | 60 | 61 | @pytest.fixture 62 | def x(key, batch_size, input_dim): 63 | """Sample input data.""" 64 | return jax.random.normal(key, (batch_size, input_dim)) 65 | 66 | 67 | @pytest.fixture 68 | def y(key, batch_size, output_dim): 69 | """Sample output/target data.""" 70 | return jax.random.normal(key, (batch_size, output_dim)) 71 | 72 | 73 | @pytest.fixture 74 | def y_onehot(key, batch_size, output_dim): 75 | """Sample one-hot encoded target data.""" 76 | indices = jax.random.randint(key, (batch_size,), 0, output_dim) 77 | return jax.nn.one_hot(indices, output_dim) 78 | 79 | 80 | @pytest.fixture 81 | def layer_sizes(input_dim, hidden_dim, output_dim, depth): 82 | """Layer sizes for testing.""" 83 | sizes = [input_dim] 84 | for _ in range(depth - 1): 85 | sizes.append(hidden_dim) 86 | sizes.append(output_dim) 87 | return sizes 88 | 89 | 90 | def pytest_ignore_collect(collection_path, config): 91 | """Skip jpc/_test.py which contains library functions, not test functions.""" 92 | # Skip _test.py files in the jpc directory (library functions, not test functions) 93 | if collection_path.name == "_test.py": 94 | file_str = str(collection_path.resolve()) 95 | # Check if it's in jpc directory but not in tests 96 | if "/jpc/" in file_str and "/tests/" not in file_str: 97 | return True # Ignore this file 98 | return None # Use default behavior for other files 99 | 100 | 101 | def pytest_collection_modifyitems(config, items): 102 | """Remove any test items that were incorrectly collected from jpc._test module.""" 103 | filtered_items = [] 104 | for item in items: 105 | # Check file path 106 | item_path = None 107 | if hasattr(item, 'fspath') and item.fspath: 108 | item_path = str(item.fspath) 109 | elif hasattr(item, 'path') and item.path: 110 | item_path = str(item.path) 111 | elif hasattr(item, 'nodeid'): 112 | # Extract path from nodeid (format: path/to/file.py::test_function) 113 | nodeid = item.nodeid 114 | if '::' in nodeid: 115 | item_path = nodeid.split('::')[0] 116 | 117 | # Check if it's from jpc/_test.py 118 | if item_path and ('jpc/_test.py' in item_path or 'jpc\\_test.py' in item_path): 119 | continue 120 | 121 | # Check module name 122 | if hasattr(item, 'module') and item.module: 123 | module_name = getattr(item.module, '__name__', None) 124 | if module_name == 'jpc._test': 125 | continue 126 | 127 | # Check function location 128 | if hasattr(item, 'function') and hasattr(item.function, '__module__'): 129 | if item.function.__module__ == 'jpc._test': 130 | continue 131 | 132 | filtered_items.append(item) 133 | 134 | items[:] = filtered_items 135 | -------------------------------------------------------------------------------- /docs/basic_usage.md: -------------------------------------------------------------------------------- 1 | JPC provides two types of API depending on the use case: 2 | 3 | * a simple, high-level API that allows to train and test models with predictive 4 | coding in a few lines of code, and 5 | * a more advanced API offering greater flexibility as well as additional features. 6 | 7 | # Basic usage 8 | At a high level, JPC provides a single convenience function `jpc.make_pc_step()` 9 | to update the parameters of a neural network with PC. 10 | ```py 11 | import jax.random as jr 12 | import jax.numpy as jnp 13 | import equinox as eqx 14 | import optax 15 | import jpc 16 | 17 | # toy data 18 | x = jnp.array([1., 1., 1.]) 19 | y = -x 20 | 21 | # define model and optimiser 22 | key = jr.PRNGKey(0) 23 | model = jpc.make_mlp( 24 | key, 25 | input_dim=3, 26 | width=50, 27 | depth=5, 28 | output_dim=3 29 | act_fn="relu" 30 | ) 31 | optim = optax.adam(1e-3) 32 | opt_state = optim.init( 33 | (eqx.filter(model, eqx.is_array), None) 34 | ) 35 | 36 | # perform one training step with PC 37 | update_result = jpc.make_pc_step( 38 | model=model, 39 | optim=optim, 40 | opt_state=opt_state, 41 | output=y, 42 | input=x 43 | ) 44 | 45 | # updated model and optimiser 46 | model, opt_state = update_result["model"], update_result["opt_state"] 47 | ``` 48 | As shown above, at a minimum `jpc.make_pc_step()` takes a model, an [optax 49 | ](https://github.com/google-deepmind/optax) optimiser and its 50 | state, and some data. The model needs to be compatible with PC updates in the 51 | sense that it's split into callable layers (see the 52 | [example notebooks 53 | ](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/)). Also note 54 | that the `input` is actually not needed for unsupervised training. In fact, 55 | `jpc.make_pc_step()` can be used for classification and generation tasks, for 56 | supervised as well as unsupervised training (again see the [example notebooks 57 | ](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/)). 58 | 59 | Under the hood, `jpc.make_pc_step()` uses [diffrax 60 | ](https://github.com/patrick-kidger/diffrax) to solve the activity (inference) 61 | dynamics of PC. Many default arguments, for example related to the ODE solver, 62 | can be changed, including the ODE solver, and there is an option to record a 63 | variety of metrics such as loss, accuracy, and energies. See the [docs 64 | ](https://thebuckleylab.github.io/jpc/api/Training/#jpc.make_pc_step) for more 65 | details. 66 | 67 | A similar convenience function `jpc.make_hpc_step()` is provided for updating the 68 | parameters of a hybrid PCN ([Tschantz et al., 2023 69 | ](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1011280)). 70 | ```py 71 | import jax.random as jr 72 | import equinox as eqx 73 | import optax 74 | import jpc 75 | 76 | # models 77 | key = jr.PRNGKey(0) 78 | subkeys = jr.split(key, 2) 79 | 80 | input_dim, output_dim = 10, 3 81 | width, depth = 100, 5 82 | generator = jpc.make_mlp( 83 | subkeys[0], 84 | input_dim=input_dim, 85 | width=width, 86 | depth=depth, 87 | output_dim=output_dim 88 | act_fn="tanh" 89 | ) 90 | # NOTE that the input and output of the amortiser are reversed 91 | amortiser = jpc.make_mlp( 92 | subkeys[0], 93 | input_dim=output_dim, 94 | width=width, 95 | depth=depth, 96 | output_dim=input_dim 97 | act_fn="tanh" 98 | ) 99 | 100 | # optimisers 101 | gen_optim = optax.adam(1e-3) 102 | amort_optim = optax.adam(1e-3) 103 | gen_pt_state = gen_optim.init( 104 | (eqx.filter(generator, eqx.is_array), None) 105 | ) 106 | amort_opt_state = amort_optim.init( 107 | eqx.filter(amortiser, eqx.is_array) 108 | ) 109 | 110 | update_result = jpc.make_hpc_step( 111 | generator=generator, 112 | amortiser=amortiser, 113 | optims=[gen_optim, amort_optim], 114 | opt_states=[gen_opt_state, amort_opt_state], 115 | output=y, 116 | input=x 117 | ) 118 | generator, amortiser = update_result["generator"], update_result["amortiser"] 119 | opt_states = update_result["opt_states"] 120 | gen_loss, amort_loss = update_result["losses"] 121 | ``` 122 | See the [docs 123 | ](https://thebuckleylab.github.io/jpc/api/Training/#jpc.make_hpc_step) and the 124 | [example notebook 125 | ](https://thebuckleylab.github.io/jpc/examples/hybrid_pc/) for more details. 126 | -------------------------------------------------------------------------------- /tests/test_init.py: -------------------------------------------------------------------------------- 1 | """Tests for initialization functions.""" 2 | 3 | import pytest 4 | import jax 5 | import jax.numpy as jnp 6 | from jpc import ( 7 | init_activities_with_ffwd, 8 | init_activities_from_normal, 9 | init_activities_with_amort, 10 | init_epc_errors 11 | ) 12 | 13 | 14 | def test_init_activities_with_ffwd(simple_model, x): 15 | """Test feedforward initialization.""" 16 | activities = init_activities_with_ffwd( 17 | model=simple_model, 18 | input=x, 19 | param_type="sp" 20 | ) 21 | 22 | assert len(activities) == len(simple_model) 23 | assert activities[0].shape == (x.shape[0], simple_model[0][1].weight.shape[0]) 24 | assert activities[-1].shape == (x.shape[0], simple_model[-1][1].weight.shape[0]) 25 | 26 | 27 | def test_init_activities_with_ffwd_mupc(key, x, input_dim, hidden_dim, output_dim, depth): 28 | """Test feedforward initialization with mupc parameterization.""" 29 | from jpc import make_mlp 30 | 31 | model = make_mlp( 32 | key=key, 33 | input_dim=input_dim, 34 | width=hidden_dim, 35 | depth=depth, 36 | output_dim=output_dim, 37 | act_fn="relu", 38 | use_bias=False, 39 | param_type="mupc" 40 | ) 41 | 42 | activities = init_activities_with_ffwd( 43 | model=model, 44 | input=x, 45 | param_type="mupc" 46 | ) 47 | 48 | assert len(activities) == len(model) 49 | 50 | 51 | def test_init_activities_with_ffwd_skip_connections(simple_model, x): 52 | """Test feedforward initialization with skip connections.""" 53 | from jpc import make_skip_model 54 | 55 | skip_model = make_skip_model(len(simple_model)) 56 | activities = init_activities_with_ffwd( 57 | model=simple_model, 58 | input=x, 59 | skip_model=skip_model, 60 | param_type="sp" 61 | ) 62 | 63 | assert len(activities) == len(simple_model) 64 | 65 | 66 | def test_init_activities_from_normal_supervised(key, layer_sizes, batch_size): 67 | """Test random initialization in supervised mode.""" 68 | activities = init_activities_from_normal( 69 | key=key, 70 | layer_sizes=layer_sizes, 71 | mode="supervised", 72 | batch_size=batch_size, 73 | sigma=0.05 74 | ) 75 | 76 | # In supervised mode, input layer is not initialized 77 | assert len(activities) == len(layer_sizes) - 1 78 | for i, (act, size) in enumerate(zip(activities, layer_sizes[1:]), 1): 79 | assert act.shape == (batch_size, size) 80 | 81 | 82 | def test_init_activities_from_normal_unsupervised(key, layer_sizes, batch_size): 83 | """Test random initialization in unsupervised mode.""" 84 | activities = init_activities_from_normal( 85 | key=key, 86 | layer_sizes=layer_sizes, 87 | mode="unsupervised", 88 | batch_size=batch_size, 89 | sigma=0.05 90 | ) 91 | 92 | # In unsupervised mode, all layers including input are initialized 93 | assert len(activities) == len(layer_sizes) 94 | for act, size in zip(activities, layer_sizes): 95 | assert act.shape == (batch_size, size) 96 | 97 | 98 | def test_init_activities_with_amort(key, simple_model, y, input_dim, hidden_dim, output_dim, depth): 99 | """Test amortized initialization.""" 100 | from jpc import make_mlp 101 | 102 | amortiser = make_mlp( 103 | key=key, 104 | input_dim=output_dim, 105 | width=hidden_dim, 106 | depth=depth, 107 | output_dim=input_dim, 108 | act_fn="relu", 109 | use_bias=False, 110 | param_type="sp" 111 | ) 112 | 113 | activities = init_activities_with_amort( 114 | amortiser=amortiser, 115 | generator=simple_model, 116 | input=y 117 | ) 118 | 119 | # Should return reversed activities plus dummy target prediction 120 | assert len(activities) == len(amortiser) + 1 121 | 122 | 123 | def test_init_epc_errors_supervised(layer_sizes, batch_size): 124 | """Test EPC error initialization in supervised mode.""" 125 | errors = init_epc_errors( 126 | layer_sizes=layer_sizes, 127 | batch_size=batch_size, 128 | mode="supervised" 129 | ) 130 | 131 | # In supervised mode, errors are initialized for layers 1 to L-1 (hidden layers only) 132 | assert len(errors) == len(layer_sizes) - 1 133 | for i, (err, size) in enumerate(zip(errors, layer_sizes[1:]), 1): 134 | assert err.shape == (batch_size, size) 135 | assert jnp.allclose(err, 0.0) # Should be zero-initialized 136 | 137 | 138 | -------------------------------------------------------------------------------- /docs/_static/custom_css.css: -------------------------------------------------------------------------------- 1 | /* Fix /page#foo going to the top of the viewport and being hidden by the navbar */ 2 | html { 3 | scroll-padding-top: 50px; 4 | } 5 | 6 | /* Hide the dark logo by default */ 7 | #logo_dark_mode { 8 | display: none; 9 | } 10 | 11 | /* Show the light logo by default */ 12 | #logo_light_mode { 13 | display: block; 14 | } 15 | 16 | /* Switch display property based on color scheme */ 17 | [data-md-color-scheme="default"] { 18 | --md-footer-logo-dark-mode: none; 19 | --md-footer-logo-light-mode: block; 20 | } 21 | 22 | [data-md-color-scheme="slate"] { 23 | --md-footer-logo-dark-mode: block; 24 | --md-footer-logo-light-mode: none; 25 | } 26 | 27 | /* Apply the custom variables */ 28 | #logo_light_mode { 29 | display: var(--md-footer-logo-light-mode); 30 | } 31 | 32 | #logo_dark_mode { 33 | display: var(--md-footer-logo-dark-mode); 34 | } 35 | 36 | /* Adjust logo size */ 37 | .md-header__button.md-logo { 38 | margin: 0; 39 | padding: 0.5; 40 | } 41 | .md-header__button.md-logo img, .md-header__button.md-logo svg { 42 | height: 2.5rem; 43 | width: auto; 44 | } 45 | 46 | /* Fit the Twitter handle alongside the GitHub one in the top right. */ 47 | 48 | div.md-header__source { 49 | width: revert; 50 | max-width: revert; 51 | } 52 | 53 | a.md-source { 54 | display: inline-block; 55 | } 56 | 57 | .md-source__repository { 58 | max-width: 100%; 59 | } 60 | 61 | /* Emphasise sections of nav on left hand side */ 62 | 63 | nav.md-nav { 64 | padding-left: 5px; 65 | } 66 | 67 | nav.md-nav--secondary { 68 | border-left: revert !important; 69 | } 70 | 71 | .md-nav__title { 72 | font-size: 0.9rem; 73 | } 74 | 75 | .md-nav__item--section > .md-nav__link { 76 | font-size: 0.9rem; 77 | } 78 | 79 | /* Indent autogenerated documentation */ 80 | 81 | div.doc-contents { 82 | padding-left: 25px; 83 | border-left: 4px solid rgba(230, 230, 230); 84 | } 85 | 86 | /* Increase visibility of splitters "---" */ 87 | 88 | [data-md-color-scheme="default"] .md-typeset hr { 89 | border-bottom-color: rgb(0, 0, 0); 90 | border-bottom-width: 1pt; 91 | } 92 | 93 | [data-md-color-scheme="slate"] .md-typeset hr { 94 | border-bottom-color: rgb(230, 230, 230); 95 | } 96 | 97 | /* More space at the bottom of the page */ 98 | 99 | .md-main__inner { 100 | margin-bottom: 1.5rem; 101 | } 102 | 103 | /* Remove prev/next footer buttons */ 104 | 105 | .md-footer__inner { 106 | display: none; 107 | } 108 | 109 | /* Change font sizes */ 110 | 111 | html { 112 | /* Decrease font size for overall webpage 113 | Down from 137.5% which is the Material default */ 114 | font-size: 110%; 115 | } 116 | 117 | .md-typeset .admonition { 118 | /* Increase font size in admonitions */ 119 | font-size: 100% !important; 120 | } 121 | 122 | .md-typeset details { 123 | /* Increase font size in details */ 124 | font-size: 100% !important; 125 | } 126 | 127 | .md-typeset h1 { 128 | font-size: 1.6rem; 129 | } 130 | 131 | .md-typeset h2 { 132 | font-size: 1.5rem; 133 | } 134 | 135 | .md-typeset h3 { 136 | font-size: 1.3rem; 137 | } 138 | 139 | .md-typeset h4 { 140 | font-size: 1.1rem; 141 | } 142 | 143 | .md-typeset h5 { 144 | font-size: 0.9rem; 145 | } 146 | 147 | .md-typeset h6 { 148 | font-size: 0.8rem; 149 | } 150 | 151 | /* Bugfix: remove the superfluous parts generated when doing: 152 | 153 | ??? Blah 154 | 155 | ::: library.something 156 | */ 157 | 158 | .md-typeset details .mkdocstrings > h4 { 159 | display: none; 160 | } 161 | 162 | .md-typeset details .mkdocstrings > h5 { 163 | display: none; 164 | } 165 | 166 | /* Change default colours for tags */ 167 | 168 | [data-md-color-scheme="default"] { 169 | --md-typeset-a-color: rgb(0, 189, 164) !important; 170 | } 171 | [data-md-color-scheme="slate"] { 172 | --md-typeset-a-color: rgb(0, 189, 164) !important; 173 | } 174 | 175 | /* Highlight functions, classes etc. type signatures. Really helps to make clear where 176 | one item ends and another begins. */ 177 | 178 | [data-md-color-scheme="default"] { 179 | --doc-heading-color: #DDD; 180 | --doc-heading-border-color: #CCC; 181 | --doc-heading-color-alt: #F0F0F0; 182 | } 183 | [data-md-color-scheme="slate"] { 184 | --doc-heading-color: rgb(25,25,33); 185 | --doc-heading-border-color: rgb(25,25,33); 186 | --doc-heading-color-alt: rgb(33,33,44); 187 | --md-code-bg-color: rgb(38,38,50); 188 | } 189 | 190 | h4.doc-heading { 191 | /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/ 192 | background-color: var(--doc-heading-color); 193 | border: solid var(--doc-heading-border-color); 194 | border-width: 1.5pt; 195 | border-radius: 2pt; 196 | padding: 0pt 5pt 2pt 5pt; 197 | } 198 | h5.doc-heading, h6.heading { 199 | background-color: var(--doc-heading-color-alt); 200 | border-radius: 2pt; 201 | padding: 0pt 5pt 2pt 5pt; 202 | } -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | theme: 2 | name: material 3 | features: 4 | - navigation.sections # Sections are included in the navigation on the left. 5 | - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right. 6 | - header.autohide # header disappears as you scroll 7 | palette: 8 | # Light mode / dark mode 9 | # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as 10 | # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle. 11 | - scheme: default 12 | primary: white 13 | accent: amber 14 | toggle: 15 | icon: material/weather-night 16 | name: Switch to dark mode 17 | - scheme: slate 18 | primary: black 19 | accent: amber 20 | toggle: 21 | icon: material/weather-sunny 22 | name: Switch to light mode 23 | icon: 24 | repo: fontawesome/brands/github # GitHub logo in top right 25 | logo_light_mode: "_static/logo-light.svg" # jpc logo in top left 26 | logo_dark_mode: "_static/logo-dark.svg" # jpc logo in top left 27 | favicon: "_static/favicon.png" 28 | custom_dir: "docs/_overrides" # Overriding part of the HTML 29 | 30 | site_name: jpc 31 | site_description: The documentation for the jpc software library. 32 | site_author: Francesco Innocenti 33 | site_url: https://thebuckleylab.github.io/jpc/ 34 | 35 | repo_url: https://github.com/thebuckleylab/jpc 36 | repo_name: thebuckleylab/jpc 37 | edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate 38 | 39 | strict: true # Don't allow warnings during the build process 40 | 41 | extra_javascript: 42 | # The below two make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/ 43 | - _static/mathjax.js 44 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 45 | 46 | extra_css: 47 | - _static/custom_css.css 48 | 49 | markdown_extensions: 50 | - pymdownx.arithmatex: # Render LaTeX via MathJax 51 | generic: true 52 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. 53 | - pymdownx.details # Allowing hidden expandable regions denoted by ??? 54 | - pymdownx.snippets: # Include one Markdown file into another 55 | base_path: docs 56 | - admonition 57 | - toc: 58 | permalink: "¤" # Adds a clickable permalink to each section heading 59 | toc_depth: 4 60 | 61 | plugins: 62 | - search: 63 | separator: '[\s\-,:!=\[\]()"/]+|(?!\b)(?=[A-Z][a-z])|\.(?!\d)|&[lg]t;' 64 | - include_exclude_files: 65 | include: 66 | - ".htaccess" 67 | exclude: 68 | - "_overrides" 69 | - "examples/.ipynb_checkpoints/" 70 | - "examples/analytical_inference_with_linear_net.ipynb" 71 | - ipynb 72 | - mkdocstrings: 73 | handlers: 74 | python: 75 | options: 76 | force_inspection: true 77 | heading_level: 4 78 | inherited_members: true 79 | members_order: source 80 | show_bases: false 81 | show_if_no_docstring: true 82 | show_overloads: false 83 | show_root_heading: true 84 | show_signature_annotations: true 85 | show_source: false 86 | show_symbol_type_heading: true 87 | show_symbol_type_toc: true 88 | 89 | nav: 90 | - 'index.md' 91 | - ⚙️ How it works: 92 | - Basic usage: 'basic_usage.md' 93 | - Advanced usage: 'advanced_usage.md' 94 | - 📚 Examples: 95 | - Introductory: 96 | - Discriminative PC: 'examples/discriminative_pc.ipynb' 97 | - Supervised generative PC: 'examples/supervised_generative_pc.ipynb' 98 | - Unsupervised generative PC: 'examples/unsupervised_generative_pc.ipynb' 99 | - Advanced: 100 | - μPC: 'examples/mupc.ipynb' 101 | - ePC: 'examples/epc.ipynb' 102 | - Hybrid PC: 'examples/hybrid_pc.ipynb' 103 | - Bidirectional PC: 'examples/bidirectional_pc.ipynb' 104 | - Linear theoretical energy: 'examples/linear_net_theoretical_energy.ipynb' 105 | - JPC from scratch: 'examples/jpc_from_scratch.ipynb' 106 | - 🌱 Basic API: 107 | - 'api/Training.md' 108 | - 'api/Testing.md' 109 | - 'api/Utils.md' 110 | - 🚀 Advanced API: 111 | - 'api/Initialisation.md' 112 | - 'api/Energy functions.md' 113 | - 'api/Gradients.md' 114 | - 'api/Continuous-time Inference.md' 115 | - 'api/Discrete updates.md' 116 | - 'api/Theoretical tools.md' 117 | 118 | copyright: | 119 | © 2024 thebuckleylab -------------------------------------------------------------------------------- /docs/_static/logo-light.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_infer.py: -------------------------------------------------------------------------------- 1 | """Tests for inference solving functions.""" 2 | 3 | import pytest 4 | import jax 5 | import jax.numpy as jnp 6 | from diffrax import Heun, PIDController 7 | from jpc import solve_inference 8 | 9 | 10 | def test_solve_inference_supervised(simple_model, x, y): 11 | """Test inference solving in supervised mode.""" 12 | from jpc import init_activities_with_ffwd 13 | 14 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp") 15 | 16 | solution = solve_inference( 17 | params=(simple_model, None), 18 | activities=activities, 19 | output=y, 20 | input=x, 21 | loss_id="mse", 22 | param_type="sp", 23 | solver=Heun(), 24 | max_t1=10, 25 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3) 26 | ) 27 | 28 | assert len(solution) == len(activities) 29 | # Solution shape depends on whether recording is enabled 30 | # Without record_iters, it's just the final state 31 | for sol, act in zip(solution, activities): 32 | # Solution may have time dimension if recorded, or match activity shape 33 | assert sol.shape[-2:] == act.shape or sol.shape == act.shape 34 | 35 | 36 | def test_solve_inference_unsupervised(simple_model, y, key, layer_sizes, batch_size): 37 | """Test inference solving in unsupervised mode.""" 38 | from jpc import init_activities_from_normal 39 | 40 | activities = init_activities_from_normal( 41 | key=key, 42 | layer_sizes=layer_sizes, 43 | mode="unsupervised", 44 | batch_size=batch_size, 45 | sigma=0.05 46 | ) 47 | 48 | solution = solve_inference( 49 | params=(simple_model, None), 50 | activities=activities, 51 | output=y, 52 | input=None, 53 | loss_id="mse", 54 | param_type="sp", 55 | solver=Heun(), 56 | max_t1=10, 57 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3) 58 | ) 59 | 60 | assert len(solution) == len(activities) 61 | 62 | 63 | def test_solve_inference_cross_entropy(simple_model, x, y_onehot): 64 | """Test inference solving with cross-entropy loss.""" 65 | from jpc import init_activities_with_ffwd 66 | 67 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp") 68 | 69 | solution = solve_inference( 70 | params=(simple_model, None), 71 | activities=activities, 72 | output=y_onehot, 73 | input=x, 74 | loss_id="ce", 75 | param_type="sp", 76 | solver=Heun(), 77 | max_t1=10, 78 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3) 79 | ) 80 | 81 | assert len(solution) == len(activities) 82 | 83 | 84 | def test_solve_inference_with_regularization(simple_model, x, y): 85 | """Test inference solving with regularization.""" 86 | from jpc import init_activities_with_ffwd 87 | 88 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp") 89 | 90 | solution = solve_inference( 91 | params=(simple_model, None), 92 | activities=activities, 93 | output=y, 94 | input=x, 95 | loss_id="mse", 96 | param_type="sp", 97 | solver=Heun(), 98 | max_t1=10, 99 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3), 100 | weight_decay=0.01, 101 | spectral_penalty=0.01, 102 | activity_decay=0.01 103 | ) 104 | 105 | assert len(solution) == len(activities) 106 | 107 | 108 | def test_solve_inference_record_iters(simple_model, x, y): 109 | """Test inference solving with iteration recording.""" 110 | from jpc import init_activities_with_ffwd 111 | 112 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp") 113 | 114 | solution = solve_inference( 115 | params=(simple_model, None), 116 | activities=activities, 117 | output=y, 118 | input=x, 119 | loss_id="mse", 120 | param_type="sp", 121 | solver=Heun(), 122 | max_t1=10, 123 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3), 124 | record_iters=True 125 | ) 126 | 127 | assert len(solution) == len(activities) 128 | # When recording, solution should have an extra dimension for time 129 | assert solution[0].ndim == 3 # (time, batch, features) 130 | 131 | 132 | def test_solve_inference_record_every(simple_model, x, y): 133 | """Test inference solving with record_every parameter.""" 134 | from jpc import init_activities_with_ffwd 135 | 136 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp") 137 | 138 | solution = solve_inference( 139 | params=(simple_model, None), 140 | activities=activities, 141 | output=y, 142 | input=x, 143 | loss_id="mse", 144 | param_type="sp", 145 | solver=Heun(), 146 | max_t1=10, 147 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3), 148 | record_every=2 149 | ) 150 | 151 | assert len(solution) == len(activities) 152 | 153 | -------------------------------------------------------------------------------- /.github/logo-with-background.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Getting started 2 | JPC is a [**J**AX](https://github.com/google/jax) library for training neural 3 | networks with **P**redictive **C**oding (PC). 4 | 5 | JPC provides a **simple**, **fast** and **flexible** API for 6 | training of a variety of PCNs including discriminative, generative and hybrid 7 | models. 8 | 9 | * Like JAX, JPC is completely functional in design, and the core library code is 10 | <1000 lines of code. 11 | 12 | * Unlike existing implementations, JPC provides a wide range of optimisers, both 13 | discrete and continuous, to solve the inference dynamics of PC, including 14 | ordinary differential equation (ODE) solvers. 15 | 16 | * JPC also provides some analytical tools that can be used to study and 17 | potentially diagnose issues with PCNs. 18 | 19 | If you're new to JPC, we recommend starting from the [ 20 | example notebooks](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/) 21 | and checking the [documentation](https://thebuckleylab.github.io/jpc/). 22 | 23 | ## 💻 Installation 24 | Clone the repo and in the project's directory run 25 | ``` 26 | pip install . 27 | ``` 28 | 29 | Requires Python 3.10+ and JAX 0.4.38–0.5.2 (inclusive). For GPU usage, upgrade 30 | jax to the appropriate cuda version (12 as an example here). 31 | 32 | ``` 33 | pip install --upgrade "jax[cuda12]" 34 | ``` 35 | 36 | ## ⚡️ Quick example 37 | Use `jpc.make_pc_step()` to update the parameters of any neural network 38 | compatible with PC updates (see [examples](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/)) 39 | ```py 40 | import jax.random as jr 41 | import jax.numpy as jnp 42 | import equinox as eqx 43 | import optax 44 | import jpc 45 | 46 | # toy data 47 | x = jnp.array([1., 1., 1.]) 48 | y = -x 49 | 50 | # define model and optimiser 51 | key = jr.PRNGKey(0) 52 | model = jpc.make_mlp( 53 | key, 54 | input_dim=3, 55 | width=50, 56 | depth=5, 57 | output_dim=3 58 | act_fn="relu" 59 | ) 60 | optim = optax.adam(1e-3) 61 | opt_state = optim.init( 62 | (eqx.filter(model, eqx.is_array), None) 63 | ) 64 | 65 | # perform one training step with PC 66 | result = jpc.make_pc_step( 67 | model=model, 68 | optim=optim, 69 | opt_state=opt_state, 70 | output=y, 71 | input=x 72 | ) 73 | 74 | # updated model and optimiser 75 | model, opt_state = result["model"], result["opt_state"] 76 | ``` 77 | Under the hood, `jpc.make_pc_step()` 78 | 79 | 1. integrates the inference (activity) dynamics using a [diffrax](https://github.com/patrick-kidger/diffrax) ODE solver, and 80 | 2. updates model parameters at the numerical solution of the activities with a given [optax](https://github.com/google-deepmind/optax) optimiser. 81 | 82 | > **NOTE**: All convenience training and test functions such as `make_pc_step()` 83 | > are already "jitted" (for optimised performance) for the user's convenience. 84 | 85 | ## 🚀 Advanced usage 86 | Advanced users can access all the underlying functions of `jpc.make_pc_step()` 87 | as well as additional features. A custom PC training step looks like the 88 | following: 89 | ```py 90 | import jpc 91 | 92 | # 1. initialise activities with a feedforward pass 93 | activities = jpc.init_activities_with_ffwd(model=model, input=x) 94 | 95 | # 2. perform inference (state optimisation) 96 | activity_opt_state = activity_optim.init(activities) 97 | for _ in range(len(model)): 98 | activity_update_result = jpc.update_pc_activities( 99 | params=(model, None), 100 | activities=activities, 101 | optim=activity_optim, 102 | opt_state=activity_opt_state, 103 | output=y, 104 | input=x 105 | ) 106 | activities = activity_update_result["activities"] 107 | activity_opt_state = activity_update_result["opt_state"] 108 | 109 | # 3. update parameters at the activities' solution with PC 110 | result = jpc.update_params( 111 | params=(model, None), 112 | activities=equilibrated_activities, 113 | optim=optim, 114 | opt_state=opt_state, 115 | output=y, 116 | input=x 117 | ) 118 | ``` 119 | which can be embedded in a jitted function with any other additional 120 | computations. 121 | 122 | ## 📄 Citation 123 | If you found this library useful in your work, please cite ([paper link](https://arxiv.org/abs/2412.03676)): 124 | 125 | ```bibtex 126 | @article{innocenti2024jpc, 127 | title={JPC: Flexible Inference for Predictive Coding Networks in JAX}, 128 | author={Innocenti, Francesco and Kinghorn, Paul and Yun-Farmbrough, Will and Varona, Miguel De Llanza and Singh, Ryan and Buckley, Christopher L}, 129 | journal={arXiv preprint arXiv:2412.03676}, 130 | year={2024} 131 | } 132 | ``` 133 | Also consider starring the project [on GitHub](https://github.com/thebuckleylab/jpc)! ⭐️ 134 | 135 | ## 🙏 Acknowledgements 136 | We are grateful to Patrick Kidger for early advice on how to use Diffrax. 137 | 138 | ## See also: other PC libraries 139 | * [ngc-learn](https://github.com/NACLab/ngc-learn) (jax & pytorch) 140 | * [pcx](https://github.com/liukidar/pcx) (jax) 141 | * [pyhgf](https://github.com/ComputationalPsychiatry/pyhgf) (jax) 142 | * [Torch2PC](https://github.com/RobertRosenbaum/Torch2PC) (pytorch) 143 | * [pypc](https://github.com/infer-actively/pypc) (pytorch) 144 | * [pybrid](https://github.com/alec-tschantz/pybrid) (pytorch) 145 | -------------------------------------------------------------------------------- /docs/_static/logo-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /tests/test_test_functions.py: -------------------------------------------------------------------------------- 1 | """Tests for test utility functions.""" 2 | 3 | import pytest 4 | import jax 5 | import jax.numpy as jnp 6 | from diffrax import Heun, PIDController 7 | from jpc import test_discriminative_pc, test_generative_pc, test_hpc 8 | 9 | 10 | def test_test_discriminative_pc(simple_model, x, y_onehot): 11 | """Test discriminative PC testing function.""" 12 | # For accuracy computation, we need one-hot encoded targets 13 | loss, acc = test_discriminative_pc( 14 | model=simple_model, 15 | output=y_onehot, 16 | input=x, 17 | loss="ce", # Use cross-entropy for one-hot targets 18 | param_type="sp" 19 | ) 20 | 21 | assert jnp.isfinite(loss) 22 | assert jnp.isfinite(acc) 23 | assert 0 <= acc <= 100 24 | 25 | 26 | def test_test_discriminative_pc_cross_entropy(simple_model, x, y_onehot): 27 | """Test discriminative PC testing with cross-entropy.""" 28 | loss, acc = test_discriminative_pc( 29 | model=simple_model, 30 | output=y_onehot, 31 | input=x, 32 | loss="ce", 33 | param_type="sp" 34 | ) 35 | 36 | assert jnp.isfinite(loss) 37 | assert jnp.isfinite(acc) 38 | 39 | 40 | def test_test_discriminative_pc_with_skip(simple_model, x, y_onehot): 41 | """Test discriminative PC testing with skip connections.""" 42 | from jpc import make_skip_model 43 | 44 | skip_model = make_skip_model(len(simple_model)) 45 | 46 | # compute_accuracy requires one-hot encoded targets 47 | loss, acc = test_discriminative_pc( 48 | model=simple_model, 49 | output=y_onehot, 50 | input=x, 51 | skip_model=skip_model, 52 | loss="ce", # Use cross-entropy for one-hot targets 53 | param_type="sp" 54 | ) 55 | 56 | assert jnp.isfinite(loss) 57 | assert jnp.isfinite(acc) 58 | 59 | 60 | def test_test_generative_pc(simple_model, x, y_onehot, key, layer_sizes, batch_size, input_dim): 61 | """Test generative PC testing function.""" 62 | # For compute_accuracy, input needs to be one-hot encoded 63 | key1, key2 = jax.random.split(key) 64 | input_onehot = jax.nn.one_hot( 65 | jax.random.randint(key1, (batch_size,), 0, input_dim), 66 | input_dim 67 | ) 68 | input_acc, output_preds = test_generative_pc( 69 | model=simple_model, 70 | output=y_onehot, 71 | input=input_onehot, 72 | key=key2, 73 | layer_sizes=layer_sizes, 74 | batch_size=batch_size, 75 | loss_id="ce", 76 | param_type="sp", 77 | sigma=0.05, 78 | ode_solver=Heun(), 79 | max_t1=10, 80 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3) 81 | ) 82 | 83 | assert jnp.isfinite(input_acc) 84 | assert 0 <= input_acc <= 100 85 | assert output_preds.shape == y_onehot.shape 86 | 87 | 88 | def test_test_generative_pc_cross_entropy(simple_model, x, y_onehot, key, layer_sizes, batch_size, input_dim): 89 | """Test generative PC testing with cross-entropy.""" 90 | # For compute_accuracy, input needs to be one-hot encoded 91 | key1, key2 = jax.random.split(key) 92 | input_onehot = jax.nn.one_hot( 93 | jax.random.randint(key1, (batch_size,), 0, input_dim), 94 | input_dim 95 | ) 96 | input_acc, output_preds = test_generative_pc( 97 | model=simple_model, 98 | output=y_onehot, 99 | input=input_onehot, 100 | key=key2, 101 | layer_sizes=layer_sizes, 102 | batch_size=batch_size, 103 | loss_id="ce", 104 | param_type="sp", 105 | sigma=0.05, 106 | ode_solver=Heun(), 107 | max_t1=10, 108 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3) 109 | ) 110 | 111 | assert jnp.isfinite(input_acc) 112 | assert output_preds.shape == y_onehot.shape 113 | 114 | 115 | def test_test_hpc(key, simple_model, x, y_onehot, output_dim, hidden_dim, input_dim, depth, layer_sizes, batch_size): 116 | """Test HPC testing function.""" 117 | from jpc import make_mlp 118 | 119 | # Split keys first 120 | key1, key2, key3 = jax.random.split(key, 3) 121 | 122 | generator = simple_model 123 | amortiser = make_mlp( 124 | key=key2, 125 | input_dim=output_dim, 126 | width=hidden_dim, 127 | depth=depth, 128 | output_dim=input_dim, 129 | act_fn="relu", 130 | use_bias=False, 131 | param_type="sp" 132 | ) 133 | 134 | # For compute_accuracy, input needs to be one-hot encoded 135 | input_onehot = jax.nn.one_hot( 136 | jax.random.randint(key1, (batch_size,), 0, input_dim), 137 | input_dim 138 | ) 139 | amort_acc, hpc_acc, gen_acc, output_preds = test_hpc( 140 | generator=generator, 141 | amortiser=amortiser, 142 | output=y_onehot, 143 | input=input_onehot, 144 | key=key3, 145 | layer_sizes=layer_sizes, 146 | batch_size=batch_size, 147 | sigma=0.05, 148 | ode_solver=Heun(), 149 | max_t1=10, 150 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3) 151 | ) 152 | 153 | assert jnp.isfinite(amort_acc) 154 | assert jnp.isfinite(hpc_acc) 155 | assert jnp.isfinite(gen_acc) 156 | assert 0 <= amort_acc <= 100 157 | assert 0 <= hpc_acc <= 100 158 | assert 0 <= gen_acc <= 100 159 | assert output_preds.shape == y_onehot.shape 160 | 161 | -------------------------------------------------------------------------------- /experiments/mupc_paper/test_energy_theory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import jax.random as jr 4 | 5 | import jpc 6 | import optax 7 | import equinox as eqx 8 | 9 | from experiments.datasets import get_dataloaders 10 | 11 | 12 | def run_test( 13 | seed, 14 | dataset, 15 | width, 16 | n_hidden, 17 | act_fn, 18 | use_skips, 19 | param_type, 20 | param_optim_id, 21 | param_lr, 22 | batch_size, 23 | save_dir 24 | ): 25 | set_seed(seed) 26 | key = jr.PRNGKey(seed) 27 | model_key, init_key = jr.split(key, 2) 28 | os.makedirs(save_dir, exist_ok=True) 29 | 30 | # create and initialise model 31 | d_in, d_out = 784, 10 32 | L = n_hidden + 1 33 | model = jpc.make_mlp( 34 | key=model_key, 35 | input_dim=d_in, 36 | width=width, 37 | depth=L, 38 | output_dim=d_out, 39 | act_fn=act_fn, 40 | use_bias=False, 41 | param_type=param_type 42 | ) 43 | skip_model = jpc.make_skip_model(L) if use_skips else None 44 | 45 | # optimisers 46 | if param_optim_id == "sgd": 47 | param_optim = optax.sgd(param_lr) 48 | elif param_optim_id == "adam": 49 | param_optim = optax.adam(param_lr) 50 | 51 | param_opt_state = param_optim.init( 52 | (eqx.filter(model, eqx.is_array), skip_model) 53 | ) 54 | 55 | # data & metrics 56 | train_loader, _ = get_dataloaders(dataset, batch_size) 57 | train_losses, train_energies = [], [] 58 | loss_energy_ratios = [] 59 | 60 | for t, (img_batch, label_batch) in enumerate(train_loader): 61 | x, y = img_batch.numpy(), label_batch.numpy() 62 | 63 | # compute loss 64 | activities = jpc.init_activities_with_ffwd( 65 | model=model, 66 | input=x, 67 | skip_model=skip_model, 68 | param_type=param_type 69 | ) 70 | loss = 0.5 * np.sum((y - activities[-1])**2) / batch_size 71 | 72 | # compute theoretical activities & energy 73 | activities = jpc.compute_linear_activity_solution( 74 | network=model, 75 | x=x, 76 | y=y, 77 | use_skips=use_skips, 78 | param_type=param_type 79 | ) 80 | energy = jpc.pc_energy_fn( 81 | params=(model, skip_model), 82 | activities=activities, 83 | y=y, 84 | x=x, 85 | param_type=param_type 86 | ) 87 | 88 | # update parameters 89 | param_update_result = jpc.update_params( 90 | params=(model, skip_model), 91 | activities=activities, 92 | optim=param_optim, 93 | opt_state=param_opt_state, 94 | output=y, 95 | input=x, 96 | param_type=param_type 97 | ) 98 | model = param_update_result["model"] 99 | skip_model = param_update_result["skip_model"] 100 | param_opt_state = param_update_result["opt_state"] 101 | 102 | train_losses.append(loss) 103 | train_energies.append(energy) 104 | loss_energy_ratios.append(loss/energy) 105 | 106 | if t % 200 == 0: 107 | print( 108 | f"\t{t * len(img_batch)}/{len(train_loader.dataset)}, " 109 | f"loss: {loss:.4f}, energy: {energy:.4f}, ratio: {loss/energy:.4f} " 110 | ) 111 | 112 | np.save(f"{save_dir}/train_losses.npy", train_losses) 113 | np.save(f"{save_dir}/train_energies.npy", train_energies) 114 | np.save(f"{save_dir}/loss_energy_ratios.npy", loss_energy_ratios) 115 | 116 | 117 | if __name__ == "__main__": 118 | 119 | RESULTS_DIR = "energy_theory_results" 120 | DATASET = "MNIST" 121 | USE_SKIPS = True 122 | PARAM_OPTIM_ID = "adam" 123 | PARAM_LR = 1e-2 124 | BATCH_SIZE = 64 125 | 126 | ACT_FNS = ["linear"] 127 | PARAM_TYPES = ["mupc"] #"sp", 128 | WIDTHS = [2**i for i in range(7)] # 7 or 10 max 129 | N_HIDDENS = [2**i for i in range(7)] # 4 or 7 max 130 | SEED = 4320 131 | 132 | for act_fn in ACT_FNS: 133 | for param_type in PARAM_TYPES: 134 | for width in WIDTHS: 135 | for n_hidden in N_HIDDENS: 136 | print( 137 | f"\nAct fn: {act_fn}\n" 138 | f"Param type: {param_type}\n" 139 | f"Use skips: {USE_SKIPS}\n" 140 | f"Param optim: {PARAM_OPTIM_ID}\n" 141 | f"Param lr: {PARAM_LR}\n" 142 | f"Width: {width}\n" 143 | f"N hidden: {n_hidden}\n" 144 | f"Seed: {SEED}\n" 145 | ) 146 | save_dir = os.path.join( 147 | RESULTS_DIR, 148 | act_fn, 149 | param_type, 150 | "skips" if USE_SKIPS else "no_skips", 151 | PARAM_OPTIM_ID, 152 | f"param_lr_{PARAM_LR}", 153 | f"width_{width}", 154 | f"{n_hidden}_n_hidden", 155 | str(SEED) 156 | ) 157 | run_test( 158 | seed=SEED, 159 | dataset=DATASET, 160 | width=width, 161 | n_hidden=n_hidden, 162 | act_fn=act_fn, 163 | use_skips=USE_SKIPS, 164 | param_type=param_type, 165 | param_optim_id=PARAM_OPTIM_ID, 166 | param_lr=PARAM_LR, 167 | batch_size=BATCH_SIZE, 168 | save_dir=save_dir 169 | ) 170 | -------------------------------------------------------------------------------- /jpc/_core/_infer.py: -------------------------------------------------------------------------------- 1 | """Function to solve the inference (activity) dynamics of PC networks.""" 2 | 3 | from jaxtyping import PyTree, ArrayLike, Array, Scalar 4 | import jax.numpy as jnp 5 | from typing import Tuple, Callable, Optional 6 | from ._grads import neg_pc_activity_grad 7 | from optimistix import rms_norm 8 | from diffrax import ( 9 | AbstractSolver, 10 | AbstractStepSizeController, 11 | Heun, 12 | PIDController, 13 | diffeqsolve, 14 | ODETerm, 15 | Event, 16 | SaveAt 17 | ) 18 | 19 | 20 | def solve_inference( 21 | params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], 22 | activities: PyTree[ArrayLike], 23 | output: ArrayLike, 24 | *, 25 | input: Optional[ArrayLike] = None, 26 | loss_id: str = "mse", 27 | param_type: str = "sp", 28 | solver: AbstractSolver = Heun(), 29 | max_t1: int = 20, 30 | dt: float | int = None, 31 | stepsize_controller: AbstractStepSizeController = PIDController( 32 | rtol=1e-3, atol=1e-3 33 | ), 34 | weight_decay: Scalar = 0., 35 | spectral_penalty: Scalar = 0., 36 | activity_decay: Scalar = 0., 37 | record_iters: bool = False, 38 | record_every: int = None 39 | ) -> PyTree[Array]: 40 | """Solves the inference (activity) dynamics of a predictive coding network. 41 | 42 | This is a wrapper around [`diffrax.diffeqsolve()`](https://docs.kidger.site/diffrax/api/diffeqsolve/#diffrax.diffeqsolve) 43 | to integrate the gradient ODE system [`jpc.neg_activity_grad()`](https://thebuckleylab.github.io/jpc/api/Gradients/#jpc.neg_activity_grad) 44 | defining the PC inference dynamics. 45 | 46 | $$ 47 | d\mathbf{z} / dt = - ∇_{\mathbf{z}} \mathcal{F} 48 | $$ 49 | 50 | where $\mathcal{F}$ is the free energy, $\mathbf{z}$ are the activities, 51 | with $\mathbf{z}_L$ clamped to some target and $\mathbf{z}_0$ optionally 52 | set to some prior. 53 | 54 | **Main arguments:** 55 | 56 | - `params`: Tuple with callable model layers and optional skip connections. 57 | - `activities`: List of activities for each layer free to vary. 58 | - `output`: Observation or target of the generative model. 59 | 60 | **Other arguments:** 61 | 62 | - `input`: Optional prior of the generative model. 63 | - `loss_id`: Loss function to use at the output layer. Options are mean squared 64 | error `"mse"` (default) or cross-entropy `"ce"`. 65 | - `param_type`: Determines the parameterisation. Options are `"sp"` 66 | (standard parameterisation), `"mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))), 67 | or `"ntp"` (neural tangent parameterisation). 68 | See [`_get_param_scalings()`](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings) 69 | for the specific scalings of these different parameterisations. Defaults 70 | to `"sp"`. 71 | - `solver`: [diffrax ODE solver](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/) 72 | to be used. Default is [`Heun`](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/#diffrax.Heun), 73 | a 2nd order explicit Runge--Kutta method. 74 | - `max_t1`: Maximum end of integration region (20 by default). 75 | - `dt`: Integration step size. Defaults to `None` since the default 76 | `stepsize_controller` will automatically determine it. 77 | - `stepsize_controller`: [diffrax controller](https://docs.kidger.site/diffrax/api/stepsize_controller/) 78 | for step size integration. Defaults to [`PIDController`](https://docs.kidger.site/diffrax/api/stepsize_controller/#diffrax.PIDController). 79 | Note that the relative and absolute tolerances of the controller will 80 | also determine the steady state to terminate the solver. 81 | - `weight_decay`: $\ell^2$ regulariser for the weights (0 by default). 82 | - `spectral_penalty`: Weight spectral penalty of the form 83 | $||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2$ (0 by default). 84 | - `activity_decay`: $\ell^2$ regulariser for the activities (0 by default). 85 | - `record_iters`: If `True`, returns all integration steps. 86 | - `record_every`: int determining the sampling frequency of the integration 87 | steps. 88 | 89 | **Returns:** 90 | 91 | List with solution of the activity dynamics for each layer. 92 | 93 | """ 94 | if record_every is not None: 95 | ts = jnp.arange(0, max_t1, record_every) 96 | saveat = SaveAt(t1=True, ts=ts) 97 | else: 98 | saveat = SaveAt(t1=True, steps=record_iters) 99 | 100 | solution = diffeqsolve( 101 | terms=ODETerm(neg_pc_activity_grad), 102 | solver=solver, 103 | t0=0, 104 | t1=max_t1, 105 | dt0=dt, 106 | y0=activities, 107 | args=( 108 | params, 109 | output, 110 | input, 111 | loss_id, 112 | param_type, 113 | weight_decay, 114 | spectral_penalty, 115 | activity_decay, 116 | stepsize_controller 117 | ), 118 | stepsize_controller=stepsize_controller, 119 | event=Event(steady_state_event_with_timeout), 120 | saveat=saveat 121 | ) 122 | return solution.ys 123 | 124 | 125 | def steady_state_event_with_timeout(t, y, args, **kwargs): 126 | _stepsize_controller = args[-1] 127 | try: 128 | _atol = _stepsize_controller.atol 129 | _rtol = _stepsize_controller.rtol 130 | except: 131 | _atol, _rtol = 1e-3, 1e-3 132 | steady_state_reached = rms_norm(y) < _atol + _rtol * rms_norm(y) 133 | timeout_reached = jnp.array(t >= 4096, dtype=jnp.bool_) 134 | return jnp.logical_or(steady_state_reached, timeout_reached) 135 | -------------------------------------------------------------------------------- /tests/test_analytical.py: -------------------------------------------------------------------------------- 1 | """Tests for analytical tools.""" 2 | 3 | import pytest 4 | import jax 5 | import jax.numpy as jnp 6 | import equinox.nn as nn 7 | from jpc import ( 8 | compute_linear_equilib_energy, 9 | compute_linear_activity_hessian, 10 | compute_linear_activity_solution 11 | ) 12 | 13 | 14 | def test_compute_linear_equilib_energy(key, x, y, input_dim, hidden_dim, output_dim, depth): 15 | """Test computation of linear equilibrium energy.""" 16 | # Create a linear network (no activation functions) 17 | subkeys = jax.random.split(key, depth) 18 | network = [] 19 | for i in range(depth): 20 | _in = input_dim if i == 0 else hidden_dim 21 | _out = output_dim if (i + 1) == depth else hidden_dim 22 | linear = nn.Linear(_in, _out, use_bias=False, key=subkeys[i]) 23 | network.append(nn.Sequential([nn.Lambda(lambda x: x), linear])) 24 | 25 | energy = compute_linear_equilib_energy( 26 | network=network, 27 | x=x, 28 | y=y 29 | ) 30 | 31 | assert jnp.isfinite(energy) 32 | assert energy >= 0 33 | 34 | 35 | def test_compute_linear_activity_hessian(key, input_dim, hidden_dim, output_dim, depth): 36 | """Test computation of linear activity Hessian.""" 37 | # Extract weight matrices 38 | subkeys = jax.random.split(key, depth) 39 | Ws = [] 40 | for i in range(depth): 41 | _in = input_dim if i == 0 else hidden_dim 42 | _out = output_dim if (i + 1) == depth else hidden_dim 43 | W = jax.random.normal(subkeys[i], (_out, _in)) 44 | Ws.append(W) 45 | 46 | hessian = compute_linear_activity_hessian( 47 | Ws=Ws, 48 | use_skips=False, 49 | param_type="sp", 50 | activity_decay=False, 51 | diag=True, 52 | off_diag=True 53 | ) 54 | 55 | # Check shape: should be (sum of hidden layer sizes) x (sum of hidden layer sizes) 56 | hidden_sizes = [hidden_dim] * (depth - 1) 57 | expected_size = sum(hidden_sizes) 58 | assert hessian.shape == (expected_size, expected_size) 59 | assert jnp.all(jnp.isfinite(hessian)) 60 | 61 | 62 | def test_compute_linear_activity_hessian_with_skips(key, input_dim, hidden_dim, output_dim, depth): 63 | """Test computation of linear activity Hessian with skip connections.""" 64 | subkeys = jax.random.split(key, depth) 65 | Ws = [] 66 | for i in range(depth): 67 | _in = input_dim if i == 0 else hidden_dim 68 | _out = output_dim if (i + 1) == depth else hidden_dim 69 | W = jax.random.normal(subkeys[i], (_out, _in)) 70 | Ws.append(W) 71 | 72 | hessian = compute_linear_activity_hessian( 73 | Ws=Ws, 74 | use_skips=True, 75 | param_type="sp", 76 | activity_decay=False, 77 | diag=True, 78 | off_diag=True 79 | ) 80 | 81 | hidden_sizes = [hidden_dim] * (depth - 1) 82 | expected_size = sum(hidden_sizes) 83 | assert hessian.shape == (expected_size, expected_size) 84 | 85 | 86 | def test_compute_linear_activity_hessian_different_param_types(key, input_dim, hidden_dim, output_dim, depth): 87 | """Test Hessian computation with different parameter types.""" 88 | subkeys = jax.random.split(key, depth) 89 | Ws = [] 90 | for i in range(depth): 91 | _in = input_dim if i == 0 else hidden_dim 92 | _out = output_dim if (i + 1) == depth else hidden_dim 93 | W = jax.random.normal(subkeys[i], (_out, _in)) 94 | Ws.append(W) 95 | 96 | # Note: The library code uses "ntp" in _analytical.py but "ntk" in _errors.py 97 | # For now, test with "sp" and "mupc" which work, and skip "ntk"/"ntp" due to inconsistency 98 | for param_type in ["sp", "mupc"]: 99 | hessian = compute_linear_activity_hessian( 100 | Ws=Ws, 101 | use_skips=False, 102 | param_type=param_type, 103 | activity_decay=False, 104 | diag=True, 105 | off_diag=True 106 | ) 107 | 108 | hidden_sizes = [hidden_dim] * (depth - 1) 109 | expected_size = sum(hidden_sizes) 110 | assert hessian.shape == (expected_size, expected_size) 111 | 112 | 113 | def test_compute_linear_activity_solution(key, x, y, input_dim, hidden_dim, output_dim, depth): 114 | """Test computation of linear activity solution.""" 115 | # Create a linear network 116 | subkeys = jax.random.split(key, depth) 117 | network = [] 118 | for i in range(depth): 119 | _in = input_dim if i == 0 else hidden_dim 120 | _out = output_dim if (i + 1) == depth else hidden_dim 121 | linear = nn.Linear(_in, _out, use_bias=False, key=subkeys[i]) 122 | network.append(nn.Sequential([nn.Lambda(lambda x: x), linear])) 123 | 124 | activities = compute_linear_activity_solution( 125 | network=network, 126 | x=x, 127 | y=y, 128 | use_skips=False, 129 | param_type="sp", 130 | activity_decay=False 131 | ) 132 | 133 | # Should return activities for hidden layers plus dummy target prediction 134 | assert len(activities) == depth 135 | assert activities[0].shape == (x.shape[0], hidden_dim) 136 | assert activities[-1].shape == (x.shape[0], output_dim) 137 | 138 | 139 | def test_compute_linear_activity_solution_with_skips(key, x, y, input_dim, hidden_dim, output_dim, depth): 140 | """Test computation of linear activity solution with skip connections.""" 141 | subkeys = jax.random.split(key, depth) 142 | network = [] 143 | for i in range(depth): 144 | _in = input_dim if i == 0 else hidden_dim 145 | _out = output_dim if (i + 1) == depth else hidden_dim 146 | linear = nn.Linear(_in, _out, use_bias=False, key=subkeys[i]) 147 | network.append(nn.Sequential([nn.Lambda(lambda x: x), linear])) 148 | 149 | activities = compute_linear_activity_solution( 150 | network=network, 151 | x=x, 152 | y=y, 153 | use_skips=True, 154 | param_type="sp", 155 | activity_decay=False 156 | ) 157 | 158 | assert len(activities) == depth 159 | 160 | -------------------------------------------------------------------------------- /jpc/_core/_init.py: -------------------------------------------------------------------------------- 1 | """Functions to initialise the layer activities of PC networks.""" 2 | 3 | from jax import vmap, random 4 | import jax.numpy as jnp 5 | import equinox as eqx 6 | from ._energies import _get_param_scalings 7 | from jaxtyping import PyTree, ArrayLike, Array, PRNGKeyArray, Scalar 8 | from typing import Callable, Optional 9 | from ._errors import _check_param_type 10 | 11 | 12 | @eqx.filter_jit 13 | def init_activities_with_ffwd( 14 | model: PyTree[Callable], 15 | input: ArrayLike, 16 | *, 17 | skip_model: Optional[PyTree[Callable]] = None, 18 | param_type: str = "sp" 19 | ) -> PyTree[Array]: 20 | """Initialises the layers' activity with a feedforward pass 21 | $\{ f_\ell(\mathbf{z}_{\ell-1}) \}_{\ell=1}^L$ where $f_\ell(\cdot)$ is some 22 | callable layer transformation and $\mathbf{z}_0 = \mathbf{x}$ is the input. 23 | 24 | !!! warning 25 | 26 | `param_type = "mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))) assumes 27 | that one is using [`jpc.make_mlp()`](https://thebuckleylab.github.io/jpc/api/Utils/#jpc.make_mlp) 28 | to create the model. 29 | 30 | **Main arguments:** 31 | 32 | - `model`: List of callable model (e.g. neural network) layers. 33 | - `input`: input to the model. 34 | 35 | **Other arguments:** 36 | 37 | - `skip_model`: Optional skip connection model. 38 | - `param_type`: Determines the parameterisation. Options are `"sp"` 39 | (standard parameterisation), `"mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))), 40 | or `"ntp"` (neural tangent parameterisation). 41 | See [`_get_param_scalings()`](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings) 42 | for the specific scalings of these different parameterisations. Defaults 43 | to `"sp"`. 44 | 45 | **Returns:** 46 | 47 | List with activity values of each layer. 48 | 49 | """ 50 | _check_param_type(param_type) 51 | 52 | L = len(model) 53 | if skip_model is None: 54 | skip_model = [None] * len(model) 55 | 56 | scalings = _get_param_scalings( 57 | model=model, 58 | input=input, 59 | skip_model=skip_model, 60 | param_type=param_type 61 | ) 62 | 63 | z1 = scalings[0] * vmap(model[0])(input) 64 | if skip_model[0] is not None: 65 | z1 += vmap(skip_model[0])(input) 66 | 67 | activities = [z1] 68 | for l in range(1, L): 69 | zl = scalings[l] * vmap(model[l])(activities[l - 1]) 70 | 71 | if skip_model[l] is not None: 72 | skip_output = vmap(skip_model[l])(activities[l - 1]) 73 | zl += skip_output 74 | 75 | activities.append(zl) 76 | 77 | return activities 78 | 79 | 80 | def init_activities_from_normal( 81 | key: PRNGKeyArray, 82 | layer_sizes: PyTree[int], 83 | mode: str, 84 | batch_size: int, 85 | sigma: Scalar = 0.05 86 | ) -> PyTree[Array]: 87 | """Initialises network activities from a zero-mean Gaussian 88 | $z_i \sim \mathcal{N}(0, \sigma^2)$. 89 | 90 | **Main arguments:** 91 | 92 | - `key`: `jax.random.PRNGKey` for sampling. 93 | - `layer_sizes`: List with dimension of all layers (input, hidden and 94 | output). 95 | - `mode`: If `"supervised"`, all hidden layers are initialised. If 96 | `"unsupervised"` the input layer $\mathbf{z}_0$ is also initialised. 97 | - `batch_size`: Dimension of data batch. 98 | - `sigma`: Standard deviation for Gaussian to sample activities from. 99 | Defaults to 5e-2. 100 | 101 | **Returns:** 102 | 103 | List of randomly initialised activities for each layer. 104 | 105 | """ 106 | start_l = 0 if mode == "unsupervised" else 1 107 | n_layers = len(layer_sizes) if mode == "unsupervised" else len(layer_sizes)-1 108 | activities = [] 109 | for l, subkey in zip( 110 | range(start_l, n_layers+1), 111 | random.split(key, num=n_layers) 112 | ): 113 | activities.append(sigma * random.normal( 114 | subkey, 115 | shape=(batch_size, layer_sizes[l]) 116 | ) 117 | ) 118 | return activities 119 | 120 | 121 | def init_activities_with_amort( 122 | amortiser: PyTree[Callable], 123 | generator: PyTree[Callable], 124 | input: ArrayLike 125 | ) -> PyTree[Array]: 126 | """Initialises layers' activity with an amortised network 127 | $\{ f_{L-\ell+1}(\mathbf{z}_{L-\ell}) \}_{\ell=1}^L$ where $\mathbf{z}_0 = \mathbf{y}$ is 128 | the input or generator's target. 129 | 130 | !!! note 131 | 132 | The output order is reversed for downstream use by the generator. 133 | 134 | **Main arguments:** 135 | 136 | - `amortiser`: List of callable layers for model amortising the inference 137 | of the `generator`. 138 | - `generator`: List of callable layers for the generative model. 139 | - `input`: Input to the amortiser. 140 | 141 | **Returns:** 142 | 143 | List with amortised initialisation of each layer. 144 | 145 | """ 146 | activities = [vmap(amortiser[0])(input)] 147 | for l in range(1, len(amortiser)): 148 | activities.append(vmap(amortiser[l])(activities[l - 1])) 149 | 150 | activities = activities[::-1] 151 | 152 | # NOTE: this dummy activity for the last layer is added in case one is 153 | # interested in inspecting the generator's target prediction during inference. 154 | activities.append( 155 | vmap(generator[-1])(activities[-1]) 156 | ) 157 | return activities 158 | 159 | 160 | def init_epc_errors( 161 | layer_sizes: PyTree[int], 162 | batch_size: int, 163 | mode: str = "supervised" 164 | ) -> PyTree[Array]: 165 | """Initialises zero errors for use with ePC $\{ \epsilon_\ell = 0 \}_{l=1}^L$. 166 | 167 | **Main arguments:** 168 | 169 | - `layer_sizes`: List with dimension of all layers (input, hidden and 170 | output). 171 | - `batch_size`: Dimension of data batch. 172 | - `mode`: If `"supervised"`, errors are initialised for layers 1 to L-1 173 | (hidden layers only). If `"unsupervised"`, errors are initialised for 174 | layer 0 (input) and layers 1 to L-1. Defaults to `"supervised"`. 175 | 176 | **Returns:** 177 | 178 | List of zero-initialised error arrays for each layer. 179 | 180 | """ 181 | start_l = 0 if mode == "unsupervised" else 1 182 | n_layers = len(layer_sizes) if mode == "unsupervised" else len(layer_sizes) - 1 183 | errors = [] 184 | for l in range(start_l, n_layers + 1): 185 | errors.append(jnp.zeros(shape=(batch_size, layer_sizes[l]))) 186 | return errors 187 | -------------------------------------------------------------------------------- /experiments/datasets.py: -------------------------------------------------------------------------------- 1 | import jax.random as jr 2 | import jax.numpy as jnp 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torchvision import datasets, transforms 7 | 8 | 9 | DATA_DIR = "datasets" 10 | IMAGENET_DIR = f"~/projects/jpc/experiments/{DATA_DIR}/ImageNet" 11 | 12 | 13 | def make_gaussian_dataset(key, mean, std, shape): 14 | x = mean + std * jr.normal(key, shape) 15 | y = x 16 | return (x, y) 17 | 18 | 19 | def get_dataloaders(dataset_id, batch_size, flatten=True): 20 | train_data = get_dataset( 21 | id=dataset_id, 22 | train=True, 23 | normalise=True, 24 | flatten=flatten 25 | ) 26 | test_data = get_dataset( 27 | id=dataset_id, 28 | train=False, 29 | normalise=True, 30 | flatten=flatten 31 | ) 32 | train_loader = DataLoader( 33 | dataset=train_data, 34 | batch_size=batch_size, 35 | shuffle=True, 36 | drop_last=True 37 | ) 38 | test_loader = DataLoader( 39 | dataset=test_data, 40 | batch_size=batch_size, 41 | shuffle=True, 42 | drop_last=True 43 | ) 44 | return train_loader, test_loader 45 | 46 | 47 | def get_dataset(id, train, normalise, flatten=True): 48 | if id == "MNIST": 49 | dataset = MNIST(train=train, normalise=normalise, flatten=flatten) 50 | elif id == "Fashion-MNIST": 51 | dataset = FashionMNIST(train=train, normalise=normalise, flatten=flatten) 52 | elif id == "CIFAR10": 53 | dataset = CIFAR10(train=train, normalise=normalise, flatten=flatten) 54 | else: 55 | raise ValueError( 56 | "Invalid dataset ID. Options are `MNIST`, `Fashion-MNIST` and `CIFAR10`" 57 | ) 58 | return dataset 59 | 60 | 61 | def get_imagenet_loaders(batch_size): 62 | train_data, val_data = ImageNet(split="train"), ImageNet(split="val") 63 | train_loader = DataLoader( 64 | dataset=train_data, 65 | batch_size=batch_size, 66 | shuffle=True, 67 | drop_last=True, 68 | num_workers=32, 69 | persistent_workers=True 70 | ) 71 | val_loader = DataLoader( 72 | dataset=val_data, 73 | batch_size=batch_size, 74 | shuffle=True, 75 | drop_last=True, 76 | num_workers=32, 77 | persistent_workers=True 78 | ) 79 | return train_loader, val_loader 80 | 81 | 82 | class MNIST(datasets.MNIST): 83 | def __init__(self, train, normalise=True, flatten=True, save_dir=DATA_DIR): 84 | self.flatten = flatten 85 | if normalise: 86 | transform = transforms.Compose( 87 | [ 88 | transforms.ToTensor(), 89 | transforms.Normalize( 90 | mean=(0.1307), std=(0.3081) 91 | ) 92 | ] 93 | ) 94 | else: 95 | transform = transforms.Compose([transforms.ToTensor()]) 96 | super().__init__(save_dir, download=True, train=train, transform=transform) 97 | 98 | def __getitem__(self, index): 99 | img, label = super().__getitem__(index) 100 | if self.flatten: 101 | img = torch.flatten(img) 102 | label = one_hot(label, n_classes=10) 103 | return img, label 104 | 105 | 106 | class FashionMNIST(datasets.FashionMNIST): 107 | def __init__(self, train, normalise=True, flatten=True, save_dir=DATA_DIR): 108 | self.flatten = flatten 109 | if normalise: 110 | transform = transforms.Compose( 111 | [ 112 | transforms.ToTensor(), 113 | transforms.Normalize( 114 | mean=(0.5), std=(0.5) 115 | ) 116 | ] 117 | ) 118 | else: 119 | transform = transforms.Compose([transforms.ToTensor()]) 120 | super().__init__(save_dir, download=True, train=train, transform=transform) 121 | 122 | def __getitem__(self, index): 123 | img, label = super().__getitem__(index) 124 | if self.flatten: 125 | img = torch.flatten(img) 126 | label = one_hot(label) 127 | return img, label 128 | 129 | 130 | class CIFAR10(datasets.CIFAR10): 131 | def __init__(self, train, normalise=True, flatten=True, save_dir=f"{DATA_DIR}/CIFAR10"): 132 | self.flatten = flatten 133 | if normalise: 134 | if train: 135 | transform = transforms.Compose( 136 | [ 137 | transforms.Resize((32,32)), 138 | transforms.RandomCrop(32, padding=4), 139 | transforms.RandomHorizontalFlip(), 140 | transforms.RandomRotation(10), 141 | transforms.ToTensor(), 142 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 143 | ] 144 | ) 145 | else: 146 | transform = transforms.Compose( 147 | [ 148 | transforms.Resize((32,32)), 149 | transforms.ToTensor(), 150 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 151 | ] 152 | ) 153 | else: 154 | transform = transforms.Compose([transforms.ToTensor()]) 155 | super().__init__(save_dir, download=True, train=train, transform=transform) 156 | 157 | def __getitem__(self, index): 158 | img, label = super().__getitem__(index) 159 | if self.flatten: 160 | img = torch.flatten(img) 161 | label = one_hot(label) 162 | return img, label 163 | 164 | 165 | class ImageNet(datasets.ImageNet): 166 | def __init__(self, split): 167 | if split == "train": 168 | transform = transforms.Compose([ 169 | transforms.RandomResizedCrop(224), 170 | transforms.RandomHorizontalFlip(), 171 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), 172 | transforms.ToTensor(), 173 | transforms.Normalize( 174 | mean=[0.485, 0.456, 0.406], 175 | std=[0.229, 0.224, 0.225]) 176 | ]) 177 | elif split == "val": 178 | transform = transforms.Compose( 179 | [ 180 | transforms.Resize(256), 181 | transforms.CenterCrop(224), 182 | transforms.ToTensor(), 183 | transforms.Normalize( 184 | mean=[0.485, 0.456, 0.406], 185 | std=[0.229, 0.224, 0.225] 186 | ), 187 | ] 188 | ) 189 | super().__init__(root=IMAGENET_DIR, split=split, transform=transform) 190 | 191 | def __getitem__(self, index): 192 | img, label = super().__getitem__(index) 193 | label = one_hot(label, n_classes=1000) 194 | return img, label 195 | 196 | 197 | def one_hot(labels, n_classes=10): 198 | arr = torch.eye(n_classes) 199 | return arr[labels] 200 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for utility functions.""" 2 | 3 | import pytest 4 | import jax 5 | import jax.numpy as jnp 6 | from jpc import ( 7 | make_mlp, 8 | make_skip_model, 9 | get_act_fn, 10 | mse_loss, 11 | cross_entropy_loss, 12 | compute_accuracy, 13 | get_t_max, 14 | compute_activity_norms, 15 | compute_param_norms, 16 | compute_infer_energies 17 | ) 18 | 19 | 20 | def test_make_mlp(key, input_dim, hidden_dim, output_dim, depth): 21 | """Test MLP creation.""" 22 | model = make_mlp( 23 | key=key, 24 | input_dim=input_dim, 25 | width=hidden_dim, 26 | depth=depth, 27 | output_dim=output_dim, 28 | act_fn="relu", 29 | use_bias=False, 30 | param_type="sp" 31 | ) 32 | 33 | assert len(model) == depth 34 | assert model[0][1].weight.shape == (hidden_dim, input_dim) 35 | assert model[-1][1].weight.shape == (output_dim, hidden_dim) 36 | 37 | 38 | def test_make_mlp_different_activations(key, input_dim, hidden_dim, output_dim, depth): 39 | """Test MLP creation with different activation functions.""" 40 | for act_fn in ["linear", "tanh", "relu", "gelu", "silu"]: 41 | model = make_mlp( 42 | key=key, 43 | input_dim=input_dim, 44 | width=hidden_dim, 45 | depth=depth, 46 | output_dim=output_dim, 47 | act_fn=act_fn, 48 | use_bias=False, 49 | param_type="sp" 50 | ) 51 | 52 | assert len(model) == depth 53 | 54 | 55 | def test_make_mlp_with_bias(key, input_dim, hidden_dim, output_dim, depth): 56 | """Test MLP creation with bias.""" 57 | model = make_mlp( 58 | key=key, 59 | input_dim=input_dim, 60 | width=hidden_dim, 61 | depth=depth, 62 | output_dim=output_dim, 63 | act_fn="relu", 64 | use_bias=True, 65 | param_type="sp" 66 | ) 67 | 68 | assert len(model) == depth 69 | 70 | 71 | def test_make_mlp_different_param_types(key, input_dim, hidden_dim, output_dim, depth): 72 | """Test MLP creation with different parameter types.""" 73 | for param_type in ["sp", "mupc", "ntp"]: 74 | model = make_mlp( 75 | key=key, 76 | input_dim=input_dim, 77 | width=hidden_dim, 78 | depth=depth, 79 | output_dim=output_dim, 80 | act_fn="relu", 81 | use_bias=False, 82 | param_type=param_type 83 | ) 84 | 85 | assert len(model) == depth 86 | 87 | 88 | def test_make_skip_model(depth): 89 | """Test skip model creation.""" 90 | skip_model = make_skip_model(depth) 91 | 92 | assert len(skip_model) == depth 93 | # First and last should be None 94 | assert skip_model[0] is None 95 | assert skip_model[-1] is None 96 | 97 | 98 | def test_get_act_fn(): 99 | """Test activation function retrieval.""" 100 | act_fns = ["linear", "tanh", "hard_tanh", "relu", "leaky_relu", "gelu", "selu", "silu"] 101 | 102 | for act_fn_name in act_fns: 103 | act_fn = get_act_fn(act_fn_name) 104 | assert callable(act_fn) 105 | 106 | # Test invalid activation function 107 | with pytest.raises(ValueError): 108 | get_act_fn("invalid") 109 | 110 | 111 | def test_mse_loss(key, batch_size, output_dim): 112 | """Test MSE loss computation.""" 113 | preds = jax.random.normal(key, (batch_size, output_dim)) 114 | labels = jax.random.normal(key, (batch_size, output_dim)) 115 | 116 | loss = mse_loss(preds, labels) 117 | 118 | assert jnp.isfinite(loss) 119 | assert loss >= 0 120 | 121 | 122 | def test_cross_entropy_loss(key, batch_size, output_dim): 123 | """Test cross-entropy loss computation.""" 124 | logits = jax.random.normal(key, (batch_size, output_dim)) 125 | labels = jax.nn.one_hot( 126 | jax.random.randint(key, (batch_size,), 0, output_dim), 127 | output_dim 128 | ) 129 | 130 | loss = cross_entropy_loss(logits, labels) 131 | 132 | assert jnp.isfinite(loss) 133 | assert loss >= 0 134 | 135 | 136 | def test_compute_accuracy(key, batch_size, output_dim): 137 | """Test accuracy computation.""" 138 | truths = jax.nn.one_hot( 139 | jax.random.randint(key, (batch_size,), 0, output_dim), 140 | output_dim 141 | ) 142 | preds = jax.nn.one_hot( 143 | jax.random.randint(key, (batch_size,), 0, output_dim), 144 | output_dim 145 | ) 146 | 147 | acc = compute_accuracy(truths, preds) 148 | 149 | assert 0 <= acc <= 100 150 | assert jnp.isfinite(acc) 151 | 152 | 153 | def test_get_t_max(key, batch_size, hidden_dim): 154 | """Test t_max computation.""" 155 | # Create fake activities_iters with time dimension 156 | # The function looks for argmax in activities_iters[0][:, 0, 0] then subtracts 1 157 | # We need to ensure there's a valid maximum 158 | activities_iters = [ 159 | jnp.zeros((100, batch_size, hidden_dim)) 160 | ] 161 | # Set a value at index 10 to be non-zero so argmax returns 10, then t_max = 10 - 1 = 9 162 | activities_iters[0] = activities_iters[0].at[10, 0, 0].set(1.0) 163 | 164 | t_max = get_t_max(activities_iters) 165 | 166 | assert jnp.isfinite(t_max) 167 | # t_max is argmax - 1, so with value at index 10, argmax is 10, so t_max is 9 168 | assert t_max >= 0 169 | 170 | 171 | def test_compute_activity_norms(simple_model, x): 172 | """Test activity norm computation.""" 173 | from jpc import init_activities_with_ffwd 174 | 175 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp") 176 | 177 | norms = compute_activity_norms(activities) 178 | 179 | assert len(norms) == len(activities) 180 | assert all(jnp.isfinite(n) and n >= 0 for n in norms) 181 | 182 | 183 | def test_compute_param_norms(simple_model): 184 | """Test parameter norm computation.""" 185 | # Skip model contains Lambda functions which don't have .weight attributes 186 | # and cause issues with compute_param_norms. Test without skip_model instead. 187 | model_norms, skip_norms = compute_param_norms((simple_model, None)) 188 | 189 | assert len(model_norms) > 0 190 | assert skip_norms is None 191 | assert all(jnp.isfinite(n) and n >= 0 for n in model_norms) 192 | 193 | 194 | def test_compute_infer_energies(simple_model, x, y): 195 | """Test inference energy computation.""" 196 | from jpc import init_activities_with_ffwd, solve_inference 197 | from diffrax import Heun, PIDController 198 | 199 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp") 200 | 201 | activities_iters = solve_inference( 202 | params=(simple_model, None), 203 | activities=activities, 204 | output=y, 205 | input=x, 206 | loss_id="mse", 207 | param_type="sp", 208 | solver=Heun(), 209 | max_t1=10, 210 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3), 211 | record_iters=True 212 | ) 213 | 214 | t_max = get_t_max(activities_iters) 215 | 216 | energies = compute_infer_energies( 217 | params=(simple_model, None), 218 | activities_iters=activities_iters, 219 | t_max=t_max, 220 | y=y, 221 | x=x, 222 | loss="mse", 223 | param_type="sp" 224 | ) 225 | 226 | assert energies.shape[0] == len(simple_model) 227 | assert all(jnp.all(jnp.isfinite(e)) for e in energies) 228 | -------------------------------------------------------------------------------- /experiments/library_paper/test_theory_energies.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import jax.random as jr 5 | import jax.numpy as jnp 6 | import equinox as eqx 7 | import optax 8 | import jpc 9 | 10 | from utils import set_seed 11 | from experiments.datasets import get_dataloaders 12 | 13 | import plotly.graph_objs as go 14 | import plotly.colors as pc 15 | 16 | 17 | def plot_accuracies(accuracies, save_path): 18 | n_train_iters = len(accuracies[10]) 19 | train_iters = [t+1 for t in range(n_train_iters)] 20 | 21 | colorscale = "Blues" 22 | colors = pc.sample_colorscale(colorscale, len(accuracies)+2)[2:][::-1] 23 | fig = go.Figure() 24 | for i, (max_t1, accuracy) in enumerate(accuracies.items()): 25 | fig.add_trace( 26 | go.Scatter( 27 | x=train_iters, 28 | y=accuracy, 29 | mode="lines", 30 | line=dict(width=2, color=colors[i]), 31 | name=f"$t = {max_t1}$" 32 | ) 33 | ) 34 | fig.update_layout( 35 | height=350, 36 | width=500, 37 | xaxis=dict( 38 | title="Training iteration", 39 | tickvals=[1, int(train_iters[-1]/2), train_iters[-1]], 40 | ticktext=[1, int(train_iters[-1]/2)*10, train_iters[-1]*10] 41 | ), 42 | yaxis=dict(title="Test accuracy (%)"), 43 | font=dict(size=16), 44 | margin=dict(r=120) 45 | ) 46 | fig.write_image(save_path) 47 | 48 | 49 | def plot_energies_across_ts(theory_energies, num_energies, save_path): 50 | n_train_iters = len(theory_energies) 51 | train_iters = [t+1 for t in range(n_train_iters)] 52 | 53 | colorscale = "Greens" 54 | colors = pc.sample_colorscale(colorscale, len(num_energies)+3)[2:][::-1] 55 | fig = go.Figure() 56 | fig.add_traces( 57 | go.Scatter( 58 | x=train_iters, 59 | y=theory_energies, 60 | name="theory", 61 | mode="lines", 62 | line=dict( 63 | width=3, 64 | dash="dash", 65 | color=colors[0] 66 | ), 67 | ) 68 | ) 69 | for i, (max_t1, num_energy) in enumerate(num_energies.items()): 70 | fig.add_trace( 71 | go.Scatter( 72 | x=train_iters, 73 | y=num_energy, 74 | mode="lines", 75 | line=dict(width=2, color=colors[i+1]), 76 | name=f"$t = {max_t1}$" 77 | ) 78 | ) 79 | fig.update_layout( 80 | height=350, 81 | width=500, 82 | xaxis=dict( 83 | title="Training iteration", 84 | tickvals=[1, int(train_iters[-1]/2), train_iters[-1]], 85 | ticktext=[1, int(train_iters[-1]/2), train_iters[-1]] 86 | ), 87 | yaxis=dict(title="Energy"), 88 | font=dict(size=16), 89 | margin=dict(r=120) 90 | ) 91 | fig.write_image(save_path) 92 | 93 | 94 | def evaluate(model, test_loader): 95 | avg_test_loss, avg_test_acc = 0, 0 96 | for _, (img_batch, label_batch) in enumerate(test_loader): 97 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy() 98 | 99 | test_loss, test_acc = jpc.test_discriminative_pc( 100 | model=model, 101 | output=label_batch, 102 | input=img_batch 103 | ) 104 | avg_test_loss += test_loss 105 | avg_test_acc += test_acc 106 | 107 | return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader) 108 | 109 | 110 | def train( 111 | dataset, 112 | width, 113 | n_hidden, 114 | lr, 115 | batch_size, 116 | max_t1, 117 | test_every, 118 | n_train_iters, 119 | save_dir 120 | ): 121 | key = jr.PRNGKey(0) 122 | input_dim = 3072 if dataset == "CIFAR10" else 784 123 | model = jpc.make_mlp( 124 | key, 125 | input_dim=input_dim, 126 | width=width, 127 | depth=n_hidden+1, 128 | output_dim=10, 129 | act_fn="linear", 130 | use_bias=False 131 | ) 132 | optim = optax.adam(lr) 133 | opt_state = optim.init( 134 | (eqx.filter(model, eqx.is_array), None) 135 | ) 136 | train_loader, test_loader = get_dataloaders(dataset, batch_size) 137 | 138 | test_accs = [] 139 | theory_energies, num_energies = [], [] 140 | for batch_id, (img_batch, label_batch) in enumerate(train_loader): 141 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy() 142 | 143 | theory_energies.append( 144 | jpc.compute_linear_equilib_energy( 145 | network=model, 146 | x=img_batch, 147 | y=label_batch 148 | ) 149 | ) 150 | result = jpc.make_pc_step( 151 | model, 152 | optim, 153 | opt_state, 154 | output=label_batch, 155 | input=img_batch, 156 | max_t1=max_t1, 157 | record_energies=True 158 | ) 159 | model, opt_state = result["model"], result["opt_state"] 160 | train_loss, t_max = result["loss"], result["t_max"] 161 | num_energies.append(result["energies"][:, t_max-1].sum()) 162 | 163 | if ((batch_id+1) % test_every) == 0: 164 | _, avg_test_acc = evaluate(model, test_loader) 165 | test_accs.append(avg_test_acc) 166 | print( 167 | f"Train iter {batch_id+1}, train loss={train_loss:4f}, " 168 | f"avg test accuracy={avg_test_acc:4f}" 169 | ) 170 | if (batch_id+1) >= n_train_iters: 171 | break 172 | 173 | np.save(f"{save_dir}/test_accs.npy", test_accs) 174 | np.save(f"{save_dir}/theory_energies.npy", theory_energies) 175 | np.save(f"{save_dir}/num_energies.npy", num_energies) 176 | 177 | return test_accs, jnp.array(theory_energies), jnp.array(num_energies) 178 | 179 | 180 | if __name__ == "__main__": 181 | RESULTS_DIR = "theory_energies_results" 182 | DATASETS = ["MNIST", "Fashion-MNIST"] 183 | SEED = 916 184 | WIDTH = 300 185 | N_HIDDEN = 10 186 | LR = 1e-3 187 | BATCH_SIZE = 64 188 | MAX_T1S = [200, 100, 50, 20, 10] 189 | TEST_EVERY = 10 190 | N_TRAIN_ITERS = 100 191 | 192 | for dataset in DATASETS: 193 | set_seed(SEED) 194 | all_test_accs, all_theory_energies, all_num_energies = {}, {}, {} 195 | for max_t1 in MAX_T1S: 196 | print(f"\nmax_t1: {max_t1}") 197 | save_dir = os.path.join(RESULTS_DIR, dataset, f"max_t1_{max_t1}") 198 | os.makedirs(save_dir, exist_ok=True) 199 | test_accs, theory_energies, num_energies = train( 200 | dataset=dataset, 201 | width=WIDTH, 202 | n_hidden=N_HIDDEN, 203 | lr=LR, 204 | batch_size=BATCH_SIZE, 205 | max_t1=max_t1, 206 | test_every=TEST_EVERY, 207 | n_train_iters=N_TRAIN_ITERS, 208 | save_dir=save_dir 209 | ) 210 | all_test_accs[max_t1] = test_accs 211 | all_theory_energies[max_t1] = theory_energies 212 | all_num_energies[max_t1] = num_energies 213 | 214 | plot_accuracies( 215 | all_test_accs, 216 | f"{RESULTS_DIR}/{dataset}/test_accs.pdf" 217 | ) 218 | plot_energies_across_ts( 219 | all_theory_energies[MAX_T1S[0]], 220 | all_num_energies, 221 | f"{RESULTS_DIR}/{dataset}/energies.pdf" 222 | ) 223 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 |

6 | 7 |

🧠 Flexible Inference for Predictive Coding Networks in JAX ⚡️

8 | 9 | ![status](https://img.shields.io/badge/status-active-green) [![Paper](https://img.shields.io/badge/Paper-arXiv:2508.01191-%23f2806bff.svg)](https://arxiv.org/abs/2412.03676) 10 | ![coverage](https://img.shields.io/badge/coverage-93%25-brightgreen) 11 | 12 | ## 📢 Updates 13 | * [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/epc.ipynb) Nov 2025: Added "ePC" ([Goemaere, et al., 2025](https://arxiv.org/abs/2505.20137)) 14 | * [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/bidirectional_pc.ipynb) Oct 2025: Added bidirectional PC (bPC, [Oliviers, et al., 2025](https://arxiv.org/abs/2505.23415)) 15 | * [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/mupc.ipynb) May 2025: Added "μPC" ([Innocenti et al., 2025](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))) 16 | 17 | --- 18 | 19 | JPC is a [**J**AX](https://github.com/google/jax) library for training neural 20 | networks with **P**redictive **C**oding (PC). 21 | 22 | JPC provides a **simple**, **fast** and **flexible** API for 23 | training of a variety of PCNs including discriminative, generative and hybrid 24 | models. 25 | * Like JAX, JPC is completely functional in design, and the core library code is 26 | <1000 lines of code. 27 | * Unlike existing implementations, JPC provides a wide range of optimisers, both 28 | discrete and continuous, to solve the inference dynamics of PC, including 29 | ordinary differential equation (ODE) solvers. 30 | * JPC also provides some analytical tools that can be used to study and 31 | potentially diagnose issues with PCNs. 32 | 33 | If you're new to JPC, we recommend starting from the [ 34 | example notebooks](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/) 35 | and checking the [documentation](https://thebuckleylab.github.io/jpc/). 36 | 37 | ## Overview 38 | * [Installation](#-installation) 39 | * [Documentation](#-documentation) 40 | * [Quick example](#-quick-example) 41 | * [Advanced usage](#-advanced-usage) 42 | * [Contributing](#-contributing) 43 | * [Citation](#-citation) 44 | 45 | ## ️💻 Installation 46 | Clone the repo and in the project's directory run 47 | ``` 48 | pip install . 49 | ``` 50 | 51 | Requires Python 3.10+ and JAX 0.4.38–0.5.2 (inclusive). For GPU usage, upgrade 52 | jax to the appropriate cuda version (12 as an example here). 53 | 54 | ``` 55 | pip install --upgrade "jax[cuda12]" 56 | ``` 57 | 58 | ## 📖 [Documentation](https://thebuckleylab.github.io/jpc/) 59 | Available at https://thebuckleylab.github.io/jpc/. 60 | 61 | ## ⚡️ Quick example 62 | Use `jpc.make_pc_step()` to update the parameters of any neural network 63 | compatible with PC updates (see the [notebook examples 64 | ](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/)) 65 | ```py 66 | import jax.random as jr 67 | import jax.numpy as jnp 68 | import equinox as eqx 69 | import optax 70 | import jpc 71 | 72 | # toy data 73 | x = jnp.array([1., 1., 1.]) 74 | y = -x 75 | 76 | # define model and optimiser 77 | key = jr.PRNGKey(0) 78 | model = jpc.make_mlp( 79 | key, 80 | input_dim=3, 81 | width=50, 82 | depth=5, 83 | output_dim=3 84 | act_fn="relu" 85 | ) 86 | optim = optax.adam(1e-3) 87 | opt_state = optim.init( 88 | (eqx.filter(model, eqx.is_array), None) 89 | ) 90 | 91 | # perform one training step with PC 92 | result = jpc.make_pc_step( 93 | model=model, 94 | optim=optim, 95 | opt_state=opt_state, 96 | output=y, 97 | input=x 98 | ) 99 | 100 | # updated model and optimiser 101 | model, opt_state = result["model"], result["opt_state"] 102 | ``` 103 | Under the hood, `jpc.make_pc_step()` 104 | 1. integrates the inference (activity) dynamics using a [diffrax](https://github.com/patrick-kidger/diffrax) ODE solver, and 105 | 2. updates model parameters at the numerical solution of the activities with a given [optax](https://github.com/google-deepmind/optax) optimiser. 106 | 107 | See the [documentation](https://thebuckleylab.github.io/jpc/) for more details. 108 | 109 | > **NOTE**: All convenience training and test functions such as `make_pc_step()` 110 | > are already "jitted" (for optimised performance) for the user's convenience. 111 | 112 | ## 🚀 Advanced usage 113 | Advanced users can access all the underlying functions of `jpc.make_pc_step()` 114 | as well as additional features. A custom PC training step looks like the 115 | following: 116 | ```py 117 | import jpc 118 | 119 | # 1. initialise activities with a feedforward pass 120 | activities = jpc.init_activities_with_ffwd(model=model, input=x) 121 | 122 | # 2. perform inference (state optimisation) 123 | activity_opt_state = activity_optim.init(activities) 124 | for _ in range(len(model)): 125 | activity_update_result = jpc.update_pc_activities( 126 | params=(model, None), 127 | activities=activities, 128 | optim=activity_optim, 129 | opt_state=activity_opt_state, 130 | output=y, 131 | input=x 132 | ) 133 | activities = activity_update_result["activities"] 134 | activity_opt_state = activity_update_result["opt_state"] 135 | 136 | # 3. update parameters at the activities' solution with PC 137 | result = jpc.update_params( 138 | params=(model, None), 139 | activities=converged_activities, 140 | optim=optim, 141 | opt_state=opt_state, 142 | output=y, 143 | input=x 144 | ) 145 | ``` 146 | which can be embedded in a jitted function with any other additional 147 | computations. Again, see the [docs](https://thebuckleylab.github.io/jpc/) 148 | for details. 149 | 150 | ## 🤝 Contributing 151 | Contributions are welcome! Fork the repo, install in editable mode (`pip install -e .`), then: 152 | * Run `ruff check .` before committing (auto-fix with `ruff check --fix .`) 153 | * Ensure all tests pass: `pytest tests/` 154 | * Add docstrings to public functions and update `docs/` for user-facing changes 155 | * Open a PR with a clear description 156 | 157 | For major features, open an issue first to discuss. 158 | 159 | ## 📄 Citation 160 | If you found this library useful in your work, please cite ([paper link](https://arxiv.org/abs/2412.03676)): 161 | 162 | ```bibtex 163 | @article{innocenti2024jpc, 164 | title={JPC: Flexible Inference for Predictive Coding Networks in JAX}, 165 | author={Innocenti, Francesco and Kinghorn, Paul and Yun-Farmbrough, Will and Varona, Miguel De Llanza and Singh, Ryan and Buckley, Christopher L}, 166 | journal={arXiv preprint arXiv:2412.03676}, 167 | year={2024} 168 | } 169 | ``` 170 | Also consider starring the repo! ⭐️ 171 | 172 | ## 🙏 Acknowledgements 173 | We are grateful to Patrick Kidger for early advice on how to use Diffrax. 174 | 175 | ## See also: other PC libraries 176 | * [ngc-learn](https://github.com/NACLab/ngc-learn) (jax & pytorch) 177 | * [pcx](https://github.com/liukidar/pcx) (jax) 178 | * [pyhgf](https://github.com/ComputationalPsychiatry/pyhgf) (jax) 179 | * [Torch2PC](https://github.com/RobertRosenbaum/Torch2PC) (pytorch) 180 | * [pypc](https://github.com/infer-actively/pypc) (pytorch) 181 | * [pybrid](https://github.com/alec-tschantz/pybrid) (pytorch) 182 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | """Tests for training functions.""" 2 | 3 | import pytest 4 | import jax 5 | import jax.numpy as jnp 6 | import optax 7 | from diffrax import Heun, PIDController 8 | from jpc import make_pc_step, make_hpc_step 9 | 10 | 11 | def test_make_pc_step_supervised(simple_model, x, y): 12 | """Test PC training step in supervised mode.""" 13 | optim = optax.sgd(learning_rate=0.01) 14 | opt_state = optim.init((simple_model, None)) 15 | 16 | result = make_pc_step( 17 | model=simple_model, 18 | optim=optim, 19 | opt_state=opt_state, 20 | output=y, 21 | input=x, 22 | loss_id="mse", 23 | param_type="sp", 24 | ode_solver=Heun(), 25 | max_t1=5, 26 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3) 27 | ) 28 | 29 | assert "model" in result 30 | assert "skip_model" in result 31 | assert "opt_state" in result 32 | assert "loss" in result 33 | assert len(result["model"]) == len(simple_model) 34 | 35 | 36 | def test_make_pc_step_unsupervised(simple_model, y, key, layer_sizes, batch_size): 37 | """Test PC training step in unsupervised mode.""" 38 | optim = optax.sgd(learning_rate=0.01) 39 | opt_state = optim.init((simple_model, None)) 40 | 41 | result = make_pc_step( 42 | model=simple_model, 43 | optim=optim, 44 | opt_state=opt_state, 45 | output=y, 46 | input=None, 47 | loss_id="mse", 48 | param_type="sp", 49 | ode_solver=Heun(), 50 | max_t1=5, 51 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3), 52 | key=key, 53 | layer_sizes=layer_sizes, 54 | batch_size=batch_size, 55 | sigma=0.05 56 | ) 57 | 58 | assert "model" in result 59 | assert "opt_state" in result 60 | 61 | 62 | def test_make_pc_step_cross_entropy(simple_model, x, y_onehot): 63 | """Test PC training step with cross-entropy loss.""" 64 | optim = optax.sgd(learning_rate=0.01) 65 | opt_state = optim.init((simple_model, None)) 66 | 67 | result = make_pc_step( 68 | model=simple_model, 69 | optim=optim, 70 | opt_state=opt_state, 71 | output=y_onehot, 72 | input=x, 73 | loss_id="ce", 74 | param_type="sp", 75 | ode_solver=Heun(), 76 | max_t1=5, 77 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3) 78 | ) 79 | 80 | assert "model" in result 81 | assert "loss" in result 82 | 83 | 84 | def test_make_pc_step_with_regularization(simple_model, x, y): 85 | """Test PC training step with regularization.""" 86 | optim = optax.sgd(learning_rate=0.01) 87 | opt_state = optim.init((simple_model, None)) 88 | 89 | result = make_pc_step( 90 | model=simple_model, 91 | optim=optim, 92 | opt_state=opt_state, 93 | output=y, 94 | input=x, 95 | loss_id="mse", 96 | param_type="sp", 97 | ode_solver=Heun(), 98 | max_t1=5, 99 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3), 100 | weight_decay=0.01, 101 | spectral_penalty=0.01, 102 | activity_decay=0.01 103 | ) 104 | 105 | assert "model" in result 106 | 107 | 108 | def test_make_pc_step_with_metrics(simple_model, x, y): 109 | """Test PC training step with metrics recording.""" 110 | optim = optax.sgd(learning_rate=0.01) 111 | opt_state = optim.init((simple_model, None)) 112 | 113 | result = make_pc_step( 114 | model=simple_model, 115 | optim=optim, 116 | opt_state=opt_state, 117 | output=y, 118 | input=x, 119 | loss_id="mse", 120 | param_type="sp", 121 | ode_solver=Heun(), 122 | max_t1=5, 123 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3), 124 | record_activities=True, 125 | record_energies=True, 126 | activity_norms=True, 127 | param_norms=True, 128 | grad_norms=True, 129 | calculate_accuracy=True 130 | ) 131 | 132 | assert "model" in result 133 | assert "activities" in result 134 | assert "energies" in result 135 | assert "activity_norms" in result 136 | assert "model_param_norms" in result 137 | assert "acc" in result 138 | 139 | 140 | def test_make_pc_step_invalid_input(): 141 | """Test PC training step with invalid input (missing required args for unsupervised).""" 142 | import jax 143 | from jpc import make_mlp 144 | import optax 145 | from diffrax import Heun, PIDController 146 | 147 | key = jax.random.PRNGKey(42) 148 | model = make_mlp(key, 10, 20, 5, 3, "relu", False, "sp") 149 | optim = optax.sgd(learning_rate=0.01) 150 | opt_state = optim.init((model, None)) 151 | y = jax.random.normal(key, (4, 5)) 152 | 153 | with pytest.raises(ValueError): 154 | make_pc_step( 155 | model=model, 156 | optim=optim, 157 | opt_state=opt_state, 158 | output=y, 159 | input=None, 160 | loss_id="mse", 161 | param_type="sp", 162 | ode_solver=Heun(), 163 | max_t1=5, 164 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3) 165 | ) 166 | 167 | 168 | def test_make_hpc_step(key, simple_model, x, y, output_dim, hidden_dim, input_dim, depth): 169 | """Test HPC training step.""" 170 | from jpc import make_mlp 171 | 172 | generator = simple_model 173 | amortiser = make_mlp( 174 | key=key, 175 | input_dim=output_dim, 176 | width=hidden_dim, 177 | depth=depth, 178 | output_dim=input_dim, 179 | act_fn="relu", 180 | use_bias=False, 181 | param_type="sp" 182 | ) 183 | 184 | gen_optim = optax.sgd(learning_rate=0.01) 185 | amort_optim = optax.sgd(learning_rate=0.01) 186 | gen_opt_state = gen_optim.init((generator, None)) 187 | amort_opt_state = amort_optim.init(amortiser) 188 | 189 | result = make_hpc_step( 190 | generator=generator, 191 | amortiser=amortiser, 192 | optims=(gen_optim, amort_optim), 193 | opt_states=(gen_opt_state, amort_opt_state), 194 | output=y, 195 | input=x, 196 | ode_solver=Heun(), 197 | max_t1=5, 198 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3) 199 | ) 200 | 201 | assert "generator" in result 202 | assert "amortiser" in result 203 | assert "opt_states" in result 204 | assert "losses" in result 205 | assert len(result["opt_states"]) == 2 206 | 207 | 208 | def test_make_hpc_step_unsupervised(key, simple_model, y, output_dim, hidden_dim, input_dim, depth): 209 | """Test HPC training step in unsupervised mode.""" 210 | from jpc import make_mlp 211 | 212 | generator = simple_model 213 | amortiser = make_mlp( 214 | key=key, 215 | input_dim=output_dim, 216 | width=hidden_dim, 217 | depth=depth, 218 | output_dim=input_dim, 219 | act_fn="relu", 220 | use_bias=False, 221 | param_type="sp" 222 | ) 223 | 224 | gen_optim = optax.sgd(learning_rate=0.01) 225 | amort_optim = optax.sgd(learning_rate=0.01) 226 | gen_opt_state = gen_optim.init((generator, None)) 227 | amort_opt_state = amort_optim.init(amortiser) 228 | 229 | result = make_hpc_step( 230 | generator=generator, 231 | amortiser=amortiser, 232 | optims=(gen_optim, amort_optim), 233 | opt_states=(gen_opt_state, amort_opt_state), 234 | output=y, 235 | input=None, 236 | ode_solver=Heun(), 237 | max_t1=5, 238 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3) 239 | ) 240 | 241 | assert "generator" in result 242 | assert "amortiser" in result 243 | 244 | 245 | def test_make_hpc_step_with_recording(key, simple_model, x, y, output_dim, hidden_dim, input_dim, depth): 246 | """Test HPC training step with activity/energy recording.""" 247 | from jpc import make_mlp 248 | 249 | generator = simple_model 250 | amortiser = make_mlp( 251 | key=key, 252 | input_dim=output_dim, 253 | width=hidden_dim, 254 | depth=depth, 255 | output_dim=input_dim, 256 | act_fn="relu", 257 | use_bias=False, 258 | param_type="sp" 259 | ) 260 | 261 | gen_optim = optax.sgd(learning_rate=0.01) 262 | amort_optim = optax.sgd(learning_rate=0.01) 263 | gen_opt_state = gen_optim.init((generator, None)) 264 | amort_opt_state = amort_optim.init(amortiser) 265 | 266 | result = make_hpc_step( 267 | generator=generator, 268 | amortiser=amortiser, 269 | optims=(gen_optim, amort_optim), 270 | opt_states=(gen_opt_state, amort_opt_state), 271 | output=y, 272 | input=x, 273 | ode_solver=Heun(), 274 | max_t1=5, 275 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3), 276 | record_activities=True, 277 | record_energies=True 278 | ) 279 | 280 | assert "activities" in result 281 | assert "energies" in result 282 | 283 | -------------------------------------------------------------------------------- /experiments/mupc_paper/analyse_activity_hessian.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jax 3 | import jax.random as jr 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | from experiments.datasets import make_gaussian_dataset, get_dataloaders 8 | from experiments.mupc_paper.utils import ( 9 | setup_hessian_analysis, 10 | set_seed, 11 | init_weights, 12 | get_network_weights, 13 | unwrap_hessian_pytree 14 | ) 15 | import jpc 16 | from experiments.mupc_paper.plotting import plot_activity_hessian 17 | 18 | 19 | def compute_hessian_metrics( 20 | network, 21 | act_fn, 22 | skip_model, 23 | y, 24 | x, 25 | use_skips, 26 | param_type, 27 | activity_decay, 28 | mode, 29 | layer_sizes, 30 | key 31 | ): 32 | if act_fn == "linear": 33 | # theoretical activity Hessian 34 | weights = get_network_weights(network) 35 | theory_H = jpc.compute_linear_activity_hessian( 36 | weights, 37 | param_type=param_type, 38 | use_skips=use_skips, 39 | activity_decay=activity_decay 40 | ) 41 | D = jpc.compute_linear_activity_hessian( 42 | weights, 43 | param_type=param_type, 44 | off_diag=False, 45 | use_skips=use_skips, 46 | activity_decay=activity_decay 47 | ) 48 | O = jpc.compute_linear_activity_hessian( 49 | weights, 50 | param_type=param_type, 51 | diag=False, 52 | use_skips=use_skips, 53 | activity_decay=activity_decay 54 | ) 55 | 56 | # numerical activity Hessian 57 | if mode == "supervised": 58 | activities = jpc.init_activities_with_ffwd( 59 | network, 60 | x, 61 | skip_model=skip_model, 62 | param_type=param_type 63 | ) 64 | elif mode == "unsupervised": 65 | activities = jpc.init_activities_from_normal( 66 | key=key, 67 | layer_sizes=layer_sizes, 68 | mode=mode, 69 | batch_size=1, 70 | sigma=1 71 | ) 72 | 73 | hessian_pytree = jax.hessian(jpc.pc_energy_fn, argnums=1)( 74 | (network, skip_model), 75 | activities, 76 | y, 77 | x=x, 78 | param_type=param_type, 79 | activity_decay=activity_decay 80 | ) 81 | num_H = unwrap_hessian_pytree( 82 | hessian_pytree, 83 | activities, 84 | ) 85 | 86 | # compute eigenthings 87 | num_H_eigenvals, _ = jnp.linalg.eigh(num_H) 88 | cond_num = jnp.linalg.cond(num_H) 89 | if act_fn == "linear": 90 | theory_H_eigenvals, _ = jnp.linalg.eigh(theory_H) 91 | D_eigenvals, _ = jnp.linalg.eigh(D) 92 | O_eigenvals, _ = jnp.linalg.eigh(O) 93 | 94 | return { 95 | "hessian": num_H, 96 | "num": num_H_eigenvals, 97 | "cond_num": cond_num, 98 | "theory": theory_H_eigenvals if act_fn == "linear" else None, 99 | "D": D_eigenvals if act_fn == "linear" else None, 100 | "O": O_eigenvals if act_fn == "linear" else None 101 | } 102 | 103 | 104 | def run_analysis( 105 | seed, 106 | in_out_dims, 107 | act_fn, 108 | use_biases, 109 | mode, 110 | use_skips, 111 | weight_init, 112 | param_type, 113 | activity_decay, 114 | width, 115 | n_hidden, 116 | save_dir 117 | ): 118 | set_seed(seed) 119 | key = jr.PRNGKey(seed) 120 | keys = jr.split(key, 4) 121 | 122 | d_in = width if in_out_dims == "width" else in_out_dims[0] 123 | d_out = width if in_out_dims == "width" else in_out_dims[1] 124 | 125 | # create and initialise model 126 | L = n_hidden+1 127 | network = jpc.make_mlp( 128 | key=keys[0], 129 | input_dim=d_in, 130 | width=width, 131 | depth=L, 132 | output_dim=d_out, 133 | act_fn=act_fn, 134 | use_bias=use_biases, 135 | param_type=param_type 136 | ) 137 | if weight_init != "standard": 138 | network = init_weights( 139 | key=keys[1], 140 | model=network, 141 | init_fn_id=weight_init 142 | ) 143 | skip_model = jpc.make_skip_model(L) if use_skips else None 144 | 145 | # data 146 | if in_out_dims != "width": 147 | train_loader, _ = get_dataloaders("MNIST", batch_size=1) 148 | img_batch, label_batch = next(iter(train_loader)) 149 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy() 150 | x, y = (img_batch, label_batch) if d_in == 784 else (label_batch, img_batch) 151 | else: 152 | x, y = make_gaussian_dataset(keys[2], 1, 0.1, (1, width)) 153 | if mode == "unsupervised": 154 | x = None 155 | 156 | layer_sizes = [d_in] + [width]*n_hidden + [d_out] 157 | metrics = compute_hessian_metrics( 158 | network=network, 159 | act_fn=act_fn, 160 | skip_model=skip_model, 161 | y=y, 162 | x=x, 163 | use_skips=use_skips, 164 | param_type=param_type, 165 | activity_decay=activity_decay, 166 | mode=mode, 167 | layer_sizes=layer_sizes, 168 | key=keys[3] 169 | ) 170 | plot_activity_hessian( 171 | metrics["hessian"], 172 | f"{save_dir}/hessian_matrix.pdf" 173 | ) 174 | np.save( 175 | f"{save_dir}/num_hessian_eigenvals", 176 | metrics["num"] 177 | ) 178 | np.save( 179 | f"{save_dir}/cond_num", 180 | metrics["cond_num"] 181 | ) 182 | if act_fn == "linear": 183 | np.save( 184 | f"{save_dir}/theory_hessian_eigenvals", 185 | metrics["theory"] 186 | ) 187 | np.save( 188 | f"{save_dir}/theory_D_eigenvals", 189 | metrics["D"] 190 | ) 191 | np.save( 192 | f"{save_dir}/theory_O_eigenvals", 193 | metrics["O"] 194 | ) 195 | 196 | 197 | if __name__ == "__main__": 198 | RESULTS_DIR = "activity_hessian_results" 199 | IN_OUT_DIMS = [[784, 10]] #, [784, 10], [10, 784]] 200 | ACT_FNS = ["linear"]#, "tanh", "relu"] 201 | USE_BIASES = [False] 202 | MODES = ["supervised"] #,"unsupervised"] 203 | USE_SKIPS = [False, True] 204 | WEIGHT_INITS = ["standard"]#["one_over_N", "standard", "orthogonal"] 205 | PARAM_TYPES = ["sp"]#, "mupc", "ntp"] 206 | ACTIVITY_DECAY = [False]#, True] 207 | WIDTHS = [2 ** i for i in range(11)] 208 | N_HIDDENS = [2 ** i for i in range(4)] 209 | N_SEEDS = 3 210 | 211 | for in_out_dims in IN_OUT_DIMS: 212 | for act_fn in ACT_FNS: 213 | for use_biases in USE_BIASES: 214 | for mode in MODES: 215 | for use_skips in USE_SKIPS: 216 | for weight_init in WEIGHT_INITS: 217 | for param_type in PARAM_TYPES: 218 | for activity_decay in ACTIVITY_DECAY: 219 | for width in WIDTHS: 220 | for n_hidden in N_HIDDENS: 221 | for seed in range(N_SEEDS): 222 | save_dir = setup_hessian_analysis( 223 | results_dir=RESULTS_DIR, 224 | in_out_dims=in_out_dims, 225 | act_fn=act_fn, 226 | use_biases=use_biases, 227 | mode=mode, 228 | use_skips=use_skips, 229 | weight_init=weight_init, 230 | param_type=param_type, 231 | activity_decay=activity_decay, 232 | width=width, 233 | n_hidden=n_hidden, 234 | seed=seed 235 | ) 236 | os.makedirs(save_dir, exist_ok=True) 237 | run_analysis( 238 | seed=seed, 239 | in_out_dims=in_out_dims, 240 | act_fn=act_fn, 241 | use_biases=use_biases, 242 | mode=mode, 243 | use_skips=use_skips, 244 | weight_init=weight_init, 245 | param_type=param_type, 246 | activity_decay=activity_decay, 247 | width=width, 248 | n_hidden=n_hidden, 249 | save_dir=save_dir 250 | ) 251 | -------------------------------------------------------------------------------- /experiments/mupc_paper/test_mlp_fwd_pass.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | import jax.random as jr 6 | import jax.numpy as jnp 7 | from jax import vmap 8 | 9 | import equinox as eqx 10 | import equinox.nn as nn 11 | import optax 12 | import jpc 13 | 14 | from experiments.datasets import get_dataloaders 15 | from experiments.mupc_paper.utils import ( 16 | set_seed, 17 | init_weights, 18 | compute_param_l2_norms, 19 | compute_param_spectral_norms 20 | ) 21 | 22 | 23 | class MLP(eqx.Module): 24 | D: int 25 | N: int 26 | L: int 27 | param_type: str 28 | use_skips: bool 29 | layers: list 30 | 31 | def __init__( 32 | self, 33 | key, 34 | d_in, 35 | N, 36 | L, 37 | d_out, 38 | act_fn, 39 | param_type, 40 | use_bias=False, 41 | use_skips=False 42 | ): 43 | self.D = d_in 44 | self.N = N 45 | self.L = L 46 | self.param_type = param_type 47 | self.use_skips = use_skips 48 | 49 | keys = jr.split(key, L) 50 | self.layers = [] 51 | for i in range(L): 52 | act_fn_l = nn.Identity() if i == 0 else jpc.get_act_fn(act_fn) 53 | _in = d_in if i == 0 else N 54 | _out = d_out if (i + 1) == L else N 55 | layer = nn.Sequential( 56 | [ 57 | nn.Lambda(act_fn_l), 58 | nn.Linear( 59 | _in, 60 | _out, 61 | use_bias=use_bias, 62 | key=keys[i] 63 | ) 64 | ] 65 | ) 66 | self.layers.append(layer) 67 | 68 | def __call__(self, x): 69 | pre_activs = [] 70 | 71 | if self.param_type == "depth_mup": 72 | for i, f in enumerate(self.layers): 73 | if (i + 1) == 1: 74 | x = f(x) / jnp.sqrt(self.D) 75 | elif 1 < (i + 1) < self.L: 76 | residual = x if self.use_skips else 0 77 | rescaling = jnp.sqrt( 78 | self.N * self.L 79 | ) if self.use_skips else jnp.sqrt(self.N) 80 | x = (f(x) / rescaling) + residual 81 | elif (i + 1) == self.L: 82 | x = f(x) / self.N 83 | 84 | pre_activs.append(x) 85 | 86 | else: 87 | for i, f in enumerate(self.layers): 88 | residual = x if self.use_skips and (1 < (i + 1) < self.L) else 0 89 | 90 | x = f(x) + residual 91 | 92 | pre_activs.append(x) 93 | 94 | return pre_activs 95 | 96 | 97 | def mse_loss(model, x, y): 98 | y_pred = vmap(model)(x)[-1] 99 | return jnp.mean((y - y_pred) ** 2) 100 | 101 | 102 | @eqx.filter_jit 103 | def make_step(model, optim, opt_state, x, y): 104 | loss, grads = eqx.filter_value_and_grad(mse_loss)(model, x, y) 105 | updates, opt_state = optim.update( 106 | updates=grads, 107 | state=opt_state, 108 | params=eqx.filter(model, eqx.is_array) 109 | ) 110 | model = eqx.apply_updates(model, updates) 111 | return model, opt_state, loss 112 | 113 | 114 | def test_fwd_pass( 115 | seed, 116 | dataset, 117 | width, 118 | depth, 119 | act_fn, 120 | optim_id, 121 | param_type, 122 | use_skips, 123 | lr, 124 | batch_size, 125 | n_checks 126 | ): 127 | set_seed(seed) 128 | 129 | key = jr.PRNGKey(seed) 130 | keys = jr.split(key, 2) 131 | model = MLP( 132 | key=keys[0], 133 | d_in=784, 134 | N=width, 135 | L=depth, 136 | d_out=10, 137 | act_fn=act_fn, 138 | param_type=param_type, 139 | use_bias=False, 140 | use_skips=use_skips 141 | ) 142 | 143 | if param_type == "depth_mup": 144 | model = init_weights( 145 | model=model, 146 | init_fn_id="standard_gauss", 147 | key=keys[1] 148 | ) 149 | elif param_type == "orthogonal": 150 | model = init_weights( 151 | model=model, 152 | init_fn_id="orthogonal", 153 | key=keys[1], 154 | gain=1.05 if act_fn == "tanh" else 1 155 | ) 156 | 157 | optim = optax.sgd(lr) if optim_id == "sgd" else optax.adam(lr) 158 | opt_state = optim.init(eqx.filter(model, eqx.is_array)) 159 | 160 | layer_idxs = [0, int(depth/4)-1, int(depth/2)-1, int(depth*3/4)-1, depth-1] 161 | avg_activity_l1 = np.zeros((len(layer_idxs), n_checks)) 162 | avg_activity_l2 = np.zeros_like(avg_activity_l1) 163 | param_l2_norms = np.zeros_like(avg_activity_l1) 164 | param_spectral_norms = np.zeros_like(avg_activity_l1) 165 | 166 | train_loader, _ = get_dataloaders(dataset, batch_size) 167 | for t, (img_batch, label_batch) in enumerate(train_loader): 168 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy() 169 | 170 | pre_activities = vmap(model)(img_batch) 171 | i = 0 172 | for l, pre_act in enumerate(pre_activities): 173 | if l in layer_idxs: 174 | avg_activity_l1[i, t] = jnp.abs(pre_act).mean() 175 | avg_activity_l2[i, t] = jnp.sqrt(jnp.mean(pre_act**2)) 176 | i += 1 177 | 178 | param_l2_norms[:, t] = compute_param_l2_norms( 179 | model=model.layers, 180 | act_fn=act_fn, 181 | layer_idxs=layer_idxs 182 | ) 183 | param_spectral_norms[:, t] = compute_param_spectral_norms( 184 | model=model.layers, 185 | act_fn=act_fn, 186 | layer_idxs=layer_idxs 187 | ) 188 | model, opt_state, _ = make_step( 189 | model=model, 190 | optim=optim, 191 | opt_state=opt_state, 192 | x=img_batch, 193 | y=label_batch 194 | ) 195 | if t >= (n_checks - 1): 196 | break 197 | 198 | return ( 199 | avg_activity_l1, 200 | avg_activity_l2, 201 | param_l2_norms, 202 | param_spectral_norms 203 | ) 204 | 205 | 206 | if __name__ == "__main__": 207 | RESULTS_DIR = "mlp_fwd_pass_results" 208 | DATASET = "MNIST" 209 | WIDTHS = [2 ** i for i in range(7, 11)] 210 | DEPTHS = [2 ** i for i in range(4, 10)] 211 | LR = 1e-3 212 | BATCH_SIZE = 64 213 | N_RECORDED_LAYERS = 5 214 | N_CHECKS = 5 215 | 216 | parser = argparse.ArgumentParser() 217 | parser.add_argument("--act_fns", type=str, nargs='+', default=["linear", "tanh", "relu"]) 218 | parser.add_argument("--optim_ids", type=str, nargs='+', default=["sgd", "adam"]) 219 | parser.add_argument("--param_types", type=str, nargs='+', default=["sp", "depth_mup", "orthogonal"]) 220 | parser.add_argument("--seed", type=int, default=54638) 221 | args = parser.parse_args() 222 | 223 | for act_fn in args.act_fns: 224 | print(f"\nact_fn: {act_fn}") 225 | 226 | for optim_id in args.optim_ids: 227 | print(f"\n\toptim: {optim_id}") 228 | 229 | for param_type in args.param_types: 230 | print(f"\n\t\tparam_type: {param_type}") 231 | 232 | skip_uses = [False, True] if param_type != "orthogonal" else [False] 233 | for use_skips in skip_uses: 234 | print(f"\n\t\t\tuse_skips: {use_skips}") 235 | 236 | save_dir = os.path.join( 237 | RESULTS_DIR, 238 | act_fn, 239 | optim_id, 240 | param_type, 241 | "skips" if use_skips else "no_skips", 242 | str(args.seed) 243 | ) 244 | os.makedirs(save_dir, exist_ok=True) 245 | 246 | avg_activity_l1_per_N_L = np.zeros((N_RECORDED_LAYERS, N_CHECKS, len(WIDTHS), len(DEPTHS))) 247 | avg_activity_l2_per_N_L = np.zeros_like(avg_activity_l1_per_N_L) 248 | param_l2_norms_per_N_L = np.zeros_like(avg_activity_l1_per_N_L) 249 | param_spectral_norms_per_N_L = np.zeros_like(avg_activity_l1_per_N_L) 250 | 251 | for w, width in enumerate(WIDTHS): 252 | print(f"\n\t\t\t\tN = {width}\n") 253 | for d, depth in enumerate(DEPTHS): 254 | print(f"\t\t\t\t\tL = {depth}") 255 | 256 | avg_activity_l1, avg_activity_l2, param_l2_norms, param_spectral_norms = test_fwd_pass( 257 | seed=args.seed, 258 | dataset=DATASET, 259 | width=width, 260 | depth=depth, 261 | act_fn=act_fn, 262 | optim_id=optim_id, 263 | param_type=param_type, 264 | use_skips=use_skips, 265 | lr=LR, 266 | batch_size=BATCH_SIZE, 267 | n_checks=N_CHECKS 268 | ) 269 | avg_activity_l1_per_N_L[:, :, w, d] = avg_activity_l1 270 | avg_activity_l2_per_N_L[:, :, w, d] = avg_activity_l2 271 | param_l2_norms_per_N_L[:, :, w, d] = param_l2_norms 272 | param_spectral_norms_per_N_L[:, :, w, d] = param_spectral_norms 273 | 274 | np.save(f"{save_dir}/avg_activity_l1_per_N_L.npy", avg_activity_l1_per_N_L) 275 | np.save(f"{save_dir}/avg_activity_l2_per_N_L.npy", avg_activity_l2_per_N_L) 276 | np.save(f"{save_dir}/param_l2_norms_per_N_L.npy", param_l2_norms_per_N_L) 277 | np.save(f"{save_dir}/param_spectral_norms_per_N_L.npy", param_spectral_norms_per_N_L) 278 | -------------------------------------------------------------------------------- /experiments/mupc_paper/train_bpn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | import jax.random as jr 6 | import jax.numpy as jnp 7 | from jax import vmap 8 | from jax.nn import log_softmax 9 | 10 | import equinox as eqx 11 | import equinox.nn as nn 12 | import optax 13 | import jpc 14 | 15 | from experiments.datasets import get_dataloaders 16 | from experiments.mupc_paper.utils import set_seed, setup_logger, init_weights 17 | 18 | 19 | class MLP(eqx.Module): 20 | D: int 21 | N: int 22 | L: int 23 | param_type: str 24 | use_skips: bool 25 | layers: list 26 | 27 | def __init__( 28 | self, 29 | key, 30 | d_in, 31 | N, 32 | L, 33 | d_out, 34 | act_fn, 35 | param_type, 36 | use_bias=False, 37 | use_skips=False 38 | ): 39 | self.D = d_in 40 | self.N = N 41 | self.L = L 42 | self.param_type = param_type 43 | self.use_skips = use_skips 44 | 45 | keys = jr.split(key, L) 46 | self.layers = [] 47 | for i in range(L): 48 | act_fn_l = nn.Identity() if i == 0 else jpc.get_act_fn(act_fn) 49 | _in = d_in if i == 0 else N 50 | _out = d_out if (i + 1) == L else N 51 | layer = nn.Sequential( 52 | [ 53 | nn.Lambda(act_fn_l), 54 | nn.Linear( 55 | _in, 56 | _out, 57 | use_bias=use_bias, 58 | key=keys[i] 59 | ) 60 | ] 61 | ) 62 | self.layers.append(layer) 63 | 64 | def __call__(self, x): 65 | if self.param_type == "depth_mup": 66 | for i, f in enumerate(self.layers): 67 | if (i + 1) == 1: 68 | x = f(x) / jnp.sqrt(self.D) 69 | elif 1 < (i + 1) < self.L: 70 | residual = x if self.use_skips else 0 71 | rescaling = jnp.sqrt( 72 | self.N * self.L 73 | ) if self.use_skips else jnp.sqrt(self.N) 74 | x = (f(x) / rescaling) + residual 75 | elif (i + 1) == self.L: 76 | x = f(x) / self.N 77 | 78 | else: 79 | for i, f in enumerate(self.layers): 80 | residual = x if self.use_skips and (1 < (i + 1) < self.L) else 0 81 | 82 | x = f(x) + residual 83 | 84 | return x 85 | 86 | 87 | def evaluate(model, testloader, loss_id): 88 | loss_fn = get_loss_fn(loss_id) 89 | avg_test_loss, avg_test_acc = 0, 0 90 | for x, y in testloader: 91 | x, y = x.numpy(), y.numpy() 92 | avg_test_loss += loss_fn(model, x, y) 93 | avg_test_acc += compute_accuracy(model, x, y) 94 | return avg_test_loss / len(testloader), avg_test_acc / len(testloader) 95 | 96 | 97 | @eqx.filter_jit 98 | def mse_loss(model, x, y): 99 | y_pred = vmap(model)(x) 100 | return jnp.mean((y - y_pred) ** 2) 101 | 102 | 103 | @eqx.filter_jit 104 | def cross_entropy_loss(model, x, y): 105 | logits = vmap(model)(x) 106 | log_probs = log_softmax(logits) 107 | return - jnp.mean(jnp.sum(y * log_probs, axis=-1)) 108 | 109 | 110 | def get_loss_fn(loss_id): 111 | if loss_id == "mse": 112 | return mse_loss 113 | elif loss_id == "ce": 114 | return cross_entropy_loss 115 | 116 | 117 | @eqx.filter_jit 118 | def compute_accuracy(model, x, y): 119 | pred_y = vmap(model)(x) 120 | return jnp.mean( 121 | jnp.argmax(y, axis=1) == jnp.argmax(pred_y, axis=1) 122 | ) * 100 123 | 124 | 125 | @eqx.filter_jit 126 | def make_step(model, optim, opt_state, x, y, loss_id="mse"): 127 | loss_fn = get_loss_fn(loss_id) 128 | loss, grads = eqx.filter_value_and_grad(loss_fn)(model, x, y) 129 | updates, opt_state = optim.update( 130 | updates=grads, 131 | state=opt_state, 132 | params=eqx.filter(model, eqx.is_array) 133 | ) 134 | model = eqx.apply_updates(model, updates) 135 | return model, opt_state, loss 136 | 137 | 138 | def train_mlp( 139 | seed, 140 | dataset, 141 | loss_id, 142 | width, 143 | n_hidden, 144 | act_fn, 145 | param_type, 146 | optim_id, 147 | lr, 148 | batch_size, 149 | max_epochs, 150 | test_every, 151 | save_dir 152 | ): 153 | set_seed(seed) 154 | key = jr.PRNGKey(seed) 155 | model_key, init_key = jr.split(key, 2) 156 | logger = setup_logger(save_dir) 157 | 158 | model = MLP( 159 | key=model_key, 160 | d_in=3072, #784, 3072 161 | N=width, 162 | L=n_hidden+1, 163 | d_out=10, 164 | act_fn=act_fn, 165 | param_type=param_type, 166 | use_bias=False, 167 | use_skips=True if param_type == "depth_mup" else False 168 | ) 169 | if param_type == "depth_mup": 170 | model = init_weights( 171 | model=model, 172 | init_fn_id="standard_gauss", 173 | key=init_key 174 | ) 175 | 176 | optim = optax.sgd(lr) if optim_id == "sgd" else optax.adam(lr) 177 | opt_state = optim.init(eqx.filter(model, eqx.is_array)) 178 | 179 | # data 180 | train_loader, test_loader = get_dataloaders(dataset, batch_size) 181 | 182 | # key metrics 183 | train_losses = [] 184 | test_losses, test_accs = [], [] 185 | 186 | diverged = no_learning = False 187 | global_batch_id = 0 188 | for epoch in range(1, max_epochs + 1): 189 | print(f"\nEpoch {epoch}\n-------------------------------") 190 | 191 | for train_iter, (img_batch, label_batch) in enumerate(train_loader): 192 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy() 193 | 194 | model, opt_state, train_loss = make_step( 195 | model=model, 196 | optim=optim, 197 | opt_state=opt_state, 198 | x=img_batch, 199 | y=label_batch, 200 | loss_id=loss_id 201 | ) 202 | train_losses.append(train_loss) 203 | global_batch_id += 1 204 | 205 | if global_batch_id % test_every == 0: 206 | print( 207 | f"Train loss: {train_loss:.7f} [{train_iter * len(img_batch)}/{len(train_loader.dataset)}]" 208 | ) 209 | avg_test_loss, avg_test_acc = evaluate( 210 | model, 211 | test_loader, 212 | loss_id 213 | ) 214 | test_losses.append(avg_test_loss) 215 | test_accs.append(avg_test_acc) 216 | print(f"Avg test accuracy: {avg_test_acc:.4f}\n") 217 | 218 | if np.isinf(train_loss) or np.isnan(train_loss): 219 | diverged = True 220 | break 221 | 222 | if global_batch_id >= test_every and avg_test_acc < 15: 223 | no_learning = True 224 | break 225 | 226 | if diverged: 227 | print( 228 | f"Stopping training because of diverging loss: {train_loss}" 229 | ) 230 | break 231 | 232 | if no_learning: 233 | print( 234 | f"Stopping training because of close-to-chance accuracy (no learning): {avg_test_acc}" 235 | ) 236 | break 237 | 238 | np.save(f"{save_dir}/train_losses.npy", train_losses) 239 | np.save(f"{save_dir}/test_losses.npy", test_losses) 240 | np.save(f"{save_dir}/test_accs.npy", test_accs) 241 | 242 | 243 | if __name__ == "__main__": 244 | 245 | parser = argparse.ArgumentParser() 246 | parser.add_argument("--results_dir", type=str, default="bp_results") 247 | parser.add_argument("--dataset", type=str, default="CIFAR10") 248 | parser.add_argument("--loss_id", type=str, default="ce") 249 | parser.add_argument("--width", type=int, default=512) 250 | parser.add_argument("--n_hidden", type=int, default=8) 251 | parser.add_argument("--act_fns", type=str, nargs='+', default=["relu"]) 252 | parser.add_argument("--param_type", type=str, default="depth_mup") 253 | parser.add_argument("--optim_id", type=str, default="adam") 254 | parser.add_argument("--lrs", type=float, nargs='+', default=[1e-2]) 255 | parser.add_argument("--batch_size", type=int, default=128) 256 | parser.add_argument("--max_epochs", type=int, default=20) 257 | parser.add_argument("--test_every", type=int, default=389) 258 | parser.add_argument("--n_seeds", type=int, default=3) 259 | args = parser.parse_args() 260 | 261 | for act_fn in args.act_fns: 262 | for lr in args.lrs: 263 | for seed in range(args.n_seeds): 264 | save_dir = os.path.join( 265 | args.results_dir, 266 | args.dataset, 267 | f"{args.loss_id}_loss", 268 | f"width_{args.width}", 269 | f"{args.n_hidden}_n_hidden", 270 | act_fn, 271 | f"{args.param_type}_param", 272 | args.optim_id, 273 | f"lr_{lr}", 274 | f"batch_size_{args.batch_size}", 275 | f"{args.max_epochs}_epochs", 276 | str(seed) 277 | ) 278 | print(f"Starting training with config: {save_dir} with seed: {seed}") 279 | train_mlp( 280 | seed=seed, 281 | dataset=args.dataset, 282 | loss_id=args.loss_id, 283 | width=args.width, 284 | n_hidden=args.n_hidden, 285 | act_fn=act_fn, 286 | param_type=args.param_type, 287 | optim_id=args.optim_id, 288 | lr=lr, 289 | batch_size=args.batch_size, 290 | max_epochs=args.max_epochs, 291 | test_every=args.test_every, 292 | save_dir=save_dir 293 | ) 294 | -------------------------------------------------------------------------------- /jpc/_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.tree_util import tree_map, tree_leaves 4 | from equinox import tree_at 5 | import equinox.nn as nn 6 | from jpc import pc_energy_fn, _check_param_type 7 | from jaxtyping import PRNGKeyArray, PyTree, ArrayLike, Scalar, Array 8 | from jaxlib.xla_extension import PjitFunction 9 | from typing import Callable, Optional, Tuple 10 | from dataclasses import dataclass 11 | 12 | 13 | _ACT_FNS = [ 14 | "linear", "tanh", "hard_tanh", "relu", "leaky_relu", "gelu", "selu", "silu" 15 | ] 16 | 17 | 18 | def get_act_fn(name: str) -> Callable: 19 | if name == "linear": 20 | return nn.Identity() 21 | elif name == "tanh": 22 | return jnp.tanh 23 | elif name == "hard_tanh": 24 | return jax.nn.hard_tanh 25 | elif name == "relu": 26 | return jax.nn.relu 27 | elif name == "leaky_relu": 28 | return jax.nn.leaky_relu 29 | elif name == "gelu": 30 | return jax.nn.gelu 31 | elif name == "selu": 32 | return jax.nn.selu 33 | elif name == "silu": 34 | return jax.nn.silu 35 | else: 36 | raise ValueError(f""" 37 | Invalid activation function ID. Options are {_ACT_FNS}. 38 | """) 39 | 40 | 41 | def make_mlp( 42 | key: PRNGKeyArray, 43 | input_dim: int, 44 | width: int, 45 | depth: int, 46 | output_dim: int, 47 | act_fn: str, 48 | use_bias: bool = False, 49 | param_type: str = "sp" 50 | ) -> PyTree[Callable]: 51 | """Creates a multi-layer perceptron compatible with predictive coding updates. 52 | 53 | !!! note 54 | 55 | This implementation places the activation function before the linear 56 | transformation, $\mathbf{W}_\ell \phi(\mathbf{z}_{\ell-1})$, for 57 | compatibility with the [μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1)) 58 | scalings when `param_type = "mupc"` in functions including 59 | [`jpc.init_activities_with_ffwd()`](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_activities_with_ffwd), 60 | [`jpc.update_activities()`](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_activities), 61 | and [`jpc.update_params()`](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_params). 62 | 63 | **Main arguments:** 64 | 65 | - `key`: `jax.random.PRNGKey` for parameter initialisation. 66 | - `input_dim`: Input dimension. 67 | - `width`: Network width. 68 | - `depth`: Network depth. 69 | - `output_dim`: Output dimension. 70 | - `act_fn`: Activation function (for all layers except the output). 71 | - `use_bias`: `False` by default. 72 | - `param_type`: Determines the parameterisation. Options are `"sp"` 73 | (standard parameterisation), `"mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))), 74 | or `"ntp"` (neural tangent parameterisation). See [`jpc._get_param_scalings()`](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings) 75 | for the specific scalings of these different parameterisations. Defaults 76 | to `"sp"`. 77 | 78 | **Returns:** 79 | 80 | List of callable fully connected layers. 81 | 82 | """ 83 | _check_param_type(param_type) 84 | 85 | subkeys = jax.random.split(key, depth) 86 | layers = [] 87 | for i in range(depth): 88 | act_fn_l = nn.Identity() if i == 0 else get_act_fn(act_fn) 89 | _in = input_dim if i == 0 else width 90 | _out = output_dim if (i + 1) == depth else width 91 | 92 | linear = nn.Linear( 93 | _in, 94 | _out, 95 | use_bias=use_bias, 96 | key=subkeys[i] 97 | ) 98 | if param_type == "mupc": 99 | W = jax.random.normal(subkeys[i], linear.weight.shape) 100 | linear = tree_at(lambda l: l.weight, linear, W) 101 | 102 | layers.append( 103 | nn.Sequential( 104 | [nn.Lambda(act_fn_l), linear] 105 | ) 106 | ) 107 | 108 | return layers 109 | 110 | 111 | def make_skip_model(depth: int) -> PyTree[Callable]: 112 | """Creates a residual network with one-layer skip connections at every layer 113 | except from the input to the next layer and from the penultimate layer to 114 | the output. 115 | 116 | This is used for compatibility with the [μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1)) 117 | parameterisation when `param_type = "mupc"` in functions including 118 | [`jpc.init_activities_with_ffwd()`](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_activities_with_ffwd), 119 | [`jpc.update_activities()`](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_activities), 120 | and [`jpc.update_params()`](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_params). 121 | """ 122 | skips = [None] * depth 123 | for l in range(1, depth-1): 124 | skips[l] = nn.Lambda(nn.Identity()) 125 | 126 | return skips 127 | 128 | 129 | def mse_loss(preds: ArrayLike, labels: ArrayLike) -> Scalar: 130 | return 0.5 * jnp.mean((labels - preds)**2) 131 | 132 | 133 | def cross_entropy_loss(logits: ArrayLike, labels: ArrayLike) -> Scalar: 134 | probs = jax.nn.softmax(logits, axis=-1) 135 | log_probs = jnp.log(probs) 136 | return - jnp.mean(jnp.sum(labels * log_probs, axis=-1)) 137 | 138 | 139 | def compute_accuracy(truths: ArrayLike, preds: ArrayLike) -> Scalar: 140 | return jnp.mean( 141 | jnp.argmax(truths, axis=1) == jnp.argmax(preds, axis=1) 142 | ) * 100 143 | 144 | 145 | def get_t_max(activities_iters: PyTree[Array]) -> Array: 146 | return jnp.argmax(activities_iters[0][:, 0, 0]) - 1 147 | 148 | 149 | def compute_infer_energies( 150 | params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], 151 | activities_iters: PyTree[Array], 152 | t_max: Array, 153 | y: ArrayLike, 154 | *, 155 | x: Optional[ArrayLike] = None, 156 | loss: str = "mse", 157 | param_type: str = "sp", 158 | weight_decay: Scalar = 0., 159 | spectral_penalty: Scalar = 0., 160 | activity_decay: Scalar = 0. 161 | ) -> PyTree[Scalar]: 162 | """Calculates layer energies during predictive coding inference. 163 | 164 | **Main arguments:** 165 | 166 | - `params`: Tuple with callable model layers and optional skip connections. 167 | - `activities_iters`: Layer-wise activities at every inference iteration. 168 | Note that each set of activities will have 4096 steps as first 169 | dimension by diffrax default. 170 | - `t_max`: Maximum number of inference iterations to compute energies for. 171 | - `y`: Observation or target of the generative model. 172 | 173 | **Other arguments:** 174 | 175 | - `x`: Optional prior of the generative model. 176 | - `loss`: Loss function to use at the output layer (mean squared error 177 | `"mse"` vs cross-entropy `"ce"`). 178 | - `param_type`: Determines the parameterisation. Options are `"sp"`, 179 | `"mupc"`, or `"ntp"`. 180 | - `weight_decay`: Weight decay for the weights. 181 | - `spectral_penalty`: Spectral penalty for the weights. 182 | - `activity_decay`: Activity decay for the activities. 183 | 184 | **Returns:** 185 | 186 | List of layer-wise energies at every inference iteration. 187 | 188 | """ 189 | model, _ = params 190 | 191 | def loop_body(state): 192 | t, energies_iters = state 193 | 194 | energies = pc_energy_fn( 195 | params=params, 196 | activities=tree_map(lambda act: act[t], activities_iters), 197 | y=y, 198 | x=x, 199 | loss=loss, 200 | param_type=param_type, 201 | weight_decay=weight_decay, 202 | spectral_penalty=spectral_penalty, 203 | activity_decay=activity_decay, 204 | record_layers=True 205 | ) 206 | energies_iters = energies_iters.at[:, t].set(energies) 207 | return t + 1, energies_iters 208 | 209 | # for memory reasons, we set 500 as the max iters to record 210 | energies_iters = jnp.zeros((len(model), 500)) 211 | _, energies_iters = jax.lax.while_loop( 212 | lambda state: state[0] < t_max, 213 | loop_body, 214 | (0, energies_iters) 215 | ) 216 | return energies_iters[::-1, :] 217 | 218 | 219 | def compute_activity_norms(activities: PyTree[Array]) -> Array: 220 | """Calculates $\ell^2$ norm of activities at each layer.""" 221 | return jnp.array([ 222 | jnp.mean( 223 | jnp.linalg.norm( 224 | a, 225 | axis=-1, 226 | ord=2 227 | ) 228 | ) for a in tree_leaves(activities) 229 | ]) 230 | 231 | 232 | def compute_param_norms(params): 233 | """Calculates $\ell^2$ norm of all model parameters.""" 234 | def process_model_params(model_params): 235 | norms = [] 236 | for p in tree_leaves(model_params): 237 | if p is None or isinstance(p, PjitFunction): 238 | norms.append(0.) 239 | elif callable(p) and not hasattr(p, 'shape'): 240 | # Skip callable functions (like Lambda-wrapped activations) that don't have shape 241 | # But keep arrays which might be callable in some JAX contexts 242 | norms.append(0.) 243 | else: 244 | try: 245 | # Check if p is a JAX array-like object 246 | if hasattr(p, 'shape') and hasattr(p, 'dtype'): 247 | norms.append(jnp.linalg.norm(jnp.ravel(p), ord=2)) 248 | else: 249 | norms.append(0.) 250 | except (TypeError, AttributeError): 251 | # If ravel fails, it's not an array 252 | norms.append(0.) 253 | return jnp.array(norms) 254 | 255 | model_params, skip_model_params = params 256 | model_norms = process_model_params(model_params) 257 | skip_model_norms = (process_model_params(skip_model_params) if 258 | skip_model_params is not None else None) 259 | 260 | return model_norms, skip_model_norms 261 | -------------------------------------------------------------------------------- /jpc/_test.py: -------------------------------------------------------------------------------- 1 | """Utility functions to test predictive coding networks.""" 2 | 3 | import equinox as eqx 4 | from jpc import ( 5 | init_activities_from_normal, 6 | init_activities_with_ffwd, 7 | init_activities_with_amort, 8 | mse_loss, 9 | cross_entropy_loss, 10 | compute_accuracy, 11 | solve_inference, 12 | _check_param_type 13 | ) 14 | from diffrax import ( 15 | AbstractSolver, 16 | AbstractStepSizeController, 17 | Heun, 18 | PIDController 19 | ) 20 | from jaxtyping import PRNGKeyArray, PyTree, ArrayLike, Array, Scalar 21 | from typing import Callable, Tuple, Optional 22 | 23 | 24 | @eqx.filter_jit 25 | def test_discriminative_pc( 26 | model: PyTree[Callable], 27 | output: ArrayLike, 28 | input: ArrayLike, 29 | *, 30 | skip_model: Optional[PyTree[Callable]] = None, 31 | loss: str = "mse", 32 | param_type: str = "sp" 33 | ) -> Tuple[Scalar, Scalar]: 34 | """Computes test metrics for a discriminative predictive coding network. 35 | 36 | **Main arguments:** 37 | 38 | - `model`: List of callable model (e.g. neural network) layers. 39 | - `output`: Observation or target of the generative model. 40 | - `input`: Optional prior of the generative model. 41 | 42 | **Other arguments:** 43 | 44 | - `skip_model`: Optional skip connection model. 45 | - `loss`: Loss function to use at the output layer. Options are mean squared 46 | error `"mse"` (default) or cross-entropy `"ce"`. 47 | - `param_type`: Determines the parameterisation. Options are `"sp"` 48 | (standard parameterisation), `"mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))), 49 | or `"ntp"` (neural tangent parameterisation). See [`_get_param_scalings()`](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings) 50 | for the specific scalings of these different parameterisations. Defaults 51 | to `"sp"`. 52 | 53 | **Returns:** 54 | 55 | Test loss and accuracy of output predictions. 56 | 57 | """ 58 | _check_param_type(param_type) 59 | 60 | preds = init_activities_with_ffwd( 61 | model=model, 62 | input=input, 63 | skip_model=skip_model, 64 | param_type=param_type 65 | )[-1] 66 | 67 | if loss == "mse": 68 | loss = mse_loss(preds, output) 69 | elif loss == "ce": 70 | loss = cross_entropy_loss(preds, output) 71 | 72 | acc = compute_accuracy(output, preds) 73 | return loss, acc 74 | 75 | 76 | @eqx.filter_jit 77 | def test_generative_pc( 78 | model: PyTree[Callable], 79 | output: ArrayLike, 80 | input: ArrayLike, 81 | key: PRNGKeyArray, 82 | layer_sizes: PyTree[int], 83 | batch_size: int, 84 | *, 85 | skip_model: Optional[PyTree[Callable]] = None, 86 | loss_id: str = "mse", 87 | param_type: str = "sp", 88 | sigma: Scalar = 0.05, 89 | ode_solver: AbstractSolver = Heun(), 90 | max_t1: int = 500, 91 | dt: Scalar | int = None, 92 | stepsize_controller: AbstractStepSizeController = PIDController( 93 | rtol=1e-3, atol=1e-3 94 | ), 95 | weight_decay: Scalar = 0., 96 | spectral_penalty: Scalar = 0., 97 | activity_decay: Scalar = 0. 98 | ) -> Tuple[Scalar, Array]: 99 | """Computes test metrics for a generative predictive coding network. 100 | 101 | Gets output predictions (e.g. of an image given a label) with a feedforward 102 | pass and calculates accuracy of inferred input (e.g. of a label given an 103 | image). 104 | 105 | **Main arguments:** 106 | 107 | - `model`: List of callable model (e.g. neural network) layers. 108 | - `output`: Observation or target of the generative model. 109 | - `input`: Prior of the generative model. 110 | - `key`: `jax.random.PRNGKey` for random initialisation of activities. 111 | - `layer_sizes`: Dimension of all layers (input, hidden and output). 112 | - `batch_size`: Dimension of data batch for activity initialisation. 113 | 114 | **Other arguments:** 115 | 116 | - `skip_model`: Optional skip connection model. 117 | - `loss_id`: Loss function to use at the output layer. Options are mean squared 118 | error `"mse"` (default) or cross-entropy `"ce"`. 119 | - `param_type`: Determines the parameterisation. Options are `"sp"` 120 | (standard parameterisation), `"mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))), 121 | or `"ntp"` (neural tangent parameterisation). See [`_get_param_scalings()`](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings) 122 | for the specific scalings of these different parameterisations. Defaults 123 | to `"sp"`. 124 | - `sigma`: Standard deviation for Gaussian to sample activities from. 125 | Defaults to 5e-2. 126 | - `ode_solver`: [diffrax ODE solver](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/) 127 | to be used. Default is [`Heun`](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/#diffrax.Heun), 128 | a 2nd order explicit Runge--Kutta method. 129 | - `max_t1`: Maximum end of integration region (500 by default). 130 | - `dt`: Integration step size. Defaults to None since the default 131 | `stepsize_controller` will automatically determine it. 132 | - `stepsize_controller`: [diffrax controller](https://docs.kidger.site/diffrax/api/stepsize_controller/) 133 | for step size integration. Defaults to [`PIDController`](https://docs.kidger.site/diffrax/api/stepsize_controller/#diffrax.PIDController). 134 | Note that the relative and absolute tolerances of the controller will 135 | also determine the steady state to terminate the solver. 136 | - `weight_decay`: Weight decay for the weights (0 by default). 137 | - `spectral_penalty`: Weight spectral penalty of the form 138 | $||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2$ (0 by default). 139 | - `activity_decay`: Activity decay for the activities (0 by default). 140 | 141 | **Returns:** 142 | 143 | Accuracy and output predictions. 144 | 145 | """ 146 | _check_param_type(param_type) 147 | 148 | params = model, skip_model 149 | activities = init_activities_from_normal( 150 | key=key, 151 | layer_sizes=layer_sizes, 152 | mode="unsupervised", 153 | batch_size=batch_size, 154 | sigma=sigma 155 | ) 156 | input_preds = solve_inference( 157 | params=params, 158 | activities=activities, 159 | output=output, 160 | loss_id=loss_id, 161 | param_type=param_type, 162 | solver=ode_solver, 163 | max_t1=max_t1, 164 | dt=dt, 165 | stepsize_controller=stepsize_controller, 166 | weight_decay=weight_decay, 167 | spectral_penalty=spectral_penalty, 168 | activity_decay=activity_decay 169 | )[0][0] 170 | input_acc = compute_accuracy(input, input_preds) 171 | output_preds = init_activities_with_ffwd( 172 | model=model, 173 | input=input, 174 | skip_model=skip_model, 175 | param_type=param_type 176 | )[-1] 177 | return input_acc, output_preds 178 | 179 | 180 | @eqx.filter_jit 181 | def test_hpc( 182 | generator: PyTree[Callable], 183 | amortiser: PyTree[Callable], 184 | output: ArrayLike, 185 | input: ArrayLike, 186 | key: PRNGKeyArray, 187 | layer_sizes: PyTree[int], 188 | batch_size: int, 189 | sigma: Scalar = 0.05, 190 | ode_solver: AbstractSolver = Heun(), 191 | max_t1: int = 500, 192 | dt: Scalar | int = None, 193 | stepsize_controller: AbstractStepSizeController = PIDController( 194 | rtol=1e-3, atol=1e-3 195 | ) 196 | ) -> Tuple[Scalar, Scalar, Scalar, Array]: 197 | """Computes test metrics for hybrid predictive coding trained in a supervised manner. 198 | 199 | Calculates input accuracy of (i) amortiser, (ii) generator, and (iii) 200 | hybrid (amortiser + generator). Also returns output predictions (e.g. of 201 | an image given a label) with a feedforward pass of the generator. 202 | 203 | !!! note 204 | 205 | The input and output of the generator are the output and input of the 206 | amortiser, respectively. 207 | 208 | **Main arguments:** 209 | 210 | - `generator`: List of callable layers for the generative model. 211 | - `amortiser`: List of callable layers for model amortising the inference 212 | of the `generator`. 213 | - `output`: Observation or target of the generative model. 214 | - `input`: Optional prior of the generator, target for the amortiser. 215 | - `key`: `jax.random.PRNGKey` for random initialisation of activities. 216 | - `layer_sizes`: Dimension of all layers (input, hidden and output). 217 | - `batch_size`: Dimension of data batch for initialisation of activities. 218 | 219 | **Other arguments:** 220 | 221 | - `sigma`: Standard deviation for Gaussian to sample activities from. 222 | Defaults to 5e-2. 223 | - `ode_solver`: [diffrax ODE solver](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/) 224 | to be used. Default is [`Heun`](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/#diffrax.Heun), 225 | a 2nd order explicit Runge--Kutta method. 226 | - `max_t1`: Maximum end of integration region (500 by default). 227 | - `dt`: Integration step size. Defaults to None since the default 228 | `stepsize_controller` will automatically determine it. 229 | - `stepsize_controller`: [diffrax controller](https://docs.kidger.site/diffrax/api/stepsize_controller/) 230 | for step size integration. Defaults to [`PIDController`](https://docs.kidger.site/diffrax/api/stepsize_controller/#diffrax.PIDController). 231 | Note that the relative and absolute tolerances of the controller will 232 | also determine the steady state to terminate the solver. 233 | 234 | **Returns:** 235 | 236 | Accuracies of all models and output predictions. 237 | 238 | """ 239 | gen_params = (generator, None) 240 | amort_activities = init_activities_with_amort( 241 | amortiser=amortiser, 242 | generator=generator, 243 | input=output 244 | ) 245 | amort_preds = amort_activities[0] 246 | hpc_preds = solve_inference( 247 | params=gen_params, 248 | activities=amort_activities, 249 | output=output, 250 | solver=ode_solver, 251 | max_t1=max_t1, 252 | dt=dt, 253 | stepsize_controller=stepsize_controller 254 | )[0][0] 255 | activities = init_activities_from_normal( 256 | key=key, 257 | layer_sizes=layer_sizes, 258 | mode="unsupervised", 259 | batch_size=batch_size, 260 | sigma=sigma 261 | ) 262 | gen_preds = solve_inference( 263 | params=gen_params, 264 | activities=activities, 265 | output=output, 266 | solver=ode_solver, 267 | max_t1=max_t1, 268 | dt=dt, 269 | stepsize_controller=stepsize_controller 270 | )[0][0] 271 | amort_acc = compute_accuracy(input, amort_preds) 272 | hpc_acc = compute_accuracy(input, hpc_preds) 273 | gen_acc = compute_accuracy(input, gen_preds) 274 | output_preds = init_activities_with_ffwd( 275 | model=generator, 276 | input=input, 277 | skip_model=None 278 | )[-1] 279 | return amort_acc, hpc_acc, gen_acc, output_preds 280 | -------------------------------------------------------------------------------- /experiments/library_paper/train_mlp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | 5 | import jax 6 | import equinox as eqx 7 | import jpc 8 | import optax 9 | from diffrax import PIDController, ConstantStepSize 10 | 11 | from utils import ( 12 | setup_mlp_experiment, 13 | get_ode_solver, 14 | set_seed 15 | ) 16 | from plotting import ( 17 | plot_loss, 18 | plot_loss_and_accuracy, 19 | plot_runtimes, 20 | plot_norms 21 | ) 22 | from experiments.datasets import get_dataloaders 23 | 24 | 25 | def evaluate(model, test_loader): 26 | avg_test_loss, avg_test_acc = 0, 0 27 | for batch_id, (img_batch, label_batch) in enumerate(test_loader): 28 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy() 29 | 30 | test_loss, test_acc = jpc.test_discriminative_pc( 31 | model=model, 32 | output=label_batch, 33 | input=img_batch 34 | ) 35 | avg_test_loss += test_loss 36 | avg_test_acc += test_acc 37 | 38 | return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader) 39 | 40 | 41 | def train_mlp( 42 | seed, 43 | dataset, 44 | width, 45 | n_hidden, 46 | act_fn, 47 | max_t1, 48 | activity_lr, 49 | param_lr, 50 | batch_size, 51 | activity_optim_id, 52 | max_epochs, 53 | test_every, 54 | save_dir 55 | ): 56 | set_seed(seed) 57 | os.makedirs(save_dir, exist_ok=True) 58 | 59 | key = jax.random.PRNGKey(seed) 60 | model = jpc.make_mlp( 61 | key, 62 | input_dim=784, 63 | width=width, 64 | depth=n_hidden+1, 65 | output_dim=10, 66 | act_fn=act_fn 67 | ) 68 | 69 | param_optim = optax.adam(param_lr) 70 | param_opt_state = param_optim.init( 71 | (eqx.filter(model, eqx.is_array), None) 72 | ) 73 | train_loader, test_loader = get_dataloaders(dataset, batch_size) 74 | 75 | train_losses = [] 76 | test_losses, test_accs = [], [] 77 | activity_norms, param_norms, param_grad_norms = [], [], [] 78 | inference_runtimes = [] 79 | 80 | if activity_optim_id != "SGD": 81 | stepsize_controller = ConstantStepSize() if ( 82 | activity_optim_id == "Euler" 83 | ) else PIDController(rtol=1e-3, atol=1e-3) 84 | ode_solver = get_ode_solver(activity_optim_id) 85 | 86 | elif activity_optim_id == "SGD": 87 | activity_optim = optax.sgd(activity_lr) 88 | 89 | global_batch_id = 0 90 | for epoch in range(1, max_epochs + 1): 91 | print(f"\nEpoch {epoch}\n-------------------------------") 92 | 93 | for batch_id, (img_batch, label_batch) in enumerate(train_loader): 94 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy() 95 | 96 | if activity_optim_id != "SGD": 97 | result = jpc.make_pc_step( 98 | model, 99 | param_optim, 100 | param_opt_state, 101 | output=label_batch, 102 | input=img_batch, 103 | ode_solver=ode_solver, 104 | max_t1=max_t1, 105 | dt=activity_lr, 106 | stepsize_controller=stepsize_controller, 107 | activity_norms=True, 108 | param_norms=True, 109 | grad_norms=True 110 | ) 111 | model, param_opt_state = result["model"], result["opt_state"] 112 | train_loss = result["loss"] 113 | activity_norms.append(result["activity_norms"][:-1]) 114 | param_norms.append( 115 | [p for p in result["model_param_norms"] if p == 0 or p % 2 == 0] 116 | ) 117 | param_grad_norms.append( 118 | [p for p in result["model_grad_norms"] if p == 0 or p % 2 == 0] 119 | ) 120 | 121 | elif activity_optim_id == "SGD": 122 | activities = jpc.init_activities_with_ffwd( 123 | model=model, 124 | input=img_batch 125 | ) 126 | activity_opt_state = activity_optim.init(activities) 127 | train_loss = jpc.mse_loss(activities[-1], label_batch) 128 | 129 | for _ in range(max_t1): 130 | activity_update_result = jpc.update_activities( 131 | params=(model, None), 132 | activities=activities, 133 | optim=activity_optim, 134 | opt_state=activity_opt_state, 135 | output=label_batch, 136 | input=img_batch 137 | ) 138 | activities = activity_update_result["activities"] 139 | activity_optim = activity_update_result["activity_optim"] 140 | activity_opt_state = activity_update_result["activity_opt_state"] 141 | 142 | param_update_result = jpc.update_params( 143 | params=(model, None), 144 | activities=activities, 145 | optim=param_optim, 146 | opt_state=param_opt_state, 147 | output=label_batch, 148 | input=img_batch 149 | ) 150 | model = param_update_result["model"] 151 | param_grads = param_update_result["param_grads"] 152 | param_optim = param_update_result["param_optim"] 153 | param_opt_state = param_update_result["param_opt_state"] 154 | 155 | activity_norms.append(jpc.compute_activity_norms(activities[:-1])) 156 | param_norms.append(jpc.compute_param_norms((model, None))[0]) 157 | param_grad_norms.append(jpc.compute_param_norms(param_grads)[0]) 158 | 159 | if activity_optim_id != "SGD": 160 | activities0 = jpc.init_activities_with_ffwd(model, img_batch) 161 | start_time = time.time() 162 | jax.block_until_ready( 163 | jpc.solve_inference( 164 | (model, None), 165 | activities0, 166 | output=label_batch, 167 | input=img_batch, 168 | solver=ode_solver, 169 | max_t1=max_t1, 170 | dt=activity_lr, 171 | stepsize_controller=stepsize_controller 172 | ) 173 | ) 174 | end_time = time.time() 175 | 176 | elif activity_optim_id == "SGD": 177 | activities = jpc.init_activities_with_ffwd( 178 | model=model, 179 | input=img_batch 180 | ) 181 | activity_opt_state = activity_optim.init(activities) 182 | start_time = time.time() 183 | for t in range(max_t1): 184 | jax.block_until_ready( 185 | jpc.update_activities( 186 | (model, None), 187 | activities, 188 | activity_optim, 189 | activity_opt_state, 190 | label_batch, 191 | img_batch 192 | ) 193 | ) 194 | end_time = time.time() 195 | 196 | train_losses.append(train_loss) 197 | inference_runtimes.append((end_time - start_time) * 1000) 198 | global_batch_id += 1 199 | 200 | if global_batch_id % test_every == 0: 201 | print(f"Train loss: {train_loss:.7f} [{batch_id * len(img_batch)}/{len(train_loader.dataset)}]") 202 | 203 | avg_test_loss, avg_test_acc = evaluate(model, test_loader) 204 | test_losses.append(avg_test_loss) 205 | test_accs.append(avg_test_acc) 206 | print(f"Avg test accuracy: {avg_test_acc:.4f}\n") 207 | 208 | plot_loss( 209 | loss=train_losses, 210 | yaxis_title="Train loss", 211 | xaxis_title="Iteration", 212 | save_path=f"{save_dir}/train_losses.pdf" 213 | ) 214 | plot_loss_and_accuracy( 215 | loss=test_losses, 216 | accuracy=test_accs, 217 | mode="test", 218 | xaxis_title="Training iteration", 219 | save_path=f"{save_dir}/test_losses_and_accs.pdf" 220 | ) 221 | plot_norms( 222 | norms=param_norms, 223 | norm_type="param", 224 | save_path=f"{save_dir}/param_norms.pdf" 225 | ) 226 | plot_norms( 227 | norms=param_grad_norms, 228 | norm_type="param_grad", 229 | save_path=f"{save_dir}/param_grad_norms.pdf" 230 | ) 231 | plot_norms( 232 | norms=activity_norms, 233 | norm_type="activity", 234 | save_path=f"{save_dir}/activity_norms.pdf" 235 | ) 236 | plot_runtimes( 237 | runtimes=inference_runtimes, 238 | save_path=f"{save_dir}/inference_runtimes.pdf" 239 | ) 240 | 241 | np.save(f"{save_dir}/batch_train_losses.npy", train_losses) 242 | np.save(f"{save_dir}/test_losses.npy", test_losses) 243 | np.save(f"{save_dir}/test_accs.npy", test_accs) 244 | 245 | np.save(f"{save_dir}/activity_norms.npy", activity_norms) 246 | np.save(f"{save_dir}/param_norms.npy", param_norms) 247 | np.save(f"{save_dir}/param_grad_norms.npy", param_grad_norms) 248 | 249 | np.save(f"{save_dir}/inference_runtimes.npy", inference_runtimes) 250 | 251 | 252 | if __name__ == "__main__": 253 | RESULTS_DIR = "mlp_results" 254 | DATASETS = ["MNIST", "Fashion-MNIST"] 255 | N_SEEDS = 3 256 | 257 | WIDTH = 300 258 | N_HIDDENS = [3, 5, 10] 259 | ACT_FN = "tanh" 260 | 261 | ACTIVITY_OPTIMS_ID = ["Euler", "Heun"] 262 | MAX_T1S = [5, 10, 20, 50, 100, 200, 500] 263 | ACTIVITY_LRS = [5e-1, 1e-1, 5e-2] 264 | 265 | PARAM_LR = 1e-3 266 | BATCH_SIZE = 64 267 | MAX_EPOCHS = 1 268 | TEST_EVERY = 100 269 | 270 | for dataset in DATASETS: 271 | for n_hidden in N_HIDDENS: 272 | for activity_optim_id in ACTIVITY_OPTIMS_ID: 273 | for max_t1 in MAX_T1S: 274 | for activity_lr in ACTIVITY_LRS: 275 | for seed in range(N_SEEDS): 276 | save_dir = setup_mlp_experiment( 277 | results_dir=RESULTS_DIR, 278 | dataset=dataset, 279 | width=WIDTH, 280 | n_hidden=n_hidden, 281 | act_fn=ACT_FN, 282 | max_t1=max_t1, 283 | activity_lr=activity_lr, 284 | param_lr=PARAM_LR, 285 | activity_optim_id=activity_optim_id, 286 | seed=seed 287 | ) 288 | train_mlp( 289 | seed=seed, 290 | dataset=dataset, 291 | width=WIDTH, 292 | n_hidden=n_hidden, 293 | act_fn=ACT_FN, 294 | max_t1=max_t1, 295 | activity_lr=activity_lr, 296 | param_lr=PARAM_LR, 297 | batch_size=BATCH_SIZE, 298 | activity_optim_id=activity_optim_id, 299 | max_epochs=MAX_EPOCHS, 300 | test_every=TEST_EVERY, 301 | save_dir=save_dir 302 | ) 303 | --------------------------------------------------------------------------------