├── docs
├── FAQs.md
├── examples
├── .htaccess
├── _static
│ ├── favicon.png
│ ├── mathjax.js
│ ├── custom_css.css
│ ├── logo-light.svg
│ └── logo-dark.svg
├── api
│ ├── linear_net_theoretical_energy.png
│ ├── Continuous-time Inference.md
│ ├── Training.md
│ ├── Utils.md
│ ├── Testing.md
│ ├── Discrete updates.md
│ ├── Initialisation.md
│ ├── Theoretical tools.md
│ ├── Energy functions.md
│ └── Gradients.md
├── _overrides
│ └── partials
│ │ ├── logo.html
│ │ └── source.html
├── requirements.txt
├── advanced_usage.md
├── basic_usage.md
└── index.md
├── experiments
├── library_paper
│ ├── __init__.py
│ ├── utils.py
│ ├── test_theory_energies.py
│ └── train_mlp.py
├── mupc_paper
│ ├── __init__.py
│ ├── spotlight_fig.png
│ ├── requirements.txt
│ ├── README.md
│ ├── test_energy_theory.py
│ ├── analyse_activity_hessian.py
│ ├── test_mlp_fwd_pass.py
│ └── train_bpn.py
└── datasets.py
├── tests
├── __init__.py
├── test_errors.py
├── README.md
├── conftest.py
├── test_init.py
├── test_infer.py
├── test_test_functions.py
├── test_analytical.py
├── test_utils.py
└── test_train.py
├── .gitattributes
├── .gitignore
├── jpc
├── _core
│ ├── _errors.py
│ ├── __init__.py
│ ├── _infer.py
│ └── _init.py
├── __init__.py
├── _utils.py
└── _test.py
├── .github
├── workflows
│ ├── tests.yml
│ └── build_docs.yml
└── logo-with-background.svg
├── LICENSE
├── pyproject.toml
├── mkdocs.yml
└── README.md
/docs/FAQs.md:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/examples:
--------------------------------------------------------------------------------
1 | ../examples
--------------------------------------------------------------------------------
/experiments/library_paper/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/mupc_paper/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/.htaccess:
--------------------------------------------------------------------------------
1 | ErrorDocument 404 /jpc/404.html
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Test suite for the jpc library."""
2 |
3 |
--------------------------------------------------------------------------------
/docs/_static/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thebuckleylab/jpc/HEAD/docs/_static/favicon.png
--------------------------------------------------------------------------------
/experiments/mupc_paper/spotlight_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thebuckleylab/jpc/HEAD/experiments/mupc_paper/spotlight_fig.png
--------------------------------------------------------------------------------
/docs/api/linear_net_theoretical_energy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thebuckleylab/jpc/HEAD/docs/api/linear_net_theoretical_energy.png
--------------------------------------------------------------------------------
/docs/_overrides/partials/logo.html:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages
2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
3 | *.ipynb linguist-generated
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.ipynb_checkpoints
2 | **/data/
3 | **/datasets/
4 | examples/data/
5 | *.egg-info/
6 | *.pdf
7 | *.npy
8 | *.DS_Store
9 | *.pkl
10 | *.iml
11 | *.xml
12 | __pycache__/
13 | .vscode/
14 | .all_objects.cache
15 | site/
16 | venv*/
17 | *.log
18 | .pytest_cache/
19 | .coverage
20 | coverage.json
21 | htmlcov/
22 |
--------------------------------------------------------------------------------
/docs/api/Continuous-time Inference.md:
--------------------------------------------------------------------------------
1 | # Continuous-time inference
2 |
3 | The inference or activity dynamics of PC networks can be solved in either
4 | discrete or continuous time. [jpc.solve_inference()](https://thebuckleylab.github.io/jpc/api/Continuous-time%20Inference/#jpc.solve_inference)
5 | leverages ODE solvers to integrate the continuous-time dynamics.
6 |
7 | ::: jpc.solve_inference
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | jax==0.5.2 # to avoid conflict for `jax.core.ClosedJaxpr`
2 | hippogriffe==0.2.2,
3 | mkdocs==1.6.1,
4 | mkdocs-include-exclude-files==0.1.0,
5 | mkdocs-ipynb==0.1.1,
6 | mkdocs-material==9.6.7,
7 | mkdocstrings==0.28.3,
8 | mkdocstrings-python==1.16.8,
9 | pymdown-extensions==10.14.3,
10 |
11 | # Dependencies of JPC itself.
12 | # Always use most up-to-date versions.
13 | jax[cpu]
--------------------------------------------------------------------------------
/docs/_static/mathjax.js:
--------------------------------------------------------------------------------
1 | window.MathJax = {
2 | tex: {
3 | inlineMath: [["\\(", "\\)"]],
4 | displayMath: [["\\[", "\\]"]],
5 | processEscapes: true,
6 | processEnvironments: true
7 | },
8 | options: {
9 | ignoreHtmlClass: ".*|",
10 | processHtmlClass: "arithmatex"
11 | }
12 | };
13 |
14 | document$.subscribe(() => {
15 | MathJax.startup.output.clearCache()
16 | MathJax.typesetClear()
17 | MathJax.texReset()
18 | MathJax.typesetPromise()
19 | })
--------------------------------------------------------------------------------
/jpc/_core/_errors.py:
--------------------------------------------------------------------------------
1 | _PARAM_TYPES = ["sp", "mupc", "ntp"]
2 |
3 |
4 | def _check_param_type(param_type):
5 | if param_type not in _PARAM_TYPES:
6 | raise ValueError(
7 | 'Invalid parameterisation. Options are `"sp"` (standard '
8 | 'parameterisation), `"mupc"` (μPC), or `"ntp"` (neural tangent '
9 | 'parameterisation). See `_get_param_scalings()` (https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings) '
10 | 'for the specific scalings of these different parameterisations.'
11 | )
12 |
--------------------------------------------------------------------------------
/docs/api/Training.md:
--------------------------------------------------------------------------------
1 | # Training
2 |
3 | JPC provides 2 single convenience functions to update the parameters of any
4 | PC-compatible model with PC:
5 |
6 | * [jpc.make_pc_step()](https://thebuckleylab.github.io/jpc/api/Training/#jpc.make_pc_step) to
7 | perform an update using standard PC, and
8 | * [jpc.make_hpc_step()](https://thebuckleylab.github.io/jpc/api/Training/#jpc.make_hpc_step)
9 | to use hybrid PC ([Tscshantz et al., 2023](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1011280)).
10 |
11 | ::: jpc.make_pc_step
12 |
13 | ---
14 |
15 | ::: jpc.make_hpc_step
--------------------------------------------------------------------------------
/docs/api/Utils.md:
--------------------------------------------------------------------------------
1 | # Utils
2 |
3 | JPC provides several standard utilities for neural network training, including
4 | creation of simple models, losses, and metrics.
5 |
6 | ::: jpc.make_mlp
7 |
8 | ---
9 |
10 | ::: jpc.make_skip_model
11 |
12 | ---
13 |
14 | ::: jpc.get_act_fn
15 |
16 | ---
17 |
18 | ::: jpc.mse_loss
19 |
20 | ---
21 |
22 | ::: jpc.cross_entropy_loss
23 |
24 | ---
25 |
26 | ::: jpc.compute_accuracy
27 |
28 | ---
29 |
30 | ::: jpc.get_t_max
31 |
32 | ---
33 |
34 | ::: jpc.compute_infer_energies
35 |
36 | ---
37 |
38 | ::: jpc.compute_activity_norms
39 |
40 | ---
41 |
42 | ::: jpc.compute_param_norms
--------------------------------------------------------------------------------
/docs/api/Testing.md:
--------------------------------------------------------------------------------
1 | # Testing
2 |
3 | JPC provides a few convenience functions to test different types of PC network (PCN):
4 |
5 | * [jpc.test_discriminative_pc()](https://thebuckleylab.github.io/jpc/api/Testing/#jpc.test_discriminative_pc)
6 | for test loss and accuracy of discriminative PCNs;
7 | * [jpc.test_generative_pc()](https://thebuckleylab.github.io/jpc/api/Testing/#jpc.test_generative_pc)
8 | for accuracy and output predictions of generative PCNs; and
9 | * [jpc.test_hpc()](https://thebuckleylab.github.io/jpc/api/Testing/#jpc.test_hpc) for accuracy
10 | of all models (amortiser, generator, & hybrid) as well as output predictions.
11 |
12 | ::: jpc.test_discriminative_pc
13 |
14 | ---
15 |
16 | ::: jpc.test_generative_pc
17 |
18 | ---
19 |
20 | ::: jpc.test_hpc
--------------------------------------------------------------------------------
/tests/test_errors.py:
--------------------------------------------------------------------------------
1 | """Tests for error checking functions."""
2 |
3 | import pytest
4 | from jpc._core._errors import _check_param_type
5 |
6 |
7 | def test_check_param_type_valid():
8 | """Test that valid parameter types pass."""
9 | _check_param_type("sp")
10 | _check_param_type("mupc")
11 | _check_param_type("ntp")
12 |
13 |
14 | def test_check_param_type_invalid():
15 | """Test that invalid parameter types raise ValueError."""
16 | with pytest.raises(ValueError, match="Invalid parameterisation"):
17 | _check_param_type("invalid")
18 |
19 | with pytest.raises(ValueError, match="Invalid parameterisation"):
20 | _check_param_type("")
21 |
22 | with pytest.raises(ValueError, match="Invalid parameterisation"):
23 | _check_param_type("ntk")
24 |
--------------------------------------------------------------------------------
/docs/_overrides/partials/source.html:
--------------------------------------------------------------------------------
1 | {% import "partials/language.html" as lang with context %}
2 |
3 |
4 | {% set icon = config.theme.icon.repo or "fontawesome/brands/git-alt" %}
5 | {% include ".icons/" ~ icon ~ ".svg" %}
6 |
7 |
8 | {{ config.repo_name }}
9 |
10 |
11 | {% if config.theme.twitter_url %}
12 |
13 |
14 | {% include ".icons/fontawesome/brands/twitter.svg" %}
15 |
16 |
17 | {{ config.theme.twitter_name }}
18 |
19 |
20 | {% endif %}
--------------------------------------------------------------------------------
/docs/api/Discrete updates.md:
--------------------------------------------------------------------------------
1 | # Discrete updates
2 |
3 | JPC provides access to standard discrete optimisers to update the parameters of
4 | PC networks ([jpc.update_pc_params](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_pc_params)),
5 | and to both discrete ([jpc.update_pc_activities](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_pc_activities))
6 | and continuous optimisers ([jpc.solve_inference](https://thebuckleylab.github.io/jpc/api/Continuous-time%20Inference/#jpc.solve_inference))
7 | to solve the PC inference or activity dynamics.
8 |
9 | ::: jpc.update_pc_activities
10 |
11 | ---
12 |
13 | ::: jpc.update_pc_params
14 |
15 | ---
16 |
17 | ::: jpc.update_bpc_activities
18 |
19 | ---
20 |
21 | ::: jpc.update_bpc_params
22 |
23 | ---
24 |
25 | ::: jpc.update_epc_errors
26 |
27 | ---
28 |
29 | ::: jpc.update_epc_params
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 |
12 | jobs:
13 | test:
14 | strategy:
15 | matrix:
16 | python-version: ["3.10", "3.11"]
17 | os: [ubuntu-latest]
18 | runs-on: ${{ matrix.os }}
19 | steps:
20 | - name: Checkout code
21 | uses: actions/checkout@v4
22 |
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v5
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 |
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | python -m pip install -e .
32 |
33 | - name: Run tests with coverage
34 | run: |
35 | pytest tests/ --cov=jpc --cov-report=term -v
--------------------------------------------------------------------------------
/.github/workflows/build_docs.yml:
--------------------------------------------------------------------------------
1 | name: Build docs
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | permissions:
9 | contents: write
10 |
11 | jobs:
12 | build:
13 | strategy:
14 | matrix:
15 | python-version: [ 3.11 ]
16 | os: [ ubuntu-latest ]
17 | runs-on: ${{ matrix.os }}
18 | steps:
19 | - name: Checkout code
20 | uses: actions/checkout@v2
21 |
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 |
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install -e .
31 | python -m pip install -r docs/requirements.txt
32 |
33 | - name: Build docs
34 | run: |
35 | mkdocs build
36 |
37 | - name: Deploy docs
38 | run: mkdocs gh-deploy --force
--------------------------------------------------------------------------------
/docs/api/Initialisation.md:
--------------------------------------------------------------------------------
1 | # Initialisation
2 |
3 | JPC provides 4 ways of initialising the activities of a PC network:
4 |
5 | * [jpc.init_activities_with_ffwd()](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_activities_with_ffwd) for a feedforward pass (standard),
6 | * [jpc.init_activities_from_normal()](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_activities_from_normal) for random initialisation,
7 | * [jpc.init_activities_with_amort()](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_activities_with_amort) for use of an amortised network, and
8 | * [jpc.init_epc_errors()](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_epc_errors) for zero-initialisation of prediction errors in [ePC](https://arxiv.org/abs/2505.20137).
9 |
10 | ::: jpc.init_activities_with_ffwd
11 |
12 | ---
13 |
14 | ::: jpc.init_activities_from_normal
15 |
16 | ---
17 |
18 | ::: jpc.init_activities_with_amort
19 |
20 | ---
21 |
22 | ::: jpc.init_epc_errors
--------------------------------------------------------------------------------
/docs/api/Theoretical tools.md:
--------------------------------------------------------------------------------
1 | # Theoretical tools
2 |
3 | JPC provides the following theoretical tools that can be used to study
4 | **deep linear networks** (DLNs) trained with PC:
5 |
6 | * [jpc.compute_linear_equilib_energy()](https://thebuckleylab.github.io/jpc/api/Theoretical%20tools/#jpc.compute_linear_equilib_energy)
7 | to compute the theoretical PC energy at the solution of the activities for DLNs;
8 | * [jpc.compute_linear_activity_hessian()](https://thebuckleylab.github.io/jpc/api/Theoretical%20tools/#jpc.compute_linear_activity_hessian)
9 | to compute the theoretical Hessian of the energy with respect to the activities of DLNs;
10 | * [jpc.compute_linear_activity_solution()](https://thebuckleylab.github.io/jpc/api/Theoretical%20tools/#jpc.compute_linear_activity_solution)
11 | to compute the analytical PC inference solution for DLNs.
12 |
13 | ::: jpc.compute_linear_equilib_energy
14 |
15 | ---
16 |
17 | ::: jpc.compute_linear_activity_hessian
18 |
19 | ---
20 |
21 | ::: jpc.compute_linear_activity_solution
--------------------------------------------------------------------------------
/docs/api/Energy functions.md:
--------------------------------------------------------------------------------
1 | # Energy functions
2 |
3 | JPC provides three main PC energy functions:
4 |
5 | * [jpc.pc_energy_fn()](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc.pc_energy_fn)
6 | for standard PC networks,
7 | * [jpc.hpc_energy_fn()](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc.hpc_energy_fn)
8 | for hybrid PC models ([Tscshantz et al., 2023](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1011280)),
9 | * [jpc.bpc_energy_fn()](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc.hpc_energy_fn)
10 | for bidirectional PC models ([Oliviers et al., 2025](https://arxiv.org/abs/2505.23415)), and
11 | * [jpc.epc_energy_fn()](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc.epc_energy_fn)
12 | for error-reparameterised PC ([Goemaere et al., 2025](https://arxiv.org/abs/2505.20137)).
13 |
14 | ::: jpc.pc_energy_fn
15 |
16 | ---
17 |
18 | ::: jpc.hpc_energy_fn
19 |
20 | ---
21 |
22 | ::: jpc.bpc_energy_fn
23 |
24 | ---
25 |
26 | ::: jpc.epc_energy_fn
27 |
28 | ---
29 |
30 | ::: jpc._get_param_scalings
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 thebuckleylab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/api/Gradients.md:
--------------------------------------------------------------------------------
1 | # Gradients
2 |
3 | !!! info
4 | There are two similar functions to compute the gradient of the energy with
5 | respect to the activities of a standard PC energy: [`jpc.neg_pc_activity_grad()`](https://thebuckleylab.github.io/jpc/api/Gradients/#jpc.neg_pc_activity_grad)
6 | and [`jpc.compute_pc_activity_grad()`](https://thebuckleylab.github.io/jpc/api/Gradients/#jpc.compute_pc_activity_grad).
7 | The first is used by [`jpc.solve_inference()`](https://thebuckleylab.github.io/jpc/api/Continuous-time%20Inference/#jpc.solve_inference)
8 | as gradient flow, while the second is for compatibility with discrete
9 | [optax](https://github.com/google-deepmind/optax) optimisers such as
10 | gradient descent.
11 |
12 | ::: jpc.neg_pc_activity_grad
13 |
14 | ---
15 |
16 | ::: jpc.compute_pc_activity_grad
17 |
18 | ---
19 |
20 | ::: jpc.compute_bpc_activity_grad
21 |
22 | ---
23 |
24 | ::: jpc.compute_epc_error_grad
25 |
26 | ---
27 |
28 | ::: jpc.compute_pc_param_grads
29 |
30 | ---
31 |
32 | ::: jpc.compute_hpc_param_grads
33 |
34 | ---
35 |
36 | ::: jpc.compute_bpc_param_grads
37 |
38 | ---
39 |
40 | ::: jpc.compute_epc_param_grads
--------------------------------------------------------------------------------
/jpc/_core/__init__.py:
--------------------------------------------------------------------------------
1 | from ._init import (
2 | init_activities_with_ffwd as init_activities_with_ffwd,
3 | init_activities_from_normal as init_activities_from_normal,
4 | init_activities_with_amort as init_activities_with_amort,
5 | init_epc_errors as init_epc_errors
6 | )
7 | from ._energies import (
8 | pc_energy_fn as pc_energy_fn,
9 | hpc_energy_fn as hpc_energy_fn,
10 | bpc_energy_fn as bpc_energy_fn,
11 | epc_energy_fn as epc_energy_fn,
12 | pdm_energy_fn as pdm_energy_fn,
13 | _get_param_scalings as _get_param_scalings
14 | )
15 | from ._grads import (
16 | neg_pc_activity_grad as neg_pc_activity_grad,
17 | compute_pc_activity_grad as compute_pc_activity_grad,
18 | compute_pc_param_grads as compute_pc_param_grads,
19 | compute_hpc_param_grads as compute_hpc_param_grads,
20 | compute_bpc_activity_grad as compute_bpc_activity_grad,
21 | compute_bpc_param_grads as compute_bpc_param_grads,
22 | compute_epc_error_grad as compute_epc_error_grad,
23 | compute_epc_param_grads as compute_epc_param_grads,
24 | compute_pdm_activity_grad as compute_pdm_activity_grad,
25 | compute_pdm_param_grads as compute_pdm_param_grads
26 | )
27 | from ._infer import solve_inference as solve_inference
28 | from ._updates import (
29 | update_pc_activities as update_pc_activities,
30 | update_pc_params as update_pc_params,
31 | update_bpc_activities as update_bpc_activities,
32 | update_bpc_params as update_bpc_params,
33 | update_epc_errors as update_epc_errors,
34 | update_epc_params as update_epc_params,
35 | update_pdm_activities as update_pdm_activities,
36 | update_pdm_params as update_pdm_params
37 | )
38 | from ._analytical import (
39 | compute_linear_equilib_energy as compute_linear_equilib_energy,
40 | compute_linear_activity_hessian as compute_linear_activity_hessian,
41 | compute_linear_activity_solution as compute_linear_activity_solution
42 | )
43 | from ._errors import _check_param_type as _check_param_type
44 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | # JPC Test Suite
2 |
3 | This directory contains comprehensive tests for the jpc library.
4 |
5 | ## Running Tests
6 |
7 | To run all tests:
8 | ```bash
9 | pytest tests/
10 | ```
11 |
12 | To run a specific test file:
13 | ```bash
14 | pytest tests/test_energies.py
15 | ```
16 |
17 | To run with verbose output:
18 | ```bash
19 | pytest tests/ -v
20 | ```
21 |
22 | To run with coverage:
23 | ```bash
24 | pytest tests/ --cov=jpc --cov-report=term
25 | ```
26 |
27 | To generate an HTML coverage report:
28 | ```bash
29 | pytest tests/ --cov=jpc --cov-report=html
30 | ```
31 | Then open `htmlcov/index.html` in your browser.
32 |
33 | To see coverage percentage only:
34 | ```bash
35 | pytest tests/ --cov=jpc --cov-report=term-missing
36 | ```
37 |
38 | ## Test Structure
39 |
40 | - `conftest.py`: Shared fixtures and pytest configuration
41 | - `test_errors.py`: Tests for error checking functions
42 | - `test_init.py`: Tests for initialization functions
43 | - `test_energies.py`: Tests for energy functions (PC, HPC, BPC, PDM)
44 | - `test_grads.py`: Tests for gradient computation functions
45 | - `test_infer.py`: Tests for inference solving functions
46 | - `test_updates.py`: Tests for update functions (activities and parameters)
47 | - `test_analytical.py`: Tests for analytical tools
48 | - `test_utils.py`: Tests for utility functions
49 | - `test_train.py`: Tests for training functions
50 | - `test_test_functions.py`: Tests for test utility functions
51 |
52 | ## Automated Testing
53 |
54 | Tests run automatically on every push and pull request via GitHub Actions (`.github/workflows/tests.yml`). The workflow:
55 | - Runs tests on Python 3.10 and 3.11
56 | - Generates coverage reports
57 | - Automatically updates the coverage badge in the main README on pushes to `main`
58 |
59 | ## Requirements
60 |
61 | The tests require:
62 | - pytest
63 | - pytest-cov (for coverage reporting)
64 | - jax
65 | - jax.numpy
66 | - equinox
67 | - diffrax
68 | - optax
69 |
70 | All dependencies should be installed when installing jpc. Coverage configuration is defined in `pyproject.toml` under `[tool.coverage.*]`.
71 |
--------------------------------------------------------------------------------
/docs/advanced_usage.md:
--------------------------------------------------------------------------------
1 | # Advanced usage
2 |
3 | Advanced users can access all the underlying functions of `jpc.make_pc_step()`
4 | as well as additional features. A custom PC training step looks like the
5 | following:
6 | ```py
7 | import jpc
8 |
9 | # 1. initialise activities with a feedforward pass
10 | activities = jpc.init_activities_with_ffwd(model=model, input=x)
11 |
12 | # 2. run inference to equilibrium
13 | equilibrated_activities = jpc.solve_inference(
14 | params=(model, None),
15 | activities=activities,
16 | output=y,
17 | input=x
18 | )
19 |
20 | # 3. update parameters at the activities' solution with PC
21 | param_update_result = jpc.update_params(
22 | params=(model, None),
23 | activities=equilibrated_activities,
24 | optim=param_optim,
25 | opt_state=param_opt_state,
26 | output=y,
27 | input=x
28 | )
29 |
30 | # updated model and optimiser
31 | model = param_update_result["model"]
32 | param_opt_state = param_update_result["opt_state"]
33 | ```
34 | which can be embedded in a jitted function with any other additional
35 | computations. One can also use any [optax
36 | ](https://optax.readthedocs.io/en/latest/api/optimizers.html) optimiser to
37 | equilibrate the inference dynamics by replacing the function in step 2, as
38 | shown below.
39 | ```py
40 | activity_optim = optax.adam(1e-3)
41 |
42 | # 1. initialise activities
43 | ...
44 |
45 | # 2. infer with adam
46 | activity_opt_state = activity_optim.init(activities)
47 |
48 | for t in range(T):
49 | activity_update_result = jpc.update_activities(
50 | params=(model, None),
51 | activities=activities,
52 | optim=activity_optim,
53 | opt_state=activity_opt_state,
54 | output=y,
55 | input=x
56 | )
57 | # updated activities and optimiser
58 | activities = activity_update_result["activities"]
59 | activity_opt_state = activity_update_result["opt_state"]
60 |
61 | # 3. update parameters at the activities' solution with PC
62 | ...
63 | ```
64 | See the [updates docs
65 | ](https://thebuckleylab.github.io/jpc/api/Updates/) for more details. JPC also
66 | comes with some analytical tools that can be used to study and potentially
67 | diagnose issues with PCNs
68 | (see [docs
69 | ](https://thebuckleylab.github.io/jpc/api/Analytical%20tools/)
70 | and [example notebook
71 | ](https://thebuckleylab.github.io/jpc/examples/theoretical_energy_with_linear_net/)).
72 |
--------------------------------------------------------------------------------
/experiments/library_paper/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | from torch import manual_seed
5 | from diffrax import Euler, Heun, Midpoint, Ralston, Bosh3, Tsit5, Dopri5, Dopri8
6 |
7 |
8 | def set_seed(seed):
9 | np.random.seed(seed)
10 | random.seed(seed)
11 | manual_seed(seed)
12 |
13 |
14 | def get_ode_solver(name):
15 | if name == "Euler":
16 | return Euler()
17 | elif name == "Heun":
18 | return Heun()
19 | elif name == "Midpoint":
20 | return Midpoint()
21 | elif name == "Ralston":
22 | return Ralston()
23 | elif name == "Bosh3":
24 | return Bosh3()
25 | elif name == "Tsit5":
26 | return Tsit5()
27 | elif name == "Dopri5":
28 | return Dopri5()
29 | elif name == "Dopri8":
30 | return Dopri8()
31 |
32 |
33 | def setup_mlp_experiment(
34 | results_dir,
35 | dataset,
36 | width,
37 | n_hidden,
38 | act_fn,
39 | max_t1,
40 | activity_lr,
41 | param_lr,
42 | activity_optim_id,
43 | seed
44 | ):
45 | print(
46 | f"""
47 | Starting experiment with configuration:
48 |
49 | Dataset: {dataset}
50 | Width: {width}
51 | N hidden: {n_hidden}
52 | Act fn: {act_fn}
53 | Max t1: {max_t1}
54 | Activity step size: {activity_lr}
55 | Param learning rate: {param_lr}
56 | Activity optim: {activity_optim_id}
57 | Seed: {seed}
58 | """
59 | )
60 | return os.path.join(
61 | results_dir,
62 | dataset,
63 | f"width_{width}",
64 | f"{n_hidden}_n_hidden",
65 | act_fn,
66 | f"max_t1_{max_t1}",
67 | f"activity_lr_{activity_lr}",
68 | f"param_lr_{param_lr}",
69 | activity_optim_id,
70 | str(seed)
71 | )
72 |
73 |
74 | def get_min_iter(lists):
75 | min_iter = 100000
76 | for i in lists:
77 | if len(i) < min_iter:
78 | min_iter = len(i)
79 | return min_iter
80 |
81 |
82 | def get_min_iter_metrics(metrics):
83 | n_seeds = len(metrics)
84 | min_iter = get_min_iter(lists=metrics)
85 |
86 | min_iter_metrics = np.zeros((n_seeds, min_iter))
87 | for seed in range(n_seeds):
88 | min_iter_metrics[seed, :] = metrics[seed][:min_iter]
89 |
90 | return min_iter_metrics
91 |
92 |
93 | def compute_metric_stats(metric):
94 | min_iter_metrics = get_min_iter_metrics(metrics=metric)
95 | metric_means = min_iter_metrics.mean(axis=0)
96 | metric_stds = min_iter_metrics.std(axis=0)
97 | return metric_means, metric_stds
98 |
--------------------------------------------------------------------------------
/jpc/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 |
3 | from ._core import (
4 | init_activities_with_ffwd as init_activities_with_ffwd,
5 | init_activities_from_normal as init_activities_from_normal,
6 | init_activities_with_amort as init_activities_with_amort,
7 | init_epc_errors as init_epc_errors,
8 | pc_energy_fn as pc_energy_fn,
9 | hpc_energy_fn as hpc_energy_fn,
10 | bpc_energy_fn as bpc_energy_fn,
11 | epc_energy_fn as epc_energy_fn,
12 | pdm_energy_fn as pdm_energy_fn,
13 | _get_param_scalings as _get_param_scalings,
14 | neg_pc_activity_grad as neg_pc_activity_grad,
15 | solve_inference as solve_inference,
16 | compute_pc_activity_grad as compute_pc_activity_grad,
17 | compute_pc_param_grads as compute_pc_param_grads,
18 | compute_hpc_param_grads as compute_hpc_param_grads,
19 | compute_bpc_activity_grad as compute_bpc_activity_grad,
20 | compute_bpc_param_grads as compute_bpc_param_grads,
21 | compute_epc_error_grad as compute_epc_error_grad,
22 | compute_epc_param_grads as compute_epc_param_grads,
23 | compute_pdm_activity_grad as compute_pdm_activity_grad,
24 | compute_pdm_param_grads as compute_pdm_param_grads,
25 | update_pc_activities as update_pc_activities,
26 | update_pc_params as update_pc_params,
27 | update_bpc_activities as update_bpc_activities,
28 | update_bpc_params as update_bpc_params,
29 | update_epc_errors as update_epc_errors,
30 | update_epc_params as update_epc_params,
31 | update_pdm_activities as update_pdm_activities,
32 | update_pdm_params as update_pdm_params,
33 | compute_linear_equilib_energy as compute_linear_equilib_energy,
34 | compute_linear_activity_hessian as compute_linear_activity_hessian,
35 | compute_linear_activity_solution as compute_linear_activity_solution,
36 | _check_param_type as _check_param_type
37 | )
38 | from ._utils import (
39 | make_mlp as make_mlp,
40 | make_skip_model as make_skip_model,
41 | get_act_fn as get_act_fn,
42 | mse_loss as mse_loss,
43 | cross_entropy_loss as cross_entropy_loss,
44 | compute_accuracy as compute_accuracy,
45 | get_t_max as get_t_max,
46 | compute_activity_norms as compute_activity_norms,
47 | compute_infer_energies as compute_infer_energies,
48 | compute_param_norms as compute_param_norms
49 | )
50 | from ._train import (
51 | make_pc_step as make_pc_step,
52 | make_hpc_step as make_hpc_step
53 | )
54 | from ._test import (
55 | test_discriminative_pc as test_discriminative_pc,
56 | test_generative_pc as test_generative_pc,
57 | test_hpc as test_hpc
58 | )
59 |
60 |
61 | __version__ = importlib.metadata.version("jpc")
62 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "jpc"
3 | version = "1.0.0"
4 | description = "Flexible Inference for Predictive Coding Networks in JAX."
5 | readme = "README.md"
6 | requires-python =">=3.10"
7 | license = {file = "LICENSE"}
8 | authors = [
9 | {name = "Francesco Innocenti", email = "F.Innocenti@sussex.ac.uk"},
10 | ]
11 | keywords = [
12 | "jax",
13 | "predictive-coding",
14 | "neural-networks",
15 | "hybrid-predictive-coding",
16 | "deep-learning",
17 | "local-learning",
18 | "inference-learning",
19 | "mupc"
20 | ]
21 | classifiers = [
22 | "Development Status :: 3 - Alpha",
23 | "Intended Audience :: Developers",
24 | "Intended Audience :: Financial and Insurance Industry",
25 | "Intended Audience :: Information Technology",
26 | "Intended Audience :: Science/Research",
27 | "License :: OSI Approved :: MIT License",
28 | "Natural Language :: English",
29 | "Programming Language :: Python :: 3",
30 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
31 | "Topic :: Scientific/Engineering :: Information Analysis",
32 | "Topic :: Scientific/Engineering :: Mathematics",
33 | ]
34 | urls = {repository = "https://github.com/thebuckleylab/jpc"}
35 | dependencies = [
36 | "jax>=0.4.38,<=0.5.2", # to prevent jaxlib import error
37 | "equinox>=0.11.2",
38 | "diffrax>=0.6.0",
39 | "optax>=0.2.4",
40 | "jaxtyping>=0.2.24",
41 | "pytest-cov>=4.0.0"
42 | ]
43 |
44 | [build-system]
45 | build-backend = "hatchling.build"
46 | requires = ["hatchling"]
47 |
48 | [tool.hatch.build]
49 | include = ["jpc/*"]
50 |
51 | [tool.ruff]
52 | extend-include = ["*.ipynb"]
53 | src = []
54 |
55 | [tool.ruff.lint]
56 | fixable = ["I001", "F401"]
57 | ignore = ["E402", "E721", "E731", "E741", "F722"]
58 | ignore-init-module-imports = true
59 | select = ["E", "F", "I001"]
60 |
61 | [tool.ruff.lint.isort]
62 | combine-as-imports = true
63 | extra-standard-library = ["typing_extensions"]
64 | lines-after-imports = 2
65 | order-by-type = false
66 |
67 | [tool.pytest.ini_options]
68 | testpaths = ["tests"]
69 | python_files = ["test_*.py"]
70 | python_classes = ["Test*"]
71 | python_functions = ["test_*"]
72 | norecursedirs = [".git", "venv", "site", "jpc.egg-info", "examples", "experiments", "jpc"]
73 |
74 | [tool.coverage.run]
75 | source = ["jpc"]
76 | omit = [
77 | "*/tests/*",
78 | "*/test_*.py",
79 | "*/__pycache__/*",
80 | "*/site/*",
81 | "*/examples/*",
82 | "*/experiments/*",
83 | ]
84 |
85 | [tool.coverage.report]
86 | exclude_lines = [
87 | "pragma: no cover",
88 | "def __repr__",
89 | "raise AssertionError",
90 | "raise NotImplementedError",
91 | "if __name__ == .__main__.:",
92 | "if TYPE_CHECKING:",
93 | "@abstractmethod",
94 | ]
--------------------------------------------------------------------------------
/experiments/mupc_paper/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | anyio==4.8.0
3 | appnope==0.1.4
4 | argon2-cffi==23.1.0
5 | argon2-cffi-bindings==21.2.0
6 | arrow==1.3.0
7 | asttokens==3.0.0
8 | async-lru==2.0.4
9 | attrs==25.1.0
10 | babel==2.17.0
11 | beautifulsoup4==4.13.3
12 | bleach==6.2.0
13 | certifi==2025.1.31
14 | cffi==1.17.1
15 | charset-normalizer==3.4.1
16 | chex==0.1.89
17 | comm==0.2.2
18 | contourpy==1.3.1
19 | cycler==0.12.1
20 | debugpy==1.8.13
21 | decorator==5.2.1
22 | defusedxml==0.7.1
23 | diffrax==0.6.2
24 | equinox==0.11.12
25 | etils==1.12.0
26 | exceptiongroup==1.2.2
27 | executing==2.2.0
28 | fastjsonschema==2.21.1
29 | filelock==3.17.0
30 | fonttools==4.56.0
31 | fqdn==1.5.1
32 | fsspec==2025.2.0
33 | h11==0.14.0
34 | httpcore==1.0.7
35 | httpx==0.28.1
36 | idna==3.10
37 | ipykernel==6.29.5
38 | ipython==8.33.0
39 | ipywidgets==8.1.5
40 | isoduration==20.11.0
41 | jax==0.5.2
42 | jaxlib==0.5.1
43 | jaxtyping==0.2.38
44 | jedi==0.19.2
45 | Jinja2==3.1.5
46 | json5==0.10.0
47 | jsonpointer==3.0.0
48 | jsonschema==4.23.0
49 | jsonschema-specifications==2024.10.1
50 | jupyter==1.1.1
51 | jupyter-console==6.6.3
52 | jupyter-events==0.12.0
53 | jupyter-lsp==2.2.5
54 | jupyter_client==8.6.3
55 | jupyter_core==5.7.2
56 | jupyter_server==2.15.0
57 | jupyter_server_terminals==0.5.3
58 | jupyterlab==4.3.5
59 | jupyterlab_pygments==0.3.0
60 | jupyterlab_server==2.27.3
61 | jupyterlab_widgets==3.0.13
62 | kaleido==0.2.1
63 | kiwisolver==1.4.8
64 | lineax==0.0.7
65 | MarkupSafe==3.0.2
66 | matplotlib==3.10.1
67 | matplotlib-inline==0.1.7
68 | mistune==3.1.2
69 | ml_dtypes==0.5.1
70 | mpmath==1.3.0
71 | narwhals==1.29.0
72 | nbclient==0.10.2
73 | nbconvert==7.16.6
74 | nbformat==5.10.4
75 | nest-asyncio==1.6.0
76 | networkx==3.4.2
77 | notebook==7.3.2
78 | notebook_shim==0.2.4
79 | numpy==2.2.3
80 | opt_einsum==3.4.0
81 | optax==0.2.4
82 | optimistix==0.0.10
83 | overrides==7.7.0
84 | packaging==24.2
85 | pandocfilters==1.5.1
86 | parso==0.8.4
87 | pexpect==4.9.0
88 | pillow==11.1.0
89 | platformdirs==4.3.6
90 | plotly==5.24.1
91 | prometheus_client==0.21.1
92 | prompt_toolkit==3.0.50
93 | psutil==7.0.0
94 | ptyprocess==0.7.0
95 | pure_eval==0.2.3
96 | pycparser==2.22
97 | Pygments==2.19.1
98 | pyparsing==3.2.1
99 | python-dateutil==2.9.0.post0
100 | python-json-logger==3.2.1
101 | PyYAML==6.0.2
102 | pyzmq==26.2.1
103 | referencing==0.36.2
104 | requests==2.32.3
105 | rfc3339-validator==0.1.4
106 | rfc3986-validator==0.1.1
107 | rpds-py==0.23.1
108 | scipy==1.15.2
109 | Send2Trash==1.8.3
110 | six==1.17.0
111 | sniffio==1.3.1
112 | soupsieve==2.6
113 | stack-data==0.6.3
114 | sympy==1.13.1
115 | tenacity==9.0.0
116 | terminado==0.18.1
117 | tinycss2==1.4.0
118 | tomli==2.2.1
119 | toolz==1.0.0
120 | torch==2.5.1+cu121
121 | torchvision==0.20.1+cu121
122 | tornado==6.4.2
123 | traitlets==5.14.3
124 | typeguard==2.13.3
125 | types-python-dateutil==2.9.0.20241206
126 | typing_extensions==4.12.2
127 | uri-template==1.3.0
128 | urllib3==2.3.0
129 | wadler_lindig==0.1.3
130 | wcwidth==0.2.13
131 | webcolors==24.11.1
132 | webencodings==0.5.1
133 | websocket-client==1.8.0
134 | widgetsnbextension==4.0.13
135 |
--------------------------------------------------------------------------------
/experiments/mupc_paper/README.md:
--------------------------------------------------------------------------------
1 | # μPC paper
2 |
3 | [](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/mupc.ipynb) [](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | This folder contains code to reproduce all the experiments for the NeurIPS 2025
12 | paper ["μPC": Scaling Predictive Coding to 100+ Layer Networks](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1)). For a high-level summary, see
13 | [this blog post](https://francesco-innocenti.github.io/posts/2025/05/20/Scaling-Predictive-Coding-to-100+-Layer-Networks/).
14 | And for a tutorial, see [this example notebook](https://thebuckleylab.github.io/jpc/examples/mupc/).
15 |
16 |
17 | ## Setup
18 | Clone the `jpc` repo. We recommend using a virtual environment, e.g.
19 | ```
20 | python3 -m venv venv
21 | ```
22 | Install `jpc`
23 | ```
24 | pip install -e .
25 | ```
26 | For GPU usage, upgrade jax to the appropriate cuda version (12 as an example
27 | here).
28 |
29 | ```
30 | pip install --upgrade "jax[cuda12]==0.5.2"
31 | ```
32 | Now navigate to `experiments/mupc_paper` and install all the requirements
33 |
34 | ```
35 | pip install -r requirements.txt
36 | ```
37 |
38 |
39 | ## Compute resources
40 | We recommend using a GPU for the experiments with 64- and 128-layer networks.
41 |
42 |
43 | ## Scripts
44 | * `train_pcn_no_metrics.py`: This is the main script that was used to produce
45 | results for Figs. 1, 5, & A.16-A.18.
46 | * `analyse_activity_hessian`: This script can be used to reproduce results
47 | related to spectral properties of the activity Hessian at initialisation
48 | (Figs. 2 & 4, & Figs. A.1-A.7, A.12 & A.21).
49 | * `train_pcn.py`: This was mainly used to monitor the condition number of the
50 | activity Hessian during training (Figs. 3, A.8-A.9, A.13-A.14,, A.22-A.28).
51 | * `test_energy_theory.py`: This can be used to reproduce results related to
52 | the convergence behaviour of μPC to BP shown in Section 6
53 | (Figs. 6 & A.32-A.33).
54 | * `train_bpn.py`: Used to obtain all the results with backprop.
55 | * `test_mlp_fwd_pass.py`: Used for results of Fig. A.29.
56 | * `toy_experiments.ipynb`: Used for many secondary results in the Appendix.
57 |
58 | The majority of results are plotted in `plot_results.ipynb` under informative
59 | headings. For details of all the experiments, see Section A.4 of the paper.
60 |
61 |
62 | ## Citation
63 | If you use μPC in your work, please cite the paper:
64 |
65 | ```bibtex
66 | @article{innocenti2025mu,
67 | title={$$\backslash$mu $ PC: Scaling Predictive Coding to 100+ Layer Networks},
68 | author={Innocenti, Francesco and Achour, El Mehdi and Buckley, Christopher L},
69 | journal={arXiv preprint arXiv:2505.13124},
70 | year={2025}
71 | }
72 | ```
73 | Also consider starring the repo! ⭐️
74 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Pytest configuration and shared fixtures for jpc tests."""
2 |
3 | import pytest
4 | import jax
5 | import jax.numpy as jnp
6 | import equinox.nn as nn
7 | from jpc import make_mlp
8 |
9 |
10 | @pytest.fixture
11 | def key():
12 | """Random key for testing."""
13 | return jax.random.PRNGKey(42)
14 |
15 |
16 | @pytest.fixture
17 | def batch_size():
18 | """Batch size for testing."""
19 | return 4
20 |
21 |
22 | @pytest.fixture
23 | def input_dim():
24 | """Input dimension for testing."""
25 | return 10
26 |
27 |
28 | @pytest.fixture
29 | def hidden_dim():
30 | """Hidden dimension for testing."""
31 | return 20
32 |
33 |
34 | @pytest.fixture
35 | def output_dim():
36 | """Output dimension for testing."""
37 | return 5
38 |
39 |
40 | @pytest.fixture
41 | def depth():
42 | """Network depth for testing."""
43 | return 3
44 |
45 |
46 | @pytest.fixture
47 | def simple_model(key, input_dim, hidden_dim, output_dim, depth):
48 | """Create a simple MLP model for testing."""
49 | return make_mlp(
50 | key=key,
51 | input_dim=input_dim,
52 | width=hidden_dim,
53 | depth=depth,
54 | output_dim=output_dim,
55 | act_fn="relu",
56 | use_bias=False,
57 | param_type="sp"
58 | )
59 |
60 |
61 | @pytest.fixture
62 | def x(key, batch_size, input_dim):
63 | """Sample input data."""
64 | return jax.random.normal(key, (batch_size, input_dim))
65 |
66 |
67 | @pytest.fixture
68 | def y(key, batch_size, output_dim):
69 | """Sample output/target data."""
70 | return jax.random.normal(key, (batch_size, output_dim))
71 |
72 |
73 | @pytest.fixture
74 | def y_onehot(key, batch_size, output_dim):
75 | """Sample one-hot encoded target data."""
76 | indices = jax.random.randint(key, (batch_size,), 0, output_dim)
77 | return jax.nn.one_hot(indices, output_dim)
78 |
79 |
80 | @pytest.fixture
81 | def layer_sizes(input_dim, hidden_dim, output_dim, depth):
82 | """Layer sizes for testing."""
83 | sizes = [input_dim]
84 | for _ in range(depth - 1):
85 | sizes.append(hidden_dim)
86 | sizes.append(output_dim)
87 | return sizes
88 |
89 |
90 | def pytest_ignore_collect(collection_path, config):
91 | """Skip jpc/_test.py which contains library functions, not test functions."""
92 | # Skip _test.py files in the jpc directory (library functions, not test functions)
93 | if collection_path.name == "_test.py":
94 | file_str = str(collection_path.resolve())
95 | # Check if it's in jpc directory but not in tests
96 | if "/jpc/" in file_str and "/tests/" not in file_str:
97 | return True # Ignore this file
98 | return None # Use default behavior for other files
99 |
100 |
101 | def pytest_collection_modifyitems(config, items):
102 | """Remove any test items that were incorrectly collected from jpc._test module."""
103 | filtered_items = []
104 | for item in items:
105 | # Check file path
106 | item_path = None
107 | if hasattr(item, 'fspath') and item.fspath:
108 | item_path = str(item.fspath)
109 | elif hasattr(item, 'path') and item.path:
110 | item_path = str(item.path)
111 | elif hasattr(item, 'nodeid'):
112 | # Extract path from nodeid (format: path/to/file.py::test_function)
113 | nodeid = item.nodeid
114 | if '::' in nodeid:
115 | item_path = nodeid.split('::')[0]
116 |
117 | # Check if it's from jpc/_test.py
118 | if item_path and ('jpc/_test.py' in item_path or 'jpc\\_test.py' in item_path):
119 | continue
120 |
121 | # Check module name
122 | if hasattr(item, 'module') and item.module:
123 | module_name = getattr(item.module, '__name__', None)
124 | if module_name == 'jpc._test':
125 | continue
126 |
127 | # Check function location
128 | if hasattr(item, 'function') and hasattr(item.function, '__module__'):
129 | if item.function.__module__ == 'jpc._test':
130 | continue
131 |
132 | filtered_items.append(item)
133 |
134 | items[:] = filtered_items
135 |
--------------------------------------------------------------------------------
/docs/basic_usage.md:
--------------------------------------------------------------------------------
1 | JPC provides two types of API depending on the use case:
2 |
3 | * a simple, high-level API that allows to train and test models with predictive
4 | coding in a few lines of code, and
5 | * a more advanced API offering greater flexibility as well as additional features.
6 |
7 | # Basic usage
8 | At a high level, JPC provides a single convenience function `jpc.make_pc_step()`
9 | to update the parameters of a neural network with PC.
10 | ```py
11 | import jax.random as jr
12 | import jax.numpy as jnp
13 | import equinox as eqx
14 | import optax
15 | import jpc
16 |
17 | # toy data
18 | x = jnp.array([1., 1., 1.])
19 | y = -x
20 |
21 | # define model and optimiser
22 | key = jr.PRNGKey(0)
23 | model = jpc.make_mlp(
24 | key,
25 | input_dim=3,
26 | width=50,
27 | depth=5,
28 | output_dim=3
29 | act_fn="relu"
30 | )
31 | optim = optax.adam(1e-3)
32 | opt_state = optim.init(
33 | (eqx.filter(model, eqx.is_array), None)
34 | )
35 |
36 | # perform one training step with PC
37 | update_result = jpc.make_pc_step(
38 | model=model,
39 | optim=optim,
40 | opt_state=opt_state,
41 | output=y,
42 | input=x
43 | )
44 |
45 | # updated model and optimiser
46 | model, opt_state = update_result["model"], update_result["opt_state"]
47 | ```
48 | As shown above, at a minimum `jpc.make_pc_step()` takes a model, an [optax
49 | ](https://github.com/google-deepmind/optax) optimiser and its
50 | state, and some data. The model needs to be compatible with PC updates in the
51 | sense that it's split into callable layers (see the
52 | [example notebooks
53 | ](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/)). Also note
54 | that the `input` is actually not needed for unsupervised training. In fact,
55 | `jpc.make_pc_step()` can be used for classification and generation tasks, for
56 | supervised as well as unsupervised training (again see the [example notebooks
57 | ](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/)).
58 |
59 | Under the hood, `jpc.make_pc_step()` uses [diffrax
60 | ](https://github.com/patrick-kidger/diffrax) to solve the activity (inference)
61 | dynamics of PC. Many default arguments, for example related to the ODE solver,
62 | can be changed, including the ODE solver, and there is an option to record a
63 | variety of metrics such as loss, accuracy, and energies. See the [docs
64 | ](https://thebuckleylab.github.io/jpc/api/Training/#jpc.make_pc_step) for more
65 | details.
66 |
67 | A similar convenience function `jpc.make_hpc_step()` is provided for updating the
68 | parameters of a hybrid PCN ([Tschantz et al., 2023
69 | ](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1011280)).
70 | ```py
71 | import jax.random as jr
72 | import equinox as eqx
73 | import optax
74 | import jpc
75 |
76 | # models
77 | key = jr.PRNGKey(0)
78 | subkeys = jr.split(key, 2)
79 |
80 | input_dim, output_dim = 10, 3
81 | width, depth = 100, 5
82 | generator = jpc.make_mlp(
83 | subkeys[0],
84 | input_dim=input_dim,
85 | width=width,
86 | depth=depth,
87 | output_dim=output_dim
88 | act_fn="tanh"
89 | )
90 | # NOTE that the input and output of the amortiser are reversed
91 | amortiser = jpc.make_mlp(
92 | subkeys[0],
93 | input_dim=output_dim,
94 | width=width,
95 | depth=depth,
96 | output_dim=input_dim
97 | act_fn="tanh"
98 | )
99 |
100 | # optimisers
101 | gen_optim = optax.adam(1e-3)
102 | amort_optim = optax.adam(1e-3)
103 | gen_pt_state = gen_optim.init(
104 | (eqx.filter(generator, eqx.is_array), None)
105 | )
106 | amort_opt_state = amort_optim.init(
107 | eqx.filter(amortiser, eqx.is_array)
108 | )
109 |
110 | update_result = jpc.make_hpc_step(
111 | generator=generator,
112 | amortiser=amortiser,
113 | optims=[gen_optim, amort_optim],
114 | opt_states=[gen_opt_state, amort_opt_state],
115 | output=y,
116 | input=x
117 | )
118 | generator, amortiser = update_result["generator"], update_result["amortiser"]
119 | opt_states = update_result["opt_states"]
120 | gen_loss, amort_loss = update_result["losses"]
121 | ```
122 | See the [docs
123 | ](https://thebuckleylab.github.io/jpc/api/Training/#jpc.make_hpc_step) and the
124 | [example notebook
125 | ](https://thebuckleylab.github.io/jpc/examples/hybrid_pc/) for more details.
126 |
--------------------------------------------------------------------------------
/tests/test_init.py:
--------------------------------------------------------------------------------
1 | """Tests for initialization functions."""
2 |
3 | import pytest
4 | import jax
5 | import jax.numpy as jnp
6 | from jpc import (
7 | init_activities_with_ffwd,
8 | init_activities_from_normal,
9 | init_activities_with_amort,
10 | init_epc_errors
11 | )
12 |
13 |
14 | def test_init_activities_with_ffwd(simple_model, x):
15 | """Test feedforward initialization."""
16 | activities = init_activities_with_ffwd(
17 | model=simple_model,
18 | input=x,
19 | param_type="sp"
20 | )
21 |
22 | assert len(activities) == len(simple_model)
23 | assert activities[0].shape == (x.shape[0], simple_model[0][1].weight.shape[0])
24 | assert activities[-1].shape == (x.shape[0], simple_model[-1][1].weight.shape[0])
25 |
26 |
27 | def test_init_activities_with_ffwd_mupc(key, x, input_dim, hidden_dim, output_dim, depth):
28 | """Test feedforward initialization with mupc parameterization."""
29 | from jpc import make_mlp
30 |
31 | model = make_mlp(
32 | key=key,
33 | input_dim=input_dim,
34 | width=hidden_dim,
35 | depth=depth,
36 | output_dim=output_dim,
37 | act_fn="relu",
38 | use_bias=False,
39 | param_type="mupc"
40 | )
41 |
42 | activities = init_activities_with_ffwd(
43 | model=model,
44 | input=x,
45 | param_type="mupc"
46 | )
47 |
48 | assert len(activities) == len(model)
49 |
50 |
51 | def test_init_activities_with_ffwd_skip_connections(simple_model, x):
52 | """Test feedforward initialization with skip connections."""
53 | from jpc import make_skip_model
54 |
55 | skip_model = make_skip_model(len(simple_model))
56 | activities = init_activities_with_ffwd(
57 | model=simple_model,
58 | input=x,
59 | skip_model=skip_model,
60 | param_type="sp"
61 | )
62 |
63 | assert len(activities) == len(simple_model)
64 |
65 |
66 | def test_init_activities_from_normal_supervised(key, layer_sizes, batch_size):
67 | """Test random initialization in supervised mode."""
68 | activities = init_activities_from_normal(
69 | key=key,
70 | layer_sizes=layer_sizes,
71 | mode="supervised",
72 | batch_size=batch_size,
73 | sigma=0.05
74 | )
75 |
76 | # In supervised mode, input layer is not initialized
77 | assert len(activities) == len(layer_sizes) - 1
78 | for i, (act, size) in enumerate(zip(activities, layer_sizes[1:]), 1):
79 | assert act.shape == (batch_size, size)
80 |
81 |
82 | def test_init_activities_from_normal_unsupervised(key, layer_sizes, batch_size):
83 | """Test random initialization in unsupervised mode."""
84 | activities = init_activities_from_normal(
85 | key=key,
86 | layer_sizes=layer_sizes,
87 | mode="unsupervised",
88 | batch_size=batch_size,
89 | sigma=0.05
90 | )
91 |
92 | # In unsupervised mode, all layers including input are initialized
93 | assert len(activities) == len(layer_sizes)
94 | for act, size in zip(activities, layer_sizes):
95 | assert act.shape == (batch_size, size)
96 |
97 |
98 | def test_init_activities_with_amort(key, simple_model, y, input_dim, hidden_dim, output_dim, depth):
99 | """Test amortized initialization."""
100 | from jpc import make_mlp
101 |
102 | amortiser = make_mlp(
103 | key=key,
104 | input_dim=output_dim,
105 | width=hidden_dim,
106 | depth=depth,
107 | output_dim=input_dim,
108 | act_fn="relu",
109 | use_bias=False,
110 | param_type="sp"
111 | )
112 |
113 | activities = init_activities_with_amort(
114 | amortiser=amortiser,
115 | generator=simple_model,
116 | input=y
117 | )
118 |
119 | # Should return reversed activities plus dummy target prediction
120 | assert len(activities) == len(amortiser) + 1
121 |
122 |
123 | def test_init_epc_errors_supervised(layer_sizes, batch_size):
124 | """Test EPC error initialization in supervised mode."""
125 | errors = init_epc_errors(
126 | layer_sizes=layer_sizes,
127 | batch_size=batch_size,
128 | mode="supervised"
129 | )
130 |
131 | # In supervised mode, errors are initialized for layers 1 to L-1 (hidden layers only)
132 | assert len(errors) == len(layer_sizes) - 1
133 | for i, (err, size) in enumerate(zip(errors, layer_sizes[1:]), 1):
134 | assert err.shape == (batch_size, size)
135 | assert jnp.allclose(err, 0.0) # Should be zero-initialized
136 |
137 |
138 |
--------------------------------------------------------------------------------
/docs/_static/custom_css.css:
--------------------------------------------------------------------------------
1 | /* Fix /page#foo going to the top of the viewport and being hidden by the navbar */
2 | html {
3 | scroll-padding-top: 50px;
4 | }
5 |
6 | /* Hide the dark logo by default */
7 | #logo_dark_mode {
8 | display: none;
9 | }
10 |
11 | /* Show the light logo by default */
12 | #logo_light_mode {
13 | display: block;
14 | }
15 |
16 | /* Switch display property based on color scheme */
17 | [data-md-color-scheme="default"] {
18 | --md-footer-logo-dark-mode: none;
19 | --md-footer-logo-light-mode: block;
20 | }
21 |
22 | [data-md-color-scheme="slate"] {
23 | --md-footer-logo-dark-mode: block;
24 | --md-footer-logo-light-mode: none;
25 | }
26 |
27 | /* Apply the custom variables */
28 | #logo_light_mode {
29 | display: var(--md-footer-logo-light-mode);
30 | }
31 |
32 | #logo_dark_mode {
33 | display: var(--md-footer-logo-dark-mode);
34 | }
35 |
36 | /* Adjust logo size */
37 | .md-header__button.md-logo {
38 | margin: 0;
39 | padding: 0.5;
40 | }
41 | .md-header__button.md-logo img, .md-header__button.md-logo svg {
42 | height: 2.5rem;
43 | width: auto;
44 | }
45 |
46 | /* Fit the Twitter handle alongside the GitHub one in the top right. */
47 |
48 | div.md-header__source {
49 | width: revert;
50 | max-width: revert;
51 | }
52 |
53 | a.md-source {
54 | display: inline-block;
55 | }
56 |
57 | .md-source__repository {
58 | max-width: 100%;
59 | }
60 |
61 | /* Emphasise sections of nav on left hand side */
62 |
63 | nav.md-nav {
64 | padding-left: 5px;
65 | }
66 |
67 | nav.md-nav--secondary {
68 | border-left: revert !important;
69 | }
70 |
71 | .md-nav__title {
72 | font-size: 0.9rem;
73 | }
74 |
75 | .md-nav__item--section > .md-nav__link {
76 | font-size: 0.9rem;
77 | }
78 |
79 | /* Indent autogenerated documentation */
80 |
81 | div.doc-contents {
82 | padding-left: 25px;
83 | border-left: 4px solid rgba(230, 230, 230);
84 | }
85 |
86 | /* Increase visibility of splitters "---" */
87 |
88 | [data-md-color-scheme="default"] .md-typeset hr {
89 | border-bottom-color: rgb(0, 0, 0);
90 | border-bottom-width: 1pt;
91 | }
92 |
93 | [data-md-color-scheme="slate"] .md-typeset hr {
94 | border-bottom-color: rgb(230, 230, 230);
95 | }
96 |
97 | /* More space at the bottom of the page */
98 |
99 | .md-main__inner {
100 | margin-bottom: 1.5rem;
101 | }
102 |
103 | /* Remove prev/next footer buttons */
104 |
105 | .md-footer__inner {
106 | display: none;
107 | }
108 |
109 | /* Change font sizes */
110 |
111 | html {
112 | /* Decrease font size for overall webpage
113 | Down from 137.5% which is the Material default */
114 | font-size: 110%;
115 | }
116 |
117 | .md-typeset .admonition {
118 | /* Increase font size in admonitions */
119 | font-size: 100% !important;
120 | }
121 |
122 | .md-typeset details {
123 | /* Increase font size in details */
124 | font-size: 100% !important;
125 | }
126 |
127 | .md-typeset h1 {
128 | font-size: 1.6rem;
129 | }
130 |
131 | .md-typeset h2 {
132 | font-size: 1.5rem;
133 | }
134 |
135 | .md-typeset h3 {
136 | font-size: 1.3rem;
137 | }
138 |
139 | .md-typeset h4 {
140 | font-size: 1.1rem;
141 | }
142 |
143 | .md-typeset h5 {
144 | font-size: 0.9rem;
145 | }
146 |
147 | .md-typeset h6 {
148 | font-size: 0.8rem;
149 | }
150 |
151 | /* Bugfix: remove the superfluous parts generated when doing:
152 |
153 | ??? Blah
154 |
155 | ::: library.something
156 | */
157 |
158 | .md-typeset details .mkdocstrings > h4 {
159 | display: none;
160 | }
161 |
162 | .md-typeset details .mkdocstrings > h5 {
163 | display: none;
164 | }
165 |
166 | /* Change default colours for tags */
167 |
168 | [data-md-color-scheme="default"] {
169 | --md-typeset-a-color: rgb(0, 189, 164) !important;
170 | }
171 | [data-md-color-scheme="slate"] {
172 | --md-typeset-a-color: rgb(0, 189, 164) !important;
173 | }
174 |
175 | /* Highlight functions, classes etc. type signatures. Really helps to make clear where
176 | one item ends and another begins. */
177 |
178 | [data-md-color-scheme="default"] {
179 | --doc-heading-color: #DDD;
180 | --doc-heading-border-color: #CCC;
181 | --doc-heading-color-alt: #F0F0F0;
182 | }
183 | [data-md-color-scheme="slate"] {
184 | --doc-heading-color: rgb(25,25,33);
185 | --doc-heading-border-color: rgb(25,25,33);
186 | --doc-heading-color-alt: rgb(33,33,44);
187 | --md-code-bg-color: rgb(38,38,50);
188 | }
189 |
190 | h4.doc-heading {
191 | /* NOT var(--md-code-bg-color) as that's not visually distinct from other code blocks.*/
192 | background-color: var(--doc-heading-color);
193 | border: solid var(--doc-heading-border-color);
194 | border-width: 1.5pt;
195 | border-radius: 2pt;
196 | padding: 0pt 5pt 2pt 5pt;
197 | }
198 | h5.doc-heading, h6.heading {
199 | background-color: var(--doc-heading-color-alt);
200 | border-radius: 2pt;
201 | padding: 0pt 5pt 2pt 5pt;
202 | }
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | theme:
2 | name: material
3 | features:
4 | - navigation.sections # Sections are included in the navigation on the left.
5 | - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right.
6 | - header.autohide # header disappears as you scroll
7 | palette:
8 | # Light mode / dark mode
9 | # We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as
10 | # (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle.
11 | - scheme: default
12 | primary: white
13 | accent: amber
14 | toggle:
15 | icon: material/weather-night
16 | name: Switch to dark mode
17 | - scheme: slate
18 | primary: black
19 | accent: amber
20 | toggle:
21 | icon: material/weather-sunny
22 | name: Switch to light mode
23 | icon:
24 | repo: fontawesome/brands/github # GitHub logo in top right
25 | logo_light_mode: "_static/logo-light.svg" # jpc logo in top left
26 | logo_dark_mode: "_static/logo-dark.svg" # jpc logo in top left
27 | favicon: "_static/favicon.png"
28 | custom_dir: "docs/_overrides" # Overriding part of the HTML
29 |
30 | site_name: jpc
31 | site_description: The documentation for the jpc software library.
32 | site_author: Francesco Innocenti
33 | site_url: https://thebuckleylab.github.io/jpc/
34 |
35 | repo_url: https://github.com/thebuckleylab/jpc
36 | repo_name: thebuckleylab/jpc
37 | edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate
38 |
39 | strict: true # Don't allow warnings during the build process
40 |
41 | extra_javascript:
42 | # The below two make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/
43 | - _static/mathjax.js
44 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
45 |
46 | extra_css:
47 | - _static/custom_css.css
48 |
49 | markdown_extensions:
50 | - pymdownx.arithmatex: # Render LaTeX via MathJax
51 | generic: true
52 | - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme.
53 | - pymdownx.details # Allowing hidden expandable regions denoted by ???
54 | - pymdownx.snippets: # Include one Markdown file into another
55 | base_path: docs
56 | - admonition
57 | - toc:
58 | permalink: "¤" # Adds a clickable permalink to each section heading
59 | toc_depth: 4
60 |
61 | plugins:
62 | - search:
63 | separator: '[\s\-,:!=\[\]()"/]+|(?!\b)(?=[A-Z][a-z])|\.(?!\d)|&[lg]t;'
64 | - include_exclude_files:
65 | include:
66 | - ".htaccess"
67 | exclude:
68 | - "_overrides"
69 | - "examples/.ipynb_checkpoints/"
70 | - "examples/analytical_inference_with_linear_net.ipynb"
71 | - ipynb
72 | - mkdocstrings:
73 | handlers:
74 | python:
75 | options:
76 | force_inspection: true
77 | heading_level: 4
78 | inherited_members: true
79 | members_order: source
80 | show_bases: false
81 | show_if_no_docstring: true
82 | show_overloads: false
83 | show_root_heading: true
84 | show_signature_annotations: true
85 | show_source: false
86 | show_symbol_type_heading: true
87 | show_symbol_type_toc: true
88 |
89 | nav:
90 | - 'index.md'
91 | - ⚙️ How it works:
92 | - Basic usage: 'basic_usage.md'
93 | - Advanced usage: 'advanced_usage.md'
94 | - 📚 Examples:
95 | - Introductory:
96 | - Discriminative PC: 'examples/discriminative_pc.ipynb'
97 | - Supervised generative PC: 'examples/supervised_generative_pc.ipynb'
98 | - Unsupervised generative PC: 'examples/unsupervised_generative_pc.ipynb'
99 | - Advanced:
100 | - μPC: 'examples/mupc.ipynb'
101 | - ePC: 'examples/epc.ipynb'
102 | - Hybrid PC: 'examples/hybrid_pc.ipynb'
103 | - Bidirectional PC: 'examples/bidirectional_pc.ipynb'
104 | - Linear theoretical energy: 'examples/linear_net_theoretical_energy.ipynb'
105 | - JPC from scratch: 'examples/jpc_from_scratch.ipynb'
106 | - 🌱 Basic API:
107 | - 'api/Training.md'
108 | - 'api/Testing.md'
109 | - 'api/Utils.md'
110 | - 🚀 Advanced API:
111 | - 'api/Initialisation.md'
112 | - 'api/Energy functions.md'
113 | - 'api/Gradients.md'
114 | - 'api/Continuous-time Inference.md'
115 | - 'api/Discrete updates.md'
116 | - 'api/Theoretical tools.md'
117 |
118 | copyright: |
119 | © 2024 thebuckleylab
--------------------------------------------------------------------------------
/docs/_static/logo-light.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/test_infer.py:
--------------------------------------------------------------------------------
1 | """Tests for inference solving functions."""
2 |
3 | import pytest
4 | import jax
5 | import jax.numpy as jnp
6 | from diffrax import Heun, PIDController
7 | from jpc import solve_inference
8 |
9 |
10 | def test_solve_inference_supervised(simple_model, x, y):
11 | """Test inference solving in supervised mode."""
12 | from jpc import init_activities_with_ffwd
13 |
14 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp")
15 |
16 | solution = solve_inference(
17 | params=(simple_model, None),
18 | activities=activities,
19 | output=y,
20 | input=x,
21 | loss_id="mse",
22 | param_type="sp",
23 | solver=Heun(),
24 | max_t1=10,
25 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3)
26 | )
27 |
28 | assert len(solution) == len(activities)
29 | # Solution shape depends on whether recording is enabled
30 | # Without record_iters, it's just the final state
31 | for sol, act in zip(solution, activities):
32 | # Solution may have time dimension if recorded, or match activity shape
33 | assert sol.shape[-2:] == act.shape or sol.shape == act.shape
34 |
35 |
36 | def test_solve_inference_unsupervised(simple_model, y, key, layer_sizes, batch_size):
37 | """Test inference solving in unsupervised mode."""
38 | from jpc import init_activities_from_normal
39 |
40 | activities = init_activities_from_normal(
41 | key=key,
42 | layer_sizes=layer_sizes,
43 | mode="unsupervised",
44 | batch_size=batch_size,
45 | sigma=0.05
46 | )
47 |
48 | solution = solve_inference(
49 | params=(simple_model, None),
50 | activities=activities,
51 | output=y,
52 | input=None,
53 | loss_id="mse",
54 | param_type="sp",
55 | solver=Heun(),
56 | max_t1=10,
57 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3)
58 | )
59 |
60 | assert len(solution) == len(activities)
61 |
62 |
63 | def test_solve_inference_cross_entropy(simple_model, x, y_onehot):
64 | """Test inference solving with cross-entropy loss."""
65 | from jpc import init_activities_with_ffwd
66 |
67 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp")
68 |
69 | solution = solve_inference(
70 | params=(simple_model, None),
71 | activities=activities,
72 | output=y_onehot,
73 | input=x,
74 | loss_id="ce",
75 | param_type="sp",
76 | solver=Heun(),
77 | max_t1=10,
78 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3)
79 | )
80 |
81 | assert len(solution) == len(activities)
82 |
83 |
84 | def test_solve_inference_with_regularization(simple_model, x, y):
85 | """Test inference solving with regularization."""
86 | from jpc import init_activities_with_ffwd
87 |
88 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp")
89 |
90 | solution = solve_inference(
91 | params=(simple_model, None),
92 | activities=activities,
93 | output=y,
94 | input=x,
95 | loss_id="mse",
96 | param_type="sp",
97 | solver=Heun(),
98 | max_t1=10,
99 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3),
100 | weight_decay=0.01,
101 | spectral_penalty=0.01,
102 | activity_decay=0.01
103 | )
104 |
105 | assert len(solution) == len(activities)
106 |
107 |
108 | def test_solve_inference_record_iters(simple_model, x, y):
109 | """Test inference solving with iteration recording."""
110 | from jpc import init_activities_with_ffwd
111 |
112 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp")
113 |
114 | solution = solve_inference(
115 | params=(simple_model, None),
116 | activities=activities,
117 | output=y,
118 | input=x,
119 | loss_id="mse",
120 | param_type="sp",
121 | solver=Heun(),
122 | max_t1=10,
123 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3),
124 | record_iters=True
125 | )
126 |
127 | assert len(solution) == len(activities)
128 | # When recording, solution should have an extra dimension for time
129 | assert solution[0].ndim == 3 # (time, batch, features)
130 |
131 |
132 | def test_solve_inference_record_every(simple_model, x, y):
133 | """Test inference solving with record_every parameter."""
134 | from jpc import init_activities_with_ffwd
135 |
136 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp")
137 |
138 | solution = solve_inference(
139 | params=(simple_model, None),
140 | activities=activities,
141 | output=y,
142 | input=x,
143 | loss_id="mse",
144 | param_type="sp",
145 | solver=Heun(),
146 | max_t1=10,
147 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3),
148 | record_every=2
149 | )
150 |
151 | assert len(solution) == len(activities)
152 |
153 |
--------------------------------------------------------------------------------
/.github/logo-with-background.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Getting started
2 | JPC is a [**J**AX](https://github.com/google/jax) library for training neural
3 | networks with **P**redictive **C**oding (PC).
4 |
5 | JPC provides a **simple**, **fast** and **flexible** API for
6 | training of a variety of PCNs including discriminative, generative and hybrid
7 | models.
8 |
9 | * Like JAX, JPC is completely functional in design, and the core library code is
10 | <1000 lines of code.
11 |
12 | * Unlike existing implementations, JPC provides a wide range of optimisers, both
13 | discrete and continuous, to solve the inference dynamics of PC, including
14 | ordinary differential equation (ODE) solvers.
15 |
16 | * JPC also provides some analytical tools that can be used to study and
17 | potentially diagnose issues with PCNs.
18 |
19 | If you're new to JPC, we recommend starting from the [
20 | example notebooks](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/)
21 | and checking the [documentation](https://thebuckleylab.github.io/jpc/).
22 |
23 | ## 💻 Installation
24 | Clone the repo and in the project's directory run
25 | ```
26 | pip install .
27 | ```
28 |
29 | Requires Python 3.10+ and JAX 0.4.38–0.5.2 (inclusive). For GPU usage, upgrade
30 | jax to the appropriate cuda version (12 as an example here).
31 |
32 | ```
33 | pip install --upgrade "jax[cuda12]"
34 | ```
35 |
36 | ## ⚡️ Quick example
37 | Use `jpc.make_pc_step()` to update the parameters of any neural network
38 | compatible with PC updates (see [examples](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/))
39 | ```py
40 | import jax.random as jr
41 | import jax.numpy as jnp
42 | import equinox as eqx
43 | import optax
44 | import jpc
45 |
46 | # toy data
47 | x = jnp.array([1., 1., 1.])
48 | y = -x
49 |
50 | # define model and optimiser
51 | key = jr.PRNGKey(0)
52 | model = jpc.make_mlp(
53 | key,
54 | input_dim=3,
55 | width=50,
56 | depth=5,
57 | output_dim=3
58 | act_fn="relu"
59 | )
60 | optim = optax.adam(1e-3)
61 | opt_state = optim.init(
62 | (eqx.filter(model, eqx.is_array), None)
63 | )
64 |
65 | # perform one training step with PC
66 | result = jpc.make_pc_step(
67 | model=model,
68 | optim=optim,
69 | opt_state=opt_state,
70 | output=y,
71 | input=x
72 | )
73 |
74 | # updated model and optimiser
75 | model, opt_state = result["model"], result["opt_state"]
76 | ```
77 | Under the hood, `jpc.make_pc_step()`
78 |
79 | 1. integrates the inference (activity) dynamics using a [diffrax](https://github.com/patrick-kidger/diffrax) ODE solver, and
80 | 2. updates model parameters at the numerical solution of the activities with a given [optax](https://github.com/google-deepmind/optax) optimiser.
81 |
82 | > **NOTE**: All convenience training and test functions such as `make_pc_step()`
83 | > are already "jitted" (for optimised performance) for the user's convenience.
84 |
85 | ## 🚀 Advanced usage
86 | Advanced users can access all the underlying functions of `jpc.make_pc_step()`
87 | as well as additional features. A custom PC training step looks like the
88 | following:
89 | ```py
90 | import jpc
91 |
92 | # 1. initialise activities with a feedforward pass
93 | activities = jpc.init_activities_with_ffwd(model=model, input=x)
94 |
95 | # 2. perform inference (state optimisation)
96 | activity_opt_state = activity_optim.init(activities)
97 | for _ in range(len(model)):
98 | activity_update_result = jpc.update_pc_activities(
99 | params=(model, None),
100 | activities=activities,
101 | optim=activity_optim,
102 | opt_state=activity_opt_state,
103 | output=y,
104 | input=x
105 | )
106 | activities = activity_update_result["activities"]
107 | activity_opt_state = activity_update_result["opt_state"]
108 |
109 | # 3. update parameters at the activities' solution with PC
110 | result = jpc.update_params(
111 | params=(model, None),
112 | activities=equilibrated_activities,
113 | optim=optim,
114 | opt_state=opt_state,
115 | output=y,
116 | input=x
117 | )
118 | ```
119 | which can be embedded in a jitted function with any other additional
120 | computations.
121 |
122 | ## 📄 Citation
123 | If you found this library useful in your work, please cite ([paper link](https://arxiv.org/abs/2412.03676)):
124 |
125 | ```bibtex
126 | @article{innocenti2024jpc,
127 | title={JPC: Flexible Inference for Predictive Coding Networks in JAX},
128 | author={Innocenti, Francesco and Kinghorn, Paul and Yun-Farmbrough, Will and Varona, Miguel De Llanza and Singh, Ryan and Buckley, Christopher L},
129 | journal={arXiv preprint arXiv:2412.03676},
130 | year={2024}
131 | }
132 | ```
133 | Also consider starring the project [on GitHub](https://github.com/thebuckleylab/jpc)! ⭐️
134 |
135 | ## 🙏 Acknowledgements
136 | We are grateful to Patrick Kidger for early advice on how to use Diffrax.
137 |
138 | ## See also: other PC libraries
139 | * [ngc-learn](https://github.com/NACLab/ngc-learn) (jax & pytorch)
140 | * [pcx](https://github.com/liukidar/pcx) (jax)
141 | * [pyhgf](https://github.com/ComputationalPsychiatry/pyhgf) (jax)
142 | * [Torch2PC](https://github.com/RobertRosenbaum/Torch2PC) (pytorch)
143 | * [pypc](https://github.com/infer-actively/pypc) (pytorch)
144 | * [pybrid](https://github.com/alec-tschantz/pybrid) (pytorch)
145 |
--------------------------------------------------------------------------------
/docs/_static/logo-dark.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/tests/test_test_functions.py:
--------------------------------------------------------------------------------
1 | """Tests for test utility functions."""
2 |
3 | import pytest
4 | import jax
5 | import jax.numpy as jnp
6 | from diffrax import Heun, PIDController
7 | from jpc import test_discriminative_pc, test_generative_pc, test_hpc
8 |
9 |
10 | def test_test_discriminative_pc(simple_model, x, y_onehot):
11 | """Test discriminative PC testing function."""
12 | # For accuracy computation, we need one-hot encoded targets
13 | loss, acc = test_discriminative_pc(
14 | model=simple_model,
15 | output=y_onehot,
16 | input=x,
17 | loss="ce", # Use cross-entropy for one-hot targets
18 | param_type="sp"
19 | )
20 |
21 | assert jnp.isfinite(loss)
22 | assert jnp.isfinite(acc)
23 | assert 0 <= acc <= 100
24 |
25 |
26 | def test_test_discriminative_pc_cross_entropy(simple_model, x, y_onehot):
27 | """Test discriminative PC testing with cross-entropy."""
28 | loss, acc = test_discriminative_pc(
29 | model=simple_model,
30 | output=y_onehot,
31 | input=x,
32 | loss="ce",
33 | param_type="sp"
34 | )
35 |
36 | assert jnp.isfinite(loss)
37 | assert jnp.isfinite(acc)
38 |
39 |
40 | def test_test_discriminative_pc_with_skip(simple_model, x, y_onehot):
41 | """Test discriminative PC testing with skip connections."""
42 | from jpc import make_skip_model
43 |
44 | skip_model = make_skip_model(len(simple_model))
45 |
46 | # compute_accuracy requires one-hot encoded targets
47 | loss, acc = test_discriminative_pc(
48 | model=simple_model,
49 | output=y_onehot,
50 | input=x,
51 | skip_model=skip_model,
52 | loss="ce", # Use cross-entropy for one-hot targets
53 | param_type="sp"
54 | )
55 |
56 | assert jnp.isfinite(loss)
57 | assert jnp.isfinite(acc)
58 |
59 |
60 | def test_test_generative_pc(simple_model, x, y_onehot, key, layer_sizes, batch_size, input_dim):
61 | """Test generative PC testing function."""
62 | # For compute_accuracy, input needs to be one-hot encoded
63 | key1, key2 = jax.random.split(key)
64 | input_onehot = jax.nn.one_hot(
65 | jax.random.randint(key1, (batch_size,), 0, input_dim),
66 | input_dim
67 | )
68 | input_acc, output_preds = test_generative_pc(
69 | model=simple_model,
70 | output=y_onehot,
71 | input=input_onehot,
72 | key=key2,
73 | layer_sizes=layer_sizes,
74 | batch_size=batch_size,
75 | loss_id="ce",
76 | param_type="sp",
77 | sigma=0.05,
78 | ode_solver=Heun(),
79 | max_t1=10,
80 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3)
81 | )
82 |
83 | assert jnp.isfinite(input_acc)
84 | assert 0 <= input_acc <= 100
85 | assert output_preds.shape == y_onehot.shape
86 |
87 |
88 | def test_test_generative_pc_cross_entropy(simple_model, x, y_onehot, key, layer_sizes, batch_size, input_dim):
89 | """Test generative PC testing with cross-entropy."""
90 | # For compute_accuracy, input needs to be one-hot encoded
91 | key1, key2 = jax.random.split(key)
92 | input_onehot = jax.nn.one_hot(
93 | jax.random.randint(key1, (batch_size,), 0, input_dim),
94 | input_dim
95 | )
96 | input_acc, output_preds = test_generative_pc(
97 | model=simple_model,
98 | output=y_onehot,
99 | input=input_onehot,
100 | key=key2,
101 | layer_sizes=layer_sizes,
102 | batch_size=batch_size,
103 | loss_id="ce",
104 | param_type="sp",
105 | sigma=0.05,
106 | ode_solver=Heun(),
107 | max_t1=10,
108 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3)
109 | )
110 |
111 | assert jnp.isfinite(input_acc)
112 | assert output_preds.shape == y_onehot.shape
113 |
114 |
115 | def test_test_hpc(key, simple_model, x, y_onehot, output_dim, hidden_dim, input_dim, depth, layer_sizes, batch_size):
116 | """Test HPC testing function."""
117 | from jpc import make_mlp
118 |
119 | # Split keys first
120 | key1, key2, key3 = jax.random.split(key, 3)
121 |
122 | generator = simple_model
123 | amortiser = make_mlp(
124 | key=key2,
125 | input_dim=output_dim,
126 | width=hidden_dim,
127 | depth=depth,
128 | output_dim=input_dim,
129 | act_fn="relu",
130 | use_bias=False,
131 | param_type="sp"
132 | )
133 |
134 | # For compute_accuracy, input needs to be one-hot encoded
135 | input_onehot = jax.nn.one_hot(
136 | jax.random.randint(key1, (batch_size,), 0, input_dim),
137 | input_dim
138 | )
139 | amort_acc, hpc_acc, gen_acc, output_preds = test_hpc(
140 | generator=generator,
141 | amortiser=amortiser,
142 | output=y_onehot,
143 | input=input_onehot,
144 | key=key3,
145 | layer_sizes=layer_sizes,
146 | batch_size=batch_size,
147 | sigma=0.05,
148 | ode_solver=Heun(),
149 | max_t1=10,
150 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3)
151 | )
152 |
153 | assert jnp.isfinite(amort_acc)
154 | assert jnp.isfinite(hpc_acc)
155 | assert jnp.isfinite(gen_acc)
156 | assert 0 <= amort_acc <= 100
157 | assert 0 <= hpc_acc <= 100
158 | assert 0 <= gen_acc <= 100
159 | assert output_preds.shape == y_onehot.shape
160 |
161 |
--------------------------------------------------------------------------------
/experiments/mupc_paper/test_energy_theory.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import jax.random as jr
4 |
5 | import jpc
6 | import optax
7 | import equinox as eqx
8 |
9 | from experiments.datasets import get_dataloaders
10 |
11 |
12 | def run_test(
13 | seed,
14 | dataset,
15 | width,
16 | n_hidden,
17 | act_fn,
18 | use_skips,
19 | param_type,
20 | param_optim_id,
21 | param_lr,
22 | batch_size,
23 | save_dir
24 | ):
25 | set_seed(seed)
26 | key = jr.PRNGKey(seed)
27 | model_key, init_key = jr.split(key, 2)
28 | os.makedirs(save_dir, exist_ok=True)
29 |
30 | # create and initialise model
31 | d_in, d_out = 784, 10
32 | L = n_hidden + 1
33 | model = jpc.make_mlp(
34 | key=model_key,
35 | input_dim=d_in,
36 | width=width,
37 | depth=L,
38 | output_dim=d_out,
39 | act_fn=act_fn,
40 | use_bias=False,
41 | param_type=param_type
42 | )
43 | skip_model = jpc.make_skip_model(L) if use_skips else None
44 |
45 | # optimisers
46 | if param_optim_id == "sgd":
47 | param_optim = optax.sgd(param_lr)
48 | elif param_optim_id == "adam":
49 | param_optim = optax.adam(param_lr)
50 |
51 | param_opt_state = param_optim.init(
52 | (eqx.filter(model, eqx.is_array), skip_model)
53 | )
54 |
55 | # data & metrics
56 | train_loader, _ = get_dataloaders(dataset, batch_size)
57 | train_losses, train_energies = [], []
58 | loss_energy_ratios = []
59 |
60 | for t, (img_batch, label_batch) in enumerate(train_loader):
61 | x, y = img_batch.numpy(), label_batch.numpy()
62 |
63 | # compute loss
64 | activities = jpc.init_activities_with_ffwd(
65 | model=model,
66 | input=x,
67 | skip_model=skip_model,
68 | param_type=param_type
69 | )
70 | loss = 0.5 * np.sum((y - activities[-1])**2) / batch_size
71 |
72 | # compute theoretical activities & energy
73 | activities = jpc.compute_linear_activity_solution(
74 | network=model,
75 | x=x,
76 | y=y,
77 | use_skips=use_skips,
78 | param_type=param_type
79 | )
80 | energy = jpc.pc_energy_fn(
81 | params=(model, skip_model),
82 | activities=activities,
83 | y=y,
84 | x=x,
85 | param_type=param_type
86 | )
87 |
88 | # update parameters
89 | param_update_result = jpc.update_params(
90 | params=(model, skip_model),
91 | activities=activities,
92 | optim=param_optim,
93 | opt_state=param_opt_state,
94 | output=y,
95 | input=x,
96 | param_type=param_type
97 | )
98 | model = param_update_result["model"]
99 | skip_model = param_update_result["skip_model"]
100 | param_opt_state = param_update_result["opt_state"]
101 |
102 | train_losses.append(loss)
103 | train_energies.append(energy)
104 | loss_energy_ratios.append(loss/energy)
105 |
106 | if t % 200 == 0:
107 | print(
108 | f"\t{t * len(img_batch)}/{len(train_loader.dataset)}, "
109 | f"loss: {loss:.4f}, energy: {energy:.4f}, ratio: {loss/energy:.4f} "
110 | )
111 |
112 | np.save(f"{save_dir}/train_losses.npy", train_losses)
113 | np.save(f"{save_dir}/train_energies.npy", train_energies)
114 | np.save(f"{save_dir}/loss_energy_ratios.npy", loss_energy_ratios)
115 |
116 |
117 | if __name__ == "__main__":
118 |
119 | RESULTS_DIR = "energy_theory_results"
120 | DATASET = "MNIST"
121 | USE_SKIPS = True
122 | PARAM_OPTIM_ID = "adam"
123 | PARAM_LR = 1e-2
124 | BATCH_SIZE = 64
125 |
126 | ACT_FNS = ["linear"]
127 | PARAM_TYPES = ["mupc"] #"sp",
128 | WIDTHS = [2**i for i in range(7)] # 7 or 10 max
129 | N_HIDDENS = [2**i for i in range(7)] # 4 or 7 max
130 | SEED = 4320
131 |
132 | for act_fn in ACT_FNS:
133 | for param_type in PARAM_TYPES:
134 | for width in WIDTHS:
135 | for n_hidden in N_HIDDENS:
136 | print(
137 | f"\nAct fn: {act_fn}\n"
138 | f"Param type: {param_type}\n"
139 | f"Use skips: {USE_SKIPS}\n"
140 | f"Param optim: {PARAM_OPTIM_ID}\n"
141 | f"Param lr: {PARAM_LR}\n"
142 | f"Width: {width}\n"
143 | f"N hidden: {n_hidden}\n"
144 | f"Seed: {SEED}\n"
145 | )
146 | save_dir = os.path.join(
147 | RESULTS_DIR,
148 | act_fn,
149 | param_type,
150 | "skips" if USE_SKIPS else "no_skips",
151 | PARAM_OPTIM_ID,
152 | f"param_lr_{PARAM_LR}",
153 | f"width_{width}",
154 | f"{n_hidden}_n_hidden",
155 | str(SEED)
156 | )
157 | run_test(
158 | seed=SEED,
159 | dataset=DATASET,
160 | width=width,
161 | n_hidden=n_hidden,
162 | act_fn=act_fn,
163 | use_skips=USE_SKIPS,
164 | param_type=param_type,
165 | param_optim_id=PARAM_OPTIM_ID,
166 | param_lr=PARAM_LR,
167 | batch_size=BATCH_SIZE,
168 | save_dir=save_dir
169 | )
170 |
--------------------------------------------------------------------------------
/jpc/_core/_infer.py:
--------------------------------------------------------------------------------
1 | """Function to solve the inference (activity) dynamics of PC networks."""
2 |
3 | from jaxtyping import PyTree, ArrayLike, Array, Scalar
4 | import jax.numpy as jnp
5 | from typing import Tuple, Callable, Optional
6 | from ._grads import neg_pc_activity_grad
7 | from optimistix import rms_norm
8 | from diffrax import (
9 | AbstractSolver,
10 | AbstractStepSizeController,
11 | Heun,
12 | PIDController,
13 | diffeqsolve,
14 | ODETerm,
15 | Event,
16 | SaveAt
17 | )
18 |
19 |
20 | def solve_inference(
21 | params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]],
22 | activities: PyTree[ArrayLike],
23 | output: ArrayLike,
24 | *,
25 | input: Optional[ArrayLike] = None,
26 | loss_id: str = "mse",
27 | param_type: str = "sp",
28 | solver: AbstractSolver = Heun(),
29 | max_t1: int = 20,
30 | dt: float | int = None,
31 | stepsize_controller: AbstractStepSizeController = PIDController(
32 | rtol=1e-3, atol=1e-3
33 | ),
34 | weight_decay: Scalar = 0.,
35 | spectral_penalty: Scalar = 0.,
36 | activity_decay: Scalar = 0.,
37 | record_iters: bool = False,
38 | record_every: int = None
39 | ) -> PyTree[Array]:
40 | """Solves the inference (activity) dynamics of a predictive coding network.
41 |
42 | This is a wrapper around [`diffrax.diffeqsolve()`](https://docs.kidger.site/diffrax/api/diffeqsolve/#diffrax.diffeqsolve)
43 | to integrate the gradient ODE system [`jpc.neg_activity_grad()`](https://thebuckleylab.github.io/jpc/api/Gradients/#jpc.neg_activity_grad)
44 | defining the PC inference dynamics.
45 |
46 | $$
47 | d\mathbf{z} / dt = - ∇_{\mathbf{z}} \mathcal{F}
48 | $$
49 |
50 | where $\mathcal{F}$ is the free energy, $\mathbf{z}$ are the activities,
51 | with $\mathbf{z}_L$ clamped to some target and $\mathbf{z}_0$ optionally
52 | set to some prior.
53 |
54 | **Main arguments:**
55 |
56 | - `params`: Tuple with callable model layers and optional skip connections.
57 | - `activities`: List of activities for each layer free to vary.
58 | - `output`: Observation or target of the generative model.
59 |
60 | **Other arguments:**
61 |
62 | - `input`: Optional prior of the generative model.
63 | - `loss_id`: Loss function to use at the output layer. Options are mean squared
64 | error `"mse"` (default) or cross-entropy `"ce"`.
65 | - `param_type`: Determines the parameterisation. Options are `"sp"`
66 | (standard parameterisation), `"mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))),
67 | or `"ntp"` (neural tangent parameterisation).
68 | See [`_get_param_scalings()`](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings)
69 | for the specific scalings of these different parameterisations. Defaults
70 | to `"sp"`.
71 | - `solver`: [diffrax ODE solver](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/)
72 | to be used. Default is [`Heun`](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/#diffrax.Heun),
73 | a 2nd order explicit Runge--Kutta method.
74 | - `max_t1`: Maximum end of integration region (20 by default).
75 | - `dt`: Integration step size. Defaults to `None` since the default
76 | `stepsize_controller` will automatically determine it.
77 | - `stepsize_controller`: [diffrax controller](https://docs.kidger.site/diffrax/api/stepsize_controller/)
78 | for step size integration. Defaults to [`PIDController`](https://docs.kidger.site/diffrax/api/stepsize_controller/#diffrax.PIDController).
79 | Note that the relative and absolute tolerances of the controller will
80 | also determine the steady state to terminate the solver.
81 | - `weight_decay`: $\ell^2$ regulariser for the weights (0 by default).
82 | - `spectral_penalty`: Weight spectral penalty of the form
83 | $||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2$ (0 by default).
84 | - `activity_decay`: $\ell^2$ regulariser for the activities (0 by default).
85 | - `record_iters`: If `True`, returns all integration steps.
86 | - `record_every`: int determining the sampling frequency of the integration
87 | steps.
88 |
89 | **Returns:**
90 |
91 | List with solution of the activity dynamics for each layer.
92 |
93 | """
94 | if record_every is not None:
95 | ts = jnp.arange(0, max_t1, record_every)
96 | saveat = SaveAt(t1=True, ts=ts)
97 | else:
98 | saveat = SaveAt(t1=True, steps=record_iters)
99 |
100 | solution = diffeqsolve(
101 | terms=ODETerm(neg_pc_activity_grad),
102 | solver=solver,
103 | t0=0,
104 | t1=max_t1,
105 | dt0=dt,
106 | y0=activities,
107 | args=(
108 | params,
109 | output,
110 | input,
111 | loss_id,
112 | param_type,
113 | weight_decay,
114 | spectral_penalty,
115 | activity_decay,
116 | stepsize_controller
117 | ),
118 | stepsize_controller=stepsize_controller,
119 | event=Event(steady_state_event_with_timeout),
120 | saveat=saveat
121 | )
122 | return solution.ys
123 |
124 |
125 | def steady_state_event_with_timeout(t, y, args, **kwargs):
126 | _stepsize_controller = args[-1]
127 | try:
128 | _atol = _stepsize_controller.atol
129 | _rtol = _stepsize_controller.rtol
130 | except:
131 | _atol, _rtol = 1e-3, 1e-3
132 | steady_state_reached = rms_norm(y) < _atol + _rtol * rms_norm(y)
133 | timeout_reached = jnp.array(t >= 4096, dtype=jnp.bool_)
134 | return jnp.logical_or(steady_state_reached, timeout_reached)
135 |
--------------------------------------------------------------------------------
/tests/test_analytical.py:
--------------------------------------------------------------------------------
1 | """Tests for analytical tools."""
2 |
3 | import pytest
4 | import jax
5 | import jax.numpy as jnp
6 | import equinox.nn as nn
7 | from jpc import (
8 | compute_linear_equilib_energy,
9 | compute_linear_activity_hessian,
10 | compute_linear_activity_solution
11 | )
12 |
13 |
14 | def test_compute_linear_equilib_energy(key, x, y, input_dim, hidden_dim, output_dim, depth):
15 | """Test computation of linear equilibrium energy."""
16 | # Create a linear network (no activation functions)
17 | subkeys = jax.random.split(key, depth)
18 | network = []
19 | for i in range(depth):
20 | _in = input_dim if i == 0 else hidden_dim
21 | _out = output_dim if (i + 1) == depth else hidden_dim
22 | linear = nn.Linear(_in, _out, use_bias=False, key=subkeys[i])
23 | network.append(nn.Sequential([nn.Lambda(lambda x: x), linear]))
24 |
25 | energy = compute_linear_equilib_energy(
26 | network=network,
27 | x=x,
28 | y=y
29 | )
30 |
31 | assert jnp.isfinite(energy)
32 | assert energy >= 0
33 |
34 |
35 | def test_compute_linear_activity_hessian(key, input_dim, hidden_dim, output_dim, depth):
36 | """Test computation of linear activity Hessian."""
37 | # Extract weight matrices
38 | subkeys = jax.random.split(key, depth)
39 | Ws = []
40 | for i in range(depth):
41 | _in = input_dim if i == 0 else hidden_dim
42 | _out = output_dim if (i + 1) == depth else hidden_dim
43 | W = jax.random.normal(subkeys[i], (_out, _in))
44 | Ws.append(W)
45 |
46 | hessian = compute_linear_activity_hessian(
47 | Ws=Ws,
48 | use_skips=False,
49 | param_type="sp",
50 | activity_decay=False,
51 | diag=True,
52 | off_diag=True
53 | )
54 |
55 | # Check shape: should be (sum of hidden layer sizes) x (sum of hidden layer sizes)
56 | hidden_sizes = [hidden_dim] * (depth - 1)
57 | expected_size = sum(hidden_sizes)
58 | assert hessian.shape == (expected_size, expected_size)
59 | assert jnp.all(jnp.isfinite(hessian))
60 |
61 |
62 | def test_compute_linear_activity_hessian_with_skips(key, input_dim, hidden_dim, output_dim, depth):
63 | """Test computation of linear activity Hessian with skip connections."""
64 | subkeys = jax.random.split(key, depth)
65 | Ws = []
66 | for i in range(depth):
67 | _in = input_dim if i == 0 else hidden_dim
68 | _out = output_dim if (i + 1) == depth else hidden_dim
69 | W = jax.random.normal(subkeys[i], (_out, _in))
70 | Ws.append(W)
71 |
72 | hessian = compute_linear_activity_hessian(
73 | Ws=Ws,
74 | use_skips=True,
75 | param_type="sp",
76 | activity_decay=False,
77 | diag=True,
78 | off_diag=True
79 | )
80 |
81 | hidden_sizes = [hidden_dim] * (depth - 1)
82 | expected_size = sum(hidden_sizes)
83 | assert hessian.shape == (expected_size, expected_size)
84 |
85 |
86 | def test_compute_linear_activity_hessian_different_param_types(key, input_dim, hidden_dim, output_dim, depth):
87 | """Test Hessian computation with different parameter types."""
88 | subkeys = jax.random.split(key, depth)
89 | Ws = []
90 | for i in range(depth):
91 | _in = input_dim if i == 0 else hidden_dim
92 | _out = output_dim if (i + 1) == depth else hidden_dim
93 | W = jax.random.normal(subkeys[i], (_out, _in))
94 | Ws.append(W)
95 |
96 | # Note: The library code uses "ntp" in _analytical.py but "ntk" in _errors.py
97 | # For now, test with "sp" and "mupc" which work, and skip "ntk"/"ntp" due to inconsistency
98 | for param_type in ["sp", "mupc"]:
99 | hessian = compute_linear_activity_hessian(
100 | Ws=Ws,
101 | use_skips=False,
102 | param_type=param_type,
103 | activity_decay=False,
104 | diag=True,
105 | off_diag=True
106 | )
107 |
108 | hidden_sizes = [hidden_dim] * (depth - 1)
109 | expected_size = sum(hidden_sizes)
110 | assert hessian.shape == (expected_size, expected_size)
111 |
112 |
113 | def test_compute_linear_activity_solution(key, x, y, input_dim, hidden_dim, output_dim, depth):
114 | """Test computation of linear activity solution."""
115 | # Create a linear network
116 | subkeys = jax.random.split(key, depth)
117 | network = []
118 | for i in range(depth):
119 | _in = input_dim if i == 0 else hidden_dim
120 | _out = output_dim if (i + 1) == depth else hidden_dim
121 | linear = nn.Linear(_in, _out, use_bias=False, key=subkeys[i])
122 | network.append(nn.Sequential([nn.Lambda(lambda x: x), linear]))
123 |
124 | activities = compute_linear_activity_solution(
125 | network=network,
126 | x=x,
127 | y=y,
128 | use_skips=False,
129 | param_type="sp",
130 | activity_decay=False
131 | )
132 |
133 | # Should return activities for hidden layers plus dummy target prediction
134 | assert len(activities) == depth
135 | assert activities[0].shape == (x.shape[0], hidden_dim)
136 | assert activities[-1].shape == (x.shape[0], output_dim)
137 |
138 |
139 | def test_compute_linear_activity_solution_with_skips(key, x, y, input_dim, hidden_dim, output_dim, depth):
140 | """Test computation of linear activity solution with skip connections."""
141 | subkeys = jax.random.split(key, depth)
142 | network = []
143 | for i in range(depth):
144 | _in = input_dim if i == 0 else hidden_dim
145 | _out = output_dim if (i + 1) == depth else hidden_dim
146 | linear = nn.Linear(_in, _out, use_bias=False, key=subkeys[i])
147 | network.append(nn.Sequential([nn.Lambda(lambda x: x), linear]))
148 |
149 | activities = compute_linear_activity_solution(
150 | network=network,
151 | x=x,
152 | y=y,
153 | use_skips=True,
154 | param_type="sp",
155 | activity_decay=False
156 | )
157 |
158 | assert len(activities) == depth
159 |
160 |
--------------------------------------------------------------------------------
/jpc/_core/_init.py:
--------------------------------------------------------------------------------
1 | """Functions to initialise the layer activities of PC networks."""
2 |
3 | from jax import vmap, random
4 | import jax.numpy as jnp
5 | import equinox as eqx
6 | from ._energies import _get_param_scalings
7 | from jaxtyping import PyTree, ArrayLike, Array, PRNGKeyArray, Scalar
8 | from typing import Callable, Optional
9 | from ._errors import _check_param_type
10 |
11 |
12 | @eqx.filter_jit
13 | def init_activities_with_ffwd(
14 | model: PyTree[Callable],
15 | input: ArrayLike,
16 | *,
17 | skip_model: Optional[PyTree[Callable]] = None,
18 | param_type: str = "sp"
19 | ) -> PyTree[Array]:
20 | """Initialises the layers' activity with a feedforward pass
21 | $\{ f_\ell(\mathbf{z}_{\ell-1}) \}_{\ell=1}^L$ where $f_\ell(\cdot)$ is some
22 | callable layer transformation and $\mathbf{z}_0 = \mathbf{x}$ is the input.
23 |
24 | !!! warning
25 |
26 | `param_type = "mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))) assumes
27 | that one is using [`jpc.make_mlp()`](https://thebuckleylab.github.io/jpc/api/Utils/#jpc.make_mlp)
28 | to create the model.
29 |
30 | **Main arguments:**
31 |
32 | - `model`: List of callable model (e.g. neural network) layers.
33 | - `input`: input to the model.
34 |
35 | **Other arguments:**
36 |
37 | - `skip_model`: Optional skip connection model.
38 | - `param_type`: Determines the parameterisation. Options are `"sp"`
39 | (standard parameterisation), `"mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))),
40 | or `"ntp"` (neural tangent parameterisation).
41 | See [`_get_param_scalings()`](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings)
42 | for the specific scalings of these different parameterisations. Defaults
43 | to `"sp"`.
44 |
45 | **Returns:**
46 |
47 | List with activity values of each layer.
48 |
49 | """
50 | _check_param_type(param_type)
51 |
52 | L = len(model)
53 | if skip_model is None:
54 | skip_model = [None] * len(model)
55 |
56 | scalings = _get_param_scalings(
57 | model=model,
58 | input=input,
59 | skip_model=skip_model,
60 | param_type=param_type
61 | )
62 |
63 | z1 = scalings[0] * vmap(model[0])(input)
64 | if skip_model[0] is not None:
65 | z1 += vmap(skip_model[0])(input)
66 |
67 | activities = [z1]
68 | for l in range(1, L):
69 | zl = scalings[l] * vmap(model[l])(activities[l - 1])
70 |
71 | if skip_model[l] is not None:
72 | skip_output = vmap(skip_model[l])(activities[l - 1])
73 | zl += skip_output
74 |
75 | activities.append(zl)
76 |
77 | return activities
78 |
79 |
80 | def init_activities_from_normal(
81 | key: PRNGKeyArray,
82 | layer_sizes: PyTree[int],
83 | mode: str,
84 | batch_size: int,
85 | sigma: Scalar = 0.05
86 | ) -> PyTree[Array]:
87 | """Initialises network activities from a zero-mean Gaussian
88 | $z_i \sim \mathcal{N}(0, \sigma^2)$.
89 |
90 | **Main arguments:**
91 |
92 | - `key`: `jax.random.PRNGKey` for sampling.
93 | - `layer_sizes`: List with dimension of all layers (input, hidden and
94 | output).
95 | - `mode`: If `"supervised"`, all hidden layers are initialised. If
96 | `"unsupervised"` the input layer $\mathbf{z}_0$ is also initialised.
97 | - `batch_size`: Dimension of data batch.
98 | - `sigma`: Standard deviation for Gaussian to sample activities from.
99 | Defaults to 5e-2.
100 |
101 | **Returns:**
102 |
103 | List of randomly initialised activities for each layer.
104 |
105 | """
106 | start_l = 0 if mode == "unsupervised" else 1
107 | n_layers = len(layer_sizes) if mode == "unsupervised" else len(layer_sizes)-1
108 | activities = []
109 | for l, subkey in zip(
110 | range(start_l, n_layers+1),
111 | random.split(key, num=n_layers)
112 | ):
113 | activities.append(sigma * random.normal(
114 | subkey,
115 | shape=(batch_size, layer_sizes[l])
116 | )
117 | )
118 | return activities
119 |
120 |
121 | def init_activities_with_amort(
122 | amortiser: PyTree[Callable],
123 | generator: PyTree[Callable],
124 | input: ArrayLike
125 | ) -> PyTree[Array]:
126 | """Initialises layers' activity with an amortised network
127 | $\{ f_{L-\ell+1}(\mathbf{z}_{L-\ell}) \}_{\ell=1}^L$ where $\mathbf{z}_0 = \mathbf{y}$ is
128 | the input or generator's target.
129 |
130 | !!! note
131 |
132 | The output order is reversed for downstream use by the generator.
133 |
134 | **Main arguments:**
135 |
136 | - `amortiser`: List of callable layers for model amortising the inference
137 | of the `generator`.
138 | - `generator`: List of callable layers for the generative model.
139 | - `input`: Input to the amortiser.
140 |
141 | **Returns:**
142 |
143 | List with amortised initialisation of each layer.
144 |
145 | """
146 | activities = [vmap(amortiser[0])(input)]
147 | for l in range(1, len(amortiser)):
148 | activities.append(vmap(amortiser[l])(activities[l - 1]))
149 |
150 | activities = activities[::-1]
151 |
152 | # NOTE: this dummy activity for the last layer is added in case one is
153 | # interested in inspecting the generator's target prediction during inference.
154 | activities.append(
155 | vmap(generator[-1])(activities[-1])
156 | )
157 | return activities
158 |
159 |
160 | def init_epc_errors(
161 | layer_sizes: PyTree[int],
162 | batch_size: int,
163 | mode: str = "supervised"
164 | ) -> PyTree[Array]:
165 | """Initialises zero errors for use with ePC $\{ \epsilon_\ell = 0 \}_{l=1}^L$.
166 |
167 | **Main arguments:**
168 |
169 | - `layer_sizes`: List with dimension of all layers (input, hidden and
170 | output).
171 | - `batch_size`: Dimension of data batch.
172 | - `mode`: If `"supervised"`, errors are initialised for layers 1 to L-1
173 | (hidden layers only). If `"unsupervised"`, errors are initialised for
174 | layer 0 (input) and layers 1 to L-1. Defaults to `"supervised"`.
175 |
176 | **Returns:**
177 |
178 | List of zero-initialised error arrays for each layer.
179 |
180 | """
181 | start_l = 0 if mode == "unsupervised" else 1
182 | n_layers = len(layer_sizes) if mode == "unsupervised" else len(layer_sizes) - 1
183 | errors = []
184 | for l in range(start_l, n_layers + 1):
185 | errors.append(jnp.zeros(shape=(batch_size, layer_sizes[l])))
186 | return errors
187 |
--------------------------------------------------------------------------------
/experiments/datasets.py:
--------------------------------------------------------------------------------
1 | import jax.random as jr
2 | import jax.numpy as jnp
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from torchvision import datasets, transforms
7 |
8 |
9 | DATA_DIR = "datasets"
10 | IMAGENET_DIR = f"~/projects/jpc/experiments/{DATA_DIR}/ImageNet"
11 |
12 |
13 | def make_gaussian_dataset(key, mean, std, shape):
14 | x = mean + std * jr.normal(key, shape)
15 | y = x
16 | return (x, y)
17 |
18 |
19 | def get_dataloaders(dataset_id, batch_size, flatten=True):
20 | train_data = get_dataset(
21 | id=dataset_id,
22 | train=True,
23 | normalise=True,
24 | flatten=flatten
25 | )
26 | test_data = get_dataset(
27 | id=dataset_id,
28 | train=False,
29 | normalise=True,
30 | flatten=flatten
31 | )
32 | train_loader = DataLoader(
33 | dataset=train_data,
34 | batch_size=batch_size,
35 | shuffle=True,
36 | drop_last=True
37 | )
38 | test_loader = DataLoader(
39 | dataset=test_data,
40 | batch_size=batch_size,
41 | shuffle=True,
42 | drop_last=True
43 | )
44 | return train_loader, test_loader
45 |
46 |
47 | def get_dataset(id, train, normalise, flatten=True):
48 | if id == "MNIST":
49 | dataset = MNIST(train=train, normalise=normalise, flatten=flatten)
50 | elif id == "Fashion-MNIST":
51 | dataset = FashionMNIST(train=train, normalise=normalise, flatten=flatten)
52 | elif id == "CIFAR10":
53 | dataset = CIFAR10(train=train, normalise=normalise, flatten=flatten)
54 | else:
55 | raise ValueError(
56 | "Invalid dataset ID. Options are `MNIST`, `Fashion-MNIST` and `CIFAR10`"
57 | )
58 | return dataset
59 |
60 |
61 | def get_imagenet_loaders(batch_size):
62 | train_data, val_data = ImageNet(split="train"), ImageNet(split="val")
63 | train_loader = DataLoader(
64 | dataset=train_data,
65 | batch_size=batch_size,
66 | shuffle=True,
67 | drop_last=True,
68 | num_workers=32,
69 | persistent_workers=True
70 | )
71 | val_loader = DataLoader(
72 | dataset=val_data,
73 | batch_size=batch_size,
74 | shuffle=True,
75 | drop_last=True,
76 | num_workers=32,
77 | persistent_workers=True
78 | )
79 | return train_loader, val_loader
80 |
81 |
82 | class MNIST(datasets.MNIST):
83 | def __init__(self, train, normalise=True, flatten=True, save_dir=DATA_DIR):
84 | self.flatten = flatten
85 | if normalise:
86 | transform = transforms.Compose(
87 | [
88 | transforms.ToTensor(),
89 | transforms.Normalize(
90 | mean=(0.1307), std=(0.3081)
91 | )
92 | ]
93 | )
94 | else:
95 | transform = transforms.Compose([transforms.ToTensor()])
96 | super().__init__(save_dir, download=True, train=train, transform=transform)
97 |
98 | def __getitem__(self, index):
99 | img, label = super().__getitem__(index)
100 | if self.flatten:
101 | img = torch.flatten(img)
102 | label = one_hot(label, n_classes=10)
103 | return img, label
104 |
105 |
106 | class FashionMNIST(datasets.FashionMNIST):
107 | def __init__(self, train, normalise=True, flatten=True, save_dir=DATA_DIR):
108 | self.flatten = flatten
109 | if normalise:
110 | transform = transforms.Compose(
111 | [
112 | transforms.ToTensor(),
113 | transforms.Normalize(
114 | mean=(0.5), std=(0.5)
115 | )
116 | ]
117 | )
118 | else:
119 | transform = transforms.Compose([transforms.ToTensor()])
120 | super().__init__(save_dir, download=True, train=train, transform=transform)
121 |
122 | def __getitem__(self, index):
123 | img, label = super().__getitem__(index)
124 | if self.flatten:
125 | img = torch.flatten(img)
126 | label = one_hot(label)
127 | return img, label
128 |
129 |
130 | class CIFAR10(datasets.CIFAR10):
131 | def __init__(self, train, normalise=True, flatten=True, save_dir=f"{DATA_DIR}/CIFAR10"):
132 | self.flatten = flatten
133 | if normalise:
134 | if train:
135 | transform = transforms.Compose(
136 | [
137 | transforms.Resize((32,32)),
138 | transforms.RandomCrop(32, padding=4),
139 | transforms.RandomHorizontalFlip(),
140 | transforms.RandomRotation(10),
141 | transforms.ToTensor(),
142 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
143 | ]
144 | )
145 | else:
146 | transform = transforms.Compose(
147 | [
148 | transforms.Resize((32,32)),
149 | transforms.ToTensor(),
150 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
151 | ]
152 | )
153 | else:
154 | transform = transforms.Compose([transforms.ToTensor()])
155 | super().__init__(save_dir, download=True, train=train, transform=transform)
156 |
157 | def __getitem__(self, index):
158 | img, label = super().__getitem__(index)
159 | if self.flatten:
160 | img = torch.flatten(img)
161 | label = one_hot(label)
162 | return img, label
163 |
164 |
165 | class ImageNet(datasets.ImageNet):
166 | def __init__(self, split):
167 | if split == "train":
168 | transform = transforms.Compose([
169 | transforms.RandomResizedCrop(224),
170 | transforms.RandomHorizontalFlip(),
171 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
172 | transforms.ToTensor(),
173 | transforms.Normalize(
174 | mean=[0.485, 0.456, 0.406],
175 | std=[0.229, 0.224, 0.225])
176 | ])
177 | elif split == "val":
178 | transform = transforms.Compose(
179 | [
180 | transforms.Resize(256),
181 | transforms.CenterCrop(224),
182 | transforms.ToTensor(),
183 | transforms.Normalize(
184 | mean=[0.485, 0.456, 0.406],
185 | std=[0.229, 0.224, 0.225]
186 | ),
187 | ]
188 | )
189 | super().__init__(root=IMAGENET_DIR, split=split, transform=transform)
190 |
191 | def __getitem__(self, index):
192 | img, label = super().__getitem__(index)
193 | label = one_hot(label, n_classes=1000)
194 | return img, label
195 |
196 |
197 | def one_hot(labels, n_classes=10):
198 | arr = torch.eye(n_classes)
199 | return arr[labels]
200 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | """Tests for utility functions."""
2 |
3 | import pytest
4 | import jax
5 | import jax.numpy as jnp
6 | from jpc import (
7 | make_mlp,
8 | make_skip_model,
9 | get_act_fn,
10 | mse_loss,
11 | cross_entropy_loss,
12 | compute_accuracy,
13 | get_t_max,
14 | compute_activity_norms,
15 | compute_param_norms,
16 | compute_infer_energies
17 | )
18 |
19 |
20 | def test_make_mlp(key, input_dim, hidden_dim, output_dim, depth):
21 | """Test MLP creation."""
22 | model = make_mlp(
23 | key=key,
24 | input_dim=input_dim,
25 | width=hidden_dim,
26 | depth=depth,
27 | output_dim=output_dim,
28 | act_fn="relu",
29 | use_bias=False,
30 | param_type="sp"
31 | )
32 |
33 | assert len(model) == depth
34 | assert model[0][1].weight.shape == (hidden_dim, input_dim)
35 | assert model[-1][1].weight.shape == (output_dim, hidden_dim)
36 |
37 |
38 | def test_make_mlp_different_activations(key, input_dim, hidden_dim, output_dim, depth):
39 | """Test MLP creation with different activation functions."""
40 | for act_fn in ["linear", "tanh", "relu", "gelu", "silu"]:
41 | model = make_mlp(
42 | key=key,
43 | input_dim=input_dim,
44 | width=hidden_dim,
45 | depth=depth,
46 | output_dim=output_dim,
47 | act_fn=act_fn,
48 | use_bias=False,
49 | param_type="sp"
50 | )
51 |
52 | assert len(model) == depth
53 |
54 |
55 | def test_make_mlp_with_bias(key, input_dim, hidden_dim, output_dim, depth):
56 | """Test MLP creation with bias."""
57 | model = make_mlp(
58 | key=key,
59 | input_dim=input_dim,
60 | width=hidden_dim,
61 | depth=depth,
62 | output_dim=output_dim,
63 | act_fn="relu",
64 | use_bias=True,
65 | param_type="sp"
66 | )
67 |
68 | assert len(model) == depth
69 |
70 |
71 | def test_make_mlp_different_param_types(key, input_dim, hidden_dim, output_dim, depth):
72 | """Test MLP creation with different parameter types."""
73 | for param_type in ["sp", "mupc", "ntp"]:
74 | model = make_mlp(
75 | key=key,
76 | input_dim=input_dim,
77 | width=hidden_dim,
78 | depth=depth,
79 | output_dim=output_dim,
80 | act_fn="relu",
81 | use_bias=False,
82 | param_type=param_type
83 | )
84 |
85 | assert len(model) == depth
86 |
87 |
88 | def test_make_skip_model(depth):
89 | """Test skip model creation."""
90 | skip_model = make_skip_model(depth)
91 |
92 | assert len(skip_model) == depth
93 | # First and last should be None
94 | assert skip_model[0] is None
95 | assert skip_model[-1] is None
96 |
97 |
98 | def test_get_act_fn():
99 | """Test activation function retrieval."""
100 | act_fns = ["linear", "tanh", "hard_tanh", "relu", "leaky_relu", "gelu", "selu", "silu"]
101 |
102 | for act_fn_name in act_fns:
103 | act_fn = get_act_fn(act_fn_name)
104 | assert callable(act_fn)
105 |
106 | # Test invalid activation function
107 | with pytest.raises(ValueError):
108 | get_act_fn("invalid")
109 |
110 |
111 | def test_mse_loss(key, batch_size, output_dim):
112 | """Test MSE loss computation."""
113 | preds = jax.random.normal(key, (batch_size, output_dim))
114 | labels = jax.random.normal(key, (batch_size, output_dim))
115 |
116 | loss = mse_loss(preds, labels)
117 |
118 | assert jnp.isfinite(loss)
119 | assert loss >= 0
120 |
121 |
122 | def test_cross_entropy_loss(key, batch_size, output_dim):
123 | """Test cross-entropy loss computation."""
124 | logits = jax.random.normal(key, (batch_size, output_dim))
125 | labels = jax.nn.one_hot(
126 | jax.random.randint(key, (batch_size,), 0, output_dim),
127 | output_dim
128 | )
129 |
130 | loss = cross_entropy_loss(logits, labels)
131 |
132 | assert jnp.isfinite(loss)
133 | assert loss >= 0
134 |
135 |
136 | def test_compute_accuracy(key, batch_size, output_dim):
137 | """Test accuracy computation."""
138 | truths = jax.nn.one_hot(
139 | jax.random.randint(key, (batch_size,), 0, output_dim),
140 | output_dim
141 | )
142 | preds = jax.nn.one_hot(
143 | jax.random.randint(key, (batch_size,), 0, output_dim),
144 | output_dim
145 | )
146 |
147 | acc = compute_accuracy(truths, preds)
148 |
149 | assert 0 <= acc <= 100
150 | assert jnp.isfinite(acc)
151 |
152 |
153 | def test_get_t_max(key, batch_size, hidden_dim):
154 | """Test t_max computation."""
155 | # Create fake activities_iters with time dimension
156 | # The function looks for argmax in activities_iters[0][:, 0, 0] then subtracts 1
157 | # We need to ensure there's a valid maximum
158 | activities_iters = [
159 | jnp.zeros((100, batch_size, hidden_dim))
160 | ]
161 | # Set a value at index 10 to be non-zero so argmax returns 10, then t_max = 10 - 1 = 9
162 | activities_iters[0] = activities_iters[0].at[10, 0, 0].set(1.0)
163 |
164 | t_max = get_t_max(activities_iters)
165 |
166 | assert jnp.isfinite(t_max)
167 | # t_max is argmax - 1, so with value at index 10, argmax is 10, so t_max is 9
168 | assert t_max >= 0
169 |
170 |
171 | def test_compute_activity_norms(simple_model, x):
172 | """Test activity norm computation."""
173 | from jpc import init_activities_with_ffwd
174 |
175 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp")
176 |
177 | norms = compute_activity_norms(activities)
178 |
179 | assert len(norms) == len(activities)
180 | assert all(jnp.isfinite(n) and n >= 0 for n in norms)
181 |
182 |
183 | def test_compute_param_norms(simple_model):
184 | """Test parameter norm computation."""
185 | # Skip model contains Lambda functions which don't have .weight attributes
186 | # and cause issues with compute_param_norms. Test without skip_model instead.
187 | model_norms, skip_norms = compute_param_norms((simple_model, None))
188 |
189 | assert len(model_norms) > 0
190 | assert skip_norms is None
191 | assert all(jnp.isfinite(n) and n >= 0 for n in model_norms)
192 |
193 |
194 | def test_compute_infer_energies(simple_model, x, y):
195 | """Test inference energy computation."""
196 | from jpc import init_activities_with_ffwd, solve_inference
197 | from diffrax import Heun, PIDController
198 |
199 | activities = init_activities_with_ffwd(simple_model, x, param_type="sp")
200 |
201 | activities_iters = solve_inference(
202 | params=(simple_model, None),
203 | activities=activities,
204 | output=y,
205 | input=x,
206 | loss_id="mse",
207 | param_type="sp",
208 | solver=Heun(),
209 | max_t1=10,
210 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3),
211 | record_iters=True
212 | )
213 |
214 | t_max = get_t_max(activities_iters)
215 |
216 | energies = compute_infer_energies(
217 | params=(simple_model, None),
218 | activities_iters=activities_iters,
219 | t_max=t_max,
220 | y=y,
221 | x=x,
222 | loss="mse",
223 | param_type="sp"
224 | )
225 |
226 | assert energies.shape[0] == len(simple_model)
227 | assert all(jnp.all(jnp.isfinite(e)) for e in energies)
228 |
--------------------------------------------------------------------------------
/experiments/library_paper/test_theory_energies.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | import jax.random as jr
5 | import jax.numpy as jnp
6 | import equinox as eqx
7 | import optax
8 | import jpc
9 |
10 | from utils import set_seed
11 | from experiments.datasets import get_dataloaders
12 |
13 | import plotly.graph_objs as go
14 | import plotly.colors as pc
15 |
16 |
17 | def plot_accuracies(accuracies, save_path):
18 | n_train_iters = len(accuracies[10])
19 | train_iters = [t+1 for t in range(n_train_iters)]
20 |
21 | colorscale = "Blues"
22 | colors = pc.sample_colorscale(colorscale, len(accuracies)+2)[2:][::-1]
23 | fig = go.Figure()
24 | for i, (max_t1, accuracy) in enumerate(accuracies.items()):
25 | fig.add_trace(
26 | go.Scatter(
27 | x=train_iters,
28 | y=accuracy,
29 | mode="lines",
30 | line=dict(width=2, color=colors[i]),
31 | name=f"$t = {max_t1}$"
32 | )
33 | )
34 | fig.update_layout(
35 | height=350,
36 | width=500,
37 | xaxis=dict(
38 | title="Training iteration",
39 | tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
40 | ticktext=[1, int(train_iters[-1]/2)*10, train_iters[-1]*10]
41 | ),
42 | yaxis=dict(title="Test accuracy (%)"),
43 | font=dict(size=16),
44 | margin=dict(r=120)
45 | )
46 | fig.write_image(save_path)
47 |
48 |
49 | def plot_energies_across_ts(theory_energies, num_energies, save_path):
50 | n_train_iters = len(theory_energies)
51 | train_iters = [t+1 for t in range(n_train_iters)]
52 |
53 | colorscale = "Greens"
54 | colors = pc.sample_colorscale(colorscale, len(num_energies)+3)[2:][::-1]
55 | fig = go.Figure()
56 | fig.add_traces(
57 | go.Scatter(
58 | x=train_iters,
59 | y=theory_energies,
60 | name="theory",
61 | mode="lines",
62 | line=dict(
63 | width=3,
64 | dash="dash",
65 | color=colors[0]
66 | ),
67 | )
68 | )
69 | for i, (max_t1, num_energy) in enumerate(num_energies.items()):
70 | fig.add_trace(
71 | go.Scatter(
72 | x=train_iters,
73 | y=num_energy,
74 | mode="lines",
75 | line=dict(width=2, color=colors[i+1]),
76 | name=f"$t = {max_t1}$"
77 | )
78 | )
79 | fig.update_layout(
80 | height=350,
81 | width=500,
82 | xaxis=dict(
83 | title="Training iteration",
84 | tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
85 | ticktext=[1, int(train_iters[-1]/2), train_iters[-1]]
86 | ),
87 | yaxis=dict(title="Energy"),
88 | font=dict(size=16),
89 | margin=dict(r=120)
90 | )
91 | fig.write_image(save_path)
92 |
93 |
94 | def evaluate(model, test_loader):
95 | avg_test_loss, avg_test_acc = 0, 0
96 | for _, (img_batch, label_batch) in enumerate(test_loader):
97 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
98 |
99 | test_loss, test_acc = jpc.test_discriminative_pc(
100 | model=model,
101 | output=label_batch,
102 | input=img_batch
103 | )
104 | avg_test_loss += test_loss
105 | avg_test_acc += test_acc
106 |
107 | return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader)
108 |
109 |
110 | def train(
111 | dataset,
112 | width,
113 | n_hidden,
114 | lr,
115 | batch_size,
116 | max_t1,
117 | test_every,
118 | n_train_iters,
119 | save_dir
120 | ):
121 | key = jr.PRNGKey(0)
122 | input_dim = 3072 if dataset == "CIFAR10" else 784
123 | model = jpc.make_mlp(
124 | key,
125 | input_dim=input_dim,
126 | width=width,
127 | depth=n_hidden+1,
128 | output_dim=10,
129 | act_fn="linear",
130 | use_bias=False
131 | )
132 | optim = optax.adam(lr)
133 | opt_state = optim.init(
134 | (eqx.filter(model, eqx.is_array), None)
135 | )
136 | train_loader, test_loader = get_dataloaders(dataset, batch_size)
137 |
138 | test_accs = []
139 | theory_energies, num_energies = [], []
140 | for batch_id, (img_batch, label_batch) in enumerate(train_loader):
141 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
142 |
143 | theory_energies.append(
144 | jpc.compute_linear_equilib_energy(
145 | network=model,
146 | x=img_batch,
147 | y=label_batch
148 | )
149 | )
150 | result = jpc.make_pc_step(
151 | model,
152 | optim,
153 | opt_state,
154 | output=label_batch,
155 | input=img_batch,
156 | max_t1=max_t1,
157 | record_energies=True
158 | )
159 | model, opt_state = result["model"], result["opt_state"]
160 | train_loss, t_max = result["loss"], result["t_max"]
161 | num_energies.append(result["energies"][:, t_max-1].sum())
162 |
163 | if ((batch_id+1) % test_every) == 0:
164 | _, avg_test_acc = evaluate(model, test_loader)
165 | test_accs.append(avg_test_acc)
166 | print(
167 | f"Train iter {batch_id+1}, train loss={train_loss:4f}, "
168 | f"avg test accuracy={avg_test_acc:4f}"
169 | )
170 | if (batch_id+1) >= n_train_iters:
171 | break
172 |
173 | np.save(f"{save_dir}/test_accs.npy", test_accs)
174 | np.save(f"{save_dir}/theory_energies.npy", theory_energies)
175 | np.save(f"{save_dir}/num_energies.npy", num_energies)
176 |
177 | return test_accs, jnp.array(theory_energies), jnp.array(num_energies)
178 |
179 |
180 | if __name__ == "__main__":
181 | RESULTS_DIR = "theory_energies_results"
182 | DATASETS = ["MNIST", "Fashion-MNIST"]
183 | SEED = 916
184 | WIDTH = 300
185 | N_HIDDEN = 10
186 | LR = 1e-3
187 | BATCH_SIZE = 64
188 | MAX_T1S = [200, 100, 50, 20, 10]
189 | TEST_EVERY = 10
190 | N_TRAIN_ITERS = 100
191 |
192 | for dataset in DATASETS:
193 | set_seed(SEED)
194 | all_test_accs, all_theory_energies, all_num_energies = {}, {}, {}
195 | for max_t1 in MAX_T1S:
196 | print(f"\nmax_t1: {max_t1}")
197 | save_dir = os.path.join(RESULTS_DIR, dataset, f"max_t1_{max_t1}")
198 | os.makedirs(save_dir, exist_ok=True)
199 | test_accs, theory_energies, num_energies = train(
200 | dataset=dataset,
201 | width=WIDTH,
202 | n_hidden=N_HIDDEN,
203 | lr=LR,
204 | batch_size=BATCH_SIZE,
205 | max_t1=max_t1,
206 | test_every=TEST_EVERY,
207 | n_train_iters=N_TRAIN_ITERS,
208 | save_dir=save_dir
209 | )
210 | all_test_accs[max_t1] = test_accs
211 | all_theory_energies[max_t1] = theory_energies
212 | all_num_energies[max_t1] = num_energies
213 |
214 | plot_accuracies(
215 | all_test_accs,
216 | f"{RESULTS_DIR}/{dataset}/test_accs.pdf"
217 | )
218 | plot_energies_across_ts(
219 | all_theory_energies[MAX_T1S[0]],
220 | all_num_energies,
221 | f"{RESULTS_DIR}/{dataset}/energies.pdf"
222 | )
223 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | 🧠 Flexible Inference for Predictive Coding Networks in JAX ⚡️
8 |
9 |  [](https://arxiv.org/abs/2412.03676)
10 | 
11 |
12 | ## 📢 Updates
13 | * [](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/epc.ipynb) Nov 2025: Added "ePC" ([Goemaere, et al., 2025](https://arxiv.org/abs/2505.20137))
14 | * [](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/bidirectional_pc.ipynb) Oct 2025: Added bidirectional PC (bPC, [Oliviers, et al., 2025](https://arxiv.org/abs/2505.23415))
15 | * [](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/mupc.ipynb) May 2025: Added "μPC" ([Innocenti et al., 2025](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1)))
16 |
17 | ---
18 |
19 | JPC is a [**J**AX](https://github.com/google/jax) library for training neural
20 | networks with **P**redictive **C**oding (PC).
21 |
22 | JPC provides a **simple**, **fast** and **flexible** API for
23 | training of a variety of PCNs including discriminative, generative and hybrid
24 | models.
25 | * Like JAX, JPC is completely functional in design, and the core library code is
26 | <1000 lines of code.
27 | * Unlike existing implementations, JPC provides a wide range of optimisers, both
28 | discrete and continuous, to solve the inference dynamics of PC, including
29 | ordinary differential equation (ODE) solvers.
30 | * JPC also provides some analytical tools that can be used to study and
31 | potentially diagnose issues with PCNs.
32 |
33 | If you're new to JPC, we recommend starting from the [
34 | example notebooks](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/)
35 | and checking the [documentation](https://thebuckleylab.github.io/jpc/).
36 |
37 | ## Overview
38 | * [Installation](#-installation)
39 | * [Documentation](#-documentation)
40 | * [Quick example](#-quick-example)
41 | * [Advanced usage](#-advanced-usage)
42 | * [Contributing](#-contributing)
43 | * [Citation](#-citation)
44 |
45 | ## ️💻 Installation
46 | Clone the repo and in the project's directory run
47 | ```
48 | pip install .
49 | ```
50 |
51 | Requires Python 3.10+ and JAX 0.4.38–0.5.2 (inclusive). For GPU usage, upgrade
52 | jax to the appropriate cuda version (12 as an example here).
53 |
54 | ```
55 | pip install --upgrade "jax[cuda12]"
56 | ```
57 |
58 | ## 📖 [Documentation](https://thebuckleylab.github.io/jpc/)
59 | Available at https://thebuckleylab.github.io/jpc/.
60 |
61 | ## ⚡️ Quick example
62 | Use `jpc.make_pc_step()` to update the parameters of any neural network
63 | compatible with PC updates (see the [notebook examples
64 | ](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/))
65 | ```py
66 | import jax.random as jr
67 | import jax.numpy as jnp
68 | import equinox as eqx
69 | import optax
70 | import jpc
71 |
72 | # toy data
73 | x = jnp.array([1., 1., 1.])
74 | y = -x
75 |
76 | # define model and optimiser
77 | key = jr.PRNGKey(0)
78 | model = jpc.make_mlp(
79 | key,
80 | input_dim=3,
81 | width=50,
82 | depth=5,
83 | output_dim=3
84 | act_fn="relu"
85 | )
86 | optim = optax.adam(1e-3)
87 | opt_state = optim.init(
88 | (eqx.filter(model, eqx.is_array), None)
89 | )
90 |
91 | # perform one training step with PC
92 | result = jpc.make_pc_step(
93 | model=model,
94 | optim=optim,
95 | opt_state=opt_state,
96 | output=y,
97 | input=x
98 | )
99 |
100 | # updated model and optimiser
101 | model, opt_state = result["model"], result["opt_state"]
102 | ```
103 | Under the hood, `jpc.make_pc_step()`
104 | 1. integrates the inference (activity) dynamics using a [diffrax](https://github.com/patrick-kidger/diffrax) ODE solver, and
105 | 2. updates model parameters at the numerical solution of the activities with a given [optax](https://github.com/google-deepmind/optax) optimiser.
106 |
107 | See the [documentation](https://thebuckleylab.github.io/jpc/) for more details.
108 |
109 | > **NOTE**: All convenience training and test functions such as `make_pc_step()`
110 | > are already "jitted" (for optimised performance) for the user's convenience.
111 |
112 | ## 🚀 Advanced usage
113 | Advanced users can access all the underlying functions of `jpc.make_pc_step()`
114 | as well as additional features. A custom PC training step looks like the
115 | following:
116 | ```py
117 | import jpc
118 |
119 | # 1. initialise activities with a feedforward pass
120 | activities = jpc.init_activities_with_ffwd(model=model, input=x)
121 |
122 | # 2. perform inference (state optimisation)
123 | activity_opt_state = activity_optim.init(activities)
124 | for _ in range(len(model)):
125 | activity_update_result = jpc.update_pc_activities(
126 | params=(model, None),
127 | activities=activities,
128 | optim=activity_optim,
129 | opt_state=activity_opt_state,
130 | output=y,
131 | input=x
132 | )
133 | activities = activity_update_result["activities"]
134 | activity_opt_state = activity_update_result["opt_state"]
135 |
136 | # 3. update parameters at the activities' solution with PC
137 | result = jpc.update_params(
138 | params=(model, None),
139 | activities=converged_activities,
140 | optim=optim,
141 | opt_state=opt_state,
142 | output=y,
143 | input=x
144 | )
145 | ```
146 | which can be embedded in a jitted function with any other additional
147 | computations. Again, see the [docs](https://thebuckleylab.github.io/jpc/)
148 | for details.
149 |
150 | ## 🤝 Contributing
151 | Contributions are welcome! Fork the repo, install in editable mode (`pip install -e .`), then:
152 | * Run `ruff check .` before committing (auto-fix with `ruff check --fix .`)
153 | * Ensure all tests pass: `pytest tests/`
154 | * Add docstrings to public functions and update `docs/` for user-facing changes
155 | * Open a PR with a clear description
156 |
157 | For major features, open an issue first to discuss.
158 |
159 | ## 📄 Citation
160 | If you found this library useful in your work, please cite ([paper link](https://arxiv.org/abs/2412.03676)):
161 |
162 | ```bibtex
163 | @article{innocenti2024jpc,
164 | title={JPC: Flexible Inference for Predictive Coding Networks in JAX},
165 | author={Innocenti, Francesco and Kinghorn, Paul and Yun-Farmbrough, Will and Varona, Miguel De Llanza and Singh, Ryan and Buckley, Christopher L},
166 | journal={arXiv preprint arXiv:2412.03676},
167 | year={2024}
168 | }
169 | ```
170 | Also consider starring the repo! ⭐️
171 |
172 | ## 🙏 Acknowledgements
173 | We are grateful to Patrick Kidger for early advice on how to use Diffrax.
174 |
175 | ## See also: other PC libraries
176 | * [ngc-learn](https://github.com/NACLab/ngc-learn) (jax & pytorch)
177 | * [pcx](https://github.com/liukidar/pcx) (jax)
178 | * [pyhgf](https://github.com/ComputationalPsychiatry/pyhgf) (jax)
179 | * [Torch2PC](https://github.com/RobertRosenbaum/Torch2PC) (pytorch)
180 | * [pypc](https://github.com/infer-actively/pypc) (pytorch)
181 | * [pybrid](https://github.com/alec-tschantz/pybrid) (pytorch)
182 |
--------------------------------------------------------------------------------
/tests/test_train.py:
--------------------------------------------------------------------------------
1 | """Tests for training functions."""
2 |
3 | import pytest
4 | import jax
5 | import jax.numpy as jnp
6 | import optax
7 | from diffrax import Heun, PIDController
8 | from jpc import make_pc_step, make_hpc_step
9 |
10 |
11 | def test_make_pc_step_supervised(simple_model, x, y):
12 | """Test PC training step in supervised mode."""
13 | optim = optax.sgd(learning_rate=0.01)
14 | opt_state = optim.init((simple_model, None))
15 |
16 | result = make_pc_step(
17 | model=simple_model,
18 | optim=optim,
19 | opt_state=opt_state,
20 | output=y,
21 | input=x,
22 | loss_id="mse",
23 | param_type="sp",
24 | ode_solver=Heun(),
25 | max_t1=5,
26 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3)
27 | )
28 |
29 | assert "model" in result
30 | assert "skip_model" in result
31 | assert "opt_state" in result
32 | assert "loss" in result
33 | assert len(result["model"]) == len(simple_model)
34 |
35 |
36 | def test_make_pc_step_unsupervised(simple_model, y, key, layer_sizes, batch_size):
37 | """Test PC training step in unsupervised mode."""
38 | optim = optax.sgd(learning_rate=0.01)
39 | opt_state = optim.init((simple_model, None))
40 |
41 | result = make_pc_step(
42 | model=simple_model,
43 | optim=optim,
44 | opt_state=opt_state,
45 | output=y,
46 | input=None,
47 | loss_id="mse",
48 | param_type="sp",
49 | ode_solver=Heun(),
50 | max_t1=5,
51 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3),
52 | key=key,
53 | layer_sizes=layer_sizes,
54 | batch_size=batch_size,
55 | sigma=0.05
56 | )
57 |
58 | assert "model" in result
59 | assert "opt_state" in result
60 |
61 |
62 | def test_make_pc_step_cross_entropy(simple_model, x, y_onehot):
63 | """Test PC training step with cross-entropy loss."""
64 | optim = optax.sgd(learning_rate=0.01)
65 | opt_state = optim.init((simple_model, None))
66 |
67 | result = make_pc_step(
68 | model=simple_model,
69 | optim=optim,
70 | opt_state=opt_state,
71 | output=y_onehot,
72 | input=x,
73 | loss_id="ce",
74 | param_type="sp",
75 | ode_solver=Heun(),
76 | max_t1=5,
77 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3)
78 | )
79 |
80 | assert "model" in result
81 | assert "loss" in result
82 |
83 |
84 | def test_make_pc_step_with_regularization(simple_model, x, y):
85 | """Test PC training step with regularization."""
86 | optim = optax.sgd(learning_rate=0.01)
87 | opt_state = optim.init((simple_model, None))
88 |
89 | result = make_pc_step(
90 | model=simple_model,
91 | optim=optim,
92 | opt_state=opt_state,
93 | output=y,
94 | input=x,
95 | loss_id="mse",
96 | param_type="sp",
97 | ode_solver=Heun(),
98 | max_t1=5,
99 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3),
100 | weight_decay=0.01,
101 | spectral_penalty=0.01,
102 | activity_decay=0.01
103 | )
104 |
105 | assert "model" in result
106 |
107 |
108 | def test_make_pc_step_with_metrics(simple_model, x, y):
109 | """Test PC training step with metrics recording."""
110 | optim = optax.sgd(learning_rate=0.01)
111 | opt_state = optim.init((simple_model, None))
112 |
113 | result = make_pc_step(
114 | model=simple_model,
115 | optim=optim,
116 | opt_state=opt_state,
117 | output=y,
118 | input=x,
119 | loss_id="mse",
120 | param_type="sp",
121 | ode_solver=Heun(),
122 | max_t1=5,
123 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3),
124 | record_activities=True,
125 | record_energies=True,
126 | activity_norms=True,
127 | param_norms=True,
128 | grad_norms=True,
129 | calculate_accuracy=True
130 | )
131 |
132 | assert "model" in result
133 | assert "activities" in result
134 | assert "energies" in result
135 | assert "activity_norms" in result
136 | assert "model_param_norms" in result
137 | assert "acc" in result
138 |
139 |
140 | def test_make_pc_step_invalid_input():
141 | """Test PC training step with invalid input (missing required args for unsupervised)."""
142 | import jax
143 | from jpc import make_mlp
144 | import optax
145 | from diffrax import Heun, PIDController
146 |
147 | key = jax.random.PRNGKey(42)
148 | model = make_mlp(key, 10, 20, 5, 3, "relu", False, "sp")
149 | optim = optax.sgd(learning_rate=0.01)
150 | opt_state = optim.init((model, None))
151 | y = jax.random.normal(key, (4, 5))
152 |
153 | with pytest.raises(ValueError):
154 | make_pc_step(
155 | model=model,
156 | optim=optim,
157 | opt_state=opt_state,
158 | output=y,
159 | input=None,
160 | loss_id="mse",
161 | param_type="sp",
162 | ode_solver=Heun(),
163 | max_t1=5,
164 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3)
165 | )
166 |
167 |
168 | def test_make_hpc_step(key, simple_model, x, y, output_dim, hidden_dim, input_dim, depth):
169 | """Test HPC training step."""
170 | from jpc import make_mlp
171 |
172 | generator = simple_model
173 | amortiser = make_mlp(
174 | key=key,
175 | input_dim=output_dim,
176 | width=hidden_dim,
177 | depth=depth,
178 | output_dim=input_dim,
179 | act_fn="relu",
180 | use_bias=False,
181 | param_type="sp"
182 | )
183 |
184 | gen_optim = optax.sgd(learning_rate=0.01)
185 | amort_optim = optax.sgd(learning_rate=0.01)
186 | gen_opt_state = gen_optim.init((generator, None))
187 | amort_opt_state = amort_optim.init(amortiser)
188 |
189 | result = make_hpc_step(
190 | generator=generator,
191 | amortiser=amortiser,
192 | optims=(gen_optim, amort_optim),
193 | opt_states=(gen_opt_state, amort_opt_state),
194 | output=y,
195 | input=x,
196 | ode_solver=Heun(),
197 | max_t1=5,
198 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3)
199 | )
200 |
201 | assert "generator" in result
202 | assert "amortiser" in result
203 | assert "opt_states" in result
204 | assert "losses" in result
205 | assert len(result["opt_states"]) == 2
206 |
207 |
208 | def test_make_hpc_step_unsupervised(key, simple_model, y, output_dim, hidden_dim, input_dim, depth):
209 | """Test HPC training step in unsupervised mode."""
210 | from jpc import make_mlp
211 |
212 | generator = simple_model
213 | amortiser = make_mlp(
214 | key=key,
215 | input_dim=output_dim,
216 | width=hidden_dim,
217 | depth=depth,
218 | output_dim=input_dim,
219 | act_fn="relu",
220 | use_bias=False,
221 | param_type="sp"
222 | )
223 |
224 | gen_optim = optax.sgd(learning_rate=0.01)
225 | amort_optim = optax.sgd(learning_rate=0.01)
226 | gen_opt_state = gen_optim.init((generator, None))
227 | amort_opt_state = amort_optim.init(amortiser)
228 |
229 | result = make_hpc_step(
230 | generator=generator,
231 | amortiser=amortiser,
232 | optims=(gen_optim, amort_optim),
233 | opt_states=(gen_opt_state, amort_opt_state),
234 | output=y,
235 | input=None,
236 | ode_solver=Heun(),
237 | max_t1=5,
238 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3)
239 | )
240 |
241 | assert "generator" in result
242 | assert "amortiser" in result
243 |
244 |
245 | def test_make_hpc_step_with_recording(key, simple_model, x, y, output_dim, hidden_dim, input_dim, depth):
246 | """Test HPC training step with activity/energy recording."""
247 | from jpc import make_mlp
248 |
249 | generator = simple_model
250 | amortiser = make_mlp(
251 | key=key,
252 | input_dim=output_dim,
253 | width=hidden_dim,
254 | depth=depth,
255 | output_dim=input_dim,
256 | act_fn="relu",
257 | use_bias=False,
258 | param_type="sp"
259 | )
260 |
261 | gen_optim = optax.sgd(learning_rate=0.01)
262 | amort_optim = optax.sgd(learning_rate=0.01)
263 | gen_opt_state = gen_optim.init((generator, None))
264 | amort_opt_state = amort_optim.init(amortiser)
265 |
266 | result = make_hpc_step(
267 | generator=generator,
268 | amortiser=amortiser,
269 | optims=(gen_optim, amort_optim),
270 | opt_states=(gen_opt_state, amort_opt_state),
271 | output=y,
272 | input=x,
273 | ode_solver=Heun(),
274 | max_t1=5,
275 | stepsize_controller=PIDController(rtol=1e-3, atol=1e-3),
276 | record_activities=True,
277 | record_energies=True
278 | )
279 |
280 | assert "activities" in result
281 | assert "energies" in result
282 |
283 |
--------------------------------------------------------------------------------
/experiments/mupc_paper/analyse_activity_hessian.py:
--------------------------------------------------------------------------------
1 | import os
2 | import jax
3 | import jax.random as jr
4 | import jax.numpy as jnp
5 | import numpy as np
6 |
7 | from experiments.datasets import make_gaussian_dataset, get_dataloaders
8 | from experiments.mupc_paper.utils import (
9 | setup_hessian_analysis,
10 | set_seed,
11 | init_weights,
12 | get_network_weights,
13 | unwrap_hessian_pytree
14 | )
15 | import jpc
16 | from experiments.mupc_paper.plotting import plot_activity_hessian
17 |
18 |
19 | def compute_hessian_metrics(
20 | network,
21 | act_fn,
22 | skip_model,
23 | y,
24 | x,
25 | use_skips,
26 | param_type,
27 | activity_decay,
28 | mode,
29 | layer_sizes,
30 | key
31 | ):
32 | if act_fn == "linear":
33 | # theoretical activity Hessian
34 | weights = get_network_weights(network)
35 | theory_H = jpc.compute_linear_activity_hessian(
36 | weights,
37 | param_type=param_type,
38 | use_skips=use_skips,
39 | activity_decay=activity_decay
40 | )
41 | D = jpc.compute_linear_activity_hessian(
42 | weights,
43 | param_type=param_type,
44 | off_diag=False,
45 | use_skips=use_skips,
46 | activity_decay=activity_decay
47 | )
48 | O = jpc.compute_linear_activity_hessian(
49 | weights,
50 | param_type=param_type,
51 | diag=False,
52 | use_skips=use_skips,
53 | activity_decay=activity_decay
54 | )
55 |
56 | # numerical activity Hessian
57 | if mode == "supervised":
58 | activities = jpc.init_activities_with_ffwd(
59 | network,
60 | x,
61 | skip_model=skip_model,
62 | param_type=param_type
63 | )
64 | elif mode == "unsupervised":
65 | activities = jpc.init_activities_from_normal(
66 | key=key,
67 | layer_sizes=layer_sizes,
68 | mode=mode,
69 | batch_size=1,
70 | sigma=1
71 | )
72 |
73 | hessian_pytree = jax.hessian(jpc.pc_energy_fn, argnums=1)(
74 | (network, skip_model),
75 | activities,
76 | y,
77 | x=x,
78 | param_type=param_type,
79 | activity_decay=activity_decay
80 | )
81 | num_H = unwrap_hessian_pytree(
82 | hessian_pytree,
83 | activities,
84 | )
85 |
86 | # compute eigenthings
87 | num_H_eigenvals, _ = jnp.linalg.eigh(num_H)
88 | cond_num = jnp.linalg.cond(num_H)
89 | if act_fn == "linear":
90 | theory_H_eigenvals, _ = jnp.linalg.eigh(theory_H)
91 | D_eigenvals, _ = jnp.linalg.eigh(D)
92 | O_eigenvals, _ = jnp.linalg.eigh(O)
93 |
94 | return {
95 | "hessian": num_H,
96 | "num": num_H_eigenvals,
97 | "cond_num": cond_num,
98 | "theory": theory_H_eigenvals if act_fn == "linear" else None,
99 | "D": D_eigenvals if act_fn == "linear" else None,
100 | "O": O_eigenvals if act_fn == "linear" else None
101 | }
102 |
103 |
104 | def run_analysis(
105 | seed,
106 | in_out_dims,
107 | act_fn,
108 | use_biases,
109 | mode,
110 | use_skips,
111 | weight_init,
112 | param_type,
113 | activity_decay,
114 | width,
115 | n_hidden,
116 | save_dir
117 | ):
118 | set_seed(seed)
119 | key = jr.PRNGKey(seed)
120 | keys = jr.split(key, 4)
121 |
122 | d_in = width if in_out_dims == "width" else in_out_dims[0]
123 | d_out = width if in_out_dims == "width" else in_out_dims[1]
124 |
125 | # create and initialise model
126 | L = n_hidden+1
127 | network = jpc.make_mlp(
128 | key=keys[0],
129 | input_dim=d_in,
130 | width=width,
131 | depth=L,
132 | output_dim=d_out,
133 | act_fn=act_fn,
134 | use_bias=use_biases,
135 | param_type=param_type
136 | )
137 | if weight_init != "standard":
138 | network = init_weights(
139 | key=keys[1],
140 | model=network,
141 | init_fn_id=weight_init
142 | )
143 | skip_model = jpc.make_skip_model(L) if use_skips else None
144 |
145 | # data
146 | if in_out_dims != "width":
147 | train_loader, _ = get_dataloaders("MNIST", batch_size=1)
148 | img_batch, label_batch = next(iter(train_loader))
149 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
150 | x, y = (img_batch, label_batch) if d_in == 784 else (label_batch, img_batch)
151 | else:
152 | x, y = make_gaussian_dataset(keys[2], 1, 0.1, (1, width))
153 | if mode == "unsupervised":
154 | x = None
155 |
156 | layer_sizes = [d_in] + [width]*n_hidden + [d_out]
157 | metrics = compute_hessian_metrics(
158 | network=network,
159 | act_fn=act_fn,
160 | skip_model=skip_model,
161 | y=y,
162 | x=x,
163 | use_skips=use_skips,
164 | param_type=param_type,
165 | activity_decay=activity_decay,
166 | mode=mode,
167 | layer_sizes=layer_sizes,
168 | key=keys[3]
169 | )
170 | plot_activity_hessian(
171 | metrics["hessian"],
172 | f"{save_dir}/hessian_matrix.pdf"
173 | )
174 | np.save(
175 | f"{save_dir}/num_hessian_eigenvals",
176 | metrics["num"]
177 | )
178 | np.save(
179 | f"{save_dir}/cond_num",
180 | metrics["cond_num"]
181 | )
182 | if act_fn == "linear":
183 | np.save(
184 | f"{save_dir}/theory_hessian_eigenvals",
185 | metrics["theory"]
186 | )
187 | np.save(
188 | f"{save_dir}/theory_D_eigenvals",
189 | metrics["D"]
190 | )
191 | np.save(
192 | f"{save_dir}/theory_O_eigenvals",
193 | metrics["O"]
194 | )
195 |
196 |
197 | if __name__ == "__main__":
198 | RESULTS_DIR = "activity_hessian_results"
199 | IN_OUT_DIMS = [[784, 10]] #, [784, 10], [10, 784]]
200 | ACT_FNS = ["linear"]#, "tanh", "relu"]
201 | USE_BIASES = [False]
202 | MODES = ["supervised"] #,"unsupervised"]
203 | USE_SKIPS = [False, True]
204 | WEIGHT_INITS = ["standard"]#["one_over_N", "standard", "orthogonal"]
205 | PARAM_TYPES = ["sp"]#, "mupc", "ntp"]
206 | ACTIVITY_DECAY = [False]#, True]
207 | WIDTHS = [2 ** i for i in range(11)]
208 | N_HIDDENS = [2 ** i for i in range(4)]
209 | N_SEEDS = 3
210 |
211 | for in_out_dims in IN_OUT_DIMS:
212 | for act_fn in ACT_FNS:
213 | for use_biases in USE_BIASES:
214 | for mode in MODES:
215 | for use_skips in USE_SKIPS:
216 | for weight_init in WEIGHT_INITS:
217 | for param_type in PARAM_TYPES:
218 | for activity_decay in ACTIVITY_DECAY:
219 | for width in WIDTHS:
220 | for n_hidden in N_HIDDENS:
221 | for seed in range(N_SEEDS):
222 | save_dir = setup_hessian_analysis(
223 | results_dir=RESULTS_DIR,
224 | in_out_dims=in_out_dims,
225 | act_fn=act_fn,
226 | use_biases=use_biases,
227 | mode=mode,
228 | use_skips=use_skips,
229 | weight_init=weight_init,
230 | param_type=param_type,
231 | activity_decay=activity_decay,
232 | width=width,
233 | n_hidden=n_hidden,
234 | seed=seed
235 | )
236 | os.makedirs(save_dir, exist_ok=True)
237 | run_analysis(
238 | seed=seed,
239 | in_out_dims=in_out_dims,
240 | act_fn=act_fn,
241 | use_biases=use_biases,
242 | mode=mode,
243 | use_skips=use_skips,
244 | weight_init=weight_init,
245 | param_type=param_type,
246 | activity_decay=activity_decay,
247 | width=width,
248 | n_hidden=n_hidden,
249 | save_dir=save_dir
250 | )
251 |
--------------------------------------------------------------------------------
/experiments/mupc_paper/test_mlp_fwd_pass.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 |
5 | import jax.random as jr
6 | import jax.numpy as jnp
7 | from jax import vmap
8 |
9 | import equinox as eqx
10 | import equinox.nn as nn
11 | import optax
12 | import jpc
13 |
14 | from experiments.datasets import get_dataloaders
15 | from experiments.mupc_paper.utils import (
16 | set_seed,
17 | init_weights,
18 | compute_param_l2_norms,
19 | compute_param_spectral_norms
20 | )
21 |
22 |
23 | class MLP(eqx.Module):
24 | D: int
25 | N: int
26 | L: int
27 | param_type: str
28 | use_skips: bool
29 | layers: list
30 |
31 | def __init__(
32 | self,
33 | key,
34 | d_in,
35 | N,
36 | L,
37 | d_out,
38 | act_fn,
39 | param_type,
40 | use_bias=False,
41 | use_skips=False
42 | ):
43 | self.D = d_in
44 | self.N = N
45 | self.L = L
46 | self.param_type = param_type
47 | self.use_skips = use_skips
48 |
49 | keys = jr.split(key, L)
50 | self.layers = []
51 | for i in range(L):
52 | act_fn_l = nn.Identity() if i == 0 else jpc.get_act_fn(act_fn)
53 | _in = d_in if i == 0 else N
54 | _out = d_out if (i + 1) == L else N
55 | layer = nn.Sequential(
56 | [
57 | nn.Lambda(act_fn_l),
58 | nn.Linear(
59 | _in,
60 | _out,
61 | use_bias=use_bias,
62 | key=keys[i]
63 | )
64 | ]
65 | )
66 | self.layers.append(layer)
67 |
68 | def __call__(self, x):
69 | pre_activs = []
70 |
71 | if self.param_type == "depth_mup":
72 | for i, f in enumerate(self.layers):
73 | if (i + 1) == 1:
74 | x = f(x) / jnp.sqrt(self.D)
75 | elif 1 < (i + 1) < self.L:
76 | residual = x if self.use_skips else 0
77 | rescaling = jnp.sqrt(
78 | self.N * self.L
79 | ) if self.use_skips else jnp.sqrt(self.N)
80 | x = (f(x) / rescaling) + residual
81 | elif (i + 1) == self.L:
82 | x = f(x) / self.N
83 |
84 | pre_activs.append(x)
85 |
86 | else:
87 | for i, f in enumerate(self.layers):
88 | residual = x if self.use_skips and (1 < (i + 1) < self.L) else 0
89 |
90 | x = f(x) + residual
91 |
92 | pre_activs.append(x)
93 |
94 | return pre_activs
95 |
96 |
97 | def mse_loss(model, x, y):
98 | y_pred = vmap(model)(x)[-1]
99 | return jnp.mean((y - y_pred) ** 2)
100 |
101 |
102 | @eqx.filter_jit
103 | def make_step(model, optim, opt_state, x, y):
104 | loss, grads = eqx.filter_value_and_grad(mse_loss)(model, x, y)
105 | updates, opt_state = optim.update(
106 | updates=grads,
107 | state=opt_state,
108 | params=eqx.filter(model, eqx.is_array)
109 | )
110 | model = eqx.apply_updates(model, updates)
111 | return model, opt_state, loss
112 |
113 |
114 | def test_fwd_pass(
115 | seed,
116 | dataset,
117 | width,
118 | depth,
119 | act_fn,
120 | optim_id,
121 | param_type,
122 | use_skips,
123 | lr,
124 | batch_size,
125 | n_checks
126 | ):
127 | set_seed(seed)
128 |
129 | key = jr.PRNGKey(seed)
130 | keys = jr.split(key, 2)
131 | model = MLP(
132 | key=keys[0],
133 | d_in=784,
134 | N=width,
135 | L=depth,
136 | d_out=10,
137 | act_fn=act_fn,
138 | param_type=param_type,
139 | use_bias=False,
140 | use_skips=use_skips
141 | )
142 |
143 | if param_type == "depth_mup":
144 | model = init_weights(
145 | model=model,
146 | init_fn_id="standard_gauss",
147 | key=keys[1]
148 | )
149 | elif param_type == "orthogonal":
150 | model = init_weights(
151 | model=model,
152 | init_fn_id="orthogonal",
153 | key=keys[1],
154 | gain=1.05 if act_fn == "tanh" else 1
155 | )
156 |
157 | optim = optax.sgd(lr) if optim_id == "sgd" else optax.adam(lr)
158 | opt_state = optim.init(eqx.filter(model, eqx.is_array))
159 |
160 | layer_idxs = [0, int(depth/4)-1, int(depth/2)-1, int(depth*3/4)-1, depth-1]
161 | avg_activity_l1 = np.zeros((len(layer_idxs), n_checks))
162 | avg_activity_l2 = np.zeros_like(avg_activity_l1)
163 | param_l2_norms = np.zeros_like(avg_activity_l1)
164 | param_spectral_norms = np.zeros_like(avg_activity_l1)
165 |
166 | train_loader, _ = get_dataloaders(dataset, batch_size)
167 | for t, (img_batch, label_batch) in enumerate(train_loader):
168 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
169 |
170 | pre_activities = vmap(model)(img_batch)
171 | i = 0
172 | for l, pre_act in enumerate(pre_activities):
173 | if l in layer_idxs:
174 | avg_activity_l1[i, t] = jnp.abs(pre_act).mean()
175 | avg_activity_l2[i, t] = jnp.sqrt(jnp.mean(pre_act**2))
176 | i += 1
177 |
178 | param_l2_norms[:, t] = compute_param_l2_norms(
179 | model=model.layers,
180 | act_fn=act_fn,
181 | layer_idxs=layer_idxs
182 | )
183 | param_spectral_norms[:, t] = compute_param_spectral_norms(
184 | model=model.layers,
185 | act_fn=act_fn,
186 | layer_idxs=layer_idxs
187 | )
188 | model, opt_state, _ = make_step(
189 | model=model,
190 | optim=optim,
191 | opt_state=opt_state,
192 | x=img_batch,
193 | y=label_batch
194 | )
195 | if t >= (n_checks - 1):
196 | break
197 |
198 | return (
199 | avg_activity_l1,
200 | avg_activity_l2,
201 | param_l2_norms,
202 | param_spectral_norms
203 | )
204 |
205 |
206 | if __name__ == "__main__":
207 | RESULTS_DIR = "mlp_fwd_pass_results"
208 | DATASET = "MNIST"
209 | WIDTHS = [2 ** i for i in range(7, 11)]
210 | DEPTHS = [2 ** i for i in range(4, 10)]
211 | LR = 1e-3
212 | BATCH_SIZE = 64
213 | N_RECORDED_LAYERS = 5
214 | N_CHECKS = 5
215 |
216 | parser = argparse.ArgumentParser()
217 | parser.add_argument("--act_fns", type=str, nargs='+', default=["linear", "tanh", "relu"])
218 | parser.add_argument("--optim_ids", type=str, nargs='+', default=["sgd", "adam"])
219 | parser.add_argument("--param_types", type=str, nargs='+', default=["sp", "depth_mup", "orthogonal"])
220 | parser.add_argument("--seed", type=int, default=54638)
221 | args = parser.parse_args()
222 |
223 | for act_fn in args.act_fns:
224 | print(f"\nact_fn: {act_fn}")
225 |
226 | for optim_id in args.optim_ids:
227 | print(f"\n\toptim: {optim_id}")
228 |
229 | for param_type in args.param_types:
230 | print(f"\n\t\tparam_type: {param_type}")
231 |
232 | skip_uses = [False, True] if param_type != "orthogonal" else [False]
233 | for use_skips in skip_uses:
234 | print(f"\n\t\t\tuse_skips: {use_skips}")
235 |
236 | save_dir = os.path.join(
237 | RESULTS_DIR,
238 | act_fn,
239 | optim_id,
240 | param_type,
241 | "skips" if use_skips else "no_skips",
242 | str(args.seed)
243 | )
244 | os.makedirs(save_dir, exist_ok=True)
245 |
246 | avg_activity_l1_per_N_L = np.zeros((N_RECORDED_LAYERS, N_CHECKS, len(WIDTHS), len(DEPTHS)))
247 | avg_activity_l2_per_N_L = np.zeros_like(avg_activity_l1_per_N_L)
248 | param_l2_norms_per_N_L = np.zeros_like(avg_activity_l1_per_N_L)
249 | param_spectral_norms_per_N_L = np.zeros_like(avg_activity_l1_per_N_L)
250 |
251 | for w, width in enumerate(WIDTHS):
252 | print(f"\n\t\t\t\tN = {width}\n")
253 | for d, depth in enumerate(DEPTHS):
254 | print(f"\t\t\t\t\tL = {depth}")
255 |
256 | avg_activity_l1, avg_activity_l2, param_l2_norms, param_spectral_norms = test_fwd_pass(
257 | seed=args.seed,
258 | dataset=DATASET,
259 | width=width,
260 | depth=depth,
261 | act_fn=act_fn,
262 | optim_id=optim_id,
263 | param_type=param_type,
264 | use_skips=use_skips,
265 | lr=LR,
266 | batch_size=BATCH_SIZE,
267 | n_checks=N_CHECKS
268 | )
269 | avg_activity_l1_per_N_L[:, :, w, d] = avg_activity_l1
270 | avg_activity_l2_per_N_L[:, :, w, d] = avg_activity_l2
271 | param_l2_norms_per_N_L[:, :, w, d] = param_l2_norms
272 | param_spectral_norms_per_N_L[:, :, w, d] = param_spectral_norms
273 |
274 | np.save(f"{save_dir}/avg_activity_l1_per_N_L.npy", avg_activity_l1_per_N_L)
275 | np.save(f"{save_dir}/avg_activity_l2_per_N_L.npy", avg_activity_l2_per_N_L)
276 | np.save(f"{save_dir}/param_l2_norms_per_N_L.npy", param_l2_norms_per_N_L)
277 | np.save(f"{save_dir}/param_spectral_norms_per_N_L.npy", param_spectral_norms_per_N_L)
278 |
--------------------------------------------------------------------------------
/experiments/mupc_paper/train_bpn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 |
5 | import jax.random as jr
6 | import jax.numpy as jnp
7 | from jax import vmap
8 | from jax.nn import log_softmax
9 |
10 | import equinox as eqx
11 | import equinox.nn as nn
12 | import optax
13 | import jpc
14 |
15 | from experiments.datasets import get_dataloaders
16 | from experiments.mupc_paper.utils import set_seed, setup_logger, init_weights
17 |
18 |
19 | class MLP(eqx.Module):
20 | D: int
21 | N: int
22 | L: int
23 | param_type: str
24 | use_skips: bool
25 | layers: list
26 |
27 | def __init__(
28 | self,
29 | key,
30 | d_in,
31 | N,
32 | L,
33 | d_out,
34 | act_fn,
35 | param_type,
36 | use_bias=False,
37 | use_skips=False
38 | ):
39 | self.D = d_in
40 | self.N = N
41 | self.L = L
42 | self.param_type = param_type
43 | self.use_skips = use_skips
44 |
45 | keys = jr.split(key, L)
46 | self.layers = []
47 | for i in range(L):
48 | act_fn_l = nn.Identity() if i == 0 else jpc.get_act_fn(act_fn)
49 | _in = d_in if i == 0 else N
50 | _out = d_out if (i + 1) == L else N
51 | layer = nn.Sequential(
52 | [
53 | nn.Lambda(act_fn_l),
54 | nn.Linear(
55 | _in,
56 | _out,
57 | use_bias=use_bias,
58 | key=keys[i]
59 | )
60 | ]
61 | )
62 | self.layers.append(layer)
63 |
64 | def __call__(self, x):
65 | if self.param_type == "depth_mup":
66 | for i, f in enumerate(self.layers):
67 | if (i + 1) == 1:
68 | x = f(x) / jnp.sqrt(self.D)
69 | elif 1 < (i + 1) < self.L:
70 | residual = x if self.use_skips else 0
71 | rescaling = jnp.sqrt(
72 | self.N * self.L
73 | ) if self.use_skips else jnp.sqrt(self.N)
74 | x = (f(x) / rescaling) + residual
75 | elif (i + 1) == self.L:
76 | x = f(x) / self.N
77 |
78 | else:
79 | for i, f in enumerate(self.layers):
80 | residual = x if self.use_skips and (1 < (i + 1) < self.L) else 0
81 |
82 | x = f(x) + residual
83 |
84 | return x
85 |
86 |
87 | def evaluate(model, testloader, loss_id):
88 | loss_fn = get_loss_fn(loss_id)
89 | avg_test_loss, avg_test_acc = 0, 0
90 | for x, y in testloader:
91 | x, y = x.numpy(), y.numpy()
92 | avg_test_loss += loss_fn(model, x, y)
93 | avg_test_acc += compute_accuracy(model, x, y)
94 | return avg_test_loss / len(testloader), avg_test_acc / len(testloader)
95 |
96 |
97 | @eqx.filter_jit
98 | def mse_loss(model, x, y):
99 | y_pred = vmap(model)(x)
100 | return jnp.mean((y - y_pred) ** 2)
101 |
102 |
103 | @eqx.filter_jit
104 | def cross_entropy_loss(model, x, y):
105 | logits = vmap(model)(x)
106 | log_probs = log_softmax(logits)
107 | return - jnp.mean(jnp.sum(y * log_probs, axis=-1))
108 |
109 |
110 | def get_loss_fn(loss_id):
111 | if loss_id == "mse":
112 | return mse_loss
113 | elif loss_id == "ce":
114 | return cross_entropy_loss
115 |
116 |
117 | @eqx.filter_jit
118 | def compute_accuracy(model, x, y):
119 | pred_y = vmap(model)(x)
120 | return jnp.mean(
121 | jnp.argmax(y, axis=1) == jnp.argmax(pred_y, axis=1)
122 | ) * 100
123 |
124 |
125 | @eqx.filter_jit
126 | def make_step(model, optim, opt_state, x, y, loss_id="mse"):
127 | loss_fn = get_loss_fn(loss_id)
128 | loss, grads = eqx.filter_value_and_grad(loss_fn)(model, x, y)
129 | updates, opt_state = optim.update(
130 | updates=grads,
131 | state=opt_state,
132 | params=eqx.filter(model, eqx.is_array)
133 | )
134 | model = eqx.apply_updates(model, updates)
135 | return model, opt_state, loss
136 |
137 |
138 | def train_mlp(
139 | seed,
140 | dataset,
141 | loss_id,
142 | width,
143 | n_hidden,
144 | act_fn,
145 | param_type,
146 | optim_id,
147 | lr,
148 | batch_size,
149 | max_epochs,
150 | test_every,
151 | save_dir
152 | ):
153 | set_seed(seed)
154 | key = jr.PRNGKey(seed)
155 | model_key, init_key = jr.split(key, 2)
156 | logger = setup_logger(save_dir)
157 |
158 | model = MLP(
159 | key=model_key,
160 | d_in=3072, #784, 3072
161 | N=width,
162 | L=n_hidden+1,
163 | d_out=10,
164 | act_fn=act_fn,
165 | param_type=param_type,
166 | use_bias=False,
167 | use_skips=True if param_type == "depth_mup" else False
168 | )
169 | if param_type == "depth_mup":
170 | model = init_weights(
171 | model=model,
172 | init_fn_id="standard_gauss",
173 | key=init_key
174 | )
175 |
176 | optim = optax.sgd(lr) if optim_id == "sgd" else optax.adam(lr)
177 | opt_state = optim.init(eqx.filter(model, eqx.is_array))
178 |
179 | # data
180 | train_loader, test_loader = get_dataloaders(dataset, batch_size)
181 |
182 | # key metrics
183 | train_losses = []
184 | test_losses, test_accs = [], []
185 |
186 | diverged = no_learning = False
187 | global_batch_id = 0
188 | for epoch in range(1, max_epochs + 1):
189 | print(f"\nEpoch {epoch}\n-------------------------------")
190 |
191 | for train_iter, (img_batch, label_batch) in enumerate(train_loader):
192 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
193 |
194 | model, opt_state, train_loss = make_step(
195 | model=model,
196 | optim=optim,
197 | opt_state=opt_state,
198 | x=img_batch,
199 | y=label_batch,
200 | loss_id=loss_id
201 | )
202 | train_losses.append(train_loss)
203 | global_batch_id += 1
204 |
205 | if global_batch_id % test_every == 0:
206 | print(
207 | f"Train loss: {train_loss:.7f} [{train_iter * len(img_batch)}/{len(train_loader.dataset)}]"
208 | )
209 | avg_test_loss, avg_test_acc = evaluate(
210 | model,
211 | test_loader,
212 | loss_id
213 | )
214 | test_losses.append(avg_test_loss)
215 | test_accs.append(avg_test_acc)
216 | print(f"Avg test accuracy: {avg_test_acc:.4f}\n")
217 |
218 | if np.isinf(train_loss) or np.isnan(train_loss):
219 | diverged = True
220 | break
221 |
222 | if global_batch_id >= test_every and avg_test_acc < 15:
223 | no_learning = True
224 | break
225 |
226 | if diverged:
227 | print(
228 | f"Stopping training because of diverging loss: {train_loss}"
229 | )
230 | break
231 |
232 | if no_learning:
233 | print(
234 | f"Stopping training because of close-to-chance accuracy (no learning): {avg_test_acc}"
235 | )
236 | break
237 |
238 | np.save(f"{save_dir}/train_losses.npy", train_losses)
239 | np.save(f"{save_dir}/test_losses.npy", test_losses)
240 | np.save(f"{save_dir}/test_accs.npy", test_accs)
241 |
242 |
243 | if __name__ == "__main__":
244 |
245 | parser = argparse.ArgumentParser()
246 | parser.add_argument("--results_dir", type=str, default="bp_results")
247 | parser.add_argument("--dataset", type=str, default="CIFAR10")
248 | parser.add_argument("--loss_id", type=str, default="ce")
249 | parser.add_argument("--width", type=int, default=512)
250 | parser.add_argument("--n_hidden", type=int, default=8)
251 | parser.add_argument("--act_fns", type=str, nargs='+', default=["relu"])
252 | parser.add_argument("--param_type", type=str, default="depth_mup")
253 | parser.add_argument("--optim_id", type=str, default="adam")
254 | parser.add_argument("--lrs", type=float, nargs='+', default=[1e-2])
255 | parser.add_argument("--batch_size", type=int, default=128)
256 | parser.add_argument("--max_epochs", type=int, default=20)
257 | parser.add_argument("--test_every", type=int, default=389)
258 | parser.add_argument("--n_seeds", type=int, default=3)
259 | args = parser.parse_args()
260 |
261 | for act_fn in args.act_fns:
262 | for lr in args.lrs:
263 | for seed in range(args.n_seeds):
264 | save_dir = os.path.join(
265 | args.results_dir,
266 | args.dataset,
267 | f"{args.loss_id}_loss",
268 | f"width_{args.width}",
269 | f"{args.n_hidden}_n_hidden",
270 | act_fn,
271 | f"{args.param_type}_param",
272 | args.optim_id,
273 | f"lr_{lr}",
274 | f"batch_size_{args.batch_size}",
275 | f"{args.max_epochs}_epochs",
276 | str(seed)
277 | )
278 | print(f"Starting training with config: {save_dir} with seed: {seed}")
279 | train_mlp(
280 | seed=seed,
281 | dataset=args.dataset,
282 | loss_id=args.loss_id,
283 | width=args.width,
284 | n_hidden=args.n_hidden,
285 | act_fn=act_fn,
286 | param_type=args.param_type,
287 | optim_id=args.optim_id,
288 | lr=lr,
289 | batch_size=args.batch_size,
290 | max_epochs=args.max_epochs,
291 | test_every=args.test_every,
292 | save_dir=save_dir
293 | )
294 |
--------------------------------------------------------------------------------
/jpc/_utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax.tree_util import tree_map, tree_leaves
4 | from equinox import tree_at
5 | import equinox.nn as nn
6 | from jpc import pc_energy_fn, _check_param_type
7 | from jaxtyping import PRNGKeyArray, PyTree, ArrayLike, Scalar, Array
8 | from jaxlib.xla_extension import PjitFunction
9 | from typing import Callable, Optional, Tuple
10 | from dataclasses import dataclass
11 |
12 |
13 | _ACT_FNS = [
14 | "linear", "tanh", "hard_tanh", "relu", "leaky_relu", "gelu", "selu", "silu"
15 | ]
16 |
17 |
18 | def get_act_fn(name: str) -> Callable:
19 | if name == "linear":
20 | return nn.Identity()
21 | elif name == "tanh":
22 | return jnp.tanh
23 | elif name == "hard_tanh":
24 | return jax.nn.hard_tanh
25 | elif name == "relu":
26 | return jax.nn.relu
27 | elif name == "leaky_relu":
28 | return jax.nn.leaky_relu
29 | elif name == "gelu":
30 | return jax.nn.gelu
31 | elif name == "selu":
32 | return jax.nn.selu
33 | elif name == "silu":
34 | return jax.nn.silu
35 | else:
36 | raise ValueError(f"""
37 | Invalid activation function ID. Options are {_ACT_FNS}.
38 | """)
39 |
40 |
41 | def make_mlp(
42 | key: PRNGKeyArray,
43 | input_dim: int,
44 | width: int,
45 | depth: int,
46 | output_dim: int,
47 | act_fn: str,
48 | use_bias: bool = False,
49 | param_type: str = "sp"
50 | ) -> PyTree[Callable]:
51 | """Creates a multi-layer perceptron compatible with predictive coding updates.
52 |
53 | !!! note
54 |
55 | This implementation places the activation function before the linear
56 | transformation, $\mathbf{W}_\ell \phi(\mathbf{z}_{\ell-1})$, for
57 | compatibility with the [μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))
58 | scalings when `param_type = "mupc"` in functions including
59 | [`jpc.init_activities_with_ffwd()`](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_activities_with_ffwd),
60 | [`jpc.update_activities()`](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_activities),
61 | and [`jpc.update_params()`](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_params).
62 |
63 | **Main arguments:**
64 |
65 | - `key`: `jax.random.PRNGKey` for parameter initialisation.
66 | - `input_dim`: Input dimension.
67 | - `width`: Network width.
68 | - `depth`: Network depth.
69 | - `output_dim`: Output dimension.
70 | - `act_fn`: Activation function (for all layers except the output).
71 | - `use_bias`: `False` by default.
72 | - `param_type`: Determines the parameterisation. Options are `"sp"`
73 | (standard parameterisation), `"mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))),
74 | or `"ntp"` (neural tangent parameterisation). See [`jpc._get_param_scalings()`](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings)
75 | for the specific scalings of these different parameterisations. Defaults
76 | to `"sp"`.
77 |
78 | **Returns:**
79 |
80 | List of callable fully connected layers.
81 |
82 | """
83 | _check_param_type(param_type)
84 |
85 | subkeys = jax.random.split(key, depth)
86 | layers = []
87 | for i in range(depth):
88 | act_fn_l = nn.Identity() if i == 0 else get_act_fn(act_fn)
89 | _in = input_dim if i == 0 else width
90 | _out = output_dim if (i + 1) == depth else width
91 |
92 | linear = nn.Linear(
93 | _in,
94 | _out,
95 | use_bias=use_bias,
96 | key=subkeys[i]
97 | )
98 | if param_type == "mupc":
99 | W = jax.random.normal(subkeys[i], linear.weight.shape)
100 | linear = tree_at(lambda l: l.weight, linear, W)
101 |
102 | layers.append(
103 | nn.Sequential(
104 | [nn.Lambda(act_fn_l), linear]
105 | )
106 | )
107 |
108 | return layers
109 |
110 |
111 | def make_skip_model(depth: int) -> PyTree[Callable]:
112 | """Creates a residual network with one-layer skip connections at every layer
113 | except from the input to the next layer and from the penultimate layer to
114 | the output.
115 |
116 | This is used for compatibility with the [μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))
117 | parameterisation when `param_type = "mupc"` in functions including
118 | [`jpc.init_activities_with_ffwd()`](https://thebuckleylab.github.io/jpc/api/Initialisation/#jpc.init_activities_with_ffwd),
119 | [`jpc.update_activities()`](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_activities),
120 | and [`jpc.update_params()`](https://thebuckleylab.github.io/jpc/api/Discrete%20updates/#jpc.update_params).
121 | """
122 | skips = [None] * depth
123 | for l in range(1, depth-1):
124 | skips[l] = nn.Lambda(nn.Identity())
125 |
126 | return skips
127 |
128 |
129 | def mse_loss(preds: ArrayLike, labels: ArrayLike) -> Scalar:
130 | return 0.5 * jnp.mean((labels - preds)**2)
131 |
132 |
133 | def cross_entropy_loss(logits: ArrayLike, labels: ArrayLike) -> Scalar:
134 | probs = jax.nn.softmax(logits, axis=-1)
135 | log_probs = jnp.log(probs)
136 | return - jnp.mean(jnp.sum(labels * log_probs, axis=-1))
137 |
138 |
139 | def compute_accuracy(truths: ArrayLike, preds: ArrayLike) -> Scalar:
140 | return jnp.mean(
141 | jnp.argmax(truths, axis=1) == jnp.argmax(preds, axis=1)
142 | ) * 100
143 |
144 |
145 | def get_t_max(activities_iters: PyTree[Array]) -> Array:
146 | return jnp.argmax(activities_iters[0][:, 0, 0]) - 1
147 |
148 |
149 | def compute_infer_energies(
150 | params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]],
151 | activities_iters: PyTree[Array],
152 | t_max: Array,
153 | y: ArrayLike,
154 | *,
155 | x: Optional[ArrayLike] = None,
156 | loss: str = "mse",
157 | param_type: str = "sp",
158 | weight_decay: Scalar = 0.,
159 | spectral_penalty: Scalar = 0.,
160 | activity_decay: Scalar = 0.
161 | ) -> PyTree[Scalar]:
162 | """Calculates layer energies during predictive coding inference.
163 |
164 | **Main arguments:**
165 |
166 | - `params`: Tuple with callable model layers and optional skip connections.
167 | - `activities_iters`: Layer-wise activities at every inference iteration.
168 | Note that each set of activities will have 4096 steps as first
169 | dimension by diffrax default.
170 | - `t_max`: Maximum number of inference iterations to compute energies for.
171 | - `y`: Observation or target of the generative model.
172 |
173 | **Other arguments:**
174 |
175 | - `x`: Optional prior of the generative model.
176 | - `loss`: Loss function to use at the output layer (mean squared error
177 | `"mse"` vs cross-entropy `"ce"`).
178 | - `param_type`: Determines the parameterisation. Options are `"sp"`,
179 | `"mupc"`, or `"ntp"`.
180 | - `weight_decay`: Weight decay for the weights.
181 | - `spectral_penalty`: Spectral penalty for the weights.
182 | - `activity_decay`: Activity decay for the activities.
183 |
184 | **Returns:**
185 |
186 | List of layer-wise energies at every inference iteration.
187 |
188 | """
189 | model, _ = params
190 |
191 | def loop_body(state):
192 | t, energies_iters = state
193 |
194 | energies = pc_energy_fn(
195 | params=params,
196 | activities=tree_map(lambda act: act[t], activities_iters),
197 | y=y,
198 | x=x,
199 | loss=loss,
200 | param_type=param_type,
201 | weight_decay=weight_decay,
202 | spectral_penalty=spectral_penalty,
203 | activity_decay=activity_decay,
204 | record_layers=True
205 | )
206 | energies_iters = energies_iters.at[:, t].set(energies)
207 | return t + 1, energies_iters
208 |
209 | # for memory reasons, we set 500 as the max iters to record
210 | energies_iters = jnp.zeros((len(model), 500))
211 | _, energies_iters = jax.lax.while_loop(
212 | lambda state: state[0] < t_max,
213 | loop_body,
214 | (0, energies_iters)
215 | )
216 | return energies_iters[::-1, :]
217 |
218 |
219 | def compute_activity_norms(activities: PyTree[Array]) -> Array:
220 | """Calculates $\ell^2$ norm of activities at each layer."""
221 | return jnp.array([
222 | jnp.mean(
223 | jnp.linalg.norm(
224 | a,
225 | axis=-1,
226 | ord=2
227 | )
228 | ) for a in tree_leaves(activities)
229 | ])
230 |
231 |
232 | def compute_param_norms(params):
233 | """Calculates $\ell^2$ norm of all model parameters."""
234 | def process_model_params(model_params):
235 | norms = []
236 | for p in tree_leaves(model_params):
237 | if p is None or isinstance(p, PjitFunction):
238 | norms.append(0.)
239 | elif callable(p) and not hasattr(p, 'shape'):
240 | # Skip callable functions (like Lambda-wrapped activations) that don't have shape
241 | # But keep arrays which might be callable in some JAX contexts
242 | norms.append(0.)
243 | else:
244 | try:
245 | # Check if p is a JAX array-like object
246 | if hasattr(p, 'shape') and hasattr(p, 'dtype'):
247 | norms.append(jnp.linalg.norm(jnp.ravel(p), ord=2))
248 | else:
249 | norms.append(0.)
250 | except (TypeError, AttributeError):
251 | # If ravel fails, it's not an array
252 | norms.append(0.)
253 | return jnp.array(norms)
254 |
255 | model_params, skip_model_params = params
256 | model_norms = process_model_params(model_params)
257 | skip_model_norms = (process_model_params(skip_model_params) if
258 | skip_model_params is not None else None)
259 |
260 | return model_norms, skip_model_norms
261 |
--------------------------------------------------------------------------------
/jpc/_test.py:
--------------------------------------------------------------------------------
1 | """Utility functions to test predictive coding networks."""
2 |
3 | import equinox as eqx
4 | from jpc import (
5 | init_activities_from_normal,
6 | init_activities_with_ffwd,
7 | init_activities_with_amort,
8 | mse_loss,
9 | cross_entropy_loss,
10 | compute_accuracy,
11 | solve_inference,
12 | _check_param_type
13 | )
14 | from diffrax import (
15 | AbstractSolver,
16 | AbstractStepSizeController,
17 | Heun,
18 | PIDController
19 | )
20 | from jaxtyping import PRNGKeyArray, PyTree, ArrayLike, Array, Scalar
21 | from typing import Callable, Tuple, Optional
22 |
23 |
24 | @eqx.filter_jit
25 | def test_discriminative_pc(
26 | model: PyTree[Callable],
27 | output: ArrayLike,
28 | input: ArrayLike,
29 | *,
30 | skip_model: Optional[PyTree[Callable]] = None,
31 | loss: str = "mse",
32 | param_type: str = "sp"
33 | ) -> Tuple[Scalar, Scalar]:
34 | """Computes test metrics for a discriminative predictive coding network.
35 |
36 | **Main arguments:**
37 |
38 | - `model`: List of callable model (e.g. neural network) layers.
39 | - `output`: Observation or target of the generative model.
40 | - `input`: Optional prior of the generative model.
41 |
42 | **Other arguments:**
43 |
44 | - `skip_model`: Optional skip connection model.
45 | - `loss`: Loss function to use at the output layer. Options are mean squared
46 | error `"mse"` (default) or cross-entropy `"ce"`.
47 | - `param_type`: Determines the parameterisation. Options are `"sp"`
48 | (standard parameterisation), `"mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))),
49 | or `"ntp"` (neural tangent parameterisation). See [`_get_param_scalings()`](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings)
50 | for the specific scalings of these different parameterisations. Defaults
51 | to `"sp"`.
52 |
53 | **Returns:**
54 |
55 | Test loss and accuracy of output predictions.
56 |
57 | """
58 | _check_param_type(param_type)
59 |
60 | preds = init_activities_with_ffwd(
61 | model=model,
62 | input=input,
63 | skip_model=skip_model,
64 | param_type=param_type
65 | )[-1]
66 |
67 | if loss == "mse":
68 | loss = mse_loss(preds, output)
69 | elif loss == "ce":
70 | loss = cross_entropy_loss(preds, output)
71 |
72 | acc = compute_accuracy(output, preds)
73 | return loss, acc
74 |
75 |
76 | @eqx.filter_jit
77 | def test_generative_pc(
78 | model: PyTree[Callable],
79 | output: ArrayLike,
80 | input: ArrayLike,
81 | key: PRNGKeyArray,
82 | layer_sizes: PyTree[int],
83 | batch_size: int,
84 | *,
85 | skip_model: Optional[PyTree[Callable]] = None,
86 | loss_id: str = "mse",
87 | param_type: str = "sp",
88 | sigma: Scalar = 0.05,
89 | ode_solver: AbstractSolver = Heun(),
90 | max_t1: int = 500,
91 | dt: Scalar | int = None,
92 | stepsize_controller: AbstractStepSizeController = PIDController(
93 | rtol=1e-3, atol=1e-3
94 | ),
95 | weight_decay: Scalar = 0.,
96 | spectral_penalty: Scalar = 0.,
97 | activity_decay: Scalar = 0.
98 | ) -> Tuple[Scalar, Array]:
99 | """Computes test metrics for a generative predictive coding network.
100 |
101 | Gets output predictions (e.g. of an image given a label) with a feedforward
102 | pass and calculates accuracy of inferred input (e.g. of a label given an
103 | image).
104 |
105 | **Main arguments:**
106 |
107 | - `model`: List of callable model (e.g. neural network) layers.
108 | - `output`: Observation or target of the generative model.
109 | - `input`: Prior of the generative model.
110 | - `key`: `jax.random.PRNGKey` for random initialisation of activities.
111 | - `layer_sizes`: Dimension of all layers (input, hidden and output).
112 | - `batch_size`: Dimension of data batch for activity initialisation.
113 |
114 | **Other arguments:**
115 |
116 | - `skip_model`: Optional skip connection model.
117 | - `loss_id`: Loss function to use at the output layer. Options are mean squared
118 | error `"mse"` (default) or cross-entropy `"ce"`.
119 | - `param_type`: Determines the parameterisation. Options are `"sp"`
120 | (standard parameterisation), `"mupc"` ([μPC](https://openreview.net/forum?id=lSLSzYuyfX&referrer=%5Bthe%20profile%20of%20Francesco%20Innocenti%5D(%2Fprofile%3Fid%3D~Francesco_Innocenti1))),
121 | or `"ntp"` (neural tangent parameterisation). See [`_get_param_scalings()`](https://thebuckleylab.github.io/jpc/api/Energy%20functions/#jpc._get_param_scalings)
122 | for the specific scalings of these different parameterisations. Defaults
123 | to `"sp"`.
124 | - `sigma`: Standard deviation for Gaussian to sample activities from.
125 | Defaults to 5e-2.
126 | - `ode_solver`: [diffrax ODE solver](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/)
127 | to be used. Default is [`Heun`](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/#diffrax.Heun),
128 | a 2nd order explicit Runge--Kutta method.
129 | - `max_t1`: Maximum end of integration region (500 by default).
130 | - `dt`: Integration step size. Defaults to None since the default
131 | `stepsize_controller` will automatically determine it.
132 | - `stepsize_controller`: [diffrax controller](https://docs.kidger.site/diffrax/api/stepsize_controller/)
133 | for step size integration. Defaults to [`PIDController`](https://docs.kidger.site/diffrax/api/stepsize_controller/#diffrax.PIDController).
134 | Note that the relative and absolute tolerances of the controller will
135 | also determine the steady state to terminate the solver.
136 | - `weight_decay`: Weight decay for the weights (0 by default).
137 | - `spectral_penalty`: Weight spectral penalty of the form
138 | $||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2$ (0 by default).
139 | - `activity_decay`: Activity decay for the activities (0 by default).
140 |
141 | **Returns:**
142 |
143 | Accuracy and output predictions.
144 |
145 | """
146 | _check_param_type(param_type)
147 |
148 | params = model, skip_model
149 | activities = init_activities_from_normal(
150 | key=key,
151 | layer_sizes=layer_sizes,
152 | mode="unsupervised",
153 | batch_size=batch_size,
154 | sigma=sigma
155 | )
156 | input_preds = solve_inference(
157 | params=params,
158 | activities=activities,
159 | output=output,
160 | loss_id=loss_id,
161 | param_type=param_type,
162 | solver=ode_solver,
163 | max_t1=max_t1,
164 | dt=dt,
165 | stepsize_controller=stepsize_controller,
166 | weight_decay=weight_decay,
167 | spectral_penalty=spectral_penalty,
168 | activity_decay=activity_decay
169 | )[0][0]
170 | input_acc = compute_accuracy(input, input_preds)
171 | output_preds = init_activities_with_ffwd(
172 | model=model,
173 | input=input,
174 | skip_model=skip_model,
175 | param_type=param_type
176 | )[-1]
177 | return input_acc, output_preds
178 |
179 |
180 | @eqx.filter_jit
181 | def test_hpc(
182 | generator: PyTree[Callable],
183 | amortiser: PyTree[Callable],
184 | output: ArrayLike,
185 | input: ArrayLike,
186 | key: PRNGKeyArray,
187 | layer_sizes: PyTree[int],
188 | batch_size: int,
189 | sigma: Scalar = 0.05,
190 | ode_solver: AbstractSolver = Heun(),
191 | max_t1: int = 500,
192 | dt: Scalar | int = None,
193 | stepsize_controller: AbstractStepSizeController = PIDController(
194 | rtol=1e-3, atol=1e-3
195 | )
196 | ) -> Tuple[Scalar, Scalar, Scalar, Array]:
197 | """Computes test metrics for hybrid predictive coding trained in a supervised manner.
198 |
199 | Calculates input accuracy of (i) amortiser, (ii) generator, and (iii)
200 | hybrid (amortiser + generator). Also returns output predictions (e.g. of
201 | an image given a label) with a feedforward pass of the generator.
202 |
203 | !!! note
204 |
205 | The input and output of the generator are the output and input of the
206 | amortiser, respectively.
207 |
208 | **Main arguments:**
209 |
210 | - `generator`: List of callable layers for the generative model.
211 | - `amortiser`: List of callable layers for model amortising the inference
212 | of the `generator`.
213 | - `output`: Observation or target of the generative model.
214 | - `input`: Optional prior of the generator, target for the amortiser.
215 | - `key`: `jax.random.PRNGKey` for random initialisation of activities.
216 | - `layer_sizes`: Dimension of all layers (input, hidden and output).
217 | - `batch_size`: Dimension of data batch for initialisation of activities.
218 |
219 | **Other arguments:**
220 |
221 | - `sigma`: Standard deviation for Gaussian to sample activities from.
222 | Defaults to 5e-2.
223 | - `ode_solver`: [diffrax ODE solver](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/)
224 | to be used. Default is [`Heun`](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/#diffrax.Heun),
225 | a 2nd order explicit Runge--Kutta method.
226 | - `max_t1`: Maximum end of integration region (500 by default).
227 | - `dt`: Integration step size. Defaults to None since the default
228 | `stepsize_controller` will automatically determine it.
229 | - `stepsize_controller`: [diffrax controller](https://docs.kidger.site/diffrax/api/stepsize_controller/)
230 | for step size integration. Defaults to [`PIDController`](https://docs.kidger.site/diffrax/api/stepsize_controller/#diffrax.PIDController).
231 | Note that the relative and absolute tolerances of the controller will
232 | also determine the steady state to terminate the solver.
233 |
234 | **Returns:**
235 |
236 | Accuracies of all models and output predictions.
237 |
238 | """
239 | gen_params = (generator, None)
240 | amort_activities = init_activities_with_amort(
241 | amortiser=amortiser,
242 | generator=generator,
243 | input=output
244 | )
245 | amort_preds = amort_activities[0]
246 | hpc_preds = solve_inference(
247 | params=gen_params,
248 | activities=amort_activities,
249 | output=output,
250 | solver=ode_solver,
251 | max_t1=max_t1,
252 | dt=dt,
253 | stepsize_controller=stepsize_controller
254 | )[0][0]
255 | activities = init_activities_from_normal(
256 | key=key,
257 | layer_sizes=layer_sizes,
258 | mode="unsupervised",
259 | batch_size=batch_size,
260 | sigma=sigma
261 | )
262 | gen_preds = solve_inference(
263 | params=gen_params,
264 | activities=activities,
265 | output=output,
266 | solver=ode_solver,
267 | max_t1=max_t1,
268 | dt=dt,
269 | stepsize_controller=stepsize_controller
270 | )[0][0]
271 | amort_acc = compute_accuracy(input, amort_preds)
272 | hpc_acc = compute_accuracy(input, hpc_preds)
273 | gen_acc = compute_accuracy(input, gen_preds)
274 | output_preds = init_activities_with_ffwd(
275 | model=generator,
276 | input=input,
277 | skip_model=None
278 | )[-1]
279 | return amort_acc, hpc_acc, gen_acc, output_preds
280 |
--------------------------------------------------------------------------------
/experiments/library_paper/train_mlp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 |
5 | import jax
6 | import equinox as eqx
7 | import jpc
8 | import optax
9 | from diffrax import PIDController, ConstantStepSize
10 |
11 | from utils import (
12 | setup_mlp_experiment,
13 | get_ode_solver,
14 | set_seed
15 | )
16 | from plotting import (
17 | plot_loss,
18 | plot_loss_and_accuracy,
19 | plot_runtimes,
20 | plot_norms
21 | )
22 | from experiments.datasets import get_dataloaders
23 |
24 |
25 | def evaluate(model, test_loader):
26 | avg_test_loss, avg_test_acc = 0, 0
27 | for batch_id, (img_batch, label_batch) in enumerate(test_loader):
28 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
29 |
30 | test_loss, test_acc = jpc.test_discriminative_pc(
31 | model=model,
32 | output=label_batch,
33 | input=img_batch
34 | )
35 | avg_test_loss += test_loss
36 | avg_test_acc += test_acc
37 |
38 | return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader)
39 |
40 |
41 | def train_mlp(
42 | seed,
43 | dataset,
44 | width,
45 | n_hidden,
46 | act_fn,
47 | max_t1,
48 | activity_lr,
49 | param_lr,
50 | batch_size,
51 | activity_optim_id,
52 | max_epochs,
53 | test_every,
54 | save_dir
55 | ):
56 | set_seed(seed)
57 | os.makedirs(save_dir, exist_ok=True)
58 |
59 | key = jax.random.PRNGKey(seed)
60 | model = jpc.make_mlp(
61 | key,
62 | input_dim=784,
63 | width=width,
64 | depth=n_hidden+1,
65 | output_dim=10,
66 | act_fn=act_fn
67 | )
68 |
69 | param_optim = optax.adam(param_lr)
70 | param_opt_state = param_optim.init(
71 | (eqx.filter(model, eqx.is_array), None)
72 | )
73 | train_loader, test_loader = get_dataloaders(dataset, batch_size)
74 |
75 | train_losses = []
76 | test_losses, test_accs = [], []
77 | activity_norms, param_norms, param_grad_norms = [], [], []
78 | inference_runtimes = []
79 |
80 | if activity_optim_id != "SGD":
81 | stepsize_controller = ConstantStepSize() if (
82 | activity_optim_id == "Euler"
83 | ) else PIDController(rtol=1e-3, atol=1e-3)
84 | ode_solver = get_ode_solver(activity_optim_id)
85 |
86 | elif activity_optim_id == "SGD":
87 | activity_optim = optax.sgd(activity_lr)
88 |
89 | global_batch_id = 0
90 | for epoch in range(1, max_epochs + 1):
91 | print(f"\nEpoch {epoch}\n-------------------------------")
92 |
93 | for batch_id, (img_batch, label_batch) in enumerate(train_loader):
94 | img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
95 |
96 | if activity_optim_id != "SGD":
97 | result = jpc.make_pc_step(
98 | model,
99 | param_optim,
100 | param_opt_state,
101 | output=label_batch,
102 | input=img_batch,
103 | ode_solver=ode_solver,
104 | max_t1=max_t1,
105 | dt=activity_lr,
106 | stepsize_controller=stepsize_controller,
107 | activity_norms=True,
108 | param_norms=True,
109 | grad_norms=True
110 | )
111 | model, param_opt_state = result["model"], result["opt_state"]
112 | train_loss = result["loss"]
113 | activity_norms.append(result["activity_norms"][:-1])
114 | param_norms.append(
115 | [p for p in result["model_param_norms"] if p == 0 or p % 2 == 0]
116 | )
117 | param_grad_norms.append(
118 | [p for p in result["model_grad_norms"] if p == 0 or p % 2 == 0]
119 | )
120 |
121 | elif activity_optim_id == "SGD":
122 | activities = jpc.init_activities_with_ffwd(
123 | model=model,
124 | input=img_batch
125 | )
126 | activity_opt_state = activity_optim.init(activities)
127 | train_loss = jpc.mse_loss(activities[-1], label_batch)
128 |
129 | for _ in range(max_t1):
130 | activity_update_result = jpc.update_activities(
131 | params=(model, None),
132 | activities=activities,
133 | optim=activity_optim,
134 | opt_state=activity_opt_state,
135 | output=label_batch,
136 | input=img_batch
137 | )
138 | activities = activity_update_result["activities"]
139 | activity_optim = activity_update_result["activity_optim"]
140 | activity_opt_state = activity_update_result["activity_opt_state"]
141 |
142 | param_update_result = jpc.update_params(
143 | params=(model, None),
144 | activities=activities,
145 | optim=param_optim,
146 | opt_state=param_opt_state,
147 | output=label_batch,
148 | input=img_batch
149 | )
150 | model = param_update_result["model"]
151 | param_grads = param_update_result["param_grads"]
152 | param_optim = param_update_result["param_optim"]
153 | param_opt_state = param_update_result["param_opt_state"]
154 |
155 | activity_norms.append(jpc.compute_activity_norms(activities[:-1]))
156 | param_norms.append(jpc.compute_param_norms((model, None))[0])
157 | param_grad_norms.append(jpc.compute_param_norms(param_grads)[0])
158 |
159 | if activity_optim_id != "SGD":
160 | activities0 = jpc.init_activities_with_ffwd(model, img_batch)
161 | start_time = time.time()
162 | jax.block_until_ready(
163 | jpc.solve_inference(
164 | (model, None),
165 | activities0,
166 | output=label_batch,
167 | input=img_batch,
168 | solver=ode_solver,
169 | max_t1=max_t1,
170 | dt=activity_lr,
171 | stepsize_controller=stepsize_controller
172 | )
173 | )
174 | end_time = time.time()
175 |
176 | elif activity_optim_id == "SGD":
177 | activities = jpc.init_activities_with_ffwd(
178 | model=model,
179 | input=img_batch
180 | )
181 | activity_opt_state = activity_optim.init(activities)
182 | start_time = time.time()
183 | for t in range(max_t1):
184 | jax.block_until_ready(
185 | jpc.update_activities(
186 | (model, None),
187 | activities,
188 | activity_optim,
189 | activity_opt_state,
190 | label_batch,
191 | img_batch
192 | )
193 | )
194 | end_time = time.time()
195 |
196 | train_losses.append(train_loss)
197 | inference_runtimes.append((end_time - start_time) * 1000)
198 | global_batch_id += 1
199 |
200 | if global_batch_id % test_every == 0:
201 | print(f"Train loss: {train_loss:.7f} [{batch_id * len(img_batch)}/{len(train_loader.dataset)}]")
202 |
203 | avg_test_loss, avg_test_acc = evaluate(model, test_loader)
204 | test_losses.append(avg_test_loss)
205 | test_accs.append(avg_test_acc)
206 | print(f"Avg test accuracy: {avg_test_acc:.4f}\n")
207 |
208 | plot_loss(
209 | loss=train_losses,
210 | yaxis_title="Train loss",
211 | xaxis_title="Iteration",
212 | save_path=f"{save_dir}/train_losses.pdf"
213 | )
214 | plot_loss_and_accuracy(
215 | loss=test_losses,
216 | accuracy=test_accs,
217 | mode="test",
218 | xaxis_title="Training iteration",
219 | save_path=f"{save_dir}/test_losses_and_accs.pdf"
220 | )
221 | plot_norms(
222 | norms=param_norms,
223 | norm_type="param",
224 | save_path=f"{save_dir}/param_norms.pdf"
225 | )
226 | plot_norms(
227 | norms=param_grad_norms,
228 | norm_type="param_grad",
229 | save_path=f"{save_dir}/param_grad_norms.pdf"
230 | )
231 | plot_norms(
232 | norms=activity_norms,
233 | norm_type="activity",
234 | save_path=f"{save_dir}/activity_norms.pdf"
235 | )
236 | plot_runtimes(
237 | runtimes=inference_runtimes,
238 | save_path=f"{save_dir}/inference_runtimes.pdf"
239 | )
240 |
241 | np.save(f"{save_dir}/batch_train_losses.npy", train_losses)
242 | np.save(f"{save_dir}/test_losses.npy", test_losses)
243 | np.save(f"{save_dir}/test_accs.npy", test_accs)
244 |
245 | np.save(f"{save_dir}/activity_norms.npy", activity_norms)
246 | np.save(f"{save_dir}/param_norms.npy", param_norms)
247 | np.save(f"{save_dir}/param_grad_norms.npy", param_grad_norms)
248 |
249 | np.save(f"{save_dir}/inference_runtimes.npy", inference_runtimes)
250 |
251 |
252 | if __name__ == "__main__":
253 | RESULTS_DIR = "mlp_results"
254 | DATASETS = ["MNIST", "Fashion-MNIST"]
255 | N_SEEDS = 3
256 |
257 | WIDTH = 300
258 | N_HIDDENS = [3, 5, 10]
259 | ACT_FN = "tanh"
260 |
261 | ACTIVITY_OPTIMS_ID = ["Euler", "Heun"]
262 | MAX_T1S = [5, 10, 20, 50, 100, 200, 500]
263 | ACTIVITY_LRS = [5e-1, 1e-1, 5e-2]
264 |
265 | PARAM_LR = 1e-3
266 | BATCH_SIZE = 64
267 | MAX_EPOCHS = 1
268 | TEST_EVERY = 100
269 |
270 | for dataset in DATASETS:
271 | for n_hidden in N_HIDDENS:
272 | for activity_optim_id in ACTIVITY_OPTIMS_ID:
273 | for max_t1 in MAX_T1S:
274 | for activity_lr in ACTIVITY_LRS:
275 | for seed in range(N_SEEDS):
276 | save_dir = setup_mlp_experiment(
277 | results_dir=RESULTS_DIR,
278 | dataset=dataset,
279 | width=WIDTH,
280 | n_hidden=n_hidden,
281 | act_fn=ACT_FN,
282 | max_t1=max_t1,
283 | activity_lr=activity_lr,
284 | param_lr=PARAM_LR,
285 | activity_optim_id=activity_optim_id,
286 | seed=seed
287 | )
288 | train_mlp(
289 | seed=seed,
290 | dataset=dataset,
291 | width=WIDTH,
292 | n_hidden=n_hidden,
293 | act_fn=ACT_FN,
294 | max_t1=max_t1,
295 | activity_lr=activity_lr,
296 | param_lr=PARAM_LR,
297 | batch_size=BATCH_SIZE,
298 | activity_optim_id=activity_optim_id,
299 | max_epochs=MAX_EPOCHS,
300 | test_every=TEST_EVERY,
301 | save_dir=save_dir
302 | )
303 |
--------------------------------------------------------------------------------