├── .gitignore ├── RFC-0038-assets └── cpython_support.png ├── RFC-0030-assets └── fp8_binary_formats.png ├── CONTRIBUTING.md ├── LICENSE ├── CODE_OF_CONDUCT.md ├── README.md ├── RFC-0000-template.md ├── RFC-0038-cpython-support.md ├── RFC-0025-sdpa-optm-cpu.md ├── RFC-0024-rfc-process.md ├── RFC-0030-native-fp8-dtype.md ├── RFC-0011-InferenceMode.md ├── RFC-0017-PyTorch-Operator-Versioning.md ├── RFC-0012-profile-directed-typing.md ├── RFC-0032-numpy-support-in-dynamo.md ├── RFC-0006-conda-distribution.md ├── RFC-0020-Lightweight-Dispatch.md ├── RFC-0024-assets └── rfc-lifecycle.svg ├── RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md └── RFC-0001-torch-function-for-methods.md /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_STORE -------------------------------------------------------------------------------- /RFC-0038-assets/cpython_support.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rfcs/HEAD/RFC-0038-assets/cpython_support.png -------------------------------------------------------------------------------- /RFC-0030-assets/fp8_binary_formats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/rfcs/HEAD/RFC-0030-assets/fp8_binary_formats.png -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to PyTorch RFCs 2 | 3 | This repository is meant for Request For Comments (RFCs) - design proposals 4 | for topics that are too large to discuss on a standard feature request issue 5 | on the main PyTorch repository. 6 | 7 | If you are unsure whether something should be an RFC or a feature request 8 | issue, please open an issue in the main PyTorch repository first to discuss. 9 | 10 | ## License 11 | By contributing to rfcs, you agree that your contributions will be licensed 12 | under the LICENSE file in the root directory of this source tree. 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | All contributions by Facebook: 2 | Copyright (c) 2020 Facebook Inc. 3 | 4 | All other contributions: 5 | Copyright(c) 2020 the respective contributors 6 | 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions are met: 11 | 12 | 1. Redistributions of source code must retain the above copyright 13 | notice, this list of conditions and the following disclaimer. 14 | 15 | 2. Redistributions in binary form must reproduce the above copyright 16 | notice, this list of conditions and the following disclaimer in the 17 | documentation and/or other materials provided with the distribution. 18 | 19 | 3. Neither the names of Facebook, nor the names of any contributors may be 20 | used to endorse or promote products derived from this software without 21 | specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 26 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 28 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 29 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 30 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 31 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 32 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 33 | POSSIBILITY OF SUCH DAMAGE. 34 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This page contains instructions on how to propose and implement feature changes to PyTorch. 2 | 3 | # Proposing a Feature to Pytorch 4 | 5 | To propose a new feature, you’ll submit a Request For Comments (RFC). This RFC is basically a design proposal where you can share a detailed description of what change you want to make, why it’s needed, and how you propose to implement it. 6 | 7 | It’s easier to make changes while your feature is in the ideation phase vs the PR phase, and this doc gives core maintainers an opportunity to suggest refinements before you start code. For example, they may know of other planned efforts that your work would otherwise collide with, or they may suggest implementation changes that make your feature more broadly usable. 8 | 9 | Smaller changes, including bug fixes and documentation improvements can be 10 | implemented and reviewed via the normal GitHub pull request workflow on the main PyTorch repo. 11 | 12 | RFCs are more suitable for design proposals that are too large to discuss on a feature-request issue, 13 | like adding a new abstraction, or if a discussion about the tradeoffs involved in a new addition are non-trivial. 14 | 15 | If you are unsure whether something should be an RFC or a feature-request issue, you can ask by 16 | opening an issue in the main PyTorch/PyTorch repository. 17 | 18 | # The Request for Comments 19 | 20 | ## Step 1: Create an RFC 21 | RFCs are located in their own repository. 22 | 23 | To create one: 24 | 25 | 1. Fork the https://github.com/pytorch/rfcs repository 26 | 2. Copy the template file `RFC-0000-template.md` to `RFC-00xx-your-feature.md` and fill it out with your proposal. The template is a guideline, feel free to add sections as appropriate 27 | 3. You may also have the template simply link to another editor, like a Google Docs file, but please ensure that the document is publicly visible. This can make the template easier to add edit, but commenting doesn’t scale very well, so please use this option with caution. 28 | 29 | ## Step 2: Get Feedback on the RFC 30 | 1. Submit a pull request titled `RFC-00xx-your-feature.md` 31 | 2. Before your PR is ready for review, give it the draft label. 32 | 3. Once it’s ready for review, remove the draft label and give it the `commenting` label 33 | 4. File an issue against the https://github.com/pytorch/pytorch repository to review your proposal. 34 | 5. In the description, include a short summary of your feature and a link to your RFC PR 35 | 6. Pytorch Triage review will route your issue to core contributors with the appropriate expertise. 36 | 7. Build consensus. Those core contributors will review your PR and offer feedback. Revise your proposal as needed until everyone agrees on a path forward. Additional forums you can share the proposal on include the [developer forum](https://dev-discuss.pytorch.org/c/rfc-chatter), and the [Slack channel](https://bit.ly/ptslack). Tagging interested stakeholders (identifiable via [CODEOWNERS](https://github.com/pytorch/pytorch/blob/master/CODEOWNERS)) can help with consensus building. 37 | 38 | _(Note: A proposal may get rejected if it comes with unresolvable drawbacks or if it’s against the long term plans of the pytorch maintiners)_ 39 | 40 | ## Step 3: Implement your Feature 41 | 1. If your RFC PR is accepted, you can merge it into the [pytorch/rfcs](https://github.com/pytorch/rfcs) repository and begin working on the implementation. 42 | 2. When you submit PRs to implement your proposal, remember to link to your RFC to help reviewers catch up on the context. 43 | 44 | 45 | 46 | ## Implementing an RFC 47 | Every accepted RFC has an associated issue tracking its implementation in the PyTorch repository; thus that 48 | associated issue can be assigned a priority via the triage process that the team uses for all issues. 49 | 50 | The author of an RFC is not obligated to implement it. Of course, the RFC 51 | author (like any other developer) is welcome to post an implementation for 52 | review after the RFC has been accepted. 53 | 54 | If you are interested in working on the implementation for an accepted RFC, but 55 | cannot determine if someone else is already working on it, feel free to ask 56 | (e.g. by leaving a comment on the associated issue). 57 | 58 | 59 | ## RFC Rejection 60 | Some RFC pull requests are tagged with the "shelved" label when they are 61 | closed (as part of the rejection process). An RFC closed with "shelved" is 62 | marked as such because we want neither to think about evaluating the proposal 63 | nor about implementing the described feature until some time in the future, and 64 | we believe that we can afford to wait until then to do so. 65 | 66 | ## Inspiration 67 | PyTorch's RFC process owes inspiration to the [Rust RFC Process](https://github.com/rust-lang/rfcs) and [React RFC Process](https://github.com/reactjs/rfcs/), and the [Artsy RFC process](https://github.com/artsy/README/blob/main/playbooks/rfcs.md#resolution) for the resolution template. 68 | 69 | ## License 70 | By contributing to rfcs, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. 71 | -------------------------------------------------------------------------------- /RFC-0000-template.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | Instructions - click to expand 5 | 6 | - Fork the rfcs repo: https://github.com/pytorch/rfcs 7 | - Copy `RFC-0000-template.md` to `RFC-00xx-my-feature.md`, or write your own open-ended proposal. Put care into the details. 8 | - Submit a pull request titled `RFC-00xx-my-feature`. 9 | - Assign the `draft` label while composing the RFC. You may find it easier to use a WYSIWYG editor (like Google Docs) when working with a few close collaborators; feel free to use whatever platform you like. Ideally this document is publicly visible and is linked to from the PR. 10 | - When opening the RFC for general discussion, copy your document into the `RFC-00xx-my-feature.md` file on the PR and assign the `commenting` label. 11 | - Build consensus for your proposal, integrate feedback and revise it as needed, and summarize the outcome of the discussion via a [resolution template](https://github.com/pytorch/rfcs/blob/master/RFC-0000-template.md#resolution). 12 | - If the RFC is idle here (no activity for 2 weeks), assign the label `stalled` to the PR. 13 | - Once the discussion has settled, assign a new label based on the level of support: 14 | - `accepted` if a decision has been made in the RFC 15 | - `draft` if the author needs to rework the RFC’s proposal 16 | - `shelved` if there are no plans to move ahead with the current RFC’s proposal. We want neither to think about evaluating the proposal 17 | nor about implementing the described feature until some time in the future. 18 | - A state of `accepted` means that the core team has agreed in principle to the proposal, and it is ready for implementation. 19 | - The author (or any interested developer) should next open a tracking issue on Github corresponding to the RFC. 20 | - This tracking issue should contain the implementation next steps. Link to this tracking issue on the RFC (in the Resolution > Next Steps section) 21 | - Once all relevant PRs are merged, the RFC’s status label can be finally updated to `closed`. 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | # [Title] 30 | 31 | **Authors:** 32 | * @nickname 33 | * @nickname 34 | 35 | 36 | ## **Summary** 37 | A short paragraph or bullet list that quickly explains what you're trying to do. 38 | 39 | 40 | ## **Motivation** 41 | What motivates this proposal and why is it important? 42 | How should users and developers think about this feature, how would it impact the way PyTorch is used? 43 | Explain impact and value of this feature 44 | 45 | 46 | ## **Proposed Implementation** 47 | This is the bulk of the RFC. Explain the design in enough detail for somebody familiar with PyTorch to understand, and for somebody familiar with the implementation to implement. 48 | This should get into specifics and corner-cases, and include examples of how the feature is used, and how it will interact with other features. Any new terminology should be defined here. 49 | Consider: 50 | * using examples and diagrams to help illustrate your ideas. 51 | * including code examples, if you're proposing an interface or system contract. 52 | * linking to project briefs or wireframes that are relevant. 53 | 54 | 55 | ## **Metrics ** 56 | What are the main metrics to measure the value of this feature? 57 | 58 | 59 | ## **Drawbacks** 60 | Are there any reasons why we should not do this? Here we aim to evaluate risk and check ourselves. 61 | 62 | Please consider: 63 | * is it a breaking change? 64 | * Impact on UX 65 | * implementation cost, both in terms of code size and complexity 66 | * integration of this feature with other existing and planned features 67 | 68 | 69 | ## **Alternatives** 70 | What other designs have been considered? What is the impact of not doing this? 71 | 72 | 73 | ## **Prior Art** 74 | Discuss prior art (both good and bad) in relation to this proposal: 75 | * Does this feature exist in other libraries? What experience has their community had? 76 | * What lessons can be learned from other implementations of this feature? 77 | * Published papers or great posts that discuss this 78 | 79 | 80 | ## **How we teach this** 81 | * What names and terminology work best for these concepts and why? How is this idea best presented? 82 | * Would the acceptance of this proposal mean the PyTorch documentation must be re-organized or altered? 83 | * How should this feature be taught to existing PyTorch users? 84 | 85 | 86 | ## **Unresolved questions** 87 | * What parts of the design do you expect to resolve through the RFC process before this gets merged? 88 | * What parts of the design do you expect to resolve through the implementation of this feature before stabilization? 89 | * What related issues do you consider out of scope for this RFC that could be addressed in the future independently of the solution that comes out of this RFC? 90 | 91 | 92 | ## Resolution 93 | We decided to do it. X% of the engineering team actively approved of this change. 94 | 95 | ### Level of Support 96 | Choose one of the following: 97 | * 1: Overwhelming positive feedback. 98 | * 2: Positive feedback. 99 | * 3: Majority Acceptance, with conflicting Feedback. 100 | * 4: Acceptance, with Little Feedback. 101 | * 5: Unclear Resolution. 102 | * 6: RFC Rejected. 103 | * 7: RFC Rejected, with Conflicting Feedback. 104 | 105 | 106 | #### Additional Context 107 | Some people were in favor of it, but some people didn’t want it for project X. 108 | 109 | 110 | ### Next Steps 111 | Will implement it. 112 | 113 | 114 | #### Tracking issue 115 | 116 | 117 | 118 | #### Exceptions 119 | Not implementing on project X now. Will revisit the decision in 1 year. 120 | -------------------------------------------------------------------------------- /RFC-0038-cpython-support.md: -------------------------------------------------------------------------------- 1 | 2 | # [CPython version support] 3 | 4 | **Authors:** 5 | * @albanD 6 | 7 | ## **Motivation** 8 | 9 | CPython follows annual release cycle. Given that the amount of work to enable and sunset versions yearly is fixed, we should be proactive in how we handle these tasks. 10 | This RFC suggests an updated policy we want for CPython version support, CI/CD requirements and the proposed yearly timeline for enablement/sunsetting. 11 | 12 | The key requirements driving this proposal are: 13 | - Enable new versions as soon as possible 14 | - Sunset old versions in a timely manner 15 | - Set clear long-term expectations for our users 16 | - Mitigate the risk of a CPython change that would render PyTorch non-functional making it to the CPython final release 17 | - Minimize any work that would not have been needed if the enablement was done post-final release 18 | 19 | 20 | ## **Proposed Implementation** 21 | 22 | ### Proposed timeline 23 | 24 | ![Summary timeline](./RFC-0038-assets/cpython_support.png) 25 | 26 | ### Which CPython version are supported 27 | 28 | PyTorch supports all CPython versions that are fully released and have not reached end of life: https://devguide.python.org/versions/ 29 | 30 | Note: This is an update from the current policy at https://github.com/pytorch/pytorch/blob/main/RELEASE.md#python which is following NumPy’s approach. In practice, we are not following the rules stated in that .md and following the rule stated just above here. We are updating the rule (instead of enforcing the NEP above) as the cost of supporting more versions is minimal. 31 | Also split build to have a single (large) c++ package shared by all CPython version is way underway (https://github.com/pytorch/pytorch/pull/129011) and will reduce the binary size use on PyPi even though we support a large number of versions. 32 | 33 | ### What is tested in CI/CD 34 | 35 | The goal here is for us to ensure coverage of testing while reducing the cost on CI. 36 | 37 | At any given time, we should run: 38 | - General CI on PR and trunk should run on the oldest supported version. 39 | - Maintainers can ask for specific CI shard to run on specific versions: 40 | - Long term for testing features tightly bound to CPython versions (for example Dynamo). 41 | - Temporarily for enablement work (for example while a new CPython version is being enabled). 42 | - CD for docker and binaries should run for all supported versions. 43 | - Wheel/release smoke test should run on all supported versions. 44 | 45 | 46 | ### Detailed CPython new version enablement 47 | 48 | - Enable new version basic compilation 49 | - When: Once the first beta version is released (Usually in May) 50 | - Planned for 3.13: 5/07 51 | - ETA: 1-2 weeks. Before beta 2 52 | - Goal: Fix compilation issues allowing to compile PyTorch locally. Report any needed patch in CPython before beta 2. 53 | - Who: Core Maintainers for pytorch/pytorch + relevant maintainers for managed domains 54 | - Note: If a patch is needed in CPython, it will be made available easily until next beta release happens 55 | - Enable new version CI/CD Infrastructure 56 | - When: As soon as basic compilation is done 57 | - ETA: 2-3 weeks. 58 | - Goal: Generate nightly wheels and scaffolding for new version testing 59 | - Who: Core Infra 60 | - High risk enablement work for the new version for submodules 61 | - When: as soon as the CI/CD infrastructure is done 62 | - ETA: before CPython first RC 63 | - Planned for 3.13: 07/30 64 | - Goal: Verify high risk systems and report any needed patch in CPython such that they can be fixed before RC1 65 | - Who: High risk systems owners. As of today this is Dynamo C++ code and python binding subsystems. 66 | - Low risk enablement work for the new version for submodules 67 | - When: as soon as the CI/CD infrastructure is done 68 | - ETA: before CPython final RC 69 | - Planned for 3.13: 09/03 70 | - Goal: Enable all testing for the new CPython version 71 | - Who: Core Maintainers handle the long tail of issues. Specific submodule owner handle larger migration efforts (dynamo, TorchScript, etc) 72 | - Full new version wheel release 73 | - When: When the new cpython verion is officially released (Usually in October) 74 | - Planned for 3.13: 10/01 75 | - ETA: 2 week 76 | - Goal: Update nightlies to track the final release and advertise it as fully working 77 | - Who: Infra + Comms 78 | - Full new version conda release 79 | - When: When the new cpython version and our runtime dependencies are supported onconda 80 | - ETA: 2 week 81 | - Goal: Push binaries that are officially supported 82 | - Who: Infra + Comms 83 | 84 | 85 | ### Detailed CPython old version Sunset 86 | 87 | - Deprecate old version 88 | - When: During the release process of the last PyTorch release before the oldest version of python goes EOL 89 | - Planned for 3.8: October 90 | - Goal: announce deprecation of the oldest cpython version on dev-discuss + release 91 | - Who: Comms 92 | - Stop releasing nightly binaries for oldest version 93 | - When: Once the last supported release happened 94 | - ETA: 2 weeks 95 | - Goal: remove all wheel and conda binaries for the EOL version 96 | - Who: Infra 97 | - Upgrade oldest version CI 98 | - When: After nightlies are not published anymore 99 | - ETA: 2 weeks 100 | - Goal: migrate all the CI jobs running on the oldest version to the soon-to-be oldest version 101 | - Who: Infra 102 | - Remove un-used code and use new features 103 | - When: After the version is fully dropped 104 | - ETA: N/A 105 | - Goal: clean up tech debt related to older cpython versions and simplify code using new features. 106 | - Who: All 107 | -------------------------------------------------------------------------------- /RFC-0025-sdpa-optm-cpu.md: -------------------------------------------------------------------------------- 1 | # [PT2.1 Feature Proposal] SDPA (Scaled-Dot Product Attention) CPU Optimization 2 | 3 | This ticket is as part of PT 2.1 feature proposal process. 4 | 5 | ## **Motivation** 6 | As LLM tends to accept a large batch size and a long context length, the requirement of large memory may lead to OOM issues or result in bad performance. To reduce memory usage and provide a substantial speedup for attention-related models, it is important to optimize SDPA. The fused SDPA, e.g. flash attention, is one type of the optimized SDPA algorithms designed for memory-bound problems, with better parallelism and memory access patterns. In PT 2.0, there exist both the basic unfused SDPA and the fused SDPA for CUDA, while only the unfused SDPA has CPU implementation. To fill the gap between CPU and CUDA, it is proposed to optimize SDPA by implementing the fused SDPA for CPU in PT 2.1. 7 | 8 | ## **Implementation** 9 | We submitted PRs for CPU SDPA optimization and demonstrated up to 3x performance speedup on attention-related benchmarks. 10 | 11 | Here are the detailed implementation items: 12 | 13 | * The flash attention CPU kernel is added, in which both forward and backward paths are implemented for data types float32 and bfloat16. Blocking is applied on dimensions of query length and kv length and the fusion of gemm + softmax update + gemm is done at once for each block. Specifically, FP32In-FP32Out and BF16In-FP32Out adopt the mkl gemm and BF16In-BF16Out adopts the OneDNN one. Parallelization is on the dimensions of batch size, head number and query length for forward path, and on the dimensions of batch size and head number for backward path. In addition, the causal attention mask is supported. As the attention is masked for the unseen tokens, early termination is applied and we only calculate the blocks in the lower triangular part. 14 | * Write an SDPA selecting function for CPU to automatically choose one SDPA implementation among several ones. There are two CPU implementations which could be chosen: the unfused SDPA and flash attention. In general, flash attention has a higher priority than the unfused SDPA. For cases where flash attention is not applicable, such as maually disabling flash attention or the inputs not 4 dimensional, the unfused SDPA is chosen. 15 | 16 | The following will be nice to have for PT 2.1: 17 | 18 | * Support data type of float16. 19 | * Enable the SDPA graph rewriting for Inductor. 20 | * Further block-related tuning for the fused SDPA. 21 | * Support Dropout for the fused SDPA. 22 | 23 | 24 | ## **Performance** 25 | All validations are run on SPR machine. 26 | 27 | ### NanoGPT's SDPA kernel 28 | Using benchmark [repo](https://github.com/mingfeima/bench_sdpa/blob/main/README.md), with one socket. 29 | Shape: Batch size 1, Sequence length 1024, Head number 25, Head size 64. 30 | 31 | | Dtype | Causal | Mode | SDPA | Time (ms per iter) | Speedup | 32 | | -------- | -------- | ------- | ------- | ------- | ------- | 33 | | float32 | FALSE | Inference | Unfused | 3.081 | | 34 | | | | | Flash attention | 1.665 | 1.85045 | 35 | | float32 | TRUE | Inference | Unfused | 3.463 | | 36 | | | | | Flash attention | 1.662 | 2.083634| 37 | | bfloat16 | FALSE | Inference | Unfused | 1.203 | | 38 | | | | | Flash attention | 1.154 | 1.042461| 39 | | bfloat16 | TRUE | Inference | Unfused | 1.543 | | 40 | | | | | Flash attention | 1.154 | 1.337088| 41 | | float32 | FALSE | Training | Unfused | 54.938 | | 42 | | | | | Flash attention | 23.029 | 2.385601| 43 | | float32 | TRUE | Training | Unfused | 58.266 | | 44 | | | | | Flash attention | 17.835 | 3.266947| 45 | | bfloat16 | FALSE | Training | Unfused | 18.924 | | 46 | | | | | Flash attention | 18.886 | 1.002012| 47 | | bfloat16 | TRUE | Training | Unfused | 21.08 | | 48 | | | | | Flash attention | 14.172 | 1.48744 | 49 | 50 | ### Stable Diffusion 51 | Following model's [BKM](https://github.com/intel-innersource/frameworks.ai.models.intel-models/blob/develop/quickstart/diffusion/pytorch/stable_diffusion/inference/cpu/README.md). 52 | 53 | | Dtype | SDPA | Throughput (fps) | Speedup SDPA | Total Time (ms) | Speedup | 54 | | -------- | -------- | ------- | ------- | ------- | ------- | 55 | | float32 | Unfused | 1.63 | | 1139 | | 56 | | | Flash attention | 1.983 | 1.216564 | 547.488 | 2.080411| 57 | | bfloat16 | Flash attention in IPEX | 4.784 | | 429.051 | | 58 | | | Flash attention | 4.857 | 1.015259 | 408.823 | 1.049479| 59 | 60 | ## **Related PRs** 61 | Flash attention Implementation: 62 | * [#103826 Flash attention forward](https://github.com/pytorch/pytorch/pull/103826) 63 | * [#104693 Flash attention backward](https://github.com/pytorch/pytorch/pull/103826) 64 | * [#104863 enable bfloat16 on flash attention](https://github.com/pytorch/pytorch/pull/104863) 65 | 66 | SDPA selecting function: 67 | * [#105131 add sdpa choice and UT](https://github.com/pytorch/pytorch/pull/105131) 68 | 69 | Some additional works: 70 | * [#104583 enable mkl’s bfloat16 gemm on PT](https://github.com/pytorch/pytorch/pull/104583) 71 | * [#104584 expand functional utils on CPU vectorization path](https://github.com/pytorch/pytorch/pull/104584) 72 | 73 | ## **Discussion** 74 | 75 | For the SDPA optimization, there are two things that needed to be discussed and I hope to have your precious opinions. 76 | 77 | One is about the util functions for SDPA selection. The current util functions are under the CUDA folder, i.e. `transformers/cuda/sdp_utils`. For CPU, we have similar functions in `transformers/sdp_utils_cpp` (see #105131). It is good to know whether we need to make them a unified API. 78 | 79 | The other one is about GQA (Grouped-Query Attention), used in llama2. It interpolates between multi-head and multi-query attention and should be presented as a new feature in SDPA. If this feature is regarded as necessary, we can do this later. 80 | -------------------------------------------------------------------------------- /RFC-0024-rfc-process.md: -------------------------------------------------------------------------------- 1 | 2 | # Process for RFC shepherding 3 | 4 | **Authors:** 5 | * @suraj813 6 | 7 | 8 | ## **Summary** 9 | This document proposes a controlled and consistent RFC shpeherding process from proposal to implementation, along with a standard RFC template for authors to follow. 10 | 11 | ## **Context** 12 | RFC is a structured document that allows community members to propose an idea to everyone before it is implemented. RFCs enable stakeholders to be aware and confident about the direction the library is evolving in. 13 | 14 | ### Features of RFCs 15 | * Good for sparking a discussion around features larger than regular issues 16 | * Enable participation and visibility across the community 17 | * Not a consensus device; synchronous tools like meetings are better for this. 18 | * Bakes in silence as positive indifference / absence of opinions or context. 19 | 20 | 21 | ## **Motivation** 22 | * To standardize a lightweight process for discussions around new features (design, goals, risks and decisions/next steps). 23 | * To ease the process of shepherding RFCs by aligning them in a workflow 24 | * To clarify the decision-making process leading to stronger technical governance in an open source foundation 25 | * To promote a shared understanding of the library by sharing institutional knowledge across different domains. 26 | * To serve as a reference repository of prior discussions that can facilitate future onboarding. 27 | 28 | Some projects redirect feature requests to their RFC flow to encourage requesters to make a stronger thoughtful proposal and open it for discussion among the community. This automatically filtered out low-effort requests. 29 | 30 | Maintainers also found the RFC process useful because it gave them a sense of what the community was asking for, even if the proposals weren’t accepted. 31 | 32 | 33 | ## **Proposed Implementation** 34 | ### RFC Lifecycle 35 | ![Lifecycle](./RFC-0024-assets/rfc-lifecycle.svg) 36 | * All RFCs start as a **draft**. 37 | * When it is ready to be discussed, the RFC is in **commenting** stage. 38 | * If there is no activity, the RFC is **stalled**. 39 | * After the commenting stage, the RFC can enter into a stage of: 40 | * **accepted** if a decision has been made in the RFC 41 | * **draft** if the author needs to rework the RFC’s proposal 42 | * **shelved** if there are no plans to move ahead with the current RFC’s proposal 43 | * After an accepted RFC is implemented and merged into the codebase, it can be **closed** 44 | 45 | ### Platform 46 | We can easily implement this lifecycle on the existing rfcs repo (https://github.com/pytorch/rfcs) via PR labels. The full workflow looks like: 47 | - Fork the rfcs repo: https://github.com/pytorch/rfcs 48 | - Copy `RFC-0000-template.md` to `RFC-00xx-my-feature.md`, or write your own open-ended proposal. Put care into the details. 49 | - Submit a pull request titled `RFC-00xx-my-feature`. 50 | - Assign the `draft` label while composing the RFC. You may find it easier to use a WYSIWYG editor (like Google Docs) when working with a few close collaborators; feel free to use whatever platform you like. Ideally this document is publicly visible and is linked to from the PR. 51 | - When opening the RFC for general discussion, copy your document into the `RFC-00xx-my-feature.md` file on the PR and assign the `commenting` label. 52 | - Build consensus for your proposal, integrate feedback and revise it as needed, and summarize the outcome of the discussion via a [resolution template](https://github.com/pytorch/rfcs/blob/master/RFC-0000-template.md#resolution). 53 | - If the RFC is idle here (no activity for 2 weeks), assign the label `stalled` to the PR. 54 | - Once the discussion has settled, assign a new label based on the level of support: 55 | - `accepted` if a decision has been made in the RFC 56 | - `draft` if the author needs to rework the RFC’s proposal 57 | - `shelved` if there are no plans to move ahead with the current RFC’s proposal. We want neither to think about evaluating the proposal 58 | nor about implementing the described feature until some time in the future. 59 | - A state of `accepted` means that the core team has agreed in principle to the proposal, and it is ready for implementation. 60 | - The author (or any interested developer) should next open a tracking issue on Github corresponding to the RFC. 61 | - This tracking issue should contain the implementation next steps. Link to this tracking issue on the RFC (in the Resolution > Next Steps section) 62 | - Once all relevant PRs are merged, the RFC’s status label can be finally updated to `closed`. 63 | 64 | ### Improving Visibility of RFCs 65 | When an RFC is in `commenting`, we can highlight it in a few ways by sharing a weekly/bi-weekly digest on: 66 | * PyTorch Slack 67 | * Developer Forums (dev-discuss.pytorch.org) 68 | * Workplace (Internal groups at Meta) 69 | * Twitter 70 | 71 | ### RFC Template 72 | The provided [RFC template](https://github.com/pytorch/rfcs/blob/master/RFC-0000-template.md) contains sections to help draft a detailed RFC, and workflow instructions so that contributors don't need to refer to any other place to understand the process. 73 | 74 | ### Advantages of Github as a platform for RFCs: 75 | * Review comments are threaded, so conversations around a particular sentence are colocated 76 | * Writing code is better on Github 77 | * Resolved comments don’t disappear 78 | * Can easily cross-link to issues on Github 79 | * One single platform for all development-related work 80 | 81 | 82 | ## Alternative Platforms 83 | 84 | ### Alt 1: Google Docs only 85 | 86 | Recently, the team introduced a public GDrive folder called ‘PyTorch Design Docs’. Documents here are open for public commenting. Write access is limited to the `[contributors@pytorch.org](mailto:contributors@pytorch.org)` mailing list; anyone currently on the mailing list can add new authors. 87 | 88 | Each RFC will have a status label in the doc title, like **_[STATUS] RFC #[number] : [Title]_** which helps viewers identify what stage the discussion is in. The RFC author is responsible for updating the status label of the document. 89 | 90 | **Advantages:** 91 | 92 | Quoting a post on Workplace: 93 | 94 | > Anecdotally, engineers at Meta create GDocs (and formerly Quips) today, and it is difficult to migrate this behavior entirely to the RFC repo. Google docs have a number of advantages: 95 | > * smooth inline comments with notifications, 96 | > * rich formatting, 97 | > * embeddable diagrams compared to plain markdown + review UI of the RFCs repo. 98 | > * There's an 'Export to Markdown' plugin for gdocs, so it's easy to commit the final version after the discussion is done. 99 | 100 | **Drawbacks:** 101 | * Discoverability: Contributors are used to working in Github; having a separate location is prone to discoverability issues. 102 | * With GDocs, developers need to check two places for notifications. Cannot use github labels for autoping. 103 | * Long comments can be difficult to read in the Google UI. Resolved comments are difficult to find in the UI. 104 | * Centralizing all RFC discussions implies that existing discussion threads must be imported into the GDrive folder. 105 | * Introducing time-bound state changes can stifle the spontaneity of the discussion process. 106 | 107 | 108 | ### Alt 2: GDocs + RFC repo hybrid approach 109 | 110 | Many opensource projects maintain an RFC presence on Github (either via a separate RFCs repo, or on the project’s Issues). A GDocs + Github hybrid solution will help the visibility of RFCs in a more familiar way, but at the cost of complicating the process and increasing overheads. 111 | 112 | When authors open a new PR on the RFCs repo, they are asked to create a document on the GDrive. The PR serves as the tracking PR and the objective is to improve visibility. Once they create a new RFC GDoc, they link it in the PR. All subsequent comments and feedback will be directed to the GDoc. When the RFC is **[Accepted]** on GDocs, it can be exported to markdown, committed to the tracking PR on Github. 113 | 114 | 115 | ## **Open Questions** 116 | 117 | * What changes will need an RFC as opposed to directly opening an issue or a PR? 118 | * @dzhulgakov: We require RFCs when adding a new abstraction, or if a discussion about the tradeoffs involved in a new addition are non-trivial. 119 | 120 | 121 | ## **Prior Art** 122 | The proposals in this document are inspired by 123 | * [Rust](https://github.com/rust-lang/rfcs) and [React](https://github.com/reactjs/rfcs/) RFC process that use Github, 124 | * the Artsy RFC process for the [resolution template](https://github.com/artsy/README/blob/main/playbooks/rfcs.md#resolve-rfc), and 125 | * the [Sourcegraph RFC](https://sourcegraph.notion.site/Requests-for-comments-RFCs-3e9cb5c238f04042893d449572ca02bd) process that uses Google Docs. 126 | 127 | 128 | ## Resolution 129 | We decided to implement the RFC workflow on the github rfcs repo. 130 | 131 | Level of support: 2 (positive feedback) 132 | 133 | ### Additional Context 134 | **@mruberry:** 135 | > We actually do receive RFCs from people who aren't from Facebook/Meta and I don't think the principal problem is that they don't know where to post but that we often ignore their posts or their posts are unfunded feature requests where our response is "that's a neat idea" but we don't have a team of engineers able to immediately tackle the issue. Here's an example: https://github.com/pytorch/pytorch/issues/64327. Here's us ignoring Intel: https://github.com/pytorch/pytorch/issues/63556. 136 | > 137 | > RFCs are inherently controversial significant changes and I would expect only a small community to reasonably create and fund them, and we probably have existing relationships with that community or should develop relationships beyond just forum posts. 138 | > 139 | > There are very rare exceptions to the above where people write entire position papers, but those are usually published as blogs, anyway (like Sasha Rush's named tensor post). And then those blogs percolate through the community and we eventually decide to fund them (or not). 140 | > 141 | > The other major issue with RFCs, as this doc points out, is discoverability. Whatever we do for RFCs we should probably think to post in Workplace, GitHub Issues, PyTorch Dev Discussions and, if we're using Google Drive, Google Drive. 142 | > 143 | > The final consideration is how easy it is to read and comment on the RFC. Personally I think GitHub issues, forum posts and Google Drive are all about the same and offer the same access to the community. 144 | 145 | 146 | **@rgommers:** 147 | > Compared to GDocs, the RFC repo has better support for code formatting and in-line commenting, and a better developer experience. 148 | > PEP and NumPy Enhancement Proposal also use Github repos for RFC (alluding to a standard way of doing RFCs in the community). 149 | 150 | 151 | **@dzhulgakov:** 152 | > I was debating between Repo and GDocs + Repo options. My hope was that with GDocs it'd be easier to incentivize Meta developers to write them. But if we want a high-quality document structure, it requires additional work on top of regular 'just some notes gdoc' anyway. So we might as well require to write the content as markdown and put it in the repo. 153 | 154 | 155 | ## Next Steps 156 | * Update [this section in PyTorch Governance](https://github.com/pytorch/pytorch/blob/master/docs/source/community/governance.rst#controversial-decision-process) to point to this repo. 157 | * Create RFC lifecycle labels and assign them to existing RFC PRs 158 | * Create #rfc-chat channel on the dev-discuss forum and PyTorch slack 159 | * Automate label assignment with github bots (first label, stalebot) -------------------------------------------------------------------------------- /RFC-0030-native-fp8-dtype.md: -------------------------------------------------------------------------------- 1 | # Proposal of fp8 dtype introduction to PyTorch 2 | 3 | **Authors:** 4 | * @australopitek 5 | 6 | 7 | ## **Summary** 8 | More and more companies working on Deep Learning accelerators are experimenting with 8-bit floating point numbers usage in training and inference. Results of these experiments are presented in many papers published in the last few years. 9 | 10 | Since fp8 data type seems to be a natural evolution of currently used fp16/bf16, to reduce computation of big DL models, it’s worth to standardize this type. Few attempts of this were done recently: 11 | 12 | * Nvidia, Arm and Intel - https://arxiv.org/pdf/2209.05433.pdf 13 | * GraphCore, AMD and Qualcomm - https://arxiv.org/pdf/2206.02915.pdf 14 | * Tesla - https://web.archive.org/web/20230503235751/https://tesla-cdn.thron.com/static/MXMU3S_tesla-dojo-technology_1WDVZN.pdf 15 | 16 | This RFC proposes adding two 8-bit floating point data types variants to PyTorch, based on the Nvidia/Arm/Intel paper. It’s important to consider these two variants, because they’re already known to be used by Nvidia H100 and Intel Gaudi2 accelerators. 17 | 18 | 19 | ## **Motivation** 20 | Existence of native fp8 dtypes in PyTorch would simplify research and development of DL models using 8-bit floating point precision. 21 | 22 | It would be simpler to create high level libraries on top of PyTorch. Potential automatic mixed-precision frameworks, like Nvidia’s TransformerEngine, could use this type directly instead of emulating it. 23 | 24 | Built-in fp8 would also increase performance and optimize memory usage of such frameworks, by avoiding overhead caused by type emulation. 25 | 26 | In addition, it’s worth to note that fp8e5m2 type was recently added to mlir https://github.com/llvm/llvm-project/blob/fd90f542cf60c2a4e735f35513268c052686dbd6/mlir/include/mlir/IR/BuiltinTypes.td#L80 and similar RFC is already being discussed in XLA https://github.com/openxla/xla/discussions/22 27 | 28 | 29 | ## **Proposed Implementation** 30 | Both new data types would be added as separate PyTorch dtypes, on python and C++ levels, similarly to float16 and bfloat16. They would have all properties of floating-point datatype, including PyTorch type promotion, C++ std::numeric and math arithmetic. 31 | 32 | ### **Basic example of fp8 usage with types promotion** 33 | 34 | Some ops in PyTorch allow for inputs with different dtypes. In such case, these inputs are internally casted to the common dtype got from type promotion matrix and below rules: 35 | 36 | * floating point types take precedence over integer types, 37 | * dimensioned tensors over zero-dim tensors, 38 | * types with more bits over types with less bits. 39 | 40 | ```python 41 | input_fp32 = torch.rand((3, 3)) 42 | input_bf16 = torch.rand((3, 3), dtype=torch.bf16) 43 | input_i32 = torch.rand((3, 3)).to(torch.int32) 44 | input_fp8 = torch.rand((3, 3)).to(torch.fp8e5m2) 45 | 46 | res_fp32_fp8 = torch.add(input_fp32, input_fp8) # dtype == torch.fp32 47 | res_bf16_fp8 = torch.add(input_bf16, input_fp8) # dtype == torch.bf16 48 | res_i32_fp8 = torch.add(input_i32, input_fp8) # dtype == torch.fp8e5m2 49 | res_fp8_fp8 = torch.add(input_fp8, input_fp8) # dtype == torch.fp8e5m2 50 | ``` 51 | Full description can be found in the documentation: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc 52 | 53 | ### **Binary codes configuration** 54 | 55 | Below table is taken from https://arxiv.org/pdf/2209.05433.pdf. 56 | 57 | ![Binary codes configuration](./RFC-0030-assets/fp8_binary_formats.png) 58 | 59 | ### **HF8 (E4M3) - hybrid float 8** 60 | 61 | Applicable mainly for weight and activation tensors, i.e. forward pass of training and inference. Due to only 4 exponent bits, dynamic range of this variant is small. To extend it from 240 to 448, it was proposed to differ from IEEE special values encoding in the following ways: 62 | 63 | * resign from encoding Infinities, 64 | * decrease NaN encodings to only all-ones mantissa. 65 | 66 | PyTorch implementation of this variant is similar to existing float16, with main difference in handling special values. 67 | 68 | ### **BF8 (E5M2) - brain float 8** 69 | 70 | Applicable mainly for gradients in the backward pass of training, however there are many models that can be trained only with this variant. Its dynamic range is big, at least comparing to above E4M3 dtype, so it's ok to have special values encoding aligned with IEEE. 71 | 72 | ### **Small range and precision issues** 73 | 74 | **Scaling** 75 | 76 | Small precision and range of fp8 dtypes make them susceptible to underflows and overflows. To avoid it, tensors should be scaled to fp8 range. 77 | 78 | One way is to scale gradients the same way as it’s already done with fp16 dtype. The existing `torch.cuda.amp.GradScaler` could be moved/copied to `torch.amp` namespace and adapted to new dtypes and devices. Examples of `GradScaler` usage can be found in https://pytorch.org/docs/stable/notes/amp_examples.html 79 | 80 | However, with both precision and dynamic range limited, single scale factor may be insufficient for fp8 training. Considering bigger models, it’s easy to imagine a set of gradient tensors with values that cannot all be scaled into fp8 format without part of it being clamped to zero. The proper, more complex solution should be used for more efficient scaling, collecting maximum values statistics from few iterations back and keeping separate scale factor for each gradient tensor. This way every few iterations each tensor is scaled by its individual factor and the risk of over/underflow is minimized. 81 | 82 | Finally, above solution with per-tensor scaling may be extended for all tensors, not only gradients. 83 | 84 | ### **Automatic conversion** 85 | 86 | It’s important to consider adding a module for automatic mixed precision training to make fp8 usage more user friendly. Since fp8 dtypes have very small range and precision, for the most efficient training of big, complex DL models, more advanced solutions than simple `torch.autocast` + `GradScaler` will be needed. 87 | 88 | One way of implementing the efficient solution for automatic mixed precision is a higher-level library on top of PyTorch. Such module would take care of execution of applicable ops in fp8 precision and gradients scaling. 89 | 90 | Details of this module are not in the scope of this RFC. 91 | 92 | ### **Basic CPU support** 93 | 94 | It’s probably a long time until some CPUs have fp8 support (or maybe never), but it’s worth to add fp8 support for few PyTorch ops, like MatMul and Conv, at least for the testing purpose. 95 | 96 | It’s enough to have basic math operations implemented via casting to float: 97 | 98 | ```cpp 99 | inline C10_HOST_DEVICE float operator+(FP8 a, float b) { 100 | return static_cast(a) + b; 101 | } 102 | 103 | inline C10_HOST_DEVICE float operator*(FP8 a, float b) { 104 | return static_cast(a) * b; 105 | } 106 | ``` 107 | 108 | and register kernels by adding fp8 dtypes into macro: 109 | 110 | ```cpp 111 | AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6 112 | ``` 113 | 114 | 115 | ## **Open points** 116 | ### **Stochastic rounding** 117 | 118 | **Overview** 119 | 120 | It’s an open question if stochastic rounding should be used in training with fp8 dtype. 121 | 122 | Due to very small precision of fp8 data type, conversion from higher precision dtypes is susceptible to bias error, when input data is not uniformly distributed. To avoid it, casting to fp8 may use stochastic rounding instead of RNE. 123 | 124 | On the other hand, most papers don’t mention it at all, which means models converge also without it. Moreover, stochastic rounding may cause harm when used in wrong places, e.g. in optimizer step. 125 | 126 | **How should it be exposed?** 127 | 128 | On the python level it can be exposed as an optional parameter in casting operator, e.g. `rounding_mode`. It would be ignored by backends that don’t support it. To make it additionally deterministic, one more parameter `seed` can be added. 129 | 130 | In addition to explicit rounding mode selection via `.to` operator API, it’s worth to provide more general selection on higher level, e.g. stochastic rounding could be enabled for the whole model or part of it. In such case some mixed-precision module or hardware backend could apply stochastic rounding in the optimized way. Not only casting operations, but also internal accumulation in matmul/conv ops could benefit on that. 131 | 132 | **Implementation proposal** 133 | 134 | Excessive bits of higher precision number’s mantissa are compared to randomly generated bits. If they’re bigger, higher precision number is rounded up, otherwise down. 135 | 136 | Example for float32 to fp8e5m2 conversion with stochastic rounding: 137 | 138 | ```cpp 139 | uint32_t r = random_number & UINT32_C(0x1FFFFF); 140 | uint32_t m = float_bits & UINT32_C(0x1FFFFF); 141 | if (m > r) { 142 | float_bits += UINT32_C(0x200000); 143 | } 144 | ``` 145 | 146 | **Example** 147 | 148 | Below code shows a naïve example of bias error caused by conversion from fp32 to fp8 and how stochastic rounding solves it. 149 | 150 | ```python 151 | # consecutive fp8e5m2 representable numbers are 40.0 and 48.0 152 | input_fp32 = torch.full((1000, 1000), 42.5) 153 | 154 | res_fp8 = input_fp32.to(torch.fp8e5m2) 155 | torch.mean(res_fp8) 156 | # mean = 40.0, bias error = 2.5 157 | 158 | res_fp8_stochastic = input_fp32.to(torch.fp8e5m2, rounding_mode="stochastic") 159 | torch.mean(res_fp8_stochastic) 160 | # mean ~=42.5, bias error ~= 0.0 161 | ``` 162 | 163 | ### **Automatic conversion placement** 164 | 165 | It has to be decided where the module for fp8 automatic mixed precision should be stored. Initially, for development, it could be an external Python package. Finally, the stable version could land in the PyTorch in `torch.amp` namespace, together with `autocast` and `GradScaler`. 166 | 167 | ### **Autocast** 168 | 169 | As mentioned before, the current torch.autocast is too simple for efficient handling fp8 mixed precision. 170 | 171 | The question is, does it make sense to add fp8 support to autocast anyway? It could help with testing and debugging simple models. 172 | 173 | **Autocast implementation proposal** 174 | 175 | * `torch.autocast()` gets new optional parameter `second_dtype` 176 | * new casting policy - `CastPolicy::second_lower_precision_fp` (or even better name) 177 | * ops supporting fp8 dtype are registered for autocast with this new CastPolicy 178 | * if `second_dtype` is not specified, it falls back to the first dtype (fp16/bf16) 179 | * additionally, functionality of configurable list of autocast ops could be added to increase flexibility, but it’s not trivial in the current design of autocast. 180 | 181 | **Example** 182 | 183 | Below code presents simple example of autocast with a support for fp8 and two dtypes at once. 184 | 185 | Assumptions: 186 | 187 | * `torch.matmul` is registered to `CastPolicy::second_lower_precision_fp`, 188 | * `torch.addmm` is registered to `CastPolicy::lower_precision_fp`, 189 | * `torch.log_softmax` is registered to `CastPolicy::fp32` 190 | 191 | ```python 192 | a = torch.rand((5, 5), dtype=torch.fp32) 193 | b = torch.rand((5, 5), dtype=torch.fp32) 194 | c = torch.rand((5, 5), dtype=torch.fp32) 195 | 196 | with torch.autocast(dtype=torch.bfloat16, second_dtype=torch.fp8e5m2): 197 | mm = torch.matmul(a, b) # dtype == torch.fp8e5m2 198 | res_log_softmax = torch.log_softmax(mm, 0) # dtype == torch.float32 199 | res_addmm = torch.addmm(a, b, c) # dtype == torch.bfloat16 200 | 201 | with torch.autocast(dtype=torch.bfloat16): 202 | mm = torch.matmul(a, b) # dtype == torch.bfloat16 203 | res_log_softmax = torch.log_softmax(mm, 0) # dtype == torch.float32 204 | res_addmm = torch.addmm(a, b, c) # dtype == torch.bfloat16 205 | ``` 206 | 207 | ### **Generic bits for fp8 prototyping** 208 | 209 | In the recent discussion about fp8 type in PyTorch (https://dev-discuss.pytorch.org/t/fp8-datatype-in-pytorch/719/6), the intermediate solution was mentioned. 210 | 211 | > we are going to add some generic bits8/16/etc type to PyTorch so you can easily prototype FP8 in a tensor subclass 212 | 213 | There are few questions regarding details of this solution in the context of being an alternative for true dtype. 214 | 215 | * What are the limitations comparing to native built-in type? 216 | * Does it have properties of floating-point format like infs/nans, underflow numbers, rounding modes? 217 | * Is it configurable in terms of size of exponent/mantissa, bias, special values encoding? 218 | * Can it be included in type promotion matrix? 219 | * Is it possible to register and use many of such types at once? 220 | * Does it support Autograd? How? 221 | * What is the reason of not customizing int8 based dtype to fp8? 222 | * Will bit8 be a base class off which the int8 and fp8 both gets created, or is it specifically for fp8? 223 | -------------------------------------------------------------------------------- /RFC-0011-InferenceMode.md: -------------------------------------------------------------------------------- 1 | ## Summary 2 | 3 | `InferenceMode` is a new context manager / RAII guard analogous to 4 | `NoGradMode` to be used when you are certain your operations will have no 5 | interactions with autograd (e.g., model training). Code run under this 6 | mode gets better performance by disabling view tracking and version 7 | counter bumps. 8 | 9 | ## Motivation 10 | 11 | In production use of PyTorch for inference, we have seen a proliferation 12 | of uses of the C++ guard `AutoNonVariableTypeMode`, which disables 13 | autograd, view tracking and version counter bumps. Unfortunately, 14 | current colloquial use of this guard is unsafe: it is possible to use 15 | `AutoNonVariableTypeMode` to bypass PyTorch's safety checks for, e.g., 16 | ensuring tensors saved for backwards are not subsequently mutated. 17 | 18 | `InferenceMode` offers a drop in replacement for 19 | `AutoNonVariableTypeMode` which: 20 | 21 | 1. Preserves the performance characteristics of 22 | `AutoNonVariableTypeMode` (Autograd, view tracking and version 23 | counter bumps are skipped for all tensors allocated within the 24 | inference mode region), but 25 | 26 | 2. Is safe, in the sense that it is not possible to bypass version 27 | counter updates on tensors which may alias with tensors which 28 | have been saved for backwards. 29 | 30 | For now, this guard is to only be made available inside C++, although 31 | we could also introduce a Python guard `torch.inference_mode` as well. 32 | 33 | Some goals and non-goals: 34 | 35 | * Goal: `InferenceMode` is semantically equivalent to `NoGradMode`, 36 | except some operations may not be supported. (In other words, this is 37 | a partial equivalence: *if* inference mode does not throw an error, 38 | then it behaves the same way as no grad mode). Caveat: this 39 | equivalence does not extend to methods that expose private 40 | implementation details; esp., `Tensor._is_view` and `Tensor._base`. 41 | 42 | * Goal: It should be possible to run code that allocates parameters 43 | (tensors with `requires_grad=True`) unchanged inside of an inference 44 | mode block. 45 | 46 | * Goal: Don't be a global or compile time flag. This makes 47 | `InferenceMode` widely applicable as it can still be used in processes 48 | where there may be training going on in another thread (e.g., 49 | federated learning on mobile). 50 | 51 | * Non-goal: `InferenceMode` doesn't affect computation beyond its scope. 52 | Indeed, the capacity for tensors allocated in `InferenceMode` (so 53 | called "inference tensors") to behave differently even outside of 54 | `InferenceMode` is one of the key implementation tools to ensuring 55 | that `InferenceMode` is safe. 56 | 57 | * Non-goal: Make operations on inference tensors fast outside of 58 | `InferenceMode`; nor, be maximally expressive with inference 59 | tensor outside of `InferenceMode`. 60 | 61 | * Non-goal: Avoid performance slowdown for view/inplace operations 62 | outside of `InferenceMode`. Benchmarking on popular models reveal 63 | that a slight slowdown on these operations is acceptable; in our 64 | case, this slowdown will be due to an extra redispatch in these cases. 65 | 66 | ## User description 67 | 68 | `InferenceMode` is an RAII guard which can be enabled for a given block 69 | of code. Inside inference mode, all newly allocated (non-view) tensors 70 | are marked as **inference tensors**; these tensors are guaranteed not to 71 | alias with tensors that may have been saved for backwards (or are 72 | otherwise making use of version counters--perhaps more accurately, 73 | you could call these "non version counter tracked tensors"). Inference 74 | tensors: 75 | 76 | * Do not have a version counter. 77 | * Raise an error if you try to read their version (e.g., because you 78 | saved this tensor for backwards.) 79 | * Raise an error if you try to mutate them into requiring gradients 80 | (e.g., directly set `requires_grad=True` or mutate them with a tensor 81 | that `requires_grad=True`.) 82 | 83 | A non-view tensor is an inference tensor if and only if it was 84 | allocated during inference mode. A view tensor is an inference 85 | tensor if and only if the tensor it is a view of is an inference tensor. 86 | 87 | Inside an `InferenceMode` block, we make the following performance 88 | guarantees: 89 | 90 | * All operations do not record `grad_fn`, even if their `requires_grad=True` 91 | (like `NoGradMode`). This applies for both inference tensors and 92 | normal tensors (also like `NoGradMode`). 93 | * View operations on inference tensors do not do view tracking; views 94 | and base inference tensors are indistinguishable. 95 | * Inplace operations on inference tensors are guaranteed not to do 96 | a version counter bump (which is equivalent to an atomic increment). 97 | Inplace operations on normal tensors still do version counter bumps. 98 | 99 | ## Implementation description 100 | 101 | **Dispatcher.** The dispatcher decides what implementation of a kernel 102 | to call when an operator is invoked. The set of possible options is 103 | controlled by several sources: 104 | 105 | * Tensor inputs (keys are unioned from all inputs) 106 | * TLS included set 107 | * TLS excluded set (which removes keys from the above two sources) 108 | 109 | **Autograd.** This is a preexisting dispatch key which is responsible 110 | for recording `grad_fn` on output tensors when any of their inputs 111 | `require_grad`. 112 | 113 | Autograd dispatch key is associated with tensors. Prior to this 114 | proposal, all tensors unconditionally have an autograd key. 115 | (Technically, the autograd dispatch key is not a single key, 116 | but a set of keys per backend; for the purposes of this proposal, 117 | this doesn't matter.) 118 | 119 | **InplaceOrView.** This is a new dispatch key which is responsible for 120 | doing version counter bumps on inplace operations, and view metadata 121 | tracking for view ops. Previously, this functionality was also done 122 | as part of the Autograd kernel. For all other operators, it is a fallthrough 123 | kernel. Here is an example kernel for an inplace op and a view op prior 124 | to this proposal: 125 | 126 | ``` 127 | Tensor & add__Tensor(c10::DispatchKeySet ks, Tensor & self, const Tensor & other, Scalar alpha) { 128 | { 129 | at::AutoDispatchBelowInplaceOrView guard; 130 | at::redispatch::add_(ks & c10::after_InplaceOrView_keyset, self, other, alpha); 131 | } 132 | increment_version(self); 133 | return self; 134 | } 135 | 136 | Tensor expand(c10::DispatchKeySet ks, const Tensor & self, IntArrayRef size, bool implicit) { 137 | auto _tmp = ([&]() { 138 | at::AutoDispatchBelowInplaceOrView guard; 139 | return at::redispatch::expand(ks & c10::after_InplaceOrView_keyset, self, size, implicit); 140 | })(); 141 | std::function func=nullptr; 142 | if (false || !self.unsafeGetTensorImpl()->support_as_strided()) { 143 | auto size_vec = size.vec(); 144 | func = [=](const at::Tensor& input_base) { 145 | return input_base.expand(size_vec, implicit); 146 | }; 147 | } 148 | auto result = as_view( 149 | /* base */ self, /* output */ _tmp, /* is_bw_differentiable */ true, 150 | /* is_fw_differentiable */ true, /* view_func */ func, 151 | /* creation_meta */ at::GradMode::is_enabled() ? CreationMeta::DEFAULT : CreationMeta::NO_GRAD_MODE 152 | return result; 153 | } 154 | ``` 155 | 156 | InplaceOrView is considered part of the default TLS included set; i.e., 157 | it is always run. It is also associated with normal tensors (like Autograd), 158 | so that these kernels get run even if InplaceOrView is not in the 159 | default TLS included set. 160 | 161 | **The algorithm.** At a high level, we would like to skip both the 162 | Autograd and InplaceOrView kernels while in inference mode, whenever 163 | it is safe to do so. Whether or not this is safe is maintained by 164 | a pair of invariants: 165 | 166 | **The no-aliasing invariant:** Inference tensors are guaranteed 167 | not to alias with any tensor which is saved for backwards (or 168 | otherwise depends on accurate version counter tracking) 169 | 170 | **The immutable invariant:** Inference tensors are immutable outside of 171 | inference mode. 172 | 173 | The no-aliasing invariant guarantees it is safe to skip version counter 174 | bumps when mutating inference tensors, as the set of tensors affected by 175 | mutation is precisely the set of aliases to that tensor. The immutable 176 | invariant guarantees it is safe to skip view metadata, as view metadata 177 | is only used to enable inplace updates on tensors that require 178 | gradients. 179 | 180 | **Inference mode** is defined to be the state when: 181 | 182 | * Autograd is added to the TLS excluded set 183 | * InplaceOrView is removed from the TLS included set (recall that by 184 | default, InplaceOrView is part of the TLS included set) 185 | * If view metadata is recorded (e.g., because a tensor has InplaceOrView 186 | directly recorded on it), the creation metadata of the view is 187 | set to forbid subsequent inplace modification with 188 | `requires_grad=True` tensors (`CreationMeta::NO_GRAD_MODE`) 189 | 190 | It is legal for only Autograd to be excluded (this happens during normal 191 | processing of Autograd kernels), but it is illegal for InplaceOrView to 192 | be removed from the TLS included set if Autograd is not also excluded. 193 | 194 | An **inference tensor** is a tensor that does not have the Autograd or 195 | InplaceOrView dispatch keys and has no version counter. Whether or not 196 | the result of a functional/view operation is an inference tensor (e.g., 197 | that omit these keys) is the result of the following rules: 198 | 199 | * If a functional operation, the output tensor is an inference 200 | tensor if and only if we are running in inference mode. In practice, 201 | this is implemented by only adding the Autograd+InplaceOrView keys 202 | in the TensorImpl constructor if inference mode is off. 203 | * If a view operation, the output tensor is an inference tensor 204 | if and only if the input tensor is an inference tensor. In practice, 205 | this is implemented by propagating the dispatch key set from the 206 | base tensor to the view tensor. 207 | 208 | These rules guarantee half of the no-aliasing invariant: functional 209 | operations are guaranteed to have non-aliasing outputs and are safe to 210 | mark as inference tensors; view operations introducing aliasing 211 | relationships, and it is only safe for inference tensors to alias other 212 | inference tensors. 213 | 214 | Furthermore, the following operations on inference tensors are disabled: 215 | 216 | * Inplace modifications on inference tensors outside of inference mode 217 | (tested at the point we do version counter increments; this code is 218 | guaranteed to run outside of inference mode because InplaceOrView is 219 | part of default included TLS). This guarantees the immutability 220 | invariant. (TODO: Also need to prevent `requires_grad` from being 221 | explicitly toggled) 222 | * Saving an inference tensor for backwards (tested in the constructor 223 | of SavedVariable). This guarantees the other half of the no-aliasing 224 | invariant. 225 | 226 | **Examples.** Given the rules above, we can describe the behavior 227 | for each combination of possibilities: 228 | 229 | * In inference mode... 230 | * Inplace operation... 231 | * On a normal tensor - version counter will increment (due 232 | to InplaceOrView key on the normal tensor) 233 | * On an inference tensor - no increment 234 | * View operation... 235 | * On a normal tensor - view metadata is recorded, creation 236 | meta is set to `INFERENCE_MODE`, version counter is propagated, 237 | result is a normal tensor 238 | * On an inference tensor - view metadata is not recorded, 239 | result is an inference tensor 240 | * Functional operation... 241 | * On a normal tensor - produces an inference tensor 242 | * On an inference tensor - produces an inference tensor 243 | * Outside of inference mode... 244 | * Inplace operation... 245 | * On an inference tensor - forbidden 246 | * View operation... 247 | * On an inference tensor - allowed, view metadata is not 248 | recorded, result is an inference tensor 249 | 250 | **Edge case: explicit `requires_grad` setting.** One might expect that in 251 | no grad mode that it is impossible to allocate a tensor with 252 | `requires_grad=True`. However, this is not true: any tensor that 253 | is explicitly allocated with `requires_grad=True` preserves this 254 | property outside of no grad mode: 255 | 256 | ``` 257 | >>> with torch.no_grad(): 258 | ... x = torch.empty(2, requires_grad=True) 259 | ... 260 | >>> x 261 | tensor([-1.3667e-17, 4.5801e-41], requires_grad=True) 262 | ``` 263 | 264 | This can also be achieved by explicitly setting 265 | `x.requires_grad = True`. Furthermore, in no grad mode, this requires 266 | grad setting propagates to views 267 | 268 | ``` 269 | >>> with torch.no_grad(): 270 | ... x = torch.empty(2) 271 | ... y = x.view(2) 272 | ... x.requires_grad = True 273 | ... 274 | >>> y.requires_grad 275 | True 276 | ``` 277 | 278 | This poses a problem for inference mode, which doesn't track view 279 | metadata and cannot implement this propagation. Our proposed solution 280 | is to forbid setting `requires_grad` (but permit tensors to be directly 281 | constructed with `requires_grad=True`). This cannot be easily 282 | implemented today as internally `requires_grad=True` factory is 283 | implemented by first constructing a tensor, and then setting its 284 | `requires_grad=True`. 285 | 286 | ## Future work: skipping Autograd kernels when `requires_grad=False` 287 | 288 | As view and inplace handling has been moved out of Autograd kernels, a 289 | tantalizing possibility is to remove the Autograd dispatch keys from 290 | tensors with `requires_grad=False`, thus skipping this kernel entirely. 291 | 292 | But this work is currently blocked for the following reason: 293 | 294 | - If `requires_grad=False` skips Autograd kernel, functional ops won't 295 | be able to go through `AutoDispatchBelowInplaceOrView` guard which 296 | suppresses both autograd and InplaceOrView keys in TLS excluded. Not 297 | suppressing InplaceOrView key means unnecessary calls to 298 | `as_view/increment_version` if any view/inplace ops are used in the 299 | kernel implementation which adds a lot of overhead. To avoid overhead, 300 | instead of fallthrough kerenl being backend fallback, we'll want to 301 | use a real kernel that suppresses InplaceOrView key. But compared to 302 | the current implementation which only adds an extra dispatch for 303 | view/inplace ops, it forces all functional ops to have an extra 304 | dispatch as well. That's why it's blocked. 305 | - To unblock it requires some fixes like identifying `at::` callsites in 306 | backend-specific kernels (static analysis? ) , replacing these with 307 | `at::native::` should unblock us from linking `requires_grad` with 308 | VariableType kernel. Alternately, do 309 | https://github.com/pytorch/pytorch/issues/54614 310 | -------------------------------------------------------------------------------- /RFC-0017-PyTorch-Operator-Versioning.md: -------------------------------------------------------------------------------- 1 | # PyTorch Operator Versioning 2 | 3 | 4 | PyTorch is a framework that allows creating and executing programs expressed with a set of operators. 5 | 6 | These operators sometimes require changes to maintain the high quality user experience (UX) that PyTorch is known for. These changes are spread out across program representation as well as execution. This poses a challenge since PyTorch programs created at a point in time may need to run in newer implementations of the PyTorch runtime. When this is not possible due to a change in the operator set it is said the change is not backwards compatible (aka BC-breaking). On the opposite direction, it is also possible that PyTorch programs may need to be executed in an older implementation of the PyTorch runtime, and some changes in the operators may break this forward compatibility (aka FC-breaking). 7 | 8 | BC and FC breaking changes have been challenging to coordinate across PyTorch because there are multiple consumers of PyTorch’s op set and we promise to keep models running in production working as expected. 9 | 10 | This document proposes a new BC and FC policy based on operator versioning. 11 | 12 | Moving forward, we're not having a difference between Meta internal and Open Source (OSS) guarantees. They would be moving under *the same Service Level Agreement (SLA)* to both internal and external use cases. 13 | 14 | 15 | ## History 16 | 17 | 18 | ### Backwards Compatibility 19 | 20 | Backwards compatibility (BC), the ability for PyTorch to continue running programs from older versions, is important so programs don’t need to be forcefully updated to comply with the new runtime implementation. 21 | 22 | PyTorch current SLA on backwards compatibility: 23 | 24 | * **OSS** — “stable” features will be deprecated for one release before a BC-breaking change is made. [PyTorch OSS BC-breaking policy](https://pytorch.org/docs/master/) 25 | * **Meta Internal** — we will not break a serialized torchscript program running in production at Meta (to be replaced with a more generic SLA) 26 | 27 | BC-breaking operator changes were previously governed by the [Backward-compatibility Breaking Change Review Process](https://fb.quip.com/gydOArylrcKd), but this only covered torchscript and eager. A generic process needs to be visible from OSS. 28 | 29 | 30 | ### Forwards Compatibility 31 | 32 | Forwards compatibility (FC), the ability for older versions of PyTorch to run programs from newer versions, is important so users don’t need to update PyTorch. 33 | 34 | PyTorch current SLA on forward compatibility: 35 | 36 | 37 | 38 | * **OSS** — no promise 39 | * **Meta Internal** — PyTorch commits can run existing PyTorch eager, package/deploy, and serialized torchscript programs for at least two weeks 40 | * The addition of a new kwarg-only argument at the end of an op’s parameter list (but before out=, if present) with a default value is FC-compatible for serialized [torchscript](https://fb.workplace.com/groups/pytorch.dev/permalink/909079013003913/) and [mobile](https://fb.workplace.com/groups/pytorch.dev/permalink/912379562673858/). 41 | 42 | 43 | ## Goals 44 | 45 | 46 | 47 | We aim to establish a policy that can support and is consistent across both server and edge use cases, including TorchScript, package/deploy and Edge. More specifically: 48 | * Support backward and (some) forward compatibility for an arbitrary BC or FC break update (schema, functional, etc) on an operator by [versioning](https://docs.google.com/document/d/1nyXmss2O003ZgKrhDmd-kyLNjjMqEXww_2skOXqkks4/edit). 49 | * Update and expansion of our existing SLAs (Service-Level Agreements). 50 | * A systematic flow to prevent BC/FC breakage on both deploy and runtime stages. 51 | * Provide testing that accurately detects dangerous BC and FC-breaking changes. 52 | 53 | ## Non-goals 54 | 55 | 56 | 57 | * It does not mean that models with old operator schema can **always** run successfully on new runtime and vice versa. 58 | * Supporting old model out of BC SLA is not guaranteed 59 | * Using new feature is not supported for old runtimes out of the 2-week server FC SLA 60 | * It’s not for the “automatic” BC/FC support that can be done without any developer’s manual work (for example, the number of arguments mentioned in the Context). To apply versioning on the updated operator, the author needs to manually add a line in the version table and provide the resolution function for BC. This proposal is for BC/FC breakages that the automatic supports don’t apply. 61 | * It does not include the BC/FC for package/deploy itself. The Python-only operators are transparent to TS and Edge clients, with the TS compilation. 62 | 63 | # Glossary 64 | 65 | * Backwards Compatibility (BC) — The ability to run programs from older versions of PyTorch 66 | * Version — A number that describes the format of the PyTorch program being read as well as providing partial information about which OpSet is required to run the program properly. (More precisely it counts the number of BC-breaking changes.) (See the dynamic versioning note (https://github.com/pytorch/pytorch/blob/6db8f7a70920f91418078fe09477eed0b0adefdb/caffe2/serialize/versions.h#L11).) 67 | * Forwards Compatibility (FC)* — The ability to run programs from future versions of PyTorch 68 | * Operator — A pair of a string (the operator’s “name” or “symbol”) and mathematical function, e.g. (div, /) 69 | * OpSet — A set of PyTorch operators (including upgraders) 70 | 71 | 72 | # Proposal 73 | 74 | We propose the operator versioning that works across eager, TorchScript, torch.package and mobile. It uses a version number + corresponding upgraders in torchscript to avoid breakage due to BC/FC breaking operator updates. 75 | 76 | * **Eager changes** 77 | * `operator_versions.yaml` and `operator_upgraders.py` are added to register operator upgrades that are BC/FC breaking. 78 | * Note: this will not cover functional operators 79 | * The default value is zero 80 | * A version bump is also required for FC break only. It's is good for compatibility analysis: if the client is running old runtime, we don't deliver the new model with the un-compatible operator to avoid unexpected crash. 81 | * **Use a single operator version number for all operators** 82 | * This number may be shared by the deploy version, but separate from other file format versions 83 | * **Newer version of the operator registry must specify an upgrader** that conforms to the older version of the operator schema. Its body is a TorchScript-able function that uses the newer version of operator to implement old semantics. 84 | * One upgrader per historic signature. The registry specifies the symbol and the file formats those upgraders are applied to. 85 | * [Improved BC testing] Tests that the old serialized version of the operator can still be loaded on the new runtime and run as expected need to be easy to add 86 | * This seems straightforward for Torchscript and Edge testing but I’m not sure how it would work for deploy/package 87 | * [Improved FC testing] Tests that the new version of the operator can still be loaded on old runtimes and run as expected need to be easy to add 88 | * This might require a new test job, which could be tricky to setup. We have no plans to support this. 89 | * **Torchscript changes** 90 | * Reuse the _version_ record in the model file as the version number for operators. In the code it's `kProducedFileFormatVersion` 91 | * During loading into the TorchScript compiler, TorchScript needs to match operator schema according to the table of operator versions stored in the package. This would generate IR that conforms to older schema. 92 | * TorchScript takes IR of older schema and use upgrader as a builtin function. 93 | * Out-of-support operator versions (ie. those no longer defined in `native_functions.yaml` with a valid upgrader) need to throw an error 94 | * **Edge runtime and mobile delivery service changes** 95 | * Delivery compatibility: communicating operator version table deployed on device and deliver models appropriately 96 | * Runtime: load upgraders at model loading time, to ensure older models always work after updating runtime 97 | * Unknown operator versions need to throw an error 98 | * The operator version and upgraders are built into the runtime for BC. 99 | * Allow for the addition of optional keyword-only arguments without a version bump or FC concern 100 | * Since additional operators can be introduced in upgraders, tracing based selective build should also cover upgraders: easier for BC because the new runtimes goes with the upgraders. 101 | * **torch.package changes** 102 | * Each torch.package package contains a table of operators and corresponding version according to PyTorch build used to package the model 103 | * Q: How does the torch.package scenario for mapping old versions to current PyTorch operators work? 104 | * A: Operator versioning, by design, can’t cover all torch.package use cases. So this should be out of scope. 105 | * **New documentation required** 106 | * e2e BC-breaking guide 107 | * To make a BC-breaking change update the version and write a torchscript adaptor and a mobile adaptor 108 | * e2e FC-breaking guide 109 | * It’s OK to add new optional keyword-only arguments as long as their default semantic preserve the operator’s current semantics 110 | * **SLA window** 111 | * PyTorch SLA will ensure that models developed using a certain version and developed with non-deprecated APIs, will be runnable (with a slight performance regression allowed) for *up to one more release or 180 days* (from the version release date that introduced the BC-breaking change), whichever is later. 112 | 113 | Note that the proposal does not introduce an explicit version to _all_ PyTorch operators. Instead code changes are only required for updated operators with BC/FC breakage, that cannot be handled by automatic BC/FC methods. For other operators, the implicit version is v0. 114 | 115 | As an example, there’s a BC/FC breaking update on operator foo. 116 | 117 | Before: 118 | ``` 119 | foo(Tensor self, Scaler alpha=1, Tensor b) -> Tensor 120 | ``` 121 | After: 122 | ``` 123 | foo(Tensor self, Tensor c, Scaler alpha=1, Tensor b, *, Tensor(a!) out) -> Tensor(a!) 124 | ``` 125 | In schema, a Tensor argument, c is added. Note that it’s not added as a “tailing default argument”, so that BC/FC cannot be handled automatically. 126 | 127 | Accordingly, in the kernel of foo, the implementation is updated based on the new argument c. The pseudo code (it’s in Python format, but can be written in C++ as well) looks like: 128 | 129 | 130 | ```python 131 | def foo(Tensor self, Tensor c, Scaler alpha=1, Tensor b) -> Tensor: 132 | # The original kernel implementation 133 | ... 134 | if not c.empty(): 135 | a.add_(c) 136 | ``` 137 | 138 | 139 | 140 | ## Code changes (minimize the work of developer) 141 | 142 | If there is a BC/FC break with schema change in a PR, a lint error can be automatically generated, with instructions below to update the PR. 143 | 144 | 145 | ### Version bump 146 | 147 | Update version field in _operator_versions.yaml_. 148 | The current table that torchscript uses should be migrated to _operator_versions.yaml_. 149 | 150 | ### BC updates 151 | 152 | The developer needs to implement a BC “upgrader” in Python. The upgrader code is put in a centralized python file, _operator_upgraders.py_, in TorchScript format. 153 | 154 | _operator_version.yaml_ 155 | ```yaml 156 | - func: foo 157 | version: 10 158 | upgrader: foo_upgrader_0_9 159 | version: 25 160 | upgrader: foo_upgrader_10_24 161 | ``` 162 | _operator_upgraders.py_ 163 | ```python 164 | def foo_upgrader_0_9(Tensor self, Tensor c, Scaler alpha=1, Tensor b): 165 | c = at.empty() 166 | return foo(self, c, alpha, b) 167 | 168 | def foo_upgrader_10_24(...): 169 | ... 170 | ``` 171 | 172 | * Different from the upgraders defined[ here](https://github.com/pytorch/pytorch/blob/8a094e3270d2fbec6060099b7059898f4a1c104a/torch/csrc/jit/frontend/builtin_functions.cpp#L98), 173 | * For some operator updates, it’s not possible to have a BC adapter. If it's FC break, an upgrader is not needed. In such a case, the operator number could still help to check compatibility and to quickly detect the source of failure with meaningful error. 174 | * For most of the operator changes, the upgrader code is not expected to be heavy. However, the performance overhead should be observed, especially for edge use cases. 175 | * If there are multiple upgraders (for example, v20 runtime loading both v0 and v10 models). 176 | 177 | 178 | ### FC updates 179 | 180 | * Except for version bump, no FC update is needed for 181 | * The BC/FC break op is not in the model (using [Dynamic versioning](https://github.com/pytorch/pytorch/pull/40279/files)) 182 | * Server with 2-week FC window 183 | * Mobile delivering systems use [compatibility analysis](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/mobile/model_compatibility.cpp) to guard FC break models to be delivered to clients. 184 | * For internal refactors that an operator may call different transient operators, re-tracing is required for mobile to make sure the inputs of tracing based selective build don't have missing operators. Currently, it's guarded by a CI lint and a command line to retrace. 185 | 186 | If needed as a temporary hotfix, an optional downgrader can be used in the backport function shown in the diagram below. It's not planed in this RFC and the discussion is left in the "Open Questions" session. 187 | 188 | 189 | ## How does it handle BC and FC? 190 | 191 | Aligned with the “Compatibility scenarios” in the [doc of versioning in general](https://docs.google.com/document/d/1nyXmss2O003ZgKrhDmd-kyLNjjMqEXww_2skOXqkks4/edit), we do BC and FC validation at both deploy time and runtime. 192 | 193 | ### BC 194 | 195 | Deploying a new runtime that needs to execute an existing on-device model. 196 | 197 | 198 | 199 | * Deploy time 200 | * The upgraders must be delivered together with the new runtime. 201 | * Runtime 202 | * The new runtime load the upgraders lazily: only loads the corresponding upgraders when the version matches. 203 | * The new runtime should always be able to run the old model, unless the old model is “retired”. Error out when the runtime’s op min version > the model's op version. 204 | 205 | 206 | ### FC 207 | 208 | Deploying a new model to an existing runtime. 209 | 210 | 211 | 212 | * Deploy time. Assuming we **keep the op version table of old runtime [query runtime]**, 213 | * For each operator in the “new” model, the version number is in the range of version in the “old” runtime. 214 | * Otherwise, if possible, backport the model with old operator schema 215 | * Runtime 216 | * The “old” runtime can run the “new” model, and errors out at load time: 217 | * When an unknown op appears. 218 | * When reaching an op whose minimum runtime version >= current 219 | 220 | # Open Questions 221 | 222 | ## Downgraders for FC 223 | Dual to upgreaders for BC on client, downgraders can be used for FC on server. There are several options: 224 | * We set a 2-week (maybe 3 week) FC window. The FC break update is split into two PRs. The first PR with new operator readers is rolled out. After the FC window (supposing all client runtime are updated to be able to read the new operator), the producer of the new operators are turned on to generate models with new operator schema. 225 | * The 2-week window may not be sufficient for mobile, where the runtime updated cannot be fully controlled. In such a case, we need to "backport" the new model to an old format. 226 | * We save the historical models that were exported in earlier code. If a compatible older model can be found, the older model is delivered to the old runtime. 227 | * We could apply a downgrader at the server side: to rewrite the new model in the old format. Since the downgrader happens at model delivery time, it can be done out of the major export flow. 228 | * Keep old PyTorch release binaries (for example, PyTorch 1.10). Mobile can backport some operators to PyTorch 1.10 opportunistically. It's more challenging to maintain historical binaries than historical models. 229 | -------------------------------------------------------------------------------- /RFC-0012-profile-directed-typing.md: -------------------------------------------------------------------------------- 1 | # Profile Directed Typing (PDT) 2 | 3 | ## Overview 4 | 5 | A common reason for failed model scripting is missing type annotations. For a long time, the process of fixing incorrect type annotations has been through trial and error—that is, by slowly going through the various type-checking errors generated from `torch.jit.script` and fixing them one by one. This workflow is both inefficient and frustrating. 6 | 7 | This process is not necessary because TorchScript can observe the types by running the unscripted program with a set of example inputs. In our preliminary experiments, we were able to leverage existing tools like MonkeyType to greatly reduce—or even get rid of—the need to annotate types by hand. We call this process **Profile Directed Typing (PDT)**. 8 | 9 | ## Background and Motivation 10 | 11 | `torch.jit.script` is a decorator or function call that compiles a PyTorch program. It first inspects the Python source code, then it constructs a semantics-preserving intermediate representation (IR). The TorchScript IR can then be transformed, packaged and eventually deployed. 12 | 13 | The TorchScript IR is statically typed, which means that all values must have a known type at runtime. Python, on the other hand, is not statically typed. This means that Python source code may not contain complete or accurate type annotations, as they aren't necessary for the Python interpreter. 14 | 15 | `torch.jit.script` relies on explicit type annotations to determine the types of functions. In the absence of type annotations, TorchScript uses a simple type inference algorithm: everything is `torch.Tensor`. Unfortunately, even basic types like `int`, `bool`, and `float` are not inferred by TorchScript, which leads to easily-avoidable compilation errors. The tradeoffs of moving to a more sophisticated bidirectional type inference algorithm have been discussed by the team, and we have ultimately concluded that the benefits of a bidirectional type inference algorithm would not outweigh the added complexity. 16 | 17 | Although TorchScript often needs annotated source code to successfully compile, this requirement poses a challenge to users. 18 | 19 | ## Challenges of Type Annotations 20 | 21 | Adding type annotations is challenging for a few reasons: 22 | 23 | * **Effort:** There is a high developer cost in adding missing annotations to an entire codebase. 24 | * **Knowledge Gap**: Users may not know the code base well enough to determine the correct types. This is especially common for models downloaded from ModelZoo or a Github repo. 25 | * **Third-party Library Dependencies**: Users may not have the access rights necessary to add type annotations to a third-party library. 26 | 27 | Because of these challenges, many users give up on using TorchScript or even PyTorch entirely. 28 | 29 | ## MonkeyType 30 | 31 | [MonkeyType](https://github.com/Instagram/MonkeyType) is a Python tool that was designed by Instagram to add type annotations to millions of lines of legacy Python code. 32 | 33 | The general workflow of MonkeyType is as follows: 34 | 35 | * A Python profiling hook is registered via [sys.setprofile](https://docs.python.org/3/library/sys.html#sys.setprofile). The hook intercepts function calls and records any values that enter or exit the intercepted calls. 36 | * A Python function is run with example input provided by the user under registered profiling hooks. 37 | * The registered hooks generates a series of [CallTrace](https://github.com/Instagram/MonkeyType/blob/08ee7d2aa268d4dd00f071535522c70f794fd05b/monkeytype/tracing.py#L42)s, which contain a mapping from function objects to their argument and return types. 38 | * The collected types are coalesced into a single type (namely, a `Union`). This allows MonkeyType to form a single, distinct signature for each function. 39 | * The computed function signatures are dumped into a data store. (The data store defaults to a sqlite database saved on disk.) 40 | * The computed function signatures are applied back to the source code in one of two ways: 41 | * Source code is directly modified by leveraging [libcst](https://pypi.python.org/pypi/libcst) 42 | * [Stub files](https://mypy.readthedocs.io/en/latest/getting_started.html#library-stubs-and-typeshed) (that are honored by Python type checkers!) are generated 43 | 44 | 45 | Though MonkeyType is most often used as an independent tool, it provides a relatively stable API that allows it to be used as a Python library. Many of its behaviors can be customized, including: 46 | 47 | * Defining a custom data store for profiling data 48 | * Defining a Callable to rewrite inferred types to something else 49 | * Defining a code filter that teaches MonkeyType to selectively ignore certain Python modules in tracing 50 | 51 | ## Proposal 52 | 53 | The painful process of adding and fixing type annotations should not be necessary when we have access to the values that will be used at runtime. By using MonkeyType, we can profile the compilation target and extract the arguments and return types of any involved functions. 54 | 55 | In every call to `torch.jit.script`, TorchScript should take the following steps: 56 | 57 | * Run the compilation target (a function or `nn.Module`) in eager mode. During execution, MonkeyType will profile the types of all values passed into and returned from functions and methods. 58 | * Infer a set of coherent types that each function/method can accept, which we call “profiled types”. The profiled types are ephemeral and stored in a custom memory structure. 59 | * Combine the profiled types and explicitly annotated types to form a set of comprehensive function signatures. 60 | * Compile the target with assistance of combined set of function signatures. 61 | 62 | The following sections will describe in detail how to implement the proposed functionality. 63 | 64 | ### API Changes 65 | 66 | In addition to the function or nn.Module to compile, `torch.jit.script` accepts an `Optional[List[Tuple]]` `example_inputs`. Each tuple in the list represents a set of valid inputs to the Callable. 67 | 68 | The new version of `torch.jit.script` can be invoked like following: 69 | 70 | ``` 71 | m = SomeModule() 72 | 73 | example_inputs= [ 74 | (torch.rand(**2**,**3**), **True**, **3**), 75 | (torch.rand(**2**,**3**), **False**, **6**), 76 | (torch.rand(**2**,**3**), **False**, "Some Input Text") 77 | ] 78 | 79 | # Script module with PDT turned on. `example_inputs` is used to infer 80 | # the arg/return types of any functions touched by `SomeModule::forward` 81 | scripted_m = torch.jit.script(m, example_inputs) 82 | 83 | # Run the newly scripted and annotated model with real inputs 84 | for i in range(**10**): 85 | scripted_m(real_inputs[i]) 86 | ``` 87 | 88 | Note that the example inputs do not have to be similarly typed. In fact, they should be as different as possible to cover a greater number of execution paths in the compilation target. 89 | 90 | ### MonkeyType Dependency 91 | 92 | We now require PyTorch users to manually install MonkeyType if they want to use PDT. 93 | 94 | An alternative approach would be to add MonkeyType as a submodule, but doing so would require us to modify the PyTorch build system. Furthermore, PyTorch only allows critical build dependencies in submodules, and MonkeyType doesn’t qualify as a “critical build dependency”. 95 | 96 | ### Customized `MonkeyType::CallTraceStore` 97 | 98 | We implemented a custom [CallTraceStore](https://github.com/Instagram/MonkeyType/blob/f680c783c3aec6b0f613c4ea0268032cab23e788/monkeytype/db/base.py#L29), which we call `JitTypeTraceStore`. This data structure holds the traced function signatures generated by MonkeyType and provides an interface for TorchScript to query later. 99 | 100 | ``` 101 | class JitTypeTraceStore(CallTraceStore): 102 | def __init__(self): 103 | super().__init__() 104 | # key - the fully-qualified name of the function 105 | # value - a list of all the corresponding CallTraces 106 | self.trace_records: Dict[string, List[CallTrace]] = {} 107 | 108 | def add(self, traces: Iterable[CallTrace]): 109 | for t in traces: 110 | qualified_name = get_qualified_name(t.func) 111 | self.trace_records[qualified_name].append(t) 112 | 113 | # ... other boilerplate methods ... 114 | ``` 115 | 116 | `JitTypeTraceStore` holds the traced function signatures for fast lookup by qualified name. `CallTrace` then contains a pointer to the raw Callable, which means that it's always possible to disambiguate different functions. 117 | 118 | ### Customized `MonkeyType::Config` 119 | 120 | To customizing our tracing behavior, we subclass the MonkeyType configurable tracing API [Config](https://monkeytype.readthedocs.io/en/latest/configuration.html) and override certain key methods. 121 | 122 | ``` 123 | class JitTypeTraceConfig(monkeytype.config.Config): 124 | def **init**(self, s: JitTypeTraceStore): 125 | super().**init**() 126 | self.s = s 127 | 128 | def trace_store(self) -> CallTraceStore: 129 | return s 130 | 131 | def code_filter(self) -> Optional[CodeFilter]: 132 | return default_code_filter 133 | ``` 134 | 135 | This config effectively tells MonkeyType to use our customized `JitTypeTraceStore` to hold recorded function signatures. We use the [default code filter](https://github.com/Instagram/MonkeyType/blob/f680c783c3aec6b0f613c4ea0268032cab23e788/monkeytype/config.py#L111) (`default_code_filter`) and avoid the excessive trace records that would result from recording Python builtins and third-party libraries. 136 | 137 | ### Tracer Invocation 138 | 139 | After performing basic script eligibility checks in `torch.jit.script`, the following code begins the tracing process: 140 | 141 | ``` 142 | from monkeytype import trace as monkeytype_trace 143 | 144 | s = JitTypeTraceStore() 145 | monkeytype_config = JitTypeTraceConfig(s) 146 | with monkeytype_trace(monkeytype_config): 147 | for example_input : example_inputs: 148 | obj(*example_input) 149 | ``` 150 | 151 | ### Type Rewriting 152 | 153 | TorchScript only supports a subset of Python types, so there’s the danger that MonkeyType could gather an unscriptable type. To prevent this from happening, we scan through the trace records and remove any types that are invalid in TorchScript. 154 | 155 | There is also one interesting situation in which TorchScript’s default inference algorithm is more sophisticated than PDT: function return. TorchScript can deduce more accurate return types than MonkeyType can observe, so, in this case, we simply discard the types gathered by MonkeyType. 156 | 157 | ### Aggregation 158 | 159 | MonkeyType tracing records a function signature for every function invocation. For example, given the following function and a set of sample inputs: 160 | 161 | ``` 162 | def fn(cond, x): 163 | if cond: 164 | return x 165 | else: 166 | return x + 1 167 | ``` 168 | 169 | MonkeyType may yield the following TraceRecords: 170 | 171 | ``` 172 | TraceRecord1: Arguments: {cond: Bool, x: Int}, Return: Int 173 | TraceRecord2: Arguments: {cond: Bool, x: Float}, Return: Float 174 | TraceRecord3: Arguments: {cond: Bool, x: Float}, Return: Float 175 | ``` 176 | 177 | In other words, the types for the argument `x` is dynamic and can’t be expressed with a single type. To account for this, we can simply use `Union` to express the function signature as: 178 | 179 | ``` 180 | fn: Arguments: {cond: Bool, b: Union[Int, Float]}, Return: Union[Int, Float] 181 | ``` 182 | 183 | In order to aggregate types in this way, we will add an additional method `analyze` to `JitTypeTraceStore` to consolidate the collected trace records. 184 | 185 | ``` 186 | class JitTypeTraceStore(CallTraceStore): 187 | # ... other methods ... 188 | def analyze(self): 189 | self.consolidated_types = {} 190 | # Perform analysis described above 191 | 192 | def query(self, qualified_name): 193 | # Return types from `consolidated_types` 194 | ``` 195 | 196 | ### Compiling With Observed Types 197 | 198 | TorchScript currently relies on `[annotations::get_signature](https://github.com/pytorch/pytorch/blob/758fb94fcb8e20d41df6b055e80725e37ddb4854/torch/jit/annotations.py#L62)` to get the signature of functions. `get_signature` works by either using `inspect::signature` for Python3-style type annotations or parsing the source code to find any Python2-style type comments. 199 | 200 | We enhance `get_signature` to look up types from `JitTypeTraceStore` and to use those types for arguments that are not manually annotated by users in source code. 201 | 202 | We will clearly indicate that a type is coming from profiling-based inference for easier user-side debugging. Concretely, this will be implemented by adding a flag to every type in the JIT type system to denote whether or not a given instance of that type was inferred or not. This flag allows us to have more specific and actionable error messages. 203 | 204 | ## Backward Compatibility 205 | 206 | This feature is fully backward compatible: 207 | 208 | * The feature is only enabled when the user provides an additional argument to `torch.jit.script`. It should have no impact on the execution of legacy code. 209 | * When the feature is enabled, it honors manual type annotations from users and only provides additional typing information for arguments that were previously unannotated. Without assistance from MonkeyType, these arguments would have caused compilation failure anyway. 210 | 211 | ## Limitations 212 | 213 | This approach has some limitations by design: 214 | 215 | * Type inference is only as good as the input examples provided. There is no way for TorchScript to infer the types for code paths that are not hit with the provided input. 216 | * Fully running the compilation target in eager mode adds overhead to the compilation and causes longer overall compile times. 217 | * Counterpoint: This shouldn't be a problem because `jit.trace` has similar—or higher--overhead.) 218 | * Currently, `torch.jit.script` also compiles custom Python classes, which may be invoked in multiple ways. We plan to change our method of class compilation soon, but, if this doesn't happen, we'll need to revisit the design of PDT. 219 | * MonkeyType only observes the types of arguments and return values. MonkeyType cannot infer the types in other situations, e.g. variable type assignment. 220 | * MonkeyType does not work for `torch.jit.interface` classes because their methods are never actually invoked during execution, thus MonkeyType cannot infer their type annotations. 221 | * Counterpoint: `torch.jit.interface` is a way to allow users to specify a module that contains certain functions. It’s hard to imagine that a user would create a `torch.jit.interface` class without knowing the exact signatures for all required methods. 222 | * Counterpoint: `torch.jit.interface` is never publicized and rarely used. 223 | 224 | ## Alternatives Considered 225 | 226 | ### Profiling in FX-based Interpreter 227 | 228 | Instead of using MonkeyType, we discussed creating an [FX-based interpreter](https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern) that could be customized to observe the types of the values that pass through the program during interpretation. However, there are several reasons why this approach is suboptimal: 229 | 230 | * FX-traced modules do not capture control flow, which means that this implementation wouldn’t for all use cases of `torch.jit.script`. 231 | * FX doesn’t trace into the underlying dunder methods, which could be customized by users and contain functions/methods that need typing. 232 | 233 | ### PyAnnotate Profiler 234 | 235 | [PyAnnotate](https://github.com/dropbox/pyannotate) is another open source tool that has similar functionality to MonkeyType—that is, it runs a preliminary version of the user’s code to observe the types, dumps the function signatures into an ephemeral file, and finally applies the function signatures back to the original program. However, when compared to MonkeyType, PyAnnotate has following disadvantages: 236 | 237 | * PyAnnotate only generates Python2-style type comments, which requires additional parsing and may not be expressive enough for later versions of Python. 238 | * PyAnnotate is less customizable—for example, it doesn’t support filtering functions/methods from third-party modules. It would be possible to fork and modify their source code, but this would represent a significant time cost. 239 | * PyAnnotate is not as well-maintained, while MonkeyType has commits almost every month. 240 | * PyAnnotate is owned by Dropbox, while MonkeyType is maintained by Facebook. Using an in-house tool will likely lead to better support. 241 | 242 | ### Bidirectional Type Inference Algorithm 243 | 244 | Another approach would be to devise an algorithm that infers the type of an argument based on how it is used in the function body. However, this solution carries a higher complexity and may not be a good fit for TorchScript language. 245 | 246 | ### Human-in-the-Loop Tool 247 | 248 | We thought about creating a tool that, like PDT, would be based on MonkeyType; the difference would be that the type annotations would be suggested and it would be up to the user to actually make the proposed changes. However, a standalone tool like this would have some major drawbacks: 249 | 250 | * Both MonkeyType and our user could be unaware of which Python types are not valid TorchScript. It’s likely that a more inexperienced user would be forced to make several passes through their code. 251 | * User code may reference a third-party library that the user has installed in a system-wide location. If we blindly apply type annotations to all code touched by the user’s program, we could pollute the user’s global environment. 252 | * Using a standalone tool is less streamlined and degrades our user experience. 253 | -------------------------------------------------------------------------------- /RFC-0032-numpy-support-in-dynamo.md: -------------------------------------------------------------------------------- 1 | # A PyTorch - NumPy compatibility layer 2 | 3 | **Authors:** 4 | * @ev-br 5 | * @lezcano 6 | * @rgommers 7 | 8 | ## Summary 9 | This RFC describes a proposal for a translation layer from NumPy into PyTorch. 10 | In simple terms, this accounts for implementing most of NumPy's API (`ndarray`, 11 | the `numpy`, `numpy.linalg`, `numpy.fft` modules, etc) using `torch.Tensor` 12 | and PyTorch ops as backend. 13 | 14 | The main goal is: **make TorchDynamo understand NumPy calls**. 15 | This should enable an end user to combine code that uses the PyTorch API with 16 | code that uses the NumPy API, in a way that allows TorchDynamo to understand 17 | those function calls and build up an execution graph. To enable this, it is key 18 | that there is a translation layer from NumPy to PyTorch function calls, which 19 | TorchDynamo can use in order to build up its execution graph from PyTorch 20 | functions/primitives only. For niche functions in NumPy that don’t have a 21 | PyTorch equivalent, it’s okay to graph break and still call NumPy to execute 22 | the function call. 23 | 24 | The work is currently being done at [numpy_pytorch_interop](https://github.com/Quansight-Labs/numpy_pytorch_interop/). 25 | 26 | 27 | ## Motivation 28 | 29 | ### Introductory examples 30 | 31 | Consider the following snippet: 32 | ```python 33 | import numpy as np 34 | 35 | x = np.random.randn(3, 4) 36 | y = np.random.randn(4, 3) 37 | z = np.dot(x, y) 38 | w = z.sum() 39 | ``` 40 | 41 | When we trace this program with the compat layer, the semantics of the 42 | program would stay the same, but the implementation would be equivalent to 43 | 44 | ```python 45 | import torch 46 | x = torch.randn(3, 4, dtype=torch.float64) 47 | y = torch.randn(4, 3, dtype=torch.float64) 48 | z = torch.matmul(x, y) 49 | w = z.sum() 50 | ``` 51 | 52 | Here, we can already spot a couple differences between NumPy and PyTorch. 53 | The most obvious one is that the default dtype in NumPy is `float64` rather than 54 | `float32`. The less obvious is very sneakily hiding in the last line. 55 | 56 | ```python 57 | >>> type(w) 58 | 59 | ``` 60 | 61 | Reductions and similar operations in NumPy return the infamous NumPy scalars. 62 | We'll discuss these and other NumPy quirks and how we dealt with them in the 63 | [design decision section](#design-decisions). 64 | 65 | 66 | Let's now have a look at a toy example of how this layer would be used. 67 | ```python 68 | import torch 69 | import numpy as np 70 | t1 = torch.tensor([1, 3, 5]) 71 | t2 = torch.exp(t) 72 | # Now say the user has some code lying around which uses NumPy: 73 | def fn(x, y): 74 | return np.multiply(x, y).sum() 75 | 76 | result = fn(t1, t2) 77 | t_results = torch.empty(5, dtype=torch.float64) 78 | t_results[0] = result # store the result in a torch.Tensor 79 | ``` 80 | 81 | Note that this code mixing NumPy and PyTorch already works in eager mode with 82 | CPU tensors, as `torch.Tensor` implements the `__array__` method. Now, the 83 | compatibility layer allows us to trace through it. In order to do that, there 84 | would be no necessary changes, other than simply ask `torch.compile` to trace 85 | through it: 86 | 87 | ```python 88 | @compile 89 | def fn(x, y): 90 | return np.multiply(x, y).sum() 91 | ``` 92 | 93 | Then, TorchDynamo will cast `x` and `y` to our internal implementation of `ndarray`, 94 | and will dispatch `np.multiply` and `sum` to our implementations in terms of `torch` 95 | functions, effectively turning this function into a pure PyTorch function. 96 | 97 | ### Design decisions 98 | 99 | The main ideas driving the design of this compatibility layer are the following: 100 | 101 | 1. The goal is to transform valid NumPy and mixed PyTorch-NumPy programs into 102 | their equivalent PyTorch-only execution. 103 | 2. The behavior of the layer should be as close to that of NumPy as possible 104 | 3. The layer follows the most recent NumPy release 105 | 106 | The following design decisions follow from these: 107 | 108 | **A superset of NumPy**. NumPy has a number of well-known edge-cases (as does 109 | PyTorch, like spotty support for `float16` on CPU and `complex32` in general). 110 | The decision to translate only valid NumPy programs, often allows us to 111 | implement a superset of the functionality of NumPy with more predictable and 112 | consistent behavior than NumPy itself has. 113 | 114 | **Exceptions may be different**. We avoid entirely modelling the exception 115 | system in NumPy. As seen in the implementation of PrimTorch, modelling the 116 | error cases of a given system is terribly difficult. We avoid this altogether 117 | and we choose not to offer any guarantee here. 118 | 119 | **Default dtypes**. One of the most common issues that bites people when migrating their 120 | codebases from NumPy to JAX is the default dtype changing from `float64` to 121 | `float32`. So much so that this is noted as one of 122 | [JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). 123 | Following the spirit of making everything match NumPy by default, we choose the 124 | NumPy default dtype whenever the `dtype` was not made explicit in a factory function. 125 | We also provide a function `set_default_dtype` that allows to change this behavior 126 | dynamically. 127 | 128 | **NumPy scalars**. NumPy's type system is tricky. At first sight, it looks 129 | like PyTorch's, but with few more dtypes like `np.uint16` or `np.longdouble`. 130 | Upon closer inspection, one finds that it also has 131 | [NumPy scalar](https://numpy.org/doc/stable/reference/arrays.scalars.html) objects. 132 | NumPy scalars are similar to Python scalars but with a fixed precision and 133 | array-like methods attached. NumPy scalars are NumPy's preferred return class 134 | for reductions and other operations that return just one element. 135 | NumPy scalars do not play particularly well with 136 | computations on devices like GPUs, as they live on CPU. Implementing NumPy 137 | scalars would mean that we need to synchronize after every `sum()` call, which 138 | would be terrible performance-wise. In this implementation, we choose to represent 139 | NumPy scalars as 0-D arrays. This may cause small divergences in some cases. For example, 140 | consider the following NumPy behavior: 141 | 142 | ```python 143 | >>> np.int32(2) * [1, 2, 3] # scalar decays to a python int 144 | [1, 2, 3, 1, 2, 3] 145 | 146 | >>> np.asarray(2) * [1, 2, 3] # zero-dim array is an array-like 147 | array([2, 4, 6]) 148 | ``` 149 | 150 | We don't expect these to pose a big issue in practice. Note that in the 151 | proposed implementation `np.int32(2)` would return the same as `np.asarray(2)`. 152 | In general, we try to avoid unnecessary graph breaks whenever we can. For 153 | example, we may choose to return a tensor of shape `(2, *)` rather than a list 154 | of pairs, to avoid a graph break. 155 | 156 | **Type promotion**. Another not-so-well-known fact of NumPy's dtype system and casting rules 157 | is that it is data-dependent. Python scalars can be used in pretty much any NumPy 158 | operation, being able to call any operation that accepts a 0-D array with a 159 | Python scalar. If you provide an operation with a Python scalar, these will be 160 | cast to the smallest dtype they can be represented in, and only then will they 161 | participate in type promotion. This allows for for some rather interesting behaviour 162 | ```python 163 | >>> np.asarray([1], dtype=np.int8) + 127 164 | array([128], dtype=int8) 165 | >>> np.asarray([1], dtype=np.int8) + 128 166 | array([129], dtype=int16) 167 | ``` 168 | This data-dependent type promotion will be removed in NumPy 2.0 (planned for Dec'23), and will be 169 | replaced with [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html) 170 | (already implemented in NumPy, it needs to be enabled via a private global switch now). 171 | For simplicity and to be forward-looking, we chose to implement the 172 | type promotion behaviour proposed in NEP 50, which is much closer to that of 173 | PyTorch. 174 | 175 | Note that the decision of going with NEP 50 complements the previous one of 176 | returning 0-D arrays in place of NumPy scalars as, currently, 0-D arrays do not 177 | participate in type promotion in NumPy (but will do in NumPy 2.0 under NEP 50): 178 | ```python 179 | int64_0d_array = np.array(1, dtype=np.int64) 180 | np.result_type(np.int8, int64_0d_array) == np.int8 181 | ``` 182 | 183 | **Versioning**. It should be clear from the previous points that NumPy has a 184 | fair amount of questionable behavior and legacy pain points. It is for this reason that 185 | we decided that rather than fighting these, we would declare that the compat 186 | layer follows the behavior of NumPy's most recent release (even, in some cases, 187 | of NumPy 2.0). Given the stability of NumPy's API and how battle-tested its 188 | main functions are, we do not expect this to become a big maintenance burden. 189 | If anything, it should make our lives easier, as some parts of NumPy will soon 190 | be simplified, saving us the pain of having to implement all the pre-existing 191 | corner-cases. 192 | 193 | **Randomness**. PyTorch and NumPy use different random number generation methods. 194 | In particular, NumPy recently moved to a [new API](https://numpy.org/doc/stable/reference/random/index.html) 195 | with a `Generator` object which has sampling methods on it. The current compat 196 | layer does not implement this new API, as the default bit generator in NumPy is 197 | `PCG64`, while on PyTorch we use `MT19937` on CPU and `Philox` on non-CPU devices. 198 | From this, it follows that this API will not give any reproducibility 199 | guarantees when it comes to randomness. 200 | 201 | **Accuracy**. For deterministic operations, we would expect to give accuracy 202 | guarantees similar to those in `torch.compile`. In particular, we would expect 203 | these decompositions to be as precise as those from NumPy when compared to an 204 | `fp64` baseline minus perhaps a small relative error. 205 | 206 | 207 | ## The `torch_np` module 208 | 209 | The bulk of the work went into implementing a system that allows us to 210 | implement NumPy operations in terms of those of PyTorch. The main design goals 211 | here were 212 | 213 | 1. Implement *most* of NumPy's API 214 | 2. Preserve NumPy semantics as much as possible 215 | 216 | We say *most* of NumPy's API, because NumPy's API is not only massive, but also 217 | there are parts of it which cannot be implemented in PyTorch. For example, 218 | NumPy has support for arrays of string, datetime, structured and other dtypes. 219 | Negative strides are another example of a feature that is not supported in PyTorch. 220 | We put together a list of things that are out of the scope of this project in the 221 | [following issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). 222 | 223 | For the bulk of the functions, we started by prioritizing the most common 224 | operations. Then, when bringing tests from the NumPy test suite, we triaged 225 | and prioritized how important it was to fix each failure we found. Doing this 226 | iteratively, we ended up with a small list of differences between the NumPy and 227 | PyTorch APIs, which we prioritized by hand. That list and the prioritization 228 | discussion can be found in [this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/87). 229 | 230 | **Visibility of the module** For simplicity, this RFC assumes that the 231 | `torch_np` module will not be public, as the initial suggestion for it to be 232 | made public was met with mixed opinions. This topic can be revisited in the 233 | future if desired. 234 | We discuss these in the section [unresolved questions](#unresolved-questions). 235 | 236 | ### Annotation-based preprocessing 237 | 238 | NumPy accepts virtually anything that smells like an array as an input. 239 | 240 | ```python 241 | >>> np.add(1, 3) 242 | 4 243 | >>> np.add([1., 2., 3.], 5) 244 | array([6., 7., 8.]) 245 | >>> np.concatenate([[1, 2, 3], [4, 5, 6]]) 246 | array([1, 2, 3, 4, 5, 6]) 247 | ``` 248 | 249 | NumPy calls all these objects `array_like` objects. 250 | To implement NumPy in terms of PyTorch, for any operation we would need to map 251 | inputs into tensors, perform the operations, and then wrap the tensor into 252 | a `torch_np.ndarray` (more on this class later). 253 | 254 | To avoid all this code repetition, we implement the functions in two steps. 255 | 256 | First, we implement functions with the NumPy signature, but assuming that in 257 | place of NumPy-land elements (`np.array`, array-like functions, `np.dtype`s, etc) 258 | they simply accept `torch.Tensor` and PyTorch-land objects and return 259 | `torch.Tensor`s. For example, we would implement `np.diag` as 260 | 261 | ```python 262 | def diag(v, k=0): 263 | return torch.diag(v, k) 264 | ``` 265 | 266 | In this layer, if a NumPy function is composite (calls other NumPy functions 267 | internally), we can simply vendor its implementation, and have it call our 268 | PyTorch-land implementations of these functions. In other words, at this level, 269 | functions are composable, as they are simply regular PyTorch functions. 270 | All these implementations are internal, and are not meant to be seen or used 271 | by the end user. 272 | 273 | The second step is then done via type annotations and a decorator. Each type 274 | annotation has an associated function from NumPy-land into PyTorch-land. This 275 | function converts the set of inputs accepted by NumPy for that argument into a 276 | PyTorch-land object (think a `torch.Tensor` or a PyTorch dtype). For example, 277 | for `np.diag` we would write 278 | 279 | ```python 280 | @normalize 281 | def diag(v: ArrayLike, k=0): 282 | return torch.diag(v, k) 283 | ``` 284 | 285 | Then, we wrap these Python-land functions with a `normalizer` decorator and 286 | expose them in the `torch_np` module. This decorator is in charge of gathering 287 | all the inputs at runtime and normalizing them (i.e., converting `torch_np` 288 | objects to PyTorch counterparts) according to their annotations. 289 | 290 | We currently have four annotations (and small variations of them): 291 | - `ArrayLike`: The input can be a `torch_np.array`, a list of lists, a 292 | scalar, or anything that NumPy would accept. It returns a `torch.Tensor`. 293 | - `DTypeLike`: Takes a `torch_np` dtype, and any other object that Numpy dtypes 294 | accept (strings, typecodes...) and returns a PyTorch dtype. 295 | - `AxisLike`: Takes anything that can be accepted as an axis (e.g. a tuple or 296 | an `ndarray`) and returns a tuple. 297 | - `OutArray`: Asserts that the input is a `torch_np.ndarray`. This is used 298 | to implement the `out` keyword. 299 | 300 | Note that none of the code in this implementation makes use of NumPy. We are 301 | writing `torch_np.ndarray` above to make more explicit our intent, but there 302 | shouldn't be any ambiguity. 303 | 304 | **Implementing `out`**: In PyTorch, the `out` kwarg is a keyword-only argument. 305 | It is for this reason that, in PrimTorch, we were able to implement it as [a 306 | decorator](https://github.com/pytorch/pytorch/blob/ce4df4cc596aa10534ac6d54912f960238264dfd/torch/_prims_common/wrappers.py#L187-L282). 307 | This is not the case in NumPy. In NumPy, `out` can be used both as a positional 308 | and a keyword argument, and is often interleaved with other parameters. This is 309 | the reason why we use the `OutArray` annotation to mark these. We then 310 | implement the `out` semantics in the `@normalizer` wrapper in a generic way. 311 | 312 | **Ufuncs and reductions**: Ufuncs (unary and binary) and reductions are two 313 | sets of functions that are particularly regular. For these functions, we 314 | implement support for their arguments in a generic way as a preprocessing or 315 | postprocessing step. 316 | 317 | **The `ndarray` class** Once we have all the free functions implemented as 318 | functions from `torch_np.ndarray`s to `torch_np.ndarray`s, implementing the 319 | methods from the `ndarray` class is rather simple. We simply register all the 320 | free functions as methods or dunder methods appropriately. We also forward the 321 | properties of `ndarray to the corresponding properties of `torch.Tensor` and we 322 | are done. This creates a circular dependency which we break with a local 323 | import. 324 | 325 | ### Testing 326 | 327 | The testing of the framework was done via ~~copying~~ vendoring tests from the 328 | NumPy test suite. Then, we would replace the NumPy imports with `torch_np` 329 | imports. The failures on these tests were then triaged, and either fixed or marked 330 | `xfail` depending on our assessment of the priority of implementing a fix. 331 | 332 | In the end, to have a last check that this tool was sound, we pulled five 333 | examples of NumPy code from different sources and ran it with this library (eager mode execution). 334 | We were able to run the five examples successfully with close to no code changes. 335 | You can read about these in the [README](https://github.com/Quansight-Labs/numpy_pytorch_interop). 336 | 337 | ### Limitations 338 | 339 | A number of known limitations are tracked in the second part of the 340 | [OP of this issue](https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/73). 341 | When landing this RFC, we will create a comprehensive document with the differences 342 | between NumPy and `torch_np`. 343 | 344 | ### Beyond plain NumPy 345 | 346 | **GPU**. The implementation allows for running NumPy code on GPU simply by 347 | adding a `torch.set_default_device("cuda")` at the top of the relevant 348 | file/code. We tested this on the code examples that we drew from the internet 349 | and they run just fine on GPU. 350 | 351 | **Gradients**. We have not tested gradient tracking either as we are still to 352 | find some good examples on which to test it, but it should be a simple 353 | corollary of all this effort. If the original tensors fed into a function 354 | have `requires_grad=True`, the tensors will track the gradients of the internal 355 | implementation and then the user can differentiate through their NumPy code. 356 | 357 | ### Bindings to TorchDynamo 358 | 359 | The bindings for NumPy at the TorchDynamo level are currently being developed in 360 | [pytorch#95849](https://github.com/pytorch/pytorch/pull/95849). 361 | 362 | 363 | ## Unresolved questions 364 | 365 | A question was left open in the initial discussion. Should the module 366 | `torch_np` be publicly exposed as `torch.numpy` or not? 367 | 368 | A few arguments in favor of making it public: 369 | * People could use it in their NumPy programs just by changing the import to 370 | `import torch.numpy as np`. This could be a selling point similar to JAX's 371 | `jax.numpy`, which could incentivize adoption. 372 | * People would not need to use the whole `torch.compile` stack to start using 373 | PyTorch as a backend for their NumPy code in their codebase. 374 | * See [this experiment in scikit-learn](https://github.com/scikit-learn/scikit-learn/pull/25956) 375 | where they got a 7x speed-up on CPU on a layer just by using `torch.linalg`. 376 | * Since the layer is rather thin and in pure Python, if there are bugs, 377 | external contributors could easily help fixing them or extend the supported 378 | functionality. 379 | 380 | A few arguments against: 381 | * The compat introduces a number of type conversions that may produce somewhat 382 | slow code when used in eager mode. 383 | * [Note] Keeping this in mind, we tried to use as few operators as possible, 384 | in the implementation, to make it reasonably fast in eager mode. 385 | * Exposing `torch.numpy` would create a less performant secondary entry point 386 | to many of the functions in PyTorch. This could be a trap for new users. 387 | -------------------------------------------------------------------------------- /RFC-0006-conda-distribution.md: -------------------------------------------------------------------------------- 1 | # A PyTorch conda "distribution" 2 | 3 | | | | 4 | | ---------- | --------------- | 5 | | Authors | Ralf Gommers | 6 | | Status | Rejected | 7 | | Type | Process | 8 | | Created | 2020-11-26 | 9 | 10 | This proposal addresses the need for a PyTorch conda distribution, meaning a 11 | collection of integration-tested packages that can be installed from a single 12 | channel, to enable package authors to release packages that depend on PyTorch 13 | and let users install them in a reliable way. 14 | 15 | 16 | ## Motivation and Scope 17 | 18 | For developers of libraries that depend on PyTorch, it is currently (Nov'20) 19 | quite difficult to express that dependency in a way that makes their package 20 | easily installable with `conda` (or `pip`) by end users. With the PyTorch 21 | ecosystem growing and the dependency graphs of sets of packages users use in 22 | a single environment becoming more complex, streamlining the package 23 | distribution and installation experience is important. 24 | 25 | Examples of packages for which there's interest in making them more easily 26 | available to end users: 27 | 28 | - [fastai](https://docs.fast.ai/): Jeremy Howard expressed interest, and 29 | plans to copy `pytorch` and other dependencies of fastai over to the `fastai` 30 | channel in case this proposal doesn't work out. 31 | - [fairseq](https://github.com/pytorch/fairseq): a fairseq developer inquired 32 | about being added to the `pytorch` channel 33 | [here](https://github.com/pytorch/builder/issues/563), and a conda-forge 34 | contributor wanted to package both PyTorch and fairseq in conda-forge, see 35 | [here](https://github.com/conda-forge/pytorch-cpu-feedstock/issues/7#issuecomment-688467743). 36 | - [TorchANI](https://github.com/aiqm/torchani): see a TorchANI user's recent 37 | attempt to add a conda-forge package 38 | [here](https://github.com/conda-forge/torchani-feedstock/pull/1). 39 | 40 | In scope for this proposal are: 41 | 42 | - Processes related to adding new packages to the `pytorch` conda channel. 43 | - CI infrastructure needed for integration testing and moving already built 44 | packages to the `pytorch` channel. 45 | 46 | _Note: using the `pytorch` channel seems like the most obvious choice for a 47 | single integration channel; using a new channel is also possible, it won't 48 | change the rest of this proposal materially._ 49 | 50 | Out of scope are: 51 | 52 | - Changes related to how libraries are built or packages for conda are created. 53 | - Updating PyTorch packaging in `defaults` or `conda-forge`. 54 | - Improvements to installing with pip or wheel builds. 55 | 56 | 57 | ### The current state of affairs 58 | 59 | PyTorch is packaged in the `pytorch` channel; users must either add that 60 | channel to the channels list globally or in an environment (using, e.g., 61 | `conda config --env --add channels pytorch`), or add `-c pytorch` to every 62 | `conda` command they run. Note that the channels method is preferred over `-c 63 | pytorch` but installation instructions invariably use the latter, which can 64 | lead to problems when it's forgotten by the user at some point. 65 | 66 | PyTorch is also packaged in `defaults`, but it's really outdated (1.4.0 for 67 | CUDA-enabled packages, 1.5.0 for CPU-only). The `conda-forge` channel doesn't 68 | have PyTorch packages - there's a desire to add them, however it's unclear if 69 | and how that will happen. 70 | 71 | Authors of _pure Python packages_ tend to use their own conda channel to 72 | distribute their own package. Installation instructions will then have both 73 | the `pytorch` and their own channel in them. For example for fastai and 74 | BoTorch: 75 | 76 | ``` 77 | conda install -c fastai -c pytorch fastai 78 | ``` 79 | 80 | ``` 81 | conda install botorch -c pytorch -c gpytorch 82 | ``` 83 | 84 | When a user needs multiple packages, that becomes unwieldy quickly with each 85 | package adding its own channel. Note: alternatively, pure Python packages can 86 | choose to distribute on PyPI only (see the _PyPI, pip and wheels_ section 87 | further down) - Kornia is an example of a package that does this. 88 | 89 | Authors of _packages containing C++ or CUDA code_ which use the PyTorch C++ 90 | API have an additional issue: they need to release new package versions in 91 | sync with PyTorch itself, because there's no stable ABI that would allow 92 | depending on multiple PyTorch versions. For example, the torchvision 93 | `install_requires` dependency is determined like: 94 | 95 | ```python 96 | pytorch_dep = 'torch' 97 | if os.getenv('PYTORCH_VERSION'): 98 | pytorch_dep += "==" + os.getenv('PYTORCH_VERSION') 99 | 100 | requirements = [ 101 | 'numpy', 102 | pytorch_dep, 103 | ] 104 | ``` 105 | and its build script ensure a one-to-one correspondence of `pytorch` and 106 | `torchvision` versions of packages. 107 | 108 | The `pytorch` channel currently already contains other packages that depend 109 | on PyTorch. Those fall into two categories: needed dependencies (e.g., 110 | `magma-cuda`, `ffmpeg`) , and PyTorch-branded and Facebook-owned projects 111 | like `torchvision`, `torchtext`, `torchaudio`, `captum`, `faiss`, `ignite`, etc. 112 | See https://anaconda.org/pytorch/repo for a complete list. 113 | 114 | Those packages maintain their own build and packaging scripts (see 115 | [this comment](https://github.com/pytorch/builder/issues/563#issuecomment-722667815)), 116 | and the integration testing and uploading to the `pytorch` conda channel is done 117 | via scripts in the [pytorch/builder](https://github.com/pytorch/builder) repo. 118 | 119 | There's more integration testing happening already: 120 | - The `test_community_repos/` directory in the `builder` repo contains a 121 | significantly larger set of packages that's tested in addition to the packages 122 | that are distributed on the `pytorch` conda channel. 123 | - The [pytorch-integration-testing](https://github.com/pytorch/pytorch-integration-testing) 124 | repo contains tooling to test PyTorch release candidates. 125 | - An overview of integration test results from the `builder` repo (last updated Oct'19, 126 | so perhaps no longer maintained) can be found 127 | [here](https://web.archive.org/web/20201222195552/http://ossci-integration-test-results.s3-website-us-east-1.amazonaws.com/test-results.html). 128 | 129 | 130 | ## Usage and Impact 131 | 132 | ### End users 133 | 134 | The intended outcome for end users is that they will be able to install many 135 | of the most commonly packages easily with `conda` from a single channel, 136 | e.g.: 137 | 138 | ``` 139 | conda install pytorch torchvision kornia fastai mmf -c pytorch 140 | ``` 141 | 142 | or, a little more complete: 143 | 144 | ``` 145 | # Use a new environment for a new project 146 | conda create -n myenv 147 | conda activate myenv 148 | # Add channel to env, so all conda commands will now pick up packages 149 | # in the pytorch channel: 150 | conda config --env --add channels pytorch 151 | conda install pytorch torchvision kornia fastai mmf 152 | ``` 153 | 154 | ### Maintainers of packages depending on PyTorch 155 | 156 | The intended outcome for maintainers is that: 157 | 158 | 1. They have clear documentation on how to add their package to the `pytorch` channel, 159 | including the criteria their packages should meet, how to run integration tests, 160 | and how to release new versions. 161 | 2. They can declare their dependencies correctly 162 | 3. They will still need their own channel or some staging channel to host packages 163 | before they get `anaconda copy`'d to the `pytorch` channel. 164 | 4. They can provide a single install command to their users, `conda install mypkg -c pytorch`, 165 | that will work reliably. 166 | 167 | 168 | ## Processes 169 | 170 | ### Proposing a new package for inclusion 171 | 172 | Prerequisites for a package being considered for inclusion in the `pytorch` channel are: 173 | 174 | 1. The package naturally belongs in the PyTorch ecosystem. I.e., PyTorch is a 175 | key dependency, and the package is focused on an area like deep learning, 176 | machine learning or scientific computing. 177 | 2. All runtime dependencies of the package are available in the `defaults` or 178 | `pytorch` channel, or adding them to the `pytorch` is possible with a 179 | reasonable amount of effort. 180 | 3. A working recipe for creating a conda package is available. 181 | 182 | A GitHub repository (working name `conda-distro`) will be used for managing 183 | proposals for new packages as well as integration configuration and tooling. 184 | To propose a new package, open an issue and fill out the instructions in the 185 | GitHub issue template. When a maintainer approves the request, the proposer 186 | can open a PR to that same repo to add the package to the integration 187 | testing. 188 | 189 | 190 | ### Integration testing infrastructure 191 | 192 | The CI connected to the `conda-distro` repo has to do the following: 193 | 194 | 1. Trigger on PRs that add or update an individual package, running the tests 195 | for that package _and_ downstream dependencies of that package. 196 | 2. If tests for (1) are successful, sync the conda packages in question to 197 | the `pytorch` channel with `anaconda copy`. 198 | 3. Provide a way to run the tests of all packages together. 199 | 4. Send notifications if a package releases requires an update (e.g. a 200 | version bump) to a downstream package. 201 | 202 | The individual packages have to do the following: 203 | 204 | 1. Ensure there are _upper bounds on dependency versions_, so new releases of 205 | PyTorch or another dependency cannot break already released versions of 206 | the individual package in question. Note that that does mean that a new 207 | PyTorch releases requires version bumps on existing packages - more detail 208 | in strategy will be needed here. 209 | 2. Tests for a package should be _runnable in a standardized way_, via 210 | `conda-build --test`. This is easy to achieve via either a `test:` section 211 | in the recipe (`meta.yaml`) or a `run_test.py` file. See [this section of 212 | the conda-build docs](https://docs.conda.io/projects/conda-build/en/latest/resources/define-metadata.html#test-section) 213 | for details. An advantage of this method is that `conda-build` is already 214 | aware of channels and dependencies, so it should work with very little 215 | extra effort. 216 | 217 | 218 | ### What happens when a new PyTorch release is made? 219 | 220 | For minor or major versions of PyTorch, new releases of downstream packages 221 | will also be necessary. A number of packages, such as `torchvision`, 222 | `torchaudio` and `torchtext`, are anyway released in sync. Other packages in 223 | the `pytorch` channel may need to be manually released via a PR to the 224 | `conda-distro` repo). 225 | 226 | Version constraints should be set such that a bugfix release of PyTorch does 227 | not require any new downstream package releases. 228 | 229 | 230 | ### Dealing with packages that aren't maintained 231 | 232 | Proposing a package for inclusion in the `pytorch` channel implies a 233 | commitment to keep maintaining the package. There wil be a place to list one 234 | or more maintainers for each package so they can be pinged if needed. In case 235 | a package is not up-to-date or broken and it does not get fixed, after a 236 | certain duration (length TBD) it may be removed from the channel. 237 | 238 | 239 | ## Alternatives 240 | 241 | ### Conda-forge 242 | 243 | The main alternative to making the `pytorch` channel an integration channel 244 | that distributes many packages that depend on PyTorch is to have a 245 | (GPU-enabled) PyTorch package in conda-forge, and tell users and package 246 | authors that that is the place to go. It will require working with 247 | conda-forge in order to ensure that the `pytorch` package is of high quality, 248 | either by copying over the binaries from the `pytorch` channel or by 249 | migrating recipes and keeping them in sync. See 250 | [this very long discussion](https://github.com/conda-forge/pytorch-cpu-feedstock/issues/7) 251 | for details (and issues). 252 | 253 | Advantages of this alternative are: 254 | 255 | - Conda-forge has a lot of packages, so it will be easier to install PyTorch 256 | in combination with other non-deep learning packages (e.g. the geo-science 257 | stack). 258 | - Conda-forge already has established tools and processes for adding and 259 | updating them. Which means it's less likely for there to be issues with 260 | dependencies (e.g. packages with many or unusual dependencies may not be 261 | accepted into the `pytorch` channel, while `conda-forge` will be fine with 262 | them). 263 | - Users are likely already familiar with using the `conda-forge` channel. 264 | 265 | Disadvantages of this alternative are: 266 | 267 | - As of today, conda-forge doesn't have GPU hardware. Building is stil 268 | possible using CUDA stubs, however testing cannot really be done inside CI, 269 | only manually (which is a pain, especially when having to test multiple 270 | hardware and OS platforms). 271 | _Note that there are packages that follow this approach (mostly without 272 | problems so far), for example `arrow-cpp` and `cupy`. To obtain a full list of packages, clone https://github.com/conda-forge/feedstocks and run 273 | `grep 'compiler(' feedstocks/*/meta.yaml | grep cuda`._ 274 | - `conda-forge` and `defaults` aren't guaranteed to be compatible, so 275 | standardizing on `conda-forge` may cause problems for people who prefer 276 | `defaults`. 277 | - Exotic hardware support may be difficult. PyTorch has support for TPUs (via 278 | XLA), AMD ROCm, Linux on ARM64, Vulkan, Metal, Android NNAPI - this list 279 | will continue to grow. Most of this is experimental and hence not present 280 | in official binaries (and/or in the C++/Java packages which aren't 281 | distributed with conda), but this is likely to change and present issues 282 | with compilers or dependencies not present in conda-forge. 283 | For more details, see [this comment by Soumith](https://github.com/conda-forge/pytorch-cpu-feedstock/issues/7#issuecomment-538253388). 284 | - Release coordination is more difficult. For a PyTorch release, packages for 285 | `pytorch`, `torchvision`, `torchtext`, `torchaudio` will all be built 286 | together and then released. There may be manual quality assurance steps 287 | before uploading the packages. 288 | Building a set of packages like that depend on each other and releasing 289 | them in a coordinated fashion is hard to do on conda-forge, given that if 290 | everything is in feedstocks, the new pytorch package must already be 291 | available before the next build can start. It may be possible to do this 292 | with channel labels (build sequentially, then move all packages to the 293 | `main` label at once), but either way all the released artifacts will be 294 | publicly visible before the official release. 295 | 296 | Other points: 297 | 298 | - If the PyTorch team does not package for conda-forge, someone else will do 299 | that at some point. 300 | - Conda-forge no longer uses a single compiler toolchain for all packages it 301 | builds for a given platform - it is now possible to use a newer compiler, 302 | which itself is built with an older glibc/binutils (that does need to be 303 | common). See 304 | [this example](https://github.com/conda-forge/omniscidb-feedstock/blob/master/recipe/conda_build_config.yaml) 305 | for how to specify using GCC 8. So not having a recent enough compiler 306 | available is unlikely to be a relevant concern. 307 | - Mirroring packages in the `pytorch` channel to the `conda-forge` channel 308 | would alleviate worries about the disadvantages here, however there's no 309 | conda-forge tooling currently to verify ABI compatibility of the packages, 310 | which is the main worry of the conda-forge team with this approach. 311 | 312 | 313 | ### DIY for every package 314 | 315 | Letting authors of every package depending on PyTorch find their own solution 316 | is basically the status quo of today. The most likely outcome longer-term is 317 | that PyTorch plus those packages depending on it will be packaged in 318 | conda-forge independently. At that point there are two competing `pytorch` 319 | packages, one in the `pytorch` and one in the `conda-forge` channel. And 320 | users who need a prebuilt version of other packages not available in the 321 | `pytorch` channel will likely migrate to `conda-forge`. 322 | 323 | The advantage is: no need to do any work to implement this proposal. The 324 | disadvantage is: depending on PyTorch will remain difficult for downstream 325 | packages. 326 | 327 | 328 | ## Related work and issues 329 | 330 | ### Conda channels 331 | 332 | Mixing multiple conda channels is rarely a good idea. It isn't even 333 | completely clear what a channel is for, opinions of conda and conda-forge 334 | maintainers differ - see 335 | https://github.com/conda-forge/conda-forge.github.io/issues/883. 336 | 337 | 338 | ### RAPIDS 339 | 340 | RAPIDS has a really complex setup for distributing conda packages. Its install instructions currently look like: 341 | ``` 342 | conda create -n rapids-0.16 -c rapidsai -c nvidia -c conda-forge \ 343 | -c defaults rapids=0.16 python=3.7 cudatoolkit=10.1 344 | ``` 345 | 346 | Depending on a user's config (e.g. having `channel_priority: strict` in 347 | `.condarc`), this may not work even in a clean environment. If one would add 348 | the `pytorch` channel as well, for users that need both PyTorch and RAPIDS, 349 | it's even less likely to work - the conda solver cannot handle that many 350 | channels and will fail to find a solution. 351 | 352 | 353 | ### Cudatoolkit 354 | 355 | CUDA libraries are distributed for conda users via the `cudatoolkit` package. 356 | That package is only available in the `nvidia`, `defaults` and `conda-forge` 357 | channels. The license of the package prohibits redistribution, and an 358 | exception is difficult to obtain. Therefore it should not be added to the 359 | `pytorch` channel (also not necessary, obtaining it from `defaults` is fine). 360 | 361 | 362 | ### PyPI, pip and wheels 363 | 364 | The experience installing PyTorch with `pip` is suboptimal, mainly because 365 | there's no way to control CUDA versions via `pip`, so the user gets whatever 366 | the default CUDA version is (10.2 at the time of writing) when running `pip 367 | install torch`. In case the user needs a different CUDA version or the 368 | CPU-only package, the install instruction looks like: 369 | ``` 370 | pip install torch==1.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html 371 | ``` 372 | There's the [pytorch-pip-shim](https://github.com/pmeier/pytorch-pip-shim) 373 | tool to handle auto-detecting CUDA versions and retrieving the right wheel. 374 | It relies on monkeypatching pip though, so it may break when new versions of 375 | pip are released. 376 | 377 | For package authors wanting to add a dependency on PyTorch, the above 378 | usability issue is a serious problem. If they add a runtime dependency on 379 | PyTorch (via `install_requires` in `setup.py` or via `pyproject.toml`), the 380 | only thing they can add is `torch` and there's no good way of signalling to 381 | the user that there's a CUDA version issue or how to deal with it. 382 | 383 | Finally note that `pip` and `conda` work together reasonably well, so for 384 | package authors that want to release packages that _do not contain C++ or 385 | CUDA code_, releasing on PyPI only and telling their users to install PyTorch 386 | with `conda` and their package with `pip` will work best. As soon as C++/CUDA 387 | code gets added, that's no longer reliable though. 388 | 389 | 390 | ## Effort estimate 391 | 392 | TODO 393 | 394 | ### Initial setup 395 | 396 | 397 | ### Ongoing effort 398 | 399 | -------------------------------------------------------------------------------- /RFC-0020-Lightweight-Dispatch.md: -------------------------------------------------------------------------------- 1 | # Context 2 | Currently we rely on `TORCH_LIBRARY` [API](https://pytorch.org/tutorials/advanced/dispatcher.html) to register operators into dispatcher. As PyTorch aims to support a wider set of devices and use cases, the issue of op registration/dispatching static initialization time and runtime overhead as well as build-time complexity has risen to the fore, we see the need for a lighter version of our operator dispatching mechanism. 3 | 4 | We thought about possible solutions: 5 | * One option is to keep using the torch-library C++ API to register/dispatch ops, but this will necessitate careful cost-cutting for these use cases with more and more intrusive “#ifdef” customizations in the core framework. 6 | * The other option (which is what we propose here) is to utilize the function schema yaml file to "declare" ops and then use the codegen framework to generate lightweight code for runtime execution. 7 | 8 | The essential point of this proposal is that the function schema DSL (which we use to declare the standard ATen ops) combined with the codegen framework (which we use to generate the registrations, dispatch stubs and other “glue” code for ATen ops) is the bare minimum set of reusable tools for building custom extensions that are compatible with the PyTorch ecosystem. 9 | 10 | # Motivation 11 | * **Performance** 12 | * For recent use cases of mobile interpreter, we need to satisfy more and more strict initialization latency requirements, where analysis shows op registration contributes to a large portion of it. 13 | * With existing meta-programming based unboxing logic shared between mobile and server, it’s relatively inflexible to introduce optimizations. 14 | * Also with static dispatch, we don’t have to register all of the ops into the JIT op registry, which saves runtime memory usage and further reduces static initialization time. 15 | * It is possible to avoid dispatching at runtime. 16 | * **Modularity and binary size** 17 | * Currently the mobile runtime consists of both JIT op registry and c10 dispatcher. This project will make it possible to not depend on the c10 dispatcher (opt-in), delivering a cleaner runtime library. 18 | * This project creates an opportunity to reduce binary size by getting rid of the dispatcher and enables further size optimization on unboxing wrappers. 19 | * **Ability to incorporate custom implementation of ATen ops** 20 | * For some of the mobile use cases, we need to support custom implementations of ATen ops. With an extra op registration path such as codegen unboxing it is easier to hookup ops with custom native functions. 21 | 22 | # Overview 23 | ![codegen drawio](https://user-images.githubusercontent.com/8188269/154173938-baad9ee6-0e3c-40bb-a9d6-649137e3f3f9.png) 24 | 25 | 26 | Currently the mobile interpreter registers all ATen ops into the dispatcher and some other ops into the JIT op registry. At model inference time, the interpreter will look for the operator name in the JIT op registry first, if not found then it will look into the dispatcher. This proposal **adds a build flavor that moves these ATen ops from dispatcher to JIT op registry** so that it’s easier to optimize (e.g. avoid schema parsing) and reduce dependencies. 27 | 28 | The interpreter is looking for a boxed function but our native implementation is unboxed. We need “glue code” to hook up these two. This proposal **extends the capabilities of codegen to generate the unboxing wrappers for operators**, as well as the code to register them into the JIT op registry. The interpreter will call generated unboxing wrappers, inside these wrappers we pop out values from the stack, and delegate to the unboxed API. 29 | 30 | To avoid hitting the dispatcher from the unboxed API, we will choose static dispatch so that we hit native functions from the unboxed API directly. To make sure we have feature parity as the default build, this proposal **adds support for multiple backends in static dispatch**. 31 | 32 | In addition to that, this proposal also supports features critical to mobile use cases, such as **tracing based selective build** and **runtime modularization** work. 33 | 34 | # Step by step walkthrough 35 | 36 | How will our new codegen unboxing wrapper fit into the picture of op registration and dispatching? For these use cases, we only need per-op codegen unboxing (red box on the left) as well as static dispatch. This way we can avoid all dependencies on c10::Dispatcher. 37 | 38 | We are going to break the project down into three parts, for **step 1 we are going to implement the codegen logic** and generate code based on [native_functions.yaml](https://fburl.com/code/2wkgwyoq), then we are going to verify the flow that we are able to find jit op in the registry and eventually call codegen unboxing wrapper (the red flow on the left). **Step 2 will focus on how to make sure we have feature parity** with the original op registration and dispatch system, with tasks like supporting multiple backends in static dispatch, supporting custom ops as well as custom kernels for ATen ops. For **step 3 we are going to integrate with some target hardware platforms** to validate latency and binary size improvements. These are the problems we need to address in step 3 including: avoiding schema parsing at library init time, supporting tracing based selective build. The goal of step 3 is to make sure per-op codegen unboxing works for our target hardware platforms and is ready to ship to production use cases. 39 | 40 | 41 | ### Step 1 42 | 43 | Bring back the unboxing kernel codegen using the new codegen framework. And make the registration no-op when we turn on the static root-op dispatch for lightweight dispatch use cases. All tasks in step 1 are based on the server version of PyTorch interpreter. 44 | 45 | 46 | #### Codegen core logic 47 | 48 | These tasks will generate C++ code that pops IValues out from a stack and casts them to their corresponding C++ types. This core logic should be shared across two types of codegens so that it can be covered by all the existing tests on server side. 49 | 50 | 51 | 52 | * **JIT type -> C++ type**. This is necessary for some of the optional C++ types, e.g., we need to map `int` to `int64_t` for the last argument in the example. 53 | * This is already done in [types.py](https://github.com/pytorch/pytorch/blob/master/tools/codegen/api/types.py), and we need to integrate it into our new codegen. 54 | * **JIT type -> IValue to basic type conversion C++ code.** E.g., the first argument of this operator: `Tensor(a) self` needs to be translated to: `(std::move(peek(stack, 0, 4))).toTensor()` 55 | * IValue provides APIs to directly convert an IValue to these basic types. See [ivalue_inl.h](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/ivalue_inl.h#L1453-L1493) 56 | * Here’s a [list](#bookmark=id.deyvpbsb5yel) of all the JIT types appearing in native_functions.yaml, most of them can be converted using IValue’s API. 57 | * Add a binding function between a JIT type to a piece of C++ code that converts IValue to a specific C++ type. 58 | * **JIT type -> IValue to ArrayRef type conversion C++ code. **IValue doesn’t provide explicit APIs for these ArrayRef types, but they are widely used in native_functions.yaml. 59 | * We can use the meta programming logic ([make_boxed_from_unboxed_functor.h](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h#L354)) as reference, convert the ivalue to vector then to ArrayRef. 60 | * **JIT type -> IValue to TensorOptions type conversion C++ code.** 61 | * Handle TensorOptions (that is not 1-1 mapping across two types of arguments), we can refer to [python.py](https://github.com/pytorch/pytorch/blob/master/tools/codegen/api/python.py#L999-L1068), maybe follow the logic over there. 62 | * **JIT schema -> unboxed function**. With all the arguments being translated, generate the C++ code to call the correct unboxed function and return the result (push it back to stack). 63 | * Figure out how to map schema to unboxed C++ function. Reference [python.py](https://github.com/pytorch/pytorch/blob/master/tools/codegen/api/python.py#L955) 64 | * Deal with method and function separately, also handle the `out` cases. 65 | 66 | 67 | #### Codegen source file details 68 | 69 | With the logic from the previous section, we should be able to wrap the code into a function pointer and register it into [torch::jit::OperatorRegistry](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/operator.cpp#L19). 70 | 71 | 72 | 73 | * Wrap generated C++ code in [OperatorGenerator](https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/operator.h#L221) so that it gets registered into the registry. Generate code for all functions in [native_functions.yaml](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml). Code snippet as an example: 74 | ```cpp 75 | RegisterCodegenUnboxedKernels.cpp 76 | =================== 77 | RegisterOperators reg({ 78 | OperatorGenerator( 79 | TORCH_SELECTIVE_SCHEMA("aten::get_device(Tensor self) -> int"), 80 | [](Stack & stack) { 81 | RECORD_FUNCTION("get_device", std::vector()); 82 | at::unboxing::get_device(stack); 83 | }, 84 | aliasAnalysisFromSchema() 85 | ), 86 | ... 87 | }) 88 | UnboxingFunctions.h 89 | =================== 90 | namespace at { 91 | namespace unboxing { 92 | 93 | TORCH_API at::Tensor get_device(Stack & stack); 94 | 95 | } // namespace unboxing 96 | } // namespace at 97 | 98 | UnboxingFunctions.cpp 99 | ===================== 100 | namespace at { 101 | namespace unboxing { 102 | 103 | TORCH_API at::Tensor get_device(Stack & stack) { 104 | auto result = at::get_device( 105 | (std::move(peek(stack, 0, 1))).toTensor() 106 | ); 107 | drop(stack, 1); 108 | pack(stack, std::move(result)); 109 | } 110 | 111 | } // namespace unboxing 112 | } // namespace at 113 | ``` 114 | 115 | 116 | 117 | 118 | * Generate separate header/cpp for codegen unboxing wrapper. We should put the codegen unboxing body into a separate function with dedicated namespace so that it is on par with the other codegen (Functions.h). 119 | * Compile generated code with current runtime and make sure the calls to ATen ops are getting dispatched to our codegen’d unboxing wrapper. 120 | * The easiest way to test is to generate a wrapper that prints out/throws an exception. Then we can execute a scripted module to trigger the dispatch. 121 | #### Server & OSS integration 122 | 123 | **Bringing codegen unboxing to server build is out of the scope of this project.** We evaluated the option of replacing JIT op registration hook (`register_c10_ops.cpp`) with codegen unboxing wrappers, but realized that effort needs proper design and a lot of effort and only brings in small value: 124 | 125 | 126 | 127 | * Having two op registration mechanisms brings more confusion. 128 | * For the scenario of adding a new operator (not to `native_functions.yaml`), we need to provide clear guidance to add it to the JIT op registry as well, otherwise JIT execution will break. 129 | * We can add tests on the mobile build for the sake of coverage. 130 | 131 | For OSS mobile integration, we will need to have a new build flavor to switch between c10 dispatcher vs jit op registry. This new flavor will include codegen source files (`UnboxingFunctions.h, UnboxingFunctions.cpp, RegisterCodegenUnboxedKernels.cpp`) instead of existing dispatcher related source files: `Operators.cpp`, `RegisterSchema.cpp `etc, similar to the internal build configuration. Again, this will be delivered as a build flavor for user to opt in, the dispatcher will be used by default. 132 | 133 | 134 | 135 | ### Step 2 136 | 137 | With step 1 we already have a working codegen unboxing + static dispatch system working but it only works for the `CPU` backend. Nowadays most models being deployed on edge devices are quantized models so we will need to support both `CPU` and `QuantizedCPU` backend. In addition to that, a lot of our models feature custom ops, however we can’t register custom ops through the old dispatcher (`TORCH_LIBRARY`) APIs any more. Here I’m proposing a solution that exposes the `native_function.yaml` syntax to the internal developers targeting this runtime mode: allow them to use the yaml file format to declare their custom ops and/or custom kernels. 138 | 139 | 140 | 141 | 142 | #### Support multiple backends in static dispatch 143 | 144 | **NOTE: this may be optional if we enabled backend tracing for ops.** For the vast majority of models, we will only have 1 backend per operator, meaning that if we can pass the backend info into codegen, we don’t have to do dispatch based on dispatch key. 145 | 146 | In the scenario that a model contains both floating point ops and quantized ops, our codegen should be able to statically dispatch to the correct backend. The following diagram shows what will be generated and included in the build and demonstrates the dependency relationship. 147 | 148 | Let’s take `acosh` as an example: 149 | ```yaml 150 | native_functions.yaml 151 | ===================== 152 | - func: acosh(Tensor self) -> Tensor 153 | variants: function, method 154 | structured_delegate: acosh.out 155 | 156 | - func: acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 157 | structured: True 158 | structured_inherits: TensorIteratorBase 159 | dispatch: 160 | CPU, CUDA: acosh_out 161 | ``` 162 | 163 | And if we pass the backends we want to codegen for both `CPU` and `QuantizedCPU` backends, our `Functions.h` will be something like this (borrowing from Jiakai’s [PR](https://github.com/pytorch/pytorch/pull/51554/commits/0ba3d4cc42187f69f17e0c382f0ab51e071a4a44)): 164 | 165 | ```cpp 166 | Functions.h 167 | =========== 168 | // aten::acosh(Tensor self) -> Tensor 169 | TORCH_API inline at::Tensor acosh(const at::Tensor & self) { 170 | DispatchKeySet _dk_set = c10::detail::multi_dispatch_key_set(tensor); 171 | DispatchKey _dk = _dk_set.highestPriorityBackendTypeId(); 172 | switch (_dk) { 173 | case DispatchKey::CPU: 174 | return at::cpu::acosh(self); 175 | case DispatchKey::QuantizedCPU: 176 | default: 177 | TORCH_CHECK(false, "Unsupported static dispatch", _dk); 178 | } 179 | } 180 | ``` 181 | Also we will generate these files: 182 | 183 | 184 | 185 | * `CPUFunctions_inl.h` (does not contain `acosh` declaration) 186 | * `CPUFunctions.h` 187 | * `RegisterCPU.cpp` (without `TORCH_LIBRARY` calls) 188 | * `QuantizedFunctions_inl.h` (contains acosh declaration) 189 | * `QuantizedFunctions.h` 190 | * `RegisterQuantizedCPU.cpp` (without `TORCH_LIBRARY` calls, contains `acosh` definition) 191 | 192 | 193 | 194 | ### Step 3 195 | 196 | With step 2 finished we should have feature parity as the existing op registration & dispatch system. Now we need to consider the problems specific to edge devices. How do we support custom kernels for different edge devices? How do we make sure the performance is improved as expected? How do we make the binary size as small as possible? This step is aiming to tackle these problems and the end goal is to ship this codegen unboxing + static dispatch approach. 197 | 198 | 199 | #### Bring codegen to target platform 200 | 201 | 202 | 203 | * Consider adding ops to our new `custom_ops.yaml` created in step 2 (maybe also rename), let the codegen read from the new yaml. The benefit of doing this is that we can easily support ATen ops with custom kernels (not to be confused with custom ops) and we only have a single source of truth. 204 | * There are two options, either we figure out all the dependencies for all the ops required, or we leverage tracing based selective build. 205 | * Bring everything to our target hardware platform to make sure it builds and runs. 206 | * Disable current Dispatcher. Avoid linking any `TORCH_LIBRARY` API calls in the build, we can only profile the performance this way. 207 | * With codegen unboxing + static dispatch, we hope that we can reach a much smaller percentage of cycle count for the op registration step. 208 | 209 | 210 | 211 | #### Avoid schema parsing at runtime 212 | 213 | As mentioned in step 1, we are registering an operator into the registry along with a schema string. We realized at the library initialization time we need to spend a lot of resources on schema parsing, according to the profiling results based on our prototype. We also noticed that the required information to instantiate a schema object are all available at codegen time, we can pass these data to the registry directly so that we can save time at runtime. For example: 214 | 215 | 216 | ``` 217 | CodegenUnboxing.cpp 218 | =================== 219 | RegisterOperators reg({ 220 | OperatorGenerator( 221 | "aten::get_device", // name 222 | "", // overload_name 223 | arguments, // a vector of arguments 224 | returns, // a vector of returns 225 | [](Stack & stack) { 226 | RECORD_FUNCTION("get_device", std::vector()); 227 | at::unboxing::get_device(stack); 228 | }, 229 | aliasAnalysisFromSchema() 230 | ), 231 | ... 232 | }) 233 | ``` 234 | 235 | 236 | This way we can directly instantiate `FunctionSchema` objects without parsing at runtime. Of course we need to change APIs in `operator.h` to make this happen. 237 | 238 | Q: Can we completely get rid of `FunctionSchema` and only register name/overload_name? 239 | 240 | A: No, because we should have feature parity to the current system and backward compatibility for mobile models is a feature we need to support for the lightweight dispatch system. Currently we rely on the number of arguments to let the new runtime be able to run the old model. 241 | 242 | 243 | #### Support tracing based selective build 244 | 245 | * In [gen.py](https://github.com/pytorch/pytorch/blob/master/tools/codegen/gen.py) the files we generate will go through the selector similar to what we are doing to `RegisterSchema.cpp` right now. 246 | * We need to make sure the binary size is on-par with or even better than existing tracing based selective build. 247 | 248 | ## Risks 249 | 250 | There are 3 risks: 251 | 252 | 253 | 254 | 1. Performance gain of using JIT op registry is insignificant or even worse than dispatcher. 255 | 1. De-risked: from the prototype running on a target platform it is proved to save latency on initial load. 256 | 2. Binary size regression. Need to make sure selective build works. 257 | 3. Mobile use case requires features only available on dispatcher. 258 | 1. E.g., boxed fallback mechanism for [conj](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/MathBitsFallback.h#L78) operator. 259 | 260 | ## Testing & Tooling Plan 261 | 262 | Expand existing tests on the mobile interpreter to cover codegen logic. Let the cpp test target depending on codegen unboxing library, test if a module forward result from JIT execution equals to the mobile interpreter execution. Since JIT execution goes through metaprogramming unboxing and mobile interpreter execution goes through codegen unboxing, we can make sure the correctness of codegen unboxing. Example: 263 | 264 | ```cpp 265 | TEST(LiteInterpreterTest, UpsampleNearest2d) { 266 | Module m("m"); 267 | m.define(R"( 268 | def forward(self, input: Tensor, scale:float): 269 | return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) 270 | )"); 271 | 272 | std::vector inputs; 273 | inputs.emplace_back(torch::rand({1, 3, 128, 128})); 274 | inputs.emplace_back(at::Scalar(2.0)); 275 | auto ref = m.forward(inputs); 276 | 277 | std::stringstream ss; 278 | m._save_for_mobile(ss); 279 | mobile::Module bc = _load_for_mobile(ss); 280 | IValue res; 281 | res = bc.forward(inputs); 282 | 283 | auto resd = res.toTensor(); 284 | auto refd = ref.toTensor(); 285 | ASSERT_TRUE(resd.equal(refd)); 286 | } 287 | ``` 288 | ## Appendix 289 | List of IValue types we need to support in unboxing: 290 | 291 | 'Device', 292 | 'Device?', 293 | 'Dimname', 294 | 'Dimname[1]', 295 | 'Dimname[]', 296 | 'Dimname[]?', 297 | 'Generator?', 298 | 'Layout?', 299 | 'MemoryFormat', 300 | 'MemoryFormat?', 301 | 'Scalar', 302 | 'Scalar?', 303 | 'ScalarType', 304 | 'ScalarType?', 305 | 'Scalar[]', 306 | 'Storage', 307 | 'Stream', 308 | 'Tensor', 309 | 'Tensor(a!)', 310 | 'Tensor(a!)[]', 311 | 'Tensor(a)', 312 | 'Tensor(b!)', 313 | 'Tensor(c!)', 314 | 'Tensor(d!)', 315 | 'Tensor?', 316 | 'Tensor?[]', 317 | 'Tensor[]', 318 | 'bool', 319 | 'bool?', 320 | 'bool[2]', 321 | 'bool[3]', 322 | 'bool[4]', 323 | 'float', 324 | 'float?', 325 | 'float[]?', 326 | 'int', 327 | 'int?', 328 | 'int[1]', 329 | 'int[1]?', 330 | 'int[2]', 331 | 'int[2]?', 332 | 'int[3]', 333 | 'int[4]', 334 | 'int[5]', 335 | 'int[6]', 336 | 'int[]', 337 | 'int[]?', 338 | 'str', 339 | 'str?' 340 | -------------------------------------------------------------------------------- /RFC-0024-assets/rfc-lifecycle.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 15 | 16 | COMMENTINGDRAFTSTALLEDACCEPTEDSHELVEDCLOSEDis idlemerge featurePRsneeds rework -------------------------------------------------------------------------------- /RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md: -------------------------------------------------------------------------------- 1 | # Summary 2 | Today PyTorch Quantization has well documented support for two backends: fbgemm (x86) or qnnpack (arm or x86), with some customizations exposed but no easy way for customers to quantize models for other backends. This document proposes how to extend PyTorch Quantization to properly support custom backends, such as Intel NNPI (A*), NVIDIA V-100/A-100 and others. We hope that this design will: 3 | * Allow for pytorch users to perform Quantization Aware Training or Post Training quantization for backends beyond server and mobile CPUs 4 | * Provide a simple and clear API for custom backend developers to integrate a custom backend with PyTorch Quantization 5 | * Provide a simple and clear API for model developers to quantize models targeting custom backends 6 | 7 | The workflow for custom backend developers who want to extend PyTorch Quantization to work on their backend looks like following: 8 | * Define the configurations for quantized operators (including fused and quantized operators) with backend_config_dict api 9 | * Define a lowering pass that transforms a model with reference quantized functions to a model that can be understood by the custom backend 10 | 11 | The workflow for model developers who want to quantize for a particular backend need to do the following: 12 | * Get the backend configuration for the custom backend 13 | * Specifying how the model should be quantized by defining the qconfig_dict, the qconfig_dict should be valid given the backend configuration 14 | Quantize the model 15 | * Lower the model to custom backend by calling the lowering function for the backend 16 | 17 | Note: This design is based on [FX Graph Mode Quantization](https://pytorch.org/docs/stable/quantization.html#quantization-api-summary). 18 | 19 | # Reference Quantized Model 20 | We introduce the concept of a reference pattern which serves as a standard format for quantized operators in all backends, a reference quantized model is a quantized model with these reference pattern. Reference patterns provide a close approximation to backends using fp32 ops and type conversion ops. If a more accurate match is desired, we need emulation operators that accurately model numerics for a backend. A reference quantized model serves two purposes: 21 | 1. Standard format for lowering quantized models 22 | 2. Emulate model numerics with approximate reference operators on a devserver for debugging. 23 | 24 | The property of a quantized operator can be decomposed into two dimensions: 25 | 1. Signature 26 | 2. Numerics 27 | 28 | Currently PyTorch quantization supports two backends: fbgemm (server x86 CPU) and qnnpack (ARM CPU and x86 on mobile) and they (almost) match in both dimensions, however, there might be other backends that differ in signature or numerics than what is provided by PyTorch Quantization right now. 29 | 30 | In general, when we have a new backend, there is no guarantee that the quantized operators supported by the new backend would match in either of the dimensions, so we propose to add an extra layer of indirection between the model produced by the quantization flow and the model that’s actually used for execution in the backends. 31 | 32 | Therefore, quantization flow can produce a model with reference patterns (dequant - float_op - quant). And there will be an extra step to lower this to a model that is executable by various backends, including the current native backends in PyTorch (fbgemm/qnnpack). One thing to call out here is that we did not address some behavior of the custom backend, for example fp16/int16 accumulations since that would require changes to the implementation of reference pattern. 33 | 34 | Here are examples of reference pattern for single operator and fused operators: 35 | ```python 36 | # single operator with torch op/functional 37 | def forward(x): 38 | ... 39 | x = x.dequantize() 40 | x = torch.sigmoid(x) 41 | x = torch.quantize_per_tensor(x, scale, zero_point, dtype) 42 | ... 43 | 44 | # single operator with module 45 | def forward(x): 46 | ... 47 | x = x.dequantize() 48 | x = self.sigmoid(x) 49 | x = torch.quantize_per_tensor(x, scale, zero_point, dtype) 50 | ... 51 | 52 | # fused operators 53 | def forward(x): 54 | ... 55 | x = x.dequantize() 56 | x = self.conv2d(x) 57 | x = torch.nn.functional.relu(x) 58 | x = torch.quantize_per_tensor(x, scale, zero_point, dtype) 59 | ... 60 | ``` 61 | 62 | 63 | 64 | # Quantization Workflow 65 | ![Quantization Workflow](https://docs.google.com/drawings/d/e/2PACX-1vQ6tgl6MOkSPLcVV4ZRWwCBLebj-ugMIpBJgw8OL2FxqYg2u5rpp8UKSVQUg_Ie1HsHyVJf3A5dPIb_/pub?w=950&h=700) 66 | 67 | As we can see from the above diagram, we’ll separate the generation of a quantized model in PyTorch and the actual runnable quantized model on a specific backend. PyTorch will produce a reference quantized model which contains reference patterns that can be fused into quantized functions, this model will act as a unified representation of a quantized model and we do not give guarantees on either numerics or performance. 68 | 69 | Accuracy of a model on a specific backend can be emulated using reference ops as long as the numerics of the backend are well approximated by reference patterns (i.e by a sequence of dequant-fp32-quant ops). If this is not the case, then reference patterns need to be lowered to numerically accurate emulation functions for purposes of emulating accuracy. 70 | 71 | To get a model runnable on a specific backends, we will need to have an extra lowering step that transforms this reference model to a backend specific model (a model that only runs on that backend, for example: fbgemm/qnnpack). We may also transform the reference model to a backend with fake ops that simulates the numerics of the ops that run on backends. 72 | Backend Configurations 73 | A backend is a hardware or kernel library (NNPI, FBGEMM, QNNPACK etc.), each hardware/kernel library has a set of settings that can differ from the default setting, We can define them by following: 74 | 1. **Quantization Scheme** (symmetric vs asymmetric, per-channel vs per-tensor) 75 | 2. **Data Type** (float32, float16, int8, int8, bfloat16, etc) 76 | 3. **Quantized (and Fused) Operators and Mapping** The quantized operators supported by the backend. For example: quantized conv2d, quantized linear etc. 77 | Some quantized operators may have different numerics compared to a naive (dequant - float_op - quant) implementation 78 | For weighted operators (conv and linear) we need to define a reference module and a mapping 79 | 4. **QAT Module Mapping** For modules with weights, e.g. Conv2d and Linear, we need to swap them with qat (quantization aware training) module that adds fake quantization to the weights 80 | 81 | Note that this is general to all backends, not just custom backends. Current default backends in PyTorch (fbgemm/qnnpack) can be defined in the above terms as well. 82 | 83 | 84 | | | Fbgemm/qnnpack (supported by default)| 85 | | ----| ----| 86 | |Quantization Scheme | activation: per tensor, weight: per tensor or per channel | 87 | | Data Type | activation: quint8 (reduce_range for fbgemm), weight: qint8 | 88 | | Quantized (and Fused) Operators and Mapping | For example: nn.Conv2d → torch.ao.nn.quantized.reference.Conv2d | 89 | | QAT Module Mapping | Conv, Conv - Bn, Linear etc.| 90 | 91 | # Proposed APIs for Custom Backend Developers (Demo Purpose Only) 92 | ```python 93 | from enum import Enum 94 | # Demonstration purpose only, the api is subject to change 95 | class QuantizedOperatorType(Enum): 96 | NEED_OBSERVER_FOR_BOTH_INPUTS_AND_OUTPUTS = 0 97 | OUTPUT_IS_SHARING_OBSERVER_WITH_INPUT = 1 98 | 99 | conv_module_config = { 100 | “type”: QuantizedOperatorType.NEED_OBSERVER_FOR_BOTH_INPUTS_AND_OUTPUTS, 101 | “float_to_quantized_operator_mapping”: { 102 | # contains mapping for float op and pattern to reference function or reference module 103 | “static”: { 104 | torch.nn.Conv2d: torch.ao.nn.quantized.reference.Conv2d, 105 | (torch.nn.ReLU, torch.nn.Conv2d): torch.ao.nn.quantized.reference.ConvReLU2d, 106 | (torch.nn.ReLU, torch.nn.qat.Conv2d): torch.ao.nn.quantized.reference.ConvReLU2d, 107 | torch.nn.intrinsic.qat.ConvBn2d: torch.ao.nn.quantized.reference.Conv2d 108 | }, 109 | “qat_mapping”: { 110 | “static”: { 111 | torch.nn.Conv2d: torch.nn.qat.Conv2d 112 | (torch.nn.BatchNorm2d, torch.nn.Conv2d): torch.nn.intrinsic.qat.ConvBn2d 113 | } 114 | } 115 | } 116 | 117 | conv_functional_config = { 118 | “type”: QuantizedOperatorType.NEED_OBSERVER_FOR_BOTH_INPUTS_AND_OUTPUTS, 119 | } 120 | 121 | bmm_config = { 122 | “type”: QuantizedOperatorType.NEED_OBSERVER_FOR_BOTH_INPUTS_AND_OUTPUTS, 123 | “fusions”: [(torch.nn.Softmax, torch.bmm)] # quantized_bmm_softmax 124 | } 125 | 126 | custom_backend_config_dict = { 127 | # optional 128 | "name": "custom_backend", 129 | # quantized operator config is a map from 130 | # module/functional/torch ops to their configurations 131 | “operator”: { 132 | torch.nn.Conv2d: conv_module_config, 133 | torch.nn.functional.conv2d: conv_functional_config, 134 | torch.bmm: bmm_config 135 | } 136 | } 137 | 138 | # define a function to return the backend config dict 139 | def get_custom_backend_config_dict(): 140 | return custom_backend_config_dict 141 | 142 | # We'll also provide utility functions to get the backend configurations for a given backend: 143 | my_backend_config_dict = get_my_backend_config_dict() 144 | ``` 145 | 146 | Note the apis here are demo purpose only, for the most up to date apis, please refer to the code: https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/fx/backend_config_dict, we'll also have tutorials later when this is fully implemented. 147 | 148 | # Proposed APIs for Model Developers 149 | ```python 150 | from torch.quantization.quantize_fx import prepare_fx, convert_to_reference_fx 151 | from custom_backend_library import get_my_backend_config_dict, lower_to_custom_backend 152 | 153 | backend_config_dict = get_my_backend_config_dict() 154 | model = prepare_fx(model, qconfig_dict, prepare_custom_config_dict=..., backend_config_dict = backend_config_dict) 155 | # calibration 156 | ... 157 | model = convert_to_reference_fx(model, convert_custom_config_dict=..., backend_config_dict=backend_config_dict) 158 | 159 | # get the lower_to_custom_backend function defined by custom backend developers and call the function to lower a Reference Quantized Model to a model that runs on a custom backend 160 | model = lower_to_custom_backend(model) 161 | ``` 162 | 163 | # Use Cases 164 | ## Use Case 1: Quantizing a Model for Inference on Server/Mobile 165 | ```python 166 | from torch.quantization.quantize_fx import prepare_fx 167 | model = model.eval() 168 | qconfig_mapping = QConfigMapping().set_global(torch.quantization.default_qconfig) 169 | model = prepare_fx(model, qconfig_mapping) 170 | calibration(model, ...) 171 | model = convert_to_reference_fx(model) 172 | ``` 173 | 174 | The model produced here is a reference model that contains reference patterns, it is runnable since it is using quantize_per_tensor/dequantize/floating point operators to simulate quantized operators. For numerics we’ll provide an approximation to backend numerics even though it may not have the exact same numerics as any backends, or same speed up as the backends. 175 | 176 | ### Backend Lowering (fbgemm/qnnpack) 177 | ```python 178 | from torch.quantization.quantize_fx import prepare_fx 179 | model = model.eval() 180 | qconfig_mapping = QConfigMapping().set_global(torch.quantization.default_qconfig) 181 | model = prepare_fx(model, qconfig_mapping) 182 | calibration(model, ...) 183 | model = convert_to_reference_fx(model) 184 | 185 | # This step will transform a model with reference patterns to a model with 186 | # fbgemm/qnnpack ops, e.g. torch.ops.quantized.conv2d 187 | model = lower_to_fbgemm_fx(model) # lower_to_qnnpack_fx(model) 188 | ``` 189 | 190 | 191 | ### Example Implementation lower_to_fbgemm_fx 192 | 193 | ```python 194 | from torch.ao.nn.quantized.functional.reference import quantized_sigmoid 195 | import torch.fx 196 | from torch.fx import subgraph_rewriter 197 | from torch.quantization.fx.graph_module import QuantizedGraphModule 198 | def relu_pattern(x, scale, zero_point): 199 | x = x.dequantize() 200 | x = torch.nn.functional.relu(x) 201 | x = torch.quantize_per_tensor(x, scale, zero_point, torch.quint8) 202 | return x 203 | 204 | def relu_replacement(x, scale, zero_point): 205 | x = torch.nn.functional.relu(x) 206 | return x 207 | 208 | 209 | def _get_all_patterns_and_replacements(): 210 | return [ 211 | (relu_pattern, relu_replacement) 212 | ] 213 | 214 | 215 | def get_fbgemm_patterns_and_replacements(): 216 | return _get_all_patterns_and_replacements() 217 | 218 | 219 | def _lower_to_native_backend(model: QuantizedGraphModule) -> torch.nn.Module: 220 | """ Lower a quantized reference model (with reference quantized operator patterns) 221 | to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same 222 | operator signature so they can be lowered with the same function 223 | """ 224 | module_dict = dict(model.named_modules()) 225 | for pattern, replacement in get_fbgemm_patterns_and_replacements(): 226 | subgraph_rewriter.replace_pattern(model, pattern, replacement) 227 | model.graph.lint() 228 | return model 229 | 230 | def lower_to_fbgemm_fx(model: QuantizedGraphModule) -> torch.nn.Module: 231 | return _lower_to_native_backend(model) 232 | 233 | def lower_to_qnnpack_fx(model: QuantizedGraphModule) -> torch.nn.Module: 234 | return _lower_to_native_backend(model) 235 | ``` 236 | 237 | As is shown from the code, there are two requirements that must be met when we add a new lowering pass for a specific backed: 238 | We need to register the backend quantized operator in `torch.ops` namespace, this can be achieved with PyTorch custom operator registration: https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/aten/src/ATen/native/quantized/cpu/qconv.cpp#L881. 239 | 240 | ## Use Case 2: Quantizing the Same Model for Inference on Custom Backend 241 | If a backend does not need to modify the quantization flow, that is, it does not have extra quantized operators, fused quantized operators or customizations of quantization flow, we only need to write a lowering pass to transform a reference quantized model to a model runnable on a custom backend. 242 | 243 | ```python 244 | from torch.quantization.quantize_fx import prepare_fx 245 | model = model.eval() 246 | qconfig_mapping = QConfigMapping().set_global(torch.quantization.default_qconfig) 247 | model = prepare_fx(model, qconfig_mapping) 248 | calibration(model, ...) 249 | model = convert_to_reference_fx(model) 250 | 251 | # This optional step will transform a model with reference 252 | # functions to a model with fakeNNPI ops, e.g. torch.ops.fakeNNPI.sigmoidFP16 253 | # This is useful for bit exact model emulation on a server. 254 | 255 | fake_nnpi_model = lower_to_fakennpi_fx(model) 256 | 257 | # This function will transform the model with reference patterns to a model with nnpi ops (need to register NNPI ops in torch.ops namespace) 258 | nnpi_model = lower_to_nnpi_fx(model) 259 | ``` 260 | 261 | 262 | There are multiple lowering options: 263 | * **Lowering with `torch.fx`** We can lower directly with `torch.fx` transformations. Please take a look at Extending PyTorch Quantization to Custom Backends for an example implementation of lowering in fx, for this we need to make sure all backend operators are exposed in `torch` namespace, for example: `torch.ops.custom_backend.quantized_conv2d` 264 | * **Lowering with TorchScript** We can also first script/trace the model and then lower with TorchScript, we need to make sure TorchScript pass can recognize the reference quantized functions, for example, this can be done by looking at the name of a CallFunction Node 265 | * **Lowering with a combination of 1 and 2** People can also do some transformations in `torch.fx` and then script/trace the model and do some other transformations in TorchScript, for example, in `torch.fx` we can transform reference quantized functions to actual (dequant - float_op - quant) pattern and in TorchScript we can fuse the patterns 266 | * **Custom Lowering** Users can also provide their own lowering functions that does not use `torch.fx` or TorchScript that can transform a Reference Quantized Model to a model that’s runnable on the target backend. Basically Reference Quantized Model is the standard format that is expected by backend developers. 267 | 268 | ## Use Case 3: Extending Quantization to Support Custom Quantized Operators and Fusions for Inference in Custom Backend 269 | 270 | If a backend supports more quantized operators, more fusions, different ways of quantization, we will need to customize the flow. We only provide limited support for quantized (and fused) operator configurations at this point and we’ll work on improving the support in the future to allow arbitrary custom behaviors. 271 | 272 | Here we’ll give an example of supporting custom quantized operators and fusions. 273 | 274 | ### Custom Operator support 275 | If you need to support an operator that does not have a reference operator implementation, you will need to implement a custom function/module to specify the quantized operator implementation. In addition, you will need to write a handler class that specifies how to convert a observed node to a quantized node. 276 | 277 | #### Example: Quantized BMM 278 | To add support for new quantized ops, we need to do the following: 279 | * Get fbgemm config dictionary and add the new entry for the operator to it, and expose the backend config with a function 280 | * Define/modify the lowering function of the custom backend that can fuse the pattern 281 | 282 | ```python 283 | from torch.quantization.quantize_fx import prepare_fx 284 | from torch.quantization.quantize_fx import get_backend_config_dict, QuantizedOperatorType, update_operator_config 285 | import torch.fx 286 | 287 | 288 | # Let's say we want to modify the backend_config of fbgemm and add new quantized operator 289 | custom_backend_config_dict = get_backend_config_dict("fbgemm") 290 | bmm_config = { 291 | "type": QuantizedOperatorType.NEED_OBSERVER_FOR_BOTH_INPUTS_AND_OUTPUTS, 292 | } 293 | update_operator_config(custom_backend_config_dict, torch.bmm, bmm_config) 294 | custom_backend_config_dict["name"] = "custom_backend" 295 | 296 | # define/modify lowering function to support lowering quantized_bmm to the custom backend, we now expect quantized models to have `dequant - torch.bmm - quant` patterns in the model if the input model contains a `torch.bmm` operator 297 | ``` 298 | 299 | 300 | 301 | ### Custom Fusion Support 302 | For custom fusions, we need to do the following: 303 | * Get fbgemm config dictionary and add the new entry for the fused op to it, and expose the backend config with a function 304 | * Define/modify the lowering function of the custom backend that can fuse the pattern 305 | 306 | #### Example: Quantized Fused BMM and Sigmoid 307 | 308 | ```python 309 | from torch.quantization.quantize_fx import prepare_fx 310 | from torch.quantization.quantize_fx import get_backend_config_dict, QuantizedOperatorType, update_operator_config 311 | import torch.fx 312 | 313 | # Let's say we want to modify the backend_config of fbgemm and add new quantized operators 314 | custom_backend_config_dict = get_backend_config("fbgemm").copy() 315 | # Notice that the order of the pattern is reversed, this is 316 | # because we want to support a graph 317 | bmm_config = { 318 | "type": QuantizedOperatorType.NEED_OBSERVER_FOR_BOTH_INPUTS_AND_OUTPUTS, 319 | } 320 | update_operator_config(custom_backend_config_dict, torch.bmm, bmm_config) 321 | update_operator_config(custom_backend_config_dict, (torch.softmax, torch.bmm), bmm_config) 322 | custom_backend_config_dict["name"] = "custom_backend" 323 | 324 | # define/modify lowering function to support lowering pattern for quantized_bmm and quantized_bmm_softmax to the custom backend 325 | # we now expect quantized models to have `dequant - torch.bmm - quant` 326 | # and `dequant - torch.bmm - torch.softmax - quant` patterns 327 | ``` 328 | 329 | # QConfig, QConfigMapping and BackendConfig 330 | ## [QConfig](https://pytorch.org/docs/master/generated/torch.quantization.qconfig.QConfig.html#torch.quantization.qconfig.QConfig) 331 | QConfig is what we use to specify how to observe an operator (e.g. conv) or operator pattern (e.g. conv - relu) in the model, for example: 332 | ``` 333 | qconfig = QConfig(activation=HistogramObserver(dtype=torch.quint8, quant_min=0, quant_max=255), weight=PerChannelMinMaxObserver(dtype=torch.qint8, quant_min=-128, quant_max=127) 334 | ``` 335 | means we want to insert `HistogramObserver` for the input and output of the operator/operator pattern (in FX Graph Mode Quantization), and insert `PerChannelMinMaxObserver` to the weight of the operator/operator pattern. Note that this is relatively restrictive today since we can only specify the same observer for both input and output and only for weight, we plan to make this more general in the future to support specifying different observers for different inputs and outputs. Example: 336 | ``` 337 | QConfig(input_args=(input_act_obs,), input_kwargs={}, attribute = {“weight”: weight_obs, “bias”: bias_obs}, output=output_act_obs) 338 | ``` 339 | 340 | ## [QConfigMapping](https://pytorch.org/docs/master/generated/torch.ao.quantization.qconfig_mapping.QConfigMapping.html#torch.ao.quantization.qconfig_mapping.QConfigMapping) 341 | `QConfigMapping` is used to configure how to quantize a model, it is a set of rules that applies to the model graph to to decide what should be the `QConfig` for each operator or operator pattern, right now we are mostly supporting operators, but in the future we'll support operator pattern as well. 342 | Example: 343 | ``` 344 | qconfig_mapping = QConfigMapping() 345 | .set_global(global_qconfig) 346 | .set_object_type(torch.nn.Linear, qconfig1) 347 | .set_object_type(torch.nn.ReLU, qconfig1) 348 | .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) 349 | .set_module_name_regex("foo.*", qconfig2) 350 | .set_module_name("module1", qconfig1) 351 | .set_module_name("module2", qconfig2) 352 | .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3) 353 | ``` 354 | Setting `qconfig1` for `torch.nn.Linear` means, whenever we see a `call_module` to an module instance of `torch.nn.Linear`, we'll apply `qconfig1`, we'll insert input and output observer for this `call_module` node based on the observer constructors specified in `qconfig1`. 355 | 356 | Note that we only apply QConfig if the configuration is supported by `BackendConfig`, for more details please see https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/README.md#13-quantdequantstub-and-observerfakequantize-insertion 357 | 358 | ## [BackendConfig](https://pytorch.org/docs/master/generated/torch.ao.quantization.backend_config.BackendConfig.html#torch.ao.quantization.backend_config.BackendConfig) 359 | `BackendConfig` is used to configure what are the different ways of quantization are supported for an operator or operator pattern, e.g. a backend may support both static and dynamic quantized linear, they can express this information with `BackendConfig`, and if a modeling user requests a type of quantization that is not supported by the backend. In general, there are three possible cases: 360 | (1). The operator or operator pattern is recognized (the entry exists in BackendConfig) and the requested type of quantization is supported (e.g. static int8 quantization), we'll insert observers based on the `QConfig` 361 | (2). The operator or operator pattern is recognized, but the requested type of quantization (e.g. requested fp16, but only static int8 is supported) is not supported in BackendConfig: we'll print a warning 362 | (3). The operator or operator pattern is not recognized, we'll ignore the request 363 | 364 | # Appendix 365 | * More explanations on patterns: https://docs.google.com/document/d/1kSM0n05vI5Y939n2U3YPD5MWOQNkgXMTpt10ASU-Kns/edit 366 | -------------------------------------------------------------------------------- /RFC-0001-torch-function-for-methods.md: -------------------------------------------------------------------------------- 1 | | | | 2 | | ---------- | ----------------------------------------------- | 3 | | Authors | Hameer Abbasi, Edward Z. Yang and Ralf Gommers | 4 | | Status | Accepted | 5 | | Type | Proposal | 6 | | Created | 2020-01-24 | 7 | | Resolution | TBD | 8 | 9 | # Improving subclassing Tensor by propagating subclass instances 10 | This RFC describes changes necessary to allow `__torch_function__` to be used 11 | by methods of `torch.Tensor` in an attempt to make subclassing more accessible 12 | to the users of the class. This entails making an API for subclass views 13 | public, and a change in the signature of `__torch_function__`. 14 | 15 | ## Motivation and Scope 16 | Quoting [[1]], [[2]] and [[3]], the goals of this proposal are: 17 | 18 | 1. Support subclassing `torch.Tensor` in Python 19 | 2. Preserve `torch.Tensor` subclasses when calling `torch` functions on them 20 | 3. Use the PyTorch API with `torch.Tensor`-like objects that are _not_ `torch.Tensor` 21 | subclasses 22 | 4. Preserve `torch.Tensor` subclasses when calling `torch.Tensor` methods. 23 | 5. Propagating subclass instances correctly also with operators, using 24 | views/slices/indexing/etc. 25 | 6. Preserve subclass attributes when using methods or views/slices/indexing. 26 | 7. A way to insert code that operates on both functions and methods uniformly 27 | (so we can write a single function that overrides all operators). 28 | 8. The ability to give external libraries a way to also define 29 | functions/methods that follow the `__torch_function__` protocol. 30 | 31 | Goals 1‒6 are explicitly about subclassing, goal 7 is already partially achieved via the `__torch_function__` protocol (which we're proposing to extend to methods), and goal 8 is a by-product required to make overridden `torch.Tensor` subclass methods behave similar to `torch.Tensor` methods. 32 | 33 | Achieving interoperability with NumPy and adopting its array protocols is out 34 | of scope for this proposal and we propose to defer it to a later proposal. 35 | 36 | We propose to solve this problem with the following changes to PyTorch: 37 | 38 | 1. Make methods, operators and properties of `torch.Tensor` go through the 39 | `__torch_function__` machinery. 40 | 2. Add a `types` argument to `__torch_function__`, to make it match NumPy's 41 | `__array_function__`. 42 | 3. Add a new method to `torch.Tensor`, `as_subclass`, which creates a subtype 43 | _view_ into the original object. 44 | 4. Make `torch.Tensor` gain a generic implementation of `__torch_function__`. 45 | 46 | ## Usage and Impact 47 | Once this proposal is merged, users of subclasses of `torch.Tensor` will have 48 | a much more streamlined experience. Namely, the following code example will 49 | work as-is, without the need for any further modification: 50 | 51 | ```python 52 | class SubTensor(torch.Tensor): 53 | a = 1 54 | 55 | t = SubTensor([1]) 56 | s = t.sum() 57 | isinstance(s, SubTensor) # True 58 | s.a # 1 59 | i = t[0] 60 | isinstance(i, SubTensor) # True 61 | i.a # 1 62 | 63 | s2 = t + torch.Tensor(1) 64 | isinstance(s2, SubTensor) # True 65 | s2.a # 1 66 | 67 | s3 = torch.Tensor(1) + t 68 | isinstance(s3, SubTensor) # True 69 | s3.a # 1 70 | ``` 71 | 72 | Additionally, it will provide subclass authors the ability to also modify the 73 | results of methods, operators and properties in `__torch_function__`, along with 74 | regular function calls, and to modify the result to their specific use-case, 75 | perform logging, or otherwise change the result or the action of the method. 76 | For example: 77 | 78 | ```python 79 | import logging 80 | 81 | class LoggingTensor(torch.Tensor): 82 | @classmethod 83 | def __torch_function__(cls, func, types, args, kwargs): 84 | logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}") 85 | return super().__torch_function__( 86 | func, 87 | types, 88 | args, 89 | kwargs 90 | ) 91 | ``` 92 | 93 | Assuming minimum logging level is set to `logging.INFO`, the following 94 | indicates the code run, with the logging output in the comments. 95 | 96 | ```python 97 | t = LoggingTensor([1]) 98 | 99 | t.sum() # Tensor.sum, (LoggingTensor([1]),), {} 100 | t[0] # Tensor.__getitem__, (LoggingTensor([1]), 0,), {} 101 | 102 | # This is already possible 103 | torch.sum(t) # sum, (LoggingTensor([1]),), {} 104 | ``` 105 | 106 | To make the protocol operate only on functions rather than methods, one can 107 | check for `func not in type(self).__dict__.values()`. To check for operators 108 | and/or indexing, one can check `func.__name__.endswith("__")`. 109 | 110 | ### Performance 111 | There are a few requirements for the performance of this proposal, when 112 | implemented: 113 | 114 | 1. No deterioration for function/method calls on `torch.Tensor` objects. 115 | 2. No deterioration of current `__torch_function__` overhead 116 | 3. Sub-µs impact on the performance of subclasses not implementing 117 | `__torch_function__`. 118 | 119 | Requirement 1 seems unachievable due to the structure of the code at this 120 | point, as: 121 | 122 | 1. In methods defined in C++, `self` is excluded from the argument processing 123 | that gathers `Tensor`-likes in C++. 124 | 2. Similar to point 1, C++ methods that take only `self` as a `Tensor`-like don't 125 | pass through this processing, and they will be required to. 126 | 3. For methods defined in Python, the processing for handling `__torch_function__` 127 | will need to be added, similar to the original `__torch_function__` PR [[5]]. 128 | 129 | We think an overhead of sub-100 ns per method call is feasible. 130 | 131 | ## Backwards Compatibility 132 | ### With PyTorch `master` as of writing 133 | PyTorch `master` pointed to commit hash 134 | `957a07ffbd13d8a805f4d718e0282efc5d2bff85` at the time of writing. Any classes 135 | implementing `__torch_function__` based on the usage in this commit hash will 136 | break completely, due to the differing signature of the protocol. However, as a 137 | release hasn't been made with `__torch_function__` in it, this is a minor- 138 | impact issue. This brings the design of `__torch_function__` more in line with 139 | NumPy's `__array_function__`, and one familiar with NumPy's protocol could 140 | transition to PyTorch's take on it without too many surprises, with the caveat 141 | that it could also receive methods rather than functions. The release that 142 | `__torch_function__` will make it into PyTorch is expected to be 1.5.0. 143 | 144 | ### With NumPy 145 | The implementation of this proposal will have no effect on how things interact with NumPy. 146 | 147 | ## Detailed Description 148 | ### Introduction 149 | Subclasses are an important way to override functionality of classes. Given the 150 | popularity of PyTorch, a number of subclasses have sprung up, both within and 151 | outside PyTorch. It is important that functions operating on `torch.Tensor`, as 152 | well as methods on it, support passing through the appropriate subclasses, 153 | otherwise information about which type was passed into the function is lost. 154 | The same applies equally, if not more so, to operators and indexing. 155 | 156 | In addition, there has been interest in adding a "universal hook" that operated 157 | on both functions and methods, perhaps modifying the control flow before 158 | returning the result. Such a hook already exists today in the form of 159 | `__torch_function__`, however, it only operates on functions and not on 160 | methods, and support for subclassed `torch.Tensor` objects in this protocol is 161 | limited. 162 | 163 | ### Proposal 164 | We propose the following signature change to `__torch_function__`, to make it 165 | match NumPy, other than the `@classmethod` decorator: [[4]] 166 | 167 | ```python 168 | class SubTensor(torch.Tensor): 169 | @classmethod 170 | def __torch_function__(cls, func, types, args, kwargs): 171 | # Implementation here 172 | ``` 173 | 174 | The reason for adding `types` to the signature is necessitated so we can 175 | check for support of the types if `Tensor`-likes coming in and we do not 176 | mix unrelated class trees. 177 | 178 | ### Process followed during a function/method call 179 | 180 | The process followed during a function/method call would be equivalent to: 181 | 182 | 1. The dispatcher is called to extract the `Tensor`-likes. 183 | 2. All `Tensor`-likes are checked for `__torch_function__`. If none exist, the 184 | internal implementation is called, and the final result is returned. 185 | 3. A collection of types that implement `__torch_function__` is created, with 186 | no guaranteed order other than that subclasses come before superclasses. 187 | 4. For one instance of each type in `types`, `__torch_function__` is called. 188 | The first such function or method to return something other than 189 | `NotImplemented` will be the final result. All exceptions will be propagated 190 | upward. 191 | 5. If all `__torch_function__` implementations return `NotImplemented`, a 192 | `TypeError` is raised with an appropriate error message. 193 | 194 | In practice, for most PyTorch functions, the list of tensor-likes is already 195 | available and the dispatcher doesn't need to be called. Additionally, while 196 | equivalent to the code above, if the `Tensor`-likes are all `Tensor` or don't have 197 | an `__torch_function__` implementation, the internal implementation is called 198 | immediately. This is done as a performance optimisation to avoid overhead for 199 | concrete `Tensor` objects. 200 | 201 | It will be the job of the dispatcher to extract `Tensor`-like objects from the 202 | argument list, however, arguments of type `Optional[Tensor]` will be considered 203 | `Tensor`-like. If one gets a compound or dependent type such as `List[Tensor]` 204 | or `Tuple[Tensor, ...]` or `Tuple[Tensor, int]`, the dispatcher will have the job 205 | of extracting an iterable of objects that *could* be `Tensor`-like. 206 | 207 | ### Generic implementation of `__torch_function__` 208 | `torch.Tensor` will gain a generic `__torch_function__` of the following form: 209 | 210 | ```python 211 | class Tensor: 212 | @classmethod 213 | def __torch_function__(cls, func, types, args, kwargs): 214 | if not all(issubclass(cls, t) for t in types): 215 | return NotImplemented 216 | 217 | # Defer to internal implementation 218 | ret = func._implementation(*args, **kwargs) 219 | if cls is not Tensor and isinstance(ret, Tensor): 220 | ret = ret.as_subclass(cls) 221 | return ret 222 | ``` 223 | 224 | This method has the effect of passing through subclasses through all 225 | functions/methods as intended. 226 | 227 | This corresponds roughly to the implementation `numpy.ndarray` gains in [[4]], 228 | except for the fact that subclasses are passed through via another internal 229 | mechanism (namely the `__array_finalize__` protocol) there, as well as the fact 230 | that we are checking subclassing against `cls` instead of `Tensor`. This 231 | has the side-effect of ensuring unrelated class trees are not merged, which is 232 | an inconsistency in NumPy's own design. Specifically, consider the example of 233 | two direct subclasses of `torch.Tensor`. Both will return `NotImplemented`, and 234 | therefore, the check will fail and `TypeError` will be raised. 235 | 236 | Since subclasses are checked before superclasses in `__torch_function__`, it is 237 | guaranteed that the subclass implementation will be called first. In this 238 | instance, since `cls` is a subclass of all types, the code will 239 | continue. Since `cls` is not `torch.Tensor`, a view into the original 240 | data is created and returned. 241 | 242 | This also works for all operators: `__add__`, `__getitem__` and so on since in 243 | Python these operators are just dunder methods of the corresponding class. 244 | 245 | ### Checking for compatibility 246 | One can check for compatibility with supported classes in the following manner: 247 | 248 | ```python 249 | class MyTensor: 250 | HANDLED_CLASSES = (MyTensor, Tensor, ...) 251 | @classmethod 252 | def __torch_function__(cls, func, types, args, kwargs): 253 | if not issubclass(t, HANDLED_CLASSES) for t in types: 254 | return NotImplemented 255 | # Do further processing here. 256 | ``` 257 | 258 | ### Implementing a subset of the API 259 | One can directly follow the following procedure to implement a subset of the 260 | API by using a hashmap to your own implementations of a function: 261 | 262 | ```python 263 | _TORCH_IMPLEMENTATIONS = {} 264 | 265 | def implements(torch_function): 266 | def inner(f): 267 | _TORCH_IMPLEMENTATIONS[torch_function] = f 268 | return f 269 | return inner 270 | 271 | @implements(torch.add) 272 | def my_add(self, other): 273 | # Implementation here 274 | 275 | class MyTensor: 276 | @classmethod 277 | def __torch_function__(cls, func, types, args, kwargs): 278 | compatible = ... 279 | if not compatible: 280 | return NotImplemented 281 | 282 | if func not in _TORCH_IMPLEMENTATIONS: 283 | return NotImplemented 284 | 285 | return _TORCH_IMPLEMENTATIONS[func](*args, **kwargs) 286 | ``` 287 | 288 | ### The need for `super().__torch_function__` 289 | To access super, one would do the following: 290 | ```python 291 | class SubTensor(torch.Tensor): 292 | @classmethod 293 | def __torch_function__(cls, func, types, args, kwargs): 294 | # Pre-processing here 295 | val = super().__torch_function__( 296 | func, 297 | types 298 | args, 299 | kwargs 300 | ) 301 | # Post processing here 302 | ``` 303 | 304 | To make the need for `super()` to be available concrete, let's consider the 305 | following scenario: 306 | 307 | ```python 308 | class SubTensor(torch.Tensor): 309 | @classmethod 310 | def __torch_function__(...): 311 | # Pre-processing 312 | ret = super().__torch_function__( 313 | func, 314 | types 315 | args, 316 | kwargs 317 | ) 318 | # Post processing 319 | return ret 320 | 321 | class SubSubTensor(SubTensor): 322 | def __add__(self, other): 323 | # Pre-processing 324 | ret = super().__add__(other) 325 | # Post-processing 326 | return ret 327 | ``` 328 | 329 | In this instance, processing would follow the `__torch_function__` protocol. 330 | This means that control would end up in `SubSubTensor.__add__`, go to `Tensor._add__`, 331 | `SubTensor.__torch_function__` from there and and then come to 332 | `Tensor.__torch_function__`, from where it would go to `Tensor.__add__`, and 333 | then back up the stack in the reverse order. This means that great care needs 334 | to be taken when writing `SubTensor.__torch_function__` 335 | to take into account the fact that it has to handle subclass methods. 336 | 337 | In general, control flow will follow this pattern: 338 | 339 | ![Control flow diagram](./RFC-0001-assets/dispatch-flow.svg) 340 | 341 | The reason we use `super().__torch_function__` instead of `func` directly is 342 | 343 | 1. We do not know if there are other `Tensor`-likes that may need to be 344 | handled. 345 | 2. Calling `func` directly would dispatch back to `__torch_function__`, 346 | leading to an infinite recursion. 347 | 348 | ### Protocol support for external libraries 349 | We will also recommend that all `Tensor` subclasses make their own methods that 350 | do not exist on `torch.Tensor` go through `__torch_function__` via a decorator 351 | `@torch_function_dispatch`. This decorator was added and then removed for 352 | performance reasons, however it will be added back to allow external libraries 353 | to interface with the protocol. It will take a single argument: a dispatcher, 354 | i.e. a callable that returns an iterable of all the "duck-`Tensor`s", or 355 | possible candidates for classes that may implement `__torch_function__`. 356 | 357 | If a library forgets to add the aforementioned decorator, then the method will 358 | no longer dispatch at all to any form of `__torch_function__`. In other words, 359 | it will lose support for the protocol. This can lead to confusion, as some 360 | methods of the subclass will pass through `__torch_function__` (the ones 361 | inherited from `torch.Tensor`), and some won't. 362 | 363 | Note that subclasses will still be passed through due to the default 364 | implementation of `__torch_function__`, but any `__torch_function__` defined on 365 | the class itself (or any of its subclasses) won't have an effect on its 366 | methods. 367 | 368 | This is a design choice that a subclass author will have to make, whether they 369 | prefer their own functions/methods to pass through `__torch_function__` like 370 | PyTorch's implementations, or whether they'd like ultimately to not support the 371 | protocol and accept having a mix of overridable and non-overridable methods. 372 | 373 | We do not propose automatic marking of functions with this decorator due to the 374 | potential backwards-compatibility break it could cause, as well as the 375 | parameters that are needed in order to allow this to happen (namely the 376 | dispatcher, which isn't in our control). 377 | 378 | ### Getting the method from its `__name__` and `__module__` 379 | To construct the function given its `__name__` and `__module__`, one can do 380 | the following, as an example: 381 | 382 | ```python 383 | def get_function(name, module): 384 | func = __import__(module) 385 | for n in name.split('.'): 386 | func = getattr(func, n) 387 | return func 388 | ``` 389 | 390 | ### Adding the `torch.Tensor.as_subclass` method 391 | The `torch.Tensor.as_subclass` method will be added, taking a single non-`self` 392 | argument: `cls`, the class for which an instance will be created with a view 393 | into the data of the original `Tensor`. It will become public API. This method 394 | will create an object that has the same data pointer as the original object, 395 | which means that modifications to this will be reflected in the original object. 396 | More or less, it will have the same effect as modifying an object's `__class__` 397 | attribute in Python. 398 | 399 | This method is already used in external libraries, and they may need it as a 400 | way to e.g. bypass the processing of `torch.Tensor.__torch_function__` 401 | entirely, while still creating `torch.Tensor` subclasses in their own code. 402 | 403 | ## Implementation 404 | To implement this proposal requires three main steps: 405 | 406 | 1. Add a `types` argument to `__torch_function__` and make sure that _only_ 407 | arguments that are instances of a type in `types` are processed. 408 | 2. Making sure that all `Tensor` methods except `__new__` and `__init__` go 409 | through `__torch_function__`. 410 | 3. Add `Tensor.as_subclass` and `@torch_function_dispatch` as public API. 411 | 412 | ### Implementing only some methods but not others 413 | One can use the dictionary idiom to only implement some methods but not others. 414 | A code example follows: 415 | 416 | ```python 417 | HANDLED_FUNCTIONS = {} 418 | 419 | def implements(func): 420 | def inner(implementation): 421 | HANDLED_FUNCTIONS[func] = implementation 422 | return implementation 423 | 424 | @implements(torch.add) 425 | def my_add(self, other): 426 | ... 427 | 428 | class TensorLike: 429 | @classmethod 430 | def __torch_function__(cls, func, types, args, kwargs): 431 | implementation = HANDLED_FUNCTIONS.get(func, None) 432 | if implementation is None: 433 | return NotImplemented 434 | 435 | return implementation(*args, **kwargs) 436 | ``` 437 | 438 | For subclasses, one can also choose to use the fallback implementation if 439 | a specialized implementation isn't available using `super`, as shown below. 440 | 441 | ```python 442 | class SubTensor(torch.Tensor): 443 | @classmethod 444 | def __torch_function__(cls, func, types, args, kwargs): 445 | implementation = HANDLED_FUNCTIONS.get(func, None) 446 | if implementation is None: 447 | return super().__torch_function__( 448 | func, types, args, kwargs 449 | ) 450 | 451 | return implementation(*args, **kwargs) 452 | ``` 453 | 454 | A call to `super().__torch_function__` can also be used to call the fallback 455 | implementation within any other function. 456 | 457 | The examples we have seen here actually specify what we anticipate will be two 458 | common patterns of using `__torch_function__`: `LoggingTensor` is an example 459 | of a global hook, and the two examples above show a way to achieve specialised 460 | implementations of particular functions. 461 | 462 | ### Wrapping `torch.Tensor` 463 | Sometimes it's useful to wrap `torch.Tensor` rather than have a subclass. 464 | The following class shows how this is possible in practice: 465 | 466 | ```python 467 | def wrap(f): 468 | @functools.wraps(f) 469 | def inner(self, *a, **kw): 470 | # Call `f` with all-unwrapped args 471 | # Possibly wrap back result before returning 472 | 473 | class WrappedTensor: 474 | def __init__(self, towrap: Tensor): 475 | self._wrapped = towrap 476 | 477 | def __getattr__(self, name): 478 | base = getattr(torch.Tensor, name) 479 | if not callable(base): 480 | return property(wrap(base.__get__)) 481 | 482 | return wrap(base) 483 | 484 | @classmethod 485 | def __torch_function__(cls, func, types, args, kwargs): 486 | return wrap(func)(*args, **kwargs) 487 | ``` 488 | 489 | ## Proposed alternatives 490 | One alternative that has been proposed is to automatically pass through 491 | subclasses a-la NumPy and provide a `__torch_finalize__` method that allows for 492 | any post-processing of the result. While this would achieve most goals, it 493 | would miss out on the one to provide a hook for methods and operators. 494 | 495 | ### Appendix: Special handling for `torch.Tensor` properties/methods 496 | Both functions and methods/properties on `torch.Tensor` will be possible arguments to 497 | `__torch_function__`. These are different in subtle but important ways, and 498 | in some cases it is required to handle them differently. For instance, 499 | `torch.Tensor` methods/properties have the following properties: 500 | 501 | 1. They can only accept `torch.Tensor` instances as the first argument. 502 | 2. They *may or may not* have a `__module__` defined. 503 | 504 | Even classes implementing `__torch_function__` that aren't subclasses 505 | can have methods passed in. It is required to treat this case with care. 506 | Consider the following code: 507 | 508 | ```python 509 | class TensorLike: 510 | @classmethod 511 | def __torch_function__(cls, func, types, args, kwargs): 512 | print(func.__name__) 513 | 514 | torch.tensor([5]) + TensorLike() # prints "add" 515 | ``` 516 | If, in this case, we are using the default implementation, of `func`, and a 517 | `torch.Tensor` instance is not passed in, an error will be raised. To handle 518 | this case, we have provided a utility method, 519 | `torch.overrides.is_tensor_method_or_property`, to determine whether something 520 | is a `torch.Tensor` method/property. 521 | 522 | For properties, their `__get__` method is passed in. For example,for 523 | `torch.Tensor.grad`, `torch.Tensor.grad.__get__` is passed in as `func`. 524 | 525 | 526 | [1]: https://github.com/pytorch/pytorch/issues/22402 "GitHub Issue 22402 on pytorch/pytorch" 527 | [2]: https://github.com/pytorch/pytorch/issues/28361#issuecomment-544520934 "Comment on GitHub Issue 28361 on pytorch/pytorch" 528 | [3]: https://github.com/pytorch/pytorch/issues/28361#issuecomment-557285807 "Comment on GitHub Issue 28361 on pytorch/pytorch" 529 | [4]: https://numpy.org/neps/nep-0018-array-function-protocol.html "NEP 18 — A dispatch mechanism for NumPy’s high level array functions" 530 | [5]: https://github.com/pytorch/pytorch/pull/32194 "GitHub Pull request 32194 on pytorch/pytorch" --------------------------------------------------------------------------------