├── .coveragerc ├── .github └── workflows │ ├── build_book.yml │ ├── ci.yml │ ├── publish.yml │ └── style.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENCE.txt ├── MANIFEST.in ├── Makefile ├── README.md ├── benchmark.py ├── docs ├── _config.yml ├── _toc.yml ├── api.rst ├── chapter.md ├── intro.md ├── logo.png └── references.bib ├── lab ├── __init__.py ├── autograd │ ├── __init__.py │ ├── custom.py │ ├── generic.py │ ├── linear_algebra.py │ ├── random.py │ └── shaping.py ├── bvn_cdf │ ├── bvn_cdf.pyx │ ├── tvpack.f │ └── tvpack.h ├── control_flow.py ├── custom.py ├── generic.py ├── jax │ ├── __init__.py │ ├── custom.py │ ├── generic.py │ ├── linear_algebra.py │ ├── random.py │ └── shaping.py ├── linear_algebra.py ├── numpy │ ├── __init__.py │ ├── generic.py │ ├── linear_algebra.py │ ├── random.py │ └── shaping.py ├── random.py ├── shape.py ├── shaping.py ├── tensorflow │ ├── __init__.py │ ├── custom.py │ ├── generic.py │ ├── linear_algebra.py │ ├── random.py │ └── shaping.py ├── torch │ ├── __init__.py │ ├── custom.py │ ├── generic.py │ ├── linear_algebra.py │ ├── random.py │ └── shaping.py ├── types.py └── util.py ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── test_control_flow.py ├── test_custom.py ├── test_generic.py ├── test_linear_algebra.py ├── test_random.py ├── test_shape.py ├── test_shaping.py ├── test_types.py ├── test_util.py └── util.py └── todo.tasks /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = lab/_version.py 3 | 4 | [report] 5 | exclude_lines = 6 | pragma: no cover 7 | pragma: specific no cover.*${PRAGMA_VERSION} 8 | -------------------------------------------------------------------------------- /.github/workflows/build_book.yml: -------------------------------------------------------------------------------- 1 | name: Build Jupyter Book 2 | 3 | on: 4 | # Trigger the workflow on push to main branch. 5 | push: 6 | branches: 7 | - main 8 | 9 | # This job installs dependencies, build the book, and pushes it to `gh-pages`. 10 | jobs: 11 | build-and-deploy-book: 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ubuntu-latest] 16 | python-version: [3.8] 17 | steps: 18 | - uses: actions/checkout@v2 19 | 20 | # Install dependencies. 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v1 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | sudo apt-get install gfortran 28 | pip install --upgrade pip setuptools 'setuptools_scm[toml]' setuptools_scm_git_archive numpy Cython 29 | python setup.py --version 30 | pip install --no-cache-dir -U -r requirements.txt | cat 31 | pip install --upgrade numpy 32 | 33 | # Build the book. 34 | - name: Build 35 | run: | 36 | jupyter-book build docs 37 | 38 | # Deploy the book's HTML to the branch `gh-pages`. 39 | - name: Deploy to GitHub Pages 40 | uses: peaceiris/actions-gh-pages@v3.6.1 41 | with: 42 | github_token: ${{ secrets.GITHUB_TOKEN }} 43 | publish_dir: docs/_build/html 44 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: [3.8, 3.9, "3.10", "3.11"] 13 | steps: 14 | - uses: actions/checkout@v2 15 | with: 16 | fetch-depth: 0 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | sudo apt-get install gfortran 24 | # JAX isn't yet NumPy 2 compatible. 25 | pip install --upgrade pip setuptools 'setuptools_scm[toml]' setuptools_scm_git_archive numpy Cython 26 | python setup.py --version 27 | LAB_BUILD=1 pip install --no-cache-dir -U -r requirements.txt | cat 28 | - name: Test 29 | run: | 30 | JAX_ENABLE_X64=1 pytest -v --cov=lab --cov-report term-missing 31 | coveralls --service=github 32 | env: 33 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | COVERALLS_FLAG_NAME: ${{ matrix.test-name }} 35 | COVERALLS_PARALLEL: true 36 | 37 | finish: 38 | name: Finish Coveralls 39 | needs: test 40 | runs-on: ubuntu-latest 41 | steps: 42 | - name: Finish Coveralls 43 | uses: coverallsapp/github-action@v1 44 | with: 45 | github-token: ${{ secrets.github_token }} 46 | parallel-finished: true 47 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python package using Twine when a release is 2 | # created. For more information see the following link: 3 | # https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 4 | 5 | name: Publish to PyPI 6 | 7 | on: 8 | release: 9 | types: [published] 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | # Make sure tags are fetched so we can get a version. 19 | - run: | 20 | git fetch --prune --unshallow --tags 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: '3.x' 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install -U setuptools 'setuptools_scm[toml]' setuptools_scm_git_archive wheel twine 30 | pip install -U numpy Cython 31 | - name: Build and publish 32 | env: 33 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 34 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 35 | 36 | run: | 37 | python setup.py sdist 38 | twine upload dist/* 39 | -------------------------------------------------------------------------------- /.github/workflows/style.yml: -------------------------------------------------------------------------------- 1 | name: Code style 2 | on: 3 | - push 4 | - pull_request 5 | 6 | jobs: 7 | check: 8 | name: Check style 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 3.9 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: 3.9 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | python -m pip install pre-commit 20 | pre-commit install 21 | - name: Check code style 22 | run: pre-commit run --all-files 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Autogenerated files 2 | lab/_version.py 3 | 4 | # Byte-compiled file 5 | *.pyc 6 | 7 | # Virtual environments 8 | venv 9 | 10 | # Packaging 11 | *.egg-info 12 | dist 13 | pip-wheel-metadata 14 | 15 | # Documentation and coverage 16 | docs/_build 17 | docs/_static 18 | docs/source 19 | docs/readme.rst 20 | cover 21 | 22 | # Other 23 | .DS_Store 24 | *.swp 25 | 26 | # Cython build files 27 | *.html 28 | *.c 29 | *.so 30 | *.o 31 | 32 | # IDE 33 | .vscode 34 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.9 3 | repos: 4 | - repo: https://github.com/psf/black 5 | rev: 23.7.0 6 | hooks: 7 | - id: black 8 | - repo: https://github.com/pycqa/isort 9 | rev: 5.12.0 10 | hooks: 11 | - id: isort 12 | args: ["--profile", "black"] 13 | # - repo: https://github.com/pycqa/flake8 14 | # rev: 5.0.4 15 | # hooks: 16 | # - id: flake8 17 | # args: ["--max-line-length=88", "--extend-ignore=E203,F811"] 18 | # additional_dependencies: 19 | # - flake8-bugbear>=22.12 20 | # - flake8-noqa>=1.3 21 | -------------------------------------------------------------------------------- /LICENCE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Wessel Bruinsma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include pyproject.toml 2 | include lab/bvn_cdf/tvpack.h 3 | include lab/bvn_cdf/tvpack.f -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | 3 | PACKAGE := lab 4 | 5 | test: 6 | pre-commit run --all-files 7 | PRAGMA_VERSION=`python -c "import sys; print('.'.join(map(str, sys.version_info[:2])))"` \ 8 | pytest tests -v --cov=$(PACKAGE) --cov-report html:cover --cov-report term-missing 9 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import autograd.numpy as np 4 | 5 | import lab as B 6 | 7 | n = 20 8 | m = 1 9 | 10 | t = np.float64 11 | eps = B.cast(t, B.epsilon) 12 | 13 | 14 | def f1(x): 15 | dists2 = (x - B.transpose(x)) ** 2 16 | K = B.exp(-0.5 * dists2) 17 | K = K + B.epsilon * B.eye(t, n) 18 | L = B.cholesky(K) 19 | return B.matmul(L, B.ones(t, n, m)) 20 | 21 | 22 | def f2(x): 23 | dists2 = (x - np.transpose(x)) ** 2 24 | K = np.exp(-0.5 * dists2) 25 | K = K + B.epsilon * np.eye(n, dtype=t) 26 | L = np.linalg.cholesky(K) 27 | return np.matmul(L, np.ones((n, m))) 28 | 29 | 30 | # Perform computation once. 31 | x = np.linspace(0, 1, n, dtype=t)[:, None] 32 | f1(x) 33 | f2(x) 34 | 35 | its = 10000 36 | 37 | s = time() 38 | for _ in range(its): 39 | z = f2(x) 40 | us_native = (time() - s) / its * 1e6 41 | 42 | s = time() 43 | for _ in range(its): 44 | z = f1(x) 45 | us_lab = (time() - s) / its * 1e6 46 | 47 | print( 48 | "Overhead: {:.1f} us / {:.1f} %" 49 | "".format(us_lab - us_native, 100 * (us_lab / us_native - 1)) 50 | ) 51 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | # Book settings 2 | title: Linear Algebra Backends 3 | author: Wessel Bruinsma 4 | copyright: "2023" 5 | logo: logo.png 6 | 7 | # Enable definition lists. 8 | parse: 9 | myst_enable_extensions: 10 | - deflist 11 | 12 | # Force re-execution of notebooks on each build. 13 | # See https://jupyterbook.org/content/execute.html 14 | execute: 15 | execute_notebooks: force 16 | 17 | # Define the name of the latex output file for PDF builds. 18 | latex: 19 | latex_documents: 20 | targetname: book.tex 21 | 22 | # Load AutoDoc extension. 23 | sphinx: 24 | extra_extensions: 25 | - 'sphinx.ext.autodoc' 26 | - 'sphinx.ext.napoleon' 27 | - 'sphinx.ext.viewcode' 28 | 29 | # Add a BiBTeX file so that we can create citations. 30 | bibtex_bibfiles: 31 | - references.bib 32 | 33 | # Information about where the book exists on the web. 34 | repository: 35 | url: https://github.com/wesselb/lab 36 | path_to_book: docs 37 | branch: master 38 | 39 | # Add GitHub buttons to your book. 40 | html: 41 | use_issues_button: true 42 | use_repository_button: true 43 | -------------------------------------------------------------------------------- /docs/_toc.yml: -------------------------------------------------------------------------------- 1 | # Table of contents 2 | # Learn more at https://jupyterbook.org/customize/toc.html 3 | 4 | format: jb-book 5 | root: intro 6 | chapters: 7 | - file: chapter 8 | - file: api 9 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | Application Programming Interface 2 | ================================= 3 | 4 | Generic 5 | ------- 6 | .. automodule:: lab.generic 7 | :members: 8 | 9 | Linear Algebra 10 | -------------- 11 | .. automodule:: lab.linear_algebra 12 | :members: 13 | 14 | Random 15 | ------ 16 | .. automodule:: lab.random 17 | :members: 18 | 19 | Shaping 20 | ------- 21 | .. automodule:: lab.shaping 22 | :members: 23 | 24 | Control Flow 25 | ------------ 26 | .. automodule:: lab.control_flow 27 | :members: 28 | 29 | Types 30 | ----- 31 | .. automodule:: lab.types 32 | :members: 33 | 34 | Shape 35 | ----- 36 | .. automodule:: lab.shape 37 | :members: 38 | 39 | Util 40 | ---- 41 | .. automodule:: lab.util 42 | :members: 43 | -------------------------------------------------------------------------------- /docs/chapter.md: -------------------------------------------------------------------------------- 1 | # Chapter 2 | 3 | Coming soon. 4 | -------------------------------------------------------------------------------- /docs/intro.md: -------------------------------------------------------------------------------- 1 | # Linear Algebra Backends 2 | 3 | Welcome to the package! 4 | Please click next. 5 | -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wesselb/lab/5d1f9634128fd43262b21746bc3a237eac6db165/docs/logo.png -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | --- 2 | --- -------------------------------------------------------------------------------- /lab/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from plum import Dispatcher 4 | 5 | B = sys.modules[__name__] # Allow both import styles. 6 | dispatch = Dispatcher() # This dispatch namespace will be used everywhere. 7 | 8 | from .control_flow import * 9 | from .generic import * 10 | from .linear_algebra import * 11 | from .numpy import * 12 | from .random import * 13 | from .shaping import * 14 | from .types import * 15 | 16 | # Fix namespace issues with `B.bvn_cdf` simply by setting it explicitly. 17 | B.bvn_cdf = B.generic.bvn_cdf 18 | -------------------------------------------------------------------------------- /lab/autograd/__init__.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | from .. import * 3 | from .. import dispatch as dispatch_original 4 | from ..shape import dispatch_unwrap_dimensions 5 | from ..types import AGNumeric, NPNumeric, Number 6 | 7 | dispatch = dispatch_unwrap_dimensions(dispatch_original) 8 | 9 | from typing import Union 10 | 11 | Numeric = Union[Number, NPNumeric, AGNumeric] 12 | 13 | import autograd # Load `autograd` to load all new types. 14 | from plum import clear_all_cache as _clear_all_cache 15 | 16 | # noinspection PyUnresolvedReferences 17 | from .generic import * 18 | from .linear_algebra import * 19 | from .random import * 20 | from .shaping import * 21 | 22 | # Clear cache to make sure that all newly loaded types are available. 23 | _clear_all_cache() 24 | 25 | # Alias to actual module. 26 | sys.modules[__name__] = B 27 | -------------------------------------------------------------------------------- /lab/autograd/custom.py: -------------------------------------------------------------------------------- 1 | from autograd.extend import defvjp_argnums, primitive 2 | from plum import Dispatcher 3 | 4 | from ..util import as_tuple 5 | 6 | __all__ = ["autograd_register"] 7 | 8 | _dispatch = Dispatcher() 9 | 10 | 11 | def autograd_register(f, s_f): 12 | """Register a function and its sensitivity for AutoGrad. 13 | 14 | Args: 15 | f (function): Function to register. 16 | s_f (function): Sensitivity of `f`. 17 | 18 | Returns: 19 | function: AutoGrad primitive. 20 | """ 21 | # Create a primitive for `f`. 22 | f_primitive = primitive(f) 23 | 24 | # Register the sensitivity. 25 | def vjp_argnums(nums, y, args, kw_args): 26 | def vjp(s_y): 27 | grads = as_tuple(s_f(s_y, y, *args, **kw_args)) 28 | return tuple([grads[i] for i in nums]) 29 | 30 | return vjp 31 | 32 | defvjp_argnums(f_primitive, vjp_argnums) 33 | 34 | # Return the AutoGrad primitive. 35 | return f_primitive 36 | -------------------------------------------------------------------------------- /lab/autograd/generic.py: -------------------------------------------------------------------------------- 1 | from types import FunctionType 2 | from typing import Union 3 | 4 | import autograd.numpy as anp 5 | import autograd.scipy.special as asps 6 | 7 | from ..custom import bvn_cdf, s_bvn_cdf 8 | from ..types import AGDType, AGNumeric, AGRandomState, Int 9 | from . import Numeric, dispatch 10 | from .custom import autograd_register 11 | 12 | __all__ = [] 13 | 14 | 15 | @dispatch 16 | def isabstract(a: Numeric): 17 | return False 18 | 19 | 20 | @dispatch 21 | def _jit_run( 22 | f: FunctionType, 23 | compilation_cache: dict, 24 | jit_kw_args: dict, 25 | *args: Union[Numeric, AGRandomState], 26 | **kw_args 27 | ): 28 | # There is no JIT for AutoGrad, so just run the function. 29 | return f(*args, **kw_args) 30 | 31 | 32 | @dispatch 33 | def isnan(a: Numeric): 34 | return anp.isnan(a) 35 | 36 | 37 | @dispatch 38 | def real(a: Numeric): 39 | return anp.real(a) 40 | 41 | 42 | @dispatch 43 | def imag(a: Numeric): 44 | return anp.imag(a) 45 | 46 | 47 | @dispatch 48 | def device(a: AGNumeric): 49 | return "cpu" 50 | 51 | 52 | @dispatch 53 | def to_active_device(a: AGNumeric): 54 | return a 55 | 56 | 57 | @dispatch 58 | def cast(dtype: AGDType, a: AGNumeric): 59 | # AutoGrad does not respect the `copy` flag, so check that manually. 60 | if dtype == a.dtype: 61 | return a 62 | else: 63 | return a.astype(dtype) 64 | 65 | 66 | @dispatch 67 | def identity(a: Numeric): 68 | return 1 * a 69 | 70 | 71 | @dispatch 72 | def round(a: Numeric): 73 | return anp.round(a) 74 | 75 | 76 | @dispatch 77 | def floor(a: Numeric): 78 | return anp.floor(a) 79 | 80 | 81 | @dispatch 82 | def ceil(a: Numeric): 83 | return anp.ceil(a) 84 | 85 | 86 | @dispatch 87 | def negative(a: Numeric): 88 | return anp.negative(a) 89 | 90 | 91 | @dispatch 92 | def abs(a: Numeric): 93 | return anp.abs(a) 94 | 95 | 96 | @dispatch 97 | def sign(a: Numeric): 98 | return anp.sign(a) 99 | 100 | 101 | @dispatch 102 | def sqrt(a: Numeric): 103 | return anp.sqrt(a) 104 | 105 | 106 | @dispatch 107 | def exp(a: Numeric): 108 | return anp.exp(a) 109 | 110 | 111 | @dispatch 112 | def log(a: Numeric): 113 | return anp.log(a) 114 | 115 | 116 | @dispatch 117 | def log1p(a: Numeric): 118 | return anp.log1p(a) 119 | 120 | 121 | @dispatch 122 | def sin(a: Numeric): 123 | return anp.sin(a) 124 | 125 | 126 | @dispatch 127 | def arcsin(a: Numeric): 128 | return anp.arcsin(a) 129 | 130 | 131 | @dispatch 132 | def cos(a: Numeric): 133 | return anp.cos(a) 134 | 135 | 136 | @dispatch 137 | def arccos(a: Numeric): 138 | return anp.arccos(a) 139 | 140 | 141 | @dispatch 142 | def tan(a: Numeric): 143 | return anp.tan(a) 144 | 145 | 146 | @dispatch 147 | def arctan(a: Numeric): 148 | return anp.arctan(a) 149 | 150 | 151 | @dispatch 152 | def tanh(a: Numeric): 153 | return anp.tanh(a) 154 | 155 | 156 | @dispatch 157 | def arctanh(a: Numeric): 158 | return anp.arctanh(a) 159 | 160 | 161 | @dispatch 162 | def loggamma(a: Numeric): 163 | return asps.gammaln(a) 164 | 165 | 166 | @dispatch 167 | def erf(a: Numeric): 168 | return asps.erf(a) 169 | 170 | 171 | @dispatch 172 | def add(a: Numeric, b: Numeric): 173 | return anp.add(a, b) 174 | 175 | 176 | @dispatch 177 | def subtract(a: Numeric, b: Numeric): 178 | return anp.subtract(a, b) 179 | 180 | 181 | @dispatch 182 | def multiply(a: Numeric, b: Numeric): 183 | return anp.multiply(a, b) 184 | 185 | 186 | @dispatch 187 | def divide(a: Numeric, b: Numeric): 188 | return anp.divide(a, b) 189 | 190 | 191 | @dispatch 192 | def power(a: Numeric, b: Numeric): 193 | return anp.power(a, b) 194 | 195 | 196 | @dispatch 197 | def minimum(a: Numeric, b: Numeric): 198 | return anp.minimum(a, b) 199 | 200 | 201 | @dispatch 202 | def maximum(a: Numeric, b: Numeric): 203 | return anp.maximum(a, b) 204 | 205 | 206 | @dispatch 207 | def min(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 208 | return anp.min(a, axis=axis, keepdims=not squeeze) 209 | 210 | 211 | @dispatch 212 | def argmin(a: Numeric, axis: Union[Int, None] = None): 213 | return anp.argmin(a, axis=axis) 214 | 215 | 216 | @dispatch 217 | def max(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 218 | return anp.max(a, axis=axis, keepdims=not squeeze) 219 | 220 | 221 | @dispatch 222 | def argmax(a: Numeric, axis: Union[Int, None] = None): 223 | return anp.argmax(a, axis=axis) 224 | 225 | 226 | @dispatch 227 | def sum(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 228 | return anp.sum(a, axis=axis, keepdims=not squeeze) 229 | 230 | 231 | @dispatch 232 | def prod(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 233 | return anp.prod(a, axis=axis, keepdims=not squeeze) 234 | 235 | 236 | @dispatch 237 | def mean(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 238 | return anp.mean(a, axis=axis, keepdims=not squeeze) 239 | 240 | 241 | @dispatch 242 | def std(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 243 | return anp.std(a, axis=axis, ddof=0, keepdims=not squeeze) 244 | 245 | 246 | @dispatch 247 | def all(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 248 | return anp.all(a, axis=axis, keepdims=not squeeze) 249 | 250 | 251 | @dispatch 252 | def any(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 253 | return anp.any(a, axis=axis, keepdims=not squeeze) 254 | 255 | 256 | @dispatch 257 | def lt(a: Numeric, b: Numeric): 258 | return anp.less(a, b) 259 | 260 | 261 | @dispatch 262 | def le(a: Numeric, b: Numeric): 263 | return anp.less_equal(a, b) 264 | 265 | 266 | @dispatch 267 | def gt(a: Numeric, b: Numeric): 268 | return anp.greater(a, b) 269 | 270 | 271 | @dispatch 272 | def ge(a: Numeric, b: Numeric): 273 | return anp.greater_equal(a, b) 274 | 275 | 276 | @dispatch 277 | def eq(a: Numeric, b: Numeric): 278 | return anp.equal(a, b) 279 | 280 | 281 | @dispatch 282 | def ne(a: Numeric, b: Numeric): 283 | return anp.not_equal(a, b) 284 | 285 | 286 | _bvn_cdf = autograd_register(bvn_cdf, s_bvn_cdf) 287 | 288 | 289 | @dispatch 290 | def bvn_cdf(a: Numeric, b: Numeric, c: Numeric): 291 | return _bvn_cdf(a, b, c) 292 | 293 | 294 | @dispatch 295 | def where(condition: Numeric, a: Numeric, b: Numeric): 296 | return anp.where(condition, a, b) 297 | 298 | 299 | @dispatch 300 | def sort(a: Numeric, axis: Int = -1, descending: bool = False): 301 | if descending: 302 | return -anp.sort(-a, axis=axis) 303 | else: 304 | return anp.sort(a, axis=axis) 305 | 306 | 307 | @dispatch 308 | def argsort(a: Numeric, axis: Int = -1, descending: bool = False): 309 | if descending: 310 | return anp.argsort(-a, axis=axis) 311 | else: 312 | return anp.argsort(a, axis=axis) 313 | 314 | 315 | @dispatch 316 | def quantile(a: Numeric, q: Numeric, axis: Union[Int, None] = None): # pragma: no cover 317 | raise NotImplementedError("Function `quantile` is not available for AutoGrad.") 318 | -------------------------------------------------------------------------------- /lab/autograd/linear_algebra.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Union 3 | 4 | import autograd.numpy as anp 5 | import autograd.scipy.linalg as asla 6 | import opt_einsum as oe 7 | 8 | from ..custom import expm, logm, s_expm, s_logm, s_toeplitz_solve, toeplitz_solve 9 | from ..linear_algebra import _default_perm 10 | from ..types import Int 11 | from ..util import batch_computation, resolve_axis 12 | from . import B, Numeric, dispatch 13 | from .custom import autograd_register 14 | 15 | __all__ = [] 16 | log = logging.getLogger(__name__) 17 | 18 | 19 | @dispatch 20 | def matmul(a: Numeric, b: Numeric, tr_a: bool = False, tr_b: bool = False): 21 | a = transpose(a) if tr_a else a 22 | b = transpose(b) if tr_b else b 23 | return anp.matmul(a, b) 24 | 25 | 26 | @dispatch 27 | def einsum(equation: str, *elements: Numeric): 28 | return oe.contract(equation, *elements, backend="autograd") 29 | 30 | 31 | @dispatch 32 | def transpose(a: Numeric, perm: Optional[Union[tuple, list]] = None): 33 | # Correctly handle special cases. 34 | rank_a = B.rank(a) 35 | if rank_a == 0: 36 | return a 37 | elif rank_a == 1 and perm is None: 38 | return a[None, :] 39 | 40 | if perm is None: 41 | perm = _default_perm(a) 42 | return anp.transpose(a, axes=perm) 43 | 44 | 45 | @dispatch 46 | def trace(a: Numeric, axis1: Int = -2, axis2: Int = -1): 47 | axis1 = resolve_axis(a, axis1) 48 | axis2 = resolve_axis(a, axis2) 49 | 50 | if axis1 == axis2: 51 | raise ValueError("Keyword arguments `axis1` and `axis2` cannot be the same.") 52 | 53 | # AutoGrad does not support the `axis1` and `axis2` arguments... 54 | 55 | # Order the axis as `axis1 < axis`. 56 | if axis2 < axis1: 57 | axis1, axis2 = axis2, axis1 58 | 59 | # Bring the trace axes forward. 60 | if (axis1, axis2) != (0, 1): 61 | perm = [axis1, axis2] 62 | perm += [i for i in range(B.rank(a)) if i != axis1 and i != axis2] 63 | a = anp.transpose(a, axes=perm) 64 | 65 | return anp.trace(a) 66 | 67 | 68 | @dispatch 69 | def svd(a: Numeric, compute_uv: bool = True): 70 | res = anp.linalg.svd(a, full_matrices=False, compute_uv=compute_uv) 71 | return (res[0], res[1], anp.conj(transpose(res[2]))) if compute_uv else res 72 | 73 | 74 | @dispatch 75 | def eig(a: Numeric, compute_eigvecs: bool = True): # pragma: no cover 76 | raise NotImplementedError("Function `quantile` is not available for AutoGrad.") 77 | 78 | 79 | @dispatch 80 | def solve(a: Numeric, b: Numeric): 81 | return anp.linalg.solve(a, b) 82 | 83 | 84 | @dispatch 85 | def inv(a: Numeric): 86 | return anp.linalg.inv(a) 87 | 88 | 89 | @dispatch 90 | def det(a: Numeric): 91 | return anp.linalg.det(a) 92 | 93 | 94 | @dispatch 95 | def logdet(a: Numeric): 96 | return anp.linalg.slogdet(a)[1] 97 | 98 | 99 | _expm = autograd_register(expm, s_expm) 100 | 101 | 102 | @dispatch 103 | def expm(a: Numeric): 104 | return _expm(a) 105 | 106 | 107 | _logm = autograd_register(logm, s_logm) 108 | 109 | 110 | @dispatch 111 | def logm(a: Numeric): 112 | return _logm(a) 113 | 114 | 115 | @dispatch 116 | def _cholesky(a: Numeric): 117 | return anp.linalg.cholesky(a) 118 | 119 | 120 | @dispatch 121 | def cholesky_solve(a: Numeric, b: Numeric): 122 | return triangular_solve(transpose(a), triangular_solve(a, b), lower_a=False) 123 | 124 | 125 | @dispatch 126 | def triangular_solve(a: Numeric, b: Numeric, lower_a: bool = True): 127 | def _triangular_solve(a_, b_): 128 | return asla.solve_triangular( 129 | a_, b_, trans="N", lower=lower_a, check_finite=False 130 | ) 131 | 132 | return batch_computation(_triangular_solve, (a, b), (2, 2)) 133 | 134 | 135 | _toeplitz_solve = autograd_register(toeplitz_solve, s_toeplitz_solve) 136 | 137 | 138 | @dispatch 139 | def toeplitz_solve(a: Numeric, b: Numeric, c: Numeric): 140 | return _toeplitz_solve(a, b, c) 141 | -------------------------------------------------------------------------------- /lab/autograd/random.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as anp 2 | import numpy as np 3 | 4 | from ..types import AGNumeric, AGRandomState, Int 5 | from . import B, dispatch 6 | 7 | __all__ = [] 8 | 9 | 10 | @dispatch 11 | def randcat(state: AGRandomState, p: AGNumeric, n: Int): 12 | # Probabilities must sum to one. 13 | p = p / anp.sum(p, axis=-1, keepdims=True) 14 | # Perform sampling routine. 15 | cdf = anp.cumsum(p, axis=-1) 16 | u = state.rand(n, *p.shape[:-1]) 17 | inds = anp.sum(u[..., None] < cdf[None], axis=-1) - 1 18 | # Be sure to return the right data type. 19 | return state, B.cast(B.dtype_int(p), inds) 20 | 21 | 22 | @dispatch 23 | def randcat(p: AGNumeric, *shape: Int): 24 | return randcat(np.random.random.__self__, p, *shape)[1] 25 | -------------------------------------------------------------------------------- /lab/autograd/shaping.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import autograd.numpy as anp 4 | 5 | from ..types import Int 6 | from . import Numeric, dispatch 7 | 8 | __all__ = [] 9 | 10 | 11 | @dispatch 12 | def length(a: Numeric): 13 | return anp.size(a) 14 | 15 | 16 | @dispatch 17 | def _expand_dims(a: Numeric, axis: Int = 0): 18 | return anp.expand_dims(a, axis=axis) 19 | 20 | 21 | @dispatch 22 | def squeeze(a: Numeric, axis: Union[Int, None] = None): 23 | return anp.squeeze(a, axis=axis) 24 | 25 | 26 | @dispatch 27 | def broadcast_to(a: Numeric, *shape: Int): 28 | return anp.broadcast_to(a, shape) 29 | 30 | 31 | @dispatch 32 | def diag(a: Numeric): 33 | return anp.diag(a) 34 | 35 | 36 | @dispatch 37 | def diag_extract(a: Numeric): 38 | return anp.diagonal(a, axis1=-2, axis2=-1) 39 | 40 | 41 | @dispatch 42 | def stack(*elements: Numeric, axis: Int = 0): 43 | return anp.stack(elements, axis=axis) 44 | 45 | 46 | @dispatch 47 | def _unstack(a: Numeric, axis: Int = 0): 48 | out = anp.split(a, anp.arange(1, a.shape[axis]), axis) 49 | return [x.squeeze(axis=axis) for x in out] 50 | 51 | 52 | @dispatch 53 | def reshape(a: Numeric, *shape: Int): 54 | return anp.reshape(a, shape) 55 | 56 | 57 | @dispatch 58 | def concat(*elements: Numeric, axis: Int = 0): 59 | return anp.concatenate(elements, axis=axis) 60 | 61 | 62 | @dispatch 63 | def tile(a: Numeric, *repeats: Int): 64 | return anp.tile(a, repeats) 65 | -------------------------------------------------------------------------------- /lab/bvn_cdf/bvn_cdf.pyx: -------------------------------------------------------------------------------- 1 | # Apparently this future import is needed to fix the NumPy import. 2 | from __future__ import absolute_import 3 | 4 | cimport numpy as np 5 | 6 | import numpy as np 7 | 8 | cimport cython 9 | 10 | from cython.parallel import prange 11 | 12 | 13 | cdef extern from "math.h" nogil: 14 | double log(double x) 15 | double exp(double x) 16 | double sqrt(double x) 17 | 18 | cdef extern from "./tvpack.h" nogil: 19 | double phid_(double* x) 20 | double bvnd_(double* x, double* y, double* rho) 21 | 22 | 23 | @cython.boundscheck(False) 24 | @cython.wraparound(False) 25 | def bvn_cdf(np.ndarray[np.float64_t, ndim=1] x, 26 | np.ndarray[np.float64_t, ndim=1] y, 27 | np.ndarray[np.float64_t, ndim=1] rho): 28 | cdef int n = x.shape[0] 29 | 30 | # Initialise output. 31 | cdef np.ndarray[np.float64_t, ndim=1] out = np.empty([n], dtype=np.float64) 32 | 33 | # Define views for access in the parallel loop. 34 | cdef np.float64_t [:] x_view = x 35 | cdef np.float64_t [:] y_view = y 36 | cdef np.float64_t [:] rho_view = rho 37 | cdef np.float64_t [:] out_view = out 38 | 39 | cdef int i 40 | cdef double neg_x 41 | cdef double neg_y 42 | 43 | for i in prange(n, nogil=True): 44 | neg_x = -x_view[i] 45 | neg_y = -y_view[i] 46 | out_view[i] = bvnd_(&neg_x, &neg_y, &rho_view[i]) 47 | 48 | return out 49 | 50 | 51 | cdef double uvn_pdf(double x) nogil: 52 | cdef double pi = 3.141592653589793 53 | return exp(-0.5 * x * x) / sqrt(2 * pi) 54 | 55 | 56 | cdef double bvn_pdf(double x, double y, double rho) nogil: 57 | cdef double pi = 3.141592653589793 58 | cdef double determinant = 2 * pi * sqrt(1 - rho * rho) 59 | cdef double quad_form = (x * x - 2 * rho * x * y + y * y) / \ 60 | (2 * (1 - rho * rho)) 61 | return exp(-quad_form) / determinant 62 | 63 | 64 | @cython.boundscheck(False) 65 | @cython.wraparound(False) 66 | def s_bvn_cdf(np.ndarray[np.float64_t, ndim=1] s_out, 67 | np.ndarray[np.float64_t, ndim=1] out, 68 | np.ndarray[np.float64_t, ndim=1] x, 69 | np.ndarray[np.float64_t, ndim=1] y, 70 | np.ndarray[np.float64_t, ndim=1] rho): 71 | cdef int n = x.shape[0]; 72 | 73 | # Initialise output. 74 | cdef np.ndarray[np.float64_t, ndim=1] s_x = np.empty([n], dtype=np.float64) 75 | cdef np.ndarray[np.float64_t, ndim=1] s_y = np.empty([n], dtype=np.float64) 76 | cdef np.ndarray[np.float64_t, ndim=1] s_rho = np.empty([n], dtype=np.float64) 77 | 78 | # Define views for access in the parallel loop. 79 | cdef np.float64_t [:] s_out_view = s_out 80 | cdef np.float64_t [:] out_view = out 81 | cdef np.float64_t [:] x_view = x 82 | cdef np.float64_t [:] y_view = y 83 | cdef np.float64_t [:] rho_view = rho 84 | 85 | cdef np.float64_t [:] s_x_view = s_x 86 | cdef np.float64_t [:] s_y_view = s_y 87 | cdef np.float64_t [:] s_rho_view = s_rho 88 | 89 | cdef int i 90 | cdef double q 91 | cdef double pdf 92 | cdef double x_normalised 93 | cdef double y_normalised 94 | 95 | for i in prange(n, nogil=True): 96 | q = sqrt(1 - rho_view[i] * rho_view[i]) 97 | pdf = bvn_pdf(x_view[i], y_view[i], rho_view[i]) 98 | x_normalised = (x_view[i] - rho_view[i] * y_view[i]) / q 99 | y_normalised = (y_view[i] - rho_view[i] * x_view[i]) / q 100 | 101 | s_x_view[i] = s_out_view[i] * uvn_pdf(x_view[i]) * phid_(&y_normalised) 102 | s_y_view[i] = s_out_view[i] * uvn_pdf(y_view[i]) * phid_(&x_normalised) 103 | s_rho_view[i] = s_out_view[i] * pdf 104 | out_view[i] = bvn_pdf(x[i], x[i], 0.5) 105 | 106 | return s_x, s_y, s_rho 107 | 108 | -------------------------------------------------------------------------------- /lab/bvn_cdf/tvpack.h: -------------------------------------------------------------------------------- 1 | extern double phid_(double* x); 2 | extern double bvnd_(double* x, double* y, double* rho); 3 | -------------------------------------------------------------------------------- /lab/control_flow.py: -------------------------------------------------------------------------------- 1 | __all__ = ["control_flow", "ControlFlowCache"] 2 | 3 | 4 | class ControlFlow: 5 | """Control flow. 6 | 7 | Attributes: 8 | caching (bool): Are we currently caching? 9 | use_cache (bool): Are we currently using a cache? 10 | """ 11 | 12 | def __init__(self): 13 | self._cache = None 14 | self._counter = -1 15 | self.caching = False 16 | self.use_cache = False 17 | 18 | def start_caching(self, cache): 19 | """Start caching. 20 | 21 | Args: 22 | cache (:class:`.control_flow.ControlFlowCache`): Cache to populate. 23 | """ 24 | self._cache = cache 25 | self._counter = -1 26 | self.caching = True 27 | 28 | def stop_caching(self): 29 | """Stop caching.""" 30 | self.caching = False 31 | 32 | def start_using_cache(self, cache): 33 | """Start using a cache. 34 | 35 | Args: 36 | cache (:class:`.control_flow.ControlFlowCache`): Cache to use. 37 | """ 38 | self._cache = cache 39 | self._counter = -1 40 | self.use_cache = True 41 | 42 | def stop_using_cache(self): 43 | """Stop using a cache.""" 44 | self.use_cache = False 45 | 46 | def get_outcome(self, name): 47 | """Get an outcome. 48 | 49 | Args: 50 | name (str): Name of the operation. 51 | """ 52 | if self.use_cache: 53 | self._counter += 1 54 | return self._cache.outcomes[name, self._counter] 55 | else: 56 | raise RuntimeError("Can only get an outcome when a cache is used.") 57 | 58 | def set_outcome(self, name, outcome, type=None): 59 | """Set an outcome. 60 | 61 | Args: 62 | name (str): Name of the operation. 63 | outcome (object): Outcome. 64 | type (type, optional): Type to convert the outcome to. 65 | """ 66 | if self.caching: 67 | self._counter += 1 68 | if type: 69 | outcome = type(outcome) 70 | self._cache.outcomes[name, self._counter] = outcome 71 | 72 | 73 | control_flow = ControlFlow() 74 | 75 | 76 | class ControlFlowCache: 77 | """A control flow cache. 78 | 79 | Attributes: 80 | populated (bool): Is the cache already populated? 81 | outcomes (dict): Outcomes. 82 | """ 83 | 84 | def __init__(self): 85 | self.populated = False 86 | self.outcomes = {} 87 | 88 | def __enter__(self): 89 | if self.populated: 90 | control_flow.start_using_cache(self) 91 | else: 92 | control_flow.start_caching(self) 93 | return self 94 | 95 | def __exit__(self, exc_type, exc_val, exc_tb): 96 | if self.populated: 97 | control_flow.stop_using_cache() 98 | else: 99 | self.populated = True 100 | control_flow.stop_caching() 101 | 102 | def __str__(self): 103 | return repr(self) 104 | 105 | def __repr__(self): 106 | return f"" 107 | -------------------------------------------------------------------------------- /lab/custom.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import namedtuple 3 | from functools import reduce 4 | 5 | import numpy as np 6 | import scipy.linalg as sla 7 | 8 | TensorDescription = namedtuple("TensorDescription", "shape dtype") 9 | """namedtuple: Description of a tensor in terms of the tensor's shape and data type.""" 10 | 11 | 12 | def promote_dtype_of_tensors(*xs): 13 | """Promote the data types of a number of tensors. 14 | 15 | Args: 16 | *xs (tensor): Tensors to take data types of and then promote those data types. 17 | 18 | Returns: 19 | dtype: Promoted data type. 20 | """ 21 | return reduce(np.promote_types, [x.dtype for x in xs]) 22 | 23 | 24 | try: 25 | # noinspection PyUnresolvedReferences 26 | from .bvn_cdf import bvn_cdf as bvn_cdf_ 27 | from .bvn_cdf import s_bvn_cdf as s_bvn_cdf_ 28 | 29 | def i_bvn_cdf(a, b, c): 30 | if a.shape != b.shape or a.shape != c.shape: 31 | raise ValueError("Shapes of the inputs to `bvn_cdf` must all be equal.") 32 | return TensorDescription(a.shape, promote_dtype_of_tensors(a, b, c)) 33 | 34 | def i_s_bvn_cdf(s_y, y, a, b, c): 35 | dtype = promote_dtype_of_tensors(s_y, y, a, b, c) 36 | return ( 37 | TensorDescription(a.shape, dtype), 38 | TensorDescription(b.shape, dtype), 39 | TensorDescription(c.shape, dtype), 40 | ) 41 | 42 | except ImportError: # pragma: no cover 43 | 44 | def bvn_cdf_(*args, **kw_args): 45 | raise RuntimeError( 46 | "`bvn_cdf` was not compiled. Please reinstall LAB with `LAB_BUILD=1`." 47 | ) 48 | 49 | def i_bvn_cdf(*args, **kw_args): 50 | raise RuntimeError( 51 | "`bvn_cdf` was not compiled. Please reinstall LAB with `LAB_BUILD=1`." 52 | ) 53 | 54 | def s_bvn_cdf_(*args, **kw_args): 55 | raise RuntimeError( 56 | "`bvn_cdf` was not compiled. Please reinstall LAB with `LAB_BUILD=1`." 57 | ) 58 | 59 | def i_s_bvn_cdf(*args, **kw_args): 60 | raise RuntimeError( 61 | "`bvn_cdf` was not compiled. Please reinstall LAB with `LAB_BUILD=1`." 62 | ) 63 | 64 | 65 | __all__ = [ 66 | "toeplitz_solve", 67 | "i_toeplitz_solve", 68 | "s_toeplitz_solve", 69 | "i_s_toeplitz_solve", 70 | "bvn_cdf", 71 | "i_bvn_cdf", 72 | "s_bvn_cdf", 73 | "i_s_bvn_cdf", 74 | "expm", 75 | "i_expm", 76 | "s_expm", 77 | "i_s_expm", 78 | "logm", 79 | "i_logm", 80 | "s_logm", 81 | "i_s_logm", 82 | ] 83 | 84 | log = logging.getLogger(__name__) 85 | 86 | 87 | def _mm(a, b): 88 | """Short hand for `np.matmul`. 89 | 90 | Args: 91 | a (tensor): First tensor in product. 92 | b (tensor): Second tensor in product. 93 | 94 | Return: 95 | tensor: Matrix product of `a` and `b`. 96 | """ 97 | 98 | return np.matmul(a, b) 99 | 100 | 101 | def _t(a): 102 | """Transpose `a`, correctly handling the case where `a` is rank one. 103 | 104 | Args: 105 | a (tensor): Tensor to transpose. 106 | 107 | Returns: 108 | tensor: Transposition of `a`. 109 | """ 110 | if a.ndim == 1: 111 | return a[None, :] 112 | else: 113 | return np.transpose(a) 114 | 115 | 116 | def _uprank(a): 117 | """Get `a` as a rank-two tensor, correctly handling the case where `a` is 118 | rank one. 119 | 120 | Args: 121 | a (tensor): Tensor to get as a rank-two tensor. 122 | 123 | Returns: 124 | tensor: `a` as a rank-two vector. 125 | """ 126 | if a.ndim == 1: 127 | return a[:, None] 128 | else: 129 | return a 130 | 131 | 132 | def toeplitz_solve(a, b, c): 133 | # For some reason, `sla.solve_toeplitz` sometimes fails with a `ValueError`, saying 134 | # that the buffer source array is read-only. We resolve this issue by copying the 135 | # inputs.... 136 | # TODO: Resolve this properly. 137 | a = np.copy(a) 138 | b = np.copy(b) 139 | c = np.copy(c) 140 | res_dtype = promote_dtype_of_tensors(a, b, c) 141 | row = np.concatenate((a[:1], b)) # First row of the Toeplitz matrix 142 | return sla.solve_toeplitz((a, row), c).astype(res_dtype) 143 | 144 | 145 | def i_toeplitz_solve(a, b, c): 146 | return TensorDescription(c.shape, promote_dtype_of_tensors(a, b, c)) 147 | 148 | 149 | def s_toeplitz_solve(s_y, y, a, b, c): 150 | # Compute `a` and `b` to get the transpose of the Toeplitz matrix. 151 | a_t = np.concatenate((a[:1], b)) 152 | b_t = a[1:] 153 | 154 | # Compute the sensitivity w.r.t `c`. 155 | s_c = toeplitz_solve(a_t, b_t, s_y) 156 | 157 | # Compute the sensitivity w.r.t. the transposed inverse of the Toeplitz 158 | # matrix. 159 | s_inv = -_mm(_uprank(s_c), _t(y)) 160 | 161 | # Finally, compute the sensitivities w.r.t. `a` and `c`. 162 | n = a.shape[0] 163 | s_a = np.array([s_inv.diagonal(-i).sum() for i in range(n)]) 164 | s_b = np.array([s_inv.diagonal(i).sum() for i in range(1, n)]) 165 | 166 | return s_a, s_b, s_c 167 | 168 | 169 | def i_s_toeplitz_solve(s_y, y, a, b, c): 170 | dtype = promote_dtype_of_tensors(s_y, y, a, b, c) 171 | return ( 172 | TensorDescription(a.shape, dtype), 173 | TensorDescription(b.shape, dtype), 174 | TensorDescription(c.shape, dtype), 175 | ) 176 | 177 | 178 | def bvn_cdf(a, b, c): 179 | # We do not directly use `bvn_cdf_` to not have `inspect.signature` fail, which 180 | # does not work for `bvn_cdf_`. Moreover, we need to ensure that the function 181 | # runs on `float64s`. 182 | res_dtype = reduce(np.promote_types, [x.dtype for x in (a, b, c)]) 183 | # The C interface requires NumPy objects of the right data type. 184 | res = bvn_cdf_( 185 | np.asarray(a).astype(np.float64), 186 | np.asarray(b).astype(np.float64), 187 | np.asarray(c).astype(np.float64), 188 | ) 189 | return res.astype(res_dtype) 190 | 191 | 192 | def s_bvn_cdf(s_y, y, a, b, c): 193 | res_dtype = reduce(np.promote_types, [x.dtype for x in (s_y, y, a, b, c)]) 194 | # The C interface requires NumPy objects of the right data type. 195 | res = s_bvn_cdf_( 196 | np.asarray(s_y).astype(np.float64), 197 | np.asarray(y).astype(np.float64), 198 | np.asarray(a).astype(np.float64), 199 | np.asarray(b).astype(np.float64), 200 | np.asarray(c).astype(np.float64), 201 | ) 202 | return tuple(x.astype(res_dtype) for x in res) 203 | 204 | 205 | def expm(a): 206 | # This sometimes fails that the buffer source array is read-only. We resolve this 207 | # issue by copying the inputs. See also `toeplitz_solve`. 208 | # TODO: Resolve this properly. 209 | a = np.copy(a) 210 | return sla.expm(a) 211 | 212 | 213 | def i_expm(a): 214 | return TensorDescription(a.shape, a.dtype) 215 | 216 | 217 | def s_expm(s_y, y, a): 218 | return sla.expm_frechet(a, s_y.T, compute_expm=False).T 219 | 220 | 221 | def i_s_expm(s_y, y, a): 222 | return TensorDescription(a.shape, promote_dtype_of_tensors(s_y, y, a)) 223 | 224 | 225 | def logm(a): 226 | # This sometimes fails that the buffer source array is read-only. We resolve this 227 | # issue by copying the inputs. See also `toeplitz_solve`. 228 | # TODO: Resolve this properly. 229 | a = np.copy(a) 230 | return sla.logm(a) 231 | 232 | 233 | def i_logm(a): 234 | return TensorDescription(a.shape, a.dtype) 235 | 236 | 237 | def s_logm(a): # pragma: no cover 238 | raise NotImplementedError( 239 | "The derivative for the matrix logarithm is current not implemented." 240 | ) 241 | 242 | 243 | def i_s_logm(s_y, y, a): # pragma: no cover 244 | raise NotImplementedError( 245 | "The derivative for the matrix logarithm is current not implemented." 246 | ) 247 | -------------------------------------------------------------------------------- /lab/jax/__init__.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | from .. import * 3 | from .. import dispatch as dispatch_original 4 | from ..shape import dispatch_unwrap_dimensions 5 | from ..types import JAXNumeric, NPNumeric, Number 6 | 7 | dispatch = dispatch_unwrap_dimensions(dispatch_original) 8 | 9 | from typing import Union 10 | 11 | Numeric = Union[Number, NPNumeric, JAXNumeric] 12 | 13 | import jax # Load `jax` to load all new types. 14 | from packaging.version import Version 15 | from plum import clear_all_cache as _clear_all_cache 16 | 17 | # In version before `0.5.1`, the type of JAX data types is located elsewhere. 18 | if Version(jax.__version__) < Version("0.5.1"): # pragma: no cover 19 | from ..types import _jax_dtype 20 | 21 | _jax_dtype._module = "jax._src.numpy.lax_numpy" 22 | 23 | 24 | # noinspection PyUnresolvedReferences 25 | from .generic import * 26 | from .linear_algebra import * 27 | from .random import * 28 | from .shaping import * 29 | 30 | # Clear cache to make sure that all newly loaded types are available. 31 | _clear_all_cache() 32 | 33 | # Alias to actual module. 34 | sys.modules[__name__] = B 35 | -------------------------------------------------------------------------------- /lab/jax/custom.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | from jax import ShapeDtypeStruct 4 | from jax import __version__ as jax_version 5 | from jax import custom_vjp 6 | from packaging.version import Version 7 | from plum import Dispatcher, convert 8 | 9 | if Version(jax_version) >= Version("0.5"): 10 | from jax.experimental import io_callback 11 | 12 | # IO callbacks do not support JVP in JAX before `0.5`, so we emulate the call with 13 | # `host_callback`. 14 | else: # pragma: no cover 15 | from jax.experimental import host_callback 16 | 17 | def io_callback(f, shapes, x): 18 | return host_callback.call(f, x, result_shape=shapes) 19 | 20 | 21 | from ..custom import TensorDescription 22 | 23 | __all__ = ["jax_register"] 24 | 25 | _dispatch = Dispatcher() 26 | 27 | 28 | @_dispatch 29 | def parse_inference_result(x: TensorDescription): 30 | """Parse the result inference functions to a PyTree (JAX terminology) of 31 | :class:`jax.ShapeDtypeStruct`s. 32 | 33 | Args: 34 | x (PyTree): Input to parse. 35 | 36 | Returns: 37 | PyTree: Parsed input. 38 | """ 39 | return ShapeDtypeStruct(x.shape, x.dtype) 40 | 41 | 42 | @_dispatch 43 | def parse_inference_result(xs: tuple): 44 | return tuple(parse_inference_result(x) for x in xs) 45 | 46 | 47 | def _wrap_cb(f, i_f): 48 | @wraps(f) 49 | def f_wrapped(*args, **kw_args): 50 | return io_callback( 51 | lambda x: f(*x[0], **x[1]), 52 | parse_inference_result(i_f(*args, **kw_args)), 53 | (args, kw_args), 54 | ) 55 | 56 | return f_wrapped 57 | 58 | 59 | def jax_register(f, i_f, s_f, i_s_f): 60 | """Register a function and its sensitivity for JAX. 61 | 62 | Args: 63 | f (function): Function to register. 64 | i_f (function): Function that infers the shape of the output. 65 | s_f (function): Sensitivity of `f`. 66 | i_s_f (function): Function that infers the shape of the output of the 67 | sensitivity of `f`. 68 | 69 | Returns: 70 | function: JAX function. 71 | """ 72 | f = _wrap_cb(f, i_f) 73 | s_f = _wrap_cb(s_f, i_s_f) 74 | 75 | f = custom_vjp(f) 76 | 77 | # Define and register the forward and backward pass. 78 | 79 | def forward(*args, **kw_args): 80 | y = f(*args, **kw_args) 81 | return y, (y, args, kw_args) 82 | 83 | def backward(res, s_y): 84 | y, args, kw_args = res 85 | return convert(s_f(s_y, y, *args, **kw_args), tuple) 86 | 87 | f.defvjp(forward, backward) 88 | 89 | return f 90 | -------------------------------------------------------------------------------- /lab/jax/generic.py: -------------------------------------------------------------------------------- 1 | from types import FunctionType 2 | from typing import Union 3 | 4 | import jax 5 | import jax.nn as jnn 6 | import jax.numpy as jnp 7 | import jax.scipy.special as jsps 8 | from plum import isinstance 9 | 10 | from ..custom import bvn_cdf, i_bvn_cdf, i_s_bvn_cdf, s_bvn_cdf 11 | from ..types import ( 12 | Int, 13 | JAXDType, 14 | JAXNumeric, 15 | JAXRandomState, 16 | NPNumeric, 17 | Number, 18 | _jax_tracer, 19 | ) 20 | from . import B, Numeric, dispatch 21 | from .custom import jax_register 22 | 23 | __all__ = [] 24 | 25 | 26 | @dispatch 27 | def isabstract(a: Numeric): 28 | return isinstance(a, _jax_tracer) 29 | 30 | 31 | @dispatch 32 | def _jit_run( 33 | f: FunctionType, 34 | compilation_cache: dict, 35 | jit_kw_args: dict, 36 | *args: Union[Numeric, JAXRandomState], 37 | **kw_args, 38 | ): 39 | if "jax" not in compilation_cache: 40 | # Run once to populate the control flow cache. 41 | f(*args, **kw_args) 42 | # Compile. 43 | compilation_cache["jax"] = jax.jit(f, **jit_kw_args) 44 | 45 | return compilation_cache["jax"](*args, **kw_args) 46 | 47 | 48 | @dispatch 49 | def isnan(a: Numeric): 50 | return jnp.isnan(a) 51 | 52 | 53 | @dispatch 54 | def real(a: Numeric): 55 | return jnp.real(a) 56 | 57 | 58 | @dispatch 59 | def imag(a: Numeric): 60 | return jnp.imag(a) 61 | 62 | 63 | @dispatch 64 | def device(a: JAXNumeric): 65 | devices = list(a.devices()) 66 | if len(devices) != 1: 67 | raise RuntimeError("Could not determine device of JAX array.") 68 | return devices[0] 69 | 70 | 71 | @dispatch 72 | def to_active_device(a: JAXNumeric): 73 | if B.ActiveDevice.active_name: 74 | parts = B.ActiveDevice.active_name.lower().split(":") 75 | if len(parts) == 1: 76 | return jax.device_put(a, jax.devices(parts[0])[0]) 77 | elif len(parts) == 2: 78 | return jax.device_put(a, jax.devices(parts[0])[int(parts[1])]) 79 | else: 80 | raise ValueError( 81 | f'Cannot parse device specification "{B.ActiveDevice.active_name}".' 82 | ) 83 | else: 84 | return a 85 | 86 | 87 | @dispatch 88 | def zeros(dtype: JAXDType, *shape: Int): 89 | return to_active_device(jnp.zeros(shape, dtype=dtype)) 90 | 91 | 92 | @dispatch 93 | def ones(dtype: JAXDType, *shape: Int): 94 | return to_active_device(jnp.ones(shape, dtype=dtype)) 95 | 96 | 97 | @dispatch 98 | def _eye2(dtype: JAXDType, *shape: Int): 99 | return to_active_device(jnp.eye(shape[0], shape[1], dtype=dtype)) 100 | 101 | 102 | @dispatch 103 | def linspace(dtype: JAXDType, a, b, num: Int): 104 | return to_active_device(jnp.linspace(a, b, num, dtype=dtype)) 105 | 106 | 107 | @dispatch 108 | def range(dtype: JAXDType, start, stop, step): 109 | return to_active_device(jnp.arange(start, stop, step, dtype=dtype)) 110 | 111 | 112 | @dispatch 113 | def cast(dtype: JAXDType, a: JAXNumeric): 114 | return a.astype(dtype) 115 | 116 | 117 | @dispatch 118 | def cast(dtype: JAXDType, a: Union[Number, NPNumeric]): 119 | return to_active_device(jnp.array(a, dtype=dtype)) 120 | 121 | 122 | @dispatch 123 | def identity(a: Numeric): 124 | # Do not return `a` identically. 125 | return jnp.multiply(1, a) 126 | 127 | 128 | @dispatch 129 | def round(a: Numeric): 130 | return jnp.round(a) 131 | 132 | 133 | @dispatch 134 | def floor(a: Numeric): 135 | return jnp.floor(a) 136 | 137 | 138 | @dispatch 139 | def ceil(a: Numeric): 140 | return jnp.ceil(a) 141 | 142 | 143 | @dispatch 144 | def negative(a: Numeric): 145 | return jnp.negative(a) 146 | 147 | 148 | @dispatch 149 | def abs(a: Numeric): 150 | return jnp.abs(a) 151 | 152 | 153 | @dispatch 154 | def sign(a: Numeric): 155 | return jnp.sign(a) 156 | 157 | 158 | @dispatch 159 | def sqrt(a: Numeric): 160 | return jnp.sqrt(a) 161 | 162 | 163 | @dispatch 164 | def exp(a: Numeric): 165 | return jnp.exp(a) 166 | 167 | 168 | @dispatch 169 | def log(a: Numeric): 170 | return jnp.log(a) 171 | 172 | 173 | @dispatch 174 | def log1p(a: Numeric): 175 | return jnp.log1p(a) 176 | 177 | 178 | @dispatch 179 | def sin(a: Numeric): 180 | return jnp.sin(a) 181 | 182 | 183 | @dispatch 184 | def arcsin(a: Numeric): 185 | return jnp.arcsin(a) 186 | 187 | 188 | @dispatch 189 | def cos(a: Numeric): 190 | return jnp.cos(a) 191 | 192 | 193 | @dispatch 194 | def arccos(a: Numeric): 195 | return jnp.arccos(a) 196 | 197 | 198 | @dispatch 199 | def tan(a: Numeric): 200 | return jnp.tan(a) 201 | 202 | 203 | @dispatch 204 | def arctan(a: Numeric): 205 | return jnp.arctan(a) 206 | 207 | 208 | @dispatch 209 | def tanh(a: Numeric): 210 | return jnp.tanh(a) 211 | 212 | 213 | @dispatch 214 | def arctanh(a: Numeric): 215 | return jnp.arctanh(a) 216 | 217 | 218 | @dispatch 219 | def loggamma(a: Numeric): 220 | return jsps.gammaln(a) 221 | 222 | 223 | @dispatch 224 | def erf(a: Numeric): 225 | return jsps.erf(a) 226 | 227 | 228 | @dispatch 229 | def softplus(a: JAXNumeric): 230 | return jnn.softplus(a) 231 | 232 | 233 | @dispatch 234 | def add(a: Numeric, b: Numeric): 235 | return jnp.add(a, b) 236 | 237 | 238 | @dispatch 239 | def subtract(a: Numeric, b: Numeric): 240 | return jnp.subtract(a, b) 241 | 242 | 243 | @dispatch 244 | def multiply(a: Numeric, b: Numeric): 245 | return jnp.multiply(a, b) 246 | 247 | 248 | @dispatch 249 | def divide(a: Numeric, b: Numeric): 250 | return jnp.divide(a, b) 251 | 252 | 253 | @dispatch 254 | def power(a: Numeric, b: Numeric): 255 | return jnp.power(a, b) 256 | 257 | 258 | @dispatch 259 | def minimum(a: Numeric, b: Numeric): 260 | return jnp.minimum(a, b) 261 | 262 | 263 | @dispatch 264 | def maximum(a: Numeric, b: Numeric): 265 | return jnp.maximum(a, b) 266 | 267 | 268 | @dispatch 269 | def min(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 270 | return jnp.min(a, axis=axis, keepdims=not squeeze) 271 | 272 | 273 | @dispatch 274 | def argmin(a: Numeric, axis: Union[Int, None] = None): 275 | return jnp.argmin(a, axis=axis) 276 | 277 | 278 | @dispatch 279 | def max(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 280 | return jnp.max(a, axis=axis, keepdims=not squeeze) 281 | 282 | 283 | @dispatch 284 | def argmax(a: Numeric, axis: Union[Int, None] = None): 285 | return jnp.argmax(a, axis=axis) 286 | 287 | 288 | @dispatch 289 | def sum(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 290 | return jnp.sum(a, axis=axis, keepdims=not squeeze) 291 | 292 | 293 | @dispatch 294 | def prod(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 295 | return jnp.prod(a, axis=axis, keepdims=not squeeze) 296 | 297 | 298 | @dispatch 299 | def mean(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 300 | return jnp.mean(a, axis=axis, keepdims=not squeeze) 301 | 302 | 303 | @dispatch 304 | def std(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 305 | return jnp.std(a, axis=axis, ddof=0, keepdims=not squeeze) 306 | 307 | 308 | @dispatch 309 | def all(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 310 | return jnp.all(a, axis=axis, keepdims=not squeeze) 311 | 312 | 313 | @dispatch 314 | def any(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 315 | return jnp.any(a, axis=axis, keepdims=not squeeze) 316 | 317 | 318 | @dispatch 319 | def lt(a: Numeric, b: Numeric): 320 | return jnp.less(a, b) 321 | 322 | 323 | @dispatch 324 | def le(a: Numeric, b: Numeric): 325 | return jnp.less_equal(a, b) 326 | 327 | 328 | @dispatch 329 | def gt(a: Numeric, b: Numeric): 330 | return jnp.greater(a, b) 331 | 332 | 333 | @dispatch 334 | def ge(a: Numeric, b: Numeric): 335 | return jnp.greater_equal(a, b) 336 | 337 | 338 | @dispatch 339 | def eq(a: Numeric, b: Numeric): 340 | return jnp.equal(a, b) 341 | 342 | 343 | @dispatch 344 | def ne(a: Numeric, b: Numeric): 345 | return jnp.not_equal(a, b) 346 | 347 | 348 | _bvn_cdf = jax_register(bvn_cdf, i_bvn_cdf, s_bvn_cdf, i_s_bvn_cdf) 349 | 350 | 351 | @dispatch 352 | def bvn_cdf(a: Numeric, b: Numeric, c: Numeric): 353 | return _bvn_cdf(a, b, c) 354 | 355 | 356 | @dispatch 357 | def _cond(condition: JAXNumeric, f_true: FunctionType, f_false: FunctionType, *args): 358 | # We could use `jax.lax.cond` here, but that invokes compilation, which makes 359 | # repeated application of `B.cond` extremely slow. 360 | if condition: 361 | return f_true(*args) 362 | else: 363 | return f_false(*args) 364 | 365 | 366 | @dispatch 367 | def where(condition: Numeric, a: Numeric, b: Numeric): 368 | return jnp.where(condition, a, b) 369 | 370 | 371 | @dispatch 372 | def sort(a: Numeric, axis: Int = -1, descending: bool = False): 373 | if descending: 374 | return -jnp.sort(-a, axis=axis) 375 | else: 376 | return jnp.sort(a, axis=axis) 377 | 378 | 379 | @dispatch 380 | def argsort(a: Numeric, axis: Int = -1, descending: bool = False): 381 | if descending: 382 | return jnp.argsort(-a, axis=axis) 383 | else: 384 | return jnp.argsort(a, axis=axis) 385 | 386 | 387 | @dispatch 388 | def quantile(a: Numeric, q: Numeric, axis: Union[Int, None] = None): 389 | q = B.cast(B.dtype_float(q), q) # JAX requires this to be a float. 390 | return jnp.quantile(a, q, axis=axis, interpolation="linear") 391 | -------------------------------------------------------------------------------- /lab/jax/linear_algebra.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Union 3 | 4 | import jax.numpy as jnp 5 | import jax.scipy.linalg as jsla 6 | import opt_einsum as oe 7 | 8 | from ..custom import ( 9 | expm, 10 | i_expm, 11 | i_logm, 12 | i_s_expm, 13 | i_s_logm, 14 | i_s_toeplitz_solve, 15 | i_toeplitz_solve, 16 | logm, 17 | s_expm, 18 | s_logm, 19 | s_toeplitz_solve, 20 | toeplitz_solve, 21 | ) 22 | from ..linear_algebra import _default_perm 23 | from ..types import Int 24 | from ..util import batch_computation 25 | from . import B, Numeric, dispatch 26 | from .custom import jax_register 27 | 28 | __all__ = [] 29 | log = logging.getLogger(__name__) 30 | 31 | 32 | @dispatch 33 | def matmul(a: Numeric, b: Numeric, tr_a: bool = False, tr_b: bool = False): 34 | a = transpose(a) if tr_a else a 35 | b = transpose(b) if tr_b else b 36 | return jnp.matmul(a, b) 37 | 38 | 39 | @dispatch 40 | def einsum(equation: str, *elements: Numeric): 41 | return oe.contract(equation, *elements, backend="jax") 42 | 43 | 44 | @dispatch 45 | def transpose(a: Numeric, perm: Optional[Union[tuple, list]] = None): 46 | # Correctly handle special cases. 47 | rank_a = B.rank(a) 48 | if rank_a == 0: 49 | return a 50 | elif rank_a == 1 and perm is None: 51 | return a[None, :] 52 | 53 | if perm is None: 54 | perm = _default_perm(a) 55 | return jnp.transpose(a, axes=perm) 56 | 57 | 58 | @dispatch 59 | def trace(a: Numeric, axis1: Int = -2, axis2: Int = -1): 60 | return jnp.trace(a, axis1=axis1, axis2=axis2) 61 | 62 | 63 | @dispatch 64 | def svd(a: Numeric, compute_uv: bool = True): 65 | res = jnp.linalg.svd(a, full_matrices=False, compute_uv=compute_uv) 66 | return (res[0], res[1], jnp.conj(transpose(res[2]))) if compute_uv else res 67 | 68 | 69 | @dispatch 70 | def eig(a: Numeric, compute_eigvecs: bool = True): 71 | vals, vecs = jnp.linalg.eig(a) 72 | return (vals, vecs) if compute_eigvecs else vals 73 | 74 | 75 | @dispatch 76 | def solve(a: Numeric, b: Numeric): 77 | return jnp.linalg.solve(a, b) 78 | 79 | 80 | @dispatch 81 | def inv(a: Numeric): 82 | return jnp.linalg.inv(a) 83 | 84 | 85 | @dispatch 86 | def det(a: Numeric): 87 | return jnp.linalg.det(a) 88 | 89 | 90 | @dispatch 91 | def logdet(a: Numeric): 92 | return jnp.linalg.slogdet(a)[1] 93 | 94 | 95 | _expm = jax_register(expm, i_expm, s_expm, i_s_expm) 96 | 97 | 98 | @dispatch 99 | def expm(a: Numeric): 100 | return _expm(a) 101 | 102 | 103 | _logm = jax_register(logm, i_logm, s_logm, i_s_logm) 104 | 105 | 106 | @dispatch 107 | def logm(a: Numeric): 108 | return _logm(a) 109 | 110 | 111 | @dispatch 112 | def _cholesky(a: Numeric): 113 | return jnp.linalg.cholesky(a) 114 | 115 | 116 | @dispatch 117 | def cholesky_solve(a: Numeric, b: Numeric): 118 | return triangular_solve(transpose(a), triangular_solve(a, b), lower_a=False) 119 | 120 | 121 | @dispatch 122 | def triangular_solve(a: Numeric, b: Numeric, lower_a: bool = True): 123 | def _triangular_solve(a_, b_): 124 | return jsla.solve_triangular( 125 | a_, b_, trans="N", lower=lower_a, check_finite=False 126 | ) 127 | 128 | return batch_computation(_triangular_solve, (a, b), (2, 2)) 129 | 130 | 131 | _toeplitz_solve = jax_register( 132 | toeplitz_solve, i_toeplitz_solve, s_toeplitz_solve, i_s_toeplitz_solve 133 | ) 134 | 135 | 136 | @dispatch 137 | def toeplitz_solve(a: Numeric, b: Numeric, c: Numeric): 138 | return _toeplitz_solve(a, b, c) 139 | -------------------------------------------------------------------------------- /lab/jax/random.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from plum import Dispatcher 6 | 7 | from ..types import Int, JAXDType, JAXNumeric, JAXRandomState 8 | from ..util import broadcast_shapes 9 | from . import B, Numeric, dispatch 10 | 11 | __all__ = [] 12 | 13 | _dispatch = Dispatcher() 14 | 15 | 16 | @dispatch 17 | def create_random_state(_: JAXDType, seed: Int = 0): 18 | return jax.random.PRNGKey(seed=seed) 19 | 20 | 21 | B.jax_global_random_state = jax.random.PRNGKey(seed=0) 22 | 23 | 24 | @dispatch 25 | def global_random_state(_: JAXDType): 26 | return B.jax_global_random_state 27 | 28 | 29 | @dispatch 30 | def set_global_random_state(state: JAXRandomState): 31 | B.jax_global_random_state = state 32 | 33 | 34 | @dispatch 35 | def rand(state: JAXRandomState, dtype: JAXDType, *shape: Int): 36 | state, key = jax.random.split(state) 37 | return state, B.to_active_device(jax.random.uniform(key, shape, dtype=dtype)) 38 | 39 | 40 | @dispatch 41 | def rand(dtype: JAXDType, *shape: Int): 42 | state, res = rand(global_random_state(dtype), dtype, *shape) 43 | B.jax_global_random_state = state 44 | return res 45 | 46 | 47 | @dispatch 48 | def randn(state: JAXRandomState, dtype: JAXDType, *shape: Int): 49 | state, key = jax.random.split(state) 50 | return state, B.to_active_device(jax.random.normal(key, shape, dtype=dtype)) 51 | 52 | 53 | @dispatch 54 | def randn(dtype: JAXDType, *shape: Int): 55 | state, res = randn(global_random_state(dtype), dtype, *shape) 56 | B.jax_global_random_state = state 57 | return res 58 | 59 | 60 | @dispatch 61 | def randcat(state: JAXRandomState, p: JAXNumeric, *shape: Int): 62 | state, key = jax.random.split(state) 63 | # We need to tile to make the batching work. 64 | p = B.tile( 65 | B.expand_dims(p, axis=0, times=len(shape)), 66 | *shape, 67 | *((1,) * len(p.shape)), 68 | ) 69 | inds = jax.random.categorical(key, jnp.log(p)) 70 | return state, inds 71 | 72 | 73 | @dispatch 74 | def randcat(p: JAXNumeric, *shape: Int): 75 | state, res = randcat(global_random_state(p), p, *shape) 76 | B.jax_global_random_state = state 77 | return res 78 | 79 | 80 | @dispatch 81 | def choice(a: JAXNumeric, *shape: Int, p: Union[Numeric, None] = None): 82 | # This method is necessary to break ambiguity. 83 | state, res = choice(global_random_state(a), a, *shape, p=p) 84 | B.jax_global_random_state = state 85 | return res 86 | 87 | 88 | @dispatch 89 | def randint( 90 | state: JAXRandomState, 91 | dtype: JAXDType, 92 | *shape: Int, 93 | lower: Int = 0, 94 | upper: Int, 95 | ): 96 | dtype = B.dtype_int(dtype) 97 | state, key = jax.random.split(state) 98 | return state, B.to_active_device( 99 | jax.random.randint(key, shape, lower, upper, dtype=dtype) 100 | ) 101 | 102 | 103 | @dispatch 104 | def randint( 105 | dtype: JAXDType, 106 | *shape: Int, 107 | lower: Int = 0, 108 | upper: Int, 109 | ): 110 | state, res = randint( 111 | global_random_state(dtype), 112 | dtype, 113 | *shape, 114 | lower=lower, 115 | upper=upper, 116 | ) 117 | B.jax_global_random_state = state 118 | return res 119 | 120 | 121 | @dispatch 122 | def randperm(state: JAXRandomState, dtype: JAXDType, n: Int): 123 | dtype = B.dtype_int(dtype) 124 | state, key = jax.random.split(state) 125 | return state, B.to_active_device(B.cast(dtype, jax.random.permutation(key, n))) 126 | 127 | 128 | @dispatch 129 | def randperm(dtype: JAXDType, n: Int): 130 | state, res = randperm(global_random_state(dtype), dtype, n) 131 | B.jax_global_random_state = state 132 | return res 133 | 134 | 135 | @dispatch 136 | def randgamma( 137 | state: JAXRandomState, 138 | dtype: JAXDType, 139 | *shape: Int, 140 | alpha: Numeric, 141 | scale: Numeric, 142 | ): 143 | state, key = jax.random.split(state) 144 | shape = shape + broadcast_shapes(B.shape(alpha), B.shape(scale)) 145 | sample = B.to_active_device(jax.random.gamma(key, alpha, shape, dtype=dtype)) 146 | sample = sample * B.to_active_device(B.cast(dtype, scale)) 147 | return state, sample 148 | 149 | 150 | @dispatch 151 | def randgamma(dtype: JAXDType, *shape: Int, alpha: Numeric, scale: Numeric): 152 | state = global_random_state(dtype) 153 | state, res = randgamma(state, dtype, *shape, alpha=alpha, scale=scale) 154 | B.jax_global_random_state = state 155 | return res 156 | -------------------------------------------------------------------------------- /lab/jax/shaping.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import jax.numpy as jnp 4 | 5 | from ..types import Int 6 | from . import Numeric, dispatch 7 | 8 | __all__ = [] 9 | 10 | 11 | @dispatch 12 | def length(a: Numeric): 13 | return jnp.size(a) 14 | 15 | 16 | @dispatch 17 | def _expand_dims(a: Numeric, axis: Int = 0): 18 | return jnp.expand_dims(a, axis=axis) 19 | 20 | 21 | @dispatch 22 | def squeeze(a: Numeric, axis: Union[Int, None] = None): 23 | return jnp.squeeze(a, axis=axis) 24 | 25 | 26 | @dispatch 27 | def broadcast_to(a: Numeric, *shape: Int): 28 | return jnp.broadcast_to(a, shape) 29 | 30 | 31 | @dispatch 32 | def diag(a: Numeric): 33 | return jnp.diag(a) 34 | 35 | 36 | @dispatch 37 | def diag_extract(a: Numeric): 38 | return jnp.diagonal(a, axis1=-2, axis2=-1) 39 | 40 | 41 | @dispatch 42 | def stack(*elements: Numeric, axis: Int = 0): 43 | return jnp.stack(elements, axis=axis) 44 | 45 | 46 | @dispatch 47 | def _unstack(a: Numeric, axis: Int = 0): 48 | out = jnp.split(a, jnp.arange(1, a.shape[axis]), axis) 49 | return [x.squeeze(axis=axis) for x in out] 50 | 51 | 52 | @dispatch 53 | def reshape(a: Numeric, *shape: Int): 54 | return jnp.reshape(a, shape) 55 | 56 | 57 | @dispatch 58 | def concat(*elements: Numeric, axis: Int = 0): 59 | return jnp.concatenate(elements, axis=axis) 60 | 61 | 62 | @dispatch 63 | def tile(a: Numeric, *repeats: Int): 64 | return jnp.tile(a, repeats) 65 | -------------------------------------------------------------------------------- /lab/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | from .. import * 3 | from .. import dispatch as dispatch_original 4 | from ..shape import dispatch_unwrap_dimensions 5 | from ..types import NPNumeric, Number 6 | 7 | # All methods here should have precedence, because NumPy forms the base of everything. 8 | dispatch_original = dispatch_original(precedence=1) 9 | 10 | dispatch = dispatch_unwrap_dimensions(dispatch_original) 11 | 12 | from typing import Union 13 | 14 | Numeric = Union[Number, NPNumeric] 15 | 16 | from .generic import * 17 | from .linear_algebra import * 18 | from .random import * 19 | from .shaping import * 20 | 21 | # Alias to actual module. 22 | sys.modules[__name__] = B 23 | -------------------------------------------------------------------------------- /lab/numpy/generic.py: -------------------------------------------------------------------------------- 1 | from types import FunctionType 2 | from typing import Union 3 | 4 | import numpy as np 5 | import scipy.special as sps 6 | 7 | from ..custom import bvn_cdf as _bvn_cdf 8 | from ..types import Int, NPDType, NPNumeric, NPRandomState 9 | from . import B, Numeric, dispatch 10 | 11 | __all__ = [] 12 | 13 | 14 | @dispatch 15 | def isabstract(a: Numeric): 16 | return False 17 | 18 | 19 | @dispatch 20 | def _jit_run( 21 | f: FunctionType, 22 | compilation_cache: dict, 23 | jit_kw_args: dict, 24 | *args: Union[Numeric, NPRandomState], 25 | **kw_args, 26 | ): 27 | # There is no JIT for NumPy, so just run the function. 28 | return f(*args, **kw_args) 29 | 30 | 31 | @dispatch 32 | def isnan(a: Numeric): 33 | return np.isnan(a) 34 | 35 | 36 | @dispatch 37 | def real(a: Numeric): 38 | return np.real(a) 39 | 40 | 41 | @dispatch 42 | def imag(a: Numeric): 43 | return np.imag(a) 44 | 45 | 46 | @dispatch 47 | def device(a: NPNumeric): 48 | return "cpu" 49 | 50 | 51 | @dispatch 52 | def to_active_device(a: NPNumeric): 53 | return a 54 | 55 | 56 | @dispatch 57 | def zeros(dtype: NPDType, *shape: Int): 58 | return np.zeros(shape, dtype=dtype) 59 | 60 | 61 | @dispatch 62 | def ones(dtype: NPDType, *shape: Int): 63 | return np.ones(shape, dtype=dtype) 64 | 65 | 66 | @dispatch 67 | def _eye2(dtype: NPDType, *shape: Int): 68 | return np.eye(shape[0], shape[1], dtype=dtype) 69 | 70 | 71 | @dispatch 72 | def linspace(dtype: NPDType, a, b, num: Int): 73 | return np.linspace(a, b, num, dtype=dtype) 74 | 75 | 76 | @dispatch 77 | def range(dtype: NPDType, start, stop, step): 78 | return np.arange(start, stop, step, dtype=dtype) 79 | 80 | 81 | @dispatch 82 | def cast(dtype: NPDType, a: Numeric): 83 | if B.dtype(a) == dtype: 84 | return a 85 | if hasattr(a, "astype"): 86 | return a.astype(dtype, copy=False) 87 | else: 88 | return np.array(a, dtype=dtype) 89 | 90 | 91 | @dispatch 92 | def identity(a: Numeric): 93 | return np.array(a) 94 | 95 | 96 | @dispatch 97 | def round(a: Numeric): 98 | return np.round(a) 99 | 100 | 101 | @dispatch 102 | def floor(a: Numeric): 103 | return np.floor(a) 104 | 105 | 106 | @dispatch 107 | def ceil(a: Numeric): 108 | return np.ceil(a) 109 | 110 | 111 | @dispatch 112 | def negative(a: Numeric): 113 | return np.negative(a) 114 | 115 | 116 | @dispatch 117 | def abs(a: Numeric): 118 | return np.abs(a) 119 | 120 | 121 | @dispatch 122 | def sign(a: Numeric): 123 | return np.sign(a) 124 | 125 | 126 | @dispatch 127 | def sqrt(a: Numeric): 128 | return np.sqrt(a) 129 | 130 | 131 | @dispatch 132 | def exp(a: Numeric): 133 | return np.exp(a) 134 | 135 | 136 | @dispatch 137 | def log(a: Numeric): 138 | return np.log(a) 139 | 140 | 141 | @dispatch 142 | def log1p(a: Numeric): 143 | return np.log1p(a) 144 | 145 | 146 | @dispatch 147 | def sin(a: Numeric): 148 | return np.sin(a) 149 | 150 | 151 | @dispatch 152 | def arcsin(a: Numeric): 153 | return np.arcsin(a) 154 | 155 | 156 | @dispatch 157 | def cos(a: Numeric): 158 | return np.cos(a) 159 | 160 | 161 | @dispatch 162 | def arccos(a: Numeric): 163 | return np.arccos(a) 164 | 165 | 166 | @dispatch 167 | def tan(a: Numeric): 168 | return np.tan(a) 169 | 170 | 171 | @dispatch 172 | def arctan(a: Numeric): 173 | return np.arctan(a) 174 | 175 | 176 | @dispatch 177 | def tanh(a: Numeric): 178 | return np.tanh(a) 179 | 180 | 181 | @dispatch 182 | def arctanh(a: Numeric): 183 | return np.arctanh(a) 184 | 185 | 186 | @dispatch 187 | def loggamma(a: Numeric): 188 | return sps.gammaln(a) 189 | 190 | 191 | @dispatch 192 | def erf(a: Numeric): 193 | return sps.erf(a) 194 | 195 | 196 | @dispatch 197 | def add(a: Numeric, b: Numeric): 198 | return np.add(a, b) 199 | 200 | 201 | @dispatch 202 | def subtract(a: Numeric, b: Numeric): 203 | return np.subtract(a, b) 204 | 205 | 206 | @dispatch 207 | def multiply(a: Numeric, b: Numeric): 208 | return np.multiply(a, b) 209 | 210 | 211 | @dispatch 212 | def divide(a: Numeric, b: Numeric): 213 | return np.divide(a, b) 214 | 215 | 216 | @dispatch 217 | def power(a: Numeric, b: Numeric): 218 | return np.power(a, b) 219 | 220 | 221 | @dispatch 222 | def minimum(a: Numeric, b: Numeric): 223 | return np.minimum(a, b) 224 | 225 | 226 | @dispatch 227 | def maximum(a: Numeric, b: Numeric): 228 | return np.maximum(a, b) 229 | 230 | 231 | @dispatch 232 | def min(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 233 | return np.min(a, axis=axis, keepdims=not squeeze) 234 | 235 | 236 | @dispatch 237 | def argmin(a: Numeric, axis: Union[Int, None] = None): 238 | return np.argmin(a, axis=axis) 239 | 240 | 241 | @dispatch 242 | def max(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 243 | return np.max(a, axis=axis, keepdims=not squeeze) 244 | 245 | 246 | @dispatch 247 | def argmax(a: Numeric, axis: Union[Int, None] = None): 248 | return np.argmax(a, axis=axis) 249 | 250 | 251 | @dispatch 252 | def sum(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 253 | return np.sum(a, axis=axis, keepdims=not squeeze) 254 | 255 | 256 | @dispatch 257 | def prod(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 258 | return np.prod(a, axis=axis, keepdims=not squeeze) 259 | 260 | 261 | @dispatch 262 | def mean(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 263 | return np.mean(a, axis=axis, keepdims=not squeeze) 264 | 265 | 266 | @dispatch 267 | def std(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 268 | return np.std(a, axis=axis, ddof=0, keepdims=not squeeze) 269 | 270 | 271 | @dispatch 272 | def all(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 273 | return np.all(a, axis=axis, keepdims=not squeeze) 274 | 275 | 276 | @dispatch 277 | def any(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 278 | return np.any(a, axis=axis, keepdims=not squeeze) 279 | 280 | 281 | @dispatch 282 | def lt(a: Numeric, b: Numeric): 283 | return np.less(a, b) 284 | 285 | 286 | @dispatch 287 | def le(a: Numeric, b: Numeric): 288 | return np.less_equal(a, b) 289 | 290 | 291 | @dispatch 292 | def gt(a: Numeric, b: Numeric): 293 | return np.greater(a, b) 294 | 295 | 296 | @dispatch 297 | def ge(a: Numeric, b: Numeric): 298 | return np.greater_equal(a, b) 299 | 300 | 301 | @dispatch 302 | def eq(a: Numeric, b: Numeric): 303 | return np.equal(a, b) 304 | 305 | 306 | @dispatch 307 | def ne(a: Numeric, b: Numeric): 308 | return np.not_equal(a, b) 309 | 310 | 311 | @dispatch 312 | def bvn_cdf(a: Numeric, b: Numeric, c: Numeric): 313 | return _bvn_cdf(a, b, c) 314 | 315 | 316 | @dispatch 317 | def where(condition: Numeric, a: Numeric, b: Numeric): 318 | return np.where(condition, a, b) 319 | 320 | 321 | @dispatch 322 | def sort(a: Numeric, axis: Int = -1, descending: bool = False): 323 | if descending: 324 | return -np.sort(-a, axis=axis) 325 | else: 326 | return np.sort(a, axis=axis) 327 | 328 | 329 | @dispatch 330 | def argsort(a: Numeric, axis: Int = -1, descending: bool = False): 331 | if descending: 332 | return np.argsort(-a, axis=axis) 333 | else: 334 | return np.argsort(a, axis=axis) 335 | 336 | 337 | @dispatch 338 | def quantile(a: Numeric, q: Numeric, axis: Union[Int, None] = None): 339 | if tuple(map(int, np.__version__.split("."))) >= (1, 22): # pragma: no cover 340 | method = {"method": "linear"} 341 | else: # pragma: no cover 342 | method = {"interpolation": "linear"} 343 | return np.quantile(a, q, axis=axis, **method) 344 | -------------------------------------------------------------------------------- /lab/numpy/linear_algebra.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Union 3 | 4 | import numpy as np 5 | import opt_einsum as oe 6 | import scipy.linalg as sla 7 | 8 | from ..custom import expm as _expm 9 | from ..custom import logm as _logm 10 | from ..custom import toeplitz_solve as _toeplitz_solve 11 | from ..linear_algebra import _default_perm 12 | from ..types import Int 13 | from ..util import batch_computation 14 | from . import B, Numeric, dispatch 15 | 16 | __all__ = [] 17 | 18 | log = logging.getLogger(__name__) 19 | 20 | 21 | @dispatch 22 | def matmul(a: Numeric, b: Numeric, tr_a: bool = False, tr_b: bool = False): 23 | a = transpose(a) if tr_a else a 24 | b = transpose(b) if tr_b else b 25 | return np.matmul(a, b) 26 | 27 | 28 | @dispatch 29 | def einsum(equation: str, *elements: Numeric): 30 | return oe.contract(equation, *elements, backend="numpy") 31 | 32 | 33 | @dispatch 34 | def transpose(a: Numeric, perm: Optional[Union[tuple, list]] = None): 35 | # Correctly handle special cases. 36 | rank_a = B.rank(a) 37 | if rank_a == 0: 38 | return a 39 | elif rank_a == 1 and perm is None: 40 | return a[None, :] 41 | 42 | if perm is None: 43 | perm = _default_perm(a) 44 | return np.transpose(a, axes=perm) 45 | 46 | 47 | @dispatch 48 | def trace(a: Numeric, axis1: Int = -2, axis2: Int = -1): 49 | return np.trace(a, axis1=axis1, axis2=axis2) 50 | 51 | 52 | @dispatch 53 | def svd(a: Numeric, compute_uv: bool = True): 54 | res = np.linalg.svd(a, full_matrices=False, compute_uv=compute_uv) 55 | return (res[0], res[1], np.conj(transpose(res[2]))) if compute_uv else res 56 | 57 | 58 | @dispatch 59 | def eig(a: Numeric, compute_eigvecs: bool = True): 60 | vals, vecs = np.linalg.eig(a) 61 | return (vals, vecs) if compute_eigvecs else vals 62 | 63 | 64 | @dispatch 65 | def solve(a: Numeric, b: Numeric): 66 | return np.linalg.solve(a, b) 67 | 68 | 69 | @dispatch 70 | def inv(a: Numeric): 71 | return np.linalg.inv(a) 72 | 73 | 74 | @dispatch 75 | def det(a: Numeric): 76 | return np.linalg.det(a) 77 | 78 | 79 | @dispatch 80 | def logdet(a: Numeric): 81 | return np.linalg.slogdet(a)[1] 82 | 83 | 84 | @dispatch 85 | def expm(a: Numeric): 86 | return _expm(a) 87 | 88 | 89 | @dispatch 90 | def logm(a: Numeric): 91 | return _logm(a) 92 | 93 | 94 | @dispatch 95 | def _cholesky(a: Numeric): 96 | return np.linalg.cholesky(a) 97 | 98 | 99 | @dispatch 100 | def cholesky_solve(a: Numeric, b: Numeric): 101 | return triangular_solve(transpose(a), triangular_solve(a, b), lower_a=False) 102 | 103 | 104 | @dispatch 105 | def triangular_solve(a: Numeric, b: Numeric, lower_a: bool = True): 106 | def _triangular_solve(a_, b_): 107 | return sla.solve_triangular( 108 | a_, b_, trans="N", lower=lower_a, check_finite=False 109 | ) 110 | 111 | return batch_computation(_triangular_solve, (a, b), (2, 2)) 112 | 113 | 114 | @dispatch 115 | def toeplitz_solve(a: Numeric, b: Numeric, c: Numeric): 116 | return _toeplitz_solve(a, b, c) 117 | -------------------------------------------------------------------------------- /lab/numpy/random.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Union 3 | 4 | import numpy as np 5 | 6 | from ..types import Int, NPDType, NPRandomState 7 | from ..util import broadcast_shapes 8 | from . import B, Numeric, dispatch 9 | 10 | __all__ = [] 11 | 12 | 13 | @dispatch 14 | def create_random_state(_: NPDType, seed: Int = 0): 15 | return np.random.RandomState(seed=seed) 16 | 17 | 18 | @dispatch 19 | def global_random_state(_: NPDType): 20 | return np.random.random.__self__ 21 | 22 | 23 | @dispatch 24 | def set_global_random_state(state: NPRandomState): 25 | np.random.random.__self__.set_state(state.get_state()) 26 | 27 | 28 | def _warn_dtype(dtype): 29 | if B.issubdtype(dtype, np.integer): 30 | warnings.warn("Casting random number of type float to type integer.") 31 | 32 | 33 | @dispatch 34 | def rand(state: NPRandomState, dtype: NPDType, *shape: Int): 35 | _warn_dtype(dtype) 36 | return state, B.cast(dtype, state.rand(*shape)) 37 | 38 | 39 | @dispatch 40 | def rand(dtype: NPDType, *shape: Int): 41 | return rand(global_random_state(dtype), dtype, *shape)[1] 42 | 43 | 44 | @dispatch 45 | def randn(state: NPRandomState, dtype: NPDType, *shape: Int): 46 | _warn_dtype(dtype) 47 | return state, B.cast(dtype, state.randn(*shape)) 48 | 49 | 50 | @dispatch 51 | def randn(dtype: NPDType, *shape: Int): 52 | return randn(global_random_state(dtype), dtype, *shape)[1] 53 | 54 | 55 | @dispatch 56 | def randcat(state: NPRandomState, p: Numeric, n: Int): 57 | # Probabilities must sum to one. 58 | p = p / np.sum(p, axis=-1, keepdims=True) 59 | # Perform sampling routine. 60 | cdf = np.cumsum(p, axis=-1) 61 | u = state.rand(n, *p.shape[:-1]) 62 | inds = np.sum(u[..., None] >= cdf[None], axis=-1) 63 | # Be sure to return the right data type. 64 | return state, B.cast(B.dtype_int(p), inds) 65 | 66 | 67 | @dispatch 68 | def randcat(p: Numeric, *shape: Int): 69 | return randcat(global_random_state(p), p, *shape)[1] 70 | 71 | 72 | @dispatch 73 | def choice(a: Numeric, *shape: Int, p: Union[Numeric, None] = None): 74 | # This method is necessary to break ambiguity. 75 | return choice(global_random_state(a), a, *shape, p=p)[1] 76 | 77 | 78 | @dispatch 79 | def randint(state: NPRandomState, dtype: NPDType, *shape: Int, lower: Int = 0, upper): 80 | dtype = B.dtype_int(dtype) 81 | return state, state.randint(lower, upper, shape, dtype=dtype) 82 | 83 | 84 | @dispatch 85 | def randint(dtype: NPDType, *shape: Int, lower: Int = 0, upper): 86 | state = global_random_state(dtype) 87 | return randint(state, dtype, *shape, lower=lower, upper=upper)[1] 88 | 89 | 90 | @dispatch 91 | def randperm(state: NPRandomState, dtype: NPDType, n: Int): 92 | dtype = B.dtype_int(dtype) 93 | return state, B.cast(dtype, state.permutation(n)) 94 | 95 | 96 | @dispatch 97 | def randperm(dtype: NPDType, n: Int): 98 | return randperm(global_random_state(dtype), dtype, n)[1] 99 | 100 | 101 | @dispatch 102 | def randgamma( 103 | state: NPRandomState, 104 | dtype: NPDType, 105 | *shape: Int, 106 | alpha: Numeric, 107 | scale: Numeric, 108 | ): 109 | _warn_dtype(dtype) 110 | shape = shape + broadcast_shapes(B.shape(alpha), B.shape(scale)) 111 | return state, B.cast(dtype, state.gamma(alpha, size=shape) * scale) 112 | 113 | 114 | @dispatch 115 | def randgamma(dtype: NPDType, *shape: Int, alpha: Numeric, scale: Numeric): 116 | state = global_random_state(dtype) 117 | return randgamma(state, dtype, *shape, alpha=alpha, scale=scale)[1] 118 | -------------------------------------------------------------------------------- /lab/numpy/shaping.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | 5 | from ..types import Int 6 | from . import Numeric, dispatch 7 | 8 | __all__ = [] 9 | 10 | 11 | @dispatch 12 | def length(a: Numeric): 13 | return np.size(a) 14 | 15 | 16 | @dispatch 17 | def _expand_dims(a: Numeric, axis: Int = 0): 18 | return np.expand_dims(a, axis=axis) 19 | 20 | 21 | @dispatch 22 | def squeeze(a: Numeric, axis: Union[Int, None] = None): 23 | return np.squeeze(a, axis=axis) 24 | 25 | 26 | @dispatch 27 | def broadcast_to(a: Numeric, *shape: Int): 28 | return np.broadcast_to(a, shape) 29 | 30 | 31 | @dispatch 32 | def diag(a: Numeric): 33 | return np.diag(a) 34 | 35 | 36 | @dispatch 37 | def diag_extract(a: Numeric): 38 | return np.diagonal(a, axis1=-2, axis2=-1) 39 | 40 | 41 | @dispatch 42 | def stack(*elements: Numeric, axis: Int = 0): 43 | return np.stack(elements, axis=axis) 44 | 45 | 46 | @dispatch 47 | def _unstack(a: Numeric, axis: Int = 0): 48 | out = np.split(a, np.arange(1, a.shape[axis]), axis) 49 | return [x.squeeze(axis=axis) for x in out] 50 | 51 | 52 | @dispatch 53 | def reshape(a: Numeric, *shape: Int): 54 | return np.reshape(a, shape) 55 | 56 | 57 | @dispatch 58 | def concat(*elements: Numeric, axis: Int = 0): 59 | return np.concatenate(elements, axis=axis) 60 | 61 | 62 | @dispatch 63 | def tile(a: Numeric, *repeats: Int): 64 | return np.tile(a, repeats) 65 | -------------------------------------------------------------------------------- /lab/random.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from functools import reduce 3 | from operator import mul 4 | from typing import Union 5 | 6 | import numpy as np 7 | 8 | from . import B, dispatch 9 | from .types import DType, Int, Numeric, RandomState 10 | from .util import abstract 11 | 12 | __all__ = [ 13 | "set_random_seed", 14 | "create_random_state", 15 | "global_random_state", 16 | "set_global_random_state", 17 | "rand", 18 | "randn", 19 | "randcat", 20 | "choice", 21 | "randint", 22 | "randperm", 23 | "randgamma", 24 | "randbeta", 25 | ] 26 | 27 | 28 | @dispatch 29 | def set_random_seed(seed: Int): 30 | """Set the random seed for all frameworks. 31 | 32 | Args: 33 | seed (int): Seed. 34 | """ 35 | # Set seed in NumPy. 36 | np.random.seed(seed) 37 | 38 | # Set seed for TensorFlow, if it is loaded. 39 | if "tensorflow" in sys.modules: 40 | import tensorflow as tf 41 | 42 | tf.random.set_seed(seed) 43 | tf.random.set_global_generator(tf.random.Generator.from_seed(seed)) 44 | 45 | # Set seed for PyTorch, if it is loaded. 46 | if "torch" in sys.modules: 47 | import torch 48 | 49 | torch.manual_seed(seed) 50 | 51 | # Set seed for JAX, if it is loaded. 52 | if hasattr(B, "jax_global_random_state"): 53 | import jax 54 | 55 | B.jax_global_random_state = jax.random.PRNGKey(seed=seed) 56 | 57 | 58 | @dispatch 59 | @abstract() 60 | def create_random_state(dtype: DType, seed: Int = 0): 61 | """Create a random state. 62 | 63 | Args: 64 | dtype (dtype): Data type of the desired framework to create a random state 65 | for. 66 | seed (int, optional): Seed to initialise the random state with. Defaults 67 | to `0`. 68 | 69 | Returns: 70 | random state: Random state. 71 | """ 72 | 73 | 74 | @dispatch 75 | @abstract() 76 | def global_random_state(dtype: DType): 77 | """Get the global random state. 78 | 79 | Args: 80 | dtype (dtype): Data type of the desired framework for which to get the global 81 | random state. 82 | 83 | Returns: 84 | random state: Global random state. 85 | """ 86 | 87 | 88 | @dispatch 89 | @abstract() 90 | def set_global_random_state(state: RandomState): 91 | """Set the global random state. 92 | 93 | NOTE: 94 | In TensorFlow, setting the global random state does NOT fix the randomness 95 | for non-LAB random calls, like `tf.random.normal`. Use `B.set_random_seed` 96 | instead! 97 | 98 | Args: 99 | state (random state): Random state to set. 100 | """ 101 | 102 | 103 | @dispatch 104 | def global_random_state(a): 105 | return global_random_state(B.dtype(a)) 106 | 107 | 108 | @dispatch 109 | @abstract() 110 | def rand(state: RandomState, dtype: DType, *shape: Int): # pragma: no cover 111 | """Construct a U[0, 1] random tensor. 112 | 113 | Args: 114 | state (random state, optional): Random state. 115 | dtype (dtype, optional): Data type. Defaults to the default data type. 116 | *shape (shape, optional): Shape of the sample. Defaults to `()`. 117 | 118 | Returns: 119 | state (random state, optional): Random state. 120 | tensor: Random tensor. 121 | """ 122 | 123 | 124 | @dispatch 125 | def rand(*shape: Int): 126 | return rand(B.default_dtype, *shape) 127 | 128 | 129 | @dispatch 130 | def rand(state: RandomState, ref: Numeric): 131 | return rand(state, B.dtype(ref), *B.shape(ref)) 132 | 133 | 134 | @dispatch 135 | def rand(ref: Numeric): 136 | return rand(B.dtype(ref), *B.shape(ref)) 137 | 138 | 139 | @dispatch 140 | def rand(shape: Int): 141 | # Single integer is not a reference. 142 | return rand(B.default_dtype, shape) 143 | 144 | 145 | @dispatch 146 | @abstract() 147 | def randn(state: RandomState, dtype: DType, *shape: Int): # pragma: no cover 148 | """Construct a N(0, 1) random tensor. 149 | 150 | Args: 151 | state (random state, optional): Random state. 152 | dtype (dtype, optional): Data type. Defaults to the default data type. 153 | *shape (shape, optional): Shape of the sample. Defaults to `()`. 154 | 155 | Returns: 156 | state (random state, optional): Random state. 157 | tensor: Random tensor. 158 | """ 159 | 160 | 161 | @dispatch 162 | def randn(*shape: Int): 163 | return randn(B.default_dtype, *shape) 164 | 165 | 166 | @dispatch 167 | def randn(state: RandomState, ref: Numeric): 168 | return randn(state, B.dtype(ref), *B.shape(ref)) 169 | 170 | 171 | @dispatch 172 | def randn(ref: Numeric): 173 | return randn(B.dtype(ref), *B.shape(ref)) 174 | 175 | 176 | @dispatch 177 | def randn(shape: Int): 178 | return randn(B.default_dtype, shape) 179 | 180 | 181 | @dispatch 182 | def randcat(state: RandomState, p: Union[Numeric, None], *shape: Int): 183 | """Randomly draw from a categorical random variable. 184 | 185 | Args: 186 | state (random state, optional): Random state. 187 | p (tensor): Probabilities. The last axis determines the probabilities and 188 | any prior axes add to the shap of the sample. 189 | *shape (int): Shape of the sample. Defaults to `()`. 190 | 191 | Returns: 192 | state (random state, optional): Random state. 193 | tensor: Realisation. 194 | """ 195 | n = reduce(mul, shape, 1) 196 | state, sample = randcat(state, p, n) 197 | return state, B.reshape(sample, *shape, *B.shape(sample)[1:]) 198 | 199 | 200 | def _randcat_last_first(a): 201 | """Put the last dimension first. 202 | 203 | Args: 204 | a (tensor): Tensor. 205 | 206 | Returns: 207 | tensor: `a`, but with last dimension first. 208 | """ 209 | perm = list(range(B.rank(a))) 210 | return B.transpose(a, perm=perm[-1:] + perm[:-1]) 211 | 212 | 213 | @dispatch 214 | def choice( 215 | state: RandomState, 216 | a: Numeric, 217 | *shape: Int, 218 | p: Union[Numeric, None] = None, 219 | ): 220 | """Randomly choose from a tensor *with* replacement. 221 | 222 | Args: 223 | state (random state, optional): Random state. 224 | a (tensor): Tensor to choose from. Choices will be made along the first 225 | dimension. 226 | *shape (int): Shape of the sample. Defaults to `()`. 227 | p (tensor, optional): Probabilities to sample with. 228 | 229 | Returns: 230 | state (random state, optional): Random state. 231 | tensor: Choices. 232 | """ 233 | if p is None: 234 | with B.on_device(a): 235 | p = B.ones(B.dtype_float(a), B.shape(a, 0)) 236 | state, inds = B.randcat(state, p, *shape) 237 | choices = B.reshape( 238 | B.take(a, B.flatten(inds), axis=0), 239 | *B.shape(inds), 240 | *B.shape(a)[1:], 241 | ) 242 | return state, choices 243 | 244 | 245 | @dispatch 246 | def choice( 247 | a: Numeric, 248 | *shape: Int, 249 | p: Union[Numeric, None] = None, 250 | ): 251 | state = B.global_random_state(a) 252 | state, choices = choice(state, a, *shape, p=p) 253 | B.set_global_random_state(state) 254 | return choices 255 | 256 | 257 | @dispatch 258 | @abstract() 259 | def randint( 260 | state: RandomState, 261 | dtype: DType, 262 | *shape: Int, 263 | lower: Int = 0, 264 | upper: Int, 265 | ): # pragma: no cover 266 | """Construct a tensor of random integers in [`lower`, `upper`). 267 | 268 | Args: 269 | state (random state, optional): Random state. 270 | dtype (dtype, optional): Data type. Defaults to the default data type. 271 | *shape (shape, optional): Shape of the tensor. Defaults to `()`. 272 | lower (int, optional): Lower bound. Defaults to `0`. 273 | upper (int): Upper bound. Must be given as a keyword argument. 274 | 275 | Returns: 276 | state (random state, optional): Random state. 277 | tensor: Random tensor. 278 | """ 279 | 280 | 281 | @dispatch 282 | def randint(*shape: Int, lower: Int = 0, upper: Int): 283 | return randint(B.default_dtype, *shape, lower=lower, upper=upper) 284 | 285 | 286 | @dispatch 287 | def randint(state: RandomState, ref: Numeric, *, lower: Int = 0, upper: Int): 288 | return randint(state, B.dtype(ref), *B.shape(ref), lower=lower, upper=upper) 289 | 290 | 291 | @dispatch 292 | def randint(ref: Numeric, *, lower: Int = 0, upper: Int): 293 | return randint(B.dtype(ref), *B.shape(ref), lower=lower, upper=upper) 294 | 295 | 296 | @dispatch 297 | def randint(shape: Int, *, lower: Int = 0, upper: Int): 298 | # Single integer is not a reference. 299 | return randint(B.default_dtype, shape, lower=lower, upper=upper) 300 | 301 | 302 | @dispatch 303 | @abstract() 304 | def randperm(state: RandomState, dtype: DType, n: Int): # pragma: no cover 305 | """Construct a random permutation counting to `n`. 306 | 307 | Args: 308 | state (random state, optional): Random state. 309 | dtype (dtype, optional): Data type. Defaults to the default data type. 310 | n (int): Length of the permutation. 311 | 312 | Returns: 313 | state (random state, optional): Random state. 314 | tensor: Random permutation. 315 | """ 316 | 317 | 318 | @dispatch 319 | def randperm(n: Int): 320 | return randperm(B.default_dtype, n) 321 | 322 | 323 | @dispatch 324 | @abstract() 325 | def randgamma( 326 | state: RandomState, 327 | dtype: DType, 328 | *shape: Int, 329 | alpha: Numeric, 330 | scale: Numeric, 331 | ): # pragma: no cover 332 | """Construct a tensor of gamma random variables with shape parameter `alpha` and 333 | scale `scale`. 334 | 335 | Args: 336 | state (random state, optional): Random state. 337 | dtype (dtype, optional): Data type. Defaults to the default data type. 338 | *shape (shape, optional): Shape of the tensor. Defaults to `()`. 339 | alpha (scalar): Shape parameter. 340 | scale (scalar): Scale parameter. 341 | 342 | Returns: 343 | state (random state, optional): Random state. 344 | tensor: Random tensor. 345 | """ 346 | 347 | 348 | @dispatch 349 | def randgamma(*shape: Int, alpha: Numeric, scale: Numeric): 350 | return randgamma(B.default_dtype, *shape, alpha=alpha, scale=scale) 351 | 352 | 353 | @dispatch 354 | def randgamma(state: RandomState, ref: Numeric, *, alpha: Numeric, scale: Numeric): 355 | return randgamma(state, B.dtype(ref), *B.shape(ref), alpha=alpha, scale=scale) 356 | 357 | 358 | @dispatch 359 | def randgamma(ref: Numeric, *, alpha: Numeric, scale: Numeric): 360 | return randgamma(B.dtype(ref), *B.shape(ref), alpha=alpha, scale=scale) 361 | 362 | 363 | @dispatch 364 | def randgamma(shape: Int, *, alpha: Numeric, scale: Numeric): 365 | # Single integer is a not a reference. 366 | return randgamma(B.default_dtype, shape, alpha=alpha, scale=scale) 367 | 368 | 369 | @dispatch 370 | def randbeta( 371 | state: RandomState, 372 | dtype: DType, 373 | *shape: Int, 374 | alpha: Numeric, 375 | beta: Numeric, 376 | ): 377 | """Construct a tensor of beta random variables with shape parameters `alpha` and 378 | `beta`. 379 | 380 | Args: 381 | state (random state, optional): Random state. 382 | dtype (dtype, optional): Data type. Defaults to the default data type. 383 | *shape (shape, optional): Shape of the tensor. Defaults to `()`. 384 | alpha (scalar): Shape parameter `alpha`. 385 | beta (scalar): Shape parameter `beta`. 386 | 387 | Returns: 388 | state (random state, optional): Random state. 389 | tensor: Random tensor. 390 | """ 391 | state, x = randgamma(state, dtype, *shape, alpha=alpha, scale=1) 392 | state, y = randgamma(state, dtype, *shape, alpha=beta, scale=1) 393 | return state, x / (x + y) 394 | 395 | 396 | @dispatch 397 | def randbeta(dtype: DType, *shape: Int, alpha: Numeric, beta: Numeric): 398 | return randbeta( 399 | B.global_random_state(dtype), 400 | dtype, 401 | *shape, 402 | alpha=alpha, 403 | beta=beta, 404 | )[1] 405 | 406 | 407 | @dispatch 408 | def randbeta(*shape: Int, alpha: Numeric, beta: Numeric): 409 | return randbeta(B.default_dtype, *shape, alpha=alpha, beta=beta) 410 | 411 | 412 | @dispatch 413 | def randbeta(state: RandomState, ref: Numeric, *, alpha: Numeric, beta: Numeric): 414 | return randbeta(state, B.dtype(ref), *B.shape(ref), alpha=alpha, beta=beta) 415 | 416 | 417 | @dispatch 418 | def randbeta(ref: Numeric, *, alpha: Numeric, beta: Numeric): 419 | return randbeta(B.dtype(ref), *B.shape(ref), alpha=alpha, beta=beta) 420 | 421 | 422 | @dispatch 423 | def randbeta(shape: Int, *, alpha: Numeric, beta: Numeric): 424 | # Single integer is not a reference. 425 | return randbeta(B.default_dtype, shape, alpha=alpha, beta=beta) 426 | -------------------------------------------------------------------------------- /lab/shape.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | from plum import Dispatcher 4 | 5 | from . import B, dispatch 6 | 7 | __all__ = ["Shape", "Dimension", "unwrap_dimension", "dispatch_unwrap_dimensions"] 8 | 9 | _dispatch = Dispatcher() 10 | 11 | 12 | class Shape: 13 | """A shape. 14 | 15 | Args: 16 | *dims (number): Dimensions of the shape. 17 | 18 | Attributes: 19 | dims (tuple[number]): Dimensions of the shape. 20 | """ 21 | 22 | def __init__(self, *dims): 23 | # Be careful to not wrap dimensions twice. 24 | self.dims = tuple(unwrap_dimension(dim) for dim in dims) 25 | 26 | @_dispatch 27 | def __getitem__(self, item): 28 | return Dimension(self.dims[item]) 29 | 30 | @_dispatch 31 | def __getitem__(self, item: slice): 32 | return Shape(*self.dims[item]) 33 | 34 | def __len__(self): 35 | return len(self.dims) 36 | 37 | def __iter__(self): 38 | for dim in self.dims: 39 | yield Dimension(dim) 40 | 41 | def __add__(self, other): 42 | return Shape(*(tuple(self) + tuple(other))) 43 | 44 | def __radd__(self, other): 45 | return Shape(*(tuple(other) + tuple(self))) 46 | 47 | def __eq__(self, other): 48 | return len(self) == len(other) and all(x == y for x, y in zip(self, other)) 49 | 50 | def __reversed__(self): 51 | return Shape(*reversed(self.dims)) 52 | 53 | def __repr__(self): 54 | return "Shape(" + ", ".join(repr(x) for x in self) + ")" 55 | 56 | def __str__(self): 57 | if len(self) == 0: 58 | return "()" 59 | elif len(self) == 1: 60 | return f"({self[0]!r},)" 61 | else: 62 | return "(" + ", ".join(repr(x) for x in self) + ")" 63 | 64 | def __hash__(self): 65 | return hash(self.dims) 66 | 67 | 68 | @dispatch 69 | def to_numpy(shape: Shape): 70 | return B.to_numpy(shape.dims) 71 | 72 | 73 | class Dimension: 74 | """A dimension in a shape. 75 | 76 | Args: 77 | dim (number): Dimension. 78 | 79 | Attributes: 80 | dim (number): Dimension. 81 | """ 82 | 83 | def __init__(self, dim): 84 | self.dim = dim 85 | 86 | def __int__(self): 87 | return int(self.dim) 88 | 89 | def __len__(self): 90 | return len(self.dim) 91 | 92 | def __iter__(self): 93 | return iter(self.dim) 94 | 95 | def __eq__(self, other): 96 | return self.dim == other 97 | 98 | def __ge__(self, other): 99 | return self.dim >= other 100 | 101 | def __gt__(self, other): 102 | return self.dim > other 103 | 104 | def __le__(self, other): 105 | return self.dim <= other 106 | 107 | def __lt__(self, other): 108 | return self.dim < other 109 | 110 | def __add__(self, other): 111 | return self.dim + other 112 | 113 | def __radd__(self, other): 114 | return other + self.dim 115 | 116 | def __sub__(self, other): 117 | return self.dim - other 118 | 119 | def __rsub__(self, other): 120 | return other - self.dim 121 | 122 | def __mul__(self, other): 123 | return self.dim * other 124 | 125 | def __rmul__(self, other): 126 | return other * self.dim 127 | 128 | def __truediv__(self, other): 129 | return self.dim / other 130 | 131 | def __rtruediv__(self, other): 132 | return other / self.dim 133 | 134 | def __floordiv__(self, other): 135 | return self.dim // other 136 | 137 | def __rfloordiv__(self, other): 138 | return other // self.dim 139 | 140 | def __neg__(self): 141 | return -self.dim 142 | 143 | def __pow__(self, power, modulo=None): 144 | return self.dim.__pow__(power, modulo) 145 | 146 | def __repr__(self): 147 | return repr(self.dim) 148 | 149 | def __str__(self): 150 | return str(self.dim) 151 | 152 | def __hash__(self): 153 | return hash(self.dim) 154 | 155 | 156 | @_dispatch 157 | def unwrap_dimension(a): 158 | """Unwrap a dimension. 159 | 160 | Args: 161 | a (object): Dimension to unwrap. 162 | 163 | Returns: 164 | number: If `a` was wrapped with :class:`.shape.Dimension`, then this will be 165 | `a.dim`. Otherwise, the result is just `a`. 166 | """ 167 | return a 168 | 169 | 170 | @_dispatch 171 | def unwrap_dimension(a: Dimension): 172 | return a.dim 173 | 174 | 175 | def dispatch_unwrap_dimensions(dispatch): 176 | """Unwrap all dimensions after performing dispatch. 177 | 178 | Args: 179 | dispatch (decorator): Dispatch decorator. 180 | """ 181 | 182 | def unwrapped_dispatch(f): 183 | @wraps(f) 184 | def f_wrapped(*args, **kw_args): 185 | return f(*(unwrap_dimension(arg) for arg in args), **kw_args) 186 | 187 | return dispatch(f_wrapped) 188 | 189 | return unwrapped_dispatch 190 | -------------------------------------------------------------------------------- /lab/tensorflow/__init__.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | from .. import * 3 | from .. import dispatch as dispatch_original 4 | from ..shape import dispatch_unwrap_dimensions 5 | from ..types import NPNumeric, Number, TFNumeric 6 | 7 | dispatch = dispatch_unwrap_dimensions(dispatch_original) 8 | 9 | from typing import Union 10 | 11 | Numeric = Union[Number, NPNumeric, TFNumeric] 12 | 13 | import tensorflow as tf # Load `tensorflow` to load all new types. 14 | from plum import clear_all_cache as _clear_all_cache 15 | 16 | # noinspection PyUnresolvedReferences 17 | from .generic import * 18 | from .linear_algebra import * 19 | from .random import * 20 | from .shaping import * 21 | 22 | # Clear cache to make sure that all newly loaded types are available. 23 | _clear_all_cache() 24 | 25 | # Set TF device manager. 26 | B.ActiveDevice._tf_manager = tf.device 27 | 28 | # Alias to actual module. 29 | sys.modules[__name__] = B 30 | -------------------------------------------------------------------------------- /lab/tensorflow/custom.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | import tensorflow as tf 4 | from plum import Dispatcher, convert 5 | 6 | from . import B 7 | 8 | __all__ = ["tensorflow_register", "as_tf"] 9 | 10 | _dispatch = Dispatcher() 11 | 12 | 13 | @_dispatch 14 | def as_tf(x: B.Numeric): 15 | """Convert object to TensorFlow. 16 | 17 | Args: 18 | x (object): Object to convert. 19 | 20 | Returns: 21 | object: `x` as a TensorFlow object. 22 | """ 23 | dtype = convert(B.dtype(x), B.TFDType) 24 | return tf.constant(x, dtype=dtype) 25 | 26 | 27 | @_dispatch 28 | def as_tf(xs: tuple): 29 | return tuple([as_tf(x) for x in xs]) 30 | 31 | 32 | def _np_apply(f, out_dtypes, *args, **kw_args): 33 | """Apply a NumPy function in TensorFlow. 34 | 35 | Args: 36 | f (function): NumPy function. 37 | out_dtypes (list[dtype]): List of data types of the output. 38 | *args (object): Argument to `f`. 39 | **kw_args (object): Keyword arguments to `f`. 40 | 41 | Returns: 42 | tensor: Result as a TensorFlow operation. 43 | """ 44 | return tf.py_function( 45 | lambda *args_: f(*[arg.numpy() for arg in args_], **kw_args), args, out_dtypes 46 | ) 47 | 48 | 49 | def tensorflow_register(f, s_f): 50 | """Register a function and its sensitivity for TensorFlow. 51 | 52 | Args: 53 | f (function): Function to register. 54 | s_f (function): Sensitivity of `f`. 55 | 56 | Returns: 57 | function: TensorFlow primitive. 58 | """ 59 | 60 | @wraps(f) 61 | def primitive(*args, **kw_args): 62 | # TODO: This assumes that the output is of the data type of the first input. 63 | # Generally, this is *not* true. How to best approach this? 64 | y = _np_apply(f, args[0].dtype, *args, **kw_args) 65 | 66 | def grad(s_y): 67 | # TODO: This assumes that the sensitivities of the inputs are of the data 68 | # types of the inputs. Again, generally, this is *not* true. How to best 69 | # approach this? 70 | return _np_apply( 71 | s_f, [arg.dtype for arg in args], *((s_y, y) + args), **kw_args 72 | ) 73 | 74 | return y, grad 75 | 76 | return tf.custom_gradient(primitive) 77 | -------------------------------------------------------------------------------- /lab/tensorflow/generic.py: -------------------------------------------------------------------------------- 1 | from types import FunctionType 2 | from typing import Callable, Union 3 | 4 | import tensorflow as tf 5 | import tensorflow_probability as tfp 6 | 7 | from ..custom import bvn_cdf, s_bvn_cdf 8 | from ..types import Int, TFDType, TFRandomState 9 | from . import B, Numeric, TFNumeric, dispatch 10 | from .custom import tensorflow_register 11 | 12 | __all__ = [] 13 | 14 | 15 | @dispatch 16 | def isabstract(a: Numeric): 17 | return not tf.executing_eagerly() 18 | 19 | 20 | @dispatch 21 | def _jit_run( 22 | f: FunctionType, 23 | compilation_cache: dict, 24 | jit_kw_args: dict, 25 | *args: Union[Numeric, TFRandomState], 26 | **kw_args, 27 | ): 28 | if "tensorflow" not in compilation_cache: 29 | # Run once to populate the control flow cache. 30 | f(*args, **kw_args) 31 | # Default `autograph` to `False`. 32 | jit_kw_args = dict(jit_kw_args) 33 | if "autograph" not in jit_kw_args: 34 | jit_kw_args["autograph"] = False 35 | # Compile. 36 | compilation_cache["tensorflow"] = tf.function(f, **jit_kw_args) 37 | 38 | return compilation_cache["tensorflow"](*args, **kw_args) 39 | 40 | 41 | @dispatch 42 | def isnan(a: Numeric): 43 | return tf.math.is_nan(a) 44 | 45 | 46 | @dispatch 47 | def real(a: Numeric): 48 | return tf.math.real(a) 49 | 50 | 51 | @dispatch 52 | def imag(a: Numeric): 53 | return tf.math.imag(a) 54 | 55 | 56 | @dispatch 57 | def device(a: TFNumeric): 58 | return a.device 59 | 60 | 61 | @dispatch 62 | def to_active_device(a: TFNumeric): 63 | return a 64 | 65 | 66 | @dispatch 67 | def zeros(dtype: TFDType, *shape: Int): 68 | return tf.zeros(shape, dtype=dtype) 69 | 70 | 71 | @dispatch 72 | def ones(dtype: TFDType, *shape: Int): 73 | return tf.ones(shape, dtype=dtype) 74 | 75 | 76 | @dispatch 77 | def _eye2(dtype: TFDType, *shape: Int): 78 | return tf.eye(shape[0], shape[1], dtype=dtype) 79 | 80 | 81 | @dispatch 82 | def linspace(dtype: TFDType, a, b, num: Int): 83 | return tf.linspace(cast(dtype, a), cast(dtype, b), num) 84 | 85 | 86 | @dispatch 87 | def range(dtype: TFDType, start, stop, step): 88 | return tf.range(start, stop, step, dtype=dtype) 89 | 90 | 91 | @dispatch 92 | def cast(dtype: TFDType, a: Numeric): 93 | return tf.cast(a, dtype=dtype) 94 | 95 | 96 | @dispatch 97 | def identity(a: Numeric): 98 | return tf.identity(a) 99 | 100 | 101 | @dispatch 102 | def round(a: Numeric): 103 | return tf.math.round(a) 104 | 105 | 106 | @dispatch 107 | def floor(a: Numeric): 108 | return tf.math.floor(a) 109 | 110 | 111 | @dispatch 112 | def ceil(a: Numeric): 113 | return tf.math.ceil(a) 114 | 115 | 116 | @dispatch 117 | def negative(a: Numeric): 118 | return tf.negative(a) 119 | 120 | 121 | @dispatch 122 | def abs(a: Numeric): 123 | return tf.abs(a) 124 | 125 | 126 | @dispatch 127 | def sign(a: Numeric): 128 | return tf.sign(a) 129 | 130 | 131 | @dispatch 132 | def sqrt(a: Numeric): 133 | return tf.sqrt(a) 134 | 135 | 136 | @dispatch 137 | def exp(a: Numeric): 138 | return tf.exp(a) 139 | 140 | 141 | @dispatch 142 | def log(a: Numeric): 143 | return tf.math.log(a) 144 | 145 | 146 | @dispatch 147 | def log1p(a: Numeric): 148 | return tf.math.log1p(a) 149 | 150 | 151 | @dispatch 152 | def sin(a: Numeric): 153 | return tf.sin(a) 154 | 155 | 156 | @dispatch 157 | def arcsin(a: Numeric): 158 | return tf.asin(a) 159 | 160 | 161 | @dispatch 162 | def cos(a: Numeric): 163 | return tf.cos(a) 164 | 165 | 166 | @dispatch 167 | def arccos(a: Numeric): 168 | return tf.acos(a) 169 | 170 | 171 | @dispatch 172 | def tan(a: Numeric): 173 | return tf.tan(a) 174 | 175 | 176 | @dispatch 177 | def arctan(a: Numeric): 178 | return tf.atan(a) 179 | 180 | 181 | @dispatch 182 | def tanh(a: Numeric): 183 | return tf.tanh(a) 184 | 185 | 186 | @dispatch 187 | def arctanh(a: Numeric): 188 | return tf.atanh(a) 189 | 190 | 191 | @dispatch 192 | def loggamma(a: Numeric): 193 | return tf.math.lgamma(a) 194 | 195 | 196 | @dispatch 197 | def erf(a: Numeric): 198 | return tf.math.erf(a) 199 | 200 | 201 | @dispatch 202 | def softplus(a: TFNumeric): 203 | return tf.math.softplus(a) 204 | 205 | 206 | @dispatch 207 | def add(a: Numeric, b: Numeric): 208 | return tf.add(a, b) 209 | 210 | 211 | @dispatch 212 | def subtract(a: Numeric, b: Numeric): 213 | return tf.subtract(a, b) 214 | 215 | 216 | @dispatch 217 | def multiply(a: Numeric, b: Numeric): 218 | return tf.multiply(a, b) 219 | 220 | 221 | @dispatch 222 | def divide(a: Numeric, b: Numeric): 223 | return tf.divide(a, b) 224 | 225 | 226 | @dispatch 227 | def power(a: Numeric, b: Numeric): 228 | return tf.pow(a, b) 229 | 230 | 231 | @dispatch 232 | def minimum(a: Numeric, b: Numeric): 233 | return tf.minimum(a, b) 234 | 235 | 236 | @dispatch 237 | def maximum(a: Numeric, b: Numeric): 238 | return tf.maximum(a, b) 239 | 240 | 241 | @dispatch 242 | def min(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 243 | return tf.reduce_min(a, axis=axis, keepdims=not squeeze) 244 | 245 | 246 | @dispatch 247 | def argmin(a: Numeric, axis: Union[Int, None] = None): 248 | if axis is None: 249 | # The default `None` reduces over the last dimension. 250 | return tf.argmin(tf.reshape(a, -1), axis=0) 251 | else: 252 | return tf.argmin(a, axis=axis) 253 | 254 | 255 | @dispatch 256 | def max(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 257 | return tf.reduce_max(a, axis=axis, keepdims=not squeeze) 258 | 259 | 260 | @dispatch 261 | def argmax(a: Numeric, axis: Union[Int, None] = None): 262 | if axis is None: 263 | # The default `None` reduces over the last dimension. 264 | return tf.argmax(tf.reshape(a, -1), axis=0) 265 | else: 266 | return tf.argmax(a, axis=axis) 267 | 268 | 269 | @dispatch 270 | def sum(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 271 | return tf.reduce_sum(a, axis=axis, keepdims=not squeeze) 272 | 273 | 274 | @dispatch 275 | def prod(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 276 | return tf.reduce_prod(a, axis=axis, keepdims=not squeeze) 277 | 278 | 279 | @dispatch 280 | def mean(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 281 | return tf.reduce_mean(a, axis=axis, keepdims=not squeeze) 282 | 283 | 284 | @dispatch 285 | def std(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 286 | if axis is None: 287 | axes = list(range(B.rank(a))) 288 | else: 289 | axes = [axis] 290 | _, var = tf.nn.moments(a, axes=axes, keepdims=not squeeze) 291 | return tf.sqrt(var) 292 | 293 | 294 | @dispatch 295 | def all(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 296 | return tf.reduce_all(a, axis=axis, keepdims=not squeeze) 297 | 298 | 299 | @dispatch 300 | def any(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 301 | return tf.reduce_any(a, axis=axis, keepdims=not squeeze) 302 | 303 | 304 | @dispatch 305 | def lt(a: Numeric, b: Numeric): 306 | return tf.less(a, b) 307 | 308 | 309 | @dispatch 310 | def le(a: Numeric, b: Numeric): 311 | return tf.less_equal(a, b) 312 | 313 | 314 | @dispatch 315 | def gt(a: Numeric, b: Numeric): 316 | return tf.greater(a, b) 317 | 318 | 319 | @dispatch 320 | def ge(a: Numeric, b: Numeric): 321 | return tf.greater_equal(a, b) 322 | 323 | 324 | @dispatch 325 | def eq(a: Numeric, b: Numeric): 326 | return tf.equal(a, b) 327 | 328 | 329 | @dispatch 330 | def ne(a: Numeric, b: Numeric): 331 | return tf.not_equal(a, b) 332 | 333 | 334 | _bvn_cdf = tensorflow_register(bvn_cdf, s_bvn_cdf) 335 | 336 | 337 | @dispatch 338 | def bvn_cdf(a: Numeric, b: Numeric, c: Numeric): 339 | return _bvn_cdf(a, b, c) 340 | 341 | 342 | @dispatch 343 | def _cond(condition: TFNumeric, f_true: FunctionType, f_false: FunctionType, *args): 344 | return tf.cond(condition, lambda: f_true(*args), lambda: f_false(*args)) 345 | 346 | 347 | @dispatch 348 | def where(condition: Numeric, a: Numeric, b: Numeric): 349 | return tf.where(condition, a, b) 350 | 351 | 352 | # If `Numeric` types are used here, this implementation is more specific than the 353 | # generic implementation, which will use TensorFlow unnecessarily. 354 | @dispatch 355 | def scan(f: Callable, xs: TFNumeric, *init_state: TFNumeric): 356 | return tf.scan(f, xs, initializer=init_state) 357 | 358 | 359 | @dispatch 360 | def sort(a: Numeric, axis: Int = -1, descending: bool = False): 361 | if descending: 362 | direction = "DESCENDING" 363 | else: 364 | direction = "ASCENDING" 365 | return tf.sort(a, axis=axis, direction=direction) 366 | 367 | 368 | @dispatch 369 | def argsort(a: Numeric, axis: Int = -1, descending: bool = False): 370 | if descending: 371 | direction = "DESCENDING" 372 | else: 373 | direction = "ASCENDING" 374 | return tf.argsort(a, axis=axis, direction=direction) 375 | 376 | 377 | @dispatch 378 | def quantile(a: Numeric, q: Numeric, axis: Union[Int, None] = None): 379 | return tfp.stats.percentile(a, 100 * q, axis=axis, interpolation="linear") 380 | -------------------------------------------------------------------------------- /lab/tensorflow/linear_algebra.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import opt_einsum as oe 4 | import tensorflow as tf 5 | 6 | from ..custom import expm, logm, s_expm, s_logm, s_toeplitz_solve, toeplitz_solve 7 | from ..linear_algebra import _default_perm 8 | from ..types import Int 9 | from ..util import resolve_axis 10 | from . import B, Numeric, dispatch 11 | from .custom import tensorflow_register 12 | 13 | __all__ = [] 14 | 15 | 16 | @dispatch 17 | def matmul(a: Numeric, b: Numeric, tr_a: bool = False, tr_b: bool = False): 18 | return tf.matmul(a, b, transpose_a=tr_a, transpose_b=tr_b) 19 | 20 | 21 | @dispatch 22 | def einsum(equation: str, *elements: Numeric): 23 | return oe.contract(equation, *elements, backend="tensorflow") 24 | 25 | 26 | @dispatch 27 | def transpose(a: Numeric, perm: Optional[Union[tuple, list]] = None): 28 | # Correctly handle special cases. 29 | rank_a = B.rank(a) 30 | if rank_a == 0: 31 | return a 32 | elif rank_a == 1 and perm is None: 33 | return a[None, :] 34 | 35 | if perm is None: 36 | perm = _default_perm(a) 37 | return tf.transpose(a, perm=perm) 38 | 39 | 40 | @dispatch 41 | def trace(a: Numeric, axis1: Int = -2, axis2: Int = -1): 42 | axis1 = resolve_axis(a, axis1) 43 | axis2 = resolve_axis(a, axis2) 44 | perm = [i for i in range(B.rank(a)) if i not in [axis1, axis2]] 45 | perm += [axis1, axis2] 46 | a = tf.transpose(a, perm=perm) 47 | return tf.linalg.trace(a) 48 | 49 | 50 | @dispatch 51 | def svd(a: Numeric, compute_uv: bool = True): 52 | res = tf.linalg.svd(a, full_matrices=False, compute_uv=compute_uv) 53 | return (res[1], res[0], res[2]) if compute_uv else res 54 | 55 | 56 | @dispatch 57 | def eig(a: Numeric, compute_eigvecs: bool = True): 58 | vals, vecs = tf.linalg.eig(a) 59 | return (vals, vecs) if compute_eigvecs else vals 60 | 61 | 62 | @dispatch 63 | def solve(a: Numeric, b: Numeric): 64 | return tf.linalg.solve(a, b) 65 | 66 | 67 | @dispatch 68 | def inv(a: Numeric): 69 | return tf.linalg.inv(a) 70 | 71 | 72 | @dispatch 73 | def det(a: Numeric): 74 | return tf.linalg.det(a) 75 | 76 | 77 | @dispatch 78 | def logdet(a: Numeric): 79 | return tf.linalg.logdet(a) 80 | 81 | 82 | _expm = tensorflow_register(expm, s_expm) 83 | 84 | 85 | @dispatch 86 | def expm(a: Numeric): 87 | return _expm(a) 88 | 89 | 90 | _logm = tensorflow_register(logm, s_logm) 91 | 92 | 93 | @dispatch 94 | def logm(a: Numeric): 95 | return _logm(a) 96 | 97 | 98 | @dispatch 99 | def _cholesky(a: Numeric): 100 | return tf.linalg.cholesky(a) 101 | 102 | 103 | @dispatch 104 | def cholesky_solve(a: Numeric, b: Numeric): 105 | return tf.linalg.cholesky_solve(a, b) 106 | 107 | 108 | @dispatch 109 | def triangular_solve(a: Numeric, b: Numeric, lower_a: bool = True): 110 | return tf.linalg.triangular_solve(a, b, lower=lower_a) 111 | 112 | 113 | _toeplitz_solve = tensorflow_register(toeplitz_solve, s_toeplitz_solve) 114 | 115 | 116 | @dispatch 117 | def toeplitz_solve(a: Numeric, b: Numeric, c: Numeric): 118 | return _toeplitz_solve(a, b, c) 119 | -------------------------------------------------------------------------------- /lab/tensorflow/random.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import tensorflow as tf 4 | 5 | from ..random import _randcat_last_first 6 | from ..types import Int, TFDType, TFNumeric, TFRandomState 7 | from ..util import broadcast_shapes, compress_batch 8 | from . import B, Numeric, dispatch 9 | 10 | __all__ = [] 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | 15 | @dispatch 16 | def create_random_state(_: TFDType, seed: Int = 0): 17 | return tf.random.Generator.from_seed(seed) 18 | 19 | 20 | @dispatch 21 | def global_random_state(_: TFDType): 22 | return tf.random.get_global_generator() 23 | 24 | 25 | @dispatch 26 | def set_global_random_state(state: TFRandomState): 27 | tf.random.set_global_generator(state) 28 | 29 | 30 | @dispatch 31 | def rand(state: TFRandomState, dtype: TFDType, *shape: Int): 32 | return state, state.uniform(shape, dtype=dtype) 33 | 34 | 35 | @dispatch 36 | def rand(dtype: TFDType, *shape: Int): 37 | return rand(global_random_state(dtype), dtype, *shape)[1] 38 | 39 | 40 | @dispatch 41 | def randn(state: TFRandomState, dtype: TFDType, *shape: Int): 42 | return state, state.normal(shape, dtype=dtype) 43 | 44 | 45 | @dispatch 46 | def randn(dtype: TFDType, *shape: Int): 47 | return randn(global_random_state(dtype), dtype, *shape)[1] 48 | 49 | 50 | @dispatch 51 | def randcat(state: TFRandomState, p: TFNumeric, n: Int): 52 | # `p` must be at least rank two. 53 | if B.rank(p) == 1: 54 | p = B.expand_dims(p, axis=0) 55 | extra_dim = True 56 | else: 57 | extra_dim = False 58 | 59 | p, uncompress = compress_batch(p, 1) 60 | inds = tf.random.stateless_categorical( 61 | tf.math.log(p), 62 | n, 63 | state.make_seeds()[:, 0], 64 | ) 65 | inds = uncompress(inds) 66 | 67 | # Possibly remove the extra dimension. Do this before moving the last dimension 68 | # first! 69 | if extra_dim: 70 | inds = inds[0, :] 71 | 72 | inds = _randcat_last_first(inds) 73 | 74 | return state, inds 75 | 76 | 77 | @dispatch 78 | def randcat(p: TFNumeric, *shape: Int): 79 | return randcat(global_random_state(p), p, *shape)[1] 80 | 81 | 82 | @dispatch 83 | def randint( 84 | state: TFRandomState, 85 | dtype: TFDType, 86 | *shape: Int, 87 | lower: Int = 0, 88 | upper: Int, 89 | ): 90 | dtype = B.dtype_int(dtype) 91 | return state, state.uniform(shape, lower, upper, dtype=dtype) 92 | 93 | 94 | @dispatch 95 | def randint(dtype: TFDType, *shape: Int, lower: Int = 0, upper: Int): 96 | state = global_random_state(dtype) 97 | return randint(state, dtype, *shape, lower=lower, upper=upper)[1] 98 | 99 | 100 | @dispatch 101 | def randperm(state: TFRandomState, dtype: TFDType, n: Int): 102 | dtype = B.dtype_int(dtype) 103 | # TF does not have a function to generate a random permutation. One way to do it 104 | # manually is to generate a range of length `n` and then shuffle it, but TF also 105 | # does not have a stateless shuffle. Hence, to get a stateless random permutation, 106 | # we generate random numbers and sort them... 107 | # TODO: Do this in a better way. 108 | perm = tf.argsort(state.uniform((n,), dtype=tf.float32)) 109 | return state, B.cast(dtype, perm) 110 | 111 | 112 | @dispatch 113 | def randperm(dtype: TFDType, n: Int): 114 | return randperm(global_random_state(dtype), dtype, n)[1] 115 | 116 | 117 | @dispatch 118 | def randgamma( 119 | state: TFRandomState, 120 | dtype: TFDType, 121 | *shape: Int, 122 | alpha: Numeric, 123 | scale: Numeric, 124 | ): 125 | sample = tf.random.stateless_gamma( 126 | shape + broadcast_shapes(B.shape(alpha), B.shape(scale)), 127 | alpha=alpha, 128 | seed=state.make_seeds()[:, 0], 129 | dtype=dtype, 130 | ) 131 | return state, sample * B.to_active_device(B.cast(dtype, scale)) 132 | 133 | 134 | @dispatch 135 | def randgamma(dtype: TFDType, *shape: Int, alpha: Numeric, scale: Numeric): 136 | state = global_random_state(dtype) 137 | return randgamma(state, dtype, *shape, alpha=alpha, scale=scale)[1] 138 | -------------------------------------------------------------------------------- /lab/tensorflow/shaping.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import tensorflow as tf 4 | 5 | from ..shape import unwrap_dimension 6 | from ..types import Int, NPNumeric, TFNumeric 7 | from ..util import resolve_axis 8 | from . import B, Numeric, dispatch 9 | 10 | __all__ = [] 11 | 12 | 13 | @dispatch 14 | def length(a: Numeric): 15 | return tf.size(a) 16 | 17 | 18 | @dispatch 19 | def _expand_dims(a: Numeric, axis: Int = 0): 20 | return tf.expand_dims(a, axis=axis) 21 | 22 | 23 | @dispatch 24 | def squeeze(a: Numeric, axis: Union[Int, None] = None): 25 | return tf.squeeze(a, axis=axis) 26 | 27 | 28 | @dispatch 29 | def broadcast_to(a: Numeric, *shape: Int): 30 | return tf.broadcast_to(a, shape) 31 | 32 | 33 | @dispatch 34 | def diag(a: Numeric): 35 | if B.rank(a) == 1: 36 | return tf.linalg.diag(a) 37 | elif B.rank(a) == 2: 38 | return tf.linalg.diag_part(a) 39 | else: 40 | raise ValueError("Input must have rank 1 or 2.") 41 | 42 | 43 | @dispatch 44 | def diag_extract(a: Numeric): 45 | return tf.linalg.diag_part(a) 46 | 47 | 48 | @dispatch 49 | def diag_construct(a: TFNumeric): 50 | return tf.linalg.diag(a) 51 | 52 | 53 | @dispatch 54 | def stack(*elements: Numeric, axis: Int = 0): 55 | return tf.stack(elements, axis=axis) 56 | 57 | 58 | @dispatch 59 | def _unstack(a: Numeric, axis: Int = 0): 60 | return tf.unstack(a, axis=axis) 61 | 62 | 63 | @dispatch 64 | def reshape(a: Numeric, *shape: Int): 65 | return tf.reshape(a, shape=shape) 66 | 67 | 68 | @dispatch 69 | def concat(*elements: Numeric, axis: Int = 0): 70 | return tf.concat(elements, axis=axis) 71 | 72 | 73 | @dispatch 74 | def tile(a: Numeric, *repeats: Int): 75 | return tf.tile(a, repeats) 76 | 77 | 78 | @dispatch 79 | def take(a: TFNumeric, indices_or_mask, axis: Int = 0): 80 | if B.rank(indices_or_mask) != 1: 81 | raise ValueError("Indices or mask must be rank 1.") 82 | is_mask, indices_or_mask, shape_hint = _is_mask_convert_shape_hint(indices_or_mask) 83 | 84 | # Perform taking operation. 85 | if is_mask: 86 | # `tf.boolean_mask` isn't happy with negative axes. 87 | result = tf.boolean_mask(a, indices_or_mask, axis=resolve_axis(a, axis)) 88 | else: 89 | result = tf.gather(a, indices_or_mask, axis=axis) 90 | 91 | # Apply the shape hint, if it is available. 92 | if shape_hint is not None: 93 | # Carefully unwrap to deal with lazy shapes. 94 | shape = list(map(unwrap_dimension, B.shape(a))) 95 | shape[axis] = shape_hint 96 | result.set_shape(shape) 97 | 98 | return result 99 | 100 | 101 | @dispatch 102 | def _is_mask_convert_shape_hint(indices_or_mask: TFNumeric): 103 | return indices_or_mask.dtype == bool, indices_or_mask, None 104 | 105 | 106 | @dispatch 107 | def _is_mask_convert_shape_hint(indices_or_mask: NPNumeric): 108 | is_mask = indices_or_mask.dtype == bool 109 | if is_mask: 110 | shape_hint = sum(indices_or_mask) 111 | else: 112 | shape_hint = len(indices_or_mask) 113 | return is_mask, tf.constant(indices_or_mask), shape_hint 114 | 115 | 116 | @dispatch 117 | def _is_mask_convert_shape_hint(indices_or_mask: Union[tuple, list]): 118 | if len(indices_or_mask) == 0: 119 | # Treat an empty tuple or list as a list of no indices. The data type does not 120 | # matter, except that it must be integer. 121 | return False, tf.constant([], dtype=tf.int32), 0 122 | else: 123 | is_mask = B.dtype(indices_or_mask[0]) == bool 124 | if is_mask: 125 | shape_hint = sum(indices_or_mask) 126 | else: 127 | shape_hint = len(indices_or_mask) 128 | return is_mask, indices_or_mask, shape_hint 129 | -------------------------------------------------------------------------------- /lab/torch/__init__.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | from .. import * 3 | from .. import dispatch as dispatch_original 4 | from ..shape import dispatch_unwrap_dimensions 5 | from ..types import Number, TorchNumeric 6 | 7 | dispatch = dispatch_unwrap_dimensions(dispatch_original) 8 | 9 | from typing import Union 10 | 11 | Numeric = Union[Number, TorchNumeric] 12 | 13 | import torch # Load `torch` to load all new types. 14 | from plum import clear_all_cache as _clear_all_cache 15 | 16 | # noinspection PyUnresolvedReferences 17 | from .generic import * 18 | from .linear_algebra import * 19 | from .random import * 20 | from .shaping import * 21 | 22 | # Clear cache to make sure that all newly loaded types are available. 23 | _clear_all_cache() 24 | 25 | # Alias to actual module. 26 | sys.modules[__name__] = B 27 | -------------------------------------------------------------------------------- /lab/torch/custom.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | import torch 4 | from plum import Dispatcher, convert 5 | 6 | from . import B 7 | 8 | __all__ = ["torch_register", "as_torch"] 9 | 10 | _dispatch = Dispatcher() 11 | 12 | 13 | @_dispatch 14 | def as_torch(x: B.Numeric, grad: bool = False): 15 | """Convert object to PyTorch. 16 | 17 | Args: 18 | x (object): Object to convert. 19 | grad (bool, optional): Requires gradient. Defaults to `False`. 20 | 21 | Returns: 22 | object: `x` as a PyTorch object. 23 | """ 24 | dtype = convert(B.dtype(x), B.TorchDType) 25 | return torch.tensor(x, dtype=dtype, requires_grad=grad) 26 | 27 | 28 | @_dispatch 29 | def as_torch(xs: tuple, grad: bool = False): 30 | return tuple([as_torch(x, grad=grad) for x in xs]) 31 | 32 | 33 | def torch_register(f, s_f): 34 | """Register a function and its sensitivity for PyTorch. 35 | 36 | Args: 37 | f (function): Function to register. 38 | s_f (function): Sensitivity of `f`. 39 | 40 | Returns: 41 | function: PyTorch primitive. 42 | """ 43 | 44 | # Create a custom PyTorch function. 45 | class Function(torch.autograd.Function): 46 | @staticmethod 47 | def forward(ctx, *args): 48 | y = f(*B.to_numpy(args)) 49 | ctx.save_for_backward(as_torch(y), *args) 50 | return as_torch(y) 51 | 52 | @staticmethod 53 | def backward(ctx, s_y): # pragma: no cover 54 | # The profiler does not catch that this is tested. 55 | y = ctx.saved_tensors[0] 56 | args = ctx.saved_tensors[1:] 57 | return as_torch(s_f(s_y.numpy(), y.numpy(), *B.to_numpy(args))) 58 | 59 | # Wrap it to preserve the function name. 60 | 61 | @wraps(f) 62 | def f_wrapped(*args, **kw_args): 63 | return Function.apply(*args, **kw_args) 64 | 65 | return f_wrapped 66 | -------------------------------------------------------------------------------- /lab/torch/generic.py: -------------------------------------------------------------------------------- 1 | from types import FunctionType 2 | from typing import Union 3 | 4 | import torch 5 | from torch.jit import is_tracing, trace 6 | 7 | from ..custom import bvn_cdf, s_bvn_cdf 8 | from ..shape import Dimension 9 | from ..types import Int, NPNumeric, Number, TorchDType, TorchNumeric, TorchRandomState 10 | from . import B, Numeric, dispatch 11 | from .custom import torch_register 12 | 13 | __all__ = [] 14 | 15 | 16 | @dispatch 17 | def isabstract(a: Numeric): 18 | return is_tracing() 19 | 20 | 21 | @dispatch 22 | def _jit_run( 23 | f: FunctionType, 24 | compilation_cache: dict, 25 | jit_kw_args: dict, 26 | *args: Union[Numeric, TorchRandomState], 27 | ): 28 | if "torch" not in compilation_cache: 29 | # Run once to populate the control flow cache. 30 | f(*args) 31 | # Compile. 32 | compilation_cache["torch"] = trace(f, args, **jit_kw_args) 33 | 34 | return compilation_cache["torch"](*args) 35 | 36 | 37 | @dispatch 38 | def isnan(a: Numeric): 39 | return torch.isnan(a) 40 | 41 | 42 | @dispatch 43 | def real(a: Numeric): 44 | return torch.real(a) 45 | 46 | 47 | @dispatch 48 | def imag(a: Numeric): 49 | return torch.imag(a) 50 | 51 | 52 | @dispatch 53 | def device(a: TorchNumeric): 54 | return a.device 55 | 56 | 57 | @dispatch 58 | def to_active_device(a: TorchNumeric): 59 | return a.to(B.ActiveDevice.active_name) 60 | 61 | 62 | @dispatch 63 | def zeros(dtype: TorchDType, *shape: Int): 64 | return torch.zeros(shape, dtype=dtype, device=B.ActiveDevice.active_name) 65 | 66 | 67 | @dispatch 68 | def ones(dtype: TorchDType, *shape: Int): 69 | return torch.ones(shape, dtype=dtype, device=B.ActiveDevice.active_name) 70 | 71 | 72 | @dispatch 73 | def _eye2(dtype: TorchDType, *shape: Int): 74 | return torch.eye(shape[0], shape[1], dtype=dtype, device=B.ActiveDevice.active_name) 75 | 76 | 77 | @dispatch 78 | def linspace(dtype: TorchDType, a, b, num: Int): 79 | return torch.linspace(a, b, num, dtype=dtype, device=B.ActiveDevice.active_name) 80 | 81 | 82 | @dispatch 83 | def range(dtype: TorchDType, start, stop, step): 84 | return torch.arange( 85 | start, stop, step, dtype=dtype, device=B.ActiveDevice.active_name 86 | ) 87 | 88 | 89 | @dispatch 90 | def cast(dtype: TorchDType, a: TorchNumeric): 91 | return a.type(dtype) 92 | 93 | 94 | @dispatch 95 | def cast(dtype: TorchDType, a: Union[Number, NPNumeric]): 96 | return torch.tensor(a, dtype=dtype, device=B.ActiveDevice.active_name) 97 | 98 | 99 | @dispatch 100 | def cast(dtype: TorchDType, a: Dimension): 101 | # A dimension may automatically unwrap to a PyTorch tensor. 102 | return cast(dtype, a) 103 | 104 | 105 | @dispatch 106 | def identity(a: Numeric): 107 | return torch.multiply(1, a) 108 | 109 | 110 | @dispatch 111 | def round(a: Numeric): 112 | return torch.round(a) 113 | 114 | 115 | @dispatch 116 | def floor(a: Numeric): 117 | return torch.floor(a) 118 | 119 | 120 | @dispatch 121 | def ceil(a: Numeric): 122 | return torch.ceil(a) 123 | 124 | 125 | @dispatch 126 | def negative(a: Numeric): 127 | return torch.neg(a) 128 | 129 | 130 | @dispatch 131 | def abs(a: Numeric): 132 | return torch.abs(a) 133 | 134 | 135 | @dispatch 136 | def sign(a: Numeric): 137 | return torch.sign(a) 138 | 139 | 140 | @dispatch 141 | def sqrt(a: Numeric): 142 | return torch.sqrt(a) 143 | 144 | 145 | @dispatch 146 | def exp(a: Numeric): 147 | return torch.exp(a) 148 | 149 | 150 | @dispatch 151 | def log(a: Numeric): 152 | return torch.log(a) 153 | 154 | 155 | @dispatch 156 | def log1p(a: Numeric): 157 | return torch.log1p(a) 158 | 159 | 160 | @dispatch 161 | def sin(a: Numeric): 162 | return torch.sin(a) 163 | 164 | 165 | @dispatch 166 | def arcsin(a: Numeric): 167 | return torch.arcsin(a) 168 | 169 | 170 | @dispatch 171 | def cos(a: Numeric): 172 | return torch.cos(a) 173 | 174 | 175 | @dispatch 176 | def arccos(a: Numeric): 177 | return torch.arccos(a) 178 | 179 | 180 | @dispatch 181 | def tan(a: Numeric): 182 | return torch.tan(a) 183 | 184 | 185 | @dispatch 186 | def arctan(a: Numeric): 187 | return torch.arctan(a) 188 | 189 | 190 | @dispatch 191 | def tanh(a: Numeric): 192 | return torch.tanh(a) 193 | 194 | 195 | @dispatch 196 | def arctanh(a: Numeric): 197 | return torch.arctanh(a) 198 | 199 | 200 | @dispatch 201 | def loggamma(a: Numeric): 202 | return torch.lgamma(a) 203 | 204 | 205 | @dispatch 206 | def erf(a: Numeric): 207 | return torch.erf(a) 208 | 209 | 210 | @dispatch 211 | def softplus(a: TorchNumeric): 212 | return torch.nn.functional.softplus(a) 213 | 214 | 215 | @dispatch 216 | def add(a: Numeric, b: Numeric): 217 | return torch.add(a, b) 218 | 219 | 220 | @dispatch 221 | def subtract(a: Numeric, b: Numeric): 222 | return torch.subtract(a, b) 223 | 224 | 225 | @dispatch 226 | def multiply(a: Numeric, b: Numeric): 227 | return torch.multiply(a, b) 228 | 229 | 230 | @dispatch 231 | def divide(a: Numeric, b: Numeric): 232 | return torch.divide(a, b) 233 | 234 | 235 | @dispatch 236 | def power(a: Numeric, b: Numeric): 237 | return torch.pow(a, b) 238 | 239 | 240 | @dispatch 241 | def minimum(a: Numeric, b: Numeric): 242 | return torch.min(a, b) 243 | 244 | 245 | @dispatch 246 | def maximum(a: Numeric, b: Numeric): 247 | return torch.max(a, b) 248 | 249 | 250 | @dispatch 251 | def min(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 252 | if axis is None: 253 | return torch.min(a) 254 | else: 255 | return torch.min(a, dim=axis, keepdim=not squeeze)[0] 256 | 257 | 258 | @dispatch 259 | def argmin(a: Numeric, axis: Union[Int, None] = None): 260 | return torch.argmin(a, dim=axis) 261 | 262 | 263 | @dispatch 264 | def max(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 265 | if axis is None: 266 | return torch.max(a) 267 | else: 268 | return torch.max(a, dim=axis, keepdim=not squeeze)[0] 269 | 270 | 271 | @dispatch 272 | def argmax(a: Numeric, axis: Union[Int, None] = None): 273 | return torch.argmax(a, dim=axis) 274 | 275 | 276 | @dispatch 277 | def sum(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 278 | if axis is None: 279 | return torch.sum(a) 280 | else: 281 | return torch.sum(a, dim=axis, keepdim=not squeeze) 282 | 283 | 284 | @dispatch 285 | def prod(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 286 | if axis is None: 287 | return torch.prod(a) 288 | else: 289 | return torch.prod(a, dim=axis, keepdim=not squeeze) 290 | 291 | 292 | @dispatch 293 | def mean(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 294 | # Only `torch.mean` allows `dim=None`. The other functions don't. 295 | return torch.mean(a, dim=axis, keepdim=not squeeze) 296 | 297 | 298 | @dispatch 299 | def std(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 300 | if axis is None: 301 | return torch.std(a, unbiased=False) 302 | else: 303 | return torch.std(a, dim=axis, unbiased=False, keepdim=not squeeze) 304 | 305 | 306 | @dispatch 307 | def all(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 308 | if axis is None: 309 | return torch.all(a) 310 | else: 311 | return torch.all(a, dim=axis, keepdim=not squeeze) 312 | 313 | 314 | @dispatch 315 | def any(a: Numeric, axis: Union[Int, None] = None, squeeze: bool = True): 316 | if axis is None: 317 | return torch.any(a) 318 | else: 319 | return torch.any(a, dim=axis, keepdim=not squeeze) 320 | 321 | 322 | @dispatch 323 | def lt(a: Numeric, b: Numeric): 324 | return torch.lt(a, b) 325 | 326 | 327 | @dispatch 328 | def le(a: Numeric, b: Numeric): 329 | return torch.le(a, b) 330 | 331 | 332 | @dispatch 333 | def gt(a: Numeric, b: Numeric): 334 | return torch.gt(a, b) 335 | 336 | 337 | @dispatch 338 | def ge(a: Numeric, b: Numeric): 339 | return torch.ge(a, b) 340 | 341 | 342 | @dispatch 343 | def eq(a: Numeric, b: Numeric): 344 | return torch.eq(a, b) 345 | 346 | 347 | @dispatch 348 | def ne(a: Numeric, b: Numeric): 349 | return torch.ne(a, b) 350 | 351 | 352 | _bvn_cdf = torch_register(bvn_cdf, s_bvn_cdf) 353 | 354 | 355 | @dispatch 356 | def bvn_cdf(a: Numeric, b: Numeric, c: Numeric): 357 | return _bvn_cdf(a, b, c) 358 | 359 | 360 | @dispatch 361 | def where(condition: Numeric, a: Numeric, b: Numeric): 362 | return torch.where(condition, a, b) 363 | 364 | 365 | @dispatch 366 | def sort(a: Numeric, axis: Int = -1, descending: bool = False): 367 | return torch.sort(a, dim=axis, descending=descending)[0] 368 | 369 | 370 | @dispatch 371 | def argsort(a: Numeric, axis: Int = -1, descending: bool = False): 372 | return torch.argsort(a, dim=axis, descending=descending) 373 | 374 | 375 | @dispatch 376 | def quantile(a: Numeric, q: Numeric, axis: Union[Int, None] = None): 377 | return torch.quantile(a, q, dim=axis) 378 | -------------------------------------------------------------------------------- /lab/torch/linear_algebra.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import opt_einsum as oe 4 | import torch 5 | 6 | from ..custom import expm, logm, s_expm, s_logm, s_toeplitz_solve, toeplitz_solve 7 | from ..linear_algebra import _default_perm 8 | from ..types import Int 9 | from . import B, Numeric, dispatch 10 | from .custom import torch_register 11 | 12 | __all__ = [] 13 | 14 | 15 | @dispatch 16 | def matmul(a: Numeric, b: Numeric, tr_a: bool = False, tr_b: bool = False): 17 | a = transpose(a) if tr_a else a 18 | b = transpose(b) if tr_b else b 19 | return torch.matmul(a, b) 20 | 21 | 22 | @dispatch 23 | def einsum(equation: str, *elements: Numeric): 24 | return oe.contract(equation, *elements, backend="torch") 25 | 26 | 27 | @dispatch 28 | def transpose(a: Numeric, perm: Optional[Union[tuple, list]] = None): 29 | # Correctly handle special cases. 30 | rank_a = B.rank(a) 31 | if rank_a == 0: 32 | return a 33 | elif rank_a == 1 and perm is None: 34 | return a[None, :] 35 | 36 | if perm is None: 37 | perm = _default_perm(a) 38 | return a.permute(*perm) 39 | 40 | 41 | @dispatch 42 | def trace(a: Numeric, axis1: Int = -2, axis2: Int = -1): 43 | return torch.sum(torch.diagonal(a, dim1=axis1, dim2=axis2), dim=-1) 44 | 45 | 46 | @dispatch 47 | def svd(a: Numeric, compute_uv: bool = True): 48 | u, s, v = torch.linalg.svd(a, full_matrices=False) 49 | return (u, s, v) if compute_uv else s 50 | 51 | 52 | @dispatch 53 | def eig(a: Numeric, compute_eigvecs: bool = True): 54 | vals, vecs = torch.linalg.eig(a) 55 | return (vals, vecs) if compute_eigvecs else vals 56 | 57 | 58 | @dispatch 59 | def solve(a: Numeric, b: Numeric): 60 | return torch.linalg.solve(a, b) 61 | 62 | 63 | @dispatch 64 | def inv(a: Numeric): 65 | return torch.inverse(a) 66 | 67 | 68 | @dispatch 69 | def det(a: Numeric): 70 | return torch.linalg.det(a) 71 | 72 | 73 | @dispatch 74 | def logdet(a: Numeric): 75 | return torch.logdet(a) 76 | 77 | 78 | _expm = torch_register(expm, s_expm) 79 | 80 | 81 | @dispatch 82 | def expm(a: Numeric): 83 | return _expm(a) 84 | 85 | 86 | _logm = torch_register(logm, s_logm) 87 | 88 | 89 | @dispatch 90 | def logm(a: Numeric): 91 | return _logm(a) 92 | 93 | 94 | @dispatch 95 | def _cholesky(a: Numeric): 96 | return torch.linalg.cholesky(a) 97 | 98 | 99 | @dispatch 100 | def cholesky_solve(a: Numeric, b: Numeric): 101 | return torch.cholesky_solve(b, a, upper=False) 102 | 103 | 104 | @dispatch 105 | def triangular_solve(a: Numeric, b: Numeric, lower_a: bool = True): 106 | return torch.linalg.solve_triangular(a, b, upper=not lower_a) 107 | 108 | 109 | _toeplitz_solve = torch_register(toeplitz_solve, s_toeplitz_solve) 110 | 111 | 112 | @dispatch 113 | def toeplitz_solve(a: Numeric, b: Numeric, c: Numeric): 114 | return _toeplitz_solve(a, b, c) 115 | -------------------------------------------------------------------------------- /lab/torch/random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..random import _randcat_last_first 4 | from ..types import Int, TorchDType, TorchNumeric, TorchRandomState 5 | from ..util import compress_batch 6 | from . import B, Numeric, dispatch 7 | 8 | __all__ = [] 9 | 10 | 11 | @dispatch 12 | def create_random_state(_: TorchDType, seed: Int = 0): 13 | state = torch.Generator(device=B.ActiveDevice.active_name) 14 | state.manual_seed(seed) 15 | return state 16 | 17 | 18 | @dispatch 19 | def global_random_state(_: TorchDType): 20 | if B.ActiveDevice.active_name in {None, "cpu"}: 21 | return torch.random.default_generator 22 | else: 23 | parts = B.ActiveDevice.active_name.lower().split(":", 1) 24 | 25 | if len(parts) == 0 or parts[0] not in {"cuda", "gpu", "mps"}: 26 | raise RuntimeError(f'Unknown active device "{B.ActiveDevice.active_name}".') 27 | 28 | if parts[0] == "mps": 29 | if len(parts) == 2 and int(parts[1]) != 0: 30 | raise ValueError("Cannot specify a device number for PyTorch MPS.") 31 | 32 | import torch.mps as mps 33 | 34 | return mps._get_default_mps_generator() 35 | else: 36 | # Ensure that the generators are available. 37 | if len(torch.cuda.default_generators) == 0: 38 | torch.cuda.init() 39 | 40 | if len(parts) == 1: 41 | return torch.cuda.default_generators[0] 42 | else: 43 | return torch.cuda.default_generators[int(parts[1])] 44 | 45 | 46 | @dispatch 47 | def set_global_random_state(state: TorchRandomState): 48 | global_gen = global_random_state.invoke(TorchDType)(None) 49 | global_gen.set_state(state.get_state()) 50 | 51 | 52 | @dispatch 53 | def rand(state: TorchRandomState, dtype: TorchDType, *shape: Int): 54 | return state, torch.rand( 55 | shape, 56 | dtype=dtype, 57 | device=B.ActiveDevice.active_name, 58 | generator=state, 59 | ) 60 | 61 | 62 | @dispatch 63 | def rand(dtype: TorchDType, *shape: Int): 64 | return rand(global_random_state(dtype), dtype, *shape)[1] 65 | 66 | 67 | @dispatch 68 | def randn(state: TorchRandomState, dtype: TorchDType, *shape: Int): 69 | return state, torch.randn( 70 | shape, 71 | dtype=dtype, 72 | device=B.ActiveDevice.active_name, 73 | generator=state, 74 | ) 75 | 76 | 77 | @dispatch 78 | def randn(dtype: TorchDType, *shape: Int): 79 | return randn(global_random_state(dtype), dtype, *shape)[1] 80 | 81 | 82 | @dispatch 83 | def randcat(state: TorchRandomState, p: TorchNumeric, n: Int): 84 | p, uncompress = compress_batch(p, 1) 85 | inds = torch.multinomial(p, n, replacement=True, generator=state) 86 | inds = uncompress(inds) 87 | inds = _randcat_last_first(inds) 88 | return state, inds 89 | 90 | 91 | @dispatch 92 | def randcat(p: TorchNumeric, *shape: Int): 93 | return randcat(B.global_random_state(p), p, *shape)[1] 94 | 95 | 96 | @dispatch 97 | def randint( 98 | state: TorchRandomState, 99 | dtype: TorchDType, 100 | *shape: Int, 101 | lower: Int = 0, 102 | upper: Int, 103 | ): 104 | dtype = B.dtype_int(dtype) 105 | return state, torch.randint( 106 | lower, 107 | upper, 108 | shape, 109 | dtype=dtype, 110 | device=B.ActiveDevice.active_name, 111 | generator=state, 112 | ) 113 | 114 | 115 | @dispatch 116 | def randint(dtype: TorchDType, *shape: Int, lower: Int = 0, upper): 117 | state = global_random_state(dtype) 118 | return randint(state, dtype, *shape, lower=lower, upper=upper)[1] 119 | 120 | 121 | @dispatch 122 | def randperm(state: TorchRandomState, dtype: TorchDType, n: Int): 123 | dtype = B.dtype_int(dtype) 124 | return state, torch.randperm( 125 | n, 126 | dtype=dtype, 127 | device=B.ActiveDevice.active_name, 128 | generator=state, 129 | ) 130 | 131 | 132 | @dispatch 133 | def randperm(dtype: TorchDType, n: Int): 134 | return randperm(global_random_state(dtype), dtype, n)[1] 135 | 136 | 137 | @dispatch 138 | def randgamma( 139 | state: TorchRandomState, 140 | dtype: TorchDType, 141 | *shape: Int, 142 | alpha: Numeric, 143 | scale: Numeric, 144 | ): 145 | alpha = B.to_active_device(B.cast(dtype, alpha)) 146 | scale = B.to_active_device(B.cast(dtype, scale)) 147 | alpha, scale = torch.broadcast_tensors(alpha, scale) 148 | alpha = B.repeat(alpha, *shape) 149 | scale = B.repeat(scale, *shape) 150 | return state, torch._standard_gamma(alpha, generator=state) * scale 151 | 152 | 153 | @dispatch 154 | def randgamma(dtype: TorchDType, *shape: Int, alpha: Numeric, scale: Numeric): 155 | state = global_random_state(dtype) 156 | return randgamma(state, dtype, *shape, alpha=alpha, scale=scale)[1] 157 | -------------------------------------------------------------------------------- /lab/torch/shaping.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | from plum import convert 6 | 7 | from ..types import Int, NPDType, issubdtype 8 | from . import Numeric, TorchNumeric, dispatch 9 | 10 | __all__ = [] 11 | 12 | 13 | @dispatch 14 | def length(a: Numeric): 15 | return a.numel() 16 | 17 | 18 | @dispatch 19 | def _expand_dims(a: Numeric, axis: Int = 0): 20 | return torch.unsqueeze(a, dim=axis) 21 | 22 | 23 | @dispatch 24 | def squeeze(a: Numeric, axis: Union[Int, None] = None): 25 | if axis is None: 26 | return torch.squeeze(a) 27 | else: 28 | return torch.squeeze(a, dim=axis) 29 | 30 | 31 | @dispatch 32 | def broadcast_to(a: Numeric, *shape: Int): 33 | return torch.broadcast_to(a, shape) 34 | 35 | 36 | @dispatch 37 | def diag(a: Numeric): 38 | return torch.diag(a) 39 | 40 | 41 | @dispatch 42 | def diag_extract(a: Numeric): 43 | return torch.diagonal(a, dim1=-2, dim2=-1) 44 | 45 | 46 | @dispatch 47 | def diag_construct(a: Numeric): 48 | return torch.diag_embed(a, dim1=-2, dim2=-1) 49 | 50 | 51 | @dispatch 52 | def stack(*elements: Numeric, axis: Int = 0): 53 | return torch.stack(elements, dim=axis) 54 | 55 | 56 | @dispatch 57 | def _unstack(a: Numeric, axis: Int = 0): 58 | return torch.unbind(a, dim=axis) 59 | 60 | 61 | @dispatch 62 | def reshape(a: Numeric, *shape: Int): 63 | return torch.reshape(a, shape=shape) 64 | 65 | 66 | @dispatch 67 | def concat(*elements: Numeric, axis: Int = 0): 68 | return torch.cat(elements, dim=axis) 69 | 70 | 71 | @dispatch 72 | def tile(a: Numeric, *repeats: Int): 73 | return a.repeat(*repeats) 74 | 75 | 76 | @dispatch 77 | def _take_convert(indices_or_mask: Union[list, tuple]): 78 | return indices_or_mask 79 | 80 | 81 | @dispatch 82 | def _take_convert(indices_or_mask: TorchNumeric): 83 | if issubdtype(convert(indices_or_mask.dtype, NPDType), np.integer): 84 | # Indices must be on the CPU and `int64`s! 85 | return indices_or_mask.cpu().type(torch.int64) 86 | else: 87 | return indices_or_mask 88 | -------------------------------------------------------------------------------- /lab/util.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | import numpy as np 4 | import plum 5 | import plum.signature 6 | import plum.type 7 | 8 | from . import B 9 | 10 | __all__ = [ 11 | "resolve_axis", 12 | "as_tuple", 13 | "batch_computation", 14 | "abstract", 15 | "compress_batch", 16 | "broadcast_shapes", 17 | ] 18 | 19 | _dispatch = plum.Dispatcher() 20 | 21 | 22 | def resolve_axis(a, axis, negative=False): 23 | """Resolve axis for a tensor `a`. 24 | 25 | Args: 26 | a (tensor): Tensor of the axis. 27 | axis (int or None): Axis to resolve. 28 | negative (bool, optional): Resolve the axis to a negative integer rather than 29 | a positive integer. Defaults to `False`. 30 | 31 | Return: 32 | int: Resolved axis. 33 | """ 34 | # Let `None`s pass through. 35 | if axis is None: 36 | return None 37 | 38 | given_axis = axis 39 | 40 | # If it isn't a `None` we should resolve it. 41 | a_rank = B.rank(a) 42 | if not negative: 43 | if axis < 0: 44 | axis = axis + a_rank 45 | if not (0 <= axis < a_rank): 46 | raise ValueError( 47 | f"Axis {given_axis} cannot be resolved for tensor of shape " 48 | f"{B.shape(a)}." 49 | ) 50 | else: 51 | if axis >= 0: 52 | axis = axis - a_rank 53 | if not (-a_rank <= axis < 0): 54 | raise ValueError( 55 | f"Axis {given_axis} cannot be resolved for tensor of shape " 56 | f"{B.shape(a)}." 57 | ) 58 | return axis 59 | 60 | 61 | @_dispatch 62 | def as_tuple(x: tuple): 63 | """Get `x` as a tuple. Will be wrapped in a one-tuple if it is not a tuple. 64 | 65 | Args: 66 | x (object): Object to get as a tuple. 67 | 68 | Returns: 69 | tuple: `x` as a tuple. 70 | """ 71 | return x 72 | 73 | 74 | @_dispatch 75 | def as_tuple(x): 76 | return (x,) 77 | 78 | 79 | def _common_shape(*shapes): 80 | common_shape = shapes[0] 81 | for shape in shapes[1:]: 82 | # Add empty dimensions to either shape if it is shorter. 83 | diff = len(common_shape) - len(shape) 84 | shape = (1,) * max(diff, 0) + shape 85 | common_shape = (1,) * max(-diff, 0) + common_shape 86 | 87 | # Resolve the shapes. 88 | new_common_shape = () 89 | for d1, d2 in zip(common_shape, shape): 90 | if d1 == d2: 91 | new_common_shape += (d1,) 92 | elif d1 == 1: 93 | new_common_shape += (d2,) 94 | elif d2 == 1: 95 | new_common_shape += (d1,) 96 | else: 97 | raise RuntimeError( 98 | f"Cannot reconcile running common shape {common_shape} " 99 | f"with {shape}." 100 | ) 101 | common_shape = new_common_shape 102 | return common_shape 103 | 104 | 105 | def _translate_index(index, batch_shape): 106 | # Remove superfluous index dimensions and cast to tuple. 107 | index = tuple(index[-len(batch_shape) :]) 108 | 109 | # Resolve the index. 110 | translated_index = () 111 | for i, s in zip(index, batch_shape): 112 | if i < s: 113 | translated_index += (i,) 114 | elif s == 1: 115 | translated_index += (0,) 116 | else: 117 | raise RuntimeError( 118 | f"Cannot translate index {index} to batch shape {batch_shape}." 119 | ) 120 | return translated_index 121 | 122 | 123 | def batch_computation(f, xs, ranks): 124 | """Apply a function over all batches of arguments. 125 | 126 | Args: 127 | f (function): Function that performs the computation. 128 | xs (tuple): Matrices or batches of matrices. 129 | ranks (tuple): Ranks of the arguments. 130 | 131 | Returns: 132 | tensor: Result in batched form. 133 | """ 134 | # Reshape arguments for batched computation. 135 | batch_shapes = [B.shape(x)[:-rank] for x, rank in zip(xs, ranks)] 136 | 137 | # Find the common shape. 138 | batch_shape = _common_shape(*batch_shapes) 139 | # Force evaluation of the element of the shape: if the shapes are lazy or when 140 | # a function is evaluated abstractly, the dimensions of the shape may still be 141 | # wrapped. 142 | indices = np.indices(tuple(int(x) for x in batch_shape)) 143 | 144 | # Handle the edge case that there is no batching. 145 | if len(indices) == 0: 146 | indices = [()] 147 | else: 148 | # Put the index dimension last. 149 | perm = tuple(list(range(1, len(batch_shape) + 1))) + (0,) 150 | indices = indices.transpose(perm) 151 | # Turn into a list of indices. 152 | indices = indices.reshape(-1, len(batch_shape)) 153 | 154 | # Loop over batches. 155 | batches = [] 156 | for index in indices: 157 | batches.append( 158 | f(*[x[_translate_index(index, s)] for x, s in zip(xs, batch_shapes)]) 159 | ) 160 | 161 | # Construct result, reshape, and return. 162 | res = B.stack(*batches, axis=0) 163 | return B.reshape(res, *(batch_shape + B.shape(res)[1:])) 164 | 165 | 166 | def abstract(promote=None, promote_from=None): 167 | """Create a decorator for an abstract function. 168 | 169 | Args: 170 | promote (int, optional): Number of arguments to promote. Set to `-1` to promote 171 | all arguments, and set to `None` or `0` to promote no arguments. Defaults to 172 | `None`. Cannot be specified in conjunction with `promote_from`. 173 | promote_from (int, optional): Index from which to promote argument. Set to `-1` 174 | or `None` to promote no arguments, and set to `0` to promote all arguments. 175 | Defaults to `None`. Cannot be specified in conjunction with `promote`. 176 | 177 | Returns: 178 | function: Decorator. 179 | """ 180 | if promote is not None and promote_from is not None: 181 | raise ValueError("Specify either `promote` or `promote_from`.") 182 | 183 | # If `promote` isn't given, we can safely give it the value of 184 | # `promote_from`: either `promote_from` is given, which is fine; or 185 | # `promote_from` isn't given, so `promote` remains at `None`. 186 | if promote is None: 187 | promote = promote_from 188 | 189 | def decorator(f): 190 | @wraps(f) 191 | def wrapper(*args, **kw_args): 192 | # Determine splitting index. 193 | if promote is None or promote == 0: 194 | promote_index = 0 195 | elif promote < 0: 196 | promote_index = len(args) + 1 197 | else: 198 | promote_index = promote 199 | 200 | # Record types. 201 | types_before = tuple(type(arg) for arg in args) 202 | 203 | # Promote. 204 | if promote_from is None: 205 | args = plum.promote(*args[:promote_index]) + args[promote_index:] 206 | else: 207 | args = args[:promote_index] + plum.promote(*args[promote_index:]) 208 | 209 | # Enforce a change in types. Otherwise, the call will recurse, which 210 | # means that an implementation is not available. 211 | types_after = tuple(type(arg) for arg in args) 212 | if types_before == types_after: 213 | signature = plum.signature.Signature(*types_after) 214 | raise plum.NotFoundLookupError(f.__name__, signature, []) 215 | 216 | # Retry call. 217 | return getattr(B, f.__name__)(*args, **kw_args) 218 | 219 | return wrapper 220 | 221 | return decorator 222 | 223 | 224 | def compress_batch(x, n): 225 | """Compress batch dimensions. 226 | 227 | Args: 228 | x (tensor): Tensor to compress. 229 | n (int): Number of non-batch dimensions. 230 | 231 | Return: 232 | tensor: Tensor with compressed batch dimensions. 233 | function: Function to uncompress the batch dimensions. 234 | """ 235 | shape = B.shape(x) 236 | 237 | def uncompress(y): 238 | return B.reshape(y, *shape[:-n], *B.shape(y)[1:]) 239 | 240 | return B.reshape(x, -1, *shape[-n:]), uncompress 241 | 242 | 243 | @_dispatch 244 | def broadcast_shapes(*shapes): 245 | """Broadcast shapes. 246 | 247 | Args: 248 | *shapes (shape): Shapes to broadcast. 249 | 250 | Return: 251 | tuple[int]: Broadcasted shape. 252 | """ 253 | shapes = [tuple(int(d) for d in shape) for shape in shapes] 254 | return np.broadcast_shapes(*shapes) 255 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=50.0", 4 | "setuptools_scm[toml]>=6.0", 5 | "setuptools_scm_git_archive", 6 | "wheel>=0.33", 7 | "numpy>=1.16", 8 | "cython>=0.29", 9 | ] 10 | 11 | [tool.setuptools_scm] 12 | write_to = "lab/_version.py" -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | log_level = DEBUG 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements for development, testing, and documentation. 2 | -e . 3 | sphinx 4 | sphinx-rtd-theme 5 | pytest 6 | pytest-cov 7 | pytest-mock 8 | coveralls 9 | fdm 10 | autograd>=1.3 11 | tensorflow>=2 12 | tensorflow-probability[tf] 13 | torch 14 | jax 15 | jaxlib 16 | setuptools_scm[toml] 17 | setuptools_scm_git_archive 18 | isort 19 | black==24.3.0 20 | jupyter-book 21 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = backends 3 | author = Wessel Bruinsma 4 | author_email = wessel.p.bruinsma@gmail.com 5 | description = A generic interface for linear algebra backends 6 | url = https://github.com/wesselb/lab 7 | license = MIT 8 | license_file = LICENCE.txt 9 | long_description = file: README.md 10 | long_description_content_type = text/markdown 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | import numpy as np 5 | from Cython.Build import build_ext 6 | from setuptools import Extension, find_packages, setup 7 | 8 | # Only compile if `LAB_BUILD=1`. 9 | 10 | if os.environ.get("LAB_BUILD", "0") != "1": 11 | ext_modules = [] 12 | 13 | else: 14 | # Include libraries from the OS X Command Line Tools. On OS X Big Sur, these 15 | # libraries are not automatically included anymore. 16 | osx_library_path = "/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/lib" 17 | if os.path.exists(osx_library_path): 18 | if "LIBRARY_PATH" in os.environ and os.environ["LIBRARY_PATH"]: 19 | os.environ["LIBRARY_PATH"] += ":" + osx_library_path 20 | else: 21 | os.environ["LIBRARY_PATH"] = osx_library_path 22 | 23 | # If `xcrun` is available, make sure the includes are added to CPATH. 24 | if subprocess.call("which xcrun", shell=True) == 0: 25 | path = ( 26 | subprocess.check_output("xcrun --show-sdk-path", shell=True) 27 | .strip() 28 | .decode("ascii") 29 | ) 30 | path += "/usr/include" 31 | 32 | # Add to CPATH. 33 | if "CPATH" not in os.environ: 34 | os.environ["CPATH"] = "" 35 | os.environ["CPATH"] += path 36 | 37 | # Default to use gcc as the compiler if `$CC` is not set. 38 | if "CC" not in os.environ or not os.environ["CC"]: 39 | os.environ["CC"] = "gcc" 40 | 41 | # Check whether `gfortran` is available. 42 | if subprocess.call("which gfortran", shell=True) != 0: 43 | if "LAB_GFORTRAN" in os.environ and os.environ["LAB_GFORTRAN"]: 44 | gfortran = os.environ["LAB_GFORTRAN"] 45 | else: 46 | raise RuntimeError( 47 | "`gfortran` cannot be found." 48 | "Please install `gfortran` or specify a binary with `LAB_GFORTRAN`. " 49 | "On OS X, this can be done with `brew install gcc`." 50 | ) 51 | else: 52 | gfortran = "gfortran" 53 | 54 | # Ensure that `$CC` is not symlinked to `clang`, because the default shipped 55 | # one often does not support OpenMP, but `gcc` does. 56 | out = subprocess.check_output("$CC --version", shell=True) 57 | if "clang" in out.decode("ascii"): 58 | # It is. Now try to find a `gcc` to replace it with. 59 | found = False 60 | for i in range(100, 3, -1): 61 | gcci = "gcc-{}".format(i) 62 | if subprocess.call(["which", gcci]) == 0: 63 | # Set both `$CC` and `$CXX` in this case, just to be sure. 64 | os.environ["CC"] = gcci 65 | os.environ["CXX"] = "g++-{}".format(i) 66 | found = True 67 | break 68 | 69 | # Ensure that one was found. 70 | if not found: 71 | raise RuntimeError( 72 | "Your `gcc` runs clang, and no version of `gcc` could be found. " 73 | "Please install `gcc`. " 74 | "On OS X, this can be done with `brew install gcc`." 75 | ) 76 | 77 | # Compile TVPACK if `gfortran` is available. 78 | if gfortran: 79 | if ( 80 | subprocess.call( 81 | f"{gfortran} -fPIC -O2 -c lab/bvn_cdf/tvpack.f -o lab/bvn_cdf/tvpack.o", 82 | shell=True, 83 | ) 84 | != 0 85 | ): 86 | raise RuntimeError("Compilation of TVPACK failed.") 87 | 88 | # Determine which external modules to compile. 89 | ext_modules = [] 90 | 91 | if gfortran: 92 | extra_objects = ["lab/bvn_cdf/tvpack.o"] 93 | extra_link_args = ["-fopenmp"] 94 | 95 | # Allow the libraries for `gfortran` to be explicitly linked. 96 | if "LAB_LIBGFORTRAN" in os.environ and os.environ["LAB_LIBGFORTRAN"]: 97 | extra_objects += [os.environ["LAB_LIBGFORTRAN"]] 98 | else: 99 | extra_link_args += ["-lgfortran"] 100 | 101 | ext_modules.append( 102 | Extension( 103 | "lab.bvn_cdf", 104 | sources=["lab/bvn_cdf/bvn_cdf.pyx"], 105 | include_dirs=[np.get_include()], 106 | extra_compile_args=["-fPIC", "-O2"], 107 | extra_objects=extra_objects, 108 | extra_link_args=extra_link_args, 109 | ) 110 | ) 111 | 112 | requirements = [ 113 | "numpy>=1.16", 114 | "scipy>=1.3", 115 | "plum-dispatch>=2.5.7", 116 | "opt-einsum", 117 | ] 118 | 119 | setup( 120 | packages=find_packages(exclude=["docs"]), 121 | python_requires=">=3.8", 122 | install_requires=requirements, 123 | cmdclass={"build_ext": build_ext}, 124 | ext_modules=ext_modules, 125 | include_package_data=True, 126 | ) 127 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # Add package path. 5 | file_dir = os.path.dirname(__file__) 6 | sys.path.insert(0, os.path.abspath(os.path.join(file_dir, ".."))) 7 | 8 | import jax 9 | 10 | # noinspection PyUnresolvedReferences 11 | import lab.autograd 12 | 13 | # noinspection PyUnresolvedReferences 14 | import lab.jax 15 | 16 | # noinspection PyUnresolvedReferences 17 | import lab.tensorflow 18 | 19 | # noinspection PyUnresolvedReferences 20 | import lab.torch 21 | 22 | # We need `float64`s for testing. 23 | jax.config.update("jax_enable_x64", True) 24 | -------------------------------------------------------------------------------- /tests/test_control_flow.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import lab.jax as B 4 | 5 | # noinspection PyUnresolvedReferences 6 | from .util import check_lazy_shapes 7 | 8 | 9 | def test_controlflowcache(check_lazy_shapes): 10 | cache = B.ControlFlowCache() 11 | 12 | assert not cache.populated 13 | assert str(cache) == repr(cache) == "" 14 | 15 | # Test populating the cache. 16 | with cache: 17 | B.ones(5) 18 | assert str(cache) == repr(cache) == "" 19 | assert cache.populated 20 | 21 | # Test that you can only get an outcome when using a cache. 22 | with pytest.raises(RuntimeError): 23 | B.control_flow.get_outcome("test") 24 | 25 | 26 | def test_cache_cond(check_lazy_shapes): 27 | outcome = {} 28 | 29 | def f_true(x, y): 30 | outcome[0] = True 31 | return x + y 32 | 33 | def f_false(x, y): 34 | outcome[0] = False 35 | return 2 * (x + y) 36 | 37 | def f(x): 38 | return B.cond(x > 0, f_true, f_false, x, x) 39 | 40 | cache_true = B.ControlFlowCache() 41 | cache_false = B.ControlFlowCache() 42 | 43 | # Populate caches: 44 | 45 | with cache_true: 46 | assert f(1) == 2 47 | assert outcome[0] 48 | assert f(-1) == -4 49 | assert not outcome[0] 50 | 51 | with cache_false: 52 | assert f(-1) == -4 53 | assert not outcome[0] 54 | assert f(1) == 2 55 | assert outcome[0] 56 | 57 | # Use caches: 58 | 59 | with cache_true: 60 | assert f(-1) == -2 61 | assert outcome[0] 62 | assert f(1) == 4 63 | assert not outcome[0] 64 | 65 | with cache_false: 66 | assert f(1) == 4 67 | assert not outcome[0] 68 | assert f(-1) == -2 69 | assert outcome[0] 70 | 71 | 72 | def test_control_flow_outcome_conversion(): 73 | def f(x): 74 | B.control_flow.set_outcome("f/x", x, type=str) 75 | if B.control_flow.use_cache: 76 | return B.control_flow.get_outcome("f/x") 77 | else: 78 | return x 79 | 80 | control_flow_cache = B.ControlFlowCache() 81 | 82 | # Populate cache. The `1` will be converted to a string when it is saved. 83 | with control_flow_cache: 84 | assert f(1) == 1 85 | 86 | # Run with cache. Check that the conversion to a string indeed happened. 87 | with control_flow_cache: 88 | assert f(1) == "1" 89 | -------------------------------------------------------------------------------- /tests/test_custom.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import pytest 5 | import tensorflow as tf 6 | import torch 7 | from autograd import grad 8 | from fdm import check_sensitivity, gradient 9 | from plum import isinstance 10 | 11 | import lab as B 12 | from lab.custom import ( 13 | bvn_cdf, 14 | expm, 15 | logm, 16 | s_bvn_cdf, 17 | s_expm, 18 | s_logm, 19 | s_toeplitz_solve, 20 | toeplitz_solve, 21 | ) 22 | from lab.tensorflow.custom import as_tf 23 | from lab.torch.custom import as_torch 24 | 25 | # noinspection PyUnresolvedReferences 26 | from .util import PSD, approx, check_function, check_lazy_shapes 27 | 28 | 29 | def test_as_tf(check_lazy_shapes): 30 | assert isinstance(as_tf(B.randn()), B.TFNumeric) 31 | assert isinstance(as_tf((B.randn(),))[0], B.TFNumeric) 32 | 33 | 34 | def test_as_torch(check_lazy_shapes): 35 | assert isinstance(as_torch(B.randn()), B.TorchNumeric) 36 | assert isinstance(as_torch((B.randn(),))[0], B.TorchNumeric) 37 | 38 | 39 | def check_grad(f, args, kw_args=None, rtol=1e-8): 40 | """Check the gradients of a function. 41 | 42 | Args: 43 | f (function): Function to check gradients of. 44 | args (tuple): Arguments to check `f` at. 45 | kw_args (tuple, optional): Keyword arguments to check `f` at. Defaults 46 | to no keyword arguments. 47 | rtol (float, optional): Relative tolerance. Defaults to `1e-8`. 48 | """ 49 | # Default to no keyword arguments. 50 | if kw_args is None: 51 | kw_args = {} 52 | 53 | # Get the associated function in LAB. 54 | lab_f = getattr(B, f.__name__) 55 | 56 | def create_f_i(i, args_): 57 | # Create a function that only varies the `i`th argument. 58 | def f_i(x): 59 | return B.mean(lab_f(*(args_[:i] + (x,) + args_[i + 1 :]), **kw_args)) 60 | 61 | return f_i 62 | 63 | # Walk through the arguments. 64 | for i in range(len(args)): 65 | # Numerically compute gradient. 66 | f_i = create_f_i(i, args) 67 | numerical_grad = gradient(f_i)(args[i]) 68 | 69 | # Check AutoGrad gradient. 70 | autograd_grad = grad(f_i)(args[i]) 71 | approx(numerical_grad, autograd_grad, rtol=rtol) 72 | 73 | # Check TensorFlow gradient. 74 | tf_args = tuple([as_tf(arg) for arg in args]) 75 | f_i = tf.function(create_f_i(i, tf_args), autograph=False) 76 | with tf.GradientTape() as t: 77 | t.watch(tf_args[i]) 78 | tf_grad = t.gradient(f_i(tf_args[i]), tf_args[i]).numpy() 79 | approx(numerical_grad, tf_grad, rtol=rtol) 80 | 81 | # Check PyTorch gradient. 82 | torch_args = tuple([as_torch(arg, grad=False) for arg in args]) 83 | f_i = torch.jit.trace(create_f_i(i, torch_args), torch_args[i]) 84 | arg = torch_args[i].requires_grad_(True) 85 | f_i(arg).backward() 86 | approx(numerical_grad, arg.grad, rtol=rtol) 87 | 88 | # Check JAX gradient. 89 | jax_args = tuple([jnp.asarray(arg) for arg in args]) 90 | f_i = create_f_i(i, jax_args) 91 | jax_grad = jax.jit(jax.grad(f_i))(jax_args[i]) 92 | approx(numerical_grad, jax_grad, rtol=rtol) 93 | 94 | 95 | def test_toeplitz_solve(check_lazy_shapes): 96 | check_sensitivity( 97 | toeplitz_solve, s_toeplitz_solve, (B.randn(3), B.randn(2), B.randn(3)) 98 | ) 99 | check_sensitivity( 100 | toeplitz_solve, s_toeplitz_solve, (B.randn(3), B.randn(2), B.randn(3, 4)) 101 | ) 102 | check_grad(toeplitz_solve, (B.randn(3), B.randn(2), B.randn(3))) 103 | check_grad(toeplitz_solve, (B.randn(3), B.randn(2), B.randn(3, 4))) 104 | 105 | 106 | def test_bvn_cdf(check_lazy_shapes): 107 | check_sensitivity(bvn_cdf, s_bvn_cdf, (B.rand(3), B.rand(3), B.rand(3))) 108 | check_grad(bvn_cdf, (B.rand(3), B.rand(3), B.rand(3))) 109 | 110 | # Check that function runs on both `float32`s and `float64`s. 111 | a, b, c = B.rand(3), B.rand(3), B.rand(3) 112 | approx( 113 | B.bvn_cdf(a, b, c), 114 | B.bvn_cdf(B.cast(np.float32, a), B.cast(np.float32, b), B.cast(np.float32, c)), 115 | ) 116 | 117 | # Check that, in JAX, the function check the shape of the inputs. 118 | with pytest.raises(ValueError): 119 | B.bvn_cdf( 120 | B.rand(jnp.float32, 2), B.rand(jnp.float32, 3), B.rand(jnp.float32, 3) 121 | ) 122 | with pytest.raises(ValueError): 123 | B.bvn_cdf( 124 | B.rand(jnp.float32, 3), B.rand(jnp.float32, 2), B.rand(jnp.float32, 3) 125 | ) 126 | with pytest.raises(ValueError): 127 | B.bvn_cdf( 128 | B.rand(jnp.float32, 3), B.rand(jnp.float32, 3), B.rand(jnp.float32, 2) 129 | ) 130 | 131 | 132 | def test_expm(check_lazy_shapes): 133 | check_sensitivity(expm, s_expm, (B.randn(3, 3),)) 134 | check_grad(expm, (B.randn(3, 3),)) 135 | 136 | 137 | def test_logm_forward(check_lazy_shapes): 138 | # This test can be removed once the gradient is implemented and the below test 139 | # passes. 140 | check_function(B.logm, (PSD(3),)) 141 | 142 | 143 | @pytest.mark.xfail 144 | def test_logm(check_lazy_shapes): 145 | mat = B.eye(3) + 0.1 * B.randn(3, 3) 146 | check_sensitivity(logm, s_logm, (mat,)) 147 | check_grad(logm, (mat,)) 148 | -------------------------------------------------------------------------------- /tests/test_random.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import pytest 6 | import tensorflow as tf 7 | import torch 8 | from plum import isinstance 9 | 10 | import lab as B 11 | import lab.autograd 12 | import lab.jax 13 | import lab.tensorflow 14 | import lab.torch 15 | 16 | from .util import PositiveTensor, Tensor, approx, check_lazy_shapes, to_np # noqa 17 | 18 | 19 | @pytest.mark.parametrize( 20 | "dtype, f_plain", 21 | [ 22 | (np.float32, np.random.randn), 23 | (tf.float32, lambda: tf.random.normal(())), 24 | (torch.float32, lambda: torch.randn(())), 25 | (jnp.float32, lambda: 1), 26 | ], 27 | ) 28 | def test_set_seed_set_global_random_state(dtype, f_plain, check_lazy_shapes): 29 | B.set_random_seed(0) 30 | x1 = to_np(B.rand(dtype)) 31 | x2 = to_np(f_plain()) 32 | B.set_random_seed(0) 33 | y1 = to_np(B.rand(dtype)) 34 | y2 = to_np(f_plain()) 35 | assert x1 == y1 36 | assert x2 == y2 37 | 38 | B.set_global_random_state(B.create_random_state(dtype, seed=0)) 39 | x1 = to_np(B.rand(dtype)) 40 | x2 = to_np(f_plain()) 41 | B.set_global_random_state(B.create_random_state(dtype, seed=0)) 42 | y1 = to_np(B.rand(dtype)) 43 | y2 = to_np(f_plain()) 44 | assert x1 == y1 45 | # TODO: Make this work with TF! 46 | if not isinstance(dtype, B.TFDType): 47 | assert x2 == y2 48 | 49 | 50 | @pytest.mark.parametrize("dtype", [np.float32, tf.float32, torch.float32, jnp.float32]) 51 | def test_create_random_state(dtype): 52 | # Test specification without argument. 53 | B.create_random_state(dtype) 54 | 55 | # Check that it does the right thing. 56 | state = B.create_random_state(dtype, seed=0) 57 | state, x1 = B.rand(state, dtype) 58 | state, x2 = B.rand(state, dtype) 59 | x1, x2 = to_np(x1), to_np(x2) 60 | 61 | state = B.create_random_state(dtype, seed=0) 62 | state, y1 = B.rand(state, dtype) 63 | state, y2 = B.rand(state, dtype) 64 | y1, y2 = to_np(y1), to_np(y2) 65 | 66 | assert x1 != x2 67 | assert x1 == y1 68 | assert x2 == y2 69 | 70 | 71 | @pytest.mark.parametrize( 72 | "f, dtype_transform, just_single_arg", 73 | [ 74 | (B.rand, lambda x: x, False), 75 | (B.randn, lambda x: x, False), 76 | (lambda *args: B.randint(*args, lower=0, upper=10), B.dtype_int, False), 77 | (B.randperm, B.dtype_int, True), 78 | (lambda *args: B.randgamma(*args, alpha=0.5, scale=0.5), lambda x: x, False), 79 | (lambda *args: B.randbeta(*args, alpha=0.5, beta=0.5), lambda x: x, False), 80 | ], 81 | ) 82 | @pytest.mark.parametrize("t", [np.float32, tf.float32, torch.float32, jnp.float32]) 83 | def test_random_generators(f, t, dtype_transform, just_single_arg, check_lazy_shapes): 84 | # Test without specifying data type. 85 | if not just_single_arg: 86 | assert B.dtype(f()) is dtype_transform(B.default_dtype) 87 | assert B.shape(f()) == () 88 | assert B.dtype(f(2)) is dtype_transform(B.default_dtype) 89 | assert B.shape(f(2)) == (2,) 90 | if not just_single_arg: 91 | assert B.dtype(f(2, 3)) is dtype_transform(B.default_dtype) 92 | assert B.shape(f(2, 3)) == (2, 3) 93 | 94 | # Test with specifying data type. 95 | state = B.create_random_state(t, 0) 96 | 97 | # Test direct specification. 98 | if not just_single_arg: 99 | assert B.dtype(f(t)) is dtype_transform(t) 100 | assert B.shape(f(t)) == () 101 | assert B.dtype(f(t, 2)) is dtype_transform(t) 102 | assert B.shape(f(t, 2)) == (2,) 103 | if not just_single_arg: 104 | assert B.dtype(f(t, 2, 3)) is dtype_transform(t) 105 | assert B.shape(f(t, 2, 3)) == (2, 3) 106 | 107 | # Test state specification. 108 | if not just_single_arg: 109 | assert isinstance(f(state, t)[0], B.RandomState) 110 | assert B.dtype(f(state, t)[1]) is dtype_transform(t) 111 | assert B.shape(f(state, t)[1]) == () 112 | assert isinstance(f(state, t, 2)[0], B.RandomState) 113 | assert B.dtype(f(state, t, 2)[1]) is dtype_transform(t) 114 | assert B.shape(f(state, t, 2)[1]) == (2,) 115 | if not just_single_arg: 116 | assert isinstance(f(state, t, 2, 3)[0], B.RandomState) 117 | assert B.dtype(f(state, t, 2, 3)[1]) is dtype_transform(t) 118 | assert B.shape(f(state, t, 2, 3)[1]) == (2, 3) 119 | 120 | if not just_single_arg: 121 | # Test reference specification. 122 | assert B.dtype(f(f(t))) is dtype_transform(t) 123 | assert B.shape(f(f())) == () 124 | assert B.dtype(f(f(t, 2))) is dtype_transform(t) 125 | assert B.shape(f(f(t, 2))) == (2,) 126 | assert B.dtype(f(f(t, 2, 3))) is dtype_transform(t) 127 | assert B.shape(f(f(t, 2, 3))) == (2, 3) 128 | 129 | # Test state and reference specification. 130 | assert isinstance(f(state, f(t))[0], B.RandomState) 131 | assert B.dtype(f(state, f(t))[1]) is dtype_transform(t) 132 | assert B.shape(f(state, f(t))[1]) == () 133 | assert isinstance(f(state, f(t, 2))[0], B.RandomState) 134 | assert B.dtype(f(state, f(t, 2))[1]) is dtype_transform(t) 135 | assert B.shape(f(state, f(t, 2))[1]) == (2,) 136 | assert isinstance(f(state, f(t, 2, 3))[0], B.RandomState) 137 | assert B.dtype(f(state, f(t, 2, 3))[1]) is dtype_transform(t) 138 | assert B.shape(f(state, f(t, 2, 3))[1]) == (2, 3) 139 | 140 | 141 | @pytest.mark.parametrize("t", [np.float32, tf.float32, torch.float32, jnp.float32]) 142 | def test_randcat_correctness(t, check_lazy_shapes): 143 | assert int(B.randcat(B.cast(t, np.array([1.0, 0.0, 0.0])))) == 0 144 | assert int(B.randcat(B.cast(t, np.array([0.0, 1.0, 0.0])))) == 1 145 | assert int(B.randcat(B.cast(t, np.array([0.0, 0.0, 1.0])))) == 2 146 | 147 | 148 | @pytest.mark.parametrize("t", [np.float32, tf.float32, torch.float32, jnp.float32]) 149 | def test_randint_bounds(t, check_lazy_shapes): 150 | assert B.randint(t, lower=10, upper=11) == 10 151 | 152 | 153 | @pytest.mark.parametrize("t", [np.float32, tf.float32, torch.float32, jnp.float32]) 154 | def test_randgamma_parameters(t, check_lazy_shapes): 155 | approx(B.randgamma(t, alpha=1, scale=0), 0, atol=1e-6) 156 | 157 | 158 | @pytest.mark.parametrize("t", [np.float32, tf.float32, torch.float32, jnp.float32]) 159 | def test_randgamma_broadcasting(t, check_lazy_shapes): 160 | assert B.shape(B.randgamma(t, alpha=1, scale=0)) == () 161 | assert B.shape(B.randgamma(t, alpha=B.rand(5), scale=0)) == (5,) 162 | assert B.shape(B.randgamma(t, alpha=B.rand(5), scale=B.rand(5))) == (5,) 163 | assert B.shape(B.randgamma(t, alpha=1, scale=B.rand(5))) == (5,) 164 | assert B.shape(B.randgamma(t, 3, alpha=B.rand(5), scale=0)) == (3, 5) 165 | assert B.shape(B.randgamma(t, 3, alpha=B.rand(5), scale=B.rand(5))) == (3, 5) 166 | assert B.shape(B.randgamma(t, 3, alpha=1, scale=B.rand(5))) == (3, 5) 167 | 168 | 169 | @pytest.mark.parametrize("t", [np.float32, tf.float32, torch.float32, jnp.float32]) 170 | def test_randbeta_parameters(t, check_lazy_shapes): 171 | approx(B.randbeta(t, alpha=1e-6, beta=1), 0, atol=1e-6) 172 | approx(B.randbeta(t, alpha=1, beta=1e-6), 1, atol=1e-6) 173 | 174 | 175 | def test_torch_global_random_state(mocker, monkeypatch): 176 | # Check CPU specifications. 177 | B.ActiveDevice.active_name = None 178 | assert B.global_random_state(torch.float32) is torch.random.default_generator 179 | B.ActiveDevice.active_name = "cpu" 180 | assert B.global_random_state(torch.float32) is torch.random.default_generator 181 | 182 | # Test that `cuda.seed` is called to initialise the default generators. 183 | torch_cuda_init = mocker.patch("torch.cuda.init") 184 | B.ActiveDevice.active_name = "cuda" 185 | # The call is allowed to fail, because `torch.cuda.seed` is mocked, so it won't 186 | # actually populate `torch.cuda.default_generators`. 187 | with pytest.raises(IndexError): 188 | B.global_random_state(torch.float32) 189 | assert torch_cuda_init.called_once() 190 | 191 | # Now set some fake default generators. 192 | monkeypatch.setattr("torch.cuda.default_generators", (33, 34)) 193 | monkeypatch.setattr("torch.mps._get_default_mps_generator", lambda: 35) 194 | 195 | # Check GPU specifications. 196 | B.ActiveDevice.active_name = "cuda" 197 | assert B.global_random_state(torch.float32) == 33 198 | B.ActiveDevice.active_name = "gpu" 199 | assert B.global_random_state(torch.float32) == 33 200 | B.ActiveDevice.active_name = "gpu:0" 201 | assert B.global_random_state(torch.float32) == 33 202 | B.ActiveDevice.active_name = "gpu:1" 203 | assert B.global_random_state(torch.float32) == 34 204 | with pytest.raises(RuntimeError): 205 | B.ActiveDevice.active_name = "weird-device" 206 | B.global_random_state(torch.float32) 207 | 208 | # Check MPS specification. 209 | B.ActiveDevice.active_name = "mps" 210 | assert B.global_random_state(torch.float32) == 35 211 | B.ActiveDevice.active_name = "mps:0" 212 | assert B.global_random_state(torch.float32) == 35 213 | with pytest.raises( 214 | ValueError, 215 | match="(?i)cannot specify a device number for PyTorch MPS", 216 | ): 217 | B.ActiveDevice.active_name = "mps:1" 218 | B.global_random_state(torch.float32) 219 | 220 | # Reset back to defaults. 221 | B.ActiveDevice.active_name = None 222 | 223 | 224 | @pytest.mark.parametrize("f", [B.rand, B.randn]) 225 | def test_conversion_warnings(f, check_lazy_shapes): 226 | with warnings.catch_warnings(record=True) as w: 227 | warnings.simplefilter("always") 228 | 229 | # Trigger the warning! 230 | f(int, 5) 231 | 232 | assert len(w) == 1 233 | 234 | 235 | _test_randcat_ps = [PositiveTensor(2).forms(), PositiveTensor(3, 2).forms()] 236 | 237 | 238 | @pytest.mark.parametrize("p", sum(_test_randcat_ps, [])) 239 | def test_randcat(p, check_lazy_shapes): 240 | state = B.create_random_state(B.dtype(p)) 241 | 242 | # Determine the shape of a single sample. 243 | if p is not None: 244 | sample_shape = B.shape(p)[:-1] 245 | else: 246 | sample_shape = () 247 | 248 | # Check shape. 249 | assert B.shape(B.randcat(p)) == sample_shape 250 | assert B.shape(B.randcat(p, 5)) == (5,) + sample_shape 251 | assert B.shape(B.randcat(p, 5, 5)) == (5, 5) + sample_shape 252 | 253 | assert isinstance(B.randcat(state, p)[0], B.RandomState) 254 | assert B.shape(B.randcat(state, p)[1]) == sample_shape 255 | assert B.shape(B.randcat(state, p, 5)[1]) == (5,) + sample_shape 256 | assert B.shape(B.randcat(state, p, 5, 5)[1]) == (5, 5) + sample_shape 257 | 258 | # Check correctness. 259 | dtype = B.dtype(p) 260 | choices = set(to_np(B.randcat(B.ones(dtype, 5), 1000))) 261 | assert choices == set(to_np(B.range(dtype, 5))) 262 | 263 | 264 | def _test_choice_with_p(forms): 265 | pairs = [(form, None) for form in forms] 266 | for alternate in _test_randcat_ps: 267 | pairs += list(zip(forms, alternate)) 268 | return pairs 269 | 270 | 271 | @pytest.mark.parametrize( 272 | "x,p", 273 | _test_choice_with_p(Tensor(2).forms()) 274 | + _test_choice_with_p(Tensor(2, 3).forms()) 275 | + _test_choice_with_p(Tensor(2, 3, 4).forms()), 276 | ) 277 | def test_choice(x, p, check_lazy_shapes): 278 | state = B.create_random_state(B.dtype(x)) 279 | 280 | # Determine the shape of a single sample. 281 | sample_shape = B.shape(x)[1:] 282 | if p is not None: 283 | sample_shape = B.shape(p)[:-1] + sample_shape 284 | 285 | # Make `p` a dictionary so that we can optionally give it. 286 | p = {"p": p} 287 | 288 | # Check shape. 289 | assert B.shape(B.choice(x, **p)) == sample_shape 290 | assert B.shape(B.choice(x, 5, **p)) == (5,) + sample_shape 291 | assert B.shape(B.choice(x, 5, 5, **p)) == (5, 5) + sample_shape 292 | 293 | assert isinstance(B.choice(state, x, **p)[0], B.RandomState) 294 | assert B.shape(B.choice(state, x, **p)[1]) == sample_shape 295 | assert B.shape(B.choice(state, x, 5, **p)[1]) == (5,) + sample_shape 296 | assert B.shape(B.choice(state, x, 5, 5, **p)[1]) == (5, 5) + sample_shape 297 | 298 | # Check correctness. 299 | dtype = B.dtype(x) 300 | choices = set(to_np(B.choice(B.range(dtype, 5), 1000))) 301 | assert choices == set(to_np(B.range(dtype, 5))) 302 | -------------------------------------------------------------------------------- /tests/test_shape.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import lab as B 4 | from lab.shape import Dimension, Shape 5 | 6 | 7 | def test_shape(): 8 | shape = Shape(5, 2, 3) 9 | 10 | # Test indexing. 11 | assert shape[0] == 5 12 | assert shape[1] == 2 13 | assert shape[2] == 3 14 | assert isinstance(shape[0:1], Shape) 15 | assert shape[0:2] == Shape(5, 2) 16 | 17 | # Test comparisons. 18 | assert shape == Shape(5, 2, 3) 19 | assert shape != Shape(5, 2, 4) 20 | 21 | # Test concatenation with another shape. 22 | shape2 = Shape(7, 8, 9) 23 | assert shape + shape2 == Shape(5, 2, 3, 7, 8, 9) 24 | assert shape.__radd__(shape2) == Shape(7, 8, 9, 5, 2, 3) 25 | assert isinstance((shape + shape2).dims[0], int) 26 | assert isinstance((shape.__radd__(shape2)).dims[0], int) 27 | 28 | # Test concatenation with a tuple. 29 | assert shape + (7, 8, 9) == Shape(5, 2, 3, 7, 8, 9) 30 | assert (7, 8, 9) + shape == Shape(7, 8, 9, 5, 2, 3) 31 | assert isinstance((shape + (7, 8, 9)).dims[0], int) 32 | assert isinstance(((7, 8, 9) + shape).dims[0], int) 33 | 34 | # Test conversion of doubly wrapped indices. 35 | assert isinstance(Shape(Dimension(1)).dims[0], int) 36 | 37 | # Test other operations. 38 | assert reversed(shape) == Shape(3, 2, 5) 39 | assert len(shape) == 3 40 | assert tuple(shape) == (Dimension(5), Dimension(2), Dimension(3)) 41 | 42 | # Test representation. 43 | assert str(Shape()) == "()" 44 | assert repr(Shape()) == "Shape()" 45 | assert str(Shape(1)) == "(1,)" 46 | assert repr(Shape(1)) == "Shape(1)" 47 | assert str(Shape(1, 2)) == "(1, 2)" 48 | assert repr(Shape(1, 2)) == "Shape(1, 2)" 49 | 50 | # Test hashing. 51 | assert hash(Shape(1, 2)) == hash((1, 2)) 52 | 53 | # Test conversion to NumPy. 54 | assert isinstance(B.to_numpy(Shape(1, 2)), tuple) 55 | assert B.to_numpy(Shape(1, 2)) == (1, 2) 56 | 57 | 58 | def test_dimension(): 59 | d = Dimension(5) 60 | 61 | assert int(d) is 5 62 | with pytest.raises(TypeError) as e: 63 | len(d) 64 | assert "object of type 'int' has no len()" in str(e.value) 65 | with pytest.raises(TypeError) as e: 66 | iter(d) 67 | assert "'int' object is not iterable" in str(e.value) 68 | 69 | # Test comparisons. 70 | assert d == 5 71 | assert d >= 5 72 | assert d > 4 73 | assert d <= 5 74 | assert d < 6 75 | 76 | # Test that the dimension automatically unwraps. 77 | assert d + 1 is 6 78 | assert 1 + d is 6 79 | assert d - 1 is 4 80 | assert 1 - d is -4 81 | assert d * 1 is 5 82 | assert 1 * d is 5 83 | assert isinstance(d / 5, float) 84 | assert d / 5 == 1 85 | assert 5 / d == 1 86 | assert d // 2 == 2 87 | assert 11 // d == 2 88 | assert -d is -5 89 | assert d**2 is 25 90 | 91 | # Test representation. 92 | assert repr(d) == str(d) == "5" 93 | 94 | # Test hashing. 95 | assert hash(d) == hash(5) 96 | -------------------------------------------------------------------------------- /tests/test_shaping.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import pytest 4 | import tensorflow as tf 5 | import torch 6 | from plum import NotFoundLookupError 7 | 8 | import lab as B 9 | from lab.shape import Shape 10 | 11 | # noinspection PyUnresolvedReferences 12 | from .util import ( 13 | List, 14 | Matrix, 15 | Tensor, 16 | Tuple, 17 | Value, 18 | approx, 19 | check_function, 20 | check_lazy_shapes, 21 | ) 22 | 23 | 24 | @pytest.mark.parametrize("f", [B.shape, B.rank, B.length, B.size]) 25 | def test_sizing(f, check_lazy_shapes): 26 | check_function(f, (Tensor(),), {}, assert_dtype=False) 27 | check_function( 28 | f, 29 | ( 30 | Tensor( 31 | 3, 32 | ), 33 | ), 34 | {}, 35 | assert_dtype=False, 36 | ) 37 | check_function(f, (Tensor(3, 4),), {}, assert_dtype=False) 38 | check_function(f, (Tensor(3, 4, 5),), {}, assert_dtype=False) 39 | 40 | 41 | @pytest.mark.parametrize( 42 | "x,shape", 43 | [ 44 | (1, ()), 45 | ([], (0,)), 46 | ([5], (1,)), 47 | ([[5], [6]], (2, 1)), 48 | ((), (0,)), 49 | ((5,), (1,)), 50 | (((5,), (2,)), (2, 1)), 51 | ], 52 | ) 53 | def test_shape(x, shape, check_lazy_shapes): 54 | assert B.shape(x) == shape 55 | 56 | 57 | def test_subshape(check_lazy_shapes): 58 | assert B.shape(B.zeros(2), 0) == 2 59 | assert B.shape(B.zeros(2, 3, 4), 1) == 3 60 | assert B.shape(B.zeros(2, 3, 4), 0, 2) == (2, 4) 61 | assert B.shape(B.zeros(2, 3, 4), 0, 1, 2) == (2, 3, 4) 62 | 63 | # Check for possible infinite recursion. 64 | with pytest.raises(NotFoundLookupError): 65 | B.shape(None, 1) 66 | 67 | 68 | def test_lazy_shape(): 69 | a = B.randn(2, 2) 70 | 71 | # By default, it should be off. 72 | assert isinstance(B.shape(a), tuple) 73 | 74 | # Turn on. 75 | with B.lazy_shapes(): 76 | assert isinstance(B.shape(a), Shape) 77 | 78 | # Force lazy shapes to be off again. 79 | B.lazy_shapes.enabled = False 80 | assert isinstance(B.shape(a), tuple) 81 | 82 | # Turn on again. 83 | with B.lazy_shapes(): 84 | assert isinstance(B.shape(a), Shape) 85 | 86 | # Should remain off. 87 | assert isinstance(B.shape(a), tuple) 88 | 89 | 90 | def test_is_scalar(check_lazy_shapes): 91 | assert B.is_scalar(1.0) 92 | assert not B.is_scalar(np.array([1.0])) 93 | 94 | 95 | def test_expand_dims(check_lazy_shapes): 96 | check_function(B.expand_dims, (Tensor(3, 4, 5),), {"axis": Value(0, 1)}) 97 | 98 | # Test keyword `times`. 99 | assert B.shape(B.expand_dims(B.ones(2), axis=-1, times=1)) == (2, 1) 100 | assert B.shape(B.expand_dims(B.ones(2), axis=-1, times=2)) == (2, 1, 1) 101 | assert B.shape(B.expand_dims(B.ones(2), axis=-1, times=3)) == (2, 1, 1, 1) 102 | assert B.shape(B.expand_dims(B.ones(2), axis=0, times=1)) == (1, 2) 103 | assert B.shape(B.expand_dims(B.ones(2), axis=0, times=2)) == (1, 1, 2) 104 | assert B.shape(B.expand_dims(B.ones(2), axis=0, times=3)) == (1, 1, 1, 2) 105 | 106 | # Test keyword `ignore_scalar`. 107 | assert B.expand_dims(1, axis=-1, ignore_scalar=False) is not 1 108 | assert B.expand_dims(1, axis=-1, ignore_scalar=True) is 1 109 | 110 | 111 | def test_squeeze(check_lazy_shapes): 112 | check_function(B.squeeze, (Tensor(3, 4, 5),)) 113 | check_function(B.squeeze, (Tensor(1, 4, 5),)) 114 | check_function(B.squeeze, (Tensor(1, 4, 5),), {"axis": Value(None, 0)}) 115 | check_function(B.squeeze, (Tensor(3, 1, 5),)) 116 | check_function(B.squeeze, (Tensor(3, 1, 5),), {"axis": Value(None, 1)}) 117 | check_function(B.squeeze, (Tensor(1, 4, 1),)) 118 | check_function(B.squeeze, (Tensor(1, 4, 1),), {"axis": Value(None, 0, 2)}) 119 | 120 | # Test squeezing lists and tuples 121 | assert B.squeeze((1,)) == 1 122 | assert B.squeeze((1, 2)) == (1, 2) 123 | assert B.squeeze([1]) == 1 124 | assert B.squeeze([1, 2]) == [1, 2] 125 | 126 | 127 | @pytest.mark.parametrize( 128 | "rank, shape, expected_shape", 129 | [ 130 | # `rank=2`, the default: 131 | (None, (), (1, 1)), 132 | (None, (2,), (2, 1)), 133 | (None, (2, 3), (2, 3)), 134 | (None, (2, 3, 4), (2, 3, 4)), 135 | # `rank=1`: 136 | (1, (), (1,)), 137 | (1, (2,), (2,)), 138 | (1, (2, 3), (2, 3)), 139 | ], 140 | ) 141 | def test_uprank(rank, shape, expected_shape, check_lazy_shapes): 142 | kw_args = {} 143 | if rank is not None: 144 | kw_args["rank"] = rank 145 | approx(B.uprank(B.ones(*shape), **kw_args), B.ones(*expected_shape)) 146 | 147 | 148 | @pytest.mark.parametrize( 149 | "rank, preserve, shape, expected_shape", 150 | [ 151 | # `rank = 2`, the default: 152 | (None, None, (), ()), 153 | (None, None, (2,), (2,)), 154 | (None, None, (2, 1), (2, 1)), 155 | (None, None, (2, 3, 4), (2, 3, 4)), 156 | (None, None, (2, 3, 1), (2, 3)), 157 | (None, None, (2, 1, 3), (2, 3)), 158 | (None, None, (1, 2, 3), (2, 3)), 159 | (None, False, (2, 3, 1), (2, 3)), 160 | (None, False, (2, 1, 3), (2, 3)), 161 | (None, False, (1, 2, 3), (2, 3)), 162 | (None, True, (2, 3, 1), (2, 3)), 163 | (None, True, (2, 1, 3), (2, 1, 3)), 164 | (None, True, (1, 2, 3), (1, 2, 3)), 165 | # `rank = 1`: 166 | (1, None, (), ()), 167 | (1, None, (2,), (2,)), 168 | (1, None, (2, 2), (2, 2)), 169 | (1, None, (2, 1), (2,)), 170 | (1, None, (1, 2), (2,)), 171 | (1, False, (2, 1), (2,)), 172 | (1, False, (1, 2), (2,)), 173 | (1, True, (2, 1), (2,)), 174 | (1, True, (1, 2), (1, 2)), 175 | ], 176 | ) 177 | def test_downrank(rank, preserve, shape, expected_shape, check_lazy_shapes): 178 | kw_args = {} 179 | if rank is not None: 180 | kw_args["rank"] = rank 181 | if preserve is not None: 182 | kw_args["preserve"] = preserve 183 | approx( 184 | B.downrank(B.ones(*shape), **kw_args), 185 | B.ones(*expected_shape), 186 | ) 187 | 188 | 189 | @pytest.mark.parametrize("source_shape", [(1, 1, 1), (1, 1, 4), (1, 3, 4), (2, 3, 4)]) 190 | def test_broadcast_to(check_lazy_shapes, source_shape): 191 | def f(x): 192 | return B.broadcast_to(x, 2, 3, 4) 193 | 194 | check_function(f, (Tensor(*source_shape),)) 195 | 196 | 197 | def test_diag(check_lazy_shapes): 198 | check_function(B.diag, (Tensor(3),)) 199 | check_function(B.diag, (Tensor(3, 3),)) 200 | # Test rank check for TensorFlow. 201 | with pytest.raises(ValueError): 202 | B.diag(Tensor().tf()) 203 | 204 | 205 | def test_diag_extract(check_lazy_shapes): 206 | check_function(B.diag_extract, (Tensor(3, 3),)) 207 | check_function(B.diag_extract, (Tensor(2, 3, 3),)) 208 | 209 | 210 | def test_diag_construct(check_lazy_shapes): 211 | check_function(B.diag_construct, (Tensor(3),)) 212 | check_function(B.diag_construct, (Tensor(2, 3),)) 213 | # Test rank check for fallback. 214 | with pytest.raises(ValueError): 215 | B.diag_construct(Tensor().np()) 216 | 217 | 218 | def test_flatten(check_lazy_shapes): 219 | check_function(B.flatten, (Tensor(3),)) 220 | check_function(B.flatten, (Tensor(3, 4),)) 221 | 222 | 223 | @pytest.mark.parametrize("offset", [-2, -1, 0, 1, 2]) 224 | @pytest.mark.parametrize("batch_shape", [(), (5,)]) 225 | def test_vec_to_tril(offset, batch_shape, check_lazy_shapes): 226 | n = B.length(B.tril_to_vec(B.ones(7, 7), offset=offset)) 227 | check_function(B.vec_to_tril, (Tensor(*batch_shape, n),), {"offset": Value(offset)}) 228 | 229 | 230 | @pytest.mark.parametrize("batch_shape", [(), (5,)]) 231 | def test_tril_to_vec(batch_shape, check_lazy_shapes): 232 | check_function( 233 | B.tril_to_vec, (Tensor(*batch_shape, 6, 6),), {"offset": Value(-1, 0, 1)} 234 | ) 235 | 236 | 237 | @pytest.mark.parametrize("offset", [-2, -1, 0, 1, 2]) 238 | @pytest.mark.parametrize("batch_shape", [(), (5,)]) 239 | def test_vec_to_tril_and_back_correctness(offset, batch_shape, check_lazy_shapes): 240 | n = B.length(B.tril_to_vec(B.ones(7, 7), offset=offset)) 241 | for vec in Tensor(*batch_shape, n).forms(): 242 | mat = B.vec_to_tril(vec, offset=offset) 243 | approx(B.tril_to_vec(mat, offset=offset), vec) 244 | 245 | 246 | def test_vec_to_tril_and_back_exceptions(check_lazy_shapes): 247 | # Check rank checks. 248 | for x in Tensor().forms(): 249 | with pytest.raises(ValueError): 250 | B.vec_to_tril(x) 251 | with pytest.raises(ValueError): 252 | B.tril_to_vec(x) 253 | for x in Tensor(3).forms(): 254 | with pytest.raises(ValueError): 255 | B.tril_to_vec(x) 256 | 257 | # Check square checks. 258 | for x in Tensor(3, 4).forms(): 259 | with pytest.raises(ValueError): 260 | B.tril_to_vec(x) 261 | for x in Tensor(3, 4, 5).forms(): 262 | with pytest.raises(ValueError): 263 | B.tril_to_vec(x) 264 | 265 | 266 | def test_stack(check_lazy_shapes): 267 | check_function(B.stack, (Matrix(3), Matrix(3), Matrix(3)), {"axis": Value(0, 1)}) 268 | 269 | 270 | def test_unstack(check_lazy_shapes): 271 | check_function( 272 | B.unstack, 273 | (Tensor(3, 4, 5),), 274 | {"axis": Value(0, 1, 2), "squeeze": Value(True, False)}, 275 | ) 276 | 277 | 278 | def test_reshape(check_lazy_shapes): 279 | check_function(B.reshape, (Tensor(3, 4, 5), Value(3), Value(20))) 280 | check_function(B.reshape, (Tensor(3, 4, 5), Value(12), Value(5))) 281 | 282 | 283 | def test_concat(check_lazy_shapes): 284 | check_function(B.concat, (Matrix(3), Matrix(3), Matrix(3)), {"axis": Value(0, 1)}) 285 | 286 | 287 | def test_concat2d(check_lazy_shapes): 288 | check_function(B.concat2d, (List(Matrix(3), Matrix(3)), List(Matrix(3), Matrix(3)))) 289 | 290 | 291 | @pytest.mark.parametrize("r1", [1, 2]) 292 | @pytest.mark.parametrize("r2", [1, 2]) 293 | def test_tile(r1, r2, check_lazy_shapes): 294 | check_function(B.tile, (Tensor(3, 4), Value(r1), Value(r2))) 295 | 296 | 297 | def test_take_consistency(check_lazy_shapes): 298 | # Check consistency between indices and mask. 299 | check_function( 300 | B.take, 301 | (Matrix(3, 3), Value([0, 1], [True, True, False])), 302 | {"axis": Value(0, 1, -1)}, 303 | ) 304 | 305 | # Test PyTorch separately, because it has a separate implementation for framework 306 | # masks or indices. 307 | for indices_or_mask in [ 308 | torch.tensor([True, True, False], dtype=torch.bool), 309 | torch.tensor([0, 1], dtype=torch.int32), 310 | torch.tensor([0, 1], dtype=torch.int64), 311 | ]: 312 | a = B.randn(torch.float32, 3, 3) 313 | approx(B.take(a, indices_or_mask), a[[0, 1]]) 314 | 315 | 316 | def test_take_consistency_order(check_lazy_shapes): 317 | # Check order of indices. 318 | check_function(B.take, (Matrix(3, 4), Value([2, 1])), {"axis": Value(0, 1, -1)}) 319 | 320 | 321 | def test_take_indices_rank(check_lazy_shapes): 322 | # Check that indices must be rank 1. 323 | for a in Matrix(3, 4).forms(): 324 | with pytest.raises(ValueError): 325 | B.take(a, [[0], [1]]) 326 | 327 | 328 | @pytest.mark.parametrize( 329 | "indices_or_mask", 330 | [[], [0, 2], [True, False, True], (), (0, 2), (True, False, True)], 331 | ) 332 | def test_take_list_tuple(check_lazy_shapes, indices_or_mask): 333 | check_function( 334 | B.take, (Matrix(3, 3, 3), Value(indices_or_mask)), {"axis": Value(0, 1, 2, -1)} 335 | ) 336 | 337 | 338 | def test_take_tf(check_lazy_shapes): 339 | # Check that TensorFlow also takes in tensors. 340 | a = Matrix(3, 4, 5) 341 | ref = Tensor(3) 342 | approx(B.take(a.tf(), ref.tf() > 0), B.take(a.np(), ref.np() > 0)) 343 | approx(B.take(a.tf(), ref.np() > 0), B.take(a.np(), ref.np() > 0)) 344 | approx(B.take(a.tf(), B.range(tf.int64, 2)), B.take(a.np(), B.range(2))) 345 | approx(B.take(a.tf(), B.range(np.int64, 2)), B.take(a.np(), B.range(2))) 346 | 347 | 348 | @pytest.mark.parametrize( 349 | "dtype", 350 | [ 351 | np.float32, 352 | np.float64, 353 | jnp.float32, 354 | jnp.float64, 355 | tf.float32, 356 | tf.float64, 357 | torch.float32, 358 | torch.float64, 359 | ], 360 | ) 361 | def test_take_perm(dtype, check_lazy_shapes): 362 | a = B.range(dtype, 10) 363 | perm = B.randperm(B.dtype_int(dtype), 10) 364 | a2 = B.take(a, perm) 365 | assert B.dtype(perm) == B.dtype_int(dtype) 366 | assert B.shape(a) == B.shape(a2) 367 | assert B.dtype(a) == B.dtype(a2) 368 | 369 | 370 | def test_submatrix(check_lazy_shapes): 371 | a = Matrix(4, 5).np() 372 | approx(B.submatrix(a, [0, 1]), a[[0, 1], :][:, [0, 1]]) 373 | a = Matrix(3, 4, 5).np() 374 | approx(B.submatrix(a, [0, 1]), a[:, [0, 1], :][:, :, [0, 1]]) 375 | -------------------------------------------------------------------------------- /tests/test_types.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import pytest 5 | import tensorflow as tf 6 | import torch 7 | from autograd import grad 8 | from plum import isinstance 9 | from plum.promotion import _promotion_rule, convert 10 | 11 | import lab as B 12 | 13 | # noinspection PyUnresolvedReferences 14 | from .util import autograd_box, check_lazy_shapes 15 | 16 | 17 | def test_numeric(check_lazy_shapes): 18 | # Test convenient types. 19 | assert isinstance(1, B.Int) 20 | assert isinstance(np.int32(1), B.Int) 21 | assert isinstance(np.uint64(1), B.Int) 22 | 23 | assert isinstance(1.0, B.Float) 24 | assert isinstance(np.float32(1), B.Float) 25 | 26 | assert isinstance(1 + 0j, B.Complex) 27 | assert isinstance(np.complex64(1), B.Complex) 28 | 29 | assert isinstance(True, B.Bool) 30 | assert isinstance(np.bool_(True), B.Bool) 31 | 32 | assert isinstance(np.uint(1), B.Number) 33 | assert isinstance(np.float64(1), B.Number) 34 | assert isinstance(np.complex64(1), B.Number) 35 | 36 | # Test NumPy. 37 | assert isinstance(np.array(1), B.NPNumeric) 38 | 39 | # Test TensorFlow. 40 | assert isinstance(tf.constant(1), B.TFNumeric) 41 | assert isinstance(tf.Variable(1), B.TFNumeric) 42 | 43 | # Test Torch. 44 | assert isinstance(torch.tensor(1), B.TorchNumeric) 45 | 46 | # Test JAX. 47 | assert isinstance(jnp.array(1), B.JAXNumeric) 48 | 49 | # Test general numeric type. 50 | assert isinstance(1, B.Numeric) 51 | assert isinstance(np.bool_(1), B.Numeric) 52 | assert isinstance(np.float64(1), B.Numeric) 53 | assert isinstance(np.array(1), B.Numeric) 54 | assert isinstance(tf.constant(1), B.Numeric) 55 | assert isinstance(torch.tensor(1), B.Numeric) 56 | 57 | # Test promotion. 58 | assert _promotion_rule(np.array(1), tf.constant(1)) == B.TFNumeric 59 | assert _promotion_rule(np.array(1), tf.Variable(1)) == B.TFNumeric 60 | assert _promotion_rule(tf.constant(1), tf.Variable(1)) == B.TFNumeric 61 | assert _promotion_rule(np.array(1), torch.tensor(1)) == B.TorchNumeric 62 | assert _promotion_rule(np.array(1), jnp.array(1)) == B.JAXNumeric 63 | with pytest.raises(TypeError): 64 | _promotion_rule(B.TFNumeric, B.TorchNumeric) 65 | 66 | # Test conversion. 67 | assert isinstance(convert(np.array(1), B.TFNumeric), B.TFNumeric) 68 | assert isinstance(convert(np.array(1), B.TorchNumeric), B.TorchNumeric) 69 | assert isinstance(convert(np.array(1), B.JAXNumeric), B.JAXNumeric) 70 | 71 | 72 | def test_autograd_tracing(check_lazy_shapes): 73 | found_objs = [] 74 | 75 | def f(x): 76 | found_objs.append(x) 77 | return B.sum(x) 78 | 79 | # Test that function runs. 80 | f(np.ones(5)) 81 | found_objs[:] = [] # Clear found objects. 82 | 83 | # Catch AutoGrad object. 84 | grad(f)(np.ones(5)) 85 | 86 | # Test that objects are of the right type. 87 | for obj in found_objs: 88 | assert isinstance(obj, B.AGNumeric) 89 | 90 | 91 | def test_jax_tracing(check_lazy_shapes): 92 | found_objs = [] 93 | 94 | def f(x): 95 | found_objs.append(x) 96 | return B.sum(x) 97 | 98 | # Catch JAX object during JIT and during gradient computation. 99 | jax.grad(f)(np.ones(5)) 100 | jax.jit(f)(np.ones(5)) 101 | 102 | # Test that objects are of the right type. 103 | for obj in found_objs: 104 | assert isinstance(obj, B.JAXNumeric) 105 | 106 | 107 | def test_data_type(check_lazy_shapes): 108 | assert isinstance(np.float32, B.NPDType) 109 | assert isinstance(np.float32, B.DType) 110 | assert isinstance(tf.float32, B.TFDType) 111 | assert isinstance(tf.float32, B.DType) 112 | assert isinstance(torch.float32, B.TorchDType) 113 | assert isinstance(torch.float32, B.DType) 114 | assert isinstance(jnp.float32, B.JAXDType) 115 | assert isinstance(jnp.float32, B.DType) 116 | 117 | # Check that the AutoGrad and JAX data types are just the NumPy data type. Then 118 | # there is nothing left to check. 119 | assert B.AGDType == B.NPDType 120 | 121 | # Test conversion between data types. 122 | assert convert(np.float32, B.TFDType) is tf.float32 123 | assert convert(np.float32, B.TorchDType) is torch.float32 124 | assert convert(np.float32, B.JAXDType) is jnp.float32 125 | assert convert(tf.float32, B.NPDType) is np.float32 126 | assert convert(tf.float32, B.TorchDType) is torch.float32 127 | assert convert(tf.float32, B.JAXDType) is jnp.float32 128 | assert convert(torch.float32, B.NPDType) is np.float32 129 | assert convert(torch.float32, B.TFDType) is tf.float32 130 | assert convert(torch.float32, B.JAXDType) is jnp.float32 131 | assert convert(jnp.float32, B.NPDType) is np.float32 132 | assert convert(jnp.float32, B.TFDType) is tf.float32 133 | assert convert(jnp.float32, B.TorchDType) is torch.float32 134 | 135 | # `torch.bool` has a manual addition, so test it separately. 136 | assert convert(torch.bool, B.NPDType) is bool 137 | 138 | 139 | def test_dtype(check_lazy_shapes): 140 | assert B.dtype(1) is int 141 | assert B.dtype(1.0) is float 142 | assert B.dtype(np.array(1, dtype=np.int32)) is np.int32 143 | assert B.dtype(np.array(1.0, dtype=np.float32)) is np.float32 144 | assert B.dtype(tf.constant(1, dtype=tf.int32)) is tf.int32 145 | assert B.dtype(tf.constant(1.0, dtype=tf.float32)) is tf.float32 146 | assert B.dtype(torch.tensor(1, dtype=torch.int32)) is torch.int32 147 | assert B.dtype(torch.tensor(1.0, dtype=torch.float32)) is torch.float32 148 | assert B.dtype(jnp.array(1, dtype=jnp.int32)) is jnp.int32 149 | assert B.dtype(jnp.array(1.0, dtype=jnp.float32)) is jnp.float32 150 | 151 | # Test tuples, which promote. 152 | assert B.dtype(1, 1) is np.int64 153 | assert B.dtype((1, 1)) is np.int64 154 | assert B.dtype(1, 1.0) is np.float64 155 | assert B.dtype((1, 1.0)) is np.float64 156 | 157 | 158 | def test_issubdtype(check_lazy_shapes): 159 | assert B.issubdtype(np.float32, np.floating) 160 | assert B.issubdtype(tf.float32, np.floating) 161 | assert B.issubdtype(torch.float32, np.floating) 162 | assert B.issubdtype(jnp.float32, np.floating) 163 | assert not B.issubdtype(np.float32, np.integer) 164 | assert not B.issubdtype(tf.float32, np.integer) 165 | assert not B.issubdtype(torch.float32, np.integer) 166 | assert not B.issubdtype(jnp.float32, np.integer) 167 | 168 | 169 | def test_promote_dtypes(check_lazy_shapes): 170 | # Check one-argument case. 171 | assert B.promote_dtypes(int) is int 172 | assert B.promote_dtypes(float) is float 173 | 174 | # Check multi-argument case. 175 | for t_int, t_float in [ 176 | (np.int64, np.float64), 177 | (tf.int64, tf.float64), 178 | (torch.int64, torch.float64), 179 | (jnp.int64, jnp.float64), 180 | ]: 181 | # Also check that the conversion back is right. 182 | assert B.promote_dtypes(t_int, int) is t_int 183 | assert B.promote_dtypes(t_int, int, int) is t_int 184 | assert B.promote_dtypes(t_int, float) is t_float 185 | assert B.promote_dtypes(t_int, int, float) is t_float 186 | 187 | 188 | def test_dtype_float(check_lazy_shapes): 189 | assert B.dtype_float(np.float32) is np.float32 190 | assert B.dtype_float(np.float32(1)) is np.float32 191 | assert B.dtype_float(np.float64) is np.float64 192 | assert B.dtype_float(np.float64(1)) is np.float64 193 | assert B.dtype_float(int) is np.float64 194 | assert B.dtype_float(1) is np.float64 195 | 196 | 197 | def test_dtype_int(check_lazy_shapes): 198 | assert B.dtype_int(np.float32) is np.int32 199 | assert B.dtype_int(np.float32(1)) is np.int32 200 | assert B.dtype_int(np.float64) is np.int64 201 | assert B.dtype_int(np.float64(1)) is np.int64 202 | assert B.dtype_int(int) is int 203 | assert B.dtype_int(1) is int 204 | # Test conversion back to right framework type. This conversion is thoroughly 205 | # tested for `B.promote_dtypes`. 206 | assert B.dtype_int(tf.float32) is tf.int32 207 | assert B.dtype_int(tf.constant(1.0, dtype=tf.float32)) is tf.int32 208 | assert B.dtype_int(tf.float64) is tf.int64 209 | assert B.dtype_int(tf.constant(1.0, dtype=tf.float64)) is tf.int64 210 | 211 | 212 | @pytest.mark.parametrize( 213 | "t, FWRandomState", 214 | [ 215 | (np.float64, B.NPRandomState), 216 | (tf.float64, B.TFRandomState), 217 | (torch.float64, B.TorchRandomState), 218 | (jnp.float64, B.JAXRandomState), 219 | ], 220 | ) 221 | def test_random_state(t, FWRandomState, check_lazy_shapes): 222 | assert isinstance(B.create_random_state(t), FWRandomState) 223 | 224 | 225 | def test_random_state_jax(check_lazy_shapes): 226 | # Splitting a JAX random state gives a NumPy array. 227 | assert isinstance(np.array(1), B.JAXRandomState) 228 | 229 | 230 | @pytest.mark.parametrize( 231 | "t, FWDevice", 232 | [ 233 | (tf.float64, B.TFDevice), 234 | (torch.float64, B.TorchDevice), 235 | (jnp.float64, B.JAXDevice), 236 | ], 237 | ) 238 | def test_device(t, FWDevice, check_lazy_shapes): 239 | a = B.randn(t, 2, 2) 240 | assert isinstance(B.device(a), FWDevice) 241 | assert isinstance(B.device(a), B.Device) 242 | 243 | # Test conversion to string. 244 | assert isinstance(convert(B.device(a), str), str) 245 | 246 | 247 | @pytest.mark.parametrize("t", [B.NP, B.Framework]) 248 | def test_framework_np(t, check_lazy_shapes): 249 | assert isinstance(np.array(1), t) 250 | assert isinstance(np.float32, t) 251 | assert isinstance(B.create_random_state(np.float32), t) 252 | 253 | 254 | @pytest.mark.parametrize("t", [B.AG, B.Framework]) 255 | def test_framework_ag(t, check_lazy_shapes): 256 | assert isinstance(autograd_box(np.array(1)), t) 257 | assert isinstance(np.float32, t) 258 | assert isinstance(B.create_random_state(np.float32), t) 259 | 260 | 261 | @pytest.mark.parametrize("t", [B.TF, B.Framework]) 262 | def test_framework_tf(t, check_lazy_shapes): 263 | assert isinstance(tf.constant(1), t) 264 | assert isinstance(tf.float32, t) 265 | assert isinstance(B.create_random_state(tf.float32), t) 266 | assert isinstance(B.device(tf.constant(1)), t) 267 | 268 | 269 | @pytest.mark.parametrize("t", [B.Torch, B.Framework]) 270 | def test_framework_torch(t, check_lazy_shapes): 271 | assert isinstance(torch.tensor(1), t) 272 | assert isinstance(torch.float32, t) 273 | assert isinstance(B.create_random_state(torch.float32), t) 274 | assert isinstance(B.device(torch.tensor(1)), t) 275 | 276 | 277 | @pytest.mark.parametrize("t", [B.JAX, B.Framework]) 278 | def test_framework_jax(t, check_lazy_shapes): 279 | assert isinstance(jnp.asarray(1), t) 280 | assert isinstance(jnp.float32, t) 281 | assert isinstance(B.create_random_state(jnp.float32), t) 282 | assert isinstance(B.device(jnp.asarray(1)), t) 283 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import plum 3 | import pytest 4 | from plum import NotFoundLookupError 5 | 6 | import lab as B 7 | import lab.autograd as B_autograd 8 | import lab.jax as B_jax 9 | import lab.tensorflow as B_tf 10 | import lab.torch as B_torch 11 | from lab.util import ( 12 | _common_shape, 13 | _translate_index, 14 | abstract, 15 | as_tuple, 16 | batch_computation, 17 | resolve_axis, 18 | ) 19 | 20 | # noinspections PyUnresolvedReferences 21 | from .util import approx, check_lazy_shapes 22 | 23 | 24 | def test_resolve_axis(check_lazy_shapes): 25 | a = B.randn(2, 2, 2) 26 | 27 | # `None`s should just pass through. 28 | assert resolve_axis(a, None) is None 29 | 30 | # Test `negative = False`. 31 | with pytest.raises(ValueError): 32 | resolve_axis(a, -4) 33 | assert resolve_axis(a, -3) == 0 34 | assert resolve_axis(a, -2) == 1 35 | assert resolve_axis(a, -1) == 2 36 | assert resolve_axis(a, 0) == 0 37 | assert resolve_axis(a, 1) == 1 38 | assert resolve_axis(a, 2) == 2 39 | with pytest.raises(ValueError): 40 | resolve_axis(a, 3) 41 | 42 | # Test `negative = True`. 43 | with pytest.raises(ValueError): 44 | resolve_axis(a, -4, negative=True) 45 | assert resolve_axis(a, -3, negative=True) == -3 46 | assert resolve_axis(a, -2, negative=True) == -2 47 | assert resolve_axis(a, -1, negative=True) == -1 48 | assert resolve_axis(a, 0, negative=True) == -3 49 | assert resolve_axis(a, 1, negative=True) == -2 50 | assert resolve_axis(a, 2, negative=True) == -1 51 | with pytest.raises(ValueError): 52 | resolve_axis(a, 3, negative=True) 53 | 54 | 55 | @pytest.mark.parametrize("other", [B_tf, B_torch, B_autograd, B_jax]) 56 | def test_module_mapping(other, check_lazy_shapes): 57 | assert B is other 58 | 59 | 60 | def test_as_tuple(check_lazy_shapes): 61 | assert as_tuple(1) == (1,) 62 | assert as_tuple((1,)) == (1,) 63 | assert as_tuple((1, 2)) == (1, 2) 64 | 65 | 66 | @pytest.mark.parametrize( 67 | "shapes,common_shape", 68 | [ 69 | ([(), ()], ()), 70 | ([(5,), (1,)], (5,)), 71 | ([(2, 5), (1, 5)], (2, 5)), 72 | ([(5,), (1, 5)], (1, 5)), 73 | ([(3, 5), (1,)], (3, 5)), 74 | ], 75 | ) 76 | def test_common_shape(shapes, common_shape, check_lazy_shapes): 77 | assert _common_shape(*shapes) == common_shape 78 | assert _common_shape(*reversed(shapes)) == common_shape 79 | 80 | 81 | @pytest.mark.parametrize("shapes", [[(5,), (6,)], [(5, 2), (5, 3)], [(5, 2), (3,)]]) 82 | def test_common_shape_errors(shapes, check_lazy_shapes): 83 | with pytest.raises(RuntimeError): 84 | _common_shape(*shapes) 85 | with pytest.raises(RuntimeError): 86 | _common_shape(*reversed(shapes)) 87 | 88 | 89 | @pytest.mark.parametrize( 90 | "index,batch_shape,translated_index", 91 | [ 92 | ((5, 2), (3,), (2,)), 93 | ((2, 3, 4), (5, 5), (3, 4)), 94 | ((2, 3, 4), (1, 5), (0, 4)), 95 | ((2, 3, 4), (5, 1), (3, 0)), 96 | ], 97 | ) 98 | def test_translate_index(index, batch_shape, translated_index, check_lazy_shapes): 99 | assert _translate_index(index, batch_shape) == translated_index 100 | 101 | 102 | @pytest.mark.parametrize("index,batch_shape", [((5, 3), (3,)), ((2, 3, 4), (4, 4))]) 103 | def test_translate_index_errors(index, batch_shape, check_lazy_shapes): 104 | with pytest.raises(RuntimeError): 105 | _translate_index(index, batch_shape) 106 | 107 | 108 | @pytest.mark.parametrize("x1_batch", [(), (1,), (2,), (2, 2), (2, 1), (1, 2)]) 109 | @pytest.mark.parametrize("x2_batch", [(), (1,), (2,), (2, 2), (2, 1), (1, 2)]) 110 | def test_batch_computation(x1_batch, x2_batch, check_lazy_shapes): 111 | x1 = np.random.randn(*(x1_batch + (3, 4))) 112 | x2 = np.random.randn(*(x2_batch + (4, 5))) 113 | approx(batch_computation(np.matmul, (x1, x2), (2, 2)), np.matmul(x1, x2)) 114 | 115 | 116 | def test_metadata(check_lazy_shapes): 117 | # Test that the name and docstrings for functions are available. 118 | assert B.transpose.__name__ == "transpose" 119 | assert B.transpose.__doc__ != "" 120 | 121 | 122 | def test_abstract(check_lazy_shapes): 123 | # Test that `promote` and `promote_from` cannot be specified at the same time. 124 | with pytest.raises(ValueError): 125 | abstract(promote=1, promote_from=1)(lambda: None) 126 | 127 | class General: 128 | pass 129 | 130 | class Specific: 131 | pass 132 | 133 | a = General() 134 | b = Specific() 135 | 136 | # Temporarily mock Plum's promotion function. 137 | plum_promote = plum.promote 138 | plum.promote = lambda *args: (b,) * len(args) 139 | 140 | # Define some abstract functions. 141 | 142 | @B.dispatch 143 | @abstract() 144 | def f1(*args: General): 145 | pass 146 | 147 | @B.dispatch 148 | def f1(*args: Specific): 149 | return args 150 | 151 | @B.dispatch 152 | @abstract(promote=None) 153 | def f2(*args: General): 154 | pass 155 | 156 | @B.dispatch 157 | def f2(*args: Specific): 158 | return args 159 | 160 | @B.dispatch 161 | @abstract(promote=-1) 162 | def f3(*args: General): 163 | pass 164 | 165 | @B.dispatch 166 | def f3(*args: Specific): 167 | return args 168 | 169 | @B.dispatch 170 | @abstract(promote_from=-1) 171 | def f3_from(*args: General): 172 | pass 173 | 174 | @B.dispatch 175 | def f3_from(*args: Specific): 176 | return args 177 | 178 | @B.dispatch 179 | @abstract(promote=0) 180 | def f4(*args: General): 181 | pass 182 | 183 | @B.dispatch 184 | def f4(*args: Specific): 185 | return args 186 | 187 | @B.dispatch 188 | @abstract(promote_from=0) 189 | def f4_from(*args: General): 190 | pass 191 | 192 | @B.dispatch 193 | def f4_from(*args: Specific): 194 | return args 195 | 196 | @B.dispatch 197 | @abstract(promote=1) 198 | def f5(*args: General): 199 | pass 200 | 201 | @B.dispatch 202 | def f5(arg: Specific, *args: General): 203 | return (arg,) + args 204 | 205 | @B.dispatch 206 | @abstract(promote_from=1) 207 | def f5_from(*args: General): 208 | pass 209 | 210 | @B.dispatch 211 | def f5_from(arg: General, *args: Specific): 212 | return (arg,) + args 213 | 214 | @B.dispatch 215 | @abstract(promote=2) 216 | def f6(*args: General): 217 | pass 218 | 219 | @B.dispatch 220 | def f6(arg1: Specific, arg2: Specific, *args: General): 221 | return (arg1, arg2) + args 222 | 223 | @B.dispatch 224 | @abstract(promote_from=2) 225 | def f6_from(*args: General): 226 | pass 227 | 228 | @B.dispatch 229 | def f6_from(arg1: General, arg2: General, *args: Specific): 230 | return (arg1, arg2) + args 231 | 232 | # Register methods. 233 | B.f1 = f1 234 | B.f2 = f2 235 | B.f3 = f3 236 | B.f3_from = f3_from 237 | B.f4 = f4 238 | B.f4_from = f4_from 239 | B.f5 = f5 240 | B.f5_from = f5_from 241 | B.f6 = f6 242 | B.f6_from = f6_from 243 | 244 | # Test promotion. 245 | with pytest.raises(NotFoundLookupError): 246 | f1(a, a, a) 247 | 248 | with pytest.raises(NotFoundLookupError): 249 | f2(a, a, a) 250 | 251 | assert f3(a, a, a) == (b, b, b) 252 | with pytest.raises(NotFoundLookupError): 253 | f3_from(a, a, a) 254 | 255 | with pytest.raises(NotFoundLookupError): 256 | f4(a, a, a) 257 | assert f4_from(a, a, a) == (b, b, b) 258 | 259 | assert f5(a, a, a) == (b, a, a) 260 | assert f5(a) == (b,) 261 | assert f5_from(a, a, a) == (a, b, b) 262 | assert f5_from(a, a) == (a, b) 263 | 264 | assert f6(a, a, a) == (b, b, a) 265 | assert f6(a, a) == (b, b) 266 | assert f6_from(a, a, a, a) == (a, a, b, b) 267 | assert f6_from(a, a, a) == (a, a, b) 268 | 269 | # Put back promotion function. 270 | plum.promote = plum_promote 271 | -------------------------------------------------------------------------------- /tests/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | from itertools import product 4 | from typing import Union 5 | 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import pytest 9 | import tensorflow as tf 10 | import torch 11 | from autograd.core import VJPNode, getval 12 | from autograd.tracer import new_box, trace_stack 13 | from plum import Dispatcher, isinstance 14 | 15 | import lab as B 16 | from lab.shape import Dimension, Shape, unwrap_dimension 17 | 18 | __all__ = [ 19 | "check_lazy_shapes", 20 | "autograd_box", 21 | "to_np", 22 | "approx", 23 | "check_function", 24 | "Tensor", 25 | "PositiveTensor", 26 | "ComplexTensor", 27 | "BoolTensor", 28 | "NaNTensor", 29 | "Matrix", 30 | "PSD", 31 | "PSDTriangular", 32 | "Tuple", 33 | "List", 34 | "Value", 35 | "Bool", 36 | ] 37 | 38 | log = logging.getLogger("lab." + __name__) 39 | 40 | _dispatch = Dispatcher() 41 | 42 | 43 | @pytest.fixture(params=[False, True]) 44 | def check_lazy_shapes(request): 45 | if request.param: 46 | with B.lazy_shapes(): 47 | yield 48 | else: 49 | yield 50 | 51 | 52 | def autograd_box(x): 53 | """Box a tensor in AutoGrad.""" 54 | t = trace_stack.new_trace().__enter__() 55 | n = VJPNode.new_root() 56 | return new_box(x, t, n) 57 | 58 | 59 | @_dispatch(precedence=1) 60 | def to_np(x: Union[B.NPNumeric, B.Number]): 61 | """Convert a tensor to NumPy.""" 62 | return x 63 | 64 | 65 | @_dispatch 66 | def to_np(x: Dimension): 67 | return unwrap_dimension(x) 68 | 69 | 70 | @_dispatch 71 | def to_np(x: B.AGNumeric): 72 | return getval(x) 73 | 74 | 75 | @_dispatch 76 | def to_np(x: Union[B.TorchNumeric, B.TFNumeric]): 77 | return x.numpy() 78 | 79 | 80 | @_dispatch 81 | def to_np(x: B.JAXNumeric): 82 | return np.array(x) 83 | 84 | 85 | @_dispatch 86 | def to_np(tup: Union[tuple, tf.TensorShape, torch.Size, Shape]): 87 | return tuple(to_np(x) for x in tup) 88 | 89 | 90 | @_dispatch 91 | def to_np(lst: list): 92 | return to_np(tuple(lst)) 93 | 94 | 95 | @_dispatch 96 | def approx(x, y, assert_dtype: bool = False, **kw_args): 97 | """Assert that two numeric objects are close.""" 98 | x, y = to_np(x), to_np(y) 99 | 100 | # Assert that data types are equal if required. 101 | if assert_dtype: 102 | dtype_x = np.array(x).dtype 103 | dtype_y = np.array(y).dtype 104 | if dtype_x != dtype_y: 105 | raise AssertionError( 106 | f"Data types not equal: `{dtype_x}` versus `{dtype_y}`." 107 | ) 108 | 109 | np.testing.assert_allclose(x, y, **kw_args) 110 | 111 | 112 | @_dispatch 113 | def approx(x: tuple, y: tuple, assert_dtype: bool = False, **kw_args): 114 | assert len(x) == len(y) 115 | for xi, yi in zip(x, y): 116 | approx(xi, yi, assert_dtype=assert_dtype, **kw_args) 117 | 118 | 119 | def check_function( 120 | f, 121 | args_spec, 122 | kw_args_spec=None, 123 | assert_dtype=True, 124 | skip=None, 125 | contains_nans=None, 126 | ): 127 | """Check that a function produces consistent output. Moreover, if the first 128 | argument is a data type, check that the result is exactly of that type.""" 129 | skip = [] if skip is None else skip 130 | 131 | if kw_args_spec is None: 132 | kw_args_spec = {} 133 | 134 | # Construct product of keyword arguments. 135 | kw_args_prod = list( 136 | product(*[[(k, v) for v in vs.forms()] for k, vs in kw_args_spec.items()]) 137 | ) 138 | kw_args_prod = [{k: v for k, v in kw_args} for kw_args in kw_args_prod] 139 | 140 | # Add default call. 141 | kw_args_prod += [{}] 142 | 143 | # Construct product of arguments. 144 | args_prod = list(product(*[arg.forms() for arg in args_spec])) 145 | 146 | # Construct framework types to skip mixes of. 147 | fw_types = [ 148 | Union[t, typing.List[t], typing.Tuple[t, ...]] 149 | for t in [B.AGNumeric, B.TorchNumeric, B.TFNumeric, B.JAXNumeric] 150 | ] 151 | 152 | # Construct other types to skip entirely. 153 | skip_types = [Union[t, typing.List[t], typing.Tuple[t, ...]] for t in skip] 154 | 155 | def exempt(arg): 156 | """Allow empty tuples and lists.""" 157 | return isinstance(arg, (tuple, list)) and len(arg) == 0 158 | 159 | # Check consistency of results. 160 | for kw_args in kw_args_prod: 161 | # Compare everything against the first result. 162 | first_result = f(*args_prod[0], **kw_args) 163 | 164 | # If first argument is a data type, then check that. 165 | if isinstance(args_prod[0][0], B.DType): 166 | assert B.dtype(first_result) is args_prod[0][0] 167 | 168 | for args in args_prod: 169 | # Skip mixes of FW types. 170 | fw_count = sum( 171 | [ 172 | any(not exempt(arg) and isinstance(arg, t) for arg in args) 173 | for t in fw_types 174 | ] 175 | ) 176 | 177 | # Skip all skips. 178 | skip_count = sum( 179 | [ 180 | any(not exempt(arg) and isinstance(arg, t) for arg in args) 181 | for t in skip_types 182 | ] 183 | ) 184 | 185 | if fw_count >= 2 or skip_count >= 1: 186 | log.debug( 187 | f"Skipping call with arguments {args} and keyword " 188 | f"arguments {kw_args}." 189 | ) 190 | continue 191 | 192 | # Check consistency. 193 | log.debug(f"Call with arguments {args} and keyword arguments {kw_args}.") 194 | result = f(*args, **kw_args) 195 | approx(first_result, result, assert_dtype=assert_dtype) 196 | 197 | # If first argument is a data type, then again check that. 198 | if isinstance(args[0], B.DType): 199 | assert B.dtype(result) is args[0] 200 | 201 | # Check NaNs. 202 | if contains_nans is not None: 203 | assert B.any(B.isnan(result)) == contains_nans 204 | 205 | 206 | class Tensor: 207 | """Tensor placeholder.""" 208 | 209 | def __init__(self, *dims, **kw_args): 210 | if "mat" not in kw_args or kw_args["mat"] is None: 211 | self.mat = np.array(np.random.randn(*dims)) 212 | else: 213 | self.mat = kw_args["mat"] 214 | 215 | def forms(self): 216 | return [self.np(), self.tf(), self.torch(), self.ag(), self.jax()] 217 | 218 | def np(self): 219 | return self.mat 220 | 221 | def tf(self): 222 | return tf.constant(self.mat) 223 | 224 | def torch(self): 225 | return torch.tensor(self.mat) 226 | 227 | def ag(self): 228 | return autograd_box(self.mat) 229 | 230 | def jax(self): 231 | return jnp.array(self.mat) 232 | 233 | 234 | class PositiveTensor(Tensor): 235 | """Positive tensor placeholder.""" 236 | 237 | def __init__(self, *dims, upper=1, **kw_args): 238 | if "mat" not in kw_args or kw_args["mat"] is None: 239 | mat = np.array(upper * np.random.rand(*dims)) 240 | else: 241 | mat = kw_args["mat"] 242 | Tensor.__init__(self, mat=mat) 243 | 244 | 245 | class ComplexTensor(Tensor): 246 | """Complex tensor placeholder.""" 247 | 248 | def __init__(self, *dims, **kw_args): 249 | if "mat" not in kw_args or kw_args["mat"] is None: 250 | mat = np.array(np.random.randn(*dims), dtype=np.complex128) 251 | else: 252 | mat = kw_args["mat"] 253 | Tensor.__init__(self, mat=mat) 254 | 255 | 256 | class BoolTensor(Tensor): 257 | """Boolean tensor placeholder.""" 258 | 259 | def __init__(self, *dims, **kw_args): 260 | if "mat" not in kw_args or kw_args["mat"] is None: 261 | mat = np.array(np.random.rand(*dims) > 0.5) 262 | else: 263 | mat = kw_args["mat"] 264 | Tensor.__init__(self, mat=mat) 265 | 266 | def torch(self): 267 | return torch.tensor(self.mat.astype(np.uint8)) 268 | 269 | 270 | class NaNTensor(Tensor): 271 | """Tensor containing NaNs placeholder.""" 272 | 273 | def __init__(self, *dims, **kw_args): 274 | if "mat" not in kw_args or kw_args["mat"] is None: 275 | mat = np.array(np.random.randn(*dims)) 276 | if len(dims) > 0: 277 | # Checkboard from https://stackoverflow.com/q/2169478. 278 | checkerboard = np.indices(dims).sum(axis=0) % 2 279 | mat[checkerboard == 1] = np.nan 280 | else: 281 | mat = kw_args["mat"] 282 | Tensor.__init__(self, mat=mat) 283 | 284 | 285 | class Matrix(Tensor): 286 | """Matrix placeholder.""" 287 | 288 | def __init__(self, *shape, **kw_args): 289 | # Handle shorthands. 290 | if shape == (): 291 | shape = (3, 3) 292 | elif len(shape) == 1: 293 | shape = shape * 2 294 | 295 | Tensor.__init__(self, *shape, **kw_args) 296 | 297 | 298 | class PSD(Matrix): 299 | """Positive-definite tensor placeholder.""" 300 | 301 | def __init__(self, *shape): 302 | # Handle shorthands. 303 | if shape == (): 304 | shape = (3, 3) 305 | elif len(shape) == 1: 306 | shape = shape * 2 307 | 308 | if not shape[-2] == shape[-1]: 309 | raise ValueError("PSD matrix must be square.") 310 | 311 | a = np.random.randn(*shape) 312 | perm = list(range(len(a.shape))) 313 | perm[-2], perm[-1] = perm[-1], perm[-2] 314 | a_t = np.transpose(a, perm) 315 | Matrix.__init__(self, mat=np.matmul(a, a_t)) 316 | 317 | 318 | class PSDTriangular(PSD): 319 | def __init__(self, *shape, **kw_args): 320 | PSD.__init__(self, *shape) 321 | 322 | # Zero upper triangular part. 323 | for i in range(self.mat.shape[0]): 324 | for j in range(i + 1, self.mat.shape[1]): 325 | self.mat[..., i, j] = 0 326 | 327 | # Create upper-triangular matrices, if asked for. 328 | if kw_args.get("upper", False): 329 | perm = list(range(len(self.mat.shape))) 330 | perm[-2], perm[-1] = perm[-1], perm[-2] 331 | self.mat = np.transpose(self.mat, perm) 332 | 333 | 334 | class Tuple: 335 | """Tuple placeholder.""" 336 | 337 | def __init__(self, *xs): 338 | self.xs = xs 339 | 340 | def forms(self): 341 | return map(tuple, zip(*(x.forms() for x in self.xs))) 342 | 343 | 344 | class List: 345 | """List placeholder for in argument specification.""" 346 | 347 | def __init__(self, *xs): 348 | self.xs = xs 349 | 350 | def forms(self): 351 | return map(list, zip(*(x.forms() for x in self.xs))) 352 | 353 | 354 | class Value: 355 | """Value placeholder.""" 356 | 357 | def __init__(self, *values): 358 | self._values = values 359 | 360 | def forms(self): 361 | return self._values 362 | 363 | 364 | class Bool(Value): 365 | """Boolean placeholder.""" 366 | 367 | def __init__(self): 368 | Value.__init__(self, False, True) 369 | -------------------------------------------------------------------------------- /todo.tasks: -------------------------------------------------------------------------------- 1 | TODO: 2 | ☐ Borrow global state so that you never can forget to set it. @high 3 | ☐ Make `B.jit` work with Torch and TF generators to ensure uniform patterns. @high 4 | 5 | ☐ Check optimality of `move_to_device`. @low 6 | 7 | ☐ Reuse Plum's error message. 8 | 9 | Bugs: 10 | ☐ Is broadcasting of shapes in `B.randgamma` safe with the JIT? @high 11 | 12 | Functions: 13 | ☐ eigvals 14 | ☐ norm 15 | ☐ dot 16 | ☐ cos_sim 17 | 18 | ___________________ 19 | Archive: 20 | ✓ Allow to index with `int32` for Torch @high @done (22-04-28 15:50) @project(TODO) 21 | ✓ Add test like this: @high @done (22-04-28 15:50) @project(TODO) 22 | import lab as B 23 | import tensorflow as tf 24 | import lab.tensorflow 25 | import jax.numpy as jnp 26 | import lab.jax 27 | import torch 28 | import lab.torch 29 | for dtype in [np.float32, jnp.float32, tf.float32, torch.float64] 30 | ✓ Let `cholesky_solve` for PyTorch use `torch.cholesky_solve` once the derivative is implemented. @done (22-03-30 19:33) @project(Future) 31 | ✓ Jax @done (22-03-30 19:32) @project(TODO / Support) 32 | ✓ Design with AutoGrad as well? @done (22-03-30 19:32) @project(TODO / Support) 33 | x Refactor tests to use PyTest: remove raises, fixtures, and parametrisation. @high @cancelled (22-03-30 19:32) @project(Functions) 34 | x Refactor scan once TF2.0 is integrated. @cancelled (19-07-07 18:53) @project(TODO) 35 | ✓ Port bvn_cdf. @done (19-05-16 18:06) @project(TODO) 36 | ✓ Check Python 2 and Python 3 compatibility. @done (19-05-16 18:06) @project(TODO) 37 | ✓ Documentation. @critical @done (19-05-02 13:29) @project(TODO) 38 | ✓ Add support for AutoGrad. @done (19-05-01 13:25) @project(TODO) 39 | x = B.range(dtype, 10) @project(TODO) 40 | perm = B.randperm(B.dtype_int(dtype), 10) 41 | B.take(x, B.cast(B.promote_dtypes(B.dtype_int(dtype), int), perm)) 42 | --------------------------------------------------------------------------------