├── .flake8 ├── .github └── workflows │ └── wheels.yml ├── .gitignore ├── .lintrunner.toml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── notebooks ├── README.md └── colab │ ├── aot_autograd_optimizations.ipynb │ ├── jacobians_hessians_colab.ipynb │ └── per_sample_grads_colab.ipynb ├── packaging └── windows │ └── internal │ ├── cuda_install.bat │ └── driver_update.bat ├── pull_request_template.md ├── setup.cfg ├── setup.py └── version.txt /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = B,C,E,F,P,T4,W,B9 3 | max-line-length = 120 4 | # C408 ignored because we like the dict keyword argument syntax 5 | # E501 is not flexible enough, we're using B950 instead 6 | ignore = 7 | E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, 8 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 9 | # to line this up with executable bit 10 | EXE001, 11 | # these ignores are from flake8-bugbear; please fix! 12 | B007,B008, 13 | # these ignores are from flake8-comprehensions; please fix! 14 | C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 15 | exclude = 16 | ./.git, 17 | ./benchmarks, 18 | ./docs, 19 | ./examples, 20 | ./notebooks 21 | -------------------------------------------------------------------------------- /.github/workflows/wheels.yml: -------------------------------------------------------------------------------- 1 | name: Wheels 2 | on: 3 | pull_request: 4 | types: [opened, synchronize, reopened] 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | 11 | build-wheel: 12 | runs-on: ubuntu-22.04 13 | steps: 14 | - name: Setup Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: 3.9 18 | architecture: x64 19 | - name: Checkout functorch 20 | uses: actions/checkout@v2 21 | - name: Install PyTorch Nightly 22 | run: | 23 | python3 -mpip install --pre torch>=1.13.0.dev -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html 24 | - name: Build wheel 25 | run: | 26 | python3 -mpip install wheel 27 | python3 setup.py bdist_wheel 28 | - name: Upload wheel as GHA artifact 29 | uses: actions/upload-artifact@v2 30 | with: 31 | name: functorch.whl 32 | path: dist/*.whl 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | functorch.egg-info/ 4 | *__pycache__* 5 | functorch/version.py 6 | functorch/_C.so 7 | .gdbinit 8 | t.py 9 | .vscode/ 10 | ccache.sh 11 | docs/build 12 | docs/src 13 | docs/source/generated 14 | .DS_Store 15 | op_analysis/*.txt 16 | 17 | # Editor temporaries 18 | *.swn 19 | *.swo 20 | *.swp 21 | *.swm 22 | -------------------------------------------------------------------------------- /.lintrunner.toml: -------------------------------------------------------------------------------- 1 | [[linter]] 2 | code = 'FLAKE8' 3 | include_patterns = ['**/*.py'] 4 | exclude_patterns = [ 5 | '.git/**', 6 | 'benchmarks/**', 7 | 'docs/**', 8 | 'examples/**', 9 | 'notebooks/**', 10 | ] 11 | command = [ 12 | 'python3', 13 | 'tools/lint/flake8_linter.py', 14 | '--', 15 | '@{{PATHSFILE}}' 16 | ] 17 | init_command = [ 18 | 'python3', 19 | 'tools/lint/pip_init.py', 20 | '--dry-run={{DRYRUN}}', 21 | 'flake8==3.8.2', 22 | 'flake8-bugbear==20.1.4', 23 | 'flake8-comprehensions==3.3.0', 24 | 'flake8-executable==2.0.4', 25 | 'flake8-pyi==20.5.0', 26 | 'mccabe==0.6.1', 27 | 'pycodestyle==2.6.0', 28 | 'pyflakes==2.2.0', 29 | ] 30 | 31 | # [[linter]] 32 | # code = 'BLACK' 33 | # include_patterns = [ 34 | # '**/*.py', 35 | # ] 36 | # command = [ 37 | # 'python3', 38 | # 'tools/lint/black_linter.py', 39 | # '--', 40 | # '@{{PATHSFILE}}' 41 | # ] 42 | # init_command = [ 43 | # 'python3', 44 | # 'tools/lint/pip_init.py', 45 | # '--dry-run={{DRYRUN}}', 46 | # 'black==22.3.0', 47 | # ] 48 | # is_formatter = true 49 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing 2 | Feedback on our APIs, as well as finding bugs, would be very helpful. 3 | 4 | Please feel free to chat us up on the PyTorch Slack, or open an issue 5 | at https://github.com/pytorch/functorch if you're interested in 6 | contributing. 7 | 8 | To contribute a change to functorch, please make sure you are submitting a 9 | Pull Request to the functorch folder in https://github.com/pytorch/pytorch 10 | repository. The source of truth for functorch has moved there from 11 | https://github.com/pytorch/functorch ; the code in the pytorch/functorch 12 | repository is read-only. 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, 4 | are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, 7 | this list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | 3. Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software 15 | without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 25 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # functorch 2 | 3 | [**Why functorch?**](#why-composable-function-transforms) 4 | | [**Install guide**](#install) 5 | | [**Transformations**](#what-are-the-transforms) 6 | | [**Documentation**](#documentation) 7 | | [**Future Plans**](#future-plans) 8 | 9 | **This library is currently under heavy development - if you have suggestions 10 | on the API or use-cases you'd like to be covered, please open an github issue 11 | or reach out. We'd love to hear about how you're using the library.** 12 | 13 | `functorch` is [JAX-like](https://github.com/google/jax) composable function 14 | transforms for PyTorch. 15 | 16 | It aims to provide composable `vmap` and `grad` transforms that work with 17 | PyTorch modules and PyTorch autograd with good eager-mode performance. 18 | 19 | In addition, there is experimental functionality to trace through these 20 | transformations using FX in order to capture the results of these transforms 21 | ahead of time. This would allow us to compile the results of vmap or grad 22 | to improve performance. 23 | 24 | ## Why composable function transforms? 25 | 26 | There are a number of use cases that are tricky to do in 27 | PyTorch today: 28 | - computing per-sample-gradients (or other per-sample quantities) 29 | - running ensembles of models on a single machine 30 | - efficiently batching together tasks in the inner-loop of MAML 31 | - efficiently computing Jacobians and Hessians 32 | - efficiently computing batched Jacobians and Hessians 33 | 34 | Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above 35 | without designing a separate subsystem for each. This idea of composable function 36 | transforms comes from the [JAX framework](https://github.com/google/jax). 37 | 38 | ## Install 39 | 40 | There are two ways to install functorch: 41 | 1. functorch from source 42 | 2. functorch beta (compatible with recent PyTorch releases) 43 | 44 | We recommend trying out the functorch beta first. 45 | 46 | ### Installing functorch from source 47 | 48 |
Click to expand 49 |

50 | 51 | #### Using Colab 52 | 53 | Follow the instructions [in this Colab notebook](https://colab.research.google.com/drive/1CrLkqIrydBYP_svnF89UUO-aQEqNPE8x?usp=sharing) 54 | 55 | #### Locally 56 | 57 | As of 9/21/2022, `functorch` comes installed alongside a nightly PyTorch binary. 58 | Please install a Preview (nightly) PyTorch binary; see https://pytorch.org/ 59 | for instructions. 60 | 61 | Once you've done that, run a quick sanity check in Python: 62 | ```py 63 | import torch 64 | from functorch import vmap 65 | x = torch.randn(3) 66 | y = vmap(torch.sin)(x) 67 | assert torch.allclose(y, x.sin()) 68 | ``` 69 | 70 | #### functorch development setup 71 | 72 | As of 9/21/2022, `functorch` comes installed alongside PyTorch and is in the 73 | PyTorch source tree. Please install 74 | [PyTorch from source](https://github.com/pytorch/pytorch#from-source), then, 75 | you will be able to `import functorch`. 76 | 77 | Try to run some tests to make sure all is OK: 78 | ```bash 79 | pytest test/test_vmap.py -v 80 | pytest test/test_eager_transforms.py -v 81 | ``` 82 | 83 | AOTAutograd has some additional optional requirements. You can install them via: 84 | ```bash 85 | pip install networkx 86 | ``` 87 | 88 | To run functorch tests, please install our test dependencies (`expecttest`, `pyyaml`). 89 | 90 | 91 |

92 |
93 | 94 | ### Installing functorch beta (compatible with recent PyTorch releases) 95 | 96 |
Click to expand 97 |

98 | 99 | #### Using Colab 100 | 101 | Follow the instructions [here](https://colab.research.google.com/drive/1GNfb01W_xf8JRu78ZKoNnLqiwcrJrbYG#scrollTo=HJ1srOGeNCGA) 102 | 103 | #### pip 104 | 105 | Prerequisite: [Install PyTorch](https://pytorch.org/get-started/locally/) 106 | 107 | 108 | ```bash 109 | pip install functorch 110 | ``` 111 | 112 | Finally, run a quick sanity check in python: 113 | ```py 114 | import torch 115 | from functorch import vmap 116 | x = torch.randn(3) 117 | y = vmap(torch.sin)(x) 118 | assert torch.allclose(y, x.sin()) 119 | ``` 120 | 121 |

