├── .clang-format ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── blossom-ci.yml │ ├── build.yml │ ├── deploy_nightly_docs.yml │ ├── docs.yml │ ├── license.yml │ ├── lint.yml │ ├── rocm-ci.yml │ ├── trigger-ci.yml │ └── upload-ci-logs.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── Acknowledgements.txt ├── CONTRIBUTING.rst ├── CPPLINT.cfg ├── LICENSE ├── README.rst ├── SECURITY.md ├── benchmarks └── attention │ ├── benchmark_attention.py │ └── benchmark_attention_rocm.py ├── ci ├── README.md ├── _utils.sh ├── ci_config.json ├── core.sh ├── jax.sh └── pytorch.sh ├── docs ├── .gitignore ├── Doxyfile ├── Makefile ├── _static │ ├── NVIDIA-LogoBlack.svg │ └── css │ │ ├── nvidia_font.css │ │ └── nvidia_footer.css ├── _templates │ ├── footer.html │ └── layout.html ├── api │ ├── c │ │ ├── activation.rst │ │ ├── cast.rst │ │ ├── fused_attn.rst │ │ ├── fused_rope.rst │ │ ├── gemm.rst │ │ ├── index.rst │ │ ├── normalization.rst │ │ ├── padding.rst │ │ ├── permutation.rst │ │ ├── recipe.rst │ │ ├── softmax.rst │ │ ├── swizzle.rst │ │ ├── transformer_engine.rst │ │ └── transpose.rst │ ├── common.rst │ ├── framework.rst │ ├── jax.rst │ └── pytorch.rst ├── conf.py ├── examples │ ├── E8M0.png │ ├── H200-NeMo-performance.png │ ├── MXFP8_FP8_comparison_1.png │ ├── MXFP8_FP8_comparison_2.png │ ├── advanced_optimizations.ipynb │ ├── attention │ │ ├── arbitrary_mask_to_post_scale_bias.py │ │ ├── attention.ipynb │ │ ├── dot_product_attention.png │ │ └── example_attention.py │ ├── comparison-fp8-bf16-training-nvidia-dgx-cloud-benchmarking-performance-explorer.jpg │ ├── delayed_scaling.png │ ├── fp8_formats.png │ ├── fp8_primer.ipynb │ ├── linear_mxfp8.png │ ├── loss_scaling.png │ ├── quickstart.ipynb │ ├── quickstart_utils.py │ ├── te_llama │ │ ├── media │ │ │ ├── llama_for_causal_lm.svg │ │ │ ├── llama_zoom.svg │ │ │ ├── llamadecoderlayer.svg │ │ │ ├── model_change.svg │ │ │ ├── swiglu.svg │ │ │ ├── swiglu_te.svg │ │ │ ├── tellamadecoderlayer.svg │ │ │ ├── transformer_llama.png │ │ │ ├── transformer_vs_llama.svg │ │ │ └── weight_swap.svg │ │ ├── te_llama.py │ │ ├── tutorial_accelerate_hf_llama_with_te.ipynb │ │ └── utils.py │ └── transformer_layer.png ├── faq.rst ├── index.rst ├── installation.rst └── version_select.patch ├── examples ├── README.md ├── jax │ ├── README.md │ ├── encoder │ │ ├── README.md │ │ ├── common.py │ │ ├── conftest.py │ │ ├── requirements.txt │ │ ├── run_test_multiprocessing_encoder.sh │ │ ├── test_model_parallel_encoder.py │ │ ├── test_multigpu_encoder.py │ │ ├── test_multiprocessing_encoder.py │ │ └── test_single_gpu_encoder.py │ └── mnist │ │ ├── README.md │ │ ├── requirements.txt │ │ └── test_single_gpu_mnist.py └── pytorch │ ├── comm_gemm_overlap │ ├── README.md │ └── te_layer_with_overlap.py │ ├── fsdp │ ├── README.md │ └── fsdp.py │ └── mnist │ ├── README.md │ └── main.py ├── hipify_custom_map.json ├── pylintrc ├── qa ├── L0_cppunittest │ └── test.sh ├── L0_jax_distributed_unittest │ └── test.sh ├── L0_jax_lint │ └── test.sh ├── L0_jax_unittest │ └── test.sh ├── L0_jax_wheel │ └── test.sh ├── L0_license │ ├── config.json │ ├── copyright_checker.py │ └── test.sh ├── L0_pytorch_lint │ └── test.sh ├── L0_pytorch_unittest │ └── test.sh ├── L0_pytorch_wheel │ └── test.sh ├── L1_jax_distributed_unittest │ └── test.sh ├── L1_pytorch_distributed_unittest │ └── test.sh ├── L1_pytorch_mcore_integration │ ├── .gitignore │ ├── merges.txt │ └── test.sh ├── L1_pytorch_thunder_integration │ └── test.sh ├── L2_jax_unittest │ └── test.sh ├── L3_pytorch_FA_versions_test │ └── test.sh └── format.sh ├── setup.py ├── tests ├── cpp │ ├── CMakeLists.txt │ ├── operator │ │ ├── CMakeLists.txt │ │ ├── test_act.cu │ │ ├── test_cast.cu │ │ ├── test_cast_current_scaling.cu │ │ ├── test_cast_dbias.cu │ │ ├── test_cast_dbias_dgelu.cu │ │ ├── test_cast_float8blockwise.cu │ │ ├── test_cast_gated_swiglu.cu │ │ ├── test_cast_mxfp8.cu │ │ ├── test_cast_mxfp8_gated_swiglu.cu │ │ ├── test_cast_transpose.cu │ │ ├── test_cast_transpose_current_scaling.cu │ │ ├── test_cast_transpose_dbias.cu │ │ ├── test_cast_transpose_dbias_dgelu.cu │ │ ├── test_cast_transpose_dgeglu.cu │ │ ├── test_causal_softmax.cu │ │ ├── test_cublaslt_gemm.cu │ │ ├── test_dequantize_mxfp8.cu │ │ ├── test_memset.cu │ │ ├── test_multi_cast_transpose.cu │ │ ├── test_multi_padding.cu │ │ ├── test_normalization.cu │ │ ├── test_normalization.h │ │ ├── test_normalization_mxfp8.cu │ │ ├── test_qdq.cu │ │ ├── test_swizzle.cu │ │ └── test_transpose.cu │ ├── test_common.cu │ ├── test_common.h │ └── util │ │ ├── CMakeLists.txt │ │ ├── test_nvrtc.cpp │ │ └── test_string.cpp ├── jax │ ├── conftest.py │ ├── distributed_test_base.py │ ├── pytest.ini │ ├── test_custom_call_compute.py │ ├── test_distributed_fused_attn.py │ ├── test_distributed_layernorm.py │ ├── test_distributed_layernorm_mlp.py │ ├── test_distributed_softmax.py │ ├── test_functions.py │ ├── test_fused_attn.py │ ├── test_helper.py │ ├── test_layer.py │ ├── test_misc.py │ ├── test_sanity_import.py │ ├── test_sharding.py │ ├── test_softmax.py │ └── utils.py └── pytorch │ ├── distributed │ ├── run_cast_master_weights_to_fp8.py │ ├── run_fsdp2_fp8_model.py │ ├── run_fsdp2_model.py │ ├── run_gemm_with_overlap.py │ ├── run_layer_with_overlap.py │ ├── run_numerics.py │ ├── test_cast_master_weights_to_fp8.py │ ├── test_comm_gemm_overlap.py │ ├── test_fusible_ops.py │ ├── test_fusible_ops_with_userbuffers.py │ ├── test_numerics.py │ ├── test_torch_fsdp2.py │ └── test_torch_fsdp2_fp8.py │ ├── fused_attn │ ├── run_fused_attn_with_cp.py │ ├── test_fused_attn.py │ ├── test_fused_attn_with_cp.py │ └── test_kv_cache.py │ ├── references │ ├── blockwise_fp8_gemm_reference.py │ ├── blockwise_quantizer_reference.py │ ├── quantize_scale_calc.py │ └── ref_per_tensor_cs.py │ ├── test_cpu_offloading.py │ ├── test_cuda_graphs.py │ ├── test_deferred_init.py │ ├── test_float8_blockwise_gemm_exact.py │ ├── test_float8_blockwise_scaling_exact.py │ ├── test_float8_current_scaling_exact.py │ ├── test_float8blockwisetensor.py │ ├── test_float8tensor.py │ ├── test_fused_optimizer.py │ ├── test_fused_rope.py │ ├── test_fusible_ops.py │ ├── test_gemm_autotune.py │ ├── test_gemm_sm_count.py │ ├── test_gqa.py │ ├── test_jit.py │ ├── test_layernorm_saved_tensors_logic.py │ ├── test_multi_tensor.py │ ├── test_numerics.py │ ├── test_parallel_cross_entropy.py │ ├── test_permutation.py │ ├── test_recipe.py │ ├── test_sanity.py │ ├── test_sanity_import.py │ ├── triton_kernels │ ├── test_cast.py │ ├── test_cast_mxfp8.py │ ├── test_common.py │ ├── test_norm_common.py │ └── test_norms.py │ └── utils.py └── transformer_engine ├── __init__.py ├── common ├── CMakeLists.txt ├── __init__.py ├── activation │ ├── activation_template.h │ ├── gelu.cu │ ├── relu.cu │ └── swiglu.cu ├── amd_detail │ ├── hip_f8_impl.h │ ├── hip_float8.h │ └── system.cpp ├── aotriton │ ├── CMakeLists.txt │ └── aotriton_custom.cmake ├── ck_fused_attn │ ├── CMakeLists.txt │ ├── aiter_prebuilt.cmake │ ├── include │ │ └── ck_fused_attn │ │ │ └── ck_fused_attn.hpp │ └── src │ │ ├── ck_fused_attn_bwd.cpp │ │ ├── ck_fused_attn_fwd.cpp │ │ ├── ck_fused_attn_utils.cpp │ │ └── ck_fused_attn_utils.hpp ├── comm_gemm_overlap │ ├── comm_gemm_overlap.cpp │ └── userbuffers │ │ ├── ipcsocket.cc │ │ ├── ipcsocket.h │ │ ├── userbuffers-host.cpp │ │ ├── userbuffers.cu │ │ └── userbuffers.h ├── common.cu ├── common.h ├── cudnn_utils.cpp ├── cudnn_utils.h ├── fused_attn │ ├── context_parallel.cu │ ├── flash_attn.cu │ ├── fused_attn.cpp │ ├── fused_attn_f16_arbitrary_seqlen.cu │ ├── fused_attn_f16_arbitrary_seqlen.h │ ├── fused_attn_f16_max512_seqlen.cu │ ├── fused_attn_f16_max512_seqlen.h │ ├── fused_attn_fp8.cu │ ├── fused_attn_fp8.h │ ├── kv_cache.cu │ ├── utils.cu │ └── utils.h ├── fused_attn_rocm │ ├── fused_attn.cpp │ ├── fused_attn_aotriton.cpp │ ├── fused_attn_aotriton.h │ ├── fused_attn_ck.cpp │ ├── fused_attn_ck.h │ ├── utils.cpp │ └── utils.h ├── fused_rope │ └── fused_rope.cu ├── fused_softmax │ ├── scaled_aligned_causal_masked_softmax.cu │ ├── scaled_masked_softmax.cu │ └── scaled_upper_triang_masked_softmax.cu ├── gemm │ ├── cublaslt_gemm.cu │ └── rocm_gemm.cu ├── include │ └── transformer_engine │ │ ├── activation.h │ │ ├── cast.h │ │ ├── cast_transpose_noop.h │ │ ├── comm_gemm_overlap.h │ │ ├── cudnn.h │ │ ├── fused_attn.h │ │ ├── fused_rope.h │ │ ├── gemm.h │ │ ├── multi_tensor.h │ │ ├── normalization.h │ │ ├── padding.h │ │ ├── permutation.h │ │ ├── recipe.h │ │ ├── softmax.h │ │ ├── swizzle.h │ │ ├── transformer_engine.h │ │ └── transpose.h ├── libtransformer_engine.version ├── multi_tensor │ ├── adam.cu │ ├── compute_scale.cu │ ├── l2norm.cu │ ├── multi_tensor_apply.cuh │ ├── scale.cu │ └── sgd.cu ├── normalization │ ├── common.cpp │ ├── common.h │ ├── kernel_traits.h │ ├── layernorm │ │ ├── ln_api.cpp │ │ ├── ln_bwd_kernels.cuh │ │ ├── ln_bwd_semi_cuda_kernel.cu │ │ ├── ln_fwd_cuda_kernel.cu │ │ └── ln_fwd_kernels.cuh │ └── rmsnorm │ │ ├── rmsnorm_api.cpp │ │ ├── rmsnorm_bwd_kernels.cuh │ │ ├── rmsnorm_bwd_semi_cuda_kernel.cu │ │ ├── rmsnorm_fwd_cuda_kernel.cu │ │ └── rmsnorm_fwd_kernels.cuh ├── nvshmem_api │ ├── CMakeLists.txt │ ├── nvshmem_waitkernel.cu │ └── nvshmem_waitkernel.h ├── nvtx.h ├── permutation │ └── permutation.cu ├── recipe │ ├── __init__.py │ ├── current_scaling.cu │ ├── delayed_scaling.cu │ ├── fp8_block_scaling.cu │ └── recipe_common.cuh ├── rocshmem_api │ ├── CMakeLists.txt │ ├── rocshmem_waitkernel.hip │ └── rocshmem_waitkernel.hpp ├── swizzle │ └── swizzle.cu ├── transformer_engine.cpp ├── transpose │ ├── cast_transpose.cu │ ├── cast_transpose.h │ ├── cast_transpose_fusion.cu │ ├── multi_cast_transpose.cu │ ├── quantize_transpose_square_blockwise.cu │ ├── quantize_transpose_vector_blockwise.cu │ ├── rtc │ │ ├── cast_transpose.cu │ │ ├── cast_transpose_fusion.cu │ │ └── transpose.cu │ ├── transpose.cu │ └── transpose_fusion.cu ├── util │ ├── cast.cu │ ├── cast_gated_kernels.cuh │ ├── cast_kernels.cuh │ ├── cuda_driver.cpp │ ├── cuda_driver.h │ ├── cuda_nvml.cpp │ ├── cuda_nvml.h │ ├── cuda_runtime.cpp │ ├── cuda_runtime.h │ ├── dequantize_kernels.cuh │ ├── handle_manager.h │ ├── logging.h │ ├── math.h │ ├── padding.cu │ ├── ptx.cuh │ ├── pybind_helper.h │ ├── rocm_cast_gated_kernels.cuh │ ├── rocm_cast_kernels.cuh │ ├── rocm_dequantize_kernels.cuh │ ├── rocm_vectorized_2d.cuh │ ├── rtc.cpp │ ├── rtc.h │ ├── shared_lib_wrapper.h │ ├── string.h │ ├── string_header.h.in │ ├── system.h │ └── vectorized_pointwise.h ├── utils.cuh └── utils.py ├── debug ├── __init__.py ├── features │ ├── __init__.py │ ├── _test_dummy_feature.py │ ├── api.py │ ├── disable_fp8_gemm.py │ ├── disable_fp8_layer.py │ ├── fake_quant.py │ ├── log_fp8_tensor_stats.py │ ├── log_tensor_stats.py │ ├── per_tensor_scaling.py │ └── utils │ │ ├── __init__.py │ │ ├── stats_buffer.py │ │ └── stats_computation.py └── pytorch │ ├── __init__.py │ ├── debug_quantization.py │ ├── debug_state.py │ └── utils.py ├── jax ├── MANIFEST.in ├── __init__.py ├── activation.py ├── attention.py ├── cpp_extensions │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── base.py │ ├── gemm.py │ ├── misc.py │ ├── normalization.py │ ├── quantization.py │ └── softmax.py ├── csrc │ ├── extensions.h │ └── extensions │ │ ├── activation.cpp │ │ ├── attention.cpp │ │ ├── cublas.cpp │ │ ├── cudnn.cpp │ │ ├── ffi.cpp │ │ ├── ffi.h │ │ ├── gemm.cpp │ │ ├── misc.cpp │ │ ├── misc.h │ │ ├── normalization.cpp │ │ ├── pybind.cpp │ │ ├── quantization.cpp │ │ ├── softmax.cpp │ │ ├── utils.cpp │ │ └── utils.h ├── dense.py ├── flax │ ├── __init__.py │ ├── module.py │ └── transformer.py ├── layernorm.py ├── layernorm_dense.py ├── layernorm_mlp.py ├── quantize │ ├── __init__.py │ ├── dequantizer.py │ ├── helper.py │ ├── metadata.py │ ├── quantizer.py │ ├── scaling_modes.py │ └── tensor.py ├── setup.py ├── sharding.py ├── softmax.py └── util.py └── pytorch ├── MANIFEST.in ├── __init__.py ├── attention ├── __init__.py ├── dot_product_attention │ ├── __init__.py │ ├── backends.py │ ├── context_parallel.py │ ├── dot_product_attention.py │ ├── softmax.py │ └── utils.py ├── inference.py ├── multi_head_attention.py └── rope.py ├── constants.py ├── cpp_extensions ├── __init__.py ├── fused_attn.py └── gemm.py ├── cpu_offload.py ├── cross_entropy.py ├── csrc ├── common.cpp ├── common.h ├── extensions.h ├── extensions │ ├── activation.cpp │ ├── apply_rope.cpp │ ├── attention.cpp │ ├── bias.cpp │ ├── cast.cpp │ ├── comm_gemm_overlap.cpp │ ├── fp8_block_scaling_partial_cast.cpp │ ├── gemm.cpp │ ├── misc.cpp │ ├── multi_tensor │ │ ├── adam.cpp │ │ ├── compute_scale.cpp │ │ ├── l2norm.cpp │ │ ├── scale.cpp │ │ └── sgd.cpp │ ├── normalization.cpp │ ├── nvshmem_comm.cpp │ ├── padding.cpp │ ├── permutation.cpp │ ├── pybind.cpp │ ├── recipe.cpp │ ├── rocshmem_comm.cpp │ ├── softmax.cpp │ └── transpose.cpp ├── pybind.h ├── quantizer.cpp ├── type_converters.cpp ├── util.cpp └── util.h ├── distributed.py ├── float8_tensor.py ├── fp8.py ├── graph.py ├── jit.py ├── module ├── __init__.py ├── _common.py ├── base.py ├── fp8_padding.py ├── fp8_unpadding.py ├── grouped_linear.py ├── layernorm.py ├── layernorm_linear.py ├── layernorm_mlp.py ├── linear.py └── rmsnorm.py ├── numerics_debug.py ├── ops ├── __init__.py ├── _common.py ├── basic │ ├── __init__.py │ ├── activation.py │ ├── add_in_place.py │ ├── all_gather.py │ ├── all_reduce.py │ ├── basic_linear.py │ ├── bias.py │ ├── identity.py │ ├── layer_norm.py │ ├── make_extra_output.py │ ├── quantize.py │ ├── reduce_scatter.py │ ├── reshape.py │ └── rmsnorm.py ├── fused │ ├── __init__.py │ ├── backward_linear_add.py │ ├── forward_linear_bias_activation.py │ ├── forward_linear_bias_add.py │ ├── userbuffers_backward_linear.py │ └── userbuffers_forward_linear.py ├── fuser.py ├── linear.py ├── op.py └── sequential.py ├── optimizers ├── __init__.py ├── fused_adam.py ├── fused_sgd.py └── multi_tensor_apply.py ├── permutation.py ├── setup.py ├── tensor ├── __init__.py ├── _internal │ ├── __init__.py │ ├── float8_blockwise_tensor_base.py │ ├── float8_tensor_base.py │ └── mxfp8_tensor_base.py ├── float8_blockwise_tensor.py ├── float8_tensor.py ├── fsdp2_allgather_tensor.py ├── mxfp8_tensor.py ├── quantized_tensor.py └── utils.py ├── transformer.py ├── triton ├── __init__.py ├── cross_entropy.py └── permutation.py ├── triton_kernels ├── __init__.py ├── cast.py ├── cast_transpose.py ├── common.py ├── layernorm.py ├── norm_common.py └── rmsnorm.py └── utils.py /.clang-format: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.clang-format -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/CODEOWNERS -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/ISSUE_TEMPLATE/bug_report.md -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/ISSUE_TEMPLATE/feature_request.md -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/PULL_REQUEST_TEMPLATE.md -------------------------------------------------------------------------------- /.github/workflows/blossom-ci.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/workflows/blossom-ci.yml -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/workflows/build.yml -------------------------------------------------------------------------------- /.github/workflows/deploy_nightly_docs.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/workflows/deploy_nightly_docs.yml -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/workflows/docs.yml -------------------------------------------------------------------------------- /.github/workflows/license.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/workflows/license.yml -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/workflows/lint.yml -------------------------------------------------------------------------------- /.github/workflows/rocm-ci.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/workflows/rocm-ci.yml -------------------------------------------------------------------------------- /.github/workflows/trigger-ci.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/workflows/trigger-ci.yml -------------------------------------------------------------------------------- /.github/workflows/upload-ci-logs.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.github/workflows/upload-ci-logs.yml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.gitignore -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.gitmodules -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/.pre-commit-config.yaml -------------------------------------------------------------------------------- /Acknowledgements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/Acknowledgements.txt -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/CONTRIBUTING.rst -------------------------------------------------------------------------------- /CPPLINT.cfg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/CPPLINT.cfg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/LICENSE -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/README.rst -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/SECURITY.md -------------------------------------------------------------------------------- /benchmarks/attention/benchmark_attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/benchmarks/attention/benchmark_attention.py -------------------------------------------------------------------------------- /benchmarks/attention/benchmark_attention_rocm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/benchmarks/attention/benchmark_attention_rocm.py -------------------------------------------------------------------------------- /ci/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/ci/README.md -------------------------------------------------------------------------------- /ci/_utils.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/ci/_utils.sh -------------------------------------------------------------------------------- /ci/ci_config.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/ci/ci_config.json -------------------------------------------------------------------------------- /ci/core.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/ci/core.sh -------------------------------------------------------------------------------- /ci/jax.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/ci/jax.sh -------------------------------------------------------------------------------- /ci/pytorch.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/ci/pytorch.sh -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | doxygen 3 | sphinx_rtd_theme -------------------------------------------------------------------------------- /docs/Doxyfile: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/Doxyfile -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/Makefile -------------------------------------------------------------------------------- /docs/_static/NVIDIA-LogoBlack.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/_static/NVIDIA-LogoBlack.svg -------------------------------------------------------------------------------- /docs/_static/css/nvidia_font.css: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/_static/css/nvidia_font.css -------------------------------------------------------------------------------- /docs/_static/css/nvidia_footer.css: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/_static/css/nvidia_footer.css -------------------------------------------------------------------------------- /docs/_templates/footer.html: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/_templates/footer.html -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/_templates/layout.html -------------------------------------------------------------------------------- /docs/api/c/activation.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/activation.rst -------------------------------------------------------------------------------- /docs/api/c/cast.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/cast.rst -------------------------------------------------------------------------------- /docs/api/c/fused_attn.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/fused_attn.rst -------------------------------------------------------------------------------- /docs/api/c/fused_rope.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/fused_rope.rst -------------------------------------------------------------------------------- /docs/api/c/gemm.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/gemm.rst -------------------------------------------------------------------------------- /docs/api/c/index.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/index.rst -------------------------------------------------------------------------------- /docs/api/c/normalization.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/normalization.rst -------------------------------------------------------------------------------- /docs/api/c/padding.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/padding.rst -------------------------------------------------------------------------------- /docs/api/c/permutation.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/permutation.rst -------------------------------------------------------------------------------- /docs/api/c/recipe.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/recipe.rst -------------------------------------------------------------------------------- /docs/api/c/softmax.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/softmax.rst -------------------------------------------------------------------------------- /docs/api/c/swizzle.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/swizzle.rst -------------------------------------------------------------------------------- /docs/api/c/transformer_engine.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/transformer_engine.rst -------------------------------------------------------------------------------- /docs/api/c/transpose.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/c/transpose.rst -------------------------------------------------------------------------------- /docs/api/common.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/common.rst -------------------------------------------------------------------------------- /docs/api/framework.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/framework.rst -------------------------------------------------------------------------------- /docs/api/jax.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/jax.rst -------------------------------------------------------------------------------- /docs/api/pytorch.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/api/pytorch.rst -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/conf.py -------------------------------------------------------------------------------- /docs/examples/E8M0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/E8M0.png -------------------------------------------------------------------------------- /docs/examples/H200-NeMo-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/H200-NeMo-performance.png -------------------------------------------------------------------------------- /docs/examples/MXFP8_FP8_comparison_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/MXFP8_FP8_comparison_1.png -------------------------------------------------------------------------------- /docs/examples/MXFP8_FP8_comparison_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/MXFP8_FP8_comparison_2.png -------------------------------------------------------------------------------- /docs/examples/advanced_optimizations.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/advanced_optimizations.ipynb -------------------------------------------------------------------------------- /docs/examples/attention/arbitrary_mask_to_post_scale_bias.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py -------------------------------------------------------------------------------- /docs/examples/attention/attention.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/attention/attention.ipynb -------------------------------------------------------------------------------- /docs/examples/attention/dot_product_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/attention/dot_product_attention.png -------------------------------------------------------------------------------- /docs/examples/attention/example_attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/attention/example_attention.py -------------------------------------------------------------------------------- /docs/examples/comparison-fp8-bf16-training-nvidia-dgx-cloud-benchmarking-performance-explorer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/comparison-fp8-bf16-training-nvidia-dgx-cloud-benchmarking-performance-explorer.jpg -------------------------------------------------------------------------------- /docs/examples/delayed_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/delayed_scaling.png -------------------------------------------------------------------------------- /docs/examples/fp8_formats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/fp8_formats.png -------------------------------------------------------------------------------- /docs/examples/fp8_primer.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/fp8_primer.ipynb -------------------------------------------------------------------------------- /docs/examples/linear_mxfp8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/linear_mxfp8.png -------------------------------------------------------------------------------- /docs/examples/loss_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/loss_scaling.png -------------------------------------------------------------------------------- /docs/examples/quickstart.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/quickstart.ipynb -------------------------------------------------------------------------------- /docs/examples/quickstart_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/quickstart_utils.py -------------------------------------------------------------------------------- /docs/examples/te_llama/media/llama_for_causal_lm.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/media/llama_for_causal_lm.svg -------------------------------------------------------------------------------- /docs/examples/te_llama/media/llama_zoom.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/media/llama_zoom.svg -------------------------------------------------------------------------------- /docs/examples/te_llama/media/llamadecoderlayer.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/media/llamadecoderlayer.svg -------------------------------------------------------------------------------- /docs/examples/te_llama/media/model_change.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/media/model_change.svg -------------------------------------------------------------------------------- /docs/examples/te_llama/media/swiglu.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/media/swiglu.svg -------------------------------------------------------------------------------- /docs/examples/te_llama/media/swiglu_te.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/media/swiglu_te.svg -------------------------------------------------------------------------------- /docs/examples/te_llama/media/tellamadecoderlayer.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/media/tellamadecoderlayer.svg -------------------------------------------------------------------------------- /docs/examples/te_llama/media/transformer_llama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/media/transformer_llama.png -------------------------------------------------------------------------------- /docs/examples/te_llama/media/transformer_vs_llama.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/media/transformer_vs_llama.svg -------------------------------------------------------------------------------- /docs/examples/te_llama/media/weight_swap.svg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/media/weight_swap.svg -------------------------------------------------------------------------------- /docs/examples/te_llama/te_llama.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/te_llama.py -------------------------------------------------------------------------------- /docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb -------------------------------------------------------------------------------- /docs/examples/te_llama/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/te_llama/utils.py -------------------------------------------------------------------------------- /docs/examples/transformer_layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/examples/transformer_layer.png -------------------------------------------------------------------------------- /docs/faq.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/faq.rst -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/index.rst -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/installation.rst -------------------------------------------------------------------------------- /docs/version_select.patch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/docs/version_select.patch -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/README.md -------------------------------------------------------------------------------- /examples/jax/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/jax/README.md -------------------------------------------------------------------------------- /examples/jax/encoder/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/jax/encoder/README.md -------------------------------------------------------------------------------- /examples/jax/encoder/common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/jax/encoder/common.py -------------------------------------------------------------------------------- /examples/jax/encoder/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/jax/encoder/conftest.py -------------------------------------------------------------------------------- /examples/jax/encoder/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets<4.0.0 2 | flax>=0.7.1 3 | nltk>=3.8.2 4 | optax 5 | -------------------------------------------------------------------------------- /examples/jax/encoder/run_test_multiprocessing_encoder.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/jax/encoder/run_test_multiprocessing_encoder.sh -------------------------------------------------------------------------------- /examples/jax/encoder/test_model_parallel_encoder.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/jax/encoder/test_model_parallel_encoder.py -------------------------------------------------------------------------------- /examples/jax/encoder/test_multigpu_encoder.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/jax/encoder/test_multigpu_encoder.py -------------------------------------------------------------------------------- /examples/jax/encoder/test_multiprocessing_encoder.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/jax/encoder/test_multiprocessing_encoder.py -------------------------------------------------------------------------------- /examples/jax/encoder/test_single_gpu_encoder.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/jax/encoder/test_single_gpu_encoder.py -------------------------------------------------------------------------------- /examples/jax/mnist/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/jax/mnist/README.md -------------------------------------------------------------------------------- /examples/jax/mnist/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets<4.0.0 2 | flax>=0.7.1 3 | optax 4 | Pillow 5 | -------------------------------------------------------------------------------- /examples/jax/mnist/test_single_gpu_mnist.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/jax/mnist/test_single_gpu_mnist.py -------------------------------------------------------------------------------- /examples/pytorch/comm_gemm_overlap/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/pytorch/comm_gemm_overlap/README.md -------------------------------------------------------------------------------- /examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py -------------------------------------------------------------------------------- /examples/pytorch/fsdp/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/pytorch/fsdp/README.md -------------------------------------------------------------------------------- /examples/pytorch/fsdp/fsdp.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/pytorch/fsdp/fsdp.py -------------------------------------------------------------------------------- /examples/pytorch/mnist/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/pytorch/mnist/README.md -------------------------------------------------------------------------------- /examples/pytorch/mnist/main.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/examples/pytorch/mnist/main.py -------------------------------------------------------------------------------- /hipify_custom_map.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/hipify_custom_map.json -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/pylintrc -------------------------------------------------------------------------------- /qa/L0_cppunittest/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L0_cppunittest/test.sh -------------------------------------------------------------------------------- /qa/L0_jax_distributed_unittest/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L0_jax_distributed_unittest/test.sh -------------------------------------------------------------------------------- /qa/L0_jax_lint/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L0_jax_lint/test.sh -------------------------------------------------------------------------------- /qa/L0_jax_unittest/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L0_jax_unittest/test.sh -------------------------------------------------------------------------------- /qa/L0_jax_wheel/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L0_jax_wheel/test.sh -------------------------------------------------------------------------------- /qa/L0_license/config.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L0_license/config.json -------------------------------------------------------------------------------- /qa/L0_license/copyright_checker.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L0_license/copyright_checker.py -------------------------------------------------------------------------------- /qa/L0_license/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L0_license/test.sh -------------------------------------------------------------------------------- /qa/L0_pytorch_lint/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L0_pytorch_lint/test.sh -------------------------------------------------------------------------------- /qa/L0_pytorch_unittest/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L0_pytorch_unittest/test.sh -------------------------------------------------------------------------------- /qa/L0_pytorch_wheel/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L0_pytorch_wheel/test.sh -------------------------------------------------------------------------------- /qa/L1_jax_distributed_unittest/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L1_jax_distributed_unittest/test.sh -------------------------------------------------------------------------------- /qa/L1_pytorch_distributed_unittest/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L1_pytorch_distributed_unittest/test.sh -------------------------------------------------------------------------------- /qa/L1_pytorch_mcore_integration/.gitignore: -------------------------------------------------------------------------------- 1 | Megatron-LM 2 | vocab.json -------------------------------------------------------------------------------- /qa/L1_pytorch_mcore_integration/merges.txt: -------------------------------------------------------------------------------- 1 | #version: 0.2 2 | -------------------------------------------------------------------------------- /qa/L1_pytorch_mcore_integration/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L1_pytorch_mcore_integration/test.sh -------------------------------------------------------------------------------- /qa/L1_pytorch_thunder_integration/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L1_pytorch_thunder_integration/test.sh -------------------------------------------------------------------------------- /qa/L2_jax_unittest/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L2_jax_unittest/test.sh -------------------------------------------------------------------------------- /qa/L3_pytorch_FA_versions_test/test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/L3_pytorch_FA_versions_test/test.sh -------------------------------------------------------------------------------- /qa/format.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/qa/format.sh -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/setup.py -------------------------------------------------------------------------------- /tests/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/CMakeLists.txt -------------------------------------------------------------------------------- /tests/cpp/operator/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/CMakeLists.txt -------------------------------------------------------------------------------- /tests/cpp/operator/test_act.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_act.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_current_scaling.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_current_scaling.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_dbias.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_dbias.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_dbias_dgelu.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_dbias_dgelu.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_float8blockwise.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_float8blockwise.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_gated_swiglu.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_gated_swiglu.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_mxfp8.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_mxfp8.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_transpose.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_transpose.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_transpose_current_scaling.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_transpose_current_scaling.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_transpose_dbias.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_transpose_dbias.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cast_transpose_dgeglu.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cast_transpose_dgeglu.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_causal_softmax.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_causal_softmax.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_cublaslt_gemm.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_cublaslt_gemm.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_dequantize_mxfp8.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_dequantize_mxfp8.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_memset.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_memset.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_multi_cast_transpose.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_multi_cast_transpose.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_multi_padding.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_multi_padding.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_normalization.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_normalization.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_normalization.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_normalization.h -------------------------------------------------------------------------------- /tests/cpp/operator/test_normalization_mxfp8.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_normalization_mxfp8.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_qdq.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_qdq.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_swizzle.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_swizzle.cu -------------------------------------------------------------------------------- /tests/cpp/operator/test_transpose.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/operator/test_transpose.cu -------------------------------------------------------------------------------- /tests/cpp/test_common.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/test_common.cu -------------------------------------------------------------------------------- /tests/cpp/test_common.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/test_common.h -------------------------------------------------------------------------------- /tests/cpp/util/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/util/CMakeLists.txt -------------------------------------------------------------------------------- /tests/cpp/util/test_nvrtc.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/util/test_nvrtc.cpp -------------------------------------------------------------------------------- /tests/cpp/util/test_string.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/cpp/util/test_string.cpp -------------------------------------------------------------------------------- /tests/jax/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/conftest.py -------------------------------------------------------------------------------- /tests/jax/distributed_test_base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/distributed_test_base.py -------------------------------------------------------------------------------- /tests/jax/pytest.ini: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/pytest.ini -------------------------------------------------------------------------------- /tests/jax/test_custom_call_compute.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_custom_call_compute.py -------------------------------------------------------------------------------- /tests/jax/test_distributed_fused_attn.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_distributed_fused_attn.py -------------------------------------------------------------------------------- /tests/jax/test_distributed_layernorm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_distributed_layernorm.py -------------------------------------------------------------------------------- /tests/jax/test_distributed_layernorm_mlp.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_distributed_layernorm_mlp.py -------------------------------------------------------------------------------- /tests/jax/test_distributed_softmax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_distributed_softmax.py -------------------------------------------------------------------------------- /tests/jax/test_functions.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_functions.py -------------------------------------------------------------------------------- /tests/jax/test_fused_attn.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_fused_attn.py -------------------------------------------------------------------------------- /tests/jax/test_helper.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_helper.py -------------------------------------------------------------------------------- /tests/jax/test_layer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_layer.py -------------------------------------------------------------------------------- /tests/jax/test_misc.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_misc.py -------------------------------------------------------------------------------- /tests/jax/test_sanity_import.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_sanity_import.py -------------------------------------------------------------------------------- /tests/jax/test_sharding.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_sharding.py -------------------------------------------------------------------------------- /tests/jax/test_softmax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/test_softmax.py -------------------------------------------------------------------------------- /tests/jax/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/jax/utils.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/run_cast_master_weights_to_fp8.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/run_fsdp2_fp8_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/run_fsdp2_fp8_model.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/run_fsdp2_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/run_fsdp2_model.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/run_gemm_with_overlap.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/run_gemm_with_overlap.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/run_layer_with_overlap.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/run_layer_with_overlap.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/run_numerics.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/run_numerics.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/test_cast_master_weights_to_fp8.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/test_comm_gemm_overlap.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/test_comm_gemm_overlap.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/test_fusible_ops.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/test_fusible_ops.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/test_numerics.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/test_numerics.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/test_torch_fsdp2.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/test_torch_fsdp2.py -------------------------------------------------------------------------------- /tests/pytorch/distributed/test_torch_fsdp2_fp8.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/distributed/test_torch_fsdp2_fp8.py -------------------------------------------------------------------------------- /tests/pytorch/fused_attn/run_fused_attn_with_cp.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/fused_attn/run_fused_attn_with_cp.py -------------------------------------------------------------------------------- /tests/pytorch/fused_attn/test_fused_attn.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/fused_attn/test_fused_attn.py -------------------------------------------------------------------------------- /tests/pytorch/fused_attn/test_fused_attn_with_cp.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/fused_attn/test_fused_attn_with_cp.py -------------------------------------------------------------------------------- /tests/pytorch/fused_attn/test_kv_cache.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/fused_attn/test_kv_cache.py -------------------------------------------------------------------------------- /tests/pytorch/references/blockwise_fp8_gemm_reference.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/references/blockwise_fp8_gemm_reference.py -------------------------------------------------------------------------------- /tests/pytorch/references/blockwise_quantizer_reference.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/references/blockwise_quantizer_reference.py -------------------------------------------------------------------------------- /tests/pytorch/references/quantize_scale_calc.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/references/quantize_scale_calc.py -------------------------------------------------------------------------------- /tests/pytorch/references/ref_per_tensor_cs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/references/ref_per_tensor_cs.py -------------------------------------------------------------------------------- /tests/pytorch/test_cpu_offloading.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_cpu_offloading.py -------------------------------------------------------------------------------- /tests/pytorch/test_cuda_graphs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_cuda_graphs.py -------------------------------------------------------------------------------- /tests/pytorch/test_deferred_init.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_deferred_init.py -------------------------------------------------------------------------------- /tests/pytorch/test_float8_blockwise_gemm_exact.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_float8_blockwise_gemm_exact.py -------------------------------------------------------------------------------- /tests/pytorch/test_float8_blockwise_scaling_exact.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_float8_blockwise_scaling_exact.py -------------------------------------------------------------------------------- /tests/pytorch/test_float8_current_scaling_exact.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_float8_current_scaling_exact.py -------------------------------------------------------------------------------- /tests/pytorch/test_float8blockwisetensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_float8blockwisetensor.py -------------------------------------------------------------------------------- /tests/pytorch/test_float8tensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_float8tensor.py -------------------------------------------------------------------------------- /tests/pytorch/test_fused_optimizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_fused_optimizer.py -------------------------------------------------------------------------------- /tests/pytorch/test_fused_rope.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_fused_rope.py -------------------------------------------------------------------------------- /tests/pytorch/test_fusible_ops.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_fusible_ops.py -------------------------------------------------------------------------------- /tests/pytorch/test_gemm_autotune.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_gemm_autotune.py -------------------------------------------------------------------------------- /tests/pytorch/test_gemm_sm_count.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_gemm_sm_count.py -------------------------------------------------------------------------------- /tests/pytorch/test_gqa.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_gqa.py -------------------------------------------------------------------------------- /tests/pytorch/test_jit.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_jit.py -------------------------------------------------------------------------------- /tests/pytorch/test_layernorm_saved_tensors_logic.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_layernorm_saved_tensors_logic.py -------------------------------------------------------------------------------- /tests/pytorch/test_multi_tensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_multi_tensor.py -------------------------------------------------------------------------------- /tests/pytorch/test_numerics.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_numerics.py -------------------------------------------------------------------------------- /tests/pytorch/test_parallel_cross_entropy.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_parallel_cross_entropy.py -------------------------------------------------------------------------------- /tests/pytorch/test_permutation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_permutation.py -------------------------------------------------------------------------------- /tests/pytorch/test_recipe.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_recipe.py -------------------------------------------------------------------------------- /tests/pytorch/test_sanity.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_sanity.py -------------------------------------------------------------------------------- /tests/pytorch/test_sanity_import.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/test_sanity_import.py -------------------------------------------------------------------------------- /tests/pytorch/triton_kernels/test_cast.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/triton_kernels/test_cast.py -------------------------------------------------------------------------------- /tests/pytorch/triton_kernels/test_cast_mxfp8.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/triton_kernels/test_cast_mxfp8.py -------------------------------------------------------------------------------- /tests/pytorch/triton_kernels/test_common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/triton_kernels/test_common.py -------------------------------------------------------------------------------- /tests/pytorch/triton_kernels/test_norm_common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/triton_kernels/test_norm_common.py -------------------------------------------------------------------------------- /tests/pytorch/triton_kernels/test_norms.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/triton_kernels/test_norms.py -------------------------------------------------------------------------------- /tests/pytorch/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/tests/pytorch/utils.py -------------------------------------------------------------------------------- /transformer_engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/__init__.py -------------------------------------------------------------------------------- /transformer_engine/common/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/CMakeLists.txt -------------------------------------------------------------------------------- /transformer_engine/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/__init__.py -------------------------------------------------------------------------------- /transformer_engine/common/activation/activation_template.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/activation/activation_template.h -------------------------------------------------------------------------------- /transformer_engine/common/activation/gelu.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/activation/gelu.cu -------------------------------------------------------------------------------- /transformer_engine/common/activation/relu.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/activation/relu.cu -------------------------------------------------------------------------------- /transformer_engine/common/activation/swiglu.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/activation/swiglu.cu -------------------------------------------------------------------------------- /transformer_engine/common/amd_detail/hip_f8_impl.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/amd_detail/hip_f8_impl.h -------------------------------------------------------------------------------- /transformer_engine/common/amd_detail/hip_float8.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/amd_detail/hip_float8.h -------------------------------------------------------------------------------- /transformer_engine/common/amd_detail/system.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/amd_detail/system.cpp -------------------------------------------------------------------------------- /transformer_engine/common/aotriton/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/aotriton/CMakeLists.txt -------------------------------------------------------------------------------- /transformer_engine/common/aotriton/aotriton_custom.cmake: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/aotriton/aotriton_custom.cmake -------------------------------------------------------------------------------- /transformer_engine/common/ck_fused_attn/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/ck_fused_attn/CMakeLists.txt -------------------------------------------------------------------------------- /transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake -------------------------------------------------------------------------------- /transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp -------------------------------------------------------------------------------- /transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp -------------------------------------------------------------------------------- /transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp -------------------------------------------------------------------------------- /transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp -------------------------------------------------------------------------------- /transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp -------------------------------------------------------------------------------- /transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp -------------------------------------------------------------------------------- /transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc -------------------------------------------------------------------------------- /transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h -------------------------------------------------------------------------------- /transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp -------------------------------------------------------------------------------- /transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu -------------------------------------------------------------------------------- /transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h -------------------------------------------------------------------------------- /transformer_engine/common/common.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/common.cu -------------------------------------------------------------------------------- /transformer_engine/common/common.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/common.h -------------------------------------------------------------------------------- /transformer_engine/common/cudnn_utils.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/cudnn_utils.cpp -------------------------------------------------------------------------------- /transformer_engine/common/cudnn_utils.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/cudnn_utils.h -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/context_parallel.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/context_parallel.cu -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/flash_attn.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/flash_attn.cu -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/fused_attn.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/fused_attn.cpp -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/fused_attn_fp8.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/fused_attn_fp8.cu -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/fused_attn_fp8.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/fused_attn_fp8.h -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/kv_cache.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/kv_cache.cu -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/utils.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/utils.cu -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn/utils.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn/utils.h -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn_rocm/fused_attn.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn_rocm/fused_attn.cpp -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn_rocm/fused_attn_ck.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn_rocm/utils.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn_rocm/utils.cpp -------------------------------------------------------------------------------- /transformer_engine/common/fused_attn_rocm/utils.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_attn_rocm/utils.h -------------------------------------------------------------------------------- /transformer_engine/common/fused_rope/fused_rope.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_rope/fused_rope.cu -------------------------------------------------------------------------------- /transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu -------------------------------------------------------------------------------- /transformer_engine/common/fused_softmax/scaled_masked_softmax.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu -------------------------------------------------------------------------------- /transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu -------------------------------------------------------------------------------- /transformer_engine/common/gemm/cublaslt_gemm.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/gemm/cublaslt_gemm.cu -------------------------------------------------------------------------------- /transformer_engine/common/gemm/rocm_gemm.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/gemm/rocm_gemm.cu -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/activation.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/activation.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/cast.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/cast.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/cast_transpose_noop.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/cudnn.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/cudnn.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/fused_attn.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/fused_attn.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/fused_rope.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/fused_rope.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/gemm.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/gemm.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/multi_tensor.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/multi_tensor.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/normalization.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/normalization.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/padding.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/padding.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/permutation.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/permutation.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/recipe.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/recipe.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/softmax.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/softmax.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/swizzle.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/swizzle.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/transformer_engine.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/transformer_engine.h -------------------------------------------------------------------------------- /transformer_engine/common/include/transformer_engine/transpose.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/include/transformer_engine/transpose.h -------------------------------------------------------------------------------- /transformer_engine/common/libtransformer_engine.version: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/libtransformer_engine.version -------------------------------------------------------------------------------- /transformer_engine/common/multi_tensor/adam.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/multi_tensor/adam.cu -------------------------------------------------------------------------------- /transformer_engine/common/multi_tensor/compute_scale.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/multi_tensor/compute_scale.cu -------------------------------------------------------------------------------- /transformer_engine/common/multi_tensor/l2norm.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/multi_tensor/l2norm.cu -------------------------------------------------------------------------------- /transformer_engine/common/multi_tensor/multi_tensor_apply.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh -------------------------------------------------------------------------------- /transformer_engine/common/multi_tensor/scale.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/multi_tensor/scale.cu -------------------------------------------------------------------------------- /transformer_engine/common/multi_tensor/sgd.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/multi_tensor/sgd.cu -------------------------------------------------------------------------------- /transformer_engine/common/normalization/common.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/common.cpp -------------------------------------------------------------------------------- /transformer_engine/common/normalization/common.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/common.h -------------------------------------------------------------------------------- /transformer_engine/common/normalization/kernel_traits.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/kernel_traits.h -------------------------------------------------------------------------------- /transformer_engine/common/normalization/layernorm/ln_api.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/layernorm/ln_api.cpp -------------------------------------------------------------------------------- /transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh -------------------------------------------------------------------------------- /transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu -------------------------------------------------------------------------------- /transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu -------------------------------------------------------------------------------- /transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh -------------------------------------------------------------------------------- /transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp -------------------------------------------------------------------------------- /transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh -------------------------------------------------------------------------------- /transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu -------------------------------------------------------------------------------- /transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu -------------------------------------------------------------------------------- /transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh -------------------------------------------------------------------------------- /transformer_engine/common/nvshmem_api/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/nvshmem_api/CMakeLists.txt -------------------------------------------------------------------------------- /transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu -------------------------------------------------------------------------------- /transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h -------------------------------------------------------------------------------- /transformer_engine/common/nvtx.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/nvtx.h -------------------------------------------------------------------------------- /transformer_engine/common/permutation/permutation.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/permutation/permutation.cu -------------------------------------------------------------------------------- /transformer_engine/common/recipe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/recipe/__init__.py -------------------------------------------------------------------------------- /transformer_engine/common/recipe/current_scaling.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/recipe/current_scaling.cu -------------------------------------------------------------------------------- /transformer_engine/common/recipe/delayed_scaling.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/recipe/delayed_scaling.cu -------------------------------------------------------------------------------- /transformer_engine/common/recipe/fp8_block_scaling.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/recipe/fp8_block_scaling.cu -------------------------------------------------------------------------------- /transformer_engine/common/recipe/recipe_common.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/recipe/recipe_common.cuh -------------------------------------------------------------------------------- /transformer_engine/common/rocshmem_api/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/rocshmem_api/CMakeLists.txt -------------------------------------------------------------------------------- /transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hip -------------------------------------------------------------------------------- /transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/rocshmem_api/rocshmem_waitkernel.hpp -------------------------------------------------------------------------------- /transformer_engine/common/swizzle/swizzle.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/swizzle/swizzle.cu -------------------------------------------------------------------------------- /transformer_engine/common/transformer_engine.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transformer_engine.cpp -------------------------------------------------------------------------------- /transformer_engine/common/transpose/cast_transpose.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transpose/cast_transpose.cu -------------------------------------------------------------------------------- /transformer_engine/common/transpose/cast_transpose.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transpose/cast_transpose.h -------------------------------------------------------------------------------- /transformer_engine/common/transpose/cast_transpose_fusion.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transpose/cast_transpose_fusion.cu -------------------------------------------------------------------------------- /transformer_engine/common/transpose/multi_cast_transpose.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transpose/multi_cast_transpose.cu -------------------------------------------------------------------------------- /transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu -------------------------------------------------------------------------------- /transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu -------------------------------------------------------------------------------- /transformer_engine/common/transpose/rtc/cast_transpose.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transpose/rtc/cast_transpose.cu -------------------------------------------------------------------------------- /transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu -------------------------------------------------------------------------------- /transformer_engine/common/transpose/rtc/transpose.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transpose/rtc/transpose.cu -------------------------------------------------------------------------------- /transformer_engine/common/transpose/transpose.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transpose/transpose.cu -------------------------------------------------------------------------------- /transformer_engine/common/transpose/transpose_fusion.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/transpose/transpose_fusion.cu -------------------------------------------------------------------------------- /transformer_engine/common/util/cast.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/cast.cu -------------------------------------------------------------------------------- /transformer_engine/common/util/cast_gated_kernels.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/cast_gated_kernels.cuh -------------------------------------------------------------------------------- /transformer_engine/common/util/cast_kernels.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/cast_kernels.cuh -------------------------------------------------------------------------------- /transformer_engine/common/util/cuda_driver.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/cuda_driver.cpp -------------------------------------------------------------------------------- /transformer_engine/common/util/cuda_driver.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/cuda_driver.h -------------------------------------------------------------------------------- /transformer_engine/common/util/cuda_nvml.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/cuda_nvml.cpp -------------------------------------------------------------------------------- /transformer_engine/common/util/cuda_nvml.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/cuda_nvml.h -------------------------------------------------------------------------------- /transformer_engine/common/util/cuda_runtime.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/cuda_runtime.cpp -------------------------------------------------------------------------------- /transformer_engine/common/util/cuda_runtime.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/cuda_runtime.h -------------------------------------------------------------------------------- /transformer_engine/common/util/dequantize_kernels.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/dequantize_kernels.cuh -------------------------------------------------------------------------------- /transformer_engine/common/util/handle_manager.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/handle_manager.h -------------------------------------------------------------------------------- /transformer_engine/common/util/logging.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/logging.h -------------------------------------------------------------------------------- /transformer_engine/common/util/math.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/math.h -------------------------------------------------------------------------------- /transformer_engine/common/util/padding.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/padding.cu -------------------------------------------------------------------------------- /transformer_engine/common/util/ptx.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/ptx.cuh -------------------------------------------------------------------------------- /transformer_engine/common/util/pybind_helper.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/pybind_helper.h -------------------------------------------------------------------------------- /transformer_engine/common/util/rocm_cast_gated_kernels.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/rocm_cast_gated_kernels.cuh -------------------------------------------------------------------------------- /transformer_engine/common/util/rocm_cast_kernels.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/rocm_cast_kernels.cuh -------------------------------------------------------------------------------- /transformer_engine/common/util/rocm_dequantize_kernels.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/rocm_dequantize_kernels.cuh -------------------------------------------------------------------------------- /transformer_engine/common/util/rocm_vectorized_2d.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/rocm_vectorized_2d.cuh -------------------------------------------------------------------------------- /transformer_engine/common/util/rtc.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/rtc.cpp -------------------------------------------------------------------------------- /transformer_engine/common/util/rtc.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/rtc.h -------------------------------------------------------------------------------- /transformer_engine/common/util/shared_lib_wrapper.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/shared_lib_wrapper.h -------------------------------------------------------------------------------- /transformer_engine/common/util/string.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/string.h -------------------------------------------------------------------------------- /transformer_engine/common/util/string_header.h.in: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/string_header.h.in -------------------------------------------------------------------------------- /transformer_engine/common/util/system.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/system.h -------------------------------------------------------------------------------- /transformer_engine/common/util/vectorized_pointwise.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/util/vectorized_pointwise.h -------------------------------------------------------------------------------- /transformer_engine/common/utils.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/utils.cuh -------------------------------------------------------------------------------- /transformer_engine/common/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/common/utils.py -------------------------------------------------------------------------------- /transformer_engine/debug/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/__init__.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/__init__.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/_test_dummy_feature.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/_test_dummy_feature.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/api.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/api.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/disable_fp8_gemm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/disable_fp8_gemm.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/disable_fp8_layer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/disable_fp8_layer.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/fake_quant.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/fake_quant.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/log_fp8_tensor_stats.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/log_fp8_tensor_stats.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/log_tensor_stats.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/log_tensor_stats.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/per_tensor_scaling.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/per_tensor_scaling.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/utils/__init__.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/utils/stats_buffer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/utils/stats_buffer.py -------------------------------------------------------------------------------- /transformer_engine/debug/features/utils/stats_computation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/features/utils/stats_computation.py -------------------------------------------------------------------------------- /transformer_engine/debug/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/pytorch/__init__.py -------------------------------------------------------------------------------- /transformer_engine/debug/pytorch/debug_quantization.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/pytorch/debug_quantization.py -------------------------------------------------------------------------------- /transformer_engine/debug/pytorch/debug_state.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/pytorch/debug_state.py -------------------------------------------------------------------------------- /transformer_engine/debug/pytorch/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/debug/pytorch/utils.py -------------------------------------------------------------------------------- /transformer_engine/jax/MANIFEST.in: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/MANIFEST.in -------------------------------------------------------------------------------- /transformer_engine/jax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/__init__.py -------------------------------------------------------------------------------- /transformer_engine/jax/activation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/activation.py -------------------------------------------------------------------------------- /transformer_engine/jax/attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/attention.py -------------------------------------------------------------------------------- /transformer_engine/jax/cpp_extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/cpp_extensions/__init__.py -------------------------------------------------------------------------------- /transformer_engine/jax/cpp_extensions/activation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/cpp_extensions/activation.py -------------------------------------------------------------------------------- /transformer_engine/jax/cpp_extensions/attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/cpp_extensions/attention.py -------------------------------------------------------------------------------- /transformer_engine/jax/cpp_extensions/base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/cpp_extensions/base.py -------------------------------------------------------------------------------- /transformer_engine/jax/cpp_extensions/gemm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/cpp_extensions/gemm.py -------------------------------------------------------------------------------- /transformer_engine/jax/cpp_extensions/misc.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/cpp_extensions/misc.py -------------------------------------------------------------------------------- /transformer_engine/jax/cpp_extensions/normalization.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/cpp_extensions/normalization.py -------------------------------------------------------------------------------- /transformer_engine/jax/cpp_extensions/quantization.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/cpp_extensions/quantization.py -------------------------------------------------------------------------------- /transformer_engine/jax/cpp_extensions/softmax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/cpp_extensions/softmax.py -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions.h -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/activation.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/activation.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/attention.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/attention.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/cublas.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/cublas.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/cudnn.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/cudnn.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/ffi.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/ffi.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/ffi.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/ffi.h -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/gemm.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/gemm.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/misc.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/misc.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/misc.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/misc.h -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/normalization.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/normalization.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/pybind.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/pybind.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/quantization.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/quantization.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/softmax.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/softmax.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/utils.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/utils.cpp -------------------------------------------------------------------------------- /transformer_engine/jax/csrc/extensions/utils.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/csrc/extensions/utils.h -------------------------------------------------------------------------------- /transformer_engine/jax/dense.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/dense.py -------------------------------------------------------------------------------- /transformer_engine/jax/flax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/flax/__init__.py -------------------------------------------------------------------------------- /transformer_engine/jax/flax/module.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/flax/module.py -------------------------------------------------------------------------------- /transformer_engine/jax/flax/transformer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/flax/transformer.py -------------------------------------------------------------------------------- /transformer_engine/jax/layernorm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/layernorm.py -------------------------------------------------------------------------------- /transformer_engine/jax/layernorm_dense.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/layernorm_dense.py -------------------------------------------------------------------------------- /transformer_engine/jax/layernorm_mlp.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/layernorm_mlp.py -------------------------------------------------------------------------------- /transformer_engine/jax/quantize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/quantize/__init__.py -------------------------------------------------------------------------------- /transformer_engine/jax/quantize/dequantizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/quantize/dequantizer.py -------------------------------------------------------------------------------- /transformer_engine/jax/quantize/helper.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/quantize/helper.py -------------------------------------------------------------------------------- /transformer_engine/jax/quantize/metadata.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/quantize/metadata.py -------------------------------------------------------------------------------- /transformer_engine/jax/quantize/quantizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/quantize/quantizer.py -------------------------------------------------------------------------------- /transformer_engine/jax/quantize/scaling_modes.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/quantize/scaling_modes.py -------------------------------------------------------------------------------- /transformer_engine/jax/quantize/tensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/quantize/tensor.py -------------------------------------------------------------------------------- /transformer_engine/jax/setup.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/setup.py -------------------------------------------------------------------------------- /transformer_engine/jax/sharding.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/sharding.py -------------------------------------------------------------------------------- /transformer_engine/jax/softmax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/softmax.py -------------------------------------------------------------------------------- /transformer_engine/jax/util.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/jax/util.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/MANIFEST.in: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/MANIFEST.in -------------------------------------------------------------------------------- /transformer_engine/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/attention/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/dot_product_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/attention/dot_product_attention/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/dot_product_attention/backends.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/attention/dot_product_attention/backends.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/dot_product_attention/softmax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/attention/dot_product_attention/softmax.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/dot_product_attention/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/attention/dot_product_attention/utils.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/inference.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/attention/inference.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/multi_head_attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/attention/multi_head_attention.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/attention/rope.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/attention/rope.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/constants.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/constants.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/cpp_extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/cpp_extensions/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/cpp_extensions/fused_attn.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/cpp_extensions/fused_attn.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/cpp_extensions/gemm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/cpp_extensions/gemm.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/cpu_offload.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/cpu_offload.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/cross_entropy.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/cross_entropy.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/common.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/common.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/common.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/common.h -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions.h -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/activation.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/activation.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/apply_rope.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/attention.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/attention.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/bias.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/bias.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/cast.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/cast.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/gemm.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/gemm.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/misc.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/misc.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/normalization.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/normalization.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/padding.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/padding.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/permutation.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/permutation.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/pybind.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/pybind.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/recipe.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/recipe.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/rocshmem_comm.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/softmax.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/softmax.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/extensions/transpose.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/extensions/transpose.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/pybind.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/pybind.h -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/quantizer.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/quantizer.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/type_converters.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/type_converters.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/util.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/util.cpp -------------------------------------------------------------------------------- /transformer_engine/pytorch/csrc/util.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/csrc/util.h -------------------------------------------------------------------------------- /transformer_engine/pytorch/distributed.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/distributed.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/float8_tensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/float8_tensor.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/fp8.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/fp8.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/graph.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/graph.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/jit.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/jit.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/module/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/_common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/module/_common.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/module/base.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/fp8_padding.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/module/fp8_padding.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/fp8_unpadding.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/module/fp8_unpadding.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/grouped_linear.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/module/grouped_linear.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/layernorm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/module/layernorm.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/layernorm_linear.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/module/layernorm_linear.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/layernorm_mlp.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/module/layernorm_mlp.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/linear.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/module/linear.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/module/rmsnorm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/module/rmsnorm.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/numerics_debug.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/numerics_debug.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/_common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/_common.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/activation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/activation.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/add_in_place.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/add_in_place.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/all_gather.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/all_gather.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/all_reduce.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/all_reduce.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/basic_linear.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/basic_linear.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/bias.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/bias.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/identity.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/identity.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/layer_norm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/layer_norm.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/make_extra_output.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/make_extra_output.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/quantize.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/quantize.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/reduce_scatter.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/reduce_scatter.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/reshape.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/reshape.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/basic/rmsnorm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/basic/rmsnorm.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/fused/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/fused/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/fused/backward_linear_add.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/fused/backward_linear_add.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/fuser.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/fuser.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/linear.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/linear.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/op.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/op.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/ops/sequential.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/ops/sequential.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/optimizers/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/optimizers/fused_adam.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/optimizers/fused_adam.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/optimizers/fused_sgd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/optimizers/fused_sgd.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/optimizers/multi_tensor_apply.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/optimizers/multi_tensor_apply.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/permutation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/permutation.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/setup.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/setup.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/tensor/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/_internal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/tensor/_internal/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/float8_blockwise_tensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/float8_tensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/tensor/float8_tensor.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/mxfp8_tensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/tensor/mxfp8_tensor.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/quantized_tensor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/tensor/quantized_tensor.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/tensor/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/tensor/utils.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/transformer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/transformer.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/triton/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/triton/cross_entropy.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/triton/cross_entropy.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/triton/permutation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/triton/permutation.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/triton_kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/triton_kernels/__init__.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/triton_kernels/cast.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/triton_kernels/cast.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/triton_kernels/cast_transpose.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/triton_kernels/cast_transpose.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/triton_kernels/common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/triton_kernels/common.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/triton_kernels/layernorm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/triton_kernels/layernorm.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/triton_kernels/norm_common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/triton_kernels/norm_common.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/triton_kernels/rmsnorm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/triton_kernels/rmsnorm.py -------------------------------------------------------------------------------- /transformer_engine/pytorch/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/TransformerEngine/HEAD/transformer_engine/pytorch/utils.py --------------------------------------------------------------------------------