├── .github
└── workflows
│ ├── check.yml
│ ├── publish.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── README.md
├── autograd
├── __init__.py
├── builtins.py
├── core.py
├── differential_operators.py
├── extend.py
├── misc
│ ├── __init__.py
│ ├── fixed_points.py
│ ├── flatten.py
│ ├── optimizers.py
│ └── tracers.py
├── numpy
│ ├── __init__.py
│ ├── fft.py
│ ├── linalg.py
│ ├── numpy_boxes.py
│ ├── numpy_jvps.py
│ ├── numpy_vjps.py
│ ├── numpy_vspaces.py
│ ├── numpy_wrapper.py
│ └── random.py
├── scipy
│ ├── __init__.py
│ ├── integrate.py
│ ├── linalg.py
│ ├── signal.py
│ ├── special.py
│ └── stats
│ │ ├── __init__.py
│ │ ├── beta.py
│ │ ├── chi2.py
│ │ ├── dirichlet.py
│ │ ├── gamma.py
│ │ ├── multivariate_normal.py
│ │ ├── norm.py
│ │ ├── poisson.py
│ │ └── t.py
├── test_util.py
├── tracer.py
├── util.py
└── wrap_util.py
├── benchmarks
├── __init__.py
├── asv.conf.json.sample
├── bench_core.py
├── bench_mem.py
├── bench_numpy_vjps.py
├── bench_rnn.py
└── bench_util.py
├── conda_recipe
└── conda.yaml
├── docs
├── tutorial.md
└── updateguide.md
├── examples
├── __init__.py
├── bayesian_neural_net.png
├── bayesian_neural_net.py
├── bayesian_optimization.py
├── black_box_svi.py
├── convnet.py
├── data.py
├── data_mnist.py
├── deep_gaussian_process.py
├── define_gradient.py
├── dot_graph.py
├── fixed_points.py
├── fluidsim
│ ├── animated.gif
│ ├── fluidsim.py
│ ├── init_smoke.png
│ ├── peace.png
│ ├── skull.png
│ ├── surprise.gif
│ ├── wing.png
│ └── wing.py
├── gaussian_process.png
├── gaussian_process.py
├── generative_adversarial_net.py
├── gmm.png
├── gmm.py
├── gplvm.png
├── gplvm.py
├── graph.pdf
├── hmm_em.py
├── ica.py
├── logistic_regression.py
├── lstm.py
├── mixture_variational_inference.py
├── natural_gradient_black_box_svi.py
├── negative_binomial_maxlike.py
├── neural_net.py
├── neural_net_regression.py
├── ode_net.py
├── ode_net_demo.png
├── print_trace.py
├── rkhs.py
├── rnn.py
├── rosenbrock.py
├── sinusoid.png
├── sinusoid.py
├── sinusoid_taylor.png
├── tanh.png
├── tanh.py
├── vae_samples.png
└── variational_autoencoder.py
├── license.txt
├── noxfile.py
├── pyproject.toml
└── tests
├── _test_complexity.py
├── check_examples_run.sh
├── conftest.py
├── numpy_utils.py
├── profiling.py
├── test_binary_ops.py
├── test_builtins.py
├── test_complex.py
├── test_core.py
├── test_dict.py
├── test_direct.py
├── test_fft.py
├── test_graphs.py
├── test_jacobian.py
├── test_linalg.py
├── test_list.py
├── test_logic.py
├── test_misc.py
├── test_numpy.py
├── test_performance.py
├── test_scalar_ops.py
├── test_scipy.py
├── test_systematic.py
├── test_tests.py
├── test_truediv.py
├── test_tuple.py
├── test_vspaces.py
└── test_wrappers.py
/.github/workflows/check.yml:
--------------------------------------------------------------------------------
1 | name: Style and package checks
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - master
7 | push:
8 | branches:
9 | - master
10 | workflow_dispatch:
11 |
12 | env:
13 | PIP_DISABLE_PIP_VERSION_CHECK: "1"
14 | FORCE_COLOR: "3"
15 |
16 | concurrency:
17 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
18 | cancel-in-progress: true
19 |
20 | jobs:
21 | check:
22 | name: ${{ matrix.env }}
23 | runs-on: ubuntu-latest
24 | strategy:
25 | fail-fast: false
26 | matrix:
27 | session:
28 | # - lint
29 | - validate-package
30 | steps:
31 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
32 | - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
33 |
34 | - uses: yezz123/setup-uv@ab6be5a42627f19dc36e57b548592a5e52cece4a # v4.1
35 |
36 | - name: Run ${{ matrix.env }}
37 | run: uvx nox -s ${{ matrix.env }}
38 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | workflow_dispatch:
5 | release:
6 | types: [published]
7 |
8 | env:
9 | PIP_DISABLE_PIP_VERSION_CHECK: '1'
10 | FORCE_COLOR: '3'
11 |
12 | jobs:
13 | build:
14 | name: Build sdist and wheel
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
18 | name: Checkout repository
19 |
20 | - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
21 | with:
22 | python-version: "3.12"
23 |
24 | - name: Install build tools
25 | run: |
26 | pipx run build --outdir dist
27 |
28 | - name: Upload wheel and sdist artifacts
29 | uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
30 | with:
31 | name: artifacts
32 | path: ./dist/*
33 | if-no-files-found: error
34 |
35 | publish:
36 | needs: [build]
37 | name: Upload to PyPI
38 | runs-on: ubuntu-latest
39 | environment:
40 | name: release
41 | url: https://pypi.org/p/autograd
42 | permissions:
43 | id-token: write # mandatory for trusted publishing
44 |
45 | steps:
46 | - name: Download artifacts
47 | uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
48 | with:
49 | path: dist
50 | merge-multiple: true
51 |
52 | - name: Sanity check artifacts
53 | run: ls -la dist/
54 |
55 | - name: Publish sdist and wheel to PyPI
56 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4
57 | with:
58 | packages-dir: dist/
59 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - master
7 | push:
8 | branches:
9 | - master
10 | workflow_dispatch:
11 | schedule:
12 | - cron: "0 4 * * *"
13 |
14 | env:
15 | PIP_DISABLE_PIP_VERSION_CHECK: "1"
16 | FORCE_COLOR: "3"
17 |
18 | concurrency:
19 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
20 | cancel-in-progress: true
21 |
22 | jobs:
23 | test:
24 | name: Regular tests / ${{ matrix.platform }} / Python ${{ matrix.python-version }}
25 | runs-on: ${{ matrix.platform }}
26 | strategy:
27 | fail-fast: false
28 | matrix:
29 | platform: [ubuntu-latest, ubuntu-22.04-arm, macos-13, macos-latest, windows-latest]
30 | python-version:
31 | ["3.9", "3.10", "3.11", "3.12", "3.13", "pypy-3.10"]
32 | steps:
33 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
34 | - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
35 | with:
36 | python-version: ${{ matrix.python-version }}
37 | allow-prereleases: true
38 | - uses: yezz123/setup-uv@ab6be5a42627f19dc36e57b548592a5e52cece4a # v4.1
39 |
40 | # On PyPy, we skip SciPy because we don't have wheels
41 | # available, see noxfile.py for more details.
42 | - name: Run tests
43 | run: uvx nox -s tests
44 |
45 | # In this job, we test against the NumPy nightly wheels hosted on
46 | # https://anaconda.org/scientific-python-nightly-wheels/numpy
47 | # on the latest Python version available across platforms, instead of
48 | # testing all Python versions and implementations on all platforms.
49 | # We do not test on PyPy.
50 | #
51 | # However, "nox -s nightly-tests" can be used locally anywhere, on
52 | # any Python version and implementation on any platform and we leave
53 | # it to the user to decide what Python version to test against, which
54 | # might or might not have a corresponding NumPy nightly wheel present.
55 | nightlies:
56 | name: Nightly tests / ${{ matrix.platform }} / Python ${{ matrix.python-version }}
57 | runs-on: ${{ matrix.platform }}
58 | strategy:
59 | fail-fast: false
60 | matrix:
61 | platform: [ubuntu-latest, ubuntu-22.04-arm, macos-13, macos-latest, windows-latest]
62 | python-version: ["3.x"]
63 |
64 | steps:
65 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
66 | - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
67 | with:
68 | python-version: ${{ matrix.python-version }}
69 | allow-prereleases: true
70 | - uses: yezz123/setup-uv@ab6be5a42627f19dc36e57b548592a5e52cece4a # v4.1
71 | - name: Run tests against nightly wheels for NumPy and SciPy
72 | run: uvx nox -s nightly-tests
73 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 |
5 | # Distribution / packaging
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 | MANIFEST
23 |
24 | # Installer logs
25 | pip-log.txt
26 | pip-delete-this-directory.txt
27 |
28 | # Unit test / coverage reports
29 | htmlcov/
30 | .tox/
31 | .nox/
32 | .coverage
33 | .coverage.*
34 | .cache
35 | coverage.*
36 | *.cover
37 | .hypothesis/
38 | nosetests.xml
39 | .pytest_cache/
40 | junit-report.xml
41 |
42 | # pyenv
43 | .python-version
44 |
45 | # Environments
46 | .env
47 | .venv
48 | env/
49 | venv/
50 | ENV/
51 | env.bak/
52 | venv.bak/
53 |
54 | # mypy
55 | .mypy_cache/
56 |
57 | # OS and IDE config files
58 | .DS_Store
59 | .idea/
60 |
61 | # project-specific
62 | data/
63 | *.so
64 | *.c
65 | scratch/
66 | examples/data
67 |
68 | .asv/
69 | asv.conf.json
70 | benchmarks/asv.conf.js
71 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ci:
2 | autoupdate_commit_msg: "chore: update pre-commit hooks"
3 | autofix_commit_msg: "style: pre-commit fixes"
4 |
5 | repos:
6 | - repo: https://github.com/pre-commit/pre-commit-hooks
7 | rev: v5.0.0
8 | hooks:
9 | - id: check-added-large-files
10 | - id: check-case-conflict
11 | - id: check-merge-conflict
12 | - id: check-yaml
13 | exclude: conda_recipe/conda.yaml
14 | - id: debug-statements
15 | - id: end-of-file-fixer
16 | - id: mixed-line-ending
17 | - id: trailing-whitespace
18 |
19 | - repo: https://github.com/astral-sh/ruff-pre-commit
20 | rev: "v0.11.12"
21 | hooks:
22 | - id: ruff
23 | args: ["--fix", "--show-fixes"]
24 | - id: ruff-format
25 |
26 | - repo: https://github.com/pre-commit/pygrep-hooks
27 | rev: v1.10.0
28 | hooks:
29 | - id: python-check-blanket-type-ignore
30 | exclude: ^src/vector/backends/_numba_object.py$
31 | - id: rst-backticks
32 | - id: rst-directive-colons
33 | - id: rst-inline-touching-normal
34 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | Use [Nox](https://nox.thea.codes/en/stable/) to run tests and linting, e.g.,
4 |
5 | ```shell
6 | pip install nox
7 | ```
8 |
9 | `nox` will run all checks in an isolated virtual environment with Autograd and its dependencies, including its optional dependencies, installed.
10 |
11 | ## Run tests, linting, packaging checks
12 |
13 | | Command | Description |
14 | | ------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
15 | | `nox --list` | Lists all available Nox sessions, including selected ones |
16 | | `nox -s lint` | Runs code style checks with pre-commit and pre-commit hooks as listed in `.pre-commit-config.yaml`. Accepts posargs to pass additional arguments to the linter. |
17 | | `nox -s tests` | Runs tests with your default Python interpreter. Accepts posargs to pass additional arguments and configuration to `pytest`. |
18 | | `nox -s nightly-tests` | Similar to `nox -s tests`, except that it runs tests with nightly versions of dependencies (NumPy, SciPy, etc.). |
19 | | `nox -s validate-package` | Builds a source distribution and a wheel using `pypa/build` and checks the package with `twine` in strict mode. |
20 | | `nox` | Runs all selected sessions, as listed in `nox.options.sessions` in `noxfile.py`. |
21 |
22 | Additionally, `nox` supports tags to run specific sessions, e.g., `nox --tags tests` runs all sessions tagged with `tests`.
23 |
24 | Make sure all tests pass before you push your changes to GitHub.
25 | GH Actions will run the tests across all supported Python versions.
26 |
27 | ## Using positional arguments (reformat, upload package, help)
28 |
29 | You can use additional arguments for the tools (`pytest`, `pre-commit`, etc.) called by Nox by
30 | separating them from the Nox arguments by a double-hyphen `--`, e.g.,
31 |
32 | - `nox -s tests -- --tests/test_tuple.py` runs just the tests listed `tests/test_tuple.py`.
33 | - `nox -s lint -- --fix` runs the linter with the `--fix` flag.
34 | - and so on.
35 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Autograd [![Checks status][checks-badge]][checks-url] [![Tests status][tests-badge]][tests-url] [![Publish status][publish-badge]][publish-url] [![asv][asv-badge]](#)
2 |
3 | [publish-badge]: https://github.com/HIPS/autograd/actions/workflows/publish.yml/badge.svg
4 | [checks-badge]: https://github.com/HIPS/autograd/actions/workflows/check.yml/badge.svg
5 | [tests-badge]: https://github.com/HIPS/autograd/actions/workflows/test.yml/badge.svg
6 | [asv-badge]: http://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat
7 | [publish-url]: https://github.com/HIPS/autograd/actions/workflows/publish.yml
8 | [checks-url]: https://github.com/HIPS/autograd/actions/workflows/check.yml
9 | [tests-url]: https://github.com/HIPS/autograd/actions/workflows/test.yml
10 |
11 | Autograd can automatically differentiate native Python and Numpy code. It can
12 | handle a large subset of Python's features, including loops, ifs, recursion and
13 | closures, and it can even take derivatives of derivatives of derivatives. It
14 | supports reverse-mode differentiation (a.k.a. backpropagation), which means it
15 | can efficiently take gradients of scalar-valued functions with respect to
16 | array-valued arguments, as well as forward-mode differentiation, and the two can
17 | be composed arbitrarily. The main intended application of Autograd is
18 | gradient-based optimization. For more information, check out the
19 | [tutorial](docs/tutorial.md) and the [examples directory](examples/).
20 |
21 | Example use:
22 |
23 | ```python
24 | >>> import autograd.numpy as np # Thinly-wrapped numpy
25 | >>> from autograd import grad # The only autograd function you may ever need
26 | >>>
27 | >>> def tanh(x): # Define a function
28 | ... return (1.0 - np.exp((-2 * x))) / (1.0 + np.exp(-(2 * x)))
29 | ...
30 | >>> grad_tanh = grad(tanh) # Obtain its gradient function
31 | >>> grad_tanh(1.0) # Evaluate the gradient at x = 1.0
32 | np.float64(0.419974341614026)
33 | >>> (tanh(1.0001) - tanh(0.9999)) / 0.0002 # Compare to finite differences
34 | np.float64(0.41997434264973155)
35 | ```
36 |
37 | We can continue to differentiate as many times as we like, and use numpy's
38 | vectorization of scalar-valued functions across many different input values:
39 |
40 | ```python
41 | >>> from autograd import elementwise_grad as egrad # for functions that vectorize over inputs
42 | >>> import matplotlib.pyplot as plt
43 | >>> x = np.linspace(-7, 7, 700)
44 | >>> plt.plot(x, tanh(x),
45 | ... x, egrad(tanh)(x), # first derivative
46 | ... x, egrad(egrad(tanh))(x), # second derivative
47 | ... x, egrad(egrad(egrad(tanh)))(x), # third derivative
48 | ... x, egrad(egrad(egrad(egrad(tanh))))(x), # fourth derivative
49 | >>> plt.show()
50 | ```
51 |
52 |
53 |
54 | See the [tanh example file](examples/tanh.py) for the code.
55 |
56 | ## Documentation
57 |
58 | You can find a tutorial [here.](docs/tutorial.md)
59 |
60 | ## End-to-end examples
61 |
62 | * [Simple neural net](examples/neural_net.py)
63 | * [Convolutional neural net](examples/convnet.py)
64 | * [Recurrent neural net](examples/rnn.py)
65 | * [LSTM](examples/lstm.py)
66 | * [Neural Turing Machine](https://github.com/DoctorTeeth/diffmem/blob/512aadeefd6dbafc1bdd253a64b6be192a435dc3/ntm/ntm.py)
67 | * [Backpropagating through a fluid simulation](examples/fluidsim/fluidsim.py)
68 |
69 |
70 |
71 | * [Variational inference in Bayesian neural network](examples/bayesian_neural_net.py)
72 | * [Gaussian process regression](examples/gaussian_process.py)
73 | * [Sampyl, a pure Python MCMC package with HMC and NUTS](https://github.com/mcleonard/sampyl)
74 |
75 | ## How to install
76 |
77 | Install Autograd using Pip:
78 |
79 | ```shell
80 | pip install autograd
81 | ```
82 |
83 | Some features require SciPy, which you can install separately or as an
84 | optional dependency along with Autograd:
85 |
86 | ```shell
87 | pip install "autograd[scipy]"
88 | ```
89 |
90 | ## Authors and maintainers
91 |
92 | Autograd was written by [Dougal Maclaurin](https://dougalmaclaurin.com),
93 | [David Duvenaud](https://www.cs.toronto.edu/~duvenaud/),
94 | [Matt Johnson](http://people.csail.mit.edu/mattjj/),
95 | [Jamie Townsend](https://github.com/j-towns)
96 | and many other contributors. The package is currently being maintained by
97 | [Agriya Khetarpal](https://github.com/agriyakhetarpal),
98 | [Fabian Joswig](https://github.com/fjosw) and
99 | [Jamie Townsend](https://github.com/j-towns).
100 | Please feel free to submit any bugs or
101 | feature requests. We'd also love to hear about your experiences with Autograd
102 | in general. Drop us an email!
103 |
104 | We want to thank Jasper Snoek and the rest of the HIPS group (led by Prof. Ryan
105 | P. Adams) for helpful contributions and advice; Barak Pearlmutter for
106 | foundational work on automatic differentiation and for guidance on our
107 | implementation; and Analog Devices Inc. (Lyric Labs) and Samsung Advanced Institute
108 | of Technology for their generous support.
109 |
--------------------------------------------------------------------------------
/autograd/__init__.py:
--------------------------------------------------------------------------------
1 | from autograd.core import primitive_with_deprecation_warnings as primitive
2 |
3 | from .builtins import dict, isinstance, list, tuple, type
4 | from .differential_operators import (
5 | checkpoint,
6 | deriv,
7 | elementwise_grad,
8 | grad,
9 | grad_and_aux,
10 | grad_named,
11 | hessian,
12 | hessian_tensor_product,
13 | hessian_vector_product,
14 | holomorphic_grad,
15 | jacobian,
16 | make_ggnvp,
17 | make_hvp,
18 | make_jvp,
19 | make_vjp,
20 | multigrad_dict,
21 | tensor_jacobian_product,
22 | value_and_grad,
23 | vector_jacobian_product,
24 | )
25 |
--------------------------------------------------------------------------------
/autograd/extend.py:
--------------------------------------------------------------------------------
1 | # Exposes API for extending autograd
2 | from .core import (
3 | JVPNode,
4 | SparseObject,
5 | VJPNode,
6 | VSpace,
7 | def_linear,
8 | defjvp,
9 | defjvp_argnum,
10 | defjvp_argnums,
11 | defvjp,
12 | defvjp_argnum,
13 | defvjp_argnums,
14 | vspace,
15 | )
16 | from .tracer import Box, notrace_primitive, primitive, register_notrace
17 |
--------------------------------------------------------------------------------
/autograd/misc/__init__.py:
--------------------------------------------------------------------------------
1 | from .flatten import flatten
2 | from .tracers import const_graph
3 |
--------------------------------------------------------------------------------
/autograd/misc/fixed_points.py:
--------------------------------------------------------------------------------
1 | from autograd import make_vjp
2 | from autograd.builtins import tuple
3 | from autograd.extend import defvjp, primitive, vspace
4 |
5 |
6 | @primitive
7 | def fixed_point(f, a, x0, distance, tol):
8 | _f = f(a)
9 | x, x_prev = _f(x0), x0
10 | while distance(x, x_prev) > tol:
11 | x, x_prev = _f(x), x
12 | return x
13 |
14 |
15 | def fixed_point_vjp(ans, f, a, x0, distance, tol):
16 | def rev_iter(params):
17 | a, x_star, x_star_bar = params
18 | vjp_x, _ = make_vjp(f(a))(x_star)
19 | vs = vspace(x_star)
20 | return lambda g: vs.add(vjp_x(g), x_star_bar)
21 |
22 | vjp_a, _ = make_vjp(lambda x, y: f(x)(y))(a, ans)
23 | return lambda g: vjp_a(fixed_point(rev_iter, tuple((a, ans, g)), vspace(x0).zeros(), distance, tol))
24 |
25 |
26 | defvjp(fixed_point, None, fixed_point_vjp, None)
27 |
--------------------------------------------------------------------------------
/autograd/misc/flatten.py:
--------------------------------------------------------------------------------
1 | """
2 | Handy functions for flattening nested containers containing numpy
3 | arrays. The main purpose is to make examples and optimizers simpler.
4 | """
5 |
6 | import autograd.numpy as np
7 | from autograd import make_vjp
8 | from autograd.builtins import type
9 |
10 |
11 | def flatten(value):
12 | """Flattens any nesting of tuples, lists, or dicts, with numpy arrays or
13 | scalars inside. Returns 1D numpy array and an unflatten function.
14 | Doesn't preserve mixed numeric types (e.g. floats and ints). Assumes dict
15 | keys are sortable."""
16 | unflatten, flat_value = make_vjp(_flatten)(value)
17 | return flat_value, unflatten
18 |
19 |
20 | def _flatten(value):
21 | t = type(value)
22 | if t in (list, tuple):
23 | return _concatenate(map(_flatten, value))
24 | elif t is dict:
25 | return _concatenate(_flatten(value[k]) for k in sorted(value))
26 | else:
27 | return np.ravel(value)
28 |
29 |
30 | def _concatenate(lst):
31 | lst = list(lst)
32 | return np.concatenate(lst) if lst else np.array([])
33 |
34 |
35 | def flatten_func(func, example):
36 | _ex, unflatten = flatten(example)
37 | _func = lambda _x, *args: flatten(func(unflatten(_x), *args))[0]
38 | return _func, unflatten, _ex
39 |
--------------------------------------------------------------------------------
/autograd/misc/optimizers.py:
--------------------------------------------------------------------------------
1 | """Some standard gradient-based stochastic optimizers.
2 |
3 | These are just standard routines that don't make any use of autograd,
4 | though you could take gradients of these functions too if you want
5 | to do meta-optimization.
6 |
7 | These routines can optimize functions whose inputs are structured
8 | objects, such as dicts of numpy arrays."""
9 |
10 | import autograd.numpy as np
11 | from autograd.misc import flatten
12 | from autograd.wrap_util import wraps
13 |
14 |
15 | def unflatten_optimizer(optimize):
16 | """Takes an optimizer that operates on flat 1D numpy arrays and returns a
17 | wrapped version that handles trees of nested containers (lists/tuples/dicts)
18 | with arrays/scalars at the leaves."""
19 |
20 | @wraps(optimize)
21 | def _optimize(grad, x0, callback=None, *args, **kwargs):
22 | _x0, unflatten = flatten(x0)
23 | _grad = lambda x, i: flatten(grad(unflatten(x), i))[0]
24 | if callback:
25 | _callback = lambda x, i, g: callback(unflatten(x), i, unflatten(g))
26 | else:
27 | _callback = None
28 | return unflatten(optimize(_grad, _x0, _callback, *args, **kwargs))
29 |
30 | return _optimize
31 |
32 |
33 | @unflatten_optimizer
34 | def sgd(grad, x, callback=None, num_iters=200, step_size=0.1, mass=0.9):
35 | """Stochastic gradient descent with momentum.
36 | grad() must have signature grad(x, i), where i is the iteration number."""
37 | velocity = np.zeros(len(x))
38 | for i in range(num_iters):
39 | g = grad(x, i)
40 | if callback:
41 | callback(x, i, g)
42 | velocity = mass * velocity - (1.0 - mass) * g
43 | x = x + step_size * velocity
44 | return x
45 |
46 |
47 | @unflatten_optimizer
48 | def rmsprop(grad, x, callback=None, num_iters=100, step_size=0.1, gamma=0.9, eps=10**-8):
49 | """Root mean squared prop: See Adagrad paper for details."""
50 | avg_sq_grad = np.ones(len(x))
51 | for i in range(num_iters):
52 | g = grad(x, i)
53 | if callback:
54 | callback(x, i, g)
55 | avg_sq_grad = avg_sq_grad * gamma + g**2 * (1 - gamma)
56 | x = x - step_size * g / (np.sqrt(avg_sq_grad) + eps)
57 | return x
58 |
59 |
60 | @unflatten_optimizer
61 | def adam(grad, x, callback=None, num_iters=100, step_size=0.001, b1=0.9, b2=0.999, eps=10**-8):
62 | """Adam as described in http://arxiv.org/pdf/1412.6980.pdf.
63 | It's basically RMSprop with momentum and some correction terms."""
64 | m = np.zeros(len(x))
65 | v = np.zeros(len(x))
66 | for i in range(num_iters):
67 | g = grad(x, i)
68 | if callback:
69 | callback(x, i, g)
70 | m = (1 - b1) * g + b1 * m # First moment estimate.
71 | v = (1 - b2) * (g**2) + b2 * v # Second moment estimate.
72 | mhat = m / (1 - b1 ** (i + 1)) # Bias correction.
73 | vhat = v / (1 - b2 ** (i + 1))
74 | x = x - step_size * mhat / (np.sqrt(vhat) + eps)
75 | return x
76 |
--------------------------------------------------------------------------------
/autograd/misc/tracers.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from itertools import repeat
3 |
4 | from autograd.tracer import Node, trace
5 | from autograd.util import subvals, toposort
6 | from autograd.wrap_util import wraps
7 |
8 |
9 | class ConstGraphNode(Node):
10 | __slots__ = ["parents", "partial_fun"]
11 |
12 | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
13 | args = subvals(args, zip(parent_argnums, repeat(None)))
14 |
15 | def partial_fun(partial_args):
16 | return fun(*subvals(args, zip(parent_argnums, partial_args)), **kwargs)
17 |
18 | self.parents = parents
19 | self.partial_fun = partial_fun
20 |
21 | def initialize_root(self):
22 | self.parents = []
23 |
24 |
25 | def const_graph_unary(fun):
26 | graph = []
27 | _fun = [fun] # Allow fun to be freed, since it may have bound args
28 |
29 | def maybe_cached_fun(x):
30 | if graph:
31 | _graph = graph[0]
32 | vals = {_graph[0]: x}
33 | for node in _graph[1:]:
34 | vals[node] = node.partial_fun([vals[p] for p in node.parents])
35 | return vals[node]
36 | else:
37 | start_node = ConstGraphNode.new_root()
38 | end_value, end_node = trace(start_node, _fun.pop(), x)
39 | if end_node is None:
40 | raise Exception("Output is independent of input")
41 | graph.append(list(toposort(end_node))[::-1])
42 | return end_value
43 |
44 | return maybe_cached_fun
45 |
46 |
47 | def const_graph(fun, *args, **kwargs):
48 | partial_fun = partial(fun, *args, **kwargs)
49 | unary_fun = lambda args: partial_fun(*args)
50 | maybe_cached_unary_fun = const_graph_unary(unary_fun)
51 |
52 | @wraps(fun)
53 | def _fun(*args):
54 | return maybe_cached_unary_fun(args)
55 |
56 | return _fun
57 |
58 |
59 | class FullGraphNode(Node):
60 | __slots__ = ["value", "recipe"]
61 |
62 | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
63 | self.value = value
64 | self.recipe = (fun, args, kwargs, zip(parent_argnums, parents))
65 |
66 | def initialize_root(self):
67 | self.value = None
68 | self.recipe = (lambda x: x, (), {}, [])
69 |
70 |
71 | def full_graph(fun, *args, **kwargs):
72 | unary_fun = lambda args: fun(*args, **kwargs)
73 | start_node = FullGraphNode.new_root()
74 | end_value, end_node = trace(start_node, unary_fun, args)
75 | return end_node
76 |
--------------------------------------------------------------------------------
/autograd/numpy/__init__.py:
--------------------------------------------------------------------------------
1 | from . import fft, linalg, numpy_boxes, numpy_jvps, numpy_vjps, numpy_vspaces, random
2 | from .numpy_wrapper import *
3 | from .numpy_wrapper import numpy_version as __version__
4 |
--------------------------------------------------------------------------------
/autograd/numpy/numpy_boxes.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import numpy as np
4 |
5 | from autograd.builtins import SequenceBox
6 | from autograd.extend import Box, primitive
7 |
8 | from . import numpy_wrapper as anp
9 |
10 | Box.__array_priority__ = 90.0
11 |
12 |
13 | class ArrayBox(Box):
14 | __slots__ = []
15 | __array_priority__ = 100.0
16 |
17 | @primitive
18 | def __getitem__(A, idx):
19 | return A[idx]
20 |
21 | # Constants w.r.t float data just pass though
22 | shape = property(lambda self: self._value.shape)
23 | ndim = property(lambda self: self._value.ndim)
24 | size = property(lambda self: self._value.size)
25 | dtype = property(lambda self: self._value.dtype)
26 | T = property(lambda self: anp.transpose(self))
27 |
28 | def __array_namespace__(self, *, api_version: Union[str, None] = None):
29 | return anp
30 |
31 | def __len__(self):
32 | return len(self._value)
33 |
34 | def astype(self, *args, **kwargs):
35 | return anp._astype(self, *args, **kwargs)
36 |
37 | def __neg__(self):
38 | return anp.negative(self)
39 |
40 | def __add__(self, other):
41 | return anp.add(self, other)
42 |
43 | def __sub__(self, other):
44 | return anp.subtract(self, other)
45 |
46 | def __mul__(self, other):
47 | return anp.multiply(self, other)
48 |
49 | def __pow__(self, other):
50 | return anp.power(self, other)
51 |
52 | def __div__(self, other):
53 | return anp.divide(self, other)
54 |
55 | def __mod__(self, other):
56 | return anp.mod(self, other)
57 |
58 | def __truediv__(self, other):
59 | return anp.true_divide(self, other)
60 |
61 | def __matmul__(self, other):
62 | return anp.matmul(self, other)
63 |
64 | def __radd__(self, other):
65 | return anp.add(other, self)
66 |
67 | def __rsub__(self, other):
68 | return anp.subtract(other, self)
69 |
70 | def __rmul__(self, other):
71 | return anp.multiply(other, self)
72 |
73 | def __rpow__(self, other):
74 | return anp.power(other, self)
75 |
76 | def __rdiv__(self, other):
77 | return anp.divide(other, self)
78 |
79 | def __rmod__(self, other):
80 | return anp.mod(other, self)
81 |
82 | def __rtruediv__(self, other):
83 | return anp.true_divide(other, self)
84 |
85 | def __rmatmul__(self, other):
86 | return anp.matmul(other, self)
87 |
88 | def __eq__(self, other):
89 | return anp.equal(self, other)
90 |
91 | def __ne__(self, other):
92 | return anp.not_equal(self, other)
93 |
94 | def __gt__(self, other):
95 | return anp.greater(self, other)
96 |
97 | def __ge__(self, other):
98 | return anp.greater_equal(self, other)
99 |
100 | def __lt__(self, other):
101 | return anp.less(self, other)
102 |
103 | def __le__(self, other):
104 | return anp.less_equal(self, other)
105 |
106 | def __abs__(self):
107 | return anp.abs(self)
108 |
109 | def __hash__(self):
110 | return id(self)
111 |
112 |
113 | ArrayBox.register(np.ndarray)
114 | for type_ in [
115 | float,
116 | np.longdouble,
117 | np.float64,
118 | np.float32,
119 | np.float16,
120 | complex,
121 | np.clongdouble,
122 | np.complex64,
123 | np.complex128,
124 | ]:
125 | ArrayBox.register(type_)
126 |
127 | # These numpy.ndarray methods are just refs to an equivalent numpy function
128 | nondiff_methods = [
129 | "all",
130 | "any",
131 | "argmax",
132 | "argmin",
133 | "argpartition",
134 | "argsort",
135 | "nonzero",
136 | "searchsorted",
137 | "round",
138 | ]
139 | diff_methods = [
140 | "clip",
141 | "compress",
142 | "cumprod",
143 | "cumsum",
144 | "diagonal",
145 | "max",
146 | "mean",
147 | "min",
148 | "prod",
149 | "ptp",
150 | "ravel",
151 | "repeat",
152 | "reshape",
153 | "squeeze",
154 | "std",
155 | "sum",
156 | "swapaxes",
157 | "take",
158 | "trace",
159 | "transpose",
160 | "var",
161 | ]
162 | for method_name in nondiff_methods + diff_methods:
163 | setattr(ArrayBox, method_name, anp.__dict__[method_name])
164 |
165 | # Flatten has no function, only a method.
166 | setattr(ArrayBox, "flatten", anp.__dict__["ravel"])
167 |
168 | if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
169 | SequenceBox.register(np.linalg._linalg.EigResult)
170 | SequenceBox.register(np.linalg._linalg.EighResult)
171 | SequenceBox.register(np.linalg._linalg.QRResult)
172 | SequenceBox.register(np.linalg._linalg.SlogdetResult)
173 | SequenceBox.register(np.linalg._linalg.SVDResult)
174 | elif np.__version__ >= "1.25":
175 | SequenceBox.register(np.linalg.linalg.EigResult)
176 | SequenceBox.register(np.linalg.linalg.EighResult)
177 | SequenceBox.register(np.linalg.linalg.QRResult)
178 | SequenceBox.register(np.linalg.linalg.SlogdetResult)
179 | SequenceBox.register(np.linalg.linalg.SVDResult)
180 |
--------------------------------------------------------------------------------
/autograd/numpy/numpy_vspaces.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from autograd.builtins import NamedTupleVSpace
4 | from autograd.extend import VSpace
5 |
6 |
7 | class ArrayVSpace(VSpace):
8 | def __init__(self, value):
9 | value = np.asarray(value)
10 | self.shape = value.shape
11 | self.dtype = value.dtype
12 |
13 | @property
14 | def size(self):
15 | return np.prod(self.shape)
16 |
17 | @property
18 | def ndim(self):
19 | return len(self.shape)
20 |
21 | def zeros(self):
22 | return np.zeros(self.shape, dtype=self.dtype)
23 |
24 | def ones(self):
25 | return np.ones(self.shape, dtype=self.dtype)
26 |
27 | def standard_basis(self):
28 | for idxs in np.ndindex(*self.shape):
29 | vect = np.zeros(self.shape, dtype=self.dtype)
30 | vect[idxs] = 1
31 | yield vect
32 |
33 | def randn(self):
34 | return np.array(np.random.randn(*self.shape)).astype(self.dtype)
35 |
36 | def _inner_prod(self, x, y):
37 | return np.dot(np.ravel(x), np.ravel(y))
38 |
39 |
40 | class ComplexArrayVSpace(ArrayVSpace):
41 | iscomplex = True
42 |
43 | @property
44 | def size(self):
45 | return np.prod(self.shape) * 2
46 |
47 | def ones(self):
48 | return np.ones(self.shape, dtype=self.dtype) + 1.0j * np.ones(self.shape, dtype=self.dtype)
49 |
50 | def standard_basis(self):
51 | for idxs in np.ndindex(*self.shape):
52 | for v in [1.0, 1.0j]:
53 | vect = np.zeros(self.shape, dtype=self.dtype)
54 | vect[idxs] = v
55 | yield vect
56 |
57 | def randn(self):
58 | return np.array(np.random.randn(*self.shape)).astype(self.dtype) + 1.0j * np.array(
59 | np.random.randn(*self.shape)
60 | ).astype(self.dtype)
61 |
62 | def _inner_prod(self, x, y):
63 | return np.real(np.dot(np.conj(np.ravel(x)), np.ravel(y)))
64 |
65 | def _covector(self, x):
66 | return np.conj(x)
67 |
68 |
69 | VSpace.register(np.ndarray, lambda x: ComplexArrayVSpace(x) if np.iscomplexobj(x) else ArrayVSpace(x))
70 |
71 | for type_ in [float, np.longdouble, np.float64, np.float32, np.float16]:
72 | ArrayVSpace.register(type_)
73 |
74 | for type_ in [complex, np.clongdouble, np.complex64, np.complex128]:
75 | ComplexArrayVSpace.register(type_)
76 |
77 |
78 | if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
79 |
80 | class EigResultVSpace(NamedTupleVSpace):
81 | seq_type = np.linalg._linalg.EigResult
82 |
83 | class EighResultVSpace(NamedTupleVSpace):
84 | seq_type = np.linalg._linalg.EighResult
85 |
86 | class QRResultVSpace(NamedTupleVSpace):
87 | seq_type = np.linalg._linalg.QRResult
88 |
89 | class SlogdetResultVSpace(NamedTupleVSpace):
90 | seq_type = np.linalg._linalg.SlogdetResult
91 |
92 | class SVDResultVSpace(NamedTupleVSpace):
93 | seq_type = np.linalg._linalg.SVDResult
94 |
95 | EigResultVSpace.register(np.linalg._linalg.EigResult)
96 | EighResultVSpace.register(np.linalg._linalg.EighResult)
97 | QRResultVSpace.register(np.linalg._linalg.QRResult)
98 | SlogdetResultVSpace.register(np.linalg._linalg.SlogdetResult)
99 | SVDResultVSpace.register(np.linalg._linalg.SVDResult)
100 | elif np.__version__ >= "1.25":
101 |
102 | class EigResultVSpace(NamedTupleVSpace):
103 | seq_type = np.linalg.linalg.EigResult
104 |
105 | class EighResultVSpace(NamedTupleVSpace):
106 | seq_type = np.linalg.linalg.EighResult
107 |
108 | class QRResultVSpace(NamedTupleVSpace):
109 | seq_type = np.linalg.linalg.QRResult
110 |
111 | class SlogdetResultVSpace(NamedTupleVSpace):
112 | seq_type = np.linalg.linalg.SlogdetResult
113 |
114 | class SVDResultVSpace(NamedTupleVSpace):
115 | seq_type = np.linalg.linalg.SVDResult
116 |
117 | EigResultVSpace.register(np.linalg.linalg.EigResult)
118 | EighResultVSpace.register(np.linalg.linalg.EighResult)
119 | QRResultVSpace.register(np.linalg.linalg.QRResult)
120 | SlogdetResultVSpace.register(np.linalg.linalg.SlogdetResult)
121 | SVDResultVSpace.register(np.linalg.linalg.SVDResult)
122 |
--------------------------------------------------------------------------------
/autograd/numpy/random.py:
--------------------------------------------------------------------------------
1 | import numpy.random as npr
2 |
3 | from .numpy_wrapper import wrap_namespace
4 |
5 | wrap_namespace(npr.__dict__, globals())
6 |
--------------------------------------------------------------------------------
/autograd/scipy/__init__.py:
--------------------------------------------------------------------------------
1 | from . import integrate, signal, special, stats
2 |
--------------------------------------------------------------------------------
/autograd/scipy/integrate.py:
--------------------------------------------------------------------------------
1 | import scipy.integrate
2 |
3 | import autograd.numpy as np
4 | from autograd import make_vjp
5 | from autograd.builtins import tuple
6 | from autograd.extend import defvjp_argnums, primitive
7 | from autograd.misc import flatten
8 |
9 | odeint = primitive(scipy.integrate.odeint)
10 |
11 |
12 | def grad_odeint(yt, func, y0, t, func_args, **kwargs):
13 | # Extended from "Scalable Inference of Ordinary Differential
14 | # Equation Models of Biochemical Processes", Sec. 2.4.2
15 | # Fabian Froehlich, Carolin Loos, Jan Hasenauer, 2017
16 | # https://arxiv.org/abs/1711.08079
17 |
18 | T, D = np.shape(yt)
19 | flat_args, unflatten = flatten(func_args)
20 |
21 | def flat_func(y, t, flat_args):
22 | return func(y, t, *unflatten(flat_args))
23 |
24 | def unpack(x):
25 | # y, vjp_y, vjp_t, vjp_args
26 | return x[0:D], x[D : 2 * D], x[2 * D], x[2 * D + 1 :]
27 |
28 | def augmented_dynamics(augmented_state, t, flat_args):
29 | # Orginal system augmented with vjp_y, vjp_t and vjp_args.
30 | y, vjp_y, _, _ = unpack(augmented_state)
31 | vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args)
32 | vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y)
33 | return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))
34 |
35 | def vjp_all(g):
36 | vjp_y = g[-1, :]
37 | vjp_t0 = 0
38 | time_vjp_list = []
39 | vjp_args = np.zeros(np.size(flat_args))
40 |
41 | for i in range(T - 1, 0, -1):
42 | # Compute effect of moving measurement time.
43 | vjp_cur_t = np.dot(func(yt[i, :], t[i], *func_args), g[i, :])
44 | time_vjp_list.append(vjp_cur_t)
45 | vjp_t0 = vjp_t0 - vjp_cur_t
46 |
47 | # Run augmented system backwards to the previous observation.
48 | aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
49 | aug_ans = odeint(
50 | augmented_dynamics, aug_y0, np.array([t[i], t[i - 1]]), tuple((flat_args,)), **kwargs
51 | )
52 | _, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])
53 |
54 | # Add gradient from current output.
55 | vjp_y = vjp_y + g[i - 1, :]
56 |
57 | time_vjp_list.append(vjp_t0)
58 | vjp_times = np.hstack(time_vjp_list)[::-1]
59 |
60 | return None, vjp_y, vjp_times, unflatten(vjp_args)
61 |
62 | return vjp_all
63 |
64 |
65 | def argnums_unpack(all_vjp_builder):
66 | # A generic autograd helper function. Takes a function that
67 | # builds vjps for all arguments, and wraps it to return only required vjps.
68 | def build_selected_vjps(argnums, ans, combined_args, kwargs):
69 | vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)
70 |
71 | def chosen_vjps(g): # Returns whichever vjps were asked for.
72 | all_vjps = vjp_func(g)
73 | return [all_vjps[argnum] for argnum in argnums]
74 |
75 | return chosen_vjps
76 |
77 | return build_selected_vjps
78 |
79 |
80 | defvjp_argnums(odeint, argnums_unpack(grad_odeint))
81 |
--------------------------------------------------------------------------------
/autograd/scipy/linalg.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import scipy.linalg
4 |
5 | import autograd.numpy as anp
6 | from autograd.extend import defjvp, defjvp_argnums, defvjp, defvjp_argnums
7 | from autograd.numpy.numpy_wrapper import wrap_namespace
8 |
9 | wrap_namespace(scipy.linalg.__dict__, globals()) # populates module namespace
10 |
11 |
12 | def _vjp_sqrtm(ans, A, disp=True, blocksize=64):
13 | assert disp, "sqrtm vjp not implemented for disp=False"
14 | ans_transp = anp.transpose(ans)
15 |
16 | def vjp(g):
17 | return anp.real(solve_sylvester(ans_transp, ans_transp, g))
18 |
19 | return vjp
20 |
21 |
22 | defvjp(sqrtm, _vjp_sqrtm)
23 |
24 |
25 | def _flip(a, trans):
26 | if anp.iscomplexobj(a):
27 | return "H" if trans in ("N", 0) else "N"
28 | else:
29 | return "T" if trans in ("N", 0) else "N"
30 |
31 |
32 | def grad_solve_triangular(ans, a, b, trans=0, lower=False, **kwargs):
33 | tri = anp.tril if (lower ^ (_flip(a, trans) == "N")) else anp.triu
34 | transpose = lambda x: x if _flip(a, trans) != "N" else x.T
35 | al2d = lambda x: x if x.ndim > 1 else x[..., None]
36 |
37 | def vjp(g):
38 | v = al2d(solve_triangular(a, g, trans=_flip(a, trans), lower=lower))
39 | return -transpose(tri(anp.dot(v, al2d(ans).T)))
40 |
41 | return vjp
42 |
43 |
44 | defvjp(
45 | solve_triangular,
46 | grad_solve_triangular,
47 | lambda ans, a, b, trans=0, lower=False, **kwargs: lambda g: solve_triangular(
48 | a, g, trans=_flip(a, trans), lower=lower
49 | ),
50 | )
51 |
52 |
53 | def grad_solve_banded(argnum, ans, l_and_u, a, b):
54 | updim = lambda x: x if x.ndim == a.ndim else x[..., None]
55 |
56 | def transpose_banded(l_and_u, a):
57 | # Compute the transpose of a banded matrix.
58 | # The transpose is itself a banded matrix.
59 |
60 | num_rows = a.shape[0]
61 |
62 | shifts = anp.arange(-l_and_u[1], l_and_u[0] + 1)
63 |
64 | T_a = anp.roll(a[:1, :], shifts[0])
65 | for rr in range(1, num_rows):
66 | T_a = anp.vstack([T_a, anp.flipud(anp.roll(a[rr : rr + 1, :], shifts[rr]))])
67 | T_a = anp.flipud(T_a)
68 |
69 | T_l_and_u = anp.flip(l_and_u)
70 |
71 | return T_l_and_u, T_a
72 |
73 | def banded_dot(l_and_u, uu, vv):
74 | # Compute tensor product of vectors uu and vv.
75 | # Tensor product elements are resticted to the bands specified by l_and_u.
76 |
77 | # TODO: replace the brute-force ravel() by smarter dimension handeling of uu and vv
78 |
79 | # main diagonal
80 | banded_uv = anp.ravel(uu) * anp.ravel(vv)
81 |
82 | # stack below the sub-diagonals
83 | for rr in range(1, l_and_u[0] + 1):
84 | banded_uv_rr = anp.hstack([anp.ravel(uu)[rr:] * anp.ravel(vv)[:-rr], anp.zeros(rr)])
85 | banded_uv = anp.vstack([banded_uv, banded_uv_rr])
86 |
87 | # stack above the sup-diagonals
88 | for rr in range(1, l_and_u[1] + 1):
89 | banded_uv_rr = anp.hstack([anp.zeros(rr), anp.ravel(uu)[:-rr] * anp.ravel(vv)[rr:]])
90 | banded_uv = anp.vstack([banded_uv_rr, banded_uv])
91 |
92 | return banded_uv
93 |
94 | T_l_and_u, T_a = transpose_banded(l_and_u, a)
95 |
96 | if argnum == 1:
97 | return lambda g: -banded_dot(
98 | l_and_u, updim(solve_banded(T_l_and_u, T_a, g)), anp.transpose(updim(ans))
99 | )
100 | elif argnum == 2:
101 | return lambda g: solve_banded(T_l_and_u, T_a, g)
102 |
103 |
104 | defvjp(solve_banded, partial(grad_solve_banded, 1), partial(grad_solve_banded, 2), argnums=[1, 2])
105 |
106 |
107 | def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64):
108 | assert disp, "sqrtm jvp not implemented for disp=False"
109 | return solve_sylvester(ans, ans, dA)
110 |
111 |
112 | defjvp(sqrtm, _jvp_sqrtm)
113 |
114 |
115 | def _jvp_sylvester(argnums, dms, ans, args, _):
116 | a, b, q = args
117 | if 0 in argnums:
118 | da = dms[0]
119 | db = dms[1] if 1 in argnums else 0
120 | else:
121 | da = 0
122 | db = dms[0] if 1 in argnums else 0
123 | dq = dms[-1] if 2 in argnums else 0
124 | rhs = dq - anp.dot(da, ans) - anp.dot(ans, db)
125 | return solve_sylvester(a, b, rhs)
126 |
127 |
128 | defjvp_argnums(solve_sylvester, _jvp_sylvester)
129 |
130 |
131 | def _vjp_sylvester(argnums, ans, args, _):
132 | a, b, q = args
133 |
134 | def vjp(g):
135 | vjps = []
136 | q_vjp = solve_sylvester(anp.transpose(a), anp.transpose(b), g)
137 | if 0 in argnums:
138 | vjps.append(-anp.dot(q_vjp, anp.transpose(ans)))
139 | if 1 in argnums:
140 | vjps.append(-anp.dot(anp.transpose(ans), q_vjp))
141 | if 2 in argnums:
142 | vjps.append(q_vjp)
143 | return tuple(vjps)
144 |
145 | return vjp
146 |
147 |
148 | defvjp_argnums(solve_sylvester, _vjp_sylvester)
149 |
--------------------------------------------------------------------------------
/autograd/scipy/stats/__init__.py:
--------------------------------------------------------------------------------
1 | from . import beta, chi2, gamma, norm, poisson, t
2 |
3 | # Try block needed in case the user has an
4 | # old version of scipy without multivariate normal.
5 | try:
6 | from . import multivariate_normal
7 | except AttributeError:
8 | pass
9 |
10 | try:
11 | from . import dirichlet
12 | except AttributeError:
13 | pass
14 |
--------------------------------------------------------------------------------
/autograd/scipy/stats/beta.py:
--------------------------------------------------------------------------------
1 | import scipy.stats
2 |
3 | import autograd.numpy as np
4 | from autograd.extend import defvjp, primitive
5 | from autograd.numpy.numpy_vjps import unbroadcast_f
6 | from autograd.scipy.special import beta, psi
7 |
8 | cdf = primitive(scipy.stats.beta.cdf)
9 | logpdf = primitive(scipy.stats.beta.logpdf)
10 | pdf = primitive(scipy.stats.beta.pdf)
11 |
12 |
13 | def grad_beta_logpdf_arg0(x, a, b):
14 | return (1 + a * (x - 1) + x * (b - 2)) / (x * (x - 1))
15 |
16 |
17 | def grad_beta_logpdf_arg1(x, a, b):
18 | return np.log(x) - psi(a) + psi(a + b)
19 |
20 |
21 | def grad_beta_logpdf_arg2(x, a, b):
22 | return np.log1p(-x) - psi(b) + psi(a + b)
23 |
24 |
25 | defvjp(
26 | cdf,
27 | lambda ans, x, a, b: unbroadcast_f(
28 | x, lambda g: g * np.power(x, a - 1) * np.power(1 - x, b - 1) / beta(a, b)
29 | ),
30 | argnums=[0],
31 | )
32 | defvjp(
33 | logpdf,
34 | lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * grad_beta_logpdf_arg0(x, a, b)),
35 | lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * grad_beta_logpdf_arg1(x, a, b)),
36 | lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * grad_beta_logpdf_arg2(x, a, b)),
37 | )
38 | defvjp(
39 | pdf,
40 | lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * ans * grad_beta_logpdf_arg0(x, a, b)),
41 | lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * ans * grad_beta_logpdf_arg1(x, a, b)),
42 | lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * ans * grad_beta_logpdf_arg2(x, a, b)),
43 | )
44 |
--------------------------------------------------------------------------------
/autograd/scipy/stats/chi2.py:
--------------------------------------------------------------------------------
1 | import scipy.stats
2 |
3 | import autograd.numpy as np
4 | from autograd.extend import defvjp, primitive
5 | from autograd.numpy.numpy_vjps import unbroadcast_f
6 | from autograd.scipy.special import gamma
7 |
8 | cdf = primitive(scipy.stats.chi2.cdf)
9 | logpdf = primitive(scipy.stats.chi2.logpdf)
10 | pdf = primitive(scipy.stats.chi2.pdf)
11 |
12 |
13 | def grad_chi2_logpdf(x, df):
14 | return np.where(df % 1 == 0, (df - x - 2) / (2 * x), 0)
15 |
16 |
17 | defvjp(
18 | cdf,
19 | lambda ans, x, df: unbroadcast_f(
20 | x, lambda g: g * np.power(2.0, -df / 2) * np.exp(-x / 2) * np.power(x, df / 2 - 1) / gamma(df / 2)
21 | ),
22 | argnums=[0],
23 | )
24 | defvjp(logpdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * grad_chi2_logpdf(x, df)), argnums=[0])
25 | defvjp(pdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * ans * grad_chi2_logpdf(x, df)), argnums=[0])
26 |
--------------------------------------------------------------------------------
/autograd/scipy/stats/dirichlet.py:
--------------------------------------------------------------------------------
1 | import scipy.stats
2 |
3 | import autograd.numpy as np
4 | from autograd.extend import defvjp, primitive
5 | from autograd.scipy.special import digamma
6 |
7 | rvs = primitive(scipy.stats.dirichlet.rvs)
8 | pdf = primitive(scipy.stats.dirichlet.pdf)
9 | logpdf = primitive(scipy.stats.dirichlet.logpdf)
10 |
11 | defvjp(
12 | logpdf,
13 | lambda ans, x, alpha: lambda g: g * (alpha - 1) / x,
14 | lambda ans, x, alpha: lambda g: g * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)),
15 | )
16 |
17 | # Same as log pdf, but multiplied by the pdf (ans).
18 | defvjp(
19 | pdf,
20 | lambda ans, x, alpha: lambda g: g * ans * (alpha - 1) / x,
21 | lambda ans, x, alpha: lambda g: g * ans * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)),
22 | )
23 |
--------------------------------------------------------------------------------
/autograd/scipy/stats/gamma.py:
--------------------------------------------------------------------------------
1 | import scipy.stats
2 |
3 | import autograd.numpy as np
4 | from autograd.extend import defvjp, primitive
5 | from autograd.numpy.numpy_vjps import unbroadcast_f
6 | from autograd.scipy.special import gamma, psi
7 |
8 | cdf = primitive(scipy.stats.gamma.cdf)
9 | logpdf = primitive(scipy.stats.gamma.logpdf)
10 | pdf = primitive(scipy.stats.gamma.pdf)
11 |
12 |
13 | def grad_gamma_logpdf_arg0(x, a):
14 | return (a - x - 1) / x
15 |
16 |
17 | def grad_gamma_logpdf_arg1(x, a):
18 | return np.log(x) - psi(a)
19 |
20 |
21 | defvjp(
22 | cdf,
23 | lambda ans, x, a: unbroadcast_f(x, lambda g: g * np.exp(-x) * np.power(x, a - 1) / gamma(a)),
24 | argnums=[0],
25 | )
26 | defvjp(
27 | logpdf,
28 | lambda ans, x, a: unbroadcast_f(x, lambda g: g * grad_gamma_logpdf_arg0(x, a)),
29 | lambda ans, x, a: unbroadcast_f(a, lambda g: g * grad_gamma_logpdf_arg1(x, a)),
30 | )
31 | defvjp(
32 | pdf,
33 | lambda ans, x, a: unbroadcast_f(x, lambda g: g * ans * grad_gamma_logpdf_arg0(x, a)),
34 | lambda ans, x, a: unbroadcast_f(a, lambda g: g * ans * grad_gamma_logpdf_arg1(x, a)),
35 | )
36 |
--------------------------------------------------------------------------------
/autograd/scipy/stats/multivariate_normal.py:
--------------------------------------------------------------------------------
1 | import scipy.stats
2 |
3 | import autograd.numpy as np
4 | from autograd.extend import defvjp, primitive
5 | from autograd.numpy.numpy_vjps import unbroadcast_f
6 |
7 | pdf = primitive(scipy.stats.multivariate_normal.pdf)
8 | logpdf = primitive(scipy.stats.multivariate_normal.logpdf)
9 | entropy = primitive(scipy.stats.multivariate_normal.entropy)
10 |
11 | # With thanks to Eric Bresch.
12 | # Some formulas are from
13 | # "An extended collection of matrix derivative results
14 | # for forward and reverse mode algorithmic differentiation"
15 | # by Mike Giles
16 | # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
17 |
18 |
19 | def generalized_outer_product(x):
20 | if np.ndim(x) == 1:
21 | return np.outer(x, x)
22 | return np.matmul(x, np.swapaxes(x, -1, -2))
23 |
24 |
25 | def covgrad(x, mean, cov, allow_singular=False):
26 | if allow_singular:
27 | raise NotImplementedError(
28 | "The multivariate normal pdf is not differentiable w.r.t. a singular covariance matix"
29 | )
30 | J = np.linalg.inv(cov)
31 | solved = np.matmul(J, np.expand_dims(x - mean, -1))
32 | return 1.0 / 2 * (generalized_outer_product(solved) - J)
33 |
34 |
35 | def solve(allow_singular):
36 | if allow_singular:
37 | return lambda A, x: np.dot(np.linalg.pinv(A), x)
38 | else:
39 | return np.linalg.solve
40 |
41 |
42 | defvjp(
43 | logpdf,
44 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
45 | x, lambda g: -np.expand_dims(np.atleast_1d(g), 1) * solve(allow_singular)(cov, (x - mean).T).T
46 | ),
47 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
48 | mean, lambda g: np.expand_dims(np.atleast_1d(g), 1) * solve(allow_singular)(cov, (x - mean).T).T
49 | ),
50 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
51 | cov, lambda g: np.reshape(g, np.shape(g) + (1, 1)) * covgrad(x, mean, cov, allow_singular)
52 | ),
53 | )
54 |
55 | # Same as log pdf, but multiplied by the pdf (ans).
56 | defvjp(
57 | pdf,
58 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
59 | x, lambda g: -np.expand_dims(np.atleast_1d(ans * g), 1) * solve(allow_singular)(cov, (x - mean).T).T
60 | ),
61 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
62 | mean,
63 | lambda g: np.expand_dims(np.atleast_1d(ans * g), 1) * solve(allow_singular)(cov, (x - mean).T).T,
64 | ),
65 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f(
66 | cov, lambda g: np.reshape(ans * g, np.shape(g) + (1, 1)) * covgrad(x, mean, cov, allow_singular)
67 | ),
68 | )
69 |
70 | defvjp(entropy, None, lambda ans, mean, cov: unbroadcast_f(cov, lambda g: 0.5 * g * np.linalg.inv(cov).T))
71 |
--------------------------------------------------------------------------------
/autograd/scipy/stats/norm.py:
--------------------------------------------------------------------------------
1 | """Gradients of the normal distribution."""
2 |
3 | import scipy.stats
4 |
5 | import autograd.numpy as anp
6 | from autograd.extend import defvjp, primitive
7 | from autograd.numpy.numpy_vjps import unbroadcast_f
8 |
9 | pdf = primitive(scipy.stats.norm.pdf)
10 | cdf = primitive(scipy.stats.norm.cdf)
11 | sf = primitive(scipy.stats.norm.sf)
12 | logpdf = primitive(scipy.stats.norm.logpdf)
13 | logcdf = primitive(scipy.stats.norm.logcdf)
14 | logsf = primitive(scipy.stats.norm.logsf)
15 |
16 | defvjp(
17 | pdf,
18 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * ans * (x - loc) / scale**2),
19 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * ans * (x - loc) / scale**2),
20 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
21 | scale, lambda g: g * ans * (((x - loc) / scale) ** 2 - 1.0) / scale
22 | ),
23 | )
24 |
25 | defvjp(
26 | cdf,
27 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: g * pdf(x, loc, scale)),
28 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: -g * pdf(x, loc, scale)),
29 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
30 | scale, lambda g: -g * pdf(x, loc, scale) * (x - loc) / scale
31 | ),
32 | )
33 |
34 | defvjp(
35 | logpdf,
36 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * (x - loc) / scale**2),
37 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * (x - loc) / scale**2),
38 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
39 | scale, lambda g: g * (-1.0 / scale + (x - loc) ** 2 / scale**3)
40 | ),
41 | )
42 |
43 | defvjp(
44 | logcdf,
45 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
46 | x, lambda g: g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))
47 | ),
48 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
49 | loc, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))
50 | ),
51 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
52 | scale, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale)) * (x - loc) / scale
53 | ),
54 | )
55 |
56 | defvjp(
57 | logsf,
58 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
59 | x, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))
60 | ),
61 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
62 | loc, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))
63 | ),
64 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
65 | scale, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale)) * (x - loc) / scale
66 | ),
67 | )
68 |
69 | defvjp(
70 | sf,
71 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * pdf(x, loc, scale)),
72 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * pdf(x, loc, scale)),
73 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(
74 | scale, lambda g: g * pdf(x, loc, scale) * (x - loc) / scale
75 | ),
76 | )
77 |
--------------------------------------------------------------------------------
/autograd/scipy/stats/poisson.py:
--------------------------------------------------------------------------------
1 | import scipy.stats
2 |
3 | import autograd.numpy as np
4 | from autograd.extend import defvjp, primitive
5 | from autograd.numpy.numpy_vjps import unbroadcast_f
6 |
7 | cdf = primitive(scipy.stats.poisson.cdf)
8 | logpmf = primitive(scipy.stats.poisson.logpmf)
9 | pmf = primitive(scipy.stats.poisson.pmf)
10 |
11 |
12 | def grad_poisson_logpmf(k, mu):
13 | return np.where(k % 1 == 0, k / mu - 1, 0)
14 |
15 |
16 | defvjp(cdf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * -pmf(np.floor(k), mu)), argnums=[1])
17 | defvjp(logpmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * grad_poisson_logpmf(k, mu)), argnums=[1])
18 | defvjp(
19 | pmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * ans * grad_poisson_logpmf(k, mu)), argnums=[1]
20 | )
21 |
--------------------------------------------------------------------------------
/autograd/scipy/stats/t.py:
--------------------------------------------------------------------------------
1 | """Gradients of the univariate t distribution."""
2 |
3 | import scipy.stats
4 |
5 | import autograd.numpy as np
6 | from autograd.extend import defvjp, primitive
7 | from autograd.numpy.numpy_vjps import unbroadcast_f
8 | from autograd.scipy.special import psi
9 |
10 | pdf = primitive(scipy.stats.t.pdf)
11 | cdf = primitive(scipy.stats.t.cdf)
12 | logpdf = primitive(scipy.stats.t.logpdf)
13 | logcdf = primitive(scipy.stats.t.logcdf)
14 |
15 |
16 | def grad_tlogpdf_diff(diff, df):
17 | return -diff * (1.0 + df) / (diff**2 + df)
18 |
19 |
20 | def grad_tlogpdf_x(x, df, loc, scale):
21 | return grad_tlogpdf_diff((x - loc) / scale, df) / scale
22 |
23 |
24 | def grad_tlogpdf_loc(x, df, loc, scale):
25 | return -grad_tlogpdf_diff((x - loc) / scale, df) / scale
26 |
27 |
28 | def grad_tlogpdf_scale(x, df, loc, scale):
29 | diff = x - loc
30 | return -(df * (scale**2 - diff**2)) / (scale * (df * scale**2 + diff**2))
31 |
32 |
33 | def grad_tlogpdf_df(x, df, loc, scale):
34 | y = (x - loc) / scale
35 | return 0.5 * (
36 | (y**2 * (df + 1)) / (df * (y**2 + df))
37 | - np.log(y**2 / df + 1)
38 | - 1.0 / df
39 | - psi(df / 2.0)
40 | + psi((df + 1) / 2.0)
41 | )
42 |
43 |
44 | defvjp(
45 | pdf,
46 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
47 | x, lambda g: g * ans * grad_tlogpdf_x(x, df, loc, scale)
48 | ),
49 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
50 | df, lambda g: g * ans * grad_tlogpdf_df(x, df, loc, scale)
51 | ),
52 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
53 | loc, lambda g: g * ans * grad_tlogpdf_loc(x, df, loc, scale)
54 | ),
55 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
56 | scale, lambda g: g * ans * grad_tlogpdf_scale(x, df, loc, scale)
57 | ),
58 | )
59 |
60 | defvjp(
61 | cdf,
62 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: g * pdf(x, df, loc, scale)),
63 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: -g * pdf(x, df, loc, scale)),
64 | argnums=(0, 2),
65 | )
66 |
67 | defvjp(
68 | logpdf,
69 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: g * grad_tlogpdf_x(x, df, loc, scale)),
70 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
71 | df, lambda g: g * grad_tlogpdf_df(x, df, loc, scale)
72 | ),
73 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
74 | loc, lambda g: g * grad_tlogpdf_loc(x, df, loc, scale)
75 | ),
76 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
77 | scale, lambda g: g * grad_tlogpdf_scale(x, df, loc, scale)
78 | ),
79 | )
80 |
81 | defvjp(
82 | logcdf,
83 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
84 | x, lambda g: g * np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale))
85 | ),
86 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(
87 | loc, lambda g: -g * np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale))
88 | ),
89 | argnums=(0, 2),
90 | )
91 |
--------------------------------------------------------------------------------
/autograd/test_util.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | from .core import make_jvp, make_vjp, vspace
4 | from .wrap_util import get_name, unary_to_nary
5 |
6 | TOL = 1e-6
7 | RTOL = 1e-6
8 |
9 |
10 | def scalar_close(a, b):
11 | return abs(a - b) < TOL or abs(a - b) / abs(a + b) < RTOL
12 |
13 |
14 | EPS = 1e-6
15 |
16 |
17 | def make_numerical_jvp(f, x):
18 | y = f(x)
19 | x_vs, y_vs = vspace(x), vspace(y)
20 |
21 | def jvp(v):
22 | # (f(x + v*eps/2) - f(x - v*eps/2)) / eps
23 | f_x_plus = f(x_vs.add(x, x_vs.scalar_mul(v, EPS / 2)))
24 | f_x_minus = f(x_vs.add(x, x_vs.scalar_mul(v, -EPS / 2)))
25 | neg_f_x_minus = y_vs.scalar_mul(f_x_minus, -1.0)
26 | return y_vs.scalar_mul(y_vs.add(f_x_plus, neg_f_x_minus), 1.0 / EPS)
27 |
28 | return jvp
29 |
30 |
31 | def check_vjp(f, x):
32 | vjp, y = make_vjp(f, x)
33 | jvp = make_numerical_jvp(f, x)
34 | x_vs, y_vs = vspace(x), vspace(y)
35 | x_v, y_v = x_vs.randn(), y_vs.randn()
36 |
37 | vjp_y = x_vs.covector(vjp(y_vs.covector(y_v)))
38 | assert vspace(vjp_y) == x_vs
39 | vjv_exact = x_vs.inner_prod(x_v, vjp_y)
40 | vjv_numeric = y_vs.inner_prod(y_v, jvp(x_v))
41 | assert scalar_close(vjv_numeric, vjv_exact), (
42 | "Derivative (VJP) check of {} failed with arg {}:\nanalytic: {}\nnumeric: {}".format(
43 | get_name(f), x, vjv_exact, vjv_numeric
44 | )
45 | )
46 |
47 |
48 | def check_jvp(f, x):
49 | jvp = make_jvp(f, x)
50 | jvp_numeric = make_numerical_jvp(f, x)
51 | x_v = vspace(x).randn()
52 | check_equivalent(jvp(x_v)[1], jvp_numeric(x_v))
53 |
54 |
55 | def check_equivalent(x, y):
56 | x_vs, y_vs = vspace(x), vspace(y)
57 | assert x_vs == y_vs, f"VSpace mismatch:\nx: {x_vs}\ny: {y_vs}"
58 | v = x_vs.randn()
59 | assert scalar_close(x_vs.inner_prod(x, v), x_vs.inner_prod(y, v)), f"Value mismatch:\nx: {x}\ny: {y}"
60 |
61 |
62 | @unary_to_nary
63 | def check_grads(f, x, modes=["fwd", "rev"], order=2):
64 | assert all(m in ["fwd", "rev"] for m in modes)
65 | if "fwd" in modes:
66 | check_jvp(f, x)
67 | if order > 1:
68 | grad_f = lambda x, v: make_jvp(f, x)(v)[1]
69 | grad_f.__name__ = f"jvp_{get_name(f)}"
70 | v = vspace(x).randn()
71 | check_grads(grad_f, (0, 1), modes, order=order - 1)(x, v)
72 | if "rev" in modes:
73 | check_vjp(f, x)
74 | if order > 1:
75 | grad_f = lambda x, v: make_vjp(f, x)[0](v)
76 | grad_f.__name__ = f"vjp_{get_name(f)}"
77 | v = vspace(f(x)).randn()
78 | check_grads(grad_f, (0, 1), modes, order=order - 1)(x, v)
79 |
80 |
81 | def combo_check(fun, *args, **kwargs):
82 | # Tests all combinations of args and kwargs given.
83 | _check_grads = lambda f: check_grads(f, *args, **kwargs)
84 |
85 | def _combo_check(*args, **kwargs):
86 | kwarg_key_vals = [[(k, x) for x in xs] for k, xs in kwargs.items()]
87 | for _args in product(*args):
88 | for _kwargs in product(*kwarg_key_vals):
89 | _check_grads(fun)(*_args, **dict(_kwargs))
90 |
91 | return _combo_check
92 |
--------------------------------------------------------------------------------
/autograd/tracer.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from collections import defaultdict
3 | from contextlib import contextmanager
4 |
5 | from .util import subvals, toposort
6 | from .wrap_util import wraps
7 |
8 |
9 | def trace(start_node, fun, x):
10 | with trace_stack.new_trace() as t:
11 | start_box = new_box(x, t, start_node)
12 | end_box = fun(start_box)
13 | if isbox(end_box) and end_box._trace == start_box._trace:
14 | return end_box._value, end_box._node
15 | else:
16 | warnings.warn("Output seems independent of input.")
17 | return end_box, None
18 |
19 |
20 | class Node:
21 | __slots__ = []
22 |
23 | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
24 | assert False
25 |
26 | def initialize_root(self, *args, **kwargs):
27 | assert False
28 |
29 | @classmethod
30 | def new_root(cls, *args, **kwargs):
31 | root = cls.__new__(cls)
32 | root.initialize_root(*args, **kwargs)
33 | return root
34 |
35 |
36 | def primitive(f_raw):
37 | """
38 | Wraps a function so that its gradient can be specified and its invocation
39 | can be recorded. For examples, see the docs."""
40 |
41 | @wraps(f_raw)
42 | def f_wrapped(*args, **kwargs):
43 | boxed_args, trace, node_constructor = find_top_boxed_args(args)
44 | if boxed_args:
45 | argvals = subvals(args, [(argnum, box._value) for argnum, box in boxed_args])
46 | if f_wrapped in notrace_primitives[node_constructor]:
47 | return f_wrapped(*argvals, **kwargs)
48 | parents = tuple(box._node for _, box in boxed_args)
49 | argnums = tuple(argnum for argnum, _ in boxed_args)
50 | ans = f_wrapped(*argvals, **kwargs)
51 | node = node_constructor(ans, f_wrapped, argvals, kwargs, argnums, parents)
52 | return new_box(ans, trace, node)
53 | else:
54 | return f_raw(*args, **kwargs)
55 |
56 | f_wrapped.fun = f_raw
57 | f_wrapped._is_autograd_primitive = True
58 | return f_wrapped
59 |
60 |
61 | notrace_primitives = defaultdict(set)
62 |
63 |
64 | def register_notrace(trace_type, primitive_fun):
65 | notrace_primitives[trace_type].add(primitive_fun)
66 |
67 |
68 | def notrace_primitive(f_raw):
69 | @wraps(f_raw)
70 | def f_wrapped(*args, **kwargs):
71 | argvals = map(getval, args)
72 | return f_raw(*argvals, **kwargs)
73 |
74 | f_wrapped._is_primitive = True
75 | return f_wrapped
76 |
77 |
78 | def find_top_boxed_args(args):
79 | top_trace = -1
80 | top_boxes = []
81 | top_node_type = None
82 | for argnum, arg in enumerate(args):
83 | if isbox(arg):
84 | trace = arg._trace
85 | if trace > top_trace:
86 | top_boxes = [(argnum, arg)]
87 | top_trace = trace
88 | top_node_type = type(arg._node)
89 | elif trace == top_trace:
90 | top_boxes.append((argnum, arg))
91 | return top_boxes, top_trace, top_node_type
92 |
93 |
94 | class TraceStack:
95 | def __init__(self):
96 | self.top = -1
97 |
98 | @contextmanager
99 | def new_trace(self):
100 | self.top += 1
101 | yield self.top
102 | self.top -= 1
103 |
104 |
105 | trace_stack = TraceStack()
106 |
107 |
108 | class Box:
109 | type_mappings = {}
110 | types = set()
111 |
112 | __slots__ = ["_value", "_trace", "_node"]
113 |
114 | def __init__(self, value, trace, node):
115 | self._value = value
116 | self._node = node
117 | self._trace = trace
118 |
119 | def __bool__(self):
120 | return bool(self._value)
121 |
122 | __nonzero__ = __bool__
123 |
124 | def __str__(self):
125 | return f"Autograd {type(self).__name__} with value {str(self._value)}"
126 |
127 | @classmethod
128 | def register(cls, value_type):
129 | Box.types.add(cls)
130 | Box.type_mappings[value_type] = cls
131 | Box.type_mappings[cls] = cls
132 |
133 |
134 | box_type_mappings = Box.type_mappings
135 |
136 |
137 | def new_box(value, trace, node):
138 | try:
139 | return box_type_mappings[type(value)](value, trace, node)
140 | except KeyError:
141 | raise TypeError(f"Can't differentiate w.r.t. type {type(value)}")
142 |
143 |
144 | box_types = Box.types
145 | isbox = lambda x: type(x) in box_types # almost 3X faster than isinstance(x, Box)
146 | getval = lambda x: getval(x._value) if isbox(x) else x
147 |
--------------------------------------------------------------------------------
/autograd/util.py:
--------------------------------------------------------------------------------
1 | import operator
2 |
3 |
4 | def subvals(x, ivs):
5 | x_ = list(x)
6 | for i, v in ivs:
7 | x_[i] = v
8 | return tuple(x_)
9 |
10 |
11 | def subval(x, i, v):
12 | x_ = list(x)
13 | x_[i] = v
14 | return tuple(x_)
15 |
16 |
17 | def func(f):
18 | return f
19 |
20 |
21 | def toposort(end_node, parents=operator.attrgetter("parents")):
22 | child_counts = {}
23 | stack = [end_node]
24 | while stack:
25 | node = stack.pop()
26 | if node in child_counts:
27 | child_counts[node] += 1
28 | else:
29 | child_counts[node] = 1
30 | stack.extend(parents(node))
31 |
32 | childless_nodes = [end_node]
33 | while childless_nodes:
34 | node = childless_nodes.pop()
35 | yield node
36 | for parent in parents(node):
37 | if child_counts[parent] == 1:
38 | childless_nodes.append(parent)
39 | else:
40 | child_counts[parent] -= 1
41 |
42 |
43 | # -------------------- deprecation warnings -----------------------
44 |
45 | import warnings
46 |
47 | deprecation_msg = """
48 | The quick_grad_check function is deprecated. See the update guide:
49 | https://github.com/HIPS/autograd/blob/master/docs/updateguide.md"""
50 |
51 |
52 | def quick_grad_check(
53 | fun, arg0, extra_args=(), kwargs={}, verbose=True, eps=1e-4, rtol=1e-4, atol=1e-6, rs=None
54 | ):
55 | warnings.warn(deprecation_msg)
56 | from autograd.test_util import check_grads
57 |
58 | fun_ = lambda arg0: fun(arg0, *extra_args, **kwargs)
59 | check_grads(fun_, modes=["rev"], order=1)(arg0)
60 |
--------------------------------------------------------------------------------
/autograd/wrap_util.py:
--------------------------------------------------------------------------------
1 | from .util import subvals
2 |
3 |
4 | def unary_to_nary(unary_operator):
5 | @wraps(unary_operator)
6 | def nary_operator(fun, argnum=0, *nary_op_args, **nary_op_kwargs):
7 | assert type(argnum) in (int, tuple, list), argnum
8 |
9 | @wrap_nary_f(fun, unary_operator, argnum)
10 | def nary_f(*args, **kwargs):
11 | @wraps(fun)
12 | def unary_f(x):
13 | if isinstance(argnum, int):
14 | subargs = subvals(args, [(argnum, x)])
15 | else:
16 | subargs = subvals(args, zip(argnum, x))
17 | return fun(*subargs, **kwargs)
18 |
19 | if isinstance(argnum, int):
20 | x = args[argnum]
21 | else:
22 | x = tuple(args[i] for i in argnum)
23 | return unary_operator(unary_f, x, *nary_op_args, **nary_op_kwargs)
24 |
25 | return nary_f
26 |
27 | return nary_operator
28 |
29 |
30 | def wraps(fun, namestr="{fun}", docstr="{doc}", **kwargs):
31 | def _wraps(f):
32 | try:
33 | f.__name__ = namestr.format(fun=get_name(fun), **kwargs)
34 | f.__doc__ = docstr.format(fun=get_name(fun), doc=get_doc(fun), **kwargs)
35 | except BaseException:
36 | pass
37 | finally:
38 | return f
39 |
40 | return _wraps
41 |
42 |
43 | def wrap_nary_f(fun, op, argnum):
44 | namestr = "{op}_of_{fun}_wrt_argnum_{argnum}"
45 | docstr = """\
46 | {op} of function {fun} with respect to argument number {argnum}. Takes the
47 | same arguments as {fun} but returns the {op}.
48 | """
49 | return wraps(fun, namestr, docstr, op=get_name(op), argnum=argnum)
50 |
51 |
52 | get_name = lambda f: getattr(f, "__name__", "[unknown name]")
53 | get_doc = lambda f: getattr(f, "__doc__", "")
54 |
--------------------------------------------------------------------------------
/benchmarks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/benchmarks/__init__.py
--------------------------------------------------------------------------------
/benchmarks/asv.conf.json.sample:
--------------------------------------------------------------------------------
1 | {
2 | "version": 1,
3 | "project": "autograd",
4 | "project_url": "http://github.com/hips/autograd",
5 | "branches": ["master"],
6 | "dvcs": "git",
7 | "environment_type": "virtualenv",
8 | "install_timeout": 600,
9 | "repo" : "..",
10 | "benchmark_dir" : ".",
11 | "env_dir" : "../.asv/env",
12 | "results_dir" : "../.asv/results",
13 | "html_dir" : "../.asv/html",
14 | }
15 |
--------------------------------------------------------------------------------
/benchmarks/bench_core.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 |
3 | import autograd.numpy as np
4 | from autograd import grad
5 |
6 | try:
7 | from autograd.core import VJPNode, backward_pass, vspace
8 | from autograd.tracer import new_box, trace
9 |
10 | MASTER_BRANCH = False
11 | except ImportError:
12 | from autograd.core import backward_pass, forward_pass, new_progenitor, vspace
13 |
14 | MASTER_BRANCH = True
15 |
16 |
17 | ## SHORT FUNCTION
18 | def f_short(x):
19 | return x**2
20 |
21 |
22 | def time_short_fun():
23 | f_short(2.0)
24 |
25 |
26 | def time_short_forward_pass():
27 | if MASTER_BRANCH:
28 | forward_pass(f_short, (2.0,), {})
29 | else:
30 | start_node = VJPNode.new_root()
31 | trace(start_node, f_short, x)
32 |
33 |
34 | def time_short_backward_pass():
35 | if MASTER_BRANCH:
36 | backward_pass(1.0, short_end_node, short_start_node)
37 | else:
38 | backward_pass(1.0, short_end_node)
39 |
40 |
41 | def time_short_grad():
42 | grad(f_short)(2.0)
43 |
44 |
45 | ## LONG FUNCTION
46 | def f_long(x):
47 | for i in range(50):
48 | x = np.sin(x)
49 | return x
50 |
51 |
52 | def time_long_fun():
53 | f_long(2.0)
54 |
55 |
56 | def time_long_forward_pass():
57 | if MASTER_BRANCH:
58 | forward_pass(f_long, (2.0,), {})
59 | else:
60 | start_node = VJPNode.new_root()
61 | trace(start_node, f_long, x)
62 |
63 |
64 | def time_long_backward_pass():
65 | if MASTER_BRANCH:
66 | backward_pass(1.0, long_end_node, long_start_node)
67 | else:
68 | backward_pass(1.0, long_end_node)
69 |
70 |
71 | def time_long_grad():
72 | grad(f_long)(2.0)
73 |
74 |
75 | ## 'PEARLMUTTER TEST' FUNCTION
76 | def fan_out_fan_in(x):
77 | for i in range(10**4):
78 | x = (x + x) / 2.0
79 | return np.sum(x)
80 |
81 |
82 | def time_fan_out_fan_in_fun():
83 | fan_out_fan_in(2.0)
84 |
85 |
86 | def time_fan_out_fan_in_forward_pass():
87 | if MASTER_BRANCH:
88 | forward_pass(fan_out_fan_in, (2.0,), {})
89 | else:
90 | start_node = VJPNode.new_root()
91 | trace(start_node, fan_out_fan_in, x)
92 |
93 |
94 | def time_fan_out_fan_in_backward_pass():
95 | if MASTER_BRANCH:
96 | backward_pass(1.0, fan_end_node, fan_start_node)
97 | else:
98 | backward_pass(1.0, fan_end_node)
99 |
100 |
101 | def time_fan_out_fan_in_grad():
102 | grad(fan_out_fan_in)(2.0)
103 |
104 |
105 | ## UNIT BENCHMARKS
106 | def time_vspace_float():
107 | vspace(1.0)
108 |
109 |
110 | A = np.array([[1.0, 2.0, 3.0]])
111 |
112 |
113 | def time_vspace_array():
114 | vspace(A)
115 |
116 |
117 | def time_new_box_float():
118 | new_box(1.0, 0, start_node)
119 |
120 |
121 | def time_new_box_array():
122 | new_box(A, 0, start_node)
123 |
124 |
125 | def time_exp_call():
126 | onp.exp(2.0)
127 |
128 |
129 | def time_exp_primitive_call_unboxed():
130 | np.exp(2.0)
131 |
132 |
133 | def time_exp_primitive_call_boxed():
134 | if MASTER_BRANCH:
135 | np.exp(progenitor)
136 | else:
137 | np.exp(start_box)
138 |
139 |
140 | def time_no_autograd_control():
141 | # Test whether the benchmarking machine is running slowly independent of autograd
142 | A = np.random.randn(200, 200)
143 | np.dot(A, A)
144 |
145 |
146 | if MASTER_BRANCH:
147 | short_start_node, short_end_node = forward_pass(f_short, (2.0,), {})
148 | long_start_node, long_end_node = forward_pass(f_long, (2.0,), {})
149 | fan_start_node, fan_end_node = forward_pass(fan_out_fan_in, (2.0,), {})
150 | progenitor = new_progenitor(2.0)
151 | else:
152 | x = 2.0
153 | start_node = VJPNode.new_root()
154 | start_box = new_box(x, 0, start_node)
155 | _, short_end_node = trace(VJPNode.new_root(), f_short, x)
156 | _, long_end_node = trace(VJPNode.new_root(), f_long, x)
157 | _, fan_end_node = trace(VJPNode.new_root(), fan_out_fan_in, x)
158 |
--------------------------------------------------------------------------------
/benchmarks/bench_mem.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | from autograd import grad
3 |
4 |
5 | def peakmem_needless_nodes():
6 | N, M = 1000, 100
7 |
8 | def fun(x):
9 | for i in range(M):
10 | x = x + 1
11 | return np.sum(x)
12 |
13 | grad(fun)(np.zeros((N, N)))
14 |
--------------------------------------------------------------------------------
/benchmarks/bench_numpy_vjps.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | import autograd.numpy.random as npr
3 | from autograd import make_vjp
4 |
5 | dot_0 = lambda a, b, g: make_vjp(np.dot, argnum=0)(a, b)[0](g)
6 | dot_1 = lambda a, b, g: make_vjp(np.dot, argnum=1)(a, b)[0](g)
7 |
8 | dot_0_0 = lambda a, b, g: make_vjp(dot_0, argnum=0)(a, b, g)[0](a)
9 | dot_0_1 = lambda a, b, g: make_vjp(dot_0, argnum=1)(a, b, g)[0](a)
10 | dot_0_2 = lambda a, b, g: make_vjp(dot_0, argnum=2)(a, b, g)[0](a)
11 |
12 | dot_1_0 = lambda a, b, g: make_vjp(dot_1, argnum=0)(a, b, g)[0](b)
13 | dot_1_1 = lambda a, b, g: make_vjp(dot_1, argnum=1)(a, b, g)[0](b)
14 | dot_1_2 = lambda a, b, g: make_vjp(dot_1, argnum=2)(a, b, g)[0](b)
15 |
16 | a = npr.randn(2, 3, 4, 5)
17 | b = npr.randn(2, 3, 5, 4)
18 | g = npr.randn(2, 3, 4, 2, 3, 4)
19 |
20 |
21 | def time_dot_0():
22 | dot_0(a, b, g)
23 |
24 |
25 | def time_dot_1():
26 | dot_1(a, b, g)
27 |
28 |
29 | def time_dot_0_0():
30 | dot_0_0(a, b, g)
31 |
32 |
33 | def time_dot_0_1():
34 | dot_0_1(a, b, g)
35 |
36 |
37 | def time_dot_0_2():
38 | dot_0_2(a, b, g)
39 |
40 |
41 | def time_dot_1_0():
42 | dot_1_0(a, b, g)
43 |
44 |
45 | def time_dot_1_1():
46 | dot_1_1(a, b, g)
47 |
48 |
49 | def time_dot_1_2():
50 | dot_1_2(a, b, g)
51 |
52 |
53 | tensordot_0 = lambda A, B, G: make_vjp(np.tensordot, argnum=0)(A, B, 2)[0](G)
54 | tensordot_1 = lambda A, B, G: make_vjp(np.tensordot, argnum=1)(A, B, 2)[0](G)
55 |
56 | tensordot_0_0 = lambda A, B, G: make_vjp(tensordot_0, argnum=0)(A, B, G)[0](A)
57 | tensordot_0_1 = lambda A, B, G: make_vjp(tensordot_0, argnum=1)(A, B, G)[0](A)
58 | tensordot_0_2 = lambda A, B, G: make_vjp(tensordot_0, argnum=2)(A, B, G)[0](A)
59 |
60 | tensordot_1_0 = lambda A, B, G: make_vjp(tensordot_1, argnum=0)(A, B, G)[0](B)
61 | tensordot_1_1 = lambda A, B, G: make_vjp(tensordot_1, argnum=1)(A, B, G)[0](B)
62 | tensordot_1_2 = lambda A, B, G: make_vjp(tensordot_1, argnum=2)(A, B, G)[0](B)
63 |
64 | A = npr.randn(2, 3, 5, 4)
65 | B = npr.randn(5, 4, 2, 3)
66 | G = npr.randn(2, 3, 2, 3)
67 |
68 |
69 | def time_tensordot_0():
70 | tensordot_0(A, B, G)
71 |
72 |
73 | def time_tensordot_1():
74 | tensordot_1(A, B, G)
75 |
76 |
77 | def time_tensordot_0_0():
78 | tensordot_0_0(A, B, G)
79 |
80 |
81 | def time_tensordot_0_1():
82 | tensordot_0_1(A, B, G)
83 |
84 |
85 | def time_tensordot_0_2():
86 | tensordot_0_2(A, B, G)
87 |
88 |
89 | def time_tensordot_1_0():
90 | tensordot_1_0(A, B, G)
91 |
92 |
93 | def time_tensordot_1_1():
94 | tensordot_1_1(A, B, G)
95 |
96 |
97 | def time_tensordot_1_2():
98 | tensordot_1_2(A, B, G)
99 |
--------------------------------------------------------------------------------
/benchmarks/bench_util.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | import autograd.numpy.random as npr
3 | from autograd import grad
4 |
5 | try:
6 | from autograd.misc.flatten import flatten
7 | except ImportError:
8 | from autograd.util import flatten
9 |
10 |
11 | def time_flatten():
12 | val = {
13 | "k": npr.random((4, 4)),
14 | "k2": npr.random((3, 3)),
15 | "k3": 3.0,
16 | "k4": [1.0, 4.0, 7.0, 9.0],
17 | "k5": np.array([4.0, 5.0, 6.0]),
18 | "k6": np.array([[7.0, 8.0], [9.0, 10.0]]),
19 | }
20 |
21 | vect, unflatten = flatten(val)
22 | val_recovered = unflatten(vect)
23 | vect_2, _ = flatten(val_recovered)
24 |
25 |
26 | # def time_vspace_flatten():
27 | # val = {'k': npr.random((4, 4)),
28 | # 'k2': npr.random((3, 3)),
29 | # 'k3': 3.0,
30 | # 'k4': [1.0, 4.0, 7.0, 9.0],
31 | # 'k5': np.array([4., 5., 6.]),
32 | # 'k6': np.array([[7., 8.], [9., 10.]])}
33 |
34 | # vspace_flatten(val)
35 |
36 |
37 | def time_grad_flatten():
38 | val = {
39 | "k": npr.random((4, 4)),
40 | "k2": npr.random((3, 3)),
41 | "k3": 3.0,
42 | "k4": [1.0, 4.0, 7.0, 9.0],
43 | "k5": np.array([4.0, 5.0, 6.0]),
44 | "k6": np.array([[7.0, 8.0], [9.0, 10.0]]),
45 | }
46 |
47 | vect, unflatten = flatten(val)
48 |
49 | def fun(vec):
50 | v = unflatten(vec)
51 | return np.sum(v["k5"]) + np.sum(v["k6"])
52 |
53 | grad(fun)(vect)
54 |
--------------------------------------------------------------------------------
/conda_recipe/conda.yaml:
--------------------------------------------------------------------------------
1 | package:
2 | name: autograd
3 | # there are ways to derive version from other sources; for now, it's hard-coded
4 | version: 1.1.1
5 |
6 | source:
7 | {% if not environ.get('BINSTAR_PLATFORM', None) %}
8 | git_url: ../
9 | {% else %}
10 | # we're building on binstar, we already have the repo; treat as local path
11 | path: ../
12 | {% endif %}
13 |
14 | requirements:
15 | build:
16 | - python
17 | - hatch
18 | - hatchling
19 | - future
20 | - numpy >=1.9
21 |
22 | run:
23 | - python
24 | - future
25 | - numpy >=1.9
26 |
27 | build:
28 | script: pip install . --no-deps
29 |
30 | test:
31 | # Python imports
32 | imports:
33 | - autograd
34 | - autograd.numpy
35 |
36 | about:
37 | home: https://github.com/HIPS/autograd
38 | license: MIT
39 | summary: 'Efficiently computes derivatives of numpy code.'
40 |
--------------------------------------------------------------------------------
/docs/updateguide.md:
--------------------------------------------------------------------------------
1 | # Autograd v1.2 update guide
2 |
3 | Autograd v1.2 changed the interface for defining custom vector-Jacobian
4 | products (VJPs). Luckily the change only affects users writing custom VJPs, and
5 | should only require minor updates to the custom VJP code.
6 |
7 | This guide is meant to explain why we made these changes (and others) in
8 | Autograd v1.2, and to summarize everything you need to know to update your
9 | custom VJP code.
10 |
11 | - [Reasoning for the changes](#reasoning-for-the-changes)
12 | - [New defvjp interface](#new-defvjp-interface)
13 | - [Gradient checking](#gradient-checking)
14 |
15 | ## Reasoning for the changes
16 |
17 | Here are some of the most important reasons for this update:
18 | 1. To allow us to make Autograd faster and more memory efficient, we staged the
19 | VJP functions to allow more garbage collection and eliminated almost all of
20 | the vspace metadata checks.
21 | 1. Forward-mode now comes built-in with `make_jvp`.
22 | 1. There's now a clear extension API in `autograd.extend`, so you can write
23 | custom VJPs or wrap your own numerical libraries.
24 | 1. Autograd is now backend-independent, making it easy to wrap other numerical
25 | libraries.
26 | 1. Autograd's tracing functionality is now parameterized and easily reusable,
27 | and we added some new tracers for
28 | [computation graph visualization](https://github.com/hips/autograd/blob/master/examples/dot_graph.py)
29 | and
30 | [pure-Python constant folding](https://github.com/hips/autograd/blob/master/autograd/misc/tracers.py).
31 | 1. More exhaustive, fast reverse- and forward-mode checking with `autograd.test_util.check_grads`.
32 | 1. Expensive VJPs can share work across arguments using `defvjp_argnums`.
33 | 1. These changes enabled some internal cleanups, and more features to come!
34 |
35 | ## New defvjp interface
36 | First, here's an example of the old way to write custom primitives and VJPs:
37 | ```python
38 | import autograd.numpy as np
39 | from autograd import primitive
40 |
41 | @primitive
42 | def func(x, y, z):
43 | assert z != 0
44 | return x * y**2
45 |
46 | func.defvjp(lambda g, ans, vs, gvs, x, y, z: g * y**2)
47 | func.defvjp(lambda g, ans, vs, gvs, x, y, z: 2 * g * x * y, argnum=1)
48 | func.defvjp_is_zero(argnums=[2])
49 | ```
50 |
51 | Here's the new way to write custom VJPs for that same primitive:
52 | ```python
53 | import autograd.numpy as np
54 | from autograd.extend import primitive, defvjp # defvjp is now a function
55 |
56 | # primitives look the same as before
57 | @primitive
58 | def func(x, y, z):
59 | assert z != 0
60 | return x * y**2
61 |
62 | # but we call defvjp differently
63 | defvjp(func,
64 | lambda ans, x, y, z: lambda g: g * y**2,
65 | lambda ans, x, y, z: lambda g: 2 * g * x * y,
66 | None)
67 | ```
68 |
69 | Here's a list of the `defvjp` changes illustrated in that example:
70 | 1. `defvjp` is a function, rather than a method on the `primitive` class. (Actually, `primitive` is now just a function, and no longer a class.) As a result, `func.defvjp(...)` became `defvjp(func, ...)`.
71 | 1. VJPs are staged, so that instead of writing `lambda g, ans, vs, gvs, *args: ...` we write `lambda ans, *args: lambda g: ...`. This change enables a lot of automatic garbage collection. In the above example, if we were differentiating only with respect to `x` argument of `func`, because the VJP for `func` with respect to argument index 0 doesn't need the values of `x` or `z` from the forward pass, those values aren't stored and can instead be immediately garbage-collected.
72 | 1. There are no more `vs` and `gvs` arguments. These usually weren't used, and computing vspace metadata for every intermediate value proved to contribute significant overhead for some programs. Autograd now avoids computing vspace metadata unless necessary.
73 | 1. `defvjp` lets you define VJPs with respect to multiple arguments at once, and the argnum(s) involved are often implicit.
74 |
75 | Here's another example, this time showing how to define VJPs with respect to
76 | specific argnums, leaving the others undefined.
77 | ```python
78 | # OLD way to leave some VJPs undefined
79 | func.defvjp(lambda g, ans, vs, gvs, x, y, z, w: ..., argnum=2)
80 | func.defvjp(lambda g, ans, vs, gvs, x, y, z, w: ..., argnum=3)
81 |
82 | # NEW way to leave some VJPs undefined
83 | defvjp(func,
84 | lambda ans, x, y, z, w: lambda g: ...,
85 | lambda ans, x, y, z, w: lambda g: ...,
86 | argnums=[2, 3])
87 | ```
88 |
89 | ## Gradient checking
90 | Here's how to do gradient checking, whether on a composite function or on your
91 | primitive with a custom VJP:
92 |
93 | ```python
94 | from autograd.test_util import check_grads
95 |
96 | # check reverse-mode to second order
97 | check_grads(my_func, modes=['rev'], order=2)(*args_for_my_func)
98 | ```
99 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/__init__.py
--------------------------------------------------------------------------------
/examples/bayesian_neural_net.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/bayesian_neural_net.png
--------------------------------------------------------------------------------
/examples/bayesian_neural_net.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from black_box_svi import black_box_variational_inference
3 |
4 | import autograd.numpy as np
5 | import autograd.numpy.random as npr
6 | from autograd.misc.optimizers import adam
7 |
8 |
9 | def make_nn_funs(layer_sizes, L2_reg, noise_variance, nonlinearity=np.tanh):
10 | """These functions implement a standard multi-layer perceptron,
11 | vectorized over both training examples and weight samples."""
12 | shapes = list(zip(layer_sizes[:-1], layer_sizes[1:]))
13 | num_weights = sum((m + 1) * n for m, n in shapes)
14 |
15 | def unpack_layers(weights):
16 | num_weight_sets = len(weights)
17 | for m, n in shapes:
18 | yield (
19 | weights[:, : m * n].reshape((num_weight_sets, m, n)),
20 | weights[:, m * n : m * n + n].reshape((num_weight_sets, 1, n)),
21 | )
22 | weights = weights[:, (m + 1) * n :]
23 |
24 | def predictions(weights, inputs):
25 | """weights is shape (num_weight_samples x num_weights)
26 | inputs is shape (num_datapoints x D)"""
27 | inputs = np.expand_dims(inputs, 0)
28 | for W, b in unpack_layers(weights):
29 | outputs = np.einsum("mnd,mdo->mno", inputs, W) + b
30 | inputs = nonlinearity(outputs)
31 | return outputs
32 |
33 | def logprob(weights, inputs, targets):
34 | log_prior = -L2_reg * np.sum(weights**2, axis=1)
35 | preds = predictions(weights, inputs)
36 | log_lik = -np.sum((preds - targets) ** 2, axis=1)[:, 0] / noise_variance
37 | return log_prior + log_lik
38 |
39 | return num_weights, predictions, logprob
40 |
41 |
42 | def build_toy_dataset(n_data=40, noise_std=0.1):
43 | D = 1
44 | rs = npr.RandomState(0)
45 | inputs = np.concatenate([np.linspace(0, 2, num=n_data / 2), np.linspace(6, 8, num=n_data / 2)])
46 | targets = np.cos(inputs) + rs.randn(n_data) * noise_std
47 | inputs = (inputs - 4.0) / 4.0
48 | inputs = inputs.reshape((len(inputs), D))
49 | targets = targets.reshape((len(targets), D))
50 | return inputs, targets
51 |
52 |
53 | if __name__ == "__main__":
54 | # Specify inference problem by its unnormalized log-posterior.
55 | rbf = lambda x: np.exp(-(x**2))
56 | relu = lambda x: np.maximum(x, 0.0)
57 | num_weights, predictions, logprob = make_nn_funs(
58 | layer_sizes=[1, 20, 20, 1], L2_reg=0.1, noise_variance=0.01, nonlinearity=rbf
59 | )
60 |
61 | inputs, targets = build_toy_dataset()
62 | log_posterior = lambda weights, t: logprob(weights, inputs, targets)
63 |
64 | # Build variational objective.
65 | objective, gradient, unpack_params = black_box_variational_inference(
66 | log_posterior, num_weights, num_samples=20
67 | )
68 |
69 | # Set up figure.
70 | fig = plt.figure(figsize=(12, 8), facecolor="white")
71 | ax = fig.add_subplot(111, frameon=False)
72 | plt.ion()
73 | plt.show(block=False)
74 |
75 | def callback(params, t, g):
76 | print(f"Iteration {t} lower bound {-objective(params, t)}")
77 |
78 | # Sample functions from posterior.
79 | rs = npr.RandomState(0)
80 | mean, log_std = unpack_params(params)
81 | # rs = npr.RandomState(0)
82 | sample_weights = rs.randn(10, num_weights) * np.exp(log_std) + mean
83 | plot_inputs = np.linspace(-8, 8, num=400)
84 | outputs = predictions(sample_weights, np.expand_dims(plot_inputs, 1))
85 |
86 | # Plot data and functions.
87 | plt.cla()
88 | ax.plot(inputs.ravel(), targets.ravel(), "bx")
89 | ax.plot(plot_inputs, outputs[:, :, 0].T)
90 | ax.set_ylim([-2, 3])
91 | plt.draw()
92 | plt.pause(1.0 / 60.0)
93 |
94 | # Initialize variational parameters
95 | rs = npr.RandomState(0)
96 | init_mean = rs.randn(num_weights)
97 | init_log_std = -5 * np.ones(num_weights)
98 | init_var_params = np.concatenate([init_mean, init_log_std])
99 |
100 | print("Optimizing variational parameters...")
101 | variational_params = adam(gradient, init_var_params, step_size=0.1, num_iters=1000, callback=callback)
102 |
--------------------------------------------------------------------------------
/examples/black_box_svi.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | import autograd.numpy as np
4 | import autograd.numpy.random as npr
5 | import autograd.scipy.stats.multivariate_normal as mvn
6 | import autograd.scipy.stats.norm as norm
7 | from autograd import grad
8 | from autograd.misc.optimizers import adam
9 |
10 |
11 | def black_box_variational_inference(logprob, D, num_samples):
12 | """Implements http://arxiv.org/abs/1401.0118, and uses the
13 | local reparameterization trick from http://arxiv.org/abs/1506.02557"""
14 |
15 | def unpack_params(params):
16 | # Variational dist is a diagonal Gaussian.
17 | mean, log_std = params[:D], params[D:]
18 | return mean, log_std
19 |
20 | def gaussian_entropy(log_std):
21 | return 0.5 * D * (1.0 + np.log(2 * np.pi)) + np.sum(log_std)
22 |
23 | rs = npr.RandomState(0)
24 |
25 | def variational_objective(params, t):
26 | """Provides a stochastic estimate of the variational lower bound."""
27 | mean, log_std = unpack_params(params)
28 | samples = rs.randn(num_samples, D) * np.exp(log_std) + mean
29 | lower_bound = gaussian_entropy(log_std) + np.mean(logprob(samples, t))
30 | return -lower_bound
31 |
32 | gradient = grad(variational_objective)
33 |
34 | return variational_objective, gradient, unpack_params
35 |
36 |
37 | if __name__ == "__main__":
38 | # Specify an inference problem by its unnormalized log-density.
39 | D = 2
40 |
41 | def log_density(x, t):
42 | mu, log_sigma = x[:, 0], x[:, 1]
43 | sigma_density = norm.logpdf(log_sigma, 0, 1.35)
44 | mu_density = norm.logpdf(mu, 0, np.exp(log_sigma))
45 | return sigma_density + mu_density
46 |
47 | # Build variational objective.
48 | objective, gradient, unpack_params = black_box_variational_inference(log_density, D, num_samples=2000)
49 |
50 | # Set up plotting code
51 | def plot_isocontours(ax, func, xlimits=[-2, 2], ylimits=[-4, 2], numticks=101):
52 | x = np.linspace(*xlimits, num=numticks)
53 | y = np.linspace(*ylimits, num=numticks)
54 | X, Y = np.meshgrid(x, y)
55 | zs = func(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T)
56 | Z = zs.reshape(X.shape)
57 | plt.contour(X, Y, Z)
58 | ax.set_yticks([])
59 | ax.set_xticks([])
60 |
61 | # Set up figure.
62 | fig = plt.figure(figsize=(8, 8), facecolor="white")
63 | ax = fig.add_subplot(111, frameon=False)
64 | plt.ion()
65 | plt.show(block=False)
66 |
67 | def callback(params, t, g):
68 | print(f"Iteration {t} lower bound {-objective(params, t)}")
69 |
70 | plt.cla()
71 | target_distribution = lambda x: np.exp(log_density(x, t))
72 | plot_isocontours(ax, target_distribution)
73 |
74 | mean, log_std = unpack_params(params)
75 | variational_contour = lambda x: mvn.pdf(x, mean, np.diag(np.exp(2 * log_std)))
76 | plot_isocontours(ax, variational_contour)
77 | plt.draw()
78 | plt.pause(1.0 / 30.0)
79 |
80 | print("Optimizing variational parameters...")
81 | init_mean = -1 * np.ones(D)
82 | init_log_std = -5 * np.ones(D)
83 | init_var_params = np.concatenate([init_mean, init_log_std])
84 | variational_params = adam(gradient, init_var_params, step_size=0.1, num_iters=2000, callback=callback)
85 |
--------------------------------------------------------------------------------
/examples/data.py:
--------------------------------------------------------------------------------
1 | import data_mnist
2 | import matplotlib.image
3 | import matplotlib.pyplot as plt
4 |
5 | import autograd.numpy as np
6 | import autograd.numpy.random as npr
7 |
8 |
9 | def load_mnist():
10 | partial_flatten = lambda x: np.reshape(x, (x.shape[0], np.prod(x.shape[1:])))
11 | one_hot = lambda x, k: np.array(x[:, None] == np.arange(k)[None, :], dtype=int)
12 | train_images, train_labels, test_images, test_labels = data_mnist.mnist()
13 | train_images = partial_flatten(train_images) / 255.0
14 | test_images = partial_flatten(test_images) / 255.0
15 | train_labels = one_hot(train_labels, 10)
16 | test_labels = one_hot(test_labels, 10)
17 | N_data = train_images.shape[0]
18 |
19 | return N_data, train_images, train_labels, test_images, test_labels
20 |
21 |
22 | def plot_images(
23 | images,
24 | ax,
25 | ims_per_row=5,
26 | padding=5,
27 | digit_dimensions=(28, 28),
28 | cmap=matplotlib.cm.binary,
29 | vmin=None,
30 | vmax=None,
31 | ):
32 | """Images should be a (N_images x pixels) matrix."""
33 | N_images = images.shape[0]
34 | N_rows = (N_images - 1) // ims_per_row + 1
35 | pad_value = np.min(images.ravel())
36 | concat_images = np.full(
37 | (
38 | (digit_dimensions[0] + padding) * N_rows + padding,
39 | (digit_dimensions[1] + padding) * ims_per_row + padding,
40 | ),
41 | pad_value,
42 | )
43 | for i in range(N_images):
44 | cur_image = np.reshape(images[i, :], digit_dimensions)
45 | row_ix = i // ims_per_row
46 | col_ix = i % ims_per_row
47 | row_start = padding + (padding + digit_dimensions[0]) * row_ix
48 | col_start = padding + (padding + digit_dimensions[1]) * col_ix
49 | concat_images[
50 | row_start : row_start + digit_dimensions[0], col_start : col_start + digit_dimensions[1]
51 | ] = cur_image
52 | cax = ax.matshow(concat_images, cmap=cmap, vmin=vmin, vmax=vmax)
53 | plt.xticks(np.array([]))
54 | plt.yticks(np.array([]))
55 | return cax
56 |
57 |
58 | def save_images(images, filename, **kwargs):
59 | fig = plt.figure(1)
60 | fig.clf()
61 | ax = fig.add_subplot(111)
62 | plot_images(images, ax, **kwargs)
63 | fig.patch.set_visible(False)
64 | ax.patch.set_visible(False)
65 | plt.savefig(filename)
66 |
67 |
68 | def make_pinwheel(radial_std, tangential_std, num_classes, num_per_class, rate, rs=npr.RandomState(0)):
69 | """Based on code by Ryan P. Adams."""
70 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)
71 |
72 | features = rs.randn(num_classes * num_per_class, 2) * np.array([radial_std, tangential_std])
73 | features[:, 0] += 1
74 | labels = np.repeat(np.arange(num_classes), num_per_class)
75 |
76 | angles = rads[labels] + rate * np.exp(features[:, 0])
77 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
78 | rotations = np.reshape(rotations.T, (-1, 2, 2))
79 |
80 | return np.einsum("ti,tij->tj", features, rotations)
81 |
--------------------------------------------------------------------------------
/examples/data_mnist.py:
--------------------------------------------------------------------------------
1 | import array
2 | import gzip
3 | import os
4 | import struct
5 | from urllib.request import urlretrieve
6 |
7 | import numpy as np
8 |
9 |
10 | def download(url, filename):
11 | if not os.path.exists("data"):
12 | os.makedirs("data")
13 | out_file = os.path.join("data", filename)
14 | if not os.path.isfile(out_file):
15 | urlretrieve(url, out_file)
16 |
17 |
18 | def mnist():
19 | base_url = "http://yann.lecun.com/exdb/mnist/"
20 |
21 | def parse_labels(filename):
22 | with gzip.open(filename, "rb") as fh:
23 | magic, num_data = struct.unpack(">II", fh.read(8))
24 | return np.array(array.array("B", fh.read()), dtype=np.uint8)
25 |
26 | def parse_images(filename):
27 | with gzip.open(filename, "rb") as fh:
28 | magic, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
29 | return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols)
30 |
31 | for filename in [
32 | "train-images-idx3-ubyte.gz",
33 | "train-labels-idx1-ubyte.gz",
34 | "t10k-images-idx3-ubyte.gz",
35 | "t10k-labels-idx1-ubyte.gz",
36 | ]:
37 | download(base_url + filename, filename)
38 |
39 | train_images = parse_images("data/train-images-idx3-ubyte.gz")
40 | train_labels = parse_labels("data/train-labels-idx1-ubyte.gz")
41 | test_images = parse_images("data/t10k-images-idx3-ubyte.gz")
42 | test_labels = parse_labels("data/t10k-labels-idx1-ubyte.gz")
43 |
44 | return train_images, train_labels, test_images, test_labels
45 |
--------------------------------------------------------------------------------
/examples/deep_gaussian_process.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from gaussian_process import make_gp_funs, rbf_covariance
3 | from scipy.optimize import minimize
4 |
5 | import autograd.numpy as np
6 | import autograd.numpy.random as npr
7 | from autograd import value_and_grad
8 |
9 |
10 | def build_step_function_dataset(D=1, n_data=40, noise_std=0.1):
11 | rs = npr.RandomState(0)
12 | inputs = np.linspace(-2, 2, num=n_data)
13 | targets = np.sign(inputs) + rs.randn(n_data) * noise_std
14 | inputs = inputs.reshape((len(inputs), D))
15 | return inputs, targets
16 |
17 |
18 | def build_deep_gp(input_dimension, hidden_dimension, covariance_function):
19 | # GP going from input to hidden
20 | num_params_layer1, predict_layer1, log_marginal_likelihood_layer1 = make_gp_funs(
21 | covariance_function, num_cov_params=input_dimension + 1
22 | )
23 |
24 | # GP going from hidden to output
25 | num_params_layer2, predict_layer2, log_marginal_likelihood_layer2 = make_gp_funs(
26 | covariance_function, num_cov_params=hidden_dimension + 1
27 | )
28 |
29 | num_hidden_params = hidden_dimension * n_data
30 | total_num_params = num_params_layer1 + num_params_layer2 + num_hidden_params
31 |
32 | def unpack_all_params(all_params):
33 | layer1_params = all_params[:num_params_layer1]
34 | layer2_params = all_params[num_params_layer1 : num_params_layer1 + num_params_layer2]
35 | hiddens = all_params[num_params_layer1 + num_params_layer2 :]
36 | return layer1_params, layer2_params, hiddens
37 |
38 | def combined_predict_fun(all_params, X, y, xs):
39 | layer1_params, layer2_params, hiddens = unpack_all_params(all_params)
40 | h_star_mean, h_star_cov = predict_layer1(layer1_params, X, hiddens, xs)
41 | y_star_mean, y_star_cov = predict_layer2(
42 | layer2_params, np.atleast_2d(hiddens).T, y, np.atleast_2d(h_star_mean).T
43 | )
44 | return y_star_mean, y_star_cov
45 |
46 | def log_marginal_likelihood(all_params):
47 | layer1_params, layer2_params, h = unpack_all_params(all_params)
48 | return log_marginal_likelihood_layer1(layer1_params, X, h) + log_marginal_likelihood_layer2(
49 | layer2_params, np.atleast_2d(h).T, y
50 | )
51 |
52 | predict_layer_funcs = [predict_layer1, predict_layer2]
53 |
54 | return (
55 | total_num_params,
56 | log_marginal_likelihood,
57 | combined_predict_fun,
58 | unpack_all_params,
59 | predict_layer_funcs,
60 | )
61 |
62 |
63 | if __name__ == "__main__":
64 | n_data = 20
65 | input_dimension = 1
66 | hidden_dimension = 1
67 | X, y = build_step_function_dataset(D=input_dimension, n_data=n_data)
68 |
69 | (
70 | total_num_params,
71 | log_marginal_likelihood,
72 | combined_predict_fun,
73 | unpack_all_params,
74 | predict_layer_funcs,
75 | ) = build_deep_gp(input_dimension, hidden_dimension, rbf_covariance)
76 |
77 | # Set up figure.
78 | fig = plt.figure(figsize=(12, 8), facecolor="white")
79 | ax_end_to_end = fig.add_subplot(311, frameon=False)
80 | ax_x_to_h = fig.add_subplot(312, frameon=False)
81 | ax_h_to_y = fig.add_subplot(313, frameon=False)
82 | plt.show(block=False)
83 |
84 | def plot_gp(ax, X, y, pred_mean, pred_cov, plot_xs):
85 | ax.cla()
86 | marg_std = np.sqrt(np.diag(pred_cov))
87 | ax.plot(plot_xs, pred_mean, "b")
88 | ax.fill(
89 | np.concatenate([plot_xs, plot_xs[::-1]]),
90 | np.concatenate([pred_mean - 1.96 * marg_std, (pred_mean + 1.96 * marg_std)[::-1]]),
91 | alpha=0.15,
92 | fc="Blue",
93 | ec="None",
94 | )
95 |
96 | # Show samples from posterior.
97 | rs = npr.RandomState(0)
98 | sampled_funcs = rs.multivariate_normal(pred_mean, pred_cov, size=10)
99 | ax.plot(plot_xs, sampled_funcs.T)
100 | ax.plot(X, y, "kx")
101 | ax.set_ylim([-1.5, 1.5])
102 | ax.set_xticks([])
103 | ax.set_yticks([])
104 |
105 | def callback(params):
106 | print(f"Log marginal likelihood {log_marginal_likelihood(params)}")
107 |
108 | # Show posterior marginals.
109 | plot_xs = np.reshape(np.linspace(-5, 5, 300), (300, 1))
110 | pred_mean, pred_cov = combined_predict_fun(params, X, y, plot_xs)
111 | plot_gp(ax_end_to_end, X, y, pred_mean, pred_cov, plot_xs)
112 | ax_end_to_end.set_title("X to y")
113 |
114 | layer1_params, layer2_params, hiddens = unpack_all_params(params)
115 | h_star_mean, h_star_cov = predict_layer_funcs[0](layer1_params, X, hiddens, plot_xs)
116 | y_star_mean, y_star_cov = predict_layer_funcs[0](layer2_params, np.atleast_2d(hiddens).T, y, plot_xs)
117 |
118 | plot_gp(ax_x_to_h, X, hiddens, h_star_mean, h_star_cov, plot_xs)
119 | ax_x_to_h.set_title("X to hiddens")
120 |
121 | plot_gp(ax_h_to_y, np.atleast_2d(hiddens).T, y, y_star_mean, y_star_cov, plot_xs)
122 | ax_h_to_y.set_title("hiddens to y")
123 |
124 | plt.draw()
125 | plt.pause(1.0 / 60.0)
126 |
127 | # Initialize covariance parameters and hiddens.
128 | rs = npr.RandomState(0)
129 | init_params = 0.1 * rs.randn(total_num_params)
130 |
131 | print("Optimizing covariance parameters...")
132 | objective = lambda params: -log_marginal_likelihood(params)
133 | cov_params = minimize(value_and_grad(objective), init_params, jac=True, method="CG", callback=callback)
134 | plt.pause(10.0)
135 |
--------------------------------------------------------------------------------
/examples/define_gradient.py:
--------------------------------------------------------------------------------
1 | """This example shows how to define the gradient of your own functions.
2 | This can be useful for speed, numerical stability, or in cases where
3 | your code depends on external library calls."""
4 |
5 | import autograd.numpy as np
6 | import autograd.numpy.random as npr
7 | from autograd import grad
8 | from autograd.extend import defvjp, primitive
9 | from autograd.test_util import check_grads
10 |
11 |
12 | # @primitive tells Autograd not to look inside this function, but instead
13 | # to treat it as a black box, whose gradient might be specified later.
14 | # Functions with this decorator can contain anything that Python knows
15 | # how to execute, and you can do things like in-place operations on arrays.
16 | @primitive
17 | def logsumexp(x):
18 | """Numerically stable log(sum(exp(x))), also defined in scipy.special"""
19 | max_x = np.max(x)
20 | return max_x + np.log(np.sum(np.exp(x - max_x)))
21 |
22 |
23 | # Next, we write a function that specifies the gradient with a closure.
24 | # The reason for the closure is so that the gradient can depend
25 | # on both the input to the original function (x), and the output of the
26 | # original function (ans).
27 |
28 |
29 | def logsumexp_vjp(ans, x):
30 | # If you want to be able to take higher-order derivatives, then all the
31 | # code inside this function must be itself differentiable by Autograd.
32 | # This closure multiplies g with the Jacobian of logsumexp (d_ans/d_x).
33 | # Because Autograd uses reverse-mode differentiation, g contains
34 | # the gradient of the objective w.r.t. ans, the output of logsumexp.
35 | # This returned VJP function doesn't close over `x`, so Python can
36 | # garbage-collect `x` if there are no references to it elsewhere.
37 | x_shape = x.shape
38 | return lambda g: np.full(x_shape, g) * np.exp(x - np.full(x_shape, ans))
39 |
40 |
41 | # Now we tell Autograd that logsumexmp has a gradient-making function.
42 | defvjp(logsumexp, logsumexp_vjp)
43 |
44 | if __name__ == "__main__":
45 | # Now we can use logsumexp() inside a larger function that we want
46 | # to differentiate.
47 | def example_func(y):
48 | z = y**2
49 | lse = logsumexp(z)
50 | return np.sum(lse)
51 |
52 | grad_of_example = grad(example_func)
53 | print("Gradient: \n", grad_of_example(npr.randn(10)))
54 |
55 | # Check the gradients numerically, just to be safe.
56 | check_grads(example_func, modes=["rev"])(npr.randn(10))
57 |
--------------------------------------------------------------------------------
/examples/dot_graph.py:
--------------------------------------------------------------------------------
1 | """Generates a graphviz DOT file of an evaluation trace.
2 | Usage (need the dot binary, from the graphviz package, www.graphviz.org):
3 |
4 | python2 dot_graph.py | dot -Tpdf -o graph.pdf
5 | """
6 |
7 | import autograd.numpy as np
8 | from autograd.tracer import Node, trace
9 |
10 |
11 | class GraphNode(Node):
12 | # Records the full graph (could having this in tracer.py)
13 | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
14 | self.fun_name = fun.__name__
15 | self.args = args
16 | self.parents = dict(zip(parent_argnums, parents))
17 | self.isroot = False
18 |
19 | def initialize_root(self, x):
20 | self.isroot = True
21 |
22 | def __repr__(self):
23 | return f"node_{id(self)}"
24 |
25 |
26 | def trace_graph(f, x):
27 | start_node = GraphNode.new_root(x)
28 | _, node = trace(start_node, f, x)
29 | return node
30 |
31 |
32 | dot_edge = "{} -> {} [color=gray30];\n".format
33 | dot_function_node = '{} [label="{}", shape=box, color=lightblue, style=filled];\n'.format
34 | dot_variable_node = '{} [label="{}", color=orange, style=filled];\n'.format
35 | dot_graph = "digraph G {{{}}}".format
36 |
37 |
38 | def graph_to_dotfile(graph):
39 | visited = set()
40 |
41 | def node_to_fragment(node):
42 | visited.add(node)
43 | if node.isroot:
44 | return dot_variable_node(node, "input")
45 | fragment = dot_function_node(node, node.fun_name)
46 | for argnum, arg in enumerate(node.args):
47 | if argnum in node.parents:
48 | parent = node.parents[argnum]
49 | fragment += dot_edge(parent, node)
50 | if parent not in visited:
51 | fragment += node_to_fragment(parent)
52 | else:
53 | argnode = f"{node}_arg_{argnum}"
54 | fragment += dot_edge(argnode, node)
55 | fragment += dot_variable_node(argnode, arg)
56 |
57 | return fragment
58 |
59 | dot_body = node_to_fragment(graph)
60 | dot_body += dot_variable_node("output", "output")
61 | dot_body += dot_edge(graph, "output")
62 | return dot_graph(dot_body)
63 |
64 |
65 | if __name__ == "__main__":
66 |
67 | def fun(x):
68 | y = np.sin(x)
69 | return (y + np.exp(x) - 0.5) * y
70 |
71 | print(graph_to_dotfile(trace_graph(fun, 1.0)))
72 |
--------------------------------------------------------------------------------
/examples/fixed_points.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | from autograd import grad
3 | from autograd.misc.fixed_points import fixed_point
4 |
5 |
6 | def newton_sqrt_iter(a):
7 | return lambda x: 0.5 * (x + a / x)
8 |
9 |
10 | def grad_descent_sqrt_iter(a):
11 | return lambda x: x - 0.05 * (x**2 - a)
12 |
13 |
14 | def sqrt(a, guess=10.0):
15 | # return fixed_point(newton_sqrt_iter, a, guess, distance, 1e-4)
16 | return fixed_point(grad_descent_sqrt_iter, a, guess, distance, 1e-4)
17 |
18 |
19 | def distance(x, y):
20 | return np.abs(x - y)
21 |
22 |
23 | print(np.sqrt(2.0))
24 | print(sqrt(2.0))
25 | print()
26 | print(grad(np.sqrt)(2.0))
27 | print(grad(sqrt)(2.0))
28 | print()
29 | print(grad(grad(np.sqrt))(2.0))
30 | print(grad(grad(sqrt))(2.0))
31 | print()
32 |
--------------------------------------------------------------------------------
/examples/fluidsim/animated.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/animated.gif
--------------------------------------------------------------------------------
/examples/fluidsim/fluidsim.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib
4 | import matplotlib.pyplot as plt
5 | from matplotlib.pyplot import imread
6 | from scipy.optimize import minimize
7 |
8 | import autograd.numpy as np
9 | from autograd import value_and_grad
10 |
11 | # Fluid simulation code based on
12 | # "Real-Time Fluid Dynamics for Games" by Jos Stam
13 | # https://www.josstam.com/_files/ugd/cf1fd6_9989229efbd34a26ba5ccd913721a2ac.pdf
14 |
15 |
16 | def project(vx, vy):
17 | """Project the velocity field to be approximately mass-conserving,
18 | using a few iterations of Gauss-Seidel."""
19 | p = np.zeros(vx.shape)
20 | h = 1.0 / vx.shape[0]
21 | div = (
22 | -0.5
23 | * h
24 | * (
25 | np.roll(vx, -1, axis=0)
26 | - np.roll(vx, 1, axis=0)
27 | + np.roll(vy, -1, axis=1)
28 | - np.roll(vy, 1, axis=1)
29 | )
30 | )
31 |
32 | for k in range(10):
33 | p = (
34 | div
35 | + np.roll(p, 1, axis=0)
36 | + np.roll(p, -1, axis=0)
37 | + np.roll(p, 1, axis=1)
38 | + np.roll(p, -1, axis=1)
39 | ) / 4.0
40 |
41 | vx -= 0.5 * (np.roll(p, -1, axis=0) - np.roll(p, 1, axis=0)) / h
42 | vy -= 0.5 * (np.roll(p, -1, axis=1) - np.roll(p, 1, axis=1)) / h
43 | return vx, vy
44 |
45 |
46 | def advect(f, vx, vy):
47 | """Move field f according to x and y velocities (u and v)
48 | using an implicit Euler integrator."""
49 | rows, cols = f.shape
50 | cell_ys, cell_xs = np.meshgrid(np.arange(rows), np.arange(cols))
51 | center_xs = (cell_xs - vx).ravel()
52 | center_ys = (cell_ys - vy).ravel()
53 |
54 | # Compute indices of source cells.
55 | left_ix = np.floor(center_xs).astype(int)
56 | top_ix = np.floor(center_ys).astype(int)
57 | rw = center_xs - left_ix # Relative weight of right-hand cells.
58 | bw = center_ys - top_ix # Relative weight of bottom cells.
59 | left_ix = np.mod(left_ix, rows) # Wrap around edges of simulation.
60 | right_ix = np.mod(left_ix + 1, rows)
61 | top_ix = np.mod(top_ix, cols)
62 | bot_ix = np.mod(top_ix + 1, cols)
63 |
64 | # A linearly-weighted sum of the 4 surrounding cells.
65 | flat_f = (1 - rw) * ((1 - bw) * f[left_ix, top_ix] + bw * f[left_ix, bot_ix]) + rw * (
66 | (1 - bw) * f[right_ix, top_ix] + bw * f[right_ix, bot_ix]
67 | )
68 | return np.reshape(flat_f, (rows, cols))
69 |
70 |
71 | def simulate(vx, vy, smoke, num_time_steps, ax=None, render=False):
72 | print("Running simulation...")
73 | for t in range(num_time_steps):
74 | if ax:
75 | plot_matrix(ax, smoke, t, render)
76 | vx_updated = advect(vx, vx, vy)
77 | vy_updated = advect(vy, vx, vy)
78 | vx, vy = project(vx_updated, vy_updated)
79 | smoke = advect(smoke, vx, vy)
80 | if ax:
81 | plot_matrix(ax, smoke, num_time_steps, render)
82 | return smoke
83 |
84 |
85 | def plot_matrix(ax, mat, t, render=False):
86 | plt.cla()
87 | ax.matshow(mat)
88 | ax.set_xticks([])
89 | ax.set_yticks([])
90 | plt.draw()
91 | if render:
92 | matplotlib.image.imsave(f"step{t:03d}.png", mat)
93 | plt.pause(0.001)
94 |
95 |
96 | if __name__ == "__main__":
97 | simulation_timesteps = 100
98 | basepath = os.path.dirname(__file__)
99 |
100 | print("Loading initial and target states...")
101 | init_smoke = imread(os.path.join(basepath, "init_smoke.png"))[:, :, 0]
102 | # target = imread('peace.png')[::2,::2,3]
103 | target = imread(os.path.join(basepath, "skull.png"))[::2, ::2]
104 | rows, cols = target.shape
105 |
106 | init_dx_and_dy = np.zeros((2, rows, cols)).ravel()
107 |
108 | def distance_from_target_image(smoke):
109 | return np.mean((target - smoke) ** 2)
110 |
111 | def convert_param_vector_to_matrices(params):
112 | vx = np.reshape(params[: (rows * cols)], (rows, cols))
113 | vy = np.reshape(params[(rows * cols) :], (rows, cols))
114 | return vx, vy
115 |
116 | def objective(params):
117 | init_vx, init_vy = convert_param_vector_to_matrices(params)
118 | final_smoke = simulate(init_vx, init_vy, init_smoke, simulation_timesteps)
119 | return distance_from_target_image(final_smoke)
120 |
121 | # Specify gradient of objective function using autograd.
122 | objective_with_grad = value_and_grad(objective)
123 |
124 | fig = plt.figure(figsize=(8, 8))
125 | ax = fig.add_subplot(111, frameon=False)
126 |
127 | def callback(params):
128 | init_vx, init_vy = convert_param_vector_to_matrices(params)
129 | simulate(init_vx, init_vy, init_smoke, simulation_timesteps, ax)
130 |
131 | print("Optimizing initial conditions...")
132 | result = minimize(
133 | objective_with_grad,
134 | init_dx_and_dy,
135 | jac=True,
136 | method="CG",
137 | options={"maxiter": 25, "disp": True},
138 | callback=callback,
139 | )
140 |
141 | print("Rendering optimized flow...")
142 | init_vx, init_vy = convert_param_vector_to_matrices(result.x)
143 | simulate(init_vx, init_vy, init_smoke, simulation_timesteps, ax, render=True)
144 |
145 | print("Converting frames to an animated GIF...")
146 | os.system("convert -delay 5 -loop 0 step*.png -delay 250 step100.png surprise.gif") # Using imagemagick.
147 | os.system("rm step*.png")
148 |
--------------------------------------------------------------------------------
/examples/fluidsim/init_smoke.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/init_smoke.png
--------------------------------------------------------------------------------
/examples/fluidsim/peace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/peace.png
--------------------------------------------------------------------------------
/examples/fluidsim/skull.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/skull.png
--------------------------------------------------------------------------------
/examples/fluidsim/surprise.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/surprise.gif
--------------------------------------------------------------------------------
/examples/fluidsim/wing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/wing.png
--------------------------------------------------------------------------------
/examples/gaussian_process.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/gaussian_process.png
--------------------------------------------------------------------------------
/examples/gaussian_process.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from scipy.optimize import minimize
3 |
4 | import autograd.numpy as np
5 | import autograd.numpy.random as npr
6 | import autograd.scipy.stats.multivariate_normal as mvn
7 | from autograd import value_and_grad
8 | from autograd.numpy.linalg import solve
9 |
10 |
11 | def make_gp_funs(cov_func, num_cov_params):
12 | """Functions that perform Gaussian process regression.
13 | cov_func has signature (cov_params, x, x')"""
14 |
15 | def unpack_kernel_params(params):
16 | mean = params[0]
17 | cov_params = params[2:]
18 | noise_scale = np.exp(params[1]) + 0.0001
19 | return mean, cov_params, noise_scale
20 |
21 | def predict(params, x, y, xstar):
22 | """Returns the predictive mean and covariance at locations xstar,
23 | of the latent function value f (without observation noise)."""
24 | mean, cov_params, noise_scale = unpack_kernel_params(params)
25 | cov_f_f = cov_func(cov_params, xstar, xstar)
26 | cov_y_f = cov_func(cov_params, x, xstar)
27 | cov_y_y = cov_func(cov_params, x, x) + noise_scale * np.eye(len(y))
28 | pred_mean = mean + np.dot(solve(cov_y_y, cov_y_f).T, y - mean)
29 | pred_cov = cov_f_f - np.dot(solve(cov_y_y, cov_y_f).T, cov_y_f)
30 | return pred_mean, pred_cov
31 |
32 | def log_marginal_likelihood(params, x, y):
33 | mean, cov_params, noise_scale = unpack_kernel_params(params)
34 | cov_y_y = cov_func(cov_params, x, x) + noise_scale * np.eye(len(y))
35 | prior_mean = mean * np.ones(len(y))
36 | return mvn.logpdf(y, prior_mean, cov_y_y)
37 |
38 | return num_cov_params + 2, predict, log_marginal_likelihood
39 |
40 |
41 | # Define an example covariance function.
42 | def rbf_covariance(kernel_params, x, xp):
43 | output_scale = np.exp(kernel_params[0])
44 | lengthscales = np.exp(kernel_params[1:])
45 | diffs = np.expand_dims(x / lengthscales, 1) - np.expand_dims(xp / lengthscales, 0)
46 | return output_scale * np.exp(-0.5 * np.sum(diffs**2, axis=2))
47 |
48 |
49 | def build_toy_dataset(D=1, n_data=20, noise_std=0.1):
50 | rs = npr.RandomState(0)
51 | inputs = np.concatenate([np.linspace(0, 3, num=n_data / 2), np.linspace(6, 8, num=n_data / 2)])
52 | targets = (np.cos(inputs) + rs.randn(n_data) * noise_std) / 2.0
53 | inputs = (inputs - 4.0) / 2.0
54 | inputs = inputs.reshape((len(inputs), D))
55 | return inputs, targets
56 |
57 |
58 | if __name__ == "__main__":
59 | D = 1
60 |
61 | # Build model and objective function.
62 | num_params, predict, log_marginal_likelihood = make_gp_funs(rbf_covariance, num_cov_params=D + 1)
63 |
64 | X, y = build_toy_dataset(D=D)
65 | objective = lambda params: -log_marginal_likelihood(params, X, y)
66 |
67 | # Set up figure.
68 | fig = plt.figure(figsize=(12, 8), facecolor="white")
69 | ax = fig.add_subplot(111, frameon=False)
70 | plt.show(block=False)
71 |
72 | def callback(params):
73 | print(f"Log likelihood {-objective(params)}")
74 | plt.cla()
75 |
76 | # Show posterior marginals.
77 | plot_xs = np.reshape(np.linspace(-7, 7, 300), (300, 1))
78 | pred_mean, pred_cov = predict(params, X, y, plot_xs)
79 | marg_std = np.sqrt(np.diag(pred_cov))
80 | ax.plot(plot_xs, pred_mean, "b")
81 | ax.fill(
82 | np.concatenate([plot_xs, plot_xs[::-1]]),
83 | np.concatenate([pred_mean - 1.96 * marg_std, (pred_mean + 1.96 * marg_std)[::-1]]),
84 | alpha=0.15,
85 | fc="Blue",
86 | ec="None",
87 | )
88 |
89 | # Show samples from posterior.
90 | rs = npr.RandomState(0)
91 | sampled_funcs = rs.multivariate_normal(pred_mean, pred_cov, size=10)
92 | ax.plot(plot_xs, sampled_funcs.T)
93 |
94 | ax.plot(X, y, "kx")
95 | ax.set_ylim([-1.5, 1.5])
96 | ax.set_xticks([])
97 | ax.set_yticks([])
98 | plt.draw()
99 | plt.pause(1.0 / 60.0)
100 |
101 | # Initialize covariance parameters
102 | rs = npr.RandomState(0)
103 | init_params = 0.1 * rs.randn(num_params)
104 |
105 | print("Optimizing covariance parameters...")
106 | cov_params = minimize(value_and_grad(objective), init_params, jac=True, method="CG", callback=callback)
107 | plt.pause(10.0)
108 |
--------------------------------------------------------------------------------
/examples/gmm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/gmm.png
--------------------------------------------------------------------------------
/examples/gmm.py:
--------------------------------------------------------------------------------
1 | """Implements a Gaussian mixture model, in which parameters are fit using
2 | gradient descent. This example runs on 2-dimensional data, but the model
3 | works on arbitrarily-high dimension."""
4 |
5 | import matplotlib.pyplot as plt
6 | from data import make_pinwheel
7 | from scipy.optimize import minimize
8 |
9 | import autograd.numpy as np
10 | import autograd.numpy.random as npr
11 | import autograd.scipy.stats.multivariate_normal as mvn
12 | from autograd import grad, hessian_vector_product
13 | from autograd.misc.flatten import flatten_func
14 | from autograd.scipy.special import logsumexp
15 |
16 |
17 | def init_gmm_params(num_components, D, scale, rs=npr.RandomState(0)):
18 | return {
19 | "log proportions": rs.randn(num_components) * scale,
20 | "means": rs.randn(num_components, D) * scale,
21 | "lower triangles": np.zeros((num_components, D, D)) + np.eye(D),
22 | }
23 |
24 |
25 | def log_normalize(x):
26 | return x - logsumexp(x)
27 |
28 |
29 | def unpack_gmm_params(params):
30 | normalized_log_proportions = log_normalize(params["log proportions"])
31 | return normalized_log_proportions, params["means"], params["lower triangles"]
32 |
33 |
34 | def gmm_log_likelihood(params, data):
35 | cluster_lls = []
36 | for log_proportion, mean, cov_sqrt in zip(*unpack_gmm_params(params)):
37 | cov = np.dot(cov_sqrt.T, cov_sqrt)
38 | cluster_lls.append(log_proportion + mvn.logpdf(data, mean, cov))
39 | return np.sum(logsumexp(np.vstack(cluster_lls), axis=0))
40 |
41 |
42 | def plot_ellipse(ax, mean, cov_sqrt, alpha, num_points=100):
43 | angles = np.linspace(0, 2 * np.pi, num_points)
44 | circle_pts = np.vstack([np.cos(angles), np.sin(angles)]).T * 2.0
45 | cur_pts = mean + np.dot(circle_pts, cov_sqrt)
46 | ax.plot(cur_pts[:, 0], cur_pts[:, 1], "-", alpha=alpha)
47 |
48 |
49 | def plot_gaussian_mixture(params, ax):
50 | for log_proportion, mean, cov_sqrt in zip(*unpack_gmm_params(params)):
51 | alpha = np.minimum(1.0, np.exp(log_proportion) * 10)
52 | plot_ellipse(ax, mean, cov_sqrt, alpha)
53 |
54 |
55 | if __name__ == "__main__":
56 | init_params = init_gmm_params(num_components=10, D=2, scale=0.1)
57 |
58 | data = make_pinwheel(radial_std=0.3, tangential_std=0.05, num_classes=3, num_per_class=100, rate=0.4)
59 |
60 | def objective(params):
61 | return -gmm_log_likelihood(params, data)
62 |
63 | flattened_obj, unflatten, flattened_init_params = flatten_func(objective, init_params)
64 |
65 | fig = plt.figure(figsize=(12, 8), facecolor="white")
66 | ax = fig.add_subplot(111, frameon=False)
67 | plt.show(block=False)
68 |
69 | def callback(flattened_params):
70 | params = unflatten(flattened_params)
71 | print(f"Log likelihood {-objective(params)}")
72 | ax.cla()
73 | ax.plot(data[:, 0], data[:, 1], "k.")
74 | ax.set_xticks([])
75 | ax.set_yticks([])
76 | plot_gaussian_mixture(params, ax)
77 | plt.draw()
78 | plt.pause(1.0 / 60.0)
79 |
80 | minimize(
81 | flattened_obj,
82 | flattened_init_params,
83 | jac=grad(flattened_obj),
84 | hessp=hessian_vector_product(flattened_obj),
85 | method="Newton-CG",
86 | callback=callback,
87 | )
88 |
--------------------------------------------------------------------------------
/examples/gplvm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/gplvm.png
--------------------------------------------------------------------------------
/examples/gplvm.py:
--------------------------------------------------------------------------------
1 | # Implements a Gaussian process latent-variable model.
2 | # The (high-dimensional) data, Y is explained by some low-dimensional latent
3 | # data X, warped by a function drawn from a GP prior (f). So Y = f(X), but
4 | # we don't know X or f.
5 | #
6 | # In this example, we optimize X and the hyperparameters of the GP, but
7 | # we integrate over all possible functions f.
8 | #
9 | # Normally the observed data would be high-dimensional.
10 | #
11 | # David Duvenaud (duvenaud@gmail.com)
12 |
13 |
14 | import matplotlib.pyplot as plt
15 | from data import make_pinwheel
16 | from gaussian_process import make_gp_funs, rbf_covariance
17 | from scipy.optimize import minimize
18 |
19 | import autograd.numpy as np
20 | import autograd.numpy.random as npr
21 | from autograd import value_and_grad
22 | from autograd.scipy.stats import norm
23 |
24 | if __name__ == "__main__":
25 | data_dimension = 2 # Normally the data dimension would be much higher.
26 | latent_dimension = 2
27 |
28 | # Build model and objective function.
29 | params_per_gp, predict, log_marginal_likelihood = make_gp_funs(
30 | rbf_covariance, num_cov_params=latent_dimension + 1
31 | )
32 | total_gp_params = data_dimension * params_per_gp
33 |
34 | data = make_pinwheel(radial_std=0.3, tangential_std=0.05, num_classes=3, num_per_class=30, rate=0.4)
35 | datalen = data.shape[0]
36 |
37 | num_latent_params = datalen * latent_dimension
38 |
39 | def unpack_params(params):
40 | gp_params = np.reshape(params[:total_gp_params], (data_dimension, params_per_gp))
41 | latents = np.reshape(params[total_gp_params:], (datalen, latent_dimension))
42 | return gp_params, latents
43 |
44 | def objective(params):
45 | gp_params, latents = unpack_params(params)
46 | gp_likelihood = sum(
47 | [log_marginal_likelihood(gp_params[i], latents, data[:, i]) for i in range(data_dimension)]
48 | )
49 | latent_prior_likelihood = np.sum(norm.logpdf(latents))
50 | return -gp_likelihood - latent_prior_likelihood
51 |
52 | # Set up figure.
53 | fig = plt.figure(figsize=(12, 8), facecolor="white")
54 | latent_ax = fig.add_subplot(121, frameon=False)
55 | data_ax = fig.add_subplot(122, frameon=False)
56 | plt.show(block=False)
57 |
58 | def callback(params):
59 | print(f"Log likelihood {-objective(params)}")
60 | gp_params, latents = unpack_params(params)
61 |
62 | data_ax.cla()
63 | data_ax.plot(data[:, 0], data[:, 1], "bx")
64 | data_ax.set_xticks([])
65 | data_ax.set_yticks([])
66 | data_ax.set_title("Observed Data")
67 |
68 | latent_ax.cla()
69 | latent_ax.plot(latents[:, 0], latents[:, 1], "kx")
70 | latent_ax.set_xticks([])
71 | latent_ax.set_yticks([])
72 | latent_ax.set_xlim([-2, 2])
73 | latent_ax.set_ylim([-2, 2])
74 | latent_ax.set_title("Latent coordinates")
75 |
76 | plt.draw()
77 | plt.pause(1.0 / 60.0)
78 |
79 | # Initialize covariance parameters
80 | rs = npr.RandomState(1)
81 | init_params = rs.randn(total_gp_params + num_latent_params) * 0.1
82 |
83 | print("Optimizing covariance parameters and latent variable locations...")
84 | minimize(value_and_grad(objective), init_params, jac=True, method="CG", callback=callback)
85 |
--------------------------------------------------------------------------------
/examples/graph.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/graph.pdf
--------------------------------------------------------------------------------
/examples/hmm_em.py:
--------------------------------------------------------------------------------
1 | import string
2 | from functools import partial
3 | from os.path import dirname, join
4 |
5 | import autograd.numpy as np
6 | import autograd.numpy.random as npr
7 | from autograd import value_and_grad as vgrad
8 | from autograd.scipy.special import logsumexp
9 |
10 |
11 | def EM(init_params, data, callback=None):
12 | def EM_update(params):
13 | natural_params = list(map(np.log, params))
14 | loglike, E_stats = vgrad(log_partition_function)(natural_params, data) # E step
15 | if callback:
16 | callback(loglike, params)
17 | return list(map(normalize, E_stats)) # M step
18 |
19 | def fixed_point(f, x0):
20 | x1 = f(x0)
21 | while different(x0, x1):
22 | x0, x1 = x1, f(x1)
23 | return x1
24 |
25 | def different(params1, params2):
26 | allclose = partial(np.allclose, atol=1e-3, rtol=1e-3)
27 | return not all(map(allclose, params1, params2))
28 |
29 | return fixed_point(EM_update, init_params)
30 |
31 |
32 | def normalize(a):
33 | def replace_zeros(a):
34 | return np.where(a > 0.0, a, 1.0)
35 |
36 | return a / replace_zeros(a.sum(-1, keepdims=True))
37 |
38 |
39 | def log_partition_function(natural_params, data):
40 | if isinstance(data, list):
41 | return sum(map(partial(log_partition_function, natural_params), data))
42 |
43 | log_pi, log_A, log_B = natural_params
44 |
45 | log_alpha = log_pi
46 | for y in data:
47 | log_alpha = logsumexp(log_alpha[:, None] + log_A, axis=0) + log_B[:, y]
48 |
49 | return logsumexp(log_alpha)
50 |
51 |
52 | def initialize_hmm_parameters(num_states, num_outputs):
53 | init_pi = normalize(npr.rand(num_states))
54 | init_A = normalize(npr.rand(num_states, num_states))
55 | init_B = normalize(npr.rand(num_states, num_outputs))
56 | return init_pi, init_A, init_B
57 |
58 |
59 | def build_dataset(filename, max_lines=-1):
60 | """Loads a text file, and turns each line into an encoded sequence."""
61 | encodings = dict(list(map(reversed, enumerate(string.printable))))
62 | digitize = lambda char: encodings[char] if char in encodings else len(encodings)
63 | encode_line = lambda line: np.array(list(map(digitize, line)))
64 | nonblank_line = lambda line: len(line) > 2
65 |
66 | with open(filename) as f:
67 | lines = f.readlines()
68 |
69 | encoded_lines = list(map(encode_line, list(filter(nonblank_line, lines))[:max_lines]))
70 | num_outputs = len(encodings) + 1
71 |
72 | return encoded_lines, num_outputs
73 |
74 |
75 | if __name__ == "__main__":
76 | np.random.seed(0)
77 | np.seterr(divide="ignore")
78 |
79 | # callback to print log likelihoods during training
80 | print_loglike = lambda loglike, params: print(loglike)
81 |
82 | # load training data
83 | lstm_filename = join(dirname(__file__), "lstm.py")
84 | train_inputs, num_outputs = build_dataset(lstm_filename, max_lines=60)
85 |
86 | # train with EM
87 | num_states = 20
88 | init_params = initialize_hmm_parameters(num_states, num_outputs)
89 | pi, A, B = EM(init_params, train_inputs, print_loglike)
90 |
--------------------------------------------------------------------------------
/examples/ica.py:
--------------------------------------------------------------------------------
1 | import matplotlib.cm as cm
2 | import matplotlib.pyplot as plt
3 | from scipy.optimize import minimize
4 |
5 | import autograd.numpy as np
6 | import autograd.numpy.random as npr
7 | import autograd.scipy.stats.t as t
8 | from autograd import value_and_grad
9 |
10 |
11 | def make_ica_funs(observed_dimension, latent_dimension):
12 | """These functions implement independent component analysis.
13 |
14 | The model is:
15 | latents are drawn i.i.d. for each data point from a product of student-ts.
16 | weights are the same across all datapoints.
17 | each data = latents * weghts + noise."""
18 |
19 | def sample(weights, n_samples, noise_std, rs):
20 | latents = rs.randn(latent_dimension, n_samples)
21 | latents = np.array(sorted(latents.T, key=lambda a_entry: a_entry[0])).T
22 | noise = rs.randn(n_samples, observed_dimension) * noise_std
23 | observed = predict(weights, latents) + noise
24 | return latents, observed
25 |
26 | def predict(weights, latents):
27 | return np.dot(weights, latents).T
28 |
29 | def logprob(weights, latents, noise_std, observed):
30 | preds = predict(weights, latents)
31 | log_lik = np.sum(t.logpdf(preds, 2.4, observed, noise_std))
32 | return log_lik
33 |
34 | num_weights = observed_dimension * latent_dimension
35 |
36 | def unpack_weights(weights):
37 | return np.reshape(weights, (observed_dimension, latent_dimension))
38 |
39 | return num_weights, sample, logprob, unpack_weights
40 |
41 |
42 | def color_scatter(ax, xs, ys):
43 | colors = cm.rainbow(np.linspace(0, 1, len(ys)))
44 | for x, y, c in zip(xs, ys, colors):
45 | ax.scatter(x, y, color=c)
46 |
47 |
48 | if __name__ == "__main__":
49 | observed_dimension = 100
50 | latent_dimension = 2
51 | true_noise_var = 1.0
52 | n_samples = 200
53 |
54 | num_weights, sample, logprob, unpack_weights = make_ica_funs(observed_dimension, latent_dimension)
55 |
56 | num_latent_params = latent_dimension * n_samples
57 | total_num_params = num_weights + num_latent_params + 1
58 |
59 | def unpack_params(params):
60 | weights = unpack_weights(params[:num_weights])
61 | latents = np.reshape(
62 | params[num_weights : num_weights + num_latent_params], (latent_dimension, n_samples)
63 | )
64 | noise_std = np.exp(params[-1])
65 | return weights, latents, noise_std
66 |
67 | rs = npr.RandomState(0)
68 | true_weights = np.zeros((observed_dimension, latent_dimension))
69 | for i in range(latent_dimension):
70 | true_weights[:, i] = np.sin(np.linspace(0, 4 + i * 3.2, observed_dimension))
71 |
72 | true_latents, data = sample(true_weights, n_samples, true_noise_var, rs)
73 |
74 | # Set up figure.
75 | fig2 = plt.figure(figsize=(6, 6), facecolor="white")
76 | ax_data = fig2.add_subplot(111, frameon=False)
77 | ax_data.matshow(data)
78 |
79 | fig1 = plt.figure(figsize=(12, 16), facecolor="white")
80 | ax_true_latents = fig1.add_subplot(411, frameon=False)
81 | ax_est_latents = fig1.add_subplot(412, frameon=False)
82 | ax_true_weights = fig1.add_subplot(413, frameon=False)
83 | ax_est_weights = fig1.add_subplot(414, frameon=False)
84 |
85 | plt.show(block=False)
86 | ax_true_weights.scatter(true_weights[:, 0], true_weights[:, 1])
87 | ax_true_weights.set_title("True weights")
88 | color_scatter(ax_true_latents, true_latents[0, :], true_latents[1, :])
89 | ax_true_latents.set_title("True latents")
90 | ax_true_latents.set_xticks([])
91 | ax_true_weights.set_xticks([])
92 | ax_true_latents.set_yticks([])
93 | ax_true_weights.set_yticks([])
94 |
95 | def objective(params):
96 | weight_matrix, latents, noise_std = unpack_params(params)
97 | return -logprob(weight_matrix, latents, noise_std, data) / n_samples
98 |
99 | def callback(params):
100 | weights, latents, noise_std = unpack_params(params)
101 | print(f"Log likelihood {-objective(params)}, noise_std {noise_std}")
102 | ax_est_weights.cla()
103 | ax_est_weights.scatter(weights[:, 0], weights[:, 1])
104 | ax_est_weights.set_title("Estimated weights")
105 | ax_est_latents.cla()
106 | color_scatter(ax_est_latents, latents[0, :], latents[1, :])
107 | ax_est_latents.set_title("Estimated latents")
108 | ax_est_weights.set_yticks([])
109 | ax_est_latents.set_yticks([])
110 | ax_est_weights.set_xticks([])
111 | ax_est_latents.set_xticks([])
112 | plt.draw()
113 | plt.pause(1.0 / 60.0)
114 |
115 | # Initialize and optimize model.
116 | rs = npr.RandomState(0)
117 | init_params = rs.randn(total_num_params)
118 | minimize(value_and_grad(objective), init_params, jac=True, method="CG", callback=callback)
119 | plt.pause(20)
120 |
--------------------------------------------------------------------------------
/examples/logistic_regression.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | from autograd import grad
3 | from autograd.test_util import check_grads
4 |
5 |
6 | def sigmoid(x):
7 | return 0.5 * (np.tanh(x) + 1)
8 |
9 |
10 | def logistic_predictions(weights, inputs):
11 | # Outputs probability of a label being true according to logistic model.
12 | return sigmoid(np.dot(inputs, weights))
13 |
14 |
15 | def training_loss(weights):
16 | # Training loss is the negative log-likelihood of the training labels.
17 | preds = logistic_predictions(weights, inputs)
18 | label_probabilities = preds * targets + (1 - preds) * (1 - targets)
19 | return -np.sum(np.log(label_probabilities))
20 |
21 |
22 | # Build a toy dataset.
23 | inputs = np.array([[0.52, 1.12, 0.77], [0.88, -1.08, 0.15], [0.52, 0.06, -1.30], [0.74, -2.49, 1.39]])
24 | targets = np.array([True, True, False, True])
25 |
26 | # Build a function that returns gradients of training loss using autograd.
27 | training_gradient_fun = grad(training_loss)
28 |
29 | # Check the gradients numerically, just to be safe.
30 | weights = np.array([0.0, 0.0, 0.0])
31 | check_grads(training_loss, modes=["rev"])(weights)
32 |
33 | # Optimize weights using gradient descent.
34 | print("Initial loss:", training_loss(weights))
35 | for i in range(100):
36 | weights -= training_gradient_fun(weights) * 0.01
37 |
38 | print("Trained loss:", training_loss(weights))
39 |
--------------------------------------------------------------------------------
/examples/lstm.py:
--------------------------------------------------------------------------------
1 | """Implements the long-short term memory character model.
2 | This version vectorizes over multiple examples, but each string
3 | has a fixed length."""
4 |
5 | from os.path import dirname, join
6 |
7 | from rnn import build_dataset, concat_and_multiply, one_hot_to_string, sigmoid, string_to_one_hot
8 |
9 | import autograd.numpy as np
10 | import autograd.numpy.random as npr
11 | from autograd import grad
12 | from autograd.misc.optimizers import adam
13 | from autograd.scipy.special import logsumexp
14 |
15 |
16 | def init_lstm_params(input_size, state_size, output_size, param_scale=0.01, rs=npr.RandomState(0)):
17 | def rp(*shape):
18 | return rs.randn(*shape) * param_scale
19 |
20 | return {
21 | "init cells": rp(1, state_size),
22 | "init hiddens": rp(1, state_size),
23 | "change": rp(input_size + state_size + 1, state_size),
24 | "forget": rp(input_size + state_size + 1, state_size),
25 | "ingate": rp(input_size + state_size + 1, state_size),
26 | "outgate": rp(input_size + state_size + 1, state_size),
27 | "predict": rp(state_size + 1, output_size),
28 | }
29 |
30 |
31 | def lstm_predict(params, inputs):
32 | def update_lstm(input, hiddens, cells):
33 | change = np.tanh(concat_and_multiply(params["change"], input, hiddens))
34 | forget = sigmoid(concat_and_multiply(params["forget"], input, hiddens))
35 | ingate = sigmoid(concat_and_multiply(params["ingate"], input, hiddens))
36 | outgate = sigmoid(concat_and_multiply(params["outgate"], input, hiddens))
37 | cells = cells * forget + ingate * change
38 | hiddens = outgate * np.tanh(cells)
39 | return hiddens, cells
40 |
41 | def hiddens_to_output_probs(hiddens):
42 | output = concat_and_multiply(params["predict"], hiddens)
43 | return output - logsumexp(output, axis=1, keepdims=True) # Normalize log-probs.
44 |
45 | num_sequences = inputs.shape[1]
46 | hiddens = np.repeat(params["init hiddens"], num_sequences, axis=0)
47 | cells = np.repeat(params["init cells"], num_sequences, axis=0)
48 |
49 | output = [hiddens_to_output_probs(hiddens)]
50 | for input in inputs: # Iterate over time steps.
51 | hiddens, cells = update_lstm(input, hiddens, cells)
52 | output.append(hiddens_to_output_probs(hiddens))
53 | return output
54 |
55 |
56 | def lstm_log_likelihood(params, inputs, targets):
57 | logprobs = lstm_predict(params, inputs)
58 | loglik = 0.0
59 | num_time_steps, num_examples, _ = inputs.shape
60 | for t in range(num_time_steps):
61 | loglik += np.sum(logprobs[t] * targets[t])
62 | return loglik / (num_time_steps * num_examples)
63 |
64 |
65 | if __name__ == "__main__":
66 | num_chars = 128
67 |
68 | # Learn to predict our own source code.
69 | text_filename = join(dirname(__file__), "lstm.py")
70 | train_inputs = build_dataset(text_filename, sequence_length=30, alphabet_size=num_chars, max_lines=60)
71 |
72 | init_params = init_lstm_params(input_size=128, output_size=128, state_size=40, param_scale=0.01)
73 |
74 | def print_training_prediction(weights):
75 | print("Training text Predicted text")
76 | logprobs = np.asarray(lstm_predict(weights, train_inputs))
77 | for t in range(logprobs.shape[1]):
78 | training_text = one_hot_to_string(train_inputs[:, t, :])
79 | predicted_text = one_hot_to_string(logprobs[:, t, :])
80 | print(training_text.replace("\n", " ") + "|" + predicted_text.replace("\n", " "))
81 |
82 | def training_loss(params, iter):
83 | return -lstm_log_likelihood(params, train_inputs, train_inputs)
84 |
85 | def callback(weights, iter, gradient):
86 | if iter % 10 == 0:
87 | print("Iteration", iter, "Train loss:", training_loss(weights, 0))
88 | print_training_prediction(weights)
89 |
90 | # Build gradient of loss function using autograd.
91 | training_loss_grad = grad(training_loss)
92 |
93 | print("Training LSTM...")
94 | trained_params = adam(training_loss_grad, init_params, step_size=0.1, num_iters=1000, callback=callback)
95 |
96 | print()
97 | print("Generating text from LSTM...")
98 | num_letters = 30
99 | for t in range(20):
100 | text = ""
101 | for i in range(num_letters):
102 | seqs = string_to_one_hot(text, num_chars)[:, np.newaxis, :]
103 | logprobs = lstm_predict(trained_params, seqs)[-1].ravel()
104 | text += chr(npr.choice(len(logprobs), p=np.exp(logprobs)))
105 | print(text)
106 |
--------------------------------------------------------------------------------
/examples/natural_gradient_black_box_svi.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | # same BBSVI function!
4 | from black_box_svi import black_box_variational_inference
5 |
6 | import autograd.numpy as np
7 | import autograd.scipy.stats.norm as norm
8 | from autograd.misc.optimizers import adam, sgd
9 |
10 | if __name__ == "__main__":
11 | # Specify an inference problem by its unnormalized log-density.
12 | # it's difficult to see the benefit in low dimensions
13 | # model parameters are a mean and a log_sigma
14 | np.random.seed(42)
15 | obs_dim = 20
16 | Y = np.random.randn(obs_dim, obs_dim).dot(np.random.randn(obs_dim))
17 |
18 | def log_density(x, t):
19 | mu, log_sigma = x[:, :obs_dim], x[:, obs_dim:]
20 | sigma_density = np.sum(norm.logpdf(log_sigma, 0, 1.35), axis=1)
21 | mu_density = np.sum(norm.logpdf(Y, mu, np.exp(log_sigma)), axis=1)
22 | return sigma_density + mu_density
23 |
24 | # Build variational objective.
25 | D = obs_dim * 2 # dimension of our posterior
26 | objective, gradient, unpack_params = black_box_variational_inference(log_density, D, num_samples=2000)
27 |
28 | # Define the natural gradient
29 | # The natural gradient of the ELBO is the gradient of the elbo,
30 | # preconditioned by the inverse Fisher Information Matrix. The Fisher,
31 | # in the case of a diagonal gaussian, is a diagonal matrix that is a
32 | # simple function of the variance. Intuitively, statistical distance
33 | # created by perturbing the mean of an independent Gaussian is
34 | # determined by how wide the distribution is along that dimension ---
35 | # the wider the distribution, the less sensitive statistical distances is
36 | # to perturbations of the mean; the narrower the distribution, the more
37 | # the statistical distance changes when you perturb the mean (imagine
38 | # an extremely narrow Gaussian --- basically a spike. The KL between
39 | # this Gaussian and a Gaussian $\epsilon$ away in location can be big ---
40 | # moving the Gaussian could significantly reduce overlap in support
41 | # which corresponds to a greater statistical distance).
42 | #
43 | # When we want to move in directions of steepest ascent, we multiply by
44 | # the inverse fisher --- that way we make quicker progress when the
45 | # variance is wide, and we scale down our step size when the variance
46 | # is small (which leads to more robust/less chaotic ascent).
47 | def fisher_diag(lam):
48 | mu, log_sigma = unpack_params(lam)
49 | return np.concatenate([np.exp(-2.0 * log_sigma), np.ones(len(log_sigma)) * 2])
50 |
51 | # simple! basically free!
52 | natural_gradient = lambda lam, i: (1.0 / fisher_diag(lam)) * gradient(lam, i)
53 |
54 | # function for keeping track of callback ELBO values (for plotting below)
55 | def optimize_and_lls(optfun):
56 | num_iters = 200
57 | elbos = []
58 |
59 | def callback(params, t, g):
60 | elbo_val = -objective(params, t)
61 | elbos.append(elbo_val)
62 | if t % 50 == 0:
63 | print(f"Iteration {t} lower bound {elbo_val}")
64 |
65 | init_mean = -1 * np.ones(D)
66 | init_log_std = -5 * np.ones(D)
67 | init_var_params = np.concatenate([init_mean, init_log_std])
68 | variational_params = optfun(num_iters, init_var_params, callback)
69 | return np.array(elbos)
70 |
71 | # let's optimize this with a few different step sizes
72 | elbo_lists = []
73 | step_sizes = [0.1, 0.25, 0.5]
74 | for step_size in step_sizes:
75 | # optimize with standard gradient + adam
76 | optfun = lambda n, init, cb: adam(gradient, init, step_size=step_size, num_iters=n, callback=cb)
77 | standard_lls = optimize_and_lls(optfun)
78 |
79 | # optimize with natural gradient + sgd, no momentum
80 | optnat = lambda n, init, cb: sgd(
81 | natural_gradient, init, step_size=step_size, num_iters=n, callback=cb, mass=0.001
82 | )
83 | natural_lls = optimize_and_lls(optnat)
84 | elbo_lists.append((standard_lls, natural_lls))
85 |
86 | # visually compare the ELBO
87 | plt.figure(figsize=(12, 8))
88 | colors = ["b", "k", "g"]
89 | for col, ss, (stand_lls, nat_lls) in zip(colors, step_sizes, elbo_lists):
90 | plt.plot(
91 | np.arange(len(stand_lls)),
92 | stand_lls,
93 | "--",
94 | label="standard (adam, step-size = %2.2f)" % ss,
95 | alpha=0.5,
96 | c=col,
97 | )
98 | plt.plot(np.arange(len(nat_lls)), nat_lls, "-", label="natural (sgd, step-size = %2.2f)" % ss, c=col)
99 |
100 | llrange = natural_lls.max() - natural_lls.min()
101 | plt.ylim((natural_lls.max() - llrange * 0.1, natural_lls.max() + 10))
102 | plt.xlabel("optimization iteration")
103 | plt.ylabel("ELBO")
104 | plt.legend(loc="lower right")
105 | plt.title("%d dimensional posterior" % D)
106 | plt.show()
107 |
--------------------------------------------------------------------------------
/examples/negative_binomial_maxlike.py:
--------------------------------------------------------------------------------
1 | import scipy.optimize
2 |
3 | import autograd.numpy as np
4 | import autograd.numpy.random as npr
5 | from autograd import grad
6 | from autograd.scipy.special import gammaln
7 |
8 | # The code in this example implements a method for finding a stationary point of
9 | # the negative binomial likelihood via Newton's method, described here:
10 | # https://en.wikipedia.org/wiki/Negative_binomial_distribution#Maximum_likelihood_estimation
11 |
12 |
13 | def newton(f, x0):
14 | # wrap scipy.optimize.newton with our automatic derivatives
15 | return scipy.optimize.newton(f, x0, fprime=grad(f), fprime2=grad(grad(f)))
16 |
17 |
18 | def negbin_loglike(r, p, x):
19 | # the negative binomial log likelihood we want to maximize
20 | return gammaln(r + x) - gammaln(r) - gammaln(x + 1) + x * np.log(p) + r * np.log(1 - p)
21 |
22 |
23 | def negbin_sample(r, p, size):
24 | # a negative binomial is a gamma-compound-Poisson
25 | return npr.poisson(npr.gamma(r, p / (1 - p), size=size))
26 |
27 |
28 | def fit_maxlike(x, r_guess):
29 | # follows Wikipedia's section on negative binomial max likelihood
30 | assert np.var(x) > np.mean(x), "Likelihood-maximizing parameters don't exist!"
31 | loglike = lambda r, p: np.sum(negbin_loglike(r, p, x))
32 | p = lambda r: np.sum(x) / np.sum(r + x)
33 | rprime = lambda r: grad(loglike)(r, p(r))
34 | r = newton(rprime, r_guess)
35 | return r, p(r)
36 |
37 |
38 | if __name__ == "__main__":
39 | # generate data
40 | npr.seed(0)
41 | data = negbin_sample(r=5, p=0.5, size=1000)
42 |
43 | # fit likelihood-extremizing parameters
44 | r, p = fit_maxlike(data, r_guess=1)
45 |
46 | # report fit
47 | print("Fit parameters:")
48 | print(f"r={r}, p={p}")
49 |
50 | print("Check that we are at a local stationary point:")
51 | loglike = lambda r, p: np.sum(negbin_loglike(r, p, data))
52 | grad_both = grad(loglike, argnum=(0, 1))
53 | print(grad_both(r, p))
54 |
55 | import matplotlib.pyplot as plt
56 |
57 | xm = data.max()
58 | plt.figure()
59 | plt.hist(data, bins=np.arange(xm + 1) - 0.5, normed=True, label="normed data counts")
60 | plt.xlim(0, xm)
61 | plt.plot(np.arange(xm), np.exp(negbin_loglike(r, p, np.arange(xm))), label="maxlike fit")
62 | plt.xlabel("k")
63 | plt.ylabel("p(k)")
64 | plt.legend(loc="best")
65 | plt.show()
66 |
--------------------------------------------------------------------------------
/examples/neural_net.py:
--------------------------------------------------------------------------------
1 | """A multi-layer perceptron for classification of MNIST handwritten digits."""
2 |
3 | from data import load_mnist
4 |
5 | import autograd.numpy as np
6 | import autograd.numpy.random as npr
7 | from autograd import grad
8 | from autograd.misc.flatten import flatten
9 | from autograd.misc.optimizers import adam
10 | from autograd.scipy.special import logsumexp
11 |
12 |
13 | def init_random_params(scale, layer_sizes, rs=npr.RandomState(0)):
14 | """Build a list of (weights, biases) tuples,
15 | one for each layer in the net."""
16 | return [
17 | (
18 | scale * rs.randn(m, n), # weight matrix
19 | scale * rs.randn(n),
20 | ) # bias vector
21 | for m, n in zip(layer_sizes[:-1], layer_sizes[1:])
22 | ]
23 |
24 |
25 | def neural_net_predict(params, inputs):
26 | """Implements a deep neural network for classification.
27 | params is a list of (weights, bias) tuples.
28 | inputs is an (N x D) matrix.
29 | returns normalized class log-probabilities."""
30 | for W, b in params:
31 | outputs = np.dot(inputs, W) + b
32 | inputs = np.tanh(outputs)
33 | return outputs - logsumexp(outputs, axis=1, keepdims=True)
34 |
35 |
36 | def l2_norm(params):
37 | """Computes l2 norm of params by flattening them into a vector."""
38 | flattened, _ = flatten(params)
39 | return np.dot(flattened, flattened)
40 |
41 |
42 | def log_posterior(params, inputs, targets, L2_reg):
43 | log_prior = -L2_reg * l2_norm(params)
44 | log_lik = np.sum(neural_net_predict(params, inputs) * targets)
45 | return log_prior + log_lik
46 |
47 |
48 | def accuracy(params, inputs, targets):
49 | target_class = np.argmax(targets, axis=1)
50 | predicted_class = np.argmax(neural_net_predict(params, inputs), axis=1)
51 | return np.mean(predicted_class == target_class)
52 |
53 |
54 | if __name__ == "__main__":
55 | # Model parameters
56 | layer_sizes = [784, 200, 100, 10]
57 | L2_reg = 1.0
58 |
59 | # Training parameters
60 | param_scale = 0.1
61 | batch_size = 256
62 | num_epochs = 5
63 | step_size = 0.001
64 |
65 | print("Loading training data...")
66 | N, train_images, train_labels, test_images, test_labels = load_mnist()
67 |
68 | init_params = init_random_params(param_scale, layer_sizes)
69 |
70 | num_batches = int(np.ceil(len(train_images) / batch_size))
71 |
72 | def batch_indices(iter):
73 | idx = iter % num_batches
74 | return slice(idx * batch_size, (idx + 1) * batch_size)
75 |
76 | # Define training objective
77 | def objective(params, iter):
78 | idx = batch_indices(iter)
79 | return -log_posterior(params, train_images[idx], train_labels[idx], L2_reg)
80 |
81 | # Get gradient of objective using autograd.
82 | objective_grad = grad(objective)
83 |
84 | print(" Epoch | Train accuracy | Test accuracy ")
85 |
86 | def print_perf(params, iter, gradient):
87 | if iter % num_batches == 0:
88 | train_acc = accuracy(params, train_images, train_labels)
89 | test_acc = accuracy(params, test_images, test_labels)
90 | print(f"{iter // num_batches:15}|{train_acc:20}|{test_acc:20}")
91 |
92 | # The optimizers provided can optimize lists, tuples, or dicts of parameters.
93 | optimized_params = adam(
94 | objective_grad,
95 | init_params,
96 | step_size=step_size,
97 | num_iters=num_epochs * num_batches,
98 | callback=print_perf,
99 | )
100 |
--------------------------------------------------------------------------------
/examples/neural_net_regression.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | import autograd.numpy as np
4 | import autograd.numpy.random as npr
5 | import autograd.scipy.stats.norm as norm
6 | from autograd import grad
7 | from autograd.misc import flatten
8 | from autograd.misc.optimizers import adam
9 |
10 |
11 | def init_random_params(scale, layer_sizes, rs=npr.RandomState(0)):
12 | """Build a list of (weights, biases) tuples, one for each layer."""
13 | return [
14 | (
15 | rs.randn(insize, outsize) * scale, # weight matrix
16 | rs.randn(outsize) * scale,
17 | ) # bias vector
18 | for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:])
19 | ]
20 |
21 |
22 | def nn_predict(params, inputs, nonlinearity=np.tanh):
23 | for W, b in params:
24 | outputs = np.dot(inputs, W) + b
25 | inputs = nonlinearity(outputs)
26 | return outputs
27 |
28 |
29 | def log_gaussian(params, scale):
30 | flat_params, _ = flatten(params)
31 | return np.sum(norm.logpdf(flat_params, 0, scale))
32 |
33 |
34 | def logprob(weights, inputs, targets, noise_scale=0.1):
35 | predictions = nn_predict(weights, inputs)
36 | return np.sum(norm.logpdf(predictions, targets, noise_scale))
37 |
38 |
39 | def build_toy_dataset(n_data=80, noise_std=0.1):
40 | rs = npr.RandomState(0)
41 | inputs = np.concatenate([np.linspace(0, 3, num=n_data / 2), np.linspace(6, 8, num=n_data / 2)])
42 | targets = np.cos(inputs) + rs.randn(n_data) * noise_std
43 | inputs = (inputs - 4.0) / 2.0
44 | inputs = inputs[:, np.newaxis]
45 | targets = targets[:, np.newaxis] / 2.0
46 | return inputs, targets
47 |
48 |
49 | if __name__ == "__main__":
50 | init_scale = 0.1
51 | weight_prior_variance = 10.0
52 | init_params = init_random_params(init_scale, layer_sizes=[1, 4, 4, 1])
53 |
54 | inputs, targets = build_toy_dataset()
55 |
56 | def objective(weights, t):
57 | return -logprob(weights, inputs, targets) - log_gaussian(weights, weight_prior_variance)
58 |
59 | print(grad(objective)(init_params, 0))
60 |
61 | # Set up figure.
62 | fig = plt.figure(figsize=(12, 8), facecolor="white")
63 | ax = fig.add_subplot(111, frameon=False)
64 | plt.show(block=False)
65 |
66 | def callback(params, t, g):
67 | print(f"Iteration {t} log likelihood {-objective(params, t)}")
68 |
69 | # Plot data and functions.
70 | plt.cla()
71 | ax.plot(inputs.ravel(), targets.ravel(), "bx", ms=12)
72 | plot_inputs = np.reshape(np.linspace(-7, 7, num=300), (300, 1))
73 | outputs = nn_predict(params, plot_inputs)
74 | ax.plot(plot_inputs, outputs, "r", lw=3)
75 | ax.set_ylim([-1, 1])
76 | plt.draw()
77 | plt.pause(1.0 / 60.0)
78 |
79 | print("Optimizing network parameters...")
80 | optimized_params = adam(grad(objective), init_params, step_size=0.01, num_iters=1000, callback=callback)
81 |
--------------------------------------------------------------------------------
/examples/ode_net.py:
--------------------------------------------------------------------------------
1 | # A demo of gradients through scipy.integrate.odeint,
2 | # estimating the dynamics of a system given a trajectory.
3 |
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as npo
7 |
8 | import autograd.numpy as np
9 | import autograd.numpy.random as npr
10 | from autograd import grad
11 | from autograd.builtins import tuple
12 | from autograd.misc.optimizers import adam
13 | from autograd.scipy.integrate import odeint
14 |
15 | N = 30 # Dataset size
16 | D = 2 # Data dimension
17 | max_T = 1.5
18 |
19 |
20 | # Two-dimensional damped oscillator
21 | def func(y, t0, A):
22 | return np.dot(y**3, A)
23 |
24 |
25 | def nn_predict(inputs, t, params):
26 | for W, b in params:
27 | outputs = np.dot(inputs, W) + b
28 | inputs = np.maximum(0, outputs)
29 | return outputs
30 |
31 |
32 | def init_nn_params(scale, layer_sizes, rs=npr.RandomState(0)):
33 | """Build a list of (weights, biases) tuples, one for each layer."""
34 | return [
35 | (
36 | rs.randn(insize, outsize) * scale, # weight matrix
37 | rs.randn(outsize) * scale,
38 | ) # bias vector
39 | for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:])
40 | ]
41 |
42 |
43 | # Define neural ODE model.
44 | def ode_pred(params, y0, t):
45 | return odeint(nn_predict, y0, t, tuple((params,)), rtol=0.01)
46 |
47 |
48 | def L1_loss(pred, targets):
49 | return np.mean(np.abs(pred - targets))
50 |
51 |
52 | if __name__ == "__main__":
53 | # Generate data from true dynamics.
54 | true_y0 = np.array([2.0, 0.0]).T
55 | t = np.linspace(0.0, max_T, N)
56 | true_A = np.array([[-0.1, 2.0], [-2.0, -0.1]])
57 | true_y = odeint(func, true_y0, t, args=(true_A,))
58 |
59 | def train_loss(params, iter):
60 | pred = ode_pred(params, true_y0, t)
61 | return L1_loss(pred, true_y)
62 |
63 | # Set up figure
64 | fig = plt.figure(figsize=(12, 4), facecolor="white")
65 | ax_traj = fig.add_subplot(131, frameon=False)
66 | ax_phase = fig.add_subplot(132, frameon=False)
67 | ax_vecfield = fig.add_subplot(133, frameon=False)
68 | plt.show(block=False)
69 |
70 | # Plots data and learned dynamics.
71 | def callback(params, iter, g):
72 | pred = ode_pred(params, true_y0, t)
73 |
74 | print(f"Iteration {iter:d} train loss {L1_loss(pred, true_y):.6f}")
75 |
76 | ax_traj.cla()
77 | ax_traj.set_title("Trajectories")
78 | ax_traj.set_xlabel("t")
79 | ax_traj.set_ylabel("x,y")
80 | ax_traj.plot(t, true_y[:, 0], "-", t, true_y[:, 1], "g-")
81 | ax_traj.plot(t, pred[:, 0], "--", t, pred[:, 1], "b--")
82 | ax_traj.set_xlim(t.min(), t.max())
83 | ax_traj.set_ylim(-2, 2)
84 | ax_traj.xaxis.set_ticklabels([])
85 | ax_traj.yaxis.set_ticklabels([])
86 | ax_traj.legend()
87 |
88 | ax_phase.cla()
89 | ax_phase.set_title("Phase Portrait")
90 | ax_phase.set_xlabel("x")
91 | ax_phase.set_ylabel("y")
92 | ax_phase.plot(true_y[:, 0], true_y[:, 1], "g-")
93 | ax_phase.plot(pred[:, 0], pred[:, 1], "b--")
94 | ax_phase.set_xlim(-2, 2)
95 | ax_phase.set_ylim(-2, 2)
96 | ax_phase.xaxis.set_ticklabels([])
97 | ax_phase.yaxis.set_ticklabels([])
98 |
99 | ax_vecfield.cla()
100 | ax_vecfield.set_title("Learned Vector Field")
101 | ax_vecfield.set_xlabel("x")
102 | ax_vecfield.set_ylabel("y")
103 | ax_vecfield.xaxis.set_ticklabels([])
104 | ax_vecfield.yaxis.set_ticklabels([])
105 |
106 | # vector field plot
107 | y, x = npo.mgrid[-2:2:21j, -2:2:21j]
108 | dydt = nn_predict(np.stack([x, y], -1).reshape(21 * 21, 2), 0, params).reshape(-1, 2)
109 | mag = np.sqrt(dydt[:, 0] ** 2 + dydt[:, 1] ** 2).reshape(-1, 1)
110 | dydt = dydt / mag
111 | dydt = dydt.reshape(21, 21, 2)
112 |
113 | ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
114 | ax_vecfield.set_xlim(-2, 2)
115 | ax_vecfield.set_ylim(-2, 2)
116 |
117 | fig.tight_layout()
118 | plt.draw()
119 | plt.pause(0.001)
120 |
121 | # Train neural net dynamics to match data.
122 | init_params = init_nn_params(0.1, layer_sizes=[D, 150, D])
123 | optimized_params = adam(grad(train_loss), init_params, num_iters=1000, callback=callback)
124 |
--------------------------------------------------------------------------------
/examples/ode_net_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/ode_net_demo.png
--------------------------------------------------------------------------------
/examples/print_trace.py:
--------------------------------------------------------------------------------
1 | """Demonstrates how to use the tracer module, independent of autodiff, by
2 | creating a trace that prints out functions and their arguments as they're being
3 | evaluated"""
4 |
5 | import autograd.numpy as np # autograd has already wrapped numpy for us
6 | from autograd.tracer import Node, trace
7 |
8 |
9 | class PrintNode(Node):
10 | def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
11 | self.varname_generator = parents[0].varname_generator
12 | self.varname = next(self.varname_generator)
13 | args_or_vars = list(args)
14 | for argnum, parent in zip(parent_argnums, parents):
15 | args_or_vars[argnum] = parent.varname
16 | print("{} = {}({}) = {}".format(self.varname, fun.__name__, ",".join(map(str, args_or_vars)), value))
17 |
18 | def initialize_root(self, x):
19 | self.varname_generator = make_varname_generator()
20 | self.varname = next(self.varname_generator)
21 | print(f"{self.varname} = {x}")
22 |
23 |
24 | def make_varname_generator():
25 | for i in range(65, 91):
26 | yield chr(i)
27 | raise Exception("Ran out of alphabet!")
28 |
29 |
30 | def print_trace(f, x):
31 | start_node = PrintNode.new_root(x)
32 | trace(start_node, f, x)
33 | print()
34 |
35 |
36 | def avg(x, y):
37 | return (x + y) / 2
38 |
39 |
40 | def fun(x):
41 | y = np.sin(x + x)
42 | return avg(y, y)
43 |
44 |
45 | print_trace(fun, 1.23)
46 |
47 | # Traces can be nested, so we can also trace through grad(fun)
48 | from autograd import grad
49 |
50 | print_trace(grad(fun), 1.0)
51 |
--------------------------------------------------------------------------------
/examples/rkhs.py:
--------------------------------------------------------------------------------
1 | """
2 | Inferring a function from a reproducing kernel Hilbert space (RKHS) by taking
3 | gradients of eval with respect to the function-valued argument
4 | """
5 |
6 | import autograd.numpy as np
7 | import autograd.numpy.random as npr
8 | from autograd import grad
9 | from autograd.extend import Box, VSpace, defvjp, primitive
10 | from autograd.util import func
11 |
12 |
13 | class RKHSFun:
14 | def __init__(self, kernel, alphas={}):
15 | self.alphas = alphas
16 | self.kernel = kernel
17 | self.vs = RKHSFunVSpace(self)
18 |
19 | @primitive
20 | def __call__(self, x):
21 | return sum([a * self.kernel(x, x_repr) for x_repr, a in self.alphas.items()], 0.0)
22 |
23 | def __add__(self, f):
24 | return self.vs.add(self, f)
25 |
26 | def __mul__(self, a):
27 | return self.vs.scalar_mul(self, a)
28 |
29 |
30 | # TODO: add vjp of __call__ wrt x (and show it in action)
31 | defvjp(func(RKHSFun.__call__), lambda ans, f, x: lambda g: RKHSFun(f.kernel, {x: 1}) * g)
32 |
33 |
34 | class RKHSFunBox(Box, RKHSFun):
35 | @property
36 | def kernel(self):
37 | return self._value.kernel
38 |
39 |
40 | RKHSFunBox.register(RKHSFun)
41 |
42 |
43 | class RKHSFunVSpace(VSpace):
44 | def __init__(self, value):
45 | self.kernel = value.kernel
46 |
47 | def zeros(self):
48 | return RKHSFun(self.kernel)
49 |
50 | def randn(self):
51 | # These arbitrary vectors are not analogous to randn in any meaningful way
52 | N = npr.randint(1, 3)
53 | return RKHSFun(self.kernel, dict(zip(npr.randn(N), npr.randn(N))))
54 |
55 | def _add(self, f, g):
56 | assert f.kernel is g.kernel
57 | return RKHSFun(f.kernel, add_dicts(f.alphas, g.alphas))
58 |
59 | def _scalar_mul(self, f, a):
60 | return RKHSFun(f.kernel, {x: a * a_cur for x, a_cur in f.alphas.items()})
61 |
62 | def _inner_prod(self, f, g):
63 | assert f.kernel is g.kernel
64 | return sum(
65 | [a1 * a2 * f.kernel(x1, x2) for x1, a1 in f.alphas.items() for x2, a2 in g.alphas.items()], 0.0
66 | )
67 |
68 |
69 | RKHSFunVSpace.register(RKHSFun)
70 |
71 |
72 | def add_dicts(d1, d2):
73 | d = {}
74 | for k, v in d1.items() + d2.items():
75 | d[k] = d[k] + v if k in d else v
76 | return d
77 |
78 |
79 | if __name__ == "__main__":
80 |
81 | def sq_exp_kernel(x1, x2):
82 | return np.exp(-((x1 - x2) ** 2))
83 |
84 | xs = range(5)
85 | ys = [1, 2, 3, 2, 1]
86 |
87 | def logprob(f, xs, ys):
88 | return -sum((f(x) - y) ** 2 for x, y in zip(xs, ys))
89 |
90 | f = RKHSFun(sq_exp_kernel)
91 | for i in range(100):
92 | f = f + grad(logprob)(f, xs, ys) * 0.01
93 |
94 | for x, y in zip(xs, ys):
95 | print(f"{x}\t{y}\t{f(x)}")
96 |
--------------------------------------------------------------------------------
/examples/rnn.py:
--------------------------------------------------------------------------------
1 | """Implements the long-short term memory character model.
2 | This version vectorizes over multiple examples, but each string
3 | has a fixed length."""
4 |
5 | from os.path import dirname, join
6 |
7 | import autograd.numpy as np
8 | import autograd.numpy.random as npr
9 | from autograd import grad
10 | from autograd.misc.optimizers import adam
11 | from autograd.scipy.special import logsumexp
12 |
13 | ### Helper functions #################
14 |
15 |
16 | def sigmoid(x):
17 | return 0.5 * (np.tanh(x) + 1.0) # Output ranges from 0 to 1.
18 |
19 |
20 | def concat_and_multiply(weights, *args):
21 | cat_state = np.hstack(args + (np.ones((args[0].shape[0], 1)),))
22 | return np.dot(cat_state, weights)
23 |
24 |
25 | ### Define recurrent neural net #######
26 |
27 |
28 | def create_rnn_params(input_size, state_size, output_size, param_scale=0.01, rs=npr.RandomState(0)):
29 | return {
30 | "init hiddens": rs.randn(1, state_size) * param_scale,
31 | "change": rs.randn(input_size + state_size + 1, state_size) * param_scale,
32 | "predict": rs.randn(state_size + 1, output_size) * param_scale,
33 | }
34 |
35 |
36 | def rnn_predict(params, inputs):
37 | def update_rnn(input, hiddens):
38 | return np.tanh(concat_and_multiply(params["change"], input, hiddens))
39 |
40 | def hiddens_to_output_probs(hiddens):
41 | output = concat_and_multiply(params["predict"], hiddens)
42 | return output - logsumexp(output, axis=1, keepdims=True) # Normalize log-probs.
43 |
44 | num_sequences = inputs.shape[1]
45 | hiddens = np.repeat(params["init hiddens"], num_sequences, axis=0)
46 | output = [hiddens_to_output_probs(hiddens)]
47 |
48 | for input in inputs: # Iterate over time steps.
49 | hiddens = update_rnn(input, hiddens)
50 | output.append(hiddens_to_output_probs(hiddens))
51 | return output
52 |
53 |
54 | def rnn_log_likelihood(params, inputs, targets):
55 | logprobs = rnn_predict(params, inputs)
56 | loglik = 0.0
57 | num_time_steps, num_examples, _ = inputs.shape
58 | for t in range(num_time_steps):
59 | loglik += np.sum(logprobs[t] * targets[t])
60 | return loglik / (num_time_steps * num_examples)
61 |
62 |
63 | ### Dataset setup ##################
64 |
65 |
66 | def string_to_one_hot(string, maxchar):
67 | """Converts an ASCII string to a one-of-k encoding."""
68 | ascii = np.array([ord(c) for c in string]).T
69 | return np.array(ascii[:, None] == np.arange(maxchar)[None, :], dtype=int)
70 |
71 |
72 | def one_hot_to_string(one_hot_matrix):
73 | return "".join([chr(np.argmax(c)) for c in one_hot_matrix])
74 |
75 |
76 | def build_dataset(filename, sequence_length, alphabet_size, max_lines=-1):
77 | """Loads a text file, and turns each line into an encoded sequence."""
78 | with open(filename) as f:
79 | content = f.readlines()
80 | content = content[:max_lines]
81 | content = [line for line in content if len(line) > 2] # Remove blank lines
82 | seqs = np.zeros((sequence_length, len(content), alphabet_size))
83 | for ix, line in enumerate(content):
84 | padded_line = (line + " " * sequence_length)[:sequence_length]
85 | seqs[:, ix, :] = string_to_one_hot(padded_line, alphabet_size)
86 | return seqs
87 |
88 |
89 | if __name__ == "__main__":
90 | num_chars = 128
91 |
92 | # Learn to predict our own source code.
93 | text_filename = join(dirname(__file__), "rnn.py")
94 | train_inputs = build_dataset(text_filename, sequence_length=30, alphabet_size=num_chars, max_lines=60)
95 |
96 | init_params = create_rnn_params(input_size=128, output_size=128, state_size=40, param_scale=0.01)
97 |
98 | def print_training_prediction(weights):
99 | print("Training text Predicted text")
100 | logprobs = np.asarray(rnn_predict(weights, train_inputs))
101 | for t in range(logprobs.shape[1]):
102 | training_text = one_hot_to_string(train_inputs[:, t, :])
103 | predicted_text = one_hot_to_string(logprobs[:, t, :])
104 | print(training_text.replace("\n", " ") + "|" + predicted_text.replace("\n", " "))
105 |
106 | def training_loss(params, iter):
107 | return -rnn_log_likelihood(params, train_inputs, train_inputs)
108 |
109 | def callback(weights, iter, gradient):
110 | if iter % 10 == 0:
111 | print("Iteration", iter, "Train loss:", training_loss(weights, 0))
112 | print_training_prediction(weights)
113 |
114 | # Build gradient of loss function using autograd.
115 | training_loss_grad = grad(training_loss)
116 |
117 | print("Training RNN...")
118 | trained_params = adam(training_loss_grad, init_params, step_size=0.1, num_iters=1000, callback=callback)
119 |
120 | print()
121 | print("Generating text from RNN...")
122 | num_letters = 30
123 | for t in range(20):
124 | text = ""
125 | for i in range(num_letters):
126 | seqs = string_to_one_hot(text, num_chars)[:, np.newaxis, :]
127 | logprobs = rnn_predict(trained_params, seqs)[-1].ravel()
128 | text += chr(npr.choice(len(logprobs), p=np.exp(logprobs)))
129 | print(text)
130 |
--------------------------------------------------------------------------------
/examples/rosenbrock.py:
--------------------------------------------------------------------------------
1 | from scipy.optimize import minimize
2 |
3 | import autograd.numpy as np
4 | from autograd import value_and_grad
5 |
6 |
7 | def rosenbrock(x):
8 | return 100 * (x[1] - x[0] ** 2) ** 2 + (1 - x[0]) ** 2
9 |
10 |
11 | # Build a function that also returns gradients using autograd.
12 | rosenbrock_with_grad = value_and_grad(rosenbrock)
13 |
14 | # Optimize using conjugate gradients.
15 | result = minimize(rosenbrock_with_grad, x0=np.array([0.0, 0.0]), jac=True, method="CG")
16 | print(f"Found minimum at {result.x}")
17 |
--------------------------------------------------------------------------------
/examples/sinusoid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/sinusoid.png
--------------------------------------------------------------------------------
/examples/sinusoid.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | import autograd.numpy as np
4 | from autograd import grad
5 |
6 |
7 | def fun(x):
8 | return np.sin(x)
9 |
10 |
11 | d_fun = grad(fun) # First derivative
12 | dd_fun = grad(d_fun) # Second derivative
13 |
14 | x = np.linspace(-10, 10, 100)
15 | plt.plot(x, list(map(fun, x)), x, list(map(d_fun, x)), x, list(map(dd_fun, x)))
16 |
17 | plt.xlim([-10, 10])
18 | plt.ylim([-1.2, 1.2])
19 | plt.axis("off")
20 | plt.savefig("sinusoid.png")
21 | plt.clf()
22 |
23 |
24 | # Taylor approximation to sin function
25 | def fun(x):
26 | currterm = x
27 | ans = currterm
28 | for i in range(1000):
29 | print(i, end=" ")
30 | currterm = -currterm * x**2 / ((2 * i + 3) * (2 * i + 2))
31 | ans = ans + currterm
32 | if np.abs(currterm) < 0.2:
33 | break # (Very generous tolerance!)
34 |
35 | return ans
36 |
37 |
38 | d_fun = grad(fun)
39 | dd_fun = grad(d_fun)
40 |
41 | x = np.linspace(-10, 10, 100)
42 | plt.plot(x, list(map(fun, x)), x, list(map(d_fun, x)), x, list(map(dd_fun, x)))
43 |
44 | plt.xlim([-10, 10])
45 | plt.ylim([-1.2, 1.2])
46 | plt.axis("off")
47 | plt.savefig("sinusoid_taylor.png")
48 | plt.clf()
49 |
--------------------------------------------------------------------------------
/examples/sinusoid_taylor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/sinusoid_taylor.png
--------------------------------------------------------------------------------
/examples/tanh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/tanh.png
--------------------------------------------------------------------------------
/examples/tanh.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | import autograd.numpy as np
4 | from autograd import elementwise_grad as egrad
5 |
6 | """
7 | Mathematically we can only take gradients of scalar-valued functions, but
8 | autograd's elementwise_grad function also handles numpy's familiar vectorization
9 | of scalar functions, which is used in this example.
10 |
11 | To be precise, elementwise_grad(fun)(x) always returns the value of a
12 | vector-Jacobian product, where the Jacobian of fun is evaluated at x and the
13 | vector is an all-ones vector with the same size as the output of fun. When
14 | vectorizing a scalar-valued function over many arguments, the Jacobian of the
15 | overall vector-to-vector mapping is diagonal, and so this vector-Jacobian
16 | product simply returns the diagonal elements of the Jacobian, which is the
17 | (elementwise) gradient of the function at each input value over which the
18 | function is vectorized.
19 | """
20 |
21 |
22 | def tanh(x):
23 | return (1.0 - np.exp((-2 * x))) / (1.0 + np.exp(-(2 * x)))
24 |
25 |
26 | ### Plotting
27 | plt.figure(figsize=(12, 8))
28 | x = np.linspace(-7, 7, 700)
29 | plt.plot(x, tanh(x), label="tanh(x)")
30 | plt.plot(x, egrad(tanh)(x), label="1st derivative")
31 | plt.plot(x, egrad(egrad(tanh))(x), label="2nd derivative")
32 | plt.plot(x, egrad(egrad(egrad(tanh)))(x), label="3rd derivative")
33 | plt.plot(x, egrad(egrad(egrad(egrad(tanh))))(x), label="4th derivative")
34 | plt.xlabel("x")
35 | plt.ylabel("y")
36 | plt.ylim(-5, 5)
37 | plt.yticks(np.arange(-5, 6, 1))
38 | plt.legend()
39 | plt.grid(True)
40 | plt.title("tanh(x) and its derivatives")
41 | plt.savefig("tanh.png")
42 | plt.show()
43 |
--------------------------------------------------------------------------------
/examples/vae_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/vae_samples.png
--------------------------------------------------------------------------------
/license.txt:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2025 by the President and Fellows of Harvard University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/noxfile.py:
--------------------------------------------------------------------------------
1 | import platform
2 |
3 | import nox
4 |
5 | NIGHTLY_INDEX_URL = "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple"
6 | UV_NIGHTLY_ENV_VARS = {
7 | "UV_INDEX_URL": NIGHTLY_INDEX_URL,
8 | "UV_PRERELEASE": "allow",
9 | "UV_INDEX_STRATEGY": "first-index",
10 | }
11 |
12 | nox.needs_version = ">=2024.4.15"
13 | nox.options.default_venv_backend = "uv|virtualenv"
14 | nox.options.reuse_existing_virtualenvs = False
15 | nox.options.error_on_external_run = True
16 | # nox.options.sessions = ["lint", "validate-package", "tests"]
17 | nox.options.sessions = ["tests"]
18 |
19 |
20 | @nox.session(name="validate-package")
21 | def check(session):
22 | """Build source distribution, wheel, and check their metadata"""
23 | session.install("build", "twine", silent=False)
24 | session.run("python", "-m", "build")
25 | session.run("twine", "check", "--strict", "dist/*")
26 |
27 |
28 | @nox.session(name="tests", tags=["tests"])
29 | def run_tests(session):
30 | """Run unit tests and generate a coverage report"""
31 | # SciPy doesn't have wheels on PyPy
32 | if platform.python_implementation() == "PyPy":
33 | session.install("-e", ".[test]", silent=False)
34 | else:
35 | session.install("-e", ".[test,scipy]", silent=False)
36 | session.run("pytest", "--cov=autograd", "--cov-report=xml", "--cov-append", *session.posargs)
37 |
38 |
39 | @nox.session(name="lint", reuse_venv=True)
40 | def ruff(session):
41 | """Lightning-fast linting for Python"""
42 | session.install("pre-commit", silent=False)
43 | session.run("pre-commit", "run", "--all-files", "--show-diff-on-failure")
44 |
45 |
46 | @nox.session(name="nightly-tests", tags=["tests"])
47 | def run_nightly_tests(session):
48 | """Run tests against nightly versions of dependencies"""
49 | session.install("-e", ".[test]", silent=False)
50 | # SciPy doesn't have wheels on PyPy
51 | if platform.python_implementation() == "PyPy":
52 | session.install(
53 | "numpy", "--upgrade", "--only-binary", ":all:", silent=False, env=UV_NIGHTLY_ENV_VARS
54 | )
55 | else:
56 | session.install(
57 | "numpy", "scipy", "--upgrade", "--only-binary", ":all:", silent=False, env=UV_NIGHTLY_ENV_VARS
58 | )
59 | session.run("pytest", "--cov=autograd", "--cov-report=xml", "--cov-append", *session.posargs)
60 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "autograd"
7 | version = "1.8.0"
8 | requires-python = ">=3.9"
9 | description = "Efficiently computes derivatives of NumPy code."
10 | readme = "README.md"
11 | license = {file = "license.txt"}
12 | authors = [
13 | {name = "Dougal Maclaurin", email = "maclaurin@physics.harvard.edu"},
14 | {name = "David Duvenaud", email = "duvenaud@cs.toronto.edu"},
15 | {name = "Matthew Johnson", email = "mattjj@csail.mit.edu"},
16 | {name = "Jamie Townsend", email = "j.h.n.townsend@uva.nl"},
17 | ]
18 | maintainers = [
19 | {name = "Jamie Townsend", email = "j.h.n.townsend@uva.nl"},
20 | {name = "Fabian Joswig", email = "fabian.joswig@uni-muenster.de"},
21 | {name = "Agriya Khetarpal", email = "agriyakhetarpal@outlook.com"},
22 | ]
23 | classifiers = [
24 | "Development Status :: 4 - Beta",
25 | "Intended Audience :: Information Technology",
26 | "Intended Audience :: Science/Research",
27 | "License :: OSI Approved :: MIT License",
28 | "Programming Language :: Python :: 3.9",
29 | "Programming Language :: Python :: 3.10",
30 | "Programming Language :: Python :: 3.11",
31 | "Programming Language :: Python :: 3.12",
32 | "Programming Language :: Python :: 3.13",
33 | "Topic :: Scientific/Engineering",
34 | ]
35 | keywords = [
36 | "Automatic differentiation",
37 | "backpropagation",
38 | "gradients",
39 | "machine learning",
40 | "optimization",
41 | "neural networks",
42 | "Python",
43 | "NumPy",
44 | "SciPy",
45 | ]
46 | dependencies = [
47 | "numpy<3",
48 | ]
49 | # dynamic = ["version"]
50 |
51 | [project.urls]
52 | Source = "https://github.com/HIPS/autograd"
53 |
54 | [project.optional-dependencies]
55 | scipy = [
56 | "scipy",
57 | ]
58 | test = [
59 | "pytest",
60 | "pytest-cov",
61 | "pytest-xdist",
62 | ]
63 |
64 | [tool.coverage.run]
65 | source = ["autograd"]
66 |
67 | [tool.coverage.report]
68 | show_missing = true
69 |
70 | [tool.pytest.ini_options]
71 | required_plugins = ["pytest-cov", "pytest-xdist"]
72 | # TODO: generate HTML report, upload to CodeCov
73 | addopts = "--color=yes -sra -n auto --cov=autograd --cov-report=xml --cov-report=term"
74 |
75 | [tool.ruff]
76 | extend-exclude = []
77 | # TODO: not ignore them
78 | lint.extend-ignore = [
79 | "E731",
80 | "F401",
81 | "F403",
82 | "F841",
83 | "F821",
84 | "E721",
85 | "E722",
86 | "E741",
87 | "E402",
88 | "F811"
89 | ]
90 | lint.extend-select = ["I", "W"]
91 | line-length = 109
92 |
--------------------------------------------------------------------------------
/tests/_test_complexity.py:
--------------------------------------------------------------------------------
1 | import time
2 | import warnings
3 |
4 | import autograd.numpy as np
5 | from autograd import deriv, grad
6 | from autograd.builtins import list as make_list
7 |
8 |
9 | def timefunction(f):
10 | t = time.time()
11 | f()
12 | return time.time() - t
13 |
14 |
15 | def assert_linear_time(f):
16 | t = timefunction(lambda: f(1))
17 | t10 = timefunction(lambda: f(10))
18 | assert t10 > 5 * t, f"Too fast: f(1) takes {t}, f(10) takes {t10}"
19 | assert t10 < 20 * t, f"Too slow: f(1) takes {t}, f(10) takes {t10}"
20 | if not (8 * t < t10 < 12 * t):
21 | warnings.warn("Borderline linearity. May fail on different hardware")
22 |
23 |
24 | def test_array_creation():
25 | def fun(x, N):
26 | arr = [x for i in range(N)]
27 | return np.sum(np.array(arr))
28 |
29 | assert_linear_time(lambda N: grad(fun)(1.0, 200 * N))
30 |
31 |
32 | def test_array_indexing():
33 | def fun(x):
34 | return sum([x[i] for i in range(len(x))])
35 |
36 | assert_linear_time(lambda N: grad(fun)(np.zeros(200 * N)))
37 |
38 |
39 | def test_list_indexing():
40 | def fun(x):
41 | return sum([x[i] for i in range(len(x))])
42 |
43 | assert_linear_time(lambda N: grad(fun)([0.0 for i in range(50 * N)]))
44 |
45 |
46 | def test_list_creation():
47 | def fun(x, N):
48 | return make_list(*[x for _ in range(N)])
49 |
50 | assert_linear_time(lambda N: deriv(fun)(0.0, 20 * N))
51 |
52 |
53 | # This fails. Need to figure out why
54 | def test_array_creation_fwd():
55 | def fun(x, N):
56 | arr = [x for i in range(N)]
57 | return np.sum(np.array(arr))
58 |
59 | assert_linear_time(lambda N: deriv(fun)(1.0, 400 * N))
60 |
--------------------------------------------------------------------------------
/tests/check_examples_run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | PYTHONPATH=".:$PYTHONPATH"
4 | trap 'kill -INT -$pid && exit 1' INT
5 |
6 | working=()
7 | failing=()
8 |
9 | examples=$(find examples -name '*.py' -not -name '__init__.py')
10 |
11 | echo 'Running all the examples...'
12 | for f in $examples; do
13 | timeout 15s python2 $f > /dev/null 2>&1 & pid=$!
14 | wait $pid
15 | status=$?
16 | if [ $status -eq 0 -o $status -eq 124 ]; then
17 | echo $f "seems to work"
18 | working+=($f)
19 | elif [ $status -eq 137 ]; then
20 | echo $f "might be working, but had to be killed"
21 | working+=($f)
22 | else
23 | echo $f "seems broken, try running manually"
24 | failing+=($f)
25 | fi
26 | done
27 |
28 | if [ ! ${#working[@]} -eq 0 ]; then
29 | echo -e '\033[01;36m'
30 | echo "These seemed to WORK:"
31 | echo -en '\033[00m'
32 | printf '%s\n' "${working[@]}"
33 | echo
34 | fi
35 | if [ ! ${#failing[@]} -eq 0 ]; then
36 | echo -e '\033[01;31m'
37 | echo "These seemed to FAIL:"
38 | echo -en '\033[00m'
39 | printf '%s\n' "${failing[@]}"
40 | echo
41 | fi
42 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 |
5 | @pytest.fixture(autouse=True)
6 | def random_seed():
7 | np.random.seed(42)
8 |
--------------------------------------------------------------------------------
/tests/numpy_utils.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy.random as npr
2 | from autograd.test_util import combo_check
3 |
4 |
5 | def stat_check(fun, test_complex=True, **kwargs):
6 | # Tests functions that compute statistics, like sum, mean, etc
7 | x = 3.5
8 | A = npr.randn()
9 | B = npr.randn(3)
10 | C = npr.randn(2, 3)
11 | D = npr.randn(1, 3)
12 | check = combo_check(fun, (0,), **kwargs)
13 | check([x, A])
14 | check([B, C, D], axis=[None, 0], keepdims=[True, False])
15 | check([C, D], axis=[None, 0, 1], keepdims=[True, False])
16 | if test_complex:
17 | c = npr.randn() + 0.1j * npr.randn()
18 | E = npr.randn(2, 3) + 0.1j * npr.randn(2, 3)
19 | check([x, c, A])
20 | check([B, C, D, E], axis=[None, 0], keepdims=[True, False])
21 |
22 |
23 | def unary_ufunc_check(fun, lims=[-2, 2], test_complex=True, **kwargs):
24 | scalar = transform(lims, 0.4)
25 | vector = transform(lims, npr.rand(2))
26 | mat = transform(lims, npr.rand(3, 2))
27 | mat2 = transform(lims, npr.rand(1, 2))
28 | check = combo_check(fun, (0,), **kwargs)
29 | check([scalar, vector, mat, mat2])
30 | if test_complex:
31 | comp = transform(lims, 0.4) + 0.1j * transform(lims, 0.3)
32 | matc = transform(lims, npr.rand(3, 2)) + 0.1j * npr.rand(3, 2)
33 | check([comp, matc])
34 |
35 |
36 | def binary_ufunc_check(fun, lims_A=[-2, 2], lims_B=[-2, 2], test_complex=True, **kwargs):
37 | T_A = lambda x: transform(lims_A, x)
38 | T_B = lambda x: transform(lims_B, x)
39 | scalar = 0.6
40 | vector = npr.rand(2)
41 | mat = npr.rand(3, 2)
42 | mat2 = npr.rand(1, 2)
43 | check = combo_check(fun, (0, 1), **kwargs)
44 | check([T_A(scalar), T_A(vector), T_A(mat), T_A(mat2)], [T_B(scalar), T_B(vector), T_B(mat), T_B(mat2)])
45 | if test_complex:
46 | comp = 0.6 + 0.3j
47 | matc = npr.rand(3, 2) + 0.1j * npr.rand(3, 2)
48 | check(
49 | [T_A(scalar), T_A(comp), T_A(vector), T_A(matc), T_A(mat2)],
50 | [T_B(scalar), T_B(comp), T_B(vector), T_B(matc), T_B(mat2)],
51 | )
52 |
53 |
54 | def binary_ufunc_check_no_same_args(fun, lims_A=[-2, 2], lims_B=[-2, 2], test_complex=True, **kwargs):
55 | T_A = lambda x: transform(lims_A, x)
56 | T_B = lambda x: transform(lims_B, x)
57 | scalar1 = 0.6
58 | scalar2 = 0.7
59 | vector1 = npr.rand(2)
60 | vector2 = npr.rand(2)
61 | mat11 = npr.rand(3, 2)
62 | mat12 = npr.rand(3, 2)
63 | mat21 = npr.rand(1, 2)
64 | mat22 = npr.rand(1, 2)
65 | check = combo_check(fun, (0, 1), **kwargs)
66 | check(
67 | [T_A(scalar1), T_A(vector1), T_A(mat11), T_A(mat21)],
68 | [T_B(scalar2), T_B(vector2), T_B(mat12), T_B(mat22)],
69 | )
70 | if test_complex:
71 | comp1 = 0.6 + 0.3j
72 | comp2 = 0.1 + 0.2j
73 | matc1 = npr.rand(3, 2) + 0.1j * npr.rand(3, 2)
74 | matc2 = npr.rand(3, 2) + 0.1j * npr.rand(3, 2)
75 | check(
76 | [T_A(scalar1), T_A(comp1), T_A(vector1), T_A(matc1), T_A(mat21)],
77 | [T_B(scalar2), T_B(comp2), T_B(vector2), T_B(matc2), T_B(mat22)],
78 | )
79 |
80 |
81 | def transform(lims, x):
82 | return x * (lims[1] - lims[0]) + lims[0]
83 |
--------------------------------------------------------------------------------
/tests/profiling.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from time import time
3 |
4 | import autograd.numpy as np
5 | import autograd.numpy.random as npr
6 | from autograd import grad
7 |
8 |
9 | @contextmanager
10 | def tictoc(text=""):
11 | print("--- Start clock ---")
12 | t1 = time()
13 | yield
14 | dt = time() - t1
15 | print(f"--- Stop clock {text}: {dt} seconds elapsed ---")
16 |
17 |
18 | def fan_out_fan_in():
19 | """The 'Pearlmutter test'"""
20 |
21 | def fun(x):
22 | for i in range(10**4):
23 | x = (x + x) / 2.0
24 | return np.sum(x)
25 |
26 | with tictoc():
27 | grad(fun)(1.0)
28 |
29 |
30 | def convolution():
31 | # MNIST-scale convolution operation
32 | import autograd.scipy.signal
33 |
34 | convolve = autograd.scipy.signal.convolve
35 | dat = npr.randn(256, 3, 28, 28)
36 | kernel = npr.randn(3, 5, 5)
37 | with tictoc():
38 | convolve(dat, kernel, axes=([2, 3], [1, 2]), dot_axes=([1], [0]))
39 |
40 |
41 | def dot_equivalent():
42 | # MNIST-scale convolution operation
43 |
44 | dat = npr.randn(256, 3, 24, 5, 24, 5)
45 | kernel = npr.randn(3, 5, 5)
46 | with tictoc():
47 | np.tensordot(dat, kernel, axes=[(1, 3, 5), (0, 1, 2)])
48 |
49 |
50 | # fan_out_fan_in()
51 | # convolution()
52 | dot_equivalent()
53 |
--------------------------------------------------------------------------------
/tests/test_binary_ops.py:
--------------------------------------------------------------------------------
1 | import itertools as it
2 | import warnings
3 |
4 | import autograd.numpy as np
5 | import autograd.numpy.random as npr
6 | from autograd import grad, value_and_grad
7 | from autograd.test_util import check_grads
8 |
9 | rs = npr.RandomState(0)
10 |
11 |
12 | def arg_pairs():
13 | scalar = 2.0
14 | vector = rs.randn(4)
15 | mat = rs.randn(3, 4)
16 | mat2 = rs.randn(1, 4)
17 | allargs = [scalar, vector, mat, mat2]
18 | yield from it.product(allargs, allargs)
19 |
20 |
21 | def test_mul():
22 | fun = lambda x, y: x * y
23 | for arg1, arg2 in arg_pairs():
24 | check_grads(fun)(arg1, arg2)
25 |
26 |
27 | def test_add():
28 | fun = lambda x, y: x + y
29 | for arg1, arg2 in arg_pairs():
30 | check_grads(fun)(arg1, arg2)
31 |
32 |
33 | def test_sub():
34 | fun = lambda x, y: x - y
35 | for arg1, arg2 in arg_pairs():
36 | check_grads(fun)(arg1, arg2)
37 |
38 |
39 | def test_div():
40 | fun = lambda x, y: x / y
41 | make_gap_from_zero = lambda x: np.sqrt(x**2 + 0.5)
42 | for arg1, arg2 in arg_pairs():
43 | arg1 = make_gap_from_zero(arg1)
44 | arg2 = make_gap_from_zero(arg2)
45 | check_grads(fun)(arg1, arg2)
46 |
47 |
48 | def test_mod():
49 | fun = lambda x, y: x % y
50 | make_gap_from_zero = lambda x: np.sqrt(x**2 + 0.5)
51 | for arg1, arg2 in arg_pairs():
52 | if arg1 is not arg2: # Gradient undefined at x == y
53 | arg1 = make_gap_from_zero(arg1)
54 | arg2 = make_gap_from_zero(arg2)
55 | check_grads(fun)(arg1, arg2)
56 |
57 |
58 | def test_pow():
59 | fun = lambda x, y: x**y
60 | make_positive = lambda x: np.abs(x) + 1.1 # Numeric derivatives fail near zero
61 | for arg1, arg2 in arg_pairs():
62 | arg1 = make_positive(arg1)
63 | check_grads(fun)(arg1, arg2)
64 |
65 |
66 | def test_arctan2():
67 | for arg1, arg2 in arg_pairs():
68 | check_grads(np.arctan2)(arg1, arg2)
69 |
70 |
71 | def test_hypot():
72 | for arg1, arg2 in arg_pairs():
73 | check_grads(np.hypot, modes=["rev"])(arg1, arg2)
74 |
75 |
76 | def test_comparison_grads():
77 | compare_funs = [
78 | lambda x, y: np.sum(x < x) + 0.0,
79 | lambda x, y: np.sum(x <= y) + 0.0,
80 | lambda x, y: np.sum(x > y) + 0.0,
81 | lambda x, y: np.sum(x >= y) + 0.0,
82 | lambda x, y: np.sum(x == y) + 0.0,
83 | lambda x, y: np.sum(x != y) + 0.0,
84 | ]
85 |
86 | with warnings.catch_warnings(record=True) as w:
87 | for arg1, arg2 in arg_pairs():
88 | zeros = (arg1 + arg2) * 0 # get correct shape
89 | for fun in compare_funs:
90 | assert np.all(grad(fun)(arg1, arg2) == zeros)
91 | assert np.all(grad(fun, argnum=1)(arg1, arg2) == zeros)
92 |
93 |
94 | def test_comparison_values():
95 | compare_funs = [
96 | lambda x, y: np.sum(x < x) + 0.0,
97 | lambda x, y: np.sum(x <= y) + 0.0,
98 | lambda x, y: np.sum(x > y) + 0.0,
99 | lambda x, y: np.sum(x >= y) + 0.0,
100 | lambda x, y: np.sum(x == y) + 0.0,
101 | lambda x, y: np.sum(x != y) + 0.0,
102 | ]
103 |
104 | for arg1, arg2 in arg_pairs():
105 | for fun in compare_funs:
106 | fun_val = fun(arg1, arg2)
107 | fun_val_from_grad, _ = value_and_grad(fun)(arg1, arg2)
108 | assert fun_val == fun_val_from_grad, (fun_val, fun_val_from_grad)
109 |
--------------------------------------------------------------------------------
/tests/test_builtins.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | from autograd import grad
3 | from autograd.builtins import isinstance
4 |
5 |
6 | def test_isinstance():
7 | def checker(ex, type_, truthval):
8 | assert isinstance(ex, type_) == truthval
9 | return 1.0
10 |
11 | examples = [
12 | [list, [[]], [()]],
13 | [np.ndarray, [np.zeros(1)], [[]]],
14 | [(tuple, list), [[], ()], [np.zeros(1)]],
15 | ]
16 |
17 | for type_, positive_examples, negative_examples in examples:
18 | for ex in positive_examples:
19 | checker(ex, type_, True)
20 | grad(checker)(ex, type_, True)
21 |
22 | for ex in negative_examples:
23 | checker(ex, type_, False)
24 | grad(checker)(ex, type_, False)
25 |
--------------------------------------------------------------------------------
/tests/test_complex.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | import autograd.numpy.random as npr
3 | from autograd import grad
4 | from autograd.test_util import check_grads
5 |
6 | npr.seed(1)
7 |
8 |
9 | def test_real_type():
10 | fun = lambda x: np.sum(np.real(x))
11 | df = grad(fun)
12 | assert np.isrealobj(df(2.0))
13 | assert np.iscomplexobj(df(1.0j))
14 |
15 |
16 | def test_real_if_close_type():
17 | fun = lambda x: np.sum(np.real(x))
18 | df = grad(fun)
19 | assert np.isrealobj(df(1.0))
20 | assert np.iscomplexobj(df(1.0j))
21 |
22 |
23 | def test_angle_real():
24 | fun = lambda x: np.angle(x)
25 | d_fun = lambda x: grad(fun)(x)
26 | check_grads(fun)(npr.rand())
27 | check_grads(d_fun)(npr.rand())
28 |
29 |
30 | def test_angle_complex():
31 | fun = lambda x: np.angle(x)
32 | d_fun = lambda x: grad(fun)(x)
33 | check_grads(fun)(npr.rand() + 1j * npr.rand())
34 | check_grads(d_fun)(npr.rand() + 1j * npr.rand())
35 |
36 |
37 | def test_abs_real():
38 | fun = lambda x: np.abs(x)
39 | d_fun = lambda x: grad(fun)(x)
40 | check_grads(fun)(1.1)
41 | check_grads(d_fun)(2.1)
42 |
43 |
44 | def test_abs_complex():
45 | fun = lambda x: np.abs(x)
46 | d_fun = lambda x: grad(fun)(x)
47 | check_grads(fun)(1.1 + 1.2j)
48 | check_grads(d_fun)(1.1 + 1.3j)
49 |
--------------------------------------------------------------------------------
/tests/test_core.py:
--------------------------------------------------------------------------------
1 | """This file doesn't import the numpy wrapper, to check if core works
2 | on basic operations even without numpy."""
3 |
4 | import warnings
5 |
6 | from autograd.core import make_vjp
7 | from autograd.wrap_util import unary_to_nary
8 |
9 |
10 | @unary_to_nary
11 | def grad(fun, x):
12 | vjp, _ = make_vjp(fun, x)
13 | return vjp(1.0)
14 |
15 |
16 | # Non-numpy gradient checking functions.
17 | def nd(f, x, eps=1e-4):
18 | return (f(x + eps / 2) - f(x - eps / 2)) / eps
19 |
20 |
21 | def check_close(a, b, atol=1e-4, rtol=1e-4):
22 | assert abs(a - b) < atol + rtol * abs(b), f"Diffs are: {a - b}"
23 |
24 |
25 | def check_binary_func(fun, independent=False):
26 | with warnings.catch_warnings(record=independent) as w:
27 | x, y = 0.7, 1.8
28 | a = grad(fun)(x, y)
29 | b = nd(lambda x: fun(x, y), x)
30 | check_close(a, b)
31 |
32 | a = grad(fun, 1)(x, y)
33 | b = nd(lambda y: fun(x, y), y)
34 | check_close(a, b)
35 |
36 |
37 | def test_add():
38 | check_binary_func(lambda x, y: x + y)
39 |
40 |
41 | def test_sub():
42 | check_binary_func(lambda x, y: x - y)
43 |
44 |
45 | def test_div():
46 | check_binary_func(lambda x, y: x / y)
47 |
48 |
49 | def test_mul():
50 | check_binary_func(lambda x, y: x * y)
51 |
52 |
53 | def test_pow():
54 | check_binary_func(lambda x, y: x**y)
55 |
56 |
57 | def test_mod():
58 | check_binary_func(lambda x, y: x % y)
59 |
60 |
61 | def test_eq():
62 | check_binary_func(lambda x, y: x == y, independent=True)
63 |
64 |
65 | def test_neq():
66 | check_binary_func(lambda x, y: x != y, independent=True)
67 |
68 |
69 | def test_leq():
70 | check_binary_func(lambda x, y: x <= y, independent=True)
71 |
72 |
73 | def test_geq():
74 | check_binary_func(lambda x, y: x >= y, independent=True)
75 |
76 |
77 | def test_lt():
78 | check_binary_func(lambda x, y: x < y, independent=True)
79 |
80 |
81 | def test_gt():
82 | check_binary_func(lambda x, y: x > y, independent=True)
83 |
--------------------------------------------------------------------------------
/tests/test_dict.py:
--------------------------------------------------------------------------------
1 | import operator as op
2 |
3 | import autograd.numpy as np
4 | import autograd.numpy.random as npr
5 | from autograd import dict as ag_dict
6 | from autograd import grad
7 | from autograd import isinstance as ag_isinstance
8 | from autograd.test_util import check_grads
9 |
10 | npr.seed(0)
11 |
12 |
13 | def test_getter():
14 | def fun(input_dict):
15 | A = np.sum(input_dict["item_1"])
16 | B = np.sum(input_dict["item_2"])
17 | C = np.sum(input_dict["item_2"])
18 | return A + B + C
19 |
20 | d_fun = grad(fun)
21 | input_dict = {"item_1": npr.randn(5, 6), "item_2": npr.randn(4, 3), "item_X": npr.randn(2, 4)}
22 |
23 | result = d_fun(input_dict)
24 | assert np.allclose(result["item_1"], np.ones((5, 6)))
25 | assert np.allclose(result["item_2"], 2 * np.ones((4, 3)))
26 | assert np.allclose(result["item_X"], np.zeros((2, 4)))
27 |
28 |
29 | def test_grads():
30 | def fun(input_dict):
31 | A = np.sum(np.sin(input_dict["item_1"]))
32 | B = np.sum(np.cos(input_dict["item_2"]))
33 | return A + B
34 |
35 | def d_fun(input_dict):
36 | g = grad(fun)(input_dict)
37 | A = np.sum(g["item_1"])
38 | B = np.sum(np.sin(g["item_1"]))
39 | C = np.sum(np.sin(g["item_2"]))
40 | return A + B + C
41 |
42 | input_dict = {"item_1": npr.randn(5, 6), "item_2": npr.randn(4, 3), "item_X": npr.randn(2, 4)}
43 |
44 | check_grads(fun)(input_dict)
45 | check_grads(d_fun)(input_dict)
46 |
47 |
48 | def test_iter():
49 | def fun(input_dict):
50 | A = 0.0
51 | B = 0.0
52 | for i, k in enumerate(sorted(input_dict)):
53 | A = A + np.sum(np.sin(input_dict[k])) * (i + 1.0)
54 | B = B + np.sum(np.cos(input_dict[k]))
55 | return A + B
56 |
57 | def d_fun(input_dict):
58 | g = grad(fun)(input_dict)
59 | A = np.sum(g["item_1"])
60 | B = np.sum(np.sin(g["item_1"]))
61 | C = np.sum(np.sin(g["item_2"]))
62 | return A + B + C
63 |
64 | input_dict = {"item_1": npr.randn(5, 6), "item_2": npr.randn(4, 3), "item_X": npr.randn(2, 4)}
65 |
66 | check_grads(fun)(input_dict)
67 | check_grads(d_fun)(input_dict)
68 |
69 |
70 | def test_items_values_keys():
71 | def fun(input_dict):
72 | A = 0.0
73 | B = 0.0
74 | for i, (k, v) in enumerate(sorted(input_dict.items(), key=op.itemgetter(0))):
75 | A = A + np.sum(np.sin(v)) * (i + 1.0)
76 | B = B + np.sum(np.cos(v))
77 | for v in input_dict.values():
78 | A = A + np.sum(np.sin(v))
79 | for k in sorted(input_dict.keys()):
80 | A = A + np.sum(np.cos(input_dict[k]))
81 | return A + B
82 |
83 | def d_fun(input_dict):
84 | g = grad(fun)(input_dict)
85 | A = np.sum(g["item_1"])
86 | B = np.sum(np.sin(g["item_1"]))
87 | C = np.sum(np.sin(g["item_2"]))
88 | return A + B + C
89 |
90 | input_dict = {"item_1": npr.randn(5, 6), "item_2": npr.randn(4, 3), "item_X": npr.randn(2, 4)}
91 |
92 | check_grads(fun)(input_dict)
93 | check_grads(d_fun)(input_dict)
94 |
95 |
96 | def test_get():
97 | def fun(d, x):
98 | return d.get("item_1", x) ** 2
99 |
100 | check_grads(fun, argnum=(0, 1))({"item_1": 3.0}, 2.0)
101 | check_grads(fun, argnum=(0, 1))({"item_2": 4.0}, 2.0)
102 | check_grads(fun, argnum=(0, 1))({}, 2.0)
103 |
104 |
105 | def test_make_dict():
106 | def fun(x):
107 | return ag_dict([("a", x)], b=x)
108 |
109 | check_grads(fun, modes=["rev"])(1.0)
110 |
111 | def fun(x):
112 | return ag_dict({"a": x})
113 |
114 | check_grads(fun, modes=["rev"])(1.0)
115 |
116 | # check some other forms of the constructor
117 | ag_dict()
118 | ag_dict(())
119 | ag_dict({})
120 |
121 |
122 | def test_isinstance():
123 | def fun(x):
124 | assert ag_isinstance(x, dict)
125 | assert ag_isinstance(x, ag_dict)
126 | return x["x"]
127 |
128 | fun({"x": 1.0})
129 | grad(fun)({"x": 1.0})
130 |
--------------------------------------------------------------------------------
/tests/test_direct.py:
--------------------------------------------------------------------------------
1 | """
2 | Set of tests that are as explicit as possible, in case the test helpers like
3 | autograd.test_util break and start letting everything pass
4 | """
5 |
6 | import numpy as onp
7 | import pytest
8 |
9 | import autograd.numpy as np
10 | from autograd import deriv, grad, holomorphic_grad
11 |
12 |
13 | def test_grad():
14 | def fun(x):
15 | return (x + np.sin(x**2)) * x
16 |
17 | assert 3.190948746871 - 1e-6 < grad(fun)(1.3) < 3.190948746871 + 1e-6
18 |
19 |
20 | def test_deriv():
21 | def fun(x):
22 | return (x + np.sin(x**2)) * x
23 |
24 | assert 3.190948746871 - 1e-6 < deriv(fun)(1.3) < 3.190948746871 + 1e-6
25 |
26 |
27 | def test_grad_complex_output():
28 | def fun(x):
29 | return x * (1.0 + 0.2j)
30 |
31 | with pytest.raises(TypeError):
32 | grad(fun)(1.0)
33 |
34 |
35 | def test_holomorphic_grad():
36 | def fun(x):
37 | return x * (1.0 + 0.2j)
38 |
39 | g = holomorphic_grad(fun)(1.0 + 0.0j)
40 | assert 0.9999 < onp.real(g) < 1.0001
41 | assert 0.1999 < onp.imag(g) < 0.2001
42 |
--------------------------------------------------------------------------------
/tests/test_jacobian.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | import autograd.numpy.random as npr
3 | from autograd import grad, jacobian
4 | from autograd.test_util import check_grads
5 |
6 | npr.seed(1)
7 |
8 |
9 | def test_jacobian_against_grad():
10 | fun = lambda x: np.sum(np.sin(x), axis=1, keepdims=True)
11 | A = npr.randn(1, 3)
12 | assert np.allclose(grad(fun)(A), jacobian(fun)(A))
13 |
14 |
15 | def test_jacobian_scalar_to_vector():
16 | fun = lambda x: np.array([x, x**2, x**3])
17 | val = npr.randn()
18 | assert np.allclose(jacobian(fun)(val), np.array([1.0, 2 * val, 3 * val**2]))
19 |
20 |
21 | def test_jacobian_against_stacked_grads():
22 | scalar_funs = [
23 | lambda x: np.sum(x**3),
24 | lambda x: np.prod(np.sin(x) + np.sin(x)),
25 | lambda x: grad(lambda y: np.exp(y) * np.tanh(x[0]))(x[1]),
26 | ]
27 |
28 | vector_fun = lambda x: np.array([f(x) for f in scalar_funs])
29 |
30 | x = npr.randn(5)
31 | jac = jacobian(vector_fun)(x)
32 | grads = [grad(f)(x) for f in scalar_funs]
33 |
34 | assert np.allclose(jac, np.vstack(grads))
35 |
36 |
37 | def test_jacobian_higher_order():
38 | fun = lambda x: np.sin(np.outer(x, x)) + np.cos(np.dot(x, x))
39 |
40 | assert jacobian(fun)(npr.randn(2)).shape == (2, 2, 2)
41 | assert jacobian(jacobian(fun))(npr.randn(2)).shape == (2, 2, 2, 2)
42 | # assert jacobian(jacobian(jacobian(fun)))(npr.randn(2)).shape == (2,2,2,2,2)
43 |
44 | check_grads(lambda x: np.sum(np.sin(jacobian(fun)(x))))(npr.randn(2))
45 | check_grads(lambda x: np.sum(np.sin(jacobian(jacobian(fun))(x))))(npr.randn(2))
46 |
--------------------------------------------------------------------------------
/tests/test_list.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | import autograd.numpy.random as npr
3 | from autograd import grad
4 | from autograd import isinstance as ag_isinstance
5 | from autograd import list as ag_list
6 | from autograd.test_util import check_grads
7 |
8 | npr.seed(1)
9 |
10 |
11 | def test_getter():
12 | def fun(input_list):
13 | A = np.sum(input_list[0])
14 | B = np.sum(input_list[1])
15 | C = np.sum(input_list[1])
16 | return A + B + C
17 |
18 | d_fun = grad(fun)
19 | input_list = [npr.randn(5, 6), npr.randn(4, 3), npr.randn(2, 4)]
20 |
21 | result = d_fun(input_list)
22 | assert np.allclose(result[0], np.ones((5, 6)))
23 | assert np.allclose(result[1], 2 * np.ones((4, 3)))
24 | assert np.allclose(result[2], np.zeros((2, 4)))
25 |
26 |
27 | def test_grads():
28 | def fun(input_list):
29 | A = np.sum(np.sin(input_list[0]))
30 | B = np.sum(np.cos(input_list[1]))
31 | return A + B
32 |
33 | def d_fun(input_list):
34 | g = grad(fun)(input_list)
35 | A = np.sum(g[0])
36 | B = np.sum(np.sin(g[0]))
37 | C = np.sum(np.sin(g[1]))
38 | return A + B + C
39 |
40 | input_list = [npr.randn(5, 6), npr.randn(4, 3), npr.randn(2, 4)]
41 |
42 | check_grads(fun)(input_list)
43 | check_grads(d_fun)(input_list)
44 |
45 |
46 | def test_slices():
47 | def f(x):
48 | s = slice(None, -1, None)
49 | y = x[s]
50 | return y[0]
51 |
52 | grad(f)([1.0, 2.0, 3.0])
53 |
54 | def f(x):
55 | y = x[1:3]
56 | return y[0]
57 |
58 | grad(f)([1.0, 2.0, 3.0])
59 |
60 |
61 | def test_nested_list():
62 | A = [[1.0], 2.0, 1.5]
63 |
64 | def fun(x):
65 | return x[1:][0]
66 |
67 | check_grads(fun)(A)
68 |
69 |
70 | def test_make_list():
71 | def fun(x):
72 | return ag_list((x, x))
73 |
74 | check_grads(fun)(1.0)
75 |
76 |
77 | def test_isinstance():
78 | def fun(x):
79 | assert ag_isinstance(x, list)
80 | assert ag_isinstance(x, ag_list)
81 | return x[0]
82 |
83 | fun([1.0, 2.0, 3.0])
84 | grad(fun)([1.0, 2.0, 3.0])
85 |
--------------------------------------------------------------------------------
/tests/test_logic.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from contextlib import contextmanager
3 |
4 | import pytest
5 |
6 | import autograd.numpy as np
7 | from autograd import deriv, grad
8 | from autograd.core import primitive_vjps
9 | from autograd.extend import primitive
10 | from autograd.test_util import check_grads
11 |
12 |
13 | def test_assert():
14 | # from https://github.com/HIPS/autograd/issues/43
15 | def fun(x):
16 | assert np.allclose(x, (x * 3.0) / 3.0)
17 | return np.sum(x)
18 |
19 | check_grads(fun)(np.array([1.0, 2.0, 3.0]))
20 |
21 |
22 | def test_nograd():
23 | # we want this to raise non-differentiability error
24 | fun = lambda x: np.allclose(x, (x * 3.0) / 3.0)
25 | with pytest.raises(TypeError):
26 | with warnings.catch_warnings(record=True) as w:
27 | grad(fun)(np.array([1.0, 2.0, 3.0]))
28 |
29 |
30 | def test_no_vjp_def():
31 | fun = primitive(lambda x: 2.0 * x)
32 | with pytest.raises(NotImplementedError):
33 | grad(fun)(1.0)
34 |
35 |
36 | def test_no_jvp_def():
37 | fun = primitive(lambda x: 2.0 * x)
38 | with pytest.raises(NotImplementedError):
39 | deriv(fun)(1.0)
40 |
41 |
42 | def test_falseyness():
43 | fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x))
44 | check_grads(fun)(2.0)
45 | check_grads(fun)(2.0 + 1j)
46 |
47 |
48 | def test_unimplemented_falseyness():
49 | @contextmanager
50 | def remove_grad_definitions(fun):
51 | vjpmaker = primitive_vjps.pop(fun, None)
52 | yield
53 | if vjpmaker:
54 | primitive_vjps[fun] = vjpmaker
55 |
56 | with remove_grad_definitions(np.iscomplex):
57 | fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x))
58 | check_grads(fun)(5.0)
59 | check_grads(fun)(2.0 + 1j)
60 |
--------------------------------------------------------------------------------
/tests/test_misc.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | import autograd.numpy.random as npr
3 | from autograd import grad, make_vjp
4 | from autograd.misc import const_graph, flatten
5 | from autograd.test_util import scalar_close
6 | from autograd.tracer import primitive
7 |
8 |
9 | def test_const_graph():
10 | L = []
11 |
12 | def foo(x, y):
13 | L.append(None)
14 | return grad(lambda x: np.sin(x) + x * 2)(x * y)
15 |
16 | foo_wrapped = const_graph(foo)
17 |
18 | assert len(L) == 0
19 | assert scalar_close(foo(0.0, 0.0), foo_wrapped(0.0, 0.0))
20 | assert len(L) == 2
21 | assert scalar_close(foo(1.0, 0.5), foo_wrapped(1.0, 0.5))
22 | assert len(L) == 3
23 | assert scalar_close(foo(1.0, 0.5), foo_wrapped(1.0, 0.5))
24 | assert len(L) == 4
25 |
26 |
27 | def test_const_graph_args():
28 | L = []
29 |
30 | @primitive
31 | def process(var, varname):
32 | L.append(varname)
33 | return var
34 |
35 | def foo(x, y, z):
36 | x = process(x, "x")
37 | y = process(y, "y")
38 | z = process(z, "z")
39 | return x + 2 * y + 3 * z
40 |
41 | foo_wrapped = const_graph(foo, 1.0, z=3.0)
42 |
43 | assert L == []
44 | assert scalar_close(foo(1.0, 2.0, 3.0), foo_wrapped(2.0))
45 | assert L == ["x", "y", "z", "x", "y", "z"]
46 | L = []
47 | assert scalar_close(foo(1.0, 2.0, 3.0), foo_wrapped(2.0))
48 | assert L == ["x", "y", "z", "y"]
49 | L = []
50 | assert scalar_close(foo(1.0, 2.0, 3.0), foo_wrapped(2.0))
51 | assert L == ["x", "y", "z", "y"]
52 |
53 |
54 | def test_flatten():
55 | r = np.random.randn
56 | x = (1.0, r(2, 3), [r(1, 4), {"x": 2.0, "y": r(4, 2)}])
57 | x_flat, unflatten = flatten(x)
58 | assert x_flat.shape == (20,)
59 | assert x_flat[0] == 1.0
60 | assert np.all(x_flat == flatten(unflatten(x_flat))[0])
61 |
62 | y = (1.0, 2.0, [3.0, {"x": 2.0, "y": 4.0}])
63 | y_flat, unflatten = flatten(y)
64 | assert y_flat.shape == (5,)
65 | assert y == unflatten(y_flat)
66 |
67 |
68 | def test_flatten_empty():
69 | val = (npr.randn(4), [npr.randn(3, 4), 2.5], (), (2.0, [1.0, npr.randn(2)]))
70 | vect, unflatten = flatten(val)
71 | val_recovered = unflatten(vect)
72 | vect_2, _ = flatten(val_recovered)
73 | assert np.all(vect == vect_2)
74 |
75 |
76 | def test_flatten_dict():
77 | val = {"k": npr.random((4, 4)), "k2": npr.random((3, 3)), "k3": 3.0, "k4": [1.0, 4.0, 7.0, 9.0]}
78 |
79 | vect, unflatten = flatten(val)
80 | val_recovered = unflatten(vect)
81 | vect_2, _ = flatten(val_recovered)
82 | assert np.all(vect == vect_2)
83 |
84 |
85 | def unflatten_tracing():
86 | val = [npr.randn(4), [npr.randn(3, 4), 2.5], (), (2.0, [1.0, npr.randn(2)])]
87 | vect, unflatten = flatten(val)
88 |
89 | def f(vect):
90 | return unflatten(vect)
91 |
92 | flatten2, _ = make_vjp(f)(vect)
93 | assert np.all(vect == flatten2(val))
94 |
95 |
96 | def test_flatten_nodes_in_containers():
97 | # see issue #232
98 | def f(x, y):
99 | xy, _ = flatten([x, y])
100 | return np.sum(xy)
101 |
102 | grad(f)(1.0, 2.0)
103 |
104 |
105 | def test_flatten_complex():
106 | val = 1 + 1j
107 | flat, unflatten = flatten(val)
108 | assert np.all(val == unflatten(flat))
109 |
--------------------------------------------------------------------------------
/tests/test_performance.py:
--------------------------------------------------------------------------------
1 | # TODO:
2 | # Do a huge calculation with trivial primitive computations
3 | # and lots of diamonds and get a benchmark per-node time and
4 | # memory cost.
5 |
--------------------------------------------------------------------------------
/tests/test_scalar_ops.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | import autograd.numpy.random as npr
3 | from autograd import grad
4 | from autograd.test_util import check_grads
5 |
6 | npr.seed(1)
7 |
8 |
9 | def test_abs():
10 | fun = lambda x: 3.0 * np.abs(x)
11 | check_grads(fun)(1.1)
12 | check_grads(fun)(-1.1)
13 | check_grads(fun, order=1)(0.0)
14 |
15 |
16 | def test_sin():
17 | fun = lambda x: 3.0 * np.sin(x)
18 | check_grads(fun)(npr.randn())
19 |
20 |
21 | def test_sign():
22 | fun = lambda x: 3.0 * np.sign(x)
23 | check_grads(fun)(1.1)
24 | check_grads(fun)(-1.1)
25 |
26 |
27 | def test_exp():
28 | fun = lambda x: 3.0 * np.exp(x)
29 | check_grads(fun)(npr.randn())
30 |
31 |
32 | def test_log():
33 | fun = lambda x: 3.0 * np.log(x)
34 | check_grads(fun)(abs(npr.randn()))
35 |
36 |
37 | def test_log2():
38 | fun = lambda x: 3.0 * np.log2(x)
39 | check_grads(fun)(abs(npr.randn()))
40 |
41 |
42 | def test_log10():
43 | fun = lambda x: 3.0 * np.log10(x)
44 | check_grads(fun)(abs(npr.randn()))
45 |
46 |
47 | def test_log1p():
48 | fun = lambda x: 3.0 * np.log1p(x)
49 | check_grads(fun)(abs(npr.randn()))
50 |
51 |
52 | def test_expm1():
53 | fun = lambda x: 3.0 * np.expm1(x)
54 | check_grads(fun)(abs(npr.randn()))
55 |
56 |
57 | def test_exp2():
58 | fun = lambda x: 3.0 * np.exp2(x)
59 | check_grads(fun)(abs(npr.randn()))
60 |
61 |
62 | def test_neg():
63 | fun = lambda x: 3.0 * -x
64 | check_grads(fun)(npr.randn())
65 |
66 |
67 | def test_cos():
68 | fun = lambda x: 3.0 * np.cos(x)
69 | check_grads(fun)(npr.randn())
70 |
71 |
72 | def test_tan():
73 | fun = lambda x: 3.0 * np.tan(x)
74 | check_grads(fun)(npr.randn())
75 |
76 |
77 | def test_cosh():
78 | fun = lambda x: 3.0 * np.cosh(x)
79 | check_grads(fun)(npr.randn())
80 |
81 |
82 | def test_sinh():
83 | fun = lambda x: 3.0 * np.sinh(x)
84 | check_grads(fun)(npr.randn())
85 |
86 |
87 | def test_tanh():
88 | fun = lambda x: 3.0 * np.tanh(x)
89 | check_grads(fun)(npr.randn())
90 |
91 |
92 | def test_arccos():
93 | fun = lambda x: 3.0 * np.arccos(x)
94 | check_grads(fun)(0.1)
95 |
96 |
97 | def test_arcsin():
98 | fun = lambda x: 3.0 * np.arcsin(x)
99 | check_grads(fun)(0.1)
100 |
101 |
102 | def test_arctan():
103 | fun = lambda x: 3.0 * np.arctan(x)
104 | check_grads(fun)(0.2)
105 |
106 |
107 | def test_arccosh():
108 | fun = lambda x: 3.0 * np.arccosh(x)
109 | check_grads(fun)(npr.randn() ** 2 + 1.2)
110 |
111 |
112 | def test_arcsinh():
113 | fun = lambda x: 3.0 * np.arcsinh(x)
114 | check_grads(fun)(npr.randn())
115 |
116 |
117 | def test_arctanh():
118 | fun = lambda x: 3.0 * np.arctanh(x)
119 | check_grads(fun)(0.2)
120 |
121 |
122 | def test_sqrt():
123 | fun = lambda x: 3.0 * np.sqrt(x)
124 | check_grads(fun)(10.0 * npr.rand())
125 |
126 |
127 | def test_power_arg0():
128 | # the +1.'s here are to avoid regimes where numerical diffs fail
129 | make_fun = lambda y: lambda x: np.power(x, y)
130 | fun = make_fun(npr.randn() ** 2 + 1.0)
131 | check_grads(fun)(npr.rand() ** 2 + 1.0)
132 |
133 | # test y == 0. as a special case, c.f. #116
134 | fun = make_fun(0.0)
135 | assert grad(fun)(0.0) == 0.0
136 | assert grad(grad(fun))(0.0) == 0.0
137 |
138 |
139 | def test_power_arg1():
140 | x = npr.randn() ** 2
141 | fun = lambda y: np.power(x, y)
142 | check_grads(fun)(npr.rand() ** 2)
143 |
144 |
145 | def test_power_arg1_zero():
146 | fun = lambda y: np.power(0.0, y)
147 | check_grads(fun)(npr.rand() ** 2)
148 |
149 |
150 | def test_mod_arg0():
151 | fun = lambda x, y: np.mod(x, y)
152 | check_grads(fun)(npr.rand(), npr.rand())
153 |
154 |
155 | def test_mod_arg1():
156 | fun = lambda x, y: np.mod(x, y)
157 | check_grads(fun)(npr.rand(), npr.rand())
158 |
159 |
160 | def test_divide_arg0():
161 | fun = lambda x, y: np.divide(x, y)
162 | check_grads(fun)(npr.rand(), npr.rand())
163 |
164 |
165 | def test_divide_arg1():
166 | fun = lambda x, y: np.divide(x, y)
167 | check_grads(fun)(npr.rand(), npr.rand())
168 |
169 |
170 | def test_multiply_arg0():
171 | fun = lambda x, y: np.multiply(x, y)
172 | check_grads(fun)(npr.rand(), npr.rand())
173 |
174 |
175 | def test_multiply_arg1():
176 | fun = lambda x, y: np.multiply(x, y)
177 | check_grads(fun)(npr.rand(), npr.rand())
178 |
179 |
180 | def test_true_divide_arg0():
181 | fun = lambda x, y: np.true_divide(x, y)
182 | check_grads(fun)(npr.rand(), npr.rand())
183 |
184 |
185 | def test_true_divide_arg1():
186 | fun = lambda x, y: np.true_divide(x, y)
187 | check_grads(fun)(npr.rand(), npr.rand())
188 |
189 |
190 | def test_reciprocal():
191 | fun = lambda x: np.reciprocal(x)
192 | check_grads(fun)(npr.rand())
193 |
194 |
195 | def test_negative():
196 | fun = lambda x: np.negative(x)
197 | check_grads(fun)(npr.rand())
198 |
199 |
200 | def test_rad2deg():
201 | fun = lambda x: 3.0 * np.rad2deg(x)
202 | check_grads(fun)(10.0 * npr.rand())
203 |
204 |
205 | def test_deg2rad():
206 | fun = lambda x: 3.0 * np.deg2rad(x)
207 | check_grads(fun)(10.0 * npr.rand())
208 |
209 |
210 | def test_radians():
211 | fun = lambda x: 3.0 * np.radians(x)
212 | check_grads(fun)(10.0 * npr.rand())
213 |
214 |
215 | def test_degrees():
216 | fun = lambda x: 3.0 * np.degrees(x)
217 | check_grads(fun)(10.0 * npr.rand())
218 |
219 |
220 | def test_sinc():
221 | fun = lambda x: 3.0 * np.sinc(x)
222 | check_grads(fun)(10.0 * npr.rand())
223 |
--------------------------------------------------------------------------------
/tests/test_tests.py:
--------------------------------------------------------------------------------
1 | from pytest import raises
2 |
3 | from autograd.extend import defvjp
4 | from autograd.test_util import check_grads
5 | from autograd.tracer import primitive
6 |
7 |
8 | def test_check_vjp_1st_order_fail():
9 | @primitive
10 | def foo(x):
11 | return x * 2.0
12 |
13 | defvjp(foo, lambda ans, x: lambda g: g * 2.001)
14 |
15 | with raises(AssertionError, match="\\(VJP\\) check of foo failed"):
16 | check_grads(foo, modes=["rev"])(1.0)
17 |
18 |
19 | def test_check_vjp_2nd_order_fail():
20 | @primitive
21 | def foo(x):
22 | return x * 2.0
23 |
24 | defvjp(foo, lambda ans, x: lambda g: bar(g) * 2)
25 |
26 | @primitive
27 | def bar(x):
28 | return x
29 |
30 | defvjp(bar, lambda ans, x: lambda g: g * 1.001)
31 |
32 | with raises(AssertionError, match="\\(VJP\\) check of vjp_foo failed"):
33 | check_grads(foo, modes=["rev"])(1.0)
34 |
--------------------------------------------------------------------------------
/tests/test_truediv.py:
--------------------------------------------------------------------------------
1 | # This file is to check that future division works.
2 |
3 | from test_binary_ops import arg_pairs
4 |
5 | import autograd.numpy as np
6 | from autograd.test_util import check_grads
7 |
8 |
9 | def test_div():
10 | fun = lambda x, y: x / y
11 | make_gap_from_zero = lambda x: np.sqrt(x**2 + 0.5)
12 | for arg1, arg2 in arg_pairs():
13 | arg1 = make_gap_from_zero(arg1)
14 | arg2 = make_gap_from_zero(arg2)
15 | check_grads(fun)(arg1, arg2)
16 |
--------------------------------------------------------------------------------
/tests/test_tuple.py:
--------------------------------------------------------------------------------
1 | import autograd.numpy as np
2 | import autograd.numpy.random as npr
3 | from autograd import grad
4 | from autograd import isinstance as ag_isinstance
5 | from autograd import tuple as ag_tuple
6 | from autograd.test_util import check_grads
7 |
8 | npr.seed(1)
9 |
10 |
11 | def test_getter():
12 | def fun(input_tuple):
13 | A = np.sum(input_tuple[0])
14 | B = np.sum(input_tuple[1])
15 | C = np.sum(input_tuple[1])
16 | return A + B + C
17 |
18 | d_fun = grad(fun)
19 | input_tuple = (npr.randn(5, 6), npr.randn(4, 3), npr.randn(2, 4))
20 |
21 | result = d_fun(input_tuple)
22 | assert np.allclose(result[0], np.ones((5, 6)))
23 | assert np.allclose(result[1], 2 * np.ones((4, 3)))
24 | assert np.allclose(result[2], np.zeros((2, 4)))
25 |
26 |
27 | def test_grads():
28 | def fun(input_tuple):
29 | A = np.sum(np.sin(input_tuple[0]))
30 | B = np.sum(np.cos(input_tuple[1]))
31 | return A + B
32 |
33 | def d_fun(input_tuple):
34 | g = grad(fun)(input_tuple)
35 | A = np.sum(g[0])
36 | B = np.sum(np.sin(g[0]))
37 | C = np.sum(np.sin(g[1]))
38 | return A + B + C
39 |
40 | input_tuple = (npr.randn(5, 6), npr.randn(4, 3), npr.randn(2, 4))
41 |
42 | check_grads(fun)(input_tuple)
43 | check_grads(d_fun)(input_tuple)
44 |
45 |
46 | def test_nested_higher_order():
47 | def outer_fun(x):
48 | def inner_fun(y):
49 | return y[0] * y[1]
50 |
51 | return np.sum(np.sin(np.array(grad(inner_fun)(ag_tuple((x, x))))))
52 |
53 | check_grads(outer_fun)(5.0)
54 | check_grads(grad(outer_fun))(10.0)
55 | check_grads(grad(grad(outer_fun)))(10.0)
56 |
57 |
58 | def test_isinstance():
59 | def fun(x):
60 | assert ag_isinstance(x, tuple)
61 | assert ag_isinstance(x, ag_tuple)
62 | return x[0]
63 |
64 | fun((1.0, 2.0, 3.0))
65 | grad(fun)((1.0, 2.0, 3.0))
66 |
--------------------------------------------------------------------------------
/tests/test_vspaces.py:
--------------------------------------------------------------------------------
1 | import itertools as it
2 | from functools import reduce
3 |
4 | import numpy as np
5 |
6 | from autograd.core import vspace
7 | from autograd.test_util import check_grads, scalar_close
8 |
9 |
10 | def check_vspace(value):
11 | vs = vspace(value)
12 | # --- required attributes ---
13 | size = vs.size
14 | add = vs.add
15 | scalar_mul = vs.scalar_mul
16 | inner_prod = vs.inner_prod
17 | randn = vs.randn
18 | zeros = vs.zeros
19 | ones = vs.ones
20 | standard_basis = vs.standard_basis
21 |
22 | # --- util ---
23 | def randns(N=2):
24 | return [randn() for i in range(N)]
25 |
26 | def rand_scalar():
27 | return float(np.random.randn())
28 |
29 | def rand_scalars(N=2):
30 | return [rand_scalar() for i in range(N)]
31 |
32 | def vector_close(x, y):
33 | z = randn()
34 | return scalar_close(inner_prod(z, x), inner_prod(z, y))
35 |
36 | # --- vector space axioms ---
37 | def associativity_of_add(x, y, z):
38 | return vector_close(add(x, add(y, z)), add(add(x, y), z))
39 |
40 | def commutativity_of_add(x, y):
41 | return vector_close(add(x, y), add(y, x))
42 |
43 | def identity_element_of_add(x):
44 | return vector_close(add(zeros(), x), x)
45 |
46 | def inverse_elements_of_add(x):
47 | return vector_close(zeros(), add(x, scalar_mul(x, -1.0)))
48 |
49 | def compatibility_of_scalar_mul_with_field_mul(x, a, b):
50 | return vector_close(scalar_mul(x, a * b), scalar_mul(scalar_mul(x, a), b))
51 |
52 | def identity_element_of_scalar_mul(x):
53 | return vector_close(scalar_mul(x, 1.0), x)
54 |
55 | def distributivity_of_scalar_mul_wrt_vector_add(x, y, a):
56 | return vector_close(scalar_mul(add(x, y), a), add(scalar_mul(x, a), scalar_mul(y, a)))
57 |
58 | def distributivity_of_scalar_mul_wrt_scalar_add(x, a, b):
59 | return vector_close(scalar_mul(x, a + b), add(scalar_mul(x, a), scalar_mul(x, b)))
60 |
61 | # --- closure ---
62 | def add_preserves_vspace(x, y):
63 | return vs == vspace(add(x, y))
64 |
65 | def scalar_mul_preserves_vspace(x, a):
66 | return vs == vspace(scalar_mul(x, a))
67 |
68 | # --- inner product axioms ---
69 | def symmetry(x, y):
70 | return scalar_close(inner_prod(x, y), inner_prod(y, x))
71 |
72 | def linearity(x, y, a):
73 | return scalar_close(inner_prod(scalar_mul(x, a), y), a * inner_prod(x, y))
74 |
75 | def positive_definitive(x):
76 | return 0 < inner_prod(x, x)
77 |
78 | def inner_zeros():
79 | return scalar_close(0, inner_prod(zeros(), zeros()))
80 |
81 | # --- basis vectors and special vectors---
82 | def basis_orthonormality():
83 | return all(
84 | [
85 | scalar_close(inner_prod(x, y), 1.0 * (ix == iy))
86 | for (ix, x), (iy, y) in it.product(enumerate(standard_basis()), enumerate(standard_basis()))
87 | ]
88 | )
89 |
90 | def ones_sum_of_basis_vects():
91 | return vector_close(reduce(add, standard_basis()), ones())
92 |
93 | def basis_correct_size():
94 | return len(list(standard_basis())) == size
95 |
96 | def basis_correct_vspace():
97 | return (vs == vspace(x) for x in standard_basis())
98 |
99 | def zeros_correct_vspace():
100 | return vs == vspace(zeros())
101 |
102 | def ones_correct_vspace():
103 | return vs == vspace(ones())
104 |
105 | def randn_correct_vspace():
106 | return vs == vspace(randn())
107 |
108 | assert associativity_of_add(*randns(3))
109 | assert commutativity_of_add(*randns())
110 | assert identity_element_of_add(randn())
111 | assert inverse_elements_of_add(randn())
112 | assert compatibility_of_scalar_mul_with_field_mul(randn(), *rand_scalars())
113 | assert identity_element_of_scalar_mul(randn())
114 | assert distributivity_of_scalar_mul_wrt_vector_add(randn(), randn(), rand_scalar())
115 | assert distributivity_of_scalar_mul_wrt_scalar_add(randn(), *rand_scalars())
116 | assert add_preserves_vspace(*randns())
117 | assert scalar_mul_preserves_vspace(randn(), rand_scalar())
118 | assert symmetry(*randns())
119 | assert linearity(randn(), randn(), rand_scalar())
120 | assert positive_definitive(randn())
121 | assert inner_zeros()
122 | assert basis_orthonormality()
123 | assert ones_sum_of_basis_vects()
124 | assert basis_correct_size()
125 | assert basis_correct_vspace()
126 | assert zeros_correct_vspace()
127 | assert ones_correct_vspace()
128 | assert randn_correct_vspace()
129 |
130 | # --- grads of basic operations ---
131 | check_grads(add)(*randns())
132 | check_grads(scalar_mul)(randn(), rand_scalar())
133 | check_grads(inner_prod)(*randns())
134 |
135 |
136 | def test_array_vspace():
137 | check_vspace(np.zeros((3, 2)))
138 |
139 |
140 | def test_array_vspace_0_dim():
141 | check_vspace(0.0)
142 |
143 |
144 | def test_array_vspace_complex():
145 | check_vspace(1.0j * np.zeros((2, 1)))
146 |
147 |
148 | def test_list_vspace():
149 | check_vspace([1.0, np.zeros((2, 1))])
150 |
151 |
152 | def test_tuple_vspace():
153 | check_vspace((1.0, np.zeros((2, 1))))
154 |
155 |
156 | def test_dict_vspace():
157 | check_vspace({"a": 1.0, "b": np.zeros((2, 1))})
158 |
159 |
160 | def test_mixed_vspace():
161 | check_vspace({"x": [0.0, np.zeros((3, 1))], "y": ({"a": 0.0}, [0.0])})
162 |
--------------------------------------------------------------------------------