├── .nojekyll ├── apex ├── contrib │ ├── __init__.py │ ├── test │ │ ├── __init__.py │ │ ├── fmha │ │ │ └── __init__.py │ │ ├── xentropy │ │ │ └── __init__.py │ │ ├── bottleneck │ │ │ └── __init__.py │ │ ├── clip_grad │ │ │ └── __init__.py │ │ ├── cudnn_gbn │ │ │ └── __init__.py │ │ ├── focal_loss │ │ │ ├── __init__.py │ │ │ └── test_focal_loss.py │ │ ├── group_norm │ │ │ └── __init__.py │ │ ├── index_mul_2d │ │ │ └── __init__.py │ │ ├── layer_norm │ │ │ └── __init__.py │ │ ├── optimizers │ │ │ └── __init__.py │ │ ├── peer_memory │ │ │ └── __init__.py │ │ ├── transducer │ │ │ └── __init__.py │ │ ├── conv_bias_relu │ │ │ └── __init__.py │ │ ├── multihead_attn │ │ │ ├── __init__.py │ │ │ ├── test_mha_fused_softmax.py │ │ │ ├── test_self_multihead_attn_norm_add.py │ │ │ └── test_fast_self_multihead_attn_bias.py │ │ └── fused_dense │ │ │ └── test_fused_dense.py │ ├── fmha │ │ └── __init__.py │ ├── group_norm │ │ └── __init__.py │ ├── nccl_allocator │ │ ├── __init__.py │ │ ├── README.md │ │ └── nccl_allocator.py │ ├── clip_grad │ │ └── __init__.py │ ├── layer_norm │ │ ├── __init__.py │ │ └── layer_norm.py │ ├── cudnn_gbn │ │ └── __init__.py │ ├── index_mul_2d │ │ └── __init__.py │ ├── sparsity │ │ ├── COPYRIGHT │ │ ├── __init__.py │ │ ├── permutation_search_kernels │ │ │ └── __init__.py │ │ └── permutation_tests │ │ │ ├── runtime_table.sh │ │ │ └── unstructured_study.sh │ ├── multihead_attn │ │ ├── MHA_bwd.png │ │ ├── MHA_fwd.png │ │ ├── __init__.py │ │ ├── README.md │ │ └── mask_softmax_dropout_func.py │ ├── peer_memory │ │ └── __init__.py │ ├── optimizers │ │ └── __init__.py │ ├── xentropy │ │ ├── __init__.py │ │ └── softmax_xentropy.py │ ├── transducer │ │ └── __init__.py │ ├── conv_bias_relu │ │ └── __init__.py │ ├── csrc │ │ ├── group_norm_v2 │ │ │ ├── gn_cuda_inst_1024_1280.cu │ │ │ ├── gn_cuda_inst_1024_1920.cu │ │ │ ├── gn_cuda_inst_1024_320.cu │ │ │ ├── gn_cuda_inst_1024_640.cu │ │ │ ├── gn_cuda_inst_1024_960.cu │ │ │ ├── gn_cuda_inst_256_1280.cu │ │ │ ├── gn_cuda_inst_256_1920.cu │ │ │ ├── gn_cuda_inst_256_2560.cu │ │ │ ├── gn_cuda_inst_256_640.cu │ │ │ ├── gn_cuda_inst_4096_320.cu │ │ │ ├── gn_cuda_inst_4096_640.cu │ │ │ ├── gn_cuda_inst_4096_960.cu │ │ │ ├── gn_cuda_inst_64_1280.cu │ │ │ ├── gn_cuda_inst_64_2560.cu │ │ │ ├── gn_utils.cpp │ │ │ ├── gn.hpp │ │ │ ├── generate_gn_cuda_inst.py │ │ │ ├── gn_cuda.cu │ │ │ └── gn_utils.hpp │ │ ├── nccl_p2p │ │ │ ├── nccl_version_check.cu │ │ │ ├── nccl_version.cpp │ │ │ ├── nccl_p2p.cpp │ │ │ └── nccl_p2p_cuda.cuh │ │ ├── group_norm │ │ │ ├── group_norm_nhwc_one_pass_10.cu │ │ │ ├── group_norm_nhwc_one_pass_112.cu │ │ │ ├── group_norm_nhwc_one_pass_12.cu │ │ │ ├── group_norm_nhwc_one_pass_120.cu │ │ │ ├── group_norm_nhwc_one_pass_128.cu │ │ │ ├── group_norm_nhwc_one_pass_14.cu │ │ │ ├── group_norm_nhwc_one_pass_16.cu │ │ │ ├── group_norm_nhwc_one_pass_160.cu │ │ │ ├── group_norm_nhwc_one_pass_20.cu │ │ │ ├── group_norm_nhwc_one_pass_24.cu │ │ │ ├── group_norm_nhwc_one_pass_26.cu │ │ │ ├── group_norm_nhwc_one_pass_28.cu │ │ │ ├── group_norm_nhwc_one_pass_30.cu │ │ │ ├── group_norm_nhwc_one_pass_32.cu │ │ │ ├── group_norm_nhwc_one_pass_4.cu │ │ │ ├── group_norm_nhwc_one_pass_40.cu │ │ │ ├── group_norm_nhwc_one_pass_42.cu │ │ │ ├── group_norm_nhwc_one_pass_48.cu │ │ │ ├── group_norm_nhwc_one_pass_56.cu │ │ │ ├── group_norm_nhwc_one_pass_60.cu │ │ │ ├── group_norm_nhwc_one_pass_64.cu │ │ │ ├── group_norm_nhwc_one_pass_70.cu │ │ │ ├── group_norm_nhwc_one_pass_8.cu │ │ │ ├── group_norm_nhwc_one_pass_80.cu │ │ │ ├── group_norm_nhwc_one_pass_84.cu │ │ │ ├── group_norm_nhwc_one_pass_96.cu │ │ │ └── group_norm_nhwc_one_pass_98.cu │ │ ├── groupbn │ │ │ └── cuda_utils.h │ │ ├── optimizers │ │ │ ├── fused_lamb_cuda.cpp │ │ │ ├── multi_tensor_distopt_lamb.cpp │ │ │ └── multi_tensor_distopt_adam.cpp │ │ ├── gpu_direct_storage │ │ │ ├── gds_pybind.cpp │ │ │ └── gds.h │ │ ├── nccl_allocator │ │ │ └── NCCLAllocator.cpp │ │ ├── peer_memory │ │ │ ├── peer_memory.cpp │ │ │ └── peer_memory_cuda.cuh │ │ ├── focal_loss │ │ │ └── focal_loss_cuda.cpp │ │ ├── xentropy │ │ │ └── interface.cpp │ │ ├── transducer │ │ │ ├── transducer_joint.cpp │ │ │ └── transducer_loss.cpp │ │ └── fmha │ │ │ └── src │ │ │ └── fmha │ │ │ └── mask.h │ ├── torchsched │ │ ├── inductor │ │ │ └── __init__.py │ │ ├── ops │ │ │ └── __init__.py │ │ ├── passes │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── config.py │ ├── bottleneck │ │ └── __init__.py │ ├── examples │ │ ├── gpu_direct_storage │ │ │ ├── example_save.py │ │ │ ├── example_load.py │ │ │ ├── benchmark_save.py │ │ │ └── benchmark_load.py │ │ └── nccl_allocator │ │ │ ├── change_cuda_allocator.py │ │ │ ├── allreduce.py │ │ │ ├── cache.py │ │ │ └── toy_ddp.py │ ├── groupbn │ │ └── __init__.py │ ├── focal_loss │ │ ├── __init__.py │ │ └── focal_loss.py │ └── gpu_direct_storage │ │ ├── README.md │ │ └── __init__.py ├── transformer │ ├── testing │ │ └── __init__.py │ ├── amp │ │ └── __init__.py │ ├── pipeline_parallel │ │ ├── __init__.py │ │ ├── schedules │ │ │ └── __init__.py │ │ └── _timers.py │ ├── _ucc_util.py │ ├── _data │ │ └── __init__.py │ ├── layers │ │ └── __init__.py │ ├── log_util.py │ ├── functional │ │ └── __init__.py │ ├── __init__.py │ ├── enums.py │ ├── utils.py │ ├── tensor_parallel │ │ ├── utils.py │ │ └── __init__.py │ └── README.md ├── mlp │ ├── __init__.py │ └── mlp.py ├── fused_dense │ └── __init__.py ├── multi_tensor_apply │ ├── __init__.py │ └── multi_tensor_apply.py ├── normalization │ └── __init__.py ├── optimizers │ └── __init__.py ├── _autocast_utils.py └── __init__.py ├── tests ├── L0 │ ├── run_optimizers │ │ └── __init__.py │ └── run_transformer │ │ ├── __init__.py │ │ ├── test_transformer_utils.py │ │ └── test_data.py ├── L1 │ ├── cross_product_distributed │ │ └── run.sh │ ├── cross_product │ │ └── run.sh │ └── common │ │ └── compare.py ├── distributed │ ├── amp_master_params │ │ ├── run.sh │ │ ├── compare.py │ │ └── amp_master_params.py │ ├── DDP │ │ ├── run_race_test.sh │ │ └── ddp_race_condition_test.py │ └── synced_batchnorm │ │ ├── unit_test.sh │ │ └── test_batchnorm1d.py └── docker_extension_builds │ └── run.sh ├── requirements_dev.txt ├── .clang-format ├── docs ├── source │ ├── _static │ │ ├── img │ │ │ └── nv-pytorch2.png │ │ └── css │ │ │ └── pytorch_theme.css │ ├── layernorm.rst │ ├── optimizers.rst │ ├── _templates │ │ └── layout.html │ └── index.rst └── Makefile ├── examples ├── simple │ └── distributed │ │ ├── run.sh │ │ ├── README.md │ │ └── distributed_data_parallel.py ├── README.md ├── docker │ ├── Dockerfile │ └── README.md └── dcgan │ └── README.md ├── requirements.txt ├── .gitmodules ├── .git-blame-ignore-revs ├── .pre-commit-config.yaml ├── csrc ├── megatron │ ├── fused_weight_gradient_dense.cpp │ ├── scaled_softmax.cpp │ ├── scaled_upper_triang_masked_softmax.cpp │ ├── generic_scaled_masked_softmax.cpp │ └── scaled_upper_triang_masked_softmax_cuda.cu ├── flatten_unflatten.cpp ├── static_switch.h ├── update_scale_hysteresis.cu └── multi_tensor_adagrad.cu ├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── pyproject.toml ├── LICENSE └── .gitignore /.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/fmha/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/xentropy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/transformer/testing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/L0/run_optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/L0/run_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/bottleneck/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/clip_grad/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/cudnn_gbn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/focal_loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/group_norm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/index_mul_2d/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/layer_norm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/peer_memory/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/transducer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/mlp/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import * 2 | -------------------------------------------------------------------------------- /apex/contrib/test/conv_bias_relu/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/test/multihead_attn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /apex/contrib/fmha/__init__.py: -------------------------------------------------------------------------------- 1 | from .fmha import FMHAFun 2 | -------------------------------------------------------------------------------- /apex/fused_dense/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_dense import * 2 | -------------------------------------------------------------------------------- /apex/contrib/group_norm/__init__.py: -------------------------------------------------------------------------------- 1 | from .group_norm import * 2 | -------------------------------------------------------------------------------- /apex/contrib/nccl_allocator/__init__.py: -------------------------------------------------------------------------------- 1 | from .nccl_allocator import * 2 | -------------------------------------------------------------------------------- /apex/contrib/clip_grad/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_grad import clip_grad_norm_ 2 | -------------------------------------------------------------------------------- /apex/contrib/layer_norm/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer_norm import FastLayerNorm 2 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | flake8>=3.7.9 3 | Sphinx>=3.0.3 -------------------------------------------------------------------------------- /apex/contrib/cudnn_gbn/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch_norm import GroupBatchNorm2d 2 | -------------------------------------------------------------------------------- /apex/contrib/index_mul_2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .index_mul_2d import index_mul_2d 2 | -------------------------------------------------------------------------------- /apex/contrib/sparsity/COPYRIGHT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 2 | -------------------------------------------------------------------------------- /apex/contrib/sparsity/__init__.py: -------------------------------------------------------------------------------- 1 | from .sparse_masklib import create_mask 2 | from .asp import ASP 3 | -------------------------------------------------------------------------------- /tests/L1/cross_product_distributed/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cp ../common/* . 4 | bash run_test.sh distributed $1 5 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | # Start with a built-in style and modify it 2 | BasedOnStyle: Google 3 | 4 | # Overrides 5 | ColumnLimit: 120 6 | -------------------------------------------------------------------------------- /apex/contrib/multihead_attn/MHA_bwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/master/apex/contrib/multihead_attn/MHA_bwd.png -------------------------------------------------------------------------------- /apex/contrib/multihead_attn/MHA_fwd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/master/apex/contrib/multihead_attn/MHA_fwd.png -------------------------------------------------------------------------------- /docs/source/_static/img/nv-pytorch2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/apex/master/docs/source/_static/img/nv-pytorch2.png -------------------------------------------------------------------------------- /examples/simple/distributed/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m torch.distributed.launch --nproc_per_node=2 distributed_data_parallel.py 3 | -------------------------------------------------------------------------------- /apex/contrib/peer_memory/__init__.py: -------------------------------------------------------------------------------- 1 | from .peer_memory import PeerMemoryPool 2 | from .peer_halo_exchanger_1d import PeerHaloExchanger1d 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cxxfilt>=0.2.0 2 | tqdm>=4.28.1 3 | numpy>=1.15.3 4 | PyYAML>=5.1 5 | pytest>=3.5.1 6 | packaging>=14.0 7 | torch>=2.6.0 8 | -------------------------------------------------------------------------------- /apex/multi_tensor_apply/__init__.py: -------------------------------------------------------------------------------- 1 | from .multi_tensor_apply import MultiTensorApply 2 | 3 | multi_tensor_applier = MultiTensorApply(2048 * 32) 4 | -------------------------------------------------------------------------------- /apex/transformer/amp/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer.amp.grad_scaler import GradScaler 2 | 3 | 4 | __all__ = [ 5 | "GradScaler", 6 | ] 7 | -------------------------------------------------------------------------------- /apex/contrib/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .fp16_optimizer import FP16_Optimizer 2 | from .fused_adam import FusedAdam 3 | from .fused_lamb import FusedLAMB 4 | -------------------------------------------------------------------------------- /apex/contrib/xentropy/__init__.py: -------------------------------------------------------------------------------- 1 | from .softmax_xentropy import SoftmaxCrossEntropyLoss 2 | 3 | 4 | __all__ = [ 5 | "SoftmaxCrossEntropyLoss", 6 | ] 7 | -------------------------------------------------------------------------------- /apex/contrib/transducer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transducer import TransducerJoint 2 | from .transducer import TransducerLoss 3 | from . import _transducer_ref 4 | -------------------------------------------------------------------------------- /tests/distributed/amp_master_params/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py 3 | 4 | python compare.py 5 | -------------------------------------------------------------------------------- /tests/distributed/DDP/run_race_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 ddp_race_condition_test.py 4 | -------------------------------------------------------------------------------- /apex/contrib/conv_bias_relu/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_bias_relu import ( 2 | ConvBiasReLU, 3 | ConvBias, 4 | ConvBiasMaskReLU, 5 | ConvFrozenScaleBiasReLU, 6 | ) 7 | -------------------------------------------------------------------------------- /apex/normalization/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_layer_norm import ( 2 | FusedLayerNorm, 3 | MixedFusedLayerNorm, 4 | FusedRMSNorm, 5 | MixedFusedRMSNorm, 6 | ) 7 | -------------------------------------------------------------------------------- /apex/contrib/sparsity/permutation_search_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from .call_permutation_search_kernels import accelerated_search_for_good_permutation 2 | from .permutation_utilities import sum_after_2_to_4 3 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1280.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(1024, 1280) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_1920.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(1024, 1920) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_320.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(1024, 320) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_640.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(1024, 640) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_1024_960.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(1024, 960) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1280.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(256, 1280) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_1920.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(256, 1920) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_2560.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(256, 2560) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_256_640.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(256, 640) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_320.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(4096, 320) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_640.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(4096, 640) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_4096_960.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(4096, 960) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_1280.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(64, 1280) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda_inst_64_2560.cu: -------------------------------------------------------------------------------- 1 | #include "gn_cuda_host_template.cuh" 2 | 3 | namespace group_norm_v2 { 4 | 5 | GN_CUDA_INST_DEFINE(64, 2560) 6 | 7 | } // namespace group_norm_v2 8 | -------------------------------------------------------------------------------- /apex/contrib/multihead_attn/__init__.py: -------------------------------------------------------------------------------- 1 | from .self_multihead_attn import SelfMultiheadAttn 2 | from .encdec_multihead_attn import EncdecMultiheadAttn 3 | from .mask_softmax_dropout_func import fast_mask_softmax_dropout_func 4 | -------------------------------------------------------------------------------- /apex/contrib/torchsched/inductor/__init__.py: -------------------------------------------------------------------------------- 1 | """Scheduling abstractions on PyTorch Inductor level.""" 2 | 3 | from apex.contrib.torchsched.inductor.graph import patch_graph_lowering 4 | 5 | __all__ = ["patch_graph_lowering"] 6 | -------------------------------------------------------------------------------- /apex/contrib/torchsched/ops/__init__.py: -------------------------------------------------------------------------------- 1 | """Custom PyTorch operators.""" 2 | 3 | import torch 4 | 5 | __all__: list[str] = [] 6 | 7 | # Register custom operators 8 | torch.ops.import_module("apex.contrib.torchsched.ops.layer_norm") 9 | -------------------------------------------------------------------------------- /apex/contrib/bottleneck/__init__.py: -------------------------------------------------------------------------------- 1 | from .bottleneck import Bottleneck, SpatialBottleneck 2 | from .halo_exchangers import ( 3 | HaloExchangerNoComm, 4 | HaloExchangerAllGather, 5 | HaloExchangerSendRecv, 6 | HaloExchangerPeer, 7 | ) 8 | -------------------------------------------------------------------------------- /apex/contrib/torchsched/passes/__init__.py: -------------------------------------------------------------------------------- 1 | """Customized compiler passes.""" 2 | 3 | from __future__ import annotations 4 | 5 | from apex.contrib.torchsched.passes.pre_grad_passes import pre_grad_custom_pass 6 | 7 | __all__ = ["pre_grad_custom_pass"] 8 | -------------------------------------------------------------------------------- /tests/L1/cross_product/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/" 4 | # DATADIR="/opt/home/apex/examples/imagenet/" 5 | cp ../common/* . 6 | bash run_test.sh single_gpu $1 7 | -------------------------------------------------------------------------------- /apex/contrib/examples/gpu_direct_storage/example_save.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import apex.contrib.gpu_direct_storage as gds 3 | 4 | for size in [128, 1024, 8192]: 5 | x = torch.linspace(0, 1, size, device = "cuda") 6 | with gds.GDSFile(f"{size}.data", "w") as f: 7 | f.save_data(x) 8 | -------------------------------------------------------------------------------- /apex/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_sgd import FusedSGD 2 | from .fused_adam import FusedAdam 3 | from .fused_novograd import FusedNovoGrad 4 | from .fused_lamb import FusedLAMB 5 | from .fused_adagrad import FusedAdagrad 6 | from .fused_mixed_precision_lamb import FusedMixedPrecisionLamb 7 | -------------------------------------------------------------------------------- /apex/transformer/pipeline_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func 2 | from apex.transformer.pipeline_parallel.schedules.common import build_model 3 | 4 | 5 | __all__ = [ 6 | "get_forward_backward_func", 7 | "build_model", 8 | ] 9 | -------------------------------------------------------------------------------- /apex/transformer/_ucc_util.py: -------------------------------------------------------------------------------- 1 | from torch import distributed as dist 2 | 3 | HAS_UCC = hasattr(dist, "is_ucc_available") and dist.is_ucc_available() 4 | if not HAS_UCC: 5 | try: 6 | import torch_ucc 7 | 8 | HAS_UCC = True 9 | except ImportError: 10 | HAS_UCC = False 11 | -------------------------------------------------------------------------------- /apex/contrib/groupbn/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import torch 3 | import bnp 4 | from .batch_norm import BatchNorm2d_NHWC 5 | 6 | del torch 7 | del bnp 8 | del batch_norm 9 | except ImportError: 10 | print("apex was installed without --bnp flag, contrib.groupbn is not available") 11 | -------------------------------------------------------------------------------- /apex/transformer/_data/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer._data._batchsampler import MegatronPretrainingRandomSampler 2 | from apex.transformer._data._batchsampler import MegatronPretrainingSampler 3 | 4 | 5 | __all__ = [ 6 | "MegatronPretrainingRandomSampler", 7 | "MegatronPretrainingSampler", 8 | ] 9 | -------------------------------------------------------------------------------- /apex/contrib/csrc/nccl_p2p/nccl_version_check.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | 3 | // This file is used to check the version of NCCL detected. 4 | #include 5 | 6 | #include 7 | 8 | std::tuple get_nccl_version() { return {int(NCCL_MAJOR), int(NCCL_MINOR)}; } 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "apex/contrib/csrc/multihead_attn/cutlass"] 2 | path = apex/contrib/csrc/multihead_attn/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | branch = v1.2.0 5 | [submodule "apex/contrib/csrc/cudnn-frontend"] 6 | path = apex/contrib/csrc/cudnn-frontend 7 | url = https://github.com/NVIDIA/cudnn-frontend.git 8 | -------------------------------------------------------------------------------- /apex/contrib/focal_loss/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import torch 3 | import focal_loss_cuda 4 | from .focal_loss import focal_loss 5 | 6 | del torch 7 | del focal_loss_cuda 8 | del focal_loss 9 | except ImportError: 10 | print("apex was installed without --focal_loss flag, apex.contrib.focal_loss is not available") 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/nccl_p2p/nccl_version.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | // This file is used to check the version of NCCL detected. 3 | #include 4 | 5 | #include 6 | 7 | std::tuple get_nccl_version(); 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_nccl_version", &get_nccl_version); } 10 | -------------------------------------------------------------------------------- /apex/contrib/examples/gpu_direct_storage/example_load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import apex.contrib.gpu_direct_storage as gds 3 | 4 | for size in [128, 1024, 8192]: 5 | x = torch.empty(size, device = "cuda") 6 | with gds.GDSFile(f"{size}.data", "r") as f: 7 | f.load_data(x) 8 | xx = torch.linspace(0, 1, size, device = "cuda") 9 | assert(torch.allclose(x, xx)) 10 | -------------------------------------------------------------------------------- /docs/source/layernorm.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | apex.normalization.fused_layer_norm 5 | =================================== 6 | 7 | .. automodule:: apex.normalization 8 | .. currentmodule:: apex.normalization 9 | 10 | .. FusedAdam 11 | ---------- 12 | 13 | .. autoclass:: FusedLayerNorm 14 | :members: 15 | 16 | .. autoclass:: FusedRMSNorm 17 | :members: 18 | -------------------------------------------------------------------------------- /apex/transformer/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | from apex.transformer.layers.layer_norm import FastLayerNorm 3 | from apex.transformer.layers.layer_norm import FusedLayerNorm 4 | from apex.transformer.layers.layer_norm import MixedFusedLayerNorm 5 | 6 | 7 | __all__ = [ 8 | "FastLayerNorm", 9 | "FusedLayerNorm", 10 | "MixedFusedLayerNorm", 11 | ] 12 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_10.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 10, /* THREADS_PER_BLOCK */ 640) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_112.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 112, /* THREADS_PER_BLOCK */ 448) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_12.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 12, /* THREADS_PER_BLOCK */ 384) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_120.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 120, /* THREADS_PER_BLOCK */ 480) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_128.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 128, /* THREADS_PER_BLOCK */ 512) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_14.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 14, /* THREADS_PER_BLOCK */ 224) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_16.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 16, /* THREADS_PER_BLOCK */ 256) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_160.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 160, /* THREADS_PER_BLOCK */ 640) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_20.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 20, /* THREADS_PER_BLOCK */ 640) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_24.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 24, /* THREADS_PER_BLOCK */ 384) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_26.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 26, /* THREADS_PER_BLOCK */ 416) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_28.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 28, /* THREADS_PER_BLOCK */ 448) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_30.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 30, /* THREADS_PER_BLOCK */ 480) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_32.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 32, /* THREADS_PER_BLOCK */ 512) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_4.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 4, /* THREADS_PER_BLOCK */ 128) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_40.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 40, /* THREADS_PER_BLOCK */ 640) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_42.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 42, /* THREADS_PER_BLOCK */ 672) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_48.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 48, /* THREADS_PER_BLOCK */ 384) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_56.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 56, /* THREADS_PER_BLOCK */ 448) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_60.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 60, /* THREADS_PER_BLOCK */ 480) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_64.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 64, /* THREADS_PER_BLOCK */ 512) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_70.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 70, /* THREADS_PER_BLOCK */ 560) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_8.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 8, /* THREADS_PER_BLOCK */ 128) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_80.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 80, /* THREADS_PER_BLOCK */ 640) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_84.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 84, /* THREADS_PER_BLOCK */ 672) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_96.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 96, /* THREADS_PER_BLOCK */ 768) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm/group_norm_nhwc_one_pass_98.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: BSD-3-Clause 4 | */ 5 | 6 | #include "group_norm_nhwc_bwd_one_pass_kernel.cuh" 7 | #include "group_norm_nhwc_fwd_one_pass_kernel.cuh" 8 | #include "macros.h" 9 | 10 | GN_FWD_BWD_ONE_PASS_DEFINITION(/* CHANNELS_PER_GROUP */ 98, /* THREADS_PER_BLOCK */ 392) 11 | -------------------------------------------------------------------------------- /apex/contrib/csrc/groupbn/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | #ifndef CUDA_UTILS_H 3 | #define CUDA_UTILS_H 4 | 5 | namespace at { 6 | namespace cuda { 7 | 8 | namespace utils { 9 | 10 | static inline int MaxSharedMemoryPerMultiprocessor(int device_id) { 11 | return getDeviceProperties(device_id)->sharedMemPerMultiprocessor; 12 | } 13 | 14 | } // namespace utils 15 | } // namespace cuda 16 | } // namespace at 17 | 18 | #endif 19 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | This directory contains examples illustrating Apex mixed precision and distributed tools. 2 | 3 | **Note for users of the pre-unification API**: 4 | `deprecated_api` contains examples illustrating the old (pre-unified) APIs. These APIs will be removed soon, and users are strongly encouraged to switch. The separate mixed precision tools called `Amp` and `FP16_Optimizer` in the old API are exposed via different flags/optimization levels in the new API. 5 | -------------------------------------------------------------------------------- /apex/contrib/examples/nccl_allocator/change_cuda_allocator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import apex.contrib.nccl_allocator as nccl_allocator 3 | 4 | nccl_allocator.init() 5 | nrep = 6 6 | pool = nccl_allocator.create_nccl_mem_pool() 7 | with nccl_allocator.nccl_mem(pool): 8 | for i in range(nrep): 9 | out = torch.randn(1024).cuda() 10 | 11 | for i in range(nrep): 12 | out = torch.randn(1024).cuda() 13 | 14 | torch.cuda.empty_cache() 15 | torch.cuda.empty_cache() 16 | -------------------------------------------------------------------------------- /apex/transformer/log_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | def get_transformer_logger(name: str) -> logging.Logger: 6 | name_wo_ext = os.path.splitext(name)[0] 7 | return logging.getLogger(name_wo_ext) 8 | 9 | 10 | def set_logging_level(verbosity) -> None: 11 | """Change logging severity. 12 | 13 | Args: 14 | verbosity 15 | """ 16 | from apex import _library_root_logger 17 | 18 | _library_root_logger.setLevel(verbosity) 19 | -------------------------------------------------------------------------------- /docs/source/optimizers.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | apex.optimizers 5 | =================================== 6 | 7 | .. automodule:: apex.optimizers 8 | .. currentmodule:: apex.optimizers 9 | 10 | .. FusedAdam 11 | ---------- 12 | 13 | .. autoclass:: FusedAdam 14 | :members: 15 | 16 | .. autoclass:: FusedLAMB 17 | :members: 18 | 19 | .. autoclass:: FusedNovoGrad 20 | :members: 21 | 22 | .. autoclass:: FusedSGD 23 | :members: 24 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Commits to ignore in git-blame 2 | # These commits are bulk formatting or refactoring changes that should be skipped when viewing blame history 3 | 4 | # Add pre-commit and GitHub Actions workflow for it (#1949) 5 | 1f20398756f0eeba37d6887a2d3f65e0687ec94f 6 | # Remove github actions config of pre-commit in favor of pre-commit ci (#1958) 7 | 27e0e8951352d9d58c88b2895cd8f2c752bda963 8 | # Enable Ruff pre-commit hooks (#1957) 9 | 16fadfe71c0d57312351c2d8b056251a0c8ce1ef 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-clang-format 3 | rev: v21.1.7 # Or pin to your preferred clang-format version 4 | hooks: 5 | - id: clang-format 6 | files: \.(c|h|cpp|hpp|proto|cu|cuh)$ 7 | exclude: ^(apex/contrib/csrc/multihead_attn/cutlass|apex/contrib/csrc/cudnn-frontend)/ 8 | 9 | - repo: https://github.com/astral-sh/ruff-pre-commit 10 | rev: v0.14.9 11 | hooks: 12 | - id: ruff-check 13 | args: ["--fix"] 14 | - id: ruff-format 15 | types_or: [python] 16 | exclude: "examples" 17 | -------------------------------------------------------------------------------- /apex/transformer/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer.functional.fused_rope import ( 2 | fused_apply_rotary_pos_emb, 3 | fused_apply_rotary_pos_emb_cached, 4 | fused_apply_rotary_pos_emb_thd, 5 | fused_apply_rotary_pos_emb_2d, 6 | ) 7 | from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax 8 | 9 | __all__ = [ 10 | "FusedScaleMaskSoftmax", 11 | "fused_apply_rotary_pos_emb", 12 | "fused_apply_rotary_pos_emb_cached", 13 | "fused_apply_rotary_pos_emb_thd", 14 | "fused_apply_rotary_pos_emb_2d", 15 | ] 16 | -------------------------------------------------------------------------------- /tests/distributed/synced_batchnorm/unit_test.sh: -------------------------------------------------------------------------------- 1 | python python_single_gpu_unit_test.py 2 | python single_gpu_unit_test.py 3 | python test_batchnorm1d.py 4 | python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py 5 | python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16 6 | python -m torch.distributed.launch --nproc_per_node=2 two_gpu_test_different_batch_size.py --apex 7 | #beware, you need a system with at least 4 gpus to test group_size 4 | #include 5 | 6 | namespace group_norm_v2 { 7 | 8 | cudaDeviceProp const& get_device_prop(int device_id) { 9 | static std::vector device_props; 10 | static std::once_flag flag; 11 | std::call_once(flag, [&] { 12 | int count; 13 | CUDA_CHECK(cudaGetDeviceCount(&count)); 14 | device_props.resize(count); 15 | for (int i = 0; i < count; i++) { 16 | CUDA_CHECK(cudaGetDeviceProperties(&device_props[i], i)); 17 | } 18 | }); 19 | return device_props.at(device_id); 20 | } 21 | 22 | } // namespace group_norm_v2 23 | -------------------------------------------------------------------------------- /apex/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer import amp 2 | from apex.transformer import functional 3 | from apex.transformer import parallel_state 4 | from apex.transformer import pipeline_parallel 5 | from apex.transformer import tensor_parallel 6 | from apex.transformer import utils 7 | from apex.transformer.enums import LayerType 8 | from apex.transformer.enums import AttnType 9 | from apex.transformer.enums import AttnMaskType 10 | 11 | 12 | __all__ = [ 13 | "amp", 14 | "functional", 15 | "parallel_state", 16 | "pipeline_parallel", 17 | "tensor_parallel", 18 | "utils", 19 | # enums.py 20 | "LayerType", 21 | "AttnType", 22 | "AttnMaskType", 23 | ] 24 | -------------------------------------------------------------------------------- /apex/contrib/gpu_direct_storage/README.md: -------------------------------------------------------------------------------- 1 | # APEX GPUDirect Storage 2 | 3 | This module aims to add a PyTorch extension for [GPUDirect Storage](https://developer.nvidia.com/blog/gpudirect-storage/) (GDS) support through utilizing the [cuFile](https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html) library. 4 | 5 | # Build command 6 | ``` 7 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--gpu_direct_storage" ./ 8 | ``` 9 | 10 | Alternatively: 11 | ``` 12 | python setup.py install --gpu_direct_storage 13 | ``` 14 | 15 | Check installation: 16 | ``` 17 | python -c "import torch; import apex.contrib.gpu_direct_storage" 18 | ``` 19 | -------------------------------------------------------------------------------- /csrc/megatron/fused_weight_gradient_dense.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | void wgrad_gemm_accum_fp32_cuda_stub(at::Tensor& input_2d, at::Tensor& d_output_2d, at::Tensor& d_weight); 7 | 8 | void wgrad_gemm_accum_fp16_cuda_stub(at::Tensor& input_2d, at::Tensor& d_output_2d, at::Tensor& d_weight); 9 | 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 11 | m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_cuda_stub, "wgrad gemm accum in fp32", 12 | py::call_guard()); 13 | m.def("wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_cuda_stub, "wgrad gemm accum in fp16", 14 | py::call_guard()); 15 | } 16 | -------------------------------------------------------------------------------- /apex/contrib/examples/nccl_allocator/allreduce.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import apex.contrib.nccl_allocator as nccl_allocator 5 | 6 | assert os.getenv("WORLD_SIZE") is not None, "Please use: torchrun --nproc-per-node=8 allreduce.py" 7 | 8 | rank = int(os.getenv("RANK")) 9 | local_rank = int(os.getenv("LOCAL_RANK")) 10 | world_size = int(os.getenv("WORLD_SIZE")) 11 | 12 | nccl_allocator.init() 13 | 14 | torch.cuda.set_device(local_rank) 15 | dist.init_process_group(backend="nccl") 16 | pool = nccl_allocator.create_nccl_mem_pool() 17 | with nccl_allocator.nccl_mem(pool): 18 | a = torch.ones(1024 * 1024 * 2, device="cuda") 19 | dist.all_reduce(a) 20 | 21 | torch.cuda.synchronize() 22 | 23 | -------------------------------------------------------------------------------- /csrc/flatten_unflatten.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | // https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h 4 | 5 | at::Tensor flatten(std::vector tensors) { return torch::utils::flatten_dense_tensors(tensors); } 6 | 7 | std::vector unflatten(at::Tensor flat, std::vector tensors) { 8 | return torch::utils::unflatten_dense_tensors(flat, tensors); 9 | } 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("flatten", &flatten, "Flatten dense tensors", py::call_guard()); 13 | m.def("unflatten", &unflatten, "Unflatten dense tensors", py::call_guard()); 14 | } 15 | -------------------------------------------------------------------------------- /apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, 4 | const float lr, const float beta1, const float beta2, const float epsilon, const int step, 5 | const int bias_correction, const float weight_decay, const int grad_averaging, 6 | const int mode, const float global_grad_norm, const float max_grad_norm); 7 | 8 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 9 | m.def("lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer", 10 | py::call_guard()); 11 | } 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve apex 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the Bug** 11 | 12 | **Minimal Steps/Code to Reproduce the Bug** 13 | 18 | 19 | **Expected Behavior** 20 | 21 | 22 | **Environment** 23 | 24 | -------------------------------------------------------------------------------- /apex/_autocast_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | import torch 4 | 5 | 6 | __all__ = ["_cast_if_autocast_enabled"] 7 | 8 | 9 | def _get_autocast_dtypes() -> Sequence[torch.dtype]: 10 | if torch.cuda.is_bf16_supported(): 11 | return [torch.half, torch.bfloat16] 12 | return [torch.half] 13 | 14 | 15 | def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype: 16 | if not torch.is_autocast_enabled(): 17 | return torch.float or dtype 18 | else: 19 | return torch.get_autocast_gpu_dtype() 20 | 21 | 22 | def _cast_if_autocast_enabled(*args): 23 | if not torch.is_autocast_enabled(): 24 | return args 25 | else: 26 | return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) 27 | -------------------------------------------------------------------------------- /examples/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Base image must at least have pytorch and CUDA installed. 2 | ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:23.03-py3 3 | FROM $BASE_IMAGE 4 | ARG BASE_IMAGE 5 | RUN echo "Installing Apex on top of ${BASE_IMAGE}" 6 | # make sure we don't overwrite some existing directory called "apex" 7 | WORKDIR /tmp/unique_for_apex 8 | # uninstall Apex if present, twice to make absolutely sure :) 9 | RUN pip uninstall -y apex || : 10 | RUN pip uninstall -y apex || : 11 | # SHA is something the user can touch to force recreation of this Docker layer, 12 | # and therefore force cloning of the latest version of Apex 13 | RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git 14 | WORKDIR /tmp/unique_for_apex/apex 15 | RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 16 | WORKDIR /workspace 17 | -------------------------------------------------------------------------------- /examples/simple/distributed/README.md: -------------------------------------------------------------------------------- 1 | **distributed_data_parallel.py** and **run.sh** show an example using Amp with 2 | [apex.parallel.DistributedDataParallel](https://nvidia.github.io/apex/parallel.html) or 3 | [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#distributeddataparallel) 4 | and the Pytorch multiprocess launcher script, 5 | [torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility). 6 | The use of `Amp` with DistributedDataParallel does not need to change from ordinary 7 | single-process use. The only gotcha is that wrapping your model with `DistributedDataParallel` must 8 | come after the call to `amp.initialize`. Test via 9 | ```bash 10 | bash run.sh 11 | ``` 12 | 13 | **This is intended purely as an instructional example, not a performance showcase.** 14 | -------------------------------------------------------------------------------- /apex/contrib/gpu_direct_storage/__init__.py: -------------------------------------------------------------------------------- 1 | from _apex_gpu_direct_storage import _GDSFile 2 | from contextlib import contextmanager 3 | 4 | 5 | @contextmanager 6 | def GDSFile(filename, mode): 7 | assert type(filename) == str 8 | assert type(mode) == str 9 | try: 10 | from apex import deprecated_warning 11 | 12 | deprecated_warning( 13 | "`gpu_direct_storage.GDSFile` is deprecated and will be removed in September 2025. " 14 | "We encourage you to use `torch.cuda.gds` module of PyTorch as a replacement. " 15 | "Its documentation is available at https://docs.pytorch.org/docs/stable/cuda.html#gpudirect-storage-prototype" 16 | ) 17 | file_handle = _GDSFile(filename, mode) 18 | yield file_handle 19 | finally: 20 | file_handle.close() 21 | del file_handle 22 | -------------------------------------------------------------------------------- /apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | // python bindings 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 11 | py::class_>(m, "_GDSFile") 12 | .def(py::init<>()) 13 | .def(py::init()) 14 | .def("open", &apex::contrib::gds::File::open) 15 | .def("close", &apex::contrib::gds::File::close) 16 | .def("load_data", &apex::contrib::gds::File::load_data) 17 | .def("save_data", &apex::contrib::gds::File::save_data) 18 | .def("load_data_no_gds", &apex::contrib::gds::File::load_data_no_gds) 19 | .def("save_data_no_gds", &apex::contrib::gds::File::save_data_no_gds); 20 | } 21 | -------------------------------------------------------------------------------- /csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // From 2 | // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr static bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr static bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /apex/contrib/csrc/gpu_direct_storage/gds.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace apex::contrib::gds { 11 | class File { 12 | public: 13 | File(); 14 | File(const std::string& filename, const std::string& mode); 15 | ~File(); 16 | 17 | void open(const std::string& filename, const std::string& mode); 18 | void close(); 19 | 20 | void load_data(const torch::Tensor& tensor); 21 | void save_data(const torch::Tensor& tensor); 22 | void load_data_no_gds(const torch::Tensor& tensor); 23 | void save_data_no_gds(const torch::Tensor& tensor); 24 | 25 | private: 26 | std::string filename; 27 | std::string mode; 28 | 29 | CUfileDescr_t cf_descr; 30 | CUfileHandle_t cf_handle; 31 | CUfileError_t status; 32 | 33 | int fd = -1; 34 | bool is_open = false; 35 | bool maybe_register = true; 36 | }; 37 | } // namespace apex::contrib::gds 38 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [tool.ruff] 9 | line-length = 100 10 | ignore = [ 11 | # Sorted by occurrence count (ascending) - easier to fix first 12 | "E731", # lambda assignment (6 occurrences) 13 | "E721", # type comparison should use isinstance (8 occurrences) 14 | "E741", # ambiguous variable name (8 occurrences) 15 | "E712", # comparison to True/False (9 occurrences) 16 | "F403", # star imports used (9 occurrences) 17 | "E701", # multiple statements on one line (10 occurrences) 18 | "E711", # comparison to None should be `cond is None` (11 occurrences) 19 | "F821", # undefined name (14 occurrences) 20 | "E722", # bare except (15 occurrences) 21 | "E402", # module level import not at top of file (41 occurrences) 22 | "F401", # imported but unused (45 occurrences) 23 | "F841", # local variable assigned but never used (52 occurrences) 24 | "F405", # star imports (80 occurrences) 25 | ] 26 | -------------------------------------------------------------------------------- /apex/transformer/enums.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import enum 16 | 17 | 18 | class LayerType(enum.Enum): 19 | encoder = 1 20 | decoder = 2 21 | 22 | 23 | class AttnType(enum.Enum): 24 | self_attn = 1 25 | cross_attn = 2 26 | 27 | 28 | class AttnMaskType(enum.Enum): 29 | padding = 1 30 | causal = 2 31 | 32 | 33 | class ModelType(enum.Enum): 34 | encoder_or_decoder = 1 35 | encoder_and_decoder = 2 36 | -------------------------------------------------------------------------------- /apex/multi_tensor_apply/multi_tensor_apply.py: -------------------------------------------------------------------------------- 1 | class MultiTensorApply(object): 2 | available = False 3 | warned = False 4 | 5 | def __init__(self, chunk_size): 6 | try: 7 | import amp_C 8 | 9 | MultiTensorApply.available = True 10 | self.chunk_size = chunk_size 11 | except ImportError as err: 12 | MultiTensorApply.available = False 13 | MultiTensorApply.import_err = err 14 | 15 | def check_avail(self): 16 | if MultiTensorApply.available == False: 17 | raise RuntimeError( 18 | "Attempted to call MultiTensorApply method, but MultiTensorApply " 19 | "is not available, possibly because Apex was installed without " 20 | "--cpp_ext --cuda_ext. Original import error message:", 21 | MultiTensorApply.import_err, 22 | ) 23 | 24 | def __call__(self, op, noop_flag_buffer, tensor_lists, *args): 25 | self.check_avail() 26 | 27 | return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args) 28 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = NVIDIAAPEX 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | gh-pages: 16 | git checkout gh-pages 17 | rm -rf build 18 | rm -rf source 19 | git checkout master -- . 20 | make html 21 | rm -rf ../_modules ../_sources ../_static 22 | mv -fv build/html/* ../ 23 | rm -rf build 24 | git add -A 25 | git commit -m "Generated gh-pages for `git log master -1 --pretty=short --abbrev-commit`" && git push origin gh-pages ; git checkout master 26 | 27 | .PHONY: help Makefile 28 | 29 | # Catch-all target: route all unknown targets to Sphinx using the new 30 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 31 | %: Makefile 32 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 33 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | namespace group_norm_v2 { 8 | 9 | struct Meta { 10 | int64_t red_buffer_size; 11 | int64_t barrier_size; 12 | int BLOCK_DIM_X; 13 | int C_PER_BLOCK; 14 | int ROWS_PER_BLOCK; 15 | int VEC_ELEMS; 16 | bool LOAD_TWICE; 17 | int BLOCKS_PER_SM; 18 | bool HARDWARE_CLUSTER; 19 | int wgrad_sync_method; 20 | }; 21 | 22 | template 23 | void gn_cuda(T* out, T* x, T* w, T* b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, 24 | int channels_per_group, float* mean_var_out, float* red_buffer, unsigned* barrier, int sm_margin, 25 | cudaStream_t stream, int device_id, Meta* meta_ptr, bool meta_only); 26 | 27 | template 28 | void gn_bwd_cuda(T* grad_input, T* grad_weight, T* grad_bias, T* grad_output, T* x, T* w, T* b, float* mean_var, 29 | float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float* red_buffer, 30 | unsigned* barrier, int sm_margin, cudaStream_t stream, int device_id, Meta* meta_ptr, bool meta_only); 31 | 32 | } // namespace group_norm_v2 33 | -------------------------------------------------------------------------------- /apex/contrib/xentropy/softmax_xentropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import xentropy_cuda 4 | 5 | 6 | class SoftmaxCrossEntropyLoss(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False): 9 | losses, max_log_sum_exp = xentropy_cuda.forward(logits, labels, smoothing, half_to_float) 10 | losses.masked_fill_(labels == padding_idx, 0) 11 | 12 | ctx.save_for_backward( 13 | logits, 14 | max_log_sum_exp, 15 | labels, 16 | torch.FloatTensor([smoothing]), 17 | torch.LongTensor([padding_idx]), 18 | ) 19 | 20 | return losses 21 | 22 | @staticmethod 23 | def backward(ctx, grad_loss): 24 | logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors 25 | 26 | if not grad_loss.is_contiguous(): 27 | grad_loss = grad_loss.contiguous() 28 | grad_loss.masked_fill_(labels == padding_idx.item(), 0) 29 | grad_logits = xentropy_cuda.backward( 30 | grad_loss.contiguous(), logits, max_log_sum_exp, labels, smoothing.item() 31 | ) 32 | 33 | return grad_logits, None, None, None, None 34 | -------------------------------------------------------------------------------- /docs/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% block sidebartitle %} {{ super() }} 3 | 4 | 32 | {% endblock %} 33 | 34 | {% block footer %} {{ super() }} 35 | 36 | 51 | {% endblock %} 52 | -------------------------------------------------------------------------------- /apex/contrib/examples/gpu_direct_storage/benchmark_save.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timeit 3 | import torch 4 | import apex.contrib.gpu_direct_storage as gds 5 | 6 | def run_benchmark(func): 7 | sizes = [2 ** i for i in range(16, 28)] 8 | for size in sizes: 9 | torch.cuda.empty_cache() 10 | s = torch.cuda.Stream() 11 | x = torch.linspace(0, 1, size, device = "cuda") 12 | 13 | # warmup 14 | torch.cuda.synchronize() 15 | for _ in range(10): 16 | func(x, f"{size}.data") 17 | os.remove(f"{size}.data") 18 | 19 | torch.cuda.synchronize() 20 | start_time = timeit.default_timer() 21 | for _ in range(10): 22 | func(x, f"{size}.data") 23 | os.remove(f"{size}.data") 24 | torch.cuda.synchronize() 25 | end_time = timeit.default_timer() 26 | print(f"{func.__name__}: size = {size}, {end_time - start_time}") 27 | 28 | def save_data_yes_gds(tensor, filename): 29 | with gds.GDSFile(filename, "w") as f: 30 | f.save_data(tensor) 31 | 32 | def save_data_no_gds(tensor, filename): 33 | with gds.GDSFile(filename, "wn") as f: 34 | f.save_data_no_gds(tensor) 35 | 36 | if __name__ == '__main__': 37 | run_benchmark(torch.save) 38 | run_benchmark(save_data_yes_gds) 39 | run_benchmark(save_data_no_gds) 40 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. PyTorch documentation master file, created by 2 | sphinx-quickstart on Fri Dec 23 13:31:47 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/nvidia/apex 7 | 8 | Apex (A PyTorch Extension) 9 | =================================== 10 | 11 | This site contains the API documentation for Apex (https://github.com/nvidia/apex), 12 | a Pytorch extension with NVIDIA-maintained utilities to streamline mixed precision and distributed training. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible. 13 | 14 | Installation instructions can be found here: https://github.com/NVIDIA/apex#quick-start. 15 | 16 | Some other useful material, including GTC 2019 and Pytorch DevCon 2019 Slides, can be found here: https://github.com/mcarilli/mixed_precision_references. 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | :caption: Fused Optimizers 21 | 22 | optimizers 23 | 24 | .. toctree:: 25 | :maxdepth: 1 26 | :caption: Fused Layer Norm 27 | 28 | layernorm 29 | 30 | .. .. toctree:: 31 | :maxdepth: 1 32 | :caption: Deprecated mixed precision API 33 | fp16_util 34 | 35 | .. RNN 36 | 37 | Indices and tables 38 | ================== 39 | 40 | * :ref:`genindex` 41 | * :ref:`modindex` 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /apex/contrib/nccl_allocator/README.md: -------------------------------------------------------------------------------- 1 | ## General information 2 | 3 | `nccl_allocator` is a module that enables `ncclMemAlloc`[^1] to be used within PyTorch for faster NCCL NVLS collective communications. 4 | It is mainly based on `CUDAPluggableAllocator`. 5 | The context manager `nccl_allocator.nccl_mem(enabled=True)` is used as a switch between `cudaMalloc` and `ncclMemAlloc` (if `enabled=True` it will use `cudaMalloc`). 6 | 7 | [^1]: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html 8 | 9 | ### Example usage: 10 | 11 | Here is a minimalistic example: 12 | 13 | ``` 14 | import os 15 | import torch 16 | import torch.distributed as dist 17 | import apex.contrib.nccl_allocator as nccl_allocator 18 | 19 | rank = int(os.getenv("RANK")) 20 | local_rank = int(os.getenv("LOCAL_RANK")) 21 | world_size = int(os.getenv("WORLD_SIZE")) 22 | 23 | nccl_allocator.init() 24 | 25 | torch.cuda.set_device(local_rank) 26 | dist.init_process_group(backend="nccl") 27 | 28 | with nccl_allocator.nccl_mem(): 29 | a = torch.ones(1024 * 1024 * 2, device="cuda") 30 | dist.all_reduce(a) 31 | 32 | torch.cuda.synchronize() 33 | ``` 34 | 35 | Please visit `apex/contrib/examples/nccl_allocator` for more examples. 36 | 37 | 38 | ### IMPORTANT 39 | 40 | There are several strict requirements: 41 | - PyTorch must include PR [#112850](https://github.com/pytorch/pytorch/pull/112850) 42 | - NCCL v2.19.4 and newer 43 | - NCCL NVLS requires CUDA Driver 530 and newer (tested on 535) 44 | 45 | -------------------------------------------------------------------------------- /tests/distributed/amp_master_params/compare.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | model_params_rank0 = torch.load("rank0model.pth", map_location=lambda storage, loc: storage.cuda(0)) 4 | model_params_rank1 = torch.load("rank1model.pth", map_location=lambda storage, loc: storage.cuda(0)) 5 | master_params_rank0 = torch.load( 6 | "rank0master.pth", map_location=lambda storage, loc: storage.cuda(0) 7 | ) 8 | master_params_rank1 = torch.load( 9 | "rank1master.pth", map_location=lambda storage, loc: storage.cuda(0) 10 | ) 11 | 12 | for model_rank0, model_rank1, master_rank0, master_rank1 in zip( 13 | model_params_rank0, model_params_rank1, master_params_rank0, master_params_rank1 14 | ): 15 | assert torch.allclose(model_rank0, model_rank1), "Model param mismatch" 16 | assert torch.allclose(master_rank0, master_rank1), "Master param mismatch" 17 | # Some debugging/investigation assistance code: 18 | # maxval, maxind = torch.max(((torch.abs(model_rank0).float())/torch.abs(master_rank0)).view(-1), 0) 19 | # offending_val_half = model_rank0.view(-1)[maxind.item()] 20 | # offending_val_float = master_rank0.view(-1)[maxind.item()] 21 | # print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(), 22 | # offending_val_float.half().item()) 23 | # rtol needs to be > 2^-11 because of denormals... 24 | assert torch.allclose(model_rank0, master_rank0.half(), rtol=0.005), "Model-master mismatch" 25 | 26 | print("OK: Model and master params match across ranks.") 27 | -------------------------------------------------------------------------------- /apex/transformer/pipeline_parallel/schedules/__init__.py: -------------------------------------------------------------------------------- 1 | from apex.transformer import parallel_state 2 | from apex.transformer.pipeline_parallel.utils import get_num_microbatches 3 | from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import ( 4 | forward_backward_no_pipelining, 5 | ) 6 | from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( 7 | _forward_backward_pipelining_with_interleaving, 8 | ) 9 | from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( 10 | forward_backward_pipelining_without_interleaving, 11 | ) 12 | 13 | __all__ = [ 14 | "get_forward_backward_func", 15 | ] 16 | 17 | 18 | class ExperimentalWarning(Warning): 19 | pass 20 | 21 | 22 | def get_forward_backward_func( 23 | virtual_pipeline_model_parallel_size, 24 | pipeline_model_parallel_size, 25 | ): 26 | if parallel_state.get_pipeline_model_parallel_world_size() > 1: 27 | if virtual_pipeline_model_parallel_size is not None: 28 | if get_num_microbatches() % pipeline_model_parallel_size != 0: 29 | msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule" 30 | raise RuntimeError(msg) 31 | forward_backward_func = _forward_backward_pipelining_with_interleaving 32 | else: 33 | forward_backward_func = forward_backward_pipelining_without_interleaving 34 | else: 35 | forward_backward_func = forward_backward_no_pipelining 36 | return forward_backward_func 37 | -------------------------------------------------------------------------------- /apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "nccl_p2p_cuda.cuh" 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id", 21 | py::call_guard()); 22 | m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm", 23 | py::call_guard()); 24 | m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, 25 | "left_right_halo_exchange_inplace", py::call_guard()); 26 | m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange", 27 | py::call_guard()); 28 | m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay", py::call_guard()); 29 | } 30 | -------------------------------------------------------------------------------- /tests/L0/run_transformer/test_transformer_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch.testing._internal import common_utils 5 | 6 | logging.getLogger("torch").setLevel(logging.WARNING) 7 | 8 | from apex.transformer import parallel_state 9 | from apex.transformer.tensor_parallel import utils 10 | from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase 11 | 12 | logging.getLogger("apex").setLevel(logging.WARNING) 13 | 14 | 15 | class TransformerUtilsTest(NcclDistributedTestBase): 16 | def test_split_tensor_along_last_dim(self): 17 | for tensor_model_paralell_world_size in range(1, self.world_size + 1): 18 | if self.world_size % tensor_model_paralell_world_size > 0: 19 | continue 20 | parallel_state.initialize_model_parallel( 21 | tensor_model_parallel_size_=tensor_model_paralell_world_size 22 | ) 23 | 24 | device = "cpu" 25 | input_tensor = torch.randn((100, 100, 100), device=device) 26 | splits = utils.split_tensor_along_last_dim(input_tensor, 10) 27 | last_dim_shapes = torch.tensor([int(split.size()[-1]) for split in splits]) 28 | 29 | self.assertTrue( 30 | torch.equal( 31 | last_dim_shapes, 32 | torch.full((10,), 10), 33 | ), 34 | msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", 35 | ) 36 | 37 | parallel_state.destroy_model_parallel() 38 | 39 | 40 | if __name__ == "__main__": 41 | common_utils.run_tests() 42 | -------------------------------------------------------------------------------- /apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #ifndef _nccl_p2p_h_ 20 | #define _nccl_p2p_h_ 21 | 22 | namespace apex { 23 | namespace contrib { 24 | namespace nccl_p2p { 25 | at::Tensor get_unique_nccl_id(int n); 26 | int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks); 27 | void left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, 28 | at::Tensor right_output_halo, at::Tensor left_input_halo, 29 | at::Tensor right_input_halo); 30 | std::vector left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, 31 | at::Tensor right_output_halo); 32 | void add_delay(int delay); 33 | } // namespace nccl_p2p 34 | } // namespace contrib 35 | } // namespace apex 36 | #endif 37 | -------------------------------------------------------------------------------- /apex/contrib/examples/nccl_allocator/cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import apex.contrib.nccl_allocator as nccl_allocator 3 | from pynvml.smi import nvidia_smi 4 | 5 | def set_device(dev): 6 | import ctypes 7 | handle = ctypes.CDLL("libcudart.so") 8 | result = handle.cudaSetDevice(ctypes.c_int(dev)) 9 | assert result == 0 10 | 11 | def print_used_mem(string, nvsmi, device_id = 0): 12 | print(f"{string}:", nvsmi.DeviceQuery('memory.used')['gpu'][device_id]) 13 | 14 | nccl_allocator.init() 15 | nrep = 6 16 | nccl_mem = [] 17 | 18 | set_device(0) 19 | nvsmi = nvidia_smi.getInstance() 20 | 21 | print_used_mem("", nvsmi) 22 | 23 | pool = nccl_allocator.create_nccl_mem_pool() 24 | with nccl_allocator.nccl_mem(pool): 25 | for i in range(nrep): 26 | out = torch.randn(1024 * 1024 * 100).cuda() # >= 400 MB 27 | nccl_mem.append(out) 28 | 29 | print_used_mem("after nccl alloc (+>=2400)", nvsmi) # + 2400+ MB 30 | 31 | cudart_mem = [] 32 | for i in range(nrep): 33 | out = torch.randn(1024 * 1024 * 50 ).cuda() # == 200 MB 34 | cudart_mem.append(out) 35 | 36 | print_used_mem("after cudart alloc (+1200)", nvsmi) 37 | 38 | del cudart_mem 39 | torch.cuda.empty_cache() 40 | torch.cuda.empty_cache() 41 | print_used_mem("release cudart mem (-1200)", nvsmi) # - 1200 MB 42 | 43 | del nccl_mem 44 | nccl_mem2 = [] 45 | with nccl_allocator.nccl_mem(pool): 46 | for i in range(nrep): 47 | out = torch.randn(1024 * 1024 * 100).cuda() # >= 400 MB 48 | nccl_mem2.append(out) 49 | print_used_mem("reuse nccl cache (same)", nvsmi) # + 0 MB 50 | del nccl_mem2 51 | torch.cuda.empty_cache() 52 | print_used_mem("release nccl_mem (-2400)", nvsmi) # - 2400 MB 53 | 54 | torch.cuda.empty_cache() 55 | -------------------------------------------------------------------------------- /apex/contrib/focal_loss/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import focal_loss_cuda 4 | 5 | 6 | class FocalLoss(torch.autograd.Function): 7 | @staticmethod 8 | def forward( 9 | ctx, 10 | cls_output, 11 | cls_targets_at_level, 12 | num_positives_sum, 13 | num_real_classes, 14 | alpha, 15 | gamma, 16 | label_smoothing=0.0, 17 | ): 18 | loss, partial_grad = focal_loss_cuda.forward( 19 | cls_output, 20 | cls_targets_at_level, 21 | num_positives_sum, 22 | num_real_classes, 23 | alpha, 24 | gamma, 25 | label_smoothing, 26 | ) 27 | 28 | ctx.save_for_backward(partial_grad, num_positives_sum) 29 | return loss 30 | 31 | @staticmethod 32 | def backward(ctx, grad_loss): 33 | partial_grad, num_positives_sum = ctx.saved_tensors 34 | 35 | # The backward kernel is actually in-place to save memory space, 36 | # partial_grad and grad_input are the same tensor. 37 | grad_input = focal_loss_cuda.backward(grad_loss, partial_grad, num_positives_sum) 38 | 39 | return grad_input, None, None, None, None, None, None 40 | 41 | 42 | def focal_loss( 43 | cls_output: torch.Tensor, 44 | cls_targets_at_level: torch.Tensor, 45 | num_positive_sum: torch.Tensor, 46 | num_real_classes: int, 47 | alpha: float, 48 | gamma: float, 49 | label_smoothing: float = 0.0, 50 | ) -> torch.Tensor: 51 | """Fused focal loss function.""" 52 | return FocalLoss.apply( 53 | cls_output, 54 | cls_targets_at_level, 55 | num_positive_sum, 56 | num_real_classes, 57 | alpha, 58 | gamma, 59 | label_smoothing, 60 | ) 61 | -------------------------------------------------------------------------------- /apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void multi_tensor_lamb_compute_update_term_cuda(int chunk_size, at::Tensor noop_flag, 4 | std::vector> tensor_lists, 5 | at::Tensor per_tensor_beta1, at::Tensor per_tensor_beta2, 6 | at::Tensor per_tensor_beta3, at::Tensor per_tensor_bias_correction, 7 | at::Tensor step, at::Tensor per_tensor_epsilon, const int mode, 8 | at::Tensor per_tensor_decay, at::Tensor global_scale, 9 | at::Tensor global_grad_norm, const float max_grad_norm); 10 | 11 | void multi_tensor_lamb_update_weights_cuda(int chunk_size, at::Tensor noop_flag, 12 | std::vector> tensor_lists, 13 | at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, 14 | at::Tensor update_norm_offset, at::Tensor learning_rate, 15 | at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb); 16 | 17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 18 | m.def("multi_tensor_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_cuda, 19 | "Computes update term for LAMB optimizer", py::call_guard()); 20 | m.def("multi_tensor_lamb_update_weights", &multi_tensor_lamb_update_weights_cuda, 21 | "Applies update term for LAMB optimizer", py::call_guard()); 22 | } 23 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/generate_gn_cuda_inst.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | 4 | hw_c_list = [ 5 | (8 * 8, 1280), 6 | (8 * 8, 2560), 7 | (16 * 16, 640), 8 | (16 * 16, 1280), 9 | (16 * 16, 1920), 10 | (16 * 16, 2560), 11 | (32 * 32, 320), 12 | (32 * 32, 640), 13 | (32 * 32, 960), 14 | (32 * 32, 1280), 15 | (32 * 32, 1920), 16 | (64 * 64, 320), 17 | (64 * 64, 640), 18 | (64 * 64, 960), 19 | ] 20 | 21 | 22 | def run(): 23 | src_path = pathlib.Path(__file__).parent.absolute() 24 | 25 | for f in src_path.glob("gn_cuda_inst_*.cu"): 26 | f.unlink() 27 | 28 | for hw, c in hw_c_list: 29 | print(f"GN_CUDA_INST_DEFINE({hw}, {c})") 30 | with open(src_path / f"gn_cuda_inst_{hw}_{c}.cu", "w") as f: 31 | f.write('#include "gn_cuda_host_template.cuh"\n') 32 | f.write("\n") 33 | f.write("\n") 34 | f.write("namespace group_norm_v2 {\n") 35 | f.write("\n") 36 | f.write(f"GN_CUDA_INST_DEFINE({hw}, {c})\n") 37 | f.write("\n") 38 | f.write("} // namespace group_norm_v2\n") 39 | 40 | with open(src_path / "gn_dispatch_hw_c.hpp", "w") as f: 41 | f.write("#pragma once\n") 42 | f.write("\n") 43 | f.write("#define DISPATCH_HW_C(hw, c, HW, C, ...) [&] { \\\n") 44 | for hw, c in hw_c_list: 45 | f.write( 46 | f" if (hw == {hw} && c == {c}) {{ constexpr int HW = {hw}, C = {c}; return __VA_ARGS__(); }} \\\n" 47 | ) 48 | f.write( 49 | ' throw std::invalid_argument("DISPATCH_HW_C " + std::to_string(hw) + " " + std::to_string(c)); \\\n' 50 | ) 51 | f.write(" }()\n") 52 | 53 | 54 | if __name__ == "__main__": 55 | run() 56 | -------------------------------------------------------------------------------- /apex/contrib/examples/nccl_allocator/toy_ddp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.distributed as dist 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | 8 | import apex.contrib.nccl_allocator as nccl_allocator 9 | 10 | assert os.getenv("WORLD_SIZE") is not None, "Please use: torchrun --nproc-per-node=8 toy_ddp.py" 11 | 12 | class ToyModel(nn.Module): 13 | def __init__(self): 14 | super(ToyModel, self).__init__() 15 | self.net1 = nn.Linear(10, 10) 16 | self.relu = nn.ReLU() 17 | self.net2 = nn.Linear(10, 5) 18 | 19 | def forward(self, x): 20 | return self.net2(self.relu(self.net1(x))) 21 | 22 | 23 | rank = int(os.getenv("RANK")) 24 | local_rank = int(os.getenv("LOCAL_RANK")) 25 | world_size = int(os.getenv("WORLD_SIZE")) 26 | 27 | nccl_allocator.init() 28 | 29 | torch.cuda.set_device(local_rank) 30 | dist.init_process_group(backend="nccl") 31 | 32 | device = torch.device("cuda", local_rank) 33 | model = ToyModel().to(device) 34 | ddp_model = DDP(model, device_ids=[rank]) 35 | loss_fn = nn.MSELoss() 36 | optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) 37 | 38 | data_ptrs = [] 39 | pool = nccl_allocator.create_nccl_mem_pool() 40 | with nccl_allocator.nccl_mem(pool): 41 | for param in ddp_model.parameters(): 42 | param.grad = torch.empty_like(param) 43 | data_ptrs.append(param.grad.data_ptr()) 44 | 45 | for _ in range(10): 46 | optimizer.zero_grad(set_to_none=False) 47 | outputs = ddp_model(torch.randn(20, 10)) 48 | labels = torch.randn(20, 5).to(rank) 49 | loss_fn(outputs, labels).backward() 50 | optimizer.step() 51 | 52 | for data_ptr, param in zip(data_ptrs, ddp_model.parameters()): 53 | assert(data_ptr == param.grad.data_ptr()) 54 | dist.destroy_process_group() 55 | -------------------------------------------------------------------------------- /tests/docker_extension_builds/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | print_banner() { 4 | printf "\n\n\n\e[30m\e[42m$1\e[0m\n\n\n\n" 5 | } 6 | 7 | print_green() { 8 | printf "\e[30m\e[42m$1\e[0m\n" 9 | } 10 | 11 | print_red() { 12 | printf "\e[30m\e[41m$1\e[0m\n" 13 | } 14 | 15 | images=( 16 | "pytorch/pytorch:nightly-devel-cuda10.0-cudnn7" 17 | "pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-devel" 18 | "pytorch/pytorch:1.0.1-cuda10.0-cudnn7-devel" 19 | "pytorch/pytorch:1.0-cuda10.0-cudnn7-devel" 20 | "pytorch/pytorch:nightly-devel-cuda9.2-cudnn7" 21 | ) 22 | 23 | branch="master" 24 | 25 | # Associative array for exit codes 26 | declare -A exit_codes 27 | for image in images 28 | do 29 | exit_codes[$image]="None" 30 | done 31 | 32 | for image in "${images[@]}" 33 | do 34 | print_banner "$image" 35 | set -x 36 | docker pull $image 37 | # Trying python setup.py install instead of pip install to ensure direct access to error codes. 38 | # Maybe pip install would be ok too but this works. 39 | docker run --runtime=nvidia --rm $image /bin/bash -c "yes | pip uninstall apex; yes | pip uninstall apex; git clone https://github.com/NVIDIA/apex.git; cd apex; git checkout $branch; set -e; python setup.py install --cuda_ext --cpp_ext" 40 | exit_code=$? 41 | set +x 42 | if [ $exit_code != 0 ] 43 | then 44 | print_red "Exit code: $exit_code" 45 | else 46 | print_green "Exit code: $exit_code" 47 | fi 48 | exit_codes[$image]=$exit_code 49 | done 50 | 51 | success=0 52 | for image in "${images[@]}" 53 | do 54 | exit_code=${exit_codes[$image]} 55 | if [ $exit_code != 0 ] 56 | then 57 | print_red "$image : $exit_code" 58 | success=1 59 | else 60 | print_green "$image : $exit_code" 61 | fi 62 | done 63 | 64 | if [ $success != 0 ] 65 | then 66 | print_red "Overall status: failure" 67 | else 68 | print_green "Overall status: success" 69 | fi 70 | 71 | exit $success 72 | -------------------------------------------------------------------------------- /apex/contrib/csrc/nccl_allocator/NCCLAllocator.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #define NCCL_CHECK(cmd) \ 8 | do { \ 9 | ncclResult_t result = cmd; \ 10 | if (result != ncclSuccess) { \ 11 | std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + std::to_string(__LINE__) + ", " + \ 12 | std::string(ncclGetErrorString(result)); \ 13 | TORCH_CHECK(false, err); \ 14 | } \ 15 | } while (0) 16 | 17 | void* nccl_alloc_plug(size_t size, int device, void* stream) { 18 | void* ptr; 19 | NCCL_CHECK(ncclMemAlloc(&ptr, size)); 20 | return ptr; 21 | } 22 | 23 | void nccl_free_plug(void* ptr, std::size_t size, int device, void* stream) { NCCL_CHECK(ncclMemFree(ptr)); } 24 | 25 | std::shared_ptr nccl_allocator; 26 | 27 | void maybe_init() { 28 | if (!nccl_allocator) { 29 | nccl_allocator = 30 | std::make_shared(nccl_alloc_plug, nccl_free_plug); 31 | } 32 | } 33 | 34 | std::shared_ptr get_nccl_allocator() { 35 | maybe_init(); 36 | return nccl_allocator; 37 | } 38 | 39 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 40 | m.def("get_nccl_allocator", []() { return get_nccl_allocator(); }); 41 | }; 42 | -------------------------------------------------------------------------------- /apex/transformer/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions used by both `pipeline_parallel` and `tensor_parallel`""" 2 | 3 | import torch 4 | 5 | from apex.transformer import parallel_state 6 | 7 | # `all_gather_into_tensor` is new placeholders for `_all_gather_base`. 8 | # It requires the most recent version of PyTorch. 9 | # The following 4 lines are for backward comparability with 10 | # older PyTorch. 11 | if "all_gather_into_tensor" not in dir(torch.distributed): 12 | torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base 13 | 14 | 15 | def ensure_divisibility(numerator, denominator): 16 | """Ensure that numerator is divisible by the denominator.""" 17 | assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) 18 | 19 | 20 | def divide(numerator, denominator): 21 | """Ensure that numerator is divisible by the denominator and return 22 | the division value.""" 23 | ensure_divisibility(numerator, denominator) 24 | return numerator // denominator 25 | 26 | 27 | def split_tensor_into_1d_equal_chunks(tensor): 28 | """Break a tensor into equal 1D chunks.""" 29 | data = tensor.view(-1) 30 | partition_size = torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size() 31 | start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() 32 | end_index = start_index + partition_size 33 | return data[start_index:end_index] 34 | 35 | 36 | def gather_split_1d_tensor(tensor): 37 | """Opposite of above function, gather values from model parallel ranks.""" 38 | world_size = parallel_state.get_tensor_model_parallel_world_size() 39 | numel = torch.numel(tensor) 40 | numel_gathered = world_size * numel 41 | gathered = torch.empty( 42 | numel_gathered, 43 | dtype=tensor.dtype, 44 | device=torch.cuda.current_device(), 45 | requires_grad=False, 46 | ) 47 | torch.distributed.all_gather_into_tensor( 48 | gathered, tensor, group=parallel_state.get_tensor_model_parallel_group() 49 | ) 50 | return gathered 51 | -------------------------------------------------------------------------------- /apex/contrib/csrc/peer_memory/peer_memory.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "peer_memory_cuda.cuh" 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw", 21 | py::call_guard()); 22 | m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw", py::call_guard()); 23 | m.def("zero", &apex::contrib::peer_memory::zero, "zero", py::call_guard()); 24 | m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address", 25 | py::call_guard()); 26 | m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers", 27 | py::call_guard()); 28 | m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half", 29 | py::call_guard()); 30 | m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float", 31 | py::call_guard()); 32 | m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int", 33 | py::call_guard()); 34 | m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d", 35 | py::call_guard()); 36 | } 37 | -------------------------------------------------------------------------------- /examples/docker/README.md: -------------------------------------------------------------------------------- 1 | ## Option 1: Create a new container with Apex 2 | 3 | **Dockerfile** installs the latest Apex on top of an existing image. Run 4 | ``` 5 | docker build -t new_image_with_apex . 6 | ``` 7 | By default, **Dockerfile** uses NVIDIA's Pytorch container as the base image, 8 | which requires an NVIDIA GPU Cloud (NGC) account. If you don't have an NGC account, you can sign up for free by following the instructions [here](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html#generating-api-key). 9 | 10 | Alternatively, you can supply your own base image via the `BASE_IMAGE` build-arg. 11 | `BASE_IMAGE` must have Pytorch and Cuda installed. For example, any 12 | `-devel` image for Pytorch 1.0 and later from the 13 | [official Pytorch Dockerhub](https://hub.docker.com/r/pytorch/pytorch) may be used: 14 | ``` 15 | docker build --build-arg BASE_IMAGE=1.3-cuda10.1-cudnn7-devel -t new_image_with_apex . 16 | ``` 17 | 18 | If you want to rebuild your image, and force the latest Apex to be cloned and installed, make any small change to the `SHA` variable in **Dockerfile**. 19 | 20 | **Warning:** 21 | Currently, the non-`-devel` images on Pytorch Dockerhub do not contain the Cuda compiler `nvcc`. Therefore, 22 | images whose name does not contain `-devel` are not eligible candidates for `BASE_IMAGE`. 23 | 24 | ### Running your Apex container 25 | 26 | Like any Cuda-enabled Pytorch container, a container with Apex should be run via [nvidia-docker](https://github.com/NVIDIA/nvidia-docker), for example: 27 | ``` 28 | docker run --runtime=nvidia -it --rm --ipc=host new_image_with_apex 29 | ``` 30 | 31 | ## Option 2: Install Apex in a running container 32 | 33 | Instead of building a new container, it is also a viable option to `git clone https://github.com/NVIDIA/apex.git` on bare metal, mount the Apex repo into your container at launch by running, for example, 34 | ``` 35 | docker run --runtime=nvidia -it --rm --ipc=host -v /bare/metal/apex:/apex/in/container 36 | ``` 37 | then go to /apex/in/container within the running container and 38 | ``` 39 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 40 | ``` 41 | -------------------------------------------------------------------------------- /apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void multi_tensor_fused_adam_cuda(int chunk_size, at::Tensor noop_flag, 4 | std::vector> tensor_lists, at::Tensor grad_scale, float lr, 5 | float beta1, float beta2, float eps, int step, int mode, int bias_correction, 6 | float weight_decay); 7 | 8 | void multi_tensor_fused_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, 9 | std::vector> tensor_lists, at::Tensor grad_scale, 10 | at::Tensor lr, float beta1, float beta2, float eps, at::Tensor step, 11 | int mode, int bias_correction, float weight_decay); 12 | 13 | void multi_tensor_fused_adam_with_param_remainders_cuda(int chunk_size, at::Tensor noop_flag, 14 | std::vector> tensor_lists, 15 | at::Tensor grad_scale, float lr, float beta1, float beta2, 16 | float eps, int step, int mode, int bias_correction, 17 | float weight_decay); 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda, 21 | "CUDA kernels for multi-tensor Adam, " 22 | "with param copy", 23 | py::call_guard()); 24 | m.def("multi_tensor_fused_adam_capturable", &multi_tensor_fused_adam_capturable_cuda, 25 | "CUDA kernels for multi-tensor Adam, " 26 | "with param copy, capturable for CUDA graph", 27 | py::call_guard()); 28 | m.def("multi_tensor_fused_adam_with_param_remainders", &multi_tensor_fused_adam_with_param_remainders_cuda, 29 | "CUDA kernel for multi-tensor Adam, " 30 | "with stored param remainders and param copy", 31 | py::call_guard()); 32 | } 33 | -------------------------------------------------------------------------------- /apex/contrib/test/multihead_attn/test_mha_fused_softmax.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | SKIP_TEST = None 7 | try: 8 | from apex.contrib.multihead_attn import fast_mask_softmax_dropout_func 9 | except ImportError as e: 10 | SKIP_TEST = e 11 | 12 | 13 | @unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") 14 | class FusedSoftmaxTest(unittest.TestCase): 15 | def setUp(self, seed=1234): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | self.seq_length = 80 20 | self.sequences = 10 21 | self.hidden_dim = 1024 22 | self.heads = 16 23 | self.dropout_prob = 0.0 24 | 25 | self.mask = (torch.randn(self.sequences, self.seq_length) > 0).cuda() 26 | self.mask = self.mask.half() * -10000 27 | self.ref_inputs = torch.randn( 28 | self.heads * self.sequences, 29 | self.seq_length, 30 | self.seq_length, 31 | dtype=torch.float16, 32 | device=torch.device("cuda"), 33 | ).requires_grad_(True) 34 | 35 | self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True) 36 | 37 | def test_fused_softmax(self): 38 | grads = torch.randn_like(self.tst_inputs) 39 | y_ref = self.ref_inputs.view(self.sequences, self.heads, self.seq_length, self.seq_length) 40 | y_ref = y_ref + self.mask.unsqueeze(1).unsqueeze(2) 41 | y_ref = y_ref.view(self.sequences * self.heads, self.seq_length, self.seq_length) 42 | y_ref = F.softmax(y_ref, dim=-1) 43 | y_ref = torch._fused_dropout(y_ref, 1.0) 44 | 45 | y_tst = fast_mask_softmax_dropout_func( 46 | True, self.heads, self.tst_inputs, self.mask, True, 0.0 47 | ) 48 | y_ref[0].backward(grads) 49 | y_tst.backward(grads) 50 | 51 | torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5) 52 | torch.testing.assert_close(y_ref[0], y_tst, atol=1e-3, rtol=1e-3) 53 | torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3) 54 | 55 | 56 | if __name__ == "__main__": 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /apex/contrib/test/fused_dense/test_fused_dense.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | 4 | import torch 5 | from torch.testing._internal import common_utils 6 | from torch.testing._internal.common_device_type import instantiate_device_type_tests 7 | 8 | SKIP_TEST = None 9 | try: 10 | from apex import fused_dense 11 | except ImportError as e: 12 | SKIP_TEST = e 13 | 14 | 15 | @unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") 16 | class FusedDenseTest(common_utils.TestCase): 17 | def _test_fused_dense(self, dtype, seed=0): 18 | os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] = "0" 19 | torch.manual_seed(seed) 20 | 21 | seq_length = 512 22 | sequences = 3 23 | hidden_dim = 1024 24 | 25 | ref_inputs = torch.randn( 26 | sequences * seq_length, hidden_dim, dtype=dtype, device=torch.device("cuda") 27 | ).requires_grad_(True) 28 | 29 | tst_inputs = ref_inputs.clone().detach().requires_grad_(True) 30 | dense = fused_dense.FusedDense(1024, 3072) 31 | dense.to(dtype=dtype) 32 | dense.cuda() 33 | 34 | y_tst = dense(tst_inputs) 35 | y_ref = torch.matmul(ref_inputs, dense.weight.t()) + dense.bias 36 | dy = torch.randn_like(y_tst).to(dtype=dtype) 37 | y_tst.backward(dy) 38 | dw_ref = torch.matmul(dy.t(), ref_inputs) 39 | dx_ref = torch.matmul(dy, dense.weight.clone()) 40 | db_ref = dy.sum(0, False) 41 | 42 | torch.testing.assert_close(ref_inputs, tst_inputs, atol=1e-5, rtol=1e-5) 43 | torch.testing.assert_close(y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True) 44 | torch.testing.assert_close(dw_ref, dense.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True) 45 | torch.testing.assert_close(dx_ref, tst_inputs.grad, atol=1e-3, rtol=1e-3, equal_nan=True) 46 | torch.testing.assert_close(db_ref, dense.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True) 47 | 48 | @common_utils.parametrize("dtype", [torch.half, torch.float, torch.bfloat16]) 49 | def test_fused_dense(self, dtype): 50 | self._test_fused_dense(dtype) 51 | 52 | 53 | instantiate_device_type_tests(FusedDenseTest, globals(), only_for=("cuda",)) 54 | 55 | if __name__ == "__main__": 56 | common_utils.run_tests() 57 | -------------------------------------------------------------------------------- /apex/contrib/examples/gpu_direct_storage/benchmark_load.py: -------------------------------------------------------------------------------- 1 | import timeit 2 | import torch 3 | import apex.contrib.gpu_direct_storage as gds 4 | 5 | def run_benchmark_torch_load(): 6 | sizes = [2 ** i for i in range(16, 28)] 7 | for size in sizes: 8 | torch.cuda.empty_cache() 9 | s = torch.cuda.Stream() 10 | x = torch.empty(size, device = "cuda") 11 | y = torch.linspace(0, 1, size, device = "cuda") 12 | torch.save(y, f"{size}.data") 13 | 14 | # warmup 15 | torch.cuda.synchronize() 16 | for _ in range(10): 17 | x = torch.load(f"{size}.data") 18 | 19 | torch.cuda.synchronize() 20 | start_time = timeit.default_timer() 21 | for _ in range(10): 22 | x = torch.load(f"{size}.data") 23 | torch.cuda.synchronize() 24 | end_time = timeit.default_timer() 25 | print(f"torch.load: size = {size}, {end_time - start_time}") 26 | assert(torch.allclose(x, y)) 27 | 28 | def run_benchmark(func): 29 | sizes = [2 ** i for i in range(16, 28)] 30 | for size in sizes: 31 | torch.cuda.empty_cache() 32 | s = torch.cuda.Stream() 33 | x = torch.empty(size, device = "cuda") 34 | y = torch.linspace(0, 1, size, device = "cuda") 35 | 36 | with gds.GDSFile(f"{size}.data", "w") as f: 37 | f.save_data(y) 38 | 39 | # warmup 40 | torch.cuda.synchronize() 41 | for _ in range(10): 42 | func(x, f"{size}.data") 43 | 44 | torch.cuda.synchronize() 45 | start_time = timeit.default_timer() 46 | for _ in range(10): 47 | func(x, f"{size}.data") 48 | torch.cuda.synchronize() 49 | end_time = timeit.default_timer() 50 | print(f"{func.__name__}: size = {size}, {end_time - start_time}") 51 | assert(torch.allclose(x, y)) 52 | 53 | def load_data_yes_gds(tensor, filename): 54 | with gds.GDSFile(filename, "r") as f: 55 | f.load_data(tensor) 56 | 57 | def load_data_no_gds(tensor, filename): 58 | with gds.GDSFile(filename, "rn") as f: 59 | f.load_data_no_gds(tensor) 60 | 61 | if __name__ == '__main__': 62 | run_benchmark_torch_load() 63 | run_benchmark(load_data_yes_gds) 64 | run_benchmark(load_data_no_gds) 65 | -------------------------------------------------------------------------------- /apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | // CUDA forward declarations 7 | 8 | std::vector focal_loss_forward_cuda(const at::Tensor& cls_output, const at::Tensor& cls_targets_at_level, 9 | const at::Tensor& num_positives_sum, const int64_t num_real_classes, 10 | const float alpha, const float gamma, const float smoothing_factor); 11 | 12 | at::Tensor focal_loss_backward_cuda(const at::Tensor& grad_output, const at::Tensor& partial_grad, 13 | const at::Tensor& num_positives_sum); 14 | 15 | // C++ interface 16 | 17 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 18 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 19 | #define CHECK_INPUT(x) \ 20 | CHECK_CUDA(x); \ 21 | CHECK_CONTIGUOUS(x) 22 | 23 | std::vector focal_loss_forward(const at::Tensor& cls_output, const at::Tensor& cls_targets_at_level, 24 | const at::Tensor& num_positives_sum, const int64_t num_real_classes, 25 | const float alpha, const float gamma, const float smoothing_factor) { 26 | CHECK_INPUT(cls_output); 27 | CHECK_INPUT(cls_targets_at_level); 28 | CHECK_INPUT(num_positives_sum); 29 | 30 | return focal_loss_forward_cuda(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, alpha, gamma, 31 | smoothing_factor); 32 | } 33 | 34 | at::Tensor focal_loss_backward(const at::Tensor& grad_output, const at::Tensor& partial_grad, 35 | const at::Tensor& num_positives_sum) { 36 | CHECK_INPUT(grad_output); 37 | CHECK_INPUT(partial_grad); 38 | 39 | return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum); 40 | } 41 | 42 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 43 | m.def("forward", &focal_loss_forward, "Focal loss calculation forward (CUDA)", 44 | py::call_guard()); 45 | m.def("backward", &focal_loss_backward, "Focal loss calculation backward (CUDA)", 46 | py::call_guard()); 47 | } 48 | -------------------------------------------------------------------------------- /examples/dcgan/README.md: -------------------------------------------------------------------------------- 1 | # Mixed Precision DCGAN Training in PyTorch 2 | 3 | `main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/dcgan](https://github.com/pytorch/examples/tree/master/dcgan). 4 | It implements Automatic Mixed Precision (Amp) training of the DCGAN example for different datasets. Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision "optimization levels" or `opt_level`s. For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html). 5 | 6 | We introduce these changes to the PyTorch DCGAN example as described in the [Multiple models/optimizers/losses](https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses) section of the documentation:: 7 | ``` 8 | # Added after models and optimizers construction 9 | [netD, netG], [optimizerD, optimizerG] = amp.initialize( 10 | [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3) 11 | ... 12 | # loss.backward() changed to: 13 | with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled: 14 | errD_real_scaled.backward() 15 | ... 16 | with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled: 17 | errD_fake_scaled.backward() 18 | ... 19 | with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled: 20 | errG_scaled.backward() 21 | ``` 22 | 23 | Note that we use different `loss_scalers` for each computed loss. 24 | Using a separate loss scaler per loss is [optional, not required](https://nvidia.github.io/apex/advanced.html#optionally-have-amp-use-a-different-loss-scaler-per-loss). 25 | 26 | To improve the numerical stability, we swapped `nn.Sigmoid() + nn.BCELoss()` to `nn.BCEWithLogitsLoss()`. 27 | 28 | With the new Amp API **you never need to explicitly convert your model, or the input data, to half().** 29 | 30 | "Pure FP32" training: 31 | ``` 32 | $ python main_amp.py --opt_level O0 33 | ``` 34 | Recommended mixed precision training: 35 | ``` 36 | $ python main_amp.py --opt_level O1 37 | ``` 38 | 39 | Have a look at the original [DCGAN example](https://github.com/pytorch/examples/tree/master/dcgan) for more information about the used arguments. 40 | 41 | To enable mixed precision training, we introduce the `--opt_level` argument. 42 | -------------------------------------------------------------------------------- /apex/contrib/layer_norm/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import init 3 | 4 | from apex._autocast_utils import _cast_if_autocast_enabled 5 | import fast_layer_norm 6 | 7 | 8 | class FastLayerNormFN(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, x, gamma, beta, epsilon, memory_efficient=False): 11 | ctx.x_shape = x.shape 12 | ctx.memory_efficient = memory_efficient 13 | 14 | x = x.contiguous() 15 | gamma = gamma.contiguous() 16 | beta = beta.contiguous() 17 | hidden_size = gamma.numel() 18 | xmat = x.view((-1, hidden_size)) 19 | ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon) 20 | if ctx.memory_efficient: 21 | ctx.save_for_backward(ymat, gamma, None, rsigma, beta) 22 | else: 23 | ctx.save_for_backward(xmat, gamma, mu, rsigma, None) 24 | return ymat.view(x.shape) 25 | 26 | @staticmethod 27 | def backward(ctx, dy): 28 | # assert dy.is_contiguous() 29 | dy = dy.contiguous() # this happens! 30 | x_or_y_mat, gamma, mu, rsigma, beta = ctx.saved_tensors 31 | dymat = dy.view(x_or_y_mat.shape) 32 | dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd( 33 | dymat, x_or_y_mat, mu, rsigma, gamma, beta, ctx.memory_efficient 34 | ) 35 | dx = dxmat.view(ctx.x_shape) 36 | return dx, dgamma, dbeta, None, None 37 | 38 | 39 | def _fast_layer_norm(x, weight, bias, epsilon, memory_efficient): 40 | args = _cast_if_autocast_enabled(x, weight, bias, epsilon, memory_efficient) 41 | with torch.amp.autocast("cuda", enabled=False): 42 | return FastLayerNormFN.apply(*args) 43 | 44 | 45 | class FastLayerNorm(torch.nn.Module): 46 | def __init__(self, hidden_size, eps=1e-5, memory_efficient=False): 47 | super().__init__() 48 | self.epsilon = eps 49 | self.memory_efficient = memory_efficient 50 | self.weight = torch.nn.Parameter(torch.empty(hidden_size)) 51 | self.bias = torch.nn.Parameter(torch.empty(hidden_size)) 52 | self.reset_parameters() 53 | 54 | def reset_parameters(self): 55 | init.ones_(self.weight) 56 | init.zeros_(self.bias) 57 | 58 | def forward(self, x): 59 | return _fast_layer_norm(x, self.weight, self.bias, self.epsilon, self.memory_efficient) 60 | -------------------------------------------------------------------------------- /apex/contrib/csrc/xentropy/interface.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | std::vector softmax_xentropy_cuda(const at::Tensor& input, const at::Tensor& labels, const float smoothing, 8 | const bool half_to_float); 9 | 10 | at::Tensor softmax_xentropy_backward_cuda(const at::Tensor& grad_loss, const at::Tensor& logits, 11 | const at::Tensor& max_log_sum_exp, const at::Tensor& labels, 12 | const float smoothing); 13 | 14 | // C++ interface 15 | 16 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 18 | #define CHECK_INPUT(x) \ 19 | CHECK_CUDA(x); \ 20 | CHECK_CONTIGUOUS(x) 21 | 22 | std::vector softmax_xentropy_forward(const at::Tensor& input, const at::Tensor& labels, 23 | const float smoothing, const bool half_to_float) { 24 | CHECK_CUDA(input); 25 | CHECK_INPUT(labels); 26 | 27 | return softmax_xentropy_cuda(input, labels, smoothing, half_to_float); 28 | } 29 | 30 | at::Tensor softmax_xentropy_backward(const at::Tensor& grad_loss, const at::Tensor& logits, 31 | const at::Tensor& max_log_sum_exp, const at::Tensor& labels, 32 | const float smoothing) { 33 | CHECK_CUDA(grad_loss); 34 | CHECK_CUDA(logits); 35 | CHECK_INPUT(max_log_sum_exp); 36 | CHECK_INPUT(labels); 37 | 38 | return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing); 39 | } 40 | 41 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 42 | m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", 43 | py::call_guard()); 44 | m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", 45 | py::call_guard()); 46 | // ref: https://pybind11.readthedocs.io/en/stable/basics.html#exporting-variables 47 | py::object version = py::cast( 48 | #ifdef XENTROPY_VER 49 | XENTROPY_VER 50 | #else 51 | std::string{} 52 | #endif 53 | ); 54 | m.attr("__version__") = version; 55 | } 56 | -------------------------------------------------------------------------------- /csrc/update_scale_hysteresis.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | __global__ void update_scale_hysteresis_cuda_kernel(float* current_scale, int* growth_tracker, int* hysteresis_tracker, 6 | const float* found_inf, double growth_factor, double backoff_factor, 7 | int growth_interval, int hysteresis) { 8 | if (*found_inf > 0) { 9 | *hysteresis_tracker -= 1; 10 | 11 | // Only reset the growth tracker when hysteresis is larger than zero 12 | if (*hysteresis_tracker > 0) { 13 | *growth_tracker = 0; 14 | return; 15 | } 16 | } 17 | 18 | if (*found_inf) { 19 | *current_scale = (*current_scale) * backoff_factor; 20 | *growth_tracker = 0; 21 | } else { 22 | // Entering this branch means we just carried out a successful step, 23 | // so growth_tracker is incremented before comparing to growth_interval. 24 | auto successful = (*growth_tracker) + 1; 25 | if (successful == growth_interval) { 26 | auto new_scale = static_cast((*current_scale) * growth_factor); 27 | // Do not grow the scale past fp32 bounds to inf. 28 | if (isfinite(new_scale)) { 29 | *current_scale = new_scale; 30 | } 31 | *growth_tracker = 0; 32 | } else { 33 | *growth_tracker = successful; 34 | } 35 | } 36 | 37 | // Reset the hysteresis tracker if no infs are found 38 | if (*found_inf <= 0) { 39 | *hysteresis_tracker = hysteresis; 40 | } 41 | } 42 | 43 | at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale, at::Tensor growth_tracker, 44 | at::Tensor hysteresis_tracker, at::Tensor found_inf, const double growth_factor, 45 | const double backoff_factor, const int64_t growth_interval, 46 | const int hysteresis) { 47 | update_scale_hysteresis_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( 48 | current_scale.mutable_data_ptr(), growth_tracker.mutable_data_ptr(), 49 | hysteresis_tracker.mutable_data_ptr(), found_inf.const_data_ptr(), growth_factor, backoff_factor, 50 | growth_interval, hysteresis); 51 | 52 | AT_CUDA_CHECK(cudaGetLastError()); 53 | 54 | return current_scale; 55 | } 56 | -------------------------------------------------------------------------------- /tests/L0/run_transformer/test_data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.testing 4 | from torch.testing._internal import common_utils 5 | 6 | logging.getLogger("torch").setLevel(logging.WARNING) 7 | 8 | from apex.transformer import parallel_state 9 | from apex.transformer.tensor_parallel import data as data_utils 10 | from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase 11 | from apex.transformer.testing.distributed_test_base import UccDistributedTestBase 12 | 13 | logging.getLogger("torch").setLevel(logging.WARNING) 14 | 15 | 16 | class BroadcastDataTestBase: 17 | def test_broadcast_data(self): 18 | tensor_model_parallel_world_size: int = self.world_size // (1 + self.world_size > 1) 19 | parallel_state.initialize_model_parallel( 20 | tensor_model_parallel_size_=tensor_model_parallel_world_size 21 | ) 22 | 23 | target_key_size = { 24 | "key1": [7, 11], 25 | "key2": [8, 2, 1], 26 | "key3": [13], 27 | "key4": [5, 1, 2], 28 | "key5": [5, 12], 29 | } 30 | keys = [k for k in target_key_size] 31 | 32 | data = {} 33 | data_t = {} 34 | with torch.no_grad(): 35 | for key in target_key_size: 36 | data[key] = torch.randint(0, 1000, size=target_key_size[key]) 37 | data_t[key] = data[key].clone() 38 | # "key_x" is supposed to be ignored. 39 | data["key_x"] = torch.rand(5) 40 | data_t["key_x"] = data["key_x"].clone() 41 | if parallel_state.get_tensor_model_parallel_rank() != 0: 42 | data = None 43 | 44 | data_utils._check_data_types(keys, data_t, torch.int64) 45 | key_size, _, _ = data_utils._build_key_size_numel_dictionaries(keys, data) 46 | 47 | for key in keys: 48 | self.assertEqual(target_key_size[key], key_size[key]) 49 | 50 | broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64) 51 | for key in keys: 52 | self.assertEqual(broadcasted_data[key], data_t[key].cuda()) 53 | 54 | parallel_state.destroy_model_parallel() 55 | 56 | 57 | class NcclBroadcastDataTest(BroadcastDataTestBase, NcclDistributedTestBase): 58 | pass 59 | 60 | 61 | class UccBroadcastDataTest(BroadcastDataTestBase, UccDistributedTestBase): 62 | pass 63 | 64 | 65 | if __name__ == "__main__": 66 | common_utils.run_tests() 67 | -------------------------------------------------------------------------------- /apex/contrib/multihead_attn/README.md: -------------------------------------------------------------------------------- 1 | # Fast Multihead Attention 2 | 3 | This implementation has two main features : 4 | * A C++ implementation to avoid the CPU overheads of Pytorch found with smaller batch sizes. 5 | * The removal of all copies and transposes found in standard implementations of Multihead Attention. 6 | 7 | | | Python Version | C++ Version | 8 | | :----------------------------------------- | :------------: | :---------: | 9 | | Layer Norm and Residual Add Variant | X | X | 10 | | Includes Linear Biases | X | | 11 | | Reduces CPU Overheads | | X | 12 | | Fuses masking with Softmax | | X | 13 | | Removes Transposes and Copies | X | X | 14 | | Includes Self and Encoder/Decoder Variants | X | X | 15 | 16 | ## How to Instantiate 17 | 18 | `SelfMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)` 19 | `EncdecMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)` 20 | 21 | `impl` has two options: 22 | * `fast` uses C++ Version 23 | * `default` uses Python Version 24 | 25 | ## Instructions to build on Linux 26 | 27 | ``` 28 | $ git clone https://github.com/NVIDIA/apex 29 | $ cd apex 30 | $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" ./ 31 | ``` 32 | ## Try Performance Tests Yourself! 33 | Perf test script is found here! 34 | ``` 35 | cd contrib/examples/multihead_attn 36 | ``` 37 | #### Fast Multihead Attention 38 | ``` 39 | python perf_test_multihead_attn.py --ref 40 | ``` 41 | #### Fast Multihead Attention with C++ Implementation 42 | ``` 43 | python perf_test_multihead_attn.py 44 | ``` 45 | #### Compare with `torch.nn.MultiheadAttn` 46 | ``` 47 | python perf_test_multihead_attn.py --native 48 | ``` 49 | #### Test your own range! 50 | ``` 51 | python perf_test_multihead_attn.py --seq-length 64 --num-seqs-start 10 --num-seqs-stop 120 --num-seqs-inc 5 52 | ``` 53 | 54 | ## Performance Comparisons 55 | 56 | * Performance was measured with 64 token sequence lengths on an NVIDIA TitanV card. 57 | * Time is measured across multiple layers to simulate an in model scenario. 58 | 59 | ![Multihead Attention Forward](MHA_fwd.png) 60 | ![Multihead Attention Backward](MHA_bwd.png) 61 | -------------------------------------------------------------------------------- /tests/L1/common/compare.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | parser = argparse.ArgumentParser(description="Compare") 5 | parser.add_argument("--opt-level", type=str) 6 | parser.add_argument("--keep-batchnorm-fp32", type=str, default=None) 7 | parser.add_argument("--loss-scale", type=str, default=None) 8 | parser.add_argument("--fused-adam", action="store_true") 9 | parser.add_argument("--use_baseline", action="store_true") 10 | args = parser.parse_args() 11 | 12 | base_file = ( 13 | str(args.opt_level) 14 | + "_" 15 | + str(args.loss_scale) 16 | + "_" 17 | + str(args.keep_batchnorm_fp32) 18 | + "_" 19 | + str(args.fused_adam) 20 | ) 21 | 22 | file_e = "True_" + base_file 23 | file_p = "False_" + base_file 24 | if args.use_baseline: 25 | file_b = "baselines/True_" + base_file 26 | 27 | dict_e = torch.load(file_e) 28 | dict_p = torch.load(file_p) 29 | if args.use_baseline: 30 | dict_b = torch.load(file_b) 31 | 32 | torch.set_printoptions(precision=10) 33 | 34 | print(file_e) 35 | print(file_p) 36 | if args.use_baseline: 37 | print(file_b) 38 | 39 | # ugly duplication here... 40 | if not args.use_baseline: 41 | for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])): 42 | assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p) 43 | 44 | loss_e = dict_e["Loss"][n] 45 | loss_p = dict_p["Loss"][n] 46 | assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format( 47 | i_e, loss_e, loss_p 48 | ) 49 | print( 50 | "{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format( 51 | i_e, loss_e, loss_p, dict_e["Speed"][n], dict_p["Speed"][n] 52 | ) 53 | ) 54 | else: 55 | for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])): 56 | assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p) 57 | 58 | loss_e = dict_e["Loss"][n] 59 | loss_p = dict_p["Loss"][n] 60 | loss_b = dict_b["Loss"][n] 61 | assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format( 62 | i_e, loss_e, loss_p 63 | ) 64 | assert loss_e == loss_b, "Iteration {}, loss_e = {}, loss_b = {}".format( 65 | i_e, loss_e, loss_b 66 | ) 67 | print( 68 | "{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format( 69 | i_e, 70 | loss_b, 71 | loss_e, 72 | loss_p, 73 | dict_b["Speed"][n], 74 | dict_e["Speed"][n], 75 | dict_p["Speed"][n], 76 | ) 77 | ) 78 | -------------------------------------------------------------------------------- /apex/contrib/torchsched/__init__.py: -------------------------------------------------------------------------------- 1 | """Graph scheduler package.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import torch 8 | import torch._inductor 9 | from torch._dynamo import list_backends 10 | from torch._dynamo import register_backend 11 | from torch._inductor.compile_fx import compile_fx_inner 12 | 13 | from .backend import get_backend 14 | 15 | if TYPE_CHECKING: 16 | from collections.abc import Callable 17 | from typing import Any 18 | 19 | from torch._ops import OpOverload 20 | 21 | __all__ = ["get_backend", "set_default_backend"] 22 | 23 | # Register custom operators 24 | torch.ops.import_module("apex.contrib.torchsched.ops") 25 | 26 | 27 | # Register torch-sched backend 28 | # Same API as torch._inductor.compile_fx 29 | @register_backend 30 | def torchsched( 31 | model_: torch.fx.GraphModule, 32 | example_inputs_: list[torch.Tensor], 33 | inner_compile: Callable[..., Any] = compile_fx_inner, 34 | config_patches: dict[str, Any] | None = None, 35 | decompositions: dict[OpOverload, Callable[..., Any]] | None = None, 36 | ) -> Callable: 37 | backend = get_backend(backend="torchsched", scheme="dwb") 38 | return backend(model_, example_inputs_, inner_compile, config_patches, decompositions) 39 | 40 | 41 | _SUPPORTED_BACKENDS = list_backends() 42 | _DEFAULT_BACKEND = "inductor" 43 | __torch_compile__ = torch.compile 44 | 45 | 46 | def set_default_backend(backend: str) -> None: 47 | """ 48 | Set the default backend for torch.compile. 49 | 50 | Parameters: 51 | backend (str): The backend to use as the default for torch.compile. 52 | """ 53 | global _DEFAULT_BACKEND 54 | assert backend in _SUPPORTED_BACKENDS, f"Unknown backend {backend}" 55 | _DEFAULT_BACKEND = backend 56 | 57 | 58 | def torchsched_compile( 59 | *args: object, 60 | backend: str | Callable | None = None, 61 | **kwargs: object, 62 | ) -> object: 63 | """ 64 | Wrap around the original torch.compile to support default backend. 65 | 66 | Parameters: 67 | *args (object): Positional arguments for torch.compile. 68 | backend (Union[str, Callable, None]): The backend to use. 69 | If None, the default backend is used. 70 | **kwargs (object): Additional keyword arguments for torch.compile. 71 | 72 | Returns: 73 | object: Compiler or compiled model. 74 | """ 75 | if backend is None: 76 | backend = _DEFAULT_BACKEND 77 | return __torch_compile__(*args, backend=backend, **kwargs) 78 | 79 | 80 | # Monkey patch torch.compile to set default backend 81 | torch.compile = torchsched_compile 82 | -------------------------------------------------------------------------------- /apex/transformer/tensor_parallel/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from typing import List, Sequence 16 | 17 | import torch 18 | 19 | from apex.transformer.utils import divide 20 | 21 | 22 | def split_tensor_along_last_dim( 23 | tensor: torch.Tensor, 24 | num_partitions: int, 25 | contiguous_split_chunks: bool = False, 26 | ) -> List[torch.Tensor]: 27 | """Split a tensor along its last dimension. 28 | Arguments: 29 | tensor: input tensor. 30 | num_partitions: number of partitions to split the tensor 31 | contiguous_split_chunks: If True, make each chunk contiguous 32 | in memory. 33 | """ 34 | # Get the size and dimension. 35 | last_dim = tensor.dim() - 1 36 | last_dim_size = divide(tensor.size()[last_dim], num_partitions) 37 | # Split. 38 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 39 | # Note: torch.split does not create contiguous tensors by default. 40 | if contiguous_split_chunks: 41 | return tuple(chunk.contiguous() for chunk in tensor_list) 42 | 43 | return tensor_list 44 | 45 | 46 | class VocabUtility: 47 | """Split the vocabulary into `world_size` chunks and return the 48 | first and last index of the vocabulary belonging to the `rank` 49 | partition: Note that indices in [fist, last)""" 50 | 51 | @staticmethod 52 | def vocab_range_from_per_partition_vocab_size( 53 | per_partition_vocab_size: int, rank, world_size: int 54 | ) -> Sequence[int]: 55 | index_f = rank * per_partition_vocab_size 56 | index_l = index_f + per_partition_vocab_size 57 | return index_f, index_l 58 | 59 | @staticmethod 60 | def vocab_range_from_global_vocab_size( 61 | global_vocab_size: int, rank: int, world_size: int 62 | ) -> Sequence[int]: 63 | per_partition_vocab_size = divide(global_vocab_size, world_size) 64 | return VocabUtility.vocab_range_from_per_partition_vocab_size( 65 | per_partition_vocab_size, rank, world_size 66 | ) 67 | -------------------------------------------------------------------------------- /apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma once 18 | #include 19 | #ifndef _peer_memory_h_ 20 | #define _peer_memory_h_ 21 | 22 | namespace apex { 23 | namespace contrib { 24 | namespace peer_memory { 25 | int64_t allocate_raw(int64_t size); 26 | void free_raw(int64_t raw); 27 | void zero(int64_t raw, int64_t size); 28 | at::Tensor get_raw_ipc_address(int64_t raw); 29 | std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw); 30 | at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last); 31 | at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last); 32 | at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last); 33 | void push_pull_halos_1d( 34 | bool diagnostics, bool explicit_nhwc, 35 | int numSM, // number of SMs to use 36 | int peer_rank, // rank in spatial parallel group 37 | bool top_zero, // if top halo should be zeroed 38 | at::Tensor top_out_halo, // top output halo buffer (in local device memory, received from top neighbor) 39 | at::Tensor top_inp_transfer, // top input transfer buffer (in local peer memory) 40 | at::Tensor top_out_transfer, // top output transfer buffer (in top neighbor peer memory) 41 | at::Tensor top_inp_halo, // top input halo buffer (in local device memory, sent to top neighbor) 42 | bool btm_zero, // if btm halo should be zeroed 43 | at::Tensor btm_out_halo, // btm output halo buffer (in local device memory, received from btm neighbor) 44 | at::Tensor btm_inp_transfer, // btm input transfer buffer (in local peer memory) 45 | at::Tensor btm_out_transfer, // btm output transfer buffer (in btm neighbor peer memory) 46 | at::Tensor btm_inp_halo // btm input halo buffer (in local device memory, sent to btm neighbor) 47 | ); 48 | } // namespace peer_memory 49 | } // namespace contrib 50 | } // namespace apex 51 | #endif 52 | -------------------------------------------------------------------------------- /csrc/megatron/scaled_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | 20 | #include 21 | 22 | namespace multihead_attn { 23 | namespace fused_softmax { 24 | namespace scaled_softmax { 25 | 26 | torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); 27 | 28 | torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); 29 | 30 | torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { 31 | TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); 32 | TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), 33 | "Only fp16 and bf16 are supported"); 34 | 35 | return fwd_cuda(input, scale_factor); 36 | } 37 | 38 | torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { 39 | TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); 40 | TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); 41 | 42 | TORCH_CHECK( 43 | (output_grads.scalar_type() == at::ScalarType::Half) || (output_grads.scalar_type() == at::ScalarType::BFloat16), 44 | "Only fp16 and bf16 are supported"); 45 | TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) || 46 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 47 | "Only fp16 and bf16 are supported"); 48 | 49 | return bwd_cuda(output_grads, softmax_results, scale_factor); 50 | } 51 | 52 | } // end namespace scaled_softmax 53 | } // end namespace fused_softmax 54 | } // end namespace multihead_attn 55 | 56 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 57 | m.def("forward", &multihead_attn::fused_softmax::scaled_softmax::fwd, 58 | "Self Multihead Attention scaled, softmax -- Forward.", py::call_guard()); 59 | m.def("backward", &multihead_attn::fused_softmax::scaled_softmax::bwd, 60 | "Self Multihead Attention scaled, softmax -- Backward.", py::call_guard()); 61 | } 62 | -------------------------------------------------------------------------------- /tests/distributed/DDP/ddp_race_condition_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch.nn import Module 4 | from apex.parallel import DistributedDataParallel as DDP 5 | import argparse 6 | import os 7 | 8 | 9 | parser = argparse.ArgumentParser(description="allreduce hook example") 10 | parser.add_argument("--local_rank", default=0, type=int) 11 | args = parser.parse_args() 12 | 13 | args.distributed = False 14 | if "WORLD_SIZE" in os.environ: 15 | args.distributed = int(os.environ["WORLD_SIZE"]) > 1 16 | 17 | if args.distributed: 18 | args.gpu = args.local_rank % torch.cuda.device_count() 19 | torch.cuda.set_device(args.gpu) 20 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 21 | args.world_size = torch.distributed.get_world_size() 22 | 23 | torch.set_printoptions(precision=10) 24 | torch.manual_seed(args.local_rank) 25 | 26 | 27 | class Model(Module): 28 | def __init__(self): 29 | super(Model, self).__init__() 30 | self.a = Parameter(torch.cuda.FloatTensor(4096 * 4096).fill_(1.0)) 31 | self.b = Parameter(torch.cuda.FloatTensor(4096 * 4096).fill_(2.0)) 32 | 33 | def forward(self, input): 34 | return (input * self.a) * self.b 35 | 36 | 37 | model = Model() 38 | # model = DDP(model, message_size=1, gradient_predivide_factor=8.0) 39 | # model = DDP(model, delay_allreduce=True) 40 | # model = DDP(model, message_size=1, allreduce_trigger_params=[model.b]) 41 | model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3) 42 | 43 | x = torch.cuda.FloatTensor(4096 * 4096) 44 | 45 | passed = True 46 | torch.cuda.cudart().cudaProfilerStart() 47 | for i in range(10): 48 | x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity 49 | model.zero_grad() 50 | out = model(x) 51 | loss = out.sum() 52 | # torch.cuda.nvtx.range_push("backward") 53 | loss.backward() 54 | # torch.cuda.nvtx.range_pop() 55 | 56 | # torch.cuda.nvtx.range_push("synchronize() + info") 57 | # torch.cuda.synchronize() 58 | print("i = {}".format(i)) 59 | 60 | def info(name, param, val): 61 | expected = val * 4096 * 4096 * (2.0 * i + 1) / 2.0 62 | actual = param.grad.data.sum().item() 63 | print( 64 | name 65 | + ": grad.data_ptr() = {}, expected sum {}, got {}".format( 66 | param.grad.data_ptr(), expected, actual 67 | ) 68 | ) 69 | return expected == actual 70 | 71 | if not info("model.a", model.module.a, 2.0): 72 | passed = False 73 | if not info("model.b", model.module.b, 1.0): 74 | passed = False 75 | # torch.cuda.nvtx.range_pop() 76 | torch.cuda.cudart().cudaProfilerStop() 77 | 78 | print("passed = ", passed) 79 | -------------------------------------------------------------------------------- /apex/contrib/test/focal_loss/test_focal_loss.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | reference_available = True 7 | try: 8 | from torchvision.ops.focal_loss import sigmoid_focal_loss 9 | except ImportError: 10 | reference_available = False 11 | 12 | SKIP_TEST = None 13 | try: 14 | from apex.contrib.focal_loss import focal_loss 15 | except ImportError as e: 16 | SKIP_TEST = e 17 | 18 | 19 | @unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") 20 | @unittest.skipIf( 21 | not reference_available, 22 | "Reference implementation `torchvision.ops.focal_loss.sigmoid_focal_loss` is not available.", 23 | ) 24 | class FocalLossTest(unittest.TestCase): 25 | N_SAMPLES = 12 26 | N_CLASSES = 8 27 | ALPHA = 0.24 28 | GAMMA = 2.0 29 | REDUCTION = "sum" 30 | 31 | def test_focal_loss(self) -> None: 32 | if not reference_available: 33 | self.skipTest( 34 | "This test needs `torchvision` for `torchvision.ops.focal_loss.sigmoid_focal_loss`." 35 | ) 36 | else: 37 | x = torch.randn(FocalLossTest.N_SAMPLES, FocalLossTest.N_CLASSES).cuda() 38 | with torch.no_grad(): 39 | x_expected = x.clone() 40 | x_actual = x.clone() 41 | x_expected.requires_grad_() 42 | x_actual.requires_grad_() 43 | 44 | classes = torch.randint(0, FocalLossTest.N_CLASSES, (FocalLossTest.N_SAMPLES,)).cuda() 45 | with torch.no_grad(): 46 | y = F.one_hot(classes, FocalLossTest.N_CLASSES).float() 47 | 48 | expected = sigmoid_focal_loss( 49 | x_expected, 50 | y, 51 | alpha=FocalLossTest.ALPHA, 52 | gamma=FocalLossTest.GAMMA, 53 | reduction=FocalLossTest.REDUCTION, 54 | ) 55 | 56 | actual = sum( 57 | [ 58 | focal_loss.FocalLoss.apply( 59 | x_actual[i : i + 1], 60 | classes[i : i + 1].long(), 61 | torch.ones([], device="cuda"), 62 | FocalLossTest.N_CLASSES, 63 | FocalLossTest.ALPHA, 64 | FocalLossTest.GAMMA, 65 | 0.0, 66 | ) 67 | for i in range(FocalLossTest.N_SAMPLES) 68 | ] 69 | ) 70 | 71 | # forward parity 72 | torch.testing.assert_close(expected, actual) 73 | 74 | expected.backward() 75 | actual.backward() 76 | 77 | # grad parity 78 | torch.testing.assert_close(x_expected.grad, x_actual.grad) 79 | 80 | 81 | if __name__ == "__main__": 82 | torch.manual_seed(42) 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /examples/simple/distributed/distributed_data_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | from apex import amp 5 | # FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead) 6 | from apex.parallel import DistributedDataParallel 7 | 8 | parser = argparse.ArgumentParser() 9 | # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied 10 | # automatically by torch.distributed.launch. 11 | parser.add_argument("--local_rank", default=0, type=int) 12 | args = parser.parse_args() 13 | 14 | # FOR DISTRIBUTED: If we are running under torch.distributed.launch, 15 | # the 'WORLD_SIZE' environment variable will also be set automatically. 16 | args.distributed = False 17 | if 'WORLD_SIZE' in os.environ: 18 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 19 | 20 | if args.distributed: 21 | # FOR DISTRIBUTED: Set the device according to local_rank. 22 | torch.cuda.set_device(args.local_rank) 23 | 24 | # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide 25 | # environment variables, and requires that you use init_method=`env://`. 26 | torch.distributed.init_process_group(backend='nccl', 27 | init_method='env://') 28 | 29 | torch.backends.cudnn.benchmark = True 30 | 31 | N, D_in, D_out = 64, 1024, 16 32 | 33 | # Each process receives its own batch of "fake input data" and "fake target data." 34 | # The "training loop" in each process just uses this fake batch over and over. 35 | # https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic 36 | # example of distributed data sampling for both training and validation. 37 | x = torch.randn(N, D_in, device='cuda') 38 | y = torch.randn(N, D_out, device='cuda') 39 | 40 | model = torch.nn.Linear(D_in, D_out).cuda() 41 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 42 | 43 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1") 44 | 45 | if args.distributed: 46 | # FOR DISTRIBUTED: After amp.initialize, wrap the model with 47 | # apex.parallel.DistributedDataParallel. 48 | model = DistributedDataParallel(model) 49 | # torch.nn.parallel.DistributedDataParallel is also fine, with some added args: 50 | # model = torch.nn.parallel.DistributedDataParallel(model, 51 | # device_ids=[args.local_rank], 52 | # output_device=args.local_rank) 53 | 54 | loss_fn = torch.nn.MSELoss() 55 | 56 | for t in range(500): 57 | optimizer.zero_grad() 58 | y_pred = model(x) 59 | loss = loss_fn(y_pred, y) 60 | with amp.scale_loss(loss, optimizer) as scaled_loss: 61 | scaled_loss.backward() 62 | optimizer.step() 63 | 64 | if args.local_rank == 0: 65 | print("final loss = ", loss) 66 | -------------------------------------------------------------------------------- /apex/transformer/tensor_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model parallel utility interface.""" 16 | 17 | from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy 18 | 19 | from apex.transformer.tensor_parallel.data import broadcast_data 20 | 21 | from apex.transformer.tensor_parallel.layers import ( 22 | ColumnParallelLinear, 23 | RowParallelLinear, 24 | VocabParallelEmbedding, 25 | set_tensor_model_parallel_attributes, 26 | set_defaults_if_not_set_tensor_model_parallel_attributes, 27 | copy_tensor_model_parallel_attributes, 28 | ) 29 | 30 | from apex.transformer.tensor_parallel.mappings import ( 31 | copy_to_tensor_model_parallel_region, 32 | gather_from_tensor_model_parallel_region, 33 | reduce_from_tensor_model_parallel_region, 34 | scatter_to_tensor_model_parallel_region, 35 | scatter_to_sequence_parallel_region, 36 | ) 37 | 38 | from .random import ( 39 | checkpoint, 40 | get_cuda_rng_tracker, 41 | init_checkpointed_activations_memory_buffer, 42 | model_parallel_cuda_manual_seed, 43 | reset_checkpointed_activations_memory_buffer, 44 | ) 45 | 46 | from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim 47 | 48 | 49 | __all__ = [ 50 | # cross_entropy.py 51 | "vocab_parallel_cross_entropy", 52 | # data.py 53 | "broadcast_data", 54 | # layers.py 55 | "ColumnParallelLinear", 56 | "RowParallelLinear", 57 | "VocabParallelEmbedding", 58 | "set_tensor_model_parallel_attributes", 59 | "set_defaults_if_not_set_tensor_model_parallel_attributes", 60 | "copy_tensor_model_parallel_attributes", 61 | # mappings.py 62 | "copy_to_tensor_model_parallel_region", 63 | "gather_from_tensor_model_parallel_region", 64 | "reduce_from_tensor_model_parallel_region", 65 | "scatter_to_tensor_model_parallel_region", 66 | "scatter_to_sequence_parallel_region", 67 | # random.py 68 | "checkpoint", 69 | "get_cuda_rng_tracker", 70 | "init_checkpointed_activations_memory_buffer", 71 | "model_parallel_cuda_manual_seed", 72 | "reset_checkpointed_activations_memory_buffer", 73 | # utils.py 74 | "split_tensor_along_last_dim", 75 | ] 76 | -------------------------------------------------------------------------------- /apex/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | 4 | # May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten 5 | import torch 6 | 7 | # For optimizers and normalization there is no Python fallback. 8 | # Absence of cuda backend is a hard error. 9 | # I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda 10 | # to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext 11 | # so they expect those backends to be available, but for some reason they actually aren't 12 | # available (for example because they built improperly in a way that isn't revealed until 13 | # load time) the error message is timely and visible. 14 | from . import optimizers 15 | from . import normalization 16 | 17 | if torch.distributed.is_available(): 18 | from . import transformer 19 | 20 | __all__ = ["optimizers", "normalization", "transformer"] 21 | 22 | # Logging utilities for apex.transformer module 23 | class RankInfoFormatter(logging.Formatter): 24 | def format(self, record): 25 | from apex.transformer.parallel_state import get_rank_info 26 | 27 | record.rank_info = get_rank_info() 28 | return super().format(record) 29 | 30 | _library_root_logger = logging.getLogger(__name__) 31 | handler = logging.StreamHandler() 32 | handler.setFormatter( 33 | RankInfoFormatter( 34 | "%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", 35 | "%y-%m-%d %H:%M:%S", 36 | ) 37 | ) 38 | _library_root_logger.addHandler(handler) 39 | _library_root_logger.propagate = False 40 | else: 41 | # Transformers require PyTorch built with distributed support 42 | __all__ = ["optimizers", "normalization"] 43 | 44 | 45 | def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: 46 | cudnn_available = torch.backends.cudnn.is_available() 47 | cudnn_version = torch.backends.cudnn.version() if cudnn_available else None 48 | if not (cudnn_available and (cudnn_version >= required_cudnn_version)): 49 | warnings.warn( 50 | f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, " 51 | f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}" 52 | ) 53 | return False 54 | return True 55 | 56 | 57 | class DeprecatedFeatureWarning(FutureWarning): 58 | pass 59 | 60 | 61 | def deprecated_warning(msg: str) -> None: 62 | if ( 63 | not torch.distributed.is_available 64 | or not torch.distributed.is_initialized() 65 | or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) 66 | ): 67 | warnings.warn(msg, DeprecatedFeatureWarning) 68 | -------------------------------------------------------------------------------- /csrc/megatron/scaled_upper_triang_masked_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | 20 | #include 21 | 22 | namespace multihead_attn { 23 | namespace fused_softmax { 24 | namespace scaled_upper_triang_masked_softmax { 25 | 26 | torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); 27 | 28 | torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); 29 | 30 | torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { 31 | TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); 32 | TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), 33 | "Only fp16 and bf16 are supported"); 34 | 35 | return fwd_cuda(input, scale_factor); 36 | } 37 | 38 | torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { 39 | TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); 40 | TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); 41 | 42 | TORCH_CHECK( 43 | (output_grads.scalar_type() == at::ScalarType::Half) || (output_grads.scalar_type() == at::ScalarType::BFloat16), 44 | "Only fp16 and bf16 are supported"); 45 | TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) || 46 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 47 | "Only fp16 and bf16 are supported"); 48 | 49 | return bwd_cuda(output_grads, softmax_results, scale_factor); 50 | } 51 | 52 | } // end namespace scaled_upper_triang_masked_softmax 53 | } // end namespace fused_softmax 54 | } // end namespace multihead_attn 55 | 56 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 57 | m.def("forward", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, 58 | "Self Multihead Attention scaled, time masked softmax -- Forward.", py::call_guard()); 59 | m.def("backward", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, 60 | "Self Multihead Attention scaled, time masked softmax -- Backward.", py::call_guard()); 61 | } 62 | -------------------------------------------------------------------------------- /apex/contrib/csrc/transducer/transducer_joint.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 5 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 6 | #define CHECK_INPUT(x) \ 7 | CHECK_CUDA(x); \ 8 | CHECK_CONTIGUOUS(x) 9 | 10 | std::vector transducer_joint_cuda_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, 11 | torch::Tensor gLen, torch::Tensor batchOffset, 12 | int64_t packedBatch, int opt, bool packOutput, bool relu, 13 | bool dropout, float dropoutProb, int tileSize); 14 | 15 | std::vector transducer_joint_cuda_backward(std::vector in, torch::Tensor fLen, 16 | torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, 17 | int maxGLen, bool packOutput, float scale); 18 | 19 | std::vector transducer_joint_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, 20 | torch::Tensor gLen, torch::Tensor batchOffset, int64_t packedBatch, 21 | int opt, bool packOutput, bool relu, bool dropout, 22 | float dropoutProb, int tileSize) { 23 | CHECK_INPUT(f); 24 | CHECK_INPUT(g); 25 | CHECK_INPUT(fLen); 26 | CHECK_INPUT(gLen); 27 | if (packOutput) CHECK_INPUT(batchOffset); 28 | return transducer_joint_cuda_forward(f, g, fLen, gLen, batchOffset, packedBatch, opt, packOutput, relu, dropout, 29 | dropoutProb, tileSize); 30 | } 31 | 32 | std::vector transducer_joint_backward(std::vector in, torch::Tensor fLen, 33 | torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, 34 | int maxGLen, bool packOutput, float scale) { 35 | for (auto t : in) { 36 | CHECK_INPUT(t); 37 | } 38 | CHECK_INPUT(fLen); 39 | CHECK_INPUT(gLen); 40 | if (packOutput) CHECK_INPUT(batchOffset); 41 | return transducer_joint_cuda_backward(in, fLen, gLen, batchOffset, maxFLen, maxGLen, packOutput, scale); 42 | } 43 | 44 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 45 | m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)", 46 | py::call_guard()); 47 | m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)", 48 | py::call_guard()); 49 | } 50 | -------------------------------------------------------------------------------- /apex/transformer/pipeline_parallel/_timers.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | 5 | 6 | class _Timer: 7 | """Timer.""" 8 | 9 | def __init__(self, name): 10 | self.name_ = name 11 | self.elapsed_ = 0.0 12 | self.started_ = False 13 | self.start_time = time.time() 14 | 15 | def start(self): 16 | """Start the timer.""" 17 | assert not self.started_, "timer has already been started" 18 | torch.cuda.synchronize() 19 | self.start_time = time.time() 20 | self.started_ = True 21 | 22 | def stop(self): 23 | """Stop the timer.""" 24 | assert self.started_, "timer is not started" 25 | torch.cuda.synchronize() 26 | self.elapsed_ += time.time() - self.start_time 27 | self.started_ = False 28 | 29 | def reset(self): 30 | """Reset timer.""" 31 | self.elapsed_ = 0.0 32 | self.started_ = False 33 | 34 | def elapsed(self, reset=True): 35 | """Calculate the elapsed time.""" 36 | started_ = self.started_ 37 | # If the timing in progress, end it first. 38 | if self.started_: 39 | self.stop() 40 | # Get the elapsed time. 41 | elapsed_ = self.elapsed_ 42 | # Reset the elapsed time 43 | if reset: 44 | self.reset() 45 | # If timing was in progress, set it back. 46 | if started_: 47 | self.start() 48 | return elapsed_ 49 | 50 | 51 | class _Timers: 52 | """Group of timers.""" 53 | 54 | def __init__(self): 55 | self.timers = {} 56 | 57 | def __call__(self, name): 58 | if name not in self.timers: 59 | self.timers[name] = _Timer(name) 60 | return self.timers[name] 61 | 62 | def write(self, names, writer, iteration, normalizer=1.0, reset=False): 63 | """Write timers to a tensorboard writer""" 64 | # currently when using add_scalars, 65 | # torch.utils.add_scalars makes each timer its own run, which 66 | # polutes the runs list, so we just add each as a scalar 67 | assert normalizer > 0.0 68 | for name in names: 69 | value = self.timers[name].elapsed(reset=reset) / normalizer 70 | writer.add_scalar(name + "-time", value, iteration) 71 | 72 | def log(self, names, normalizer=1.0, reset=True): 73 | """Log a group of timers.""" 74 | assert normalizer > 0.0 75 | string = "time (ms)" 76 | for name in names: 77 | elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer 78 | string += " | {}: {:.2f}".format(name, elapsed_time) 79 | if torch.distributed.is_initialized(): 80 | if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): 81 | print(string, flush=True) 82 | else: 83 | print(string, flush=True) 84 | -------------------------------------------------------------------------------- /apex/contrib/csrc/transducer/transducer_loss.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 6 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) \ 8 | CHECK_CUDA(x); \ 9 | CHECK_CONTIGUOUS(x) 10 | 11 | std::vector transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen, 12 | torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen, 13 | int blankIdx, int opt, bool packedInput); 14 | 15 | torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, 16 | torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen, 17 | torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx, 18 | int opt, bool fuseSoftmaxBackward, bool packedInput); 19 | 20 | std::vector transducer_loss_forward(torch::Tensor x, torch::Tensor label, torch::Tensor fLen, 21 | torch::Tensor yLen, torch::Tensor batchOffset, int maxFLen, 22 | int blankIdx, int opt, bool packedInput) { 23 | CHECK_INPUT(x); 24 | CHECK_INPUT(label); 25 | CHECK_INPUT(fLen); 26 | CHECK_INPUT(yLen); 27 | if (packedInput) CHECK_INPUT(batchOffset); 28 | return transducer_loss_cuda_forward(x, label, fLen, yLen, batchOffset, maxFLen, blankIdx, opt, packedInput); 29 | } 30 | 31 | torch::Tensor transducer_loss_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, torch::Tensor beta, 32 | torch::Tensor fLen, torch::Tensor yLen, torch::Tensor label, 33 | torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt, 34 | bool fuseSoftmaxBackward, bool packedInput) { 35 | CHECK_INPUT(x); 36 | CHECK_INPUT(label); 37 | CHECK_INPUT(lossGrad); 38 | CHECK_INPUT(alpha); 39 | CHECK_INPUT(beta); 40 | CHECK_INPUT(fLen); 41 | CHECK_INPUT(yLen); 42 | if (packedInput) CHECK_INPUT(batchOffset); 43 | 44 | return transducer_loss_cuda_backward(x, lossGrad, alpha, beta, fLen, yLen, label, batchOffset, maxFLen, blankIdx, opt, 45 | fuseSoftmaxBackward, packedInput); 46 | } 47 | 48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 49 | m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)", 50 | py::call_guard()); 51 | m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)", 52 | py::call_guard()); 53 | } 54 | -------------------------------------------------------------------------------- /docs/source/_static/css/pytorch_theme.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 3 | } 4 | 5 | /* Default header fonts are ugly */ 6 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { 7 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 8 | } 9 | 10 | /* Use white for docs background */ 11 | .wy-side-nav-search { 12 | background-color: #fff; 13 | } 14 | 15 | .wy-nav-content-wrap, .wy-menu li.current > a { 16 | background-color: #fff; 17 | } 18 | 19 | @media screen and (min-width: 1400px) { 20 | .wy-nav-content-wrap { 21 | background-color: rgba(0, 0, 0, 0.0470588); 22 | } 23 | 24 | .wy-nav-content { 25 | background-color: #fff; 26 | } 27 | } 28 | 29 | /* Fixes for mobile */ 30 | .wy-nav-top { 31 | background-color: #fff; 32 | background-image: url('../img/apex.jpg'); 33 | background-repeat: no-repeat; 34 | background-position: center; 35 | padding: 0; 36 | margin: 0.4045em 0.809em; 37 | color: #333; 38 | } 39 | 40 | .wy-nav-top > a { 41 | display: none; 42 | } 43 | 44 | @media screen and (max-width: 768px) { 45 | .wy-side-nav-search>a img.logo { 46 | height: 60px; 47 | } 48 | } 49 | 50 | /* This is needed to ensure that logo above search scales properly */ 51 | .wy-side-nav-search a { 52 | display: block; 53 | } 54 | 55 | /* This ensures that multiple constructors will remain in separate lines. */ 56 | .rst-content dl:not(.docutils) dt { 57 | display: table; 58 | } 59 | 60 | /* Use our red for literals (it's very similar to the original color) */ 61 | .rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal { 62 | color: #F05732; 63 | } 64 | 65 | .rst-content tt.xref, a .rst-content tt, .rst-content tt.xref, 66 | .rst-content code.xref, a .rst-content tt, a .rst-content code { 67 | color: #404040; 68 | } 69 | 70 | /* Change link colors (except for the menu) */ 71 | 72 | a { 73 | color: #F05732; 74 | } 75 | 76 | a:hover { 77 | color: #F05732; 78 | } 79 | 80 | 81 | a:visited { 82 | color: #D44D2C; 83 | } 84 | 85 | .wy-menu a { 86 | color: #b3b3b3; 87 | } 88 | 89 | .wy-menu a:hover { 90 | color: #b3b3b3; 91 | } 92 | 93 | /* Default footer text is quite big */ 94 | footer { 95 | font-size: 80%; 96 | } 97 | 98 | footer .rst-footer-buttons { 99 | font-size: 125%; /* revert footer settings - 1/80% = 125% */ 100 | } 101 | 102 | footer p { 103 | font-size: 100%; 104 | } 105 | 106 | /* For hidden headers that appear in TOC tree */ 107 | /* see http://stackoverflow.com/a/32363545/3343043 */ 108 | .rst-content .hidden-section { 109 | display: none; 110 | } 111 | 112 | nav .hidden-section { 113 | display: inherit; 114 | } 115 | 116 | .wy-side-nav-search>div.version { 117 | color: #000; 118 | } 119 | -------------------------------------------------------------------------------- /apex/contrib/sparsity/permutation_tests/runtime_table.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OUTDIR="results/runtime_logs" 4 | mkdir -p $OUTDIR 5 | 6 | R1000=random,1000 7 | CS=channel_swap,0 8 | CS_100=channel_swap,100 9 | OSG2=optimize_stripe_groups,8,0 10 | OSG2_100=optimize_stripe_groups,8,100 11 | OSG2_1000=optimize_stripe_groups,8,1000 12 | OSG3=optimize_stripe_groups,12,0 13 | OSG3_100=optimize_stripe_groups,12,100 14 | OSG3_1000=optimize_stripe_groups,12,1000 15 | 16 | for cols in "32" "64" "128" "256"; do 17 | echo "$cols x $cols" 18 | python3 permutation_test.py --channels $cols --filters $cols --pretty_print=False $R1000 $CS $CS_100 $OSG2 $OSG2_100 $OSG2_1000 $OSG3 $OSG3_100 $OSG3_1000 | tee "${OUTDIR}/runtime_${cols}x${cols}.log" 19 | let "rows = $cols * 2" 20 | echo "$cols x $rows" 21 | python3 permutation_test.py --channels $cols --filters $rows --pretty_print=False $R1000 $CS $CS_100 $OSG2 $OSG2_100 $OSG2_1000 $OSG3 $OSG3_100 $OSG3_1000 | tee "${OUTDIR}/runtime_${cols}x${rows}.log" 22 | done 23 | 24 | # 2048x2048 is too large for OSG3 25 | echo "2048 x 2048" 26 | python3 permutation_test.py --channels 2048 --filters 2048 --pretty_print=False $R1000 $CS $CS_100 $OSG2 $OSG2_100 $OSG2_1000 | tee "${OUTDIR}/runtime_2048x2048.log" 27 | 28 | 29 | ############### collect results into a .csv file 30 | echo "Gathering results ..." 31 | 32 | # efficacy and runtime from one strategy and size 33 | get_results() { 34 | local strategy=$1 35 | local cols=$2 36 | local rows=$3 37 | local OUTFILE=$4 38 | 39 | grep "$strategy," "$OUTDIR/runtime_${cols}x${rows}.log" | awk -F "," '{printf "%s,%s,",$3,$4}' >> $OUTFILE 40 | } 41 | 42 | # prepare output file headers 43 | OUTFILE="results/runtimes.csv" 44 | printf "Columns," > $OUTFILE 45 | for cols in "32" "64" "128" "256"; do 46 | printf "$cols,$cols,$cols,$cols," >> $OUTFILE 47 | done 48 | printf "2048,2048\n" >> $OUTFILE 49 | 50 | printf "Rows," >> $OUTFILE 51 | for cols in "32" "64" "128" "256"; do 52 | let "rows = $cols * 2" 53 | printf "$cols,$cols,$rows,$rows," >> $OUTFILE 54 | done 55 | printf "2048,2048\n" >> $OUTFILE 56 | 57 | printf "Metric," >> $OUTFILE 58 | for cols in "32" "64" "128" "256"; do 59 | printf "Efficacy,Runtime,Efficay,Runtime," >> $OUTFILE 60 | done 61 | printf "Efficacy,Runtime\n" >> $OUTFILE 62 | 63 | # gather data in a reasonable order 64 | for strategy in "$R1000" "$CS" "$CS_100" "$OSG2" "$OSG2_100" "$OSG2_1000" "$OSG3" "$OSG3_100" "$OSG3_1000"; do 65 | strategy=$(echo "$strategy" | sed 's/,/_/g') # replace commas with underscores, as they'll appear in the results logs 66 | printf "$strategy," >> $OUTFILE 67 | for cols in "32" "64" "128" "256"; do 68 | get_results "$strategy" "$cols" "$cols" "$OUTFILE" 69 | let "rows = $cols * 2" 70 | get_results "$strategy" "$cols" "$rows" "$OUTFILE" 71 | done 72 | 73 | get_results "$strategy" "2048" "2048" "$OUTFILE" 74 | 75 | printf "\n" >> $OUTFILE 76 | done 77 | 78 | echo "Done! $OUTFILE" 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | apex.egg-info 2 | dist 3 | build 4 | docs/build 5 | *~ 6 | __pycache__ 7 | .vscode 8 | 9 | # Copied from https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ 148 | -------------------------------------------------------------------------------- /csrc/megatron/generic_scaled_masked_softmax.cpp: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | 20 | #include 21 | 22 | namespace multihead_attn { 23 | namespace fused_softmax { 24 | namespace generic_scaled_masked_softmax { 25 | 26 | torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor); 27 | 28 | torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); 29 | 30 | torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) { 31 | TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); 32 | TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), 33 | "Only fp16 and bf16 are supported"); 34 | TORCH_CHECK(mask.dim() == 4, "expected 4D tensor"); 35 | 36 | return fwd_cuda(input, mask, scale_factor); 37 | } 38 | 39 | torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { 40 | TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); 41 | TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); 42 | 43 | TORCH_CHECK( 44 | (output_grads.scalar_type() == at::ScalarType::Half) || (output_grads.scalar_type() == at::ScalarType::BFloat16), 45 | "Only fp16 and bf16 are supported"); 46 | TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) || 47 | (softmax_results.scalar_type() == at::ScalarType::BFloat16), 48 | "Only fp16 and bf16 are supported"); 49 | 50 | return bwd_cuda(output_grads, softmax_results, scale_factor); 51 | } 52 | 53 | } // end namespace generic_scaled_masked_softmax 54 | } // end namespace fused_softmax 55 | } // end namespace multihead_attn 56 | 57 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 58 | m.def("forward", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd, 59 | "Self Multihead Attention scaled, time masked softmax -- Forward.", py::call_guard()); 60 | 61 | m.def("backward", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd, 62 | "Self Multihead Attention scaled, time masked softmax -- Backward.", py::call_guard()); 63 | } 64 | -------------------------------------------------------------------------------- /apex/contrib/nccl_allocator/nccl_allocator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import _apex_nccl_allocator 4 | 5 | from contextlib import nullcontext 6 | 7 | 8 | __all__ = ["init", "nccl_mem", "create_nccl_mem_pool"] 9 | 10 | 11 | def get_func_args(func): 12 | import inspect 13 | 14 | sig = inspect.signature(func) 15 | return [arg.name for arg in sig.parameters.values()] 16 | 17 | 18 | def create_nccl_mem_pool(symmetric: bool | None = None) -> torch.cuda.MemPool: 19 | _allocator = _apex_nccl_allocator.get_nccl_allocator() 20 | if symmetric is None: 21 | _pool = torch.cuda.MemPool(_allocator) 22 | else: 23 | if "symmetric" in get_func_args(torch.cuda.MemPool): 24 | _pool = torch.cuda.MemPool(_allocator, symmetric=symmetric) 25 | elif "symm_mem" in get_func_args(torch.cuda.MemPool): 26 | # This path handles argument name divergence between 27 | # nvidia pytorch and the official pytorch. 28 | _pool = torch.cuda.MemPool(_allocator, symm_mem=symmetric) 29 | else: 30 | raise ValueError( 31 | "symmetric setting with torch.cuda.MemPool requires higher PyTorch version" 32 | ) 33 | return _pool 34 | 35 | 36 | def init() -> None: 37 | os.environ["NCCL_NVLS_ENABLE"] = "1" 38 | os.environ["TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"] = "0" 39 | 40 | 41 | class nccl_mem: 42 | def __init__(self, pool, enabled=True, device=None, group=None): 43 | self.device = None 44 | self.group = None 45 | self.mem_context = None 46 | self.pool = pool 47 | 48 | if enabled: 49 | if device is None: 50 | self.device = torch.device("cuda", torch.cuda.current_device()) 51 | elif isinstance(device, int): 52 | self.device = torch.device("cuda", device) 53 | elif isinstance(device, str): 54 | assert "cuda" in device, "only cuda devices are supported" 55 | self.device = torch.device(device) 56 | 57 | if group is None: 58 | self.group = torch.distributed.distributed_c10d._get_default_group() 59 | else: 60 | self.group = group 61 | 62 | self.mem_context = torch.cuda.use_mem_pool(self.pool) 63 | else: 64 | self.mem_context = nullcontext() 65 | 66 | def __enter__(self): 67 | self.mem_context.__enter__() 68 | if self.group is not None: 69 | backend = self.group._get_backend(self.device) 70 | try: 71 | backend.deregister_mem_pool(self.pool) 72 | except RuntimeError: 73 | pass 74 | 75 | def __exit__(self, *args): 76 | if self.group is not None: 77 | backend = self.group._get_backend(self.device) 78 | try: 79 | backend.register_mem_pool(self.pool) 80 | except RuntimeError: 81 | pass 82 | self.mem_context.__exit__(*args) 83 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "gn.hpp" 10 | #include "gn_dispatch_hw_c.hpp" 11 | #include "gn_utils.hpp" 12 | 13 | #define DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, NUM_GROUPS, SILU, ...) \ 14 | [&] { \ 15 | if (num_groups == 16 && silu == true) { \ 16 | constexpr int NUM_GROUPS = 16; \ 17 | constexpr bool SILU = true; \ 18 | return __VA_ARGS__(); \ 19 | } \ 20 | if (num_groups == 32 && silu == false) { \ 21 | constexpr int NUM_GROUPS = 32; \ 22 | constexpr bool SILU = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | throw std::invalid_argument("DISPATCH_NUM_GROUPS_AND_SILU " + std::to_string(num_groups) + " " + \ 26 | std::to_string(silu)); \ 27 | }() 28 | 29 | namespace group_norm_v2 { 30 | 31 | template 32 | void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(T)); 33 | 34 | template 35 | void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(T)); 36 | 37 | template 38 | void gn_cuda(GN_CUDA_HOST_PARAMS(T)) { 39 | DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] { 40 | DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, 41 | [&] { return gn_cuda_single_shape(GN_CUDA_HOST_ARGS); }); 42 | }); 43 | } 44 | 45 | template 46 | void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(T)) { 47 | DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] { 48 | DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, 49 | [&] { return gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_ARGS); }); 50 | }); 51 | } 52 | 53 | template void gn_cuda(GN_CUDA_HOST_PARAMS(half)); 54 | template void gn_cuda(GN_CUDA_HOST_PARAMS(__nv_bfloat16)); 55 | 56 | template void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(half)); 57 | template void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16)); 58 | 59 | } // namespace group_norm_v2 60 | -------------------------------------------------------------------------------- /apex/contrib/torchsched/config.py: -------------------------------------------------------------------------------- 1 | """Configurations for graph scheduler.""" 2 | 3 | import functools 4 | import os 5 | import re 6 | import sys 7 | 8 | # Debug info and dump grpahs 9 | debug = os.getenv("TORCH_SCHED_DEBUG", "0") == "1" 10 | 11 | # Toggle pre_grad_pass for various pattern matches 12 | enable_pre_grad_pass = False 13 | 14 | # Pre grad pass patterns 15 | pre_grad_pass_options: list[str] = ["cudnn_layer_norm"] 16 | 17 | # Number of CUDA streams used for multi-stream scheduling. 18 | # The first stream will be critical path stream, operators on non-critical path will be 19 | # scheduled to other streams in a round-robin way. 20 | num_streams = int(os.getenv("TORCH_SCHED_NUM_STREAMS", "8")) 21 | 22 | 23 | def _get_skip_post_grad_graph_ids() -> set[int]: 24 | if ids := os.environ.get("TORCH_SCHED_SKIP_GRAPH_IDS"): 25 | result: set[int] = set() 26 | for part in ids.split(","): 27 | if "-" in part: 28 | start, end = map(int, part.split("-")) 29 | result.update(range(start, end + 1)) 30 | else: 31 | result.add(int(part)) 32 | return result 33 | else: 34 | return set() 35 | 36 | 37 | # IDs of post AOT-autograd graphs that should be skipped for multi-stream scheduling. Can be 38 | # specified via TORCH_SCHED_SKIP_GRAPH_IDS environment variable in a SLURM-like scheme, e.g., 39 | # TORCH_SCHED_SKIP_GRAPH_IDS=1,2,3-5,7-10 40 | skip_post_grad_graph_ids: set[int] = _get_skip_post_grad_graph_ids() 41 | 42 | # Reduce the number of allocated CUDA Events in the generated program by: 43 | # 1. Track reference count of each CUDA Event in the scheduling phase. Skip generating CUDA Events 44 | # that have no reference counts, i.e., have not been waited by other streams; 45 | # 2. Reuse allocated CUDA Events when feasible. 46 | # This option is enable by default. 47 | reuse_cuda_event: bool = os.getenv("TORCH_SCHED_REUSE_CUDA_EVENT", "1") == "1" 48 | 49 | 50 | @functools.lru_cache 51 | def __get_dump_code_backends_and_dir( 52 | dump_code: str | None, 53 | ) -> tuple[list[str], str | None]: 54 | pattern = r"(?:\+(?P\w+),)?(?P[\w\/\.\-\s@#~]+)" 55 | backends, dir = ["torchsched"], None 56 | if dump_code and (match := re.match(pattern, dump_code)): 57 | if backend := match.group("backend"): 58 | backends.append(backend) 59 | dir = os.path.abspath(match.group("dir")) 60 | return backends, dir 61 | 62 | 63 | # Specify dump code backend types and output directory by:: 64 | # 65 | # TORCH_SCHED_DUMP_CODE='+inductor,/dir/to/save/code' 66 | # 67 | # Where `+inductor` enables dump both Inductor and torchsched code. If omitted, only dump 68 | # torchsched code. `/dir/to/save/code` specifies a directory to dump code to. 69 | ( 70 | dump_code_backends, 71 | dump_code_dir, 72 | ) = __get_dump_code_backends_and_dir(os.getenv("TORCH_SCHED_DUMP_CODE")) 73 | 74 | from torch.utils._config_module import install_config_module # noqa: E402 75 | 76 | # adds patch, save_config, etc 77 | install_config_module(sys.modules[__name__]) 78 | -------------------------------------------------------------------------------- /tests/distributed/amp_master_params/amp_master_params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | from apex import amp 5 | 6 | # FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead) 7 | from apex.parallel import DistributedDataParallel 8 | 9 | parser = argparse.ArgumentParser() 10 | # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied 11 | # automatically by torch.distributed.launch. 12 | parser.add_argument("--local_rank", default=0, type=int) 13 | args = parser.parse_args() 14 | 15 | # FOR DISTRIBUTED: If we are running under torch.distributed.launch, 16 | # the 'WORLD_SIZE' environment variable will also be set automatically. 17 | args.distributed = False 18 | if "WORLD_SIZE" in os.environ: 19 | args.distributed = int(os.environ["WORLD_SIZE"]) > 1 20 | 21 | if args.distributed: 22 | # FOR DISTRIBUTED: Set the device according to local_rank. 23 | torch.cuda.set_device(args.local_rank) 24 | 25 | # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide 26 | # environment variables, and requires that you use init_method=`env://`. 27 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 28 | 29 | torch.manual_seed(torch.distributed.get_rank()) 30 | 31 | torch.backends.cudnn.benchmark = True 32 | 33 | N, D_in, D_out = 64, 1024, 16 34 | 35 | # Each process receives its own batch of "fake input data" and "fake target data." 36 | # The "training loop" in each process just uses this fake batch over and over. 37 | # https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic 38 | # example of distributed data sampling for both training and validation. 39 | x = torch.randn(N, D_in, device="cuda") 40 | y = torch.randn(N, D_out, device="cuda") 41 | 42 | model = torch.nn.Linear(D_in, D_out).cuda() 43 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 44 | 45 | model, optimizer = amp.initialize(model, optimizer, opt_level="O2") 46 | 47 | if args.distributed: 48 | # FOR DISTRIBUTED: After amp.initialize, wrap the model with 49 | # apex.parallel.DistributedDataParallel. 50 | model = DistributedDataParallel(model) 51 | # torch.nn.parallel.DistributedDataParallel is also fine, with some added args: 52 | # model = torch.nn.parallel.DistributedDataParallel(model, 53 | # device_ids=[args.local_rank], 54 | # output_device=args.local_rank) 55 | 56 | loss_fn = torch.nn.MSELoss() 57 | 58 | for t in range(500): 59 | optimizer.zero_grad() 60 | y_pred = model(x) 61 | loss = loss_fn(y_pred, y) 62 | with amp.scale_loss(loss, optimizer) as scaled_loss: 63 | scaled_loss.backward() 64 | optimizer.step() 65 | 66 | if args.local_rank == 0: 67 | print("final loss = ", loss) 68 | 69 | torch.save(list(model.parameters()), "rank{}model.pth".format(torch.distributed.get_rank())) 70 | torch.save( 71 | list(amp.master_params(optimizer)), 72 | "rank{}master.pth".format(torch.distributed.get_rank()), 73 | ) 74 | -------------------------------------------------------------------------------- /apex/transformer/README.md: -------------------------------------------------------------------------------- 1 | # apex.transformer 2 | 3 | `apex.transformer` is a module which enables efficient large Transformer models at scale. 4 | 5 | `apex.transformer.tensor_parallel` and `apex.transformer.pipeline_parallel` are both based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s module. 6 | The former is based on `megatron.mpu` and the latter is on `megatron.schedules` and `megatron.p2p_communication`. 7 | 8 | ## Tensor Model Parallel (TP) 9 | 10 | APEX's tensor model parallel utilities provides some `torch.nn.Module`'s, custom fused kernels, and PRNG state handling. 11 | See Appendix B.2 of [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) for the details of 12 | PRNG state handling. 13 | 14 | ## Pipeline Model Parallel (PP) 15 | APEX's pipeline model parallel functions require models to have `.set_input_tensor` because 16 | the input tensor for `.forward` method can be `None`. 17 | 18 | The following is a really casual sketch of training script with apex pp. 19 | 20 | ```python 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | from apex.transformer import parallel_state 26 | from apex.transformer.pipeline_parallel import get_forward_backward_func 27 | 28 | 29 | class Model(nn.Module): 30 | 31 | ... 32 | 33 | def __init__(self, *args, **kwargs): 34 | super().__init__() 35 | pre_process = kwargs.pop("pre_process") 36 | post_process = kwargs.pop("post_process") 37 | 38 | def set_input_tensor(self, tensor): 39 | self.input_tensor = tensor 40 | 41 | def forward(self, x, ...): 42 | if parallel_state.is_pipeline_first_stage(): 43 | input = x 44 | else: 45 | input = self.input_tensor 46 | ... 47 | 48 | 49 | def model_provider_func(*args, **kwargs): 50 | return Model(*args, **kwargs) 51 | 52 | 53 | def loss_func(pred, label): 54 | loss = ... 55 | averaged_loss = average_losses_across_data_parallel_group([loss]) 56 | return loss, {'nice_loss': averaged_loss} 57 | 58 | 59 | def forward_step_func(batch, model): 60 | input, label = process_batch(batch) 61 | out = model(input) 62 | return out, partial(loss_func, label) 63 | 64 | 65 | forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size) 66 | 67 | 68 | parallel_state.initialize_model_parallel( 69 | tensor_model_parallel_size, 70 | pipeline_model_parallel_size, 71 | virtual_pipeline_model_parallel_size, 72 | ) 73 | # The following line basically is equivalent to `build_model(Model, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)` 74 | model = build_model(model_provider_func, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs) 75 | optimizer = ... 76 | data_loader = ... 77 | for epoch in range(num_epochs): 78 | for batch in data_loader: 79 | forward_backward_func(forward_step_func, batch, model, forward_only=False, tensor_shape) 80 | optimizer.step() 81 | ``` 82 | -------------------------------------------------------------------------------- /apex/mlp/mlp.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from apex._autocast_utils import _cast_if_autocast_enabled 8 | import mlp_cuda 9 | 10 | 11 | class MlpFunction(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, bias, activation, *args): 14 | output = mlp_cuda.forward(bias, activation, args) 15 | ctx.save_for_backward(*args) 16 | ctx.outputs = output 17 | ctx.bias = bias 18 | ctx.activation = activation 19 | return output[0] 20 | 21 | @staticmethod 22 | def backward(ctx, grad_o): 23 | grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors) 24 | del ctx.outputs 25 | return (None, None, *grads) 26 | 27 | 28 | def mlp_function(bias, activation, *args): 29 | autocast_args = _cast_if_autocast_enabled(bias, activation, *args) 30 | return MlpFunction.apply(*autocast_args) 31 | 32 | 33 | class MLP(torch.nn.Module): 34 | """Launch MLP in C++ 35 | 36 | Args: 37 | mlp_sizes (list of int): MLP sizes. Example: [1024,1024,1024] will create 2 MLP layers with shape 1024x1024 38 | bias (bool): Default True: 39 | relu (bool): Default True 40 | """ 41 | 42 | def __init__(self, mlp_sizes, bias=True, activation="relu"): 43 | super().__init__() 44 | self.num_layers = len(mlp_sizes) - 1 45 | self.mlp_sizes = copy(mlp_sizes) 46 | self.bias = 1 if bias else 0 47 | 48 | if activation == "none": 49 | self.activation = 0 50 | elif activation == "relu": 51 | self.activation = 1 52 | elif activation == "sigmoid": 53 | self.activation = 2 54 | else: 55 | raise TypeError("activation must be relu or none.") 56 | 57 | self.weights = [] 58 | self.biases = [] 59 | for i in range(self.num_layers): 60 | w = torch.nn.Parameter(torch.empty(mlp_sizes[i + 1], mlp_sizes[i])) 61 | self.weights.append(w) 62 | name = "weight_{}".format(i) 63 | setattr(self, name, w) 64 | if self.bias: 65 | b = torch.nn.Parameter(torch.empty(mlp_sizes[i + 1])) 66 | self.biases.append(b) 67 | name = "bias_{}".format(i) 68 | setattr(self, name, b) 69 | 70 | self.reset_parameters() 71 | 72 | def reset_parameters(self): 73 | for weight in self.weights: 74 | dimsum = weight.size(0) + weight.size(1) 75 | std = math.sqrt(2.0 / float(dimsum)) 76 | nn.init.normal_(weight, 0.0, std) 77 | if self.bias: 78 | for bias in self.biases: 79 | std = math.sqrt(1.0 / float(bias.size(0))) 80 | nn.init.normal_(bias, 0.0, std) 81 | 82 | def forward(self, input): 83 | return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases) 84 | 85 | def extra_repr(self): 86 | s = f"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}" 87 | return s 88 | -------------------------------------------------------------------------------- /apex/contrib/sparsity/permutation_tests/unstructured_study.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 2 ]; then 4 | echo "Please specify both the source directory and a run tag: bash unstructured_study.sh " 5 | exit 6 | fi 7 | 8 | dir=$1 # or set to the directory containing .npy files of interest 9 | tag=$2 # or set to an identifier, e.g. "network_name" 10 | 11 | resdir="results/unstructured_logs/${tag}" 12 | mkdir -p $resdir 13 | 14 | CS=channel_swap,0 15 | OSG2=optimize_stripe_groups,8,0 16 | OSG2_100=optimize_stripe_groups,8,100 17 | OSG2_1000=optimize_stripe_groups,8,1000 18 | OSG3=optimize_stripe_groups,12,0 19 | 20 | CS_successes=() 21 | OSG2_successes=() 22 | OSG2_100_successes=() 23 | OSG2_1000_successes=() 24 | OSG3_successes=() 25 | 26 | for sparsity in {50..100}; do 27 | CS_successes+=(0) 28 | OSG2_successes+=(0) 29 | OSG2_100_successes+=(0) 30 | OSG2_1000_successes+=(0) 31 | OSG3_successes+=(0) 32 | done 33 | 34 | update_successes () { 35 | strategy=$1 36 | local -n _successes=$2 37 | logfile=$3 38 | 39 | limit=$(grep "${strategy}," $logfile | awk -F "," '{print $3}') 40 | 41 | echo $logfile, $strategy, $limit 42 | for (( sparsity=$limit; sparsity<=100; sparsity++ )); do 43 | let "entry = $sparsity - 50" 44 | let "value = ${_successes[$entry]} + 1" 45 | _successes[$entry]=$value 46 | done 47 | } 48 | 49 | # Figure 4 50 | for filename in $dir/*.npy; do 51 | out=$(basename -- "$filename") 52 | echo "Searching for minimum sparsities for $out" 53 | out=$resdir/$out.unstructured 54 | python3 permutation_test.py --infile=$filename --pretty_print=False --unstructured=-1 $CS $OSG2 $OSG2_100 $OSG2_1000 $OSG3 > $out 55 | 56 | update_successes "channel_swap_0" CS_successes "$out" 57 | update_successes "optimize_stripe_groups_8_0" OSG2_successes "$out" 58 | update_successes "optimize_stripe_groups_8_100" OSG2_100_successes "$out" 59 | update_successes "optimize_stripe_groups_8_1000" OSG2_1000_successes "$out" 60 | update_successes "optimize_stripe_groups_12_0" OSG3_successes "$out" 61 | done 62 | 63 | #################### save the table 64 | # log a single strategy in as a row in the table 65 | log_success () { 66 | strategy=$1 67 | local -n _successes=$2 68 | OUTFILE=$3 69 | 70 | printf "$strategy," >> $OUTFILE 71 | for sparsity in {50..100}; do 72 | let "entry = $sparsity - 50" 73 | printf "%d," ${_successes[$entry]} >> $OUTFILE 74 | done 75 | printf "\n" >> $OUTFILE 76 | } 77 | 78 | # prepare the header 79 | OUTFILE="results/unstructured.csv" 80 | printf "Sparsity," > $OUTFILE 81 | for sparsity in {50..100}; do 82 | printf "%d," $sparsity >> $OUTFILE 83 | done 84 | printf "\n" >> $OUTFILE 85 | 86 | # add data for each strategy 87 | log_success "channel_swap_0" CS_successes "$OUTFILE" 88 | log_success "optimize_stripe_groups_8_0" OSG2_successes "$OUTFILE" 89 | log_success "optimize_stripe_groups_8_100" OSG2_100_successes "$OUTFILE" 90 | log_success "optimize_stripe_groups_8_1000" OSG2_1000_successes "$OUTFILE" 91 | log_success "optimize_stripe_groups_12_0" OSG3_successes "$OUTFILE" 92 | 93 | echo "Done! ${OUTFILE}" 94 | -------------------------------------------------------------------------------- /apex/contrib/csrc/group_norm_v2/gn_utils.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "gn.hpp" 10 | 11 | // Definition of CUDA_CHECK macro 12 | #define CUDA_CHECK(call) \ 13 | do { \ 14 | cudaError_t err_ = call; \ 15 | if (err_ != cudaSuccess) { \ 16 | fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", __FILE__, __LINE__, err_, cudaGetErrorString(err_), \ 17 | #call); \ 18 | exit(EXIT_FAILURE); \ 19 | } \ 20 | } while (0) 21 | 22 | #define GN_CUDA_HOST_PARAMS(T) \ 23 | T *out, T *x, T *w, T *b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, \ 24 | float *mean_var_out, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, \ 25 | Meta *meta_ptr, bool meta_only 26 | 27 | #define GN_BWD_CUDA_HOST_PARAMS(T) \ 28 | T *grad_input, T *grad_weight, T *grad_bias, T *grad_output, T *x, T *w, T *b, float *mean_var, float eps, \ 29 | bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *red_buffer, unsigned *barrier, \ 30 | int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only 31 | 32 | #define GN_CUDA_HOST_ARGS \ 33 | out, x, w, b, eps, silu, n, hw, num_groups, channels_per_group, mean_var_out, red_buffer, barrier, sm_margin, \ 34 | stream, device_id, meta_ptr, meta_only 35 | 36 | #define GN_BWD_CUDA_HOST_ARGS \ 37 | grad_input, grad_weight, grad_bias, grad_output, x, w, b, mean_var, eps, silu, n, hw, num_groups, \ 38 | channels_per_group, red_buffer, barrier, sm_margin, stream, device_id, meta_ptr, meta_only 39 | 40 | namespace group_norm_v2 { 41 | 42 | cudaDeviceProp const& get_device_prop(int device_id); 43 | 44 | #ifdef __CUDA_ARCH__ 45 | 46 | template 47 | __host__ __device__ inline int print_rank_0(char const* fmt, Ts&&... args) { 48 | if (threadIdx.x + threadIdx.y + threadIdx.z == 0 && blockIdx.x + blockIdx.y + blockIdx.z == 0) { 49 | return printf(fmt, std::forward(args)...); 50 | } 51 | return 0; 52 | } 53 | 54 | #endif 55 | 56 | } // namespace group_norm_v2 57 | -------------------------------------------------------------------------------- /apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | SKIP_TEST = None 6 | try: 7 | from apex.contrib.multihead_attn import SelfMultiheadAttn 8 | except ImportError as e: 9 | SKIP_TEST = e 10 | 11 | 12 | @unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") 13 | class SelfMultiheadAttnNormAddTest(unittest.TestCase): 14 | def setUp(self, seed=1234): 15 | torch.manual_seed(seed) 16 | 17 | self.seq_length = 80 18 | self.sequences = 10 19 | self.hidden_dim = 1024 20 | self.heads = 16 21 | self.dropout_prob = 0.0 22 | 23 | self.ref_layer = SelfMultiheadAttn( 24 | self.hidden_dim, 25 | self.heads, 26 | dropout=self.dropout_prob, 27 | bias=False, 28 | include_norm_add=True, 29 | impl="default", 30 | ) 31 | self.ref_layer.cuda().half() 32 | self.ref_layer.reset_parameters() 33 | self.ref_inputs = torch.randn( 34 | self.seq_length, 35 | self.sequences, 36 | self.hidden_dim, 37 | dtype=torch.float16, 38 | device=torch.device("cuda"), 39 | ).requires_grad_(True) 40 | 41 | # Reset seed so parameters are identical 42 | torch.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | 45 | self.tst_layer = SelfMultiheadAttn( 46 | self.hidden_dim, 47 | self.heads, 48 | dropout=self.dropout_prob, 49 | bias=False, 50 | include_norm_add=True, 51 | impl="fast", 52 | ) 53 | self.tst_layer.cuda().half() 54 | self.tst_layer.reset_parameters() 55 | 56 | self.tst_inputs = torch.randn( 57 | self.seq_length, 58 | self.sequences, 59 | self.hidden_dim, 60 | dtype=torch.float16, 61 | device=torch.device("cuda"), 62 | ).requires_grad_(True) 63 | 64 | def test_self_multihead_attn_norm_add(self): 65 | grads = torch.randn_like(self.tst_inputs) 66 | 67 | for _ in range(0, 5): 68 | ref_outputs, _ = self.ref_layer.forward( 69 | self.ref_inputs, 70 | self.ref_inputs, 71 | self.ref_inputs, 72 | key_padding_mask=None, 73 | need_weights=False, 74 | attn_mask=None, 75 | is_training=True, 76 | ) 77 | 78 | tst_outputs, _ = self.tst_layer.forward( 79 | self.tst_inputs, 80 | self.tst_inputs, 81 | self.tst_inputs, 82 | key_padding_mask=None, 83 | need_weights=False, 84 | attn_mask=None, 85 | is_training=True, 86 | ) 87 | 88 | self.ref_inputs.backward(grads) 89 | self.tst_inputs.backward(grads) 90 | 91 | torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5) 92 | torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3) 93 | torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3) 94 | 95 | 96 | if __name__ == "__main__": 97 | unittest.main() 98 | -------------------------------------------------------------------------------- /apex/contrib/multihead_attn/mask_softmax_dropout_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import fast_multihead_attn 4 | 5 | 6 | class MaskSoftmaxDropout(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob): 9 | heads_t = torch.tensor([heads]) 10 | dropout_prob_t = torch.tensor([dropout_prob]) 11 | null_tensor = torch.tensor([]) 12 | use_mask = pad_mask is not None 13 | use_mask_t = torch.tensor([use_mask]) 14 | mask_additive_t = torch.tensor([mask_additive]) 15 | 16 | if mask_additive: 17 | dropout_results, dropout_mask, softmax_results = ( 18 | fast_multihead_attn.additive_mask_softmax_dropout_forward( 19 | use_mask, 20 | is_training, 21 | heads, 22 | inputs, 23 | pad_mask if use_mask else null_tensor, 24 | dropout_prob, 25 | ) 26 | ) 27 | # fast_additive_mask_softmax_dropout.forward( \ 28 | else: 29 | dropout_results, dropout_mask, softmax_results = ( 30 | fast_multihead_attn.mask_softmax_dropout_forward( 31 | use_mask, 32 | is_training, 33 | heads, 34 | inputs, 35 | pad_mask if use_mask else null_tensor, 36 | dropout_prob, 37 | ) 38 | ) 39 | # fast_mask_softmax_dropout.forward( \ 40 | 41 | ctx.save_for_backward( 42 | use_mask_t, 43 | heads_t, 44 | softmax_results, 45 | dropout_mask, 46 | pad_mask if use_mask else null_tensor, 47 | mask_additive_t, 48 | dropout_prob_t, 49 | ) 50 | 51 | return dropout_results.detach() 52 | 53 | @staticmethod 54 | def backward(ctx, output_grads): 55 | ( 56 | use_mask_t, 57 | heads_t, 58 | softmax_results, 59 | dropout_mask, 60 | pad_mask, 61 | mask_additive_t, 62 | dropout_prob_t, 63 | ) = ctx.saved_tensors 64 | 65 | if mask_additive_t[0]: 66 | input_grads = fast_multihead_attn.additive_mask_softmax_dropout_backward( 67 | use_mask_t[0], 68 | heads_t[0], 69 | output_grads, 70 | softmax_results, 71 | dropout_mask, 72 | dropout_prob_t[0], 73 | ) 74 | # fast_additive_mask_softmax_dropout.backward( \ 75 | else: 76 | input_grads = fast_multihead_attn.mask_softmax_dropout_backward( 77 | use_mask_t[0], 78 | heads_t[0], 79 | output_grads, 80 | softmax_results, 81 | dropout_mask, 82 | pad_mask, 83 | dropout_prob_t[0], 84 | ) 85 | # fast_mask_softmax_dropout.backward( \ 86 | return None, None, input_grads, None, None, None 87 | 88 | 89 | fast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply 90 | -------------------------------------------------------------------------------- /csrc/multi_tensor_adagrad.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | // Another possibility: 6 | // #include 7 | 8 | #include 9 | 10 | #include "multi_tensor_apply.cuh" 11 | #include "type_shim.h" 12 | 13 | #define BLOCK_SIZE 1024 14 | #define ILP 4 15 | 16 | typedef enum { 17 | ADAGRAD_MODE_0 = 0, // L2 regularization mode. 18 | ADAGRAD_MODE_1 = 1, // AdamW-style weight decay. 19 | 20 | } adagradMode_t; 21 | 22 | using MATH_T = float; 23 | 24 | template 25 | struct AdagradFunctor { 26 | __device__ __forceinline__ void operator()(int chunk_size, volatile int* noop_gmem, TensorListMetadata<3>& tl, 27 | const float epsilon, const float lr, adagradMode_t mode, 28 | const float weight_decay) { 29 | int tensor_loc = tl.block_to_tensor[blockIdx.x]; 30 | int chunk_idx = tl.block_to_chunk[blockIdx.x]; 31 | int n = tl.sizes[tensor_loc]; 32 | 33 | T* g = (T*)tl.addresses[0][tensor_loc]; 34 | g += chunk_idx * chunk_size; 35 | 36 | T* p = (T*)tl.addresses[1][tensor_loc]; 37 | p += chunk_idx * chunk_size; 38 | 39 | T* h = (T*)tl.addresses[2][tensor_loc]; 40 | h += chunk_idx * chunk_size; 41 | 42 | n -= chunk_idx * chunk_size; 43 | 44 | // see note in multi_tensor_scale_kernel.cu 45 | for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { 46 | MATH_T r_g[ILP]; 47 | MATH_T r_p[ILP]; 48 | MATH_T r_h[ILP]; 49 | #pragma unroll 50 | for (int ii = 0; ii < ILP; ii++) { 51 | int i = i_start + threadIdx.x + ii * blockDim.x; 52 | if (i < n && i < chunk_size) { 53 | r_g[ii] = g[i]; 54 | r_p[ii] = p[i]; 55 | r_h[ii] = h[i]; 56 | } else { 57 | r_g[ii] = MATH_T(0); 58 | r_p[ii] = MATH_T(0); 59 | r_h[ii] = MATH_T(0); 60 | } 61 | } 62 | #pragma unroll 63 | for (int ii = 0; ii < ILP; ii++) { 64 | if (mode == ADAGRAD_MODE_0) { // L2 65 | r_g[ii] = r_g[ii] + weight_decay * r_p[ii]; 66 | r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii]; 67 | r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon)); 68 | } else { // AdamW-style 69 | r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii]; 70 | r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon) + weight_decay * r_p[ii]); 71 | } 72 | } 73 | #pragma unroll 74 | for (int ii = 0; ii < ILP; ii++) { 75 | int i = i_start + threadIdx.x + ii * blockDim.x; 76 | if (i < n && i < chunk_size) { 77 | p[i] = r_p[ii]; 78 | h[i] = r_h[ii]; 79 | } 80 | } 81 | } 82 | } 83 | }; 84 | 85 | void multi_tensor_adagrad_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, 86 | const float lr, const float epsilon, const int mode, const float weight_decay) { 87 | using namespace at; 88 | 89 | // Assume single type across p,g,h now 90 | DISPATCH_DOUBLE_FLOAT_AND_HALF( 91 | tensor_lists[0][0].scalar_type(), 0, "adagrad", 92 | multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdagradFunctor(), epsilon, lr, 93 | (adagradMode_t)mode, weight_decay);) 94 | 95 | AT_CUDA_CHECK(cudaGetLastError()); 96 | } 97 | -------------------------------------------------------------------------------- /apex/contrib/csrc/fmha/src/fmha/mask.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | namespace fmha { 31 | 32 | template 33 | struct Mask { 34 | using Mma_tile = fmha::Hmma_tile; 35 | 36 | template 37 | __device__ Mask(const Params& params, const BInfo& blockInfo, int tidx) { 38 | actual_seqlen = blockInfo.actual_seqlen; 39 | 40 | const int warp = tidx / Cta_tile::THREADS_PER_WARP; 41 | const int lane = tidx % Cta_tile::THREADS_PER_WARP; 42 | 43 | static_assert(Cta_tile::WARPS_K == 1, ""); 44 | 45 | // find the warp in the Cta tile 46 | const int warp_n = (warp / Cta_tile::WARPS_M); 47 | const int warp_m = (warp % Cta_tile::WARPS_M); 48 | // decompose warp into 8x4 tile 49 | const int quad = lane / 4; 50 | const int tid = (lane % 4) * 2; 51 | row = warp_m * 16 + quad; 52 | col = warp_n * 16 + tid; 53 | } 54 | 55 | inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const { 56 | // ii and jj iterate over the 2x4 fragment 57 | const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen; 58 | //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen; 59 | return col_valid; 60 | // return row_valid && col_valid; 61 | } 62 | 63 | // BERT Mask: if upper left is invalid, none are valid 64 | inline __device__ bool any_valid(int mi, int ni) const { return is_valid(mi, ni, 0, 0); } 65 | 66 | inline __device__ void load(int it) { row_offset = it * Cta_tile::M + row; } 67 | int row_offset; 68 | 69 | int row; 70 | int col; 71 | int actual_seqlen; 72 | }; 73 | 74 | } // namespace fmha 75 | -------------------------------------------------------------------------------- /csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include "scaled_upper_triang_masked_softmax.h" 26 | #include "type_shim.h" 27 | 28 | namespace multihead_attn { 29 | namespace fused_softmax { 30 | namespace scaled_upper_triang_masked_softmax { 31 | 32 | torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { 33 | // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] 34 | const int attn_batches = input.size(0); 35 | const int seq_len = input.size(1); 36 | TORCH_INTERNAL_ASSERT(seq_len <= 16384); 37 | 38 | // Output 39 | auto act_options = input.options().requires_grad(false); 40 | torch::Tensor softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); 41 | 42 | // Softmax Intermediate Result Ptr 43 | void* input_ptr = static_cast(input.data_ptr()); 44 | void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); 45 | 46 | DISPATCH_HALF_AND_BFLOAT( 47 | input.scalar_type(), "dispatch_scaled_upper_triang_masked_softmax_forward", 48 | dispatch_scaled_upper_triang_masked_softmax_forward( 49 | reinterpret_cast(softmax_results_ptr), reinterpret_cast(input_ptr), scale_factor, 50 | seq_len, seq_len, attn_batches);); 51 | return softmax_results; 52 | } 53 | 54 | torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) { 55 | auto output_grads = output_grads_.contiguous(); 56 | auto softmax_results = softmax_results_.contiguous(); 57 | 58 | // output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] 59 | const int attn_batches = output_grads.size(0); 60 | const int seq_len = output_grads.size(1); 61 | TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); 62 | 63 | void* output_grads_ptr = static_cast(output_grads.data_ptr()); 64 | 65 | // Softmax Grad 66 | DISPATCH_HALF_AND_BFLOAT( 67 | output_grads_.scalar_type(), "dispatch_scaled_upper_triang_masked_softmax_backward", 68 | dispatch_scaled_upper_triang_masked_softmax_backward( 69 | reinterpret_cast(output_grads_ptr), reinterpret_cast(output_grads_ptr), 70 | reinterpret_cast(softmax_results.data_ptr()), scale_factor, seq_len, seq_len, 71 | attn_batches);); 72 | 73 | // backward pass is completely in-place 74 | return output_grads; 75 | } 76 | } // namespace scaled_upper_triang_masked_softmax 77 | } // namespace fused_softmax 78 | } // namespace multihead_attn 79 | -------------------------------------------------------------------------------- /apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | SKIP_TEST = None 6 | try: 7 | from apex.contrib.multihead_attn import SelfMultiheadAttn 8 | except ImportError as e: 9 | SKIP_TEST = e 10 | 11 | 12 | @unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") 13 | class SelfMultiheadAttnTest(unittest.TestCase): 14 | def setUp(self, seed=1234): 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | 18 | self.seq_length = 80 19 | self.sequences = 10 20 | self.hidden_dim = 1024 21 | self.heads = 16 22 | self.dropout_prob = 0.0 23 | 24 | self.ref_layer = SelfMultiheadAttn( 25 | self.hidden_dim, 26 | self.heads, 27 | dropout=self.dropout_prob, 28 | bias=True, 29 | include_norm_add=False, 30 | separate_qkv_params=True, 31 | mask_additive=True, 32 | impl="default", 33 | ) 34 | self.ref_layer.cuda().half() 35 | self.ref_layer.reset_parameters() 36 | self.ref_inputs = torch.randn( 37 | self.seq_length, 38 | self.sequences, 39 | self.hidden_dim, 40 | dtype=torch.float16, 41 | device=torch.device("cuda"), 42 | ).requires_grad_(True) 43 | # Reset seed so parameters are identical 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | 47 | self.tst_layer = SelfMultiheadAttn( 48 | self.hidden_dim, 49 | self.heads, 50 | dropout=self.dropout_prob, 51 | bias=True, 52 | include_norm_add=False, 53 | separate_qkv_params=True, 54 | mask_additive=True, 55 | impl="fast", 56 | ) 57 | self.tst_layer.cuda().half() 58 | self.tst_layer.reset_parameters() 59 | 60 | self.tst_inputs = torch.randn( 61 | self.seq_length, 62 | self.sequences, 63 | self.hidden_dim, 64 | dtype=torch.float16, 65 | device=torch.device("cuda"), 66 | ).requires_grad_(True) 67 | 68 | def test_self_multihead_attn_additive_mask(self): 69 | grads = torch.randn_like(self.tst_inputs) 70 | mask = ((torch.randn(self.sequences, self.seq_length) > 0) * -10000.0).half().cuda() 71 | 72 | ref_outputs, _ = self.ref_layer.forward( 73 | self.ref_inputs, 74 | self.ref_inputs, 75 | self.ref_inputs, 76 | key_padding_mask=mask, 77 | need_weights=False, 78 | attn_mask=None, 79 | is_training=True, 80 | ) 81 | 82 | tst_outputs, _ = self.tst_layer.forward( 83 | self.tst_inputs, 84 | self.tst_inputs, 85 | self.tst_inputs, 86 | key_padding_mask=mask, 87 | need_weights=False, 88 | attn_mask=None, 89 | is_training=True, 90 | ) 91 | 92 | self.ref_inputs.backward(grads) 93 | self.tst_inputs.backward(grads) 94 | 95 | torch.testing.assert_close(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5) 96 | torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3) 97 | torch.testing.assert_close(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3) 98 | 99 | 100 | if __name__ == "__main__": 101 | unittest.main() 102 | --------------------------------------------------------------------------------