├── .clang-format
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── PULL_REQUEST_TEMPLATE.md
└── workflows
│ ├── blossom-ci.yml
│ ├── build.yml
│ ├── deploy_nightly_docs.yml
│ ├── docs.yml
│ ├── license.yml
│ ├── lint.yml
│ ├── trigger-ci.yml
│ └── upload-ci-logs.yml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── Acknowledgements.txt
├── CONTRIBUTING.rst
├── CPPLINT.cfg
├── LICENSE
├── README.rst
├── SECURITY.md
├── benchmarks
└── attention
│ └── benchmark_attention.py
├── build_tools
├── VERSION.txt
├── __init__.py
├── build_ext.py
├── jax.py
├── pytorch.py
├── te_version.py
├── utils.py
└── wheel_utils
│ ├── Dockerfile.aarch
│ ├── Dockerfile.x86
│ ├── build_wheels.sh
│ ├── launch_aarch.sh
│ └── launch_x86.sh
├── docs
├── .gitignore
├── Doxyfile
├── Makefile
├── _static
│ ├── NVIDIA-LogoBlack.svg
│ └── css
│ │ ├── nvidia_font.css
│ │ └── nvidia_footer.css
├── _templates
│ ├── footer.html
│ └── layout.html
├── api
│ ├── c
│ │ ├── activation.rst
│ │ ├── cast.rst
│ │ ├── cast_transpose_noop.rst
│ │ ├── cudnn.rst
│ │ ├── fused_attn.rst
│ │ ├── fused_rope.rst
│ │ ├── gemm.rst
│ │ ├── index.rst
│ │ ├── multi_tensor.rst
│ │ ├── normalization.rst
│ │ ├── padding.rst
│ │ ├── permutation.rst
│ │ ├── recipe.rst
│ │ ├── softmax.rst
│ │ ├── swizzle.rst
│ │ ├── transformer_engine.rst
│ │ └── transpose.rst
│ ├── common.rst
│ ├── framework.rst
│ ├── jax.rst
│ └── pytorch.rst
├── conf.py
├── debug.rst
├── debug
│ ├── 1_getting_started.rst
│ ├── 2_config_file_structure.rst
│ ├── 3_api_debug_setup.rst
│ ├── 3_api_features.rst
│ ├── 3_api_te_calls.rst
│ ├── 4_distributed.rst
│ ├── api.rst
│ └── img
│ │ ├── api_calls1.svg
│ │ ├── api_calls2.svg
│ │ ├── fake_quant.svg
│ │ ├── introduction.svg
│ │ ├── names.svg
│ │ ├── pipeline_logging.svg
│ │ ├── reduction1.svg
│ │ ├── reduction2.svg
│ │ ├── reduction3.svg
│ │ ├── scaling_factors.svg
│ │ └── tensorboard.png
├── examples
│ ├── E8M0.png
│ ├── H200-NeMo-performance.png
│ ├── MXFP8_FP8_comparison_1.png
│ ├── MXFP8_FP8_comparison_2.png
│ ├── advanced_optimizations.ipynb
│ ├── attention
│ │ ├── arbitrary_mask_to_post_scale_bias.py
│ │ ├── attention.ipynb
│ │ ├── dot_product_attention.png
│ │ └── example_attention.py
│ ├── comparison-fp8-bf16-training-nvidia-dgx-cloud-benchmarking-performance-explorer.jpg
│ ├── delayed_scaling.png
│ ├── fp8_formats.png
│ ├── fp8_primer.ipynb
│ ├── linear_mxfp8.png
│ ├── loss_scaling.png
│ ├── quickstart.ipynb
│ ├── quickstart_utils.py
│ ├── te_llama
│ │ ├── media
│ │ │ ├── llama_for_causal_lm.svg
│ │ │ ├── llama_zoom.svg
│ │ │ ├── llamadecoderlayer.svg
│ │ │ ├── model_change.svg
│ │ │ ├── swiglu.svg
│ │ │ ├── swiglu_te.svg
│ │ │ ├── tellamadecoderlayer.svg
│ │ │ ├── transformer_llama.png
│ │ │ ├── transformer_vs_llama.svg
│ │ │ └── weight_swap.svg
│ │ ├── te_llama.py
│ │ ├── tutorial_accelerate_hf_llama_with_te.ipynb
│ │ └── utils.py
│ └── transformer_layer.png
├── faq.rst
├── index.rst
├── installation.rst
└── version_select.patch
├── examples
├── README.md
├── jax
│ ├── README.md
│ ├── encoder
│ │ ├── README.md
│ │ ├── common.py
│ │ ├── conftest.py
│ │ ├── requirements.txt
│ │ ├── run_test_multiprocessing_encoder.sh
│ │ ├── test_model_parallel_encoder.py
│ │ ├── test_multigpu_encoder.py
│ │ ├── test_multiprocessing_encoder.py
│ │ └── test_single_gpu_encoder.py
│ └── mnist
│ │ ├── README.md
│ │ ├── requirements.txt
│ │ └── test_single_gpu_mnist.py
└── pytorch
│ ├── comm_gemm_overlap
│ ├── README.md
│ └── te_layer_with_overlap.py
│ ├── fsdp
│ ├── README.md
│ └── fsdp.py
│ └── mnist
│ ├── README.md
│ └── main.py
├── pylintrc
├── qa
├── L0_cppunittest
│ └── test.sh
├── L0_jax_distributed_unittest
│ └── test.sh
├── L0_jax_lint
│ └── test.sh
├── L0_jax_unittest
│ └── test.sh
├── L0_jax_wheel
│ └── test.sh
├── L0_license
│ ├── config.json
│ ├── copyright_checker.py
│ └── test.sh
├── L0_pytorch_debug_unittest
│ └── test.sh
├── L0_pytorch_lint
│ └── test.sh
├── L0_pytorch_unittest
│ └── test.sh
├── L0_pytorch_wheel
│ └── test.sh
├── L1_jax_distributed_unittest
│ └── test.sh
├── L1_pytorch_distributed_unittest
│ └── test.sh
├── L1_pytorch_mcore_integration
│ ├── .gitignore
│ ├── merges.txt
│ └── test.sh
├── L1_pytorch_thunder_integration
│ └── test.sh
├── L2_jax_unittest
│ └── test.sh
├── L3_pytorch_FA_versions_test
│ └── test.sh
└── format.sh
├── setup.py
├── tests
├── cpp
│ ├── CMakeLists.txt
│ ├── operator
│ │ ├── CMakeLists.txt
│ │ ├── test_act.cu
│ │ ├── test_cast.cu
│ │ ├── test_cast_current_scaling.cu
│ │ ├── test_cast_dbias.cu
│ │ ├── test_cast_dbias_dgelu.cu
│ │ ├── test_cast_float8blockwise.cu
│ │ ├── test_cast_gated_swiglu.cu
│ │ ├── test_cast_mxfp8.cu
│ │ ├── test_cast_mxfp8_gated_swiglu.cu
│ │ ├── test_cast_transpose.cu
│ │ ├── test_cast_transpose_current_scaling.cu
│ │ ├── test_cast_transpose_dbias.cu
│ │ ├── test_cast_transpose_dbias_dgelu.cu
│ │ ├── test_cast_transpose_dgeglu.cu
│ │ ├── test_causal_softmax.cu
│ │ ├── test_dequantize_mxfp8.cu
│ │ ├── test_memset.cu
│ │ ├── test_multi_cast_transpose.cu
│ │ ├── test_multi_padding.cu
│ │ ├── test_normalization.cu
│ │ ├── test_normalization.h
│ │ ├── test_normalization_mxfp8.cu
│ │ ├── test_qdq.cu
│ │ ├── test_swizzle.cu
│ │ └── test_transpose.cu
│ ├── test_common.cu
│ ├── test_common.h
│ └── util
│ │ ├── CMakeLists.txt
│ │ ├── test_nvrtc.cpp
│ │ └── test_string.cpp
├── jax
│ ├── conftest.py
│ ├── distributed_test_base.py
│ ├── pytest.ini
│ ├── test_custom_call_compute.py
│ ├── test_distributed_fused_attn.py
│ ├── test_distributed_layernorm.py
│ ├── test_distributed_layernorm_mlp.py
│ ├── test_distributed_softmax.py
│ ├── test_functions.py
│ ├── test_fused_attn.py
│ ├── test_helper.py
│ ├── test_layer.py
│ ├── test_misc.py
│ ├── test_sanity_import.py
│ ├── test_sharding.py
│ ├── test_softmax.py
│ └── utils.py
└── pytorch
│ ├── debug
│ ├── conftest.py
│ ├── run_distributed.py
│ ├── test_api_features.py
│ ├── test_config.py
│ ├── test_configs
│ │ ├── disable_fp8_gemms.yaml
│ │ ├── disable_fp8_layer.yaml
│ │ ├── dummy_feature.yaml
│ │ ├── fake_quantization_config.yaml
│ │ ├── per_tensor_scaling.yaml
│ │ ├── stats_collection_test_config.yaml
│ │ └── tensor_manipulation_transformer_engine.yaml
│ ├── test_distributed.py
│ ├── test_numerics.py
│ ├── test_sanity.py
│ └── utils.py
│ ├── distributed
│ ├── run_cast_master_weights_to_fp8.py
│ ├── run_fsdp2_model.py
│ ├── run_gemm_with_overlap.py
│ ├── run_layer_with_overlap.py
│ ├── run_numerics.py
│ ├── test_cast_master_weights_to_fp8.py
│ ├── test_comm_gemm_overlap.py
│ ├── test_fusible_ops.py
│ ├── test_fusible_ops_with_userbuffers.py
│ ├── test_numerics.py
│ └── test_torch_fsdp2.py
│ ├── fused_attn
│ ├── run_fused_attn_with_cp.py
│ ├── test_fused_attn.py
│ ├── test_fused_attn_with_cp.py
│ └── test_kv_cache.py
│ ├── references
│ ├── blockwise_fp8_gemm_reference.py
│ ├── blockwise_quantizer_reference.py
│ ├── quantize_scale_calc.py
│ └── ref_per_tensor_cs.py
│ ├── test_cpu_offloading.py
│ ├── test_cuda_graphs.py
│ ├── test_deferred_init.py
│ ├── test_float8_blockwise_gemm_exact.py
│ ├── test_float8_blockwise_scaling_exact.py
│ ├── test_float8_current_scaling_exact.py
│ ├── test_float8blockwisetensor.py
│ ├── test_float8tensor.py
│ ├── test_fused_optimizer.py
│ ├── test_fused_rope.py
│ ├── test_fusible_ops.py
│ ├── test_gqa.py
│ ├── test_hf_integration.py
│ ├── test_jit.py
│ ├── test_multi_tensor.py
│ ├── test_numerics.py
│ ├── test_parallel_cross_entropy.py
│ ├── test_permutation.py
│ ├── test_recipe.py
│ ├── test_sanity.py
│ ├── test_sanity_import.py
│ └── utils.py
└── transformer_engine
├── __init__.py
├── common
├── CMakeLists.txt
├── __init__.py
├── activation
│ ├── activation_template.h
│ ├── gelu.cu
│ ├── relu.cu
│ └── swiglu.cu
├── comm_gemm_overlap
│ ├── comm_gemm_overlap.cpp
│ └── userbuffers
│ │ ├── ipcsocket.cc
│ │ ├── ipcsocket.h
│ │ ├── userbuffers-host.cpp
│ │ ├── userbuffers.cu
│ │ └── userbuffers.h
├── common.cu
├── common.h
├── cudnn_utils.cpp
├── cudnn_utils.h
├── fused_attn
│ ├── context_parallel.cu
│ ├── flash_attn.cu
│ ├── fused_attn.cpp
│ ├── fused_attn_f16_arbitrary_seqlen.cu
│ ├── fused_attn_f16_arbitrary_seqlen.h
│ ├── fused_attn_f16_max512_seqlen.cu
│ ├── fused_attn_f16_max512_seqlen.h
│ ├── fused_attn_fp8.cu
│ ├── fused_attn_fp8.h
│ ├── kv_cache.cu
│ ├── utils.cu
│ └── utils.h
├── fused_rope
│ └── fused_rope.cu
├── fused_softmax
│ ├── scaled_aligned_causal_masked_softmax.cu
│ ├── scaled_masked_softmax.cu
│ └── scaled_upper_triang_masked_softmax.cu
├── gemm
│ └── cublaslt_gemm.cu
├── include
│ └── transformer_engine
│ │ ├── activation.h
│ │ ├── cast.h
│ │ ├── cast_transpose_noop.h
│ │ ├── comm_gemm_overlap.h
│ │ ├── cudnn.h
│ │ ├── fused_attn.h
│ │ ├── fused_rope.h
│ │ ├── gemm.h
│ │ ├── multi_tensor.h
│ │ ├── normalization.h
│ │ ├── padding.h
│ │ ├── permutation.h
│ │ ├── recipe.h
│ │ ├── softmax.h
│ │ ├── swizzle.h
│ │ ├── transformer_engine.h
│ │ └── transpose.h
├── libtransformer_engine.version
├── multi_tensor
│ ├── adam.cu
│ ├── compute_scale.cu
│ ├── l2norm.cu
│ ├── multi_tensor_apply.cuh
│ ├── scale.cu
│ └── sgd.cu
├── normalization
│ ├── common.cpp
│ ├── common.h
│ ├── kernel_traits.h
│ ├── layernorm
│ │ ├── ln_api.cpp
│ │ ├── ln_bwd_kernels.cuh
│ │ ├── ln_bwd_semi_cuda_kernel.cu
│ │ ├── ln_fwd_cuda_kernel.cu
│ │ └── ln_fwd_kernels.cuh
│ └── rmsnorm
│ │ ├── rmsnorm_api.cpp
│ │ ├── rmsnorm_bwd_kernels.cuh
│ │ ├── rmsnorm_bwd_semi_cuda_kernel.cu
│ │ ├── rmsnorm_fwd_cuda_kernel.cu
│ │ └── rmsnorm_fwd_kernels.cuh
├── nvshmem_api
│ ├── CMakeLists.txt
│ ├── nvshmem_waitkernel.cu
│ └── nvshmem_waitkernel.h
├── nvtx.h
├── permutation
│ └── permutation.cu
├── recipe
│ ├── __init__.py
│ ├── current_scaling.cu
│ ├── delayed_scaling.cu
│ ├── fp8_block_scaling.cu
│ └── recipe_common.cuh
├── swizzle
│ └── swizzle.cu
├── transformer_engine.cpp
├── transpose
│ ├── cast_transpose.cu
│ ├── cast_transpose.h
│ ├── cast_transpose_fusion.cu
│ ├── multi_cast_transpose.cu
│ ├── quantize_transpose_square_blockwise.cu
│ ├── quantize_transpose_vector_blockwise.cu
│ ├── rtc
│ │ ├── cast_transpose.cu
│ │ ├── cast_transpose_fusion.cu
│ │ └── transpose.cu
│ ├── transpose.cu
│ └── transpose_fusion.cu
├── util
│ ├── cast.cu
│ ├── cast_gated_kernels.cuh
│ ├── cast_kernels.cuh
│ ├── cuda_driver.cpp
│ ├── cuda_driver.h
│ ├── cuda_nvml.cpp
│ ├── cuda_nvml.h
│ ├── cuda_runtime.cpp
│ ├── cuda_runtime.h
│ ├── dequantize_kernels.cuh
│ ├── handle_manager.h
│ ├── logging.h
│ ├── math.h
│ ├── padding.cu
│ ├── ptx.cuh
│ ├── pybind_helper.h
│ ├── rtc.cpp
│ ├── rtc.h
│ ├── shared_lib_wrapper.h
│ ├── string.h
│ ├── string_header.h.in
│ ├── system.h
│ └── vectorized_pointwise.h
├── utils.cuh
└── utils.py
├── debug
├── __init__.py
├── features
│ ├── __init__.py
│ ├── _test_dummy_feature.py
│ ├── api.py
│ ├── disable_fp8_gemm.py
│ ├── disable_fp8_layer.py
│ ├── fake_quant.py
│ ├── log_fp8_tensor_stats.py
│ ├── log_tensor_stats.py
│ ├── per_tensor_scaling.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── stats_buffer.py
│ │ └── stats_computation.py
└── pytorch
│ ├── __init__.py
│ ├── debug_quantization.py
│ ├── debug_state.py
│ └── utils.py
├── jax
├── MANIFEST.in
├── __init__.py
├── activation.py
├── attention.py
├── cpp_extensions
│ ├── __init__.py
│ ├── activation.py
│ ├── attention.py
│ ├── base.py
│ ├── gemm.py
│ ├── misc.py
│ ├── normalization.py
│ ├── quantization.py
│ └── softmax.py
├── csrc
│ ├── extensions.h
│ └── extensions
│ │ ├── activation.cpp
│ │ ├── attention.cpp
│ │ ├── cublas.cpp
│ │ ├── cudnn.cpp
│ │ ├── ffi.cpp
│ │ ├── ffi.h
│ │ ├── gemm.cpp
│ │ ├── misc.cpp
│ │ ├── misc.h
│ │ ├── normalization.cpp
│ │ ├── pybind.cpp
│ │ ├── quantization.cpp
│ │ ├── softmax.cpp
│ │ ├── utils.cpp
│ │ └── utils.h
├── dense.py
├── flax
│ ├── __init__.py
│ ├── module.py
│ └── transformer.py
├── layernorm.py
├── layernorm_dense.py
├── layernorm_mlp.py
├── quantize
│ ├── __init__.py
│ ├── dequantizer.py
│ ├── helper.py
│ ├── metadata.py
│ ├── quantizer.py
│ ├── scaling_modes.py
│ └── tensor.py
├── setup.py
├── sharding.py
└── softmax.py
└── pytorch
├── MANIFEST.in
├── __init__.py
├── attention
├── __init__.py
├── dot_product_attention
│ ├── __init__.py
│ ├── backends.py
│ ├── context_parallel.py
│ ├── dot_product_attention.py
│ ├── softmax.py
│ └── utils.py
├── inference.py
├── multi_head_attention.py
└── rope.py
├── constants.py
├── cpp_extensions
├── __init__.py
├── fused_attn.py
└── gemm.py
├── cpu_offload.py
├── cross_entropy.py
├── csrc
├── common.cpp
├── common.h
├── extensions.h
├── extensions
│ ├── activation.cpp
│ ├── apply_rope.cpp
│ ├── attention.cpp
│ ├── bias.cpp
│ ├── cast.cpp
│ ├── comm_gemm_overlap.cpp
│ ├── fp8_block_scaling_partial_cast.cpp
│ ├── gemm.cpp
│ ├── misc.cpp
│ ├── multi_tensor
│ │ ├── adam.cpp
│ │ ├── compute_scale.cpp
│ │ ├── l2norm.cpp
│ │ ├── scale.cpp
│ │ └── sgd.cpp
│ ├── normalization.cpp
│ ├── nvshmem_comm.cpp
│ ├── padding.cpp
│ ├── permutation.cpp
│ ├── pybind.cpp
│ ├── recipe.cpp
│ ├── softmax.cpp
│ └── transpose.cpp
├── pybind.h
├── quantizer.cpp
├── type_converters.cpp
├── util.cpp
└── util.h
├── distributed.py
├── float8_tensor.py
├── fp8.py
├── graph.py
├── jit.py
├── module
├── __init__.py
├── _common.py
├── base.py
├── fp8_padding.py
├── fp8_unpadding.py
├── grouped_linear.py
├── layernorm.py
├── layernorm_linear.py
├── layernorm_mlp.py
├── linear.py
└── rmsnorm.py
├── numerics_debug.py
├── ops
├── __init__.py
├── _common.py
├── basic
│ ├── __init__.py
│ ├── activation.py
│ ├── add_in_place.py
│ ├── all_gather.py
│ ├── all_reduce.py
│ ├── basic_linear.py
│ ├── bias.py
│ ├── identity.py
│ ├── layer_norm.py
│ ├── make_extra_output.py
│ ├── quantize.py
│ ├── reduce_scatter.py
│ ├── reshape.py
│ └── rmsnorm.py
├── fused
│ ├── __init__.py
│ ├── backward_linear_add.py
│ ├── forward_linear_bias_activation.py
│ ├── forward_linear_bias_add.py
│ ├── userbuffers_backward_linear.py
│ └── userbuffers_forward_linear.py
├── fuser.py
├── linear.py
├── op.py
└── sequential.py
├── optimizers
├── __init__.py
├── fused_adam.py
├── fused_sgd.py
└── multi_tensor_apply.py
├── permutation.py
├── setup.py
├── tensor
├── __init__.py
├── _internal
│ ├── __init__.py
│ ├── float8_blockwise_tensor_base.py
│ ├── float8_tensor_base.py
│ └── mxfp8_tensor_base.py
├── float8_blockwise_tensor.py
├── float8_tensor.py
├── mxfp8_tensor.py
├── quantized_tensor.py
└── utils.py
├── transformer.py
├── triton
├── __init__.py
├── cross_entropy.py
└── permutation.py
└── utils.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 |
12 | A clear and concise description of what the bug is.
13 |
14 | **Steps/Code to reproduce bug**
15 |
16 | Please list *minimal* steps or code snippet for us to be able to reproduce the bug.
17 |
18 | A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
19 |
20 |
21 | **Expected behavior**
22 |
23 | A clear and concise description of what you expected to happen.
24 |
25 | **Environment overview (please complete the following information)**
26 |
27 | - Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)]
28 | - Method of Transformer Engine install: [pip install or from source]. Please specify exact commands you used to install.
29 | - If method of install is [Docker], provide `docker pull` & `docker run` commands used
30 |
31 | **Environment details**
32 |
33 | If NVIDIA docker image is used you don't need to specify these.
34 | Otherwise, please provide:
35 | - OS version
36 | - PyTorch version
37 | - Python version
38 | - Transformer Engine version
39 | - CUDA version
40 | - CUDNN version
41 |
42 | **Device details**
43 | - GPU model
44 |
45 | **Additional context**
46 |
47 | Add any other context about the problem here.
48 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: feature request
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 |
12 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
13 |
14 | **Describe the solution you'd like**
15 |
16 | A clear and concise description of what you want to happen.
17 | Provide a code snippet on how new APIs/changes would be used by others.
18 |
19 | **Describe alternatives you've considered**
20 |
21 | A clear and concise description of any alternative solutions or features you've considered.
22 |
23 | **Additional context**
24 |
25 | Add any other context or screenshots about the feature request here.
26 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # Description
2 |
3 | Please include a brief summary of the changes, relevant motivation and context.
4 |
5 | Fixes # (issue)
6 |
7 | ## Type of change
8 |
9 | - [ ] Documentation change (change only to the documentation, either a fix or a new content)
10 | - [ ] Bug fix (non-breaking change which fixes an issue)
11 | - [ ] New feature (non-breaking change which adds functionality)
12 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
13 | - [ ] Infra/Build change
14 | - [ ] Code refactoring
15 |
16 | ## Changes
17 |
18 | Please list the changes introduced in this PR:
19 |
20 | - Change A
21 | - Change B
22 |
23 | # Checklist:
24 |
25 | - [ ] I have read and followed the [contributing guidelines](https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst)
26 | - [ ] The functionality is complete
27 | - [ ] I have commented my code, particularly in hard-to-understand areas
28 | - [ ] I have made corresponding changes to the documentation
29 | - [ ] My changes generate no new warnings
30 | - [ ] I have added tests that prove my fix is effective or that my feature works
31 | - [ ] New and existing unit tests pass locally with my changes
32 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | # A workflow to trigger TE build on GitHub
6 | name: 'Build'
7 | on:
8 | pull_request:
9 | workflow_dispatch:
10 | jobs:
11 | core:
12 | name: 'Core'
13 | runs-on: ubuntu-latest
14 | container:
15 | image: nvcr.io/nvidia/cuda:12.1.0-devel-ubuntu22.04
16 | options: --user root
17 | steps:
18 | - name: 'Dependencies'
19 | run: |
20 | apt-get update
21 | apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
22 | pip install cmake==3.21.0
23 | - name: 'Checkout'
24 | uses: actions/checkout@v3
25 | with:
26 | submodules: recursive
27 | - name: 'Build'
28 | run: pip install --no-build-isolation . -v
29 | env:
30 | NVTE_FRAMEWORK: none
31 | MAX_JOBS: 1
32 | - name: 'Sanity check'
33 | run: python3 -c "import transformer_engine"
34 | working-directory: /
35 | pytorch:
36 | name: 'PyTorch'
37 | runs-on: ubuntu-latest
38 | container:
39 | image: nvcr.io/nvidia/cuda:12.8.0-devel-ubuntu22.04
40 | options: --user root
41 | steps:
42 | - name: 'Dependencies'
43 | run: |
44 | apt-get update
45 | apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
46 | pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11
47 | - name: 'Checkout'
48 | uses: actions/checkout@v3
49 | with:
50 | submodules: recursive
51 | - name: 'Build'
52 | run: pip install --no-build-isolation . -v --no-deps
53 | env:
54 | NVTE_FRAMEWORK: pytorch
55 | MAX_JOBS: 1
56 | - name: 'Sanity check'
57 | if: false # Sanity import test requires Flash Attention
58 | run: python3 tests/pytorch/test_sanity_import.py
59 | jax:
60 | name: 'JAX'
61 | runs-on: ubuntu-latest
62 | container:
63 | image: ghcr.io/nvidia/jax:jax
64 | options: --user root
65 | steps:
66 | - name: 'Checkout'
67 | uses: actions/checkout@v3
68 | with:
69 | submodules: recursive
70 | - name: 'Build'
71 | run: pip install --no-build-isolation . -v
72 | env:
73 | NVTE_FRAMEWORK: jax
74 | MAX_JOBS: 1
75 | - name: 'Sanity check'
76 | run: python tests/jax/test_sanity_import.py
77 |
--------------------------------------------------------------------------------
/.github/workflows/deploy_nightly_docs.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | # A workflow to deploy the nightly version of TE documentation to GitHub Pages
6 | name: Deploy nightly docs
7 | on:
8 | push:
9 | branches: [ "main" ]
10 | jobs:
11 | build:
12 | uses: ./.github/workflows/docs.yml
13 |
14 | prepare:
15 | needs: build
16 | runs-on: ubuntu-latest
17 | steps:
18 | - name: Download artifact
19 | uses: actions/download-artifact@v4
20 | with:
21 | name: "te_docs"
22 | path: "html"
23 | - name: Prepare for pages
24 | uses: actions/upload-pages-artifact@v1.0.7
25 | with:
26 | name: github-pages
27 | path: "html"
28 | deploy:
29 | needs: prepare
30 | environment:
31 | name: github-pages
32 | url: ${{ steps.deployment.outputs.page_url }}
33 | permissions:
34 | pages: write
35 | id-token: write
36 | runs-on: ubuntu-latest
37 | steps:
38 | - name: Deploy
39 | uses: actions/deploy-pages@v2.0.0
40 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | # A workflow to trigger the build of TE documentation on GitHub
6 | name: 'Documentation'
7 | on:
8 | pull_request:
9 | workflow_dispatch:
10 | workflow_call:
11 | jobs:
12 | build_docs:
13 | name: 'Build'
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: 'Checkout'
17 | uses: actions/checkout@v3
18 | - name: 'Install dependencies'
19 | run: |
20 | pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2
21 | pip install breathe==4.35.0 sphinx-autoapi==3.3.2
22 | sudo apt-get install -y pandoc graphviz doxygen
23 | export GIT_SHA=$(git show-ref --hash HEAD)
24 | - name: 'Build docs'
25 | run: |
26 | doxygen docs/Doxyfile
27 | cd docs
28 | make html
29 | - name: 'Upload docs'
30 | uses: actions/upload-artifact@v4
31 | with:
32 | name: te_docs
33 | path: docs/_build/html
34 | retention-days: 7
35 |
--------------------------------------------------------------------------------
/.github/workflows/license.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | # A workflow to trigger the TE license check on GitHub
6 | name: 'License'
7 | on:
8 | pull_request:
9 | workflow_dispatch:
10 | jobs:
11 | check:
12 | name: 'Check'
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: 'Checkout'
16 | uses: actions/checkout@v3
17 | - name: 'Check License'
18 | run: |
19 | export TE_PATH=.
20 | bash ./qa/L0_license/test.sh
21 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | # A workflow to trigger lint tests on GitHub
6 | name: 'Lint'
7 | on:
8 | pull_request:
9 | workflow_dispatch:
10 | jobs:
11 | pytorch_cpplint:
12 | name: 'PyTorch C++'
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Checkout
16 | uses: actions/checkout@v3
17 | - name: 'Lint'
18 | run: |
19 | sudo apt-get update
20 | sudo apt-get install pip -y
21 | export CPP_ONLY=1
22 | export TE_PATH=.
23 | bash ./qa/L0_pytorch_lint/test.sh
24 | pytorch_pylint:
25 | name: 'PyTorch Python'
26 | runs-on: ubuntu-latest
27 | steps:
28 | - name: 'Checkout'
29 | uses: actions/checkout@v3
30 | - name: 'Lint'
31 | run: |
32 | sudo apt-get update
33 | sudo apt-get install pip -y
34 | pip install torch numpy
35 | export PYTHON_ONLY=1
36 | export TE_PATH=.
37 | bash ./qa/L0_pytorch_lint/test.sh
38 | jax_cpplint:
39 | name: 'JAX C++'
40 | runs-on: ubuntu-latest
41 | steps:
42 | - name: 'Checkout'
43 | uses: actions/checkout@v3
44 | - name: 'Lint'
45 | run: |
46 | sudo apt-get update
47 | sudo apt-get install pip -y
48 | export CPP_ONLY=1
49 | export TE_PATH=.
50 | bash ./qa/L0_jax_lint/test.sh
51 | jax_pylint:
52 | name: 'JAX Python'
53 | runs-on: ubuntu-latest
54 | steps:
55 | - name: 'Checkout'
56 | uses: actions/checkout@v3
57 | - name: 'Lint'
58 | run: |
59 | sudo apt-get update
60 | sudo apt-get install pip -y
61 | export PYTHON_ONLY=1
62 | export TE_PATH=.
63 | bash ./qa/L0_jax_lint/test.sh
64 |
--------------------------------------------------------------------------------
/.github/workflows/upload-ci-logs.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | # A workflow to trigger ci on hybrid infra (github + self hosted runner)
6 | name: TE-CI Logs
7 | on:
8 | workflow_dispatch:
9 | inputs:
10 | platform:
11 | description: 'runs-on argument'
12 | required: false
13 | args:
14 | description: 'argument'
15 | required: false
16 | job_name:
17 | description: 'name of the job'
18 | required: true
19 | commit_sha:
20 | description: 'SHA of the commit that was tested.'
21 | required: true
22 | result:
23 | description: 'Job result'
24 | required: true
25 | run-name: PR ${{ fromJson(github.event.inputs.args).pr }} - ${{ inputs.job_name }}
26 | jobs:
27 | Upload-Log:
28 | name: Upload log
29 | runs-on: blossom
30 | steps:
31 | - name: Log
32 | run: blossom-ci
33 | env:
34 | OPERATION: 'POST-PROCESSING'
35 | CI_SERVER: ${{ secrets.CI_SERVER }}
36 | REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
37 | status_update:
38 | name: Update commit status
39 | runs-on: ubuntu-latest
40 | permissions:
41 | statuses: write
42 | needs: [Upload-Log]
43 | if: ${{ always() }}
44 | steps:
45 | - name: Set status
46 | run: |
47 | curl \
48 | -X POST \
49 | -H "Accept: application/vnd.github+json" \
50 | -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
51 | https://api.github.com/repos/${{ github.repository }}/statuses/${{ inputs.commit_sha }} \
52 | -d "{\"state\":\"${{ inputs.result }}\",\"target_url\":\"${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}\",\"description\":\"\",\"context\":\"te-ci/${{ inputs.job_name }}\"}"
53 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.o
2 | *.swp
3 | *.ii
4 | *.ptx
5 | *.cubin
6 | *.fatbin*
7 | *.module_id
8 | *.nsys-rep
9 | *.ncu-rep
10 | *.sqlite
11 | *.eggs
12 | build/
13 | *.so
14 | *.egg-info
15 | __pycache__
16 | .ycm_extra_conf.py
17 | .vimrc
18 | .vs
19 | .vscode
20 | .cache
21 | .hypothesis
22 | .devcontainer.json
23 | tests/cpp/build/
24 | .ipynb_checkpoints
25 | *.log
26 | CMakeFiles/CMakeSystem.cmake
27 | sdist/
28 | var/
29 | wheels/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 | develop-eggs/
36 | dist/
37 | downloads/
38 | .pytest_cache/
39 | compile_commands.json
40 | .nfs
41 | tensor_dumps/
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "3rdparty/googletest"]
2 | path = 3rdparty/googletest
3 | url = https://github.com/google/googletest.git
4 | [submodule "3rdparty/cudnn-frontend"]
5 | path = 3rdparty/cudnn-frontend
6 | url = https://github.com/NVIDIA/cudnn-frontend.git
7 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | default_language_version:
6 | python: python3
7 |
8 | ci:
9 | autofix_prs: true
10 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
11 | autoupdate_schedule: quarterly
12 | submodules: false
13 | skip: []
14 |
15 | repos:
16 | - repo: https://github.com/pre-commit/pre-commit-hooks
17 | rev: v4.6.0
18 | hooks:
19 | - id: check-merge-conflict
20 | - id: check-added-large-files
21 | - id: end-of-file-fixer
22 | files: .*.(c|cc|cxx|cpp|cu|cuh|h|hpp|py)$
23 | - id: trailing-whitespace
24 | files: .*.(c|cc|cxx|cpp|cu|cuh|h|hpp|py)$
25 |
26 | - repo: https://github.com/psf/black
27 | rev: 24.4.2
28 | hooks:
29 | - id: black
30 | name: Format python code
31 | args: [--line-length=100, --preview, --enable-unstable-feature=string_processing]
32 | types: [python]
33 |
34 | - repo: https://github.com/pre-commit/mirrors-clang-format
35 | rev: v18.1.6
36 | hooks:
37 | - id: clang-format
38 | entry: clang-format -i
39 | args: ["-style=file"]
40 | files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$
41 |
--------------------------------------------------------------------------------
/CPPLINT.cfg:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | # Stop searching for additional config files.
6 | set noparent
7 |
8 | # Limit line length.
9 | linelength=100
10 |
11 | # Ignore the following errors.
12 | filter=-build/include_subdir
13 | filter=-build/namespaces
14 | filter=-readability/todo
15 | filter=-build/header_guard
16 | filter=-build/include
17 | filter=-build/c++11
18 | filter=-runtime/references
19 | filter=-whitespace
20 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | ## Security
2 |
3 | NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization.
4 |
5 | If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub/GitLab.**
6 |
7 | ## Reporting Potential Security Vulnerability in an NVIDIA Product
8 |
9 | To report a potential security vulnerability in any NVIDIA product:
10 | - Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html)
11 | - E-Mail: psirt@nvidia.com
12 | - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key)
13 | - Please include the following information:
14 | - Product/Driver name and version/branch that contains the vulnerability
15 | - Type of vulnerability (code execution, denial of service, buffer overflow, etc.)
16 | - Instructions to reproduce the vulnerability
17 | - Proof-of-concept or exploit code
18 | - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability
19 |
20 | While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information.
21 |
22 | ## NVIDIA Product Security
23 |
24 | For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security
25 |
--------------------------------------------------------------------------------
/build_tools/VERSION.txt:
--------------------------------------------------------------------------------
1 | 2.5.0.dev0
2 |
--------------------------------------------------------------------------------
/build_tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | """Build related infrastructure."""
6 |
--------------------------------------------------------------------------------
/build_tools/jax.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | """JAX related extensions."""
6 | import os
7 | import shutil
8 | from pathlib import Path
9 |
10 | import setuptools
11 |
12 | from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled
13 | from typing import List
14 |
15 |
16 | def xla_path() -> str:
17 | """XLA root path lookup.
18 | Throws FileNotFoundError if XLA source is not found."""
19 |
20 | try:
21 | from jax.extend import ffi
22 | except ImportError:
23 | if os.getenv("XLA_HOME"):
24 | xla_home = Path(os.getenv("XLA_HOME"))
25 | else:
26 | xla_home = "/opt/xla"
27 | else:
28 | xla_home = ffi.include_dir()
29 |
30 | if not os.path.isdir(xla_home):
31 | raise FileNotFoundError("Could not find xla source.")
32 | return xla_home
33 |
34 |
35 | def setup_jax_extension(
36 | csrc_source_files,
37 | csrc_header_files,
38 | common_header_files,
39 | ) -> setuptools.Extension:
40 | """Setup PyBind11 extension for JAX support"""
41 | # Source files
42 | csrc_source_files = Path(csrc_source_files)
43 | extensions_dir = csrc_source_files / "extensions"
44 | sources = all_files_in_dir(extensions_dir, name_extension="cpp")
45 |
46 | # Header files
47 | include_dirs = get_cuda_include_dirs()
48 | include_dirs.extend(
49 | [
50 | common_header_files,
51 | common_header_files / "common",
52 | common_header_files / "common" / "include",
53 | csrc_header_files,
54 | xla_path(),
55 | ]
56 | )
57 |
58 | # Compile flags
59 | cxx_flags = ["-O3"]
60 | if debug_build_enabled():
61 | cxx_flags.append("-g")
62 | cxx_flags.append("-UNDEBUG")
63 | else:
64 | cxx_flags.append("-g0")
65 |
66 | # Define TE/JAX as a Pybind11Extension
67 | from pybind11.setup_helpers import Pybind11Extension
68 |
69 | class Pybind11CPPExtension(Pybind11Extension):
70 | """Modified Pybind11Extension to allow custom CXX flags."""
71 |
72 | def _add_cflags(self, flags: List[str]) -> None:
73 | if isinstance(self.extra_compile_args, dict):
74 | cxx_flags = self.extra_compile_args.pop("cxx", [])
75 | cxx_flags += flags
76 | self.extra_compile_args["cxx"] = cxx_flags
77 | else:
78 | self.extra_compile_args[:0] = flags
79 |
80 | return Pybind11CPPExtension(
81 | "transformer_engine_jax",
82 | sources=[str(path) for path in sources],
83 | include_dirs=[str(path) for path in include_dirs],
84 | extra_compile_args={"cxx": cxx_flags},
85 | )
86 |
--------------------------------------------------------------------------------
/build_tools/te_version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | """Transformer Engine version string."""
6 | import os
7 | from pathlib import Path
8 | import subprocess
9 |
10 |
11 | def te_version() -> str:
12 | """Transformer Engine version string
13 |
14 | Includes Git commit as local version, unless suppressed with
15 | NVTE_NO_LOCAL_VERSION environment variable.
16 |
17 | """
18 | root_path = Path(__file__).resolve().parent
19 | with open(root_path / "VERSION.txt", "r") as f:
20 | version = f.readline().strip()
21 | if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")) and not bool(
22 | int(os.getenv("NVTE_RELEASE_BUILD", "0"))
23 | ):
24 | try:
25 | output = subprocess.run(
26 | ["git", "rev-parse", "--short", "HEAD"],
27 | capture_output=True,
28 | cwd=root_path,
29 | check=True,
30 | universal_newlines=True,
31 | )
32 | except (subprocess.CalledProcessError, OSError):
33 | pass
34 | else:
35 | commit = output.stdout.strip()
36 | version += f"+{commit}"
37 | return version
38 |
--------------------------------------------------------------------------------
/build_tools/wheel_utils/Dockerfile.aarch:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | FROM quay.io/pypa/manylinux_2_28_aarch64
6 |
7 | WORKDIR /TransformerEngine/
8 | COPY ../.. /TransformerEngine/
9 |
10 | ARG VER="12-3"
11 | ARG ARCH="aarch64"
12 | RUN dnf -y install vim
13 |
14 | # Cuda toolkit, cudnn, driver.
15 | RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
16 | RUN dnf -y install epel-release
17 | RUN dnf -y install cuda-compiler-${VER}.${ARCH} \
18 | cuda-libraries-${VER}.${ARCH} \
19 | cuda-libraries-devel-${VER}.${ARCH}
20 | RUN dnf -y install --allowerasing cudnn9-cuda-12
21 | RUN dnf clean all
22 | RUN rm -rf /var/cache/dnf/*
23 | RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
24 | RUN dnf -y install cuda-toolkit
25 | RUN dnf clean all
26 | RUN dnf -y install glog.aarch64 glog-devel.aarch64
27 |
28 | ENV PATH="/usr/local/cuda/bin:${PATH}"
29 | ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
30 | ENV CUDA_HOME=/usr/local/cuda
31 | ENV CUDA_ROOT=/usr/local/cuda
32 | ENV CUDA_PATH=/usr/local/cuda
33 | ENV CUDADIR=/usr/local/cuda
34 | ENV NVTE_RELEASE_BUILD=1
35 |
36 | CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"]
37 |
--------------------------------------------------------------------------------
/build_tools/wheel_utils/Dockerfile.x86:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | FROM quay.io/pypa/manylinux_2_28_x86_64
6 |
7 | WORKDIR /TransformerEngine/
8 | COPY ../.. /TransformerEngine/
9 |
10 | ARG VER="12-3"
11 | ARG ARCH="x86_64"
12 | RUN dnf -y install vim
13 |
14 | # Cuda toolkit, cudnn, driver.
15 | RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
16 | RUN dnf -y install epel-release
17 | RUN dnf -y install cuda-compiler-${VER}.${ARCH} \
18 | cuda-libraries-${VER}.${ARCH} \
19 | cuda-libraries-devel-${VER}.${ARCH}
20 | RUN dnf -y install --allowerasing cudnn9-cuda-12
21 | RUN dnf clean all
22 | RUN rm -rf /var/cache/dnf/*
23 | RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
24 | RUN dnf -y install cuda-toolkit
25 | RUN dnf clean all
26 | RUN dnf -y install glog.x86_64 glog-devel.x86_64
27 |
28 | ENV PATH="/usr/local/cuda/bin:${PATH}"
29 | ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
30 | ENV CUDA_HOME=/usr/local/cuda
31 | ENV CUDA_ROOT=/usr/local/cuda
32 | ENV CUDA_PATH=/usr/local/cuda
33 | ENV CUDADIR=/usr/local/cuda
34 | ENV NVTE_RELEASE_BUILD=1
35 |
36 | CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"]
37 |
--------------------------------------------------------------------------------
/build_tools/wheel_utils/launch_aarch.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch .
6 | docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel"
7 | rm -rf aarch_wheelhouse
8 | docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse
9 |
--------------------------------------------------------------------------------
/build_tools/wheel_utils/launch_x86.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # See LICENSE for license information.
4 |
5 | docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 .
6 | docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel"
7 | rm -rf x86_wheelhouse
8 | docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse
9 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | _build
2 | doxygen
3 | sphinx_rtd_theme
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile sphinx_rtd_theme
20 | PYTHONPATH=sphinx_rtd_theme:$(PYTHONPATH) $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
22 | # Patch Sphinx RTD theme 3.0.1 to add version selector in sidebar
23 | sphinx_rtd_theme:
24 | git clone --depth=1 -b 3.0.1 --single-branch https://github.com/readthedocs/sphinx_rtd_theme.git
25 | bash -c "cd sphinx_rtd_theme; git apply ../version_select.patch"
26 |
--------------------------------------------------------------------------------
/docs/_static/NVIDIA-LogoBlack.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/_static/css/nvidia_footer.css:
--------------------------------------------------------------------------------
1 | footer img {
2 | display: block;
3 | width: 137.5px;
4 | position: relative;
5 | left: -9px;
6 | margin: 0 0 15px 0;
7 | }
8 |
9 | footer p {
10 | color: #666666;
11 | font-weight: normal;
12 | font-size: 12px;
13 | line-height: 1.25em;
14 | }
15 |
16 | footer p:not(.notices) {
17 | display: inline;
18 | margin: 0;
19 | }
20 |
21 | footer p a,
22 | footer p a:link,
23 | footer p a:visited {
24 | color: #666666;
25 | }
26 |
27 | footer p a:hover {
28 | color: #666666;
29 | }
30 |
--------------------------------------------------------------------------------
/docs/_templates/footer.html:
--------------------------------------------------------------------------------
1 | {% extends '!footer.html' %}
2 |
3 | {% block contentinfo %}
4 |
5 |
6 | Privacy Policy 7 | | 8 | Manage My Privacy 9 | | 10 | Do Not Sell or Share My Data 11 | | 12 | Terms of Service 13 | | 14 | Accessibility 15 | | 16 | Corporate Policies 17 | | 18 | Product Security 19 | | 20 | Contact 21 |
22 | {{ super() }} 23 | {% endblock %} 24 | -------------------------------------------------------------------------------- /docs/api/c/activation.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | activation.h 7 | ============ 8 | 9 | .. doxygenfile:: activation.h 10 | -------------------------------------------------------------------------------- /docs/api/c/cast.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | cast.h 7 | ====== 8 | 9 | .. doxygenfile:: cast.h 10 | -------------------------------------------------------------------------------- /docs/api/c/cast_transpose_noop.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | cast_transpose_noop.h 7 | ===================== 8 | 9 | .. doxygenfile:: cast_transpose_noop.h 10 | -------------------------------------------------------------------------------- /docs/api/c/cudnn.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | cudnn.h 7 | ======= 8 | 9 | .. doxygenfile:: cudnn.h 10 | -------------------------------------------------------------------------------- /docs/api/c/fused_attn.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | fused_attn.h 7 | ============ 8 | 9 | .. doxygenfile:: fused_attn.h 10 | -------------------------------------------------------------------------------- /docs/api/c/fused_rope.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | fused_rope.h 7 | ============ 8 | 9 | .. doxygenfile:: fused_rope.h 10 | 11 | -------------------------------------------------------------------------------- /docs/api/c/gemm.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | gemm.h 7 | ====== 8 | 9 | .. doxygenfile:: gemm.h 10 | -------------------------------------------------------------------------------- /docs/api/c/index.rst: -------------------------------------------------------------------------------- 1 | .. 2 | Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | 4 | See LICENSE for license information. 5 | 6 | C/C++ API 7 | ========= 8 | 9 | The C/C++ API allows you to access the custom kernels defined in `libtransformer_engine.so` library 10 | directly from C/C++, without Python. 11 | 12 | .. toctree:: 13 | :caption: Headers 14 | 15 | transformer_engine.h