├── .github └── workflows │ ├── check.yml │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── README.md ├── autograd ├── __init__.py ├── builtins.py ├── core.py ├── differential_operators.py ├── extend.py ├── misc │ ├── __init__.py │ ├── fixed_points.py │ ├── flatten.py │ ├── optimizers.py │ └── tracers.py ├── numpy │ ├── __init__.py │ ├── fft.py │ ├── linalg.py │ ├── numpy_boxes.py │ ├── numpy_jvps.py │ ├── numpy_vjps.py │ ├── numpy_vspaces.py │ ├── numpy_wrapper.py │ └── random.py ├── scipy │ ├── __init__.py │ ├── integrate.py │ ├── linalg.py │ ├── signal.py │ ├── special.py │ └── stats │ │ ├── __init__.py │ │ ├── beta.py │ │ ├── chi2.py │ │ ├── dirichlet.py │ │ ├── gamma.py │ │ ├── multivariate_normal.py │ │ ├── norm.py │ │ ├── poisson.py │ │ └── t.py ├── test_util.py ├── tracer.py ├── util.py └── wrap_util.py ├── benchmarks ├── __init__.py ├── asv.conf.json.sample ├── bench_core.py ├── bench_mem.py ├── bench_numpy_vjps.py ├── bench_rnn.py └── bench_util.py ├── conda_recipe └── conda.yaml ├── docs ├── tutorial.md └── updateguide.md ├── examples ├── __init__.py ├── bayesian_neural_net.png ├── bayesian_neural_net.py ├── bayesian_optimization.py ├── black_box_svi.py ├── convnet.py ├── data.py ├── data_mnist.py ├── deep_gaussian_process.py ├── define_gradient.py ├── dot_graph.py ├── fixed_points.py ├── fluidsim │ ├── animated.gif │ ├── fluidsim.py │ ├── init_smoke.png │ ├── peace.png │ ├── skull.png │ ├── surprise.gif │ ├── wing.png │ └── wing.py ├── gaussian_process.png ├── gaussian_process.py ├── generative_adversarial_net.py ├── gmm.png ├── gmm.py ├── gplvm.png ├── gplvm.py ├── graph.pdf ├── hmm_em.py ├── ica.py ├── logistic_regression.py ├── lstm.py ├── mixture_variational_inference.py ├── natural_gradient_black_box_svi.py ├── negative_binomial_maxlike.py ├── neural_net.py ├── neural_net_regression.py ├── ode_net.py ├── ode_net_demo.png ├── print_trace.py ├── rkhs.py ├── rnn.py ├── rosenbrock.py ├── sinusoid.png ├── sinusoid.py ├── sinusoid_taylor.png ├── tanh.png ├── tanh.py ├── vae_samples.png └── variational_autoencoder.py ├── license.txt ├── noxfile.py ├── pyproject.toml └── tests ├── _test_complexity.py ├── check_examples_run.sh ├── conftest.py ├── numpy_utils.py ├── profiling.py ├── test_binary_ops.py ├── test_builtins.py ├── test_complex.py ├── test_core.py ├── test_dict.py ├── test_direct.py ├── test_fft.py ├── test_graphs.py ├── test_jacobian.py ├── test_linalg.py ├── test_list.py ├── test_logic.py ├── test_misc.py ├── test_numpy.py ├── test_performance.py ├── test_scalar_ops.py ├── test_scipy.py ├── test_systematic.py ├── test_tests.py ├── test_truediv.py ├── test_tuple.py ├── test_vspaces.py └── test_wrappers.py /.github/workflows/check.yml: -------------------------------------------------------------------------------- 1 | name: Style and package checks 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | push: 8 | branches: 9 | - master 10 | workflow_dispatch: 11 | 12 | env: 13 | PIP_DISABLE_PIP_VERSION_CHECK: "1" 14 | FORCE_COLOR: "3" 15 | 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 18 | cancel-in-progress: true 19 | 20 | jobs: 21 | check: 22 | name: ${{ matrix.env }} 23 | runs-on: ubuntu-latest 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | session: 28 | # - lint 29 | - validate-package 30 | steps: 31 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 32 | - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 33 | 34 | - uses: yezz123/setup-uv@ab6be5a42627f19dc36e57b548592a5e52cece4a # v4.1 35 | 36 | - name: Run ${{ matrix.env }} 37 | run: uvx nox -s ${{ matrix.env }} 38 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | workflow_dispatch: 5 | release: 6 | types: [published] 7 | 8 | env: 9 | PIP_DISABLE_PIP_VERSION_CHECK: '1' 10 | FORCE_COLOR: '3' 11 | 12 | jobs: 13 | build: 14 | name: Build sdist and wheel 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 18 | name: Checkout repository 19 | 20 | - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 21 | with: 22 | python-version: "3.12" 23 | 24 | - name: Install build tools 25 | run: | 26 | pipx run build --outdir dist 27 | 28 | - name: Upload wheel and sdist artifacts 29 | uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 30 | with: 31 | name: artifacts 32 | path: ./dist/* 33 | if-no-files-found: error 34 | 35 | publish: 36 | needs: [build] 37 | name: Upload to PyPI 38 | runs-on: ubuntu-latest 39 | environment: 40 | name: release 41 | url: https://pypi.org/p/autograd 42 | permissions: 43 | id-token: write # mandatory for trusted publishing 44 | 45 | steps: 46 | - name: Download artifacts 47 | uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0 48 | with: 49 | path: dist 50 | merge-multiple: true 51 | 52 | - name: Sanity check artifacts 53 | run: ls -la dist/ 54 | 55 | - name: Publish sdist and wheel to PyPI 56 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 57 | with: 58 | packages-dir: dist/ 59 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | push: 8 | branches: 9 | - master 10 | workflow_dispatch: 11 | schedule: 12 | - cron: "0 4 * * *" 13 | 14 | env: 15 | PIP_DISABLE_PIP_VERSION_CHECK: "1" 16 | FORCE_COLOR: "3" 17 | 18 | concurrency: 19 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 20 | cancel-in-progress: true 21 | 22 | jobs: 23 | test: 24 | name: Regular tests / ${{ matrix.platform }} / Python ${{ matrix.python-version }} 25 | runs-on: ${{ matrix.platform }} 26 | strategy: 27 | fail-fast: false 28 | matrix: 29 | platform: [ubuntu-latest, ubuntu-22.04-arm, macos-13, macos-latest, windows-latest] 30 | python-version: 31 | ["3.9", "3.10", "3.11", "3.12", "3.13", "pypy-3.10"] 32 | steps: 33 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 34 | - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | allow-prereleases: true 38 | - uses: yezz123/setup-uv@ab6be5a42627f19dc36e57b548592a5e52cece4a # v4.1 39 | 40 | # On PyPy, we skip SciPy because we don't have wheels 41 | # available, see noxfile.py for more details. 42 | - name: Run tests 43 | run: uvx nox -s tests 44 | 45 | # In this job, we test against the NumPy nightly wheels hosted on 46 | # https://anaconda.org/scientific-python-nightly-wheels/numpy 47 | # on the latest Python version available across platforms, instead of 48 | # testing all Python versions and implementations on all platforms. 49 | # We do not test on PyPy. 50 | # 51 | # However, "nox -s nightly-tests" can be used locally anywhere, on 52 | # any Python version and implementation on any platform and we leave 53 | # it to the user to decide what Python version to test against, which 54 | # might or might not have a corresponding NumPy nightly wheel present. 55 | nightlies: 56 | name: Nightly tests / ${{ matrix.platform }} / Python ${{ matrix.python-version }} 57 | runs-on: ${{ matrix.platform }} 58 | strategy: 59 | fail-fast: false 60 | matrix: 61 | platform: [ubuntu-latest, ubuntu-22.04-arm, macos-13, macos-latest, windows-latest] 62 | python-version: ["3.x"] 63 | 64 | steps: 65 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 66 | - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 67 | with: 68 | python-version: ${{ matrix.python-version }} 69 | allow-prereleases: true 70 | - uses: yezz123/setup-uv@ab6be5a42627f19dc36e57b548592a5e52cece4a # v4.1 71 | - name: Run tests against nightly wheels for NumPy and SciPy 72 | run: uvx nox -s nightly-tests 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | 5 | # Distribution / packaging 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | MANIFEST 23 | 24 | # Installer logs 25 | pip-log.txt 26 | pip-delete-this-directory.txt 27 | 28 | # Unit test / coverage reports 29 | htmlcov/ 30 | .tox/ 31 | .nox/ 32 | .coverage 33 | .coverage.* 34 | .cache 35 | coverage.* 36 | *.cover 37 | .hypothesis/ 38 | nosetests.xml 39 | .pytest_cache/ 40 | junit-report.xml 41 | 42 | # pyenv 43 | .python-version 44 | 45 | # Environments 46 | .env 47 | .venv 48 | env/ 49 | venv/ 50 | ENV/ 51 | env.bak/ 52 | venv.bak/ 53 | 54 | # mypy 55 | .mypy_cache/ 56 | 57 | # OS and IDE config files 58 | .DS_Store 59 | .idea/ 60 | 61 | # project-specific 62 | data/ 63 | *.so 64 | *.c 65 | scratch/ 66 | examples/data 67 | 68 | .asv/ 69 | asv.conf.json 70 | benchmarks/asv.conf.js 71 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_commit_msg: "chore: update pre-commit hooks" 3 | autofix_commit_msg: "style: pre-commit fixes" 4 | 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v5.0.0 8 | hooks: 9 | - id: check-added-large-files 10 | - id: check-case-conflict 11 | - id: check-merge-conflict 12 | - id: check-yaml 13 | exclude: conda_recipe/conda.yaml 14 | - id: debug-statements 15 | - id: end-of-file-fixer 16 | - id: mixed-line-ending 17 | - id: trailing-whitespace 18 | 19 | - repo: https://github.com/astral-sh/ruff-pre-commit 20 | rev: "v0.11.12" 21 | hooks: 22 | - id: ruff 23 | args: ["--fix", "--show-fixes"] 24 | - id: ruff-format 25 | 26 | - repo: https://github.com/pre-commit/pygrep-hooks 27 | rev: v1.10.0 28 | hooks: 29 | - id: python-check-blanket-type-ignore 30 | exclude: ^src/vector/backends/_numba_object.py$ 31 | - id: rst-backticks 32 | - id: rst-directive-colons 33 | - id: rst-inline-touching-normal 34 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Use [Nox](https://nox.thea.codes/en/stable/) to run tests and linting, e.g., 4 | 5 | ```shell 6 | pip install nox 7 | ``` 8 | 9 | `nox` will run all checks in an isolated virtual environment with Autograd and its dependencies, including its optional dependencies, installed. 10 | 11 | ## Run tests, linting, packaging checks 12 | 13 | | Command | Description | 14 | | ------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | 15 | | `nox --list` | Lists all available Nox sessions, including selected ones | 16 | | `nox -s lint` | Runs code style checks with pre-commit and pre-commit hooks as listed in `.pre-commit-config.yaml`. Accepts posargs to pass additional arguments to the linter. | 17 | | `nox -s tests` | Runs tests with your default Python interpreter. Accepts posargs to pass additional arguments and configuration to `pytest`. | 18 | | `nox -s nightly-tests` | Similar to `nox -s tests`, except that it runs tests with nightly versions of dependencies (NumPy, SciPy, etc.). | 19 | | `nox -s validate-package` | Builds a source distribution and a wheel using `pypa/build` and checks the package with `twine` in strict mode. | 20 | | `nox` | Runs all selected sessions, as listed in `nox.options.sessions` in `noxfile.py`. | 21 | 22 | Additionally, `nox` supports tags to run specific sessions, e.g., `nox --tags tests` runs all sessions tagged with `tests`. 23 | 24 | Make sure all tests pass before you push your changes to GitHub. 25 | GH Actions will run the tests across all supported Python versions. 26 | 27 | ## Using positional arguments (reformat, upload package, help) 28 | 29 | You can use additional arguments for the tools (`pytest`, `pre-commit`, etc.) called by Nox by 30 | separating them from the Nox arguments by a double-hyphen `--`, e.g., 31 | 32 | - `nox -s tests -- --tests/test_tuple.py` runs just the tests listed `tests/test_tuple.py`. 33 | - `nox -s lint -- --fix` runs the linter with the `--fix` flag. 34 | - and so on. 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Autograd [![Checks status][checks-badge]][checks-url] [![Tests status][tests-badge]][tests-url] [![Publish status][publish-badge]][publish-url] [![asv][asv-badge]](#) 2 | 3 | [publish-badge]: https://github.com/HIPS/autograd/actions/workflows/publish.yml/badge.svg 4 | [checks-badge]: https://github.com/HIPS/autograd/actions/workflows/check.yml/badge.svg 5 | [tests-badge]: https://github.com/HIPS/autograd/actions/workflows/test.yml/badge.svg 6 | [asv-badge]: http://img.shields.io/badge/benchmarked%20by-asv-green.svg?style=flat 7 | [publish-url]: https://github.com/HIPS/autograd/actions/workflows/publish.yml 8 | [checks-url]: https://github.com/HIPS/autograd/actions/workflows/check.yml 9 | [tests-url]: https://github.com/HIPS/autograd/actions/workflows/test.yml 10 | 11 | Autograd can automatically differentiate native Python and Numpy code. It can 12 | handle a large subset of Python's features, including loops, ifs, recursion and 13 | closures, and it can even take derivatives of derivatives of derivatives. It 14 | supports reverse-mode differentiation (a.k.a. backpropagation), which means it 15 | can efficiently take gradients of scalar-valued functions with respect to 16 | array-valued arguments, as well as forward-mode differentiation, and the two can 17 | be composed arbitrarily. The main intended application of Autograd is 18 | gradient-based optimization. For more information, check out the 19 | [tutorial](docs/tutorial.md) and the [examples directory](examples/). 20 | 21 | Example use: 22 | 23 | ```python 24 | >>> import autograd.numpy as np # Thinly-wrapped numpy 25 | >>> from autograd import grad # The only autograd function you may ever need 26 | >>> 27 | >>> def tanh(x): # Define a function 28 | ... return (1.0 - np.exp((-2 * x))) / (1.0 + np.exp(-(2 * x))) 29 | ... 30 | >>> grad_tanh = grad(tanh) # Obtain its gradient function 31 | >>> grad_tanh(1.0) # Evaluate the gradient at x = 1.0 32 | np.float64(0.419974341614026) 33 | >>> (tanh(1.0001) - tanh(0.9999)) / 0.0002 # Compare to finite differences 34 | np.float64(0.41997434264973155) 35 | ``` 36 | 37 | We can continue to differentiate as many times as we like, and use numpy's 38 | vectorization of scalar-valued functions across many different input values: 39 | 40 | ```python 41 | >>> from autograd import elementwise_grad as egrad # for functions that vectorize over inputs 42 | >>> import matplotlib.pyplot as plt 43 | >>> x = np.linspace(-7, 7, 700) 44 | >>> plt.plot(x, tanh(x), 45 | ... x, egrad(tanh)(x), # first derivative 46 | ... x, egrad(egrad(tanh))(x), # second derivative 47 | ... x, egrad(egrad(egrad(tanh)))(x), # third derivative 48 | ... x, egrad(egrad(egrad(egrad(tanh))))(x), # fourth derivative 49 | >>> plt.show() 50 | ``` 51 | 52 | 53 | 54 | See the [tanh example file](examples/tanh.py) for the code. 55 | 56 | ## Documentation 57 | 58 | You can find a tutorial [here.](docs/tutorial.md) 59 | 60 | ## End-to-end examples 61 | 62 | * [Simple neural net](examples/neural_net.py) 63 | * [Convolutional neural net](examples/convnet.py) 64 | * [Recurrent neural net](examples/rnn.py) 65 | * [LSTM](examples/lstm.py) 66 | * [Neural Turing Machine](https://github.com/DoctorTeeth/diffmem/blob/512aadeefd6dbafc1bdd253a64b6be192a435dc3/ntm/ntm.py) 67 | * [Backpropagating through a fluid simulation](examples/fluidsim/fluidsim.py) 68 | 69 | 70 | 71 | * [Variational inference in Bayesian neural network](examples/bayesian_neural_net.py) 72 | * [Gaussian process regression](examples/gaussian_process.py) 73 | * [Sampyl, a pure Python MCMC package with HMC and NUTS](https://github.com/mcleonard/sampyl) 74 | 75 | ## How to install 76 | 77 | Install Autograd using Pip: 78 | 79 | ```shell 80 | pip install autograd 81 | ``` 82 | 83 | Some features require SciPy, which you can install separately or as an 84 | optional dependency along with Autograd: 85 | 86 | ```shell 87 | pip install "autograd[scipy]" 88 | ``` 89 | 90 | ## Authors and maintainers 91 | 92 | Autograd was written by [Dougal Maclaurin](https://dougalmaclaurin.com), 93 | [David Duvenaud](https://www.cs.toronto.edu/~duvenaud/), 94 | [Matt Johnson](http://people.csail.mit.edu/mattjj/), 95 | [Jamie Townsend](https://github.com/j-towns) 96 | and many other contributors. The package is currently being maintained by 97 | [Agriya Khetarpal](https://github.com/agriyakhetarpal), 98 | [Fabian Joswig](https://github.com/fjosw) and 99 | [Jamie Townsend](https://github.com/j-towns). 100 | Please feel free to submit any bugs or 101 | feature requests. We'd also love to hear about your experiences with Autograd 102 | in general. Drop us an email! 103 | 104 | We want to thank Jasper Snoek and the rest of the HIPS group (led by Prof. Ryan 105 | P. Adams) for helpful contributions and advice; Barak Pearlmutter for 106 | foundational work on automatic differentiation and for guidance on our 107 | implementation; and Analog Devices Inc. (Lyric Labs) and Samsung Advanced Institute 108 | of Technology for their generous support. 109 | -------------------------------------------------------------------------------- /autograd/__init__.py: -------------------------------------------------------------------------------- 1 | from autograd.core import primitive_with_deprecation_warnings as primitive 2 | 3 | from .builtins import dict, isinstance, list, tuple, type 4 | from .differential_operators import ( 5 | checkpoint, 6 | deriv, 7 | elementwise_grad, 8 | grad, 9 | grad_and_aux, 10 | grad_named, 11 | hessian, 12 | hessian_tensor_product, 13 | hessian_vector_product, 14 | holomorphic_grad, 15 | jacobian, 16 | make_ggnvp, 17 | make_hvp, 18 | make_jvp, 19 | make_vjp, 20 | multigrad_dict, 21 | tensor_jacobian_product, 22 | value_and_grad, 23 | vector_jacobian_product, 24 | ) 25 | -------------------------------------------------------------------------------- /autograd/extend.py: -------------------------------------------------------------------------------- 1 | # Exposes API for extending autograd 2 | from .core import ( 3 | JVPNode, 4 | SparseObject, 5 | VJPNode, 6 | VSpace, 7 | def_linear, 8 | defjvp, 9 | defjvp_argnum, 10 | defjvp_argnums, 11 | defvjp, 12 | defvjp_argnum, 13 | defvjp_argnums, 14 | vspace, 15 | ) 16 | from .tracer import Box, notrace_primitive, primitive, register_notrace 17 | -------------------------------------------------------------------------------- /autograd/misc/__init__.py: -------------------------------------------------------------------------------- 1 | from .flatten import flatten 2 | from .tracers import const_graph 3 | -------------------------------------------------------------------------------- /autograd/misc/fixed_points.py: -------------------------------------------------------------------------------- 1 | from autograd import make_vjp 2 | from autograd.builtins import tuple 3 | from autograd.extend import defvjp, primitive, vspace 4 | 5 | 6 | @primitive 7 | def fixed_point(f, a, x0, distance, tol): 8 | _f = f(a) 9 | x, x_prev = _f(x0), x0 10 | while distance(x, x_prev) > tol: 11 | x, x_prev = _f(x), x 12 | return x 13 | 14 | 15 | def fixed_point_vjp(ans, f, a, x0, distance, tol): 16 | def rev_iter(params): 17 | a, x_star, x_star_bar = params 18 | vjp_x, _ = make_vjp(f(a))(x_star) 19 | vs = vspace(x_star) 20 | return lambda g: vs.add(vjp_x(g), x_star_bar) 21 | 22 | vjp_a, _ = make_vjp(lambda x, y: f(x)(y))(a, ans) 23 | return lambda g: vjp_a(fixed_point(rev_iter, tuple((a, ans, g)), vspace(x0).zeros(), distance, tol)) 24 | 25 | 26 | defvjp(fixed_point, None, fixed_point_vjp, None) 27 | -------------------------------------------------------------------------------- /autograd/misc/flatten.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handy functions for flattening nested containers containing numpy 3 | arrays. The main purpose is to make examples and optimizers simpler. 4 | """ 5 | 6 | import autograd.numpy as np 7 | from autograd import make_vjp 8 | from autograd.builtins import type 9 | 10 | 11 | def flatten(value): 12 | """Flattens any nesting of tuples, lists, or dicts, with numpy arrays or 13 | scalars inside. Returns 1D numpy array and an unflatten function. 14 | Doesn't preserve mixed numeric types (e.g. floats and ints). Assumes dict 15 | keys are sortable.""" 16 | unflatten, flat_value = make_vjp(_flatten)(value) 17 | return flat_value, unflatten 18 | 19 | 20 | def _flatten(value): 21 | t = type(value) 22 | if t in (list, tuple): 23 | return _concatenate(map(_flatten, value)) 24 | elif t is dict: 25 | return _concatenate(_flatten(value[k]) for k in sorted(value)) 26 | else: 27 | return np.ravel(value) 28 | 29 | 30 | def _concatenate(lst): 31 | lst = list(lst) 32 | return np.concatenate(lst) if lst else np.array([]) 33 | 34 | 35 | def flatten_func(func, example): 36 | _ex, unflatten = flatten(example) 37 | _func = lambda _x, *args: flatten(func(unflatten(_x), *args))[0] 38 | return _func, unflatten, _ex 39 | -------------------------------------------------------------------------------- /autograd/misc/optimizers.py: -------------------------------------------------------------------------------- 1 | """Some standard gradient-based stochastic optimizers. 2 | 3 | These are just standard routines that don't make any use of autograd, 4 | though you could take gradients of these functions too if you want 5 | to do meta-optimization. 6 | 7 | These routines can optimize functions whose inputs are structured 8 | objects, such as dicts of numpy arrays.""" 9 | 10 | import autograd.numpy as np 11 | from autograd.misc import flatten 12 | from autograd.wrap_util import wraps 13 | 14 | 15 | def unflatten_optimizer(optimize): 16 | """Takes an optimizer that operates on flat 1D numpy arrays and returns a 17 | wrapped version that handles trees of nested containers (lists/tuples/dicts) 18 | with arrays/scalars at the leaves.""" 19 | 20 | @wraps(optimize) 21 | def _optimize(grad, x0, callback=None, *args, **kwargs): 22 | _x0, unflatten = flatten(x0) 23 | _grad = lambda x, i: flatten(grad(unflatten(x), i))[0] 24 | if callback: 25 | _callback = lambda x, i, g: callback(unflatten(x), i, unflatten(g)) 26 | else: 27 | _callback = None 28 | return unflatten(optimize(_grad, _x0, _callback, *args, **kwargs)) 29 | 30 | return _optimize 31 | 32 | 33 | @unflatten_optimizer 34 | def sgd(grad, x, callback=None, num_iters=200, step_size=0.1, mass=0.9): 35 | """Stochastic gradient descent with momentum. 36 | grad() must have signature grad(x, i), where i is the iteration number.""" 37 | velocity = np.zeros(len(x)) 38 | for i in range(num_iters): 39 | g = grad(x, i) 40 | if callback: 41 | callback(x, i, g) 42 | velocity = mass * velocity - (1.0 - mass) * g 43 | x = x + step_size * velocity 44 | return x 45 | 46 | 47 | @unflatten_optimizer 48 | def rmsprop(grad, x, callback=None, num_iters=100, step_size=0.1, gamma=0.9, eps=10**-8): 49 | """Root mean squared prop: See Adagrad paper for details.""" 50 | avg_sq_grad = np.ones(len(x)) 51 | for i in range(num_iters): 52 | g = grad(x, i) 53 | if callback: 54 | callback(x, i, g) 55 | avg_sq_grad = avg_sq_grad * gamma + g**2 * (1 - gamma) 56 | x = x - step_size * g / (np.sqrt(avg_sq_grad) + eps) 57 | return x 58 | 59 | 60 | @unflatten_optimizer 61 | def adam(grad, x, callback=None, num_iters=100, step_size=0.001, b1=0.9, b2=0.999, eps=10**-8): 62 | """Adam as described in http://arxiv.org/pdf/1412.6980.pdf. 63 | It's basically RMSprop with momentum and some correction terms.""" 64 | m = np.zeros(len(x)) 65 | v = np.zeros(len(x)) 66 | for i in range(num_iters): 67 | g = grad(x, i) 68 | if callback: 69 | callback(x, i, g) 70 | m = (1 - b1) * g + b1 * m # First moment estimate. 71 | v = (1 - b2) * (g**2) + b2 * v # Second moment estimate. 72 | mhat = m / (1 - b1 ** (i + 1)) # Bias correction. 73 | vhat = v / (1 - b2 ** (i + 1)) 74 | x = x - step_size * mhat / (np.sqrt(vhat) + eps) 75 | return x 76 | -------------------------------------------------------------------------------- /autograd/misc/tracers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import repeat 3 | 4 | from autograd.tracer import Node, trace 5 | from autograd.util import subvals, toposort 6 | from autograd.wrap_util import wraps 7 | 8 | 9 | class ConstGraphNode(Node): 10 | __slots__ = ["parents", "partial_fun"] 11 | 12 | def __init__(self, value, fun, args, kwargs, parent_argnums, parents): 13 | args = subvals(args, zip(parent_argnums, repeat(None))) 14 | 15 | def partial_fun(partial_args): 16 | return fun(*subvals(args, zip(parent_argnums, partial_args)), **kwargs) 17 | 18 | self.parents = parents 19 | self.partial_fun = partial_fun 20 | 21 | def initialize_root(self): 22 | self.parents = [] 23 | 24 | 25 | def const_graph_unary(fun): 26 | graph = [] 27 | _fun = [fun] # Allow fun to be freed, since it may have bound args 28 | 29 | def maybe_cached_fun(x): 30 | if graph: 31 | _graph = graph[0] 32 | vals = {_graph[0]: x} 33 | for node in _graph[1:]: 34 | vals[node] = node.partial_fun([vals[p] for p in node.parents]) 35 | return vals[node] 36 | else: 37 | start_node = ConstGraphNode.new_root() 38 | end_value, end_node = trace(start_node, _fun.pop(), x) 39 | if end_node is None: 40 | raise Exception("Output is independent of input") 41 | graph.append(list(toposort(end_node))[::-1]) 42 | return end_value 43 | 44 | return maybe_cached_fun 45 | 46 | 47 | def const_graph(fun, *args, **kwargs): 48 | partial_fun = partial(fun, *args, **kwargs) 49 | unary_fun = lambda args: partial_fun(*args) 50 | maybe_cached_unary_fun = const_graph_unary(unary_fun) 51 | 52 | @wraps(fun) 53 | def _fun(*args): 54 | return maybe_cached_unary_fun(args) 55 | 56 | return _fun 57 | 58 | 59 | class FullGraphNode(Node): 60 | __slots__ = ["value", "recipe"] 61 | 62 | def __init__(self, value, fun, args, kwargs, parent_argnums, parents): 63 | self.value = value 64 | self.recipe = (fun, args, kwargs, zip(parent_argnums, parents)) 65 | 66 | def initialize_root(self): 67 | self.value = None 68 | self.recipe = (lambda x: x, (), {}, []) 69 | 70 | 71 | def full_graph(fun, *args, **kwargs): 72 | unary_fun = lambda args: fun(*args, **kwargs) 73 | start_node = FullGraphNode.new_root() 74 | end_value, end_node = trace(start_node, unary_fun, args) 75 | return end_node 76 | -------------------------------------------------------------------------------- /autograd/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fft, linalg, numpy_boxes, numpy_jvps, numpy_vjps, numpy_vspaces, random 2 | from .numpy_wrapper import * 3 | from .numpy_wrapper import numpy_version as __version__ 4 | -------------------------------------------------------------------------------- /autograd/numpy/numpy_boxes.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | 5 | from autograd.builtins import SequenceBox 6 | from autograd.extend import Box, primitive 7 | 8 | from . import numpy_wrapper as anp 9 | 10 | Box.__array_priority__ = 90.0 11 | 12 | 13 | class ArrayBox(Box): 14 | __slots__ = [] 15 | __array_priority__ = 100.0 16 | 17 | @primitive 18 | def __getitem__(A, idx): 19 | return A[idx] 20 | 21 | # Constants w.r.t float data just pass though 22 | shape = property(lambda self: self._value.shape) 23 | ndim = property(lambda self: self._value.ndim) 24 | size = property(lambda self: self._value.size) 25 | dtype = property(lambda self: self._value.dtype) 26 | T = property(lambda self: anp.transpose(self)) 27 | 28 | def __array_namespace__(self, *, api_version: Union[str, None] = None): 29 | return anp 30 | 31 | def __len__(self): 32 | return len(self._value) 33 | 34 | def astype(self, *args, **kwargs): 35 | return anp._astype(self, *args, **kwargs) 36 | 37 | def __neg__(self): 38 | return anp.negative(self) 39 | 40 | def __add__(self, other): 41 | return anp.add(self, other) 42 | 43 | def __sub__(self, other): 44 | return anp.subtract(self, other) 45 | 46 | def __mul__(self, other): 47 | return anp.multiply(self, other) 48 | 49 | def __pow__(self, other): 50 | return anp.power(self, other) 51 | 52 | def __div__(self, other): 53 | return anp.divide(self, other) 54 | 55 | def __mod__(self, other): 56 | return anp.mod(self, other) 57 | 58 | def __truediv__(self, other): 59 | return anp.true_divide(self, other) 60 | 61 | def __matmul__(self, other): 62 | return anp.matmul(self, other) 63 | 64 | def __radd__(self, other): 65 | return anp.add(other, self) 66 | 67 | def __rsub__(self, other): 68 | return anp.subtract(other, self) 69 | 70 | def __rmul__(self, other): 71 | return anp.multiply(other, self) 72 | 73 | def __rpow__(self, other): 74 | return anp.power(other, self) 75 | 76 | def __rdiv__(self, other): 77 | return anp.divide(other, self) 78 | 79 | def __rmod__(self, other): 80 | return anp.mod(other, self) 81 | 82 | def __rtruediv__(self, other): 83 | return anp.true_divide(other, self) 84 | 85 | def __rmatmul__(self, other): 86 | return anp.matmul(other, self) 87 | 88 | def __eq__(self, other): 89 | return anp.equal(self, other) 90 | 91 | def __ne__(self, other): 92 | return anp.not_equal(self, other) 93 | 94 | def __gt__(self, other): 95 | return anp.greater(self, other) 96 | 97 | def __ge__(self, other): 98 | return anp.greater_equal(self, other) 99 | 100 | def __lt__(self, other): 101 | return anp.less(self, other) 102 | 103 | def __le__(self, other): 104 | return anp.less_equal(self, other) 105 | 106 | def __abs__(self): 107 | return anp.abs(self) 108 | 109 | def __hash__(self): 110 | return id(self) 111 | 112 | 113 | ArrayBox.register(np.ndarray) 114 | for type_ in [ 115 | float, 116 | np.longdouble, 117 | np.float64, 118 | np.float32, 119 | np.float16, 120 | complex, 121 | np.clongdouble, 122 | np.complex64, 123 | np.complex128, 124 | ]: 125 | ArrayBox.register(type_) 126 | 127 | # These numpy.ndarray methods are just refs to an equivalent numpy function 128 | nondiff_methods = [ 129 | "all", 130 | "any", 131 | "argmax", 132 | "argmin", 133 | "argpartition", 134 | "argsort", 135 | "nonzero", 136 | "searchsorted", 137 | "round", 138 | ] 139 | diff_methods = [ 140 | "clip", 141 | "compress", 142 | "cumprod", 143 | "cumsum", 144 | "diagonal", 145 | "max", 146 | "mean", 147 | "min", 148 | "prod", 149 | "ptp", 150 | "ravel", 151 | "repeat", 152 | "reshape", 153 | "squeeze", 154 | "std", 155 | "sum", 156 | "swapaxes", 157 | "take", 158 | "trace", 159 | "transpose", 160 | "var", 161 | ] 162 | for method_name in nondiff_methods + diff_methods: 163 | setattr(ArrayBox, method_name, anp.__dict__[method_name]) 164 | 165 | # Flatten has no function, only a method. 166 | setattr(ArrayBox, "flatten", anp.__dict__["ravel"]) 167 | 168 | if np.lib.NumpyVersion(np.__version__) >= "2.0.0": 169 | SequenceBox.register(np.linalg._linalg.EigResult) 170 | SequenceBox.register(np.linalg._linalg.EighResult) 171 | SequenceBox.register(np.linalg._linalg.QRResult) 172 | SequenceBox.register(np.linalg._linalg.SlogdetResult) 173 | SequenceBox.register(np.linalg._linalg.SVDResult) 174 | elif np.__version__ >= "1.25": 175 | SequenceBox.register(np.linalg.linalg.EigResult) 176 | SequenceBox.register(np.linalg.linalg.EighResult) 177 | SequenceBox.register(np.linalg.linalg.QRResult) 178 | SequenceBox.register(np.linalg.linalg.SlogdetResult) 179 | SequenceBox.register(np.linalg.linalg.SVDResult) 180 | -------------------------------------------------------------------------------- /autograd/numpy/numpy_vspaces.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from autograd.builtins import NamedTupleVSpace 4 | from autograd.extend import VSpace 5 | 6 | 7 | class ArrayVSpace(VSpace): 8 | def __init__(self, value): 9 | value = np.asarray(value) 10 | self.shape = value.shape 11 | self.dtype = value.dtype 12 | 13 | @property 14 | def size(self): 15 | return np.prod(self.shape) 16 | 17 | @property 18 | def ndim(self): 19 | return len(self.shape) 20 | 21 | def zeros(self): 22 | return np.zeros(self.shape, dtype=self.dtype) 23 | 24 | def ones(self): 25 | return np.ones(self.shape, dtype=self.dtype) 26 | 27 | def standard_basis(self): 28 | for idxs in np.ndindex(*self.shape): 29 | vect = np.zeros(self.shape, dtype=self.dtype) 30 | vect[idxs] = 1 31 | yield vect 32 | 33 | def randn(self): 34 | return np.array(np.random.randn(*self.shape)).astype(self.dtype) 35 | 36 | def _inner_prod(self, x, y): 37 | return np.dot(np.ravel(x), np.ravel(y)) 38 | 39 | 40 | class ComplexArrayVSpace(ArrayVSpace): 41 | iscomplex = True 42 | 43 | @property 44 | def size(self): 45 | return np.prod(self.shape) * 2 46 | 47 | def ones(self): 48 | return np.ones(self.shape, dtype=self.dtype) + 1.0j * np.ones(self.shape, dtype=self.dtype) 49 | 50 | def standard_basis(self): 51 | for idxs in np.ndindex(*self.shape): 52 | for v in [1.0, 1.0j]: 53 | vect = np.zeros(self.shape, dtype=self.dtype) 54 | vect[idxs] = v 55 | yield vect 56 | 57 | def randn(self): 58 | return np.array(np.random.randn(*self.shape)).astype(self.dtype) + 1.0j * np.array( 59 | np.random.randn(*self.shape) 60 | ).astype(self.dtype) 61 | 62 | def _inner_prod(self, x, y): 63 | return np.real(np.dot(np.conj(np.ravel(x)), np.ravel(y))) 64 | 65 | def _covector(self, x): 66 | return np.conj(x) 67 | 68 | 69 | VSpace.register(np.ndarray, lambda x: ComplexArrayVSpace(x) if np.iscomplexobj(x) else ArrayVSpace(x)) 70 | 71 | for type_ in [float, np.longdouble, np.float64, np.float32, np.float16]: 72 | ArrayVSpace.register(type_) 73 | 74 | for type_ in [complex, np.clongdouble, np.complex64, np.complex128]: 75 | ComplexArrayVSpace.register(type_) 76 | 77 | 78 | if np.lib.NumpyVersion(np.__version__) >= "2.0.0": 79 | 80 | class EigResultVSpace(NamedTupleVSpace): 81 | seq_type = np.linalg._linalg.EigResult 82 | 83 | class EighResultVSpace(NamedTupleVSpace): 84 | seq_type = np.linalg._linalg.EighResult 85 | 86 | class QRResultVSpace(NamedTupleVSpace): 87 | seq_type = np.linalg._linalg.QRResult 88 | 89 | class SlogdetResultVSpace(NamedTupleVSpace): 90 | seq_type = np.linalg._linalg.SlogdetResult 91 | 92 | class SVDResultVSpace(NamedTupleVSpace): 93 | seq_type = np.linalg._linalg.SVDResult 94 | 95 | EigResultVSpace.register(np.linalg._linalg.EigResult) 96 | EighResultVSpace.register(np.linalg._linalg.EighResult) 97 | QRResultVSpace.register(np.linalg._linalg.QRResult) 98 | SlogdetResultVSpace.register(np.linalg._linalg.SlogdetResult) 99 | SVDResultVSpace.register(np.linalg._linalg.SVDResult) 100 | elif np.__version__ >= "1.25": 101 | 102 | class EigResultVSpace(NamedTupleVSpace): 103 | seq_type = np.linalg.linalg.EigResult 104 | 105 | class EighResultVSpace(NamedTupleVSpace): 106 | seq_type = np.linalg.linalg.EighResult 107 | 108 | class QRResultVSpace(NamedTupleVSpace): 109 | seq_type = np.linalg.linalg.QRResult 110 | 111 | class SlogdetResultVSpace(NamedTupleVSpace): 112 | seq_type = np.linalg.linalg.SlogdetResult 113 | 114 | class SVDResultVSpace(NamedTupleVSpace): 115 | seq_type = np.linalg.linalg.SVDResult 116 | 117 | EigResultVSpace.register(np.linalg.linalg.EigResult) 118 | EighResultVSpace.register(np.linalg.linalg.EighResult) 119 | QRResultVSpace.register(np.linalg.linalg.QRResult) 120 | SlogdetResultVSpace.register(np.linalg.linalg.SlogdetResult) 121 | SVDResultVSpace.register(np.linalg.linalg.SVDResult) 122 | -------------------------------------------------------------------------------- /autograd/numpy/random.py: -------------------------------------------------------------------------------- 1 | import numpy.random as npr 2 | 3 | from .numpy_wrapper import wrap_namespace 4 | 5 | wrap_namespace(npr.__dict__, globals()) 6 | -------------------------------------------------------------------------------- /autograd/scipy/__init__.py: -------------------------------------------------------------------------------- 1 | from . import integrate, signal, special, stats 2 | -------------------------------------------------------------------------------- /autograd/scipy/integrate.py: -------------------------------------------------------------------------------- 1 | import scipy.integrate 2 | 3 | import autograd.numpy as np 4 | from autograd import make_vjp 5 | from autograd.builtins import tuple 6 | from autograd.extend import defvjp_argnums, primitive 7 | from autograd.misc import flatten 8 | 9 | odeint = primitive(scipy.integrate.odeint) 10 | 11 | 12 | def grad_odeint(yt, func, y0, t, func_args, **kwargs): 13 | # Extended from "Scalable Inference of Ordinary Differential 14 | # Equation Models of Biochemical Processes", Sec. 2.4.2 15 | # Fabian Froehlich, Carolin Loos, Jan Hasenauer, 2017 16 | # https://arxiv.org/abs/1711.08079 17 | 18 | T, D = np.shape(yt) 19 | flat_args, unflatten = flatten(func_args) 20 | 21 | def flat_func(y, t, flat_args): 22 | return func(y, t, *unflatten(flat_args)) 23 | 24 | def unpack(x): 25 | # y, vjp_y, vjp_t, vjp_args 26 | return x[0:D], x[D : 2 * D], x[2 * D], x[2 * D + 1 :] 27 | 28 | def augmented_dynamics(augmented_state, t, flat_args): 29 | # Orginal system augmented with vjp_y, vjp_t and vjp_args. 30 | y, vjp_y, _, _ = unpack(augmented_state) 31 | vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args) 32 | vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y) 33 | return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args)) 34 | 35 | def vjp_all(g): 36 | vjp_y = g[-1, :] 37 | vjp_t0 = 0 38 | time_vjp_list = [] 39 | vjp_args = np.zeros(np.size(flat_args)) 40 | 41 | for i in range(T - 1, 0, -1): 42 | # Compute effect of moving measurement time. 43 | vjp_cur_t = np.dot(func(yt[i, :], t[i], *func_args), g[i, :]) 44 | time_vjp_list.append(vjp_cur_t) 45 | vjp_t0 = vjp_t0 - vjp_cur_t 46 | 47 | # Run augmented system backwards to the previous observation. 48 | aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args)) 49 | aug_ans = odeint( 50 | augmented_dynamics, aug_y0, np.array([t[i], t[i - 1]]), tuple((flat_args,)), **kwargs 51 | ) 52 | _, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1]) 53 | 54 | # Add gradient from current output. 55 | vjp_y = vjp_y + g[i - 1, :] 56 | 57 | time_vjp_list.append(vjp_t0) 58 | vjp_times = np.hstack(time_vjp_list)[::-1] 59 | 60 | return None, vjp_y, vjp_times, unflatten(vjp_args) 61 | 62 | return vjp_all 63 | 64 | 65 | def argnums_unpack(all_vjp_builder): 66 | # A generic autograd helper function. Takes a function that 67 | # builds vjps for all arguments, and wraps it to return only required vjps. 68 | def build_selected_vjps(argnums, ans, combined_args, kwargs): 69 | vjp_func = all_vjp_builder(ans, *combined_args, **kwargs) 70 | 71 | def chosen_vjps(g): # Returns whichever vjps were asked for. 72 | all_vjps = vjp_func(g) 73 | return [all_vjps[argnum] for argnum in argnums] 74 | 75 | return chosen_vjps 76 | 77 | return build_selected_vjps 78 | 79 | 80 | defvjp_argnums(odeint, argnums_unpack(grad_odeint)) 81 | -------------------------------------------------------------------------------- /autograd/scipy/linalg.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import scipy.linalg 4 | 5 | import autograd.numpy as anp 6 | from autograd.extend import defjvp, defjvp_argnums, defvjp, defvjp_argnums 7 | from autograd.numpy.numpy_wrapper import wrap_namespace 8 | 9 | wrap_namespace(scipy.linalg.__dict__, globals()) # populates module namespace 10 | 11 | 12 | def _vjp_sqrtm(ans, A, disp=True, blocksize=64): 13 | assert disp, "sqrtm vjp not implemented for disp=False" 14 | ans_transp = anp.transpose(ans) 15 | 16 | def vjp(g): 17 | return anp.real(solve_sylvester(ans_transp, ans_transp, g)) 18 | 19 | return vjp 20 | 21 | 22 | defvjp(sqrtm, _vjp_sqrtm) 23 | 24 | 25 | def _flip(a, trans): 26 | if anp.iscomplexobj(a): 27 | return "H" if trans in ("N", 0) else "N" 28 | else: 29 | return "T" if trans in ("N", 0) else "N" 30 | 31 | 32 | def grad_solve_triangular(ans, a, b, trans=0, lower=False, **kwargs): 33 | tri = anp.tril if (lower ^ (_flip(a, trans) == "N")) else anp.triu 34 | transpose = lambda x: x if _flip(a, trans) != "N" else x.T 35 | al2d = lambda x: x if x.ndim > 1 else x[..., None] 36 | 37 | def vjp(g): 38 | v = al2d(solve_triangular(a, g, trans=_flip(a, trans), lower=lower)) 39 | return -transpose(tri(anp.dot(v, al2d(ans).T))) 40 | 41 | return vjp 42 | 43 | 44 | defvjp( 45 | solve_triangular, 46 | grad_solve_triangular, 47 | lambda ans, a, b, trans=0, lower=False, **kwargs: lambda g: solve_triangular( 48 | a, g, trans=_flip(a, trans), lower=lower 49 | ), 50 | ) 51 | 52 | 53 | def grad_solve_banded(argnum, ans, l_and_u, a, b): 54 | updim = lambda x: x if x.ndim == a.ndim else x[..., None] 55 | 56 | def transpose_banded(l_and_u, a): 57 | # Compute the transpose of a banded matrix. 58 | # The transpose is itself a banded matrix. 59 | 60 | num_rows = a.shape[0] 61 | 62 | shifts = anp.arange(-l_and_u[1], l_and_u[0] + 1) 63 | 64 | T_a = anp.roll(a[:1, :], shifts[0]) 65 | for rr in range(1, num_rows): 66 | T_a = anp.vstack([T_a, anp.flipud(anp.roll(a[rr : rr + 1, :], shifts[rr]))]) 67 | T_a = anp.flipud(T_a) 68 | 69 | T_l_and_u = anp.flip(l_and_u) 70 | 71 | return T_l_and_u, T_a 72 | 73 | def banded_dot(l_and_u, uu, vv): 74 | # Compute tensor product of vectors uu and vv. 75 | # Tensor product elements are resticted to the bands specified by l_and_u. 76 | 77 | # TODO: replace the brute-force ravel() by smarter dimension handeling of uu and vv 78 | 79 | # main diagonal 80 | banded_uv = anp.ravel(uu) * anp.ravel(vv) 81 | 82 | # stack below the sub-diagonals 83 | for rr in range(1, l_and_u[0] + 1): 84 | banded_uv_rr = anp.hstack([anp.ravel(uu)[rr:] * anp.ravel(vv)[:-rr], anp.zeros(rr)]) 85 | banded_uv = anp.vstack([banded_uv, banded_uv_rr]) 86 | 87 | # stack above the sup-diagonals 88 | for rr in range(1, l_and_u[1] + 1): 89 | banded_uv_rr = anp.hstack([anp.zeros(rr), anp.ravel(uu)[:-rr] * anp.ravel(vv)[rr:]]) 90 | banded_uv = anp.vstack([banded_uv_rr, banded_uv]) 91 | 92 | return banded_uv 93 | 94 | T_l_and_u, T_a = transpose_banded(l_and_u, a) 95 | 96 | if argnum == 1: 97 | return lambda g: -banded_dot( 98 | l_and_u, updim(solve_banded(T_l_and_u, T_a, g)), anp.transpose(updim(ans)) 99 | ) 100 | elif argnum == 2: 101 | return lambda g: solve_banded(T_l_and_u, T_a, g) 102 | 103 | 104 | defvjp(solve_banded, partial(grad_solve_banded, 1), partial(grad_solve_banded, 2), argnums=[1, 2]) 105 | 106 | 107 | def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64): 108 | assert disp, "sqrtm jvp not implemented for disp=False" 109 | return solve_sylvester(ans, ans, dA) 110 | 111 | 112 | defjvp(sqrtm, _jvp_sqrtm) 113 | 114 | 115 | def _jvp_sylvester(argnums, dms, ans, args, _): 116 | a, b, q = args 117 | if 0 in argnums: 118 | da = dms[0] 119 | db = dms[1] if 1 in argnums else 0 120 | else: 121 | da = 0 122 | db = dms[0] if 1 in argnums else 0 123 | dq = dms[-1] if 2 in argnums else 0 124 | rhs = dq - anp.dot(da, ans) - anp.dot(ans, db) 125 | return solve_sylvester(a, b, rhs) 126 | 127 | 128 | defjvp_argnums(solve_sylvester, _jvp_sylvester) 129 | 130 | 131 | def _vjp_sylvester(argnums, ans, args, _): 132 | a, b, q = args 133 | 134 | def vjp(g): 135 | vjps = [] 136 | q_vjp = solve_sylvester(anp.transpose(a), anp.transpose(b), g) 137 | if 0 in argnums: 138 | vjps.append(-anp.dot(q_vjp, anp.transpose(ans))) 139 | if 1 in argnums: 140 | vjps.append(-anp.dot(anp.transpose(ans), q_vjp)) 141 | if 2 in argnums: 142 | vjps.append(q_vjp) 143 | return tuple(vjps) 144 | 145 | return vjp 146 | 147 | 148 | defvjp_argnums(solve_sylvester, _vjp_sylvester) 149 | -------------------------------------------------------------------------------- /autograd/scipy/stats/__init__.py: -------------------------------------------------------------------------------- 1 | from . import beta, chi2, gamma, norm, poisson, t 2 | 3 | # Try block needed in case the user has an 4 | # old version of scipy without multivariate normal. 5 | try: 6 | from . import multivariate_normal 7 | except AttributeError: 8 | pass 9 | 10 | try: 11 | from . import dirichlet 12 | except AttributeError: 13 | pass 14 | -------------------------------------------------------------------------------- /autograd/scipy/stats/beta.py: -------------------------------------------------------------------------------- 1 | import scipy.stats 2 | 3 | import autograd.numpy as np 4 | from autograd.extend import defvjp, primitive 5 | from autograd.numpy.numpy_vjps import unbroadcast_f 6 | from autograd.scipy.special import beta, psi 7 | 8 | cdf = primitive(scipy.stats.beta.cdf) 9 | logpdf = primitive(scipy.stats.beta.logpdf) 10 | pdf = primitive(scipy.stats.beta.pdf) 11 | 12 | 13 | def grad_beta_logpdf_arg0(x, a, b): 14 | return (1 + a * (x - 1) + x * (b - 2)) / (x * (x - 1)) 15 | 16 | 17 | def grad_beta_logpdf_arg1(x, a, b): 18 | return np.log(x) - psi(a) + psi(a + b) 19 | 20 | 21 | def grad_beta_logpdf_arg2(x, a, b): 22 | return np.log1p(-x) - psi(b) + psi(a + b) 23 | 24 | 25 | defvjp( 26 | cdf, 27 | lambda ans, x, a, b: unbroadcast_f( 28 | x, lambda g: g * np.power(x, a - 1) * np.power(1 - x, b - 1) / beta(a, b) 29 | ), 30 | argnums=[0], 31 | ) 32 | defvjp( 33 | logpdf, 34 | lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * grad_beta_logpdf_arg0(x, a, b)), 35 | lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * grad_beta_logpdf_arg1(x, a, b)), 36 | lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * grad_beta_logpdf_arg2(x, a, b)), 37 | ) 38 | defvjp( 39 | pdf, 40 | lambda ans, x, a, b: unbroadcast_f(x, lambda g: g * ans * grad_beta_logpdf_arg0(x, a, b)), 41 | lambda ans, x, a, b: unbroadcast_f(a, lambda g: g * ans * grad_beta_logpdf_arg1(x, a, b)), 42 | lambda ans, x, a, b: unbroadcast_f(b, lambda g: g * ans * grad_beta_logpdf_arg2(x, a, b)), 43 | ) 44 | -------------------------------------------------------------------------------- /autograd/scipy/stats/chi2.py: -------------------------------------------------------------------------------- 1 | import scipy.stats 2 | 3 | import autograd.numpy as np 4 | from autograd.extend import defvjp, primitive 5 | from autograd.numpy.numpy_vjps import unbroadcast_f 6 | from autograd.scipy.special import gamma 7 | 8 | cdf = primitive(scipy.stats.chi2.cdf) 9 | logpdf = primitive(scipy.stats.chi2.logpdf) 10 | pdf = primitive(scipy.stats.chi2.pdf) 11 | 12 | 13 | def grad_chi2_logpdf(x, df): 14 | return np.where(df % 1 == 0, (df - x - 2) / (2 * x), 0) 15 | 16 | 17 | defvjp( 18 | cdf, 19 | lambda ans, x, df: unbroadcast_f( 20 | x, lambda g: g * np.power(2.0, -df / 2) * np.exp(-x / 2) * np.power(x, df / 2 - 1) / gamma(df / 2) 21 | ), 22 | argnums=[0], 23 | ) 24 | defvjp(logpdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * grad_chi2_logpdf(x, df)), argnums=[0]) 25 | defvjp(pdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * ans * grad_chi2_logpdf(x, df)), argnums=[0]) 26 | -------------------------------------------------------------------------------- /autograd/scipy/stats/dirichlet.py: -------------------------------------------------------------------------------- 1 | import scipy.stats 2 | 3 | import autograd.numpy as np 4 | from autograd.extend import defvjp, primitive 5 | from autograd.scipy.special import digamma 6 | 7 | rvs = primitive(scipy.stats.dirichlet.rvs) 8 | pdf = primitive(scipy.stats.dirichlet.pdf) 9 | logpdf = primitive(scipy.stats.dirichlet.logpdf) 10 | 11 | defvjp( 12 | logpdf, 13 | lambda ans, x, alpha: lambda g: g * (alpha - 1) / x, 14 | lambda ans, x, alpha: lambda g: g * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)), 15 | ) 16 | 17 | # Same as log pdf, but multiplied by the pdf (ans). 18 | defvjp( 19 | pdf, 20 | lambda ans, x, alpha: lambda g: g * ans * (alpha - 1) / x, 21 | lambda ans, x, alpha: lambda g: g * ans * (digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)), 22 | ) 23 | -------------------------------------------------------------------------------- /autograd/scipy/stats/gamma.py: -------------------------------------------------------------------------------- 1 | import scipy.stats 2 | 3 | import autograd.numpy as np 4 | from autograd.extend import defvjp, primitive 5 | from autograd.numpy.numpy_vjps import unbroadcast_f 6 | from autograd.scipy.special import gamma, psi 7 | 8 | cdf = primitive(scipy.stats.gamma.cdf) 9 | logpdf = primitive(scipy.stats.gamma.logpdf) 10 | pdf = primitive(scipy.stats.gamma.pdf) 11 | 12 | 13 | def grad_gamma_logpdf_arg0(x, a): 14 | return (a - x - 1) / x 15 | 16 | 17 | def grad_gamma_logpdf_arg1(x, a): 18 | return np.log(x) - psi(a) 19 | 20 | 21 | defvjp( 22 | cdf, 23 | lambda ans, x, a: unbroadcast_f(x, lambda g: g * np.exp(-x) * np.power(x, a - 1) / gamma(a)), 24 | argnums=[0], 25 | ) 26 | defvjp( 27 | logpdf, 28 | lambda ans, x, a: unbroadcast_f(x, lambda g: g * grad_gamma_logpdf_arg0(x, a)), 29 | lambda ans, x, a: unbroadcast_f(a, lambda g: g * grad_gamma_logpdf_arg1(x, a)), 30 | ) 31 | defvjp( 32 | pdf, 33 | lambda ans, x, a: unbroadcast_f(x, lambda g: g * ans * grad_gamma_logpdf_arg0(x, a)), 34 | lambda ans, x, a: unbroadcast_f(a, lambda g: g * ans * grad_gamma_logpdf_arg1(x, a)), 35 | ) 36 | -------------------------------------------------------------------------------- /autograd/scipy/stats/multivariate_normal.py: -------------------------------------------------------------------------------- 1 | import scipy.stats 2 | 3 | import autograd.numpy as np 4 | from autograd.extend import defvjp, primitive 5 | from autograd.numpy.numpy_vjps import unbroadcast_f 6 | 7 | pdf = primitive(scipy.stats.multivariate_normal.pdf) 8 | logpdf = primitive(scipy.stats.multivariate_normal.logpdf) 9 | entropy = primitive(scipy.stats.multivariate_normal.entropy) 10 | 11 | # With thanks to Eric Bresch. 12 | # Some formulas are from 13 | # "An extended collection of matrix derivative results 14 | # for forward and reverse mode algorithmic differentiation" 15 | # by Mike Giles 16 | # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf 17 | 18 | 19 | def generalized_outer_product(x): 20 | if np.ndim(x) == 1: 21 | return np.outer(x, x) 22 | return np.matmul(x, np.swapaxes(x, -1, -2)) 23 | 24 | 25 | def covgrad(x, mean, cov, allow_singular=False): 26 | if allow_singular: 27 | raise NotImplementedError( 28 | "The multivariate normal pdf is not differentiable w.r.t. a singular covariance matix" 29 | ) 30 | J = np.linalg.inv(cov) 31 | solved = np.matmul(J, np.expand_dims(x - mean, -1)) 32 | return 1.0 / 2 * (generalized_outer_product(solved) - J) 33 | 34 | 35 | def solve(allow_singular): 36 | if allow_singular: 37 | return lambda A, x: np.dot(np.linalg.pinv(A), x) 38 | else: 39 | return np.linalg.solve 40 | 41 | 42 | defvjp( 43 | logpdf, 44 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f( 45 | x, lambda g: -np.expand_dims(np.atleast_1d(g), 1) * solve(allow_singular)(cov, (x - mean).T).T 46 | ), 47 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f( 48 | mean, lambda g: np.expand_dims(np.atleast_1d(g), 1) * solve(allow_singular)(cov, (x - mean).T).T 49 | ), 50 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f( 51 | cov, lambda g: np.reshape(g, np.shape(g) + (1, 1)) * covgrad(x, mean, cov, allow_singular) 52 | ), 53 | ) 54 | 55 | # Same as log pdf, but multiplied by the pdf (ans). 56 | defvjp( 57 | pdf, 58 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f( 59 | x, lambda g: -np.expand_dims(np.atleast_1d(ans * g), 1) * solve(allow_singular)(cov, (x - mean).T).T 60 | ), 61 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f( 62 | mean, 63 | lambda g: np.expand_dims(np.atleast_1d(ans * g), 1) * solve(allow_singular)(cov, (x - mean).T).T, 64 | ), 65 | lambda ans, x, mean, cov, allow_singular=False: unbroadcast_f( 66 | cov, lambda g: np.reshape(ans * g, np.shape(g) + (1, 1)) * covgrad(x, mean, cov, allow_singular) 67 | ), 68 | ) 69 | 70 | defvjp(entropy, None, lambda ans, mean, cov: unbroadcast_f(cov, lambda g: 0.5 * g * np.linalg.inv(cov).T)) 71 | -------------------------------------------------------------------------------- /autograd/scipy/stats/norm.py: -------------------------------------------------------------------------------- 1 | """Gradients of the normal distribution.""" 2 | 3 | import scipy.stats 4 | 5 | import autograd.numpy as anp 6 | from autograd.extend import defvjp, primitive 7 | from autograd.numpy.numpy_vjps import unbroadcast_f 8 | 9 | pdf = primitive(scipy.stats.norm.pdf) 10 | cdf = primitive(scipy.stats.norm.cdf) 11 | sf = primitive(scipy.stats.norm.sf) 12 | logpdf = primitive(scipy.stats.norm.logpdf) 13 | logcdf = primitive(scipy.stats.norm.logcdf) 14 | logsf = primitive(scipy.stats.norm.logsf) 15 | 16 | defvjp( 17 | pdf, 18 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * ans * (x - loc) / scale**2), 19 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * ans * (x - loc) / scale**2), 20 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f( 21 | scale, lambda g: g * ans * (((x - loc) / scale) ** 2 - 1.0) / scale 22 | ), 23 | ) 24 | 25 | defvjp( 26 | cdf, 27 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: g * pdf(x, loc, scale)), 28 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: -g * pdf(x, loc, scale)), 29 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f( 30 | scale, lambda g: -g * pdf(x, loc, scale) * (x - loc) / scale 31 | ), 32 | ) 33 | 34 | defvjp( 35 | logpdf, 36 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * (x - loc) / scale**2), 37 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * (x - loc) / scale**2), 38 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f( 39 | scale, lambda g: g * (-1.0 / scale + (x - loc) ** 2 / scale**3) 40 | ), 41 | ) 42 | 43 | defvjp( 44 | logcdf, 45 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f( 46 | x, lambda g: g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale)) 47 | ), 48 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f( 49 | loc, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale)) 50 | ), 51 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f( 52 | scale, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale)) * (x - loc) / scale 53 | ), 54 | ) 55 | 56 | defvjp( 57 | logsf, 58 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f( 59 | x, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale)) 60 | ), 61 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f( 62 | loc, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale)) 63 | ), 64 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f( 65 | scale, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale)) * (x - loc) / scale 66 | ), 67 | ) 68 | 69 | defvjp( 70 | sf, 71 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: -g * pdf(x, loc, scale)), 72 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: g * pdf(x, loc, scale)), 73 | lambda ans, x, loc=0.0, scale=1.0: unbroadcast_f( 74 | scale, lambda g: g * pdf(x, loc, scale) * (x - loc) / scale 75 | ), 76 | ) 77 | -------------------------------------------------------------------------------- /autograd/scipy/stats/poisson.py: -------------------------------------------------------------------------------- 1 | import scipy.stats 2 | 3 | import autograd.numpy as np 4 | from autograd.extend import defvjp, primitive 5 | from autograd.numpy.numpy_vjps import unbroadcast_f 6 | 7 | cdf = primitive(scipy.stats.poisson.cdf) 8 | logpmf = primitive(scipy.stats.poisson.logpmf) 9 | pmf = primitive(scipy.stats.poisson.pmf) 10 | 11 | 12 | def grad_poisson_logpmf(k, mu): 13 | return np.where(k % 1 == 0, k / mu - 1, 0) 14 | 15 | 16 | defvjp(cdf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * -pmf(np.floor(k), mu)), argnums=[1]) 17 | defvjp(logpmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * grad_poisson_logpmf(k, mu)), argnums=[1]) 18 | defvjp( 19 | pmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * ans * grad_poisson_logpmf(k, mu)), argnums=[1] 20 | ) 21 | -------------------------------------------------------------------------------- /autograd/scipy/stats/t.py: -------------------------------------------------------------------------------- 1 | """Gradients of the univariate t distribution.""" 2 | 3 | import scipy.stats 4 | 5 | import autograd.numpy as np 6 | from autograd.extend import defvjp, primitive 7 | from autograd.numpy.numpy_vjps import unbroadcast_f 8 | from autograd.scipy.special import psi 9 | 10 | pdf = primitive(scipy.stats.t.pdf) 11 | cdf = primitive(scipy.stats.t.cdf) 12 | logpdf = primitive(scipy.stats.t.logpdf) 13 | logcdf = primitive(scipy.stats.t.logcdf) 14 | 15 | 16 | def grad_tlogpdf_diff(diff, df): 17 | return -diff * (1.0 + df) / (diff**2 + df) 18 | 19 | 20 | def grad_tlogpdf_x(x, df, loc, scale): 21 | return grad_tlogpdf_diff((x - loc) / scale, df) / scale 22 | 23 | 24 | def grad_tlogpdf_loc(x, df, loc, scale): 25 | return -grad_tlogpdf_diff((x - loc) / scale, df) / scale 26 | 27 | 28 | def grad_tlogpdf_scale(x, df, loc, scale): 29 | diff = x - loc 30 | return -(df * (scale**2 - diff**2)) / (scale * (df * scale**2 + diff**2)) 31 | 32 | 33 | def grad_tlogpdf_df(x, df, loc, scale): 34 | y = (x - loc) / scale 35 | return 0.5 * ( 36 | (y**2 * (df + 1)) / (df * (y**2 + df)) 37 | - np.log(y**2 / df + 1) 38 | - 1.0 / df 39 | - psi(df / 2.0) 40 | + psi((df + 1) / 2.0) 41 | ) 42 | 43 | 44 | defvjp( 45 | pdf, 46 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f( 47 | x, lambda g: g * ans * grad_tlogpdf_x(x, df, loc, scale) 48 | ), 49 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f( 50 | df, lambda g: g * ans * grad_tlogpdf_df(x, df, loc, scale) 51 | ), 52 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f( 53 | loc, lambda g: g * ans * grad_tlogpdf_loc(x, df, loc, scale) 54 | ), 55 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f( 56 | scale, lambda g: g * ans * grad_tlogpdf_scale(x, df, loc, scale) 57 | ), 58 | ) 59 | 60 | defvjp( 61 | cdf, 62 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: g * pdf(x, df, loc, scale)), 63 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(loc, lambda g: -g * pdf(x, df, loc, scale)), 64 | argnums=(0, 2), 65 | ) 66 | 67 | defvjp( 68 | logpdf, 69 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f(x, lambda g: g * grad_tlogpdf_x(x, df, loc, scale)), 70 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f( 71 | df, lambda g: g * grad_tlogpdf_df(x, df, loc, scale) 72 | ), 73 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f( 74 | loc, lambda g: g * grad_tlogpdf_loc(x, df, loc, scale) 75 | ), 76 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f( 77 | scale, lambda g: g * grad_tlogpdf_scale(x, df, loc, scale) 78 | ), 79 | ) 80 | 81 | defvjp( 82 | logcdf, 83 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f( 84 | x, lambda g: g * np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale)) 85 | ), 86 | lambda ans, x, df, loc=0.0, scale=1.0: unbroadcast_f( 87 | loc, lambda g: -g * np.exp(logpdf(x, df, loc, scale) - logcdf(x, df, loc, scale)) 88 | ), 89 | argnums=(0, 2), 90 | ) 91 | -------------------------------------------------------------------------------- /autograd/test_util.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | from .core import make_jvp, make_vjp, vspace 4 | from .wrap_util import get_name, unary_to_nary 5 | 6 | TOL = 1e-6 7 | RTOL = 1e-6 8 | 9 | 10 | def scalar_close(a, b): 11 | return abs(a - b) < TOL or abs(a - b) / abs(a + b) < RTOL 12 | 13 | 14 | EPS = 1e-6 15 | 16 | 17 | def make_numerical_jvp(f, x): 18 | y = f(x) 19 | x_vs, y_vs = vspace(x), vspace(y) 20 | 21 | def jvp(v): 22 | # (f(x + v*eps/2) - f(x - v*eps/2)) / eps 23 | f_x_plus = f(x_vs.add(x, x_vs.scalar_mul(v, EPS / 2))) 24 | f_x_minus = f(x_vs.add(x, x_vs.scalar_mul(v, -EPS / 2))) 25 | neg_f_x_minus = y_vs.scalar_mul(f_x_minus, -1.0) 26 | return y_vs.scalar_mul(y_vs.add(f_x_plus, neg_f_x_minus), 1.0 / EPS) 27 | 28 | return jvp 29 | 30 | 31 | def check_vjp(f, x): 32 | vjp, y = make_vjp(f, x) 33 | jvp = make_numerical_jvp(f, x) 34 | x_vs, y_vs = vspace(x), vspace(y) 35 | x_v, y_v = x_vs.randn(), y_vs.randn() 36 | 37 | vjp_y = x_vs.covector(vjp(y_vs.covector(y_v))) 38 | assert vspace(vjp_y) == x_vs 39 | vjv_exact = x_vs.inner_prod(x_v, vjp_y) 40 | vjv_numeric = y_vs.inner_prod(y_v, jvp(x_v)) 41 | assert scalar_close(vjv_numeric, vjv_exact), ( 42 | "Derivative (VJP) check of {} failed with arg {}:\nanalytic: {}\nnumeric: {}".format( 43 | get_name(f), x, vjv_exact, vjv_numeric 44 | ) 45 | ) 46 | 47 | 48 | def check_jvp(f, x): 49 | jvp = make_jvp(f, x) 50 | jvp_numeric = make_numerical_jvp(f, x) 51 | x_v = vspace(x).randn() 52 | check_equivalent(jvp(x_v)[1], jvp_numeric(x_v)) 53 | 54 | 55 | def check_equivalent(x, y): 56 | x_vs, y_vs = vspace(x), vspace(y) 57 | assert x_vs == y_vs, f"VSpace mismatch:\nx: {x_vs}\ny: {y_vs}" 58 | v = x_vs.randn() 59 | assert scalar_close(x_vs.inner_prod(x, v), x_vs.inner_prod(y, v)), f"Value mismatch:\nx: {x}\ny: {y}" 60 | 61 | 62 | @unary_to_nary 63 | def check_grads(f, x, modes=["fwd", "rev"], order=2): 64 | assert all(m in ["fwd", "rev"] for m in modes) 65 | if "fwd" in modes: 66 | check_jvp(f, x) 67 | if order > 1: 68 | grad_f = lambda x, v: make_jvp(f, x)(v)[1] 69 | grad_f.__name__ = f"jvp_{get_name(f)}" 70 | v = vspace(x).randn() 71 | check_grads(grad_f, (0, 1), modes, order=order - 1)(x, v) 72 | if "rev" in modes: 73 | check_vjp(f, x) 74 | if order > 1: 75 | grad_f = lambda x, v: make_vjp(f, x)[0](v) 76 | grad_f.__name__ = f"vjp_{get_name(f)}" 77 | v = vspace(f(x)).randn() 78 | check_grads(grad_f, (0, 1), modes, order=order - 1)(x, v) 79 | 80 | 81 | def combo_check(fun, *args, **kwargs): 82 | # Tests all combinations of args and kwargs given. 83 | _check_grads = lambda f: check_grads(f, *args, **kwargs) 84 | 85 | def _combo_check(*args, **kwargs): 86 | kwarg_key_vals = [[(k, x) for x in xs] for k, xs in kwargs.items()] 87 | for _args in product(*args): 88 | for _kwargs in product(*kwarg_key_vals): 89 | _check_grads(fun)(*_args, **dict(_kwargs)) 90 | 91 | return _combo_check 92 | -------------------------------------------------------------------------------- /autograd/tracer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import defaultdict 3 | from contextlib import contextmanager 4 | 5 | from .util import subvals, toposort 6 | from .wrap_util import wraps 7 | 8 | 9 | def trace(start_node, fun, x): 10 | with trace_stack.new_trace() as t: 11 | start_box = new_box(x, t, start_node) 12 | end_box = fun(start_box) 13 | if isbox(end_box) and end_box._trace == start_box._trace: 14 | return end_box._value, end_box._node 15 | else: 16 | warnings.warn("Output seems independent of input.") 17 | return end_box, None 18 | 19 | 20 | class Node: 21 | __slots__ = [] 22 | 23 | def __init__(self, value, fun, args, kwargs, parent_argnums, parents): 24 | assert False 25 | 26 | def initialize_root(self, *args, **kwargs): 27 | assert False 28 | 29 | @classmethod 30 | def new_root(cls, *args, **kwargs): 31 | root = cls.__new__(cls) 32 | root.initialize_root(*args, **kwargs) 33 | return root 34 | 35 | 36 | def primitive(f_raw): 37 | """ 38 | Wraps a function so that its gradient can be specified and its invocation 39 | can be recorded. For examples, see the docs.""" 40 | 41 | @wraps(f_raw) 42 | def f_wrapped(*args, **kwargs): 43 | boxed_args, trace, node_constructor = find_top_boxed_args(args) 44 | if boxed_args: 45 | argvals = subvals(args, [(argnum, box._value) for argnum, box in boxed_args]) 46 | if f_wrapped in notrace_primitives[node_constructor]: 47 | return f_wrapped(*argvals, **kwargs) 48 | parents = tuple(box._node for _, box in boxed_args) 49 | argnums = tuple(argnum for argnum, _ in boxed_args) 50 | ans = f_wrapped(*argvals, **kwargs) 51 | node = node_constructor(ans, f_wrapped, argvals, kwargs, argnums, parents) 52 | return new_box(ans, trace, node) 53 | else: 54 | return f_raw(*args, **kwargs) 55 | 56 | f_wrapped.fun = f_raw 57 | f_wrapped._is_autograd_primitive = True 58 | return f_wrapped 59 | 60 | 61 | notrace_primitives = defaultdict(set) 62 | 63 | 64 | def register_notrace(trace_type, primitive_fun): 65 | notrace_primitives[trace_type].add(primitive_fun) 66 | 67 | 68 | def notrace_primitive(f_raw): 69 | @wraps(f_raw) 70 | def f_wrapped(*args, **kwargs): 71 | argvals = map(getval, args) 72 | return f_raw(*argvals, **kwargs) 73 | 74 | f_wrapped._is_primitive = True 75 | return f_wrapped 76 | 77 | 78 | def find_top_boxed_args(args): 79 | top_trace = -1 80 | top_boxes = [] 81 | top_node_type = None 82 | for argnum, arg in enumerate(args): 83 | if isbox(arg): 84 | trace = arg._trace 85 | if trace > top_trace: 86 | top_boxes = [(argnum, arg)] 87 | top_trace = trace 88 | top_node_type = type(arg._node) 89 | elif trace == top_trace: 90 | top_boxes.append((argnum, arg)) 91 | return top_boxes, top_trace, top_node_type 92 | 93 | 94 | class TraceStack: 95 | def __init__(self): 96 | self.top = -1 97 | 98 | @contextmanager 99 | def new_trace(self): 100 | self.top += 1 101 | yield self.top 102 | self.top -= 1 103 | 104 | 105 | trace_stack = TraceStack() 106 | 107 | 108 | class Box: 109 | type_mappings = {} 110 | types = set() 111 | 112 | __slots__ = ["_value", "_trace", "_node"] 113 | 114 | def __init__(self, value, trace, node): 115 | self._value = value 116 | self._node = node 117 | self._trace = trace 118 | 119 | def __bool__(self): 120 | return bool(self._value) 121 | 122 | __nonzero__ = __bool__ 123 | 124 | def __str__(self): 125 | return f"Autograd {type(self).__name__} with value {str(self._value)}" 126 | 127 | @classmethod 128 | def register(cls, value_type): 129 | Box.types.add(cls) 130 | Box.type_mappings[value_type] = cls 131 | Box.type_mappings[cls] = cls 132 | 133 | 134 | box_type_mappings = Box.type_mappings 135 | 136 | 137 | def new_box(value, trace, node): 138 | try: 139 | return box_type_mappings[type(value)](value, trace, node) 140 | except KeyError: 141 | raise TypeError(f"Can't differentiate w.r.t. type {type(value)}") 142 | 143 | 144 | box_types = Box.types 145 | isbox = lambda x: type(x) in box_types # almost 3X faster than isinstance(x, Box) 146 | getval = lambda x: getval(x._value) if isbox(x) else x 147 | -------------------------------------------------------------------------------- /autograd/util.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | 4 | def subvals(x, ivs): 5 | x_ = list(x) 6 | for i, v in ivs: 7 | x_[i] = v 8 | return tuple(x_) 9 | 10 | 11 | def subval(x, i, v): 12 | x_ = list(x) 13 | x_[i] = v 14 | return tuple(x_) 15 | 16 | 17 | def func(f): 18 | return f 19 | 20 | 21 | def toposort(end_node, parents=operator.attrgetter("parents")): 22 | child_counts = {} 23 | stack = [end_node] 24 | while stack: 25 | node = stack.pop() 26 | if node in child_counts: 27 | child_counts[node] += 1 28 | else: 29 | child_counts[node] = 1 30 | stack.extend(parents(node)) 31 | 32 | childless_nodes = [end_node] 33 | while childless_nodes: 34 | node = childless_nodes.pop() 35 | yield node 36 | for parent in parents(node): 37 | if child_counts[parent] == 1: 38 | childless_nodes.append(parent) 39 | else: 40 | child_counts[parent] -= 1 41 | 42 | 43 | # -------------------- deprecation warnings ----------------------- 44 | 45 | import warnings 46 | 47 | deprecation_msg = """ 48 | The quick_grad_check function is deprecated. See the update guide: 49 | https://github.com/HIPS/autograd/blob/master/docs/updateguide.md""" 50 | 51 | 52 | def quick_grad_check( 53 | fun, arg0, extra_args=(), kwargs={}, verbose=True, eps=1e-4, rtol=1e-4, atol=1e-6, rs=None 54 | ): 55 | warnings.warn(deprecation_msg) 56 | from autograd.test_util import check_grads 57 | 58 | fun_ = lambda arg0: fun(arg0, *extra_args, **kwargs) 59 | check_grads(fun_, modes=["rev"], order=1)(arg0) 60 | -------------------------------------------------------------------------------- /autograd/wrap_util.py: -------------------------------------------------------------------------------- 1 | from .util import subvals 2 | 3 | 4 | def unary_to_nary(unary_operator): 5 | @wraps(unary_operator) 6 | def nary_operator(fun, argnum=0, *nary_op_args, **nary_op_kwargs): 7 | assert type(argnum) in (int, tuple, list), argnum 8 | 9 | @wrap_nary_f(fun, unary_operator, argnum) 10 | def nary_f(*args, **kwargs): 11 | @wraps(fun) 12 | def unary_f(x): 13 | if isinstance(argnum, int): 14 | subargs = subvals(args, [(argnum, x)]) 15 | else: 16 | subargs = subvals(args, zip(argnum, x)) 17 | return fun(*subargs, **kwargs) 18 | 19 | if isinstance(argnum, int): 20 | x = args[argnum] 21 | else: 22 | x = tuple(args[i] for i in argnum) 23 | return unary_operator(unary_f, x, *nary_op_args, **nary_op_kwargs) 24 | 25 | return nary_f 26 | 27 | return nary_operator 28 | 29 | 30 | def wraps(fun, namestr="{fun}", docstr="{doc}", **kwargs): 31 | def _wraps(f): 32 | try: 33 | f.__name__ = namestr.format(fun=get_name(fun), **kwargs) 34 | f.__doc__ = docstr.format(fun=get_name(fun), doc=get_doc(fun), **kwargs) 35 | except BaseException: 36 | pass 37 | finally: 38 | return f 39 | 40 | return _wraps 41 | 42 | 43 | def wrap_nary_f(fun, op, argnum): 44 | namestr = "{op}_of_{fun}_wrt_argnum_{argnum}" 45 | docstr = """\ 46 | {op} of function {fun} with respect to argument number {argnum}. Takes the 47 | same arguments as {fun} but returns the {op}. 48 | """ 49 | return wraps(fun, namestr, docstr, op=get_name(op), argnum=argnum) 50 | 51 | 52 | get_name = lambda f: getattr(f, "__name__", "[unknown name]") 53 | get_doc = lambda f: getattr(f, "__doc__", "") 54 | -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/benchmarks/__init__.py -------------------------------------------------------------------------------- /benchmarks/asv.conf.json.sample: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "project": "autograd", 4 | "project_url": "http://github.com/hips/autograd", 5 | "branches": ["master"], 6 | "dvcs": "git", 7 | "environment_type": "virtualenv", 8 | "install_timeout": 600, 9 | "repo" : "..", 10 | "benchmark_dir" : ".", 11 | "env_dir" : "../.asv/env", 12 | "results_dir" : "../.asv/results", 13 | "html_dir" : "../.asv/html", 14 | } 15 | -------------------------------------------------------------------------------- /benchmarks/bench_core.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | 3 | import autograd.numpy as np 4 | from autograd import grad 5 | 6 | try: 7 | from autograd.core import VJPNode, backward_pass, vspace 8 | from autograd.tracer import new_box, trace 9 | 10 | MASTER_BRANCH = False 11 | except ImportError: 12 | from autograd.core import backward_pass, forward_pass, new_progenitor, vspace 13 | 14 | MASTER_BRANCH = True 15 | 16 | 17 | ## SHORT FUNCTION 18 | def f_short(x): 19 | return x**2 20 | 21 | 22 | def time_short_fun(): 23 | f_short(2.0) 24 | 25 | 26 | def time_short_forward_pass(): 27 | if MASTER_BRANCH: 28 | forward_pass(f_short, (2.0,), {}) 29 | else: 30 | start_node = VJPNode.new_root() 31 | trace(start_node, f_short, x) 32 | 33 | 34 | def time_short_backward_pass(): 35 | if MASTER_BRANCH: 36 | backward_pass(1.0, short_end_node, short_start_node) 37 | else: 38 | backward_pass(1.0, short_end_node) 39 | 40 | 41 | def time_short_grad(): 42 | grad(f_short)(2.0) 43 | 44 | 45 | ## LONG FUNCTION 46 | def f_long(x): 47 | for i in range(50): 48 | x = np.sin(x) 49 | return x 50 | 51 | 52 | def time_long_fun(): 53 | f_long(2.0) 54 | 55 | 56 | def time_long_forward_pass(): 57 | if MASTER_BRANCH: 58 | forward_pass(f_long, (2.0,), {}) 59 | else: 60 | start_node = VJPNode.new_root() 61 | trace(start_node, f_long, x) 62 | 63 | 64 | def time_long_backward_pass(): 65 | if MASTER_BRANCH: 66 | backward_pass(1.0, long_end_node, long_start_node) 67 | else: 68 | backward_pass(1.0, long_end_node) 69 | 70 | 71 | def time_long_grad(): 72 | grad(f_long)(2.0) 73 | 74 | 75 | ## 'PEARLMUTTER TEST' FUNCTION 76 | def fan_out_fan_in(x): 77 | for i in range(10**4): 78 | x = (x + x) / 2.0 79 | return np.sum(x) 80 | 81 | 82 | def time_fan_out_fan_in_fun(): 83 | fan_out_fan_in(2.0) 84 | 85 | 86 | def time_fan_out_fan_in_forward_pass(): 87 | if MASTER_BRANCH: 88 | forward_pass(fan_out_fan_in, (2.0,), {}) 89 | else: 90 | start_node = VJPNode.new_root() 91 | trace(start_node, fan_out_fan_in, x) 92 | 93 | 94 | def time_fan_out_fan_in_backward_pass(): 95 | if MASTER_BRANCH: 96 | backward_pass(1.0, fan_end_node, fan_start_node) 97 | else: 98 | backward_pass(1.0, fan_end_node) 99 | 100 | 101 | def time_fan_out_fan_in_grad(): 102 | grad(fan_out_fan_in)(2.0) 103 | 104 | 105 | ## UNIT BENCHMARKS 106 | def time_vspace_float(): 107 | vspace(1.0) 108 | 109 | 110 | A = np.array([[1.0, 2.0, 3.0]]) 111 | 112 | 113 | def time_vspace_array(): 114 | vspace(A) 115 | 116 | 117 | def time_new_box_float(): 118 | new_box(1.0, 0, start_node) 119 | 120 | 121 | def time_new_box_array(): 122 | new_box(A, 0, start_node) 123 | 124 | 125 | def time_exp_call(): 126 | onp.exp(2.0) 127 | 128 | 129 | def time_exp_primitive_call_unboxed(): 130 | np.exp(2.0) 131 | 132 | 133 | def time_exp_primitive_call_boxed(): 134 | if MASTER_BRANCH: 135 | np.exp(progenitor) 136 | else: 137 | np.exp(start_box) 138 | 139 | 140 | def time_no_autograd_control(): 141 | # Test whether the benchmarking machine is running slowly independent of autograd 142 | A = np.random.randn(200, 200) 143 | np.dot(A, A) 144 | 145 | 146 | if MASTER_BRANCH: 147 | short_start_node, short_end_node = forward_pass(f_short, (2.0,), {}) 148 | long_start_node, long_end_node = forward_pass(f_long, (2.0,), {}) 149 | fan_start_node, fan_end_node = forward_pass(fan_out_fan_in, (2.0,), {}) 150 | progenitor = new_progenitor(2.0) 151 | else: 152 | x = 2.0 153 | start_node = VJPNode.new_root() 154 | start_box = new_box(x, 0, start_node) 155 | _, short_end_node = trace(VJPNode.new_root(), f_short, x) 156 | _, long_end_node = trace(VJPNode.new_root(), f_long, x) 157 | _, fan_end_node = trace(VJPNode.new_root(), fan_out_fan_in, x) 158 | -------------------------------------------------------------------------------- /benchmarks/bench_mem.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | from autograd import grad 3 | 4 | 5 | def peakmem_needless_nodes(): 6 | N, M = 1000, 100 7 | 8 | def fun(x): 9 | for i in range(M): 10 | x = x + 1 11 | return np.sum(x) 12 | 13 | grad(fun)(np.zeros((N, N))) 14 | -------------------------------------------------------------------------------- /benchmarks/bench_numpy_vjps.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | import autograd.numpy.random as npr 3 | from autograd import make_vjp 4 | 5 | dot_0 = lambda a, b, g: make_vjp(np.dot, argnum=0)(a, b)[0](g) 6 | dot_1 = lambda a, b, g: make_vjp(np.dot, argnum=1)(a, b)[0](g) 7 | 8 | dot_0_0 = lambda a, b, g: make_vjp(dot_0, argnum=0)(a, b, g)[0](a) 9 | dot_0_1 = lambda a, b, g: make_vjp(dot_0, argnum=1)(a, b, g)[0](a) 10 | dot_0_2 = lambda a, b, g: make_vjp(dot_0, argnum=2)(a, b, g)[0](a) 11 | 12 | dot_1_0 = lambda a, b, g: make_vjp(dot_1, argnum=0)(a, b, g)[0](b) 13 | dot_1_1 = lambda a, b, g: make_vjp(dot_1, argnum=1)(a, b, g)[0](b) 14 | dot_1_2 = lambda a, b, g: make_vjp(dot_1, argnum=2)(a, b, g)[0](b) 15 | 16 | a = npr.randn(2, 3, 4, 5) 17 | b = npr.randn(2, 3, 5, 4) 18 | g = npr.randn(2, 3, 4, 2, 3, 4) 19 | 20 | 21 | def time_dot_0(): 22 | dot_0(a, b, g) 23 | 24 | 25 | def time_dot_1(): 26 | dot_1(a, b, g) 27 | 28 | 29 | def time_dot_0_0(): 30 | dot_0_0(a, b, g) 31 | 32 | 33 | def time_dot_0_1(): 34 | dot_0_1(a, b, g) 35 | 36 | 37 | def time_dot_0_2(): 38 | dot_0_2(a, b, g) 39 | 40 | 41 | def time_dot_1_0(): 42 | dot_1_0(a, b, g) 43 | 44 | 45 | def time_dot_1_1(): 46 | dot_1_1(a, b, g) 47 | 48 | 49 | def time_dot_1_2(): 50 | dot_1_2(a, b, g) 51 | 52 | 53 | tensordot_0 = lambda A, B, G: make_vjp(np.tensordot, argnum=0)(A, B, 2)[0](G) 54 | tensordot_1 = lambda A, B, G: make_vjp(np.tensordot, argnum=1)(A, B, 2)[0](G) 55 | 56 | tensordot_0_0 = lambda A, B, G: make_vjp(tensordot_0, argnum=0)(A, B, G)[0](A) 57 | tensordot_0_1 = lambda A, B, G: make_vjp(tensordot_0, argnum=1)(A, B, G)[0](A) 58 | tensordot_0_2 = lambda A, B, G: make_vjp(tensordot_0, argnum=2)(A, B, G)[0](A) 59 | 60 | tensordot_1_0 = lambda A, B, G: make_vjp(tensordot_1, argnum=0)(A, B, G)[0](B) 61 | tensordot_1_1 = lambda A, B, G: make_vjp(tensordot_1, argnum=1)(A, B, G)[0](B) 62 | tensordot_1_2 = lambda A, B, G: make_vjp(tensordot_1, argnum=2)(A, B, G)[0](B) 63 | 64 | A = npr.randn(2, 3, 5, 4) 65 | B = npr.randn(5, 4, 2, 3) 66 | G = npr.randn(2, 3, 2, 3) 67 | 68 | 69 | def time_tensordot_0(): 70 | tensordot_0(A, B, G) 71 | 72 | 73 | def time_tensordot_1(): 74 | tensordot_1(A, B, G) 75 | 76 | 77 | def time_tensordot_0_0(): 78 | tensordot_0_0(A, B, G) 79 | 80 | 81 | def time_tensordot_0_1(): 82 | tensordot_0_1(A, B, G) 83 | 84 | 85 | def time_tensordot_0_2(): 86 | tensordot_0_2(A, B, G) 87 | 88 | 89 | def time_tensordot_1_0(): 90 | tensordot_1_0(A, B, G) 91 | 92 | 93 | def time_tensordot_1_1(): 94 | tensordot_1_1(A, B, G) 95 | 96 | 97 | def time_tensordot_1_2(): 98 | tensordot_1_2(A, B, G) 99 | -------------------------------------------------------------------------------- /benchmarks/bench_util.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | import autograd.numpy.random as npr 3 | from autograd import grad 4 | 5 | try: 6 | from autograd.misc.flatten import flatten 7 | except ImportError: 8 | from autograd.util import flatten 9 | 10 | 11 | def time_flatten(): 12 | val = { 13 | "k": npr.random((4, 4)), 14 | "k2": npr.random((3, 3)), 15 | "k3": 3.0, 16 | "k4": [1.0, 4.0, 7.0, 9.0], 17 | "k5": np.array([4.0, 5.0, 6.0]), 18 | "k6": np.array([[7.0, 8.0], [9.0, 10.0]]), 19 | } 20 | 21 | vect, unflatten = flatten(val) 22 | val_recovered = unflatten(vect) 23 | vect_2, _ = flatten(val_recovered) 24 | 25 | 26 | # def time_vspace_flatten(): 27 | # val = {'k': npr.random((4, 4)), 28 | # 'k2': npr.random((3, 3)), 29 | # 'k3': 3.0, 30 | # 'k4': [1.0, 4.0, 7.0, 9.0], 31 | # 'k5': np.array([4., 5., 6.]), 32 | # 'k6': np.array([[7., 8.], [9., 10.]])} 33 | 34 | # vspace_flatten(val) 35 | 36 | 37 | def time_grad_flatten(): 38 | val = { 39 | "k": npr.random((4, 4)), 40 | "k2": npr.random((3, 3)), 41 | "k3": 3.0, 42 | "k4": [1.0, 4.0, 7.0, 9.0], 43 | "k5": np.array([4.0, 5.0, 6.0]), 44 | "k6": np.array([[7.0, 8.0], [9.0, 10.0]]), 45 | } 46 | 47 | vect, unflatten = flatten(val) 48 | 49 | def fun(vec): 50 | v = unflatten(vec) 51 | return np.sum(v["k5"]) + np.sum(v["k6"]) 52 | 53 | grad(fun)(vect) 54 | -------------------------------------------------------------------------------- /conda_recipe/conda.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: autograd 3 | # there are ways to derive version from other sources; for now, it's hard-coded 4 | version: 1.1.1 5 | 6 | source: 7 | {% if not environ.get('BINSTAR_PLATFORM', None) %} 8 | git_url: ../ 9 | {% else %} 10 | # we're building on binstar, we already have the repo; treat as local path 11 | path: ../ 12 | {% endif %} 13 | 14 | requirements: 15 | build: 16 | - python 17 | - hatch 18 | - hatchling 19 | - future 20 | - numpy >=1.9 21 | 22 | run: 23 | - python 24 | - future 25 | - numpy >=1.9 26 | 27 | build: 28 | script: pip install . --no-deps 29 | 30 | test: 31 | # Python imports 32 | imports: 33 | - autograd 34 | - autograd.numpy 35 | 36 | about: 37 | home: https://github.com/HIPS/autograd 38 | license: MIT 39 | summary: 'Efficiently computes derivatives of numpy code.' 40 | -------------------------------------------------------------------------------- /docs/updateguide.md: -------------------------------------------------------------------------------- 1 | # Autograd v1.2 update guide 2 | 3 | Autograd v1.2 changed the interface for defining custom vector-Jacobian 4 | products (VJPs). Luckily the change only affects users writing custom VJPs, and 5 | should only require minor updates to the custom VJP code. 6 | 7 | This guide is meant to explain why we made these changes (and others) in 8 | Autograd v1.2, and to summarize everything you need to know to update your 9 | custom VJP code. 10 | 11 | - [Reasoning for the changes](#reasoning-for-the-changes) 12 | - [New defvjp interface](#new-defvjp-interface) 13 | - [Gradient checking](#gradient-checking) 14 | 15 | ## Reasoning for the changes 16 | 17 | Here are some of the most important reasons for this update: 18 | 1. To allow us to make Autograd faster and more memory efficient, we staged the 19 | VJP functions to allow more garbage collection and eliminated almost all of 20 | the vspace metadata checks. 21 | 1. Forward-mode now comes built-in with `make_jvp`. 22 | 1. There's now a clear extension API in `autograd.extend`, so you can write 23 | custom VJPs or wrap your own numerical libraries. 24 | 1. Autograd is now backend-independent, making it easy to wrap other numerical 25 | libraries. 26 | 1. Autograd's tracing functionality is now parameterized and easily reusable, 27 | and we added some new tracers for 28 | [computation graph visualization](https://github.com/hips/autograd/blob/master/examples/dot_graph.py) 29 | and 30 | [pure-Python constant folding](https://github.com/hips/autograd/blob/master/autograd/misc/tracers.py). 31 | 1. More exhaustive, fast reverse- and forward-mode checking with `autograd.test_util.check_grads`. 32 | 1. Expensive VJPs can share work across arguments using `defvjp_argnums`. 33 | 1. These changes enabled some internal cleanups, and more features to come! 34 | 35 | ## New defvjp interface 36 | First, here's an example of the old way to write custom primitives and VJPs: 37 | ```python 38 | import autograd.numpy as np 39 | from autograd import primitive 40 | 41 | @primitive 42 | def func(x, y, z): 43 | assert z != 0 44 | return x * y**2 45 | 46 | func.defvjp(lambda g, ans, vs, gvs, x, y, z: g * y**2) 47 | func.defvjp(lambda g, ans, vs, gvs, x, y, z: 2 * g * x * y, argnum=1) 48 | func.defvjp_is_zero(argnums=[2]) 49 | ``` 50 | 51 | Here's the new way to write custom VJPs for that same primitive: 52 | ```python 53 | import autograd.numpy as np 54 | from autograd.extend import primitive, defvjp # defvjp is now a function 55 | 56 | # primitives look the same as before 57 | @primitive 58 | def func(x, y, z): 59 | assert z != 0 60 | return x * y**2 61 | 62 | # but we call defvjp differently 63 | defvjp(func, 64 | lambda ans, x, y, z: lambda g: g * y**2, 65 | lambda ans, x, y, z: lambda g: 2 * g * x * y, 66 | None) 67 | ``` 68 | 69 | Here's a list of the `defvjp` changes illustrated in that example: 70 | 1. `defvjp` is a function, rather than a method on the `primitive` class. (Actually, `primitive` is now just a function, and no longer a class.) As a result, `func.defvjp(...)` became `defvjp(func, ...)`. 71 | 1. VJPs are staged, so that instead of writing `lambda g, ans, vs, gvs, *args: ...` we write `lambda ans, *args: lambda g: ...`. This change enables a lot of automatic garbage collection. In the above example, if we were differentiating only with respect to `x` argument of `func`, because the VJP for `func` with respect to argument index 0 doesn't need the values of `x` or `z` from the forward pass, those values aren't stored and can instead be immediately garbage-collected. 72 | 1. There are no more `vs` and `gvs` arguments. These usually weren't used, and computing vspace metadata for every intermediate value proved to contribute significant overhead for some programs. Autograd now avoids computing vspace metadata unless necessary. 73 | 1. `defvjp` lets you define VJPs with respect to multiple arguments at once, and the argnum(s) involved are often implicit. 74 | 75 | Here's another example, this time showing how to define VJPs with respect to 76 | specific argnums, leaving the others undefined. 77 | ```python 78 | # OLD way to leave some VJPs undefined 79 | func.defvjp(lambda g, ans, vs, gvs, x, y, z, w: ..., argnum=2) 80 | func.defvjp(lambda g, ans, vs, gvs, x, y, z, w: ..., argnum=3) 81 | 82 | # NEW way to leave some VJPs undefined 83 | defvjp(func, 84 | lambda ans, x, y, z, w: lambda g: ..., 85 | lambda ans, x, y, z, w: lambda g: ..., 86 | argnums=[2, 3]) 87 | ``` 88 | 89 | ## Gradient checking 90 | Here's how to do gradient checking, whether on a composite function or on your 91 | primitive with a custom VJP: 92 | 93 | ```python 94 | from autograd.test_util import check_grads 95 | 96 | # check reverse-mode to second order 97 | check_grads(my_func, modes=['rev'], order=2)(*args_for_my_func) 98 | ``` 99 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/__init__.py -------------------------------------------------------------------------------- /examples/bayesian_neural_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/bayesian_neural_net.png -------------------------------------------------------------------------------- /examples/bayesian_neural_net.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from black_box_svi import black_box_variational_inference 3 | 4 | import autograd.numpy as np 5 | import autograd.numpy.random as npr 6 | from autograd.misc.optimizers import adam 7 | 8 | 9 | def make_nn_funs(layer_sizes, L2_reg, noise_variance, nonlinearity=np.tanh): 10 | """These functions implement a standard multi-layer perceptron, 11 | vectorized over both training examples and weight samples.""" 12 | shapes = list(zip(layer_sizes[:-1], layer_sizes[1:])) 13 | num_weights = sum((m + 1) * n for m, n in shapes) 14 | 15 | def unpack_layers(weights): 16 | num_weight_sets = len(weights) 17 | for m, n in shapes: 18 | yield ( 19 | weights[:, : m * n].reshape((num_weight_sets, m, n)), 20 | weights[:, m * n : m * n + n].reshape((num_weight_sets, 1, n)), 21 | ) 22 | weights = weights[:, (m + 1) * n :] 23 | 24 | def predictions(weights, inputs): 25 | """weights is shape (num_weight_samples x num_weights) 26 | inputs is shape (num_datapoints x D)""" 27 | inputs = np.expand_dims(inputs, 0) 28 | for W, b in unpack_layers(weights): 29 | outputs = np.einsum("mnd,mdo->mno", inputs, W) + b 30 | inputs = nonlinearity(outputs) 31 | return outputs 32 | 33 | def logprob(weights, inputs, targets): 34 | log_prior = -L2_reg * np.sum(weights**2, axis=1) 35 | preds = predictions(weights, inputs) 36 | log_lik = -np.sum((preds - targets) ** 2, axis=1)[:, 0] / noise_variance 37 | return log_prior + log_lik 38 | 39 | return num_weights, predictions, logprob 40 | 41 | 42 | def build_toy_dataset(n_data=40, noise_std=0.1): 43 | D = 1 44 | rs = npr.RandomState(0) 45 | inputs = np.concatenate([np.linspace(0, 2, num=n_data / 2), np.linspace(6, 8, num=n_data / 2)]) 46 | targets = np.cos(inputs) + rs.randn(n_data) * noise_std 47 | inputs = (inputs - 4.0) / 4.0 48 | inputs = inputs.reshape((len(inputs), D)) 49 | targets = targets.reshape((len(targets), D)) 50 | return inputs, targets 51 | 52 | 53 | if __name__ == "__main__": 54 | # Specify inference problem by its unnormalized log-posterior. 55 | rbf = lambda x: np.exp(-(x**2)) 56 | relu = lambda x: np.maximum(x, 0.0) 57 | num_weights, predictions, logprob = make_nn_funs( 58 | layer_sizes=[1, 20, 20, 1], L2_reg=0.1, noise_variance=0.01, nonlinearity=rbf 59 | ) 60 | 61 | inputs, targets = build_toy_dataset() 62 | log_posterior = lambda weights, t: logprob(weights, inputs, targets) 63 | 64 | # Build variational objective. 65 | objective, gradient, unpack_params = black_box_variational_inference( 66 | log_posterior, num_weights, num_samples=20 67 | ) 68 | 69 | # Set up figure. 70 | fig = plt.figure(figsize=(12, 8), facecolor="white") 71 | ax = fig.add_subplot(111, frameon=False) 72 | plt.ion() 73 | plt.show(block=False) 74 | 75 | def callback(params, t, g): 76 | print(f"Iteration {t} lower bound {-objective(params, t)}") 77 | 78 | # Sample functions from posterior. 79 | rs = npr.RandomState(0) 80 | mean, log_std = unpack_params(params) 81 | # rs = npr.RandomState(0) 82 | sample_weights = rs.randn(10, num_weights) * np.exp(log_std) + mean 83 | plot_inputs = np.linspace(-8, 8, num=400) 84 | outputs = predictions(sample_weights, np.expand_dims(plot_inputs, 1)) 85 | 86 | # Plot data and functions. 87 | plt.cla() 88 | ax.plot(inputs.ravel(), targets.ravel(), "bx") 89 | ax.plot(plot_inputs, outputs[:, :, 0].T) 90 | ax.set_ylim([-2, 3]) 91 | plt.draw() 92 | plt.pause(1.0 / 60.0) 93 | 94 | # Initialize variational parameters 95 | rs = npr.RandomState(0) 96 | init_mean = rs.randn(num_weights) 97 | init_log_std = -5 * np.ones(num_weights) 98 | init_var_params = np.concatenate([init_mean, init_log_std]) 99 | 100 | print("Optimizing variational parameters...") 101 | variational_params = adam(gradient, init_var_params, step_size=0.1, num_iters=1000, callback=callback) 102 | -------------------------------------------------------------------------------- /examples/black_box_svi.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import autograd.numpy as np 4 | import autograd.numpy.random as npr 5 | import autograd.scipy.stats.multivariate_normal as mvn 6 | import autograd.scipy.stats.norm as norm 7 | from autograd import grad 8 | from autograd.misc.optimizers import adam 9 | 10 | 11 | def black_box_variational_inference(logprob, D, num_samples): 12 | """Implements http://arxiv.org/abs/1401.0118, and uses the 13 | local reparameterization trick from http://arxiv.org/abs/1506.02557""" 14 | 15 | def unpack_params(params): 16 | # Variational dist is a diagonal Gaussian. 17 | mean, log_std = params[:D], params[D:] 18 | return mean, log_std 19 | 20 | def gaussian_entropy(log_std): 21 | return 0.5 * D * (1.0 + np.log(2 * np.pi)) + np.sum(log_std) 22 | 23 | rs = npr.RandomState(0) 24 | 25 | def variational_objective(params, t): 26 | """Provides a stochastic estimate of the variational lower bound.""" 27 | mean, log_std = unpack_params(params) 28 | samples = rs.randn(num_samples, D) * np.exp(log_std) + mean 29 | lower_bound = gaussian_entropy(log_std) + np.mean(logprob(samples, t)) 30 | return -lower_bound 31 | 32 | gradient = grad(variational_objective) 33 | 34 | return variational_objective, gradient, unpack_params 35 | 36 | 37 | if __name__ == "__main__": 38 | # Specify an inference problem by its unnormalized log-density. 39 | D = 2 40 | 41 | def log_density(x, t): 42 | mu, log_sigma = x[:, 0], x[:, 1] 43 | sigma_density = norm.logpdf(log_sigma, 0, 1.35) 44 | mu_density = norm.logpdf(mu, 0, np.exp(log_sigma)) 45 | return sigma_density + mu_density 46 | 47 | # Build variational objective. 48 | objective, gradient, unpack_params = black_box_variational_inference(log_density, D, num_samples=2000) 49 | 50 | # Set up plotting code 51 | def plot_isocontours(ax, func, xlimits=[-2, 2], ylimits=[-4, 2], numticks=101): 52 | x = np.linspace(*xlimits, num=numticks) 53 | y = np.linspace(*ylimits, num=numticks) 54 | X, Y = np.meshgrid(x, y) 55 | zs = func(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T) 56 | Z = zs.reshape(X.shape) 57 | plt.contour(X, Y, Z) 58 | ax.set_yticks([]) 59 | ax.set_xticks([]) 60 | 61 | # Set up figure. 62 | fig = plt.figure(figsize=(8, 8), facecolor="white") 63 | ax = fig.add_subplot(111, frameon=False) 64 | plt.ion() 65 | plt.show(block=False) 66 | 67 | def callback(params, t, g): 68 | print(f"Iteration {t} lower bound {-objective(params, t)}") 69 | 70 | plt.cla() 71 | target_distribution = lambda x: np.exp(log_density(x, t)) 72 | plot_isocontours(ax, target_distribution) 73 | 74 | mean, log_std = unpack_params(params) 75 | variational_contour = lambda x: mvn.pdf(x, mean, np.diag(np.exp(2 * log_std))) 76 | plot_isocontours(ax, variational_contour) 77 | plt.draw() 78 | plt.pause(1.0 / 30.0) 79 | 80 | print("Optimizing variational parameters...") 81 | init_mean = -1 * np.ones(D) 82 | init_log_std = -5 * np.ones(D) 83 | init_var_params = np.concatenate([init_mean, init_log_std]) 84 | variational_params = adam(gradient, init_var_params, step_size=0.1, num_iters=2000, callback=callback) 85 | -------------------------------------------------------------------------------- /examples/data.py: -------------------------------------------------------------------------------- 1 | import data_mnist 2 | import matplotlib.image 3 | import matplotlib.pyplot as plt 4 | 5 | import autograd.numpy as np 6 | import autograd.numpy.random as npr 7 | 8 | 9 | def load_mnist(): 10 | partial_flatten = lambda x: np.reshape(x, (x.shape[0], np.prod(x.shape[1:]))) 11 | one_hot = lambda x, k: np.array(x[:, None] == np.arange(k)[None, :], dtype=int) 12 | train_images, train_labels, test_images, test_labels = data_mnist.mnist() 13 | train_images = partial_flatten(train_images) / 255.0 14 | test_images = partial_flatten(test_images) / 255.0 15 | train_labels = one_hot(train_labels, 10) 16 | test_labels = one_hot(test_labels, 10) 17 | N_data = train_images.shape[0] 18 | 19 | return N_data, train_images, train_labels, test_images, test_labels 20 | 21 | 22 | def plot_images( 23 | images, 24 | ax, 25 | ims_per_row=5, 26 | padding=5, 27 | digit_dimensions=(28, 28), 28 | cmap=matplotlib.cm.binary, 29 | vmin=None, 30 | vmax=None, 31 | ): 32 | """Images should be a (N_images x pixels) matrix.""" 33 | N_images = images.shape[0] 34 | N_rows = (N_images - 1) // ims_per_row + 1 35 | pad_value = np.min(images.ravel()) 36 | concat_images = np.full( 37 | ( 38 | (digit_dimensions[0] + padding) * N_rows + padding, 39 | (digit_dimensions[1] + padding) * ims_per_row + padding, 40 | ), 41 | pad_value, 42 | ) 43 | for i in range(N_images): 44 | cur_image = np.reshape(images[i, :], digit_dimensions) 45 | row_ix = i // ims_per_row 46 | col_ix = i % ims_per_row 47 | row_start = padding + (padding + digit_dimensions[0]) * row_ix 48 | col_start = padding + (padding + digit_dimensions[1]) * col_ix 49 | concat_images[ 50 | row_start : row_start + digit_dimensions[0], col_start : col_start + digit_dimensions[1] 51 | ] = cur_image 52 | cax = ax.matshow(concat_images, cmap=cmap, vmin=vmin, vmax=vmax) 53 | plt.xticks(np.array([])) 54 | plt.yticks(np.array([])) 55 | return cax 56 | 57 | 58 | def save_images(images, filename, **kwargs): 59 | fig = plt.figure(1) 60 | fig.clf() 61 | ax = fig.add_subplot(111) 62 | plot_images(images, ax, **kwargs) 63 | fig.patch.set_visible(False) 64 | ax.patch.set_visible(False) 65 | plt.savefig(filename) 66 | 67 | 68 | def make_pinwheel(radial_std, tangential_std, num_classes, num_per_class, rate, rs=npr.RandomState(0)): 69 | """Based on code by Ryan P. Adams.""" 70 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False) 71 | 72 | features = rs.randn(num_classes * num_per_class, 2) * np.array([radial_std, tangential_std]) 73 | features[:, 0] += 1 74 | labels = np.repeat(np.arange(num_classes), num_per_class) 75 | 76 | angles = rads[labels] + rate * np.exp(features[:, 0]) 77 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]) 78 | rotations = np.reshape(rotations.T, (-1, 2, 2)) 79 | 80 | return np.einsum("ti,tij->tj", features, rotations) 81 | -------------------------------------------------------------------------------- /examples/data_mnist.py: -------------------------------------------------------------------------------- 1 | import array 2 | import gzip 3 | import os 4 | import struct 5 | from urllib.request import urlretrieve 6 | 7 | import numpy as np 8 | 9 | 10 | def download(url, filename): 11 | if not os.path.exists("data"): 12 | os.makedirs("data") 13 | out_file = os.path.join("data", filename) 14 | if not os.path.isfile(out_file): 15 | urlretrieve(url, out_file) 16 | 17 | 18 | def mnist(): 19 | base_url = "http://yann.lecun.com/exdb/mnist/" 20 | 21 | def parse_labels(filename): 22 | with gzip.open(filename, "rb") as fh: 23 | magic, num_data = struct.unpack(">II", fh.read(8)) 24 | return np.array(array.array("B", fh.read()), dtype=np.uint8) 25 | 26 | def parse_images(filename): 27 | with gzip.open(filename, "rb") as fh: 28 | magic, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) 29 | return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols) 30 | 31 | for filename in [ 32 | "train-images-idx3-ubyte.gz", 33 | "train-labels-idx1-ubyte.gz", 34 | "t10k-images-idx3-ubyte.gz", 35 | "t10k-labels-idx1-ubyte.gz", 36 | ]: 37 | download(base_url + filename, filename) 38 | 39 | train_images = parse_images("data/train-images-idx3-ubyte.gz") 40 | train_labels = parse_labels("data/train-labels-idx1-ubyte.gz") 41 | test_images = parse_images("data/t10k-images-idx3-ubyte.gz") 42 | test_labels = parse_labels("data/t10k-labels-idx1-ubyte.gz") 43 | 44 | return train_images, train_labels, test_images, test_labels 45 | -------------------------------------------------------------------------------- /examples/deep_gaussian_process.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from gaussian_process import make_gp_funs, rbf_covariance 3 | from scipy.optimize import minimize 4 | 5 | import autograd.numpy as np 6 | import autograd.numpy.random as npr 7 | from autograd import value_and_grad 8 | 9 | 10 | def build_step_function_dataset(D=1, n_data=40, noise_std=0.1): 11 | rs = npr.RandomState(0) 12 | inputs = np.linspace(-2, 2, num=n_data) 13 | targets = np.sign(inputs) + rs.randn(n_data) * noise_std 14 | inputs = inputs.reshape((len(inputs), D)) 15 | return inputs, targets 16 | 17 | 18 | def build_deep_gp(input_dimension, hidden_dimension, covariance_function): 19 | # GP going from input to hidden 20 | num_params_layer1, predict_layer1, log_marginal_likelihood_layer1 = make_gp_funs( 21 | covariance_function, num_cov_params=input_dimension + 1 22 | ) 23 | 24 | # GP going from hidden to output 25 | num_params_layer2, predict_layer2, log_marginal_likelihood_layer2 = make_gp_funs( 26 | covariance_function, num_cov_params=hidden_dimension + 1 27 | ) 28 | 29 | num_hidden_params = hidden_dimension * n_data 30 | total_num_params = num_params_layer1 + num_params_layer2 + num_hidden_params 31 | 32 | def unpack_all_params(all_params): 33 | layer1_params = all_params[:num_params_layer1] 34 | layer2_params = all_params[num_params_layer1 : num_params_layer1 + num_params_layer2] 35 | hiddens = all_params[num_params_layer1 + num_params_layer2 :] 36 | return layer1_params, layer2_params, hiddens 37 | 38 | def combined_predict_fun(all_params, X, y, xs): 39 | layer1_params, layer2_params, hiddens = unpack_all_params(all_params) 40 | h_star_mean, h_star_cov = predict_layer1(layer1_params, X, hiddens, xs) 41 | y_star_mean, y_star_cov = predict_layer2( 42 | layer2_params, np.atleast_2d(hiddens).T, y, np.atleast_2d(h_star_mean).T 43 | ) 44 | return y_star_mean, y_star_cov 45 | 46 | def log_marginal_likelihood(all_params): 47 | layer1_params, layer2_params, h = unpack_all_params(all_params) 48 | return log_marginal_likelihood_layer1(layer1_params, X, h) + log_marginal_likelihood_layer2( 49 | layer2_params, np.atleast_2d(h).T, y 50 | ) 51 | 52 | predict_layer_funcs = [predict_layer1, predict_layer2] 53 | 54 | return ( 55 | total_num_params, 56 | log_marginal_likelihood, 57 | combined_predict_fun, 58 | unpack_all_params, 59 | predict_layer_funcs, 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | n_data = 20 65 | input_dimension = 1 66 | hidden_dimension = 1 67 | X, y = build_step_function_dataset(D=input_dimension, n_data=n_data) 68 | 69 | ( 70 | total_num_params, 71 | log_marginal_likelihood, 72 | combined_predict_fun, 73 | unpack_all_params, 74 | predict_layer_funcs, 75 | ) = build_deep_gp(input_dimension, hidden_dimension, rbf_covariance) 76 | 77 | # Set up figure. 78 | fig = plt.figure(figsize=(12, 8), facecolor="white") 79 | ax_end_to_end = fig.add_subplot(311, frameon=False) 80 | ax_x_to_h = fig.add_subplot(312, frameon=False) 81 | ax_h_to_y = fig.add_subplot(313, frameon=False) 82 | plt.show(block=False) 83 | 84 | def plot_gp(ax, X, y, pred_mean, pred_cov, plot_xs): 85 | ax.cla() 86 | marg_std = np.sqrt(np.diag(pred_cov)) 87 | ax.plot(plot_xs, pred_mean, "b") 88 | ax.fill( 89 | np.concatenate([plot_xs, plot_xs[::-1]]), 90 | np.concatenate([pred_mean - 1.96 * marg_std, (pred_mean + 1.96 * marg_std)[::-1]]), 91 | alpha=0.15, 92 | fc="Blue", 93 | ec="None", 94 | ) 95 | 96 | # Show samples from posterior. 97 | rs = npr.RandomState(0) 98 | sampled_funcs = rs.multivariate_normal(pred_mean, pred_cov, size=10) 99 | ax.plot(plot_xs, sampled_funcs.T) 100 | ax.plot(X, y, "kx") 101 | ax.set_ylim([-1.5, 1.5]) 102 | ax.set_xticks([]) 103 | ax.set_yticks([]) 104 | 105 | def callback(params): 106 | print(f"Log marginal likelihood {log_marginal_likelihood(params)}") 107 | 108 | # Show posterior marginals. 109 | plot_xs = np.reshape(np.linspace(-5, 5, 300), (300, 1)) 110 | pred_mean, pred_cov = combined_predict_fun(params, X, y, plot_xs) 111 | plot_gp(ax_end_to_end, X, y, pred_mean, pred_cov, plot_xs) 112 | ax_end_to_end.set_title("X to y") 113 | 114 | layer1_params, layer2_params, hiddens = unpack_all_params(params) 115 | h_star_mean, h_star_cov = predict_layer_funcs[0](layer1_params, X, hiddens, plot_xs) 116 | y_star_mean, y_star_cov = predict_layer_funcs[0](layer2_params, np.atleast_2d(hiddens).T, y, plot_xs) 117 | 118 | plot_gp(ax_x_to_h, X, hiddens, h_star_mean, h_star_cov, plot_xs) 119 | ax_x_to_h.set_title("X to hiddens") 120 | 121 | plot_gp(ax_h_to_y, np.atleast_2d(hiddens).T, y, y_star_mean, y_star_cov, plot_xs) 122 | ax_h_to_y.set_title("hiddens to y") 123 | 124 | plt.draw() 125 | plt.pause(1.0 / 60.0) 126 | 127 | # Initialize covariance parameters and hiddens. 128 | rs = npr.RandomState(0) 129 | init_params = 0.1 * rs.randn(total_num_params) 130 | 131 | print("Optimizing covariance parameters...") 132 | objective = lambda params: -log_marginal_likelihood(params) 133 | cov_params = minimize(value_and_grad(objective), init_params, jac=True, method="CG", callback=callback) 134 | plt.pause(10.0) 135 | -------------------------------------------------------------------------------- /examples/define_gradient.py: -------------------------------------------------------------------------------- 1 | """This example shows how to define the gradient of your own functions. 2 | This can be useful for speed, numerical stability, or in cases where 3 | your code depends on external library calls.""" 4 | 5 | import autograd.numpy as np 6 | import autograd.numpy.random as npr 7 | from autograd import grad 8 | from autograd.extend import defvjp, primitive 9 | from autograd.test_util import check_grads 10 | 11 | 12 | # @primitive tells Autograd not to look inside this function, but instead 13 | # to treat it as a black box, whose gradient might be specified later. 14 | # Functions with this decorator can contain anything that Python knows 15 | # how to execute, and you can do things like in-place operations on arrays. 16 | @primitive 17 | def logsumexp(x): 18 | """Numerically stable log(sum(exp(x))), also defined in scipy.special""" 19 | max_x = np.max(x) 20 | return max_x + np.log(np.sum(np.exp(x - max_x))) 21 | 22 | 23 | # Next, we write a function that specifies the gradient with a closure. 24 | # The reason for the closure is so that the gradient can depend 25 | # on both the input to the original function (x), and the output of the 26 | # original function (ans). 27 | 28 | 29 | def logsumexp_vjp(ans, x): 30 | # If you want to be able to take higher-order derivatives, then all the 31 | # code inside this function must be itself differentiable by Autograd. 32 | # This closure multiplies g with the Jacobian of logsumexp (d_ans/d_x). 33 | # Because Autograd uses reverse-mode differentiation, g contains 34 | # the gradient of the objective w.r.t. ans, the output of logsumexp. 35 | # This returned VJP function doesn't close over `x`, so Python can 36 | # garbage-collect `x` if there are no references to it elsewhere. 37 | x_shape = x.shape 38 | return lambda g: np.full(x_shape, g) * np.exp(x - np.full(x_shape, ans)) 39 | 40 | 41 | # Now we tell Autograd that logsumexmp has a gradient-making function. 42 | defvjp(logsumexp, logsumexp_vjp) 43 | 44 | if __name__ == "__main__": 45 | # Now we can use logsumexp() inside a larger function that we want 46 | # to differentiate. 47 | def example_func(y): 48 | z = y**2 49 | lse = logsumexp(z) 50 | return np.sum(lse) 51 | 52 | grad_of_example = grad(example_func) 53 | print("Gradient: \n", grad_of_example(npr.randn(10))) 54 | 55 | # Check the gradients numerically, just to be safe. 56 | check_grads(example_func, modes=["rev"])(npr.randn(10)) 57 | -------------------------------------------------------------------------------- /examples/dot_graph.py: -------------------------------------------------------------------------------- 1 | """Generates a graphviz DOT file of an evaluation trace. 2 | Usage (need the dot binary, from the graphviz package, www.graphviz.org): 3 | 4 | python2 dot_graph.py | dot -Tpdf -o graph.pdf 5 | """ 6 | 7 | import autograd.numpy as np 8 | from autograd.tracer import Node, trace 9 | 10 | 11 | class GraphNode(Node): 12 | # Records the full graph (could having this in tracer.py) 13 | def __init__(self, value, fun, args, kwargs, parent_argnums, parents): 14 | self.fun_name = fun.__name__ 15 | self.args = args 16 | self.parents = dict(zip(parent_argnums, parents)) 17 | self.isroot = False 18 | 19 | def initialize_root(self, x): 20 | self.isroot = True 21 | 22 | def __repr__(self): 23 | return f"node_{id(self)}" 24 | 25 | 26 | def trace_graph(f, x): 27 | start_node = GraphNode.new_root(x) 28 | _, node = trace(start_node, f, x) 29 | return node 30 | 31 | 32 | dot_edge = "{} -> {} [color=gray30];\n".format 33 | dot_function_node = '{} [label="{}", shape=box, color=lightblue, style=filled];\n'.format 34 | dot_variable_node = '{} [label="{}", color=orange, style=filled];\n'.format 35 | dot_graph = "digraph G {{{}}}".format 36 | 37 | 38 | def graph_to_dotfile(graph): 39 | visited = set() 40 | 41 | def node_to_fragment(node): 42 | visited.add(node) 43 | if node.isroot: 44 | return dot_variable_node(node, "input") 45 | fragment = dot_function_node(node, node.fun_name) 46 | for argnum, arg in enumerate(node.args): 47 | if argnum in node.parents: 48 | parent = node.parents[argnum] 49 | fragment += dot_edge(parent, node) 50 | if parent not in visited: 51 | fragment += node_to_fragment(parent) 52 | else: 53 | argnode = f"{node}_arg_{argnum}" 54 | fragment += dot_edge(argnode, node) 55 | fragment += dot_variable_node(argnode, arg) 56 | 57 | return fragment 58 | 59 | dot_body = node_to_fragment(graph) 60 | dot_body += dot_variable_node("output", "output") 61 | dot_body += dot_edge(graph, "output") 62 | return dot_graph(dot_body) 63 | 64 | 65 | if __name__ == "__main__": 66 | 67 | def fun(x): 68 | y = np.sin(x) 69 | return (y + np.exp(x) - 0.5) * y 70 | 71 | print(graph_to_dotfile(trace_graph(fun, 1.0))) 72 | -------------------------------------------------------------------------------- /examples/fixed_points.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | from autograd import grad 3 | from autograd.misc.fixed_points import fixed_point 4 | 5 | 6 | def newton_sqrt_iter(a): 7 | return lambda x: 0.5 * (x + a / x) 8 | 9 | 10 | def grad_descent_sqrt_iter(a): 11 | return lambda x: x - 0.05 * (x**2 - a) 12 | 13 | 14 | def sqrt(a, guess=10.0): 15 | # return fixed_point(newton_sqrt_iter, a, guess, distance, 1e-4) 16 | return fixed_point(grad_descent_sqrt_iter, a, guess, distance, 1e-4) 17 | 18 | 19 | def distance(x, y): 20 | return np.abs(x - y) 21 | 22 | 23 | print(np.sqrt(2.0)) 24 | print(sqrt(2.0)) 25 | print() 26 | print(grad(np.sqrt)(2.0)) 27 | print(grad(sqrt)(2.0)) 28 | print() 29 | print(grad(grad(np.sqrt))(2.0)) 30 | print(grad(grad(sqrt))(2.0)) 31 | print() 32 | -------------------------------------------------------------------------------- /examples/fluidsim/animated.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/animated.gif -------------------------------------------------------------------------------- /examples/fluidsim/fluidsim.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | from matplotlib.pyplot import imread 6 | from scipy.optimize import minimize 7 | 8 | import autograd.numpy as np 9 | from autograd import value_and_grad 10 | 11 | # Fluid simulation code based on 12 | # "Real-Time Fluid Dynamics for Games" by Jos Stam 13 | # https://www.josstam.com/_files/ugd/cf1fd6_9989229efbd34a26ba5ccd913721a2ac.pdf 14 | 15 | 16 | def project(vx, vy): 17 | """Project the velocity field to be approximately mass-conserving, 18 | using a few iterations of Gauss-Seidel.""" 19 | p = np.zeros(vx.shape) 20 | h = 1.0 / vx.shape[0] 21 | div = ( 22 | -0.5 23 | * h 24 | * ( 25 | np.roll(vx, -1, axis=0) 26 | - np.roll(vx, 1, axis=0) 27 | + np.roll(vy, -1, axis=1) 28 | - np.roll(vy, 1, axis=1) 29 | ) 30 | ) 31 | 32 | for k in range(10): 33 | p = ( 34 | div 35 | + np.roll(p, 1, axis=0) 36 | + np.roll(p, -1, axis=0) 37 | + np.roll(p, 1, axis=1) 38 | + np.roll(p, -1, axis=1) 39 | ) / 4.0 40 | 41 | vx -= 0.5 * (np.roll(p, -1, axis=0) - np.roll(p, 1, axis=0)) / h 42 | vy -= 0.5 * (np.roll(p, -1, axis=1) - np.roll(p, 1, axis=1)) / h 43 | return vx, vy 44 | 45 | 46 | def advect(f, vx, vy): 47 | """Move field f according to x and y velocities (u and v) 48 | using an implicit Euler integrator.""" 49 | rows, cols = f.shape 50 | cell_ys, cell_xs = np.meshgrid(np.arange(rows), np.arange(cols)) 51 | center_xs = (cell_xs - vx).ravel() 52 | center_ys = (cell_ys - vy).ravel() 53 | 54 | # Compute indices of source cells. 55 | left_ix = np.floor(center_xs).astype(int) 56 | top_ix = np.floor(center_ys).astype(int) 57 | rw = center_xs - left_ix # Relative weight of right-hand cells. 58 | bw = center_ys - top_ix # Relative weight of bottom cells. 59 | left_ix = np.mod(left_ix, rows) # Wrap around edges of simulation. 60 | right_ix = np.mod(left_ix + 1, rows) 61 | top_ix = np.mod(top_ix, cols) 62 | bot_ix = np.mod(top_ix + 1, cols) 63 | 64 | # A linearly-weighted sum of the 4 surrounding cells. 65 | flat_f = (1 - rw) * ((1 - bw) * f[left_ix, top_ix] + bw * f[left_ix, bot_ix]) + rw * ( 66 | (1 - bw) * f[right_ix, top_ix] + bw * f[right_ix, bot_ix] 67 | ) 68 | return np.reshape(flat_f, (rows, cols)) 69 | 70 | 71 | def simulate(vx, vy, smoke, num_time_steps, ax=None, render=False): 72 | print("Running simulation...") 73 | for t in range(num_time_steps): 74 | if ax: 75 | plot_matrix(ax, smoke, t, render) 76 | vx_updated = advect(vx, vx, vy) 77 | vy_updated = advect(vy, vx, vy) 78 | vx, vy = project(vx_updated, vy_updated) 79 | smoke = advect(smoke, vx, vy) 80 | if ax: 81 | plot_matrix(ax, smoke, num_time_steps, render) 82 | return smoke 83 | 84 | 85 | def plot_matrix(ax, mat, t, render=False): 86 | plt.cla() 87 | ax.matshow(mat) 88 | ax.set_xticks([]) 89 | ax.set_yticks([]) 90 | plt.draw() 91 | if render: 92 | matplotlib.image.imsave(f"step{t:03d}.png", mat) 93 | plt.pause(0.001) 94 | 95 | 96 | if __name__ == "__main__": 97 | simulation_timesteps = 100 98 | basepath = os.path.dirname(__file__) 99 | 100 | print("Loading initial and target states...") 101 | init_smoke = imread(os.path.join(basepath, "init_smoke.png"))[:, :, 0] 102 | # target = imread('peace.png')[::2,::2,3] 103 | target = imread(os.path.join(basepath, "skull.png"))[::2, ::2] 104 | rows, cols = target.shape 105 | 106 | init_dx_and_dy = np.zeros((2, rows, cols)).ravel() 107 | 108 | def distance_from_target_image(smoke): 109 | return np.mean((target - smoke) ** 2) 110 | 111 | def convert_param_vector_to_matrices(params): 112 | vx = np.reshape(params[: (rows * cols)], (rows, cols)) 113 | vy = np.reshape(params[(rows * cols) :], (rows, cols)) 114 | return vx, vy 115 | 116 | def objective(params): 117 | init_vx, init_vy = convert_param_vector_to_matrices(params) 118 | final_smoke = simulate(init_vx, init_vy, init_smoke, simulation_timesteps) 119 | return distance_from_target_image(final_smoke) 120 | 121 | # Specify gradient of objective function using autograd. 122 | objective_with_grad = value_and_grad(objective) 123 | 124 | fig = plt.figure(figsize=(8, 8)) 125 | ax = fig.add_subplot(111, frameon=False) 126 | 127 | def callback(params): 128 | init_vx, init_vy = convert_param_vector_to_matrices(params) 129 | simulate(init_vx, init_vy, init_smoke, simulation_timesteps, ax) 130 | 131 | print("Optimizing initial conditions...") 132 | result = minimize( 133 | objective_with_grad, 134 | init_dx_and_dy, 135 | jac=True, 136 | method="CG", 137 | options={"maxiter": 25, "disp": True}, 138 | callback=callback, 139 | ) 140 | 141 | print("Rendering optimized flow...") 142 | init_vx, init_vy = convert_param_vector_to_matrices(result.x) 143 | simulate(init_vx, init_vy, init_smoke, simulation_timesteps, ax, render=True) 144 | 145 | print("Converting frames to an animated GIF...") 146 | os.system("convert -delay 5 -loop 0 step*.png -delay 250 step100.png surprise.gif") # Using imagemagick. 147 | os.system("rm step*.png") 148 | -------------------------------------------------------------------------------- /examples/fluidsim/init_smoke.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/init_smoke.png -------------------------------------------------------------------------------- /examples/fluidsim/peace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/peace.png -------------------------------------------------------------------------------- /examples/fluidsim/skull.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/skull.png -------------------------------------------------------------------------------- /examples/fluidsim/surprise.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/surprise.gif -------------------------------------------------------------------------------- /examples/fluidsim/wing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/fluidsim/wing.png -------------------------------------------------------------------------------- /examples/gaussian_process.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/gaussian_process.png -------------------------------------------------------------------------------- /examples/gaussian_process.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from scipy.optimize import minimize 3 | 4 | import autograd.numpy as np 5 | import autograd.numpy.random as npr 6 | import autograd.scipy.stats.multivariate_normal as mvn 7 | from autograd import value_and_grad 8 | from autograd.numpy.linalg import solve 9 | 10 | 11 | def make_gp_funs(cov_func, num_cov_params): 12 | """Functions that perform Gaussian process regression. 13 | cov_func has signature (cov_params, x, x')""" 14 | 15 | def unpack_kernel_params(params): 16 | mean = params[0] 17 | cov_params = params[2:] 18 | noise_scale = np.exp(params[1]) + 0.0001 19 | return mean, cov_params, noise_scale 20 | 21 | def predict(params, x, y, xstar): 22 | """Returns the predictive mean and covariance at locations xstar, 23 | of the latent function value f (without observation noise).""" 24 | mean, cov_params, noise_scale = unpack_kernel_params(params) 25 | cov_f_f = cov_func(cov_params, xstar, xstar) 26 | cov_y_f = cov_func(cov_params, x, xstar) 27 | cov_y_y = cov_func(cov_params, x, x) + noise_scale * np.eye(len(y)) 28 | pred_mean = mean + np.dot(solve(cov_y_y, cov_y_f).T, y - mean) 29 | pred_cov = cov_f_f - np.dot(solve(cov_y_y, cov_y_f).T, cov_y_f) 30 | return pred_mean, pred_cov 31 | 32 | def log_marginal_likelihood(params, x, y): 33 | mean, cov_params, noise_scale = unpack_kernel_params(params) 34 | cov_y_y = cov_func(cov_params, x, x) + noise_scale * np.eye(len(y)) 35 | prior_mean = mean * np.ones(len(y)) 36 | return mvn.logpdf(y, prior_mean, cov_y_y) 37 | 38 | return num_cov_params + 2, predict, log_marginal_likelihood 39 | 40 | 41 | # Define an example covariance function. 42 | def rbf_covariance(kernel_params, x, xp): 43 | output_scale = np.exp(kernel_params[0]) 44 | lengthscales = np.exp(kernel_params[1:]) 45 | diffs = np.expand_dims(x / lengthscales, 1) - np.expand_dims(xp / lengthscales, 0) 46 | return output_scale * np.exp(-0.5 * np.sum(diffs**2, axis=2)) 47 | 48 | 49 | def build_toy_dataset(D=1, n_data=20, noise_std=0.1): 50 | rs = npr.RandomState(0) 51 | inputs = np.concatenate([np.linspace(0, 3, num=n_data / 2), np.linspace(6, 8, num=n_data / 2)]) 52 | targets = (np.cos(inputs) + rs.randn(n_data) * noise_std) / 2.0 53 | inputs = (inputs - 4.0) / 2.0 54 | inputs = inputs.reshape((len(inputs), D)) 55 | return inputs, targets 56 | 57 | 58 | if __name__ == "__main__": 59 | D = 1 60 | 61 | # Build model and objective function. 62 | num_params, predict, log_marginal_likelihood = make_gp_funs(rbf_covariance, num_cov_params=D + 1) 63 | 64 | X, y = build_toy_dataset(D=D) 65 | objective = lambda params: -log_marginal_likelihood(params, X, y) 66 | 67 | # Set up figure. 68 | fig = plt.figure(figsize=(12, 8), facecolor="white") 69 | ax = fig.add_subplot(111, frameon=False) 70 | plt.show(block=False) 71 | 72 | def callback(params): 73 | print(f"Log likelihood {-objective(params)}") 74 | plt.cla() 75 | 76 | # Show posterior marginals. 77 | plot_xs = np.reshape(np.linspace(-7, 7, 300), (300, 1)) 78 | pred_mean, pred_cov = predict(params, X, y, plot_xs) 79 | marg_std = np.sqrt(np.diag(pred_cov)) 80 | ax.plot(plot_xs, pred_mean, "b") 81 | ax.fill( 82 | np.concatenate([plot_xs, plot_xs[::-1]]), 83 | np.concatenate([pred_mean - 1.96 * marg_std, (pred_mean + 1.96 * marg_std)[::-1]]), 84 | alpha=0.15, 85 | fc="Blue", 86 | ec="None", 87 | ) 88 | 89 | # Show samples from posterior. 90 | rs = npr.RandomState(0) 91 | sampled_funcs = rs.multivariate_normal(pred_mean, pred_cov, size=10) 92 | ax.plot(plot_xs, sampled_funcs.T) 93 | 94 | ax.plot(X, y, "kx") 95 | ax.set_ylim([-1.5, 1.5]) 96 | ax.set_xticks([]) 97 | ax.set_yticks([]) 98 | plt.draw() 99 | plt.pause(1.0 / 60.0) 100 | 101 | # Initialize covariance parameters 102 | rs = npr.RandomState(0) 103 | init_params = 0.1 * rs.randn(num_params) 104 | 105 | print("Optimizing covariance parameters...") 106 | cov_params = minimize(value_and_grad(objective), init_params, jac=True, method="CG", callback=callback) 107 | plt.pause(10.0) 108 | -------------------------------------------------------------------------------- /examples/gmm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/gmm.png -------------------------------------------------------------------------------- /examples/gmm.py: -------------------------------------------------------------------------------- 1 | """Implements a Gaussian mixture model, in which parameters are fit using 2 | gradient descent. This example runs on 2-dimensional data, but the model 3 | works on arbitrarily-high dimension.""" 4 | 5 | import matplotlib.pyplot as plt 6 | from data import make_pinwheel 7 | from scipy.optimize import minimize 8 | 9 | import autograd.numpy as np 10 | import autograd.numpy.random as npr 11 | import autograd.scipy.stats.multivariate_normal as mvn 12 | from autograd import grad, hessian_vector_product 13 | from autograd.misc.flatten import flatten_func 14 | from autograd.scipy.special import logsumexp 15 | 16 | 17 | def init_gmm_params(num_components, D, scale, rs=npr.RandomState(0)): 18 | return { 19 | "log proportions": rs.randn(num_components) * scale, 20 | "means": rs.randn(num_components, D) * scale, 21 | "lower triangles": np.zeros((num_components, D, D)) + np.eye(D), 22 | } 23 | 24 | 25 | def log_normalize(x): 26 | return x - logsumexp(x) 27 | 28 | 29 | def unpack_gmm_params(params): 30 | normalized_log_proportions = log_normalize(params["log proportions"]) 31 | return normalized_log_proportions, params["means"], params["lower triangles"] 32 | 33 | 34 | def gmm_log_likelihood(params, data): 35 | cluster_lls = [] 36 | for log_proportion, mean, cov_sqrt in zip(*unpack_gmm_params(params)): 37 | cov = np.dot(cov_sqrt.T, cov_sqrt) 38 | cluster_lls.append(log_proportion + mvn.logpdf(data, mean, cov)) 39 | return np.sum(logsumexp(np.vstack(cluster_lls), axis=0)) 40 | 41 | 42 | def plot_ellipse(ax, mean, cov_sqrt, alpha, num_points=100): 43 | angles = np.linspace(0, 2 * np.pi, num_points) 44 | circle_pts = np.vstack([np.cos(angles), np.sin(angles)]).T * 2.0 45 | cur_pts = mean + np.dot(circle_pts, cov_sqrt) 46 | ax.plot(cur_pts[:, 0], cur_pts[:, 1], "-", alpha=alpha) 47 | 48 | 49 | def plot_gaussian_mixture(params, ax): 50 | for log_proportion, mean, cov_sqrt in zip(*unpack_gmm_params(params)): 51 | alpha = np.minimum(1.0, np.exp(log_proportion) * 10) 52 | plot_ellipse(ax, mean, cov_sqrt, alpha) 53 | 54 | 55 | if __name__ == "__main__": 56 | init_params = init_gmm_params(num_components=10, D=2, scale=0.1) 57 | 58 | data = make_pinwheel(radial_std=0.3, tangential_std=0.05, num_classes=3, num_per_class=100, rate=0.4) 59 | 60 | def objective(params): 61 | return -gmm_log_likelihood(params, data) 62 | 63 | flattened_obj, unflatten, flattened_init_params = flatten_func(objective, init_params) 64 | 65 | fig = plt.figure(figsize=(12, 8), facecolor="white") 66 | ax = fig.add_subplot(111, frameon=False) 67 | plt.show(block=False) 68 | 69 | def callback(flattened_params): 70 | params = unflatten(flattened_params) 71 | print(f"Log likelihood {-objective(params)}") 72 | ax.cla() 73 | ax.plot(data[:, 0], data[:, 1], "k.") 74 | ax.set_xticks([]) 75 | ax.set_yticks([]) 76 | plot_gaussian_mixture(params, ax) 77 | plt.draw() 78 | plt.pause(1.0 / 60.0) 79 | 80 | minimize( 81 | flattened_obj, 82 | flattened_init_params, 83 | jac=grad(flattened_obj), 84 | hessp=hessian_vector_product(flattened_obj), 85 | method="Newton-CG", 86 | callback=callback, 87 | ) 88 | -------------------------------------------------------------------------------- /examples/gplvm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/gplvm.png -------------------------------------------------------------------------------- /examples/gplvm.py: -------------------------------------------------------------------------------- 1 | # Implements a Gaussian process latent-variable model. 2 | # The (high-dimensional) data, Y is explained by some low-dimensional latent 3 | # data X, warped by a function drawn from a GP prior (f). So Y = f(X), but 4 | # we don't know X or f. 5 | # 6 | # In this example, we optimize X and the hyperparameters of the GP, but 7 | # we integrate over all possible functions f. 8 | # 9 | # Normally the observed data would be high-dimensional. 10 | # 11 | # David Duvenaud (duvenaud@gmail.com) 12 | 13 | 14 | import matplotlib.pyplot as plt 15 | from data import make_pinwheel 16 | from gaussian_process import make_gp_funs, rbf_covariance 17 | from scipy.optimize import minimize 18 | 19 | import autograd.numpy as np 20 | import autograd.numpy.random as npr 21 | from autograd import value_and_grad 22 | from autograd.scipy.stats import norm 23 | 24 | if __name__ == "__main__": 25 | data_dimension = 2 # Normally the data dimension would be much higher. 26 | latent_dimension = 2 27 | 28 | # Build model and objective function. 29 | params_per_gp, predict, log_marginal_likelihood = make_gp_funs( 30 | rbf_covariance, num_cov_params=latent_dimension + 1 31 | ) 32 | total_gp_params = data_dimension * params_per_gp 33 | 34 | data = make_pinwheel(radial_std=0.3, tangential_std=0.05, num_classes=3, num_per_class=30, rate=0.4) 35 | datalen = data.shape[0] 36 | 37 | num_latent_params = datalen * latent_dimension 38 | 39 | def unpack_params(params): 40 | gp_params = np.reshape(params[:total_gp_params], (data_dimension, params_per_gp)) 41 | latents = np.reshape(params[total_gp_params:], (datalen, latent_dimension)) 42 | return gp_params, latents 43 | 44 | def objective(params): 45 | gp_params, latents = unpack_params(params) 46 | gp_likelihood = sum( 47 | [log_marginal_likelihood(gp_params[i], latents, data[:, i]) for i in range(data_dimension)] 48 | ) 49 | latent_prior_likelihood = np.sum(norm.logpdf(latents)) 50 | return -gp_likelihood - latent_prior_likelihood 51 | 52 | # Set up figure. 53 | fig = plt.figure(figsize=(12, 8), facecolor="white") 54 | latent_ax = fig.add_subplot(121, frameon=False) 55 | data_ax = fig.add_subplot(122, frameon=False) 56 | plt.show(block=False) 57 | 58 | def callback(params): 59 | print(f"Log likelihood {-objective(params)}") 60 | gp_params, latents = unpack_params(params) 61 | 62 | data_ax.cla() 63 | data_ax.plot(data[:, 0], data[:, 1], "bx") 64 | data_ax.set_xticks([]) 65 | data_ax.set_yticks([]) 66 | data_ax.set_title("Observed Data") 67 | 68 | latent_ax.cla() 69 | latent_ax.plot(latents[:, 0], latents[:, 1], "kx") 70 | latent_ax.set_xticks([]) 71 | latent_ax.set_yticks([]) 72 | latent_ax.set_xlim([-2, 2]) 73 | latent_ax.set_ylim([-2, 2]) 74 | latent_ax.set_title("Latent coordinates") 75 | 76 | plt.draw() 77 | plt.pause(1.0 / 60.0) 78 | 79 | # Initialize covariance parameters 80 | rs = npr.RandomState(1) 81 | init_params = rs.randn(total_gp_params + num_latent_params) * 0.1 82 | 83 | print("Optimizing covariance parameters and latent variable locations...") 84 | minimize(value_and_grad(objective), init_params, jac=True, method="CG", callback=callback) 85 | -------------------------------------------------------------------------------- /examples/graph.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/graph.pdf -------------------------------------------------------------------------------- /examples/hmm_em.py: -------------------------------------------------------------------------------- 1 | import string 2 | from functools import partial 3 | from os.path import dirname, join 4 | 5 | import autograd.numpy as np 6 | import autograd.numpy.random as npr 7 | from autograd import value_and_grad as vgrad 8 | from autograd.scipy.special import logsumexp 9 | 10 | 11 | def EM(init_params, data, callback=None): 12 | def EM_update(params): 13 | natural_params = list(map(np.log, params)) 14 | loglike, E_stats = vgrad(log_partition_function)(natural_params, data) # E step 15 | if callback: 16 | callback(loglike, params) 17 | return list(map(normalize, E_stats)) # M step 18 | 19 | def fixed_point(f, x0): 20 | x1 = f(x0) 21 | while different(x0, x1): 22 | x0, x1 = x1, f(x1) 23 | return x1 24 | 25 | def different(params1, params2): 26 | allclose = partial(np.allclose, atol=1e-3, rtol=1e-3) 27 | return not all(map(allclose, params1, params2)) 28 | 29 | return fixed_point(EM_update, init_params) 30 | 31 | 32 | def normalize(a): 33 | def replace_zeros(a): 34 | return np.where(a > 0.0, a, 1.0) 35 | 36 | return a / replace_zeros(a.sum(-1, keepdims=True)) 37 | 38 | 39 | def log_partition_function(natural_params, data): 40 | if isinstance(data, list): 41 | return sum(map(partial(log_partition_function, natural_params), data)) 42 | 43 | log_pi, log_A, log_B = natural_params 44 | 45 | log_alpha = log_pi 46 | for y in data: 47 | log_alpha = logsumexp(log_alpha[:, None] + log_A, axis=0) + log_B[:, y] 48 | 49 | return logsumexp(log_alpha) 50 | 51 | 52 | def initialize_hmm_parameters(num_states, num_outputs): 53 | init_pi = normalize(npr.rand(num_states)) 54 | init_A = normalize(npr.rand(num_states, num_states)) 55 | init_B = normalize(npr.rand(num_states, num_outputs)) 56 | return init_pi, init_A, init_B 57 | 58 | 59 | def build_dataset(filename, max_lines=-1): 60 | """Loads a text file, and turns each line into an encoded sequence.""" 61 | encodings = dict(list(map(reversed, enumerate(string.printable)))) 62 | digitize = lambda char: encodings[char] if char in encodings else len(encodings) 63 | encode_line = lambda line: np.array(list(map(digitize, line))) 64 | nonblank_line = lambda line: len(line) > 2 65 | 66 | with open(filename) as f: 67 | lines = f.readlines() 68 | 69 | encoded_lines = list(map(encode_line, list(filter(nonblank_line, lines))[:max_lines])) 70 | num_outputs = len(encodings) + 1 71 | 72 | return encoded_lines, num_outputs 73 | 74 | 75 | if __name__ == "__main__": 76 | np.random.seed(0) 77 | np.seterr(divide="ignore") 78 | 79 | # callback to print log likelihoods during training 80 | print_loglike = lambda loglike, params: print(loglike) 81 | 82 | # load training data 83 | lstm_filename = join(dirname(__file__), "lstm.py") 84 | train_inputs, num_outputs = build_dataset(lstm_filename, max_lines=60) 85 | 86 | # train with EM 87 | num_states = 20 88 | init_params = initialize_hmm_parameters(num_states, num_outputs) 89 | pi, A, B = EM(init_params, train_inputs, print_loglike) 90 | -------------------------------------------------------------------------------- /examples/ica.py: -------------------------------------------------------------------------------- 1 | import matplotlib.cm as cm 2 | import matplotlib.pyplot as plt 3 | from scipy.optimize import minimize 4 | 5 | import autograd.numpy as np 6 | import autograd.numpy.random as npr 7 | import autograd.scipy.stats.t as t 8 | from autograd import value_and_grad 9 | 10 | 11 | def make_ica_funs(observed_dimension, latent_dimension): 12 | """These functions implement independent component analysis. 13 | 14 | The model is: 15 | latents are drawn i.i.d. for each data point from a product of student-ts. 16 | weights are the same across all datapoints. 17 | each data = latents * weghts + noise.""" 18 | 19 | def sample(weights, n_samples, noise_std, rs): 20 | latents = rs.randn(latent_dimension, n_samples) 21 | latents = np.array(sorted(latents.T, key=lambda a_entry: a_entry[0])).T 22 | noise = rs.randn(n_samples, observed_dimension) * noise_std 23 | observed = predict(weights, latents) + noise 24 | return latents, observed 25 | 26 | def predict(weights, latents): 27 | return np.dot(weights, latents).T 28 | 29 | def logprob(weights, latents, noise_std, observed): 30 | preds = predict(weights, latents) 31 | log_lik = np.sum(t.logpdf(preds, 2.4, observed, noise_std)) 32 | return log_lik 33 | 34 | num_weights = observed_dimension * latent_dimension 35 | 36 | def unpack_weights(weights): 37 | return np.reshape(weights, (observed_dimension, latent_dimension)) 38 | 39 | return num_weights, sample, logprob, unpack_weights 40 | 41 | 42 | def color_scatter(ax, xs, ys): 43 | colors = cm.rainbow(np.linspace(0, 1, len(ys))) 44 | for x, y, c in zip(xs, ys, colors): 45 | ax.scatter(x, y, color=c) 46 | 47 | 48 | if __name__ == "__main__": 49 | observed_dimension = 100 50 | latent_dimension = 2 51 | true_noise_var = 1.0 52 | n_samples = 200 53 | 54 | num_weights, sample, logprob, unpack_weights = make_ica_funs(observed_dimension, latent_dimension) 55 | 56 | num_latent_params = latent_dimension * n_samples 57 | total_num_params = num_weights + num_latent_params + 1 58 | 59 | def unpack_params(params): 60 | weights = unpack_weights(params[:num_weights]) 61 | latents = np.reshape( 62 | params[num_weights : num_weights + num_latent_params], (latent_dimension, n_samples) 63 | ) 64 | noise_std = np.exp(params[-1]) 65 | return weights, latents, noise_std 66 | 67 | rs = npr.RandomState(0) 68 | true_weights = np.zeros((observed_dimension, latent_dimension)) 69 | for i in range(latent_dimension): 70 | true_weights[:, i] = np.sin(np.linspace(0, 4 + i * 3.2, observed_dimension)) 71 | 72 | true_latents, data = sample(true_weights, n_samples, true_noise_var, rs) 73 | 74 | # Set up figure. 75 | fig2 = plt.figure(figsize=(6, 6), facecolor="white") 76 | ax_data = fig2.add_subplot(111, frameon=False) 77 | ax_data.matshow(data) 78 | 79 | fig1 = plt.figure(figsize=(12, 16), facecolor="white") 80 | ax_true_latents = fig1.add_subplot(411, frameon=False) 81 | ax_est_latents = fig1.add_subplot(412, frameon=False) 82 | ax_true_weights = fig1.add_subplot(413, frameon=False) 83 | ax_est_weights = fig1.add_subplot(414, frameon=False) 84 | 85 | plt.show(block=False) 86 | ax_true_weights.scatter(true_weights[:, 0], true_weights[:, 1]) 87 | ax_true_weights.set_title("True weights") 88 | color_scatter(ax_true_latents, true_latents[0, :], true_latents[1, :]) 89 | ax_true_latents.set_title("True latents") 90 | ax_true_latents.set_xticks([]) 91 | ax_true_weights.set_xticks([]) 92 | ax_true_latents.set_yticks([]) 93 | ax_true_weights.set_yticks([]) 94 | 95 | def objective(params): 96 | weight_matrix, latents, noise_std = unpack_params(params) 97 | return -logprob(weight_matrix, latents, noise_std, data) / n_samples 98 | 99 | def callback(params): 100 | weights, latents, noise_std = unpack_params(params) 101 | print(f"Log likelihood {-objective(params)}, noise_std {noise_std}") 102 | ax_est_weights.cla() 103 | ax_est_weights.scatter(weights[:, 0], weights[:, 1]) 104 | ax_est_weights.set_title("Estimated weights") 105 | ax_est_latents.cla() 106 | color_scatter(ax_est_latents, latents[0, :], latents[1, :]) 107 | ax_est_latents.set_title("Estimated latents") 108 | ax_est_weights.set_yticks([]) 109 | ax_est_latents.set_yticks([]) 110 | ax_est_weights.set_xticks([]) 111 | ax_est_latents.set_xticks([]) 112 | plt.draw() 113 | plt.pause(1.0 / 60.0) 114 | 115 | # Initialize and optimize model. 116 | rs = npr.RandomState(0) 117 | init_params = rs.randn(total_num_params) 118 | minimize(value_and_grad(objective), init_params, jac=True, method="CG", callback=callback) 119 | plt.pause(20) 120 | -------------------------------------------------------------------------------- /examples/logistic_regression.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | from autograd import grad 3 | from autograd.test_util import check_grads 4 | 5 | 6 | def sigmoid(x): 7 | return 0.5 * (np.tanh(x) + 1) 8 | 9 | 10 | def logistic_predictions(weights, inputs): 11 | # Outputs probability of a label being true according to logistic model. 12 | return sigmoid(np.dot(inputs, weights)) 13 | 14 | 15 | def training_loss(weights): 16 | # Training loss is the negative log-likelihood of the training labels. 17 | preds = logistic_predictions(weights, inputs) 18 | label_probabilities = preds * targets + (1 - preds) * (1 - targets) 19 | return -np.sum(np.log(label_probabilities)) 20 | 21 | 22 | # Build a toy dataset. 23 | inputs = np.array([[0.52, 1.12, 0.77], [0.88, -1.08, 0.15], [0.52, 0.06, -1.30], [0.74, -2.49, 1.39]]) 24 | targets = np.array([True, True, False, True]) 25 | 26 | # Build a function that returns gradients of training loss using autograd. 27 | training_gradient_fun = grad(training_loss) 28 | 29 | # Check the gradients numerically, just to be safe. 30 | weights = np.array([0.0, 0.0, 0.0]) 31 | check_grads(training_loss, modes=["rev"])(weights) 32 | 33 | # Optimize weights using gradient descent. 34 | print("Initial loss:", training_loss(weights)) 35 | for i in range(100): 36 | weights -= training_gradient_fun(weights) * 0.01 37 | 38 | print("Trained loss:", training_loss(weights)) 39 | -------------------------------------------------------------------------------- /examples/lstm.py: -------------------------------------------------------------------------------- 1 | """Implements the long-short term memory character model. 2 | This version vectorizes over multiple examples, but each string 3 | has a fixed length.""" 4 | 5 | from os.path import dirname, join 6 | 7 | from rnn import build_dataset, concat_and_multiply, one_hot_to_string, sigmoid, string_to_one_hot 8 | 9 | import autograd.numpy as np 10 | import autograd.numpy.random as npr 11 | from autograd import grad 12 | from autograd.misc.optimizers import adam 13 | from autograd.scipy.special import logsumexp 14 | 15 | 16 | def init_lstm_params(input_size, state_size, output_size, param_scale=0.01, rs=npr.RandomState(0)): 17 | def rp(*shape): 18 | return rs.randn(*shape) * param_scale 19 | 20 | return { 21 | "init cells": rp(1, state_size), 22 | "init hiddens": rp(1, state_size), 23 | "change": rp(input_size + state_size + 1, state_size), 24 | "forget": rp(input_size + state_size + 1, state_size), 25 | "ingate": rp(input_size + state_size + 1, state_size), 26 | "outgate": rp(input_size + state_size + 1, state_size), 27 | "predict": rp(state_size + 1, output_size), 28 | } 29 | 30 | 31 | def lstm_predict(params, inputs): 32 | def update_lstm(input, hiddens, cells): 33 | change = np.tanh(concat_and_multiply(params["change"], input, hiddens)) 34 | forget = sigmoid(concat_and_multiply(params["forget"], input, hiddens)) 35 | ingate = sigmoid(concat_and_multiply(params["ingate"], input, hiddens)) 36 | outgate = sigmoid(concat_and_multiply(params["outgate"], input, hiddens)) 37 | cells = cells * forget + ingate * change 38 | hiddens = outgate * np.tanh(cells) 39 | return hiddens, cells 40 | 41 | def hiddens_to_output_probs(hiddens): 42 | output = concat_and_multiply(params["predict"], hiddens) 43 | return output - logsumexp(output, axis=1, keepdims=True) # Normalize log-probs. 44 | 45 | num_sequences = inputs.shape[1] 46 | hiddens = np.repeat(params["init hiddens"], num_sequences, axis=0) 47 | cells = np.repeat(params["init cells"], num_sequences, axis=0) 48 | 49 | output = [hiddens_to_output_probs(hiddens)] 50 | for input in inputs: # Iterate over time steps. 51 | hiddens, cells = update_lstm(input, hiddens, cells) 52 | output.append(hiddens_to_output_probs(hiddens)) 53 | return output 54 | 55 | 56 | def lstm_log_likelihood(params, inputs, targets): 57 | logprobs = lstm_predict(params, inputs) 58 | loglik = 0.0 59 | num_time_steps, num_examples, _ = inputs.shape 60 | for t in range(num_time_steps): 61 | loglik += np.sum(logprobs[t] * targets[t]) 62 | return loglik / (num_time_steps * num_examples) 63 | 64 | 65 | if __name__ == "__main__": 66 | num_chars = 128 67 | 68 | # Learn to predict our own source code. 69 | text_filename = join(dirname(__file__), "lstm.py") 70 | train_inputs = build_dataset(text_filename, sequence_length=30, alphabet_size=num_chars, max_lines=60) 71 | 72 | init_params = init_lstm_params(input_size=128, output_size=128, state_size=40, param_scale=0.01) 73 | 74 | def print_training_prediction(weights): 75 | print("Training text Predicted text") 76 | logprobs = np.asarray(lstm_predict(weights, train_inputs)) 77 | for t in range(logprobs.shape[1]): 78 | training_text = one_hot_to_string(train_inputs[:, t, :]) 79 | predicted_text = one_hot_to_string(logprobs[:, t, :]) 80 | print(training_text.replace("\n", " ") + "|" + predicted_text.replace("\n", " ")) 81 | 82 | def training_loss(params, iter): 83 | return -lstm_log_likelihood(params, train_inputs, train_inputs) 84 | 85 | def callback(weights, iter, gradient): 86 | if iter % 10 == 0: 87 | print("Iteration", iter, "Train loss:", training_loss(weights, 0)) 88 | print_training_prediction(weights) 89 | 90 | # Build gradient of loss function using autograd. 91 | training_loss_grad = grad(training_loss) 92 | 93 | print("Training LSTM...") 94 | trained_params = adam(training_loss_grad, init_params, step_size=0.1, num_iters=1000, callback=callback) 95 | 96 | print() 97 | print("Generating text from LSTM...") 98 | num_letters = 30 99 | for t in range(20): 100 | text = "" 101 | for i in range(num_letters): 102 | seqs = string_to_one_hot(text, num_chars)[:, np.newaxis, :] 103 | logprobs = lstm_predict(trained_params, seqs)[-1].ravel() 104 | text += chr(npr.choice(len(logprobs), p=np.exp(logprobs))) 105 | print(text) 106 | -------------------------------------------------------------------------------- /examples/natural_gradient_black_box_svi.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | # same BBSVI function! 4 | from black_box_svi import black_box_variational_inference 5 | 6 | import autograd.numpy as np 7 | import autograd.scipy.stats.norm as norm 8 | from autograd.misc.optimizers import adam, sgd 9 | 10 | if __name__ == "__main__": 11 | # Specify an inference problem by its unnormalized log-density. 12 | # it's difficult to see the benefit in low dimensions 13 | # model parameters are a mean and a log_sigma 14 | np.random.seed(42) 15 | obs_dim = 20 16 | Y = np.random.randn(obs_dim, obs_dim).dot(np.random.randn(obs_dim)) 17 | 18 | def log_density(x, t): 19 | mu, log_sigma = x[:, :obs_dim], x[:, obs_dim:] 20 | sigma_density = np.sum(norm.logpdf(log_sigma, 0, 1.35), axis=1) 21 | mu_density = np.sum(norm.logpdf(Y, mu, np.exp(log_sigma)), axis=1) 22 | return sigma_density + mu_density 23 | 24 | # Build variational objective. 25 | D = obs_dim * 2 # dimension of our posterior 26 | objective, gradient, unpack_params = black_box_variational_inference(log_density, D, num_samples=2000) 27 | 28 | # Define the natural gradient 29 | # The natural gradient of the ELBO is the gradient of the elbo, 30 | # preconditioned by the inverse Fisher Information Matrix. The Fisher, 31 | # in the case of a diagonal gaussian, is a diagonal matrix that is a 32 | # simple function of the variance. Intuitively, statistical distance 33 | # created by perturbing the mean of an independent Gaussian is 34 | # determined by how wide the distribution is along that dimension --- 35 | # the wider the distribution, the less sensitive statistical distances is 36 | # to perturbations of the mean; the narrower the distribution, the more 37 | # the statistical distance changes when you perturb the mean (imagine 38 | # an extremely narrow Gaussian --- basically a spike. The KL between 39 | # this Gaussian and a Gaussian $\epsilon$ away in location can be big --- 40 | # moving the Gaussian could significantly reduce overlap in support 41 | # which corresponds to a greater statistical distance). 42 | # 43 | # When we want to move in directions of steepest ascent, we multiply by 44 | # the inverse fisher --- that way we make quicker progress when the 45 | # variance is wide, and we scale down our step size when the variance 46 | # is small (which leads to more robust/less chaotic ascent). 47 | def fisher_diag(lam): 48 | mu, log_sigma = unpack_params(lam) 49 | return np.concatenate([np.exp(-2.0 * log_sigma), np.ones(len(log_sigma)) * 2]) 50 | 51 | # simple! basically free! 52 | natural_gradient = lambda lam, i: (1.0 / fisher_diag(lam)) * gradient(lam, i) 53 | 54 | # function for keeping track of callback ELBO values (for plotting below) 55 | def optimize_and_lls(optfun): 56 | num_iters = 200 57 | elbos = [] 58 | 59 | def callback(params, t, g): 60 | elbo_val = -objective(params, t) 61 | elbos.append(elbo_val) 62 | if t % 50 == 0: 63 | print(f"Iteration {t} lower bound {elbo_val}") 64 | 65 | init_mean = -1 * np.ones(D) 66 | init_log_std = -5 * np.ones(D) 67 | init_var_params = np.concatenate([init_mean, init_log_std]) 68 | variational_params = optfun(num_iters, init_var_params, callback) 69 | return np.array(elbos) 70 | 71 | # let's optimize this with a few different step sizes 72 | elbo_lists = [] 73 | step_sizes = [0.1, 0.25, 0.5] 74 | for step_size in step_sizes: 75 | # optimize with standard gradient + adam 76 | optfun = lambda n, init, cb: adam(gradient, init, step_size=step_size, num_iters=n, callback=cb) 77 | standard_lls = optimize_and_lls(optfun) 78 | 79 | # optimize with natural gradient + sgd, no momentum 80 | optnat = lambda n, init, cb: sgd( 81 | natural_gradient, init, step_size=step_size, num_iters=n, callback=cb, mass=0.001 82 | ) 83 | natural_lls = optimize_and_lls(optnat) 84 | elbo_lists.append((standard_lls, natural_lls)) 85 | 86 | # visually compare the ELBO 87 | plt.figure(figsize=(12, 8)) 88 | colors = ["b", "k", "g"] 89 | for col, ss, (stand_lls, nat_lls) in zip(colors, step_sizes, elbo_lists): 90 | plt.plot( 91 | np.arange(len(stand_lls)), 92 | stand_lls, 93 | "--", 94 | label="standard (adam, step-size = %2.2f)" % ss, 95 | alpha=0.5, 96 | c=col, 97 | ) 98 | plt.plot(np.arange(len(nat_lls)), nat_lls, "-", label="natural (sgd, step-size = %2.2f)" % ss, c=col) 99 | 100 | llrange = natural_lls.max() - natural_lls.min() 101 | plt.ylim((natural_lls.max() - llrange * 0.1, natural_lls.max() + 10)) 102 | plt.xlabel("optimization iteration") 103 | plt.ylabel("ELBO") 104 | plt.legend(loc="lower right") 105 | plt.title("%d dimensional posterior" % D) 106 | plt.show() 107 | -------------------------------------------------------------------------------- /examples/negative_binomial_maxlike.py: -------------------------------------------------------------------------------- 1 | import scipy.optimize 2 | 3 | import autograd.numpy as np 4 | import autograd.numpy.random as npr 5 | from autograd import grad 6 | from autograd.scipy.special import gammaln 7 | 8 | # The code in this example implements a method for finding a stationary point of 9 | # the negative binomial likelihood via Newton's method, described here: 10 | # https://en.wikipedia.org/wiki/Negative_binomial_distribution#Maximum_likelihood_estimation 11 | 12 | 13 | def newton(f, x0): 14 | # wrap scipy.optimize.newton with our automatic derivatives 15 | return scipy.optimize.newton(f, x0, fprime=grad(f), fprime2=grad(grad(f))) 16 | 17 | 18 | def negbin_loglike(r, p, x): 19 | # the negative binomial log likelihood we want to maximize 20 | return gammaln(r + x) - gammaln(r) - gammaln(x + 1) + x * np.log(p) + r * np.log(1 - p) 21 | 22 | 23 | def negbin_sample(r, p, size): 24 | # a negative binomial is a gamma-compound-Poisson 25 | return npr.poisson(npr.gamma(r, p / (1 - p), size=size)) 26 | 27 | 28 | def fit_maxlike(x, r_guess): 29 | # follows Wikipedia's section on negative binomial max likelihood 30 | assert np.var(x) > np.mean(x), "Likelihood-maximizing parameters don't exist!" 31 | loglike = lambda r, p: np.sum(negbin_loglike(r, p, x)) 32 | p = lambda r: np.sum(x) / np.sum(r + x) 33 | rprime = lambda r: grad(loglike)(r, p(r)) 34 | r = newton(rprime, r_guess) 35 | return r, p(r) 36 | 37 | 38 | if __name__ == "__main__": 39 | # generate data 40 | npr.seed(0) 41 | data = negbin_sample(r=5, p=0.5, size=1000) 42 | 43 | # fit likelihood-extremizing parameters 44 | r, p = fit_maxlike(data, r_guess=1) 45 | 46 | # report fit 47 | print("Fit parameters:") 48 | print(f"r={r}, p={p}") 49 | 50 | print("Check that we are at a local stationary point:") 51 | loglike = lambda r, p: np.sum(negbin_loglike(r, p, data)) 52 | grad_both = grad(loglike, argnum=(0, 1)) 53 | print(grad_both(r, p)) 54 | 55 | import matplotlib.pyplot as plt 56 | 57 | xm = data.max() 58 | plt.figure() 59 | plt.hist(data, bins=np.arange(xm + 1) - 0.5, normed=True, label="normed data counts") 60 | plt.xlim(0, xm) 61 | plt.plot(np.arange(xm), np.exp(negbin_loglike(r, p, np.arange(xm))), label="maxlike fit") 62 | plt.xlabel("k") 63 | plt.ylabel("p(k)") 64 | plt.legend(loc="best") 65 | plt.show() 66 | -------------------------------------------------------------------------------- /examples/neural_net.py: -------------------------------------------------------------------------------- 1 | """A multi-layer perceptron for classification of MNIST handwritten digits.""" 2 | 3 | from data import load_mnist 4 | 5 | import autograd.numpy as np 6 | import autograd.numpy.random as npr 7 | from autograd import grad 8 | from autograd.misc.flatten import flatten 9 | from autograd.misc.optimizers import adam 10 | from autograd.scipy.special import logsumexp 11 | 12 | 13 | def init_random_params(scale, layer_sizes, rs=npr.RandomState(0)): 14 | """Build a list of (weights, biases) tuples, 15 | one for each layer in the net.""" 16 | return [ 17 | ( 18 | scale * rs.randn(m, n), # weight matrix 19 | scale * rs.randn(n), 20 | ) # bias vector 21 | for m, n in zip(layer_sizes[:-1], layer_sizes[1:]) 22 | ] 23 | 24 | 25 | def neural_net_predict(params, inputs): 26 | """Implements a deep neural network for classification. 27 | params is a list of (weights, bias) tuples. 28 | inputs is an (N x D) matrix. 29 | returns normalized class log-probabilities.""" 30 | for W, b in params: 31 | outputs = np.dot(inputs, W) + b 32 | inputs = np.tanh(outputs) 33 | return outputs - logsumexp(outputs, axis=1, keepdims=True) 34 | 35 | 36 | def l2_norm(params): 37 | """Computes l2 norm of params by flattening them into a vector.""" 38 | flattened, _ = flatten(params) 39 | return np.dot(flattened, flattened) 40 | 41 | 42 | def log_posterior(params, inputs, targets, L2_reg): 43 | log_prior = -L2_reg * l2_norm(params) 44 | log_lik = np.sum(neural_net_predict(params, inputs) * targets) 45 | return log_prior + log_lik 46 | 47 | 48 | def accuracy(params, inputs, targets): 49 | target_class = np.argmax(targets, axis=1) 50 | predicted_class = np.argmax(neural_net_predict(params, inputs), axis=1) 51 | return np.mean(predicted_class == target_class) 52 | 53 | 54 | if __name__ == "__main__": 55 | # Model parameters 56 | layer_sizes = [784, 200, 100, 10] 57 | L2_reg = 1.0 58 | 59 | # Training parameters 60 | param_scale = 0.1 61 | batch_size = 256 62 | num_epochs = 5 63 | step_size = 0.001 64 | 65 | print("Loading training data...") 66 | N, train_images, train_labels, test_images, test_labels = load_mnist() 67 | 68 | init_params = init_random_params(param_scale, layer_sizes) 69 | 70 | num_batches = int(np.ceil(len(train_images) / batch_size)) 71 | 72 | def batch_indices(iter): 73 | idx = iter % num_batches 74 | return slice(idx * batch_size, (idx + 1) * batch_size) 75 | 76 | # Define training objective 77 | def objective(params, iter): 78 | idx = batch_indices(iter) 79 | return -log_posterior(params, train_images[idx], train_labels[idx], L2_reg) 80 | 81 | # Get gradient of objective using autograd. 82 | objective_grad = grad(objective) 83 | 84 | print(" Epoch | Train accuracy | Test accuracy ") 85 | 86 | def print_perf(params, iter, gradient): 87 | if iter % num_batches == 0: 88 | train_acc = accuracy(params, train_images, train_labels) 89 | test_acc = accuracy(params, test_images, test_labels) 90 | print(f"{iter // num_batches:15}|{train_acc:20}|{test_acc:20}") 91 | 92 | # The optimizers provided can optimize lists, tuples, or dicts of parameters. 93 | optimized_params = adam( 94 | objective_grad, 95 | init_params, 96 | step_size=step_size, 97 | num_iters=num_epochs * num_batches, 98 | callback=print_perf, 99 | ) 100 | -------------------------------------------------------------------------------- /examples/neural_net_regression.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import autograd.numpy as np 4 | import autograd.numpy.random as npr 5 | import autograd.scipy.stats.norm as norm 6 | from autograd import grad 7 | from autograd.misc import flatten 8 | from autograd.misc.optimizers import adam 9 | 10 | 11 | def init_random_params(scale, layer_sizes, rs=npr.RandomState(0)): 12 | """Build a list of (weights, biases) tuples, one for each layer.""" 13 | return [ 14 | ( 15 | rs.randn(insize, outsize) * scale, # weight matrix 16 | rs.randn(outsize) * scale, 17 | ) # bias vector 18 | for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:]) 19 | ] 20 | 21 | 22 | def nn_predict(params, inputs, nonlinearity=np.tanh): 23 | for W, b in params: 24 | outputs = np.dot(inputs, W) + b 25 | inputs = nonlinearity(outputs) 26 | return outputs 27 | 28 | 29 | def log_gaussian(params, scale): 30 | flat_params, _ = flatten(params) 31 | return np.sum(norm.logpdf(flat_params, 0, scale)) 32 | 33 | 34 | def logprob(weights, inputs, targets, noise_scale=0.1): 35 | predictions = nn_predict(weights, inputs) 36 | return np.sum(norm.logpdf(predictions, targets, noise_scale)) 37 | 38 | 39 | def build_toy_dataset(n_data=80, noise_std=0.1): 40 | rs = npr.RandomState(0) 41 | inputs = np.concatenate([np.linspace(0, 3, num=n_data / 2), np.linspace(6, 8, num=n_data / 2)]) 42 | targets = np.cos(inputs) + rs.randn(n_data) * noise_std 43 | inputs = (inputs - 4.0) / 2.0 44 | inputs = inputs[:, np.newaxis] 45 | targets = targets[:, np.newaxis] / 2.0 46 | return inputs, targets 47 | 48 | 49 | if __name__ == "__main__": 50 | init_scale = 0.1 51 | weight_prior_variance = 10.0 52 | init_params = init_random_params(init_scale, layer_sizes=[1, 4, 4, 1]) 53 | 54 | inputs, targets = build_toy_dataset() 55 | 56 | def objective(weights, t): 57 | return -logprob(weights, inputs, targets) - log_gaussian(weights, weight_prior_variance) 58 | 59 | print(grad(objective)(init_params, 0)) 60 | 61 | # Set up figure. 62 | fig = plt.figure(figsize=(12, 8), facecolor="white") 63 | ax = fig.add_subplot(111, frameon=False) 64 | plt.show(block=False) 65 | 66 | def callback(params, t, g): 67 | print(f"Iteration {t} log likelihood {-objective(params, t)}") 68 | 69 | # Plot data and functions. 70 | plt.cla() 71 | ax.plot(inputs.ravel(), targets.ravel(), "bx", ms=12) 72 | plot_inputs = np.reshape(np.linspace(-7, 7, num=300), (300, 1)) 73 | outputs = nn_predict(params, plot_inputs) 74 | ax.plot(plot_inputs, outputs, "r", lw=3) 75 | ax.set_ylim([-1, 1]) 76 | plt.draw() 77 | plt.pause(1.0 / 60.0) 78 | 79 | print("Optimizing network parameters...") 80 | optimized_params = adam(grad(objective), init_params, step_size=0.01, num_iters=1000, callback=callback) 81 | -------------------------------------------------------------------------------- /examples/ode_net.py: -------------------------------------------------------------------------------- 1 | # A demo of gradients through scipy.integrate.odeint, 2 | # estimating the dynamics of a system given a trajectory. 3 | 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as npo 7 | 8 | import autograd.numpy as np 9 | import autograd.numpy.random as npr 10 | from autograd import grad 11 | from autograd.builtins import tuple 12 | from autograd.misc.optimizers import adam 13 | from autograd.scipy.integrate import odeint 14 | 15 | N = 30 # Dataset size 16 | D = 2 # Data dimension 17 | max_T = 1.5 18 | 19 | 20 | # Two-dimensional damped oscillator 21 | def func(y, t0, A): 22 | return np.dot(y**3, A) 23 | 24 | 25 | def nn_predict(inputs, t, params): 26 | for W, b in params: 27 | outputs = np.dot(inputs, W) + b 28 | inputs = np.maximum(0, outputs) 29 | return outputs 30 | 31 | 32 | def init_nn_params(scale, layer_sizes, rs=npr.RandomState(0)): 33 | """Build a list of (weights, biases) tuples, one for each layer.""" 34 | return [ 35 | ( 36 | rs.randn(insize, outsize) * scale, # weight matrix 37 | rs.randn(outsize) * scale, 38 | ) # bias vector 39 | for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:]) 40 | ] 41 | 42 | 43 | # Define neural ODE model. 44 | def ode_pred(params, y0, t): 45 | return odeint(nn_predict, y0, t, tuple((params,)), rtol=0.01) 46 | 47 | 48 | def L1_loss(pred, targets): 49 | return np.mean(np.abs(pred - targets)) 50 | 51 | 52 | if __name__ == "__main__": 53 | # Generate data from true dynamics. 54 | true_y0 = np.array([2.0, 0.0]).T 55 | t = np.linspace(0.0, max_T, N) 56 | true_A = np.array([[-0.1, 2.0], [-2.0, -0.1]]) 57 | true_y = odeint(func, true_y0, t, args=(true_A,)) 58 | 59 | def train_loss(params, iter): 60 | pred = ode_pred(params, true_y0, t) 61 | return L1_loss(pred, true_y) 62 | 63 | # Set up figure 64 | fig = plt.figure(figsize=(12, 4), facecolor="white") 65 | ax_traj = fig.add_subplot(131, frameon=False) 66 | ax_phase = fig.add_subplot(132, frameon=False) 67 | ax_vecfield = fig.add_subplot(133, frameon=False) 68 | plt.show(block=False) 69 | 70 | # Plots data and learned dynamics. 71 | def callback(params, iter, g): 72 | pred = ode_pred(params, true_y0, t) 73 | 74 | print(f"Iteration {iter:d} train loss {L1_loss(pred, true_y):.6f}") 75 | 76 | ax_traj.cla() 77 | ax_traj.set_title("Trajectories") 78 | ax_traj.set_xlabel("t") 79 | ax_traj.set_ylabel("x,y") 80 | ax_traj.plot(t, true_y[:, 0], "-", t, true_y[:, 1], "g-") 81 | ax_traj.plot(t, pred[:, 0], "--", t, pred[:, 1], "b--") 82 | ax_traj.set_xlim(t.min(), t.max()) 83 | ax_traj.set_ylim(-2, 2) 84 | ax_traj.xaxis.set_ticklabels([]) 85 | ax_traj.yaxis.set_ticklabels([]) 86 | ax_traj.legend() 87 | 88 | ax_phase.cla() 89 | ax_phase.set_title("Phase Portrait") 90 | ax_phase.set_xlabel("x") 91 | ax_phase.set_ylabel("y") 92 | ax_phase.plot(true_y[:, 0], true_y[:, 1], "g-") 93 | ax_phase.plot(pred[:, 0], pred[:, 1], "b--") 94 | ax_phase.set_xlim(-2, 2) 95 | ax_phase.set_ylim(-2, 2) 96 | ax_phase.xaxis.set_ticklabels([]) 97 | ax_phase.yaxis.set_ticklabels([]) 98 | 99 | ax_vecfield.cla() 100 | ax_vecfield.set_title("Learned Vector Field") 101 | ax_vecfield.set_xlabel("x") 102 | ax_vecfield.set_ylabel("y") 103 | ax_vecfield.xaxis.set_ticklabels([]) 104 | ax_vecfield.yaxis.set_ticklabels([]) 105 | 106 | # vector field plot 107 | y, x = npo.mgrid[-2:2:21j, -2:2:21j] 108 | dydt = nn_predict(np.stack([x, y], -1).reshape(21 * 21, 2), 0, params).reshape(-1, 2) 109 | mag = np.sqrt(dydt[:, 0] ** 2 + dydt[:, 1] ** 2).reshape(-1, 1) 110 | dydt = dydt / mag 111 | dydt = dydt.reshape(21, 21, 2) 112 | 113 | ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black") 114 | ax_vecfield.set_xlim(-2, 2) 115 | ax_vecfield.set_ylim(-2, 2) 116 | 117 | fig.tight_layout() 118 | plt.draw() 119 | plt.pause(0.001) 120 | 121 | # Train neural net dynamics to match data. 122 | init_params = init_nn_params(0.1, layer_sizes=[D, 150, D]) 123 | optimized_params = adam(grad(train_loss), init_params, num_iters=1000, callback=callback) 124 | -------------------------------------------------------------------------------- /examples/ode_net_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/ode_net_demo.png -------------------------------------------------------------------------------- /examples/print_trace.py: -------------------------------------------------------------------------------- 1 | """Demonstrates how to use the tracer module, independent of autodiff, by 2 | creating a trace that prints out functions and their arguments as they're being 3 | evaluated""" 4 | 5 | import autograd.numpy as np # autograd has already wrapped numpy for us 6 | from autograd.tracer import Node, trace 7 | 8 | 9 | class PrintNode(Node): 10 | def __init__(self, value, fun, args, kwargs, parent_argnums, parents): 11 | self.varname_generator = parents[0].varname_generator 12 | self.varname = next(self.varname_generator) 13 | args_or_vars = list(args) 14 | for argnum, parent in zip(parent_argnums, parents): 15 | args_or_vars[argnum] = parent.varname 16 | print("{} = {}({}) = {}".format(self.varname, fun.__name__, ",".join(map(str, args_or_vars)), value)) 17 | 18 | def initialize_root(self, x): 19 | self.varname_generator = make_varname_generator() 20 | self.varname = next(self.varname_generator) 21 | print(f"{self.varname} = {x}") 22 | 23 | 24 | def make_varname_generator(): 25 | for i in range(65, 91): 26 | yield chr(i) 27 | raise Exception("Ran out of alphabet!") 28 | 29 | 30 | def print_trace(f, x): 31 | start_node = PrintNode.new_root(x) 32 | trace(start_node, f, x) 33 | print() 34 | 35 | 36 | def avg(x, y): 37 | return (x + y) / 2 38 | 39 | 40 | def fun(x): 41 | y = np.sin(x + x) 42 | return avg(y, y) 43 | 44 | 45 | print_trace(fun, 1.23) 46 | 47 | # Traces can be nested, so we can also trace through grad(fun) 48 | from autograd import grad 49 | 50 | print_trace(grad(fun), 1.0) 51 | -------------------------------------------------------------------------------- /examples/rkhs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inferring a function from a reproducing kernel Hilbert space (RKHS) by taking 3 | gradients of eval with respect to the function-valued argument 4 | """ 5 | 6 | import autograd.numpy as np 7 | import autograd.numpy.random as npr 8 | from autograd import grad 9 | from autograd.extend import Box, VSpace, defvjp, primitive 10 | from autograd.util import func 11 | 12 | 13 | class RKHSFun: 14 | def __init__(self, kernel, alphas={}): 15 | self.alphas = alphas 16 | self.kernel = kernel 17 | self.vs = RKHSFunVSpace(self) 18 | 19 | @primitive 20 | def __call__(self, x): 21 | return sum([a * self.kernel(x, x_repr) for x_repr, a in self.alphas.items()], 0.0) 22 | 23 | def __add__(self, f): 24 | return self.vs.add(self, f) 25 | 26 | def __mul__(self, a): 27 | return self.vs.scalar_mul(self, a) 28 | 29 | 30 | # TODO: add vjp of __call__ wrt x (and show it in action) 31 | defvjp(func(RKHSFun.__call__), lambda ans, f, x: lambda g: RKHSFun(f.kernel, {x: 1}) * g) 32 | 33 | 34 | class RKHSFunBox(Box, RKHSFun): 35 | @property 36 | def kernel(self): 37 | return self._value.kernel 38 | 39 | 40 | RKHSFunBox.register(RKHSFun) 41 | 42 | 43 | class RKHSFunVSpace(VSpace): 44 | def __init__(self, value): 45 | self.kernel = value.kernel 46 | 47 | def zeros(self): 48 | return RKHSFun(self.kernel) 49 | 50 | def randn(self): 51 | # These arbitrary vectors are not analogous to randn in any meaningful way 52 | N = npr.randint(1, 3) 53 | return RKHSFun(self.kernel, dict(zip(npr.randn(N), npr.randn(N)))) 54 | 55 | def _add(self, f, g): 56 | assert f.kernel is g.kernel 57 | return RKHSFun(f.kernel, add_dicts(f.alphas, g.alphas)) 58 | 59 | def _scalar_mul(self, f, a): 60 | return RKHSFun(f.kernel, {x: a * a_cur for x, a_cur in f.alphas.items()}) 61 | 62 | def _inner_prod(self, f, g): 63 | assert f.kernel is g.kernel 64 | return sum( 65 | [a1 * a2 * f.kernel(x1, x2) for x1, a1 in f.alphas.items() for x2, a2 in g.alphas.items()], 0.0 66 | ) 67 | 68 | 69 | RKHSFunVSpace.register(RKHSFun) 70 | 71 | 72 | def add_dicts(d1, d2): 73 | d = {} 74 | for k, v in d1.items() + d2.items(): 75 | d[k] = d[k] + v if k in d else v 76 | return d 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | def sq_exp_kernel(x1, x2): 82 | return np.exp(-((x1 - x2) ** 2)) 83 | 84 | xs = range(5) 85 | ys = [1, 2, 3, 2, 1] 86 | 87 | def logprob(f, xs, ys): 88 | return -sum((f(x) - y) ** 2 for x, y in zip(xs, ys)) 89 | 90 | f = RKHSFun(sq_exp_kernel) 91 | for i in range(100): 92 | f = f + grad(logprob)(f, xs, ys) * 0.01 93 | 94 | for x, y in zip(xs, ys): 95 | print(f"{x}\t{y}\t{f(x)}") 96 | -------------------------------------------------------------------------------- /examples/rnn.py: -------------------------------------------------------------------------------- 1 | """Implements the long-short term memory character model. 2 | This version vectorizes over multiple examples, but each string 3 | has a fixed length.""" 4 | 5 | from os.path import dirname, join 6 | 7 | import autograd.numpy as np 8 | import autograd.numpy.random as npr 9 | from autograd import grad 10 | from autograd.misc.optimizers import adam 11 | from autograd.scipy.special import logsumexp 12 | 13 | ### Helper functions ################# 14 | 15 | 16 | def sigmoid(x): 17 | return 0.5 * (np.tanh(x) + 1.0) # Output ranges from 0 to 1. 18 | 19 | 20 | def concat_and_multiply(weights, *args): 21 | cat_state = np.hstack(args + (np.ones((args[0].shape[0], 1)),)) 22 | return np.dot(cat_state, weights) 23 | 24 | 25 | ### Define recurrent neural net ####### 26 | 27 | 28 | def create_rnn_params(input_size, state_size, output_size, param_scale=0.01, rs=npr.RandomState(0)): 29 | return { 30 | "init hiddens": rs.randn(1, state_size) * param_scale, 31 | "change": rs.randn(input_size + state_size + 1, state_size) * param_scale, 32 | "predict": rs.randn(state_size + 1, output_size) * param_scale, 33 | } 34 | 35 | 36 | def rnn_predict(params, inputs): 37 | def update_rnn(input, hiddens): 38 | return np.tanh(concat_and_multiply(params["change"], input, hiddens)) 39 | 40 | def hiddens_to_output_probs(hiddens): 41 | output = concat_and_multiply(params["predict"], hiddens) 42 | return output - logsumexp(output, axis=1, keepdims=True) # Normalize log-probs. 43 | 44 | num_sequences = inputs.shape[1] 45 | hiddens = np.repeat(params["init hiddens"], num_sequences, axis=0) 46 | output = [hiddens_to_output_probs(hiddens)] 47 | 48 | for input in inputs: # Iterate over time steps. 49 | hiddens = update_rnn(input, hiddens) 50 | output.append(hiddens_to_output_probs(hiddens)) 51 | return output 52 | 53 | 54 | def rnn_log_likelihood(params, inputs, targets): 55 | logprobs = rnn_predict(params, inputs) 56 | loglik = 0.0 57 | num_time_steps, num_examples, _ = inputs.shape 58 | for t in range(num_time_steps): 59 | loglik += np.sum(logprobs[t] * targets[t]) 60 | return loglik / (num_time_steps * num_examples) 61 | 62 | 63 | ### Dataset setup ################## 64 | 65 | 66 | def string_to_one_hot(string, maxchar): 67 | """Converts an ASCII string to a one-of-k encoding.""" 68 | ascii = np.array([ord(c) for c in string]).T 69 | return np.array(ascii[:, None] == np.arange(maxchar)[None, :], dtype=int) 70 | 71 | 72 | def one_hot_to_string(one_hot_matrix): 73 | return "".join([chr(np.argmax(c)) for c in one_hot_matrix]) 74 | 75 | 76 | def build_dataset(filename, sequence_length, alphabet_size, max_lines=-1): 77 | """Loads a text file, and turns each line into an encoded sequence.""" 78 | with open(filename) as f: 79 | content = f.readlines() 80 | content = content[:max_lines] 81 | content = [line for line in content if len(line) > 2] # Remove blank lines 82 | seqs = np.zeros((sequence_length, len(content), alphabet_size)) 83 | for ix, line in enumerate(content): 84 | padded_line = (line + " " * sequence_length)[:sequence_length] 85 | seqs[:, ix, :] = string_to_one_hot(padded_line, alphabet_size) 86 | return seqs 87 | 88 | 89 | if __name__ == "__main__": 90 | num_chars = 128 91 | 92 | # Learn to predict our own source code. 93 | text_filename = join(dirname(__file__), "rnn.py") 94 | train_inputs = build_dataset(text_filename, sequence_length=30, alphabet_size=num_chars, max_lines=60) 95 | 96 | init_params = create_rnn_params(input_size=128, output_size=128, state_size=40, param_scale=0.01) 97 | 98 | def print_training_prediction(weights): 99 | print("Training text Predicted text") 100 | logprobs = np.asarray(rnn_predict(weights, train_inputs)) 101 | for t in range(logprobs.shape[1]): 102 | training_text = one_hot_to_string(train_inputs[:, t, :]) 103 | predicted_text = one_hot_to_string(logprobs[:, t, :]) 104 | print(training_text.replace("\n", " ") + "|" + predicted_text.replace("\n", " ")) 105 | 106 | def training_loss(params, iter): 107 | return -rnn_log_likelihood(params, train_inputs, train_inputs) 108 | 109 | def callback(weights, iter, gradient): 110 | if iter % 10 == 0: 111 | print("Iteration", iter, "Train loss:", training_loss(weights, 0)) 112 | print_training_prediction(weights) 113 | 114 | # Build gradient of loss function using autograd. 115 | training_loss_grad = grad(training_loss) 116 | 117 | print("Training RNN...") 118 | trained_params = adam(training_loss_grad, init_params, step_size=0.1, num_iters=1000, callback=callback) 119 | 120 | print() 121 | print("Generating text from RNN...") 122 | num_letters = 30 123 | for t in range(20): 124 | text = "" 125 | for i in range(num_letters): 126 | seqs = string_to_one_hot(text, num_chars)[:, np.newaxis, :] 127 | logprobs = rnn_predict(trained_params, seqs)[-1].ravel() 128 | text += chr(npr.choice(len(logprobs), p=np.exp(logprobs))) 129 | print(text) 130 | -------------------------------------------------------------------------------- /examples/rosenbrock.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import minimize 2 | 3 | import autograd.numpy as np 4 | from autograd import value_and_grad 5 | 6 | 7 | def rosenbrock(x): 8 | return 100 * (x[1] - x[0] ** 2) ** 2 + (1 - x[0]) ** 2 9 | 10 | 11 | # Build a function that also returns gradients using autograd. 12 | rosenbrock_with_grad = value_and_grad(rosenbrock) 13 | 14 | # Optimize using conjugate gradients. 15 | result = minimize(rosenbrock_with_grad, x0=np.array([0.0, 0.0]), jac=True, method="CG") 16 | print(f"Found minimum at {result.x}") 17 | -------------------------------------------------------------------------------- /examples/sinusoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/sinusoid.png -------------------------------------------------------------------------------- /examples/sinusoid.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import autograd.numpy as np 4 | from autograd import grad 5 | 6 | 7 | def fun(x): 8 | return np.sin(x) 9 | 10 | 11 | d_fun = grad(fun) # First derivative 12 | dd_fun = grad(d_fun) # Second derivative 13 | 14 | x = np.linspace(-10, 10, 100) 15 | plt.plot(x, list(map(fun, x)), x, list(map(d_fun, x)), x, list(map(dd_fun, x))) 16 | 17 | plt.xlim([-10, 10]) 18 | plt.ylim([-1.2, 1.2]) 19 | plt.axis("off") 20 | plt.savefig("sinusoid.png") 21 | plt.clf() 22 | 23 | 24 | # Taylor approximation to sin function 25 | def fun(x): 26 | currterm = x 27 | ans = currterm 28 | for i in range(1000): 29 | print(i, end=" ") 30 | currterm = -currterm * x**2 / ((2 * i + 3) * (2 * i + 2)) 31 | ans = ans + currterm 32 | if np.abs(currterm) < 0.2: 33 | break # (Very generous tolerance!) 34 | 35 | return ans 36 | 37 | 38 | d_fun = grad(fun) 39 | dd_fun = grad(d_fun) 40 | 41 | x = np.linspace(-10, 10, 100) 42 | plt.plot(x, list(map(fun, x)), x, list(map(d_fun, x)), x, list(map(dd_fun, x))) 43 | 44 | plt.xlim([-10, 10]) 45 | plt.ylim([-1.2, 1.2]) 46 | plt.axis("off") 47 | plt.savefig("sinusoid_taylor.png") 48 | plt.clf() 49 | -------------------------------------------------------------------------------- /examples/sinusoid_taylor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/sinusoid_taylor.png -------------------------------------------------------------------------------- /examples/tanh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/tanh.png -------------------------------------------------------------------------------- /examples/tanh.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import autograd.numpy as np 4 | from autograd import elementwise_grad as egrad 5 | 6 | """ 7 | Mathematically we can only take gradients of scalar-valued functions, but 8 | autograd's elementwise_grad function also handles numpy's familiar vectorization 9 | of scalar functions, which is used in this example. 10 | 11 | To be precise, elementwise_grad(fun)(x) always returns the value of a 12 | vector-Jacobian product, where the Jacobian of fun is evaluated at x and the 13 | vector is an all-ones vector with the same size as the output of fun. When 14 | vectorizing a scalar-valued function over many arguments, the Jacobian of the 15 | overall vector-to-vector mapping is diagonal, and so this vector-Jacobian 16 | product simply returns the diagonal elements of the Jacobian, which is the 17 | (elementwise) gradient of the function at each input value over which the 18 | function is vectorized. 19 | """ 20 | 21 | 22 | def tanh(x): 23 | return (1.0 - np.exp((-2 * x))) / (1.0 + np.exp(-(2 * x))) 24 | 25 | 26 | ### Plotting 27 | plt.figure(figsize=(12, 8)) 28 | x = np.linspace(-7, 7, 700) 29 | plt.plot(x, tanh(x), label="tanh(x)") 30 | plt.plot(x, egrad(tanh)(x), label="1st derivative") 31 | plt.plot(x, egrad(egrad(tanh))(x), label="2nd derivative") 32 | plt.plot(x, egrad(egrad(egrad(tanh)))(x), label="3rd derivative") 33 | plt.plot(x, egrad(egrad(egrad(egrad(tanh))))(x), label="4th derivative") 34 | plt.xlabel("x") 35 | plt.ylabel("y") 36 | plt.ylim(-5, 5) 37 | plt.yticks(np.arange(-5, 6, 1)) 38 | plt.legend() 39 | plt.grid(True) 40 | plt.title("tanh(x) and its derivatives") 41 | plt.savefig("tanh.png") 42 | plt.show() 43 | -------------------------------------------------------------------------------- /examples/vae_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HIPS/autograd/18e8db44839276951dc9d2dfaad308e5e8a46c46/examples/vae_samples.png -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2025 by the President and Fellows of Harvard University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import nox 4 | 5 | NIGHTLY_INDEX_URL = "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" 6 | UV_NIGHTLY_ENV_VARS = { 7 | "UV_INDEX_URL": NIGHTLY_INDEX_URL, 8 | "UV_PRERELEASE": "allow", 9 | "UV_INDEX_STRATEGY": "first-index", 10 | } 11 | 12 | nox.needs_version = ">=2024.4.15" 13 | nox.options.default_venv_backend = "uv|virtualenv" 14 | nox.options.reuse_existing_virtualenvs = False 15 | nox.options.error_on_external_run = True 16 | # nox.options.sessions = ["lint", "validate-package", "tests"] 17 | nox.options.sessions = ["tests"] 18 | 19 | 20 | @nox.session(name="validate-package") 21 | def check(session): 22 | """Build source distribution, wheel, and check their metadata""" 23 | session.install("build", "twine", silent=False) 24 | session.run("python", "-m", "build") 25 | session.run("twine", "check", "--strict", "dist/*") 26 | 27 | 28 | @nox.session(name="tests", tags=["tests"]) 29 | def run_tests(session): 30 | """Run unit tests and generate a coverage report""" 31 | # SciPy doesn't have wheels on PyPy 32 | if platform.python_implementation() == "PyPy": 33 | session.install("-e", ".[test]", silent=False) 34 | else: 35 | session.install("-e", ".[test,scipy]", silent=False) 36 | session.run("pytest", "--cov=autograd", "--cov-report=xml", "--cov-append", *session.posargs) 37 | 38 | 39 | @nox.session(name="lint", reuse_venv=True) 40 | def ruff(session): 41 | """Lightning-fast linting for Python""" 42 | session.install("pre-commit", silent=False) 43 | session.run("pre-commit", "run", "--all-files", "--show-diff-on-failure") 44 | 45 | 46 | @nox.session(name="nightly-tests", tags=["tests"]) 47 | def run_nightly_tests(session): 48 | """Run tests against nightly versions of dependencies""" 49 | session.install("-e", ".[test]", silent=False) 50 | # SciPy doesn't have wheels on PyPy 51 | if platform.python_implementation() == "PyPy": 52 | session.install( 53 | "numpy", "--upgrade", "--only-binary", ":all:", silent=False, env=UV_NIGHTLY_ENV_VARS 54 | ) 55 | else: 56 | session.install( 57 | "numpy", "scipy", "--upgrade", "--only-binary", ":all:", silent=False, env=UV_NIGHTLY_ENV_VARS 58 | ) 59 | session.run("pytest", "--cov=autograd", "--cov-report=xml", "--cov-append", *session.posargs) 60 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "autograd" 7 | version = "1.8.0" 8 | requires-python = ">=3.9" 9 | description = "Efficiently computes derivatives of NumPy code." 10 | readme = "README.md" 11 | license = {file = "license.txt"} 12 | authors = [ 13 | {name = "Dougal Maclaurin", email = "maclaurin@physics.harvard.edu"}, 14 | {name = "David Duvenaud", email = "duvenaud@cs.toronto.edu"}, 15 | {name = "Matthew Johnson", email = "mattjj@csail.mit.edu"}, 16 | {name = "Jamie Townsend", email = "j.h.n.townsend@uva.nl"}, 17 | ] 18 | maintainers = [ 19 | {name = "Jamie Townsend", email = "j.h.n.townsend@uva.nl"}, 20 | {name = "Fabian Joswig", email = "fabian.joswig@uni-muenster.de"}, 21 | {name = "Agriya Khetarpal", email = "agriyakhetarpal@outlook.com"}, 22 | ] 23 | classifiers = [ 24 | "Development Status :: 4 - Beta", 25 | "Intended Audience :: Information Technology", 26 | "Intended Audience :: Science/Research", 27 | "License :: OSI Approved :: MIT License", 28 | "Programming Language :: Python :: 3.9", 29 | "Programming Language :: Python :: 3.10", 30 | "Programming Language :: Python :: 3.11", 31 | "Programming Language :: Python :: 3.12", 32 | "Programming Language :: Python :: 3.13", 33 | "Topic :: Scientific/Engineering", 34 | ] 35 | keywords = [ 36 | "Automatic differentiation", 37 | "backpropagation", 38 | "gradients", 39 | "machine learning", 40 | "optimization", 41 | "neural networks", 42 | "Python", 43 | "NumPy", 44 | "SciPy", 45 | ] 46 | dependencies = [ 47 | "numpy<3", 48 | ] 49 | # dynamic = ["version"] 50 | 51 | [project.urls] 52 | Source = "https://github.com/HIPS/autograd" 53 | 54 | [project.optional-dependencies] 55 | scipy = [ 56 | "scipy", 57 | ] 58 | test = [ 59 | "pytest", 60 | "pytest-cov", 61 | "pytest-xdist", 62 | ] 63 | 64 | [tool.coverage.run] 65 | source = ["autograd"] 66 | 67 | [tool.coverage.report] 68 | show_missing = true 69 | 70 | [tool.pytest.ini_options] 71 | required_plugins = ["pytest-cov", "pytest-xdist"] 72 | # TODO: generate HTML report, upload to CodeCov 73 | addopts = "--color=yes -sra -n auto --cov=autograd --cov-report=xml --cov-report=term" 74 | 75 | [tool.ruff] 76 | extend-exclude = [] 77 | # TODO: not ignore them 78 | lint.extend-ignore = [ 79 | "E731", 80 | "F401", 81 | "F403", 82 | "F841", 83 | "F821", 84 | "E721", 85 | "E722", 86 | "E741", 87 | "E402", 88 | "F811" 89 | ] 90 | lint.extend-select = ["I", "W"] 91 | line-length = 109 92 | -------------------------------------------------------------------------------- /tests/_test_complexity.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | 4 | import autograd.numpy as np 5 | from autograd import deriv, grad 6 | from autograd.builtins import list as make_list 7 | 8 | 9 | def timefunction(f): 10 | t = time.time() 11 | f() 12 | return time.time() - t 13 | 14 | 15 | def assert_linear_time(f): 16 | t = timefunction(lambda: f(1)) 17 | t10 = timefunction(lambda: f(10)) 18 | assert t10 > 5 * t, f"Too fast: f(1) takes {t}, f(10) takes {t10}" 19 | assert t10 < 20 * t, f"Too slow: f(1) takes {t}, f(10) takes {t10}" 20 | if not (8 * t < t10 < 12 * t): 21 | warnings.warn("Borderline linearity. May fail on different hardware") 22 | 23 | 24 | def test_array_creation(): 25 | def fun(x, N): 26 | arr = [x for i in range(N)] 27 | return np.sum(np.array(arr)) 28 | 29 | assert_linear_time(lambda N: grad(fun)(1.0, 200 * N)) 30 | 31 | 32 | def test_array_indexing(): 33 | def fun(x): 34 | return sum([x[i] for i in range(len(x))]) 35 | 36 | assert_linear_time(lambda N: grad(fun)(np.zeros(200 * N))) 37 | 38 | 39 | def test_list_indexing(): 40 | def fun(x): 41 | return sum([x[i] for i in range(len(x))]) 42 | 43 | assert_linear_time(lambda N: grad(fun)([0.0 for i in range(50 * N)])) 44 | 45 | 46 | def test_list_creation(): 47 | def fun(x, N): 48 | return make_list(*[x for _ in range(N)]) 49 | 50 | assert_linear_time(lambda N: deriv(fun)(0.0, 20 * N)) 51 | 52 | 53 | # This fails. Need to figure out why 54 | def test_array_creation_fwd(): 55 | def fun(x, N): 56 | arr = [x for i in range(N)] 57 | return np.sum(np.array(arr)) 58 | 59 | assert_linear_time(lambda N: deriv(fun)(1.0, 400 * N)) 60 | -------------------------------------------------------------------------------- /tests/check_examples_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PYTHONPATH=".:$PYTHONPATH" 4 | trap 'kill -INT -$pid && exit 1' INT 5 | 6 | working=() 7 | failing=() 8 | 9 | examples=$(find examples -name '*.py' -not -name '__init__.py') 10 | 11 | echo 'Running all the examples...' 12 | for f in $examples; do 13 | timeout 15s python2 $f > /dev/null 2>&1 & pid=$! 14 | wait $pid 15 | status=$? 16 | if [ $status -eq 0 -o $status -eq 124 ]; then 17 | echo $f "seems to work" 18 | working+=($f) 19 | elif [ $status -eq 137 ]; then 20 | echo $f "might be working, but had to be killed" 21 | working+=($f) 22 | else 23 | echo $f "seems broken, try running manually" 24 | failing+=($f) 25 | fi 26 | done 27 | 28 | if [ ! ${#working[@]} -eq 0 ]; then 29 | echo -e '\033[01;36m' 30 | echo "These seemed to WORK:" 31 | echo -en '\033[00m' 32 | printf '%s\n' "${working[@]}" 33 | echo 34 | fi 35 | if [ ! ${#failing[@]} -eq 0 ]; then 36 | echo -e '\033[01;31m' 37 | echo "These seemed to FAIL:" 38 | echo -en '\033[00m' 39 | printf '%s\n' "${failing[@]}" 40 | echo 41 | fi 42 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | 5 | @pytest.fixture(autouse=True) 6 | def random_seed(): 7 | np.random.seed(42) 8 | -------------------------------------------------------------------------------- /tests/numpy_utils.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy.random as npr 2 | from autograd.test_util import combo_check 3 | 4 | 5 | def stat_check(fun, test_complex=True, **kwargs): 6 | # Tests functions that compute statistics, like sum, mean, etc 7 | x = 3.5 8 | A = npr.randn() 9 | B = npr.randn(3) 10 | C = npr.randn(2, 3) 11 | D = npr.randn(1, 3) 12 | check = combo_check(fun, (0,), **kwargs) 13 | check([x, A]) 14 | check([B, C, D], axis=[None, 0], keepdims=[True, False]) 15 | check([C, D], axis=[None, 0, 1], keepdims=[True, False]) 16 | if test_complex: 17 | c = npr.randn() + 0.1j * npr.randn() 18 | E = npr.randn(2, 3) + 0.1j * npr.randn(2, 3) 19 | check([x, c, A]) 20 | check([B, C, D, E], axis=[None, 0], keepdims=[True, False]) 21 | 22 | 23 | def unary_ufunc_check(fun, lims=[-2, 2], test_complex=True, **kwargs): 24 | scalar = transform(lims, 0.4) 25 | vector = transform(lims, npr.rand(2)) 26 | mat = transform(lims, npr.rand(3, 2)) 27 | mat2 = transform(lims, npr.rand(1, 2)) 28 | check = combo_check(fun, (0,), **kwargs) 29 | check([scalar, vector, mat, mat2]) 30 | if test_complex: 31 | comp = transform(lims, 0.4) + 0.1j * transform(lims, 0.3) 32 | matc = transform(lims, npr.rand(3, 2)) + 0.1j * npr.rand(3, 2) 33 | check([comp, matc]) 34 | 35 | 36 | def binary_ufunc_check(fun, lims_A=[-2, 2], lims_B=[-2, 2], test_complex=True, **kwargs): 37 | T_A = lambda x: transform(lims_A, x) 38 | T_B = lambda x: transform(lims_B, x) 39 | scalar = 0.6 40 | vector = npr.rand(2) 41 | mat = npr.rand(3, 2) 42 | mat2 = npr.rand(1, 2) 43 | check = combo_check(fun, (0, 1), **kwargs) 44 | check([T_A(scalar), T_A(vector), T_A(mat), T_A(mat2)], [T_B(scalar), T_B(vector), T_B(mat), T_B(mat2)]) 45 | if test_complex: 46 | comp = 0.6 + 0.3j 47 | matc = npr.rand(3, 2) + 0.1j * npr.rand(3, 2) 48 | check( 49 | [T_A(scalar), T_A(comp), T_A(vector), T_A(matc), T_A(mat2)], 50 | [T_B(scalar), T_B(comp), T_B(vector), T_B(matc), T_B(mat2)], 51 | ) 52 | 53 | 54 | def binary_ufunc_check_no_same_args(fun, lims_A=[-2, 2], lims_B=[-2, 2], test_complex=True, **kwargs): 55 | T_A = lambda x: transform(lims_A, x) 56 | T_B = lambda x: transform(lims_B, x) 57 | scalar1 = 0.6 58 | scalar2 = 0.7 59 | vector1 = npr.rand(2) 60 | vector2 = npr.rand(2) 61 | mat11 = npr.rand(3, 2) 62 | mat12 = npr.rand(3, 2) 63 | mat21 = npr.rand(1, 2) 64 | mat22 = npr.rand(1, 2) 65 | check = combo_check(fun, (0, 1), **kwargs) 66 | check( 67 | [T_A(scalar1), T_A(vector1), T_A(mat11), T_A(mat21)], 68 | [T_B(scalar2), T_B(vector2), T_B(mat12), T_B(mat22)], 69 | ) 70 | if test_complex: 71 | comp1 = 0.6 + 0.3j 72 | comp2 = 0.1 + 0.2j 73 | matc1 = npr.rand(3, 2) + 0.1j * npr.rand(3, 2) 74 | matc2 = npr.rand(3, 2) + 0.1j * npr.rand(3, 2) 75 | check( 76 | [T_A(scalar1), T_A(comp1), T_A(vector1), T_A(matc1), T_A(mat21)], 77 | [T_B(scalar2), T_B(comp2), T_B(vector2), T_B(matc2), T_B(mat22)], 78 | ) 79 | 80 | 81 | def transform(lims, x): 82 | return x * (lims[1] - lims[0]) + lims[0] 83 | -------------------------------------------------------------------------------- /tests/profiling.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from time import time 3 | 4 | import autograd.numpy as np 5 | import autograd.numpy.random as npr 6 | from autograd import grad 7 | 8 | 9 | @contextmanager 10 | def tictoc(text=""): 11 | print("--- Start clock ---") 12 | t1 = time() 13 | yield 14 | dt = time() - t1 15 | print(f"--- Stop clock {text}: {dt} seconds elapsed ---") 16 | 17 | 18 | def fan_out_fan_in(): 19 | """The 'Pearlmutter test'""" 20 | 21 | def fun(x): 22 | for i in range(10**4): 23 | x = (x + x) / 2.0 24 | return np.sum(x) 25 | 26 | with tictoc(): 27 | grad(fun)(1.0) 28 | 29 | 30 | def convolution(): 31 | # MNIST-scale convolution operation 32 | import autograd.scipy.signal 33 | 34 | convolve = autograd.scipy.signal.convolve 35 | dat = npr.randn(256, 3, 28, 28) 36 | kernel = npr.randn(3, 5, 5) 37 | with tictoc(): 38 | convolve(dat, kernel, axes=([2, 3], [1, 2]), dot_axes=([1], [0])) 39 | 40 | 41 | def dot_equivalent(): 42 | # MNIST-scale convolution operation 43 | 44 | dat = npr.randn(256, 3, 24, 5, 24, 5) 45 | kernel = npr.randn(3, 5, 5) 46 | with tictoc(): 47 | np.tensordot(dat, kernel, axes=[(1, 3, 5), (0, 1, 2)]) 48 | 49 | 50 | # fan_out_fan_in() 51 | # convolution() 52 | dot_equivalent() 53 | -------------------------------------------------------------------------------- /tests/test_binary_ops.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | import warnings 3 | 4 | import autograd.numpy as np 5 | import autograd.numpy.random as npr 6 | from autograd import grad, value_and_grad 7 | from autograd.test_util import check_grads 8 | 9 | rs = npr.RandomState(0) 10 | 11 | 12 | def arg_pairs(): 13 | scalar = 2.0 14 | vector = rs.randn(4) 15 | mat = rs.randn(3, 4) 16 | mat2 = rs.randn(1, 4) 17 | allargs = [scalar, vector, mat, mat2] 18 | yield from it.product(allargs, allargs) 19 | 20 | 21 | def test_mul(): 22 | fun = lambda x, y: x * y 23 | for arg1, arg2 in arg_pairs(): 24 | check_grads(fun)(arg1, arg2) 25 | 26 | 27 | def test_add(): 28 | fun = lambda x, y: x + y 29 | for arg1, arg2 in arg_pairs(): 30 | check_grads(fun)(arg1, arg2) 31 | 32 | 33 | def test_sub(): 34 | fun = lambda x, y: x - y 35 | for arg1, arg2 in arg_pairs(): 36 | check_grads(fun)(arg1, arg2) 37 | 38 | 39 | def test_div(): 40 | fun = lambda x, y: x / y 41 | make_gap_from_zero = lambda x: np.sqrt(x**2 + 0.5) 42 | for arg1, arg2 in arg_pairs(): 43 | arg1 = make_gap_from_zero(arg1) 44 | arg2 = make_gap_from_zero(arg2) 45 | check_grads(fun)(arg1, arg2) 46 | 47 | 48 | def test_mod(): 49 | fun = lambda x, y: x % y 50 | make_gap_from_zero = lambda x: np.sqrt(x**2 + 0.5) 51 | for arg1, arg2 in arg_pairs(): 52 | if arg1 is not arg2: # Gradient undefined at x == y 53 | arg1 = make_gap_from_zero(arg1) 54 | arg2 = make_gap_from_zero(arg2) 55 | check_grads(fun)(arg1, arg2) 56 | 57 | 58 | def test_pow(): 59 | fun = lambda x, y: x**y 60 | make_positive = lambda x: np.abs(x) + 1.1 # Numeric derivatives fail near zero 61 | for arg1, arg2 in arg_pairs(): 62 | arg1 = make_positive(arg1) 63 | check_grads(fun)(arg1, arg2) 64 | 65 | 66 | def test_arctan2(): 67 | for arg1, arg2 in arg_pairs(): 68 | check_grads(np.arctan2)(arg1, arg2) 69 | 70 | 71 | def test_hypot(): 72 | for arg1, arg2 in arg_pairs(): 73 | check_grads(np.hypot, modes=["rev"])(arg1, arg2) 74 | 75 | 76 | def test_comparison_grads(): 77 | compare_funs = [ 78 | lambda x, y: np.sum(x < x) + 0.0, 79 | lambda x, y: np.sum(x <= y) + 0.0, 80 | lambda x, y: np.sum(x > y) + 0.0, 81 | lambda x, y: np.sum(x >= y) + 0.0, 82 | lambda x, y: np.sum(x == y) + 0.0, 83 | lambda x, y: np.sum(x != y) + 0.0, 84 | ] 85 | 86 | with warnings.catch_warnings(record=True) as w: 87 | for arg1, arg2 in arg_pairs(): 88 | zeros = (arg1 + arg2) * 0 # get correct shape 89 | for fun in compare_funs: 90 | assert np.all(grad(fun)(arg1, arg2) == zeros) 91 | assert np.all(grad(fun, argnum=1)(arg1, arg2) == zeros) 92 | 93 | 94 | def test_comparison_values(): 95 | compare_funs = [ 96 | lambda x, y: np.sum(x < x) + 0.0, 97 | lambda x, y: np.sum(x <= y) + 0.0, 98 | lambda x, y: np.sum(x > y) + 0.0, 99 | lambda x, y: np.sum(x >= y) + 0.0, 100 | lambda x, y: np.sum(x == y) + 0.0, 101 | lambda x, y: np.sum(x != y) + 0.0, 102 | ] 103 | 104 | for arg1, arg2 in arg_pairs(): 105 | for fun in compare_funs: 106 | fun_val = fun(arg1, arg2) 107 | fun_val_from_grad, _ = value_and_grad(fun)(arg1, arg2) 108 | assert fun_val == fun_val_from_grad, (fun_val, fun_val_from_grad) 109 | -------------------------------------------------------------------------------- /tests/test_builtins.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | from autograd import grad 3 | from autograd.builtins import isinstance 4 | 5 | 6 | def test_isinstance(): 7 | def checker(ex, type_, truthval): 8 | assert isinstance(ex, type_) == truthval 9 | return 1.0 10 | 11 | examples = [ 12 | [list, [[]], [()]], 13 | [np.ndarray, [np.zeros(1)], [[]]], 14 | [(tuple, list), [[], ()], [np.zeros(1)]], 15 | ] 16 | 17 | for type_, positive_examples, negative_examples in examples: 18 | for ex in positive_examples: 19 | checker(ex, type_, True) 20 | grad(checker)(ex, type_, True) 21 | 22 | for ex in negative_examples: 23 | checker(ex, type_, False) 24 | grad(checker)(ex, type_, False) 25 | -------------------------------------------------------------------------------- /tests/test_complex.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | import autograd.numpy.random as npr 3 | from autograd import grad 4 | from autograd.test_util import check_grads 5 | 6 | npr.seed(1) 7 | 8 | 9 | def test_real_type(): 10 | fun = lambda x: np.sum(np.real(x)) 11 | df = grad(fun) 12 | assert np.isrealobj(df(2.0)) 13 | assert np.iscomplexobj(df(1.0j)) 14 | 15 | 16 | def test_real_if_close_type(): 17 | fun = lambda x: np.sum(np.real(x)) 18 | df = grad(fun) 19 | assert np.isrealobj(df(1.0)) 20 | assert np.iscomplexobj(df(1.0j)) 21 | 22 | 23 | def test_angle_real(): 24 | fun = lambda x: np.angle(x) 25 | d_fun = lambda x: grad(fun)(x) 26 | check_grads(fun)(npr.rand()) 27 | check_grads(d_fun)(npr.rand()) 28 | 29 | 30 | def test_angle_complex(): 31 | fun = lambda x: np.angle(x) 32 | d_fun = lambda x: grad(fun)(x) 33 | check_grads(fun)(npr.rand() + 1j * npr.rand()) 34 | check_grads(d_fun)(npr.rand() + 1j * npr.rand()) 35 | 36 | 37 | def test_abs_real(): 38 | fun = lambda x: np.abs(x) 39 | d_fun = lambda x: grad(fun)(x) 40 | check_grads(fun)(1.1) 41 | check_grads(d_fun)(2.1) 42 | 43 | 44 | def test_abs_complex(): 45 | fun = lambda x: np.abs(x) 46 | d_fun = lambda x: grad(fun)(x) 47 | check_grads(fun)(1.1 + 1.2j) 48 | check_grads(d_fun)(1.1 + 1.3j) 49 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | """This file doesn't import the numpy wrapper, to check if core works 2 | on basic operations even without numpy.""" 3 | 4 | import warnings 5 | 6 | from autograd.core import make_vjp 7 | from autograd.wrap_util import unary_to_nary 8 | 9 | 10 | @unary_to_nary 11 | def grad(fun, x): 12 | vjp, _ = make_vjp(fun, x) 13 | return vjp(1.0) 14 | 15 | 16 | # Non-numpy gradient checking functions. 17 | def nd(f, x, eps=1e-4): 18 | return (f(x + eps / 2) - f(x - eps / 2)) / eps 19 | 20 | 21 | def check_close(a, b, atol=1e-4, rtol=1e-4): 22 | assert abs(a - b) < atol + rtol * abs(b), f"Diffs are: {a - b}" 23 | 24 | 25 | def check_binary_func(fun, independent=False): 26 | with warnings.catch_warnings(record=independent) as w: 27 | x, y = 0.7, 1.8 28 | a = grad(fun)(x, y) 29 | b = nd(lambda x: fun(x, y), x) 30 | check_close(a, b) 31 | 32 | a = grad(fun, 1)(x, y) 33 | b = nd(lambda y: fun(x, y), y) 34 | check_close(a, b) 35 | 36 | 37 | def test_add(): 38 | check_binary_func(lambda x, y: x + y) 39 | 40 | 41 | def test_sub(): 42 | check_binary_func(lambda x, y: x - y) 43 | 44 | 45 | def test_div(): 46 | check_binary_func(lambda x, y: x / y) 47 | 48 | 49 | def test_mul(): 50 | check_binary_func(lambda x, y: x * y) 51 | 52 | 53 | def test_pow(): 54 | check_binary_func(lambda x, y: x**y) 55 | 56 | 57 | def test_mod(): 58 | check_binary_func(lambda x, y: x % y) 59 | 60 | 61 | def test_eq(): 62 | check_binary_func(lambda x, y: x == y, independent=True) 63 | 64 | 65 | def test_neq(): 66 | check_binary_func(lambda x, y: x != y, independent=True) 67 | 68 | 69 | def test_leq(): 70 | check_binary_func(lambda x, y: x <= y, independent=True) 71 | 72 | 73 | def test_geq(): 74 | check_binary_func(lambda x, y: x >= y, independent=True) 75 | 76 | 77 | def test_lt(): 78 | check_binary_func(lambda x, y: x < y, independent=True) 79 | 80 | 81 | def test_gt(): 82 | check_binary_func(lambda x, y: x > y, independent=True) 83 | -------------------------------------------------------------------------------- /tests/test_dict.py: -------------------------------------------------------------------------------- 1 | import operator as op 2 | 3 | import autograd.numpy as np 4 | import autograd.numpy.random as npr 5 | from autograd import dict as ag_dict 6 | from autograd import grad 7 | from autograd import isinstance as ag_isinstance 8 | from autograd.test_util import check_grads 9 | 10 | npr.seed(0) 11 | 12 | 13 | def test_getter(): 14 | def fun(input_dict): 15 | A = np.sum(input_dict["item_1"]) 16 | B = np.sum(input_dict["item_2"]) 17 | C = np.sum(input_dict["item_2"]) 18 | return A + B + C 19 | 20 | d_fun = grad(fun) 21 | input_dict = {"item_1": npr.randn(5, 6), "item_2": npr.randn(4, 3), "item_X": npr.randn(2, 4)} 22 | 23 | result = d_fun(input_dict) 24 | assert np.allclose(result["item_1"], np.ones((5, 6))) 25 | assert np.allclose(result["item_2"], 2 * np.ones((4, 3))) 26 | assert np.allclose(result["item_X"], np.zeros((2, 4))) 27 | 28 | 29 | def test_grads(): 30 | def fun(input_dict): 31 | A = np.sum(np.sin(input_dict["item_1"])) 32 | B = np.sum(np.cos(input_dict["item_2"])) 33 | return A + B 34 | 35 | def d_fun(input_dict): 36 | g = grad(fun)(input_dict) 37 | A = np.sum(g["item_1"]) 38 | B = np.sum(np.sin(g["item_1"])) 39 | C = np.sum(np.sin(g["item_2"])) 40 | return A + B + C 41 | 42 | input_dict = {"item_1": npr.randn(5, 6), "item_2": npr.randn(4, 3), "item_X": npr.randn(2, 4)} 43 | 44 | check_grads(fun)(input_dict) 45 | check_grads(d_fun)(input_dict) 46 | 47 | 48 | def test_iter(): 49 | def fun(input_dict): 50 | A = 0.0 51 | B = 0.0 52 | for i, k in enumerate(sorted(input_dict)): 53 | A = A + np.sum(np.sin(input_dict[k])) * (i + 1.0) 54 | B = B + np.sum(np.cos(input_dict[k])) 55 | return A + B 56 | 57 | def d_fun(input_dict): 58 | g = grad(fun)(input_dict) 59 | A = np.sum(g["item_1"]) 60 | B = np.sum(np.sin(g["item_1"])) 61 | C = np.sum(np.sin(g["item_2"])) 62 | return A + B + C 63 | 64 | input_dict = {"item_1": npr.randn(5, 6), "item_2": npr.randn(4, 3), "item_X": npr.randn(2, 4)} 65 | 66 | check_grads(fun)(input_dict) 67 | check_grads(d_fun)(input_dict) 68 | 69 | 70 | def test_items_values_keys(): 71 | def fun(input_dict): 72 | A = 0.0 73 | B = 0.0 74 | for i, (k, v) in enumerate(sorted(input_dict.items(), key=op.itemgetter(0))): 75 | A = A + np.sum(np.sin(v)) * (i + 1.0) 76 | B = B + np.sum(np.cos(v)) 77 | for v in input_dict.values(): 78 | A = A + np.sum(np.sin(v)) 79 | for k in sorted(input_dict.keys()): 80 | A = A + np.sum(np.cos(input_dict[k])) 81 | return A + B 82 | 83 | def d_fun(input_dict): 84 | g = grad(fun)(input_dict) 85 | A = np.sum(g["item_1"]) 86 | B = np.sum(np.sin(g["item_1"])) 87 | C = np.sum(np.sin(g["item_2"])) 88 | return A + B + C 89 | 90 | input_dict = {"item_1": npr.randn(5, 6), "item_2": npr.randn(4, 3), "item_X": npr.randn(2, 4)} 91 | 92 | check_grads(fun)(input_dict) 93 | check_grads(d_fun)(input_dict) 94 | 95 | 96 | def test_get(): 97 | def fun(d, x): 98 | return d.get("item_1", x) ** 2 99 | 100 | check_grads(fun, argnum=(0, 1))({"item_1": 3.0}, 2.0) 101 | check_grads(fun, argnum=(0, 1))({"item_2": 4.0}, 2.0) 102 | check_grads(fun, argnum=(0, 1))({}, 2.0) 103 | 104 | 105 | def test_make_dict(): 106 | def fun(x): 107 | return ag_dict([("a", x)], b=x) 108 | 109 | check_grads(fun, modes=["rev"])(1.0) 110 | 111 | def fun(x): 112 | return ag_dict({"a": x}) 113 | 114 | check_grads(fun, modes=["rev"])(1.0) 115 | 116 | # check some other forms of the constructor 117 | ag_dict() 118 | ag_dict(()) 119 | ag_dict({}) 120 | 121 | 122 | def test_isinstance(): 123 | def fun(x): 124 | assert ag_isinstance(x, dict) 125 | assert ag_isinstance(x, ag_dict) 126 | return x["x"] 127 | 128 | fun({"x": 1.0}) 129 | grad(fun)({"x": 1.0}) 130 | -------------------------------------------------------------------------------- /tests/test_direct.py: -------------------------------------------------------------------------------- 1 | """ 2 | Set of tests that are as explicit as possible, in case the test helpers like 3 | autograd.test_util break and start letting everything pass 4 | """ 5 | 6 | import numpy as onp 7 | import pytest 8 | 9 | import autograd.numpy as np 10 | from autograd import deriv, grad, holomorphic_grad 11 | 12 | 13 | def test_grad(): 14 | def fun(x): 15 | return (x + np.sin(x**2)) * x 16 | 17 | assert 3.190948746871 - 1e-6 < grad(fun)(1.3) < 3.190948746871 + 1e-6 18 | 19 | 20 | def test_deriv(): 21 | def fun(x): 22 | return (x + np.sin(x**2)) * x 23 | 24 | assert 3.190948746871 - 1e-6 < deriv(fun)(1.3) < 3.190948746871 + 1e-6 25 | 26 | 27 | def test_grad_complex_output(): 28 | def fun(x): 29 | return x * (1.0 + 0.2j) 30 | 31 | with pytest.raises(TypeError): 32 | grad(fun)(1.0) 33 | 34 | 35 | def test_holomorphic_grad(): 36 | def fun(x): 37 | return x * (1.0 + 0.2j) 38 | 39 | g = holomorphic_grad(fun)(1.0 + 0.0j) 40 | assert 0.9999 < onp.real(g) < 1.0001 41 | assert 0.1999 < onp.imag(g) < 0.2001 42 | -------------------------------------------------------------------------------- /tests/test_jacobian.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | import autograd.numpy.random as npr 3 | from autograd import grad, jacobian 4 | from autograd.test_util import check_grads 5 | 6 | npr.seed(1) 7 | 8 | 9 | def test_jacobian_against_grad(): 10 | fun = lambda x: np.sum(np.sin(x), axis=1, keepdims=True) 11 | A = npr.randn(1, 3) 12 | assert np.allclose(grad(fun)(A), jacobian(fun)(A)) 13 | 14 | 15 | def test_jacobian_scalar_to_vector(): 16 | fun = lambda x: np.array([x, x**2, x**3]) 17 | val = npr.randn() 18 | assert np.allclose(jacobian(fun)(val), np.array([1.0, 2 * val, 3 * val**2])) 19 | 20 | 21 | def test_jacobian_against_stacked_grads(): 22 | scalar_funs = [ 23 | lambda x: np.sum(x**3), 24 | lambda x: np.prod(np.sin(x) + np.sin(x)), 25 | lambda x: grad(lambda y: np.exp(y) * np.tanh(x[0]))(x[1]), 26 | ] 27 | 28 | vector_fun = lambda x: np.array([f(x) for f in scalar_funs]) 29 | 30 | x = npr.randn(5) 31 | jac = jacobian(vector_fun)(x) 32 | grads = [grad(f)(x) for f in scalar_funs] 33 | 34 | assert np.allclose(jac, np.vstack(grads)) 35 | 36 | 37 | def test_jacobian_higher_order(): 38 | fun = lambda x: np.sin(np.outer(x, x)) + np.cos(np.dot(x, x)) 39 | 40 | assert jacobian(fun)(npr.randn(2)).shape == (2, 2, 2) 41 | assert jacobian(jacobian(fun))(npr.randn(2)).shape == (2, 2, 2, 2) 42 | # assert jacobian(jacobian(jacobian(fun)))(npr.randn(2)).shape == (2,2,2,2,2) 43 | 44 | check_grads(lambda x: np.sum(np.sin(jacobian(fun)(x))))(npr.randn(2)) 45 | check_grads(lambda x: np.sum(np.sin(jacobian(jacobian(fun))(x))))(npr.randn(2)) 46 | -------------------------------------------------------------------------------- /tests/test_list.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | import autograd.numpy.random as npr 3 | from autograd import grad 4 | from autograd import isinstance as ag_isinstance 5 | from autograd import list as ag_list 6 | from autograd.test_util import check_grads 7 | 8 | npr.seed(1) 9 | 10 | 11 | def test_getter(): 12 | def fun(input_list): 13 | A = np.sum(input_list[0]) 14 | B = np.sum(input_list[1]) 15 | C = np.sum(input_list[1]) 16 | return A + B + C 17 | 18 | d_fun = grad(fun) 19 | input_list = [npr.randn(5, 6), npr.randn(4, 3), npr.randn(2, 4)] 20 | 21 | result = d_fun(input_list) 22 | assert np.allclose(result[0], np.ones((5, 6))) 23 | assert np.allclose(result[1], 2 * np.ones((4, 3))) 24 | assert np.allclose(result[2], np.zeros((2, 4))) 25 | 26 | 27 | def test_grads(): 28 | def fun(input_list): 29 | A = np.sum(np.sin(input_list[0])) 30 | B = np.sum(np.cos(input_list[1])) 31 | return A + B 32 | 33 | def d_fun(input_list): 34 | g = grad(fun)(input_list) 35 | A = np.sum(g[0]) 36 | B = np.sum(np.sin(g[0])) 37 | C = np.sum(np.sin(g[1])) 38 | return A + B + C 39 | 40 | input_list = [npr.randn(5, 6), npr.randn(4, 3), npr.randn(2, 4)] 41 | 42 | check_grads(fun)(input_list) 43 | check_grads(d_fun)(input_list) 44 | 45 | 46 | def test_slices(): 47 | def f(x): 48 | s = slice(None, -1, None) 49 | y = x[s] 50 | return y[0] 51 | 52 | grad(f)([1.0, 2.0, 3.0]) 53 | 54 | def f(x): 55 | y = x[1:3] 56 | return y[0] 57 | 58 | grad(f)([1.0, 2.0, 3.0]) 59 | 60 | 61 | def test_nested_list(): 62 | A = [[1.0], 2.0, 1.5] 63 | 64 | def fun(x): 65 | return x[1:][0] 66 | 67 | check_grads(fun)(A) 68 | 69 | 70 | def test_make_list(): 71 | def fun(x): 72 | return ag_list((x, x)) 73 | 74 | check_grads(fun)(1.0) 75 | 76 | 77 | def test_isinstance(): 78 | def fun(x): 79 | assert ag_isinstance(x, list) 80 | assert ag_isinstance(x, ag_list) 81 | return x[0] 82 | 83 | fun([1.0, 2.0, 3.0]) 84 | grad(fun)([1.0, 2.0, 3.0]) 85 | -------------------------------------------------------------------------------- /tests/test_logic.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from contextlib import contextmanager 3 | 4 | import pytest 5 | 6 | import autograd.numpy as np 7 | from autograd import deriv, grad 8 | from autograd.core import primitive_vjps 9 | from autograd.extend import primitive 10 | from autograd.test_util import check_grads 11 | 12 | 13 | def test_assert(): 14 | # from https://github.com/HIPS/autograd/issues/43 15 | def fun(x): 16 | assert np.allclose(x, (x * 3.0) / 3.0) 17 | return np.sum(x) 18 | 19 | check_grads(fun)(np.array([1.0, 2.0, 3.0])) 20 | 21 | 22 | def test_nograd(): 23 | # we want this to raise non-differentiability error 24 | fun = lambda x: np.allclose(x, (x * 3.0) / 3.0) 25 | with pytest.raises(TypeError): 26 | with warnings.catch_warnings(record=True) as w: 27 | grad(fun)(np.array([1.0, 2.0, 3.0])) 28 | 29 | 30 | def test_no_vjp_def(): 31 | fun = primitive(lambda x: 2.0 * x) 32 | with pytest.raises(NotImplementedError): 33 | grad(fun)(1.0) 34 | 35 | 36 | def test_no_jvp_def(): 37 | fun = primitive(lambda x: 2.0 * x) 38 | with pytest.raises(NotImplementedError): 39 | deriv(fun)(1.0) 40 | 41 | 42 | def test_falseyness(): 43 | fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x)) 44 | check_grads(fun)(2.0) 45 | check_grads(fun)(2.0 + 1j) 46 | 47 | 48 | def test_unimplemented_falseyness(): 49 | @contextmanager 50 | def remove_grad_definitions(fun): 51 | vjpmaker = primitive_vjps.pop(fun, None) 52 | yield 53 | if vjpmaker: 54 | primitive_vjps[fun] = vjpmaker 55 | 56 | with remove_grad_definitions(np.iscomplex): 57 | fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x)) 58 | check_grads(fun)(5.0) 59 | check_grads(fun)(2.0 + 1j) 60 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | import autograd.numpy.random as npr 3 | from autograd import grad, make_vjp 4 | from autograd.misc import const_graph, flatten 5 | from autograd.test_util import scalar_close 6 | from autograd.tracer import primitive 7 | 8 | 9 | def test_const_graph(): 10 | L = [] 11 | 12 | def foo(x, y): 13 | L.append(None) 14 | return grad(lambda x: np.sin(x) + x * 2)(x * y) 15 | 16 | foo_wrapped = const_graph(foo) 17 | 18 | assert len(L) == 0 19 | assert scalar_close(foo(0.0, 0.0), foo_wrapped(0.0, 0.0)) 20 | assert len(L) == 2 21 | assert scalar_close(foo(1.0, 0.5), foo_wrapped(1.0, 0.5)) 22 | assert len(L) == 3 23 | assert scalar_close(foo(1.0, 0.5), foo_wrapped(1.0, 0.5)) 24 | assert len(L) == 4 25 | 26 | 27 | def test_const_graph_args(): 28 | L = [] 29 | 30 | @primitive 31 | def process(var, varname): 32 | L.append(varname) 33 | return var 34 | 35 | def foo(x, y, z): 36 | x = process(x, "x") 37 | y = process(y, "y") 38 | z = process(z, "z") 39 | return x + 2 * y + 3 * z 40 | 41 | foo_wrapped = const_graph(foo, 1.0, z=3.0) 42 | 43 | assert L == [] 44 | assert scalar_close(foo(1.0, 2.0, 3.0), foo_wrapped(2.0)) 45 | assert L == ["x", "y", "z", "x", "y", "z"] 46 | L = [] 47 | assert scalar_close(foo(1.0, 2.0, 3.0), foo_wrapped(2.0)) 48 | assert L == ["x", "y", "z", "y"] 49 | L = [] 50 | assert scalar_close(foo(1.0, 2.0, 3.0), foo_wrapped(2.0)) 51 | assert L == ["x", "y", "z", "y"] 52 | 53 | 54 | def test_flatten(): 55 | r = np.random.randn 56 | x = (1.0, r(2, 3), [r(1, 4), {"x": 2.0, "y": r(4, 2)}]) 57 | x_flat, unflatten = flatten(x) 58 | assert x_flat.shape == (20,) 59 | assert x_flat[0] == 1.0 60 | assert np.all(x_flat == flatten(unflatten(x_flat))[0]) 61 | 62 | y = (1.0, 2.0, [3.0, {"x": 2.0, "y": 4.0}]) 63 | y_flat, unflatten = flatten(y) 64 | assert y_flat.shape == (5,) 65 | assert y == unflatten(y_flat) 66 | 67 | 68 | def test_flatten_empty(): 69 | val = (npr.randn(4), [npr.randn(3, 4), 2.5], (), (2.0, [1.0, npr.randn(2)])) 70 | vect, unflatten = flatten(val) 71 | val_recovered = unflatten(vect) 72 | vect_2, _ = flatten(val_recovered) 73 | assert np.all(vect == vect_2) 74 | 75 | 76 | def test_flatten_dict(): 77 | val = {"k": npr.random((4, 4)), "k2": npr.random((3, 3)), "k3": 3.0, "k4": [1.0, 4.0, 7.0, 9.0]} 78 | 79 | vect, unflatten = flatten(val) 80 | val_recovered = unflatten(vect) 81 | vect_2, _ = flatten(val_recovered) 82 | assert np.all(vect == vect_2) 83 | 84 | 85 | def unflatten_tracing(): 86 | val = [npr.randn(4), [npr.randn(3, 4), 2.5], (), (2.0, [1.0, npr.randn(2)])] 87 | vect, unflatten = flatten(val) 88 | 89 | def f(vect): 90 | return unflatten(vect) 91 | 92 | flatten2, _ = make_vjp(f)(vect) 93 | assert np.all(vect == flatten2(val)) 94 | 95 | 96 | def test_flatten_nodes_in_containers(): 97 | # see issue #232 98 | def f(x, y): 99 | xy, _ = flatten([x, y]) 100 | return np.sum(xy) 101 | 102 | grad(f)(1.0, 2.0) 103 | 104 | 105 | def test_flatten_complex(): 106 | val = 1 + 1j 107 | flat, unflatten = flatten(val) 108 | assert np.all(val == unflatten(flat)) 109 | -------------------------------------------------------------------------------- /tests/test_performance.py: -------------------------------------------------------------------------------- 1 | # TODO: 2 | # Do a huge calculation with trivial primitive computations 3 | # and lots of diamonds and get a benchmark per-node time and 4 | # memory cost. 5 | -------------------------------------------------------------------------------- /tests/test_scalar_ops.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | import autograd.numpy.random as npr 3 | from autograd import grad 4 | from autograd.test_util import check_grads 5 | 6 | npr.seed(1) 7 | 8 | 9 | def test_abs(): 10 | fun = lambda x: 3.0 * np.abs(x) 11 | check_grads(fun)(1.1) 12 | check_grads(fun)(-1.1) 13 | check_grads(fun, order=1)(0.0) 14 | 15 | 16 | def test_sin(): 17 | fun = lambda x: 3.0 * np.sin(x) 18 | check_grads(fun)(npr.randn()) 19 | 20 | 21 | def test_sign(): 22 | fun = lambda x: 3.0 * np.sign(x) 23 | check_grads(fun)(1.1) 24 | check_grads(fun)(-1.1) 25 | 26 | 27 | def test_exp(): 28 | fun = lambda x: 3.0 * np.exp(x) 29 | check_grads(fun)(npr.randn()) 30 | 31 | 32 | def test_log(): 33 | fun = lambda x: 3.0 * np.log(x) 34 | check_grads(fun)(abs(npr.randn())) 35 | 36 | 37 | def test_log2(): 38 | fun = lambda x: 3.0 * np.log2(x) 39 | check_grads(fun)(abs(npr.randn())) 40 | 41 | 42 | def test_log10(): 43 | fun = lambda x: 3.0 * np.log10(x) 44 | check_grads(fun)(abs(npr.randn())) 45 | 46 | 47 | def test_log1p(): 48 | fun = lambda x: 3.0 * np.log1p(x) 49 | check_grads(fun)(abs(npr.randn())) 50 | 51 | 52 | def test_expm1(): 53 | fun = lambda x: 3.0 * np.expm1(x) 54 | check_grads(fun)(abs(npr.randn())) 55 | 56 | 57 | def test_exp2(): 58 | fun = lambda x: 3.0 * np.exp2(x) 59 | check_grads(fun)(abs(npr.randn())) 60 | 61 | 62 | def test_neg(): 63 | fun = lambda x: 3.0 * -x 64 | check_grads(fun)(npr.randn()) 65 | 66 | 67 | def test_cos(): 68 | fun = lambda x: 3.0 * np.cos(x) 69 | check_grads(fun)(npr.randn()) 70 | 71 | 72 | def test_tan(): 73 | fun = lambda x: 3.0 * np.tan(x) 74 | check_grads(fun)(npr.randn()) 75 | 76 | 77 | def test_cosh(): 78 | fun = lambda x: 3.0 * np.cosh(x) 79 | check_grads(fun)(npr.randn()) 80 | 81 | 82 | def test_sinh(): 83 | fun = lambda x: 3.0 * np.sinh(x) 84 | check_grads(fun)(npr.randn()) 85 | 86 | 87 | def test_tanh(): 88 | fun = lambda x: 3.0 * np.tanh(x) 89 | check_grads(fun)(npr.randn()) 90 | 91 | 92 | def test_arccos(): 93 | fun = lambda x: 3.0 * np.arccos(x) 94 | check_grads(fun)(0.1) 95 | 96 | 97 | def test_arcsin(): 98 | fun = lambda x: 3.0 * np.arcsin(x) 99 | check_grads(fun)(0.1) 100 | 101 | 102 | def test_arctan(): 103 | fun = lambda x: 3.0 * np.arctan(x) 104 | check_grads(fun)(0.2) 105 | 106 | 107 | def test_arccosh(): 108 | fun = lambda x: 3.0 * np.arccosh(x) 109 | check_grads(fun)(npr.randn() ** 2 + 1.2) 110 | 111 | 112 | def test_arcsinh(): 113 | fun = lambda x: 3.0 * np.arcsinh(x) 114 | check_grads(fun)(npr.randn()) 115 | 116 | 117 | def test_arctanh(): 118 | fun = lambda x: 3.0 * np.arctanh(x) 119 | check_grads(fun)(0.2) 120 | 121 | 122 | def test_sqrt(): 123 | fun = lambda x: 3.0 * np.sqrt(x) 124 | check_grads(fun)(10.0 * npr.rand()) 125 | 126 | 127 | def test_power_arg0(): 128 | # the +1.'s here are to avoid regimes where numerical diffs fail 129 | make_fun = lambda y: lambda x: np.power(x, y) 130 | fun = make_fun(npr.randn() ** 2 + 1.0) 131 | check_grads(fun)(npr.rand() ** 2 + 1.0) 132 | 133 | # test y == 0. as a special case, c.f. #116 134 | fun = make_fun(0.0) 135 | assert grad(fun)(0.0) == 0.0 136 | assert grad(grad(fun))(0.0) == 0.0 137 | 138 | 139 | def test_power_arg1(): 140 | x = npr.randn() ** 2 141 | fun = lambda y: np.power(x, y) 142 | check_grads(fun)(npr.rand() ** 2) 143 | 144 | 145 | def test_power_arg1_zero(): 146 | fun = lambda y: np.power(0.0, y) 147 | check_grads(fun)(npr.rand() ** 2) 148 | 149 | 150 | def test_mod_arg0(): 151 | fun = lambda x, y: np.mod(x, y) 152 | check_grads(fun)(npr.rand(), npr.rand()) 153 | 154 | 155 | def test_mod_arg1(): 156 | fun = lambda x, y: np.mod(x, y) 157 | check_grads(fun)(npr.rand(), npr.rand()) 158 | 159 | 160 | def test_divide_arg0(): 161 | fun = lambda x, y: np.divide(x, y) 162 | check_grads(fun)(npr.rand(), npr.rand()) 163 | 164 | 165 | def test_divide_arg1(): 166 | fun = lambda x, y: np.divide(x, y) 167 | check_grads(fun)(npr.rand(), npr.rand()) 168 | 169 | 170 | def test_multiply_arg0(): 171 | fun = lambda x, y: np.multiply(x, y) 172 | check_grads(fun)(npr.rand(), npr.rand()) 173 | 174 | 175 | def test_multiply_arg1(): 176 | fun = lambda x, y: np.multiply(x, y) 177 | check_grads(fun)(npr.rand(), npr.rand()) 178 | 179 | 180 | def test_true_divide_arg0(): 181 | fun = lambda x, y: np.true_divide(x, y) 182 | check_grads(fun)(npr.rand(), npr.rand()) 183 | 184 | 185 | def test_true_divide_arg1(): 186 | fun = lambda x, y: np.true_divide(x, y) 187 | check_grads(fun)(npr.rand(), npr.rand()) 188 | 189 | 190 | def test_reciprocal(): 191 | fun = lambda x: np.reciprocal(x) 192 | check_grads(fun)(npr.rand()) 193 | 194 | 195 | def test_negative(): 196 | fun = lambda x: np.negative(x) 197 | check_grads(fun)(npr.rand()) 198 | 199 | 200 | def test_rad2deg(): 201 | fun = lambda x: 3.0 * np.rad2deg(x) 202 | check_grads(fun)(10.0 * npr.rand()) 203 | 204 | 205 | def test_deg2rad(): 206 | fun = lambda x: 3.0 * np.deg2rad(x) 207 | check_grads(fun)(10.0 * npr.rand()) 208 | 209 | 210 | def test_radians(): 211 | fun = lambda x: 3.0 * np.radians(x) 212 | check_grads(fun)(10.0 * npr.rand()) 213 | 214 | 215 | def test_degrees(): 216 | fun = lambda x: 3.0 * np.degrees(x) 217 | check_grads(fun)(10.0 * npr.rand()) 218 | 219 | 220 | def test_sinc(): 221 | fun = lambda x: 3.0 * np.sinc(x) 222 | check_grads(fun)(10.0 * npr.rand()) 223 | -------------------------------------------------------------------------------- /tests/test_tests.py: -------------------------------------------------------------------------------- 1 | from pytest import raises 2 | 3 | from autograd.extend import defvjp 4 | from autograd.test_util import check_grads 5 | from autograd.tracer import primitive 6 | 7 | 8 | def test_check_vjp_1st_order_fail(): 9 | @primitive 10 | def foo(x): 11 | return x * 2.0 12 | 13 | defvjp(foo, lambda ans, x: lambda g: g * 2.001) 14 | 15 | with raises(AssertionError, match="\\(VJP\\) check of foo failed"): 16 | check_grads(foo, modes=["rev"])(1.0) 17 | 18 | 19 | def test_check_vjp_2nd_order_fail(): 20 | @primitive 21 | def foo(x): 22 | return x * 2.0 23 | 24 | defvjp(foo, lambda ans, x: lambda g: bar(g) * 2) 25 | 26 | @primitive 27 | def bar(x): 28 | return x 29 | 30 | defvjp(bar, lambda ans, x: lambda g: g * 1.001) 31 | 32 | with raises(AssertionError, match="\\(VJP\\) check of vjp_foo failed"): 33 | check_grads(foo, modes=["rev"])(1.0) 34 | -------------------------------------------------------------------------------- /tests/test_truediv.py: -------------------------------------------------------------------------------- 1 | # This file is to check that future division works. 2 | 3 | from test_binary_ops import arg_pairs 4 | 5 | import autograd.numpy as np 6 | from autograd.test_util import check_grads 7 | 8 | 9 | def test_div(): 10 | fun = lambda x, y: x / y 11 | make_gap_from_zero = lambda x: np.sqrt(x**2 + 0.5) 12 | for arg1, arg2 in arg_pairs(): 13 | arg1 = make_gap_from_zero(arg1) 14 | arg2 = make_gap_from_zero(arg2) 15 | check_grads(fun)(arg1, arg2) 16 | -------------------------------------------------------------------------------- /tests/test_tuple.py: -------------------------------------------------------------------------------- 1 | import autograd.numpy as np 2 | import autograd.numpy.random as npr 3 | from autograd import grad 4 | from autograd import isinstance as ag_isinstance 5 | from autograd import tuple as ag_tuple 6 | from autograd.test_util import check_grads 7 | 8 | npr.seed(1) 9 | 10 | 11 | def test_getter(): 12 | def fun(input_tuple): 13 | A = np.sum(input_tuple[0]) 14 | B = np.sum(input_tuple[1]) 15 | C = np.sum(input_tuple[1]) 16 | return A + B + C 17 | 18 | d_fun = grad(fun) 19 | input_tuple = (npr.randn(5, 6), npr.randn(4, 3), npr.randn(2, 4)) 20 | 21 | result = d_fun(input_tuple) 22 | assert np.allclose(result[0], np.ones((5, 6))) 23 | assert np.allclose(result[1], 2 * np.ones((4, 3))) 24 | assert np.allclose(result[2], np.zeros((2, 4))) 25 | 26 | 27 | def test_grads(): 28 | def fun(input_tuple): 29 | A = np.sum(np.sin(input_tuple[0])) 30 | B = np.sum(np.cos(input_tuple[1])) 31 | return A + B 32 | 33 | def d_fun(input_tuple): 34 | g = grad(fun)(input_tuple) 35 | A = np.sum(g[0]) 36 | B = np.sum(np.sin(g[0])) 37 | C = np.sum(np.sin(g[1])) 38 | return A + B + C 39 | 40 | input_tuple = (npr.randn(5, 6), npr.randn(4, 3), npr.randn(2, 4)) 41 | 42 | check_grads(fun)(input_tuple) 43 | check_grads(d_fun)(input_tuple) 44 | 45 | 46 | def test_nested_higher_order(): 47 | def outer_fun(x): 48 | def inner_fun(y): 49 | return y[0] * y[1] 50 | 51 | return np.sum(np.sin(np.array(grad(inner_fun)(ag_tuple((x, x)))))) 52 | 53 | check_grads(outer_fun)(5.0) 54 | check_grads(grad(outer_fun))(10.0) 55 | check_grads(grad(grad(outer_fun)))(10.0) 56 | 57 | 58 | def test_isinstance(): 59 | def fun(x): 60 | assert ag_isinstance(x, tuple) 61 | assert ag_isinstance(x, ag_tuple) 62 | return x[0] 63 | 64 | fun((1.0, 2.0, 3.0)) 65 | grad(fun)((1.0, 2.0, 3.0)) 66 | -------------------------------------------------------------------------------- /tests/test_vspaces.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | from functools import reduce 3 | 4 | import numpy as np 5 | 6 | from autograd.core import vspace 7 | from autograd.test_util import check_grads, scalar_close 8 | 9 | 10 | def check_vspace(value): 11 | vs = vspace(value) 12 | # --- required attributes --- 13 | size = vs.size 14 | add = vs.add 15 | scalar_mul = vs.scalar_mul 16 | inner_prod = vs.inner_prod 17 | randn = vs.randn 18 | zeros = vs.zeros 19 | ones = vs.ones 20 | standard_basis = vs.standard_basis 21 | 22 | # --- util --- 23 | def randns(N=2): 24 | return [randn() for i in range(N)] 25 | 26 | def rand_scalar(): 27 | return float(np.random.randn()) 28 | 29 | def rand_scalars(N=2): 30 | return [rand_scalar() for i in range(N)] 31 | 32 | def vector_close(x, y): 33 | z = randn() 34 | return scalar_close(inner_prod(z, x), inner_prod(z, y)) 35 | 36 | # --- vector space axioms --- 37 | def associativity_of_add(x, y, z): 38 | return vector_close(add(x, add(y, z)), add(add(x, y), z)) 39 | 40 | def commutativity_of_add(x, y): 41 | return vector_close(add(x, y), add(y, x)) 42 | 43 | def identity_element_of_add(x): 44 | return vector_close(add(zeros(), x), x) 45 | 46 | def inverse_elements_of_add(x): 47 | return vector_close(zeros(), add(x, scalar_mul(x, -1.0))) 48 | 49 | def compatibility_of_scalar_mul_with_field_mul(x, a, b): 50 | return vector_close(scalar_mul(x, a * b), scalar_mul(scalar_mul(x, a), b)) 51 | 52 | def identity_element_of_scalar_mul(x): 53 | return vector_close(scalar_mul(x, 1.0), x) 54 | 55 | def distributivity_of_scalar_mul_wrt_vector_add(x, y, a): 56 | return vector_close(scalar_mul(add(x, y), a), add(scalar_mul(x, a), scalar_mul(y, a))) 57 | 58 | def distributivity_of_scalar_mul_wrt_scalar_add(x, a, b): 59 | return vector_close(scalar_mul(x, a + b), add(scalar_mul(x, a), scalar_mul(x, b))) 60 | 61 | # --- closure --- 62 | def add_preserves_vspace(x, y): 63 | return vs == vspace(add(x, y)) 64 | 65 | def scalar_mul_preserves_vspace(x, a): 66 | return vs == vspace(scalar_mul(x, a)) 67 | 68 | # --- inner product axioms --- 69 | def symmetry(x, y): 70 | return scalar_close(inner_prod(x, y), inner_prod(y, x)) 71 | 72 | def linearity(x, y, a): 73 | return scalar_close(inner_prod(scalar_mul(x, a), y), a * inner_prod(x, y)) 74 | 75 | def positive_definitive(x): 76 | return 0 < inner_prod(x, x) 77 | 78 | def inner_zeros(): 79 | return scalar_close(0, inner_prod(zeros(), zeros())) 80 | 81 | # --- basis vectors and special vectors--- 82 | def basis_orthonormality(): 83 | return all( 84 | [ 85 | scalar_close(inner_prod(x, y), 1.0 * (ix == iy)) 86 | for (ix, x), (iy, y) in it.product(enumerate(standard_basis()), enumerate(standard_basis())) 87 | ] 88 | ) 89 | 90 | def ones_sum_of_basis_vects(): 91 | return vector_close(reduce(add, standard_basis()), ones()) 92 | 93 | def basis_correct_size(): 94 | return len(list(standard_basis())) == size 95 | 96 | def basis_correct_vspace(): 97 | return (vs == vspace(x) for x in standard_basis()) 98 | 99 | def zeros_correct_vspace(): 100 | return vs == vspace(zeros()) 101 | 102 | def ones_correct_vspace(): 103 | return vs == vspace(ones()) 104 | 105 | def randn_correct_vspace(): 106 | return vs == vspace(randn()) 107 | 108 | assert associativity_of_add(*randns(3)) 109 | assert commutativity_of_add(*randns()) 110 | assert identity_element_of_add(randn()) 111 | assert inverse_elements_of_add(randn()) 112 | assert compatibility_of_scalar_mul_with_field_mul(randn(), *rand_scalars()) 113 | assert identity_element_of_scalar_mul(randn()) 114 | assert distributivity_of_scalar_mul_wrt_vector_add(randn(), randn(), rand_scalar()) 115 | assert distributivity_of_scalar_mul_wrt_scalar_add(randn(), *rand_scalars()) 116 | assert add_preserves_vspace(*randns()) 117 | assert scalar_mul_preserves_vspace(randn(), rand_scalar()) 118 | assert symmetry(*randns()) 119 | assert linearity(randn(), randn(), rand_scalar()) 120 | assert positive_definitive(randn()) 121 | assert inner_zeros() 122 | assert basis_orthonormality() 123 | assert ones_sum_of_basis_vects() 124 | assert basis_correct_size() 125 | assert basis_correct_vspace() 126 | assert zeros_correct_vspace() 127 | assert ones_correct_vspace() 128 | assert randn_correct_vspace() 129 | 130 | # --- grads of basic operations --- 131 | check_grads(add)(*randns()) 132 | check_grads(scalar_mul)(randn(), rand_scalar()) 133 | check_grads(inner_prod)(*randns()) 134 | 135 | 136 | def test_array_vspace(): 137 | check_vspace(np.zeros((3, 2))) 138 | 139 | 140 | def test_array_vspace_0_dim(): 141 | check_vspace(0.0) 142 | 143 | 144 | def test_array_vspace_complex(): 145 | check_vspace(1.0j * np.zeros((2, 1))) 146 | 147 | 148 | def test_list_vspace(): 149 | check_vspace([1.0, np.zeros((2, 1))]) 150 | 151 | 152 | def test_tuple_vspace(): 153 | check_vspace((1.0, np.zeros((2, 1)))) 154 | 155 | 156 | def test_dict_vspace(): 157 | check_vspace({"a": 1.0, "b": np.zeros((2, 1))}) 158 | 159 | 160 | def test_mixed_vspace(): 161 | check_vspace({"x": [0.0, np.zeros((3, 1))], "y": ({"a": 0.0}, [0.0])}) 162 | --------------------------------------------------------------------------------