├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── sct-ad-conditional.gif ├── sct-ad-debugging.png ├── sct-ad-forward.gif ├── sct-ad-insert_grad_of.gif ├── sct-ad-live.gif ├── sct-ad-loop.gif ├── sct-ad-numpy.gif ├── sct-ad-subroutine.gif ├── sct-ad-tf.gif ├── sct-ad.gif ├── small-benchmark.png └── toolspace.png ├── environment.yml ├── requirements.txt ├── setup.cfg ├── setup.py ├── tangent ├── __init__.py ├── anf.py ├── annotate.py ├── annotations.py ├── ast.py ├── cfg.py ├── comments.py ├── compile.py ├── create.py ├── desugar.py ├── errors.py ├── fence.py ├── fixes.py ├── forward_ad.py ├── funcsigs.py ├── grad_util.py ├── grads.py ├── grammar.py ├── naming.py ├── non_differentiable.py ├── optimization.py ├── quoting.py ├── reverse_ad.py ├── tangents.py ├── template.py ├── tf_extensions.py ├── tracing.py ├── transformers.py └── utils.py └── tests ├── conftest.py ├── functions.py ├── test_anf.py ├── test_annotate.py ├── test_cfg.py ├── test_comments.py ├── test_compile.py ├── test_fence.py ├── test_forward_mode.py ├── test_hessian_vector_products.py ├── test_optimization.py ├── test_reverse_mode.py ├── test_reverse_over_reverse.py ├── test_template.py ├── test_transformers.py ├── tfe_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | .static_storage/ 58 | .media/ 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - 2.7 4 | - 3.5 5 | - 3.6 6 | install: 7 | - pip install -e . 8 | - pip install -r requirements.txt 9 | - pip install pytest pytest-cov flake8 10 | script: 11 | # stop the build if there are Python syntax errors or undefined names 12 | - flake8 . --builtins=d,tangent,numpy --count --exclude=tests --select=E901,E999,F821,F822,F823 --show-source --statistics 13 | - py.test --cov=coverage --short tests 14 | after_success: 15 | - coveralls 16 | notifications: 17 | email: 18 | on_success: change 19 | on_failure: change 20 | deploy: 21 | provider: pypi 22 | user: tangent_oss 23 | password: 24 | secure: RYOgIM2/gSvFsWiiCm33OXgPxCer2kX4m6cfMd3yaTzILigUMWnCgfxnfnrZrbUj8DQDSXUoMnXTWHcHPCHo5wTrfGXcFudaBEAyhoozi/2HT1BUqLJIvDC8VgxDvXo9zrcMOhr6WnE+/c9V8JdAFprtnfwBjAR27CqIaLJ9qnPbDKPtlrCsjDMC6HYLFRR8qWqBAh7uwGHlY0Idr57kIv8vVjQbIvuVnBaLPIqK+N22m4h1z633ASwv+cNSp9MHWCBQ70ON+YrEDW/HZO2+kgyCVWFiwJpWtJxOq/YvYunsH+7cHv5GFkyBplnbqaN6gR/Xoknu7gC1YbiVxFAAyIyPdy+HCQ02nFbSf0rkTxYWRmIInq1MEruHkckIEML8DKy2XSZJkIDesBAUIUU5hAZCvKB6TieMZ8xJbD9hi/y26Wiwh4gr9AVtVM/jkJQbfGVFUfkgfIgj/FsKeRRxdx5J1sOtKtKGca4dz/83vRNGzHF6jfNMCGivqfCnyxcvZIg9Nh5xjboO3n4nwyKkbf97cVYmUHM8WKan7qRgWIPbT/WhKY1GgEdCZaHY9ZVCrCBZJ+4Z6w0Xt4WuqMACYH1sMyScY4CAh2WjgsXIL5vwSv5bwqAB6muLuj7mLjzuUQsTMU9xhwrHDJ/2AVQTpNdyV0zqbF9Tk07sBfZlz5w= 25 | on: 26 | tags: true 27 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | ## How to become a contributor and submit your own code 4 | 5 | ### Contributor License Agreements 6 | 7 | We'd love to accept your patches! Before we can take them, we have to jump a couple of legal hurdles. 8 | 9 | Please fill out either the individual or corporate Contributor License Agreement (CLA). 10 | 11 | * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html). 12 | * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html). 13 | 14 | Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests. 15 | 16 | ***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the main repository. 17 | 18 | ### Contributing code 19 | 20 | #### Adding new derivatives 21 | 22 | We still have a lot of derivatives we need to write! To add a new derivative for a primitive operation, 23 | 24 | - [Read the docs on how to write derivatives in Tangent, and look at some examples](https://github.com/google/tangent/blob/7bf4eaffd646a5906aa15a852f117833d37fb09a/tangent/grads.py#L14-L33)! 25 | - Add your primitive op's reverse-mode derivative to grads.py (example reverse-mode derivative [for np.sin](https://github.com/google/tangent/blob/7bf4eaffd646a5906aa15a852f117833d37fb09a/tangent/grads.py#L230-L232), and for [tf.sin](https://github.com/google/tangent/blob/7bf4eaffd646a5906aa15a852f117833d37fb09a/tangent/tf_extensions.py#L183-L185)). 26 | - Add its forward-mode derivative to tangents.py (example forward-mode derivative for [np.sin](https://github.com/google/tangent/blob/7bf4eaffd646a5906aa15a852f117833d37fb09a/tangent/tangents.py#L144-L146), and for [tf.sin](https://github.com/google/tangent/blob/7bf4eaffd646a5906aa15a852f117833d37fb09a/tangent/tf_extensions.py#L344-L346)) 27 | - Add a function using the primitive operation in functions.py. Our tests will pick up on it automatically ([example test function](https://github.com/google/tangent/blob/7bf4eaffd646a5906aa15a852f117833d37fb09a/tests/functions.py#L406-L407)). 28 | - Make sure the tests pass. The tests will be run automatically with Travis once you submit a PR, but it's good to do this locally, so you don't have to wait as long. 29 | ``` 30 | # Make sure you have pytest installed 31 | pip install pytest 32 | # Run this command from the root of the Tangent project 33 | py.test --short tests 34 | ``` 35 | 36 | #### Adding other new functionality 37 | 38 | Tangent is a work-in-progress, so there's a lot of upgrades and tweaks that would be useful. If you've already fixed a bug, or added an enhancement, open a PR, and the team will take a look at it, and help you polish it, conform to style guidelines etc., before we merge it. If you're thinking about embarking on a feature enhancement, open a GitHub issue to start a discussion, or [talk to us on gitter](https://gitter.im/google/tangent). 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2017 Google Inc. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /docs/sct-ad-conditional.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/sct-ad-conditional.gif -------------------------------------------------------------------------------- /docs/sct-ad-debugging.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/sct-ad-debugging.png -------------------------------------------------------------------------------- /docs/sct-ad-forward.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/sct-ad-forward.gif -------------------------------------------------------------------------------- /docs/sct-ad-insert_grad_of.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/sct-ad-insert_grad_of.gif -------------------------------------------------------------------------------- /docs/sct-ad-live.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/sct-ad-live.gif -------------------------------------------------------------------------------- /docs/sct-ad-loop.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/sct-ad-loop.gif -------------------------------------------------------------------------------- /docs/sct-ad-numpy.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/sct-ad-numpy.gif -------------------------------------------------------------------------------- /docs/sct-ad-subroutine.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/sct-ad-subroutine.gif -------------------------------------------------------------------------------- /docs/sct-ad-tf.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/sct-ad-tf.gif -------------------------------------------------------------------------------- /docs/sct-ad.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/sct-ad.gif -------------------------------------------------------------------------------- /docs/small-benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/small-benchmark.png -------------------------------------------------------------------------------- /docs/toolspace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/tangent/6533e83af09de7345d1b438512679992f080dcc9/docs/toolspace.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: tangent 2 | 3 | dependencies: 4 | - python>=2.7 5 | - enum34 6 | - future 7 | - nose 8 | - numpy 9 | - six 10 | - pip: 11 | - autograd 12 | - astor>=0.6 13 | - gast 14 | - tf-nightly==1.5.0.dev20171026 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autograd>=1.2 2 | astor>=0.6 3 | enum34 4 | future 5 | gast 6 | nose 7 | numpy 8 | six 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages 2 | from setuptools import setup 3 | 4 | with open('README.md') as f: 5 | readme = f.read() 6 | 7 | with open('LICENSE') as f: 8 | lic = f.read() 9 | 10 | with open('requirements.txt') as f: 11 | reqs = list(f.read().strip().split('\n')) 12 | 13 | setup( 14 | name='tangent', 15 | version='0.1.9', 16 | description=('Automatic differentiation using source code transformation ' 17 | 'for Python'), 18 | long_description=readme, 19 | author='Google Inc.', 20 | author_email='alexbw@google.com', 21 | url='https://github.com/google/tangent', 22 | license=lic, 23 | packages=find_packages(exclude=('tests')), 24 | package_data={'':['README.md','LICENSE']}, 25 | keywords=[ 26 | 'autodiff', 'automatic-differentiation', 'machine-learning', 27 | 'deep-learning' 28 | ], 29 | install_requires=reqs, 30 | ) 31 | -------------------------------------------------------------------------------- /tangent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Several imports to flatten the Tangent namespace for end users.""" 15 | from __future__ import absolute_import 16 | import functools 17 | 18 | import gast 19 | 20 | from tangent import annotate 21 | from tangent import ast as ast_ 22 | from tangent import compile as compile_ 23 | from tangent.tracing import trace 24 | from tangent.tracing import trace_grad 25 | from tangent.utils import add_grad 26 | from tangent.utils import array_size 27 | from tangent.utils import astype 28 | from tangent.utils import balanced_eq 29 | from tangent.utils import copy 30 | from tangent.utils import grad_dot 31 | from tangent.utils import init_grad 32 | from tangent.utils import insert_grad_of 33 | from tangent.utils import pop 34 | from tangent.utils import pop_stack 35 | from tangent.utils import push 36 | from tangent.utils import push_stack 37 | from tangent.utils import shapes_match 38 | from tangent.utils import Stack 39 | from tangent.utils import unbroadcast 40 | from tangent.utils import unreduce 41 | from tangent.utils import unreduce_like 42 | 43 | # Imported last to avoid circular imports 44 | from tangent.grad_util import grad, autodiff, vjp, jvp 45 | from tangent.errors import * 46 | try: 47 | from tangent.tf_extensions import * 48 | except ImportError: 49 | pass 50 | 51 | 52 | class RemoveWith(gast.NodeTransformer): 53 | """A transformer that removes `with insert_grad_of` statements.""" 54 | 55 | def visit_With(self, node): 56 | if ast_.is_insert_grad_of_statement(node): 57 | return None 58 | else: 59 | return node 60 | 61 | 62 | def tangent(f): 63 | """A decorator which removes the `with insert_grad_of` statement. 64 | 65 | This allows the function to be called as usual. 66 | 67 | Args: 68 | f: A function 69 | 70 | Returns: 71 | A function with any `with insert_grad_of` context managers removed. 72 | """ 73 | node = annotate.resolve_calls(f) 74 | RemoveWith().visit(node) 75 | wrapped = functools.wraps(f)(compile_.compile_function(node)) 76 | wrapped.tangent = f 77 | return wrapped 78 | -------------------------------------------------------------------------------- /tangent/anf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Transform AST into something similar to A-normal form. 15 | 16 | This significantly simplifies certain procedures later on. The ANF 17 | transformations guarantee the following: 18 | 19 | All nested expressions on the right hand side of assignments are expanded and 20 | reduced to the following: 21 | 22 | y = x 23 | y = f(x1, ..., xn) 24 | z = x + y 25 | y = -x 26 | y.i = x 27 | y = x.i 28 | y[i] = x 29 | y = x[i] 30 | z = x, y 31 | 32 | Note that we do not allow tuple unpacking, because statements like `x[i], y = 33 | f(x)` are difficult to process in this case. Hence, unpacking is made explicit. 34 | 35 | The value of the return statement is reduced to either a single variable, or a 36 | tuple of variables (nested tuples are expanded). 37 | 38 | """ 39 | from __future__ import absolute_import 40 | import gast 41 | 42 | from tangent import annotations as anno 43 | from tangent import grammar 44 | from tangent import naming 45 | from tangent import quoting 46 | from tangent import transformers 47 | 48 | 49 | class ANF(transformers.TreeTransformer): 50 | """Transform a tree to an ANF-like form.""" 51 | 52 | def __init__(self): 53 | super(ANF, self).__init__() 54 | # Whether the current statement in question must be trivialized 55 | self.trivializing = False 56 | # The original line that is transformed, which is kept as an annotation 57 | self.src = '' 58 | 59 | def mark(self, node): 60 | if not anno.hasanno(node, 'pre_anf') and self.src: 61 | anno.setanno(node, 'pre_anf', self.src) 62 | 63 | def trivialize(self, node): 64 | if isinstance(node, (gast.Name, type(None)) + grammar.LITERALS): 65 | return node 66 | name = self.namer.name(node) 67 | stmt = gast.Assign( 68 | targets=[gast.Name(annotation=None, id=name, ctx=gast.Store())], 69 | value=None) 70 | self.mark(stmt) 71 | self.prepend(stmt) 72 | stmt.value = self.visit(node) 73 | return gast.Name(annotation=None, id=name, ctx=gast.Load()) 74 | 75 | def visit_Call(self, node): 76 | if self.trivializing: 77 | for i, arg in enumerate(node.args): 78 | node.args[i] = self.trivialize(arg) 79 | for keyword in node.keywords: 80 | keyword.value = self.trivialize(keyword.value) 81 | return node 82 | 83 | def visit_FunctionDef(self, node): 84 | self.namer = naming.Namer.build(node) 85 | return self.generic_visit(node) 86 | 87 | def visit_BinOp(self, node): 88 | if self.trivializing: 89 | node.left = self.trivialize(node.left) 90 | node.right = self.trivialize(node.right) 91 | return node 92 | 93 | def visit_UnaryOp(self, node): 94 | if self.trivializing: 95 | node.operand = self.trivialize(node.operand) 96 | return node 97 | 98 | def visit_Return(self, node): 99 | self.trivializing = True 100 | self.namer.target = node 101 | node.value = self.trivialize(node.value) 102 | self.trivializing = False 103 | self.namer.target = None 104 | return node 105 | 106 | def trivialize_slice(self, node): 107 | if isinstance(node, gast.Slice): 108 | name = self.namer.name(node) 109 | target = gast.Name(id=name, ctx=gast.Store(), annotation=None) 110 | stmt = gast.Assign(targets=[target], value=None) 111 | self.prepend(stmt) 112 | stmt.value = gast.Call( 113 | func=gast.Name(id='slice', ctx=gast.Load(), annotation=None), 114 | args=[ 115 | self.trivialize(arg) if arg else 116 | gast.Name(id='None', ctx=gast.Load(), annotation=None) 117 | for arg in [node.lower, node.upper, 118 | node.step]], 119 | keywords=[]) 120 | return gast.Name(id=name, ctx=gast.Load(), annotation=None) 121 | elif isinstance(node, gast.ExtSlice): 122 | name = self.namer.name(node) 123 | target = gast.Name(id=name, ctx=gast.Store(), annotation=None) 124 | stmt = gast.Assign(targets=[target], value=None) 125 | self.prepend(stmt) 126 | dim_names = [self.trivialize_slice(s).id for s in node.dims] 127 | stmt.value = gast.Tuple(elts=[ 128 | gast.Name(id=n, ctx=gast.Load(), annotation=None) 129 | for n in dim_names], ctx=gast.Load()) 130 | return gast.Name(id=name, ctx=gast.Load(), annotation=None) 131 | elif isinstance(node, gast.Index): 132 | return self.trivialize(node.value) 133 | else: 134 | raise ValueError(node) 135 | 136 | def visit_Subscript(self, node): 137 | if self.trivializing: 138 | node.value = self.trivialize(node.value) 139 | node.slice = gast.Index(value=self.trivialize_slice(node.slice)) 140 | return node 141 | 142 | def visit_Tuple(self, node): 143 | if self.trivializing: 144 | node.elts = [self.trivialize(elt) for elt in node.elts] 145 | return node 146 | 147 | def visit_List(self, node): 148 | if self.trivializing: 149 | node.elts = [self.trivialize(elt) for elt in node.elts] 150 | return node 151 | 152 | def visit_AugAssign(self, node): 153 | self.src = quoting.unquote(node) 154 | self.trivializing = True 155 | self.namer.target = node.target 156 | right = self.trivialize(node.value) 157 | target = self.trivialize(node.target) 158 | left = gast.Name(id=target.id, ctx=gast.Load(), annotation=None) 159 | node = gast.Assign(targets=[target], 160 | value=gast.BinOp( 161 | left=left, op=node.op, right=right)) 162 | self.mark(node) 163 | node = self.generic_visit(node) 164 | self.namer.target = None 165 | self.trivializing = False 166 | return node 167 | 168 | def visit_Assign(self, node): 169 | self.src = quoting.unquote(node) 170 | self.mark(node) 171 | self.trivializing = True 172 | self.namer.target = node.targets[0] 173 | if isinstance(node.targets[0], (gast.Subscript, gast.Attribute)): 174 | node.value = self.trivialize(node.value) 175 | node.targets[0] = self.visit(node.targets[0]) 176 | elif isinstance(node.targets[0], gast.Tuple): 177 | node.value = self.visit(node.value) 178 | name = self.namer.name(node.targets[0]) 179 | target = gast.Name(id=name, ctx=gast.Store(), annotation=None) 180 | for i, elt in enumerate(node.targets[0].elts): 181 | stmt = gast.Assign( 182 | targets=[elt], 183 | value=gast.Subscript( 184 | value=gast.Name(id=name, ctx=gast.Load(), 185 | annotation=None), 186 | slice=gast.Index(value=gast.Num(n=i)), 187 | ctx=gast.Load())) 188 | self.mark(stmt) 189 | self.append(stmt) 190 | node.targets[0] = target 191 | elif not isinstance(node.targets[0], gast.Name): 192 | raise ValueError('Cannot Assign to %s' % type(node.target)) 193 | node = self.generic_visit(node) 194 | self.namer.target = None 195 | self.trivializing = False 196 | return node 197 | 198 | 199 | def anf(node): 200 | """Turn an AST into ANF-like form.""" 201 | ANF().visit(node) 202 | return node 203 | -------------------------------------------------------------------------------- /tangent/annotate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Annotate the AST. 15 | 16 | This file contains passes that walk the AST and attach annotations to various 17 | nodes. 18 | """ 19 | from __future__ import absolute_import 20 | from collections import defaultdict 21 | import builtins 22 | 23 | import gast 24 | import six 25 | 26 | from tangent import annotations as anno 27 | from tangent import cfg 28 | from tangent import quoting 29 | from tangent import tracing 30 | from tangent import utils 31 | 32 | 33 | class ResolveCalls(gast.NodeVisitor): 34 | """Annotate Call nodes with the function being called.""" 35 | 36 | def __init__(self, func): 37 | self.func = func 38 | self.namespace = six.get_function_globals(func) 39 | if six.get_function_closure(func): 40 | self.namespace.update(dict(zip( 41 | func.__code__.co_freevars, 42 | (cell.cell_contents for cell in six.get_function_closure(func))))) 43 | 44 | def visit_FunctionDef(self, node): 45 | self.generic_visit(node) 46 | anno.setanno(node, 'func', self.func) 47 | 48 | def visit_Call(self, node): 49 | self.generic_visit(node) 50 | 51 | def resolve(node): 52 | if isinstance(node, gast.Attribute): 53 | return getattr(resolve(node.value), node.attr) 54 | if isinstance(node, gast.Name): 55 | if node.id in self.namespace: 56 | return self.namespace[node.id] 57 | else: 58 | # TODO: we should detect when tracing is a fallback. 59 | if hasattr(builtins, node.id): 60 | return getattr(builtins, node.id) 61 | else: 62 | raise AttributeError( 63 | 'Failed to resolve name "%s" used by "%s".'% ( 64 | node.id, self.func.__name__)) 65 | 66 | func = resolve(node.func) 67 | # If the user has used the @tangent.trace decorator, 68 | # then we'll switch to tracing the function. 69 | if hasattr(func, 'should_trace'): 70 | func = tracing.Traceable 71 | elif hasattr(func, 'fun'): 72 | # TODO: use a less dicey API to check if a function is autograd-wrapped 73 | # Autograd primitives keep around their original wrapped function. 74 | # We need that to be the func annotation, otherwise we'd have to 75 | # redefine derivatives for all autograd wrapped versions of NumPy. 76 | # Beyond that, autograd wrapped functions only have fn(*args,**kwargs) 77 | # for their signature. We need access tothe default values of functions 78 | # for proper code generation. 79 | func = func.fun 80 | anno.setanno(node, 'func', func) 81 | 82 | 83 | def resolve_calls(func): 84 | """Parse a function into an AST with function calls resolved. 85 | 86 | Since the calls are resolved using the global and local namespace of the 87 | function it means that procedural parameters (i.e. functions passed as 88 | arguments) won't be resolved. 89 | 90 | Similarly, functions defined inside of the function that we are trying to 91 | resolve won't be resolved, since they are not in the local namespace of the 92 | outer function. 93 | 94 | The function definition itself is also annotated, so that it can be matched 95 | to calls to it in other functions. 96 | 97 | Args: 98 | func: The function whose calls are being resolved. 99 | 100 | Returns: 101 | node: An AST where each `Call` node has a `func` annotation with the 102 | function handle that the call resolves to. 103 | 104 | Raises: 105 | AttributeError: When a function is used on the RHS of an assignment cannot 106 | be resolved (because it was passed as an argument or was defined in the 107 | body of the function). 108 | """ 109 | node = quoting.parse_function(func) 110 | ResolveCalls(func).visit(node) 111 | return node 112 | 113 | 114 | def _get_stack_op_handle(node): 115 | assert isinstance(node, gast.Call), 'Only can get fn handles of Call nodes' 116 | fn_handle = anno.getanno(node, 'func', False) 117 | fn_map = defaultdict(lambda: False) 118 | fn_map['tangent.pop'] = utils.pop 119 | fn_map['tangent.push'] = utils.push 120 | fn_map['tangent.pop_stack'] = utils.pop_stack 121 | fn_map['tangent.push_stack'] = utils.push_stack 122 | if not fn_handle: 123 | fn_handle = fn_map[quoting.unquote(node.func)] 124 | return fn_handle 125 | 126 | 127 | class FindStackOps(gast.NodeVisitor): 128 | """Find the pushes and pops in a node, and record all matched op IDs. 129 | A necessary prerequisite to annotating the push/pop Call and containing 130 | Assign and Expr nodes in `FindStack`. 131 | """ 132 | 133 | def __init__(self): 134 | self.push_pop_pairs = dict() 135 | 136 | def visit_Call(self, node): 137 | fn_handle = _get_stack_op_handle(node) 138 | if fn_handle and fn_handle in [ 139 | utils.pop, utils.push, utils.push_stack, utils.pop_stack 140 | ]: 141 | # Retrieve the op_id, e.g. tangent.push(_stack,val,'abc') 142 | # ^^^ 143 | if fn_handle in [utils.push, utils.push_stack]: 144 | _, _, op_id_node = node.args 145 | elif fn_handle in [utils.pop, utils.pop_stack]: 146 | _, op_id_node = node.args 147 | op_id = op_id_node.s 148 | if op_id not in self.push_pop_pairs: 149 | self.push_pop_pairs[op_id] = dict() 150 | assert fn_handle not in self.push_pop_pairs, ( 151 | 'Conflicting op_ids. ' 152 | 'Already have fn %s with ' 153 | 'op_id %s') % (quoting.unquote(node.func), op_id) 154 | self.push_pop_pairs[op_id][fn_handle] = node 155 | 156 | 157 | class AnnotateStacks(gast.NodeVisitor): 158 | """A class to find pushes and pops to the stack and annotate them as such. 159 | 160 | Args: 161 | push_pop_pairs: A dict of dicts containing a mapping from op_ids to push/pop 162 | Call nodes. Compute this using `FindStackOps`. 163 | strict: A boolean indicating whether to stringently test whether each 164 | push and pop are matched. This is not always possible when taking 165 | higher-order derivatives of code generated in split-motion (e.g. 166 | a function y = primal_f(x) only pushes variables onto a stack for use 167 | within dx = adjoint_f(dy,x), taking the second-order derivative of the 168 | call tree that contains these two will only see primal_f in isolation, 169 | and thus will only see a push, and never the connected pop) 170 | 171 | Push and pop functions are paired using the no-op string argument `op_id`. 172 | We use these matched strings to annotate the Call nodes, the containing 173 | Assign (for pop) and Expr (for push) nodes. 174 | 175 | We also track which variables was moved on/off the stack by adding the 176 | 'push_var' and 'pop_var' annotations, which are used in `CleanStack` 177 | to remove pushes of variables that are never defined. 178 | 179 | Each push Expr is given a 'pop' annotation, pointing to the pop Assign node. 180 | Each pop Assign is given a 'push' annotation, pointing to the push Expr node. 181 | """ 182 | 183 | def __init__(self, push_pop_pairs, strict): 184 | self.push_pop_pairs = push_pop_pairs 185 | self.strict = strict 186 | self.fn_map = {} 187 | self.fn_map[utils.pop] = utils.push 188 | self.fn_map[utils.push] = utils.pop 189 | self.fn_map[utils.pop_stack] = utils.pop_stack 190 | 191 | self.fn_map[utils.push_stack] = utils.pop_stack 192 | 193 | def visit_Assign(self, node): 194 | if not isinstance(node.value, gast.Call): 195 | return 196 | fn_handle = _get_stack_op_handle(node.value) 197 | if fn_handle and fn_handle in [utils.pop, utils.pop_stack]: 198 | # Retrieve the op_id, e.g. val = tangent.pop(_stack,'abc') 199 | # ^^^ 200 | _, op_id_node = node.value.args 201 | op_id = op_id_node.s 202 | anno.setanno(node, 'pop_var', node.targets[0]) 203 | 204 | if op_id not in self.push_pop_pairs: 205 | raise ValueError('op_id %s not known' % op_id) 206 | push_pop_nodes = self.push_pop_pairs[op_id] 207 | keys = push_pop_nodes.keys() 208 | # Check that the op_id is associated with only two operations 209 | if self.strict and len(keys) != 2: 210 | raise ValueError('Instead of 2 push/pop fns, found %d' % len(keys)) 211 | 212 | # Make sure that those two operations are either `push` and `pop` 213 | # or `push_stack` and `pop_stack`. 214 | if (self.strict and set(keys) != set((utils.push, utils.pop)) and 215 | set(keys) != set((utils.push_stack, utils.pop_stack))): 216 | raise ValueError('Invalid push/pop function pair. Found %s' % keys) 217 | 218 | try: 219 | matching_push = self.push_pop_pairs[op_id][self.fn_map[fn_handle]] 220 | except KeyError as e: 221 | if not self.strict: 222 | return 223 | else: 224 | raise e 225 | anno.setanno(node, 'push', matching_push, False) 226 | anno.setanno(node.value, 'push', matching_push, False) 227 | 228 | def visit_Expr(self, node): 229 | if isinstance(node.value, gast.Call): 230 | fn_handle = _get_stack_op_handle(node.value) 231 | if fn_handle and fn_handle in [utils.push, utils.push_stack]: 232 | op_id = node.value.args[-1].s 233 | anno.setanno(node, 'push_var', node.value.args[1]) 234 | try: 235 | matching_pop = self.push_pop_pairs[op_id][self.fn_map[fn_handle]] 236 | except KeyError as e: 237 | if not self.strict: 238 | return 239 | else: 240 | raise e 241 | anno.setanno(node, 'pop', matching_pop, False) 242 | anno.setanno(node.value, 'pop', matching_pop, False) 243 | 244 | 245 | def find_stacks(node, strict=False): 246 | """Find pushes and pops to the stack and annotate them as such. 247 | 248 | Args: 249 | node: An AST node that might contain stack pushes and pops. 250 | strict: A boolean indicating whether to stringently test whether each 251 | push and pop are matched. This is not always possible when taking 252 | higher-order derivatives of code generated in split-motion. 253 | 254 | Returns: 255 | node: The node passed in, but with pushes and pops annotated in AST nodes. 256 | """ 257 | # First, find all stack operation IDs. 258 | fso = FindStackOps() 259 | fso.visit(node) 260 | # Using those IDs, make annotations onto the push and pop nodes. 261 | AnnotateStacks(fso.push_pop_pairs, strict).visit(node) 262 | return node 263 | 264 | 265 | class Unused(gast.NodeVisitor): 266 | """Walks AST to find uses of variable definitions. 267 | 268 | See `unused` for details. 269 | """ 270 | 271 | def __init__(self): 272 | # A set that contains all the definitions so far 273 | self.definitions = set() 274 | # A set of all the definitions potentially used so far 275 | self.used = set() 276 | # The definitions that reach the current statement 277 | self.reaching_definitions = () 278 | 279 | @property 280 | def unused(self): 281 | """Calculate which AST nodes are unused. 282 | 283 | Note that we have to take special care in the case of 284 | x,y = f(z) where x is used later, but y is not.""" 285 | unused = self.definitions - self.used 286 | # Filter (variable_name,node) pairs that should be removed, because 287 | # node is used elsewhere 288 | used_nodes = set([u[1] for u in self.used]) 289 | unused = set([u for u in unused if u[1] not in used_nodes]) 290 | return unused 291 | 292 | def visit(self, node): 293 | if anno.hasanno(node, 'definitions_gen'): 294 | self.definitions.update(anno.getanno(node, 'definitions_gen')) 295 | self.reaching_definitions = anno.getanno(node, 'definitions_in') 296 | if isinstance(node, gast.Name) and isinstance(node.ctx, gast.Load): 297 | self.used.update(def_ for def_ in self.reaching_definitions 298 | if def_[0] == node.id) 299 | super(Unused, self).visit(node) 300 | if anno.hasanno(node, 'definitions_gen'): 301 | self.reaching_definitions = None 302 | 303 | 304 | def unused(node): 305 | """Find unused definitions that can be remove. 306 | 307 | This runs reaching definitions analysis followed by a walk over the AST to 308 | find all variable definitions that are not used later on. 309 | 310 | Args: 311 | node: The AST of e.g. a function body to find unused variable definitions. 312 | 313 | Returns: 314 | unused: After visiting all the nodes, this attribute contanis a set of 315 | definitions in the form of `(variable_name, node)` pairs which are 316 | unused in this AST. 317 | """ 318 | cfg.forward(node, cfg.ReachingDefinitions()) 319 | unused_obj = Unused() 320 | unused_obj.visit(node) 321 | return unused_obj.unused 322 | -------------------------------------------------------------------------------- /tangent/annotations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Handling annotations on AST nodes.""" 15 | from __future__ import absolute_import 16 | 17 | import gast 18 | 19 | ANNOTATION_FIELD = '_tangent' 20 | # These annotation's won't be cleared between passes 21 | FIXED_ANNOTATIONS = set(['pop', 'push', 'add_grad', 'init_grad', 'pri', 'adj', 22 | 'push_func', 'pop_func', 'adjoint_var', 23 | 'temp_adjoint_var', 'temp_var', 'pri_call', 24 | 'adj_call', 'comment', 'pre_anf']) 25 | 26 | 27 | def setanno(node, key, value, safe=True): 28 | annotations = getattr(node, ANNOTATION_FIELD, {}) 29 | setattr(node, ANNOTATION_FIELD, annotations) 30 | if safe and hasanno(node, key): 31 | raise ValueError('annotation already present') 32 | annotations[key] = value 33 | 34 | # So that the annotations survive gast_to_ast() and ast_to_gast() 35 | if ANNOTATION_FIELD not in node._fields: 36 | node._fields += (ANNOTATION_FIELD,) 37 | 38 | 39 | def hasanno(node, key): 40 | annotations = getattr(node, ANNOTATION_FIELD, {}) 41 | return key in annotations 42 | 43 | 44 | def setdefaultanno(node, key, value=None): 45 | if not hasanno(node, key): 46 | setanno(node, key, value) 47 | return getanno(node, key) 48 | 49 | 50 | def clearanno(node): 51 | for succ in gast.walk(node): 52 | if hasattr(succ, ANNOTATION_FIELD): 53 | new = {} 54 | for anno in FIXED_ANNOTATIONS: 55 | if hasanno(succ, anno): 56 | new[anno] = getanno(succ, anno) 57 | setattr(succ, ANNOTATION_FIELD, new) 58 | return node 59 | 60 | 61 | def getanno(node, key, default=None): 62 | annotations = getattr(node, ANNOTATION_FIELD, {}) 63 | if key not in annotations and default is None: 64 | raise KeyError('Node "%s" has no annotation "%s"' % (node, key)) 65 | return annotations.get(key, default) 66 | 67 | 68 | def delanno(node, key): 69 | annotations = getattr(node, ANNOTATION_FIELD, {}) 70 | del annotations[key] 71 | -------------------------------------------------------------------------------- /tangent/ast.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities to manipulate the AST and its annotations.""" 15 | from __future__ import absolute_import 16 | import copy 17 | 18 | import gast 19 | 20 | from tangent import annotations as anno 21 | from tangent import quoting 22 | from tangent import utils 23 | 24 | 25 | def get_name(node): 26 | """Get the name of a variable. 27 | 28 | Args: 29 | node: A `Name`, `Subscript` or `Attribute` node. 30 | 31 | Returns: 32 | The name of the variable e.g. `'x'` for `x`, `x.i` and `x[i]`. 33 | """ 34 | if isinstance(node, gast.Name): 35 | return node.id 36 | elif isinstance(node, (gast.Subscript, gast.Attribute)): 37 | return get_name(node.value) 38 | else: 39 | raise TypeError 40 | 41 | 42 | def _get_target(node): 43 | if isinstance(node, (gast.Name, gast.Subscript, gast.Attribute)): 44 | return set([get_name(node)]) 45 | elif isinstance(node, (gast.Tuple, gast.List)): 46 | return set.union(*(_get_target(target) 47 | for target in node.elts)) 48 | else: 49 | raise ValueError 50 | 51 | 52 | def get_updated(node): 53 | """Return the variable names created or mutated by this statement. 54 | 55 | This function considers assign statements, augmented assign statements, and 56 | the targets of for loops, as well as function arguments. 57 | 58 | For example, `x[0] = 2` will return `x`, `x, y = 3, 4` will return `x` and 59 | `y`, `for i in range(x)` will return `i`, etc. 60 | 61 | Args: 62 | node: An AST node 63 | 64 | Returns: 65 | A set of variable names (strings) of all the variables created or mutated. 66 | """ 67 | if isinstance(node, gast.Assign): 68 | return set.union(*(_get_target(target) 69 | for target in node.targets)) 70 | elif isinstance(node, (gast.For, gast.AugAssign)): 71 | return _get_target(node.target) 72 | elif isinstance(node, gast.arguments): 73 | targets = set(arg.id for arg in node.args + node.kwonlyargs) 74 | if node.vararg: 75 | targets.add(node.vararg.id) 76 | if node.kwarg: 77 | targets.add(node.kwarg.id) 78 | return targets 79 | else: 80 | return set() 81 | 82 | 83 | def copy_node(node): 84 | """Copy a node but keep its annotations intact.""" 85 | if not isinstance(node, gast.AST): 86 | return [copy_node(n) for n in node] 87 | new_node = copy.deepcopy(node) 88 | setattr(new_node, anno.ANNOTATION_FIELD, 89 | getattr(node, anno.ANNOTATION_FIELD, {}).copy()) 90 | return new_node 91 | 92 | 93 | class ArgAppend(gast.NodeTransformer): 94 | """Append arguments to a function definition.""" 95 | 96 | def __init__(self, node_list): 97 | self.visited = False 98 | self.node_list = node_list 99 | 100 | def visit_FunctionDef(self, node): 101 | if not self.visited: 102 | node.args.args.extend(self.node_list) 103 | self.visited = True 104 | return node 105 | 106 | 107 | def append_args(node, node_list): 108 | if not isinstance(node_list, list): 109 | raise TypeError('Please pass in a list') 110 | if all([isinstance(n, str) for n in node_list]): 111 | node_list = [quoting.quote(n) for n in node_list] 112 | return ArgAppend(node_list).visit(node) 113 | 114 | 115 | def is_insert_grad_of_statement(node): 116 | """Check whether a context manager calls `insert_grad_of`. 117 | 118 | Args: 119 | node: The context manager node. 120 | 121 | Returns: 122 | Whether or not this node contains `insert_grad_of` calls. 123 | 124 | Raises: 125 | ValueError: If the `insert_grad_of` calls are mixed with other calls. 126 | """ 127 | tangent_calls = [anno.getanno(item.context_expr, 'func', None) 128 | is utils.insert_grad_of for item in node.items] 129 | if all(tangent_calls): 130 | return True 131 | elif any(tangent_calls): 132 | raise ValueError 133 | else: 134 | return False 135 | -------------------------------------------------------------------------------- /tangent/cfg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Control flow graph analysis. 15 | 16 | Given a Python AST we construct a doubly linked control flow graph whose nodes 17 | contain the AST of the statements. We can then perform forward analysis on this 18 | CFG. 19 | 20 | """ 21 | from __future__ import absolute_import 22 | import functools 23 | import operator 24 | 25 | import gast 26 | 27 | from tangent import annotations as anno 28 | from tangent import ast as ast_ 29 | from tangent import grammar 30 | from tangent import utils 31 | 32 | 33 | class Node(object): 34 | """A node in the CFG.""" 35 | __slots__ = ['next', 'value', 'prev'] 36 | 37 | def __init__(self, value): 38 | self.next = set() 39 | self.prev = set() 40 | self.value = value 41 | 42 | 43 | class CFG(gast.NodeVisitor): 44 | """Construct a control flow graph. 45 | 46 | Each statement is represented as a node. For control flow statements such 47 | as conditionals and loops the conditional itself is a node which either 48 | branches or cycles, respectively. 49 | 50 | Attributes: 51 | entry: The entry node, which contains the `gast.arguments` node of the 52 | function definition. 53 | exit: The exit node. This node is special because it has no value (i.e. no 54 | corresponding AST node). This is because Python functions can have 55 | multiple return statements. 56 | """ 57 | 58 | def __init__(self): 59 | # The current leaves of the CFG 60 | self.head = [] 61 | # A stack of continue statements 62 | self.continue_ = [] 63 | # A stack of break nodes 64 | self.break_ = [] 65 | 66 | @staticmethod 67 | def backlink(node): 68 | """Given a CFG with outgoing links, create incoming links.""" 69 | seen = set() 70 | to_see = [node] 71 | while to_see: 72 | node = to_see.pop() 73 | seen.add(node) 74 | for succ in node.next: 75 | succ.prev.add(node) 76 | if succ not in seen: 77 | to_see.append(succ) 78 | 79 | def set_head(self, node): 80 | """Link this node to the current leaves.""" 81 | for head in self.head: 82 | head.next.add(node) 83 | self.head[:] = [] 84 | self.head.append(node) 85 | 86 | @classmethod 87 | def build_cfg(cls, node): 88 | """Build a CFG for a function. 89 | 90 | Args: 91 | node: A function definition the body of which to analyze. 92 | 93 | Returns: 94 | A CFG object. 95 | 96 | Raises: 97 | TypeError: If the input is not a function definition. 98 | """ 99 | if not isinstance(node, gast.FunctionDef): 100 | raise TypeError('input must be a function definition') 101 | cfg = cls() 102 | cfg.entry = Node(node.args) 103 | cfg.head = [cfg.entry] 104 | cfg.visit_statements(node.body) 105 | cfg.exit = Node(None) 106 | cfg.set_head(cfg.exit) 107 | cfg.backlink(cfg.entry) 108 | return cfg 109 | 110 | def visit_statements(self, nodes): 111 | for node in nodes: 112 | if isinstance(node, grammar.CONTROL_FLOW): 113 | self.visit(node) 114 | else: 115 | expr = Node(node) 116 | self.set_head(expr) 117 | 118 | def generic_visit(self, node): 119 | raise ValueError('unknown control flow') 120 | 121 | def visit_If(self, node): 122 | # The current head will hold the conditional 123 | test = Node(node.test) 124 | self.set_head(test) 125 | # Handle the body 126 | self.visit_statements(node.body) 127 | body_exit = self.head[:] 128 | self.head[:] = [] 129 | self.head.append(test) 130 | # Handle the orelse 131 | self.visit_statements(node.orelse) 132 | self.head.extend(body_exit) 133 | 134 | def visit_While(self, node): 135 | test = Node(node.test) 136 | self.set_head(test) 137 | # Start a new level of nesting 138 | self.break_.append([]) 139 | self.continue_.append([]) 140 | # Handle the body 141 | self.visit_statements(node.body) 142 | self.head.extend(self.continue_.pop()) 143 | self.set_head(test) 144 | # Handle the orelse 145 | self.visit_statements(node.orelse) 146 | # The break statements and the test go to the next node 147 | self.head.extend(self.break_.pop()) 148 | 149 | def visit_For(self, node): 150 | iter_ = Node(node) 151 | self.set_head(iter_) 152 | self.break_.append([]) 153 | self.continue_.append([]) 154 | self.visit_statements(node.body) 155 | self.head.extend(self.continue_.pop()) 156 | self.set_head(iter_) 157 | self.head.extend(self.break_.pop()) 158 | 159 | def visit_Break(self, node): 160 | self.break_[-1].extend(self.head) 161 | self.head[:] = [] 162 | 163 | def visit_Continue(self, node): 164 | self.continue_[-1].extend(self.head) 165 | self.head[:] = [] 166 | 167 | def visit_Try(self, node): 168 | self.visit_statements(node.body) 169 | body = self.head 170 | handlers = [] 171 | for handler in node.handlers: 172 | self.head = body[:] 173 | self.visit_statements(handler.body) 174 | handlers.extend(self.head) 175 | self.head = body 176 | self.visit_statements(node.orelse) 177 | self.head = handlers + self.head 178 | self.visit_statements(node.finalbody) 179 | 180 | 181 | class Forward(object): 182 | """Forward analysis on CFG. 183 | 184 | Args: 185 | label: A name for this analysis e.g. 'active' for activity analysis. The 186 | AST nodes in the CFG will be given annotations 'name_in', 'name_out', 187 | 'name_gen' and 'name_kill' which contain the incoming values, outgoing 188 | values, values generated by the statement, and values deleted by the 189 | statement respectively. 190 | gen: A function which takes the CFG node as well as a set of incoming 191 | values. It must return a set of newly generated values by the statement 192 | as well as a set of deleted (killed) values. 193 | op: Either the AND or OR operator. If the AND operator is used it turns 194 | into forward must analysis (i.e. a value will only be carried forward 195 | if it appears on all incoming paths). The OR operator means that 196 | forward may analysis is done (i.e. the union of incoming values will be 197 | taken). 198 | """ 199 | 200 | def __init__(self, label, gen, op=operator.or_): 201 | self.gen = gen 202 | self.op = op 203 | self.out_label = label + '_out' 204 | self.in_label = label + '_in' 205 | self.gen_label = label + '_gen' 206 | self.kill_label = label + '_kill' 207 | 208 | def visit(self, node): 209 | if node.value: 210 | if anno.hasanno(node.value, self.out_label): 211 | before = hash(anno.getanno(node.value, self.out_label)) 212 | else: 213 | before = None 214 | preds = [anno.getanno(pred.value, self.out_label) 215 | for pred in node.prev 216 | if anno.hasanno(pred.value, self.out_label)] 217 | if preds: 218 | incoming = functools.reduce(self.op, preds[1:], preds[0]) 219 | else: 220 | incoming = frozenset() 221 | anno.setanno(node.value, self.in_label, incoming, safe=False) 222 | gen, kill = self.gen(node, incoming) 223 | anno.setanno(node.value, self.gen_label, gen, safe=False) 224 | anno.setanno(node.value, self.kill_label, kill, safe=False) 225 | anno.setanno(node.value, self.out_label, (incoming - kill) | gen, 226 | safe=False) 227 | if hash(anno.getanno(node.value, self.out_label)) != before: 228 | for succ in node.next: 229 | self.visit(succ) 230 | else: 231 | preds = [anno.getanno(pred.value, self.out_label) 232 | for pred in node.prev] 233 | self.exit = functools.reduce(self.op, preds[1:], preds[0]) 234 | 235 | 236 | def forward(node, analysis): 237 | """Perform a given analysis on all functions within an AST.""" 238 | if not isinstance(analysis, Forward): 239 | raise TypeError('not a valid forward analysis object') 240 | for succ in gast.walk(node): 241 | if isinstance(succ, gast.FunctionDef): 242 | cfg_obj = CFG.build_cfg(succ) 243 | analysis.visit(cfg_obj.entry) 244 | return node 245 | 246 | 247 | class ReachingDefinitions(Forward): 248 | """Perform reaching definition analysis. 249 | 250 | Each statement is annotated with a set of (variable, definition) pairs. 251 | 252 | """ 253 | 254 | def __init__(self): 255 | def definition(node, incoming): 256 | definitions = ast_.get_updated(node.value) 257 | gen = frozenset((id_, node.value) for id_ in definitions) 258 | kill = frozenset(def_ for def_ in incoming 259 | if def_[0] in definitions) 260 | return gen, kill 261 | super(ReachingDefinitions, self).__init__('definitions', definition) 262 | 263 | 264 | class Defined(Forward): 265 | """Perform defined variable analysis. 266 | 267 | Each statement is annotated with a set of variables which are guaranteed to 268 | be defined at that point. 269 | """ 270 | 271 | def __init__(self): 272 | def defined(node, incoming): 273 | gen = ast_.get_updated(node.value) 274 | return gen, frozenset() 275 | super(Defined, self).__init__('defined', defined, operator.and_) 276 | 277 | 278 | class Active(Forward): 279 | """Active variable analysis. 280 | 281 | Given a set of active arguments, find all variables that are active i.e. 282 | variables whose values possibly depend on the given set of arguments. 283 | 284 | Args: 285 | wrt: A tuple of indices of arguments that are active. 286 | """ 287 | 288 | def __init__(self, wrt): 289 | def active(node, incoming): 290 | gen = set() 291 | kill = set() 292 | if isinstance(node.value, gast.arguments): 293 | gen.update(node.value.args[i].id for i in wrt) 294 | if isinstance(node.value, gast.Assign): 295 | # Special-case e.g. x = tangent.pop(_stack) 296 | # such that all values popped off the stack are live. 297 | if anno.getanno(node.value.value, 'func', False) == utils.pop: 298 | gen.update(ast_.get_updated(node.value)) 299 | else: 300 | for succ in gast.walk(node.value.value): 301 | if isinstance(succ, gast.Name) and succ.id in incoming: 302 | gen.update(ast_.get_updated(node.value)) 303 | break 304 | else: 305 | kill.update(ast_.get_updated(node.value)) 306 | return gen, kill 307 | super(Active, self).__init__('active', active) 308 | -------------------------------------------------------------------------------- /tangent/comments.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Handling comments on nodes. 15 | 16 | To make the generated derivative source code more legible, statements are 17 | annotated with human-readable comments. 18 | 19 | """ 20 | from __future__ import absolute_import 21 | 22 | import gast 23 | 24 | from tangent import annotations as anno 25 | 26 | 27 | def add_comment(node, text, location='above'): 28 | """Add a comment to the given node. 29 | 30 | If the `SourceWithCommentGenerator` class is used these comments will be 31 | output as part of the source code. 32 | 33 | Note that a node can only contain one comment. Subsequent calls to 34 | `add_comment` will ovverride the existing comments. 35 | 36 | Args: 37 | node: The AST node whose containing statement will be commented. 38 | text: A comment string. 39 | location: Where the comment should appear. Valid values are 'above', 40 | 'below' and 'right' 41 | 42 | Returns: 43 | The node with the comment stored as an annotation. 44 | """ 45 | anno.setanno(node, 'comment', dict(location=location, text=text), safe=False) 46 | return node 47 | 48 | 49 | def remove_repeated_comments(node): 50 | """Remove comments that repeat themselves. 51 | 52 | Multiple statements might be annotated with the same comment. This way if one 53 | of the statements is deleted during optimization passes, the comment won't be 54 | lost. This pass removes sequences of identical comments, leaving only the 55 | first one. 56 | 57 | Args: 58 | node: An AST 59 | 60 | Returns: 61 | An AST where comments are not repeated in sequence. 62 | 63 | """ 64 | last_comment = {'text': None} 65 | for _node in gast.walk(node): 66 | if anno.hasanno(_node, 'comment'): 67 | comment = anno.getanno(_node, 'comment') 68 | if comment['text'] == last_comment['text']: 69 | anno.delanno(_node, 'comment') 70 | last_comment = comment 71 | return node 72 | -------------------------------------------------------------------------------- /tangent/compile.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Going from AST or source code to executable code.""" 15 | from __future__ import absolute_import 16 | import os 17 | import tempfile 18 | from uuid import uuid4 19 | 20 | import gast 21 | import six 22 | if six.PY3: 23 | from importlib import util 24 | else: 25 | import imp 26 | 27 | from tangent import quoting 28 | 29 | 30 | def compile_file(source, globals_=None): 31 | """Compile by saving to file and importing that. 32 | 33 | Compiling the AST/source code this way ensures that the source code is 34 | readable by e.g. `pdb` or `inspect`. 35 | 36 | Args: 37 | source: The code to compile, either as a string or as an AST. 38 | globals_: A dictionary of variables that should be available as globals in 39 | the compiled module. They will be monkey patched after importing the 40 | module. 41 | 42 | Returns: 43 | A module object containing the compiled source code. 44 | """ 45 | if isinstance(source, gast.AST): 46 | source = quoting.to_source(source) 47 | 48 | # Write source to temporary file 49 | tempdir = tempfile.mkdtemp() 50 | uuid = str(uuid4().hex[:4]) 51 | tmpname = os.path.join(tempdir, 'tangent_%s.py' % uuid) 52 | with open(tmpname, 'w') as f: 53 | f.write(source) 54 | 55 | # Load the temporary file as a module 56 | module_name = 'tangent_%s' % uuid 57 | if six.PY3: 58 | spec = util.spec_from_file_location(module_name, tmpname) 59 | m = util.module_from_spec(spec) 60 | spec.loader.exec_module(m) 61 | else: 62 | m = imp.load_source(module_name, tmpname) 63 | 64 | # Update the modules namespace 65 | if globals_: 66 | m.__dict__.update(globals_) 67 | return m 68 | 69 | 70 | def compile_function(node, globals_=None): 71 | """Convert an AST or string into a function with inspectable source. 72 | 73 | This function uses `compile_file` internally, but instead of returning the 74 | entire module it will return the function only. 75 | 76 | Args: 77 | node: A `FunctionDef` node or a `Module` node which contains at least one 78 | `FunctionDef` node. If a module contains multiple functions, a handle 79 | to the first one will be returned. 80 | globals_: See `compile_file` 81 | 82 | Returns: 83 | A handle to the compiled function. 84 | 85 | Raises: 86 | TypeError: If the input is not a string or AST. 87 | ValueError: If no function can be found. 88 | """ 89 | if not isinstance(node, gast.AST): 90 | if not isinstance(node, six.string_types): 91 | raise TypeError 92 | node = gast.parse(node) 93 | if isinstance(node, gast.Module): 94 | for succ in node.body: 95 | if isinstance(succ, gast.FunctionDef): 96 | name = succ.name 97 | break 98 | else: 99 | raise ValueError('no function found') 100 | elif isinstance(node, gast.FunctionDef): 101 | name = node.name 102 | else: 103 | raise TypeError 104 | module = compile_file(node, globals_) 105 | return getattr(module, name) 106 | -------------------------------------------------------------------------------- /tangent/create.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Helper functions to create gradient nodes from other nodes.""" 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | 18 | import gast 19 | 20 | from tangent import annotations as anno 21 | 22 | 23 | def create_grad(node, namer, tangent=False): 24 | """Given a variable, create a variable for the gradient. 25 | 26 | Args: 27 | node: A node to create a gradient for, can be a normal variable (`x`) or a 28 | subscript (`x[i]`). 29 | namer: The namer object which will determine the name to use for the 30 | gradient. 31 | tangent: Whether a tangent (instead of adjoint) is created. 32 | 33 | Returns: 34 | node: A node representing the gradient with the correct name e.g. the 35 | gradient of `x[i]` is `dx[i]`. 36 | 37 | Note that this returns an invalid node, with the `ctx` attribute 38 | missing. It is assumed that this attribute is filled in later. 39 | 40 | Node has an `adjoint_var` annotation referring to the node it is an 41 | adjoint of. 42 | """ 43 | if not isinstance(node, (gast.Subscript, gast.Name, gast.Str)): 44 | raise TypeError 45 | 46 | if anno.hasanno(node, 'temp_var'): 47 | return create_grad(anno.getanno(node, 'temp_var'), namer, tangent) 48 | 49 | def _name_grad(node): 50 | if not isinstance(node, gast.Name): 51 | raise TypeError 52 | varname = node.id 53 | name = namer.grad(varname, tangent) 54 | grad_node = gast.Name( 55 | id=name, ctx=None, annotation=None) 56 | anno.setanno(grad_node, 'adjoint_var', node) 57 | return grad_node 58 | if isinstance(node, gast.Subscript): 59 | grad_node = create_grad(node.value, namer, tangent=tangent) 60 | grad_node.ctx = gast.Load() 61 | return gast.Subscript(value=grad_node, slice=node.slice, ctx=None) 62 | elif isinstance(node, gast.Str): 63 | grad_node = create_grad( 64 | gast.Name(id=node.s, ctx=None, annotation=None), namer, tangent=tangent) 65 | return gast.Str(grad_node.id) 66 | else: 67 | return _name_grad(node) 68 | 69 | 70 | def create_temp_grad(node, namer, tangent=False): 71 | """Create a variable to store partial gradients. 72 | 73 | Args: 74 | node: See `create_grad`. 75 | namer: See `create_grad`. 76 | tangent: See `create_grad`. 77 | 78 | Returns: 79 | node: See `create_grad`. Returns a node representing the partial gradient. 80 | Note that this is always a simple variable e.g. the temporary partial 81 | of `x[i]` can be something like `_dxi`. 82 | 83 | Nodes are given an annotation `temp_adjoint_var`. 84 | """ 85 | if not isinstance(node, (gast.Subscript, gast.Name)): 86 | raise TypeError 87 | 88 | def _name_temp_grad(node): 89 | name = namer.temp_grad(node.id, tangent) 90 | temp_node = gast.Name(id=name, annotation=None, ctx=None) 91 | return temp_node 92 | if isinstance(node, gast.Subscript): 93 | temp_node = _name_temp_grad(node.value) 94 | else: 95 | temp_node = _name_temp_grad(node) 96 | anno.setanno(temp_node, 'temp_adjoint_var', node) 97 | return temp_node 98 | 99 | 100 | def create_temp(node, namer): 101 | """Create a temporary variable. 102 | 103 | Args: 104 | node: Create a temporary variable to store this variable in. 105 | namer: A naming object that guarantees the names are unique. 106 | 107 | Returns: 108 | node: See `create_grad`. Returns a temporary variable, which is always a 109 | simple variable annotated with `temp_var`. 110 | """ 111 | if isinstance(node, gast.Name): 112 | name = node.id 113 | elif isinstance(node, (gast.Attribute, gast.Subscript)): 114 | name = node.value.id 115 | else: 116 | raise TypeError 117 | temp_node = gast.Name(id=namer.temp(name), annotation=None, ctx=None) 118 | anno.setanno(temp_node, 'temp_var', node) 119 | return temp_node 120 | -------------------------------------------------------------------------------- /tangent/desugar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | 17 | import gast 18 | import copy 19 | 20 | from tangent import ast 21 | from tangent import annotations as anno 22 | from tangent import cfg 23 | from tangent import naming 24 | from tangent import quoting 25 | from tangent import template 26 | from tangent import transformers 27 | 28 | 29 | class ExplicitLoopIndexes(transformers.TreeTransformer): 30 | 31 | def visit_FunctionDef(self, node): 32 | cfg.forward(node, cfg.Active(range(len(node.args.args)))) 33 | self.namer = naming.Namer.build(node) 34 | node = self.generic_visit(node) 35 | return node 36 | 37 | def visit_For(self, node): 38 | # If the iter is a Name that is active, 39 | # we need to rewrite the loop. 40 | # Iterators of the form `for a in x` rely on an implicit 41 | # indexing operation, which Tangent cannot reverse without 42 | # more information. So, we will create an explicit 43 | # indexing operation. Note that we will use 44 | # integer indexes, which will cause strange behavior if 45 | # the iterator's `next()` behavior deviates from 46 | # a plain incrementing index. 47 | # The right thing to do (eventually) is to write a multiple-dispatch 48 | # version of the `next` operator, and its adjoint, so that 49 | # we can handle e.g. dicts. 50 | 51 | if isinstance(node.iter, (gast.Name, gast.Subscript, gast.Attribute)): 52 | iter_name = ast.get_name(node.iter) 53 | if iter_name in anno.getanno(node, 'active_in'): 54 | # for a in x: 55 | # f(a) 56 | # # becomes 57 | # for i in range(len(x)): 58 | # a = x[i] 59 | # f(a) 60 | 61 | # Get a unique iterator name 62 | old_target = copy.deepcopy(node.target) 63 | new_target = quoting.quote(self.namer.unique('_idx')) 64 | old_iter = copy.deepcopy(node.iter) 65 | 66 | item_access = template.replace( 67 | 'old_target = x[i]', 68 | old_target=old_target, 69 | x=old_iter, 70 | i=new_target) 71 | 72 | node.target = gast.Name(id=new_target.id, ctx=gast.Store(), annotation=None) 73 | node.iter = quoting.quote('range(len(%s))' % iter_name) 74 | anno.setanno(node.iter, 'func', range) 75 | anno.setanno(node.iter.args[0], 'func', len) 76 | node.body = [item_access] + node.body 77 | 78 | return node 79 | 80 | 81 | def explicit_loop_indexes(node): 82 | node = ExplicitLoopIndexes().visit(node) 83 | for n in gast.walk(node): 84 | for key in ('active_in', 'active_out', 'active_gen', 'active_kill'): 85 | if anno.hasanno(n, key): 86 | anno.delanno(n, key) 87 | return node -------------------------------------------------------------------------------- /tangent/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tangent-specific errors.""" 15 | from __future__ import absolute_import 16 | 17 | 18 | class TangentParseError(SyntaxError): 19 | """Error generated when encountering an unsupported feature.""" 20 | pass 21 | 22 | 23 | class ForwardNotImplementedError(NotImplementedError): 24 | """Error generated when encountering a @tangent_ yet to be implemented.""" 25 | 26 | def __init__(self, func): 27 | NotImplementedError.__init__( 28 | self, 'Forward mode for function "%s" is not yet implemented.' % 29 | func.__name__) 30 | 31 | 32 | class ReverseNotImplementedError(NotImplementedError): 33 | """Error generated when encountering an @adjoint yet to be implemented.""" 34 | 35 | def __init__(self, func): 36 | NotImplementedError.__init__( 37 | self, 38 | 'Reverse mode for function "%s" is not yet implemented.' % 39 | func.__name__, 40 | ) 41 | -------------------------------------------------------------------------------- /tangent/fence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License'); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The fence allows placing language feature restrictions on the AST. 15 | 16 | The fence works by walking an AST and raising an error if the tree contains any 17 | of the unsupported features. Only the first encountered feature is flagged. 18 | 19 | For a detailed documentation of AST nodes, see 20 | http://greentreesnakes.readthedocs.io/en/latest/nodes.html 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | 25 | import gast as ast 26 | 27 | from tangent.errors import TangentParseError 28 | 29 | 30 | def validate(node, source): 31 | """Call this function to validate an AST.""" 32 | # TODO: leaving strict checking off to support insert_grad_of 33 | lf = LanguageFence(source, strict=False) 34 | lf.visit(node) 35 | return node 36 | 37 | 38 | class LanguageFence(ast.NodeVisitor): 39 | """An AST visitor that raises an error for unsupported language features. 40 | 41 | This implementation is not thread-safe. 42 | 43 | LanguageFence instances are lightweight and tied to the AST they validate. 44 | In general, you should not attempt to reuse them. 45 | """ 46 | 47 | def __init__(self, source, strict=True): 48 | """Creates a LanguageFence. 49 | 50 | Args: 51 | source: String, the source code of the AST that will be verified. 52 | strict: Boolean, set to False to allow unsafe constructs. 53 | Raises: 54 | ValueError: if source code has not been supplied. 55 | """ 56 | self._visited_top_module = False 57 | if not source: 58 | raise ValueError('The source code of the tree is required.') 59 | self._source = source 60 | self._strict = strict 61 | 62 | # Location information is used to locate the offending elements 63 | # in the source code. 64 | self._current_lineno = None # Only consistent during a visit. 65 | self._current_offset = None # Only consistent during a visit. 66 | 67 | super(LanguageFence, self).__init__() 68 | 69 | def _raise_error(self, msg): 70 | assert self._source 71 | lineno = self._current_lineno 72 | offset = self._current_offset 73 | line = self._source.splitlines()[lineno - 1] 74 | raise TangentParseError(msg, ('', lineno, offset + 1, line)) 75 | 76 | def _track_location(self, node): 77 | # TODO: Add tests that cover all nodes. 78 | exposed_symbols = dir(node) 79 | # Not all nodes have source information. This is a generic way to collect 80 | # whenever available. 81 | if 'lineno' in exposed_symbols and 'col_offset' in exposed_symbols: 82 | self._current_lineno = node.lineno 83 | self._current_offset = node.col_offset 84 | 85 | def _allow_and_continue(self, node): 86 | self._track_location(node) 87 | self.generic_visit(node) 88 | 89 | def _reject(self, node, msg): 90 | self._track_location(node) 91 | self._raise_error(msg) 92 | 93 | def visit_Module(self, node): 94 | self._visited_top_module = True 95 | self._allow_and_continue(node) 96 | 97 | def visit_Num(self, node): 98 | self._allow_and_continue(node) 99 | 100 | def visit_Str(self, node): 101 | self._allow_and_continue(node) 102 | 103 | def visit_FormattedValue(self, node): 104 | self._reject(node, 'F-Strings are not supported') 105 | 106 | def visit_JoinedStr(self, node): 107 | self._reject(node, 'F-Strings are not supported') 108 | 109 | def visit_Bytes(self, node): 110 | self._reject(node, 'Byte Literals are not supported') 111 | 112 | def visit_List(self, node): 113 | self._allow_and_continue(node) 114 | 115 | def visit_Tuple(self, node): 116 | # TODO: Make sure none of the original functionality was lost. 117 | self._allow_and_continue(node) 118 | 119 | def visit_Set(self, node): 120 | self._reject(node, 'Sets not are supported') 121 | 122 | def visit_Dict(self, node): 123 | self._allow_and_continue(node) 124 | 125 | def visit_Ellipsis(self, node): 126 | self._allow_and_continue(node) 127 | 128 | def visit_NameConstant(self, node): 129 | self._allow_and_continue(node) 130 | 131 | def visit_Name(self, node): 132 | self._allow_and_continue(node) 133 | 134 | def visit_Load(self, node): 135 | self._allow_and_continue(node) 136 | 137 | def visit_Store(self, node): 138 | self._allow_and_continue(node) 139 | 140 | def visit_Del(self, node): 141 | self._reject(node, 'Deleting variables is not supported') 142 | 143 | def visit_Starred(self, node): 144 | self._reject(node, 'Unpackings are not supported') 145 | 146 | def visit_Expr(self, node): 147 | self._allow_and_continue(node) 148 | 149 | def visit_UnaryOp(self, node): 150 | self._allow_and_continue(node) 151 | 152 | def visit_UAdd(self, node): 153 | self._reject(node, 'Unary Add operator is not supported') 154 | 155 | def visit_USub(self, node): 156 | self._allow_and_continue(node) 157 | 158 | def visit_Not(self, node): 159 | self._reject(node, 'Not operator is not supported') 160 | 161 | def visit_Invert(self, node): 162 | self._reject(node, 'Invert operator is not supported') 163 | 164 | def visit_BinOp(self, node): 165 | self._allow_and_continue(node) 166 | 167 | def visit_Add(self, node): 168 | self._allow_and_continue(node) 169 | 170 | def visit_Sub(self, node): 171 | self._allow_and_continue(node) 172 | 173 | def visit_Mult(self, node): 174 | self._allow_and_continue(node) 175 | 176 | def visit_Div(self, node): 177 | self._allow_and_continue(node) 178 | 179 | def visit_FloorDiv(self, node): 180 | self._reject(node, 'Floor Div operator is not supported') 181 | 182 | def visit_Mod(self, node): 183 | self._allow_and_continue(node) 184 | 185 | def visit_Pow(self, node): 186 | self._allow_and_continue(node) 187 | 188 | def visit_LShift(self, node): 189 | self._reject(node, 'Left Shift operator is not supported') 190 | 191 | def visit_RShift(self, node): 192 | self._reject(node, 'Right Shift operator is not supported') 193 | 194 | def visit_BitOr(self, node): 195 | self._reject(node, 'Bitwise Or operator is not supported') 196 | 197 | def visit_BitXor(self, node): 198 | self._reject(node, 'Bitwise Xor operator is not supported') 199 | 200 | def visit_BitAnd(self, node): 201 | self._reject(node, 'Bitwise And operator is not supported') 202 | 203 | def visit_MatMult(self, node): 204 | # TODO: Add support for this. 205 | self._reject(node, 'MatMult operator is not supported') 206 | 207 | def visit_BoolOp(self, node): 208 | self._allow_and_continue(node) 209 | 210 | def visit_And(self, node): 211 | self._allow_and_continue(node) 212 | 213 | def visit_Or(self, node): 214 | self._allow_and_continue(node) 215 | 216 | def visit_Compare(self, node): 217 | self._allow_and_continue(node) 218 | 219 | def visit_Eq(self, node): 220 | self._allow_and_continue(node) 221 | 222 | def visit_NotEq(self, node): 223 | self._allow_and_continue(node) 224 | 225 | def visit_Lt(self, node): 226 | self._allow_and_continue(node) 227 | 228 | def visit_LtE(self, node): 229 | self._allow_and_continue(node) 230 | 231 | def visit_Gt(self, node): 232 | self._allow_and_continue(node) 233 | 234 | def visit_GtE(self, node): 235 | self._allow_and_continue(node) 236 | 237 | def visit_Is(self, node): 238 | self._allow_and_continue(node) 239 | 240 | def visit_IsNot(self, node): 241 | self._allow_and_continue(node) 242 | 243 | def visit_In(self, node): 244 | self._reject(node, 'In operator is not supported') 245 | 246 | def visit_NotIn(self, node): 247 | self._reject(node, 'Not In operator is not supported') 248 | 249 | def visit_Call(self, node): 250 | self._allow_and_continue(node) 251 | 252 | def visit_keyword(self, node): 253 | self._allow_and_continue(node) 254 | 255 | def visit_IfExp(self, node): 256 | self._reject(node, 'Conditional Expressions are not supported') 257 | 258 | def visit_Attribute(self, node): 259 | self._allow_and_continue(node) 260 | 261 | def visit_Subscript(self, node): 262 | self._allow_and_continue(node) 263 | 264 | def visit_Index(self, node): 265 | self._allow_and_continue(node) 266 | 267 | def visit_Slice(self, node): 268 | self._allow_and_continue(node) 269 | 270 | def visit_ExtSlice(self, node): 271 | self._allow_and_continue(node) 272 | 273 | def visit_ListComp(self, node): 274 | self._allow_and_continue(node) 275 | 276 | def visit_SetComp(self, node): 277 | self._reject(node, 'Set Comprehensions are not supported') 278 | 279 | def visit_GeneratorExp(self, node): 280 | self._reject(node, 'Generator Expressions are not supported') 281 | 282 | def visit_DictComp(self, node): 283 | self._reject(node, 'Dictionary Comprehensions are not supported') 284 | 285 | def visit_comprehension(self, node): 286 | self._allow_and_continue(node) 287 | 288 | def visit_Assign(self, node): 289 | self._allow_and_continue(node) 290 | 291 | def visit_AnnAssign(self, node): 292 | self._reject(node, 'Type-annotated assignment are not supported') 293 | 294 | def visit_AugAssign(self, node): 295 | self._allow_and_continue(node) 296 | 297 | def visit_Print(self, node): 298 | self._allow_and_continue(node) 299 | 300 | def visit_Raise(self, node): 301 | self._allow_and_continue(node) 302 | 303 | def visit_Assert(self, node): 304 | if __debug__: 305 | self._allow_and_continue(node) 306 | else: 307 | assert False, 'Assert statements should not appear in optimized code' 308 | 309 | def visit_Delete(self, node): 310 | self._reject(node, 'Delete statements are not supported') 311 | 312 | def visit_Pass(self, node): 313 | self._allow_and_continue(node) 314 | 315 | def visit_Import(self, node): 316 | self._reject(node, 'Import statements are not supported') 317 | 318 | def visit_ImportFrom(self, node): 319 | self._reject(node, 'Import/From statements are not supported') 320 | 321 | def visit_alias(self, node): 322 | self._reject(node, 'Aliases are not supported') 323 | 324 | def visit_If(self, node): 325 | self._allow_and_continue(node) 326 | 327 | def visit_For(self, node): 328 | if node.orelse: 329 | self._reject(node, 'For/Else block is not supported') 330 | else: 331 | self._allow_and_continue(node) 332 | 333 | def visit_While(self, node): 334 | self._allow_and_continue(node) 335 | 336 | def visit_Break(self, node): 337 | if self._strict: 338 | self._reject(node, 'Break statements are not supported in strict mode') 339 | else: 340 | self._allow_and_continue(node) 341 | 342 | def visit_Continue(self, node): 343 | self._reject(node, 'Continue statements are not supported') 344 | 345 | def visit_Try(self, node): 346 | self._allow_and_continue(node) 347 | 348 | def visit_TryFinally(self, node): 349 | self._reject(node, 'Try/Finally blocks are not supported') 350 | 351 | def visit_TryExcept(self, node): 352 | self._reject(node, 'Try/Except blocks are not supported') 353 | 354 | def visit_ExceptHandler(self, node): 355 | self._allow_and_continue(node) 356 | 357 | def visit_With(self, node): 358 | self._allow_and_continue(node) 359 | 360 | def visit_withitem(self, node): 361 | self._allow_and_continue(node) 362 | 363 | def visit_FunctionDef(self, node): 364 | self._allow_and_continue(node) 365 | 366 | def visit_Lambda(self, node): 367 | self._reject(node, 'Lambda functions are not supported') 368 | 369 | def visit_arguments(self, node): 370 | self._allow_and_continue(node) 371 | 372 | def visit_arg(self, node): 373 | self._allow_and_continue(node) 374 | 375 | def visit_Return(self, node): 376 | # TODO: Make sure none of the original functionality was lost. 377 | self._allow_and_continue(node) 378 | 379 | def visit_Yield(self, node): 380 | self._reject(node, 'Yield statements are not supported') 381 | 382 | def visit_YieldFrom(self, node): 383 | self._reject(node, 'Yield/From statements are not supported') 384 | 385 | def visit_Global(self, node): 386 | self._reject(node, 'Global statements are not supported') 387 | 388 | def visit_Nonlocal(self, node): 389 | self._reject(node, 'Nonlocal statements are not supported') 390 | 391 | def visit_ClassDef(self, node): 392 | self._reject(node, 'Classes are not supported') 393 | 394 | def visit_AsyncFunctionDef(self, node): 395 | self._reject(node, 'Async function definitions are not supported') 396 | 397 | def visit_Await(self, node): 398 | self._reject(node, 'Await statements are not supported') 399 | 400 | def visit_AsyncFor(self, node): 401 | self._reject(node, 'Async For loops are not supported') 402 | 403 | def visit_AsyncWith(self, node): 404 | self._reject(node, 'Async With statements are not supported') 405 | -------------------------------------------------------------------------------- /tangent/fixes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Fix naive AD rules. 15 | 16 | Automatic differentiation proceeds by transforming each statement in isolation. 17 | In principle, this works, but there are some corner cases: 18 | 19 | Each variable gets pushed to the stack before being assigned to. However, the 20 | first time a variable get assigned this results in pushing an undefined 21 | variable. We either remove these entirely (`CleanStack`) or we ensure the 22 | variable exists by manually setting the variable to `None` (`FixStack`, e.g. 23 | for loops). 24 | 25 | Each partial gets accumulated into the gradient of that variable. The first 26 | time this happens the gradient doesn't exist yet, so we replace accumulation 27 | with assignment (`CleanGrad`) or we explicitly initialize the gradient to zeros 28 | (`FixGrad`, e.g. in loops). 29 | 30 | """ 31 | from __future__ import absolute_import 32 | import gast 33 | 34 | from tangent import annotations as anno 35 | from tangent import ast as ast_ 36 | from tangent import quoting 37 | from tangent import transformers 38 | from tangent import utils 39 | 40 | 41 | class CleanStack(transformers.TreeTransformer): 42 | """Remove stack pushes of variables that are never defined.""" 43 | 44 | def visit(self, node): 45 | # Remove all AD-generated pushes of unused variables. 46 | if anno.hasanno(node, 'push_var') and anno.hasanno( 47 | node, 'pop') and anno.hasanno(node, 'gen_push'): 48 | defs = frozenset(id_ 49 | for id_, node in anno.getanno(node, 'definitions_in')) 50 | if ast_.get_name(anno.getanno(node, 'push_var')) not in defs: 51 | self.remove(node) 52 | self.remove(anno.getanno(node, 'pop')) 53 | return super(CleanStack, self).visit(node) 54 | 55 | 56 | class FixStack(transformers.TreeTransformer): 57 | """Explicitly defines variables that might not be defined.""" 58 | 59 | def visit(self, node): 60 | if anno.hasanno(node, 'push_var'): 61 | varname = ast_.get_name(anno.getanno(node, 'push_var')) 62 | if varname not in anno.getanno(node, 'defined_in'): 63 | self.insert_top(quoting.quote('{} = None'.format(varname))) 64 | return super(FixStack, self).visit(node) 65 | 66 | 67 | class CleanGrad(gast.NodeTransformer): 68 | """Replace `dx = dx + partial` with `dx = partial` if `dx` undefined.""" 69 | 70 | def visit_Assign(self, node): 71 | if isinstance(node.value, gast.Call) and anno.hasanno(node.value.func, 72 | 'add_grad'): 73 | defs = frozenset(id_ for id_, node in anno.getanno(node, 74 | 'definitions_in')) 75 | if ast_.get_name(node.targets[0]) not in defs: 76 | node.value = node.value.args[1] 77 | return node 78 | 79 | 80 | class FixGrad(transformers.TreeTransformer): 81 | """Explicitly initialize gradient to zero if needed.""" 82 | 83 | def __init__(self): 84 | super(FixGrad, self).__init__() 85 | self.added = set() 86 | 87 | def _init(self, node): 88 | gradname = ast_.get_name(node) 89 | if anno.hasanno(node, 'adjoint_var'): 90 | var = anno.getanno(node, 'adjoint_var') 91 | else: 92 | var = anno.getanno(node, 'temp_adjoint_var') 93 | return gast.Assign( 94 | targets=[gast.Name(id=gradname, ctx=gast.Store(), annotation=None)], 95 | value=gast.Call(func=utils.INIT_GRAD, args=[var], keywords=[])) 96 | 97 | def prepend_uninitialized_grads(self, node): 98 | if anno.hasanno(node, 'defined_in'): 99 | uses = (succ for succ in gast.walk(node) if 100 | isinstance(succ, gast.Name) and 101 | isinstance(succ.ctx, gast.Load)) 102 | for use in uses: 103 | if ((anno.hasanno(use, 'adjoint_var') or 104 | anno.hasanno(use, 'temp_adjoint_var')) and 105 | use.id not in anno.getanno(node, 'defined_in') and 106 | use.id not in self.added): 107 | self.added.add(use.id) 108 | self.insert_top(self._init(use)) 109 | return node 110 | 111 | def visit_Assign(self, node): 112 | node = self.prepend_uninitialized_grads(node) 113 | return node 114 | 115 | def visit_AugAssign(self, node): 116 | node = self.prepend_uninitialized_grads(node) 117 | return node 118 | 119 | def visit_Expr(self, node): 120 | node = self.prepend_uninitialized_grads(node) 121 | return node 122 | 123 | def visit_Return(self, node): 124 | node = self.prepend_uninitialized_grads(node) 125 | return node 126 | -------------------------------------------------------------------------------- /tangent/grads.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Templates for gradient expressions. 15 | 16 | The first argument to the adjoint must be the return value of the primal. 17 | 18 | Use `d[x]` to denote the gradient of a variable `x`. 19 | 20 | If the primal returns a tuple, the first argument to the adjoint is a tuple, 21 | and the adjoint is supposed to define `d[y]` as a tuple. 22 | 23 | Templates do not support use of `**kwargs`. 24 | 25 | If a keyword argument isn't present in the adjoint, it means that Tangent 26 | doesn't support it, and an error will be raised if it appears in user code. 27 | 28 | Adjoints have access to the inputs of the primal, output of the primal, and 29 | gradients with respect to the output. They are expected to contain expressions 30 | for the gradient with respect to the input. They don't have access to any 31 | intermediate variables from the primal. 32 | 33 | """ 34 | from __future__ import absolute_import 35 | 36 | import math 37 | import types 38 | 39 | import gast 40 | import numpy 41 | import tangent 42 | from tangent import tracing 43 | 44 | 45 | # TODO: Avoid requiring non-differentiables to define @tangent_s. 46 | # All non-differentiable function need to create shadow zero-filled variables 47 | # in forward mode. Currently we achieve that by defining identity @tangent_ 48 | # versions of those functions, but a beter approach would be to do that 49 | # automatically. 50 | 51 | # Create decorators that add templates to dictionaries 52 | adjoints = {} 53 | primals = {} 54 | 55 | 56 | def get_module_functions(modules): 57 | """Finds functions that do not have implemented derivatives. 58 | 59 | Args: 60 | modules: A list of Python modules. Functions contained in these modules 61 | will be checked for membership in 'implemented', and if not found, 62 | will be added to an 'unimplemented' set 63 | implemented: A Python object containing implemented derivatives. A function 64 | should be checkable for membership using the `fn in implemented` syntax. 65 | 66 | Returns: 67 | module_fns: A set of functions, builtins or ufuncs in `modules`. 68 | """ 69 | module_fns = set() 70 | for module in modules: 71 | for key in dir(module): 72 | attr = getattr(module, key) 73 | if isinstance( 74 | attr, (types.BuiltinFunctionType, types.FunctionType, numpy.ufunc)): 75 | module_fns.add(attr) 76 | return module_fns 77 | 78 | 79 | def create_register(dict_): 80 | def register(key): 81 | def _(f): 82 | dict_[key] = f 83 | return f 84 | return _ 85 | return register 86 | 87 | 88 | adjoint = create_register(adjoints) 89 | primal = create_register(primals) 90 | 91 | 92 | # Functions: f => f, df 93 | @adjoint(gast.FunctionDef) 94 | def dfunction_def(adjoint_body, return_dx): 95 | def df(): 96 | adjoint_body 97 | return_dx 98 | 99 | 100 | # Control flow 101 | @primal(gast.For) 102 | def for_(body, i, iter_, target, push, push_target, _target, _stack, op_id_iter, 103 | op_id_target): 104 | i = 0 105 | for target in iter_: 106 | _target = target 107 | i += 1 108 | body 109 | push_target(_stack, _target, op_id_target) 110 | push(_stack, i, op_id_iter) 111 | 112 | 113 | @adjoint(gast.For) 114 | def dfor_(adjoint_body, i, pop, pop_target, target, _stack, op_id_iter, 115 | op_id_target): 116 | i = pop(_stack, op_id_iter) 117 | for _ in range(i): 118 | target = pop_target(_stack, op_id_target) 119 | adjoint_body 120 | 121 | 122 | @primal(gast.While) 123 | def while_(body, i, test, push, _stack, op_id): 124 | i = 0 125 | while test: 126 | i += 1 127 | body 128 | push(_stack, i, op_id) 129 | 130 | 131 | @adjoint(gast.While) 132 | def dwhile_(adjoint_body, i, pop, _stack, op_id): 133 | i = pop(_stack, op_id) 134 | for _ in range(i): 135 | adjoint_body 136 | 137 | 138 | @primal(gast.If) 139 | def if_(cond, test, body, orelse, push, _stack, op_id): 140 | cond = test 141 | if cond: 142 | body 143 | else: 144 | orelse 145 | push(_stack, cond, op_id) 146 | 147 | 148 | @adjoint(gast.If) 149 | def dif_(cond, adjoint_body, adjoint_orelse, pop, _stack, op_id): 150 | cond = pop(_stack, op_id) 151 | if cond: 152 | adjoint_body 153 | else: 154 | adjoint_orelse 155 | 156 | 157 | # Binary ops: z = op(x, y) 158 | @adjoint(gast.Mult) 159 | def mult(z, x, y): 160 | d[x] = tangent.unbroadcast(d[z] * y, x) 161 | d[y] = tangent.unbroadcast(d[z] * x, y) 162 | 163 | 164 | @adjoint(gast.Add) 165 | def add(z, x, y): 166 | d[x] = tangent.unbroadcast(d[z], x) 167 | d[y] = tangent.unbroadcast(d[z], y) 168 | 169 | 170 | @adjoint(gast.Pow) 171 | def pow(z, x, y): 172 | d[x] = y * x ** (y - 1) * d[z] 173 | d[y] = numpy.log(x) * x ** y * d[z] 174 | 175 | 176 | @adjoint(gast.Sub) 177 | def sub(z, x, y): 178 | d[x] = tangent.unbroadcast(d[z], x) 179 | d[y] = -tangent.unbroadcast(d[z], y) 180 | 181 | 182 | @adjoint(gast.Div) 183 | def div(z, x, y): 184 | d[x] = d[z] / y 185 | d[y] = -d[z] * x / (y * y) 186 | 187 | 188 | # Unary ops: y = op(x) 189 | @adjoint(gast.USub) 190 | def usub(y, x): 191 | d[x] = -d[y] 192 | 193 | 194 | @adjoint(gast.UAdd) 195 | def uadd(y, x): 196 | d[x] = d[y] 197 | 198 | 199 | # 200 | # NumPy adjoints 201 | # 202 | 203 | 204 | @adjoint(numpy.log) 205 | def log(y, x): 206 | d[x] = d[y] / x 207 | 208 | 209 | @adjoint(numpy.cos) 210 | def cos(y, x): 211 | d[x] = -d[y] * numpy.sin(x) 212 | 213 | 214 | @adjoint(numpy.sin) 215 | def sin(y, x): 216 | d[x] = d[y] * numpy.cos(x) 217 | 218 | 219 | @adjoint(numpy.tan) 220 | def tan(y, x): 221 | cx = numpy.cos(x) 222 | d[x] = d[y] / (cx * cx) 223 | 224 | 225 | @adjoint(numpy.cosh) 226 | def cosh(y, x): 227 | d[x] = d[y] * numpy.sinh(x) 228 | 229 | 230 | @adjoint(numpy.sinh) 231 | def sinh(y, x): 232 | d[x] = d[y] * numpy.cosh(x) 233 | 234 | 235 | @adjoint(numpy.tanh) 236 | def tanh(y, x): 237 | d[x] = d[y] * (1.0 - (y * y)) 238 | 239 | 240 | @adjoint(numpy.arccos) 241 | def arccos(y, x): 242 | d[x] = -d[y] / numpy.sqrt(1.0 - x * x) 243 | 244 | 245 | @adjoint(numpy.arcsin) 246 | def arcsin(y, x): 247 | d[x] = d[y] / numpy.sqrt(1.0 - x * x) 248 | 249 | 250 | @adjoint(numpy.arctan) 251 | def arctan(y, x): 252 | d[x] = d[y] / (1.0 + x * x) 253 | 254 | 255 | @adjoint(numpy.exp) 256 | def exp(y, x): 257 | d[x] = y * d[y] 258 | 259 | 260 | @adjoint(numpy.sqrt) 261 | def sqrt(y, x): 262 | d[x] = d[y] / (2.0 * y) 263 | 264 | 265 | @adjoint(numpy.multiply) 266 | def multiply(z, x, y): 267 | d[x] = y * d[z] 268 | d[y] = x * d[z] 269 | 270 | 271 | @adjoint(numpy.dot) 272 | def dot(y, x1, x2): 273 | d[x1] = tangent.grad_dot(d[y], x1, x2) 274 | d[x2] = numpy.transpose(tangent.grad_dot(numpy.transpose(d[y]), 275 | numpy.transpose(x2), 276 | numpy.transpose(x1))) 277 | 278 | 279 | @adjoint(numpy.atleast_1d) 280 | def atleast_1d(y, x): 281 | d[x] = numpy.reshape(d[y], numpy.shape(x)) 282 | 283 | 284 | @adjoint(numpy.atleast_2d) 285 | def atleast_2d(y, x): 286 | d[x] = numpy.reshape(d[y], numpy.shape(x)) 287 | 288 | 289 | @adjoint(numpy.atleast_3d) 290 | def atleast_3d(y, x): 291 | d[x] = numpy.reshape(d[y], numpy.shape(x)) 292 | 293 | 294 | @adjoint(numpy.reshape) 295 | def reshape(y, x, y_shape): 296 | d[x] = numpy.reshape(d[y], numpy.shape(x)) 297 | 298 | 299 | @adjoint(numpy.transpose) 300 | def transpose(y, x): 301 | d[x] = numpy.transpose(d[y]) 302 | 303 | 304 | @adjoint(numpy.broadcast_arrays) 305 | def broadcast_arrays(ys, *args): 306 | d[args] = tuple(tangent.unbroadcast_to(dy, numpy.shape(arg)) 307 | for arg, dy in zip(args, d[ys])) 308 | 309 | 310 | @adjoint(numpy.sum) 311 | def sum(y, x, axis=None, dtype=None, keepdims=False): 312 | d[x] = tangent.astype(tangent.unreduce(d[y], numpy.shape(x), 313 | axis, keepdims), x) 314 | 315 | 316 | @adjoint(numpy.mean) 317 | def mean(y, x, axis=None, dtype=None, keepdims=False): 318 | n = tangent.astype(tangent.array_size(x, axis), x) 319 | d[x] = tangent.astype(tangent.unreduce(d[y], numpy.shape(x), 320 | axis, keepdims), x) / n 321 | 322 | 323 | @adjoint(numpy.maximum) 324 | def maximum(ans, x, y): 325 | d[x] = d[ans] * tangent.balanced_eq(x, ans, y) 326 | d[y] = d[ans] * tangent.balanced_eq(y, ans, x) 327 | 328 | 329 | @adjoint(numpy.array) 330 | def aarray(ans,x): 331 | d[x] = tangent.astype(d[ans],x) 332 | 333 | 334 | @adjoint(numpy.linalg.det) 335 | def adet(z, x): 336 | """d|A|/dA = adj(A).T 337 | 338 | See Jacobi's formula: https://en.wikipedia.org/wiki/Jacobi%27s_formula 339 | """ 340 | adjugate = numpy.linalg.det(x) * numpy.linalg.pinv(x) 341 | d[x] = d[z] * numpy.transpose(adjugate) 342 | 343 | 344 | # 345 | # Tangent adjoints 346 | # 347 | 348 | 349 | @adjoint(tangent.unreduce) 350 | def aunreduce(y, x, shape, axis, keepdims): 351 | d[x] = tangent.unbroadcast(d[y], x) 352 | 353 | 354 | @adjoint(tangent.unbroadcast) 355 | def aunbroadcast(y, x, shape): 356 | d[x] = tangent.unreduce_like(d[y], x, None, False) 357 | 358 | 359 | @adjoint(tangent.add_grad) 360 | def aadd_grad(z, left, right): 361 | d[left] = tangent.unbroadcast(d[z], left) 362 | d[right] = tangent.unbroadcast(d[z], right) 363 | 364 | 365 | @adjoint(tangent.astype) 366 | def aastype(z, array, y): 367 | d[array] = tangent.astype(d[z], array) 368 | 369 | 370 | @adjoint(tangent.push) 371 | def apush(stack, val, op_id): 372 | d[val] = tangent.pop(stack, d[op_id]) 373 | 374 | 375 | @adjoint(tangent.pop) 376 | def apop(z, stack, op_id): 377 | tangent.push(stack, d[z], d[op_id]) 378 | 379 | 380 | @adjoint(tangent.push_stack) 381 | def apush_stack(stack, val, op_id): 382 | d[val] = tangent.pop_stack(stack, d[op_id]) 383 | 384 | 385 | @adjoint(tangent.pop_stack) 386 | def apop_stack(z, stack, op_id): 387 | tangent.push_stack(stack, d[z], d[op_id]) 388 | 389 | 390 | @adjoint(tangent.copy) 391 | def acopy(z, x): 392 | d[x] = tangent.copy(d[z]) 393 | 394 | # 395 | # Tracing primitives 396 | # 397 | 398 | 399 | @primal(tracing.Traceable) 400 | def traceable_primal(result, fn, vjp, tmp, args): 401 | result, vjp = tangent.trace_grad(fn, args) 402 | 403 | 404 | @adjoint(tracing.Traceable) 405 | def traceable_adjoint(result, vjp, dargs): 406 | dargs = vjp(d[result]) 407 | 408 | 409 | # 410 | # Blacklist unimplemented NumPy grads 411 | # 412 | 413 | # We can enumerate all of the functions that we'd like grads for. 414 | # Until we've written the adjoints of all functions we want to support, 415 | # we will throw an explicit "no grad found" error for those we have not 416 | # finished. UNIMPLEMENTED will contain the list of all of these unimplemented 417 | # grad functions 418 | UNIMPLEMENTED_ADJOINTS = get_module_functions( 419 | (numpy, numpy.fft, numpy.linalg, numpy.random, math)) - set(adjoints) 420 | -------------------------------------------------------------------------------- /tangent/grammar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Classifications of AST nodes.""" 15 | from __future__ import absolute_import 16 | import gast 17 | 18 | LITERALS = (gast.Num, gast.Str, gast.Bytes, gast.Ellipsis, gast.NameConstant) 19 | 20 | CONTROL_FLOW = (gast.For, gast.AsyncFor, gast.While, gast.If, gast.Try, 21 | gast.Break, gast.Continue) 22 | 23 | COMPOUND_STATEMENTS = ( 24 | gast.FunctionDef, 25 | gast.ClassDef, 26 | gast.For, 27 | gast.While, 28 | gast.If, 29 | gast.With, 30 | gast.Try, 31 | gast.AsyncFunctionDef, 32 | gast.AsyncFor, 33 | gast.AsyncWith 34 | ) 35 | 36 | SIMPLE_STATEMENTS = ( 37 | gast.Return, 38 | gast.Delete, 39 | gast.Assign, 40 | gast.AugAssign, 41 | gast.Raise, 42 | gast.Assert, 43 | gast.Import, 44 | gast.ImportFrom, 45 | gast.Global, 46 | gast.Nonlocal, 47 | gast.Expr, 48 | gast.Pass, 49 | gast.Break, 50 | gast.Continue 51 | ) 52 | 53 | STATEMENTS = COMPOUND_STATEMENTS + SIMPLE_STATEMENTS 54 | 55 | BLOCKS = ( 56 | (gast.Module, 'body'), 57 | (gast.FunctionDef, 'body'), 58 | (gast.AsyncFunctionDef, 'body'), 59 | (gast.For, 'body'), 60 | (gast.For, 'orelse'), 61 | (gast.AsyncFor, 'body'), 62 | (gast.AsyncFor, 'orelse'), 63 | (gast.While, 'body'), 64 | (gast.While, 'orelse'), 65 | (gast.If, 'body'), 66 | (gast.If, 'orelse'), 67 | ) 68 | -------------------------------------------------------------------------------- /tangent/naming.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tools for naming conventions.""" 15 | from __future__ import absolute_import 16 | import random 17 | import re 18 | import types 19 | 20 | import gast 21 | import six 22 | 23 | PRIMAL_NAME = 'pri_{}{}' 24 | ADJOINT_NAME = '_d{}d{}' 25 | TANGENT_NAME = '_t{}t{}' 26 | JOINT_NAME = 'd{}d{}' 27 | STACK_NAME = '_stack' 28 | SUBSTACK_NAME = '_substack' 29 | 30 | 31 | def primal_name(func, wrt): 32 | """Name for the primal of a function.""" 33 | if not isinstance(func, types.FunctionType): 34 | raise TypeError(func) 35 | varnames = six.get_function_code(func).co_varnames 36 | return PRIMAL_NAME.format(func.__name__, ''.join(varnames[i] for i in wrt)) 37 | 38 | 39 | def _adjoint_name(func, wrt, template): 40 | if not isinstance(func, types.FunctionType): 41 | raise TypeError 42 | varnames = six.get_function_code(func).co_varnames 43 | return template.format(func.__name__, ''.join(varnames[i] for i in wrt)) 44 | 45 | 46 | def joint_name(func, wrt): 47 | """Name for a function in joint mode.""" 48 | return _adjoint_name(func, wrt, JOINT_NAME) 49 | 50 | 51 | def adjoint_name(func, wrt): 52 | """Name for the adjoint of a function.""" 53 | return _adjoint_name(func, wrt, ADJOINT_NAME) 54 | 55 | 56 | def tangent_name(func, wrt): 57 | """Name for a function in forward mode.""" 58 | return _adjoint_name(func, wrt, TANGENT_NAME) 59 | 60 | 61 | class Names(gast.NodeVisitor): 62 | 63 | def __init__(self): 64 | self.names = set() 65 | 66 | def visit_Name(self, node): 67 | if isinstance(node.ctx, (gast.Store, gast.Param)): 68 | self.names.add(node.id) 69 | 70 | 71 | def get_names(node): 72 | """Find the arguments and variables assigned to in a certain node.""" 73 | names = Names() 74 | names.visit(node) 75 | return names.names 76 | 77 | 78 | def uniqify(func): 79 | """Make sure that a method returns a unique name.""" 80 | @six.wraps(func) 81 | def unique(self, *args, **kwargs): 82 | return self.unique(func(self, *args, **kwargs)) 83 | return unique 84 | 85 | 86 | def uniqify_once(func): 87 | """Make sure that a method returns a unique name.""" 88 | @six.wraps(func) 89 | def unique_once(self, *args, **kwargs): 90 | return self.unique_once(func(self, *args, **kwargs)) 91 | return unique_once 92 | 93 | 94 | class Namer(object): 95 | """Generate human-readable names for AST nodes. 96 | 97 | Given an AST node, this class tries to produce a sensible variable name 98 | that it could be subtituted with. 99 | 100 | In principle, it will try to construct sensible names from the operands and 101 | operator e.g. `x + y` becomes `x_plus_y`. However, the length of these 102 | variable names can quickly explode. In that case, we try to back off to using 103 | the left hand side of the statement if possible e.g. in `z = f(x + y)` the 104 | expression `x + y` could be named `_z`. 105 | 106 | In case the LHS is not available (because it wasn't given by the calling 107 | code) or if the LHS name is too long, we fall back to assigning random 108 | variable names. 109 | 110 | Some methods (such as `grad`) will return the same name when called with the 111 | same inputs. 112 | 113 | Attributes: 114 | names: A set of variable names that cannot be used. Allowed to be changed. 115 | target: The node that is on the LHS of the current statement. Is `None` by 116 | default. Should be set by the calling code. 117 | """ 118 | # Naming convention from 'Evaluating Derivatives', b is rev, d is fwd 119 | ADJOINT_VAR = 'b{}' 120 | TANGENT_VAR = 'd{}' 121 | TEMP_VAR = '_{}' 122 | TEMP_ADJOINT_VAR = '_b{}' 123 | TEMP_TANGENT_VAR = '_d{}' 124 | 125 | MAX_LENGTH = 15 126 | 127 | def __init__(self): 128 | self.names = set() 129 | self.name_mappings = dict() 130 | # The targets field of the LHS whenever a node inside an assign statement 131 | # is being named 132 | self.target = None 133 | 134 | @classmethod 135 | def build(cls, node): 136 | """Construct a namer object for a given function scope.""" 137 | if not isinstance(node, gast.FunctionDef): 138 | raise ValueError 139 | namer = cls() 140 | namer.names.update(get_names(node)) 141 | return namer 142 | 143 | def valid(self, name): 144 | """Ensure a variable name is valid. 145 | 146 | Note: Assumes variable names are ASCII, which isn't necessarily true in 147 | Python 3. 148 | 149 | Args: 150 | name: A proposed variable name. 151 | 152 | Returns: 153 | A valid version of the name. 154 | """ 155 | name = re.sub('[^0-9a-zA-Z_]', '', name) 156 | if re.match('[0-9]', name): 157 | name = '_' + name 158 | return name 159 | 160 | def trim(self, name): 161 | """When the name is too long, use the LHS or a random string instead.""" 162 | if len(name) > self.MAX_LENGTH and self.target: 163 | name = self.TEMP_VAR.format(self._name(self.target)) 164 | if len(name) > self.MAX_LENGTH: 165 | while True: 166 | name = '_{:04x}'.format(random.randint(0, 16 ** 4 - 1)) 167 | if name not in self.names: 168 | break 169 | return name 170 | 171 | def unique(self, name): 172 | """Make a variable name unique by appending a number if needed.""" 173 | # Make sure the name is valid 174 | name = self.valid(name) 175 | # Make sure it's not too long 176 | name = self.trim(name) 177 | # Now make sure it's unique 178 | unique_name = name 179 | i = 2 180 | while unique_name in self.names: 181 | unique_name = name + str(i) 182 | i += 1 183 | self.names.add(unique_name) 184 | return unique_name 185 | 186 | def unique_once(self, name): 187 | if name not in self.name_mappings: 188 | unique_name = self.unique(name) 189 | self.name_mappings[name] = unique_name 190 | return self.name_mappings[name] 191 | 192 | def __getattr__(self, attr): 193 | """Access unwrapped versions of methods. 194 | 195 | Methods are wrapped with `uniqify` to return a unique version of a 196 | name. Internally the class however might want to use the original 197 | version of these methods. This method makes those accessible by using a 198 | leading underscore. 199 | """ 200 | if attr.startswith('_') and hasattr(self, attr[1:]): 201 | return getattr(self, attr[1:]).__wrapped__.__get__(self, Namer) 202 | raise AttributeError 203 | 204 | @uniqify 205 | def name(self, node): 206 | namer = getattr(self, 'name_' + node.__class__.__name__) 207 | return namer(node) 208 | 209 | @uniqify 210 | def counter(self): 211 | return 'i' 212 | 213 | @uniqify_once 214 | def grad(self, name, tangent=False): 215 | if tangent: 216 | var_template = self.TANGENT_VAR 217 | else: 218 | var_template = self.ADJOINT_VAR 219 | return var_template.format(name) 220 | 221 | @uniqify 222 | def temp_grad(self, name, tangent=False): 223 | if tangent: 224 | var_template = self.TEMP_TANGENT_VAR 225 | else: 226 | var_template = self.TEMP_ADJOINT_VAR 227 | return var_template.format(name) 228 | 229 | @uniqify_once 230 | def temp(self, name): 231 | return self.TEMP_VAR.format(name) 232 | 233 | @uniqify 234 | def cond(self): 235 | return 'cond' 236 | 237 | def name_Name(self, node): 238 | return node.id 239 | 240 | def name_Return(self, node): 241 | return 'return' 242 | 243 | def name_Tuple(self, node): 244 | return 't' 245 | 246 | def name_List(self, node): 247 | return 'l' 248 | 249 | def name_Call(self, node): 250 | if len(node.args) <= 2: 251 | return (self._name(node.func) + '_' + 252 | '_'.join(self._name(arg) for arg in node.args)) 253 | else: 254 | return self._name(node.func) 255 | 256 | def name_Attribute(self, node): 257 | return self._name(node.value) + '_' + node.attr 258 | 259 | def name_Subscript(self, node): 260 | return self._name(node.value) + '_' + self._name(node.slice) 261 | 262 | def name_Index(self, node): 263 | return self._name(node.value) 264 | 265 | def name_Slice(self, node): 266 | return ''.join(self._name(i) if i else '' 267 | for i in (node.lower, node.upper, node.step)) 268 | 269 | def name_ExtSlice(self, node): 270 | return '_'.join(self._name(d) for d in node.dims) 271 | 272 | def name_Num(self, node): 273 | num_str = str(node.n) 274 | num_str = num_str.replace('.', '_') 275 | num_str = num_str.replace('-', 'neg') 276 | num_str = num_str.replace('+', 'plus') 277 | return num_str 278 | 279 | def name_Str(self, node): 280 | return node.s 281 | 282 | BINOP_NAMES = { 283 | gast.Add: 'plus', 284 | gast.Sub: 'minus', 285 | gast.Mult: 'times', 286 | gast.Div: 'over', 287 | gast.FloorDiv: 'intdiv', 288 | gast.Mod: 'modulo', 289 | gast.Pow: 'to_the', 290 | gast.MatMult: 'times' 291 | } 292 | 293 | def name_BinOp(self, node): 294 | return '{left}_{op}_{right}'.format(left=self._name(node.left), 295 | right=self._name(node.right), 296 | op=self.BINOP_NAMES[type(node.op)]) 297 | 298 | UNARYOP_NAMES = { 299 | gast.UAdd: 'plus', 300 | gast.USub: 'minus', 301 | gast.Not: 'not' 302 | } 303 | 304 | def name_UnaryOp(self, node): 305 | return '{op}_{operand}'.format(op=self.UNARYOP_NAMES[type(node.op)], 306 | operand=self._name(node.operand)) 307 | -------------------------------------------------------------------------------- /tangent/non_differentiable.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Non-differentiable functions. 15 | 16 | Not in the mathematical sense, but in the sense of them providing zero gradient 17 | because they provide meta-information (shape) do integer arithmetic, or are 18 | tensor constructors. 19 | 20 | Note that one still needs to provide tangents for non-differentiable functions, 21 | but these should simply call the original. 22 | TODO: Remove this requirement. 23 | """ 24 | from __future__ import absolute_import 25 | 26 | import numpy 27 | import tangent 28 | 29 | 30 | NON_DIFFERENTIABLE = set([ 31 | len, 32 | numpy.shape, numpy.zeros, numpy.ones, numpy.zeros_like, numpy.ones_like, 33 | ]) 34 | 35 | 36 | def register_non_differentiable_functions(*funcs): 37 | global NON_DIFFERENTIABLE 38 | NON_DIFFERENTIABLE |= set(funcs) 39 | -------------------------------------------------------------------------------- /tangent/optimization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Functions which perform compiler-style optimizations on the AST.""" 15 | from __future__ import absolute_import 16 | from collections import defaultdict 17 | import gast 18 | 19 | from tangent import annotate 20 | from tangent import annotations as anno 21 | from tangent import cfg 22 | from tangent import quoting 23 | from tangent import transformers 24 | 25 | 26 | def fixed_point(f): 27 | 28 | def _fp(node): 29 | while True: 30 | before = quoting.to_source(node) 31 | node = f(node) 32 | after = quoting.to_source(node) 33 | if before == after: 34 | break 35 | return node 36 | 37 | return _fp 38 | 39 | 40 | @fixed_point 41 | def optimize(node): 42 | """Perform a series of optimization passes. 43 | 44 | This function performs a series of optimizations (dead code elimination, 45 | constant folding, variable folding) on the given AST. 46 | It optimizes the code repeatedly until reaching a fixed point. The fixed 47 | point is determine roughly by checking whether the number of lines of 48 | generated source code changed after the latest pass. 49 | 50 | Args: 51 | node: The AST to optimize. 52 | Returns: 53 | The optimized AST. 54 | """ 55 | node = dead_code_elimination(node) 56 | node = constant_folding(node) 57 | node = assignment_propagation(node) 58 | return node 59 | 60 | 61 | @fixed_point 62 | def dead_code_elimination(node): 63 | """Perform a simple form of dead code elimination on a Python AST. 64 | 65 | This method performs reaching definitions analysis on all function 66 | definitions. It then looks for the definition of variables that are not used 67 | elsewhere and removes those definitions. 68 | 69 | This function takes into consideration push and pop statements; if a pop 70 | statement is removed, it will also try to remove the accompanying push 71 | statement. Note that this *requires dead code elimination to be performed on 72 | the primal and adjoint simultaneously*. 73 | 74 | Args: 75 | node: The AST to optimize. 76 | 77 | Returns: 78 | The optimized AST. 79 | """ 80 | to_remove = set(def_[1] for def_ in annotate.unused(node) 81 | if not isinstance(def_[1], (gast.arguments, gast.For))) 82 | for n in list(to_remove): 83 | for succ in gast.walk(n): 84 | if anno.getanno(succ, 'push', False): 85 | to_remove.add(anno.getanno(succ, 'push')) 86 | transformers.Remove(to_remove).visit(node) 87 | anno.clearanno(node) 88 | return node 89 | 90 | 91 | class ReadCounts(gast.NodeVisitor): 92 | """Find the number of times that each definition is used. 93 | 94 | Requires `ReachingDefinitions` analysis to have been performed. 95 | """ 96 | 97 | def __init__(self): 98 | self.n_read = defaultdict(int) 99 | 100 | def visit(self, node): 101 | if anno.hasanno(node, 'definitions_in'): 102 | self.reaching_definitions = anno.getanno(node, 'definitions_in') 103 | super(ReadCounts, self).visit(node) 104 | if anno.hasanno(node, 'definitions_in'): 105 | self.reaching_definitions = None 106 | 107 | def visit_Name(self, node): 108 | if isinstance(node.ctx, gast.Load): 109 | for def_ in self.reaching_definitions: 110 | if def_[0] == node.id: 111 | self.n_read[def_[1]] += 1 112 | 113 | 114 | def read_counts(node): 115 | """Check how many times a variable definition was used. 116 | 117 | Args: 118 | node: An AST to analyze. 119 | 120 | Returns: 121 | A dictionary from assignment nodes to the number of times the assigned to 122 | variable was used. 123 | """ 124 | cfg.forward(node, cfg.ReachingDefinitions()) 125 | 126 | rc = ReadCounts() 127 | rc.visit(node) 128 | return rc.n_read 129 | 130 | 131 | @fixed_point 132 | def assignment_propagation(node): 133 | """Perform assignment propagation. 134 | 135 | Assignment propagation is not a compiler optimization as much as a 136 | readability optimization. If a variable name is used only once, it gets 137 | renamed when possible e.g. `y = x; z = y` will become `z = x`. 138 | 139 | Args: 140 | node: The AST to optimize. 141 | 142 | Returns: 143 | The optimized AST. 144 | """ 145 | n_reads = read_counts(node) 146 | 147 | to_remove = [] 148 | for succ in gast.walk(node): 149 | # We found an assignment of the form a = b 150 | # - Left-hand side is a Name, right-hand side is a Name. 151 | if (isinstance(succ, gast.Assign) and isinstance(succ.value, gast.Name) and 152 | len(succ.targets) == 1 and isinstance(succ.targets[0], gast.Name)): 153 | rhs_name = succ.value.id 154 | # We now find all the places that b was defined 155 | rhs_defs = [def_[1] for def_ in anno.getanno(succ, 'definitions_in') 156 | if def_[0] == rhs_name] 157 | # If b was defined in only one place (not an argument), and wasn't used 158 | # anywhere else but in a == b, and was defined as b = x, then we can fold 159 | # the statements 160 | if (len(rhs_defs) == 1 and isinstance(rhs_defs[0], gast.Assign) and 161 | n_reads[rhs_defs[0]] == 1 and 162 | isinstance(rhs_defs[0].value, gast.Name) and 163 | isinstance(rhs_defs[0].targets[0], gast.Name)): 164 | # Mark rhs_def for deletion 165 | to_remove.append(rhs_defs[0]) 166 | # Propagate the definition 167 | succ.value = rhs_defs[0].value 168 | 169 | # Remove the definitions we folded 170 | transformers.Remove(to_remove).visit(node) 171 | anno.clearanno(node) 172 | return node 173 | 174 | 175 | class ConstantFolding(gast.NodeTransformer): 176 | 177 | def visit_BinOp(self, node): 178 | self.generic_visit(node) 179 | left_val = node.left 180 | right_val = node.right 181 | left_is_num = isinstance(left_val, gast.Num) 182 | right_is_num = isinstance(right_val, gast.Num) 183 | 184 | if isinstance(node.op, gast.Mult): 185 | if left_is_num and right_is_num: 186 | return gast.Num(left_val.n * right_val.n) 187 | if left_is_num: 188 | if left_val.n == 0: 189 | return gast.Num(0) 190 | elif left_val.n == 1: 191 | return right_val 192 | if right_is_num: 193 | if right_val.n == 0: 194 | return gast.Num(0) 195 | elif right_val.n == 1: 196 | return left_val 197 | elif isinstance(node.op, gast.Add): 198 | if left_is_num and right_is_num: 199 | return gast.Num(left_val.n + right_val.n) 200 | if left_is_num and left_val.n == 0: 201 | return right_val 202 | if right_is_num and right_val.n == 0: 203 | return left_val 204 | elif isinstance(node.op, gast.Sub): 205 | if left_is_num and right_is_num: 206 | return gast.Num(left_val.n - right_val.n) 207 | if left_is_num and left_val.n == 0: 208 | return gast.UnaryOp(op=gast.USub(), operand=right_val) 209 | if right_is_num and right_val.n == 0: 210 | return left_val 211 | elif isinstance(node.op, gast.Div): 212 | if left_is_num and right_is_num: 213 | return gast.Num(left_val.n / right_val.n) 214 | if right_is_num and right_val.n == 1: 215 | return left_val 216 | elif isinstance(node.op, gast.Pow): 217 | if left_is_num and right_is_num: 218 | return gast.Num(left_val.n ** right_val.n) 219 | if left_is_num: 220 | if left_val.n == 0: 221 | return gast.Num(0) 222 | elif left_val.n == 1: 223 | return gast.Num(1) 224 | if right_is_num: 225 | if right_val.n == 0: 226 | return gast.Num(1) 227 | elif right_val.n == 1: 228 | return left_val 229 | return node 230 | 231 | 232 | @fixed_point 233 | def constant_folding(node): 234 | """Perform constant folding. 235 | 236 | This function also uses arithmetic identities (like multiplying with one or 237 | adding zero) to simplify statements. However, it doesn't inline constants in 238 | expressions, so the simplifications don't propagate. 239 | 240 | Args: 241 | node: The AST to optimize. 242 | 243 | Returns: 244 | The optimized AST. 245 | """ 246 | f = ConstantFolding() 247 | return f.visit(node) 248 | -------------------------------------------------------------------------------- /tangent/quoting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Moving between source code and AST.""" 15 | from __future__ import absolute_import 16 | import inspect 17 | import textwrap 18 | 19 | import astor 20 | import gast 21 | 22 | from tangent import annotations as anno 23 | 24 | 25 | class TangentParseError(SyntaxError): 26 | pass 27 | 28 | 29 | class SourceWithCommentGenerator(astor.codegen.SourceGenerator): 30 | """Source code generator that outputs comments.""" 31 | 32 | def __init__(self, *args, **kwargs): 33 | super(SourceWithCommentGenerator, self).__init__(*args, **kwargs) 34 | self.new_indentation = True 35 | 36 | def body(self, statements): 37 | self.new_indentation = True 38 | super(SourceWithCommentGenerator, self).body(statements) 39 | 40 | def visit(self, node, abort=astor.codegen.SourceGenerator.abort_visit): 41 | if anno.hasanno(node, 'comment'): 42 | comment = anno.getanno(node, 'comment') 43 | # Preprocess the comment to fit to maximum line width of 80 characters 44 | linewidth = 78 45 | if comment['location'] in ('above', 'below'): 46 | comment['text'] = comment['text'][:linewidth] 47 | n_newlines = 1 if self.new_indentation else 2 48 | if comment['location'] == 'above': 49 | self.result.append('\n' * n_newlines) 50 | self.result.append(self.indent_with * self.indentation) 51 | self.result.append('# %s' % comment['text']) 52 | super(SourceWithCommentGenerator, self).visit(node) 53 | elif comment['location'] == 'below': 54 | super(SourceWithCommentGenerator, self).visit(node) 55 | self.result.append('\n') 56 | self.result.append(self.indent_with * self.indentation) 57 | self.result.append('# %s' % comment['text']) 58 | self.result.append('\n' * (n_newlines - 1)) 59 | elif comment['location'] == 'right': 60 | super(SourceWithCommentGenerator, self).visit(node) 61 | self.result.append(' # %s' % comment['text']) 62 | else: 63 | raise TangentParseError('Only valid comment locations are ' 64 | 'above, below, right') 65 | else: 66 | self.new_indentation = False 67 | super(SourceWithCommentGenerator, self).visit(node) 68 | 69 | 70 | def to_source(node, indentation=' ' * 4): 71 | """Return source code of a given AST.""" 72 | if isinstance(node, gast.AST): 73 | node = gast.gast_to_ast(node) 74 | generator = SourceWithCommentGenerator(indentation, False, 75 | astor.string_repr.pretty_string) 76 | generator.visit(node) 77 | generator.result.append('\n') 78 | return astor.source_repr.pretty_source(generator.result).lstrip() 79 | 80 | 81 | def parse_function(fn): 82 | """Get the source of a function and return its AST.""" 83 | try: 84 | return parse_string(inspect.getsource(fn)) 85 | except (IOError, OSError) as e: 86 | raise ValueError( 87 | 'Cannot differentiate function: %s. Tangent must be able to access the ' 88 | 'source code of the function. Functions defined in a Python ' 89 | 'interpreter and functions backed by C extension modules do not ' 90 | 'have accessible source code.' % e) 91 | 92 | 93 | def parse_string(src): 94 | """Parse a string into an AST.""" 95 | return gast.parse(textwrap.dedent(src)) 96 | 97 | 98 | def quote(src_string, return_expr=False): 99 | """Go from source code to AST nodes. 100 | 101 | This function returns a tree without enclosing `Module` or `Expr` nodes. 102 | 103 | Args: 104 | src_string: The source code to parse. 105 | return_expr: Whether or not to return a containing expression. This can be 106 | set to `True` if the result is to be part of a series of statements. 107 | 108 | Returns: 109 | An AST of the given source code. 110 | 111 | """ 112 | node = parse_string(src_string) 113 | body = node.body 114 | if len(body) == 1: 115 | if isinstance(body[0], gast.Expr) and not return_expr: 116 | out = body[0].value 117 | else: 118 | out = body[0] 119 | else: 120 | out = node 121 | return out 122 | 123 | 124 | def unquote(node): 125 | """Go from an AST to source code.""" 126 | return to_source(node).strip() 127 | -------------------------------------------------------------------------------- /tangent/tangents.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Templates for tangent expressions. 15 | 16 | The first argument to the tangent must be the return value of the primal. 17 | 18 | Use `d[x]` to denote the derivative of a variable `x`. 19 | 20 | If the primal returns a tuple, the first argument to the tangent is a tuple, 21 | and the adjoint is supposed to define `d[y]` as a tuple. 22 | 23 | Templates do not support use of `**kwargs`. 24 | 25 | If a keyword argument isn't present in the tangent compound statements, it means 26 | that Tangent doesn't support it, and an error will be raised if it appears in 27 | user code. 28 | 29 | Tangents have access to the inputs and outputs of the primal. They are expected 30 | to contain expressions for the derivative with respect to the output. They don't 31 | have access to any intermediate variables from the primal. 32 | """ 33 | from __future__ import absolute_import 34 | 35 | import math 36 | 37 | import gast 38 | import numpy 39 | import tangent 40 | from tangent import grads 41 | 42 | tangents = {} 43 | tangent_ = grads.create_register(tangents) 44 | 45 | 46 | # 47 | # AST tangents 48 | # 49 | 50 | 51 | @tangent_(gast.Assign) 52 | def tassign(temp, tangent, target, value): 53 | temp = value 54 | tangent 55 | target = temp 56 | 57 | 58 | @tangent_(gast.Num) 59 | def tnum(z, x): 60 | d[z] = tangent.init_grad(x) 61 | 62 | 63 | @tangent_(gast.Name) 64 | def tname(z, x): 65 | d[z] = d[x] 66 | 67 | 68 | @tangent_(gast.Attribute) 69 | def tattr(z, x): 70 | d[z] = tangent.init_grad(x) 71 | 72 | 73 | @tangent_(gast.Subscript) 74 | def tsubscript(z, x): 75 | d[z] = d[x] 76 | 77 | 78 | # For a reference for primitive tangents, see: 79 | # https://en.wikipedia.org/wiki/Automatic_differentiation#Automatic_differentiation_using_dual_numbers 80 | # or 81 | # https://en.wikipedia.org/wiki/Differentiation_rules 82 | # Note that we don't use "dual numbers", that's a data structure that's useful 83 | # for doing run-time forward-mode automatic differentiation. We're doing 84 | # compile-time autodiff, and we can keep track of the directional derivatives 85 | # in individual variables, with no need to store them alongside the 86 | # original values. 87 | 88 | 89 | @tangent_(gast.Add) 90 | def tadd(z, x, y): 91 | d[z] = d[x] + d[y] 92 | 93 | 94 | @tangent_(gast.Mult) 95 | def tmult(z, x, y): 96 | d[z] = d[x] * y + x * d[y] 97 | 98 | 99 | @tangent_(gast.Sub) 100 | def tsub(z, x, y): 101 | d[z] = d[x] - d[y] 102 | 103 | 104 | @tangent_(gast.Div) 105 | def tdiv(z, x, y): 106 | d[z] = (d[x] * y - x * d[y]) / (y * y) 107 | 108 | 109 | @tangent_(gast.Pow) 110 | def tpow(z, x, y): 111 | d[z] = y * (x ** (y - 1.0)) * d[x] 112 | 113 | 114 | @tangent_(gast.USub) 115 | def tusub(z, x): 116 | d[z] = -d[x] 117 | 118 | 119 | # 120 | # Collection tangents 121 | # 122 | 123 | 124 | @tangent_(tuple) 125 | def ttangent(z, x): 126 | d[z] = tuple(d[x]) 127 | 128 | 129 | @tangent_(list) 130 | def tlist(z, x): 131 | d[z] = list(d[x]) 132 | 133 | 134 | # 135 | # NumPy tangents 136 | # 137 | 138 | 139 | @tangent_(numpy.cos) 140 | def tcos(z, x): 141 | d[z] = -d[x] * numpy.sin(x) 142 | 143 | 144 | @tangent_(numpy.sin) 145 | def tsin(z, x): 146 | d[z] = d[x] * numpy.cos(x) 147 | 148 | 149 | @tangent_(numpy.tan) 150 | def ttan(z, x): 151 | cx = numpy.cos(x) 152 | d[z] = d[x] / (cx * cx) 153 | 154 | 155 | @tangent_(numpy.cosh) 156 | def tcosh(z, x): 157 | d[z] = d[x] * numpy.sinh(x) 158 | 159 | 160 | @tangent_(numpy.sinh) 161 | def tsinh(z, x): 162 | d[z] = d[x] * numpy.cosh(x) 163 | 164 | 165 | @tangent_(numpy.tanh) 166 | def ttanh(z, x): 167 | cx = numpy.cosh(x) 168 | d[z] = d[x] / (cx * cx) 169 | 170 | 171 | @tangent_(numpy.arccos) 172 | def tarccos(z, x): 173 | d[z] = -d[x] / numpy.sqrt(1.0 - x * x) 174 | 175 | 176 | @tangent_(numpy.arcsin) 177 | def tarcsin(z, x): 178 | d[z] = d[x] / numpy.sqrt(1.0 - x * x) 179 | 180 | 181 | @tangent_(numpy.arctan) 182 | def tarctan(z, x): 183 | d[z] = d[x] / (1.0 + x * x) 184 | 185 | 186 | @tangent_(numpy.exp) 187 | def texp(z, x): 188 | d[z] = d[x] * z 189 | 190 | 191 | @tangent_(numpy.log) 192 | def tlog(z, x): 193 | d[z] = d[x] / x 194 | 195 | 196 | @tangent_(numpy.sqrt) 197 | def tsqrt(z, x): 198 | d[z] = d[x] / (2 * z) 199 | 200 | 201 | @tangent_(numpy.dot) 202 | def tdot(z, x, y): 203 | d[z] = numpy.dot(d[x], y) + numpy.dot(x, d[y]) 204 | 205 | 206 | @tangent_(numpy.atleast_1d) 207 | def tatleast_1d(z, x): 208 | d[z] = numpy.atleast_1d(d[x]) 209 | 210 | 211 | @tangent_(numpy.atleast_2d) 212 | def tatleast_2d(z, x): 213 | d[z] = numpy.atleast_2d(d[x]) 214 | 215 | 216 | @tangent_(numpy.atleast_3d) 217 | def tatleast_3d(z, x): 218 | d[z] = numpy.atleast_3d(d[x]) 219 | 220 | 221 | @tangent_(numpy.transpose) 222 | def ttranspose(z, x): 223 | d[z] = numpy.transpose(d[x]) 224 | 225 | 226 | @tangent_(numpy.sum) 227 | def tsum(y, x, axis=None, dtype=None, keepdims=False): 228 | d[y] = numpy.sum(d[x], axis=axis, dtype=dtype, keepdims=keepdims) 229 | 230 | 231 | @tangent_(numpy.mean) 232 | def tmean( 233 | y, x, axis=None, dtype=None, keepdims=False): 234 | d[y] = numpy.mean(d[x], axis=axis, dtype=dtype, keepdims=keepdims) 235 | 236 | 237 | @tangent_(numpy.multiply) 238 | def tmultiply(z, x, y): 239 | d[z] = numpy.multiply(d[x], y) + numpy.multiply(x, d[y]) 240 | 241 | 242 | @tangent_(numpy.arange) 243 | def tarange(z, stop): 244 | d[z] = numpy.zeros_like(z) 245 | 246 | 247 | @tangent_(numpy.ndim) 248 | def tndim(z, x): 249 | d[z] = numpy.ndim(d[x]) 250 | 251 | 252 | @tangent_(numpy.rollaxis) 253 | def trollaxis(z, a, axis, start=0): 254 | d[z] = numpy.rollaxis(d[a], axis, start) 255 | 256 | 257 | @tangent_(numpy.shape) 258 | def tshape(z, x): 259 | d[z] = numpy.shape(d[x]) 260 | 261 | 262 | @tangent_(numpy.array) 263 | def tarray(z, x): 264 | d[z] = numpy.array(d[x]) 265 | 266 | 267 | # 268 | # Tangent tangents 269 | # 270 | 271 | 272 | @tangent_(tangent.add_grad) 273 | def tadd_grad(z, x, y): 274 | d[z] = tangent.add_grad(d[x], d[y]) 275 | 276 | 277 | @tangent_(tangent.init_grad) 278 | def tinit_grad(z, x, allow_lazy_initializer=False): 279 | d[z] = tangent.init_grad(d[x], allow_lazy_initializer=False) 280 | 281 | 282 | @tangent_(tangent.push) 283 | def tpush(x, stack, op_id): 284 | tangent.push(d[stack], d[x], d[op_id]) 285 | 286 | 287 | @tangent_(tangent.push_stack) 288 | def tpush_stack(x, stack, op_id): 289 | tangent.push_stack(d[stack], d[x], d[op_id]) 290 | 291 | 292 | @tangent_(tangent.pop) 293 | def tpop(x, stack, op_id): 294 | d[x] = tangent.pop(d[stack], d[op_id]) 295 | 296 | 297 | @tangent_(tangent.pop_stack) 298 | def tpop_stack(x, stack, op_id): 299 | d[x] = tangent.pop_stack(d[stack], d[op_id]) 300 | 301 | 302 | @tangent_(tangent.unbroadcast) 303 | def tunbroadcast(z, x, y): 304 | d[z] = tangent.unbroadcast(d[x], d[y]) 305 | 306 | 307 | @tangent_(tangent.Stack) 308 | def tstack(z): 309 | d[z] = tangent.Stack() 310 | 311 | 312 | @tangent_(tangent.astype) 313 | def tastype(z, x, y): 314 | d[z] = tangent.astype(d[x], d[y]) 315 | 316 | 317 | @tangent_(tangent.unreduce) 318 | def tunreduce(z, array, shape, axis, keepdims): 319 | d[z] = tangent.unreduce(d[array], d[shape], axis, keepdims) 320 | 321 | 322 | 323 | # Until we've written the adjoints of all functions we want to support, 324 | # we will throw an explicit "no tangent found" error for those we have not 325 | # finished. UNIMPLEMENTED will contain the list of all of these unimplemented 326 | # tangent functions 327 | 328 | UNIMPLEMENTED_TANGENTS = grads.get_module_functions( 329 | (numpy, numpy.fft, numpy.linalg, numpy.random, math)) - set(tangents) 330 | -------------------------------------------------------------------------------- /tangent/template.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Helper functions and classes for filling in templates. 15 | 16 | Functions can be used as templates. In this case, all the variables to be 17 | replaced should be function arguments. This allows static analysis to still 18 | work. For simple templates nodes can be passed as well. 19 | 20 | """ 21 | from __future__ import absolute_import 22 | 23 | import types 24 | import enum 25 | 26 | import gast 27 | import six 28 | from tangent import annotations as anno 29 | from tangent import ast as ast_ 30 | from tangent import create 31 | from tangent import naming 32 | from tangent import quoting 33 | from tangent import transformers 34 | 35 | 36 | class ReplaceTransformer(gast.NodeTransformer): 37 | """Replace variables with AST nodes. 38 | 39 | The context of the replacements is automatically set to load or store. 40 | 41 | """ 42 | 43 | def __init__(self, replacements): 44 | self.replacements = replacements 45 | self.seen = set() 46 | self.is_top = True 47 | 48 | def visit_Expr(self, node): 49 | if (isinstance(node.value, gast.Name) and 50 | node.value.id in self.replacements): 51 | return self.visit(node.value) 52 | self.generic_visit(node) 53 | return node 54 | 55 | def visit_FunctionDef(self, node): 56 | node = self.generic_visit(node) 57 | if node.name in self.replacements: 58 | node.name = self.replacements[node.name].id 59 | return node 60 | 61 | def visit_Name(self, node): 62 | if node.id in self.replacements: 63 | # NOTE In principle we don't want to copy, because it might break 64 | # references held in annotations, but we will copy if we have to to 65 | # avoid duplicate nodes 66 | if node.id in self.seen: 67 | new_nodes = ast_.copy_node(self.replacements[node.id]) 68 | else: 69 | self.seen.add(node.id) 70 | new_nodes = self.replacements[node.id] 71 | if isinstance(new_nodes, gast.AST): 72 | new_nodes = [new_nodes] 73 | for new_node in new_nodes: 74 | anno.setanno(new_node, 'replacement', node, safe=False) 75 | if 'ctx' in new_node._fields: 76 | new_node.ctx = node.ctx 77 | if len(new_nodes) == 1: 78 | new_nodes, = new_nodes 79 | return new_nodes 80 | else: 81 | return node 82 | 83 | 84 | Replace = enum.Enum('Replace', ['NONE', 'PARTIAL', 'FULL', 'TANGENT']) 85 | 86 | 87 | class ReplaceGradTransformer(transformers.TreeTransformer): 88 | """Interpret the gradient operator `d[x]` in templates. 89 | 90 | The gradient of a temporary variable is the normal gradient i.e. d[_x] = 91 | dx. 92 | 93 | Args: 94 | replace_grad: One of the enumerated `Replace` values. If `PARTIAL` then 95 | `d[x]` will be transformed into the gradient `bx` when read, but 96 | transformed into a temporary variable (e.g. `_bx`) when written to. 97 | This ensures that the gradient `bx` doesn't get overwritten if it 98 | already exists. If the mode is `FULL` then `d[x]` becomes the gradient 99 | `bx` everywhere. `TANGENT` functions as `FULL` but creates the tangent 100 | instead of the adjoint i.e. `dx`. 101 | namer: A `Namer` object which decides on the names to give to the 102 | gradients. This guarantess temporaries receiving unique names. 103 | tangent: Whether to create tangents or adjoints i.e. whether we are in 104 | reverse or forward mode. 105 | """ 106 | 107 | def __init__(self, replace_grad, namer=None, tangent=False): 108 | self.replace_grad = replace_grad 109 | if namer is None: 110 | namer = naming.Namer() 111 | self.namer = namer 112 | 113 | self.tangent = tangent 114 | super(ReplaceGradTransformer, self).__init__() 115 | 116 | def visit_Subscript(self, node): 117 | if isinstance(node.value, (gast.Name, gast.Num)) and node.value.id == 'd': 118 | if (not isinstance(node.slice, gast.Index) or 119 | not isinstance(node.slice.value, 120 | (gast.Subscript, gast.Name, gast.Str))): 121 | # This happens when the gradient of a constant is taken 122 | if self.replace_grad == Replace.TANGENT: 123 | new_node = gast.Num(0) 124 | else: 125 | new_node = gast.Name(id='_', ctx=None, annotation=None) 126 | self.remove(new_node) 127 | elif (self.replace_grad in (Replace.FULL, Replace.TANGENT) or 128 | isinstance(node.ctx, gast.Load)): 129 | new_node = create.create_grad(node.slice.value, self.namer, 130 | self.tangent) 131 | elif isinstance(node.ctx, gast.Store): 132 | new_node = create.create_temp_grad(node.slice.value, self.namer, 133 | self.tangent) 134 | else: 135 | raise ValueError 136 | new_node.ctx = node.ctx 137 | if isinstance(new_node, gast.Tuple): 138 | for elt in new_node.elts: 139 | elt.ctx = node.ctx 140 | node = new_node 141 | return node 142 | 143 | 144 | def replace(template, replace_grad=Replace.PARTIAL, 145 | namer=None, **replacements): 146 | """Replace placeholders in a Python template (quote). 147 | 148 | Args: 149 | template: A function, AST node or string to be used as a template. Note 150 | that if a function is passed, any placeholder is expected to also be a 151 | function argument. If a string is passed, it must represent valid 152 | Python code, and any variable it references is a placeholder. 153 | replace_grad: If Replace.NONE, statements of the form `d[x]` are ignored. 154 | For the other possible values, see `ReplaceGradTransformer`. 155 | namer: See `ReplaceGradTransformer`. 156 | **replacements: A mapping from placeholder names to (lists of) AST nodes 157 | that these placeholders will be replaced by. If a string is passed, 158 | `quote` will be called on it to turn it into a node. 159 | 160 | Returns: 161 | body: An AST node or list of AST nodes with the replacements made. If the 162 | template was a function, a list will be returned. If the template was a 163 | node, the same node will be returned. If the template was a string, an 164 | AST node will be returned (a `Module` node in the case of a multi-line 165 | string, an `Expr` node otherwise). 166 | 167 | Raises: 168 | ValueError: If a function is used as a template and an incorrect set of 169 | replacements was passed. 170 | """ 171 | # Handle the 3 different types of templates: funcs, nodes, and strings 172 | is_function = isinstance(template, types.FunctionType) 173 | if is_function: 174 | tree = quoting.parse_function(template).body[0] 175 | placeholders = set(arg.id for arg in tree.args.args) 176 | tree.args.args = [] 177 | if tree.args.vararg: 178 | placeholders.add(tree.args.vararg) 179 | tree.args.vararg = None 180 | if set(replacements.keys()) != placeholders: 181 | raise ValueError('too many or few replacements') 182 | elif isinstance(template, gast.AST): 183 | tree = template 184 | else: 185 | tree = quoting.quote(template, return_expr=True) 186 | # If the replacements are strings, turn them into nodes 187 | for k, v in replacements.items(): 188 | if isinstance(v, six.string_types): 189 | replacements[k] = quoting.quote(v) 190 | # Perform the replacement 191 | ReplaceTransformer(replacements).visit(tree) 192 | # Handle the d[x] operator 193 | if replace_grad is not Replace.NONE: 194 | rgt = ReplaceGradTransformer( 195 | replace_grad=replace_grad, 196 | namer=namer, 197 | tangent=replace_grad is Replace.TANGENT) 198 | rgt.visit(tree) 199 | # Return the AST node with replacements made 200 | if is_function: 201 | return tree.body 202 | else: 203 | return tree 204 | -------------------------------------------------------------------------------- /tangent/tracing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for tracing code, a useful fallback when ahead-of-time AD fails. 15 | """ 16 | 17 | 18 | class Traceable(object): 19 | pass 20 | 21 | 22 | def trace_grad(fn, args): 23 | """Trace a function, and return a VJP and the function's output.""" 24 | from tensorflow.python.eager.backprop import make_vjp 25 | result, vjp = make_vjp(fn)(*args) 26 | return result, vjp 27 | 28 | 29 | def trace(fn): 30 | """Decorator that marks a function to be traced.""" 31 | fn.should_trace = True 32 | return fn 33 | -------------------------------------------------------------------------------- /tangent/transformers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """AST visiting and transformation patterns.""" 15 | 16 | from __future__ import absolute_import 17 | 18 | from collections import deque 19 | from copy import copy 20 | 21 | import gast 22 | from tangent import annotations as anno 23 | from tangent import grammar 24 | 25 | 26 | class TreeTransformer(gast.NodeTransformer): 27 | """A transformer that allows for non-local changes. 28 | 29 | An extension of the standard `NodeTransformer` in Python's `ast` package. 30 | This transformer can insert statements right before or after the current 31 | statement, at the end or beginning of the current block, or at the top of the 32 | function. 33 | 34 | This class is meant to be subclassed in the same way as Python's 35 | `NodeTransformer` class. The subclasses can then call the `append`, 36 | `prepend`, etc. methods as appropriate to transform the AST. 37 | 38 | Note that nodes that are appended or prepended using the `append` and 39 | `prepend` methods will be visited by the transformer. This means that they 40 | can recursively append or prepend statements of their own. This doesn't hold 41 | for statements that are appended/prepended to the block or function body; 42 | these inserted statements are not visited after being inserted. 43 | 44 | To see which nodes classify as statements or which node fields classify as 45 | blocks, please see `grammar.py`. 46 | 47 | Attributes: 48 | to_remove: After the initial pass, this contains a set of nodes that will 49 | be removed. A second pass is automatically performed using the `Remove` 50 | transformer to actually remove those nodes. 51 | 52 | """ 53 | 54 | def __init__(self): 55 | self.to_insert = [] 56 | self.to_prepend = [] 57 | self.to_append = [] 58 | self.to_prepend_block = [] 59 | self.to_append_block = [] 60 | self.to_insert_top = deque() 61 | self.to_remove = set() 62 | self._top = True 63 | 64 | def prepend(self, node): 65 | """Prepend a statement to the current statement. 66 | 67 | Note that multiple calls to prepend will result in the last statement to be 68 | prepended to end up at the top. 69 | 70 | Args: 71 | node: The statement to prepend. 72 | 73 | Raises: 74 | ValueError: If the given node is not a statement. 75 | 76 | """ 77 | if not isinstance(node, grammar.STATEMENTS): 78 | raise ValueError 79 | self.to_prepend[-1].appendleft(node) 80 | 81 | def append(self, node): 82 | """Append a statement to the current statement. 83 | 84 | Note that multiple calls to append will result in the last statement to be 85 | appended to end up at the bottom. 86 | 87 | Args: 88 | node: The statement to append. 89 | 90 | Raises: 91 | ValueError: If the given node is not a statement. 92 | 93 | """ 94 | if not isinstance(node, grammar.STATEMENTS): 95 | raise ValueError 96 | self.to_append[-1].append(node) 97 | 98 | def remove(self, node): 99 | """Remove the given node.""" 100 | self.to_remove.add(node) 101 | 102 | def insert_top(self, node): 103 | """Insert statements at the top of the function body. 104 | 105 | Note that multiple calls to `insert_top` will result in the statements 106 | being prepended in that order; this is different behavior from `prepend`. 107 | 108 | Args: 109 | node: The statement to prepend. 110 | 111 | Raises: 112 | ValueError: If the given node is not a statement. 113 | 114 | """ 115 | if not isinstance(node, grammar.STATEMENTS): 116 | raise ValueError 117 | self.to_insert_top.append(node) 118 | 119 | def prepend_block(self, node, reverse=False): 120 | """Prepend a statement to the current block. 121 | 122 | Args: 123 | node: The statement to prepend. 124 | reverse: When called multiple times, this flag determines whether the 125 | statement should be prepended or appended to the already inserted 126 | statements. 127 | 128 | Raises: 129 | ValueError: If the given node is not a statement. 130 | 131 | """ 132 | if not isinstance(node, grammar.STATEMENTS): 133 | raise ValueError 134 | if reverse: 135 | self.to_prepend_block[-1].appendleft(node) 136 | else: 137 | self.to_prepend_block[-1].append(node) 138 | 139 | def append_block(self, node, reverse=False): 140 | """Append a statement to the current block. 141 | 142 | Args: 143 | node: The statement to prepend. 144 | reverse: When called multiple times, this flag determines whether the 145 | statement should be prepended or appended to the already inserted 146 | statements. 147 | 148 | Raises: 149 | ValueError: If the given node is not a statement. 150 | 151 | """ 152 | if not isinstance(node, grammar.STATEMENTS): 153 | raise ValueError 154 | if reverse: 155 | self.to_append_block[-1].appendleft(node) 156 | else: 157 | self.to_append_block[-1].append(node) 158 | 159 | def visit_statements(self, nodes): 160 | """Visit a series of nodes in a node body. 161 | 162 | This function is factored out so that it can be called recursively on 163 | statements that are appended or prepended. This allows e.g. a nested 164 | expression to prepend a statement, and that statement can prepend a 165 | statement again, etc. 166 | 167 | Args: 168 | nodes: A list of statements. 169 | 170 | Returns: 171 | A list of transformed statements. 172 | """ 173 | for node in nodes: 174 | if isinstance(node, gast.AST): 175 | self.to_prepend.append(deque()) 176 | self.to_append.append(deque()) 177 | node = self.visit(node) 178 | self.visit_statements(self.to_prepend.pop()) 179 | if isinstance(node, gast.AST): 180 | self.to_insert[-1].append(node) 181 | elif node: 182 | self.to_insert[-1].extend(node) 183 | self.visit_statements(self.to_append.pop()) 184 | else: 185 | self.to_insert[-1].append(node) 186 | return self.to_insert[-1] 187 | 188 | def generic_visit(self, node): 189 | is_top = False 190 | if self._top: 191 | is_top = True 192 | self._top = False 193 | for field, old_value in gast.iter_fields(node): 194 | if isinstance(old_value, list): 195 | if (type(node), field) in grammar.BLOCKS: 196 | self.to_prepend_block.append(deque()) 197 | self.to_append_block.append(deque()) 198 | self.to_insert.append(deque()) 199 | new_values = copy(self.visit_statements(old_value)) 200 | self.to_insert.pop() 201 | else: 202 | new_values = [] 203 | for value in old_value: 204 | if isinstance(value, gast.AST): 205 | value = self.visit(value) 206 | if value is None: 207 | continue 208 | elif not isinstance(value, gast.AST): 209 | new_values.extend(value) 210 | continue 211 | new_values.append(value) 212 | if isinstance(node, gast.FunctionDef) and field == 'body': 213 | new_values.extendleft(self.to_insert_top) 214 | self.to_insert_top = deque([]) 215 | if (type(node), field) in grammar.BLOCKS: 216 | new_values.extendleft(self.to_prepend_block.pop()) 217 | return_ = None 218 | if new_values and isinstance(new_values[-1], gast.Return): 219 | return_ = new_values.pop() 220 | new_values.extend(self.to_append_block.pop()) 221 | if return_: 222 | new_values.append(return_) 223 | old_value[:] = new_values 224 | elif isinstance(old_value, gast.AST): 225 | new_node = self.visit(old_value) 226 | if new_node is None: 227 | delattr(node, field) 228 | else: 229 | setattr(node, field, new_node) 230 | if is_top and self.to_remove: 231 | Remove(self.to_remove).visit(node) 232 | return node 233 | 234 | 235 | class Remove(gast.NodeTransformer): 236 | """Remove statements containing given nodes. 237 | 238 | If an entire block was deleted, it will delete the relevant conditional or 239 | loop entirely. Note that deleting an entire function body will result in an 240 | invalid AST. 241 | 242 | Calls to user functions that were generated by Tangent will not be removed 243 | because this might result in incorrect writing and reading from the tape. 244 | 245 | Args: 246 | to_remove: A set of nodes that need to be removed. Note that the entire 247 | statement containing this node will be removed e.g. `y = f(x)` with `x` 248 | being in `to_remove` will result in the entire statement being removed. 249 | 250 | """ 251 | 252 | def __init__(self, to_remove): 253 | self.to_remove = to_remove 254 | self.remove = False 255 | self.is_call = False 256 | 257 | def visit(self, node): 258 | if node in self.to_remove: 259 | self.remove = True 260 | if anno.hasanno(node, 'pri_call') or anno.hasanno(node, 'adj_call'): 261 | # We don't remove function calls for now; removing them also 262 | # removes the push statements inside of them, but not the 263 | # corresponding pop statements 264 | self.is_call = True 265 | new_node = super(Remove, self).visit(node) 266 | if isinstance(node, grammar.STATEMENTS): 267 | if self.remove and not self.is_call: 268 | new_node = None 269 | self.remove = self.is_call = False 270 | if isinstance(node, gast.If) and not node.body: 271 | # If we optimized away an entire if block, we need to handle that 272 | if not node.orelse: 273 | return 274 | else: 275 | node.test = gast.UnaryOp(op=gast.Not(), operand=node.test) 276 | node.body, node.orelse = node.orelse, node.body 277 | elif isinstance(node, (gast.While, gast.For)) and not node.body: 278 | return node.orelse 279 | return new_node 280 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Automatically test gradients with multiple inputs, modes and motions.""" 15 | import numpy as np 16 | import six 17 | 18 | import functions 19 | import tfe_utils 20 | 21 | 22 | blacklisted = [ 23 | 'inlining_contextmanager', 24 | 'listcomp', 25 | 'cart2polar', 26 | 'iterpower_with_nested_def', 27 | 'fn_multiple_return', 28 | 'insert_grad_of', 29 | '_trace_mul', 30 | '_nontrace_mul', 31 | 'active_subscript', # TODO: fix then remove from blacklist 32 | 'init_array_grad_maybe_active', # TODO: fix then remove from blacklist 33 | ] 34 | 35 | funcs = [f for f in functions.__dict__.values() if callable(f)] 36 | whitelist = [f for f in funcs if f.__name__ not in blacklisted] 37 | blacklist = [f for f in funcs if f.__name__ in blacklisted] 38 | 39 | 40 | def pytest_addoption(parser): 41 | # Only test with one input 42 | parser.addoption('--short', action='store_true') 43 | # Only test with all inputs 44 | parser.addoption('--all', action='store_true') 45 | # Restrict to certain functions by name 46 | parser.addoption('--func_filter', action='store') 47 | 48 | 49 | def pytest_generate_tests(metafunc): 50 | # Parametrize the functions 51 | if 'func' in metafunc.fixturenames: 52 | func_filter = metafunc.config.option.func_filter 53 | 54 | # Test takes args, only pass funcs with same signature 55 | args = tuple( 56 | arg for arg in metafunc.fixturenames 57 | if arg not in ('func', 'motion', 'optimized', 'preserve_result')) 58 | if args: 59 | func_args = [] 60 | for f in whitelist: 61 | fc = six.get_function_code(f) 62 | if fc.co_varnames[:fc.co_argcount] == args: 63 | func_args.append(f) 64 | else: 65 | func_args = funcs 66 | 67 | if func_filter: 68 | func_args = [f for f in func_args if func_filter in f.__name__] 69 | 70 | func_names = [f.__name__ for f in func_args] 71 | metafunc.parametrize('func', func_args, ids=func_names) 72 | 73 | if 'motion' in metafunc.fixturenames: 74 | metafunc.parametrize('motion', ('split', 'joint')) 75 | 76 | if 'optimized' in metafunc.fixturenames: 77 | metafunc.parametrize('optimized', (True, False), 78 | ids=('optimized', 'unoptimized')) 79 | 80 | if 'preserve_result' in metafunc.fixturenames: 81 | metafunc.parametrize('preserve_result', (True, False)) 82 | 83 | # Parametrize the arguments 84 | short = metafunc.config.option.short 85 | 86 | bools = [True, False] 87 | for arg in ['boolean', 'boolean1', 'boolean2']: 88 | if arg in metafunc.fixturenames: 89 | metafunc.parametrize(arg, bools) 90 | 91 | scalars = [2.] if short else [ 92 | -2., -1.5, -1., -0.5, -0.1, 0.1, 0.5, 1., 1.5, 2. 93 | ] 94 | for arg in 'abc': 95 | if arg in metafunc.fixturenames: 96 | metafunc.parametrize(arg, scalars) 97 | 98 | integers = [1] if short else [1, 2, 3] 99 | if 'n' in metafunc.fixturenames: 100 | metafunc.parametrize('n', integers) 101 | 102 | vectors = [np.random.randn(i) for i in ((3,) if short else (3, 5, 10))] 103 | if 'x' in metafunc.fixturenames: 104 | metafunc.parametrize('x', vectors) 105 | 106 | square_matrices = [np.random.randn(*i) for i in (((3, 3),) if short else ((1, 1), (5, 5)))] 107 | if 'sqm' in metafunc.fixturenames: 108 | metafunc.parametrize('sqm', square_matrices) 109 | 110 | tfe_utils.register_parametrizations(metafunc, short) 111 | -------------------------------------------------------------------------------- /tests/test_anf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import gast 15 | import pytest 16 | 17 | from tangent import anf 18 | from tangent import quoting 19 | 20 | 21 | def anf_lines(f): 22 | """Return the ANF transformed source code as lines.""" 23 | return quoting.unquote(anf.anf(quoting.parse_function(f))).split('\n') 24 | 25 | 26 | def anf_function(f, globals_=None): 27 | m = gast.gast_to_ast(anf.anf(quoting.parse_function(f))) 28 | m = gast.fix_missing_locations(m) 29 | exec(compile(m, '', 'exec'), globals_) 30 | return f 31 | 32 | 33 | def test_anf(): 34 | def g(x): 35 | return x * 2 36 | 37 | h = g 38 | 39 | def f(x): 40 | y = g(h(x)) 41 | return y 42 | 43 | assert anf_lines(f)[1].strip() == "h_x = h(x)" 44 | assert anf_function(f, locals())(2) == 8 45 | 46 | def f(x): 47 | return x * x * x 48 | 49 | assert 'return' in anf_lines(f)[-1] and '*' not in anf_lines(f)[-1] 50 | assert anf_function(f)(2) == 8 51 | 52 | def f(x): 53 | y = [(x.y[0],), 3] 54 | y += x * f(x[g(x)].b, (3, x / -2)) 55 | 56 | assert anf.anf(quoting.parse_function(f)) 57 | 58 | 59 | def test_long(): 60 | def f(x): 61 | return some_very_long_name(some_other_long_name(x)) 62 | 63 | # If a function name is long, we use the LHS or return statement for the name 64 | # instead 65 | assert anf_lines(f)[-1].strip() == 'return _return' 66 | assert anf_lines(f)[-3].strip().startswith('_return2 = ') 67 | 68 | def f(x): 69 | some_very_long_variable_name_here = f(some_very_long_function_name(x)) 70 | return some_very_long_variable_name_here 71 | 72 | # If both the target and function name are long, we should back off to short, 73 | # random variable names 74 | assert len(anf_lines(f)[-3].strip()) < 40 75 | 76 | 77 | if __name__ == '__main__': 78 | assert not pytest.main([__file__]) 79 | -------------------------------------------------------------------------------- /tests/test_annotate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | from tangent import annotate 17 | from tangent import annotations as anno 18 | from tangent import quoting 19 | 20 | 21 | def test_resolve(): 22 | def g(x): 23 | return 2 * x 24 | 25 | def f(x): 26 | return g(x) 27 | 28 | node = annotate.resolve_calls(f) 29 | assert anno.getanno(node.body[0].body[0].value, 'func') == g 30 | 31 | def f(x): 32 | return h(x) 33 | 34 | node = quoting.parse_function(f) 35 | with pytest.raises(AttributeError): 36 | annotate.resolve_calls(f) 37 | 38 | 39 | def test_unused(): 40 | def f(x): 41 | y = x * 2 42 | return x 43 | 44 | node = quoting.parse_function(f) 45 | unused = annotate.unused(node) 46 | assert unused == set([('y', node.body[0].body[0])]) 47 | 48 | def f(x): 49 | y = x * 2 50 | return y 51 | 52 | unused = annotate.unused(quoting.parse_function(f)) 53 | assert not unused 54 | 55 | def f(x): 56 | while True: 57 | y = x * 2 58 | x = 3 59 | return y 60 | 61 | unused = annotate.unused(quoting.parse_function(f)) 62 | assert not unused 63 | 64 | 65 | if __name__ == '__main__': 66 | assert not pytest.main([__file__]) 67 | -------------------------------------------------------------------------------- /tests/test_cfg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | import tangent 17 | from tangent import annotations as anno 18 | from tangent import cfg 19 | 20 | 21 | def f(x): 22 | x 23 | while True: 24 | x = x 25 | x = x 26 | return x 27 | 28 | 29 | def g(x): 30 | if x: 31 | y = 2 32 | return x 33 | 34 | 35 | def h(x, y): 36 | y = f(x) 37 | return y 38 | 39 | 40 | def i(x, y): 41 | z = h(x, y) 42 | x = z[0] 43 | return z 44 | 45 | 46 | def test_reaching(): 47 | node = tangent.quoting.parse_function(f) 48 | cfg.forward(node, cfg.ReachingDefinitions()) 49 | body = node.body[0].body 50 | # Only the argument reaches the expression 51 | assert len(anno.getanno(body[0], 'definitions_in')) == 1 52 | while_body = body[1].body 53 | # x can be either the argument here, or from the previous loop 54 | assert len(anno.getanno(while_body[0], 'definitions_in')) == 2 55 | # x can only be the previous line here 56 | assert len(anno.getanno(while_body[1], 'definitions_in')) == 1 57 | # x can be the argument here or the last definition from the while body 58 | assert len(anno.getanno(body[2], 'definitions_in')) == 2 59 | 60 | 61 | def test_defined(): 62 | node = tangent.quoting.parse_function(g) 63 | cfg.forward(node, cfg.Defined()) 64 | body = node.body[0].body 65 | # only x is for sure defined at the end 66 | assert len(anno.getanno(body[1], 'defined_in')) == 1 67 | # at the end of the if body both x and y are defined 68 | if_body = body[0].body 69 | assert len(anno.getanno(if_body[0], 'defined_out')) == 2 70 | 71 | 72 | def test_active(): 73 | node = tangent.quoting.parse_function(h) 74 | cfg.forward(node, cfg.Active(wrt=(1,))) 75 | body = node.body[0].body 76 | # y has been overwritten here, so nothing is active anymore 77 | assert not anno.getanno(body[-1], 'active_out') 78 | 79 | 80 | def test_active2(): 81 | node = tangent.quoting.parse_function(i) 82 | cfg.forward(node, cfg.Active(wrt=(1,))) 83 | body = node.body[0].body 84 | # through y both x and z are now active 85 | assert len(anno.getanno(body[-1], 'active_out')) == 3 86 | 87 | 88 | if __name__ == '__main__': 89 | assert not pytest.main([__file__]) 90 | -------------------------------------------------------------------------------- /tests/test_comments.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | from tangent import comments 17 | from tangent import quoting 18 | 19 | 20 | def f(x): 21 | y = x 22 | return y 23 | 24 | 25 | def test_comment(): 26 | node = quoting.parse_function(f).body[0] 27 | 28 | comments.add_comment(node.body[0], 'foo', 'above') 29 | source = quoting.to_source(node) 30 | lines = source.split('\n') 31 | assert lines[1].strip() == '# foo' 32 | 33 | comments.add_comment(node.body[0], 'foo', 'right') 34 | source = quoting.to_source(node) 35 | lines = source.split('\n') 36 | assert lines[1].strip() == 'y = x # foo' 37 | 38 | comments.add_comment(node.body[0], 'foo', 'below') 39 | source = quoting.to_source(node) 40 | lines = source.split('\n') 41 | assert lines[2].strip() == '# foo' 42 | 43 | 44 | if __name__ == '__main__': 45 | assert not pytest.main([__file__]) 46 | -------------------------------------------------------------------------------- /tests/test_compile.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import inspect 15 | 16 | import gast 17 | import pytest 18 | 19 | from tangent import compile as compile_ 20 | from tangent import quoting 21 | 22 | 23 | def test_compile(): 24 | def f(x): 25 | return x * 2 26 | 27 | f = compile_.compile_function(quoting.parse_function(f)) 28 | assert f(2) == 4 29 | assert inspect.getsource(f).split('\n')[0] == 'def f(x):' 30 | 31 | def f(x): 32 | return y * 2 33 | 34 | f = compile_.compile_function(quoting.parse_function(f), {'y': 3}) 35 | assert f(2) == 6 36 | 37 | 38 | def test_function_compile(): 39 | with pytest.raises(TypeError): 40 | compile_.compile_function(quoting.quote('x = y')) 41 | with pytest.raises(ValueError): 42 | compile_.compile_function(gast.parse('x = y')) 43 | 44 | 45 | if __name__ == '__main__': 46 | assert not pytest.main([__file__]) 47 | -------------------------------------------------------------------------------- /tests/test_fence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Fence tests.""" 15 | import inspect 16 | import sys 17 | 18 | import pytest 19 | from tangent import fence 20 | from tangent import quoting 21 | 22 | # TODO: Add version-specific checks. 23 | # Not tested (mainly because Python 3.6 is not yet supported): 24 | # matmult operator 25 | # f-strings 26 | # type annotations 27 | # try/finally (version-specific) 28 | # try/except (version-specific) 29 | # tyeld/from 30 | # await 31 | # asyncfor 32 | # asyncwith 33 | # nonlocal 34 | # async def 35 | 36 | # Note: currently, these tests only cover the rejection cases. Positive cases 37 | # should normally be caught by the main tests. 38 | 39 | testglobal = 0 40 | 41 | 42 | def _assert_tangent_parse_error(func, fragment): 43 | try: 44 | fence.validate(quoting.parse_function(func), inspect.getsource(func)) 45 | assert False 46 | except fence.TangentParseError as expected: 47 | assert fragment in str(expected) 48 | 49 | 50 | def test_bytes(): 51 | if sys.version_info >= (3, 0): 52 | def f(_): 53 | return b'foo' 54 | _assert_tangent_parse_error(f, 'Byte Literals') 55 | 56 | 57 | def test_set(): 58 | 59 | def f(_): 60 | return set({1}) 61 | 62 | _assert_tangent_parse_error(f, 'Sets') 63 | 64 | 65 | def test_del(): 66 | 67 | def f(x): 68 | del x 69 | return 1 70 | 71 | _assert_tangent_parse_error(f, 'Del') 72 | 73 | 74 | def test_starred(): 75 | 76 | def f(x): 77 | return zip(*x) 78 | 79 | _assert_tangent_parse_error(f, 'Unpack') 80 | 81 | 82 | def test_uadd(): 83 | 84 | def f(x): 85 | return +x 86 | 87 | _assert_tangent_parse_error(f, 'Unary Add') 88 | 89 | 90 | def test_not(): 91 | 92 | def f(x): 93 | return not x 94 | 95 | _assert_tangent_parse_error(f, 'Not operator') 96 | 97 | 98 | def test_invert(): 99 | 100 | def f(x): 101 | return ~x 102 | 103 | _assert_tangent_parse_error(f, 'Invert') 104 | 105 | 106 | def test_floordiv(): 107 | 108 | def f(x): 109 | return x // 2 110 | 111 | _assert_tangent_parse_error(f, 'Floor Div') 112 | 113 | 114 | def test_lshift(): 115 | 116 | def f(x): 117 | return x << 1 118 | 119 | _assert_tangent_parse_error(f, 'Left Shift') 120 | 121 | 122 | def test_rshift(): 123 | 124 | def f(x): 125 | return x >> 1 126 | 127 | _assert_tangent_parse_error(f, 'Right Shift') 128 | 129 | 130 | def test_bitor(): 131 | 132 | def f(x): 133 | return x | 1 134 | 135 | _assert_tangent_parse_error(f, 'Bitwise Or') 136 | 137 | 138 | def test_bitxor(): 139 | 140 | def f(x): 141 | return x ^ 1 142 | 143 | _assert_tangent_parse_error(f, 'Bitwise Xor') 144 | 145 | 146 | def test_bitand(): 147 | 148 | def f(x): 149 | return x & 1 150 | 151 | _assert_tangent_parse_error(f, 'Bitwise And') 152 | 153 | 154 | def test_in(): 155 | 156 | def f(x): 157 | return 1 in x 158 | 159 | _assert_tangent_parse_error(f, 'In operator') 160 | 161 | 162 | def test_notin(): 163 | 164 | def f(x): 165 | return 1 not in x 166 | 167 | _assert_tangent_parse_error(f, 'Not In operator') 168 | 169 | 170 | def test_ifexp(): 171 | 172 | def f(x): 173 | return 1 if x else 2 174 | 175 | _assert_tangent_parse_error(f, 'Conditional') 176 | 177 | 178 | def test_setcomp(): 179 | 180 | def f(x): 181 | return {i for i in x} 182 | 183 | _assert_tangent_parse_error(f, 'Set Comprehensions') 184 | 185 | 186 | def test_generatorexp(): 187 | 188 | def f(x): 189 | return (i for i in x) 190 | 191 | _assert_tangent_parse_error(f, 'Generator') 192 | 193 | 194 | def test_dictcomp(): 195 | 196 | def f(x): 197 | return {i: 1 for i in x} 198 | 199 | _assert_tangent_parse_error(f, 'Dictionary Comprehensions') 200 | 201 | 202 | def test_delete(): 203 | 204 | def f(x): 205 | del x[1] 206 | return x 207 | 208 | _assert_tangent_parse_error(f, 'Delete statements') 209 | 210 | 211 | def test_import(): 212 | 213 | def f(x): 214 | import tangent 215 | return x 216 | 217 | _assert_tangent_parse_error(f, 'Import statements') 218 | 219 | 220 | def test_importfrom(): 221 | 222 | def f(x): 223 | from tangent import grad 224 | return x 225 | 226 | _assert_tangent_parse_error(f, 'Import/From statements') 227 | 228 | 229 | def test_alias(): 230 | 231 | def f(x): 232 | import tangent as tg 233 | return x 234 | 235 | # The checker should never reach alias nodes as long as it blocks imports. 236 | _assert_tangent_parse_error(f, 'Import statements') 237 | 238 | 239 | def test_for(): 240 | 241 | def f(x): 242 | for _ in range(2): 243 | x += 1 244 | break 245 | else: 246 | x = 0 247 | return x 248 | 249 | _assert_tangent_parse_error(f, 'Else block') 250 | 251 | 252 | def test_continue(): 253 | 254 | def f(x): 255 | for _ in range(2): 256 | continue 257 | return x 258 | 259 | _assert_tangent_parse_error(f, 'Continue') 260 | 261 | 262 | def test_lambda(): 263 | 264 | def f(_): 265 | return lambda x: x + 1 266 | 267 | _assert_tangent_parse_error(f, 'Lambda') 268 | 269 | 270 | def test_yield(): 271 | 272 | def f(x): 273 | yield x + 1 274 | 275 | _assert_tangent_parse_error(f, 'Yield') 276 | 277 | 278 | def test_global(): 279 | 280 | def f(x): 281 | global testglobal 282 | testglobal = 0 283 | return x 284 | 285 | _assert_tangent_parse_error(f, 'Global') 286 | 287 | 288 | def test_classdef(): 289 | 290 | def f(_): 291 | 292 | class Foo(object): 293 | pass 294 | 295 | return Foo 296 | 297 | _assert_tangent_parse_error(f, 'Class') 298 | 299 | 300 | if __name__ == '__main__': 301 | assert not pytest.main([__file__]) 302 | -------------------------------------------------------------------------------- /tests/test_forward_mode.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Forward-mode tests. 15 | 16 | Notes 17 | ----- 18 | Arguments func, a, b, c, x, and n are automatically filled in. 19 | 20 | Pass --short for a quick run. 21 | 22 | """ 23 | import sys 24 | 25 | import pytest 26 | import tfe_utils 27 | import utils 28 | 29 | 30 | def test_deriv_unary(func, preserve_result, a): 31 | """Test derivatives of single-argument scalar functions.""" 32 | utils.test_forward_array(func, (0,), preserve_result, a) 33 | 34 | 35 | def test_deriv_binary(func, preserve_result, a, b): 36 | """Test derivatives of two-argument scalar functions.""" 37 | utils.test_forward_array(func, (0,), preserve_result, a, b) 38 | 39 | 40 | def test_deriv_ternary(func, preserve_result, a, b, c): 41 | """Test derivatives of three-argument scalar functions.""" 42 | utils.test_forward_array(func, (0,), preserve_result, a, b, c) 43 | 44 | 45 | def test_deriv_binary_int(func, preserve_result, a, n): 46 | """Test derivatives of functions with scalar and integer input.""" 47 | utils.test_forward_array(func, (0,), preserve_result, a, n) 48 | 49 | 50 | def test_deriv_unary_tensor(func, t): 51 | """Test derivatives of functions with single tensor input.""" 52 | # TODO: remove trace test exemption when tests are consolidated. 53 | if 'trace' in func.__name__: 54 | return 55 | if any(n in func.__name__ for n in ('tfe_rsqrt',)): 56 | utils.assert_forward_not_implemented(func, (0,)) 57 | return 58 | tfe_utils.test_forward_tensor(func, (0,), t) 59 | 60 | 61 | def test_deriv_binary_tensor(func, t1, t2): 62 | """Test derivatives of functions with binary tensor inputs.""" 63 | if any(n in func.__name__ for n in ('tfe_squared_difference',)): 64 | utils.assert_forward_not_implemented(func, (0,)) 65 | return 66 | tfe_utils.test_forward_tensor(func, (0,), t1, t2) 67 | tfe_utils.test_forward_tensor(func, (1,), t1, t2) 68 | 69 | 70 | def test_deriv_image(func, timage, tkernel, conv2dstrides): 71 | """Test derivatives of image functions.""" 72 | utils.assert_forward_not_implemented(func, (0,)) 73 | 74 | 75 | if __name__ == '__main__': 76 | assert not pytest.main([__file__, '--short'] + sys.argv[1:]) 77 | -------------------------------------------------------------------------------- /tests/test_hessian_vector_products.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for Hessian-vector products on a few limited functions. 15 | 16 | HVPs are run in three configurations: 17 | - Reverse-over-reverse (autograd-style) 18 | - Forward-over-reverse (traditional AD style, most efficient) 19 | - Reverse-over-forward 20 | """ 21 | 22 | from autograd import hessian_vector_product 23 | import autograd.numpy as np 24 | import pytest 25 | import tangent 26 | import tfe_utils 27 | 28 | @pytest.fixture 29 | def tf(): 30 | try: 31 | import tensorflow as tf 32 | except ImportError: 33 | pytest.skip() 34 | 35 | 36 | # This test function broke HVPs several times 37 | # during development, so we're using it as a unit test. 38 | def f_straightline(x): 39 | a = x * x * x 40 | b = a * x**2.0 41 | return np.sum(b) 42 | 43 | 44 | def cube(a): 45 | b = a * a * a 46 | return b 47 | 48 | 49 | def f_calltree(x): 50 | a = cube(x) 51 | b = a * x**2.0 52 | return np.sum(b) 53 | 54 | 55 | def tf_straightline(x, tf): 56 | a = x * x * x 57 | b = a * x ** 2.0 58 | return tf.reduce_sum(b) 59 | 60 | 61 | def _test_hvp(func, optimized): 62 | np.random.seed(0) 63 | a = np.random.normal(scale=1, size=(300,)).astype('float32') 64 | v = a.ravel() 65 | 66 | modes = ['forward', 'reverse'] 67 | for mode1 in modes: 68 | for mode2 in modes: 69 | if mode1 == mode2 == 'forward': 70 | continue 71 | df = tangent.autodiff( 72 | func, 73 | mode=mode1, 74 | motion='joint', 75 | optimized=optimized, 76 | check_dims=False) 77 | ddf = tangent.autodiff( 78 | df, mode=mode2, motion='joint', optimized=optimized, check_dims=False) 79 | dx = ddf(a, 1, v) 80 | hvp_ag = hessian_vector_product(func) 81 | dx_ag = hvp_ag(a, v) 82 | assert np.allclose(dx, dx_ag) 83 | 84 | 85 | def _test_tf_hvp(func, optimized, tf): 86 | a = tf.random_normal(shape=(300,)) 87 | v = tf.reshape(a, shape=(-1,)) 88 | 89 | modes = ['forward', 'reverse'] 90 | for mode1 in modes: 91 | for mode2 in modes: 92 | if mode1 == mode2 == 'forward': 93 | continue 94 | df = tangent.autodiff( 95 | func, 96 | mode=mode1, 97 | motion='joint', 98 | optimized=optimized, 99 | check_dims=False) 100 | ddf = tangent.autodiff( 101 | df, mode=mode2, motion='joint', optimized=optimized, check_dims=False) 102 | dx = ddf(a, tf.constant(1.0), v) 103 | # We just ensure it computes something in this case. 104 | assert dx.shape == a.shape 105 | 106 | 107 | def test_hvp_complex_tf(optimized, tf): 108 | _test_tf_hvp(tf_straightline, optimized, tf) 109 | 110 | 111 | def test_hvp_straightline(optimized, tf): 112 | _test_hvp(f_straightline, optimized, tf) 113 | 114 | 115 | def test_hvp_calltree(optimized, tf): 116 | _test_hvp(f_calltree, optimized, tf) 117 | 118 | 119 | if __name__ == '__main__': 120 | assert not pytest.main([__file__, '--short']) 121 | -------------------------------------------------------------------------------- /tests/test_optimization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import gast 15 | import pytest 16 | 17 | from tangent import optimization 18 | from tangent import quoting 19 | 20 | 21 | def test_assignment_propagation(): 22 | def f(x): 23 | y = x 24 | z = y 25 | return z 26 | 27 | node = quoting.parse_function(f) 28 | node = optimization.assignment_propagation(node) 29 | assert len(node.body[0].body) == 2 30 | 31 | 32 | def test_dce(): 33 | def f(x): 34 | y = 2 * x 35 | return x 36 | 37 | node = quoting.parse_function(f) 38 | node = optimization.dead_code_elimination(node) 39 | assert isinstance(node.body[0].body[0], gast.Return) 40 | 41 | 42 | def test_fixed_point(): 43 | def f(x): 44 | y = g(x) 45 | z = h(y) 46 | return x 47 | 48 | node = quoting.parse_function(f) 49 | node = optimization.optimize(node) 50 | assert isinstance(node.body[0].body[0], gast.Return) 51 | 52 | 53 | def test_constant_folding(): 54 | def f(x): 55 | x = 1 * x 56 | x = 0 * x 57 | x = x * 1 58 | x = x * 0 59 | x = x * 2 60 | x = 2 * x 61 | x = 2 * 3 62 | x = 1 + x 63 | x = 0 + x 64 | x = x + 1 65 | x = x + 0 66 | x = x + 2 67 | x = 2 + x 68 | x = 2 + 3 69 | x = 1 - x 70 | x = 0 - x 71 | x = x - 1 72 | x = x - 0 73 | x = x - 2 74 | x = 2 - x 75 | x = 2 - 3 76 | x = 1 / x 77 | x = 0 / x 78 | x = x / 1 79 | x = x / 0 80 | x = x / 2 81 | x = 2 / x 82 | x = 2 / 8 83 | x = 1 ** x 84 | x = 0 ** x 85 | x = x ** 1 86 | x = x ** 0 87 | x = x ** 2 88 | x = 2 ** x 89 | x = 2 ** 3 90 | 91 | def f_opt(x): 92 | x = x 93 | x = 0 94 | x = x 95 | x = 0 96 | x = x * 2 97 | x = 2 * x 98 | x = 6 99 | x = 1 + x 100 | x = x 101 | x = x + 1 102 | x = x 103 | x = x + 2 104 | x = 2 + x 105 | x = 5 106 | x = 1 - x 107 | x = -x 108 | x = x - 1 109 | x = x 110 | x = x - 2 111 | x = 2 - x 112 | x = -1 113 | x = 1 / x 114 | x = 0 / x 115 | x = x 116 | x = x / 0 117 | x = x / 2 118 | x = 2 / x 119 | x = 0.25 120 | x = 1 121 | x = 0 122 | x = x 123 | x = 1 124 | x = x ** 2 125 | x = 2 ** x 126 | x = 8 127 | 128 | node = quoting.parse_function(f) 129 | node = optimization.constant_folding(node) 130 | node_opt = quoting.parse_function(f_opt) 131 | lines = quoting.to_source(node).strip().split('\n')[1:] 132 | lines_opt = quoting.to_source(node_opt).strip().split('\n')[1:] 133 | # In Python 2 integer division could be on, in which case... 134 | if 1 / 2 == 0: 135 | lines_opt[27] = ' x = 0' 136 | assert lines == lines_opt 137 | 138 | 139 | if __name__ == '__main__': 140 | assert not pytest.main([__file__]) 141 | -------------------------------------------------------------------------------- /tests/test_reverse_mode.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Reverse-mode tests. 15 | 16 | Notes 17 | ----- 18 | Arguments func, a, b, c, x, and n are automatically filled in. 19 | 20 | Pass --short for a quick run. 21 | 22 | """ 23 | import sys 24 | 25 | from autograd import grad as ag_grad 26 | from autograd.misc.flatten import flatten 27 | import autograd.numpy as ag_np 28 | import numpy as np 29 | import pytest 30 | 31 | import tangent 32 | from tangent.grad_util import INPUT_DERIVATIVE 33 | from tangent import quoting 34 | import tfe_utils 35 | import utils 36 | from functions import bilinear 37 | from functions import dict_saxpy 38 | from functions import inlining_contextmanager 39 | from functions import logistic_regression 40 | from functions import nested_dict 41 | from functions import rnn 42 | from functions import unpacking_args_saxpy 43 | 44 | 45 | def test_parses(func): 46 | """Test all functions parse.""" 47 | quoting.parse_function(func) 48 | 49 | 50 | def test_logistic_regression(motion, optimized): 51 | func = logistic_regression 52 | w = np.random.randn(3, 5) 53 | b = np.random.randn(5) 54 | input_ = np.random.rand(3) 55 | label = np.zeros(5) 56 | label[1] = 1 57 | 58 | func.__globals__['np'] = np 59 | df = tangent.autodiff( 60 | func, 61 | wrt=(2, 3), 62 | motion=motion, 63 | optimized=optimized, 64 | verbose=True, 65 | input_derivative=INPUT_DERIVATIVE.DefaultOne) 66 | dw, db = df(input_, label, w, b) 67 | 68 | func.__globals__['np'] = ag_np 69 | ag_dw = ag_grad(func, argnum=2)(input_, label, w, b) 70 | ag_db = ag_grad(func, argnum=3)(input_, label, w, b) 71 | assert np.allclose(ag_dw, dw) 72 | assert np.allclose(ag_db, db) 73 | 74 | 75 | def test_rnn(motion, optimized): 76 | func = rnn 77 | w = np.random.randn(2, 3) 78 | inputs = np.random.randn(3, 2) 79 | 80 | func.__globals__['np'] = np 81 | df = tangent.autodiff( 82 | func, 83 | wrt=(0, 1), 84 | motion=motion, 85 | optimized=optimized, 86 | verbose=True, 87 | input_derivative=INPUT_DERIVATIVE.DefaultOne) 88 | dinputs, dw = df(inputs, w) 89 | 90 | num_dinputs = utils.numeric_grad(func)(inputs, w) 91 | num_dw = utils.numeric_grad(lambda w, x: func(x, w))(w, inputs) 92 | assert np.allclose(num_dw, dw) 93 | assert np.allclose(num_dinputs, dinputs) 94 | 95 | 96 | def test_bilinear(optimized): 97 | func = bilinear 98 | D = 3 99 | np.random.seed(0) 100 | x = np.random.randn(1, D) 101 | h = np.random.randn(1, D) 102 | U = np.random.randn(D, D) 103 | w = np.random.randn(D, D) 104 | b = np.random.randn(1, D) 105 | 106 | func.__globals__['np'] = np 107 | df = tangent.autodiff( 108 | func, 109 | wrt=(0,), 110 | motion='joint', 111 | optimized=optimized, 112 | verbose=True, 113 | input_derivative=INPUT_DERIVATIVE.DefaultOne) 114 | dx = df(x, h, U, w, b) 115 | 116 | num_dx = utils.numeric_grad(func)(x, h, U, w, b) 117 | assert np.allclose(num_dx, dx) 118 | 119 | 120 | def test_attributes(): 121 | def f(x): 122 | return x.shape 123 | try: 124 | utils.test_reverse_array(f, 'JOINT', False, False, np.array([1.0, 2.0])) 125 | assert False 126 | except ValueError as expected: 127 | assert 'attributes are not yet supported' in str(expected) 128 | 129 | 130 | def test_grad_unary(func, motion, optimized, preserve_result, a): 131 | """Test gradients of single-argument scalar functions.""" 132 | utils.test_reverse_array(func, motion, optimized, preserve_result, a) 133 | 134 | 135 | def test_grad_binary(func, motion, optimized, preserve_result, a, b): 136 | """Test gradients of two-argument scalar functions.""" 137 | utils.test_reverse_array(func, motion, optimized, preserve_result, a, b) 138 | 139 | 140 | def test_grad_ternary(func, motion, optimized, preserve_result, a, b, c): 141 | """Test gradients of three-argument scalar functions.""" 142 | utils.test_reverse_array(func, motion, optimized, preserve_result, a, b, c) 143 | 144 | 145 | def test_grad_vector(func, motion, optimized, preserve_result, x): 146 | """Test gradients of vector functions.""" 147 | utils.test_reverse_array(func, motion, optimized, preserve_result, x) 148 | 149 | 150 | def test_grad_square_matrix(func, motion, optimized, preserve_result, sqm): 151 | """Test gradients of square matrix functions.""" 152 | utils.test_reverse_array(func, motion, optimized, preserve_result, sqm) 153 | 154 | 155 | def test_grad_binary_int(func, motion, optimized, preserve_result, a, n): 156 | """Test gradients of functions with scalar and integer input.""" 157 | utils.test_reverse_array(func, motion, optimized, preserve_result, a, n) 158 | 159 | 160 | def test_inlining_contextmanager(motion, optimized, a): 161 | func = inlining_contextmanager 162 | func = tangent.tangent(func) 163 | 164 | func.__globals__['np'] = np 165 | df = tangent.autodiff( 166 | func, 167 | motion=motion, 168 | optimized=optimized, 169 | verbose=True, 170 | input_derivative=INPUT_DERIVATIVE.DefaultOne) 171 | dx = df(a) 172 | 173 | func.__globals__['np'] = ag_np 174 | df_ag = ag_grad(func) 175 | df_ag(a) 176 | assert np.allclose(dx, 2.9 * a**2) 177 | 178 | 179 | def test_dict_saxpy(motion, optimized, a, b, c): 180 | func = dict_saxpy 181 | func = tangent.tangent(func) 182 | 183 | func.__globals__['np'] = np 184 | df = tangent.autodiff( 185 | func, 186 | motion=motion, 187 | optimized=optimized, 188 | verbose=True, 189 | input_derivative=INPUT_DERIVATIVE.DefaultOne) 190 | dx = df(dict(a=a, b=b, c=c)) 191 | 192 | df_num = utils.numeric_grad(func) 193 | dx_num = df_num(dict(a=float(a), b=float(b), c=float(c))) 194 | flat_dx, _ = flatten(dx) 195 | flat_dx_num, _ = flatten(dx_num) 196 | assert np.allclose(flat_dx, flat_dx_num) 197 | 198 | 199 | def test_unpacking_args_saxpy(motion, optimized, a, b, c): 200 | func = unpacking_args_saxpy 201 | func = tangent.tangent(func) 202 | 203 | func.__globals__['np'] = np 204 | df = tangent.autodiff( 205 | func, 206 | motion=motion, 207 | optimized=optimized, 208 | verbose=True, 209 | input_derivative=INPUT_DERIVATIVE.DefaultOne) 210 | dx = df((a, b, c)) 211 | 212 | df_num = utils.numeric_grad(func) 213 | dx_num = df_num((a, b, c)) 214 | assert np.allclose(dx, dx_num) 215 | 216 | 217 | def test_nested_dict(motion, optimized): 218 | p = dict(i=dict(j=3.0, k=4.0)) 219 | func = nested_dict 220 | df = tangent.autodiff( 221 | func, 222 | motion=motion, 223 | optimized=optimized, 224 | verbose=True, 225 | input_derivative=INPUT_DERIVATIVE.DefaultOne) 226 | dx = df(p) 227 | 228 | df_ag = ag_grad(func) 229 | dx_ag = df_ag(p) 230 | for k in p['i']: 231 | assert np.allclose(dx['i'][k], dx_ag['i'][k]) 232 | 233 | 234 | def test_grad_unary_tensor(func, motion, optimized, preserve_result, t): 235 | """Test gradients of functions with single tensor input.""" 236 | tfe_utils.test_rev_tensor(func, motion, optimized, preserve_result, (0,), t) 237 | 238 | 239 | def test_grad_unary_reduction(func, motion, optimized, preserve_result, 240 | timage, boolean): 241 | """Test gradients of reduction functions.""" 242 | tfe_utils.test_rev_tensor(func, motion, optimized, preserve_result, (0,), 243 | timage, boolean) 244 | 245 | 246 | def test_grad_binary_tensor(func, motion, optimized, preserve_result, t1, t2): 247 | """Test gradients of functions with binary tensor inputs.""" 248 | tfe_utils.test_rev_tensor(func, motion, optimized, preserve_result, (0, 1), 249 | t1, t2) 250 | 251 | 252 | def test_grad_matmul(func, motion, optimized, preserve_result, mat1, mat2, 253 | boolean1, boolean2): 254 | """Test gradients of functions with binary matrix inputs.""" 255 | tfe_utils.test_rev_tensor(func, motion, optimized, preserve_result, (0, 1), 256 | mat1, mat2, boolean1, boolean2) 257 | 258 | 259 | def test_grad_matmul_higherdim(func, motion, optimized, preserve_result, 260 | timage1, timage2, boolean1, boolean2): 261 | """Test gradients of functions with binary matrix inputs.""" 262 | tfe_utils.test_rev_tensor(func, motion, optimized, preserve_result, (0, 1), 263 | timage1, timage2, boolean1, boolean2) 264 | 265 | 266 | def test_grad_tensor_broadcast(func, motion, optimized, preserve_result, s, 267 | t): 268 | """Test gradients of functions with binary tensor inputs.""" 269 | tfe_utils.test_rev_tensor(func, motion, optimized, preserve_result, (0, 1), 270 | s, t) 271 | 272 | 273 | def test_grad_image(func, motion, optimized, preserve_result, timage, tkernel, 274 | conv2dstrides): 275 | """Test gradients of image functions.""" 276 | # TODO: Upgrade utils.py to allow simultaneous testing of uneven args. 277 | tfe_utils.test_rev_tensor(func, motion, optimized, preserve_result, (0,), 278 | timage, tkernel, conv2dstrides) 279 | tfe_utils.test_rev_tensor(func, motion, optimized, preserve_result, (1,), 280 | timage, tkernel, conv2dstrides) 281 | 282 | 283 | def test_grad_image_pooling(func, motion, optimized, preserve_result, timage, 284 | pool2dsizes, conv2dstrides): 285 | tfe_utils.test_rev_tensor(func, motion, optimized, preserve_result, (0,), 286 | timage, pool2dsizes, conv2dstrides) 287 | 288 | 289 | if __name__ == '__main__': 290 | assert not pytest.main([__file__, '--short'] + sys.argv[1:]) 291 | -------------------------------------------------------------------------------- /tests/test_reverse_over_reverse.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for reverse-over-reverse automatic differentiation. 15 | 16 | Notes 17 | ----- 18 | Arguments func, a, b, c, x, and n are automatically filled in. 19 | 20 | Pass --short for a quick run. 21 | 22 | """ 23 | from autograd import grad as ag_grad 24 | import autograd.numpy as ag_np 25 | import numpy as np 26 | import pytest 27 | import tangent 28 | import utils 29 | 30 | 31 | def _test_gradgrad_array(func, optimized, *args): 32 | """Test gradients of functions with NumPy-compatible signatures.""" 33 | 34 | def tangent_func(): 35 | func.__globals__['np'] = np 36 | df = tangent.grad(func, optimized=optimized, verbose=True) 37 | ddf = tangent.grad(df, optimized=optimized, verbose=True) 38 | return ddf(*args) 39 | 40 | def reference_func(): 41 | func.__globals__['np'] = ag_np 42 | return ag_grad(ag_grad(func))(*args) 43 | 44 | def backup_reference_func(): 45 | return utils.numeric_grad(utils.numeric_grad(func))(*args) 46 | 47 | utils.assert_result_matches_reference( 48 | tangent_func, reference_func, backup_reference_func, 49 | tolerance=1e-2) # extra loose bounds for 2nd order grad 50 | 51 | 52 | def test_reverse_over_reverse_unary(func, a, optimized): 53 | _test_gradgrad_array(func, optimized, a) 54 | 55 | 56 | def test_reverse_over_reverse_binary(func, a, b, optimized): 57 | _test_gradgrad_array(func, optimized, a, b) 58 | 59 | 60 | def test_reverse_over_reverse_ternary(func, optimized, a, b, c): 61 | _test_gradgrad_array(func, optimized, a, b, c) 62 | 63 | 64 | if __name__ == '__main__': 65 | assert not pytest.main([__file__, '--short']) 66 | -------------------------------------------------------------------------------- /tests/test_template.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import gast 15 | import pytest 16 | 17 | from tangent import compile as compile_ 18 | from tangent import quoting 19 | from tangent import template 20 | 21 | 22 | def _wrap(body): 23 | """Take a list of statements and wrap them in a function to compile.""" 24 | def f(): 25 | pass 26 | tree = quoting.parse_function(f) 27 | tree.body[0].body = body 28 | return tree 29 | 30 | 31 | def test_variable_replace(): 32 | def f(x): 33 | x = 2 34 | return x 35 | 36 | body = template.replace(f, x=gast.Name(id='y', ctx=None, annotation=None)) 37 | assert body[0].targets[0].id == 'y' 38 | assert isinstance(body[0].targets[0].ctx, gast.Store) 39 | assert isinstance(body[1].value.ctx, gast.Load) 40 | compile_.compile_function(_wrap(body)) 41 | 42 | 43 | def test_statement_replace(): 44 | def f(body): 45 | body 46 | 47 | body = [gast.Expr(value=gast.Name(id=var, ctx=gast.Load(), annotation=None)) 48 | for var in 'xy'] 49 | new_body = template.replace(f, body=body) 50 | assert len(new_body) == 2 51 | assert isinstance(new_body[0], gast.Expr) 52 | compile_.compile_function(_wrap(new_body)) 53 | 54 | 55 | def test_function_replace(): 56 | def f(f, args): 57 | def f(args): 58 | pass 59 | body = template.replace( 60 | f, f='g', args=[gast.Name(id=arg, ctx=None, annotation=None) 61 | for arg in 'ab']) 62 | assert isinstance(body[0], gast.FunctionDef) 63 | assert body[0].name == 'g' 64 | assert len(body[0].args.args) == 2 65 | assert isinstance(body[0].args.args[0].ctx, gast.Param) 66 | assert body[0].args.args[1].id == 'b' 67 | compile_.compile_function(_wrap(body)) 68 | 69 | 70 | def test_partial_gradient_replace(): 71 | def f(x, y): 72 | d[x] = d[y] 73 | 74 | tree = quoting.parse_function(f) 75 | transformer = template.ReplaceGradTransformer(template.Replace.PARTIAL) 76 | new_tree = transformer.visit(tree) 77 | assert isinstance(new_tree.body[0].body[0].targets[0], gast.Name) 78 | assert new_tree.body[0].body[0].targets[0].id == '_bx' 79 | assert new_tree.body[0].body[0].value.id == 'by' 80 | compile_.compile_function(new_tree) 81 | 82 | 83 | def test_full_gradient_replace(): 84 | def f(x, y): 85 | d[x] = d[y] 86 | 87 | tree = quoting.parse_function(f) 88 | transformer = template.ReplaceGradTransformer(template.Replace.FULL) 89 | new_tree = transformer.visit(tree) 90 | assert isinstance(new_tree.body[0].body[0].targets[0], gast.Name) 91 | assert new_tree.body[0].body[0].targets[0].id == 'bx' 92 | assert new_tree.body[0].body[0].value.id == 'by' 93 | compile_.compile_function(new_tree) 94 | 95 | 96 | def test_node_replace(): 97 | node = template.replace(quoting.quote("a = b"), a="y", b="x * 2") 98 | assert quoting.unquote(node) == "y = x * 2" 99 | 100 | 101 | def test_string_replace(): 102 | node = template.replace("a = b", a="y", b="x * 2") 103 | assert quoting.unquote(node) == "y = x * 2" 104 | 105 | 106 | if __name__ == '__main__': 107 | assert not pytest.main([__file__]) 108 | -------------------------------------------------------------------------------- /tests/test_transformers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | from tangent import quoting 17 | from tangent import transformers 18 | 19 | 20 | def test_insert(): 21 | def f(x): 22 | y = x 23 | return y 24 | node = quoting.parse_function(f) 25 | 26 | class Prepend(transformers.TreeTransformer): 27 | def visit_Assign(self, node): 28 | # If the target is y, then prepend this statement 29 | # NOTE Without this test, we'd have an infinite loop 30 | if node.targets[0].id == 'y': 31 | statement = quoting.quote("x = 2 * x") 32 | self.prepend(statement) 33 | return node 34 | 35 | Prepend().visit(node) 36 | assert quoting.unquote(node).split('\n')[1].strip() == "x = 2 * x" 37 | 38 | 39 | def test_insert_block(): 40 | def f(x): 41 | while True: 42 | y = x 43 | z = y 44 | return y 45 | node = quoting.parse_function(f) 46 | 47 | class PrependBlock(transformers.TreeTransformer): 48 | def visit_Assign(self, node): 49 | # If the target is y, then prepend this statement 50 | # NOTE Without this test, we'd have an infinite loop 51 | if node.targets[0].id == 'z': 52 | statement = quoting.quote("x = 2 * x") 53 | self.prepend_block(statement) 54 | return node 55 | 56 | PrependBlock().visit(node) 57 | assert quoting.unquote(node).split('\n')[2].strip() == "x = 2 * x" 58 | 59 | 60 | def test_insert_top(): 61 | def f(x): 62 | while True: 63 | y = x 64 | z = y 65 | return y 66 | node = quoting.parse_function(f) 67 | 68 | class InsertTop(transformers.TreeTransformer): 69 | def visit_Assign(self, node): 70 | # If the target is y, then prepend this statement 71 | # NOTE Without this test, we'd have an infinite loop 72 | if node.targets[0].id == 'z': 73 | statement = quoting.quote("x = 2 * x") 74 | self.insert_top(statement) 75 | return node 76 | 77 | InsertTop().visit(node) 78 | assert quoting.unquote(node).split('\n')[1].strip() == "x = 2 * x" 79 | 80 | 81 | def test_remove(): 82 | def f(x): 83 | while True: 84 | y = x 85 | z = y 86 | return y 87 | node = quoting.parse_function(f) 88 | 89 | class InsertTop(transformers.TreeTransformer): 90 | def visit_Assign(self, node): 91 | # If the target is y, then prepend this statement 92 | # NOTE Without this test, we'd have an infinite loop 93 | if node.targets[0].id == 'z': 94 | self.remove(node) 95 | return node 96 | 97 | InsertTop().visit(node) 98 | assert quoting.unquote(node).split('\n')[3].strip() == "return y" 99 | 100 | 101 | if __name__ == '__main__': 102 | assert not pytest.main([__file__]) 103 | -------------------------------------------------------------------------------- /tests/tfe_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """TFE-specific test utils.""" 15 | import numpy as np 16 | 17 | import pytest 18 | from tangent.grad_util import autodiff, jvp 19 | import utils 20 | 21 | try: 22 | import tensorflow as tf 23 | from tensorflow.contrib.eager.python import tfe 24 | except ImportError: 25 | tf = None 26 | tfe = None 27 | else: 28 | tfe.enable_eager_execution() 29 | 30 | 31 | def register_parametrizations(metafunc, short): 32 | """Create additional parametrizations required for TF tests.""" 33 | 34 | for arg in ['t', 't1', 't2']: 35 | # Note: care must be exercised when sharing tensor objects. Although 36 | # immutable, references to the same value will be interpreted as the same 37 | # variable, with unexpected side effects. 38 | if tf: 39 | vectors = [ 40 | np.random.randn(i) 41 | for i in ( 42 | (3,) if short else (3, 5, 10))] 43 | tensors = [tf.constant(v, dtype=tf.float32) for v in vectors] 44 | else: 45 | tensors = [pytest.mark.skip(None, reason='tensorflow not present')(None)] 46 | if arg in metafunc.fixturenames: 47 | metafunc.parametrize(arg, tensors) 48 | 49 | for arg in ['mat1', 'mat2']: 50 | if tf: 51 | matrices = [ 52 | np.random.randn(*i) 53 | for i in ( 54 | ((3, 3),) if short else ( 55 | (1, 1), 56 | (3, 3), 57 | (5, 5)))] 58 | tensors = [tf.constant(m, dtype=tf.float32) for m in matrices] 59 | else: 60 | tensors = [pytest.mark.skip(None, reason='tensorflow not present')(None)] 61 | if arg in metafunc.fixturenames: 62 | metafunc.parametrize(arg, tensors) 63 | 64 | if 's' in metafunc.fixturenames: 65 | if tf: 66 | if short: 67 | scalars = [tf.constant(1.0)] 68 | else: 69 | scalars = [tf.constant(c) for c in (0.0, 1.0, 2.0)] 70 | else: 71 | scalars = [pytest.mark.skip(reason='tensorflow not present')(None)] 72 | metafunc.parametrize('s', scalars) 73 | 74 | for arg in ['timage', 'timage1', 'timage2']: 75 | if arg in metafunc.fixturenames: 76 | if tf: 77 | images = [ 78 | np.random.randn(*i) 79 | for i in ( 80 | ((2, 3, 3, 3),) if short else ( 81 | (2, 1, 1, 3), 82 | (2, 3, 3, 3), 83 | (2, 5, 5, 3), 84 | )) 85 | ] 86 | timages = [tf.constant(v, dtype=tf.float32) for v in images] 87 | else: 88 | timages = [pytest.mark.skip(reason='tensorflow not present')(None)] 89 | metafunc.parametrize(arg, timages) 90 | 91 | if 'tkernel' in metafunc.fixturenames: 92 | if tf: 93 | kernels = [ 94 | np.random.randn(*i) 95 | for i in ( 96 | ((3, 3, 3, 1),) if short else ( 97 | (3, 3, 3, 1), 98 | (3, 3, 3, 2), 99 | (5, 5, 3, 3), 100 | )) 101 | ] 102 | tkernels = [tf.constant(v, dtype=tf.float32) for v in kernels] 103 | else: 104 | tkernels = [pytest.mark.skip(reason='tensorflow not present')(None)] 105 | metafunc.parametrize('tkernel', tkernels) 106 | 107 | if 'conv2dstrides' in metafunc.fixturenames: 108 | strides = [(1, 2, 2, 1)] if short else [ 109 | (1, 1, 1, 1), 110 | (1, 2, 2, 1), 111 | (1, 2, 2, 2), 112 | ] 113 | metafunc.parametrize('conv2dstrides', strides) 114 | 115 | if 'pool2dsizes' in metafunc.fixturenames: 116 | sizes = [(1, 2, 2, 1)] if short else [ 117 | (1, 1, 1, 1), 118 | (1, 2, 2, 1), 119 | (1, 3, 3, 1), 120 | ] 121 | metafunc.parametrize('pool2dsizes', sizes) 122 | 123 | 124 | def tensors_to_numpy(tensors): 125 | if isinstance(tensors, (tuple, list)): 126 | return tuple(tensors_to_numpy(t) for t in tensors) 127 | if isinstance(tensors, tf.Tensor): 128 | return tensors.numpy() 129 | raise ValueError('Don\'t know how to handle %s' % type(tensors)) 130 | 131 | 132 | def as_numpy_sig(func): 133 | """Wrap a TF Eager function into a signature that uses NumPy arrays.""" 134 | def wrapped(*args): 135 | np_args = [tf.constant(a) if isinstance(a, np.ndarray) else a for a in args] 136 | return tensors_to_numpy(func(*np_args)) 137 | return wrapped 138 | 139 | 140 | def test_forward_tensor(func, wrt, *args): 141 | """Test gradients of functions with TFE signatures.""" 142 | 143 | def tangent_func(): 144 | df = jvp(func, wrt=wrt, optimized=True, verbose=True) 145 | args_ = args + tuple(tf.ones_like(args[i]) for i in wrt) # seed gradient 146 | return tensors_to_numpy(df(*args_)) 147 | 148 | def reference_func(): 149 | return tensors_to_numpy(tfe.gradients_function(func, params=wrt)(*args)) 150 | 151 | def backup_reference_func(): 152 | func_ = as_numpy_sig(func) 153 | args_ = tensors_to_numpy(args) 154 | return utils.numeric_grad(utils.numeric_grad(func_))(*args_) 155 | 156 | # TODO: Should results really be that far off? 157 | utils.assert_result_matches_reference( 158 | tangent_func, reference_func, backup_reference_func, 159 | tolerance=1e-4) 160 | 161 | 162 | def test_gradgrad_tensor(func, optimized, *args): 163 | """Test gradients of functions with TFE signatures.""" 164 | 165 | def tangent_func(): 166 | df = tangent.autodiff(func, motion='joint', optimized=optimized, verbose=True) 167 | ddf = tangent.autodiff(df, motion='joint', optimized=optimized, verbose=True) 168 | dxx = ddf(*args) 169 | return tuple(t.numpy() for t in dxx) 170 | 171 | def reference_func(): 172 | dxx = tfe.gradients_function(tfe.gradients_function(func))(*args) 173 | return tensors_to_numpy(tuple(t.numpy() for t in dxx)) 174 | 175 | def backup_reference_func(): 176 | func_ = as_numpy_sig(func) 177 | args_ = tensors_to_numpy(args) 178 | return utils.numeric_grad(utils.numeric_grad(func_))(*args_) 179 | 180 | utils.assert_result_matches_reference( 181 | tangent_func, reference_func, backup_reference_func, 182 | tolerance=1e-2) # extra loose bounds for 2nd order grad 183 | 184 | 185 | def test_rev_tensor(func, motion, optimized, preserve_result, wrt, *args): 186 | """Test gradients of functions with TFE signatures.""" 187 | 188 | def tangent_func(): 189 | y = func(*args) 190 | if isinstance(y, (tuple, list)): 191 | init_grad = tuple(tf.ones_like(t) for t in y) 192 | else: 193 | init_grad = tf.ones_like(y) 194 | df = autodiff( 195 | func, 196 | motion=motion, 197 | optimized=optimized, 198 | preserve_result=preserve_result, 199 | wrt=wrt, 200 | verbose=True) 201 | if motion == 'joint': 202 | # TODO: This won't work if func has default args unspecified. 203 | dx = df(*args + (init_grad,)) 204 | else: 205 | dx = df(*args, init_grad=init_grad) 206 | return tensors_to_numpy(dx) 207 | 208 | def reference_func(): 209 | gradval = tensors_to_numpy(tfe.gradients_function(func, params=wrt)(*args)) 210 | if preserve_result: 211 | val = tensors_to_numpy(func(*args)) 212 | if isinstance(gradval, (tuple)): 213 | return gradval + (val,) 214 | return gradval, val 215 | else: 216 | return gradval 217 | 218 | def backup_reference_func(): 219 | func_ = as_numpy_sig(func) 220 | args_ = tensors_to_numpy(args) 221 | gradval = utils.numeric_grad(utils.numeric_grad(func_))(*args_) 222 | if preserve_result: 223 | val = tensors_to_numpy(func(*args)) 224 | return gradval, val 225 | else: 226 | return gradval 227 | 228 | utils.assert_result_matches_reference( 229 | tangent_func, reference_func, backup_reference_func, 230 | # Some ops like tf.divide diverge significantly due to what looks like 231 | # numerical instability. 232 | tolerance=1e-5) 233 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common testing utilities.""" 15 | from copy import deepcopy 16 | 17 | from autograd import grad as ag_grad 18 | from autograd import value_and_grad as ag_value_and_grad 19 | from autograd.misc.flatten import flatten 20 | import autograd.numpy as ag_np 21 | import numpy as np 22 | import tangent 23 | 24 | # Autograd's NumPy implementation may be missing the definition for _NoValue. 25 | if not hasattr(ag_np, '_NoValue'): 26 | ag_np._NoValue = np._NoValue # pylint: disable=protected-access 27 | 28 | 29 | def assert_forward_not_implemented(func, wrt): 30 | try: 31 | tangent.autodiff(func, mode='forward', preserve_result=False, wrt=wrt) 32 | assert False, 'Remove this when implementing.' 33 | except NotImplementedError: 34 | pass 35 | 36 | 37 | def _assert_allclose(a, b, tol=1e-5): 38 | if isinstance(a, (tuple, list)) and isinstance(b, (tuple, list)): 39 | for ia, ib in zip(a, b): 40 | _assert_allclose(ia, ib, tol) 41 | else: 42 | try: 43 | a = np.nan_to_num(a) 44 | b = np.nan_to_num(b) 45 | assert np.allclose(a, b, tol), ('Expected: %s\nGot: %s' % (b, a)) 46 | except TypeError: 47 | raise TypeError('Could not compare values %s and %s' % (a, b)) 48 | 49 | 50 | def assert_result_matches_reference( 51 | tangent_func, 52 | reference_func, 53 | backup_reference_func, 54 | tolerance=1e-7): 55 | """Test Tangent functionality against reference implementation. 56 | 57 | Args: 58 | tangent_func: Returns the Tangent derivative. 59 | reference_func: Returns the derivative calculated by the reference 60 | implementation. 61 | backup_reference_func: Returns the derivative calculated by a catch-all 62 | implementation, should the reference be unavailable. 63 | tolerance: Absolute tolerance override for FP comparisons. 64 | """ 65 | tangent_value = tangent_func() 66 | try: 67 | reference_value = reference_func() 68 | except (ImportError, TypeError) as e: 69 | if __debug__: 70 | print('WARNING: Reference function call failed. The test will revert to ' 71 | 'the backup reference.\nReason for failure: %s' % e) 72 | # TODO: Try to narrow the exception handler. 73 | reference_value = backup_reference_func() 74 | _assert_allclose(tangent_value, reference_value, tolerance) 75 | 76 | 77 | def numeric_grad(func, eps=1e-6): 78 | """Generate a finite-differences gradient of function `f`. 79 | 80 | def f(x, *args): 81 | ... 82 | return scalar 83 | 84 | g = numeric_grad(f, eps=1e-4) 85 | finite_difference_grad_of_x = g(x, *args) 86 | 87 | Adapted from github.com/hips/autograd 88 | """ 89 | def g(x, *args): 90 | fd_grad, unflatten_fd = flatten(tangent.init_grad(x)) 91 | y = func(deepcopy(x), *args) 92 | seed = np.ones_like(y) 93 | for d in range(fd_grad.size): 94 | x_flat, unflatten_x = flatten(deepcopy(x)) 95 | x_flat[d] += eps / 2 96 | a = np.array(func(unflatten_x(x_flat), *args)) 97 | x_flat, unflatten_x = flatten(deepcopy(x)) 98 | x_flat[d] -= eps / 2 99 | b = np.array(func(unflatten_x(x_flat), *args)) 100 | fd_grad[d] = np.dot((a - b) / eps, seed) 101 | return unflatten_fd(fd_grad) 102 | 103 | return g 104 | 105 | 106 | def test_reverse_array(func, motion, optimized, preserve_result, *args): 107 | """Test gradients of functions with NumPy-compatible signatures.""" 108 | 109 | def tangent_func(): 110 | y = func(*deepcopy(args)) 111 | if np.array(y).size > 1: 112 | init_grad = np.ones_like(y) 113 | else: 114 | init_grad = 1 115 | func.__globals__['np'] = np 116 | df = tangent.autodiff( 117 | func, 118 | mode='reverse', 119 | motion=motion, 120 | optimized=optimized, 121 | preserve_result=preserve_result, 122 | verbose=1) 123 | if motion == 'joint': 124 | return df(*deepcopy(args) + (init_grad,)) 125 | return df(*deepcopy(args), init_grad=init_grad) 126 | 127 | def reference_func(): 128 | func.__globals__['np'] = ag_np 129 | if preserve_result: 130 | val, gradval = ag_value_and_grad(func)(*deepcopy(args)) 131 | return gradval, val 132 | else: 133 | return ag_grad(func)(*deepcopy(args)) 134 | 135 | def backup_reference_func(): 136 | func.__globals__['np'] = np 137 | df_num = numeric_grad(func) 138 | gradval = df_num(*deepcopy(args)) 139 | if preserve_result: 140 | val = func(*deepcopy(args)) 141 | return gradval, val 142 | else: 143 | return gradval 144 | 145 | assert_result_matches_reference(tangent_func, reference_func, 146 | backup_reference_func) 147 | 148 | 149 | def test_forward_array(func, wrt, preserve_result, *args): 150 | """Test derivatives of functions with NumPy-compatible signatures.""" 151 | 152 | def tangent_func(): 153 | func.__globals__['np'] = np 154 | df = tangent.autodiff( 155 | func, 156 | mode='forward', 157 | preserve_result=preserve_result, 158 | wrt=wrt, 159 | optimized=True, 160 | verbose=1) 161 | args_ = args + (1.0,) # seed gradient 162 | return df(*deepcopy(args_)) 163 | 164 | def reference_func(): 165 | func.__globals__['np'] = ag_np 166 | if preserve_result: 167 | # Note: ag_value_and_grad returns (val, grad) but we need (grad, val) 168 | val, gradval = ag_value_and_grad(func)(*deepcopy(args)) 169 | return gradval, val 170 | else: 171 | return ag_grad(func)(*deepcopy(args)) 172 | 173 | def backup_reference_func(): 174 | func.__globals__['np'] = np 175 | df_num = numeric_grad(func) 176 | gradval = df_num(*deepcopy(args)) 177 | if preserve_result: 178 | val = func(*deepcopy(args)) 179 | return gradval, val 180 | else: 181 | return gradval 182 | 183 | assert_result_matches_reference(tangent_func, reference_func, 184 | backup_reference_func) 185 | --------------------------------------------------------------------------------