├── src ├── tests │ ├── __init__.py │ ├── test_jacobian.py │ └── test_diff.py └── udiff │ ├── __init__.py │ ├── _core.py │ ├── _uarray_plug.py │ ├── _jvp_diffs.py │ ├── _diff_array.py │ └── _vjp_diffs.py ├── docs ├── logo.png ├── _static │ ├── backward.png │ ├── approaches.png │ ├── forward_mode.png │ ├── reverse_mode.png │ ├── expression_swell.png │ ├── computational_graph.png │ └── build_computational_graph.png ├── generated │ ├── udiff.defjvp.rst │ ├── udiff.defvjp.rst │ ├── udiff.def_linear.rst │ ├── udiff.DiffArray.id.rst │ ├── udiff.DiffArray.dtype.rst │ ├── udiff.DiffArray.value.rst │ ├── udiff.JVPDiffArray.id.rst │ ├── udiff.JVPDiffArray.to.rst │ ├── udiff.VJPDiffArray.id.rst │ ├── udiff.VJPDiffArray.to.rst │ ├── udiff.JVPDiffArray.dtype.rst │ ├── udiff.JVPDiffArray.value.rst │ ├── udiff.VJPDiffArray.dtype.rst │ ├── udiff.VJPDiffArray.value.rst │ ├── udiff.DiffArray.__init__.rst │ ├── udiff.JVPDiffArray.__init__.rst │ ├── udiff.VJPDiffArray.__init__.rst │ ├── udiff.DiffArrayBackend.__init__.rst │ ├── udiff.JVPDiffArray.register_diff.rst │ ├── udiff.VJPDiffArray.register_diff.rst │ ├── udiff.DiffArrayBackend.replace_arrays.rst │ ├── udiff.DiffArrayBackend.overridden_class.rst │ ├── udiff.DiffArrayBackend.self_implementations.rst │ ├── udiff.rst │ ├── udiff.DiffArray.rst │ ├── udiff.DiffArrayBackend.rst │ ├── udiff.JVPDiffArray.rst │ └── udiff.VJPDiffArray.rst ├── _templates │ └── autosummary │ │ ├── base.rst │ │ ├── module.rst │ │ └── class.rst ├── install.rst ├── index.rst ├── notebooks │ └── linear_regression.ipynb ├── conf.py └── auto_diff.rst ├── .coveragerc ├── doc8.ini ├── CODE_OF_CONDUCT.md ├── pytest.ini ├── .conda └── environment.yml ├── conftest.py ├── README.md ├── CONTRIBUTING.md ├── setup.py ├── LICENSE ├── .gitignore └── azure-pipelines.yml /src/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Quansight-Labs/udiff/HEAD/docs/logo.png -------------------------------------------------------------------------------- /docs/_static/backward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Quansight-Labs/udiff/HEAD/docs/_static/backward.png -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = 4 | src/udiff 5 | [report] 6 | omit = 7 | **/tests/ 8 | -------------------------------------------------------------------------------- /docs/_static/approaches.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Quansight-Labs/udiff/HEAD/docs/_static/approaches.png -------------------------------------------------------------------------------- /docs/_static/forward_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Quansight-Labs/udiff/HEAD/docs/_static/forward_mode.png -------------------------------------------------------------------------------- /docs/_static/reverse_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Quansight-Labs/udiff/HEAD/docs/_static/reverse_mode.png -------------------------------------------------------------------------------- /docs/generated/udiff.defjvp.rst: -------------------------------------------------------------------------------- 1 | defjvp 2 | ====== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autofunction:: defjvp -------------------------------------------------------------------------------- /docs/generated/udiff.defvjp.rst: -------------------------------------------------------------------------------- 1 | defvjp 2 | ====== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autofunction:: defvjp -------------------------------------------------------------------------------- /doc8.ini: -------------------------------------------------------------------------------- 1 | [doc8] 2 | max-line-length=88 3 | ignore-path=docs/generated,docs/_build,docs\_templates,_build,*.egg-info,src/ -------------------------------------------------------------------------------- /docs/_static/expression_swell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Quansight-Labs/udiff/HEAD/docs/_static/expression_swell.png -------------------------------------------------------------------------------- /docs/_static/computational_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Quansight-Labs/udiff/HEAD/docs/_static/computational_graph.png -------------------------------------------------------------------------------- /docs/generated/udiff.def_linear.rst: -------------------------------------------------------------------------------- 1 | def\_linear 2 | =========== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autofunction:: def_linear -------------------------------------------------------------------------------- /docs/_static/build_computational_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Quansight-Labs/udiff/HEAD/docs/_static/build_computational_graph.png -------------------------------------------------------------------------------- /docs/generated/udiff.DiffArray.id.rst: -------------------------------------------------------------------------------- 1 | DiffArray.id 2 | ============ 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoproperty:: DiffArray.id -------------------------------------------------------------------------------- /docs/generated/udiff.DiffArray.dtype.rst: -------------------------------------------------------------------------------- 1 | DiffArray.dtype 2 | =============== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoproperty:: DiffArray.dtype -------------------------------------------------------------------------------- /docs/generated/udiff.DiffArray.value.rst: -------------------------------------------------------------------------------- 1 | DiffArray.value 2 | =============== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoproperty:: DiffArray.value -------------------------------------------------------------------------------- /docs/generated/udiff.JVPDiffArray.id.rst: -------------------------------------------------------------------------------- 1 | JVPDiffArray.id 2 | =============== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoproperty:: JVPDiffArray.id -------------------------------------------------------------------------------- /docs/generated/udiff.JVPDiffArray.to.rst: -------------------------------------------------------------------------------- 1 | JVPDiffArray.to 2 | =============== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. automethod:: JVPDiffArray.to -------------------------------------------------------------------------------- /docs/generated/udiff.VJPDiffArray.id.rst: -------------------------------------------------------------------------------- 1 | VJPDiffArray.id 2 | =============== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoproperty:: VJPDiffArray.id -------------------------------------------------------------------------------- /docs/generated/udiff.VJPDiffArray.to.rst: -------------------------------------------------------------------------------- 1 | VJPDiffArray.to 2 | =============== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. automethod:: VJPDiffArray.to -------------------------------------------------------------------------------- /docs/_templates/autosummary/base.rst: -------------------------------------------------------------------------------- 1 | {{ objname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. auto{{ objtype }}:: {{ objname }} 6 | -------------------------------------------------------------------------------- /docs/generated/udiff.JVPDiffArray.dtype.rst: -------------------------------------------------------------------------------- 1 | JVPDiffArray.dtype 2 | ================== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoproperty:: JVPDiffArray.dtype -------------------------------------------------------------------------------- /docs/generated/udiff.JVPDiffArray.value.rst: -------------------------------------------------------------------------------- 1 | JVPDiffArray.value 2 | ================== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoproperty:: JVPDiffArray.value -------------------------------------------------------------------------------- /docs/generated/udiff.VJPDiffArray.dtype.rst: -------------------------------------------------------------------------------- 1 | VJPDiffArray.dtype 2 | ================== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoproperty:: VJPDiffArray.dtype -------------------------------------------------------------------------------- /docs/generated/udiff.VJPDiffArray.value.rst: -------------------------------------------------------------------------------- 1 | VJPDiffArray.value 2 | ================== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoproperty:: VJPDiffArray.value -------------------------------------------------------------------------------- /docs/generated/udiff.DiffArray.__init__.rst: -------------------------------------------------------------------------------- 1 | DiffArray.\_\_init\_\_ 2 | ====================== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. automethod:: DiffArray.__init__ -------------------------------------------------------------------------------- /docs/generated/udiff.JVPDiffArray.__init__.rst: -------------------------------------------------------------------------------- 1 | JVPDiffArray.\_\_init\_\_ 2 | ========================= 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. automethod:: JVPDiffArray.__init__ -------------------------------------------------------------------------------- /docs/generated/udiff.VJPDiffArray.__init__.rst: -------------------------------------------------------------------------------- 1 | VJPDiffArray.\_\_init\_\_ 2 | ========================= 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. automethod:: VJPDiffArray.__init__ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | This repository is governed by the Quansight Repository Code of Conduct. It 2 | can be found here: 3 | https://github.com/Quansight/.github/blob/master/CODE_OF_CONDUCT.md. 4 | -------------------------------------------------------------------------------- /docs/generated/udiff.DiffArrayBackend.__init__.rst: -------------------------------------------------------------------------------- 1 | DiffArrayBackend.\_\_init\_\_ 2 | ============================= 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. automethod:: DiffArrayBackend.__init__ -------------------------------------------------------------------------------- /docs/generated/udiff.JVPDiffArray.register_diff.rst: -------------------------------------------------------------------------------- 1 | JVPDiffArray.register\_diff 2 | =========================== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. automethod:: JVPDiffArray.register_diff -------------------------------------------------------------------------------- /docs/generated/udiff.VJPDiffArray.register_diff.rst: -------------------------------------------------------------------------------- 1 | VJPDiffArray.register\_diff 2 | =========================== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. automethod:: VJPDiffArray.register_diff -------------------------------------------------------------------------------- /docs/generated/udiff.DiffArrayBackend.replace_arrays.rst: -------------------------------------------------------------------------------- 1 | DiffArrayBackend.replace\_arrays 2 | ================================ 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. automethod:: DiffArrayBackend.replace_arrays -------------------------------------------------------------------------------- /docs/generated/udiff.DiffArrayBackend.overridden_class.rst: -------------------------------------------------------------------------------- 1 | DiffArrayBackend.overridden\_class 2 | ================================== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. automethod:: DiffArrayBackend.overridden_class -------------------------------------------------------------------------------- /docs/generated/udiff.DiffArrayBackend.self_implementations.rst: -------------------------------------------------------------------------------- 1 | DiffArrayBackend.self\_implementations 2 | ====================================== 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoproperty:: DiffArrayBackend.self_implementations -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | junit_family = xunit2 3 | addopts = --black --doctest-modules --junitxml=junit/test-results.xml --cov-report=xml --cov-report=term --cov --cov-report html --cov . --cov-config .coveragerc 4 | doctest_optionflags= IGNORE_EXCEPTION_DETAIL 5 | -------------------------------------------------------------------------------- /.conda/environment.yml: -------------------------------------------------------------------------------- 1 | name: uarray 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - python=3.7 8 | - pip 9 | - sphinx 10 | - sphinx_rtd_theme 11 | - pytest 12 | - pytest-cov 13 | - mypy 14 | - pytorch-cpu 15 | - scipy 16 | - dask 17 | - sparse 18 | - doc8 19 | - black 20 | - matplotlib 21 | - nbsphinx 22 | - pip: 23 | - pytest-mypy 24 | - pytest-black 25 | - nbval 26 | -------------------------------------------------------------------------------- /docs/generated/udiff.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: udiff 2 | 3 | Functions 4 | =========== 5 | 6 | .. rubric:: Functions 7 | 8 | .. autosummary:: 9 | :toctree: 10 | 11 | 12 | defvjp 13 | defjvp 14 | def_linear 15 | 16 | Classes 17 | ======== 18 | .. rubric:: Classes 19 | 20 | .. autosummary:: 21 | :toctree: 22 | 23 | DiffArrayBackend 24 | DiffArray 25 | VJPDiffArray 26 | JVPDiffArray -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import udiff 3 | import pytest # type: ignore 4 | 5 | 6 | def pytest_cmdline_preparse(args): 7 | try: 8 | import pytest_black # type: ignore 9 | except ImportError: 10 | pass 11 | else: 12 | args.append("--black") 13 | print("uarray: Enabling pytest-black") 14 | 15 | 16 | @pytest.fixture(autouse=True) 17 | def add_namespaces(doctest_namespace): 18 | doctest_namespace["udiff"] = udiff 19 | -------------------------------------------------------------------------------- /docs/generated/udiff.DiffArray.rst: -------------------------------------------------------------------------------- 1 | DiffArray 2 | ========= 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoclass:: DiffArray 7 | 8 | 9 | 10 | .. rubric:: Attributes 11 | .. autosummary:: 12 | :toctree: 13 | 14 | DiffArray.dtype 15 | 16 | DiffArray.id 17 | 18 | DiffArray.value 19 | 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Methods 26 | .. autosummary:: 27 | :toctree: 28 | 29 | DiffArray.__init__ 30 | 31 | 32 | -------------------------------------------------------------------------------- /src/udiff/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import uarray as ua 3 | 4 | from . import _vjp_diffs, _jvp_diffs 5 | from ._uarray_plug import DiffArrayBackend 6 | from ._core import defvjp, defvjp_argnum, defjvp, defjvp_argnum, def_linear 7 | 8 | from ._diff_array import DiffArray, JVPDiffArray, VJPDiffArray 9 | 10 | __all__ = [ 11 | "DiffArrayBackend", 12 | "DiffArray", 13 | "JVPDiffArray", 14 | "defvjp", 15 | "defvjp_argnum", 16 | "VJPDiffArray", 17 | "defjvp", 18 | "defjvp_argnum", 19 | "def_linear", 20 | ] 21 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/module.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. rubric:: Description 4 | .. automodule:: {{ fullname }} 5 | .. currentmodule:: {{ fullname }} 6 | 7 | {% if classes %} 8 | .. rubric:: Classes 9 | 10 | .. autosummary:: 11 | :toctree: 12 | 13 | {% for class in classes %} 14 | {{ class }} 15 | {% endfor %} 16 | 17 | {% endif %} 18 | 19 | {% if functions %} 20 | .. rubric:: Functions 21 | 22 | .. autosummary:: 23 | :toctree: 24 | 25 | {% for function in functions %} 26 | {{ function }} 27 | {% endfor %} 28 | 29 | {% endif %} 30 | -------------------------------------------------------------------------------- /docs/generated/udiff.DiffArrayBackend.rst: -------------------------------------------------------------------------------- 1 | DiffArrayBackend 2 | ================ 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoclass:: DiffArrayBackend 7 | 8 | 9 | 10 | .. rubric:: Attributes 11 | .. autosummary:: 12 | :toctree: 13 | 14 | DiffArrayBackend.self_implementations 15 | 16 | 17 | 18 | 19 | 20 | 21 | .. rubric:: Methods 22 | .. autosummary:: 23 | :toctree: 24 | 25 | DiffArrayBackend.__init__ 26 | 27 | DiffArrayBackend.overridden_class 28 | 29 | DiffArrayBackend.replace_arrays 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs/generated/udiff.JVPDiffArray.rst: -------------------------------------------------------------------------------- 1 | JVPDiffArray 2 | ============ 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoclass:: JVPDiffArray 7 | 8 | 9 | 10 | .. rubric:: Attributes 11 | .. autosummary:: 12 | :toctree: 13 | 14 | JVPDiffArray.dtype 15 | 16 | JVPDiffArray.id 17 | 18 | JVPDiffArray.value 19 | 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Methods 26 | .. autosummary:: 27 | :toctree: 28 | 29 | JVPDiffArray.__init__ 30 | 31 | JVPDiffArray.register_diff 32 | 33 | JVPDiffArray.to 34 | 35 | 36 | -------------------------------------------------------------------------------- /docs/generated/udiff.VJPDiffArray.rst: -------------------------------------------------------------------------------- 1 | VJPDiffArray 2 | ============ 3 | 4 | .. currentmodule:: udiff 5 | 6 | .. autoclass:: VJPDiffArray 7 | 8 | 9 | 10 | .. rubric:: Attributes 11 | .. autosummary:: 12 | :toctree: 13 | 14 | VJPDiffArray.dtype 15 | 16 | VJPDiffArray.id 17 | 18 | VJPDiffArray.value 19 | 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Methods 26 | .. autosummary:: 27 | :toctree: 28 | 29 | VJPDiffArray.__init__ 30 | 31 | VJPDiffArray.register_diff 32 | 33 | VJPDiffArray.to 34 | 35 | 36 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ objname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | {% block attributes %} 8 | {% if attributes %} 9 | .. rubric:: Attributes 10 | .. autosummary:: 11 | :toctree: 12 | {% for item in attributes %} 13 | {{ name }}.{{ item }} 14 | {% endfor %} 15 | {% endif %} 16 | {% endblock %} 17 | 18 | {% block methods %} 19 | {% if methods %} 20 | .. rubric:: Methods 21 | .. autosummary:: 22 | :toctree: 23 | {% for item in methods %} 24 | {{ name }}.{{ item }} 25 | {% endfor %} 26 | {% endif %} 27 | {% endblock %} 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `udiff` - Automatic differentiation with uarray/unumpy. 2 | [![Join the chat at https://gitter.im/Plures/uarray](https://badges.gitter.im/Plures/uarray.svg)](https://gitter.im/Plures/uarray?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) ![language](https://img.shields.io/badge/language-python3-orange.svg) ![license](https://img.shields.io/github/license/Quansight-Labs/udiff) 3 | 4 | ## Quickstart 5 | ```python 6 | import uarray as ua 7 | import unumpy as np 8 | import udiff 9 | from unumpy import numpy_backend 10 | 11 | with ua.set_backend(udiff.DiffArrayBackend(numpy_backend), coerce=True): 12 | x1 = np.reshape(np.arange(1, 26), (5, 5)) 13 | x2 = np.reshape(np.arange(1, 26), (5, 5)) 14 | y = np.log(x1) + x1 * x2 - np.sin(x2) 15 | print(y) 16 | print(y.to(x1)) 17 | print(y.to(x2)) 18 | ``` 19 | 20 | ## Contributing 21 | 22 | See [`CONTRIBUTING.md`](CONTRIBUTING.md) for more information on how to contribute to `udiff`. 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions to `udiff` are welcome and appreciated. Contributions can take the form of bug reports, documentation, code, and more. 4 | 5 | ## Getting the code 6 | 7 | Make a fork of the main [udiff repository](https://github.com/Quansight-Labs/udiff) and clone the fork: 8 | 9 | ``` 10 | git clone https://github.com//udiff 11 | ``` 12 | 13 | ## Install 14 | 15 | Note that udiff supports Python versions >= 3.5. If you're running `conda` and would prefer to have dependencies 16 | pulled from there, use 17 | 18 | ``` 19 | conda env create -f .conda/environment.yml 20 | conda activate uarray 21 | ``` 22 | 23 | This will create an environment named `uarray` which you can use for development. 24 | 25 | `unumpy` and all development dependencies can be installed via: 26 | 27 | ``` 28 | pip install -e . 29 | ``` 30 | 31 | ## Testing 32 | 33 | Tests can be run from the main uarray directory as follows: 34 | 35 | ``` 36 | pytest 37 | ``` 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="udiff", 5 | version="0.6.0", 6 | author="Hameer Abbasi", 7 | author_email="hameerabbasi@yahoo.com", 8 | description="Automatic differentiation with uarray/unumpy.", 9 | platforms="Posix; MacOS X; Windows", 10 | packages=setuptools.find_packages(where="src", exclude=["tests*"]), 11 | package_dir={"": "src"}, 12 | include_package_data=True, 13 | install_requires=( 14 | "uarray @ git+https://github.com/Quansight-Labs/uarray@master#egg=uarray", 15 | "unumpy @ git+https://github.com/Quansight-Labs/unumpy@master#egg=unumpy", 16 | ), 17 | classifiers=[ 18 | "Development Status :: 3 - Alpha", 19 | "Natural Language :: English", 20 | "Intended Audience :: Science/Research", 21 | "Programming Language :: Python", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.5", 24 | "Programming Language :: Python :: 3.6", 25 | "Programming Language :: Python :: 3.7", 26 | "Programming Language :: Python :: 3.8", 27 | ], 28 | project_urls={ 29 | "Source": "https://github.com/Quansight-Labs/udiff", 30 | }, 31 | zip_safe=False, 32 | ) 33 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | Prerequisites 2 | -------------- 3 | 4 | There are some prerequisite packages you should install before udiff. 5 | 6 | :Python3: udiff requires Python 3.5 or higher. 7 | :`uarray `_: uarray is a backend system for Python that allows you to separately define an API, 8 | along with backends that contain separate implementations of that API. 9 | 10 | .. code:: bash 11 | 12 | pip install git+https://github.com/Quansight-Labs/uarray.git 13 | 14 | :`unumpy `_: unumpy builds on top of uarray. 15 | It is an effort to specify the core NumPy API, and provide backends for the API. 16 | 17 | .. code:: bash 18 | 19 | pip install git+https://github.com/Quansight-Labs/unumpy.git 20 | 21 | 22 | Installation 23 | ------------- 24 | 25 | .. note:: 26 | :obj:`udiff` has not been published on PyPI. You have to install it from source code now. 27 | 28 | #. Use Git to clone the :obj:`udiff` repository: 29 | 30 | .. code:: bash 31 | 32 | git clone https://github.com/Quansight-Labs/udiff.git 33 | cd udiff 34 | 35 | #. Install :obj:`udiff` on the command line, enter: 36 | 37 | .. code:: bash 38 | 39 | pip install -e . --no-deps --user 40 | 41 | If you want to install it system-wide for all users (assuming you have the necessary rights), 42 | just drop the ``--user`` flag. 43 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed under the BSD-3-Clause: https://opensource.org/licenses/BSD-3-Clause 2 | 3 | ``udiff`` 4 | =========== 5 | 6 | .. warning:: 7 | 8 | Working in progress. This is a pre-release build. 9 | 10 | :obj:`udiff` is a tool for automatic differentiation with uarray/unumpy. 11 | 12 | What's new in ``udiff``? 13 | ------------------------- 14 | 15 | :obj:`udiff` is committed to providing a universal differentiation mechanism built on a 16 | generic backend system. It's possible to get the differential of various 17 | calculations of different data types(scalar, tensor and matrix). In addition, it's 18 | possible to change the used backend via a context manager. 19 | 20 | .. toctree:: 21 | :maxdepth: 1 22 | :caption: Installation 23 | 24 | install 25 | 26 | .. toctree:: 27 | :maxdepth: 1 28 | :caption: Tutorials 29 | 30 | notebooks/quickstart 31 | notebooks/linear_regression 32 | 33 | .. toctree:: 34 | :maxdepth: 1 35 | :caption: API documentation 36 | 37 | generated/udiff 38 | 39 | .. toctree:: 40 | :maxdepth: 1 41 | :caption: Developer documentation 42 | 43 | auto_diff 44 | 45 | Indices and tables 46 | ================== 47 | 48 | * :ref:`genindex` 49 | * :ref:`modindex` 50 | * :ref:`search` 51 | 52 | Bug reports are gladly accepted at the `GitHub issue tracker`_. 53 | GitHub also hosts the `code repository`_. 54 | 55 | .. _GitHub issue tracker: https://github.com/Quansight-Labs/udiff/issues 56 | .. _code repository: https://github.com/Quansight-Labs/udiff 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Quansight Labs 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | junit/ 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | default.profraw 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | _build/ 71 | make.bat 72 | Makefile 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | # IDEs 112 | .vscode/ 113 | .idea/ 114 | 115 | sandbox.py 116 | -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | jobs: 2 | - job: Tests 3 | pool: 4 | vmImage: 'ubuntu-16.04' 5 | 6 | steps: 7 | - script: | 8 | echo "##vso[task.prependpath]$CONDA/bin" 9 | conda env create -f .conda/environment.yml 10 | displayName: Prepare conda 11 | - script: | 12 | source activate uarray 13 | pip install git+https://github.com/Quansight-Labs/uarray.git 14 | pip install git+https://github.com/Quansight-Labs/unumpy.git 15 | pip install -e . --no-deps 16 | displayName: Install package 17 | - script: | 18 | source activate uarray 19 | pytest 20 | displayName: Run tests 21 | - task: PublishCodeCoverageResults@1 22 | inputs: 23 | codeCoverageTool: Cobertura 24 | summaryFileLocation: "$(System.DefaultWorkingDirectory)/**/coverage.xml" 25 | 26 | # - job: TestsMinimalEnv 27 | # pool: 28 | # vmImage: 'ubuntu-16.04' 29 | 30 | # steps: 31 | # - script: | 32 | # echo "##vso[task.prependpath]$CONDA/bin" 33 | # conda env create -f .conda/environment_minimal.yml 34 | # displayName: Prepare conda 35 | # - script: | 36 | # source activate uarray_min 37 | # pip install git+https://github.com/Quansight-Labs/uarray.git 38 | # pip install git+https://github.com/Quansight-Labs/unumpy.git 39 | # pip install -e . --no-deps 40 | # displayName: Install package 41 | # - script: | 42 | # source activate uarray_min 43 | # python setup.py build 44 | # python setup.py install 45 | # pytest 46 | # displayName: Run tests 47 | # - task: PublishCodeCoverageResults@1 48 | # inputs: 49 | # codeCoverageTool: Cobertura 50 | # summaryFileLocation: "$(System.DefaultWorkingDirectory)/**/coverage.xml" 51 | 52 | - job: Docs 53 | pool: 54 | vmImage: 'ubuntu-16.04' 55 | steps: 56 | - script: | 57 | echo "##vso[task.prependpath]$CONDA/bin" 58 | conda env create -f .conda/environment.yml 59 | displayName: Prepare conda 60 | - script: | 61 | source activate uarray 62 | pip install git+https://github.com/Quansight-Labs/uarray.git 63 | pip install git+https://github.com/Quansight-Labs/unumpy.git 64 | pip install -e . --no-deps 65 | displayName: Install package 66 | - script: | 67 | source activate uarray 68 | sphinx-build -W -b html docs/ _build/html 69 | displayName: Build docs 70 | - script: | 71 | source activate uarray 72 | doc8 73 | displayName: Lint docs 74 | - task: PublishPipelineArtifact@0 75 | inputs: 76 | artifactName: 'Documentation' 77 | targetPath: '$(System.DefaultWorkingDirectory)/_build/html' 78 | 79 | trigger: 80 | branches: 81 | include: 82 | - master 83 | 84 | pr: 85 | - master -------------------------------------------------------------------------------- /docs/notebooks/linear_regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Linear Regression\n", 8 | "\n", 9 | "\n", 10 | "Linear regression is the simplest model in machine learning, while it is an important part of many complex models, such as neural network.\n", 11 | "\n", 12 | "In this section, we will implement a simple linear regression model and optimize its parameters by gradient descent." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import uarray as ua\n", 22 | "import unumpy as np\n", 23 | "import numpy as onp\n", 24 | "import udiff\n", 25 | "from unumpy import numpy_backend" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "def synthetic_data(w, b, num_examples):\n", 35 | " \"\"\"Generate y = Xw + b + noise.\"\"\"\n", 36 | " X = onp.random.normal(0, 1, (num_examples, len(w)))\n", 37 | " y = onp.dot(X, w) + b\n", 38 | " y += onp.random.normal(0, 0.01, y.shape)\n", 39 | " return np.asarray(X), y" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "def gradient_descent(params, loss, lr=0.1):\n", 49 | " \"\"\"Gradient Descent.\"\"\"\n", 50 | " for param in params:\n", 51 | " param._value -= lr * loss.to(param).value" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": { 58 | "tags": [] 59 | }, 60 | "outputs": [ 61 | { 62 | "output_type": "stream", 63 | "name": "stdout", 64 | "text": "True\nTrue\n" 65 | } 66 | ], 67 | "source": [ 68 | "with ua.set_backend(udiff.DiffArrayBackend(numpy_backend), coerce=True):\n", 69 | " # hyper-parameters\n", 70 | " lr = 0.1\n", 71 | " epoch = 100\n", 72 | " num_examples = 1000\n", 73 | "\n", 74 | " # generate dataset\n", 75 | " true_w = onp.array([[2], [-3.4]])\n", 76 | " true_b = 4.2\n", 77 | " features, labels = synthetic_data(true_w, true_b, num_examples)\n", 78 | "\n", 79 | " # trainable parameters\n", 80 | " W = np.asarray(onp.random.normal(scale=0.01, size=(2, 1)))\n", 81 | " b = np.zeros(1)\n", 82 | " params = [W, b]\n", 83 | "\n", 84 | " # define model and loss function\n", 85 | " net = lambda X: np.matmul(X, W) + b\n", 86 | " # mean squared error\n", 87 | " loss = lambda y_hat, y: np.sum((y_hat - y) ** 2) / num_examples\n", 88 | "\n", 89 | " # train\n", 90 | " for e in range(epoch):\n", 91 | " y_hat = net(features)\n", 92 | " l = loss(y_hat, labels)\n", 93 | " gradient_descent(params, l, lr=lr)\n", 94 | "\n", 95 | " print(onp.allclose(W.value, true_w, 0.1))\n", 96 | " print(onp.allclose(b.value, true_b, 0.1))" 97 | ] 98 | } 99 | ], 100 | "metadata": { 101 | "kernelspec": { 102 | "display_name": "Python 3", 103 | "language": "python", 104 | "name": "python3" 105 | }, 106 | "language_info": { 107 | "codemirror_mode": { 108 | "name": "ipython", 109 | "version": 3 110 | }, 111 | "file_extension": ".py", 112 | "mimetype": "text/x-python", 113 | "name": "python", 114 | "nbconvert_exporter": "python", 115 | "pygments_lexer": "ipython3", 116 | "version": "3.7.7-final" 117 | } 118 | }, 119 | "nbformat": 4, 120 | "nbformat_minor": 4 121 | } -------------------------------------------------------------------------------- /src/udiff/_core.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from itertools import count 3 | from functools import reduce 4 | import unumpy as np 5 | import uarray as ua 6 | from unumpy import numpy_backend 7 | 8 | # -------------------- reverse mode -------------------- 9 | 10 | primitive_vjps = {} 11 | 12 | 13 | def defvjp_argnums(fun, vjpmaker): 14 | primitive_vjps[fun] = vjpmaker 15 | 16 | 17 | def defvjp_argnum(fun, vjpmaker): 18 | def vjp_argnums(argnums, *args): 19 | vjps = [vjpmaker(argnum, *args) for argnum in argnums] 20 | return lambda g: (vjp(g) for vjp in vjps) 21 | 22 | defvjp_argnums(fun, vjp_argnums) 23 | 24 | 25 | def defvjp(fun, *vjpmakers, **kwargs): 26 | """ 27 | Set up a unumpy-transformable function for a VJP rule definition. 28 | 29 | 30 | Parameters 31 | ---------- 32 | fun : np.ufunc 33 | The function need to be derived. 34 | *jvpfuns : 35 | Functions for calculating derivative. 36 | 37 | Examples 38 | -------- 39 | >>> defvjp(np.positive, lambda ans, x: lambda g: g) 40 | 41 | """ 42 | argnums = kwargs.get("argnums", count()) 43 | vjps_dict = { 44 | argnum: translate_vjp(vjpmaker, fun, argnum) 45 | for argnum, vjpmaker in zip(argnums, vjpmakers) 46 | } 47 | 48 | def vjp_argnums(argnums, ans, args, kwargs): 49 | try: 50 | vjps = [vjps_dict[argnum](ans, *args, **kwargs) for argnum in argnums] 51 | except KeyError: 52 | raise NotImplementedError("VJP of {} not defined".format(fun.name)) 53 | 54 | def ret(g): 55 | return tuple(vjp(g) for vjp in vjps) 56 | 57 | return ret 58 | 59 | defvjp_argnums(fun, vjp_argnums) 60 | 61 | 62 | def translate_vjp(vjpfun, fun, argnum): 63 | if vjpfun is None: 64 | return lambda ans, *args, **kwargs: lambda g: np.zeros_like(args[argnum]) 65 | elif callable(vjpfun): 66 | return vjpfun 67 | else: 68 | raise Exception("Bad VJP '{}' for '{}'".format(vjpfun, fun.__name__)) 69 | 70 | 71 | # -------------------- forward mode -------------------- 72 | 73 | primitive_jvps = {} 74 | 75 | 76 | def subval(x, i, v): 77 | x_ = list(x) 78 | x_[i] = v 79 | return tuple(x_) 80 | 81 | 82 | def defjvp_argnums(fun, jvpmaker): 83 | primitive_jvps[fun] = jvpmaker 84 | 85 | 86 | def defjvp_argnum(fun, jvpmaker): 87 | def jvp_argnums(argnums, ans, args, kwargs): 88 | return (jvpmaker(argnum, ans, args, kwargs) for argnum in argnums) 89 | 90 | defjvp_argnums(fun, jvp_argnums) 91 | 92 | 93 | def defjvp(fun, *jvpfuns, **kwargs): 94 | """ 95 | Set up a unumpy-transformable function for a JVP rule definition. 96 | 97 | Parameters 98 | ---------- 99 | fun : np.ufunc 100 | The function need to be derived. 101 | *jvpfuns : 102 | Functions for calculating derivative. 103 | 104 | Examples 105 | -------- 106 | >>> defjvp( 107 | ... np.arctan2, 108 | ... lambda ans, x, y: lambda g: g * y / (x ** 2 + y ** 2), 109 | ... lambda ans, x, y: lambda g: g * -x / (x ** 2 + y ** 2), 110 | ... ) 111 | 112 | """ 113 | argnums = kwargs.get("argnums", count()) 114 | jvps_dict = { 115 | argnum: translate_jvp(jvpfun, fun, argnum) 116 | for argnum, jvpfun in zip(argnums, jvpfuns) 117 | } 118 | 119 | def jvp_argnums(argnums, ans, args, kwargs): 120 | return [jvps_dict[argnum](ans, *args, **kwargs) for argnum in argnums] 121 | 122 | defjvp_argnums(fun, jvp_argnums) 123 | 124 | 125 | def translate_jvp(jvpfun, fun, argnum): 126 | if jvpfun is None: 127 | return lambda ans, *a, **k: lambda g: np.zeros_like(ans) 128 | elif jvpfun == "same": 129 | return lambda ans, *args, **kwargs: lambda g: fun( 130 | *subval(args, argnum, g), **kwargs 131 | ) 132 | elif callable(jvpfun): 133 | return jvpfun 134 | else: 135 | raise Exception("Bad JVP '{}' for '{}'".format(jvpfun, fun.__name__)) 136 | 137 | 138 | def def_linear(fun): 139 | """ 140 | Flags that a function is linear wrt all args. 141 | 142 | Parameters 143 | ---------- 144 | fun : np.ufunc 145 | The function need to be derived. 146 | 147 | Examples 148 | -------- 149 | >>> def_linear(np.matmul) 150 | 151 | """ 152 | defjvp_argnum( 153 | fun, 154 | lambda argnum, ans, args, kwargs: lambda g: fun( 155 | *subval(args, argnum, g), **kwargs 156 | ), 157 | ) 158 | -------------------------------------------------------------------------------- /src/udiff/_uarray_plug.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from uarray import wrap_single_convertor_instance 4 | from unumpy import ufunc, ndarray, numpy_backend 5 | import unumpy 6 | 7 | import unumpy as np 8 | import uarray as ua 9 | 10 | from ._diff_array import DiffArray, VJPDiffArray, JVPDiffArray 11 | from ._vjp_diffs import nograd_functions, raw_functions 12 | 13 | from typing import Dict 14 | 15 | _ufunc_mapping: Dict[ufunc, np.ufunc] = {} 16 | 17 | 18 | class DiffArrayBackend: 19 | """ 20 | The backend used for udiff. 21 | 22 | Attributes 23 | ---------- 24 | _inner 25 | The backend used, such as numpy_backend. 26 | 27 | _mode: default vjp. 28 | The mode used to calculate gradient. It must be vjp or jvp. 29 | 30 | Examples 31 | -------- 32 | >>> with ua.set_backend(udiff.DiffArrayBackend(numpy_backend), coerce=True): 33 | ... x = np.array([2]) 34 | """ 35 | 36 | __ua_domain__ = "numpy" 37 | 38 | _implementations: Dict = { 39 | unumpy.asarray: DiffArray, 40 | } 41 | 42 | @property 43 | @functools.lru_cache(None) 44 | def self_implementations(self): 45 | """ 46 | Specify the data type to be converted. 47 | """ 48 | return {unumpy.ClassOverrideMeta.overridden_class.fget: self.overridden_class} 49 | 50 | def __init__(self, inner, mode="vjp"): 51 | mode = mode.lower() 52 | if mode not in ["vjp", "jvp"]: 53 | raise ValueError("mode must be vjp or jvp") 54 | self._inner = inner 55 | self._mode = mode 56 | 57 | def overridden_class(self, self2): 58 | """ 59 | Convert ndarray to VJPDiffArray or JVPDiffArray according to mode. 60 | """ 61 | if self is ndarray: 62 | if self._mode == "vjp": 63 | return VJPDiffArray 64 | else: 65 | return JVPDiffArray 66 | 67 | with ua.set_backend(self._inner, only=True): 68 | return self2.overridden_class 69 | 70 | def __ua_function__(self, func, args, kwargs): 71 | extracted_args = func.arg_extractor(*args, **kwargs) 72 | arr_args = tuple(x.value for x in extracted_args if x.type is np.ndarray) 73 | 74 | with ua.set_backend(self._inner, only=True): 75 | if len(arr_args) == 0: 76 | out = func(*args, **kwargs) 77 | else: 78 | a, kw = self.replace_arrays( 79 | func, 80 | args, 81 | kwargs, 82 | ( 83 | x.value if x is not None and isinstance(x, DiffArray) else x 84 | for x in arr_args 85 | ), 86 | ) 87 | out = func(*a, **kw) 88 | 89 | real_func = func 90 | if func is np.ufunc.__call__: 91 | real_func = args[0] 92 | 93 | if real_func not in raw_functions: 94 | with ua.set_backend(self._inner, coerce=True): 95 | if self._mode == "vjp": 96 | out = VJPDiffArray(out) 97 | else: 98 | out = JVPDiffArray(out) 99 | 100 | if real_func not in nograd_functions: 101 | out.register_diff(func, args, kwargs) 102 | 103 | return out 104 | 105 | def replace_arrays(self, func, a, kw, arrays): 106 | """ 107 | Convert the parameters in func to primitive types. 108 | """ 109 | d = tuple(func.arg_extractor(*a, **kw)) 110 | arrays = tuple(arrays) 111 | new_d = [] 112 | j = 0 113 | for i in d: 114 | if i.type is np.ndarray: 115 | new_d.append(arrays[j]) 116 | j += 1 117 | else: 118 | new_d.append(i.value) 119 | 120 | return func.arg_replacer(a, kw, tuple(new_d)) 121 | 122 | @wrap_single_convertor_instance 123 | def __ua_convert__(self, value, dispatch_type, coerce): 124 | if dispatch_type is np.ndarray: 125 | if value is None: 126 | return value 127 | 128 | if isinstance(value, DiffArray): 129 | return value 130 | 131 | if coerce: 132 | with ua.set_backend(self._inner, coerce=True): 133 | if self._mode == "vjp": 134 | return VJPDiffArray(np.asarray(value)) 135 | else: 136 | return JVPDiffArray(np.asarray(value)) 137 | 138 | return NotImplemented 139 | 140 | return value 141 | 142 | __hash__ = object.__hash__ 143 | __eq__ = object.__eq__ 144 | -------------------------------------------------------------------------------- /src/tests/test_jacobian.py: -------------------------------------------------------------------------------- 1 | import uarray as ua 2 | import unumpy as np 3 | import numpy as onp 4 | import torch 5 | import dask.array as da 6 | import udiff 7 | import sparse 8 | from math import * 9 | from random import uniform, randrange 10 | import unumpy.numpy_backend as NumpyBackend 11 | 12 | import unumpy.torch_backend as TorchBackend 13 | import unumpy.dask_backend as DaskBackend 14 | import unumpy.sparse_backend as SparseBackend 15 | 16 | import numpy as onp 17 | from numpy.testing import assert_allclose 18 | 19 | import pytest 20 | 21 | ua.set_global_backend(NumpyBackend) 22 | 23 | LIST_BACKENDS = [ 24 | NumpyBackend, 25 | # DaskBackend, 26 | # SparseBackend, 27 | pytest.param( 28 | TorchBackend, 29 | marks=pytest.mark.xfail(reason="PyTorch not fully NumPy compatible."), 30 | ), 31 | ] 32 | 33 | 34 | FULLY_TESTED_BACKENDS = [NumpyBackend, DaskBackend] 35 | 36 | try: 37 | import unumpy.cupy_backend as CupyBackend 38 | import cupy as cp 39 | 40 | LIST_BACKENDS.append(pytest.param(CupyBackend)) 41 | except ImportError: 42 | LIST_BACKENDS.append( 43 | pytest.param( 44 | (None, None), marks=pytest.mark.skip(reason="cupy is not importable") 45 | ) 46 | ) 47 | 48 | 49 | EXCEPTIONS = { 50 | (DaskBackend, np.in1d), 51 | (DaskBackend, np.intersect1d), 52 | (DaskBackend, np.setdiff1d), 53 | (DaskBackend, np.setxor1d), 54 | (DaskBackend, np.union1d), 55 | (DaskBackend, np.sort), 56 | (DaskBackend, np.argsort), 57 | (DaskBackend, np.lexsort), 58 | (DaskBackend, np.partition), 59 | (DaskBackend, np.argpartition), 60 | (DaskBackend, np.sort_complex), 61 | (DaskBackend, np.msort), 62 | (DaskBackend, np.searchsorted), 63 | } 64 | 65 | 66 | @pytest.fixture(scope="session", params=LIST_BACKENDS) 67 | def backend(request): 68 | backend = request.param 69 | return backend 70 | 71 | 72 | @pytest.fixture(scope="session", params=["vjp", "jvp"]) 73 | def mode(request): 74 | mode = request.param 75 | return mode 76 | 77 | 78 | @pytest.mark.parametrize( 79 | "x, func, expect_jacobian", 80 | [ 81 | ( 82 | onp.arange(12).reshape(2, 3, 2), 83 | lambda x: np.sum(x, axis=1), 84 | [ 85 | [ 86 | [[[1, 0], [1, 0], [1, 0]], [[0, 0], [0, 0], [0, 0]]], 87 | [[[0, 1], [0, 1], [0, 1]], [[0, 0], [0, 0], [0, 0]]], 88 | ], 89 | [ 90 | [[[0, 0], [0, 0], [0, 0]], [[1, 0], [1, 0], [1, 0]]], 91 | [[[0, 0], [0, 0], [0, 0]], [[0, 1], [0, 1], [0, 1]]], 92 | ], 93 | ], 94 | ), 95 | ( 96 | onp.arange(4).reshape((2, 2)), 97 | lambda x: x, 98 | [ 99 | [[[1, 0], [0, 0]], [[0, 1], [0, 0]]], 100 | [[[0, 0], [1, 0]], [[0, 0], [0, 1]]], 101 | ], 102 | ), 103 | ], 104 | ) 105 | def test_jacobian(backend, mode, x, func, expect_jacobian): 106 | try: 107 | with ua.set_backend(udiff.DiffArrayBackend(backend, mode=mode), coerce=True): 108 | x = np.asarray(x) 109 | y = func(x) 110 | x_jacobian = y.to(x, jacobian=True) 111 | except ua.BackendNotImplementedError: 112 | if backend in FULLY_TESTED_BACKENDS: 113 | raise 114 | pytest.xfail(reason="The backend has no implementation for this ufunc.") 115 | 116 | if isinstance(y, da.Array): 117 | y.compute() 118 | 119 | assert_allclose(x_jacobian.value, expect_jacobian) 120 | 121 | 122 | @pytest.mark.parametrize( 123 | "u, v, func, expect_u_jacobian, expect_v_jacobian", 124 | [ 125 | ( 126 | onp.arange(2).reshape(1, 2, 1), 127 | onp.arange(2).reshape(1, 1, 2), 128 | lambda x, y: np.matmul(x, y), 129 | [[[[[[0], [0]]], [[[1], [0]]]], [[[[0], [0]]], [[[0], [1]]]]]], 130 | [[[[[[0, 0]]], [[[0, 0]]]], [[[[1, 0]]], [[[0, 1]]]]]], 131 | ), 132 | ], 133 | ) 134 | def test_separation_binary( 135 | backend, mode, u, v, func, expect_u_jacobian, expect_v_jacobian 136 | ): 137 | try: 138 | with ua.set_backend(udiff.DiffArrayBackend(backend, mode=mode), coerce=True): 139 | u = np.asarray(u) 140 | v = np.asarray(v) 141 | 142 | y = func(u, v) 143 | u_jacobian = y.to(u, jacobian=True) 144 | v_jacobian = y.to(v, jacobian=True) 145 | except ua.BackendNotImplementedError: 146 | if backend in FULLY_TESTED_BACKENDS: 147 | raise 148 | pytest.xfail(reason="The backend has no implementation for this ufunc.") 149 | 150 | if isinstance(y, da.Array): 151 | y.compute() 152 | 153 | assert_allclose(u_jacobian.value, expect_u_jacobian) 154 | assert_allclose(v_jacobian.value, expect_v_jacobian) 155 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | # sys.path.insert(0, os.path.abspath(".")) 17 | from typing import List, Dict 18 | 19 | sys.path.insert(0, os.path.abspath("../src")) 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "udiff" 24 | copyright = "2020, Quansight-Labs" 25 | author = "Quansight-Labs" 26 | 27 | # The full version, including alpha/beta/rc tags 28 | release = "0.6.0-alpha" 29 | 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named "sphinx.ext.*") or your custom 35 | # ones. 36 | extensions: List[str] = [ 37 | "nbsphinx", 38 | "sphinx.ext.autodoc", 39 | "sphinx.ext.viewcode", 40 | "sphinx.ext.napoleon", 41 | "sphinx.ext.intersphinx", 42 | "sphinx.ext.autosummary", 43 | "sphinx.ext.doctest", 44 | ] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ["_templates"] 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffixes as a list of string: 51 | # 52 | # source_suffix = [".rst", ".md"] 53 | source_suffix = ".rst" 54 | 55 | # The master toctree document. 56 | master_doc = "index" 57 | 58 | # The language for content autogenerated by Sphinx. Refer to documentation 59 | # for a list of supported languages. 60 | # 61 | # This is also used if you do content translation via gettext catalogs. 62 | # Usually you set "language" from the command line for these cases. 63 | language = None 64 | 65 | # List of patterns, relative to source directory, that match files and 66 | # directories to ignore when looking for source files. 67 | # This pattern also affects html_static_path and html_extra_path. 68 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 69 | 70 | # The name of the Pygments (syntax highlighting) style to use. 71 | pygments_style = None 72 | 73 | # -- Options for HTML output ------------------------------------------------- 74 | 75 | # The theme to use for HTML and HTML Help pages. See the documentation for 76 | # a list of builtin themes. 77 | # 78 | html_theme = "sphinx_rtd_theme" 79 | html_logo = "logo.png" 80 | html_favicon = "logo.png" 81 | 82 | # Add any paths that contain custom static files (such as style sheets) here, 83 | # relative to this directory. They are copied after the builtin static files, 84 | # so a file named "default.css" will overwrite the builtin "default.css". 85 | html_static_path: List[str] = ["_static"] 86 | 87 | # -- Options for HTMLHelp output --------------------------------------------- 88 | 89 | # Output file base name for HTML help builder. 90 | htmlhelp_basename = "udiffdoc" 91 | 92 | 93 | # -- Options for LaTeX output ------------------------------------------------ 94 | 95 | latex_elements: Dict[str, str] = { 96 | # The paper size ("letterpaper" or "a4paper"). 97 | # 98 | # "papersize": "letterpaper", 99 | # The font size ("10pt", "11pt" or "12pt"). 100 | # 101 | # "pointsize": "10pt", 102 | # Additional stuff for the LaTeX preamble. 103 | # 104 | # "preamble": "", 105 | # Latex figure (float) alignment 106 | # 107 | # "figure_align": "htbp", 108 | } 109 | 110 | # Grouping the document tree into LaTeX files. List of tuples 111 | # (source start file, target name, title, 112 | # author, documentclass [howto, manual, or own class]). 113 | latex_documents = [ 114 | (master_doc, "udiff.tex", "udiff Documentation", "Quansight-Labs", "manual") 115 | ] 116 | 117 | 118 | # -- Options for manual page output ------------------------------------------ 119 | 120 | # One entry per manual page. List of tuples 121 | # (source start file, name, description, authors, manual section). 122 | man_pages = [(master_doc, "udiff", "udiff Documentation", [author], 1)] 123 | 124 | 125 | # -- Options for Texinfo output ---------------------------------------------- 126 | 127 | # Grouping the document tree into Texinfo files. List of tuples 128 | # (source start file, target name, title, author, 129 | # dir menu entry, description, category) 130 | texinfo_documents = [ 131 | ( 132 | master_doc, 133 | "udiff", 134 | "udiff Documentation", 135 | author, 136 | "udiff", 137 | "udiff is a tool for automatic differentiation with uarray/unumpy.", 138 | "Miscellaneous", 139 | ) 140 | ] 141 | 142 | 143 | # -- Options for Epub output ------------------------------------------------- 144 | 145 | # Bibliographic Dublin Core info. 146 | epub_title = project 147 | 148 | # The unique identifier of the text. This can be a ISBN number 149 | # or the project homepage. 150 | # 151 | # epub_identifier = "" 152 | 153 | # A unique identification for the text. 154 | # 155 | # epub_uid = "" 156 | 157 | # A list of files that should not be packed into the epub file. 158 | epub_exclude_files = ["search.html"] 159 | 160 | autosummary_generate = True 161 | autoclass_content = "both" 162 | 163 | intersphinx_mapping = { 164 | "python": ("https://docs.python.org/3/", None), 165 | "numpy": ("https://docs.scipy.org/doc/numpy/", None), 166 | "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), 167 | "uarray": ("https://uarray.org/en/latest/", None), 168 | "unumpy": ("https://unumpy.uarray.org/en/latest/", None), 169 | } 170 | 171 | doctest_global_setup = """ 172 | import uarray as ua 173 | import unumpy as np 174 | from unumpy import numpy_backend 175 | import udiff 176 | """ 177 | -------------------------------------------------------------------------------- /docs/auto_diff.rst: -------------------------------------------------------------------------------- 1 | Automatic Differentiation 2 | ================================ 3 | 4 | The optimization process of deep learning models is based on the gradient 5 | descent method. Deep learning frameworks such as PyTorch and Tensorflow can 6 | be divided into three parts: model api, gradient calculation and gpu 7 | acceleration. Gradient calculation plays an important role, and the core 8 | technology of this part is automatic differentiation. 9 | 10 | Differential Methods 11 | ------------------------------------ 12 | 13 | There are four differential methods: 14 | 15 | * Manual differentiation 16 | * Numerical differentiation 17 | * Symbolic differentiation 18 | * Automatic differentiation 19 | 20 | .. image:: _static/approaches.png 21 | 22 | Manual differentiation is to use the derivation formula to manually write the 23 | derivation formula. This method is accurate and effective, and the only 24 | disadvantage is that it takes effort. 25 | 26 | Numerical differentiation uses the definition of derivative: 27 | 28 | .. math:: 29 | \frac{\partial f(\mathbf{x})}{\partial x_{i}} \approx \frac{f\left(\mathbf{x}+h \mathbf{e}_{i}\right)-f(\mathbf{x})}{h} 30 | 31 | This method is simple to implement, but there are two serious problems: 32 | truncation error and roundoff error. But this method can be a good way to 33 | check whether the gradient is accurate. 34 | 35 | Another method is symbolic differentiation, which transfers the work we did in 36 | manual differentiation to the computer. The problem with this method is that 37 | the expression must be closed-form, that is, there cannot be loops and 38 | conditional expressions. So that the entire problem can be converted into 39 | a pure mathematical symbol problem can be solved using some algebraic 40 | software. However, when expressions are complex, the problem of "expression 41 | swell" is prone to occur. 42 | 43 | .. image:: _static/expression_swell.png 44 | 45 | The last is our protagonist: automatic differentiation. It is also the most 46 | widely used derivation method in programe. 47 | 48 | What is Automatic Differentiation? 49 | ------------------------------------ 50 | 51 | The automatic differentiation discovers the essence of 52 | differential calculation: 53 | **Differential calculation is a combination of a limited series of \ 54 | differentiable operators.** 55 | 56 | We can regarded the formula 57 | 58 | .. math:: 59 | f\left(x_{1}, x_{2}\right)=\ln \left(x_{1}\right)+x_{1} x_{2}-\sin \left(x_{2}\right) 60 | 61 | as a calculation graph 62 | (What’s more, it can be regarded as a tree structure, too) 63 | . In the process of forward calculation, we can obtain the value of each node. 64 | 65 | .. image:: _static/computational_graph.png 66 | 67 | Then we can express the derivation process as follows: 68 | 69 | .. math:: 70 | \frac{d f}{d x_1} = \frac{d v_{-1}}{d x_1} \cdot (\frac{d v_1}{d v_{-1}} \cdot \frac{d v_4}{d v_1} + \frac{d v_2}{d v_{-1}} \cdot \frac{d v_4}{d v_2} ) \cdot \frac{d v_5}{d v_4} \cdot \frac{d f}{d v_5} 71 | 72 | It can be seen that the whole derivation can be split into a series of 73 | differential operator combinations. The calculation can be divided into two 74 | types: calculating the formula from forward to backward is called 75 | Forward Mode, and calculating the formula from backward to forward is called 76 | Reverse Mode. The process of the two modes is expressed as follows: 77 | 78 | .. image:: _static/forward_mode.png 79 | 80 | .. image:: _static/reverse_mode.png 81 | 82 | The gradient values calculated by the two modes are the same, but for the 83 | calculation order is different, the calculation speed is different. Generally, 84 | if the Jacobian matrix is relatively high, then the forward mode is more 85 | efficient; if the Jacobian matrix is wider, then the reverse mode is more 86 | efficient. 87 | 88 | ``JVP``, ``VJP`` and ``vmap`` 89 | ------------------------------------ 90 | 91 | If you have used pytorch, you will find that if ``y`` is a tensor instead of a 92 | scalar, you will be asked to pass a ``grad_variables`` in ``y.backward()``. 93 | And the derivative result ``x.grad`` has the same shape as ``x``. 94 | Where is the Jacobian matrix? 95 | 96 | The reason is that deep learning frameworks such as Tensorflow and PyTorch 97 | prohibit the derivatives with tensor by tensor, but only retain scalar by 98 | tensor. When we call ``y.backward()`` and enter a ``grad_variables`` v. 99 | In fact, it actually converts y into a weighted sum ``l = torch.sum(y * v)``, 100 | where ``l`` is a scalar, and then the gradient of ``x.grad`` is naturally 101 | of the same shape as ``x``. The reason for this is that the loss of deep 102 | learning is definitely a scalar, and gradient descent requires that 103 | the gradient must be of the same type as ``x``. 104 | 105 | But what if we want to obtain the Jacobian matrix? 106 | 107 | The answer is to derive ``x`` for each value of ``y``.In addition, 108 | Google's new deep learning framework JAX uses a more advanced method, the 109 | vectorization operation vmap to speed up the calculation. 110 | 111 | Reference 112 | ------------------------------------ 113 | * `The Autodiff Cookbook `_ 114 | * `Automatic Differentiation in Machine Learning: a Survey `_ 115 | 116 | How to implement ``VJP``? 117 | ================================ 118 | 119 | This article describes how to build an automatic differentiation framework 120 | based on vjp. 121 | 122 | Basic vjp differential operator 123 | ------------------------------------ 124 | 125 | The vjp differential operator is the cornerstone of automatic differentiation 126 | system from based on vjp. Because some differential operators are too 127 | complicated and error-prone, we can use the code in `autograd `_. 128 | The differential operator of A simple binary function is defined as follows: 129 | 130 | .. code-block:: python 131 | 132 | defvjp( 133 | np.subtract, 134 | lambda ans, x, y : unbroadcast_f(x, lambda g: g), 135 | lambda ans, x, y : unbroadcast_f(y, lambda g: -g) 136 | ) 137 | 138 | 139 | The ``defvjp`` registers two differential operators of the function 140 | ``np.subtract``. Each requires at least four input parameters. 141 | 142 | In the first stage (Build calculation graph), input the calculation result 143 | ``ans``, inputs ``x`` and ``y``, other parameters that may influence the 144 | derivative, such as the ``axis`` of `np.sum`. The second stage (back 145 | propagation) inputs the gradient ``g``. 146 | 147 | Build calculation graph 148 | ------------------------------------ 149 | 150 | We can express any calculation as a directed acyclic graph. For example, the 151 | formula 152 | 153 | .. math:: 154 | f\left(x_{1}, x_{2}\right)=\ln \left(x_{1}\right)+x_{1} x_{2}-\sin \left(x_{2 }\right), \text{assuming } x_1=2, x_2=5 155 | 156 | the calculation graph can be expressed as follows: 157 | 158 | .. image:: _static/build_computational_graph.png 159 | 160 | For back propagation, we need to define a data structure to retain 161 | each node in the calculation graph. Assuming this data structure is 162 | ``VJPDiffArray``, the class should have the following attributes: 163 | 164 | * ``_id``: The id of the node. 165 | 166 | * ``_value``: The value of the node, such as ``ln2``. 167 | 168 | * ``_parents``: The nodes that point to the current node, for example, the 169 | ``_parents`` of ``v4`` is ``[v2, v3]``. 170 | 171 | * ``_vjp``: A function that calculates the gradient of the current node to its 172 | parents. The function inputs ``ans``, ``x``, and ``y`` during forward 173 | calculation. Then you just need input ``g`` during backward propagation. 174 | 175 | All these attributes are assigned by a function called ``register_diff`` during 176 | forward calculation. 177 | 178 | Backward propagation 179 | ------------------------------------ 180 | 181 | After constructing the calculation graph, the process of derivation is 182 | relatively simple, we can express it as a back propagation on the calculation 183 | graph: 184 | 185 | .. image:: _static/backward.png 186 | 187 | For each node, we input the gradient of the previous node to current node to 188 | obtain the gradient of the current node to ``_parents``. This process is 189 | implemented by the ``_backward`` function. 190 | 191 | How to implement ``JVP``? 192 | ================================ 193 | 194 | ``VJP`` is derived from back to front, while ``JVP`` is derived from front to back. 195 | The problems is that it is difficult to implement high-order derivation 196 | in a concise way, because we will also register gradient when we calculate 197 | the derivative, and ``JVP``'s front-to-back computer system will make the program 198 | fall into an infinite loop. 199 | 200 | In order to solve this problem, we need to make each ``JVPDiffArray`` self-complete, 201 | that is, it needs to carry all the information needed for its own derivation. 202 | 203 | Thus, when calculating the gradient, 204 | it is enough to calculate the stored jvp functions chain 205 | because ``JVPDiffArray`` itself already contains all the required information. 206 | 207 | How to implement ``Jacobian``? 208 | ================================ 209 | 210 | After completing the above two parts, 211 | we can easily obtain the jacobian matrix based on them. 212 | ``VJP`` is the derivative of ``np.sum(y*v)`` (``v`` is the grad variables) to ``x``, 213 | and the part of the jacobian matrix ``j[i][j]`` 214 | (representing the derivative of ``y[i][j]`` to ``x``) 215 | can be regarded as setting the ``v`` in ``VJP`` as a matrix 216 | whose position [i,j] is 1 and all other positions are 0. 217 | 218 | How to support higher-order derivative? 219 | ======================================== 220 | Because we also register the gradient (call ``register_diff``) during 221 | the calculation of gradient, obtaining higher-order differentials 222 | only needs to call the ``to`` function repeatedly. 223 | 224 | For more details, please refer to `_diff_array.py `_. 225 | -------------------------------------------------------------------------------- /src/udiff/_jvp_diffs.py: -------------------------------------------------------------------------------- 1 | import unumpy as np 2 | from ._vjp_diffs import ( 3 | balanced_eq, 4 | match_complex, 5 | replace_zero, 6 | metadata, 7 | nograd_functions, 8 | ) 9 | from ._core import ( 10 | defjvp, 11 | defjvp_argnum, 12 | def_linear, 13 | ) 14 | 15 | # ----- Functions that are constant w.r.t. continuous inputs ----- 16 | defjvp(np.nan_to_num, lambda ans, x: lambda g: np.where(np.isfinite(x), g, 0.0)) 17 | 18 | # ----- Binary ufuncs (linear) ----- 19 | def_linear(np.multiply) 20 | 21 | # ----- Binary ufuncs ----- 22 | defjvp( 23 | np.add, 24 | lambda ans, x, y: lambda g: np.broadcast_to(g, np.shape(ans)), 25 | lambda ans, x, y: lambda g: np.broadcast_to(g, np.shape(ans)), 26 | ) 27 | defjvp( 28 | np.subtract, 29 | lambda ans, x, y: lambda g: np.broadcast_to(g, np.shape(ans)), 30 | lambda ans, x, y: lambda g: np.broadcast_to(-g, np.shape(ans)), 31 | ) 32 | defjvp( 33 | np.multiply, 34 | lambda ans, x, y: lambda g: np.broadcast_to(g * y, np.shape(ans)), 35 | lambda ans, x, y: lambda g: np.broadcast_to(x * g, np.shape(ans)), 36 | ) 37 | defjvp(np.divide, "same", lambda ans, x, y: lambda g: -g * x / y ** 2) 38 | defjvp( 39 | np.maximum, 40 | lambda ans, x, y: lambda g: g * balanced_eq(x, ans, y), 41 | lambda ans, x, y: lambda g: g * balanced_eq(y, ans, x), 42 | ) 43 | defjvp( 44 | np.minimum, 45 | lambda ans, x, y: lambda g: g * balanced_eq(x, ans, y), 46 | lambda ans, x, y: lambda g: g * balanced_eq(y, ans, x), 47 | ) 48 | defjvp( 49 | np.fmax, 50 | lambda ans, x, y: lambda g: g * balanced_eq(x, ans, y), 51 | lambda ans, x, y: lambda g: g * balanced_eq(y, ans, x), 52 | ) 53 | defjvp( 54 | np.fmin, 55 | lambda ans, x, y: lambda g: g * balanced_eq(x, ans, y), 56 | lambda ans, x, y: lambda g: g * balanced_eq(y, ans, x), 57 | ) 58 | defjvp( 59 | np.logaddexp, 60 | lambda ans, x, y: lambda g: g * np.exp(x - ans), 61 | lambda ans, x, y: lambda g: g * np.exp(y - ans), 62 | ) 63 | defjvp( 64 | np.logaddexp2, 65 | lambda ans, x, y: lambda g: g * 2 ** (x - ans), 66 | lambda ans, x, y: lambda g: g * 2 ** (y - ans), 67 | ) 68 | defjvp(np.true_divide, "same", lambda ans, x, y: lambda g: -g * x / y ** 2) 69 | defjvp( 70 | np.mod, 71 | lambda ans, x, y: lambda g: np.broadcast_to(g, np.shape(ans)), 72 | lambda ans, x, y: lambda g: -g * np.floor(x / y), 73 | ) 74 | defjvp( 75 | np.remainder, 76 | lambda ans, x, y: lambda g: np.broadcast_to(g, np.shape(ans)), 77 | lambda ans, x, y: lambda g: -g * np.floor(x / y), 78 | ) 79 | defjvp( 80 | np.power, 81 | lambda ans, x, y: lambda g: g * y * x ** np.where(y, y - 1, 1.0), 82 | lambda ans, x, y: lambda g: g * np.log(replace_zero(x, 1.0)) * ans, 83 | ) 84 | defjvp( 85 | np.arctan2, 86 | lambda ans, x, y: lambda g: g * y / (x ** 2 + y ** 2), 87 | lambda ans, x, y: lambda g: g * -x / (x ** 2 + y ** 2), 88 | ) 89 | 90 | # ----- Simple grads (linear) ----- 91 | defjvp(np.negative, "same") 92 | defjvp(np.rad2deg, "same") 93 | defjvp(np.degrees, "same") 94 | defjvp(np.deg2rad, "same") 95 | defjvp(np.radians, "same") 96 | defjvp(np.reshape, "same") 97 | defjvp(np.roll, "same") 98 | defjvp(np.array_split, "same") 99 | defjvp(np.split, "same") 100 | defjvp(np.vsplit, "same") 101 | defjvp(np.hsplit, "same") 102 | defjvp(np.dsplit, "same") 103 | defjvp(np.ravel, "same") 104 | defjvp(np.expand_dims, "same") 105 | defjvp(np.squeeze, "same") 106 | defjvp(np.diag, "same") 107 | defjvp(np.diagonal, "same") 108 | defjvp(np.flipud, "same") 109 | defjvp(np.fliplr, "same") 110 | defjvp(np.rot90, "same") 111 | defjvp(np.full, "same", argnums=(1,)) 112 | defjvp(np.triu, "same") 113 | defjvp(np.tril, "same") 114 | defjvp(np.swapaxes, "same") 115 | defjvp(np.rollaxis, "same") 116 | defjvp(np.moveaxis, "same") 117 | defjvp(np.broadcast_to, "same") 118 | def_linear(np.cross) 119 | 120 | # ----- Simple grads ----- 121 | defjvp(np.positive, lambda ans, x: lambda g: np.ones_like(x) * g) 122 | defjvp(np.negative, lambda ans, x: lambda g: -np.ones_like(x) * g) 123 | defjvp( 124 | np.fabs, lambda ans, x: lambda g: np.sign(x) * g 125 | ) # fabs doesn't take complex numbers. 126 | defjvp(np.absolute, lambda ans, x: lambda g: np.real(g * np.conj(x)) / ans) 127 | defjvp(np.reciprocal, lambda ans, x: lambda g: -g / x ** 2) 128 | defjvp(np.exp, lambda ans, x: lambda g: ans * g) 129 | defjvp(np.exp2, lambda ans, x: lambda g: ans * np.log(2) * g) 130 | defjvp(np.expm1, lambda ans, x: lambda g: (ans + 1) * g) 131 | defjvp(np.log, lambda ans, x: lambda g: g / x) 132 | defjvp(np.log2, lambda ans, x: lambda g: g / x / np.log(2)) 133 | defjvp(np.log10, lambda ans, x: lambda g: g / x / np.log(10)) 134 | defjvp(np.log1p, lambda ans, x: lambda g: g / (x + 1)) 135 | defjvp(np.sin, lambda ans, x: lambda g: g * np.cos(x)) 136 | defjvp(np.cos, lambda ans, x: lambda g: -g * np.sin(x)) 137 | defjvp(np.tan, lambda ans, x: lambda g: g / np.cos(x) ** 2) 138 | defjvp(np.arcsin, lambda ans, x: lambda g: g / np.sqrt(1 - x ** 2)) 139 | defjvp(np.arccos, lambda ans, x: lambda g: -g / np.sqrt(1 - x ** 2)) 140 | defjvp(np.arctan, lambda ans, x: lambda g: g / (1 + x ** 2)) 141 | defjvp(np.sinh, lambda ans, x: lambda g: g * np.cosh(x)) 142 | defjvp(np.cosh, lambda ans, x: lambda g: g * np.sinh(x)) 143 | defjvp(np.tanh, lambda ans, x: lambda g: g / np.cosh(x) ** 2) 144 | defjvp(np.arcsinh, lambda ans, x: lambda g: g / np.sqrt(x ** 2 + 1)) 145 | defjvp(np.arccosh, lambda ans, x: lambda g: g / np.sqrt(x ** 2 - 1)) 146 | defjvp(np.arctanh, lambda ans, x: lambda g: g / (1 - x ** 2)) 147 | defjvp(np.square, lambda ans, x: lambda g: g * 2 * x) 148 | defjvp(np.sqrt, lambda ans, x: lambda g: g * 0.5 * x ** -0.5) 149 | defjvp( 150 | np.sinc, 151 | lambda ans, x: lambda g: g 152 | * (np.cos(np.pi * x) * np.pi * x - np.sin(np.pi * x)) 153 | / (np.pi * x ** 2), 154 | ) 155 | defjvp( 156 | np.clip, 157 | lambda ans, x, a_min, a_max: lambda g: g 158 | * np.logical_and(ans != a_min, ans != a_max), 159 | ) 160 | defjvp(np.real_if_close, lambda ans, x: lambda g: match_complex(ans, g)) 161 | defjvp(np.real, lambda ans, x: lambda g: np.real(g)) 162 | defjvp(np.imag, lambda ans, x: lambda g: match_complex(ans, -1j * g)) 163 | defjvp(np.conj, lambda ans, x: lambda g: np.conj(g)) 164 | defjvp( 165 | np.angle, 166 | lambda ans, x: lambda g: match_complex(ans, g * np.conj(x * 1j) / np.abs(x) ** 2), 167 | ) 168 | defjvp( 169 | np.where, 170 | None, 171 | lambda ans, c, x=None, y=None: lambda g: np.where(c, g, np.zeros(np.shape(g))), 172 | lambda ans, c, x=None, y=None: lambda g: np.where(c, np.zeros(g.shape), g), 173 | ) 174 | 175 | # ----- Trickier grads ----- 176 | # defjvp(np.kron, "same", "same") 177 | defjvp(np.diff, "same") 178 | defjvp(np.gradient, "same") 179 | defjvp(np.repeat, "same") 180 | defjvp(np.tile, "same") 181 | defjvp(np.transpose, "same") 182 | defjvp(np.sum, "same") 183 | 184 | defjvp( 185 | np.prod, 186 | lambda ans, x, axis=None, keepdims=False: lambda g: ans 187 | * np.sum(g / x, axis=axis, keepdims=keepdims), 188 | ) 189 | defjvp( 190 | np.linspace, 191 | lambda ans, start, stop, *args, **kwargs: lambda g: np.linspace( 192 | g, 0, *args, **kwargs 193 | ), 194 | lambda ans, start, stop, *args, **kwargs: lambda g: np.linspace( 195 | 0, g, *args, **kwargs 196 | ), 197 | ) 198 | 199 | 200 | def forward_grad_np_var(ans, x, axis=None, ddof=0, keepdims=False): 201 | def jvp(g): 202 | if axis is None: 203 | num_reps = np.size(g) 204 | elif isinstance(axis, int): 205 | num_reps = np.shape(g)[axis] 206 | elif isinstance(axis, tuple): 207 | num_reps = np.prod(np.array(np.shape(g))[list(axis)]) 208 | 209 | x_minus_mean = np.conj(x - np.mean(x, axis=axis, keepdims=True)) 210 | return ( 211 | 2.0 212 | * np.sum(np.real(g * x_minus_mean), axis=axis, keepdims=keepdims) 213 | / (num_reps - ddof) 214 | ) 215 | 216 | return jvp 217 | 218 | 219 | defjvp(np.var, forward_grad_np_var) 220 | 221 | 222 | def forward_grad_np_std(ans, x, axis=None, ddof=0, keepdims=False): 223 | def jvp(g): 224 | if axis is None: 225 | num_reps = np.size(g) 226 | elif isinstance(axis, int): 227 | num_reps = np.shape(g)[axis] 228 | elif isinstance(axis, tuple): 229 | num_reps = np.prod(np.array(np.shape(g))[list(axis)]) 230 | 231 | if num_reps <= 1: 232 | return np.zeros_like(ans) 233 | x_minus_mean = np.conj(x - np.mean(x, axis=axis, keepdims=True)) 234 | return np.sum(np.real(g * x_minus_mean), axis=axis, keepdims=keepdims) / ( 235 | (num_reps - ddof) * ans 236 | ) 237 | 238 | return jvp 239 | 240 | 241 | defjvp(np.std, forward_grad_np_std) 242 | 243 | 244 | def fwd_grad_chooser(ans, x, axis=None, keepdims=False): 245 | def jvp(g): 246 | if np.isscalar(x): 247 | return g 248 | if not keepdims: 249 | if isinstance(axis, int): 250 | ans = np.expand_dims(ans, axis) 251 | elif isinstance(axis, tuple): 252 | for ax in sorted(axis): 253 | ans = np.expand_dims(ans, ax) 254 | chosen_locations = x == ans 255 | return np.sum((g * chosen_locations), axis=axis, keepdims=keepdims) / np.sum( 256 | chosen_locations, axis=axis, keepdims=keepdims 257 | ) 258 | 259 | return jvp 260 | 261 | 262 | defjvp(np.max, fwd_grad_chooser) 263 | defjvp(np.min, fwd_grad_chooser) 264 | 265 | 266 | defjvp(np.cumsum, "same") 267 | 268 | def_linear(np.matmul) 269 | 270 | 271 | def fwd_grad_sort(g, ans, x, axis=-1, kind="quicksort", order=None): 272 | sort_perm = np.argsort(x, axis, kind, order) 273 | return g[sort_perm] 274 | 275 | 276 | defjvp(np.sort, fwd_grad_sort) 277 | defjvp(np.msort, lambda ans, x: lambda g: fwd_grad_sort(g, ans, x, axis=0)) 278 | 279 | 280 | def fwd_grad_partition(ans, x, kth, axis=-1, kind="introselect", order=None): 281 | def jvp(g): 282 | partition_perm = np.argpartition(x, kth, axis, kind, order) 283 | return g[partition_perm] 284 | 285 | return jvp 286 | 287 | 288 | defjvp(np.partition, fwd_grad_partition) 289 | 290 | 291 | def atleast_jvpmaker(fun): 292 | def jvp(g, ans, *arys): 293 | if len(arys) > 1: 294 | raise NotImplementedError("Can't handle multiple arguments yet.") 295 | return lambda g: fun(g) 296 | 297 | return jvp 298 | 299 | 300 | defjvp(np.atleast_1d, atleast_jvpmaker(np.atleast_1d)) 301 | defjvp(np.atleast_2d, atleast_jvpmaker(np.atleast_2d)) 302 | defjvp(np.atleast_3d, atleast_jvpmaker(np.atleast_3d)) 303 | 304 | 305 | defjvp( 306 | np.pad, lambda ans, array, width, mode, **kwargs: lambda g: np.pad(g, width, mode) 307 | ) 308 | 309 | 310 | def stack_diff(ans, x, axis=0): 311 | def jvp(g): 312 | ret = [] 313 | ng = np.broadcast_to(g, np.shape(ans)) 314 | shape = np.shape(ng) 315 | for idx in range(shape[axis]): 316 | ret.append(np.take(ng, idx, axis=axis)) 317 | return tuple(ret) 318 | 319 | return jvp 320 | 321 | 322 | defjvp(np.stack, stack_diff) 323 | -------------------------------------------------------------------------------- /src/tests/test_diff.py: -------------------------------------------------------------------------------- 1 | import uarray as ua 2 | import unumpy as np 3 | import numpy as onp 4 | import torch 5 | import dask.array as da 6 | import udiff 7 | import sparse 8 | from math import * 9 | from random import uniform, randrange 10 | import unumpy.numpy_backend as NumpyBackend 11 | 12 | import unumpy.torch_backend as TorchBackend 13 | import unumpy.dask_backend as DaskBackend 14 | import unumpy.sparse_backend as SparseBackend 15 | 16 | import numpy as onp 17 | from numpy.testing import * 18 | 19 | import pytest 20 | 21 | ua.set_global_backend(NumpyBackend) 22 | 23 | LIST_BACKENDS = [ 24 | NumpyBackend, 25 | # DaskBackend, 26 | # SparseBackend, 27 | pytest.param( 28 | TorchBackend, 29 | marks=pytest.mark.xfail(reason="PyTorch not fully NumPy compatible."), 30 | ), 31 | ] 32 | 33 | 34 | FULLY_TESTED_BACKENDS = [NumpyBackend, DaskBackend] 35 | 36 | 37 | try: 38 | import unumpy.cupy_backend as CupyBackend 39 | import cupy as cp 40 | 41 | LIST_BACKENDS.append(pytest.param(CupyBackend)) 42 | except ImportError: 43 | LIST_BACKENDS.append( 44 | pytest.param( 45 | (None, None), marks=pytest.mark.skip(reason="cupy is not importable") 46 | ) 47 | ) 48 | 49 | 50 | EXCEPTIONS = { 51 | (DaskBackend, np.in1d), 52 | (DaskBackend, np.intersect1d), 53 | (DaskBackend, np.setdiff1d), 54 | (DaskBackend, np.setxor1d), 55 | (DaskBackend, np.union1d), 56 | (DaskBackend, np.sort), 57 | (DaskBackend, np.argsort), 58 | (DaskBackend, np.lexsort), 59 | (DaskBackend, np.partition), 60 | (DaskBackend, np.argpartition), 61 | (DaskBackend, np.sort_complex), 62 | (DaskBackend, np.msort), 63 | (DaskBackend, np.searchsorted), 64 | } 65 | 66 | 67 | @pytest.fixture(scope="session", params=LIST_BACKENDS) 68 | def backend(request): 69 | backend = request.param 70 | return backend 71 | 72 | 73 | @pytest.fixture(scope="session", params=["vjp", "jvp"]) 74 | def mode(request): 75 | mode = request.param 76 | return mode 77 | 78 | 79 | def generate_test_data(n_elements=12, a=None, b=None): 80 | if a is None: 81 | a = -10 82 | if b is None: 83 | b = 10 84 | x_arr = [uniform(a + 1e-3, b - 1e-3) for i in range(n_elements)] 85 | return x_arr 86 | 87 | 88 | def grad_check_sparse(f, x, analytic_grad, num_checks=10, h=1e-5): 89 | """ 90 | sample a few random elements and only return numerical 91 | in this dimensions. 92 | """ 93 | for i in range(num_checks): 94 | ix = tuple([randrange(m) for m in np.shape(x)]) 95 | 96 | oldval = x[ix] 97 | x[ix] = oldval + h # increment by h 98 | fxph = f(x) # evaluate f(x + h) 99 | x[ix] = oldval - h # increment by h 100 | fxmh = f(x) # evaluate f(x - h) 101 | x[ix] = oldval # reset 102 | 103 | grad_numerical = ((fxph - fxmh) / (2 * h))[ix] 104 | grad_analytic = analytic_grad[ix] 105 | rel_error = abs(grad_numerical - grad_analytic) / ( 106 | abs(grad_numerical) + abs(grad_analytic) 107 | ) 108 | assert_almost_equal(rel_error, 0, decimal=5) 109 | 110 | 111 | @pytest.mark.parametrize( 112 | "func, y_d, domain", 113 | [ 114 | (np.positive, lambda x: 1, None), 115 | (np.negative, lambda x: -1, None), 116 | (np.exp, lambda x: pow(e, x), None), 117 | (np.exp2, lambda x: pow(2, x) * log(2), None), 118 | (np.log, lambda x: 1 / x, (0, None)), 119 | (np.log2, lambda x: 1 / (x * log(2)), (0, None)), 120 | (np.log10, lambda x: 1 / (x * log(10)), (0, None)), 121 | (np.log1p, lambda x: 1 / (x + 1), (-1, None)), 122 | (np.sqrt, lambda x: 0.5 * pow(x, -0.5), (0, None)), 123 | (np.square, lambda x: 2 * x, None), 124 | (np.reciprocal, lambda x: -1 / pow(x, 2), (None, 0)), 125 | (np.sin, lambda x: cos(x), None), 126 | (np.cos, lambda x: -sin(x), None), 127 | ( 128 | np.tan, 129 | lambda x: 1 / cos(x) ** 2, 130 | (-5, 5), 131 | ), # Set bound to prevent numerical overflow 132 | (np.arcsin, lambda x: 1 / sqrt(1 - x ** 2), (-1, 1)), 133 | (np.arccos, lambda x: -1 / sqrt(1 - x ** 2), (-1, 1)), 134 | (np.arctan, lambda x: 1 / (1 + x ** 2), None), 135 | (np.sinh, lambda x: cosh(x), None), 136 | (np.cosh, lambda x: sinh(x), (1, None)), 137 | (np.tanh, lambda x: 1 / cosh(x) ** 2, (-1, 1)), 138 | (np.arcsinh, lambda x: 1 / sqrt(1 + x ** 2), None), 139 | (np.arccosh, lambda x: 1 / sqrt(-1 + x ** 2), (1, None)), 140 | (np.arctanh, lambda x: 1 / (1 - x ** 2), (-1, 1)), 141 | (np.absolute, lambda x: 1 if x > 0 else -1, None), 142 | (np.fabs, lambda x: 1 if x > 0 else -1, None), 143 | (np.reciprocal, lambda x: -1 / x ** 2, (1, 10)), 144 | (np.expm1, lambda x: exp(x), None), 145 | (np.rad2deg, lambda x: 1 / pi * 180.0, None), 146 | (np.deg2rad, lambda x: pi / 180.0, None), 147 | ], 148 | ) 149 | def test_unary_function(backend, mode, func, y_d, domain): 150 | if domain is None: 151 | x_arr = generate_test_data() 152 | else: 153 | x_arr = generate_test_data(a=domain[0], b=domain[1]) 154 | expect_diff = [y_d(xa) for xa in x_arr] 155 | try: 156 | with ua.set_backend(udiff.DiffArrayBackend(backend, mode=mode), coerce=True): 157 | x = np.asarray(x_arr) 158 | y = func(x) 159 | x_diff = y.to(x) 160 | except ua.BackendNotImplementedError: 161 | if backend in FULLY_TESTED_BACKENDS: 162 | raise 163 | pytest.xfail(reason="The backend has no implementation for this ufunc.") 164 | 165 | if isinstance(y, da.Array): 166 | y.compute() 167 | 168 | assert_allclose(x_diff.value, expect_diff) 169 | 170 | 171 | @pytest.mark.parametrize( 172 | "func, u_d, v_d, u_domain, v_domain", 173 | [ 174 | (np.add, lambda u, v: 1, lambda u, v: 1, None, None), 175 | (np.subtract, lambda u, v: 1, lambda u, v: -1, None, None), 176 | (np.multiply, lambda u, v: v, lambda u, v: u, None, None), 177 | (np.divide, lambda u, v: 1 / v, lambda u, v: -u / v ** 2, None, (0, None)), 178 | ( 179 | np.maximum, 180 | lambda u, v: 1 if u >= v else 0, 181 | lambda u, v: 1 if v > u else 0, 182 | None, 183 | None, 184 | ), 185 | ( 186 | np.minimum, 187 | lambda u, v: 1 if u <= v else 0, 188 | lambda u, v: 1 if v <= u else 0, 189 | None, 190 | None, 191 | ), 192 | ( 193 | np.logaddexp, 194 | lambda u, v: exp(u) / (exp(u) + exp(v)), 195 | lambda u, v: exp(v) / (exp(u) + exp(v)), 196 | (-1, 1), 197 | (-1, 1), 198 | ), 199 | ( 200 | np.logaddexp2, 201 | lambda u, v: 2 ** u / (2 ** u + 2 ** v), 202 | lambda u, v: 2 ** v / (2 ** u + 2 ** v), 203 | (-1, 1), 204 | (-1, 1), 205 | ), 206 | ( 207 | np.true_divide, 208 | lambda u, v: 1 / v, 209 | lambda u, v: -u / (v ** 2), 210 | (1, 5), 211 | (1, 5), 212 | ), 213 | (np.mod, lambda u, v: 1, lambda u, v: -floor(u / v), (1, 10), (1, 10)), 214 | ( 215 | np.power, 216 | lambda u, v: pow(u, v) * v / u, 217 | lambda u, v: pow(u, v) * log(u), 218 | (1, 5), 219 | (1, 5), 220 | ), 221 | ( 222 | np.arctan2, 223 | lambda u, v: v / (u ** 2 + v ** 2), 224 | lambda u, v: -u / (u ** 2 + v ** 2), 225 | (0, 1), 226 | (0, 1), 227 | ), 228 | ( 229 | np.hypot, 230 | lambda u, v: u / sqrt(u ** 2 + v ** 2), 231 | lambda u, v: v / sqrt(u ** 2 + v ** 2), 232 | None, 233 | None, 234 | ), 235 | ], 236 | ) 237 | def test_binary_function(backend, mode, func, u_d, v_d, u_domain, v_domain): 238 | if u_domain is None: 239 | u_arr = generate_test_data() 240 | else: 241 | u_arr = generate_test_data(a=u_domain[0], b=u_domain[1]) 242 | if v_domain is None: 243 | v_arr = generate_test_data() 244 | else: 245 | v_arr = generate_test_data(a=v_domain[0], b=v_domain[1]) 246 | 247 | expect_u_diff = [u_d(ua, va) for ua, va in zip(u_arr, v_arr)] 248 | expect_v_diff = [v_d(ua, va) for ua, va in zip(u_arr, v_arr)] 249 | try: 250 | with ua.set_backend(udiff.DiffArrayBackend(backend, mode=mode), coerce=True): 251 | u = np.asarray(u_arr) 252 | v = np.asarray(v_arr) 253 | y = func(u, v) 254 | u_diff = y.to(u) 255 | v_diff = y.to(v) 256 | except ua.BackendNotImplementedError: 257 | if backend in FULLY_TESTED_BACKENDS: 258 | raise 259 | pytest.xfail(reason="The backend has no implementation for this ufunc.") 260 | except NotImplementedError: 261 | pytest.xfail( 262 | reason="The func has no implementation in the {} mode.".format(mode) 263 | ) 264 | 265 | if isinstance(y, da.Array): 266 | y.compute() 267 | 268 | assert_allclose(u_diff.value, expect_u_diff) 269 | assert_allclose(v_diff.value, expect_v_diff) 270 | 271 | 272 | @pytest.mark.parametrize( 273 | "func, y_d, domain", 274 | [ 275 | (lambda x: x * x, lambda x: 2 * x, None), 276 | (lambda x: (2 * x + 1) ** 3, lambda x: 6 * (2 * x + 1) ** 2, (0.5, None)), 277 | ( 278 | lambda x: np.sin(x ** 2) / np.sin(x) ** 2, 279 | lambda x: (2 * x * cos(x ** 2) * sin(x) - 2 * sin(x ** 2) * cos(x)) 280 | / sin(x) ** 3, 281 | (0, pi), 282 | ), 283 | ( 284 | lambda x: np.log(x ** 2) ** (1 / 3), 285 | lambda x: 2 * log(x ** 2) ** (-2 / 3) / (3 * x), 286 | (1, None), 287 | ), 288 | ( 289 | lambda x: np.log((1 + x) / (1 - x)) / 4 - np.arctan(x) / 2, 290 | lambda x: x ** 2 / (1 - x ** 4), 291 | (-1, 1), 292 | ), 293 | ( 294 | lambda x: np.log(1 + x ** 2) / np.arctanh(x), 295 | lambda x: ( 296 | (2 * x * atanh(x) / (1 + x ** 2)) - (log(1 + x ** 2) / (1 - x ** 2)) 297 | ) 298 | / atanh(x) ** 2, 299 | (0, 1), 300 | ), 301 | ], 302 | ) 303 | def test_arbitrary_function(backend, mode, func, y_d, domain): 304 | if domain is None: 305 | x_arr = generate_test_data() 306 | else: 307 | x_arr = generate_test_data(a=domain[0], b=domain[1]) 308 | expect_diff = [y_d(xa) for xa in x_arr] 309 | try: 310 | with ua.set_backend(udiff.DiffArrayBackend(backend, mode=mode), coerce=True): 311 | x = np.asarray(x_arr) 312 | y = func(x) 313 | x_diff = y.to(x) 314 | except ua.BackendNotImplementedError: 315 | if backend in FULLY_TESTED_BACKENDS: 316 | raise 317 | pytest.xfail(reason="The backend has no implementation for this ufunc.") 318 | 319 | if isinstance(y, da.Array): 320 | y.compute() 321 | 322 | assert_allclose(x_diff.value, expect_diff) 323 | 324 | 325 | @pytest.mark.parametrize( 326 | "func, y_d, domain", 327 | [ 328 | (lambda x: x * x, lambda x: 2, None), 329 | (lambda x: (2 * x + 1) ** 3, lambda x: 24 * (2 * x + 1), (0.5, None)), 330 | ( 331 | lambda x: np.sin(x ** 2) + np.sin(x) ** 2, 332 | lambda x: 2 * cos(x ** 2) 333 | - 4 * x ** 2 * sin(x ** 2) 334 | + 2 * cos(x) ** 2 335 | - 2 * sin(x) ** 2, 336 | (0, pi), 337 | ), 338 | ( 339 | lambda x: np.log(x ** 2), 340 | lambda x: -2 / x ** 2, 341 | (1, None), 342 | ), 343 | ( 344 | lambda x: np.power(np.cos(x), 2) * np.log(x), 345 | lambda x: -2 * cos(2 * x) * log(x) 346 | - 2 * sin(2 * x) / x 347 | - cos(x) ** 2 / x ** 2, 348 | (0, None), 349 | ), 350 | ( 351 | lambda x: x / np.sqrt(1 - x ** 2), 352 | lambda x: 3 * x / (1 - x ** 2) ** (5 / 2), 353 | (-1, 1), 354 | ), 355 | ], 356 | ) 357 | def test_high_order_diff(backend, mode, func, y_d, domain): 358 | if domain is None: 359 | x_arr = generate_test_data() 360 | else: 361 | x_arr = generate_test_data(a=domain[0], b=domain[1]) 362 | expect_diff = [y_d(xa) for xa in x_arr] 363 | try: 364 | with ua.set_backend(udiff.DiffArrayBackend(backend, mode=mode), coerce=True): 365 | x = np.asarray(x_arr) 366 | y = func(x) 367 | x_diff = y.to(x).to(x) 368 | except ua.BackendNotImplementedError: 369 | if backend in FULLY_TESTED_BACKENDS: 370 | raise 371 | pytest.xfail(reason="The backend has no implementation for this ufunc.") 372 | 373 | if isinstance(y, da.Array): 374 | y.compute() 375 | 376 | assert_allclose(x_diff.value, expect_diff) 377 | -------------------------------------------------------------------------------- /src/udiff/_diff_array.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import uarray as ua 3 | import unumpy as np 4 | import itertools 5 | from functools import reduce 6 | from ._core import primitive_vjps, primitive_jvps 7 | from unumpy import numpy_backend 8 | 9 | 10 | class DiffArray(np.ndarray): 11 | """ 12 | A container with the necessary information used in derivation. 13 | 14 | Attributes 15 | ---------- 16 | arr : DiffArray 17 | A DiffArray or ndarray used to initialize the class. 18 | 19 | 20 | .. note:: DiffArray is the base class of JVPDiffArray and VJPDiffArray. Do not use it. 21 | 22 | """ 23 | 24 | def __init__(self, arr): 25 | if isinstance(arr, DiffArray): 26 | self._value = arr.value 27 | self._id = arr.id 28 | return 29 | 30 | with ua.determine_backend(arr, np.ndarray, domain="numpy", coerce=True): 31 | arr = np.asarray(arr) 32 | 33 | self._value = arr 34 | self._id = uuid.uuid4() 35 | 36 | @property 37 | def dtype(self): 38 | """ 39 | The data type of the DiffArray. 40 | """ 41 | return self._value.dtype 42 | 43 | @property 44 | def value(self): 45 | """ 46 | The value of the DiffArray. 47 | """ 48 | return self._value 49 | 50 | @property 51 | def id(self): 52 | """ 53 | The id of the DiffArray. 54 | """ 55 | return self._id 56 | 57 | def __str__(self): 58 | return "<{}, id={}, value=\n{}\n>".format( 59 | type(self).__name__, 60 | repr(self.id), 61 | str(self.value), 62 | ) 63 | 64 | __repr__ = __str__ 65 | __hash__ = object.__hash__ 66 | 67 | 68 | class VJPDiffArray(DiffArray): 69 | """ 70 | A container with the necessary information used in derivation under VJP mode. 71 | 72 | Attributes 73 | ---------- 74 | arr : VJPDiffArray or ndarray 75 | An VJPDiffArray or ndarray used to initialize the class. 76 | 77 | Examples 78 | -------- 79 | You do not need to use VJPDiffArray explicitly. 80 | When you call a function such as ``np.array`` that could create a ndarray under vjp mode, 81 | VJPDiffArray will be created automatically. 82 | 83 | >>> with ua.set_backend(udiff.DiffArrayBackend(numpy_backend), coerce=True): 84 | ... x = np.array([2]) 85 | ... isinstance(x, VJPDiffArray) 86 | True 87 | """ 88 | 89 | def __init__(self, arr): 90 | if isinstance(arr, VJPDiffArray): 91 | self._value = arr.value 92 | self._id = arr.id 93 | self._parents = arr._parents 94 | self._vjp = arr._vjp 95 | self._diff = arr._diff 96 | self._jacobian = arr._jacobian 97 | return 98 | 99 | with ua.determine_backend(arr, np.ndarray, domain="numpy", coerce=True): 100 | arr = np.asarray(arr) 101 | 102 | self._value = arr 103 | self._id = uuid.uuid4() 104 | self._parents = None 105 | self._vjp = None 106 | self._diff = None 107 | self._jacobian = None 108 | 109 | def register_diff(self, func, args, kwargs): 110 | """ 111 | Register the derivative function used in backward propagation for the current node. 112 | 113 | Parameters 114 | ---------- 115 | func : np.ufunc 116 | The function need to be derived. 117 | args : 118 | Arguments used in func. 119 | kwargs : 120 | Keyword-only arguments used in func. 121 | """ 122 | 123 | try: 124 | if func is np.ufunc.__call__: 125 | vjpmaker = primitive_vjps[args[0]] 126 | else: 127 | vjpmaker = primitive_vjps[func] 128 | except KeyError: 129 | raise NotImplementedError("VJP of func not defined") 130 | 131 | vjp_args = [] 132 | 133 | if self._parents is None: 134 | self._parents = [] 135 | 136 | for arg in args: 137 | if isinstance(arg, VJPDiffArray): 138 | self._parents.append(arg) 139 | vjp_args.append(arg) 140 | elif not isinstance(arg, np.ufunc): 141 | vjp_args.append(arg) 142 | 143 | parent_argnums = tuple(range(len(self._parents))) 144 | self._vjp = vjpmaker(parent_argnums, self, tuple(vjp_args), kwargs) 145 | 146 | def _backward(self, grad_variables, end_node, base): 147 | """ 148 | Backpropagation. 149 | Traverse computation graph backwards in topological order from the end node. 150 | For each node, compute local gradient contribution and accumulate. 151 | """ 152 | if grad_variables is None: 153 | grad_variables = np.ones_like(self.value) 154 | 155 | if end_node is None: 156 | end_node = self 157 | 158 | if base is None or base.id == self.id: 159 | if self._diff is None: 160 | self._diff = {} 161 | 162 | if end_node in self._diff: 163 | self._diff[end_node] = self._diff[end_node] + grad_variables 164 | else: 165 | self._diff[end_node] = grad_variables 166 | 167 | if self._vjp: 168 | diffs = list(self._vjp(grad_variables)) 169 | for i, p in enumerate(self._parents): 170 | p._backward(diffs[i], end_node, base) 171 | 172 | def _backward_jacobian(self, grad_variables, end_node, position, base): 173 | if base is None or base.id == self.id: 174 | if self._jacobian is None: 175 | self._jacobian = {} 176 | 177 | if end_node not in self._jacobian: 178 | self._jacobian[end_node] = {} 179 | 180 | if position not in self._jacobian[end_node]: 181 | self._jacobian[end_node][position] = grad_variables 182 | else: 183 | self._jacobian[end_node][position] = ( 184 | self._jacobian[end_node][position] + grad_variables 185 | ) 186 | 187 | if self._vjp: 188 | diffs = list(self._vjp(grad_variables)) 189 | for i, p in enumerate(self._parents): 190 | p._backward_jacobian(diffs[i], end_node, position, base) 191 | 192 | def to(self, x, grad_variables=None, jacobian=False): 193 | """ 194 | Calculate the VJP or Jacobian matrix of self to x. 195 | 196 | Parameters 197 | ---------- 198 | x : VJPDiffArray 199 | The denominator in derivative. 200 | grad_variables : VJPDiffArray 201 | Gradient of the numerator in derivative. 202 | jacobian : bool 203 | Flag identifies whether to calculate the jacobian logo. 204 | If set ``True``, it will return jacobian matrix instead of vjp. 205 | 206 | Examples 207 | -------- 208 | >>> with ua.set_backend(udiff.DiffArrayBackend(numpy_backend), coerce=True): 209 | ... 210 | ... x1 = np.array([2]) 211 | ... x2 = np.array([5]) 212 | ... y = np.log(x1) + x1 * x2 - np.sin(x2) 213 | ... x1_diff = y.to(x1) 214 | ... print(np.allclose(x1_diff.value, [5.5])) 215 | True 216 | """ 217 | if jacobian: 218 | if x._jacobian is None or self not in x._jacobian: 219 | for position in itertools.product(*[range(i) for i in np.shape(self)]): 220 | grad_variables = np.zeros_like(self.value) 221 | grad_variables.value[position] = 1 222 | self._backward_jacobian(grad_variables, self, position, x) 223 | 224 | x._jacobian[self] = np.reshape( 225 | np.stack(x._jacobian[self].values()), np.shape(self) + np.shape(x) 226 | ) 227 | return x._jacobian[self] 228 | else: 229 | if x._diff is None or self not in x._diff: 230 | self._backward(grad_variables, self, x) 231 | return x._diff[self] 232 | 233 | 234 | class JVPDiffArray(DiffArray): 235 | """ 236 | A container with the necessary information used in derivation under jvp mode. 237 | 238 | Attributes 239 | ---------- 240 | arr : JVPDiffArray 241 | A JVPDiffArray or ndarray used to initialize the class. 242 | 243 | Examples 244 | -------- 245 | You do not need to use JVPDiffArray explicitly. 246 | When you call a function such as ``np.array`` that could create a ndarray under jvp mode, 247 | JVPDiffArray will be created automatically. 248 | 249 | >>> with ua.set_backend(udiff.DiffArrayBackend(numpy_backend, mode="jvp"), coerce=True): 250 | ... x = np.array([2]) 251 | ... isinstance(x, JVPDiffArray) 252 | True 253 | """ 254 | 255 | def __init__(self, arr): 256 | if isinstance(arr, JVPDiffArray): 257 | self._value = arr.value 258 | self._id = arr.id 259 | self._diff = arr._diff 260 | self._jacobian = arr._jacobian 261 | self._jvp = arr._jvp 262 | self._vars = arr._vars 263 | return 264 | 265 | with ua.determine_backend(arr, np.ndarray, domain="numpy", coerce=True): 266 | arr = np.asarray(arr) 267 | 268 | self._value = arr 269 | self._id = uuid.uuid4() 270 | self._diff = None 271 | self._jvp = None 272 | self._jacobian = None 273 | self._vars = None 274 | 275 | def register_diff(self, func, args, kwargs): 276 | """ 277 | Register all the information used in forward propagation for the current node. 278 | 279 | Parameters 280 | ---------- 281 | func : np.ufunc 282 | The function need to be derived. 283 | args : 284 | Arguments used in func. 285 | kwargs : 286 | Keyword-only arguments used in func. 287 | """ 288 | try: 289 | if func is np.ufunc.__call__: 290 | jvpmaker = primitive_jvps[args[0]] 291 | else: 292 | jvpmaker = primitive_jvps[func] 293 | except KeyError: 294 | raise NotImplementedError("JVP of func not defined") 295 | 296 | jvp_args, parents = [], [] 297 | 298 | for arg in args: 299 | if isinstance(arg, JVPDiffArray): 300 | jvp_args.append(arg) 301 | 302 | parents.append(arg) 303 | 304 | elif not isinstance(arg, np.ufunc): 305 | jvp_args.append(arg) 306 | 307 | parent_argnums = tuple(range(len(parents))) 308 | 309 | jvps = list(jvpmaker(parent_argnums, self, tuple(jvp_args), kwargs)) 310 | 311 | if self._jvp is None: 312 | self._jvp = {} 313 | 314 | for p, jvp in zip(parents, jvps): 315 | if p not in self._jvp: 316 | self._jvp[p] = [[jvp]] 317 | else: 318 | self._jvp[p] += [[jvp]] 319 | if p._jvp: 320 | for base in p._jvp: 321 | if base not in self._jvp: 322 | self._jvp[base] = [] 323 | self._jvp[base] += [flist + [jvp] for flist in p._jvp[base]] 324 | 325 | def _forward(self, x, grad_variables): 326 | if self._jvp is None: 327 | return grad_variables 328 | 329 | result = [] 330 | 331 | for flist in self._jvp[x]: 332 | cur_result = grad_variables 333 | for f in flist: 334 | cur_result = f(cur_result) 335 | result.append(cur_result) 336 | 337 | return reduce(lambda x, y: x + y, result) 338 | 339 | def to(self, x, grad_variables=None, jacobian=False): 340 | """ 341 | Calculate the JVP or Jacobian matrix of self to x. 342 | 343 | Parameters 344 | ---------- 345 | x : JVPDiffArray 346 | The denominator in derivative. 347 | grad_variables : JVPDiffArray 348 | Gradient assigned to the x. 349 | jacobian : bool 350 | Flag identifies whether to calculate the jacobian logo. 351 | If set ``True``, it will return jacobian matrix instead of jvp. 352 | 353 | Examples 354 | -------- 355 | >>> with ua.set_backend(udiff.DiffArrayBackend(numpy_backend, mode="jvp"), coerce=True): 356 | ... 357 | ... x1 = np.array([2]) 358 | ... x2 = np.array([5]) 359 | ... y = np.log(x1) + x1 * x2 - np.sin(x2) 360 | ... x1_diff = y.to(x1) 361 | ... print(np.allclose(x1_diff, [5.5])) 362 | True 363 | """ 364 | if self._jvp and x not in self._jvp: 365 | raise ValueError("Please check if the base is correct.") 366 | 367 | if jacobian: 368 | if self._jacobian is None: 369 | self._jacobian = {} 370 | 371 | if x not in self._jacobian: 372 | self._jacobian[x] = {} 373 | for position in itertools.product(*[range(i) for i in np.shape(x)]): 374 | grad_variables = np.zeros_like(x) 375 | grad_variables.value[position] = 1 376 | self._jacobian[x][position] = self._forward(x, grad_variables) 377 | 378 | old_axes = tuple(range(np.ndim(self) + np.ndim(x))) 379 | new_axes = old_axes[np.ndim(x) :] + old_axes[: np.ndim(x)] 380 | self._jacobian[x] = np.transpose( 381 | np.reshape( 382 | np.stack(self._jacobian[x].values()), 383 | np.shape(x) + np.shape(self), 384 | ), 385 | new_axes, 386 | ) 387 | return self._jacobian[x] 388 | else: 389 | if self._diff is None: 390 | self._diff = {} 391 | 392 | if x not in self._diff: 393 | if grad_variables is None: 394 | grad_variables = np.ones_like(self) 395 | 396 | self._diff[x] = self._forward(x, grad_variables) 397 | 398 | return self._diff[x] 399 | -------------------------------------------------------------------------------- /src/udiff/_vjp_diffs.py: -------------------------------------------------------------------------------- 1 | from ._core import defvjp, defvjp_argnum 2 | 3 | import uarray as ua 4 | import unumpy as np 5 | from unumpy import numpy_backend 6 | from functools import partial, reduce 7 | import operator 8 | import collections.abc 9 | 10 | # ----- Non-differentiable functions ----- 11 | raw_functions = set( 12 | [ 13 | np.ClassOverrideMeta.__instancecheck__, 14 | np.ndim, 15 | np.shape, 16 | np.dtype, 17 | np.isfinite, 18 | np.isinf, 19 | np.isnan, 20 | np.equal, 21 | np.not_equal, 22 | np.size, 23 | np.array_equiv, 24 | np.greater, 25 | np.greater_equal, 26 | np.less, 27 | np.less_equal, 28 | np.logical_and, 29 | np.logical_or, 30 | np.logical_not, 31 | np.logical_xor, 32 | np.isneginf, 33 | np.isposinf, 34 | np.allclose, 35 | np.isclose, 36 | np.iscomplexobj, 37 | np.iscomplex, 38 | np.isscalar, 39 | np.isreal, 40 | ] 41 | ) 42 | 43 | 44 | nograd_functions = set( 45 | [ 46 | np.array, 47 | np.ones, 48 | np.zeros, 49 | np.floor, 50 | np.ceil, 51 | np.rint, 52 | np.trunc, 53 | np.all, 54 | np.arange, 55 | np.any, 56 | np.argmax, 57 | np.argmin, 58 | np.argpartition, 59 | np.argsort, 60 | np.argwhere, 61 | np.nonzero, 62 | np.asarray, 63 | np.flatnonzero, 64 | np.count_nonzero, 65 | np.searchsorted, 66 | np.sign, 67 | np.floor_divide, 68 | np.around, 69 | np.fix, 70 | np.zeros_like, 71 | np.ones_like, 72 | ] 73 | ) 74 | 75 | # defvjp(np.nan_to_num, lambda ans, x: lambda g: np.where(np.isfinite(x), g, 0.)) 76 | 77 | # ----- Binary ufuncs ----- 78 | 79 | defvjp( 80 | np.add, 81 | lambda ans, x, y: unbroadcast_f(x, lambda g: g), 82 | lambda ans, x, y: unbroadcast_f(y, lambda g: g), 83 | ) 84 | defvjp( 85 | np.multiply, 86 | lambda ans, x, y: unbroadcast_f(x, lambda g: y * g), 87 | lambda ans, x, y: unbroadcast_f(y, lambda g: x * g), 88 | ) 89 | defvjp( 90 | np.subtract, 91 | lambda ans, x, y: unbroadcast_f(x, lambda g: g), 92 | lambda ans, x, y: unbroadcast_f(y, lambda g: -g), 93 | ) 94 | defvjp( 95 | np.divide, 96 | lambda ans, x, y: unbroadcast_f(x, lambda g: g / y), 97 | lambda ans, x, y: unbroadcast_f(y, lambda g: -g * x / y ** 2), 98 | ) 99 | defvjp( 100 | np.maximum, 101 | lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), 102 | lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)), 103 | ) 104 | defvjp( 105 | np.minimum, 106 | lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), 107 | lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)), 108 | ) 109 | defvjp( 110 | np.fmax, 111 | lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), 112 | lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)), 113 | ) 114 | defvjp( 115 | np.fmin, 116 | lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)), 117 | lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)), 118 | ) 119 | defvjp( 120 | np.logaddexp, 121 | lambda ans, x, y: unbroadcast_f(x, lambda g: g * np.exp(x - ans)), 122 | lambda ans, x, y: unbroadcast_f(y, lambda g: g * np.exp(y - ans)), 123 | ) 124 | defvjp( 125 | np.logaddexp2, 126 | lambda ans, x, y: unbroadcast_f(x, lambda g: g * 2 ** (x - ans)), 127 | lambda ans, x, y: unbroadcast_f(y, lambda g: g * 2 ** (y - ans)), 128 | ) 129 | defvjp( 130 | np.true_divide, 131 | lambda ans, x, y: unbroadcast_f(x, lambda g: g / y), 132 | lambda ans, x, y: unbroadcast_f(y, lambda g: -g * x / y ** 2), 133 | ) 134 | defvjp( 135 | np.mod, 136 | lambda ans, x, y: unbroadcast_f(x, lambda g: g), 137 | lambda ans, x, y: unbroadcast_f(y, lambda g: -g * np.floor(x / y)), 138 | ) 139 | defvjp( 140 | np.remainder, 141 | lambda ans, x, y: unbroadcast_f(x, lambda g: g), 142 | lambda ans, x, y: unbroadcast_f(y, lambda g: -g * np.floor(x / y)), 143 | ) 144 | defvjp( 145 | np.power, 146 | lambda ans, x, y: unbroadcast_f(x, lambda g: g * y * x ** np.where(y, y - 1, 1.0)), 147 | lambda ans, x, y: unbroadcast_f( 148 | y, lambda g: g * np.log(replace_non_positive(x, 1.0)) * ans 149 | ), 150 | ) 151 | defvjp( 152 | np.arctan2, 153 | lambda ans, x, y: unbroadcast_f(x, lambda g: g * y / (x ** 2 + y ** 2)), 154 | lambda ans, x, y: unbroadcast_f(y, lambda g: g * -x / (x ** 2 + y ** 2)), 155 | ) 156 | defvjp( 157 | np.hypot, 158 | lambda ans, x, y: unbroadcast_f(x, lambda g: g * x / ans), 159 | lambda ans, x, y: unbroadcast_f(y, lambda g: g * y / ans), 160 | ) 161 | 162 | # ----- Simple grads ----- 163 | defvjp(np.sign, lambda ans, x: lambda g: np.nan if x == 0 else 0) 164 | defvjp(np.positive, lambda ans, x: lambda g: g) 165 | defvjp(np.negative, lambda ans, x: lambda g: -g) 166 | defvjp( 167 | np.absolute, 168 | lambda ans, x: lambda g: g * replace_zero(np.conj(x), 0.0) / replace_zero(ans, 1.0), 169 | ) 170 | defvjp( 171 | np.fabs, lambda ans, x: lambda g: np.sign(x) * g 172 | ) # fabs doesn't take complex numbers. 173 | defvjp(np.absolute, lambda ans, x: lambda g: g * np.conj(x) / ans) 174 | defvjp(np.reciprocal, lambda ans, x: lambda g: -g / x ** 2) 175 | defvjp(np.exp, lambda ans, x: lambda g: ans * g) 176 | defvjp(np.exp2, lambda ans, x: lambda g: ans * np.log(2) * g) 177 | defvjp(np.expm1, lambda ans, x: lambda g: (ans + 1) * g) 178 | defvjp(np.log, lambda ans, x: lambda g: g / x) 179 | defvjp(np.log2, lambda ans, x: lambda g: g / x / np.log(2)) 180 | defvjp(np.log10, lambda ans, x: lambda g: g / x / np.log(10)) 181 | defvjp(np.log1p, lambda ans, x: lambda g: g / (x + 1)) 182 | defvjp(np.sin, lambda ans, x: lambda g: g * np.cos(x)) 183 | defvjp(np.cos, lambda ans, x: lambda g: -g * np.sin(x)) 184 | defvjp(np.tan, lambda ans, x: lambda g: g / np.cos(x) ** 2) 185 | defvjp(np.arcsin, lambda ans, x: lambda g: g / np.sqrt(1 - x ** 2)) 186 | defvjp(np.arccos, lambda ans, x: lambda g: -g / np.sqrt(1 - x ** 2)) 187 | defvjp(np.arctan, lambda ans, x: lambda g: g / (1 + x ** 2)) 188 | defvjp(np.sinh, lambda ans, x: lambda g: g * np.cosh(x)) 189 | defvjp(np.cosh, lambda ans, x: lambda g: g * np.sinh(x)) 190 | defvjp(np.tanh, lambda ans, x: lambda g: g / np.cosh(x) ** 2) 191 | defvjp(np.arcsinh, lambda ans, x: lambda g: g / np.sqrt(x ** 2 + 1)) 192 | defvjp(np.arccosh, lambda ans, x: lambda g: g / np.sqrt(x ** 2 - 1)) 193 | defvjp(np.arctanh, lambda ans, x: lambda g: g / (1 - x ** 2)) 194 | defvjp(np.rad2deg, lambda ans, x: lambda g: g / np.pi * 180.0) 195 | defvjp(np.degrees, lambda ans, x: lambda g: g / np.pi * 180.0) 196 | defvjp(np.deg2rad, lambda ans, x: lambda g: g * np.pi / 180.0) 197 | defvjp(np.radians, lambda ans, x: lambda g: g * np.pi / 180.0) 198 | defvjp(np.square, lambda ans, x: lambda g: g * 2 * x) 199 | defvjp(np.sqrt, lambda ans, x: lambda g: g * 0.5 * x ** -0.5) 200 | defvjp( 201 | np.sinc, 202 | lambda ans, x: lambda g: g 203 | * (np.cos(np.pi * x) * np.pi * x - np.sin(np.pi * x)) 204 | / (np.pi * x ** 2), 205 | ) 206 | defvjp( 207 | np.reshape, 208 | lambda ans, x, shape, order=None: lambda g: np.reshape(g, np.shape(x), order=order), 209 | ) 210 | defvjp( 211 | np.roll, lambda ans, x, shift, axis=None: lambda g: np.roll(g, -shift, axis=axis) 212 | ) 213 | defvjp( 214 | np.array_split, 215 | lambda ans, ary, idxs, axis=0: lambda g: np.concatenate(g, axis=axis), 216 | ) 217 | defvjp(np.split, lambda ans, ary, idxs, axis=0: lambda g: np.concatenate(g, axis=axis)) 218 | defvjp(np.vsplit, lambda ans, ary, idxs: lambda g: np.concatenate(g, axis=0)) 219 | defvjp(np.hsplit, lambda ans, ary, idxs: lambda g: np.concatenate(g, axis=1)) 220 | defvjp(np.dsplit, lambda ans, ary, idxs: lambda g: np.concatenate(g, axis=2)) 221 | defvjp( 222 | np.ravel, 223 | lambda ans, x, order=None: lambda g: np.reshape(g, np.shape(x), order=order), 224 | ) 225 | defvjp(np.expand_dims, lambda ans, x, axis: lambda g: np.reshape(g, np.shape(x))) 226 | defvjp(np.squeeze, lambda ans, x, axis=None: lambda g: np.reshape(g, np.shape(x))) 227 | defvjp(np.diag, lambda ans, x, k=0: lambda g: np.diag(g, k)) 228 | defvjp(np.flipud, lambda ans, x,: lambda g: np.flipud(g)) 229 | defvjp(np.fliplr, lambda ans, x,: lambda g: np.fliplr(g)) 230 | defvjp(np.rot90, lambda ans, x, k=1: lambda g: np.rot90(g, -k)) 231 | defvjp( 232 | np.full, 233 | lambda ans, shape, fill_value, dtype=None: lambda g: np.sum(g), 234 | argnums=(1,), 235 | ) 236 | defvjp(np.triu, lambda ans, x, k=0: lambda g: np.triu(g, k=k)) 237 | defvjp(np.tril, lambda ans, x, k=0: lambda g: np.tril(g, k=k)) 238 | defvjp( 239 | np.clip, 240 | lambda ans, x, a_min, a_max: lambda g: g 241 | * np.logical_and(ans != a_min, ans != a_max), 242 | ) 243 | defvjp(np.swapaxes, lambda ans, x, axis1, axis2: lambda g: np.swapaxes(g, axis2, axis1)) 244 | defvjp( 245 | np.moveaxis, 246 | lambda ans, a, source, destination: lambda g: np.moveaxis(g, destination, source), 247 | ) 248 | defvjp(np.real_if_close, lambda ans, x: lambda g: match_complex(x, g)) 249 | defvjp(np.real, lambda ans, x: lambda g: match_complex(x, g)) 250 | defvjp(np.imag, lambda ans, x: lambda g: match_complex(x, -1j * g)) 251 | defvjp(np.conj, lambda ans, x: lambda g: np.conj(g)) 252 | defvjp(np.conjugate, lambda ans, x: lambda g: np.conj(g)) 253 | defvjp( 254 | np.angle, 255 | lambda ans, x: lambda g: match_complex(x, g * np.conj(x * 1j) / np.abs(x) ** 2), 256 | ) 257 | defvjp( 258 | np.where, 259 | None, 260 | lambda ans, c, x=None, y=None: lambda g: np.where(c, g, np.zeros_like(g)), 261 | lambda ans, c, x=None, y=None: lambda g: np.where(c, np.zeros_like(g), g), 262 | ) 263 | defvjp( 264 | np.cross, 265 | lambda ans, a, b, axisa=-1, axisb=-1, axisc=-1, axis=None: lambda g: np.cross( 266 | b, g, axisb, axisc, axisa, axis 267 | ), 268 | lambda ans, a, b, axisa=-1, axisb=-1, axisc=-1, axis=None: lambda g: np.cross( 269 | g, a, axisc, axisa, axisb, axis 270 | ), 271 | ) 272 | defvjp( 273 | np.linspace, 274 | lambda ans, start, stop, num: lambda g: np.dot(np.linspace(1.0, 0.0, num), g), 275 | lambda ans, start, stop, num: lambda g: np.dot(np.linspace(0.0, 1.0, num), g), 276 | ) 277 | 278 | # ----- Trickier grads ----- 279 | def grad_rollaxis(ans, a, axis, start=0): 280 | if axis < 0: 281 | raise NotImplementedError( 282 | "Gradient of rollaxis not implemented for axis < 0. " 283 | "Please use moveaxis instead." 284 | ) 285 | elif start < 0: 286 | raise NotImplementedError( 287 | "Gradient of rollaxis not implemented for start < 0. " 288 | "Please use moveaxis instead." 289 | ) 290 | return ( 291 | lambda g: np.rollaxis(g, start - 1, axis) 292 | if start > axis 293 | else np.rollaxis(g, start, axis + 1) 294 | ) 295 | 296 | 297 | defvjp(np.rollaxis, grad_rollaxis) 298 | 299 | 300 | def stack_diff(ans, x, axis=0): 301 | def vjp(g): 302 | ret = [] 303 | shape = np.shape(g) 304 | for idx in range(shape[axis]): 305 | ret.append(np.take(g, idx, axis=axis)) 306 | return tuple(ret) 307 | 308 | return vjp 309 | 310 | 311 | defvjp(np.stack, stack_diff) 312 | 313 | 314 | def grad_gradient(ans, x, *vargs, **kwargs): 315 | axis = kwargs.pop("axis", None) 316 | if vargs or kwargs: 317 | raise NotImplementedError( 318 | "The only optional argument currently supported for np.gradient " "is axis." 319 | ) 320 | if axis is None: 321 | axis = range(np.ndim(x)) 322 | elif type(axis) is int: 323 | axis = [axis] 324 | else: 325 | axis = list(axis) 326 | 327 | x_dtype = x.dtype 328 | x_shape = x.shape 329 | nd = np.ndim(x) 330 | 331 | def vjp(g): 332 | if np.ndim(g) == nd: 333 | # add axis if gradient was along one axis only 334 | g = g[np.newaxis] 335 | 336 | # accumulate gradient 337 | out = np.zeros(x_shape, dtype=x_dtype) 338 | 339 | for i, a in enumerate(axis): 340 | # swap gradient axis to the front 341 | g_swap = np.swapaxes(g[i], 0, a)[:, np.newaxis] 342 | 343 | out_axis = np.concatenate( 344 | ( 345 | -g_swap[0] - 0.5 * g_swap[1], 346 | g_swap[0] - 0.5 * g_swap[2], 347 | (-1.0) * np.gradient(g_swap, axis=0)[2:-2, 0], 348 | 0.5 * g_swap[-3] - g_swap[-1], 349 | 0.5 * g_swap[-2] + g_swap[-1], 350 | ), 351 | axis=0, 352 | ) 353 | 354 | out = out + np.swapaxes(out_axis, 0, a) 355 | 356 | return out 357 | 358 | return vjp 359 | 360 | 361 | defvjp(np.gradient, grad_gradient) 362 | 363 | 364 | def grad_repeat(ans, x, repeats, axis=None): 365 | shape = np.shape(x) 366 | 367 | def vjp(g): 368 | if axis is None: # If axis is none, np.repeat() repeats the flattened array. 369 | expanded = np.reshape(g, (np.prod(shape),) + (repeats,)) 370 | return np.reshape(np.sum(expanded, axis=1, keepdims=False), shape) 371 | else: 372 | if shape[axis] == 1: # For this common case, the logic is simple. 373 | return np.sum(g, axis=axis, keepdims=True) 374 | else: 375 | expanded = np.reshape( 376 | g, shape[0 : axis + 1] + (repeats,) + shape[axis + 1 :] 377 | ) 378 | return np.sum(expanded, axis=axis + 1, keepdims=False) 379 | 380 | return vjp 381 | 382 | 383 | defvjp(np.repeat, grad_repeat) 384 | 385 | 386 | def grad_tile(ans, x, reps): 387 | reps = [reps] if np.isscalar(reps) else reps 388 | x_shape = np.shape(x) 389 | 390 | def vjp(g): 391 | for axis, rep in enumerate(reps): 392 | g = sum(np.split(g, rep, axis)) 393 | return np.reshape(g, x_shape) 394 | 395 | return vjp 396 | 397 | 398 | defvjp(np.tile, grad_tile) 399 | 400 | 401 | def grad_transpose(ans, x, axes=None): 402 | if axes is not None: 403 | axes = np.argsort(axes) 404 | return lambda g: np.transpose(g, axes) 405 | 406 | 407 | defvjp(np.transpose, grad_transpose) 408 | 409 | 410 | def repeat_to_match_shape(g, shape, dtype, axis, keepdims): 411 | """Returns the array g repeated along axis to fit vector space vs. 412 | Also returns the number of repetitions of the array.""" 413 | with ua.set_backend(numpy_backend, coerce=True): 414 | if shape == (): 415 | return g, 1 416 | axis = list(axis) if isinstance(axis, tuple) else axis 417 | new_shape = np.array(shape, dtype=int) 418 | new_shape[axis] = 1 419 | num_reps = np.prod(np.array(shape)[axis]) 420 | return np.broadcast_to(np.reshape(g, new_shape), shape), num_reps 421 | 422 | 423 | def grad_broadcast_to(ans, x, new_shape): 424 | old_shape = np.shape(x) 425 | assert np.shape(ans) == new_shape 426 | assert len(old_shape) == len(new_shape), "Can't handle extra leading dims" 427 | 428 | broadcast_axes = tuple( 429 | i for i in range(len(old_shape)) if old_shape[i] == 1 and new_shape[i] > 1 430 | ) 431 | 432 | return lambda g: np.sum(g, axis=broadcast_axes, keepdims=True) 433 | 434 | 435 | defvjp(np.broadcast_to, grad_broadcast_to) 436 | 437 | 438 | def grad_np_sum(ans, x, axis=None, keepdims=False, dtype=None): 439 | shape, dtype = np.shape(x.value), x.dtype 440 | return lambda g: repeat_to_match_shape(g, shape, dtype, axis, keepdims)[0] 441 | 442 | 443 | defvjp(np.sum, grad_np_sum) 444 | 445 | 446 | def grad_np_prod(ans, x, axis=None, keepdims=False): # TODO: Support tuples of axes. 447 | shape, dtype = np.shape(x), x.dtype 448 | 449 | def vjp(g): 450 | g_repeated, _ = repeat_to_match_shape(g * ans, shape, dtype, axis, keepdims) 451 | return g_repeated / x 452 | 453 | return vjp 454 | 455 | 456 | defvjp(np.prod, grad_np_prod) 457 | 458 | 459 | def grad_np_var(ans, x, axis=None, ddof=0, keepdims=False): 460 | shape, _, dtype, iscomplex = metadata(x) 461 | 462 | def vjp(g): 463 | if iscomplex: 464 | g = g + 0j 465 | g_repeated, num_reps = repeat_to_match_shape(g, shape, dtype, axis, keepdims) 466 | x_minus_mean = np.conj(x - x / np.sum(x, axis=axis, keepdims=True)) 467 | return 2.0 * g_repeated * x_minus_mean / (num_reps - ddof) 468 | 469 | return vjp 470 | 471 | 472 | defvjp(np.var, grad_np_var) 473 | 474 | 475 | def grad_np_std(ans, x, axis=None, ddof=0, keepdims=False): 476 | shape, _, dtype, iscomplex = metadata(x) 477 | 478 | def vjp(g): 479 | if iscomplex: 480 | g = g + 0j 481 | g_repeated, num_reps = repeat_to_match_shape( 482 | g, shape, dtype, axis, keepdims 483 | ) # Avoid division by zero. 484 | if num_reps <= 1: 485 | return g_repeated * 0.0 486 | else: 487 | g_repeated, num_reps = repeat_to_match_shape( 488 | g / ans, shape, dtype, axis, keepdims 489 | ) 490 | x_minus_mean = np.conj(x - x / np.sum(x, axis=axis, keepdims=True)) 491 | return g_repeated * x_minus_mean / (num_reps - ddof) 492 | 493 | return vjp 494 | 495 | 496 | defvjp(np.std, grad_np_std) 497 | 498 | 499 | def grad_chooser(ans, x, axis=None, keepdims=None): 500 | shape, dtype = np.shape(x), x.dtype 501 | 502 | def vjp(g): 503 | """Builds gradient of functions that choose a single item, such as min or max.""" 504 | g_repeated, _ = repeat_to_match_shape(g, shape, dtype, axis, keepdims) 505 | argmax_locations = ( 506 | x == repeat_to_match_shape(ans, shape, dtype, axis, keepdims)[0] 507 | ) 508 | return ( 509 | g_repeated 510 | * argmax_locations 511 | / np.sum(argmax_locations, axis=axis, keepdims=True) 512 | ) 513 | 514 | return vjp 515 | 516 | 517 | defvjp(np.max, grad_chooser) 518 | defvjp(np.min, grad_chooser) 519 | 520 | 521 | def reverse_axis(x, axis): 522 | x = x.swapaxes(axis, 0) 523 | x = x[::-1, ...] 524 | return x.swapaxes(0, axis) 525 | 526 | 527 | def grad_np_cumsum(ans, x, axis=None): 528 | def vjp(g): 529 | if axis: 530 | return reverse_axis(np.cumsum(reverse_axis(g, axis), axis), axis) 531 | else: 532 | return np.reshape(np.cumsum(g[::-1], axis)[::-1], x.shape) 533 | 534 | return vjp 535 | 536 | 537 | defvjp(np.cumsum, grad_np_cumsum) 538 | 539 | 540 | def matmul_adjoint_0(B, G, A_meta, B_ndim): 541 | G_ndim = np.ndim(G) 542 | if G_ndim == 0: # A_ndim == B_ndim == 1 543 | return unbroadcast(G * B, A_meta) 544 | _, A_ndim, _, _ = A_meta 545 | if A_ndim == 1: 546 | G = np.expand_dims(G, G_ndim - 1) 547 | if B_ndim == 1: # The result we need is an outer product 548 | B = np.expand_dims(B, 0) 549 | G = np.expand_dims(G, G_ndim) 550 | else: # We need to swap the last two axes of B 551 | B = np.swapaxes(B, B_ndim - 2, B_ndim - 1) 552 | result = np.matmul(G, B) 553 | return unbroadcast(result, A_meta) 554 | 555 | 556 | def matmul_adjoint_1(A, G, A_ndim, B_meta): 557 | G_ndim = np.ndim(G) 558 | if G_ndim == 0: # A_ndim == B_ndim == 1 559 | return unbroadcast(G * A, B_meta) 560 | _, B_ndim, _, _ = B_meta 561 | B_is_vec = B_ndim == 1 562 | if B_is_vec: 563 | G = np.expand_dims(G, G_ndim) 564 | if A_ndim == 1: # The result we need is an outer product 565 | A = np.expand_dims(A, 1) 566 | G = np.expand_dims(G, G_ndim - 1) 567 | else: # We need to swap the last two axes of A 568 | A = np.swapaxes(A, A_ndim - 2, A_ndim - 1) 569 | result = np.matmul(A, G) 570 | if B_is_vec: 571 | result = np.squeeze(result, G_ndim - 1) 572 | return unbroadcast(result, B_meta) 573 | 574 | 575 | def matmul_vjp_0(ans, A, B): 576 | A_meta = metadata(A) 577 | B_ndim = np.ndim(B) 578 | return lambda g: matmul_adjoint_0(B, g, A_meta, B_ndim) 579 | 580 | 581 | def matmul_vjp_1(ans, A, B): 582 | A_ndim = np.ndim(A) 583 | B_meta = metadata(B) 584 | return lambda g: matmul_adjoint_1(A, g, A_ndim, B_meta) 585 | 586 | 587 | defvjp(np.matmul, matmul_vjp_0, matmul_vjp_1) 588 | 589 | 590 | def grad_sort(ans, x, axis=-1, kind="quicksort", order=None): 591 | if len(np.shape(x)) > 1: 592 | raise NotImplementedError( 593 | "Gradient of sort not implemented for multi-dimensional arrays." 594 | ) 595 | sort_perm = np.argsort(x, axis, kind, order) 596 | return lambda g: unpermuter(g, sort_perm) 597 | 598 | 599 | defvjp(np.sort, grad_sort) 600 | defvjp(np.msort, grad_sort) # Until multi-D is allowed, these are the same. 601 | 602 | 603 | def grad_partition(ans, x, kth, axis=-1, kind="introselect", order=None): 604 | if len(np.shape(x)) > 1: 605 | raise NotImplementedError( 606 | "Gradient of partition not implemented for multi-dimensional arrays." 607 | ) 608 | partition_perm = np.argpartition(x, kth, axis, kind, order) 609 | return lambda g: unpermuter(g, partition_perm) 610 | 611 | 612 | defvjp(np.partition, grad_partition) 613 | 614 | 615 | def unpermuter(g, permutation): 616 | unsort = np.zeros(len(permutation), dtype=int) 617 | unsort[permutation] = list(range(len(permutation))) 618 | return g[unsort] 619 | 620 | 621 | def grad_reshape_list(ans, *arys): 622 | if len(arys) > 1: 623 | raise NotImplementedError("Can't handle multiple arguments yet.") 624 | return lambda g: np.reshape(g, np.shape(arys[0])) 625 | 626 | 627 | defvjp(np.atleast_1d, grad_reshape_list) 628 | defvjp(np.atleast_2d, grad_reshape_list) 629 | defvjp(np.atleast_3d, grad_reshape_list) 630 | 631 | 632 | def match_complex(target, x): 633 | target_iscomplex = np.iscomplexobj(target) 634 | x_iscomplex = np.iscomplexobj(x) 635 | if x_iscomplex and not target_iscomplex: 636 | return np.real(x) 637 | elif not x_iscomplex and target_iscomplex: 638 | return x + 0j 639 | else: 640 | return x 641 | 642 | 643 | def metadata(A): 644 | return np.shape(A), np.ndim(A), A.dtype, np.iscomplexobj(A) 645 | 646 | 647 | def unbroadcast(x, target_meta, broadcast_idx=0): 648 | target_shape, target_ndim, _, _ = target_meta 649 | while np.ndim(x) > target_ndim: 650 | x = np.sum(x, axis=broadcast_idx) 651 | for axis, size in enumerate(target_shape): 652 | if size == 1: 653 | x = np.sum(x, axis=axis, keepdims=True) 654 | if np.iscomplexobj(x) and not target_iscomplex: 655 | x = np.real(x) 656 | return x 657 | 658 | 659 | def unbroadcast_f(target, f): 660 | target_meta = metadata(target) 661 | return lambda g: unbroadcast(f(g), target_meta) 662 | 663 | 664 | def unbroadcast_einsum(x, target_meta, subscript): 665 | if Ellipsis not in subscript: 666 | return x 667 | elif subscript[0] == Ellipsis: 668 | return unbroadcast(x, target_meta, 0) 669 | elif subscript[-1] == Ellipsis: 670 | return unbroadcast(x, target_meta, -1) 671 | else: 672 | return unbroadcast(x, target_meta, subscript.index(Ellipsis)) 673 | 674 | 675 | def balanced_eq(x, z, y): 676 | return (x == z) / (1.0 + (x == y)) 677 | 678 | 679 | def replace_zero(x, val): 680 | return np.where(x, x, val) 681 | 682 | 683 | def replace_non_positive(x, val): 684 | return np.where(x.value > 0, x, val) 685 | 686 | 687 | # ----- extra functions used internally ----- 688 | 689 | 690 | def _unpad(array, width): 691 | if np.isscalar(width): 692 | width = [[width, width]] 693 | elif np.shape(width) == (1,): 694 | width = [np.concatenate((width, width))] 695 | elif np.shape(width) == (2,): 696 | width = [width] 697 | if np.shape(width)[0] == 1: 698 | width = np.repeat(width, np.ndim(array), 0) 699 | idxs = tuple(slice(l, -u or None) for l, u in width) 700 | return array[idxs] 701 | 702 | 703 | def pad_vjp(ans, array, pad_width, mode, **kwargs): 704 | assert mode == "constant", "Only constant mode padding is supported." 705 | return lambda g: _unpad(g, pad_width) 706 | 707 | 708 | defvjp(np.pad, pad_vjp) 709 | --------------------------------------------------------------------------------