122 |
123 | 124 | ## What are the transforms? 125 | 126 | Right now, we support the following transforms: 127 | - `grad`, `vjp`, `jvp`, 128 | - `jacrev`, `jacfwd`, `hessian` 129 | - `vmap` 130 | 131 | Furthermore, we have some utilities for working with PyTorch modules. 132 | - `make_functional(model)` 133 | - `make_functional_with_buffers(model)` 134 | 135 | ### vmap 136 | 137 | Note: `vmap` imposes restrictions on the code that it can be used on. 138 | For more details, please read its docstring. 139 | 140 | `vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor 141 | operations in `func`. `vmap(func)` returns a new function that maps `func` over 142 | some dimension (default: 0) of each Tensor in `inputs`. 143 | 144 | `vmap` is useful for hiding batch dimensions: one can write a function `func` 145 | that runs on examples and then lift it to a function that can take batches of 146 | examples with `vmap(func)`, leading to a simpler modeling experience: 147 | 148 | ```py 149 | from functorch import vmap 150 | batch_size, feature_size = 3, 5 151 | weights = torch.randn(feature_size, requires_grad=True) 152 | 153 | def model(feature_vec): 154 | # Very simple linear model with activation 155 | assert feature_vec.dim() == 1 156 | return feature_vec.dot(weights).relu() 157 | 158 | examples = torch.randn(batch_size, feature_size) 159 | result = vmap(model)(examples) 160 | ``` 161 | 162 | ### grad 163 | 164 | `grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute 165 | the gradients of the output of func w.r.t. to `inputs[0]`. 166 | 167 | ```py 168 | from functorch import grad 169 | x = torch.randn([]) 170 | cos_x = grad(lambda x: torch.sin(x))(x) 171 | assert torch.allclose(cos_x, x.cos()) 172 | 173 | # Second-order gradients 174 | neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) 175 | assert torch.allclose(neg_sin_x, -x.sin()) 176 | ``` 177 | 178 | When composed with `vmap`, `grad` can be used to compute per-sample-gradients: 179 | ```py 180 | from functorch import vmap 181 | batch_size, feature_size = 3, 5 182 | 183 | def model(weights,feature_vec): 184 | # Very simple linear model with activation 185 | assert feature_vec.dim() == 1 186 | return feature_vec.dot(weights).relu() 187 | 188 | def compute_loss(weights, example, target): 189 | y = model(weights, example) 190 | return ((y - target) ** 2).mean() # MSELoss 191 | 192 | weights = torch.randn(feature_size, requires_grad=True) 193 | examples = torch.randn(batch_size, feature_size) 194 | targets = torch.randn(batch_size) 195 | inputs = (weights,examples, targets) 196 | grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) 197 | ``` 198 | 199 | ### vjp 200 | 201 | The `vjp` transform applies `func` to `inputs` and returns a new function that 202 | computes vjps given some `cotangents` Tensors. 203 | ```py 204 | from functorch import vjp 205 | outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents) 206 | ``` 207 | 208 | ### jvp 209 | 210 | The `jvp` transforms computes Jacobian-vector-products and is also known as 211 | "forward-mode AD". It is not a higher-order function unlike most other transforms, 212 | but it returns the outputs of `func(inputs)` as well as the `jvp`s. 213 | ```py 214 | from functorch import jvp 215 | x = torch.randn(5) 216 | y = torch.randn(5) 217 | f = lambda x, y: (x * y) 218 | _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) 219 | assert torch.allclose(output, x + y) 220 | ``` 221 | 222 | ### jacrev, jacfwd, and hessian 223 | 224 | The `jacrev` transform returns a new function that takes in `x` and returns the 225 | Jacobian of `torch.sin` with respect to `x` using reverse-mode AD. 226 | ```py 227 | from functorch import jacrev 228 | x = torch.randn(5) 229 | jacobian = jacrev(torch.sin)(x) 230 | expected = torch.diag(torch.cos(x)) 231 | assert torch.allclose(jacobian, expected) 232 | ``` 233 | Use `jacrev` to compute the jacobian. This can be composed with vmap to produce 234 | batched jacobians: 235 | 236 | ```py 237 | x = torch.randn(64, 5) 238 | jacobian = vmap(jacrev(torch.sin))(x) 239 | assert jacobian.shape == (64, 5, 5) 240 | ``` 241 | 242 | `jacfwd` is a drop-in replacement for `jacrev` that computes Jacobians using 243 | forward-mode AD: 244 | ```py 245 | from functorch import jacfwd 246 | x = torch.randn(5) 247 | jacobian = jacfwd(torch.sin)(x) 248 | expected = torch.diag(torch.cos(x)) 249 | assert torch.allclose(jacobian, expected) 250 | ``` 251 | 252 | Composing `jacrev` with itself or `jacfwd` can produce hessians: 253 | ```py 254 | def f(x): 255 | return x.sin().sum() 256 | 257 | x = torch.randn(5) 258 | hessian0 = jacrev(jacrev(f))(x) 259 | hessian1 = jacfwd(jacrev(f))(x) 260 | ``` 261 | 262 | The `hessian` is a convenience function that combines `jacfwd` and `jacrev`: 263 | ```py 264 | from functorch import hessian 265 | 266 | def f(x): 267 | return x.sin().sum() 268 | 269 | x = torch.randn(5) 270 | hess = hessian(f)(x) 271 | ``` 272 | 273 | ### Tracing through the transformations 274 | We can also trace through these transformations in order to capture the results as new code using `make_fx`. There is also experimental integration with the NNC compiler (only works on CPU for now!). 275 | 276 | ```py 277 | from functorch import make_fx, grad 278 | def f(x): 279 | return torch.sin(x).sum() 280 | x = torch.randn(100) 281 | grad_f = make_fx(grad(f))(x) 282 | print(grad_f.code) 283 | 284 | def forward(self, x_1): 285 | sin = torch.ops.aten.sin(x_1) 286 | sum_1 = torch.ops.aten.sum(sin, None); sin = None 287 | cos = torch.ops.aten.cos(x_1); x_1 = None 288 | _tensor_constant0 = self._tensor_constant0 289 | mul = torch.ops.aten.mul(_tensor_constant0, cos); _tensor_constant0 = cos = None 290 | return mul 291 | ``` 292 | 293 | ### Working with NN modules: make_functional and friends 294 | 295 | Sometimes you may want to perform a transform with respect to the parameters 296 | and/or buffers of an nn.Module. This can happen for example in: 297 | - model ensembling, where all of your weights and buffers have an additional 298 | dimension 299 | - per-sample-gradient computation where you want to compute per-sample-grads 300 | of the loss with respect to the model parameters 301 | 302 | Our solution to this right now is an API that, given an nn.Module, creates a 303 | stateless version of it that can be called like a function. 304 | 305 | - `make_functional(model)` returns a functional version of `model` and the 306 | `model.parameters()` 307 | - `make_functional_with_buffers(model)` returns a functional version of 308 | `model` and the `model.parameters()` and `model.buffers()`. 309 | 310 | Here's an example where we compute per-sample-gradients using an nn.Linear 311 | layer: 312 | 313 | ```py 314 | import torch 315 | from functorch import make_functional, vmap, grad 316 | 317 | model = torch.nn.Linear(3, 3) 318 | data = torch.randn(64, 3) 319 | targets = torch.randn(64, 3) 320 | 321 | func_model, params = make_functional(model) 322 | 323 | def compute_loss(params, data, targets): 324 | preds = func_model(params, data) 325 | return torch.mean((preds - targets) ** 2) 326 | 327 | per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets) 328 | ``` 329 | 330 | If you're making an ensemble of models, you may find 331 | `combine_state_for_ensemble` useful. 332 | 333 | ## Documentation 334 | 335 | For more documentation, see [our docs website](https://pytorch.org/functorch). 336 | 337 | ## Debugging 338 | `torch._C._functorch.dump_tensor`: Dumps dispatch keys on stack 339 | `torch._C._functorch._set_vmap_fallback_warning_enabled(False)` if the vmap warning spam bothers you. 340 | 341 | ## Future Plans 342 | 343 | In the end state, we'd like to upstream this into PyTorch once we iron out the 344 | design details. To figure out the details, we need your help -- please send us 345 | your use cases by starting a conversation in the issue tracker or trying our 346 | project out. 347 | 348 | ## License 349 | Functorch has a BSD-style license, as found in the [LICENSE](LICENSE) file. 350 | 351 | ## Citing functorch 352 | 353 | If you use functorch in your publication, please cite it by using the following BibTeX entry. 354 | 355 | ```bibtex 356 | @Misc{functorch2021, 357 | author = {Horace He, Richard Zou}, 358 | title = {functorch: JAX-like composable function transforms for PyTorch}, 359 | howpublished = {\url{https://github.com/pytorch/functorch}}, 360 | year = {2021} 361 | } 362 | ``` 363 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | The new, updated versions of these notebooks may be found in the pytorch/pytorch repo. 2 | 3 | We're leaving the old notebooks here as a temporary solution so that our website still 4 | points to the correct thing. We plan to rewrite the links on the website to point to 5 | their newer counterparts soon. 6 | -------------------------------------------------------------------------------- /notebooks/colab/aot_autograd_optimizations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# AOT Autograd - How to use and optimize?\n", 8 | "\n", 9 | "\n", 10 | " \"Open\n", 11 | "\n", 12 | "\n", 13 | "## Background\n", 14 | "In this tutorial, we will learn how to use AOT Autograd to speedup training of deep learning models.\n", 15 | "\n", 16 | "For background, AOT Autograd is a toolkit to assist developers in accelerating training on PyTorch. Broadly, it has two key features\n", 17 | "* AOT Autograd traces the forward and backward graph ahead of time. Presence of forward and backward graph ahead of time facilitates joint graph optimizations such as recomputation or activation checkpointing.\n", 18 | "* AOT Autograd provides simple mechanisms to compile the extracted forward and backward graphs through deep learning compilers, such as NVFuser, NNC, TVM and others.\n" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "\n", 26 | "## What will you learn?\n", 27 | "In this tutorial, we will look at how AOT Autograd can be used, in conjunction with backend compilers, to accelerate the training of PyTorch models. More specifically, you will learn\n", 28 | "* How to use AOT Autograd?\n", 29 | "* How AOT Autograd uses backend compilers to perform operation fusion?\n", 30 | "* How AOT Autograd enables training-specific optimizations such as Recomputation?\n", 31 | "\n", 32 | "So, lets get started.\n" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## Setup\n", 40 | "\n", 41 | "Let's setup a simple model.\n" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 1, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "import torch\n", 51 | "\n", 52 | "def fn(a, b, c, d):\n", 53 | " x = a + b + c + d\n", 54 | " return x.cos().cos()" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "# Test that it works\n", 64 | "a, b, c, d = [torch.randn(2, 4, requires_grad=True) for _ in range(4)]\n", 65 | "ref = fn(a, b, c, d)\n", 66 | "loss = ref.sum()\n", 67 | "loss.backward()" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Use AOT Autograd\n", 75 | "\n", 76 | "Now, lets use AOT Autograd and look at the extracted forward and backward graphs. Internally, AOT uses `__torch_dispatch__` based tracing mechanism to extract forward and backward graphs, and wraps them in `torch.Fx` GraphModule containers. Note that AOT Autograd tracing is different from the usual Fx symbolic tracing. AOT Autograd uses Fx GraphModule just to represent the traced graphs (and not for tracing).\n", 77 | "\n", 78 | "AOT Autograd then sends these forward and backward graphs to the user supplied compilers. So, lets write a compiler that just prints the graph." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 3, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "\n", 91 | "\n", 92 | "\n", 93 | "def forward(self, primals_1, primals_2, primals_3, primals_4):\n", 94 | " add = torch.ops.aten.add(primals_1, primals_2); primals_1 = primals_2 = None\n", 95 | " add_1 = torch.ops.aten.add(add, primals_3); add = primals_3 = None\n", 96 | " add_2 = torch.ops.aten.add(add_1, primals_4); add_1 = primals_4 = None\n", 97 | " cos = torch.ops.aten.cos(add_2)\n", 98 | " cos_1 = torch.ops.aten.cos(cos)\n", 99 | " return [cos_1, add_2, cos]\n", 100 | " \n", 101 | "\n", 102 | "\n", 103 | "\n", 104 | "def forward(self, add_2, cos, tangents_1):\n", 105 | " sin = torch.ops.aten.sin(cos); cos = None\n", 106 | " neg = torch.ops.aten.neg(sin); sin = None\n", 107 | " mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None\n", 108 | " sin_1 = torch.ops.aten.sin(add_2); add_2 = None\n", 109 | " neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None\n", 110 | " mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None\n", 111 | " return [mul_1, mul_1, mul_1, mul_1]\n", 112 | " \n" 113 | ] 114 | } 115 | ], 116 | "source": [ 117 | "from functorch.compile import aot_function\n", 118 | "\n", 119 | "# The compiler_fn is called after the forward and backward graphs are extracted.\n", 120 | "# Here, we just print the code in the compiler_fn. Return of this function is a callable.\n", 121 | "def compiler_fn(fx_module: torch.fx.GraphModule, _):\n", 122 | " print(fx_module.code)\n", 123 | " return fx_module\n", 124 | "\n", 125 | "# Pass on the compiler_fn to the aot_function API\n", 126 | "aot_print_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn)\n", 127 | "\n", 128 | "# Run the aot_print_fn once to trigger the compilation and print the graphs\n", 129 | "res = aot_print_fn(a, b, c, d).sum().backward()\n", 130 | "assert torch.allclose(ref, res)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "The above code prints the Fx graph for the forward and backward graph. You can see that in addition to the original input of the forward pass, the forward graph outputs some additional tensors. These tensors are saved for the backward pass for gradient calculation. We will come back to these later while talking about recomputation." 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "## Operator Fusion\n", 145 | "Now that we understand how to use AOT Autograd to print forward and backward graphs, let us use AOT Autograd to use some actual deep learning compiler. In this tutorial, we use PyTorch Neural Network Compiler (NNC) to perform pointwise operator fusion for CPU devices. For CUDA devices, a suitable alternative is NvFuser. So, lets use NNC" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 4, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "# AOT Autograd has a suite of already integrated backends. Lets import the NNC compiler backend - ts_compile\n", 155 | "from functorch.compile import ts_compile\n", 156 | "\n", 157 | "# Lets compile the forward and backward through ts_compile.\n", 158 | "aot_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile)\n", 159 | "\n", 160 | "# Correctness checking. Lets clone the input so that we can check grads.\n", 161 | "cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]\n", 162 | "cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs\n", 163 | "\n", 164 | "res = aot_nnc_fn(*cloned_inputs)\n", 165 | "loss = res.sum()\n", 166 | "loss.backward()\n", 167 | "assert torch.allclose(ref, res)\n", 168 | "assert torch.allclose(a.grad, cloned_a.grad)\n", 169 | "assert torch.allclose(b.grad, cloned_b.grad)\n", 170 | "assert torch.allclose(c.grad, cloned_c.grad)\n", 171 | "assert torch.allclose(d.grad, cloned_d.grad)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "Lets benchmark the original and AOT Autograd + NNC compiled function." 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 5, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "# Lets write a function to benchmark the forward and backward pass\n", 188 | "import time\n", 189 | "import statistics\n", 190 | "\n", 191 | "def bench(fn, args, prefix):\n", 192 | " warmup = 10\n", 193 | " iterations = 100\n", 194 | "\n", 195 | " for _ in range(warmup):\n", 196 | " ref = fn(*args)\n", 197 | " ref.sum().backward()\n", 198 | " \n", 199 | " fw_latencies = []\n", 200 | " bw_latencies = []\n", 201 | " for _ in range(iterations):\n", 202 | " for arg in args:\n", 203 | " arg.grad = None\n", 204 | "\n", 205 | " fw_begin = time.perf_counter()\n", 206 | " ref = fn(*args)\n", 207 | " fw_end = time.perf_counter()\n", 208 | "\n", 209 | " loss = ref.sum() \n", 210 | "\n", 211 | " bw_begin = time.perf_counter()\n", 212 | " loss.backward()\n", 213 | " bw_end = time.perf_counter()\n", 214 | "\n", 215 | " fw_latencies.append(fw_end - fw_begin)\n", 216 | " bw_latencies.append(bw_end - bw_begin)\n", 217 | " \n", 218 | " avg_fw_latency = statistics.mean(fw_latencies) * 10**6\n", 219 | " avg_bw_latency = statistics.mean(bw_latencies) * 10**6\n", 220 | " print(prefix, \"Fwd = \" + str(avg_fw_latency) + \" us\", \"Bwd = \" + str(avg_bw_latency) + \" us\", sep=', ')\n" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 6, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "Eager, Fwd = 982.6959593920038 us, Bwd = 1899.7003795811906 us\n", 233 | "AOT, Fwd = 734.2723174951971 us, Bwd = 831.1696897726506 us\n" 234 | ] 235 | } 236 | ], 237 | "source": [ 238 | "large_inputs = [torch.randn(1024, 2048, requires_grad=True) for _ in range(4)]\n", 239 | "\n", 240 | "# Benchmark the Eager and AOT Autograd functions\n", 241 | "bench(fn, large_inputs, \"Eager\")\n", 242 | "bench(aot_nnc_fn, large_inputs, \"AOT\")" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "With the help of NNC, AOT Autograd speeds up both the forward and backward pass. If we look at the printed graphs earlier, all the operators are pointwise. The pointwise operators are memory bandwidth bound, and thus benefit from operator fusion. Looking closely at the numbers, the backward pass gets higher speedup. This is because forward pass has to output some intermediate tensors for gradient calculation for the backward pass, preventing it from saving some memory reads and writes. However, such restriction does not exist in the backward graph." 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "## Recomputation (aka Activation Checkpointing)\n", 257 | "Recomputation (often called activation checkpointing) is a technique in which, instead of saving some activations for use in backwards, we recompute them **during** the backwards pass. Recomputing saves memory, but we incur performance overhead.\n", 258 | "\n", 259 | "However, in the presence of fusing compiler, we can do better that that. We can recompute the fusion-friendly operators to save memory, and then rely on the fusing compiler to fuse the recomputed operators. This reduces both memory and runtime. Please refer to this [discuss post](https://dev-discuss.pytorch.org/t/min-cut-optimal-recomputation-i-e-activation-checkpointing-with-aotautograd/467) for more details.\n", 260 | "\n", 261 | "Here, we use AOT Autograd with NNC to perform similar type of recomputation. At the end of `__torch_dispatch__` tracing, AOT Autograd has a forward graph and joint forward-backward graph. AOT Autograd then uses a partitioner to isolate the forward and backward graph. In the example above, we used a default partitioner. For this experiment, we will use another partitioner called `min_cut_rematerialization_partition` to perform smarter fusion-aware recomputation. The partitioner is configurable and one can write their own partitioner to plug it in AOT Autograd." 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 7, 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "name": "stdout", 271 | "output_type": "stream", 272 | "text": [ 273 | "\n", 274 | "\n", 275 | "\n", 276 | "def forward(self, primals_1, primals_2, primals_3, primals_4):\n", 277 | " add = torch.ops.aten.add(primals_1, primals_2); primals_1 = primals_2 = None\n", 278 | " add_1 = torch.ops.aten.add(add, primals_3); add = primals_3 = None\n", 279 | " add_2 = torch.ops.aten.add(add_1, primals_4); add_1 = primals_4 = None\n", 280 | " cos = torch.ops.aten.cos(add_2)\n", 281 | " cos_1 = torch.ops.aten.cos(cos); cos = None\n", 282 | " return [cos_1, add_2]\n", 283 | " \n", 284 | "\n", 285 | "\n", 286 | "\n", 287 | "def forward(self, add_2, tangents_1):\n", 288 | " cos = torch.ops.aten.cos(add_2)\n", 289 | " sin = torch.ops.aten.sin(cos); cos = None\n", 290 | " neg = torch.ops.aten.neg(sin); sin = None\n", 291 | " mul = torch.ops.aten.mul(tangents_1, neg); tangents_1 = neg = None\n", 292 | " sin_1 = torch.ops.aten.sin(add_2); add_2 = None\n", 293 | " neg_1 = torch.ops.aten.neg(sin_1); sin_1 = None\n", 294 | " mul_1 = torch.ops.aten.mul(mul, neg_1); mul = neg_1 = None\n", 295 | " return [mul_1, mul_1, mul_1, mul_1]\n", 296 | " \n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "from functorch.compile import min_cut_rematerialization_partition\n", 302 | "\n", 303 | "# Lets set up the partitioner. Also set the fwd and bwd compilers to the printer function that we used earlier.\n", 304 | "# This will show us how the recomputation has modified the graph.\n", 305 | "aot_fn = aot_function(fn, fw_compiler=compiler_fn, bw_compiler=compiler_fn, partition_fn=min_cut_rematerialization_partition)\n", 306 | "res = aot_fn(a, b, c, d).sum().backward()" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "We can see that compared to default partitioner, forward pass now outputs fewer tensors, and recomputes some operations in the backward pass. Let us try NNC compiler now to perform operator fusions (note that we also have a wrapper function - `memory_efficient_fusion` which internally uses `min_cut_rematerialization_partition` and Torchscript compiler to achieve the same effect as following code)." 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 8, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "\n", 323 | "# Lets set up the partitioner and NNC compiler.\n", 324 | "aot_recompute_nnc_fn = aot_function(fn, fw_compiler=ts_compile, bw_compiler=ts_compile, partition_fn=min_cut_rematerialization_partition)\n", 325 | "\n", 326 | "# Correctness checking. Lets clone the input so that we can check grads.\n", 327 | "cloned_inputs = [x.clone().detach().requires_grad_(True) for x in (a, b, c, d)]\n", 328 | "cloned_a, cloned_b, cloned_c, cloned_d = cloned_inputs\n", 329 | "\n", 330 | "res = aot_recompute_nnc_fn(*cloned_inputs)\n", 331 | "loss = res.sum()\n", 332 | "loss.backward()\n", 333 | "assert torch.allclose(ref, res)\n", 334 | "assert torch.allclose(a.grad, cloned_a.grad)\n", 335 | "assert torch.allclose(b.grad, cloned_b.grad)\n", 336 | "assert torch.allclose(c.grad, cloned_c.grad)\n", 337 | "assert torch.allclose(d.grad, cloned_d.grad)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "markdown", 342 | "metadata": {}, 343 | "source": [ 344 | "Finally, lets benchmark the different functions" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 14, 350 | "metadata": {}, 351 | "outputs": [ 352 | { 353 | "name": "stdout", 354 | "output_type": "stream", 355 | "text": [ 356 | "Eager, Fwd = 740.7676504226401 us, Bwd = 1560.5240693548694 us\n", 357 | "AOT, Fwd = 713.8530415249988 us, Bwd = 909.1200679540634 us\n", 358 | "AOT_Recomp, Fwd = 712.2249767417088 us, Bwd = 791.4606417762116 us\n" 359 | ] 360 | } 361 | ], 362 | "source": [ 363 | "bench(fn, large_inputs, \"Eager\")\n", 364 | "bench(aot_nnc_fn, large_inputs, \"AOT\")\n", 365 | "bench(aot_recompute_nnc_fn, large_inputs, \"AOT_Recomp\")" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "metadata": {}, 371 | "source": [ 372 | "We observe that both forward and backward latency improve over the default partitioner (and a lot better than eager). Fewer outputs in the forward pass and fewer inputs in the backward pass, along with fusion, allows better memory bandwidth utilization leading to further speedups." 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": {}, 378 | "source": [ 379 | "## Actual Usage\n", 380 | "For actual usage on CUDA devices, we've wrapped AOTAutograd in a convenient wrapper - `memory_efficient_fusion`. Use this for fusion on GPU!\n", 381 | "\n", 382 | "```\n", 383 | "from functorch.compile import memory_efficient_fusion\n", 384 | "```\n" 385 | ] 386 | } 387 | ], 388 | "metadata": { 389 | "kernelspec": { 390 | "display_name": "Python 3.9.5 ('base')", 391 | "language": "python", 392 | "name": "python3" 393 | }, 394 | "language_info": { 395 | "codemirror_mode": { 396 | "name": "ipython", 397 | "version": 3 398 | }, 399 | "file_extension": ".py", 400 | "mimetype": "text/x-python", 401 | "name": "python", 402 | "nbconvert_exporter": "python", 403 | "pygments_lexer": "ipython3", 404 | "version": "3.9.5" 405 | }, 406 | "orig_nbformat": 4, 407 | "vscode": { 408 | "interpreter": { 409 | "hash": "73b6e0ee7c860e06bb349c72324473b318d6cb6c97bcad772bce0703fb8d0dfb" 410 | } 411 | } 412 | }, 413 | "nbformat": 4, 414 | "nbformat_minor": 2 415 | } 416 | -------------------------------------------------------------------------------- /notebooks/colab/jacobians_hessians_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms\n", 7 | "\n", 8 | "\n", 9 | " \"Open\n", 10 | "\n", 11 | "\n", 12 | "Computing jacobians or hessians are useful in a number of non-traditional\n", 13 | "deep learning models. It is difficult (or annoying) to compute these quantities\n", 14 | "efficiently using a standard autodiff system like PyTorch Autograd; functorch\n", 15 | "provides ways of computing various higher-order autodiff quantities efficiently." 16 | ], 17 | "metadata": { 18 | "id": "zPbR6-eP51fe" 19 | }, 20 | "id": "zPbR6-eP51fe" 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "source": [ 25 | "## Computing the Jacobian" 26 | ], 27 | "metadata": { 28 | "id": "3kDj8fhn52j3" 29 | }, 30 | "id": "3kDj8fhn52j3" 31 | }, 32 | { 33 | "cell_type": "code", 34 | "source": [ 35 | "import torch\n", 36 | "import torch.nn as nn\n", 37 | "import torch.nn.functional as F\n", 38 | "from functools import partial\n", 39 | "_ = torch.manual_seed(0)" 40 | ], 41 | "metadata": { 42 | "id": "w_IinyjzflUH" 43 | }, 44 | "execution_count": null, 45 | "outputs": [], 46 | "id": "w_IinyjzflUH" 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "source": [ 51 | "Let’s start with a function that we’d like to compute the jacobian of. This is a simple linear function with non-linear activation.\n", 52 | "\n" 53 | ], 54 | "metadata": { 55 | "id": "cibF_PEYflUH" 56 | }, 57 | "id": "cibF_PEYflUH" 58 | }, 59 | { 60 | "cell_type": "code", 61 | "source": [ 62 | "def predict(weight, bias, x):\n", 63 | " return F.linear(x, weight, bias).tanh()" 64 | ], 65 | "metadata": { 66 | "id": "qhcD9hWYflUH" 67 | }, 68 | "execution_count": null, 69 | "outputs": [], 70 | "id": "qhcD9hWYflUH" 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "source": [ 75 | "Let's add some dummy data: a weight, a bias, and a feature vector x.\n", 76 | "\n" 77 | ], 78 | "metadata": { 79 | "id": "G8tqQrO_flUH" 80 | }, 81 | "id": "G8tqQrO_flUH" 82 | }, 83 | { 84 | "cell_type": "code", 85 | "source": [ 86 | "D = 16\n", 87 | "weight = torch.randn(D, D)\n", 88 | "bias = torch.randn(D)\n", 89 | "x = torch.randn(D) # feature vector" 90 | ], 91 | "metadata": { 92 | "id": "FZ4uJfZGflUH" 93 | }, 94 | "execution_count": null, 95 | "outputs": [], 96 | "id": "FZ4uJfZGflUH" 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "source": [ 101 | "Let's think of `predict` as a function that maps the input `x` from $R^D -> R^D$.\n", 102 | "PyTorch Autograd computes vector-Jacobian products. In order to compute the full\n", 103 | "Jacobian of this $R^D -> R^D$ function, we would have to compute it row-by-row\n", 104 | "by using a different unit vector each time." 105 | ], 106 | "metadata": { 107 | "id": "uMAW-ArQflUH" 108 | }, 109 | "id": "uMAW-ArQflUH" 110 | }, 111 | { 112 | "cell_type": "code", 113 | "source": [ 114 | "def compute_jac(xp):\n", 115 | " jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]\n", 116 | " for vec in unit_vectors]\n", 117 | " return torch.stack(jacobian_rows)" 118 | ], 119 | "metadata": { 120 | "id": "z-BJPtbpflUI" 121 | }, 122 | "execution_count": null, 123 | "outputs": [], 124 | "id": "z-BJPtbpflUI" 125 | }, 126 | { 127 | "cell_type": "code", 128 | "source": [ 129 | "xp = x.clone().requires_grad_()\n", 130 | "unit_vectors = torch.eye(D)\n", 131 | "\n", 132 | "jacobian = compute_jac(xp)\n", 133 | "\n", 134 | "print(jacobian.shape)\n", 135 | "print(jacobian[0]) # show first row" 136 | ], 137 | "metadata": { 138 | "colab": { 139 | "base_uri": "https://localhost:8080/" 140 | }, 141 | "outputId": "f1f1ec12-56ef-40f7-8c3c-cbad7bf86644", 142 | "id": "zuWGSXspflUI" 143 | }, 144 | "execution_count": null, 145 | "outputs": [ 146 | { 147 | "output_type": "stream", 148 | "name": "stdout", 149 | "text": [ 150 | "torch.Size([16, 16])\n", 151 | "tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,\n", 152 | " 0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])\n" 153 | ] 154 | } 155 | ], 156 | "id": "zuWGSXspflUI" 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "source": [ 161 | "Instead of computing the jacobian row-by-row, we can use vmap to get rid of the for-loop and vectorize the computation. \n", 162 | "We can’t directly apply vmap to PyTorch Autograd; instead, functorch provides a vjp transform:\n", 163 | "\n" 164 | ], 165 | "metadata": { 166 | "id": "mxlEOUieflUI" 167 | }, 168 | "id": "mxlEOUieflUI" 169 | }, 170 | { 171 | "cell_type": "code", 172 | "source": [ 173 | "from functorch import vmap, vjp\n", 174 | "\n", 175 | "_, vjp_fn = vjp(partial(predict, weight, bias), x)\n", 176 | "\n", 177 | "ft_jacobian, = vmap(vjp_fn)(unit_vectors)\n", 178 | "\n", 179 | "# lets confirm both methods compute the same result\n", 180 | "assert torch.allclose(ft_jacobian, jacobian)" 181 | ], 182 | "metadata": { 183 | "id": "DeF6uy4WflUI" 184 | }, 185 | "execution_count": null, 186 | "outputs": [], 187 | "id": "DeF6uy4WflUI" 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "source": [ 192 | "In future tutorial a composition of reverse-mode AD and vmap will give us per-sample-gradients. \n", 193 | "In this tutorial, composing reverse-mode AD and vmap gives us Jacobian computation! \n", 194 | "Various compositions of vmap and autodiff transforms can give us different interesting quantities.\n", 195 | "\n", 196 | "functorch provides **jacrev** as a convenience function that performs the vmap-vjp composition to compute jacobians. **jacrev** accepts an argnums argument that says which argument we would like to compute Jacobians with respect to.\n", 197 | "\n" 198 | ], 199 | "metadata": { 200 | "id": "Hy4REmwDflUI" 201 | }, 202 | "id": "Hy4REmwDflUI" 203 | }, 204 | { 205 | "cell_type": "code", 206 | "source": [ 207 | "from functorch import jacrev\n", 208 | "\n", 209 | "ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)\n", 210 | "\n", 211 | "# confirm \n", 212 | "assert torch.allclose(ft_jacobian, jacobian)" 213 | ], 214 | "metadata": { 215 | "id": "Rt7i6_YlflUI" 216 | }, 217 | "execution_count": null, 218 | "outputs": [], 219 | "id": "Rt7i6_YlflUI" 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "source": [ 224 | "Let’s compare the performance of the two ways to compute the jacobian. The functorch version is much faster (and becomes even faster the more outputs there are). \n", 225 | "\n", 226 | "In general, we expect that vectorization via vmap can help eliminate overhead and give better utilization of your hardware.\n", 227 | "\n", 228 | "Vmap does this magic by pushing the outer loop down into the functions primitive operations in order to obtain better performance.\n", 229 | "\n", 230 | "\n" 231 | ], 232 | "metadata": { 233 | "id": "JYe2H1UcflUJ" 234 | }, 235 | "id": "JYe2H1UcflUJ" 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "source": [ 240 | "Let's make a quick function to evaluate performance and deal with microseconds and milliseconds measurements:" 241 | ], 242 | "metadata": { 243 | "id": "i_143LZwflUJ" 244 | }, 245 | "id": "i_143LZwflUJ" 246 | }, 247 | { 248 | "cell_type": "code", 249 | "source": [ 250 | "def get_perf(first, first_descriptor, second, second_descriptor):\n", 251 | " \"\"\" takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n", 252 | " faster = second.times[0]\n", 253 | " slower = first.times[0]\n", 254 | " gain = (slower-faster)/slower\n", 255 | " if gain < 0: gain *=-1 \n", 256 | " final_gain = gain*100\n", 257 | " print(f\" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} \")" 258 | ], 259 | "metadata": { 260 | "id": "II7r6jBtflUJ" 261 | }, 262 | "execution_count": null, 263 | "outputs": [], 264 | "id": "II7r6jBtflUJ" 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "source": [ 269 | "And then run the performance comparison:" 270 | ], 271 | "metadata": { 272 | "id": "r4clPnPKflUJ" 273 | }, 274 | "id": "r4clPnPKflUJ" 275 | }, 276 | { 277 | "cell_type": "code", 278 | "source": [ 279 | "from torch.utils.benchmark import Timer\n", 280 | "\n", 281 | "without_vmap = Timer(stmt=\"compute_jac(xp)\", globals=globals())\n", 282 | "with_vmap = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", 283 | "\n", 284 | "no_vmap_timer = without_vmap.timeit(500)\n", 285 | "with_vmap_timer = with_vmap.timeit(500)\n", 286 | "\n", 287 | "print(no_vmap_timer)\n", 288 | "print(with_vmap_timer)" 289 | ], 290 | "metadata": { 291 | "colab": { 292 | "base_uri": "https://localhost:8080/" 293 | }, 294 | "outputId": "cbf77a19-aac9-428d-eba1-74d337c53e49", 295 | "id": "ZPtoxF6eflUJ" 296 | }, 297 | "execution_count": null, 298 | "outputs": [ 299 | { 300 | "output_type": "stream", 301 | "name": "stdout", 302 | "text": [ 303 | "\n", 304 | "compute_jac(xp)\n", 305 | " 2.25 ms\n", 306 | " 1 measurement, 500 runs , 1 thread\n", 307 | "\n", 308 | "jacrev(predict, argnums=2)(weight, bias, x)\n", 309 | " 884.34 us\n", 310 | " 1 measurement, 500 runs , 1 thread\n" 311 | ] 312 | } 313 | ], 314 | "id": "ZPtoxF6eflUJ" 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "source": [ 319 | "Lets do a relative performance comparison of the above with our get_perf function:" 320 | ], 321 | "metadata": { 322 | "id": "nGBBi4dZflUJ" 323 | }, 324 | "id": "nGBBi4dZflUJ" 325 | }, 326 | { 327 | "cell_type": "code", 328 | "source": [ 329 | "get_perf(no_vmap_timer, \"without vmap\", with_vmap_timer, \"vmap\");" 330 | ], 331 | "metadata": { 332 | "colab": { 333 | "base_uri": "https://localhost:8080/" 334 | }, 335 | "outputId": "85d0bc5f-34aa-4826-f953-6c637404490c", 336 | "id": "zqV2RzEXflUJ" 337 | }, 338 | "execution_count": null, 339 | "outputs": [ 340 | { 341 | "output_type": "stream", 342 | "name": "stdout", 343 | "text": [ 344 | " Performance delta: 60.7170 percent improvement with vmap \n" 345 | ] 346 | } 347 | ], 348 | "id": "zqV2RzEXflUJ" 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "source": [ 353 | "Furthemore, it’s pretty easy to flip the problem around and say we want to compute Jacobians of the parameters to our model (weight, bias) instead of the input." 354 | ], 355 | "metadata": { 356 | "id": "EQAB99EQflUJ" 357 | }, 358 | "id": "EQAB99EQflUJ" 359 | }, 360 | { 361 | "cell_type": "code", 362 | "source": [ 363 | "# note the change in input via argnums params of 0,1 to map to weight and bias\n", 364 | "ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)" 365 | ], 366 | "metadata": { 367 | "id": "8UZpC8DnflUK" 368 | }, 369 | "execution_count": null, 370 | "outputs": [], 371 | "id": "8UZpC8DnflUK" 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "source": [ 376 | "## reverse-mode Jacobian (jacrev) vs forward-mode Jacobian (jacfwd)\n" 377 | ], 378 | "metadata": { 379 | "id": "F3USYENIflUK" 380 | }, 381 | "id": "F3USYENIflUK" 382 | }, 383 | { 384 | "cell_type": "markdown", 385 | "source": [ 386 | "We offer two APIs to compute jacobians: **jacrev** and **jacfwd**: \n", 387 | "- jacrev uses reverse-mode AD. As you saw above it is a composition of our vjp and vmap transforms. \n", 388 | "- jacfwd uses forward-mode AD. It is implemented as a composition of our jvp and vmap transforms. \n", 389 | "\n", 390 | "jacfwd and jacrev can be substituted for each other but they have different performance characteristics.\n", 391 | "\n", 392 | "As a general rule of thumb, if you’re computing the jacobian of an $𝑅^N \\to R^M$ function, and there are many more outputs than inputs (i.e. $M > N$) then jacfwd is preferred, otherwise use jacrev. There are exceptions to this rule, but a non-rigorous argument for this follows:\n", 393 | "\n", 394 | "In reverse-mode AD, we are computing the jacobian row-by-row, while in forward-mode AD (which computes Jacobian-vector products), we are computing it column-by-column. The Jacobian matrix has M rows and N columns, so if it is taller or wider one way we may prefer the method that deals with fewer rows or columns.\n", 395 | "\n" 396 | ], 397 | "metadata": { 398 | "id": "V7B3vE8dflUK" 399 | }, 400 | "id": "V7B3vE8dflUK" 401 | }, 402 | { 403 | "cell_type": "code", 404 | "source": [ 405 | "from functorch import jacrev, jacfwd" 406 | ], 407 | "metadata": { 408 | "id": "k7Tok7m3flUK" 409 | }, 410 | "execution_count": null, 411 | "outputs": [], 412 | "id": "k7Tok7m3flUK" 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "source": [ 417 | "First, let's benchmark with more inputs than outputs:\n", 418 | "\n" 419 | ], 420 | "metadata": { 421 | "id": "YrV-gZAaflUL" 422 | }, 423 | "id": "YrV-gZAaflUL" 424 | }, 425 | { 426 | "cell_type": "code", 427 | "source": [ 428 | "Din = 32\n", 429 | "Dout = 2048\n", 430 | "weight = torch.randn(Dout, Din)\n", 431 | "\n", 432 | "bias = torch.randn(Dout)\n", 433 | "x = torch.randn(Din)\n", 434 | "\n", 435 | "# remember the general rule about taller vs wider...here we have a taller matrix:\n", 436 | "print(weight.shape)\n", 437 | "\n", 438 | "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", 439 | "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", 440 | "\n", 441 | "jacfwd_timing = using_fwd.timeit(500)\n", 442 | "jacrev_timing = using_bwd.timeit(500)\n", 443 | "\n", 444 | "print(f'jacfwd time: {jacfwd_timing}')\n", 445 | "print(f'jacrev time: {jacrev_timing}')\n" 446 | ], 447 | "metadata": { 448 | "colab": { 449 | "base_uri": "https://localhost:8080/" 450 | }, 451 | "outputId": "dd882726-9723-47c0-a72f-3c7835a85aa1", 452 | "id": "m5j-4hSxflUL" 453 | }, 454 | "execution_count": null, 455 | "outputs": [ 456 | { 457 | "output_type": "stream", 458 | "name": "stdout", 459 | "text": [ 460 | "torch.Size([2048, 32])\n", 461 | "jacfwd time: \n", 462 | "jacfwd(predict, argnums=2)(weight, bias, x)\n", 463 | " 1.32 ms\n", 464 | " 1 measurement, 500 runs , 1 thread\n", 465 | "jacrev time: \n", 466 | "jacrev(predict, argnums=2)(weight, bias, x)\n", 467 | " 12.46 ms\n", 468 | " 1 measurement, 500 runs , 1 thread\n" 469 | ] 470 | } 471 | ], 472 | "id": "m5j-4hSxflUL" 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "source": [ 477 | "and then do a relative benchmark:" 478 | ], 479 | "metadata": { 480 | "id": "k_Sg-4tVflUL" 481 | }, 482 | "id": "k_Sg-4tVflUL" 483 | }, 484 | { 485 | "cell_type": "code", 486 | "source": [ 487 | "get_perf(jacfwd_timing, \"jacfwd\", jacrev_timing, \"jacrev\", );" 488 | ], 489 | "metadata": { 490 | "colab": { 491 | "base_uri": "https://localhost:8080/" 492 | }, 493 | "outputId": "3a6586a1-269d-46d8-d119-e24f6d46277f", 494 | "id": "_4T96zGjflUL" 495 | }, 496 | "execution_count": null, 497 | "outputs": [ 498 | { 499 | "output_type": "stream", 500 | "name": "stdout", 501 | "text": [ 502 | " Performance delta: 842.8274 percent improvement with jacrev \n" 503 | ] 504 | } 505 | ], 506 | "id": "_4T96zGjflUL" 507 | }, 508 | { 509 | "cell_type": "markdown", 510 | "source": [ 511 | "and now the reverse - more outputs (M) than inputs (N):" 512 | ], 513 | "metadata": { 514 | "id": "RCDPot1yflUL" 515 | }, 516 | "id": "RCDPot1yflUL" 517 | }, 518 | { 519 | "cell_type": "code", 520 | "source": [ 521 | "Din = 2048\n", 522 | "Dout = 32\n", 523 | "weight = torch.randn(Dout, Din)\n", 524 | "bias = torch.randn(Dout)\n", 525 | "x = torch.randn(Din)\n", 526 | "\n", 527 | "using_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", 528 | "using_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n", 529 | "\n", 530 | "jacfwd_timing = using_fwd.timeit(500)\n", 531 | "jacrev_timing = using_bwd.timeit(500)\n", 532 | "\n", 533 | "print(f'jacfwd time: {jacfwd_timing}')\n", 534 | "print(f'jacrev time: {jacrev_timing}')" 535 | ], 536 | "metadata": { 537 | "colab": { 538 | "base_uri": "https://localhost:8080/" 539 | }, 540 | "outputId": "913e9ccd-3d4f-472a-a749-19cee36d0a16", 541 | "id": "_DRFqzqZflUM" 542 | }, 543 | "execution_count": null, 544 | "outputs": [ 545 | { 546 | "output_type": "stream", 547 | "name": "stdout", 548 | "text": [ 549 | "jacfwd time: \n", 550 | "jacfwd(predict, argnums=2)(weight, bias, x)\n", 551 | " 7.99 ms\n", 552 | " 1 measurement, 500 runs , 1 thread\n", 553 | "jacrev time: \n", 554 | "jacrev(predict, argnums=2)(weight, bias, x)\n", 555 | " 1.09 ms\n", 556 | " 1 measurement, 500 runs , 1 thread\n" 557 | ] 558 | } 559 | ], 560 | "id": "_DRFqzqZflUM" 561 | }, 562 | { 563 | "cell_type": "markdown", 564 | "source": [ 565 | "and a relative perf comparison:" 566 | ], 567 | "metadata": { 568 | "id": "5SRbMCNsflUM" 569 | }, 570 | "id": "5SRbMCNsflUM" 571 | }, 572 | { 573 | "cell_type": "code", 574 | "source": [ 575 | "get_perf(jacrev_timing, \"jacrev\", jacfwd_timing, \"jacfwd\")" 576 | ], 577 | "metadata": { 578 | "colab": { 579 | "base_uri": "https://localhost:8080/" 580 | }, 581 | "outputId": "c282ce25-4f6e-44cd-aed7-60f6f5010e5b", 582 | "id": "uF_9GaoiflUM" 583 | }, 584 | "execution_count": null, 585 | "outputs": [ 586 | { 587 | "output_type": "stream", 588 | "name": "stdout", 589 | "text": [ 590 | " Performance delta: 635.2095 percent improvement with jacfwd \n" 591 | ] 592 | } 593 | ], 594 | "id": "uF_9GaoiflUM" 595 | }, 596 | { 597 | "cell_type": "markdown", 598 | "source": [ 599 | "## Hessian computation with functorch.hessian\n" 600 | ], 601 | "metadata": { 602 | "id": "J29FQaBQflUM" 603 | }, 604 | "id": "J29FQaBQflUM" 605 | }, 606 | { 607 | "cell_type": "markdown", 608 | "source": [ 609 | "We offer a convenience API to compute hessians: `functorch.hessian`. \n", 610 | "Hessians are the jacobian of the jacobian (or the partial derivative of the partial derivative, aka second order).\n", 611 | "\n", 612 | "This suggests that one can just compose functorch’s jacobian transforms to compute the Hessian. \n", 613 | "Indeed, under the hood, `hessian(f)` is simply `jacfwd(jacrev(f))`.\n", 614 | "\n" 615 | ], 616 | "metadata": { 617 | "id": "My4DPH97flUM" 618 | }, 619 | "id": "My4DPH97flUM" 620 | }, 621 | { 622 | "cell_type": "markdown", 623 | "source": [ 624 | "Note: to boost performance: depending on your model, you may also want to use `jacfwd(jacfwd(f))` or `jacrev(jacrev(f))` instead to compute hessians leveraging the rule of thumb above regarding wider vs taller matrices.\n", 625 | "\n" 626 | ], 627 | "metadata": { 628 | "id": "FJt038l5flUM" 629 | }, 630 | "id": "FJt038l5flUM" 631 | }, 632 | { 633 | "cell_type": "code", 634 | "source": [ 635 | "from functorch import hessian\n", 636 | "\n", 637 | "# lets reduce the size in order not to blow out colab. Hessians require significant memory:\n", 638 | "Din = 512\n", 639 | "Dout = 32\n", 640 | "weight = torch.randn(Dout, Din)\n", 641 | "bias = torch.randn(Dout)\n", 642 | "x = torch.randn(Din)\n", 643 | "\n", 644 | "hess_api = hessian(predict, argnums=2)(weight, bias, x)\n", 645 | "hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)\n", 646 | "#hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)\n" 647 | ], 648 | "metadata": { 649 | "id": "jEqr2ywZflUM" 650 | }, 651 | "execution_count": null, 652 | "outputs": [], 653 | "id": "jEqr2ywZflUM" 654 | }, 655 | { 656 | "cell_type": "markdown", 657 | "source": [ 658 | "Let's verify we have the same result regardless of using hessian api or using jacfwd(jacfwd())" 659 | ], 660 | "metadata": { 661 | "id": "n9BHcICQflUN" 662 | }, 663 | "id": "n9BHcICQflUN" 664 | }, 665 | { 666 | "cell_type": "code", 667 | "source": [ 668 | "torch.allclose(hess_api, hess_fwdfwd)" 669 | ], 670 | "metadata": { 671 | "colab": { 672 | "base_uri": "https://localhost:8080/" 673 | }, 674 | "outputId": "e457e3bc-f085-4f90-966d-f98893b98ea8", 675 | "id": "eHiWRkjJflUN" 676 | }, 677 | "execution_count": null, 678 | "outputs": [ 679 | { 680 | "output_type": "execute_result", 681 | "data": { 682 | "text/plain": [ 683 | "True" 684 | ] 685 | }, 686 | "metadata": {}, 687 | "execution_count": 18 688 | } 689 | ], 690 | "id": "eHiWRkjJflUN" 691 | }, 692 | { 693 | "cell_type": "markdown", 694 | "source": [ 695 | "## Batch Jacobian and Batch Hessian\n" 696 | ], 697 | "metadata": { 698 | "id": "Gjt1RO8HflUN" 699 | }, 700 | "id": "Gjt1RO8HflUN" 701 | }, 702 | { 703 | "cell_type": "markdown", 704 | "source": [ 705 | "In the above examples we’ve been operating with a single feature vector. In some cases you might want to take the Jacobian of a batch of outputs with respect to a batch of inputs. That is, given a batch of inputs of shape `(B, N)` and a function that goes from $R^N \\to R^M$, we would like a Jacobian of shape `(B, M, N)`. \n", 706 | "\n", 707 | "The easiest way to do this is to use vmap:" 708 | ], 709 | "metadata": { 710 | "id": "RjIzdoQNflUN" 711 | }, 712 | "id": "RjIzdoQNflUN" 713 | }, 714 | { 715 | "cell_type": "code", 716 | "source": [ 717 | "batch_size = 64\n", 718 | "Din = 31\n", 719 | "Dout = 33\n", 720 | "\n", 721 | "weight = torch.randn(Dout, Din)\n", 722 | "print(f\"weight shape = {weight.shape}\")\n", 723 | "\n", 724 | "bias = torch.randn(Dout)\n", 725 | "\n", 726 | "x = torch.randn(batch_size, Din)" 727 | ], 728 | "metadata": { 729 | "colab": { 730 | "base_uri": "https://localhost:8080/" 731 | }, 732 | "outputId": "561eb618-e00f-40d5-bd99-fa51ab82051f", 733 | "id": "B1eoEO4UflUN" 734 | }, 735 | "execution_count": null, 736 | "outputs": [ 737 | { 738 | "output_type": "stream", 739 | "name": "stdout", 740 | "text": [ 741 | "weight shape = torch.Size([33, 31])\n" 742 | ] 743 | } 744 | ], 745 | "id": "B1eoEO4UflUN" 746 | }, 747 | { 748 | "cell_type": "code", 749 | "source": [ 750 | "compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))\n", 751 | "batch_jacobian0 = compute_batch_jacobian(weight, bias, x)" 752 | ], 753 | "metadata": { 754 | "id": "nZ_V02NhflUN" 755 | }, 756 | "execution_count": null, 757 | "outputs": [], 758 | "id": "nZ_V02NhflUN" 759 | }, 760 | { 761 | "cell_type": "markdown", 762 | "source": [ 763 | "If you have a function that goes from (B, N) -> (B, M) instead and are certain that each input produces an independent output, then it’s also sometimes possible to do this without using vmap by summing the outputs and then computing the Jacobian of that function:\n", 764 | "\n" 765 | ], 766 | "metadata": { 767 | "id": "_OLDiY3MflUN" 768 | }, 769 | "id": "_OLDiY3MflUN" 770 | }, 771 | { 772 | "cell_type": "code", 773 | "source": [ 774 | "def predict_with_output_summed(weight, bias, x):\n", 775 | " return predict(weight, bias, x).sum(0)\n", 776 | "\n", 777 | "batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)\n", 778 | "assert torch.allclose(batch_jacobian0, batch_jacobian1)" 779 | ], 780 | "metadata": { 781 | "id": "_QH4hD8PflUO" 782 | }, 783 | "execution_count": null, 784 | "outputs": [], 785 | "id": "_QH4hD8PflUO" 786 | }, 787 | { 788 | "cell_type": "markdown", 789 | "source": [ 790 | "If you instead have a function that goes from $𝑅^𝑁 \\to 𝑅^𝑀$ but inputs that are batched, you compose vmap with jacrev to compute batched jacobians:\n", 791 | "\n", 792 | "Finally, batch hessians can be computed similarly. It’s easiest to think about them by using vmap to batch over hessian computation, but in some cases the sum trick also works.\n", 793 | "\n" 794 | ], 795 | "metadata": { 796 | "id": "eUjw65cCflUO" 797 | }, 798 | "id": "eUjw65cCflUO" 799 | }, 800 | { 801 | "cell_type": "code", 802 | "source": [ 803 | "compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))\n", 804 | "\n", 805 | "batch_hess = compute_batch_hessian(weight, bias, x)\n", 806 | "batch_hess.shape" 807 | ], 808 | "metadata": { 809 | "colab": { 810 | "base_uri": "https://localhost:8080/" 811 | }, 812 | "outputId": "f3135cfa-e9e5-4f18-8cb7-0655e8a37cb5", 813 | "id": "3vAyQjMsflUO" 814 | }, 815 | "execution_count": null, 816 | "outputs": [ 817 | { 818 | "output_type": "execute_result", 819 | "data": { 820 | "text/plain": [ 821 | "torch.Size([64, 33, 31, 31])" 822 | ] 823 | }, 824 | "metadata": {}, 825 | "execution_count": 22 826 | } 827 | ], 828 | "id": "3vAyQjMsflUO" 829 | }, 830 | { 831 | "cell_type": "markdown", 832 | "source": [ 833 | "## Computing Hessian-vector products\n", 834 | "\n", 835 | "The naive way to compute a Hessian-vector product (hvp) is to materialize the full Hessian and perform a dot-product with a vector. We can do better: it turns out we don't need to materialize the full Hessian to do this. We'll go through two (of many) different strategies to compute Hessian-vector products:\n", 836 | "- composing reverse-mode AD with reverse-mode AD\n", 837 | "- composing reverse-mode AD with forward-mode AD\n", 838 | "\n", 839 | "Composing reverse-mode AD with forward-mode AD (as opposed to reverse-mode with reverse-mode) is generally the more memory efficient way to compute a hvp because forward-mode AD doesn't need to construct an Autograd graph and save intermediates for backward:" 840 | ], 841 | "metadata": { 842 | "id": "Wa8E48sQgpkb" 843 | }, 844 | "id": "Wa8E48sQgpkb" 845 | }, 846 | { 847 | "cell_type": "code", 848 | "source": [ 849 | "from functorch import jvp, grad, vjp\n", 850 | "\n", 851 | "def hvp(f, primals, tangents):\n", 852 | " return jvp(grad(f), primals, tangents)[1]" 853 | ], 854 | "metadata": { 855 | "id": "trw6WbAth6BM" 856 | }, 857 | "execution_count": null, 858 | "outputs": [], 859 | "id": "trw6WbAth6BM" 860 | }, 861 | { 862 | "cell_type": "markdown", 863 | "source": [ 864 | "Here's some sample usage." 865 | ], 866 | "metadata": { 867 | "id": "DQMpRo6nitfr" 868 | }, 869 | "id": "DQMpRo6nitfr" 870 | }, 871 | { 872 | "cell_type": "code", 873 | "source": [ 874 | "def f(x):\n", 875 | " return x.sin().sum()\n", 876 | "\n", 877 | "x = torch.randn(2048)\n", 878 | "tangent = torch.randn(2048)\n", 879 | "\n", 880 | "result = hvp(f, (x,), (tangent,))" 881 | ], 882 | "metadata": { 883 | "id": "sPwg8SOdiVAK" 884 | }, 885 | "execution_count": null, 886 | "outputs": [], 887 | "id": "sPwg8SOdiVAK" 888 | }, 889 | { 890 | "cell_type": "markdown", 891 | "source": [ 892 | "If PyTorch forward-AD does not have coverage for your operations, then we can instead compose reverse-mode AD with reverse-mode AD:" 893 | ], 894 | "metadata": { 895 | "id": "zGvUIcB0j1Ez" 896 | }, 897 | "id": "zGvUIcB0j1Ez" 898 | }, 899 | { 900 | "cell_type": "code", 901 | "source": [ 902 | "def hvp_revrev(f, primals, tangents):\n", 903 | " _, vjp_fn = vjp(grad(f), *primals)\n", 904 | " return vjp_fn(*tangents)" 905 | ], 906 | "metadata": { 907 | "id": "mdDFZdlekAOK" 908 | }, 909 | "execution_count": null, 910 | "outputs": [], 911 | "id": "mdDFZdlekAOK" 912 | }, 913 | { 914 | "cell_type": "code", 915 | "source": [ 916 | "result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))\n", 917 | "assert torch.allclose(result, result_hvp_revrev[0])" 918 | ], 919 | "metadata": { 920 | "id": "_CuCk9X0lW7C" 921 | }, 922 | "execution_count": null, 923 | "outputs": [], 924 | "id": "_CuCk9X0lW7C" 925 | } 926 | ], 927 | "metadata": { 928 | "kernelspec": { 929 | "display_name": "Python 3", 930 | "language": "python", 931 | "name": "python3" 932 | }, 933 | "language_info": { 934 | "codemirror_mode": { 935 | "name": "ipython", 936 | "version": 3 937 | }, 938 | "file_extension": ".py", 939 | "mimetype": "text/x-python", 940 | "name": "python", 941 | "nbconvert_exporter": "python", 942 | "pygments_lexer": "ipython3", 943 | "version": "3.8.3" 944 | }, 945 | "colab": { 946 | "name": "jacobians_hessians.ipynb", 947 | "provenance": [] 948 | } 949 | }, 950 | "nbformat": 4, 951 | "nbformat_minor": 5 952 | } 953 | -------------------------------------------------------------------------------- /notebooks/colab/per_sample_grads_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a474c143-05c4-43b6-b12c-17b592d07a6a", 6 | "metadata": { 7 | "id": "a474c143-05c4-43b6-b12c-17b592d07a6a" 8 | }, 9 | "source": [ 10 | "# Per-sample-gradients\n", 11 | "\n", 12 | "\n", 13 | " \"Open\n", 14 | "\n", 15 | "\n", 16 | "## What is it?\n", 17 | "\n", 18 | "Per-sample-gradient computation is computing the gradient for each and every\n", 19 | "sample in a batch of data. It is a useful quantity in differential privacy, meta-learning,\n", 20 | "and optimization research.\n" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "source": [ 26 | "import torch\n", 27 | "import torch.nn as nn\n", 28 | "import torch.nn.functional as F\n", 29 | "from functools import partial\n", 30 | "\n", 31 | "torch.manual_seed(0);" 32 | ], 33 | "metadata": { 34 | "id": "Gb-yt4VKUUuc" 35 | }, 36 | "execution_count": null, 37 | "outputs": [], 38 | "id": "Gb-yt4VKUUuc" 39 | }, 40 | { 41 | "cell_type": "code", 42 | "source": [ 43 | "# Here's a simple CNN and loss function:\n", 44 | "\n", 45 | "class SimpleCNN(nn.Module):\n", 46 | " def __init__(self):\n", 47 | " super(SimpleCNN, self).__init__()\n", 48 | " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", 49 | " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", 50 | " self.fc1 = nn.Linear(9216, 128)\n", 51 | " self.fc2 = nn.Linear(128, 10)\n", 52 | "\n", 53 | " def forward(self, x):\n", 54 | " x = self.conv1(x)\n", 55 | " x = F.relu(x)\n", 56 | " x = self.conv2(x)\n", 57 | " x = F.relu(x)\n", 58 | " x = F.max_pool2d(x, 2)\n", 59 | " x = torch.flatten(x, 1)\n", 60 | " x = self.fc1(x)\n", 61 | " x = F.relu(x)\n", 62 | " x = self.fc2(x)\n", 63 | " output = F.log_softmax(x, dim=1)\n", 64 | " output = x\n", 65 | " return output\n", 66 | "\n", 67 | "def loss_fn(predictions, targets):\n", 68 | " return F.nll_loss(predictions, targets)" 69 | ], 70 | "metadata": { 71 | "id": "tf-HKHjUUbyY" 72 | }, 73 | "execution_count": null, 74 | "outputs": [], 75 | "id": "tf-HKHjUUbyY" 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "source": [ 80 | "Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. \n", 81 | "\n", 82 | "The dummy images are 28 by 28 and we use a minibatch of size 64.\n", 83 | "\n" 84 | ], 85 | "metadata": { 86 | "id": "VEDPe-EoU5Fa" 87 | }, 88 | "id": "VEDPe-EoU5Fa" 89 | }, 90 | { 91 | "cell_type": "code", 92 | "source": [ 93 | "device = 'cuda'\n", 94 | "\n", 95 | "num_models = 10\n", 96 | "batch_size = 64\n", 97 | "data = torch.randn(batch_size, 1, 28, 28, device=device)\n", 98 | "\n", 99 | "targets = torch.randint(10, (64,), device=device)" 100 | ], 101 | "metadata": { 102 | "id": "WB2Qe3AHUvPN" 103 | }, 104 | "execution_count": null, 105 | "outputs": [], 106 | "id": "WB2Qe3AHUvPN" 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "source": [ 111 | "In regular model training, one would forward the minibatch through the model, and then call .backward() to compute gradients. This would generate an 'average' gradient of the entire mini-batch:\n", 112 | "\n" 113 | ], 114 | "metadata": { 115 | "id": "GOGJ-OUxVcT5" 116 | }, 117 | "id": "GOGJ-OUxVcT5" 118 | }, 119 | { 120 | "cell_type": "code", 121 | "source": [ 122 | "model = SimpleCNN().to(device=device)\n", 123 | "predictions = model(data) # move the entire mini-batch through the model\n", 124 | "\n", 125 | "loss = loss_fn(predictions, targets)\n", 126 | "loss.backward() # back propogate the 'average' gradient of this mini-batch" 127 | ], 128 | "metadata": { 129 | "id": "WYjMx8QTUvRu" 130 | }, 131 | "execution_count": null, 132 | "outputs": [], 133 | "id": "WYjMx8QTUvRu" 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "source": [ 138 | "In contrast to the above approach, per-sample-gradient computation is equivalent to: \n", 139 | "- for each individual sample of the data, perform a forward and a backward pass to get an individual (per-sample) gradient.\n", 140 | "\n" 141 | ], 142 | "metadata": { 143 | "id": "HNw4_IVzU5Pz" 144 | }, 145 | "id": "HNw4_IVzU5Pz" 146 | }, 147 | { 148 | "cell_type": "code", 149 | "source": [ 150 | "def compute_grad(sample, target):\n", 151 | " \n", 152 | " sample = sample.unsqueeze(0) # prepend batch dimension for processing\n", 153 | " target = target.unsqueeze(0)\n", 154 | "\n", 155 | " prediction = model(sample)\n", 156 | " loss = loss_fn(prediction, target)\n", 157 | "\n", 158 | " return torch.autograd.grad(loss, list(model.parameters()))\n", 159 | "\n", 160 | "\n", 161 | "def compute_sample_grads(data, targets):\n", 162 | " \"\"\" manually process each sample with per sample gradient \"\"\"\n", 163 | " sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]\n", 164 | " sample_grads = zip(*sample_grads)\n", 165 | " sample_grads = [torch.stack(shards) for shards in sample_grads]\n", 166 | " return sample_grads\n", 167 | "\n", 168 | "per_sample_grads = compute_sample_grads(data, targets)" 169 | ], 170 | "metadata": { 171 | "id": "vUsb3VfexJrY" 172 | }, 173 | "execution_count": null, 174 | "outputs": [], 175 | "id": "vUsb3VfexJrY" 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "source": [ 180 | "`sample_grads[0]` is the per-sample-grad for model.conv1.weight. `model.conv1.weight.shape` is `[32, 1, 3, 3]`; notice how there is one gradient, per sample, in the batch for a total of 64.\n", 181 | "\n", 182 | "\n", 183 | "\n" 184 | ], 185 | "metadata": { 186 | "id": "aNkX6lFIxzcm" 187 | }, 188 | "id": "aNkX6lFIxzcm" 189 | }, 190 | { 191 | "cell_type": "code", 192 | "source": [ 193 | "print(per_sample_grads[0].shape)" 194 | ], 195 | "metadata": { 196 | "id": "C3a9_clvyPho", 197 | "colab": { 198 | "base_uri": "https://localhost:8080/" 199 | }, 200 | "outputId": "407abc1a-846f-4e50-83bc-c90719a26073" 201 | }, 202 | "execution_count": null, 203 | "outputs": [ 204 | { 205 | "output_type": "stream", 206 | "name": "stdout", 207 | "text": [ 208 | "torch.Size([64, 32, 1, 3, 3])\n" 209 | ] 210 | } 211 | ], 212 | "id": "C3a9_clvyPho" 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "source": [ 217 | "## Per-sample-grads, *the efficient way*, using functorch\n", 218 | "\n", 219 | "\n" 220 | ], 221 | "metadata": { 222 | "id": "mFJDWMM9yaYZ" 223 | }, 224 | "id": "mFJDWMM9yaYZ" 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "source": [ 229 | "We can compute per-sample-gradients efficiently by using function transforms. \n", 230 | "\n", 231 | "First, let’s create a stateless functional version of `model` by using `functorch.make_functional_with_buffers`. \n", 232 | "\n", 233 | "This will separate state (the parameters) from the model and turn the model into a pure function:\n", 234 | "\n" 235 | ], 236 | "metadata": { 237 | "id": "tlkmyQyfY6XU" 238 | }, 239 | "id": "tlkmyQyfY6XU" 240 | }, 241 | { 242 | "cell_type": "code", 243 | "source": [ 244 | "from functorch import make_functional_with_buffers, vmap, grad\n", 245 | "\n", 246 | "fmodel, params, buffers = make_functional_with_buffers(model)" 247 | ], 248 | "metadata": { 249 | "id": "WiSMupvCyecd" 250 | }, 251 | "execution_count": null, 252 | "outputs": [], 253 | "id": "WiSMupvCyecd" 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "source": [ 258 | "Let's review the changes - first, the model has become the stateless FunctionalModuleWithBuffers:" 259 | ], 260 | "metadata": { 261 | "id": "wMsbppPNZklo" 262 | }, 263 | "id": "wMsbppPNZklo" 264 | }, 265 | { 266 | "cell_type": "code", 267 | "source": [ 268 | "fmodel" 269 | ], 270 | "metadata": { 271 | "colab": { 272 | "base_uri": "https://localhost:8080/" 273 | }, 274 | "id": "Xj0cZOJMZbbB", 275 | "outputId": "2e87dfde-3af2-4e1f-cd91-5c232446fb53" 276 | }, 277 | "execution_count": null, 278 | "outputs": [ 279 | { 280 | "output_type": "execute_result", 281 | "data": { 282 | "text/plain": [ 283 | "FunctionalModuleWithBuffers(\n", 284 | " (stateless_model): SimpleCNN(\n", 285 | " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n", 286 | " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", 287 | " (fc1): Linear(in_features=9216, out_features=128, bias=True)\n", 288 | " (fc2): Linear(in_features=128, out_features=10, bias=True)\n", 289 | " )\n", 290 | ")" 291 | ] 292 | }, 293 | "metadata": {}, 294 | "execution_count": 15 295 | } 296 | ], 297 | "id": "Xj0cZOJMZbbB" 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "source": [ 302 | "And the model parameters now exist independently of the model, stored as a tuple:" 303 | ], 304 | "metadata": { 305 | "id": "zv4_YYPxZvvg" 306 | }, 307 | "id": "zv4_YYPxZvvg" 308 | }, 309 | { 310 | "cell_type": "code", 311 | "source": [ 312 | "for x in params:\n", 313 | " print(f\"{x.shape}\")\n", 314 | "\n", 315 | "print(f\"\\n{type(params)}\")" 316 | ], 317 | "metadata": { 318 | "colab": { 319 | "base_uri": "https://localhost:8080/" 320 | }, 321 | "id": "tH0TAZhBZ3bS", 322 | "outputId": "97c4401f-cccb-43f6-b071-c85a18fc439b" 323 | }, 324 | "execution_count": null, 325 | "outputs": [ 326 | { 327 | "output_type": "stream", 328 | "name": "stdout", 329 | "text": [ 330 | "torch.Size([32, 1, 3, 3])\n", 331 | "torch.Size([32])\n", 332 | "torch.Size([64, 32, 3, 3])\n", 333 | "torch.Size([64])\n", 334 | "torch.Size([128, 9216])\n", 335 | "torch.Size([128])\n", 336 | "torch.Size([10, 128])\n", 337 | "torch.Size([10])\n", 338 | "\n", 339 | "\n" 340 | ] 341 | } 342 | ], 343 | "id": "tH0TAZhBZ3bS" 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "source": [ 348 | "Next, let’s define a function to compute the loss of the model given a single input rather than a batch of inputs. It is important that this function accepts the parameters, the input, and the target, because we will be transforming over them. \n", 349 | "\n", 350 | "Note - because the model was originally written to handle batches, we’ll use `torch.unsqueeze` to add a batch dimension.\n", 351 | "\n" 352 | ], 353 | "metadata": { 354 | "id": "cTgIIZ9Wyih8" 355 | }, 356 | "id": "cTgIIZ9Wyih8" 357 | }, 358 | { 359 | "cell_type": "code", 360 | "source": [ 361 | "def compute_loss_stateless_model (params, buffers, sample, target):\n", 362 | " batch = sample.unsqueeze(0)\n", 363 | " targets = target.unsqueeze(0)\n", 364 | "\n", 365 | " predictions = fmodel(params, buffers, batch) \n", 366 | " loss = loss_fn(predictions, targets)\n", 367 | " return loss" 368 | ], 369 | "metadata": { 370 | "id": "ItURFU3M-p98" 371 | }, 372 | "execution_count": null, 373 | "outputs": [], 374 | "id": "ItURFU3M-p98" 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "source": [ 379 | "Now, let’s use functorch's `grad` to create a new function that computes the gradient with respect to the first argument of `compute_loss` (i.e. the params)." 380 | ], 381 | "metadata": { 382 | "id": "Qo3sbDK2i_bH" 383 | }, 384 | "id": "Qo3sbDK2i_bH" 385 | }, 386 | { 387 | "cell_type": "code", 388 | "source": [ 389 | "ft_compute_grad = grad(compute_loss_stateless_model)" 390 | ], 391 | "metadata": { 392 | "id": "sqRp_Sxni-Xm" 393 | }, 394 | "execution_count": null, 395 | "outputs": [], 396 | "id": "sqRp_Sxni-Xm" 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "source": [ 401 | "The `ft_compute_grad` function computes the gradient for a single (sample, target) pair. We can use vmap to get it to compute the gradient over an entire batch of samples and targets. Note that `in_dims=(None, None, 0, 0)` because we wish to map `ft_compute_grad` over the 0th dimension of the data and targets, and use the same params and buffers for each.\n", 402 | "\n" 403 | ], 404 | "metadata": { 405 | "id": "2pG3Ofqjjc8O" 406 | }, 407 | "id": "2pG3Ofqjjc8O" 408 | }, 409 | { 410 | "cell_type": "code", 411 | "source": [ 412 | "ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))" 413 | ], 414 | "metadata": { 415 | "id": "62ecNMO6inqX" 416 | }, 417 | "execution_count": null, 418 | "outputs": [], 419 | "id": "62ecNMO6inqX" 420 | }, 421 | { 422 | "cell_type": "markdown", 423 | "source": [ 424 | "Finally, let’s used our transformed function to compute per-sample-gradients:\n", 425 | "\n" 426 | ], 427 | "metadata": { 428 | "id": "_alXdQ3QkETu" 429 | }, 430 | "id": "_alXdQ3QkETu" 431 | }, 432 | { 433 | "cell_type": "code", 434 | "source": [ 435 | "ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)\n", 436 | "\n", 437 | "# we can double check that the results using functorch grad and vmap match the results of hand processing each one individually:\n", 438 | "for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads):\n", 439 | " assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)" 440 | ], 441 | "metadata": { 442 | "id": "1gehVA1c-BHd" 443 | }, 444 | "execution_count": null, 445 | "outputs": [], 446 | "id": "1gehVA1c-BHd" 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "source": [ 451 | "A quick note: there are limitations around what types of functions can be transformed by vmap. The best functions to transform are ones that are pure functions: a function where the outputs are only determined by the inputs, and that have no side effects (e.g. mutation). vmap is unable to handle mutation of arbitrary Python data structures, but it is able to handle many in-place PyTorch operations.\n", 452 | "\n", 453 | "\n", 454 | "\n" 455 | ], 456 | "metadata": { 457 | "id": "BEZaNt1d_bc1" 458 | }, 459 | "id": "BEZaNt1d_bc1" 460 | }, 461 | { 462 | "cell_type": "markdown", 463 | "source": [ 464 | "## Performance comparison" 465 | ], 466 | "metadata": { 467 | "id": "BASP151Iml7B" 468 | }, 469 | "id": "BASP151Iml7B" 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "source": [ 474 | "Curious about how the performance of vmap compares?\n", 475 | "\n", 476 | "Currently the best results are obtained on newer GPU's such as the A100 (Ampere) where we've seen up to 25x speedups on this example, but here are some results done in Colab:" 477 | ], 478 | "metadata": { 479 | "id": "jr1xNpV4nJ7u" 480 | }, 481 | "id": "jr1xNpV4nJ7u" 482 | }, 483 | { 484 | "cell_type": "code", 485 | "source": [ 486 | "def get_perf(first, first_descriptor, second, second_descriptor):\n", 487 | " \"\"\" takes torch.benchmark objects and compares delta of second vs first. \"\"\"\n", 488 | " second_res = second.times[0]\n", 489 | " first_res = first.times[0]\n", 490 | "\n", 491 | " gain = (first_res-second_res)/first_res\n", 492 | " if gain < 0: gain *=-1 \n", 493 | " final_gain = gain*100\n", 494 | "\n", 495 | " print(f\" Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} \")" 496 | ], 497 | "metadata": { 498 | "id": "GnAnMkYmoc-j" 499 | }, 500 | "execution_count": null, 501 | "outputs": [], 502 | "id": "GnAnMkYmoc-j" 503 | }, 504 | { 505 | "cell_type": "code", 506 | "source": [ 507 | "from torch.utils.benchmark import Timer\n", 508 | "\n", 509 | "without_vmap = Timer( stmt=\"compute_sample_grads(data, targets)\", globals=globals())\n", 510 | "with_vmap = Timer(stmt=\"ft_compute_sample_grad(params, buffers, data, targets)\",globals=globals())\n", 511 | "no_vmap_timing = without_vmap.timeit(100)\n", 512 | "with_vmap_timing = with_vmap.timeit(100)\n", 513 | "\n", 514 | "print(f'Per-sample-grads without vmap {no_vmap_timing}')\n", 515 | "print(f'Per-sample-grads with vmap {with_vmap_timing}')" 516 | ], 517 | "metadata": { 518 | "id": "Zfnn2C2g-6Fb", 519 | "colab": { 520 | "base_uri": "https://localhost:8080/" 521 | }, 522 | "outputId": "922f3901-773f-446b-b562-88e78f49036c" 523 | }, 524 | "execution_count": null, 525 | "outputs": [ 526 | { 527 | "output_type": "stream", 528 | "name": "stdout", 529 | "text": [ 530 | "Per-sample-grads without vmap \n", 531 | "compute_sample_grads(data, targets)\n", 532 | " 79.86 ms\n", 533 | " 1 measurement, 100 runs , 1 thread\n", 534 | "Per-sample-grads with vmap \n", 535 | "ft_compute_sample_grad(params, buffers, data, targets)\n", 536 | " 12.93 ms\n", 537 | " 1 measurement, 100 runs , 1 thread\n" 538 | ] 539 | } 540 | ], 541 | "id": "Zfnn2C2g-6Fb" 542 | }, 543 | { 544 | "cell_type": "code", 545 | "source": [ 546 | "get_perf(with_vmap_timing, \"vmap\", no_vmap_timing,\"no vmap\" )" 547 | ], 548 | "metadata": { 549 | "colab": { 550 | "base_uri": "https://localhost:8080/" 551 | }, 552 | "id": "NV9R3LZQoavl", 553 | "outputId": "e11e8be9-287d-4e60-e517-e08f8d6909bd" 554 | }, 555 | "execution_count": null, 556 | "outputs": [ 557 | { 558 | "output_type": "stream", 559 | "name": "stdout", 560 | "text": [ 561 | " Performance delta: 517.5791 percent improvement with vmap \n" 562 | ] 563 | } 564 | ], 565 | "id": "NV9R3LZQoavl" 566 | }, 567 | { 568 | "cell_type": "markdown", 569 | "source": [ 570 | "There are other optimized solutions (like in https://github.com/pytorch/opacus) to computing per-sample-gradients in PyTorch that also perform better than the naive method. But it’s cool that composing `vmap` and `grad` give us a nice speedup.\n", 571 | "\n", 572 | "\n", 573 | "In general, vectorization with vmap should be faster than running a function in a for-loop and competitive with manual batching. There are some exceptions though, like if we haven’t implemented the vmap rule for a particular operation or if the underlying kernels weren’t optimized for older hardware (GPUs). If you see any of these cases, please let us know by opening an issue at our [GitHub](https://github.com/pytorch/functorch)!\n", 574 | "\n" 575 | ], 576 | "metadata": { 577 | "id": "UI74G9JarQU8" 578 | }, 579 | "id": "UI74G9JarQU8" 580 | } 581 | ], 582 | "metadata": { 583 | "kernelspec": { 584 | "display_name": "Python 3", 585 | "language": "python", 586 | "name": "python3" 587 | }, 588 | "language_info": { 589 | "codemirror_mode": { 590 | "name": "ipython", 591 | "version": 3 592 | }, 593 | "file_extension": ".py", 594 | "mimetype": "text/x-python", 595 | "name": "python", 596 | "nbconvert_exporter": "python", 597 | "pygments_lexer": "ipython3", 598 | "version": "3.8.5" 599 | }, 600 | "colab": { 601 | "name": "per_sample_grads.ipynb", 602 | "provenance": [] 603 | } 604 | }, 605 | "nbformat": 4, 606 | "nbformat_minor": 5 607 | } 608 | -------------------------------------------------------------------------------- /packaging/windows/internal/cuda_install.bat: -------------------------------------------------------------------------------- 1 | @echo on 2 | 3 | if "%CU_VERSION%" == "cpu" ( 4 | echo Skipping for CPU builds 5 | exit /b 0 6 | ) 7 | 8 | set SRC_DIR=%~dp0\.. 9 | 10 | if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build" 11 | 12 | rem in unit test workflow, we get CUDA_VERSION, for example 11.1 13 | if defined CUDA_VERSION ( 14 | set CUDA_VER=%CUDA_VERSION:.=% 15 | ) else ( 16 | set CUDA_VER=%CU_VERSION:cu=% 17 | ) 18 | 19 | set /a CUDA_VER=%CU_VERSION:cu=% 20 | set CUDA_VER_MAJOR=%CUDA_VER:~0,-1% 21 | set CUDA_VER_MINOR=%CUDA_VER:~-1,1% 22 | set CUDA_VERSION_STR=%CUDA_VER_MAJOR%.%CUDA_VER_MINOR% 23 | 24 | 25 | if %CUDA_VER% EQU 92 goto cuda92 26 | if %CUDA_VER% EQU 100 goto cuda100 27 | if %CUDA_VER% EQU 101 goto cuda101 28 | if %CUDA_VER% EQU 102 goto cuda102 29 | if %CUDA_VER% EQU 110 goto cuda110 30 | if %CUDA_VER% EQU 111 goto cuda111 31 | if %CUDA_VER% EQU 112 goto cuda112 32 | if %CUDA_VER% EQU 113 goto cuda113 33 | if %CUDA_VER% EQU 115 goto cuda115 34 | 35 | 36 | echo CUDA %CUDA_VERSION_STR% is not supported 37 | exit /b 1 38 | 39 | :cuda92 40 | if not exist "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" ( 41 | curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_9.2.148_win10.exe --output "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" 42 | if errorlevel 1 exit /b 1 43 | set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" 44 | set "ARGS=nvcc_9.2 cuobjdump_9.2 nvprune_9.2 cupti_9.2 cublas_9.2 cublas_dev_9.2 cudart_9.2 cufft_9.2 cufft_dev_9.2 curand_9.2 curand_dev_9.2 cusolver_9.2 cusolver_dev_9.2 cusparse_9.2 cusparse_dev_9.2 nvgraph_9.2 nvgraph_dev_9.2 npp_9.2 npp_dev_9.2 nvrtc_9.2 nvrtc_dev_9.2 nvml_dev_9.2" 45 | ) 46 | 47 | if not exist "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" ( 48 | curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-9.2-windows10-x64-v7.2.1.38.zip --output "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" 49 | if errorlevel 1 exit /b 1 50 | set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" 51 | ) 52 | 53 | goto cuda_common 54 | 55 | :cuda100 56 | 57 | if not exist "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" ( 58 | curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_10.0.130_411.31_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" 59 | if errorlevel 1 exit /b 1 60 | set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" 61 | set "ARGS=nvcc_10.0 cuobjdump_10.0 nvprune_10.0 cupti_10.0 cublas_10.0 cublas_dev_10.0 cudart_10.0 cufft_10.0 cufft_dev_10.0 curand_10.0 curand_dev_10.0 cusolver_10.0 cusolver_dev_10.0 cusparse_10.0 cusparse_dev_10.0 nvgraph_10.0 nvgraph_dev_10.0 npp_10.0 npp_dev_10.0 nvrtc_10.0 nvrtc_dev_10.0 nvml_dev_10.0" 62 | ) 63 | 64 | if not exist "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" ( 65 | curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-10.0-windows10-x64-v7.4.1.5.zip --output "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" 66 | if errorlevel 1 exit /b 1 67 | set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" 68 | ) 69 | 70 | goto cuda_common 71 | 72 | :cuda101 73 | 74 | if not exist "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" ( 75 | curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.1.243_426.00_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" 76 | if errorlevel 1 exit /b 1 77 | set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" 78 | set "ARGS=nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvjpeg_10.1 nvjpeg_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1" 79 | ) 80 | 81 | if not exist "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" ( 82 | curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.1-windows10-x64-v7.6.4.38.zip --output "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" 83 | if errorlevel 1 exit /b 1 84 | set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" 85 | ) 86 | 87 | goto cuda_common 88 | 89 | :cuda102 90 | 91 | if not exist "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" ( 92 | curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.2.89_441.22_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" 93 | if errorlevel 1 exit /b 1 94 | set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" 95 | set "ARGS=nvcc_10.2 cuobjdump_10.2 nvprune_10.2 cupti_10.2 cublas_10.2 cublas_dev_10.2 cudart_10.2 cufft_10.2 cufft_dev_10.2 curand_10.2 curand_dev_10.2 cusolver_10.2 cusolver_dev_10.2 cusparse_10.2 cusparse_dev_10.2 nvgraph_10.2 nvgraph_dev_10.2 npp_10.2 npp_dev_10.2 nvjpeg_10.2 nvjpeg_dev_10.2 nvrtc_10.2 nvrtc_dev_10.2 nvml_dev_10.2" 96 | ) 97 | 98 | if not exist "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" ( 99 | curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.2-windows10-x64-v7.6.5.32.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" 100 | if errorlevel 1 exit /b 1 101 | set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" 102 | ) 103 | 104 | rem The below only for cu102, if it's used in other version, e.g. cu111, torch.cuda.is_availabe() would be False. 105 | if not exist "%SRC_DIR%\temp_build\gpu_driver_dlls.7z" ( 106 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "%SRC_DIR%\temp_build\gpu_driver_dlls.zip" 107 | if errorlevel 1 exit /b 1 108 | ) 109 | 110 | echo Installing GPU driver DLLs 111 | 7z x %SRC_DIR%\temp_build\gpu_driver_dlls.zip -aoa -o"C:\Windows\System32" 112 | 113 | goto cuda_common 114 | 115 | :cuda110 116 | 117 | if not exist "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" ( 118 | curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.0.2_451.48_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" 119 | if errorlevel 1 exit /b 1 120 | set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" 121 | set "ARGS=nvcc_11.0 cuobjdump_11.0 nvprune_11.0 nvprof_11.0 cupti_11.0 cublas_11.0 cublas_dev_11.0 cudart_11.0 cufft_11.0 cufft_dev_11.0 curand_11.0 curand_dev_11.0 cusolver_11.0 cusolver_dev_11.0 cusparse_11.0 cusparse_dev_11.0 npp_11.0 npp_dev_11.0 nvjpeg_11.0 nvjpeg_dev_11.0 nvrtc_11.0 nvrtc_dev_11.0 nvml_dev_11.0" 122 | ) 123 | 124 | if not exist "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" ( 125 | curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.0-windows-x64-v8.0.4.30.zip --output "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" 126 | if errorlevel 1 exit /b 1 127 | set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" 128 | ) 129 | 130 | goto cuda_common 131 | 132 | :cuda111 133 | 134 | if not exist "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" ( 135 | curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.1.1_456.81_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" 136 | if errorlevel 1 exit /b 1 137 | set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" 138 | set "ARGS=nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvjpeg_11.1 nvjpeg_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1" 139 | ) 140 | 141 | if not exist "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" ( 142 | curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.1-windows-x64-v8.0.5.39.zip --output "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" 143 | if errorlevel 1 exit /b 1 144 | set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" 145 | ) 146 | 147 | goto cuda_common 148 | 149 | :cuda112 150 | 151 | if not exist "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" ( 152 | curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.2.0_460.89_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" 153 | if errorlevel 1 exit /b 1 154 | set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" 155 | set "ARGS=nvcc_11.2 cuobjdump_11.2 nvprune_11.2 nvprof_11.2 cupti_11.2 cublas_11.2 cublas_dev_11.2 cudart_11.2 cufft_11.2 cufft_dev_11.2 curand_11.2 curand_dev_11.2 cusolver_11.2 cusolver_dev_11.2 cusparse_11.2 cusparse_dev_11.2 npp_11.2 npp_dev_11.2 nvjpeg_11.2 nvjpeg_dev_11.2 nvrtc_11.2 nvrtc_dev_11.2 nvml_dev_11.2" 156 | ) 157 | 158 | if not exist "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" ( 159 | curl -k -L http://s3.amazonaws.com/ossci-windows/cudnn-11.2-windows-x64-v8.1.0.77.zip --output "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" 160 | if errorlevel 1 exit /b 1 161 | set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" 162 | ) 163 | 164 | goto cuda_common 165 | 166 | :cuda113 167 | 168 | set CUDA_INSTALL_EXE=cuda_11.3.0_465.89_win10.exe 169 | if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( 170 | curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" 171 | if errorlevel 1 exit /b 1 172 | set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" 173 | set "ARGS=thrust_11.3 nvcc_11.3 cuobjdump_11.3 nvprune_11.3 nvprof_11.3 cupti_11.3 cublas_11.3 cublas_dev_11.3 cudart_11.3 cufft_11.3 cufft_dev_11.3 curand_11.3 curand_dev_11.3 cusolver_11.3 cusolver_dev_11.3 cusparse_11.3 cusparse_dev_11.3 npp_11.3 npp_dev_11.3 nvjpeg_11.3 nvjpeg_dev_11.3 nvrtc_11.3 nvrtc_dev_11.3 nvml_dev_11.3" 174 | 175 | ) 176 | 177 | set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip 178 | if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( 179 | curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" 180 | if errorlevel 1 exit /b 1 181 | set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" 182 | ) 183 | 184 | goto cuda_common 185 | 186 | :cuda115 187 | 188 | set CUDA_INSTALL_EXE=cuda_11.5.0_496.13_win10.exe 189 | if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( 190 | curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" 191 | if errorlevel 1 exit /b 1 192 | set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" 193 | set "ARGS=thrust_11.5 nvcc_11.5 cuobjdump_11.5 nvprune_11.5 nvprof_11.5 cupti_11.5 cublas_11.5 cublas_dev_11.5 cudart_11.5 cufft_11.5 cufft_dev_11.5 curand_11.5 curand_dev_11.5 cusolver_11.5 cusolver_dev_11.5 cusparse_11.5 cusparse_dev_11.5 npp_11.5 npp_dev_11.5 nvrtc_11.5 nvrtc_dev_11.5 nvml_dev_11.5" 194 | ) 195 | 196 | set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip 197 | if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( 198 | curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" 199 | if errorlevel 1 exit /b 1 200 | set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" 201 | ) 202 | 203 | goto cuda_common 204 | 205 | :cuda_common 206 | 207 | if not exist "%SRC_DIR%\temp_build\NvToolsExt.7z" ( 208 | curl -k -L https://www.dropbox.com/s/9mcolalfdj4n979/NvToolsExt.7z?dl=1 --output "%SRC_DIR%\temp_build\NvToolsExt.7z" 209 | if errorlevel 1 exit /b 1 210 | ) 211 | 212 | echo Installing CUDA toolkit... 213 | 7z x %CUDA_SETUP_FILE% -o"%SRC_DIR%\temp_build\cuda" 214 | pushd "%SRC_DIR%\temp_build\cuda" 215 | sc config wuauserv start= disabled 216 | sc stop wuauserv 217 | sc query wuauserv 218 | 219 | start /wait setup.exe -s %ARGS% -loglevel:6 -log:"%cd%/cuda_install_logs" 220 | echo %errorlevel% 221 | 222 | popd 223 | 224 | echo Installing VS integration... 225 | rem It's for VS 2019 226 | if "%CUDA_VER_MAJOR%" == "10" ( 227 | xcopy /Y "%SRC_DIR%\temp_build\cuda\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" 228 | ) 229 | if "%CUDA_VER_MAJOR%" == "11" ( 230 | xcopy /Y "%SRC_DIR%\temp_build\cuda\visual_studio_integration\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" 231 | ) 232 | 233 | echo Installing NvToolsExt... 234 | 7z x %SRC_DIR%\temp_build\NvToolsExt.7z -o"%SRC_DIR%\temp_build\NvToolsExt" 235 | mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" 236 | mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" 237 | mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" 238 | xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\bin\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" 239 | xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\include\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" 240 | xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\lib\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" 241 | 242 | echo Setting up environment... 243 | set "PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\libnvvp;%PATH%" 244 | set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" 245 | set "CUDA_PATH_V%CUDA_VER_MAJOR%_%CUDA_VER_MINOR%=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" 246 | set "NVTOOLSEXT_PATH=%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" 247 | 248 | if not exist "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin\nvcc.exe" ( 249 | echo CUDA %CUDA_VERSION_STR% installed failed. 250 | echo --------- RunDll32.exe.log 251 | type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.RunDll32.exe.log" 252 | echo --------- setup.exe.log ------- 253 | type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.setup.exe.log" 254 | exit /b 1 255 | ) 256 | 257 | echo Installing cuDNN... 258 | 7z x %CUDNN_SETUP_FILE% -o"%SRC_DIR%\temp_build\cudnn" 259 | xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\bin\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin" 260 | xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\lib\x64\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\lib\x64" 261 | xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\include\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\include" 262 | 263 | echo Cleaning temp files 264 | rd /s /q "%SRC_DIR%\temp_build" || ver > nul 265 | -------------------------------------------------------------------------------- /packaging/windows/internal/driver_update.bat: -------------------------------------------------------------------------------- 1 | set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe" 2 | curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe 3 | if errorlevel 1 exit /b 1 4 | 5 | start /wait 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -s -noreboot 6 | if errorlevel 1 exit /b 1 7 | 8 | del 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe || ver > NUL 9 | 10 | setlocal EnableDelayedExpansion 11 | set NVIDIA_GPU_EXISTS=0 12 | for /F "delims=" %%i in ('wmic path win32_VideoController get name') do ( 13 | set GPUS=%%i 14 | if not "x!GPUS:NVIDIA=!" == "x!GPUS!" ( 15 | SET NVIDIA_GPU_EXISTS=1 16 | goto gpu_check_end 17 | ) 18 | ) 19 | :gpu_check_end 20 | endlocal & set NVIDIA_GPU_EXISTS=%NVIDIA_GPU_EXISTS% 21 | 22 | if "%NVIDIA_GPU_EXISTS%" == "0" ( 23 | echo "CUDA Driver installation Failed" 24 | exit /b 1 25 | ) 26 | -------------------------------------------------------------------------------- /pull_request_template.md: -------------------------------------------------------------------------------- 1 | To contribute a change to functorch, please make sure you are submitting a 2 | Pull Request to the functorch folder in https://github.com/pytorch/pytorch 3 | repository. The source of truth for functorch has moved there from 4 | https://github.com/pytorch/functorch ; the pytorch/functorch repository 5 | is now read-only. 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [metadata] 5 | license_file = LICENSE 6 | 7 | [pep8] 8 | max-line-length = 120 9 | 10 | [flake8] 11 | max-line-length = 120 12 | exclude = docs, benchmarks, notebooks, tools 13 | per-file-ignores = 14 | __init__.py: F401 15 | functorch/_src/decompositions.py: E501 16 | 17 | [pydocstyle] 18 | select = D417 # Missing argument descriptions in the docstring 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import subprocess 9 | from setuptools import setup 10 | 11 | cwd = os.path.dirname(os.path.abspath(__file__)) 12 | version_txt = os.path.join(cwd, 'version.txt') 13 | with open(version_txt, 'r') as f: 14 | version = f.readline().strip() 15 | 16 | try: 17 | sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() 18 | except Exception: 19 | sha = 'Unknown' 20 | package_name = 'functorch' 21 | 22 | if os.getenv('BUILD_VERSION'): 23 | version = os.getenv('BUILD_VERSION') 24 | elif sha != 'Unknown': 25 | version += '+' + sha[:7] 26 | 27 | 28 | requirements = [ 29 | # This represents a nightly version of PyTorch. 30 | # It can be installed as a binary or from source. 31 | "torch>=1.14.0.dev", 32 | ] 33 | 34 | extras = {} 35 | extras["aot"] = ["networkx", ] 36 | 37 | 38 | if __name__ == '__main__': 39 | try: 40 | setup( 41 | # Metadata 42 | name=package_name, 43 | version=version, 44 | author='PyTorch Core Team', 45 | url="https://github.com/pytorch/functorch", 46 | description='JAX-like composable function transforms for PyTorch', 47 | license='BSD', 48 | 49 | # Package info 50 | packages=[], 51 | install_requires=requirements, 52 | extras_require=extras, 53 | ) 54 | except Exception as e: 55 | print(e, file=sys.stderr) 56 | sys.exit(1) 57 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 1.14.0a0 2 | --------------------------------------------------------------------------------