├── .clang-format ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── blossom-ci.yml │ ├── build.yml │ ├── deploy_nightly_docs.yml │ ├── docs.yml │ ├── license.yml │ ├── lint.yml │ ├── trigger-ci.yml │ └── upload-ci-logs.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── Acknowledgements.txt ├── CONTRIBUTING.rst ├── CPPLINT.cfg ├── LICENSE ├── README.rst ├── SECURITY.md ├── benchmarks └── attention │ └── benchmark_attention.py ├── build_tools ├── VERSION.txt ├── __init__.py ├── build_ext.py ├── jax.py ├── pytorch.py ├── te_version.py ├── utils.py └── wheel_utils │ ├── Dockerfile.aarch │ ├── Dockerfile.x86 │ ├── build_wheels.sh │ ├── launch_aarch.sh │ └── launch_x86.sh ├── docs ├── .gitignore ├── Doxyfile ├── Makefile ├── _static │ ├── NVIDIA-LogoBlack.svg │ └── css │ │ ├── nvidia_font.css │ │ └── nvidia_footer.css ├── _templates │ ├── footer.html │ └── layout.html ├── api │ ├── c │ │ ├── activation.rst │ │ ├── cast.rst │ │ ├── cast_transpose_noop.rst │ │ ├── cudnn.rst │ │ ├── fused_attn.rst │ │ ├── fused_rope.rst │ │ ├── gemm.rst │ │ ├── index.rst │ │ ├── multi_tensor.rst │ │ ├── normalization.rst │ │ ├── padding.rst │ │ ├── permutation.rst │ │ ├── recipe.rst │ │ ├── softmax.rst │ │ ├── swizzle.rst │ │ ├── transformer_engine.rst │ │ └── transpose.rst │ ├── common.rst │ ├── framework.rst │ ├── jax.rst │ └── pytorch.rst ├── conf.py ├── debug.rst ├── debug │ ├── 1_getting_started.rst │ ├── 2_config_file_structure.rst │ ├── 3_api_debug_setup.rst │ ├── 3_api_features.rst │ ├── 3_api_te_calls.rst │ ├── 4_distributed.rst │ ├── api.rst │ └── img │ │ ├── api_calls1.svg │ │ ├── api_calls2.svg │ │ ├── fake_quant.svg │ │ ├── introduction.svg │ │ ├── names.svg │ │ ├── pipeline_logging.svg │ │ ├── reduction1.svg │ │ ├── reduction2.svg │ │ ├── reduction3.svg │ │ ├── scaling_factors.svg │ │ └── tensorboard.png ├── examples │ ├── E8M0.png │ ├── H200-NeMo-performance.png │ ├── MXFP8_FP8_comparison_1.png │ ├── MXFP8_FP8_comparison_2.png │ ├── advanced_optimizations.ipynb │ ├── attention │ │ ├── arbitrary_mask_to_post_scale_bias.py │ │ ├── attention.ipynb │ │ ├── dot_product_attention.png │ │ └── example_attention.py │ ├── comparison-fp8-bf16-training-nvidia-dgx-cloud-benchmarking-performance-explorer.jpg │ ├── delayed_scaling.png │ ├── fp8_formats.png │ ├── fp8_primer.ipynb │ ├── linear_mxfp8.png │ ├── loss_scaling.png │ ├── quickstart.ipynb │ ├── quickstart_utils.py │ ├── te_llama │ │ ├── media │ │ │ ├── llama_for_causal_lm.svg │ │ │ ├── llama_zoom.svg │ │ │ ├── llamadecoderlayer.svg │ │ │ ├── model_change.svg │ │ │ ├── swiglu.svg │ │ │ ├── swiglu_te.svg │ │ │ ├── tellamadecoderlayer.svg │ │ │ ├── transformer_llama.png │ │ │ ├── transformer_vs_llama.svg │ │ │ └── weight_swap.svg │ │ ├── te_llama.py │ │ ├── tutorial_accelerate_hf_llama_with_te.ipynb │ │ └── utils.py │ └── transformer_layer.png ├── faq.rst ├── index.rst ├── installation.rst └── version_select.patch ├── examples ├── README.md ├── jax │ ├── README.md │ ├── encoder │ │ ├── README.md │ │ ├── common.py │ │ ├── conftest.py │ │ ├── requirements.txt │ │ ├── run_test_multiprocessing_encoder.sh │ │ ├── test_model_parallel_encoder.py │ │ ├── test_multigpu_encoder.py │ │ ├── test_multiprocessing_encoder.py │ │ └── test_single_gpu_encoder.py │ └── mnist │ │ ├── README.md │ │ ├── requirements.txt │ │ └── test_single_gpu_mnist.py └── pytorch │ ├── comm_gemm_overlap │ ├── README.md │ └── te_layer_with_overlap.py │ ├── fsdp │ ├── README.md │ └── fsdp.py │ └── mnist │ ├── README.md │ └── main.py ├── pylintrc ├── qa ├── L0_cppunittest │ └── test.sh ├── L0_jax_distributed_unittest │ └── test.sh ├── L0_jax_lint │ └── test.sh ├── L0_jax_unittest │ └── test.sh ├── L0_jax_wheel │ └── test.sh ├── L0_license │ ├── config.json │ ├── copyright_checker.py │ └── test.sh ├── L0_pytorch_debug_unittest │ └── test.sh ├── L0_pytorch_lint │ └── test.sh ├── L0_pytorch_unittest │ └── test.sh ├── L0_pytorch_wheel │ └── test.sh ├── L1_jax_distributed_unittest │ └── test.sh ├── L1_pytorch_distributed_unittest │ └── test.sh ├── L1_pytorch_mcore_integration │ ├── .gitignore │ ├── merges.txt │ └── test.sh ├── L1_pytorch_thunder_integration │ └── test.sh ├── L2_jax_unittest │ └── test.sh ├── L3_pytorch_FA_versions_test │ └── test.sh └── format.sh ├── setup.py ├── tests ├── cpp │ ├── CMakeLists.txt │ ├── operator │ │ ├── CMakeLists.txt │ │ ├── test_act.cu │ │ ├── test_cast.cu │ │ ├── test_cast_current_scaling.cu │ │ ├── test_cast_dbias.cu │ │ ├── test_cast_dbias_dgelu.cu │ │ ├── test_cast_float8blockwise.cu │ │ ├── test_cast_gated_swiglu.cu │ │ ├── test_cast_mxfp8.cu │ │ ├── test_cast_mxfp8_gated_swiglu.cu │ │ ├── test_cast_transpose.cu │ │ ├── test_cast_transpose_current_scaling.cu │ │ ├── test_cast_transpose_dbias.cu │ │ ├── test_cast_transpose_dbias_dgelu.cu │ │ ├── test_cast_transpose_dgeglu.cu │ │ ├── test_causal_softmax.cu │ │ ├── test_dequantize_mxfp8.cu │ │ ├── test_memset.cu │ │ ├── test_multi_cast_transpose.cu │ │ ├── test_multi_padding.cu │ │ ├── test_normalization.cu │ │ ├── test_normalization.h │ │ ├── test_normalization_mxfp8.cu │ │ ├── test_qdq.cu │ │ ├── test_swizzle.cu │ │ └── test_transpose.cu │ ├── test_common.cu │ ├── test_common.h │ └── util │ │ ├── CMakeLists.txt │ │ ├── test_nvrtc.cpp │ │ └── test_string.cpp ├── jax │ ├── conftest.py │ ├── distributed_test_base.py │ ├── pytest.ini │ ├── test_custom_call_compute.py │ ├── test_distributed_fused_attn.py │ ├── test_distributed_layernorm.py │ ├── test_distributed_layernorm_mlp.py │ ├── test_distributed_softmax.py │ ├── test_functions.py │ ├── test_fused_attn.py │ ├── test_helper.py │ ├── test_layer.py │ ├── test_misc.py │ ├── test_sanity_import.py │ ├── test_sharding.py │ ├── test_softmax.py │ └── utils.py └── pytorch │ ├── debug │ ├── conftest.py │ ├── run_distributed.py │ ├── test_api_features.py │ ├── test_config.py │ ├── test_configs │ │ ├── disable_fp8_gemms.yaml │ │ ├── disable_fp8_layer.yaml │ │ ├── dummy_feature.yaml │ │ ├── fake_quantization_config.yaml │ │ ├── per_tensor_scaling.yaml │ │ ├── stats_collection_test_config.yaml │ │ └── tensor_manipulation_transformer_engine.yaml │ ├── test_distributed.py │ ├── test_numerics.py │ ├── test_sanity.py │ └── utils.py │ ├── distributed │ ├── run_cast_master_weights_to_fp8.py │ ├── run_fsdp2_model.py │ ├── run_gemm_with_overlap.py │ ├── run_layer_with_overlap.py │ ├── run_numerics.py │ ├── test_cast_master_weights_to_fp8.py │ ├── test_comm_gemm_overlap.py │ ├── test_fusible_ops.py │ ├── test_fusible_ops_with_userbuffers.py │ ├── test_numerics.py │ └── test_torch_fsdp2.py │ ├── fused_attn │ ├── run_fused_attn_with_cp.py │ ├── test_fused_attn.py │ ├── test_fused_attn_with_cp.py │ └── test_kv_cache.py │ ├── references │ ├── blockwise_fp8_gemm_reference.py │ ├── blockwise_quantizer_reference.py │ ├── quantize_scale_calc.py │ └── ref_per_tensor_cs.py │ ├── test_cpu_offloading.py │ ├── test_cuda_graphs.py │ ├── test_deferred_init.py │ ├── test_float8_blockwise_gemm_exact.py │ ├── test_float8_blockwise_scaling_exact.py │ ├── test_float8_current_scaling_exact.py │ ├── test_float8blockwisetensor.py │ ├── test_float8tensor.py │ ├── test_fused_optimizer.py │ ├── test_fused_rope.py │ ├── test_fusible_ops.py │ ├── test_gqa.py │ ├── test_hf_integration.py │ ├── test_jit.py │ ├── test_multi_tensor.py │ ├── test_numerics.py │ ├── test_parallel_cross_entropy.py │ ├── test_permutation.py │ ├── test_recipe.py │ ├── test_sanity.py │ ├── test_sanity_import.py │ └── utils.py └── transformer_engine ├── __init__.py ├── common ├── CMakeLists.txt ├── __init__.py ├── activation │ ├── activation_template.h │ ├── gelu.cu │ ├── relu.cu │ └── swiglu.cu ├── comm_gemm_overlap │ ├── comm_gemm_overlap.cpp │ └── userbuffers │ │ ├── ipcsocket.cc │ │ ├── ipcsocket.h │ │ ├── userbuffers-host.cpp │ │ ├── userbuffers.cu │ │ └── userbuffers.h ├── common.cu ├── common.h ├── cudnn_utils.cpp ├── cudnn_utils.h ├── fused_attn │ ├── context_parallel.cu │ ├── flash_attn.cu │ ├── fused_attn.cpp │ ├── fused_attn_f16_arbitrary_seqlen.cu │ ├── fused_attn_f16_arbitrary_seqlen.h │ ├── fused_attn_f16_max512_seqlen.cu │ ├── fused_attn_f16_max512_seqlen.h │ ├── fused_attn_fp8.cu │ ├── fused_attn_fp8.h │ ├── kv_cache.cu │ ├── utils.cu │ └── utils.h ├── fused_rope │ └── fused_rope.cu ├── fused_softmax │ ├── scaled_aligned_causal_masked_softmax.cu │ ├── scaled_masked_softmax.cu │ └── scaled_upper_triang_masked_softmax.cu ├── gemm │ └── cublaslt_gemm.cu ├── include │ └── transformer_engine │ │ ├── activation.h │ │ ├── cast.h │ │ ├── cast_transpose_noop.h │ │ ├── comm_gemm_overlap.h │ │ ├── cudnn.h │ │ ├── fused_attn.h │ │ ├── fused_rope.h │ │ ├── gemm.h │ │ ├── multi_tensor.h │ │ ├── normalization.h │ │ ├── padding.h │ │ ├── permutation.h │ │ ├── recipe.h │ │ ├── softmax.h │ │ ├── swizzle.h │ │ ├── transformer_engine.h │ │ └── transpose.h ├── libtransformer_engine.version ├── multi_tensor │ ├── adam.cu │ ├── compute_scale.cu │ ├── l2norm.cu │ ├── multi_tensor_apply.cuh │ ├── scale.cu │ └── sgd.cu ├── normalization │ ├── common.cpp │ ├── common.h │ ├── kernel_traits.h │ ├── layernorm │ │ ├── ln_api.cpp │ │ ├── ln_bwd_kernels.cuh │ │ ├── ln_bwd_semi_cuda_kernel.cu │ │ ├── ln_fwd_cuda_kernel.cu │ │ └── ln_fwd_kernels.cuh │ └── rmsnorm │ │ ├── rmsnorm_api.cpp │ │ ├── rmsnorm_bwd_kernels.cuh │ │ ├── rmsnorm_bwd_semi_cuda_kernel.cu │ │ ├── rmsnorm_fwd_cuda_kernel.cu │ │ └── rmsnorm_fwd_kernels.cuh ├── nvshmem_api │ ├── CMakeLists.txt │ ├── nvshmem_waitkernel.cu │ └── nvshmem_waitkernel.h ├── nvtx.h ├── permutation │ └── permutation.cu ├── recipe │ ├── __init__.py │ ├── current_scaling.cu │ ├── delayed_scaling.cu │ ├── fp8_block_scaling.cu │ └── recipe_common.cuh ├── swizzle │ └── swizzle.cu ├── transformer_engine.cpp ├── transpose │ ├── cast_transpose.cu │ ├── cast_transpose.h │ ├── cast_transpose_fusion.cu │ ├── multi_cast_transpose.cu │ ├── quantize_transpose_square_blockwise.cu │ ├── quantize_transpose_vector_blockwise.cu │ ├── rtc │ │ ├── cast_transpose.cu │ │ ├── cast_transpose_fusion.cu │ │ └── transpose.cu │ ├── transpose.cu │ └── transpose_fusion.cu ├── util │ ├── cast.cu │ ├── cast_gated_kernels.cuh │ ├── cast_kernels.cuh │ ├── cuda_driver.cpp │ ├── cuda_driver.h │ ├── cuda_nvml.cpp │ ├── cuda_nvml.h │ ├── cuda_runtime.cpp │ ├── cuda_runtime.h │ ├── dequantize_kernels.cuh │ ├── handle_manager.h │ ├── logging.h │ ├── math.h │ ├── padding.cu │ ├── ptx.cuh │ ├── pybind_helper.h │ ├── rtc.cpp │ ├── rtc.h │ ├── shared_lib_wrapper.h │ ├── string.h │ ├── string_header.h.in │ ├── system.h │ └── vectorized_pointwise.h ├── utils.cuh └── utils.py ├── debug ├── __init__.py ├── features │ ├── __init__.py │ ├── _test_dummy_feature.py │ ├── api.py │ ├── disable_fp8_gemm.py │ ├── disable_fp8_layer.py │ ├── fake_quant.py │ ├── log_fp8_tensor_stats.py │ ├── log_tensor_stats.py │ ├── per_tensor_scaling.py │ └── utils │ │ ├── __init__.py │ │ ├── stats_buffer.py │ │ └── stats_computation.py └── pytorch │ ├── __init__.py │ ├── debug_quantization.py │ ├── debug_state.py │ └── utils.py ├── jax ├── MANIFEST.in ├── __init__.py ├── activation.py ├── attention.py ├── cpp_extensions │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── base.py │ ├── gemm.py │ ├── misc.py │ ├── normalization.py │ ├── quantization.py │ └── softmax.py ├── csrc │ ├── extensions.h │ └── extensions │ │ ├── activation.cpp │ │ ├── attention.cpp │ │ ├── cublas.cpp │ │ ├── cudnn.cpp │ │ ├── ffi.cpp │ │ ├── ffi.h │ │ ├── gemm.cpp │ │ ├── misc.cpp │ │ ├── misc.h │ │ ├── normalization.cpp │ │ ├── pybind.cpp │ │ ├── quantization.cpp │ │ ├── softmax.cpp │ │ ├── utils.cpp │ │ └── utils.h ├── dense.py ├── flax │ ├── __init__.py │ ├── module.py │ └── transformer.py ├── layernorm.py ├── layernorm_dense.py ├── layernorm_mlp.py ├── quantize │ ├── __init__.py │ ├── dequantizer.py │ ├── helper.py │ ├── metadata.py │ ├── quantizer.py │ ├── scaling_modes.py │ └── tensor.py ├── setup.py ├── sharding.py └── softmax.py └── pytorch ├── MANIFEST.in ├── __init__.py ├── attention ├── __init__.py ├── dot_product_attention │ ├── __init__.py │ ├── backends.py │ ├── context_parallel.py │ ├── dot_product_attention.py │ ├── softmax.py │ └── utils.py ├── inference.py ├── multi_head_attention.py └── rope.py ├── constants.py ├── cpp_extensions ├── __init__.py ├── fused_attn.py └── gemm.py ├── cpu_offload.py ├── cross_entropy.py ├── csrc ├── common.cpp ├── common.h ├── extensions.h ├── extensions │ ├── activation.cpp │ ├── apply_rope.cpp │ ├── attention.cpp │ ├── bias.cpp │ ├── cast.cpp │ ├── comm_gemm_overlap.cpp │ ├── fp8_block_scaling_partial_cast.cpp │ ├── gemm.cpp │ ├── misc.cpp │ ├── multi_tensor │ │ ├── adam.cpp │ │ ├── compute_scale.cpp │ │ ├── l2norm.cpp │ │ ├── scale.cpp │ │ └── sgd.cpp │ ├── normalization.cpp │ ├── nvshmem_comm.cpp │ ├── padding.cpp │ ├── permutation.cpp │ ├── pybind.cpp │ ├── recipe.cpp │ ├── softmax.cpp │ └── transpose.cpp ├── pybind.h ├── quantizer.cpp ├── type_converters.cpp ├── util.cpp └── util.h ├── distributed.py ├── float8_tensor.py ├── fp8.py ├── graph.py ├── jit.py ├── module ├── __init__.py ├── _common.py ├── base.py ├── fp8_padding.py ├── fp8_unpadding.py ├── grouped_linear.py ├── layernorm.py ├── layernorm_linear.py ├── layernorm_mlp.py ├── linear.py └── rmsnorm.py ├── numerics_debug.py ├── ops ├── __init__.py ├── _common.py ├── basic │ ├── __init__.py │ ├── activation.py │ ├── add_in_place.py │ ├── all_gather.py │ ├── all_reduce.py │ ├── basic_linear.py │ ├── bias.py │ ├── identity.py │ ├── layer_norm.py │ ├── make_extra_output.py │ ├── quantize.py │ ├── reduce_scatter.py │ ├── reshape.py │ └── rmsnorm.py ├── fused │ ├── __init__.py │ ├── backward_linear_add.py │ ├── forward_linear_bias_activation.py │ ├── forward_linear_bias_add.py │ ├── userbuffers_backward_linear.py │ └── userbuffers_forward_linear.py ├── fuser.py ├── linear.py ├── op.py └── sequential.py ├── optimizers ├── __init__.py ├── fused_adam.py ├── fused_sgd.py └── multi_tensor_apply.py ├── permutation.py ├── setup.py ├── tensor ├── __init__.py ├── _internal │ ├── __init__.py │ ├── float8_blockwise_tensor_base.py │ ├── float8_tensor_base.py │ └── mxfp8_tensor_base.py ├── float8_blockwise_tensor.py ├── float8_tensor.py ├── mxfp8_tensor.py ├── quantized_tensor.py └── utils.py ├── transformer.py ├── triton ├── __init__.py ├── cross_entropy.py └── permutation.py └── utils.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | 12 | A clear and concise description of what the bug is. 13 | 14 | **Steps/Code to reproduce bug** 15 | 16 | Please list *minimal* steps or code snippet for us to be able to reproduce the bug. 17 | 18 | A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports. 19 | 20 | 21 | **Expected behavior** 22 | 23 | A clear and concise description of what you expected to happen. 24 | 25 | **Environment overview (please complete the following information)** 26 | 27 | - Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)] 28 | - Method of Transformer Engine install: [pip install or from source]. Please specify exact commands you used to install. 29 | - If method of install is [Docker], provide `docker pull` & `docker run` commands used 30 | 31 | **Environment details** 32 | 33 | If NVIDIA docker image is used you don't need to specify these. 34 | Otherwise, please provide: 35 | - OS version 36 | - PyTorch version 37 | - Python version 38 | - Transformer Engine version 39 | - CUDA version 40 | - CUDNN version 41 | 42 | **Device details** 43 | - GPU model 44 | 45 | **Additional context** 46 | 47 | Add any other context about the problem here. 48 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: feature request 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | 12 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 13 | 14 | **Describe the solution you'd like** 15 | 16 | A clear and concise description of what you want to happen. 17 | Provide a code snippet on how new APIs/changes would be used by others. 18 | 19 | **Describe alternatives you've considered** 20 | 21 | A clear and concise description of any alternative solutions or features you've considered. 22 | 23 | **Additional context** 24 | 25 | Add any other context or screenshots about the feature request here. 26 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Please include a brief summary of the changes, relevant motivation and context. 4 | 5 | Fixes # (issue) 6 | 7 | ## Type of change 8 | 9 | - [ ] Documentation change (change only to the documentation, either a fix or a new content) 10 | - [ ] Bug fix (non-breaking change which fixes an issue) 11 | - [ ] New feature (non-breaking change which adds functionality) 12 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 13 | - [ ] Infra/Build change 14 | - [ ] Code refactoring 15 | 16 | ## Changes 17 | 18 | Please list the changes introduced in this PR: 19 | 20 | - Change A 21 | - Change B 22 | 23 | # Checklist: 24 | 25 | - [ ] I have read and followed the [contributing guidelines](https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst) 26 | - [ ] The functionality is complete 27 | - [ ] I have commented my code, particularly in hard-to-understand areas 28 | - [ ] I have made corresponding changes to the documentation 29 | - [ ] My changes generate no new warnings 30 | - [ ] I have added tests that prove my fix is effective or that my feature works 31 | - [ ] New and existing unit tests pass locally with my changes 32 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | # A workflow to trigger TE build on GitHub 6 | name: 'Build' 7 | on: 8 | pull_request: 9 | workflow_dispatch: 10 | jobs: 11 | core: 12 | name: 'Core' 13 | runs-on: ubuntu-latest 14 | container: 15 | image: nvcr.io/nvidia/cuda:12.1.0-devel-ubuntu22.04 16 | options: --user root 17 | steps: 18 | - name: 'Dependencies' 19 | run: | 20 | apt-get update 21 | apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 22 | pip install cmake==3.21.0 23 | - name: 'Checkout' 24 | uses: actions/checkout@v3 25 | with: 26 | submodules: recursive 27 | - name: 'Build' 28 | run: pip install --no-build-isolation . -v 29 | env: 30 | NVTE_FRAMEWORK: none 31 | MAX_JOBS: 1 32 | - name: 'Sanity check' 33 | run: python3 -c "import transformer_engine" 34 | working-directory: / 35 | pytorch: 36 | name: 'PyTorch' 37 | runs-on: ubuntu-latest 38 | container: 39 | image: nvcr.io/nvidia/cuda:12.8.0-devel-ubuntu22.04 40 | options: --user root 41 | steps: 42 | - name: 'Dependencies' 43 | run: | 44 | apt-get update 45 | apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12 46 | pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 47 | - name: 'Checkout' 48 | uses: actions/checkout@v3 49 | with: 50 | submodules: recursive 51 | - name: 'Build' 52 | run: pip install --no-build-isolation . -v --no-deps 53 | env: 54 | NVTE_FRAMEWORK: pytorch 55 | MAX_JOBS: 1 56 | - name: 'Sanity check' 57 | if: false # Sanity import test requires Flash Attention 58 | run: python3 tests/pytorch/test_sanity_import.py 59 | jax: 60 | name: 'JAX' 61 | runs-on: ubuntu-latest 62 | container: 63 | image: ghcr.io/nvidia/jax:jax 64 | options: --user root 65 | steps: 66 | - name: 'Checkout' 67 | uses: actions/checkout@v3 68 | with: 69 | submodules: recursive 70 | - name: 'Build' 71 | run: pip install --no-build-isolation . -v 72 | env: 73 | NVTE_FRAMEWORK: jax 74 | MAX_JOBS: 1 75 | - name: 'Sanity check' 76 | run: python tests/jax/test_sanity_import.py 77 | -------------------------------------------------------------------------------- /.github/workflows/deploy_nightly_docs.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | # A workflow to deploy the nightly version of TE documentation to GitHub Pages 6 | name: Deploy nightly docs 7 | on: 8 | push: 9 | branches: [ "main" ] 10 | jobs: 11 | build: 12 | uses: ./.github/workflows/docs.yml 13 | 14 | prepare: 15 | needs: build 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Download artifact 19 | uses: actions/download-artifact@v4 20 | with: 21 | name: "te_docs" 22 | path: "html" 23 | - name: Prepare for pages 24 | uses: actions/upload-pages-artifact@v1.0.7 25 | with: 26 | name: github-pages 27 | path: "html" 28 | deploy: 29 | needs: prepare 30 | environment: 31 | name: github-pages 32 | url: ${{ steps.deployment.outputs.page_url }} 33 | permissions: 34 | pages: write 35 | id-token: write 36 | runs-on: ubuntu-latest 37 | steps: 38 | - name: Deploy 39 | uses: actions/deploy-pages@v2.0.0 40 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | # A workflow to trigger the build of TE documentation on GitHub 6 | name: 'Documentation' 7 | on: 8 | pull_request: 9 | workflow_dispatch: 10 | workflow_call: 11 | jobs: 12 | build_docs: 13 | name: 'Build' 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: 'Checkout' 17 | uses: actions/checkout@v3 18 | - name: 'Install dependencies' 19 | run: | 20 | pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2 21 | pip install breathe==4.35.0 sphinx-autoapi==3.3.2 22 | sudo apt-get install -y pandoc graphviz doxygen 23 | export GIT_SHA=$(git show-ref --hash HEAD) 24 | - name: 'Build docs' 25 | run: | 26 | doxygen docs/Doxyfile 27 | cd docs 28 | make html 29 | - name: 'Upload docs' 30 | uses: actions/upload-artifact@v4 31 | with: 32 | name: te_docs 33 | path: docs/_build/html 34 | retention-days: 7 35 | -------------------------------------------------------------------------------- /.github/workflows/license.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | # A workflow to trigger the TE license check on GitHub 6 | name: 'License' 7 | on: 8 | pull_request: 9 | workflow_dispatch: 10 | jobs: 11 | check: 12 | name: 'Check' 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: 'Checkout' 16 | uses: actions/checkout@v3 17 | - name: 'Check License' 18 | run: | 19 | export TE_PATH=. 20 | bash ./qa/L0_license/test.sh 21 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | # A workflow to trigger lint tests on GitHub 6 | name: 'Lint' 7 | on: 8 | pull_request: 9 | workflow_dispatch: 10 | jobs: 11 | pytorch_cpplint: 12 | name: 'PyTorch C++' 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v3 17 | - name: 'Lint' 18 | run: | 19 | sudo apt-get update 20 | sudo apt-get install pip -y 21 | export CPP_ONLY=1 22 | export TE_PATH=. 23 | bash ./qa/L0_pytorch_lint/test.sh 24 | pytorch_pylint: 25 | name: 'PyTorch Python' 26 | runs-on: ubuntu-latest 27 | steps: 28 | - name: 'Checkout' 29 | uses: actions/checkout@v3 30 | - name: 'Lint' 31 | run: | 32 | sudo apt-get update 33 | sudo apt-get install pip -y 34 | pip install torch numpy 35 | export PYTHON_ONLY=1 36 | export TE_PATH=. 37 | bash ./qa/L0_pytorch_lint/test.sh 38 | jax_cpplint: 39 | name: 'JAX C++' 40 | runs-on: ubuntu-latest 41 | steps: 42 | - name: 'Checkout' 43 | uses: actions/checkout@v3 44 | - name: 'Lint' 45 | run: | 46 | sudo apt-get update 47 | sudo apt-get install pip -y 48 | export CPP_ONLY=1 49 | export TE_PATH=. 50 | bash ./qa/L0_jax_lint/test.sh 51 | jax_pylint: 52 | name: 'JAX Python' 53 | runs-on: ubuntu-latest 54 | steps: 55 | - name: 'Checkout' 56 | uses: actions/checkout@v3 57 | - name: 'Lint' 58 | run: | 59 | sudo apt-get update 60 | sudo apt-get install pip -y 61 | export PYTHON_ONLY=1 62 | export TE_PATH=. 63 | bash ./qa/L0_jax_lint/test.sh 64 | -------------------------------------------------------------------------------- /.github/workflows/upload-ci-logs.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | # A workflow to trigger ci on hybrid infra (github + self hosted runner) 6 | name: TE-CI Logs 7 | on: 8 | workflow_dispatch: 9 | inputs: 10 | platform: 11 | description: 'runs-on argument' 12 | required: false 13 | args: 14 | description: 'argument' 15 | required: false 16 | job_name: 17 | description: 'name of the job' 18 | required: true 19 | commit_sha: 20 | description: 'SHA of the commit that was tested.' 21 | required: true 22 | result: 23 | description: 'Job result' 24 | required: true 25 | run-name: PR ${{ fromJson(github.event.inputs.args).pr }} - ${{ inputs.job_name }} 26 | jobs: 27 | Upload-Log: 28 | name: Upload log 29 | runs-on: blossom 30 | steps: 31 | - name: Log 32 | run: blossom-ci 33 | env: 34 | OPERATION: 'POST-PROCESSING' 35 | CI_SERVER: ${{ secrets.CI_SERVER }} 36 | REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }} 37 | status_update: 38 | name: Update commit status 39 | runs-on: ubuntu-latest 40 | permissions: 41 | statuses: write 42 | needs: [Upload-Log] 43 | if: ${{ always() }} 44 | steps: 45 | - name: Set status 46 | run: | 47 | curl \ 48 | -X POST \ 49 | -H "Accept: application/vnd.github+json" \ 50 | -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ 51 | https://api.github.com/repos/${{ github.repository }}/statuses/${{ inputs.commit_sha }} \ 52 | -d "{\"state\":\"${{ inputs.result }}\",\"target_url\":\"${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}\",\"description\":\"\",\"context\":\"te-ci/${{ inputs.job_name }}\"}" 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.swp 3 | *.ii 4 | *.ptx 5 | *.cubin 6 | *.fatbin* 7 | *.module_id 8 | *.nsys-rep 9 | *.ncu-rep 10 | *.sqlite 11 | *.eggs 12 | build/ 13 | *.so 14 | *.egg-info 15 | __pycache__ 16 | .ycm_extra_conf.py 17 | .vimrc 18 | .vs 19 | .vscode 20 | .cache 21 | .hypothesis 22 | .devcontainer.json 23 | tests/cpp/build/ 24 | .ipynb_checkpoints 25 | *.log 26 | CMakeFiles/CMakeSystem.cmake 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | .pytest_cache/ 39 | compile_commands.json 40 | .nfs 41 | tensor_dumps/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/googletest"] 2 | path = 3rdparty/googletest 3 | url = https://github.com/google/googletest.git 4 | [submodule "3rdparty/cudnn-frontend"] 5 | path = 3rdparty/cudnn-frontend 6 | url = https://github.com/NVIDIA/cudnn-frontend.git 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | default_language_version: 6 | python: python3 7 | 8 | ci: 9 | autofix_prs: true 10 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 11 | autoupdate_schedule: quarterly 12 | submodules: false 13 | skip: [] 14 | 15 | repos: 16 | - repo: https://github.com/pre-commit/pre-commit-hooks 17 | rev: v4.6.0 18 | hooks: 19 | - id: check-merge-conflict 20 | - id: check-added-large-files 21 | - id: end-of-file-fixer 22 | files: .*.(c|cc|cxx|cpp|cu|cuh|h|hpp|py)$ 23 | - id: trailing-whitespace 24 | files: .*.(c|cc|cxx|cpp|cu|cuh|h|hpp|py)$ 25 | 26 | - repo: https://github.com/psf/black 27 | rev: 24.4.2 28 | hooks: 29 | - id: black 30 | name: Format python code 31 | args: [--line-length=100, --preview, --enable-unstable-feature=string_processing] 32 | types: [python] 33 | 34 | - repo: https://github.com/pre-commit/mirrors-clang-format 35 | rev: v18.1.6 36 | hooks: 37 | - id: clang-format 38 | entry: clang-format -i 39 | args: ["-style=file"] 40 | files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$ 41 | -------------------------------------------------------------------------------- /CPPLINT.cfg: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | # Stop searching for additional config files. 6 | set noparent 7 | 8 | # Limit line length. 9 | linelength=100 10 | 11 | # Ignore the following errors. 12 | filter=-build/include_subdir 13 | filter=-build/namespaces 14 | filter=-readability/todo 15 | filter=-build/header_guard 16 | filter=-build/include 17 | filter=-build/c++11 18 | filter=-runtime/references 19 | filter=-whitespace 20 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization. 4 | 5 | If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub/GitLab.** 6 | 7 | ## Reporting Potential Security Vulnerability in an NVIDIA Product 8 | 9 | To report a potential security vulnerability in any NVIDIA product: 10 | - Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html) 11 | - E-Mail: psirt@nvidia.com 12 | - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key) 13 | - Please include the following information: 14 | - Product/Driver name and version/branch that contains the vulnerability 15 | - Type of vulnerability (code execution, denial of service, buffer overflow, etc.) 16 | - Instructions to reproduce the vulnerability 17 | - Proof-of-concept or exploit code 18 | - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability 19 | 20 | While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information. 21 | 22 | ## NVIDIA Product Security 23 | 24 | For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security 25 | -------------------------------------------------------------------------------- /build_tools/VERSION.txt: -------------------------------------------------------------------------------- 1 | 2.5.0.dev0 2 | -------------------------------------------------------------------------------- /build_tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Build related infrastructure.""" 6 | -------------------------------------------------------------------------------- /build_tools/jax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """JAX related extensions.""" 6 | import os 7 | import shutil 8 | from pathlib import Path 9 | 10 | import setuptools 11 | 12 | from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled 13 | from typing import List 14 | 15 | 16 | def xla_path() -> str: 17 | """XLA root path lookup. 18 | Throws FileNotFoundError if XLA source is not found.""" 19 | 20 | try: 21 | from jax.extend import ffi 22 | except ImportError: 23 | if os.getenv("XLA_HOME"): 24 | xla_home = Path(os.getenv("XLA_HOME")) 25 | else: 26 | xla_home = "/opt/xla" 27 | else: 28 | xla_home = ffi.include_dir() 29 | 30 | if not os.path.isdir(xla_home): 31 | raise FileNotFoundError("Could not find xla source.") 32 | return xla_home 33 | 34 | 35 | def setup_jax_extension( 36 | csrc_source_files, 37 | csrc_header_files, 38 | common_header_files, 39 | ) -> setuptools.Extension: 40 | """Setup PyBind11 extension for JAX support""" 41 | # Source files 42 | csrc_source_files = Path(csrc_source_files) 43 | extensions_dir = csrc_source_files / "extensions" 44 | sources = all_files_in_dir(extensions_dir, name_extension="cpp") 45 | 46 | # Header files 47 | include_dirs = get_cuda_include_dirs() 48 | include_dirs.extend( 49 | [ 50 | common_header_files, 51 | common_header_files / "common", 52 | common_header_files / "common" / "include", 53 | csrc_header_files, 54 | xla_path(), 55 | ] 56 | ) 57 | 58 | # Compile flags 59 | cxx_flags = ["-O3"] 60 | if debug_build_enabled(): 61 | cxx_flags.append("-g") 62 | cxx_flags.append("-UNDEBUG") 63 | else: 64 | cxx_flags.append("-g0") 65 | 66 | # Define TE/JAX as a Pybind11Extension 67 | from pybind11.setup_helpers import Pybind11Extension 68 | 69 | class Pybind11CPPExtension(Pybind11Extension): 70 | """Modified Pybind11Extension to allow custom CXX flags.""" 71 | 72 | def _add_cflags(self, flags: List[str]) -> None: 73 | if isinstance(self.extra_compile_args, dict): 74 | cxx_flags = self.extra_compile_args.pop("cxx", []) 75 | cxx_flags += flags 76 | self.extra_compile_args["cxx"] = cxx_flags 77 | else: 78 | self.extra_compile_args[:0] = flags 79 | 80 | return Pybind11CPPExtension( 81 | "transformer_engine_jax", 82 | sources=[str(path) for path in sources], 83 | include_dirs=[str(path) for path in include_dirs], 84 | extra_compile_args={"cxx": cxx_flags}, 85 | ) 86 | -------------------------------------------------------------------------------- /build_tools/te_version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Transformer Engine version string.""" 6 | import os 7 | from pathlib import Path 8 | import subprocess 9 | 10 | 11 | def te_version() -> str: 12 | """Transformer Engine version string 13 | 14 | Includes Git commit as local version, unless suppressed with 15 | NVTE_NO_LOCAL_VERSION environment variable. 16 | 17 | """ 18 | root_path = Path(__file__).resolve().parent 19 | with open(root_path / "VERSION.txt", "r") as f: 20 | version = f.readline().strip() 21 | if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")) and not bool( 22 | int(os.getenv("NVTE_RELEASE_BUILD", "0")) 23 | ): 24 | try: 25 | output = subprocess.run( 26 | ["git", "rev-parse", "--short", "HEAD"], 27 | capture_output=True, 28 | cwd=root_path, 29 | check=True, 30 | universal_newlines=True, 31 | ) 32 | except (subprocess.CalledProcessError, OSError): 33 | pass 34 | else: 35 | commit = output.stdout.strip() 36 | version += f"+{commit}" 37 | return version 38 | -------------------------------------------------------------------------------- /build_tools/wheel_utils/Dockerfile.aarch: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | FROM quay.io/pypa/manylinux_2_28_aarch64 6 | 7 | WORKDIR /TransformerEngine/ 8 | COPY ../.. /TransformerEngine/ 9 | 10 | ARG VER="12-3" 11 | ARG ARCH="aarch64" 12 | RUN dnf -y install vim 13 | 14 | # Cuda toolkit, cudnn, driver. 15 | RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo 16 | RUN dnf -y install epel-release 17 | RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ 18 | cuda-libraries-${VER}.${ARCH} \ 19 | cuda-libraries-devel-${VER}.${ARCH} 20 | RUN dnf -y install --allowerasing cudnn9-cuda-12 21 | RUN dnf clean all 22 | RUN rm -rf /var/cache/dnf/* 23 | RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf 24 | RUN dnf -y install cuda-toolkit 25 | RUN dnf clean all 26 | RUN dnf -y install glog.aarch64 glog-devel.aarch64 27 | 28 | ENV PATH="/usr/local/cuda/bin:${PATH}" 29 | ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" 30 | ENV CUDA_HOME=/usr/local/cuda 31 | ENV CUDA_ROOT=/usr/local/cuda 32 | ENV CUDA_PATH=/usr/local/cuda 33 | ENV CUDADIR=/usr/local/cuda 34 | ENV NVTE_RELEASE_BUILD=1 35 | 36 | CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"] 37 | -------------------------------------------------------------------------------- /build_tools/wheel_utils/Dockerfile.x86: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | FROM quay.io/pypa/manylinux_2_28_x86_64 6 | 7 | WORKDIR /TransformerEngine/ 8 | COPY ../.. /TransformerEngine/ 9 | 10 | ARG VER="12-3" 11 | ARG ARCH="x86_64" 12 | RUN dnf -y install vim 13 | 14 | # Cuda toolkit, cudnn, driver. 15 | RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo 16 | RUN dnf -y install epel-release 17 | RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ 18 | cuda-libraries-${VER}.${ARCH} \ 19 | cuda-libraries-devel-${VER}.${ARCH} 20 | RUN dnf -y install --allowerasing cudnn9-cuda-12 21 | RUN dnf clean all 22 | RUN rm -rf /var/cache/dnf/* 23 | RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf 24 | RUN dnf -y install cuda-toolkit 25 | RUN dnf clean all 26 | RUN dnf -y install glog.x86_64 glog-devel.x86_64 27 | 28 | ENV PATH="/usr/local/cuda/bin:${PATH}" 29 | ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" 30 | ENV CUDA_HOME=/usr/local/cuda 31 | ENV CUDA_ROOT=/usr/local/cuda 32 | ENV CUDA_PATH=/usr/local/cuda 33 | ENV CUDADIR=/usr/local/cuda 34 | ENV NVTE_RELEASE_BUILD=1 35 | 36 | CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"] 37 | -------------------------------------------------------------------------------- /build_tools/wheel_utils/launch_aarch.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch . 6 | docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel" 7 | rm -rf aarch_wheelhouse 8 | docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse 9 | -------------------------------------------------------------------------------- /build_tools/wheel_utils/launch_x86.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 . 6 | docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel" 7 | rm -rf x86_wheelhouse 8 | docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse 9 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | doxygen 3 | sphinx_rtd_theme -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile sphinx_rtd_theme 20 | PYTHONPATH=sphinx_rtd_theme:$(PYTHONPATH) $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | # Patch Sphinx RTD theme 3.0.1 to add version selector in sidebar 23 | sphinx_rtd_theme: 24 | git clone --depth=1 -b 3.0.1 --single-branch https://github.com/readthedocs/sphinx_rtd_theme.git 25 | bash -c "cd sphinx_rtd_theme; git apply ../version_select.patch" 26 | -------------------------------------------------------------------------------- /docs/_static/NVIDIA-LogoBlack.svg: -------------------------------------------------------------------------------- 1 | NVIDIA-LogoBlack -------------------------------------------------------------------------------- /docs/_static/css/nvidia_footer.css: -------------------------------------------------------------------------------- 1 | footer img { 2 | display: block; 3 | width: 137.5px; 4 | position: relative; 5 | left: -9px; 6 | margin: 0 0 15px 0; 7 | } 8 | 9 | footer p { 10 | color: #666666; 11 | font-weight: normal; 12 | font-size: 12px; 13 | line-height: 1.25em; 14 | } 15 | 16 | footer p:not(.notices) { 17 | display: inline; 18 | margin: 0; 19 | } 20 | 21 | footer p a, 22 | footer p a:link, 23 | footer p a:visited { 24 | color: #666666; 25 | } 26 | 27 | footer p a:hover { 28 | color: #666666; 29 | } 30 | -------------------------------------------------------------------------------- /docs/_templates/footer.html: -------------------------------------------------------------------------------- 1 | {% extends '!footer.html' %} 2 | 3 | {% block contentinfo %} 4 | 5 |

