├── .coveragerc ├── .github └── workflows │ ├── build_book.yml │ ├── ci.yml │ ├── publish.yml │ └── style.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENCE.txt ├── Makefile ├── README.md ├── docs ├── _config.yml ├── _toc.yml ├── advanced_usage.md ├── architectures.rst ├── basic_usage.md ├── build_your_own_model.md ├── coders.rst ├── intro.md ├── logo.png └── references.bib ├── experiment ├── __init__.py ├── data │ ├── __init__.py │ ├── antarctica.py │ ├── eeg.py │ ├── gp.py │ ├── predprey.py │ ├── temperature.py │ └── util.py ├── plot.py └── util.py ├── neuralprocesses ├── __init__.py ├── aggregate.py ├── architectures │ ├── __init__.py │ ├── agnp.py │ ├── climate.py │ ├── convgnp.py │ ├── fullconvgnp.py │ ├── gnp.py │ └── util.py ├── augment.py ├── chain.py ├── coders │ ├── __init__.py │ ├── aggregate.py │ ├── attention.py │ ├── augment.py │ ├── copy.py │ ├── deepset.py │ ├── densecov.py │ ├── functional.py │ ├── fuse.py │ ├── inputs.py │ ├── mapdiag.py │ ├── nn.py │ ├── setconv │ │ ├── __init__.py │ │ ├── density.py │ │ ├── identity.py │ │ └── setconv.py │ └── shaping.py ├── coding.py ├── data │ ├── __init__.py │ ├── antarctica.py │ ├── batch.py │ ├── bimodal.py │ ├── data.py │ ├── eeg.py │ ├── gp.py │ ├── mixgp.py │ ├── mixture.py │ ├── predefined.py │ ├── predprey.py │ ├── sawtooth.py │ ├── temperature.py │ └── util.py ├── datadims.py ├── disc.py ├── dist │ ├── __init__.py │ ├── beta.py │ ├── dirac.py │ ├── dist.py │ ├── gamma.py │ ├── geom.py │ ├── normal.py │ ├── spikeslab.py │ ├── transformed.py │ └── uniform.py ├── likelihood.py ├── mask.py ├── materialise.py ├── model │ ├── __init__.py │ ├── ar.py │ ├── elbo.py │ ├── loglik.py │ ├── model.py │ ├── predict.py │ └── util.py ├── numdata.py ├── parallel.py ├── tensorflow │ ├── __init__.py │ └── nn.py ├── torch │ ├── __init__.py │ └── nn.py └── util.py ├── pyproject.toml ├── schedule.py ├── scripts ├── predprey.py ├── predprey_visualise.py ├── sawtooth_sample_ar.py ├── synthetic_extra.py ├── synthetic_parse.py ├── temperature_mae.py ├── temperature_summarise_folds.py └── temperature_visualise.py ├── tables.ipynb ├── tests ├── __init__.py ├── coders │ ├── __init__.py │ └── test_shaping.py ├── dists │ ├── __init__.py │ └── test_normal.py ├── gnp │ ├── __init__.py │ ├── autoencoding.py │ ├── gnp.py │ └── util.py ├── test_architectures.py ├── test_augment.py ├── test_chain.py ├── test_data.py ├── test_discretisation.py ├── test_distribution.py ├── test_mask.py ├── test_model.py ├── test_unet.py ├── test_util.py └── util.py ├── todo.tasks └── train.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = neuralprocesses/_version.py 3 | 4 | [report] 5 | exclude_lines = 6 | pragma: no cover 7 | pragma: specific no cover.*${PRAGMA_VERSION} 8 | -------------------------------------------------------------------------------- /.github/workflows/build_book.yml: -------------------------------------------------------------------------------- 1 | name: Build Jupyter Book 2 | 3 | on: 4 | # Trigger the workflow on push to main branch. 5 | push: 6 | branches: 7 | - main 8 | 9 | # This job installs dependencies, build the book, and pushes it to `gh-pages`. 10 | jobs: 11 | build-and-deploy-book: 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ubuntu-latest] 16 | python-version: [3.8] 17 | steps: 18 | - uses: actions/checkout@v2 19 | 20 | # Install dependencies. 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v1 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install -e '.[dev]' 28 | 29 | # Build the book. 30 | - name: Build 31 | run: | 32 | jupyter-book build docs 33 | 34 | # Deploy the book's HTML to the branch `gh-pages`. 35 | - name: Deploy to GitHub Pages 36 | uses: peaceiris/actions-gh-pages@v3.6.1 37 | with: 38 | github_token: ${{ secrets.GITHUB_TOKEN }} 39 | publish_dir: docs/_build/html 40 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: [3.8, 3.9, "3.10", "3.11"] 13 | steps: 14 | - uses: actions/checkout@v2 15 | 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | python -m pip install --upgrade --no-cache-dir -e '.[dev]' 25 | 26 | - name: Test 27 | env: 28 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 29 | COVERALLS_FLAG_NAME: ${{ matrix.python-version }} 30 | COVERALLS_PARALLEL: true 31 | run: | 32 | PRAGMA_VERSION=`python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))"` \ 33 | pytest -v --cov=neuralprocesses --cov-report term-missing 34 | coveralls --service=github 35 | 36 | coveralls: 37 | name: Finish coverage 38 | needs: test 39 | runs-on: ubuntu-latest 40 | container: python:3-slim 41 | steps: 42 | - name: Finished 43 | run: | 44 | pip3 install --upgrade coveralls 45 | coveralls --finish 46 | env: 47 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 48 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python package using Twine when a release is 2 | # created. For more information see the following link: 3 | # https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 4 | 5 | name: Publish to PyPI 6 | 7 | on: 8 | release: 9 | types: [published] 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | # Make sure tags are fetched so we can get a version. 19 | - run: | 20 | git fetch --prune --unshallow --tags 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: '3.x' 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install --upgrade build twine 30 | 31 | - name: Build and publish 32 | env: 33 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 34 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 35 | 36 | run: | 37 | python -m build 38 | twine upload dist/* 39 | -------------------------------------------------------------------------------- /.github/workflows/style.yml: -------------------------------------------------------------------------------- 1 | name: Code style 2 | on: 3 | - push 4 | - pull_request 5 | 6 | jobs: 7 | check: 8 | name: Code style 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 3.9 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: 3.9 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | python -m pip install pre-commit 20 | pre-commit install 21 | - name: Check code style 22 | run: pre-commit run --all-files 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Autogenerated files 2 | neuralprocesses/_version.py 3 | 4 | # Byte-compiled file 5 | *.pyc 6 | 7 | # Virtual environments 8 | venv 9 | /dist 10 | pip-wheel-metadata 11 | 12 | # Packaging 13 | *.egg-info 14 | 15 | # Documentation and coverage 16 | docs/_build 17 | docs/_static 18 | docs/source 19 | docs/readme.rst 20 | cover 21 | 22 | # Other 23 | .DS_Store 24 | *.swp 25 | .vscode/* 26 | venv-np 27 | _experiments 28 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.9 3 | repos: 4 | - repo: https://github.com/psf/black 5 | rev: 23.7.0 6 | hooks: 7 | - id: black 8 | exclude: ^(scripts/|experiment/) 9 | - repo: https://github.com/pycqa/isort 10 | rev: 5.12.0 11 | hooks: 12 | - id: isort 13 | exclude: ^(scripts/|experiment/) 14 | args: ["--profile", "black"] 15 | # - repo: https://github.com/pycqa/flake8 16 | # rev: 5.0.4 17 | # hooks: 18 | # - id: flake8 19 | # args: ["--max-line-length=88", "--extend-ignore=E203,F811"] 20 | # additional_dependencies: 21 | # - flake8-bugbear>=22.12 22 | # - flake8-noqa>=1.3 23 | -------------------------------------------------------------------------------- /LICENCE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Wessel Bruinsma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: install test 2 | 3 | PACKAGE := neuralprocesses 4 | 5 | install: 6 | pip install -e '.[dev]' 7 | 8 | test: 9 | pre-commit run --all-files 10 | PRAGMA_VERSION=`python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))"` \ 11 | pytest tests -v --cov=$(PACKAGE) --cov-report html:cover --cov-report term-missing 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Neural Processes](http://github.com/wesselb/neuralprocesses) 2 | 3 | [![CI](https://github.com/wesselb/neuralprocesses/workflows/CI/badge.svg)](https://github.com/wesselb/neuralprocesses/actions?query=workflow%3ACI) 4 | [![Coverage Status](https://coveralls.io/repos/github/wesselb/neuralprocesses/badge.svg?branch=main)](https://coveralls.io/github/wesselb/neuralprocesses?branch=master) 5 | [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://wesselb.github.io/neuralprocesses) 6 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 7 | 8 | A framework for composing Neural Processes in Python. 9 | 10 | ## Installation 11 | 12 | ``` 13 | pip install neuralprocesses tensorflow tensorflow-probability # For use with TensorFlow 14 | pip install neuralprocesses torch # For use with PyTorch 15 | ``` 16 | 17 | If something is not working or unclear, please feel free to open an issue. 18 | 19 | ## Documentation 20 | 21 | See [here](https://wesselb.github.io/neuralprocesses). 22 | 23 | ## TL;DR! Just Get me Started! 24 | 25 | Here you go: 26 | 27 | ```python 28 | import torch 29 | 30 | import neuralprocesses.torch as nps 31 | 32 | # Construct a ConvCNP. 33 | convcnp = nps.construct_convgnp(dim_x=1, dim_y=2, likelihood="het") 34 | 35 | # Construct optimiser. 36 | opt = torch.optim.Adam(convcnp.parameters(), 1e-3) 37 | 38 | # Training: optimise the model for 32 batches. 39 | for _ in range(32): 40 | # Sample a batch of new context and target sets. Replace this with your data. The 41 | # shapes are `(batch_size, dimensionality, num_data)`. 42 | xc = torch.randn(16, 1, 10) # Context inputs 43 | yc = torch.randn(16, 2, 10) # Context outputs 44 | xt = torch.randn(16, 1, 15) # Target inputs 45 | yt = torch.randn(16, 2, 15) # Target output 46 | 47 | # Compute the loss and update the model parameters. 48 | loss = -torch.mean(nps.loglik(convcnp, xc, yc, xt, yt, normalise=True)) 49 | opt.zero_grad(set_to_none=True) 50 | loss.backward() 51 | opt.step() 52 | 53 | # Testing: make some predictions. 54 | mean, var, noiseless_samples, noisy_samples = nps.predict( 55 | convcnp, 56 | torch.randn(16, 1, 10), # Context inputs 57 | torch.randn(16, 2, 10), # Context outputs 58 | torch.randn(16, 1, 15), # Target inputs 59 | ) 60 | ``` 61 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | # Book settings 2 | title: NeuralProcesses 3 | author: Wessel Bruinsma 4 | copyright: "2022" 5 | logo: logo.png 6 | 7 | # Force re-execution of notebooks on each build. 8 | # See https://jupyterbook.org/content/execute.html 9 | execute: 10 | execute_notebooks: force 11 | 12 | # Define the name of the latex output file for PDF builds. 13 | latex: 14 | latex_documents: 15 | targetname: book.tex 16 | 17 | # Load AutoDoc extension. 18 | sphinx: 19 | extra_extensions: 20 | - 'sphinx.ext.autodoc' 21 | - 'sphinx.ext.napoleon' 22 | - 'sphinx.ext.viewcode' 23 | 24 | # Add a BiBTeX file so that we can create citations. 25 | bibtex_bibfiles: 26 | - references.bib 27 | 28 | # Information about where the book exists on the web. 29 | repository: 30 | url: https://github.com/wesselb/NeuralProcesses 31 | path_to_book: docs 32 | branch: main 33 | 34 | # Add GitHub buttons to your book. 35 | html: 36 | use_issues_button: true 37 | use_repository_button: true 38 | -------------------------------------------------------------------------------- /docs/_toc.yml: -------------------------------------------------------------------------------- 1 | # Table of contents 2 | # Learn more at https://jupyterbook.org/customize/toc.html 3 | 4 | format: jb-book 5 | root: intro 6 | chapters: 7 | - file: basic_usage 8 | - file: architectures 9 | - file: advanced_usage 10 | - file: build_your_own_model 11 | - file: coders 12 | -------------------------------------------------------------------------------- /docs/advanced_usage.md: -------------------------------------------------------------------------------- 1 | # Advanced Usage 2 | 3 | ## Masking 4 | 5 | In this section, we'll take the following ConvGNP as a running example: 6 | 7 | ```python 8 | import lab as B 9 | import torch 10 | 11 | import neuralprocesses.torch as nps 12 | 13 | cnp = nps.construct_convgnp( 14 | dim_x=2, 15 | dim_yc=(1, 1), # Two context sets, both with one channel 16 | dim_yt=1, 17 | ) 18 | 19 | # Construct two sample context sets with one on a grid. 20 | xc = B.randn(torch.float32, 1, 2, 20) 21 | yc = B.randn(torch.float32, 1, 1, 20) 22 | xc_grid = (B.randn(torch.float32, 1, 1, 10), B.randn(torch.float32, 1, 1, 15)) 23 | yc_grid = B.randn(torch.float32, 1, 1, 10, 15) 24 | 25 | # Contruct sample target inputs 26 | xt = B.randn(torch.float32, 1, 2, 50) 27 | ``` 28 | 29 | For example, then predictions can be made via 30 | 31 | ```python 32 | >>> pred = cnp([(xc, yc), (xc_grid, yc_grid)], xt) 33 | ``` 34 | 35 | ### Masking Particular Inputs 36 | 37 | Suppose that due to a particular reason you didn't observe `yc_grid[5, 5]`. 38 | In the specification above, it is not possible to just omit that one element. 39 | The proposed solution is to use a _mask_. 40 | A mask `mask` is a tensor of the same size as the context outputs (`yc_grid` in this case) 41 | but with _only one channel_ consisting of ones and zeros. 42 | If `mask[i, 0, j, k] = 1`, then that means that `yc_grid[i, :, j, k]` is observed. 43 | On the other hand, if `mask[i, 0, j, k] = 0`, then that means that `yc_grid[i, :, j, k]` 44 | is _not_ observed. 45 | `yc_grid[i, :, j, k]` will still have values, _which must be not NaNs_, but those values 46 | will be ignored. 47 | To mask context outputs, use `nps.Masked(yc_grid, mask)`. 48 | 49 | Definition: 50 | 51 | ```python 52 | masked_yc = Masked(yc, mask) 53 | ``` 54 | 55 | Example: 56 | 57 | ```python 58 | >>> mask = B.ones(torch.float32, 1, 1, *B.shape(yc_grid, 2, 3)) 59 | 60 | >>> mask[:, :, 5, 5] = 0 61 | 62 | >>> pred = cnp([(xc, yc), (xc_grid, nps.Masked(yc_grid, mask))], xt) 63 | ``` 64 | 65 | Masking is also possible for non-gridded contexts. 66 | 67 | Example: 68 | 69 | ```python 70 | >>> mask = B.ones(torch.float32, 1, 1, B.shape(yc, 2)) 71 | 72 | >>> mask[:, :, 2:7] = 0 # Elements 3 to 7 are missing. 73 | 74 | >>> pred = cnp([(xc, nps.Masked(yc, mask)), (xc_grid, yc_grid)], xt) 75 | ``` 76 | 77 | ### Using Masks to Batch Context Sets of Different Sizes 78 | 79 | Suppose that we also had another context set, of a different size: 80 | 81 | ```python 82 | # Construct another two sample context sets with one on a grid. 83 | xc2 = B.randn(torch.float32, 1, 2, 30) 84 | yc2 = B.randn(torch.float32, 1, 1, 30) 85 | xc2_grid = (B.randn(torch.float32, 1, 1, 5), B.randn(torch.float32, 1, 1, 20)) 86 | yc2_grid = B.randn(torch.float32, 1, 1, 5, 20) 87 | ``` 88 | 89 | Rather than running the model once for `[(xc, yc), (xc_grid, yc_grid)]` and once for 90 | `[(xc2, yc2), (xc2_grid, yc2_grid)]`, we would like to concatenate the 91 | two context sets along the batch dimension and run the model only once. 92 | This, however, doesn't work, because the twe context sets have different sizes. 93 | 94 | The proposed solution is to pad the context sets with zeros to align them, concatenate 95 | the padded contexts, and use a mask to reject the padded zeros. 96 | The function `nps.merge_contexts` can be used to do this automatically. 97 | 98 | Definition: 99 | 100 | ```python 101 | xc_merged, yc_merged = nps.merge_contexts((xc1, yc1), (xc2, yc2), ...) 102 | ``` 103 | 104 | Example: 105 | 106 | ```python 107 | xc_merged, yc_merged = nps.merge_contexts((xc, yc), (xc2, yc2)) 108 | xc_grid_merged, yc_grid_merged = nps.merge_contexts( 109 | (xc_grid, yc_grid), (xc2_grid, yc2_grid) 110 | ) 111 | ``` 112 | 113 | ```python 114 | >>> pred = cnp( 115 | [(xc_merged, yc_merged), (xc_grid_merged, yc_grid_merged)], 116 | B.concat(xt, xt, axis=0) 117 | ) 118 | ``` 119 | 120 | ## Equivalence of PyTorch and TensorFlow Architectures 121 | 122 | Jonny Taylor has a created a very helpful [Gist](https://gist.github.com/DrJonnyT/c946044591fb4ce922b0f5e7fd0f047a) 123 | which can be used to verify the equivalence of PyTorch and TensorFlow versions of architectures. 124 | 125 | -------------------------------------------------------------------------------- /docs/architectures.rst: -------------------------------------------------------------------------------- 1 | List of Predefined Architectures 2 | ================================ 3 | 4 | Deep-Set-Based NPs 5 | ------------------ 6 | .. automodule:: neuralprocesses.architectures.gnp 7 | :members: 8 | 9 | Attentive NPs 10 | ------------- 11 | .. automodule:: neuralprocesses.architectures.agnp 12 | :members: 13 | 14 | Convolutional NPs 15 | ----------------- 16 | .. automodule:: neuralprocesses.architectures.convgnp 17 | :members: 18 | 19 | Fully Convolutional NPs 20 | ----------------------- 21 | .. automodule:: neuralprocesses.architectures.fullconvgnp 22 | :members: 23 | 24 | Specific Models for Climate Experiments 25 | --------------------------------------- 26 | .. automodule:: neuralprocesses.architectures.climate 27 | :members: 28 | -------------------------------------------------------------------------------- /docs/build_your_own_model.md: -------------------------------------------------------------------------------- 1 | # Build Your Own Model 2 | 3 | NeuralProcesses offers building blocks which can be put together in various ways to 4 | construct models suited to a particular application. 5 | 6 | ## Examples in PyTorch 7 | 8 | None yet. 9 | 10 | ## Examples in TensorFlow 11 | 12 | ## ConvGNP 13 | 14 | ```python 15 | import lab as B 16 | import tensorflow as tf 17 | 18 | import neuralprocesses.tensorflow as nps 19 | 20 | dim_x = 1 21 | dim_y = 1 22 | 23 | # CNN architecture: 24 | unet = nps.UNet( 25 | dim=dim_x, 26 | in_channels=2 * dim_y, 27 | out_channels=(2 + 512) * dim_y, 28 | channels=(8, 16, 16, 32, 32, 64), 29 | ) 30 | 31 | # Discretisation of the functional embedding: 32 | disc = nps.Discretisation( 33 | points_per_unit=64, 34 | multiple=2**unet.num_halving_layers, 35 | margin=0.1, 36 | dim=dim_x, 37 | ) 38 | 39 | # Create the encoder and decoder and construct the model. 40 | encoder = nps.FunctionalCoder( 41 | disc, 42 | nps.Chain( 43 | nps.PrependDensityChannel(), 44 | nps.SetConv(scale=1 / disc.points_per_unit), 45 | nps.DivideByFirstChannel(), 46 | nps.DeterministicLikelihood(), 47 | ), 48 | ) 49 | decoder = nps.Chain( 50 | unet, 51 | nps.SetConv(scale=1 / disc.points_per_unit), 52 | nps.LowRankGaussianLikelihood(512), 53 | ) 54 | convgnp = nps.Model(encoder, decoder) 55 | 56 | # Run the model on some random data. 57 | dist = convgnp( 58 | B.randn(tf.float32, 16, 1, 10), 59 | B.randn(tf.float32, 16, 1, 10), 60 | B.randn(tf.float32, 16, 1, 15), 61 | ) 62 | ``` 63 | 64 | ## ConvGNP with Auxiliary Variables 65 | 66 | ```python 67 | import lab as B 68 | import tensorflow as tf 69 | 70 | import neuralprocesses.tensorflow as nps 71 | 72 | dim_x = 2 73 | # We will use two target sets with output dimensionalities `dim_y` and `dim_y2`. 74 | dim_y = 1 75 | dim_y2 = 10 76 | # We will also use auxiliary target information of dimensionality `dim_aux_t`. 77 | dim_aux_t = 7 78 | 79 | # CNN architecture: 80 | unet = nps.UNet( 81 | dim=dim_x, 82 | # The outputs are `dim_y`-dimensional, and we will use another context set 83 | # consisting of `dim_y2` variables. Both of these context sets will also have a 84 | # density channel. 85 | in_channels=dim_y + 1 + dim_y2 + 1, 86 | out_channels=8, 87 | channels=(8, 16, 16, 32, 32, 64), 88 | ) 89 | 90 | # Discretisation of the functional embedding: 91 | disc = nps.Discretisation( 92 | points_per_unit=32, 93 | multiple=2**unet.num_halving_layers, 94 | margin=0.1, 95 | dim=dim_x, 96 | ) 97 | 98 | # Create the encoder and decoder and construct the model. 99 | encoder = nps.FunctionalCoder( 100 | disc, 101 | nps.Chain( 102 | nps.PrependDensityChannel(), 103 | # Use a separate set conv for every context set. Here we initialise the length 104 | # scales of these set convs both to `1 / disc.points_per_unit`. 105 | nps.Parallel( 106 | nps.SetConv(scale=1 / disc.points_per_unit), 107 | nps.SetConv(scale=1 / disc.points_per_unit), 108 | ), 109 | nps.DivideByFirstChannel(), 110 | # Concatenate the encodings of the context sets. 111 | nps.Concatenate(), 112 | nps.DeterministicLikelihood(), 113 | ), 114 | ) 115 | decoder = nps.Chain( 116 | unet, 117 | nps.SetConv(scale=1 / disc.points_per_unit), 118 | # `nps.Augment` will concatenate any auxiliary information to the current encoding 119 | # before proceedings. 120 | nps.Augment( 121 | nps.Chain( 122 | nps.MLP( 123 | # Input dimensionality is equal to the number of channels coming out of 124 | # `unet` plus the dimensionality of the auxiliary target information. 125 | in_dim=8 + dim_aux_t, 126 | layers=(128,) * 3, 127 | out_dim=(2 + 512) * dim_y, 128 | ), 129 | nps.LowRankGaussianLikelihood(512), 130 | ) 131 | ) 132 | ) 133 | convgnp = nps.Model(encoder, decoder) 134 | 135 | # Run the model on some random data. 136 | dist = convgnp( 137 | [ 138 | ( 139 | B.randn(tf.float32, 16, dim_x, 10), 140 | B.randn(tf.float32, 16, dim_y, 10), 141 | ), 142 | ( 143 | # The second context set is given on a grid. 144 | (B.randn(tf.float32, 16, 1, 12), B.randn(tf.float32, 16, 1, 12)), 145 | B.randn(tf.float32, 16, dim_y2, 12, 12), 146 | ) 147 | ], 148 | B.randn(tf.float32, 16, dim_x, 15), 149 | aux_t=B.randn(tf.float32, 16, dim_aux_t, 15), 150 | ) 151 | ``` 152 | -------------------------------------------------------------------------------- /docs/coders.rst: -------------------------------------------------------------------------------- 1 | List of Coders 2 | ============== 3 | 4 | Deep Sets 5 | --------- 6 | .. automodule:: neuralprocesses.coders.deepset 7 | :members: 8 | .. automodule:: neuralprocesses.coders.inputs 9 | :members: 10 | 11 | Attention 12 | --------- 13 | .. automodule:: neuralprocesses.coders.attention 14 | :members: 15 | 16 | Convolutional Deep Sets 17 | ----------------------- 18 | .. automodule:: neuralprocesses.coders.functional 19 | :members: 20 | .. automodule:: neuralprocesses.coders.setconv.setconv 21 | :members: 22 | .. automodule:: neuralprocesses.coders.setconv.density 23 | :members: 24 | .. automodule:: neuralprocesses.coders.setconv.identity 25 | :members: 26 | 27 | Neural Networks 28 | --------------- 29 | .. automodule:: neuralprocesses.coders.nn 30 | :members: 31 | 32 | Aggregations 33 | ------------ 34 | .. automodule:: neuralprocesses.coders.aggregate 35 | :members: 36 | 37 | Auxiliary Variables 38 | ------------------- 39 | .. automodule:: neuralprocesses.coders.augment 40 | :members: 41 | 42 | Utility 43 | ------- 44 | .. automodule:: neuralprocesses.coders.copy 45 | :members: 46 | .. automodule:: neuralprocesses.coders.shaping 47 | :members: 48 | .. automodule:: neuralprocesses.coders.fuse 49 | :members: 50 | .. automodule:: neuralprocesses.coders.mapdiag 51 | :members: 52 | .. automodule:: neuralprocesses.coders.densecov 53 | :members: 54 | 55 | -------------------------------------------------------------------------------- /docs/intro.md: -------------------------------------------------------------------------------- 1 | # NeuralProcesses 2 | 3 | Welcome to the package! 4 | -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wesselb/neuralprocesses/89faf25e5bfd481865d344c6c1aec256c1fd6961/docs/logo.png -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | --- 2 | --- -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * 2 | from .plot import * 3 | from .util import * 4 | -------------------------------------------------------------------------------- /experiment/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .gp import * 2 | from .predprey import * 3 | from .eeg import * 4 | from .antarctica import * 5 | from .temperature import * 6 | from .util import * 7 | -------------------------------------------------------------------------------- /experiment/data/antarctica.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import neuralprocesses.torch as nps 5 | from .util import register_data 6 | 7 | __all__ = [] 8 | 9 | 10 | def setup(args, config, *, num_tasks_train, num_tasks_cv, num_tasks_eval, device): 11 | 12 | root_dir = f"{os.getcwd()}/antarctica-data" 13 | 14 | config["default"]["rate"] = 1e-4 15 | config["default"]["epochs"] = 200 16 | config["dim_x"] = 2 17 | config["dim_y"] = 2 18 | 19 | num_tasks_train = 10**4 20 | num_tasks_cv = 10**3 21 | num_tasks_eval = 10**3 22 | 23 | # Configure the convolutional models: 24 | config["points_per_unit"] = 256 25 | config["margin"] = 0.2 26 | config["conv_receptive_field"] = 1. 27 | config["unet_strides"] = (1,) + (2,) * 5 28 | 29 | config["unet_channels"] = (64, 64, 64, 64, 64, 64) 30 | config["encoder_scales"] = 2 / config["points_per_unit"] 31 | config["transform"] = None 32 | 33 | # Other settings specific to the EEG experiments: 34 | config["plot"] = {1: {"range": (0, 1), "axvline": []}} 35 | 36 | gen_train = nps.AntarcticaGenerator( 37 | root_dir=root_dir, 38 | dtype=torch.float32, 39 | seed=0, 40 | batch_size=args.batch_size, 41 | num_tasks=num_tasks_train, 42 | subset="train", 43 | device=device, 44 | ) 45 | 46 | gen_cv = lambda: ( 47 | nps.AntarcticaGenerator( 48 | root_dir=root_dir, 49 | dtype=torch.float32, 50 | seed=0, 51 | batch_size=args.batch_size, 52 | num_tasks=num_tasks_cv, 53 | device=device, 54 | ) 55 | ) 56 | 57 | def gens_eval(): 58 | return [] 59 | 60 | return gen_train, gen_cv, gens_eval 61 | 62 | register_data("antarctica", setup) 63 | -------------------------------------------------------------------------------- /experiment/data/eeg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import neuralprocesses.torch as nps 4 | from .util import register_data 5 | 6 | __all__ = [] 7 | 8 | 9 | def setup(args, config, *, num_tasks_train, num_tasks_cv, num_tasks_eval, device): 10 | config["default"]["rate"] = 5e-5 # 2e-4 11 | config["default"]["epochs"] = 1000 # 200 12 | config["dim_x"] = 1 13 | config["dim_y"] = 7 14 | 15 | # Architecture choices specific for the EEG experiments: 16 | config["transform"] = None 17 | config["epsilon"] = 1e-6 18 | config["enc_same"] = True 19 | 20 | # Configure the convolutional models: 21 | config["points_per_unit"] = 256 22 | config["margin"] = 0.1 23 | config["conv_receptive_field"] = 1.0 24 | config["unet_strides"] = (1,) + (2,) * 5 25 | 26 | # Increase the capacity of the ConvCNP, ConvGNP, and ConvNP to account for the many 27 | # outputs. The FullConvGNP is already large enough... 28 | if args.model in {"convcnp", "convgnp"}: 29 | config["unet_channels"] = (128,) * 6 30 | elif args.model == "convnp": 31 | config["unet_channels"] = (96,) * 6 32 | else: 33 | config["unet_channels"] = (64,) * 6 34 | config["encoder_scales"] = 0.77 / 256 35 | config["fullconvgnp_kernel_factor"] = 1 36 | 37 | # Other settings specific to the EEG experiments: 38 | config["plot"] = {1: {"range": (0, 1), "axvline": []}} 39 | 40 | gen_train = nps.EEGGenerator( 41 | dtype=torch.float32, 42 | seed=0, 43 | batch_size=args.batch_size, 44 | num_tasks=num_tasks_train, 45 | mode=config["eeg_mode"], # "random", 46 | subset="train", 47 | device=device, 48 | ) 49 | 50 | gen_cv = lambda: ( 51 | nps.EEGGenerator( 52 | dtype=torch.float32, 53 | seed=20, 54 | batch_size=args.batch_size, 55 | num_tasks=num_tasks_cv, 56 | mode=config["eeg_mode"], # "random", 57 | subset="cv", 58 | device=device, 59 | ) 60 | ) 61 | 62 | def gens_eval(): 63 | return [ 64 | ( 65 | "Interpolation", 66 | nps.EEGGenerator( 67 | dtype=torch.float32, 68 | batch_size=args.batch_size, 69 | num_tasks=num_tasks_eval, 70 | mode="interpolation", 71 | subset="eval", 72 | device=device, 73 | ), 74 | ), 75 | ( 76 | "Forecasting", 77 | nps.EEGGenerator( 78 | dtype=torch.float32, 79 | batch_size=args.batch_size, 80 | num_tasks=num_tasks_eval, 81 | mode="forecasting", 82 | subset="eval", 83 | device=device, 84 | ), 85 | ), 86 | ( 87 | "Reconstruction", 88 | nps.EEGGenerator( 89 | dtype=torch.float32, 90 | batch_size=args.batch_size, 91 | num_tasks=num_tasks_eval, 92 | mode="reconstruction", 93 | subset="eval", 94 | device=device, 95 | ), 96 | ), 97 | ] 98 | 99 | return gen_train, gen_cv, gens_eval 100 | 101 | 102 | register_data("eeg", setup) 103 | -------------------------------------------------------------------------------- /experiment/data/gp.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | 5 | import neuralprocesses.torch as nps 6 | from .util import register_data 7 | 8 | __all__ = [] 9 | 10 | 11 | def setup(name, args, config, *, num_tasks_train, num_tasks_cv, num_tasks_eval, device): 12 | config["dim_x"] = args.dim_x 13 | config["dim_y"] = args.dim_y 14 | 15 | # Architecture choices specific for the GP experiments: 16 | # TODO: We should use a stride of 1 in the first layer, but for compatibility 17 | # reasons with the models we already trained, we keep it like this. 18 | config["unet_strides"] = (2,) * 6 19 | config["conv_receptive_field"] = 4 20 | config["margin"] = 0.1 21 | if args.dim_x == 1: 22 | config["points_per_unit"] = 64 23 | elif args.dim_x == 2: 24 | # Reduce the PPU to reduce memory consumption. 25 | config["points_per_unit"] = 32 26 | # Since the PPU is reduced, we can also take off a layer of the UNet. 27 | config["unet_strides"] = config["unet_strides"][:-1] 28 | config["unet_channels"] = config["unet_channels"][:-1] 29 | else: 30 | raise RuntimeError(f"Invalid input dimensionality {args.dim_x}.") 31 | 32 | # Other settings specific to the GP experiments: 33 | config["plot"] = { 34 | 1: {"range": (-2, 4), "axvline": [2]}, 35 | 2: {"range": ((-2, 2), (-2, 2))}, 36 | } 37 | config["transform"] = None 38 | 39 | gen_train = nps.construct_predefined_gens( 40 | torch.float32, 41 | seed=10, 42 | batch_size=args.batch_size, 43 | num_tasks=num_tasks_train, 44 | dim_x=args.dim_x, 45 | dim_y=args.dim_y, 46 | pred_logpdf=False, 47 | pred_logpdf_diag=False, 48 | device=device, 49 | mean_diff=config["mean_diff"], 50 | )[name] 51 | 52 | gen_cv = lambda: nps.construct_predefined_gens( 53 | torch.float32, 54 | seed=20, # Use a different seed! 55 | batch_size=args.batch_size, 56 | num_tasks=num_tasks_cv, 57 | dim_x=args.dim_x, 58 | dim_y=args.dim_y, 59 | pred_logpdf=True, 60 | pred_logpdf_diag=True, 61 | device=device, 62 | mean_diff=config["mean_diff"], 63 | )[name] 64 | 65 | def gens_eval(): 66 | return [ 67 | ( 68 | eval_name, 69 | nps.construct_predefined_gens( 70 | torch.float32, 71 | seed=30, # Use yet another seed! 72 | batch_size=args.batch_size, 73 | num_tasks=num_tasks_eval, 74 | dim_x=args.dim_x, 75 | dim_y=args.dim_y, 76 | pred_logpdf=True, 77 | pred_logpdf_diag=True, 78 | device=device, 79 | x_range_context=x_range_context, 80 | x_range_target=x_range_target, 81 | mean_diff=config["mean_diff"], 82 | )[args.data], 83 | ) 84 | for eval_name, x_range_context, x_range_target in [ 85 | ("interpolation in training range", (-2, 2), (-2, 2)), 86 | ("interpolation beyond training range", (2, 6), (2, 6)), 87 | ("extrapolation beyond training range", (-2, 2), (2, 6)), 88 | ] 89 | ] 90 | 91 | return gen_train, gen_cv, gens_eval 92 | 93 | names = [ 94 | "eq", 95 | "matern", 96 | "weakly-periodic", 97 | "mix-eq", 98 | "mix-matern", 99 | "mix-weakly-periodic", 100 | "sawtooth", 101 | "mixture", 102 | ] 103 | 104 | for name in names: 105 | register_data( 106 | name, 107 | partial(setup, name), 108 | requires_dim_x=True, 109 | requires_dim_y=True, 110 | ) 111 | -------------------------------------------------------------------------------- /experiment/data/predprey.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import neuralprocesses.torch as nps 4 | from .util import register_data 5 | 6 | __all__ = [] 7 | 8 | 9 | def setup(args, config, *, num_tasks_train, num_tasks_cv, num_tasks_eval, device): 10 | config["default"]["rate"] = 1e-4 11 | config["default"]["epochs"] = 200 12 | config["dim_x"] = 1 13 | config["dim_y"] = 2 14 | 15 | # Architecture choices specific for the predator-prey experiments: 16 | config["transform"] = "softplus" 17 | 18 | # Configure the convolutional models: 19 | config["points_per_unit"] = 4 20 | config["margin"] = 1 21 | config["conv_receptive_field"] = 100 22 | config["unet_strides"] = (1,) + (2,) * 6 23 | config["unet_channels"] = (64,) * 7 24 | 25 | # Other settings specific to the predator-prey experiments: 26 | config["plot"] = {1: {"range": (0, 100), "axvline": []}} 27 | 28 | gen_train = nps.PredPreyGenerator( 29 | torch.float32, 30 | seed=10, 31 | batch_size=args.batch_size, 32 | num_tasks=num_tasks_train, 33 | mode="random", 34 | device=device, 35 | ) 36 | gen_cv = lambda: nps.PredPreyGenerator( 37 | torch.float32, 38 | seed=20, 39 | batch_size=args.batch_size, 40 | num_tasks=num_tasks_cv, 41 | mode="random", 42 | device=device, 43 | ) 44 | 45 | def gens_eval(): 46 | return [ 47 | # For the real tasks, the batch size will be one. Keep the number of batches 48 | # the same. 49 | ( 50 | "Interpolation (Simulated)", 51 | nps.PredPreyGenerator( 52 | torch.float32, 53 | seed=30, 54 | batch_size=args.batch_size, 55 | num_tasks=num_tasks_eval, 56 | mode="interpolation", 57 | device=device, 58 | ), 59 | ), 60 | ( 61 | "Forecasting (Simulated)", 62 | nps.PredPreyGenerator( 63 | torch.float32, 64 | seed=30, 65 | batch_size=args.batch_size, 66 | num_tasks=num_tasks_eval, 67 | mode="forecasting", 68 | device=device, 69 | ), 70 | ), 71 | ( 72 | "Reconstruction (Simulated)", 73 | nps.PredPreyGenerator( 74 | torch.float32, 75 | seed=30, 76 | batch_size=args.batch_size, 77 | num_tasks=num_tasks_eval, 78 | mode="reconstruction", 79 | device=device, 80 | ), 81 | ), 82 | # For the real tasks, the batch size will be one. Keep the number of batches 83 | # the same. 84 | ( 85 | "Interpolation (Real)", 86 | nps.PredPreyRealGenerator( 87 | torch.float32, 88 | seed=30, 89 | num_tasks=num_tasks_eval // args.batch_size, 90 | mode="interpolation", 91 | device=device, 92 | ), 93 | ), 94 | ( 95 | "Forecasting (Real)", 96 | nps.PredPreyRealGenerator( 97 | torch.float32, 98 | seed=30, 99 | num_tasks=num_tasks_eval // args.batch_size, 100 | mode="forecasting", 101 | device=device, 102 | ), 103 | ), 104 | ( 105 | "Reconstruction (Real)", 106 | nps.PredPreyRealGenerator( 107 | torch.float32, 108 | seed=30, 109 | num_tasks=num_tasks_eval // args.batch_size, 110 | mode="reconstruction", 111 | device=device, 112 | ), 113 | ), 114 | ] 115 | 116 | return gen_train, gen_cv, gens_eval 117 | 118 | 119 | register_data("predprey", setup) 120 | -------------------------------------------------------------------------------- /experiment/data/util.py: -------------------------------------------------------------------------------- 1 | __all__ = ["data", "register_data"] 2 | 3 | 4 | data = {} #: All data sets to train on 5 | 6 | 7 | def register_data(name, setup, requires_dim_x=False, requires_dim_y=False): 8 | """Register a data set. 9 | 10 | Args: 11 | name (str): Name. 12 | setup (function): Setup function. 13 | requires_dim_x (bool, optional): Requires the value of `--dim-x`. Defaults to 14 | `False`. 15 | requires_dim_y (bool, optional): Requires the value of `--dim-y`. Defaults to 16 | `False`. 17 | """ 18 | data[name] = { 19 | "setup": setup, 20 | "requires_dim_x": requires_dim_x, 21 | "requires_dim_y": requires_dim_y, 22 | } 23 | -------------------------------------------------------------------------------- /neuralprocesses/__init__.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import plum 4 | 5 | _internal_dispatch = plum.Dispatcher() 6 | 7 | 8 | def _dispatch(f=None, **kw_args): 9 | if f is None: 10 | return functools.partial(_dispatch, **kw_args) 11 | 12 | if f.__name__ in {"code", "code_track", "recode"}: 13 | 14 | @functools.wraps(f) 15 | def f_wrapped(*args, **kw_args): 16 | if "root" not in kw_args or not kw_args["root"]: 17 | raise RuntimeError( 18 | "Did you not set `root = True` at the root coding call, " 19 | "or did you forget to propagate `**kw_args`?" 20 | ) 21 | return f(*args, **kw_args) 22 | 23 | return _internal_dispatch(f_wrapped, **kw_args) 24 | else: 25 | return _internal_dispatch(f, **kw_args) 26 | 27 | 28 | from .aggregate import * 29 | from .architectures import * 30 | from .augment import * 31 | from .chain import * 32 | from .coders import * 33 | from .coding import * 34 | from .data import * 35 | from .datadims import * 36 | from .disc import * 37 | from .dist import * 38 | from .likelihood import * 39 | from .mask import * 40 | from .materialise import * 41 | from .model import * 42 | from .numdata import * 43 | from .parallel import * 44 | -------------------------------------------------------------------------------- /neuralprocesses/aggregate.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import lab as B 4 | 5 | from . import _dispatch 6 | 7 | __all__ = ["Aggregate", "AggregateInput"] 8 | 9 | 10 | class Aggregate: 11 | """An ordered aggregate of things. 12 | 13 | Args: 14 | *elements (object): Elements in the aggregate. 15 | 16 | Attributes: 17 | elements (tuple): Elements in the aggregate. 18 | """ 19 | 20 | def __init__(self, *elements): 21 | self.elements = elements 22 | 23 | def __iter__(self): 24 | return iter(self.elements) 25 | 26 | def __len__(self): 27 | return len(self.elements) 28 | 29 | def __getitem__(self, item): 30 | return self.elements[item] 31 | 32 | 33 | @B.dispatch 34 | def on_device(agg: Aggregate): 35 | return B.on_device(agg[0]) 36 | 37 | 38 | @B.dispatch 39 | def dtype(agg: Aggregate): 40 | return B.dtype(*agg) 41 | 42 | 43 | @B.dispatch 44 | def cast(dtype: B.DType, agg: Aggregate): 45 | return Aggregate(*(B.cast(dtype, x) for x in agg)) 46 | 47 | 48 | def _assert_equal_lengths(*elements): 49 | if any(len(elements[0]) != len(e) for e in elements[1:]): 50 | raise ValueError("Aggregates have unequal lengths.") 51 | 52 | 53 | def _map_f(name, num_args): 54 | method = getattr(B, name) 55 | 56 | if num_args == 1: 57 | 58 | @method.dispatch 59 | def f(a: Aggregate, **kw_args): 60 | return Aggregate(*(getattr(B, name)(ai, **kw_args) for ai in a)) 61 | 62 | elif num_args == 2: 63 | 64 | @method.dispatch 65 | def f(a: Aggregate, b: Aggregate, **kw_args): 66 | _assert_equal_lengths(a, b) 67 | return Aggregate( 68 | *(getattr(B, name)(ai, bi, **kw_args) for ai, bi in zip(a, b)) 69 | ) 70 | 71 | elif num_args == "*": 72 | 73 | @method.dispatch 74 | def f(*args: Aggregate, **kw_args): 75 | _assert_equal_lengths(*args) 76 | return Aggregate(*(getattr(B, name)(*xs, **kw_args) for xs in zip(*args))) 77 | 78 | else: 79 | raise ValueError(f"Invalid number of arguments {num_args}.") 80 | 81 | 82 | _map_f("expand_dims", 1) 83 | _map_f("exp", 1) 84 | _map_f("one", 1) 85 | _map_f("zero", 1) 86 | _map_f("mean", 1) 87 | _map_f("sum", 1) 88 | _map_f("logsumexp", 1) 89 | 90 | _map_f("add", 2) 91 | _map_f("subtract", 2) 92 | _map_f("multiply", 2) 93 | _map_f("divide", 2) 94 | 95 | _map_f("stack", "*") 96 | _map_f("concat", "*") 97 | _map_f("squeeze", "*") 98 | 99 | 100 | @B.dispatch 101 | def max(a: Aggregate): 102 | return B.max(B.stack(*(B.max(ai) for ai in a))) 103 | 104 | 105 | @B.dispatch 106 | def min(a: Aggregate): 107 | return B.min(B.stack(*(B.min(ai) for ai in a))) 108 | 109 | 110 | class AggregateInput: 111 | """An ordered aggregate of inputs for specific outputs. This allow the user to 112 | specify different inputs for different outputs. 113 | 114 | Args: 115 | *elements (tuple[object, int]): A tuple of inputs and integers where the integer 116 | selects the particular output. 117 | 118 | Attributes: 119 | elements (tuple[object, int]): Elements in the aggregate input. 120 | """ 121 | 122 | @_dispatch 123 | def __init__(self, *elements: Tuple[object, int]): 124 | self.elements = elements 125 | 126 | def __iter__(self): 127 | return iter(self.elements) 128 | 129 | def __len__(self): 130 | return len(self.elements) 131 | 132 | def __getitem__(self, item): 133 | return self.elements[item] 134 | 135 | 136 | @B.dispatch 137 | def on_device(agg: AggregateInput): 138 | return B.on_device(agg[0][0]) 139 | 140 | 141 | @B.dispatch 142 | def dtype(agg: AggregateInput): 143 | return B.dtype(*(x for x, i in agg)) 144 | 145 | 146 | @B.dispatch 147 | def cast(dtype: B.DType, agg: AggregateInput): 148 | return Aggregate(*((B.cast(dtype, x), i) for x, i in agg)) 149 | -------------------------------------------------------------------------------- /neuralprocesses/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .agnp import * 2 | from .climate import * 3 | from .convgnp import * 4 | from .fullconvgnp import * 5 | from .gnp import * 6 | -------------------------------------------------------------------------------- /neuralprocesses/architectures/agnp.py: -------------------------------------------------------------------------------- 1 | import neuralprocesses as nps # This fixes inspection below. 2 | 3 | from ..util import register_model 4 | 5 | __all__ = ["construct_agnp"] 6 | 7 | 8 | @register_model 9 | def construct_agnp(*args, nps=nps, num_heads=8, **kw_args): 10 | """An Attentive Gaussian Neural Process. 11 | 12 | Args: 13 | dim_x (int, optional): Dimensionality of the inputs. Defaults to 1. 14 | dim_y (int, optional): Dimensionality of the outputs. Defaults to 1. 15 | dim_yc (int or tuple[int], optional): Dimensionality of the outputs of the 16 | context set. You should set this if the dimensionality of the outputs 17 | of the context set is not equal to the dimensionality of the outputs 18 | of the target set. You should also set this if you want to use multiple 19 | context sets. In that case, set this equal to a tuple of integers 20 | indicating the respective output dimensionalities. 21 | dim_yt (int, optional): Dimensionality of the outputs of the target set. You 22 | should set this if the dimensionality of the outputs of the target set is 23 | not equal to the dimensionality of the outputs of the context set. 24 | dim_embedding (int, optional): Dimensionality of the embedding. Defaults to 128. 25 | num_heads (int, optional): Number of heads. Defaults to `8`. 26 | num_enc_layers (int, optional): Number of layers in the encoder. Defaults to 3. 27 | enc_same (bool, optional): Use the same encoder for all context sets. This 28 | only works if all context sets have the same dimensionality. Defaults to 29 | `False`. 30 | num_dec_layers (int, optional): Number of layers in the decoder. Defaults to 6. 31 | width (int, optional): Widths of all intermediate MLPs. Defaults to 512. 32 | nonlinearity (Callable or str, optional): Nonlinearity. Can also be specified 33 | as a string: `"ReLU"` or `"LeakyReLU"`. Defaults to ReLUs. 34 | likelihood (str, optional): Likelihood. Must be one of `"het"`, `"lowrank"`, 35 | `"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`. 36 | num_basis_functions (int, optional): Number of basis functions for the 37 | low-rank likelihood. Defaults to 512. 38 | dim_lv (int, optional): Dimensionality of the latent variable. Defaults to 0. 39 | lv_likelihood (str, optional): Likelihood of the latent variable. Must be one of 40 | `"het"` or `"dense"`. Defaults to `"het"`. 41 | transform (str or tuple[float, float]): Bijection applied to the 42 | output of the model. This can help deal with positive of bounded data. 43 | Must be either `"positive"`, `"exp"`, `"softplus"`, or 44 | `"softplus_of_square"` for positive data or `(lower, upper)` for data in 45 | this open interval. 46 | dtype (dtype, optional): Data type. 47 | 48 | Returns: 49 | :class:`.model.Model`: AGNP model. 50 | """ 51 | return nps.construct_gnp( 52 | *args, 53 | nps=nps, 54 | attention=True, 55 | attention_num_heads=num_heads, 56 | **kw_args, 57 | ) 58 | -------------------------------------------------------------------------------- /neuralprocesses/architectures/util.py: -------------------------------------------------------------------------------- 1 | import neuralprocesses as nps # This fixes inspection below. 2 | 3 | __all__ = [ 4 | "construct_likelihood", 5 | "parse_transform", 6 | ] 7 | 8 | 9 | def construct_likelihood(nps=nps, *, spec, dim_y, num_basis_functions, dtype): 10 | """Construct the likelihood. 11 | 12 | Args: 13 | nps (module): Appropriate backend-specific module. 14 | spec (str, optional): Specification. Must be one of `"het"`, `"lowrank"`, 15 | `"dense"`, `"spikes-beta"`, or `"bernoulli-gamma"`. Defaults to `"lowrank"`. 16 | Must be given as a keyword argument. 17 | dim_y (int): Dimensionality of the outputs. Must be given as a keyword argument. 18 | num_basis_functions (int): Number of basis functions for the low-rank 19 | likelihood. Must be given as a keyword argument. 20 | dtype (dtype): Data type. Must be given as a keyword argument. 21 | 22 | Returns: 23 | int: Number of channels that the likelihood requires. 24 | coder: Coder which can select a particular output channel. This coder may be 25 | `None`. 26 | coder: Coder. 27 | """ 28 | if spec == "het": 29 | num_channels = 2 * dim_y 30 | selector = nps.SelectFromChannels(dim_y, dim_y) 31 | lik = nps.HeterogeneousGaussianLikelihood() 32 | elif spec == "lowrank": 33 | num_channels = (2 + num_basis_functions) * dim_y 34 | selector = nps.SelectFromChannels(dim_y, (num_basis_functions, dim_y), dim_y) 35 | lik = nps.LowRankGaussianLikelihood(num_basis_functions) 36 | elif spec == "dense": 37 | # This is intended to only work for global variables. 38 | num_channels = 2 * dim_y + dim_y * dim_y 39 | selector = None 40 | lik = nps.Chain( 41 | nps.Splitter(2 * dim_y, dim_y * dim_y), 42 | nps.Parallel( 43 | lambda x: x, 44 | nps.Chain( 45 | nps.ToDenseCovariance(), 46 | nps.DenseCovariancePSDTransform(), 47 | ), 48 | ), 49 | nps.DenseGaussianLikelihood(), 50 | ) 51 | elif spec == "spikes-beta": 52 | num_channels = (2 + 3) * dim_y # Alpha, beta, and three log-probabilities 53 | selector = nps.SelectFromChannels(dim_y, dim_y, dim_y, dim_y, dim_y) 54 | lik = nps.SpikesBetaLikelihood() 55 | elif spec == "bernoulli-gamma": 56 | num_channels = (2 + 2) * dim_y # Shape, scale, and two log-probabilities 57 | selector = nps.SelectFromChannels(dim_y, dim_y, dim_y, dim_y) 58 | lik = nps.BernoulliGammaLikelihood() 59 | 60 | else: 61 | raise ValueError(f'Incorrect likelihood specification "{spec}".') 62 | return num_channels, selector, lik 63 | 64 | 65 | def parse_transform(nps=nps, *, transform): 66 | """Construct the likelihood. 67 | 68 | Args: 69 | nps (module): Appropriate backend-specific module. 70 | transform (str or tuple[float, float]): Bijection applied to the 71 | output of the model. This can help deal with positive of bounded data. 72 | Must be either `"positive"`, `"exp"`, `"softplus"`, or 73 | `"softplus_of_square"` for positive data or `(lower, upper)` for data in 74 | this open interval. 75 | 76 | Returns: 77 | coder: Transform. 78 | """ 79 | if isinstance(transform, str) and transform.lower() in {"positive", "exp"}: 80 | transform = nps.Transform.exp() 81 | elif isinstance(transform, str) and transform.lower() == "softplus": 82 | transform = nps.Transform.softplus() 83 | elif isinstance(transform, str) and transform.lower() == "softplus_of_square": 84 | transform = nps.Chain( 85 | nps.Transform.signed_square(), 86 | nps.Transform.softplus(), 87 | ) 88 | elif isinstance(transform, tuple): 89 | lower, upper = transform 90 | transform = nps.Transform.bounded(lower, upper) 91 | elif transform is not None: 92 | raise ValueError(f'Cannot parse value "{transform}" for `transform`.') 93 | else: 94 | transform = lambda x: x 95 | return transform 96 | -------------------------------------------------------------------------------- /neuralprocesses/augment.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | 3 | __all__ = ["AugmentedInput"] 4 | 5 | 6 | class AugmentedInput: 7 | """An augmented input. 8 | 9 | Args: 10 | x (input): Input. 11 | augmentation (object): Augmentation. 12 | 13 | Attributes: 14 | x (input): Input. 15 | augmentation (object): Augmentation. 16 | """ 17 | 18 | def __init__(self, x, augmentation): 19 | self.x = x 20 | self.augmentation = augmentation 21 | 22 | 23 | @B.dispatch 24 | def on_device(x: AugmentedInput): 25 | return B.on_device(x.x) 26 | 27 | 28 | @B.dispatch 29 | def dtype(x: AugmentedInput): 30 | return B.dtype(x.x) 31 | -------------------------------------------------------------------------------- /neuralprocesses/chain.py: -------------------------------------------------------------------------------- 1 | from matrix.util import indent 2 | 3 | from . import _dispatch 4 | from .util import is_framework_module, register_module 5 | 6 | __all__ = ["Chain"] 7 | 8 | 9 | @register_module 10 | class Chain: 11 | """A chain of links. 12 | 13 | Args: 14 | *links (object): Links of the chain. 15 | 16 | Attributes: 17 | links (tuple): Links of the chain. 18 | """ 19 | 20 | def __init__(self, *links): 21 | # Filter `None`s. 22 | links = tuple(filter(None, links)) 23 | if any(is_framework_module(link) for link in links): 24 | self.links = self.nn.ModuleList(links) 25 | else: 26 | self.links = links 27 | 28 | def __call__(self, x): 29 | for link in self.links: 30 | x = link(x) 31 | return x 32 | 33 | def __getitem__(self, item): 34 | return self.links[item] 35 | 36 | def __len__(self): 37 | return len(self.links) 38 | 39 | def __iter__(self): 40 | return iter(self.links) 41 | 42 | def __str__(self): 43 | return repr(self) 44 | 45 | def __repr__(self): 46 | return ( 47 | "Chain(\n" 48 | + "".join([indent(repr(e).strip(), " " * 4) + ",\n" for e in self]) 49 | + ")" 50 | ) 51 | 52 | 53 | @_dispatch 54 | def code(chain: Chain, xz, z, x, **kw_args): 55 | for link in chain: 56 | xz, z = code(link, xz, z, x, **kw_args) 57 | return xz, z 58 | 59 | 60 | @_dispatch 61 | def code_track(chain: Chain, xz, z, x, h, **kw_args): 62 | for link in chain: 63 | xz, z, h = code_track(link, xz, z, x, h, **kw_args) 64 | return xz, z, h 65 | 66 | 67 | @_dispatch 68 | def recode(chain: Chain, xz, z, h, **kw_args): 69 | for link in chain: 70 | xz, z, h = recode(link, xz, z, h, **kw_args) 71 | return xz, z, h 72 | -------------------------------------------------------------------------------- /neuralprocesses/coders/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregate import * 2 | from .attention import * 3 | from .augment import * 4 | from .copy import * 5 | from .deepset import * 6 | from .densecov import * 7 | from .functional import * 8 | from .fuse import * 9 | from .inputs import * 10 | from .mapdiag import * 11 | from .nn import * 12 | from .setconv import * 13 | from .shaping import * 14 | -------------------------------------------------------------------------------- /neuralprocesses/coders/augment.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | 3 | from .. import _dispatch 4 | from ..augment import AugmentedInput 5 | from ..datadims import data_dims 6 | from ..materialise import _repeat_concat 7 | from ..util import register_composite_coder, register_module 8 | 9 | __all__ = ["Augment", "AssertNoAugmentation"] 10 | 11 | 12 | @register_composite_coder 13 | @register_module 14 | class Augment: 15 | """Concatenate the augmentation of the input to the encoding, and remove any 16 | augmentation of target inputs. 17 | 18 | Args: 19 | coder (coder): Coder to run after the augmentation. 20 | 21 | Attributes: 22 | coder (coder): Coder to run after the augmentation. 23 | """ 24 | 25 | def __init__(self, coder): 26 | self.coder = coder 27 | 28 | 29 | @_dispatch 30 | def code( 31 | coder: Augment, 32 | xz, 33 | z, 34 | x, 35 | **kw_args, 36 | ): 37 | xz, z = _augment(xz, z) 38 | x = _augment(x) 39 | return code(coder.coder, xz, z, x, **kw_args) 40 | 41 | 42 | @_dispatch 43 | def _augment(xz: AugmentedInput, z: B.Numeric): 44 | return xz.x, _repeat_concat(data_dims(xz), z, xz.augmentation) 45 | 46 | 47 | @_dispatch 48 | def _augment(xz: AugmentedInput): 49 | return xz.x 50 | 51 | 52 | @register_module 53 | class AssertNoAugmentation: 54 | """Assert no augmentation of the target inputs.""" 55 | 56 | 57 | @_dispatch 58 | def code(coder: AssertNoAugmentation, xz, z, x, **kw_args): 59 | return xz, z 60 | 61 | 62 | @_dispatch 63 | def code(coder: AssertNoAugmentation, xz, z, x: AugmentedInput, **kw_args): 64 | raise AssertionError("Did not expect augmentation of the target inputs.") 65 | -------------------------------------------------------------------------------- /neuralprocesses/coders/copy.py: -------------------------------------------------------------------------------- 1 | from .. import _dispatch 2 | from ..parallel import Parallel 3 | from ..util import register_module 4 | 5 | __all__ = ["Copy"] 6 | 7 | 8 | @register_module 9 | class Copy: 10 | def __init__(self, times): 11 | self.times = times 12 | 13 | 14 | @_dispatch 15 | def code(coder: Copy, xz, z, x, **kw_args): 16 | return Parallel(*((xz,) * coder.times)), Parallel(*((z,) * coder.times)) 17 | -------------------------------------------------------------------------------- /neuralprocesses/coders/deepset.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import matrix # noqa 3 | 4 | from .. import _dispatch 5 | from ..util import register_module 6 | 7 | __all__ = ["DeepSet"] 8 | 9 | 10 | @register_module 11 | class DeepSet: 12 | """Deep set. 13 | 14 | Args: 15 | phi (object): Pre-aggregation function. 16 | agg (object, optional): Aggregation function. Defaults to summing. 17 | 18 | Attributes: 19 | phi (object): Pre-aggregation function. 20 | agg (object): Aggregation function. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | phi, 26 | agg=lambda x: B.sum(x, axis=-1, squeeze=False), 27 | ): 28 | self.phi = phi 29 | self.agg = agg 30 | 31 | 32 | @_dispatch 33 | def code(coder: DeepSet, xz: B.Numeric, z: B.Numeric, x, **kw_args): 34 | z = B.concat(xz, z, axis=-2) 35 | z = coder.phi(z) 36 | z = coder.agg(z) # This aggregates over the data dimension. 37 | return None, z 38 | -------------------------------------------------------------------------------- /neuralprocesses/coders/densecov.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import numpy as np 3 | 4 | from .. import _dispatch 5 | from ..datadims import data_dims 6 | from ..util import batch, register_module, split_dimension 7 | 8 | __all__ = [ 9 | "ToDenseCovariance", 10 | "FromDenseCovariance", 11 | "DenseCovariancePSDTransform", 12 | ] 13 | 14 | 15 | def _reorder_groups(z, groups, order): 16 | perm = list(range(B.rank(z))) 17 | # Assume that the groups are specified from the end, so walk through everything in 18 | # reversed order. 19 | group_axes = [] 20 | for g in reversed(groups): 21 | perm, axes = perm[:-g], perm[-g:] 22 | group_axes.append(axes) 23 | group_axes = list(reversed(group_axes)) 24 | perm = perm + sum([group_axes[i] for i in order], []) 25 | return B.transpose(z, perm=perm) 26 | 27 | 28 | @register_module 29 | class ToDenseCovariance: 30 | """Shape a regular encoding into a dense covariance encoding.""" 31 | 32 | 33 | @_dispatch 34 | def code(coder: ToDenseCovariance, xz: None, z: B.Numeric, x, **kw_args): 35 | c, n = B.shape(z, -2, -1) 36 | if n != 1: 37 | raise ValueError("Encoding is not global.") 38 | sqrt_c = int(B.sqrt(c)) 39 | # Only in this case, we also duplicate the inputs! 40 | return (None, None), B.reshape(z, *B.shape(z)[:-2], sqrt_c, 1, sqrt_c, 1) 41 | 42 | 43 | @_dispatch 44 | def code(coder: ToDenseCovariance, xz: tuple, z: B.Numeric, x, **kw_args): 45 | d = data_dims(xz) // 2 46 | c = B.shape(z, -2 * d - 1) 47 | sqrt_c = int(B.sqrt(c)) 48 | z = split_dimension(z, -2 * d - 1, (sqrt_c, sqrt_c)) 49 | 50 | # The ordering is now `(..., sqrt_c, sqrt_c, *n, *n)` where the length of 51 | # `(*n,)` is `d`. We want to swap the last `sqrt_c` with the first `*n`. 52 | z = _reorder_groups(z, (1, 1, d, d), (0, 2, 1, 3)) 53 | 54 | return xz, z 55 | 56 | 57 | @register_module 58 | class FromDenseCovariance: 59 | """Shape a dense covariance encoding into a regular encoding.""" 60 | 61 | 62 | @_dispatch 63 | def code(coder: FromDenseCovariance, xz, z, x, **kw_args): 64 | d = data_dims(xz) // 2 65 | sqrt_c = B.shape(z, -d - 1) 66 | 67 | # The ordering is `(..., sqrt_c, *n, sqrt_c, *n)` where the length of `(*n,)` is 68 | # `d`. We want to swap the first `*n` with the last `sqrt_c`. 69 | z = _reorder_groups(z, (1, d, 1, d), (0, 2, 1, 3)) 70 | 71 | # Now merge the separate channel dimensions. 72 | z = B.reshape(z, *batch(z, 2 * d + 2), sqrt_c * sqrt_c, *B.shape(z)[-2 * d :]) 73 | 74 | return xz, z 75 | 76 | 77 | @register_module 78 | class DenseCovariancePSDTransform: 79 | """Multiply a dense covariance encoding by itself transposed to ensure that it is 80 | PSD.""" 81 | 82 | 83 | @_dispatch 84 | def code(coder: DenseCovariancePSDTransform, xz, z: B.Numeric, x, **kw_args): 85 | d = data_dims(xz) // 2 86 | 87 | # Record the original shape so we can transform back at the end. 88 | orig_shape = B.shape(z) 89 | 90 | # Compute the lengths of the sides of the covariance. 91 | len1 = np.prod(B.shape(z)[-d - 1 :]) 92 | len2 = np.prod(B.shape(z)[2 * (-d - 1) : -d - 1]) 93 | 94 | # Reshape into matrix, perform PD transform, and reshape back. 95 | z = B.reshape(z, *B.shape(z)[: 2 * (-d - 1)], len1, len2) 96 | z = B.matmul(z, z, tr_b=True) 97 | z = z / 100 # Stabilise the initialisation. 98 | z = B.reshape(z, *orig_shape) 99 | 100 | return xz, z 101 | -------------------------------------------------------------------------------- /neuralprocesses/coders/functional.py: -------------------------------------------------------------------------------- 1 | import matrix # noqa 2 | from plum import convert 3 | 4 | from .. import _dispatch 5 | from ..util import register_composite_coder, register_module 6 | 7 | __all__ = ["FunctionalCoder"] 8 | 9 | 10 | @register_composite_coder 11 | @register_module 12 | class FunctionalCoder: 13 | """A coder that codes to a discretisation for a functional representation. 14 | 15 | Args: 16 | disc (:class:`.discretisation.AbstractDiscretisation`): Discretisation. 17 | coder (coder): Coder. 18 | target (function, optional): Function which takes in the inputs of the current 19 | encoding and the desired inputs and which returns a tuple containing the 20 | inputs to span the discretisation over. 21 | 22 | Attributes: 23 | disc (:class:`.discretisation.AbstractDiscretisation`): Discretisation. 24 | coder (coder): Coder. 25 | target (function): Function which takes in the inputs of the current encoding 26 | and the desired inputs and which returns a tuple containing the inputs to 27 | span the discretisation over. 28 | """ 29 | 30 | def __init__(self, disc, coder, target=lambda xc, xt: (xc, xt)): 31 | self.disc = disc 32 | self.coder = coder 33 | self.target = target 34 | 35 | 36 | @_dispatch 37 | def code(coder: FunctionalCoder, xz, z, x, **kw_args): 38 | x = coder.disc(*coder.target(xz, x), **kw_args) 39 | return code(coder.coder, xz, z, x, **kw_args) 40 | 41 | 42 | @_dispatch 43 | def code_track(coder: FunctionalCoder, xz, z, x, h, **kw_args): 44 | x = coder.disc(*coder.target(xz, x), **kw_args) 45 | return code_track(coder.coder, xz, z, x, h + [x], **kw_args) 46 | 47 | 48 | @_dispatch 49 | def recode(coder: FunctionalCoder, xz, z, h, **kw_args): 50 | return recode(coder.coder, xz, z, h[1:], **kw_args) 51 | -------------------------------------------------------------------------------- /neuralprocesses/coders/fuse.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | 3 | from .. import _dispatch 4 | from ..datadims import data_dims 5 | from ..parallel import Parallel 6 | from ..util import register_module 7 | 8 | __all__ = ["Fuse"] 9 | 10 | 11 | @register_module 12 | class Fuse: 13 | """In a parallel of two things, interpolate the first to the inputs of the second, 14 | and concatenate the result to the second. 15 | 16 | Args: 17 | set_conv (coder): Set conv that should perform the interpolation. 18 | 19 | Attributes: 20 | set_conv (coder): Set conv that should perform the interpolation. 21 | """ 22 | 23 | def __init__(self, set_conv): 24 | self.set_conv = set_conv 25 | 26 | 27 | @_dispatch 28 | def code(coder: Fuse, xz: Parallel, z: Parallel, x, **kw_args): 29 | xz1, xz2 = xz 30 | z1, z2 = z 31 | _, z1 = code(coder.set_conv, xz1, z1, xz2, **kw_args) 32 | return xz2, B.concat(z1, z2, axis=-1 - data_dims(xz2)) 33 | -------------------------------------------------------------------------------- /neuralprocesses/coders/inputs.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import matrix # noqa 3 | 4 | from .. import _dispatch 5 | from ..aggregate import Aggregate, AggregateInput 6 | from ..util import register_module 7 | 8 | __all__ = ["InputsCoder"] 9 | 10 | 11 | @register_module 12 | class InputsCoder: 13 | """Encode with the target inputs.""" 14 | 15 | 16 | @_dispatch 17 | def code(coder: InputsCoder, xz, z, x: B.Numeric, **kw_args): 18 | return x, x 19 | 20 | 21 | @_dispatch 22 | def code(coder: InputsCoder, xz, z, x: AggregateInput, **kw_args): 23 | return x, Aggregate(*(xi for xi, i in x)) 24 | -------------------------------------------------------------------------------- /neuralprocesses/coders/mapdiag.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import matrix # noqa 3 | 4 | from .. import _dispatch 5 | from ..aggregate import AggregateInput 6 | from ..parallel import Parallel 7 | from ..util import register_composite_coder, register_module 8 | 9 | __all__ = ["MapDiagonal"] 10 | 11 | 12 | @register_composite_coder 13 | @register_module 14 | class MapDiagonal: 15 | """Map to the diagonal of the squared space. 16 | 17 | Args: 18 | coder (coder): Coder to apply the mapped values to. 19 | 20 | Attributes: 21 | coder (coder): Coder to apply the mapped values to. 22 | """ 23 | 24 | def __init__(self, coder): 25 | self.coder = coder 26 | 27 | 28 | @_dispatch 29 | def code(coder: MapDiagonal, xz, z, x, **kw_args): 30 | x, d = _mapdiagonal_duplicate_target(x) 31 | # The encoding might already be on the diagonal. Therefore, only duplicate the 32 | # inputs if the dimensionalities don't line up. 33 | xz = _mapdiagonal_possibly_duplicate_context(xz, d) 34 | return code(coder.coder, xz, z, x, **kw_args) 35 | 36 | 37 | @_dispatch 38 | def code_track(coder: MapDiagonal, xz, z, x, h, **kw_args): 39 | x, d = _mapdiagonal_duplicate_target(x) 40 | xz = _mapdiagonal_possibly_duplicate_context(xz, d) 41 | return code_track(coder.coder, xz, z, x, h + [(x, d)], **kw_args) 42 | 43 | 44 | @_dispatch 45 | def recode(coder: MapDiagonal, xz, z, h, **kw_args): 46 | (_, d), h = h[0], h[1:] 47 | xz = _mapdiagonal_possibly_duplicate_context(xz, d) 48 | return recode(coder.coder, xz, z, h, **kw_args) 49 | 50 | 51 | @_dispatch 52 | def _mapdiagonal_duplicate_target(x: B.Numeric): 53 | return (x, x), 2 54 | 55 | 56 | @_dispatch 57 | def _mapdiagonal_duplicate_target(x: AggregateInput): 58 | xis, ds = zip(*(_mapdiagonal_duplicate_target(xi) for xi, _ in x)) 59 | if not all([d == ds[0] for d in ds[1:]]): 60 | raise NotImplementedError("All data dimensionalities must be equal.") 61 | else: 62 | d = ds[0] 63 | return AggregateInput(*((xi, i) for xi, (_, i) in zip(xis, x))), d 64 | 65 | 66 | @_dispatch 67 | def _mapdiagonal_possibly_duplicate_context(xz: B.Numeric, d: B.Int): 68 | if B.shape(xz, -2) != d: 69 | return B.concat(xz, xz, axis=-2) 70 | else: 71 | return xz 72 | 73 | 74 | @_dispatch 75 | def _mapdiagonal_possibly_duplicate_context(xz: tuple, d: B.Int): 76 | if len(xz) != d: 77 | return xz * 2 78 | else: 79 | return xz 80 | 81 | 82 | @_dispatch 83 | def _mapdiagonal_possibly_duplicate_context(xz: Parallel, d: B.Int): 84 | return Parallel(*(_mapdiagonal_possibly_duplicate_context(xzi, d) for xzi in xz)) 85 | -------------------------------------------------------------------------------- /neuralprocesses/coders/setconv/__init__.py: -------------------------------------------------------------------------------- 1 | from .density import * 2 | from .identity import * 3 | from .setconv import * 4 | -------------------------------------------------------------------------------- /neuralprocesses/coders/setconv/density.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import lab as B 4 | from plum import isinstance 5 | 6 | from ... import _dispatch 7 | from ...datadims import data_dims 8 | from ...mask import Masked 9 | from ...parallel import broadcast_coder_over_parallel 10 | from ...util import batch, register_module 11 | 12 | __all__ = [ 13 | "PrependDensityChannel", 14 | "PrependMultiDensityChannel", 15 | "DivideByFirstChannel", 16 | "DivideByFirstHalf", 17 | ] 18 | 19 | 20 | @register_module 21 | class PrependDensityChannel: 22 | """Prepend a density channel to the current encoding. 23 | 24 | Args: 25 | multi (bool, optional): Produce a separate density channel for every channel. 26 | If `False`, produce just one density channel for all channels. Defaults 27 | to `False`. 28 | """ 29 | 30 | def __init__(self, multi=False): 31 | self.multi = multi 32 | 33 | 34 | @register_module 35 | class PrependMultiDensityChannel(PrependDensityChannel): 36 | """Prepend a separate density channel for every channel to the current encoding.""" 37 | 38 | def __init__(self): 39 | PrependDensityChannel.__init__(self, multi=True) 40 | 41 | 42 | @_dispatch 43 | def code(coder: PrependDensityChannel, xz, z: B.Numeric, x, **kw_args): 44 | d = data_dims(xz) 45 | with B.on_device(z): 46 | if coder.multi: 47 | # Produce separate density channels. 48 | c = B.shape(z, -d - 1) 49 | else: 50 | # Produce just one density channel. 51 | c = 1 52 | density_channel = B.ones(B.dtype(z), *batch(z, d + 1), c, *B.shape(z)[-d:]) 53 | return xz, B.concat(density_channel, z, axis=-d - 1) 54 | 55 | 56 | broadcast_coder_over_parallel(PrependDensityChannel) 57 | 58 | 59 | @_dispatch 60 | def code(coder: PrependDensityChannel, xz, z: Masked, x, **kw_args): 61 | mask = z.mask 62 | d = data_dims(xz) 63 | # Set the missing values to zero by multiplying with the mask. Zeros in the data 64 | # channel do not affect the encoding. 65 | return xz, B.concat(z.mask, z.y * z.mask, axis=-d - 1) 66 | 67 | 68 | @register_module 69 | class DivideByChannels: 70 | """Divide by the first `n` channels. 71 | 72 | Args: 73 | spec (int or str): Channels to divide by. 74 | epsilon (float): Value to add to the channel before dividing. 75 | 76 | Attributes: 77 | spec (int or str): Channels to divide by. 78 | epsilon (float): Value to add to the channel before dividing. 79 | """ 80 | 81 | @_dispatch 82 | def __init__(self, spec: Union[int, str], epsilon: float): 83 | self.spec = spec 84 | self.epsilon = epsilon 85 | 86 | 87 | @register_module 88 | class DivideByFirstChannel(DivideByChannels): 89 | """Divide by the first channel. 90 | 91 | Args: 92 | epsilon (float): Value to add to the channel before dividing. 93 | 94 | Attributes: 95 | epsilon (float): Value to add to the channel before dividing. 96 | """ 97 | 98 | @_dispatch 99 | def __init__(self, epsilon: float = 1e-8): 100 | DivideByChannels.__init__(self, 1, epsilon) 101 | 102 | 103 | @register_module 104 | class DivideByFirstHalf(DivideByChannels): 105 | """Divide by the first half of channels. 106 | 107 | Args: 108 | epsilon (float): Value to add to the channels before dividing. 109 | 110 | Attributes: 111 | epsilon (float): Value to add to the channels before dividing. 112 | """ 113 | 114 | @_dispatch 115 | def __init__(self, epsilon: float = 1e-8): 116 | DivideByChannels.__init__(self, "half", epsilon) 117 | 118 | 119 | @_dispatch 120 | def code( 121 | coder: DivideByChannels, 122 | xz, 123 | z: B.Numeric, 124 | x, 125 | epsilon: Optional[float] = None, 126 | **kw_args, 127 | ): 128 | epsilon = epsilon or coder.epsilon 129 | d = data_dims(xz) 130 | if isinstance(coder.spec, B.Int): 131 | num_divide = coder.spec 132 | elif coder.spec == "half": 133 | num_divide = B.shape(z, -d - 1) // 2 134 | else: 135 | raise ValueError(f"Unknown specification `{coder.spec}`.") 136 | slice1 = (Ellipsis, slice(None, num_divide, None)) + (slice(None, None, None),) * d 137 | slice2 = (Ellipsis, slice(num_divide, None, None)) + (slice(None, None, None),) * d 138 | return ( 139 | xz, 140 | B.concat( 141 | z[slice1], 142 | z[slice2] / (z[slice1] + epsilon), 143 | axis=-d - 1, 144 | ), 145 | ) 146 | 147 | 148 | broadcast_coder_over_parallel(DivideByChannels) 149 | -------------------------------------------------------------------------------- /neuralprocesses/coders/setconv/identity.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | 3 | from ... import _dispatch 4 | from ...datadims import data_dims 5 | from ...util import batch, register_module 6 | 7 | __all__ = ["PrependIdentityChannel"] 8 | 9 | 10 | @register_module 11 | class PrependIdentityChannel: 12 | """Prepend a density channel to the current encoding.""" 13 | 14 | 15 | @_dispatch 16 | def code(coder: PrependIdentityChannel, xz, z: B.Numeric, x, **kw_args): 17 | d = data_dims(xz) 18 | b = batch(z, d + 1) 19 | with B.on_device(z): 20 | if d == 2: 21 | identity_channel = B.diag_construct(B.ones(B.dtype(z), B.shape(z, -1))) 22 | else: 23 | raise RuntimeError( 24 | f"Cannot construct identity channels for encodings of " 25 | f"dimensionality {d}." 26 | ) 27 | identity_channel = B.tile( 28 | B.expand_dims(identity_channel, axis=0, times=len(b) + 1), 29 | *b, 30 | 1, 31 | *((1,) * d), 32 | ) 33 | return xz, B.concat(identity_channel, z, axis=-d - 1) 34 | -------------------------------------------------------------------------------- /neuralprocesses/coders/setconv/setconv.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from string import ascii_lowercase as letters 3 | 4 | import lab as B 5 | 6 | from ... import _dispatch 7 | from ...augment import AugmentedInput 8 | from ...parallel import broadcast_coder_over_parallel 9 | from ...util import register_module 10 | 11 | __all__ = ["SetConv"] 12 | 13 | 14 | @register_module 15 | class SetConv: 16 | """A set convolution. 17 | 18 | Args: 19 | scale (float): Initial value for the length scale. 20 | dtype (dtype, optional): Data type. 21 | learnable (bool, optional): Whether the SetConv length scale is learnable. 22 | 23 | Attributes: 24 | log_scale (scalar): Logarithm of the length scale. 25 | 26 | """ 27 | 28 | def __init__(self, scale, dtype=None, learnable=True): 29 | self.log_scale = self.nn.Parameter( 30 | B.log(scale), dtype=dtype, learnable=learnable 31 | ) 32 | 33 | 34 | def _dim_is_concrete(x, i): 35 | try: 36 | int(B.shape(x, i)) 37 | return True 38 | except TypeError: 39 | return False 40 | 41 | 42 | def _batch_targets(f): 43 | @wraps(f) 44 | def f_wrapped(coder, xz, z, x, batch_size=1024, **kw_args): 45 | # If `x` is the internal discretisation and we're compiling this function 46 | # with `tf.function`, then `B.shape(x, -1)` will be `None`. We therefore 47 | # check that `B.shape(x, -1)` is concrete before attempting the comparison. 48 | if _dim_is_concrete(x, -1) and B.shape(x, -1) > batch_size: 49 | i = 0 50 | outs = [] 51 | while i < B.shape(x, -1): 52 | outs.append( 53 | code( 54 | coder, 55 | xz, 56 | z, 57 | x[..., i : i + batch_size], 58 | batch_size=batch_size, 59 | **kw_args, 60 | )[1] 61 | ) 62 | i += batch_size 63 | return x, B.concat(*outs, axis=-1) 64 | else: 65 | return f(coder, xz, z, x, **kw_args) 66 | 67 | return f_wrapped 68 | 69 | 70 | def compute_weights(coder, x1, x2): 71 | # Compute interpolation weights. 72 | dists2 = B.pw_dists2(B.transpose(x1), B.transpose(x2)) 73 | return B.exp(-0.5 * dists2 / B.exp(2 * coder.log_scale)) 74 | 75 | 76 | @_dispatch 77 | @_batch_targets 78 | def code(coder: SetConv, xz: B.Numeric, z: B.Numeric, x: B.Numeric, **kw_args): 79 | return x, B.matmul(z, compute_weights(coder, xz, x)) 80 | 81 | 82 | _setconv_cache_num_tup = {} 83 | 84 | 85 | @_dispatch 86 | def code(coder: SetConv, xz: B.Numeric, z: B.Numeric, x: tuple, **kw_args): 87 | ws = [compute_weights(coder, xz[..., i : i + 1, :], xi) for i, xi in enumerate(x)] 88 | 89 | # Use a cache so we don't build the equation every time. 90 | try: 91 | equation = _setconv_cache_num_tup[len(x)] 92 | except KeyError: 93 | letters_i = 3 94 | base = "...bc" 95 | result = "...b" 96 | for _ in range(len(x)): 97 | let = letters[letters_i] 98 | letters_i += 1 99 | base += f",...c{let}" 100 | result += f"{let}" 101 | _setconv_cache_num_tup[len(x)] = f"{base}->{result}" 102 | equation = _setconv_cache_num_tup[len(x)] 103 | 104 | return x, B.einsum(equation, z, *ws) 105 | 106 | 107 | _setconv_cache_tup_num = {} 108 | 109 | 110 | @_dispatch 111 | @_batch_targets 112 | def code(coder: SetConv, xz: tuple, z: B.Numeric, x: B.Numeric, **kw_args): 113 | ws = [compute_weights(coder, xzi, x[..., i : i + 1, :]) for i, xzi in enumerate(xz)] 114 | 115 | # Use a cache so we don't build the equation every time. 116 | try: 117 | equation = _setconv_cache_tup_num[len(xz)] 118 | except KeyError: 119 | letters_i = 3 120 | base_base = "...b" 121 | base_els = "" 122 | for _ in range(len(xz)): 123 | let = letters[letters_i] 124 | letters_i += 1 125 | base_base += f"{let}" 126 | base_els += f",...{let}c" 127 | _setconv_cache_tup_num[len(xz)] = f"{base_base}{base_els}->...bc" 128 | equation = _setconv_cache_tup_num[len(xz)] 129 | 130 | return x, B.einsum(equation, z, *ws) 131 | 132 | 133 | _setconv_cache_tup_tup = {} 134 | 135 | 136 | @_dispatch 137 | def code(coder: SetConv, xz: tuple, z: B.Numeric, x: tuple, **kw_args): 138 | ws = [compute_weights(coder, xzi, xi) for xzi, xi in zip(xz, x)] 139 | 140 | # Use a cache so we don't build the equation every time. 141 | try: 142 | equation = _setconv_cache_tup_tup[len(x)] 143 | except KeyError: 144 | letters_i = 2 145 | base_base = "...b" 146 | base_els = "" 147 | result = "...b" 148 | for _ in range(len(x)): 149 | let1 = letters[letters_i] 150 | letters_i += 1 151 | let2 = letters[letters_i] 152 | letters_i += 1 153 | base_base += f"{let1}" 154 | base_els += f",...{let1}{let2}" 155 | result += f"{let2}" 156 | _setconv_cache_tup_tup[len(x)] = f"{base_base}{base_els}->{result}" 157 | equation = _setconv_cache_tup_tup[len(x)] 158 | 159 | return x, B.einsum(equation, z, *ws) 160 | 161 | 162 | broadcast_coder_over_parallel(SetConv) 163 | 164 | 165 | @_dispatch 166 | def code(coder: SetConv, xz, z, x: AugmentedInput, **kw_args): 167 | xz, z = code(coder, xz, z, x.x, **kw_args) 168 | return AugmentedInput(xz, x.augmentation), z 169 | -------------------------------------------------------------------------------- /neuralprocesses/coders/shaping.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | 3 | from .. import _dispatch 4 | from ..datadims import data_dims 5 | from ..parallel import Parallel 6 | from ..util import register_module, split 7 | 8 | __all__ = [ 9 | "Identity", 10 | "Splitter", 11 | "RestructureParallel", 12 | "AssertParallel", 13 | "SqueezeParallel", 14 | "AssertNoParallel", 15 | ] 16 | 17 | 18 | @register_module 19 | class Identity: 20 | """Identity coder.""" 21 | 22 | 23 | @_dispatch 24 | def code(coder: Identity, xz, z, x, **kw_args): 25 | return xz, z 26 | 27 | 28 | @register_module 29 | class Splitter: 30 | """Split a tensor into multiple tensors. 31 | 32 | Args: 33 | *sizes (int): Size of every split 34 | 35 | Attributes: 36 | sizes (tuple[int]): Size of every split 37 | """ 38 | 39 | def __init__(self, size0, *sizes): 40 | self.sizes = (size0,) + sizes 41 | 42 | 43 | @_dispatch 44 | def code(coder: Splitter, xz, z: B.Numeric, x, **kw_args): 45 | d = data_dims(xz) 46 | return xz, Parallel(*split(z, coder.sizes, -d - 1)) 47 | 48 | 49 | @register_module 50 | class RestructureParallel: 51 | """Restructure a parallel of things. 52 | 53 | Args: 54 | current (tuple): Current structure. 55 | new (tuple): New structure. 56 | 57 | Attributes: 58 | current (tuple): Current structure. 59 | new (tuple): New structure. 60 | """ 61 | 62 | def __init__(self, current, new): 63 | self.current = current 64 | self.new = new 65 | 66 | 67 | @_dispatch 68 | def code(coder: RestructureParallel, xz: Parallel, z: Parallel, x, **kw_args): 69 | return ( 70 | _restructure(xz, coder.current, coder.new), 71 | _restructure(z, coder.current, coder.new), 72 | ) 73 | 74 | 75 | @_dispatch 76 | def _restructure(p: Parallel, current: tuple, new: tuple): 77 | element_map = {} 78 | _restructure_assign(element_map, p, current) 79 | return _restructure_create(element_map, new) 80 | 81 | 82 | @_dispatch 83 | def _restructure_assign(element_map: dict, obj, i): 84 | element_map[i] = obj 85 | 86 | 87 | @_dispatch 88 | def _restructure_assign(element_map: dict, p: Parallel, x: tuple): 89 | if len(p) != len(x): 90 | raise RuntimeError("Parallel does not match structure.") 91 | for pi, xi in zip(p, x): 92 | _restructure_assign(element_map, pi, xi) 93 | 94 | 95 | @_dispatch 96 | def _restructure_create(element_map, i): 97 | return element_map[i] 98 | 99 | 100 | @_dispatch 101 | def _restructure_create(element_map, x: tuple): 102 | return Parallel(*(_restructure_create(element_map, xi) for xi in x)) 103 | 104 | 105 | @register_module 106 | class AssertParallel: 107 | """Assert a parallel of `n` elements. 108 | 109 | Args: 110 | n (int): Number of elements asserted in parallel. 111 | 112 | Attributes: 113 | n (int): Number of elements asserted in parallel. 114 | """ 115 | 116 | def __init__(self, n): 117 | self.n = n 118 | 119 | 120 | @_dispatch 121 | def code(p: AssertParallel, xz, z, x, **kw_args): 122 | raise AssertionError(f"Expected a parallel of elements, but got `{xz}` and `{z}`.") 123 | 124 | 125 | @_dispatch 126 | def code(p: AssertParallel, xz: Parallel, z: Parallel, x, **kw_args): 127 | if not len(xz) == len(z) == p.n: 128 | raise AssertionError( 129 | f"Expected a parallel of {p.n} elements, " 130 | f"but got {len(xz)} inputs and {len(z)} outputs." 131 | ) 132 | return xz, z 133 | 134 | 135 | @register_module 136 | class SqueezeParallel: 137 | """If there is a parallel of exactly one element, remove the parallel.""" 138 | 139 | 140 | @_dispatch 141 | def code(p: SqueezeParallel, xz, z, x, **kw_args): 142 | return xz, z 143 | 144 | 145 | @_dispatch 146 | def code(p: SqueezeParallel, xz: Parallel, z: Parallel, x, **kw_args): 147 | if len(xz) == len(z) == 1: 148 | return xz[0], z[0] 149 | else: 150 | return xz, z 151 | 152 | 153 | @register_module 154 | class AssertNoParallel: 155 | """Assert exactly one element in parallel or not a parallel of elements.""" 156 | 157 | 158 | @_dispatch 159 | def code(p: AssertNoParallel, xz, z, x, **kw_args): 160 | return xz, z 161 | 162 | 163 | @_dispatch 164 | def code(p: AssertNoParallel, xz: Parallel, z, x, **kw_args): 165 | raise AssertionError( 166 | "Expected not a parallel of elements, but got inputs in parallel." 167 | ) 168 | 169 | 170 | @_dispatch 171 | def code(p: AssertNoParallel, xz, z: Parallel, x, **kw_args): 172 | raise AssertionError( 173 | "Expected not a parallel of elements, but got outputs in parallel." 174 | ) 175 | 176 | 177 | @_dispatch 178 | def code(p: AssertNoParallel, xz: Parallel, z: Parallel, x, **kw_args): 179 | raise AssertionError( 180 | "Expected not a parallel of elements, but got inputs and outputs in parallel." 181 | ) 182 | -------------------------------------------------------------------------------- /neuralprocesses/coding.py: -------------------------------------------------------------------------------- 1 | import matrix # noqa 2 | from plum import isinstance, issubclass 3 | 4 | from . import _dispatch 5 | from .dist import AbstractDistribution, Dirac 6 | from .parallel import Parallel 7 | from .util import is_composite_coder 8 | 9 | __all__ = [ 10 | "code", 11 | "code_track", 12 | "recode", 13 | "recode_stochastic", 14 | ] 15 | 16 | 17 | @_dispatch 18 | def code(coder, xz, z, x, **kw_args): 19 | """Perform a coding operation. 20 | 21 | The default behaviour is to apply `coder` to `z` and return `(xz, coder(z))`. 22 | 23 | Args: 24 | coder (coder): Coder. 25 | xz (input): Current inputs corresponding to current encoding. 26 | z (tensor): Current encoding. 27 | x (input): Desired inputs. 28 | 29 | Returns: 30 | tuple[input, tensor]: New encoding. 31 | """ 32 | if any( 33 | [ 34 | isinstance(coder, m.signature.types[0]) 35 | and issubclass(m.signature.types[0], object) 36 | and not issubclass(object, m.signature.types[0]) 37 | for m in code.methods 38 | ] 39 | ): 40 | raise RuntimeError( 41 | f"Dispatched to fallback implementation for `code`, but specialised " 42 | f"implementation are available. The arguments are " 43 | f"`({coder}, {xz}, {z}, {x})`." 44 | ) 45 | return xz, coder(z) 46 | 47 | 48 | @_dispatch 49 | def code_track(coder, xz, z, x, **kw_args): 50 | """Perform a coding operation whilst tracking the sequence of target inputs, also 51 | called the history. This history can be used to perform the coding operation again 52 | at that sequence of target inputs exactly. 53 | 54 | Args: 55 | coder (coder): Coder. 56 | xz (input): Current inputs corresponding to current encoding. 57 | z (tensor): Current encoding. 58 | x (input): Desired inputs. 59 | 60 | Returns: 61 | input: Input of encoding. 62 | tensor: Encoding. 63 | list: History. 64 | """ 65 | return code_track(coder, xz, z, x, [], **kw_args) 66 | 67 | 68 | @_dispatch 69 | def code_track(coder, xz, z, x, h, **kw_args): 70 | if is_composite_coder(coder): 71 | raise RuntimeError( 72 | f"Dispatched to fallback implementation of `code_track` for " 73 | f"`{type(coder)}`, but the coder is composite." 74 | ) 75 | xz, z = code(coder, xz, z, x, **kw_args) 76 | return xz, z, h + [x] 77 | 78 | 79 | @_dispatch 80 | def recode(coder, xz, z, h, **kw_args): 81 | """Perform a coding operation at an earlier recorded sequence of target inputs, 82 | called the history. 83 | 84 | Args: 85 | coder (coder): Coder. 86 | xz (input): Current inputs corresponding to current encoding. 87 | z (tensor): Current encoding. 88 | h (list): Target history. 89 | 90 | Returns: 91 | input: Input of encoding. 92 | tensor: Encoding. 93 | list: Remainder of the target history. 94 | """ 95 | if is_composite_coder(coder): 96 | raise RuntimeError( 97 | f"Dispatched to fallback implementation of `recode` for " 98 | f"`{type(coder)}`, but the coder is composite." 99 | ) 100 | xz, z = code(coder, xz, z, h[0], **kw_args) 101 | return xz, z, h[1:] 102 | 103 | 104 | @_dispatch 105 | def recode_stochastic(coders: Parallel, codings: Parallel, xc, yc, h, **kw_args): 106 | """In an existing aggregate coding `codings`, recode the codings that are not 107 | :class:`.dist.Dirac`s for a new context set. 108 | 109 | Args: 110 | coders (:class:`.parallel.Parallel`): Coders that producing the codings. 111 | codings (:class:`.parallel.Parallel`): Codings. 112 | xc (tensor): Inputs of new context set. 113 | yc (tensor): Outputs of new context set. 114 | h (list): History. 115 | 116 | Returns: 117 | :class:`.parallel.Parallel`: Updated coding. 118 | """ 119 | return Parallel( 120 | *( 121 | recode_stochastic(coder, coding, xc, yc, hi, **kw_args) 122 | for (coder, coding, hi) in zip(coders, codings, h[0]) 123 | ) 124 | ) 125 | 126 | 127 | @_dispatch 128 | def recode_stochastic(coder, coding: Dirac, xc, yc, h, **kw_args): 129 | # Do not recode `Dirac`s. 130 | return coding 131 | 132 | 133 | # If the coding is aggregate, it can still contain `Dirac`s, so we need to be careful. 134 | 135 | 136 | @_dispatch 137 | def recode_stochastic(coder, coding, xc, yc, h, **kw_args): 138 | # Do not recode `Dirac`s. 139 | return _choose(recode(coder, xc, yc, h, **kw_args)[1], coding) 140 | 141 | 142 | @_dispatch 143 | def _choose(new: Parallel, old: Parallel): 144 | return Parallel(*(_choose(x, y) for x, y in zip(new, old))) 145 | 146 | 147 | @_dispatch 148 | def _choose(new: Dirac, old: Dirac): 149 | # Do not recode `Dirac`s. 150 | return old 151 | 152 | 153 | @_dispatch 154 | def _choose(new: AbstractDistribution, old: AbstractDistribution): 155 | # Do recode other distributions. 156 | return new 157 | -------------------------------------------------------------------------------- /neuralprocesses/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch import * 2 | from .data import * 3 | from .eeg import * 4 | from .gp import * 5 | from .mixgp import * 6 | from .mixture import * 7 | from .predefined import * 8 | from .predprey import * 9 | from .sawtooth import * 10 | from .temperature import * 11 | -------------------------------------------------------------------------------- /neuralprocesses/data/batch.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | 3 | from .. import _dispatch 4 | from ..aggregate import Aggregate, AggregateInput 5 | from ..augment import AugmentedInput 6 | from ..mask import Masked 7 | 8 | __all__ = [ 9 | "batch_index", 10 | "batch_xc", 11 | "batch_yc", 12 | "batch_xt", 13 | "batch_yt", 14 | ] 15 | 16 | 17 | @_dispatch 18 | def batch_index(batch: dict, index): 19 | """Index into the tensors of a batch. 20 | 21 | Args: 22 | batch (dict): Batch dictionary. 23 | index (object): Index. 24 | 25 | Returns: 26 | dict: `batch` indexed at `index`. 27 | """ 28 | return {k: batch_index(v, index) for k, v in batch.items()} 29 | 30 | 31 | @_dispatch 32 | def batch_index(x: B.Numeric, index): 33 | return x[index] 34 | 35 | 36 | @_dispatch 37 | def batch_index(t: tuple, index): 38 | return tuple(batch_index(ti, index) for ti in t) 39 | 40 | 41 | @_dispatch 42 | def batch_index(t: list, index): 43 | return [batch_index(ti, index) for ti in t] 44 | 45 | 46 | @_dispatch 47 | def batch_index(_: None, index): 48 | return None 49 | 50 | 51 | @_dispatch 52 | def batch_index(xt: AggregateInput, index): 53 | return AggregateInput(*((batch_index(xti, index), i) for xti, i in xt)) 54 | 55 | 56 | @_dispatch 57 | def batch_index(yt: Aggregate, index): 58 | return Aggregate(*(batch_index(yti, index) for yti in yt)) 59 | 60 | 61 | @_dispatch 62 | def batch_index(y: Masked, index): 63 | return Masked(batch_index(y.y, index), batch_index(y.mask, index)) 64 | 65 | 66 | @_dispatch 67 | def batch_index(x: AugmentedInput, index): 68 | return AugmentedInput( 69 | batch_index(x.x, index), 70 | AugmentedInput(x.augmentation, index), 71 | ) 72 | 73 | 74 | @_dispatch 75 | def batch_xc(batch: dict, i: int): 76 | """Get the context inputs for a particular output dimension. 77 | 78 | Args: 79 | batch (dict): Batch dictionary. 80 | i (int): Index of output. 81 | 82 | Returns: 83 | tensor: Context inputs. 84 | """ 85 | return batch["contexts"][i][0] 86 | 87 | 88 | @_dispatch 89 | def batch_yc(batch: dict, i: int): 90 | """Get the context outputs for a particular output dimension. 91 | 92 | Args: 93 | batch (dict): Batch dictionary. 94 | i (int): Index of output. 95 | 96 | Returns: 97 | tensor: Context outputs. 98 | """ 99 | return _batch_yc(batch["contexts"][i][1]) 100 | 101 | 102 | @_dispatch 103 | def _batch_yc(yc: B.Numeric): 104 | return yc[..., 0, :] 105 | 106 | 107 | @_dispatch 108 | def _batch_yc(yc: Masked): 109 | with B.on_device(yc.y): 110 | nan = B.to_active_device(B.cast(B.dtype(yc.y), B.nan)) 111 | return B.where(yc.mask[..., 0, :] == 1, yc.y[..., 0, :], nan) 112 | 113 | 114 | @_dispatch 115 | def batch_xt(batch: dict, i: int): 116 | """Get the target inputs for a particular output dimension. 117 | 118 | Args: 119 | batch (dict): Batch dictionary. 120 | i (int): Index of output. 121 | 122 | Returns: 123 | tensor: Target inputs. 124 | """ 125 | return _batch_xt(batch["xt"], i) 126 | 127 | 128 | @_dispatch 129 | def _batch_xt(x: B.Numeric, i: int): 130 | return x 131 | 132 | 133 | @_dispatch 134 | def _batch_xt(x: AggregateInput, i: int): 135 | return x[[xi[1] for xi in x].index(i)][0] 136 | 137 | 138 | @_dispatch 139 | def _batch_xt(x: AugmentedInput, i: int): 140 | return _batch_xt(x.x, i) 141 | 142 | 143 | @_dispatch 144 | def batch_yt(batch: dict, i: int): 145 | """Get the target outputs for a particular output dimension. 146 | 147 | Args: 148 | batch (dict): Batch dictionary. 149 | i (int): Index of output. 150 | 151 | Returns: 152 | tensor: Target outputs. 153 | """ 154 | return _batch_yt(batch["xt"], batch["yt"], i) 155 | 156 | 157 | @_dispatch 158 | def _batch_yt(x: B.Numeric, y: B.Numeric, i: int): 159 | return y[..., i, :] 160 | 161 | 162 | @_dispatch 163 | def _batch_yt(x: AggregateInput, y: Aggregate, i: int): 164 | return y[[xi[1] for xi in x].index(i)][..., 0, :] 165 | -------------------------------------------------------------------------------- /neuralprocesses/data/bimodal.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import numpy as np 3 | from lab.shape import Dimension 4 | 5 | from neuralprocesses.data import SyntheticGenerator, new_batch 6 | 7 | __all__ = ["BiModalGenerator"] 8 | 9 | 10 | class BiModalGenerator(SyntheticGenerator): 11 | """Bi-modal distribution generator. 12 | 13 | Further takes in arguments and keyword arguments from the constructor of 14 | :class:`.data.SyntheticGenerator`. Moreover, also has the attributes of 15 | :class:`.data.SyntheticGenerator`. 16 | """ 17 | 18 | def __init__(self, *args, **kw_args): 19 | super().__init__(*args, **kw_args) 20 | 21 | def generate_batch(self): 22 | with B.on_device(self.device): 23 | set_batch, xcs, xc, nc, xts, xt, nt = new_batch(self, self.dim_y) 24 | x = B.concat(xc, xt, axis=1) 25 | 26 | # Draw a different random phase, amplitude, and period for every task in 27 | # the batch. 28 | self.state, rand = B.rand( 29 | self.state, 30 | self.float64, 31 | 3, 32 | self.batch_size, 33 | 1, # Broadcast over `n`. 34 | 1, # There is only one input dimension. 35 | ) 36 | phase = 2 * B.pi * rand[0] 37 | amplitude = 1 + rand[1] 38 | period = 1 + rand[2] 39 | 40 | # Construct the noiseless function. 41 | f = amplitude * B.sin(phase + (2 * B.pi / period) * x) 42 | 43 | # Add noise with variance. 44 | probs = B.cast(self.float64, np.array([0.5, 0.5])) 45 | means = B.cast(self.float64, np.array([-0.1, 0.1])) 46 | variance = 1 47 | # Randomly choose from `means` with probabilities `probs`. 48 | self.state, mean = B.choice(self.state, means, self.batch_size, p=probs) 49 | self.state, randn = B.randn( 50 | self.state, 51 | self.float64, 52 | self.batch_size, 53 | # `nc` and `nt` are tensors rather than plain integers. Tell dispatch 54 | # that they can be interpreted as dimensions of a shape. 55 | Dimension(nc + nt), 56 | 1, 57 | ) 58 | noise = B.sqrt(variance) * randn + mean[:, None, None] 59 | 60 | # Construct the noisy function. 61 | y = f + noise 62 | 63 | batch = {} 64 | set_batch(batch, y[:, :nc], y[:, nc:]) 65 | return batch 66 | -------------------------------------------------------------------------------- /neuralprocesses/data/gp.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import stheno 3 | 4 | from .data import SyntheticGenerator, new_batch 5 | 6 | __all__ = ["GPGenerator"] 7 | 8 | 9 | class GPGenerator(SyntheticGenerator): 10 | """GP generator. 11 | 12 | Further takes in arguments and keyword arguments from the constructor of 13 | :class:`.data.SyntheticGenerator`. Moreover, also has the attributes of 14 | :class:`.data.SyntheticGenerator`. 15 | 16 | Args: 17 | kernel (:class:`stheno.Kernel`, optional): Kernel of the GP. Defaults to an 18 | EQ kernel with length scale `0.25`. 19 | pred_logpdf (bool, optional): Also compute the logpdf of the target set given 20 | the context set under the true GP. Defaults to `True`. 21 | pred_logpdf_diag (bool, optional): Also compute the logpdf of the target set 22 | given the context set under the true diagonalised GP. Defaults to `True`. 23 | 24 | Attributes: 25 | kernel (:class:`stheno.Kernel`): Kernel of the GP. 26 | pred_logpdf (bool): Also compute the logpdf of the target set given the context 27 | set under the true GP. 28 | pred_logpdf_diag (bool): Also compute the logpdf of the target set given the 29 | context set under the true diagonalised GP. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | *args, 35 | kernel=stheno.EQ().stretch(0.25), 36 | pred_logpdf=True, 37 | pred_logpdf_diag=True, 38 | **kw_args, 39 | ): 40 | self.kernel = kernel 41 | self.pred_logpdf = pred_logpdf 42 | self.pred_logpdf_diag = pred_logpdf_diag 43 | super().__init__(*args, **kw_args) 44 | 45 | def generate_batch(self): 46 | """Generate a batch. 47 | 48 | Returns: 49 | dict: A batch, which is a dictionary with keys "xc", "yc", "xt", and "yt". 50 | Also possibly contains the keys "pred_logpdf" and "pred_logpdf_diag". 51 | """ 52 | with B.on_device(self.device): 53 | set_batch, xcs, xc, nc, xts, xt, nt = new_batch(self, self.dim_y) 54 | 55 | # If `self.h` is specified, then we create a multi-output GP. Otherwise, we 56 | # use a simple regular GP. 57 | if self.h is None: 58 | with stheno.Measure() as prior: 59 | f = stheno.GP(self.kernel) 60 | # Construct FDDs for the context and target points. 61 | fc = f(xc, self.noise) 62 | ft = f(xt, self.noise) 63 | else: 64 | with stheno.Measure() as prior: 65 | # Construct latent processes and initialise output processes. 66 | us = [stheno.GP(self.kernel) for _ in range(self.dim_y_latent)] 67 | fs = [0 for _ in range(self.dim_y)] 68 | # Perform matrix multiplication. 69 | for i in range(self.dim_y): 70 | for j in range(self.dim_y_latent): 71 | fs[i] = fs[i] + self.h[i, j] * us[j] 72 | # Finally, construct the multi-output GP. 73 | f = stheno.cross(*fs) 74 | # Construct FDDs for the context and target points. 75 | fc = f( 76 | tuple(fi(xci) for fi, xci in zip(fs, xcs)), 77 | self.noise, 78 | ) 79 | ft = f( 80 | tuple(fi(xti) for fi, xti in zip(fs, xts)), 81 | self.noise, 82 | ) 83 | 84 | # Sample context and target set. 85 | self.state, yc, yt = prior.sample(self.state, fc, ft) 86 | 87 | # Make the new batch. 88 | batch = {} 89 | set_batch(batch, yc, yt) 90 | 91 | # Compute predictive logpdfs. 92 | if self.pred_logpdf or self.pred_logpdf_diag: 93 | post = prior | (fc, yc) 94 | if self.pred_logpdf: 95 | batch["pred_logpdf"] = post(ft).logpdf(yt) 96 | if self.pred_logpdf_diag: 97 | batch["pred_logpdf_diag"] = post(ft).diagonalise().logpdf(yt) 98 | 99 | return batch 100 | -------------------------------------------------------------------------------- /neuralprocesses/data/mixgp.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import stheno 3 | 4 | from ..dist import UniformDiscrete 5 | from .gp import GPGenerator, new_batch 6 | 7 | __all__ = ["MixtureGPGenerator"] 8 | 9 | 10 | class MixtureGPGenerator(GPGenerator): 11 | def __init__( 12 | self, 13 | *args, 14 | mean_diff=0.0, 15 | **kw_args, 16 | ): 17 | super().__init__(*args, **kw_args) 18 | 19 | self.mean_diff = mean_diff 20 | 21 | def generate_batch(self): 22 | """Generate a batch. 23 | 24 | Returns: 25 | dict: A batch, which is a dictionary with keys "xc", "yc", "xt", and "yt". 26 | Also possibly contains the keys "pred_logpdf" and "pred_logpdf_diag". 27 | """ 28 | 29 | with B.on_device(self.device): 30 | set_batch, xcs, xc, nc, xts, xt, nt = new_batch(self, self.dim_y) 31 | 32 | # If `self.h` is specified, then we create a multi-output GP. Otherwise, we 33 | # use a simple regular GP. 34 | if self.h is None: 35 | with stheno.Measure() as prior: 36 | f = stheno.GP(self.kernel) 37 | # Construct FDDs for the context and target points. 38 | fc = f(xc, self.noise) 39 | ft = f(xt, self.noise) 40 | else: 41 | with stheno.Measure() as prior: 42 | # Construct latent processes and initialise output processes. 43 | us = [stheno.GP(self.kernel) for _ in range(self.dim_y_latent)] 44 | fs = [0 for _ in range(self.dim_y)] 45 | # Perform matrix multiplication. 46 | for i in range(self.dim_y): 47 | for j in range(self.dim_y_latent): 48 | fs[i] = fs[i] + self.h[i, j] * us[j] 49 | # Finally, construct the multi-output GP. 50 | f = stheno.cross(*fs) 51 | # Construct FDDs for the context and target points. 52 | fc = f( 53 | tuple(fi(xci) for fi, xci in zip(fs, xcs)), 54 | self.noise, 55 | ) 56 | ft = f( 57 | tuple(fi(xti) for fi, xti in zip(fs, xts)), 58 | self.noise, 59 | ) 60 | 61 | # Sample context and target set. 62 | self.state, yc, yt = prior.sample(self.state, fc, ft) 63 | 64 | self.state, i = UniformDiscrete(0, 1).sample( 65 | self.state, 66 | self.int64, 67 | self.batch_size, 68 | ) 69 | mean = self.mean_diff / 2 - i * self.mean_diff 70 | 71 | yc = yc + mean[:, None, None] 72 | yt = yt + mean[:, None, None] 73 | 74 | # Make the new batch. 75 | batch = {} 76 | set_batch(batch, yc, yt) 77 | 78 | # Compute predictive logpdfs. 79 | if self.pred_logpdf or self.pred_logpdf_diag: 80 | post = prior | (fc, yc) 81 | if self.pred_logpdf: 82 | batch["pred_logpdf"] = post(ft).logpdf(yt) 83 | if self.pred_logpdf_diag: 84 | batch["pred_logpdf_diag"] = post(ft).diagonalise().logpdf(yt) 85 | 86 | return batch 87 | -------------------------------------------------------------------------------- /neuralprocesses/data/mixture.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import numpy as np 3 | 4 | from .. import _dispatch 5 | from .data import AbstractGenerator, DataGenerator 6 | 7 | __all__ = ["MixtureGenerator"] 8 | 9 | 10 | class MixtureGenerator(AbstractGenerator): 11 | """A mixture of data generators. 12 | 13 | Args: 14 | *gens (:class:`.data.DataGenerator`): Components of the mixture. 15 | seed (int, optional): Random seed. Defaults to `0`. 16 | 17 | Attributes: 18 | num_batches (int): Number of batches in an epoch. 19 | gens (tuple[:class:`.data.SyntheticGenerator`]): Components of the mixture. 20 | state (random state): Random state. 21 | """ 22 | 23 | @_dispatch 24 | def __init__(self, *gens: DataGenerator, seed=0): 25 | if not all(gens[0].num_batches == g.num_batches for g in gens[1:]): 26 | raise ValueError( 27 | f"Attribute `num_batches` inconsistent between elements of the mixture." 28 | ) 29 | self.num_batches = gens[0].num_batches 30 | self.gens = gens 31 | self.state = B.create_random_state(np.float64, seed=seed) 32 | 33 | def generate_batch(self): 34 | self.state, i = B.randint(self.state, np.int64, upper=len(self.gens)) 35 | return self.gens[i].generate_batch() 36 | -------------------------------------------------------------------------------- /neuralprocesses/data/predefined.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | from stheno import EQ, Matern52 3 | 4 | from ..dist.uniform import UniformContinuous, UniformDiscrete 5 | from .gp import GPGenerator 6 | from .mixgp import MixtureGPGenerator 7 | from .mixture import MixtureGenerator 8 | from .sawtooth import SawtoothGenerator 9 | 10 | __all__ = ["construct_predefined_gens"] 11 | 12 | 13 | def construct_predefined_gens( 14 | dtype, 15 | seed=0, 16 | batch_size=16, 17 | num_tasks=2**14, 18 | dim_x=1, 19 | dim_y=1, 20 | x_range_context=(-2, 2), 21 | x_range_target=(-2, 2), 22 | mean_diff=0.0, 23 | pred_logpdf=True, 24 | pred_logpdf_diag=True, 25 | device="cpu", 26 | ): 27 | """Construct a number of predefined data generators. 28 | 29 | Args: 30 | dtype (dtype): Data type to generate. 31 | seed (int, optional): Seed. Defaults to `0`. 32 | batch_size (int, optional): Batch size. Defaults to 16. 33 | num_tasks (int, optional): Number of tasks to generate per epoch. Must be an 34 | integer multiple of `batch_size`. Defaults to 2^14. 35 | dim_x (int, optional): Dimensionality of the input space. Defaults to `1`. 36 | dim_y (int, optional): Dimensionality of the output space. Defaults to `1`. 37 | x_range_context (tuple[float, float], optional): Range of the inputs of the 38 | context points. Defaults to `(-2, 2)`. 39 | x_range_target (tuple[float, float], optional): Range of the inputs of the 40 | target points. Defaults to `(-2, 2)`. 41 | mean_diff (float, optional): Difference in means in the samples of 42 | :class:`neuralprocesses.data.mixgp.MixtureGPGenerator`. 43 | pred_logpdf (bool, optional): Also compute the logpdf of the target set given 44 | the context set under the true GP. Defaults to `True`. 45 | pred_logpdf_diag (bool, optional): Also compute the logpdf of the target set 46 | given the context set under the true diagonalised GP. Defaults to `True`. 47 | device (str, optional): Device on which to generate data. Defaults to `cpu`. 48 | 49 | Returns: 50 | dict: A dictionary mapping names of data generators to the generators. 51 | """ 52 | # Ensure that distances don't become bigger as we increase the input dimensionality. 53 | # We achieve this by blowing up all length scales by `sqrt(dim_x)`. 54 | factor = B.sqrt(dim_x) 55 | config = { 56 | "num_tasks": num_tasks, 57 | "batch_size": batch_size, 58 | "dist_x_context": UniformContinuous(*((x_range_context,) * dim_x)), 59 | "dist_x_target": UniformContinuous(*((x_range_target,) * dim_x)), 60 | "dim_y": dim_y, 61 | "device": device, 62 | } 63 | kernels = { 64 | "eq": EQ().stretch(factor * 0.25), 65 | "matern": Matern52().stretch(factor * 0.25), 66 | "weakly-periodic": ( 67 | EQ().stretch(factor * 0.5) * EQ().stretch(factor).periodic(factor * 0.25) 68 | ), 69 | } 70 | gens = { 71 | name: GPGenerator( 72 | dtype, 73 | seed=seed, 74 | noise=0.05, 75 | kernel=kernel, 76 | num_context=UniformDiscrete(0, 30 * dim_x), 77 | num_target=UniformDiscrete(50 * dim_x, 50 * dim_x), 78 | pred_logpdf=pred_logpdf, 79 | pred_logpdf_diag=pred_logpdf_diag, 80 | **config, 81 | ) 82 | for name, kernel in kernels.items() 83 | } 84 | # Previously, the maximum number of context points was `75 * dim_x`. However, if 85 | # `dim_x == 1`, then this is too high. We therefore change that case, and keep all 86 | # other cases the same. 87 | max_context = 30 if dim_x == 1 else 75 * dim_x 88 | gens["sawtooth"] = SawtoothGenerator( 89 | dtype, 90 | seed=seed, 91 | # The sawtooth is hard already as it is. Do not add noise. 92 | noise=0, 93 | dist_freq=UniformContinuous(2 / factor, 4 / factor), 94 | num_context=UniformDiscrete(0, max_context), 95 | num_target=UniformDiscrete(100 * dim_x, 100 * dim_x), 96 | **config, 97 | ) 98 | # Be sure to use different seeds in the mixture components. 99 | gens["mixture"] = MixtureGenerator( 100 | *( 101 | GPGenerator( 102 | dtype, 103 | seed=seed + i, 104 | noise=0.05, 105 | kernel=kernel, 106 | num_context=UniformDiscrete(0, max_context), 107 | num_target=UniformDiscrete(100 * dim_x, 100 * dim_x), 108 | pred_logpdf=pred_logpdf, 109 | pred_logpdf_diag=pred_logpdf_diag, 110 | **config, 111 | ) 112 | # Make sure that the order of `kernels.items()` is fixed. 113 | for i, (_, kernel) in enumerate(sorted(kernels.items(), key=lambda x: x[0])) 114 | ), 115 | SawtoothGenerator( 116 | dtype, 117 | seed=seed + len(kernels.items()), 118 | # The sawtooth is hard already as it is. Do not add noise. 119 | noise=0, 120 | dist_freq=UniformContinuous(2 / factor, 4 / factor), 121 | num_context=UniformDiscrete(0, max_context), 122 | num_target=UniformDiscrete(100 * dim_x, 100 * dim_x), 123 | **config, 124 | ), 125 | seed=seed, 126 | ) 127 | 128 | for i, kernel in enumerate(kernels.keys()): 129 | gens[f"mix-{kernel}"] = MixtureGPGenerator( 130 | dtype, 131 | seed=seed + len(kernels.items()) + i + 1, 132 | noise=0.05, 133 | kernel=kernels[kernel], 134 | num_context=UniformDiscrete(0, 30 * dim_x), 135 | num_target=UniformDiscrete(50 * dim_x, 50 * dim_x), 136 | pred_logpdf=False, 137 | pred_logpdf_diag=False, 138 | mean_diff=mean_diff, 139 | **config, 140 | ) 141 | 142 | return gens 143 | -------------------------------------------------------------------------------- /neuralprocesses/data/sawtooth.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | 3 | from ..dist import UniformContinuous 4 | from .data import SyntheticGenerator, new_batch 5 | 6 | __all__ = ["SawtoothGenerator"] 7 | 8 | 9 | class SawtoothGenerator(SyntheticGenerator): 10 | """GP generator. 11 | 12 | Further takes in arguments and keyword arguments from the constructor of 13 | :class:`.data.SyntheticGenerator`. Moreover, also has the attributes of 14 | :class:`.data.SyntheticGenerator`. 15 | 16 | Args: 17 | dist_freq (:class:`neuralprocesses.dist.dist.AbstractDistribution`, optional): 18 | Distribution of the frequency. Defaults to the uniform distribution over 19 | $[3, 5]$. 20 | 21 | Attributes: 22 | dist_freq (:class:`neuralprocesses.dist.dist.AbstractDistribution`): 23 | Distribution of the frequency. 24 | """ 25 | 26 | def __init__(self, *args, dist_freq=UniformContinuous(3, 5), **kw_args): 27 | super().__init__(*args, **kw_args) 28 | self.dist_freq = dist_freq 29 | 30 | def generate_batch(self): 31 | with B.on_device(self.device): 32 | set_batch, xcs, xc, nc, xts, xt, nt = new_batch(self, self.dim_y) 33 | x = B.concat(xc, xt, axis=1) 34 | 35 | # Sample a frequency. 36 | self.state, freq = self.dist_freq.sample( 37 | self.state, 38 | self.float64, 39 | self.batch_size, 40 | self.dim_y_latent, 41 | ) 42 | 43 | # Sample a direction. 44 | self.state, direction = B.randn( 45 | self.state, 46 | self.float64, 47 | self.batch_size, 48 | self.dim_y_latent, 49 | B.shape(x, 2), 50 | ) 51 | norm = B.sqrt(B.sum(direction * direction, axis=2, squeeze=False)) 52 | direction = direction / norm 53 | 54 | # Sample a uniformly distributed (conditional on frequency) offset. 55 | self.state, sample = B.rand( 56 | self.state, 57 | self.float64, 58 | self.batch_size, 59 | self.dim_y_latent, 60 | 1, 61 | ) 62 | offset = sample / freq 63 | 64 | # Construct the sawtooth and add noise. 65 | f = (freq * (B.matmul(direction, x, tr_b=True) - offset)) % 1 66 | if self.h is not None: 67 | f = B.matmul(self.h, f) 68 | y = f + B.sqrt(self.noise) * B.randn(f) 69 | 70 | # Finalise batch. 71 | batch = {} 72 | set_batch(batch, y[:, :, :nc], y[:, :, nc:], transpose=False) 73 | return batch 74 | -------------------------------------------------------------------------------- /neuralprocesses/data/util.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | __all__ = ["cache"] 4 | 5 | 6 | def cache(f): 7 | """A decorator that caches the output of a function. It assumes that all arguments 8 | and keyword argument""" 9 | _f_cache = {} 10 | 11 | @wraps(f) 12 | def f_wrapped(*args, **kw_args): 13 | cache_key = (args, frozenset(kw_args.items())) 14 | try: 15 | return _f_cache[cache_key] 16 | except KeyError: 17 | # Cache miss. Perform computation. 18 | _f_cache[cache_key] = f(*args, **kw_args) 19 | return _f_cache[cache_key] 20 | 21 | return f_wrapped 22 | -------------------------------------------------------------------------------- /neuralprocesses/datadims.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | 3 | from . import _dispatch 4 | from .aggregate import Aggregate, AggregateInput 5 | from .augment import AugmentedInput 6 | from .parallel import Parallel 7 | 8 | __all__ = ["data_dims"] 9 | 10 | 11 | @_dispatch 12 | def data_dims(x: B.Numeric): 13 | """Check how many data dimensions the encoding corresponding to an input has. 14 | 15 | Args: 16 | x (input): Input. 17 | 18 | Returns: 19 | int: Number of data dimensions. 20 | """ 21 | return 1 22 | 23 | 24 | @_dispatch 25 | def data_dims(x: None): 26 | return 1 27 | 28 | 29 | @_dispatch 30 | def data_dims(x: tuple): 31 | return len(x) 32 | 33 | 34 | @_dispatch 35 | def data_dims(x: Parallel): 36 | return _data_dims_merge(*(data_dims(xi) for xi in x)) 37 | 38 | 39 | @_dispatch 40 | def _data_dims_merge(d1, d2, d3, *ds): 41 | d = _data_dims_merge(d1, d2) 42 | for di in (d3,) + ds: 43 | d = _data_dims_merge(d, di) 44 | return d 45 | 46 | 47 | @_dispatch 48 | def _data_dims_merge(d): 49 | return d 50 | 51 | 52 | @_dispatch 53 | def _data_dims_merge(d1, d2): 54 | if d1 == d2: 55 | return d1 56 | else: 57 | raise RuntimeError(f"Cannot reconcile data dimensionalities {d1} and {d2}.") 58 | 59 | 60 | @_dispatch 61 | def data_dims(x: AggregateInput): 62 | return Aggregate(*(data_dims(xi) for xi, _ in x)) 63 | 64 | 65 | @_dispatch 66 | def _data_dims_merge(d1: Aggregate, d2: Aggregate): 67 | return Aggregate(*(_data_dims_merge(d1i, d2i) for d1i, d2i in zip(d1, d2))) 68 | 69 | 70 | @_dispatch 71 | def _data_dims_merge(d1: Aggregate, d2): 72 | return _data_dims_merge(*(_data_dims_merge(d1i, d2) for d1i in d1)) 73 | 74 | 75 | @_dispatch 76 | def _data_dims_merge(d1, d2: Aggregate): 77 | return _data_dims_merge(*(_data_dims_merge(d1, d2i) for d2i in d2)) 78 | 79 | 80 | @_dispatch 81 | def data_dims(x: AugmentedInput): 82 | return data_dims(x.x) 83 | -------------------------------------------------------------------------------- /neuralprocesses/disc.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import lab as B 4 | from lab.shape import Dimension 5 | 6 | from . import _dispatch 7 | from .aggregate import AggregateInput 8 | from .augment import AugmentedInput 9 | from .parallel import Parallel 10 | from .util import batch, is_nonempty, register_module 11 | 12 | __all__ = ["Discretisation"] 13 | 14 | 15 | @register_module 16 | class Discretisation: 17 | """Discretisation. 18 | 19 | Args: 20 | points_per_unit (float): Density of the discretisation. 21 | multiple (int, optional): Always produce a discretisation which is a multiple 22 | of this number. Defaults to `1`. 23 | margin (float, optional): Leave this much space around the most extremal points. 24 | Defaults to `0.1`. 25 | dim (int, optional): Dimensionality of the inputs. 26 | 27 | Attributes: 28 | resolution (float): Resolution of the discretisation. Equal to the inverse of 29 | `points_per_unit`. 30 | multiple (int): Always produce a discretisation which is a multiple of this 31 | number. 32 | margin (float): Leave this much space around the most extremal points. 33 | dim (int): Dimensionality of the inputs. 34 | """ 35 | 36 | def __init__(self, points_per_unit, multiple=1, margin=0.1, dim=None): 37 | self.points_per_unit = points_per_unit 38 | self.resolution = 1 / self.points_per_unit 39 | self.multiple = multiple 40 | self.margin = margin 41 | self.dim = dim 42 | 43 | def discretise_1d(self, *args, margin): 44 | """Perform the discretisation for one-dimensional inputs. 45 | 46 | Args: 47 | *args (input): One-dimensional inputs. 48 | margin (float): Leave this much space around the most extremal points. 49 | 50 | Returns: 51 | tensor: Discretisation. 52 | """ 53 | # Filter global and empty inputs. 54 | args = [x for x in args if x is not None and is_nonempty(x)] 55 | grid_min = B.min(B.stack(*[B.min(x) for x in args])) 56 | grid_max = B.max(B.stack(*[B.max(x) for x in args])) 57 | 58 | # Add margin. 59 | grid_min = grid_min - margin 60 | grid_max = grid_max + margin 61 | 62 | # Account for snapping to the grid (below). 63 | grid_min = grid_min - self.resolution 64 | grid_max = grid_max + self.resolution 65 | 66 | # Ensure that the multiple is respected. Add one point to account for the end. 67 | n_raw = (grid_max - grid_min) / self.resolution + 1 68 | n = B.ceil(n_raw / self.multiple) * self.multiple 69 | 70 | # Nicely shift the grid to account for the extra points. 71 | grid_start = grid_min - (n - n_raw) * self.resolution / 2 72 | 73 | # Snap to the nearest grid point. We accounted for this above. 74 | grid_start = B.round(grid_start / self.resolution) * self.resolution 75 | 76 | # Produce the grid. 77 | b = batch(args[0], 2) 78 | with B.on_device(args[0]): 79 | return B.tile( 80 | B.expand_dims( 81 | B.linspace( 82 | B.dtype(args[0]), 83 | grid_start, 84 | grid_start + (n - 1) * self.resolution, 85 | # Tell LAB that it can be interpreted as an integer. 86 | Dimension(B.cast(B.dtype_int(n), n)), 87 | ), 88 | axis=0, 89 | times=len(b) + 1, 90 | ), 91 | *b, 92 | 1, 93 | 1, 94 | ) 95 | 96 | def __call__(self, *args, margin=None, **kw_args): 97 | """Perform the discretisation for multi-dimensional inputs. 98 | 99 | Args: 100 | *args (input): Multi-dimensional inputs. 101 | margin (float, optional): Leave this much space around the most extremal 102 | points. Defaults to `self.margin`. 103 | 104 | Returns: 105 | input: Discretisation. 106 | """ 107 | if margin is None: 108 | margin = self.margin 109 | coords = _split_coordinates(Parallel(*args), dim=self.dim) 110 | discs = tuple(self.discretise_1d(*cs, margin=margin) for cs in coords) 111 | return discs[0] if len(discs) == 1 else discs 112 | 113 | 114 | @_dispatch 115 | def _split_coordinates( 116 | x: B.Numeric, dim: Optional[int] = None 117 | ) -> List[List[B.Numeric]]: 118 | # Cast with `int` so we can safely pass it to `range` below! 119 | dim = dim or int(B.shape(x, -2)) 120 | return [[x[..., i : i + 1, :]] for i in range(dim)] 121 | 122 | 123 | @_dispatch 124 | def _split_coordinates(x: Parallel, dim: Optional[int] = None) -> List[List[B.Numeric]]: 125 | all_coords = zip(*(_split_coordinates(xi, dim=dim) for xi in x)) 126 | return [sum(coords, []) for coords in all_coords] 127 | 128 | 129 | @_dispatch 130 | def _split_coordinates(x: tuple, dim: Optional[int] = None) -> List[List[B.Numeric]]: 131 | return [[xi] for xi in x] 132 | 133 | 134 | @_dispatch 135 | def _split_coordinates( 136 | x: AugmentedInput, dim: Optional[int] = None 137 | ) -> List[List[B.Numeric]]: 138 | return _split_coordinates(x.x, dim=dim) 139 | 140 | 141 | @_dispatch 142 | def _split_coordinates( 143 | x: AggregateInput, dim: Optional[int] = None 144 | ) -> List[List[B.Numeric]]: 145 | # Can treat it like a parallel of inputs. However, be sure to remove the indices. 146 | return _split_coordinates(Parallel(*(xi for xi, i in x)), dim=dim) 147 | -------------------------------------------------------------------------------- /neuralprocesses/dist/__init__.py: -------------------------------------------------------------------------------- 1 | from .beta import * 2 | from .dirac import * 3 | from .dist import * 4 | from .gamma import * 5 | from .geom import * 6 | from .normal import * 7 | from .spikeslab import * 8 | from .transformed import * 9 | from .uniform import * 10 | -------------------------------------------------------------------------------- /neuralprocesses/dist/beta.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | from matrix.shape import broadcast 3 | from plum import parametric 4 | 5 | from .. import _dispatch 6 | from ..aggregate import Aggregate 7 | from ..mask import Masked 8 | from .dist import AbstractDistribution, shape_batch 9 | 10 | __all__ = ["Beta"] 11 | 12 | 13 | @parametric 14 | class Beta(AbstractDistribution): 15 | """Beta distribution. 16 | 17 | Args: 18 | alpha (tensor): Shape parameter `alpha`. 19 | beta (tensor): Shape parameter `beta`. 20 | d (int): Dimensionality of the data. 21 | 22 | Attributes: 23 | alpha (tensor): Shape parameter `alpha`. 24 | beta (tensor): Shape parameter `beta`. 25 | d (int): Dimensionality of the data. 26 | """ 27 | 28 | def __init__(self, alpha, beta, d): 29 | self.alpha = alpha 30 | self.beta = beta 31 | self.d = d 32 | 33 | @property 34 | def mean(self): 35 | return B.divide(self.alpha, B.add(self.alpha, self.beta)) 36 | 37 | @property 38 | def var(self): 39 | sum = B.add(self.alpha, self.beta) 40 | with B.on_device(sum): 41 | one = B.one(sum) 42 | return B.divide( 43 | B.multiply(self.alpha, self.beta), 44 | B.multiply(B.multiply(sum, sum), B.add(sum, one)), 45 | ) 46 | 47 | @_dispatch 48 | def sample( 49 | self: "Beta[Aggregate, Aggregate, Aggregate]", 50 | state: B.RandomState, 51 | dtype: B.DType, 52 | *shape, 53 | ): 54 | samples = [] 55 | for ai, bi, di in zip(self.alpha, self.beta, self.d): 56 | state, sample = Beta(ai, bi, di).sample(state, dtype, *shape) 57 | samples.append(sample) 58 | return state, Aggregate(*samples) 59 | 60 | @_dispatch 61 | def sample( 62 | self: "Beta[B.Numeric, B.Numeric, B.Int]", 63 | state: B.RandomState, 64 | dtype: B.DType, 65 | *shape, 66 | ): 67 | return B.randbeta(state, dtype, *shape, alpha=self.alpha, beta=self.beta) 68 | 69 | @_dispatch 70 | def logpdf(self: "Beta[Aggregate, Aggregate, Aggregate]", x: Aggregate): 71 | return sum( 72 | [ 73 | Beta(ai, bi, di).logpdf(xi) 74 | for ai, bi, di, xi in zip(self.alpha, self.beta, self.d, x) 75 | ], 76 | 0, 77 | ) 78 | 79 | @_dispatch 80 | def logpdf(self: "Beta[B.Numeric, B.Numeric, B.Int]", x: Masked): 81 | x, mask = x.y, x.mask 82 | with B.on_device(self.alpha): 83 | safe = B.to_active_device(B.cast(B.dtype(self.alpha), 0.5)) 84 | # Make inputs safe. 85 | x = mask * x + (1 - mask) * safe 86 | # Run with safe inputs, and filter out the right logpdfs. 87 | return self.logpdf(x, mask=mask) 88 | 89 | @_dispatch 90 | def logpdf(self: "Beta[B.Numeric, B.Numeric, B.Int]", x: B.Numeric, *, mask=1): 91 | logz = B.logbeta(self.alpha, self.beta) 92 | logpdf = (self.alpha - 1) * B.log(x) + (self.beta - 1) * B.log(1 - x) - logz 93 | logpdf = logpdf * mask 94 | if self.d == 0: 95 | return logpdf 96 | else: 97 | return B.sum(logpdf, axis=tuple(range(B.rank(logpdf)))[-self.d :]) 98 | 99 | def __str__(self): 100 | return f"Beta({self.alpha}, {self.beta})" 101 | 102 | def __repr__(self): 103 | return f"Beta({self.alpha!r}, {self.beta!r})" 104 | 105 | 106 | @B.dtype.dispatch 107 | def dtype(dist: Beta): 108 | return B.dtype(dist.alpha, dist.beta) 109 | 110 | 111 | @shape_batch.dispatch 112 | def shape_batch(dist: "Beta[B.Numeric, B.Numeric, B.Int]"): 113 | return B.shape_broadcast(dist.alpha, dist.beta)[: -dist.d] 114 | 115 | 116 | @shape_batch.dispatch 117 | def shape_batch(dist: "Beta[Aggregate, Aggregate, Aggregate]"): 118 | return broadcast( 119 | *( 120 | shape_batch(Beta(ai, bi, di)) 121 | for ai, bi, di in zip(dist.alpha, dist.beta, dist.d) 122 | ) 123 | ) 124 | -------------------------------------------------------------------------------- /neuralprocesses/dist/dirac.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | from plum import parametric 3 | from wbml.util import indented_kv 4 | 5 | from .. import _dispatch 6 | from ..aggregate import Aggregate 7 | from ..util import batch 8 | from .dist import AbstractDistribution 9 | 10 | __all__ = ["Dirac"] 11 | 12 | 13 | @parametric 14 | class Dirac(AbstractDistribution): 15 | """A Dirac delta. 16 | 17 | Args: 18 | x (tensor): Position of the Dirac delta. 19 | d (int): Dimensionality of the data. 20 | 21 | Attributes: 22 | x (tensor): Position of the Dirac delta. 23 | d (int): Dimensionality of the data. 24 | """ 25 | 26 | def __init__(self, x, d): 27 | self.x = x 28 | self.d = d 29 | 30 | def __repr__(self): 31 | return " self.lower: 47 | lam = B.log(self.factor) / (self.factor_at - self.lower) 48 | lam = B.cast(dtype_float, B.to_active_device(lam)) 49 | probs = B.exp(-lam * B.cast(dtype_float, realisations)) 50 | else: 51 | probs = B.to_active_device(B.ones(dtype_float, 1)) 52 | return B.choice(state, realisations, *shape, p=probs) 53 | 54 | def __str__(self): 55 | return ( 56 | f"TruncatedGeometric(" 57 | f"{self.lower}, {self.upper}, {self.factor}, {self.factor_at}" 58 | f")" 59 | ) 60 | 61 | def __repr__(self): 62 | return ( 63 | f"TruncatedGeometric(" 64 | f"{self.lower!r}, {self.uppers!r}, {self.factor!r}, {self.factor_at!r}" 65 | f")" 66 | ) 67 | -------------------------------------------------------------------------------- /neuralprocesses/dist/uniform.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import lab as B 4 | from lab.shape import Dimension 5 | 6 | from .. import _dispatch 7 | from .dist import AbstractDistribution 8 | 9 | __all__ = ["UniformContinuous", "UniformDiscrete"] 10 | 11 | 12 | class UniformContinuous(AbstractDistribution): 13 | """A uniform continuous distribution. 14 | 15 | Also takes in tuples of its arguments. 16 | 17 | Args: 18 | lower (float): Lower bound. 19 | upper (float): Upper bound. 20 | 21 | Attributes: 22 | lowers (vector): Lower bounds. 23 | uppers (vector): Upper bounds. 24 | """ 25 | 26 | @_dispatch 27 | def __init__(self, lower: B.Number, upper: B.Number): 28 | self.__init__((lower, upper)) 29 | 30 | @_dispatch 31 | def __init__(self, *bounds: Tuple[B.Number, B.Number]): 32 | lowers, uppers = zip(*bounds) 33 | self.lowers = B.stack(*lowers) 34 | self.uppers = B.stack(*uppers) 35 | 36 | def __str__(self): 37 | return f"UniformContinuous({self.lower}, {self.upper})" 38 | 39 | def __repr__(self): 40 | return f"UniformContinuous({self.lower!r}, {self.uppers!r})" 41 | 42 | @_dispatch 43 | def sample(self, state: B.RandomState, dtype: B.DType, *shape): 44 | lowers = B.to_active_device(B.cast(dtype, self.lowers)) 45 | uppers = B.to_active_device(B.cast(dtype, self.uppers)) 46 | # Wrap everything in `Dimension`s to make dispatch work. 47 | shape = (Dimension(d) for d in shape) 48 | state, rand = B.rand(state, dtype, *shape, B.shape(lowers, 0)) 49 | return state, lowers + (uppers - lowers) * rand 50 | 51 | 52 | class UniformDiscrete(AbstractDistribution): 53 | """A uniform discrete distribution. 54 | 55 | Args: 56 | lower (int): Lower bound. 57 | upper (int): Upper bound. 58 | 59 | Attributes: 60 | lower (int): Lower bound. 61 | upper (int): Upper bound. 62 | """ 63 | 64 | @_dispatch 65 | def __init__(self, lower: B.Int, upper: B.Int): 66 | self.lower = lower 67 | self.upper = upper 68 | 69 | @_dispatch 70 | def sample(self, state: B.RandomState, dtype: B.DType, *shape): 71 | return B.randint(state, dtype, lower=self.lower, upper=self.upper + 1, *shape) 72 | 73 | def __str__(self): 74 | return f"UniformDiscrete({self.lower}, {self.upper})" 75 | 76 | def __repr__(self): 77 | return f"UniformDiscrete({self.lower!r}, {self.uppers!r})" 78 | -------------------------------------------------------------------------------- /neuralprocesses/mask.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import lab as B 4 | from lab.util import resolve_axis 5 | 6 | from . import _dispatch 7 | 8 | __all__ = ["Masked", "mask_context", "merge_contexts"] 9 | 10 | 11 | class Masked: 12 | """A masked output. 13 | 14 | Args: 15 | y (tensor): Output to mask. The masked values can have any non-NaN value, but 16 | they cannot be NaN! 17 | mask (tensor): A mask consisting of zeros and ones and just one channel. 18 | 19 | Attributes: 20 | y (tensor): Masked output. 21 | mask (tensor): A mask consisting of zeros and ones and just one channel. 22 | """ 23 | 24 | def __init__(self, y, mask): 25 | self.y = y 26 | self.mask = mask 27 | 28 | 29 | @B.to_active_device.dispatch 30 | def to_active_device(masked: Masked): 31 | return Masked(B.to_active_device(masked.y), B.to_active_device(masked.mask)) 32 | 33 | 34 | @_dispatch 35 | def _pad_zeros(x: B.Numeric, up_to: int, axis: int): 36 | axis = resolve_axis(x, axis) 37 | shape = list(B.shape(x)) 38 | shape[axis] = up_to - shape[axis] 39 | with B.on_device(x): 40 | return B.concat(x, B.zeros(B.dtype(x), *shape), axis=axis) 41 | 42 | 43 | def _ceil_to_closest_multiple(n, m): 44 | d, r = divmod(n, m) 45 | # If `n` is zero, then we must also round up. 46 | if n == 0 or r > 0: 47 | return (d + 1) * m 48 | else: 49 | return d * m 50 | 51 | 52 | @_dispatch 53 | def _determine_ns(xc: tuple, multiple: Union[int, tuple]): 54 | ns = [B.shape(xci, 2) for xci in xc] 55 | 56 | if not isinstance(multiple, tuple): 57 | multiple = (multiple,) * len(ns) 58 | 59 | # Ceil to the closest multiple of `multiple`. 60 | ns = [_ceil_to_closest_multiple(n, m) for n, m in zip(ns, multiple)] 61 | 62 | return ns 63 | 64 | 65 | @_dispatch 66 | def mask_context(xc: tuple, yc: B.Numeric, multiple=1): 67 | """Mask a context set. 68 | 69 | Args: 70 | xc (input): Context inputs. 71 | yc (tensor): Context outputs. 72 | multiple (int or tuple[int], optional): Pad with zeros until the number of 73 | context points is a multiple of this number. Defaults to 1. 74 | 75 | Returns: 76 | tuple[input, :class:`.Masked`]: Masked context set with zeros appended. 77 | """ 78 | ns = _determine_ns(xc, multiple) 79 | 80 | # Construct the mask, not yet of the final size. 81 | with B.on_device(yc): 82 | mask = B.ones(yc) 83 | 84 | # Pad everything with zeros to get the desired size. 85 | xc = tuple(_pad_zeros(xci, n, 2) for xci, n in zip(xc, ns)) 86 | for i, n in enumerate(ns): 87 | yc = _pad_zeros(yc, n, 2 + i) 88 | mask = _pad_zeros(mask, n, 2 + i) 89 | 90 | return xc, Masked(yc, mask) 91 | 92 | 93 | @_dispatch 94 | def mask_context(xc: B.Numeric, yc: B.Numeric, **kw_args): 95 | xc, yc = mask_context((xc,), yc, **kw_args) # Pack input. 96 | return xc[0], yc # Unpack input. 97 | 98 | 99 | @_dispatch 100 | def merge_contexts(*contexts: Tuple[tuple, B.Numeric], multiple=1): 101 | """Merge context sets. 102 | 103 | Args: 104 | *contexts (tuple[input, tensor]): Contexts to merge. 105 | multiple (int or tuple[int], optional): Pad with zeros until the number of 106 | context points is a multiple of this number. Defaults to 1. 107 | 108 | Returns: 109 | tuple[input, :class:`.Masked`]: Merged context set. 110 | """ 111 | ns = tuple(map(max, zip(*(_determine_ns(xc, multiple) for xc, _ in contexts)))) 112 | xcs, ycs = zip(*(mask_context(*context, multiple=ns) for context in contexts)) 113 | ycs, masks = zip(*((yc.y, yc.mask) for yc in ycs)) 114 | 115 | return ( 116 | tuple(B.concat(*xcsi, axis=0) for xcsi in zip(*xcs)), 117 | Masked(B.concat(*ycs, axis=0), B.concat(*masks, axis=0)), 118 | ) 119 | 120 | 121 | @_dispatch 122 | def merge_contexts(*contexts: Tuple[B.Numeric, B.Numeric], **kw_args): 123 | xcs, ycs = zip(*contexts) 124 | xcs = tuple((xc,) for xc in xcs) # Pack inputs. 125 | xc, yc = merge_contexts(*zip(xcs, ycs), **kw_args) 126 | return xc[0], yc # Unpack inputs. 127 | -------------------------------------------------------------------------------- /neuralprocesses/materialise.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import matrix # noqa 3 | 4 | from . import _dispatch 5 | from .aggregate import Aggregate, AggregateInput 6 | from .datadims import data_dims 7 | from .parallel import Parallel 8 | from .util import register_module 9 | 10 | __all__ = ["Materialise", "Concatenate", "Sum"] 11 | 12 | 13 | @register_module 14 | class Concatenate: 15 | """Materialise an aggregate encoding by concatenating.""" 16 | 17 | 18 | Materialise = Concatenate #: Alias of `.Concatenate` for backward compatibility. 19 | 20 | 21 | @_dispatch 22 | def code(coder: Concatenate, xz, z, x, **kw_args): 23 | return _merge(xz), _repeat_concat(data_dims(xz), z) 24 | 25 | 26 | @_dispatch 27 | def _merge(z1, z2, z3, *zs): 28 | z = _merge(z1, z2) 29 | for zi in (z3,) + zs: 30 | z = _merge(z, zi) 31 | return z 32 | 33 | 34 | @_dispatch 35 | def _merge(zs: Parallel): 36 | return _merge(*zs) 37 | 38 | 39 | @_dispatch(precedence=2) 40 | def _merge(z1: Parallel, z2): 41 | return _merge(_merge(*z1), z2) 42 | 43 | 44 | @_dispatch(precedence=1) 45 | def _merge(z1, z2: Parallel): 46 | return _merge(z1, _merge(*z2)) 47 | 48 | 49 | @_dispatch 50 | def _merge(z): 51 | return z 52 | 53 | 54 | @_dispatch 55 | def _merge(z1: None, z2: None): 56 | return None 57 | 58 | 59 | @_dispatch 60 | def _merge(z1: None, z2): 61 | return z2 62 | 63 | 64 | @_dispatch 65 | def _merge(z1, z2: None): 66 | return z1 67 | 68 | 69 | @_dispatch 70 | def _merge(z1: B.Numeric, z2: B.Numeric): 71 | if B.jit_to_numpy(B.mean(B.abs(z1 - z2))) > B.epsilon: 72 | raise ValueError("Cannot merge inputs.") 73 | return z1 74 | 75 | 76 | @_dispatch 77 | def _merge(z1: tuple, z2: tuple): 78 | return tuple(_merge(z1i, z2i) for z1i, z2i in zip(z1, z2)) 79 | 80 | 81 | @_dispatch 82 | def _merge(z1: AggregateInput, z2: AggregateInput): 83 | # Merge indices. 84 | inds1 = tuple(i for _, i in z1) 85 | inds2 = tuple(i for _, i in z2) 86 | if inds1 != inds2: 87 | raise ValueError("Cannot merge aggregate targets.") 88 | 89 | # Merges values and zip indices to them. 90 | x1 = tuple(x for x, _ in z1) 91 | x2 = tuple(x for x, _ in z2) 92 | return AggregateInput( 93 | *((_merge(x1i, x2i), i) for (x1i, x2i), i in zip(zip(x1, x2), inds1)) 94 | ) 95 | 96 | 97 | @_dispatch 98 | def _repeat_concat(dims, z1, z2, z3, *zs): 99 | z = _repeat_concat(dims, z1, z2) 100 | for zi in (z3,) + zs: 101 | z = _repeat_concat(dims, z, zi) 102 | return z 103 | 104 | 105 | @_dispatch 106 | def _repeat_concat(dims, z: Parallel): 107 | return _repeat_concat(dims, *z) 108 | 109 | 110 | @_dispatch(precedence=2) 111 | def _repeat_concat(dims, z1: Parallel, z2): 112 | return _repeat_concat(dims, _repeat_concat(dims, *z1), z2) 113 | 114 | 115 | @_dispatch(precedence=1) 116 | def _repeat_concat(dims, z1, z2: Parallel): 117 | return _repeat_concat(dims, z1, _repeat_concat(dims, *z2)) 118 | 119 | 120 | @_dispatch 121 | def _repeat_concat(dims, z): 122 | return z 123 | 124 | 125 | @_dispatch 126 | def _repeat_concat(dims: B.Int, z1: B.Numeric, z2: B.Numeric): 127 | # One of the two may have an sample dimension, but that's the only discrepancy 128 | # which is allowed. 129 | rank, rank2 = B.rank(z1), B.rank(z2) 130 | if rank == rank2: 131 | pass # This is fine, of course. 132 | elif rank + 1 == rank2: 133 | z1 = B.expand_dims(z1, axis=0) 134 | elif rank == rank2 + 1: 135 | z2 = B.expand_dims(z2, axis=0) 136 | else: 137 | raise ValueError(f"Cannot concatenate tensors with ranks {rank} and {rank2}.") 138 | # The ranks of `z1` and `z2` should now be the same. Take the rank of any. 139 | rank = B.rank(z1) 140 | 141 | # Broadcast the data dimensions and possible sample dimension. There are `1 + dims` 142 | # many of them, so perform a loop. 143 | shape1, shape2 = list(B.shape(z1)), list(B.shape(z2)) 144 | for i in [0] + list(range(rank - 1, rank - 1 - dims, -1)): 145 | shape_n = max(shape1[i], shape2[i]) 146 | # Zeros cannot be broadcasted. Those must be retained. 147 | if shape1[i] == 0 or shape2[i] == 0: 148 | shape_n = 0 149 | shape1[i] = shape_n 150 | shape2[i] = shape_n 151 | z1 = B.broadcast_to(z1, *shape1) 152 | z2 = B.broadcast_to(z2, *shape2) 153 | 154 | # `z1` and `z2` should now be ready for concatenation. 155 | return B.concat(z1, z2, axis=-1 - dims) 156 | 157 | 158 | @_dispatch 159 | def _repeat_concat(dims: Aggregate, z1: Aggregate, z2: Aggregate): 160 | return Aggregate( 161 | *(_repeat_concat(di, z1i, z2i) for di, z1i, z2i in zip(dims, z1, z2)) 162 | ) 163 | 164 | 165 | @_dispatch 166 | def _repeat_concat(dims: Aggregate, z1: Aggregate, z2): 167 | return Aggregate(*(_repeat_concat(di, z1i, z2) for di, z1i in zip(dims, z1))) 168 | 169 | 170 | @_dispatch 171 | def _repeat_concat(dims: Aggregate, z1, z2: Aggregate): 172 | return Aggregate(*(_repeat_concat(di, z1, z2i) for di, z2i in zip(dims, z2))) 173 | 174 | 175 | @_dispatch 176 | def _repeat_concat(dims, z1: Aggregate, z2: Aggregate): 177 | return Aggregate(*(_repeat_concat(dims, z1i, z2i) for z1i, z2i in zip(z1, z2))) 178 | 179 | 180 | @_dispatch 181 | def _repeat_concat(dims, z1: Aggregate, z2): 182 | return Aggregate(*(_repeat_concat(dims, z1i, z2) for z1i in z1)) 183 | 184 | 185 | @_dispatch 186 | def _repeat_concat(dims, z1, z2: Aggregate): 187 | return Aggregate(*(_repeat_concat(dims, z1, z2i) for z2i in z2)) 188 | 189 | 190 | @register_module 191 | class Sum: 192 | """Materialise an aggregate encoding by summing.""" 193 | 194 | 195 | @_dispatch 196 | def code(coder: Sum, xz, z, x, **kw_args): 197 | return _merge(xz), _sum(z) 198 | 199 | 200 | @_dispatch 201 | def _sum(z: B.Numeric): 202 | return z 203 | 204 | 205 | @_dispatch 206 | def _sum(zs: Parallel): 207 | return sum(zs, 0) 208 | -------------------------------------------------------------------------------- /neuralprocesses/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .ar import * 2 | from .elbo import * 3 | from .loglik import * 4 | from .model import * 5 | from .predict import * 6 | -------------------------------------------------------------------------------- /neuralprocesses/model/loglik.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import numpy as np 3 | 4 | from .. import _dispatch 5 | from ..numdata import num_data 6 | from .model import Model 7 | from .util import fix_noise as fix_noise_in_pred 8 | 9 | __all__ = ["loglik"] 10 | 11 | 12 | @_dispatch 13 | def loglik( 14 | state: B.RandomState, 15 | model: Model, 16 | contexts: list, 17 | xt, 18 | yt, 19 | *, 20 | num_samples=1, 21 | batch_size=16, 22 | normalise=False, 23 | fix_noise=None, 24 | dtype_lik=None, 25 | **kw_args, 26 | ): 27 | """Log-likelihood objective. 28 | 29 | Args: 30 | state (random state, optional): Random state. 31 | model (:class:`.Model`): Model. 32 | xc (input): Inputs of the context set. 33 | yc (tensor): Output of the context set. 34 | xt (input): Inputs of the target set. 35 | yt (tensor): Outputs of the target set. 36 | num_samples (int, optional): Number of samples. Defaults to 1. 37 | batch_size (int, optional): Batch size to use for sampling. Defaults to 16. 38 | normalise (bool, optional): Normalise the objective by the number of targets. 39 | Defaults to `False`. 40 | fix_noise (float, optional): Fix the likelihood variance to this value. 41 | dtype_lik (dtype, optional): Data type to use for the likelihood computation. 42 | Defaults to the 64-bit variant of the data type of `yt`. 43 | 44 | Returns: 45 | random state, optional: Random state. 46 | tensor: Log-likelihoods. 47 | """ 48 | float = B.dtype_float(yt) 49 | float64 = B.promote_dtypes(float, np.float64) 50 | 51 | # For the likelihood computation, default to using a 64-bit version of the data 52 | # type of `yt`. 53 | if not dtype_lik: 54 | dtype_lik = float64 55 | 56 | # Sample in batches to alleviate memory requirements. 57 | logpdfs = None 58 | done_num_samples = 0 59 | while done_num_samples < num_samples: 60 | # Limit the number of samples at the batch size. 61 | this_num_samples = min(num_samples - done_num_samples, batch_size) 62 | 63 | # Perform batch. 64 | state, pred = model( 65 | state, 66 | contexts, 67 | xt, 68 | num_samples=this_num_samples, 69 | dtype_enc_sample=float, 70 | dtype_lik=dtype_lik, 71 | **kw_args, 72 | ) 73 | pred = fix_noise_in_pred(pred, fix_noise) 74 | this_logpdfs = pred.logpdf(B.cast(dtype_lik, yt)) 75 | 76 | # If the number of samples is equal to one but `num_samples > 1`, then the 77 | # encoding was a `Dirac`, so we can stop batching. Also, set `num_samples = 1` 78 | # because we only have one sample now. We also don't need to do the 79 | # `logsumexp` anymore. 80 | if num_samples > 1 and B.shape(this_logpdfs, 0) == 1: 81 | logpdfs = this_logpdfs 82 | num_samples = 1 83 | break 84 | 85 | # Record current samples. 86 | if logpdfs is None: 87 | logpdfs = this_logpdfs 88 | else: 89 | # Concatenate at the sample dimension. 90 | logpdfs = B.concat(logpdfs, this_logpdfs, axis=0) 91 | 92 | # Increase the counter. 93 | done_num_samples += this_num_samples 94 | 95 | # Average over samples. Sample dimension should always be the first. 96 | logpdfs = B.logsumexp(logpdfs, axis=0) - B.cast(dtype_lik, B.log(num_samples)) 97 | 98 | if normalise: 99 | # Normalise by the number of targets. 100 | logpdfs = logpdfs / B.cast(dtype_lik, num_data(xt, yt)) 101 | 102 | return state, logpdfs 103 | 104 | 105 | @_dispatch 106 | def loglik(state: B.RandomState, model: Model, xc, yc, xt, yt, **kw_args): 107 | return loglik(state, model, [(xc, yc)], xt, yt, **kw_args) 108 | 109 | 110 | @_dispatch 111 | def loglik(model: Model, *args, **kw_args): 112 | state = B.global_random_state(B.dtype(args[-2])) 113 | state, logpdfs = loglik(state, model, *args, **kw_args) 114 | B.set_global_random_state(state) 115 | return logpdfs 116 | -------------------------------------------------------------------------------- /neuralprocesses/model/model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import lab as B 4 | from matrix.util import indent 5 | 6 | from .. import _dispatch 7 | from ..augment import AugmentedInput 8 | from ..coding import code 9 | from ..mask import Masked 10 | from ..util import register_module 11 | from .util import compress_contexts, sample 12 | 13 | __all__ = ["Model"] 14 | 15 | 16 | @register_module 17 | class Model: 18 | """Encoder-decoder model. 19 | 20 | Args: 21 | encoder (coder): Coder. 22 | decoder (coder): Coder. 23 | 24 | Attributes: 25 | encoder (coder): Coder. 26 | decoder (coder): Coder. 27 | """ 28 | 29 | def __init__(self, encoder, decoder): 30 | self.encoder = encoder 31 | self.decoder = decoder 32 | 33 | @_dispatch 34 | def __call__( 35 | self, 36 | state: B.RandomState, 37 | xc, 38 | yc, 39 | xt, 40 | *, 41 | num_samples=None, 42 | aux_t=None, 43 | dtype_enc_sample=None, 44 | **kw_args, 45 | ): 46 | """Run the model. 47 | 48 | Args: 49 | state (random state, optional): Random state. 50 | xc (input): Context inputs. 51 | yc (tensor): Context outputs. 52 | xt (input): Target inputs. 53 | num_samples (int, optional): Number of samples, if applicable. 54 | aux_t (tensor, optional): Target-specific auxiliary input, if applicable. 55 | dtype_enc_sample (dtype, optional): Data type to convert the sampled 56 | encoding to. 57 | 58 | Returns: 59 | random state, optional: Random state. 60 | input: Target inputs. 61 | object: Prediction for target outputs. 62 | """ 63 | # Perform augmentation of `xt` with auxiliary target information. 64 | if aux_t is not None: 65 | xt = AugmentedInput(xt, aux_t) 66 | 67 | # If the keyword `noiseless` is set to `True`, then that only applies to the 68 | # decoder. 69 | enc_kw_args = dict(kw_args) 70 | if "noiseless" in enc_kw_args: 71 | del enc_kw_args["noiseless"] 72 | xz, pz = code(self.encoder, xc, yc, xt, root=True, **enc_kw_args) 73 | 74 | # Sample and convert sample to the right data type. 75 | shape = () if num_samples is None else (num_samples,) 76 | state, z = sample(state, pz, *shape) 77 | if dtype_enc_sample: 78 | z = B.cast(dtype_enc_sample, z) 79 | 80 | _, d = code(self.decoder, xz, z, xt, root=True, **kw_args) 81 | 82 | return state, d 83 | 84 | @_dispatch 85 | def __call__(self, xc, yc, xt, **kw_args): 86 | state = B.global_random_state(B.dtype(xt)) 87 | state, d = self(state, xc, yc, xt, **kw_args) 88 | B.set_global_random_state(state) 89 | return d 90 | 91 | @_dispatch 92 | def __call__( 93 | self, 94 | state: B.RandomState, 95 | contexts: List[ 96 | Tuple[Union[None, B.Numeric, tuple], Union[None, B.Numeric, Masked]], 97 | ], 98 | xt, 99 | **kw_args, 100 | ): 101 | return self( 102 | state, 103 | *compress_contexts(contexts), 104 | xt, 105 | **kw_args, 106 | ) 107 | 108 | @_dispatch 109 | def __call__( 110 | self, 111 | contexts: List[ 112 | Tuple[Union[None, B.Numeric, tuple], Union[None, B.Numeric, Masked]] 113 | ], 114 | xt, 115 | **kw_args, 116 | ): 117 | state = B.global_random_state(B.dtype(xt)) 118 | state, d = self(state, contexts, xt, **kw_args) 119 | B.set_global_random_state(state) 120 | return d 121 | 122 | def __str__(self): 123 | return ( 124 | f"Model(\n" 125 | + indent(str(self.encoder), " " * 4) 126 | + ",\n" 127 | + indent(str(self.decoder), " " * 4) 128 | + "\n)" 129 | ) 130 | 131 | def __repr__(self): 132 | return ( 133 | f"Model(\n" 134 | + indent(repr(self.encoder), " " * 4) 135 | + ",\n" 136 | + indent(repr(self.decoder), " " * 4) 137 | + "\n)" 138 | ) 139 | -------------------------------------------------------------------------------- /neuralprocesses/model/predict.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import numpy as np 3 | 4 | from .. import _dispatch 5 | from ..aggregate import Aggregate 6 | from ..dist import shape_batch 7 | from .model import Model 8 | 9 | __all__ = ["predict"] 10 | 11 | 12 | @_dispatch 13 | def predict( 14 | state: B.RandomState, 15 | model: Model, 16 | contexts: list, 17 | xt, 18 | *, 19 | num_samples=50, 20 | batch_size=16, 21 | dtype_lik=None, 22 | ): 23 | """Use a model to predict. 24 | 25 | Args: 26 | state (random state, optional): Random state. 27 | model (:class:`.Model`): Model. 28 | xc (input): Inputs of the context set. 29 | yc (tensor): Output of the context set. 30 | xt (input): Inputs of the target set. 31 | num_samples (int, optional): Number of samples to produce. Defaults to 50. 32 | batch_size (int, optional): Batch size. Defaults to 16. 33 | dtype_lik (dtype, optional): Data type to use for the likelihood computation. 34 | Defaults to the 64-bit variant of the data type of `xt`. 35 | 36 | Returns: 37 | random state, optional: Random state. 38 | tensor: Marginal mean. 39 | tensor: Marginal variance. 40 | tensor: `num_samples` noiseless samples. 41 | tensor: `num_samples` noisy samples. 42 | """ 43 | float = B.dtype_float(xt) 44 | 45 | # For the likelihood computation, default to using a 64-bit version of the data 46 | # type of `xt`. 47 | if not dtype_lik: 48 | dtype_lik = B.promote_dtypes(float, np.float64) 49 | 50 | # Collect noiseless samples, noisy samples, first moments, and second moments. 51 | ft, yt = [], [] 52 | m1s, m2s = [], [] 53 | 54 | done_num_samples = 0 55 | while done_num_samples < num_samples: 56 | # Limit the number of samples at the batch size. 57 | this_num_samples = min(num_samples - done_num_samples, batch_size) 58 | 59 | state, pred = model( 60 | state, 61 | contexts, 62 | xt, 63 | dtype_enc_sample=float, 64 | dtype_lik=dtype_lik, 65 | num_samples=this_num_samples, 66 | ) 67 | 68 | # If the number of samples is equal to one but `num_samples > 1`, then the 69 | # encoding was a `Dirac`, so we can stop batching. In this case, we can 70 | # efficiently compute everything that we need and exit. 71 | if this_num_samples > 1 and shape_batch(pred, 0) == 1: 72 | state, ft = pred.noiseless.sample(state, num_samples) 73 | state, yt = pred.sample(state, num_samples) 74 | # If `pred` or `pred.noiseless` were `Dirac`s, then `ft` or `yt` might not 75 | # have the right number of samples. 76 | ft = _possibly_tile(ft, num_samples) 77 | yt = _possibly_tile(yt, num_samples) 78 | return ( 79 | state, 80 | # Squeeze the newly introduced sample dimension. 81 | B.squeeze(pred.mean, axis=0), 82 | B.squeeze(pred.var, axis=0), 83 | # Squeeze the previously introduced sample dimension. 84 | B.squeeze(ft, axis=1), 85 | B.squeeze(yt, axis=1), 86 | ) 87 | 88 | # Produce samples. 89 | state, sample = pred.noiseless.sample(state) 90 | ft.append(sample) 91 | state, sample = pred.sample(state) 92 | yt.append(sample) 93 | 94 | # Produce moments. 95 | m1s.append(pred.mean) 96 | m2s.append(B.add(pred.var, B.multiply(m1s[-1], m1s[-1]))) 97 | 98 | done_num_samples += this_num_samples 99 | 100 | # Stack samples. 101 | ft = B.concat(*ft, axis=0) 102 | yt = B.concat(*yt, axis=0) 103 | 104 | # Compute marginal statistics. 105 | m1 = B.mean(B.concat(*m1s, axis=0), axis=0) 106 | m2 = B.mean(B.concat(*m2s, axis=0), axis=0) 107 | mean, var = m1, B.subtract(m2, B.multiply(m1, m1)) 108 | 109 | return state, mean, var, ft, yt 110 | 111 | 112 | @_dispatch 113 | def predict(state: B.RandomState, model: Model, xc, yc, xt, **kw_args): 114 | return predict(state, model, [(xc, yc)], xt, **kw_args) 115 | 116 | 117 | @_dispatch 118 | def predict(model: Model, *args, **kw_args): 119 | state = B.global_random_state(B.dtype(args[-1])) 120 | res = predict(state, model, *args, **kw_args) 121 | state, res = res[0], res[1:] 122 | B.set_global_random_state(state) 123 | return res 124 | 125 | 126 | @_dispatch 127 | def _possibly_tile(x: B.Numeric, n: B.Int): 128 | if B.shape(x, 0) == 1 and n > 1: 129 | return B.tile(x, n, *((1,) * (B.rank(x) - 1))) 130 | else: 131 | return x 132 | 133 | 134 | @_dispatch 135 | def _possibly_tile(x: Aggregate, n: B.Int): 136 | return Aggregate(*(_possibly_tile(xi, n) for xi in x)) 137 | -------------------------------------------------------------------------------- /neuralprocesses/model/util.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | from matrix import Diagonal 3 | 4 | from .. import _dispatch 5 | from ..aggregate import Aggregate, AggregateInput 6 | from ..dist import ( 7 | AbstractDistribution, 8 | MultiOutputNormal, 9 | SpikesSlab, 10 | TransformedMultiOutputDistribution, 11 | ) 12 | from ..parallel import Parallel 13 | 14 | __all__ = ["sample", "fix_noise", "compress_contexts", "tile_for_sampling"] 15 | 16 | 17 | @_dispatch 18 | def sample( 19 | state: B.RandomState, 20 | x: AbstractDistribution, 21 | *shape: B.Int, 22 | ): 23 | """Sample an encoding. 24 | 25 | Args: 26 | state (random state): Random state. 27 | x (object): Encoding. 28 | *shape (int): Batch shape of the sample. 29 | 30 | Returns: 31 | random state: Random state. 32 | object: Sample. 33 | """ 34 | return x.sample(state, *shape) 35 | 36 | 37 | @_dispatch 38 | def sample(state: B.RandomState, x: Parallel, *shape: B.Int): 39 | samples = [] 40 | for xi in x: 41 | state, s = sample(state, xi, *shape) 42 | samples.append(s) 43 | return state, Parallel(*samples) 44 | 45 | 46 | @_dispatch 47 | def fix_noise(d, value: None): 48 | """Fix the noise of a prediction. 49 | 50 | Args: 51 | d (:class:`neuralprocesses.dist.dist.AbstractDistribution`): 52 | Prediction. 53 | value (float or None): Value to fix it to. 54 | 55 | Returns: 56 | :class:`neuralprocesses.dist.dist.AbstractDistribution`: Prediction 57 | with noise fixed. 58 | """ 59 | return d 60 | 61 | 62 | @_dispatch 63 | def fix_noise(d: MultiOutputNormal, value: float): 64 | with B.on_device(d.vectorised_normal.var_diag): 65 | return MultiOutputNormal( 66 | d._mean, 67 | B.zeros(d._var), 68 | value * Diagonal(B.ones(d.vectorised_normal.var_diag)), 69 | d.shape, 70 | ) 71 | 72 | 73 | @_dispatch 74 | def fix_noise(d: TransformedMultiOutputDistribution, value: float): 75 | return TransformedMultiOutputDistribution( 76 | fix_noise(d.dist, value), 77 | d.transform, 78 | ) 79 | 80 | 81 | @_dispatch 82 | def fix_noise(d: SpikesSlab, value: float): 83 | return d 84 | 85 | 86 | @_dispatch 87 | def compress_contexts(contexts: list): 88 | """Compress multiple context sets into a single `(x, y)` pair. 89 | 90 | Args: 91 | contexts (list): Context sets. 92 | 93 | Returns: 94 | input: Context inputs. 95 | object: Context outputs. 96 | """ 97 | # Don't unnecessarily wrap things in a `Parallel`. 98 | if len(contexts) == 1: 99 | return contexts[0] 100 | else: 101 | return ( 102 | Parallel(*(c[0] for c in contexts)), 103 | Parallel(*(c[1] for c in contexts)), 104 | ) 105 | 106 | 107 | @_dispatch 108 | def tile_for_sampling(x: B.Numeric, num_samples: int): 109 | """Tile to setup batching to produce multiple samples. 110 | 111 | Args: 112 | x (object): Object to tile. 113 | num_samples (int): Number of samples. 114 | 115 | Returns: 116 | object: `x` tiled `num_samples` number of times. 117 | """ 118 | return B.tile(x[None, ...], num_samples, *((1,) * B.rank(x))) 119 | 120 | 121 | @_dispatch 122 | def tile_for_sampling(y: Aggregate, num_samples: int): 123 | return Aggregate(*(tile_for_sampling(yi, num_samples) for yi in y)) 124 | 125 | 126 | @_dispatch 127 | def tile_for_sampling(x: AggregateInput, num_samples: int): 128 | return Aggregate(*((tile_for_sampling(xi, num_samples), i) for xi, i in x)) 129 | 130 | 131 | @_dispatch 132 | def tile_for_sampling(contexts: list, num_samples: int): 133 | return [ 134 | (tile_for_sampling(xi, num_samples), tile_for_sampling(yi, num_samples)) 135 | for xi, yi in contexts 136 | ] 137 | -------------------------------------------------------------------------------- /neuralprocesses/numdata.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | 3 | from . import _dispatch 4 | from .aggregate import Aggregate, AggregateInput 5 | from .datadims import data_dims 6 | 7 | __all__ = ["num_data"] 8 | 9 | 10 | @_dispatch 11 | def num_data(x, y: B.Numeric): 12 | """Count the number of data points. 13 | 14 | Args: 15 | x (input): Inputs. 16 | y (object): Outputs. 17 | 18 | Returns: 19 | int: Number of data points. 20 | """ 21 | d = data_dims(x) 22 | available = B.cast(B.dtype_float(y), ~B.isnan(y)) 23 | # Sum over the channel dimension and over all data dimensions. 24 | return B.sum(available, axis=tuple(range(-d - 1, 0))) 25 | 26 | 27 | @_dispatch 28 | def num_data(x: AggregateInput, y: Aggregate): 29 | return sum([num_data(xi, yi) for (xi, i), yi in zip(x, y)]) 30 | -------------------------------------------------------------------------------- /neuralprocesses/parallel.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | from matrix.util import indent 3 | 4 | from . import _dispatch 5 | from .util import is_framework_module, register_module 6 | 7 | __all__ = [ 8 | "Parallel", 9 | "broadcast_coder_over_parallel", 10 | ] 11 | 12 | 13 | @register_module 14 | class Parallel: 15 | """A parallel of elements. 16 | 17 | Args: 18 | *elements (object): Objects to put in parallel. 19 | 20 | Attributes: 21 | elements (tuple): Objects in parallel. 22 | """ 23 | 24 | def __init__(self, *elements): 25 | if any(is_framework_module(element) for element in elements): 26 | self.elements = self.nn.ModuleList(elements) 27 | else: 28 | self.elements = elements 29 | 30 | def __call__(self, x): 31 | return Parallel(*(e(x) for e in self.elements)) 32 | 33 | def __getitem__(self, item): 34 | return self.elements[item] 35 | 36 | def __len__(self): 37 | return len(self.elements) 38 | 39 | def __iter__(self): 40 | return iter(self.elements) 41 | 42 | def __str__(self): 43 | return repr(self) 44 | 45 | def __repr__(self): 46 | return ( 47 | "Parallel(\n" 48 | + "".join([indent(repr(e).strip(), " " * 4) + ",\n" for e in self]) 49 | + ")" 50 | ) 51 | 52 | 53 | @B.cast.dispatch 54 | def cast(dtype, x: Parallel): 55 | return Parallel(*(B.cast(dtype, xi) for xi in x)) 56 | 57 | 58 | @_dispatch 59 | def code(p: Parallel, xz, z, x, **kw_args): 60 | xz, z = zip(*[code(pi, xz, z, x, **kw_args) for pi in p]) 61 | return Parallel(*xz), Parallel(*z) 62 | 63 | 64 | @_dispatch 65 | def code(p: Parallel, xz, z: Parallel, x, **kw_args): 66 | xz, z = zip(*[code(pi, xz, zi, x, **kw_args) for (pi, zi) in zip(p, z)]) 67 | return Parallel(*xz), Parallel(*z) 68 | 69 | 70 | @_dispatch 71 | def code(p: Parallel, xz: Parallel, z: Parallel, x, **kw_args): 72 | xz, z = zip(*[code(pi, xzi, zi, x, **kw_args) for (pi, xzi, zi) in zip(p, xz, z)]) 73 | return Parallel(*xz), Parallel(*z) 74 | 75 | 76 | @_dispatch 77 | def code_track(p: Parallel, xz, z, x, h, **kw_args): 78 | xz, z, hs = zip(*[code_track(pi, xz, z, x, [], **kw_args) for pi in p]) 79 | return Parallel(*xz), Parallel(*z), h + [Parallel(*hs)] 80 | 81 | 82 | @_dispatch 83 | def code_track(p: Parallel, xz, z: Parallel, x, h, **kw_args): 84 | xz, z, hs = zip( 85 | *[code_track(pi, xz, zi, x, [], **kw_args) for (pi, zi) in zip(p, z)] 86 | ) 87 | return Parallel(*xz), Parallel(*z), h + [Parallel(*hs)] 88 | 89 | 90 | @_dispatch 91 | def code_track(p: Parallel, xz: Parallel, z: Parallel, x, h, **kw_args): 92 | xz, z, hs = zip( 93 | *[code_track(pi, xzi, zi, x, [], **kw_args) for (pi, xzi, zi) in zip(p, xz, z)] 94 | ) 95 | return Parallel(*xz), Parallel(*z), h + [Parallel(*hs)] 96 | 97 | 98 | @_dispatch 99 | def recode(p: Parallel, xz, z, h, **kw_args): 100 | xz, z, _ = zip(*[recode(pi, xz, z, hi, **kw_args) for pi, hi in zip(p, h[0])]) 101 | return Parallel(*xz), Parallel(*z), h[1:] 102 | 103 | 104 | @_dispatch 105 | def recode(p: Parallel, xz, z: Parallel, h, **kw_args): 106 | xz, z, _ = zip( 107 | *[recode(pi, xz, zi, hi, **kw_args) for (pi, zi, hi) in zip(p, z, h[0])] 108 | ) 109 | return Parallel(*xz), Parallel(*z), h[1:] 110 | 111 | 112 | @_dispatch 113 | def recode(p: Parallel, xz: Parallel, z: Parallel, h, **kw_args): 114 | xz, z, _ = zip( 115 | *[ 116 | recode(pi, xzi, zi, hi, **kw_args) 117 | for (pi, xzi, zi, hi) in zip(p, xz, z, h[0]) 118 | ] 119 | ) 120 | return Parallel(*xz), Parallel(*z), h[1:] 121 | 122 | 123 | def broadcast_coder_over_parallel(coder_type): 124 | """Broadcast a coder over parallel encodings. 125 | 126 | Args: 127 | coder_type (type): Type of coder. 128 | """ 129 | 130 | @_dispatch 131 | def code(p: coder_type, xz: Parallel, z: Parallel, x, **kw_args): 132 | xz, z = zip(*[code(p, xzi, zi, x, **kw_args) for (xzi, zi) in zip(xz, z)]) 133 | return Parallel(*xz), Parallel(*z) 134 | -------------------------------------------------------------------------------- /neuralprocesses/tensorflow/__init__.py: -------------------------------------------------------------------------------- 1 | import lab.tensorflow # noqa 2 | from plum import convert 3 | 4 | from .. import * # noqa 5 | from ..util import models, modules, wrapped_partial 6 | from .nn import * 7 | 8 | 9 | def create_init(module): 10 | def __init__(self, *args, **kw_args): 11 | Module.__init__(self) 12 | module.__init__(self, *args, **kw_args) 13 | 14 | return __init__ 15 | 16 | 17 | def create_tf_call(module): 18 | def call(self, *args, training=False, **kw_args): 19 | try: 20 | return module.__call__(self, *args, training=training, **kw_args) 21 | except TypeError: 22 | return module.__call__(self, *args, **kw_args) 23 | 24 | return call 25 | 26 | 27 | for module in modules: 28 | globals()[module.__name__] = type( 29 | module.__name__, 30 | (module, Module), 31 | { 32 | "__init__": create_init(module), 33 | "__call__": create_tf_call(module), 34 | "call": create_tf_call(module), 35 | }, 36 | ) 37 | 38 | 39 | class Namespace: 40 | pass 41 | 42 | 43 | ns = Namespace() 44 | ns.__dict__.update(globals()) 45 | 46 | for model in models: 47 | globals()[model.__name__] = wrapped_partial(model, nps=ns) 48 | -------------------------------------------------------------------------------- /neuralprocesses/torch/__init__.py: -------------------------------------------------------------------------------- 1 | import lab.torch # noqa 2 | 3 | from .. import * # noqa 4 | from ..util import models, modules, wrapped_partial 5 | from .nn import * 6 | 7 | 8 | def create_init(module): 9 | def __init__(self, *args, **kw_args): 10 | Module.__init__(self) 11 | module.__init__(self, *args, **kw_args) 12 | 13 | return __init__ 14 | 15 | 16 | def create_forward(Module): 17 | def forward(self, *args, **kw_args): 18 | return Module.__call__(self, *args, **kw_args) 19 | 20 | return forward 21 | 22 | 23 | for module in modules: 24 | globals()[module.__name__] = type( 25 | module.__name__, 26 | (module, Module), 27 | {"__init__": create_init(module), "forward": create_forward(module)}, 28 | ) 29 | 30 | 31 | class Namespace: 32 | pass 33 | 34 | 35 | ns = Namespace() 36 | ns.__dict__.update(globals()) 37 | 38 | for model in models: 39 | globals()[model.__name__] = wrapped_partial(model, nps=ns) 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling>=1.8.0", "hatch-vcs"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "NeuralProcesses" 7 | description = "A framework for composing Neural Processes in Python" 8 | authors = [ 9 | {name="Wessel Bruinsma", email="wessel.p.bruinsma@gmail.com"}, 10 | ] 11 | license = {text="MIT"} 12 | readme = "README.md" 13 | dynamic = ["version"] 14 | 15 | requires-python = ">=3.8" 16 | dependencies = [ 17 | "numpy>=1.16", 18 | "backends>=1.6.2", 19 | "backends-matrix>=1.2.10", 20 | "plum-dispatch>=2.3.0", 21 | "stheno>=1.4.2", 22 | "wbml>=0.3.18", 23 | ] 24 | 25 | [project.optional-dependencies] 26 | dev = [ 27 | "numpy", 28 | "pytest>=6", 29 | "pytest-cov", 30 | "coveralls", 31 | "pre-commit", 32 | "IPython", 33 | "black==23.7.0", 34 | "ghp-import", 35 | "wheel", 36 | "build", 37 | "tox", 38 | "jupyter-book", 39 | "mypy", 40 | "pyright", 41 | "torch", 42 | "tensorflow", 43 | "tensorflow-probability[tf]", 44 | ] 45 | 46 | [project.urls] 47 | repository = "https://github.com/wesselb/neuralprocesses" 48 | 49 | [tool.hatch.build] 50 | include = ["neuralprocesses*"] 51 | 52 | [tool.hatch.version] 53 | source = "vcs" 54 | 55 | [tool.hatch.build.hooks.vcs] 56 | version-file = "neuralprocesses/_version.py" 57 | 58 | # Tests: 59 | 60 | [tool.coverage.run] 61 | branch = true 62 | command_line = "-m pytest --verbose test" 63 | source = ["neuralprocesses"] 64 | 65 | [tool.pytest.ini_options] 66 | testpaths = [ 67 | "tests", 68 | ] 69 | 70 | # Formatting tools: 71 | 72 | [tool.black] 73 | line-length = 88 74 | target-version = ["py38", "py39"] 75 | include = '\.pyi?$' 76 | exclude = ''' 77 | /( 78 | \.eggs 79 | | \.git 80 | | \.hg 81 | | \.mypy_cache 82 | | \.tox 83 | | \.venv 84 | | _build 85 | | buck-out 86 | | build 87 | | dist 88 | )/ 89 | ''' 90 | 91 | [tool.isort] 92 | profile = "black" 93 | src_paths = ["neuralprocesses", "tests"] 94 | -------------------------------------------------------------------------------- /scripts/predprey.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import lab as B 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import wbml.out 8 | import wbml.out as out 9 | from wbml.data.predprey import load 10 | from wbml.experiment import WorkingDirectory 11 | from wbml.plot import tweak, tex, pdfcrop 12 | 13 | import neuralprocesses.torch as nps 14 | from train import main 15 | 16 | # Setup script. 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model", type=str, required=True) 19 | parser.add_argument("--ar", action="store_true") 20 | args = parser.parse_args() 21 | 22 | wbml.out.report_time = True 23 | 24 | # Load experiment. 25 | with out.Section("Loading experiment"): 26 | exp = main( 27 | # The keyword arguments here must line up with the arguments you provide on the 28 | # command line. 29 | model=args.model, 30 | data="predprey", 31 | # Point `root` to where the `_experiments` directory is located. In my case, 32 | # I'm `rsync`ing it to the directory `server`. 33 | root="server/_experiments", 34 | load=True, 35 | ) 36 | model = exp["model"] 37 | model.load_state_dict( 38 | torch.load(exp["wd"].file("model-last.torch"), map_location="cpu")["weights"] 39 | ) 40 | 41 | # Setup another working directory to save output of the evaluation in. 42 | wd = WorkingDirectory("_experiments", "eval", "predprey", args.model) 43 | tex() 44 | 45 | # Increase regularisation. 46 | B.epsilon = 1e-6 47 | 48 | # Load the data. 49 | df = load() 50 | x = torch.tensor(np.array(df.index), dtype=torch.float32) 51 | x = x - x[0] 52 | y = torch.tensor(np.array(df[["hare", "lynx"]]), dtype=torch.float32) 53 | 54 | # Construct mask for the hares. 55 | mask_hare = x < 0 56 | for i in range(20, 80, 8): 57 | mask_hare |= (x >= i) & (x <= i + 3) 58 | mask_hare = ~mask_hare 59 | 60 | # Construct a mask for the lynxes. 61 | mask_lynx = x < 0 62 | for i in range(25, 80, 8): 63 | mask_lynx |= (x >= i) & (x <= i + 3) 64 | mask_lynx = ~mask_lynx 65 | 66 | # Share tensors into the standard formats. 67 | x = x[None, None, :] 68 | y = y.T[None, :, :] 69 | xt = torch.linspace(-20, 120, 141, dtype=torch.float32)[None, None, :] 70 | 71 | contexts = [ 72 | (x[:, :, mask_hare], y[:, 0:1, mask_hare]), 73 | (x[:, :, mask_lynx], y[:, 1:2, mask_lynx]), 74 | ] 75 | 76 | # `torch.no_grad` is necessary to prevent memory from accumulating. 77 | with torch.no_grad(): 78 | 79 | # Perform evaluation. 80 | xt_eval = nps.AggregateInput((x[:, :, ~mask_hare], 0), (x[:, :, ~mask_lynx], 1)) 81 | yt_eval = nps.Aggregate(y[:, 0:1, ~mask_hare], y[:, 1:2, ~mask_lynx]) 82 | out.kv( 83 | "Logpdf", 84 | nps.loglik(model, contexts, xt_eval, yt_eval, normalise=True), 85 | ) 86 | if args.ar: 87 | out.kv( 88 | "Logpdf (AR)", 89 | nps.ar_loglik(model, contexts, xt_eval, yt_eval, normalise=True), 90 | ) 91 | 92 | # Make predictions. 93 | predict = nps.ar_predict if args.ar else nps.predict 94 | mean, _, noiseless_samples, noisy_samples = predict( 95 | model, 96 | contexts, 97 | nps.AggregateInput((xt, 0), (xt, 1)), 98 | num_samples=100, 99 | ) 100 | 101 | # Plot the result. 102 | 103 | plt.figure(figsize=(10, 8)) 104 | 105 | plt.subplot(2, 1, 1) 106 | plt.scatter( 107 | x[0, 0, mask_hare], 108 | y[0, 0, mask_hare], 109 | marker="o", 110 | style="train", 111 | s=20, 112 | label="Hare", 113 | ) 114 | plt.scatter(x[0, 0, ~mask_hare], y[0, 0, ~mask_hare], marker="o", style="test", s=20) 115 | plt.plot(xt[0, 0], mean[0][0, 0, :], style="pred") 116 | plt.plot(xt[0, 0], noiseless_samples[0][:10, 0, 0, :].T, style="pred", ls="-", lw=0.5) 117 | plt.fill_between( 118 | xt[0, 0], 119 | B.quantile(noisy_samples[0][:, 0, 0, :], 2.5 / 100, axis=0), 120 | B.quantile(noisy_samples[0][:, 0, 0, :], (100 - 2.5) / 100, axis=0), 121 | style="pred", 122 | ) 123 | plt.ylim(0, 300) 124 | tweak() 125 | 126 | plt.subplot(2, 1, 2) 127 | plt.scatter( 128 | x[0, 0, mask_lynx], 129 | y[0, 1, mask_lynx], 130 | marker="o", 131 | style="train", 132 | s=20, 133 | label="Lynx", 134 | ) 135 | plt.scatter(x[0, 0, ~mask_lynx], y[0, 1, ~mask_lynx], marker="o", style="test", s=20) 136 | plt.plot(xt[0, 0], mean[1][0, 0, :], style="pred") 137 | plt.plot(xt[0, 0], noiseless_samples[1][:10, 0, 0, :].T, style="pred", ls="-", lw=0.5) 138 | plt.fill_between( 139 | xt[0, 0], 140 | B.quantile(noisy_samples[1][:, 0, 0, :], 2.5 / 100, axis=0), 141 | B.quantile(noisy_samples[1][:, 0, 0, :], (100 - 2.5) / 100, axis=0), 142 | style="pred", 143 | ) 144 | plt.ylim(0, 300) 145 | tweak() 146 | 147 | plt.savefig(wd.file("predprey.pdf")) 148 | pdfcrop(wd.file("predprey.pdf")) 149 | plt.show() 150 | -------------------------------------------------------------------------------- /scripts/predprey_visualise.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import matplotlib.pyplot as plt 4 | import torch 5 | from wbml.plot import tweak 6 | 7 | import neuralprocesses.torch as nps 8 | 9 | # Parse arguments. 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--mode", type=str) 12 | args = parser.parse_args() 13 | 14 | gen = nps.PredPreyGenerator(torch.float32, seed=0, mode=args.mode) 15 | batch = gen.generate_batch() 16 | 17 | plt.figure(figsize=(12, 8)) 18 | 19 | for i in range(16): 20 | # Plot the preys (output 0). 21 | plt.subplot(4, 4, i + 1) 22 | xc = nps.batch_xc(batch, 0)[i, 0] 23 | yc = nps.batch_yc(batch, 0)[i] 24 | xt = nps.batch_xt(batch, 0)[i, 0] 25 | yt = nps.batch_yt(batch, 0)[i] 26 | plt.scatter(xc, yc, c="tab:red", marker="x", s=5) 27 | plt.scatter(xt, yt, c="tab:orange", marker="x", s=5) 28 | 29 | # Plot the predators (output 1). 30 | xc = nps.batch_xc(batch, 1)[i, 0] 31 | yc = nps.batch_yc(batch, 1)[i] 32 | xt = nps.batch_xt(batch, 1)[i, 0] 33 | yt = nps.batch_yt(batch, 1)[i] 34 | plt.scatter(xc, yc, c="tab:blue", marker="o", s=5) 35 | plt.scatter(xt, yt, c="tab:cyan", marker="o", s=5) 36 | 37 | plt.xlim(0, 100) 38 | tweak() 39 | 40 | plt.show() 41 | -------------------------------------------------------------------------------- /scripts/sawtooth_sample_ar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import lab as B 4 | import matplotlib.pyplot as plt 5 | import neuralprocesses.torch as nps 6 | import torch 7 | import wbml.out as out 8 | from wbml.experiment import WorkingDirectory 9 | from wbml.plot import tweak 10 | 11 | # Parse the arguments, which should include the path to the weights. 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--weights", required=True) 14 | args = parser.parse_args() 15 | 16 | # Setup a directory to store the results in. 17 | out.report_time = True 18 | wd = WorkingDirectory("_experiments", "sawtooth_sample_ar") 19 | 20 | # Construct the model and load the weights. 21 | model = nps.construct_convgnp( 22 | dim_x=1, 23 | dim_y=1, 24 | unet_channels=(64,) * 6, 25 | points_per_unit=64, 26 | likelihood="het", 27 | ) 28 | model.load_state_dict(torch.load(args.weights, map_location="cpu")["weights"]) 29 | 30 | # Construct the data generator the model was trained on. 31 | gen = nps.SawtoothGenerator( 32 | torch.float32, 33 | seed=2, 34 | batch_size=1, # Only need one sample. 35 | # Use only two context points to introduce ambiguity. 36 | num_context=nps.UniformDiscrete(2, 2), 37 | # Be sure to use the same distribution of frequencies we used during training. 38 | dist_freq=nps.UniformContinuous(2, 4), 39 | ) 40 | batch = gen.generate_batch() # Sample a batch of data. 41 | 42 | # Predict at the following points. 43 | x = B.linspace(torch.float32, -2, 2, 400)[None, None, :] 44 | pred = model(batch["contexts"], x) 45 | 46 | plt.figure(figsize=(12, 6)) 47 | 48 | for i, order in enumerate(["random", "left-to-right"]): 49 | # We can predict autoregressively by using `nps.ar_predict`. 50 | mean, var, noiseless_samples, noisy_samples = nps.ar_predict( 51 | model, 52 | batch["contexts"], 53 | # The inputs to predict at need to be wrapped in a `nps.AggregateInput`. For 54 | # this particular problem, this extra functionality is not needed, but it is 55 | # needed in the case of multiple outputs. Below, `(x, 0)` means to predict 56 | # output 0 at inputs `x`. The return values will also be wrapped in 57 | # `nps.Aggregate`s, which can be accessed by indexing with the output index, 58 | # 0 in this case. 59 | nps.AggregateInput((x, 0)), 60 | num_samples=6, 61 | order=order, 62 | ) 63 | 64 | for j, (title_suffix, samples) in enumerate( 65 | [("noiseless", noiseless_samples), ("noisy", noisy_samples)] 66 | ): 67 | plt.subplot(2, 2, 2 * i + 1 + j) 68 | plt.title(order.capitalize() + f" ({title_suffix})") 69 | # Plot the context points. 70 | plt.scatter( 71 | nps.batch_xc(batch, 0)[0, 0], 72 | nps.batch_yc(batch, 0)[0], 73 | style="train", 74 | ) 75 | # Plot the mean and variance of the non-AR predction. 76 | plt.plot(x[0, 0], pred.mean[0, 0], style="pred") 77 | err = 1.96 * B.sqrt(pred.var[0, 0]) 78 | plt.fill_between( 79 | x[0, 0], 80 | pred.mean[0, 0] - err, 81 | pred.mean[0, 0] + err, 82 | style="pred", 83 | ) 84 | # Plot the samples. 85 | plt.plot(x[0, 0], samples[0][:, 0, 0, :].T, style="pred", lw=0.5, ls="-") 86 | plt.ylim(-0.2, 1.2) 87 | plt.xlim(-2, 2) 88 | tweak() 89 | 90 | plt.savefig(wd.file("samples.pdf")) 91 | plt.show() 92 | -------------------------------------------------------------------------------- /scripts/synthetic_extra.py: -------------------------------------------------------------------------------- 1 | import lab.torch as B 2 | import torch 3 | import wbml.out as out 4 | from wbml.experiment import WorkingDirectory 5 | 6 | import neuralprocesses.torch as nps 7 | from experiment import with_err 8 | 9 | wd = WorkingDirectory("_experiments", "synthetic_extra") 10 | 11 | 12 | def gens_eval(data, dim_x, dim_y): 13 | return [ 14 | ( 15 | eval_name, 16 | nps.construct_predefined_gens( 17 | torch.float32, 18 | seed=30, 19 | batch_size=16, 20 | num_tasks=2**12, 21 | dim_x=dim_x, 22 | dim_y=dim_y, 23 | pred_logpdf=True, 24 | pred_logpdf_diag=True, 25 | device="cuda", 26 | x_range_context=x_range_context, 27 | x_range_target=x_range_target, 28 | )[data], 29 | ) 30 | for eval_name, x_range_context, x_range_target in [ 31 | ("interpolation in training range", (-2, 2), (-2, 2)), 32 | ("interpolation beyond training range", (2, 6), (2, 6)), 33 | ("extrapolation beyond training range", (-2, 2), (2, 6)), 34 | ] 35 | ] 36 | 37 | 38 | for data in ["eq", "matern", "weakly-periodic", "sawtooth", "mixture"]: 39 | for dim_x in [1, 2]: 40 | for dim_y in [1, 2]: 41 | with out.Section(f"{data}-{dim_x}-{dim_y}"): 42 | for task, gen in gens_eval(data, dim_x, dim_y): 43 | with out.Section(task.capitalize()): 44 | 45 | logpdfs = [] 46 | logpdfs_diag = [] 47 | m1s = [0] * dim_y 48 | m2s = [0] * dim_y 49 | ns = [0] * dim_y 50 | 51 | # Loop over the epoch and compute statistics. 52 | for batch in gen.epoch(): 53 | if "pred_logpdf" in batch: 54 | logpdfs.append( 55 | batch["pred_logpdf"] 56 | / nps.num_data(batch["xt"], batch["yt"]) 57 | ) 58 | logpdfs_diag.append( 59 | batch["pred_logpdf_diag"] 60 | / nps.num_data(batch["xt"], batch["yt"]) 61 | ) 62 | if dim_y == 1: 63 | m1s[0] += B.sum(batch["yt"]) 64 | m2s[0] += B.sum(batch["yt"] ** 2) 65 | ns[0] += B.length(batch["yt"]) 66 | else: 67 | for i in range(dim_y): 68 | m1s[i] += B.sum(batch["yt"][i]) 69 | m2s[i] += B.sum(batch["yt"][i] ** 2) 70 | ns[i] += B.length(batch["yt"][i]) 71 | 72 | # Compute the trivial logpdf. 73 | logpdfs_trivial = [] 74 | for i in range(dim_y): 75 | m1 = m1s[i] / ns[i] 76 | m2 = m2s[i] / ns[i] 77 | emp_var = m2 - m1**2 78 | logpdfs_trivial.append( 79 | -0.5 * B.log(2 * B.pi * B.exp(1) * emp_var) 80 | ) 81 | logpdf_trivial = B.mean(B.stack(*logpdfs_trivial)) 82 | out.kv("Logpdf (trivial)", logpdf_trivial, fmt=".5f") 83 | 84 | # Report KLs. 85 | if logpdfs: 86 | out.kv("Logpdf (diag)", with_err(B.stack(*logpdfs_diag))) 87 | out.kv( 88 | "KL (diag)", 89 | with_err(B.stack(*logpdfs) - B.stack(*logpdfs_diag)), 90 | ) 91 | out.kv( 92 | "KL (trivial)", 93 | with_err(B.stack(*logpdfs) - logpdf_trivial), 94 | ) 95 | -------------------------------------------------------------------------------- /scripts/temperature_mae.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import experiment 4 | import lab as B 5 | import torch 6 | import wbml.out as out 7 | from neuralprocesses import AugmentedInput 8 | from train import main 9 | 10 | # Load experiment. 11 | sys.argv += ["--load"] 12 | exp = main() 13 | 14 | # Setup model. 15 | model = exp["model"] 16 | model.load_state_dict(torch.load(exp["wd"].file("model-last.torch"))["weights"]) 17 | 18 | 19 | def strip_augmentation(x): 20 | """Strip possible augmentation from the inputs.""" 21 | if isinstance(x, AugmentedInput): 22 | return x.x 23 | return x 24 | 25 | 26 | xt_all = strip_augmentation(exp["gens_eval"]()[0][1].generate_batch()["xt"])[0] 27 | 28 | 29 | def reindex(mae, xt): 30 | """Let the MAEs to correctly line up for randomly sampled batched.""" 31 | nan_row = mae[:, :, :1] * B.nan 32 | xt = strip_augmentation(xt)[0] 33 | 34 | # Precomputing the distances like this allows us to get away with a simple 35 | # `for`-loop below. No need to optimise that further. 36 | dists = B.to_numpy(B.pw_dists(B.t(xt_all), B.t(xt)).cpu()) 37 | 38 | rows = [] 39 | for i in range(B.shape(xt_all, -1)): 40 | match = False 41 | for j in range(B.shape(xt, -1)): 42 | if dists[i, j] < 1e-6: 43 | rows.append(mae[:, :, j : j + 1]) 44 | match = True 45 | break 46 | if not match: 47 | rows.append(nan_row) 48 | return B.concat(*rows, axis=-1) 49 | 50 | 51 | for name, gen in exp["gens_eval"](): 52 | with out.Section(name): 53 | state = B.create_random_state(torch.float32, seed=0) 54 | maes = [] 55 | 56 | with torch.no_grad(): 57 | for batch in gen.epoch(): 58 | state, pred = model(state, batch["contexts"], batch["xt"]) 59 | mae = B.abs(pred.mean - batch["yt"]) 60 | maes.append(reindex(mae, batch["xt"])) 61 | maes = B.concat(*maes) 62 | 63 | # Compute the average MAE per station, and then take the median over 64 | # stations. This lines up with the VALUE protocol. 65 | maes = B.nanmean(maes, axis=(0, 1)) 66 | 67 | out.kv("Station-wise MAEs", maes) 68 | out.kv("MAE", experiment.with_err(maes[~B.isnan(maes)])) 69 | out.kv("MAE (median)", experiment.with_err(*experiment.median_and_err(maes))) 70 | -------------------------------------------------------------------------------- /scripts/temperature_summarise_folds.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import wbml.out as out 3 | import torch 4 | from wbml.experiment import WorkingDirectory 5 | import neuralprocesses.torch as nps 6 | 7 | import experiment 8 | from train import main 9 | 10 | wd = WorkingDirectory("_experiments", "temperature_summarise_folds") 11 | 12 | 13 | def compute_logpdfs_maes(data, model): 14 | # Load experiment. 15 | exp = main(data=data, model=model, load=True) 16 | 17 | # Setup model. 18 | model = exp["model"] 19 | model.load_state_dict(torch.load(exp["wd"].file("model-best.torch"))["weights"]) 20 | 21 | # Setup generator. 22 | gens = exp["gens_eval"]() 23 | _, gen = gens[0] # The first one corresponds to downscaling. 24 | 25 | state = B.create_random_state(torch.float32, seed=0) 26 | logpdfs, maes = [], [] 27 | with torch.no_grad(): 28 | for batch in gen.epoch(): 29 | state, pred = model(state, batch["contexts"], batch["xt"]) 30 | n = nps.num_data(batch["xt"], batch["yt"]) 31 | logpdfs.append(pred.logpdf(batch["yt"]) / n) 32 | maes.append(B.abs(pred.mean - batch["yt"])) 33 | return B.concat(*logpdfs), B.concat(*maes) 34 | 35 | 36 | for model in ["convcnp-mlp", "convgnp-mlp"]: 37 | with out.Section(model): 38 | for data in ["temperature-germany", "temperature-value"]: 39 | with out.Section(data): 40 | logpdfs, maes = [], [] 41 | for fold in [1, 2, 3, 4, 5]: 42 | fold_logpdfs, fold_maes = compute_logpdfs_maes( 43 | f"{data}-{fold}", 44 | model, 45 | ) 46 | maes.append(fold_maes) 47 | logpdfs.append(fold_logpdfs) 48 | logpdfs = B.concat(*logpdfs) 49 | maes = B.concat(*maes) 50 | 51 | # Compute the average MAE per station, and then take the median over 52 | # stations. This lines up with the VALUE protocol. 53 | maes = B.nanmean(maes, axis=(0, 1)) 54 | 55 | out.kv("Loglik", experiment.with_err(logpdfs, and_upper=True)) 56 | out.kv("MAE", experiment.with_err(maes)) 57 | out.kv("MAE (median)", experiment.with_err(*experiment.median_and_err(maes))) 58 | -------------------------------------------------------------------------------- /scripts/temperature_visualise.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import lab as B 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import wbml.out as out 7 | from wbml.experiment import WorkingDirectory 8 | from wbml.plot import tweak, tex, pdfcrop 9 | 10 | import neuralprocesses.torch as nps 11 | from neuralprocesses.mask import Masked 12 | from train import main 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--sample", action="store_true") 16 | args = parser.parse_args() 17 | 18 | 19 | device = "cuda" if torch.cuda.is_available() else "cpu" 20 | B.set_global_device(device) 21 | tex() 22 | 23 | wd = WorkingDirectory("_experiments", "temperature_visualise") 24 | 25 | # Load experiment. 26 | exp = main( 27 | data="temperature", 28 | model="convcnp-multires", 29 | root="aws_run_2022-05-17_temperature/_experiments", 30 | load=True, 31 | ) 32 | model = exp["model"] 33 | model.load_state_dict( 34 | torch.load(exp["wd"].file("model-last.torch"), map_location=device)["weights"] 35 | ) 36 | gen = nps.TemperatureGenerator( 37 | torch.float32, 38 | seed=41, 39 | batch_size=1, 40 | subset="train", 41 | context_sample=True, 42 | device=device, 43 | ) 44 | b = gen.generate_batch(nc=20) 45 | 46 | # Ensure that the contexts are masked for compatibility with the below. 47 | xc_fuse, yc_fuse = b["contexts"][0] 48 | out.kv("Num context", B.shape(xc_fuse, -1)) 49 | if not isinstance(yc_fuse, Masked): 50 | yc_fuse = Masked(yc_fuse, B.ones(yc_fuse)) 51 | b["contexts"][0] = (xc_fuse, yc_fuse) 52 | 53 | # Make predictions on a grid. 54 | lons = B.linspace(torch.float32, 6, 16, 200)[None, None, :] 55 | lats = B.linspace(torch.float32, 55, 47, 200)[None, None, :] 56 | pred = model(b["contexts"], (lons, lats)) 57 | 58 | mean = pred.mean # Mean 59 | 60 | # Specify coarse grid to AR sample on. 61 | n = 30 62 | ar_lons = B.linspace(torch.float32, 6, 16, n)[:, None] 63 | ar_lons = B.flatten(B.broadcast_to(ar_lons, n, n))[None, None, :] 64 | ar_lats = B.linspace(torch.float32, 55, 47, n)[None, :] 65 | ar_lats = B.flatten(B.broadcast_to(ar_lats, n, n))[None, None, :] 66 | ar_xs = B.concat(ar_lons, ar_lats, axis=-2) 67 | 68 | state = B.create_random_state(torch.float32, seed=0) 69 | state, perm = B.randperm(state, torch.int64, n * n) 70 | ar_xs = B.take(ar_xs, perm, axis=-1) 71 | 72 | if args.sample: 73 | samples = [] 74 | for _ in range(3): 75 | for i in range(B.shape(ar_xs, -1)): 76 | # Sample target. 77 | x = ar_xs[:, :, i : i + 1] 78 | y = model(b["contexts"], x).sample() 79 | 80 | # Append target to contexts. 81 | xc, yc = b["contexts"][0] 82 | xc = B.concat(xc, x, axis=-1) 83 | mask = B.concat(yc.mask, B.ones(y), axis=-1) 84 | yc = B.concat(yc.y, y, axis=-1) 85 | b["contexts"][0] = (xc, Masked(yc, mask)) 86 | 87 | pred = model(b["contexts"], (lons, lats)) 88 | # state, sample = pred.noiseless.sample(state) 89 | sample = pred.mean 90 | samples.append(sample) 91 | 92 | # Reset contexts. 93 | b["contexts"][0] = (xc_fuse, yc_fuse) 94 | wd.save(samples, "samples.pickle") 95 | else: 96 | samples = wd.load("samples.pickle") 97 | 98 | mask = yc_fuse.mask 99 | yc_fuse = yc_fuse.y 100 | yc_fuse[~mask] = B.nan 101 | 102 | vmin = -15 103 | vmax = 15 104 | cmap = "bwr" 105 | 106 | plt.figure(figsize=(14, 4)) 107 | 108 | plt.subplot(1, 4, 1) 109 | plt.title("Mean") 110 | plt.imshow( 111 | mean[0].T, 112 | extent=(6, 16, 47, 55), 113 | vmin=vmin, 114 | vmax=vmax, 115 | cmap=cmap, 116 | ) 117 | plt.scatter( 118 | xc_fuse[0, 0, :], 119 | xc_fuse[0, 1, :], 120 | c=yc_fuse[0, 0, :], 121 | vmin=vmin, 122 | vmax=vmax, 123 | edgecolor="white", 124 | lw=0.5, 125 | cmap=cmap, 126 | ) 127 | tweak(legend=False, grid=False) 128 | 129 | for i in range(len(samples)): 130 | plt.subplot(1, 4, 2 + i) 131 | plt.title(f"Sample {i + 1}") 132 | plt.imshow( 133 | samples[i][0, 0].T, 134 | extent=(6, 16, 47, 55), 135 | vmin=vmin, 136 | vmax=vmax, 137 | cmap=cmap, 138 | ) 139 | plt.scatter( 140 | xc_fuse[0, 0, :], 141 | xc_fuse[0, 1, :], 142 | c=yc_fuse[0, 0, :], 143 | vmin=vmin, 144 | vmax=vmax, 145 | edgecolor="white", 146 | lw=0.5, 147 | cmap=cmap, 148 | ) 149 | tweak(legend=False, grid=False) 150 | 151 | plt.savefig(wd.file("temperature.pdf")) 152 | pdfcrop(wd.file("temperature.pdf")) 153 | plt.show() 154 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # Add package to path. 5 | file_dir = os.path.dirname(__file__) 6 | sys.path.insert(0, os.path.abspath(os.path.join(file_dir, ".."))) 7 | -------------------------------------------------------------------------------- /tests/coders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wesselb/neuralprocesses/89faf25e5bfd481865d344c6c1aec256c1fd6961/tests/coders/__init__.py -------------------------------------------------------------------------------- /tests/coders/test_shaping.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from plum import Dispatcher 3 | 4 | import neuralprocesses as nps_ 5 | 6 | from ..util import nps # noqa 7 | 8 | _dispatch = Dispatcher() 9 | 10 | 11 | @_dispatch 12 | def _to_tuple(x): 13 | return x 14 | 15 | 16 | @_dispatch 17 | def _to_tuple(x: tuple): 18 | return tuple(_to_tuple(xi) for xi in x) 19 | 20 | 21 | @_dispatch 22 | def _to_tuple(p: nps_.Parallel): 23 | return tuple(_to_tuple(pi) for pi in p) 24 | 25 | 26 | def test_restructure_parallel(nps): 27 | reorg = nps.RestructureParallel((0, (1, 2)), (0, (2,), 1)) 28 | 29 | res = nps.code( 30 | reorg, 31 | nps_.Parallel("x1", nps_.Parallel("x2", "x3")), 32 | nps_.Parallel("y1", nps_.Parallel("y2", "y3")), 33 | None, 34 | root=True, 35 | ) 36 | assert _to_tuple(res) == (("x1", ("x3",), "x2"), ("y1", ("y3",), "y2")) 37 | 38 | # Check that the structure must be right. 39 | with pytest.raises(RuntimeError, match="Parallel does not match structure."): 40 | nps.code( 41 | reorg, 42 | nps_.Parallel("x1", "x2", nps_.Parallel("x3")), 43 | nps_.Parallel("y1", nps_.Parallel("y2", "y3")), 44 | None, 45 | root=True, 46 | ) 47 | -------------------------------------------------------------------------------- /tests/dists/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wesselb/neuralprocesses/89faf25e5bfd481865d344c6c1aec256c1fd6961/tests/dists/__init__.py -------------------------------------------------------------------------------- /tests/dists/test_normal.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import pytest 3 | 4 | from neuralprocesses.dist import MultiOutputNormal 5 | 6 | from ..util import approx, nps # noqa 7 | 8 | _missing_dists = [] 9 | 10 | # Dense: 11 | mean = B.randn(4, 3, 10) 12 | var = B.randn(4, 3, 10, 3, 10) 13 | # Make variance positive definite. 14 | var = B.reshape(var, 4, 30, 30) 15 | var = var @ B.transpose(var) 16 | var = B.reshape(var, 4, 3, 10, 3, 10) 17 | noise = B.rand(4, 3, 10) 18 | _missing_dists.append( 19 | ( 20 | MultiOutputNormal.dense( 21 | B.reshape(mean, 4, 30), 22 | B.reshape(var, 4, 30, 30), 23 | B.reshape(noise, 4, 30), 24 | (3, 10), 25 | ), 26 | MultiOutputNormal.dense( 27 | B.reshape(mean[:, :2, :], 4, 20), 28 | B.reshape(var[:, :2, :, :2, :], 4, 20, 20), 29 | B.reshape(noise[:, :2, :], 4, 20), 30 | (2, 10), 31 | ), 32 | ) 33 | ) 34 | 35 | # Diagonal: 36 | mean = B.randn(4, 3, 10) 37 | noise = B.rand(4, 3, 10) 38 | _missing_dists.append( 39 | ( 40 | MultiOutputNormal.diagonal( 41 | B.reshape(mean, 4, 30), 42 | B.reshape(noise, 4, 30), 43 | (3, 10), 44 | ), 45 | MultiOutputNormal.diagonal( 46 | B.reshape(mean[:, :2, :], 4, 20), 47 | B.reshape(noise[:, :2, :], 4, 20), 48 | (2, 10), 49 | ), 50 | ) 51 | ) 52 | 53 | # Low rank: 54 | mean = B.randn(4, 3, 10) 55 | var_factor = B.rand(4, 3, 10, 7) 56 | noise = B.rand(4, 3, 10) 57 | _missing_dists.append( 58 | ( 59 | MultiOutputNormal.lowrank( 60 | B.reshape(mean, 4, 30), 61 | B.reshape(var_factor, 4, 30, 7), 62 | B.reshape(noise, 4, 30), 63 | (3, 10), 64 | ), 65 | MultiOutputNormal.lowrank( 66 | B.reshape(mean[:, :2, :], 4, 20), 67 | B.reshape(var_factor[:, :2, :, :], 4, 20, 7), 68 | B.reshape(noise[:, :2, :], 4, 20), 69 | (2, 10), 70 | ), 71 | ) 72 | ) 73 | 74 | 75 | @pytest.mark.parametrize("d, d_ref", _missing_dists) 76 | def test_monormal_missing(nps, d, d_ref): 77 | y_ref = B.randn(4, 3, 10) 78 | y = y_ref.copy() 79 | y[:, 2, :] = B.nan 80 | approx(d.logpdf(y), d_ref.logpdf(y_ref[:, :2, :])) 81 | -------------------------------------------------------------------------------- /tests/gnp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wesselb/neuralprocesses/89faf25e5bfd481865d344c6c1aec256c1fd6961/tests/gnp/__init__.py -------------------------------------------------------------------------------- /tests/gnp/autoencoding.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import pytest 3 | 4 | import neuralprocesses.gnp as gnp 5 | 6 | # noinspection PyUnresolvedReferences 7 | from .util import context_set, target_set 8 | 9 | 10 | @pytest.fixture() 11 | def disc(): 12 | return gnp.Discretisation1d(points_per_unit=32, multiple=4, margin=0.1) 13 | 14 | 15 | def test_autoencoding_1d(disc, context_set, target_set): 16 | enc = gnp.SetConv1dEncoder(disc) 17 | dec = gnp.SetConv1dDecoder(disc) 18 | 19 | xz, z = enc.forward(*context_set, target_set[0]) 20 | xz, z = dec.forward(xz, z, target_set[0]) 21 | 22 | target_shape = ( 23 | B.shape(target_set[1])[0], 24 | B.shape(target_set[1])[1], 25 | B.shape(target_set[1])[2] + 1, 26 | ) 27 | assert B.shape(z) == target_shape 28 | 29 | 30 | def test_autoencoding_2d(disc, context_set, target_set): 31 | enc = gnp.SetConv1dPDEncoder(disc) 32 | dec = gnp.SetConv1dPDDecoder(disc) 33 | 34 | xz, z = enc.forward(*context_set, target_set[0]) 35 | xz, z = dec.forward(xz, z, target_set[0]) 36 | 37 | target_shape = ( 38 | B.shape(target_set[1])[0], 39 | B.shape(target_set[1])[2] + 2, 40 | B.shape(target_set[1])[1], 41 | B.shape(target_set[1])[1], 42 | ) 43 | assert B.shape(z) == target_shape 44 | -------------------------------------------------------------------------------- /tests/gnp/gnp.py: -------------------------------------------------------------------------------- 1 | import lab.torch as B 2 | import pytest 3 | 4 | import neuralprocesses.gnp as gnp 5 | import torch 6 | 7 | # noinspection PyUnresolvedReferences 8 | from .util import context_set, target_set 9 | 10 | 11 | def test_gnp(context_set, target_set): 12 | model = gnp.GNP(y_target_dim=B.shape(target_set[1])[2]) 13 | mean, cov = model(*context_set, target_set[0]) 14 | pred = torch.distributions.MultivariateNormal(loc=mean, covariance_matrix=cov) 15 | 16 | # Check logpdf computation. This only works for one-dimensional outputs. 17 | assert B.shape(target_set[1])[2] == 1 18 | logpdf = pred.log_prob(target_set[1][:, :, 0]) 19 | 20 | # Check output. 21 | assert B.all(torch.isfinite(logpdf)) 22 | assert B.shape(logpdf) == B.shape(target_set[1])[:1] 23 | 24 | 25 | def test_gnp_x_target_check(context_set, target_set): 26 | # Must provide target inputs: 27 | model = gnp.GNP() 28 | model(*context_set, target_set[0]) 29 | with pytest.raises(ValueError): 30 | model(*context_set) 31 | 32 | # May provide target inputs: 33 | model = gnp.GNP(x_target=0.5 * target_set[0]) 34 | diff = model(*context_set)[0] - model(*context_set, target_set[0])[0] 35 | assert B.sum(B.abs(diff)) > 1e-2 36 | -------------------------------------------------------------------------------- /tests/gnp/util.py: -------------------------------------------------------------------------------- 1 | import lab.torch as B 2 | import numpy.testing 3 | import pytest 4 | 5 | import torch 6 | 7 | __all__ = ["approx", "context_set", "target_set"] 8 | 9 | 10 | approx = numpy.testing.assert_allclose 11 | 12 | 13 | @pytest.fixture() 14 | def context_set(): 15 | batch_size = 2 16 | n = 15 17 | x = B.randn(torch.float32, batch_size, n, 1) 18 | y = B.randn(torch.float32, batch_size, n, 1) 19 | return x, y 20 | 21 | 22 | @pytest.fixture() 23 | def target_set(): 24 | batch_size = 2 25 | n = 10 26 | x = B.randn(torch.float32, batch_size, n, 1) 27 | y = B.randn(torch.float32, batch_size, n, 1) 28 | return x, y 29 | -------------------------------------------------------------------------------- /tests/test_augment.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import pytest 3 | from plum import NotFoundLookupError 4 | 5 | from .test_architectures import check_prediction 6 | from .util import nps # noqa 7 | 8 | 9 | @pytest.mark.flaky(reruns=3) 10 | def test_convgnp_auxiliary_variable(nps): 11 | model = nps.construct_convgnp( 12 | dim_x=2, 13 | dim_yc=(3, 1, 2), 14 | dim_aux_t=4, 15 | dim_yt=3, 16 | num_basis_functions=16, 17 | points_per_unit=16, 18 | likelihood="lowrank", 19 | ) 20 | 21 | observed_data = ( 22 | B.randn(nps.dtype, 16, 2, 10), 23 | B.randn(nps.dtype, 16, 3, 10), 24 | ) 25 | aux_var1 = ( 26 | B.randn(nps.dtype, 16, 2, 12), 27 | B.randn(nps.dtype, 16, 1, 12), 28 | ) 29 | aux_var2 = ( 30 | (B.randn(nps.dtype, 16, 1, 25), B.randn(nps.dtype, 16, 1, 35)), 31 | B.randn(nps.dtype, 16, 2, 25, 35), 32 | ) 33 | aux_var_t = B.randn(nps.dtype, 16, 4, 15) 34 | pred = model( 35 | [observed_data, aux_var1, aux_var2], 36 | B.randn(nps.dtype, 16, 2, 15), 37 | aux_t=aux_var_t, 38 | ) 39 | 40 | check_prediction(nps, pred, B.randn(nps.dtype, 16, 3, 15)) 41 | 42 | # Check that the model cannot be run forward without the auxiliary variable. 43 | with pytest.raises(NotFoundLookupError): 44 | model( 45 | [observed_data, aux_var1, aux_var2], 46 | B.randn(nps.dtype, 16, 2, 15), 47 | ) 48 | 49 | 50 | def test_convgnp_auxiliary_variable_given_but_not_specified(nps): 51 | """Test that giving the auxiliary variable without specifying `dim_aux_t` raises 52 | an error.""" 53 | model = nps.construct_convgnp(points_per_unit=4) 54 | with pytest.raises(AssertionError, match="(?i)did not expect augmentation"): 55 | model( 56 | B.randn(nps.dtype, 4, 1, 15), 57 | B.randn(nps.dtype, 4, 1, 15), 58 | B.randn(nps.dtype, 4, 1, 10), 59 | aux_t=B.randn(nps.dtype, 4, 2, 10), 60 | ) 61 | -------------------------------------------------------------------------------- /tests/test_chain.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import neuralprocesses as nps 4 | 5 | 6 | def test_chain(): 7 | c = nps.Chain(lambda x: x - 1) 8 | assert c(3) == 2 9 | 10 | # Check that the links of the chain are processed in the right order. 11 | c = nps.Chain(lambda x: x - 1, lambda x: x**2) 12 | assert c(3) == 4 13 | c = nps.Chain(lambda x: x**2, lambda x: x - 1) 14 | assert c(3) == 8 15 | -------------------------------------------------------------------------------- /tests/test_discretisation.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | 3 | from .util import approx, nps # noqa 4 | 5 | 6 | def test_discretisation(nps): 7 | disc = nps.Discretisation(points_per_unit=33, multiple=5, margin=0.05) 8 | 9 | x1 = B.linspace(nps.dtype, 0.1, 0.5, 10)[None, None, :] 10 | x2 = B.linspace(nps.dtype, 0.2, 0.6, 15)[None, None, :] 11 | 12 | grid = disc(x1, x2) 13 | 14 | # Check begin and start. 15 | assert B.min(grid) <= 0.1 - 0.05 16 | assert B.max(grid) >= 0.6 + 0.05 17 | 18 | # Check resolution. 19 | approx(grid[1:] - grid[:-1], 1 / 33, atol=1e-8) 20 | 21 | # Check that everything is on a global grid. 22 | approx(grid * 33, B.to_numpy(grid * 33).astype(int), atol=1e-8) 23 | 24 | # Check that overshoot is balanced. 25 | overshoot_left = (0.1 - 0.05) - B.min(grid) 26 | overshoot_right = B.max(grid) - (0.6 + 0.05) 27 | assert B.abs(overshoot_left - overshoot_right) <= 1 / 33 28 | -------------------------------------------------------------------------------- /tests/test_distribution.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import scipy.stats as stats 3 | 4 | import torch 5 | from neuralprocesses.dist.beta import Beta 6 | from neuralprocesses.dist.gamma import Gamma 7 | 8 | from .test_architectures import check_prediction, generate_data 9 | from .util import approx, nps # noqa 10 | 11 | 12 | def test_transform_positive(nps): 13 | model = nps.construct_convgnp( 14 | dim_x=1, 15 | dim_y=1, 16 | points_per_unit=16, 17 | unet_channels=(8, 16), 18 | transform="positive", 19 | ) 20 | xc, yc, xt, yt = generate_data(nps, dim_x=1, dim_y=1) 21 | # Make data positive. 22 | yc = B.exp(yc) 23 | yt = B.exp(yt) 24 | pred = model(xc, yc, xt) 25 | 26 | check_prediction(nps, pred, yt) 27 | # Check that predictions and samples satisfy the constraint. 28 | assert B.all(pred.mean > 0) 29 | assert B.all(pred.sample(2) > 0) 30 | 31 | 32 | def test_transform_bounded(nps): 33 | model = nps.construct_convgnp( 34 | dim_x=1, 35 | dim_y=1, 36 | points_per_unit=16, 37 | unet_channels=(8, 16), 38 | transform=(10, 11), 39 | ) 40 | xc, yc, xt, yt = generate_data(nps, dim_x=1, dim_y=1) 41 | # Force data in the range `(10, 11)`. 42 | yc = 10 + 1 / (1 + B.exp(yc)) 43 | yt = 10 + 1 / (1 + B.exp(yt)) 44 | 45 | pred = model(xc, yc, xt) 46 | check_prediction(nps, pred, yt) 47 | # Check that predictions and samples satisfy the constraint. 48 | assert B.all(pred.mean > 10) and B.all(pred.mean < 11) 49 | assert B.all(pred.sample() > 10) and B.all(pred.sample() < 11) 50 | 51 | 52 | def test_beta_correctness(): 53 | """Test the correctness of the beta distribution.""" 54 | beta = Beta(B.cast(torch.float64, 0.2), B.cast(torch.float64, 0.8), 0) 55 | beta_ref = stats.beta(0.2, 0.8) 56 | 57 | sample = beta.sample() 58 | approx(beta.logpdf(sample), beta_ref.logpdf(sample)) 59 | approx(beta.mean, beta_ref.mean()) 60 | approx(beta.var, beta_ref.var()) 61 | 62 | # Test dimensionality argument. 63 | for d in range(4): 64 | beta = Beta(beta.alpha, beta.beta, d) 65 | assert beta.logpdf(beta.sample(1, 2, 3)).shape == (1, 2, 3)[: 3 - d] 66 | 67 | 68 | def test_gamma(): 69 | """Test the correctness of the gamma distribution.""" 70 | gamma = Gamma(B.cast(torch.float64, 2), B.cast(torch.float64, 0.8), 0) 71 | gamma_ref = stats.gamma(2, scale=0.8) 72 | 73 | sample = gamma.sample() 74 | approx(gamma.logpdf(sample), gamma_ref.logpdf(sample)) 75 | approx(gamma.mean, gamma_ref.mean()) 76 | approx(gamma.var, gamma_ref.var()) 77 | 78 | # Test dimensionality argument. 79 | for d in range(4): 80 | gamma = Gamma(gamma.k, gamma.scale, d) 81 | assert gamma.logpdf(gamma.sample(1, 2, 3)).shape == (1, 2, 3)[: 3 - d] 82 | -------------------------------------------------------------------------------- /tests/test_mask.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import pytest 3 | 4 | from .test_architectures import generate_data 5 | from .util import approx, nps # noqa 6 | 7 | 8 | @pytest.mark.flaky(reruns=3) 9 | def test_convgnp_mask(nps): 10 | model = nps.construct_convgnp( 11 | num_basis_functions=16, 12 | points_per_unit=16, 13 | conv_arch="conv", 14 | conv_receptive_field=0.5, 15 | conv_layers=1, 16 | conv_channels=1, 17 | # A large margin and `float64`s help with numerical stability. 18 | margin=2, 19 | dtype=nps.dtype64, 20 | ) 21 | xc, yc, xt, yt = generate_data(nps, dtype=nps.dtype64) 22 | 23 | # Predict without the final three points. 24 | pred = model(xc[:, :, :-3], yc[:, :, :-3], xt) 25 | # Predict using a mask instead. 26 | mask = B.to_numpy(B.ones(yc)) # Perform assignment in NumPy. 27 | mask[:, :, -3:] = 0 28 | mask = B.cast(B.dtype(yc), mask) 29 | pred_masked = model(xc, nps.Masked(yc, mask), xt) 30 | 31 | # Check that the two ways of doing it coincide. 32 | approx(pred.mean, pred_masked.mean) 33 | approx(pred.var, pred_masked.var) 34 | 35 | 36 | @pytest.mark.parametrize("ns", [(10,), (0,), (10, 5), (10, 0), (0, 10), (15, 5, 10)]) 37 | @pytest.mark.parametrize("multiple", [1, 2, 3, 5]) 38 | def test_mask_contexts(nps, ns, multiple): 39 | x, y = nps.merge_contexts( 40 | *((B.randn(nps.dtype, 2, 3, n), B.randn(nps.dtype, 2, 4, n)) for n in ns), 41 | multiple=multiple, 42 | ) 43 | 44 | # Test that the output is of the right shape. 45 | if max(ns) == 0: 46 | assert B.shape(y.y, 2) == multiple 47 | else: 48 | assert B.shape(y.y, 2) == ((max(ns) - 1) // multiple + 1) * multiple 49 | 50 | # Test that the mask is right. 51 | mask = y.mask == 1 # Convert mask to booleans. 52 | assert B.all(B.take(B.flatten(y.y), B.flatten(mask)) != 0) 53 | assert B.all(B.take(B.flatten(y.y), B.flatten(~mask)) == 0) 54 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import numpy as np 3 | import pytest 4 | 5 | from .test_architectures import generate_data 6 | from .util import approx, generate_data, nps # noqa 7 | 8 | 9 | @pytest.mark.parametrize("dim_lv", [0, 4]) 10 | def test_loglik_batching(nps, dim_lv): 11 | xc, yc, xt, yt = generate_data(nps) 12 | model = nps.construct_gnp(dim_lv=dim_lv, dtype=nps.dtype) 13 | # Test a high number of samples, a number which also isn't a multiple of the batch 14 | # size. 15 | logpdfs = B.mean( 16 | nps.loglik(model, xc, yc, xt, yt, num_samples=4000, batch_size=128) 17 | ) 18 | assert np.isfinite(B.to_numpy(logpdfs)) 19 | 20 | 21 | @pytest.mark.parametrize("dim_lv", [0, 4]) 22 | @pytest.mark.parametrize("dtype_lik", [None, "dtype32", "dtype64"]) 23 | @pytest.mark.parametrize("objective", ["loglik", "elbo"]) 24 | @pytest.mark.parametrize("normalise", [True, False]) 25 | def test_loglik_dtype_lik(nps, dim_lv, dtype_lik, objective, normalise): 26 | xc, yc, xt, yt = generate_data(nps) 27 | dtype_lik = None if dtype_lik is None else getattr(nps, dtype_lik) 28 | model = nps.construct_gnp(dim_lv=dim_lv, dtype=nps.dtype) 29 | 30 | logpdfs = B.mean( 31 | getattr(nps, objective)( 32 | model, 33 | xc, 34 | yc, 35 | xt, 36 | yt, 37 | normalise=normalise, 38 | dtype_lik=dtype_lik, 39 | ) 40 | ) 41 | if dtype_lik: 42 | assert B.dtype(logpdfs) == dtype_lik 43 | assert np.isfinite(B.to_numpy(logpdfs)) 44 | 45 | 46 | def test_ar_predict_without_aggregate(nps): 47 | xc, yc, xt, yt = generate_data(nps, dim_x=2, dim_y=3) 48 | convcnp = nps.construct_gnp(dim_x=2, dim_yc=(1, 1, 1), dim_yt=3, dtype=nps.dtype) 49 | 50 | # Perform AR by using `AggregateInput` manually. 51 | _, mean1, var1, ft1, yt1 = nps.ar_predict( 52 | B.create_random_state(nps.dtype, seed=0), 53 | convcnp, 54 | [(xc, yc[:, 0:1, :]), (xc, yc[:, 1:2, :]), (xc, yc[:, 2:3, :])], 55 | nps.AggregateInput((xt, 0), (xt, 1), (xt, 2)), 56 | ) 57 | mean1 = B.concat(*mean1, axis=-2) 58 | var1 = B.concat(*var1, axis=-2) 59 | ft1 = B.concat(*ft1, axis=-2) 60 | yt1 = B.concat(*yt1, axis=-2) 61 | 62 | # Let the package work its magic. 63 | _, mean2, var2, ft2, yt2 = nps.ar_predict( 64 | B.create_random_state(nps.dtype, seed=0), 65 | convcnp, 66 | xc, 67 | yc, 68 | xt, 69 | ) 70 | 71 | # Check that the two give identical results. 72 | approx(mean1, mean2) 73 | approx(var1, var2) 74 | approx(ft1, ft2) 75 | approx(yt1, yt2) 76 | -------------------------------------------------------------------------------- /tests/test_unet.py: -------------------------------------------------------------------------------- 1 | import lab as B 2 | import numpy as np 3 | from plum import isinstance 4 | 5 | from .util import nps # noqa 6 | 7 | 8 | def test_unet_1d(nps): 9 | unet = nps.UNet( 10 | dim=1, 11 | in_channels=3, 12 | out_channels=4, 13 | channels=(8, 16, 16, 32, 32, 64), 14 | ) 15 | n = 2 * 2**unet.num_halving_layers 16 | z = B.randn(nps.dtype, 2, 3, n) 17 | assert B.shape(unet(z)) == (2, 4, n) 18 | assert 40_000 <= nps.num_params(unet) <= 60_000 19 | 20 | 21 | def test_unet_1d_receptive_field(nps): 22 | unet = nps.UNet( 23 | dim=1, 24 | in_channels=1, 25 | out_channels=1, 26 | channels=(3, 5, 7, 5, 3), 27 | activations=(B.identity,) * 5, 28 | ) 29 | # Run the model once. 30 | mult = 2**unet.num_halving_layers 31 | x = B.zeros(nps.dtype, 1, 1, mult) 32 | unet(x) 33 | # Set all weights to one. 34 | if isinstance(nps.dtype, B.TFDType): 35 | unet.set_weights(0 * np.array(unet.get_weights(), dtype=object) + 1) 36 | elif isinstance(nps.dtype, B.TorchDType): 37 | for p in unet.parameters(): 38 | p.data = p.data * 0 + 1 39 | else: 40 | raise RuntimeError("I don't know how to set the weights of the model.") 41 | for offset in range(unet.receptive_field): 42 | # Create perturbation. 43 | x = B.zeros(1, 1, int(10 * unet.receptive_field / mult) * mult) 44 | x[0, 0, 5 * unet.receptive_field + offset] = 1 45 | x = B.cast(nps.dtype, x) 46 | # Check that the computed receptive field is indeed right. 47 | n = B.sum(B.cast(B.dtype(x), B.abs(B.flatten(unet(x * 0) - unet(x))) > 0)) 48 | assert n == unet.receptive_field 49 | 50 | 51 | def test_unet_2d(nps): 52 | unet = nps.UNet( 53 | dim=2, 54 | in_channels=3, 55 | out_channels=4, 56 | channels=(8, 16, 16, 32, 32, 64), 57 | ) 58 | n = 2 * 2**unet.num_halving_layers 59 | z = B.randn(nps.dtype, 2, 3, n, n) 60 | assert B.shape(unet(z)) == (2, 4, n, n) 61 | assert 200_000 <= nps.num_params(unet) <= 300_000 62 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | from .util import generate_data, nps # noqa 2 | 3 | 4 | def test_num_params(nps): 5 | model = nps.construct_gnp() 6 | model(*generate_data(nps)[:3]) # Run forward to initialise parameters. 7 | assert isinstance(nps.num_params(model), int) 8 | assert nps.num_params(model) > 0 9 | -------------------------------------------------------------------------------- /tests/util.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from typing import Union 3 | 4 | import lab as B 5 | import pytest 6 | from numpy.testing import assert_allclose 7 | from plum import Dispatcher 8 | 9 | import neuralprocesses 10 | import tensorflow as tf 11 | import torch 12 | 13 | __all__ = ["approx", "nps", "generate_data", "remote_xfail", "remote_skip"] 14 | 15 | _dispatch = Dispatcher() 16 | 17 | # Stabilise numerics during tests. 18 | B.epsilon = 1e-6 19 | B.cholesky_retry_factor = 1e4 20 | 21 | 22 | @_dispatch 23 | def approx(a, b, **kw_args): 24 | assert_allclose(B.to_numpy(a), B.to_numpy(b), **kw_args) 25 | 26 | 27 | @_dispatch 28 | def approx(a: None, b: None, **kw_args): 29 | assert True 30 | 31 | 32 | @_dispatch 33 | def approx( 34 | a: Union[neuralprocesses.Parallel, tuple], 35 | b: Union[neuralprocesses.Parallel, tuple], 36 | **kw_args, 37 | ): 38 | assert len(a) == len(b) 39 | for ai, bi in zip(a, b): 40 | approx(ai, bi, **kw_args) 41 | 42 | 43 | import neuralprocesses.tensorflow as nps_tf 44 | import neuralprocesses.torch as nps_torch 45 | 46 | nps_torch.dtype = torch.float32 47 | nps_torch.dtype32 = torch.float32 48 | nps_torch.dtype64 = torch.float64 49 | nps_tf.dtype = tf.float32 50 | nps_tf.dtype32 = tf.float32 51 | nps_tf.dtype64 = tf.float64 52 | 53 | 54 | @pytest.fixture(params=[nps_tf, nps_torch], scope="module") 55 | def nps(request): 56 | return request.param 57 | 58 | 59 | def generate_data( 60 | nps, 61 | batch_size=4, 62 | dim_x=1, 63 | dim_y=1, 64 | n_context=5, 65 | n_target=7, 66 | binary=False, 67 | dtype=None, 68 | ): 69 | if dtype is None: 70 | dtype = nps.dtype 71 | xc = B.randn(dtype, batch_size, dim_x, n_context) 72 | yc = B.randn(dtype, batch_size, dim_y, n_context) 73 | xt = B.randn(dtype, batch_size, dim_x, n_target) 74 | yt = B.randn(dtype, batch_size, dim_y, n_target) 75 | if binary: 76 | yc = B.cast(dtype, yc >= 0) 77 | yt = B.cast(dtype, yt >= 0) 78 | return xc, yc, xt, yt 79 | 80 | 81 | if socket.gethostname().lower().startswith("wessel"): 82 | remote_xfail = lambda f: f #: `xfail` only on CI. 83 | remote_skip = lambda f: f #: `skip` only on CI. 84 | else: 85 | remote_xfail = pytest.mark.xfail #: `xfail` only on CI. 86 | remote_skip = pytest.mark.skip #: `skip` only on CI. 87 | --------------------------------------------------------------------------------