├── .github └── workflows │ ├── python-app.yml │ └── ufmt.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── bench_linear_float8.py ├── bench_matmul.py ├── bench_multi_gpu.py ├── bench_padding.py ├── profile_linear_float8.py └── utils.py ├── float8_experimental ├── __init__.py ├── config.py ├── distributed_utils.py ├── float8_aten_api.py ├── float8_linear.py ├── float8_linear_utils.py ├── float8_ops.py ├── float8_python_api.py ├── float8_scaling_utils.py ├── float8_tensor.py ├── float8_tensor_parallel.py ├── float8_utils.py ├── fsdp_utils.py └── inference.py ├── pyproject.toml └── test ├── test_base.py ├── test_compile.py ├── test_dtensor.py ├── test_dtensor.sh ├── test_everything.sh ├── test_fsdp.py ├── test_fsdp.sh ├── test_fsdp2 ├── test_fsdp2.py └── test_fsdp2_common.py ├── test_fsdp_compile.py ├── test_fsdp_compile.sh ├── test_inference_flows.py └── test_numerics_integration.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # Basic flak8 + pytest workflow for Python 3.10 2 | 3 | name: Python Lint and Test 4 | 5 | on: 6 | push: 7 | branches: [ "main" ] 8 | pull_request: 9 | branches: [ "main" ] 10 | 11 | permissions: 12 | contents: read 13 | 14 | jobs: 15 | build: 16 | 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Set up Python 3.10 22 | uses: actions/setup-python@v3 23 | with: 24 | python-version: "3.10" 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip3 install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 29 | pip install -e . 30 | pip install -e .'[dev]' 31 | pip install -e .'[test]' 32 | - name: Lint with ruff 33 | run: | 34 | ruff check . 35 | - name: Running Tests 36 | run: | 37 | ./test/test_everything.sh 38 | -------------------------------------------------------------------------------- /.github/workflows/ufmt.yml: -------------------------------------------------------------------------------- 1 | name: Ufmt 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10"] 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | pip install black==23.3.0 usort==1.0.6 ufmt==2.1.0 libcst==1.0.1 24 | - name: Analyzing the code with ufmt 25 | run: | 26 | ufmt format . 27 | git diff 28 | git restore . 29 | ufmt check . 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | float8_experimental/__pycache__/* 3 | finetune/__pycache__/* 4 | test/__pycache__/* 5 | torch_compile_debug/* 6 | test/tmp/* 7 | benchmarks/data/* 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: 'build' 2 | 3 | default_language_version: 4 | python: python3 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: 6306a48f7dae5861702d573c9c247e4e9498e867 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: check-ast 12 | - id: check-merge-conflict 13 | - id: no-commit-to-branch 14 | args: ['--branch=main'] 15 | - id: check-added-large-files 16 | args: ['--maxkb=500'] 17 | - id: end-of-file-fixer 18 | exclude: '^(.*\.svg)$' 19 | 20 | - repo: https://github.com/astral-sh/ruff-pre-commit 21 | # Ruff version. 22 | rev: v0.3.0 23 | hooks: 24 | # Run the linter. 25 | - id: ruff 26 | 27 | - repo: https://github.com/omnilib/ufmt 28 | rev: v2.3.0 29 | hooks: 30 | - id: ufmt 31 | additional_dependencies: 32 | - black == 23.3.0 33 | - usort == 1.0.6 34 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to float8_experimental 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Contributor License Agreement ("CLA") 6 | In order to accept your pull request, we need you to submit a CLA. You only need 7 | to do this once to work on any of Meta's open source projects. 8 | 9 | Complete your CLA here: 10 | 11 | ## Issues 12 | We use GitHub issues to track public bugs. Please ensure your description is 13 | clear and has sufficient instructions to be able to reproduce the issue. 14 | 15 | 16 | ## License 17 | By contributing to float8_experimental, you agree that your contributions will be licensed 18 | under the LICENSE file in the root directory of this source tree. 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, PyTorch Labs 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NOTICE: float8_experimental has moved 2 | 3 | We have moved `float8_experimental` to 4 | [pytorch/ao](https://github.com/pytorch/ao/tree/main/torchao/float8) 5 | 6 | - `import float8_experimental` is now `import torchao.float8` 7 | 8 | # float8_experimental 9 | 10 | This is an early version of a library for accelerating training with float8 in native PyTorch 11 | according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf. 12 | The codebase strives to stay small, easily hackable, debuggable with native PyTorch tooling, 13 | and composable with key systems such as autograd, ```torch.compile``` and distributed. 14 | With ``torch.compile`` on, initial results show 15 | throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs. 16 | 17 | :warning: See the [feature tracker](https://github.com/pytorch-labs/float8_experimental/issues/187) for upcoming features. 18 | 19 | :warning: Backwards compatibility is not guaranteed at this point. The codebase is in active development and 20 | will change rapidly. 21 | 22 | # installation 23 | 24 | :warning: For now, use the latest PyTorch nightly for best results with torch.compile. 25 | 26 | ```Shell 27 | pip install . 28 | 29 | # Optionally install editable 30 | pip install -e . 31 | 32 | # Optionally Install dev tooling 33 | pip install -e ".[dev]" 34 | ``` 35 | 36 | # Single GPU User API 37 | 38 | We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`). 39 | 40 | ## float8 linear with dynamic scaling for `input`, `weight` and `grad_output` 41 | 42 | This is the most accurate recipe as every tensor is scaled dynamically. 43 | 44 | ```python 45 | from float8_experimental import ( 46 | convert_to_float8_training, 47 | precompute_float8_dynamic_scale_for_fsdp, 48 | ) 49 | 50 | # create model 51 | m = Model(...) 52 | 53 | # optional: filter modules from being eligible for float8 conversion 54 | def module_filter_fn(mod: torch.nn.Module, fqn: str): 55 | # don't convert the output module 56 | if fqn == "output": 57 | return False 58 | # don't convert linear modules with weight dimensions not divisible by 16 59 | if isinstance(mod, torch.nn.Linear): 60 | if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: 61 | return False 62 | return True 63 | 64 | # convert all `torch.nn.Linear` modules to `Float8Linear` 65 | convert_to_float8_training(m, module_filter_fn=module_filter_fn) 66 | 67 | # optional: use FSDP 68 | model = FSDP(model, use_orig_params=True) 69 | 70 | # optional: enable torch.compile for improved performance 71 | m = torch.compile(m) 72 | 73 | # toy training loop 74 | for _ in range(N_ITER): 75 | optimizer.zero_grad() 76 | y = m(x) 77 | y.sum().backward() 78 | optimizer.step() 79 | 80 | # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on 81 | # this method is optional but is highly recommended for performance 82 | # it calcuclates scales for all parameters in a single all-reduce 83 | precompute_float8_dynamic_scale_for_fsdp(model) 84 | 85 | ``` 86 | 87 | ## float8 linear with delayed scaling 88 | 89 | This is theoretically the most performant recipe as it minimizes memory reads. 90 | 91 | ```python 92 | from float8_experimental import ( 93 | convert_to_float8_training, 94 | sync_float8_amax_and_scale_history, 95 | ScalingType, 96 | ) 97 | 98 | # create model 99 | m = Model(...) 100 | 101 | # optional: configure for compatibility with FSDP. Note that workarounds 102 | # gated with config.enable_amax_init and 103 | # config.enable_pre_and_post_forward are needed for 104 | # autocast + compile + FSDP + float8 to work 105 | from float8_experimental import Float8LinearConfig, ScalingType, CastConfig 106 | config = Float8LinearConfig( 107 | enable_amax_init = False, # only needed for autocast + compile + FSDP + float8 delayed 108 | enable_pre_and_post_forward, False # only needed for autocast + compile + FSDP + float8 delayed 109 | cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), 110 | cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), 111 | cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), 112 | ) 113 | 114 | # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling 115 | # type 116 | convert_to_float8_training( 117 | m, 118 | config=config, 119 | ) 120 | 121 | # optional: use FSDP 122 | model = FSDP(model, use_orig_params=True) 123 | 124 | # optional: enable torch.compile for improved performance 125 | m = torch.compile(m) 126 | 127 | # toy training loop 128 | for _ in range(N_ITER): 129 | optimizer.zero_grad() 130 | y = m(x) 131 | y.sum().backward() 132 | 133 | # specific to float8 with delayed scaling: separate step to sync scales/amaxes 134 | # in the future, this may move to a context manager 135 | sync_float8_amax_and_scale_history(model) 136 | 137 | optimizer.step() 138 | ``` 139 | 140 | # Multi GPU User API 141 | 142 | We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html), 143 | such as FSDP, TP and SP. Please see the [torchtitan](https://github.com/pytorch/torchtitan) repository for e2e examples 144 | on using `float8_experimental` in a distributed setting. 145 | 146 | # Testing 147 | 148 | ```bash 149 | # run single-GPU unit tests 150 | pytest test/test_base.py 151 | 152 | # run single-GPU compile tests 153 | pytest test/test_compile.py 154 | 155 | # run single-GPU numerics integration tests 156 | pytest test/test_numerics_integration.py 157 | 158 | # run a two-GPU integration test on FSDP 159 | ./test/test_fsdp.sh 160 | 161 | # run integration tests on the DTensor TP/SP integration 162 | ./test/test_dtensor.sh 163 | 164 | # run integration tests on the FSDP2 integration 165 | python test/test_fsdp2/test_fsdp2.py 166 | 167 | # run all of these tests 168 | ./test/test_everything.sh 169 | ``` 170 | 171 | # Benchmarking 172 | 173 | ```bash 174 | # benchmark the torch._scaled_mm function on LLaMa 2 70B shapes 175 | ./benchmarks/bench_matmul.py 176 | 177 | # benchmark fw/bw of `Linear` and `Float8Linear` on LLaMa 2 70B shapes 178 | # make sure to turn on torch.compile to get the best performance 179 | ./benchmarks/bench_linear_float8.py -o ../tmp/test.txt --compile 180 | ``` 181 | 182 | # License 183 | PyTorch has a BSD 3-Clause License, as found in the LICENSE file. 184 | -------------------------------------------------------------------------------- /benchmarks/bench_linear_float8.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import argparse 7 | import copy 8 | from dataclasses import dataclass 9 | from itertools import product 10 | from pathlib import Path 11 | from typing import Callable, List, Optional, Tuple 12 | 13 | import pandas as pd 14 | 15 | import torch 16 | import torch.utils.benchmark as benchmark 17 | from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType 18 | from float8_experimental.float8_linear import Float8Linear 19 | from float8_experimental.float8_linear_utils import ( 20 | linear_requires_sync, 21 | sync_float8_amax_and_scale_history, 22 | ) 23 | from float8_experimental.float8_tensor import ScaledMMConfig 24 | from tqdm import tqdm 25 | 26 | # estimating TOPs for matmuls in fp32, fp16, fp8 27 | # assuming A * B = C, with A being M * K, B being K * N, C being M * N 28 | 29 | # H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ 30 | h100_peak_flops_float32 = 67e12 31 | h100_peak_flops_fp16_tc = 1979e12 32 | h100_peak_tops_float8_tc = 3958e12 33 | 34 | dtype_to_peak_tops = { 35 | torch.float32: h100_peak_flops_float32, 36 | torch.float16: h100_peak_flops_fp16_tc, 37 | torch.bfloat16: h100_peak_flops_fp16_tc, 38 | torch.float8_e4m3fn: h100_peak_tops_float8_tc, 39 | torch.float8_e5m2: h100_peak_tops_float8_tc, 40 | } 41 | 42 | # prevent splitting columns when printing a data frame 43 | pd.set_option("display.expand_frame_repr", False) 44 | # print the entire data frame 45 | pd_print_full_ctx = pd.option_context( 46 | "display.max_rows", None, "display.max_columns", None 47 | ) 48 | 49 | 50 | def benchmark_torch_function_in_microseconds( 51 | func: Callable, 52 | *args, 53 | **kwargs, 54 | ) -> float: 55 | t0 = benchmark.Timer( 56 | stmt="func(*args, **kwargs)", 57 | globals={"args": args, "kwargs": kwargs, "func": func}, 58 | ) 59 | return t0.blocked_autorange().median * 1e6 60 | 61 | 62 | @dataclass 63 | class Experiment: 64 | name: str 65 | shape: Tuple[int, int, int] 66 | ref_time_sec: float 67 | float8_time_sec: float 68 | dtype: torch.dtype 69 | compiled: bool 70 | use_fast_accum: bool 71 | scaling_repr: str 72 | 73 | # 3 Times since we are calculating forward backward 74 | @property 75 | def ref_tops_sec(self): 76 | M, K, N = self.shape 77 | return float(3 * (2 * M * K * N)) / self.ref_time_sec 78 | 79 | @property 80 | def ref_pct_top_peak(self): 81 | return self.ref_tops_sec / dtype_to_peak_tops[self.dtype] 82 | 83 | @property 84 | def float8_tops_sec(self): 85 | M, K, N = self.shape 86 | return float(3 * (2 * M * K * N)) / self.float8_time_sec 87 | 88 | @property 89 | def float8_pct_top_peak(self): 90 | return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn] 91 | 92 | 93 | def main( 94 | sweep_path: Optional[Path] = None, 95 | compile: bool = True, 96 | n_limit: Optional[int] = None, 97 | fast_accum_filter: Optional[bool] = None, 98 | shape_name_filter: Optional[str] = None, 99 | scaling_type_input: str = "dynamic", 100 | scaling_type_weight: str = "dynamic", 101 | scaling_type_grad_output: str = "dynamic", 102 | ): 103 | device = "cuda" 104 | print(f"Compile is set to | {compile}") 105 | 106 | scaling_type_input = ScalingType(scaling_type_input) 107 | scaling_type_weight = ScalingType(scaling_type_weight) 108 | scaling_type_grad_output = ScalingType(scaling_type_grad_output) 109 | config = Float8LinearConfig( 110 | cast_config_input=CastConfig(scaling_type=scaling_type_input), 111 | cast_config_weight=CastConfig(scaling_type=scaling_type_weight), 112 | cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), 113 | ) 114 | 115 | # LLaMa 2 70B single-node weight shapes 116 | # assumes fused attn.wqkv and ffn.w13 117 | name_to_shapes_70b = { 118 | "attn.wqkv": (8192, 1280), 119 | "attn.w0": (1024, 8192), 120 | "ffn.w13": (8192, 7168), 121 | "ffn.w2": (3584, 8192), 122 | } 123 | input_bias = False 124 | if fast_accum_filter is not None: 125 | use_fast_accum = [fast_accum_filter] 126 | else: 127 | use_fast_accum = [True, False] 128 | if shape_name_filter is not None: 129 | k = shape_name_filter 130 | name_to_shapes_70b = {k: name_to_shapes_70b[k]} 131 | experiment_list: List[Experiment] = [] 132 | dtype = torch.bfloat16 133 | for idx, (fast_accum, (name, (K, N))) in enumerate( 134 | tqdm(list(product(use_fast_accum, name_to_shapes_70b.items()))) 135 | ): 136 | if n_limit is not None and idx >= n_limit: 137 | break 138 | linear_ref = torch.nn.Linear(K, N, bias=input_bias).to( 139 | device=device, dtype=dtype 140 | ) 141 | 142 | linear_float8 = Float8Linear.from_float( 143 | copy.deepcopy(linear_ref), 144 | config=config, 145 | ) 146 | scaling_repr = linear_float8.scaling_repr() 147 | 148 | if fast_accum: 149 | linear_float8.forward_config = ScaledMMConfig(False, True, False) 150 | else: 151 | linear_float8.forward_config = ScaledMMConfig(False, False, False) 152 | 153 | bsz, seq_len = 4, 4096 154 | M = bsz * seq_len 155 | input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True) 156 | ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() 157 | 158 | def float8_forw_backward(): 159 | if linear_requires_sync(config): 160 | sync_float8_amax_and_scale_history(linear_float8) 161 | linear_float8(input_tensor).sum().backward() 162 | 163 | def n_times(n, fn, *args, **kwargs): 164 | def wrapper(*args, **kwargs): 165 | for _ in range(n): 166 | fn(*args, **kwargs) 167 | 168 | return wrapper 169 | 170 | REPEAT_N = 100 171 | 172 | ref_forw_backward = n_times(REPEAT_N, ref_forw_backward) 173 | float8_forw_backward = n_times(REPEAT_N, float8_forw_backward) 174 | 175 | if compile: 176 | ref_forw_backward = torch.compile(ref_forw_backward) 177 | float8_forw_backward = torch.compile(float8_forw_backward) 178 | 179 | for _ in range(5): 180 | ref_forw_backward() 181 | float8_forw_backward() 182 | 183 | ref_time = ( 184 | benchmark_torch_function_in_microseconds(ref_forw_backward) 185 | * 1e-6 186 | / REPEAT_N 187 | ) 188 | float8_time = ( 189 | benchmark_torch_function_in_microseconds(float8_forw_backward) 190 | * 1e-6 191 | / REPEAT_N 192 | ) 193 | experiment = Experiment( 194 | name, 195 | (M, K, N), 196 | ref_time, 197 | float8_time, 198 | dtype, 199 | compile, 200 | use_fast_accum=fast_accum, 201 | scaling_repr=scaling_repr, 202 | ) 203 | print(experiment) 204 | print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec) 205 | experiment_list.append(experiment) 206 | torch._dynamo.reset() 207 | 208 | headers = [ 209 | "name", 210 | "M", 211 | "K", 212 | "N", 213 | "scaling_repr", 214 | "ref_dtype", 215 | "compiled", 216 | "use_fast_accum", 217 | "ref_time_sec", 218 | "pt_fp8_time_sec", 219 | "ref_tops_sec", 220 | "ref_pct_top_peak", 221 | "pt_fp8_tops_sec", 222 | "pt_fp8_pct_top_peak", 223 | ] 224 | data = [] 225 | for experiment in experiment_list: 226 | data.append( 227 | [ 228 | experiment.name, 229 | experiment.shape[0], 230 | experiment.shape[1], 231 | experiment.shape[2], 232 | experiment.scaling_repr, 233 | experiment.dtype, 234 | experiment.compiled, 235 | experiment.use_fast_accum, 236 | experiment.ref_time_sec, 237 | experiment.float8_time_sec, 238 | experiment.ref_tops_sec, 239 | experiment.ref_pct_top_peak, 240 | experiment.float8_tops_sec, 241 | experiment.float8_pct_top_peak, 242 | ] 243 | ) 244 | 245 | data_pd = pd.DataFrame(data, columns=headers) 246 | data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"] 247 | data_pd["shape"] = ( 248 | "(" 249 | + data_pd["M"].astype(str) 250 | + ", " 251 | + data_pd["K"].astype(str) 252 | + ", " 253 | + data_pd["N"].astype(str) 254 | + ")" 255 | ) 256 | 257 | data_pd_simple = data_pd[ 258 | [ 259 | "name", 260 | "shape", 261 | "scaling_repr", 262 | "compiled", 263 | "use_fast_accum", 264 | "ref_time_sec", 265 | "pt_fp8_time_sec", 266 | "pt_fp8_speedup", 267 | ] 268 | ] 269 | with pd_print_full_ctx: 270 | print(data_pd_simple) 271 | 272 | if sweep_path is not None: 273 | sweep_path = sweep_path.with_suffix(".csv") 274 | data_pd.to_csv(sweep_path) 275 | 276 | 277 | def invoke_main() -> None: 278 | parser = argparse.ArgumentParser() 279 | parser.add_argument("-o", "--output_path", type=str, required=False) 280 | parser.add_argument("--disable_compile", action="store_true") 281 | parser.add_argument("-n", "--n_limit", type=int, required=False) 282 | parser.add_argument("--fast_accum_filter", type=bool, required=False) 283 | parser.add_argument("--shape_name_filter", type=str, required=False) 284 | parser.add_argument("--scaling_type_input", type=str, required=False) 285 | parser.add_argument("--scaling_type_weight", type=str, required=False) 286 | parser.add_argument("--scaling_type_grad_output", type=str, required=False) 287 | args = parser.parse_args() 288 | output_path = Path(args.output_path) if args.output_path is not None else None 289 | kwargs = {} 290 | if args.scaling_type_input is not None: 291 | kwargs["scaling_type_input"] = args.scaling_type_input 292 | if args.scaling_type_weight is not None: 293 | kwargs["scaling_type_weight"] = args.scaling_type_weight 294 | if args.scaling_type_grad_output is not None: 295 | kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output 296 | main( 297 | output_path, 298 | not args.disable_compile, 299 | args.n_limit, 300 | args.fast_accum_filter, 301 | args.shape_name_filter, 302 | **kwargs, 303 | ) 304 | 305 | 306 | if __name__ == "__main__": 307 | invoke_main() # pragma: no cover 308 | -------------------------------------------------------------------------------- /benchmarks/bench_matmul.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import itertools 7 | from typing import Optional 8 | 9 | import fire 10 | import pandas as pd 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.utils.benchmark as benchmark 15 | 16 | # estimating TOPs for matmuls in fp32, fp16, fp8 17 | # assuming A * B = C, with A being M * K, B being K * N, C being M * N 18 | 19 | # H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ 20 | h100_peak_flops_float32 = 67e12 21 | h100_peak_flops_fp16_tc = 989e12 22 | h100_peak_tops_float8_tc = 1979e12 23 | 24 | dtype_to_peak_tops = { 25 | torch.float32: h100_peak_flops_float32, 26 | torch.float16: h100_peak_flops_fp16_tc, 27 | torch.bfloat16: h100_peak_flops_fp16_tc, 28 | torch.float8_e4m3fn: h100_peak_tops_float8_tc, 29 | torch.float8_e5m2: h100_peak_tops_float8_tc, 30 | } 31 | 32 | 33 | def benchmark_fn_in_sec(f, *args, **kwargs): 34 | # Manual warmup 35 | for _ in range(4): 36 | f(*args, **kwargs) 37 | t0 = benchmark.Timer( 38 | stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} 39 | ) 40 | measurement = t0.blocked_autorange() 41 | return measurement.mean 42 | 43 | 44 | def do_benchmarks(tops, peak_tops, f, *args, **kwargs): 45 | time_sec = benchmark_fn_in_sec(f, *args, **kwargs) 46 | tops_sec = float(tops) / time_sec 47 | pct_top_peak = tops_sec / peak_tops 48 | return time_sec, tops_sec, pct_top_peak 49 | 50 | 51 | @torch.inference_mode() 52 | def run(n_limit: Optional[int] = None): 53 | device = "cuda" 54 | 55 | # LLaMa 2 70B single-node weight shapes 56 | # assumes fused attn.wqkv and ffn.w13 57 | # source: https://fburl.com/gsheet/g8onr7rh 58 | name_to_shapes_70b = { 59 | "attn.wqkv": (8192, 1280), 60 | "attn.w0": (1024, 8192), 61 | "ffn.w13": (8192, 7168), 62 | "ffn.w2": (3584, 8192), 63 | } 64 | 65 | headers = ("name", "shape", "dtype", "ref_time_s", "fp8_time_s", "fp8_speedup") 66 | results = [] 67 | 68 | name_to_shapes = name_to_shapes_70b 69 | dtypes = torch.bfloat16, torch.float16 70 | 71 | for idx, (dtype, (name, (K, N))) in enumerate( 72 | itertools.product(dtypes, name_to_shapes.items()) 73 | ): 74 | if n_limit is not None and idx >= n_limit: 75 | break 76 | 77 | # source: Xiao Sun, these are realistic for LLaMa 70B training 78 | bsz, seq_len = 4, 4096 79 | 80 | M = bsz * seq_len 81 | print("M, K, N:", M, K, N) 82 | tops = 2 * M * N * K 83 | print(f"tops: {tops:.2E}") 84 | 85 | # raw torch.mm 86 | A = torch.randn(M, K, device=device, dtype=dtype) 87 | m_ref = nn.Sequential(nn.Linear(K, N, dtype=dtype, device=device, bias=False)) 88 | ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks( 89 | tops, dtype_to_peak_tops[dtype], m_ref, A 90 | ) 91 | print( 92 | f"{dtype} time_sec {ref_time_sec:.2E}, tops/sec {ref_tops_sec:.2E}, pct_peak {ref_pct_top_peak:.3f}" 93 | ) 94 | 95 | del A 96 | 97 | # raw float8 matmul (upper bound for what we can achive in eager mode) 98 | # TODO(future): add e5m2 99 | d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype 100 | A = torch.zeros(M, K, device=device, dtype=d1) 101 | B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() 102 | 103 | def do_matmul(A, B): 104 | scale_a = torch.tensor([1.0], device=device) 105 | scale_b = torch.tensor([1.0], device=device) 106 | return torch._scaled_mm( 107 | A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False 108 | ) 109 | 110 | fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks( 111 | tops, dtype_to_peak_tops[d1], do_matmul, A, B 112 | ) 113 | print( 114 | f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}" 115 | ) 116 | 117 | del A, B 118 | 119 | results.append( 120 | [ 121 | name, 122 | (M, K, N), 123 | dtype, 124 | ref_time_sec, 125 | fp8_time_sec, 126 | ref_time_sec / fp8_time_sec, 127 | ] 128 | ) 129 | 130 | data_pd = pd.DataFrame(results, columns=headers) 131 | print(data_pd) 132 | 133 | 134 | def main() -> None: 135 | fire.Fire(run) 136 | 137 | 138 | if __name__ == "__main__": 139 | main() # pragma: no cover 140 | -------------------------------------------------------------------------------- /benchmarks/bench_multi_gpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import Callable 9 | 10 | import fire 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp 15 | import torch.nn as nn 16 | import torch.utils.benchmark as benchmark 17 | from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType 18 | from float8_experimental.float8_linear_utils import ( 19 | convert_to_float8_training, 20 | sync_float8_amax_and_scale_history, 21 | ) 22 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 23 | 24 | 25 | torch.manual_seed(0) 26 | 27 | # TODO: Add more shapes for the benchmark 28 | B, M, K, N = 32, 1024, 1024, 1024 29 | lr = 0.01 30 | 31 | config = Float8LinearConfig( 32 | cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), 33 | cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), 34 | cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), 35 | ) 36 | 37 | 38 | def benchmark_torch_function_in_microseconds( 39 | func: Callable, 40 | *args, 41 | **kwargs, 42 | ) -> float: 43 | t0 = benchmark.Timer( 44 | stmt="func(*args, **kwargs)", 45 | globals={"args": args, "kwargs": kwargs, "func": func}, 46 | ) 47 | return t0.blocked_autorange().median * 1e6 48 | 49 | 50 | def setup(rank, world_size): 51 | os.environ["MASTER_ADDR"] = "localhost" 52 | os.environ["MASTER_PORT"] = "12355" 53 | 54 | # initialize the process group 55 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 56 | 57 | 58 | def cleanup(): 59 | dist.destroy_process_group() 60 | 61 | 62 | def get_model(K, N, is_fp8, base_dtype=torch.float32): 63 | modules = [ 64 | nn.Linear(K, N, dtype=base_dtype), 65 | nn.ReLU(), 66 | ] 67 | N_LAYERS = 20 68 | # N linear layers 69 | for _ in range(N_LAYERS - 1): 70 | modules.append(nn.Linear(N, N, dtype=base_dtype)) 71 | modules.append(nn.ReLU()) 72 | m = nn.Sequential(*modules) 73 | if is_fp8: 74 | convert_to_float8_training( 75 | m, 76 | config=config, 77 | ) 78 | return m 79 | 80 | 81 | def fsdp_main(rank, world_size, args): 82 | setup(rank, world_size) 83 | torch.cuda.set_device(rank) 84 | 85 | base_dtype, input_global, compile = args 86 | 87 | # basic distributed data sampling 88 | assert B % world_size == 0 89 | bsz_local_start = int(rank / world_size * B) 90 | bsz_local_end = int((rank + 1) / world_size * B) 91 | input_tensor = input_global[bsz_local_start:bsz_local_end].to(rank) 92 | 93 | fp8_model = get_model(K, N, is_fp8=True, base_dtype=base_dtype).to(rank) 94 | # Need use_orig_params=True to compile FSDP 95 | fp8_model = FSDP(fp8_model, use_orig_params=True) 96 | fp8_optimizer = torch.optim.SGD(fp8_model.parameters(), lr=lr * world_size) 97 | 98 | # Run one iteration to make compile work, see experiments doc for more context of this issue. 99 | fp8_optimizer.zero_grad() 100 | y_local = fp8_model(input_tensor) 101 | y_local.sum().backward() 102 | fp8_optimizer.step() 103 | sync_float8_amax_and_scale_history(fp8_model) 104 | 105 | sync_float8_func = sync_float8_amax_and_scale_history 106 | if compile: 107 | # TODO: Need to fix issues with compile 108 | fp8_model = torch.compile(fp8_model) 109 | sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) 110 | 111 | def float8_forw_backward(): 112 | fp8_optimizer.zero_grad() 113 | y_local = fp8_model(input_tensor) 114 | y_local.sum().backward() 115 | fp8_optimizer.step() 116 | sync_float8_func(fp8_model) 117 | 118 | ref_model = get_model(K, N, is_fp8=False, base_dtype=base_dtype).to(rank) 119 | ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size) 120 | if compile: 121 | ref_model = torch.compile(ref_model) 122 | 123 | ref_model = FSDP(ref_model, use_orig_params=True) 124 | 125 | def ref_forw_backward(): 126 | ref_optimizer.zero_grad() 127 | ref_model(input_tensor).sum().backward() 128 | ref_optimizer.step() 129 | 130 | def run_n_iterations(n, fn): 131 | for _ in range(n): 132 | fn() 133 | # make sure training is done on all ranks 134 | dist.barrier() 135 | 136 | # warmup 137 | run_n_iterations(50, ref_forw_backward) 138 | run_n_iterations(50, float8_forw_backward) 139 | 140 | N_ITER = 50 141 | ref_time = ( 142 | benchmark_torch_function_in_microseconds( 143 | run_n_iterations, N_ITER, ref_forw_backward 144 | ) 145 | * 1e-6 146 | / N_ITER 147 | ) 148 | float8_time = ( 149 | benchmark_torch_function_in_microseconds( 150 | run_n_iterations, N_ITER, float8_forw_backward 151 | ) 152 | * 1e-6 153 | / N_ITER 154 | ) 155 | 156 | if rank == 0: 157 | print("ref_time", ref_time) 158 | print("float8_time", float8_time) 159 | print("float8 speedup", ref_time / float8_time) 160 | 161 | cleanup() 162 | 163 | 164 | def run(compile: bool): 165 | base_dtype = torch.bfloat16 166 | WORLD_SIZE = torch.cuda.device_count() 167 | print(f"{base_dtype = }") 168 | print(f"{compile = }") 169 | print(f"{WORLD_SIZE = }") 170 | 171 | # generate input data 172 | ref_input = torch.randn(B, M, K).cuda().to(base_dtype) 173 | # run fsdp model 174 | args = (base_dtype, ref_input, compile) 175 | mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) 176 | 177 | 178 | # Usgae: 179 | # CUDA_VISIBLE_DEVICES=0,1 python benchmarks/bench_multi_gpu.py 180 | if __name__ == "__main__": 181 | fire.Fire(run) 182 | -------------------------------------------------------------------------------- /benchmarks/bench_padding.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import fire 5 | 6 | import torch 7 | from float8_experimental.float8_tensor import ( 8 | GemmInputRole, 9 | hp_tensor_and_scale_to_float8, 10 | LinearMMConfig, 11 | ScaledMMConfig, 12 | ) 13 | from float8_experimental.float8_utils import pad_tensor_for_matmul 14 | from tabulate import tabulate 15 | from torch._inductor.utils import do_bench_using_profiling 16 | from tqdm import tqdm 17 | 18 | # estimating TOPs for matmuls in fp32, fp16, fp8 19 | # assuming A * B = C, with A being M * K, B being K * N, C being M * N 20 | 21 | # H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ 22 | h100_peak_flops_float32 = 67e12 23 | h100_peak_flops_fp16_tc = 1979e12 24 | h100_peak_tops_float8_tc = 3958e12 25 | 26 | dtype_to_peak_tops = { 27 | torch.float32: h100_peak_flops_float32, 28 | torch.float16: h100_peak_flops_fp16_tc, 29 | torch.bfloat16: h100_peak_flops_fp16_tc, 30 | torch.float8_e4m3fn: h100_peak_tops_float8_tc, 31 | torch.float8_e5m2: h100_peak_tops_float8_tc, 32 | } 33 | 34 | 35 | def benchmark_fn_in_usec(f, *args, **kwargs): 36 | no_args = lambda: f(*args, **kwargs) 37 | time = do_bench_using_profiling(no_args) 38 | return time * 1e3 39 | 40 | 41 | def get_tops_info(tops, time, peak_tops): 42 | time_sec = time / 1e6 43 | tops_sec = float(tops) / time_sec 44 | pct_top_peak = tops_sec / peak_tops 45 | return tops_sec, pct_top_peak 46 | 47 | 48 | def do_fp8_matmul(A, B, fp8_dtype, out_dtype): 49 | scale_a = torch.tensor([1], device="cuda", dtype=torch.float32) 50 | scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) 51 | 52 | a_config = ScaledMMConfig( 53 | emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True 54 | ) 55 | b_config = ScaledMMConfig( 56 | emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True 57 | ) 58 | a_config = LinearMMConfig(a_config, a_config, a_config) 59 | b_config = LinearMMConfig(b_config, b_config, b_config) 60 | 61 | a_fp8 = hp_tensor_and_scale_to_float8( 62 | A, 63 | scale_a, 64 | fp8_dtype, 65 | a_config, 66 | GemmInputRole.INPUT, 67 | ) 68 | b_fp8 = hp_tensor_and_scale_to_float8( 69 | B, 70 | scale_b, 71 | fp8_dtype, 72 | b_config, 73 | GemmInputRole.WEIGHT, 74 | ) 75 | 76 | return a_fp8 @ b_fp8 77 | 78 | 79 | def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype): 80 | # Breaks with compile due to trying to pad on fp8 dtype 81 | # return do_fp8_matmul(A, B, fp8_dtype, out_dtype) 82 | A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy 83 | B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy 84 | 85 | scale_a = torch.tensor([1], device="cuda", dtype=torch.float32) 86 | scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) 87 | 88 | A_pad = A_pad.to(fp8_dtype) # mem copy 89 | B_pad = B_pad.to(fp8_dtype) # mem copy 90 | 91 | B_pad = B_pad.t().contiguous().t() # mem copy 92 | 93 | return torch._scaled_mm( 94 | A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True 95 | ) 96 | 97 | 98 | def do_hp_matmul(A, B): 99 | return torch.matmul(A, B) 100 | 101 | 102 | def do_aligned_bf16_matmul(A, B): 103 | A_pad = pad_tensor_for_matmul(A, dims=1) 104 | B_pad = pad_tensor_for_matmul(B, dims=0) 105 | return torch.matmul(A_pad, B_pad) 106 | 107 | 108 | @dataclass 109 | class Experiment_config: 110 | M: int 111 | K: int 112 | N: int 113 | output_dtype: torch.dtype 114 | fp8_dtype: torch.dtype 115 | 116 | def __iter__(self): 117 | return iter((self.M, self.K, self.N, self.output_dtype, self.fp8_dtype)) 118 | 119 | 120 | def gen_configs(): 121 | shapes = shapes = [ 122 | (8193, 2501, 5008), 123 | (65, 253, 4096), 124 | (1023, 1029, 2512), 125 | (4095, 511, 10000), 126 | (2047, 3073, 8192), 127 | (511, 769, 7504), 128 | (127, 4097, 12288), 129 | (32769, 15, 15024), 130 | (9217, 8191, 20480), 131 | (16385, 1025, 25008), 132 | ] 133 | output_dtype = torch.bfloat16 134 | fp8_dtype = torch.float8_e4m3fn 135 | return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes] 136 | 137 | 138 | @torch.no_grad() 139 | def run(compile: bool = False, n_limit: Optional[int] = None): 140 | device = "cuda" 141 | experiments = gen_configs() 142 | results = [] 143 | tops_table = [] 144 | tops_headers = [ 145 | "Shape", 146 | "Ref Dtype", 147 | "Ref Tops", 148 | "Aligned BF16 Tops", 149 | "FP8 Tops", 150 | "Ref % Peak", 151 | "Aligned BF16 % Peak", 152 | "FP8 % Peak", 153 | ] 154 | 155 | for experiment in tqdm(experiments): 156 | M, K, N, output_dtype, fp8_dtype = experiment 157 | tops = 2 * M * N * K 158 | 159 | A_base = torch.rand(M, K, device=device, dtype=output_dtype) 160 | B_base = torch.rand(K, N, device=device, dtype=output_dtype) 161 | 162 | hp_func = torch.compile(do_hp_matmul) if compile else do_hp_matmul 163 | aligned_bf16_func = ( 164 | torch.compile(do_aligned_bf16_matmul) if compile else do_aligned_bf16_matmul 165 | ) 166 | fp8_func = torch.compile(do_fp8_pad_first_matmul) if compile else do_fp8_matmul 167 | 168 | ref_time = benchmark_fn_in_usec(hp_func, A_base, B_base) 169 | aligned_bf16_time = benchmark_fn_in_usec(aligned_bf16_func, A_base, B_base) 170 | fp8_time = benchmark_fn_in_usec( 171 | fp8_func, A_base, B_base, fp8_dtype, output_dtype 172 | ) 173 | 174 | ref_tops_sec, ref_pct_top_peak = get_tops_info( 175 | tops, ref_time, dtype_to_peak_tops[output_dtype] 176 | ) 177 | aligned_bf16_tops_sec, aligned_bf16_pct_top_peak = get_tops_info( 178 | tops, aligned_bf16_time, dtype_to_peak_tops[torch.bfloat16] 179 | ) 180 | fp8_tops_sec, fp8_pct_top_peak = get_tops_info( 181 | tops, fp8_time, dtype_to_peak_tops[fp8_dtype] 182 | ) 183 | tops_table.append( 184 | [ 185 | f"({M}x{K}x{N})", 186 | f"{output_dtype}", 187 | f"{ref_tops_sec:.2E}", 188 | f"{aligned_bf16_tops_sec:.2E}", 189 | f"{fp8_tops_sec:.2E}", 190 | f"{ref_pct_top_peak:.3f}", 191 | f"{aligned_bf16_pct_top_peak:.3f}", 192 | f"{fp8_pct_top_peak:.3f}", 193 | ] 194 | ) 195 | results.append( 196 | [ 197 | (M, K, N), 198 | output_dtype, 199 | ref_time, 200 | aligned_bf16_time, 201 | fp8_time, 202 | ref_time / aligned_bf16_time, 203 | ref_time / fp8_time, 204 | ] 205 | ) 206 | 207 | print("TOPs".center(80, "*")) 208 | print(tabulate(tops_table, headers=tops_headers)) 209 | print("Speed Results".center(80, "*")) 210 | headers = [ 211 | "Shape", 212 | "Ref Dtype", 213 | "Ref Time", 214 | "Aligned BF16 Time", 215 | "FP8 Time", 216 | "Aligned BF16 Speedup", 217 | "FP8 Speedup", 218 | ] 219 | print(tabulate(results, headers=headers, tablefmt="grid")) 220 | 221 | 222 | if __name__ == "__main__": 223 | fire.Fire(run) 224 | -------------------------------------------------------------------------------- /benchmarks/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import collections 8 | import re 9 | 10 | 11 | def profiler_output_to_time_by_kernel_name(prof): 12 | """ 13 | Input: a profiler with captured events. 14 | Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name 15 | 16 | Note that if there are user_annotations in the captured events, `torch.profiler` 17 | will include their time in the total GPU time displayed at the bottom of 18 | `key_averages.table()`. The filter below excludes them to prevent double 19 | counting. 20 | """ 21 | key_averages = prof.key_averages() 22 | thresh = 1e-10 23 | kernel_name_to_gpu_time_us = collections.defaultdict(float) 24 | for e in key_averages: 25 | # manually filter top-level CPU events with attributed CUDA time 26 | # example CPU event row: 27 | # aten::addmm 0.83% 76.554us 0.98% 90.846us 90.846us 1.022ms 31.82% 1.022ms 1.022ms 1 28 | # and it maps to this CUDA event: 29 | # sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize256x64... 0.00% 0.000us 0.00% 0.000us 0.000us 1.022ms 31.82% 1.022ms 1.022ms 1 30 | if not (e.self_cpu_time_total > thresh and e.self_device_time_total > thresh): 31 | continue 32 | kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total 33 | return kernel_name_to_gpu_time_us 34 | 35 | 36 | def profiler_output_to_gpu_time_for_key(prof, key): 37 | """ 38 | Input: an event name 39 | Output: sum of GPU time of all events with that name in `prof` 40 | 41 | This is useful to get the total time of a user annotation 42 | """ 43 | total = 0 44 | for e in prof.profiler.function_events: 45 | if e.key == key: 46 | total += e.device_time_total 47 | return total 48 | 49 | 50 | def kernel_name_to_category(k): 51 | # number prefix is for easy sorting 52 | if k in ("aten::mm", "aten::addmm", "aten::_scaled_mm"): 53 | return "0_gemm" 54 | elif ( 55 | # max(abs(tensor)) 56 | ("abs" in k and "max" in k) 57 | or 58 | # casting pointwise to float8 59 | ("clamp" in k) 60 | or 61 | # things related to scaled_mm 62 | ("scaled_mm" in k) 63 | or 64 | # syncing amaxes and scales 65 | ("roll" in k) 66 | ): 67 | # note: the above filter is approximate and will give false 68 | # positives if model code contains other code to abs/max/clamp 69 | return "1_f8_overhead" 70 | return "2_other" 71 | 72 | 73 | def parse_bw_and_kernel_name(line): 74 | """ 75 | Input: a single line of stdout of TORCHINDUCTOR_PROFILE=1 output, such as 76 | 0.257ms 0.537 GB 2092.43GB/s triton_red_fused_native_layer_norm_0 77 | Output: the bandwidth value and the kernel name, or None and None 78 | """ 79 | result = re.search(".* ([0-9\.]+)GB/s.*(triton_[a-z_0-9]+)", line) 80 | if result: 81 | return result.group(1), result.group(2) 82 | else: 83 | return None, None 84 | -------------------------------------------------------------------------------- /float8_experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Lets define a few top level things here 7 | from float8_experimental.config import ( 8 | CastConfig, 9 | DelayedScalingConfig, 10 | Float8GemmConfig, 11 | Float8LinearConfig, 12 | ScalingType, 13 | ) 14 | from float8_experimental.float8_linear import Float8Linear 15 | from float8_experimental.float8_linear_utils import ( 16 | convert_to_float8_training, 17 | linear_requires_sync, 18 | sync_float8_amax_and_scale_history, 19 | ) 20 | from float8_experimental.float8_tensor import ( 21 | Float8Tensor, 22 | GemmInputRole, 23 | LinearMMConfig, 24 | ScaledMMConfig, 25 | ) 26 | from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp 27 | 28 | # Needed to load Float8Tensor with weights_only = True 29 | from torch.serialization import add_safe_globals 30 | 31 | add_safe_globals([Float8Tensor, ScaledMMConfig, GemmInputRole, LinearMMConfig]) 32 | 33 | __all__ = [ 34 | # configuration 35 | "DelayedScalingConfig", 36 | "ScalingType", 37 | "Float8GemmConfig", 38 | "Float8LinearConfig", 39 | "CastConfig", 40 | # top level UX 41 | "convert_to_float8_training", 42 | "linear_requires_sync", 43 | "sync_float8_amax_and_scale_history", 44 | "precompute_float8_dynamic_scale_for_fsdp", 45 | # note: Float8Tensor and Float8Linear are not public APIs 46 | ] 47 | -------------------------------------------------------------------------------- /float8_experimental/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import enum 8 | from dataclasses import dataclass 9 | 10 | 11 | # TODO(future): consider renaming to ScalingType 12 | class ScalingType(enum.Enum): 13 | DELAYED = "delayed" 14 | DYNAMIC = "dynamic" 15 | 16 | def short_str(self): 17 | if self is ScalingType.DELAYED: 18 | return "del" 19 | else: 20 | assert self is ScalingType.DYNAMIC 21 | return "dyn" 22 | 23 | 24 | @dataclass(frozen=True) 25 | class CastConfig: 26 | """ 27 | Configuration for casting a single tensor to float8 28 | """ 29 | 30 | scaling_type: ScalingType = ScalingType.DYNAMIC 31 | 32 | 33 | @dataclass(frozen=True) 34 | class DelayedScalingConfig: 35 | """ 36 | Configuration for delayed scaling. 37 | 38 | Note: for now, `history_len` values must be the same for all layers in the 39 | model using delayed scaling. 40 | 41 | TODO(future): serialization for recipes 42 | """ 43 | 44 | # Controls the history length of amax buffers 45 | history_len: int = 16 46 | 47 | # Controls the way to calculate current scale from amax history 48 | # TODO(future): add other functions as needed, hardcoded or user defined 49 | scale_fn_name: str = "max" 50 | 51 | def __post_init__(self): 52 | assert ( 53 | self.scale_fn_name == "max" 54 | ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." 55 | 56 | 57 | @dataclass(frozen=True) 58 | class Float8GemmConfig: 59 | """ 60 | Configuration for a float8 gemm. 61 | """ 62 | 63 | # If True, fast accumulation in lower precision is used. 64 | # Note: this flag is currently a no-op if emulation is turned on. 65 | use_fast_accum: bool = False 66 | 67 | 68 | @dataclass(frozen=True) 69 | class Float8LinearConfig: 70 | """ 71 | Configuration for converting a `torch.nn.Linear` module to float8 72 | for training. 73 | """ 74 | 75 | # 76 | # Per-tensor configuration for `input`, `weight`, `grad_output` 77 | # 78 | cast_config_input: CastConfig = CastConfig() 79 | cast_config_weight: CastConfig = CastConfig() 80 | cast_config_grad_output: CastConfig = CastConfig() 81 | 82 | # 83 | # Per-gemm configuration for gemms calculating `output`, `grad_input` and 84 | # `grad_weight` 85 | # 86 | gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) 87 | gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() 88 | gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig() 89 | 90 | # 91 | # Per-linear configuration 92 | # 93 | 94 | # If True, on the first iteration of Float8Linear the amaxes will be 95 | # initialized with the incoming data. As of 2023-12-30, this doesn't work 96 | # with autocast + torch.compile + FSDP. Enabling this option is nice for 97 | # testing, but this is not necessary for real training jobs. 98 | enable_amax_init: bool = True 99 | 100 | # If True, pre-forward and post-forward functions are run. As of 2023-12-30, 101 | # this doesn't work with autocast + torch.compile + FSDP. Enabling this 102 | # option is useful for safety, but not strictly necessary. 103 | enable_pre_and_post_forward: bool = True 104 | 105 | # If True, then uses a tensor subclass for the float8 linear module's weight that 106 | # implements pre/post-all-gather methods to do float8 all-gather with FSDP2. 107 | enable_fsdp_float8_all_gather: bool = False 108 | 109 | # If True, then prior to performing the fp8 scaled mamtmul we will pad the 110 | # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls 111 | # _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16. 112 | # This can cause a memory spike however so we keep this off by default. 113 | pad_inner_dim: bool = False 114 | 115 | # If True, emulation is used instead of hardware accelerated gemm 116 | emulate: bool = False 117 | 118 | # Configuration for delayed scaling 119 | # Note: this is actually applied per-tensor, but only using the same 120 | # configuration for all tensors and layers in the model is currently 121 | # supported. If in the future we add support for a more fine grained 122 | # configuration, this field may move to per-tensor configs. 123 | delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() 124 | 125 | 126 | # If True, use 'fnuz' float8 types for calculations. 127 | # Currently, ROCm only supports fnuz variants. 128 | # TODO(future PR): move this to Float8LinearConfig 129 | use_fnuz_dtype = False 130 | -------------------------------------------------------------------------------- /float8_experimental/distributed_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Any 7 | 8 | import torch 9 | 10 | from fairscale.nn.model_parallel.initialize import get_model_parallel_group 11 | 12 | # from float8_tensor import Float8Tensor 13 | from float8_experimental.float8_tensor import Float8Tensor 14 | 15 | # additional differentiable distributed primitives for SP which are not in 16 | # the Fairscale codebase 17 | 18 | 19 | def _gather_along_first_dim(input_: torch.Tensor): 20 | # same as https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/model_parallel/mappings.py#L67, 21 | # but gather along first dim instead of last dim 22 | group = get_model_parallel_group() 23 | 24 | # Bypass the function if we are using only 1 GPU. 25 | if torch.distributed.get_world_size(group=group) == 1: 26 | return input_ 27 | 28 | # Size and dimension. 29 | first_dim = 0 30 | rank = torch.distributed.get_rank(group=group) 31 | world_size = torch.distributed.get_world_size(group=group) 32 | 33 | # If the input is a float8 tensor, we need to do the transformation on the 34 | # inner tensor and then return a new wrapper. 35 | def _transform(t): 36 | # tensors must be contiguous for all_gather to work 37 | input_contig = t.contiguous() 38 | 39 | tensor_list = [torch.empty_like(input_contig) for _ in range(world_size)] 40 | tensor_list[rank] = input_contig 41 | torch.distributed.all_gather(tensor_list, input_contig, group=group) 42 | 43 | # Note: torch.cat already creates a contiguous tensor. 44 | output = torch.cat(tensor_list, dim=first_dim).contiguous() 45 | return output 46 | 47 | if isinstance(input_, Float8Tensor): 48 | new_data = input_._data 49 | new_data = new_data.view(torch.int8) 50 | new_data = _transform(new_data) 51 | new_data = new_data.view(input_._data.dtype) 52 | output = Float8Tensor(new_data, input_._scale, input_._orig_dtype) 53 | else: 54 | output = _transform(input_) 55 | 56 | return output 57 | 58 | 59 | def _reduce_scatter(ctx: Any, input_: torch.Tensor): 60 | group = get_model_parallel_group() 61 | world_size = torch.distributed.get_world_size(group) 62 | 63 | assert input_.shape[0] % world_size == 0 64 | output_shape = (input_.shape[0] // world_size, *input_.shape[1:]) 65 | output = torch.empty(*output_shape, device=input_.device, dtype=input_.dtype) 66 | 67 | torch.distributed.reduce_scatter_tensor(output, input_, group=group) 68 | return output 69 | 70 | 71 | def _split_along_first_dim(input_: torch.Tensor): 72 | # this is needed for testing 73 | 74 | # like fairscale.nn.model_parallel.mappings._split, but 75 | # along the first dim instead of last dim 76 | 77 | group = get_model_parallel_group() 78 | local_rank = torch.distributed.get_rank(group) 79 | world_size = torch.distributed.get_world_size(group) 80 | 81 | assert input_.shape[0] % world_size == 0 82 | input_list = torch.split(input_, input_.shape[0] // world_size) 83 | return input_list[local_rank] 84 | 85 | 86 | class _AllGatherFloat8FwReduceScatterBw(torch.autograd.Function): 87 | @staticmethod 88 | def forward(ctx, input_): 89 | return _gather_along_first_dim(input_) 90 | 91 | @staticmethod 92 | def backward(ctx, grad_output): 93 | return _reduce_scatter(ctx, grad_output) 94 | 95 | 96 | class _ReduceScatterFwAllGatherFloat8Bw(torch.autograd.Function): 97 | @staticmethod 98 | def forward(ctx, input_): 99 | return _reduce_scatter(ctx, input_) 100 | 101 | @staticmethod 102 | def backward(ctx, grad_output): 103 | return _gather_along_first_dim(grad_output) 104 | 105 | 106 | class _AllGatherFwSplitBw(torch.autograd.Function): 107 | @staticmethod 108 | def forward(ctx, input_): 109 | return _gather_along_first_dim(input_) 110 | 111 | @staticmethod 112 | def backward(ctx, grad_output): 113 | return _split_along_first_dim(grad_output) 114 | -------------------------------------------------------------------------------- /float8_experimental/float8_aten_api.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | This file defines the aten functions for float8. Today, all of these functions 8 | are emulated. In the future, they should be calling NVIDIA's float8 kernels. 9 | """ 10 | 11 | import torch 12 | 13 | from torch.library import Library 14 | 15 | 16 | def mm_float8_emulated( 17 | m1, # input 1 data 18 | s1, # input 1 scale 19 | m2, # input 2 data 20 | s2, # input 2 scale 21 | dtype3, # output dtype 22 | ): 23 | # naive implementation: dq -> op -> q 24 | m1_fp32 = m1.float() / s1 25 | m2_fp32 = m2.float() / s2 26 | m3_fp32 = torch.mm(m1_fp32, m2_fp32) 27 | 28 | return m3_fp32.to(dtype3) 29 | 30 | 31 | # 32 | # ATen op placeholders 33 | # 34 | 35 | # Register the aten level functions we need. 36 | # These are mostly placeholder and might need to be implemented in c++ as needed 37 | lib = Library("aten", "FRAGMENT") 38 | 39 | lib.define( 40 | "mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> Tensor" 41 | ) 42 | lib.impl("mm_float8_emulated", mm_float8_emulated, "CPU") 43 | lib.impl("mm_float8_emulated", mm_float8_emulated, "CUDA") 44 | 45 | 46 | @torch.library.impl(lib, "mm_float8_emulated", "Meta") 47 | def _mm_float8_emulated_meta(m1, s1, m2, s2, dtype3): 48 | out = torch.mm(m1.float(), m2.float()).to(dtype3) 49 | return out 50 | -------------------------------------------------------------------------------- /float8_experimental/float8_linear_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | from typing import Callable, List, Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | import torch.nn as nn 12 | from float8_experimental.config import Float8LinearConfig, ScalingType 13 | from float8_experimental.float8_linear import Float8Linear 14 | 15 | from float8_experimental.float8_utils import ( 16 | amax_history_to_scale_stack, 17 | e4m3_dtype, 18 | e5m2_dtype, 19 | ) 20 | from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor 21 | 22 | log = logging.getLogger(__name__) 23 | log.addHandler(logging.NullHandler()) 24 | 25 | 26 | def linear_requires_sync(config: Float8LinearConfig): 27 | """Returns whether the given linear_type requires sync before forward.""" 28 | return any( 29 | [ 30 | config.cast_config_input.scaling_type is ScalingType.DELAYED, 31 | config.cast_config_weight.scaling_type is ScalingType.DELAYED, 32 | config.cast_config_grad_output.scaling_type is ScalingType.DELAYED, 33 | ] 34 | ) 35 | 36 | 37 | def _update_history_stack( 38 | new_amax: torch.Tensor, amax_history_stack: torch.Tensor 39 | ) -> torch.Tensor: 40 | """ 41 | Updates `amax_history` (the last N cur_amax values) inplace with the value 42 | of `new_amax`. 43 | 44 | Args: 45 | new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1) 46 | amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length) 47 | """ 48 | assert ( 49 | amax_history_stack.dim() == 2 50 | ), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}" 51 | assert new_amax.size(0) == amax_history_stack.size( 52 | 0 53 | ), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}" 54 | new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1) 55 | new_amax_history_stack[:, 0] = new_amax.squeeze(-1) 56 | amax_history_stack.copy_(new_amax_history_stack) 57 | 58 | 59 | def swap_linear_layers( 60 | module: nn.Module, 61 | from_float_func: Callable[[nn.Linear], nn.Linear], 62 | *, 63 | module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, 64 | ) -> nn.Module: 65 | """ 66 | Generic function to swap linear layers in a module with a new type of linear layer. 67 | 68 | Note: 69 | If applied to a root-level nn.Linear, the module will not be modified in place 70 | and returned instead 71 | 72 | Args: 73 | module: Module to modify. 74 | from_float_func: Function that accepts a linear layer and returns a new type of linear layer. 75 | module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that 76 | that pass the filter function will be swapped. The inputs to the 77 | filter function are the module instance, and the FQN. 78 | 79 | Returns: 80 | nn.Module: The modified module with swapped linear layers. 81 | """ 82 | if isinstance(module, nn.Linear) and ( 83 | module_filter_fn is None or module_filter_fn(module, "") 84 | ): 85 | if len(list(module.children())) > 0: 86 | raise AssertionError( 87 | f"Does not support a root nn.Linear with children: {module}" 88 | ) 89 | return from_float_func( 90 | module, 91 | ) 92 | 93 | root_module = module 94 | 95 | def post_order_traversal( 96 | module: nn.Module, 97 | cur_fqn: Optional[str] = None, 98 | parent_module: Optional[nn.Module] = None, 99 | ): 100 | if cur_fqn is None: 101 | cur_fqn = "" 102 | 103 | for child_module_name, child_module in module.named_children(): 104 | if cur_fqn == "": 105 | new_fqn = child_module_name 106 | else: 107 | new_fqn = f"{cur_fqn}.{child_module_name}" 108 | 109 | post_order_traversal(child_module, new_fqn, module) 110 | 111 | if isinstance(module, nn.Linear) and ( 112 | module_filter_fn is None or module_filter_fn(module, cur_fqn) 113 | ): 114 | assert ( 115 | parent_module is not None 116 | ), f"Linear root module should return early: {module}" 117 | new_linear_module = from_float_func(module) 118 | cur_module_name = cur_fqn.split(".")[-1] 119 | setattr(parent_module, cur_module_name, new_linear_module) 120 | 121 | post_order_traversal(root_module) 122 | return root_module 123 | 124 | 125 | def convert_to_float8_training( 126 | module: nn.Module, 127 | *, 128 | module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, 129 | config: Float8LinearConfig = None, 130 | ) -> nn.Module: 131 | """ 132 | Swaps `torch.nn.Linear` in `module` with `Float8Linear`. 133 | 134 | Args: 135 | module: Module to modify. 136 | module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that 137 | that pass the filter function will be swapped. The inputs to the 138 | filter function are the module instance and the FQN. 139 | config (Float8LinearConfig): configuration for conversion to float8 140 | 141 | Returns: 142 | nn.Module: The modified module with swapped linear layers. 143 | """ 144 | if config is None: 145 | config = Float8LinearConfig() 146 | from_float = lambda m: Float8Linear.from_float( 147 | m, 148 | config=config, 149 | ) 150 | return swap_linear_layers( 151 | module, 152 | from_float, 153 | module_filter_fn=module_filter_fn, 154 | ) 155 | 156 | 157 | def get_float8_layers(model: torch.nn.Module): 158 | """Iterates through the model and returns all the Float8Linear layers. 159 | Args: 160 | model (torch.nn.Module): The model to look for Float8Linear layers in. 161 | """ 162 | 163 | # Get all fp8 layers and tensors 164 | fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)] 165 | if not torch._dynamo.is_compiling(): 166 | for layer in fp8_layers: 167 | for buf in layer.buffers(): 168 | torch._dynamo.mark_static_address(buf, guard=True) 169 | return fp8_layers 170 | 171 | 172 | @torch.no_grad() 173 | def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None: 174 | """ 175 | Manages the float8 amax and scale bookkeeping. In detail, it does the 176 | following: 177 | 1. in distributed contexts, syncs amax values across workers for activations and gradients 178 | 2. adds the `amax` values to history 179 | 3. calculates the scales to be used for next iteration 180 | 4. sets the `amax_and_scale_synced` flag on the Float8Linear modules 181 | to signal that they have been synced 182 | 183 | TODO(future): design the UX for this (context manager, etc) 184 | 185 | PERFORMANCE NOTE: 186 | When you can, it is much more efficient to call get_float8_layers once at 187 | the beginning of the training loop and pass the result to this function. 188 | Because of how this interacts with torch.compile 189 | 190 | Args: 191 | model (torch.nn.Module): The model to track amaxes for 192 | fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored, 193 | and we loop over all fp8_layers to sync and update amax scale histories. 194 | Users can use get_float8_layers to get all fp8 layers. 195 | """ 196 | if fp8_layers is None: 197 | fp8_layers = get_float8_layers(model) 198 | 199 | if len(fp8_layers) == 0: 200 | log.warn( 201 | "Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers" 202 | ) 203 | return 204 | 205 | def inner_func(): 206 | """Why do we have this inner_function? 207 | 208 | There are two portions of the outer sync_function that cause graph_breaks: 209 | 1. The `get_float8_layers` call can cause graph breaks if the user did not pass 210 | in the fp8_layers. 211 | 2. At the end of syncing all the amaxes and scales we set the attr on the module 212 | signaling that we have synced the amaxes and scales and the next forward can be run. 213 | # TODO Maybe we should remove this safety check to remove the graph break? 214 | 215 | By having this inner function, we can ensure that although the outer function may cause graph breaks 216 | the inner function will not. 217 | """ 218 | # Loop over all fp8 layers and grab the needed tensors 219 | fp8_amax_input_tensor_list = [None] * len(fp8_layers) 220 | fp8_amax_weight_tensor_list = [None] * len(fp8_layers) 221 | fp8_amax_grad_output_tensor_list = [None] * len(fp8_layers) 222 | 223 | fp8_input_amax_history_stack = [None] * len(fp8_layers) 224 | fp8_weight_amax_history_stack = [None] * len(fp8_layers) 225 | fp8_grad_output_amax_history_stack = [None] * len(fp8_layers) 226 | 227 | x_dtypes = set() 228 | scale_fn_recipes = set() 229 | 230 | for idx, child in enumerate(fp8_layers): 231 | fp8_amax_input_tensor_list[idx] = child.fp8_amax_input 232 | fp8_amax_weight_tensor_list[idx] = child.fp8_amax_weight 233 | fp8_amax_grad_output_tensor_list[idx] = child.fp8_amax_grad_output 234 | 235 | fp8_input_amax_history_stack[idx] = child.fp8_amax_history_input 236 | fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight 237 | fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output 238 | 239 | x_dtypes.add(child.last_seen_input_dtype) 240 | scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) 241 | 242 | # TODO This way to get the activation dtype is not ideal 243 | if len(x_dtypes) != 1: 244 | raise ValueError( 245 | f"All layers must have the same last seen input_dtype, got {x_dtypes}" 246 | ) 247 | x_dtype = next(iter(x_dtypes)) 248 | 249 | if len(scale_fn_recipes) != 1: 250 | raise ValueError( 251 | f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" 252 | ) 253 | scale_fn_recipe = next(iter(scale_fn_recipes)) 254 | 255 | assert ( 256 | len(fp8_amax_input_tensor_list) 257 | == len(fp8_amax_weight_tensor_list) 258 | == len(fp8_amax_grad_output_tensor_list) 259 | ), "Mismatched lengths of amax tensors." 260 | 261 | if dist.is_initialized(): 262 | all_amax_tensors = torch.cat( 263 | fp8_amax_input_tensor_list 264 | + fp8_amax_weight_tensor_list 265 | + fp8_amax_grad_output_tensor_list 266 | ) 267 | all_reduced_amax_tensor = all_reduce( 268 | all_amax_tensors, "MAX", list(range(dist.get_world_size())) 269 | ) 270 | if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): 271 | all_reduced_amax_tensor = all_reduced_amax_tensor.wait() 272 | 273 | ( 274 | reduced_fp8_amax_input_tensor, 275 | reduced_fp8_amax_weight_tensor, 276 | reduced_fp8_amax_grad_output_tensor, 277 | ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_input_tensor_list)) 278 | 279 | for idx, child in enumerate(fp8_layers): 280 | child.fp8_amax_input.copy_(reduced_fp8_amax_input_tensor[idx]) 281 | child.fp8_amax_weight.copy_(reduced_fp8_amax_weight_tensor[idx]) 282 | child.fp8_amax_grad_output.copy_( 283 | reduced_fp8_amax_grad_output_tensor[idx] 284 | ) 285 | 286 | # We create two stacked tensor groups, one for the amax history and one for the current scales 287 | fp8_amax_input_tensors = torch.vstack(fp8_amax_input_tensor_list) 288 | fp8_amax_weight_tensors = torch.vstack(fp8_amax_weight_tensor_list) 289 | fp8_amax_grad_output_tensors = torch.vstack(fp8_amax_grad_output_tensor_list) 290 | 291 | fp8_input_amax_history_stack = torch.vstack(fp8_input_amax_history_stack) 292 | fp8_weight_amax_history_stack = torch.vstack(fp8_weight_amax_history_stack) 293 | fp8_grad_output_amax_history_stack = torch.vstack( 294 | fp8_grad_output_amax_history_stack 295 | ) 296 | 297 | # Update the history stacks with the new amax values 298 | _update_history_stack(fp8_amax_input_tensors, fp8_input_amax_history_stack) 299 | _update_history_stack(fp8_amax_weight_tensors, fp8_weight_amax_history_stack) 300 | _update_history_stack( 301 | fp8_amax_grad_output_tensors, fp8_grad_output_amax_history_stack 302 | ) 303 | 304 | # Calculate the new scales from the updated history stacks 305 | new_input_scales = amax_history_to_scale_stack( 306 | fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe 307 | ) 308 | new_weight_scales = amax_history_to_scale_stack( 309 | fp8_weight_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe 310 | ) 311 | new_grad_output_scales = amax_history_to_scale_stack( 312 | fp8_grad_output_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe 313 | ) 314 | 315 | # Iterate through the layers and update the scales 316 | for idx, child in enumerate(fp8_layers): 317 | child.fp8_scale_input.copy_(new_input_scales[idx]) 318 | child.fp8_scale_weight.copy_(new_weight_scales[idx]) 319 | child.fp8_scale_grad_output.copy_(new_grad_output_scales[idx]) 320 | 321 | # This allows for the compile to succede on the inner func and fail on the graph breaks 322 | # at the beginning and and of syncing 323 | inner_func() 324 | 325 | for child in fp8_layers: 326 | # Set a flag to signal amaxes/scales are ready 327 | child.amax_and_scale_synced = True 328 | -------------------------------------------------------------------------------- /float8_experimental/float8_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Any, Dict, Tuple 7 | 8 | import torch 9 | 10 | from float8_experimental.float8_python_api import addmm_float8_unwrapped 11 | from float8_experimental.float8_tensor import choose_scaled_mm_config, Float8Tensor 12 | from float8_experimental.float8_utils import is_row_major, pad_tensor_for_matmul 13 | 14 | from torch.utils._pytree import tree_map 15 | 16 | aten = torch.ops.aten 17 | c10d_functional = torch.ops.c10d_functional 18 | _c10d_functional = torch.ops._c10d_functional 19 | FLOAT8_OPS_TABLE: Dict[Any, Any] = {} 20 | 21 | 22 | def implements(aten_ops): 23 | """Register aten ops to the float8 op table""" 24 | 25 | def decorator(func): 26 | for op in aten_ops: 27 | FLOAT8_OPS_TABLE[op] = func 28 | return func 29 | 30 | return decorator 31 | 32 | 33 | @implements( 34 | [ 35 | aten.view.default, 36 | aten._unsafe_view.default, 37 | aten.t.default, 38 | aten.as_strided.default, 39 | aten.clone.default, 40 | aten.detach.default, 41 | aten.slice.Tensor, 42 | aten.transpose.int, 43 | aten.fill_.Scalar, 44 | ] 45 | ) 46 | def float8_desugar_op(aten_op, args, kwargs=None): 47 | new_data = aten_op(args[0]._data, *args[1:], **kwargs) 48 | return Float8Tensor( 49 | new_data, 50 | args[0]._scale, 51 | args[0]._orig_dtype, 52 | args[0]._linear_mm_config, 53 | args[0]._gemm_input_role, 54 | ) 55 | 56 | 57 | @implements([aten.split.Tensor]) 58 | def float8_split(aten_op, args, kwargs=None): 59 | new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs) 60 | 61 | def make_float8(data): 62 | return Float8Tensor( 63 | data, 64 | args[0]._scale, 65 | args[0]._orig_dtype, 66 | args[0]._linear_mm_config, 67 | args[0]._gemm_input_role, 68 | ) 69 | 70 | out = map(make_float8, new_data_tensors) 71 | return list(out) 72 | 73 | 74 | # Errors cant `cat_cuda float8 e4m3fn` 75 | @implements([aten.cat.default]) 76 | def float8_cat(aten_op, args, kwargs=None): 77 | chunked_tensors: Tuple[Float8Tensor] = args[0] 78 | 79 | orig_dtype = chunked_tensors[0]._orig_dtype 80 | scale = chunked_tensors[0]._scale 81 | mm_config = chunked_tensors[0]._linear_mm_config 82 | fp8_dtype = chunked_tensors[0]._data.dtype 83 | gemm_input_role = chunked_tensors[0]._gemm_input_role 84 | chunk_data = [] 85 | for chunk in chunked_tensors: 86 | assert isinstance( 87 | chunk, Float8Tensor 88 | ), "Expecting all chunks to be of type Float8Tensor" 89 | assert ( 90 | chunk._orig_dtype == orig_dtype 91 | ), "Expecting all chunks to be of the same dtype" 92 | assert ( 93 | chunk._scale is scale 94 | ), "Expecting all chunks to have thee same scale as a result of a split" 95 | assert ( 96 | chunk._linear_mm_config is mm_config 97 | ), "Expecting all chunks to have thee same mm config as a result of a split" 98 | assert ( 99 | chunk._data.dtype == fp8_dtype 100 | ), "Expecting all chunks to be of the same dtype as a result of a split" 101 | assert ( 102 | chunk._gemm_input_role is gemm_input_role 103 | ), "Expecting all chunks to have the same gemm_input_role as a result of a split" 104 | chunk_data.append(chunk._data.view(torch.uint8)) 105 | 106 | new_data = aten_op(chunk_data, *args[1:], **kwargs) 107 | new_data = new_data.view(fp8_dtype) 108 | return Float8Tensor(new_data, scale, orig_dtype, mm_config, gemm_input_role) 109 | 110 | 111 | @implements([aten.sum.dim_IntList]) 112 | def float8_cast_up_op(aten_op, args, kwargs=None): 113 | """Be careful with this function, this is a "fallback" op that 114 | casts the output of the op to the original precision. And performs the op. 115 | 116 | We currently need this to support the backward for admmm bias. 117 | "addmm" -> out 118 | "hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut" 119 | """ 120 | 121 | def unwrap(x): 122 | if isinstance(x, Float8Tensor): 123 | return x.to_original_precision() 124 | return x 125 | 126 | new_args = tree_map(unwrap, args) 127 | new_kwargs = tree_map(unwrap, kwargs) 128 | return aten_op(*new_args, **new_kwargs) 129 | 130 | 131 | def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): 132 | a_data = a._data 133 | a_scale = a._scale 134 | b_data = b._data 135 | 136 | scaled_mm_config = choose_scaled_mm_config( 137 | a._gemm_input_role, 138 | a._linear_mm_config, 139 | b._gemm_input_role, 140 | b._linear_mm_config, 141 | ) 142 | 143 | if scaled_mm_config.pad_inner_dim: 144 | assert a._data.size(1) == b._data.size( 145 | 0 146 | ), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}" 147 | a_data = pad_tensor_for_matmul(a_data, dims=1) 148 | b_data = pad_tensor_for_matmul(b_data, dims=0) 149 | 150 | if not is_row_major(a_data.stride()): 151 | a_data = a_data.contiguous() 152 | if is_row_major(b_data.stride()): 153 | b_data = b_data.t().contiguous().t() 154 | b_scale = b._scale 155 | return a_data, a_scale, b_data, b_scale 156 | 157 | 158 | @implements([aten.mm.default, aten.matmul.default]) 159 | def float8_mm(aten_op, args, kwargs=None): 160 | a = args[0] 161 | b = args[1] 162 | 163 | assert isinstance(a, Float8Tensor) and isinstance( 164 | b, Float8Tensor 165 | ), "Expecting both Float8Tensor for mm inputs but found {} and {}".format( 166 | type(a), type(b) 167 | ) 168 | a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) 169 | output_dtype = a._orig_dtype 170 | scaled_mm_config = choose_scaled_mm_config( 171 | a._gemm_input_role, 172 | a._linear_mm_config, 173 | b._gemm_input_role, 174 | b._linear_mm_config, 175 | ) 176 | if scaled_mm_config.emulate: 177 | return torch.ops.aten.mm_float8_emulated( 178 | a._data, a._scale, b._data, b._scale, output_dtype 179 | ) 180 | tensor_out = addmm_float8_unwrapped( 181 | a_data, 182 | a_scale, 183 | b_data, 184 | b_scale, 185 | output_dtype, 186 | output_scale=None, 187 | bias=None, 188 | use_fast_accum=scaled_mm_config.use_fast_accum, 189 | ) 190 | return tensor_out 191 | 192 | 193 | @implements([aten.addmm.default]) 194 | def float8_addmm(aten_op, args, kwargs=None): 195 | assert ( 196 | isinstance(args[0], torch.Tensor) 197 | and isinstance(args[1], Float8Tensor) 198 | and isinstance(args[2], Float8Tensor) 199 | ) 200 | bias = args[0] 201 | a = args[1] 202 | b = args[2] 203 | a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b) 204 | output_dtype = a._orig_dtype 205 | assert bias.dtype == output_dtype, "bias dtype must match output dtype" 206 | scaled_mm_config = choose_scaled_mm_config( 207 | a._gemm_input_role, 208 | a._linear_mm_config, 209 | b._gemm_input_role, 210 | b._linear_mm_config, 211 | ) 212 | if scaled_mm_config.emulate: 213 | out = torch.ops.aten.mm_float8_emulated( 214 | a._data, a._scale, b._data, b._scale, output_dtype 215 | ) 216 | return out + bias 217 | tensor_out = addmm_float8_unwrapped( 218 | a_data, 219 | a_scale, 220 | b_data, 221 | b_scale, 222 | output_dtype, 223 | output_scale=None, 224 | bias=bias, 225 | use_fast_accum=scaled_mm_config.use_fast_accum, 226 | ) 227 | return tensor_out 228 | 229 | 230 | @implements([aten.is_same_size.default]) 231 | def float8_is_same_size(aten_op, args, kwargs=None): 232 | return args[0].shape == args[1].shape 233 | 234 | 235 | @implements([aten._to_copy.default]) 236 | def autocast_to_copy(aten_op, args, kwargs=None): 237 | """This gets called when running matmul under autocast 238 | when the input is a Float8Tensor, presenting as a fp32 239 | tensor. 240 | """ 241 | assert isinstance(args[0], Float8Tensor) 242 | assert ( 243 | len(kwargs) == 1 and "dtype" in kwargs 244 | ), "Only support dtype kwarg for autocast" 245 | assert kwargs["dtype"] in { 246 | torch.float16, 247 | torch.bfloat16, 248 | }, "Only support floating point conversion for autocast w/ Float8Tensor" 249 | return Float8Tensor( 250 | args[0]._data, 251 | args[0]._scale, 252 | kwargs["dtype"], 253 | args[0]._linear_mm_config, 254 | args[0]._gemm_input_role, 255 | ) 256 | 257 | 258 | @implements( 259 | [ 260 | c10d_functional.all_gather_into_tensor.default, 261 | _c10d_functional.all_gather_into_tensor.default, 262 | ] 263 | ) 264 | def allgather_fp8(aten_op, args, kwargs=None): 265 | """ 266 | override funcol with FP8 handling 267 | """ 268 | fp8_input = args[0] 269 | assert isinstance( 270 | fp8_input, Float8Tensor 271 | ), f"expecting a Float8Tensor for allgather but found {type(fp8_input)}" 272 | 273 | fp8_data = fp8_input._data 274 | fp8_data = fp8_data.contiguous() 275 | fp8_out = aten_op(fp8_data, *args[1:], **kwargs) 276 | return Float8Tensor( 277 | fp8_out, 278 | fp8_input._scale, 279 | fp8_input._orig_dtype, 280 | fp8_input._linear_mm_config, 281 | fp8_input._gemm_input_role, 282 | ) 283 | 284 | 285 | @implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default]) 286 | def wait_tensor_fp8(aten_op, args, kwargs=None): 287 | fp8_input = args[0] 288 | assert isinstance(fp8_input, Float8Tensor) 289 | 290 | fp8_data = fp8_input._data 291 | fp8_out = aten_op(fp8_data, *args[1:], **kwargs) 292 | return Float8Tensor( 293 | fp8_out, 294 | fp8_input._scale, 295 | fp8_input._orig_dtype, 296 | fp8_input._linear_mm_config, 297 | fp8_input._gemm_input_role, 298 | ) 299 | 300 | 301 | @implements([aten.index_put_.default]) 302 | def index_put_fp8(aten_op, args, kwargs=None): 303 | fp8_self = args[0] 304 | fp8_values = args[2] 305 | assert isinstance(fp8_self, Float8Tensor) 306 | assert isinstance(fp8_values, Float8Tensor) 307 | assert fp8_self._scale == fp8_values._scale 308 | assert fp8_self.dtype == fp8_values.dtype 309 | assert fp8_self._orig_dtype == fp8_values._orig_dtype 310 | 311 | fp8_data = fp8_self._data 312 | fp8_values_data = fp8_values._data 313 | fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs) 314 | return Float8Tensor( 315 | fp8_out, 316 | fp8_self._scale, 317 | fp8_self._orig_dtype, 318 | fp8_self._linear_mm_config, 319 | fp8_self._gemm_input_role, 320 | ) 321 | 322 | 323 | @implements([aten.copy_.default]) 324 | def copy_fp8(aten_op, args, kwargs=None): 325 | # For a copy op with Float8Tensors involved, only the following combinations are allowed: 326 | # 1. self is a high precision (hp) tensor, src is a Float8Tensor: 327 | # in this case src is upcasted and unscaled to go into the hp tensor 328 | # 2. self and src are Float8Tensors: 329 | # the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat) 330 | # Every other combination is banned as the semantics are not well defined 331 | 332 | self = args[0] 333 | src = args[1] 334 | 335 | if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): 336 | src_hp = src.to_original_precision() 337 | return aten_op(self, src_hp, *args[2:], **kwargs) 338 | elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): 339 | assert ( 340 | self._orig_dtype == src._orig_dtype 341 | ), "Expecting both Float8Tensors to be of the same dtype" 342 | assert ( 343 | self._scale == src._scale 344 | ), "Expecting both Float8Tensors to have thee same scale" 345 | assert ( 346 | self._linear_mm_config == src._linear_mm_config 347 | ), "Expecting both Float8Tensors to have thee same mm config" 348 | assert ( 349 | self._data.dtype == src._data.dtype 350 | ), "Expecting both Float8Tensors to be of the same dtypet" 351 | assert ( 352 | self._gemm_input_role == src._gemm_input_role 353 | ), "Expecting both Float8Tensors to have the same gemm_input_role" 354 | fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs) 355 | return Float8Tensor( 356 | fp8_out, 357 | self._scale, 358 | self._orig_dtype, 359 | self._linear_mm_config, 360 | self._gemm_input_role, 361 | ) 362 | else: 363 | raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor") 364 | -------------------------------------------------------------------------------- /float8_experimental/float8_python_api.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | This file defines the Python functions for float8 which expect inputs 8 | of class `Float8Tensor`. This is a thin wrapper on top of the aten API 9 | to simplify the product code. 10 | """ 11 | 12 | from typing import Optional 13 | 14 | import float8_experimental.float8_aten_api # noqa 15 | 16 | import torch 17 | 18 | 19 | # [Note] Usage of scales 20 | # The meaning of scale in this library can be found in the definition of the Float8Tensor 21 | # Cublas defines scale to always mean a multiplicative factor for the respective matrices 22 | # For a,b going from fp8 -> fp32 we multiple by the inverse of the scale 23 | # For output going from fp32 -> fp8 we multiply by the scale 24 | def addmm_float8_unwrapped( 25 | a_data: torch.Tensor, 26 | a_scale: torch.Tensor, 27 | b_data: torch.Tensor, 28 | b_scale: torch.tensor, 29 | output_dtype: torch.dtype, 30 | output_scale: Optional[torch.Tensor] = None, 31 | bias: Optional[torch.Tensor] = None, 32 | use_fast_accum: bool = False, 33 | ) -> torch.Tensor: 34 | """ 35 | This is the unwrapped version of addmm_float8, which does not take in Float8Tensors 36 | as inputs. This is used to standardize the logic between subclassed and non subclassed 37 | versions of the linear module. 38 | """ 39 | a_inverse_scale = a_scale.reciprocal() 40 | b_inverse_scale = b_scale.reciprocal() 41 | if output_dtype == torch.float32 and bias is not None: 42 | # Bias is not supported by _scaled_mm when output is fp32 43 | output = torch._scaled_mm( 44 | a_data, 45 | b_data, 46 | scale_a=a_inverse_scale, 47 | scale_b=b_inverse_scale, 48 | scale_result=output_scale, 49 | out_dtype=output_dtype, 50 | use_fast_accum=use_fast_accum, 51 | ) 52 | output += bias 53 | return output 54 | output = torch._scaled_mm( 55 | a_data, 56 | b_data, 57 | scale_a=a_inverse_scale, 58 | scale_b=b_inverse_scale, 59 | bias=bias, 60 | scale_result=output_scale, 61 | out_dtype=output_dtype, 62 | use_fast_accum=use_fast_accum, 63 | ) 64 | return output 65 | -------------------------------------------------------------------------------- /float8_experimental/float8_scaling_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Utilities for scaling high precision tensors to float8. 9 | """ 10 | 11 | from typing import Optional 12 | 13 | import torch 14 | 15 | from float8_experimental.float8_tensor import ( 16 | Float8Tensor, 17 | GemmInputRole, 18 | hp_tensor_and_scale_to_float8, 19 | LinearMMConfig, 20 | ScaledMMConfig, 21 | tensor_already_casted_to_fp8, 22 | ) 23 | 24 | from float8_experimental.float8_utils import ( 25 | amax_history_to_scale, 26 | e4m3_dtype, 27 | e5m2_dtype, 28 | tensor_to_amax, 29 | tensor_to_scale, 30 | ) 31 | 32 | 33 | def hp_tensor_to_float8_dynamic( 34 | hp_tensor: torch.Tensor, 35 | float8_dtype: torch.dtype, 36 | linear_mm_config: LinearMMConfig, 37 | reduce_amax: bool = False, 38 | gemm_input_role: GemmInputRole = GemmInputRole.INPUT, 39 | ) -> Float8Tensor: 40 | """ 41 | Given a high precision tensor `hp_tensor`, 42 | scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result. 43 | 44 | Args: 45 | hp_tensor: the tensor to convert 46 | float8_dtype: the float8 dtype to use 47 | linear_mm_config: Defines the configuration for the scaled_mm for 48 | the 3 fwd/bwd gemms of linear 49 | reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks 50 | gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in 51 | the 3 fwd/bwd gemms of linear 52 | """ 53 | if tensor_already_casted_to_fp8(hp_tensor): 54 | return hp_tensor 55 | scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax) 56 | return hp_tensor_and_scale_to_float8( 57 | hp_tensor, 58 | scale, 59 | float8_dtype, 60 | linear_mm_config, 61 | gemm_input_role, 62 | ) 63 | 64 | 65 | def hp_tensor_to_float8_delayed( 66 | hp_tensor: torch.Tensor, 67 | s: torch.Tensor, 68 | float8_dtype: torch.dtype, 69 | amax_buffer: torch.Tensor, 70 | linear_mm_config: Optional[LinearMMConfig] = None, 71 | gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, 72 | ) -> Float8Tensor: 73 | """ 74 | Given a high precision tensor `hp_tensor` and relevant metadata, scales it using 75 | delayed scaling and returns a `Float8Tensor` of the result. Specifically: 76 | 1. calculates max(abs(hp_tensor)) and stores the result in `amax_buffer`, inplace 77 | 2. scales `hp_tensor` by `s` and returns the result wrapped in Float8Tensor 78 | 79 | Args: 80 | hp_tensor: the tensor to convert 81 | s: the scale to use to convert the tensor 82 | float8_dtype: the float8 dtype to use 83 | amax_buffer: the buffer to modify inplace with max(abs(hp_tensor)) 84 | linear_mm_config: Defines the configuration for the scaled_mm for 85 | the 3 fwd/bwd gemms of linear 86 | gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in 87 | the 3 fwd/bwd gemms of linear 88 | """ 89 | amax_buffer.fill_(tensor_to_amax(hp_tensor)) 90 | return hp_tensor_and_scale_to_float8( 91 | hp_tensor, 92 | s, 93 | float8_dtype, 94 | linear_mm_config, 95 | gemm_input_role, 96 | ) 97 | 98 | 99 | def _maybe_initialize_amaxes_scales_for_float8_cast( 100 | x, 101 | cur_amax, 102 | amax_history, 103 | scale, 104 | scale_fn_name, 105 | float8_dtype, 106 | is_initialized, 107 | reduce_amax, 108 | ): 109 | """ 110 | If x is about to be cast to `float8` and the amax buffers are not initialized, 111 | initializes them inplace. 112 | """ 113 | if is_initialized: 114 | return 115 | with torch.no_grad(): 116 | # Note: we need to enable distributed reduction here in order 117 | # to match numerics between single GPU and multi GPU code for 118 | # activations and gradients 119 | new_amax = tensor_to_amax(x, reduce_amax=reduce_amax) 120 | cur_amax.fill_(new_amax) 121 | amax_history[0] = new_amax 122 | new_scale = amax_history_to_scale( 123 | amax_history, float8_dtype, x.dtype, scale_fn_name 124 | ) 125 | scale.copy_(new_scale) 126 | 127 | 128 | @torch._dynamo.allow_in_graph 129 | class NoopFwToFloat8E5M2BwDelayed(torch.autograd.Function): 130 | """ 131 | Forward: no-op 132 | Backward: convert to float8_e5m2 with delayed scaling, initialize if needed 133 | """ 134 | 135 | @staticmethod 136 | def forward( 137 | ctx, 138 | tensor, 139 | fp8_amax_grad_output, 140 | fp8_amax_history_grad_output, 141 | fp8_scale_grad_output, 142 | scale_fn_name, 143 | is_amax_initialized, 144 | linear_mm_config: LinearMMConfig, 145 | ): 146 | ctx.save_for_backward( 147 | fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output 148 | ) 149 | ctx.scale_fn_name = scale_fn_name 150 | ctx.is_amax_initialized = is_amax_initialized 151 | ctx.linear_mm_config = linear_mm_config 152 | return tensor 153 | 154 | @staticmethod 155 | def backward(ctx, go): 156 | ( 157 | fp8_amax_grad_output, 158 | fp8_amax_history_grad_output, 159 | fp8_scale_grad_output, 160 | ) = ctx.saved_tensors 161 | scale_fn_name = ctx.scale_fn_name 162 | is_amax_initialized = ctx.is_amax_initialized 163 | 164 | _maybe_initialize_amaxes_scales_for_float8_cast( 165 | go, 166 | fp8_amax_grad_output, 167 | fp8_amax_history_grad_output, 168 | fp8_scale_grad_output, 169 | scale_fn_name, 170 | e5m2_dtype, 171 | is_amax_initialized, 172 | reduce_amax=True, 173 | ) 174 | 175 | fp8_amax_grad_output.fill_(tensor_to_amax(go)) 176 | 177 | res = hp_tensor_and_scale_to_float8( 178 | go, 179 | fp8_scale_grad_output, 180 | e5m2_dtype, 181 | ctx.linear_mm_config, 182 | GemmInputRole.GRAD_OUTPUT, 183 | ) 184 | empty_grads = None, None, None, None, None, None 185 | return res, *empty_grads 186 | 187 | 188 | @torch._dynamo.allow_in_graph 189 | class NoopFwToFloat8E5M2BwDynamic(torch.autograd.Function): 190 | """ 191 | Forward: no-op 192 | Backward: convert to float8_e5m2 with dynamic scaling 193 | """ 194 | 195 | @staticmethod 196 | def forward( 197 | ctx, 198 | tensor, 199 | linear_mm_config: LinearMMConfig, 200 | ): 201 | ctx.linear_mm_config = linear_mm_config 202 | return tensor 203 | 204 | @staticmethod 205 | def backward(ctx, gradY): 206 | if tensor_already_casted_to_fp8(gradY): 207 | return gradY, None 208 | gradY_scale = tensor_to_scale(gradY, e5m2_dtype) 209 | fp8_tensor = hp_tensor_and_scale_to_float8( 210 | gradY, 211 | gradY_scale, 212 | e5m2_dtype, 213 | ctx.linear_mm_config, 214 | GemmInputRole.GRAD_OUTPUT, 215 | ) 216 | return fp8_tensor, None 217 | -------------------------------------------------------------------------------- /float8_experimental/float8_tensor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import enum 7 | from collections import namedtuple 8 | from typing import Dict, Optional 9 | 10 | import torch 11 | 12 | import torch.distributed._functional_collectives as funcol 13 | from float8_experimental.float8_utils import ( 14 | e4m3_dtype, 15 | tensor_to_amax, 16 | to_fp8_saturated, 17 | ) 18 | from torch.distributed._tensor import DTensor 19 | 20 | aten = torch.ops.aten 21 | 22 | # 23 | # A note on configuration of float8 logic in a linear 24 | # TODO(future): move all the configs to separate file 25 | # TODO(future): change this to input/weight/grad_output notation, 26 | # can be separate PR because none of this is user facing 27 | # 28 | # There are three gemms in a forward + backward of a Linear layer: 29 | # 30 | # 1. input @ weight_t = output (forward pass) 31 | # 2. grad_output @ weight = grad_input (backward pass) 32 | # 3. input_t @ grad_output = grad_weight (backward pass) 33 | # 34 | # In the formulas above, there are: 35 | # A. six input tensors (input, input_t, weight, weight_t, grad_output, grad_output_t). 36 | # - Note that grad_output_t is implied because of memory format requirements 37 | # of float8 gemms 38 | # B. three output tensors (output, grad_input, grad_weight) 39 | # 40 | # We want each input tensor, gemm, and output tensor to be configurable. 41 | # The state of this configuration today is: 42 | # 43 | # i. pairs of input tensors (non-t and t variants) have their scaling 44 | # configurable via the scaling_type_* arguments to Float8Linear 45 | # ii. each gemm + output is configurable via ScaledMMConfig, which is not user facing 46 | # iii. LinearMMConfig is a container for the three ScaledMMConfig objects needed 47 | # to configure all three gemms, also not user facing 48 | 49 | 50 | # ScaledMMConfig is a namedtuple that defines the configuration for the scaled_mm in the forward and backward pass. 51 | # emulate: whether to emulate the matmuls in fp32 52 | # use_fast_accum: whether to use the fast-accumulation option for scaled_mm 53 | # fp8_output: whether to output the result of the scaled_mm in fp8 54 | # pad_inner_dim: whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16. 55 | ScaledMMConfig = namedtuple( 56 | "ScaledMMConfig", 57 | ["emulate", "use_fast_accum", "fp8_output", "pad_inner_dim"], 58 | defaults=[False, False, False, False], 59 | ) 60 | 61 | # The object below is not user facing and exists for convenience, 62 | # to allow Float8Tensor to use 63 | # the right config based on which gemm from gemms with outputs 64 | # `output`, `grad_input`, `grad_weight` is 65 | # being called. 66 | LinearMMConfig = namedtuple( 67 | "LinearMMConfig", 68 | ["output", "grad_input", "grad_weight"], 69 | defaults=[ 70 | ScaledMMConfig(False, True, False, False), 71 | ScaledMMConfig(False, False, False, False), 72 | ScaledMMConfig(False, False, False, False), 73 | ], 74 | ) 75 | 76 | 77 | class GemmInputRole(enum.Enum): 78 | """ 79 | Given a Float8Tensor, the enum below describes the expected role of this 80 | tensor in the three gemms present in the fw + bw pass of a Linear layer. 81 | This is used to choose the right config for a float8 gemm when the 82 | gemm is performed. 83 | """ 84 | 85 | INPUT = "input" 86 | WEIGHT = "weight" 87 | GRAD_OUTPUT = "grad_output" 88 | 89 | 90 | # choose which scaled_mm_config to use based on gemm inputs 91 | def choose_scaled_mm_config( 92 | a_role: GemmInputRole, 93 | a_linear_mm_config: LinearMMConfig, 94 | b_role: GemmInputRole, 95 | b_linear_mm_config: LinearMMConfig, 96 | ): 97 | if a_role is GemmInputRole.INPUT and b_role is GemmInputRole.WEIGHT: 98 | assert ( 99 | a_linear_mm_config.output == b_linear_mm_config.output 100 | ), f"linear_mm_config.output mismatch: {a_linear_mm_config.output} vs {b_linear_mm_config.output}" 101 | return a_linear_mm_config.output 102 | elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.WEIGHT: 103 | assert ( 104 | a_linear_mm_config.grad_input == b_linear_mm_config.grad_input 105 | ), f"linear_mm_config.grad_input mismatch: {a_linear_mm_config.grad_input} vs {b_linear_mm_config.grad_input}" 106 | return a_linear_mm_config.grad_input 107 | elif a_role is GemmInputRole.GRAD_OUTPUT and b_role is GemmInputRole.INPUT: 108 | assert ( 109 | a_linear_mm_config.grad_weight == b_linear_mm_config.grad_weight 110 | ), f"linear_mm_config.grad_weight mismatch: {a_linear_mm_config.grad_weight} vs {b_linear_mm_config.grad_weight}" 111 | return a_linear_mm_config.grad_weight 112 | else: 113 | raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}") 114 | 115 | 116 | def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool: 117 | """ 118 | Check if the tensor is already casted to fp8 119 | """ 120 | if isinstance(tensor, Float8Tensor): 121 | return True 122 | elif isinstance(tensor, DTensor): 123 | # TODO: shall we stick to public API and directly use tensor.to_local() here? 124 | return tensor_already_casted_to_fp8(tensor._local_tensor) 125 | elif isinstance(tensor, funcol.AsyncCollectiveTensor): 126 | return tensor_already_casted_to_fp8(tensor.elem) 127 | 128 | return False 129 | 130 | 131 | @torch._dynamo.allow_in_graph 132 | class _ToFloat8ConstrFunc(torch.autograd.Function): 133 | """ 134 | A differentiable conversion to fp8. 135 | * forward: convert from high precision to float8 136 | * backward: pass the gradient without changes 137 | """ 138 | 139 | @staticmethod 140 | def forward( 141 | ctx, 142 | tensor: torch.Tensor, 143 | scale: torch.Tensor, 144 | float8_dtype=e4m3_dtype, 145 | linear_mm_config: Optional[LinearMMConfig] = None, 146 | gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, 147 | ): 148 | """ 149 | This function will apply the scaling, and then convert to a Float8Tensor 150 | 151 | Note: 152 | We will call this function with a DTensor subclass. Ideally this would be an aten OP 153 | that DTensor could overload to ensure proper semantics. There are some techincal issues 154 | with that composing with FakeTensor, so we special case here. 155 | 156 | DTensor Invariant: DTensor must always be the outer most tensor subclass 157 | """ 158 | tensor_scaled = tensor * scale 159 | bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) 160 | 161 | if isinstance(bits_fp8, DTensor): 162 | assert isinstance( 163 | scale, DTensor 164 | ), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor" 165 | bits_mesh = bits_fp8.device_mesh 166 | bits_placements = bits_fp8.placements 167 | local_bits = bits_fp8.to_local() 168 | local_scale = scale.to_local() 169 | inner_float8_tensor = Float8Tensor( 170 | local_bits, 171 | local_scale, 172 | tensor.dtype, 173 | linear_mm_config=linear_mm_config, 174 | gemm_input_role=gemm_input_role, 175 | ) 176 | return DTensor.from_local( 177 | inner_float8_tensor, 178 | bits_mesh, 179 | bits_placements, 180 | run_check=False, 181 | shape=bits_fp8.size(), 182 | stride=bits_fp8.stride(), 183 | ) 184 | 185 | return Float8Tensor( 186 | bits_fp8, 187 | scale, 188 | tensor.dtype, 189 | linear_mm_config=linear_mm_config, 190 | gemm_input_role=gemm_input_role, 191 | ) 192 | 193 | @staticmethod 194 | def backward(ctx, g): 195 | return g, None, None, None, None, None 196 | 197 | 198 | @torch._dynamo.allow_in_graph 199 | class _FromFloat8ConstrFunc(torch.autograd.Function): 200 | """ 201 | A differentiable conversion from fp8. 202 | * forward: convert from float8 to high precision 203 | * backward: pass the gradient without changes 204 | """ 205 | 206 | @staticmethod 207 | def forward(ctx, tensor): 208 | return tensor._data.to(tensor._orig_dtype) / tensor._scale 209 | 210 | @staticmethod 211 | def backward(ctx, g): 212 | return g, None, None 213 | 214 | 215 | def hp_tensor_and_scale_to_float8( 216 | hp_tensor: torch.Tensor, 217 | s: torch.Tensor, 218 | float8_dtype=e4m3_dtype, 219 | linear_mm_config: Optional[LinearMMConfig] = None, 220 | gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, 221 | ): 222 | """ 223 | Given a high precision tensor `hp_tensor` and a precalculated scale `s`, 224 | scales `hp_tensor` by `s` and returns a `Float8Tensor` of the result. 225 | 226 | Autograd-aware, the derivative is pass-through. 227 | DTensor-aware, if the input is a DTensor the output will be DTensor(Float8Tensor). 228 | 229 | Args: 230 | hp_tensor: the tensor to convert 231 | s: the scale to use to convert the tensor 232 | float8_dtype: the float8 dtype to use 233 | linear_mm_config: Defines the configuration for the scaled_mm for 234 | the 3 fwd/bwd gemms of linear 235 | gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in 236 | the 3 fwd/bwd gemms of linear 237 | """ 238 | return _ToFloat8ConstrFunc.apply( 239 | hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role 240 | ) 241 | 242 | 243 | class Float8Tensor(torch.Tensor): 244 | """ 245 | Note: this is **not** a public API and is only intended to be used 246 | inside of this repository. Please file an issue if you would benefit 247 | from this being a public API. 248 | 249 | A Python-only Float8 tensor subclass. Contains: 250 | * `_data`: the underlying e4m3 or e5m2 data 251 | * `_scale`: the scale used to scale the original fp32 tensor. We multiply 252 | by scale to go from fp32 range to fp8 range, and divide by scale to go 253 | from fp8 range to fp32 range. 254 | * `_orig_dtype`: the original dtype of the tensor used to create this 255 | tensor. 256 | * `_emulate`: if true using fp32 emulation for the matmuls, helpful 257 | if you don't have access to h100 hardware. 258 | 259 | Intended usage of this abstraction: 260 | 1. to bundle raw data + fp8 metadata together for easy passing through 261 | Python PyTorch systems. 262 | 2. Float8-aware user code can use the private fields on these tensors 263 | to call into float8 operations. 264 | 3. Float8-agnostic user code can use these tensors as is - they will 265 | convert to original precision in `__torch_dispatch__`. 266 | """ 267 | 268 | _data: torch.Tensor 269 | _scale: torch.Tensor 270 | _orig_dtype: torch.dtype 271 | _linear_mm_config: LinearMMConfig 272 | __slots__ = ["_data", "_scale", "_orig_dtype", "_linear_mm_config"] 273 | 274 | def __new__( 275 | cls, 276 | data: torch.Tensor, 277 | scale: torch.Tensor, 278 | orig_dtype: torch.dtype, 279 | linear_mm_config: Optional[LinearMMConfig], 280 | gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, 281 | ): 282 | assert ( 283 | scale.numel() == 1 284 | ), "Scale should contain a single value, but got: {} elements".format( 285 | scale.numel() 286 | ) 287 | 288 | self = torch.Tensor._make_wrapper_subclass( 289 | cls, 290 | data.size(), 291 | strides=data.stride(), 292 | storage_offset=data.storage_offset(), 293 | dtype=orig_dtype, 294 | layout=data.layout, 295 | requires_grad=data.requires_grad, 296 | device=data.device, 297 | ) 298 | self._data = data 299 | self._scale = scale 300 | self._orig_dtype = orig_dtype 301 | self._linear_mm_config = ( 302 | linear_mm_config if linear_mm_config is not None else LinearMMConfig() 303 | ) 304 | self._gemm_input_role = gemm_input_role 305 | 306 | return self 307 | 308 | def __repr__(self): 309 | return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, linear_mm_config={self._linear_mm_config}\ngemm_input_role={self._gemm_input_role}\nas_orig_prec={self.to_original_precision()}" 310 | 311 | def __tensor_flatten__(self): 312 | ctx = { 313 | "_orig_dtype": self._orig_dtype, 314 | "_linear_mm_config": self._linear_mm_config, 315 | "_gemm_input_role": self._gemm_input_role, 316 | } 317 | return ["_data", "_scale"], ctx 318 | 319 | @staticmethod 320 | def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride): 321 | assert len(inner_tensors) == 2 322 | return Float8Tensor( 323 | inner_tensors["_data"], 324 | inner_tensors["_scale"], 325 | metadata["_orig_dtype"], 326 | metadata["_linear_mm_config"], 327 | metadata["_gemm_input_role"], 328 | ) 329 | 330 | def to_original_precision(self): 331 | return _FromFloat8ConstrFunc.apply(self) 332 | 333 | @classmethod 334 | def __torch_dispatch__(cls, func, types, args, kwargs=None): 335 | # 1. tracing through __torch_function__ logic is not supported yet in 336 | # PT2.0, so we explicitly disallow it here for callsites from user code. 337 | # 2. We do need to handle a couple of ops in order for 338 | # TorchDynamo tracing to succeed. 339 | 340 | # Lazy import to avoid circular dependency 341 | from float8_experimental.float8_ops import FLOAT8_OPS_TABLE 342 | 343 | # All ops in the FLOAT8_OPS_TABLE expect Float8Tensor as inputs 344 | # And don't support mixed tensor subclasses. This will trigger the handler for 345 | # the next type in the dispatch list 346 | def allowed_subclasses(type): 347 | return ( 348 | issubclass(cls, type) 349 | or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) 350 | or issubclass( 351 | torch._subclasses.functional_tensor.FunctionalTensor, type 352 | ) 353 | ) 354 | 355 | if not all(allowed_subclasses(t) for t in types): 356 | return NotImplemented 357 | 358 | if func in FLOAT8_OPS_TABLE: 359 | return FLOAT8_OPS_TABLE[func](func, args, kwargs) 360 | raise NotImplementedError(f"attempting to run {func}, this is not supported") 361 | 362 | # Do not force the Float8Tensor type on the returned tensor 363 | __torch_function__ = torch._C._disabled_torch_function_impl 364 | -------------------------------------------------------------------------------- /float8_experimental/float8_tensor_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from float8_experimental.config import ScalingType 4 | from float8_experimental.float8_scaling_utils import ( 5 | hp_tensor_to_float8_dynamic, 6 | NoopFwToFloat8E5M2BwDynamic, 7 | ) 8 | from float8_experimental.float8_tensor import GemmInputRole 9 | from float8_experimental.float8_utils import e4m3_dtype 10 | from torch.distributed._tensor import DTensor 11 | from torch.distributed.device_mesh import DeviceMesh 12 | from torch.distributed.tensor.parallel import ( 13 | ColwiseParallel, 14 | PrepareModuleInput, 15 | RowwiseParallel, 16 | ) 17 | 18 | # subclass the ColwiseParallel and RowwiseParallel classes 19 | # to add the float8 support 20 | # The parameter sharding stays the same as the core 21 | # ColwiseParallel and RowwiseParallel, the only difference 22 | # here is that in input/output handling we do casting after 23 | # creating the DTensor. 24 | 25 | # NOTE: This only works and tested with the dynamic scaling 26 | 27 | 28 | def _float8_linear_supports_float8_allgather(m): 29 | # TODO(future): add support for delayed scaling for activations 30 | # and gradients 31 | return ( 32 | m.scaling_type_input == ScalingType.DYNAMIC 33 | and m.scaling_type_grad_output == ScalingType.DYNAMIC 34 | ) 35 | 36 | 37 | class Float8ColwiseParallel(ColwiseParallel): 38 | @staticmethod 39 | def _prepare_input_fn( 40 | input_layouts, desired_input_layouts, mod, inputs, device_mesh 41 | ): 42 | # annotate module input placements/sharding with input_layouts 43 | input_tensor = inputs[0] 44 | if not isinstance(input_tensor, DTensor): 45 | input_tensor = DTensor.from_local( 46 | input_tensor, device_mesh, input_layouts, run_check=False 47 | ) 48 | 49 | input_tensor = hp_tensor_to_float8_dynamic( 50 | input_tensor, 51 | e4m3_dtype, 52 | mod.linear_mm_config, 53 | gemm_input_role=GemmInputRole.INPUT, 54 | ) # DTensor(Float8Tensor) 55 | 56 | # transform the input layouts to the desired layouts of ColwiseParallel 57 | if input_layouts != desired_input_layouts: 58 | input_tensor = input_tensor.redistribute( 59 | placements=desired_input_layouts, async_op=True 60 | ) 61 | return input_tensor 62 | 63 | @staticmethod 64 | def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): 65 | # outputs is a shard on last dimension DTensor, i.e. Shard(-1) 66 | if outputs.placements != output_layouts: 67 | outputs = outputs.redistribute( 68 | placements=output_layouts, async_op=True 69 | ) # DTensor(torch.Tensor) 70 | 71 | # fwd noop bwd cast to DTensor(Float8Tensor) 72 | outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) 73 | 74 | # back to local tensor 75 | return outputs.to_local() if use_local_output else outputs 76 | 77 | def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 78 | from float8_experimental.float8_linear import Float8Linear 79 | 80 | if not isinstance(module, Float8Linear): 81 | raise ValueError( 82 | f"Expecting module to be Float8Linear but found {type(module)}" 83 | ) 84 | elif isinstance( 85 | module, Float8Linear 86 | ) and not _float8_linear_supports_float8_allgather(module): 87 | raise AssertionError("unsupported") 88 | 89 | return super()._apply(module, device_mesh) 90 | 91 | 92 | class Float8RowwiseParallel(RowwiseParallel): 93 | @staticmethod 94 | def _prepare_input_fn( 95 | input_layouts, desired_input_layouts, mod, inputs, device_mesh 96 | ): 97 | input_tensor = inputs[0] 98 | if not isinstance(input_tensor, DTensor): 99 | input_tensor = DTensor.from_local( 100 | input_tensor, device_mesh, input_layouts, run_check=False 101 | ) 102 | 103 | input_tensor = hp_tensor_to_float8_dynamic( 104 | input_tensor, 105 | e4m3_dtype, 106 | mod.linear_mm_config, 107 | gemm_input_role=GemmInputRole.INPUT, 108 | ) # DTensor(Float8Tensor) 109 | 110 | if input_layouts != desired_input_layouts: 111 | input_tensor = input_tensor.redistribute( 112 | placements=desired_input_layouts, async_op=True 113 | ) 114 | return input_tensor 115 | 116 | @staticmethod 117 | def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): 118 | # Rowwise sharding produces partial output, depending on output layouts: 119 | # 1. to replicate -> allreduce 120 | # 2. to shard -> reduce_scatter 121 | if outputs.placements != output_layouts: 122 | outputs = outputs.redistribute(placements=output_layouts, async_op=True) 123 | 124 | # fwd noop bwd cast to DTensor(Float8Tensor) 125 | outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) 126 | 127 | # back to local tensor if use_local_output is True 128 | return outputs.to_local() if use_local_output else outputs 129 | 130 | def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 131 | from float8_experimental.float8_linear import Float8Linear 132 | 133 | if not isinstance(module, Float8Linear): 134 | raise ValueError( 135 | f"Expecting module to be Float8Linear but found {type(module)}" 136 | ) 137 | elif isinstance( 138 | module, Float8Linear 139 | ) and not _float8_linear_supports_float8_allgather(module): 140 | raise AssertionError("unsupported") 141 | 142 | return super()._apply(module, device_mesh) 143 | 144 | 145 | class PrepareFloat8ModuleInput(PrepareModuleInput): 146 | # subclass the PrepareModuleInput classes to implement fp8 specific logic, the only difference is that 147 | # after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor) 148 | # This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate) 149 | # so that if there are multiple float8 users of the input activation, we perform fp8 allgather 150 | # only once. 151 | # FP8 Args: 152 | # float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input, 153 | # we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn 154 | # fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used 155 | # for the float8 cast. If not specified, we will search for the Float8Linear in the submodules 156 | # and use the forward config from that module, in this case all module's forward config must be 157 | # the same. 158 | 159 | def __init__( 160 | self, 161 | *, 162 | input_layouts=None, 163 | desired_input_layouts=None, 164 | input_kwarg_layouts=None, 165 | desired_input_kwarg_layouts=None, 166 | use_local_output=False, 167 | float8_dtype=torch.float8_e4m3fn, 168 | fwd_config_submodule_fqn=None, 169 | ): 170 | super().__init__( 171 | input_layouts=input_layouts, 172 | desired_input_layouts=desired_input_layouts, 173 | input_kwarg_layouts=input_kwarg_layouts, 174 | desired_input_kwarg_layouts=desired_input_kwarg_layouts, 175 | use_local_output=use_local_output, 176 | ) 177 | 178 | # fp8 specific fields 179 | self.float8_dtype = float8_dtype 180 | self.linear_mm_config = None 181 | self.fwd_config_submodule_fqn = fwd_config_submodule_fqn 182 | 183 | if self.float8_dtype != torch.float8_e4m3fn: 184 | raise NotImplementedError( 185 | "PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now" 186 | ) 187 | 188 | def _prepare_input_arg(self, input, mesh, input_layout, desired_layout): 189 | if input_layout is not None: 190 | if isinstance(input, DTensor): 191 | # TODO: re-enable the check once we fix the compile path 192 | # assert inp.placements[0] == input_layout 193 | dt_inp = input 194 | else: 195 | assert isinstance( 196 | input, torch.Tensor 197 | ), "expecting input to be a torch.Tensor!" 198 | dt_inp = DTensor.from_local( 199 | input, mesh, (input_layout,), run_check=False 200 | ) 201 | 202 | dt_inp = hp_tensor_to_float8_dynamic( 203 | dt_inp, 204 | e4m3_dtype, 205 | self.linear_mm_config, 206 | gemm_input_role=GemmInputRole.INPUT, 207 | ) # DTensor(Float8Tensor) 208 | if desired_layout is not None and input_layout != desired_layout: 209 | dt_inp = dt_inp.redistribute(placements=(desired_layout,)) 210 | 211 | return dt_inp.to_local() if self.use_local_output else dt_inp 212 | else: 213 | return input 214 | 215 | def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 216 | from float8_experimental.float8_linear import Float8Linear 217 | 218 | if self.fwd_config_submodule_fqn is not None: 219 | fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn) 220 | assert isinstance(fwd_linear, Float8Linear) 221 | self.linear_mm_config = fwd_linear.linear_mm_config 222 | else: 223 | # search for ScaledMM configs for all the submodules and make sure they are the same 224 | for mod in module.modules(): 225 | if isinstance(mod, Float8Linear): 226 | if self.linear_mm_config is None: 227 | self.linear_mm_config = mod.linear_mm_config 228 | else: 229 | assert ( 230 | self.linear_mm_config == mod.linear_mm_config 231 | ), "All the Float8Linear modules should have same linear_mm_config!" 232 | 233 | assert self.linear_mm_config is not None 234 | super()._apply(module, device_mesh) 235 | return module 236 | -------------------------------------------------------------------------------- /float8_experimental/float8_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Iterable, Literal, Tuple, Union 8 | 9 | import float8_experimental.config as config 10 | 11 | import torch 12 | import torch.distributed as dist 13 | 14 | # Helpful visualizer for debugging (only supports fp32): 15 | # https://www.h-schmidt.net/FloatConverter/IEEE754.html 16 | 17 | # avoid division by zero when calculating scale 18 | # TODO: align this value with NVIDIA's assumptions (current value is a guess) 19 | EPS = 1e-12 20 | 21 | IS_ROCM = torch.cuda.is_available() and torch.version.hip is not None 22 | FP8_TYPES = { 23 | torch.float8_e4m3fn, 24 | torch.float8_e5m2, 25 | torch.float8_e4m3fnuz, 26 | torch.float8_e5m2fnuz, 27 | } 28 | 29 | 30 | # User defined type for using the individual F8 type based on config 31 | e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz 32 | e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz 33 | 34 | 35 | @torch.no_grad() 36 | def amax_to_scale( 37 | amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype 38 | ): 39 | """Converts the amax value of a tensor to the fp8 scale. 40 | Args: 41 | amax: The amax value of the tensor. 42 | float8_dtype: The float8 dtype. 43 | orig_dtype: The original dtype of the tensor. 44 | """ 45 | scale = torch.empty_like(amax, dtype=torch.float32) 46 | if float8_dtype in FP8_TYPES: 47 | res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) 48 | else: 49 | raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") 50 | 51 | # Ensure that the scale is representable in float16, 52 | # this helps when amax is small. We are assuming that we don't need 53 | # to care about this for float32/bfloat16. 54 | if orig_dtype is torch.float16: 55 | res = torch.clamp(res, max=torch.finfo(torch.float16).max) 56 | scale.copy_(res) 57 | return scale 58 | 59 | 60 | @torch.no_grad() 61 | def amax_history_to_scale( 62 | amax_history: torch.Tensor, 63 | float8_dtype: torch.Tensor, 64 | orig_dtype: torch.dtype, 65 | history_to_scale_fn_type: Literal["max"], 66 | ): 67 | """Takes in a history of amax values and returns a scale tensor. 68 | Args: 69 | amax_history: A tensor containing the history of amax values. 70 | float8_dtype: The float8 dtype. 71 | orig_dtype: The original dtype of the tensor. 72 | history_to_scale_fn_type: The type of function to use to convert the history to a scale. 73 | """ 74 | if history_to_scale_fn_type == "max": 75 | amax = torch.max(amax_history) 76 | return amax_to_scale(amax, float8_dtype, orig_dtype) 77 | raise NotImplementedError() 78 | 79 | 80 | @torch.no_grad() 81 | def amax_history_to_scale_stack( 82 | amax_history: torch.Tensor, 83 | float8_dtype: torch.dtype, 84 | orig_dtype: torch.dtype, 85 | history_to_scale_fn_type: Literal["max"], 86 | ) -> torch.Tensor: 87 | """Takes in a stack of amax_history tensors and returns a scale tensor. 88 | Args: 89 | amax_history: A 2D tensor containing a stack of amax histories. 90 | float8_dtype: The float8 dtype. 91 | orig_dtype: The original dtype of the tensor. 92 | history_to_scale_fn_type: The type of function to use to convert the history to a scale. 93 | """ 94 | if history_to_scale_fn_type == "max": 95 | amax_stack = torch.max(amax_history, dim=1).values 96 | return amax_to_scale(amax_stack, float8_dtype, orig_dtype) 97 | raise NotImplementedError( 98 | f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}" 99 | ) 100 | 101 | 102 | @torch.no_grad() 103 | def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: 104 | amax = torch.max(torch.abs(x)) 105 | 106 | # If the user asked for distributed reduction, do it. 107 | # If the user did not ask for it, assume that it will 108 | # happen elsewhere. 109 | if reduce_amax and dist.is_initialized(): 110 | dist.all_reduce(amax, op=dist.ReduceOp.MAX) 111 | 112 | return amax 113 | 114 | 115 | @torch.no_grad() 116 | def tensor_to_scale( 117 | x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False 118 | ) -> torch.Tensor: 119 | amax = tensor_to_amax(x, reduce_amax=reduce_amax) 120 | return amax_to_scale(amax, float8_dtype, x.dtype) 121 | 122 | 123 | def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): 124 | """Converts a tensor to a saturated fp8 tensor. 125 | 126 | Note: 127 | The default behavior in PyTorch for casting to `float8_e4m3fn` 128 | and `e5m2` is to not saturate. In this context, we should saturate. 129 | A common case where we want to saturate is when the history of a 130 | tensor has a maximum value of `amax1`, and the current amax value 131 | is `amax2`, where `amax1 < amax2`. This is common when using delayed 132 | scaling. 133 | """ 134 | if float8_dtype in FP8_TYPES: 135 | max_value = torch.finfo(float8_dtype).max 136 | x = x.clamp(min=-max_value, max=max_value) 137 | return x.to(float8_dtype) 138 | else: 139 | raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") 140 | 141 | 142 | def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 143 | """Computes the error between two tensors in dB. 144 | 145 | For more details see: 146 | https://en.wikipedia.org/wiki/Signal-to-noise_ratio 147 | 148 | Args: 149 | x: The original tensor. 150 | y: The tensor to compare to the original tensor. 151 | """ 152 | Ps = torch.norm(x) 153 | Pn = torch.norm(x - y) 154 | return 20 * torch.log10(Ps / Pn) 155 | 156 | 157 | def fp8_tensor_statistics( 158 | tensor: torch.Tensor, float8_dtype=e4m3_dtype 159 | ) -> Tuple[int, ...]: 160 | """Calculate FP8 tensor stats 161 | 162 | Args: 163 | tensor: The tensor to calculate stats for. 164 | float8_dtype: The float8 dtype. 165 | 166 | Returns: 167 | A tuple containing the number of zeros and the number of max values. 168 | """ 169 | if float8_dtype in FP8_TYPES: 170 | FP8_MAX = torch.finfo(float8_dtype).max 171 | else: 172 | raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") 173 | tensor_orig_type = tensor._data.to(dtype=tensor._orig_dtype) 174 | num_max = (torch.abs(tensor_orig_type) == FP8_MAX).sum().item() 175 | num_zero = (tensor_orig_type == 0).sum().item() 176 | return (num_zero, num_max) 177 | 178 | 179 | def is_row_major(stride): 180 | assert len(stride) == 2, "is_row_major only supports 2D tensors" 181 | return stride[0] > stride[1] and stride[1] == 1 182 | 183 | 184 | def _get_min_alignment(size: int, alignment_value: int) -> int: 185 | """ 186 | Returns the minimum alignment value that is greater than or equal to the given size. 187 | 188 | Args: 189 | size: The size of the data to be aligned. 190 | alignment_value: The alignment value to be used. 191 | 192 | Returns: 193 | int: The minimum alignment value that is greater than or equal to the given size. 194 | 195 | Usage: 196 | ``` 197 | >>> _get_min_alignment(10, 8) 198 | 16 199 | ``` 200 | """ 201 | if size % alignment_value == 0: 202 | return size 203 | return (1 + (size // alignment_value)) * alignment_value 204 | 205 | 206 | def pad_tensor_for_matmul( 207 | tensor: torch.Tensor, dims: Union[int, Iterable[int]] 208 | ) -> torch.Tensor: 209 | """ 210 | Pads a 2D tensor with zeros to ensure that its dimensions are multiples of 16, which is required `torch._scaled_mm` 211 | 212 | Args: 213 | tensor: The tensor to pad. 214 | both: Whether to pad both dimensions or just the second dimension. 215 | 216 | Returns: 217 | torch.Tensor: The padded tensor. 218 | 219 | Usage: 220 | ``` 221 | >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=0).shape 222 | torch.Size([16, 10]) 223 | >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=1).shape 224 | torch.Size([10, 16]) 225 | >>> pad_tensor_for_matmul(torch.randn((10, 10)), dims=(0, 1)).shape 226 | torch.Size([16, 16]) 227 | ``` 228 | """ 229 | assert tensor.dim() == 2 230 | dim1, dim2 = tensor.shape 231 | 232 | if isinstance(dims, int): 233 | dims = (dims,) 234 | 235 | # Calculate aligned dimensions based on the specified dims 236 | dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1 237 | dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2 238 | 239 | # Check if padding is needed for either dimension 240 | if dim1 == dim1_aligned and dim2 == dim2_aligned: 241 | return tensor 242 | 243 | # Calculate padding values for both dimensions 244 | pad_dim1 = dim1_aligned - dim1 245 | pad_dim2 = dim2_aligned - dim2 246 | 247 | return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) 248 | -------------------------------------------------------------------------------- /float8_experimental/fsdp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Any, List, Optional, Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.utils._pytree as pytree 13 | from float8_experimental.float8_scaling_utils import ( 14 | hp_tensor_to_float8_delayed, 15 | hp_tensor_to_float8_dynamic, 16 | ) 17 | 18 | from float8_experimental.float8_tensor import ( 19 | Float8Tensor, 20 | GemmInputRole, 21 | hp_tensor_and_scale_to_float8, 22 | LinearMMConfig, 23 | ) 24 | 25 | from float8_experimental.float8_utils import e4m3_dtype, EPS 26 | from torch._prims_common import suggest_memory_format 27 | 28 | 29 | @torch.no_grad() 30 | def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: 31 | """ 32 | Calculate scale dynamically for all float8 parameters. 33 | This should be run after the optimizer step. It performs a single all-reduce to compute the 34 | scales for all float8 weights. 35 | Example usage: 36 | model(input).sum().backward() 37 | optim.step() 38 | precompute_float8_dynamic_scale_for_fsdp(model) 39 | """ 40 | from float8_experimental.config import ScalingType 41 | from float8_experimental.float8_linear import Float8Linear 42 | from torch.distributed._tensor import DTensor 43 | 44 | if any( 45 | isinstance(m, Float8Linear) and m.scaling_type_weight is ScalingType.DELAYED 46 | for m in module.modules() 47 | ): 48 | raise NotImplementedError("Only supports delayed scaling") 49 | float8_linears: List[Float8Linear] = [ 50 | m 51 | for m in module.modules() 52 | if isinstance(m, Float8Linear) 53 | and isinstance(m.weight, DTensor) 54 | and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) 55 | ] 56 | weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] 57 | 58 | if not weights: 59 | return 60 | 61 | # inf-norm is equivalent to max(abs(w)) 62 | max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial 63 | amax_tensor = torch.stack(max_weights) # Partial 64 | # clamp is dispatched through DTensor 65 | # it will issue a single all-reduce 66 | amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate 67 | scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate 68 | if amax_tensor.dtype is torch.float16: 69 | scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) 70 | local_scale_tensor = scale_tensor.to_local() 71 | for i, float8_linear in enumerate(float8_linears): 72 | float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i] 73 | 74 | 75 | # FSDP pads its local tensor on dim-0. The subclass should be preserved such 76 | # that the padded local tensor (and any transformations like copying to GPU) 77 | # is of the subclass as well. 78 | _ops_to_preserve_subclass = { 79 | torch.ops.aten.empty_like.default, 80 | torch.ops.aten.new_zeros.default, 81 | torch.ops.aten.slice.Tensor, 82 | torch.ops.aten.copy_.default, 83 | torch.ops.aten.view.default, 84 | torch.ops.aten.as_strided.default, 85 | torch.ops.aten._to_copy.default, 86 | torch.ops.aten._pin_memory.default, 87 | } 88 | 89 | 90 | class WeightWithDynamicFloat8CastTensor(torch.Tensor): 91 | @staticmethod 92 | def __new__( 93 | cls, 94 | tensor: torch.Tensor, 95 | linear_mm_config: LinearMMConfig, 96 | precomputed_scale: Optional[torch.Tensor] = None, 97 | ): 98 | return torch.Tensor._make_wrapper_subclass( 99 | cls, 100 | tensor.size(), 101 | strides=tensor.stride(), 102 | storage_offset=tensor.storage_offset(), 103 | memory_format=suggest_memory_format(tensor), 104 | dtype=tensor.dtype, 105 | layout=tensor.layout, 106 | device=tensor.device, 107 | pin_memory=tensor.is_pinned(), 108 | requires_grad=tensor.requires_grad, 109 | ) 110 | 111 | def __init__( 112 | self, 113 | tensor: torch.Tensor, 114 | linear_mm_config: LinearMMConfig, 115 | precomputed_scale: Optional[torch.Tensor] = None, 116 | ): 117 | self._tensor = tensor 118 | self._linear_mm_config = linear_mm_config 119 | # for dynamic scaling 120 | # `precompute_float8_dynamic_scale_for_fsdp` calculates scales 121 | # for all float8 parameters after optimizer step 122 | self._precomputed_scale = precomputed_scale 123 | 124 | @classmethod 125 | def __torch_dispatch__(cls, func, types, args, kwargs=None): 126 | if func == torch.ops.aten.detach.default: 127 | return WeightWithDynamicFloat8CastTensor( 128 | args[0]._tensor, args[0]._linear_mm_config 129 | ) 130 | mm_config: Optional[LinearMMConfig] = None 131 | 132 | def unwrap(t): 133 | nonlocal mm_config 134 | if mm_config is None: 135 | mm_config = t._linear_mm_config 136 | else: 137 | assert t._linear_mm_config == mm_config 138 | return t._tensor 139 | 140 | args, kwargs = pytree.tree_map_only( 141 | WeightWithDynamicFloat8CastTensor, unwrap, (args, kwargs or {}) 142 | ) 143 | out = func(*args, **kwargs) 144 | if func not in _ops_to_preserve_subclass: 145 | return out 146 | return pytree.tree_map_only( 147 | torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out 148 | ) 149 | 150 | def __tensor_flatten__(self): 151 | if self._precomputed_scale: 152 | return ["_tensor", "_precomputed_scale"], self._linear_mm_config 153 | else: 154 | return ["_tensor"], self._linear_mm_config 155 | 156 | @staticmethod 157 | def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): 158 | mm_config = flatten_spec 159 | return WeightWithDynamicFloat8CastTensor( 160 | inner_tensors["_tensor"], 161 | mm_config, 162 | getattr(inner_tensors, "_precomputed_scale", None), 163 | ) 164 | 165 | def __repr__(self): 166 | return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config})" 167 | 168 | def fsdp_pre_all_gather(self, mesh): 169 | if self._precomputed_scale is not None: 170 | float8_tensor = hp_tensor_and_scale_to_float8( 171 | self._tensor, 172 | self._precomputed_scale, 173 | torch.float8_e4m3fn, 174 | self._linear_mm_config, 175 | GemmInputRole.WEIGHT, 176 | ) 177 | else: 178 | float8_tensor = hp_tensor_to_float8_dynamic( 179 | self._tensor, 180 | e4m3_dtype, 181 | self._linear_mm_config, 182 | reduce_amax=True, 183 | gemm_input_role=GemmInputRole.WEIGHT, 184 | ) 185 | return (float8_tensor._data,), (float8_tensor._scale,) 186 | 187 | def fsdp_post_all_gather( 188 | self, 189 | all_gather_outputs: Tuple[torch.Tensor, ...], 190 | metadata: Any, 191 | param_dtype: torch.dtype, 192 | *, 193 | out: Optional[torch.Tensor] = None, 194 | ): 195 | (data,) = all_gather_outputs 196 | (scale,) = metadata 197 | if out is not None: 198 | assert isinstance(out, Float8Tensor), f"{type(out)}" 199 | out._scale = scale 200 | return 201 | return Float8Tensor( 202 | data, 203 | scale, 204 | param_dtype, 205 | self._linear_mm_config, 206 | gemm_input_role=GemmInputRole.WEIGHT, 207 | ), (data,) 208 | 209 | 210 | class WeightWithDelayedFloat8CastTensor(torch.Tensor): 211 | @staticmethod 212 | def __new__( 213 | cls, 214 | tensor: torch.Tensor, 215 | amax_buffer: torch.Tensor, 216 | amax_history_buffer: torch.Tensor, 217 | scale_buffer: torch.Tensor, 218 | linear_mm_config: LinearMMConfig, 219 | is_amax_initialized: bool, 220 | ): 221 | return torch.Tensor._make_wrapper_subclass( 222 | cls, 223 | tensor.size(), 224 | strides=tensor.stride(), 225 | storage_offset=tensor.storage_offset(), 226 | memory_format=suggest_memory_format(tensor), 227 | dtype=tensor.dtype, 228 | layout=tensor.layout, 229 | device=tensor.device, 230 | pin_memory=tensor.is_pinned(), 231 | requires_grad=tensor.requires_grad, 232 | ) 233 | 234 | def __init__( 235 | self, 236 | tensor: torch.Tensor, 237 | amax_buffer: torch.Tensor, 238 | amax_history_buffer: torch.Tensor, 239 | scale_buffer: torch.Tensor, 240 | linear_mm_config: LinearMMConfig, 241 | is_amax_initialized: bool, 242 | ): 243 | self._tensor = tensor 244 | self._amax_buffer = amax_buffer 245 | self._amax_history_buffer = amax_history_buffer 246 | self._scale_buffer = scale_buffer 247 | self._linear_mm_config = linear_mm_config 248 | 249 | # Note: is_amax_initialized is not a buffer to avoid data dependent 250 | # control flow visible to dynamo 251 | # TODO(future PR): add serialization for this flag 252 | self.is_amax_initialized = is_amax_initialized 253 | 254 | @classmethod 255 | def __torch_dispatch__(cls, func, types, args, kwargs=None): 256 | if func == torch.ops.aten.detach.default: 257 | return WeightWithDelayedFloat8CastTensor( 258 | args[0]._tensor, 259 | args[0]._amax_buffer, 260 | args[0]._amax_history_buffer, 261 | args[0]._scale_buffer, 262 | args[0]._linear_mm_config, 263 | args[0].is_amax_initialized, 264 | ) 265 | mm_config: Optional[LinearMMConfig] = None 266 | amax_buffer: Optional[torch.Tensor] = None 267 | amax_history_buffer: Optional[torch.Tensor] = None 268 | scale_buffer: Optional[torch.Tensor] = None 269 | is_amax_initialized: Optional[bool] = None 270 | 271 | def unwrap(t): 272 | nonlocal mm_config 273 | if mm_config is None: 274 | mm_config = t._linear_mm_config 275 | else: 276 | assert t._linear_mm_config == mm_config 277 | nonlocal amax_buffer 278 | if amax_buffer is None: 279 | amax_buffer = t._amax_buffer 280 | nonlocal amax_history_buffer 281 | if amax_history_buffer is None: 282 | amax_history_buffer = t._amax_history_buffer 283 | nonlocal scale_buffer 284 | if scale_buffer is None: 285 | scale_buffer = t._scale_buffer 286 | nonlocal is_amax_initialized 287 | if is_amax_initialized is None: 288 | is_amax_initialized = t.is_amax_initialized 289 | return t._tensor 290 | 291 | args, kwargs = pytree.tree_map_only( 292 | WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {}) 293 | ) 294 | out = func(*args, **kwargs) 295 | if func not in _ops_to_preserve_subclass: 296 | return out 297 | return pytree.tree_map_only( 298 | torch.Tensor, 299 | lambda x: WeightWithDelayedFloat8CastTensor( 300 | x, 301 | amax_buffer, 302 | amax_history_buffer, 303 | scale_buffer, 304 | mm_config, 305 | is_amax_initialized, 306 | ), 307 | out, 308 | ) 309 | 310 | def __tensor_flatten__(self): 311 | return ( 312 | [ 313 | "_tensor", 314 | "_amax_buffer", 315 | "_amax_history_buffer", 316 | "_scale_buffer", 317 | ], 318 | { 319 | "mm_config": self._linear_mm_config, 320 | "is_amax_initialized": self.is_amax_initialized, 321 | }, 322 | ) 323 | 324 | @staticmethod 325 | def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): 326 | return WeightWithDelayedFloat8CastTensor( 327 | inner_tensors["_tensor"], 328 | inner_tensors["_amax_buffer"], 329 | inner_tensors["_amax_history_buffer"], 330 | inner_tensors["_scale_buffer"], 331 | metadata["mm_config"], 332 | metadata["is_amax_initialized"], 333 | ) 334 | 335 | def __repr__(self): 336 | return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config})" 337 | 338 | def fsdp_pre_all_gather(self, mesh): 339 | # initialize if needed 340 | # TODO(before land): ensure settings are consistent between Float8Linear and here 341 | if not self.is_amax_initialized: 342 | from float8_experimental.float8_linear import ( 343 | _maybe_initialize_amaxes_scales_for_float8_cast, 344 | ) 345 | 346 | _maybe_initialize_amaxes_scales_for_float8_cast( 347 | self._tensor, 348 | self._amax_buffer, 349 | self._amax_history_buffer, 350 | self._scale_buffer, 351 | "max", # TODO(before land): read this from parent 352 | e4m3_dtype, 353 | self.is_amax_initialized, 354 | reduce_amax=True, 355 | ) 356 | self.is_amax_initialized = True 357 | 358 | float8_tensor = hp_tensor_to_float8_delayed( 359 | self._tensor, 360 | self._scale_buffer, 361 | e4m3_dtype, 362 | self._amax_buffer, 363 | self._linear_mm_config, 364 | GemmInputRole.WEIGHT, 365 | ) 366 | return (float8_tensor._data,), (float8_tensor._scale,) 367 | 368 | def fsdp_post_all_gather( 369 | self, 370 | all_gather_outputs: Tuple[torch.Tensor, ...], 371 | metadata: Any, 372 | param_dtype: torch.dtype, 373 | *, 374 | out: Optional[torch.Tensor] = None, 375 | ): 376 | (data,) = all_gather_outputs 377 | (scale,) = metadata 378 | if out is not None: 379 | assert isinstance(out, Float8Tensor), f"{type(out)}" 380 | out._scale = scale 381 | return 382 | return Float8Tensor( 383 | data, 384 | scale, 385 | param_dtype, 386 | self._linear_mm_config, 387 | gemm_input_role=GemmInputRole.WEIGHT, 388 | ), (data,) 389 | -------------------------------------------------------------------------------- /float8_experimental/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Defines an nn module designed to be used during inference 8 | """ 9 | 10 | from dataclasses import dataclass 11 | 12 | from enum import auto, Enum 13 | from typing import Callable, List, Optional 14 | 15 | import torch 16 | import torch.nn as nn 17 | from float8_experimental.float8_linear_utils import swap_linear_layers 18 | 19 | from float8_experimental.float8_tensor import ( 20 | Float8Tensor, 21 | GemmInputRole, 22 | hp_tensor_and_scale_to_float8, 23 | LinearMMConfig, 24 | ScaledMMConfig, 25 | tensor_already_casted_to_fp8, 26 | ) 27 | from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale 28 | 29 | 30 | class ActivationCasting(Enum): 31 | """Types of quantization to perform on the activations 32 | 33 | WEIGHT_ONLY: Only quantize the weight, no activation casting, weight will be dequantized in the forward pass 34 | STATIC: Activation is quantized during model initialization with a static scale 35 | DYNAMIC: Activation is quantized during forward pass with a dynamic scale calculated from the input activation 36 | """ 37 | 38 | # TODO: A better name would be NONE, we should unify this with torchao 39 | WEIGHT_ONLY = auto() 40 | DYNAMIC = auto() 41 | STATIC = auto() 42 | 43 | 44 | @dataclass(frozen=True) 45 | class QuantConfig: 46 | """Defines the configuration for the quantization to fp8 of a linear module 47 | 48 | Args: 49 | activation_casting: The type of quantization to perform on the activations 50 | static_quantization_scale: The scale of the input to this linear module, used for static quantization only 51 | """ 52 | 53 | activation_casting: ActivationCasting 54 | static_quantization_scale: Optional[torch.Tensor] = None 55 | 56 | # If True, then prior to performing the fp8 scaled mamtmul we will pad the 57 | # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls 58 | # _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16. 59 | # This can cause a memory spike however so we keep this off by default. 60 | pad_inner_dim = False 61 | 62 | def __post_init__(self): 63 | if self.activation_casting == ActivationCasting.STATIC: 64 | assert isinstance( 65 | self.static_quantization_scale, torch.Tensor 66 | ), "When activation_casting is 'static', activation_scale must be a tensor." 67 | 68 | 69 | class Float8InferenceLinear(torch.nn.Linear): 70 | """ 71 | This is a wrapper around torch.nn.Linear that supports FP8 inference 72 | Supported forms of inference: 73 | - FP8 inference with high precision matmul - weight only 74 | - FP8 inference with fp8 matmul and dynamic weight casting 75 | - FP8 inference with fp8 matmul and static weight casting 76 | """ 77 | 78 | def __init__( 79 | self, 80 | # FP8 specific arguments 81 | quant_config: QuantConfig, 82 | linear_mm_config: LinearMMConfig, 83 | # nn.Linear arguments 84 | in_features: int, 85 | out_features: int, 86 | bias: bool = True, 87 | device: Optional[torch.device] = None, 88 | dtype: Optional[torch.dtype] = None, 89 | ) -> None: 90 | # Construct the superclass this will create dummy weights and biases 91 | super().__init__(in_features, out_features, bias, device, dtype) 92 | self.linear_mm_config = linear_mm_config 93 | self.activation_casting = quant_config.activation_casting 94 | if self.activation_casting == ActivationCasting.STATIC: 95 | self.register_buffer( 96 | "static_quantization_scale", quant_config.static_quantization_scale 97 | ) 98 | else: 99 | self.static_quantization_scale = None 100 | 101 | def forward(self, input: torch.Tensor) -> torch.Tensor: 102 | if self.activation_casting == ActivationCasting.WEIGHT_ONLY: 103 | return torch.nn.functional.linear( 104 | input, self.weight.to_original_precision() 105 | ) 106 | 107 | x_fp8 = cast_to_float8_e4m3_inference( 108 | input, 109 | self.linear_mm_config, 110 | static_quantization_scale=self.static_quantization_scale, 111 | ) 112 | return torch.nn.functional.linear(x_fp8, self.weight, self.bias) 113 | 114 | # Builder functions for Float8LinearInference 115 | def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: 116 | """This functions converts the weight to a Float8Tensor and sets its requires_grad to False. 117 | 118 | Args: 119 | dtype: The dtype to quantize the weight to. Default is e4m3_dtype. 120 | 121 | Note: 122 | This function is typically called during inference to quantize the weight once since 123 | the weight is not updated during inference. 124 | 125 | """ 126 | assert not isinstance( 127 | self.weight, Float8Tensor 128 | ), "Weight has already been quantized, cannot quantize again." 129 | scale = tensor_to_scale(self.weight, dtype) 130 | quantized_weight = hp_tensor_and_scale_to_float8( 131 | self.weight, 132 | scale, 133 | dtype, 134 | self.linear_mm_config, 135 | GemmInputRole.WEIGHT, 136 | ) 137 | self.weight = nn.Parameter(quantized_weight) 138 | self.weight.requires_grad = False 139 | 140 | def set_weight_and_bias( 141 | self, weight: torch.nn.Parameter, bias: Optional[torch.nn.Parameter] 142 | ): 143 | self.weight = weight 144 | self.bias = bias 145 | 146 | @classmethod 147 | def from_float( 148 | cls, module: nn.Module, quant_config: QuantConfig, use_fast_accum: bool 149 | ) -> "Float8InferenceLinear": 150 | """ 151 | Create an nn.Linear with fp8 compute from another nn.Linear 152 | 153 | Args: 154 | mod (torch.nn.Linear): nn.Linear to convert 155 | quant_config (QuantConfig): Configuration for the weight and activation casting 156 | """ 157 | forward_config = ScaledMMConfig( 158 | False, use_fast_accum, pad_inner_dim=quant_config.pad_inner_dim 159 | ) 160 | linear_mm_config = LinearMMConfig( 161 | forward_config, forward_config, forward_config 162 | ) 163 | linear = cls( 164 | quant_config, 165 | linear_mm_config, 166 | module.in_features, 167 | module.out_features, 168 | False, 169 | device=torch.device("meta"), 170 | ) 171 | linear.set_weight_and_bias(module.weight, module.bias) 172 | linear.quantize_weight() 173 | return linear 174 | 175 | 176 | def cast_to_float8_e4m3_inference( 177 | inpt_tensor: torch.Tensor, 178 | linear_mm_config: LinearMMConfig, 179 | reduce_amax: bool = False, 180 | static_quantization_scale: Optional[torch.Tensor] = None, 181 | ) -> Float8Tensor: 182 | """Casts an input tensor to the Float8 (e4m3fn*) 183 | 184 | Args: 185 | inpt_tensor: The input tensor to be cast. 186 | linear_mm_config: Configuration settings for the matrix multiplication 187 | reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group. 188 | static_quantization_scale: Optional tensor specifying the scale for activation. Default is None. 189 | 190 | Returns: 191 | Float8Tensor: The input tensor cast to Float8 (e4m3fn) format. 192 | 193 | Note: 194 | If the input tensor is already in Float8 format, it is returned as is without re-casting. 195 | """ 196 | if tensor_already_casted_to_fp8(inpt_tensor): 197 | return inpt_tensor 198 | scale = ( 199 | static_quantization_scale 200 | if static_quantization_scale is not None 201 | else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) 202 | ) 203 | return hp_tensor_and_scale_to_float8( 204 | inpt_tensor, 205 | scale, 206 | e4m3_dtype, 207 | linear_mm_config, 208 | GemmInputRole.INPUT, 209 | ) 210 | 211 | 212 | def quantize_to_float8( 213 | module: nn.Module, 214 | quant_config: QuantConfig, 215 | *, 216 | module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, 217 | use_fast_accum: bool = True, 218 | ) -> nn.Module: 219 | """ 220 | Converts torch.nn.Linear layers in the given module to Float8InferenceLinear. 221 | 222 | Note: 223 | If applied to a root-level nn.Linear, the module will not be modified in place 224 | and returned instead 225 | 226 | Args: 227 | module (nn.Module): The module to modify. 228 | quant_config (QuantConfig): Quantization configuration for Float8 conversion. 229 | module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that 230 | that pass the filter function will be swapped. The inputs to the 231 | filter function are the module instance and the FQN. 232 | use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True. 233 | 234 | Returns: 235 | nn.Module: The modified module with applicable Linear layers converted to Float8. 236 | 237 | Raises: 238 | AssertionError: If a root-level nn.Linear with children is encountered. 239 | """ 240 | return swap_linear_layers( 241 | module, 242 | lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum), 243 | module_filter_fn=module_filter_fn, 244 | ) 245 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "float8_experimental" 7 | version = "0.0.1" 8 | description = "This is a prototype of a float8 training UX in native PyTorch, with full PT2.0 and distributed support." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "Operating System :: OS Independent", 14 | ] 15 | 16 | dependencies = [ 17 | "torch >= 2.3", 18 | ] 19 | 20 | [project.optional-dependencies] 21 | test = [ 22 | "pandas >= 2.0", 23 | "tqdm==4.66.2", 24 | "fire==0.5.0", 25 | "expecttest", 26 | ] 27 | dev = [ 28 | "black==23.3.0", 29 | "usort==1.0.6", 30 | "ufmt==2.1.0", 31 | "libcst==1.1.0", 32 | "pytest==7.4.0", 33 | "bumpver", 34 | "pip-tools", 35 | "ruff==0.3.0" 36 | ] 37 | # ---------- TOOL CONFIGURATIONS ------------ 38 | [tool.usort] 39 | first_party_detection = false 40 | 41 | [tool.black] 42 | target-version = ["py310"] 43 | 44 | [tool.ruff] 45 | # Exclude a variety of commonly ignored directories. 46 | exclude = [ 47 | ".bzr", 48 | ".direnv", 49 | ".eggs", 50 | ".git", 51 | ".git-rewrite", 52 | ".hg", 53 | ".ipynb_checkpoints", 54 | ".mypy_cache", 55 | ".nox", 56 | ".pants.d", 57 | ".pyenv", 58 | ".pytest_cache", 59 | ".pytype", 60 | ".ruff_cache", 61 | ".svn", 62 | ".tox", 63 | ".venv", 64 | ".vscode", 65 | "__pypackages__", 66 | "_build", 67 | "buck-out", 68 | "build", 69 | "dist", 70 | "node_modules", 71 | "site-packages", 72 | "venv", 73 | ] 74 | 75 | # Same as Black. 76 | line-length = 88 77 | indent-width = 4 78 | 79 | # Assume Python 3.10 80 | target-version = "py310" 81 | 82 | [tool.ruff.lint] 83 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 84 | select = ["E4", "E7", "E9", "F"] 85 | ignore = ["E731"] 86 | 87 | # Allow fix for all enabled rules (when `--fix`) is provided. 88 | fixable = ["ALL"] 89 | unfixable = [] 90 | 91 | # Allow unused variables when underscore-prefixed. 92 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 93 | -------------------------------------------------------------------------------- /test/test_compile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import copy 7 | import random 8 | import sys 9 | import unittest 10 | from io import StringIO 11 | 12 | import pytest 13 | 14 | import torch 15 | import torch.nn as nn 16 | from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType 17 | from float8_experimental.float8_linear import Float8Linear 18 | from float8_experimental.float8_linear_utils import ( 19 | convert_to_float8_training, 20 | get_float8_layers, 21 | sync_float8_amax_and_scale_history, 22 | ) 23 | from float8_experimental.float8_scaling_utils import hp_tensor_to_float8_delayed 24 | from float8_experimental.float8_tensor import LinearMMConfig 25 | from float8_experimental.float8_utils import e4m3_dtype 26 | 27 | from torch._dynamo.test_case import TestCase as DynamoTestCase 28 | from torch._dynamo.testing import CompileCounterWithBackend 29 | 30 | is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) 31 | 32 | 33 | def _test_compile_base( 34 | backend: str, 35 | fullgraph: bool, 36 | config: Float8LinearConfig, 37 | dtype: torch.dtype, 38 | ): 39 | random.seed(0) 40 | torch.manual_seed(0) 41 | x_shape = (16, 16) 42 | linear_dtype = torch.bfloat16 43 | 44 | x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) 45 | m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) 46 | 47 | m_fp8 = Float8Linear.from_float( 48 | copy.deepcopy(m_ref), 49 | config, 50 | ) 51 | 52 | m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) 53 | m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) 54 | y_fp8 = m_fp8(x) 55 | y_fp8.sum().backward() 56 | y_ref = m_ref(x) 57 | y_ref.sum().backward() 58 | torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2) 59 | torch.testing.assert_close( 60 | m_fp8.weight.grad, m_ref.weight.grad, atol=2e-1, rtol=2e-1 61 | ) 62 | torch.testing.assert_close(m_fp8.bias.grad, m_ref.bias.grad, atol=8e-2, rtol=8e-2) 63 | 64 | 65 | @pytest.mark.parametrize("fullgraph", [True]) 66 | @pytest.mark.parametrize( 67 | "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] 68 | ) 69 | @pytest.mark.parametrize( 70 | "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] 71 | ) 72 | @pytest.mark.parametrize( 73 | "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] 74 | ) 75 | @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) 76 | @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) 77 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") 78 | def test_eager_only( 79 | fullgraph, 80 | emulate: bool, 81 | scaling_type_input: ScalingType, 82 | scaling_type_weight: ScalingType, 83 | scaling_type_grad_output: ScalingType, 84 | dtype: torch.dtype, 85 | ): 86 | torch._dynamo.reset() 87 | config = Float8LinearConfig( 88 | cast_config_input=CastConfig(scaling_type=scaling_type_input), 89 | cast_config_weight=CastConfig(scaling_type=scaling_type_weight), 90 | cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), 91 | emulate=emulate, 92 | ) 93 | _test_compile_base( 94 | "eager", 95 | fullgraph, 96 | config, 97 | dtype, 98 | ) 99 | 100 | 101 | @pytest.mark.parametrize("fullgraph", [True]) 102 | @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) 103 | @pytest.mark.parametrize( 104 | "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] 105 | ) 106 | @pytest.mark.parametrize( 107 | "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] 108 | ) 109 | @pytest.mark.parametrize( 110 | "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] 111 | ) 112 | @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) 113 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") 114 | def test_aot_eager( 115 | fullgraph, 116 | emulate: bool, 117 | scaling_type_input: ScalingType, 118 | scaling_type_weight: ScalingType, 119 | scaling_type_grad_output: ScalingType, 120 | dtype: torch.dtype, 121 | ): 122 | torch._dynamo.reset() 123 | config = Float8LinearConfig( 124 | cast_config_input=CastConfig(scaling_type=scaling_type_input), 125 | cast_config_weight=CastConfig(scaling_type=scaling_type_weight), 126 | cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), 127 | emulate=emulate, 128 | ) 129 | _test_compile_base( 130 | "aot_eager", 131 | fullgraph, 132 | config, 133 | dtype, 134 | ) 135 | 136 | 137 | @pytest.mark.parametrize("fullgraph", [True]) 138 | @pytest.mark.parametrize("emulate", [False]) 139 | @pytest.mark.parametrize( 140 | "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] 141 | ) 142 | @pytest.mark.parametrize( 143 | "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] 144 | ) 145 | @pytest.mark.parametrize( 146 | "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC] 147 | ) 148 | @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available") 149 | @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) 150 | def test_inductor( 151 | fullgraph, 152 | emulate: bool, 153 | scaling_type_input: ScalingType, 154 | scaling_type_weight: ScalingType, 155 | scaling_type_grad_output: ScalingType, 156 | dtype: torch.dtype, 157 | ): 158 | torch._dynamo.reset() 159 | config = Float8LinearConfig( 160 | cast_config_input=CastConfig(scaling_type=scaling_type_input), 161 | cast_config_weight=CastConfig(scaling_type=scaling_type_weight), 162 | cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), 163 | emulate=emulate, 164 | ) 165 | _test_compile_base( 166 | "inductor", 167 | fullgraph, 168 | config, 169 | dtype, 170 | ) 171 | 172 | 173 | class TestGraphBreaks(DynamoTestCase): 174 | class MockLinear(torch.nn.Module): 175 | def __init__(self, graph_break: bool): 176 | super().__init__() 177 | self.register_buffer("fp8_amax_x", torch.tensor(1.0)) 178 | self.register_buffer("fp8_scale_x", torch.tensor(1.0)) 179 | self.graph_break = graph_break 180 | 181 | def forward(self, x): 182 | x_fp8 = hp_tensor_to_float8_delayed( 183 | x, 184 | self.fp8_scale_x, 185 | e4m3_dtype, 186 | self.fp8_amax_x, 187 | LinearMMConfig(), 188 | ) 189 | if self.graph_break: 190 | torch._dynamo.graph_break() 191 | x_hp = x_fp8.to_original_precision() 192 | return x_hp 193 | return x_fp8 194 | 195 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") 196 | def test_float8_with_graph_break_in_the_middle(self): 197 | """Test that having Float8Tensor object at the boundary of a subgraph""" 198 | cnts = CompileCounterWithBackend("inductor") 199 | mod = self.MockLinear(graph_break=True).cuda() 200 | compiled_mod = copy.deepcopy(mod) 201 | compiled_mod = torch.compile(compiled_mod, backend=cnts) 202 | x = torch.randn(16, 16, device="cuda") 203 | y_eager = mod(x) 204 | y_compiled = compiled_mod(x) 205 | self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!") 206 | torch.testing.assert_close(y_eager, y_compiled) 207 | 208 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") 209 | def test_float8_graph_input(self): 210 | """Test that having Float8Tensor object as a graph input""" 211 | 212 | def to_float(x): 213 | return x.to_original_precision() 214 | 215 | cnts = CompileCounterWithBackend("inductor") 216 | mod = self.MockLinear(graph_break=False).cuda() 217 | x = torch.randn(2, 2, device="cuda") 218 | compiled_to_float = torch.compile(to_float, backend=cnts) 219 | y = mod(x) 220 | y2_eager = to_float(y) 221 | y2_compiled = compiled_to_float(y) 222 | self.assertEqual( 223 | cnts.frame_count, 224 | 1, 225 | "to_float was not compiled into 1 frame and likely encountered a skip!", 226 | ) 227 | torch.testing.assert_close(y2_eager, y2_compiled) 228 | 229 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") 230 | def test_float8_graph_output(self): 231 | """Test that having Float8Tensor object as a graph output works""" 232 | cnts = CompileCounterWithBackend("inductor") 233 | mod = self.MockLinear(graph_break=False).cuda() 234 | compiled_mod = torch.compile(mod, backend=cnts) 235 | x = torch.randn(16, 16, device="cuda") 236 | y_compiled = compiled_mod(x) 237 | 238 | self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") 239 | tensors, ctx = y_compiled.__tensor_flatten__() 240 | for tensor in tensors: 241 | assert not isinstance( 242 | getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor 243 | ), "Float8Tensor should not contain any FakeTensors!" 244 | assert isinstance( 245 | y_compiled._orig_dtype, torch.dtype 246 | ), "Float8Tensor._orig_dtype should be a dtype but got {}".format( 247 | type(y_compiled._orig_dtype) 248 | ) 249 | assert isinstance( 250 | y_compiled._linear_mm_config.output.emulate, bool 251 | ), "Float8Tensor._emulate should be a bool but got {}".format( 252 | type(y_compiled._linear_mm_config.output.emulate) 253 | ) 254 | 255 | 256 | @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available") 257 | def test_sync_amax_func(): 258 | torch._dynamo.reset() 259 | cnts = CompileCounterWithBackend("inductor") 260 | module = torch.nn.Sequential( 261 | nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) 262 | ) 263 | config = Float8LinearConfig( 264 | cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), 265 | cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), 266 | cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), 267 | ) 268 | float8_mod = convert_to_float8_training( 269 | module, 270 | config=config, 271 | ) 272 | compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts) 273 | compiled_swap_func(float8_mod) 274 | assert cnts.frame_count == 1, "Compiled graph should have 1 frame!" 275 | 276 | 277 | class capture_stderr(list): 278 | """ 279 | Replace sys.stderr with a temporary StringIO 280 | """ 281 | 282 | def __enter__(self): 283 | self.sys_stderr = sys.stderr 284 | self.stringio = StringIO() 285 | sys.stderr = self.stringio 286 | return self 287 | 288 | def __exit__(self, *args): 289 | self.append(str(self.stringio.getvalue())) 290 | del self.stringio 291 | sys.stderr = self.sys_stderr 292 | 293 | 294 | @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available") 295 | def test_sync_amax_func_cuda_graph_success(): 296 | torch._dynamo.reset() 297 | with capture_stderr() as stderr: 298 | my_module = nn.Sequential( 299 | nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) 300 | ).to("cuda") 301 | config = Float8LinearConfig( 302 | cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), 303 | cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), 304 | cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), 305 | ) 306 | convert_to_float8_training( 307 | my_module, 308 | config=config, 309 | ) 310 | inpt = torch.randn( 311 | 16, 16, device="cuda", dtype=torch.float32, requires_grad=True 312 | ) 313 | sync_func = torch.compile( 314 | sync_float8_amax_and_scale_history, mode="reduce-overhead", fullgraph=True 315 | ) 316 | fp8_layers = get_float8_layers(my_module) 317 | my_module(inpt) 318 | sync_func(my_module, fp8_layers) 319 | 320 | assert "skipping cudagraphs due to mutaton on input" not in stderr[0] 321 | 322 | 323 | if __name__ == "__main__": 324 | pytest.main([__file__]) 325 | -------------------------------------------------------------------------------- /test/test_dtensor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Test numerics of manually defined float16 TP vs float8 TP of toy models 8 | """ 9 | 10 | import copy 11 | import os 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from float8_experimental import Float8LinearConfig 17 | from float8_experimental.float8_linear_utils import convert_to_float8_training 18 | 19 | from float8_experimental.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic 20 | from float8_experimental.float8_tensor import ( 21 | Float8Tensor, 22 | GemmInputRole, 23 | hp_tensor_and_scale_to_float8, 24 | LinearMMConfig, 25 | ) 26 | from float8_experimental.float8_tensor_parallel import ( 27 | Float8ColwiseParallel, 28 | Float8RowwiseParallel, 29 | PrepareFloat8ModuleInput, 30 | ) 31 | from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale 32 | from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard 33 | from torch.distributed.device_mesh import DeviceMesh, init_device_mesh 34 | from torch.distributed.tensor.parallel import parallelize_module 35 | from tqdm import tqdm 36 | 37 | 38 | def setup_distributed(): 39 | world_size = int(os.environ.get("WORLD_SIZE", -1)) 40 | device_mesh = init_device_mesh("cuda", (world_size,)) 41 | # seed must be the same in all processes 42 | torch.manual_seed(1) 43 | return device_mesh 44 | 45 | 46 | class FeedForward(nn.Module): 47 | """MLP based model""" 48 | 49 | def __init__(self): 50 | super(FeedForward, self).__init__() 51 | self.w1 = nn.Linear(16, 32, bias=False) 52 | self.w2 = nn.Linear(16, 32, bias=False) 53 | self.out_proj = nn.Linear(32, 16, bias=False) 54 | 55 | def forward(self, x): 56 | return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) 57 | 58 | 59 | class ToyModel(nn.Module): 60 | def __init__(self): 61 | super(ToyModel, self).__init__() 62 | self.ffn = FeedForward() 63 | 64 | def forward(self, x): 65 | return self.ffn(x) 66 | 67 | 68 | def test_scaled_mm(mesh: DeviceMesh, size=16): 69 | device = mesh.device_type 70 | fp8_dtype = e4m3_dtype 71 | world_size = mesh.size() 72 | 73 | x_fp32 = torch.rand(size, size, device=device) 74 | y_fp32 = torch.eye(size, device=device).t() 75 | 76 | placement_combs = ( 77 | (Shard(0), Replicate()), 78 | (Replicate(), Shard(1)), 79 | (Shard(1), Shard(0)), 80 | ) 81 | expected_dt_out_shape = ( 82 | (size * world_size, size), 83 | (size, size * world_size), 84 | (size, size), 85 | ) 86 | for idx, (lhs_placement, rhs_placement) in enumerate(placement_combs): 87 | x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() 88 | y_scale = tensor_to_scale(y_fp32, fp8_dtype).float() 89 | 90 | x_fp8 = hp_tensor_and_scale_to_float8( 91 | x_fp32, x_scale, fp8_dtype, None, GemmInputRole.INPUT 92 | ) 93 | y_fp8 = hp_tensor_and_scale_to_float8( 94 | y_fp32, y_scale, fp8_dtype, None, GemmInputRole.WEIGHT 95 | ) 96 | 97 | dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False) 98 | dist_y_fp8 = DTensor.from_local(y_fp8, mesh, [rhs_placement], run_check=False) 99 | 100 | assert isinstance(dist_x_fp8.to_local(), Float8Tensor) 101 | assert isinstance(dist_y_fp8.to_local(), Float8Tensor) 102 | assert dist_x_fp8.to_local()._orig_dtype == torch.float32 103 | out_fp8 = torch.mm(dist_x_fp8, dist_y_fp8) 104 | local_fp8_out = out_fp8.to_local() 105 | assert out_fp8.shape == expected_dt_out_shape[idx], (idx, local_fp8_out.shape) 106 | 107 | # after mm the out dtype should be fp32 108 | assert local_fp8_out.dtype == torch.float32 109 | 110 | 111 | def test_fp8_redistribute(mesh: DeviceMesh, size=16): 112 | device = mesh.device_type 113 | fp8_dtype = e4m3_dtype 114 | world_size = mesh.size() 115 | 116 | x_fp32 = torch.rand(size, size, device=device) 117 | 118 | x_scale = tensor_to_scale(x_fp32, fp8_dtype).float() 119 | 120 | x_fp8 = hp_tensor_and_scale_to_float8(x_fp32, x_scale, fp8_dtype) 121 | 122 | dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [Shard(0)], run_check=False) 123 | out_dist = dist_x_fp8.redistribute(placements=[Replicate()]) 124 | assert out_dist.shape == (size * world_size, size) 125 | assert out_dist.placements == (Replicate(),) 126 | out_local = out_dist.to_local() 127 | # after allgather the out shape should be replicate 128 | assert out_local.shape == (size * world_size, size) 129 | from torch.distributed._functional_collectives import AsyncCollectiveTensor 130 | 131 | if isinstance(out_local, AsyncCollectiveTensor): 132 | out_local = out_local.wait() 133 | 134 | assert isinstance(out_local, Float8Tensor) 135 | assert out_local._data.dtype == fp8_dtype 136 | 137 | 138 | def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16): 139 | device = mesh.device_type 140 | fp8_dtype = e4m3_dtype 141 | 142 | x_fp32 = torch.rand(size, size, device=device) 143 | dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) 144 | 145 | dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float() 146 | assert isinstance(dist_x_scale, DTensor) 147 | 148 | dist_x_fp8 = hp_tensor_and_scale_to_float8(dist_x_fp32, dist_x_scale, fp8_dtype) 149 | assert isinstance(dist_x_fp8, DTensor) 150 | 151 | 152 | def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): 153 | device = mesh.device_type 154 | fp8_dtype = e4m3_dtype 155 | 156 | x_fp32 = torch.rand(size, size, device=device, requires_grad=True) 157 | local_weight = torch.rand(2 * size, size, device=device, requires_grad=True) 158 | target = torch.rand(size, 2 * size, device=device) 159 | 160 | dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) 161 | dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float() 162 | 163 | dist_wight_fp32 = distribute_tensor(local_weight, mesh, [Shard(0)]) 164 | dist_weight_scale = tensor_to_scale(dist_wight_fp32, fp8_dtype).float() 165 | dist_target = distribute_tensor(target, mesh, [Shard(0)]) 166 | 167 | dist_x_fp8 = hp_tensor_and_scale_to_float8( 168 | dist_x_fp32, 169 | dist_x_scale, 170 | fp8_dtype, 171 | None, 172 | GemmInputRole.INPUT, 173 | ) 174 | dist_weight_fp8 = hp_tensor_and_scale_to_float8( 175 | dist_wight_fp32, 176 | dist_weight_scale, 177 | fp8_dtype, 178 | None, 179 | GemmInputRole.WEIGHT, 180 | ) 181 | 182 | out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) 183 | out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig()) 184 | assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" 185 | loss = torch.sum(torch.abs(out - dist_target)) 186 | loss.backward() 187 | 188 | 189 | def _test_fp8_mlp_tensor_parallelism_base( 190 | mesh: DeviceMesh, size=16, compile: bool = False 191 | ): 192 | device = mesh.device_type 193 | # For now, only supports dynamic scaling of `x` and `dL_dY`. 194 | # TODO(future): add support for float8 all-gather with delayed scaling 195 | # for activations and gradients. 196 | config = Float8LinearConfig(emulate=True) 197 | 198 | toy_model = ToyModel().to(device) 199 | toy_model_fp8 = convert_to_float8_training(toy_model, config=config) 200 | 201 | tp_model = copy.deepcopy(toy_model) 202 | tp_model = convert_to_float8_training(tp_model, config=config) 203 | sp_model = copy.deepcopy(toy_model) 204 | sp_model = convert_to_float8_training(sp_model, config=config) 205 | 206 | # vanilla TP 207 | tp_model = parallelize_module( 208 | tp_model, 209 | mesh, 210 | { 211 | "ffn.w1": Float8ColwiseParallel(), 212 | "ffn.w2": Float8ColwiseParallel(), 213 | "ffn.out_proj": Float8RowwiseParallel(), 214 | }, 215 | ) 216 | 217 | # "sequence parallel" mlp computation 218 | sp_model = parallelize_module( 219 | sp_model, 220 | mesh, 221 | { 222 | "ffn": PrepareFloat8ModuleInput( 223 | input_layouts=Shard(1), desired_input_layouts=Replicate() 224 | ), 225 | "ffn.w1": Float8ColwiseParallel(), 226 | "ffn.w2": Float8ColwiseParallel(), 227 | "ffn.out_proj": Float8RowwiseParallel( 228 | output_layouts=Shard(1), use_local_output=False 229 | ), 230 | }, 231 | ) 232 | 233 | # PrepareFloat8ModuleInput with specific submodule fqn 234 | sp_model2 = copy.deepcopy(toy_model) 235 | sp_model2 = convert_to_float8_training(sp_model2, config=config) 236 | 237 | sp_model2 = parallelize_module( 238 | sp_model2, 239 | mesh, 240 | { 241 | "ffn": PrepareFloat8ModuleInput( 242 | input_layouts=Shard(1), 243 | desired_input_layouts=Replicate(), 244 | fwd_config_submodule_fqn="w2", 245 | ), 246 | "ffn.w1": Float8ColwiseParallel(), 247 | "ffn.w2": Float8ColwiseParallel(), 248 | "ffn.out_proj": Float8RowwiseParallel( 249 | output_layouts=Shard(1), use_local_output=False 250 | ), 251 | }, 252 | ) 253 | 254 | if compile: 255 | tp_model = torch.compile(tp_model) 256 | sp_model = torch.compile(sp_model) 257 | sp_model2 = torch.compile(sp_model2) 258 | 259 | x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) 260 | x_fp32_tp_input = x_fp32.clone() 261 | x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) 262 | 263 | tp_out = tp_model(x_fp32_tp_input) 264 | tp_out.sum().backward() 265 | sp_out = sp_model(x_fp32_sp_input) 266 | sp_out.sum().backward() 267 | global_out = toy_model_fp8(x_fp32) 268 | global_out.sum().backward() 269 | torch.testing.assert_close(tp_out, global_out) 270 | torch.testing.assert_close(sp_out.full_tensor(), global_out) 271 | torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) 272 | torch.testing.assert_close( 273 | tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad 274 | ) 275 | 276 | sp_out2 = sp_model2(x_fp32_sp_input) 277 | sp_out2.sum().backward() 278 | torch.testing.assert_close(sp_out2.full_tensor(), global_out) 279 | torch.testing.assert_close( 280 | tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad 281 | ) 282 | torch.testing.assert_close( 283 | tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad 284 | ) 285 | 286 | 287 | def test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): 288 | _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False) 289 | 290 | 291 | def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): 292 | _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) 293 | 294 | 295 | if __name__ == "__main__": 296 | # float8 only works on CUDA H100 so we only test cuda and we follow 297 | # other test files to not use TestCase but instead just add the test 298 | # cases in the main func. 299 | device_mesh = setup_distributed() 300 | tests = [ 301 | test_scaled_mm, 302 | test_fp8_redistribute, 303 | test_dtensor_cast_to_fp8, 304 | test_dtensor_fp8_autograd, 305 | test_fp8_mlp_tensor_parallelism_eager, 306 | test_fp8_mlp_tensor_parallelism_compile, 307 | ] 308 | 309 | for test in tqdm(tests, desc="Running tests"): 310 | try: 311 | test(device_mesh) 312 | except Exception as e: 313 | print(f"Test {test.__name__} failed with error: {e}") 314 | raise e 315 | 316 | torch.distributed.destroy_process_group() 317 | -------------------------------------------------------------------------------- /test/test_dtensor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # terminate script on first error 4 | set -e 5 | 6 | if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then 7 | echo "Skipping test_dtensor.sh because no CUDA devices are available." 8 | exit 9 | fi 10 | 11 | NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/test_dtensor.py 12 | -------------------------------------------------------------------------------- /test/test_everything.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # terminate script on first error 4 | set -e 5 | IS_ROCM=$(rocm-smi --version || true) 6 | 7 | pytest test/test_base.py 8 | pytest test/test_compile.py 9 | pytest test/test_inference_flows.py 10 | pytest test/test_numerics_integration.py 11 | 12 | # These tests do not work on ROCm yet 13 | if [ -z "$IS_ROCM" ] 14 | then 15 | ./test/test_fsdp.sh 16 | ./test/test_fsdp_compile.sh 17 | ./test/test_dtensor.sh 18 | pytest test/test_fsdp2/test_fsdp2.py 19 | fi 20 | 21 | echo "all tests successful" 22 | -------------------------------------------------------------------------------- /test/test_fsdp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Test numerics of bf16 versus float8 with FSDP on. At a high level: 8 | 1. start with a reference model, with FSDP on 9 | 2. run forward + backward + optim for 2 iterations 10 | 3. repeat 2 with float8 enabled (2 iterations needed for delayed scaling) 11 | 4. compare outputs and state dict between (2) and (3), should be close 12 | """ 13 | 14 | import copy 15 | import os 16 | import warnings 17 | 18 | import fire 19 | 20 | import torch 21 | import torch.distributed as dist 22 | import torch.multiprocessing as mp 23 | import torch.nn as nn 24 | from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType 25 | from float8_experimental.float8_linear_utils import ( 26 | convert_to_float8_training, 27 | linear_requires_sync, 28 | sync_float8_amax_and_scale_history, 29 | ) 30 | from float8_experimental.float8_utils import compute_error 31 | from torch.distributed.fsdp import ( 32 | FullStateDictConfig, 33 | FullyShardedDataParallel as FSDP, 34 | StateDictType, 35 | ) 36 | 37 | torch.manual_seed(0) 38 | 39 | B, M, K, N = 8, 8, 32, 32 40 | lr = 0.01 41 | N_ITER = 2 42 | 43 | 44 | def setup(rank, world_size): 45 | os.environ["MASTER_ADDR"] = "localhost" 46 | os.environ["MASTER_PORT"] = "12355" 47 | 48 | # initialize the process group 49 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 50 | 51 | 52 | def cleanup(): 53 | dist.destroy_process_group() 54 | 55 | 56 | def get_model(K, N, base_dtype=torch.float32): 57 | m = nn.Sequential( 58 | nn.Linear(K, N, dtype=base_dtype), 59 | nn.ReLU(), 60 | nn.Linear(N, N, dtype=base_dtype), 61 | nn.ReLU(), 62 | ) 63 | return m 64 | 65 | 66 | # taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html 67 | # and modified 68 | def fsdp_main(rank, world_size, args): 69 | setup(rank, world_size) 70 | torch.cuda.set_device(rank) 71 | 72 | emulate, base_dtype, compile, use_weight_dynamic_scaling = args 73 | model = get_model(K, N, base_dtype=base_dtype).to(rank) 74 | model_fp8 = copy.deepcopy(model) 75 | 76 | scaling_type_weight = ( 77 | ScalingType.DYNAMIC if use_weight_dynamic_scaling else ScalingType.DELAYED 78 | ) 79 | config = Float8LinearConfig( 80 | cast_config_weight=CastConfig(scaling_type=scaling_type_weight), 81 | # TODO(future): delete this arg as it's always False 82 | emulate=False, 83 | ) 84 | 85 | # Note: we only iterate over `scaling_type_weight` because FSDP only interacts 86 | # with weights. 87 | convert_to_float8_training( 88 | model_fp8, 89 | config=config, 90 | ) 91 | 92 | # To compile FSDP, we need use_orig_params to True 93 | model = FSDP(model, use_orig_params=True) 94 | model_fp8 = FSDP(model_fp8, use_orig_params=True) 95 | # TODO: The following line doesn't work. We should fix it. 96 | # model = FSDP(torch.compile(model), use_orig_params=True) 97 | 98 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) 99 | optimizer_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr) 100 | 101 | # Note: we need two different inputs to properly measure the impact of 102 | # delayed scaling, before the first input uses dynamic scaling to 103 | # populate the buffers 104 | ref_input_global = [ 105 | torch.randn(B, M, K).cuda().to(base_dtype), 106 | torch.randn(B, M, K).cuda().to(base_dtype), 107 | ] 108 | ref_grad_global = [ 109 | torch.randn(B, M, N).cuda().to(base_dtype), 110 | torch.randn(B, M, N).cuda().to(base_dtype), 111 | ] 112 | ref_input_local = [] 113 | ref_grad_local = [] 114 | 115 | # basic distributed data sampling 116 | assert B % world_size == 0 117 | bsz_local_start = int(rank / world_size * B) 118 | bsz_local_end = int((rank + 1) / world_size * B) 119 | for idx in range(N_ITER): 120 | ref_input_local.append( 121 | ref_input_global[idx][bsz_local_start:bsz_local_end].to(rank) 122 | ) 123 | ref_grad_local.append( 124 | ref_grad_global[idx][bsz_local_start:bsz_local_end].to(rank) 125 | ) 126 | 127 | sync_float8_func = sync_float8_amax_and_scale_history 128 | if compile: 129 | sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) 130 | 131 | def forward_backward(model, optim, is_fp8, i): 132 | optim.zero_grad() 133 | y_local = model(ref_input_local[i]) 134 | y_local.backward(ref_grad_local[i]) 135 | if is_fp8 and linear_requires_sync(config): 136 | sync_float8_func(model) 137 | optim.step() 138 | return y_local 139 | 140 | for i in range(N_ITER): 141 | # We first run one iteration without compile, as a workaround to compile float8 layer. 142 | # In the first iter, float8 layers go to the branches of "self.is_amax_initialized == False" 143 | # After that, float8 layers go the the branches of "self.is_amax_initialized == True" 144 | # TODO: Need to fix compile to run wihtout this workaround. 145 | if i == 1 and compile: 146 | model = torch.compile(model) 147 | model_fp8 = torch.compile(model_fp8) 148 | y_local = forward_backward(model, optimizer, is_fp8=False, i=i) 149 | y_local_fp8 = forward_backward(model_fp8, optimizer_fp8, is_fp8=True, i=i) 150 | local_sqnr = compute_error(y_local, y_local_fp8) # noqa: F841 151 | 152 | # get global y 153 | y_global = [ 154 | torch.zeros(*y_local.shape, dtype=base_dtype).to(rank) 155 | for r in range(world_size) 156 | ] 157 | dist.all_gather(y_global, y_local) 158 | y_global = torch.cat(y_global, dim=0) 159 | y_global_fp8 = [ 160 | torch.zeros(*y_local_fp8.shape, dtype=base_dtype).to(rank) 161 | for r in range(world_size) 162 | ] 163 | dist.all_gather(y_global_fp8, y_local_fp8) 164 | y_global_fp8 = torch.cat(y_global_fp8, dim=0) 165 | if rank == 0: 166 | sqnr = compute_error(y_global, y_global_fp8) 167 | assert sqnr > 15.0, f"SQNR of {sqnr} is too low" 168 | 169 | # get global state dict 170 | # https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html 171 | dist.barrier() 172 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 173 | with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): 174 | cpu_state = model.state_dict() 175 | with FSDP.state_dict_type(model_fp8, StateDictType.FULL_STATE_DICT, save_policy): 176 | cpu_state_fp8 = model_fp8.state_dict() 177 | if rank == 0: 178 | for k, v1 in cpu_state.items(): 179 | v2 = cpu_state_fp8[k] 180 | v1, v2 = v1.cpu(), v2.cpu() 181 | sqnr = compute_error(v1, v2) 182 | assert sqnr > 15.0, f"SQNR of {sqnr} is too low, k: {k}, v1: {v1}, v2: {v2}" 183 | 184 | cleanup() 185 | 186 | 187 | def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False): 188 | base_dtype = torch.bfloat16 189 | 190 | emulate = False 191 | if not torch.cuda.is_available(): 192 | warnings.warn("CUDA not available, running in emulation_mode") 193 | emulate = True 194 | elif torch.cuda.get_device_capability() < (9, 0): 195 | warnings.warn( 196 | f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode" 197 | ) 198 | emulate = True 199 | 200 | WORLD_SIZE = torch.cuda.device_count() 201 | args = (emulate, base_dtype, compile_fsdp, use_weight_dynamic_scaling) 202 | mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) 203 | 204 | 205 | if __name__ == "__main__": 206 | fire.Fire(run) 207 | -------------------------------------------------------------------------------- /test/test_fsdp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # terminate script on first error 4 | set -e 5 | 6 | launch() { 7 | echo "launching compile_fsdp $COMPILE, use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING" 8 | 9 | # the NCCL_DEBUG setting is to avoid log spew 10 | # the CUDA_VISIBLE_DEVICES setting is for easy debugging 11 | NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/test_fsdp.py \ 12 | --compile_fsdp $COMPILE --use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING 13 | 14 | echo "✅ All Tests Passed ✅" 15 | } 16 | 17 | if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then 18 | echo "Skipping test_fsdp.sh because no CUDA devices are available." 19 | exit 20 | fi 21 | 22 | # COMPILE, USE_WEIGHT_DYNAMIC_SCALING 23 | for i in False,False False,True True,False True,True 24 | do 25 | IFS=","; set -- $i; 26 | COMPILE=$1; USE_WEIGHT_DYNAMIC_SCALING=$2 27 | launch 28 | done 29 | -------------------------------------------------------------------------------- /test/test_fsdp2/test_fsdp2_common.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import List, Optional 3 | 4 | import float8_experimental.config as config 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn as nn 9 | from float8_experimental.config import Float8LinearConfig, ScalingType 10 | from float8_experimental.float8_linear_utils import ( 11 | linear_requires_sync, 12 | sync_float8_amax_and_scale_history, 13 | ) 14 | from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp 15 | 16 | 17 | def check_parity_no_mp( 18 | test_cls, 19 | ref_model: nn.Module, 20 | ref_optim: torch.optim.Optimizer, 21 | fsdp_model: nn.Module, 22 | fsdp_optim: torch.optim.Optimizer, 23 | local_inp: torch.Tensor, 24 | precompute: bool = False, 25 | config: Optional[Float8LinearConfig] = None, 26 | compile_transformer_block: bool = False, 27 | ): 28 | # TODO(before land): reorder args and make config not optional 29 | for iter_idx in range(10): 30 | losses: List[torch.Tensor] = [] 31 | for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)): 32 | optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 33 | losses.append(model(local_inp).sum()) 34 | losses[-1].backward() 35 | if model is ref_model: 36 | for param in model.parameters(): 37 | dist.all_reduce(param.grad) 38 | param.grad.div_(dist.get_world_size()) 39 | 40 | if linear_requires_sync(config): 41 | sync_float8_amax_and_scale_history(model) 42 | 43 | optim.step() 44 | if ( 45 | model is fsdp_model 46 | and precompute 47 | and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC 48 | ): 49 | precompute_float8_dynamic_scale_for_fsdp(model) 50 | 51 | if compile_transformer_block: 52 | test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4) 53 | else: 54 | test_cls.assertEqual(losses[0], losses[1]) 55 | 56 | 57 | def check_parity_bf16_mp( 58 | test_cls, 59 | ref_model: nn.Module, 60 | ref_model_bf16: nn.Module, 61 | ref_optim: torch.optim.Optimizer, 62 | fsdp_model: nn.Module, 63 | fsdp_optim: torch.optim.Optimizer, 64 | local_inp: torch.Tensor, 65 | ): 66 | for iter_idx in range(10): 67 | losses: List[torch.Tensor] = [] 68 | for model, optim in ( 69 | (ref_model_bf16, ref_optim), 70 | (fsdp_model, fsdp_optim), 71 | ): 72 | optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 73 | losses.append(model(local_inp).sum()) 74 | losses[-1].backward() 75 | if model is ref_model_bf16: 76 | for param_bf16, param_fp32 in zip( 77 | ref_model_bf16.parameters(), ref_model.parameters() 78 | ): 79 | dist.all_reduce(param_bf16.grad) 80 | param_bf16.grad.div_(dist.get_world_size()) 81 | param_fp32.grad = param_bf16.grad.float() 82 | param_bf16.grad = None 83 | # TODO(future): add amax syncing once delayed scaling is supported 84 | optim.step() 85 | for param_fp32, param_bf16 in zip( 86 | ref_model.parameters(), ref_model_bf16.parameters() 87 | ): 88 | param_bf16.detach().copy_(param_fp32) 89 | test_cls.assertEqual(losses[0], losses[1]) 90 | -------------------------------------------------------------------------------- /test/test_fsdp_compile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Test autocast + torch.compile + FSDP + Float8Linear 9 | """ 10 | 11 | import os 12 | import warnings 13 | 14 | import fire 15 | 16 | import torch 17 | import torch.distributed as dist 18 | import torch.multiprocessing as mp 19 | import torch.nn as nn 20 | from float8_experimental import Float8LinearConfig 21 | from float8_experimental.config import CastConfig, ScalingType 22 | from float8_experimental.float8_linear_utils import ( 23 | convert_to_float8_training, 24 | sync_float8_amax_and_scale_history, 25 | ) 26 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 27 | 28 | torch.manual_seed(0) 29 | 30 | B, M, K, N = 8, 8, 32, 32 31 | lr = 0.01 32 | N_ITER = 1 33 | 34 | 35 | def setup(rank, world_size): 36 | os.environ["MASTER_ADDR"] = "localhost" 37 | os.environ["MASTER_PORT"] = "12355" 38 | 39 | # initialize the process group 40 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 41 | 42 | 43 | def cleanup(): 44 | dist.destroy_process_group() 45 | 46 | 47 | def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): 48 | # composability of torch.compile + FSDP + autocast + Float8Linear 49 | # as fo 2023-12-30 50 | 51 | # without any changes to the Float8Linear, we get this error: 52 | # https://gist.github.com/vkuzo/3bcb81806cc92f99ac0b9c5fdf287730 53 | 54 | # if we initialize Float8Linear with is_amax_initialized=True and 55 | # amax_and_scale_synced=True, we get 56 | # https://gist.github.com/vkuzo/ed8e168fd9f7463f1fce34301334ab55 57 | # to get around this, we can disable amax init 58 | config = Float8LinearConfig( 59 | enable_amax_init=False, 60 | cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), 61 | cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), 62 | cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), 63 | emulate=emulate, 64 | ) 65 | 66 | m = nn.Sequential( 67 | nn.Linear(K, N, dtype=base_dtype), 68 | nn.ReLU(), 69 | ) 70 | convert_to_float8_training( 71 | m, 72 | config=config, 73 | ) 74 | return m 75 | 76 | 77 | # taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html 78 | # and modified 79 | def fsdp_main(rank, world_size, args): 80 | setup(rank, world_size) 81 | torch.cuda.set_device(rank) 82 | 83 | (emulate,) = args 84 | 85 | # finally, if we remove the usage of self.bias_dtype, then 86 | # things work e2e. Note that FSDP does not support full-graph compile 87 | # regardless of float8. 88 | 89 | model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=torch.bfloat16).to( 90 | rank 91 | ) 92 | 93 | # To compile FSDP, we need use_orig_params to True 94 | model = FSDP(model, use_orig_params=True) 95 | 96 | optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) 97 | input_local = torch.randn(B, M, K, N, device="cuda") 98 | sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) 99 | 100 | model = torch.compile(model) 101 | 102 | for _iter in range(N_ITER): 103 | optimizer.zero_grad() 104 | with torch.autocast("cuda"): 105 | y_local = model(input_local) 106 | y_local.sum().backward() 107 | sync_float8_func(model) 108 | optimizer.step() 109 | 110 | print("done!") 111 | cleanup() 112 | 113 | 114 | def run(): 115 | emulate = False 116 | if not torch.cuda.is_available(): 117 | warnings.warn("CUDA not available, running in emulation_mode", stacklevel=2) 118 | emulate = True 119 | elif torch.cuda.get_device_capability() < (9, 0): 120 | warnings.warn( 121 | f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode", 122 | stacklevel=2, 123 | ) 124 | emulate = True 125 | 126 | WORLD_SIZE = torch.cuda.device_count() 127 | args = (emulate,) 128 | mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) 129 | 130 | 131 | if __name__ == "__main__": 132 | fire.Fire(run) 133 | -------------------------------------------------------------------------------- /test/test_fsdp_compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # terminate script on first error 4 | set -e 5 | if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then 6 | echo "Skipping test_fsdp_compile.sh because no CUDA devices are available." 7 | exit 8 | fi 9 | 10 | # Code to be executed if CUDA devices are available 11 | NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/test_fsdp_compile.py 12 | -------------------------------------------------------------------------------- /test/test_inference_flows.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import copy 7 | import io 8 | import random 9 | import unittest 10 | 11 | import pytest 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from float8_experimental.config import ScalingType 17 | from float8_experimental.float8_linear_utils import convert_to_float8_training 18 | from float8_experimental.float8_tensor import Float8Tensor 19 | from float8_experimental.float8_utils import compute_error 20 | from float8_experimental.inference import ( 21 | ActivationCasting, 22 | Float8InferenceLinear, 23 | QuantConfig, 24 | quantize_to_float8, 25 | ) 26 | 27 | 28 | random.seed(0) 29 | torch.manual_seed(0) 30 | 31 | is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) 32 | 33 | 34 | class FeedForward(nn.Module): 35 | def __init__(self) -> None: 36 | super().__init__() 37 | self.w1 = nn.Linear(4096, 14336, bias=False) 38 | self.w3 = nn.Linear(4096, 14336, bias=False) 39 | self.w2 = nn.Linear(14336, 4096, bias=False) 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 43 | 44 | def reset_parameters(self): 45 | for m in self.modules(): 46 | if isinstance(m, nn.Linear): 47 | m.reset_parameters() 48 | 49 | 50 | class TestHPTrainToFP8LinearInference: 51 | def base_test_mlp_transform(self, base_mlp, quantized_mlp, input_tensor): 52 | with torch.no_grad(): 53 | base_output = base_mlp(input_tensor) 54 | transformed_output = quantized_mlp(input_tensor) 55 | 56 | # Compute and check SQNR 57 | sqnr = compute_error(base_output, transformed_output) 58 | assert sqnr.item() > 20, f"SQNR is too low: {sqnr.item()} dB" 59 | 60 | @pytest.mark.parametrize("compile_backend", ["eager", "inductor"]) 61 | @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) 62 | @unittest.skipIf( 63 | not torch.cuda.is_available() or not is_H100, 64 | "CUDA not available or on non H100 machine", 65 | ) 66 | def test_dynamic_fp8_mlp(self, compile_backend, dtype): 67 | original_mlp = FeedForward().to("cuda", dtype=dtype) 68 | original_mlp.reset_parameters() 69 | 70 | dynamic_fp8_mlp = copy.deepcopy(original_mlp) 71 | 72 | quant_config = QuantConfig(ActivationCasting.DYNAMIC) 73 | quantize_to_float8(dynamic_fp8_mlp, quant_config) 74 | 75 | batch_size = 4 76 | num_tokens = 1024 77 | embedding_dim = 4096 78 | 79 | input_tensor = torch.randn( 80 | batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype 81 | ) 82 | 83 | # Compile the models 84 | compiled_original_mlp = torch.compile( 85 | original_mlp, backend=compile_backend, fullgraph=True 86 | ) 87 | compiled_dynamic_fp8_mlp = torch.compile( 88 | dynamic_fp8_mlp, backend=compile_backend, fullgraph=True 89 | ) 90 | 91 | self.base_test_mlp_transform( 92 | compiled_original_mlp, compiled_dynamic_fp8_mlp, input_tensor 93 | ) 94 | 95 | @pytest.mark.parametrize("compile_backend", ["eager", "inductor"]) 96 | @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) 97 | @unittest.skipIf( 98 | not torch.cuda.is_available() or not is_H100, 99 | "CUDA not available or on non H100 machine", 100 | ) 101 | def test_static_fp8_mlp(self, compile_backend, dtype): 102 | original_mlp = FeedForward().to("cuda", dtype=dtype) 103 | original_mlp.reset_parameters() 104 | 105 | static_fp8_mlp = copy.deepcopy(original_mlp) 106 | quant_config = QuantConfig( 107 | ActivationCasting.STATIC, 108 | static_quantization_scale=torch.tensor( 109 | [1.0], device="cuda", dtype=torch.float32 110 | ), 111 | ) 112 | quantize_to_float8(static_fp8_mlp, quant_config) 113 | 114 | batch_size = 4 115 | num_tokens = 1024 116 | embedding_dim = 4096 117 | 118 | input_tensor = torch.randn( 119 | batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype 120 | ) 121 | 122 | # Compile the models 123 | compiled_original_mlp = torch.compile( 124 | original_mlp, backend=compile_backend, fullgraph=True 125 | ) 126 | compiled_static_fp8_mlp = torch.compile( 127 | static_fp8_mlp, backend=compile_backend, fullgraph=True 128 | ) 129 | 130 | self.base_test_mlp_transform( 131 | compiled_original_mlp, compiled_static_fp8_mlp, input_tensor 132 | ) 133 | 134 | @pytest.mark.parametrize("compile_backend", ["eager", "inductor"]) 135 | @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) 136 | @unittest.skipIf( 137 | not torch.cuda.is_available() or not is_H100, 138 | "CUDA not available or on non H100 machine", 139 | ) 140 | def test_weight_only_fp8_mlp(self, compile_backend, dtype): 141 | original_mlp = FeedForward().to("cuda", dtype=dtype) 142 | original_mlp.reset_parameters() 143 | 144 | static_fp8_mlp = copy.deepcopy(original_mlp) 145 | quant_config = QuantConfig(ActivationCasting.WEIGHT_ONLY) 146 | quantize_to_float8(static_fp8_mlp, quant_config) 147 | 148 | batch_size = 4 149 | num_tokens = 1024 150 | embedding_dim = 4096 151 | 152 | input_tensor = torch.randn( 153 | batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype 154 | ) 155 | 156 | # Compile the models 157 | compiled_original_mlp = torch.compile( 158 | original_mlp, backend=compile_backend, fullgraph=True 159 | ) 160 | compiled_static_fp8_mlp = torch.compile( 161 | static_fp8_mlp, backend=compile_backend, fullgraph=True 162 | ) 163 | 164 | self.base_test_mlp_transform( 165 | compiled_original_mlp, compiled_static_fp8_mlp, input_tensor 166 | ) 167 | 168 | 169 | class TestFP8TrainToFP8LinearInference: 170 | def train(self, model: nn.Module, dtype: torch.dtype): 171 | model.train() 172 | optimizer = torch.optim.SGD(model.parameters(), lr=0.001) 173 | criterion = nn.MSELoss() 174 | target_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) 175 | for _ in range(10): 176 | input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) 177 | optimizer.zero_grad() 178 | output = model(input_tensor) 179 | loss = criterion(output, target_tensor) 180 | loss.backward() 181 | optimizer.step() 182 | model.eval() 183 | return model 184 | 185 | @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) 186 | @unittest.skipIf( 187 | not torch.cuda.is_available() or not is_H100, 188 | "CUDA not available or on non H100 machine", 189 | ) 190 | def test_fp8_save_and_load(self, dtype: torch.dtype): 191 | # Initialize FP8 model 192 | fp8_mlp = FeedForward().to("cuda", dtype=torch.float32) 193 | fp8_mlp.reset_parameters() 194 | convert_to_float8_training(fp8_mlp) 195 | 196 | # Train the model 197 | self.train(fp8_mlp, dtype) 198 | 199 | # Generate input tensor and original out 200 | input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype) 201 | og_out = fp8_mlp(input_tensor) 202 | 203 | # Save model state dict 204 | buffer = io.BytesIO() 205 | torch.save(fp8_mlp.state_dict(), buffer) 206 | 207 | # Reset buffer position to the beginning 208 | buffer.seek(0) 209 | 210 | # Later on you load the model, will be w/ Float8Linear on meta device 211 | with torch.device("meta"): 212 | new_fp8_mlp = FeedForward().to(dtype=dtype) 213 | convert_to_float8_training(new_fp8_mlp) 214 | 215 | # Load the actual data 216 | new_fp8_mlp.load_state_dict( 217 | torch.load(buffer, weights_only=True), strict=True, assign=True 218 | ) 219 | 220 | quant_config = QuantConfig(ActivationCasting.DYNAMIC) 221 | quantize_to_float8(new_fp8_mlp, quant_config) 222 | 223 | fp8_mod_count = 0 224 | for module in new_fp8_mlp.modules(): 225 | if isinstance(module, Float8InferenceLinear): 226 | assert isinstance(module.weight, Float8Tensor) 227 | assert module.weight.requires_grad is False 228 | fp8_mod_count += 1 229 | assert fp8_mod_count == 3, "Expected 3 FP8 modules, got {}".format( 230 | fp8_mod_count 231 | ) 232 | 233 | new_out = new_fp8_mlp(input_tensor) 234 | 235 | # Assert exact equality 236 | assert torch.all(og_out == new_out).item() 237 | 238 | 239 | if __name__ == "__main__": 240 | pytest.main([__file__]) 241 | -------------------------------------------------------------------------------- /test/test_numerics_integration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD 3-Clause license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Tests LLaMa FeedForward numerics with float8 8 | 9 | import copy 10 | from typing import Optional 11 | 12 | import pytest 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType 18 | from float8_experimental.float8_linear_utils import ( 19 | convert_to_float8_training, 20 | linear_requires_sync, 21 | sync_float8_amax_and_scale_history, 22 | ) 23 | from float8_experimental.float8_utils import compute_error, IS_ROCM 24 | 25 | is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) 26 | 27 | 28 | torch.manual_seed(0) 29 | 30 | 31 | # copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py 32 | class FeedForward(nn.Module): 33 | """ 34 | FeedForward module 35 | 36 | Args: 37 | dim (int): Input dimension. 38 | hidden_dim (int): Hidden dimension of the feedforward layer. 39 | multiple_of (int): Value to ensure hidden dimension is a multiple of this value. 40 | ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. 41 | 42 | Attributes: 43 | w1 (Linear): Linear transformation for the first layer. 44 | w2 (Linear): Linear transformation for the second layer. 45 | w3 (Linear): Linear transformation for the third layer. 46 | 47 | """ 48 | 49 | def __init__( 50 | self, 51 | dim: int, 52 | hidden_dim: int, 53 | multiple_of: int, 54 | ffn_dim_multiplier: Optional[float], 55 | ): 56 | super().__init__() 57 | hidden_dim = int(2 * hidden_dim / 3) 58 | # custom dim factor multiplier 59 | if ffn_dim_multiplier is not None: 60 | hidden_dim = int(ffn_dim_multiplier * hidden_dim) 61 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 62 | 63 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 64 | self.w2 = nn.Linear(hidden_dim, dim, bias=False) 65 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 66 | 67 | def forward(self, x): 68 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 69 | 70 | def init_weights(self, init_std: float): 71 | nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) 72 | for linear in (self.w2, self.w3): 73 | nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) 74 | 75 | 76 | class TestFloat8NumericsIntegrationTest: 77 | @pytest.mark.parametrize( 78 | "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC] 79 | ) 80 | @pytest.mark.parametrize( 81 | "scaling_type_weight", [ScalingType.DELAYED, ScalingType.DYNAMIC] 82 | ) 83 | @pytest.mark.parametrize( 84 | "scaling_type_grad_output", 85 | [ScalingType.DELAYED, ScalingType.DYNAMIC], 86 | ) 87 | @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") 88 | @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") 89 | def test_encoder_fw_bw( 90 | self, 91 | scaling_type_input: ScalingType, 92 | scaling_type_weight: ScalingType, 93 | scaling_type_grad_output: ScalingType, 94 | ): 95 | # TODO(later): maybe add float16 back if it becomes important 96 | data_dtype = torch.bfloat16 97 | 98 | # LLaMa 3 70B shapes 99 | model_ref = ( 100 | FeedForward( 101 | dim=4096, 102 | hidden_dim=16384, 103 | multiple_of=1024, 104 | ffn_dim_multiplier=1.3, 105 | ) 106 | .cuda() 107 | .to(data_dtype) 108 | ) 109 | 110 | # for now just test the encoder to simplify things 111 | model_fp8 = copy.deepcopy(model_ref) 112 | config = Float8LinearConfig( 113 | cast_config_input=CastConfig(scaling_type=scaling_type_input), 114 | cast_config_weight=CastConfig(scaling_type=scaling_type_weight), 115 | cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), 116 | ) 117 | convert_to_float8_training( 118 | model_fp8, 119 | config=config, 120 | ) 121 | 122 | lr = 0.01 123 | optim_ref = torch.optim.SGD(model_ref.parameters(), lr=lr) 124 | optim_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr) 125 | 126 | # Note: you need two different inputs to properly test numerics 127 | # of delayed scaling, because the first time around the initialization 128 | # logic of delayed scaling behaves as dynamic scaling 129 | # TODO(future): also make unit tests do this properly 130 | shape = (1, 8192, 4096) 131 | data1 = torch.randn(*shape, device="cuda", dtype=data_dtype) 132 | data2 = torch.randn(*shape, device="cuda", dtype=data_dtype) 133 | 134 | model_ref(data1).sum().backward() 135 | # zero out grads without stepping, since we just want to compare grads 136 | # of the second datum 137 | optim_ref.zero_grad() 138 | model_ref_out = model_ref(data2) 139 | model_ref_out.sum().backward() 140 | 141 | if linear_requires_sync(config): 142 | sync_float8_amax_and_scale_history(model_fp8) 143 | model_fp8(data1).sum().backward() 144 | # zero out grads without stepping, since we just want to compare grads 145 | # of the second datum 146 | optim_fp8.zero_grad() 147 | if linear_requires_sync(config): 148 | sync_float8_amax_and_scale_history(model_fp8) 149 | model_fp8_out = model_fp8(data2) 150 | model_fp8_out.sum().backward() 151 | 152 | out_sqnr = compute_error(model_ref_out, model_fp8_out) 153 | assert out_sqnr > 20.0 154 | 155 | ref_name_to_grad = { 156 | name: param.grad for name, param in model_ref.named_parameters() 157 | } 158 | 159 | grad_sqnr_threshold = 20.0 160 | 161 | for name, param in model_fp8.named_parameters(): 162 | ref_grad = ref_name_to_grad[name] 163 | cur_grad = param.grad 164 | sqnr = compute_error(ref_grad, cur_grad) 165 | assert sqnr > grad_sqnr_threshold 166 | 167 | 168 | if __name__ == "__main__": 169 | pytest.main([__file__]) 170 | --------------------------------------------------------------------------------