6 | Privacy Policy 7 | | 8 | Manage My Privacy 9 | | 10 | Do Not Sell or Share My Data 11 | | 12 | Terms of Service 13 | | 14 | Accessibility 15 | | 16 | Corporate Policies 17 | | 18 | Product Security 19 | | 20 | Contact 21 |

22 | {{ super() }} 23 | {% endblock %} 24 | -------------------------------------------------------------------------------- /docs/api/c/activation.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | activation.h 7 | ============ 8 | 9 | .. doxygenfile:: activation.h 10 | -------------------------------------------------------------------------------- /docs/api/c/cast.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | cast.h 7 | ====== 8 | 9 | .. doxygenfile:: cast.h 10 | -------------------------------------------------------------------------------- /docs/api/c/cast_transpose_noop.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | cast_transpose_noop.h 7 | ===================== 8 | 9 | .. doxygenfile:: cast_transpose_noop.h 10 | -------------------------------------------------------------------------------- /docs/api/c/cudnn.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | cudnn.h 7 | ======= 8 | 9 | .. doxygenfile:: cudnn.h 10 | -------------------------------------------------------------------------------- /docs/api/c/fused_attn.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | fused_attn.h 7 | ============ 8 | 9 | .. doxygenfile:: fused_attn.h 10 | -------------------------------------------------------------------------------- /docs/api/c/fused_rope.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | fused_rope.h 7 | ============ 8 | 9 | .. doxygenfile:: fused_rope.h 10 | 11 | -------------------------------------------------------------------------------- /docs/api/c/gemm.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | gemm.h 7 | ====== 8 | 9 | .. doxygenfile:: gemm.h 10 | -------------------------------------------------------------------------------- /docs/api/c/index.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | C/C++ API 7 | ========= 8 | 9 | The C/C++ API allows you to access the custom kernels defined in `libtransformer_engine.so` library 10 | directly from C/C++, without Python. 11 | 12 | .. toctree:: 13 | :caption: Headers 14 | 15 | transformer_engine.h 16 | activation.h 17 | cast_transpose_noop.h 18 | cast.h 19 | cudnn.h 20 | fused_attn.h 21 | fused_rope.h 22 | gemm.h 23 | multi_tensor.h 24 | normalization.h 25 | padding.h 26 | permutation.h 27 | recipe.h 28 | softmax.h 29 | swizzle.h 30 | transpose.h 31 | -------------------------------------------------------------------------------- /docs/api/c/multi_tensor.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | multi_tensor.h 7 | ============== 8 | 9 | .. doxygenfile:: multi_tensor.h 10 | -------------------------------------------------------------------------------- /docs/api/c/normalization.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | normalization.h 7 | =============== 8 | 9 | .. doxygenfile:: normalization.h 10 | -------------------------------------------------------------------------------- /docs/api/c/padding.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | padding.h 7 | ========= 8 | 9 | .. doxygenfile:: padding.h 10 | 11 | -------------------------------------------------------------------------------- /docs/api/c/permutation.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | permutation.h 7 | ============= 8 | 9 | .. doxygenfile:: permutation.h 10 | 11 | -------------------------------------------------------------------------------- /docs/api/c/recipe.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | recipe.h 7 | ======== 8 | 9 | .. doxygenfile:: recipe.h 10 | 11 | -------------------------------------------------------------------------------- /docs/api/c/softmax.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | softmax.h 7 | ========= 8 | 9 | .. doxygenfile:: softmax.h 10 | -------------------------------------------------------------------------------- /docs/api/c/swizzle.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | swizzle.h 7 | ========= 8 | 9 | .. doxygenfile:: swizzle.h 10 | 11 | -------------------------------------------------------------------------------- /docs/api/c/transformer_engine.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | transformer_engine.h 7 | ==================== 8 | 9 | .. doxygenfile:: transformer_engine.h 10 | -------------------------------------------------------------------------------- /docs/api/c/transpose.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | transpose.h 7 | =========== 8 | 9 | .. doxygenfile:: transpose.h 10 | -------------------------------------------------------------------------------- /docs/api/common.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | Common API 7 | ========== 8 | 9 | .. autoapiclass:: transformer_engine.common.recipe.Format 10 | 11 | .. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None) 12 | 13 | .. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3) 14 | 15 | .. autoapiclass:: transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=Format.HYBRID) 16 | 17 | .. autoapiclass:: transformer_engine.common.recipe.Float8BlockScaling(fp8_format=Format.E4M3) 18 | -------------------------------------------------------------------------------- /docs/api/framework.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | Framework-specific API 7 | ====================== 8 | 9 | .. toctree:: 10 | 11 | pytorch 12 | jax 13 | -------------------------------------------------------------------------------- /docs/api/jax.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | Jax 7 | ======= 8 | 9 | Pre-defined Variable of Logical Axes 10 | ------------------------------------ 11 | Variables are available in `transformer_engine.jax.sharding`. 12 | 13 | * BATCH_AXES: The logical axis of batch dimension. It is usually sharded along DP + FSDP on Mesh. 14 | * SEQLEN_AXES: The logical axis of sequence length dimension. It is usually not sharded. 15 | * SEQLEN_TP_AXES: The logical axis of sequence length dimension. It is usually sharded along TP on Mesh. 16 | * HEAD_AXES: The logical axis of head dimension of MHA. It is usually sharded along TP on Mesh. 17 | * HIDDEN_AXES: The logical axis of hidden dimension. It is usually not sharded. 18 | * HIDDEN_TP_AXES: The logical axis of hidden dimension. It is usually sharded along TP on Mesh. 19 | * JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded. 20 | 21 | 22 | Modules 23 | ------------------------------------ 24 | .. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType 25 | .. autoapiclass:: transformer_engine.jax.MeshResource() 26 | 27 | 28 | .. autoapifunction:: transformer_engine.jax.fp8_autocast 29 | .. autoapifunction:: transformer_engine.jax.update_collections 30 | 31 | 32 | .. autoapiclass:: transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs) 33 | :members: __call__ 34 | 35 | .. autoapiclass:: transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs) 36 | :members: __call__ 37 | 38 | .. autoapiclass:: transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs) 39 | :members: __call__ 40 | 41 | .. autoapiclass:: transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs) 42 | :members: __call__ 43 | 44 | .. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs) 45 | :members: __call__ 46 | 47 | .. autoapiclass:: transformer_engine.jax.flax.DotProductAttention(head_dim, num_heads, **kwargs) 48 | :members: __call__ 49 | 50 | .. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs) 51 | :members: __call__ 52 | 53 | .. autoapiclass:: transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs) 54 | :members: __call__ 55 | 56 | .. autoapifunction:: transformer_engine.jax.flax.extend_logical_axis_rules 57 | -------------------------------------------------------------------------------- /docs/debug.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | Precision debug tools 6 | ============================================== 7 | 8 | .. toctree:: 9 | :caption: Precision debug tools 10 | 11 | debug/1_getting_started.rst 12 | debug/2_config_file_structure.rst 13 | debug/api 14 | debug/4_distributed.rst -------------------------------------------------------------------------------- /docs/debug/3_api_features.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | Debug features 7 | ========== 8 | 9 | .. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats 10 | .. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats 11 | .. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM 12 | .. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer 13 | .. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling 14 | .. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant 15 | -------------------------------------------------------------------------------- /docs/debug/api.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | API 6 | ============ 7 | 8 | .. toctree:: 9 | :caption: Precision debug tools API 10 | 11 | 3_api_debug_setup.rst 12 | 3_api_features.rst 13 | 3_api_te_calls.rst -------------------------------------------------------------------------------- /docs/debug/img/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/debug/img/tensorboard.png -------------------------------------------------------------------------------- /docs/examples/E8M0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/E8M0.png -------------------------------------------------------------------------------- /docs/examples/H200-NeMo-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/H200-NeMo-performance.png -------------------------------------------------------------------------------- /docs/examples/MXFP8_FP8_comparison_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/MXFP8_FP8_comparison_1.png -------------------------------------------------------------------------------- /docs/examples/MXFP8_FP8_comparison_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/MXFP8_FP8_comparison_2.png -------------------------------------------------------------------------------- /docs/examples/attention/dot_product_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/attention/dot_product_attention.png -------------------------------------------------------------------------------- /docs/examples/comparison-fp8-bf16-training-nvidia-dgx-cloud-benchmarking-performance-explorer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/comparison-fp8-bf16-training-nvidia-dgx-cloud-benchmarking-performance-explorer.jpg -------------------------------------------------------------------------------- /docs/examples/delayed_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/delayed_scaling.png -------------------------------------------------------------------------------- /docs/examples/fp8_formats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/fp8_formats.png -------------------------------------------------------------------------------- /docs/examples/linear_mxfp8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/linear_mxfp8.png -------------------------------------------------------------------------------- /docs/examples/loss_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/loss_scaling.png -------------------------------------------------------------------------------- /docs/examples/te_llama/media/transformer_llama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/te_llama/media/transformer_llama.png -------------------------------------------------------------------------------- /docs/examples/transformer_layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/TransformerEngine/c6a9e2610edddb73e28f78defe94cc0ae7845af1/docs/examples/transformer_layer.png -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | Transformer Engine documentation 7 | ============================================== 8 | 9 | .. ifconfig:: "dev" in release 10 | 11 | .. warning:: 12 | You are currently viewing unstable developer preview of the documentation. 13 | To see the documentation for the latest stable release, refer to: 14 | 15 | * `Release Notes `_ 16 | * `Developer Guide `_ (stable version of this page) 17 | 18 | .. include:: ../README.rst 19 | :start-after: overview-begin-marker-do-not-remove 20 | :end-before: overview-end-marker-do-not-remove 21 | 22 | .. toctree:: 23 | :hidden: 24 | 25 | Home 26 | 27 | .. toctree:: 28 | :hidden: 29 | :caption: Getting Started 30 | 31 | installation 32 | examples/quickstart.ipynb 33 | faq 34 | 35 | .. toctree:: 36 | :hidden: 37 | :caption: Python API documentation 38 | 39 | api/common 40 | api/framework 41 | 42 | .. toctree:: 43 | :hidden: 44 | :caption: Examples and Tutorials 45 | 46 | examples/fp8_primer.ipynb 47 | examples/advanced_optimizations.ipynb 48 | examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb 49 | 50 | .. toctree:: 51 | :hidden: 52 | :caption: Advanced 53 | 54 | api/c/index 55 | debug 56 | examples/attention/attention.ipynb 57 | -------------------------------------------------------------------------------- /docs/version_select.patch: -------------------------------------------------------------------------------- 1 | diff --git a/sphinx_rtd_theme/layout.html b/sphinx_rtd_theme/layout.html 2 | index e6a38b1..579eaec 100644 3 | --- a/sphinx_rtd_theme/layout.html 4 | +++ b/sphinx_rtd_theme/layout.html 5 | @@ -124,6 +124,16 @@ 6 | {%- endif %} 7 | 8 | 9 | + {# Show TE version and version selector #} 10 | +
11 | + {{ version }} 12 | +
13 | + Version select: 17 | +
18 | + 19 | {%- if READTHEDOCS or DEBUG %} 20 | {%- if theme_version_selector or theme_language_selector %} 21 |
22 | -------------------------------------------------------------------------------- /examples/jax/README.md: -------------------------------------------------------------------------------- 1 | # Transformer Engine Examples # 2 | 3 | This folder contains simple examples introducing Transformer Engine and FP8 training usage. 4 | 5 | **Examples Outline** 6 | * MNIST training: Training MNIST dataset is a good start point to learn how use Transformer Engine and enable FP8 training 7 | * Encoder training: The encoder examples introduce more about how to scale up training on multiple GPUs with Transformer Engine -------------------------------------------------------------------------------- /examples/jax/encoder/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """config for test_multiprocessing_encoder""" 6 | import pytest 7 | 8 | 9 | def pytest_addoption(parser): 10 | """Pytest hook for test_multiprocessing_encoder""" 11 | parser.addoption("--num-process", action="store", default=0) 12 | parser.addoption("--process-id", action="store", default=0) 13 | 14 | 15 | @pytest.fixture(autouse=True) 16 | def multiprocessing_parses(request): 17 | """Fixture for querying num-process and process-id""" 18 | if request.cls: 19 | request.cls.num_process = int(request.config.getoption("--num-process")) 20 | request.cls.process_id = int(request.config.getoption("--process-id")) 21 | -------------------------------------------------------------------------------- /examples/jax/encoder/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | flax>=0.7.1 3 | nltk>=3.8.2 4 | optax 5 | -------------------------------------------------------------------------------- /examples/jax/encoder/run_test_multiprocessing_encoder.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} 6 | 7 | # Define the test cases to run 8 | TEST_CASES=( 9 | "test_te_bf16" 10 | "test_te_delayed_scaling_fp8" 11 | "test_te_current_scaling_fp8" 12 | "test_te_mxfp8" 13 | "test_te_bf16_shardy" 14 | "test_te_delayed_scaling_fp8_shardy" 15 | "test_te_current_scaling_fp8_shardy" 16 | ) 17 | 18 | echo 19 | echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***" 20 | 21 | HAS_FAILURE=0 # Global failure flag 22 | 23 | # Run each test case across all GPUs 24 | for TEST_CASE in "${TEST_CASES[@]}"; do 25 | echo 26 | echo "=== Starting test: $TEST_CASE ..." 27 | 28 | for i in $(seq 0 $(($NUM_GPUS - 1))); do 29 | # Define output file for logs 30 | LOG_FILE="${TEST_CASE}_gpu_${i}.log" 31 | 32 | # Run pytest and redirect stdout and stderr to the log file 33 | pytest -c "$TE_PATH/tests/jax/pytest.ini" \ 34 | -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ 35 | --num-process=$NUM_GPUS \ 36 | --process-id=$i > "$LOG_FILE" 2>&1 & 37 | done 38 | 39 | # Wait for the process to finish 40 | wait 41 | 42 | # Check and print the log content accordingly 43 | if grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then 44 | HAS_FAILURE=1 45 | echo "... $TEST_CASE FAILED" 46 | tail -n +7 "${TEST_CASE}_gpu_0.log" 47 | elif grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then 48 | echo "... $TEST_CASE SKIPPED" 49 | elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then 50 | echo "... $TEST_CASE PASSED" 51 | else 52 | echo "Invalid ${TEST_CASE}_gpu_0.log" 53 | fi 54 | 55 | # Remove the log file after processing it 56 | rm ${TEST_CASE}_gpu_*.log 57 | done 58 | 59 | exit $HAS_FAILURE 60 | -------------------------------------------------------------------------------- /examples/jax/mnist/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | flax>=0.7.1 3 | optax 4 | Pillow 5 | -------------------------------------------------------------------------------- /examples/pytorch/fsdp/README.md: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | # Basic Example for Using PyTorch Fully Sharded Data Parallel mode with Transformer Engine 6 | 7 | ```bash 8 | # FSDP without deferred initialization: 9 | # Duplicate modules initialized on each device. Load on device memory reduced only after 10 | # torch.distributed.fsdp.FullyShardedDataParallel mode shards model parameters. 11 | $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --no-defer-init 12 | # Sample output on 8xL40S: 13 | # [GPU-0] WORLD_SIZE = 8 14 | # [GPU-0] TransformerEngine Model: 15 | # TransformerLayer( 16 | # (self_attention): MultiheadAttention( 17 | # (layernorm_qkv): LayerNormLinear() 18 | # (core_attention): DotProductAttention( 19 | # (flash_attention): FlashAttention() 20 | # (fused_attention): FusedAttention() 21 | # (unfused_attention): UnfusedDotProductAttention( 22 | # (scale_mask_softmax): FusedScaleMaskSoftmax() 23 | # (attention_dropout): Dropout(p=0.1, inplace=False) 24 | # ) 25 | # ) 26 | # (proj): Linear() 27 | # ) 28 | # (layernorm_mlp): LayerNormMLP() 29 | # ) 30 | # [GPU-0] Pre-FSDP memory use = 83.935232MiB 31 | # [GPU-0] Post-FSDP memory use = 10.491904MiB 32 | # [GPU-0] Iter. 1 33 | # [GPU-0] Iter. 2 34 | # [GPU-0] Iter. 3 35 | # [GPU-0] Training Time: 6.647654296875s 36 | # [GPU-0] Avg. Iter. Time: 2.2158847656250003s 37 | # [GPU-0] Peak memory use = 3000MiB 38 | 39 | # FSDP with deferred initialization: 40 | # Modules initialized with empty parameters via `device='meta'` option. Zero load on device 41 | # memory until torch.distributed.fsdp.FullyShardedDataParallel mode triggers a reset on 42 | # on already sharded model parameters. 43 | $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py 44 | # Sample output on 8xL40S: 45 | # [GPU-0] WORLD_SIZE = 8 46 | # ... 47 | # [GPU-0] Pre-FSDP memory use = 0.0MiB 48 | # [GPU-0] Post-FSDP memory use = 10.491904MiB 49 | # ... 50 | ``` 51 | 52 | **NOTE:** This example has `fp8_autocast()` enabled by default. To run on GPUs without Fp8 support 53 | (e.g.: A100), add the `--no-fp8` option to the commands shown above. 54 | -------------------------------------------------------------------------------- /examples/pytorch/mnist/README.md: -------------------------------------------------------------------------------- 1 | # Basic MNIST Example with optional FP8 2 | 3 | ```bash 4 | python main.py 5 | python main.py --use-te # Linear layers from TransformerEngine 6 | python main.py --use-fp8 # FP8 + TransformerEngine for Linear layers 7 | ``` 8 | -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | extension-pkg-whitelist=flash_attn_2_cuda, 3 | torch, 4 | transformer_engine_torch, 5 | transformer_engine_jax 6 | 7 | disable=too-many-locals, 8 | too-few-public-methods, 9 | too-many-public-methods, 10 | too-many-positional-arguments, 11 | invalid-name, 12 | too-many-arguments, 13 | abstract-method, 14 | arguments-differ, 15 | too-many-instance-attributes, 16 | unsubscriptable-object, 17 | import-outside-toplevel, 18 | too-many-statements, 19 | import-error, 20 | too-many-lines, 21 | use-maxsplit-arg, 22 | protected-access, 23 | pointless-string-statement, 24 | cyclic-import, 25 | duplicate-code, 26 | no-member, 27 | attribute-defined-outside-init, 28 | global-statement, 29 | too-many-branches, 30 | global-variable-not-assigned, 31 | redefined-argument-from-local, 32 | line-too-long, 33 | too-many-return-statements, 34 | too-many-nested-blocks 35 | 36 | [TYPECHECK] 37 | ignored-modules=torch 38 | ignored-classes=torch 39 | -------------------------------------------------------------------------------- /qa/L0_cppunittest/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | set -e 6 | 7 | # Find TE 8 | : ${TE_PATH:=/opt/transformerengine} 9 | TE_LIB_PATH=`pip3 show transformer-engine | grep Location | cut -d ' ' -f 2` 10 | export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH 11 | 12 | # Set parallelization parameters 13 | NUM_PHYSICAL_CORES=$(nproc) 14 | NUM_PARALLEL_JOBS=4 15 | 16 | cd $TE_PATH/tests/cpp 17 | cmake -GNinja -Bbuild . 18 | cmake --build build 19 | export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS)) 20 | ctest --test-dir build -j$NUM_PARALLEL_JOBS 21 | -------------------------------------------------------------------------------- /qa/L0_jax_distributed_unittest/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | function error_exit() { 6 | echo "Error: $1" 7 | exit 1 8 | } 9 | 10 | function test_fail() { 11 | RET=1 12 | FAILED_CASES="$FAILED_CASES $1" 13 | echo "Error: sub-test failed: $1" 14 | } 15 | 16 | RET=0 17 | FAILED_CASES="" 18 | 19 | : ${TE_PATH:=/opt/transformerengine} 20 | : ${XML_LOG_DIR:=/logs} 21 | mkdir -p "$XML_LOG_DIR" 22 | 23 | pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install requirements" 24 | 25 | # Make encoder tests to have run-to-run deterministic to have the stable CI results 26 | export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" 27 | python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" 28 | wait 29 | python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" 30 | wait 31 | . $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" 32 | 33 | if [ $RET -ne 0 ]; then 34 | echo "Error: some sub-tests failed: $FAILED_CASES" 35 | exit 1 36 | fi 37 | echo "All tests passed" 38 | exit 0 39 | -------------------------------------------------------------------------------- /qa/L0_jax_lint/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | set -e 6 | 7 | : "${TE_PATH:=/opt/transformerengine}" 8 | 9 | pip3 install cpplint==1.6.0 pylint==3.3.1 10 | if [ -z "${PYTHON_ONLY}" ] 11 | then 12 | cd $TE_PATH 13 | echo "Checking common API headers" 14 | python3 -m cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include 15 | echo "Checking C++ files" 16 | python3 -m cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common 17 | python3 -m cpplint --recursive transformer_engine/jax 18 | fi 19 | if [ -z "${CPP_ONLY}" ] 20 | then 21 | cd $TE_PATH 22 | echo "Checking Python files" 23 | python3 -m pylint --recursive=y transformer_engine/common transformer_engine/jax 24 | fi 25 | -------------------------------------------------------------------------------- /qa/L0_jax_unittest/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | set -x 6 | 7 | function error_exit() { 8 | echo "Error: $1" 9 | exit 1 10 | } 11 | 12 | function test_fail() { 13 | RET=1 14 | FAILED_CASES="$FAILED_CASES $1" 15 | echo "Error: sub-test failed: $1" 16 | } 17 | 18 | RET=0 19 | FAILED_CASES="" 20 | 21 | pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk" 22 | pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" 23 | 24 | : ${TE_PATH:=/opt/transformerengine} 25 | : ${XML_LOG_DIR:=/logs} 26 | mkdir -p "$XML_LOG_DIR" 27 | 28 | python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" 29 | 30 | pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" 31 | python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" 32 | 33 | pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" 34 | # Make encoder tests to have run-to-run deterministic to have the stable CI results 35 | export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" 36 | python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" 37 | # Test without custom calls 38 | export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" 39 | NVTE_JAX_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" 40 | 41 | if [ $RET -ne 0 ]; then 42 | echo "Error: some sub-tests failed: $FAILED_CASES" 43 | exit 1 44 | fi 45 | echo "All tests passed" 46 | exit 0 47 | -------------------------------------------------------------------------------- /qa/L0_jax_wheel/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | function error_exit() { 6 | echo "Error: $1" 7 | exit 1 8 | } 9 | 10 | function test_fail() { 11 | RET=1 12 | FAILED_CASES="$FAILED_CASES $1" 13 | echo "Error: sub-test failed: $1" 14 | } 15 | 16 | RET=0 17 | FAILED_CASES="" 18 | 19 | : "${TE_PATH:=/opt/transformerengine}" 20 | 21 | pip3 install wheel || error_exit "Failed to install wheel" 22 | 23 | cd $TE_PATH 24 | pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax || error_exit "Failed to uninstall transformer-engine transformer-engine-cu12 transformer-engine-jax" 25 | 26 | VERSION=`cat $TE_PATH/build_tools/VERSION.txt` 27 | WHL_BASE="transformer_engine-${VERSION}" 28 | 29 | # Core wheel. 30 | NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" 31 | wheel unpack dist/* || error_exit "Failed to unpack dist/*" 32 | sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" 33 | sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" 34 | mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" 35 | wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" 36 | rm dist/*.whl || error_exit "Failed to remove dist/*.whl" 37 | mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" 38 | NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" 39 | 40 | cd transformer_engine/jax 41 | NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" 42 | 43 | pip3 install dist/* || error_exit "Failed to install dist/*" 44 | cd $TE_PATH 45 | pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" 46 | 47 | python3 $TE_PATH/tests/jax/test_sanity_import.py || test_fail "test_sanity_import.py" 48 | 49 | if [ $RET -ne 0 ]; then 50 | echo "Error: some sub-tests failed: $FAILED_CASES" 51 | exit 1 52 | fi 53 | echo "All tests passed" 54 | exit 0 55 | -------------------------------------------------------------------------------- /qa/L0_license/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "initial_year": 2022, 3 | "copyright": "Copyright (c) , NVIDIA CORPORATION & AFFILIATES. All rights reserved.", 4 | "license": "See LICENSE for license information.", 5 | "exclude": ["3rdparty", 6 | "Dockerfile", 7 | "Dockerfile.base", 8 | "Dockerfile.qa", 9 | "Dockerfile.devel", 10 | "Dockerfile.docs", 11 | "docker-build.sh", 12 | ".png", 13 | ".ipynb", 14 | "docs/Makefile", 15 | "layout.html", 16 | "LICENSE", 17 | "VERSION", 18 | "Doxyfile", 19 | "pylintrc", 20 | ".json", 21 | ".md", 22 | ".txt" 23 | ], 24 | "exclude_copyright": [], 25 | "copyright_only": false 26 | } 27 | -------------------------------------------------------------------------------- /qa/L0_license/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | set -e 6 | 7 | : "${TE_PATH:=/opt/transformerengine}" 8 | 9 | python3 $TE_PATH/qa/L0_license/copyright_checker.py $TE_PATH 10 | -------------------------------------------------------------------------------- /qa/L0_pytorch_debug_unittest/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | 6 | 7 | : ${TE_PATH:=/opt/transformerengine} 8 | : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} 9 | : ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/} 10 | 11 | # Config with the dummy feature which prevents nvinspect from being disabled. 12 | # Nvinspect will be disabled if no feature is active. 13 | : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} 14 | 15 | FAIL=0 16 | 17 | pip install pytest==8.2.1 18 | pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 19 | pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 20 | pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 21 | NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 22 | 23 | # standard numerics tests with initialized debug 24 | NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 25 | 26 | exit $FAIL 27 | -------------------------------------------------------------------------------- /qa/L0_pytorch_lint/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | set -e 6 | 7 | : "${TE_PATH:=/opt/transformerengine}" 8 | 9 | pip3 install cpplint==1.6.0 pylint==3.3.1 10 | if [ -z "${PYTHON_ONLY}" ] 11 | then 12 | cd $TE_PATH 13 | echo "Checking common API headers" 14 | python3 -m cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include 15 | echo "Checking C++ files" 16 | python3 -m cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common 17 | python3 -m cpplint --recursive transformer_engine/pytorch 18 | fi 19 | if [ -z "${CPP_ONLY}" ] 20 | then 21 | cd $TE_PATH 22 | echo "Checking Python files" 23 | python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch transformer_engine/debug 24 | fi 25 | -------------------------------------------------------------------------------- /qa/L0_pytorch_wheel/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | function error_exit() { 6 | echo "Error: $1" 7 | exit 1 8 | } 9 | 10 | function test_fail() { 11 | RET=1 12 | FAILED_CASES="$FAILED_CASES $1" 13 | echo "Error: sub-test failed: $1" 14 | } 15 | 16 | RET=0 17 | FAILED_CASES="" 18 | 19 | : "${TE_PATH:=/opt/transformerengine}" 20 | 21 | pip3 install wheel || error_exit "Failed to install wheel" 22 | 23 | cd $TE_PATH 24 | pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch || error_exit "Failed to uninstall transformer-engine transformer-engine-cu12 transformer-engine-torch" 25 | 26 | VERSION=`cat $TE_PATH/build_tools/VERSION.txt` 27 | WHL_BASE="transformer_engine-${VERSION}" 28 | 29 | # Core wheel. 30 | NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" 31 | wheel unpack dist/* || error_exit "Failed to unpack dist/*" 32 | sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" 33 | sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" 34 | mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" 35 | wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" 36 | rm dist/*.whl || error_exit "Failed to remove dist/*.whl" 37 | mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" 38 | NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" 39 | 40 | cd transformer_engine/pytorch 41 | NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" 42 | 43 | pip3 install dist/* || error_exit "Failed to install dist/*" 44 | cd $TE_PATH 45 | pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" 46 | 47 | python3 $TE_PATH/tests/pytorch/test_sanity_import.py || test_fail "test_sanity_import.py" 48 | 49 | if [ "$RET" -ne 0 ]; then 50 | echo "Error in the following test cases:$FAILED_CASES" 51 | exit 1 52 | fi 53 | echo "All tests passed" 54 | exit 0 55 | -------------------------------------------------------------------------------- /qa/L1_jax_distributed_unittest/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | set -xe 6 | 7 | : ${TE_PATH:=/opt/transformerengine} 8 | : ${XML_LOG_DIR:=/logs} 9 | mkdir -p "$XML_LOG_DIR" 10 | 11 | python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* 12 | -------------------------------------------------------------------------------- /qa/L1_pytorch_mcore_integration/.gitignore: -------------------------------------------------------------------------------- 1 | Megatron-LM 2 | vocab.json -------------------------------------------------------------------------------- /qa/L1_pytorch_mcore_integration/merges.txt: -------------------------------------------------------------------------------- 1 | #version: 0.2 2 | -------------------------------------------------------------------------------- /qa/L1_pytorch_mcore_integration/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | set -e 6 | 7 | # Paths 8 | : ${TE_PATH:=/opt/transformerengine} 9 | : ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM} 10 | 11 | # Check whether FP8 is supported 12 | DEVICE_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | sed 's/[^0-9]//g') 13 | if [[ ${DEVICE_ARCH} -ge 89 ]]; then 14 | WITH_FP8=1 15 | fi 16 | 17 | # Download Megatron-LM if needed 18 | if [ ! -d "${MCORE_PATH}" ]; then 19 | pushd $(dirname ${MCORE_PATH}) 20 | git clone -b core_r0.12.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM 21 | popd 22 | fi 23 | 24 | # Create mock vocab 25 | VOCAB_FILE=${TE_PATH}/qa/L1_pytorch_mcore_integration/vocab.json 26 | printf "" > ${VOCAB_FILE} 27 | printf "{" >> ${VOCAB_FILE} 28 | printf "\"<|endoftext|>\": 0" >> ${VOCAB_FILE} 29 | seq 1 4095 | awk '{ printf(", \"%d\": %d", $1, $1) }' >> ${VOCAB_FILE} 30 | printf "}" >> ${VOCAB_FILE} 31 | 32 | # Megatron-LM invocation 33 | COMMAND=" 34 | NVTE_TORCH_COMPILE=0 35 | NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 36 | NVTE_FLASH_ATTN=1 37 | NVTE_FWD_LAYERNORM_SM_MARGIN=0 38 | NVTE_BWD_LAYERNORM_SM_MARGIN=0 39 | CUDA_DEVICE_MAX_CONNECTIONS=1 40 | NVTE_BIAS_GELU_NVFUSION=0 41 | NVTE_BIAS_DROPOUT_FUSION=0 42 | 43 | python3 44 | -m torch.distributed.launch 45 | --use_env 46 | --nnodes=1 47 | --nproc_per_node=1 48 | 49 | ${MCORE_PATH}/pretrain_gpt.py 50 | --tensor-model-parallel-size 1 51 | --pipeline-model-parallel-size 1 52 | --use-cpu-initialization 53 | --num-layers 2 54 | --hidden-size 128 55 | --num-attention-heads 8 56 | --seq-length 128 57 | --max-position-embeddings 128 58 | --micro-batch-size 1 59 | --global-batch-size 8 60 | --train-iters 10 61 | --eval-iters 10 62 | --lr 1e-4 63 | --mock-data 64 | --vocab-file ${VOCAB_FILE} 65 | --merge-file ${TE_PATH}/qa/L1_pytorch_mcore_integration/merges.txt 66 | --transformer-impl transformer_engine 67 | ${WITH_FP8:+--fp8-format hybrid} 68 | " 69 | COMMAND=$(echo "${COMMAND}" | tr '\n' ' ') 70 | 71 | # Launch Megatron-LM 72 | bash -c "${COMMAND}" 73 | -------------------------------------------------------------------------------- /qa/L1_pytorch_thunder_integration/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | set -x 6 | 7 | : ${THUNDER_PATH:=/opt/pytorch/lightning-thunder} 8 | : ${XML_LOG_DIR:=/logs} 9 | mkdir -p "$XML_LOG_DIR" 10 | 11 | pip3 install pytest==8.1.1 pytest-benchmark==5.1.0 12 | python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py 13 | 14 | # Check return code 15 | # Note: Return code 5 is fine. Lightning tests are skipped on systems 16 | # without FP8 support and Pytest returns 5 if no tests are run. 17 | RC=$? 18 | if [ ${RC} -eq 5 ]; then 19 | RC=0 20 | fi 21 | exit ${RC} 22 | -------------------------------------------------------------------------------- /qa/L2_jax_unittest/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | set -x 6 | 7 | function error_exit() { 8 | echo "Error: $1" 9 | exit 1 10 | } 11 | 12 | function test_fail() { 13 | RET=1 14 | FAILED_CASES="$FAILED_CASES $1" 15 | echo "Error: sub-test failed: $1" 16 | } 17 | 18 | RET=0 19 | FAILED_CASES="" 20 | 21 | pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk" 22 | pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" 23 | 24 | : ${TE_PATH:=/opt/transformerengine} 25 | : ${XML_LOG_DIR:=/logs} 26 | mkdir -p "$XML_LOG_DIR" 27 | 28 | NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" 29 | 30 | pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" 31 | NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" 32 | 33 | pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" 34 | # Make encoder tests to have run-to-run deterministic to have the stable CI results 35 | export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" 36 | NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" 37 | # Test without custom calls 38 | export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" 39 | NVTE_JAX_CUSTOM_CALLS_RE="" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" 40 | 41 | if [ $RET -ne 0 ]; then 42 | echo "Error: some sub-tests failed: $FAILED_CASES" 43 | exit 1 44 | fi 45 | echo "All tests passed" 46 | exit 0 47 | -------------------------------------------------------------------------------- /qa/L3_pytorch_FA_versions_test/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | set -e 6 | 7 | : ${TE_PATH:=/opt/transformerengine} 8 | : ${XML_LOG_DIR:=/logs} 9 | mkdir -p "$XML_LOG_DIR" 10 | 11 | pip3 install pytest==8.2.1 12 | 13 | # Limit parallel build jobs to avoid overwhelming system resources 14 | export MAX_JOBS=32 15 | 16 | # Iterate over Flash Attention versions 17 | sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"` 18 | export FLASH_ATTN_CUDA_ARCHS=$sm_arch 19 | if [ $sm_arch -gt 90 ] 20 | then 21 | FA_versions=(2.7.3) 22 | elif [ $sm_arch -eq 90 ] 23 | then 24 | FA_versions=(2.5.7 2.7.3 3.0.0b1) 25 | fi 26 | 27 | for fa_version in "${FA_versions[@]}" 28 | do 29 | 30 | # Build Flash Attention 31 | if [ "${fa_version}" \< "3.0.0" ] 32 | then 33 | pip3 install flash-attn==${fa_version} 34 | else 35 | git clone https://github.com/Dao-AILab/flash-attention.git 36 | cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install 37 | python_path=`python -c "import site; print(site.getsitepackages()[0])"` 38 | mkdir -p $python_path/flash_attn_3 39 | wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py 40 | cd ../../ 41 | fi 42 | 43 | # Run tests 44 | NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py 45 | 46 | done 47 | -------------------------------------------------------------------------------- /qa/format.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | # Utility file to run pre-commit hooks locally 6 | # Usage: bash qa/format.sh 7 | 8 | set -e 9 | 10 | : "${TE_PATH:=.}" 11 | 12 | cd $TE_PATH 13 | 14 | pip3 install pre-commit 15 | python3 -m pre_commit run --all-files 16 | -------------------------------------------------------------------------------- /tests/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | cmake_minimum_required(VERSION 3.18) 6 | 7 | if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) 8 | if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) 9 | set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) 10 | else () 11 | set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) 12 | endif() 13 | endif() 14 | 15 | 16 | set(CMAKE_CXX_STANDARD 17) 17 | set(CMAKE_CUDA_STANDARD 17) 18 | set(CMAKE_CUDA_STANDARD_REQUIRED ON) 19 | 20 | project(transformer_engine_tests LANGUAGES CUDA CXX) 21 | 22 | add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) 23 | 24 | enable_testing() 25 | 26 | include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) 27 | 28 | if(NOT DEFINED TE_LIB_PATH) 29 | execute_process(COMMAND bash -c "pip3 show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'" 30 | OUTPUT_VARIABLE TE_LIB_PATH) 31 | endif() 32 | 33 | find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/transformer_engine" ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) 34 | 35 | message(STATUS "Found transformer_engine library: ${TE_LIB}") 36 | include_directories(../../transformer_engine/common/include) 37 | include_directories(../../transformer_engine/common) 38 | include_directories(${CMAKE_SOURCE_DIR}) 39 | 40 | find_package(CUDAToolkit REQUIRED) 41 | include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) 42 | 43 | add_subdirectory(operator) 44 | add_subdirectory(util) 45 | -------------------------------------------------------------------------------- /tests/cpp/operator/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | add_executable(test_operator 6 | test_cast.cu 7 | test_cast_current_scaling.cu 8 | test_cast_dbias.cu 9 | test_cast_dbias_dgelu.cu 10 | test_cast_gated_swiglu.cu 11 | test_cast_mxfp8_gated_swiglu.cu 12 | test_qdq.cu 13 | test_cast_mxfp8.cu 14 | test_cast_float8blockwise.cu 15 | test_dequantize_mxfp8.cu 16 | test_transpose.cu 17 | test_cast_transpose.cu 18 | test_cast_transpose_current_scaling.cu 19 | test_cast_transpose_dbias.cu 20 | test_cast_transpose_dbias_dgelu.cu 21 | test_cast_transpose_dgeglu.cu 22 | test_act.cu 23 | test_normalization.cu 24 | test_normalization_mxfp8.cu 25 | test_multi_cast_transpose.cu 26 | test_multi_padding.cu 27 | test_causal_softmax.cu 28 | test_swizzle.cu 29 | ../test_common.cu) 30 | 31 | find_package(OpenMP REQUIRED) 32 | list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) 33 | 34 | target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) 35 | target_compile_options(test_operator PRIVATE -O2 -fopenmp) 36 | 37 | include(GoogleTest) 38 | gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600) 39 | -------------------------------------------------------------------------------- /tests/cpp/operator/test_memset.cu: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | #include 20 | #include "../test_common.h" 21 | 22 | using namespace transformer_engine; 23 | 24 | 25 | class MemsetTestSuite : public ::testing::TestWithParam> {}; 27 | 28 | TEST_P(MemsetTestSuite, TestMemset) { 29 | using namespace transformer_engine; 30 | using namespace test; 31 | 32 | int value = std::get<0>(GetParam()); 33 | size_t size_in_bytes = std::get<1>(GetParam()); 34 | 35 | std::vector h_buffer{}; 36 | h_buffer.resize(size_in_bytes); 37 | for (size_t i = 0; i < size_in_bytes; ++i) { 38 | h_buffer[i] = value + 1; // Initialize host buffer to a different value than memset value to verify memset is working correctly 39 | } 40 | 41 | char* d_ptr; 42 | NVTE_CHECK_CUDA(cudaMalloc(&d_ptr, size_in_bytes)); 43 | 44 | NVTE_CHECK_CUDA(cudaMemcpy(d_ptr, h_buffer.data(), size_in_bytes, cudaMemcpyHostToDevice)); 45 | 46 | nvte_memset(d_ptr, value, size_in_bytes, 0 /* stream */); 47 | 48 | NVTE_CHECK_CUDA(cudaMemcpy( 49 | h_buffer.data(), d_ptr, size_in_bytes, cudaMemcpyDeviceToHost)); 50 | NVTE_CHECK_CUDA(cudaFree(d_ptr)); 51 | 52 | NVTE_CHECK_CUDA(cudaDeviceSynchronize()); 53 | 54 | for (size_t i = 0; i < size_in_bytes; ++i) { 55 | EXPECT_EQ(h_buffer[i], static_cast(value)) 56 | << "Mismatch at index " << i << ": expected " << static_cast(value) 57 | << ", got " << static_cast(h_buffer[i]); 58 | } 59 | } 60 | 61 | namespace { 62 | 63 | std::vector memset_test_sizes = { 64 | 1, 65 | 4, 66 | 9, 67 | 16, 68 | 128, 69 | 4096, 70 | 4097, 71 | 8192, 72 | }; 73 | 74 | } // namespace 75 | 76 | INSTANTIATE_TEST_SUITE_P( 77 | OperatorTest, 78 | MemsetTestSuite, 79 | ::testing::Combine( 80 | ::testing::Values(0, 6), 81 | ::testing::ValuesIn(memset_test_sizes)), 82 | [](const testing::TestParamInfo& info) { 83 | std::string name = std::to_string(std::get<0>(info.param)) + "X" + 84 | std::to_string(std::get<1>(info.param)); 85 | return name; 86 | }); 87 | -------------------------------------------------------------------------------- /tests/cpp/util/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | add_executable(test_util 6 | test_nvrtc.cpp 7 | test_string.cpp 8 | ../test_common.cu) 9 | 10 | 11 | find_package(OpenMP REQUIRED) 12 | target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX) 13 | target_compile_options(test_util PRIVATE -O2 -fopenmp) 14 | 15 | include(GoogleTest) 16 | gtest_discover_tests(test_util DISCOVERY_TIMEOUT 600) 17 | -------------------------------------------------------------------------------- /tests/jax/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | """conftest for tests/jax""" 5 | import os 6 | import jax 7 | import pytest 8 | 9 | 10 | import transformer_engine.jax 11 | from transformer_engine_jax import get_device_compute_capability 12 | 13 | 14 | @pytest.fixture(autouse=True, scope="function") 15 | def clear_live_arrays(): 16 | """ 17 | Clear all live arrays to keep the resource clean 18 | """ 19 | yield 20 | for arr in jax.live_arrays(): 21 | arr.delete() 22 | 23 | 24 | @pytest.fixture(autouse=True, scope="module") 25 | def enable_fused_attn_after_hopper(): 26 | """ 27 | Enable fused attn for hopper+ arch. 28 | Fused attn kernels on pre-hopper arch are not deterministic. 29 | """ 30 | if get_device_compute_capability(0) >= 90: 31 | os.environ["NVTE_FUSED_ATTN"] = "1" 32 | yield 33 | if "NVTE_FUSED_ATTN" in os.environ: 34 | del os.environ["NVTE_FUSED_ATTN"] 35 | -------------------------------------------------------------------------------- /tests/jax/pytest.ini: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | [pytest] 6 | filterwarnings= 7 | ignore:Fused attention is not enabled.*:UserWarning 8 | ignore:The hookimpl.*:DeprecationWarning 9 | ignore:xmap is an experimental feature and probably has bugs! 10 | ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning 11 | ignore:can't resolve package from __spec__ or __package__:ImportWarning 12 | ignore:Using or importing the ABCs.*:DeprecationWarning 13 | ignore:numpy.ufunc size changed 14 | ignore:.*experimental feature 15 | ignore:The distutils.* is deprecated.*:DeprecationWarning 16 | ignore:backend and device argument on jit is deprecated.*:DeprecationWarning 17 | ignore:ml_dtypes.float8_e4m3b11 is deprecated. 18 | ignore:np.find_common_type is deprecated.*:DeprecationWarning 19 | ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning 20 | ignore:The numpy.array_api submodule is still experimental.*:UserWarning 21 | ignore:case not machine-readable.*:UserWarning 22 | ignore:not machine-readable.*:UserWarning 23 | ignore:Special cases found for .* but none were parsed.*:UserWarning 24 | ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning 25 | ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning 26 | ignore:The host_callback APIs are deprecated .*:DeprecationWarning 27 | ignore:Scan loop is disabled for fused ring attention.*:UserWarning 28 | ignore:jax.extend.ffi.register_ffi_target is deprecated 29 | ignore:jax.extend.ffi.ffi_lowering is deprecated 30 | -------------------------------------------------------------------------------- /tests/jax/test_misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import pytest 6 | from functools import partial 7 | import os 8 | 9 | from transformer_engine.jax.cpp_extensions.misc import get_xla_flag 10 | 11 | 12 | @pytest.fixture(autouse=True, scope="function") 13 | def preserve_xla_flags(): 14 | """Ensures the XLA flags environment variable is restored after any tests in this file run.""" 15 | old_flags = os.getenv("XLA_FLAGS") 16 | yield 17 | if old_flags is not None: 18 | os.environ["XLA_FLAGS"] = old_flags 19 | 20 | 21 | def test_get_xla_flag(request): 22 | os.environ["XLA_FLAGS"] = "" 23 | assert get_xla_flag("") is None 24 | assert get_xla_flag("--foo") is None 25 | assert get_xla_flag("--bar=1") is None 26 | 27 | os.environ["XLA_FLAGS"] = "--foo --bar=1 --baz=biz" 28 | assert get_xla_flag("--foo") == True 29 | assert get_xla_flag("--bar") == "1" 30 | assert get_xla_flag("--bar", cast=int) == 1 31 | assert get_xla_flag("--bar", cast=bool) == True 32 | assert get_xla_flag("--baz") == "biz" 33 | with pytest.raises(ValueError): 34 | # cast will fail 35 | assert get_xla_flag("--baz", cast=int) 36 | assert get_xla_flag("--xla") is None 37 | 38 | os.environ["XLA_FLAGS"] = "--xla_abc --xla_abb" 39 | assert get_xla_flag("--xla_abc") == True 40 | assert get_xla_flag("--xla_abb") == True 41 | -------------------------------------------------------------------------------- /tests/jax/test_sanity_import.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import transformer_engine.jax 6 | 7 | print("OK") 8 | -------------------------------------------------------------------------------- /tests/jax/test_sharding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import pytest 6 | 7 | from transformer_engine.jax.flax import extend_logical_axis_rules 8 | from transformer_engine.jax.sharding import global_shard_guard, MeshResource 9 | 10 | LOGICAL_RULES = [ 11 | [(("a1", None), ("a2", "ma2")), False], 12 | [(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True], 13 | [(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False], 14 | [(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True], 15 | [(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True], 16 | ] 17 | 18 | MeshS = [ 19 | MeshResource(), 20 | MeshResource("data", None), 21 | MeshResource(None, "model"), 22 | MeshResource("data", "model"), 23 | ] 24 | 25 | 26 | class TestShardingSideAPI: 27 | 28 | @pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES) 29 | @pytest.mark.parametrize("sr", MeshS) 30 | def test_extend_logical_axis_rules(self, base_rules, need_assert, sr): 31 | with global_shard_guard(sr): 32 | try: 33 | target_te_rules = extend_logical_axis_rules(tuple()) 34 | extended_rules = extend_logical_axis_rules(base_rules) 35 | assert extended_rules == (*base_rules, *target_te_rules) 36 | assert not need_assert 37 | except AssertionError as ae: 38 | assert need_assert, f"{ae.args}" 39 | -------------------------------------------------------------------------------- /tests/pytorch/debug/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | import pytest 5 | 6 | 7 | def pytest_addoption(parser): 8 | parser.addoption( 9 | "--feature_dirs", nargs="+", action="store", default="", help="List of feature directories" 10 | ) 11 | parser.addoption( 12 | "--configs_dir", 13 | action="store", 14 | default="", 15 | type=str, 16 | help="Path to the directory with configs.", 17 | ) 18 | 19 | 20 | @pytest.fixture 21 | def feature_dirs(request): 22 | return request.config.getoption("--feature_dirs") 23 | 24 | 25 | @pytest.fixture 26 | def configs_dir(request): 27 | return request.config.getoption("--configs_dir") 28 | -------------------------------------------------------------------------------- /tests/pytorch/debug/test_configs/disable_fp8_gemms.yaml: -------------------------------------------------------------------------------- 1 | test_disable_fp8_gemm_1: 2 | enabled: True 3 | layers: 4 | layer_types: [qkv, fc2] 5 | transformer_engine: 6 | DisableFP8GEMM: 7 | enabled: True 8 | gemms: [dgrad, wgrad] -------------------------------------------------------------------------------- /tests/pytorch/debug/test_configs/disable_fp8_layer.yaml: -------------------------------------------------------------------------------- 1 | test_disable_fp8_layer: 2 | enabled: True 3 | layers: 4 | layer_types: [qkv] 5 | transformer_engine: 6 | DisableFP8Layer: 7 | enabled: True -------------------------------------------------------------------------------- /tests/pytorch/debug/test_configs/dummy_feature.yaml: -------------------------------------------------------------------------------- 1 | deummy_feature_everywhere: 2 | enabled: True 3 | layers: 4 | layer_name_regex_pattern: .* 5 | transformer_engine: 6 | TestDummyFeature: 7 | enabled: True 8 | tensors: [weight, activation, gradient, output, wgrad, dgrad] 9 | gemms: [wgrad, dgrad, fprop] -------------------------------------------------------------------------------- /tests/pytorch/debug/test_configs/fake_quantization_config.yaml: -------------------------------------------------------------------------------- 1 | test_fake_quant_fp8: 2 | enabled: True 3 | layers: 4 | layer_numbers: [1] 5 | layer_types: [fc1, fc2] 6 | transformer_engine: 7 | FakeQuant: 8 | enabled: True 9 | gemms: [fprop, dgrad] 10 | tensors_struct: 11 | - tensor: activation 12 | quant_format: FP8E4M3 13 | - tensor: gradient 14 | quant_format: FP8E5M2 -------------------------------------------------------------------------------- /tests/pytorch/debug/test_configs/per_tensor_scaling.yaml: -------------------------------------------------------------------------------- 1 | test_per_tensor_scaling: 2 | enabled: True 3 | layers: 4 | layer_numbers: [1] 5 | layer_types: [fc1, fc2] 6 | transformer_engine: 7 | DisableFP8GEMM: 8 | enabled: True 9 | gemms: [wgrad] 10 | PerTensorScaling: 11 | enabled: True 12 | gemms_struct: 13 | - gemm: fprop 14 | tensors_struct: 15 | - tensor: activation 16 | - tensor: weight 17 | - gemm: dgrad 18 | tensors_struct: 19 | - tensor: gradient -------------------------------------------------------------------------------- /tests/pytorch/debug/test_configs/stats_collection_test_config.yaml: -------------------------------------------------------------------------------- 1 | stat_collection_test_1: 2 | enabled: True 3 | layers: 4 | layer_numbers: [1, 3] 5 | LogTensorStats: 6 | enabled: True 7 | stats: [mean, std, l1_norm, l2_norm] 8 | tensors: [activation] 9 | freq: 1 10 | start_step: 100 11 | end_step: 500 12 | transformer_engine: 13 | LogTensorStats: 14 | enabled: True 15 | stats: [cur_amax, dynamic_range] 16 | tensors: [activation] 17 | freq: 2 18 | start_step: 100 19 | end_step: 500 20 | LogFp8TensorStats: 21 | enabled: True 22 | stats: [underflows%] 23 | tensors: [gradient] 24 | freq: 5 25 | start_step: 100 26 | end_step: 500 27 | 28 | stat_collection_test_2: 29 | enabled: True 30 | layers: 31 | layer_numbers: [6, 7] 32 | transformer_engine: 33 | LogTensorStats: 34 | enabled: True 35 | tensors_struct: 36 | - tensor: activation 37 | stats: [cur_amax, dynamic_range, mean, std, l1_norm] 38 | freq: 2 39 | start_step: 100 40 | end_step: 500 41 | - tensor: weight 42 | stats: [mean, std, l1_norm, min, max] 43 | freq: 5 44 | start_step: 100 45 | end_step: 500 46 | 47 | stat_collection_test_4: 48 | enabled: True 49 | layers: 50 | layer_numbers: [5] 51 | transformer_engine: 52 | LogTensorStats: 53 | enabled: True 54 | tensors: [activation] 55 | stats: [cur_amax, dynamic_range, mean, std, l1_norm] 56 | LogFp8TensorStats: 57 | enabled: True 58 | stats: [underflows%] 59 | tensors: [activation] -------------------------------------------------------------------------------- /tests/pytorch/debug/test_configs/tensor_manipulation_transformer_engine.yaml: -------------------------------------------------------------------------------- 1 | # This config is used when FP8 training is ON 2 | 3 | transformer_engine_fc1_manipulation: 4 | enabled: True 5 | layers: 6 | layer_name_regex_pattern: .*(fc1) # Select layers if they end in fc1 7 | transformer_engine: # namespace 8 | DisableFP8GEMM: # Disable FP8 GEMM. FProp run in high precision 9 | enabled: True 10 | gemms: [fprop] 11 | PerTensorScaling: # Scale DGrad gradients using per tensor current scaling and run FP8 GEMM 12 | enabled: True 13 | gemms: [dgrad] 14 | tensors: [gradient] 15 | FakeQuant: # Disable FP8 GEMM for Wgrad. Fake quantize activations to Wgrad and run high precision GEMM 16 | enabled: True 17 | gemms: [fprop] 18 | tensors_struct: 19 | - tensor: activation 20 | quant_format: FP8E4M3 21 | - tensor: weight 22 | quant_format: FP8E4M3 23 | 24 | transformer_engine_fc2_manipulation: 25 | enabled: True 26 | layers: 27 | layer_name_regex_pattern: .*(fc2) # Select layers if they end in fc2 28 | transformer_engine: # namespace 29 | PerTensorScaling: # Scale WGrad and Fprop inputs using per tensor current scaling and run FP8 GEMM 30 | enabled: True 31 | gemms_struct: 32 | - gemm: fprop 33 | tensors_struct: 34 | - tensor: activation 35 | - tensor: weight 36 | - gemm: wgrad 37 | tensors_struct: 38 | - tensor: activation 39 | - tensor: gradient 40 | FakeQuant: # Disable FP8 GEMM for DGrad. Fake quantize weights and gradients to DGrad and run high precision GEMM 41 | enabled: True 42 | gemms_struct: 43 | - gemm: dgrad 44 | tensors: [weight, gradient] 45 | quant_format: FP8E5M2 -------------------------------------------------------------------------------- /tests/pytorch/debug/test_distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import os 6 | import subprocess 7 | from pathlib import Path 8 | 9 | import pytest 10 | import torch 11 | 12 | """ 13 | Distributed numerics tests 14 | 15 | These tests test the numerical corectness of the TransformerEngine layers. 16 | Tests are parametrized by the layer and fp8 precision. 17 | One test consists of running multiple configurations from file run_numerics.py 18 | Such design is due to the fact the initialization of one test is long 19 | - 2 processes need to start and load torch and TE. Multiple configurations 20 | are run in one test - this reduces the initialization overhead. 21 | 22 | """ 23 | 24 | 25 | if torch.cuda.device_count() < 2: 26 | pytest.skip("Distributed training needs at least 2 GPUs.") 27 | 28 | TEST_ROOT = Path(__file__).parent.resolve() 29 | NUM_PROCS: int = min(4, torch.cuda.device_count()) 30 | LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] 31 | 32 | 33 | def test_debug_distributed(feature_dirs): 34 | test_path = TEST_ROOT / "run_distributed.py" 35 | test_cmd = LAUNCH_CMD + [str(test_path), f"--feature_dirs={feature_dirs[0]}"] 36 | 37 | result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) 38 | if result.returncode != 0: 39 | raise AssertionError(result.stderr.decode()) 40 | -------------------------------------------------------------------------------- /tests/pytorch/debug/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import os 6 | 7 | LOG_FILE = os.path.join("nvdlfw_inspect_logs", "nvdlfw_inspect_globalrank-0.log") 8 | 9 | 10 | def reset_debug_log(): 11 | if os.path.isfile(LOG_FILE): 12 | # delete all content 13 | with open(LOG_FILE, "w") as f: 14 | pass 15 | 16 | 17 | def check_debug_log(msg): 18 | with open(LOG_FILE, "r") as f: 19 | for line in f.readlines(): 20 | if msg in line: 21 | return True 22 | return False 23 | -------------------------------------------------------------------------------- /tests/pytorch/distributed/test_cast_master_weights_to_fp8.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import os 6 | import subprocess 7 | from pathlib import Path 8 | 9 | import pytest 10 | import torch 11 | from transformer_engine.pytorch.fp8 import FP8GlobalStateManager 12 | 13 | 14 | if torch.cuda.device_count() < 2: 15 | pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.") 16 | 17 | fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() 18 | fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( 19 | FP8GlobalStateManager.is_fp8_block_scaling_available() 20 | ) 21 | 22 | TEST_ROOT = Path(__file__).parent.resolve() 23 | NUM_PROCS: int = min(2, torch.cuda.device_count()) 24 | LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] 25 | 26 | 27 | def _run_test(quantization): 28 | test_path = TEST_ROOT / "run_cast_master_weights_to_fp8.py" 29 | test_cmd = LAUNCH_CMD + [str(test_path)] + ["--quantization", quantization] 30 | result = subprocess.run(test_cmd, env=os.environ, check=False) 31 | assert result.returncode == 0 32 | 33 | 34 | @pytest.mark.parametrize("quantization", ["fp8", "fp8_cs", "fp8_block"]) 35 | def test_cast_master_weights_to_fp8(quantization): 36 | if quantization in ("fp8", "fp8_cs") and not fp8_available: 37 | pytest.skip(reason_for_no_fp8) 38 | if quantization == "fp8_block" and not fp8_block_scaling_available: 39 | pytest.skip(reason_for_no_fp8_block_scaling) 40 | _run_test(quantization) 41 | -------------------------------------------------------------------------------- /tests/pytorch/distributed/test_numerics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import os 6 | import subprocess 7 | from pathlib import Path 8 | 9 | import pytest 10 | import torch 11 | from transformer_engine.pytorch.fp8 import FP8GlobalStateManager 12 | 13 | """ 14 | Distributed numerics tests 15 | 16 | These tests test the numerical corectness of the TransformerEngine layers. 17 | Tests are parametrized by the layer and fp8 precision. 18 | One test consists of running multiple configurations from file run_numerics.py 19 | Such design is due to the fact the initialization of one test is long 20 | - 2 processes need to start and load torch and TE. Multiple configurations 21 | are run in one test - this reduces the initialization overhead. 22 | 23 | """ 24 | 25 | 26 | if torch.cuda.device_count() < 2: 27 | pytest.skip("Distributed training needs at least 2 GPUs.") 28 | 29 | fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() 30 | mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() 31 | fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( 32 | FP8GlobalStateManager.is_fp8_block_scaling_available() 33 | ) 34 | 35 | TEST_ROOT = Path(__file__).parent.resolve() 36 | NUM_PROCS: int = min(4, torch.cuda.device_count()) 37 | LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] 38 | 39 | 40 | def _run_test(quantization): 41 | test_path = TEST_ROOT / "run_numerics.py" 42 | test_cmd = LAUNCH_CMD + [str(test_path)] 43 | 44 | if quantization is not None: 45 | test_cmd += ["--quantization", quantization] 46 | 47 | result = subprocess.run(test_cmd, env=os.environ, check=False) 48 | assert result.returncode == 0 49 | 50 | 51 | all_boolean = [True, False] 52 | 53 | 54 | @pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) 55 | def test_distributed(quantization): 56 | if quantization == "fp8" and not fp8_available: 57 | pytest.skip(reason_for_no_fp8) 58 | if quantization == "fp8_cs" and not fp8_available: 59 | pytest.skip(fp8_available) 60 | if quantization == "mxfp8" and not mxfp8_available: 61 | pytest.skip(reason_for_no_mxfp8) 62 | if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: 63 | pytest.skip(reason_for_no_fp8_block_scaling) 64 | _run_test(quantization) 65 | -------------------------------------------------------------------------------- /tests/pytorch/distributed/test_torch_fsdp2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import os 6 | import pytest 7 | import subprocess 8 | from pathlib import Path 9 | from transformer_engine.pytorch import torch_version 10 | from transformer_engine.pytorch.fp8 import FP8GlobalStateManager 11 | 12 | import torch 13 | 14 | 15 | fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() 16 | 17 | NUM_PROCS: int = torch.cuda.device_count() 18 | 19 | 20 | def _run_test(fp_init, sharding_dims): 21 | test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" 22 | test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] 23 | 24 | if fp_init: 25 | test_cmd += ["--fp8-init"] 26 | if len(sharding_dims) == 1: 27 | test_cmd += ["--sharding-dims", str(sharding_dims[0])] 28 | elif len(sharding_dims) == 2: 29 | test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] 30 | else: 31 | assert False 32 | result = subprocess.run(test_cmd, env=os.environ, check=True) 33 | 34 | 35 | @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") 36 | @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") 37 | @pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") 38 | @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) 39 | @pytest.mark.parametrize("fp8_init", (False, True)) 40 | def test_distributed(fp8_init, sharding_dims): 41 | 42 | # Skip invalid configurations 43 | if torch.cuda.device_count() < 4: 44 | pytest.skip("FSDP2 test requires at least 4 GPUs") 45 | 46 | if fp8_init and not fp8_available: 47 | pytest.skip(reason_for_no_fp8) 48 | 49 | _run_test(fp8_init, sharding_dims) 50 | 51 | 52 | def test_dummy() -> None: 53 | """Dummy test 54 | 55 | pytest returns exit code 5 if all tests are skipped. 56 | 57 | """ 58 | pass 59 | -------------------------------------------------------------------------------- /tests/pytorch/references/quantize_scale_calc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | from typing import Tuple 6 | import torch 7 | 8 | 9 | def scale_from_amax_tensor( 10 | x_dtype: torch.dtype, 11 | amax: torch.Tensor, 12 | quant_dtype: torch.dtype, 13 | *, 14 | eps: float, 15 | pow_2_scales: bool, 16 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 17 | """Derives quantization and dequantization from amax and options. 18 | 19 | Reference implementation for scale calculation. 20 | 21 | Returns: 22 | - scale: quantization scales 23 | - scale_inv: dequantization scales 24 | - amax: Amax tensor with updates made for extrema values. 25 | """ 26 | assert amax.dtype == torch.float, "amax must be a float tensor." 27 | fp8_max = torch.finfo(quant_dtype).max 28 | # Clamping amax to avoid division by small numbers 29 | amax = torch.max(amax, torch.tensor(eps)) 30 | 31 | # Compute scale factor 32 | scale = torch.div(fp8_max, amax) 33 | # Note frexp doesn't give back inf for exponent with an inf input 34 | # We take care of inf before pow_2_scales 35 | scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale) 36 | if pow_2_scales: 37 | # Calculate rounded down exponent 38 | _, exp = torch.frexp(scale) 39 | # Positive numbers are always returned as mant, exp with 40 | # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with 41 | # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because 42 | # of the shift. Subnormal and zero cases need not be considered because 43 | # the smallest possible result of fp8_max / amax is still normal. 44 | exp = exp - 1 45 | # No subnormals and zero. 46 | assert (exp > -127).all() 47 | unity = torch.tensor([1.0], device=exp.device) 48 | torch.ldexp(unity, exp, out=scale) 49 | # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales 50 | # Return 0.0 for 0.0 scale for consistency with non-pow2 scale 51 | # calculation. 52 | scale = torch.where(amax == float("inf"), 0.0, scale) 53 | 54 | # Handle overflow cases for amax zero causing NaN 55 | scale = torch.where(amax == 0, 1.0, scale) 56 | 57 | # Compute scale_inv 58 | scale_inv = torch.reciprocal(scale) 59 | 60 | return scale, scale_inv, amax 61 | -------------------------------------------------------------------------------- /tests/pytorch/references/ref_per_tensor_cs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import torch 6 | import transformer_engine_torch as tex 7 | 8 | from transformer_engine.pytorch.constants import TE_DType_To_Torch 9 | from references.quantize_scale_calc import scale_from_amax_tensor 10 | 11 | 12 | # compute amax and scale 13 | def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): 14 | x_fp32 = x.to(torch.float32) 15 | amax = torch.amax(torch.abs(x_fp32)).view(1) 16 | return scale_from_amax_tensor( 17 | torch.float32, amax, quant_dtype, eps=eps, pow_2_scales=pow_2_scales 18 | ) 19 | 20 | 21 | def _multi_dim_transpose(tensor): 22 | # Get the number of dimensions 23 | dims = list(range(len(tensor.shape))) 24 | 25 | if len(dims) <= 1: 26 | return tensor 27 | 28 | # circular shift of shapes 29 | new_order = [] 30 | new_order.append(dims[-1]) 31 | for i in range(len(dims) - 1): 32 | new_order.append(dims[i]) 33 | 34 | # Permute the tensor according to the new order 35 | output_tensor = tensor.permute(new_order).contiguous() 36 | 37 | return output_tensor 38 | 39 | 40 | # current scaling reference quantization 41 | def ref_per_tensor_cs_cast( 42 | tensor: torch.Tensor, 43 | fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, 44 | return_transpose: bool = False, 45 | force_pow_2_scales: bool = False, 46 | amax_epsilon: float = 0.0, 47 | ) -> torch.Tensor: 48 | 49 | quant_dtype_torch = TE_DType_To_Torch[fp8_dtype] 50 | scale, scale_inv, _ = _ref_compute_amax_scale( 51 | tensor, 52 | quant_dtype_torch, 53 | amax_epsilon, 54 | force_pow_2_scales, 55 | ) 56 | 57 | qx = (tensor.float() * scale).to(quant_dtype_torch) 58 | sx = scale_inv 59 | qx_t = None 60 | sx_t = None 61 | 62 | if tensor.shape == torch.Size([]): 63 | qx = qx.view([]) 64 | 65 | if return_transpose: 66 | qx_t = _multi_dim_transpose(qx) 67 | sx_t = sx 68 | return qx, sx, qx_t, sx_t 69 | -------------------------------------------------------------------------------- /tests/pytorch/test_gqa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import pytest 6 | import torch 7 | 8 | import transformer_engine.pytorch as te 9 | 10 | batch_size = 32 11 | seq_length = 2048 12 | num_heads = 16 13 | head_dim = 64 14 | dtype = torch.bfloat16 15 | num_attn_head = 16 16 | ffn_hidden_size = 1024 17 | 18 | 19 | @pytest.mark.parametrize("kv_channels", [128, 256]) 20 | @pytest.mark.parametrize("hidden_size", [128, 256]) 21 | @pytest.mark.parametrize("num_gqa_groups", [1, 2, 4, 8, 16]) 22 | def test_gqa(kv_channels, hidden_size, num_gqa_groups) -> None: 23 | 24 | model = te.TransformerLayer( 25 | hidden_size, ffn_hidden_size, num_attn_head, num_gqa_groups, kv_channels=kv_channels 26 | ) 27 | 28 | # Run forward pass 29 | x = torch.randn((batch_size, 1, hidden_size)).cuda() 30 | model(x) 31 | 32 | # Check shapes of weights. 33 | assert model.self_attention.layernorm_qkv.key_weight.shape[0] == kv_channels * num_gqa_groups 34 | assert model.self_attention.layernorm_qkv.key_weight.shape[1] == hidden_size 35 | 36 | assert model.self_attention.layernorm_qkv.query_weight.shape[0] == kv_channels * num_attn_head 37 | assert model.self_attention.layernorm_qkv.query_weight.shape[1] == hidden_size 38 | 39 | assert model.self_attention.layernorm_qkv.value_weight.shape[0] == kv_channels * num_gqa_groups 40 | assert model.self_attention.layernorm_qkv.value_weight.shape[1] == hidden_size 41 | 42 | assert model.self_attention.proj.weight.shape[0] == hidden_size 43 | assert model.self_attention.proj.weight.shape[1] == kv_channels * num_attn_head 44 | -------------------------------------------------------------------------------- /tests/pytorch/test_hf_integration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import pytest 6 | from transformers.configuration_utils import PretrainedConfig 7 | from transformers.modeling_utils import PreTrainedModel 8 | 9 | from transformer_engine.pytorch.transformer import TransformerLayer 10 | from transformer_engine.pytorch.utils import is_bf16_compatible 11 | 12 | 13 | class SimpleTEModel(PreTrainedModel): 14 | config_class = PretrainedConfig 15 | 16 | def __init__(self, config: PretrainedConfig): 17 | super().__init__(config) 18 | self.my_layer = TransformerLayer( 19 | hidden_size=320, 20 | num_attention_heads=16, 21 | ffn_hidden_size=1024, 22 | layer_number=None, 23 | ) 24 | 25 | def forward(self, hidden_states, attention_mask): 26 | return self.my_layer(hidden_states, attention_mask) 27 | 28 | 29 | def test_save_hf_model(tmp_path): 30 | model = SimpleTEModel(PretrainedConfig()) 31 | model.save_pretrained(tmp_path / "simple_te_model") 32 | 33 | 34 | @pytest.mark.xfail(reason="This test is failing until huggingface/transformers#38155 is merged.") 35 | def test_save_and_load_hf_model(tmp_path): 36 | model = SimpleTEModel(PretrainedConfig()) 37 | model.save_pretrained(tmp_path / "simple_te_model") 38 | del model 39 | model = SimpleTEModel.from_pretrained(tmp_path / "simple_te_model") 40 | assert model is not None 41 | -------------------------------------------------------------------------------- /tests/pytorch/test_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | from typing import Tuple 6 | 7 | import pytest 8 | import torch 9 | 10 | import transformer_engine.pytorch as te 11 | 12 | # Model names for test_torch_dynamo 13 | _model_factory = { 14 | "Linear": [(lambda: te.Linear(16, 16)), [16, 16]], 15 | "LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]], 16 | "LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]], 17 | "LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]], 18 | "TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]], 19 | } 20 | 21 | 22 | @pytest.mark.skipif(torch.__version__ < "2", reason="torch.compile not available") 23 | @pytest.mark.parametrize("model_name", list(_model_factory.keys())) 24 | def test_torch_dynamo(model_name: str): 25 | """Test compatibility with Torch Dynamo 26 | 27 | Construct model, optimize with Torch Dynamo, and perform a single 28 | forward and backward pass. 29 | 30 | """ 31 | 32 | # Helper function to construct tensor with default options 33 | def make_tensor( 34 | dims: Tuple[int], 35 | dtype: torch.dtype = torch.float32, 36 | device: torch.device = "cuda", 37 | requires_grad: bool = True, 38 | **kwargs, 39 | ): 40 | return torch.zeros( 41 | dims, 42 | dtype=dtype, 43 | device=device, 44 | requires_grad=requires_grad, 45 | **kwargs, 46 | ) 47 | 48 | # Construct model and input tensors 49 | model_builder, input_builder = _model_factory[model_name] 50 | model = model_builder() 51 | inputs = [make_tensor(input_builder)] 52 | 53 | # Optimize model with TorchDynamo 54 | torch.compile(model) 55 | 56 | # Forward and backward pass 57 | out = model(*inputs) 58 | out.backward(torch.zeros_like(out)) 59 | 60 | 61 | def test_lazy_compile(): 62 | """Smoke test to ensure lazy compilation is working.""" 63 | from transformer_engine.pytorch.jit import dgelu_fused_ 64 | 65 | dgelu_fused_(torch.randn(10, 10), torch.randn(10, 10)) 66 | -------------------------------------------------------------------------------- /tests/pytorch/test_sanity_import.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | import transformer_engine.pytorch 6 | 7 | print("OK") 8 | -------------------------------------------------------------------------------- /transformer_engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Top level package""" 6 | 7 | # pylint: disable=unused-import 8 | 9 | import os 10 | from importlib import metadata 11 | import transformer_engine.common 12 | 13 | try: 14 | from . import pytorch 15 | except ImportError: 16 | pass 17 | except FileNotFoundError as e: 18 | if "Could not find shared object file" in str(e): 19 | if os.getenv("NVTE_FRAMEWORK") is None: 20 | # If we got here, we could import `torch` but could not load the framework extension. 21 | # This can happen when a user wants to work only with `transformer_engine.jax` on a system that 22 | # also has a PyTorch installation. In order to enable that use case, we issue a warning here 23 | # about the missing PyTorch extension in case the user hasn't set NVTE_FRAMEWORK. 24 | import warnings 25 | 26 | warnings.warn( 27 | "Detected a PyTorch installation but could not find the shared object file for the " 28 | "Transformer Engine PyTorch extension library. If this is not intentional, please " 29 | "reinstall Transformer Engine with `pip install transformer_engine[pytorch]` or " 30 | "build from source with `NVTE_FRAMEWORK=pytorch`.", 31 | category=RuntimeWarning, 32 | ) 33 | elif os.getenv("NVTE_FRAMEWORK") in ("pytorch", "all"): 34 | raise e 35 | 36 | try: 37 | from . import jax 38 | except ImportError: 39 | pass 40 | except FileNotFoundError as e: 41 | if "Could not find shared object file" in str(e): 42 | if os.getenv("NVTE_FRAMEWORK") is None: 43 | # If we got here, we could import `jax` but could not load the framework extension. 44 | # This can happen when a user wants to work only with `transformer_engine.pytorch` on a system 45 | # that also has a Jax installation. In order to enable that use case, we issue a warning here 46 | # about the missing Jax extension in case the user hasn't set NVTE_FRAMEWORK. 47 | import warnings 48 | 49 | warnings.warn( 50 | "Detected a Jax installation but could not find the shared object file for the " 51 | "Transformer Engine Jax extension library. If this is not intentional, please " 52 | "reinstall Transformer Engine with `pip install transformer_engine[jax]` or " 53 | "build from source with `NVTE_FRAMEWORK=jax`.", 54 | category=RuntimeWarning, 55 | ) 56 | elif os.getenv("NVTE_FRAMEWORK") in ("jax", "all"): 57 | raise e 58 | 59 | __version__ = str(metadata.version("transformer_engine")) 60 | -------------------------------------------------------------------------------- /transformer_engine/common/activation/gelu.cu: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "../util/math.h" 8 | #include "./activation_template.h" 9 | 10 | void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { 11 | NVTE_API_CALL(nvte_gelu); 12 | using namespace transformer_engine; 13 | act_fn>(input, output, stream); 14 | } 15 | 16 | void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, 17 | cudaStream_t stream) { 18 | NVTE_API_CALL(nvte_dgelu); 19 | using namespace transformer_engine; 20 | dact_fn>(grad, input, output, stream); 21 | } 22 | 23 | void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { 24 | NVTE_API_CALL(nvte_geglu); 25 | using namespace transformer_engine; 26 | gated_act_fn>(input, output, stream); 27 | } 28 | 29 | void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, 30 | cudaStream_t stream) { 31 | NVTE_API_CALL(nvte_dgeglu); 32 | using namespace transformer_engine; 33 | dgated_act_fn, dgelu>(grad, input, output, stream); 34 | } 35 | 36 | void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { 37 | NVTE_API_CALL(nvte_qgelu); 38 | using namespace transformer_engine; 39 | act_fn>(input, output, stream); 40 | } 41 | 42 | void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, 43 | cudaStream_t stream) { 44 | NVTE_API_CALL(nvte_dqgelu); 45 | using namespace transformer_engine; 46 | dact_fn>(grad, input, output, stream); 47 | } 48 | 49 | void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { 50 | NVTE_API_CALL(nvte_qgeglu); 51 | using namespace transformer_engine; 52 | gated_act_fn>(input, output, stream); 53 | } 54 | 55 | void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, 56 | cudaStream_t stream) { 57 | NVTE_API_CALL(nvte_dqgeglu); 58 | using namespace transformer_engine; 59 | dgated_act_fn, dqgelu>(grad, input, output, stream); 60 | } 61 | -------------------------------------------------------------------------------- /transformer_engine/common/activation/relu.cu: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "../util/math.h" 8 | #include "./activation_template.h" 9 | 10 | void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { 11 | NVTE_API_CALL(nvte_relu); 12 | using namespace transformer_engine; 13 | act_fn>(input, output, stream); 14 | } 15 | 16 | void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, 17 | cudaStream_t stream) { 18 | NVTE_API_CALL(nvte_drelu); 19 | using namespace transformer_engine; 20 | dact_fn>(grad, input, output, stream); 21 | } 22 | 23 | void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { 24 | NVTE_API_CALL(nvte_reglu); 25 | using namespace transformer_engine; 26 | gated_act_fn>(input, output, stream); 27 | } 28 | 29 | void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, 30 | cudaStream_t stream) { 31 | NVTE_API_CALL(nvte_dreglu); 32 | using namespace transformer_engine; 33 | dgated_act_fn, drelu>(grad, input, output, stream); 34 | } 35 | 36 | void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { 37 | NVTE_API_CALL(nvte_srelu); 38 | using namespace transformer_engine; 39 | act_fn>(input, output, stream); 40 | } 41 | 42 | void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, 43 | cudaStream_t stream) { 44 | NVTE_API_CALL(nvte_dsrelu); 45 | using namespace transformer_engine; 46 | dact_fn>(grad, input, output, stream); 47 | } 48 | 49 | void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { 50 | NVTE_API_CALL(nvte_sreglu); 51 | using namespace transformer_engine; 52 | gated_act_fn>(input, output, stream); 53 | } 54 | 55 | void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, 56 | cudaStream_t stream) { 57 | NVTE_API_CALL(nvte_dsreglu); 58 | using namespace transformer_engine; 59 | dgated_act_fn, dsrelu>(grad, input, output, stream); 60 | } 61 | -------------------------------------------------------------------------------- /transformer_engine/common/activation/swiglu.cu: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "../util/math.h" 8 | #include "./activation_template.h" 9 | 10 | void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { 11 | NVTE_API_CALL(nvte_silu); 12 | using namespace transformer_engine; 13 | act_fn>(input, output, stream); 14 | } 15 | 16 | void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, 17 | cudaStream_t stream) { 18 | NVTE_API_CALL(nvte_dsilu); 19 | using namespace transformer_engine; 20 | dact_fn>(grad, input, output, stream); 21 | } 22 | 23 | void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { 24 | NVTE_API_CALL(nvte_swiglu); 25 | using namespace transformer_engine; 26 | gated_act_fn>(input, output, stream); 27 | } 28 | 29 | void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, 30 | cudaStream_t stream) { 31 | NVTE_API_CALL(nvte_dswiglu); 32 | using namespace transformer_engine; 33 | dgated_act_fn, dsilu>(grad, input, output, stream); 34 | } 35 | -------------------------------------------------------------------------------- /transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #ifndef TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H 8 | #define TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | typedef enum { 23 | ipcSocketSuccess = 0, 24 | ipcSocketUnhandledCudaError = 1, 25 | ipcSocketSystemError = 2, 26 | ipcSocketInternalError = 3, 27 | ipcSocketInvalidArgument = 4, 28 | ipcSocketInvalidUsage = 5, 29 | ipcSocketRemoteError = 6, 30 | ipcSocketInProgress = 7, 31 | ipcSocketNumResults = 8 32 | } ipcSocketResult_t; 33 | 34 | const char *ipcSocketGetErrorString(ipcSocketResult_t res); 35 | 36 | #define IPC_SOCKNAME_LEN 64 37 | 38 | struct IpcSocketHandle { 39 | int fd; 40 | char socketName[IPC_SOCKNAME_LEN]; 41 | volatile uint32_t *abortFlag; 42 | }; 43 | 44 | ipcSocketResult_t ipcSocketInit(IpcSocketHandle *handle, int rank, uint64_t hash, 45 | volatile uint32_t *abortFlag); 46 | ipcSocketResult_t ipcSocketClose(IpcSocketHandle *handle); 47 | ipcSocketResult_t ipcSocketGetFd(IpcSocketHandle *handle, int *fd); 48 | 49 | ipcSocketResult_t ipcSocketRecvFd(IpcSocketHandle *handle, int *fd); 50 | ipcSocketResult_t ipcSocketSendFd(IpcSocketHandle *handle, const int fd, int rank, uint64_t hash); 51 | 52 | #endif /* TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H */ 53 | -------------------------------------------------------------------------------- /transformer_engine/common/cudnn_utils.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "cudnn_utils.h" 8 | 9 | #include "./util/logging.h" 10 | #include "transformer_engine/cudnn.h" 11 | 12 | namespace transformer_engine { 13 | 14 | // get cuDNN data type 15 | cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) { 16 | using namespace transformer_engine; 17 | switch (t) { 18 | case DType::kInt32: 19 | return CUDNN_DATA_INT32; 20 | case DType::kInt64: 21 | return CUDNN_DATA_INT64; 22 | case DType::kFloat16: 23 | return CUDNN_DATA_HALF; 24 | case DType::kFloat32: 25 | return CUDNN_DATA_FLOAT; 26 | case DType::kBFloat16: 27 | return CUDNN_DATA_BFLOAT16; 28 | case DType::kFloat8E4M3: 29 | return CUDNN_DATA_FP8_E4M3; 30 | case DType::kFloat8E5M2: 31 | return CUDNN_DATA_FP8_E5M2; 32 | default: 33 | NVTE_ERROR("Invalid cuDNN data type. \n"); 34 | } 35 | } 36 | 37 | // get cuDNN data type 38 | cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) { 39 | using namespace transformer_engine; 40 | switch (t) { 41 | case DType::kInt32: 42 | return cudnn_frontend::DataType_t::INT32; 43 | case DType::kInt64: 44 | return cudnn_frontend::DataType_t::INT64; 45 | case DType::kFloat16: 46 | return cudnn_frontend::DataType_t::HALF; 47 | case DType::kFloat32: 48 | return cudnn_frontend::DataType_t::FLOAT; 49 | case DType::kBFloat16: 50 | return cudnn_frontend::DataType_t::BFLOAT16; 51 | case DType::kFloat8E4M3: 52 | return cudnn_frontend::DataType_t::FP8_E4M3; 53 | case DType::kFloat8E5M2: 54 | return cudnn_frontend::DataType_t::FP8_E5M2; 55 | default: 56 | NVTE_ERROR("Invalid cuDNN data type. \n"); 57 | } 58 | } 59 | 60 | void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); } 61 | 62 | namespace detail { 63 | 64 | void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); } 65 | 66 | } // namespace detail 67 | 68 | } // namespace transformer_engine 69 | 70 | namespace cudnn_frontend { 71 | 72 | // This is needed to define the symbol `cudnn_dlhandle` 73 | // When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING 74 | // to enable dynamic loading. 75 | void* cudnn_dlhandle = nullptr; 76 | 77 | } // namespace cudnn_frontend 78 | -------------------------------------------------------------------------------- /transformer_engine/common/cudnn_utils.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #ifndef TRANSFORMER_ENGINE_CUDNN_UTILS_H_ 8 | #define TRANSFORMER_ENGINE_CUDNN_UTILS_H_ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "transformer_engine/transformer_engine.h" 16 | #include "util/handle_manager.h" 17 | 18 | namespace transformer_engine { 19 | 20 | namespace detail { 21 | 22 | void CreateCuDNNHandle(cudnnHandle_t* handle); 23 | 24 | } // namespace detail 25 | 26 | cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); 27 | 28 | cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); 29 | 30 | using cudnnExecutionPlanManager = detail::HandleManager; 31 | 32 | } // namespace transformer_engine 33 | 34 | #endif // TRANSFORMER_ENGINE_CUDNN_UTILS_H_ 35 | -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/cast_transpose_noop.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | /*! \file transpose_with_noop.h 8 | * \brief Functions handling transposes with no-op. 9 | */ 10 | 11 | #ifndef TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ 12 | #define TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ 13 | 14 | #include "transformer_engine.h" 15 | 16 | #ifdef __cplusplus 17 | extern "C" { 18 | #endif 19 | 20 | /*! \brief Transposes the input. 21 | * 22 | * \param[in] input Input tensor to be cast. 23 | * \param[in] noop If this single element tensor has non-zero value, kernel will exit immediately. 24 | * \param[in,out] output Output tensor. 25 | * \param[in] stream CUDA stream used for the operation. 26 | */ 27 | void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, 28 | cudaStream_t stream); 29 | 30 | /*! \brief Casts and transposes the input. 31 | * 32 | * \param[in] input Input tensor to be cast. 33 | * \param[in] noop If this single element tensor has non-zero value, kernel will exit immediately. 34 | * \param[in,out] output Output quantized tensor. 35 | * \param[in] stream CUDA stream used for the operation. 36 | */ 37 | void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, 38 | cudaStream_t stream); 39 | 40 | #ifdef __cplusplus 41 | } // extern "C" 42 | #endif 43 | 44 | #endif // TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ 45 | -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/cudnn.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | /*! \file cudnn.h 8 | * \brief Helper for cuDNN initialization 9 | */ 10 | 11 | #ifndef TRANSFORMER_ENGINE_CUDNN_H_ 12 | #define TRANSFORMER_ENGINE_CUDNN_H_ 13 | 14 | #include "transformer_engine.h" 15 | 16 | /*! \namespace transformer_engine 17 | */ 18 | namespace transformer_engine { 19 | 20 | /*! \brief TE/JAX cudaGraph requires the cuDNN initialization to happen outside of the capturing 21 | * region. This function is a helper to call cudnnCreate() which allocate memory for the handle. 22 | * The function will be called in the initialize() phase of the related XLA custom calls. 23 | */ 24 | 25 | void nvte_cudnn_handle_init(); 26 | 27 | } // namespace transformer_engine 28 | 29 | #endif // TRANSFORMER_ENGINE_CUDNN_H_ 30 | -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/padding.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | /*! \file padding.h 8 | * \brief Functions handling padding. 9 | */ 10 | 11 | #ifndef TRANSFORMER_ENGINE_PADDING_H_ 12 | #define TRANSFORMER_ENGINE_PADDING_H_ 13 | 14 | #include "transformer_engine.h" 15 | 16 | #ifdef __cplusplus 17 | extern "C" { 18 | #endif 19 | 20 | /*! \brief Padding multiple tensors. 21 | * 22 | * NOTE: Padding mode only support bottom. 23 | * 24 | * For example, 3x3 matrix pad to 4x3 matrix. 25 | * 26 | * source 27 | * | 1 | 2 | 3 | 28 | * | 4 | 5 | 6 | 29 | * | 7 | 8 | 9 | 30 | * 31 | * destination 32 | * | 1 | 2 | 3 | 33 | * | 4 | 5 | 6 | 34 | * | 7 | 8 | 9 | 35 | * | 0 | 0 | 0 | 36 | * 37 | * \param[in] num_tensors Number of tensors. 38 | * \param[in] input_list List of 2D input tensors. 39 | * \param[in,out] output_list List of padded tensors. Dimensions 40 | * match tensors in input_list. 41 | * \param[in] padded_num_rows_list List of padded num rows corresponding to input tensors. 42 | * \param[in] stream CUDA stream used for the operation. 43 | */ 44 | void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, 45 | const int* padded_num_rows_list, cudaStream_t stream); 46 | 47 | #ifdef __cplusplus 48 | } // extern "C" 49 | #endif 50 | 51 | #endif // TRANSFORMER_ENGINE_PADDING_H_ 52 | -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/permutation.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #ifndef TRANSFORMER_ENGINE_PERMUTATION_H_ 8 | #define TRANSFORMER_ENGINE_PERMUTATION_H_ 9 | 10 | #include "transformer_engine.h" 11 | 12 | void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor sorted_row_id, 13 | NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, 14 | const NVTETensor input_fwd, const int num_rows, const int topK, 15 | const int num_cols, const int num_out_tokens, cudaStream_t stream = nullptr); 16 | 17 | void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, 18 | const NVTETensor prob, const int num_rows, const int topK, const int num_cols, 19 | cudaStream_t stream = nullptr); 20 | 21 | void nvte_device_radix_sort_pairs(void *temp_storage, size_t *temp_storage_bytes, int *keys_in, 22 | int *keys_out, int *values_in, int *values_out, size_t num_items); 23 | 24 | #endif // TRANSFORMER_ENGINE_PERMUTATION_H_ 25 | -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/swizzle.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | /*! \file cast.h 8 | * \brief Functions to cast to/from FP8. 9 | */ 10 | 11 | #ifndef TRANSFORMER_ENGINE_SWIZZLE_H_ 12 | #define TRANSFORMER_ENGINE_SWIZZLE_H_ 13 | 14 | #include "transformer_engine.h" 15 | 16 | #ifdef __cplusplus 17 | extern "C" { 18 | #endif 19 | 20 | /*! \brief Swizzling scaling factors into the required interleaved layout for GEMM 21 | * 22 | * \param[in] input Input tensor with non-swizzled scale_inv. 23 | * \param[in,out] output Output tensor which hosts swizzled scale_inv. 24 | * \param[in] stream CUDA stream used for the operation. 25 | * 26 | * Requirements: 27 | * - scale_inv is stored in row-major. 28 | * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. 29 | * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. 30 | */ 31 | void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); 32 | 33 | #ifdef __cplusplus 34 | } // extern "C" 35 | #endif 36 | 37 | #endif // TRANSFORMER_ENGINE_SWIZZLE_H_ 38 | -------------------------------------------------------------------------------- /transformer_engine/common/libtransformer_engine.version: -------------------------------------------------------------------------------- 1 | { 2 | global: 3 | extern "C++" { 4 | nvte_*; 5 | transformer_engine::cuda::sm_count*; 6 | transformer_engine::cuda::sm_arch*; 7 | transformer_engine::cuda::supports_multicast*; 8 | transformer_engine::cuda::stream_priority_range*; 9 | transformer_engine::cuda::current_device*; 10 | transformer_engine::cuda_driver::get_symbol*; 11 | transformer_engine::ubuf_built_with_mpi*; 12 | *transformer_engine::rtc*; 13 | transformer_engine::nvte_cudnn_handle_init*; 14 | transformer_engine::nvte_cublas_handle_init*; 15 | transformer_engine::typeToSize*; 16 | transformer_engine::is_fp8_dtype*; 17 | *transformer_engine::CommOverlapBase*; 18 | *transformer_engine::CommOverlapP2PBase*; 19 | *transformer_engine::CommOverlapCore*; 20 | *nvshmem_wait_on_stream*; 21 | *nvshmemi_init_thread* 22 | }; 23 | local: *; 24 | }; 25 | -------------------------------------------------------------------------------- /transformer_engine/common/nvshmem_api/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | ########################################################################## 2 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # See LICENSE for license information. 5 | ########################################################################## 6 | cmake_minimum_required (VERSION 3.18) 7 | project(nvshmemapi LANGUAGES CXX CUDA) 8 | 9 | # Configure dependencies 10 | find_package(CUDAToolkit REQUIRED) 11 | # find_package(MPI REQUIRED) 12 | set(NVSHMEM_HOME "$ENV{NVSHMEM_HOME}" CACHE STRING "Location of NVSHMEM installation") 13 | 14 | add_library(nvshmemapi STATIC nvshmem_waitkernel.cu) 15 | set(NVSHMEMAPI_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}" PARENT_SCOPE) 16 | target_link_directories(nvshmemapi PUBLIC ${NVSHMEM_HOME}/lib) 17 | target_link_libraries(nvshmemapi PUBLIC -static-libstdc++ nvshmem_device nvshmem_host CUDA::nvml CUDA::cublas CUDA::cuda_driver) 18 | target_include_directories(nvshmemapi PRIVATE 19 | ${NVSHMEM_HOME}/include/) 20 | target_include_directories(nvshmemapi PUBLIC 21 | ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} 22 | "${CMAKE_CURRENT_SOURCE_DIR}") 23 | 24 | set_target_properties(nvshmemapi PROPERTIES 25 | CUDA_STANDARD 17 26 | POSITION_INDEPENDENT_CODE ON 27 | CUDA_SEPARABLE_COMPILATION ON) -------------------------------------------------------------------------------- /transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "../util/logging.h" 19 | #include "nvshmem_waitkernel.h" 20 | 21 | __global__ void __launch_bounds__(1) 22 | wait_until_on_stream_and_reset(uint64_t* wait_flag, uint64_t wait_value, 23 | uint64_t signal_reset) { 24 | nvshmem_uint64_wait_until(wait_flag, NVSHMEM_CMP_EQ, wait_value); 25 | *wait_flag = signal_reset; 26 | } 27 | void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t stream) { 28 | uint64_t wait_value = 1; 29 | uint64_t signal_reset = 0; 30 | cudaStream_t cur_stream = stream; 31 | 32 | NVTE_CHECK(wait_kind >= WaitKind::KERNEL_WAIT && wait_kind <= WaitKind::STREAM_WAIT, 33 | "Invalid wait kind: ", static_cast(wait_kind)); 34 | 35 | switch (wait_kind) { 36 | case WaitKind::KERNEL_WAIT: 37 | wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset); 38 | break; 39 | case WaitKind::NVSHMEM_WAIT: 40 | nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream); 41 | cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, 42 | CU_STREAM_WRITE_VALUE_DEFAULT); 43 | break; 44 | case WaitKind::STREAM_WAIT: 45 | cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value, 46 | CU_STREAM_WAIT_VALUE_GEQ); 47 | cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, 48 | CU_STREAM_WRITE_VALUE_DEFAULT); 49 | break; 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #ifndef TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H 8 | #define TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H 9 | 10 | #ifdef __cplusplus 11 | #include 12 | extern "C" { 13 | #else 14 | #include 15 | #endif 16 | 17 | /*! \enum WaitKind 18 | * \brief Types of wait operations that can be performed. 19 | */ 20 | enum class WaitKind { 21 | KERNEL_WAIT = 0, /*!< Wait using a CUDA kernel */ 22 | NVSHMEM_WAIT = 1, /*!< Wait using NVSHMEM wait operation */ 23 | STREAM_WAIT = 2 /*!< Wait using CUDA stream synchronization */ 24 | }; 25 | 26 | /*! \brief Wait on a signal until a certain condition is met. 27 | * 28 | * \param[in] sig_addr The address of the signal to wait on. 29 | * \param[in] wait_kind The kind of wait to perform. 30 | * \param[in] stream The stream to wait on. 31 | */ 32 | void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t stream); 33 | 34 | #ifdef __cplusplus 35 | } // extern "C" 36 | #endif 37 | 38 | #endif // TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H 39 | -------------------------------------------------------------------------------- /transformer_engine/common/nvtx.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #ifndef TRANSFORMER_ENGINE_COMMON_NVTX_H_ 8 | #define TRANSFORMER_ENGINE_COMMON_NVTX_H_ 9 | 10 | #include 11 | 12 | #include 13 | 14 | namespace transformer_engine::nvtx { 15 | 16 | struct NVTXWrapper { 17 | explicit NVTXWrapper(const std::string &name) { nvtxRangePush(name.c_str()); } 18 | 19 | ~NVTXWrapper() { nvtxRangePop(); } 20 | }; 21 | 22 | } // namespace transformer_engine::nvtx 23 | 24 | #endif // TRANSFORMER_ENGINE_COMMON_NVTX_H_ 25 | -------------------------------------------------------------------------------- /transformer_engine/common/util/cuda_driver.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include 8 | 9 | #include "../common.h" 10 | #include "../util/cuda_runtime.h" 11 | 12 | namespace transformer_engine { 13 | 14 | namespace cuda_driver { 15 | 16 | void *get_symbol(const char *symbol) { 17 | void *entry_point; 18 | cudaDriverEntryPointQueryResult driver_result; 19 | NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result)); 20 | NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess, 21 | "Could not find CUDA driver entry point for ", symbol); 22 | return entry_point; 23 | } 24 | 25 | } // namespace cuda_driver 26 | 27 | } // namespace transformer_engine 28 | -------------------------------------------------------------------------------- /transformer_engine/common/util/cuda_nvml.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "cuda_nvml.h" 8 | 9 | #include "shared_lib_wrapper.h" 10 | 11 | namespace transformer_engine { 12 | 13 | namespace cuda_nvml { 14 | 15 | /*! \brief Lazily-initialized shared library for CUDA NVML */ 16 | Library &cuda_nvml_lib() { 17 | constexpr char lib_name[] = "libnvidia-ml.so.1"; 18 | static Library lib(lib_name); 19 | return lib; 20 | } 21 | 22 | void *get_symbol(const char *symbol) { return cuda_nvml_lib().get_symbol(symbol); } 23 | 24 | } // namespace cuda_nvml 25 | 26 | } // namespace transformer_engine 27 | -------------------------------------------------------------------------------- /transformer_engine/common/util/cuda_runtime.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_ 8 | #define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_ 9 | 10 | #include 11 | 12 | #include 13 | 14 | namespace transformer_engine { 15 | 16 | namespace cuda { 17 | 18 | /* \brief Number of accessible devices */ 19 | int num_devices(); 20 | 21 | /* \brief Which device is currently being used */ 22 | int current_device(); 23 | 24 | /* \brief Compute capability of device 25 | * 26 | * \param[in] device_id CUDA device (default is current device) 27 | * 28 | * \return Compute capability as int. Last digit is minor revision, 29 | * remaining digits are major revision. 30 | */ 31 | int sm_arch(int device_id = -1); 32 | 33 | /* \brief Number of multiprocessors on a device 34 | * 35 | * \param[in] device_id CUDA device (default is current device) 36 | * 37 | * \return Number of multiprocessors 38 | */ 39 | int sm_count(int device_id = -1); 40 | 41 | /* \brief Minimum and maximum stream priorities supported on device 42 | * 43 | * \param[in] device_id CUDA device (default is current device) 44 | * 45 | * \param[out] low_priority Lowest priority value on device. 46 | * 47 | * \param[out] high_priority Highest priority value on device. 48 | */ 49 | void stream_priority_range(int *low_priority, int *high_priority, int device_id = -1); 50 | 51 | /* \brief CUDA Multicast support status for device 52 | * 53 | * \param[in] device_id CUDA device (default is current device) 54 | * 55 | * \return CUDA multicast support flag 56 | */ 57 | bool supports_multicast(int device_id = -1); 58 | 59 | /* \brief Path to CUDA Toolkit headers 60 | * 61 | * The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the 62 | * environment. Otherwise searches in common install paths. 63 | * 64 | * \param[in] required Whether to throw exception if not found 65 | * 66 | * \return Path to include directory, or an empty string if not found 67 | */ 68 | const std::string &include_directory(bool required = false); 69 | 70 | } // namespace cuda 71 | 72 | } // namespace transformer_engine 73 | 74 | #endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_ 75 | -------------------------------------------------------------------------------- /transformer_engine/common/util/handle_manager.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ 8 | #define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ 9 | 10 | #include 11 | 12 | #include "cuda_runtime.h" 13 | #include "logging.h" 14 | 15 | namespace transformer_engine::detail { 16 | 17 | template 18 | class HandleManager { 19 | public: 20 | static HandleManager& Instance() { 21 | static thread_local HandleManager instance; 22 | return instance; 23 | } 24 | 25 | Handle GetHandle() { 26 | static thread_local std::vector initialized(handles_.size(), false); 27 | const int device_id = cuda::current_device(); 28 | NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID"); 29 | if (!initialized[device_id]) { 30 | Create(&(handles_[device_id])); 31 | initialized[device_id] = true; 32 | } 33 | return handles_[device_id]; 34 | } 35 | 36 | ~HandleManager() { 37 | if (Destroy != nullptr) { 38 | for (auto& handle : handles_) { 39 | Destroy(handle); 40 | } 41 | } 42 | } 43 | 44 | private: 45 | HandleManager() : handles_(cuda::num_devices(), nullptr) {} 46 | 47 | std::vector handles_ = nullptr; 48 | }; 49 | 50 | } // namespace transformer_engine::detail 51 | 52 | #endif // TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ 53 | -------------------------------------------------------------------------------- /transformer_engine/common/util/shared_lib_wrapper.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ 8 | #define TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ 9 | 10 | #include 11 | 12 | namespace transformer_engine { 13 | 14 | /*! \brief Wrapper class for a shared library 15 | * 16 | * \todo Windows support 17 | */ 18 | class Library { 19 | public: 20 | explicit Library(const char *filename) { 21 | #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) 22 | // TODO Windows support 23 | NVTE_ERROR("Shared library initialization is not supported with Windows"); 24 | #else 25 | handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL); 26 | NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed"); 27 | #endif // _WIN32 or _WIN64 or __WINDOW__ 28 | } 29 | 30 | ~Library() { 31 | #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) 32 | // TODO Windows support 33 | #else 34 | if (handle_ != nullptr) { 35 | dlclose(handle_); 36 | } 37 | #endif // _WIN32 or _WIN64 or __WINDOW__ 38 | } 39 | 40 | Library(const Library &) = delete; // move-only 41 | 42 | void *get() noexcept { return handle_; } 43 | 44 | const void *get() const noexcept { return handle_; } 45 | 46 | /*! \brief Get pointer corresponding to symbol in shared library */ 47 | void *get_symbol(const char *symbol) { 48 | #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) 49 | // TODO Windows support 50 | NVTE_ERROR("Shared library initialization is not supported with Windows"); 51 | #else 52 | void *ptr = dlsym(handle_, symbol); 53 | NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library"); 54 | return ptr; 55 | #endif // _WIN32 or _WIN64 or __WINDOW__ 56 | } 57 | 58 | private: 59 | void *handle_ = nullptr; 60 | }; 61 | 62 | } // namespace transformer_engine 63 | 64 | #endif // TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ 65 | -------------------------------------------------------------------------------- /transformer_engine/common/util/string.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_STRING_H_ 8 | #define TRANSFORMER_ENGINE_COMMON_UTIL_STRING_H_ 9 | 10 | #include // NOLINT(*) 11 | #include 12 | #include 13 | 14 | namespace transformer_engine { 15 | 16 | inline const std::string &to_string_like(const std::string &val) noexcept { return val; } 17 | 18 | constexpr const char *to_string_like(const char *val) noexcept { return val; } 19 | 20 | /* \brief Convert arithmetic type to string */ 21 | template ::value>::type> 22 | inline std::string to_string_like(const T &val) { 23 | return std::to_string(val); 24 | } 25 | 26 | /* \brief Convert container to string */ 27 | template ::value>::type, 28 | typename = decltype(std::declval().begin())> 29 | inline std::string to_string_like(const T &container) { 30 | std::string str; 31 | str.reserve(1024); // Assume strings are <1 KB 32 | str += "("; 33 | bool first = true; 34 | for (const auto &val : container) { 35 | if (!first) { 36 | str += ","; 37 | } 38 | str += to_string_like(val); 39 | first = false; 40 | } 41 | str += ")"; 42 | return str; 43 | } 44 | 45 | /*! \brief Convert arguments to strings and concatenate */ 46 | template 47 | inline std::string concat_strings(const Ts &...args) { 48 | std::string str; 49 | str.reserve(1024); // Assume strings are <1 KB 50 | (..., (str += to_string_like(args))); 51 | return str; 52 | } 53 | 54 | /*! \brief Substitute regex occurances in string 55 | * 56 | * This is a convenience wrapper around std::regex_replace. 57 | */ 58 | template 59 | inline std::string regex_replace(const std::string &str, const std::string &pattern, 60 | const T &replacement) { 61 | return std::regex_replace(str, std::regex(pattern), to_string_like(replacement)); 62 | } 63 | 64 | } // namespace transformer_engine 65 | 66 | #endif // TRANSFORMER_ENGINE_COMMON_UTIL_STRING_H_ 67 | -------------------------------------------------------------------------------- /transformer_engine/common/util/string_header.h.in: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | static constexpr char @STRING_NAME@[] 8 | = R"__STRING_DELIM__(@STRING@)__STRING_DELIM__"; 9 | -------------------------------------------------------------------------------- /transformer_engine/common/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | """The utilities for Transformer Engine""" 5 | import inspect 6 | import warnings 7 | from enum import Enum 8 | 9 | warnings.filterwarnings( 10 | "module", category=DeprecationWarning, module="transformer_engine.common.utils" 11 | ) 12 | 13 | 14 | class DeprecatedEnum: # pylint: disable=too-few-public-methods 15 | """DeprecatedEnum""" 16 | 17 | def __init__(self, enum_cls, msg): 18 | self.enum_cls = enum_cls 19 | self.msg = msg 20 | 21 | def __iter__(self): 22 | return iter(list(self.enum_cls.__members__.values())) 23 | 24 | def __getattr__(self, name): 25 | if name in self.enum_cls.__members__: 26 | warnings.warn(self.msg, DeprecationWarning) 27 | return self.enum_cls.__members__[name] 28 | raise AttributeError(f"{self.enum_cls} does not contain {name}") 29 | 30 | 31 | def deprecate_wrapper(obj, msg): 32 | """Deprecate wrapper""" 33 | if inspect.isclass(obj): 34 | if issubclass(obj, Enum): 35 | return DeprecatedEnum(obj, msg) 36 | 37 | class DeprecatedCls(obj): # pylint: disable=too-few-public-methods 38 | """DeprecatedCls""" 39 | 40 | def __init__(self, *args, **kwargs): 41 | warnings.warn(msg, DeprecationWarning) 42 | super().__init__(*args, **kwargs) 43 | 44 | return DeprecatedCls 45 | 46 | if inspect.isfunction(obj): 47 | 48 | def deprecated(*args, **kwargs): 49 | warnings.warn(msg, DeprecationWarning) 50 | return obj(*args, **kwargs) 51 | 52 | return deprecated 53 | 54 | raise NotImplementedError( 55 | f"deprecate_cls_wrapper only support Class and Function, but got {type(obj)}." 56 | ) 57 | -------------------------------------------------------------------------------- /transformer_engine/debug/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Top level package for numerical debugging.""" 6 | 7 | try: 8 | from . import pytorch 9 | from .pytorch.debug_state import set_weight_tensor_tp_group_reduce 10 | except ImportError as e: 11 | pass 12 | -------------------------------------------------------------------------------- /transformer_engine/debug/features/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Debug features.""" 6 | -------------------------------------------------------------------------------- /transformer_engine/debug/features/_test_dummy_feature.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Feature doing nothing, used for testing purposes.""" 6 | 7 | from nvdlfw_inspect.registry import Registry, api_method 8 | from transformer_engine.debug.features.api import TEConfigAPIMapper 9 | 10 | 11 | @Registry.register_feature(namespace="transformer_engine") 12 | class TestDummyFeature(TEConfigAPIMapper): 13 | """ 14 | This is feature used only in tests. It invokes look_at_tensor_before_process 15 | and does nothing. 16 | 17 | If no features are used, then TE layer automatically switches to the non-debug mode. 18 | This feature is invoked for each GEMM to prevent this behavior. 19 | """ 20 | 21 | @api_method 22 | def inspect_tensor_enabled(self, *_args, **_kwargs): 23 | """API call used to determine whether to run look_at_tensor_before_process 24 | in the forward pass.""" 25 | return True 26 | -------------------------------------------------------------------------------- /transformer_engine/debug/features/disable_fp8_gemm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """DisableFP8GEMM Feature support for nvidia-dlframework-inspect""" 6 | 7 | from nvdlfw_inspect.registry import Registry, api_method 8 | from transformer_engine.debug.features.api import TEConfigAPIMapper 9 | 10 | 11 | @Registry.register_feature(namespace="transformer_engine") 12 | class DisableFP8GEMM(TEConfigAPIMapper): 13 | """ 14 | GEMM operations are executed in higher precision, even when FP8 autocast is enabled. 15 | 16 | Parameters 17 | ---------- 18 | 19 | gemms: List[str] 20 | list of gemms to disable 21 | 22 | - fprop 23 | - dgrad 24 | - wgrad 25 | 26 | Example 27 | ------- 28 | .. code-block:: yaml 29 | 30 | example_disable_fp8_gemm: 31 | enabled: True 32 | layers: 33 | layer_types: [fc1] 34 | transformer_engine: 35 | DisableFP8GEMM: 36 | enabled: True 37 | gemms: [dgrad, wgrad] 38 | """ 39 | 40 | @api_method 41 | def fp8_gemm_enabled( 42 | self, config, layer_name: str, gemm: str, iteration: int 43 | ): # pylint: disable=unused-argument 44 | """API call responsible for choice between high-precision and FP8 GEMM execution.""" 45 | 46 | for key in config: 47 | if key != "gemm": 48 | raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') 49 | 50 | # If this feature is invoked, then FP8 GEMM is disabled. 51 | # If not, then default behaviour in TransformerEngineAPI 52 | # is that fp8_gemm() API call returns True. 53 | return False 54 | -------------------------------------------------------------------------------- /transformer_engine/debug/features/disable_fp8_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """DisableFP8Layer Feature support for nvidia-dlframework-inspect""" 6 | 7 | import nvdlfw_inspect.api as debug_api 8 | from nvdlfw_inspect.registry import Registry, api_method 9 | 10 | 11 | @Registry.register_feature(namespace="transformer_engine") 12 | class DisableFP8Layer: 13 | """ 14 | Disables all FP8 GEMMs in the layer. 15 | 16 | 17 | Example 18 | ------- 19 | .. code-block:: yaml 20 | 21 | example_disable_fp8_layer: 22 | enabled: True 23 | layers: 24 | layer_types: [fc1] 25 | transformer_engine: 26 | DisableFP8Layer: 27 | enabled: True 28 | """ 29 | 30 | @api_method 31 | def fp8_gemm_enabled( 32 | self, config, layer_name: str, gemm: str, iteration: int 33 | ): # pylint: disable=unused-argument 34 | """API call responsible for selecting between high-precision and FP8 GEMM execution.""" 35 | for key in config: 36 | if key not in ["enabled", "gemm"]: 37 | raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') 38 | # If FP8 training, disable FP8 for the selected layers if this feature is enabled in config. 39 | debug_api.log_message("FP8 Disabled", layer_name) 40 | 41 | # If this feature is invoked, then FP8 GEMM is disabled. 42 | # If not, then default behavior in TransformerEngineAPI 43 | # is that fp8_gemm() API call returns True. 44 | return False 45 | 46 | def parse_config_and_api(self, config, **_kwargs): 47 | """Determines whether to run the API 48 | DisableFP8Layer is the only feature provided by the Transformer Engine 49 | which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for 50 | parsing gemms and tensors fields from the config, which are not needed for this feature. 51 | 52 | Explanation of the parse_config_and_api can be found in the 53 | nvidia-dlframework-inspect documentation. 54 | """ 55 | return config["enabled"], None 56 | -------------------------------------------------------------------------------- /transformer_engine/debug/features/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """ 6 | Utils for the debug features. 7 | """ 8 | -------------------------------------------------------------------------------- /transformer_engine/debug/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | -------------------------------------------------------------------------------- /transformer_engine/debug/pytorch/debug_state.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """ 6 | Managing the state of all the debugged layers. 7 | """ 8 | 9 | import sys 10 | 11 | 12 | class TEDebugState: 13 | """ 14 | A class to manage the state of debug layers. 15 | """ 16 | 17 | layer_count = 1 18 | layers_initialized = {} 19 | weight_tensor_tp_group_reduce = True 20 | debug_enabled = None 21 | 22 | @classmethod 23 | def initialize(cls): 24 | """ 25 | If debug_api module is initialized, then sets cls.debug_enabled to True. 26 | """ 27 | 28 | if "nvdlfw_inspect" in sys.modules: 29 | import nvdlfw_inspect.api as debug_api 30 | 31 | if cls.debug_enabled is False and debug_api.DEBUG_MANAGER is not None: 32 | # This method is invoked when initializing TE modules. 33 | # If this error is thrown, it means that some TE module had been initialized before 34 | # debug_api was initialized, and now a new TE module is being initialized. 35 | # This is likely to be a bug. 36 | raise RuntimeError( 37 | "[nv_dlfw_inspect] nv_dlfw_inspect module should be initialized before" 38 | " initialization of the first TE module" 39 | ) 40 | cls.debug_enabled = debug_api.DEBUG_MANAGER is not None 41 | 42 | @classmethod 43 | def _reset(cls): 44 | """Resets layer count and stats buffers.""" 45 | from ..features.utils.stats_buffer import STATS_BUFFERS 46 | 47 | STATS_BUFFERS.reset() 48 | cls.debug_enabled = None 49 | cls.layers_initialized.clear() 50 | 51 | @classmethod 52 | def get_layer_count(cls): 53 | """ 54 | Layer counter is used when layer names are not provided to modules by the user. 55 | """ 56 | lc = cls.layer_count 57 | cls.layer_count += 1 58 | return lc 59 | 60 | @classmethod 61 | def set_weight_tensor_tp_group_reduce(cls, enabled): 62 | """Sets weight tensor reduction mode.""" 63 | cls.weight_tensor_tp_group_reduce = enabled 64 | 65 | 66 | def set_weight_tensor_tp_group_reduce(enabled): 67 | """Sets weight tensor reduction mode.""" 68 | TEDebugState.set_weight_tensor_tp_group_reduce(enabled) 69 | -------------------------------------------------------------------------------- /transformer_engine/debug/pytorch/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Utils functions for the debug module.""" 6 | 7 | 8 | def any_feature_enabled(quantizers): 9 | """Returns True if at least one API call is made from DebugQuantizer.""" 10 | return any(q.any_feature_enabled() for q in quantizers) 11 | -------------------------------------------------------------------------------- /transformer_engine/jax/MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include build_tools *.* 2 | recursive-include common_headers *.* 3 | recursive-include csrc *.* 4 | -------------------------------------------------------------------------------- /transformer_engine/jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | """Transformer Engine bindings for JAX. 5 | 6 | This module provides JAX bindings for NVIDIA's Transformer Engine, enabling 7 | high-performance transformer operations with mixed precision and quantization 8 | support. It includes implementations of key transformer components like attention, 9 | linear layers, and layer normalization, optimized for NVIDIA GPUs. 10 | 11 | The module exports various transformer operations and utilities: 12 | - Attention mechanisms (self-attention, cross-attention) 13 | - Linear transformations with optional quantization 14 | - Layer normalization operations 15 | - Activation functions 16 | - Softmax operations 17 | - Sharding utilities for distributed training 18 | 19 | All operations are designed to work seamlessly with JAX's functional programming 20 | model and support automatic differentiation. 21 | """ 22 | 23 | # pylint: disable=wrong-import-position 24 | 25 | # This unused import is needed because the top level `transformer_engine/__init__.py` 26 | # file catches an `ImportError` as a guard for cases where the given framework's 27 | # extensions are not available. 28 | import jax 29 | 30 | from transformer_engine.common import load_framework_extension 31 | 32 | load_framework_extension("jax") 33 | 34 | from . import flax 35 | from . import quantize 36 | 37 | from .quantize import fp8_autocast, update_collections, get_delayed_scaling 38 | from .quantize import NVTE_FP8_COLLECTION_NAME 39 | 40 | from .sharding import MeshResource 41 | from .sharding import MajorShardingType, ShardingResource, ShardingType 42 | 43 | from ..common.utils import deprecate_wrapper 44 | from ..common.utils import DeprecatedEnum 45 | 46 | MajorShardingType = DeprecatedEnum( 47 | MajorShardingType, "MajorShardingType is deprecating in the near feature." 48 | ) 49 | ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.") 50 | ShardingResource = deprecate_wrapper( 51 | ShardingResource, 52 | "ShardingResource is renamed to MeshResource, and will be removed in the near feature.", 53 | ) 54 | 55 | __all__ = [ 56 | "NVTE_FP8_COLLECTION_NAME", 57 | "fp8_autocast", 58 | "update_collections", 59 | "get_delayed_scaling", 60 | "MeshResource", 61 | "MajorShardingType", 62 | "ShardingResource", 63 | "ShardingType", 64 | "flax", 65 | "quantize", 66 | ] 67 | -------------------------------------------------------------------------------- /transformer_engine/jax/cpp_extensions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | """Python interface for c++ extensions""" 5 | from .activation import * 6 | from .attention import * 7 | from .normalization import * 8 | from .quantization import * 9 | from .softmax import * 10 | from .gemm import * 11 | -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/cublas.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "extensions.h" 8 | #include "transformer_engine/gemm.h" 9 | #include "xla/ffi/api/c_api.h" 10 | 11 | namespace transformer_engine { 12 | namespace jax { 13 | 14 | Error_Type CublasHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets, 15 | Dictionary attrs) { 16 | nvte_cublas_handle_init(); 17 | return ffi_with_cuda_error_check(); 18 | } 19 | 20 | XLA_FFI_DEFINE_HANDLER_SYMBOL(CublasHandleInitHandler, CublasHandleInitFFI, 21 | FFI::Bind().RemainingArgs().RemainingRets().Attrs()); 22 | } // namespace jax 23 | } // namespace transformer_engine 24 | -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/cudnn.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "transformer_engine/cudnn.h" 8 | 9 | #include "extensions.h" 10 | #include "xla/ffi/api/c_api.h" 11 | 12 | namespace transformer_engine { 13 | namespace jax { 14 | 15 | Error_Type CudnnHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets, 16 | Dictionary attrs) { 17 | nvte_cudnn_handle_init(); 18 | return ffi_with_cuda_error_check(); 19 | } 20 | 21 | XLA_FFI_DEFINE_HANDLER_SYMBOL(CudnnHandleInitHandler, CudnnHandleInitFFI, 22 | FFI::Bind().RemainingArgs().RemainingRets().Attrs()); 23 | } // namespace jax 24 | } // namespace transformer_engine 25 | -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/ffi.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | #include "extensions/ffi.h" 7 | 8 | #include 9 | 10 | namespace transformer_engine { 11 | namespace jax { 12 | 13 | // For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186 14 | DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { 15 | switch (type) { 16 | // Using this for E8M0 17 | case xla::ffi::DataType::U8: 18 | return DType::kFloat8E8M0; 19 | break; 20 | case xla::ffi::DataType::S32: 21 | return DType::kInt32; 22 | break; 23 | case xla::ffi::DataType::S64: 24 | return DType::kInt64; 25 | break; 26 | case xla::ffi::DataType::F32: 27 | return DType::kFloat32; 28 | break; 29 | case xla::ffi::DataType::F16: 30 | return DType::kFloat16; 31 | break; 32 | case xla::ffi::DataType::BF16: 33 | return DType::kBFloat16; 34 | break; 35 | case xla::ffi::DataType::F8E5M2: 36 | return DType::kFloat8E5M2; 37 | break; 38 | case xla::ffi::DataType::F8E4M3FN: 39 | return DType::kFloat8E4M3; 40 | break; 41 | // case xla::ffi::DataType::F8E8M0FNU: 42 | // return DType::kFloat8E8M0; 43 | // break; 44 | default: 45 | auto type_num = static_cast(type); 46 | if (type_num == 33) return DType::kFloat8E8M0; 47 | NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", 48 | static_cast(type_num)); 49 | break; 50 | } 51 | } 52 | 53 | Error_Type ffi_with_cuda_error_check() { 54 | cudaError_t last_error = cudaGetLastError(); 55 | if (last_error != cudaSuccess) { 56 | return Error_Type(XLA_FFI_Error_Code_INTERNAL, 57 | std::string("CUDA error: ") + cudaGetErrorString(last_error)); 58 | } 59 | return Error_Type::Success(); 60 | } 61 | 62 | } // namespace jax 63 | } // namespace transformer_engine 64 | -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/misc.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "extensions.h" 8 | 9 | namespace transformer_engine { 10 | namespace jax { 11 | 12 | std::vector MakeShapeVector(NVTEShape shape) { 13 | return std::vector(shape.data, shape.data + shape.ndim); 14 | } 15 | 16 | void Shape::from_vector(const std::vector &shape) { 17 | num_dim = shape.size(); 18 | assert(num_dim <= kMaxNumDim); 19 | std::memcpy(dims, shape.data(), num_dim * sizeof(size_t)); 20 | } 21 | 22 | std::vector Shape::to_vector() const { 23 | assert(num_dim <= kMaxNumDim); 24 | std::vector shape(num_dim); 25 | std::memcpy(shape.data(), dims, num_dim * sizeof(size_t)); 26 | return shape; 27 | } 28 | 29 | } // namespace jax 30 | } // namespace transformer_engine 31 | -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/misc.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | namespace transformer_engine { 14 | namespace jax { 15 | 16 | constexpr int kMaxNumDim = 8; 17 | 18 | struct Shape { 19 | int num_dim; 20 | size_t dims[kMaxNumDim]; 21 | 22 | void from_vector(const std::vector &shape); 23 | 24 | std::vector to_vector() const; 25 | }; 26 | 27 | std::vector MakeShapeVector(NVTEShape shape); 28 | 29 | inline size_t product(const std::vector &shape) { 30 | size_t ret = 1; 31 | for (const auto &elem : shape) { 32 | ret *= elem; 33 | } 34 | return ret; 35 | } 36 | 37 | enum class QuantizeLayout { 38 | ROWWISE, 39 | COLWISE, 40 | ROWWISE_COLWISE, 41 | }; 42 | 43 | enum class JAXX_Scaling_Mode : int64_t { 44 | NO_SCALING = 0, 45 | DELAYED_TENSOR_SCALING = 1, 46 | MXFP8_1D_SCALING = 2, 47 | CURRENT_TENSOR_SCALING = 3, 48 | }; 49 | 50 | static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { 51 | switch (mode) { 52 | case JAXX_Scaling_Mode::NO_SCALING: 53 | return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; 54 | break; 55 | case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING: 56 | return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; 57 | break; 58 | case JAXX_Scaling_Mode::MXFP8_1D_SCALING: 59 | return NVTEScalingMode::NVTE_MXFP8_1D_SCALING; 60 | break; 61 | case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING: 62 | return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; 63 | break; 64 | default: 65 | NVTE_ERROR("Invalid Scaling Mode ", static_cast(mode)); 66 | break; 67 | } 68 | } 69 | 70 | } // namespace jax 71 | } // namespace transformer_engine 72 | -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/utils.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | #include "utils.h" 7 | 8 | #include 9 | 10 | #include 11 | 12 | #include "common/util/cuda_runtime.h" 13 | 14 | namespace transformer_engine { 15 | namespace jax { 16 | 17 | int GetCudaRuntimeVersion() { 18 | int ver = 0; 19 | NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver)); 20 | return ver; 21 | } 22 | 23 | size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); } 24 | 25 | int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); } 26 | 27 | } // namespace jax 28 | } // namespace transformer_engine 29 | -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/utils.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "common/util/logging.h" 17 | 18 | namespace transformer_engine { 19 | namespace jax { 20 | 21 | int GetCudaRuntimeVersion(); 22 | size_t GetCudnnRuntimeVersion(); 23 | int GetDeviceComputeCapability(int gpu_id); 24 | 25 | class cudaDevicePropertiesManager { 26 | public: 27 | static cudaDevicePropertiesManager &Instance() { 28 | static thread_local cudaDevicePropertiesManager instance; 29 | return instance; 30 | } 31 | 32 | int GetMultiProcessorCount() { 33 | if (!prop_queried_) { 34 | int device_id; 35 | NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); 36 | cudaGetDeviceProperties(&prop_, device_id); 37 | prop_queried_ = true; 38 | } 39 | return prop_.multiProcessorCount; 40 | } 41 | 42 | int GetMajor() { 43 | if (!prop_queried_) { 44 | int device_id; 45 | NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); 46 | cudaGetDeviceProperties(&prop_, device_id); 47 | prop_queried_ = true; 48 | } 49 | return prop_.major; 50 | } 51 | 52 | private: 53 | bool prop_queried_ = false; 54 | cudaDeviceProp prop_; 55 | }; 56 | 57 | } // namespace jax 58 | } // namespace transformer_engine 59 | -------------------------------------------------------------------------------- /transformer_engine/jax/flax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | """Transformer Engine bindings for JAX""" 5 | from .module import DenseGeneral, LayerNorm 6 | from .module import LayerNormDenseGeneral, LayerNormMLP 7 | from .transformer import extend_logical_axis_rules 8 | from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases 9 | from .transformer import TransformerLayer, TransformerLayerType 10 | 11 | __all__ = [ 12 | "DenseGeneral", 13 | "LayerNorm", 14 | "LayerNormDenseGeneral", 15 | "LayerNormMLP", 16 | "extend_logical_axis_rules", 17 | "DotProductAttention", 18 | "MultiHeadAttention", 19 | "RelativePositionBiases", 20 | "TransformerLayer", 21 | "TransformerLayerType", 22 | ] 23 | -------------------------------------------------------------------------------- /transformer_engine/jax/quantize/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | """ 5 | Python interface for quantization helpers. 6 | 7 | This module provides a high-level interface for tensor quantization in JAX, 8 | including support for various scaling modes and quantization strategies. 9 | It exports all the necessary classes and functions from the underlying 10 | implementation modules. 11 | """ 12 | from .tensor import * 13 | from .quantizer import * 14 | from .dequantizer import * 15 | from .scaling_modes import * 16 | from .metadata import * 17 | from .helper import * 18 | -------------------------------------------------------------------------------- /transformer_engine/jax/quantize/metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """ 6 | Metadata classes for quantization in JAX. 7 | 8 | This module provides classes for managing quantization metadata, including 9 | scale factors and amax history for different tensor types. 10 | """ 11 | from dataclasses import dataclass 12 | import jax.numpy as jnp 13 | 14 | 15 | __all__ = ["QuantizeMeta", "QuantizeMetaSet"] 16 | 17 | 18 | @dataclass 19 | class QuantizeMeta: 20 | """Metadata for quantization parameters. 21 | 22 | Attributes: 23 | scale: The scaling factor for quantization 24 | amax_history: History of maximum absolute values 25 | """ 26 | 27 | scale: jnp.ndarray 28 | amax_history: jnp.ndarray 29 | 30 | 31 | @dataclass 32 | class QuantizeMetaSet: 33 | """Set of quantization metadata for different tensor types. 34 | 35 | Attributes: 36 | x: Quantization metadata for input tensors 37 | kernel: Quantization metadata for kernel tensors 38 | grad: Quantization metadata for gradient tensors 39 | """ 40 | 41 | x: QuantizeMeta 42 | kernel: QuantizeMeta 43 | grad: QuantizeMeta 44 | -------------------------------------------------------------------------------- /transformer_engine/jax/softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | """JAX softmax modules""" 5 | from enum import Enum 6 | from functools import partial 7 | from typing import Optional 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | 12 | from . import cpp_extensions as tex 13 | 14 | 15 | class SoftmaxType(Enum): 16 | """SoftmaxType.""" 17 | 18 | SCALED = "scaled" 19 | SCALED_MASKED = "scaled_masked" 20 | SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked" 21 | 22 | 23 | def softmax( 24 | logits: jnp.ndarray, 25 | mask: Optional[jnp.ndarray] = None, 26 | scale_factor: Optional[float] = 1.0, 27 | softmax_type: Optional[SoftmaxType] = SoftmaxType.SCALED, 28 | ): 29 | """ 30 | Softmax wrapper 31 | """ 32 | output = _softmax(logits, mask, scale_factor, softmax_type) 33 | return output 34 | 35 | 36 | @partial(jax.custom_vjp, nondiff_argnums=(2, 3)) 37 | def _softmax(logits, mask, scale_factor, softmax_type): 38 | 39 | output, _ = _softmax_fwd_rule(logits, mask, scale_factor, softmax_type) 40 | return output 41 | 42 | 43 | def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type): 44 | if softmax_type is SoftmaxType.SCALED_MASKED: 45 | assert mask is not None 46 | output = tex.scaled_masked_softmax_fwd(logits, mask, scale_factor) 47 | elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: 48 | output = tex.scaled_upper_triang_masked_softmax_fwd(logits, scale_factor) 49 | else: 50 | output = tex.scaled_softmax_fwd(logits, scale_factor) 51 | 52 | return output, (output, logits, mask) 53 | 54 | 55 | def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz): 56 | (softmax_output, logits, mask) = ctx 57 | 58 | if softmax_type is SoftmaxType.SCALED_MASKED: 59 | dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor) 60 | elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: 61 | dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor) 62 | else: 63 | dgrad = tex.scaled_softmax_bwd(dz, softmax_output, logits, scale_factor) 64 | 65 | return (dgrad, None) 66 | 67 | 68 | _softmax.defvjp(_softmax_fwd_rule, _softmax_bwd_rule) 69 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include build_tools *.* 2 | recursive-include common_headers *.* 3 | recursive-include csrc *.* 4 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Transformer Engine bindings for pyTorch""" 6 | 7 | # pylint: disable=wrong-import-position 8 | 9 | import functools 10 | from packaging.version import Version as PkgVersion 11 | 12 | import torch 13 | 14 | from transformer_engine.common import load_framework_extension 15 | 16 | 17 | @functools.lru_cache(maxsize=None) 18 | def torch_version() -> tuple[int, ...]: 19 | """Get PyTorch version""" 20 | return PkgVersion(str(torch.__version__)).release 21 | 22 | 23 | assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." 24 | 25 | 26 | load_framework_extension("torch") 27 | from transformer_engine.pytorch.module import LayerNormLinear 28 | from transformer_engine.pytorch.module import Linear 29 | from transformer_engine.pytorch.module import LayerNormMLP 30 | from transformer_engine.pytorch.module import LayerNorm 31 | from transformer_engine.pytorch.module import RMSNorm 32 | from transformer_engine.pytorch.module import GroupedLinear 33 | from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding 34 | from transformer_engine.pytorch.module import initialize_ub 35 | from transformer_engine.pytorch.module import destroy_ub 36 | from transformer_engine.pytorch.attention import DotProductAttention 37 | from transformer_engine.pytorch.attention import MultiheadAttention 38 | from transformer_engine.pytorch.attention import InferenceParams 39 | from transformer_engine.pytorch.attention import RotaryPositionEmbedding 40 | from transformer_engine.pytorch.transformer import TransformerLayer 41 | from transformer_engine.pytorch.permutation import ( 42 | moe_permute, 43 | moe_permute_with_probs, 44 | moe_unpermute, 45 | moe_sort_chunks_by_index, 46 | moe_sort_chunks_by_index_with_probs, 47 | ) 48 | from transformer_engine.pytorch.fp8 import fp8_autocast 49 | from transformer_engine.pytorch.fp8 import fp8_model_init 50 | from transformer_engine.pytorch.graph import make_graphed_callables 51 | from transformer_engine.pytorch.distributed import checkpoint 52 | from transformer_engine.pytorch.distributed import CudaRNGStatesTracker 53 | from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context 54 | from transformer_engine.pytorch import ops 55 | from transformer_engine.pytorch import optimizers 56 | from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy 57 | 58 | try: 59 | torch._dynamo.config.error_on_nested_jit_trace = False 60 | except AttributeError: 61 | pass # error_on_nested_jit_trace was added in PyTorch 2.2.0 62 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Python interface for attention""" 6 | 7 | from .dot_product_attention import DotProductAttention 8 | from .multi_head_attention import MultiheadAttention 9 | from .inference import InferenceParams 10 | from .rope import RotaryPositionEmbedding 11 | 12 | __all__ = [ 13 | "DotProductAttention", 14 | "MultiheadAttention", 15 | "InferenceParams", 16 | "RotaryPositionEmbedding", 17 | ] 18 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/dot_product_attention/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Python interface for dot product attention""" 6 | 7 | from .dot_product_attention import DotProductAttention, _attention_backends 8 | 9 | __all__ = ["DotProductAttention", "_attention_backends"] 10 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Enums for e2e transformer""" 6 | import torch 7 | import torch.distributed 8 | import transformer_engine_torch as tex 9 | 10 | 11 | """ 12 | This is a map: torch.dtype -> int 13 | Used for passing dtypes into cuda 14 | extension. Has one to one mapping 15 | with enum in transformer_engine.h 16 | """ 17 | TE_DType = { 18 | torch.uint8: tex.DType.kByte, 19 | torch.float8_e4m3fn: tex.DType.kFloat8E4M3, 20 | torch.float8_e5m2: tex.DType.kFloat8E5M2, 21 | torch.int32: tex.DType.kInt32, 22 | torch.float32: tex.DType.kFloat32, 23 | torch.half: tex.DType.kFloat16, 24 | torch.bfloat16: tex.DType.kBFloat16, 25 | } 26 | 27 | """ 28 | This is a map: int -> torch.dtype 29 | Used for resolving cuda extension types to torch. 30 | Has one to one mapping with enum in 31 | transformer_engine.h 32 | """ 33 | TE_DType_To_Torch = { 34 | tex.DType.kByte: torch.uint8, 35 | tex.DType.kFloat8E4M3: torch.float8_e4m3fn, 36 | tex.DType.kFloat8E5M2: torch.float8_e5m2, 37 | tex.DType.kInt32: torch.int32, 38 | tex.DType.kFloat32: torch.float32, 39 | tex.DType.kFloat16: torch.half, 40 | tex.DType.kBFloat16: torch.bfloat16, 41 | } 42 | 43 | AttnMaskTypes = ( 44 | "no_mask", 45 | "padding", 46 | "causal", 47 | "padding_causal", 48 | "causal_bottom_right", 49 | "padding_causal_bottom_right", 50 | "arbitrary", 51 | ) 52 | 53 | AttnTypes = ("self", "cross") 54 | 55 | AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias", "alibi") 56 | 57 | QKVLayouts = ( 58 | "sb3hd", 59 | "sbh3d", 60 | "sbhd_sb2hd", 61 | "sbhd_sbh2d", 62 | "sbhd_sbhd_sbhd", 63 | "bs3hd", 64 | "bsh3d", 65 | "bshd_bs2hd", 66 | "bshd_bsh2d", 67 | "bshd_bshd_bshd", 68 | "t3hd", 69 | "th3d", 70 | "thd_t2hd", 71 | "thd_th2d", 72 | "thd_thd_thd", 73 | "sbhd_bshd_bshd", 74 | "bshd_sbhd_sbhd", 75 | "thd_bshd_bshd", 76 | "thd_sbhd_sbhd", 77 | "paged_kv_bshd_bshd_bshd", 78 | "paged_kv_bshd_sbhd_sbhd", 79 | "paged_kv_sbhd_bshd_bshd", 80 | "paged_kv_sbhd_sbhd_sbhd", 81 | "paged_kv_thd_bshd_bshd", 82 | "paged_kv_thd_sbhd_sbhd", 83 | ) 84 | 85 | LayerTypes = ("encoder", "decoder") 86 | 87 | GemmParallelModes = ("row", "column", None) 88 | 89 | dist_group_type = torch.distributed.ProcessGroup 90 | 91 | MXFP8_BLOCK_SCALING_SIZE = 32 92 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/cpp_extensions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Python interface for c++ extensions""" 6 | from transformer_engine_torch import * 7 | 8 | from .fused_attn import * 9 | from .gemm import * 10 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/misc.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "extensions.h" 8 | 9 | namespace transformer_engine::pytorch { 10 | 11 | size_t get_cublasLt_version() { return cublasLtGetVersion(); } 12 | 13 | size_t get_cudnn_version() { return cudnnGetVersion(); } 14 | 15 | } // namespace transformer_engine::pytorch 16 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "extensions.h" 8 | 9 | namespace transformer_engine::pytorch { 10 | 11 | void multi_tensor_compute_scale_and_scale_inv_cuda( 12 | int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, 13 | float max_fp8, bool force_pow_2_scales, float epsilon) { 14 | auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); 15 | auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = 16 | makeTransformerEngineTensorList(tensor_lists); 17 | int device_id = tensor_lists[0][0].device().index(); 18 | 19 | nvte_multi_tensor_compute_scale_and_scale_inv_cuda( 20 | chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8, 21 | force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream()); 22 | } 23 | 24 | } // namespace transformer_engine::pytorch 25 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "extensions.h" 8 | 9 | namespace transformer_engine::pytorch { 10 | 11 | void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, 12 | std::vector> tensor_lists, float scale) { 13 | auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); 14 | auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = 15 | makeTransformerEngineTensorList(tensor_lists); 16 | int device_id = tensor_lists[0][0].device().index(); 17 | 18 | nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, 19 | num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream()); 20 | } 21 | 22 | } // namespace transformer_engine::pytorch 23 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #include "extensions.h" 8 | 9 | namespace transformer_engine::pytorch { 10 | 11 | void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, 12 | std::vector> tensor_lists, float wd, 13 | float momentum, float dampening, float lr, bool nesterov, bool first_run, 14 | bool wd_after_momentum, float scale) { 15 | auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); 16 | auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = 17 | makeTransformerEngineTensorList(tensor_lists); 18 | int device_id = tensor_lists[0][0].device().index(); 19 | 20 | nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, 21 | num_tensors, wd, momentum, dampening, lr, nesterov, first_run, 22 | wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream()); 23 | } 24 | 25 | } // namespace transformer_engine::pytorch 26 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/util.h: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * See LICENSE for license information. 5 | ************************************************************************/ 6 | 7 | #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ 8 | #define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ 9 | 10 | #include 11 | 12 | #include 13 | 14 | #include "transformer_engine/transformer_engine.h" 15 | 16 | /* Swizzle the scaling factor of the input tensor. 17 | * 18 | * The returned swizzled scaling factor tensor should be kept alive during the GEMM. 19 | */ 20 | std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper &input, 21 | bool trans); 22 | 23 | #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ 24 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/float8_tensor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Tensor class with FP8 data""" 6 | 7 | from .tensor.float8_tensor import Float8Tensor 8 | 9 | __all__ = ["Float8Tensor"] 10 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Module level PyTorch APIs""" 6 | from .layernorm_linear import LayerNormLinear 7 | from .linear import Linear 8 | from .grouped_linear import GroupedLinear 9 | from .layernorm_mlp import LayerNormMLP 10 | from .layernorm import LayerNorm 11 | from .rmsnorm import RMSNorm 12 | from .fp8_padding import Fp8Padding 13 | from .fp8_unpadding import Fp8Unpadding 14 | from .base import initialize_ub, destroy_ub 15 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/numerics_debug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Utilities for debugging numerical issues with FP8""" 6 | from typing import Tuple 7 | import torch 8 | from transformer_engine.common import recipe 9 | 10 | _NUMERICS_DEBUG = False 11 | 12 | 13 | def debug(enabled: bool = True) -> None: 14 | """Set FP8 debug mode""" 15 | global _NUMERICS_DEBUG 16 | _NUMERICS_DEBUG = enabled 17 | 18 | 19 | def fp8_tensor_statistics(tensor: torch.Tensor, fp8_format: str = "E4M3") -> Tuple[int, ...]: 20 | """Print FP8 tensor stats""" 21 | fp8_format = fp8_format.upper() 22 | assert fp8_format in ( 23 | "E4M3", 24 | "E5M2", 25 | ), "fp8_format must be 'E4M3' or 'E5M2' for amax" 26 | 27 | fmt = recipe.Format[fp8_format] 28 | FP8_MAX = fmt.value.max_fwd 29 | 30 | num_overflows = (tensor == FP8_MAX).sum().item() 31 | num_underflows = (tensor == 0).sum().item() 32 | return (num_underflows, num_overflows) 33 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Fusible operations. 6 | 7 | This operation-based API is experimental and subject to change. 8 | 9 | """ 10 | 11 | from transformer_engine.pytorch.ops.basic import * 12 | from transformer_engine.pytorch.ops.linear import Linear 13 | from transformer_engine.pytorch.ops.op import FusibleOperation 14 | from transformer_engine.pytorch.ops.sequential import Sequential 15 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Single tensor operations supported by the operation fuser.""" 6 | 7 | from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU 8 | from .add_in_place import AddInPlace 9 | from .all_gather import AllGather 10 | from .all_reduce import AllReduce 11 | from .basic_linear import BasicLinear 12 | from .bias import Bias 13 | from .identity import Identity 14 | from .layer_norm import LayerNorm 15 | from .make_extra_output import MakeExtraOutput 16 | from .quantize import Quantize 17 | from .reduce_scatter import ReduceScatter 18 | from .reshape import Reshape 19 | from .rmsnorm import RMSNorm 20 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/all_reduce.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Fusible operation for all-reduce.""" 6 | 7 | from __future__ import annotations 8 | from typing import Optional 9 | 10 | import torch 11 | 12 | from ...tensor import QuantizedTensor 13 | from ..op import BasicOperation, OperationContext 14 | 15 | 16 | class AllReduce(BasicOperation): 17 | """All-reduce tensor 18 | 19 | Equivalent to summing tensors from all processes. It is assumed 20 | that the output is used in operations that are redundantly 21 | computed on all processes, and hence that gradients are identical 22 | between processes. 23 | 24 | Parameters 25 | ---------- 26 | process_group: torch.distributed.ProcessGroup, default = world group 27 | Process group for communication 28 | 29 | """ 30 | 31 | def __init__( 32 | self, 33 | process_group: Optional[torch.distributed.ProcessGroup] = None, 34 | reduce_in_backward: bool = True, 35 | ) -> None: 36 | super().__init__() 37 | self.process_group: Optional[torch.distributed.ProcessGroup] = process_group 38 | self._reduce_in_backward: bool = reduce_in_backward 39 | 40 | def op_forward( 41 | self, 42 | ctx: OperationContext, 43 | input_: torch.Tensor, 44 | prev_op: Optional[BasicOperation] = None, 45 | next_op: Optional[BasicOperation] = None, 46 | ) -> torch.Tensor: 47 | 48 | # Trivial case 49 | if torch.distributed.get_world_size(self.process_group) == 1: 50 | return input_ 51 | 52 | # Perform all-reduce 53 | x = input_ 54 | if isinstance(x, QuantizedTensor): 55 | x = x.dequantize() 56 | x = x.contiguous() 57 | torch.distributed.all_reduce(x, group=self.process_group) 58 | return x 59 | 60 | def op_backward( 61 | self, 62 | ctx: OperationContext, 63 | grad_output: torch.Tensor, 64 | ) -> tuple[torch.Tensor, tuple[()]]: 65 | return grad_output, () 66 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/identity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Fusible operation for identity.""" 6 | 7 | from __future__ import annotations 8 | from typing import Optional 9 | 10 | import torch 11 | 12 | from transformer_engine.pytorch.ops.op import ( 13 | BasicOperation, 14 | OperationContext, 15 | ) 16 | 17 | 18 | class Identity(BasicOperation): 19 | """Return input tensor""" 20 | 21 | def op_forward( 22 | self, 23 | ctx: OperationContext, 24 | input_: torch.Tensor, 25 | prev_op: Optional[BasicOperation] = None, 26 | next_op: Optional[BasicOperation] = None, 27 | ) -> torch.Tensor: 28 | return input_ 29 | 30 | def op_backward( 31 | self, 32 | ctx: OperationContext, 33 | grad_output: torch.Tensor, 34 | ) -> tuple[torch.Tensor, tuple[()]]: 35 | return grad_output, () 36 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Fusible operation for quantization.""" 6 | 7 | from __future__ import annotations 8 | from typing import Optional 9 | 10 | import torch 11 | 12 | from ...fp8 import FP8GlobalStateManager 13 | from ...tensor import QuantizedTensor 14 | from ..op import BasicOperation, OperationContext 15 | 16 | 17 | class Quantize(BasicOperation): 18 | """Quantize tensor data 19 | 20 | Uses FP8 recipe from `fp8_autocast` context. When called outside 21 | of an `fp8_autocast` context, this is an identity operation. 22 | 23 | Parameters 24 | ---------- 25 | forward: bool, default = `True` 26 | Perform quantization in forward pass 27 | backward: bool, default = `False` 28 | Perform quantization in backward pass 29 | 30 | """ 31 | 32 | def __init__( 33 | self, 34 | forward: bool = True, 35 | backward: bool = False, 36 | ) -> None: 37 | super().__init__() 38 | self._quantize_forward = forward 39 | self._quantize_backward = backward 40 | 41 | def num_quantizers(self, mode: str) -> int: 42 | if mode == "forward" and self._quantize_forward: 43 | return 1 44 | if mode == "backward" and self._quantize_backward: 45 | return 1 46 | return 0 47 | 48 | def op_forward( 49 | self, 50 | ctx: OperationContext, 51 | input_: torch.Tensor, 52 | prev_op: Optional[BasicOperation] = None, 53 | next_op: Optional[BasicOperation] = None, 54 | ) -> torch.Tensor: 55 | 56 | # Check if FP8 is enabled 57 | fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() 58 | quantize_forward = fp8_enabled and self._quantize_forward 59 | quantize_backward = fp8_enabled and self._quantize_backward 60 | 61 | # Quantize if needed 62 | out = input_ 63 | if quantize_forward and not isinstance(out, QuantizedTensor): 64 | out = self.get_quantizer("forward", 0)(out) 65 | 66 | ctx.quantize_backward = quantize_backward 67 | return out 68 | 69 | def op_backward( 70 | self, 71 | ctx: OperationContext, 72 | grad_output: torch.Tensor, 73 | ) -> tuple[torch.Tensor, tuple[()]]: 74 | grad_input = grad_output 75 | if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor): 76 | grad_input = self.get_quantizer("backward", 0)(grad_input) 77 | return grad_input, () 78 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/reshape.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Fusible operation for reshape.""" 6 | 7 | from __future__ import annotations 8 | from collections.abc import Iterable 9 | from typing import Optional 10 | 11 | import torch 12 | 13 | from transformer_engine.pytorch.ops.op import ( 14 | BasicOperation, 15 | OperationContext, 16 | ) 17 | 18 | 19 | class Reshape(BasicOperation): 20 | """Reshape tensor 21 | 22 | See `torch.reshape`. 23 | 24 | Parameters 25 | ---------- 26 | shape: iterable of int 27 | Output tensor dimensions. If one dimension is -1, it is 28 | inferred based on input tensor dimensions. 29 | 30 | """ 31 | 32 | def __init__(self, shape: Iterable[int]) -> None: 33 | super().__init__() 34 | self._shape = tuple(shape) 35 | 36 | def op_forward( 37 | self, 38 | ctx: OperationContext, 39 | input_: torch.Tensor, 40 | prev_op: Optional[BasicOperation] = None, 41 | next_op: Optional[BasicOperation] = None, 42 | ) -> torch.Tensor: 43 | ctx.input_shape = input_.size() 44 | return input_.reshape(*self._shape) 45 | 46 | def op_backward( 47 | self, 48 | ctx: OperationContext, 49 | grad_output: torch.Tensor, 50 | ) -> tuple[torch.Tensor, tuple[()]]: 51 | return grad_output.reshape(*ctx.input_shape), () 52 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/fused/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Compound tensor operation supported by the operation fuser.""" 6 | 7 | from .backward_linear_add import ( 8 | BackwardLinearAdd, 9 | fuse_backward_linear_add, 10 | ) 11 | from .forward_linear_bias_activation import ( 12 | ForwardLinearBiasActivation, 13 | fuse_forward_linear_bias_activation, 14 | ) 15 | from .forward_linear_bias_add import ( 16 | ForwardLinearBiasAdd, 17 | fuse_forward_linear_bias_add, 18 | ) 19 | from .userbuffers_backward_linear import ( 20 | UserbuffersBackwardLinear, 21 | fuse_userbuffers_backward_linear, 22 | ) 23 | from .userbuffers_forward_linear import ( 24 | UserbuffersForwardLinear, 25 | fuse_userbuffers_forward_linear, 26 | ) 27 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Fused optimizers and multi-tensor kernels.""" 6 | from transformer_engine_torch import ( 7 | multi_tensor_scale, 8 | multi_tensor_l2norm, 9 | multi_tensor_unscale_l2norm, 10 | multi_tensor_adam, 11 | multi_tensor_adam_fp8, 12 | multi_tensor_adam_capturable, 13 | multi_tensor_adam_capturable_master, 14 | multi_tensor_sgd, 15 | ) 16 | from .fused_adam import FusedAdam 17 | from .fused_sgd import FusedSGD 18 | from .multi_tensor_apply import MultiTensorApply, multi_tensor_applier 19 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/optimizers/multi_tensor_apply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Multi-tensor apply entry.""" 6 | from torch.distributed._tensor import DTensor 7 | 8 | 9 | class MultiTensorApply: # pylint: disable=too-few-public-methods 10 | """Multi-tensor apply entry.""" 11 | 12 | def __init__(self, chunk_size): 13 | self.chunk_size = chunk_size 14 | 15 | def __call__(self, op, noop_flag_buffer, tensor_lists, *args): 16 | for i, ts in enumerate(tensor_lists): 17 | for j, t in enumerate(ts): 18 | if isinstance(t, DTensor): 19 | tensor_lists[i][j] = t._local_tensor 20 | 21 | return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) 22 | 23 | 24 | multi_tensor_applier = MultiTensorApply(2048 * 32) 25 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Custom tensor classes""" 6 | 7 | import torch 8 | 9 | from .quantized_tensor import QuantizedTensor, Quantizer 10 | from .utils import cast_master_weights_to_fp8, replace_raw_data 11 | 12 | __all__ = [ 13 | "QuantizedTensor", 14 | "Quantizer", 15 | ] 16 | 17 | 18 | def _make_module_cast_func(dtype): 19 | """Make module cast function that can handle QuantizedTensor""" 20 | cast_func_name = { 21 | torch.float32: "float", 22 | torch.float16: "half", 23 | torch.bfloat16: "bfloat16", 24 | }[dtype] 25 | 26 | def tensor_cast_func(tensor: torch.Tensor) -> torch.Tensor: 27 | """Cast tensor dtype""" 28 | if isinstance(tensor, QuantizedTensor): 29 | return tensor.__class__.make_like(tensor, dtype=dtype) 30 | if tensor.is_floating_point(): 31 | return getattr(tensor, cast_func_name)() 32 | return tensor 33 | 34 | def module_cast_func(self: torch.nn.Module) -> torch.nn.Module: 35 | """Cast module dtype""" 36 | return self._apply(tensor_cast_func) 37 | 38 | return module_cast_func 39 | 40 | 41 | # Monkey-patch module cast functions to handle QuantizedTensor 42 | torch.nn.Module.float = _make_module_cast_func(torch.float32) 43 | torch.nn.Module.half = _make_module_cast_func(torch.float16) 44 | torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) 45 | 46 | 47 | def get_all_tensor_types(): 48 | """ 49 | Get all tensor-like types that can be used in TE. 50 | """ 51 | from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase 52 | from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase 53 | from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( 54 | Float8BlockwiseQTensor, 55 | Float8BlockwiseQTensorBase, 56 | ) 57 | 58 | all_tensor_types = [ 59 | torch.Tensor, 60 | torch.nn.Parameter, 61 | Float8Tensor, 62 | Float8TensorBase, 63 | MXFP8Tensor, 64 | MXFP8TensorBase, 65 | Float8BlockwiseQTensor, 66 | Float8BlockwiseQTensorBase, 67 | ] 68 | return all_tensor_types 69 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/_internal/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | """Internal data structures for quantized tensors.""" 5 | -------------------------------------------------------------------------------- /transformer_engine/pytorch/triton/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | 5 | """Kernels written with OpenAI Triton.""" 6 | --------------------------------------------------------------------------------